Compare commits

...

8 Commits

Author SHA1 Message Date
Luis Pater
0defb68c6c fix(translator): improve error handling for function parameters schema transformation
- Added fallback to set default `parametersJsonSchema` when `parameters` key is absent.
- Enhanced logging to capture detailed errors during schema transformation.
- Refined tool declaration appending logic for robustness.
2025-10-28 22:57:26 +08:00
Luis Pater
d6272d3300 Merge pull request #177 from router-for-me/aistudio
feat(registry): unify Gemini models and add AI Studio set
2025-10-28 21:57:18 +08:00
hkfires
c99d0dfb33 fix(aistudio): remove no-op executor unregister on WS disconnect 2025-10-28 19:51:05 +08:00
hkfires
663b9b35ab fix(executor): pass authID to relay instead of provider 2025-10-28 19:28:26 +08:00
hkfires
5dced4c0a6 feat(registry): unify Gemini models and add AI Studio set 2025-10-28 19:00:25 +08:00
Luis Pater
5891785125 docs(readme): clarify model definition and add usage example for undefined models
- Updated `README.md` and `README_CN.md` to include additional instructions on requesting undefined models using the `openrouter://` format.
- Added example for `moonshotai/kimi-k2:free` usage.
2025-10-28 09:09:19 +08:00
Luis Pater
ac3d47e8c0 Merge pull request #173 from tobwen/feature/dynamic-model-routing
Add support for dynamic model providers
2025-10-28 08:55:08 +08:00
tobwen
e5ed2cba4a Add support for dynamic model providers
Implements functionality to parse model names with provider information in the format "provider://model" This allows dynamic provider selection rather than relying only on predefined mappings.

The change affects all execution methods to properly handle these dynamic model specifications while maintaining compatibility with the existing approach for standard model names.
2025-10-28 01:41:54 +01:00
11 changed files with 233 additions and 189 deletions

View File

@@ -415,7 +415,7 @@ openai-compatibility:
# api-keys:
# - "sk-or-v1-...b780"
# - "sk-or-v1-...b781"
models: # The models supported by the provider.
models: # The models supported by the provider. Or you can use a format such as openrouter://moonshotai/kimi-k2:free to request undefined models
- name: "moonshotai/kimi-k2:free" # The actual model name.
alias: "kimi-k2" # The alias used in the API.
```

View File

@@ -428,7 +428,7 @@ openai-compatibility:
# api-keys:
# - "sk-or-v1-...b780"
# - "sk-or-v1-...b781"
models: # 提供商支持的模型。
models: # 提供商支持的模型。或者你可以使用类似 openrouter://moonshotai/kimi-k2:free 这样的格式来请求未在这里定义的模型
- name: "moonshotai/kimi-k2:free" # 实际的模型名称。
alias: "kimi-k2" # 在API中使用的别名。
```

View File

@@ -146,6 +146,10 @@ func (MyExecutor) ExecuteStream(ctx context.Context, a *coreauth.Auth, req clipe
return ch, nil
}
func (MyExecutor) CountTokens(ctx context.Context, a *coreauth.Auth, req clipexec.Request, opts clipexec.Options) (clipexec.Response, error) {
return clipexec.Response{}, errors.New("not implemented")
}
func (MyExecutor) Refresh(ctx context.Context, a *coreauth.Auth) (*coreauth.Auth, error) {
return a, nil
}

View File

