mirror of
https://github.com/router-for-me/CLIProxyAPI.git
synced 2026-02-03 13:00:52 +08:00
feat(cliproxy): introduce global model name mappings for improved aliasing and routing
This commit is contained in:
@@ -111,6 +111,9 @@ type Manager struct {
|
||||
requestRetry atomic.Int32
|
||||
maxRetryInterval atomic.Int64
|
||||
|
||||
// modelNameMappings stores global model name alias mappings (alias -> upstream name) keyed by channel.
|
||||
modelNameMappings atomic.Value
|
||||
|
||||
// Optional HTTP RoundTripper provider injected by host.
|
||||
rtProvider RoundTripperProvider
|
||||
|
||||
@@ -410,6 +413,7 @@ func (m *Manager) executeWithProvider(ctx context.Context, provider string, req
|
||||
}
|
||||
execReq := req
|
||||
execReq.Model, execReq.Metadata = rewriteModelForAuth(routeModel, req.Metadata, auth)
|
||||
execReq.Metadata = m.applyGlobalModelNameMappingMetadata(auth, execReq.Model, execReq.Metadata)
|
||||
resp, errExec := executor.Execute(execCtx, auth, execReq, opts)
|
||||
result := Result{AuthID: auth.ID, Provider: provider, Model: routeModel, Success: errExec == nil}
|
||||
if errExec != nil {
|
||||
@@ -471,6 +475,7 @@ func (m *Manager) executeCountWithProvider(ctx context.Context, provider string,
|
||||
}
|
||||
execReq := req
|
||||
execReq.Model, execReq.Metadata = rewriteModelForAuth(routeModel, req.Metadata, auth)
|
||||
execReq.Metadata = m.applyGlobalModelNameMappingMetadata(auth, execReq.Model, execReq.Metadata)
|
||||
resp, errExec := executor.CountTokens(execCtx, auth, execReq, opts)
|
||||
result := Result{AuthID: auth.ID, Provider: provider, Model: routeModel, Success: errExec == nil}
|
||||
if errExec != nil {
|
||||
@@ -532,6 +537,7 @@ func (m *Manager) executeStreamWithProvider(ctx context.Context, provider string
|
||||
}
|
||||
execReq := req
|
||||
execReq.Model, execReq.Metadata = rewriteModelForAuth(routeModel, req.Metadata, auth)
|
||||
execReq.Metadata = m.applyGlobalModelNameMappingMetadata(auth, execReq.Model, execReq.Metadata)
|
||||
chunks, errStream := executor.ExecuteStream(execCtx, auth, execReq, opts)
|
||||
if errStream != nil {
|
||||
rerr := &Error{Message: errStream.Error()}
|
||||
@@ -592,6 +598,7 @@ func stripPrefixFromMetadata(metadata map[string]any, needle string) map[string]
|
||||
keys := []string{
|
||||
util.ThinkingOriginalModelMetadataKey,
|
||||
util.GeminiOriginalModelMetadataKey,
|
||||
util.ModelMappingOriginalModelMetadataKey,
|
||||
}
|
||||
var out map[string]any
|
||||
for _, key := range keys {
|
||||
|
||||
163
sdk/cliproxy/auth/model_name_mappings.go
Normal file
163
sdk/cliproxy/auth/model_name_mappings.go
Normal file
@@ -0,0 +1,163 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"strings"
|
||||
|
||||
internalconfig "github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
||||
)
|
||||
|
||||
type modelNameMappingTable struct {
|
||||
// reverse maps channel -> alias (lower) -> original upstream model name.
|
||||
reverse map[string]map[string]string
|
||||
}
|
||||
|
||||
func compileModelNameMappingTable(mappings map[string][]internalconfig.ModelNameMapping) *modelNameMappingTable {
|
||||
if len(mappings) == 0 {
|
||||
return &modelNameMappingTable{}
|
||||
}
|
||||
out := &modelNameMappingTable{
|
||||
reverse: make(map[string]map[string]string, len(mappings)),
|
||||
}
|
||||
for rawChannel, entries := range mappings {
|
||||
channel := strings.ToLower(strings.TrimSpace(rawChannel))
|
||||
if channel == "" || len(entries) == 0 {
|
||||
continue
|
||||
}
|
||||
rev := make(map[string]string, len(entries))
|
||||
for _, entry := range entries {
|
||||
from := strings.TrimSpace(entry.From)
|
||||
to := strings.TrimSpace(entry.To)
|
||||
if from == "" || to == "" {
|
||||
continue
|
||||
}
|
||||
if strings.EqualFold(from, to) {
|
||||
continue
|
||||
}
|
||||
aliasKey := strings.ToLower(to)
|
||||
if _, exists := rev[aliasKey]; exists {
|
||||
continue
|
||||
}
|
||||
rev[aliasKey] = from
|
||||
}
|
||||
if len(rev) > 0 {
|
||||
out.reverse[channel] = rev
|
||||
}
|
||||
}
|
||||
if len(out.reverse) == 0 {
|
||||
out.reverse = nil
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// SetGlobalModelNameMappings updates the global model name mapping table used during execution.
|
||||
// The mapping is applied per-auth channel to resolve the upstream model name while keeping the
|
||||
// client-visible model name unchanged for translation/response formatting.
|
||||
func (m *Manager) SetGlobalModelNameMappings(mappings map[string][]internalconfig.ModelNameMapping) {
|
||||
if m == nil {
|
||||
return
|
||||
}
|
||||
table := compileModelNameMappingTable(mappings)
|
||||
// atomic.Value requires non-nil store values.
|
||||
if table == nil {
|
||||
table = &modelNameMappingTable{}
|
||||
}
|
||||
m.modelNameMappings.Store(table)
|
||||
}
|
||||
|
||||
func (m *Manager) applyGlobalModelNameMappingMetadata(auth *Auth, requestedModel string, metadata map[string]any) map[string]any {
|
||||
original := m.resolveGlobalUpstreamModelForAuth(auth, requestedModel)
|
||||
if original == "" {
|
||||
return metadata
|
||||
}
|
||||
if metadata != nil {
|
||||
if v, ok := metadata[util.ModelMappingOriginalModelMetadataKey]; ok {
|
||||
if s, okStr := v.(string); okStr && strings.EqualFold(s, original) {
|
||||
return metadata
|
||||
}
|
||||
}
|
||||
}
|
||||
out := make(map[string]any, 1)
|
||||
if len(metadata) > 0 {
|
||||
out = make(map[string]any, len(metadata)+1)
|
||||
for k, v := range metadata {
|
||||
out[k] = v
|
||||
}
|
||||
}
|
||||
out[util.ModelMappingOriginalModelMetadataKey] = original
|
||||
return out
|
||||
}
|
||||
|
||||
func (m *Manager) resolveGlobalUpstreamModelForAuth(auth *Auth, requestedModel string) string {
|
||||
if m == nil || auth == nil {
|
||||
return ""
|
||||
}
|
||||
channel := globalModelMappingChannelForAuth(auth)
|
||||
if channel == "" {
|
||||
return ""
|
||||
}
|
||||
key := strings.ToLower(strings.TrimSpace(requestedModel))
|
||||
if key == "" {
|
||||
return ""
|
||||
}
|
||||
raw := m.modelNameMappings.Load()
|
||||
table, _ := raw.(*modelNameMappingTable)
|
||||
if table == nil || table.reverse == nil {
|
||||
return ""
|
||||
}
|
||||
rev := table.reverse[channel]
|
||||
if rev == nil {
|
||||
return ""
|
||||
}
|
||||
original := strings.TrimSpace(rev[key])
|
||||
if original == "" || strings.EqualFold(original, requestedModel) {
|
||||
return ""
|
||||
}
|
||||
return original
|
||||
}
|
||||
|
||||
func globalModelMappingChannelForAuth(auth *Auth) string {
|
||||
if auth == nil {
|
||||
return ""
|
||||
}
|
||||
provider := strings.ToLower(strings.TrimSpace(auth.Provider))
|
||||
authKind := ""
|
||||
if auth.Attributes != nil {
|
||||
authKind = strings.ToLower(strings.TrimSpace(auth.Attributes["auth_kind"]))
|
||||
}
|
||||
if authKind == "" {
|
||||
if kind, _ := auth.AccountInfo(); strings.EqualFold(kind, "api_key") {
|
||||
authKind = "apikey"
|
||||
}
|
||||
}
|
||||
return globalModelMappingChannel(provider, authKind)
|
||||
}
|
||||
|
||||
func globalModelMappingChannel(provider, authKind string) string {
|
||||
switch provider {
|
||||
case "gemini":
|
||||
if authKind == "apikey" {
|
||||
return "apikey-gemini"
|
||||
}
|
||||
return "gemini"
|
||||
case "codex":
|
||||
if authKind == "apikey" {
|
||||
return ""
|
||||
}
|
||||
return "codex"
|
||||
case "claude":
|
||||
if authKind == "apikey" {
|
||||
return ""
|
||||
}
|
||||
return "claude"
|
||||
case "vertex":
|
||||
if authKind == "apikey" {
|
||||
return ""
|
||||
}
|
||||
return "vertex"
|
||||
case "antigravity", "qwen", "iflow":
|
||||
return provider
|
||||
default:
|
||||
return ""
|
||||
}
|
||||
}
|
||||
@@ -215,6 +215,7 @@ func (b *Builder) Build() (*Service, error) {
|
||||
}
|
||||
// Attach a default RoundTripper provider so providers can opt-in per-auth transports.
|
||||
coreManager.SetRoundTripperProvider(newDefaultRoundTripperProvider())
|
||||
coreManager.SetGlobalModelNameMappings(b.cfg.ModelNameMappings)
|
||||
|
||||
service := &Service{
|
||||
cfg: b.cfg,
|
||||
|
||||
@@ -552,6 +552,9 @@ func (s *Service) Run(ctx context.Context) error {
|
||||
s.cfgMu.Lock()
|
||||
s.cfg = newCfg
|
||||
s.cfgMu.Unlock()
|
||||
if s.coreManager != nil {
|
||||
s.coreManager.SetGlobalModelNameMappings(newCfg.ModelNameMappings)
|
||||
}
|
||||
s.rebindExecutors()
|
||||
}
|
||||
|
||||
@@ -677,6 +680,11 @@ func (s *Service) registerModelsForAuth(a *coreauth.Auth) {
|
||||
return
|
||||
}
|
||||
authKind := strings.ToLower(strings.TrimSpace(a.Attributes["auth_kind"]))
|
||||
if authKind == "" {
|
||||
if kind, _ := a.AccountInfo(); strings.EqualFold(kind, "api_key") {
|
||||
authKind = "apikey"
|
||||
}
|
||||
}
|
||||
if a.Attributes != nil {
|
||||
if v := strings.TrimSpace(a.Attributes["gemini_virtual_primary"]); strings.EqualFold(v, "true") {
|
||||
GlobalModelRegistry().UnregisterClient(a.ID)
|
||||
@@ -836,6 +844,7 @@ func (s *Service) registerModelsForAuth(a *coreauth.Auth) {
|
||||
}
|
||||
}
|
||||
}
|
||||
models = applyGlobalModelNameMappings(s.cfg, provider, authKind, models)
|
||||
if len(models) > 0 {
|
||||
key := provider
|
||||
if key == "" {
|
||||
@@ -1145,6 +1154,124 @@ func buildVertexCompatConfigModels(entry *config.VertexCompatKey) []*ModelInfo {
|
||||
return out
|
||||
}
|
||||
|
||||
func globalModelMappingChannel(provider, authKind string) string {
|
||||
provider = strings.ToLower(strings.TrimSpace(provider))
|
||||
authKind = strings.ToLower(strings.TrimSpace(authKind))
|
||||
switch provider {
|
||||
case "gemini":
|
||||
if authKind == "apikey" {
|
||||
return "apikey-gemini"
|
||||
}
|
||||
return "gemini"
|
||||
case "codex":
|
||||
if authKind == "apikey" {
|
||||
return ""
|
||||
}
|
||||
return "codex"
|
||||
case "claude":
|
||||
if authKind == "apikey" {
|
||||
return ""
|
||||
}
|
||||
return "claude"
|
||||
case "vertex":
|
||||
if authKind == "apikey" {
|
||||
return ""
|
||||
}
|
||||
return "vertex"
|
||||
case "antigravity", "qwen", "iflow":
|
||||
return provider
|
||||
default:
|
||||
return ""
|
||||
}
|
||||
}
|
||||
|
||||
func rewriteModelInfoName(name, oldID, newID string) string {
|
||||
trimmed := strings.TrimSpace(name)
|
||||
if trimmed == "" {
|
||||
return name
|
||||
}
|
||||
oldID = strings.TrimSpace(oldID)
|
||||
newID = strings.TrimSpace(newID)
|
||||
if oldID == "" || newID == "" {
|
||||
return name
|
||||
}
|
||||
if strings.EqualFold(oldID, newID) {
|
||||
return name
|
||||
}
|
||||
if strings.HasSuffix(trimmed, "/"+oldID) {
|
||||
prefix := strings.TrimSuffix(trimmed, oldID)
|
||||
return prefix + newID
|
||||
}
|
||||
if trimmed == "models/"+oldID {
|
||||
return "models/" + newID
|
||||
}
|
||||
return name
|
||||
}
|
||||
|
||||
func applyGlobalModelNameMappings(cfg *config.Config, provider, authKind string, models []*ModelInfo) []*ModelInfo {
|
||||
if cfg == nil || len(models) == 0 {
|
||||
return models
|
||||
}
|
||||
channel := globalModelMappingChannel(provider, authKind)
|
||||
if channel == "" || len(cfg.ModelNameMappings) == 0 {
|
||||
return models
|
||||
}
|
||||
mappings := cfg.ModelNameMappings[channel]
|
||||
if len(mappings) == 0 {
|
||||
return models
|
||||
}
|
||||
forward := make(map[string]string, len(mappings))
|
||||
for i := range mappings {
|
||||
from := strings.TrimSpace(mappings[i].From)
|
||||
to := strings.TrimSpace(mappings[i].To)
|
||||
if from == "" || to == "" {
|
||||
continue
|
||||
}
|
||||
if strings.EqualFold(from, to) {
|
||||
continue
|
||||
}
|
||||
key := strings.ToLower(from)
|
||||
if _, exists := forward[key]; exists {
|
||||
continue
|
||||
}
|
||||
forward[key] = to
|
||||
}
|
||||
if len(forward) == 0 {
|
||||
return models
|
||||
}
|
||||
out := make([]*ModelInfo, 0, len(models))
|
||||
seen := make(map[string]struct{}, len(models))
|
||||
for _, model := range models {
|
||||
if model == nil {
|
||||
continue
|
||||
}
|
||||
id := strings.TrimSpace(model.ID)
|
||||
if id == "" {
|
||||
continue
|
||||
}
|
||||
mappedID := id
|
||||
if to, ok := forward[strings.ToLower(id)]; ok && strings.TrimSpace(to) != "" {
|
||||
mappedID = strings.TrimSpace(to)
|
||||
}
|
||||
uniqueKey := strings.ToLower(mappedID)
|
||||
if _, exists := seen[uniqueKey]; exists {
|
||||
continue
|
||||
}
|
||||
seen[uniqueKey] = struct{}{}
|
||||
if mappedID == id {
|
||||
out = append(out, model)
|
||||
continue
|
||||
}
|
||||
clone := *model
|
||||
clone.ID = mappedID
|
||||
if clone.Name != "" {
|
||||
clone.Name = rewriteModelInfoName(clone.Name, id, mappedID)
|
||||
}
|
||||
out = append(out, &clone)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func buildClaudeConfigModels(entry *config.ClaudeKey) []*ModelInfo {
|
||||
if entry == nil || len(entry.Models) == 0 {
|
||||
return nil
|
||||
|
||||
Reference in New Issue
Block a user