package auth import ( "strings" internalconfig "github.com/router-for-me/CLIProxyAPI/v6/internal/config" "github.com/router-for-me/CLIProxyAPI/v6/internal/util" ) type modelNameMappingTable struct { // reverse maps channel -> alias (lower) -> original upstream model name. reverse map[string]map[string]string } func compileModelNameMappingTable(mappings map[string][]internalconfig.ModelNameMapping) *modelNameMappingTable { if len(mappings) == 0 { return &modelNameMappingTable{} } out := &modelNameMappingTable{ reverse: make(map[string]map[string]string, len(mappings)), } for rawChannel, entries := range mappings { channel := strings.ToLower(strings.TrimSpace(rawChannel)) if channel == "" || len(entries) == 0 { continue } rev := make(map[string]string, len(entries)) for _, entry := range entries { from := strings.TrimSpace(entry.From) to := strings.TrimSpace(entry.To) if from == "" || to == "" { continue } if strings.EqualFold(from, to) { continue } aliasKey := strings.ToLower(to) if _, exists := rev[aliasKey]; exists { continue } rev[aliasKey] = from } if len(rev) > 0 { out.reverse[channel] = rev } } if len(out.reverse) == 0 { out.reverse = nil } return out } // SetOAuthModelMappings updates the OAuth model name mapping table used during execution. // The mapping is applied per-auth channel to resolve the upstream model name while keeping the // client-visible model name unchanged for translation/response formatting. func (m *Manager) SetOAuthModelMappings(mappings map[string][]internalconfig.ModelNameMapping) { if m == nil { return } table := compileModelNameMappingTable(mappings) // atomic.Value requires non-nil store values. if table == nil { table = &modelNameMappingTable{} } m.modelNameMappings.Store(table) } func (m *Manager) applyOAuthModelMappingMetadata(auth *Auth, requestedModel string, metadata map[string]any) map[string]any { original := m.resolveOAuthUpstreamModel(auth, requestedModel) if original == "" { return metadata } if metadata != nil { if v, ok := metadata[util.ModelMappingOriginalModelMetadataKey]; ok { if s, okStr := v.(string); okStr && strings.EqualFold(s, original) { return metadata } } } out := make(map[string]any, 1) if len(metadata) > 0 { out = make(map[string]any, len(metadata)+1) for k, v := range metadata { out[k] = v } } out[util.ModelMappingOriginalModelMetadataKey] = original return out } func (m *Manager) resolveOAuthUpstreamModel(auth *Auth, requestedModel string) string { if m == nil || auth == nil { return "" } channel := modelMappingChannel(auth) if channel == "" { return "" } key := strings.ToLower(strings.TrimSpace(requestedModel)) if key == "" { return "" } raw := m.modelNameMappings.Load() table, _ := raw.(*modelNameMappingTable) if table == nil || table.reverse == nil { return "" } rev := table.reverse[channel] if rev == nil { return "" } original := strings.TrimSpace(rev[key]) if original == "" || strings.EqualFold(original, requestedModel) { return "" } return original } // modelMappingChannel extracts the OAuth model mapping channel from an Auth object. // It determines the provider and auth kind from the Auth's attributes and delegates // to OAuthModelMappingChannel for the actual channel resolution. func modelMappingChannel(auth *Auth) string { if auth == nil { return "" } provider := strings.ToLower(strings.TrimSpace(auth.Provider)) authKind := "" if auth.Attributes != nil { authKind = strings.ToLower(strings.TrimSpace(auth.Attributes["auth_kind"])) } if authKind == "" { if kind, _ := auth.AccountInfo(); strings.EqualFold(kind, "api_key") { authKind = "apikey" } } return OAuthModelMappingChannel(provider, authKind) } // OAuthModelMappingChannel returns the OAuth model mapping channel name for a given provider // and auth kind. Returns empty string if the provider/authKind combination doesn't support // OAuth model mappings (e.g., API key authentication). // // Supported channels: gemini-cli, vertex, aistudio, antigravity, claude, codex, qwen, iflow. func OAuthModelMappingChannel(provider, authKind string) string { provider = strings.ToLower(strings.TrimSpace(provider)) authKind = strings.ToLower(strings.TrimSpace(authKind)) switch provider { case "gemini": // gemini provider uses gemini-api-key config, not oauth-model-mappings. // OAuth-based gemini auth is converted to "gemini-cli" by the synthesizer. return "" case "vertex": if authKind == "apikey" { return "" } return "vertex" case "claude": if authKind == "apikey" { return "" } return "claude" case "codex": if authKind == "apikey" { return "" } return "codex" case "gemini-cli", "aistudio", "antigravity", "qwen", "iflow": return provider default: return "" } }