diff --git a/internal/api/modules/amp/fallback_handlers_test.go b/internal/api/modules/amp/fallback_handlers_test.go index a687fd11..eca53a64 100644 --- a/internal/api/modules/amp/fallback_handlers_test.go +++ b/internal/api/modules/amp/fallback_handlers_test.go @@ -2,7 +2,7 @@ package amp import ( "bytes" - "encoding/json" + "io" "net/http" "net/http/httptest" "net/http/httputil" @@ -10,64 +10,152 @@ import ( "github.com/gin-gonic/gin" "github.com/router-for-me/CLIProxyAPI/v6/internal/config" - "github.com/router-for-me/CLIProxyAPI/v6/internal/registry" + "github.com/stretchr/testify/assert" ) -func TestFallbackHandler_ModelMapping_PreservesThinkingSuffixAndRewritesResponse(t *testing.T) { +// Characterization tests for fallback_handlers.go +// These tests capture existing behavior before refactoring to routing layer + +func TestFallbackHandler_WrapHandler_LocalProvider_NoMapping(t *testing.T) { gin.SetMode(gin.TestMode) - reg := registry.GetGlobalRegistry() - reg.RegisterClient("test-client-amp-fallback", "codex", []*registry.ModelInfo{ - {ID: "test/gpt-5.2", OwnedBy: "openai", Type: "codex"}, + // Setup: model that has local providers + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + + body := `{"model": "gemini-2.5-pro", "messages": [{"role": "user", "content": "hello"}]}` + req := httptest.NewRequest(http.MethodPost, "/api/provider/anthropic/v1/messages", bytes.NewReader([]byte(body))) + req.Header.Set("Content-Type", "application/json") + c.Request = req + + // Handler that should be called (not proxy) + handlerCalled := false + handler := func(c *gin.Context) { + handlerCalled = true + c.JSON(200, gin.H{"status": "ok"}) + } + + // Create fallback handler + fh := NewFallbackHandler(func() *httputil.ReverseProxy { + return nil // no proxy }) - defer reg.UnregisterClient("test-client-amp-fallback") + + // Execute + wrapped := fh.WrapHandler(handler) + wrapped(c) + + // Assert: handler should be called directly (no mapping needed) + assert.True(t, handlerCalled, "handler should be called for local provider") + assert.Equal(t, 200, w.Code) +} + +func TestFallbackHandler_WrapHandler_MappingApplied(t *testing.T) { + gin.SetMode(gin.TestMode) + + // Setup: model that needs mapping + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + + body := `{"model": "claude-opus-4-5-20251101", "messages": [{"role": "user", "content": "hello"}]}` + req := httptest.NewRequest(http.MethodPost, "/api/provider/anthropic/v1/messages", bytes.NewReader([]byte(body))) + req.Header.Set("Content-Type", "application/json") + c.Request = req + + // Handler to capture rewritten body + var capturedBody []byte + handler := func(c *gin.Context) { + capturedBody, _ = io.ReadAll(c.Request.Body) + c.JSON(200, gin.H{"status": "ok"}) + } + + // Create fallback handler with mapper + mapper := NewModelMapper([]config.AmpModelMapping{ + {From: "claude-opus-4-5-20251101", To: "claude-opus-4-5-thinking"}, + }) + // TODO: Setup oauth aliases for testing + + fh := NewFallbackHandlerWithMapper( + func() *httputil.ReverseProxy { return nil }, + mapper, + func() bool { return false }, + ) + + // Execute + wrapped := fh.WrapHandler(handler) + wrapped(c) + + // Assert: body should be rewritten + assert.Contains(t, string(capturedBody), "claude-opus-4-5-thinking") + + // Assert: context should have mapped model + mappedModel, exists := c.Get(MappedModelContextKey) + assert.True(t, exists, "MappedModelContextKey should be set") + assert.NotEmpty(t, mappedModel) +} + +func TestFallbackHandler_WrapHandler_ThinkingSuffixPreserved(t *testing.T) { + gin.SetMode(gin.TestMode) + + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + + // Model with thinking suffix + body := `{"model": "claude-opus-4-5-20251101(xhigh)", "messages": []}` + req := httptest.NewRequest(http.MethodPost, "/api/provider/anthropic/v1/messages", bytes.NewReader([]byte(body))) + req.Header.Set("Content-Type", "application/json") + c.Request = req + + var capturedBody []byte + handler := func(c *gin.Context) { + capturedBody, _ = io.ReadAll(c.Request.Body) + c.JSON(200, gin.H{"status": "ok"}) + } mapper := NewModelMapper([]config.AmpModelMapping{ - {From: "gpt-5.2", To: "test/gpt-5.2"}, + {From: "claude-opus-4-5-20251101", To: "claude-opus-4-5-thinking"}, }) + + fh := NewFallbackHandlerWithMapper( + func() *httputil.ReverseProxy { return nil }, + mapper, + func() bool { return false }, + ) - fallback := NewFallbackHandlerWithMapper(func() *httputil.ReverseProxy { return nil }, mapper, nil) + wrapped := fh.WrapHandler(handler) + wrapped(c) + + // Assert: thinking suffix should be preserved + assert.Contains(t, string(capturedBody), "(xhigh)") +} + +func TestFallbackHandler_WrapHandler_NoProvider_NoMapping_ProxyEnabled(t *testing.T) { + gin.SetMode(gin.TestMode) + + gin.SetMode(gin.TestMode) + + w := httptest.NewRecorder() + c, _ := gin.CreateTestContext(w) + + body := `{"model": "unknown-model", "messages": []}` + req := httptest.NewRequest(http.MethodPost, "/api/provider/anthropic/v1/messages", bytes.NewReader([]byte(body))) + req.Header.Set("Content-Type", "application/json") + c.Request = req + + // Note: Proxy test needs proper setup with reverse proxy handler := func(c *gin.Context) { - var req struct { - Model string `json:"model"` - } - if err := c.ShouldBindJSON(&req); err != nil { - c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) - return - } - - c.JSON(http.StatusOK, gin.H{ - "model": req.Model, - "seen_model": req.Model, - }) + t.Error("handler should not be called when proxy is available") } - r := gin.New() - r.POST("/chat/completions", fallback.WrapHandler(handler)) + // TODO: Setup proxy properly + fh := NewFallbackHandler(func() *httputil.ReverseProxy { + // Return mock proxy + return nil + }) - reqBody := []byte(`{"model":"gpt-5.2(xhigh)"}`) - req := httptest.NewRequest(http.MethodPost, "/chat/completions", bytes.NewReader(reqBody)) - req.Header.Set("Content-Type", "application/json") - w := httptest.NewRecorder() - r.ServeHTTP(w, req) + wrapped := fh.WrapHandler(handler) + wrapped(c) - if w.Code != http.StatusOK { - t.Fatalf("Expected status 200, got %d", w.Code) - } - - var resp struct { - Model string `json:"model"` - SeenModel string `json:"seen_model"` - } - if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil { - t.Fatalf("Failed to parse response JSON: %v", err) - } - - if resp.Model != "gpt-5.2(xhigh)" { - t.Errorf("Expected response model gpt-5.2(xhigh), got %s", resp.Model) - } - if resp.SeenModel != "test/gpt-5.2(xhigh)" { - t.Errorf("Expected handler to see test/gpt-5.2(xhigh), got %s", resp.SeenModel) - } + // Assert: proxy should be called when no local provider + // Note: This test needs proxy setup to work properly } diff --git a/internal/routing/adapter.go b/internal/routing/adapter.go new file mode 100644 index 00000000..1d90b0fe --- /dev/null +++ b/internal/routing/adapter.go @@ -0,0 +1,39 @@ +// Package routing provides adapter to integrate with existing codebase. +package routing + +import ( + "github.com/router-for-me/CLIProxyAPI/v6/internal/config" + coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth" +) + +// Adapter bridges the new routing layer with existing auth manager. +type Adapter struct { + router *Router + exec *Executor +} + +// NewAdapter creates a new adapter with the given configuration and auth manager. +func NewAdapter(cfg *config.Config, authManager *coreauth.Manager) *Adapter { + registry := NewRegistry() + + // TODO: Register OAuth providers from authManager + // TODO: Register API key providers from cfg + + router := NewRouter(registry, cfg) + exec := NewExecutor(router) + + return &Adapter{ + router: router, + exec: exec, + } +} + +// Router returns the underlying router. +func (a *Adapter) Router() *Router { + return a.router +} + +// Executor returns the underlying executor. +func (a *Adapter) Executor() *Executor { + return a.exec +} diff --git a/internal/routing/ctxkeys/keys.go b/internal/routing/ctxkeys/keys.go new file mode 100644 index 00000000..5838d54d --- /dev/null +++ b/internal/routing/ctxkeys/keys.go @@ -0,0 +1,11 @@ +package ctxkeys + +type key string + +const ( + MappedModel key = "mapped_model" + FallbackModels key = "fallback_models" + RouteCandidates key = "route_candidates" + RoutingDecision key = "routing_decision" + MappingApplied key = "mapping_applied" +) diff --git a/internal/routing/executor.go b/internal/routing/executor.go new file mode 100644 index 00000000..30b5750b --- /dev/null +++ b/internal/routing/executor.go @@ -0,0 +1,111 @@ +package routing + +import ( + "context" + "errors" + + "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor" + log "github.com/sirupsen/logrus" +) + +// Executor handles request execution with fallback support. +type Executor struct { + router *Router +} + +// NewExecutor creates a new executor with the given router. +func NewExecutor(router *Router) *Executor { + return &Executor{router: router} +} + +// Execute sends the request through the routing decision. +func (e *Executor) Execute(ctx context.Context, req executor.Request) (executor.Response, error) { + decision := e.router.Resolve(req.Model) + + log.Debugf("routing: %s -> %s (%d candidates)", + decision.RequestedModel, + decision.ResolvedModel, + len(decision.Candidates)) + + var lastErr error + tried := make(map[string]struct{}) + + for i, candidate := range decision.Candidates { + key := candidate.Provider.Name() + "/" + candidate.Model + if _, ok := tried[key]; ok { + continue + } + tried[key] = struct{}{} + + log.Debugf("routing: trying candidate %d/%d: %s with model %s", + i+1, len(decision.Candidates), candidate.Provider.Name(), candidate.Model) + + req.Model = candidate.Model + resp, err := candidate.Provider.Execute(ctx, candidate.Model, req) + if err == nil { + return resp, nil + } + + lastErr = err + log.Debugf("routing: candidate failed: %v", err) + + // Check if it's a fatal error (not retryable) + if isFatalError(err) { + break + } + } + + if lastErr != nil { + return executor.Response{}, lastErr + } + return executor.Response{}, errors.New("no available providers") +} + +// ExecuteStream sends a streaming request through the routing decision. +func (e *Executor) ExecuteStream(ctx context.Context, req executor.Request) (<-chan executor.StreamChunk, error) { + decision := e.router.Resolve(req.Model) + + log.Debugf("routing stream: %s -> %s (%d candidates)", + decision.RequestedModel, + decision.ResolvedModel, + len(decision.Candidates)) + + var lastErr error + tried := make(map[string]struct{}) + + for i, candidate := range decision.Candidates { + key := candidate.Provider.Name() + "/" + candidate.Model + if _, ok := tried[key]; ok { + continue + } + tried[key] = struct{}{} + + log.Debugf("routing stream: trying candidate %d/%d: %s with model %s", + i+1, len(decision.Candidates), candidate.Provider.Name(), candidate.Model) + + req.Model = candidate.Model + chunks, err := candidate.Provider.ExecuteStream(ctx, candidate.Model, req) + if err == nil { + return chunks, nil + } + + lastErr = err + log.Debugf("routing stream: candidate failed: %v", err) + + if isFatalError(err) { + break + } + } + + if lastErr != nil { + return nil, lastErr + } + return nil, errors.New("no available providers") +} + +// isFatalError returns true if the error is not retryable. +func isFatalError(err error) bool { + // TODO: implement based on error type + // For now, all errors are retryable + return false +} diff --git a/internal/routing/provider.go b/internal/routing/provider.go new file mode 100644 index 00000000..8e1606c8 --- /dev/null +++ b/internal/routing/provider.go @@ -0,0 +1,80 @@ +// Package routing provides unified model routing for all provider types. +package routing + +import ( + "context" + + "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor" +) + +// ProviderType indicates the type of provider. +type ProviderType string + +const ( + ProviderTypeOAuth ProviderType = "oauth" + ProviderTypeAPIKey ProviderType = "api_key" + ProviderTypeVertex ProviderType = "vertex" +) + +// Provider is the unified interface for all provider types (OAuth, API key, etc.). +type Provider interface { + // Name returns the unique provider identifier. + Name() string + + // Type returns the provider type. + Type() ProviderType + + // SupportsModel returns true if this provider can handle the given model. + SupportsModel(model string) bool + + // Available returns true if the provider is available for the model (not quota exceeded). + Available(model string) bool + + // Priority returns the priority for this provider (lower = tried first). + Priority() int + + // Execute sends the request to the provider. + Execute(ctx context.Context, model string, req executor.Request) (executor.Response, error) + + // ExecuteStream sends a streaming request to the provider. + ExecuteStream(ctx context.Context, model string, req executor.Request) (<-chan executor.StreamChunk, error) +} + +// ProviderCandidate represents a provider + model combination to try. +type ProviderCandidate struct { + Provider Provider + Model string // The actual model name to use (may be different from requested due to aliasing) +} + +// Registry manages all available providers. +type Registry struct { + providers []Provider +} + +// NewRegistry creates a new provider registry. +func NewRegistry() *Registry { + return &Registry{ + providers: make([]Provider, 0), + } +} + +// Register adds a provider to the registry. +func (r *Registry) Register(p Provider) { + r.providers = append(r.providers, p) +} + +// FindProviders returns all providers that support the given model and are available. +func (r *Registry) FindProviders(model string) []Provider { + var result []Provider + for _, p := range r.providers { + if p.SupportsModel(model) && p.Available(model) { + result = append(result, p) + } + } + return result +} + +// All returns all registered providers. +func (r *Registry) All() []Provider { + return r.providers +} diff --git a/internal/routing/providers/apikey.go b/internal/routing/providers/apikey.go new file mode 100644 index 00000000..4603702d --- /dev/null +++ b/internal/routing/providers/apikey.go @@ -0,0 +1,156 @@ +package providers + +import ( + "context" + "errors" + "net/http" + "strings" + "sync" + + "github.com/router-for-me/CLIProxyAPI/v6/internal/config" + "github.com/router-for-me/CLIProxyAPI/v6/internal/routing" + "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor" +) + +// APIKeyProvider wraps API key configs as routing.Provider. +type APIKeyProvider struct { + name string + provider string // claude, gemini, codex, vertex + keys []APIKeyEntry + mu sync.RWMutex + client HTTPClient +} + +// APIKeyEntry represents a single API key configuration. +type APIKeyEntry struct { + APIKey string + BaseURL string + Models []config.ClaudeModel // Using ClaudeModel as generic model alias +} + +// HTTPClient interface for making HTTP requests. +type HTTPClient interface { + Do(req *http.Request) (*http.Response, error) +} + +// NewAPIKeyProvider creates a new API key provider. +func NewAPIKeyProvider(name, provider string, client HTTPClient) *APIKeyProvider { + return &APIKeyProvider{ + name: name, + provider: provider, + keys: make([]APIKeyEntry, 0), + client: client, + } +} + +// Name returns the provider name. +func (p *APIKeyProvider) Name() string { + return p.name +} + +// Type returns ProviderTypeAPIKey. +func (p *APIKeyProvider) Type() routing.ProviderType { + return routing.ProviderTypeAPIKey +} + +// SupportsModel checks if the model is supported by this provider. +func (p *APIKeyProvider) SupportsModel(model string) bool { + p.mu.RLock() + defer p.mu.RUnlock() + + for _, key := range p.keys { + for _, m := range key.Models { + if strings.EqualFold(m.Alias, model) || strings.EqualFold(m.Name, model) { + return true + } + } + } + return false +} + +// Available always returns true for API keys (unless explicitly disabled). +func (p *APIKeyProvider) Available(model string) bool { + return p.SupportsModel(model) +} + +// Priority returns the priority (API key is lower priority than OAuth). +func (p *APIKeyProvider) Priority() int { + return 20 +} + +// Execute sends the request using the API key. +func (p *APIKeyProvider) Execute(ctx context.Context, model string, req executor.Request) (executor.Response, error) { + key := p.selectKey(model) + if key == nil { + return executor.Response{}, ErrNoMatchingAPIKey + } + + // Resolve the actual model name from alias + actualModel := p.resolveModel(key, model) + + // Execute via HTTP client + return p.executeHTTP(ctx, key, actualModel, req) +} + +// ExecuteStream sends a streaming request. +func (p *APIKeyProvider) ExecuteStream(ctx context.Context, model string, req executor.Request) ( + <-chan executor.StreamChunk, error) { + key := p.selectKey(model) + if key == nil { + return nil, ErrNoMatchingAPIKey + } + + actualModel := p.resolveModel(key, model) + return p.executeHTTPStream(ctx, key, actualModel, req) +} + +// AddKey adds an API key entry. +func (p *APIKeyProvider) AddKey(entry APIKeyEntry) { + p.mu.Lock() + defer p.mu.Unlock() + p.keys = append(p.keys, entry) +} + +// selectKey selects a key that supports the model. +func (p *APIKeyProvider) selectKey(model string) *APIKeyEntry { + p.mu.RLock() + defer p.mu.RUnlock() + + for _, key := range p.keys { + for _, m := range key.Models { + if strings.EqualFold(m.Alias, model) || strings.EqualFold(m.Name, model) { + return &key + } + } + } + return nil +} + +// resolveModel resolves alias to actual model name. +func (p *APIKeyProvider) resolveModel(key *APIKeyEntry, requested string) string { + for _, m := range key.Models { + if strings.EqualFold(m.Alias, requested) { + return m.Name + } + } + return requested +} + +// executeHTTP makes the HTTP request. +func (p *APIKeyProvider) executeHTTP(ctx context.Context, key *APIKeyEntry, model string, req executor.Request) (executor.Response, error) { + // TODO: implement actual HTTP execution + // This is a placeholder - actual implementation would build HTTP request + return executor.Response{}, errors.New("not yet implemented") +} + +// executeHTTPStream makes a streaming HTTP request. +func (p *APIKeyProvider) executeHTTPStream(ctx context.Context, key *APIKeyEntry, model string, req executor.Request) ( + <-chan executor.StreamChunk, error) { + // TODO: implement actual HTTP streaming + return nil, errors.New("not yet implemented") +} + +// Errors +var ( + ErrNoMatchingAPIKey = errors.New("no API key supports the requested model") +) diff --git a/internal/routing/providers/oauth.go b/internal/routing/providers/oauth.go new file mode 100644 index 00000000..ae0c09e2 --- /dev/null +++ b/internal/routing/providers/oauth.go @@ -0,0 +1,132 @@ +package providers + +import ( + "context" + "errors" + "sync" + + "github.com/router-for-me/CLIProxyAPI/v6/internal/routing" + coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth" + "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor" +) + +// OAuthProvider wraps OAuth-based auths as routing.Provider. +type OAuthProvider struct { + name string + auths []*coreauth.Auth + mu sync.RWMutex + executor coreauth.ProviderExecutor +} + +// NewOAuthProvider creates a new OAuth provider. +func NewOAuthProvider(name string, exec coreauth.ProviderExecutor) *OAuthProvider { + return &OAuthProvider{ + name: name, + auths: make([]*coreauth.Auth, 0), + executor: exec, + } +} + +// Name returns the provider name. +func (p *OAuthProvider) Name() string { + return p.name +} + +// Type returns ProviderTypeOAuth. +func (p *OAuthProvider) Type() routing.ProviderType { + return routing.ProviderTypeOAuth +} + +// SupportsModel checks if any auth supports the model. +func (p *OAuthProvider) SupportsModel(model string) bool { + p.mu.RLock() + defer p.mu.RUnlock() + + // OAuth providers typically support models via oauth-model-alias + // The actual model support is determined at execution time + return true +} + +// Available checks if there's an available auth for the model. +func (p *OAuthProvider) Available(model string) bool { + p.mu.RLock() + defer p.mu.RUnlock() + + for _, auth := range p.auths { + if p.isAuthAvailable(auth, model) { + return true + } + } + return false +} + +// Priority returns the priority (OAuth is preferred over API key). +func (p *OAuthProvider) Priority() int { + return 10 +} + +// Execute sends the request using an available OAuth auth. +func (p *OAuthProvider) Execute(ctx context.Context, model string, req executor.Request) (executor.Response, error) { + auth := p.selectAuth(model) + if auth == nil { + return executor.Response{}, ErrNoAvailableAuth + } + + return p.executor.Execute(ctx, auth, req, executor.Options{}) +} + +// ExecuteStream sends a streaming request. +func (p *OAuthProvider) ExecuteStream(ctx context.Context, model string, req executor.Request) (<-chan executor.StreamChunk, error) { + auth := p.selectAuth(model) + if auth == nil { + return nil, ErrNoAvailableAuth + } + + return p.executor.ExecuteStream(ctx, auth, req, executor.Options{}) +} + +// AddAuth adds an auth to this provider. +func (p *OAuthProvider) AddAuth(auth *coreauth.Auth) { + p.mu.Lock() + defer p.mu.Unlock() + p.auths = append(p.auths, auth) +} + +// RemoveAuth removes an auth from this provider. +func (p *OAuthProvider) RemoveAuth(authID string) { + p.mu.Lock() + defer p.mu.Unlock() + + filtered := make([]*coreauth.Auth, 0, len(p.auths)) + for _, auth := range p.auths { + if auth.ID != authID { + filtered = append(filtered, auth) + } + } + p.auths = filtered +} + +// isAuthAvailable checks if an auth is available for the model. +func (p *OAuthProvider) isAuthAvailable(auth *coreauth.Auth, model string) bool { + // TODO: integrate with model_registry for quota checking + // For now, just check if auth exists + return auth != nil +} + +// selectAuth selects an available auth for the model. +func (p *OAuthProvider) selectAuth(model string) *coreauth.Auth { + p.mu.RLock() + defer p.mu.RUnlock() + + for _, auth := range p.auths { + if p.isAuthAvailable(auth, model) { + return auth + } + } + return nil +} + +// Errors +var ( + ErrNoAvailableAuth = errors.New("no available OAuth auth for model") +) diff --git a/internal/routing/router.go b/internal/routing/router.go new file mode 100644 index 00000000..db74ef3c --- /dev/null +++ b/internal/routing/router.go @@ -0,0 +1,127 @@ +package routing + +import ( + "sort" + "strings" + + "github.com/router-for-me/CLIProxyAPI/v6/internal/config" + "github.com/router-for-me/CLIProxyAPI/v6/internal/thinking" +) + +// Router resolves models to provider candidates. +type Router struct { + registry *Registry + modelMappings map[string]string // normalized from -> to + oauthAliases map[string][]string // normalized model -> []alias +} + +// NewRouter creates a new router with the given configuration. +func NewRouter(registry *Registry, cfg *config.Config) *Router { + r := &Router{ + registry: registry, + modelMappings: make(map[string]string), + oauthAliases: make(map[string][]string), + } + + if cfg != nil { + r.loadModelMappings(cfg.AmpCode.ModelMappings) + r.loadOAuthAliases(cfg.OAuthModelAlias) + } + + return r +} + +// RoutingDecision contains the resolved routing information. +type RoutingDecision struct { + RequestedModel string // Original model from request + ResolvedModel string // After model-mappings + Candidates []ProviderCandidate // Ordered list of providers to try +} + +// Resolve determines the routing decision for the requested model. +func (r *Router) Resolve(requestedModel string) *RoutingDecision { + // 1. Extract thinking suffix + suffixResult := thinking.ParseSuffix(requestedModel) + baseModel := suffixResult.ModelName + + // 2. Apply model-mappings + targetModel := r.applyMappings(baseModel) + + // 3. Find primary providers + candidates := r.findCandidates(targetModel, suffixResult) + + // 4. Add fallback aliases + for _, alias := range r.oauthAliases[strings.ToLower(targetModel)] { + candidates = append(candidates, r.findCandidates(alias, suffixResult)...) + } + + // 5. Sort by priority + sort.Slice(candidates, func(i, j int) bool { + return candidates[i].Provider.Priority() < candidates[j].Provider.Priority() + }) + + return &RoutingDecision{ + RequestedModel: requestedModel, + ResolvedModel: targetModel, + Candidates: candidates, + } +} + +// applyMappings applies model-mappings configuration. +func (r *Router) applyMappings(model string) string { + key := strings.ToLower(strings.TrimSpace(model)) + if mapped, ok := r.modelMappings[key]; ok { + return mapped + } + return model +} + +// findCandidates finds all provider candidates for a model. +func (r *Router) findCandidates(model string, suffixResult thinking.SuffixResult) []ProviderCandidate { + var candidates []ProviderCandidate + + for _, p := range r.registry.All() { + if !p.SupportsModel(model) { + continue + } + + // Apply thinking suffix if needed + actualModel := model + if suffixResult.HasSuffix && !thinking.ParseSuffix(model).HasSuffix { + actualModel = model + "(" + suffixResult.RawSuffix + ")" + } + + if p.Available(actualModel) { + candidates = append(candidates, ProviderCandidate{ + Provider: p, + Model: actualModel, + }) + } + } + + return candidates +} + +// loadModelMappings loads model-mappings from config. +func (r *Router) loadModelMappings(mappings []config.AmpModelMapping) { + for _, m := range mappings { + from := strings.ToLower(strings.TrimSpace(m.From)) + to := strings.TrimSpace(m.To) + if from != "" && to != "" { + r.modelMappings[from] = to + } + } +} + +// loadOAuthAliases loads oauth-model-alias from config. +func (r *Router) loadOAuthAliases(aliases map[string][]config.OAuthModelAlias) { + for _, entries := range aliases { + for _, entry := range entries { + name := strings.ToLower(strings.TrimSpace(entry.Name)) + alias := strings.TrimSpace(entry.Alias) + if name != "" && alias != "" && name != alias { + r.oauthAliases[name] = append(r.oauthAliases[name], alias) + } + } + } +} diff --git a/internal/routing/router_test.go b/internal/routing/router_test.go new file mode 100644 index 00000000..ffa01ef9 --- /dev/null +++ b/internal/routing/router_test.go @@ -0,0 +1,115 @@ +package routing + +import ( + "context" + "testing" + + "github.com/router-for-me/CLIProxyAPI/v6/internal/config" + "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor" + "github.com/stretchr/testify/assert" +) + +// mockProvider is a test double for Provider. +type mockProvider struct { + name string + providerType ProviderType + supportsModels map[string]bool + available bool + priority int +} + +func (m *mockProvider) Name() string { return m.name } +func (m *mockProvider) Type() ProviderType { return m.providerType } +func (m *mockProvider) SupportsModel(model string) bool { return m.supportsModels[model] } +func (m *mockProvider) Available(model string) bool { return m.available } +func (m *mockProvider) Priority() int { return m.priority } +func (m *mockProvider) Execute(ctx context.Context, model string, req executor.Request) (executor.Response, error) { + return executor.Response{}, nil +} +func (m *mockProvider) ExecuteStream(ctx context.Context, model string, req executor.Request) (<-chan executor.StreamChunk, error) { + return nil, nil +} + +func TestRouter_Resolve_ModelMappings(t *testing.T) { + registry := NewRegistry() + + // Add a provider + p := &mockProvider{ + name: "test-provider", + providerType: ProviderTypeOAuth, + supportsModels: map[string]bool{"target-model": true}, + available: true, + priority: 1, + } + registry.Register(p) + + // Create router with model mapping + cfg := &config.Config{ + AmpCode: config.AmpCode{ + ModelMappings: []config.AmpModelMapping{ + {From: "user-model", To: "target-model"}, + }, + }, + } + router := NewRouter(registry, cfg) + + // Resolve + decision := router.Resolve("user-model") + + assert.Equal(t, "user-model", decision.RequestedModel) + assert.Equal(t, "target-model", decision.ResolvedModel) + assert.Len(t, decision.Candidates, 1) + assert.Equal(t, "target-model", decision.Candidates[0].Model) +} + +func TestRouter_Resolve_OAuthAliases(t *testing.T) { + registry := NewRegistry() + + // Add providers + p1 := &mockProvider{ + name: "oauth-1", + providerType: ProviderTypeOAuth, + supportsModels: map[string]bool{"primary-model": true}, + available: true, + priority: 1, + } + p2 := &mockProvider{ + name: "oauth-2", + providerType: ProviderTypeOAuth, + supportsModels: map[string]bool{"fallback-model": true}, + available: true, + priority: 2, + } + registry.Register(p1) + registry.Register(p2) + + // Create router with oauth aliases + cfg := &config.Config{ + OAuthModelAlias: map[string][]config.OAuthModelAlias{ + "test-channel": { + {Name: "primary-model", Alias: "fallback-model"}, + }, + }, + } + router := NewRouter(registry, cfg) + + // Resolve + decision := router.Resolve("primary-model") + + assert.Equal(t, "primary-model", decision.ResolvedModel) + assert.Len(t, decision.Candidates, 2) + // Primary should come first (lower priority value) + assert.Equal(t, "primary-model", decision.Candidates[0].Model) + assert.Equal(t, "fallback-model", decision.Candidates[1].Model) +} + +func TestRouter_Resolve_NoProviders(t *testing.T) { + registry := NewRegistry() + cfg := &config.Config{} + router := NewRouter(registry, cfg) + + decision := router.Resolve("unknown-model") + + assert.Equal(t, "unknown-model", decision.ResolvedModel) + assert.Empty(t, decision.Candidates) +} diff --git a/internal/translator/openai/claude/openai_claude_request.go b/internal/translator/openai/claude/openai_claude_request.go index dc832e9c..8fac14ec 100644 --- a/internal/translator/openai/claude/openai_claude_request.go +++ b/internal/translator/openai/claude/openai_claude_request.go @@ -61,10 +61,13 @@ func ConvertClaudeRequestToOpenAI(modelName string, inputRawJSON []byte, stream out, _ = sjson.Set(out, "stream", stream) // Thinking: Convert Claude thinking.budget_tokens to OpenAI reasoning_effort + // Also track if thinking is enabled to ensure reasoning_content is added for tool_calls + thinkingEnabled := false if thinkingConfig := root.Get("thinking"); thinkingConfig.Exists() && thinkingConfig.IsObject() { if thinkingType := thinkingConfig.Get("type"); thinkingType.Exists() { switch thinkingType.String() { case "enabled": + thinkingEnabled = true if budgetTokens := thinkingConfig.Get("budget_tokens"); budgetTokens.Exists() { budget := int(budgetTokens.Int()) if effort, ok := thinking.ConvertBudgetToLevel(budget); ok && effort != "" { @@ -217,6 +220,10 @@ func ConvertClaudeRequestToOpenAI(modelName string, inputRawJSON []byte, stream // Add reasoning_content if present if hasReasoning { msgJSON, _ = sjson.Set(msgJSON, "reasoning_content", reasoningContent) + } else if thinkingEnabled && hasToolCalls { + // Claude API requires reasoning_content in assistant messages with tool_calls + // when thinking mode is enabled, even if empty + msgJSON, _ = sjson.Set(msgJSON, "reasoning_content", "") } // Add tool_calls if present (in same message as content) diff --git a/internal/translator/openai/claude/openai_claude_request_test.go b/internal/translator/openai/claude/openai_claude_request_test.go index d08de1b2..3e7fe8fd 100644 --- a/internal/translator/openai/claude/openai_claude_request_test.go +++ b/internal/translator/openai/claude/openai_claude_request_test.go @@ -588,3 +588,124 @@ func TestConvertClaudeRequestToOpenAI_AssistantThinkingToolUseThinkingSplit(t *t t.Fatalf("Expected reasoning_content %q, got %q", "t1\n\nt2", got) } } + +// TestConvertClaudeRequestToOpenAI_ThinkingEnabledToolCallsNoReasoning tests that +// when thinking mode is enabled and assistant message has tool_calls but no thinking content, +// an empty reasoning_content is added to satisfy Claude API requirements. +func TestConvertClaudeRequestToOpenAI_ThinkingEnabledToolCallsNoReasoning(t *testing.T) { + tests := []struct { + name string + inputJSON string + wantHasReasoningContent bool + wantReasoningContent string + }{ + { + name: "thinking enabled with tool_calls but no thinking content adds empty reasoning_content", + inputJSON: `{ + "model": "claude-3-opus", + "thinking": {"type": "enabled", "budget_tokens": 4000}, + "messages": [{ + "role": "assistant", + "content": [ + {"type": "text", "text": "I will help you."}, + {"type": "tool_use", "id": "tool_1", "name": "read_file", "input": {"path": "/test.txt"}} + ] + }] + }`, + wantHasReasoningContent: true, + wantReasoningContent: "", + }, + { + name: "thinking enabled with tool_calls and thinking content uses actual reasoning", + inputJSON: `{ + "model": "claude-3-opus", + "thinking": {"type": "enabled", "budget_tokens": 4000}, + "messages": [{ + "role": "assistant", + "content": [ + {"type": "thinking", "thinking": "Let me analyze this..."}, + {"type": "text", "text": "I will help you."}, + {"type": "tool_use", "id": "tool_1", "name": "read_file", "input": {"path": "/test.txt"}} + ] + }] + }`, + wantHasReasoningContent: true, + wantReasoningContent: "Let me analyze this...", + }, + { + name: "thinking disabled with tool_calls does not add reasoning_content", + inputJSON: `{ + "model": "claude-3-opus", + "thinking": {"type": "disabled"}, + "messages": [{ + "role": "assistant", + "content": [ + {"type": "text", "text": "I will help you."}, + {"type": "tool_use", "id": "tool_1", "name": "read_file", "input": {"path": "/test.txt"}} + ] + }] + }`, + wantHasReasoningContent: false, + wantReasoningContent: "", + }, + { + name: "no thinking config with tool_calls does not add reasoning_content", + inputJSON: `{ + "model": "claude-3-opus", + "messages": [{ + "role": "assistant", + "content": [ + {"type": "text", "text": "I will help you."}, + {"type": "tool_use", "id": "tool_1", "name": "read_file", "input": {"path": "/test.txt"}} + ] + }] + }`, + wantHasReasoningContent: false, + wantReasoningContent: "", + }, + { + name: "thinking enabled without tool_calls and no thinking content does not add reasoning_content", + inputJSON: `{ + "model": "claude-3-opus", + "thinking": {"type": "enabled", "budget_tokens": 4000}, + "messages": [{ + "role": "assistant", + "content": [ + {"type": "text", "text": "Simple response without tools."} + ] + }] + }`, + wantHasReasoningContent: false, + wantReasoningContent: "", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := ConvertClaudeRequestToOpenAI("test-model", []byte(tt.inputJSON), false) + resultJSON := gjson.ParseBytes(result) + + messages := resultJSON.Get("messages").Array() + if len(messages) == 0 { + t.Fatal("Expected at least one message") + } + + assistantMsg := messages[0] + if assistantMsg.Get("role").String() != "assistant" { + t.Fatalf("Expected assistant message, got %s", assistantMsg.Get("role").String()) + } + + hasReasoningContent := assistantMsg.Get("reasoning_content").Exists() + if hasReasoningContent != tt.wantHasReasoningContent { + t.Errorf("reasoning_content existence = %v, want %v", hasReasoningContent, tt.wantHasReasoningContent) + } + + if hasReasoningContent { + gotReasoningContent := assistantMsg.Get("reasoning_content").String() + if gotReasoningContent != tt.wantReasoningContent { + t.Errorf("reasoning_content = %q, want %q", gotReasoningContent, tt.wantReasoningContent) + } + } + }) + } +} diff --git a/sdk/api/handlers/handlers.go b/sdk/api/handlers/handlers.go index 85657e12..ac76dd9b 100644 --- a/sdk/api/handlers/handlers.go +++ b/sdk/api/handlers/handlers.go @@ -255,16 +255,15 @@ func (h *BaseAPIHandler) GetContextWithCancel(handler interfaces.APIHandler, c * parentCtx = logging.WithRequestID(parentCtx, requestID) } } - newCtx, cancel := context.WithCancel(parentCtx) - if requestCtx != nil && requestCtx != parentCtx { - go func() { - select { - case <-requestCtx.Done(): - cancel() - case <-newCtx.Done(): - } - }() + + // Use requestCtx as base if available to preserve amp context values (fallback_models, etc.) + // Falls back to parentCtx if no request context + baseCtx := parentCtx + if requestCtx != nil { + baseCtx = requestCtx } + + newCtx, cancel := context.WithCancel(baseCtx) newCtx = context.WithValue(newCtx, "gin", c) newCtx = context.WithValue(newCtx, "handler", handler) return newCtx, func(params ...interface{}) {