refactor: improve thinking logic

This commit is contained in:
hkfires
2026-01-14 08:32:02 +08:00
parent 5a7e5bd870
commit 0b06d637e7
76 changed files with 8712 additions and 1815 deletions

View File

@@ -0,0 +1,201 @@
package auth
import (
"context"
"testing"
internalconfig "github.com/router-for-me/CLIProxyAPI/v6/internal/config"
)
func TestLookupAPIKeyUpstreamModel(t *testing.T) {
cfg := &internalconfig.Config{
GeminiKey: []internalconfig.GeminiKey{
{
APIKey: "k",
BaseURL: "https://example.com",
Models: []internalconfig.GeminiModel{
{Name: "gemini-2.5-pro-exp-03-25", Alias: "g25p"},
{Name: "gemini-2.5-flash(low)", Alias: "g25f"},
},
},
},
}
mgr := NewManager(nil, nil, nil)
mgr.SetConfig(cfg)
ctx := context.Background()
_, _ = mgr.Register(ctx, &Auth{ID: "a1", Provider: "gemini", Attributes: map[string]string{"api_key": "k", "base_url": "https://example.com"}})
tests := []struct {
name string
authID string
input string
want string
}{
// Fast path + suffix preservation
{"alias with suffix", "a1", "g25p(8192)", "gemini-2.5-pro-exp-03-25(8192)"},
{"alias without suffix", "a1", "g25p", "gemini-2.5-pro-exp-03-25"},
// Config suffix takes priority
{"config suffix priority", "a1", "g25f(high)", "gemini-2.5-flash(low)"},
{"config suffix no user suffix", "a1", "g25f", "gemini-2.5-flash(low)"},
// Case insensitive
{"uppercase alias", "a1", "G25P", "gemini-2.5-pro-exp-03-25"},
{"mixed case with suffix", "a1", "G25p(4096)", "gemini-2.5-pro-exp-03-25(4096)"},
// Direct name lookup
{"upstream name direct", "a1", "gemini-2.5-pro-exp-03-25", "gemini-2.5-pro-exp-03-25"},
{"upstream name with suffix", "a1", "gemini-2.5-pro-exp-03-25(8192)", "gemini-2.5-pro-exp-03-25(8192)"},
// Cache miss scenarios
{"non-existent auth", "non-existent", "g25p", ""},
{"unknown alias", "a1", "unknown-alias", ""},
{"empty auth ID", "", "g25p", ""},
{"empty model", "a1", "", ""},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
resolved := mgr.lookupAPIKeyUpstreamModel(tt.authID, tt.input)
if resolved != tt.want {
t.Errorf("lookupAPIKeyUpstreamModel(%q, %q) = %q, want %q", tt.authID, tt.input, resolved, tt.want)
}
})
}
}
func TestAPIKeyModelMappings_ConfigHotReload(t *testing.T) {
cfg := &internalconfig.Config{
GeminiKey: []internalconfig.GeminiKey{
{
APIKey: "k",
Models: []internalconfig.GeminiModel{{Name: "gemini-2.5-pro-exp-03-25", Alias: "g25p"}},
},
},
}
mgr := NewManager(nil, nil, nil)
mgr.SetConfig(cfg)
ctx := context.Background()
_, _ = mgr.Register(ctx, &Auth{ID: "a1", Provider: "gemini", Attributes: map[string]string{"api_key": "k"}})
// Initial mapping
if resolved := mgr.lookupAPIKeyUpstreamModel("a1", "g25p"); resolved != "gemini-2.5-pro-exp-03-25" {
t.Fatalf("before reload: got %q, want %q", resolved, "gemini-2.5-pro-exp-03-25")
}
// Hot reload with new mapping
mgr.SetConfig(&internalconfig.Config{
GeminiKey: []internalconfig.GeminiKey{
{
APIKey: "k",
Models: []internalconfig.GeminiModel{{Name: "gemini-2.5-flash", Alias: "g25p"}},
},
},
})
// New mapping should take effect
if resolved := mgr.lookupAPIKeyUpstreamModel("a1", "g25p"); resolved != "gemini-2.5-flash" {
t.Fatalf("after reload: got %q, want %q", resolved, "gemini-2.5-flash")
}
}
func TestAPIKeyModelMappings_MultipleProviders(t *testing.T) {
cfg := &internalconfig.Config{
GeminiKey: []internalconfig.GeminiKey{{APIKey: "gemini-key", Models: []internalconfig.GeminiModel{{Name: "gemini-2.5-pro", Alias: "gp"}}}},
ClaudeKey: []internalconfig.ClaudeKey{{APIKey: "claude-key", Models: []internalconfig.ClaudeModel{{Name: "claude-sonnet-4", Alias: "cs4"}}}},
CodexKey: []internalconfig.CodexKey{{APIKey: "codex-key", Models: []internalconfig.CodexModel{{Name: "o3", Alias: "o"}}}},
}
mgr := NewManager(nil, nil, nil)
mgr.SetConfig(cfg)
ctx := context.Background()
_, _ = mgr.Register(ctx, &Auth{ID: "gemini-auth", Provider: "gemini", Attributes: map[string]string{"api_key": "gemini-key"}})
_, _ = mgr.Register(ctx, &Auth{ID: "claude-auth", Provider: "claude", Attributes: map[string]string{"api_key": "claude-key"}})
_, _ = mgr.Register(ctx, &Auth{ID: "codex-auth", Provider: "codex", Attributes: map[string]string{"api_key": "codex-key"}})
tests := []struct {
authID, input, want string
}{
{"gemini-auth", "gp", "gemini-2.5-pro"},
{"claude-auth", "cs4", "claude-sonnet-4"},
{"codex-auth", "o", "o3"},
}
for _, tt := range tests {
if resolved := mgr.lookupAPIKeyUpstreamModel(tt.authID, tt.input); resolved != tt.want {
t.Errorf("lookupAPIKeyUpstreamModel(%q, %q) = %q, want %q", tt.authID, tt.input, resolved, tt.want)
}
}
}
func TestApplyAPIKeyModelMapping(t *testing.T) {
cfg := &internalconfig.Config{
GeminiKey: []internalconfig.GeminiKey{
{APIKey: "k", Models: []internalconfig.GeminiModel{{Name: "gemini-2.5-pro-exp-03-25", Alias: "g25p"}}},
},
}
mgr := NewManager(nil, nil, nil)
mgr.SetConfig(cfg)
ctx := context.Background()
apiKeyAuth := &Auth{ID: "a1", Provider: "gemini", Attributes: map[string]string{"api_key": "k"}}
oauthAuth := &Auth{ID: "oauth-auth", Provider: "gemini", Attributes: map[string]string{"auth_kind": "oauth"}}
_, _ = mgr.Register(ctx, apiKeyAuth)
tests := []struct {
name string
auth *Auth
inputModel string
wantModel string
wantOriginal string
expectMapping bool
}{
{
name: "api_key auth with alias",
auth: apiKeyAuth,
inputModel: "g25p(8192)",
wantModel: "gemini-2.5-pro-exp-03-25(8192)",
wantOriginal: "g25p(8192)",
expectMapping: true,
},
{
name: "oauth auth passthrough",
auth: oauthAuth,
inputModel: "some-model",
wantModel: "some-model",
expectMapping: false,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
metadata := map[string]any{"existing": "value"}
resolvedModel, resultMeta := mgr.applyAPIKeyModelMapping(tt.auth, tt.inputModel, metadata)
if resolvedModel != tt.wantModel {
t.Errorf("model = %q, want %q", resolvedModel, tt.wantModel)
}
if resultMeta["existing"] != "value" {
t.Error("existing metadata not preserved")
}
original, hasOriginal := resultMeta["model_mapping_original_model"].(string)
if tt.expectMapping {
if !hasOriginal || original != tt.wantOriginal {
t.Errorf("original model = %q, want %q", original, tt.wantOriginal)
}
} else {
if hasOriginal {
t.Error("should not set model_mapping_original_model for non-api_key auth")
}
}
})
}
}

