feat(openai-compat): enhance provider key handling and model resolution

- Introduced dynamic `providerKey` resolution for OpenAI-compatible providers, incorporating attributes like `provider_key` and `compat_name`.
- Implemented upstream model overrides via `resolveUpstreamModel` and `overrideModel` methods in the OpenAI executor.
- Updated registry logic to correctly store provider mappings and register clients using normalized keys.
- Ensured consistency in handling empty or default provider names across components.
This commit is contained in:
Luis Pater
2025-09-22 22:54:21 +08:00
parent f1c4caf14a
commit e41d127732
3 changed files with 106 additions and 14 deletions

View File

@@ -14,6 +14,7 @@ import (
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor" cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
sdktranslator "github.com/router-for-me/CLIProxyAPI/v6/sdk/translator" sdktranslator "github.com/router-for-me/CLIProxyAPI/v6/sdk/translator"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/tidwall/sjson"
) )
// OpenAICompatExecutor implements a stateless executor for OpenAI-compatible providers. // OpenAICompatExecutor implements a stateless executor for OpenAI-compatible providers.
@@ -47,6 +48,9 @@ func (e *OpenAICompatExecutor) Execute(ctx context.Context, auth *cliproxyauth.A
from := opts.SourceFormat from := opts.SourceFormat
to := sdktranslator.FromString("openai") to := sdktranslator.FromString("openai")
translated := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), opts.Stream) translated := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), opts.Stream)
if modelOverride := e.resolveUpstreamModel(req.Model, auth); modelOverride != "" {
translated = e.overrideModel(translated, modelOverride)
}
url := strings.TrimSuffix(baseURL, "/") + "/chat/completions" url := strings.TrimSuffix(baseURL, "/") + "/chat/completions"
recordAPIRequest(ctx, e.cfg, translated) recordAPIRequest(ctx, e.cfg, translated)
@@ -91,6 +95,9 @@ func (e *OpenAICompatExecutor) ExecuteStream(ctx context.Context, auth *cliproxy
from := opts.SourceFormat from := opts.SourceFormat
to := sdktranslator.FromString("openai") to := sdktranslator.FromString("openai")
translated := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), true) translated := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), true)
if modelOverride := e.resolveUpstreamModel(req.Model, auth); modelOverride != "" {
translated = e.overrideModel(translated, modelOverride)
}
url := strings.TrimSuffix(baseURL, "/") + "/chat/completions" url := strings.TrimSuffix(baseURL, "/") + "/chat/completions"
recordAPIRequest(ctx, e.cfg, translated) recordAPIRequest(ctx, e.cfg, translated)
@@ -164,6 +171,67 @@ func (e *OpenAICompatExecutor) resolveCredentials(auth *cliproxyauth.Auth) (base
return return
} }
func (e *OpenAICompatExecutor) resolveUpstreamModel(alias string, auth *cliproxyauth.Auth) string {
if alias == "" || auth == nil || e.cfg == nil {
return ""
}
compat := e.resolveCompatConfig(auth)
if compat == nil {
return ""
}
for i := range compat.Models {
model := compat.Models[i]
if model.Alias != "" {
if strings.EqualFold(model.Alias, alias) {
if model.Name != "" {
return model.Name
}
return alias
}
continue
}
if strings.EqualFold(model.Name, alias) {
return model.Name
}
}
return ""
}
func (e *OpenAICompatExecutor) resolveCompatConfig(auth *cliproxyauth.Auth) *config.OpenAICompatibility {
if auth == nil || e.cfg == nil {
return nil
}
candidates := make([]string, 0, 3)
if auth.Attributes != nil {
if v := strings.TrimSpace(auth.Attributes["compat_name"]); v != "" {
candidates = append(candidates, v)
}
if v := strings.TrimSpace(auth.Attributes["provider_key"]); v != "" {
candidates = append(candidates, v)
}
}
if v := strings.TrimSpace(auth.Provider); v != "" {
candidates = append(candidates, v)
}
for i := range e.cfg.OpenAICompatibility {
compat := &e.cfg.OpenAICompatibility[i]
for _, candidate := range candidates {
if candidate != "" && strings.EqualFold(strings.TrimSpace(candidate), compat.Name) {
return compat
}
}
}
return nil
}
func (e *OpenAICompatExecutor) overrideModel(payload []byte, model string) []byte {
if len(payload) == 0 || model == "" {
return payload
}
payload, _ = sjson.SetBytes(payload, "model", model)
return payload
}
type statusErr struct { type statusErr struct {
code int code int
msg string msg string

View File

@@ -474,19 +474,24 @@ func (w *Watcher) SnapshotCoreAuths() []*coreauth.Auth {
} }
for i := range cfg.OpenAICompatibility { for i := range cfg.OpenAICompatibility {
compat := &cfg.OpenAICompatibility[i] compat := &cfg.OpenAICompatibility[i]
providerName := strings.ToLower(strings.TrimSpace(compat.Name))
if providerName == "" {
providerName = "openai-compatibility"
}
base := compat.BaseURL base := compat.BaseURL
for j := range compat.APIKeys { for j := range compat.APIKeys {
key := compat.APIKeys[j] key := compat.APIKeys[j]
a := &coreauth.Auth{ a := &coreauth.Auth{
ID: fmt.Sprintf("openai-compatibility:%s:%d", compat.Name, j), ID: fmt.Sprintf("openai-compatibility:%s:%d", compat.Name, j),
Provider: "openai-compatibility", Provider: providerName,
Label: compat.Name, Label: compat.Name,
Status: coreauth.StatusActive, Status: coreauth.StatusActive,
Attributes: map[string]string{ Attributes: map[string]string{
"source": fmt.Sprintf("config:%s#%d", compat.Name, j), "source": fmt.Sprintf("config:%s#%d", compat.Name, j),
"base_url": base, "base_url": base,
"api_key": key, "api_key": key,
"compat_name": compat.Name, "compat_name": compat.Name,
"provider_key": providerName,
}, },
CreatedAt: now, CreatedAt: now,
UpdatedAt: now, UpdatedAt: now,

View File

@@ -295,7 +295,11 @@ func (s *Service) syncCoreAuthFromAuths(ctx context.Context, auths []*coreauth.A
case "qwen": case "qwen":
s.coreManager.RegisterExecutor(executor.NewQwenExecutor(s.cfg)) s.coreManager.RegisterExecutor(executor.NewQwenExecutor(s.cfg))
default: default:
s.coreManager.RegisterExecutor(executor.NewOpenAICompatExecutor("openai-compatibility", s.cfg)) providerKey := strings.ToLower(strings.TrimSpace(a.Provider))
if providerKey == "" {
providerKey = "openai-compatibility"
}
s.coreManager.RegisterExecutor(executor.NewOpenAICompatExecutor(providerKey, s.cfg))
} }
// Preserve existing temporal fields // Preserve existing temporal fields
@@ -341,7 +345,7 @@ func (s *Service) registerModelsForAuth(a *coreauth.Auth) {
} }
} }
} }
provider := strings.ToLower(a.Provider) provider := strings.ToLower(strings.TrimSpace(a.Provider))
var models []*ModelInfo var models []*ModelInfo
switch provider { switch provider {
case "gemini": case "gemini":
@@ -359,11 +363,19 @@ func (s *Service) registerModelsForAuth(a *coreauth.Auth) {
default: default:
// Handle OpenAI-compatibility providers by name using config // Handle OpenAI-compatibility providers by name using config
if s.cfg != nil { if s.cfg != nil {
// When provider is normalized to "openai-compatibility", read the original name from attributes. providerKey := provider
compatName := a.Provider compatName := strings.TrimSpace(a.Provider)
if strings.EqualFold(compatName, "openai-compatibility") { if strings.EqualFold(providerKey, "openai-compatibility") {
if a.Attributes != nil && a.Attributes["compat_name"] != "" { if a.Attributes != nil {
compatName = a.Attributes["compat_name"] if v := strings.TrimSpace(a.Attributes["compat_name"]); v != "" {
compatName = v
}
if v := strings.TrimSpace(a.Attributes["provider_key"]); v != "" {
providerKey = strings.ToLower(v)
}
}
if providerKey == "openai-compatibility" && compatName != "" {
providerKey = strings.ToLower(compatName)
} }
} }
for i := range s.cfg.OpenAICompatibility { for i := range s.cfg.OpenAICompatibility {
@@ -384,7 +396,10 @@ func (s *Service) registerModelsForAuth(a *coreauth.Auth) {
} }
// Register and return // Register and return
if len(ms) > 0 { if len(ms) > 0 {
GlobalModelRegistry().RegisterClient(a.ID, a.Provider, ms) if providerKey == "" {
providerKey = "openai-compatibility"
}
GlobalModelRegistry().RegisterClient(a.ID, providerKey, ms)
} }
return return
} }
@@ -392,6 +407,10 @@ func (s *Service) registerModelsForAuth(a *coreauth.Auth) {
} }
} }
if len(models) > 0 { if len(models) > 0 {
GlobalModelRegistry().RegisterClient(a.ID, a.Provider, models) key := provider
if key == "" {
key = strings.ToLower(strings.TrimSpace(a.Provider))
}
GlobalModelRegistry().RegisterClient(a.ID, key, models)
} }
} }