mirror of
https://github.com/router-for-me/CLIProxyAPI.git
synced 2026-02-02 04:20:50 +08:00
Added support for external hooks to observe model registry events using the `ModelRegistryHook` interface. Implemented thread-safe, non-blocking execution of hooks with panic recovery. Comprehensive tests added to verify hook behavior during registration, unregistration, blocking, and panic scenarios.
205 lines
5.3 KiB
Go
205 lines
5.3 KiB
Go
package registry
|
|
|
|
import (
|
|
"context"
|
|
"sync"
|
|
"testing"
|
|
"time"
|
|
)
|
|
|
|
func newTestModelRegistry() *ModelRegistry {
|
|
return &ModelRegistry{
|
|
models: make(map[string]*ModelRegistration),
|
|
clientModels: make(map[string][]string),
|
|
clientModelInfos: make(map[string]map[string]*ModelInfo),
|
|
clientProviders: make(map[string]string),
|
|
mutex: &sync.RWMutex{},
|
|
}
|
|
}
|
|
|
|
type registeredCall struct {
|
|
provider string
|
|
clientID string
|
|
models []*ModelInfo
|
|
}
|
|
|
|
type unregisteredCall struct {
|
|
provider string
|
|
clientID string
|
|
}
|
|
|
|
type capturingHook struct {
|
|
registeredCh chan registeredCall
|
|
unregisteredCh chan unregisteredCall
|
|
}
|
|
|
|
func (h *capturingHook) OnModelsRegistered(ctx context.Context, provider, clientID string, models []*ModelInfo) {
|
|
h.registeredCh <- registeredCall{provider: provider, clientID: clientID, models: models}
|
|
}
|
|
|
|
func (h *capturingHook) OnModelsUnregistered(ctx context.Context, provider, clientID string) {
|
|
h.unregisteredCh <- unregisteredCall{provider: provider, clientID: clientID}
|
|
}
|
|
|
|
func TestModelRegistryHook_OnModelsRegisteredCalled(t *testing.T) {
|
|
r := newTestModelRegistry()
|
|
hook := &capturingHook{
|
|
registeredCh: make(chan registeredCall, 1),
|
|
unregisteredCh: make(chan unregisteredCall, 1),
|
|
}
|
|
r.SetHook(hook)
|
|
|
|
inputModels := []*ModelInfo{
|
|
{ID: "m1", DisplayName: "Model One"},
|
|
{ID: "m2", DisplayName: "Model Two"},
|
|
}
|
|
r.RegisterClient("client-1", "OpenAI", inputModels)
|
|
|
|
select {
|
|
case call := <-hook.registeredCh:
|
|
if call.provider != "openai" {
|
|
t.Fatalf("provider mismatch: got %q, want %q", call.provider, "openai")
|
|
}
|
|
if call.clientID != "client-1" {
|
|
t.Fatalf("clientID mismatch: got %q, want %q", call.clientID, "client-1")
|
|
}
|
|
if len(call.models) != 2 {
|
|
t.Fatalf("models length mismatch: got %d, want %d", len(call.models), 2)
|
|
}
|
|
if call.models[0] == nil || call.models[0].ID != "m1" {
|
|
t.Fatalf("models[0] mismatch: got %#v, want ID=%q", call.models[0], "m1")
|
|
}
|
|
if call.models[1] == nil || call.models[1].ID != "m2" {
|
|
t.Fatalf("models[1] mismatch: got %#v, want ID=%q", call.models[1], "m2")
|
|
}
|
|
case <-time.After(2 * time.Second):
|
|
t.Fatal("timeout waiting for OnModelsRegistered hook call")
|
|
}
|
|
}
|
|
|
|
func TestModelRegistryHook_OnModelsUnregisteredCalled(t *testing.T) {
|
|
r := newTestModelRegistry()
|
|
hook := &capturingHook{
|
|
registeredCh: make(chan registeredCall, 1),
|
|
unregisteredCh: make(chan unregisteredCall, 1),
|
|
}
|
|
r.SetHook(hook)
|
|
|
|
r.RegisterClient("client-1", "OpenAI", []*ModelInfo{{ID: "m1"}})
|
|
select {
|
|
case <-hook.registeredCh:
|
|
case <-time.After(2 * time.Second):
|
|
t.Fatal("timeout waiting for OnModelsRegistered hook call")
|
|
}
|
|
|
|
r.UnregisterClient("client-1")
|
|
|
|
select {
|
|
case call := <-hook.unregisteredCh:
|
|
if call.provider != "openai" {
|
|
t.Fatalf("provider mismatch: got %q, want %q", call.provider, "openai")
|
|
}
|
|
if call.clientID != "client-1" {
|
|
t.Fatalf("clientID mismatch: got %q, want %q", call.clientID, "client-1")
|
|
}
|
|
case <-time.After(2 * time.Second):
|
|
t.Fatal("timeout waiting for OnModelsUnregistered hook call")
|
|
}
|
|
}
|
|
|
|
type blockingHook struct {
|
|
started chan struct{}
|
|
unblock chan struct{}
|
|
}
|
|
|
|
func (h *blockingHook) OnModelsRegistered(ctx context.Context, provider, clientID string, models []*ModelInfo) {
|
|
select {
|
|
case <-h.started:
|
|
default:
|
|
close(h.started)
|
|
}
|
|
<-h.unblock
|
|
}
|
|
|
|
func (h *blockingHook) OnModelsUnregistered(ctx context.Context, provider, clientID string) {}
|
|
|
|
func TestModelRegistryHook_DoesNotBlockRegisterClient(t *testing.T) {
|
|
r := newTestModelRegistry()
|
|
hook := &blockingHook{
|
|
started: make(chan struct{}),
|
|
unblock: make(chan struct{}),
|
|
}
|
|
r.SetHook(hook)
|
|
defer close(hook.unblock)
|
|
|
|
done := make(chan struct{})
|
|
go func() {
|
|
r.RegisterClient("client-1", "OpenAI", []*ModelInfo{{ID: "m1"}})
|
|
close(done)
|
|
}()
|
|
|
|
select {
|
|
case <-hook.started:
|
|
case <-time.After(2 * time.Second):
|
|
t.Fatal("timeout waiting for hook to start")
|
|
}
|
|
|
|
select {
|
|
case <-done:
|
|
case <-time.After(200 * time.Millisecond):
|
|
t.Fatal("RegisterClient appears to be blocked by hook")
|
|
}
|
|
|
|
if !r.ClientSupportsModel("client-1", "m1") {
|
|
t.Fatal("model registration failed; expected client to support model")
|
|
}
|
|
}
|
|
|
|
type panicHook struct {
|
|
registeredCalled chan struct{}
|
|
unregisteredCalled chan struct{}
|
|
}
|
|
|
|
func (h *panicHook) OnModelsRegistered(ctx context.Context, provider, clientID string, models []*ModelInfo) {
|
|
if h.registeredCalled != nil {
|
|
h.registeredCalled <- struct{}{}
|
|
}
|
|
panic("boom")
|
|
}
|
|
|
|
func (h *panicHook) OnModelsUnregistered(ctx context.Context, provider, clientID string) {
|
|
if h.unregisteredCalled != nil {
|
|
h.unregisteredCalled <- struct{}{}
|
|
}
|
|
panic("boom")
|
|
}
|
|
|
|
func TestModelRegistryHook_PanicDoesNotAffectRegistry(t *testing.T) {
|
|
r := newTestModelRegistry()
|
|
hook := &panicHook{
|
|
registeredCalled: make(chan struct{}, 1),
|
|
unregisteredCalled: make(chan struct{}, 1),
|
|
}
|
|
r.SetHook(hook)
|
|
|
|
r.RegisterClient("client-1", "OpenAI", []*ModelInfo{{ID: "m1"}})
|
|
|
|
select {
|
|
case <-hook.registeredCalled:
|
|
case <-time.After(2 * time.Second):
|
|
t.Fatal("timeout waiting for OnModelsRegistered hook call")
|
|
}
|
|
|
|
if !r.ClientSupportsModel("client-1", "m1") {
|
|
t.Fatal("model registration failed; expected client to support model")
|
|
}
|
|
|
|
r.UnregisterClient("client-1")
|
|
|
|
select {
|
|
case <-hook.unregisteredCalled:
|
|
case <-time.After(2 * time.Second):
|
|
t.Fatal("timeout waiting for OnModelsUnregistered hook call")
|
|
}
|
|
}
|