Improved the /v1/models endpoint

This commit is contained in:
Luis Pater
2025-08-27 20:30:17 +08:00
parent ed8873fbb0
commit dff31a7a4c
13 changed files with 757 additions and 136 deletions

View File

@@ -16,6 +16,7 @@ import (
"github.com/luispater/CLIProxyAPI/internal/api/handlers" "github.com/luispater/CLIProxyAPI/internal/api/handlers"
. "github.com/luispater/CLIProxyAPI/internal/constant" . "github.com/luispater/CLIProxyAPI/internal/constant"
"github.com/luispater/CLIProxyAPI/internal/interfaces" "github.com/luispater/CLIProxyAPI/internal/interfaces"
"github.com/luispater/CLIProxyAPI/internal/registry"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/tidwall/gjson" "github.com/tidwall/gjson"
) )
@@ -47,7 +48,9 @@ func (h *ClaudeCodeAPIHandler) HandlerType() string {
// Models returns a list of models supported by this handler. // Models returns a list of models supported by this handler.
func (h *ClaudeCodeAPIHandler) Models() []map[string]any { func (h *ClaudeCodeAPIHandler) Models() []map[string]any {
return make([]map[string]any, 0) // Get dynamic models from the global registry
modelRegistry := registry.GetGlobalRegistry()
return modelRegistry.GetAvailableModels("claude")
} }
// ClaudeMessages handles Claude-compatible streaming chat completions. // ClaudeMessages handles Claude-compatible streaming chat completions.
@@ -79,6 +82,17 @@ func (h *ClaudeCodeAPIHandler) ClaudeMessages(c *gin.Context) {
h.handleStreamingResponse(c, rawJSON) h.handleStreamingResponse(c, rawJSON)
} }
// ClaudeModels handles the Claude models listing endpoint.
// It returns a JSON response containing available Claude models and their specifications.
//
// Parameters:
// - c: The Gin context for the request.
func (h *ClaudeCodeAPIHandler) ClaudeModels(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{
"data": h.Models(),
})
}
// handleStreamingResponse streams Claude-compatible responses backed by Gemini. // handleStreamingResponse streams Claude-compatible responses backed by Gemini.
// It sets up SSE, selects a backend client with rotation/quota logic, // It sets up SSE, selects a backend client with rotation/quota logic,
// forwards chunks, and translates them to Claude CLI format. // forwards chunks, and translates them to Claude CLI format.

View File

@@ -16,6 +16,7 @@ import (
"github.com/luispater/CLIProxyAPI/internal/api/handlers" "github.com/luispater/CLIProxyAPI/internal/api/handlers"
. "github.com/luispater/CLIProxyAPI/internal/constant" . "github.com/luispater/CLIProxyAPI/internal/constant"
"github.com/luispater/CLIProxyAPI/internal/interfaces" "github.com/luispater/CLIProxyAPI/internal/interfaces"
"github.com/luispater/CLIProxyAPI/internal/registry"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
) )
@@ -40,62 +41,9 @@ func (h *GeminiAPIHandler) HandlerType() string {
// Models returns the Gemini-compatible model metadata supported by this handler. // Models returns the Gemini-compatible model metadata supported by this handler.
func (h *GeminiAPIHandler) Models() []map[string]any { func (h *GeminiAPIHandler) Models() []map[string]any {
return []map[string]any{ // Get dynamic models from the global registry
{ modelRegistry := registry.GetGlobalRegistry()
"name": "models/gemini-2.5-flash", return modelRegistry.GetAvailableModels("gemini")
"version": "001",
"displayName": "Gemini 2.5 Flash",
"description": "Stable version of Gemini 2.5 Flash, our mid-size multimodal model that supports up to 1 million tokens, released in June of 2025.",
"inputTokenLimit": 1048576,
"outputTokenLimit": 65536,
"supportedGenerationMethods": []string{
"generateContent",
"countTokens",
"createCachedContent",
"batchGenerateContent",
},
"temperature": 1,
"topP": 0.95,
"topK": 64,
"maxTemperature": 2,
"thinking": true,
},
{
"name": "models/gemini-2.5-pro",
"version": "2.5",
"displayName": "Gemini 2.5 Pro",
"description": "Stable release (June 17th, 2025) of Gemini 2.5 Pro",
"inputTokenLimit": 1048576,
"outputTokenLimit": 65536,
"supportedGenerationMethods": []string{
"generateContent",
"countTokens",
"createCachedContent",
"batchGenerateContent",
},
"temperature": 1,
"topP": 0.95,
"topK": 64,
"maxTemperature": 2,
"thinking": true,
},
{
"name": "gpt-5",
"version": "001",
"displayName": "GPT 5",
"description": "Stable version of GPT 5, The best model for coding and agentic tasks across domains.",
"inputTokenLimit": 400000,
"outputTokenLimit": 128000,
"supportedGenerationMethods": []string{
"generateContent",
},
"temperature": 1,
"topP": 0.95,
"topK": 64,
"maxTemperature": 2,
"thinking": true,
},
}
} }
// GeminiModels handles the Gemini models listing endpoint. // GeminiModels handles the Gemini models listing endpoint.

View File

@@ -16,6 +16,7 @@ import (
"github.com/luispater/CLIProxyAPI/internal/api/handlers" "github.com/luispater/CLIProxyAPI/internal/api/handlers"
. "github.com/luispater/CLIProxyAPI/internal/constant" . "github.com/luispater/CLIProxyAPI/internal/constant"
"github.com/luispater/CLIProxyAPI/internal/interfaces" "github.com/luispater/CLIProxyAPI/internal/interfaces"
"github.com/luispater/CLIProxyAPI/internal/registry"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/tidwall/gjson" "github.com/tidwall/gjson"
) )
@@ -47,90 +48,18 @@ func (h *OpenAIAPIHandler) HandlerType() string {
// Models returns the OpenAI-compatible model metadata supported by this handler. // Models returns the OpenAI-compatible model metadata supported by this handler.
func (h *OpenAIAPIHandler) Models() []map[string]any { func (h *OpenAIAPIHandler) Models() []map[string]any {
return []map[string]any{ // Get dynamic models from the global registry
{ modelRegistry := registry.GetGlobalRegistry()
"id": "gemini-2.5-pro", return modelRegistry.GetAvailableModels("openai")
"object": "model",
"version": "2.5",
"name": "Gemini 2.5 Pro",
"description": "Stable release (June 17th, 2025) of Gemini 2.5 Pro",
"context_length": 1_048_576,
"max_completion_tokens": 65_536,
"supported_parameters": []string{
"tools",
"temperature",
"top_p",
"top_k",
},
"temperature": 1,
"topP": 0.95,
"topK": 64,
"maxTemperature": 2,
"thinking": true,
},
{
"id": "gemini-2.5-flash",
"object": "model",
"version": "001",
"name": "Gemini 2.5 Flash",
"description": "Stable version of Gemini 2.5 Flash, our mid-size multimodal model that supports up to 1 million tokens, released in June of 2025.",
"context_length": 1_048_576,
"max_completion_tokens": 65_536,
"supported_parameters": []string{
"tools",
"temperature",
"top_p",
"top_k",
},
"temperature": 1,
"topP": 0.95,
"topK": 64,
"maxTemperature": 2,
"thinking": true,
},
{
"id": "gpt-5",
"object": "model",
"version": "gpt-5-2025-08-07",
"name": "GPT 5",
"description": "Stable version of GPT 5, The best model for coding and agentic tasks across domains.",
"context_length": 400_000,
"max_completion_tokens": 128_000,
"supported_parameters": []string{
"tools",
},
"temperature": 1,
"topP": 0.95,
"topK": 64,
"maxTemperature": 2,
"thinking": true,
},
{
"id": "claude-opus-4-1-20250805",
"object": "model",
"version": "claude-opus-4-1-20250805",
"name": "Claude Opus 4.1",
"description": "Anthropic's most capable model.",
"context_length": 200_000,
"max_completion_tokens": 32_000,
"supported_parameters": []string{
"tools",
},
"temperature": 1,
"topP": 0.95,
"topK": 64,
"maxTemperature": 2,
"thinking": true,
},
}
} }
// OpenAIModels handles the /v1/models endpoint. // OpenAIModels handles the /v1/models endpoint.
// It returns a hardcoded list of available AI models with their capabilities // It returns a list of available AI models with their capabilities
// and specifications in OpenAI-compatible format. // and specifications in OpenAI-compatible format.
func (h *OpenAIAPIHandler) OpenAIModels(c *gin.Context) { func (h *OpenAIAPIHandler) OpenAIModels(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{ c.JSON(http.StatusOK, gin.H{
"data": h.Models(), "object": "list",
"data": h.Models(),
}) })
} }

