Compare commits

..

18 Commits

Author SHA1 Message Date
Luis Pater
4874253d1e Merge pull request #1425 from router-for-me/auth
fix(cliproxy): update auth before model registration
2026-02-04 15:01:01 +08:00
Luis Pater
b72250349f Merge pull request #1423 from router-for-me/watcher
feat(watcher): log auth field changes on reload
2026-02-04 15:00:38 +08:00
hkfires
116573311f fix(cliproxy): update auth before model registration 2026-02-04 14:03:15 +08:00
hkfires
4af712544d feat(watcher): log auth field changes on reload
Cache parsed auth contents and compute redacted diffs for prefix, proxy_url,
and disabled when auth files are added or updated.
2026-02-04 12:29:56 +08:00
Luis Pater
1548c567ab feat(pprof): add support for configurable pprof HTTP debug server
- Introduced a new `pprof` server to enable/debug HTTP profiling.
- Added configuration options for enabling/disabling and specifying the server address.
- Integrated pprof server lifecycle management with `Service`.

#1287
2026-02-04 02:39:26 +08:00
Luis Pater
5b23fc570c Merge pull request #1396 from Xm798/fix/log-dir-tilde-expansion
fix(logging): expand tilde in auth-dir path for log directory
2026-02-04 02:00:13 +08:00
Luis Pater
04e1c7a05a docs: reorganize and update README entries for CLIProxyAPI projects 2026-02-04 01:49:27 +08:00
Luis Pater
9181e72204 Merge pull request #1409 from wangdabaoqq/main
docs: Add a new client application - Lin Jun
2026-02-04 01:47:31 +08:00
宝宝宝
4939865f6d Add a new client application - Lin Jun 2026-02-03 23:55:24 +08:00
宝宝宝
3da7f7482e Add a new client application - Lin Jun 2026-02-03 23:36:34 +08:00
宝宝宝
9072b029b2 Add a new client application - Lin Jun 2026-02-03 23:35:53 +08:00
宝宝宝
c296cfb8c0 docs: Add a new client application - Lin Jun 2026-02-03 23:32:50 +08:00
Luis Pater
2707377fcb docs: add AICodeMirror sponsorship details to README files 2026-02-03 22:34:50 +08:00
Luis Pater
259f586ff7 Fixed: #1398
fix(translator): use model group caching for client signature validation
2026-02-03 22:04:52 +08:00
Luis Pater
d885b81f23 Fixed: #1403
fix(translator): handle "input" field transformation for OpenAI responses
2026-02-03 21:49:30 +08:00
Luis Pater
fe6bffd080 fixed: #1407
fix(translator): adjust "developer" role to "user" and ignore unsupported tool types
2026-02-03 21:41:17 +08:00
Luis Pater
250f212fa3 fix(executor): handle "global" location in AI platform URL generation 2026-02-03 01:39:57 +08:00
Cyrus
a275db3fdb fix(logging): expand tilde in auth-dir and log resolution errors
- Use util.ResolveAuthDir to properly expand ~ to user home directory
- Fixes issue where logs were created in literal "~/.cli-proxy-api" folder
- Add warning log when auth-dir resolution fails for debugging

Bug introduced in 62e2b67 (refactor(logging): centralize log directory
resolution logic), where strings.TrimSpace was used instead of
util.ResolveAuthDir to process auth-dir path.
2026-02-03 00:02:54 +08:00
48 changed files with 664 additions and 4032 deletions

View File

@@ -30,6 +30,10 @@ Get 10% OFF GLM CODING PLANhttps://z.ai/subscribe?ic=8JVLJQFSKB
<td width="180"><a href="https://cubence.com/signup?code=CLIPROXYAPI&source=cpa"><img src="./assets/cubence.png" alt="Cubence" width="150"></a></td>
<td>Thanks to Cubence for sponsoring this project! Cubence is a reliable and efficient API relay service provider, offering relay services for Claude Code, Codex, Gemini, and more. Cubence provides special discounts for our software users: register using <a href="https://cubence.com/signup?code=CLIPROXYAPI&source=cpa">this link</a> and enter the "CLIPROXYAPI" promo code during recharge to get 10% off.</td>
</tr>
<tr>
<td width="180"><a href="https://www.aicodemirror.com/register?invitecode=TJNAIF"><img src="./assets/aicodemirror.png" alt="AICodeMirror" width="150"></a></td>
<td>Thanks to AICodeMirror for sponsoring this project! AICodeMirror provides official high-stability relay services for Claude Code / Codex / Gemini CLI, with enterprise-grade concurrency, fast invoicing, and 24/7 dedicated technical support. Claude Code / Codex / Gemini official channels at 38% / 2% / 9% of original price, with extra discounts on top-ups! AICodeMirror offers special benefits for CLIProxyAPI users: register via <a href="https://www.aicodemirror.com/register?invitecode=TJNAIF">this link</a> to enjoy 20% off your first top-up, and enterprise customers can get up to 25% off!</td>
</tr>
</tbody>
</table>
@@ -142,6 +146,10 @@ A lightweight web admin panel for CLIProxyAPI with health checks, resource monit
A Windows tray application implemented using PowerShell scripts, without relying on any third-party libraries. The main features include: automatic creation of shortcuts, silent running, password management, channel switching (Main / Plus), and automatic downloading and updating.
### [霖君](https://github.com/wangdabaoqq/LinJun)
霖君 is a cross-platform desktop application for managing AI programming assistants, supporting macOS, Windows, and Linux systems. Unified management of Claude Code, Gemini CLI, OpenAI Codex, Qwen Code, and other AI coding tools, with local proxy for multi-account quota tracking and one-click configuration.
> [!NOTE]
> If you developed a project based on CLIProxyAPI, please open a PR to add it to this list.

View File

@@ -30,6 +30,10 @@ GLM CODING PLAN 是专为AI编码打造的订阅套餐每月最低仅需20元
<td width="180"><a href="https://cubence.com/signup?code=CLIPROXYAPI&source=cpa"><img src="./assets/cubence.png" alt="Cubence" width="150"></a></td>
<td>感谢 Cubence 对本项目的赞助Cubence 是一家可靠高效的 API 中转服务商,提供 Claude Code、Codex、Gemini 等多种服务的中转。Cubence 为本软件用户提供了特别优惠:使用<a href="https://cubence.com/signup?code=CLIPROXYAPI&source=cpa">此链接</a>注册,并在充值时输入 "CLIPROXYAPI" 优惠码即可享受九折优惠。</td>
</tr>
<tr>
<td width="180"><a href="https://www.aicodemirror.com/register?invitecode=TJNAIF"><img src="./assets/aicodemirror.png" alt="AICodeMirror" width="150"></a></td>
<td>感谢 AICodeMirror 赞助了本项目AICodeMirror 提供 Claude Code / Codex / Gemini CLI 官方高稳定中转服务支持企业级高并发、极速开票、7×24 专属技术支持。 Claude Code / Codex / Gemini 官方渠道低至 3.8 / 0.2 / 0.9 折充值更有折上折AICodeMirror 为 CLIProxyAPI 的用户提供了特别福利,通过<a href="https://www.aicodemirror.com/register?invitecode=TJNAIF">此链接</a>注册的用户可享受首充8折企业客户最高可享 7.5 折!</td>
</tr>
</tbody>
</table>
@@ -137,6 +141,14 @@ Windows 桌面应用,基于 Tauri + React 构建,用于通过 CLIProxyAPI
面向 CLIProxyAPI 的 Web 管理面板,提供健康检查、资源监控、日志查看、自动更新、请求统计与定价展示,支持一键安装与 systemd 服务。
### [CLIProxyAPI Tray](https://github.com/kitephp/CLIProxyAPI_Tray)
Windows 托盘应用,基于 PowerShell 脚本实现不依赖任何第三方库。主要功能包括自动创建快捷方式、静默运行、密码管理、通道切换Main / Plus以及自动下载与更新。
### [霖君](https://github.com/wangdabaoqq/LinJun)
霖君是一款用于管理AI编程助手的跨平台桌面应用支持macOS、Windows、Linux系统。统一管理Claude Code、Gemini CLI、OpenAI Codex、Qwen Code等AI编程工具本地代理实现多账户配额跟踪和一键配置。
> [!NOTE]
> 如果你开发了基于 CLIProxyAPI 的项目,请提交一个 PR拉取请求将其添加到此列表中。
@@ -148,10 +160,6 @@ Windows 桌面应用,基于 Tauri + React 构建,用于通过 CLIProxyAPI
基于 Next.js 的实现,灵感来自 CLIProxyAPI易于安装使用自研格式转换OpenAI/Claude/Gemini/Ollama、组合系统与自动回退、多账户管理指数退避、Next.js Web 控制台,并支持 Cursor、Claude Code、Cline、RooCode 等 CLI 工具,无需 API 密钥。
### [CLIProxyAPI Tray](https://github.com/kitephp/CLIProxyAPI_Tray)
Windows 托盘应用,基于 PowerShell 脚本实现不依赖任何第三方库。主要功能包括自动创建快捷方式、静默运行、密码管理、通道切换Main / Plus以及自动下载与更新。
> [!NOTE]
> 如果你开发了 CLIProxyAPI 的移植或衍生项目,请提交 PR 将其添加到此列表中。

BIN
assets/aicodemirror.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 45 KiB

View File

@@ -40,6 +40,11 @@ api-keys:
# Enable debug logging
debug: false
# Enable pprof HTTP debug server (host:port). Keep it bound to localhost for safety.
pprof:
enable: false
addr: "127.0.0.1:8316"
# When true, disable high-overhead HTTP middleware features to reduce per-request memory usage under high concurrency.
commercial-mode: false

View File

