From e41d1277326c5e5b02e6f0f3cc540cd673664a12 Mon Sep 17 00:00:00 2001 From: Luis Pater Date: Mon, 22 Sep 2025 22:54:21 +0800 Subject: [PATCH] feat(openai-compat): enhance provider key handling and model resolution - Introduced dynamic `providerKey` resolution for OpenAI-compatible providers, incorporating attributes like `provider_key` and `compat_name`. - Implemented upstream model overrides via `resolveUpstreamModel` and `overrideModel` methods in the OpenAI executor. - Updated registry logic to correctly store provider mappings and register clients using normalized keys. - Ensured consistency in handling empty or default provider names across components. --- .../executor/openai_compat_executor.go | 68 +++++++++++++++++++ internal/watcher/watcher.go | 15 ++-- sdk/cliproxy/service.go | 37 +++++++--- 3 files changed, 106 insertions(+), 14 deletions(-) diff --git a/internal/runtime/executor/openai_compat_executor.go b/internal/runtime/executor/openai_compat_executor.go index b1abde84..0c23c4d1 100644 --- a/internal/runtime/executor/openai_compat_executor.go +++ b/internal/runtime/executor/openai_compat_executor.go @@ -14,6 +14,7 @@ import ( cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor" sdktranslator "github.com/router-for-me/CLIProxyAPI/v6/sdk/translator" log "github.com/sirupsen/logrus" + "github.com/tidwall/sjson" ) // OpenAICompatExecutor implements a stateless executor for OpenAI-compatible providers. @@ -47,6 +48,9 @@ func (e *OpenAICompatExecutor) Execute(ctx context.Context, auth *cliproxyauth.A from := opts.SourceFormat to := sdktranslator.FromString("openai") translated := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), opts.Stream) + if modelOverride := e.resolveUpstreamModel(req.Model, auth); modelOverride != "" { + translated = e.overrideModel(translated, modelOverride) + } url := strings.TrimSuffix(baseURL, "/") + "/chat/completions" recordAPIRequest(ctx, e.cfg, translated) @@ -91,6 +95,9 @@ func (e *OpenAICompatExecutor) ExecuteStream(ctx context.Context, auth *cliproxy from := opts.SourceFormat to := sdktranslator.FromString("openai") translated := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), true) + if modelOverride := e.resolveUpstreamModel(req.Model, auth); modelOverride != "" { + translated = e.overrideModel(translated, modelOverride) + } url := strings.TrimSuffix(baseURL, "/") + "/chat/completions" recordAPIRequest(ctx, e.cfg, translated) @@ -164,6 +171,67 @@ func (e *OpenAICompatExecutor) resolveCredentials(auth *cliproxyauth.Auth) (base return } +func (e *OpenAICompatExecutor) resolveUpstreamModel(alias string, auth *cliproxyauth.Auth) string { + if alias == "" || auth == nil || e.cfg == nil { + return "" + } + compat := e.resolveCompatConfig(auth) + if compat == nil { + return "" + } + for i := range compat.Models { + model := compat.Models[i] + if model.Alias != "" { + if strings.EqualFold(model.Alias, alias) { + if model.Name != "" { + return model.Name + } + return alias + } + continue + } + if strings.EqualFold(model.Name, alias) { + return model.Name + } + } + return "" +} + +func (e *OpenAICompatExecutor) resolveCompatConfig(auth *cliproxyauth.Auth) *config.OpenAICompatibility { + if auth == nil || e.cfg == nil { + return nil + } + candidates := make([]string, 0, 3) + if auth.Attributes != nil { + if v := strings.TrimSpace(auth.Attributes["compat_name"]); v != "" { + candidates = append(candidates, v) + } + if v := strings.TrimSpace(auth.Attributes["provider_key"]); v != "" { + candidates = append(candidates, v) + } + } + if v := strings.TrimSpace(auth.Provider); v != "" { + candidates = append(candidates, v) + } + for i := range e.cfg.OpenAICompatibility { + compat := &e.cfg.OpenAICompatibility[i] + for _, candidate := range candidates { + if candidate != "" && strings.EqualFold(strings.TrimSpace(candidate), compat.Name) { + return compat + } + } + } + return nil +} + +func (e *OpenAICompatExecutor) overrideModel(payload []byte, model string) []byte { + if len(payload) == 0 || model == "" { + return payload + } + payload, _ = sjson.SetBytes(payload, "model", model) + return payload +} + type statusErr struct { code int msg string diff --git a/internal/watcher/watcher.go b/internal/watcher/watcher.go index 458b8cf5..6f00b27b 100644 --- a/internal/watcher/watcher.go +++ b/internal/watcher/watcher.go @@ -474,19 +474,24 @@ func (w *Watcher) SnapshotCoreAuths() []*coreauth.Auth { } for i := range cfg.OpenAICompatibility { compat := &cfg.OpenAICompatibility[i] + providerName := strings.ToLower(strings.TrimSpace(compat.Name)) + if providerName == "" { + providerName = "openai-compatibility" + } base := compat.BaseURL for j := range compat.APIKeys { key := compat.APIKeys[j] a := &coreauth.Auth{ ID: fmt.Sprintf("openai-compatibility:%s:%d", compat.Name, j), - Provider: "openai-compatibility", + Provider: providerName, Label: compat.Name, Status: coreauth.StatusActive, Attributes: map[string]string{ - "source": fmt.Sprintf("config:%s#%d", compat.Name, j), - "base_url": base, - "api_key": key, - "compat_name": compat.Name, + "source": fmt.Sprintf("config:%s#%d", compat.Name, j), + "base_url": base, + "api_key": key, + "compat_name": compat.Name, + "provider_key": providerName, }, CreatedAt: now, UpdatedAt: now, diff --git a/sdk/cliproxy/service.go b/sdk/cliproxy/service.go index cd619041..80990122 100644 --- a/sdk/cliproxy/service.go +++ b/sdk/cliproxy/service.go @@ -295,7 +295,11 @@ func (s *Service) syncCoreAuthFromAuths(ctx context.Context, auths []*coreauth.A case "qwen": s.coreManager.RegisterExecutor(executor.NewQwenExecutor(s.cfg)) default: - s.coreManager.RegisterExecutor(executor.NewOpenAICompatExecutor("openai-compatibility", s.cfg)) + providerKey := strings.ToLower(strings.TrimSpace(a.Provider)) + if providerKey == "" { + providerKey = "openai-compatibility" + } + s.coreManager.RegisterExecutor(executor.NewOpenAICompatExecutor(providerKey, s.cfg)) } // Preserve existing temporal fields @@ -341,7 +345,7 @@ func (s *Service) registerModelsForAuth(a *coreauth.Auth) { } } } - provider := strings.ToLower(a.Provider) + provider := strings.ToLower(strings.TrimSpace(a.Provider)) var models []*ModelInfo switch provider { case "gemini": @@ -359,11 +363,19 @@ func (s *Service) registerModelsForAuth(a *coreauth.Auth) { default: // Handle OpenAI-compatibility providers by name using config if s.cfg != nil { - // When provider is normalized to "openai-compatibility", read the original name from attributes. - compatName := a.Provider - if strings.EqualFold(compatName, "openai-compatibility") { - if a.Attributes != nil && a.Attributes["compat_name"] != "" { - compatName = a.Attributes["compat_name"] + providerKey := provider + compatName := strings.TrimSpace(a.Provider) + if strings.EqualFold(providerKey, "openai-compatibility") { + if a.Attributes != nil { + if v := strings.TrimSpace(a.Attributes["compat_name"]); v != "" { + compatName = v + } + if v := strings.TrimSpace(a.Attributes["provider_key"]); v != "" { + providerKey = strings.ToLower(v) + } + } + if providerKey == "openai-compatibility" && compatName != "" { + providerKey = strings.ToLower(compatName) } } for i := range s.cfg.OpenAICompatibility { @@ -384,7 +396,10 @@ func (s *Service) registerModelsForAuth(a *coreauth.Auth) { } // Register and return if len(ms) > 0 { - GlobalModelRegistry().RegisterClient(a.ID, a.Provider, ms) + if providerKey == "" { + providerKey = "openai-compatibility" + } + GlobalModelRegistry().RegisterClient(a.ID, providerKey, ms) } return } @@ -392,6 +407,10 @@ func (s *Service) registerModelsForAuth(a *coreauth.Auth) { } } if len(models) > 0 { - GlobalModelRegistry().RegisterClient(a.ID, a.Provider, models) + key := provider + if key == "" { + key = strings.ToLower(strings.TrimSpace(a.Provider)) + } + GlobalModelRegistry().RegisterClient(a.ID, key, models) } }