Files
CLIProxyAPI/sdk/cliproxy/auth/model_name_mappings.go

254 lines
7.2 KiB
Go

package auth
import (
"strings"
internalconfig "github.com/router-for-me/CLIProxyAPI/v6/internal/config"
"github.com/router-for-me/CLIProxyAPI/v6/internal/thinking"
)
type modelMappingEntry interface {
GetName() string
GetAlias() string
}
type modelNameMappingTable struct {
// reverse maps channel -> alias (lower) -> original upstream model name.
reverse map[string]map[string]string
}
func compileModelNameMappingTable(mappings map[string][]internalconfig.ModelNameMapping) *modelNameMappingTable {
if len(mappings) == 0 {
return &modelNameMappingTable{}
}
out := &modelNameMappingTable{
reverse: make(map[string]map[string]string, len(mappings)),
}
for rawChannel, entries := range mappings {
channel := strings.ToLower(strings.TrimSpace(rawChannel))
if channel == "" || len(entries) == 0 {
continue
}
rev := make(map[string]string, len(entries))
for _, entry := range entries {
name := strings.TrimSpace(entry.Name)
alias := strings.TrimSpace(entry.Alias)
if name == "" || alias == "" {
continue
}
if strings.EqualFold(name, alias) {
continue
}
aliasKey := strings.ToLower(alias)
if _, exists := rev[aliasKey]; exists {
continue
}
rev[aliasKey] = name
}
if len(rev) > 0 {
out.reverse[channel] = rev
}
}
if len(out.reverse) == 0 {
out.reverse = nil
}
return out
}
// SetOAuthModelMappings updates the OAuth model name mapping table used during execution.
// The mapping is applied per-auth channel to resolve the upstream model name while keeping the
// client-visible model name unchanged for translation/response formatting.
func (m *Manager) SetOAuthModelMappings(mappings map[string][]internalconfig.ModelNameMapping) {
if m == nil {
return
}
table := compileModelNameMappingTable(mappings)
// atomic.Value requires non-nil store values.
if table == nil {
table = &modelNameMappingTable{}
}
m.modelNameMappings.Store(table)
}
// applyOAuthModelMapping resolves the upstream model from OAuth model mappings.
// If a mapping exists, the returned model is the upstream model.
func (m *Manager) applyOAuthModelMapping(auth *Auth, requestedModel string) string {
upstreamModel := m.resolveOAuthUpstreamModel(auth, requestedModel)
if upstreamModel == "" {
return requestedModel
}
return upstreamModel
}
func resolveModelAliasFromConfigModels(requestedModel string, models []modelMappingEntry) string {
requestedModel = strings.TrimSpace(requestedModel)
if requestedModel == "" {
return ""
}
if len(models) == 0 {
return ""
}
requestResult := thinking.ParseSuffix(requestedModel)
base := requestResult.ModelName
candidates := []string{base}
if base != requestedModel {
candidates = append(candidates, requestedModel)
}
preserveSuffix := func(resolved string) string {
resolved = strings.TrimSpace(resolved)
if resolved == "" {
return ""
}
if thinking.ParseSuffix(resolved).HasSuffix {
return resolved
}
if requestResult.HasSuffix && requestResult.RawSuffix != "" {
return resolved + "(" + requestResult.RawSuffix + ")"
}
return resolved
}
for i := range models {
name := strings.TrimSpace(models[i].GetName())
alias := strings.TrimSpace(models[i].GetAlias())
for _, candidate := range candidates {
if candidate == "" {
continue
}
if alias != "" && strings.EqualFold(alias, candidate) {
if name != "" {
return preserveSuffix(name)
}
return preserveSuffix(candidate)
}
if name != "" && strings.EqualFold(name, candidate) {
return preserveSuffix(name)
}
}
}
return ""
}
// resolveOAuthUpstreamModel resolves the upstream model name from OAuth model mappings.
// If a mapping exists, returns the original (upstream) model name that corresponds
// to the requested alias.
//
// If the requested model contains a thinking suffix (e.g., "gemini-2.5-pro(8192)"),
// the suffix is preserved in the returned model name. However, if the mapping's
// original name already contains a suffix, the config suffix takes priority.
func (m *Manager) resolveOAuthUpstreamModel(auth *Auth, requestedModel string) string {
return resolveUpstreamModelFromMappingTable(m, auth, requestedModel, modelMappingChannel(auth))
}
func resolveUpstreamModelFromMappingTable(m *Manager, auth *Auth, requestedModel, channel string) string {
if m == nil || auth == nil {
return ""
}
if channel == "" {
return ""
}
// Extract thinking suffix from requested model using ParseSuffix
requestResult := thinking.ParseSuffix(requestedModel)
baseModel := requestResult.ModelName
// Candidate keys to match: base model and raw input (handles suffix-parsing edge cases).
candidates := []string{baseModel}
if baseModel != requestedModel {
candidates = append(candidates, requestedModel)
}
raw := m.modelNameMappings.Load()
table, _ := raw.(*modelNameMappingTable)
if table == nil || table.reverse == nil {
return ""
}
rev := table.reverse[channel]
if rev == nil {
return ""
}
for _, candidate := range candidates {
key := strings.ToLower(strings.TrimSpace(candidate))
if key == "" {
continue
}
original := strings.TrimSpace(rev[key])
if original == "" {
continue
}
if strings.EqualFold(original, baseModel) {
return ""
}
// If config already has suffix, it takes priority.
if thinking.ParseSuffix(original).HasSuffix {
return original
}
// Preserve user's thinking suffix on the resolved model.
if requestResult.HasSuffix && requestResult.RawSuffix != "" {
return original + "(" + requestResult.RawSuffix + ")"
}
return original
}
return ""
}
// modelMappingChannel extracts the OAuth model mapping channel from an Auth object.
// It determines the provider and auth kind from the Auth's attributes and delegates
// to OAuthModelMappingChannel for the actual channel resolution.
func modelMappingChannel(auth *Auth) string {
if auth == nil {
return ""
}
provider := strings.ToLower(strings.TrimSpace(auth.Provider))
authKind := ""
if auth.Attributes != nil {
authKind = strings.ToLower(strings.TrimSpace(auth.Attributes["auth_kind"]))
}
if authKind == "" {
if kind, _ := auth.AccountInfo(); strings.EqualFold(kind, "api_key") {
authKind = "apikey"
}
}
return OAuthModelMappingChannel(provider, authKind)
}
// OAuthModelMappingChannel returns the OAuth model mapping channel name for a given provider
// and auth kind. Returns empty string if the provider/authKind combination doesn't support
// OAuth model mappings (e.g., API key authentication).
//
// Supported channels: gemini-cli, vertex, aistudio, antigravity, claude, codex, qwen, iflow.
func OAuthModelMappingChannel(provider, authKind string) string {
provider = strings.ToLower(strings.TrimSpace(provider))
authKind = strings.ToLower(strings.TrimSpace(authKind))
switch provider {
case "gemini":
// gemini provider uses gemini-api-key config, not oauth-model-mappings.
// OAuth-based gemini auth is converted to "gemini-cli" by the synthesizer.
return ""
case "vertex":
if authKind == "apikey" {
return ""
}
return "vertex"
case "claude":
if authKind == "apikey" {
return ""
}
return "claude"
case "codex":
if authKind == "apikey" {
return ""
}
return "codex"
case "gemini-cli", "aistudio", "antigravity", "qwen", "iflow":
return provider
default:
return ""
}
}