refactor(config): rename model-name-mappings to oauth-model-mappings

This commit is contained in:
hkfires
2025-12-30 10:14:27 +08:00
parent f6ab6d97b9
commit 7be3f1c36c
7 changed files with 85 additions and 102 deletions

View File

@@ -413,7 +413,7 @@ func (m *Manager) executeWithProvider(ctx context.Context, provider string, req
}
execReq := req
execReq.Model, execReq.Metadata = rewriteModelForAuth(routeModel, req.Metadata, auth)
execReq.Metadata = m.applyGlobalModelNameMappingMetadata(auth, execReq.Model, execReq.Metadata)
execReq.Metadata = m.applyOAuthModelMappingMetadata(auth, execReq.Model, execReq.Metadata)
resp, errExec := executor.Execute(execCtx, auth, execReq, opts)
result := Result{AuthID: auth.ID, Provider: provider, Model: routeModel, Success: errExec == nil}
if errExec != nil {
@@ -475,7 +475,7 @@ func (m *Manager) executeCountWithProvider(ctx context.Context, provider string,
}
execReq := req
execReq.Model, execReq.Metadata = rewriteModelForAuth(routeModel, req.Metadata, auth)
execReq.Metadata = m.applyGlobalModelNameMappingMetadata(auth, execReq.Model, execReq.Metadata)
execReq.Metadata = m.applyOAuthModelMappingMetadata(auth, execReq.Model, execReq.Metadata)
resp, errExec := executor.CountTokens(execCtx, auth, execReq, opts)
result := Result{AuthID: auth.ID, Provider: provider, Model: routeModel, Success: errExec == nil}
if errExec != nil {
@@ -537,7 +537,7 @@ func (m *Manager) executeStreamWithProvider(ctx context.Context, provider string
}
execReq := req
execReq.Model, execReq.Metadata = rewriteModelForAuth(routeModel, req.Metadata, auth)
execReq.Metadata = m.applyGlobalModelNameMappingMetadata(auth, execReq.Model, execReq.Metadata)
execReq.Metadata = m.applyOAuthModelMappingMetadata(auth, execReq.Model, execReq.Metadata)
chunks, errStream := executor.ExecuteStream(execCtx, auth, execReq, opts)
if errStream != nil {
rerr := &Error{Message: errStream.Error()}

View File

@@ -50,10 +50,10 @@ func compileModelNameMappingTable(mappings map[string][]internalconfig.ModelName
return out
}
// SetGlobalModelNameMappings updates the global model name mapping table used during execution.
// SetOAuthModelMappings updates the OAuth model name mapping table used during execution.
// The mapping is applied per-auth channel to resolve the upstream model name while keeping the
// client-visible model name unchanged for translation/response formatting.
func (m *Manager) SetGlobalModelNameMappings(mappings map[string][]internalconfig.ModelNameMapping) {
func (m *Manager) SetOAuthModelMappings(mappings map[string][]internalconfig.ModelNameMapping) {
if m == nil {
return
}
@@ -65,8 +65,8 @@ func (m *Manager) SetGlobalModelNameMappings(mappings map[string][]internalconfi
m.modelNameMappings.Store(table)
}
func (m *Manager) applyGlobalModelNameMappingMetadata(auth *Auth, requestedModel string, metadata map[string]any) map[string]any {
original := m.resolveGlobalUpstreamModelForAuth(auth, requestedModel)
func (m *Manager) applyOAuthModelMappingMetadata(auth *Auth, requestedModel string, metadata map[string]any) map[string]any {
original := m.resolveOAuthUpstreamModel(auth, requestedModel)
if original == "" {
return metadata
}
@@ -88,11 +88,11 @@ func (m *Manager) applyGlobalModelNameMappingMetadata(auth *Auth, requestedModel
return out
}
func (m *Manager) resolveGlobalUpstreamModelForAuth(auth *Auth, requestedModel string) string {
func (m *Manager) resolveOAuthUpstreamModel(auth *Auth, requestedModel string) string {
if m == nil || auth == nil {
return ""
}
channel := globalModelMappingChannelForAuth(auth)
channel := modelMappingChannel(auth)
if channel == "" {
return ""
}
@@ -116,7 +116,10 @@ func (m *Manager) resolveGlobalUpstreamModelForAuth(auth *Auth, requestedModel s
return original
}
func globalModelMappingChannelForAuth(auth *Auth) string {
// modelMappingChannel extracts the OAuth model mapping channel from an Auth object.
// It determines the provider and auth kind from the Auth's attributes and delegates
// to OAuthModelMappingChannel for the actual channel resolution.
func modelMappingChannel(auth *Auth) string {
if auth == nil {
return ""
}
@@ -130,32 +133,38 @@ func globalModelMappingChannelForAuth(auth *Auth) string {
authKind = "apikey"
}
}
return globalModelMappingChannel(provider, authKind)
return OAuthModelMappingChannel(provider, authKind)
}
func globalModelMappingChannel(provider, authKind string) string {
// OAuthModelMappingChannel returns the OAuth model mapping channel name for a given provider
// and auth kind. Returns empty string if the provider/authKind combination doesn't support
// OAuth model mappings (e.g., API key authentication).
//
// Supported channels: gemini-cli, vertex, aistudio, antigravity, claude, codex, qwen, iflow.
func OAuthModelMappingChannel(provider, authKind string) string {
provider = strings.ToLower(strings.TrimSpace(provider))
authKind = strings.ToLower(strings.TrimSpace(authKind))
switch provider {
case "gemini":
if authKind == "apikey" {
return "apikey-gemini"
}
return "gemini"
case "codex":
if authKind == "apikey" {
return ""
}
return "codex"
case "claude":
if authKind == "apikey" {
return ""
}
return "claude"
// gemini provider uses gemini-api-key config, not oauth-model-mappings.
// OAuth-based gemini auth is converted to "gemini-cli" by the synthesizer.
return ""
case "vertex":
if authKind == "apikey" {
return ""
}
return "vertex"
case "antigravity", "qwen", "iflow":
case "claude":
if authKind == "apikey" {
return ""
}
return "claude"
case "codex":
if authKind == "apikey" {
return ""
}
return "codex"
case "gemini-cli", "aistudio", "antigravity", "qwen", "iflow":
return provider
default:
return ""

View File

@@ -215,7 +215,7 @@ func (b *Builder) Build() (*Service, error) {
}
// Attach a default RoundTripper provider so providers can opt-in per-auth transports.
coreManager.SetRoundTripperProvider(newDefaultRoundTripperProvider())
coreManager.SetGlobalModelNameMappings(b.cfg.ModelNameMappings)
coreManager.SetOAuthModelMappings(b.cfg.OAuthModelMappings)
service := &Service{
cfg: b.cfg,

View File

@@ -553,7 +553,7 @@ func (s *Service) Run(ctx context.Context) error {
s.cfg = newCfg
s.cfgMu.Unlock()
if s.coreManager != nil {
s.coreManager.SetGlobalModelNameMappings(newCfg.ModelNameMappings)
s.coreManager.SetOAuthModelMappings(newCfg.OAuthModelMappings)
}
s.rebindExecutors()
}
@@ -844,7 +844,7 @@ func (s *Service) registerModelsForAuth(a *coreauth.Auth) {
}
}
}
models = applyGlobalModelNameMappings(s.cfg, provider, authKind, models)
models = applyOAuthModelMappings(s.cfg, provider, authKind, models)
if len(models) > 0 {
key := provider
if key == "" {
@@ -1154,37 +1154,6 @@ func buildVertexCompatConfigModels(entry *config.VertexCompatKey) []*ModelInfo {
return out
}
func globalModelMappingChannel(provider, authKind string) string {
provider = strings.ToLower(strings.TrimSpace(provider))
authKind = strings.ToLower(strings.TrimSpace(authKind))
switch provider {
case "gemini":
if authKind == "apikey" {
return "apikey-gemini"
}
return "gemini"
case "codex":
if authKind == "apikey" {
return ""
}
return "codex"
case "claude":
if authKind == "apikey" {
return ""
}
return "claude"
case "vertex":
if authKind == "apikey" {
return ""
}
return "vertex"
case "antigravity", "qwen", "iflow":
return provider
default:
return ""
}
}
func rewriteModelInfoName(name, oldID, newID string) string {
trimmed := strings.TrimSpace(name)
if trimmed == "" {
@@ -1208,15 +1177,15 @@ func rewriteModelInfoName(name, oldID, newID string) string {
return name
}
func applyGlobalModelNameMappings(cfg *config.Config, provider, authKind string, models []*ModelInfo) []*ModelInfo {
func applyOAuthModelMappings(cfg *config.Config, provider, authKind string, models []*ModelInfo) []*ModelInfo {
if cfg == nil || len(models) == 0 {
return models
}
channel := globalModelMappingChannel(provider, authKind)
if channel == "" || len(cfg.ModelNameMappings) == 0 {
channel := coreauth.OAuthModelMappingChannel(provider, authKind)
if channel == "" || len(cfg.OAuthModelMappings) == 0 {
return models
}
mappings := cfg.ModelNameMappings[channel]
mappings := cfg.OAuthModelMappings[channel]
if len(mappings) == 0 {
return models
}