View File

@@ -102,7 +102,7 @@ func (s *Server) setupRoutes() {
v1 := s.engine.Group("/v1") v1 := s.engine.Group("/v1")
v1.Use(AuthMiddleware(s.cfg)) v1.Use(AuthMiddleware(s.cfg))
{ {
v1.GET("/models", openaiHandlers.OpenAIModels) v1.GET("/models", s.unifiedModelsHandler(openaiHandlers, claudeCodeHandlers))
v1.POST("/chat/completions", openaiHandlers.ChatCompletions) v1.POST("/chat/completions", openaiHandlers.ChatCompletions)
v1.POST("/messages", claudeCodeHandlers.ClaudeMessages) v1.POST("/messages", claudeCodeHandlers.ClaudeMessages)
} }
@@ -130,6 +130,25 @@ func (s *Server) setupRoutes() {
s.engine.POST("/v1internal:method", geminiCLIHandlers.CLIHandler) s.engine.POST("/v1internal:method", geminiCLIHandlers.CLIHandler)
} }
// unifiedModelsHandler creates a unified handler for the /v1/models endpoint
// that routes to different handlers based on the User-Agent header.
// If User-Agent starts with "claude-cli", it routes to Claude handler,
// otherwise it routes to OpenAI handler.
func (s *Server) unifiedModelsHandler(openaiHandler *openai.OpenAIAPIHandler, claudeHandler *claude.ClaudeCodeAPIHandler) gin.HandlerFunc {
return func(c *gin.Context) {
userAgent := c.GetHeader("User-Agent")
// Route to Claude handler if User-Agent starts with "claude-cli"
if strings.HasPrefix(userAgent, "claude-cli") {
log.Debugf("Routing /v1/models to Claude handler for User-Agent: %s", userAgent)
claudeHandler.ClaudeModels(c)
} else {
log.Debugf("Routing /v1/models to OpenAI handler for User-Agent: %s", userAgent)
openaiHandler.OpenAIModels(c)
}
}
}
// Start begins listening for and serving HTTP requests. // Start begins listening for and serving HTTP requests.
// It's a blocking call and will only return on an unrecoverable error. // It's a blocking call and will only return on an unrecoverable error.
// //

View File