View File

@@ -15,8 +15,10 @@ import (
"time"
"github.com/google/uuid"
internalconfig "github.com/router-for-me/CLIProxyAPI/v6/internal/config"
"github.com/router-for-me/CLIProxyAPI/v6/internal/logging"
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
"github.com/router-for-me/CLIProxyAPI/v6/internal/thinking"
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
log "github.com/sirupsen/logrus"
@@ -120,6 +122,14 @@ type Manager struct {
// modelNameMappings stores global model name alias mappings (alias -> upstream name) keyed by channel.
modelNameMappings atomic.Value
// runtimeConfig stores the latest application config for request-time decisions.
// It is initialized in NewManager; never Load() before first Store().
runtimeConfig atomic.Value
// apiKeyModelMappings caches resolved model alias mappings for API-key auths.
// Keyed by auth.ID, value is alias(lower) -> upstream model (including suffix).
apiKeyModelMappings atomic.Value
// Optional HTTP RoundTripper provider injected by host.
rtProvider RoundTripperProvider
@@ -135,7 +145,7 @@ func NewManager(store Store, selector Selector, hook Hook) *Manager {
if hook == nil {
hook = NoopHook{}
}
return &Manager{
manager := &Manager{
store: store,
executors: make(map[string]ProviderExecutor),
selector: selector,
@@ -143,6 +153,10 @@ func NewManager(store Store, selector Selector, hook Hook) *Manager {
auths: make(map[string]*Auth),
providerOffsets: make(map[string]int),
}
// atomic.Value requires non-nil initial value.
manager.runtimeConfig.Store(&internalconfig.Config{})
manager.apiKeyModelMappings.Store(apiKeyModelMappingTable(nil))
return manager
}
func (m *Manager) SetSelector(selector Selector) {
@@ -171,6 +185,181 @@ func (m *Manager) SetRoundTripperProvider(p RoundTripperProvider) {
m.mu.Unlock()
}
// SetConfig updates the runtime config snapshot used by request-time helpers.
// Callers should provide the latest config on reload so per-credential alias mapping stays in sync.
func (m *Manager) SetConfig(cfg *internalconfig.Config) {
if m == nil {
return
}
if cfg == nil {
cfg = &internalconfig.Config{}
}
m.runtimeConfig.Store(cfg)
m.rebuildAPIKeyModelMappingsFromRuntimeConfig()
}
func (m *Manager) lookupAPIKeyUpstreamModel(authID, requestedModel string) string {
if m == nil {
return ""
}
authID = strings.TrimSpace(authID)
if authID == "" {
return ""
}
requestedModel = strings.TrimSpace(requestedModel)
if requestedModel == "" {
return ""
}
table, _ := m.apiKeyModelMappings.Load().(apiKeyModelMappingTable)
if table == nil {
return ""
}
byAlias := table[authID]
if len(byAlias) == 0 {
return ""
}
key := strings.ToLower(thinking.ParseSuffix(requestedModel).ModelName)
if key == "" {
key = strings.ToLower(requestedModel)
}
resolved := strings.TrimSpace(byAlias[key])
if resolved == "" {
return ""
}
// Preserve thinking suffix from the client's requested model unless config already has one.
requestResult := thinking.ParseSuffix(requestedModel)
if thinking.ParseSuffix(resolved).HasSuffix {
return resolved
}
if requestResult.HasSuffix && requestResult.RawSuffix != "" {
return resolved + "(" + requestResult.RawSuffix + ")"
}
return resolved
}
func (m *Manager) rebuildAPIKeyModelMappingsFromRuntimeConfig() {
if m == nil {
return
}
cfg, _ := m.runtimeConfig.Load().(*internalconfig.Config)
if cfg == nil {
cfg = &internalconfig.Config{}
}
m.mu.Lock()
defer m.mu.Unlock()
m.rebuildAPIKeyModelMappingsLocked(cfg)
}
func (m *Manager) rebuildAPIKeyModelMappingsLocked(cfg *internalconfig.Config) {
if m == nil {
return
}
if cfg == nil {
cfg = &internalconfig.Config{}
}
out := make(apiKeyModelMappingTable)
for _, auth := range m.auths {
if auth == nil {
continue
}
if strings.TrimSpace(auth.ID) == "" {
continue
}
kind, _ := auth.AccountInfo()
if !strings.EqualFold(strings.TrimSpace(kind), "api_key") {
continue
}
byAlias := make(map[string]string)
provider := strings.ToLower(strings.TrimSpace(auth.Provider))
switch provider {
case "gemini":
if entry := resolveGeminiAPIKeyConfig(cfg, auth); entry != nil {
compileAPIKeyModelMappingsForModels(byAlias, entry.Models)
}
case "claude":
if entry := resolveClaudeAPIKeyConfig(cfg, auth); entry != nil {
compileAPIKeyModelMappingsForModels(byAlias, entry.Models)
}
case "codex":
if entry := resolveCodexAPIKeyConfig(cfg, auth); entry != nil {
compileAPIKeyModelMappingsForModels(byAlias, entry.Models)
}
case "vertex":
if entry := resolveVertexAPIKeyConfig(cfg, auth); entry != nil {
compileAPIKeyModelMappingsForModels(byAlias, entry.Models)
}
default:
// OpenAI-compat uses config selection from auth.Attributes.
providerKey := ""
compatName := ""
if auth.Attributes != nil {
providerKey = strings.TrimSpace(auth.Attributes["provider_key"])
compatName = strings.TrimSpace(auth.Attributes["compat_name"])
}
if compatName != "" || strings.EqualFold(strings.TrimSpace(auth.Provider), "openai-compatibility") {
if entry := resolveOpenAICompatConfig(cfg, providerKey, compatName, auth.Provider); entry != nil {
compileAPIKeyModelMappingsForModels(byAlias, entry.Models)
}
}
}
if len(byAlias) > 0 {
out[auth.ID] = byAlias
}
}
m.apiKeyModelMappings.Store(out)
}
func compileAPIKeyModelMappingsForModels[T interface {
GetName() string
GetAlias() string
}](out map[string]string, models []T) {
if out == nil {
return
}
for i := range models {
alias := strings.TrimSpace(models[i].GetAlias())
name := strings.TrimSpace(models[i].GetName())
if alias == "" || name == "" {
continue
}
aliasKey := strings.ToLower(thinking.ParseSuffix(alias).ModelName)
if aliasKey == "" {
aliasKey = strings.ToLower(alias)
}
// Config priority: first alias wins.
if _, exists := out[aliasKey]; exists {
continue
}
out[aliasKey] = name
// Also allow direct lookup by upstream name (case-insensitive), so lookups on already-upstream
// models remain a cheap no-op.
nameKey := strings.ToLower(thinking.ParseSuffix(name).ModelName)
if nameKey == "" {
nameKey = strings.ToLower(name)
}
if nameKey != "" {
if _, exists := out[nameKey]; !exists {
out[nameKey] = name
}
}
// Preserve config suffix priority by seeding a base-name lookup when name already has suffix.
nameResult := thinking.ParseSuffix(name)
if nameResult.HasSuffix {
baseKey := strings.ToLower(strings.TrimSpace(nameResult.ModelName))
if baseKey != "" {
if _, exists := out[baseKey]; !exists {
out[baseKey] = name
}
}
}
}
}
// SetRetryConfig updates retry attempts and cooldown wait interval.
func (m *Manager) SetRetryConfig(retry int, maxRetryInterval time.Duration) {
if m == nil {
@@ -219,6 +408,7 @@ func (m *Manager) Register(ctx context.Context, auth *Auth) (*Auth, error) {
m.mu.Lock()
m.auths[auth.ID] = auth.Clone()
m.mu.Unlock()
m.rebuildAPIKeyModelMappingsFromRuntimeConfig()
_ = m.persist(ctx, auth)
m.hook.OnAuthRegistered(ctx, auth.Clone())
return auth.Clone(), nil
@@ -237,6 +427,7 @@ func (m *Manager) Update(ctx context.Context, auth *Auth) (*Auth, error) {
auth.EnsureIndex()
m.auths[auth.ID] = auth.Clone()
m.mu.Unlock()
m.rebuildAPIKeyModelMappingsFromRuntimeConfig()
_ = m.persist(ctx, auth)
m.hook.OnAuthUpdated(ctx, auth.Clone())
return auth.Clone(), nil
@@ -261,6 +452,11 @@ func (m *Manager) Load(ctx context.Context) error {
auth.EnsureIndex()
m.auths[auth.ID] = auth.Clone()
}
cfg, _ := m.runtimeConfig.Load().(*internalconfig.Config)
if cfg == nil {
cfg = &internalconfig.Config{}
}
m.rebuildAPIKeyModelMappingsLocked(cfg)
return nil
}
@@ -558,6 +754,7 @@ func (m *Manager) executeWithProvider(ctx context.Context, provider string, req
execReq := req
execReq.Model, execReq.Metadata = rewriteModelForAuth(routeModel, req.Metadata, auth)
execReq.Model, execReq.Metadata = m.applyOAuthModelMapping(auth, execReq.Model, execReq.Metadata)
execReq.Model, execReq.Metadata = m.applyAPIKeyModelMapping(auth, execReq.Model, execReq.Metadata)
resp, errExec := executor.Execute(execCtx, auth, execReq, opts)
result := Result{AuthID: auth.ID, Provider: provider, Model: routeModel, Success: errExec == nil}
if errExec != nil {
@@ -606,6 +803,7 @@ func (m *Manager) executeCountWithProvider(ctx context.Context, provider string,
execReq := req
execReq.Model, execReq.Metadata = rewriteModelForAuth(routeModel, req.Metadata, auth)
execReq.Model, execReq.Metadata = m.applyOAuthModelMapping(auth, execReq.Model, execReq.Metadata)
execReq.Model, execReq.Metadata = m.applyAPIKeyModelMapping(auth, execReq.Model, execReq.Metadata)
resp, errExec := executor.CountTokens(execCtx, auth, execReq, opts)
result := Result{AuthID: auth.ID, Provider: provider, Model: routeModel, Success: errExec == nil}
if errExec != nil {
@@ -654,6 +852,7 @@ func (m *Manager) executeStreamWithProvider(ctx context.Context, provider string
execReq := req
execReq.Model, execReq.Metadata = rewriteModelForAuth(routeModel, req.Metadata, auth)
execReq.Model, execReq.Metadata = m.applyOAuthModelMapping(auth, execReq.Model, execReq.Metadata)
execReq.Model, execReq.Metadata = m.applyAPIKeyModelMapping(auth, execReq.Model, execReq.Metadata)
chunks, errStream := executor.ExecuteStream(execCtx, auth, execReq, opts)
if errStream != nil {
rerr := &Error{Message: errStream.Error()}
@@ -712,7 +911,6 @@ func stripPrefixFromMetadata(metadata map[string]any, needle string) map[string]
return metadata
}
keys := []string{
util.ThinkingOriginalModelMetadataKey,
util.GeminiOriginalModelMetadataKey,
util.ModelMappingOriginalModelMetadataKey,
}
@@ -740,6 +938,215 @@ func stripPrefixFromMetadata(metadata map[string]any, needle string) map[string]
return out
}
func (m *Manager) applyAPIKeyModelMapping(auth *Auth, requestedModel string, metadata map[string]any) (string, map[string]any) {
if m == nil || auth == nil {
return requestedModel, metadata
}
kind, _ := auth.AccountInfo()
if !strings.EqualFold(strings.TrimSpace(kind), "api_key") {
return requestedModel, metadata
}
requestedModel = strings.TrimSpace(requestedModel)
if requestedModel == "" {
return requestedModel, metadata
}
// Fast path: lookup per-auth mapping table (keyed by auth.ID).
if resolved := m.lookupAPIKeyUpstreamModel(auth.ID, requestedModel); resolved != "" {
return applyUpstreamModelOverride(requestedModel, resolved, metadata)
}
// Slow path: scan config for the matching credential entry and resolve alias.
// This acts as a safety net if mappings are stale or auth.ID is missing.
cfg, _ := m.runtimeConfig.Load().(*internalconfig.Config)
if cfg == nil {
cfg = &internalconfig.Config{}
}
provider := strings.ToLower(strings.TrimSpace(auth.Provider))
upstreamModel := ""
switch provider {
case "gemini":
upstreamModel = resolveUpstreamModelForGeminiAPIKey(cfg, auth, requestedModel)
case "claude":
upstreamModel = resolveUpstreamModelForClaudeAPIKey(cfg, auth, requestedModel)
case "codex":
upstreamModel = resolveUpstreamModelForCodexAPIKey(cfg, auth, requestedModel)
case "vertex":
upstreamModel = resolveUpstreamModelForVertexAPIKey(cfg, auth, requestedModel)
default:
upstreamModel = resolveUpstreamModelForOpenAICompatAPIKey(cfg, auth, requestedModel)
}
// applyUpstreamModelOverride lives in model_name_mappings.go.
return applyUpstreamModelOverride(requestedModel, upstreamModel, metadata)
}
// APIKeyConfigEntry is a generic interface for API key configurations.
type APIKeyConfigEntry interface {
GetAPIKey() string
GetBaseURL() string
}
func resolveAPIKeyConfig[T APIKeyConfigEntry](entries []T, auth *Auth) *T {
if auth == nil || len(entries) == 0 {
return nil
}
attrKey, attrBase := "", ""
if auth.Attributes != nil {
attrKey = strings.TrimSpace(auth.Attributes["api_key"])
attrBase = strings.TrimSpace(auth.Attributes["base_url"])
}
for i := range entries {
entry := &entries[i]
cfgKey := strings.TrimSpace((*entry).GetAPIKey())
cfgBase := strings.TrimSpace((*entry).GetBaseURL())
if attrKey != "" && attrBase != "" {
if strings.EqualFold(cfgKey, attrKey) && strings.EqualFold(cfgBase, attrBase) {
return entry
}
continue
}
if attrKey != "" && strings.EqualFold(cfgKey, attrKey) {
if cfgBase == "" || strings.EqualFold(cfgBase, attrBase) {
return entry
}
}
if attrKey == "" && attrBase != "" && strings.EqualFold(cfgBase, attrBase) {
return entry
}
}
if attrKey != "" {
for i := range entries {
entry := &entries[i]
if strings.EqualFold(strings.TrimSpace((*entry).GetAPIKey()), attrKey) {
return entry
}
}
}
return nil
}
func resolveGeminiAPIKeyConfig(cfg *internalconfig.Config, auth *Auth) *internalconfig.GeminiKey {
if cfg == nil {
return nil
}
return resolveAPIKeyConfig(cfg.GeminiKey, auth)
}
func resolveClaudeAPIKeyConfig(cfg *internalconfig.Config, auth *Auth) *internalconfig.ClaudeKey {
if cfg == nil {
return nil
}
return resolveAPIKeyConfig(cfg.ClaudeKey, auth)
}
func resolveCodexAPIKeyConfig(cfg *internalconfig.Config, auth *Auth) *internalconfig.CodexKey {
if cfg == nil {
return nil
}
return resolveAPIKeyConfig(cfg.CodexKey, auth)
}
func resolveVertexAPIKeyConfig(cfg *internalconfig.Config, auth *Auth) *internalconfig.VertexCompatKey {
if cfg == nil {
return nil
}
return resolveAPIKeyConfig(cfg.VertexCompatAPIKey, auth)
}
func resolveUpstreamModelForGeminiAPIKey(cfg *internalconfig.Config, auth *Auth, requestedModel string) string {
entry := resolveGeminiAPIKeyConfig(cfg, auth)
if entry == nil {
return ""
}
return resolveModelAliasFromConfigModels(requestedModel, asModelAliasEntries(entry.Models))
}
func resolveUpstreamModelForClaudeAPIKey(cfg *internalconfig.Config, auth *Auth, requestedModel string) string {
entry := resolveClaudeAPIKeyConfig(cfg, auth)
if entry == nil {
return ""
}
return resolveModelAliasFromConfigModels(requestedModel, asModelAliasEntries(entry.Models))
}
func resolveUpstreamModelForCodexAPIKey(cfg *internalconfig.Config, auth *Auth, requestedModel string) string {
entry := resolveCodexAPIKeyConfig(cfg, auth)
if entry == nil {
return ""
}
return resolveModelAliasFromConfigModels(requestedModel, asModelAliasEntries(entry.Models))
}
func resolveUpstreamModelForVertexAPIKey(cfg *internalconfig.Config, auth *Auth, requestedModel string) string {
entry := resolveVertexAPIKeyConfig(cfg, auth)
if entry == nil {
return ""
}
return resolveModelAliasFromConfigModels(requestedModel, asModelAliasEntries(entry.Models))
}
func resolveUpstreamModelForOpenAICompatAPIKey(cfg *internalconfig.Config, auth *Auth, requestedModel string) string {
providerKey := ""
compatName := ""
if auth != nil && len(auth.Attributes) > 0 {
providerKey = strings.TrimSpace(auth.Attributes["provider_key"])
compatName = strings.TrimSpace(auth.Attributes["compat_name"])
}
if compatName == "" && !strings.EqualFold(strings.TrimSpace(auth.Provider), "openai-compatibility") {
return ""
}
entry := resolveOpenAICompatConfig(cfg, providerKey, compatName, auth.Provider)
if entry == nil {
return ""
}
return resolveModelAliasFromConfigModels(requestedModel, asModelAliasEntries(entry.Models))
}
type apiKeyModelMappingTable map[string]map[string]string
func resolveOpenAICompatConfig(cfg *internalconfig.Config, providerKey, compatName, authProvider string) *internalconfig.OpenAICompatibility {
if cfg == nil {
return nil
}
candidates := make([]string, 0, 3)
if v := strings.TrimSpace(compatName); v != "" {
candidates = append(candidates, v)
}
if v := strings.TrimSpace(providerKey); v != "" {
candidates = append(candidates, v)
}
if v := strings.TrimSpace(authProvider); v != "" {
candidates = append(candidates, v)
}
for i := range cfg.OpenAICompatibility {
compat := &cfg.OpenAICompatibility[i]
for _, candidate := range candidates {
if candidate != "" && strings.EqualFold(strings.TrimSpace(candidate), compat.Name) {
return compat
}
}
}
return nil
}
func asModelAliasEntries[T interface {
GetName() string
GetAlias() string
}](models []T) []modelMappingEntry {
if len(models) == 0 {
return nil
}
out := make([]modelMappingEntry, 0, len(models))
for i := range models {
out = append(out, models[i])
}
return out
}
func (m *Manager) normalizeProviders(providers []string) []string {
if len(providers) == 0 {
return nil

View File

@@ -4,9 +4,15 @@ import (
"strings"
internalconfig "github.com/router-for-me/CLIProxyAPI/v6/internal/config"
"github.com/router-for-me/CLIProxyAPI/v6/internal/thinking"
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
)
type modelMappingEntry interface {
GetName() string
GetAlias() string
}
type modelNameMappingTable struct {
// reverse maps channel -> alias (lower) -> original upstream model name.
reverse map[string]map[string]string
@@ -71,9 +77,14 @@ func (m *Manager) SetOAuthModelMappings(mappings map[string][]internalconfig.Mod
// requested model for response translation.
func (m *Manager) applyOAuthModelMapping(auth *Auth, requestedModel string, metadata map[string]any) (string, map[string]any) {
upstreamModel := m.resolveOAuthUpstreamModel(auth, requestedModel)
return applyUpstreamModelOverride(requestedModel, upstreamModel, metadata)
}
func applyUpstreamModelOverride(requestedModel, upstreamModel string, metadata map[string]any) (string, map[string]any) {
if upstreamModel == "" {
return requestedModel, metadata
}
out := make(map[string]any, 1)
if len(metadata) > 0 {
out = make(map[string]any, len(metadata)+1)
@@ -81,24 +92,92 @@ func (m *Manager) applyOAuthModelMapping(auth *Auth, requestedModel string, meta
out[k] = v
}
}
// Store the requested alias (e.g., "gp") so downstream can use it to look up
// model metadata from the global registry where it was registered under this alias.
// Preserve the original client model string (including any suffix) for downstream.
out[util.ModelMappingOriginalModelMetadataKey] = requestedModel
return upstreamModel, out
}
func resolveModelAliasFromConfigModels(requestedModel string, models []modelMappingEntry) string {
requestedModel = strings.TrimSpace(requestedModel)
if requestedModel == "" {
return ""
}
if len(models) == 0 {
return ""
}
requestResult := thinking.ParseSuffix(requestedModel)
base := requestResult.ModelName
candidates := []string{base}
if base != requestedModel {
candidates = append(candidates, requestedModel)
}
preserveSuffix := func(resolved string) string {
resolved = strings.TrimSpace(resolved)
if resolved == "" {
return ""
}
if thinking.ParseSuffix(resolved).HasSuffix {
return resolved
}
if requestResult.HasSuffix && requestResult.RawSuffix != "" {
return resolved + "(" + requestResult.RawSuffix + ")"
}
return resolved
}
for i := range models {
name := strings.TrimSpace(models[i].GetName())
alias := strings.TrimSpace(models[i].GetAlias())
for _, candidate := range candidates {
if candidate == "" {
continue
}
if alias != "" && strings.EqualFold(alias, candidate) {
if name != "" {
return preserveSuffix(name)
}
return preserveSuffix(candidate)
}
if name != "" && strings.EqualFold(name, candidate) {
return preserveSuffix(name)
}
}
}
return ""
}
// resolveOAuthUpstreamModel resolves the upstream model name from OAuth model mappings.
// If a mapping exists, returns the original (upstream) model name that corresponds
// to the requested alias.
//
// If the requested model contains a thinking suffix (e.g., "gemini-2.5-pro(8192)"),
// the suffix is preserved in the returned model name. However, if the mapping's
// original name already contains a suffix, the config suffix takes priority.
func (m *Manager) resolveOAuthUpstreamModel(auth *Auth, requestedModel string) string {
return resolveUpstreamModelFromMappingTable(m, auth, requestedModel, modelMappingChannel(auth))
}
func resolveUpstreamModelFromMappingTable(m *Manager, auth *Auth, requestedModel, channel string) string {
if m == nil || auth == nil {
return ""
}
channel := modelMappingChannel(auth)
if channel == "" {
return ""
}
key := strings.ToLower(strings.TrimSpace(requestedModel))
if key == "" {
return ""
// Extract thinking suffix from requested model using ParseSuffix
requestResult := thinking.ParseSuffix(requestedModel)
baseModel := requestResult.ModelName
// Candidate keys to match: base model and raw input (handles suffix-parsing edge cases).
candidates := []string{baseModel}
if baseModel != requestedModel {
candidates = append(candidates, requestedModel)
}
raw := m.modelNameMappings.Load()
table, _ := raw.(*modelNameMappingTable)
if table == nil || table.reverse == nil {
@@ -108,11 +187,32 @@ func (m *Manager) resolveOAuthUpstreamModel(auth *Auth, requestedModel string) s
if rev == nil {
return ""
}
original := strings.TrimSpace(rev[key])
if original == "" || strings.EqualFold(original, requestedModel) {
return ""
for _, candidate := range candidates {
key := strings.ToLower(strings.TrimSpace(candidate))
if key == "" {
continue
}
original := strings.TrimSpace(rev[key])
if original == "" {
continue
}
if strings.EqualFold(original, baseModel) {
return ""
}
// If config already has suffix, it takes priority.
if thinking.ParseSuffix(original).HasSuffix {
return original
}
// Preserve user's thinking suffix on the resolved model.
if requestResult.HasSuffix && requestResult.RawSuffix != "" {
return original + "(" + requestResult.RawSuffix + ")"
}
return original
}
return original
return ""
}
// modelMappingChannel extracts the OAuth model mapping channel from an Auth object.

View File

@@ -0,0 +1,187 @@
package auth
import (
"testing"
internalconfig "github.com/router-for-me/CLIProxyAPI/v6/internal/config"
)
func TestResolveOAuthUpstreamModel_SuffixPreservation(t *testing.T) {
t.Parallel()
tests := []struct {
name string
mappings map[string][]internalconfig.ModelNameMapping
channel string
input string
want string
}{
{
name: "numeric suffix preserved",
mappings: map[string][]internalconfig.ModelNameMapping{
"gemini-cli": {{Name: "gemini-2.5-pro-exp-03-25", Alias: "gemini-2.5-pro"}},
},
channel: "gemini-cli",
input: "gemini-2.5-pro(8192)",
want: "gemini-2.5-pro-exp-03-25(8192)",
},
{
name: "level suffix preserved",
mappings: map[string][]internalconfig.ModelNameMapping{
"claude": {{Name: "claude-sonnet-4-5-20250514", Alias: "claude-sonnet-4-5"}},
},
channel: "claude",
input: "claude-sonnet-4-5(high)",
want: "claude-sonnet-4-5-20250514(high)",
},
{
name: "no suffix unchanged",
mappings: map[string][]internalconfig.ModelNameMapping{
"gemini-cli": {{Name: "gemini-2.5-pro-exp-03-25", Alias: "gemini-2.5-pro"}},
},
channel: "gemini-cli",
input: "gemini-2.5-pro",
want: "gemini-2.5-pro-exp-03-25",
},
{
name: "config suffix takes priority",
mappings: map[string][]internalconfig.ModelNameMapping{
"claude": {{Name: "claude-sonnet-4-5-20250514(low)", Alias: "claude-sonnet-4-5"}},
},
channel: "claude",
input: "claude-sonnet-4-5(high)",
want: "claude-sonnet-4-5-20250514(low)",
},
{
name: "auto suffix preserved",
mappings: map[string][]internalconfig.ModelNameMapping{
"gemini-cli": {{Name: "gemini-2.5-pro-exp-03-25", Alias: "gemini-2.5-pro"}},
},
channel: "gemini-cli",
input: "gemini-2.5-pro(auto)",
want: "gemini-2.5-pro-exp-03-25(auto)",
},
{
name: "none suffix preserved",
mappings: map[string][]internalconfig.ModelNameMapping{
"gemini-cli": {{Name: "gemini-2.5-pro-exp-03-25", Alias: "gemini-2.5-pro"}},
},
channel: "gemini-cli",
input: "gemini-2.5-pro(none)",
want: "gemini-2.5-pro-exp-03-25(none)",
},
{
name: "case insensitive alias lookup with suffix",
mappings: map[string][]internalconfig.ModelNameMapping{
"gemini-cli": {{Name: "gemini-2.5-pro-exp-03-25", Alias: "Gemini-2.5-Pro"}},
},
channel: "gemini-cli",
input: "gemini-2.5-pro(high)",
want: "gemini-2.5-pro-exp-03-25(high)",
},
{
name: "no mapping returns empty",
mappings: map[string][]internalconfig.ModelNameMapping{
"gemini-cli": {{Name: "gemini-2.5-pro-exp-03-25", Alias: "gemini-2.5-pro"}},
},
channel: "gemini-cli",
input: "unknown-model(high)",
want: "",
},
{
name: "wrong channel returns empty",
mappings: map[string][]internalconfig.ModelNameMapping{
"gemini-cli": {{Name: "gemini-2.5-pro-exp-03-25", Alias: "gemini-2.5-pro"}},
},
channel: "claude",
input: "gemini-2.5-pro(high)",
want: "",
},
{
name: "empty suffix filtered out",
mappings: map[string][]internalconfig.ModelNameMapping{
"gemini-cli": {{Name: "gemini-2.5-pro-exp-03-25", Alias: "gemini-2.5-pro"}},
},
channel: "gemini-cli",
input: "gemini-2.5-pro()",
want: "gemini-2.5-pro-exp-03-25",
},
{
name: "incomplete suffix treated as no suffix",
mappings: map[string][]internalconfig.ModelNameMapping{
"gemini-cli": {{Name: "gemini-2.5-pro-exp-03-25", Alias: "gemini-2.5-pro(high"}},
},
channel: "gemini-cli",
input: "gemini-2.5-pro(high",
want: "gemini-2.5-pro-exp-03-25",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
t.Parallel()
mgr := NewManager(nil, nil, nil)
mgr.SetConfig(&internalconfig.Config{})
mgr.SetOAuthModelMappings(tt.mappings)
auth := createAuthForChannel(tt.channel)
got := mgr.resolveOAuthUpstreamModel(auth, tt.input)
if got != tt.want {
t.Errorf("resolveOAuthUpstreamModel(%q) = %q, want %q", tt.input, got, tt.want)
}
})
}
}
func createAuthForChannel(channel string) *Auth {
switch channel {
case "gemini-cli":
return &Auth{Provider: "gemini-cli"}
case "claude":
return &Auth{Provider: "claude", Attributes: map[string]string{"auth_kind": "oauth"}}
case "vertex":
return &Auth{Provider: "vertex", Attributes: map[string]string{"auth_kind": "oauth"}}
case "codex":
return &Auth{Provider: "codex", Attributes: map[string]string{"auth_kind": "oauth"}}
case "aistudio":
return &Auth{Provider: "aistudio"}
case "antigravity":
return &Auth{Provider: "antigravity"}
case "qwen":
return &Auth{Provider: "qwen"}
case "iflow":
return &Auth{Provider: "iflow"}
default:
return &Auth{Provider: channel}
}
}
func TestApplyOAuthModelMapping_SuffixPreservation(t *testing.T) {
t.Parallel()
mappings := map[string][]internalconfig.ModelNameMapping{
"gemini-cli": {{Name: "gemini-2.5-pro-exp-03-25", Alias: "gemini-2.5-pro"}},
}
mgr := NewManager(nil, nil, nil)
mgr.SetConfig(&internalconfig.Config{})
mgr.SetOAuthModelMappings(mappings)
auth := &Auth{ID: "test-auth-id", Provider: "gemini-cli"}
metadata := map[string]any{"existing": "value"}
resolvedModel, resultMeta := mgr.applyOAuthModelMapping(auth, "gemini-2.5-pro(8192)", metadata)
if resolvedModel != "gemini-2.5-pro-exp-03-25(8192)" {
t.Errorf("applyOAuthModelMapping() model = %q, want %q", resolvedModel, "gemini-2.5-pro-exp-03-25(8192)")
}
originalModel, ok := resultMeta["model_mapping_original_model"].(string)
if !ok || originalModel != "gemini-2.5-pro(8192)" {
t.Errorf("applyOAuthModelMapping() metadata[model_mapping_original_model] = %v, want %q", resultMeta["model_mapping_original_model"], "gemini-2.5-pro(8192)")
}
if resultMeta["existing"] != "value" {
t.Errorf("applyOAuthModelMapping() metadata[existing] = %v, want %q", resultMeta["existing"], "value")
}
}