mirror of
https://github.com/router-for-me/CLIProxyAPI.git
synced 2026-02-03 04:50:52 +08:00
Improved the /v1/models endpoint
This commit is contained in:
150
internal/registry/model_definitions.go
Normal file
150
internal/registry/model_definitions.go
Normal file
@@ -0,0 +1,150 @@
|
||||
// Package registry provides model definitions for various AI service providers.
|
||||
// This file contains static model definitions that can be used by clients
|
||||
// when registering their supported models.
|
||||
package registry
|
||||
|
||||
import "time"
|
||||
|
||||
// GetClaudeModels returns the standard Claude model definitions
|
||||
func GetClaudeModels() []*ModelInfo {
|
||||
return []*ModelInfo{
|
||||
{
|
||||
ID: "claude-opus-4-1-20250805",
|
||||
Object: "model",
|
||||
Created: 1722945600, // 2025-08-05
|
||||
OwnedBy: "anthropic",
|
||||
Type: "claude",
|
||||
DisplayName: "Claude 4.1 Opus",
|
||||
},
|
||||
{
|
||||
ID: "claude-opus-4-20250514",
|
||||
Object: "model",
|
||||
Created: 1715644800, // 2025-05-14
|
||||
OwnedBy: "anthropic",
|
||||
Type: "claude",
|
||||
DisplayName: "Claude 4 Opus",
|
||||
},
|
||||
{
|
||||
ID: "claude-sonnet-4-20250514",
|
||||
Object: "model",
|
||||
Created: 1715644800, // 2025-05-14
|
||||
OwnedBy: "anthropic",
|
||||
Type: "claude",
|
||||
DisplayName: "Claude 4 Sonnet",
|
||||
},
|
||||
{
|
||||
ID: "claude-3-7-sonnet-20250219",
|
||||
Object: "model",
|
||||
Created: 1708300800, // 2025-02-19
|
||||
OwnedBy: "anthropic",
|
||||
Type: "claude",
|
||||
DisplayName: "Claude 3.7 Sonnet",
|
||||
},
|
||||
{
|
||||
ID: "claude-3-5-haiku-20241022",
|
||||
Object: "model",
|
||||
Created: 1729555200, // 2024-10-22
|
||||
OwnedBy: "anthropic",
|
||||
Type: "claude",
|
||||
DisplayName: "Claude 3.5 Haiku",
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// GetGeminiModels returns the standard Gemini model definitions
|
||||
func GetGeminiModels() []*ModelInfo {
|
||||
return []*ModelInfo{
|
||||
{
|
||||
ID: "gemini-2.5-flash",
|
||||
Object: "model",
|
||||
Created: time.Now().Unix(),
|
||||
OwnedBy: "google",
|
||||
Type: "gemini",
|
||||
Name: "models/gemini-2.5-flash",
|
||||
Version: "001",
|
||||
DisplayName: "Gemini 2.5 Flash",
|
||||
Description: "Stable version of Gemini 2.5 Flash, our mid-size multimodal model that supports up to 1 million tokens, released in June of 2025.",
|
||||
InputTokenLimit: 1048576,
|
||||
OutputTokenLimit: 65536,
|
||||
SupportedGenerationMethods: []string{"generateContent", "countTokens", "createCachedContent", "batchGenerateContent"},
|
||||
},
|
||||
{
|
||||
ID: "gemini-2.5-pro",
|
||||
Object: "model",
|
||||
Created: time.Now().Unix(),
|
||||
OwnedBy: "google",
|
||||
Type: "gemini",
|
||||
Name: "models/gemini-2.5-pro",
|
||||
Version: "2.5",
|
||||
DisplayName: "Gemini 2.5 Pro",
|
||||
Description: "Stable release (June 17th, 2025) of Gemini 2.5 Pro",
|
||||
InputTokenLimit: 1048576,
|
||||
OutputTokenLimit: 65536,
|
||||
SupportedGenerationMethods: []string{"generateContent", "countTokens", "createCachedContent", "batchGenerateContent"},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// GetOpenAIModels returns the standard OpenAI model definitions
|
||||
func GetOpenAIModels() []*ModelInfo {
|
||||
return []*ModelInfo{
|
||||
{
|
||||
ID: "gpt-5",
|
||||
Object: "model",
|
||||
Created: time.Now().Unix(),
|
||||
OwnedBy: "openai",
|
||||
Type: "openai",
|
||||
Version: "gpt-5-2025-08-07",
|
||||
DisplayName: "GPT 5",
|
||||
Description: "Stable version of GPT 5, The best model for coding and agentic tasks across domains.",
|
||||
ContextLength: 400000,
|
||||
MaxCompletionTokens: 128000,
|
||||
SupportedParameters: []string{"tools"},
|
||||
},
|
||||
{
|
||||
ID: "codex-mini-latest",
|
||||
Object: "model",
|
||||
Created: time.Now().Unix(),
|
||||
OwnedBy: "openai",
|
||||
Type: "openai",
|
||||
Version: "1.0",
|
||||
DisplayName: "Codex Mini",
|
||||
Description: "Lightweight code generation model",
|
||||
ContextLength: 4096,
|
||||
MaxCompletionTokens: 2048,
|
||||
SupportedParameters: []string{"temperature", "max_tokens", "stream", "stop"},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// GetQwenModels returns the standard Qwen model definitions
|
||||
func GetQwenModels() []*ModelInfo {
|
||||
return []*ModelInfo{
|
||||
{
|
||||
ID: "qwen3-coder-plus",
|
||||
Object: "model",
|
||||
Created: time.Now().Unix(),
|
||||
OwnedBy: "qwen",
|
||||
Type: "qwen",
|
||||
Version: "3.0",
|
||||
DisplayName: "Qwen3 Coder Plus",
|
||||
Description: "Advanced code generation and understanding model",
|
||||
ContextLength: 32768,
|
||||
MaxCompletionTokens: 8192,
|
||||
SupportedParameters: []string{"temperature", "top_p", "max_tokens", "stream", "stop"},
|
||||
},
|
||||
{
|
||||
ID: "qwen3-coder-flash",
|
||||
Object: "model",
|
||||
Created: time.Now().Unix(),
|
||||
OwnedBy: "qwen",
|
||||
Type: "qwen",
|
||||
Version: "3.0",
|
||||
DisplayName: "Qwen3 Coder Flash",
|
||||
Description: "Fast code generation model",
|
||||
ContextLength: 8192,
|
||||
MaxCompletionTokens: 2048,
|
||||
SupportedParameters: []string{"temperature", "top_p", "max_tokens", "stream", "stop"},
|
||||
},
|
||||
}
|
||||
}
|
||||
374
internal/registry/model_registry.go
Normal file
374
internal/registry/model_registry.go
Normal file
@@ -0,0 +1,374 @@
|
||||
// Package registry provides centralized model management for all AI service providers.
|
||||
// It implements a dynamic model registry with reference counting to track active clients
|
||||
// and automatically hide models when no clients are available or when quota is exceeded.
|
||||
package registry
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// ModelInfo represents information about an available model
|
||||
type ModelInfo struct {
|
||||
// ID is the unique identifier for the model
|
||||
ID string `json:"id"`
|
||||
// Object type for the model (typically "model")
|
||||
Object string `json:"object"`
|
||||
// Created timestamp when the model was created
|
||||
Created int64 `json:"created"`
|
||||
// OwnedBy indicates the organization that owns the model
|
||||
OwnedBy string `json:"owned_by"`
|
||||
// Type indicates the model type (e.g., "claude", "gemini", "openai")
|
||||
Type string `json:"type"`
|
||||
// DisplayName is the human-readable name for the model
|
||||
DisplayName string `json:"display_name,omitempty"`
|
||||
// Name is used for Gemini-style model names
|
||||
Name string `json:"name,omitempty"`
|
||||
// Version is the model version
|
||||
Version string `json:"version,omitempty"`
|
||||
// Description provides detailed information about the model
|
||||
Description string `json:"description,omitempty"`
|
||||
// InputTokenLimit is the maximum input token limit
|
||||
InputTokenLimit int `json:"inputTokenLimit,omitempty"`
|
||||
// OutputTokenLimit is the maximum output token limit
|
||||
OutputTokenLimit int `json:"outputTokenLimit,omitempty"`
|
||||
// SupportedGenerationMethods lists supported generation methods
|
||||
SupportedGenerationMethods []string `json:"supportedGenerationMethods,omitempty"`
|
||||
// ContextLength is the context window size
|
||||
ContextLength int `json:"context_length,omitempty"`
|
||||
// MaxCompletionTokens is the maximum completion tokens
|
||||
MaxCompletionTokens int `json:"max_completion_tokens,omitempty"`
|
||||
// SupportedParameters lists supported parameters
|
||||
SupportedParameters []string `json:"supported_parameters,omitempty"`
|
||||
}
|
||||
|
||||
// ModelRegistration tracks a model's availability
|
||||
type ModelRegistration struct {
|
||||
// Info contains the model metadata
|
||||
Info *ModelInfo
|
||||
// Count is the number of active clients that can provide this model
|
||||
Count int
|
||||
// LastUpdated tracks when this registration was last modified
|
||||
LastUpdated time.Time
|
||||
// QuotaExceededClients tracks which clients have exceeded quota for this model
|
||||
QuotaExceededClients map[string]*time.Time
|
||||
}
|
||||
|
||||
// ModelRegistry manages the global registry of available models
|
||||
type ModelRegistry struct {
|
||||
// models maps model ID to registration information
|
||||
models map[string]*ModelRegistration
|
||||
// clientModels maps client ID to the models it provides
|
||||
clientModels map[string][]string
|
||||
// mutex ensures thread-safe access to the registry
|
||||
mutex *sync.RWMutex
|
||||
}
|
||||
|
||||
// Global model registry instance
|
||||
var globalRegistry *ModelRegistry
|
||||
var registryOnce sync.Once
|
||||
|
||||
// GetGlobalRegistry returns the global model registry instance
|
||||
func GetGlobalRegistry() *ModelRegistry {
|
||||
registryOnce.Do(func() {
|
||||
globalRegistry = &ModelRegistry{
|
||||
models: make(map[string]*ModelRegistration),
|
||||
clientModels: make(map[string][]string),
|
||||
mutex: &sync.RWMutex{},
|
||||
}
|
||||
})
|
||||
return globalRegistry
|
||||
}
|
||||
|
||||
// RegisterClient registers a client and its supported models
|
||||
// Parameters:
|
||||
// - clientID: Unique identifier for the client
|
||||
// - clientProvider: Provider name (e.g., "gemini", "claude", "openai")
|
||||
// - models: List of models that this client can provide
|
||||
func (r *ModelRegistry) RegisterClient(clientID, clientProvider string, models []*ModelInfo) {
|
||||
r.mutex.Lock()
|
||||
defer r.mutex.Unlock()
|
||||
|
||||
// Remove any existing registration for this client
|
||||
r.unregisterClientInternal(clientID)
|
||||
|
||||
modelIDs := make([]string, 0, len(models))
|
||||
now := time.Now()
|
||||
|
||||
for _, model := range models {
|
||||
modelIDs = append(modelIDs, model.ID)
|
||||
|
||||
if existing, exists := r.models[model.ID]; exists {
|
||||
// Model already exists, increment count
|
||||
existing.Count++
|
||||
existing.LastUpdated = now
|
||||
log.Debugf("Incremented count for model %s, now %d clients", model.ID, existing.Count)
|
||||
} else {
|
||||
// New model, create registration
|
||||
r.models[model.ID] = &ModelRegistration{
|
||||
Info: model,
|
||||
Count: 1,
|
||||
LastUpdated: now,
|
||||
QuotaExceededClients: make(map[string]*time.Time),
|
||||
}
|
||||
log.Debugf("Registered new model %s from provider %s", model.ID, clientProvider)
|
||||
}
|
||||
}
|
||||
|
||||
r.clientModels[clientID] = modelIDs
|
||||
log.Debugf("Registered client %s from provider %s with %d models", clientID, clientProvider, len(models))
|
||||
}
|
||||
|
||||
// UnregisterClient removes a client and decrements counts for its models
|
||||
// Parameters:
|
||||
// - clientID: Unique identifier for the client to remove
|
||||
func (r *ModelRegistry) UnregisterClient(clientID string) {
|
||||
r.mutex.Lock()
|
||||
defer r.mutex.Unlock()
|
||||
r.unregisterClientInternal(clientID)
|
||||
}
|
||||
|
||||
// unregisterClientInternal performs the actual client unregistration (internal, no locking)
|
||||
func (r *ModelRegistry) unregisterClientInternal(clientID string) {
|
||||
models, exists := r.clientModels[clientID]
|
||||
if !exists {
|
||||
return
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
for _, modelID := range models {
|
||||
if registration, isExists := r.models[modelID]; isExists {
|
||||
registration.Count--
|
||||
registration.LastUpdated = now
|
||||
|
||||
// Remove quota tracking for this client
|
||||
delete(registration.QuotaExceededClients, clientID)
|
||||
|
||||
log.Debugf("Decremented count for model %s, now %d clients", modelID, registration.Count)
|
||||
|
||||
// Remove model if no clients remain
|
||||
if registration.Count <= 0 {
|
||||
delete(r.models, modelID)
|
||||
log.Debugf("Removed model %s as no clients remain", modelID)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
delete(r.clientModels, clientID)
|
||||
log.Debugf("Unregistered client %s", clientID)
|
||||
}
|
||||
|
||||
// SetModelQuotaExceeded marks a model as quota exceeded for a specific client
|
||||
// Parameters:
|
||||
// - clientID: The client that exceeded quota
|
||||
// - modelID: The model that exceeded quota
|
||||
func (r *ModelRegistry) SetModelQuotaExceeded(clientID, modelID string) {
|
||||
r.mutex.Lock()
|
||||
defer r.mutex.Unlock()
|
||||
|
||||
if registration, exists := r.models[modelID]; exists {
|
||||
now := time.Now()
|
||||
registration.QuotaExceededClients[clientID] = &now
|
||||
log.Debugf("Marked model %s as quota exceeded for client %s", modelID, clientID)
|
||||
}
|
||||
}
|
||||
|
||||
// ClearModelQuotaExceeded removes quota exceeded status for a model and client
|
||||
// Parameters:
|
||||
// - clientID: The client to clear quota status for
|
||||
// - modelID: The model to clear quota status for
|
||||
func (r *ModelRegistry) ClearModelQuotaExceeded(clientID, modelID string) {
|
||||
r.mutex.Lock()
|
||||
defer r.mutex.Unlock()
|
||||
|
||||
if registration, exists := r.models[modelID]; exists {
|
||||
delete(registration.QuotaExceededClients, clientID)
|
||||
log.Debugf("Cleared quota exceeded status for model %s and client %s", modelID, clientID)
|
||||
}
|
||||
}
|
||||
|
||||
// GetAvailableModels returns all models that have at least one available client
|
||||
// Parameters:
|
||||
// - handlerType: The handler type to filter models for (e.g., "openai", "claude", "gemini")
|
||||
//
|
||||
// Returns:
|
||||
// - []map[string]any: List of available models in the requested format
|
||||
func (r *ModelRegistry) GetAvailableModels(handlerType string) []map[string]any {
|
||||
r.mutex.RLock()
|
||||
defer r.mutex.RUnlock()
|
||||
|
||||
models := make([]map[string]any, 0)
|
||||
quotaExpiredDuration := 5 * time.Minute
|
||||
|
||||
for _, registration := range r.models {
|
||||
// Check if model has any non-quota-exceeded clients
|
||||
availableClients := registration.Count
|
||||
now := time.Now()
|
||||
|
||||
// Count clients that have exceeded quota but haven't recovered yet
|
||||
expiredClients := 0
|
||||
for _, quotaTime := range registration.QuotaExceededClients {
|
||||
if quotaTime != nil && now.Sub(*quotaTime) < quotaExpiredDuration {
|
||||
expiredClients++
|
||||
}
|
||||
}
|
||||
|
||||
effectiveClients := availableClients - expiredClients
|
||||
|
||||
// Only include models that have available clients
|
||||
if effectiveClients > 0 {
|
||||
model := r.convertModelToMap(registration.Info, handlerType)
|
||||
if model != nil {
|
||||
models = append(models, model)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return models
|
||||
}
|
||||
|
||||
// GetModelCount returns the number of available clients for a specific model
|
||||
// Parameters:
|
||||
// - modelID: The model ID to check
|
||||
//
|
||||
// Returns:
|
||||
// - int: Number of available clients for the model
|
||||
func (r *ModelRegistry) GetModelCount(modelID string) int {
|
||||
r.mutex.RLock()
|
||||
defer r.mutex.RUnlock()
|
||||
|
||||
if registration, exists := r.models[modelID]; exists {
|
||||
now := time.Now()
|
||||
quotaExpiredDuration := 5 * time.Minute
|
||||
|
||||
// Count clients that have exceeded quota but haven't recovered yet
|
||||
expiredClients := 0
|
||||
for _, quotaTime := range registration.QuotaExceededClients {
|
||||
if quotaTime != nil && now.Sub(*quotaTime) < quotaExpiredDuration {
|
||||
expiredClients++
|
||||
}
|
||||
}
|
||||
|
||||
return registration.Count - expiredClients
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
// convertModelToMap converts ModelInfo to the appropriate format for different handler types
|
||||
func (r *ModelRegistry) convertModelToMap(model *ModelInfo, handlerType string) map[string]any {
|
||||
if model == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
switch handlerType {
|
||||
case "openai":
|
||||
result := map[string]any{
|
||||
"id": model.ID,
|
||||
"object": "model",
|
||||
"owned_by": model.OwnedBy,
|
||||
}
|
||||
if model.Created > 0 {
|
||||
result["created"] = model.Created
|
||||
}
|
||||
if model.Type != "" {
|
||||
result["type"] = model.Type
|
||||
}
|
||||
if model.DisplayName != "" {
|
||||
result["display_name"] = model.DisplayName
|
||||
}
|
||||
if model.Version != "" {
|
||||
result["version"] = model.Version
|
||||
}
|
||||
if model.Description != "" {
|
||||
result["description"] = model.Description
|
||||
}
|
||||
if model.ContextLength > 0 {
|
||||
result["context_length"] = model.ContextLength
|
||||
}
|
||||
if model.MaxCompletionTokens > 0 {
|
||||
result["max_completion_tokens"] = model.MaxCompletionTokens
|
||||
}
|
||||
if len(model.SupportedParameters) > 0 {
|
||||
result["supported_parameters"] = model.SupportedParameters
|
||||
}
|
||||
return result
|
||||
|
||||
case "claude":
|
||||
result := map[string]any{
|
||||
"id": model.ID,
|
||||
"object": "model",
|
||||
"owned_by": model.OwnedBy,
|
||||
}
|
||||
if model.Created > 0 {
|
||||
result["created"] = model.Created
|
||||
}
|
||||
if model.Type != "" {
|
||||
result["type"] = model.Type
|
||||
}
|
||||
if model.DisplayName != "" {
|
||||
result["display_name"] = model.DisplayName
|
||||
}
|
||||
return result
|
||||
|
||||
case "gemini":
|
||||
result := map[string]any{}
|
||||
if model.Name != "" {
|
||||
result["name"] = model.Name
|
||||
} else {
|
||||
result["name"] = model.ID
|
||||
}
|
||||
if model.Version != "" {
|
||||
result["version"] = model.Version
|
||||
}
|
||||
if model.DisplayName != "" {
|
||||
result["displayName"] = model.DisplayName
|
||||
}
|
||||
if model.Description != "" {
|
||||
result["description"] = model.Description
|
||||
}
|
||||
if model.InputTokenLimit > 0 {
|
||||
result["inputTokenLimit"] = model.InputTokenLimit
|
||||
}
|
||||
if model.OutputTokenLimit > 0 {
|
||||
result["outputTokenLimit"] = model.OutputTokenLimit
|
||||
}
|
||||
if len(model.SupportedGenerationMethods) > 0 {
|
||||
result["supportedGenerationMethods"] = model.SupportedGenerationMethods
|
||||
}
|
||||
return result
|
||||
|
||||
default:
|
||||
// Generic format
|
||||
result := map[string]any{
|
||||
"id": model.ID,
|
||||
"object": "model",
|
||||
}
|
||||
if model.OwnedBy != "" {
|
||||
result["owned_by"] = model.OwnedBy
|
||||
}
|
||||
if model.Type != "" {
|
||||
result["type"] = model.Type
|
||||
}
|
||||
return result
|
||||
}
|
||||
}
|
||||
|
||||
// CleanupExpiredQuotas removes expired quota tracking entries
|
||||
func (r *ModelRegistry) CleanupExpiredQuotas() {
|
||||
r.mutex.Lock()
|
||||
defer r.mutex.Unlock()
|
||||
|
||||
now := time.Now()
|
||||
quotaExpiredDuration := 5 * time.Minute
|
||||
|
||||
for modelID, registration := range r.models {
|
||||
for clientID, quotaTime := range registration.QuotaExceededClients {
|
||||
if quotaTime != nil && now.Sub(*quotaTime) >= quotaExpiredDuration {
|
||||
delete(registration.QuotaExceededClients, clientID)
|
||||
log.Debugf("Cleaned up expired quota tracking for model %s, client %s", modelID, clientID)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user