diff --git a/internal/registry/model_registry.go b/internal/registry/model_registry.go index a2c154ca..a4e9acdf 100644 --- a/internal/registry/model_registry.go +++ b/internal/registry/model_registry.go @@ -4,6 +4,7 @@ package registry import ( + "context" "fmt" "sort" "strings" @@ -84,6 +85,13 @@ type ModelRegistration struct { SuspendedClients map[string]string } +// ModelRegistryHook provides optional callbacks for external integrations to track model list changes. +// Hook implementations must be non-blocking and resilient; calls are executed asynchronously and panics are recovered. +type ModelRegistryHook interface { + OnModelsRegistered(ctx context.Context, provider, clientID string, models []*ModelInfo) + OnModelsUnregistered(ctx context.Context, provider, clientID string) +} + // ModelRegistry manages the global registry of available models type ModelRegistry struct { // models maps model ID to registration information @@ -97,6 +105,8 @@ type ModelRegistry struct { clientProviders map[string]string // mutex ensures thread-safe access to the registry mutex *sync.RWMutex + // hook is an optional callback sink for model registration changes + hook ModelRegistryHook } // Global model registry instance @@ -117,6 +127,53 @@ func GetGlobalRegistry() *ModelRegistry { return globalRegistry } +// SetHook sets an optional hook for observing model registration changes. +func (r *ModelRegistry) SetHook(hook ModelRegistryHook) { + if r == nil { + return + } + r.mutex.Lock() + defer r.mutex.Unlock() + r.hook = hook +} + +const defaultModelRegistryHookTimeout = 5 * time.Second + +func (r *ModelRegistry) triggerModelsRegistered(provider, clientID string, models []*ModelInfo) { + hook := r.hook + if hook == nil { + return + } + modelsCopy := cloneModelInfosUnique(models) + go func() { + defer func() { + if recovered := recover(); recovered != nil { + log.Errorf("model registry hook OnModelsRegistered panic: %v", recovered) + } + }() + ctx, cancel := context.WithTimeout(context.Background(), defaultModelRegistryHookTimeout) + defer cancel() + hook.OnModelsRegistered(ctx, provider, clientID, modelsCopy) + }() +} + +func (r *ModelRegistry) triggerModelsUnregistered(provider, clientID string) { + hook := r.hook + if hook == nil { + return + } + go func() { + defer func() { + if recovered := recover(); recovered != nil { + log.Errorf("model registry hook OnModelsUnregistered panic: %v", recovered) + } + }() + ctx, cancel := context.WithTimeout(context.Background(), defaultModelRegistryHookTimeout) + defer cancel() + hook.OnModelsUnregistered(ctx, provider, clientID) + }() +} + // RegisterClient registers a client and its supported models // Parameters: // - clientID: Unique identifier for the client @@ -177,6 +234,7 @@ func (r *ModelRegistry) RegisterClient(clientID, clientProvider string, models [ } else { delete(r.clientProviders, clientID) } + r.triggerModelsRegistered(provider, clientID, models) log.Debugf("Registered client %s from provider %s with %d models", clientID, clientProvider, len(rawModelIDs)) misc.LogCredentialSeparator() return @@ -310,6 +368,7 @@ func (r *ModelRegistry) RegisterClient(clientID, clientProvider string, models [ delete(r.clientProviders, clientID) } + r.triggerModelsRegistered(provider, clientID, models) if len(added) == 0 && len(removed) == 0 && !providerChanged { // Only metadata (e.g., display name) changed; skip separator when no log output. return @@ -400,6 +459,25 @@ func cloneModelInfo(model *ModelInfo) *ModelInfo { return ©Model } +func cloneModelInfosUnique(models []*ModelInfo) []*ModelInfo { + if len(models) == 0 { + return nil + } + cloned := make([]*ModelInfo, 0, len(models)) + seen := make(map[string]struct{}, len(models)) + for _, model := range models { + if model == nil || model.ID == "" { + continue + } + if _, exists := seen[model.ID]; exists { + continue + } + seen[model.ID] = struct{}{} + cloned = append(cloned, cloneModelInfo(model)) + } + return cloned +} + // UnregisterClient removes a client and decrements counts for its models // Parameters: // - clientID: Unique identifier for the client to remove @@ -460,6 +538,7 @@ func (r *ModelRegistry) unregisterClientInternal(clientID string) { log.Debugf("Unregistered client %s", clientID) // Separator line after completing client unregistration (after the summary line) misc.LogCredentialSeparator() + r.triggerModelsUnregistered(provider, clientID) } // SetModelQuotaExceeded marks a model as quota exceeded for a specific client diff --git a/internal/registry/model_registry_hook_test.go b/internal/registry/model_registry_hook_test.go new file mode 100644 index 00000000..70226b9e --- /dev/null +++ b/internal/registry/model_registry_hook_test.go @@ -0,0 +1,204 @@ +package registry + +import ( + "context" + "sync" + "testing" + "time" +) + +func newTestModelRegistry() *ModelRegistry { + return &ModelRegistry{ + models: make(map[string]*ModelRegistration), + clientModels: make(map[string][]string), + clientModelInfos: make(map[string]map[string]*ModelInfo), + clientProviders: make(map[string]string), + mutex: &sync.RWMutex{}, + } +} + +type registeredCall struct { + provider string + clientID string + models []*ModelInfo +} + +type unregisteredCall struct { + provider string + clientID string +} + +type capturingHook struct { + registeredCh chan registeredCall + unregisteredCh chan unregisteredCall +} + +func (h *capturingHook) OnModelsRegistered(ctx context.Context, provider, clientID string, models []*ModelInfo) { + h.registeredCh <- registeredCall{provider: provider, clientID: clientID, models: models} +} + +func (h *capturingHook) OnModelsUnregistered(ctx context.Context, provider, clientID string) { + h.unregisteredCh <- unregisteredCall{provider: provider, clientID: clientID} +} + +func TestModelRegistryHook_OnModelsRegisteredCalled(t *testing.T) { + r := newTestModelRegistry() + hook := &capturingHook{ + registeredCh: make(chan registeredCall, 1), + unregisteredCh: make(chan unregisteredCall, 1), + } + r.SetHook(hook) + + inputModels := []*ModelInfo{ + {ID: "m1", DisplayName: "Model One"}, + {ID: "m2", DisplayName: "Model Two"}, + } + r.RegisterClient("client-1", "OpenAI", inputModels) + + select { + case call := <-hook.registeredCh: + if call.provider != "openai" { + t.Fatalf("provider mismatch: got %q, want %q", call.provider, "openai") + } + if call.clientID != "client-1" { + t.Fatalf("clientID mismatch: got %q, want %q", call.clientID, "client-1") + } + if len(call.models) != 2 { + t.Fatalf("models length mismatch: got %d, want %d", len(call.models), 2) + } + if call.models[0] == nil || call.models[0].ID != "m1" { + t.Fatalf("models[0] mismatch: got %#v, want ID=%q", call.models[0], "m1") + } + if call.models[1] == nil || call.models[1].ID != "m2" { + t.Fatalf("models[1] mismatch: got %#v, want ID=%q", call.models[1], "m2") + } + case <-time.After(2 * time.Second): + t.Fatal("timeout waiting for OnModelsRegistered hook call") + } +} + +func TestModelRegistryHook_OnModelsUnregisteredCalled(t *testing.T) { + r := newTestModelRegistry() + hook := &capturingHook{ + registeredCh: make(chan registeredCall, 1), + unregisteredCh: make(chan unregisteredCall, 1), + } + r.SetHook(hook) + + r.RegisterClient("client-1", "OpenAI", []*ModelInfo{{ID: "m1"}}) + select { + case <-hook.registeredCh: + case <-time.After(2 * time.Second): + t.Fatal("timeout waiting for OnModelsRegistered hook call") + } + + r.UnregisterClient("client-1") + + select { + case call := <-hook.unregisteredCh: + if call.provider != "openai" { + t.Fatalf("provider mismatch: got %q, want %q", call.provider, "openai") + } + if call.clientID != "client-1" { + t.Fatalf("clientID mismatch: got %q, want %q", call.clientID, "client-1") + } + case <-time.After(2 * time.Second): + t.Fatal("timeout waiting for OnModelsUnregistered hook call") + } +} + +type blockingHook struct { + started chan struct{} + unblock chan struct{} +} + +func (h *blockingHook) OnModelsRegistered(ctx context.Context, provider, clientID string, models []*ModelInfo) { + select { + case <-h.started: + default: + close(h.started) + } + <-h.unblock +} + +func (h *blockingHook) OnModelsUnregistered(ctx context.Context, provider, clientID string) {} + +func TestModelRegistryHook_DoesNotBlockRegisterClient(t *testing.T) { + r := newTestModelRegistry() + hook := &blockingHook{ + started: make(chan struct{}), + unblock: make(chan struct{}), + } + r.SetHook(hook) + defer close(hook.unblock) + + done := make(chan struct{}) + go func() { + r.RegisterClient("client-1", "OpenAI", []*ModelInfo{{ID: "m1"}}) + close(done) + }() + + select { + case <-hook.started: + case <-time.After(2 * time.Second): + t.Fatal("timeout waiting for hook to start") + } + + select { + case <-done: + case <-time.After(200 * time.Millisecond): + t.Fatal("RegisterClient appears to be blocked by hook") + } + + if !r.ClientSupportsModel("client-1", "m1") { + t.Fatal("model registration failed; expected client to support model") + } +} + +type panicHook struct { + registeredCalled chan struct{} + unregisteredCalled chan struct{} +} + +func (h *panicHook) OnModelsRegistered(ctx context.Context, provider, clientID string, models []*ModelInfo) { + if h.registeredCalled != nil { + h.registeredCalled <- struct{}{} + } + panic("boom") +} + +func (h *panicHook) OnModelsUnregistered(ctx context.Context, provider, clientID string) { + if h.unregisteredCalled != nil { + h.unregisteredCalled <- struct{}{} + } + panic("boom") +} + +func TestModelRegistryHook_PanicDoesNotAffectRegistry(t *testing.T) { + r := newTestModelRegistry() + hook := &panicHook{ + registeredCalled: make(chan struct{}, 1), + unregisteredCalled: make(chan struct{}, 1), + } + r.SetHook(hook) + + r.RegisterClient("client-1", "OpenAI", []*ModelInfo{{ID: "m1"}}) + + select { + case <-hook.registeredCalled: + case <-time.After(2 * time.Second): + t.Fatal("timeout waiting for OnModelsRegistered hook call") + } + + if !r.ClientSupportsModel("client-1", "m1") { + t.Fatal("model registration failed; expected client to support model") + } + + r.UnregisterClient("client-1") + + select { + case <-hook.unregisteredCalled: + case <-time.After(2 * time.Second): + t.Fatal("timeout waiting for OnModelsUnregistered hook call") + } +} diff --git a/sdk/cliproxy/model_registry.go b/sdk/cliproxy/model_registry.go index 3cd57842..01cea5b7 100644 --- a/sdk/cliproxy/model_registry.go +++ b/sdk/cliproxy/model_registry.go @@ -5,6 +5,9 @@ import "github.com/router-for-me/CLIProxyAPI/v6/internal/registry" // ModelInfo re-exports the registry model info structure. type ModelInfo = registry.ModelInfo +// ModelRegistryHook re-exports the registry hook interface for external integrations. +type ModelRegistryHook = registry.ModelRegistryHook + // ModelRegistry describes registry operations consumed by external callers. type ModelRegistry interface { RegisterClient(clientID, clientProvider string, models []*ModelInfo) @@ -20,3 +23,8 @@ type ModelRegistry interface { func GlobalModelRegistry() ModelRegistry { return registry.GetGlobalRegistry() } + +// SetGlobalModelRegistryHook registers an optional hook on the shared global registry instance. +func SetGlobalModelRegistryHook(hook ModelRegistryHook) { + registry.GetGlobalRegistry().SetHook(hook) +}