@@ -23,6 +23,7 @@ import (
. "github.com/luispater/CLIProxyAPI/internal/constant" . "github.com/luispater/CLIProxyAPI/internal/constant"
"github.com/luispater/CLIProxyAPI/internal/interfaces" "github.com/luispater/CLIProxyAPI/internal/interfaces"
"github.com/luispater/CLIProxyAPI/internal/misc" "github.com/luispater/CLIProxyAPI/internal/misc"
"github.com/luispater/CLIProxyAPI/internal/registry"
"github.com/luispater/CLIProxyAPI/internal/translator/translator" "github.com/luispater/CLIProxyAPI/internal/translator/translator"
"github.com/luispater/CLIProxyAPI/internal/util" "github.com/luispater/CLIProxyAPI/internal/util"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
@@ -55,6 +56,10 @@ type ClaudeClient struct {
// - *ClaudeClient: A new Claude client instance. // - *ClaudeClient: A new Claude client instance.
func NewClaudeClient(cfg *config.Config, ts *claude.ClaudeTokenStorage) *ClaudeClient { func NewClaudeClient(cfg *config.Config, ts *claude.ClaudeTokenStorage) *ClaudeClient {
httpClient := util.SetProxy(cfg, &http.Client{}) httpClient := util.SetProxy(cfg, &http.Client{})
// Generate unique client ID
clientID := fmt.Sprintf("claude-%d", time.Now().UnixNano())
client := &ClaudeClient{ client := &ClaudeClient{
ClientBase: ClientBase{ ClientBase: ClientBase{
RequestMutex: &sync.Mutex{}, RequestMutex: &sync.Mutex{},
@@ -67,6 +72,10 @@ func NewClaudeClient(cfg *config.Config, ts *claude.ClaudeTokenStorage) *ClaudeC
apiKeyIndex: -1, apiKeyIndex: -1,
} }
// Initialize model registry and register Claude models
client.InitializeModelRegistry(clientID)
client.RegisterModels("claude", registry.GetClaudeModels())
return client return client
} }
@@ -82,6 +91,10 @@ func NewClaudeClient(cfg *config.Config, ts *claude.ClaudeTokenStorage) *ClaudeC
// - *ClaudeClient: A new Claude client instance. // - *ClaudeClient: A new Claude client instance.
func NewClaudeClientWithKey(cfg *config.Config, apiKeyIndex int) *ClaudeClient { func NewClaudeClientWithKey(cfg *config.Config, apiKeyIndex int) *ClaudeClient {
httpClient := util.SetProxy(cfg, &http.Client{}) httpClient := util.SetProxy(cfg, &http.Client{})
// Generate unique client ID for API key client
clientID := fmt.Sprintf("claude-apikey-%d-%d", apiKeyIndex, time.Now().UnixNano())
client := &ClaudeClient{ client := &ClaudeClient{
ClientBase: ClientBase{ ClientBase: ClientBase{
RequestMutex: &sync.Mutex{}, RequestMutex: &sync.Mutex{},
@@ -94,6 +107,10 @@ func NewClaudeClientWithKey(cfg *config.Config, apiKeyIndex int) *ClaudeClient {
apiKeyIndex: apiKeyIndex, apiKeyIndex: apiKeyIndex,
} }
// Initialize model registry and register Claude models
client.InitializeModelRegistry(clientID)
client.RegisterModels("claude", registry.GetClaudeModels())
return client return client
} }
@@ -174,10 +191,14 @@ func (c *ClaudeClient) SendRawMessage(ctx context.Context, modelName string, raw
if err.StatusCode == 429 { if err.StatusCode == 429 {
now := time.Now() now := time.Now()
c.modelQuotaExceeded[modelName] = &now c.modelQuotaExceeded[modelName] = &now
// Update model registry quota status
c.SetModelQuotaExceeded(modelName)
} }
return nil, err return nil, err
} }
delete(c.modelQuotaExceeded, modelName) delete(c.modelQuotaExceeded, modelName)
// Clear quota status in model registry
c.ClearModelQuotaExceeded(modelName)
bodyBytes, errReadAll := io.ReadAll(respBody) bodyBytes, errReadAll := io.ReadAll(respBody)
if errReadAll != nil { if errReadAll != nil {
return nil, &interfaces.ErrorMessage{StatusCode: 500, Error: errReadAll} return nil, &interfaces.ErrorMessage{StatusCode: 500, Error: errReadAll}
@@ -234,11 +255,15 @@ func (c *ClaudeClient) SendRawMessageStream(ctx context.Context, modelName strin
if err.StatusCode == 429 { if err.StatusCode == 429 {
now := time.Now() now := time.Now()
c.modelQuotaExceeded[modelName] = &now c.modelQuotaExceeded[modelName] = &now
// Update model registry quota status
c.SetModelQuotaExceeded(modelName)
} }
errChan <- err errChan <- err
return return
} }
delete(c.modelQuotaExceeded, modelName) delete(c.modelQuotaExceeded, modelName)
// Clear quota status in model registry
c.ClearModelQuotaExceeded(modelName)
defer func() { defer func() {
_ = stream.Close() _ = stream.Close()
}() }()

View File

@@ -13,6 +13,7 @@ import (
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/luispater/CLIProxyAPI/internal/auth" "github.com/luispater/CLIProxyAPI/internal/auth"
"github.com/luispater/CLIProxyAPI/internal/config" "github.com/luispater/CLIProxyAPI/internal/config"
"github.com/luispater/CLIProxyAPI/internal/registry"
) )
// ClientBase provides a common base structure for all AI API clients. // ClientBase provides a common base structure for all AI API clients.
@@ -34,6 +35,12 @@ type ClientBase struct {
// modelQuotaExceeded tracks when models have exceeded their quota. // modelQuotaExceeded tracks when models have exceeded their quota.
// The map key is the model name, and the value is the time when the quota was exceeded. // The map key is the model name, and the value is the time when the quota was exceeded.
modelQuotaExceeded map[string]*time.Time modelQuotaExceeded map[string]*time.Time
// clientID is the unique identifier for this client instance.
clientID string
// modelRegistry is the global model registry for tracking model availability.
modelRegistry *registry.ModelRegistry
} }
// GetRequestMutex returns the mutex used to synchronize requests for this client. // GetRequestMutex returns the mutex used to synchronize requests for this client.
@@ -71,3 +78,50 @@ func (c *ClientBase) AddAPIResponseData(ctx context.Context, line []byte) {
} }
} }
} }
// InitializeModelRegistry initializes the model registry for this client
// This should be called by all client implementations during construction
func (c *ClientBase) InitializeModelRegistry(clientID string) {
c.clientID = clientID
c.modelRegistry = registry.GetGlobalRegistry()
}
// RegisterModels registers the models that this client can provide
// Parameters:
// - provider: The provider name (e.g., "gemini", "claude", "openai")
// - models: The list of models this client supports
func (c *ClientBase) RegisterModels(provider string, models []*registry.ModelInfo) {
if c.modelRegistry != nil && c.clientID != "" {
c.modelRegistry.RegisterClient(c.clientID, provider, models)
}
}
// UnregisterClient removes this client from the model registry
func (c *ClientBase) UnregisterClient() {
if c.modelRegistry != nil && c.clientID != "" {
c.modelRegistry.UnregisterClient(c.clientID)
}
}
// SetModelQuotaExceeded marks a model as quota exceeded in the registry
// Parameters:
// - modelID: The model that exceeded quota
func (c *ClientBase) SetModelQuotaExceeded(modelID string) {
if c.modelRegistry != nil && c.clientID != "" {
c.modelRegistry.SetModelQuotaExceeded(c.clientID, modelID)
}
}
// ClearModelQuotaExceeded clears quota exceeded status for a model
// Parameters:
// - modelID: The model to clear quota status for
func (c *ClientBase) ClearModelQuotaExceeded(modelID string) {
if c.modelRegistry != nil && c.clientID != "" {
c.modelRegistry.ClearModelQuotaExceeded(c.clientID, modelID)
}
}
// GetClientID returns the unique identifier for this client
func (c *ClientBase) GetClientID() string {
return c.clientID
}

View File