@@ -225,9 +225,13 @@ func NewServer(cfg *config.Config, authManager *auth.Manager, accessManager *sdk
envManagementSecret := envAdminPasswordSet && envAdminPassword != ""
// Create server instance
providerNames := make([]string, 0, len(cfg.OpenAICompatibility))
for _, p := range cfg.OpenAICompatibility {
providerNames = append(providerNames, p.Name)
}
s := &Server{
engine: engine,
handlers: handlers.NewBaseAPIHandlers(&cfg.SDKConfig, authManager),
handlers: handlers.NewBaseAPIHandlers(&cfg.SDKConfig, authManager, providerNames),
cfg: cfg,
accessManager: accessManager,
requestLogger: requestLogger,
@@ -823,6 +827,13 @@ func (s *Server) UpdateClients(cfg *config.Config) {
managementasset.SetCurrentConfig(cfg)
// Save YAML snapshot for next comparison
s.oldConfigYaml, _ = yaml.Marshal(cfg)
providerNames := make([]string, 0, len(cfg.OpenAICompatibility))
for _, p := range cfg.OpenAICompatibility {
providerNames = append(providerNames, p.Name)
}
s.handlers.OpenAICompatProviders = providerNames
s.handlers.UpdateClients(&cfg.SDKConfig)
if !cfg.RemoteManagement.DisableControlPanel {
@@ -903,5 +914,3 @@ func AuthMiddleware(manager *sdkaccess.Manager) gin.HandlerFunc {
}
}
}
// legacy clientsToSlice removed; handlers no longer consume legacy client slices

View File

@@ -68,84 +68,8 @@ func GetClaudeModels() []*ModelInfo {
}
}
// GetGeminiModels returns the standard Gemini model definitions
func GetGeminiModels() []*ModelInfo {
return []*ModelInfo{
{
ID: "gemini-2.5-flash",
Object: "model",
Created: time.Now().Unix(),
OwnedBy: "google",
Type: "gemini",
Name: "models/gemini-2.5-flash",
Version: "001",
DisplayName: "Gemini 2.5 Flash",
Description: "Stable version of Gemini 2.5 Flash, our mid-size multimodal model that supports up to 1 million tokens, released in June of 2025.",
InputTokenLimit: 1048576,
OutputTokenLimit: 65536,
SupportedGenerationMethods: []string{"generateContent", "countTokens", "createCachedContent", "batchGenerateContent"},
},
{
ID: "gemini-2.5-pro",
Object: "model",
Created: time.Now().Unix(),
OwnedBy: "google",
Type: "gemini",
Name: "models/gemini-2.5-pro",
Version: "2.5",
DisplayName: "Gemini 2.5 Pro",
Description: "Stable release (June 17th, 2025) of Gemini 2.5 Pro",
InputTokenLimit: 1048576,
OutputTokenLimit: 65536,
SupportedGenerationMethods: []string{"generateContent", "countTokens", "createCachedContent", "batchGenerateContent"},
},
{
ID: "gemini-2.5-flash-lite",
Object: "model",
Created: time.Now().Unix(),
OwnedBy: "google",
Type: "gemini",
Name: "models/gemini-2.5-flash-lite",
Version: "2.5",
DisplayName: "Gemini 2.5 Flash Lite",
Description: "Stable release (June 17th, 2025) of Gemini 2.5 Flash Lite",
InputTokenLimit: 1048576,
OutputTokenLimit: 65536,
SupportedGenerationMethods: []string{"generateContent", "countTokens", "createCachedContent", "batchGenerateContent"},
},
{
ID: "gemini-2.5-flash-image-preview",
Object: "model",
Created: time.Now().Unix(),
OwnedBy: "google",
Type: "gemini",
Name: "models/gemini-2.5-flash-image-preview",
Version: "2.5",
DisplayName: "Gemini 2.5 Flash Image Preview",
Description: "State-of-the-art image generation and editing model.",
InputTokenLimit: 1048576,
OutputTokenLimit: 8192,
SupportedGenerationMethods: []string{"generateContent", "countTokens", "createCachedContent", "batchGenerateContent"},
},
{
ID: "gemini-2.5-flash-image",
Object: "model",
Created: time.Now().Unix(),
OwnedBy: "google",
Type: "gemini",
Name: "models/gemini-2.5-flash-image",
Version: "2.5",
DisplayName: "Gemini 2.5 Flash Image",
Description: "State-of-the-art image generation and editing model.",
InputTokenLimit: 1048576,
OutputTokenLimit: 8192,
SupportedGenerationMethods: []string{"generateContent", "countTokens", "createCachedContent", "batchGenerateContent"},
},
}
}
// GetGeminiCLIModels returns the standard Gemini model definitions
func GetGeminiCLIModels() []*ModelInfo {
// GeminiModels returns the shared base Gemini model set used by multiple providers.
func GeminiModels() []*ModelInfo {
return []*ModelInfo{
{
ID: "gemini-2.5-flash",
@@ -220,6 +144,63 @@ func GetGeminiCLIModels() []*ModelInfo {
}
}
// GetGeminiModels returns the standard Gemini model definitions
func GetGeminiModels() []*ModelInfo { return GeminiModels() }
// GetGeminiCLIModels returns the standard Gemini model definitions
func GetGeminiCLIModels() []*ModelInfo { return GeminiModels() }
// GetAIStudioModels returns the Gemini model definitions for AI Studio integrations
func GetAIStudioModels() []*ModelInfo {
models := make([]*ModelInfo, 0, 8)
models = append(models, GeminiModels()...)
models = append(models,
&ModelInfo{
ID: "gemini-pro-latest",
Object: "model",
Created: time.Now().Unix(),
OwnedBy: "google",
Type: "gemini",
Name: "models/gemini-pro-latest",
Version: "2.5",
DisplayName: "Gemini Pro Latest",
Description: "Latest release of Gemini Pro",
InputTokenLimit: 1048576,
OutputTokenLimit: 65536,
SupportedGenerationMethods: []string{"generateContent", "countTokens", "createCachedContent", "batchGenerateContent"},
},
&ModelInfo{
ID: "gemini-flash-latest",
Object: "model",
Created: time.Now().Unix(),
OwnedBy: "google",
Type: "gemini",
Name: "models/gemini-flash-latest",
Version: "2.5",
DisplayName: "Gemini Flash Latest",
Description: "Latest release of Gemini Flash",
InputTokenLimit: 1048576,
OutputTokenLimit: 65536,
SupportedGenerationMethods: []string{"generateContent", "countTokens", "createCachedContent", "batchGenerateContent"},
},
&ModelInfo{
ID: "gemini-flash-lite-latest",
Object: "model",
Created: time.Now().Unix(),
OwnedBy: "google",
Type: "gemini",
Name: "models/gemini-flash-lite-latest",
Version: "2.5",
DisplayName: "Gemini Flash-Lite Latest",
Description: "Latest release of Gemini Flash-Lite",
InputTokenLimit: 1048576,
OutputTokenLimit: 65536,
SupportedGenerationMethods: []string{"generateContent", "countTokens", "createCachedContent", "batchGenerateContent"},
},
)
return models
}
// GetOpenAIModels returns the standard OpenAI model definitions
func GetOpenAIModels() []*ModelInfo {
return []*ModelInfo{
@@ -417,7 +398,6 @@ func GetIFlowModels() []*ModelInfo {
{ID: "qwen3-vl-plus", DisplayName: "Qwen3-VL-Plus", Description: "Qwen3 multimodal vision-language"},
{ID: "qwen3-max-preview", DisplayName: "Qwen3-Max-Preview", Description: "Qwen3 Max preview build"},
{ID: "kimi-k2-0905", DisplayName: "Kimi-K2-Instruct-0905", Description: "Moonshot Kimi K2 instruct 0905"},
{ID: "glm-4.5", DisplayName: "GLM-4.5", Description: "Zhipu GLM 4.5 general model"},
{ID: "glm-4.6", DisplayName: "GLM-4.6", Description: "Zhipu GLM 4.6 general model"},
{ID: "kimi-k2", DisplayName: "Kimi-K2", Description: "Moonshot Kimi K2 general model"},
{ID: "deepseek-v3.2", DisplayName: "DeepSeek-V3.2-Exp", Description: "DeepSeek V3.2 experimental"},

View File

@@ -19,27 +19,27 @@ import (
"github.com/tidwall/sjson"
)
// AistudioExecutor routes AI Studio requests through a websocket-backed transport.
type AistudioExecutor struct {
// AIStudioExecutor routes AI Studio requests through a websocket-backed transport.
type AIStudioExecutor struct {
provider string
relay *wsrelay.Manager
cfg *config.Config
}
// NewAistudioExecutor constructs a websocket executor for the provider name.
func NewAistudioExecutor(cfg *config.Config, provider string, relay *wsrelay.Manager) *AistudioExecutor {
return &AistudioExecutor{provider: strings.ToLower(provider), relay: relay, cfg: cfg}
// NewAIStudioExecutor constructs a websocket executor for the provider name.
func NewAIStudioExecutor(cfg *config.Config, provider string, relay *wsrelay.Manager) *AIStudioExecutor {
return &AIStudioExecutor{provider: strings.ToLower(provider), relay: relay, cfg: cfg}
}
// Identifier returns the provider key served by this executor.
func (e *AistudioExecutor) Identifier() string { return e.provider }
// Identifier returns the logical provider key for routing.
func (e *AIStudioExecutor) Identifier() string { return "aistudio" }
// PrepareRequest is a no-op because websocket transport already injects headers.
func (e *AistudioExecutor) PrepareRequest(_ *http.Request, _ *cliproxyauth.Auth) error {
func (e *AIStudioExecutor) PrepareRequest(_ *http.Request, _ *cliproxyauth.Auth) error {
return nil
}
func (e *AistudioExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (resp cliproxyexecutor.Response, err error) {
func (e *AIStudioExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (resp cliproxyexecutor.Response, err error) {
reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth)
defer reporter.trackFailure(ctx, &err)
@@ -66,14 +66,14 @@ func (e *AistudioExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth,
Method: http.MethodPost,
Headers: wsReq.Headers.Clone(),
Body: bytes.Clone(body.payload),
Provider: e.provider,
Provider: e.Identifier(),
AuthID: authID,
AuthLabel: authLabel,
AuthType: authType,
AuthValue: authValue,
})
wsResp, err := e.relay.NonStream(ctx, e.provider, wsReq)
wsResp, err := e.relay.NonStream(ctx, authID, wsReq)
if err != nil {
recordAPIResponseError(ctx, e.cfg, err)
return resp, err
@@ -92,7 +92,7 @@ func (e *AistudioExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth,
return resp, nil
}
func (e *AistudioExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (stream <-chan cliproxyexecutor.StreamChunk, err error) {
func (e *AIStudioExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (stream <-chan cliproxyexecutor.StreamChunk, err error) {
reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth)
defer reporter.trackFailure(ctx, &err)
@@ -118,13 +118,13 @@ func (e *AistudioExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth
Method: http.MethodPost,
Headers: wsReq.Headers.Clone(),
Body: bytes.Clone(body.payload),
Provider: e.provider,
Provider: e.Identifier(),
AuthID: authID,
AuthLabel: authLabel,
AuthType: authType,
AuthValue: authValue,
})
wsStream, err := e.relay.Stream(ctx, e.provider, wsReq)
wsStream, err := e.relay.Stream(ctx, authID, wsReq)
if err != nil {
recordAPIResponseError(ctx, e.cfg, err)
return nil, err
@@ -151,7 +151,7 @@ func (e *AistudioExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth
case wsrelay.MessageTypeStreamChunk:
if len(event.Payload) > 0 {
appendAPIResponseChunk(ctx, e.cfg, bytes.Clone(event.Payload))
filtered := filterAistudioUsageMetadata(event.Payload)
filtered := filterAIStudioUsageMetadata(event.Payload)
if detail, ok := parseGeminiStreamUsage(filtered); ok {
reporter.publish(ctx, detail)
}
@@ -188,7 +188,7 @@ func (e *AistudioExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth
return stream, nil
}
func (e *AistudioExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
func (e *AIStudioExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
_, body, err := e.translateRequest(req, opts, false)
if err != nil {
return cliproxyexecutor.Response{}, err
@@ -215,13 +215,13 @@ func (e *AistudioExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.A
Method: http.MethodPost,
Headers: wsReq.Headers.Clone(),
Body: bytes.Clone(body.payload),
Provider: e.provider,
Provider: e.Identifier(),
AuthID: authID,
AuthLabel: authLabel,
AuthType: authType,
AuthValue: authValue,
})
resp, err := e.relay.NonStream(ctx, e.provider, wsReq)
resp, err := e.relay.NonStream(ctx, authID, wsReq)
if err != nil {
recordAPIResponseError(ctx, e.cfg, err)
return cliproxyexecutor.Response{}, err
@@ -241,7 +241,7 @@ func (e *AistudioExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.A
return cliproxyexecutor.Response{Payload: []byte(translated)}, nil
}
func (e *AistudioExecutor) Refresh(ctx context.Context, auth *cliproxyauth.Auth) (*cliproxyauth.Auth, error) {
func (e *AIStudioExecutor) Refresh(ctx context.Context, auth *cliproxyauth.Auth) (*cliproxyauth.Auth, error) {
_ = ctx
return auth, nil
}
@@ -252,7 +252,7 @@ type translatedPayload struct {
toFormat sdktranslator.Format
}
func (e *AistudioExecutor) translateRequest(req cliproxyexecutor.Request, opts cliproxyexecutor.Options, stream bool) ([]byte, translatedPayload, error) {
func (e *AIStudioExecutor) translateRequest(req cliproxyexecutor.Request, opts cliproxyexecutor.Options, stream bool) ([]byte, translatedPayload, error) {
from := opts.SourceFormat
to := sdktranslator.FromString("gemini")
payload := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), stream)
@@ -275,7 +275,7 @@ func (e *AistudioExecutor) translateRequest(req cliproxyexecutor.Request, opts c
return payload, translatedPayload{payload: payload, action: action, toFormat: to}, nil
}
func (e *AistudioExecutor) buildEndpoint(model, action, alt string) string {
func (e *AIStudioExecutor) buildEndpoint(model, action, alt string) string {
base := fmt.Sprintf("%s/%s/models/%s:%s", glEndpoint, glAPIVersion, model, action)
if action == "streamGenerateContent" {
if alt == "" {
@@ -289,9 +289,9 @@ func (e *AistudioExecutor) buildEndpoint(model, action, alt string) string {
return base
}
// filterAistudioUsageMetadata removes usageMetadata from intermediate SSE events so that
// filterAIStudioUsageMetadata removes usageMetadata from intermediate SSE events so that
// only the terminal chunk retains token statistics.
func filterAistudioUsageMetadata(payload []byte) []byte {
func filterAIStudioUsageMetadata(payload []byte) []byte {
if len(payload) == 0 {
return payload
}

View File

@@ -250,8 +250,34 @@ func ConvertOpenAIRequestToGeminiCLI(modelName string, inputRawJSON []byte, _ bo
if t.Get("type").String() == "function" {
fn := t.Get("function")
if fn.Exists() && fn.IsObject() {
parametersJsonSchema, _ := util.RenameKey(fn.Raw, "parameters", "parametersJsonSchema")
out, _ = sjson.SetRawBytes(out, fdPath+".-1", []byte(parametersJsonSchema))
fnRaw := fn.Raw
if fn.Get("parameters").Exists() {
renamed, errRename := util.RenameKey(fnRaw, "parameters", "parametersJsonSchema")
if errRename != nil {
log.Warnf("Failed to rename parameters for tool '%s': %v", fn.Get("name").String(), errRename)
} else {
fnRaw = renamed
}
} else {
var errSet error
fnRaw, errSet = sjson.Set(fnRaw, "parametersJsonSchema.type", "object")
if errSet != nil {
log.Warnf("Failed to set default schema type for tool '%s': %v", fn.Get("name").String(), errSet)
continue
}
fnRaw, errSet = sjson.Set(fnRaw, "parametersJsonSchema.properties", map[string]interface{}{})
if errSet != nil {
log.Warnf("Failed to set default schema properties for tool '%s': %v", fn.Get("name").String(), errSet)
continue
}
}
tmp, errSet := sjson.SetRawBytes(out, fdPath+".-1", []byte(fnRaw))
if errSet != nil {
log.Warnf("Failed to append tool declaration for '%s': %v", fn.Get("name").String(), errSet)
continue
}
out = tmp
}
}
}

View File

@@ -4,6 +4,7 @@ package chat_completions
import (
"bytes"
"github.com/tidwall/sjson"
)
// ConvertOpenAIRequestToOpenAI converts an OpenAI Chat Completions request (raw JSON)
@@ -17,5 +18,14 @@ import (
// Returns:
// - []byte: The transformed request data in Gemini CLI API format
func ConvertOpenAIRequestToOpenAI(modelName string, inputRawJSON []byte, _ bool) []byte {
return bytes.Clone(inputRawJSON)
// Update the "model" field in the JSON payload with the provided modelName
// The sjson.SetBytes function returns a new byte slice with the updated JSON.
updatedJSON, err := sjson.SetBytes(inputRawJSON, "model", modelName)
if err != nil {
// If there's an error, return the original JSON or handle the error appropriately.
// For now, we'll return the original, but in a real scenario, logging or a more robust error
// handling mechanism would be needed.
return bytes.Clone(inputRawJSON)
}
return updatedJSON
}

View File

@@ -6,6 +6,7 @@ package handlers
import (
"fmt"
"net/http"
"strings"
"github.com/gin-gonic/gin"
"github.com/router-for-me/CLIProxyAPI/v6/internal/interfaces"
@@ -46,6 +47,9 @@ type BaseAPIHandler struct {
// Cfg holds the current application configuration.
Cfg *config.SDKConfig
// OpenAICompatProviders is a list of provider names for OpenAI compatibility.
OpenAICompatProviders []string
}
// NewBaseAPIHandlers creates a new API handlers instance.
@@ -57,10 +61,11 @@ type BaseAPIHandler struct {
//
// Returns:
// - *BaseAPIHandler: A new API handlers instance
func NewBaseAPIHandlers(cfg *config.SDKConfig, authManager *coreauth.Manager) *BaseAPIHandler {
func NewBaseAPIHandlers(cfg *config.SDKConfig, authManager *coreauth.Manager, openAICompatProviders []string) *BaseAPIHandler {
return &BaseAPIHandler{
Cfg: cfg,
AuthManager: authManager,
Cfg: cfg,
AuthManager: authManager,
OpenAICompatProviders: openAICompatProviders,
}
}
@@ -133,10 +138,9 @@ func (h *BaseAPIHandler) GetContextWithCancel(handler interfaces.APIHandler, c *
// ExecuteWithAuthManager executes a non-streaming request via the core auth manager.
// This path is the only supported execution route.
func (h *BaseAPIHandler) ExecuteWithAuthManager(ctx context.Context, handlerType, modelName string, rawJSON []byte, alt string) ([]byte, *interfaces.ErrorMessage) {
normalizedModel, metadata := normalizeModelMetadata(modelName)
providers := util.GetProviderName(normalizedModel)
if len(providers) == 0 {
return nil, &interfaces.ErrorMessage{StatusCode: http.StatusBadRequest, Error: fmt.Errorf("unknown provider for model %s", modelName)}
providers, normalizedModel, metadata, errMsg := h.getRequestDetails(modelName)
if errMsg != nil {
return nil, errMsg
}
req := coreexecutor.Request{
Model: normalizedModel,
@@ -176,10 +180,9 @@ func (h *BaseAPIHandler) ExecuteWithAuthManager(ctx context.Context, handlerType
// ExecuteCountWithAuthManager executes a non-streaming request via the core auth manager.
// This path is the only supported execution route.
func (h *BaseAPIHandler) ExecuteCountWithAuthManager(ctx context.Context, handlerType, modelName string, rawJSON []byte, alt string) ([]byte, *interfaces.ErrorMessage) {
normalizedModel, metadata := normalizeModelMetadata(modelName)
providers := util.GetProviderName(normalizedModel)
if len(providers) == 0 {
return nil, &interfaces.ErrorMessage{StatusCode: http.StatusBadRequest, Error: fmt.Errorf("unknown provider for model %s", modelName)}
providers, normalizedModel, metadata, errMsg := h.getRequestDetails(modelName)
if errMsg != nil {
return nil, errMsg
}
req := coreexecutor.Request{
Model: normalizedModel,
@@ -219,11 +222,10 @@ func (h *BaseAPIHandler) ExecuteCountWithAuthManager(ctx context.Context, handle
// ExecuteStreamWithAuthManager executes a streaming request via the core auth manager.
// This path is the only supported execution route.
func (h *BaseAPIHandler) ExecuteStreamWithAuthManager(ctx context.Context, handlerType, modelName string, rawJSON []byte, alt string) (<-chan []byte, <-chan *interfaces.ErrorMessage) {
normalizedModel, metadata := normalizeModelMetadata(modelName)
providers := util.GetProviderName(normalizedModel)
if len(providers) == 0 {
providers, normalizedModel, metadata, errMsg := h.getRequestDetails(modelName)
if errMsg != nil {
errChan := make(chan *interfaces.ErrorMessage, 1)
errChan <- &interfaces.ErrorMessage{StatusCode: http.StatusBadRequest, Error: fmt.Errorf("unknown provider for model %s", modelName)}
errChan <- errMsg
close(errChan)
return nil, errChan
}
@@ -292,6 +294,58 @@ func (h *BaseAPIHandler) ExecuteStreamWithAuthManager(ctx context.Context, handl
return dataChan, errChan
}
func (h *BaseAPIHandler) getRequestDetails(modelName string) (providers []string, normalizedModel string, metadata map[string]any, err *interfaces.ErrorMessage) {
providerName, extractedModelName, isDynamic := h.parseDynamicModel(modelName)
// First, normalize the model name to handle suffixes like "-thinking-128"
// This needs to happen before determining the provider for non-dynamic models.
normalizedModel, metadata = normalizeModelMetadata(modelName)
if isDynamic {
providers = []string{providerName}
// For dynamic models, the extractedModelName is already normalized by parseDynamicModel
// so we use it as the final normalizedModel.
normalizedModel = extractedModelName
} else {
// For non-dynamic models, use the normalizedModel to get the provider name.
providers = util.GetProviderName(normalizedModel)
}
if len(providers) == 0 {
return nil, "", nil, &interfaces.ErrorMessage{StatusCode: http.StatusBadRequest, Error: fmt.Errorf("unknown provider for model %s", modelName)}
}
// If it's a dynamic model, the normalizedModel was already set to extractedModelName.
// If it's a non-dynamic model, normalizedModel was set by normalizeModelMetadata.
// So, normalizedModel is already correctly set at this point.
return providers, normalizedModel, metadata, nil
}
func (h *BaseAPIHandler) parseDynamicModel(modelName string) (providerName, model string, isDynamic bool) {
var providerPart, modelPart string
for _, sep := range []string{"://"} {
if parts := strings.SplitN(modelName, sep, 2); len(parts) == 2 {
providerPart = parts[0]
modelPart = parts[1]
break
}
}
if providerPart == "" {
return "", modelName, false
}
// Check if the provider is a configured openai-compatibility provider
for _, pName := range h.OpenAICompatProviders {
if pName == providerPart {
return providerPart, modelPart, true
}
}
return "", modelName, false
}
func cloneBytes(src []byte) []byte {
if len(src) == 0 {
return nil

View File

@@ -157,15 +157,6 @@ func (a *Auth) AccountInfo() (string, string) {
return "oauth", v
}
}
if strings.HasPrefix(strings.ToLower(strings.TrimSpace(a.Provider)), "aistudio-") {
if label := strings.TrimSpace(a.Label); label != "" {
return "oauth", label
}
if id := strings.TrimSpace(a.ID); id != "" {
return "oauth", id
}
return "oauth", "aistudio"
}
if a.Attributes != nil {
if v := a.Attributes["api_key"]; v != "" {
return "api_key", v

View File

@@ -194,15 +194,15 @@ func (s *Service) ensureWebsocketGateway() {
s.wsGateway = wsrelay.NewManager(opts)
}
func (s *Service) wsOnConnected(provider string) {
if s == nil || provider == "" {
func (s *Service) wsOnConnected(channelID string) {
if s == nil || channelID == "" {
return
}
if !strings.HasPrefix(strings.ToLower(provider), "aistudio-") {
if !strings.HasPrefix(strings.ToLower(channelID), "aistudio-") {
return
}
if s.coreManager != nil {
if existing, ok := s.coreManager.GetByID(provider); ok && existing != nil {
if existing, ok := s.coreManager.GetByID(channelID); ok && existing != nil {
if !existing.Disabled && existing.Status == coreauth.StatusActive {
return
}
@@ -210,36 +210,33 @@ func (s *Service) wsOnConnected(provider string) {
}
now := time.Now().UTC()
auth := &coreauth.Auth{
ID: provider,
Provider: provider,
Label: provider,
Status: coreauth.StatusActive,
CreatedAt: now,
UpdatedAt: now,
Attributes: map[string]string{"ws_provider": "gemini"},
ID: channelID, // keep channel identifier as ID
Provider: "aistudio", // logical provider for switch routing
Label: channelID, // display original channel id
Status: coreauth.StatusActive,
CreatedAt: now,
UpdatedAt: now,
Metadata: map[string]any{"email": channelID}, // inject email inline
}
log.Infof("websocket provider connected: %s", provider)
log.Infof("websocket provider connected: %s", channelID)
s.applyCoreAuthAddOrUpdate(context.Background(), auth)
}
func (s *Service) wsOnDisconnected(provider string, reason error) {
if s == nil || provider == "" {
func (s *Service) wsOnDisconnected(channelID string, reason error) {
if s == nil || channelID == "" {
return
}
if reason != nil {
if strings.Contains(reason.Error(), "replaced by new connection") {
log.Infof("websocket provider replaced: %s", provider)
log.Infof("websocket provider replaced: %s", channelID)
return
}
log.Warnf("websocket provider disconnected: %s (%v)", provider, reason)
log.Warnf("websocket provider disconnected: %s (%v)", channelID, reason)
} else {
log.Infof("websocket provider disconnected: %s", provider)
log.Infof("websocket provider disconnected: %s", channelID)
}
ctx := context.Background()
s.applyCoreAuthRemoval(ctx, provider)
if s.coreManager != nil {
s.coreManager.UnregisterExecutor(provider)
}
s.applyCoreAuthRemoval(ctx, channelID)
}
func (s *Service) applyCoreAuthAddOrUpdate(ctx context.Context, auth *coreauth.Auth) {
@@ -317,17 +314,16 @@ func (s *Service) ensureExecutorsForAuth(a *coreauth.Auth) {
s.coreManager.RegisterExecutor(executor.NewOpenAICompatExecutor(compatProviderKey, s.cfg))
return
}
if strings.HasPrefix(strings.ToLower(strings.TrimSpace(a.Provider)), "aistudio-") {
if s.wsGateway != nil {
s.coreManager.RegisterExecutor(executor.NewAistudioExecutor(s.cfg, a.Provider, s.wsGateway))
}
return
}
switch strings.ToLower(a.Provider) {
case "gemini":
s.coreManager.RegisterExecutor(executor.NewGeminiExecutor(s.cfg))
case "gemini-cli":
s.coreManager.RegisterExecutor(executor.NewGeminiCLIExecutor(s.cfg))
case "aistudio":
if s.wsGateway != nil {
s.coreManager.RegisterExecutor(executor.NewAIStudioExecutor(s.cfg, a.ID, s.wsGateway))
}
return
case "claude":
s.coreManager.RegisterExecutor(executor.NewClaudeExecutor(s.cfg))
case "codex":
@@ -609,13 +605,6 @@ func (s *Service) registerModelsForAuth(a *coreauth.Auth) {
}
provider := strings.ToLower(strings.TrimSpace(a.Provider))
compatProviderKey, compatDisplayName, compatDetected := openAICompatInfoFromAuth(a)
if a.Attributes != nil {
if strings.EqualFold(a.Attributes["ws_provider"], "gemini") {
models := mergeGeminiModels()
GlobalModelRegistry().RegisterClient(a.ID, provider, models)
return
}
}
if compatDetected {
provider = "openai-compatibility"
}
@@ -625,6 +614,8 @@ func (s *Service) registerModelsForAuth(a *coreauth.Auth) {
models = registry.GetGeminiModels()
case "gemini-cli":
models = registry.GetGeminiCLIModels()
case "aistudio":
models = registry.GetAIStudioModels()
case "claude":
models = registry.GetClaudeModels()
if entry := s.resolveConfigClaudeKey(a); entry != nil && len(entry.Models) > 0 {
@@ -726,27 +717,6 @@ func (s *Service) registerModelsForAuth(a *coreauth.Auth) {
}
}
func mergeGeminiModels() []*ModelInfo {
models := make([]*ModelInfo, 0, 16)
seen := make(map[string]struct{})
appendModels := func(items []*ModelInfo) {
for i := range items {
m := items[i]
if m == nil || m.ID == "" {
continue
}
if _, ok := seen[m.ID]; ok {
continue
}
seen[m.ID] = struct{}{}
models = append(models, m)
}
}
appendModels(registry.GetGeminiModels())
appendModels(registry.GetGeminiCLIModels())
return models
}
func (s *Service) resolveConfigClaudeKey(auth *coreauth.Auth) *config.ClaudeKey {
if auth == nil || s.cfg == nil {
return nil