From dff31a7a4c60b535686338f3f0393ecddc487756 Mon Sep 17 00:00:00 2001 From: Luis Pater Date: Wed, 27 Aug 2025 20:30:17 +0800 Subject: [PATCH] Improved the `/v1/models` endpoint --- internal/api/handlers/claude/code_handlers.go | 16 +- .../api/handlers/gemini/gemini_handlers.go | 60 +-- .../api/handlers/openai/openai_handlers.go | 85 +--- internal/api/server.go | 21 +- internal/client/claude_client.go | 25 ++ internal/client/client.go | 54 +++ internal/client/codex_client.go | 17 + internal/client/gemini-cli_client.go | 21 + internal/client/gemini_client.go | 21 + .../client/openai-compatibility_client.go | 32 ++ internal/client/qwen_client.go | 17 + internal/registry/model_definitions.go | 150 +++++++ internal/registry/model_registry.go | 374 ++++++++++++++++++ 13 files changed, 757 insertions(+), 136 deletions(-) create mode 100644 internal/registry/model_definitions.go create mode 100644 internal/registry/model_registry.go diff --git a/internal/api/handlers/claude/code_handlers.go b/internal/api/handlers/claude/code_handlers.go index 620eb17c..d1e4eaeb 100644 --- a/internal/api/handlers/claude/code_handlers.go +++ b/internal/api/handlers/claude/code_handlers.go @@ -16,6 +16,7 @@ import ( "github.com/luispater/CLIProxyAPI/internal/api/handlers" . "github.com/luispater/CLIProxyAPI/internal/constant" "github.com/luispater/CLIProxyAPI/internal/interfaces" + "github.com/luispater/CLIProxyAPI/internal/registry" log "github.com/sirupsen/logrus" "github.com/tidwall/gjson" ) @@ -47,7 +48,9 @@ func (h *ClaudeCodeAPIHandler) HandlerType() string { // Models returns a list of models supported by this handler. func (h *ClaudeCodeAPIHandler) Models() []map[string]any { - return make([]map[string]any, 0) + // Get dynamic models from the global registry + modelRegistry := registry.GetGlobalRegistry() + return modelRegistry.GetAvailableModels("claude") } // ClaudeMessages handles Claude-compatible streaming chat completions. @@ -79,6 +82,17 @@ func (h *ClaudeCodeAPIHandler) ClaudeMessages(c *gin.Context) { h.handleStreamingResponse(c, rawJSON) } +// ClaudeModels handles the Claude models listing endpoint. +// It returns a JSON response containing available Claude models and their specifications. +// +// Parameters: +// - c: The Gin context for the request. +func (h *ClaudeCodeAPIHandler) ClaudeModels(c *gin.Context) { + c.JSON(http.StatusOK, gin.H{ + "data": h.Models(), + }) +} + // handleStreamingResponse streams Claude-compatible responses backed by Gemini. // It sets up SSE, selects a backend client with rotation/quota logic, // forwards chunks, and translates them to Claude CLI format. diff --git a/internal/api/handlers/gemini/gemini_handlers.go b/internal/api/handlers/gemini/gemini_handlers.go index ec9b9346..c8ee072d 100644 --- a/internal/api/handlers/gemini/gemini_handlers.go +++ b/internal/api/handlers/gemini/gemini_handlers.go @@ -16,6 +16,7 @@ import ( "github.com/luispater/CLIProxyAPI/internal/api/handlers" . "github.com/luispater/CLIProxyAPI/internal/constant" "github.com/luispater/CLIProxyAPI/internal/interfaces" + "github.com/luispater/CLIProxyAPI/internal/registry" log "github.com/sirupsen/logrus" ) @@ -40,62 +41,9 @@ func (h *GeminiAPIHandler) HandlerType() string { // Models returns the Gemini-compatible model metadata supported by this handler. func (h *GeminiAPIHandler) Models() []map[string]any { - return []map[string]any{ - { - "name": "models/gemini-2.5-flash", - "version": "001", - "displayName": "Gemini 2.5 Flash", - "description": "Stable version of Gemini 2.5 Flash, our mid-size multimodal model that supports up to 1 million tokens, released in June of 2025.", - "inputTokenLimit": 1048576, - "outputTokenLimit": 65536, - "supportedGenerationMethods": []string{ - "generateContent", - "countTokens", - "createCachedContent", - "batchGenerateContent", - }, - "temperature": 1, - "topP": 0.95, - "topK": 64, - "maxTemperature": 2, - "thinking": true, - }, - { - "name": "models/gemini-2.5-pro", - "version": "2.5", - "displayName": "Gemini 2.5 Pro", - "description": "Stable release (June 17th, 2025) of Gemini 2.5 Pro", - "inputTokenLimit": 1048576, - "outputTokenLimit": 65536, - "supportedGenerationMethods": []string{ - "generateContent", - "countTokens", - "createCachedContent", - "batchGenerateContent", - }, - "temperature": 1, - "topP": 0.95, - "topK": 64, - "maxTemperature": 2, - "thinking": true, - }, - { - "name": "gpt-5", - "version": "001", - "displayName": "GPT 5", - "description": "Stable version of GPT 5, The best model for coding and agentic tasks across domains.", - "inputTokenLimit": 400000, - "outputTokenLimit": 128000, - "supportedGenerationMethods": []string{ - "generateContent", - }, - "temperature": 1, - "topP": 0.95, - "topK": 64, - "maxTemperature": 2, - "thinking": true, - }, - } + // Get dynamic models from the global registry + modelRegistry := registry.GetGlobalRegistry() + return modelRegistry.GetAvailableModels("gemini") } // GeminiModels handles the Gemini models listing endpoint. diff --git a/internal/api/handlers/openai/openai_handlers.go b/internal/api/handlers/openai/openai_handlers.go index 2011c723..bd8d0aef 100644 --- a/internal/api/handlers/openai/openai_handlers.go +++ b/internal/api/handlers/openai/openai_handlers.go @@ -16,6 +16,7 @@ import ( "github.com/luispater/CLIProxyAPI/internal/api/handlers" . "github.com/luispater/CLIProxyAPI/internal/constant" "github.com/luispater/CLIProxyAPI/internal/interfaces" + "github.com/luispater/CLIProxyAPI/internal/registry" log "github.com/sirupsen/logrus" "github.com/tidwall/gjson" ) @@ -47,90 +48,18 @@ func (h *OpenAIAPIHandler) HandlerType() string { // Models returns the OpenAI-compatible model metadata supported by this handler. func (h *OpenAIAPIHandler) Models() []map[string]any { - return []map[string]any{ - { - "id": "gemini-2.5-pro", - "object": "model", - "version": "2.5", - "name": "Gemini 2.5 Pro", - "description": "Stable release (June 17th, 2025) of Gemini 2.5 Pro", - "context_length": 1_048_576, - "max_completion_tokens": 65_536, - "supported_parameters": []string{ - "tools", - "temperature", - "top_p", - "top_k", - }, - "temperature": 1, - "topP": 0.95, - "topK": 64, - "maxTemperature": 2, - "thinking": true, - }, - { - "id": "gemini-2.5-flash", - "object": "model", - "version": "001", - "name": "Gemini 2.5 Flash", - "description": "Stable version of Gemini 2.5 Flash, our mid-size multimodal model that supports up to 1 million tokens, released in June of 2025.", - "context_length": 1_048_576, - "max_completion_tokens": 65_536, - "supported_parameters": []string{ - "tools", - "temperature", - "top_p", - "top_k", - }, - "temperature": 1, - "topP": 0.95, - "topK": 64, - "maxTemperature": 2, - "thinking": true, - }, - { - "id": "gpt-5", - "object": "model", - "version": "gpt-5-2025-08-07", - "name": "GPT 5", - "description": "Stable version of GPT 5, The best model for coding and agentic tasks across domains.", - "context_length": 400_000, - "max_completion_tokens": 128_000, - "supported_parameters": []string{ - "tools", - }, - "temperature": 1, - "topP": 0.95, - "topK": 64, - "maxTemperature": 2, - "thinking": true, - }, - { - "id": "claude-opus-4-1-20250805", - "object": "model", - "version": "claude-opus-4-1-20250805", - "name": "Claude Opus 4.1", - "description": "Anthropic's most capable model.", - "context_length": 200_000, - "max_completion_tokens": 32_000, - "supported_parameters": []string{ - "tools", - }, - "temperature": 1, - "topP": 0.95, - "topK": 64, - "maxTemperature": 2, - "thinking": true, - }, - } + // Get dynamic models from the global registry + modelRegistry := registry.GetGlobalRegistry() + return modelRegistry.GetAvailableModels("openai") } // OpenAIModels handles the /v1/models endpoint. -// It returns a hardcoded list of available AI models with their capabilities +// It returns a list of available AI models with their capabilities // and specifications in OpenAI-compatible format. func (h *OpenAIAPIHandler) OpenAIModels(c *gin.Context) { c.JSON(http.StatusOK, gin.H{ - "data": h.Models(), + "object": "list", + "data": h.Models(), }) } diff --git a/internal/api/server.go b/internal/api/server.go index 374a25ea..cf5328de 100644 --- a/internal/api/server.go +++ b/internal/api/server.go @@ -102,7 +102,7 @@ func (s *Server) setupRoutes() { v1 := s.engine.Group("/v1") v1.Use(AuthMiddleware(s.cfg)) { - v1.GET("/models", openaiHandlers.OpenAIModels) + v1.GET("/models", s.unifiedModelsHandler(openaiHandlers, claudeCodeHandlers)) v1.POST("/chat/completions", openaiHandlers.ChatCompletions) v1.POST("/messages", claudeCodeHandlers.ClaudeMessages) } @@ -130,6 +130,25 @@ func (s *Server) setupRoutes() { s.engine.POST("/v1internal:method", geminiCLIHandlers.CLIHandler) } +// unifiedModelsHandler creates a unified handler for the /v1/models endpoint +// that routes to different handlers based on the User-Agent header. +// If User-Agent starts with "claude-cli", it routes to Claude handler, +// otherwise it routes to OpenAI handler. +func (s *Server) unifiedModelsHandler(openaiHandler *openai.OpenAIAPIHandler, claudeHandler *claude.ClaudeCodeAPIHandler) gin.HandlerFunc { + return func(c *gin.Context) { + userAgent := c.GetHeader("User-Agent") + + // Route to Claude handler if User-Agent starts with "claude-cli" + if strings.HasPrefix(userAgent, "claude-cli") { + log.Debugf("Routing /v1/models to Claude handler for User-Agent: %s", userAgent) + claudeHandler.ClaudeModels(c) + } else { + log.Debugf("Routing /v1/models to OpenAI handler for User-Agent: %s", userAgent) + openaiHandler.OpenAIModels(c) + } + } +} + // Start begins listening for and serving HTTP requests. // It's a blocking call and will only return on an unrecoverable error. // diff --git a/internal/client/claude_client.go b/internal/client/claude_client.go index cdfe54a9..6da0486a 100644 --- a/internal/client/claude_client.go +++ b/internal/client/claude_client.go @@ -23,6 +23,7 @@ import ( . "github.com/luispater/CLIProxyAPI/internal/constant" "github.com/luispater/CLIProxyAPI/internal/interfaces" "github.com/luispater/CLIProxyAPI/internal/misc" + "github.com/luispater/CLIProxyAPI/internal/registry" "github.com/luispater/CLIProxyAPI/internal/translator/translator" "github.com/luispater/CLIProxyAPI/internal/util" log "github.com/sirupsen/logrus" @@ -55,6 +56,10 @@ type ClaudeClient struct { // - *ClaudeClient: A new Claude client instance. func NewClaudeClient(cfg *config.Config, ts *claude.ClaudeTokenStorage) *ClaudeClient { httpClient := util.SetProxy(cfg, &http.Client{}) + + // Generate unique client ID + clientID := fmt.Sprintf("claude-%d", time.Now().UnixNano()) + client := &ClaudeClient{ ClientBase: ClientBase{ RequestMutex: &sync.Mutex{}, @@ -67,6 +72,10 @@ func NewClaudeClient(cfg *config.Config, ts *claude.ClaudeTokenStorage) *ClaudeC apiKeyIndex: -1, } + // Initialize model registry and register Claude models + client.InitializeModelRegistry(clientID) + client.RegisterModels("claude", registry.GetClaudeModels()) + return client } @@ -82,6 +91,10 @@ func NewClaudeClient(cfg *config.Config, ts *claude.ClaudeTokenStorage) *ClaudeC // - *ClaudeClient: A new Claude client instance. func NewClaudeClientWithKey(cfg *config.Config, apiKeyIndex int) *ClaudeClient { httpClient := util.SetProxy(cfg, &http.Client{}) + + // Generate unique client ID for API key client + clientID := fmt.Sprintf("claude-apikey-%d-%d", apiKeyIndex, time.Now().UnixNano()) + client := &ClaudeClient{ ClientBase: ClientBase{ RequestMutex: &sync.Mutex{}, @@ -94,6 +107,10 @@ func NewClaudeClientWithKey(cfg *config.Config, apiKeyIndex int) *ClaudeClient { apiKeyIndex: apiKeyIndex, } + // Initialize model registry and register Claude models + client.InitializeModelRegistry(clientID) + client.RegisterModels("claude", registry.GetClaudeModels()) + return client } @@ -174,10 +191,14 @@ func (c *ClaudeClient) SendRawMessage(ctx context.Context, modelName string, raw if err.StatusCode == 429 { now := time.Now() c.modelQuotaExceeded[modelName] = &now + // Update model registry quota status + c.SetModelQuotaExceeded(modelName) } return nil, err } delete(c.modelQuotaExceeded, modelName) + // Clear quota status in model registry + c.ClearModelQuotaExceeded(modelName) bodyBytes, errReadAll := io.ReadAll(respBody) if errReadAll != nil { return nil, &interfaces.ErrorMessage{StatusCode: 500, Error: errReadAll} @@ -234,11 +255,15 @@ func (c *ClaudeClient) SendRawMessageStream(ctx context.Context, modelName strin if err.StatusCode == 429 { now := time.Now() c.modelQuotaExceeded[modelName] = &now + // Update model registry quota status + c.SetModelQuotaExceeded(modelName) } errChan <- err return } delete(c.modelQuotaExceeded, modelName) + // Clear quota status in model registry + c.ClearModelQuotaExceeded(modelName) defer func() { _ = stream.Close() }() diff --git a/internal/client/client.go b/internal/client/client.go index 60201db2..6dd7fa56 100644 --- a/internal/client/client.go +++ b/internal/client/client.go @@ -13,6 +13,7 @@ import ( "github.com/gin-gonic/gin" "github.com/luispater/CLIProxyAPI/internal/auth" "github.com/luispater/CLIProxyAPI/internal/config" + "github.com/luispater/CLIProxyAPI/internal/registry" ) // ClientBase provides a common base structure for all AI API clients. @@ -34,6 +35,12 @@ type ClientBase struct { // modelQuotaExceeded tracks when models have exceeded their quota. // The map key is the model name, and the value is the time when the quota was exceeded. modelQuotaExceeded map[string]*time.Time + + // clientID is the unique identifier for this client instance. + clientID string + + // modelRegistry is the global model registry for tracking model availability. + modelRegistry *registry.ModelRegistry } // GetRequestMutex returns the mutex used to synchronize requests for this client. @@ -71,3 +78,50 @@ func (c *ClientBase) AddAPIResponseData(ctx context.Context, line []byte) { } } } + +// InitializeModelRegistry initializes the model registry for this client +// This should be called by all client implementations during construction +func (c *ClientBase) InitializeModelRegistry(clientID string) { + c.clientID = clientID + c.modelRegistry = registry.GetGlobalRegistry() +} + +// RegisterModels registers the models that this client can provide +// Parameters: +// - provider: The provider name (e.g., "gemini", "claude", "openai") +// - models: The list of models this client supports +func (c *ClientBase) RegisterModels(provider string, models []*registry.ModelInfo) { + if c.modelRegistry != nil && c.clientID != "" { + c.modelRegistry.RegisterClient(c.clientID, provider, models) + } +} + +// UnregisterClient removes this client from the model registry +func (c *ClientBase) UnregisterClient() { + if c.modelRegistry != nil && c.clientID != "" { + c.modelRegistry.UnregisterClient(c.clientID) + } +} + +// SetModelQuotaExceeded marks a model as quota exceeded in the registry +// Parameters: +// - modelID: The model that exceeded quota +func (c *ClientBase) SetModelQuotaExceeded(modelID string) { + if c.modelRegistry != nil && c.clientID != "" { + c.modelRegistry.SetModelQuotaExceeded(c.clientID, modelID) + } +} + +// ClearModelQuotaExceeded clears quota exceeded status for a model +// Parameters: +// - modelID: The model to clear quota status for +func (c *ClientBase) ClearModelQuotaExceeded(modelID string) { + if c.modelRegistry != nil && c.clientID != "" { + c.modelRegistry.ClearModelQuotaExceeded(c.clientID, modelID) + } +} + +// GetClientID returns the unique identifier for this client +func (c *ClientBase) GetClientID() string { + return c.clientID +} diff --git a/internal/client/codex_client.go b/internal/client/codex_client.go index 1122ff2c..14acc9a1 100644 --- a/internal/client/codex_client.go +++ b/internal/client/codex_client.go @@ -22,6 +22,7 @@ import ( "github.com/luispater/CLIProxyAPI/internal/config" . "github.com/luispater/CLIProxyAPI/internal/constant" "github.com/luispater/CLIProxyAPI/internal/interfaces" + "github.com/luispater/CLIProxyAPI/internal/registry" "github.com/luispater/CLIProxyAPI/internal/translator/translator" "github.com/luispater/CLIProxyAPI/internal/util" log "github.com/sirupsen/logrus" @@ -50,6 +51,10 @@ type CodexClient struct { // - error: An error if the client creation fails. func NewCodexClient(cfg *config.Config, ts *codex.CodexTokenStorage) (*CodexClient, error) { httpClient := util.SetProxy(cfg, &http.Client{}) + + // Generate unique client ID + clientID := fmt.Sprintf("codex-%d", time.Now().UnixNano()) + client := &CodexClient{ ClientBase: ClientBase{ RequestMutex: &sync.Mutex{}, @@ -61,6 +66,10 @@ func NewCodexClient(cfg *config.Config, ts *codex.CodexTokenStorage) (*CodexClie codexAuth: codex.NewCodexAuth(cfg), } + // Initialize model registry and register OpenAI models + client.InitializeModelRegistry(clientID) + client.RegisterModels("codex", registry.GetOpenAIModels()) + return client, nil } @@ -123,10 +132,14 @@ func (c *CodexClient) SendRawMessage(ctx context.Context, modelName string, rawJ if err.StatusCode == 429 { now := time.Now() c.modelQuotaExceeded[modelName] = &now + // Update model registry quota status + c.SetModelQuotaExceeded(modelName) } return nil, err } delete(c.modelQuotaExceeded, modelName) + // Clear quota status in model registry + c.ClearModelQuotaExceeded(modelName) bodyBytes, errReadAll := io.ReadAll(respBody) if errReadAll != nil { return nil, &interfaces.ErrorMessage{StatusCode: 500, Error: errReadAll} @@ -184,11 +197,15 @@ func (c *CodexClient) SendRawMessageStream(ctx context.Context, modelName string if err.StatusCode == 429 { now := time.Now() c.modelQuotaExceeded[modelName] = &now + // Update model registry quota status + c.SetModelQuotaExceeded(modelName) } errChan <- err return } delete(c.modelQuotaExceeded, modelName) + // Clear quota status in model registry + c.ClearModelQuotaExceeded(modelName) defer func() { _ = stream.Close() }() diff --git a/internal/client/gemini-cli_client.go b/internal/client/gemini-cli_client.go index d0ee814f..d1c6bdf9 100644 --- a/internal/client/gemini-cli_client.go +++ b/internal/client/gemini-cli_client.go @@ -22,6 +22,7 @@ import ( "github.com/luispater/CLIProxyAPI/internal/config" . "github.com/luispater/CLIProxyAPI/internal/constant" "github.com/luispater/CLIProxyAPI/internal/interfaces" + "github.com/luispater/CLIProxyAPI/internal/registry" "github.com/luispater/CLIProxyAPI/internal/translator/translator" "github.com/luispater/CLIProxyAPI/internal/util" log "github.com/sirupsen/logrus" @@ -57,6 +58,9 @@ type GeminiCLIClient struct { // Returns: // - *GeminiCLIClient: A new Gemini CLI client instance. func NewGeminiCLIClient(httpClient *http.Client, ts *geminiAuth.GeminiTokenStorage, cfg *config.Config) *GeminiCLIClient { + // Generate unique client ID + clientID := fmt.Sprintf("gemini-cli-%d", time.Now().UnixNano()) + client := &GeminiCLIClient{ ClientBase: ClientBase{ RequestMutex: &sync.Mutex{}, @@ -66,6 +70,11 @@ func NewGeminiCLIClient(httpClient *http.Client, ts *geminiAuth.GeminiTokenStora modelQuotaExceeded: make(map[string]*time.Time), }, } + + // Initialize model registry and register Gemini models + client.InitializeModelRegistry(clientID) + client.RegisterModels("gemini-cli", registry.GetGeminiModels()) + return client } @@ -426,6 +435,8 @@ func (c *GeminiCLIClient) SendRawTokenCount(ctx context.Context, modelName strin if err.StatusCode == 429 { now := time.Now() c.modelQuotaExceeded[modelName] = &now + // Update model registry quota status + c.SetModelQuotaExceeded(modelName) if c.cfg.QuotaExceeded.SwitchPreviewModel { continue } @@ -433,6 +444,8 @@ func (c *GeminiCLIClient) SendRawTokenCount(ctx context.Context, modelName strin return nil, err } delete(c.modelQuotaExceeded, modelName) + // Clear quota status in model registry + c.ClearModelQuotaExceeded(modelName) bodyBytes, errReadAll := io.ReadAll(respBody) if errReadAll != nil { return nil, &interfaces.ErrorMessage{StatusCode: 500, Error: errReadAll} @@ -485,6 +498,8 @@ func (c *GeminiCLIClient) SendRawMessage(ctx context.Context, modelName string, if err.StatusCode == 429 { now := time.Now() c.modelQuotaExceeded[modelName] = &now + // Update model registry quota status + c.SetModelQuotaExceeded(modelName) if c.cfg.QuotaExceeded.SwitchPreviewModel { continue } @@ -492,6 +507,8 @@ func (c *GeminiCLIClient) SendRawMessage(ctx context.Context, modelName string, return nil, err } delete(c.modelQuotaExceeded, modelName) + // Clear quota status in model registry + c.ClearModelQuotaExceeded(modelName) bodyBytes, errReadAll := io.ReadAll(respBody) if errReadAll != nil { return nil, &interfaces.ErrorMessage{StatusCode: 500, Error: errReadAll} @@ -562,6 +579,8 @@ func (c *GeminiCLIClient) SendRawMessageStream(ctx context.Context, modelName st if err.StatusCode == 429 { now := time.Now() c.modelQuotaExceeded[modelName] = &now + // Update model registry quota status + c.SetModelQuotaExceeded(modelName) if c.cfg.QuotaExceeded.SwitchPreviewModel { continue } @@ -570,6 +589,8 @@ func (c *GeminiCLIClient) SendRawMessageStream(ctx context.Context, modelName st return } delete(c.modelQuotaExceeded, modelName) + // Clear quota status in model registry + c.ClearModelQuotaExceeded(modelName) break } defer func() { diff --git a/internal/client/gemini_client.go b/internal/client/gemini_client.go index fd6442a0..cf4b9d5a 100644 --- a/internal/client/gemini_client.go +++ b/internal/client/gemini_client.go @@ -18,6 +18,7 @@ import ( "github.com/luispater/CLIProxyAPI/internal/config" . "github.com/luispater/CLIProxyAPI/internal/constant" "github.com/luispater/CLIProxyAPI/internal/interfaces" + "github.com/luispater/CLIProxyAPI/internal/registry" "github.com/luispater/CLIProxyAPI/internal/translator/translator" "github.com/luispater/CLIProxyAPI/internal/util" log "github.com/sirupsen/logrus" @@ -44,6 +45,9 @@ type GeminiClient struct { // Returns: // - *GeminiClient: A new Gemini client instance. func NewGeminiClient(httpClient *http.Client, cfg *config.Config, glAPIKey string) *GeminiClient { + // Generate unique client ID + clientID := fmt.Sprintf("gemini-apikey-%s-%d", glAPIKey[:8], time.Now().UnixNano()) // Use first 8 chars of API key + client := &GeminiClient{ ClientBase: ClientBase{ RequestMutex: &sync.Mutex{}, @@ -53,6 +57,11 @@ func NewGeminiClient(httpClient *http.Client, cfg *config.Config, glAPIKey strin }, glAPIKey: glAPIKey, } + + // Initialize model registry and register Gemini models + client.InitializeModelRegistry(clientID) + client.RegisterModels("gemini", registry.GetGeminiModels()) + return client } @@ -195,10 +204,14 @@ func (c *GeminiClient) SendRawTokenCount(ctx context.Context, modelName string, if err.StatusCode == 429 { now := time.Now() c.modelQuotaExceeded[modelName] = &now + // Update model registry quota status + c.SetModelQuotaExceeded(modelName) } return nil, err } delete(c.modelQuotaExceeded, modelName) + // Clear quota status in model registry + c.ClearModelQuotaExceeded(modelName) bodyBytes, errReadAll := io.ReadAll(respBody) if errReadAll != nil { return nil, &interfaces.ErrorMessage{StatusCode: 500, Error: errReadAll} @@ -240,10 +253,14 @@ func (c *GeminiClient) SendRawMessage(ctx context.Context, modelName string, raw if err.StatusCode == 429 { now := time.Now() c.modelQuotaExceeded[modelName] = &now + // Update model registry quota status + c.SetModelQuotaExceeded(modelName) } return nil, err } delete(c.modelQuotaExceeded, modelName) + // Clear quota status in model registry + c.ClearModelQuotaExceeded(modelName) bodyBytes, errReadAll := io.ReadAll(respBody) if errReadAll != nil { return nil, &interfaces.ErrorMessage{StatusCode: 500, Error: errReadAll} @@ -297,11 +314,15 @@ func (c *GeminiClient) SendRawMessageStream(ctx context.Context, modelName strin if err.StatusCode == 429 { now := time.Now() c.modelQuotaExceeded[modelName] = &now + // Update model registry quota status + c.SetModelQuotaExceeded(modelName) } errChan <- err return } delete(c.modelQuotaExceeded, modelName) + // Clear quota status in model registry + c.ClearModelQuotaExceeded(modelName) defer func() { _ = stream.Close() }() diff --git a/internal/client/openai-compatibility_client.go b/internal/client/openai-compatibility_client.go index 22370ad2..6ae82401 100644 --- a/internal/client/openai-compatibility_client.go +++ b/internal/client/openai-compatibility_client.go @@ -19,6 +19,7 @@ import ( "github.com/luispater/CLIProxyAPI/internal/config" . "github.com/luispater/CLIProxyAPI/internal/constant" "github.com/luispater/CLIProxyAPI/internal/interfaces" + "github.com/luispater/CLIProxyAPI/internal/registry" "github.com/luispater/CLIProxyAPI/internal/translator/translator" "github.com/luispater/CLIProxyAPI/internal/util" log "github.com/sirupsen/logrus" @@ -53,6 +54,10 @@ func NewOpenAICompatibilityClient(cfg *config.Config, compatConfig *config.OpenA } httpClient := util.SetProxy(cfg, &http.Client{}) + + // Generate unique client ID + clientID := fmt.Sprintf("openai-compatibility-%s-%d", compatConfig.Name, time.Now().UnixNano()) + client := &OpenAICompatibilityClient{ ClientBase: ClientBase{ RequestMutex: &sync.Mutex{}, @@ -64,6 +69,25 @@ func NewOpenAICompatibilityClient(cfg *config.Config, compatConfig *config.OpenA currentAPIKeyIndex: 0, } + // Initialize model registry + client.InitializeModelRegistry(clientID) + + // Convert compatibility models to registry models and register them + registryModels := make([]*registry.ModelInfo, 0, len(compatConfig.Models)) + for _, model := range compatConfig.Models { + registryModel := ®istry.ModelInfo{ + ID: model.Alias, + Object: "model", + Created: time.Now().Unix(), + OwnedBy: compatConfig.Name, + Type: "openai-compatibility", + DisplayName: model.Name, + } + registryModels = append(registryModels, registryModel) + } + + client.RegisterModels(compatConfig.Name, registryModels) + return client, nil } @@ -216,10 +240,14 @@ func (c *OpenAICompatibilityClient) SendRawMessage(ctx context.Context, modelNam if err.StatusCode == 429 { now := time.Now() c.modelQuotaExceeded[modelName] = &now + // Update model registry quota status + c.SetModelQuotaExceeded(modelName) } return nil, err } delete(c.modelQuotaExceeded, modelName) + // Clear quota status in model registry + c.ClearModelQuotaExceeded(modelName) bodyBytes, errReadAll := io.ReadAll(respBody) if errReadAll != nil { return nil, &interfaces.ErrorMessage{StatusCode: 500, Error: errReadAll} @@ -270,11 +298,15 @@ func (c *OpenAICompatibilityClient) SendRawMessageStream(ctx context.Context, mo if err.StatusCode == 429 { now := time.Now() c.modelQuotaExceeded[modelName] = &now + // Update model registry quota status + c.SetModelQuotaExceeded(modelName) } errChan <- err return } delete(c.modelQuotaExceeded, modelName) + // Clear quota status in model registry + c.ClearModelQuotaExceeded(modelName) defer func() { _ = stream.Close() }() diff --git a/internal/client/qwen_client.go b/internal/client/qwen_client.go index 72007b9b..86d456ea 100644 --- a/internal/client/qwen_client.go +++ b/internal/client/qwen_client.go @@ -22,6 +22,7 @@ import ( "github.com/luispater/CLIProxyAPI/internal/config" . "github.com/luispater/CLIProxyAPI/internal/constant" "github.com/luispater/CLIProxyAPI/internal/interfaces" + "github.com/luispater/CLIProxyAPI/internal/registry" "github.com/luispater/CLIProxyAPI/internal/translator/translator" "github.com/luispater/CLIProxyAPI/internal/util" log "github.com/sirupsen/logrus" @@ -49,6 +50,10 @@ type QwenClient struct { // - *QwenClient: A new Qwen client instance. func NewQwenClient(cfg *config.Config, ts *qwen.QwenTokenStorage) *QwenClient { httpClient := util.SetProxy(cfg, &http.Client{}) + + // Generate unique client ID + clientID := fmt.Sprintf("qwen-%d", time.Now().UnixNano()) + client := &QwenClient{ ClientBase: ClientBase{ RequestMutex: &sync.Mutex{}, @@ -60,6 +65,10 @@ func NewQwenClient(cfg *config.Config, ts *qwen.QwenTokenStorage) *QwenClient { qwenAuth: qwen.NewQwenAuth(cfg), } + // Initialize model registry and register Qwen models + client.InitializeModelRegistry(clientID) + client.RegisterModels("qwen", registry.GetQwenModels()) + return client } @@ -119,10 +128,14 @@ func (c *QwenClient) SendRawMessage(ctx context.Context, modelName string, rawJS if err.StatusCode == 429 { now := time.Now() c.modelQuotaExceeded[modelName] = &now + // Update model registry quota status + c.SetModelQuotaExceeded(modelName) } return nil, err } delete(c.modelQuotaExceeded, modelName) + // Clear quota status in model registry + c.ClearModelQuotaExceeded(modelName) bodyBytes, errReadAll := io.ReadAll(respBody) if errReadAll != nil { return nil, &interfaces.ErrorMessage{StatusCode: 500, Error: errReadAll} @@ -182,11 +195,15 @@ func (c *QwenClient) SendRawMessageStream(ctx context.Context, modelName string, if err.StatusCode == 429 { now := time.Now() c.modelQuotaExceeded[modelName] = &now + // Update model registry quota status + c.SetModelQuotaExceeded(modelName) } errChan <- err return } delete(c.modelQuotaExceeded, modelName) + // Clear quota status in model registry + c.ClearModelQuotaExceeded(modelName) defer func() { _ = stream.Close() }() diff --git a/internal/registry/model_definitions.go b/internal/registry/model_definitions.go new file mode 100644 index 00000000..2ee4cd2d --- /dev/null +++ b/internal/registry/model_definitions.go @@ -0,0 +1,150 @@ +// Package registry provides model definitions for various AI service providers. +// This file contains static model definitions that can be used by clients +// when registering their supported models. +package registry + +import "time" + +// GetClaudeModels returns the standard Claude model definitions +func GetClaudeModels() []*ModelInfo { + return []*ModelInfo{ + { + ID: "claude-opus-4-1-20250805", + Object: "model", + Created: 1722945600, // 2025-08-05 + OwnedBy: "anthropic", + Type: "claude", + DisplayName: "Claude 4.1 Opus", + }, + { + ID: "claude-opus-4-20250514", + Object: "model", + Created: 1715644800, // 2025-05-14 + OwnedBy: "anthropic", + Type: "claude", + DisplayName: "Claude 4 Opus", + }, + { + ID: "claude-sonnet-4-20250514", + Object: "model", + Created: 1715644800, // 2025-05-14 + OwnedBy: "anthropic", + Type: "claude", + DisplayName: "Claude 4 Sonnet", + }, + { + ID: "claude-3-7-sonnet-20250219", + Object: "model", + Created: 1708300800, // 2025-02-19 + OwnedBy: "anthropic", + Type: "claude", + DisplayName: "Claude 3.7 Sonnet", + }, + { + ID: "claude-3-5-haiku-20241022", + Object: "model", + Created: 1729555200, // 2024-10-22 + OwnedBy: "anthropic", + Type: "claude", + DisplayName: "Claude 3.5 Haiku", + }, + } +} + +// GetGeminiModels returns the standard Gemini model definitions +func GetGeminiModels() []*ModelInfo { + return []*ModelInfo{ + { + ID: "gemini-2.5-flash", + Object: "model", + Created: time.Now().Unix(), + OwnedBy: "google", + Type: "gemini", + Name: "models/gemini-2.5-flash", + Version: "001", + DisplayName: "Gemini 2.5 Flash", + Description: "Stable version of Gemini 2.5 Flash, our mid-size multimodal model that supports up to 1 million tokens, released in June of 2025.", + InputTokenLimit: 1048576, + OutputTokenLimit: 65536, + SupportedGenerationMethods: []string{"generateContent", "countTokens", "createCachedContent", "batchGenerateContent"}, + }, + { + ID: "gemini-2.5-pro", + Object: "model", + Created: time.Now().Unix(), + OwnedBy: "google", + Type: "gemini", + Name: "models/gemini-2.5-pro", + Version: "2.5", + DisplayName: "Gemini 2.5 Pro", + Description: "Stable release (June 17th, 2025) of Gemini 2.5 Pro", + InputTokenLimit: 1048576, + OutputTokenLimit: 65536, + SupportedGenerationMethods: []string{"generateContent", "countTokens", "createCachedContent", "batchGenerateContent"}, + }, + } +} + +// GetOpenAIModels returns the standard OpenAI model definitions +func GetOpenAIModels() []*ModelInfo { + return []*ModelInfo{ + { + ID: "gpt-5", + Object: "model", + Created: time.Now().Unix(), + OwnedBy: "openai", + Type: "openai", + Version: "gpt-5-2025-08-07", + DisplayName: "GPT 5", + Description: "Stable version of GPT 5, The best model for coding and agentic tasks across domains.", + ContextLength: 400000, + MaxCompletionTokens: 128000, + SupportedParameters: []string{"tools"}, + }, + { + ID: "codex-mini-latest", + Object: "model", + Created: time.Now().Unix(), + OwnedBy: "openai", + Type: "openai", + Version: "1.0", + DisplayName: "Codex Mini", + Description: "Lightweight code generation model", + ContextLength: 4096, + MaxCompletionTokens: 2048, + SupportedParameters: []string{"temperature", "max_tokens", "stream", "stop"}, + }, + } +} + +// GetQwenModels returns the standard Qwen model definitions +func GetQwenModels() []*ModelInfo { + return []*ModelInfo{ + { + ID: "qwen3-coder-plus", + Object: "model", + Created: time.Now().Unix(), + OwnedBy: "qwen", + Type: "qwen", + Version: "3.0", + DisplayName: "Qwen3 Coder Plus", + Description: "Advanced code generation and understanding model", + ContextLength: 32768, + MaxCompletionTokens: 8192, + SupportedParameters: []string{"temperature", "top_p", "max_tokens", "stream", "stop"}, + }, + { + ID: "qwen3-coder-flash", + Object: "model", + Created: time.Now().Unix(), + OwnedBy: "qwen", + Type: "qwen", + Version: "3.0", + DisplayName: "Qwen3 Coder Flash", + Description: "Fast code generation model", + ContextLength: 8192, + MaxCompletionTokens: 2048, + SupportedParameters: []string{"temperature", "top_p", "max_tokens", "stream", "stop"}, + }, + } +} diff --git a/internal/registry/model_registry.go b/internal/registry/model_registry.go new file mode 100644 index 00000000..0221f5fb --- /dev/null +++ b/internal/registry/model_registry.go @@ -0,0 +1,374 @@ +// Package registry provides centralized model management for all AI service providers. +// It implements a dynamic model registry with reference counting to track active clients +// and automatically hide models when no clients are available or when quota is exceeded. +package registry + +import ( + "sync" + "time" + + log "github.com/sirupsen/logrus" +) + +// ModelInfo represents information about an available model +type ModelInfo struct { + // ID is the unique identifier for the model + ID string `json:"id"` + // Object type for the model (typically "model") + Object string `json:"object"` + // Created timestamp when the model was created + Created int64 `json:"created"` + // OwnedBy indicates the organization that owns the model + OwnedBy string `json:"owned_by"` + // Type indicates the model type (e.g., "claude", "gemini", "openai") + Type string `json:"type"` + // DisplayName is the human-readable name for the model + DisplayName string `json:"display_name,omitempty"` + // Name is used for Gemini-style model names + Name string `json:"name,omitempty"` + // Version is the model version + Version string `json:"version,omitempty"` + // Description provides detailed information about the model + Description string `json:"description,omitempty"` + // InputTokenLimit is the maximum input token limit + InputTokenLimit int `json:"inputTokenLimit,omitempty"` + // OutputTokenLimit is the maximum output token limit + OutputTokenLimit int `json:"outputTokenLimit,omitempty"` + // SupportedGenerationMethods lists supported generation methods + SupportedGenerationMethods []string `json:"supportedGenerationMethods,omitempty"` + // ContextLength is the context window size + ContextLength int `json:"context_length,omitempty"` + // MaxCompletionTokens is the maximum completion tokens + MaxCompletionTokens int `json:"max_completion_tokens,omitempty"` + // SupportedParameters lists supported parameters + SupportedParameters []string `json:"supported_parameters,omitempty"` +} + +// ModelRegistration tracks a model's availability +type ModelRegistration struct { + // Info contains the model metadata + Info *ModelInfo + // Count is the number of active clients that can provide this model + Count int + // LastUpdated tracks when this registration was last modified + LastUpdated time.Time + // QuotaExceededClients tracks which clients have exceeded quota for this model + QuotaExceededClients map[string]*time.Time +} + +// ModelRegistry manages the global registry of available models +type ModelRegistry struct { + // models maps model ID to registration information + models map[string]*ModelRegistration + // clientModels maps client ID to the models it provides + clientModels map[string][]string + // mutex ensures thread-safe access to the registry + mutex *sync.RWMutex +} + +// Global model registry instance +var globalRegistry *ModelRegistry +var registryOnce sync.Once + +// GetGlobalRegistry returns the global model registry instance +func GetGlobalRegistry() *ModelRegistry { + registryOnce.Do(func() { + globalRegistry = &ModelRegistry{ + models: make(map[string]*ModelRegistration), + clientModels: make(map[string][]string), + mutex: &sync.RWMutex{}, + } + }) + return globalRegistry +} + +// RegisterClient registers a client and its supported models +// Parameters: +// - clientID: Unique identifier for the client +// - clientProvider: Provider name (e.g., "gemini", "claude", "openai") +// - models: List of models that this client can provide +func (r *ModelRegistry) RegisterClient(clientID, clientProvider string, models []*ModelInfo) { + r.mutex.Lock() + defer r.mutex.Unlock() + + // Remove any existing registration for this client + r.unregisterClientInternal(clientID) + + modelIDs := make([]string, 0, len(models)) + now := time.Now() + + for _, model := range models { + modelIDs = append(modelIDs, model.ID) + + if existing, exists := r.models[model.ID]; exists { + // Model already exists, increment count + existing.Count++ + existing.LastUpdated = now + log.Debugf("Incremented count for model %s, now %d clients", model.ID, existing.Count) + } else { + // New model, create registration + r.models[model.ID] = &ModelRegistration{ + Info: model, + Count: 1, + LastUpdated: now, + QuotaExceededClients: make(map[string]*time.Time), + } + log.Debugf("Registered new model %s from provider %s", model.ID, clientProvider) + } + } + + r.clientModels[clientID] = modelIDs + log.Debugf("Registered client %s from provider %s with %d models", clientID, clientProvider, len(models)) +} + +// UnregisterClient removes a client and decrements counts for its models +// Parameters: +// - clientID: Unique identifier for the client to remove +func (r *ModelRegistry) UnregisterClient(clientID string) { + r.mutex.Lock() + defer r.mutex.Unlock() + r.unregisterClientInternal(clientID) +} + +// unregisterClientInternal performs the actual client unregistration (internal, no locking) +func (r *ModelRegistry) unregisterClientInternal(clientID string) { + models, exists := r.clientModels[clientID] + if !exists { + return + } + + now := time.Now() + for _, modelID := range models { + if registration, isExists := r.models[modelID]; isExists { + registration.Count-- + registration.LastUpdated = now + + // Remove quota tracking for this client + delete(registration.QuotaExceededClients, clientID) + + log.Debugf("Decremented count for model %s, now %d clients", modelID, registration.Count) + + // Remove model if no clients remain + if registration.Count <= 0 { + delete(r.models, modelID) + log.Debugf("Removed model %s as no clients remain", modelID) + } + } + } + + delete(r.clientModels, clientID) + log.Debugf("Unregistered client %s", clientID) +} + +// SetModelQuotaExceeded marks a model as quota exceeded for a specific client +// Parameters: +// - clientID: The client that exceeded quota +// - modelID: The model that exceeded quota +func (r *ModelRegistry) SetModelQuotaExceeded(clientID, modelID string) { + r.mutex.Lock() + defer r.mutex.Unlock() + + if registration, exists := r.models[modelID]; exists { + now := time.Now() + registration.QuotaExceededClients[clientID] = &now + log.Debugf("Marked model %s as quota exceeded for client %s", modelID, clientID) + } +} + +// ClearModelQuotaExceeded removes quota exceeded status for a model and client +// Parameters: +// - clientID: The client to clear quota status for +// - modelID: The model to clear quota status for +func (r *ModelRegistry) ClearModelQuotaExceeded(clientID, modelID string) { + r.mutex.Lock() + defer r.mutex.Unlock() + + if registration, exists := r.models[modelID]; exists { + delete(registration.QuotaExceededClients, clientID) + log.Debugf("Cleared quota exceeded status for model %s and client %s", modelID, clientID) + } +} + +// GetAvailableModels returns all models that have at least one available client +// Parameters: +// - handlerType: The handler type to filter models for (e.g., "openai", "claude", "gemini") +// +// Returns: +// - []map[string]any: List of available models in the requested format +func (r *ModelRegistry) GetAvailableModels(handlerType string) []map[string]any { + r.mutex.RLock() + defer r.mutex.RUnlock() + + models := make([]map[string]any, 0) + quotaExpiredDuration := 5 * time.Minute + + for _, registration := range r.models { + // Check if model has any non-quota-exceeded clients + availableClients := registration.Count + now := time.Now() + + // Count clients that have exceeded quota but haven't recovered yet + expiredClients := 0 + for _, quotaTime := range registration.QuotaExceededClients { + if quotaTime != nil && now.Sub(*quotaTime) < quotaExpiredDuration { + expiredClients++ + } + } + + effectiveClients := availableClients - expiredClients + + // Only include models that have available clients + if effectiveClients > 0 { + model := r.convertModelToMap(registration.Info, handlerType) + if model != nil { + models = append(models, model) + } + } + } + + return models +} + +// GetModelCount returns the number of available clients for a specific model +// Parameters: +// - modelID: The model ID to check +// +// Returns: +// - int: Number of available clients for the model +func (r *ModelRegistry) GetModelCount(modelID string) int { + r.mutex.RLock() + defer r.mutex.RUnlock() + + if registration, exists := r.models[modelID]; exists { + now := time.Now() + quotaExpiredDuration := 5 * time.Minute + + // Count clients that have exceeded quota but haven't recovered yet + expiredClients := 0 + for _, quotaTime := range registration.QuotaExceededClients { + if quotaTime != nil && now.Sub(*quotaTime) < quotaExpiredDuration { + expiredClients++ + } + } + + return registration.Count - expiredClients + } + return 0 +} + +// convertModelToMap converts ModelInfo to the appropriate format for different handler types +func (r *ModelRegistry) convertModelToMap(model *ModelInfo, handlerType string) map[string]any { + if model == nil { + return nil + } + + switch handlerType { + case "openai": + result := map[string]any{ + "id": model.ID, + "object": "model", + "owned_by": model.OwnedBy, + } + if model.Created > 0 { + result["created"] = model.Created + } + if model.Type != "" { + result["type"] = model.Type + } + if model.DisplayName != "" { + result["display_name"] = model.DisplayName + } + if model.Version != "" { + result["version"] = model.Version + } + if model.Description != "" { + result["description"] = model.Description + } + if model.ContextLength > 0 { + result["context_length"] = model.ContextLength + } + if model.MaxCompletionTokens > 0 { + result["max_completion_tokens"] = model.MaxCompletionTokens + } + if len(model.SupportedParameters) > 0 { + result["supported_parameters"] = model.SupportedParameters + } + return result + + case "claude": + result := map[string]any{ + "id": model.ID, + "object": "model", + "owned_by": model.OwnedBy, + } + if model.Created > 0 { + result["created"] = model.Created + } + if model.Type != "" { + result["type"] = model.Type + } + if model.DisplayName != "" { + result["display_name"] = model.DisplayName + } + return result + + case "gemini": + result := map[string]any{} + if model.Name != "" { + result["name"] = model.Name + } else { + result["name"] = model.ID + } + if model.Version != "" { + result["version"] = model.Version + } + if model.DisplayName != "" { + result["displayName"] = model.DisplayName + } + if model.Description != "" { + result["description"] = model.Description + } + if model.InputTokenLimit > 0 { + result["inputTokenLimit"] = model.InputTokenLimit + } + if model.OutputTokenLimit > 0 { + result["outputTokenLimit"] = model.OutputTokenLimit + } + if len(model.SupportedGenerationMethods) > 0 { + result["supportedGenerationMethods"] = model.SupportedGenerationMethods + } + return result + + default: + // Generic format + result := map[string]any{ + "id": model.ID, + "object": "model", + } + if model.OwnedBy != "" { + result["owned_by"] = model.OwnedBy + } + if model.Type != "" { + result["type"] = model.Type + } + return result + } +} + +// CleanupExpiredQuotas removes expired quota tracking entries +func (r *ModelRegistry) CleanupExpiredQuotas() { + r.mutex.Lock() + defer r.mutex.Unlock() + + now := time.Now() + quotaExpiredDuration := 5 * time.Minute + + for modelID, registration := range r.models { + for clientID, quotaTime := range registration.QuotaExceededClients { + if quotaTime != nil && now.Sub(*quotaTime) >= quotaExpiredDuration { + delete(registration.QuotaExceededClients, clientID) + log.Debugf("Cleaned up expired quota tracking for model %s, client %s", modelID, clientID) + } + } + } +}