@@ -22,6 +22,7 @@ import (
"github.com/luispater/CLIProxyAPI/internal/config" "github.com/luispater/CLIProxyAPI/internal/config"
. "github.com/luispater/CLIProxyAPI/internal/constant" . "github.com/luispater/CLIProxyAPI/internal/constant"
"github.com/luispater/CLIProxyAPI/internal/interfaces" "github.com/luispater/CLIProxyAPI/internal/interfaces"
"github.com/luispater/CLIProxyAPI/internal/registry"
"github.com/luispater/CLIProxyAPI/internal/translator/translator" "github.com/luispater/CLIProxyAPI/internal/translator/translator"
"github.com/luispater/CLIProxyAPI/internal/util" "github.com/luispater/CLIProxyAPI/internal/util"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
@@ -50,6 +51,10 @@ type CodexClient struct {
// - error: An error if the client creation fails. // - error: An error if the client creation fails.
func NewCodexClient(cfg *config.Config, ts *codex.CodexTokenStorage) (*CodexClient, error) { func NewCodexClient(cfg *config.Config, ts *codex.CodexTokenStorage) (*CodexClient, error) {
httpClient := util.SetProxy(cfg, &http.Client{}) httpClient := util.SetProxy(cfg, &http.Client{})
// Generate unique client ID
clientID := fmt.Sprintf("codex-%d", time.Now().UnixNano())
client := &CodexClient{ client := &CodexClient{
ClientBase: ClientBase{ ClientBase: ClientBase{
RequestMutex: &sync.Mutex{}, RequestMutex: &sync.Mutex{},
@@ -61,6 +66,10 @@ func NewCodexClient(cfg *config.Config, ts *codex.CodexTokenStorage) (*CodexClie
codexAuth: codex.NewCodexAuth(cfg), codexAuth: codex.NewCodexAuth(cfg),
} }
// Initialize model registry and register OpenAI models
client.InitializeModelRegistry(clientID)
client.RegisterModels("codex", registry.GetOpenAIModels())
return client, nil return client, nil
} }
@@ -123,10 +132,14 @@ func (c *CodexClient) SendRawMessage(ctx context.Context, modelName string, rawJ
if err.StatusCode == 429 { if err.StatusCode == 429 {
now := time.Now() now := time.Now()
c.modelQuotaExceeded[modelName] = &now c.modelQuotaExceeded[modelName] = &now
// Update model registry quota status
c.SetModelQuotaExceeded(modelName)
} }
return nil, err return nil, err
} }
delete(c.modelQuotaExceeded, modelName) delete(c.modelQuotaExceeded, modelName)
// Clear quota status in model registry
c.ClearModelQuotaExceeded(modelName)
bodyBytes, errReadAll := io.ReadAll(respBody) bodyBytes, errReadAll := io.ReadAll(respBody)
if errReadAll != nil { if errReadAll != nil {
return nil, &interfaces.ErrorMessage{StatusCode: 500, Error: errReadAll} return nil, &interfaces.ErrorMessage{StatusCode: 500, Error: errReadAll}
@@ -184,11 +197,15 @@ func (c *CodexClient) SendRawMessageStream(ctx context.Context, modelName string
if err.StatusCode == 429 { if err.StatusCode == 429 {
now := time.Now() now := time.Now()
c.modelQuotaExceeded[modelName] = &now c.modelQuotaExceeded[modelName] = &now
// Update model registry quota status
c.SetModelQuotaExceeded(modelName)
} }
errChan <- err errChan <- err
return return
} }
delete(c.modelQuotaExceeded, modelName) delete(c.modelQuotaExceeded, modelName)
// Clear quota status in model registry
c.ClearModelQuotaExceeded(modelName)
defer func() { defer func() {
_ = stream.Close() _ = stream.Close()
}() }()

View File

@@ -22,6 +22,7 @@ import (
"github.com/luispater/CLIProxyAPI/internal/config" "github.com/luispater/CLIProxyAPI/internal/config"
. "github.com/luispater/CLIProxyAPI/internal/constant" . "github.com/luispater/CLIProxyAPI/internal/constant"
"github.com/luispater/CLIProxyAPI/internal/interfaces" "github.com/luispater/CLIProxyAPI/internal/interfaces"
"github.com/luispater/CLIProxyAPI/internal/registry"
"github.com/luispater/CLIProxyAPI/internal/translator/translator" "github.com/luispater/CLIProxyAPI/internal/translator/translator"
"github.com/luispater/CLIProxyAPI/internal/util" "github.com/luispater/CLIProxyAPI/internal/util"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
@@ -57,6 +58,9 @@ type GeminiCLIClient struct {
// Returns: // Returns:
// - *GeminiCLIClient: A new Gemini CLI client instance. // - *GeminiCLIClient: A new Gemini CLI client instance.
func NewGeminiCLIClient(httpClient *http.Client, ts *geminiAuth.GeminiTokenStorage, cfg *config.Config) *GeminiCLIClient { func NewGeminiCLIClient(httpClient *http.Client, ts *geminiAuth.GeminiTokenStorage, cfg *config.Config) *GeminiCLIClient {
// Generate unique client ID
clientID := fmt.Sprintf("gemini-cli-%d", time.Now().UnixNano())
client := &GeminiCLIClient{ client := &GeminiCLIClient{
ClientBase: ClientBase{ ClientBase: ClientBase{
RequestMutex: &sync.Mutex{}, RequestMutex: &sync.Mutex{},
@@ -66,6 +70,11 @@ func NewGeminiCLIClient(httpClient *http.Client, ts *geminiAuth.GeminiTokenStora
modelQuotaExceeded: make(map[string]*time.Time), modelQuotaExceeded: make(map[string]*time.Time),
}, },
} }
// Initialize model registry and register Gemini models
client.InitializeModelRegistry(clientID)
client.RegisterModels("gemini-cli", registry.GetGeminiModels())
return client return client
} }
@@ -426,6 +435,8 @@ func (c *GeminiCLIClient) SendRawTokenCount(ctx context.Context, modelName strin
if err.StatusCode == 429 { if err.StatusCode == 429 {
now := time.Now() now := time.Now()
c.modelQuotaExceeded[modelName] = &now c.modelQuotaExceeded[modelName] = &now
// Update model registry quota status
c.SetModelQuotaExceeded(modelName)
if c.cfg.QuotaExceeded.SwitchPreviewModel { if c.cfg.QuotaExceeded.SwitchPreviewModel {
continue continue
} }
@@ -433,6 +444,8 @@ func (c *GeminiCLIClient) SendRawTokenCount(ctx context.Context, modelName strin
return nil, err return nil, err
} }
delete(c.modelQuotaExceeded, modelName) delete(c.modelQuotaExceeded, modelName)
// Clear quota status in model registry
c.ClearModelQuotaExceeded(modelName)
bodyBytes, errReadAll := io.ReadAll(respBody) bodyBytes, errReadAll := io.ReadAll(respBody)
if errReadAll != nil { if errReadAll != nil {
return nil, &interfaces.ErrorMessage{StatusCode: 500, Error: errReadAll} return nil, &interfaces.ErrorMessage{StatusCode: 500, Error: errReadAll}
@@ -485,6 +498,8 @@ func (c *GeminiCLIClient) SendRawMessage(ctx context.Context, modelName string,
if err.StatusCode == 429 { if err.StatusCode == 429 {
now := time.Now() now := time.Now()
c.modelQuotaExceeded[modelName] = &now c.modelQuotaExceeded[modelName] = &now
// Update model registry quota status
c.SetModelQuotaExceeded(modelName)
if c.cfg.QuotaExceeded.SwitchPreviewModel { if c.cfg.QuotaExceeded.SwitchPreviewModel {
continue continue
} }
@@ -492,6 +507,8 @@ func (c *GeminiCLIClient) SendRawMessage(ctx context.Context, modelName string,
return nil, err return nil, err
} }
delete(c.modelQuotaExceeded, modelName) delete(c.modelQuotaExceeded, modelName)
// Clear quota status in model registry
c.ClearModelQuotaExceeded(modelName)
bodyBytes, errReadAll := io.ReadAll(respBody) bodyBytes, errReadAll := io.ReadAll(respBody)
if errReadAll != nil { if errReadAll != nil {
return nil, &interfaces.ErrorMessage{StatusCode: 500, Error: errReadAll} return nil, &interfaces.ErrorMessage{StatusCode: 500, Error: errReadAll}
@@ -562,6 +579,8 @@ func (c *GeminiCLIClient) SendRawMessageStream(ctx context.Context, modelName st
if err.StatusCode == 429 { if err.StatusCode == 429 {
now := time.Now() now := time.Now()
c.modelQuotaExceeded[modelName] = &now c.modelQuotaExceeded[modelName] = &now
// Update model registry quota status
c.SetModelQuotaExceeded(modelName)
if c.cfg.QuotaExceeded.SwitchPreviewModel { if c.cfg.QuotaExceeded.SwitchPreviewModel {
continue continue
} }
@@ -570,6 +589,8 @@ func (c *GeminiCLIClient) SendRawMessageStream(ctx context.Context, modelName st
return return
} }
delete(c.modelQuotaExceeded, modelName) delete(c.modelQuotaExceeded, modelName)
// Clear quota status in model registry
c.ClearModelQuotaExceeded(modelName)
break break
} }
defer func() { defer func() {

View File

@@ -18,6 +18,7 @@ import (
"github.com/luispater/CLIProxyAPI/internal/config" "github.com/luispater/CLIProxyAPI/internal/config"
. "github.com/luispater/CLIProxyAPI/internal/constant" . "github.com/luispater/CLIProxyAPI/internal/constant"
"github.com/luispater/CLIProxyAPI/internal/interfaces" "github.com/luispater/CLIProxyAPI/internal/interfaces"
"github.com/luispater/CLIProxyAPI/internal/registry"
"github.com/luispater/CLIProxyAPI/internal/translator/translator" "github.com/luispater/CLIProxyAPI/internal/translator/translator"
"github.com/luispater/CLIProxyAPI/internal/util" "github.com/luispater/CLIProxyAPI/internal/util"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
@@ -44,6 +45,9 @@ type GeminiClient struct {
// Returns: // Returns:
// - *GeminiClient: A new Gemini client instance. // - *GeminiClient: A new Gemini client instance.
func NewGeminiClient(httpClient *http.Client, cfg *config.Config, glAPIKey string) *GeminiClient { func NewGeminiClient(httpClient *http.Client, cfg *config.Config, glAPIKey string) *GeminiClient {
// Generate unique client ID
clientID := fmt.Sprintf("gemini-apikey-%s-%d", glAPIKey[:8], time.Now().UnixNano()) // Use first 8 chars of API key
client := &GeminiClient{ client := &GeminiClient{
ClientBase: ClientBase{ ClientBase: ClientBase{
RequestMutex: &sync.Mutex{}, RequestMutex: &sync.Mutex{},
@@ -53,6 +57,11 @@ func NewGeminiClient(httpClient *http.Client, cfg *config.Config, glAPIKey strin
}, },
glAPIKey: glAPIKey, glAPIKey: glAPIKey,
} }
// Initialize model registry and register Gemini models
client.InitializeModelRegistry(clientID)
client.RegisterModels("gemini", registry.GetGeminiModels())
return client return client
} }
@@ -195,10 +204,14 @@ func (c *GeminiClient) SendRawTokenCount(ctx context.Context, modelName string,
if err.StatusCode == 429 { if err.StatusCode == 429 {
now := time.Now() now := time.Now()
c.modelQuotaExceeded[modelName] = &now c.modelQuotaExceeded[modelName] = &now
// Update model registry quota status
c.SetModelQuotaExceeded(modelName)
} }
return nil, err return nil, err
} }
delete(c.modelQuotaExceeded, modelName) delete(c.modelQuotaExceeded, modelName)
// Clear quota status in model registry
c.ClearModelQuotaExceeded(modelName)
bodyBytes, errReadAll := io.ReadAll(respBody) bodyBytes, errReadAll := io.ReadAll(respBody)
if errReadAll != nil { if errReadAll != nil {
return nil, &interfaces.ErrorMessage{StatusCode: 500, Error: errReadAll} return nil, &interfaces.ErrorMessage{StatusCode: 500, Error: errReadAll}
@@ -240,10 +253,14 @@ func (c *GeminiClient) SendRawMessage(ctx context.Context, modelName string, raw
if err.StatusCode == 429 { if err.StatusCode == 429 {
now := time.Now() now := time.Now()
c.modelQuotaExceeded[modelName] = &now c.modelQuotaExceeded[modelName] = &now
// Update model registry quota status
c.SetModelQuotaExceeded(modelName)
} }
return nil, err return nil, err
} }
delete(c.modelQuotaExceeded, modelName) delete(c.modelQuotaExceeded, modelName)
// Clear quota status in model registry
c.ClearModelQuotaExceeded(modelName)
bodyBytes, errReadAll := io.ReadAll(respBody) bodyBytes, errReadAll := io.ReadAll(respBody)
if errReadAll != nil { if errReadAll != nil {
return nil, &interfaces.ErrorMessage{StatusCode: 500, Error: errReadAll} return nil, &interfaces.ErrorMessage{StatusCode: 500, Error: errReadAll}
@@ -297,11 +314,15 @@ func (c *GeminiClient) SendRawMessageStream(ctx context.Context, modelName strin
if err.StatusCode == 429 { if err.StatusCode == 429 {
now := time.Now() now := time.Now()
c.modelQuotaExceeded[modelName] = &now c.modelQuotaExceeded[modelName] = &now
// Update model registry quota status
c.SetModelQuotaExceeded(modelName)
} }
errChan <- err errChan <- err
return return
} }
delete(c.modelQuotaExceeded, modelName) delete(c.modelQuotaExceeded, modelName)
// Clear quota status in model registry
c.ClearModelQuotaExceeded(modelName)
defer func() { defer func() {
_ = stream.Close() _ = stream.Close()
}() }()

View File

@@ -19,6 +19,7 @@ import (
"github.com/luispater/CLIProxyAPI/internal/config" "github.com/luispater/CLIProxyAPI/internal/config"
. "github.com/luispater/CLIProxyAPI/internal/constant" . "github.com/luispater/CLIProxyAPI/internal/constant"
"github.com/luispater/CLIProxyAPI/internal/interfaces" "github.com/luispater/CLIProxyAPI/internal/interfaces"
"github.com/luispater/CLIProxyAPI/internal/registry"
"github.com/luispater/CLIProxyAPI/internal/translator/translator" "github.com/luispater/CLIProxyAPI/internal/translator/translator"
"github.com/luispater/CLIProxyAPI/internal/util" "github.com/luispater/CLIProxyAPI/internal/util"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
@@ -53,6 +54,10 @@ func NewOpenAICompatibilityClient(cfg *config.Config, compatConfig *config.OpenA
} }
httpClient := util.SetProxy(cfg, &http.Client{}) httpClient := util.SetProxy(cfg, &http.Client{})
// Generate unique client ID
clientID := fmt.Sprintf("openai-compatibility-%s-%d", compatConfig.Name, time.Now().UnixNano())
client := &OpenAICompatibilityClient{ client := &OpenAICompatibilityClient{
ClientBase: ClientBase{ ClientBase: ClientBase{
RequestMutex: &sync.Mutex{}, RequestMutex: &sync.Mutex{},
@@ -64,6 +69,25 @@ func NewOpenAICompatibilityClient(cfg *config.Config, compatConfig *config.OpenA
currentAPIKeyIndex: 0, currentAPIKeyIndex: 0,
} }
// Initialize model registry
client.InitializeModelRegistry(clientID)
// Convert compatibility models to registry models and register them
registryModels := make([]*registry.ModelInfo, 0, len(compatConfig.Models))
for _, model := range compatConfig.Models {
registryModel := &registry.ModelInfo{
ID: model.Alias,
Object: "model",
Created: time.Now().Unix(),
OwnedBy: compatConfig.Name,
Type: "openai-compatibility",
DisplayName: model.Name,
}
registryModels = append(registryModels, registryModel)
}
client.RegisterModels(compatConfig.Name, registryModels)
return client, nil return client, nil
} }
@@ -216,10 +240,14 @@ func (c *OpenAICompatibilityClient) SendRawMessage(ctx context.Context, modelNam
if err.StatusCode == 429 { if err.StatusCode == 429 {
now := time.Now() now := time.Now()
c.modelQuotaExceeded[modelName] = &now c.modelQuotaExceeded[modelName] = &now
// Update model registry quota status
c.SetModelQuotaExceeded(modelName)
} }
return nil, err return nil, err
} }
delete(c.modelQuotaExceeded, modelName) delete(c.modelQuotaExceeded, modelName)
// Clear quota status in model registry
c.ClearModelQuotaExceeded(modelName)
bodyBytes, errReadAll := io.ReadAll(respBody) bodyBytes, errReadAll := io.ReadAll(respBody)
if errReadAll != nil { if errReadAll != nil {
return nil, &interfaces.ErrorMessage{StatusCode: 500, Error: errReadAll} return nil, &interfaces.ErrorMessage{StatusCode: 500, Error: errReadAll}
@@ -270,11 +298,15 @@ func (c *OpenAICompatibilityClient) SendRawMessageStream(ctx context.Context, mo
if err.StatusCode == 429 { if err.StatusCode == 429 {
now := time.Now() now := time.Now()
c.modelQuotaExceeded[modelName] = &now c.modelQuotaExceeded[modelName] = &now
// Update model registry quota status
c.SetModelQuotaExceeded(modelName)
} }
errChan <- err errChan <- err
return return
} }
delete(c.modelQuotaExceeded, modelName) delete(c.modelQuotaExceeded, modelName)
// Clear quota status in model registry
c.ClearModelQuotaExceeded(modelName)
defer func() { defer func() {
_ = stream.Close() _ = stream.Close()
}() }()

View File

@@ -22,6 +22,7 @@ import (
"github.com/luispater/CLIProxyAPI/internal/config" "github.com/luispater/CLIProxyAPI/internal/config"
. "github.com/luispater/CLIProxyAPI/internal/constant" . "github.com/luispater/CLIProxyAPI/internal/constant"
"github.com/luispater/CLIProxyAPI/internal/interfaces" "github.com/luispater/CLIProxyAPI/internal/interfaces"
"github.com/luispater/CLIProxyAPI/internal/registry"
"github.com/luispater/CLIProxyAPI/internal/translator/translator" "github.com/luispater/CLIProxyAPI/internal/translator/translator"
"github.com/luispater/CLIProxyAPI/internal/util" "github.com/luispater/CLIProxyAPI/internal/util"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
@@ -49,6 +50,10 @@ type QwenClient struct {
// - *QwenClient: A new Qwen client instance. // - *QwenClient: A new Qwen client instance.
func NewQwenClient(cfg *config.Config, ts *qwen.QwenTokenStorage) *QwenClient { func NewQwenClient(cfg *config.Config, ts *qwen.QwenTokenStorage) *QwenClient {
httpClient := util.SetProxy(cfg, &http.Client{}) httpClient := util.SetProxy(cfg, &http.Client{})
// Generate unique client ID
clientID := fmt.Sprintf("qwen-%d", time.Now().UnixNano())
client := &QwenClient{ client := &QwenClient{
ClientBase: ClientBase{ ClientBase: ClientBase{
RequestMutex: &sync.Mutex{}, RequestMutex: &sync.Mutex{},
@@ -60,6 +65,10 @@ func NewQwenClient(cfg *config.Config, ts *qwen.QwenTokenStorage) *QwenClient {
qwenAuth: qwen.NewQwenAuth(cfg), qwenAuth: qwen.NewQwenAuth(cfg),
} }
// Initialize model registry and register Qwen models
client.InitializeModelRegistry(clientID)
client.RegisterModels("qwen", registry.GetQwenModels())
return client return client
} }
@@ -119,10 +128,14 @@ func (c *QwenClient) SendRawMessage(ctx context.Context, modelName string, rawJS
if err.StatusCode == 429 { if err.StatusCode == 429 {
now := time.Now() now := time.Now()
c.modelQuotaExceeded[modelName] = &now c.modelQuotaExceeded[modelName] = &now
// Update model registry quota status
c.SetModelQuotaExceeded(modelName)
} }
return nil, err return nil, err
} }
delete(c.modelQuotaExceeded, modelName) delete(c.modelQuotaExceeded, modelName)
// Clear quota status in model registry
c.ClearModelQuotaExceeded(modelName)
bodyBytes, errReadAll := io.ReadAll(respBody) bodyBytes, errReadAll := io.ReadAll(respBody)
if errReadAll != nil { if errReadAll != nil {
return nil, &interfaces.ErrorMessage{StatusCode: 500, Error: errReadAll} return nil, &interfaces.ErrorMessage{StatusCode: 500, Error: errReadAll}
@@ -182,11 +195,15 @@ func (c *QwenClient) SendRawMessageStream(ctx context.Context, modelName string,
if err.StatusCode == 429 { if err.StatusCode == 429 {
now := time.Now() now := time.Now()
c.modelQuotaExceeded[modelName] = &now c.modelQuotaExceeded[modelName] = &now
// Update model registry quota status
c.SetModelQuotaExceeded(modelName)
} }
errChan <- err errChan <- err
return return
} }
delete(c.modelQuotaExceeded, modelName) delete(c.modelQuotaExceeded, modelName)
// Clear quota status in model registry
c.ClearModelQuotaExceeded(modelName)
defer func() { defer func() {
_ = stream.Close() _ = stream.Close()
}() }()

View File

@@ -0,0 +1,150 @@
// Package registry provides model definitions for various AI service providers.
// This file contains static model definitions that can be used by clients
// when registering their supported models.
package registry
import "time"
// GetClaudeModels returns the standard Claude model definitions
func GetClaudeModels() []*ModelInfo {
return []*ModelInfo{
{
ID: "claude-opus-4-1-20250805",
Object: "model",
Created: 1722945600, // 2025-08-05
OwnedBy: "anthropic",
Type: "claude",
DisplayName: "Claude 4.1 Opus",
},
{
ID: "claude-opus-4-20250514",
Object: "model",
Created: 1715644800, // 2025-05-14
OwnedBy: "anthropic",
Type: "claude",
DisplayName: "Claude 4 Opus",
},
{
ID: "claude-sonnet-4-20250514",
Object: "model",
Created: 1715644800, // 2025-05-14
OwnedBy: "anthropic",
Type: "claude",
DisplayName: "Claude 4 Sonnet",
},
{
ID: "claude-3-7-sonnet-20250219",
Object: "model",
Created: 1708300800, // 2025-02-19
OwnedBy: "anthropic",
Type: "claude",
DisplayName: "Claude 3.7 Sonnet",
},
{
ID: "claude-3-5-haiku-20241022",
Object: "model",
Created: 1729555200, // 2024-10-22
OwnedBy: "anthropic",
Type: "claude",
DisplayName: "Claude 3.5 Haiku",
},
}
}
// GetGeminiModels returns the standard Gemini model definitions
func GetGeminiModels() []*ModelInfo {
return []*ModelInfo{
{
ID: "gemini-2.5-flash",
Object: "model",
Created: time.Now().Unix(),
OwnedBy: "google",
Type: "gemini",
Name: "models/gemini-2.5-flash",
Version: "001",
DisplayName: "Gemini 2.5 Flash",
Description: "Stable version of Gemini 2.5 Flash, our mid-size multimodal model that supports up to 1 million tokens, released in June of 2025.",
InputTokenLimit: 1048576,
OutputTokenLimit: 65536,
SupportedGenerationMethods: []string{"generateContent", "countTokens", "createCachedContent", "batchGenerateContent"},
},
{
ID: "gemini-2.5-pro",
Object: "model",
Created: time.Now().Unix(),
OwnedBy: "google",
Type: "gemini",
Name: "models/gemini-2.5-pro",
Version: "2.5",
DisplayName: "Gemini 2.5 Pro",
Description: "Stable release (June 17th, 2025) of Gemini 2.5 Pro",
InputTokenLimit: 1048576,
OutputTokenLimit: 65536,
SupportedGenerationMethods: []string{"generateContent", "countTokens", "createCachedContent", "batchGenerateContent"},
},
}
}
// GetOpenAIModels returns the standard OpenAI model definitions
func GetOpenAIModels() []*ModelInfo {
return []*ModelInfo{
{
ID: "gpt-5",
Object: "model",
Created: time.Now().Unix(),
OwnedBy: "openai",
Type: "openai",
Version: "gpt-5-2025-08-07",
DisplayName: "GPT 5",
Description: "Stable version of GPT 5, The best model for coding and agentic tasks across domains.",
ContextLength: 400000,
MaxCompletionTokens: 128000,
SupportedParameters: []string{"tools"},
},
{
ID: "codex-mini-latest",
Object: "model",
Created: time.Now().Unix(),
OwnedBy: "openai",
Type: "openai",
Version: "1.0",
DisplayName: "Codex Mini",
Description: "Lightweight code generation model",
ContextLength: 4096,
MaxCompletionTokens: 2048,
SupportedParameters: []string{"temperature", "max_tokens", "stream", "stop"},
},
}
}
// GetQwenModels returns the standard Qwen model definitions
func GetQwenModels() []*ModelInfo {
return []*ModelInfo{
{
ID: "qwen3-coder-plus",
Object: "model",
Created: time.Now().Unix(),
OwnedBy: "qwen",
Type: "qwen",
Version: "3.0",
DisplayName: "Qwen3 Coder Plus",
Description: "Advanced code generation and understanding model",
ContextLength: 32768,
MaxCompletionTokens: 8192,
SupportedParameters: []string{"temperature", "top_p", "max_tokens", "stream", "stop"},
},
{
ID: "qwen3-coder-flash",
Object: "model",
Created: time.Now().Unix(),
OwnedBy: "qwen",
Type: "qwen",
Version: "3.0",
DisplayName: "Qwen3 Coder Flash",
Description: "Fast code generation model",
ContextLength: 8192,
MaxCompletionTokens: 2048,
SupportedParameters: []string{"temperature", "top_p", "max_tokens", "stream", "stop"},
},
}
}

View File

@@ -0,0 +1,374 @@
// Package registry provides centralized model management for all AI service providers.
// It implements a dynamic model registry with reference counting to track active clients
// and automatically hide models when no clients are available or when quota is exceeded.
package registry
import (
"sync"
"time"
log "github.com/sirupsen/logrus"
)
// ModelInfo represents information about an available model
type ModelInfo struct {
// ID is the unique identifier for the model
ID string `json:"id"`
// Object type for the model (typically "model")
Object string `json:"object"`
// Created timestamp when the model was created
Created int64 `json:"created"`
// OwnedBy indicates the organization that owns the model
OwnedBy string `json:"owned_by"`
// Type indicates the model type (e.g., "claude", "gemini", "openai")
Type string `json:"type"`
// DisplayName is the human-readable name for the model
DisplayName string `json:"display_name,omitempty"`
// Name is used for Gemini-style model names
Name string `json:"name,omitempty"`
// Version is the model version
Version string `json:"version,omitempty"`
// Description provides detailed information about the model
Description string `json:"description,omitempty"`
// InputTokenLimit is the maximum input token limit
InputTokenLimit int `json:"inputTokenLimit,omitempty"`
// OutputTokenLimit is the maximum output token limit
OutputTokenLimit int `json:"outputTokenLimit,omitempty"`
// SupportedGenerationMethods lists supported generation methods
SupportedGenerationMethods []string `json:"supportedGenerationMethods,omitempty"`
// ContextLength is the context window size
ContextLength int `json:"context_length,omitempty"`
// MaxCompletionTokens is the maximum completion tokens
MaxCompletionTokens int `json:"max_completion_tokens,omitempty"`
// SupportedParameters lists supported parameters
SupportedParameters []string `json:"supported_parameters,omitempty"`
}
// ModelRegistration tracks a model's availability
type ModelRegistration struct {
// Info contains the model metadata
Info *ModelInfo
// Count is the number of active clients that can provide this model
Count int
// LastUpdated tracks when this registration was last modified
LastUpdated time.Time
// QuotaExceededClients tracks which clients have exceeded quota for this model
QuotaExceededClients map[string]*time.Time
}
// ModelRegistry manages the global registry of available models
type ModelRegistry struct {
// models maps model ID to registration information
models map[string]*ModelRegistration
// clientModels maps client ID to the models it provides
clientModels map[string][]string
// mutex ensures thread-safe access to the registry
mutex *sync.RWMutex
}
// Global model registry instance
var globalRegistry *ModelRegistry
var registryOnce sync.Once
// GetGlobalRegistry returns the global model registry instance
func GetGlobalRegistry() *ModelRegistry {
registryOnce.Do(func() {
globalRegistry = &ModelRegistry{
models: make(map[string]*ModelRegistration),
clientModels: make(map[string][]string),
mutex: &sync.RWMutex{},
}
})
return globalRegistry
}
// RegisterClient registers a client and its supported models
// Parameters:
// - clientID: Unique identifier for the client
// - clientProvider: Provider name (e.g., "gemini", "claude", "openai")
// - models: List of models that this client can provide
func (r *ModelRegistry) RegisterClient(clientID, clientProvider string, models []*ModelInfo) {
r.mutex.Lock()
defer r.mutex.Unlock()
// Remove any existing registration for this client
r.unregisterClientInternal(clientID)
modelIDs := make([]string, 0, len(models))
now := time.Now()
for _, model := range models {
modelIDs = append(modelIDs, model.ID)
if existing, exists := r.models[model.ID]; exists {
// Model already exists, increment count
existing.Count++
existing.LastUpdated = now
log.Debugf("Incremented count for model %s, now %d clients", model.ID, existing.Count)
} else {
// New model, create registration
r.models[model.ID] = &ModelRegistration{
Info: model,
Count: 1,
LastUpdated: now,
QuotaExceededClients: make(map[string]*time.Time),
}
log.Debugf("Registered new model %s from provider %s", model.ID, clientProvider)
}
}
r.clientModels[clientID] = modelIDs
log.Debugf("Registered client %s from provider %s with %d models", clientID, clientProvider, len(models))
}
// UnregisterClient removes a client and decrements counts for its models
// Parameters:
// - clientID: Unique identifier for the client to remove
func (r *ModelRegistry) UnregisterClient(clientID string) {
r.mutex.Lock()
defer r.mutex.Unlock()
r.unregisterClientInternal(clientID)
}
// unregisterClientInternal performs the actual client unregistration (internal, no locking)
func (r *ModelRegistry) unregisterClientInternal(clientID string) {
models, exists := r.clientModels[clientID]
if !exists {
return
}
now := time.Now()
for _, modelID := range models {
if registration, isExists := r.models[modelID]; isExists {
registration.Count--
registration.LastUpdated = now
// Remove quota tracking for this client
delete(registration.QuotaExceededClients, clientID)
log.Debugf("Decremented count for model %s, now %d clients", modelID, registration.Count)
// Remove model if no clients remain
if registration.Count <= 0 {
delete(r.models, modelID)
log.Debugf("Removed model %s as no clients remain", modelID)
}
}
}
delete(r.clientModels, clientID)
log.Debugf("Unregistered client %s", clientID)
}
// SetModelQuotaExceeded marks a model as quota exceeded for a specific client
// Parameters:
// - clientID: The client that exceeded quota
// - modelID: The model that exceeded quota
func (r *ModelRegistry) SetModelQuotaExceeded(clientID, modelID string) {
r.mutex.Lock()
defer r.mutex.Unlock()
if registration, exists := r.models[modelID]; exists {
now := time.Now()
registration.QuotaExceededClients[clientID] = &now
log.Debugf("Marked model %s as quota exceeded for client %s", modelID, clientID)
}
}
// ClearModelQuotaExceeded removes quota exceeded status for a model and client
// Parameters:
// - clientID: The client to clear quota status for
// - modelID: The model to clear quota status for
func (r *ModelRegistry) ClearModelQuotaExceeded(clientID, modelID string) {
r.mutex.Lock()
defer r.mutex.Unlock()
if registration, exists := r.models[modelID]; exists {
delete(registration.QuotaExceededClients, clientID)
log.Debugf("Cleared quota exceeded status for model %s and client %s", modelID, clientID)
}
}
// GetAvailableModels returns all models that have at least one available client
// Parameters:
// - handlerType: The handler type to filter models for (e.g., "openai", "claude", "gemini")
//
// Returns:
// - []map[string]any: List of available models in the requested format
func (r *ModelRegistry) GetAvailableModels(handlerType string) []map[string]any {
r.mutex.RLock()
defer r.mutex.RUnlock()
models := make([]map[string]any, 0)
quotaExpiredDuration := 5 * time.Minute
for _, registration := range r.models {
// Check if model has any non-quota-exceeded clients
availableClients := registration.Count
now := time.Now()
// Count clients that have exceeded quota but haven't recovered yet
expiredClients := 0
for _, quotaTime := range registration.QuotaExceededClients {
if quotaTime != nil && now.Sub(*quotaTime) < quotaExpiredDuration {
expiredClients++
}
}
effectiveClients := availableClients - expiredClients
// Only include models that have available clients
if effectiveClients > 0 {
model := r.convertModelToMap(registration.Info, handlerType)
if model != nil {
models = append(models, model)
}
}
}
return models
}
// GetModelCount returns the number of available clients for a specific model
// Parameters:
// - modelID: The model ID to check
//
// Returns:
// - int: Number of available clients for the model
func (r *ModelRegistry) GetModelCount(modelID string) int {
r.mutex.RLock()
defer r.mutex.RUnlock()
if registration, exists := r.models[modelID]; exists {
now := time.Now()
quotaExpiredDuration := 5 * time.Minute
// Count clients that have exceeded quota but haven't recovered yet
expiredClients := 0
for _, quotaTime := range registration.QuotaExceededClients {
if quotaTime != nil && now.Sub(*quotaTime) < quotaExpiredDuration {
expiredClients++
}
}
return registration.Count - expiredClients
}
return 0
}
// convertModelToMap converts ModelInfo to the appropriate format for different handler types
func (r *ModelRegistry) convertModelToMap(model *ModelInfo, handlerType string) map[string]any {
if model == nil {
return nil
}
switch handlerType {
case "openai":
result := map[string]any{
"id": model.ID,
"object": "model",
"owned_by": model.OwnedBy,
}
if model.Created > 0 {
result["created"] = model.Created
}
if model.Type != "" {
result["type"] = model.Type
}
if model.DisplayName != "" {
result["display_name"] = model.DisplayName
}
if model.Version != "" {
result["version"] = model.Version
}
if model.Description != "" {
result["description"] = model.Description
}
if model.ContextLength > 0 {
result["context_length"] = model.ContextLength
}
if model.MaxCompletionTokens > 0 {
result["max_completion_tokens"] = model.MaxCompletionTokens
}
if len(model.SupportedParameters) > 0 {
result["supported_parameters"] = model.SupportedParameters
}
return result
case "claude":
result := map[string]any{
"id": model.ID,
"object": "model",
"owned_by": model.OwnedBy,
}
if model.Created > 0 {
result["created"] = model.Created
}
if model.Type != "" {
result["type"] = model.Type
}
if model.DisplayName != "" {
result["display_name"] = model.DisplayName
}
return result
case "gemini":
result := map[string]any{}
if model.Name != "" {
result["name"] = model.Name
} else {
result["name"] = model.ID
}
if model.Version != "" {
result["version"] = model.Version
}
if model.DisplayName != "" {
result["displayName"] = model.DisplayName
}
if model.Description != "" {
result["description"] = model.Description
}
if model.InputTokenLimit > 0 {
result["inputTokenLimit"] = model.InputTokenLimit
}
if model.OutputTokenLimit > 0 {
result["outputTokenLimit"] = model.OutputTokenLimit
}
if len(model.SupportedGenerationMethods) > 0 {
result["supportedGenerationMethods"] = model.SupportedGenerationMethods
}
return result
default:
// Generic format
result := map[string]any{
"id": model.ID,
"object": "model",
}
if model.OwnedBy != "" {
result["owned_by"] = model.OwnedBy
}
if model.Type != "" {
result["type"] = model.Type
}
return result
}
}
// CleanupExpiredQuotas removes expired quota tracking entries
func (r *ModelRegistry) CleanupExpiredQuotas() {
r.mutex.Lock()
defer r.mutex.Unlock()
now := time.Now()
quotaExpiredDuration := 5 * time.Minute
for modelID, registration := range r.models {
for clientID, quotaTime := range registration.QuotaExceededClients {
if quotaTime != nil && now.Sub(*quotaTime) >= quotaExpiredDuration {
delete(registration.QuotaExceededClients, clientID)
log.Debugf("Cleaned up expired quota tracking for model %s, client %s", modelID, clientID)
}
}
}
}