mirror of
https://github.com/router-for-me/CLIProxyAPI.git
synced 2026-02-03 04:50:52 +08:00
Improved the /v1/models endpoint
This commit is contained in:
@@ -16,6 +16,7 @@ import (
|
|||||||
"github.com/luispater/CLIProxyAPI/internal/api/handlers"
|
"github.com/luispater/CLIProxyAPI/internal/api/handlers"
|
||||||
. "github.com/luispater/CLIProxyAPI/internal/constant"
|
. "github.com/luispater/CLIProxyAPI/internal/constant"
|
||||||
"github.com/luispater/CLIProxyAPI/internal/interfaces"
|
"github.com/luispater/CLIProxyAPI/internal/interfaces"
|
||||||
|
"github.com/luispater/CLIProxyAPI/internal/registry"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
"github.com/tidwall/gjson"
|
"github.com/tidwall/gjson"
|
||||||
)
|
)
|
||||||
@@ -47,7 +48,9 @@ func (h *ClaudeCodeAPIHandler) HandlerType() string {
|
|||||||
|
|
||||||
// Models returns a list of models supported by this handler.
|
// Models returns a list of models supported by this handler.
|
||||||
func (h *ClaudeCodeAPIHandler) Models() []map[string]any {
|
func (h *ClaudeCodeAPIHandler) Models() []map[string]any {
|
||||||
return make([]map[string]any, 0)
|
// Get dynamic models from the global registry
|
||||||
|
modelRegistry := registry.GetGlobalRegistry()
|
||||||
|
return modelRegistry.GetAvailableModels("claude")
|
||||||
}
|
}
|
||||||
|
|
||||||
// ClaudeMessages handles Claude-compatible streaming chat completions.
|
// ClaudeMessages handles Claude-compatible streaming chat completions.
|
||||||
@@ -79,6 +82,17 @@ func (h *ClaudeCodeAPIHandler) ClaudeMessages(c *gin.Context) {
|
|||||||
h.handleStreamingResponse(c, rawJSON)
|
h.handleStreamingResponse(c, rawJSON)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ClaudeModels handles the Claude models listing endpoint.
|
||||||
|
// It returns a JSON response containing available Claude models and their specifications.
|
||||||
|
//
|
||||||
|
// Parameters:
|
||||||
|
// - c: The Gin context for the request.
|
||||||
|
func (h *ClaudeCodeAPIHandler) ClaudeModels(c *gin.Context) {
|
||||||
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"data": h.Models(),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
// handleStreamingResponse streams Claude-compatible responses backed by Gemini.
|
// handleStreamingResponse streams Claude-compatible responses backed by Gemini.
|
||||||
// It sets up SSE, selects a backend client with rotation/quota logic,
|
// It sets up SSE, selects a backend client with rotation/quota logic,
|
||||||
// forwards chunks, and translates them to Claude CLI format.
|
// forwards chunks, and translates them to Claude CLI format.
|
||||||
|
|||||||
@@ -16,6 +16,7 @@ import (
|
|||||||
"github.com/luispater/CLIProxyAPI/internal/api/handlers"
|
"github.com/luispater/CLIProxyAPI/internal/api/handlers"
|
||||||
. "github.com/luispater/CLIProxyAPI/internal/constant"
|
. "github.com/luispater/CLIProxyAPI/internal/constant"
|
||||||
"github.com/luispater/CLIProxyAPI/internal/interfaces"
|
"github.com/luispater/CLIProxyAPI/internal/interfaces"
|
||||||
|
"github.com/luispater/CLIProxyAPI/internal/registry"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -40,62 +41,9 @@ func (h *GeminiAPIHandler) HandlerType() string {
|
|||||||
|
|
||||||
// Models returns the Gemini-compatible model metadata supported by this handler.
|
// Models returns the Gemini-compatible model metadata supported by this handler.
|
||||||
func (h *GeminiAPIHandler) Models() []map[string]any {
|
func (h *GeminiAPIHandler) Models() []map[string]any {
|
||||||
return []map[string]any{
|
// Get dynamic models from the global registry
|
||||||
{
|
modelRegistry := registry.GetGlobalRegistry()
|
||||||
"name": "models/gemini-2.5-flash",
|
return modelRegistry.GetAvailableModels("gemini")
|
||||||
"version": "001",
|
|
||||||
"displayName": "Gemini 2.5 Flash",
|
|
||||||
"description": "Stable version of Gemini 2.5 Flash, our mid-size multimodal model that supports up to 1 million tokens, released in June of 2025.",
|
|
||||||
"inputTokenLimit": 1048576,
|
|
||||||
"outputTokenLimit": 65536,
|
|
||||||
"supportedGenerationMethods": []string{
|
|
||||||
"generateContent",
|
|
||||||
"countTokens",
|
|
||||||
"createCachedContent",
|
|
||||||
"batchGenerateContent",
|
|
||||||
},
|
|
||||||
"temperature": 1,
|
|
||||||
"topP": 0.95,
|
|
||||||
"topK": 64,
|
|
||||||
"maxTemperature": 2,
|
|
||||||
"thinking": true,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "models/gemini-2.5-pro",
|
|
||||||
"version": "2.5",
|
|
||||||
"displayName": "Gemini 2.5 Pro",
|
|
||||||
"description": "Stable release (June 17th, 2025) of Gemini 2.5 Pro",
|
|
||||||
"inputTokenLimit": 1048576,
|
|
||||||
"outputTokenLimit": 65536,
|
|
||||||
"supportedGenerationMethods": []string{
|
|
||||||
"generateContent",
|
|
||||||
"countTokens",
|
|
||||||
"createCachedContent",
|
|
||||||
"batchGenerateContent",
|
|
||||||
},
|
|
||||||
"temperature": 1,
|
|
||||||
"topP": 0.95,
|
|
||||||
"topK": 64,
|
|
||||||
"maxTemperature": 2,
|
|
||||||
"thinking": true,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"name": "gpt-5",
|
|
||||||
"version": "001",
|
|
||||||
"displayName": "GPT 5",
|
|
||||||
"description": "Stable version of GPT 5, The best model for coding and agentic tasks across domains.",
|
|
||||||
"inputTokenLimit": 400000,
|
|
||||||
"outputTokenLimit": 128000,
|
|
||||||
"supportedGenerationMethods": []string{
|
|
||||||
"generateContent",
|
|
||||||
},
|
|
||||||
"temperature": 1,
|
|
||||||
"topP": 0.95,
|
|
||||||
"topK": 64,
|
|
||||||
"maxTemperature": 2,
|
|
||||||
"thinking": true,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// GeminiModels handles the Gemini models listing endpoint.
|
// GeminiModels handles the Gemini models listing endpoint.
|
||||||
|
|||||||
@@ -16,6 +16,7 @@ import (
|
|||||||
"github.com/luispater/CLIProxyAPI/internal/api/handlers"
|
"github.com/luispater/CLIProxyAPI/internal/api/handlers"
|
||||||
. "github.com/luispater/CLIProxyAPI/internal/constant"
|
. "github.com/luispater/CLIProxyAPI/internal/constant"
|
||||||
"github.com/luispater/CLIProxyAPI/internal/interfaces"
|
"github.com/luispater/CLIProxyAPI/internal/interfaces"
|
||||||
|
"github.com/luispater/CLIProxyAPI/internal/registry"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
"github.com/tidwall/gjson"
|
"github.com/tidwall/gjson"
|
||||||
)
|
)
|
||||||
@@ -47,89 +48,17 @@ func (h *OpenAIAPIHandler) HandlerType() string {
|
|||||||
|
|
||||||
// Models returns the OpenAI-compatible model metadata supported by this handler.
|
// Models returns the OpenAI-compatible model metadata supported by this handler.
|
||||||
func (h *OpenAIAPIHandler) Models() []map[string]any {
|
func (h *OpenAIAPIHandler) Models() []map[string]any {
|
||||||
return []map[string]any{
|
// Get dynamic models from the global registry
|
||||||
{
|
modelRegistry := registry.GetGlobalRegistry()
|
||||||
"id": "gemini-2.5-pro",
|
return modelRegistry.GetAvailableModels("openai")
|
||||||
"object": "model",
|
|
||||||
"version": "2.5",
|
|
||||||
"name": "Gemini 2.5 Pro",
|
|
||||||
"description": "Stable release (June 17th, 2025) of Gemini 2.5 Pro",
|
|
||||||
"context_length": 1_048_576,
|
|
||||||
"max_completion_tokens": 65_536,
|
|
||||||
"supported_parameters": []string{
|
|
||||||
"tools",
|
|
||||||
"temperature",
|
|
||||||
"top_p",
|
|
||||||
"top_k",
|
|
||||||
},
|
|
||||||
"temperature": 1,
|
|
||||||
"topP": 0.95,
|
|
||||||
"topK": 64,
|
|
||||||
"maxTemperature": 2,
|
|
||||||
"thinking": true,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": "gemini-2.5-flash",
|
|
||||||
"object": "model",
|
|
||||||
"version": "001",
|
|
||||||
"name": "Gemini 2.5 Flash",
|
|
||||||
"description": "Stable version of Gemini 2.5 Flash, our mid-size multimodal model that supports up to 1 million tokens, released in June of 2025.",
|
|
||||||
"context_length": 1_048_576,
|
|
||||||
"max_completion_tokens": 65_536,
|
|
||||||
"supported_parameters": []string{
|
|
||||||
"tools",
|
|
||||||
"temperature",
|
|
||||||
"top_p",
|
|
||||||
"top_k",
|
|
||||||
},
|
|
||||||
"temperature": 1,
|
|
||||||
"topP": 0.95,
|
|
||||||
"topK": 64,
|
|
||||||
"maxTemperature": 2,
|
|
||||||
"thinking": true,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": "gpt-5",
|
|
||||||
"object": "model",
|
|
||||||
"version": "gpt-5-2025-08-07",
|
|
||||||
"name": "GPT 5",
|
|
||||||
"description": "Stable version of GPT 5, The best model for coding and agentic tasks across domains.",
|
|
||||||
"context_length": 400_000,
|
|
||||||
"max_completion_tokens": 128_000,
|
|
||||||
"supported_parameters": []string{
|
|
||||||
"tools",
|
|
||||||
},
|
|
||||||
"temperature": 1,
|
|
||||||
"topP": 0.95,
|
|
||||||
"topK": 64,
|
|
||||||
"maxTemperature": 2,
|
|
||||||
"thinking": true,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"id": "claude-opus-4-1-20250805",
|
|
||||||
"object": "model",
|
|
||||||
"version": "claude-opus-4-1-20250805",
|
|
||||||
"name": "Claude Opus 4.1",
|
|
||||||
"description": "Anthropic's most capable model.",
|
|
||||||
"context_length": 200_000,
|
|
||||||
"max_completion_tokens": 32_000,
|
|
||||||
"supported_parameters": []string{
|
|
||||||
"tools",
|
|
||||||
},
|
|
||||||
"temperature": 1,
|
|
||||||
"topP": 0.95,
|
|
||||||
"topK": 64,
|
|
||||||
"maxTemperature": 2,
|
|
||||||
"thinking": true,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// OpenAIModels handles the /v1/models endpoint.
|
// OpenAIModels handles the /v1/models endpoint.
|
||||||
// It returns a hardcoded list of available AI models with their capabilities
|
// It returns a list of available AI models with their capabilities
|
||||||
// and specifications in OpenAI-compatible format.
|
// and specifications in OpenAI-compatible format.
|
||||||
func (h *OpenAIAPIHandler) OpenAIModels(c *gin.Context) {
|
func (h *OpenAIAPIHandler) OpenAIModels(c *gin.Context) {
|
||||||
c.JSON(http.StatusOK, gin.H{
|
c.JSON(http.StatusOK, gin.H{
|
||||||
|
"object": "list",
|
||||||
"data": h.Models(),
|
"data": h.Models(),
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -102,7 +102,7 @@ func (s *Server) setupRoutes() {
|
|||||||
v1 := s.engine.Group("/v1")
|
v1 := s.engine.Group("/v1")
|
||||||
v1.Use(AuthMiddleware(s.cfg))
|
v1.Use(AuthMiddleware(s.cfg))
|
||||||
{
|
{
|
||||||
v1.GET("/models", openaiHandlers.OpenAIModels)
|
v1.GET("/models", s.unifiedModelsHandler(openaiHandlers, claudeCodeHandlers))
|
||||||
v1.POST("/chat/completions", openaiHandlers.ChatCompletions)
|
v1.POST("/chat/completions", openaiHandlers.ChatCompletions)
|
||||||
v1.POST("/messages", claudeCodeHandlers.ClaudeMessages)
|
v1.POST("/messages", claudeCodeHandlers.ClaudeMessages)
|
||||||
}
|
}
|
||||||
@@ -130,6 +130,25 @@ func (s *Server) setupRoutes() {
|
|||||||
s.engine.POST("/v1internal:method", geminiCLIHandlers.CLIHandler)
|
s.engine.POST("/v1internal:method", geminiCLIHandlers.CLIHandler)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// unifiedModelsHandler creates a unified handler for the /v1/models endpoint
|
||||||
|
// that routes to different handlers based on the User-Agent header.
|
||||||
|
// If User-Agent starts with "claude-cli", it routes to Claude handler,
|
||||||
|
// otherwise it routes to OpenAI handler.
|
||||||
|
func (s *Server) unifiedModelsHandler(openaiHandler *openai.OpenAIAPIHandler, claudeHandler *claude.ClaudeCodeAPIHandler) gin.HandlerFunc {
|
||||||
|
return func(c *gin.Context) {
|
||||||
|
userAgent := c.GetHeader("User-Agent")
|
||||||
|
|
||||||
|
// Route to Claude handler if User-Agent starts with "claude-cli"
|
||||||
|
if strings.HasPrefix(userAgent, "claude-cli") {
|
||||||
|
log.Debugf("Routing /v1/models to Claude handler for User-Agent: %s", userAgent)
|
||||||
|
claudeHandler.ClaudeModels(c)
|
||||||
|
} else {
|
||||||
|
log.Debugf("Routing /v1/models to OpenAI handler for User-Agent: %s", userAgent)
|
||||||
|
openaiHandler.OpenAIModels(c)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Start begins listening for and serving HTTP requests.
|
// Start begins listening for and serving HTTP requests.
|
||||||
// It's a blocking call and will only return on an unrecoverable error.
|
// It's a blocking call and will only return on an unrecoverable error.
|
||||||
//
|
//
|
||||||
|
|||||||
@@ -23,6 +23,7 @@ import (
|
|||||||
. "github.com/luispater/CLIProxyAPI/internal/constant"
|
. "github.com/luispater/CLIProxyAPI/internal/constant"
|
||||||
"github.com/luispater/CLIProxyAPI/internal/interfaces"
|
"github.com/luispater/CLIProxyAPI/internal/interfaces"
|
||||||
"github.com/luispater/CLIProxyAPI/internal/misc"
|
"github.com/luispater/CLIProxyAPI/internal/misc"
|
||||||
|
"github.com/luispater/CLIProxyAPI/internal/registry"
|
||||||
"github.com/luispater/CLIProxyAPI/internal/translator/translator"
|
"github.com/luispater/CLIProxyAPI/internal/translator/translator"
|
||||||
"github.com/luispater/CLIProxyAPI/internal/util"
|
"github.com/luispater/CLIProxyAPI/internal/util"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
@@ -55,6 +56,10 @@ type ClaudeClient struct {
|
|||||||
// - *ClaudeClient: A new Claude client instance.
|
// - *ClaudeClient: A new Claude client instance.
|
||||||
func NewClaudeClient(cfg *config.Config, ts *claude.ClaudeTokenStorage) *ClaudeClient {
|
func NewClaudeClient(cfg *config.Config, ts *claude.ClaudeTokenStorage) *ClaudeClient {
|
||||||
httpClient := util.SetProxy(cfg, &http.Client{})
|
httpClient := util.SetProxy(cfg, &http.Client{})
|
||||||
|
|
||||||
|
// Generate unique client ID
|
||||||
|
clientID := fmt.Sprintf("claude-%d", time.Now().UnixNano())
|
||||||
|
|
||||||
client := &ClaudeClient{
|
client := &ClaudeClient{
|
||||||
ClientBase: ClientBase{
|
ClientBase: ClientBase{
|
||||||
RequestMutex: &sync.Mutex{},
|
RequestMutex: &sync.Mutex{},
|
||||||
@@ -67,6 +72,10 @@ func NewClaudeClient(cfg *config.Config, ts *claude.ClaudeTokenStorage) *ClaudeC
|
|||||||
apiKeyIndex: -1,
|
apiKeyIndex: -1,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Initialize model registry and register Claude models
|
||||||
|
client.InitializeModelRegistry(clientID)
|
||||||
|
client.RegisterModels("claude", registry.GetClaudeModels())
|
||||||
|
|
||||||
return client
|
return client
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -82,6 +91,10 @@ func NewClaudeClient(cfg *config.Config, ts *claude.ClaudeTokenStorage) *ClaudeC
|
|||||||
// - *ClaudeClient: A new Claude client instance.
|
// - *ClaudeClient: A new Claude client instance.
|
||||||
func NewClaudeClientWithKey(cfg *config.Config, apiKeyIndex int) *ClaudeClient {
|
func NewClaudeClientWithKey(cfg *config.Config, apiKeyIndex int) *ClaudeClient {
|
||||||
httpClient := util.SetProxy(cfg, &http.Client{})
|
httpClient := util.SetProxy(cfg, &http.Client{})
|
||||||
|
|
||||||
|
// Generate unique client ID for API key client
|
||||||
|
clientID := fmt.Sprintf("claude-apikey-%d-%d", apiKeyIndex, time.Now().UnixNano())
|
||||||
|
|
||||||
client := &ClaudeClient{
|
client := &ClaudeClient{
|
||||||
ClientBase: ClientBase{
|
ClientBase: ClientBase{
|
||||||
RequestMutex: &sync.Mutex{},
|
RequestMutex: &sync.Mutex{},
|
||||||
@@ -94,6 +107,10 @@ func NewClaudeClientWithKey(cfg *config.Config, apiKeyIndex int) *ClaudeClient {
|
|||||||
apiKeyIndex: apiKeyIndex,
|
apiKeyIndex: apiKeyIndex,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Initialize model registry and register Claude models
|
||||||
|
client.InitializeModelRegistry(clientID)
|
||||||
|
client.RegisterModels("claude", registry.GetClaudeModels())
|
||||||
|
|
||||||
return client
|
return client
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -174,10 +191,14 @@ func (c *ClaudeClient) SendRawMessage(ctx context.Context, modelName string, raw
|
|||||||
if err.StatusCode == 429 {
|
if err.StatusCode == 429 {
|
||||||
now := time.Now()
|
now := time.Now()
|
||||||
c.modelQuotaExceeded[modelName] = &now
|
c.modelQuotaExceeded[modelName] = &now
|
||||||
|
// Update model registry quota status
|
||||||
|
c.SetModelQuotaExceeded(modelName)
|
||||||
}
|
}
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
delete(c.modelQuotaExceeded, modelName)
|
delete(c.modelQuotaExceeded, modelName)
|
||||||
|
// Clear quota status in model registry
|
||||||
|
c.ClearModelQuotaExceeded(modelName)
|
||||||
bodyBytes, errReadAll := io.ReadAll(respBody)
|
bodyBytes, errReadAll := io.ReadAll(respBody)
|
||||||
if errReadAll != nil {
|
if errReadAll != nil {
|
||||||
return nil, &interfaces.ErrorMessage{StatusCode: 500, Error: errReadAll}
|
return nil, &interfaces.ErrorMessage{StatusCode: 500, Error: errReadAll}
|
||||||
@@ -234,11 +255,15 @@ func (c *ClaudeClient) SendRawMessageStream(ctx context.Context, modelName strin
|
|||||||
if err.StatusCode == 429 {
|
if err.StatusCode == 429 {
|
||||||
now := time.Now()
|
now := time.Now()
|
||||||
c.modelQuotaExceeded[modelName] = &now
|
c.modelQuotaExceeded[modelName] = &now
|
||||||
|
// Update model registry quota status
|
||||||
|
c.SetModelQuotaExceeded(modelName)
|
||||||
}
|
}
|
||||||
errChan <- err
|
errChan <- err
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
delete(c.modelQuotaExceeded, modelName)
|
delete(c.modelQuotaExceeded, modelName)
|
||||||
|
// Clear quota status in model registry
|
||||||
|
c.ClearModelQuotaExceeded(modelName)
|
||||||
defer func() {
|
defer func() {
|
||||||
_ = stream.Close()
|
_ = stream.Close()
|
||||||
}()
|
}()
|
||||||
|
|||||||
@@ -13,6 +13,7 @@ import (
|
|||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"github.com/luispater/CLIProxyAPI/internal/auth"
|
"github.com/luispater/CLIProxyAPI/internal/auth"
|
||||||
"github.com/luispater/CLIProxyAPI/internal/config"
|
"github.com/luispater/CLIProxyAPI/internal/config"
|
||||||
|
"github.com/luispater/CLIProxyAPI/internal/registry"
|
||||||
)
|
)
|
||||||
|
|
||||||
// ClientBase provides a common base structure for all AI API clients.
|
// ClientBase provides a common base structure for all AI API clients.
|
||||||
@@ -34,6 +35,12 @@ type ClientBase struct {
|
|||||||
// modelQuotaExceeded tracks when models have exceeded their quota.
|
// modelQuotaExceeded tracks when models have exceeded their quota.
|
||||||
// The map key is the model name, and the value is the time when the quota was exceeded.
|
// The map key is the model name, and the value is the time when the quota was exceeded.
|
||||||
modelQuotaExceeded map[string]*time.Time
|
modelQuotaExceeded map[string]*time.Time
|
||||||
|
|
||||||
|
// clientID is the unique identifier for this client instance.
|
||||||
|
clientID string
|
||||||
|
|
||||||
|
// modelRegistry is the global model registry for tracking model availability.
|
||||||
|
modelRegistry *registry.ModelRegistry
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetRequestMutex returns the mutex used to synchronize requests for this client.
|
// GetRequestMutex returns the mutex used to synchronize requests for this client.
|
||||||
@@ -71,3 +78,50 @@ func (c *ClientBase) AddAPIResponseData(ctx context.Context, line []byte) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// InitializeModelRegistry initializes the model registry for this client
|
||||||
|
// This should be called by all client implementations during construction
|
||||||
|
func (c *ClientBase) InitializeModelRegistry(clientID string) {
|
||||||
|
c.clientID = clientID
|
||||||
|
c.modelRegistry = registry.GetGlobalRegistry()
|
||||||
|
}
|
||||||
|
|
||||||
|
// RegisterModels registers the models that this client can provide
|
||||||
|
// Parameters:
|
||||||
|
// - provider: The provider name (e.g., "gemini", "claude", "openai")
|
||||||
|
// - models: The list of models this client supports
|
||||||
|
func (c *ClientBase) RegisterModels(provider string, models []*registry.ModelInfo) {
|
||||||
|
if c.modelRegistry != nil && c.clientID != "" {
|
||||||
|
c.modelRegistry.RegisterClient(c.clientID, provider, models)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// UnregisterClient removes this client from the model registry
|
||||||
|
func (c *ClientBase) UnregisterClient() {
|
||||||
|
if c.modelRegistry != nil && c.clientID != "" {
|
||||||
|
c.modelRegistry.UnregisterClient(c.clientID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetModelQuotaExceeded marks a model as quota exceeded in the registry
|
||||||
|
// Parameters:
|
||||||
|
// - modelID: The model that exceeded quota
|
||||||
|
func (c *ClientBase) SetModelQuotaExceeded(modelID string) {
|
||||||
|
if c.modelRegistry != nil && c.clientID != "" {
|
||||||
|
c.modelRegistry.SetModelQuotaExceeded(c.clientID, modelID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ClearModelQuotaExceeded clears quota exceeded status for a model
|
||||||
|
// Parameters:
|
||||||
|
// - modelID: The model to clear quota status for
|
||||||
|
func (c *ClientBase) ClearModelQuotaExceeded(modelID string) {
|
||||||
|
if c.modelRegistry != nil && c.clientID != "" {
|
||||||
|
c.modelRegistry.ClearModelQuotaExceeded(c.clientID, modelID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetClientID returns the unique identifier for this client
|
||||||
|
func (c *ClientBase) GetClientID() string {
|
||||||
|
return c.clientID
|
||||||
|
}
|
||||||
|
|||||||
@@ -22,6 +22,7 @@ import (
|
|||||||
"github.com/luispater/CLIProxyAPI/internal/config"
|
"github.com/luispater/CLIProxyAPI/internal/config"
|
||||||
. "github.com/luispater/CLIProxyAPI/internal/constant"
|
. "github.com/luispater/CLIProxyAPI/internal/constant"
|
||||||
"github.com/luispater/CLIProxyAPI/internal/interfaces"
|
"github.com/luispater/CLIProxyAPI/internal/interfaces"
|
||||||
|
"github.com/luispater/CLIProxyAPI/internal/registry"
|
||||||
"github.com/luispater/CLIProxyAPI/internal/translator/translator"
|
"github.com/luispater/CLIProxyAPI/internal/translator/translator"
|
||||||
"github.com/luispater/CLIProxyAPI/internal/util"
|
"github.com/luispater/CLIProxyAPI/internal/util"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
@@ -50,6 +51,10 @@ type CodexClient struct {
|
|||||||
// - error: An error if the client creation fails.
|
// - error: An error if the client creation fails.
|
||||||
func NewCodexClient(cfg *config.Config, ts *codex.CodexTokenStorage) (*CodexClient, error) {
|
func NewCodexClient(cfg *config.Config, ts *codex.CodexTokenStorage) (*CodexClient, error) {
|
||||||
httpClient := util.SetProxy(cfg, &http.Client{})
|
httpClient := util.SetProxy(cfg, &http.Client{})
|
||||||
|
|
||||||
|
// Generate unique client ID
|
||||||
|
clientID := fmt.Sprintf("codex-%d", time.Now().UnixNano())
|
||||||
|
|
||||||
client := &CodexClient{
|
client := &CodexClient{
|
||||||
ClientBase: ClientBase{
|
ClientBase: ClientBase{
|
||||||
RequestMutex: &sync.Mutex{},
|
RequestMutex: &sync.Mutex{},
|
||||||
@@ -61,6 +66,10 @@ func NewCodexClient(cfg *config.Config, ts *codex.CodexTokenStorage) (*CodexClie
|
|||||||
codexAuth: codex.NewCodexAuth(cfg),
|
codexAuth: codex.NewCodexAuth(cfg),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Initialize model registry and register OpenAI models
|
||||||
|
client.InitializeModelRegistry(clientID)
|
||||||
|
client.RegisterModels("codex", registry.GetOpenAIModels())
|
||||||
|
|
||||||
return client, nil
|
return client, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -123,10 +132,14 @@ func (c *CodexClient) SendRawMessage(ctx context.Context, modelName string, rawJ
|
|||||||
if err.StatusCode == 429 {
|
if err.StatusCode == 429 {
|
||||||
now := time.Now()
|
now := time.Now()
|
||||||
c.modelQuotaExceeded[modelName] = &now
|
c.modelQuotaExceeded[modelName] = &now
|
||||||
|
// Update model registry quota status
|
||||||
|
c.SetModelQuotaExceeded(modelName)
|
||||||
}
|
}
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
delete(c.modelQuotaExceeded, modelName)
|
delete(c.modelQuotaExceeded, modelName)
|
||||||
|
// Clear quota status in model registry
|
||||||
|
c.ClearModelQuotaExceeded(modelName)
|
||||||
bodyBytes, errReadAll := io.ReadAll(respBody)
|
bodyBytes, errReadAll := io.ReadAll(respBody)
|
||||||
if errReadAll != nil {
|
if errReadAll != nil {
|
||||||
return nil, &interfaces.ErrorMessage{StatusCode: 500, Error: errReadAll}
|
return nil, &interfaces.ErrorMessage{StatusCode: 500, Error: errReadAll}
|
||||||
@@ -184,11 +197,15 @@ func (c *CodexClient) SendRawMessageStream(ctx context.Context, modelName string
|
|||||||
if err.StatusCode == 429 {
|
if err.StatusCode == 429 {
|
||||||
now := time.Now()
|
now := time.Now()
|
||||||
c.modelQuotaExceeded[modelName] = &now
|
c.modelQuotaExceeded[modelName] = &now
|
||||||
|
// Update model registry quota status
|
||||||
|
c.SetModelQuotaExceeded(modelName)
|
||||||
}
|
}
|
||||||
errChan <- err
|
errChan <- err
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
delete(c.modelQuotaExceeded, modelName)
|
delete(c.modelQuotaExceeded, modelName)
|
||||||
|
// Clear quota status in model registry
|
||||||
|
c.ClearModelQuotaExceeded(modelName)
|
||||||
defer func() {
|
defer func() {
|
||||||
_ = stream.Close()
|
_ = stream.Close()
|
||||||
}()
|
}()
|
||||||
|
|||||||
@@ -22,6 +22,7 @@ import (
|
|||||||
"github.com/luispater/CLIProxyAPI/internal/config"
|
"github.com/luispater/CLIProxyAPI/internal/config"
|
||||||
. "github.com/luispater/CLIProxyAPI/internal/constant"
|
. "github.com/luispater/CLIProxyAPI/internal/constant"
|
||||||
"github.com/luispater/CLIProxyAPI/internal/interfaces"
|
"github.com/luispater/CLIProxyAPI/internal/interfaces"
|
||||||
|
"github.com/luispater/CLIProxyAPI/internal/registry"
|
||||||
"github.com/luispater/CLIProxyAPI/internal/translator/translator"
|
"github.com/luispater/CLIProxyAPI/internal/translator/translator"
|
||||||
"github.com/luispater/CLIProxyAPI/internal/util"
|
"github.com/luispater/CLIProxyAPI/internal/util"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
@@ -57,6 +58,9 @@ type GeminiCLIClient struct {
|
|||||||
// Returns:
|
// Returns:
|
||||||
// - *GeminiCLIClient: A new Gemini CLI client instance.
|
// - *GeminiCLIClient: A new Gemini CLI client instance.
|
||||||
func NewGeminiCLIClient(httpClient *http.Client, ts *geminiAuth.GeminiTokenStorage, cfg *config.Config) *GeminiCLIClient {
|
func NewGeminiCLIClient(httpClient *http.Client, ts *geminiAuth.GeminiTokenStorage, cfg *config.Config) *GeminiCLIClient {
|
||||||
|
// Generate unique client ID
|
||||||
|
clientID := fmt.Sprintf("gemini-cli-%d", time.Now().UnixNano())
|
||||||
|
|
||||||
client := &GeminiCLIClient{
|
client := &GeminiCLIClient{
|
||||||
ClientBase: ClientBase{
|
ClientBase: ClientBase{
|
||||||
RequestMutex: &sync.Mutex{},
|
RequestMutex: &sync.Mutex{},
|
||||||
@@ -66,6 +70,11 @@ func NewGeminiCLIClient(httpClient *http.Client, ts *geminiAuth.GeminiTokenStora
|
|||||||
modelQuotaExceeded: make(map[string]*time.Time),
|
modelQuotaExceeded: make(map[string]*time.Time),
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Initialize model registry and register Gemini models
|
||||||
|
client.InitializeModelRegistry(clientID)
|
||||||
|
client.RegisterModels("gemini-cli", registry.GetGeminiModels())
|
||||||
|
|
||||||
return client
|
return client
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -426,6 +435,8 @@ func (c *GeminiCLIClient) SendRawTokenCount(ctx context.Context, modelName strin
|
|||||||
if err.StatusCode == 429 {
|
if err.StatusCode == 429 {
|
||||||
now := time.Now()
|
now := time.Now()
|
||||||
c.modelQuotaExceeded[modelName] = &now
|
c.modelQuotaExceeded[modelName] = &now
|
||||||
|
// Update model registry quota status
|
||||||
|
c.SetModelQuotaExceeded(modelName)
|
||||||
if c.cfg.QuotaExceeded.SwitchPreviewModel {
|
if c.cfg.QuotaExceeded.SwitchPreviewModel {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
@@ -433,6 +444,8 @@ func (c *GeminiCLIClient) SendRawTokenCount(ctx context.Context, modelName strin
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
delete(c.modelQuotaExceeded, modelName)
|
delete(c.modelQuotaExceeded, modelName)
|
||||||
|
// Clear quota status in model registry
|
||||||
|
c.ClearModelQuotaExceeded(modelName)
|
||||||
bodyBytes, errReadAll := io.ReadAll(respBody)
|
bodyBytes, errReadAll := io.ReadAll(respBody)
|
||||||
if errReadAll != nil {
|
if errReadAll != nil {
|
||||||
return nil, &interfaces.ErrorMessage{StatusCode: 500, Error: errReadAll}
|
return nil, &interfaces.ErrorMessage{StatusCode: 500, Error: errReadAll}
|
||||||
@@ -485,6 +498,8 @@ func (c *GeminiCLIClient) SendRawMessage(ctx context.Context, modelName string,
|
|||||||
if err.StatusCode == 429 {
|
if err.StatusCode == 429 {
|
||||||
now := time.Now()
|
now := time.Now()
|
||||||
c.modelQuotaExceeded[modelName] = &now
|
c.modelQuotaExceeded[modelName] = &now
|
||||||
|
// Update model registry quota status
|
||||||
|
c.SetModelQuotaExceeded(modelName)
|
||||||
if c.cfg.QuotaExceeded.SwitchPreviewModel {
|
if c.cfg.QuotaExceeded.SwitchPreviewModel {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
@@ -492,6 +507,8 @@ func (c *GeminiCLIClient) SendRawMessage(ctx context.Context, modelName string,
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
delete(c.modelQuotaExceeded, modelName)
|
delete(c.modelQuotaExceeded, modelName)
|
||||||
|
// Clear quota status in model registry
|
||||||
|
c.ClearModelQuotaExceeded(modelName)
|
||||||
bodyBytes, errReadAll := io.ReadAll(respBody)
|
bodyBytes, errReadAll := io.ReadAll(respBody)
|
||||||
if errReadAll != nil {
|
if errReadAll != nil {
|
||||||
return nil, &interfaces.ErrorMessage{StatusCode: 500, Error: errReadAll}
|
return nil, &interfaces.ErrorMessage{StatusCode: 500, Error: errReadAll}
|
||||||
@@ -562,6 +579,8 @@ func (c *GeminiCLIClient) SendRawMessageStream(ctx context.Context, modelName st
|
|||||||
if err.StatusCode == 429 {
|
if err.StatusCode == 429 {
|
||||||
now := time.Now()
|
now := time.Now()
|
||||||
c.modelQuotaExceeded[modelName] = &now
|
c.modelQuotaExceeded[modelName] = &now
|
||||||
|
// Update model registry quota status
|
||||||
|
c.SetModelQuotaExceeded(modelName)
|
||||||
if c.cfg.QuotaExceeded.SwitchPreviewModel {
|
if c.cfg.QuotaExceeded.SwitchPreviewModel {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
@@ -570,6 +589,8 @@ func (c *GeminiCLIClient) SendRawMessageStream(ctx context.Context, modelName st
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
delete(c.modelQuotaExceeded, modelName)
|
delete(c.modelQuotaExceeded, modelName)
|
||||||
|
// Clear quota status in model registry
|
||||||
|
c.ClearModelQuotaExceeded(modelName)
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
defer func() {
|
defer func() {
|
||||||
|
|||||||
@@ -18,6 +18,7 @@ import (
|
|||||||
"github.com/luispater/CLIProxyAPI/internal/config"
|
"github.com/luispater/CLIProxyAPI/internal/config"
|
||||||
. "github.com/luispater/CLIProxyAPI/internal/constant"
|
. "github.com/luispater/CLIProxyAPI/internal/constant"
|
||||||
"github.com/luispater/CLIProxyAPI/internal/interfaces"
|
"github.com/luispater/CLIProxyAPI/internal/interfaces"
|
||||||
|
"github.com/luispater/CLIProxyAPI/internal/registry"
|
||||||
"github.com/luispater/CLIProxyAPI/internal/translator/translator"
|
"github.com/luispater/CLIProxyAPI/internal/translator/translator"
|
||||||
"github.com/luispater/CLIProxyAPI/internal/util"
|
"github.com/luispater/CLIProxyAPI/internal/util"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
@@ -44,6 +45,9 @@ type GeminiClient struct {
|
|||||||
// Returns:
|
// Returns:
|
||||||
// - *GeminiClient: A new Gemini client instance.
|
// - *GeminiClient: A new Gemini client instance.
|
||||||
func NewGeminiClient(httpClient *http.Client, cfg *config.Config, glAPIKey string) *GeminiClient {
|
func NewGeminiClient(httpClient *http.Client, cfg *config.Config, glAPIKey string) *GeminiClient {
|
||||||
|
// Generate unique client ID
|
||||||
|
clientID := fmt.Sprintf("gemini-apikey-%s-%d", glAPIKey[:8], time.Now().UnixNano()) // Use first 8 chars of API key
|
||||||
|
|
||||||
client := &GeminiClient{
|
client := &GeminiClient{
|
||||||
ClientBase: ClientBase{
|
ClientBase: ClientBase{
|
||||||
RequestMutex: &sync.Mutex{},
|
RequestMutex: &sync.Mutex{},
|
||||||
@@ -53,6 +57,11 @@ func NewGeminiClient(httpClient *http.Client, cfg *config.Config, glAPIKey strin
|
|||||||
},
|
},
|
||||||
glAPIKey: glAPIKey,
|
glAPIKey: glAPIKey,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Initialize model registry and register Gemini models
|
||||||
|
client.InitializeModelRegistry(clientID)
|
||||||
|
client.RegisterModels("gemini", registry.GetGeminiModels())
|
||||||
|
|
||||||
return client
|
return client
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -195,10 +204,14 @@ func (c *GeminiClient) SendRawTokenCount(ctx context.Context, modelName string,
|
|||||||
if err.StatusCode == 429 {
|
if err.StatusCode == 429 {
|
||||||
now := time.Now()
|
now := time.Now()
|
||||||
c.modelQuotaExceeded[modelName] = &now
|
c.modelQuotaExceeded[modelName] = &now
|
||||||
|
// Update model registry quota status
|
||||||
|
c.SetModelQuotaExceeded(modelName)
|
||||||
}
|
}
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
delete(c.modelQuotaExceeded, modelName)
|
delete(c.modelQuotaExceeded, modelName)
|
||||||
|
// Clear quota status in model registry
|
||||||
|
c.ClearModelQuotaExceeded(modelName)
|
||||||
bodyBytes, errReadAll := io.ReadAll(respBody)
|
bodyBytes, errReadAll := io.ReadAll(respBody)
|
||||||
if errReadAll != nil {
|
if errReadAll != nil {
|
||||||
return nil, &interfaces.ErrorMessage{StatusCode: 500, Error: errReadAll}
|
return nil, &interfaces.ErrorMessage{StatusCode: 500, Error: errReadAll}
|
||||||
@@ -240,10 +253,14 @@ func (c *GeminiClient) SendRawMessage(ctx context.Context, modelName string, raw
|
|||||||
if err.StatusCode == 429 {
|
if err.StatusCode == 429 {
|
||||||
now := time.Now()
|
now := time.Now()
|
||||||
c.modelQuotaExceeded[modelName] = &now
|
c.modelQuotaExceeded[modelName] = &now
|
||||||
|
// Update model registry quota status
|
||||||
|
c.SetModelQuotaExceeded(modelName)
|
||||||
}
|
}
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
delete(c.modelQuotaExceeded, modelName)
|
delete(c.modelQuotaExceeded, modelName)
|
||||||
|
// Clear quota status in model registry
|
||||||
|
c.ClearModelQuotaExceeded(modelName)
|
||||||
bodyBytes, errReadAll := io.ReadAll(respBody)
|
bodyBytes, errReadAll := io.ReadAll(respBody)
|
||||||
if errReadAll != nil {
|
if errReadAll != nil {
|
||||||
return nil, &interfaces.ErrorMessage{StatusCode: 500, Error: errReadAll}
|
return nil, &interfaces.ErrorMessage{StatusCode: 500, Error: errReadAll}
|
||||||
@@ -297,11 +314,15 @@ func (c *GeminiClient) SendRawMessageStream(ctx context.Context, modelName strin
|
|||||||
if err.StatusCode == 429 {
|
if err.StatusCode == 429 {
|
||||||
now := time.Now()
|
now := time.Now()
|
||||||
c.modelQuotaExceeded[modelName] = &now
|
c.modelQuotaExceeded[modelName] = &now
|
||||||
|
// Update model registry quota status
|
||||||
|
c.SetModelQuotaExceeded(modelName)
|
||||||
}
|
}
|
||||||
errChan <- err
|
errChan <- err
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
delete(c.modelQuotaExceeded, modelName)
|
delete(c.modelQuotaExceeded, modelName)
|
||||||
|
// Clear quota status in model registry
|
||||||
|
c.ClearModelQuotaExceeded(modelName)
|
||||||
defer func() {
|
defer func() {
|
||||||
_ = stream.Close()
|
_ = stream.Close()
|
||||||
}()
|
}()
|
||||||
|
|||||||
@@ -19,6 +19,7 @@ import (
|
|||||||
"github.com/luispater/CLIProxyAPI/internal/config"
|
"github.com/luispater/CLIProxyAPI/internal/config"
|
||||||
. "github.com/luispater/CLIProxyAPI/internal/constant"
|
. "github.com/luispater/CLIProxyAPI/internal/constant"
|
||||||
"github.com/luispater/CLIProxyAPI/internal/interfaces"
|
"github.com/luispater/CLIProxyAPI/internal/interfaces"
|
||||||
|
"github.com/luispater/CLIProxyAPI/internal/registry"
|
||||||
"github.com/luispater/CLIProxyAPI/internal/translator/translator"
|
"github.com/luispater/CLIProxyAPI/internal/translator/translator"
|
||||||
"github.com/luispater/CLIProxyAPI/internal/util"
|
"github.com/luispater/CLIProxyAPI/internal/util"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
@@ -53,6 +54,10 @@ func NewOpenAICompatibilityClient(cfg *config.Config, compatConfig *config.OpenA
|
|||||||
}
|
}
|
||||||
|
|
||||||
httpClient := util.SetProxy(cfg, &http.Client{})
|
httpClient := util.SetProxy(cfg, &http.Client{})
|
||||||
|
|
||||||
|
// Generate unique client ID
|
||||||
|
clientID := fmt.Sprintf("openai-compatibility-%s-%d", compatConfig.Name, time.Now().UnixNano())
|
||||||
|
|
||||||
client := &OpenAICompatibilityClient{
|
client := &OpenAICompatibilityClient{
|
||||||
ClientBase: ClientBase{
|
ClientBase: ClientBase{
|
||||||
RequestMutex: &sync.Mutex{},
|
RequestMutex: &sync.Mutex{},
|
||||||
@@ -64,6 +69,25 @@ func NewOpenAICompatibilityClient(cfg *config.Config, compatConfig *config.OpenA
|
|||||||
currentAPIKeyIndex: 0,
|
currentAPIKeyIndex: 0,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Initialize model registry
|
||||||
|
client.InitializeModelRegistry(clientID)
|
||||||
|
|
||||||
|
// Convert compatibility models to registry models and register them
|
||||||
|
registryModels := make([]*registry.ModelInfo, 0, len(compatConfig.Models))
|
||||||
|
for _, model := range compatConfig.Models {
|
||||||
|
registryModel := ®istry.ModelInfo{
|
||||||
|
ID: model.Alias,
|
||||||
|
Object: "model",
|
||||||
|
Created: time.Now().Unix(),
|
||||||
|
OwnedBy: compatConfig.Name,
|
||||||
|
Type: "openai-compatibility",
|
||||||
|
DisplayName: model.Name,
|
||||||
|
}
|
||||||
|
registryModels = append(registryModels, registryModel)
|
||||||
|
}
|
||||||
|
|
||||||
|
client.RegisterModels(compatConfig.Name, registryModels)
|
||||||
|
|
||||||
return client, nil
|
return client, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -216,10 +240,14 @@ func (c *OpenAICompatibilityClient) SendRawMessage(ctx context.Context, modelNam
|
|||||||
if err.StatusCode == 429 {
|
if err.StatusCode == 429 {
|
||||||
now := time.Now()
|
now := time.Now()
|
||||||
c.modelQuotaExceeded[modelName] = &now
|
c.modelQuotaExceeded[modelName] = &now
|
||||||
|
// Update model registry quota status
|
||||||
|
c.SetModelQuotaExceeded(modelName)
|
||||||
}
|
}
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
delete(c.modelQuotaExceeded, modelName)
|
delete(c.modelQuotaExceeded, modelName)
|
||||||
|
// Clear quota status in model registry
|
||||||
|
c.ClearModelQuotaExceeded(modelName)
|
||||||
bodyBytes, errReadAll := io.ReadAll(respBody)
|
bodyBytes, errReadAll := io.ReadAll(respBody)
|
||||||
if errReadAll != nil {
|
if errReadAll != nil {
|
||||||
return nil, &interfaces.ErrorMessage{StatusCode: 500, Error: errReadAll}
|
return nil, &interfaces.ErrorMessage{StatusCode: 500, Error: errReadAll}
|
||||||
@@ -270,11 +298,15 @@ func (c *OpenAICompatibilityClient) SendRawMessageStream(ctx context.Context, mo
|
|||||||
if err.StatusCode == 429 {
|
if err.StatusCode == 429 {
|
||||||
now := time.Now()
|
now := time.Now()
|
||||||
c.modelQuotaExceeded[modelName] = &now
|
c.modelQuotaExceeded[modelName] = &now
|
||||||
|
// Update model registry quota status
|
||||||
|
c.SetModelQuotaExceeded(modelName)
|
||||||
}
|
}
|
||||||
errChan <- err
|
errChan <- err
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
delete(c.modelQuotaExceeded, modelName)
|
delete(c.modelQuotaExceeded, modelName)
|
||||||
|
// Clear quota status in model registry
|
||||||
|
c.ClearModelQuotaExceeded(modelName)
|
||||||
defer func() {
|
defer func() {
|
||||||
_ = stream.Close()
|
_ = stream.Close()
|
||||||
}()
|
}()
|
||||||
|
|||||||
@@ -22,6 +22,7 @@ import (
|
|||||||
"github.com/luispater/CLIProxyAPI/internal/config"
|
"github.com/luispater/CLIProxyAPI/internal/config"
|
||||||
. "github.com/luispater/CLIProxyAPI/internal/constant"
|
. "github.com/luispater/CLIProxyAPI/internal/constant"
|
||||||
"github.com/luispater/CLIProxyAPI/internal/interfaces"
|
"github.com/luispater/CLIProxyAPI/internal/interfaces"
|
||||||
|
"github.com/luispater/CLIProxyAPI/internal/registry"
|
||||||
"github.com/luispater/CLIProxyAPI/internal/translator/translator"
|
"github.com/luispater/CLIProxyAPI/internal/translator/translator"
|
||||||
"github.com/luispater/CLIProxyAPI/internal/util"
|
"github.com/luispater/CLIProxyAPI/internal/util"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
@@ -49,6 +50,10 @@ type QwenClient struct {
|
|||||||
// - *QwenClient: A new Qwen client instance.
|
// - *QwenClient: A new Qwen client instance.
|
||||||
func NewQwenClient(cfg *config.Config, ts *qwen.QwenTokenStorage) *QwenClient {
|
func NewQwenClient(cfg *config.Config, ts *qwen.QwenTokenStorage) *QwenClient {
|
||||||
httpClient := util.SetProxy(cfg, &http.Client{})
|
httpClient := util.SetProxy(cfg, &http.Client{})
|
||||||
|
|
||||||
|
// Generate unique client ID
|
||||||
|
clientID := fmt.Sprintf("qwen-%d", time.Now().UnixNano())
|
||||||
|
|
||||||
client := &QwenClient{
|
client := &QwenClient{
|
||||||
ClientBase: ClientBase{
|
ClientBase: ClientBase{
|
||||||
RequestMutex: &sync.Mutex{},
|
RequestMutex: &sync.Mutex{},
|
||||||
@@ -60,6 +65,10 @@ func NewQwenClient(cfg *config.Config, ts *qwen.QwenTokenStorage) *QwenClient {
|
|||||||
qwenAuth: qwen.NewQwenAuth(cfg),
|
qwenAuth: qwen.NewQwenAuth(cfg),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Initialize model registry and register Qwen models
|
||||||
|
client.InitializeModelRegistry(clientID)
|
||||||
|
client.RegisterModels("qwen", registry.GetQwenModels())
|
||||||
|
|
||||||
return client
|
return client
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -119,10 +128,14 @@ func (c *QwenClient) SendRawMessage(ctx context.Context, modelName string, rawJS
|
|||||||
if err.StatusCode == 429 {
|
if err.StatusCode == 429 {
|
||||||
now := time.Now()
|
now := time.Now()
|
||||||
c.modelQuotaExceeded[modelName] = &now
|
c.modelQuotaExceeded[modelName] = &now
|
||||||
|
// Update model registry quota status
|
||||||
|
c.SetModelQuotaExceeded(modelName)
|
||||||
}
|
}
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
delete(c.modelQuotaExceeded, modelName)
|
delete(c.modelQuotaExceeded, modelName)
|
||||||
|
// Clear quota status in model registry
|
||||||
|
c.ClearModelQuotaExceeded(modelName)
|
||||||
bodyBytes, errReadAll := io.ReadAll(respBody)
|
bodyBytes, errReadAll := io.ReadAll(respBody)
|
||||||
if errReadAll != nil {
|
if errReadAll != nil {
|
||||||
return nil, &interfaces.ErrorMessage{StatusCode: 500, Error: errReadAll}
|
return nil, &interfaces.ErrorMessage{StatusCode: 500, Error: errReadAll}
|
||||||
@@ -182,11 +195,15 @@ func (c *QwenClient) SendRawMessageStream(ctx context.Context, modelName string,
|
|||||||
if err.StatusCode == 429 {
|
if err.StatusCode == 429 {
|
||||||
now := time.Now()
|
now := time.Now()
|
||||||
c.modelQuotaExceeded[modelName] = &now
|
c.modelQuotaExceeded[modelName] = &now
|
||||||
|
// Update model registry quota status
|
||||||
|
c.SetModelQuotaExceeded(modelName)
|
||||||
}
|
}
|
||||||
errChan <- err
|
errChan <- err
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
delete(c.modelQuotaExceeded, modelName)
|
delete(c.modelQuotaExceeded, modelName)
|
||||||
|
// Clear quota status in model registry
|
||||||
|
c.ClearModelQuotaExceeded(modelName)
|
||||||
defer func() {
|
defer func() {
|
||||||
_ = stream.Close()
|
_ = stream.Close()
|
||||||
}()
|
}()
|
||||||
|
|||||||
150
internal/registry/model_definitions.go
Normal file
150
internal/registry/model_definitions.go
Normal file
@@ -0,0 +1,150 @@
|
|||||||
|
// Package registry provides model definitions for various AI service providers.
|
||||||
|
// This file contains static model definitions that can be used by clients
|
||||||
|
// when registering their supported models.
|
||||||
|
package registry
|
||||||
|
|
||||||
|
import "time"
|
||||||
|
|
||||||
|
// GetClaudeModels returns the standard Claude model definitions
|
||||||
|
func GetClaudeModels() []*ModelInfo {
|
||||||
|
return []*ModelInfo{
|
||||||
|
{
|
||||||
|
ID: "claude-opus-4-1-20250805",
|
||||||
|
Object: "model",
|
||||||
|
Created: 1722945600, // 2025-08-05
|
||||||
|
OwnedBy: "anthropic",
|
||||||
|
Type: "claude",
|
||||||
|
DisplayName: "Claude 4.1 Opus",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
ID: "claude-opus-4-20250514",
|
||||||
|
Object: "model",
|
||||||
|
Created: 1715644800, // 2025-05-14
|
||||||
|
OwnedBy: "anthropic",
|
||||||
|
Type: "claude",
|
||||||
|
DisplayName: "Claude 4 Opus",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
ID: "claude-sonnet-4-20250514",
|
||||||
|
Object: "model",
|
||||||
|
Created: 1715644800, // 2025-05-14
|
||||||
|
OwnedBy: "anthropic",
|
||||||
|
Type: "claude",
|
||||||
|
DisplayName: "Claude 4 Sonnet",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
ID: "claude-3-7-sonnet-20250219",
|
||||||
|
Object: "model",
|
||||||
|
Created: 1708300800, // 2025-02-19
|
||||||
|
OwnedBy: "anthropic",
|
||||||
|
Type: "claude",
|
||||||
|
DisplayName: "Claude 3.7 Sonnet",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
ID: "claude-3-5-haiku-20241022",
|
||||||
|
Object: "model",
|
||||||
|
Created: 1729555200, // 2024-10-22
|
||||||
|
OwnedBy: "anthropic",
|
||||||
|
Type: "claude",
|
||||||
|
DisplayName: "Claude 3.5 Haiku",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetGeminiModels returns the standard Gemini model definitions
|
||||||
|
func GetGeminiModels() []*ModelInfo {
|
||||||
|
return []*ModelInfo{
|
||||||
|
{
|
||||||
|
ID: "gemini-2.5-flash",
|
||||||
|
Object: "model",
|
||||||
|
Created: time.Now().Unix(),
|
||||||
|
OwnedBy: "google",
|
||||||
|
Type: "gemini",
|
||||||
|
Name: "models/gemini-2.5-flash",
|
||||||
|
Version: "001",
|
||||||
|
DisplayName: "Gemini 2.5 Flash",
|
||||||
|
Description: "Stable version of Gemini 2.5 Flash, our mid-size multimodal model that supports up to 1 million tokens, released in June of 2025.",
|
||||||
|
InputTokenLimit: 1048576,
|
||||||
|
OutputTokenLimit: 65536,
|
||||||
|
SupportedGenerationMethods: []string{"generateContent", "countTokens", "createCachedContent", "batchGenerateContent"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
ID: "gemini-2.5-pro",
|
||||||
|
Object: "model",
|
||||||
|
Created: time.Now().Unix(),
|
||||||
|
OwnedBy: "google",
|
||||||
|
Type: "gemini",
|
||||||
|
Name: "models/gemini-2.5-pro",
|
||||||
|
Version: "2.5",
|
||||||
|
DisplayName: "Gemini 2.5 Pro",
|
||||||
|
Description: "Stable release (June 17th, 2025) of Gemini 2.5 Pro",
|
||||||
|
InputTokenLimit: 1048576,
|
||||||
|
OutputTokenLimit: 65536,
|
||||||
|
SupportedGenerationMethods: []string{"generateContent", "countTokens", "createCachedContent", "batchGenerateContent"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetOpenAIModels returns the standard OpenAI model definitions
|
||||||
|
func GetOpenAIModels() []*ModelInfo {
|
||||||
|
return []*ModelInfo{
|
||||||
|
{
|
||||||
|
ID: "gpt-5",
|
||||||
|
Object: "model",
|
||||||
|
Created: time.Now().Unix(),
|
||||||
|
OwnedBy: "openai",
|
||||||
|
Type: "openai",
|
||||||
|
Version: "gpt-5-2025-08-07",
|
||||||
|
DisplayName: "GPT 5",
|
||||||
|
Description: "Stable version of GPT 5, The best model for coding and agentic tasks across domains.",
|
||||||
|
ContextLength: 400000,
|
||||||
|
MaxCompletionTokens: 128000,
|
||||||
|
SupportedParameters: []string{"tools"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
ID: "codex-mini-latest",
|
||||||
|
Object: "model",
|
||||||
|
Created: time.Now().Unix(),
|
||||||
|
OwnedBy: "openai",
|
||||||
|
Type: "openai",
|
||||||
|
Version: "1.0",
|
||||||
|
DisplayName: "Codex Mini",
|
||||||
|
Description: "Lightweight code generation model",
|
||||||
|
ContextLength: 4096,
|
||||||
|
MaxCompletionTokens: 2048,
|
||||||
|
SupportedParameters: []string{"temperature", "max_tokens", "stream", "stop"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetQwenModels returns the standard Qwen model definitions
|
||||||
|
func GetQwenModels() []*ModelInfo {
|
||||||
|
return []*ModelInfo{
|
||||||
|
{
|
||||||
|
ID: "qwen3-coder-plus",
|
||||||
|
Object: "model",
|
||||||
|
Created: time.Now().Unix(),
|
||||||
|
OwnedBy: "qwen",
|
||||||
|
Type: "qwen",
|
||||||
|
Version: "3.0",
|
||||||
|
DisplayName: "Qwen3 Coder Plus",
|
||||||
|
Description: "Advanced code generation and understanding model",
|
||||||
|
ContextLength: 32768,
|
||||||
|
MaxCompletionTokens: 8192,
|
||||||
|
SupportedParameters: []string{"temperature", "top_p", "max_tokens", "stream", "stop"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
ID: "qwen3-coder-flash",
|
||||||
|
Object: "model",
|
||||||
|
Created: time.Now().Unix(),
|
||||||
|
OwnedBy: "qwen",
|
||||||
|
Type: "qwen",
|
||||||
|
Version: "3.0",
|
||||||
|
DisplayName: "Qwen3 Coder Flash",
|
||||||
|
Description: "Fast code generation model",
|
||||||
|
ContextLength: 8192,
|
||||||
|
MaxCompletionTokens: 2048,
|
||||||
|
SupportedParameters: []string{"temperature", "top_p", "max_tokens", "stream", "stop"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
374
internal/registry/model_registry.go
Normal file
374
internal/registry/model_registry.go
Normal file
@@ -0,0 +1,374 @@
|
|||||||
|
// Package registry provides centralized model management for all AI service providers.
|
||||||
|
// It implements a dynamic model registry with reference counting to track active clients
|
||||||
|
// and automatically hide models when no clients are available or when quota is exceeded.
|
||||||
|
package registry
|
||||||
|
|
||||||
|
import (
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ModelInfo represents information about an available model
|
||||||
|
type ModelInfo struct {
|
||||||
|
// ID is the unique identifier for the model
|
||||||
|
ID string `json:"id"`
|
||||||
|
// Object type for the model (typically "model")
|
||||||
|
Object string `json:"object"`
|
||||||
|
// Created timestamp when the model was created
|
||||||
|
Created int64 `json:"created"`
|
||||||
|
// OwnedBy indicates the organization that owns the model
|
||||||
|
OwnedBy string `json:"owned_by"`
|
||||||
|
// Type indicates the model type (e.g., "claude", "gemini", "openai")
|
||||||
|
Type string `json:"type"`
|
||||||
|
// DisplayName is the human-readable name for the model
|
||||||
|
DisplayName string `json:"display_name,omitempty"`
|
||||||
|
// Name is used for Gemini-style model names
|
||||||
|
Name string `json:"name,omitempty"`
|
||||||
|
// Version is the model version
|
||||||
|
Version string `json:"version,omitempty"`
|
||||||
|
// Description provides detailed information about the model
|
||||||
|
Description string `json:"description,omitempty"`
|
||||||
|
// InputTokenLimit is the maximum input token limit
|
||||||
|
InputTokenLimit int `json:"inputTokenLimit,omitempty"`
|
||||||
|
// OutputTokenLimit is the maximum output token limit
|
||||||
|
OutputTokenLimit int `json:"outputTokenLimit,omitempty"`
|
||||||
|
// SupportedGenerationMethods lists supported generation methods
|
||||||
|
SupportedGenerationMethods []string `json:"supportedGenerationMethods,omitempty"`
|
||||||
|
// ContextLength is the context window size
|
||||||
|
ContextLength int `json:"context_length,omitempty"`
|
||||||
|
// MaxCompletionTokens is the maximum completion tokens
|
||||||
|
MaxCompletionTokens int `json:"max_completion_tokens,omitempty"`
|
||||||
|
// SupportedParameters lists supported parameters
|
||||||
|
SupportedParameters []string `json:"supported_parameters,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// ModelRegistration tracks a model's availability
|
||||||
|
type ModelRegistration struct {
|
||||||
|
// Info contains the model metadata
|
||||||
|
Info *ModelInfo
|
||||||
|
// Count is the number of active clients that can provide this model
|
||||||
|
Count int
|
||||||
|
// LastUpdated tracks when this registration was last modified
|
||||||
|
LastUpdated time.Time
|
||||||
|
// QuotaExceededClients tracks which clients have exceeded quota for this model
|
||||||
|
QuotaExceededClients map[string]*time.Time
|
||||||
|
}
|
||||||
|
|
||||||
|
// ModelRegistry manages the global registry of available models
|
||||||
|
type ModelRegistry struct {
|
||||||
|
// models maps model ID to registration information
|
||||||
|
models map[string]*ModelRegistration
|
||||||
|
// clientModels maps client ID to the models it provides
|
||||||
|
clientModels map[string][]string
|
||||||
|
// mutex ensures thread-safe access to the registry
|
||||||
|
mutex *sync.RWMutex
|
||||||
|
}
|
||||||
|
|
||||||
|
// Global model registry instance
|
||||||
|
var globalRegistry *ModelRegistry
|
||||||
|
var registryOnce sync.Once
|
||||||
|
|
||||||
|
// GetGlobalRegistry returns the global model registry instance
|
||||||
|
func GetGlobalRegistry() *ModelRegistry {
|
||||||
|
registryOnce.Do(func() {
|
||||||
|
globalRegistry = &ModelRegistry{
|
||||||
|
models: make(map[string]*ModelRegistration),
|
||||||
|
clientModels: make(map[string][]string),
|
||||||
|
mutex: &sync.RWMutex{},
|
||||||
|
}
|
||||||
|
})
|
||||||
|
return globalRegistry
|
||||||
|
}
|
||||||
|
|
||||||
|
// RegisterClient registers a client and its supported models
|
||||||
|
// Parameters:
|
||||||
|
// - clientID: Unique identifier for the client
|
||||||
|
// - clientProvider: Provider name (e.g., "gemini", "claude", "openai")
|
||||||
|
// - models: List of models that this client can provide
|
||||||
|
func (r *ModelRegistry) RegisterClient(clientID, clientProvider string, models []*ModelInfo) {
|
||||||
|
r.mutex.Lock()
|
||||||
|
defer r.mutex.Unlock()
|
||||||
|
|
||||||
|
// Remove any existing registration for this client
|
||||||
|
r.unregisterClientInternal(clientID)
|
||||||
|
|
||||||
|
modelIDs := make([]string, 0, len(models))
|
||||||
|
now := time.Now()
|
||||||
|
|
||||||
|
for _, model := range models {
|
||||||
|
modelIDs = append(modelIDs, model.ID)
|
||||||
|
|
||||||
|
if existing, exists := r.models[model.ID]; exists {
|
||||||
|
// Model already exists, increment count
|
||||||
|
existing.Count++
|
||||||
|
existing.LastUpdated = now
|
||||||
|
log.Debugf("Incremented count for model %s, now %d clients", model.ID, existing.Count)
|
||||||
|
} else {
|
||||||
|
// New model, create registration
|
||||||
|
r.models[model.ID] = &ModelRegistration{
|
||||||
|
Info: model,
|
||||||
|
Count: 1,
|
||||||
|
LastUpdated: now,
|
||||||
|
QuotaExceededClients: make(map[string]*time.Time),
|
||||||
|
}
|
||||||
|
log.Debugf("Registered new model %s from provider %s", model.ID, clientProvider)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
r.clientModels[clientID] = modelIDs
|
||||||
|
log.Debugf("Registered client %s from provider %s with %d models", clientID, clientProvider, len(models))
|
||||||
|
}
|
||||||
|
|
||||||
|
// UnregisterClient removes a client and decrements counts for its models
|
||||||
|
// Parameters:
|
||||||
|
// - clientID: Unique identifier for the client to remove
|
||||||
|
func (r *ModelRegistry) UnregisterClient(clientID string) {
|
||||||
|
r.mutex.Lock()
|
||||||
|
defer r.mutex.Unlock()
|
||||||
|
r.unregisterClientInternal(clientID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// unregisterClientInternal performs the actual client unregistration (internal, no locking)
|
||||||
|
func (r *ModelRegistry) unregisterClientInternal(clientID string) {
|
||||||
|
models, exists := r.clientModels[clientID]
|
||||||
|
if !exists {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
now := time.Now()
|
||||||
|
for _, modelID := range models {
|
||||||
|
if registration, isExists := r.models[modelID]; isExists {
|
||||||
|
registration.Count--
|
||||||
|
registration.LastUpdated = now
|
||||||
|
|
||||||
|
// Remove quota tracking for this client
|
||||||
|
delete(registration.QuotaExceededClients, clientID)
|
||||||
|
|
||||||
|
log.Debugf("Decremented count for model %s, now %d clients", modelID, registration.Count)
|
||||||
|
|
||||||
|
// Remove model if no clients remain
|
||||||
|
if registration.Count <= 0 {
|
||||||
|
delete(r.models, modelID)
|
||||||
|
log.Debugf("Removed model %s as no clients remain", modelID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
delete(r.clientModels, clientID)
|
||||||
|
log.Debugf("Unregistered client %s", clientID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetModelQuotaExceeded marks a model as quota exceeded for a specific client
|
||||||
|
// Parameters:
|
||||||
|
// - clientID: The client that exceeded quota
|
||||||
|
// - modelID: The model that exceeded quota
|
||||||
|
func (r *ModelRegistry) SetModelQuotaExceeded(clientID, modelID string) {
|
||||||
|
r.mutex.Lock()
|
||||||
|
defer r.mutex.Unlock()
|
||||||
|
|
||||||
|
if registration, exists := r.models[modelID]; exists {
|
||||||
|
now := time.Now()
|
||||||
|
registration.QuotaExceededClients[clientID] = &now
|
||||||
|
log.Debugf("Marked model %s as quota exceeded for client %s", modelID, clientID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ClearModelQuotaExceeded removes quota exceeded status for a model and client
|
||||||
|
// Parameters:
|
||||||
|
// - clientID: The client to clear quota status for
|
||||||
|
// - modelID: The model to clear quota status for
|
||||||
|
func (r *ModelRegistry) ClearModelQuotaExceeded(clientID, modelID string) {
|
||||||
|
r.mutex.Lock()
|
||||||
|
defer r.mutex.Unlock()
|
||||||
|
|
||||||
|
if registration, exists := r.models[modelID]; exists {
|
||||||
|
delete(registration.QuotaExceededClients, clientID)
|
||||||
|
log.Debugf("Cleared quota exceeded status for model %s and client %s", modelID, clientID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetAvailableModels returns all models that have at least one available client
|
||||||
|
// Parameters:
|
||||||
|
// - handlerType: The handler type to filter models for (e.g., "openai", "claude", "gemini")
|
||||||
|
//
|
||||||
|
// Returns:
|
||||||
|
// - []map[string]any: List of available models in the requested format
|
||||||
|
func (r *ModelRegistry) GetAvailableModels(handlerType string) []map[string]any {
|
||||||
|
r.mutex.RLock()
|
||||||
|
defer r.mutex.RUnlock()
|
||||||
|
|
||||||
|
models := make([]map[string]any, 0)
|
||||||
|
quotaExpiredDuration := 5 * time.Minute
|
||||||
|
|
||||||
|
for _, registration := range r.models {
|
||||||
|
// Check if model has any non-quota-exceeded clients
|
||||||
|
availableClients := registration.Count
|
||||||
|
now := time.Now()
|
||||||
|
|
||||||
|
// Count clients that have exceeded quota but haven't recovered yet
|
||||||
|
expiredClients := 0
|
||||||
|
for _, quotaTime := range registration.QuotaExceededClients {
|
||||||
|
if quotaTime != nil && now.Sub(*quotaTime) < quotaExpiredDuration {
|
||||||
|
expiredClients++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
effectiveClients := availableClients - expiredClients
|
||||||
|
|
||||||
|
// Only include models that have available clients
|
||||||
|
if effectiveClients > 0 {
|
||||||
|
model := r.convertModelToMap(registration.Info, handlerType)
|
||||||
|
if model != nil {
|
||||||
|
models = append(models, model)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return models
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetModelCount returns the number of available clients for a specific model
|
||||||
|
// Parameters:
|
||||||
|
// - modelID: The model ID to check
|
||||||
|
//
|
||||||
|
// Returns:
|
||||||
|
// - int: Number of available clients for the model
|
||||||
|
func (r *ModelRegistry) GetModelCount(modelID string) int {
|
||||||
|
r.mutex.RLock()
|
||||||
|
defer r.mutex.RUnlock()
|
||||||
|
|
||||||
|
if registration, exists := r.models[modelID]; exists {
|
||||||
|
now := time.Now()
|
||||||
|
quotaExpiredDuration := 5 * time.Minute
|
||||||
|
|
||||||
|
// Count clients that have exceeded quota but haven't recovered yet
|
||||||
|
expiredClients := 0
|
||||||
|
for _, quotaTime := range registration.QuotaExceededClients {
|
||||||
|
if quotaTime != nil && now.Sub(*quotaTime) < quotaExpiredDuration {
|
||||||
|
expiredClients++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return registration.Count - expiredClients
|
||||||
|
}
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
|
||||||
|
// convertModelToMap converts ModelInfo to the appropriate format for different handler types
|
||||||
|
func (r *ModelRegistry) convertModelToMap(model *ModelInfo, handlerType string) map[string]any {
|
||||||
|
if model == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
switch handlerType {
|
||||||
|
case "openai":
|
||||||
|
result := map[string]any{
|
||||||
|
"id": model.ID,
|
||||||
|
"object": "model",
|
||||||
|
"owned_by": model.OwnedBy,
|
||||||
|
}
|
||||||
|
if model.Created > 0 {
|
||||||
|
result["created"] = model.Created
|
||||||
|
}
|
||||||
|
if model.Type != "" {
|
||||||
|
result["type"] = model.Type
|
||||||
|
}
|
||||||
|
if model.DisplayName != "" {
|
||||||
|
result["display_name"] = model.DisplayName
|
||||||
|
}
|
||||||
|
if model.Version != "" {
|
||||||
|
result["version"] = model.Version
|
||||||
|
}
|
||||||
|
if model.Description != "" {
|
||||||
|
result["description"] = model.Description
|
||||||
|
}
|
||||||
|
if model.ContextLength > 0 {
|
||||||
|
result["context_length"] = model.ContextLength
|
||||||
|
}
|
||||||
|
if model.MaxCompletionTokens > 0 {
|
||||||
|
result["max_completion_tokens"] = model.MaxCompletionTokens
|
||||||
|
}
|
||||||
|
if len(model.SupportedParameters) > 0 {
|
||||||
|
result["supported_parameters"] = model.SupportedParameters
|
||||||
|
}
|
||||||
|
return result
|
||||||
|
|
||||||
|
case "claude":
|
||||||
|
result := map[string]any{
|
||||||
|
"id": model.ID,
|
||||||
|
"object": "model",
|
||||||
|
"owned_by": model.OwnedBy,
|
||||||
|
}
|
||||||
|
if model.Created > 0 {
|
||||||
|
result["created"] = model.Created
|
||||||
|
}
|
||||||
|
if model.Type != "" {
|
||||||
|
result["type"] = model.Type
|
||||||
|
}
|
||||||
|
if model.DisplayName != "" {
|
||||||
|
result["display_name"] = model.DisplayName
|
||||||
|
}
|
||||||
|
return result
|
||||||
|
|
||||||
|
case "gemini":
|
||||||
|
result := map[string]any{}
|
||||||
|
if model.Name != "" {
|
||||||
|
result["name"] = model.Name
|
||||||
|
} else {
|
||||||
|
result["name"] = model.ID
|
||||||
|
}
|
||||||
|
if model.Version != "" {
|
||||||
|
result["version"] = model.Version
|
||||||
|
}
|
||||||
|
if model.DisplayName != "" {
|
||||||
|
result["displayName"] = model.DisplayName
|
||||||
|
}
|
||||||
|
if model.Description != "" {
|
||||||
|
result["description"] = model.Description
|
||||||
|
}
|
||||||
|
if model.InputTokenLimit > 0 {
|
||||||
|
result["inputTokenLimit"] = model.InputTokenLimit
|
||||||
|
}
|
||||||
|
if model.OutputTokenLimit > 0 {
|
||||||
|
result["outputTokenLimit"] = model.OutputTokenLimit
|
||||||
|
}
|
||||||
|
if len(model.SupportedGenerationMethods) > 0 {
|
||||||
|
result["supportedGenerationMethods"] = model.SupportedGenerationMethods
|
||||||
|
}
|
||||||
|
return result
|
||||||
|
|
||||||
|
default:
|
||||||
|
// Generic format
|
||||||
|
result := map[string]any{
|
||||||
|
"id": model.ID,
|
||||||
|
"object": "model",
|
||||||
|
}
|
||||||
|
if model.OwnedBy != "" {
|
||||||
|
result["owned_by"] = model.OwnedBy
|
||||||
|
}
|
||||||
|
if model.Type != "" {
|
||||||
|
result["type"] = model.Type
|
||||||
|
}
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// CleanupExpiredQuotas removes expired quota tracking entries
|
||||||
|
func (r *ModelRegistry) CleanupExpiredQuotas() {
|
||||||
|
r.mutex.Lock()
|
||||||
|
defer r.mutex.Unlock()
|
||||||
|
|
||||||
|
now := time.Now()
|
||||||
|
quotaExpiredDuration := 5 * time.Minute
|
||||||
|
|
||||||
|
for modelID, registration := range r.models {
|
||||||
|
for clientID, quotaTime := range registration.QuotaExceededClients {
|
||||||
|
if quotaTime != nil && now.Sub(*quotaTime) >= quotaExpiredDuration {
|
||||||
|
delete(registration.QuotaExceededClients, clientID)
|
||||||
|
log.Debugf("Cleaned up expired quota tracking for model %s, client %s", modelID, clientID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user