@@ -125,8 +125,6 @@ func (m *AmpModule) Register(ctx modules.Context) error {
m.registerOnce.Do(func() {
// Initialize model mapper from config (for routing unavailable models to alternatives)
m.modelMapper = NewModelMapper(settings.ModelMappings)
// Load oauth-model-alias for provider lookup via aliases
m.modelMapper.UpdateOAuthModelAlias(ctx.Config.OAuthModelAlias)
// Store initial config for partial reload comparison
settingsCopy := settings
@@ -214,11 +212,6 @@ func (m *AmpModule) OnConfigUpdated(cfg *config.Config) error {
}
}
// Always update oauth-model-alias for model mapper (used for provider lookup)
if m.modelMapper != nil {
m.modelMapper.UpdateOAuthModelAlias(cfg.OAuthModelAlias)
}
if m.enabled {
// Check upstream URL change - now supports hot-reload
if newUpstreamURL == "" && oldUpstreamURL != "" {

View File

@@ -2,15 +2,12 @@ package amp
import (
"bytes"
"errors"
"io"
"net/http"
"net/http/httputil"
"strings"
"time"
"github.com/gin-gonic/gin"
"github.com/router-for-me/CLIProxyAPI/v6/internal/routing/ctxkeys"
"github.com/router-for-me/CLIProxyAPI/v6/internal/thinking"
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
log "github.com/sirupsen/logrus"
@@ -33,13 +30,7 @@ const (
)
// MappedModelContextKey is the Gin context key for passing mapped model names.
// Deprecated: Use ctxkeys.MappedModel instead.
const MappedModelContextKey = string(ctxkeys.MappedModel)
// FallbackModelsContextKey is the Gin context key for passing fallback model names.
// When the primary mapped model fails (e.g., quota exceeded), these models can be tried.
// Deprecated: Use ctxkeys.FallbackModels instead.
const FallbackModelsContextKey = string(ctxkeys.FallbackModels)
const MappedModelContextKey = "mapped_model"
// logAmpRouting logs the routing decision for an Amp request with structured fields
func logAmpRouting(routeType AmpRouteType, requestedModel, resolvedModel, provider, path string) {
@@ -86,10 +77,6 @@ func logAmpRouting(routeType AmpRouteType, requestedModel, resolvedModel, provid
// FallbackHandler wraps a standard handler with fallback logic to ampcode.com
// when the model's provider is not available in CLIProxyAPI
//
// Deprecated: FallbackHandler is deprecated in favor of routing.ModelRoutingWrapper.
// Use routing.NewModelRoutingWrapper() instead for unified routing logic.
// This type is kept for backward compatibility and test purposes.
type FallbackHandler struct {
getProxy func() *httputil.ReverseProxy
modelMapper ModelMapper
@@ -98,8 +85,6 @@ type FallbackHandler struct {
// NewFallbackHandler creates a new fallback handler wrapper
// The getProxy function allows lazy evaluation of the proxy (useful when proxy is created after routes)
//
// Deprecated: Use routing.NewModelRoutingWrapper() instead.
func NewFallbackHandler(getProxy func() *httputil.ReverseProxy) *FallbackHandler {
return &FallbackHandler{
getProxy: getProxy,
@@ -108,8 +93,6 @@ func NewFallbackHandler(getProxy func() *httputil.ReverseProxy) *FallbackHandler
}
// NewFallbackHandlerWithMapper creates a new fallback handler with model mapping support
//
// Deprecated: Use routing.NewModelRoutingWrapper() instead.
func NewFallbackHandlerWithMapper(getProxy func() *httputil.ReverseProxy, mapper ModelMapper, forceModelMappings func() bool) *FallbackHandler {
if forceModelMappings == nil {
forceModelMappings = func() bool { return false }
@@ -130,20 +113,6 @@ func (fh *FallbackHandler) SetModelMapper(mapper ModelMapper) {
// If the model's provider is not configured in CLIProxyAPI, it forwards to ampcode.com
func (fh *FallbackHandler) WrapHandler(handler gin.HandlerFunc) gin.HandlerFunc {
return func(c *gin.Context) {
// Swallow ErrAbortHandler panics from ReverseProxy to avoid noisy stack traces.
// ReverseProxy raises this panic when the client connection is closed prematurely
// (e.g., user cancels request, network disconnect) or when ServeHTTP is called
// with a ResponseWriter that doesn't implement http.CloseNotifier.
// This is an expected error condition, not a bug, so we handle it gracefully.
defer func() {
if rec := recover(); rec != nil {
if err, ok := rec.(error); ok && errors.Is(err, http.ErrAbortHandler) {
return
}
panic(rec)
}
}()
requestPath := c.Request.URL.Path
// Read the request body to extract the model name
@@ -173,57 +142,36 @@ func (fh *FallbackHandler) WrapHandler(handler gin.HandlerFunc) gin.HandlerFunc
thinkingSuffix = "(" + suffixResult.RawSuffix + ")"
}
// resolveMappedModels returns all mapped models (primary + fallbacks) and providers for the first one.
resolveMappedModels := func() ([]string, []string) {
resolveMappedModel := func() (string, []string) {
if fh.modelMapper == nil {
return nil, nil
return "", nil
}
mapper, ok := fh.modelMapper.(*DefaultModelMapper)
if !ok {
// Fallback to single model for non-DefaultModelMapper
mappedModel := fh.modelMapper.MapModel(modelName)
if mappedModel == "" {
mappedModel = fh.modelMapper.MapModel(normalizedModel)
}
if mappedModel == "" {
return nil, nil
}
mappedBaseModel := thinking.ParseSuffix(mappedModel).ModelName
mappedProviders := util.GetProviderName(mappedBaseModel)
if len(mappedProviders) == 0 {
return nil, nil
}
return []string{mappedModel}, mappedProviders
mappedModel := fh.modelMapper.MapModel(modelName)
if mappedModel == "" {
mappedModel = fh.modelMapper.MapModel(normalizedModel)
}
mappedModel = strings.TrimSpace(mappedModel)
if mappedModel == "" {
return "", nil
}
// Use MapModelWithFallbacks for DefaultModelMapper
mappedModels := mapper.MapModelWithFallbacks(modelName)
if len(mappedModels) == 0 {
mappedModels = mapper.MapModelWithFallbacks(normalizedModel)
}
if len(mappedModels) == 0 {
return nil, nil
}
// Apply thinking suffix if needed
for i, model := range mappedModels {
if thinkingSuffix != "" {
suffixResult := thinking.ParseSuffix(model)
if !suffixResult.HasSuffix {
mappedModels[i] = model + thinkingSuffix
}
// Preserve dynamic thinking suffix (e.g. "(xhigh)") when mapping applies, unless the target
// already specifies its own thinking suffix.
if thinkingSuffix != "" {
mappedSuffixResult := thinking.ParseSuffix(mappedModel)
if !mappedSuffixResult.HasSuffix {
mappedModel += thinkingSuffix
}
}
// Get providers for the first model
firstBaseModel := thinking.ParseSuffix(mappedModels[0]).ModelName
providers := util.GetProviderName(firstBaseModel)
if len(providers) == 0 {
return nil, nil
mappedBaseModel := thinking.ParseSuffix(mappedModel).ModelName
mappedProviders := util.GetProviderName(mappedBaseModel)
if len(mappedProviders) == 0 {
return "", nil
}
return mappedModels, providers
return mappedModel, mappedProviders
}
// Track resolved model for logging (may change if mapping is applied)
@@ -231,27 +179,21 @@ func (fh *FallbackHandler) WrapHandler(handler gin.HandlerFunc) gin.HandlerFunc
usedMapping := false
var providers []string
// Helper to apply model mapping and update state
applyMapping := func(mappedModels []string, mappedProviders []string) {
bodyBytes = rewriteModelInRequest(bodyBytes, mappedModels[0])
c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes))
c.Set(string(ctxkeys.MappedModel), mappedModels[0])
if len(mappedModels) > 1 {
c.Set(string(ctxkeys.FallbackModels), mappedModels[1:])
}
resolvedModel = mappedModels[0]
usedMapping = true
providers = mappedProviders
}
// Check if model mappings should be forced ahead of local API keys
forceMappings := fh.forceModelMappings != nil && fh.forceModelMappings()
if forceMappings {
// FORCE MODE: Check model mappings FIRST (takes precedence over local API keys)
// This allows users to route Amp requests to their preferred OAuth providers
if mappedModels, mappedProviders := resolveMappedModels(); len(mappedModels) > 0 {
applyMapping(mappedModels, mappedProviders)
if mappedModel, mappedProviders := resolveMappedModel(); mappedModel != "" {
// Mapping found and provider available - rewrite the model in request body
bodyBytes = rewriteModelInRequest(bodyBytes, mappedModel)
c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes))
// Store mapped model in context for handlers that check it (like gemini bridge)
c.Set(MappedModelContextKey, mappedModel)
resolvedModel = mappedModel
usedMapping = true
providers = mappedProviders
}
// If no mapping applied, check for local providers
@@ -264,8 +206,15 @@ func (fh *FallbackHandler) WrapHandler(handler gin.HandlerFunc) gin.HandlerFunc
if len(providers) == 0 {
// No providers configured - check if we have a model mapping
if mappedModels, mappedProviders := resolveMappedModels(); len(mappedModels) > 0 {
applyMapping(mappedModels, mappedProviders)
if mappedModel, mappedProviders := resolveMappedModel(); mappedModel != "" {
// Mapping found and provider available - rewrite the model in request body
bodyBytes = rewriteModelInRequest(bodyBytes, mappedModel)
c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes))
// Store mapped model in context for handlers that check it (like gemini bridge)
c.Set(MappedModelContextKey, mappedModel)
resolvedModel = mappedModel
usedMapping = true
providers = mappedProviders
}
}
}

View File

@@ -1,326 +0,0 @@
package amp
import (
"bytes"
"net/http"
"net/http/httptest"
"net/http/httputil"
"net/url"
"strings"
"testing"
"github.com/gin-gonic/gin"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
"github.com/router-for-me/CLIProxyAPI/v6/internal/routing/testutil"
"github.com/stretchr/testify/assert"
)
// Characterization tests for fallback_handlers.go using testutil recorders
// These tests capture existing behavior before refactoring to routing layer
func TestCharacterization_LocalProvider(t *testing.T) {
gin.SetMode(gin.TestMode)
// Register a mock provider for the test model
reg := registry.GetGlobalRegistry()
reg.RegisterClient("char-test-local", "anthropic", []*registry.ModelInfo{
{ID: "test-model-local"},
})
defer reg.UnregisterClient("char-test-local")
// Setup recorders
proxyRecorder := testutil.NewFakeProxyRecorder()
handlerRecorder := testutil.NewFakeHandlerRecorder()
// Create gin context
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
body := `{"model": "test-model-local", "messages": [{"role": "user", "content": "hello"}]}`
req := httptest.NewRequest(http.MethodPost, "/api/provider/anthropic/v1/messages", bytes.NewReader([]byte(body)))
req.Header.Set("Content-Type", "application/json")
c.Request = req
// Create fallback handler with proxy recorder
// Create a test server to act as the proxy target
proxyServer := httptest.NewServer(proxyRecorder.ToHandler())
defer proxyServer.Close()
fh := NewFallbackHandler(func() *httputil.ReverseProxy {
// Create a reverse proxy that forwards to our test server
targetURL, _ := url.Parse(proxyServer.URL)
return httputil.NewSingleHostReverseProxy(targetURL)
})
// Execute
wrapped := fh.WrapHandler(handlerRecorder.GinHandler())
wrapped(c)
// Assert: proxy NOT called
assert.False(t, proxyRecorder.Called, "proxy should NOT be called for local provider")
// Assert: local handler called once
assert.True(t, handlerRecorder.WasCalled(), "local handler should be called")
assert.Equal(t, 1, handlerRecorder.GetCallCount(), "local handler should be called exactly once")
// Assert: request body model unchanged
assert.Contains(t, string(handlerRecorder.RequestBody), "test-model-local", "request body model should be unchanged")
}
func TestCharacterization_ModelMapping(t *testing.T) {
gin.SetMode(gin.TestMode)
// Register a mock provider for the TARGET model (the mapped-to model)
reg := registry.GetGlobalRegistry()
reg.RegisterClient("char-test-mapped", "openai", []*registry.ModelInfo{
{ID: "gpt-4-local"},
})
defer reg.UnregisterClient("char-test-mapped")
// Setup recorders
proxyRecorder := testutil.NewFakeProxyRecorder()
handlerRecorder := testutil.NewFakeHandlerRecorder()
// Create model mapper with a mapping
mapper := NewModelMapper([]config.AmpModelMapping{
{From: "gpt-4-turbo", To: "gpt-4-local"},
})
// Create gin context
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
// Request with original model that gets mapped
body := `{"model": "gpt-4-turbo", "messages": [{"role": "user", "content": "hello"}]}`
req := httptest.NewRequest(http.MethodPost, "/api/provider/openai/v1/chat/completions", bytes.NewReader([]byte(body)))
req.Header.Set("Content-Type", "application/json")
c.Request = req
// Create fallback handler with mapper
proxyServer := httptest.NewServer(proxyRecorder.ToHandler())
defer proxyServer.Close()
fh := NewFallbackHandlerWithMapper(func() *httputil.ReverseProxy {
targetURL, _ := url.Parse(proxyServer.URL)
return httputil.NewSingleHostReverseProxy(targetURL)
}, mapper, func() bool { return false })
// Execute - use handler that returns model in response for rewriter to work
wrapped := fh.WrapHandler(handlerRecorder.GinHandlerWithModel())
wrapped(c)
// Assert: proxy NOT called
assert.False(t, proxyRecorder.Called, "proxy should NOT be called for model mapping")
// Assert: local handler called once
assert.True(t, handlerRecorder.WasCalled(), "local handler should be called")
assert.Equal(t, 1, handlerRecorder.GetCallCount(), "local handler should be called exactly once")
// Assert: request body model was rewritten to mapped model
assert.Contains(t, string(handlerRecorder.RequestBody), "gpt-4-local", "request body model should be rewritten to mapped model")
assert.NotContains(t, string(handlerRecorder.RequestBody), "gpt-4-turbo", "request body should NOT contain original model")
// Assert: context has mapped_model key set
mappedModel, exists := handlerRecorder.GetContextKey("mapped_model")
assert.True(t, exists, "context should have mapped_model key")
assert.Equal(t, "gpt-4-local", mappedModel, "mapped_model should be the target model")
// Assert: response body model rewritten back to original
// The response writer should rewrite model names in the response
responseBody := w.Body.String()
assert.Contains(t, responseBody, "gpt-4-turbo", "response should have original model name")
}
func TestCharacterization_AmpCreditsProxy(t *testing.T) {
gin.SetMode(gin.TestMode)
// Setup recorders - NO local provider registered, NO mapping configured
proxyRecorder := testutil.NewFakeProxyRecorder()
handlerRecorder := testutil.NewFakeHandlerRecorder()
// Create gin context with CloseNotifier support (required for ReverseProxy)
w := testutil.NewCloseNotifierRecorder()
c, _ := gin.CreateTestContext(w)
// Request with a model that has no local provider and no mapping
body := `{"model": "unknown-model-no-provider", "messages": [{"role": "user", "content": "hello"}]}`
req := httptest.NewRequest(http.MethodPost, "/api/provider/openai/v1/chat/completions", bytes.NewReader([]byte(body)))
req.Header.Set("Content-Type", "application/json")
c.Request = req
// Create fallback handler
proxyServer := httptest.NewServer(proxyRecorder.ToHandler())
defer proxyServer.Close()
fh := NewFallbackHandler(func() *httputil.ReverseProxy {
targetURL, _ := url.Parse(proxyServer.URL)
return httputil.NewSingleHostReverseProxy(targetURL)
})
// Execute
wrapped := fh.WrapHandler(handlerRecorder.GinHandler())
wrapped(c)
// Assert: proxy called once
assert.True(t, proxyRecorder.Called, "proxy should be called when no local provider and no mapping")
assert.Equal(t, 1, proxyRecorder.GetCallCount(), "proxy should be called exactly once")
// Assert: local handler NOT called
assert.False(t, handlerRecorder.WasCalled(), "local handler should NOT be called when falling back to proxy")
// Assert: body forwarded to proxy is original (no rewrite)
assert.Contains(t, string(proxyRecorder.RequestBody), "unknown-model-no-provider", "request body model should be unchanged when proxying")
}
func TestCharacterization_BodyRestore(t *testing.T) {
gin.SetMode(gin.TestMode)
// Register a mock provider for the test model
reg := registry.GetGlobalRegistry()
reg.RegisterClient("char-test-body", "anthropic", []*registry.ModelInfo{
{ID: "test-model-body"},
})
defer reg.UnregisterClient("char-test-body")
// Setup recorders
proxyRecorder := testutil.NewFakeProxyRecorder()
handlerRecorder := testutil.NewFakeHandlerRecorder()
// Create gin context
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
// Create a complex request body that will be read by the wrapper for model extraction
originalBody := `{"model": "test-model-body", "messages": [{"role": "user", "content": "hello"}], "temperature": 0.7, "stream": true}`
req := httptest.NewRequest(http.MethodPost, "/api/provider/anthropic/v1/messages", bytes.NewReader([]byte(originalBody)))
req.Header.Set("Content-Type", "application/json")
c.Request = req
// Create fallback handler with proxy recorder
proxyServer := httptest.NewServer(proxyRecorder.ToHandler())
defer proxyServer.Close()
fh := NewFallbackHandler(func() *httputil.ReverseProxy {
targetURL, _ := url.Parse(proxyServer.URL)
return httputil.NewSingleHostReverseProxy(targetURL)
})
// Execute
wrapped := fh.WrapHandler(handlerRecorder.GinHandler())
wrapped(c)
// Assert: local handler called (not proxy, since we have a local provider)
assert.True(t, handlerRecorder.WasCalled(), "local handler should be called")
assert.False(t, proxyRecorder.Called, "proxy should NOT be called for local provider")
// Assert: handler receives complete original body
// This verifies that the body was properly restored after the wrapper read it for model extraction
assert.Equal(t, originalBody, string(handlerRecorder.RequestBody), "handler should receive complete original body after wrapper reads it for model extraction")
}
// TestCharacterization_GeminiV1Beta1_PostModels tests that POST requests with /models/ path use Gemini bridge handler
// This is a characterization test for the route gating logic in routes.go
func TestCharacterization_GeminiV1Beta1_PostModels(t *testing.T) {
gin.SetMode(gin.TestMode)
// Register a mock provider for the test model (Gemini format uses path-based model extraction)
reg := registry.GetGlobalRegistry()
reg.RegisterClient("char-test-gemini", "google", []*registry.ModelInfo{
{ID: "gemini-pro"},
})
defer reg.UnregisterClient("char-test-gemini")
// Setup recorders
proxyRecorder := testutil.NewFakeProxyRecorder()
handlerRecorder := testutil.NewFakeHandlerRecorder()
// Create a test server for the proxy
proxyServer := httptest.NewServer(proxyRecorder.ToHandler())
defer proxyServer.Close()
// Create fallback handler
fh := NewFallbackHandler(func() *httputil.ReverseProxy {
targetURL, _ := url.Parse(proxyServer.URL)
return httputil.NewSingleHostReverseProxy(targetURL)
})
// Create the Gemini bridge handler (simulating what routes.go does)
geminiBridge := createGeminiBridgeHandler(handlerRecorder.GinHandler())
geminiV1Beta1Handler := fh.WrapHandler(geminiBridge)
// Create router with the same gating logic as routes.go
r := gin.New()
r.Any("/api/provider/google/v1beta1/*path", func(c *gin.Context) {
if c.Request.Method == "POST" {
if path := c.Param("path"); strings.Contains(path, "/models/") {
// POST with /models/ path -> use Gemini bridge with fallback handler
geminiV1Beta1Handler(c)
return
}
}
// Non-POST or no /models/ in path -> proxy upstream
proxyRecorder.ServeHTTP(c.Writer, c.Request)
})
// Execute: POST request with /models/ in path
body := `{"contents": [{"role": "user", "parts": [{"text": "hello"}]}]}`
req := httptest.NewRequest(http.MethodPost, "/api/provider/google/v1beta1/publishers/google/models/gemini-pro:generateContent", bytes.NewReader([]byte(body)))
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
// Assert: local Gemini handler called
assert.True(t, handlerRecorder.WasCalled(), "local Gemini handler should be called for POST /models/")
// Assert: proxy NOT called
assert.False(t, proxyRecorder.Called, "proxy should NOT be called for POST /models/ path")
}
// TestCharacterization_GeminiV1Beta1_GetProxies tests that GET requests to Gemini v1beta1 always use proxy
// This is a characterization test for the route gating logic in routes.go
func TestCharacterization_GeminiV1Beta1_GetProxies(t *testing.T) {
gin.SetMode(gin.TestMode)
// Setup recorders
proxyRecorder := testutil.NewFakeProxyRecorder()
handlerRecorder := testutil.NewFakeHandlerRecorder()
// Create a test server for the proxy
proxyServer := httptest.NewServer(proxyRecorder.ToHandler())
defer proxyServer.Close()
// Create fallback handler
fh := NewFallbackHandler(func() *httputil.ReverseProxy {
targetURL, _ := url.Parse(proxyServer.URL)
return httputil.NewSingleHostReverseProxy(targetURL)
})
// Create the Gemini bridge handler
geminiBridge := createGeminiBridgeHandler(handlerRecorder.GinHandler())
geminiV1Beta1Handler := fh.WrapHandler(geminiBridge)
// Create router with the same gating logic as routes.go
r := gin.New()
r.Any("/api/provider/google/v1beta1/*path", func(c *gin.Context) {
if c.Request.Method == "POST" {
if path := c.Param("path"); strings.Contains(path, "/models/") {
geminiV1Beta1Handler(c)
return
}
}
proxyRecorder.ServeHTTP(c.Writer, c.Request)
})
// Execute: GET request (even with /models/ in path)
req := httptest.NewRequest(http.MethodGet, "/api/provider/google/v1beta1/publishers/google/models/gemini-pro", nil)
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
// Assert: proxy called
assert.True(t, proxyRecorder.Called, "proxy should be called for GET requests")
assert.Equal(t, 1, proxyRecorder.GetCallCount(), "proxy should be called exactly once")
// Assert: local handler NOT called
assert.False(t, handlerRecorder.WasCalled(), "local handler should NOT be called for GET requests")
}

View File

@@ -2,7 +2,7 @@ package amp
import (
"bytes"
"io"
"encoding/json"
"net/http"
"net/http/httptest"
"net/http/httputil"
@@ -11,138 +11,63 @@ import (
"github.com/gin-gonic/gin"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
"github.com/stretchr/testify/assert"
)
// Characterization tests for fallback_handlers.go
// These tests capture existing behavior before refactoring to routing layer
func TestFallbackHandler_WrapHandler_LocalProvider_NoMapping(t *testing.T) {
func TestFallbackHandler_ModelMapping_PreservesThinkingSuffixAndRewritesResponse(t *testing.T) {
gin.SetMode(gin.TestMode)
// Setup: model that has local providers (gemini-2.5-pro is registered)
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
body := `{"model": "gemini-2.5-pro", "messages": [{"role": "user", "content": "hello"}]}`
req := httptest.NewRequest(http.MethodPost, "/api/provider/anthropic/v1/messages", bytes.NewReader([]byte(body)))
req.Header.Set("Content-Type", "application/json")
c.Request = req
// Handler that should be called (not proxy)
handlerCalled := false
handler := func(c *gin.Context) {
handlerCalled = true
c.JSON(200, gin.H{"status": "ok"})
}
// Create fallback handler
fh := NewFallbackHandler(func() *httputil.ReverseProxy {
return nil // no proxy
})
// Execute
wrapped := fh.WrapHandler(handler)
wrapped(c)
// Assert: handler should be called directly (no mapping needed)
assert.True(t, handlerCalled, "handler should be called for local provider")
assert.Equal(t, 200, w.Code)
}
func TestFallbackHandler_WrapHandler_MappingApplied(t *testing.T) {
gin.SetMode(gin.TestMode)
// Register a mock provider for the target model
reg := registry.GetGlobalRegistry()
reg.RegisterClient("test-client", "anthropic", []*registry.ModelInfo{
{ID: "claude-opus-4-5-thinking"},
reg.RegisterClient("test-client-amp-fallback", "codex", []*registry.ModelInfo{
{ID: "test/gpt-5.2", OwnedBy: "openai", Type: "codex"},
})
// Setup: model that needs mapping
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
body := `{"model": "claude-opus-4-5-20251101", "messages": [{"role": "user", "content": "hello"}]}`
req := httptest.NewRequest(http.MethodPost, "/api/provider/anthropic/v1/messages", bytes.NewReader([]byte(body)))
req.Header.Set("Content-Type", "application/json")
c.Request = req
// Handler to capture rewritten body
var capturedBody []byte
handler := func(c *gin.Context) {
capturedBody, _ = io.ReadAll(c.Request.Body)
c.JSON(200, gin.H{"status": "ok"})
}
// Create fallback handler with mapper
mapper := NewModelMapper([]config.AmpModelMapping{
{From: "claude-opus-4-5-20251101", To: "claude-opus-4-5-thinking"},
})
fh := NewFallbackHandlerWithMapper(
func() *httputil.ReverseProxy { return nil },
mapper,
func() bool { return false },
)
// Execute
wrapped := fh.WrapHandler(handler)
wrapped(c)
// Assert: body should be rewritten
assert.Contains(t, string(capturedBody), "claude-opus-4-5-thinking")
// Assert: context should have mapped model
mappedModel, exists := c.Get(MappedModelContextKey)
assert.True(t, exists, "MappedModelContextKey should be set")
assert.NotEmpty(t, mappedModel)
}
func TestFallbackHandler_WrapHandler_ThinkingSuffixPreserved(t *testing.T) {
gin.SetMode(gin.TestMode)
// Register a mock provider for the target model
reg := registry.GetGlobalRegistry()
reg.RegisterClient("test-client-2", "anthropic", []*registry.ModelInfo{
{ID: "claude-opus-4-5-thinking"},
})
w := httptest.NewRecorder()
c, _ := gin.CreateTestContext(w)
// Model with thinking suffix
body := `{"model": "claude-opus-4-5-20251101(xhigh)", "messages": []}`
req := httptest.NewRequest(http.MethodPost, "/api/provider/anthropic/v1/messages", bytes.NewReader([]byte(body)))
req.Header.Set("Content-Type", "application/json")
c.Request = req
var capturedBody []byte
handler := func(c *gin.Context) {
capturedBody, _ = io.ReadAll(c.Request.Body)
c.JSON(200, gin.H{"status": "ok"})
}
defer reg.UnregisterClient("test-client-amp-fallback")
mapper := NewModelMapper([]config.AmpModelMapping{
{From: "claude-opus-4-5-20251101", To: "claude-opus-4-5-thinking"},
{From: "gpt-5.2", To: "test/gpt-5.2"},
})
fh := NewFallbackHandlerWithMapper(
func() *httputil.ReverseProxy { return nil },
mapper,
func() bool { return false },
)
fallback := NewFallbackHandlerWithMapper(func() *httputil.ReverseProxy { return nil }, mapper, nil)
wrapped := fh.WrapHandler(handler)
wrapped(c)
handler := func(c *gin.Context) {
var req struct {
Model string `json:"model"`
}
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
// Assert: thinking suffix should be preserved
assert.Contains(t, string(capturedBody), "(xhigh)")
}
func TestFallbackHandler_WrapHandler_NoProvider_NoMapping_ProxyEnabled(t *testing.T) {
// Skip: httptest.ResponseRecorder doesn't implement http.CloseNotifier
// which is required by httputil.ReverseProxy. This test requires a real
// HTTP server and client to properly test proxy behavior.
t.Skip("requires real HTTP server for proxy testing")
c.JSON(http.StatusOK, gin.H{
"model": req.Model,
"seen_model": req.Model,
})
}
r := gin.New()
r.POST("/chat/completions", fallback.WrapHandler(handler))
reqBody := []byte(`{"model":"gpt-5.2(xhigh)"}`)
req := httptest.NewRequest(http.MethodPost, "/chat/completions", bytes.NewReader(reqBody))
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("Expected status 200, got %d", w.Code)
}
var resp struct {
Model string `json:"model"`
SeenModel string `json:"seen_model"`
}
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
t.Fatalf("Failed to parse response JSON: %v", err)
}
if resp.Model != "gpt-5.2(xhigh)" {
t.Errorf("Expected response model gpt-5.2(xhigh), got %s", resp.Model)
}
if resp.SeenModel != "test/gpt-5.2(xhigh)" {
t.Errorf("Expected handler to see test/gpt-5.2(xhigh), got %s", resp.SeenModel)
}
}

View File

@@ -30,98 +30,18 @@ type DefaultModelMapper struct {
mu sync.RWMutex
mappings map[string]string // exact: from -> to (normalized lowercase keys)
regexps []regexMapping // regex rules evaluated in order
// oauthAliasForward maps channel -> name (lower) -> []alias for oauth-model-alias lookup.
// This allows model-mappings targets to find providers via their aliases.
oauthAliasForward map[string]map[string][]string
}
// NewModelMapper creates a new model mapper with the given initial mappings.
func NewModelMapper(mappings []config.AmpModelMapping) *DefaultModelMapper {
m := &DefaultModelMapper{
mappings: make(map[string]string),
regexps: nil,
oauthAliasForward: nil,
mappings: make(map[string]string),
regexps: nil,
}
m.UpdateMappings(mappings)
return m
}
// UpdateOAuthModelAlias updates the oauth-model-alias lookup table.
// This is called during initialization and on config hot-reload.
func (m *DefaultModelMapper) UpdateOAuthModelAlias(aliases map[string][]config.OAuthModelAlias) {
m.mu.Lock()
defer m.mu.Unlock()
if len(aliases) == 0 {
m.oauthAliasForward = nil
return
}
forward := make(map[string]map[string][]string, len(aliases))
for rawChannel, entries := range aliases {
channel := strings.ToLower(strings.TrimSpace(rawChannel))
if channel == "" || len(entries) == 0 {
continue
}
channelMap := make(map[string][]string)
for _, entry := range entries {
name := strings.TrimSpace(entry.Name)
alias := strings.TrimSpace(entry.Alias)
if name == "" || alias == "" {
continue
}
if strings.EqualFold(name, alias) {
continue
}
nameKey := strings.ToLower(name)
channelMap[nameKey] = append(channelMap[nameKey], alias)
}
if len(channelMap) > 0 {
forward[channel] = channelMap
}
}
if len(forward) == 0 {
m.oauthAliasForward = nil
return
}
m.oauthAliasForward = forward
log.Debugf("amp model mapping: loaded oauth-model-alias for %d channel(s)", len(forward))
}
// findAllAliasesWithProviders returns all oauth-model-alias aliases for targetModel
// that have available providers. Useful for fallback when one alias is quota-exceeded.
func (m *DefaultModelMapper) findAllAliasesWithProviders(targetModel string) []string {
if m.oauthAliasForward == nil {
return nil
}
targetKey := strings.ToLower(strings.TrimSpace(targetModel))
if targetKey == "" {
return nil
}
var result []string
seen := make(map[string]struct{})
// Check all channels for this model name
for _, channelMap := range m.oauthAliasForward {
aliases := channelMap[targetKey]
for _, alias := range aliases {
aliasLower := strings.ToLower(alias)
if _, exists := seen[aliasLower]; exists {
continue
}
providers := util.GetProviderName(alias)
if len(providers) > 0 {
result = append(result, alias)
seen[aliasLower] = struct{}{}
}
}
}
return result
}
// MapModel checks if a mapping exists for the requested model and if the
// target model has available local providers. Returns the mapped model name
// or empty string if no valid mapping exists.
@@ -131,19 +51,8 @@ func (m *DefaultModelMapper) findAllAliasesWithProviders(targetModel string) []s
// However, if the mapping target already contains a suffix, the config suffix
// takes priority over the user's suffix.
func (m *DefaultModelMapper) MapModel(requestedModel string) string {
models := m.MapModelWithFallbacks(requestedModel)
if len(models) == 0 {
return ""
}
return models[0]
}
// MapModelWithFallbacks returns all possible target models for the requested model,
// including fallback aliases from oauth-model-alias. The first model is the primary target,
// and subsequent models are fallbacks to try if the primary is unavailable (e.g., quota exceeded).
func (m *DefaultModelMapper) MapModelWithFallbacks(requestedModel string) []string {
if requestedModel == "" {
return nil
return ""
}
m.mu.RLock()
@@ -169,54 +78,34 @@ func (m *DefaultModelMapper) MapModelWithFallbacks(requestedModel string) []stri
}
}
if !exists {
return nil
return ""
}
}
// Check if target model already has a thinking suffix (config priority)
targetResult := thinking.ParseSuffix(targetModel)
targetBase := targetResult.ModelName
// Helper to apply suffix to a model
applySuffix := func(model string) string {
modelResult := thinking.ParseSuffix(model)
if modelResult.HasSuffix {
return model
}
if requestResult.HasSuffix && requestResult.RawSuffix != "" {
return model + "(" + requestResult.RawSuffix + ")"
}
return model
}
// Verify target model has available providers (use base model for lookup)
providers := util.GetProviderName(targetBase)
// If direct provider available, return it as primary
if len(providers) > 0 {
return []string{applySuffix(targetModel)}
}
// No direct providers - check oauth-model-alias for all aliases that have providers
allAliases := m.findAllAliasesWithProviders(targetBase)
if len(allAliases) == 0 {
providers := util.GetProviderName(targetResult.ModelName)
if len(providers) == 0 {
log.Debugf("amp model mapping: target model %s has no available providers, skipping mapping", targetModel)
return nil
return ""
}
// Log resolution
if len(allAliases) == 1 {
log.Debugf("amp model mapping: resolved %s -> %s via oauth-model-alias", targetModel, allAliases[0])
} else {
log.Debugf("amp model mapping: resolved %s -> %v via oauth-model-alias (%d fallbacks)", targetModel, allAliases, len(allAliases)-1)
// Suffix handling: config suffix takes priority, otherwise preserve user suffix
if targetResult.HasSuffix {
// Config's "to" already contains a suffix - use it as-is (config priority)
return targetModel
}
// Apply suffix to all aliases
result := make([]string, len(allAliases))
for i, alias := range allAliases {
result[i] = applySuffix(alias)
// Preserve user's thinking suffix on the mapped model
// (skip empty suffixes to avoid returning "model()")
if requestResult.HasSuffix && requestResult.RawSuffix != "" {
return targetModel + "(" + requestResult.RawSuffix + ")"
}
return result
// Note: Detailed routing log is handled by logAmpRouting in fallback_handlers.go
return targetModel
}
// UpdateMappings refreshes the mapping configuration from config.
@@ -276,22 +165,6 @@ func (m *DefaultModelMapper) GetMappings() map[string]string {
return result
}
// GetMappingsAsConfig returns the current model mappings as config.AmpModelMapping slice.
// Safe for concurrent use.
func (m *DefaultModelMapper) GetMappingsAsConfig() []config.AmpModelMapping {
m.mu.RLock()
defer m.mu.RUnlock()
result := make([]config.AmpModelMapping, 0, len(m.mappings))
for from, to := range m.mappings {
result = append(result, config.AmpModelMapping{
From: from,
To: to,
})
}
return result
}
type regexMapping struct {
re *regexp.Regexp
to string

View File

@@ -5,12 +5,11 @@ import (
"errors"
"net"
"net/http"
"net/http/httputil"
"strings"
"github.com/gin-gonic/gin"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
"github.com/router-for-me/CLIProxyAPI/v6/internal/logging"
"github.com/router-for-me/CLIProxyAPI/v6/internal/routing"
"github.com/router-for-me/CLIProxyAPI/v6/sdk/api/handlers"
"github.com/router-for-me/CLIProxyAPI/v6/sdk/api/handlers/claude"
"github.com/router-for-me/CLIProxyAPI/v6/sdk/api/handlers/gemini"
@@ -235,20 +234,19 @@ func (m *AmpModule) registerManagementRoutes(engine *gin.Engine, baseHandler *ha
// If no local OAuth is available, falls back to ampcode.com proxy.
geminiHandlers := gemini.NewGeminiAPIHandler(baseHandler)
geminiBridge := createGeminiBridgeHandler(geminiHandlers.GeminiHandler)
geminiV1Beta1Fallback := NewFallbackHandlerWithMapper(func() *httputil.ReverseProxy {
return m.getProxy()
}, m.modelMapper, m.forceModelMappings)
geminiV1Beta1Handler := geminiV1Beta1Fallback.WrapHandler(geminiBridge)
// T-025: Migrated Gemini v1beta1 bridge to use ModelRoutingWrapper
// Create a dedicated routing wrapper for the Gemini bridge
geminiBridgeWrapper := m.createModelRoutingWrapper()
geminiV1Beta1Handler := geminiBridgeWrapper.Wrap(geminiBridge)
// Route POST model calls through Gemini bridge with ModelRoutingWrapper.
// ModelRoutingWrapper checks provider -> mapping -> proxy fallback automatically.
// Route POST model calls through Gemini bridge with FallbackHandler.
// FallbackHandler checks provider -> mapping -> proxy fallback automatically.
// All other methods (e.g., GET model listing) always proxy to upstream to preserve Amp CLI behavior.
ampAPI.Any("/provider/google/v1beta1/*path", func(c *gin.Context) {
if c.Request.Method == "POST" {
if path := c.Param("path"); strings.Contains(path, "/models/") {
// POST with /models/ path -> use Gemini bridge with unified routing wrapper
// ModelRoutingWrapper will check provider/mapping and proxy if needed
// POST with /models/ path -> use Gemini bridge with fallback handler
// FallbackHandler will check provider/mapping and proxy if needed
geminiV1Beta1Handler(c)
return
}
@@ -258,41 +256,6 @@ func (m *AmpModule) registerManagementRoutes(engine *gin.Engine, baseHandler *ha
})
}
// createModelRoutingWrapper creates a new ModelRoutingWrapper for unified routing.
// This is used for testing the new routing implementation (T-021 onwards).
func (m *AmpModule) createModelRoutingWrapper() *routing.ModelRoutingWrapper {
// Create a registry - in production this would be populated with actual providers
registry := routing.NewRegistry()
// Create a minimal config with just AmpCode settings
// The Router only needs AmpCode.ModelMappings and OAuthModelAlias
cfg := &config.Config{
AmpCode: func() config.AmpCode {
if m.modelMapper != nil {
return config.AmpCode{
ModelMappings: m.modelMapper.GetMappingsAsConfig(),
}
}
return config.AmpCode{}
}(),
}
// Create router with registry and config
router := routing.NewRouter(registry, cfg)
// Create wrapper with proxy function
proxyFunc := func(c *gin.Context) {
proxy := m.getProxy()
if proxy != nil {
proxy.ServeHTTP(c.Writer, c.Request)
} else {
c.JSON(503, gin.H{"error": "amp upstream proxy not available"})
}
}
return routing.NewModelRoutingWrapper(router, nil, nil, proxyFunc)
}
// registerProviderAliases registers /api/provider/{provider}/... routes
// These allow Amp CLI to route requests like:
//
@@ -306,9 +269,12 @@ func (m *AmpModule) registerProviderAliases(engine *gin.Engine, baseHandler *han
claudeCodeHandlers := claude.NewClaudeCodeAPIHandler(baseHandler)
openaiResponsesHandlers := openai.NewOpenAIResponsesAPIHandler(baseHandler)
// Create unified routing wrapper (T-021 onwards)
// Replaces FallbackHandler with Router-based unified routing
routingWrapper := m.createModelRoutingWrapper()
// Create fallback handler wrapper that forwards to ampcode.com when provider not found
// Uses m.getProxy() for hot-reload support (proxy can be updated at runtime)
// Also includes model mapping support for routing unavailable models to alternatives
fallbackHandler := NewFallbackHandlerWithMapper(func() *httputil.ReverseProxy {
return m.getProxy()
}, m.modelMapper, m.forceModelMappings)
// Provider-specific routes under /api/provider/:provider
ampProviders := engine.Group("/api/provider")
@@ -336,36 +302,33 @@ func (m *AmpModule) registerProviderAliases(engine *gin.Engine, baseHandler *han
}
// Root-level routes (for providers that omit /v1, like groq/cerebras)
// T-022: Migrated all OpenAI routes to use ModelRoutingWrapper for unified routing
// Wrap handlers with fallback logic to forward to ampcode.com when provider not found
provider.GET("/models", ampModelsHandler) // Models endpoint doesn't need fallback (no body to check)
provider.POST("/chat/completions", routingWrapper.Wrap(openaiHandlers.ChatCompletions))
provider.POST("/completions", routingWrapper.Wrap(openaiHandlers.Completions))
provider.POST("/responses", routingWrapper.Wrap(openaiResponsesHandlers.Responses))
provider.POST("/chat/completions", fallbackHandler.WrapHandler(openaiHandlers.ChatCompletions))
provider.POST("/completions", fallbackHandler.WrapHandler(openaiHandlers.Completions))
provider.POST("/responses", fallbackHandler.WrapHandler(openaiResponsesHandlers.Responses))
// /v1 routes (OpenAI/Claude-compatible endpoints)
v1Amp := provider.Group("/v1")
{
v1Amp.GET("/models", ampModelsHandler) // Models endpoint doesn't need fallback
// OpenAI-compatible endpoints with ModelRoutingWrapper
// T-021, T-022: Migrated to unified routing wrapper
v1Amp.POST("/chat/completions", routingWrapper.Wrap(openaiHandlers.ChatCompletions))
v1Amp.POST("/completions", routingWrapper.Wrap(openaiHandlers.Completions))
v1Amp.POST("/responses", routingWrapper.Wrap(openaiResponsesHandlers.Responses))
// OpenAI-compatible endpoints with fallback
v1Amp.POST("/chat/completions", fallbackHandler.WrapHandler(openaiHandlers.ChatCompletions))
v1Amp.POST("/completions", fallbackHandler.WrapHandler(openaiHandlers.Completions))
v1Amp.POST("/responses", fallbackHandler.WrapHandler(openaiResponsesHandlers.Responses))
// Claude/Anthropic-compatible endpoints with ModelRoutingWrapper
// T-023: Migrated Claude routes to unified routing wrapper
v1Amp.POST("/messages", routingWrapper.Wrap(claudeCodeHandlers.ClaudeMessages))
v1Amp.POST("/messages/count_tokens", routingWrapper.Wrap(claudeCodeHandlers.ClaudeCountTokens))
// Claude/Anthropic-compatible endpoints with fallback
v1Amp.POST("/messages", fallbackHandler.WrapHandler(claudeCodeHandlers.ClaudeMessages))
v1Amp.POST("/messages/count_tokens", fallbackHandler.WrapHandler(claudeCodeHandlers.ClaudeCountTokens))
}
// /v1beta routes (Gemini native API)
// Note: Gemini handler extracts model from URL path, so fallback logic needs special handling
// T-024: Migrated Gemini v1beta routes to unified routing wrapper
v1betaAmp := provider.Group("/v1beta")
{
v1betaAmp.GET("/models", geminiHandlers.GeminiModels)
v1betaAmp.POST("/models/*action", routingWrapper.Wrap(geminiHandlers.GeminiHandler))
v1betaAmp.POST("/models/*action", fallbackHandler.WrapHandler(geminiHandlers.GeminiHandler))
v1betaAmp.GET("/models/*action", geminiHandlers.GeminiGetHandler)
}
}

View File

@@ -960,8 +960,8 @@ func (s *Server) UpdateClients(cfg *config.Config) {
s.mgmt.SetAuthManager(s.handlers.AuthManager)
}
// Notify Amp module when Amp config or OAuth model aliases have changed.
ampConfigChanged := oldCfg == nil || !reflect.DeepEqual(oldCfg.AmpCode, cfg.AmpCode) || !reflect.DeepEqual(oldCfg.OAuthModelAlias, cfg.OAuthModelAlias)
// Notify Amp module only when Amp config has changed.
ampConfigChanged := oldCfg == nil || !reflect.DeepEqual(oldCfg.AmpCode, cfg.AmpCode)
if ampConfigChanged {
if s.ampModule != nil {
log.Debugf("triggering amp module config update")

View File

@@ -6,8 +6,6 @@ import (
"strings"
"sync"
"time"
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
)
// SignatureEntry holds a cached thinking signature with timestamp
@@ -186,7 +184,6 @@ func HasValidSignature(modelName, signature string) bool {
}
func GetModelGroup(modelName string) string {
// Fast path: check model name patterns first
if strings.Contains(modelName, "gpt") {
return "gpt"
} else if strings.Contains(modelName, "claude") {
@@ -194,21 +191,5 @@ func GetModelGroup(modelName string) string {
} else if strings.Contains(modelName, "gemini") {
return "gemini"
}
// Slow path: check registry for provider-based grouping
// This handles models registered via claude-api-key, gemini-api-key, etc.
// that don't have provider name in their model name (e.g., kimi-k2.5 via claude-api-key)
if providers := registry.GetGlobalRegistry().GetModelProviders(modelName); len(providers) > 0 {
provider := strings.ToLower(providers[0])
switch provider {
case "claude":
return "claude"
case "gemini", "gemini-cli", "aistudio", "vertex", "antigravity":
return "gemini"
case "codex":
return "gpt"
}
}
return modelName
}

View File

@@ -208,84 +208,3 @@ func TestCacheSignature_ExpirationLogic(t *testing.T) {
// but the logic is verified by the implementation
_ = time.Now() // Acknowledge we're not testing time passage
}
// === GetModelGroup Tests ===
// These tests verify that GetModelGroup correctly identifies model groups
// both by name pattern (fast path) and by registry provider lookup (slow path).
func TestGetModelGroup_ByNamePattern(t *testing.T) {
tests := []struct {
modelName string
expectedGroup string
}{
{"gpt-4o", "gpt"},
{"gpt-4-turbo", "gpt"},
{"claude-sonnet-4-20250514", "claude"},
{"claude-opus-4-5-thinking", "claude"},
{"gemini-2.5-pro", "gemini"},
{"gemini-3-pro-preview", "gemini"},
}
for _, tt := range tests {
t.Run(tt.modelName, func(t *testing.T) {
result := GetModelGroup(tt.modelName)
if result != tt.expectedGroup {
t.Errorf("GetModelGroup(%q) = %q, expected %q", tt.modelName, result, tt.expectedGroup)
}
})
}
}
func TestGetModelGroup_UnknownModel(t *testing.T) {
// For unknown models with no registry entry, should return the model name itself
result := GetModelGroup("unknown-model-xyz")
if result != "unknown-model-xyz" {
t.Errorf("GetModelGroup for unknown model should return model name, got %q", result)
}
}
// TestGetModelGroup_RegistryFallback tests that models registered via
// provider-specific API keys (e.g., kimi-k2.5 via claude-api-key) are
// correctly grouped by their provider.
// This test requires a populated global registry.
func TestGetModelGroup_RegistryFallback(t *testing.T) {
// This test only makes sense when the global registry is populated
// In unit test context, skip if registry is empty
// Example: kimi-k2.5 registered via claude-api-key should group as "claude"
// The model name doesn't contain "claude", so name pattern matching fails.
// The registry should be checked to find the provider.
// Skip for now - this requires integration test setup
t.Skip("Requires populated global registry - run as integration test")
}
// === Cross-Model Signature Validation Tests ===
// These tests verify that signatures cached under one model name can be
// validated under mapped model names (same provider group).
func TestCacheSignature_CrossModelValidation(t *testing.T) {
ClearSignatureCache("")
// Original request uses "claude-opus-4-5-20251101"
originalModel := "claude-opus-4-5-20251101"
// Mapped model is "claude-opus-4-5-thinking"
mappedModel := "claude-opus-4-5-thinking"
text := "Some thinking block content"
sig := "validSignature123456789012345678901234567890123456789012"
// Cache signature under the original model
CacheSignature(originalModel, text, sig)
// Both should return the same signature because they're in the same group
retrieved1 := GetCachedSignature(originalModel, text)
retrieved2 := GetCachedSignature(mappedModel, text)
if retrieved1 != sig {
t.Errorf("Original model signature mismatch: got %q", retrieved1)
}
if retrieved2 != sig {
t.Errorf("Mapped model signature mismatch: got %q", retrieved2)
}
}

View File

@@ -18,7 +18,10 @@ import (
"gopkg.in/yaml.v3"
)
const DefaultPanelGitHubRepository = "https://github.com/router-for-me/Cli-Proxy-API-Management-Center"
const (
DefaultPanelGitHubRepository = "https://github.com/router-for-me/Cli-Proxy-API-Management-Center"
DefaultPprofAddr = "127.0.0.1:8316"
)
// Config represents the application's configuration, loaded from a YAML file.
type Config struct {
@@ -41,6 +44,9 @@ type Config struct {
// Debug enables or disables debug-level logging and other debug features.
Debug bool `yaml:"debug" json:"debug"`
// Pprof config controls the optional pprof HTTP debug server.
Pprof PprofConfig `yaml:"pprof" json:"pprof"`
// CommercialMode disables high-overhead HTTP middleware features to minimize per-request memory usage.
CommercialMode bool `yaml:"commercial-mode" json:"commercial-mode"`
@@ -121,6 +127,14 @@ type TLSConfig struct {
Key string `yaml:"key" json:"key"`
}
// PprofConfig holds pprof HTTP server settings.
type PprofConfig struct {
// Enable toggles the pprof HTTP debug server.
Enable bool `yaml:"enable" json:"enable"`
// Addr is the host:port address for the pprof HTTP server.
Addr string `yaml:"addr" json:"addr"`
}
// RemoteManagement holds management API configuration under 'remote-management'.
type RemoteManagement struct {
// AllowRemote toggles remote (non-localhost) access to management API.
@@ -514,6 +528,8 @@ func LoadConfigOptional(configFile string, optional bool) (*Config, error) {
cfg.ErrorLogsMaxFiles = 10
cfg.UsageStatisticsEnabled = false
cfg.DisableCooling = false
cfg.Pprof.Enable = false
cfg.Pprof.Addr = DefaultPprofAddr
cfg.AmpCode.RestrictManagementToLocalhost = false // Default to false: API key auth is sufficient
cfg.RemoteManagement.PanelGitHubRepository = DefaultPanelGitHubRepository
if err = yaml.Unmarshal(data, &cfg); err != nil {
@@ -556,6 +572,11 @@ func LoadConfigOptional(configFile string, optional bool) (*Config, error) {
cfg.RemoteManagement.PanelGitHubRepository = DefaultPanelGitHubRepository
}
cfg.Pprof.Addr = strings.TrimSpace(cfg.Pprof.Addr)
if cfg.Pprof.Addr == "" {
cfg.Pprof.Addr = DefaultPprofAddr
}
if cfg.LogsMaxTotalSizeMB < 0 {
cfg.LogsMaxTotalSizeMB = 0
}

View File

@@ -131,7 +131,10 @@ func ResolveLogDirectory(cfg *config.Config) string {
return logDir
}
if !isDirWritable(logDir) {
authDir := strings.TrimSpace(cfg.AuthDir)
authDir, err := util.ResolveAuthDir(cfg.AuthDir)
if err != nil {
log.Warnf("Failed to resolve auth-dir %q for log directory: %v", cfg.AuthDir, err)
}
if authDir != "" {
logDir = filepath.Join(authDir, "logs")
}

View File

@@ -1,39 +0,0 @@
// Package routing provides adapter to integrate with existing codebase.
package routing
import (
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
)
// Adapter bridges the new routing layer with existing auth manager.
type Adapter struct {
router *Router
exec *Executor
}
// NewAdapter creates a new adapter with the given configuration and auth manager.
func NewAdapter(cfg *config.Config, authManager *coreauth.Manager) *Adapter {
registry := NewRegistry()
// TODO: Register OAuth providers from authManager
// TODO: Register API key providers from cfg
router := NewRouter(registry, cfg)
exec := NewExecutor(router)
return &Adapter{
router: router,
exec: exec,
}
}
// Router returns the underlying router.
func (a *Adapter) Router() *Router {
return a.router
}
// Executor returns the underlying executor.
func (a *Adapter) Executor() *Executor {
return a.exec
}

View File

@@ -1,11 +0,0 @@
package ctxkeys
type key string
const (
MappedModel key = "mapped_model"
FallbackModels key = "fallback_models"
RouteCandidates key = "route_candidates"
RoutingDecision key = "routing_decision"
MappingApplied key = "mapping_applied"
)

View File

@@ -1,111 +0,0 @@
package routing
import (
"context"
"errors"
"github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
log "github.com/sirupsen/logrus"
)
// Executor handles request execution with fallback support.
type Executor struct {
router *Router
}
// NewExecutor creates a new executor with the given router.
func NewExecutor(router *Router) *Executor {
return &Executor{router: router}
}
// Execute sends the request through the routing decision.
func (e *Executor) Execute(ctx context.Context, req executor.Request) (executor.Response, error) {
decision := e.router.Resolve(req.Model)
log.Debugf("routing: %s -> %s (%d candidates)",
decision.RequestedModel,
decision.ResolvedModel,
len(decision.Candidates))
var lastErr error
tried := make(map[string]struct{})
for i, candidate := range decision.Candidates {
key := candidate.Provider.Name() + "/" + candidate.Model
if _, ok := tried[key]; ok {
continue
}
tried[key] = struct{}{}
log.Debugf("routing: trying candidate %d/%d: %s with model %s",
i+1, len(decision.Candidates), candidate.Provider.Name(), candidate.Model)
req.Model = candidate.Model
resp, err := candidate.Provider.Execute(ctx, candidate.Model, req)
if err == nil {
return resp, nil
}
lastErr = err
log.Debugf("routing: candidate failed: %v", err)
// Check if it's a fatal error (not retryable)
if isFatalError(err) {
break
}
}
if lastErr != nil {
return executor.Response{}, lastErr
}
return executor.Response{}, errors.New("no available providers")
}
// ExecuteStream sends a streaming request through the routing decision.
func (e *Executor) ExecuteStream(ctx context.Context, req executor.Request) (<-chan executor.StreamChunk, error) {
decision := e.router.Resolve(req.Model)
log.Debugf("routing stream: %s -> %s (%d candidates)",
decision.RequestedModel,
decision.ResolvedModel,
len(decision.Candidates))
var lastErr error
tried := make(map[string]struct{})
for i, candidate := range decision.Candidates {
key := candidate.Provider.Name() + "/" + candidate.Model
if _, ok := tried[key]; ok {
continue
}
tried[key] = struct{}{}
log.Debugf("routing stream: trying candidate %d/%d: %s with model %s",
i+1, len(decision.Candidates), candidate.Provider.Name(), candidate.Model)
req.Model = candidate.Model
chunks, err := candidate.Provider.ExecuteStream(ctx, candidate.Model, req)
if err == nil {
return chunks, nil
}
lastErr = err
log.Debugf("routing stream: candidate failed: %v", err)
if isFatalError(err) {
break
}
}
if lastErr != nil {
return nil, lastErr
}
return nil, errors.New("no available providers")
}
// isFatalError returns true if the error is not retryable.
func isFatalError(err error) bool {
// TODO: implement based on error type
// For now, all errors are retryable
return false
}

View File

@@ -1,59 +0,0 @@
package routing
import (
"strings"
"github.com/tidwall/gjson"
)
// ModelExtractor extracts model names from request data.
type ModelExtractor interface {
// Extract returns the model name from the request body and gin parameters.
// The ginParams map contains route parameters like "action" and "path".
Extract(body []byte, ginParams map[string]string) (string, error)
}
// DefaultModelExtractor is the standard implementation of ModelExtractor.
type DefaultModelExtractor struct{}
// NewModelExtractor creates a new DefaultModelExtractor.
func NewModelExtractor() *DefaultModelExtractor {
return &DefaultModelExtractor{}
}
// Extract extracts the model name from the request.
// It checks in order:
// 1. JSON body "model" field (OpenAI, Claude format)
// 2. "action" parameter for Gemini standard format (e.g., "gemini-pro:generateContent")
// 3. "path" parameter for AMP CLI Gemini format (e.g., "/publishers/google/models/gemini-3-pro:streamGenerateContent")
func (e *DefaultModelExtractor) Extract(body []byte, ginParams map[string]string) (string, error) {
// First try to parse from JSON body (OpenAI, Claude, etc.)
if result := gjson.GetBytes(body, "model"); result.Exists() && result.Type == gjson.String {
return result.String(), nil
}
// For Gemini requests, model is in the URL path
// Standard format: /models/{model}:generateContent -> :action parameter
if action, ok := ginParams["action"]; ok && action != "" {
// Split by colon to get model name (e.g., "gemini-pro:generateContent" -> "gemini-pro")
parts := strings.Split(action, ":")
if len(parts) > 0 && parts[0] != "" {
return parts[0], nil
}
}
// AMP CLI format: /publishers/google/models/{model}:method -> *path parameter
// Example: /publishers/google/models/gemini-3-pro-preview:streamGenerateContent
if path, ok := ginParams["path"]; ok && path != "" {
// Look for /models/{model}:method pattern
if idx := strings.Index(path, "/models/"); idx >= 0 {
modelPart := path[idx+8:] // Skip "/models/"
// Split by colon to get model name
if colonIdx := strings.Index(modelPart, ":"); colonIdx > 0 {
return modelPart[:colonIdx], nil
}
}
}
return "", nil
}

View File

@@ -1,214 +0,0 @@
package routing
import (
"testing"
"github.com/stretchr/testify/assert"
)
func TestModelExtractor_ExtractFromJSONBody(t *testing.T) {
extractor := NewModelExtractor()
tests := []struct {
name string
body []byte
want string
wantErr bool
}{
{
name: "extract from JSON body with model field",
body: []byte(`{"model":"gpt-4.1"}`),
want: "gpt-4.1",
},
{
name: "extract claude model from JSON body",
body: []byte(`{"model":"claude-3-5-sonnet-20241022"}`),
want: "claude-3-5-sonnet-20241022",
},
{
name: "extract with additional fields",
body: []byte(`{"model":"gpt-4","messages":[{"role":"user","content":"hello"}]}`),
want: "gpt-4",
},
{
name: "empty body returns empty",
body: []byte{},
want: "",
},
{
name: "no model field returns empty",
body: []byte(`{"messages":[]}`),
want: "",
},
{
name: "model is not string returns empty",
body: []byte(`{"model":123}`),
want: "",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := extractor.Extract(tt.body, nil)
if tt.wantErr {
assert.Error(t, err)
return
}
assert.NoError(t, err)
assert.Equal(t, tt.want, got)
})
}
}
func TestModelExtractor_ExtractFromGeminiActionParam(t *testing.T) {
extractor := NewModelExtractor()
tests := []struct {
name string
body []byte
ginParams map[string]string
want string
}{
{
name: "extract from action parameter - gemini-pro",
body: []byte(`{}`),
ginParams: map[string]string{"action": "gemini-pro:generateContent"},
want: "gemini-pro",
},
{
name: "extract from action parameter - gemini-ultra",
body: []byte(`{}`),
ginParams: map[string]string{"action": "gemini-ultra:chat"},
want: "gemini-ultra",
},
{
name: "empty action returns empty",
body: []byte(`{}`),
ginParams: map[string]string{"action": ""},
want: "",
},
{
name: "action without colon returns full value",
body: []byte(`{}`),
ginParams: map[string]string{"action": "gemini-model"},
want: "gemini-model",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := extractor.Extract(tt.body, tt.ginParams)
assert.NoError(t, err)
assert.Equal(t, tt.want, got)
})
}
}
func TestModelExtractor_ExtractFromGeminiV1Beta1Path(t *testing.T) {
extractor := NewModelExtractor()
tests := []struct {
name string
body []byte
ginParams map[string]string
want string
}{
{
name: "extract from v1beta1 path - gemini-3-pro",
body: []byte(`{}`),
ginParams: map[string]string{"path": "/publishers/google/models/gemini-3-pro:streamGenerateContent"},
want: "gemini-3-pro",
},
{
name: "extract from v1beta1 path with preview",
body: []byte(`{}`),
ginParams: map[string]string{"path": "/publishers/google/models/gemini-3-pro-preview:generateContent"},
want: "gemini-3-pro-preview",
},
{
name: "path without models segment returns empty",
body: []byte(`{}`),
ginParams: map[string]string{"path": "/publishers/google/gemini-3-pro:streamGenerateContent"},
want: "",
},
{
name: "empty path returns empty",
body: []byte(`{}`),
ginParams: map[string]string{"path": ""},
want: "",
},
{
name: "path with /models/ but no colon returns empty",
body: []byte(`{}`),
ginParams: map[string]string{"path": "/publishers/google/models/gemini-3-pro"},
want: "",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := extractor.Extract(tt.body, tt.ginParams)
assert.NoError(t, err)
assert.Equal(t, tt.want, got)
})
}
}
func TestModelExtractor_ExtractPriority(t *testing.T) {
extractor := NewModelExtractor()
// JSON body takes priority over gin params
t.Run("JSON body takes priority over action param", func(t *testing.T) {
body := []byte(`{"model":"gpt-4"}`)
params := map[string]string{"action": "gemini-pro:generateContent"}
got, err := extractor.Extract(body, params)
assert.NoError(t, err)
assert.Equal(t, "gpt-4", got)
})
// Action param takes priority over path param
t.Run("action param takes priority over path param", func(t *testing.T) {
body := []byte(`{}`)
params := map[string]string{
"action": "gemini-action:generate",
"path": "/publishers/google/models/gemini-path:streamGenerateContent",
}
got, err := extractor.Extract(body, params)
assert.NoError(t, err)
assert.Equal(t, "gemini-action", got)
})
}
func TestModelExtractor_NoModelFound(t *testing.T) {
extractor := NewModelExtractor()
tests := []struct {
name string
body []byte
ginParams map[string]string
}{
{
name: "empty body and no params",
body: []byte{},
ginParams: nil,
},
{
name: "body without model and no params",
body: []byte(`{"messages":[]}`),
ginParams: map[string]string{},
},
{
name: "irrelevant params only",
body: []byte(`{}`),
ginParams: map[string]string{"other": "value"},
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
got, err := extractor.Extract(tt.body, tt.ginParams)
assert.NoError(t, err)
assert.Empty(t, got)
})
}
}

View File

@@ -1,80 +0,0 @@
// Package routing provides unified model routing for all provider types.
package routing
import (
"context"
"github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
)
// ProviderType indicates the type of provider.
type ProviderType string
const (
ProviderTypeOAuth ProviderType = "oauth"
ProviderTypeAPIKey ProviderType = "api_key"
ProviderTypeVertex ProviderType = "vertex"
)
// Provider is the unified interface for all provider types (OAuth, API key, etc.).
type Provider interface {
// Name returns the unique provider identifier.
Name() string
// Type returns the provider type.
Type() ProviderType
// SupportsModel returns true if this provider can handle the given model.
SupportsModel(model string) bool
// Available returns true if the provider is available for the model (not quota exceeded).
Available(model string) bool
// Priority returns the priority for this provider (lower = tried first).
Priority() int
// Execute sends the request to the provider.
Execute(ctx context.Context, model string, req executor.Request) (executor.Response, error)
// ExecuteStream sends a streaming request to the provider.
ExecuteStream(ctx context.Context, model string, req executor.Request) (<-chan executor.StreamChunk, error)
}
// ProviderCandidate represents a provider + model combination to try.
type ProviderCandidate struct {
Provider Provider
Model string // The actual model name to use (may be different from requested due to aliasing)
}
// Registry manages all available providers.
type Registry struct {
providers []Provider
}
// NewRegistry creates a new provider registry.
func NewRegistry() *Registry {
return &Registry{
providers: make([]Provider, 0),
}
}
// Register adds a provider to the registry.
func (r *Registry) Register(p Provider) {
r.providers = append(r.providers, p)
}
// FindProviders returns all providers that support the given model and are available.
func (r *Registry) FindProviders(model string) []Provider {
var result []Provider
for _, p := range r.providers {
if p.SupportsModel(model) && p.Available(model) {
result = append(result, p)
}
}
return result
}
// All returns all registered providers.
func (r *Registry) All() []Provider {
return r.providers
}

View File

@@ -1,156 +0,0 @@
package providers
import (
"context"
"errors"
"net/http"
"strings"
"sync"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
"github.com/router-for-me/CLIProxyAPI/v6/internal/routing"
"github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
)
// APIKeyProvider wraps API key configs as routing.Provider.
type APIKeyProvider struct {
name string
provider string // claude, gemini, codex, vertex
keys []APIKeyEntry
mu sync.RWMutex
client HTTPClient
}
// APIKeyEntry represents a single API key configuration.
type APIKeyEntry struct {
APIKey string
BaseURL string
Models []config.ClaudeModel // Using ClaudeModel as generic model alias
}
// HTTPClient interface for making HTTP requests.
type HTTPClient interface {
Do(req *http.Request) (*http.Response, error)
}
// NewAPIKeyProvider creates a new API key provider.
func NewAPIKeyProvider(name, provider string, client HTTPClient) *APIKeyProvider {
return &APIKeyProvider{
name: name,
provider: provider,
keys: make([]APIKeyEntry, 0),
client: client,
}
}
// Name returns the provider name.
func (p *APIKeyProvider) Name() string {
return p.name
}
// Type returns ProviderTypeAPIKey.
func (p *APIKeyProvider) Type() routing.ProviderType {
return routing.ProviderTypeAPIKey
}
// SupportsModel checks if the model is supported by this provider.
func (p *APIKeyProvider) SupportsModel(model string) bool {
p.mu.RLock()
defer p.mu.RUnlock()
for _, key := range p.keys {
for _, m := range key.Models {
if strings.EqualFold(m.Alias, model) || strings.EqualFold(m.Name, model) {
return true
}
}
}
return false
}
// Available always returns true for API keys (unless explicitly disabled).
func (p *APIKeyProvider) Available(model string) bool {
return p.SupportsModel(model)
}
// Priority returns the priority (API key is lower priority than OAuth).
func (p *APIKeyProvider) Priority() int {
return 20
}
// Execute sends the request using the API key.
func (p *APIKeyProvider) Execute(ctx context.Context, model string, req executor.Request) (executor.Response, error) {
key := p.selectKey(model)
if key == nil {
return executor.Response{}, ErrNoMatchingAPIKey
}
// Resolve the actual model name from alias
actualModel := p.resolveModel(key, model)
// Execute via HTTP client
return p.executeHTTP(ctx, key, actualModel, req)
}
// ExecuteStream sends a streaming request.
func (p *APIKeyProvider) ExecuteStream(ctx context.Context, model string, req executor.Request) (
<-chan executor.StreamChunk, error) {
key := p.selectKey(model)
if key == nil {
return nil, ErrNoMatchingAPIKey
}
actualModel := p.resolveModel(key, model)
return p.executeHTTPStream(ctx, key, actualModel, req)
}
// AddKey adds an API key entry.
func (p *APIKeyProvider) AddKey(entry APIKeyEntry) {
p.mu.Lock()
defer p.mu.Unlock()
p.keys = append(p.keys, entry)
}
// selectKey selects a key that supports the model.
func (p *APIKeyProvider) selectKey(model string) *APIKeyEntry {
p.mu.RLock()
defer p.mu.RUnlock()
for _, key := range p.keys {
for _, m := range key.Models {
if strings.EqualFold(m.Alias, model) || strings.EqualFold(m.Name, model) {
return &key
}
}
}
return nil
}
// resolveModel resolves alias to actual model name.
func (p *APIKeyProvider) resolveModel(key *APIKeyEntry, requested string) string {
for _, m := range key.Models {
if strings.EqualFold(m.Alias, requested) {
return m.Name
}
}
return requested
}
// executeHTTP makes the HTTP request.
func (p *APIKeyProvider) executeHTTP(ctx context.Context, key *APIKeyEntry, model string, req executor.Request) (executor.Response, error) {
// TODO: implement actual HTTP execution
// This is a placeholder - actual implementation would build HTTP request
return executor.Response{}, errors.New("not yet implemented")
}
// executeHTTPStream makes a streaming HTTP request.
func (p *APIKeyProvider) executeHTTPStream(ctx context.Context, key *APIKeyEntry, model string, req executor.Request) (
<-chan executor.StreamChunk, error) {
// TODO: implement actual HTTP streaming
return nil, errors.New("not yet implemented")
}
// Errors
var (
ErrNoMatchingAPIKey = errors.New("no API key supports the requested model")
)

View File

@@ -1,132 +0,0 @@
package providers
import (
"context"
"errors"
"sync"
"github.com/router-for-me/CLIProxyAPI/v6/internal/routing"
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
"github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
)
// OAuthProvider wraps OAuth-based auths as routing.Provider.
type OAuthProvider struct {
name string
auths []*coreauth.Auth
mu sync.RWMutex
executor coreauth.ProviderExecutor
}
// NewOAuthProvider creates a new OAuth provider.
func NewOAuthProvider(name string, exec coreauth.ProviderExecutor) *OAuthProvider {
return &OAuthProvider{
name: name,
auths: make([]*coreauth.Auth, 0),
executor: exec,
}
}
// Name returns the provider name.
func (p *OAuthProvider) Name() string {
return p.name
}
// Type returns ProviderTypeOAuth.
func (p *OAuthProvider) Type() routing.ProviderType {
return routing.ProviderTypeOAuth
}
// SupportsModel checks if any auth supports the model.
func (p *OAuthProvider) SupportsModel(model string) bool {
p.mu.RLock()
defer p.mu.RUnlock()
// OAuth providers typically support models via oauth-model-alias
// The actual model support is determined at execution time
return true
}
// Available checks if there's an available auth for the model.
func (p *OAuthProvider) Available(model string) bool {
p.mu.RLock()
defer p.mu.RUnlock()
for _, auth := range p.auths {
if p.isAuthAvailable(auth, model) {
return true
}
}
return false
}
// Priority returns the priority (OAuth is preferred over API key).
func (p *OAuthProvider) Priority() int {
return 10
}
// Execute sends the request using an available OAuth auth.
func (p *OAuthProvider) Execute(ctx context.Context, model string, req executor.Request) (executor.Response, error) {
auth := p.selectAuth(model)
if auth == nil {
return executor.Response{}, ErrNoAvailableAuth
}
return p.executor.Execute(ctx, auth, req, executor.Options{})
}
// ExecuteStream sends a streaming request.
func (p *OAuthProvider) ExecuteStream(ctx context.Context, model string, req executor.Request) (<-chan executor.StreamChunk, error) {
auth := p.selectAuth(model)
if auth == nil {
return nil, ErrNoAvailableAuth
}
return p.executor.ExecuteStream(ctx, auth, req, executor.Options{})
}
// AddAuth adds an auth to this provider.
func (p *OAuthProvider) AddAuth(auth *coreauth.Auth) {
p.mu.Lock()
defer p.mu.Unlock()
p.auths = append(p.auths, auth)
}
// RemoveAuth removes an auth from this provider.
func (p *OAuthProvider) RemoveAuth(authID string) {
p.mu.Lock()
defer p.mu.Unlock()
filtered := make([]*coreauth.Auth, 0, len(p.auths))
for _, auth := range p.auths {
if auth.ID != authID {
filtered = append(filtered, auth)
}
}
p.auths = filtered
}
// isAuthAvailable checks if an auth is available for the model.
func (p *OAuthProvider) isAuthAvailable(auth *coreauth.Auth, model string) bool {
// TODO: integrate with model_registry for quota checking
// For now, just check if auth exists
return auth != nil
}
// selectAuth selects an available auth for the model.
func (p *OAuthProvider) selectAuth(model string) *coreauth.Auth {
p.mu.RLock()
defer p.mu.RUnlock()
for _, auth := range p.auths {
if p.isAuthAvailable(auth, model) {
return auth
}
}
return nil
}
// Errors
var (
ErrNoAvailableAuth = errors.New("no available OAuth auth for model")
)

View File

@@ -1,159 +0,0 @@
package routing
import (
"bytes"
"net/http"
"strings"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
log "github.com/sirupsen/logrus"
)
// ModelRewriter handles model name rewriting in requests and responses.
type ModelRewriter interface {
// RewriteRequestBody rewrites the model field in a JSON request body.
// Returns the modified body or the original if no rewrite was needed.
RewriteRequestBody(body []byte, newModel string) ([]byte, error)
// WrapResponseWriter wraps an http.ResponseWriter to rewrite model names in the response.
// Returns the wrapped writer and a cleanup function that must be called after the response is complete.
WrapResponseWriter(w http.ResponseWriter, requestedModel, resolvedModel string) (http.ResponseWriter, func())
}
// DefaultModelRewriter is the standard implementation of ModelRewriter.
type DefaultModelRewriter struct{}
// NewModelRewriter creates a new DefaultModelRewriter.
func NewModelRewriter() *DefaultModelRewriter {
return &DefaultModelRewriter{}
}
// RewriteRequestBody replaces the model name in a JSON request body.
func (r *DefaultModelRewriter) RewriteRequestBody(body []byte, newModel string) ([]byte, error) {
if !gjson.GetBytes(body, "model").Exists() {
return body, nil
}
result, err := sjson.SetBytes(body, "model", newModel)
if err != nil {
return body, err
}
return result, nil
}
// WrapResponseWriter wraps a response writer to rewrite model names.
// The cleanup function must be called after the handler completes to flush any buffered data.
func (r *DefaultModelRewriter) WrapResponseWriter(w http.ResponseWriter, requestedModel, resolvedModel string) (http.ResponseWriter, func()) {
rw := &responseRewriter{
ResponseWriter: w,
body: &bytes.Buffer{},
requestedModel: requestedModel,
resolvedModel: resolvedModel,
}
return rw, func() { rw.flush() }
}
// responseRewriter wraps http.ResponseWriter to intercept and modify the response body.
type responseRewriter struct {
http.ResponseWriter
body *bytes.Buffer
requestedModel string
resolvedModel string
isStreaming bool
wroteHeader bool
flushed bool
}
// Write intercepts response writes and buffers them for model name replacement.
func (rw *responseRewriter) Write(data []byte) (int, error) {
// Ensure header is written
if !rw.wroteHeader {
rw.WriteHeader(http.StatusOK)
}
// Detect streaming on first write
if rw.body.Len() == 0 && !rw.isStreaming {
contentType := rw.Header().Get("Content-Type")
rw.isStreaming = strings.Contains(contentType, "text/event-stream") ||
strings.Contains(contentType, "stream")
}
if rw.isStreaming {
n, err := rw.ResponseWriter.Write(rw.rewriteStreamChunk(data))
if err == nil {
if flusher, ok := rw.ResponseWriter.(http.Flusher); ok {
flusher.Flush()
}
}
return n, err
}
return rw.body.Write(data)
}
// WriteHeader captures the status code and delegates to the underlying writer.
func (rw *responseRewriter) WriteHeader(code int) {
if !rw.wroteHeader {
rw.wroteHeader = true
rw.ResponseWriter.WriteHeader(code)
}
}
// flush writes the buffered response with model names rewritten.
func (rw *responseRewriter) flush() {
if rw.flushed {
return
}
rw.flushed = true
if rw.isStreaming {
if flusher, ok := rw.ResponseWriter.(http.Flusher); ok {
flusher.Flush()
}
return
}
if rw.body.Len() > 0 {
data := rw.rewriteModelInResponse(rw.body.Bytes())
if _, err := rw.ResponseWriter.Write(data); err != nil {
log.Warnf("response rewriter: failed to write rewritten response: %v", err)
}
}
}
// modelFieldPaths lists all JSON paths where model name may appear.
var modelFieldPaths = []string{"model", "modelVersion", "response.modelVersion", "message.model"}
// rewriteModelInResponse replaces all occurrences of the resolved model with the requested model.
func (rw *responseRewriter) rewriteModelInResponse(data []byte) []byte {
if rw.requestedModel == "" || rw.resolvedModel == "" || rw.requestedModel == rw.resolvedModel {
return data
}
for _, path := range modelFieldPaths {
if gjson.GetBytes(data, path).Exists() {
data, _ = sjson.SetBytes(data, path, rw.requestedModel)
}
}
return data
}
// rewriteStreamChunk rewrites model names in SSE stream chunks.
func (rw *responseRewriter) rewriteStreamChunk(chunk []byte) []byte {
if rw.requestedModel == "" || rw.resolvedModel == "" || rw.requestedModel == rw.resolvedModel {
return chunk
}
// SSE format: "data: {json}\n\n"
lines := bytes.Split(chunk, []byte("\n"))
for i, line := range lines {
if bytes.HasPrefix(line, []byte("data: ")) {
jsonData := bytes.TrimPrefix(line, []byte("data: "))
if len(jsonData) > 0 && jsonData[0] == '{' {
// Rewrite JSON in the data line
rewritten := rw.rewriteModelInResponse(jsonData)
lines[i] = append([]byte("data: "), rewritten...)
}
}
}
return bytes.Join(lines, []byte("\n"))
}

View File

@@ -1,342 +0,0 @@
package routing
import (
"io"
"net/http"
"net/http/httptest"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestModelRewriter_RewriteRequestBody(t *testing.T) {
rewriter := NewModelRewriter()
tests := []struct {
name string
body []byte
newModel string
wantModel string
wantChange bool
}{
{
name: "rewrites model field in JSON body",
body: []byte(`{"model":"gpt-4.1","messages":[]}`),
newModel: "claude-local",
wantModel: "claude-local",
wantChange: true,
},
{
name: "rewrites with empty body returns empty",
body: []byte{},
newModel: "gpt-4",
wantModel: "",
wantChange: false,
},
{
name: "handles missing model field gracefully",
body: []byte(`{"messages":[{"role":"user"}]}`),
newModel: "gpt-4",
wantModel: "",
wantChange: false,
},
{
name: "preserves other fields when rewriting",
body: []byte(`{"model":"old-model","temperature":0.7,"max_tokens":100}`),
newModel: "new-model",
wantModel: "new-model",
wantChange: true,
},
{
name: "handles nested JSON structure",
body: []byte(`{"model":"gpt-4","messages":[{"role":"user","content":"hello"}],"stream":true}`),
newModel: "claude-3-opus",
wantModel: "claude-3-opus",
wantChange: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result, err := rewriter.RewriteRequestBody(tt.body, tt.newModel)
require.NoError(t, err)
if tt.wantChange {
assert.NotEqual(t, string(tt.body), string(result), "body should have been modified")
}
if tt.wantModel != "" {
// Parse result and check model field
model, _ := NewModelExtractor().Extract(result, nil)
assert.Equal(t, tt.wantModel, model)
}
})
}
}
func TestModelRewriter_WrapResponseWriter(t *testing.T) {
rewriter := NewModelRewriter()
t.Run("response writer wraps without error", func(t *testing.T) {
recorder := httptest.NewRecorder()
wrapped, cleanup := rewriter.WrapResponseWriter(recorder, "gpt-4", "claude-local")
require.NotNil(t, wrapped)
require.NotNil(t, cleanup)
defer cleanup()
})
t.Run("rewrites model in non-streaming response", func(t *testing.T) {
recorder := httptest.NewRecorder()
wrapped, cleanup := rewriter.WrapResponseWriter(recorder, "gpt-4", "claude-local")
// Write a response with the resolved model
response := []byte(`{"model":"claude-local","content":"hello"}`)
wrapped.Header().Set("Content-Type", "application/json")
_, err := wrapped.Write(response)
require.NoError(t, err)
// Cleanup triggers the rewrite
cleanup()
// Check the response was rewritten to the requested model
body := recorder.Body.Bytes()
assert.Contains(t, string(body), `"model":"gpt-4"`)
assert.NotContains(t, string(body), `"model":"claude-local"`)
})
t.Run("no-op when requested equals resolved", func(t *testing.T) {
recorder := httptest.NewRecorder()
wrapped, cleanup := rewriter.WrapResponseWriter(recorder, "gpt-4", "gpt-4")
response := []byte(`{"model":"gpt-4","content":"hello"}`)
wrapped.Header().Set("Content-Type", "application/json")
_, err := wrapped.Write(response)
require.NoError(t, err)
cleanup()
body := recorder.Body.Bytes()
assert.Contains(t, string(body), `"model":"gpt-4"`)
})
t.Run("rewrites modelVersion field", func(t *testing.T) {
recorder := httptest.NewRecorder()
wrapped, cleanup := rewriter.WrapResponseWriter(recorder, "gpt-4", "claude-local")
response := []byte(`{"modelVersion":"claude-local","content":"hello"}`)
wrapped.Header().Set("Content-Type", "application/json")
_, err := wrapped.Write(response)
require.NoError(t, err)
cleanup()
body := recorder.Body.Bytes()
assert.Contains(t, string(body), `"modelVersion":"gpt-4"`)
})
t.Run("handles streaming responses", func(t *testing.T) {
recorder := httptest.NewRecorder()
wrapped, cleanup := rewriter.WrapResponseWriter(recorder, "gpt-4", "claude-local")
// Set streaming content type
wrapped.Header().Set("Content-Type", "text/event-stream")
// Write SSE chunks with resolved model
chunk1 := []byte("data: {\"model\":\"claude-local\",\"delta\":\"hello\"}\n\n")
_, err := wrapped.Write(chunk1)
require.NoError(t, err)
chunk2 := []byte("data: {\"model\":\"claude-local\",\"delta\":\" world\"}\n\n")
_, err = wrapped.Write(chunk2)
require.NoError(t, err)
cleanup()
// For streaming, data is written immediately with rewrites
body := recorder.Body.Bytes()
assert.Contains(t, string(body), `"model":"gpt-4"`)
assert.NotContains(t, string(body), `"model":"claude-local"`)
})
t.Run("empty body handled gracefully", func(t *testing.T) {
recorder := httptest.NewRecorder()
wrapped, cleanup := rewriter.WrapResponseWriter(recorder, "gpt-4", "claude-local")
wrapped.Header().Set("Content-Type", "application/json")
// Don't write anything
cleanup()
body := recorder.Body.Bytes()
assert.Empty(t, body)
})
t.Run("preserves other JSON fields", func(t *testing.T) {
recorder := httptest.NewRecorder()
wrapped, cleanup := rewriter.WrapResponseWriter(recorder, "gpt-4", "claude-local")
response := []byte(`{"model":"claude-local","temperature":0.7,"usage":{"prompt_tokens":10}}`)
wrapped.Header().Set("Content-Type", "application/json")
_, err := wrapped.Write(response)
require.NoError(t, err)
cleanup()
body := recorder.Body.Bytes()
assert.Contains(t, string(body), `"temperature":0.7`)
assert.Contains(t, string(body), `"prompt_tokens":10`)
})
}
func TestResponseRewriter_ImplementsInterfaces(t *testing.T) {
rewriter := NewModelRewriter()
recorder := httptest.NewRecorder()
wrapped, cleanup := rewriter.WrapResponseWriter(recorder, "gpt-4", "claude-local")
defer cleanup()
// Should implement http.ResponseWriter
assert.Implements(t, (*http.ResponseWriter)(nil), wrapped)
// Should preserve header access
wrapped.Header().Set("X-Custom", "value")
assert.Equal(t, "value", recorder.Header().Get("X-Custom"))
// Should write status
wrapped.WriteHeader(http.StatusCreated)
assert.Equal(t, http.StatusCreated, recorder.Code)
}
func TestResponseRewriter_Flush(t *testing.T) {
t.Run("flush writes buffered content", func(t *testing.T) {
rewriter := NewModelRewriter()
recorder := httptest.NewRecorder()
wrapped, cleanup := rewriter.WrapResponseWriter(recorder, "gpt-4", "claude-local")
response := []byte(`{"model":"claude-local","content":"test"}`)
wrapped.Header().Set("Content-Type", "application/json")
wrapped.Write(response)
// Before cleanup, response should be empty (buffered)
assert.Empty(t, recorder.Body.Bytes())
// After cleanup, response should be written
cleanup()
assert.NotEmpty(t, recorder.Body.Bytes())
})
t.Run("multiple flush calls are safe", func(t *testing.T) {
rewriter := NewModelRewriter()
recorder := httptest.NewRecorder()
wrapped, cleanup := rewriter.WrapResponseWriter(recorder, "gpt-4", "claude-local")
response := []byte(`{"model":"claude-local"}`)
wrapped.Header().Set("Content-Type", "application/json")
wrapped.Write(response)
// First cleanup
cleanup()
firstBody := recorder.Body.Bytes()
// Second cleanup should not write again
cleanup()
secondBody := recorder.Body.Bytes()
assert.Equal(t, firstBody, secondBody)
})
}
func TestResponseRewriter_StreamingWithDataLines(t *testing.T) {
rewriter := NewModelRewriter()
recorder := httptest.NewRecorder()
wrapped, cleanup := rewriter.WrapResponseWriter(recorder, "gpt-4", "claude-local")
wrapped.Header().Set("Content-Type", "text/event-stream")
// SSE format with multiple data lines
chunk := []byte("data: {\"model\":\"claude-local\"}\n\ndata: {\"model\":\"claude-local\",\"done\":true}\n\n")
wrapped.Write(chunk)
cleanup()
body := recorder.Body.Bytes()
// Both data lines should have model rewritten
assert.Contains(t, string(body), `"model":"gpt-4"`)
assert.NotContains(t, string(body), `"model":"claude-local"`)
}
func TestModelRewriter_RoundTrip(t *testing.T) {
// Simulate a full request -> response cycle with model rewriting
rewriter := NewModelRewriter()
// Step 1: Rewrite request body
originalRequest := []byte(`{"model":"gpt-4","messages":[{"role":"user","content":"hello"}]}`)
rewrittenRequest, err := rewriter.RewriteRequestBody(originalRequest, "claude-local")
require.NoError(t, err)
// Verify request was rewritten
extractor := NewModelExtractor()
requestModel, _ := extractor.Extract(rewrittenRequest, nil)
assert.Equal(t, "claude-local", requestModel)
// Step 2: Simulate response with resolved model
recorder := httptest.NewRecorder()
wrapped, cleanup := rewriter.WrapResponseWriter(recorder, "gpt-4", "claude-local")
response := []byte(`{"model":"claude-local","content":"Hello! How can I help?"}`)
wrapped.Header().Set("Content-Type", "application/json")
wrapped.Write(response)
cleanup()
// Verify response was rewritten back
body, _ := io.ReadAll(recorder.Result().Body)
responseModel, _ := extractor.Extract(body, nil)
assert.Equal(t, "gpt-4", responseModel)
}
func TestModelRewriter_NonJSONBody(t *testing.T) {
rewriter := NewModelRewriter()
// Binary/non-JSON body should be returned unchanged
body := []byte{0x00, 0x01, 0x02, 0x03}
result, err := rewriter.RewriteRequestBody(body, "gpt-4")
require.NoError(t, err)
assert.Equal(t, body, result)
}
func TestModelRewriter_InvalidJSON(t *testing.T) {
rewriter := NewModelRewriter()
// Invalid JSON without model field should be returned unchanged
body := []byte(`not valid json`)
result, err := rewriter.RewriteRequestBody(body, "gpt-4")
require.NoError(t, err)
assert.Equal(t, body, result)
}
func TestResponseRewriter_StatusCodePreserved(t *testing.T) {
rewriter := NewModelRewriter()
recorder := httptest.NewRecorder()
wrapped, cleanup := rewriter.WrapResponseWriter(recorder, "gpt-4", "claude-local")
wrapped.WriteHeader(http.StatusAccepted)
wrapped.Write([]byte(`{"model":"claude-local"}`))
cleanup()
assert.Equal(t, http.StatusAccepted, recorder.Code)
}
func TestResponseRewriter_HeaderFlushed(t *testing.T) {
rewriter := NewModelRewriter()
recorder := httptest.NewRecorder()
wrapped, cleanup := rewriter.WrapResponseWriter(recorder, "gpt-4", "claude-local")
wrapped.Header().Set("Content-Type", "application/json")
wrapped.Header().Set("X-Request-ID", "abc123")
wrapped.Write([]byte(`{"model":"claude-local"}`))
cleanup()
result := recorder.Result()
assert.Equal(t, "application/json", result.Header.Get("Content-Type"))
assert.Equal(t, "abc123", result.Header.Get("X-Request-ID"))
}

View File

@@ -1,317 +0,0 @@
package routing
import (
"context"
"sort"
"strings"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
"github.com/router-for-me/CLIProxyAPI/v6/internal/thinking"
"github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
)
// Router resolves models to provider candidates.
type Router struct {
registry *Registry
modelMappings map[string]string // normalized from -> to
oauthAliases map[string][]string // normalized model -> []alias
}
// NewRouter creates a new router with the given configuration.
func NewRouter(registry *Registry, cfg *config.Config) *Router {
r := &Router{
registry: registry,
modelMappings: make(map[string]string),
oauthAliases: make(map[string][]string),
}
if cfg != nil {
r.loadModelMappings(cfg.AmpCode.ModelMappings)
r.loadOAuthAliases(cfg.OAuthModelAlias)
}
return r
}
// LegacyRoutingDecision contains the resolved routing information.
// Deprecated: Will be replaced by RoutingDecision from types.go in T-013.
type LegacyRoutingDecision struct {
RequestedModel string // Original model from request
ResolvedModel string // After model-mappings
Candidates []ProviderCandidate // Ordered list of providers to try
}
// Resolve determines the routing decision for the requested model.
// Deprecated: Will be updated to use RoutingRequest and return *RoutingDecision in T-013.
func (r *Router) Resolve(requestedModel string) *LegacyRoutingDecision {
// 1. Extract thinking suffix
suffixResult := thinking.ParseSuffix(requestedModel)
baseModel := suffixResult.ModelName
// 2. Apply model-mappings
targetModel := r.applyMappings(baseModel)
// 3. Find primary providers
candidates := r.findCandidates(targetModel, suffixResult)
// 4. Add fallback aliases
for _, alias := range r.oauthAliases[strings.ToLower(targetModel)] {
candidates = append(candidates, r.findCandidates(alias, suffixResult)...)
}
// 5. Sort by priority
sort.Slice(candidates, func(i, j int) bool {
return candidates[i].Provider.Priority() < candidates[j].Provider.Priority()
})
return &LegacyRoutingDecision{
RequestedModel: requestedModel,
ResolvedModel: targetModel,
Candidates: candidates,
}
}
// ResolveV2 determines the routing decision for a routing request.
// It uses the new RoutingRequest and RoutingDecision types.
func (r *Router) ResolveV2(req RoutingRequest) *RoutingDecision {
// 1. Extract thinking suffix
suffixResult := thinking.ParseSuffix(req.RequestedModel)
baseModel := suffixResult.ModelName
thinkingSuffix := ""
if suffixResult.HasSuffix {
thinkingSuffix = "(" + suffixResult.RawSuffix + ")"
}
// 2. Check for local providers
localCandidates := r.findLocalCandidates(baseModel, suffixResult)
// 3. Apply model-mappings if needed
mappedModel := r.applyMappings(baseModel)
mappingCandidates := r.findLocalCandidates(mappedModel, suffixResult)
// 4. Determine route type based on preferences and availability
var decision *RoutingDecision
if req.ForceModelMapping && mappedModel != baseModel && len(mappingCandidates) > 0 {
// FORCE MODE: Use mapping even if local provider exists
decision = r.buildMappingDecision(req.RequestedModel, mappedModel, mappingCandidates, thinkingSuffix, mappingCandidates[1:])
} else if req.PreferLocalProvider && len(localCandidates) > 0 {
// DEFAULT MODE with local preference: Use local provider first
decision = r.buildLocalProviderDecision(req.RequestedModel, localCandidates, thinkingSuffix)
} else if len(localCandidates) > 0 {
// DEFAULT MODE: Local provider available
decision = r.buildLocalProviderDecision(req.RequestedModel, localCandidates, thinkingSuffix)
} else if mappedModel != baseModel && len(mappingCandidates) > 0 {
// DEFAULT MODE: No local provider, but mapping available
decision = r.buildMappingDecision(req.RequestedModel, mappedModel, mappingCandidates, thinkingSuffix, mappingCandidates[1:])
} else {
// No local provider, no mapping - use amp credits proxy
decision = &RoutingDecision{
RouteType: RouteTypeAmpCredits,
ResolvedModel: req.RequestedModel,
ShouldProxy: true,
}
}
return decision
}
// findLocalCandidates finds local provider candidates for a model.
// If the internal registry is empty, it falls back to the global model registry.
func (r *Router) findLocalCandidates(model string, suffixResult thinking.SuffixResult) []ProviderCandidate {
var candidates []ProviderCandidate
// Check internal registry first
registryProviders := r.registry.All()
if len(registryProviders) > 0 {
for _, p := range registryProviders {
if !p.SupportsModel(model) {
continue
}
// Apply thinking suffix if needed
actualModel := model
if suffixResult.HasSuffix && !thinking.ParseSuffix(model).HasSuffix {
actualModel = model + "(" + suffixResult.RawSuffix + ")"
}
if p.Available(actualModel) {
candidates = append(candidates, ProviderCandidate{
Provider: p,
Model: actualModel,
})
}
}
} else {
// Fallback to global model registry (same logic as FallbackHandler)
// This ensures compatibility when the wrapper is initialized with an empty registry
providers := registry.GetGlobalRegistry().GetModelProviders(model)
if len(providers) > 0 {
actualModel := model
if suffixResult.HasSuffix && !thinking.ParseSuffix(model).HasSuffix {
actualModel = model + "(" + suffixResult.RawSuffix + ")"
}
// Create a synthetic provider candidate for each provider
for _, providerName := range providers {
candidates = append(candidates, ProviderCandidate{
Provider: &globalRegistryProvider{name: providerName, model: actualModel},
Model: actualModel,
})
}
}
}
// Sort by priority
sort.Slice(candidates, func(i, j int) bool {
return candidates[i].Provider.Priority() < candidates[j].Provider.Priority()
})
return candidates
}
// globalRegistryProvider is a synthetic Provider implementation that wraps
// a provider name from the global model registry. It is used only for routing
// decisions when the internal registry is empty - actual execution goes through
// the normal handler path, not through this provider's Execute methods.
type globalRegistryProvider struct {
name string
model string
}
func (p *globalRegistryProvider) Name() string { return p.name }
func (p *globalRegistryProvider) Type() ProviderType { return ProviderTypeOAuth }
func (p *globalRegistryProvider) Priority() int { return 0 }
func (p *globalRegistryProvider) SupportsModel(string) bool { return true }
func (p *globalRegistryProvider) Available(string) bool { return true }
// Execute is not used for globalRegistryProvider - routing wrapper calls the handler directly.
func (p *globalRegistryProvider) Execute(ctx context.Context, model string, req executor.Request) (executor.Response, error) {
return executor.Response{}, nil
}
// ExecuteStream is not used for globalRegistryProvider - routing wrapper calls the handler directly.
func (p *globalRegistryProvider) ExecuteStream(ctx context.Context, model string, req executor.Request) (<-chan executor.StreamChunk, error) {
return nil, nil
}
// buildLocalProviderDecision creates a decision for local provider routing.
func (r *Router) buildLocalProviderDecision(requestedModel string, candidates []ProviderCandidate, thinkingSuffix string) *RoutingDecision {
resolvedModel := requestedModel
if thinkingSuffix != "" {
// Ensure thinking suffix is preserved
sr := thinking.ParseSuffix(requestedModel)
if !sr.HasSuffix {
resolvedModel = requestedModel + thinkingSuffix
}
}
var fallbackModels []string
if len(candidates) > 1 {
for _, c := range candidates[1:] {
fallbackModels = append(fallbackModels, c.Model)
}
}
return &RoutingDecision{
RouteType: RouteTypeLocalProvider,
ResolvedModel: resolvedModel,
ProviderName: candidates[0].Provider.Name(),
FallbackModels: fallbackModels,
ShouldProxy: false,
}
}
// buildMappingDecision creates a decision for model mapping routing.
func (r *Router) buildMappingDecision(requestedModel, mappedModel string, candidates []ProviderCandidate, thinkingSuffix string, fallbackCandidates []ProviderCandidate) *RoutingDecision {
// Apply thinking suffix to resolved model if needed
resolvedModel := mappedModel
if thinkingSuffix != "" {
sr := thinking.ParseSuffix(mappedModel)
if !sr.HasSuffix {
resolvedModel = mappedModel + thinkingSuffix
}
}
var fallbackModels []string
for _, c := range fallbackCandidates {
fallbackModels = append(fallbackModels, c.Model)
}
// Also add oauth aliases as fallbacks
baseMapped := thinking.ParseSuffix(mappedModel).ModelName
for _, alias := range r.oauthAliases[strings.ToLower(baseMapped)] {
// Check if this alias has providers
aliasCandidates := r.findLocalCandidates(alias, thinking.SuffixResult{ModelName: alias})
for _, c := range aliasCandidates {
fallbackModels = append(fallbackModels, c.Model)
}
}
return &RoutingDecision{
RouteType: RouteTypeModelMapping,
ResolvedModel: resolvedModel,
ProviderName: candidates[0].Provider.Name(),
FallbackModels: fallbackModels,
ShouldProxy: false,
}
}
// applyMappings applies model-mappings configuration.
func (r *Router) applyMappings(model string) string {
key := strings.ToLower(strings.TrimSpace(model))
if mapped, ok := r.modelMappings[key]; ok {
return mapped
}
return model
}
// findCandidates finds all provider candidates for a model.
func (r *Router) findCandidates(model string, suffixResult thinking.SuffixResult) []ProviderCandidate {
var candidates []ProviderCandidate
for _, p := range r.registry.All() {
if !p.SupportsModel(model) {
continue
}
// Apply thinking suffix if needed
actualModel := model
if suffixResult.HasSuffix && !thinking.ParseSuffix(model).HasSuffix {
actualModel = model + "(" + suffixResult.RawSuffix + ")"
}
if p.Available(actualModel) {
candidates = append(candidates, ProviderCandidate{
Provider: p,
Model: actualModel,
})
}
}
return candidates
}
// loadModelMappings loads model-mappings from config.
func (r *Router) loadModelMappings(mappings []config.AmpModelMapping) {
for _, m := range mappings {
from := strings.ToLower(strings.TrimSpace(m.From))
to := strings.TrimSpace(m.To)
if from != "" && to != "" {
r.modelMappings[from] = to
}
}
}
// loadOAuthAliases loads oauth-model-alias from config.
func (r *Router) loadOAuthAliases(aliases map[string][]config.OAuthModelAlias) {
for _, entries := range aliases {
for _, entry := range entries {
name := strings.ToLower(strings.TrimSpace(entry.Name))
alias := strings.TrimSpace(entry.Alias)
if name != "" && alias != "" && name != alias {
r.oauthAliases[name] = append(r.oauthAliases[name], alias)
}
}
}
}

View File

@@ -1,202 +0,0 @@
package routing
import (
"context"
"testing"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
globalRegistry "github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
"github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
"github.com/stretchr/testify/assert"
)
// mockProvider is a test double for Provider.
type mockProvider struct {
name string
providerType ProviderType
supportsModels map[string]bool
available bool
priority int
}
func (m *mockProvider) Name() string { return m.name }
func (m *mockProvider) Type() ProviderType { return m.providerType }
func (m *mockProvider) SupportsModel(model string) bool { return m.supportsModels[model] }
func (m *mockProvider) Available(model string) bool { return m.available }
func (m *mockProvider) Priority() int { return m.priority }
func (m *mockProvider) Execute(ctx context.Context, model string, req executor.Request) (executor.Response, error) {
return executor.Response{}, nil
}
func (m *mockProvider) ExecuteStream(ctx context.Context, model string, req executor.Request) (<-chan executor.StreamChunk, error) {
return nil, nil
}
func TestRouter_Resolve_ModelMappings(t *testing.T) {
registry := NewRegistry()
// Add a provider
p := &mockProvider{
name: "test-provider",
providerType: ProviderTypeOAuth,
supportsModels: map[string]bool{"target-model": true},
available: true,
priority: 1,
}
registry.Register(p)
// Create router with model mapping
cfg := &config.Config{
AmpCode: config.AmpCode{
ModelMappings: []config.AmpModelMapping{
{From: "user-model", To: "target-model"},
},
},
}
router := NewRouter(registry, cfg)
// Resolve
decision := router.Resolve("user-model")
assert.Equal(t, "user-model", decision.RequestedModel)
assert.Equal(t, "target-model", decision.ResolvedModel)
assert.Len(t, decision.Candidates, 1)
assert.Equal(t, "target-model", decision.Candidates[0].Model)
}
func TestRouter_Resolve_OAuthAliases(t *testing.T) {
registry := NewRegistry()
// Add providers
p1 := &mockProvider{
name: "oauth-1",
providerType: ProviderTypeOAuth,
supportsModels: map[string]bool{"primary-model": true},
available: true,
priority: 1,
}
p2 := &mockProvider{
name: "oauth-2",
providerType: ProviderTypeOAuth,
supportsModels: map[string]bool{"fallback-model": true},
available: true,
priority: 2,
}
registry.Register(p1)
registry.Register(p2)
// Create router with oauth aliases
cfg := &config.Config{
OAuthModelAlias: map[string][]config.OAuthModelAlias{
"test-channel": {
{Name: "primary-model", Alias: "fallback-model"},
},
},
}
router := NewRouter(registry, cfg)
// Resolve
decision := router.Resolve("primary-model")
assert.Equal(t, "primary-model", decision.ResolvedModel)
assert.Len(t, decision.Candidates, 2)
// Primary should come first (lower priority value)
assert.Equal(t, "primary-model", decision.Candidates[0].Model)
assert.Equal(t, "fallback-model", decision.Candidates[1].Model)
}
func TestRouter_Resolve_NoProviders(t *testing.T) {
registry := NewRegistry()
cfg := &config.Config{}
router := NewRouter(registry, cfg)
decision := router.Resolve("unknown-model")
assert.Equal(t, "unknown-model", decision.ResolvedModel)
assert.Empty(t, decision.Candidates)
}
// === Global Registry Fallback Tests (T-027) ===
// These tests verify that when the internal registry is empty,
// the router falls back to the global model registry.
// This is the core fix for the thinking signature 400 error.
func TestRouter_GlobalRegistryFallback_LocalProvider(t *testing.T) {
// This test requires registering a model in the global registry.
// We use a model that's already registered via api-key config in production.
// For isolated testing, we can skip if global registry is not populated.
globalReg := globalRegistry.GetGlobalRegistry()
modelCount := globalReg.GetModelCount("claude-sonnet-4-20250514")
if modelCount == 0 {
t.Skip("Global registry not populated - run with server context")
}
// Empty internal registry
emptyRegistry := NewRegistry()
cfg := &config.Config{}
router := NewRouter(emptyRegistry, cfg)
req := RoutingRequest{
RequestedModel: "claude-sonnet-4-20250514",
PreferLocalProvider: true,
}
decision := router.ResolveV2(req)
// Should find provider from global registry
assert.Equal(t, RouteTypeLocalProvider, decision.RouteType)
assert.Equal(t, "claude-sonnet-4-20250514", decision.ResolvedModel)
assert.False(t, decision.ShouldProxy)
}
func TestRouter_GlobalRegistryFallback_ModelMapping(t *testing.T) {
// This test verifies that model mapping works with global registry fallback.
globalReg := globalRegistry.GetGlobalRegistry()
modelCount := globalReg.GetModelCount("claude-opus-4-5-thinking")
if modelCount == 0 {
t.Skip("Global registry not populated - run with server context")
}
// Empty internal registry
emptyRegistry := NewRegistry()
cfg := &config.Config{
AmpCode: config.AmpCode{
ModelMappings: []config.AmpModelMapping{
{From: "claude-opus-4-5-20251101", To: "claude-opus-4-5-thinking"},
},
},
}
router := NewRouter(emptyRegistry, cfg)
req := RoutingRequest{
RequestedModel: "claude-opus-4-5-20251101",
PreferLocalProvider: true,
}
decision := router.ResolveV2(req)
// Should find mapped model from global registry
assert.Equal(t, RouteTypeModelMapping, decision.RouteType)
assert.Equal(t, "claude-opus-4-5-thinking", decision.ResolvedModel)
assert.False(t, decision.ShouldProxy)
}
func TestRouter_GlobalRegistryFallback_AmpCreditsWhenNotFound(t *testing.T) {
// Empty internal registry
emptyRegistry := NewRegistry()
cfg := &config.Config{}
router := NewRouter(emptyRegistry, cfg)
// Use a model that definitely doesn't exist anywhere
req := RoutingRequest{
RequestedModel: "nonexistent-model-12345",
PreferLocalProvider: true,
}
decision := router.ResolveV2(req)
// Should fall back to AMP credits proxy
assert.Equal(t, RouteTypeAmpCredits, decision.RouteType)
assert.Equal(t, "nonexistent-model-12345", decision.ResolvedModel)
assert.True(t, decision.ShouldProxy)
}

View File

@@ -1,245 +0,0 @@
package routing
import (
"context"
"testing"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
"github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
"github.com/stretchr/testify/assert"
)
func TestRouter_DefaultMode_PrefersLocal(t *testing.T) {
// Setup: Create a router with a mock provider that supports "gpt-4"
registry := NewRegistry()
mockProvider := &MockProvider{
name: "openai",
supportedModels: []string{"gpt-4"},
available: true,
priority: 1,
}
registry.Register(mockProvider)
cfg := &config.Config{
AmpCode: config.AmpCode{
ModelMappings: []config.AmpModelMapping{
{From: "gpt-4", To: "claude-local"},
},
},
}
router := NewRouter(registry, cfg)
// Test: Request gpt-4 when local provider exists
req := RoutingRequest{
RequestedModel: "gpt-4",
PreferLocalProvider: true,
ForceModelMapping: false,
}
decision := router.ResolveV2(req)
// Assert: Should return LOCAL_PROVIDER, not MODEL_MAPPING
assert.Equal(t, RouteTypeLocalProvider, decision.RouteType)
assert.Equal(t, "gpt-4", decision.ResolvedModel)
assert.Equal(t, "openai", decision.ProviderName)
assert.False(t, decision.ShouldProxy)
}
func TestRouter_DefaultMode_MapsWhenNoLocal(t *testing.T) {
// Setup: Create a router with NO provider for "gpt-4" but a mapping to "claude-local"
// which has a provider
registry := NewRegistry()
mockProvider := &MockProvider{
name: "anthropic",
supportedModels: []string{"claude-local"},
available: true,
priority: 1,
}
registry.Register(mockProvider)
cfg := &config.Config{
AmpCode: config.AmpCode{
ModelMappings: []config.AmpModelMapping{
{From: "gpt-4", To: "claude-local"},
},
},
}
router := NewRouter(registry, cfg)
// Test: Request gpt-4 when no local provider exists, but mapping exists
req := RoutingRequest{
RequestedModel: "gpt-4",
PreferLocalProvider: true,
ForceModelMapping: false,
}
decision := router.ResolveV2(req)
// Assert: Should return MODEL_MAPPING
assert.Equal(t, RouteTypeModelMapping, decision.RouteType)
assert.Equal(t, "claude-local", decision.ResolvedModel)
assert.Equal(t, "anthropic", decision.ProviderName)
assert.False(t, decision.ShouldProxy)
}
func TestRouter_DefaultMode_AmpCreditsWhenNoLocalOrMapping(t *testing.T) {
// Setup: Create a router with no providers and no mappings
registry := NewRegistry()
cfg := &config.Config{
AmpCode: config.AmpCode{
ModelMappings: []config.AmpModelMapping{},
},
}
router := NewRouter(registry, cfg)
// Test: Request a model with no local provider and no mapping
req := RoutingRequest{
RequestedModel: "unknown-model",
PreferLocalProvider: true,
ForceModelMapping: false,
}
decision := router.ResolveV2(req)
// Assert: Should return AMP_CREDITS with ShouldProxy=true
assert.Equal(t, RouteTypeAmpCredits, decision.RouteType)
assert.Equal(t, "unknown-model", decision.ResolvedModel)
assert.True(t, decision.ShouldProxy)
assert.Empty(t, decision.ProviderName)
}
func TestRouter_ForceMode_MapsEvenWithLocal(t *testing.T) {
// Setup: Create a router with BOTH a local provider for "gpt-4" AND a mapping from "gpt-4" to "claude-local"
// The mapping target "claude-local" also has a provider
registry := NewRegistry()
// Local provider for gpt-4
openaiProvider := &MockProvider{
name: "openai",
supportedModels: []string{"gpt-4"},
available: true,
priority: 1,
}
registry.Register(openaiProvider)
// Local provider for the mapped model
anthropicProvider := &MockProvider{
name: "anthropic",
supportedModels: []string{"claude-local"},
available: true,
priority: 2,
}
registry.Register(anthropicProvider)
cfg := &config.Config{
AmpCode: config.AmpCode{
ModelMappings: []config.AmpModelMapping{
{From: "gpt-4", To: "claude-local"},
},
},
}
router := NewRouter(registry, cfg)
// Test: Request gpt-4 with ForceModelMapping=true
// Even though gpt-4 has a local provider, mapping should take precedence
req := RoutingRequest{
RequestedModel: "gpt-4",
PreferLocalProvider: false,
ForceModelMapping: true,
}
decision := router.ResolveV2(req)
// Assert: Should return MODEL_MAPPING, not LOCAL_PROVIDER
assert.Equal(t, RouteTypeModelMapping, decision.RouteType)
assert.Equal(t, "claude-local", decision.ResolvedModel)
assert.Equal(t, "anthropic", decision.ProviderName)
assert.False(t, decision.ShouldProxy)
}
func TestRouter_ThinkingSuffix_Preserved(t *testing.T) {
// Setup: Create a router with mapping and provider for mapped model
registry := NewRegistry()
mockProvider := &MockProvider{
name: "anthropic",
supportedModels: []string{"claude-local"},
available: true,
priority: 1,
}
registry.Register(mockProvider)
cfg := &config.Config{
AmpCode: config.AmpCode{
ModelMappings: []config.AmpModelMapping{
{From: "claude-3-5-sonnet", To: "claude-local"},
},
},
}
router := NewRouter(registry, cfg)
// Test: Request claude-3-5-sonnet with thinking suffix
req := RoutingRequest{
RequestedModel: "claude-3-5-sonnet(thinking:foo)",
PreferLocalProvider: true,
ForceModelMapping: false,
}
decision := router.ResolveV2(req)
// Assert: Thinking suffix should be preserved in resolved model
assert.Equal(t, RouteTypeModelMapping, decision.RouteType)
assert.Equal(t, "claude-local(thinking:foo)", decision.ResolvedModel)
assert.Equal(t, "anthropic", decision.ProviderName)
}
// MockProvider is a mock implementation of Provider for testing
type MockProvider struct {
name string
providerType ProviderType
supportedModels []string
available bool
priority int
}
func (m *MockProvider) Name() string {
return m.name
}
func (m *MockProvider) Type() ProviderType {
if m.providerType == "" {
return ProviderTypeOAuth
}
return m.providerType
}
func (m *MockProvider) SupportsModel(model string) bool {
for _, supported := range m.supportedModels {
if supported == model {
return true
}
}
return false
}
func (m *MockProvider) Available(model string) bool {
return m.available
}
func (m *MockProvider) Priority() int {
return m.priority
}
func (m *MockProvider) Execute(ctx context.Context, model string, req executor.Request) (executor.Response, error) {
return executor.Response{}, nil
}
func (m *MockProvider) ExecuteStream(ctx context.Context, model string, req executor.Request) (<-chan executor.StreamChunk, error) {
return nil, nil
}

View File

@@ -1,113 +0,0 @@
package testutil
import (
"io"
"net/http"
"github.com/gin-gonic/gin"
)
// FakeHandlerRecorder records handler invocations for testing.
type FakeHandlerRecorder struct {
Called bool
CallCount int
RequestBody []byte
RequestHeader http.Header
ContextKeys map[string]interface{}
ResponseStatus int
ResponseBody []byte
}
// NewFakeHandlerRecorder creates a new fake handler recorder.
func NewFakeHandlerRecorder() *FakeHandlerRecorder {
return &FakeHandlerRecorder{
ContextKeys: make(map[string]interface{}),
ResponseStatus: http.StatusOK,
ResponseBody: []byte(`{"status":"handled"}`),
}
}
// GinHandler returns a gin.HandlerFunc that records the invocation.
func (f *FakeHandlerRecorder) GinHandler() gin.HandlerFunc {
return func(c *gin.Context) {
f.record(c)
c.Data(f.ResponseStatus, "application/json", f.ResponseBody)
}
}
// GinHandlerWithModel returns a gin.HandlerFunc that records the invocation and returns the model from context.
// Useful for testing response rewriting in model mapping scenarios.
func (f *FakeHandlerRecorder) GinHandlerWithModel() gin.HandlerFunc {
return func(c *gin.Context) {
f.record(c)
// Return a response with the model field that would be in the actual API response
// If ResponseBody was explicitly set (not default), use that; otherwise generate from context
var body []byte
if mappedModel, exists := c.Get("mapped_model"); exists {
body = []byte(`{"model":"` + mappedModel.(string) + `","status":"handled"}`)
} else {
body = f.ResponseBody
}
c.Data(f.ResponseStatus, "application/json", body)
}
}
// HTTPHandler returns an http.HandlerFunc that records the invocation.
func (f *FakeHandlerRecorder) HTTPHandler() http.HandlerFunc {
return func(w http.ResponseWriter, r *http.Request) {
body, _ := io.ReadAll(r.Body)
f.Called = true
f.CallCount++
f.RequestBody = body
f.RequestHeader = r.Header.Clone()
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(f.ResponseStatus)
w.Write(f.ResponseBody)
}
}
// record captures the request details from gin context.
func (f *FakeHandlerRecorder) record(c *gin.Context) {
f.Called = true
f.CallCount++
body, _ := io.ReadAll(c.Request.Body)
f.RequestBody = body
f.RequestHeader = c.Request.Header.Clone()
// Capture common context keys used by routing
if val, exists := c.Get("mapped_model"); exists {
f.ContextKeys["mapped_model"] = val
}
if val, exists := c.Get("fallback_models"); exists {
f.ContextKeys["fallback_models"] = val
}
if val, exists := c.Get("route_type"); exists {
f.ContextKeys["route_type"] = val
}
}
// Reset clears the recorder state.
func (f *FakeHandlerRecorder) Reset() {
f.Called = false
f.CallCount = 0
f.RequestBody = nil
f.RequestHeader = nil
f.ContextKeys = make(map[string]interface{})
}
// GetContextKey returns a captured context key value.
func (f *FakeHandlerRecorder) GetContextKey(key string) (interface{}, bool) {
val, ok := f.ContextKeys[key]
return val, ok
}
// WasCalled returns true if the handler was called.
func (f *FakeHandlerRecorder) WasCalled() bool {
return f.Called
}
// GetCallCount returns the number of times the handler was called.
func (f *FakeHandlerRecorder) GetCallCount() int {
return f.CallCount
}

View File

@@ -1,83 +0,0 @@
package testutil
import (
"io"
"net/http"
"net/http/httptest"
)
// CloseNotifierRecorder wraps httptest.ResponseRecorder with CloseNotify support.
// This is needed because ReverseProxy requires http.CloseNotifier.
type CloseNotifierRecorder struct {
*httptest.ResponseRecorder
closeChan chan bool
}
// NewCloseNotifierRecorder creates a ResponseRecorder that implements CloseNotifier.
func NewCloseNotifierRecorder() *CloseNotifierRecorder {
return &CloseNotifierRecorder{
ResponseRecorder: httptest.NewRecorder(),
closeChan: make(chan bool, 1),
}
}
// CloseNotify implements http.CloseNotifier.
func (c *CloseNotifierRecorder) CloseNotify() <-chan bool {
return c.closeChan
}
// FakeProxyRecorder records proxy invocations for testing.
type FakeProxyRecorder struct {
Called bool
CallCount int
RequestBody []byte
RequestHeaders http.Header
ResponseStatus int
ResponseBody []byte
}
// NewFakeProxyRecorder creates a new fake proxy recorder.
func NewFakeProxyRecorder() *FakeProxyRecorder {
return &FakeProxyRecorder{
ResponseStatus: http.StatusOK,
ResponseBody: []byte(`{"status":"proxied"}`),
}
}
// ServeHTTP implements http.Handler to act as a reverse proxy.
func (f *FakeProxyRecorder) ServeHTTP(w http.ResponseWriter, r *http.Request) {
f.Called = true
f.CallCount++
f.RequestHeaders = r.Header.Clone()
body, err := io.ReadAll(r.Body)
if err == nil {
f.RequestBody = body
}
w.WriteHeader(f.ResponseStatus)
w.Write(f.ResponseBody)
}
// GetCallCount returns the number of times the proxy was called.
func (f *FakeProxyRecorder) GetCallCount() int {
return f.CallCount
}
// Reset clears the recorder state.
func (f *FakeProxyRecorder) Reset() {
f.Called = false
f.CallCount = 0
f.RequestBody = nil
f.RequestHeaders = nil
}
// ToHandler returns the recorder as an http.Handler for use with httptest.
func (f *FakeProxyRecorder) ToHandler() http.Handler {
return http.HandlerFunc(f.ServeHTTP)
}
// CreateTestServer creates an httptest server with this fake proxy.
func (f *FakeProxyRecorder) CreateTestServer() *httptest.Server {
return httptest.NewServer(f.ToHandler())
}

View File

@@ -1,62 +0,0 @@
package routing
// RouteType represents the type of routing decision made for a request.
type RouteType string
const (
// RouteTypeLocalProvider indicates the request is handled by a local OAuth provider (free).
RouteTypeLocalProvider RouteType = "LOCAL_PROVIDER"
// RouteTypeModelMapping indicates the request was remapped to another available model (free).
RouteTypeModelMapping RouteType = "MODEL_MAPPING"
// RouteTypeAmpCredits indicates the request is forwarded to ampcode.com (uses Amp credits).
RouteTypeAmpCredits RouteType = "AMP_CREDITS"
// RouteTypeNoProvider indicates no provider or fallback available.
RouteTypeNoProvider RouteType = "NO_PROVIDER"
)
// RoutingRequest contains the information needed to make a routing decision.
type RoutingRequest struct {
// RequestedModel is the model name from the incoming request.
RequestedModel string
// PreferLocalProvider indicates whether to prefer local providers over mappings.
// When true, check local providers first before applying model mappings.
PreferLocalProvider bool
// ForceModelMapping indicates whether to force model mapping even if local provider exists.
// When true, apply model mappings first and skip local provider checks.
ForceModelMapping bool
}
// RoutingDecision contains the result of a routing decision.
type RoutingDecision struct {
// RouteType indicates the type of routing decision.
RouteType RouteType
// ResolvedModel is the final model name after any mappings.
ResolvedModel string
// ProviderName is the name of the selected provider (if any).
ProviderName string
// FallbackModels is a list of alternative models to try if the primary fails.
FallbackModels []string
// ShouldProxy indicates whether the request should be proxied to ampcode.com.
ShouldProxy bool
}
// NewRoutingDecision creates a new RoutingDecision with the given parameters.
func NewRoutingDecision(routeType RouteType, resolvedModel, providerName string, fallbackModels []string, shouldProxy bool) *RoutingDecision {
return &RoutingDecision{
RouteType: routeType,
ResolvedModel: resolvedModel,
ProviderName: providerName,
FallbackModels: fallbackModels,
ShouldProxy: shouldProxy,
}
}
// IsLocal returns true if the decision routes to a local provider.
func (d *RoutingDecision) IsLocal() bool {
return d.RouteType == RouteTypeLocalProvider || d.RouteType == RouteTypeModelMapping
}
// HasFallbacks returns true if there are fallback models available.
func (d *RoutingDecision) HasFallbacks() bool {
return len(d.FallbackModels) > 0
}

View File

@@ -1,270 +0,0 @@
package routing
import (
"bufio"
"bytes"
"io"
"net"
"net/http"
"github.com/gin-gonic/gin"
"github.com/router-for-me/CLIProxyAPI/v6/internal/routing/ctxkeys"
"github.com/sirupsen/logrus"
)
// ProxyFunc is the function type for proxying requests.
type ProxyFunc func(c *gin.Context)
// ModelRoutingWrapper wraps HTTP handlers with unified model routing logic.
// It replaces the FallbackHandler logic with a Router-based approach.
type ModelRoutingWrapper struct {
router *Router
extractor ModelExtractor
rewriter ModelRewriter
proxyFunc ProxyFunc
logger *logrus.Logger
}
// NewModelRoutingWrapper creates a new ModelRoutingWrapper with the given dependencies.
// If extractor is nil, a DefaultModelExtractor is used.
// If rewriter is nil, a DefaultModelRewriter is used.
// proxyFunc is called for AMP_CREDITS route type; if nil, the handler will be called instead.
func NewModelRoutingWrapper(router *Router, extractor ModelExtractor, rewriter ModelRewriter, proxyFunc ProxyFunc) *ModelRoutingWrapper {
if extractor == nil {
extractor = NewModelExtractor()
}
if rewriter == nil {
rewriter = NewModelRewriter()
}
return &ModelRoutingWrapper{
router: router,
extractor: extractor,
rewriter: rewriter,
proxyFunc: proxyFunc,
logger: logrus.New(),
}
}
// SetLogger sets the logger for the wrapper.
func (w *ModelRoutingWrapper) SetLogger(logger *logrus.Logger) {
w.logger = logger
}
// Wrap wraps a gin.HandlerFunc with model routing logic.
// The returned handler will:
// 1. Extract the model from the request
// 2. Get a routing decision from the Router
// 3. Handle the request according to the decision type (LOCAL_PROVIDER, MODEL_MAPPING, AMP_CREDITS)
func (w *ModelRoutingWrapper) Wrap(handler gin.HandlerFunc) gin.HandlerFunc {
return func(c *gin.Context) {
// Read request body
bodyBytes, err := io.ReadAll(c.Request.Body)
if err != nil {
w.logger.Errorf("routing wrapper: failed to read request body: %v", err)
handler(c)
return
}
// Extract model from request
ginParams := map[string]string{
"action": c.Param("action"),
"path": c.Param("path"),
}
modelName, err := w.extractor.Extract(bodyBytes, ginParams)
if err != nil {
w.logger.Warnf("routing wrapper: failed to extract model: %v", err)
c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes))
handler(c)
return
}
if modelName == "" {
// No model found, proceed with original handler
c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes))
handler(c)
return
}
// Get routing decision
req := RoutingRequest{
RequestedModel: modelName,
PreferLocalProvider: true,
ForceModelMapping: false, // TODO: Get from config
}
decision := w.router.ResolveV2(req)
// Store decision in context for downstream handlers
c.Set(string(ctxkeys.RoutingDecision), decision)
// Handle based on route type
switch decision.RouteType {
case RouteTypeLocalProvider:
w.handleLocalProvider(c, handler, bodyBytes, decision)
case RouteTypeModelMapping:
w.handleModelMapping(c, handler, bodyBytes, decision)
case RouteTypeAmpCredits:
w.handleAmpCredits(c, handler, bodyBytes)
default:
// No provider available
c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes))
handler(c)
}
}
}
// handleLocalProvider handles the LOCAL_PROVIDER route type.
func (w *ModelRoutingWrapper) handleLocalProvider(c *gin.Context, handler gin.HandlerFunc, bodyBytes []byte, decision *RoutingDecision) {
// Filter Anthropic-Beta header for local provider
filterAnthropicBetaHeader(c)
// Restore body with original content
c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes))
// Call handler
handler(c)
}
// handleModelMapping handles the MODEL_MAPPING route type.
func (w *ModelRoutingWrapper) handleModelMapping(c *gin.Context, handler gin.HandlerFunc, bodyBytes []byte, decision *RoutingDecision) {
// Rewrite request body with mapped model
rewrittenBody, err := w.rewriter.RewriteRequestBody(bodyBytes, decision.ResolvedModel)
if err != nil {
w.logger.Warnf("routing wrapper: failed to rewrite request body: %v", err)
rewrittenBody = bodyBytes
}
_ = rewrittenBody
// Store mapped model in context
c.Set(string(ctxkeys.MappedModel), decision.ResolvedModel)
// Store fallback models in context if present
if len(decision.FallbackModels) > 0 {
c.Set(string(ctxkeys.FallbackModels), decision.FallbackModels)
}
// Filter Anthropic-Beta header for local provider
filterAnthropicBetaHeader(c)
// Restore body with rewritten content
c.Request.Body = io.NopCloser(bytes.NewReader(rewrittenBody))
// Wrap response writer to rewrite model back
wrappedWriter, cleanup := w.rewriter.WrapResponseWriter(c.Writer, decision.ResolvedModel, decision.ResolvedModel)
c.Writer = &ginResponseWriterAdapter{ResponseWriter: wrappedWriter, original: c.Writer}
// Call handler
handler(c)
// Cleanup (flush response rewriting)
cleanup()
}
// handleAmpCredits handles the AMP_CREDITS route type.
// It calls the proxy function directly if available, otherwise passes to handler.
// Does NOT filter headers or rewrite body - proxy handles everything.
func (w *ModelRoutingWrapper) handleAmpCredits(c *gin.Context, handler gin.HandlerFunc, bodyBytes []byte) {
// Restore body with original content (no rewriting for proxy)
c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes))
// Call proxy function if available, otherwise fall back to handler
if w.proxyFunc != nil {
w.proxyFunc(c)
} else {
handler(c)
}
}
// filterAnthropicBetaHeader filters Anthropic-Beta header for local providers.
func filterAnthropicBetaHeader(c *gin.Context) {
if betaHeader := c.Request.Header.Get("Anthropic-Beta"); betaHeader != "" {
filtered := filterBetaFeatures(betaHeader, "context-1m-2025-08-07")
if filtered != "" {
c.Request.Header.Set("Anthropic-Beta", filtered)
} else {
c.Request.Header.Del("Anthropic-Beta")
}
}
}
// filterBetaFeatures removes specified beta features from the header.
func filterBetaFeatures(betaHeader, featureToRemove string) string {
// Simple implementation - can be enhanced
if betaHeader == featureToRemove {
return ""
}
return betaHeader
}
// ginResponseWriterAdapter adapts http.ResponseWriter to gin.ResponseWriter.
type ginResponseWriterAdapter struct {
http.ResponseWriter
original gin.ResponseWriter
}
func (a *ginResponseWriterAdapter) WriteHeader(code int) {
a.ResponseWriter.WriteHeader(code)
}
func (a *ginResponseWriterAdapter) Write(data []byte) (int, error) {
return a.ResponseWriter.Write(data)
}
func (a *ginResponseWriterAdapter) Header() http.Header {
return a.ResponseWriter.Header()
}
// CloseNotify implements http.CloseNotifier.
func (a *ginResponseWriterAdapter) CloseNotify() <-chan bool {
if notifier, ok := a.ResponseWriter.(http.CloseNotifier); ok {
return notifier.CloseNotify()
}
return a.original.CloseNotify()
}
// Flush implements http.Flusher.
func (a *ginResponseWriterAdapter) Flush() {
if flusher, ok := a.ResponseWriter.(http.Flusher); ok {
flusher.Flush()
}
}
// Hijack implements http.Hijacker.
func (a *ginResponseWriterAdapter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
if hijacker, ok := a.ResponseWriter.(http.Hijacker); ok {
return hijacker.Hijack()
}
return a.original.Hijack()
}
// Status returns the HTTP status code.
func (a *ginResponseWriterAdapter) Status() int {
return a.original.Status()
}
// Size returns the number of bytes already written into the response http body.
func (a *ginResponseWriterAdapter) Size() int {
return a.original.Size()
}
// Written returns whether or not the response for this context has been written.
func (a *ginResponseWriterAdapter) Written() bool {
return a.original.Written()
}
// WriteHeaderNow forces WriteHeader to be called.
func (a *ginResponseWriterAdapter) WriteHeaderNow() {
a.original.WriteHeaderNow()
}
// WriteString writes the given string into the response body.
func (a *ginResponseWriterAdapter) WriteString(s string) (int, error) {
return a.Write([]byte(s))
}
// Pusher returns the http.Pusher for server push.
func (a *ginResponseWriterAdapter) Pusher() http.Pusher {
if pusher, ok := a.ResponseWriter.(http.Pusher); ok {
return pusher
}
return nil
}

View File

@@ -1003,6 +1003,8 @@ func vertexBaseURL(location string) string {
loc := strings.TrimSpace(location)
if loc == "" {
loc = "us-central1"
} else if loc == "global" {
return "https://aiplatform.googleapis.com"
}
return fmt.Sprintf("https://%s-aiplatform.googleapis.com", loc)
}

View File

@@ -83,10 +83,6 @@ func (a *Applier) Apply(body []byte, config thinking.ThinkingConfig, modelInfo *
// Ensure max_tokens > thinking.budget_tokens (Anthropic API constraint)
result = a.normalizeClaudeBudget(result, config.Budget, modelInfo)
// When thinking is enabled, Claude API requires assistant messages with tool_use
// to have a thinking block. Inject empty thinking block if missing.
result = injectThinkingBlockForToolUse(result)
return result, nil
}
@@ -153,85 +149,18 @@ func applyCompatibleClaude(body []byte, config thinking.ThinkingConfig) ([]byte,
body = []byte(`{}`)
}
var result []byte
switch config.Mode {
case thinking.ModeNone:
result, _ = sjson.SetBytes(body, "thinking.type", "disabled")
result, _ := sjson.SetBytes(body, "thinking.type", "disabled")
result, _ = sjson.DeleteBytes(result, "thinking.budget_tokens")
return result, nil
case thinking.ModeAuto:
result, _ = sjson.SetBytes(body, "thinking.type", "enabled")
result, _ := sjson.SetBytes(body, "thinking.type", "enabled")
result, _ = sjson.DeleteBytes(result, "thinking.budget_tokens")
return result, nil
default:
result, _ = sjson.SetBytes(body, "thinking.type", "enabled")
result, _ := sjson.SetBytes(body, "thinking.type", "enabled")
result, _ = sjson.SetBytes(result, "thinking.budget_tokens", config.Budget)
return result, nil
}
// When thinking is enabled, Claude API requires assistant messages with tool_use
// to have a thinking block. Inject empty thinking block if missing.
result = injectThinkingBlockForToolUse(result)
return result, nil
}
// injectThinkingBlockForToolUse adds empty thinking block to assistant messages
// that have tool_use but no thinking block. This is required by Claude API when
// thinking is enabled.
func injectThinkingBlockForToolUse(body []byte) []byte {
messages := gjson.GetBytes(body, "messages")
if !messages.IsArray() {
return body
}
messageArray := messages.Array()
modified := false
newMessages := "[]"
for _, msg := range messageArray {
role := msg.Get("role").String()
if role != "assistant" {
newMessages, _ = sjson.SetRaw(newMessages, "-1", msg.Raw)
continue
}
content := msg.Get("content")
if !content.IsArray() {
newMessages, _ = sjson.SetRaw(newMessages, "-1", msg.Raw)
continue
}
contentArray := content.Array()
hasToolUse := false
hasThinking := false
for _, part := range contentArray {
partType := part.Get("type").String()
if partType == "tool_use" {
hasToolUse = true
}
if partType == "thinking" {
hasThinking = true
}
}
if hasToolUse && !hasThinking {
// Inject empty thinking block at the beginning of content
newContent := "[]"
newContent, _ = sjson.SetRaw(newContent, "-1", `{"type":"thinking","thinking":""}`)
for _, part := range contentArray {
newContent, _ = sjson.SetRaw(newContent, "-1", part.Raw)
}
msgJSON := msg.Raw
msgJSON, _ = sjson.SetRaw(msgJSON, "content", newContent)
newMessages, _ = sjson.SetRaw(newMessages, "-1", msgJSON)
modified = true
continue
}
newMessages, _ = sjson.SetRaw(newMessages, "-1", msg.Raw)
}
if modified {
body, _ = sjson.SetRawBytes(body, "messages", []byte(newMessages))
}
return body
}

View File

@@ -1,187 +0,0 @@
package claude
import (
"testing"
"github.com/router-for-me/CLIProxyAPI/v6/internal/thinking"
"github.com/tidwall/gjson"
)
func TestInjectThinkingBlockForToolUse(t *testing.T) {
tests := []struct {
name string
input string
expected string
}{
{
name: "assistant with tool_use but no thinking - should inject thinking",
input: `{
"model": "kimi-k2.5",
"messages": [
{
"role": "assistant",
"content": [
{"type": "text", "text": "Let me use a tool"},
{"type": "tool_use", "id": "tool_1", "name": "test_tool", "input": {}}
]
}
]
}`,
expected: "thinking",
},
{
name: "assistant with tool_use and thinking - should not modify",
input: `{
"model": "kimi-k2.5",
"messages": [
{
"role": "assistant",
"content": [
{"type": "thinking", "thinking": "I need to use a tool"},
{"type": "tool_use", "id": "tool_1", "name": "test_tool", "input": {}}
]
}
]
}`,
expected: "thinking",
},
{
name: "user message with tool_use - should not modify",
input: `{
"model": "kimi-k2.5",
"messages": [
{
"role": "user",
"content": [
{"type": "tool_result", "tool_use_id": "tool_1", "content": "result"}
]
}
]
}`,
expected: "",
},
{
name: "assistant without tool_use - should not modify",
input: `{
"model": "kimi-k2.5",
"messages": [
{
"role": "assistant",
"content": [
{"type": "text", "text": "Hello!"}
]
}
]
}`,
expected: "",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := injectThinkingBlockForToolUse([]byte(tt.input))
// Check if thinking block exists in assistant messages with tool_use
messages := gjson.GetBytes(result, "messages")
if !messages.IsArray() {
t.Fatal("messages is not an array")
}
for _, msg := range messages.Array() {
if msg.Get("role").String() == "assistant" {
content := msg.Get("content")
if !content.IsArray() {
continue
}
hasToolUse := false
hasThinking := false
for _, part := range content.Array() {
partType := part.Get("type").String()
if partType == "tool_use" {
hasToolUse = true
}
if partType == "thinking" {
hasThinking = true
}
}
if hasToolUse && tt.expected == "thinking" && !hasThinking {
t.Errorf("Expected thinking block in assistant message with tool_use, but not found")
}
}
}
})
}
}
func TestApplyCompatibleClaude(t *testing.T) {
tests := []struct {
name string
input string
config thinking.ThinkingConfig
expectThinking bool
}{
{
name: "thinking enabled with tool_use - should inject thinking block",
input: `{
"model": "kimi-k2.5",
"messages": [
{
"role": "assistant",
"content": [
{"type": "tool_use", "id": "tool_1", "name": "test_tool", "input": {}}
]
}
]
}`,
config: thinking.ThinkingConfig{
Mode: thinking.ModeBudget,
Budget: 4000,
},
expectThinking: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result, err := applyCompatibleClaude([]byte(tt.input), tt.config)
if err != nil {
t.Fatalf("applyCompatibleClaude failed: %v", err)
}
// Check if thinking.type is enabled
thinkingType := gjson.GetBytes(result, "thinking.type").String()
if thinkingType != "enabled" {
t.Errorf("Expected thinking.type=enabled, got %s", thinkingType)
}
// Check if thinking block is injected
messages := gjson.GetBytes(result, "messages")
if !messages.IsArray() {
t.Fatal("messages is not an array")
}
for _, msg := range messages.Array() {
if msg.Get("role").String() == "assistant" {
content := msg.Get("content")
if !content.IsArray() {
continue
}
hasThinking := false
for _, part := range content.Array() {
if part.Get("type").String() == "thinking" {
hasThinking = true
break
}
}
if tt.expectThinking && !hasThinking {
t.Errorf("Expected thinking block in assistant message, but not found. Result: %s", string(result))
}
}
}
})
}
}

View File

@@ -115,8 +115,6 @@ func ConvertClaudeRequestToAntigravity(modelName string, inputRawJSON []byte, _
if signatureResult.Exists() && signatureResult.String() != "" {
arrayClientSignatures := strings.SplitN(signatureResult.String(), "#", 2)
if len(arrayClientSignatures) == 2 {
// Compare using model group to handle model mapping
// e.g., claude-opus-4-5-thinking -> "claude" group should match "claude#signature"
if cache.GetModelGroup(modelName) == arrayClientSignatures[0] {
clientSignature = arrayClientSignatures[1]
}

View File

@@ -11,6 +11,12 @@ import (
func ConvertOpenAIResponsesRequestToCodex(modelName string, inputRawJSON []byte, _ bool) []byte {
rawJSON := bytes.Clone(inputRawJSON)
inputResult := gjson.GetBytes(rawJSON, "input")
if inputResult.Type == gjson.String {
input, _ := sjson.Set(`[{"type":"message","role":"user","content":[{"type":"input_text","text":""}]}]`, "0.content.0.text", inputResult.String())
rawJSON, _ = sjson.SetRawBytes(rawJSON, "input", []byte(input))
}
rawJSON, _ = sjson.SetBytes(rawJSON, "stream", true)
rawJSON, _ = sjson.SetBytes(rawJSON, "store", false)
rawJSON, _ = sjson.SetBytes(rawJSON, "parallel_tool_calls", true)

View File

@@ -61,13 +61,10 @@ func ConvertClaudeRequestToOpenAI(modelName string, inputRawJSON []byte, stream
out, _ = sjson.Set(out, "stream", stream)
// Thinking: Convert Claude thinking.budget_tokens to OpenAI reasoning_effort
// Also track if thinking is enabled to ensure reasoning_content is added for tool_calls
thinkingEnabled := false
if thinkingConfig := root.Get("thinking"); thinkingConfig.Exists() && thinkingConfig.IsObject() {
if thinkingType := thinkingConfig.Get("type"); thinkingType.Exists() {
switch thinkingType.String() {
case "enabled":
thinkingEnabled = true
if budgetTokens := thinkingConfig.Get("budget_tokens"); budgetTokens.Exists() {
budget := int(budgetTokens.Int())
if effort, ok := thinking.ConvertBudgetToLevel(budget); ok && effort != "" {
@@ -220,10 +217,6 @@ func ConvertClaudeRequestToOpenAI(modelName string, inputRawJSON []byte, stream
// Add reasoning_content if present
if hasReasoning {
msgJSON, _ = sjson.Set(msgJSON, "reasoning_content", reasoningContent)
} else if thinkingEnabled && hasToolCalls {
// Claude API requires reasoning_content in assistant messages with tool_calls
// when thinking mode is enabled, even if empty
msgJSON, _ = sjson.Set(msgJSON, "reasoning_content", "")
}
// Add tool_calls if present (in same message as content)

View File

@@ -588,124 +588,3 @@ func TestConvertClaudeRequestToOpenAI_AssistantThinkingToolUseThinkingSplit(t *t
t.Fatalf("Expected reasoning_content %q, got %q", "t1\n\nt2", got)
}
}
// TestConvertClaudeRequestToOpenAI_ThinkingEnabledToolCallsNoReasoning tests that
// when thinking mode is enabled and assistant message has tool_calls but no thinking content,
// an empty reasoning_content is added to satisfy Claude API requirements.
func TestConvertClaudeRequestToOpenAI_ThinkingEnabledToolCallsNoReasoning(t *testing.T) {
tests := []struct {
name string
inputJSON string
wantHasReasoningContent bool
wantReasoningContent string
}{
{
name: "thinking enabled with tool_calls but no thinking content adds empty reasoning_content",
inputJSON: `{
"model": "claude-3-opus",
"thinking": {"type": "enabled", "budget_tokens": 4000},
"messages": [{
"role": "assistant",
"content": [
{"type": "text", "text": "I will help you."},
{"type": "tool_use", "id": "tool_1", "name": "read_file", "input": {"path": "/test.txt"}}
]
}]
}`,
wantHasReasoningContent: true,
wantReasoningContent: "",
},
{
name: "thinking enabled with tool_calls and thinking content uses actual reasoning",
inputJSON: `{
"model": "claude-3-opus",
"thinking": {"type": "enabled", "budget_tokens": 4000},
"messages": [{
"role": "assistant",
"content": [
{"type": "thinking", "thinking": "Let me analyze this..."},
{"type": "text", "text": "I will help you."},
{"type": "tool_use", "id": "tool_1", "name": "read_file", "input": {"path": "/test.txt"}}
]
}]
}`,
wantHasReasoningContent: true,
wantReasoningContent: "Let me analyze this...",
},
{
name: "thinking disabled with tool_calls does not add reasoning_content",
inputJSON: `{
"model": "claude-3-opus",
"thinking": {"type": "disabled"},
"messages": [{
"role": "assistant",
"content": [
{"type": "text", "text": "I will help you."},
{"type": "tool_use", "id": "tool_1", "name": "read_file", "input": {"path": "/test.txt"}}
]
}]
}`,
wantHasReasoningContent: false,
wantReasoningContent: "",
},
{
name: "no thinking config with tool_calls does not add reasoning_content",
inputJSON: `{
"model": "claude-3-opus",
"messages": [{
"role": "assistant",
"content": [
{"type": "text", "text": "I will help you."},
{"type": "tool_use", "id": "tool_1", "name": "read_file", "input": {"path": "/test.txt"}}
]
}]
}`,
wantHasReasoningContent: false,
wantReasoningContent: "",
},
{
name: "thinking enabled without tool_calls and no thinking content does not add reasoning_content",
inputJSON: `{
"model": "claude-3-opus",
"thinking": {"type": "enabled", "budget_tokens": 4000},
"messages": [{
"role": "assistant",
"content": [
{"type": "text", "text": "Simple response without tools."}
]
}]
}`,
wantHasReasoningContent: false,
wantReasoningContent: "",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := ConvertClaudeRequestToOpenAI("test-model", []byte(tt.inputJSON), false)
resultJSON := gjson.ParseBytes(result)
messages := resultJSON.Get("messages").Array()
if len(messages) == 0 {
t.Fatal("Expected at least one message")
}
assistantMsg := messages[0]
if assistantMsg.Get("role").String() != "assistant" {
t.Fatalf("Expected assistant message, got %s", assistantMsg.Get("role").String())
}
hasReasoningContent := assistantMsg.Get("reasoning_content").Exists()
if hasReasoningContent != tt.wantHasReasoningContent {
t.Errorf("reasoning_content existence = %v, want %v", hasReasoningContent, tt.wantHasReasoningContent)
}
if hasReasoningContent {
gotReasoningContent := assistantMsg.Get("reasoning_content").String()
if gotReasoningContent != tt.wantReasoningContent {
t.Errorf("reasoning_content = %q, want %q", gotReasoningContent, tt.wantReasoningContent)
}
}
})
}
}

View File

@@ -68,6 +68,9 @@ func ConvertOpenAIResponsesRequestToOpenAIChatCompletions(modelName string, inpu
case "message", "":
// Handle regular message conversion
role := item.Get("role").String()
if role == "developer" {
role = "user"
}
message := `{"role":"","content":""}`
message, _ = sjson.Set(message, "role", role)
@@ -167,7 +170,8 @@ func ConvertOpenAIResponsesRequestToOpenAIChatCompletions(modelName string, inpu
// Only function tools need structural conversion because Chat Completions nests details under "function".
toolType := tool.Get("type").String()
if toolType != "" && toolType != "function" && tool.IsObject() {
chatCompletionsTools = append(chatCompletionsTools, tool.Value())
// Almost all providers lack built-in tools, so we just ignore them.
// chatCompletionsTools = append(chatCompletionsTools, tool.Value())
return true
}

View File

@@ -6,6 +6,7 @@ import (
"context"
"crypto/sha256"
"encoding/hex"
"encoding/json"
"fmt"
"io/fs"
"os"
@@ -15,6 +16,7 @@ import (
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
"github.com/router-for-me/CLIProxyAPI/v6/internal/watcher/diff"
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
log "github.com/sirupsen/logrus"
)
@@ -72,6 +74,7 @@ func (w *Watcher) reloadClients(rescanAuth bool, affectedOAuthProviders []string
w.clientsMutex.Lock()
w.lastAuthHashes = make(map[string]string)
w.lastAuthContents = make(map[string]*coreauth.Auth)
if resolvedAuthDir, errResolveAuthDir := util.ResolveAuthDir(cfg.AuthDir); errResolveAuthDir != nil {
log.Errorf("failed to resolve auth directory for hash cache: %v", errResolveAuthDir)
} else if resolvedAuthDir != "" {
@@ -84,6 +87,11 @@ func (w *Watcher) reloadClients(rescanAuth bool, affectedOAuthProviders []string
sum := sha256.Sum256(data)
normalizedPath := w.normalizeAuthPath(path)
w.lastAuthHashes[normalizedPath] = hex.EncodeToString(sum[:])
// Parse and cache auth content for future diff comparisons
var auth coreauth.Auth
if errParse := json.Unmarshal(data, &auth); errParse == nil {
w.lastAuthContents[normalizedPath] = &auth
}
}
}
return nil
@@ -127,6 +135,13 @@ func (w *Watcher) addOrUpdateClient(path string) {
curHash := hex.EncodeToString(sum[:])
normalized := w.normalizeAuthPath(path)
// Parse new auth content for diff comparison
var newAuth coreauth.Auth
if errParse := json.Unmarshal(data, &newAuth); errParse != nil {
log.Errorf("failed to parse auth file %s: %v", filepath.Base(path), errParse)
return
}
w.clientsMutex.Lock()
cfg := w.config
@@ -141,7 +156,26 @@ func (w *Watcher) addOrUpdateClient(path string) {
return
}
// Get old auth for diff comparison
var oldAuth *coreauth.Auth
if w.lastAuthContents != nil {
oldAuth = w.lastAuthContents[normalized]
}
// Compute and log field changes
if changes := diff.BuildAuthChangeDetails(oldAuth, &newAuth); len(changes) > 0 {
log.Debugf("auth field changes for %s:", filepath.Base(path))
for _, c := range changes {
log.Debugf(" %s", c)
}
}
// Update caches
w.lastAuthHashes[normalized] = curHash
if w.lastAuthContents == nil {
w.lastAuthContents = make(map[string]*coreauth.Auth)
}
w.lastAuthContents[normalized] = &newAuth
w.clientsMutex.Unlock() // Unlock before the callback
@@ -160,6 +194,7 @@ func (w *Watcher) removeClient(path string) {
cfg := w.config
delete(w.lastAuthHashes, normalized)
delete(w.lastAuthContents, normalized)
w.clientsMutex.Unlock() // Release the lock before the callback

View File

@@ -0,0 +1,44 @@
// auth_diff.go computes human-readable diffs for auth file field changes.
package diff
import (
"fmt"
"strings"
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
)
// BuildAuthChangeDetails computes a redacted, human-readable list of auth field changes.
// Only prefix, proxy_url, and disabled fields are tracked; sensitive data is never printed.
func BuildAuthChangeDetails(oldAuth, newAuth *coreauth.Auth) []string {
changes := make([]string, 0, 3)
// Handle nil cases by using empty Auth as default
if oldAuth == nil {
oldAuth = &coreauth.Auth{}
}
if newAuth == nil {
return changes
}
// Compare prefix
oldPrefix := strings.TrimSpace(oldAuth.Prefix)
newPrefix := strings.TrimSpace(newAuth.Prefix)
if oldPrefix != newPrefix {
changes = append(changes, fmt.Sprintf("prefix: %s -> %s", oldPrefix, newPrefix))
}
// Compare proxy_url (redacted)
oldProxy := strings.TrimSpace(oldAuth.ProxyURL)
newProxy := strings.TrimSpace(newAuth.ProxyURL)
if oldProxy != newProxy {
changes = append(changes, fmt.Sprintf("proxy_url: %s -> %s", formatProxyURL(oldProxy), formatProxyURL(newProxy)))
}
// Compare disabled
if oldAuth.Disabled != newAuth.Disabled {
changes = append(changes, fmt.Sprintf("disabled: %t -> %t", oldAuth.Disabled, newAuth.Disabled))
}
return changes
}

View File

@@ -27,6 +27,12 @@ func BuildConfigChangeDetails(oldCfg, newCfg *config.Config) []string {
if oldCfg.Debug != newCfg.Debug {
changes = append(changes, fmt.Sprintf("debug: %t -> %t", oldCfg.Debug, newCfg.Debug))
}
if oldCfg.Pprof.Enable != newCfg.Pprof.Enable {
changes = append(changes, fmt.Sprintf("pprof.enable: %t -> %t", oldCfg.Pprof.Enable, newCfg.Pprof.Enable))
}
if strings.TrimSpace(oldCfg.Pprof.Addr) != strings.TrimSpace(newCfg.Pprof.Addr) {
changes = append(changes, fmt.Sprintf("pprof.addr: %s -> %s", strings.TrimSpace(oldCfg.Pprof.Addr), strings.TrimSpace(newCfg.Pprof.Addr)))
}
if oldCfg.LoggingToFile != newCfg.LoggingToFile {
changes = append(changes, fmt.Sprintf("logging-to-file: %t -> %t", oldCfg.LoggingToFile, newCfg.LoggingToFile))
}

View File

@@ -38,6 +38,7 @@ type Watcher struct {
reloadCallback func(*config.Config)
watcher *fsnotify.Watcher
lastAuthHashes map[string]string
lastAuthContents map[string]*coreauth.Auth
lastRemoveTimes map[string]time.Time
lastConfigHash string
authQueue chan<- AuthUpdate

View File

@@ -255,15 +255,16 @@ func (h *BaseAPIHandler) GetContextWithCancel(handler interfaces.APIHandler, c *
parentCtx = logging.WithRequestID(parentCtx, requestID)
}
}
// Use requestCtx as base if available to preserve amp context values (fallback_models, etc.)
// Falls back to parentCtx if no request context
baseCtx := parentCtx
if requestCtx != nil {
baseCtx = requestCtx
newCtx, cancel := context.WithCancel(parentCtx)
if requestCtx != nil && requestCtx != parentCtx {
go func() {
select {
case <-requestCtx.Done():
cancel()
case <-newCtx.Done():
}
}()
}
newCtx, cancel := context.WithCancel(baseCtx)
newCtx = context.WithValue(newCtx, "gin", c)
newCtx = context.WithValue(newCtx, "handler", handler)
return newCtx, func(params ...interface{}) {

View File

@@ -18,7 +18,6 @@ import (
internalconfig "github.com/router-for-me/CLIProxyAPI/v6/internal/config"
"github.com/router-for-me/CLIProxyAPI/v6/internal/logging"
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
"github.com/router-for-me/CLIProxyAPI/v6/internal/routing/ctxkeys"
"github.com/router-for-me/CLIProxyAPI/v6/internal/thinking"
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
@@ -563,188 +562,192 @@ func (m *Manager) ExecuteStream(ctx context.Context, providers []string, req cli
return nil, &Error{Code: "auth_not_found", Message: "no auth available"}
}
func (m *Manager) executeWithFallback(
ctx context.Context,
initialProviders []string,
req cliproxyexecutor.Request,
opts cliproxyexecutor.Options,
exec func(ctx context.Context, executor ProviderExecutor, auth *Auth, provider, routeModel string) error,
) error {
routeModel := req.Model
providers := initialProviders
opts = ensureRequestedModelMetadata(opts, routeModel)
tried := make(map[string]struct{})
var lastErr error
// Track fallback models from context (provided by Amp module fallback_models key)
var fallbacks []string
if v := ctx.Value(ctxkeys.FallbackModels); v != nil {
if fs, ok := v.([]string); ok {
fallbacks = fs
}
}
fallbackIdx := -1
for {
auth, executor, provider, errPick := m.pickNextMixed(ctx, providers, routeModel, opts, tried)
if errPick != nil {
// No more auths for current model. Try next fallback model if available.
if fallbackIdx+1 < len(fallbacks) {
fallbackIdx++
routeModel = fallbacks[fallbackIdx]
log.Debugf("no more auths for current model, trying fallback model: %s (fallback %d/%d)", routeModel, fallbackIdx+1, len(fallbacks))
// Reset tried set for the new model and find its providers
tried = make(map[string]struct{})
providers = util.GetProviderName(thinking.ParseSuffix(routeModel).ModelName)
// Reset opts for the new model
opts = ensureRequestedModelMetadata(opts, routeModel)
if len(providers) == 0 {
log.Debugf("fallback model %s has no providers, skipping", routeModel)
continue // Try next fallback if this one has no providers
}
continue
}
if lastErr != nil {
return lastErr
}
return errPick
}
tried[auth.ID] = struct{}{}
if err := exec(ctx, executor, auth, provider, routeModel); err != nil {
if errCtx := ctx.Err(); errCtx != nil {
return errCtx
}
lastErr = err
continue
}
return nil
}
}
func (m *Manager) executeMixedAttempt(
ctx context.Context,
auth *Auth,
provider, routeModel string,
req cliproxyexecutor.Request,
opts cliproxyexecutor.Options,
exec func(ctx context.Context, execReq cliproxyexecutor.Request) error,
) error {
entry := logEntryWithRequestID(ctx)
debugLogAuthSelection(entry, auth, provider, req.Model)
execCtx := ctx
if rt := m.roundTripperFor(auth); rt != nil {
execCtx = context.WithValue(execCtx, roundTripperContextKey{}, rt)
execCtx = context.WithValue(execCtx, "cliproxy.roundtripper", rt)
}
execReq := req
execReq.Model = rewriteModelForAuth(routeModel, auth)
execReq.Model = m.applyOAuthModelAlias(auth, execReq.Model)
execReq.Model = m.applyAPIKeyModelAlias(auth, execReq.Model)
err := exec(execCtx, execReq)
result := Result{AuthID: auth.ID, Provider: provider, Model: routeModel, Success: err == nil}
if err != nil {
result.Error = &Error{Message: err.Error()}
var se cliproxyexecutor.StatusError
if errors.As(err, &se) && se != nil {
result.Error.HTTPStatus = se.StatusCode()
}
if ra := retryAfterFromError(err); ra != nil {
result.RetryAfter = ra
}
}
m.MarkResult(execCtx, result)
return err
}
func (m *Manager) executeMixedOnce(ctx context.Context, providers []string, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
if len(providers) == 0 {
return cliproxyexecutor.Response{}, &Error{Code: "provider_not_found", Message: "no provider supplied"}
}
routeModel := req.Model
opts = ensureRequestedModelMetadata(opts, routeModel)
tried := make(map[string]struct{})
var lastErr error
for {
auth, executor, provider, errPick := m.pickNextMixed(ctx, providers, routeModel, opts, tried)
if errPick != nil {
if lastErr != nil {
return cliproxyexecutor.Response{}, lastErr
}
return cliproxyexecutor.Response{}, errPick
}
var resp cliproxyexecutor.Response
err := m.executeWithFallback(ctx, providers, req, opts, func(ctx context.Context, executor ProviderExecutor, auth *Auth, provider, routeModel string) error {
return m.executeMixedAttempt(ctx, auth, provider, routeModel, req, opts, func(execCtx context.Context, execReq cliproxyexecutor.Request) error {
var errExec error
resp, errExec = executor.Execute(execCtx, auth, execReq, opts)
return errExec
})
})
return resp, err
entry := logEntryWithRequestID(ctx)
debugLogAuthSelection(entry, auth, provider, req.Model)
tried[auth.ID] = struct{}{}
execCtx := ctx
if rt := m.roundTripperFor(auth); rt != nil {
execCtx = context.WithValue(execCtx, roundTripperContextKey{}, rt)
execCtx = context.WithValue(execCtx, "cliproxy.roundtripper", rt)
}
execReq := req
execReq.Model = rewriteModelForAuth(routeModel, auth)
execReq.Model = m.applyOAuthModelAlias(auth, execReq.Model)
execReq.Model = m.applyAPIKeyModelAlias(auth, execReq.Model)
resp, errExec := executor.Execute(execCtx, auth, execReq, opts)
result := Result{AuthID: auth.ID, Provider: provider, Model: routeModel, Success: errExec == nil}
if errExec != nil {
if errCtx := execCtx.Err(); errCtx != nil {
return cliproxyexecutor.Response{}, errCtx
}
result.Error = &Error{Message: errExec.Error()}
var se cliproxyexecutor.StatusError
if errors.As(errExec, &se) && se != nil {
result.Error.HTTPStatus = se.StatusCode()
}
if ra := retryAfterFromError(errExec); ra != nil {
result.RetryAfter = ra
}
m.MarkResult(execCtx, result)
lastErr = errExec
continue
}
m.MarkResult(execCtx, result)
return resp, nil
}
}
func (m *Manager) executeCountMixedOnce(ctx context.Context, providers []string, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
if len(providers) == 0 {
return cliproxyexecutor.Response{}, &Error{Code: "provider_not_found", Message: "no provider supplied"}
}
routeModel := req.Model
opts = ensureRequestedModelMetadata(opts, routeModel)
tried := make(map[string]struct{})
var lastErr error
for {
auth, executor, provider, errPick := m.pickNextMixed(ctx, providers, routeModel, opts, tried)
if errPick != nil {
if lastErr != nil {
return cliproxyexecutor.Response{}, lastErr
}
return cliproxyexecutor.Response{}, errPick
}
var resp cliproxyexecutor.Response
err := m.executeWithFallback(ctx, providers, req, opts, func(ctx context.Context, executor ProviderExecutor, auth *Auth, provider, routeModel string) error {
return m.executeMixedAttempt(ctx, auth, provider, routeModel, req, opts, func(execCtx context.Context, execReq cliproxyexecutor.Request) error {
var errExec error
resp, errExec = executor.CountTokens(execCtx, auth, execReq, opts)
return errExec
})
})
return resp, err
entry := logEntryWithRequestID(ctx)
debugLogAuthSelection(entry, auth, provider, req.Model)
tried[auth.ID] = struct{}{}
execCtx := ctx
if rt := m.roundTripperFor(auth); rt != nil {
execCtx = context.WithValue(execCtx, roundTripperContextKey{}, rt)
execCtx = context.WithValue(execCtx, "cliproxy.roundtripper", rt)
}
execReq := req
execReq.Model = rewriteModelForAuth(routeModel, auth)
execReq.Model = m.applyOAuthModelAlias(auth, execReq.Model)
execReq.Model = m.applyAPIKeyModelAlias(auth, execReq.Model)
resp, errExec := executor.CountTokens(execCtx, auth, execReq, opts)
result := Result{AuthID: auth.ID, Provider: provider, Model: routeModel, Success: errExec == nil}
if errExec != nil {
if errCtx := execCtx.Err(); errCtx != nil {
return cliproxyexecutor.Response{}, errCtx
}
result.Error = &Error{Message: errExec.Error()}
var se cliproxyexecutor.StatusError
if errors.As(errExec, &se) && se != nil {
result.Error.HTTPStatus = se.StatusCode()
}
if ra := retryAfterFromError(errExec); ra != nil {
result.RetryAfter = ra
}
m.MarkResult(execCtx, result)
lastErr = errExec
continue
}
m.MarkResult(execCtx, result)
return resp, nil
}
}
func (m *Manager) executeStreamMixedOnce(ctx context.Context, providers []string, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (<-chan cliproxyexecutor.StreamChunk, error) {
if len(providers) == 0 {
return nil, &Error{Code: "provider_not_found", Message: "no provider supplied"}
}
var chunks <-chan cliproxyexecutor.StreamChunk
err := m.executeWithFallback(ctx, providers, req, opts, func(ctx context.Context, executor ProviderExecutor, auth *Auth, provider, routeModel string) error {
return m.executeMixedAttempt(ctx, auth, provider, routeModel, req, opts, func(execCtx context.Context, execReq cliproxyexecutor.Request) error {
var errExec error
chunks, errExec = executor.ExecuteStream(execCtx, auth, execReq, opts)
if errExec != nil {
return errExec
routeModel := req.Model
opts = ensureRequestedModelMetadata(opts, routeModel)
tried := make(map[string]struct{})
var lastErr error
for {
auth, executor, provider, errPick := m.pickNextMixed(ctx, providers, routeModel, opts, tried)
if errPick != nil {
if lastErr != nil {
return nil, lastErr
}
return nil, errPick
}
out := make(chan cliproxyexecutor.StreamChunk)
go func(streamCtx context.Context, streamAuth *Auth, streamProvider string, streamChunks <-chan cliproxyexecutor.StreamChunk) {
defer close(out)
var failed bool
forward := true
for chunk := range streamChunks {
if chunk.Err != nil && !failed {
failed = true
rerr := &Error{Message: chunk.Err.Error()}
var se cliproxyexecutor.StatusError
if errors.As(chunk.Err, &se) && se != nil {
rerr.HTTPStatus = se.StatusCode()
}
m.MarkResult(streamCtx, Result{AuthID: streamAuth.ID, Provider: streamProvider, Model: routeModel, Success: false, Error: rerr})
}
if !forward {
continue
}
if streamCtx == nil {
out <- chunk
continue
}
select {
case <-streamCtx.Done():
forward = false
case out <- chunk:
entry := logEntryWithRequestID(ctx)
debugLogAuthSelection(entry, auth, provider, req.Model)
tried[auth.ID] = struct{}{}
execCtx := ctx
if rt := m.roundTripperFor(auth); rt != nil {
execCtx = context.WithValue(execCtx, roundTripperContextKey{}, rt)
execCtx = context.WithValue(execCtx, "cliproxy.roundtripper", rt)
}
execReq := req
execReq.Model = rewriteModelForAuth(routeModel, auth)
execReq.Model = m.applyOAuthModelAlias(auth, execReq.Model)
execReq.Model = m.applyAPIKeyModelAlias(auth, execReq.Model)
chunks, errStream := executor.ExecuteStream(execCtx, auth, execReq, opts)
if errStream != nil {
if errCtx := execCtx.Err(); errCtx != nil {
return nil, errCtx
}
rerr := &Error{Message: errStream.Error()}
var se cliproxyexecutor.StatusError
if errors.As(errStream, &se) && se != nil {
rerr.HTTPStatus = se.StatusCode()
}
result := Result{AuthID: auth.ID, Provider: provider, Model: routeModel, Success: false, Error: rerr}
result.RetryAfter = retryAfterFromError(errStream)
m.MarkResult(execCtx, result)
lastErr = errStream
continue
}
out := make(chan cliproxyexecutor.StreamChunk)
go func(streamCtx context.Context, streamAuth *Auth, streamProvider string, streamChunks <-chan cliproxyexecutor.StreamChunk) {
defer close(out)
var failed bool
forward := true
for chunk := range streamChunks {
if chunk.Err != nil && !failed {
failed = true
rerr := &Error{Message: chunk.Err.Error()}
var se cliproxyexecutor.StatusError
if errors.As(chunk.Err, &se) && se != nil {
rerr.HTTPStatus = se.StatusCode()
}
m.MarkResult(streamCtx, Result{AuthID: streamAuth.ID, Provider: streamProvider, Model: routeModel, Success: false, Error: rerr})
}
if !failed {
m.MarkResult(streamCtx, Result{AuthID: streamAuth.ID, Provider: streamProvider, Model: routeModel, Success: true})
if !forward {
continue
}
}(execCtx, auth.Clone(), provider, chunks)
chunks = out
return nil
})
})
return chunks, err
if streamCtx == nil {
out <- chunk
continue
}
select {
case <-streamCtx.Done():
forward = false
case out <- chunk:
}
}
if !failed {
m.MarkResult(streamCtx, Result{AuthID: streamAuth.ID, Provider: streamProvider, Model: routeModel, Success: true})
}
}(execCtx, auth.Clone(), provider, chunks)
return out, nil
}
}
func ensureRequestedModelMetadata(opts cliproxyexecutor.Options, requestedModel string) cliproxyexecutor.Options {

View File

@@ -0,0 +1,163 @@
package cliproxy
import (
"context"
"errors"
"net/http"
"net/http/pprof"
"strings"
"sync"
"time"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
log "github.com/sirupsen/logrus"
)
type pprofServer struct {
mu sync.Mutex
server *http.Server
addr string
enabled bool
}
func newPprofServer() *pprofServer {
return &pprofServer{}
}
func (s *Service) applyPprofConfig(cfg *config.Config) {
if s == nil || cfg == nil {
return
}
if s.pprofServer == nil {
s.pprofServer = newPprofServer()
}
s.pprofServer.Apply(cfg)
}
func (s *Service) shutdownPprof(ctx context.Context) error {
if s == nil || s.pprofServer == nil {
return nil
}
return s.pprofServer.Shutdown(ctx)
}
func (p *pprofServer) Apply(cfg *config.Config) {
if p == nil || cfg == nil {
return
}
addr := strings.TrimSpace(cfg.Pprof.Addr)
if addr == "" {
addr = config.DefaultPprofAddr
}
enabled := cfg.Pprof.Enable
p.mu.Lock()
currentServer := p.server
currentAddr := p.addr
p.addr = addr
p.enabled = enabled
if !enabled {
p.server = nil
p.mu.Unlock()
if currentServer != nil {
p.stopServer(currentServer, currentAddr, "disabled")
}
return
}
if currentServer != nil && currentAddr == addr {
p.mu.Unlock()
return
}
p.server = nil
p.mu.Unlock()
if currentServer != nil {
p.stopServer(currentServer, currentAddr, "restarted")
}
p.startServer(addr)
}
func (p *pprofServer) Shutdown(ctx context.Context) error {
if p == nil {
return nil
}
p.mu.Lock()
currentServer := p.server
currentAddr := p.addr
p.server = nil
p.enabled = false
p.mu.Unlock()
if currentServer == nil {
return nil
}
return p.stopServerWithContext(ctx, currentServer, currentAddr, "shutdown")
}
func (p *pprofServer) startServer(addr string) {
mux := newPprofMux()
server := &http.Server{
Addr: addr,
Handler: mux,
ReadHeaderTimeout: 5 * time.Second,
}
p.mu.Lock()
if !p.enabled || p.addr != addr || p.server != nil {
p.mu.Unlock()
return
}
p.server = server
p.mu.Unlock()
log.Infof("pprof server starting on %s", addr)
go func() {
if errServe := server.ListenAndServe(); errServe != nil && !errors.Is(errServe, http.ErrServerClosed) {
log.Errorf("pprof server failed on %s: %v", addr, errServe)
p.mu.Lock()
if p.server == server {
p.server = nil
}
p.mu.Unlock()
}
}()
}
func (p *pprofServer) stopServer(server *http.Server, addr string, reason string) {
_ = p.stopServerWithContext(context.Background(), server, addr, reason)
}
func (p *pprofServer) stopServerWithContext(ctx context.Context, server *http.Server, addr string, reason string) error {
if server == nil {
return nil
}
stopCtx := ctx
if stopCtx == nil {
stopCtx = context.Background()
}
stopCtx, cancel := context.WithTimeout(stopCtx, 5*time.Second)
defer cancel()
if errStop := server.Shutdown(stopCtx); errStop != nil {
log.Errorf("pprof server stop failed on %s: %v", addr, errStop)
return errStop
}
log.Infof("pprof server stopped on %s (%s)", addr, reason)
return nil
}
func newPprofMux() *http.ServeMux {
mux := http.NewServeMux()
mux.HandleFunc("/debug/pprof/", pprof.Index)
mux.HandleFunc("/debug/pprof/cmdline", pprof.Cmdline)
mux.HandleFunc("/debug/pprof/profile", pprof.Profile)
mux.HandleFunc("/debug/pprof/symbol", pprof.Symbol)
mux.HandleFunc("/debug/pprof/trace", pprof.Trace)
mux.Handle("/debug/pprof/allocs", pprof.Handler("allocs"))
mux.Handle("/debug/pprof/block", pprof.Handler("block"))
mux.Handle("/debug/pprof/goroutine", pprof.Handler("goroutine"))
mux.Handle("/debug/pprof/heap", pprof.Handler("heap"))
mux.Handle("/debug/pprof/mutex", pprof.Handler("mutex"))
mux.Handle("/debug/pprof/threadcreate", pprof.Handler("threadcreate"))
return mux
}

View File

@@ -57,6 +57,9 @@ type Service struct {
// server is the HTTP API server instance.
server *api.Server
// pprofServer manages the optional pprof HTTP debug server.
pprofServer *pprofServer
// serverErr channel for server startup/shutdown errors.
serverErr chan error
@@ -270,27 +273,42 @@ func (s *Service) wsOnDisconnected(channelID string, reason error) {
}
func (s *Service) applyCoreAuthAddOrUpdate(ctx context.Context, auth *coreauth.Auth) {
if s == nil || auth == nil || auth.ID == "" {
return
}
if s.coreManager == nil {
if s == nil || s.coreManager == nil || auth == nil || auth.ID == "" {
return
}
auth = auth.Clone()
s.ensureExecutorsForAuth(auth)
s.registerModelsForAuth(auth)
if existing, ok := s.coreManager.GetByID(auth.ID); ok && existing != nil {
// IMPORTANT: Update coreManager FIRST, before model registration.
// This ensures that configuration changes (proxy_url, prefix, etc.) take effect
// immediately for API calls, rather than waiting for model registration to complete.
// Model registration may involve network calls (e.g., FetchAntigravityModels) that
// could timeout if the new proxy_url is unreachable.
op := "register"
var err error
if existing, ok := s.coreManager.GetByID(auth.ID); ok {
auth.CreatedAt = existing.CreatedAt
auth.LastRefreshedAt = existing.LastRefreshedAt
auth.NextRefreshAfter = existing.NextRefreshAfter
if _, err := s.coreManager.Update(ctx, auth); err != nil {
log.Errorf("failed to update auth %s: %v", auth.ID, err)
op = "update"
_, err = s.coreManager.Update(ctx, auth)
} else {
_, err = s.coreManager.Register(ctx, auth)
}
if err != nil {
log.Errorf("failed to %s auth %s: %v", op, auth.ID, err)
current, ok := s.coreManager.GetByID(auth.ID)
if !ok || current.Disabled {
GlobalModelRegistry().UnregisterClient(auth.ID)
return
}
return
}
if _, err := s.coreManager.Register(ctx, auth); err != nil {
log.Errorf("failed to register auth %s: %v", auth.ID, err)
auth = current
}
// Register models after auth is updated in coreManager.
// This operation may block on network calls, but the auth configuration
// is already effective at this point.
s.registerModelsForAuth(auth)
}
func (s *Service) applyCoreAuthRemoval(ctx context.Context, id string) {
@@ -501,6 +519,8 @@ func (s *Service) Run(ctx context.Context) error {
time.Sleep(100 * time.Millisecond)
fmt.Printf("API server started successfully on: %s:%d\n", s.cfg.Host, s.cfg.Port)
s.applyPprofConfig(s.cfg)
if s.hooks.OnAfterStart != nil {
s.hooks.OnAfterStart(s)
}
@@ -546,6 +566,7 @@ func (s *Service) Run(ctx context.Context) error {
}
s.applyRetryConfig(newCfg)
s.applyPprofConfig(newCfg)
if s.server != nil {
s.server.UpdateClients(newCfg)
}
@@ -639,6 +660,13 @@ func (s *Service) Shutdown(ctx context.Context) error {
s.authQueueStop = nil
}
if errShutdownPprof := s.shutdownPprof(ctx); errShutdownPprof != nil {
log.Errorf("failed to stop pprof server: %v", errShutdownPprof)
if shutdownErr == nil {
shutdownErr = errShutdownPprof
}
}
// no legacy clients to persist
if s.server != nil {