refactor(thinking): remove legacy utilities and simplify model mapping

This commit is contained in:
hkfires
2026-01-14 19:11:04 +08:00
parent 33d66959e9
commit 2262479365
9 changed files with 43 additions and 865 deletions

View File

@@ -149,53 +149,32 @@ func TestApplyAPIKeyModelMapping(t *testing.T) {
_, _ = mgr.Register(ctx, apiKeyAuth)
tests := []struct {
name string
auth *Auth
inputModel string
wantModel string
wantOriginal string
expectMapping bool
name string
auth *Auth
inputModel string
wantModel string
}{
{
name: "api_key auth with alias",
auth: apiKeyAuth,
inputModel: "g25p(8192)",
wantModel: "gemini-2.5-pro-exp-03-25(8192)",
wantOriginal: "g25p(8192)",
expectMapping: true,
name: "api_key auth with alias",
auth: apiKeyAuth,
inputModel: "g25p(8192)",
wantModel: "gemini-2.5-pro-exp-03-25(8192)",
},
{
name: "oauth auth passthrough",
auth: oauthAuth,
inputModel: "some-model",
wantModel: "some-model",
expectMapping: false,
name: "oauth auth passthrough",
auth: oauthAuth,
inputModel: "some-model",
wantModel: "some-model",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
metadata := map[string]any{"existing": "value"}
resolvedModel, resultMeta := mgr.applyAPIKeyModelMapping(tt.auth, tt.inputModel, metadata)
resolvedModel := mgr.applyAPIKeyModelMapping(tt.auth, tt.inputModel)
if resolvedModel != tt.wantModel {
t.Errorf("model = %q, want %q", resolvedModel, tt.wantModel)
}
if resultMeta["existing"] != "value" {
t.Error("existing metadata not preserved")
}
original, hasOriginal := resultMeta["model_mapping_original_model"].(string)
if tt.expectMapping {
if !hasOriginal || original != tt.wantOriginal {
t.Errorf("original model = %q, want %q", original, tt.wantOriginal)
}
} else {
if hasOriginal {
t.Error("should not set model_mapping_original_model for non-api_key auth")
}
}
})
}
}

View File

