mirror of
https://github.com/router-for-me/CLIProxyAPI.git
synced 2026-02-02 04:20:50 +08:00
Merge pull request #1033 from router-for-me/reasoning
Refactor thinking
This commit is contained in:
@@ -16,6 +16,7 @@ import (
|
||||
"github.com/google/uuid"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/interfaces"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/logging"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/thinking"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
||||
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||
coreexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
|
||||
@@ -379,7 +380,7 @@ func appendAPIResponse(c *gin.Context, data []byte) {
|
||||
// ExecuteWithAuthManager executes a non-streaming request via the core auth manager.
|
||||
// This path is the only supported execution route.
|
||||
func (h *BaseAPIHandler) ExecuteWithAuthManager(ctx context.Context, handlerType, modelName string, rawJSON []byte, alt string) ([]byte, *interfaces.ErrorMessage) {
|
||||
providers, normalizedModel, metadata, errMsg := h.getRequestDetails(modelName)
|
||||
providers, normalizedModel, errMsg := h.getRequestDetails(modelName)
|
||||
if errMsg != nil {
|
||||
return nil, errMsg
|
||||
}
|
||||
@@ -388,16 +389,13 @@ func (h *BaseAPIHandler) ExecuteWithAuthManager(ctx context.Context, handlerType
|
||||
Model: normalizedModel,
|
||||
Payload: cloneBytes(rawJSON),
|
||||
}
|
||||
if cloned := cloneMetadata(metadata); cloned != nil {
|
||||
req.Metadata = cloned
|
||||
}
|
||||
opts := coreexecutor.Options{
|
||||
Stream: false,
|
||||
Alt: alt,
|
||||
OriginalRequest: cloneBytes(rawJSON),
|
||||
SourceFormat: sdktranslator.FromString(handlerType),
|
||||
}
|
||||
opts.Metadata = mergeMetadata(cloneMetadata(metadata), reqMeta)
|
||||
opts.Metadata = reqMeta
|
||||
resp, err := h.AuthManager.Execute(ctx, providers, req, opts)
|
||||
if err != nil {
|
||||
status := http.StatusInternalServerError
|
||||
@@ -420,7 +418,7 @@ func (h *BaseAPIHandler) ExecuteWithAuthManager(ctx context.Context, handlerType
|
||||
// ExecuteCountWithAuthManager executes a non-streaming request via the core auth manager.
|
||||
// This path is the only supported execution route.
|
||||
func (h *BaseAPIHandler) ExecuteCountWithAuthManager(ctx context.Context, handlerType, modelName string, rawJSON []byte, alt string) ([]byte, *interfaces.ErrorMessage) {
|
||||
providers, normalizedModel, metadata, errMsg := h.getRequestDetails(modelName)
|
||||
providers, normalizedModel, errMsg := h.getRequestDetails(modelName)
|
||||
if errMsg != nil {
|
||||
return nil, errMsg
|
||||
}
|
||||
@@ -429,16 +427,13 @@ func (h *BaseAPIHandler) ExecuteCountWithAuthManager(ctx context.Context, handle
|
||||
Model: normalizedModel,
|
||||
Payload: cloneBytes(rawJSON),
|
||||
}
|
||||
if cloned := cloneMetadata(metadata); cloned != nil {
|
||||
req.Metadata = cloned
|
||||
}
|
||||
opts := coreexecutor.Options{
|
||||
Stream: false,
|
||||
Alt: alt,
|
||||
OriginalRequest: cloneBytes(rawJSON),
|
||||
SourceFormat: sdktranslator.FromString(handlerType),
|
||||
}
|
||||
opts.Metadata = mergeMetadata(cloneMetadata(metadata), reqMeta)
|
||||
opts.Metadata = reqMeta
|
||||
resp, err := h.AuthManager.ExecuteCount(ctx, providers, req, opts)
|
||||
if err != nil {
|
||||
status := http.StatusInternalServerError
|
||||
@@ -461,7 +456,7 @@ func (h *BaseAPIHandler) ExecuteCountWithAuthManager(ctx context.Context, handle
|
||||
// ExecuteStreamWithAuthManager executes a streaming request via the core auth manager.
|
||||
// This path is the only supported execution route.
|
||||
func (h *BaseAPIHandler) ExecuteStreamWithAuthManager(ctx context.Context, handlerType, modelName string, rawJSON []byte, alt string) (<-chan []byte, <-chan *interfaces.ErrorMessage) {
|
||||
providers, normalizedModel, metadata, errMsg := h.getRequestDetails(modelName)
|
||||
providers, normalizedModel, errMsg := h.getRequestDetails(modelName)
|
||||
if errMsg != nil {
|
||||
errChan := make(chan *interfaces.ErrorMessage, 1)
|
||||
errChan <- errMsg
|
||||
@@ -473,16 +468,13 @@ func (h *BaseAPIHandler) ExecuteStreamWithAuthManager(ctx context.Context, handl
|
||||
Model: normalizedModel,
|
||||
Payload: cloneBytes(rawJSON),
|
||||
}
|
||||
if cloned := cloneMetadata(metadata); cloned != nil {
|
||||
req.Metadata = cloned
|
||||
}
|
||||
opts := coreexecutor.Options{
|
||||
Stream: true,
|
||||
Alt: alt,
|
||||
OriginalRequest: cloneBytes(rawJSON),
|
||||
SourceFormat: sdktranslator.FromString(handlerType),
|
||||
}
|
||||
opts.Metadata = mergeMetadata(cloneMetadata(metadata), reqMeta)
|
||||
opts.Metadata = reqMeta
|
||||
chunks, err := h.AuthManager.ExecuteStream(ctx, providers, req, opts)
|
||||
if err != nil {
|
||||
errChan := make(chan *interfaces.ErrorMessage, 1)
|
||||
@@ -595,38 +587,40 @@ func statusFromError(err error) int {
|
||||
return 0
|
||||
}
|
||||
|
||||
func (h *BaseAPIHandler) getRequestDetails(modelName string) (providers []string, normalizedModel string, metadata map[string]any, err *interfaces.ErrorMessage) {
|
||||
// Resolve "auto" model to an actual available model first
|
||||
resolvedModelName := util.ResolveAutoModel(modelName)
|
||||
|
||||
// Normalize the model name to handle dynamic thinking suffixes before determining the provider.
|
||||
normalizedModel, metadata = normalizeModelMetadata(resolvedModelName)
|
||||
|
||||
// Use the normalizedModel to get the provider name.
|
||||
providers = util.GetProviderName(normalizedModel)
|
||||
if len(providers) == 0 && metadata != nil {
|
||||
if originalRaw, ok := metadata[util.ThinkingOriginalModelMetadataKey]; ok {
|
||||
if originalModel, okStr := originalRaw.(string); okStr {
|
||||
originalModel = strings.TrimSpace(originalModel)
|
||||
if originalModel != "" && !strings.EqualFold(originalModel, normalizedModel) {
|
||||
if altProviders := util.GetProviderName(originalModel); len(altProviders) > 0 {
|
||||
providers = altProviders
|
||||
normalizedModel = originalModel
|
||||
}
|
||||
}
|
||||
}
|
||||
func (h *BaseAPIHandler) getRequestDetails(modelName string) (providers []string, normalizedModel string, err *interfaces.ErrorMessage) {
|
||||
resolvedModelName := modelName
|
||||
initialSuffix := thinking.ParseSuffix(modelName)
|
||||
if initialSuffix.ModelName == "auto" {
|
||||
resolvedBase := util.ResolveAutoModel(initialSuffix.ModelName)
|
||||
if initialSuffix.HasSuffix {
|
||||
resolvedModelName = fmt.Sprintf("%s(%s)", resolvedBase, initialSuffix.RawSuffix)
|
||||
} else {
|
||||
resolvedModelName = resolvedBase
|
||||
}
|
||||
} else {
|
||||
resolvedModelName = util.ResolveAutoModel(modelName)
|
||||
}
|
||||
|
||||
parsed := thinking.ParseSuffix(resolvedModelName)
|
||||
baseModel := strings.TrimSpace(parsed.ModelName)
|
||||
|
||||
providers = util.GetProviderName(baseModel)
|
||||
// Fallback: if baseModel has no provider but differs from resolvedModelName,
|
||||
// try using the full model name. This handles edge cases where custom models
|
||||
// may be registered with their full suffixed name (e.g., "my-model(8192)").
|
||||
// Evaluated in Story 11.8: This fallback is intentionally preserved to support
|
||||
// custom model registrations that include thinking suffixes.
|
||||
if len(providers) == 0 && baseModel != resolvedModelName {
|
||||
providers = util.GetProviderName(resolvedModelName)
|
||||
}
|
||||
|
||||
if len(providers) == 0 {
|
||||
return nil, "", nil, &interfaces.ErrorMessage{StatusCode: http.StatusBadRequest, Error: fmt.Errorf("unknown provider for model %s", modelName)}
|
||||
return nil, "", &interfaces.ErrorMessage{StatusCode: http.StatusBadRequest, Error: fmt.Errorf("unknown provider for model %s", modelName)}
|
||||
}
|
||||
|
||||
// If it's a dynamic model, the normalizedModel was already set to extractedModelName.
|
||||
// If it's a non-dynamic model, normalizedModel was set by normalizeModelMetadata.
|
||||
// So, normalizedModel is already correctly set at this point.
|
||||
|
||||
return providers, normalizedModel, metadata, nil
|
||||
// The thinking suffix is preserved in the model name itself, so no
|
||||
// metadata-based configuration passing is needed.
|
||||
return providers, resolvedModelName, nil
|
||||
}
|
||||
|
||||
func cloneBytes(src []byte) []byte {
|
||||
@@ -638,10 +632,6 @@ func cloneBytes(src []byte) []byte {
|
||||
return dst
|
||||
}
|
||||
|
||||
func normalizeModelMetadata(modelName string) (string, map[string]any) {
|
||||
return util.NormalizeThinkingModel(modelName)
|
||||
}
|
||||
|
||||
func cloneMetadata(src map[string]any) map[string]any {
|
||||
if len(src) == 0 {
|
||||
return nil
|
||||
|
||||
118
sdk/api/handlers/handlers_request_details_test.go
Normal file
118
sdk/api/handlers/handlers_request_details_test.go
Normal file
@@ -0,0 +1,118 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
|
||||
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||
sdkconfig "github.com/router-for-me/CLIProxyAPI/v6/sdk/config"
|
||||
)
|
||||
|
||||
func TestGetRequestDetails_PreservesSuffix(t *testing.T) {
|
||||
modelRegistry := registry.GetGlobalRegistry()
|
||||
now := time.Now().Unix()
|
||||
|
||||
modelRegistry.RegisterClient("test-request-details-gemini", "gemini", []*registry.ModelInfo{
|
||||
{ID: "gemini-2.5-pro", Created: now + 30},
|
||||
{ID: "gemini-2.5-flash", Created: now + 25},
|
||||
})
|
||||
modelRegistry.RegisterClient("test-request-details-openai", "openai", []*registry.ModelInfo{
|
||||
{ID: "gpt-5.2", Created: now + 20},
|
||||
})
|
||||
modelRegistry.RegisterClient("test-request-details-claude", "claude", []*registry.ModelInfo{
|
||||
{ID: "claude-sonnet-4-5", Created: now + 5},
|
||||
})
|
||||
|
||||
// Ensure cleanup of all test registrations.
|
||||
clientIDs := []string{
|
||||
"test-request-details-gemini",
|
||||
"test-request-details-openai",
|
||||
"test-request-details-claude",
|
||||
}
|
||||
for _, clientID := range clientIDs {
|
||||
id := clientID
|
||||
t.Cleanup(func() {
|
||||
modelRegistry.UnregisterClient(id)
|
||||
})
|
||||
}
|
||||
|
||||
handler := NewBaseAPIHandlers(&sdkconfig.SDKConfig{}, coreauth.NewManager(nil, nil, nil))
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
inputModel string
|
||||
wantProviders []string
|
||||
wantModel string
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "numeric suffix preserved",
|
||||
inputModel: "gemini-2.5-pro(8192)",
|
||||
wantProviders: []string{"gemini"},
|
||||
wantModel: "gemini-2.5-pro(8192)",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "level suffix preserved",
|
||||
inputModel: "gpt-5.2(high)",
|
||||
wantProviders: []string{"openai"},
|
||||
wantModel: "gpt-5.2(high)",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "no suffix unchanged",
|
||||
inputModel: "claude-sonnet-4-5",
|
||||
wantProviders: []string{"claude"},
|
||||
wantModel: "claude-sonnet-4-5",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "unknown model with suffix",
|
||||
inputModel: "unknown-model(8192)",
|
||||
wantProviders: nil,
|
||||
wantModel: "",
|
||||
wantErr: true,
|
||||
},
|
||||
{
|
||||
name: "auto suffix resolved",
|
||||
inputModel: "auto(high)",
|
||||
wantProviders: []string{"gemini"},
|
||||
wantModel: "gemini-2.5-pro(high)",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "special suffix none preserved",
|
||||
inputModel: "gemini-2.5-flash(none)",
|
||||
wantProviders: []string{"gemini"},
|
||||
wantModel: "gemini-2.5-flash(none)",
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "special suffix auto preserved",
|
||||
inputModel: "claude-sonnet-4-5(auto)",
|
||||
wantProviders: []string{"claude"},
|
||||
wantModel: "claude-sonnet-4-5(auto)",
|
||||
wantErr: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
providers, model, errMsg := handler.getRequestDetails(tt.inputModel)
|
||||
if (errMsg != nil) != tt.wantErr {
|
||||
t.Fatalf("getRequestDetails() error = %v, wantErr %v", errMsg, tt.wantErr)
|
||||
}
|
||||
if errMsg != nil {
|
||||
return
|
||||
}
|
||||
if !reflect.DeepEqual(providers, tt.wantProviders) {
|
||||
t.Fatalf("getRequestDetails() providers = %v, want %v", providers, tt.wantProviders)
|
||||
}
|
||||
if model != tt.wantModel {
|
||||
t.Fatalf("getRequestDetails() model = %v, want %v", model, tt.wantModel)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
180
sdk/cliproxy/auth/api_key_model_alias_test.go
Normal file
180
sdk/cliproxy/auth/api_key_model_alias_test.go
Normal file
@@ -0,0 +1,180 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
internalconfig "github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
)
|
||||
|
||||
func TestLookupAPIKeyUpstreamModel(t *testing.T) {
|
||||
cfg := &internalconfig.Config{
|
||||
GeminiKey: []internalconfig.GeminiKey{
|
||||
{
|
||||
APIKey: "k",
|
||||
BaseURL: "https://example.com",
|
||||
Models: []internalconfig.GeminiModel{
|
||||
{Name: "gemini-2.5-pro-exp-03-25", Alias: "g25p"},
|
||||
{Name: "gemini-2.5-flash(low)", Alias: "g25f"},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
mgr := NewManager(nil, nil, nil)
|
||||
mgr.SetConfig(cfg)
|
||||
|
||||
ctx := context.Background()
|
||||
_, _ = mgr.Register(ctx, &Auth{ID: "a1", Provider: "gemini", Attributes: map[string]string{"api_key": "k", "base_url": "https://example.com"}})
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
authID string
|
||||
input string
|
||||
want string
|
||||
}{
|
||||
// Fast path + suffix preservation
|
||||
{"alias with suffix", "a1", "g25p(8192)", "gemini-2.5-pro-exp-03-25(8192)"},
|
||||
{"alias without suffix", "a1", "g25p", "gemini-2.5-pro-exp-03-25"},
|
||||
|
||||
// Config suffix takes priority
|
||||
{"config suffix priority", "a1", "g25f(high)", "gemini-2.5-flash(low)"},
|
||||
{"config suffix no user suffix", "a1", "g25f", "gemini-2.5-flash(low)"},
|
||||
|
||||
// Case insensitive
|
||||
{"uppercase alias", "a1", "G25P", "gemini-2.5-pro-exp-03-25"},
|
||||
{"mixed case with suffix", "a1", "G25p(4096)", "gemini-2.5-pro-exp-03-25(4096)"},
|
||||
|
||||
// Direct name lookup
|
||||
{"upstream name direct", "a1", "gemini-2.5-pro-exp-03-25", "gemini-2.5-pro-exp-03-25"},
|
||||
{"upstream name with suffix", "a1", "gemini-2.5-pro-exp-03-25(8192)", "gemini-2.5-pro-exp-03-25(8192)"},
|
||||
|
||||
// Cache miss scenarios
|
||||
{"non-existent auth", "non-existent", "g25p", ""},
|
||||
{"unknown alias", "a1", "unknown-alias", ""},
|
||||
{"empty auth ID", "", "g25p", ""},
|
||||
{"empty model", "a1", "", ""},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
resolved := mgr.lookupAPIKeyUpstreamModel(tt.authID, tt.input)
|
||||
if resolved != tt.want {
|
||||
t.Errorf("lookupAPIKeyUpstreamModel(%q, %q) = %q, want %q", tt.authID, tt.input, resolved, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAPIKeyModelAlias_ConfigHotReload(t *testing.T) {
|
||||
cfg := &internalconfig.Config{
|
||||
GeminiKey: []internalconfig.GeminiKey{
|
||||
{
|
||||
APIKey: "k",
|
||||
Models: []internalconfig.GeminiModel{{Name: "gemini-2.5-pro-exp-03-25", Alias: "g25p"}},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
mgr := NewManager(nil, nil, nil)
|
||||
mgr.SetConfig(cfg)
|
||||
|
||||
ctx := context.Background()
|
||||
_, _ = mgr.Register(ctx, &Auth{ID: "a1", Provider: "gemini", Attributes: map[string]string{"api_key": "k"}})
|
||||
|
||||
// Initial alias
|
||||
if resolved := mgr.lookupAPIKeyUpstreamModel("a1", "g25p"); resolved != "gemini-2.5-pro-exp-03-25" {
|
||||
t.Fatalf("before reload: got %q, want %q", resolved, "gemini-2.5-pro-exp-03-25")
|
||||
}
|
||||
|
||||
// Hot reload with new alias
|
||||
mgr.SetConfig(&internalconfig.Config{
|
||||
GeminiKey: []internalconfig.GeminiKey{
|
||||
{
|
||||
APIKey: "k",
|
||||
Models: []internalconfig.GeminiModel{{Name: "gemini-2.5-flash", Alias: "g25p"}},
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
// New alias should take effect
|
||||
if resolved := mgr.lookupAPIKeyUpstreamModel("a1", "g25p"); resolved != "gemini-2.5-flash" {
|
||||
t.Fatalf("after reload: got %q, want %q", resolved, "gemini-2.5-flash")
|
||||
}
|
||||
}
|
||||
|
||||
func TestAPIKeyModelAlias_MultipleProviders(t *testing.T) {
|
||||
cfg := &internalconfig.Config{
|
||||
GeminiKey: []internalconfig.GeminiKey{{APIKey: "gemini-key", Models: []internalconfig.GeminiModel{{Name: "gemini-2.5-pro", Alias: "gp"}}}},
|
||||
ClaudeKey: []internalconfig.ClaudeKey{{APIKey: "claude-key", Models: []internalconfig.ClaudeModel{{Name: "claude-sonnet-4", Alias: "cs4"}}}},
|
||||
CodexKey: []internalconfig.CodexKey{{APIKey: "codex-key", Models: []internalconfig.CodexModel{{Name: "o3", Alias: "o"}}}},
|
||||
}
|
||||
|
||||
mgr := NewManager(nil, nil, nil)
|
||||
mgr.SetConfig(cfg)
|
||||
|
||||
ctx := context.Background()
|
||||
_, _ = mgr.Register(ctx, &Auth{ID: "gemini-auth", Provider: "gemini", Attributes: map[string]string{"api_key": "gemini-key"}})
|
||||
_, _ = mgr.Register(ctx, &Auth{ID: "claude-auth", Provider: "claude", Attributes: map[string]string{"api_key": "claude-key"}})
|
||||
_, _ = mgr.Register(ctx, &Auth{ID: "codex-auth", Provider: "codex", Attributes: map[string]string{"api_key": "codex-key"}})
|
||||
|
||||
tests := []struct {
|
||||
authID, input, want string
|
||||
}{
|
||||
{"gemini-auth", "gp", "gemini-2.5-pro"},
|
||||
{"claude-auth", "cs4", "claude-sonnet-4"},
|
||||
{"codex-auth", "o", "o3"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
if resolved := mgr.lookupAPIKeyUpstreamModel(tt.authID, tt.input); resolved != tt.want {
|
||||
t.Errorf("lookupAPIKeyUpstreamModel(%q, %q) = %q, want %q", tt.authID, tt.input, resolved, tt.want)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestApplyAPIKeyModelAlias(t *testing.T) {
|
||||
cfg := &internalconfig.Config{
|
||||
GeminiKey: []internalconfig.GeminiKey{
|
||||
{APIKey: "k", Models: []internalconfig.GeminiModel{{Name: "gemini-2.5-pro-exp-03-25", Alias: "g25p"}}},
|
||||
},
|
||||
}
|
||||
|
||||
mgr := NewManager(nil, nil, nil)
|
||||
mgr.SetConfig(cfg)
|
||||
|
||||
ctx := context.Background()
|
||||
apiKeyAuth := &Auth{ID: "a1", Provider: "gemini", Attributes: map[string]string{"api_key": "k"}}
|
||||
oauthAuth := &Auth{ID: "oauth-auth", Provider: "gemini", Attributes: map[string]string{"auth_kind": "oauth"}}
|
||||
_, _ = mgr.Register(ctx, apiKeyAuth)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
auth *Auth
|
||||
inputModel string
|
||||
wantModel string
|
||||
}{
|
||||
{
|
||||
name: "api_key auth with alias",
|
||||
auth: apiKeyAuth,
|
||||
inputModel: "g25p(8192)",
|
||||
wantModel: "gemini-2.5-pro-exp-03-25(8192)",
|
||||
},
|
||||
{
|
||||
name: "oauth auth passthrough",
|
||||
auth: oauthAuth,
|
||||
inputModel: "some-model",
|
||||
wantModel: "some-model",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
resolvedModel := mgr.applyAPIKeyModelAlias(tt.auth, tt.inputModel)
|
||||
|
||||
if resolvedModel != tt.wantModel {
|
||||
t.Errorf("model = %q, want %q", resolvedModel, tt.wantModel)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -15,8 +15,10 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
internalconfig "github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/logging"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/thinking"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
||||
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
|
||||
log "github.com/sirupsen/logrus"
|
||||
@@ -117,8 +119,16 @@ type Manager struct {
|
||||
requestRetry atomic.Int32
|
||||
maxRetryInterval atomic.Int64
|
||||
|
||||
// modelNameMappings stores global model name alias mappings (alias -> upstream name) keyed by channel.
|
||||
modelNameMappings atomic.Value
|
||||
// oauthModelAlias stores global OAuth model alias mappings (alias -> upstream name) keyed by channel.
|
||||
oauthModelAlias atomic.Value
|
||||
|
||||
// apiKeyModelAlias caches resolved model alias mappings for API-key auths.
|
||||
// Keyed by auth.ID, value is alias(lower) -> upstream model (including suffix).
|
||||
apiKeyModelAlias atomic.Value
|
||||
|
||||
// runtimeConfig stores the latest application config for request-time decisions.
|
||||
// It is initialized in NewManager; never Load() before first Store().
|
||||
runtimeConfig atomic.Value
|
||||
|
||||
// Optional HTTP RoundTripper provider injected by host.
|
||||
rtProvider RoundTripperProvider
|
||||
@@ -135,7 +145,7 @@ func NewManager(store Store, selector Selector, hook Hook) *Manager {
|
||||
if hook == nil {
|
||||
hook = NoopHook{}
|
||||
}
|
||||
return &Manager{
|
||||
manager := &Manager{
|
||||
store: store,
|
||||
executors: make(map[string]ProviderExecutor),
|
||||
selector: selector,
|
||||
@@ -143,6 +153,10 @@ func NewManager(store Store, selector Selector, hook Hook) *Manager {
|
||||
auths: make(map[string]*Auth),
|
||||
providerOffsets: make(map[string]int),
|
||||
}
|
||||
// atomic.Value requires non-nil initial value.
|
||||
manager.runtimeConfig.Store(&internalconfig.Config{})
|
||||
manager.apiKeyModelAlias.Store(apiKeyModelAliasTable(nil))
|
||||
return manager
|
||||
}
|
||||
|
||||
func (m *Manager) SetSelector(selector Selector) {
|
||||
@@ -171,6 +185,181 @@ func (m *Manager) SetRoundTripperProvider(p RoundTripperProvider) {
|
||||
m.mu.Unlock()
|
||||
}
|
||||
|
||||
// SetConfig updates the runtime config snapshot used by request-time helpers.
|
||||
// Callers should provide the latest config on reload so per-credential alias mapping stays in sync.
|
||||
func (m *Manager) SetConfig(cfg *internalconfig.Config) {
|
||||
if m == nil {
|
||||
return
|
||||
}
|
||||
if cfg == nil {
|
||||
cfg = &internalconfig.Config{}
|
||||
}
|
||||
m.runtimeConfig.Store(cfg)
|
||||
m.rebuildAPIKeyModelAliasFromRuntimeConfig()
|
||||
}
|
||||
|
||||
func (m *Manager) lookupAPIKeyUpstreamModel(authID, requestedModel string) string {
|
||||
if m == nil {
|
||||
return ""
|
||||
}
|
||||
authID = strings.TrimSpace(authID)
|
||||
if authID == "" {
|
||||
return ""
|
||||
}
|
||||
requestedModel = strings.TrimSpace(requestedModel)
|
||||
if requestedModel == "" {
|
||||
return ""
|
||||
}
|
||||
table, _ := m.apiKeyModelAlias.Load().(apiKeyModelAliasTable)
|
||||
if table == nil {
|
||||
return ""
|
||||
}
|
||||
byAlias := table[authID]
|
||||
if len(byAlias) == 0 {
|
||||
return ""
|
||||
}
|
||||
key := strings.ToLower(thinking.ParseSuffix(requestedModel).ModelName)
|
||||
if key == "" {
|
||||
key = strings.ToLower(requestedModel)
|
||||
}
|
||||
resolved := strings.TrimSpace(byAlias[key])
|
||||
if resolved == "" {
|
||||
return ""
|
||||
}
|
||||
// Preserve thinking suffix from the client's requested model unless config already has one.
|
||||
requestResult := thinking.ParseSuffix(requestedModel)
|
||||
if thinking.ParseSuffix(resolved).HasSuffix {
|
||||
return resolved
|
||||
}
|
||||
if requestResult.HasSuffix && requestResult.RawSuffix != "" {
|
||||
return resolved + "(" + requestResult.RawSuffix + ")"
|
||||
}
|
||||
return resolved
|
||||
|
||||
}
|
||||
|
||||
func (m *Manager) rebuildAPIKeyModelAliasFromRuntimeConfig() {
|
||||
if m == nil {
|
||||
return
|
||||
}
|
||||
cfg, _ := m.runtimeConfig.Load().(*internalconfig.Config)
|
||||
if cfg == nil {
|
||||
cfg = &internalconfig.Config{}
|
||||
}
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
m.rebuildAPIKeyModelAliasLocked(cfg)
|
||||
}
|
||||
|
||||
func (m *Manager) rebuildAPIKeyModelAliasLocked(cfg *internalconfig.Config) {
|
||||
if m == nil {
|
||||
return
|
||||
}
|
||||
if cfg == nil {
|
||||
cfg = &internalconfig.Config{}
|
||||
}
|
||||
|
||||
out := make(apiKeyModelAliasTable)
|
||||
for _, auth := range m.auths {
|
||||
if auth == nil {
|
||||
continue
|
||||
}
|
||||
if strings.TrimSpace(auth.ID) == "" {
|
||||
continue
|
||||
}
|
||||
kind, _ := auth.AccountInfo()
|
||||
if !strings.EqualFold(strings.TrimSpace(kind), "api_key") {
|
||||
continue
|
||||
}
|
||||
|
||||
byAlias := make(map[string]string)
|
||||
provider := strings.ToLower(strings.TrimSpace(auth.Provider))
|
||||
switch provider {
|
||||
case "gemini":
|
||||
if entry := resolveGeminiAPIKeyConfig(cfg, auth); entry != nil {
|
||||
compileAPIKeyModelAliasForModels(byAlias, entry.Models)
|
||||
}
|
||||
case "claude":
|
||||
if entry := resolveClaudeAPIKeyConfig(cfg, auth); entry != nil {
|
||||
compileAPIKeyModelAliasForModels(byAlias, entry.Models)
|
||||
}
|
||||
case "codex":
|
||||
if entry := resolveCodexAPIKeyConfig(cfg, auth); entry != nil {
|
||||
compileAPIKeyModelAliasForModels(byAlias, entry.Models)
|
||||
}
|
||||
case "vertex":
|
||||
if entry := resolveVertexAPIKeyConfig(cfg, auth); entry != nil {
|
||||
compileAPIKeyModelAliasForModels(byAlias, entry.Models)
|
||||
}
|
||||
default:
|
||||
// OpenAI-compat uses config selection from auth.Attributes.
|
||||
providerKey := ""
|
||||
compatName := ""
|
||||
if auth.Attributes != nil {
|
||||
providerKey = strings.TrimSpace(auth.Attributes["provider_key"])
|
||||
compatName = strings.TrimSpace(auth.Attributes["compat_name"])
|
||||
}
|
||||
if compatName != "" || strings.EqualFold(strings.TrimSpace(auth.Provider), "openai-compatibility") {
|
||||
if entry := resolveOpenAICompatConfig(cfg, providerKey, compatName, auth.Provider); entry != nil {
|
||||
compileAPIKeyModelAliasForModels(byAlias, entry.Models)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if len(byAlias) > 0 {
|
||||
out[auth.ID] = byAlias
|
||||
}
|
||||
}
|
||||
|
||||
m.apiKeyModelAlias.Store(out)
|
||||
}
|
||||
|
||||
func compileAPIKeyModelAliasForModels[T interface {
|
||||
GetName() string
|
||||
GetAlias() string
|
||||
}](out map[string]string, models []T) {
|
||||
if out == nil {
|
||||
return
|
||||
}
|
||||
for i := range models {
|
||||
alias := strings.TrimSpace(models[i].GetAlias())
|
||||
name := strings.TrimSpace(models[i].GetName())
|
||||
if alias == "" || name == "" {
|
||||
continue
|
||||
}
|
||||
aliasKey := strings.ToLower(thinking.ParseSuffix(alias).ModelName)
|
||||
if aliasKey == "" {
|
||||
aliasKey = strings.ToLower(alias)
|
||||
}
|
||||
// Config priority: first alias wins.
|
||||
if _, exists := out[aliasKey]; exists {
|
||||
continue
|
||||
}
|
||||
out[aliasKey] = name
|
||||
// Also allow direct lookup by upstream name (case-insensitive), so lookups on already-upstream
|
||||
// models remain a cheap no-op.
|
||||
nameKey := strings.ToLower(thinking.ParseSuffix(name).ModelName)
|
||||
if nameKey == "" {
|
||||
nameKey = strings.ToLower(name)
|
||||
}
|
||||
if nameKey != "" {
|
||||
if _, exists := out[nameKey]; !exists {
|
||||
out[nameKey] = name
|
||||
}
|
||||
}
|
||||
// Preserve config suffix priority by seeding a base-name lookup when name already has suffix.
|
||||
nameResult := thinking.ParseSuffix(name)
|
||||
if nameResult.HasSuffix {
|
||||
baseKey := strings.ToLower(strings.TrimSpace(nameResult.ModelName))
|
||||
if baseKey != "" {
|
||||
if _, exists := out[baseKey]; !exists {
|
||||
out[baseKey] = name
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// SetRetryConfig updates retry attempts and cooldown wait interval.
|
||||
func (m *Manager) SetRetryConfig(retry int, maxRetryInterval time.Duration) {
|
||||
if m == nil {
|
||||
@@ -219,6 +408,7 @@ func (m *Manager) Register(ctx context.Context, auth *Auth) (*Auth, error) {
|
||||
m.mu.Lock()
|
||||
m.auths[auth.ID] = auth.Clone()
|
||||
m.mu.Unlock()
|
||||
m.rebuildAPIKeyModelAliasFromRuntimeConfig()
|
||||
_ = m.persist(ctx, auth)
|
||||
m.hook.OnAuthRegistered(ctx, auth.Clone())
|
||||
return auth.Clone(), nil
|
||||
@@ -237,6 +427,7 @@ func (m *Manager) Update(ctx context.Context, auth *Auth) (*Auth, error) {
|
||||
auth.EnsureIndex()
|
||||
m.auths[auth.ID] = auth.Clone()
|
||||
m.mu.Unlock()
|
||||
m.rebuildAPIKeyModelAliasFromRuntimeConfig()
|
||||
_ = m.persist(ctx, auth)
|
||||
m.hook.OnAuthUpdated(ctx, auth.Clone())
|
||||
return auth.Clone(), nil
|
||||
@@ -261,6 +452,11 @@ func (m *Manager) Load(ctx context.Context) error {
|
||||
auth.EnsureIndex()
|
||||
m.auths[auth.ID] = auth.Clone()
|
||||
}
|
||||
cfg, _ := m.runtimeConfig.Load().(*internalconfig.Config)
|
||||
if cfg == nil {
|
||||
cfg = &internalconfig.Config{}
|
||||
}
|
||||
m.rebuildAPIKeyModelAliasLocked(cfg)
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -395,8 +591,9 @@ func (m *Manager) executeMixedOnce(ctx context.Context, providers []string, req
|
||||
execCtx = context.WithValue(execCtx, "cliproxy.roundtripper", rt)
|
||||
}
|
||||
execReq := req
|
||||
execReq.Model, execReq.Metadata = rewriteModelForAuth(routeModel, req.Metadata, auth)
|
||||
execReq.Model, execReq.Metadata = m.applyOAuthModelMapping(auth, execReq.Model, execReq.Metadata)
|
||||
execReq.Model = rewriteModelForAuth(routeModel, auth)
|
||||
execReq.Model = m.applyOAuthModelAlias(auth, execReq.Model)
|
||||
execReq.Model = m.applyAPIKeyModelAlias(auth, execReq.Model)
|
||||
resp, errExec := executor.Execute(execCtx, auth, execReq, opts)
|
||||
result := Result{AuthID: auth.ID, Provider: provider, Model: routeModel, Success: errExec == nil}
|
||||
if errExec != nil {
|
||||
@@ -443,8 +640,9 @@ func (m *Manager) executeCountMixedOnce(ctx context.Context, providers []string,
|
||||
execCtx = context.WithValue(execCtx, "cliproxy.roundtripper", rt)
|
||||
}
|
||||
execReq := req
|
||||
execReq.Model, execReq.Metadata = rewriteModelForAuth(routeModel, req.Metadata, auth)
|
||||
execReq.Model, execReq.Metadata = m.applyOAuthModelMapping(auth, execReq.Model, execReq.Metadata)
|
||||
execReq.Model = rewriteModelForAuth(routeModel, auth)
|
||||
execReq.Model = m.applyOAuthModelAlias(auth, execReq.Model)
|
||||
execReq.Model = m.applyAPIKeyModelAlias(auth, execReq.Model)
|
||||
resp, errExec := executor.CountTokens(execCtx, auth, execReq, opts)
|
||||
result := Result{AuthID: auth.ID, Provider: provider, Model: routeModel, Success: errExec == nil}
|
||||
if errExec != nil {
|
||||
@@ -491,8 +689,9 @@ func (m *Manager) executeStreamMixedOnce(ctx context.Context, providers []string
|
||||
execCtx = context.WithValue(execCtx, "cliproxy.roundtripper", rt)
|
||||
}
|
||||
execReq := req
|
||||
execReq.Model, execReq.Metadata = rewriteModelForAuth(routeModel, req.Metadata, auth)
|
||||
execReq.Model, execReq.Metadata = m.applyOAuthModelMapping(auth, execReq.Model, execReq.Metadata)
|
||||
execReq.Model = rewriteModelForAuth(routeModel, auth)
|
||||
execReq.Model = m.applyOAuthModelAlias(auth, execReq.Model)
|
||||
execReq.Model = m.applyAPIKeyModelAlias(auth, execReq.Model)
|
||||
chunks, errStream := executor.ExecuteStream(execCtx, auth, execReq, opts)
|
||||
if errStream != nil {
|
||||
rerr := &Error{Message: errStream.Error()}
|
||||
@@ -556,8 +755,9 @@ func (m *Manager) executeWithProvider(ctx context.Context, provider string, req
|
||||
execCtx = context.WithValue(execCtx, "cliproxy.roundtripper", rt)
|
||||
}
|
||||
execReq := req
|
||||
execReq.Model, execReq.Metadata = rewriteModelForAuth(routeModel, req.Metadata, auth)
|
||||
execReq.Model, execReq.Metadata = m.applyOAuthModelMapping(auth, execReq.Model, execReq.Metadata)
|
||||
execReq.Model = rewriteModelForAuth(routeModel, auth)
|
||||
execReq.Model = m.applyOAuthModelAlias(auth, execReq.Model)
|
||||
execReq.Model = m.applyAPIKeyModelAlias(auth, execReq.Model)
|
||||
resp, errExec := executor.Execute(execCtx, auth, execReq, opts)
|
||||
result := Result{AuthID: auth.ID, Provider: provider, Model: routeModel, Success: errExec == nil}
|
||||
if errExec != nil {
|
||||
@@ -604,8 +804,9 @@ func (m *Manager) executeCountWithProvider(ctx context.Context, provider string,
|
||||
execCtx = context.WithValue(execCtx, "cliproxy.roundtripper", rt)
|
||||
}
|
||||
execReq := req
|
||||
execReq.Model, execReq.Metadata = rewriteModelForAuth(routeModel, req.Metadata, auth)
|
||||
execReq.Model, execReq.Metadata = m.applyOAuthModelMapping(auth, execReq.Model, execReq.Metadata)
|
||||
execReq.Model = rewriteModelForAuth(routeModel, auth)
|
||||
execReq.Model = m.applyOAuthModelAlias(auth, execReq.Model)
|
||||
execReq.Model = m.applyAPIKeyModelAlias(auth, execReq.Model)
|
||||
resp, errExec := executor.CountTokens(execCtx, auth, execReq, opts)
|
||||
result := Result{AuthID: auth.ID, Provider: provider, Model: routeModel, Success: errExec == nil}
|
||||
if errExec != nil {
|
||||
@@ -652,8 +853,9 @@ func (m *Manager) executeStreamWithProvider(ctx context.Context, provider string
|
||||
execCtx = context.WithValue(execCtx, "cliproxy.roundtripper", rt)
|
||||
}
|
||||
execReq := req
|
||||
execReq.Model, execReq.Metadata = rewriteModelForAuth(routeModel, req.Metadata, auth)
|
||||
execReq.Model, execReq.Metadata = m.applyOAuthModelMapping(auth, execReq.Model, execReq.Metadata)
|
||||
execReq.Model = rewriteModelForAuth(routeModel, auth)
|
||||
execReq.Model = m.applyOAuthModelAlias(auth, execReq.Model)
|
||||
execReq.Model = m.applyAPIKeyModelAlias(auth, execReq.Model)
|
||||
chunks, errStream := executor.ExecuteStream(execCtx, auth, execReq, opts)
|
||||
if errStream != nil {
|
||||
rerr := &Error{Message: errStream.Error()}
|
||||
@@ -691,51 +893,229 @@ func (m *Manager) executeStreamWithProvider(ctx context.Context, provider string
|
||||
}
|
||||
}
|
||||
|
||||
func rewriteModelForAuth(model string, metadata map[string]any, auth *Auth) (string, map[string]any) {
|
||||
func rewriteModelForAuth(model string, auth *Auth) string {
|
||||
if auth == nil || model == "" {
|
||||
return model, metadata
|
||||
return model
|
||||
}
|
||||
prefix := strings.TrimSpace(auth.Prefix)
|
||||
if prefix == "" {
|
||||
return model, metadata
|
||||
return model
|
||||
}
|
||||
needle := prefix + "/"
|
||||
if !strings.HasPrefix(model, needle) {
|
||||
return model, metadata
|
||||
return model
|
||||
}
|
||||
rewritten := strings.TrimPrefix(model, needle)
|
||||
return rewritten, stripPrefixFromMetadata(metadata, needle)
|
||||
return strings.TrimPrefix(model, needle)
|
||||
}
|
||||
|
||||
func stripPrefixFromMetadata(metadata map[string]any, needle string) map[string]any {
|
||||
if len(metadata) == 0 || needle == "" {
|
||||
return metadata
|
||||
func (m *Manager) applyAPIKeyModelAlias(auth *Auth, requestedModel string) string {
|
||||
if m == nil || auth == nil {
|
||||
return requestedModel
|
||||
}
|
||||
keys := []string{
|
||||
util.ThinkingOriginalModelMetadataKey,
|
||||
util.GeminiOriginalModelMetadataKey,
|
||||
util.ModelMappingOriginalModelMetadataKey,
|
||||
|
||||
kind, _ := auth.AccountInfo()
|
||||
if !strings.EqualFold(strings.TrimSpace(kind), "api_key") {
|
||||
return requestedModel
|
||||
}
|
||||
var out map[string]any
|
||||
for _, key := range keys {
|
||||
raw, ok := metadata[key]
|
||||
if !ok {
|
||||
|
||||
requestedModel = strings.TrimSpace(requestedModel)
|
||||
if requestedModel == "" {
|
||||
return requestedModel
|
||||
}
|
||||
|
||||
// Fast path: lookup per-auth mapping table (keyed by auth.ID).
|
||||
if resolved := m.lookupAPIKeyUpstreamModel(auth.ID, requestedModel); resolved != "" {
|
||||
return resolved
|
||||
}
|
||||
|
||||
// Slow path: scan config for the matching credential entry and resolve alias.
|
||||
// This acts as a safety net if mappings are stale or auth.ID is missing.
|
||||
cfg, _ := m.runtimeConfig.Load().(*internalconfig.Config)
|
||||
if cfg == nil {
|
||||
cfg = &internalconfig.Config{}
|
||||
}
|
||||
|
||||
provider := strings.ToLower(strings.TrimSpace(auth.Provider))
|
||||
upstreamModel := ""
|
||||
switch provider {
|
||||
case "gemini":
|
||||
upstreamModel = resolveUpstreamModelForGeminiAPIKey(cfg, auth, requestedModel)
|
||||
case "claude":
|
||||
upstreamModel = resolveUpstreamModelForClaudeAPIKey(cfg, auth, requestedModel)
|
||||
case "codex":
|
||||
upstreamModel = resolveUpstreamModelForCodexAPIKey(cfg, auth, requestedModel)
|
||||
case "vertex":
|
||||
upstreamModel = resolveUpstreamModelForVertexAPIKey(cfg, auth, requestedModel)
|
||||
default:
|
||||
upstreamModel = resolveUpstreamModelForOpenAICompatAPIKey(cfg, auth, requestedModel)
|
||||
}
|
||||
|
||||
// Return upstream model if found, otherwise return requested model.
|
||||
if upstreamModel != "" {
|
||||
return upstreamModel
|
||||
}
|
||||
return requestedModel
|
||||
}
|
||||
|
||||
// APIKeyConfigEntry is a generic interface for API key configurations.
|
||||
type APIKeyConfigEntry interface {
|
||||
GetAPIKey() string
|
||||
GetBaseURL() string
|
||||
}
|
||||
|
||||
func resolveAPIKeyConfig[T APIKeyConfigEntry](entries []T, auth *Auth) *T {
|
||||
if auth == nil || len(entries) == 0 {
|
||||
return nil
|
||||
}
|
||||
attrKey, attrBase := "", ""
|
||||
if auth.Attributes != nil {
|
||||
attrKey = strings.TrimSpace(auth.Attributes["api_key"])
|
||||
attrBase = strings.TrimSpace(auth.Attributes["base_url"])
|
||||
}
|
||||
for i := range entries {
|
||||
entry := &entries[i]
|
||||
cfgKey := strings.TrimSpace((*entry).GetAPIKey())
|
||||
cfgBase := strings.TrimSpace((*entry).GetBaseURL())
|
||||
if attrKey != "" && attrBase != "" {
|
||||
if strings.EqualFold(cfgKey, attrKey) && strings.EqualFold(cfgBase, attrBase) {
|
||||
return entry
|
||||
}
|
||||
continue
|
||||
}
|
||||
value, okStr := raw.(string)
|
||||
if !okStr || !strings.HasPrefix(value, needle) {
|
||||
continue
|
||||
}
|
||||
if out == nil {
|
||||
out = make(map[string]any, len(metadata))
|
||||
for k, v := range metadata {
|
||||
out[k] = v
|
||||
if attrKey != "" && strings.EqualFold(cfgKey, attrKey) {
|
||||
if cfgBase == "" || strings.EqualFold(cfgBase, attrBase) {
|
||||
return entry
|
||||
}
|
||||
}
|
||||
out[key] = strings.TrimPrefix(value, needle)
|
||||
if attrKey == "" && attrBase != "" && strings.EqualFold(cfgBase, attrBase) {
|
||||
return entry
|
||||
}
|
||||
}
|
||||
if out == nil {
|
||||
return metadata
|
||||
if attrKey != "" {
|
||||
for i := range entries {
|
||||
entry := &entries[i]
|
||||
if strings.EqualFold(strings.TrimSpace((*entry).GetAPIKey()), attrKey) {
|
||||
return entry
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func resolveGeminiAPIKeyConfig(cfg *internalconfig.Config, auth *Auth) *internalconfig.GeminiKey {
|
||||
if cfg == nil {
|
||||
return nil
|
||||
}
|
||||
return resolveAPIKeyConfig(cfg.GeminiKey, auth)
|
||||
}
|
||||
|
||||
func resolveClaudeAPIKeyConfig(cfg *internalconfig.Config, auth *Auth) *internalconfig.ClaudeKey {
|
||||
if cfg == nil {
|
||||
return nil
|
||||
}
|
||||
return resolveAPIKeyConfig(cfg.ClaudeKey, auth)
|
||||
}
|
||||
|
||||
func resolveCodexAPIKeyConfig(cfg *internalconfig.Config, auth *Auth) *internalconfig.CodexKey {
|
||||
if cfg == nil {
|
||||
return nil
|
||||
}
|
||||
return resolveAPIKeyConfig(cfg.CodexKey, auth)
|
||||
}
|
||||
|
||||
func resolveVertexAPIKeyConfig(cfg *internalconfig.Config, auth *Auth) *internalconfig.VertexCompatKey {
|
||||
if cfg == nil {
|
||||
return nil
|
||||
}
|
||||
return resolveAPIKeyConfig(cfg.VertexCompatAPIKey, auth)
|
||||
}
|
||||
|
||||
func resolveUpstreamModelForGeminiAPIKey(cfg *internalconfig.Config, auth *Auth, requestedModel string) string {
|
||||
entry := resolveGeminiAPIKeyConfig(cfg, auth)
|
||||
if entry == nil {
|
||||
return ""
|
||||
}
|
||||
return resolveModelAliasFromConfigModels(requestedModel, asModelAliasEntries(entry.Models))
|
||||
}
|
||||
|
||||
func resolveUpstreamModelForClaudeAPIKey(cfg *internalconfig.Config, auth *Auth, requestedModel string) string {
|
||||
entry := resolveClaudeAPIKeyConfig(cfg, auth)
|
||||
if entry == nil {
|
||||
return ""
|
||||
}
|
||||
return resolveModelAliasFromConfigModels(requestedModel, asModelAliasEntries(entry.Models))
|
||||
}
|
||||
|
||||
func resolveUpstreamModelForCodexAPIKey(cfg *internalconfig.Config, auth *Auth, requestedModel string) string {
|
||||
entry := resolveCodexAPIKeyConfig(cfg, auth)
|
||||
if entry == nil {
|
||||
return ""
|
||||
}
|
||||
return resolveModelAliasFromConfigModels(requestedModel, asModelAliasEntries(entry.Models))
|
||||
}
|
||||
|
||||
func resolveUpstreamModelForVertexAPIKey(cfg *internalconfig.Config, auth *Auth, requestedModel string) string {
|
||||
entry := resolveVertexAPIKeyConfig(cfg, auth)
|
||||
if entry == nil {
|
||||
return ""
|
||||
}
|
||||
return resolveModelAliasFromConfigModels(requestedModel, asModelAliasEntries(entry.Models))
|
||||
}
|
||||
|
||||
func resolveUpstreamModelForOpenAICompatAPIKey(cfg *internalconfig.Config, auth *Auth, requestedModel string) string {
|
||||
providerKey := ""
|
||||
compatName := ""
|
||||
if auth != nil && len(auth.Attributes) > 0 {
|
||||
providerKey = strings.TrimSpace(auth.Attributes["provider_key"])
|
||||
compatName = strings.TrimSpace(auth.Attributes["compat_name"])
|
||||
}
|
||||
if compatName == "" && !strings.EqualFold(strings.TrimSpace(auth.Provider), "openai-compatibility") {
|
||||
return ""
|
||||
}
|
||||
entry := resolveOpenAICompatConfig(cfg, providerKey, compatName, auth.Provider)
|
||||
if entry == nil {
|
||||
return ""
|
||||
}
|
||||
return resolveModelAliasFromConfigModels(requestedModel, asModelAliasEntries(entry.Models))
|
||||
}
|
||||
|
||||
type apiKeyModelAliasTable map[string]map[string]string
|
||||
|
||||
func resolveOpenAICompatConfig(cfg *internalconfig.Config, providerKey, compatName, authProvider string) *internalconfig.OpenAICompatibility {
|
||||
if cfg == nil {
|
||||
return nil
|
||||
}
|
||||
candidates := make([]string, 0, 3)
|
||||
if v := strings.TrimSpace(compatName); v != "" {
|
||||
candidates = append(candidates, v)
|
||||
}
|
||||
if v := strings.TrimSpace(providerKey); v != "" {
|
||||
candidates = append(candidates, v)
|
||||
}
|
||||
if v := strings.TrimSpace(authProvider); v != "" {
|
||||
candidates = append(candidates, v)
|
||||
}
|
||||
for i := range cfg.OpenAICompatibility {
|
||||
compat := &cfg.OpenAICompatibility[i]
|
||||
for _, candidate := range candidates {
|
||||
if candidate != "" && strings.EqualFold(strings.TrimSpace(candidate), compat.Name) {
|
||||
return compat
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func asModelAliasEntries[T interface {
|
||||
GetName() string
|
||||
GetAlias() string
|
||||
}](models []T) []modelAliasEntry {
|
||||
if len(models) == 0 {
|
||||
return nil
|
||||
}
|
||||
out := make([]modelAliasEntry, 0, len(models))
|
||||
for i := range models {
|
||||
out = append(out, models[i])
|
||||
}
|
||||
return out
|
||||
}
|
||||
@@ -1304,6 +1684,13 @@ func (m *Manager) pickNext(ctx context.Context, provider, model string, opts cli
|
||||
}
|
||||
candidates := make([]*Auth, 0, len(m.auths))
|
||||
modelKey := strings.TrimSpace(model)
|
||||
// Always use base model name (without thinking suffix) for auth matching.
|
||||
if modelKey != "" {
|
||||
parsed := thinking.ParseSuffix(modelKey)
|
||||
if parsed.ModelName != "" {
|
||||
modelKey = strings.TrimSpace(parsed.ModelName)
|
||||
}
|
||||
}
|
||||
registryRef := registry.GetGlobalRegistry()
|
||||
for _, candidate := range m.auths {
|
||||
if candidate.Provider != provider || candidate.Disabled {
|
||||
@@ -1359,6 +1746,13 @@ func (m *Manager) pickNextMixed(ctx context.Context, providers []string, model s
|
||||
m.mu.RLock()
|
||||
candidates := make([]*Auth, 0, len(m.auths))
|
||||
modelKey := strings.TrimSpace(model)
|
||||
// Always use base model name (without thinking suffix) for auth matching.
|
||||
if modelKey != "" {
|
||||
parsed := thinking.ParseSuffix(modelKey)
|
||||
if parsed.ModelName != "" {
|
||||
modelKey = strings.TrimSpace(parsed.ModelName)
|
||||
}
|
||||
}
|
||||
registryRef := registry.GetGlobalRegistry()
|
||||
for _, candidate := range m.auths {
|
||||
if candidate == nil || candidate.Disabled {
|
||||
|
||||
@@ -1,171 +0,0 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"strings"
|
||||
|
||||
internalconfig "github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
||||
)
|
||||
|
||||
type modelNameMappingTable struct {
|
||||
// reverse maps channel -> alias (lower) -> original upstream model name.
|
||||
reverse map[string]map[string]string
|
||||
}
|
||||
|
||||
func compileModelNameMappingTable(mappings map[string][]internalconfig.ModelNameMapping) *modelNameMappingTable {
|
||||
if len(mappings) == 0 {
|
||||
return &modelNameMappingTable{}
|
||||
}
|
||||
out := &modelNameMappingTable{
|
||||
reverse: make(map[string]map[string]string, len(mappings)),
|
||||
}
|
||||
for rawChannel, entries := range mappings {
|
||||
channel := strings.ToLower(strings.TrimSpace(rawChannel))
|
||||
if channel == "" || len(entries) == 0 {
|
||||
continue
|
||||
}
|
||||
rev := make(map[string]string, len(entries))
|
||||
for _, entry := range entries {
|
||||
name := strings.TrimSpace(entry.Name)
|
||||
alias := strings.TrimSpace(entry.Alias)
|
||||
if name == "" || alias == "" {
|
||||
continue
|
||||
}
|
||||
if strings.EqualFold(name, alias) {
|
||||
continue
|
||||
}
|
||||
aliasKey := strings.ToLower(alias)
|
||||
if _, exists := rev[aliasKey]; exists {
|
||||
continue
|
||||
}
|
||||
rev[aliasKey] = name
|
||||
}
|
||||
if len(rev) > 0 {
|
||||
out.reverse[channel] = rev
|
||||
}
|
||||
}
|
||||
if len(out.reverse) == 0 {
|
||||
out.reverse = nil
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// SetOAuthModelMappings updates the OAuth model name mapping table used during execution.
|
||||
// The mapping is applied per-auth channel to resolve the upstream model name while keeping the
|
||||
// client-visible model name unchanged for translation/response formatting.
|
||||
func (m *Manager) SetOAuthModelMappings(mappings map[string][]internalconfig.ModelNameMapping) {
|
||||
if m == nil {
|
||||
return
|
||||
}
|
||||
table := compileModelNameMappingTable(mappings)
|
||||
// atomic.Value requires non-nil store values.
|
||||
if table == nil {
|
||||
table = &modelNameMappingTable{}
|
||||
}
|
||||
m.modelNameMappings.Store(table)
|
||||
}
|
||||
|
||||
// applyOAuthModelMapping resolves the upstream model from OAuth model mappings
|
||||
// and returns the resolved model along with updated metadata. If a mapping exists,
|
||||
// the returned model is the upstream model and metadata contains the original
|
||||
// requested model for response translation.
|
||||
func (m *Manager) applyOAuthModelMapping(auth *Auth, requestedModel string, metadata map[string]any) (string, map[string]any) {
|
||||
upstreamModel := m.resolveOAuthUpstreamModel(auth, requestedModel)
|
||||
if upstreamModel == "" {
|
||||
return requestedModel, metadata
|
||||
}
|
||||
out := make(map[string]any, 1)
|
||||
if len(metadata) > 0 {
|
||||
out = make(map[string]any, len(metadata)+1)
|
||||
for k, v := range metadata {
|
||||
out[k] = v
|
||||
}
|
||||
}
|
||||
// Store the requested alias (e.g., "gp") so downstream can use it to look up
|
||||
// model metadata from the global registry where it was registered under this alias.
|
||||
out[util.ModelMappingOriginalModelMetadataKey] = requestedModel
|
||||
return upstreamModel, out
|
||||
}
|
||||
|
||||
func (m *Manager) resolveOAuthUpstreamModel(auth *Auth, requestedModel string) string {
|
||||
if m == nil || auth == nil {
|
||||
return ""
|
||||
}
|
||||
channel := modelMappingChannel(auth)
|
||||
if channel == "" {
|
||||
return ""
|
||||
}
|
||||
key := strings.ToLower(strings.TrimSpace(requestedModel))
|
||||
if key == "" {
|
||||
return ""
|
||||
}
|
||||
raw := m.modelNameMappings.Load()
|
||||
table, _ := raw.(*modelNameMappingTable)
|
||||
if table == nil || table.reverse == nil {
|
||||
return ""
|
||||
}
|
||||
rev := table.reverse[channel]
|
||||
if rev == nil {
|
||||
return ""
|
||||
}
|
||||
original := strings.TrimSpace(rev[key])
|
||||
if original == "" || strings.EqualFold(original, requestedModel) {
|
||||
return ""
|
||||
}
|
||||
return original
|
||||
}
|
||||
|
||||
// modelMappingChannel extracts the OAuth model mapping channel from an Auth object.
|
||||
// It determines the provider and auth kind from the Auth's attributes and delegates
|
||||
// to OAuthModelMappingChannel for the actual channel resolution.
|
||||
func modelMappingChannel(auth *Auth) string {
|
||||
if auth == nil {
|
||||
return ""
|
||||
}
|
||||
provider := strings.ToLower(strings.TrimSpace(auth.Provider))
|
||||
authKind := ""
|
||||
if auth.Attributes != nil {
|
||||
authKind = strings.ToLower(strings.TrimSpace(auth.Attributes["auth_kind"]))
|
||||
}
|
||||
if authKind == "" {
|
||||
if kind, _ := auth.AccountInfo(); strings.EqualFold(kind, "api_key") {
|
||||
authKind = "apikey"
|
||||
}
|
||||
}
|
||||
return OAuthModelMappingChannel(provider, authKind)
|
||||
}
|
||||
|
||||
// OAuthModelMappingChannel returns the OAuth model mapping channel name for a given provider
|
||||
// and auth kind. Returns empty string if the provider/authKind combination doesn't support
|
||||
// OAuth model mappings (e.g., API key authentication).
|
||||
//
|
||||
// Supported channels: gemini-cli, vertex, aistudio, antigravity, claude, codex, qwen, iflow.
|
||||
func OAuthModelMappingChannel(provider, authKind string) string {
|
||||
provider = strings.ToLower(strings.TrimSpace(provider))
|
||||
authKind = strings.ToLower(strings.TrimSpace(authKind))
|
||||
switch provider {
|
||||
case "gemini":
|
||||
// gemini provider uses gemini-api-key config, not oauth-model-mappings.
|
||||
// OAuth-based gemini auth is converted to "gemini-cli" by the synthesizer.
|
||||
return ""
|
||||
case "vertex":
|
||||
if authKind == "apikey" {
|
||||
return ""
|
||||
}
|
||||
return "vertex"
|
||||
case "claude":
|
||||
if authKind == "apikey" {
|
||||
return ""
|
||||
}
|
||||
return "claude"
|
||||
case "codex":
|
||||
if authKind == "apikey" {
|
||||
return ""
|
||||
}
|
||||
return "codex"
|
||||
case "gemini-cli", "aistudio", "antigravity", "qwen", "iflow":
|
||||
return provider
|
||||
default:
|
||||
return ""
|
||||
}
|
||||
}
|
||||
253
sdk/cliproxy/auth/oauth_model_alias.go
Normal file
253
sdk/cliproxy/auth/oauth_model_alias.go
Normal file
@@ -0,0 +1,253 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"strings"
|
||||
|
||||
internalconfig "github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/thinking"
|
||||
)
|
||||
|
||||
type modelAliasEntry interface {
|
||||
GetName() string
|
||||
GetAlias() string
|
||||
}
|
||||
|
||||
type oauthModelAliasTable struct {
|
||||
// reverse maps channel -> alias (lower) -> original upstream model name.
|
||||
reverse map[string]map[string]string
|
||||
}
|
||||
|
||||
func compileOAuthModelAliasTable(aliases map[string][]internalconfig.OAuthModelAlias) *oauthModelAliasTable {
|
||||
if len(aliases) == 0 {
|
||||
return &oauthModelAliasTable{}
|
||||
}
|
||||
out := &oauthModelAliasTable{
|
||||
reverse: make(map[string]map[string]string, len(aliases)),
|
||||
}
|
||||
for rawChannel, entries := range aliases {
|
||||
channel := strings.ToLower(strings.TrimSpace(rawChannel))
|
||||
if channel == "" || len(entries) == 0 {
|
||||
continue
|
||||
}
|
||||
rev := make(map[string]string, len(entries))
|
||||
for _, entry := range entries {
|
||||
name := strings.TrimSpace(entry.Name)
|
||||
alias := strings.TrimSpace(entry.Alias)
|
||||
if name == "" || alias == "" {
|
||||
continue
|
||||
}
|
||||
if strings.EqualFold(name, alias) {
|
||||
continue
|
||||
}
|
||||
aliasKey := strings.ToLower(alias)
|
||||
if _, exists := rev[aliasKey]; exists {
|
||||
continue
|
||||
}
|
||||
rev[aliasKey] = name
|
||||
}
|
||||
if len(rev) > 0 {
|
||||
out.reverse[channel] = rev
|
||||
}
|
||||
}
|
||||
if len(out.reverse) == 0 {
|
||||
out.reverse = nil
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// SetOAuthModelAlias updates the OAuth model name alias table used during execution.
|
||||
// The alias is applied per-auth channel to resolve the upstream model name while keeping the
|
||||
// client-visible model name unchanged for translation/response formatting.
|
||||
func (m *Manager) SetOAuthModelAlias(aliases map[string][]internalconfig.OAuthModelAlias) {
|
||||
if m == nil {
|
||||
return
|
||||
}
|
||||
table := compileOAuthModelAliasTable(aliases)
|
||||
// atomic.Value requires non-nil store values.
|
||||
if table == nil {
|
||||
table = &oauthModelAliasTable{}
|
||||
}
|
||||
m.oauthModelAlias.Store(table)
|
||||
}
|
||||
|
||||
// applyOAuthModelAlias resolves the upstream model from OAuth model alias.
|
||||
// If an alias exists, the returned model is the upstream model.
|
||||
func (m *Manager) applyOAuthModelAlias(auth *Auth, requestedModel string) string {
|
||||
upstreamModel := m.resolveOAuthUpstreamModel(auth, requestedModel)
|
||||
if upstreamModel == "" {
|
||||
return requestedModel
|
||||
}
|
||||
return upstreamModel
|
||||
}
|
||||
|
||||
func resolveModelAliasFromConfigModels(requestedModel string, models []modelAliasEntry) string {
|
||||
requestedModel = strings.TrimSpace(requestedModel)
|
||||
if requestedModel == "" {
|
||||
return ""
|
||||
}
|
||||
if len(models) == 0 {
|
||||
return ""
|
||||
}
|
||||
|
||||
requestResult := thinking.ParseSuffix(requestedModel)
|
||||
base := requestResult.ModelName
|
||||
candidates := []string{base}
|
||||
if base != requestedModel {
|
||||
candidates = append(candidates, requestedModel)
|
||||
}
|
||||
|
||||
preserveSuffix := func(resolved string) string {
|
||||
resolved = strings.TrimSpace(resolved)
|
||||
if resolved == "" {
|
||||
return ""
|
||||
}
|
||||
if thinking.ParseSuffix(resolved).HasSuffix {
|
||||
return resolved
|
||||
}
|
||||
if requestResult.HasSuffix && requestResult.RawSuffix != "" {
|
||||
return resolved + "(" + requestResult.RawSuffix + ")"
|
||||
}
|
||||
return resolved
|
||||
}
|
||||
|
||||
for i := range models {
|
||||
name := strings.TrimSpace(models[i].GetName())
|
||||
alias := strings.TrimSpace(models[i].GetAlias())
|
||||
for _, candidate := range candidates {
|
||||
if candidate == "" {
|
||||
continue
|
||||
}
|
||||
if alias != "" && strings.EqualFold(alias, candidate) {
|
||||
if name != "" {
|
||||
return preserveSuffix(name)
|
||||
}
|
||||
return preserveSuffix(candidate)
|
||||
}
|
||||
if name != "" && strings.EqualFold(name, candidate) {
|
||||
return preserveSuffix(name)
|
||||
}
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// resolveOAuthUpstreamModel resolves the upstream model name from OAuth model alias.
|
||||
// If an alias exists, returns the original (upstream) model name that corresponds
|
||||
// to the requested alias.
|
||||
//
|
||||
// If the requested model contains a thinking suffix (e.g., "gemini-2.5-pro(8192)"),
|
||||
// the suffix is preserved in the returned model name. However, if the alias's
|
||||
// original name already contains a suffix, the config suffix takes priority.
|
||||
func (m *Manager) resolveOAuthUpstreamModel(auth *Auth, requestedModel string) string {
|
||||
return resolveUpstreamModelFromAliasTable(m, auth, requestedModel, modelAliasChannel(auth))
|
||||
}
|
||||
|
||||
func resolveUpstreamModelFromAliasTable(m *Manager, auth *Auth, requestedModel, channel string) string {
|
||||
if m == nil || auth == nil {
|
||||
return ""
|
||||
}
|
||||
if channel == "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
// Extract thinking suffix from requested model using ParseSuffix
|
||||
requestResult := thinking.ParseSuffix(requestedModel)
|
||||
baseModel := requestResult.ModelName
|
||||
|
||||
// Candidate keys to match: base model and raw input (handles suffix-parsing edge cases).
|
||||
candidates := []string{baseModel}
|
||||
if baseModel != requestedModel {
|
||||
candidates = append(candidates, requestedModel)
|
||||
}
|
||||
|
||||
raw := m.oauthModelAlias.Load()
|
||||
table, _ := raw.(*oauthModelAliasTable)
|
||||
if table == nil || table.reverse == nil {
|
||||
return ""
|
||||
}
|
||||
rev := table.reverse[channel]
|
||||
if rev == nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
for _, candidate := range candidates {
|
||||
key := strings.ToLower(strings.TrimSpace(candidate))
|
||||
if key == "" {
|
||||
continue
|
||||
}
|
||||
original := strings.TrimSpace(rev[key])
|
||||
if original == "" {
|
||||
continue
|
||||
}
|
||||
if strings.EqualFold(original, baseModel) {
|
||||
return ""
|
||||
}
|
||||
|
||||
// If config already has suffix, it takes priority.
|
||||
if thinking.ParseSuffix(original).HasSuffix {
|
||||
return original
|
||||
}
|
||||
// Preserve user's thinking suffix on the resolved model.
|
||||
if requestResult.HasSuffix && requestResult.RawSuffix != "" {
|
||||
return original + "(" + requestResult.RawSuffix + ")"
|
||||
}
|
||||
return original
|
||||
}
|
||||
|
||||
return ""
|
||||
}
|
||||
|
||||
// modelAliasChannel extracts the OAuth model alias channel from an Auth object.
|
||||
// It determines the provider and auth kind from the Auth's attributes and delegates
|
||||
// to OAuthModelAliasChannel for the actual channel resolution.
|
||||
func modelAliasChannel(auth *Auth) string {
|
||||
if auth == nil {
|
||||
return ""
|
||||
}
|
||||
provider := strings.ToLower(strings.TrimSpace(auth.Provider))
|
||||
authKind := ""
|
||||
if auth.Attributes != nil {
|
||||
authKind = strings.ToLower(strings.TrimSpace(auth.Attributes["auth_kind"]))
|
||||
}
|
||||
if authKind == "" {
|
||||
if kind, _ := auth.AccountInfo(); strings.EqualFold(kind, "api_key") {
|
||||
authKind = "apikey"
|
||||
}
|
||||
}
|
||||
return OAuthModelAliasChannel(provider, authKind)
|
||||
}
|
||||
|
||||
// OAuthModelAliasChannel returns the OAuth model alias channel name for a given provider
|
||||
// and auth kind. Returns empty string if the provider/authKind combination doesn't support
|
||||
// OAuth model alias (e.g., API key authentication).
|
||||
//
|
||||
// Supported channels: gemini-cli, vertex, aistudio, antigravity, claude, codex, qwen, iflow.
|
||||
func OAuthModelAliasChannel(provider, authKind string) string {
|
||||
provider = strings.ToLower(strings.TrimSpace(provider))
|
||||
authKind = strings.ToLower(strings.TrimSpace(authKind))
|
||||
switch provider {
|
||||
case "gemini":
|
||||
// gemini provider uses gemini-api-key config, not oauth-model-alias.
|
||||
// OAuth-based gemini auth is converted to "gemini-cli" by the synthesizer.
|
||||
return ""
|
||||
case "vertex":
|
||||
if authKind == "apikey" {
|
||||
return ""
|
||||
}
|
||||
return "vertex"
|
||||
case "claude":
|
||||
if authKind == "apikey" {
|
||||
return ""
|
||||
}
|
||||
return "claude"
|
||||
case "codex":
|
||||
if authKind == "apikey" {
|
||||
return ""
|
||||
}
|
||||
return "codex"
|
||||
case "gemini-cli", "aistudio", "antigravity", "qwen", "iflow":
|
||||
return provider
|
||||
default:
|
||||
return ""
|
||||
}
|
||||
}
|
||||
177
sdk/cliproxy/auth/oauth_model_alias_test.go
Normal file
177
sdk/cliproxy/auth/oauth_model_alias_test.go
Normal file
@@ -0,0 +1,177 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
internalconfig "github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
)
|
||||
|
||||
func TestResolveOAuthUpstreamModel_SuffixPreservation(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
aliases map[string][]internalconfig.OAuthModelAlias
|
||||
channel string
|
||||
input string
|
||||
want string
|
||||
}{
|
||||
{
|
||||
name: "numeric suffix preserved",
|
||||
aliases: map[string][]internalconfig.OAuthModelAlias{
|
||||
"gemini-cli": {{Name: "gemini-2.5-pro-exp-03-25", Alias: "gemini-2.5-pro"}},
|
||||
},
|
||||
channel: "gemini-cli",
|
||||
input: "gemini-2.5-pro(8192)",
|
||||
want: "gemini-2.5-pro-exp-03-25(8192)",
|
||||
},
|
||||
{
|
||||
name: "level suffix preserved",
|
||||
aliases: map[string][]internalconfig.OAuthModelAlias{
|
||||
"claude": {{Name: "claude-sonnet-4-5-20250514", Alias: "claude-sonnet-4-5"}},
|
||||
},
|
||||
channel: "claude",
|
||||
input: "claude-sonnet-4-5(high)",
|
||||
want: "claude-sonnet-4-5-20250514(high)",
|
||||
},
|
||||
{
|
||||
name: "no suffix unchanged",
|
||||
aliases: map[string][]internalconfig.OAuthModelAlias{
|
||||
"gemini-cli": {{Name: "gemini-2.5-pro-exp-03-25", Alias: "gemini-2.5-pro"}},
|
||||
},
|
||||
channel: "gemini-cli",
|
||||
input: "gemini-2.5-pro",
|
||||
want: "gemini-2.5-pro-exp-03-25",
|
||||
},
|
||||
{
|
||||
name: "config suffix takes priority",
|
||||
aliases: map[string][]internalconfig.OAuthModelAlias{
|
||||
"claude": {{Name: "claude-sonnet-4-5-20250514(low)", Alias: "claude-sonnet-4-5"}},
|
||||
},
|
||||
channel: "claude",
|
||||
input: "claude-sonnet-4-5(high)",
|
||||
want: "claude-sonnet-4-5-20250514(low)",
|
||||
},
|
||||
{
|
||||
name: "auto suffix preserved",
|
||||
aliases: map[string][]internalconfig.OAuthModelAlias{
|
||||
"gemini-cli": {{Name: "gemini-2.5-pro-exp-03-25", Alias: "gemini-2.5-pro"}},
|
||||
},
|
||||
channel: "gemini-cli",
|
||||
input: "gemini-2.5-pro(auto)",
|
||||
want: "gemini-2.5-pro-exp-03-25(auto)",
|
||||
},
|
||||
{
|
||||
name: "none suffix preserved",
|
||||
aliases: map[string][]internalconfig.OAuthModelAlias{
|
||||
"gemini-cli": {{Name: "gemini-2.5-pro-exp-03-25", Alias: "gemini-2.5-pro"}},
|
||||
},
|
||||
channel: "gemini-cli",
|
||||
input: "gemini-2.5-pro(none)",
|
||||
want: "gemini-2.5-pro-exp-03-25(none)",
|
||||
},
|
||||
{
|
||||
name: "case insensitive alias lookup with suffix",
|
||||
aliases: map[string][]internalconfig.OAuthModelAlias{
|
||||
"gemini-cli": {{Name: "gemini-2.5-pro-exp-03-25", Alias: "Gemini-2.5-Pro"}},
|
||||
},
|
||||
channel: "gemini-cli",
|
||||
input: "gemini-2.5-pro(high)",
|
||||
want: "gemini-2.5-pro-exp-03-25(high)",
|
||||
},
|
||||
{
|
||||
name: "no alias returns empty",
|
||||
aliases: map[string][]internalconfig.OAuthModelAlias{
|
||||
"gemini-cli": {{Name: "gemini-2.5-pro-exp-03-25", Alias: "gemini-2.5-pro"}},
|
||||
},
|
||||
channel: "gemini-cli",
|
||||
input: "unknown-model(high)",
|
||||
want: "",
|
||||
},
|
||||
{
|
||||
name: "wrong channel returns empty",
|
||||
aliases: map[string][]internalconfig.OAuthModelAlias{
|
||||
"gemini-cli": {{Name: "gemini-2.5-pro-exp-03-25", Alias: "gemini-2.5-pro"}},
|
||||
},
|
||||
channel: "claude",
|
||||
input: "gemini-2.5-pro(high)",
|
||||
want: "",
|
||||
},
|
||||
{
|
||||
name: "empty suffix filtered out",
|
||||
aliases: map[string][]internalconfig.OAuthModelAlias{
|
||||
"gemini-cli": {{Name: "gemini-2.5-pro-exp-03-25", Alias: "gemini-2.5-pro"}},
|
||||
},
|
||||
channel: "gemini-cli",
|
||||
input: "gemini-2.5-pro()",
|
||||
want: "gemini-2.5-pro-exp-03-25",
|
||||
},
|
||||
{
|
||||
name: "incomplete suffix treated as no suffix",
|
||||
aliases: map[string][]internalconfig.OAuthModelAlias{
|
||||
"gemini-cli": {{Name: "gemini-2.5-pro-exp-03-25", Alias: "gemini-2.5-pro(high"}},
|
||||
},
|
||||
channel: "gemini-cli",
|
||||
input: "gemini-2.5-pro(high",
|
||||
want: "gemini-2.5-pro-exp-03-25",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
mgr := NewManager(nil, nil, nil)
|
||||
mgr.SetConfig(&internalconfig.Config{})
|
||||
mgr.SetOAuthModelAlias(tt.aliases)
|
||||
|
||||
auth := createAuthForChannel(tt.channel)
|
||||
got := mgr.resolveOAuthUpstreamModel(auth, tt.input)
|
||||
if got != tt.want {
|
||||
t.Errorf("resolveOAuthUpstreamModel(%q) = %q, want %q", tt.input, got, tt.want)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func createAuthForChannel(channel string) *Auth {
|
||||
switch channel {
|
||||
case "gemini-cli":
|
||||
return &Auth{Provider: "gemini-cli"}
|
||||
case "claude":
|
||||
return &Auth{Provider: "claude", Attributes: map[string]string{"auth_kind": "oauth"}}
|
||||
case "vertex":
|
||||
return &Auth{Provider: "vertex", Attributes: map[string]string{"auth_kind": "oauth"}}
|
||||
case "codex":
|
||||
return &Auth{Provider: "codex", Attributes: map[string]string{"auth_kind": "oauth"}}
|
||||
case "aistudio":
|
||||
return &Auth{Provider: "aistudio"}
|
||||
case "antigravity":
|
||||
return &Auth{Provider: "antigravity"}
|
||||
case "qwen":
|
||||
return &Auth{Provider: "qwen"}
|
||||
case "iflow":
|
||||
return &Auth{Provider: "iflow"}
|
||||
default:
|
||||
return &Auth{Provider: channel}
|
||||
}
|
||||
}
|
||||
|
||||
func TestApplyOAuthModelAlias_SuffixPreservation(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
aliases := map[string][]internalconfig.OAuthModelAlias{
|
||||
"gemini-cli": {{Name: "gemini-2.5-pro-exp-03-25", Alias: "gemini-2.5-pro"}},
|
||||
}
|
||||
|
||||
mgr := NewManager(nil, nil, nil)
|
||||
mgr.SetConfig(&internalconfig.Config{})
|
||||
mgr.SetOAuthModelAlias(aliases)
|
||||
|
||||
auth := &Auth{ID: "test-auth-id", Provider: "gemini-cli"}
|
||||
|
||||
resolvedModel := mgr.applyOAuthModelAlias(auth, "gemini-2.5-pro(8192)")
|
||||
if resolvedModel != "gemini-2.5-pro-exp-03-25(8192)" {
|
||||
t.Errorf("applyOAuthModelAlias() model = %q, want %q", resolvedModel, "gemini-2.5-pro-exp-03-25(8192)")
|
||||
}
|
||||
}
|
||||
@@ -215,7 +215,8 @@ func (b *Builder) Build() (*Service, error) {
|
||||
}
|
||||
// Attach a default RoundTripper provider so providers can opt-in per-auth transports.
|
||||
coreManager.SetRoundTripperProvider(newDefaultRoundTripperProvider())
|
||||
coreManager.SetOAuthModelMappings(b.cfg.OAuthModelMappings)
|
||||
coreManager.SetConfig(b.cfg)
|
||||
coreManager.SetOAuthModelAlias(b.cfg.OAuthModelAlias)
|
||||
|
||||
service := &Service{
|
||||
cfg: b.cfg,
|
||||
|
||||
@@ -553,7 +553,8 @@ func (s *Service) Run(ctx context.Context) error {
|
||||
s.cfg = newCfg
|
||||
s.cfgMu.Unlock()
|
||||
if s.coreManager != nil {
|
||||
s.coreManager.SetOAuthModelMappings(newCfg.OAuthModelMappings)
|
||||
s.coreManager.SetConfig(newCfg)
|
||||
s.coreManager.SetOAuthModelAlias(newCfg.OAuthModelAlias)
|
||||
}
|
||||
s.rebindExecutors()
|
||||
}
|
||||
@@ -825,6 +826,7 @@ func (s *Service) registerModelsForAuth(a *coreauth.Auth) {
|
||||
OwnedBy: compat.Name,
|
||||
Type: "openai-compatibility",
|
||||
DisplayName: modelID,
|
||||
UserDefined: true,
|
||||
})
|
||||
}
|
||||
// Register and return
|
||||
@@ -847,7 +849,7 @@ func (s *Service) registerModelsForAuth(a *coreauth.Auth) {
|
||||
}
|
||||
}
|
||||
}
|
||||
models = applyOAuthModelMappings(s.cfg, provider, authKind, models)
|
||||
models = applyOAuthModelAlias(s.cfg, provider, authKind, models)
|
||||
if len(models) > 0 {
|
||||
key := provider
|
||||
if key == "" {
|
||||
@@ -1157,6 +1159,7 @@ func buildConfigModels[T modelEntry](models []T, ownedBy, modelType string) []*M
|
||||
OwnedBy: ownedBy,
|
||||
Type: modelType,
|
||||
DisplayName: display,
|
||||
UserDefined: true,
|
||||
}
|
||||
if name != "" {
|
||||
if upstream := registry.LookupStaticModelInfo(name); upstream != nil && upstream.Thinking != nil {
|
||||
@@ -1219,28 +1222,28 @@ func rewriteModelInfoName(name, oldID, newID string) string {
|
||||
return name
|
||||
}
|
||||
|
||||
func applyOAuthModelMappings(cfg *config.Config, provider, authKind string, models []*ModelInfo) []*ModelInfo {
|
||||
func applyOAuthModelAlias(cfg *config.Config, provider, authKind string, models []*ModelInfo) []*ModelInfo {
|
||||
if cfg == nil || len(models) == 0 {
|
||||
return models
|
||||
}
|
||||
channel := coreauth.OAuthModelMappingChannel(provider, authKind)
|
||||
if channel == "" || len(cfg.OAuthModelMappings) == 0 {
|
||||
channel := coreauth.OAuthModelAliasChannel(provider, authKind)
|
||||
if channel == "" || len(cfg.OAuthModelAlias) == 0 {
|
||||
return models
|
||||
}
|
||||
mappings := cfg.OAuthModelMappings[channel]
|
||||
if len(mappings) == 0 {
|
||||
aliases := cfg.OAuthModelAlias[channel]
|
||||
if len(aliases) == 0 {
|
||||
return models
|
||||
}
|
||||
|
||||
type mappingEntry struct {
|
||||
type aliasEntry struct {
|
||||
alias string
|
||||
fork bool
|
||||
}
|
||||
|
||||
forward := make(map[string][]mappingEntry, len(mappings))
|
||||
for i := range mappings {
|
||||
name := strings.TrimSpace(mappings[i].Name)
|
||||
alias := strings.TrimSpace(mappings[i].Alias)
|
||||
forward := make(map[string][]aliasEntry, len(aliases))
|
||||
for i := range aliases {
|
||||
name := strings.TrimSpace(aliases[i].Name)
|
||||
alias := strings.TrimSpace(aliases[i].Alias)
|
||||
if name == "" || alias == "" {
|
||||
continue
|
||||
}
|
||||
@@ -1248,7 +1251,7 @@ func applyOAuthModelMappings(cfg *config.Config, provider, authKind string, mode
|
||||
continue
|
||||
}
|
||||
key := strings.ToLower(name)
|
||||
forward[key] = append(forward[key], mappingEntry{alias: alias, fork: mappings[i].Fork})
|
||||
forward[key] = append(forward[key], aliasEntry{alias: alias, fork: aliases[i].Fork})
|
||||
}
|
||||
if len(forward) == 0 {
|
||||
return models
|
||||
|
||||
@@ -6,9 +6,9 @@ import (
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/sdk/config"
|
||||
)
|
||||
|
||||
func TestApplyOAuthModelMappings_Rename(t *testing.T) {
|
||||
func TestApplyOAuthModelAlias_Rename(t *testing.T) {
|
||||
cfg := &config.Config{
|
||||
OAuthModelMappings: map[string][]config.ModelNameMapping{
|
||||
OAuthModelAlias: map[string][]config.OAuthModelAlias{
|
||||
"codex": {
|
||||
{Name: "gpt-5", Alias: "g5"},
|
||||
},
|
||||
@@ -18,7 +18,7 @@ func TestApplyOAuthModelMappings_Rename(t *testing.T) {
|
||||
{ID: "gpt-5", Name: "models/gpt-5"},
|
||||
}
|
||||
|
||||
out := applyOAuthModelMappings(cfg, "codex", "oauth", models)
|
||||
out := applyOAuthModelAlias(cfg, "codex", "oauth", models)
|
||||
if len(out) != 1 {
|
||||
t.Fatalf("expected 1 model, got %d", len(out))
|
||||
}
|
||||
@@ -30,9 +30,9 @@ func TestApplyOAuthModelMappings_Rename(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestApplyOAuthModelMappings_ForkAddsAlias(t *testing.T) {
|
||||
func TestApplyOAuthModelAlias_ForkAddsAlias(t *testing.T) {
|
||||
cfg := &config.Config{
|
||||
OAuthModelMappings: map[string][]config.ModelNameMapping{
|
||||
OAuthModelAlias: map[string][]config.OAuthModelAlias{
|
||||
"codex": {
|
||||
{Name: "gpt-5", Alias: "g5", Fork: true},
|
||||
},
|
||||
@@ -42,7 +42,7 @@ func TestApplyOAuthModelMappings_ForkAddsAlias(t *testing.T) {
|
||||
{ID: "gpt-5", Name: "models/gpt-5"},
|
||||
}
|
||||
|
||||
out := applyOAuthModelMappings(cfg, "codex", "oauth", models)
|
||||
out := applyOAuthModelAlias(cfg, "codex", "oauth", models)
|
||||
if len(out) != 2 {
|
||||
t.Fatalf("expected 2 models, got %d", len(out))
|
||||
}
|
||||
@@ -57,9 +57,9 @@ func TestApplyOAuthModelMappings_ForkAddsAlias(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestApplyOAuthModelMappings_ForkAddsMultipleAliases(t *testing.T) {
|
||||
func TestApplyOAuthModelAlias_ForkAddsMultipleAliases(t *testing.T) {
|
||||
cfg := &config.Config{
|
||||
OAuthModelMappings: map[string][]config.ModelNameMapping{
|
||||
OAuthModelAlias: map[string][]config.OAuthModelAlias{
|
||||
"codex": {
|
||||
{Name: "gpt-5", Alias: "g5", Fork: true},
|
||||
{Name: "gpt-5", Alias: "g5-2", Fork: true},
|
||||
@@ -70,7 +70,7 @@ func TestApplyOAuthModelMappings_ForkAddsMultipleAliases(t *testing.T) {
|
||||
{ID: "gpt-5", Name: "models/gpt-5"},
|
||||
}
|
||||
|
||||
out := applyOAuthModelMappings(cfg, "codex", "oauth", models)
|
||||
out := applyOAuthModelAlias(cfg, "codex", "oauth", models)
|
||||
if len(out) != 3 {
|
||||
t.Fatalf("expected 3 models, got %d", len(out))
|
||||
}
|
||||
@@ -16,7 +16,7 @@ type StreamingConfig = internalconfig.StreamingConfig
|
||||
type TLSConfig = internalconfig.TLSConfig
|
||||
type RemoteManagement = internalconfig.RemoteManagement
|
||||
type AmpCode = internalconfig.AmpCode
|
||||
type ModelNameMapping = internalconfig.ModelNameMapping
|
||||
type OAuthModelAlias = internalconfig.OAuthModelAlias
|
||||
type PayloadConfig = internalconfig.PayloadConfig
|
||||
type PayloadRule = internalconfig.PayloadRule
|
||||
type PayloadModelRule = internalconfig.PayloadModelRule
|
||||
|
||||
Reference in New Issue
Block a user