feat(cliproxy): introduce global model name mappings for improved aliasing and routing

This commit is contained in:
Luis Pater
2025-12-30 08:13:06 +08:00
parent a8cb01819d
commit 50e6d845f4
10 changed files with 431 additions and 15 deletions

View File

@@ -111,6 +111,9 @@ type Manager struct {
requestRetry atomic.Int32
maxRetryInterval atomic.Int64
// modelNameMappings stores global model name alias mappings (alias -> upstream name) keyed by channel.
modelNameMappings atomic.Value
// Optional HTTP RoundTripper provider injected by host.
rtProvider RoundTripperProvider
@@ -410,6 +413,7 @@ func (m *Manager) executeWithProvider(ctx context.Context, provider string, req
}
execReq := req
execReq.Model, execReq.Metadata = rewriteModelForAuth(routeModel, req.Metadata, auth)
execReq.Metadata = m.applyGlobalModelNameMappingMetadata(auth, execReq.Model, execReq.Metadata)
resp, errExec := executor.Execute(execCtx, auth, execReq, opts)
result := Result{AuthID: auth.ID, Provider: provider, Model: routeModel, Success: errExec == nil}
if errExec != nil {
@@ -471,6 +475,7 @@ func (m *Manager) executeCountWithProvider(ctx context.Context, provider string,
}
execReq := req
execReq.Model, execReq.Metadata = rewriteModelForAuth(routeModel, req.Metadata, auth)
execReq.Metadata = m.applyGlobalModelNameMappingMetadata(auth, execReq.Model, execReq.Metadata)
resp, errExec := executor.CountTokens(execCtx, auth, execReq, opts)
result := Result{AuthID: auth.ID, Provider: provider, Model: routeModel, Success: errExec == nil}
if errExec != nil {
@@ -532,6 +537,7 @@ func (m *Manager) executeStreamWithProvider(ctx context.Context, provider string
}
execReq := req
execReq.Model, execReq.Metadata = rewriteModelForAuth(routeModel, req.Metadata, auth)
execReq.Metadata = m.applyGlobalModelNameMappingMetadata(auth, execReq.Model, execReq.Metadata)
chunks, errStream := executor.ExecuteStream(execCtx, auth, execReq, opts)
if errStream != nil {
rerr := &Error{Message: errStream.Error()}
@@ -592,6 +598,7 @@ func stripPrefixFromMetadata(metadata map[string]any, needle string) map[string]
keys := []string{
util.ThinkingOriginalModelMetadataKey,
util.GeminiOriginalModelMetadataKey,
util.ModelMappingOriginalModelMetadataKey,
}
var out map[string]any
for _, key := range keys {

View File

@@ -0,0 +1,163 @@
package auth
import (
"strings"
internalconfig "github.com/router-for-me/CLIProxyAPI/v6/internal/config"
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
)
type modelNameMappingTable struct {
// reverse maps channel -> alias (lower) -> original upstream model name.
reverse map[string]map[string]string
}
func compileModelNameMappingTable(mappings map[string][]internalconfig.ModelNameMapping) *modelNameMappingTable {
if len(mappings) == 0 {
return &modelNameMappingTable{}
}
out := &modelNameMappingTable{
reverse: make(map[string]map[string]string, len(mappings)),
}
for rawChannel, entries := range mappings {
channel := strings.ToLower(strings.TrimSpace(rawChannel))
if channel == "" || len(entries) == 0 {
continue
}
rev := make(map[string]string, len(entries))
for _, entry := range entries {
from := strings.TrimSpace(entry.From)
to := strings.TrimSpace(entry.To)
if from == "" || to == "" {
continue
}
if strings.EqualFold(from, to) {
continue
}
aliasKey := strings.ToLower(to)
if _, exists := rev[aliasKey]; exists {
continue
}
rev[aliasKey] = from
}
if len(rev) > 0 {
out.reverse[channel] = rev
}
}
if len(out.reverse) == 0 {
out.reverse = nil
}
return out
}
// SetGlobalModelNameMappings updates the global model name mapping table used during execution.
// The mapping is applied per-auth channel to resolve the upstream model name while keeping the
// client-visible model name unchanged for translation/response formatting.
func (m *Manager) SetGlobalModelNameMappings(mappings map[string][]internalconfig.ModelNameMapping) {
if m == nil {
return
}
table := compileModelNameMappingTable(mappings)
// atomic.Value requires non-nil store values.
if table == nil {
table = &modelNameMappingTable{}
}
m.modelNameMappings.Store(table)
}
func (m *Manager) applyGlobalModelNameMappingMetadata(auth *Auth, requestedModel string, metadata map[string]any) map[string]any {
original := m.resolveGlobalUpstreamModelForAuth(auth, requestedModel)
if original == "" {
return metadata
}
if metadata != nil {
if v, ok := metadata[util.ModelMappingOriginalModelMetadataKey]; ok {
if s, okStr := v.(string); okStr && strings.EqualFold(s, original) {
return metadata
}
}
}
out := make(map[string]any, 1)
if len(metadata) > 0 {
out = make(map[string]any, len(metadata)+1)
for k, v := range metadata {
out[k] = v
}
}
out[util.ModelMappingOriginalModelMetadataKey] = original
return out
}
func (m *Manager) resolveGlobalUpstreamModelForAuth(auth *Auth, requestedModel string) string {
if m == nil || auth == nil {
return ""
}
channel := globalModelMappingChannelForAuth(auth)
if channel == "" {
return ""
}
key := strings.ToLower(strings.TrimSpace(requestedModel))
if key == "" {
return ""
}
raw := m.modelNameMappings.Load()
table, _ := raw.(*modelNameMappingTable)
if table == nil || table.reverse == nil {
return ""
}
rev := table.reverse[channel]
if rev == nil {
return ""
}
original := strings.TrimSpace(rev[key])
if original == "" || strings.EqualFold(original, requestedModel) {
return ""
}
return original
}
func globalModelMappingChannelForAuth(auth *Auth) string {
if auth == nil {
return ""
}
provider := strings.ToLower(strings.TrimSpace(auth.Provider))
authKind := ""
if auth.Attributes != nil {
authKind = strings.ToLower(strings.TrimSpace(auth.Attributes["auth_kind"]))
}
if authKind == "" {
if kind, _ := auth.AccountInfo(); strings.EqualFold(kind, "api_key") {
authKind = "apikey"
}
}
return globalModelMappingChannel(provider, authKind)
}
func globalModelMappingChannel(provider, authKind string) string {
switch provider {
case "gemini":
if authKind == "apikey" {
return "apikey-gemini"
}
return "gemini"
case "codex":
if authKind == "apikey" {
return ""
}
return "codex"
case "claude":
if authKind == "apikey" {
return ""
}
return "claude"
case "vertex":
if authKind == "apikey" {
return ""
}
return "vertex"
case "antigravity", "qwen", "iflow":
return provider
default:
return ""
}
}

View File

@@ -215,6 +215,7 @@ func (b *Builder) Build() (*Service, error) {
}
// Attach a default RoundTripper provider so providers can opt-in per-auth transports.
coreManager.SetRoundTripperProvider(newDefaultRoundTripperProvider())
coreManager.SetGlobalModelNameMappings(b.cfg.ModelNameMappings)
service := &Service{
cfg: b.cfg,

View File

@@ -552,6 +552,9 @@ func (s *Service) Run(ctx context.Context) error {
s.cfgMu.Lock()
s.cfg = newCfg
s.cfgMu.Unlock()
if s.coreManager != nil {
s.coreManager.SetGlobalModelNameMappings(newCfg.ModelNameMappings)
}
s.rebindExecutors()
}
@@ -677,6 +680,11 @@ func (s *Service) registerModelsForAuth(a *coreauth.Auth) {
return
}
authKind := strings.ToLower(strings.TrimSpace(a.Attributes["auth_kind"]))
if authKind == "" {
if kind, _ := a.AccountInfo(); strings.EqualFold(kind, "api_key") {
authKind = "apikey"
}
}
if a.Attributes != nil {
if v := strings.TrimSpace(a.Attributes["gemini_virtual_primary"]); strings.EqualFold(v, "true") {
GlobalModelRegistry().UnregisterClient(a.ID)
@@ -836,6 +844,7 @@ func (s *Service) registerModelsForAuth(a *coreauth.Auth) {
}
}
}
models = applyGlobalModelNameMappings(s.cfg, provider, authKind, models)
if len(models) > 0 {
key := provider
if key == "" {
@@ -1145,6 +1154,124 @@ func buildVertexCompatConfigModels(entry *config.VertexCompatKey) []*ModelInfo {
return out
}
func globalModelMappingChannel(provider, authKind string) string {
provider = strings.ToLower(strings.TrimSpace(provider))
authKind = strings.ToLower(strings.TrimSpace(authKind))
switch provider {
case "gemini":
if authKind == "apikey" {
return "apikey-gemini"
}
return "gemini"
case "codex":
if authKind == "apikey" {
return ""
}
return "codex"
case "claude":
if authKind == "apikey" {
return ""
}
return "claude"
case "vertex":
if authKind == "apikey" {
return ""
}
return "vertex"
case "antigravity", "qwen", "iflow":
return provider
default:
return ""
}
}
func rewriteModelInfoName(name, oldID, newID string) string {
trimmed := strings.TrimSpace(name)
if trimmed == "" {
return name
}
oldID = strings.TrimSpace(oldID)
newID = strings.TrimSpace(newID)
if oldID == "" || newID == "" {
return name
}
if strings.EqualFold(oldID, newID) {
return name
}
if strings.HasSuffix(trimmed, "/"+oldID) {
prefix := strings.TrimSuffix(trimmed, oldID)
return prefix + newID
}
if trimmed == "models/"+oldID {
return "models/" + newID
}
return name
}
func applyGlobalModelNameMappings(cfg *config.Config, provider, authKind string, models []*ModelInfo) []*ModelInfo {
if cfg == nil || len(models) == 0 {
return models
}
channel := globalModelMappingChannel(provider, authKind)
if channel == "" || len(cfg.ModelNameMappings) == 0 {
return models
}
mappings := cfg.ModelNameMappings[channel]
if len(mappings) == 0 {
return models
}
forward := make(map[string]string, len(mappings))
for i := range mappings {
from := strings.TrimSpace(mappings[i].From)
to := strings.TrimSpace(mappings[i].To)
if from == "" || to == "" {
continue
}
if strings.EqualFold(from, to) {
continue
}
key := strings.ToLower(from)
if _, exists := forward[key]; exists {
continue
}
forward[key] = to
}
if len(forward) == 0 {
return models
}
out := make([]*ModelInfo, 0, len(models))
seen := make(map[string]struct{}, len(models))
for _, model := range models {
if model == nil {
continue
}
id := strings.TrimSpace(model.ID)
if id == "" {
continue
}
mappedID := id
if to, ok := forward[strings.ToLower(id)]; ok && strings.TrimSpace(to) != "" {
mappedID = strings.TrimSpace(to)
}
uniqueKey := strings.ToLower(mappedID)
if _, exists := seen[uniqueKey]; exists {
continue
}
seen[uniqueKey] = struct{}{}
if mappedID == id {
out = append(out, model)
continue
}
clone := *model
clone.ID = mappedID
if clone.Name != "" {
clone.Name = rewriteModelInfoName(clone.Name, id, mappedID)
}
out = append(out, &clone)
}
return out
}
func buildClaudeConfigModels(entry *config.ClaudeKey) []*ModelInfo {
if entry == nil || len(entry.Models) == 0 {
return nil