@@ -752,9 +752,9 @@ func (m *Manager) executeWithProvider(ctx context.Context, provider string, req
execCtx = context.WithValue(execCtx, "cliproxy.roundtripper", rt)
}
execReq := req
execReq.Model, execReq.Metadata = rewriteModelForAuth(routeModel, req.Metadata, auth)
execReq.Model, execReq.Metadata = m.applyOAuthModelMapping(auth, execReq.Model, execReq.Metadata)
execReq.Model, execReq.Metadata = m.applyAPIKeyModelMapping(auth, execReq.Model, execReq.Metadata)
execReq.Model = rewriteModelForAuth(routeModel, auth)
execReq.Model = m.applyOAuthModelMapping(auth, execReq.Model)
execReq.Model = m.applyAPIKeyModelMapping(auth, execReq.Model)
resp, errExec := executor.Execute(execCtx, auth, execReq, opts)
result := Result{AuthID: auth.ID, Provider: provider, Model: routeModel, Success: errExec == nil}
if errExec != nil {
@@ -801,9 +801,9 @@ func (m *Manager) executeCountWithProvider(ctx context.Context, provider string,
execCtx = context.WithValue(execCtx, "cliproxy.roundtripper", rt)
}
execReq := req
execReq.Model, execReq.Metadata = rewriteModelForAuth(routeModel, req.Metadata, auth)
execReq.Model, execReq.Metadata = m.applyOAuthModelMapping(auth, execReq.Model, execReq.Metadata)
execReq.Model, execReq.Metadata = m.applyAPIKeyModelMapping(auth, execReq.Model, execReq.Metadata)
execReq.Model = rewriteModelForAuth(routeModel, auth)
execReq.Model = m.applyOAuthModelMapping(auth, execReq.Model)
execReq.Model = m.applyAPIKeyModelMapping(auth, execReq.Model)
resp, errExec := executor.CountTokens(execCtx, auth, execReq, opts)
result := Result{AuthID: auth.ID, Provider: provider, Model: routeModel, Success: errExec == nil}
if errExec != nil {
@@ -850,9 +850,9 @@ func (m *Manager) executeStreamWithProvider(ctx context.Context, provider string
execCtx = context.WithValue(execCtx, "cliproxy.roundtripper", rt)
}
execReq := req
execReq.Model, execReq.Metadata = rewriteModelForAuth(routeModel, req.Metadata, auth)
execReq.Model, execReq.Metadata = m.applyOAuthModelMapping(auth, execReq.Model, execReq.Metadata)
execReq.Model, execReq.Metadata = m.applyAPIKeyModelMapping(auth, execReq.Model, execReq.Metadata)
execReq.Model = rewriteModelForAuth(routeModel, auth)
execReq.Model = m.applyOAuthModelMapping(auth, execReq.Model)
execReq.Model = m.applyAPIKeyModelMapping(auth, execReq.Model)
chunks, errStream := executor.ExecuteStream(execCtx, auth, execReq, opts)
if errStream != nil {
rerr := &Error{Message: errStream.Error()}
@@ -890,72 +890,39 @@ func (m *Manager) executeStreamWithProvider(ctx context.Context, provider string
}
}
func rewriteModelForAuth(model string, metadata map[string]any, auth *Auth) (string, map[string]any) {
func rewriteModelForAuth(model string, auth *Auth) string {
if auth == nil || model == "" {
return model, metadata
return model
}
prefix := strings.TrimSpace(auth.Prefix)
if prefix == "" {
return model, metadata
return model
}
needle := prefix + "/"
if !strings.HasPrefix(model, needle) {
return model, metadata
return model
}
rewritten := strings.TrimPrefix(model, needle)
return rewritten, stripPrefixFromMetadata(metadata, needle)
return strings.TrimPrefix(model, needle)
}
func stripPrefixFromMetadata(metadata map[string]any, needle string) map[string]any {
if len(metadata) == 0 || needle == "" {
return metadata
}
keys := []string{
util.GeminiOriginalModelMetadataKey,
util.ModelMappingOriginalModelMetadataKey,
}
var out map[string]any
for _, key := range keys {
raw, ok := metadata[key]
if !ok {
continue
}
value, okStr := raw.(string)
if !okStr || !strings.HasPrefix(value, needle) {
continue
}
if out == nil {
out = make(map[string]any, len(metadata))
for k, v := range metadata {
out[k] = v
}
}
out[key] = strings.TrimPrefix(value, needle)
}
if out == nil {
return metadata
}
return out
}
func (m *Manager) applyAPIKeyModelMapping(auth *Auth, requestedModel string, metadata map[string]any) (string, map[string]any) {
func (m *Manager) applyAPIKeyModelMapping(auth *Auth, requestedModel string) string {
if m == nil || auth == nil {
return requestedModel, metadata
return requestedModel
}
kind, _ := auth.AccountInfo()
if !strings.EqualFold(strings.TrimSpace(kind), "api_key") {
return requestedModel, metadata
return requestedModel
}
requestedModel = strings.TrimSpace(requestedModel)
if requestedModel == "" {
return requestedModel, metadata
return requestedModel
}
// Fast path: lookup per-auth mapping table (keyed by auth.ID).
if resolved := m.lookupAPIKeyUpstreamModel(auth.ID, requestedModel); resolved != "" {
return applyUpstreamModelOverride(requestedModel, resolved, metadata)
return resolved
}
// Slow path: scan config for the matching credential entry and resolve alias.
@@ -980,8 +947,11 @@ func (m *Manager) applyAPIKeyModelMapping(auth *Auth, requestedModel string, met
upstreamModel = resolveUpstreamModelForOpenAICompatAPIKey(cfg, auth, requestedModel)
}
// applyUpstreamModelOverride lives in model_name_mappings.go.
return applyUpstreamModelOverride(requestedModel, upstreamModel, metadata)
// Return upstream model if found, otherwise return requested model.
if upstreamModel != "" {
return upstreamModel
}
return requestedModel
}
// APIKeyConfigEntry is a generic interface for API key configurations.

View File

@@ -5,7 +5,6 @@ import (
internalconfig "github.com/router-for-me/CLIProxyAPI/v6/internal/config"
"github.com/router-for-me/CLIProxyAPI/v6/internal/thinking"
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
)
type modelMappingEntry interface {
@@ -71,31 +70,14 @@ func (m *Manager) SetOAuthModelMappings(mappings map[string][]internalconfig.Mod
m.modelNameMappings.Store(table)
}
// applyOAuthModelMapping resolves the upstream model from OAuth model mappings
// and returns the resolved model along with updated metadata. If a mapping exists,
// the returned model is the upstream model and metadata contains the original
// requested model for response translation.
func (m *Manager) applyOAuthModelMapping(auth *Auth, requestedModel string, metadata map[string]any) (string, map[string]any) {
// applyOAuthModelMapping resolves the upstream model from OAuth model mappings.
// If a mapping exists, the returned model is the upstream model.
func (m *Manager) applyOAuthModelMapping(auth *Auth, requestedModel string) string {
upstreamModel := m.resolveOAuthUpstreamModel(auth, requestedModel)
return applyUpstreamModelOverride(requestedModel, upstreamModel, metadata)
}
func applyUpstreamModelOverride(requestedModel, upstreamModel string, metadata map[string]any) (string, map[string]any) {
if upstreamModel == "" {
return requestedModel, metadata
return requestedModel
}
out := make(map[string]any, 1)
if len(metadata) > 0 {
out = make(map[string]any, len(metadata)+1)
for k, v := range metadata {
out[k] = v
}
}
// Preserve the original client model string (including any suffix) for downstream.
out[util.ModelMappingOriginalModelMetadataKey] = requestedModel
return upstreamModel, out
return upstreamModel
}
func resolveModelAliasFromConfigModels(requestedModel string, models []modelMappingEntry) string {

View File

@@ -169,19 +169,9 @@ func TestApplyOAuthModelMapping_SuffixPreservation(t *testing.T) {
mgr.SetOAuthModelMappings(mappings)
auth := &Auth{ID: "test-auth-id", Provider: "gemini-cli"}
metadata := map[string]any{"existing": "value"}
resolvedModel, resultMeta := mgr.applyOAuthModelMapping(auth, "gemini-2.5-pro(8192)", metadata)
resolvedModel := mgr.applyOAuthModelMapping(auth, "gemini-2.5-pro(8192)")
if resolvedModel != "gemini-2.5-pro-exp-03-25(8192)" {
t.Errorf("applyOAuthModelMapping() model = %q, want %q", resolvedModel, "gemini-2.5-pro-exp-03-25(8192)")
}
originalModel, ok := resultMeta["model_mapping_original_model"].(string)
if !ok || originalModel != "gemini-2.5-pro(8192)" {
t.Errorf("applyOAuthModelMapping() metadata[model_mapping_original_model] = %v, want %q", resultMeta["model_mapping_original_model"], "gemini-2.5-pro(8192)")
}
if resultMeta["existing"] != "value" {
t.Errorf("applyOAuthModelMapping() metadata[existing] = %v, want %q", resultMeta["existing"], "value")
}
}