mirror of
https://github.com/router-for-me/CLIProxyAPI.git
synced 2026-02-03 04:50:52 +08:00
feat(openai-compat): enhance provider key handling and model resolution
- Introduced dynamic `providerKey` resolution for OpenAI-compatible providers, incorporating attributes like `provider_key` and `compat_name`. - Implemented upstream model overrides via `resolveUpstreamModel` and `overrideModel` methods in the OpenAI executor. - Updated registry logic to correctly store provider mappings and register clients using normalized keys. - Ensured consistency in handling empty or default provider names across components.
This commit is contained in:
@@ -14,6 +14,7 @@ import (
|
|||||||
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
|
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
|
||||||
sdktranslator "github.com/router-for-me/CLIProxyAPI/v6/sdk/translator"
|
sdktranslator "github.com/router-for-me/CLIProxyAPI/v6/sdk/translator"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
"github.com/tidwall/sjson"
|
||||||
)
|
)
|
||||||
|
|
||||||
// OpenAICompatExecutor implements a stateless executor for OpenAI-compatible providers.
|
// OpenAICompatExecutor implements a stateless executor for OpenAI-compatible providers.
|
||||||
@@ -47,6 +48,9 @@ func (e *OpenAICompatExecutor) Execute(ctx context.Context, auth *cliproxyauth.A
|
|||||||
from := opts.SourceFormat
|
from := opts.SourceFormat
|
||||||
to := sdktranslator.FromString("openai")
|
to := sdktranslator.FromString("openai")
|
||||||
translated := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), opts.Stream)
|
translated := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), opts.Stream)
|
||||||
|
if modelOverride := e.resolveUpstreamModel(req.Model, auth); modelOverride != "" {
|
||||||
|
translated = e.overrideModel(translated, modelOverride)
|
||||||
|
}
|
||||||
|
|
||||||
url := strings.TrimSuffix(baseURL, "/") + "/chat/completions"
|
url := strings.TrimSuffix(baseURL, "/") + "/chat/completions"
|
||||||
recordAPIRequest(ctx, e.cfg, translated)
|
recordAPIRequest(ctx, e.cfg, translated)
|
||||||
@@ -91,6 +95,9 @@ func (e *OpenAICompatExecutor) ExecuteStream(ctx context.Context, auth *cliproxy
|
|||||||
from := opts.SourceFormat
|
from := opts.SourceFormat
|
||||||
to := sdktranslator.FromString("openai")
|
to := sdktranslator.FromString("openai")
|
||||||
translated := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), true)
|
translated := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), true)
|
||||||
|
if modelOverride := e.resolveUpstreamModel(req.Model, auth); modelOverride != "" {
|
||||||
|
translated = e.overrideModel(translated, modelOverride)
|
||||||
|
}
|
||||||
|
|
||||||
url := strings.TrimSuffix(baseURL, "/") + "/chat/completions"
|
url := strings.TrimSuffix(baseURL, "/") + "/chat/completions"
|
||||||
recordAPIRequest(ctx, e.cfg, translated)
|
recordAPIRequest(ctx, e.cfg, translated)
|
||||||
@@ -164,6 +171,67 @@ func (e *OpenAICompatExecutor) resolveCredentials(auth *cliproxyauth.Auth) (base
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (e *OpenAICompatExecutor) resolveUpstreamModel(alias string, auth *cliproxyauth.Auth) string {
|
||||||
|
if alias == "" || auth == nil || e.cfg == nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
compat := e.resolveCompatConfig(auth)
|
||||||
|
if compat == nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
for i := range compat.Models {
|
||||||
|
model := compat.Models[i]
|
||||||
|
if model.Alias != "" {
|
||||||
|
if strings.EqualFold(model.Alias, alias) {
|
||||||
|
if model.Name != "" {
|
||||||
|
return model.Name
|
||||||
|
}
|
||||||
|
return alias
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if strings.EqualFold(model.Name, alias) {
|
||||||
|
return model.Name
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *OpenAICompatExecutor) resolveCompatConfig(auth *cliproxyauth.Auth) *config.OpenAICompatibility {
|
||||||
|
if auth == nil || e.cfg == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
candidates := make([]string, 0, 3)
|
||||||
|
if auth.Attributes != nil {
|
||||||
|
if v := strings.TrimSpace(auth.Attributes["compat_name"]); v != "" {
|
||||||
|
candidates = append(candidates, v)
|
||||||
|
}
|
||||||
|
if v := strings.TrimSpace(auth.Attributes["provider_key"]); v != "" {
|
||||||
|
candidates = append(candidates, v)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if v := strings.TrimSpace(auth.Provider); v != "" {
|
||||||
|
candidates = append(candidates, v)
|
||||||
|
}
|
||||||
|
for i := range e.cfg.OpenAICompatibility {
|
||||||
|
compat := &e.cfg.OpenAICompatibility[i]
|
||||||
|
for _, candidate := range candidates {
|
||||||
|
if candidate != "" && strings.EqualFold(strings.TrimSpace(candidate), compat.Name) {
|
||||||
|
return compat
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *OpenAICompatExecutor) overrideModel(payload []byte, model string) []byte {
|
||||||
|
if len(payload) == 0 || model == "" {
|
||||||
|
return payload
|
||||||
|
}
|
||||||
|
payload, _ = sjson.SetBytes(payload, "model", model)
|
||||||
|
return payload
|
||||||
|
}
|
||||||
|
|
||||||
type statusErr struct {
|
type statusErr struct {
|
||||||
code int
|
code int
|
||||||
msg string
|
msg string
|
||||||
|
|||||||
@@ -474,19 +474,24 @@ func (w *Watcher) SnapshotCoreAuths() []*coreauth.Auth {
|
|||||||
}
|
}
|
||||||
for i := range cfg.OpenAICompatibility {
|
for i := range cfg.OpenAICompatibility {
|
||||||
compat := &cfg.OpenAICompatibility[i]
|
compat := &cfg.OpenAICompatibility[i]
|
||||||
|
providerName := strings.ToLower(strings.TrimSpace(compat.Name))
|
||||||
|
if providerName == "" {
|
||||||
|
providerName = "openai-compatibility"
|
||||||
|
}
|
||||||
base := compat.BaseURL
|
base := compat.BaseURL
|
||||||
for j := range compat.APIKeys {
|
for j := range compat.APIKeys {
|
||||||
key := compat.APIKeys[j]
|
key := compat.APIKeys[j]
|
||||||
a := &coreauth.Auth{
|
a := &coreauth.Auth{
|
||||||
ID: fmt.Sprintf("openai-compatibility:%s:%d", compat.Name, j),
|
ID: fmt.Sprintf("openai-compatibility:%s:%d", compat.Name, j),
|
||||||
Provider: "openai-compatibility",
|
Provider: providerName,
|
||||||
Label: compat.Name,
|
Label: compat.Name,
|
||||||
Status: coreauth.StatusActive,
|
Status: coreauth.StatusActive,
|
||||||
Attributes: map[string]string{
|
Attributes: map[string]string{
|
||||||
"source": fmt.Sprintf("config:%s#%d", compat.Name, j),
|
"source": fmt.Sprintf("config:%s#%d", compat.Name, j),
|
||||||
"base_url": base,
|
"base_url": base,
|
||||||
"api_key": key,
|
"api_key": key,
|
||||||
"compat_name": compat.Name,
|
"compat_name": compat.Name,
|
||||||
|
"provider_key": providerName,
|
||||||
},
|
},
|
||||||
CreatedAt: now,
|
CreatedAt: now,
|
||||||
UpdatedAt: now,
|
UpdatedAt: now,
|
||||||
|
|||||||
@@ -295,7 +295,11 @@ func (s *Service) syncCoreAuthFromAuths(ctx context.Context, auths []*coreauth.A
|
|||||||
case "qwen":
|
case "qwen":
|
||||||
s.coreManager.RegisterExecutor(executor.NewQwenExecutor(s.cfg))
|
s.coreManager.RegisterExecutor(executor.NewQwenExecutor(s.cfg))
|
||||||
default:
|
default:
|
||||||
s.coreManager.RegisterExecutor(executor.NewOpenAICompatExecutor("openai-compatibility", s.cfg))
|
providerKey := strings.ToLower(strings.TrimSpace(a.Provider))
|
||||||
|
if providerKey == "" {
|
||||||
|
providerKey = "openai-compatibility"
|
||||||
|
}
|
||||||
|
s.coreManager.RegisterExecutor(executor.NewOpenAICompatExecutor(providerKey, s.cfg))
|
||||||
}
|
}
|
||||||
|
|
||||||
// Preserve existing temporal fields
|
// Preserve existing temporal fields
|
||||||
@@ -341,7 +345,7 @@ func (s *Service) registerModelsForAuth(a *coreauth.Auth) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
provider := strings.ToLower(a.Provider)
|
provider := strings.ToLower(strings.TrimSpace(a.Provider))
|
||||||
var models []*ModelInfo
|
var models []*ModelInfo
|
||||||
switch provider {
|
switch provider {
|
||||||
case "gemini":
|
case "gemini":
|
||||||
@@ -359,11 +363,19 @@ func (s *Service) registerModelsForAuth(a *coreauth.Auth) {
|
|||||||
default:
|
default:
|
||||||
// Handle OpenAI-compatibility providers by name using config
|
// Handle OpenAI-compatibility providers by name using config
|
||||||
if s.cfg != nil {
|
if s.cfg != nil {
|
||||||
// When provider is normalized to "openai-compatibility", read the original name from attributes.
|
providerKey := provider
|
||||||
compatName := a.Provider
|
compatName := strings.TrimSpace(a.Provider)
|
||||||
if strings.EqualFold(compatName, "openai-compatibility") {
|
if strings.EqualFold(providerKey, "openai-compatibility") {
|
||||||
if a.Attributes != nil && a.Attributes["compat_name"] != "" {
|
if a.Attributes != nil {
|
||||||
compatName = a.Attributes["compat_name"]
|
if v := strings.TrimSpace(a.Attributes["compat_name"]); v != "" {
|
||||||
|
compatName = v
|
||||||
|
}
|
||||||
|
if v := strings.TrimSpace(a.Attributes["provider_key"]); v != "" {
|
||||||
|
providerKey = strings.ToLower(v)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if providerKey == "openai-compatibility" && compatName != "" {
|
||||||
|
providerKey = strings.ToLower(compatName)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
for i := range s.cfg.OpenAICompatibility {
|
for i := range s.cfg.OpenAICompatibility {
|
||||||
@@ -384,7 +396,10 @@ func (s *Service) registerModelsForAuth(a *coreauth.Auth) {
|
|||||||
}
|
}
|
||||||
// Register and return
|
// Register and return
|
||||||
if len(ms) > 0 {
|
if len(ms) > 0 {
|
||||||
GlobalModelRegistry().RegisterClient(a.ID, a.Provider, ms)
|
if providerKey == "" {
|
||||||
|
providerKey = "openai-compatibility"
|
||||||
|
}
|
||||||
|
GlobalModelRegistry().RegisterClient(a.ID, providerKey, ms)
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -392,6 +407,10 @@ func (s *Service) registerModelsForAuth(a *coreauth.Auth) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
if len(models) > 0 {
|
if len(models) > 0 {
|
||||||
GlobalModelRegistry().RegisterClient(a.ID, a.Provider, models)
|
key := provider
|
||||||
|
if key == "" {
|
||||||
|
key = strings.ToLower(strings.TrimSpace(a.Provider))
|
||||||
|
}
|
||||||
|
GlobalModelRegistry().RegisterClient(a.ID, key, models)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user