mirror of
https://github.com/router-for-me/CLIProxyAPI.git
synced 2026-02-10 00:10:51 +08:00
Compare commits
10 Commits
v6.8.1
...
feature/am
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
c548c5d49f | ||
|
|
a424396a87 | ||
|
|
24b4bee500 | ||
|
|
9299897e04 | ||
|
|
527a269799 | ||
|
|
2fe0b6cd2d | ||
|
|
eeb1812d60 | ||
|
|
adedb16d35 | ||
|
|
89907231c1 | ||
|
|
09044e8ccc |
@@ -30,10 +30,6 @@ Get 10% OFF GLM CODING PLAN:https://z.ai/subscribe?ic=8JVLJQFSKB
|
||||
<td width="180"><a href="https://cubence.com/signup?code=CLIPROXYAPI&source=cpa"><img src="./assets/cubence.png" alt="Cubence" width="150"></a></td>
|
||||
<td>Thanks to Cubence for sponsoring this project! Cubence is a reliable and efficient API relay service provider, offering relay services for Claude Code, Codex, Gemini, and more. Cubence provides special discounts for our software users: register using <a href="https://cubence.com/signup?code=CLIPROXYAPI&source=cpa">this link</a> and enter the "CLIPROXYAPI" promo code during recharge to get 10% off.</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td width="180"><a href="https://www.aicodemirror.com/register?invitecode=TJNAIF"><img src="./assets/aicodemirror.png" alt="AICodeMirror" width="150"></a></td>
|
||||
<td>Thanks to AICodeMirror for sponsoring this project! AICodeMirror provides official high-stability relay services for Claude Code / Codex / Gemini CLI, with enterprise-grade concurrency, fast invoicing, and 24/7 dedicated technical support. Claude Code / Codex / Gemini official channels at 38% / 2% / 9% of original price, with extra discounts on top-ups! AICodeMirror offers special benefits for CLIProxyAPI users: register via <a href="https://www.aicodemirror.com/register?invitecode=TJNAIF">this link</a> to enjoy 20% off your first top-up, and enterprise customers can get up to 25% off!</td>
|
||||
</tr>
|
||||
</tbody>
|
||||
</table>
|
||||
|
||||
@@ -146,10 +142,6 @@ A lightweight web admin panel for CLIProxyAPI with health checks, resource monit
|
||||
|
||||
A Windows tray application implemented using PowerShell scripts, without relying on any third-party libraries. The main features include: automatic creation of shortcuts, silent running, password management, channel switching (Main / Plus), and automatic downloading and updating.
|
||||
|
||||
### [霖君](https://github.com/wangdabaoqq/LinJun)
|
||||
|
||||
霖君 is a cross-platform desktop application for managing AI programming assistants, supporting macOS, Windows, and Linux systems. Unified management of Claude Code, Gemini CLI, OpenAI Codex, Qwen Code, and other AI coding tools, with local proxy for multi-account quota tracking and one-click configuration.
|
||||
|
||||
> [!NOTE]
|
||||
> If you developed a project based on CLIProxyAPI, please open a PR to add it to this list.
|
||||
|
||||
|
||||
16
README_CN.md
16
README_CN.md
@@ -30,10 +30,6 @@ GLM CODING PLAN 是专为AI编码打造的订阅套餐,每月最低仅需20元
|
||||
<td width="180"><a href="https://cubence.com/signup?code=CLIPROXYAPI&source=cpa"><img src="./assets/cubence.png" alt="Cubence" width="150"></a></td>
|
||||
<td>感谢 Cubence 对本项目的赞助!Cubence 是一家可靠高效的 API 中转服务商,提供 Claude Code、Codex、Gemini 等多种服务的中转。Cubence 为本软件用户提供了特别优惠:使用<a href="https://cubence.com/signup?code=CLIPROXYAPI&source=cpa">此链接</a>注册,并在充值时输入 "CLIPROXYAPI" 优惠码即可享受九折优惠。</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td width="180"><a href="https://www.aicodemirror.com/register?invitecode=TJNAIF"><img src="./assets/aicodemirror.png" alt="AICodeMirror" width="150"></a></td>
|
||||
<td>感谢 AICodeMirror 赞助了本项目!AICodeMirror 提供 Claude Code / Codex / Gemini CLI 官方高稳定中转服务,支持企业级高并发、极速开票、7×24 专属技术支持。 Claude Code / Codex / Gemini 官方渠道低至 3.8 / 0.2 / 0.9 折,充值更有折上折!AICodeMirror 为 CLIProxyAPI 的用户提供了特别福利,通过<a href="https://www.aicodemirror.com/register?invitecode=TJNAIF">此链接</a>注册的用户,可享受首充8折,企业客户最高可享 7.5 折!</td>
|
||||
</tr>
|
||||
</tbody>
|
||||
</table>
|
||||
|
||||
@@ -141,14 +137,6 @@ Windows 桌面应用,基于 Tauri + React 构建,用于通过 CLIProxyAPI
|
||||
|
||||
面向 CLIProxyAPI 的 Web 管理面板,提供健康检查、资源监控、日志查看、自动更新、请求统计与定价展示,支持一键安装与 systemd 服务。
|
||||
|
||||
### [CLIProxyAPI Tray](https://github.com/kitephp/CLIProxyAPI_Tray)
|
||||
|
||||
Windows 托盘应用,基于 PowerShell 脚本实现,不依赖任何第三方库。主要功能包括:自动创建快捷方式、静默运行、密码管理、通道切换(Main / Plus)以及自动下载与更新。
|
||||
|
||||
### [霖君](https://github.com/wangdabaoqq/LinJun)
|
||||
|
||||
霖君是一款用于管理AI编程助手的跨平台桌面应用,支持macOS、Windows、Linux系统。统一管理Claude Code、Gemini CLI、OpenAI Codex、Qwen Code等AI编程工具,本地代理实现多账户配额跟踪和一键配置。
|
||||
|
||||
> [!NOTE]
|
||||
> 如果你开发了基于 CLIProxyAPI 的项目,请提交一个 PR(拉取请求)将其添加到此列表中。
|
||||
|
||||
@@ -160,6 +148,10 @@ Windows 托盘应用,基于 PowerShell 脚本实现,不依赖任何第三方
|
||||
|
||||
基于 Next.js 的实现,灵感来自 CLIProxyAPI,易于安装使用;自研格式转换(OpenAI/Claude/Gemini/Ollama)、组合系统与自动回退、多账户管理(指数退避)、Next.js Web 控制台,并支持 Cursor、Claude Code、Cline、RooCode 等 CLI 工具,无需 API 密钥。
|
||||
|
||||
### [CLIProxyAPI Tray](https://github.com/kitephp/CLIProxyAPI_Tray)
|
||||
|
||||
Windows 托盘应用,基于 PowerShell 脚本实现,不依赖任何第三方库。主要功能包括:自动创建快捷方式、静默运行、密码管理、通道切换(Main / Plus)以及自动下载与更新。
|
||||
|
||||
> [!NOTE]
|
||||
> 如果你开发了 CLIProxyAPI 的移植或衍生项目,请提交 PR 将其添加到此列表中。
|
||||
|
||||
|
||||
Binary file not shown.
|
Before Width: | Height: | Size: 45 KiB |
@@ -63,7 +63,6 @@ func main() {
|
||||
var noBrowser bool
|
||||
var oauthCallbackPort int
|
||||
var antigravityLogin bool
|
||||
var kimiLogin bool
|
||||
var projectID string
|
||||
var vertexImport string
|
||||
var configPath string
|
||||
@@ -79,7 +78,6 @@ func main() {
|
||||
flag.BoolVar(&noBrowser, "no-browser", false, "Don't open browser automatically for OAuth")
|
||||
flag.IntVar(&oauthCallbackPort, "oauth-callback-port", 0, "Override OAuth callback port (defaults to provider-specific port)")
|
||||
flag.BoolVar(&antigravityLogin, "antigravity-login", false, "Login to Antigravity using OAuth")
|
||||
flag.BoolVar(&kimiLogin, "kimi-login", false, "Login to Kimi using OAuth")
|
||||
flag.StringVar(&projectID, "project_id", "", "Project ID (Gemini only, not required)")
|
||||
flag.StringVar(&configPath, "config", DefaultConfigPath, "Configure File Path")
|
||||
flag.StringVar(&vertexImport, "vertex-import", "", "Import Vertex service account key JSON file")
|
||||
@@ -470,8 +468,6 @@ func main() {
|
||||
cmd.DoIFlowLogin(cfg, options)
|
||||
} else if iflowCookie {
|
||||
cmd.DoIFlowCookieAuth(cfg, options)
|
||||
} else if kimiLogin {
|
||||
cmd.DoKimiLogin(cfg, options)
|
||||
} else {
|
||||
// In cloud deploy mode without config file, just wait for shutdown signals
|
||||
if isCloudDeploy && !configFileExists {
|
||||
|
||||
@@ -40,11 +40,6 @@ api-keys:
|
||||
# Enable debug logging
|
||||
debug: false
|
||||
|
||||
# Enable pprof HTTP debug server (host:port). Keep it bound to localhost for safety.
|
||||
pprof:
|
||||
enable: false
|
||||
addr: "127.0.0.1:8316"
|
||||
|
||||
# When true, disable high-overhead HTTP middleware features to reduce per-request memory usage under high concurrency.
|
||||
commercial-mode: false
|
||||
|
||||
@@ -221,7 +216,7 @@ nonstream-keepalive-interval: 0
|
||||
|
||||
# Global OAuth model name aliases (per channel)
|
||||
# These aliases rename model IDs for both model listing and request routing.
|
||||
# Supported channels: gemini-cli, vertex, aistudio, antigravity, claude, codex, qwen, iflow, kimi.
|
||||
# Supported channels: gemini-cli, vertex, aistudio, antigravity, claude, codex, qwen, iflow.
|
||||
# NOTE: Aliases do not apply to gemini-api-key, codex-api-key, claude-api-key, openai-compatibility, vertex-api-key, or ampcode.
|
||||
# You can repeat the same name with different aliases to expose multiple client model names.
|
||||
oauth-model-alias:
|
||||
@@ -262,9 +257,6 @@ oauth-model-alias:
|
||||
# iflow:
|
||||
# - name: "glm-4.7"
|
||||
# alias: "glm-god"
|
||||
# kimi:
|
||||
# - name: "kimi-k2.5"
|
||||
# alias: "k2.5"
|
||||
|
||||
# OAuth provider excluded models
|
||||
# oauth-excluded-models:
|
||||
@@ -287,8 +279,6 @@ oauth-model-alias:
|
||||
# - "vision-model"
|
||||
# iflow:
|
||||
# - "tstars2.0"
|
||||
# kimi:
|
||||
# - "kimi-k2-thinking"
|
||||
|
||||
# Optional payload configuration
|
||||
# payload:
|
||||
|
||||
@@ -25,7 +25,6 @@ import (
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/auth/codex"
|
||||
geminiAuth "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/gemini"
|
||||
iflowauth "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/iflow"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/auth/kimi"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/auth/qwen"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/interfaces"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/misc"
|
||||
@@ -1609,82 +1608,6 @@ func (h *Handler) RequestQwenToken(c *gin.Context) {
|
||||
c.JSON(200, gin.H{"status": "ok", "url": authURL, "state": state})
|
||||
}
|
||||
|
||||
func (h *Handler) RequestKimiToken(c *gin.Context) {
|
||||
ctx := context.Background()
|
||||
|
||||
fmt.Println("Initializing Kimi authentication...")
|
||||
|
||||
state := fmt.Sprintf("kmi-%d", time.Now().UnixNano())
|
||||
// Initialize Kimi auth service
|
||||
kimiAuth := kimi.NewKimiAuth(h.cfg)
|
||||
|
||||
// Generate authorization URL
|
||||
deviceFlow, errStartDeviceFlow := kimiAuth.StartDeviceFlow(ctx)
|
||||
if errStartDeviceFlow != nil {
|
||||
log.Errorf("Failed to generate authorization URL: %v", errStartDeviceFlow)
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to generate authorization url"})
|
||||
return
|
||||
}
|
||||
authURL := deviceFlow.VerificationURIComplete
|
||||
if authURL == "" {
|
||||
authURL = deviceFlow.VerificationURI
|
||||
}
|
||||
|
||||
RegisterOAuthSession(state, "kimi")
|
||||
|
||||
go func() {
|
||||
fmt.Println("Waiting for authentication...")
|
||||
authBundle, errWaitForAuthorization := kimiAuth.WaitForAuthorization(ctx, deviceFlow)
|
||||
if errWaitForAuthorization != nil {
|
||||
SetOAuthSessionError(state, "Authentication failed")
|
||||
fmt.Printf("Authentication failed: %v\n", errWaitForAuthorization)
|
||||
return
|
||||
}
|
||||
|
||||
// Create token storage
|
||||
tokenStorage := kimiAuth.CreateTokenStorage(authBundle)
|
||||
|
||||
metadata := map[string]any{
|
||||
"type": "kimi",
|
||||
"access_token": authBundle.TokenData.AccessToken,
|
||||
"refresh_token": authBundle.TokenData.RefreshToken,
|
||||
"token_type": authBundle.TokenData.TokenType,
|
||||
"scope": authBundle.TokenData.Scope,
|
||||
"timestamp": time.Now().UnixMilli(),
|
||||
}
|
||||
if authBundle.TokenData.ExpiresAt > 0 {
|
||||
expired := time.Unix(authBundle.TokenData.ExpiresAt, 0).UTC().Format(time.RFC3339)
|
||||
metadata["expired"] = expired
|
||||
}
|
||||
if strings.TrimSpace(authBundle.DeviceID) != "" {
|
||||
metadata["device_id"] = strings.TrimSpace(authBundle.DeviceID)
|
||||
}
|
||||
|
||||
fileName := fmt.Sprintf("kimi-%d.json", time.Now().UnixMilli())
|
||||
record := &coreauth.Auth{
|
||||
ID: fileName,
|
||||
Provider: "kimi",
|
||||
FileName: fileName,
|
||||
Label: "Kimi User",
|
||||
Storage: tokenStorage,
|
||||
Metadata: metadata,
|
||||
}
|
||||
savedPath, errSave := h.saveTokenRecord(ctx, record)
|
||||
if errSave != nil {
|
||||
log.Errorf("Failed to save authentication tokens: %v", errSave)
|
||||
SetOAuthSessionError(state, "Failed to save authentication tokens")
|
||||
return
|
||||
}
|
||||
|
||||
fmt.Printf("Authentication successful! Token saved to %s\n", savedPath)
|
||||
fmt.Println("You can now use Kimi services through this CLI")
|
||||
CompleteOAuthSession(state)
|
||||
CompleteOAuthSessionsByProvider("kimi")
|
||||
}()
|
||||
|
||||
c.JSON(200, gin.H{"status": "ok", "url": authURL, "state": state})
|
||||
}
|
||||
|
||||
func (h *Handler) RequestIFlowToken(c *gin.Context) {
|
||||
ctx := context.Background()
|
||||
|
||||
|
||||
@@ -125,6 +125,8 @@ func (m *AmpModule) Register(ctx modules.Context) error {
|
||||
m.registerOnce.Do(func() {
|
||||
// Initialize model mapper from config (for routing unavailable models to alternatives)
|
||||
m.modelMapper = NewModelMapper(settings.ModelMappings)
|
||||
// Load oauth-model-alias for provider lookup via aliases
|
||||
m.modelMapper.UpdateOAuthModelAlias(ctx.Config.OAuthModelAlias)
|
||||
|
||||
// Store initial config for partial reload comparison
|
||||
settingsCopy := settings
|
||||
@@ -212,6 +214,11 @@ func (m *AmpModule) OnConfigUpdated(cfg *config.Config) error {
|
||||
}
|
||||
}
|
||||
|
||||
// Always update oauth-model-alias for model mapper (used for provider lookup)
|
||||
if m.modelMapper != nil {
|
||||
m.modelMapper.UpdateOAuthModelAlias(cfg.OAuthModelAlias)
|
||||
}
|
||||
|
||||
if m.enabled {
|
||||
// Check upstream URL change - now supports hot-reload
|
||||
if newUpstreamURL == "" && oldUpstreamURL != "" {
|
||||
|
||||
@@ -2,12 +2,15 @@ package amp
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httputil"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/routing/ctxkeys"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/thinking"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
||||
log "github.com/sirupsen/logrus"
|
||||
@@ -30,7 +33,13 @@ const (
|
||||
)
|
||||
|
||||
// MappedModelContextKey is the Gin context key for passing mapped model names.
|
||||
const MappedModelContextKey = "mapped_model"
|
||||
// Deprecated: Use ctxkeys.MappedModel instead.
|
||||
const MappedModelContextKey = string(ctxkeys.MappedModel)
|
||||
|
||||
// FallbackModelsContextKey is the Gin context key for passing fallback model names.
|
||||
// When the primary mapped model fails (e.g., quota exceeded), these models can be tried.
|
||||
// Deprecated: Use ctxkeys.FallbackModels instead.
|
||||
const FallbackModelsContextKey = string(ctxkeys.FallbackModels)
|
||||
|
||||
// logAmpRouting logs the routing decision for an Amp request with structured fields
|
||||
func logAmpRouting(routeType AmpRouteType, requestedModel, resolvedModel, provider, path string) {
|
||||
@@ -77,6 +86,10 @@ func logAmpRouting(routeType AmpRouteType, requestedModel, resolvedModel, provid
|
||||
|
||||
// FallbackHandler wraps a standard handler with fallback logic to ampcode.com
|
||||
// when the model's provider is not available in CLIProxyAPI
|
||||
//
|
||||
// Deprecated: FallbackHandler is deprecated in favor of routing.ModelRoutingWrapper.
|
||||
// Use routing.NewModelRoutingWrapper() instead for unified routing logic.
|
||||
// This type is kept for backward compatibility and test purposes.
|
||||
type FallbackHandler struct {
|
||||
getProxy func() *httputil.ReverseProxy
|
||||
modelMapper ModelMapper
|
||||
@@ -85,6 +98,8 @@ type FallbackHandler struct {
|
||||
|
||||
// NewFallbackHandler creates a new fallback handler wrapper
|
||||
// The getProxy function allows lazy evaluation of the proxy (useful when proxy is created after routes)
|
||||
//
|
||||
// Deprecated: Use routing.NewModelRoutingWrapper() instead.
|
||||
func NewFallbackHandler(getProxy func() *httputil.ReverseProxy) *FallbackHandler {
|
||||
return &FallbackHandler{
|
||||
getProxy: getProxy,
|
||||
@@ -93,6 +108,8 @@ func NewFallbackHandler(getProxy func() *httputil.ReverseProxy) *FallbackHandler
|
||||
}
|
||||
|
||||
// NewFallbackHandlerWithMapper creates a new fallback handler with model mapping support
|
||||
//
|
||||
// Deprecated: Use routing.NewModelRoutingWrapper() instead.
|
||||
func NewFallbackHandlerWithMapper(getProxy func() *httputil.ReverseProxy, mapper ModelMapper, forceModelMappings func() bool) *FallbackHandler {
|
||||
if forceModelMappings == nil {
|
||||
forceModelMappings = func() bool { return false }
|
||||
@@ -113,6 +130,20 @@ func (fh *FallbackHandler) SetModelMapper(mapper ModelMapper) {
|
||||
// If the model's provider is not configured in CLIProxyAPI, it forwards to ampcode.com
|
||||
func (fh *FallbackHandler) WrapHandler(handler gin.HandlerFunc) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
// Swallow ErrAbortHandler panics from ReverseProxy to avoid noisy stack traces.
|
||||
// ReverseProxy raises this panic when the client connection is closed prematurely
|
||||
// (e.g., user cancels request, network disconnect) or when ServeHTTP is called
|
||||
// with a ResponseWriter that doesn't implement http.CloseNotifier.
|
||||
// This is an expected error condition, not a bug, so we handle it gracefully.
|
||||
defer func() {
|
||||
if rec := recover(); rec != nil {
|
||||
if err, ok := rec.(error); ok && errors.Is(err, http.ErrAbortHandler) {
|
||||
return
|
||||
}
|
||||
panic(rec)
|
||||
}
|
||||
}()
|
||||
|
||||
requestPath := c.Request.URL.Path
|
||||
|
||||
// Read the request body to extract the model name
|
||||
@@ -142,36 +173,57 @@ func (fh *FallbackHandler) WrapHandler(handler gin.HandlerFunc) gin.HandlerFunc
|
||||
thinkingSuffix = "(" + suffixResult.RawSuffix + ")"
|
||||
}
|
||||
|
||||
resolveMappedModel := func() (string, []string) {
|
||||
// resolveMappedModels returns all mapped models (primary + fallbacks) and providers for the first one.
|
||||
resolveMappedModels := func() ([]string, []string) {
|
||||
if fh.modelMapper == nil {
|
||||
return "", nil
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
mappedModel := fh.modelMapper.MapModel(modelName)
|
||||
if mappedModel == "" {
|
||||
mappedModel = fh.modelMapper.MapModel(normalizedModel)
|
||||
}
|
||||
mappedModel = strings.TrimSpace(mappedModel)
|
||||
if mappedModel == "" {
|
||||
return "", nil
|
||||
mapper, ok := fh.modelMapper.(*DefaultModelMapper)
|
||||
if !ok {
|
||||
// Fallback to single model for non-DefaultModelMapper
|
||||
mappedModel := fh.modelMapper.MapModel(modelName)
|
||||
if mappedModel == "" {
|
||||
mappedModel = fh.modelMapper.MapModel(normalizedModel)
|
||||
}
|
||||
if mappedModel == "" {
|
||||
return nil, nil
|
||||
}
|
||||
mappedBaseModel := thinking.ParseSuffix(mappedModel).ModelName
|
||||
mappedProviders := util.GetProviderName(mappedBaseModel)
|
||||
if len(mappedProviders) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
return []string{mappedModel}, mappedProviders
|
||||
}
|
||||
|
||||
// Preserve dynamic thinking suffix (e.g. "(xhigh)") when mapping applies, unless the target
|
||||
// already specifies its own thinking suffix.
|
||||
if thinkingSuffix != "" {
|
||||
mappedSuffixResult := thinking.ParseSuffix(mappedModel)
|
||||
if !mappedSuffixResult.HasSuffix {
|
||||
mappedModel += thinkingSuffix
|
||||
// Use MapModelWithFallbacks for DefaultModelMapper
|
||||
mappedModels := mapper.MapModelWithFallbacks(modelName)
|
||||
if len(mappedModels) == 0 {
|
||||
mappedModels = mapper.MapModelWithFallbacks(normalizedModel)
|
||||
}
|
||||
if len(mappedModels) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// Apply thinking suffix if needed
|
||||
for i, model := range mappedModels {
|
||||
if thinkingSuffix != "" {
|
||||
suffixResult := thinking.ParseSuffix(model)
|
||||
if !suffixResult.HasSuffix {
|
||||
mappedModels[i] = model + thinkingSuffix
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
mappedBaseModel := thinking.ParseSuffix(mappedModel).ModelName
|
||||
mappedProviders := util.GetProviderName(mappedBaseModel)
|
||||
if len(mappedProviders) == 0 {
|
||||
return "", nil
|
||||
// Get providers for the first model
|
||||
firstBaseModel := thinking.ParseSuffix(mappedModels[0]).ModelName
|
||||
providers := util.GetProviderName(firstBaseModel)
|
||||
if len(providers) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
return mappedModel, mappedProviders
|
||||
return mappedModels, providers
|
||||
}
|
||||
|
||||
// Track resolved model for logging (may change if mapping is applied)
|
||||
@@ -179,21 +231,27 @@ func (fh *FallbackHandler) WrapHandler(handler gin.HandlerFunc) gin.HandlerFunc
|
||||
usedMapping := false
|
||||
var providers []string
|
||||
|
||||
// Helper to apply model mapping and update state
|
||||
applyMapping := func(mappedModels []string, mappedProviders []string) {
|
||||
bodyBytes = rewriteModelInRequest(bodyBytes, mappedModels[0])
|
||||
c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes))
|
||||
c.Set(string(ctxkeys.MappedModel), mappedModels[0])
|
||||
if len(mappedModels) > 1 {
|
||||
c.Set(string(ctxkeys.FallbackModels), mappedModels[1:])
|
||||
}
|
||||
resolvedModel = mappedModels[0]
|
||||
usedMapping = true
|
||||
providers = mappedProviders
|
||||
}
|
||||
|
||||
// Check if model mappings should be forced ahead of local API keys
|
||||
forceMappings := fh.forceModelMappings != nil && fh.forceModelMappings()
|
||||
|
||||
if forceMappings {
|
||||
// FORCE MODE: Check model mappings FIRST (takes precedence over local API keys)
|
||||
// This allows users to route Amp requests to their preferred OAuth providers
|
||||
if mappedModel, mappedProviders := resolveMappedModel(); mappedModel != "" {
|
||||
// Mapping found and provider available - rewrite the model in request body
|
||||
bodyBytes = rewriteModelInRequest(bodyBytes, mappedModel)
|
||||
c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes))
|
||||
// Store mapped model in context for handlers that check it (like gemini bridge)
|
||||
c.Set(MappedModelContextKey, mappedModel)
|
||||
resolvedModel = mappedModel
|
||||
usedMapping = true
|
||||
providers = mappedProviders
|
||||
if mappedModels, mappedProviders := resolveMappedModels(); len(mappedModels) > 0 {
|
||||
applyMapping(mappedModels, mappedProviders)
|
||||
}
|
||||
|
||||
// If no mapping applied, check for local providers
|
||||
@@ -206,15 +264,8 @@ func (fh *FallbackHandler) WrapHandler(handler gin.HandlerFunc) gin.HandlerFunc
|
||||
|
||||
if len(providers) == 0 {
|
||||
// No providers configured - check if we have a model mapping
|
||||
if mappedModel, mappedProviders := resolveMappedModel(); mappedModel != "" {
|
||||
// Mapping found and provider available - rewrite the model in request body
|
||||
bodyBytes = rewriteModelInRequest(bodyBytes, mappedModel)
|
||||
c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes))
|
||||
// Store mapped model in context for handlers that check it (like gemini bridge)
|
||||
c.Set(MappedModelContextKey, mappedModel)
|
||||
resolvedModel = mappedModel
|
||||
usedMapping = true
|
||||
providers = mappedProviders
|
||||
if mappedModels, mappedProviders := resolveMappedModels(); len(mappedModels) > 0 {
|
||||
applyMapping(mappedModels, mappedProviders)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -0,0 +1,326 @@
|
||||
package amp
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/http/httputil"
|
||||
"net/url"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/routing/testutil"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
// Characterization tests for fallback_handlers.go using testutil recorders
|
||||
// These tests capture existing behavior before refactoring to routing layer
|
||||
|
||||
func TestCharacterization_LocalProvider(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
// Register a mock provider for the test model
|
||||
reg := registry.GetGlobalRegistry()
|
||||
reg.RegisterClient("char-test-local", "anthropic", []*registry.ModelInfo{
|
||||
{ID: "test-model-local"},
|
||||
})
|
||||
defer reg.UnregisterClient("char-test-local")
|
||||
|
||||
// Setup recorders
|
||||
proxyRecorder := testutil.NewFakeProxyRecorder()
|
||||
handlerRecorder := testutil.NewFakeHandlerRecorder()
|
||||
|
||||
// Create gin context
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
|
||||
body := `{"model": "test-model-local", "messages": [{"role": "user", "content": "hello"}]}`
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/provider/anthropic/v1/messages", bytes.NewReader([]byte(body)))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
c.Request = req
|
||||
|
||||
// Create fallback handler with proxy recorder
|
||||
// Create a test server to act as the proxy target
|
||||
proxyServer := httptest.NewServer(proxyRecorder.ToHandler())
|
||||
defer proxyServer.Close()
|
||||
|
||||
fh := NewFallbackHandler(func() *httputil.ReverseProxy {
|
||||
// Create a reverse proxy that forwards to our test server
|
||||
targetURL, _ := url.Parse(proxyServer.URL)
|
||||
return httputil.NewSingleHostReverseProxy(targetURL)
|
||||
})
|
||||
|
||||
// Execute
|
||||
wrapped := fh.WrapHandler(handlerRecorder.GinHandler())
|
||||
wrapped(c)
|
||||
|
||||
// Assert: proxy NOT called
|
||||
assert.False(t, proxyRecorder.Called, "proxy should NOT be called for local provider")
|
||||
|
||||
// Assert: local handler called once
|
||||
assert.True(t, handlerRecorder.WasCalled(), "local handler should be called")
|
||||
assert.Equal(t, 1, handlerRecorder.GetCallCount(), "local handler should be called exactly once")
|
||||
|
||||
// Assert: request body model unchanged
|
||||
assert.Contains(t, string(handlerRecorder.RequestBody), "test-model-local", "request body model should be unchanged")
|
||||
}
|
||||
|
||||
func TestCharacterization_ModelMapping(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
// Register a mock provider for the TARGET model (the mapped-to model)
|
||||
reg := registry.GetGlobalRegistry()
|
||||
reg.RegisterClient("char-test-mapped", "openai", []*registry.ModelInfo{
|
||||
{ID: "gpt-4-local"},
|
||||
})
|
||||
defer reg.UnregisterClient("char-test-mapped")
|
||||
|
||||
// Setup recorders
|
||||
proxyRecorder := testutil.NewFakeProxyRecorder()
|
||||
handlerRecorder := testutil.NewFakeHandlerRecorder()
|
||||
|
||||
// Create model mapper with a mapping
|
||||
mapper := NewModelMapper([]config.AmpModelMapping{
|
||||
{From: "gpt-4-turbo", To: "gpt-4-local"},
|
||||
})
|
||||
|
||||
// Create gin context
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
|
||||
// Request with original model that gets mapped
|
||||
body := `{"model": "gpt-4-turbo", "messages": [{"role": "user", "content": "hello"}]}`
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/provider/openai/v1/chat/completions", bytes.NewReader([]byte(body)))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
c.Request = req
|
||||
|
||||
// Create fallback handler with mapper
|
||||
proxyServer := httptest.NewServer(proxyRecorder.ToHandler())
|
||||
defer proxyServer.Close()
|
||||
|
||||
fh := NewFallbackHandlerWithMapper(func() *httputil.ReverseProxy {
|
||||
targetURL, _ := url.Parse(proxyServer.URL)
|
||||
return httputil.NewSingleHostReverseProxy(targetURL)
|
||||
}, mapper, func() bool { return false })
|
||||
|
||||
// Execute - use handler that returns model in response for rewriter to work
|
||||
wrapped := fh.WrapHandler(handlerRecorder.GinHandlerWithModel())
|
||||
wrapped(c)
|
||||
|
||||
// Assert: proxy NOT called
|
||||
assert.False(t, proxyRecorder.Called, "proxy should NOT be called for model mapping")
|
||||
|
||||
// Assert: local handler called once
|
||||
assert.True(t, handlerRecorder.WasCalled(), "local handler should be called")
|
||||
assert.Equal(t, 1, handlerRecorder.GetCallCount(), "local handler should be called exactly once")
|
||||
|
||||
// Assert: request body model was rewritten to mapped model
|
||||
assert.Contains(t, string(handlerRecorder.RequestBody), "gpt-4-local", "request body model should be rewritten to mapped model")
|
||||
assert.NotContains(t, string(handlerRecorder.RequestBody), "gpt-4-turbo", "request body should NOT contain original model")
|
||||
|
||||
// Assert: context has mapped_model key set
|
||||
mappedModel, exists := handlerRecorder.GetContextKey("mapped_model")
|
||||
assert.True(t, exists, "context should have mapped_model key")
|
||||
assert.Equal(t, "gpt-4-local", mappedModel, "mapped_model should be the target model")
|
||||
|
||||
// Assert: response body model rewritten back to original
|
||||
// The response writer should rewrite model names in the response
|
||||
responseBody := w.Body.String()
|
||||
assert.Contains(t, responseBody, "gpt-4-turbo", "response should have original model name")
|
||||
}
|
||||
|
||||
func TestCharacterization_AmpCreditsProxy(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
// Setup recorders - NO local provider registered, NO mapping configured
|
||||
proxyRecorder := testutil.NewFakeProxyRecorder()
|
||||
handlerRecorder := testutil.NewFakeHandlerRecorder()
|
||||
|
||||
// Create gin context with CloseNotifier support (required for ReverseProxy)
|
||||
w := testutil.NewCloseNotifierRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
|
||||
// Request with a model that has no local provider and no mapping
|
||||
body := `{"model": "unknown-model-no-provider", "messages": [{"role": "user", "content": "hello"}]}`
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/provider/openai/v1/chat/completions", bytes.NewReader([]byte(body)))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
c.Request = req
|
||||
|
||||
// Create fallback handler
|
||||
proxyServer := httptest.NewServer(proxyRecorder.ToHandler())
|
||||
defer proxyServer.Close()
|
||||
|
||||
fh := NewFallbackHandler(func() *httputil.ReverseProxy {
|
||||
targetURL, _ := url.Parse(proxyServer.URL)
|
||||
return httputil.NewSingleHostReverseProxy(targetURL)
|
||||
})
|
||||
|
||||
// Execute
|
||||
wrapped := fh.WrapHandler(handlerRecorder.GinHandler())
|
||||
wrapped(c)
|
||||
|
||||
// Assert: proxy called once
|
||||
assert.True(t, proxyRecorder.Called, "proxy should be called when no local provider and no mapping")
|
||||
assert.Equal(t, 1, proxyRecorder.GetCallCount(), "proxy should be called exactly once")
|
||||
|
||||
// Assert: local handler NOT called
|
||||
assert.False(t, handlerRecorder.WasCalled(), "local handler should NOT be called when falling back to proxy")
|
||||
|
||||
// Assert: body forwarded to proxy is original (no rewrite)
|
||||
assert.Contains(t, string(proxyRecorder.RequestBody), "unknown-model-no-provider", "request body model should be unchanged when proxying")
|
||||
}
|
||||
|
||||
func TestCharacterization_BodyRestore(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
// Register a mock provider for the test model
|
||||
reg := registry.GetGlobalRegistry()
|
||||
reg.RegisterClient("char-test-body", "anthropic", []*registry.ModelInfo{
|
||||
{ID: "test-model-body"},
|
||||
})
|
||||
defer reg.UnregisterClient("char-test-body")
|
||||
|
||||
// Setup recorders
|
||||
proxyRecorder := testutil.NewFakeProxyRecorder()
|
||||
handlerRecorder := testutil.NewFakeHandlerRecorder()
|
||||
|
||||
// Create gin context
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
|
||||
// Create a complex request body that will be read by the wrapper for model extraction
|
||||
originalBody := `{"model": "test-model-body", "messages": [{"role": "user", "content": "hello"}], "temperature": 0.7, "stream": true}`
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/provider/anthropic/v1/messages", bytes.NewReader([]byte(originalBody)))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
c.Request = req
|
||||
|
||||
// Create fallback handler with proxy recorder
|
||||
proxyServer := httptest.NewServer(proxyRecorder.ToHandler())
|
||||
defer proxyServer.Close()
|
||||
|
||||
fh := NewFallbackHandler(func() *httputil.ReverseProxy {
|
||||
targetURL, _ := url.Parse(proxyServer.URL)
|
||||
return httputil.NewSingleHostReverseProxy(targetURL)
|
||||
})
|
||||
|
||||
// Execute
|
||||
wrapped := fh.WrapHandler(handlerRecorder.GinHandler())
|
||||
wrapped(c)
|
||||
|
||||
// Assert: local handler called (not proxy, since we have a local provider)
|
||||
assert.True(t, handlerRecorder.WasCalled(), "local handler should be called")
|
||||
assert.False(t, proxyRecorder.Called, "proxy should NOT be called for local provider")
|
||||
|
||||
// Assert: handler receives complete original body
|
||||
// This verifies that the body was properly restored after the wrapper read it for model extraction
|
||||
assert.Equal(t, originalBody, string(handlerRecorder.RequestBody), "handler should receive complete original body after wrapper reads it for model extraction")
|
||||
}
|
||||
|
||||
// TestCharacterization_GeminiV1Beta1_PostModels tests that POST requests with /models/ path use Gemini bridge handler
|
||||
// This is a characterization test for the route gating logic in routes.go
|
||||
func TestCharacterization_GeminiV1Beta1_PostModels(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
// Register a mock provider for the test model (Gemini format uses path-based model extraction)
|
||||
reg := registry.GetGlobalRegistry()
|
||||
reg.RegisterClient("char-test-gemini", "google", []*registry.ModelInfo{
|
||||
{ID: "gemini-pro"},
|
||||
})
|
||||
defer reg.UnregisterClient("char-test-gemini")
|
||||
|
||||
// Setup recorders
|
||||
proxyRecorder := testutil.NewFakeProxyRecorder()
|
||||
handlerRecorder := testutil.NewFakeHandlerRecorder()
|
||||
|
||||
// Create a test server for the proxy
|
||||
proxyServer := httptest.NewServer(proxyRecorder.ToHandler())
|
||||
defer proxyServer.Close()
|
||||
|
||||
// Create fallback handler
|
||||
fh := NewFallbackHandler(func() *httputil.ReverseProxy {
|
||||
targetURL, _ := url.Parse(proxyServer.URL)
|
||||
return httputil.NewSingleHostReverseProxy(targetURL)
|
||||
})
|
||||
|
||||
// Create the Gemini bridge handler (simulating what routes.go does)
|
||||
geminiBridge := createGeminiBridgeHandler(handlerRecorder.GinHandler())
|
||||
geminiV1Beta1Handler := fh.WrapHandler(geminiBridge)
|
||||
|
||||
// Create router with the same gating logic as routes.go
|
||||
r := gin.New()
|
||||
r.Any("/api/provider/google/v1beta1/*path", func(c *gin.Context) {
|
||||
if c.Request.Method == "POST" {
|
||||
if path := c.Param("path"); strings.Contains(path, "/models/") {
|
||||
// POST with /models/ path -> use Gemini bridge with fallback handler
|
||||
geminiV1Beta1Handler(c)
|
||||
return
|
||||
}
|
||||
}
|
||||
// Non-POST or no /models/ in path -> proxy upstream
|
||||
proxyRecorder.ServeHTTP(c.Writer, c.Request)
|
||||
})
|
||||
|
||||
// Execute: POST request with /models/ in path
|
||||
body := `{"contents": [{"role": "user", "parts": [{"text": "hello"}]}]}`
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/provider/google/v1beta1/publishers/google/models/gemini-pro:generateContent", bytes.NewReader([]byte(body)))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
w := httptest.NewRecorder()
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
// Assert: local Gemini handler called
|
||||
assert.True(t, handlerRecorder.WasCalled(), "local Gemini handler should be called for POST /models/")
|
||||
|
||||
// Assert: proxy NOT called
|
||||
assert.False(t, proxyRecorder.Called, "proxy should NOT be called for POST /models/ path")
|
||||
}
|
||||
|
||||
// TestCharacterization_GeminiV1Beta1_GetProxies tests that GET requests to Gemini v1beta1 always use proxy
|
||||
// This is a characterization test for the route gating logic in routes.go
|
||||
func TestCharacterization_GeminiV1Beta1_GetProxies(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
// Setup recorders
|
||||
proxyRecorder := testutil.NewFakeProxyRecorder()
|
||||
handlerRecorder := testutil.NewFakeHandlerRecorder()
|
||||
|
||||
// Create a test server for the proxy
|
||||
proxyServer := httptest.NewServer(proxyRecorder.ToHandler())
|
||||
defer proxyServer.Close()
|
||||
|
||||
// Create fallback handler
|
||||
fh := NewFallbackHandler(func() *httputil.ReverseProxy {
|
||||
targetURL, _ := url.Parse(proxyServer.URL)
|
||||
return httputil.NewSingleHostReverseProxy(targetURL)
|
||||
})
|
||||
|
||||
// Create the Gemini bridge handler
|
||||
geminiBridge := createGeminiBridgeHandler(handlerRecorder.GinHandler())
|
||||
geminiV1Beta1Handler := fh.WrapHandler(geminiBridge)
|
||||
|
||||
// Create router with the same gating logic as routes.go
|
||||
r := gin.New()
|
||||
r.Any("/api/provider/google/v1beta1/*path", func(c *gin.Context) {
|
||||
if c.Request.Method == "POST" {
|
||||
if path := c.Param("path"); strings.Contains(path, "/models/") {
|
||||
geminiV1Beta1Handler(c)
|
||||
return
|
||||
}
|
||||
}
|
||||
proxyRecorder.ServeHTTP(c.Writer, c.Request)
|
||||
})
|
||||
|
||||
// Execute: GET request (even with /models/ in path)
|
||||
req := httptest.NewRequest(http.MethodGet, "/api/provider/google/v1beta1/publishers/google/models/gemini-pro", nil)
|
||||
w := httptest.NewRecorder()
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
// Assert: proxy called
|
||||
assert.True(t, proxyRecorder.Called, "proxy should be called for GET requests")
|
||||
assert.Equal(t, 1, proxyRecorder.GetCallCount(), "proxy should be called exactly once")
|
||||
|
||||
// Assert: local handler NOT called
|
||||
assert.False(t, handlerRecorder.WasCalled(), "local handler should NOT be called for GET requests")
|
||||
}
|
||||
@@ -2,7 +2,7 @@ package amp
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/http/httputil"
|
||||
@@ -11,63 +11,138 @@ import (
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestFallbackHandler_ModelMapping_PreservesThinkingSuffixAndRewritesResponse(t *testing.T) {
|
||||
// Characterization tests for fallback_handlers.go
|
||||
// These tests capture existing behavior before refactoring to routing layer
|
||||
|
||||
func TestFallbackHandler_WrapHandler_LocalProvider_NoMapping(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
reg := registry.GetGlobalRegistry()
|
||||
reg.RegisterClient("test-client-amp-fallback", "codex", []*registry.ModelInfo{
|
||||
{ID: "test/gpt-5.2", OwnedBy: "openai", Type: "codex"},
|
||||
// Setup: model that has local providers (gemini-2.5-pro is registered)
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
|
||||
body := `{"model": "gemini-2.5-pro", "messages": [{"role": "user", "content": "hello"}]}`
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/provider/anthropic/v1/messages", bytes.NewReader([]byte(body)))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
c.Request = req
|
||||
|
||||
// Handler that should be called (not proxy)
|
||||
handlerCalled := false
|
||||
handler := func(c *gin.Context) {
|
||||
handlerCalled = true
|
||||
c.JSON(200, gin.H{"status": "ok"})
|
||||
}
|
||||
|
||||
// Create fallback handler
|
||||
fh := NewFallbackHandler(func() *httputil.ReverseProxy {
|
||||
return nil // no proxy
|
||||
})
|
||||
defer reg.UnregisterClient("test-client-amp-fallback")
|
||||
|
||||
// Execute
|
||||
wrapped := fh.WrapHandler(handler)
|
||||
wrapped(c)
|
||||
|
||||
// Assert: handler should be called directly (no mapping needed)
|
||||
assert.True(t, handlerCalled, "handler should be called for local provider")
|
||||
assert.Equal(t, 200, w.Code)
|
||||
}
|
||||
|
||||
func TestFallbackHandler_WrapHandler_MappingApplied(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
// Register a mock provider for the target model
|
||||
reg := registry.GetGlobalRegistry()
|
||||
reg.RegisterClient("test-client", "anthropic", []*registry.ModelInfo{
|
||||
{ID: "claude-opus-4-5-thinking"},
|
||||
})
|
||||
|
||||
// Setup: model that needs mapping
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
|
||||
body := `{"model": "claude-opus-4-5-20251101", "messages": [{"role": "user", "content": "hello"}]}`
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/provider/anthropic/v1/messages", bytes.NewReader([]byte(body)))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
c.Request = req
|
||||
|
||||
// Handler to capture rewritten body
|
||||
var capturedBody []byte
|
||||
handler := func(c *gin.Context) {
|
||||
capturedBody, _ = io.ReadAll(c.Request.Body)
|
||||
c.JSON(200, gin.H{"status": "ok"})
|
||||
}
|
||||
|
||||
// Create fallback handler with mapper
|
||||
mapper := NewModelMapper([]config.AmpModelMapping{
|
||||
{From: "claude-opus-4-5-20251101", To: "claude-opus-4-5-thinking"},
|
||||
})
|
||||
|
||||
fh := NewFallbackHandlerWithMapper(
|
||||
func() *httputil.ReverseProxy { return nil },
|
||||
mapper,
|
||||
func() bool { return false },
|
||||
)
|
||||
|
||||
// Execute
|
||||
wrapped := fh.WrapHandler(handler)
|
||||
wrapped(c)
|
||||
|
||||
// Assert: body should be rewritten
|
||||
assert.Contains(t, string(capturedBody), "claude-opus-4-5-thinking")
|
||||
|
||||
// Assert: context should have mapped model
|
||||
mappedModel, exists := c.Get(MappedModelContextKey)
|
||||
assert.True(t, exists, "MappedModelContextKey should be set")
|
||||
assert.NotEmpty(t, mappedModel)
|
||||
}
|
||||
|
||||
func TestFallbackHandler_WrapHandler_ThinkingSuffixPreserved(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
// Register a mock provider for the target model
|
||||
reg := registry.GetGlobalRegistry()
|
||||
reg.RegisterClient("test-client-2", "anthropic", []*registry.ModelInfo{
|
||||
{ID: "claude-opus-4-5-thinking"},
|
||||
})
|
||||
|
||||
w := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(w)
|
||||
|
||||
// Model with thinking suffix
|
||||
body := `{"model": "claude-opus-4-5-20251101(xhigh)", "messages": []}`
|
||||
req := httptest.NewRequest(http.MethodPost, "/api/provider/anthropic/v1/messages", bytes.NewReader([]byte(body)))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
c.Request = req
|
||||
|
||||
var capturedBody []byte
|
||||
handler := func(c *gin.Context) {
|
||||
capturedBody, _ = io.ReadAll(c.Request.Body)
|
||||
c.JSON(200, gin.H{"status": "ok"})
|
||||
}
|
||||
|
||||
mapper := NewModelMapper([]config.AmpModelMapping{
|
||||
{From: "gpt-5.2", To: "test/gpt-5.2"},
|
||||
{From: "claude-opus-4-5-20251101", To: "claude-opus-4-5-thinking"},
|
||||
})
|
||||
|
||||
fallback := NewFallbackHandlerWithMapper(func() *httputil.ReverseProxy { return nil }, mapper, nil)
|
||||
fh := NewFallbackHandlerWithMapper(
|
||||
func() *httputil.ReverseProxy { return nil },
|
||||
mapper,
|
||||
func() bool { return false },
|
||||
)
|
||||
|
||||
handler := func(c *gin.Context) {
|
||||
var req struct {
|
||||
Model string `json:"model"`
|
||||
}
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
wrapped := fh.WrapHandler(handler)
|
||||
wrapped(c)
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"model": req.Model,
|
||||
"seen_model": req.Model,
|
||||
})
|
||||
}
|
||||
|
||||
r := gin.New()
|
||||
r.POST("/chat/completions", fallback.WrapHandler(handler))
|
||||
|
||||
reqBody := []byte(`{"model":"gpt-5.2(xhigh)"}`)
|
||||
req := httptest.NewRequest(http.MethodPost, "/chat/completions", bytes.NewReader(reqBody))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
w := httptest.NewRecorder()
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("Expected status 200, got %d", w.Code)
|
||||
}
|
||||
|
||||
var resp struct {
|
||||
Model string `json:"model"`
|
||||
SeenModel string `json:"seen_model"`
|
||||
}
|
||||
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
|
||||
t.Fatalf("Failed to parse response JSON: %v", err)
|
||||
}
|
||||
|
||||
if resp.Model != "gpt-5.2(xhigh)" {
|
||||
t.Errorf("Expected response model gpt-5.2(xhigh), got %s", resp.Model)
|
||||
}
|
||||
if resp.SeenModel != "test/gpt-5.2(xhigh)" {
|
||||
t.Errorf("Expected handler to see test/gpt-5.2(xhigh), got %s", resp.SeenModel)
|
||||
}
|
||||
// Assert: thinking suffix should be preserved
|
||||
assert.Contains(t, string(capturedBody), "(xhigh)")
|
||||
}
|
||||
|
||||
func TestFallbackHandler_WrapHandler_NoProvider_NoMapping_ProxyEnabled(t *testing.T) {
|
||||
// Skip: httptest.ResponseRecorder doesn't implement http.CloseNotifier
|
||||
// which is required by httputil.ReverseProxy. This test requires a real
|
||||
// HTTP server and client to properly test proxy behavior.
|
||||
t.Skip("requires real HTTP server for proxy testing")
|
||||
}
|
||||
|
||||
@@ -30,18 +30,98 @@ type DefaultModelMapper struct {
|
||||
mu sync.RWMutex
|
||||
mappings map[string]string // exact: from -> to (normalized lowercase keys)
|
||||
regexps []regexMapping // regex rules evaluated in order
|
||||
|
||||
// oauthAliasForward maps channel -> name (lower) -> []alias for oauth-model-alias lookup.
|
||||
// This allows model-mappings targets to find providers via their aliases.
|
||||
oauthAliasForward map[string]map[string][]string
|
||||
}
|
||||
|
||||
// NewModelMapper creates a new model mapper with the given initial mappings.
|
||||
func NewModelMapper(mappings []config.AmpModelMapping) *DefaultModelMapper {
|
||||
m := &DefaultModelMapper{
|
||||
mappings: make(map[string]string),
|
||||
regexps: nil,
|
||||
mappings: make(map[string]string),
|
||||
regexps: nil,
|
||||
oauthAliasForward: nil,
|
||||
}
|
||||
m.UpdateMappings(mappings)
|
||||
return m
|
||||
}
|
||||
|
||||
// UpdateOAuthModelAlias updates the oauth-model-alias lookup table.
|
||||
// This is called during initialization and on config hot-reload.
|
||||
func (m *DefaultModelMapper) UpdateOAuthModelAlias(aliases map[string][]config.OAuthModelAlias) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
if len(aliases) == 0 {
|
||||
m.oauthAliasForward = nil
|
||||
return
|
||||
}
|
||||
|
||||
forward := make(map[string]map[string][]string, len(aliases))
|
||||
for rawChannel, entries := range aliases {
|
||||
channel := strings.ToLower(strings.TrimSpace(rawChannel))
|
||||
if channel == "" || len(entries) == 0 {
|
||||
continue
|
||||
}
|
||||
channelMap := make(map[string][]string)
|
||||
for _, entry := range entries {
|
||||
name := strings.TrimSpace(entry.Name)
|
||||
alias := strings.TrimSpace(entry.Alias)
|
||||
if name == "" || alias == "" {
|
||||
continue
|
||||
}
|
||||
if strings.EqualFold(name, alias) {
|
||||
continue
|
||||
}
|
||||
nameKey := strings.ToLower(name)
|
||||
channelMap[nameKey] = append(channelMap[nameKey], alias)
|
||||
}
|
||||
if len(channelMap) > 0 {
|
||||
forward[channel] = channelMap
|
||||
}
|
||||
}
|
||||
if len(forward) == 0 {
|
||||
m.oauthAliasForward = nil
|
||||
return
|
||||
}
|
||||
m.oauthAliasForward = forward
|
||||
log.Debugf("amp model mapping: loaded oauth-model-alias for %d channel(s)", len(forward))
|
||||
}
|
||||
|
||||
// findAllAliasesWithProviders returns all oauth-model-alias aliases for targetModel
|
||||
// that have available providers. Useful for fallback when one alias is quota-exceeded.
|
||||
func (m *DefaultModelMapper) findAllAliasesWithProviders(targetModel string) []string {
|
||||
if m.oauthAliasForward == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
targetKey := strings.ToLower(strings.TrimSpace(targetModel))
|
||||
if targetKey == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
var result []string
|
||||
seen := make(map[string]struct{})
|
||||
|
||||
// Check all channels for this model name
|
||||
for _, channelMap := range m.oauthAliasForward {
|
||||
aliases := channelMap[targetKey]
|
||||
for _, alias := range aliases {
|
||||
aliasLower := strings.ToLower(alias)
|
||||
if _, exists := seen[aliasLower]; exists {
|
||||
continue
|
||||
}
|
||||
providers := util.GetProviderName(alias)
|
||||
if len(providers) > 0 {
|
||||
result = append(result, alias)
|
||||
seen[aliasLower] = struct{}{}
|
||||
}
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// MapModel checks if a mapping exists for the requested model and if the
|
||||
// target model has available local providers. Returns the mapped model name
|
||||
// or empty string if no valid mapping exists.
|
||||
@@ -51,9 +131,20 @@ func NewModelMapper(mappings []config.AmpModelMapping) *DefaultModelMapper {
|
||||
// However, if the mapping target already contains a suffix, the config suffix
|
||||
// takes priority over the user's suffix.
|
||||
func (m *DefaultModelMapper) MapModel(requestedModel string) string {
|
||||
if requestedModel == "" {
|
||||
models := m.MapModelWithFallbacks(requestedModel)
|
||||
if len(models) == 0 {
|
||||
return ""
|
||||
}
|
||||
return models[0]
|
||||
}
|
||||
|
||||
// MapModelWithFallbacks returns all possible target models for the requested model,
|
||||
// including fallback aliases from oauth-model-alias. The first model is the primary target,
|
||||
// and subsequent models are fallbacks to try if the primary is unavailable (e.g., quota exceeded).
|
||||
func (m *DefaultModelMapper) MapModelWithFallbacks(requestedModel string) []string {
|
||||
if requestedModel == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
@@ -78,34 +169,54 @@ func (m *DefaultModelMapper) MapModel(requestedModel string) string {
|
||||
}
|
||||
}
|
||||
if !exists {
|
||||
return ""
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// Check if target model already has a thinking suffix (config priority)
|
||||
targetResult := thinking.ParseSuffix(targetModel)
|
||||
targetBase := targetResult.ModelName
|
||||
|
||||
// Helper to apply suffix to a model
|
||||
applySuffix := func(model string) string {
|
||||
modelResult := thinking.ParseSuffix(model)
|
||||
if modelResult.HasSuffix {
|
||||
return model
|
||||
}
|
||||
if requestResult.HasSuffix && requestResult.RawSuffix != "" {
|
||||
return model + "(" + requestResult.RawSuffix + ")"
|
||||
}
|
||||
return model
|
||||
}
|
||||
|
||||
// Verify target model has available providers (use base model for lookup)
|
||||
providers := util.GetProviderName(targetResult.ModelName)
|
||||
if len(providers) == 0 {
|
||||
providers := util.GetProviderName(targetBase)
|
||||
|
||||
// If direct provider available, return it as primary
|
||||
if len(providers) > 0 {
|
||||
return []string{applySuffix(targetModel)}
|
||||
}
|
||||
|
||||
// No direct providers - check oauth-model-alias for all aliases that have providers
|
||||
allAliases := m.findAllAliasesWithProviders(targetBase)
|
||||
if len(allAliases) == 0 {
|
||||
log.Debugf("amp model mapping: target model %s has no available providers, skipping mapping", targetModel)
|
||||
return ""
|
||||
return nil
|
||||
}
|
||||
|
||||
// Suffix handling: config suffix takes priority, otherwise preserve user suffix
|
||||
if targetResult.HasSuffix {
|
||||
// Config's "to" already contains a suffix - use it as-is (config priority)
|
||||
return targetModel
|
||||
// Log resolution
|
||||
if len(allAliases) == 1 {
|
||||
log.Debugf("amp model mapping: resolved %s -> %s via oauth-model-alias", targetModel, allAliases[0])
|
||||
} else {
|
||||
log.Debugf("amp model mapping: resolved %s -> %v via oauth-model-alias (%d fallbacks)", targetModel, allAliases, len(allAliases)-1)
|
||||
}
|
||||
|
||||
// Preserve user's thinking suffix on the mapped model
|
||||
// (skip empty suffixes to avoid returning "model()")
|
||||
if requestResult.HasSuffix && requestResult.RawSuffix != "" {
|
||||
return targetModel + "(" + requestResult.RawSuffix + ")"
|
||||
// Apply suffix to all aliases
|
||||
result := make([]string, len(allAliases))
|
||||
for i, alias := range allAliases {
|
||||
result[i] = applySuffix(alias)
|
||||
}
|
||||
|
||||
// Note: Detailed routing log is handled by logAmpRouting in fallback_handlers.go
|
||||
return targetModel
|
||||
return result
|
||||
}
|
||||
|
||||
// UpdateMappings refreshes the mapping configuration from config.
|
||||
@@ -165,6 +276,22 @@ func (m *DefaultModelMapper) GetMappings() map[string]string {
|
||||
return result
|
||||
}
|
||||
|
||||
// GetMappingsAsConfig returns the current model mappings as config.AmpModelMapping slice.
|
||||
// Safe for concurrent use.
|
||||
func (m *DefaultModelMapper) GetMappingsAsConfig() []config.AmpModelMapping {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
|
||||
result := make([]config.AmpModelMapping, 0, len(m.mappings))
|
||||
for from, to := range m.mappings {
|
||||
result = append(result, config.AmpModelMapping{
|
||||
From: from,
|
||||
To: to,
|
||||
})
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
type regexMapping struct {
|
||||
re *regexp.Regexp
|
||||
to string
|
||||
|
||||
@@ -5,11 +5,12 @@ import (
|
||||
"errors"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httputil"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/logging"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/routing"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/sdk/api/handlers"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/sdk/api/handlers/claude"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/sdk/api/handlers/gemini"
|
||||
@@ -234,19 +235,20 @@ func (m *AmpModule) registerManagementRoutes(engine *gin.Engine, baseHandler *ha
|
||||
// If no local OAuth is available, falls back to ampcode.com proxy.
|
||||
geminiHandlers := gemini.NewGeminiAPIHandler(baseHandler)
|
||||
geminiBridge := createGeminiBridgeHandler(geminiHandlers.GeminiHandler)
|
||||
geminiV1Beta1Fallback := NewFallbackHandlerWithMapper(func() *httputil.ReverseProxy {
|
||||
return m.getProxy()
|
||||
}, m.modelMapper, m.forceModelMappings)
|
||||
geminiV1Beta1Handler := geminiV1Beta1Fallback.WrapHandler(geminiBridge)
|
||||
|
||||
// Route POST model calls through Gemini bridge with FallbackHandler.
|
||||
// FallbackHandler checks provider -> mapping -> proxy fallback automatically.
|
||||
// T-025: Migrated Gemini v1beta1 bridge to use ModelRoutingWrapper
|
||||
// Create a dedicated routing wrapper for the Gemini bridge
|
||||
geminiBridgeWrapper := m.createModelRoutingWrapper()
|
||||
geminiV1Beta1Handler := geminiBridgeWrapper.Wrap(geminiBridge)
|
||||
|
||||
// Route POST model calls through Gemini bridge with ModelRoutingWrapper.
|
||||
// ModelRoutingWrapper checks provider -> mapping -> proxy fallback automatically.
|
||||
// All other methods (e.g., GET model listing) always proxy to upstream to preserve Amp CLI behavior.
|
||||
ampAPI.Any("/provider/google/v1beta1/*path", func(c *gin.Context) {
|
||||
if c.Request.Method == "POST" {
|
||||
if path := c.Param("path"); strings.Contains(path, "/models/") {
|
||||
// POST with /models/ path -> use Gemini bridge with fallback handler
|
||||
// FallbackHandler will check provider/mapping and proxy if needed
|
||||
// POST with /models/ path -> use Gemini bridge with unified routing wrapper
|
||||
// ModelRoutingWrapper will check provider/mapping and proxy if needed
|
||||
geminiV1Beta1Handler(c)
|
||||
return
|
||||
}
|
||||
@@ -256,6 +258,41 @@ func (m *AmpModule) registerManagementRoutes(engine *gin.Engine, baseHandler *ha
|
||||
})
|
||||
}
|
||||
|
||||
// createModelRoutingWrapper creates a new ModelRoutingWrapper for unified routing.
|
||||
// This is used for testing the new routing implementation (T-021 onwards).
|
||||
func (m *AmpModule) createModelRoutingWrapper() *routing.ModelRoutingWrapper {
|
||||
// Create a registry - in production this would be populated with actual providers
|
||||
registry := routing.NewRegistry()
|
||||
|
||||
// Create a minimal config with just AmpCode settings
|
||||
// The Router only needs AmpCode.ModelMappings and OAuthModelAlias
|
||||
cfg := &config.Config{
|
||||
AmpCode: func() config.AmpCode {
|
||||
if m.modelMapper != nil {
|
||||
return config.AmpCode{
|
||||
ModelMappings: m.modelMapper.GetMappingsAsConfig(),
|
||||
}
|
||||
}
|
||||
return config.AmpCode{}
|
||||
}(),
|
||||
}
|
||||
|
||||
// Create router with registry and config
|
||||
router := routing.NewRouter(registry, cfg)
|
||||
|
||||
// Create wrapper with proxy function
|
||||
proxyFunc := func(c *gin.Context) {
|
||||
proxy := m.getProxy()
|
||||
if proxy != nil {
|
||||
proxy.ServeHTTP(c.Writer, c.Request)
|
||||
} else {
|
||||
c.JSON(503, gin.H{"error": "amp upstream proxy not available"})
|
||||
}
|
||||
}
|
||||
|
||||
return routing.NewModelRoutingWrapper(router, nil, nil, proxyFunc)
|
||||
}
|
||||
|
||||
// registerProviderAliases registers /api/provider/{provider}/... routes
|
||||
// These allow Amp CLI to route requests like:
|
||||
//
|
||||
@@ -269,12 +306,9 @@ func (m *AmpModule) registerProviderAliases(engine *gin.Engine, baseHandler *han
|
||||
claudeCodeHandlers := claude.NewClaudeCodeAPIHandler(baseHandler)
|
||||
openaiResponsesHandlers := openai.NewOpenAIResponsesAPIHandler(baseHandler)
|
||||
|
||||
// Create fallback handler wrapper that forwards to ampcode.com when provider not found
|
||||
// Uses m.getProxy() for hot-reload support (proxy can be updated at runtime)
|
||||
// Also includes model mapping support for routing unavailable models to alternatives
|
||||
fallbackHandler := NewFallbackHandlerWithMapper(func() *httputil.ReverseProxy {
|
||||
return m.getProxy()
|
||||
}, m.modelMapper, m.forceModelMappings)
|
||||
// Create unified routing wrapper (T-021 onwards)
|
||||
// Replaces FallbackHandler with Router-based unified routing
|
||||
routingWrapper := m.createModelRoutingWrapper()
|
||||
|
||||
// Provider-specific routes under /api/provider/:provider
|
||||
ampProviders := engine.Group("/api/provider")
|
||||
@@ -302,33 +336,36 @@ func (m *AmpModule) registerProviderAliases(engine *gin.Engine, baseHandler *han
|
||||
}
|
||||
|
||||
// Root-level routes (for providers that omit /v1, like groq/cerebras)
|
||||
// Wrap handlers with fallback logic to forward to ampcode.com when provider not found
|
||||
// T-022: Migrated all OpenAI routes to use ModelRoutingWrapper for unified routing
|
||||
provider.GET("/models", ampModelsHandler) // Models endpoint doesn't need fallback (no body to check)
|
||||
provider.POST("/chat/completions", fallbackHandler.WrapHandler(openaiHandlers.ChatCompletions))
|
||||
provider.POST("/completions", fallbackHandler.WrapHandler(openaiHandlers.Completions))
|
||||
provider.POST("/responses", fallbackHandler.WrapHandler(openaiResponsesHandlers.Responses))
|
||||
provider.POST("/chat/completions", routingWrapper.Wrap(openaiHandlers.ChatCompletions))
|
||||
provider.POST("/completions", routingWrapper.Wrap(openaiHandlers.Completions))
|
||||
provider.POST("/responses", routingWrapper.Wrap(openaiResponsesHandlers.Responses))
|
||||
|
||||
// /v1 routes (OpenAI/Claude-compatible endpoints)
|
||||
v1Amp := provider.Group("/v1")
|
||||
{
|
||||
v1Amp.GET("/models", ampModelsHandler) // Models endpoint doesn't need fallback
|
||||
|
||||
// OpenAI-compatible endpoints with fallback
|
||||
v1Amp.POST("/chat/completions", fallbackHandler.WrapHandler(openaiHandlers.ChatCompletions))
|
||||
v1Amp.POST("/completions", fallbackHandler.WrapHandler(openaiHandlers.Completions))
|
||||
v1Amp.POST("/responses", fallbackHandler.WrapHandler(openaiResponsesHandlers.Responses))
|
||||
// OpenAI-compatible endpoints with ModelRoutingWrapper
|
||||
// T-021, T-022: Migrated to unified routing wrapper
|
||||
v1Amp.POST("/chat/completions", routingWrapper.Wrap(openaiHandlers.ChatCompletions))
|
||||
v1Amp.POST("/completions", routingWrapper.Wrap(openaiHandlers.Completions))
|
||||
v1Amp.POST("/responses", routingWrapper.Wrap(openaiResponsesHandlers.Responses))
|
||||
|
||||
// Claude/Anthropic-compatible endpoints with fallback
|
||||
v1Amp.POST("/messages", fallbackHandler.WrapHandler(claudeCodeHandlers.ClaudeMessages))
|
||||
v1Amp.POST("/messages/count_tokens", fallbackHandler.WrapHandler(claudeCodeHandlers.ClaudeCountTokens))
|
||||
// Claude/Anthropic-compatible endpoints with ModelRoutingWrapper
|
||||
// T-023: Migrated Claude routes to unified routing wrapper
|
||||
v1Amp.POST("/messages", routingWrapper.Wrap(claudeCodeHandlers.ClaudeMessages))
|
||||
v1Amp.POST("/messages/count_tokens", routingWrapper.Wrap(claudeCodeHandlers.ClaudeCountTokens))
|
||||
}
|
||||
|
||||
// /v1beta routes (Gemini native API)
|
||||
// Note: Gemini handler extracts model from URL path, so fallback logic needs special handling
|
||||
// T-024: Migrated Gemini v1beta routes to unified routing wrapper
|
||||
v1betaAmp := provider.Group("/v1beta")
|
||||
{
|
||||
v1betaAmp.GET("/models", geminiHandlers.GeminiModels)
|
||||
v1betaAmp.POST("/models/*action", fallbackHandler.WrapHandler(geminiHandlers.GeminiHandler))
|
||||
v1betaAmp.POST("/models/*action", routingWrapper.Wrap(geminiHandlers.GeminiHandler))
|
||||
v1betaAmp.GET("/models/*action", geminiHandlers.GeminiGetHandler)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -623,7 +623,6 @@ func (s *Server) registerManagementRoutes() {
|
||||
mgmt.GET("/gemini-cli-auth-url", s.mgmt.RequestGeminiCLIToken)
|
||||
mgmt.GET("/antigravity-auth-url", s.mgmt.RequestAntigravityToken)
|
||||
mgmt.GET("/qwen-auth-url", s.mgmt.RequestQwenToken)
|
||||
mgmt.GET("/kimi-auth-url", s.mgmt.RequestKimiToken)
|
||||
mgmt.GET("/iflow-auth-url", s.mgmt.RequestIFlowToken)
|
||||
mgmt.POST("/iflow-auth-url", s.mgmt.RequestIFlowCookieToken)
|
||||
mgmt.POST("/oauth-callback", s.mgmt.PostOAuthCallback)
|
||||
@@ -961,8 +960,8 @@ func (s *Server) UpdateClients(cfg *config.Config) {
|
||||
s.mgmt.SetAuthManager(s.handlers.AuthManager)
|
||||
}
|
||||
|
||||
// Notify Amp module only when Amp config has changed.
|
||||
ampConfigChanged := oldCfg == nil || !reflect.DeepEqual(oldCfg.AmpCode, cfg.AmpCode)
|
||||
// Notify Amp module when Amp config or OAuth model aliases have changed.
|
||||
ampConfigChanged := oldCfg == nil || !reflect.DeepEqual(oldCfg.AmpCode, cfg.AmpCode) || !reflect.DeepEqual(oldCfg.OAuthModelAlias, cfg.OAuthModelAlias)
|
||||
if ampConfigChanged {
|
||||
if s.ampModule != nil {
|
||||
log.Debugf("triggering amp module config update")
|
||||
|
||||
@@ -1,396 +0,0 @@
|
||||
// Package kimi provides authentication and token management for Kimi (Moonshot AI) API.
|
||||
// It handles the RFC 8628 OAuth2 Device Authorization Grant flow for secure authentication.
|
||||
package kimi
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"os"
|
||||
"runtime"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
const (
|
||||
// kimiClientID is Kimi Code's OAuth client ID.
|
||||
kimiClientID = "17e5f671-d194-4dfb-9706-5516cb48c098"
|
||||
// kimiOAuthHost is the OAuth server endpoint.
|
||||
kimiOAuthHost = "https://auth.kimi.com"
|
||||
// kimiDeviceCodeURL is the endpoint for requesting device codes.
|
||||
kimiDeviceCodeURL = kimiOAuthHost + "/api/oauth/device_authorization"
|
||||
// kimiTokenURL is the endpoint for exchanging device codes for tokens.
|
||||
kimiTokenURL = kimiOAuthHost + "/api/oauth/token"
|
||||
// KimiAPIBaseURL is the base URL for Kimi API requests.
|
||||
KimiAPIBaseURL = "https://api.kimi.com/coding/v1"
|
||||
// defaultPollInterval is the default interval for polling token endpoint.
|
||||
defaultPollInterval = 5 * time.Second
|
||||
// maxPollDuration is the maximum time to wait for user authorization.
|
||||
maxPollDuration = 15 * time.Minute
|
||||
// refreshThresholdSeconds is when to refresh token before expiry (5 minutes).
|
||||
refreshThresholdSeconds = 300
|
||||
)
|
||||
|
||||
// KimiAuth handles Kimi authentication flow.
|
||||
type KimiAuth struct {
|
||||
deviceClient *DeviceFlowClient
|
||||
cfg *config.Config
|
||||
}
|
||||
|
||||
// NewKimiAuth creates a new KimiAuth service instance.
|
||||
func NewKimiAuth(cfg *config.Config) *KimiAuth {
|
||||
return &KimiAuth{
|
||||
deviceClient: NewDeviceFlowClient(cfg),
|
||||
cfg: cfg,
|
||||
}
|
||||
}
|
||||
|
||||
// StartDeviceFlow initiates the device flow authentication.
|
||||
func (k *KimiAuth) StartDeviceFlow(ctx context.Context) (*DeviceCodeResponse, error) {
|
||||
return k.deviceClient.RequestDeviceCode(ctx)
|
||||
}
|
||||
|
||||
// WaitForAuthorization polls for user authorization and returns the auth bundle.
|
||||
func (k *KimiAuth) WaitForAuthorization(ctx context.Context, deviceCode *DeviceCodeResponse) (*KimiAuthBundle, error) {
|
||||
tokenData, err := k.deviceClient.PollForToken(ctx, deviceCode)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &KimiAuthBundle{
|
||||
TokenData: tokenData,
|
||||
DeviceID: k.deviceClient.deviceID,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// CreateTokenStorage creates a new KimiTokenStorage from auth bundle.
|
||||
func (k *KimiAuth) CreateTokenStorage(bundle *KimiAuthBundle) *KimiTokenStorage {
|
||||
expired := ""
|
||||
if bundle.TokenData.ExpiresAt > 0 {
|
||||
expired = time.Unix(bundle.TokenData.ExpiresAt, 0).UTC().Format(time.RFC3339)
|
||||
}
|
||||
return &KimiTokenStorage{
|
||||
AccessToken: bundle.TokenData.AccessToken,
|
||||
RefreshToken: bundle.TokenData.RefreshToken,
|
||||
TokenType: bundle.TokenData.TokenType,
|
||||
Scope: bundle.TokenData.Scope,
|
||||
DeviceID: strings.TrimSpace(bundle.DeviceID),
|
||||
Expired: expired,
|
||||
Type: "kimi",
|
||||
}
|
||||
}
|
||||
|
||||
// DeviceFlowClient handles the OAuth2 device flow for Kimi.
|
||||
type DeviceFlowClient struct {
|
||||
httpClient *http.Client
|
||||
cfg *config.Config
|
||||
deviceID string
|
||||
}
|
||||
|
||||
// NewDeviceFlowClient creates a new device flow client.
|
||||
func NewDeviceFlowClient(cfg *config.Config) *DeviceFlowClient {
|
||||
return NewDeviceFlowClientWithDeviceID(cfg, "")
|
||||
}
|
||||
|
||||
// NewDeviceFlowClientWithDeviceID creates a new device flow client with the specified device ID.
|
||||
func NewDeviceFlowClientWithDeviceID(cfg *config.Config, deviceID string) *DeviceFlowClient {
|
||||
client := &http.Client{Timeout: 30 * time.Second}
|
||||
if cfg != nil {
|
||||
client = util.SetProxy(&cfg.SDKConfig, client)
|
||||
}
|
||||
resolvedDeviceID := strings.TrimSpace(deviceID)
|
||||
if resolvedDeviceID == "" {
|
||||
resolvedDeviceID = getOrCreateDeviceID()
|
||||
}
|
||||
return &DeviceFlowClient{
|
||||
httpClient: client,
|
||||
cfg: cfg,
|
||||
deviceID: resolvedDeviceID,
|
||||
}
|
||||
}
|
||||
|
||||
// getOrCreateDeviceID returns an in-memory device ID for the current authentication flow.
|
||||
func getOrCreateDeviceID() string {
|
||||
return uuid.New().String()
|
||||
}
|
||||
|
||||
// getDeviceModel returns a device model string.
|
||||
func getDeviceModel() string {
|
||||
osName := runtime.GOOS
|
||||
arch := runtime.GOARCH
|
||||
|
||||
switch osName {
|
||||
case "darwin":
|
||||
return fmt.Sprintf("macOS %s", arch)
|
||||
case "windows":
|
||||
return fmt.Sprintf("Windows %s", arch)
|
||||
case "linux":
|
||||
return fmt.Sprintf("Linux %s", arch)
|
||||
default:
|
||||
return fmt.Sprintf("%s %s", osName, arch)
|
||||
}
|
||||
}
|
||||
|
||||
// getHostname returns the machine hostname.
|
||||
func getHostname() string {
|
||||
hostname, err := os.Hostname()
|
||||
if err != nil {
|
||||
return "unknown"
|
||||
}
|
||||
return hostname
|
||||
}
|
||||
|
||||
// commonHeaders returns headers required for Kimi API requests.
|
||||
func (c *DeviceFlowClient) commonHeaders() map[string]string {
|
||||
return map[string]string{
|
||||
"X-Msh-Platform": "cli-proxy-api",
|
||||
"X-Msh-Version": "1.0.0",
|
||||
"X-Msh-Device-Name": getHostname(),
|
||||
"X-Msh-Device-Model": getDeviceModel(),
|
||||
"X-Msh-Device-Id": c.deviceID,
|
||||
}
|
||||
}
|
||||
|
||||
// RequestDeviceCode initiates the device flow by requesting a device code from Kimi.
|
||||
func (c *DeviceFlowClient) RequestDeviceCode(ctx context.Context) (*DeviceCodeResponse, error) {
|
||||
data := url.Values{}
|
||||
data.Set("client_id", kimiClientID)
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, kimiDeviceCodeURL, strings.NewReader(data.Encode()))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("kimi: failed to create device code request: %w", err)
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
req.Header.Set("Accept", "application/json")
|
||||
for k, v := range c.commonHeaders() {
|
||||
req.Header.Set(k, v)
|
||||
}
|
||||
|
||||
resp, err := c.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("kimi: device code request failed: %w", err)
|
||||
}
|
||||
defer func() {
|
||||
if errClose := resp.Body.Close(); errClose != nil {
|
||||
log.Errorf("kimi device code: close body error: %v", errClose)
|
||||
}
|
||||
}()
|
||||
|
||||
bodyBytes, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("kimi: failed to read device code response: %w", err)
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("kimi: device code request failed with status %d: %s", resp.StatusCode, string(bodyBytes))
|
||||
}
|
||||
|
||||
var deviceCode DeviceCodeResponse
|
||||
if err = json.Unmarshal(bodyBytes, &deviceCode); err != nil {
|
||||
return nil, fmt.Errorf("kimi: failed to parse device code response: %w", err)
|
||||
}
|
||||
|
||||
return &deviceCode, nil
|
||||
}
|
||||
|
||||
// PollForToken polls the token endpoint until the user authorizes or the device code expires.
|
||||
func (c *DeviceFlowClient) PollForToken(ctx context.Context, deviceCode *DeviceCodeResponse) (*KimiTokenData, error) {
|
||||
if deviceCode == nil {
|
||||
return nil, fmt.Errorf("kimi: device code is nil")
|
||||
}
|
||||
|
||||
interval := time.Duration(deviceCode.Interval) * time.Second
|
||||
if interval < defaultPollInterval {
|
||||
interval = defaultPollInterval
|
||||
}
|
||||
|
||||
deadline := time.Now().Add(maxPollDuration)
|
||||
if deviceCode.ExpiresIn > 0 {
|
||||
codeDeadline := time.Now().Add(time.Duration(deviceCode.ExpiresIn) * time.Second)
|
||||
if codeDeadline.Before(deadline) {
|
||||
deadline = codeDeadline
|
||||
}
|
||||
}
|
||||
|
||||
ticker := time.NewTicker(interval)
|
||||
defer ticker.Stop()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil, fmt.Errorf("kimi: context cancelled: %w", ctx.Err())
|
||||
case <-ticker.C:
|
||||
if time.Now().After(deadline) {
|
||||
return nil, fmt.Errorf("kimi: device code expired")
|
||||
}
|
||||
|
||||
token, pollErr, shouldContinue := c.exchangeDeviceCode(ctx, deviceCode.DeviceCode)
|
||||
if token != nil {
|
||||
return token, nil
|
||||
}
|
||||
if !shouldContinue {
|
||||
return nil, pollErr
|
||||
}
|
||||
// Continue polling
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// exchangeDeviceCode attempts to exchange the device code for an access token.
|
||||
// Returns (token, error, shouldContinue).
|
||||
func (c *DeviceFlowClient) exchangeDeviceCode(ctx context.Context, deviceCode string) (*KimiTokenData, error, bool) {
|
||||
data := url.Values{}
|
||||
data.Set("client_id", kimiClientID)
|
||||
data.Set("device_code", deviceCode)
|
||||
data.Set("grant_type", "urn:ietf:params:oauth:grant-type:device_code")
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, kimiTokenURL, strings.NewReader(data.Encode()))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("kimi: failed to create token request: %w", err), false
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
req.Header.Set("Accept", "application/json")
|
||||
for k, v := range c.commonHeaders() {
|
||||
req.Header.Set(k, v)
|
||||
}
|
||||
|
||||
resp, err := c.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("kimi: token request failed: %w", err), false
|
||||
}
|
||||
defer func() {
|
||||
if errClose := resp.Body.Close(); errClose != nil {
|
||||
log.Errorf("kimi token exchange: close body error: %v", errClose)
|
||||
}
|
||||
}()
|
||||
|
||||
bodyBytes, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("kimi: failed to read token response: %w", err), false
|
||||
}
|
||||
|
||||
// Parse response - Kimi returns 200 for both success and pending states
|
||||
var oauthResp struct {
|
||||
Error string `json:"error"`
|
||||
ErrorDescription string `json:"error_description"`
|
||||
AccessToken string `json:"access_token"`
|
||||
RefreshToken string `json:"refresh_token"`
|
||||
TokenType string `json:"token_type"`
|
||||
ExpiresIn float64 `json:"expires_in"`
|
||||
Scope string `json:"scope"`
|
||||
}
|
||||
|
||||
if err = json.Unmarshal(bodyBytes, &oauthResp); err != nil {
|
||||
return nil, fmt.Errorf("kimi: failed to parse token response: %w", err), false
|
||||
}
|
||||
|
||||
if oauthResp.Error != "" {
|
||||
switch oauthResp.Error {
|
||||
case "authorization_pending":
|
||||
return nil, nil, true // Continue polling
|
||||
case "slow_down":
|
||||
return nil, nil, true // Continue polling (with increased interval handled by caller)
|
||||
case "expired_token":
|
||||
return nil, fmt.Errorf("kimi: device code expired"), false
|
||||
case "access_denied":
|
||||
return nil, fmt.Errorf("kimi: access denied by user"), false
|
||||
default:
|
||||
return nil, fmt.Errorf("kimi: OAuth error: %s - %s", oauthResp.Error, oauthResp.ErrorDescription), false
|
||||
}
|
||||
}
|
||||
|
||||
if oauthResp.AccessToken == "" {
|
||||
return nil, fmt.Errorf("kimi: empty access token in response"), false
|
||||
}
|
||||
|
||||
var expiresAt int64
|
||||
if oauthResp.ExpiresIn > 0 {
|
||||
expiresAt = time.Now().Unix() + int64(oauthResp.ExpiresIn)
|
||||
}
|
||||
|
||||
return &KimiTokenData{
|
||||
AccessToken: oauthResp.AccessToken,
|
||||
RefreshToken: oauthResp.RefreshToken,
|
||||
TokenType: oauthResp.TokenType,
|
||||
ExpiresAt: expiresAt,
|
||||
Scope: oauthResp.Scope,
|
||||
}, nil, false
|
||||
}
|
||||
|
||||
// RefreshToken exchanges a refresh token for a new access token.
|
||||
func (c *DeviceFlowClient) RefreshToken(ctx context.Context, refreshToken string) (*KimiTokenData, error) {
|
||||
data := url.Values{}
|
||||
data.Set("client_id", kimiClientID)
|
||||
data.Set("grant_type", "refresh_token")
|
||||
data.Set("refresh_token", refreshToken)
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, kimiTokenURL, strings.NewReader(data.Encode()))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("kimi: failed to create refresh request: %w", err)
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
req.Header.Set("Accept", "application/json")
|
||||
for k, v := range c.commonHeaders() {
|
||||
req.Header.Set(k, v)
|
||||
}
|
||||
|
||||
resp, err := c.httpClient.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("kimi: refresh request failed: %w", err)
|
||||
}
|
||||
defer func() {
|
||||
if errClose := resp.Body.Close(); errClose != nil {
|
||||
log.Errorf("kimi refresh token: close body error: %v", errClose)
|
||||
}
|
||||
}()
|
||||
|
||||
bodyBytes, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("kimi: failed to read refresh response: %w", err)
|
||||
}
|
||||
|
||||
if resp.StatusCode == http.StatusUnauthorized || resp.StatusCode == http.StatusForbidden {
|
||||
return nil, fmt.Errorf("kimi: refresh token rejected (status %d)", resp.StatusCode)
|
||||
}
|
||||
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
return nil, fmt.Errorf("kimi: refresh failed with status %d: %s", resp.StatusCode, string(bodyBytes))
|
||||
}
|
||||
|
||||
var tokenResp struct {
|
||||
AccessToken string `json:"access_token"`
|
||||
RefreshToken string `json:"refresh_token"`
|
||||
TokenType string `json:"token_type"`
|
||||
ExpiresIn float64 `json:"expires_in"`
|
||||
Scope string `json:"scope"`
|
||||
}
|
||||
|
||||
if err = json.Unmarshal(bodyBytes, &tokenResp); err != nil {
|
||||
return nil, fmt.Errorf("kimi: failed to parse refresh response: %w", err)
|
||||
}
|
||||
|
||||
if tokenResp.AccessToken == "" {
|
||||
return nil, fmt.Errorf("kimi: empty access token in refresh response")
|
||||
}
|
||||
|
||||
var expiresAt int64
|
||||
if tokenResp.ExpiresIn > 0 {
|
||||
expiresAt = time.Now().Unix() + int64(tokenResp.ExpiresIn)
|
||||
}
|
||||
|
||||
return &KimiTokenData{
|
||||
AccessToken: tokenResp.AccessToken,
|
||||
RefreshToken: tokenResp.RefreshToken,
|
||||
TokenType: tokenResp.TokenType,
|
||||
ExpiresAt: expiresAt,
|
||||
Scope: tokenResp.Scope,
|
||||
}, nil
|
||||
}
|
||||
@@ -1,116 +0,0 @@
|
||||
// Package kimi provides authentication and token management functionality
|
||||
// for Kimi (Moonshot AI) services. It handles OAuth2 device flow token storage,
|
||||
// serialization, and retrieval for maintaining authenticated sessions with the Kimi API.
|
||||
package kimi
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"time"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/misc"
|
||||
)
|
||||
|
||||
// KimiTokenStorage stores OAuth2 token information for Kimi API authentication.
|
||||
type KimiTokenStorage struct {
|
||||
// AccessToken is the OAuth2 access token used for authenticating API requests.
|
||||
AccessToken string `json:"access_token"`
|
||||
// RefreshToken is the OAuth2 refresh token used to obtain new access tokens.
|
||||
RefreshToken string `json:"refresh_token"`
|
||||
// TokenType is the type of token, typically "Bearer".
|
||||
TokenType string `json:"token_type"`
|
||||
// Scope is the OAuth2 scope granted to the token.
|
||||
Scope string `json:"scope,omitempty"`
|
||||
// DeviceID is the OAuth device flow identifier used for Kimi requests.
|
||||
DeviceID string `json:"device_id,omitempty"`
|
||||
// Expired is the RFC3339 timestamp when the access token expires.
|
||||
Expired string `json:"expired,omitempty"`
|
||||
// Type indicates the authentication provider type, always "kimi" for this storage.
|
||||
Type string `json:"type"`
|
||||
}
|
||||
|
||||
// KimiTokenData holds the raw OAuth token response from Kimi.
|
||||
type KimiTokenData struct {
|
||||
// AccessToken is the OAuth2 access token.
|
||||
AccessToken string `json:"access_token"`
|
||||
// RefreshToken is the OAuth2 refresh token.
|
||||
RefreshToken string `json:"refresh_token"`
|
||||
// TokenType is the type of token, typically "Bearer".
|
||||
TokenType string `json:"token_type"`
|
||||
// ExpiresAt is the Unix timestamp when the token expires.
|
||||
ExpiresAt int64 `json:"expires_at"`
|
||||
// Scope is the OAuth2 scope granted to the token.
|
||||
Scope string `json:"scope"`
|
||||
}
|
||||
|
||||
// KimiAuthBundle bundles authentication data for storage.
|
||||
type KimiAuthBundle struct {
|
||||
// TokenData contains the OAuth token information.
|
||||
TokenData *KimiTokenData
|
||||
// DeviceID is the device identifier used during OAuth device flow.
|
||||
DeviceID string
|
||||
}
|
||||
|
||||
// DeviceCodeResponse represents Kimi's device code response.
|
||||
type DeviceCodeResponse struct {
|
||||
// DeviceCode is the device verification code.
|
||||
DeviceCode string `json:"device_code"`
|
||||
// UserCode is the code the user must enter at the verification URI.
|
||||
UserCode string `json:"user_code"`
|
||||
// VerificationURI is the URL where the user should enter the code.
|
||||
VerificationURI string `json:"verification_uri,omitempty"`
|
||||
// VerificationURIComplete is the URL with the code pre-filled.
|
||||
VerificationURIComplete string `json:"verification_uri_complete"`
|
||||
// ExpiresIn is the number of seconds until the device code expires.
|
||||
ExpiresIn int `json:"expires_in"`
|
||||
// Interval is the minimum number of seconds to wait between polling requests.
|
||||
Interval int `json:"interval"`
|
||||
}
|
||||
|
||||
// SaveTokenToFile serializes the Kimi token storage to a JSON file.
|
||||
func (ts *KimiTokenStorage) SaveTokenToFile(authFilePath string) error {
|
||||
misc.LogSavingCredentials(authFilePath)
|
||||
ts.Type = "kimi"
|
||||
|
||||
if err := os.MkdirAll(filepath.Dir(authFilePath), 0700); err != nil {
|
||||
return fmt.Errorf("failed to create directory: %v", err)
|
||||
}
|
||||
|
||||
f, err := os.Create(authFilePath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create token file: %w", err)
|
||||
}
|
||||
defer func() {
|
||||
_ = f.Close()
|
||||
}()
|
||||
|
||||
encoder := json.NewEncoder(f)
|
||||
encoder.SetIndent("", " ")
|
||||
if err = encoder.Encode(ts); err != nil {
|
||||
return fmt.Errorf("failed to write token to file: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// IsExpired checks if the token has expired.
|
||||
func (ts *KimiTokenStorage) IsExpired() bool {
|
||||
if ts.Expired == "" {
|
||||
return false // No expiry set, assume valid
|
||||
}
|
||||
t, err := time.Parse(time.RFC3339, ts.Expired)
|
||||
if err != nil {
|
||||
return true // Has expiry string but can't parse
|
||||
}
|
||||
// Consider expired if within refresh threshold
|
||||
return time.Now().Add(time.Duration(refreshThresholdSeconds) * time.Second).After(t)
|
||||
}
|
||||
|
||||
// NeedsRefresh checks if the token should be refreshed.
|
||||
func (ts *KimiTokenStorage) NeedsRefresh() bool {
|
||||
if ts.RefreshToken == "" {
|
||||
return false // Can't refresh without refresh token
|
||||
}
|
||||
return ts.IsExpired()
|
||||
}
|
||||
19
internal/cache/signature_cache.go
vendored
19
internal/cache/signature_cache.go
vendored
@@ -6,6 +6,8 @@ import (
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
|
||||
)
|
||||
|
||||
// SignatureEntry holds a cached thinking signature with timestamp
|
||||
@@ -184,6 +186,7 @@ func HasValidSignature(modelName, signature string) bool {
|
||||
}
|
||||
|
||||
func GetModelGroup(modelName string) string {
|
||||
// Fast path: check model name patterns first
|
||||
if strings.Contains(modelName, "gpt") {
|
||||
return "gpt"
|
||||
} else if strings.Contains(modelName, "claude") {
|
||||
@@ -191,5 +194,21 @@ func GetModelGroup(modelName string) string {
|
||||
} else if strings.Contains(modelName, "gemini") {
|
||||
return "gemini"
|
||||
}
|
||||
|
||||
// Slow path: check registry for provider-based grouping
|
||||
// This handles models registered via claude-api-key, gemini-api-key, etc.
|
||||
// that don't have provider name in their model name (e.g., kimi-k2.5 via claude-api-key)
|
||||
if providers := registry.GetGlobalRegistry().GetModelProviders(modelName); len(providers) > 0 {
|
||||
provider := strings.ToLower(providers[0])
|
||||
switch provider {
|
||||
case "claude":
|
||||
return "claude"
|
||||
case "gemini", "gemini-cli", "aistudio", "vertex", "antigravity":
|
||||
return "gemini"
|
||||
case "codex":
|
||||
return "gpt"
|
||||
}
|
||||
}
|
||||
|
||||
return modelName
|
||||
}
|
||||
|
||||
81
internal/cache/signature_cache_test.go
vendored
81
internal/cache/signature_cache_test.go
vendored
@@ -208,3 +208,84 @@ func TestCacheSignature_ExpirationLogic(t *testing.T) {
|
||||
// but the logic is verified by the implementation
|
||||
_ = time.Now() // Acknowledge we're not testing time passage
|
||||
}
|
||||
|
||||
// === GetModelGroup Tests ===
|
||||
// These tests verify that GetModelGroup correctly identifies model groups
|
||||
// both by name pattern (fast path) and by registry provider lookup (slow path).
|
||||
|
||||
func TestGetModelGroup_ByNamePattern(t *testing.T) {
|
||||
tests := []struct {
|
||||
modelName string
|
||||
expectedGroup string
|
||||
}{
|
||||
{"gpt-4o", "gpt"},
|
||||
{"gpt-4-turbo", "gpt"},
|
||||
{"claude-sonnet-4-20250514", "claude"},
|
||||
{"claude-opus-4-5-thinking", "claude"},
|
||||
{"gemini-2.5-pro", "gemini"},
|
||||
{"gemini-3-pro-preview", "gemini"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.modelName, func(t *testing.T) {
|
||||
result := GetModelGroup(tt.modelName)
|
||||
if result != tt.expectedGroup {
|
||||
t.Errorf("GetModelGroup(%q) = %q, expected %q", tt.modelName, result, tt.expectedGroup)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestGetModelGroup_UnknownModel(t *testing.T) {
|
||||
// For unknown models with no registry entry, should return the model name itself
|
||||
result := GetModelGroup("unknown-model-xyz")
|
||||
if result != "unknown-model-xyz" {
|
||||
t.Errorf("GetModelGroup for unknown model should return model name, got %q", result)
|
||||
}
|
||||
}
|
||||
|
||||
// TestGetModelGroup_RegistryFallback tests that models registered via
|
||||
// provider-specific API keys (e.g., kimi-k2.5 via claude-api-key) are
|
||||
// correctly grouped by their provider.
|
||||
// This test requires a populated global registry.
|
||||
func TestGetModelGroup_RegistryFallback(t *testing.T) {
|
||||
// This test only makes sense when the global registry is populated
|
||||
// In unit test context, skip if registry is empty
|
||||
|
||||
// Example: kimi-k2.5 registered via claude-api-key should group as "claude"
|
||||
// The model name doesn't contain "claude", so name pattern matching fails.
|
||||
// The registry should be checked to find the provider.
|
||||
|
||||
// Skip for now - this requires integration test setup
|
||||
t.Skip("Requires populated global registry - run as integration test")
|
||||
}
|
||||
|
||||
// === Cross-Model Signature Validation Tests ===
|
||||
// These tests verify that signatures cached under one model name can be
|
||||
// validated under mapped model names (same provider group).
|
||||
|
||||
func TestCacheSignature_CrossModelValidation(t *testing.T) {
|
||||
ClearSignatureCache("")
|
||||
|
||||
// Original request uses "claude-opus-4-5-20251101"
|
||||
originalModel := "claude-opus-4-5-20251101"
|
||||
// Mapped model is "claude-opus-4-5-thinking"
|
||||
mappedModel := "claude-opus-4-5-thinking"
|
||||
|
||||
text := "Some thinking block content"
|
||||
sig := "validSignature123456789012345678901234567890123456789012"
|
||||
|
||||
// Cache signature under the original model
|
||||
CacheSignature(originalModel, text, sig)
|
||||
|
||||
// Both should return the same signature because they're in the same group
|
||||
retrieved1 := GetCachedSignature(originalModel, text)
|
||||
retrieved2 := GetCachedSignature(mappedModel, text)
|
||||
|
||||
if retrieved1 != sig {
|
||||
t.Errorf("Original model signature mismatch: got %q", retrieved1)
|
||||
}
|
||||
if retrieved2 != sig {
|
||||
t.Errorf("Mapped model signature mismatch: got %q", retrieved2)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -19,7 +19,6 @@ func newAuthManager() *sdkAuth.Manager {
|
||||
sdkAuth.NewQwenAuthenticator(),
|
||||
sdkAuth.NewIFlowAuthenticator(),
|
||||
sdkAuth.NewAntigravityAuthenticator(),
|
||||
sdkAuth.NewKimiAuthenticator(),
|
||||
)
|
||||
return manager
|
||||
}
|
||||
|
||||
@@ -1,44 +0,0 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
sdkAuth "github.com/router-for-me/CLIProxyAPI/v6/sdk/auth"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// DoKimiLogin triggers the OAuth device flow for Kimi (Moonshot AI) and saves tokens.
|
||||
// It initiates the device flow authentication, displays the verification URL for the user,
|
||||
// and waits for authorization before saving the tokens.
|
||||
//
|
||||
// Parameters:
|
||||
// - cfg: The application configuration containing proxy and auth directory settings
|
||||
// - options: Login options including browser behavior settings
|
||||
func DoKimiLogin(cfg *config.Config, options *LoginOptions) {
|
||||
if options == nil {
|
||||
options = &LoginOptions{}
|
||||
}
|
||||
|
||||
manager := newAuthManager()
|
||||
authOpts := &sdkAuth.LoginOptions{
|
||||
NoBrowser: options.NoBrowser,
|
||||
Metadata: map[string]string{},
|
||||
Prompt: options.Prompt,
|
||||
}
|
||||
|
||||
record, savedPath, err := manager.Login(context.Background(), "kimi", cfg, authOpts)
|
||||
if err != nil {
|
||||
log.Errorf("Kimi authentication failed: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
if savedPath != "" {
|
||||
fmt.Printf("Authentication saved to %s\n", savedPath)
|
||||
}
|
||||
if record != nil && record.Label != "" {
|
||||
fmt.Printf("Authenticated as %s\n", record.Label)
|
||||
}
|
||||
fmt.Println("Kimi authentication successful!")
|
||||
}
|
||||
@@ -18,10 +18,7 @@ import (
|
||||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
|
||||
const (
|
||||
DefaultPanelGitHubRepository = "https://github.com/router-for-me/Cli-Proxy-API-Management-Center"
|
||||
DefaultPprofAddr = "127.0.0.1:8316"
|
||||
)
|
||||
const DefaultPanelGitHubRepository = "https://github.com/router-for-me/Cli-Proxy-API-Management-Center"
|
||||
|
||||
// Config represents the application's configuration, loaded from a YAML file.
|
||||
type Config struct {
|
||||
@@ -44,9 +41,6 @@ type Config struct {
|
||||
// Debug enables or disables debug-level logging and other debug features.
|
||||
Debug bool `yaml:"debug" json:"debug"`
|
||||
|
||||
// Pprof config controls the optional pprof HTTP debug server.
|
||||
Pprof PprofConfig `yaml:"pprof" json:"pprof"`
|
||||
|
||||
// CommercialMode disables high-overhead HTTP middleware features to minimize per-request memory usage.
|
||||
CommercialMode bool `yaml:"commercial-mode" json:"commercial-mode"`
|
||||
|
||||
@@ -127,14 +121,6 @@ type TLSConfig struct {
|
||||
Key string `yaml:"key" json:"key"`
|
||||
}
|
||||
|
||||
// PprofConfig holds pprof HTTP server settings.
|
||||
type PprofConfig struct {
|
||||
// Enable toggles the pprof HTTP debug server.
|
||||
Enable bool `yaml:"enable" json:"enable"`
|
||||
// Addr is the host:port address for the pprof HTTP server.
|
||||
Addr string `yaml:"addr" json:"addr"`
|
||||
}
|
||||
|
||||
// RemoteManagement holds management API configuration under 'remote-management'.
|
||||
type RemoteManagement struct {
|
||||
// AllowRemote toggles remote (non-localhost) access to management API.
|
||||
@@ -493,15 +479,14 @@ func LoadConfig(configFile string) (*Config, error) {
|
||||
// If optional is true and the file is missing, it returns an empty Config.
|
||||
// If optional is true and the file is empty or invalid, it returns an empty Config.
|
||||
func LoadConfigOptional(configFile string, optional bool) (*Config, error) {
|
||||
// NOTE: Startup oauth-model-alias migration is intentionally disabled.
|
||||
// Reason: avoid mutating config.yaml during server startup.
|
||||
// Re-enable the block below if automatic startup migration is needed again.
|
||||
// if migrated, err := MigrateOAuthModelAlias(configFile); err != nil {
|
||||
// // Log warning but don't fail - config loading should still work
|
||||
// fmt.Printf("Warning: oauth-model-alias migration failed: %v\n", err)
|
||||
// } else if migrated {
|
||||
// fmt.Println("Migrated oauth-model-mappings to oauth-model-alias")
|
||||
// }
|
||||
// Perform oauth-model-alias migration before loading config.
|
||||
// This migrates oauth-model-mappings to oauth-model-alias if needed.
|
||||
if migrated, err := MigrateOAuthModelAlias(configFile); err != nil {
|
||||
// Log warning but don't fail - config loading should still work
|
||||
fmt.Printf("Warning: oauth-model-alias migration failed: %v\n", err)
|
||||
} else if migrated {
|
||||
fmt.Println("Migrated oauth-model-mappings to oauth-model-alias")
|
||||
}
|
||||
|
||||
// Read the entire configuration file into memory.
|
||||
data, err := os.ReadFile(configFile)
|
||||
@@ -529,8 +514,6 @@ func LoadConfigOptional(configFile string, optional bool) (*Config, error) {
|
||||
cfg.ErrorLogsMaxFiles = 10
|
||||
cfg.UsageStatisticsEnabled = false
|
||||
cfg.DisableCooling = false
|
||||
cfg.Pprof.Enable = false
|
||||
cfg.Pprof.Addr = DefaultPprofAddr
|
||||
cfg.AmpCode.RestrictManagementToLocalhost = false // Default to false: API key auth is sufficient
|
||||
cfg.RemoteManagement.PanelGitHubRepository = DefaultPanelGitHubRepository
|
||||
if err = yaml.Unmarshal(data, &cfg); err != nil {
|
||||
@@ -541,21 +524,18 @@ func LoadConfigOptional(configFile string, optional bool) (*Config, error) {
|
||||
return nil, fmt.Errorf("failed to parse config file: %w", err)
|
||||
}
|
||||
|
||||
// NOTE: Startup legacy key migration is intentionally disabled.
|
||||
// Reason: avoid mutating config.yaml during server startup.
|
||||
// Re-enable the block below if automatic startup migration is needed again.
|
||||
// var legacy legacyConfigData
|
||||
// if errLegacy := yaml.Unmarshal(data, &legacy); errLegacy == nil {
|
||||
// if cfg.migrateLegacyGeminiKeys(legacy.LegacyGeminiKeys) {
|
||||
// cfg.legacyMigrationPending = true
|
||||
// }
|
||||
// if cfg.migrateLegacyOpenAICompatibilityKeys(legacy.OpenAICompat) {
|
||||
// cfg.legacyMigrationPending = true
|
||||
// }
|
||||
// if cfg.migrateLegacyAmpConfig(&legacy) {
|
||||
// cfg.legacyMigrationPending = true
|
||||
// }
|
||||
// }
|
||||
var legacy legacyConfigData
|
||||
if errLegacy := yaml.Unmarshal(data, &legacy); errLegacy == nil {
|
||||
if cfg.migrateLegacyGeminiKeys(legacy.LegacyGeminiKeys) {
|
||||
cfg.legacyMigrationPending = true
|
||||
}
|
||||
if cfg.migrateLegacyOpenAICompatibilityKeys(legacy.OpenAICompat) {
|
||||
cfg.legacyMigrationPending = true
|
||||
}
|
||||
if cfg.migrateLegacyAmpConfig(&legacy) {
|
||||
cfg.legacyMigrationPending = true
|
||||
}
|
||||
}
|
||||
|
||||
// Hash remote management key if plaintext is detected (nested)
|
||||
// We consider a value to be already hashed if it looks like a bcrypt hash ($2a$, $2b$, or $2y$ prefix).
|
||||
@@ -576,11 +556,6 @@ func LoadConfigOptional(configFile string, optional bool) (*Config, error) {
|
||||
cfg.RemoteManagement.PanelGitHubRepository = DefaultPanelGitHubRepository
|
||||
}
|
||||
|
||||
cfg.Pprof.Addr = strings.TrimSpace(cfg.Pprof.Addr)
|
||||
if cfg.Pprof.Addr == "" {
|
||||
cfg.Pprof.Addr = DefaultPprofAddr
|
||||
}
|
||||
|
||||
if cfg.LogsMaxTotalSizeMB < 0 {
|
||||
cfg.LogsMaxTotalSizeMB = 0
|
||||
}
|
||||
@@ -616,20 +591,17 @@ func LoadConfigOptional(configFile string, optional bool) (*Config, error) {
|
||||
// Validate raw payload rules and drop invalid entries.
|
||||
cfg.SanitizePayloadRules()
|
||||
|
||||
// NOTE: Legacy migration persistence is intentionally disabled together with
|
||||
// startup legacy migration to keep startup read-only for config.yaml.
|
||||
// Re-enable the block below if automatic startup migration is needed again.
|
||||
// if cfg.legacyMigrationPending {
|
||||
// fmt.Println("Detected legacy configuration keys, attempting to persist the normalized config...")
|
||||
// if !optional && configFile != "" {
|
||||
// if err := SaveConfigPreserveComments(configFile, &cfg); err != nil {
|
||||
// return nil, fmt.Errorf("failed to persist migrated legacy config: %w", err)
|
||||
// }
|
||||
// fmt.Println("Legacy configuration normalized and persisted.")
|
||||
// } else {
|
||||
// fmt.Println("Legacy configuration normalized in memory; persistence skipped.")
|
||||
// }
|
||||
// }
|
||||
if cfg.legacyMigrationPending {
|
||||
fmt.Println("Detected legacy configuration keys, attempting to persist the normalized config...")
|
||||
if !optional && configFile != "" {
|
||||
if err := SaveConfigPreserveComments(configFile, &cfg); err != nil {
|
||||
return nil, fmt.Errorf("failed to persist migrated legacy config: %w", err)
|
||||
}
|
||||
fmt.Println("Legacy configuration normalized and persisted.")
|
||||
} else {
|
||||
fmt.Println("Legacy configuration normalized in memory; persistence skipped.")
|
||||
}
|
||||
}
|
||||
|
||||
// Return the populated configuration struct.
|
||||
return &cfg, nil
|
||||
|
||||
@@ -17,7 +17,6 @@ var antigravityModelConversionTable = map[string]string{
|
||||
"gemini-claude-sonnet-4-5": "claude-sonnet-4-5",
|
||||
"gemini-claude-sonnet-4-5-thinking": "claude-sonnet-4-5-thinking",
|
||||
"gemini-claude-opus-4-5-thinking": "claude-opus-4-5-thinking",
|
||||
"gemini-claude-opus-4-6-thinking": "claude-opus-4-6-thinking",
|
||||
}
|
||||
|
||||
// defaultAntigravityAliases returns the default oauth-model-alias configuration
|
||||
@@ -31,7 +30,6 @@ func defaultAntigravityAliases() []OAuthModelAlias {
|
||||
{Name: "claude-sonnet-4-5", Alias: "gemini-claude-sonnet-4-5"},
|
||||
{Name: "claude-sonnet-4-5-thinking", Alias: "gemini-claude-sonnet-4-5-thinking"},
|
||||
{Name: "claude-opus-4-5-thinking", Alias: "gemini-claude-opus-4-5-thinking"},
|
||||
{Name: "claude-opus-4-6-thinking", Alias: "gemini-claude-opus-4-6-thinking"},
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -131,9 +131,6 @@ func TestMigrateOAuthModelAlias_ConvertsAntigravityModels(t *testing.T) {
|
||||
if !strings.Contains(content, "claude-opus-4-5-thinking") {
|
||||
t.Fatal("expected missing default alias claude-opus-4-5-thinking to be added")
|
||||
}
|
||||
if !strings.Contains(content, "claude-opus-4-6-thinking") {
|
||||
t.Fatal("expected missing default alias claude-opus-4-6-thinking to be added")
|
||||
}
|
||||
}
|
||||
|
||||
func TestMigrateOAuthModelAlias_AddsDefaultIfNeitherExists(t *testing.T) {
|
||||
|
||||
@@ -131,10 +131,7 @@ func ResolveLogDirectory(cfg *config.Config) string {
|
||||
return logDir
|
||||
}
|
||||
if !isDirWritable(logDir) {
|
||||
authDir, err := util.ResolveAuthDir(cfg.AuthDir)
|
||||
if err != nil {
|
||||
log.Warnf("Failed to resolve auth-dir %q for log directory: %v", cfg.AuthDir, err)
|
||||
}
|
||||
authDir := strings.TrimSpace(cfg.AuthDir)
|
||||
if authDir != "" {
|
||||
logDir = filepath.Join(authDir, "logs")
|
||||
}
|
||||
|
||||
@@ -15,7 +15,7 @@ func GetClaudeModels() []*ModelInfo {
|
||||
DisplayName: "Claude 4.5 Haiku",
|
||||
ContextLength: 200000,
|
||||
MaxCompletionTokens: 64000,
|
||||
Thinking: &ThinkingSupport{Min: 1024, Max: 128000, ZeroAllowed: true, DynamicAllowed: false},
|
||||
// Thinking: not supported for Haiku models
|
||||
},
|
||||
{
|
||||
ID: "claude-sonnet-4-5-20250929",
|
||||
@@ -28,18 +28,6 @@ func GetClaudeModels() []*ModelInfo {
|
||||
MaxCompletionTokens: 64000,
|
||||
Thinking: &ThinkingSupport{Min: 1024, Max: 128000, ZeroAllowed: true, DynamicAllowed: false},
|
||||
},
|
||||
{
|
||||
ID: "claude-opus-4-6",
|
||||
Object: "model",
|
||||
Created: 1770318000, // 2026-02-05
|
||||
OwnedBy: "anthropic",
|
||||
Type: "claude",
|
||||
DisplayName: "Claude 4.6 Opus",
|
||||
Description: "Premium model combining maximum intelligence with practical performance",
|
||||
ContextLength: 1000000,
|
||||
MaxCompletionTokens: 128000,
|
||||
Thinking: &ThinkingSupport{Min: 1024, Max: 128000, ZeroAllowed: true, DynamicAllowed: false},
|
||||
},
|
||||
{
|
||||
ID: "claude-opus-4-5-20251101",
|
||||
Object: "model",
|
||||
@@ -728,20 +716,6 @@ func GetOpenAIModels() []*ModelInfo {
|
||||
SupportedParameters: []string{"tools"},
|
||||
Thinking: &ThinkingSupport{Levels: []string{"low", "medium", "high", "xhigh"}},
|
||||
},
|
||||
{
|
||||
ID: "gpt-5.3-codex",
|
||||
Object: "model",
|
||||
Created: 1770307200,
|
||||
OwnedBy: "openai",
|
||||
Type: "openai",
|
||||
Version: "gpt-5.3",
|
||||
DisplayName: "GPT 5.3 Codex",
|
||||
Description: "Stable version of GPT 5.3 Codex, The best model for coding and agentic tasks across domains.",
|
||||
ContextLength: 400000,
|
||||
MaxCompletionTokens: 128000,
|
||||
SupportedParameters: []string{"tools"},
|
||||
Thinking: &ThinkingSupport{Levels: []string{"low", "medium", "high", "xhigh"}},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
@@ -829,7 +803,6 @@ func GetIFlowModels() []*ModelInfo {
|
||||
{ID: "minimax-m2", DisplayName: "MiniMax-M2", Description: "MiniMax M2", Created: 1758672000, Thinking: iFlowThinkingSupport},
|
||||
{ID: "minimax-m2.1", DisplayName: "MiniMax-M2.1", Description: "MiniMax M2.1", Created: 1766448000, Thinking: iFlowThinkingSupport},
|
||||
{ID: "iflow-rome-30ba3b", DisplayName: "iFlow-ROME", Description: "iFlow Rome 30BA3B model", Created: 1736899200},
|
||||
{ID: "kimi-k2.5", DisplayName: "Kimi-K2.5", Description: "Moonshot Kimi K2.5", Created: 1769443200, Thinking: iFlowThinkingSupport},
|
||||
}
|
||||
models := make([]*ModelInfo, 0, len(entries))
|
||||
for _, entry := range entries {
|
||||
@@ -866,50 +839,8 @@ func GetAntigravityModelConfig() map[string]*AntigravityModelConfig {
|
||||
"gemini-3-flash": {Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"minimal", "low", "medium", "high"}}},
|
||||
"claude-sonnet-4-5-thinking": {Thinking: &ThinkingSupport{Min: 1024, Max: 128000, ZeroAllowed: true, DynamicAllowed: true}, MaxCompletionTokens: 64000},
|
||||
"claude-opus-4-5-thinking": {Thinking: &ThinkingSupport{Min: 1024, Max: 128000, ZeroAllowed: true, DynamicAllowed: true}, MaxCompletionTokens: 64000},
|
||||
"claude-opus-4-6-thinking": {Thinking: &ThinkingSupport{Min: 1024, Max: 128000, ZeroAllowed: true, DynamicAllowed: true}, MaxCompletionTokens: 128000},
|
||||
"claude-sonnet-4-5": {MaxCompletionTokens: 64000},
|
||||
"gpt-oss-120b-medium": {},
|
||||
"tab_flash_lite_preview": {},
|
||||
}
|
||||
}
|
||||
|
||||
// GetKimiModels returns the standard Kimi (Moonshot AI) model definitions
|
||||
func GetKimiModels() []*ModelInfo {
|
||||
return []*ModelInfo{
|
||||
{
|
||||
ID: "kimi-k2",
|
||||
Object: "model",
|
||||
Created: 1752192000, // 2025-07-11
|
||||
OwnedBy: "moonshot",
|
||||
Type: "kimi",
|
||||
DisplayName: "Kimi K2",
|
||||
Description: "Kimi K2 - Moonshot AI's flagship coding model",
|
||||
ContextLength: 131072,
|
||||
MaxCompletionTokens: 32768,
|
||||
},
|
||||
{
|
||||
ID: "kimi-k2-thinking",
|
||||
Object: "model",
|
||||
Created: 1762387200, // 2025-11-06
|
||||
OwnedBy: "moonshot",
|
||||
Type: "kimi",
|
||||
DisplayName: "Kimi K2 Thinking",
|
||||
Description: "Kimi K2 Thinking - Extended reasoning model",
|
||||
ContextLength: 131072,
|
||||
MaxCompletionTokens: 32768,
|
||||
Thinking: &ThinkingSupport{Min: 1024, Max: 32000, ZeroAllowed: true, DynamicAllowed: true},
|
||||
},
|
||||
{
|
||||
ID: "kimi-k2.5",
|
||||
Object: "model",
|
||||
Created: 1769472000, // 2026-01-26
|
||||
OwnedBy: "moonshot",
|
||||
Type: "kimi",
|
||||
DisplayName: "Kimi K2.5",
|
||||
Description: "Kimi K2.5 - Latest Moonshot AI coding model with improved capabilities",
|
||||
ContextLength: 131072,
|
||||
MaxCompletionTokens: 32768,
|
||||
Thinking: &ThinkingSupport{Min: 1024, Max: 32000, ZeroAllowed: true, DynamicAllowed: true},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
39
internal/routing/adapter.go
Normal file
39
internal/routing/adapter.go
Normal file
@@ -0,0 +1,39 @@
|
||||
// Package routing provides adapter to integrate with existing codebase.
|
||||
package routing
|
||||
|
||||
import (
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||
)
|
||||
|
||||
// Adapter bridges the new routing layer with existing auth manager.
|
||||
type Adapter struct {
|
||||
router *Router
|
||||
exec *Executor
|
||||
}
|
||||
|
||||
// NewAdapter creates a new adapter with the given configuration and auth manager.
|
||||
func NewAdapter(cfg *config.Config, authManager *coreauth.Manager) *Adapter {
|
||||
registry := NewRegistry()
|
||||
|
||||
// TODO: Register OAuth providers from authManager
|
||||
// TODO: Register API key providers from cfg
|
||||
|
||||
router := NewRouter(registry, cfg)
|
||||
exec := NewExecutor(router)
|
||||
|
||||
return &Adapter{
|
||||
router: router,
|
||||
exec: exec,
|
||||
}
|
||||
}
|
||||
|
||||
// Router returns the underlying router.
|
||||
func (a *Adapter) Router() *Router {
|
||||
return a.router
|
||||
}
|
||||
|
||||
// Executor returns the underlying executor.
|
||||
func (a *Adapter) Executor() *Executor {
|
||||
return a.exec
|
||||
}
|
||||
11
internal/routing/ctxkeys/keys.go
Normal file
11
internal/routing/ctxkeys/keys.go
Normal file
@@ -0,0 +1,11 @@
|
||||
package ctxkeys
|
||||
|
||||
type key string
|
||||
|
||||
const (
|
||||
MappedModel key = "mapped_model"
|
||||
FallbackModels key = "fallback_models"
|
||||
RouteCandidates key = "route_candidates"
|
||||
RoutingDecision key = "routing_decision"
|
||||
MappingApplied key = "mapping_applied"
|
||||
)
|
||||
111
internal/routing/executor.go
Normal file
111
internal/routing/executor.go
Normal file
@@ -0,0 +1,111 @@
|
||||
package routing
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// Executor handles request execution with fallback support.
|
||||
type Executor struct {
|
||||
router *Router
|
||||
}
|
||||
|
||||
// NewExecutor creates a new executor with the given router.
|
||||
func NewExecutor(router *Router) *Executor {
|
||||
return &Executor{router: router}
|
||||
}
|
||||
|
||||
// Execute sends the request through the routing decision.
|
||||
func (e *Executor) Execute(ctx context.Context, req executor.Request) (executor.Response, error) {
|
||||
decision := e.router.Resolve(req.Model)
|
||||
|
||||
log.Debugf("routing: %s -> %s (%d candidates)",
|
||||
decision.RequestedModel,
|
||||
decision.ResolvedModel,
|
||||
len(decision.Candidates))
|
||||
|
||||
var lastErr error
|
||||
tried := make(map[string]struct{})
|
||||
|
||||
for i, candidate := range decision.Candidates {
|
||||
key := candidate.Provider.Name() + "/" + candidate.Model
|
||||
if _, ok := tried[key]; ok {
|
||||
continue
|
||||
}
|
||||
tried[key] = struct{}{}
|
||||
|
||||
log.Debugf("routing: trying candidate %d/%d: %s with model %s",
|
||||
i+1, len(decision.Candidates), candidate.Provider.Name(), candidate.Model)
|
||||
|
||||
req.Model = candidate.Model
|
||||
resp, err := candidate.Provider.Execute(ctx, candidate.Model, req)
|
||||
if err == nil {
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
lastErr = err
|
||||
log.Debugf("routing: candidate failed: %v", err)
|
||||
|
||||
// Check if it's a fatal error (not retryable)
|
||||
if isFatalError(err) {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if lastErr != nil {
|
||||
return executor.Response{}, lastErr
|
||||
}
|
||||
return executor.Response{}, errors.New("no available providers")
|
||||
}
|
||||
|
||||
// ExecuteStream sends a streaming request through the routing decision.
|
||||
func (e *Executor) ExecuteStream(ctx context.Context, req executor.Request) (<-chan executor.StreamChunk, error) {
|
||||
decision := e.router.Resolve(req.Model)
|
||||
|
||||
log.Debugf("routing stream: %s -> %s (%d candidates)",
|
||||
decision.RequestedModel,
|
||||
decision.ResolvedModel,
|
||||
len(decision.Candidates))
|
||||
|
||||
var lastErr error
|
||||
tried := make(map[string]struct{})
|
||||
|
||||
for i, candidate := range decision.Candidates {
|
||||
key := candidate.Provider.Name() + "/" + candidate.Model
|
||||
if _, ok := tried[key]; ok {
|
||||
continue
|
||||
}
|
||||
tried[key] = struct{}{}
|
||||
|
||||
log.Debugf("routing stream: trying candidate %d/%d: %s with model %s",
|
||||
i+1, len(decision.Candidates), candidate.Provider.Name(), candidate.Model)
|
||||
|
||||
req.Model = candidate.Model
|
||||
chunks, err := candidate.Provider.ExecuteStream(ctx, candidate.Model, req)
|
||||
if err == nil {
|
||||
return chunks, nil
|
||||
}
|
||||
|
||||
lastErr = err
|
||||
log.Debugf("routing stream: candidate failed: %v", err)
|
||||
|
||||
if isFatalError(err) {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if lastErr != nil {
|
||||
return nil, lastErr
|
||||
}
|
||||
return nil, errors.New("no available providers")
|
||||
}
|
||||
|
||||
// isFatalError returns true if the error is not retryable.
|
||||
func isFatalError(err error) bool {
|
||||
// TODO: implement based on error type
|
||||
// For now, all errors are retryable
|
||||
return false
|
||||
}
|
||||
59
internal/routing/extractor.go
Normal file
59
internal/routing/extractor.go
Normal file
@@ -0,0 +1,59 @@
|
||||
package routing
|
||||
|
||||
import (
|
||||
"strings"
|
||||
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
|
||||
// ModelExtractor extracts model names from request data.
|
||||
type ModelExtractor interface {
|
||||
// Extract returns the model name from the request body and gin parameters.
|
||||
// The ginParams map contains route parameters like "action" and "path".
|
||||
Extract(body []byte, ginParams map[string]string) (string, error)
|
||||
}
|
||||
|
||||
// DefaultModelExtractor is the standard implementation of ModelExtractor.
|
||||
type DefaultModelExtractor struct{}
|
||||
|
||||
// NewModelExtractor creates a new DefaultModelExtractor.
|
||||
func NewModelExtractor() *DefaultModelExtractor {
|
||||
return &DefaultModelExtractor{}
|
||||
}
|
||||
|
||||
// Extract extracts the model name from the request.
|
||||
// It checks in order:
|
||||
// 1. JSON body "model" field (OpenAI, Claude format)
|
||||
// 2. "action" parameter for Gemini standard format (e.g., "gemini-pro:generateContent")
|
||||
// 3. "path" parameter for AMP CLI Gemini format (e.g., "/publishers/google/models/gemini-3-pro:streamGenerateContent")
|
||||
func (e *DefaultModelExtractor) Extract(body []byte, ginParams map[string]string) (string, error) {
|
||||
// First try to parse from JSON body (OpenAI, Claude, etc.)
|
||||
if result := gjson.GetBytes(body, "model"); result.Exists() && result.Type == gjson.String {
|
||||
return result.String(), nil
|
||||
}
|
||||
|
||||
// For Gemini requests, model is in the URL path
|
||||
// Standard format: /models/{model}:generateContent -> :action parameter
|
||||
if action, ok := ginParams["action"]; ok && action != "" {
|
||||
// Split by colon to get model name (e.g., "gemini-pro:generateContent" -> "gemini-pro")
|
||||
parts := strings.Split(action, ":")
|
||||
if len(parts) > 0 && parts[0] != "" {
|
||||
return parts[0], nil
|
||||
}
|
||||
}
|
||||
|
||||
// AMP CLI format: /publishers/google/models/{model}:method -> *path parameter
|
||||
// Example: /publishers/google/models/gemini-3-pro-preview:streamGenerateContent
|
||||
if path, ok := ginParams["path"]; ok && path != "" {
|
||||
// Look for /models/{model}:method pattern
|
||||
if idx := strings.Index(path, "/models/"); idx >= 0 {
|
||||
modelPart := path[idx+8:] // Skip "/models/"
|
||||
// Split by colon to get model name
|
||||
if colonIdx := strings.Index(modelPart, ":"); colonIdx > 0 {
|
||||
return modelPart[:colonIdx], nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return "", nil
|
||||
}
|
||||
214
internal/routing/extractor_test.go
Normal file
214
internal/routing/extractor_test.go
Normal file
@@ -0,0 +1,214 @@
|
||||
package routing
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestModelExtractor_ExtractFromJSONBody(t *testing.T) {
|
||||
extractor := NewModelExtractor()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
body []byte
|
||||
want string
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "extract from JSON body with model field",
|
||||
body: []byte(`{"model":"gpt-4.1"}`),
|
||||
want: "gpt-4.1",
|
||||
},
|
||||
{
|
||||
name: "extract claude model from JSON body",
|
||||
body: []byte(`{"model":"claude-3-5-sonnet-20241022"}`),
|
||||
want: "claude-3-5-sonnet-20241022",
|
||||
},
|
||||
{
|
||||
name: "extract with additional fields",
|
||||
body: []byte(`{"model":"gpt-4","messages":[{"role":"user","content":"hello"}]}`),
|
||||
want: "gpt-4",
|
||||
},
|
||||
{
|
||||
name: "empty body returns empty",
|
||||
body: []byte{},
|
||||
want: "",
|
||||
},
|
||||
{
|
||||
name: "no model field returns empty",
|
||||
body: []byte(`{"messages":[]}`),
|
||||
want: "",
|
||||
},
|
||||
{
|
||||
name: "model is not string returns empty",
|
||||
body: []byte(`{"model":123}`),
|
||||
want: "",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, err := extractor.Extract(tt.body, nil)
|
||||
if tt.wantErr {
|
||||
assert.Error(t, err)
|
||||
return
|
||||
}
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, tt.want, got)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestModelExtractor_ExtractFromGeminiActionParam(t *testing.T) {
|
||||
extractor := NewModelExtractor()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
body []byte
|
||||
ginParams map[string]string
|
||||
want string
|
||||
}{
|
||||
{
|
||||
name: "extract from action parameter - gemini-pro",
|
||||
body: []byte(`{}`),
|
||||
ginParams: map[string]string{"action": "gemini-pro:generateContent"},
|
||||
want: "gemini-pro",
|
||||
},
|
||||
{
|
||||
name: "extract from action parameter - gemini-ultra",
|
||||
body: []byte(`{}`),
|
||||
ginParams: map[string]string{"action": "gemini-ultra:chat"},
|
||||
want: "gemini-ultra",
|
||||
},
|
||||
{
|
||||
name: "empty action returns empty",
|
||||
body: []byte(`{}`),
|
||||
ginParams: map[string]string{"action": ""},
|
||||
want: "",
|
||||
},
|
||||
{
|
||||
name: "action without colon returns full value",
|
||||
body: []byte(`{}`),
|
||||
ginParams: map[string]string{"action": "gemini-model"},
|
||||
want: "gemini-model",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, err := extractor.Extract(tt.body, tt.ginParams)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, tt.want, got)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestModelExtractor_ExtractFromGeminiV1Beta1Path(t *testing.T) {
|
||||
extractor := NewModelExtractor()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
body []byte
|
||||
ginParams map[string]string
|
||||
want string
|
||||
}{
|
||||
{
|
||||
name: "extract from v1beta1 path - gemini-3-pro",
|
||||
body: []byte(`{}`),
|
||||
ginParams: map[string]string{"path": "/publishers/google/models/gemini-3-pro:streamGenerateContent"},
|
||||
want: "gemini-3-pro",
|
||||
},
|
||||
{
|
||||
name: "extract from v1beta1 path with preview",
|
||||
body: []byte(`{}`),
|
||||
ginParams: map[string]string{"path": "/publishers/google/models/gemini-3-pro-preview:generateContent"},
|
||||
want: "gemini-3-pro-preview",
|
||||
},
|
||||
{
|
||||
name: "path without models segment returns empty",
|
||||
body: []byte(`{}`),
|
||||
ginParams: map[string]string{"path": "/publishers/google/gemini-3-pro:streamGenerateContent"},
|
||||
want: "",
|
||||
},
|
||||
{
|
||||
name: "empty path returns empty",
|
||||
body: []byte(`{}`),
|
||||
ginParams: map[string]string{"path": ""},
|
||||
want: "",
|
||||
},
|
||||
{
|
||||
name: "path with /models/ but no colon returns empty",
|
||||
body: []byte(`{}`),
|
||||
ginParams: map[string]string{"path": "/publishers/google/models/gemini-3-pro"},
|
||||
want: "",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, err := extractor.Extract(tt.body, tt.ginParams)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, tt.want, got)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestModelExtractor_ExtractPriority(t *testing.T) {
|
||||
extractor := NewModelExtractor()
|
||||
|
||||
// JSON body takes priority over gin params
|
||||
t.Run("JSON body takes priority over action param", func(t *testing.T) {
|
||||
body := []byte(`{"model":"gpt-4"}`)
|
||||
params := map[string]string{"action": "gemini-pro:generateContent"}
|
||||
got, err := extractor.Extract(body, params)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "gpt-4", got)
|
||||
})
|
||||
|
||||
// Action param takes priority over path param
|
||||
t.Run("action param takes priority over path param", func(t *testing.T) {
|
||||
body := []byte(`{}`)
|
||||
params := map[string]string{
|
||||
"action": "gemini-action:generate",
|
||||
"path": "/publishers/google/models/gemini-path:streamGenerateContent",
|
||||
}
|
||||
got, err := extractor.Extract(body, params)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, "gemini-action", got)
|
||||
})
|
||||
}
|
||||
|
||||
func TestModelExtractor_NoModelFound(t *testing.T) {
|
||||
extractor := NewModelExtractor()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
body []byte
|
||||
ginParams map[string]string
|
||||
}{
|
||||
{
|
||||
name: "empty body and no params",
|
||||
body: []byte{},
|
||||
ginParams: nil,
|
||||
},
|
||||
{
|
||||
name: "body without model and no params",
|
||||
body: []byte(`{"messages":[]}`),
|
||||
ginParams: map[string]string{},
|
||||
},
|
||||
{
|
||||
name: "irrelevant params only",
|
||||
body: []byte(`{}`),
|
||||
ginParams: map[string]string{"other": "value"},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got, err := extractor.Extract(tt.body, tt.ginParams)
|
||||
assert.NoError(t, err)
|
||||
assert.Empty(t, got)
|
||||
})
|
||||
}
|
||||
}
|
||||
80
internal/routing/provider.go
Normal file
80
internal/routing/provider.go
Normal file
@@ -0,0 +1,80 @@
|
||||
// Package routing provides unified model routing for all provider types.
|
||||
package routing
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
|
||||
)
|
||||
|
||||
// ProviderType indicates the type of provider.
|
||||
type ProviderType string
|
||||
|
||||
const (
|
||||
ProviderTypeOAuth ProviderType = "oauth"
|
||||
ProviderTypeAPIKey ProviderType = "api_key"
|
||||
ProviderTypeVertex ProviderType = "vertex"
|
||||
)
|
||||
|
||||
// Provider is the unified interface for all provider types (OAuth, API key, etc.).
|
||||
type Provider interface {
|
||||
// Name returns the unique provider identifier.
|
||||
Name() string
|
||||
|
||||
// Type returns the provider type.
|
||||
Type() ProviderType
|
||||
|
||||
// SupportsModel returns true if this provider can handle the given model.
|
||||
SupportsModel(model string) bool
|
||||
|
||||
// Available returns true if the provider is available for the model (not quota exceeded).
|
||||
Available(model string) bool
|
||||
|
||||
// Priority returns the priority for this provider (lower = tried first).
|
||||
Priority() int
|
||||
|
||||
// Execute sends the request to the provider.
|
||||
Execute(ctx context.Context, model string, req executor.Request) (executor.Response, error)
|
||||
|
||||
// ExecuteStream sends a streaming request to the provider.
|
||||
ExecuteStream(ctx context.Context, model string, req executor.Request) (<-chan executor.StreamChunk, error)
|
||||
}
|
||||
|
||||
// ProviderCandidate represents a provider + model combination to try.
|
||||
type ProviderCandidate struct {
|
||||
Provider Provider
|
||||
Model string // The actual model name to use (may be different from requested due to aliasing)
|
||||
}
|
||||
|
||||
// Registry manages all available providers.
|
||||
type Registry struct {
|
||||
providers []Provider
|
||||
}
|
||||
|
||||
// NewRegistry creates a new provider registry.
|
||||
func NewRegistry() *Registry {
|
||||
return &Registry{
|
||||
providers: make([]Provider, 0),
|
||||
}
|
||||
}
|
||||
|
||||
// Register adds a provider to the registry.
|
||||
func (r *Registry) Register(p Provider) {
|
||||
r.providers = append(r.providers, p)
|
||||
}
|
||||
|
||||
// FindProviders returns all providers that support the given model and are available.
|
||||
func (r *Registry) FindProviders(model string) []Provider {
|
||||
var result []Provider
|
||||
for _, p := range r.providers {
|
||||
if p.SupportsModel(model) && p.Available(model) {
|
||||
result = append(result, p)
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// All returns all registered providers.
|
||||
func (r *Registry) All() []Provider {
|
||||
return r.providers
|
||||
}
|
||||
156
internal/routing/providers/apikey.go
Normal file
156
internal/routing/providers/apikey.go
Normal file
@@ -0,0 +1,156 @@
|
||||
package providers
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/routing"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
|
||||
)
|
||||
|
||||
// APIKeyProvider wraps API key configs as routing.Provider.
|
||||
type APIKeyProvider struct {
|
||||
name string
|
||||
provider string // claude, gemini, codex, vertex
|
||||
keys []APIKeyEntry
|
||||
mu sync.RWMutex
|
||||
client HTTPClient
|
||||
}
|
||||
|
||||
// APIKeyEntry represents a single API key configuration.
|
||||
type APIKeyEntry struct {
|
||||
APIKey string
|
||||
BaseURL string
|
||||
Models []config.ClaudeModel // Using ClaudeModel as generic model alias
|
||||
}
|
||||
|
||||
// HTTPClient interface for making HTTP requests.
|
||||
type HTTPClient interface {
|
||||
Do(req *http.Request) (*http.Response, error)
|
||||
}
|
||||
|
||||
// NewAPIKeyProvider creates a new API key provider.
|
||||
func NewAPIKeyProvider(name, provider string, client HTTPClient) *APIKeyProvider {
|
||||
return &APIKeyProvider{
|
||||
name: name,
|
||||
provider: provider,
|
||||
keys: make([]APIKeyEntry, 0),
|
||||
client: client,
|
||||
}
|
||||
}
|
||||
|
||||
// Name returns the provider name.
|
||||
func (p *APIKeyProvider) Name() string {
|
||||
return p.name
|
||||
}
|
||||
|
||||
// Type returns ProviderTypeAPIKey.
|
||||
func (p *APIKeyProvider) Type() routing.ProviderType {
|
||||
return routing.ProviderTypeAPIKey
|
||||
}
|
||||
|
||||
// SupportsModel checks if the model is supported by this provider.
|
||||
func (p *APIKeyProvider) SupportsModel(model string) bool {
|
||||
p.mu.RLock()
|
||||
defer p.mu.RUnlock()
|
||||
|
||||
for _, key := range p.keys {
|
||||
for _, m := range key.Models {
|
||||
if strings.EqualFold(m.Alias, model) || strings.EqualFold(m.Name, model) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// Available always returns true for API keys (unless explicitly disabled).
|
||||
func (p *APIKeyProvider) Available(model string) bool {
|
||||
return p.SupportsModel(model)
|
||||
}
|
||||
|
||||
// Priority returns the priority (API key is lower priority than OAuth).
|
||||
func (p *APIKeyProvider) Priority() int {
|
||||
return 20
|
||||
}
|
||||
|
||||
// Execute sends the request using the API key.
|
||||
func (p *APIKeyProvider) Execute(ctx context.Context, model string, req executor.Request) (executor.Response, error) {
|
||||
key := p.selectKey(model)
|
||||
if key == nil {
|
||||
return executor.Response{}, ErrNoMatchingAPIKey
|
||||
}
|
||||
|
||||
// Resolve the actual model name from alias
|
||||
actualModel := p.resolveModel(key, model)
|
||||
|
||||
// Execute via HTTP client
|
||||
return p.executeHTTP(ctx, key, actualModel, req)
|
||||
}
|
||||
|
||||
// ExecuteStream sends a streaming request.
|
||||
func (p *APIKeyProvider) ExecuteStream(ctx context.Context, model string, req executor.Request) (
|
||||
<-chan executor.StreamChunk, error) {
|
||||
key := p.selectKey(model)
|
||||
if key == nil {
|
||||
return nil, ErrNoMatchingAPIKey
|
||||
}
|
||||
|
||||
actualModel := p.resolveModel(key, model)
|
||||
return p.executeHTTPStream(ctx, key, actualModel, req)
|
||||
}
|
||||
|
||||
// AddKey adds an API key entry.
|
||||
func (p *APIKeyProvider) AddKey(entry APIKeyEntry) {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
p.keys = append(p.keys, entry)
|
||||
}
|
||||
|
||||
// selectKey selects a key that supports the model.
|
||||
func (p *APIKeyProvider) selectKey(model string) *APIKeyEntry {
|
||||
p.mu.RLock()
|
||||
defer p.mu.RUnlock()
|
||||
|
||||
for _, key := range p.keys {
|
||||
for _, m := range key.Models {
|
||||
if strings.EqualFold(m.Alias, model) || strings.EqualFold(m.Name, model) {
|
||||
return &key
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// resolveModel resolves alias to actual model name.
|
||||
func (p *APIKeyProvider) resolveModel(key *APIKeyEntry, requested string) string {
|
||||
for _, m := range key.Models {
|
||||
if strings.EqualFold(m.Alias, requested) {
|
||||
return m.Name
|
||||
}
|
||||
}
|
||||
return requested
|
||||
}
|
||||
|
||||
// executeHTTP makes the HTTP request.
|
||||
func (p *APIKeyProvider) executeHTTP(ctx context.Context, key *APIKeyEntry, model string, req executor.Request) (executor.Response, error) {
|
||||
// TODO: implement actual HTTP execution
|
||||
// This is a placeholder - actual implementation would build HTTP request
|
||||
return executor.Response{}, errors.New("not yet implemented")
|
||||
}
|
||||
|
||||
// executeHTTPStream makes a streaming HTTP request.
|
||||
func (p *APIKeyProvider) executeHTTPStream(ctx context.Context, key *APIKeyEntry, model string, req executor.Request) (
|
||||
<-chan executor.StreamChunk, error) {
|
||||
// TODO: implement actual HTTP streaming
|
||||
return nil, errors.New("not yet implemented")
|
||||
}
|
||||
|
||||
// Errors
|
||||
var (
|
||||
ErrNoMatchingAPIKey = errors.New("no API key supports the requested model")
|
||||
)
|
||||
132
internal/routing/providers/oauth.go
Normal file
132
internal/routing/providers/oauth.go
Normal file
@@ -0,0 +1,132 @@
|
||||
package providers
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"sync"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/routing"
|
||||
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
|
||||
)
|
||||
|
||||
// OAuthProvider wraps OAuth-based auths as routing.Provider.
|
||||
type OAuthProvider struct {
|
||||
name string
|
||||
auths []*coreauth.Auth
|
||||
mu sync.RWMutex
|
||||
executor coreauth.ProviderExecutor
|
||||
}
|
||||
|
||||
// NewOAuthProvider creates a new OAuth provider.
|
||||
func NewOAuthProvider(name string, exec coreauth.ProviderExecutor) *OAuthProvider {
|
||||
return &OAuthProvider{
|
||||
name: name,
|
||||
auths: make([]*coreauth.Auth, 0),
|
||||
executor: exec,
|
||||
}
|
||||
}
|
||||
|
||||
// Name returns the provider name.
|
||||
func (p *OAuthProvider) Name() string {
|
||||
return p.name
|
||||
}
|
||||
|
||||
// Type returns ProviderTypeOAuth.
|
||||
func (p *OAuthProvider) Type() routing.ProviderType {
|
||||
return routing.ProviderTypeOAuth
|
||||
}
|
||||
|
||||
// SupportsModel checks if any auth supports the model.
|
||||
func (p *OAuthProvider) SupportsModel(model string) bool {
|
||||
p.mu.RLock()
|
||||
defer p.mu.RUnlock()
|
||||
|
||||
// OAuth providers typically support models via oauth-model-alias
|
||||
// The actual model support is determined at execution time
|
||||
return true
|
||||
}
|
||||
|
||||
// Available checks if there's an available auth for the model.
|
||||
func (p *OAuthProvider) Available(model string) bool {
|
||||
p.mu.RLock()
|
||||
defer p.mu.RUnlock()
|
||||
|
||||
for _, auth := range p.auths {
|
||||
if p.isAuthAvailable(auth, model) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// Priority returns the priority (OAuth is preferred over API key).
|
||||
func (p *OAuthProvider) Priority() int {
|
||||
return 10
|
||||
}
|
||||
|
||||
// Execute sends the request using an available OAuth auth.
|
||||
func (p *OAuthProvider) Execute(ctx context.Context, model string, req executor.Request) (executor.Response, error) {
|
||||
auth := p.selectAuth(model)
|
||||
if auth == nil {
|
||||
return executor.Response{}, ErrNoAvailableAuth
|
||||
}
|
||||
|
||||
return p.executor.Execute(ctx, auth, req, executor.Options{})
|
||||
}
|
||||
|
||||
// ExecuteStream sends a streaming request.
|
||||
func (p *OAuthProvider) ExecuteStream(ctx context.Context, model string, req executor.Request) (<-chan executor.StreamChunk, error) {
|
||||
auth := p.selectAuth(model)
|
||||
if auth == nil {
|
||||
return nil, ErrNoAvailableAuth
|
||||
}
|
||||
|
||||
return p.executor.ExecuteStream(ctx, auth, req, executor.Options{})
|
||||
}
|
||||
|
||||
// AddAuth adds an auth to this provider.
|
||||
func (p *OAuthProvider) AddAuth(auth *coreauth.Auth) {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
p.auths = append(p.auths, auth)
|
||||
}
|
||||
|
||||
// RemoveAuth removes an auth from this provider.
|
||||
func (p *OAuthProvider) RemoveAuth(authID string) {
|
||||
p.mu.Lock()
|
||||
defer p.mu.Unlock()
|
||||
|
||||
filtered := make([]*coreauth.Auth, 0, len(p.auths))
|
||||
for _, auth := range p.auths {
|
||||
if auth.ID != authID {
|
||||
filtered = append(filtered, auth)
|
||||
}
|
||||
}
|
||||
p.auths = filtered
|
||||
}
|
||||
|
||||
// isAuthAvailable checks if an auth is available for the model.
|
||||
func (p *OAuthProvider) isAuthAvailable(auth *coreauth.Auth, model string) bool {
|
||||
// TODO: integrate with model_registry for quota checking
|
||||
// For now, just check if auth exists
|
||||
return auth != nil
|
||||
}
|
||||
|
||||
// selectAuth selects an available auth for the model.
|
||||
func (p *OAuthProvider) selectAuth(model string) *coreauth.Auth {
|
||||
p.mu.RLock()
|
||||
defer p.mu.RUnlock()
|
||||
|
||||
for _, auth := range p.auths {
|
||||
if p.isAuthAvailable(auth, model) {
|
||||
return auth
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Errors
|
||||
var (
|
||||
ErrNoAvailableAuth = errors.New("no available OAuth auth for model")
|
||||
)
|
||||
159
internal/routing/rewriter.go
Normal file
159
internal/routing/rewriter.go
Normal file
@@ -0,0 +1,159 @@
|
||||
package routing
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/tidwall/gjson"
|
||||
"github.com/tidwall/sjson"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// ModelRewriter handles model name rewriting in requests and responses.
|
||||
type ModelRewriter interface {
|
||||
// RewriteRequestBody rewrites the model field in a JSON request body.
|
||||
// Returns the modified body or the original if no rewrite was needed.
|
||||
RewriteRequestBody(body []byte, newModel string) ([]byte, error)
|
||||
|
||||
// WrapResponseWriter wraps an http.ResponseWriter to rewrite model names in the response.
|
||||
// Returns the wrapped writer and a cleanup function that must be called after the response is complete.
|
||||
WrapResponseWriter(w http.ResponseWriter, requestedModel, resolvedModel string) (http.ResponseWriter, func())
|
||||
}
|
||||
|
||||
// DefaultModelRewriter is the standard implementation of ModelRewriter.
|
||||
type DefaultModelRewriter struct{}
|
||||
|
||||
// NewModelRewriter creates a new DefaultModelRewriter.
|
||||
func NewModelRewriter() *DefaultModelRewriter {
|
||||
return &DefaultModelRewriter{}
|
||||
}
|
||||
|
||||
// RewriteRequestBody replaces the model name in a JSON request body.
|
||||
func (r *DefaultModelRewriter) RewriteRequestBody(body []byte, newModel string) ([]byte, error) {
|
||||
if !gjson.GetBytes(body, "model").Exists() {
|
||||
return body, nil
|
||||
}
|
||||
result, err := sjson.SetBytes(body, "model", newModel)
|
||||
if err != nil {
|
||||
return body, err
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// WrapResponseWriter wraps a response writer to rewrite model names.
|
||||
// The cleanup function must be called after the handler completes to flush any buffered data.
|
||||
func (r *DefaultModelRewriter) WrapResponseWriter(w http.ResponseWriter, requestedModel, resolvedModel string) (http.ResponseWriter, func()) {
|
||||
rw := &responseRewriter{
|
||||
ResponseWriter: w,
|
||||
body: &bytes.Buffer{},
|
||||
requestedModel: requestedModel,
|
||||
resolvedModel: resolvedModel,
|
||||
}
|
||||
return rw, func() { rw.flush() }
|
||||
}
|
||||
|
||||
// responseRewriter wraps http.ResponseWriter to intercept and modify the response body.
|
||||
type responseRewriter struct {
|
||||
http.ResponseWriter
|
||||
body *bytes.Buffer
|
||||
requestedModel string
|
||||
resolvedModel string
|
||||
isStreaming bool
|
||||
wroteHeader bool
|
||||
flushed bool
|
||||
}
|
||||
|
||||
// Write intercepts response writes and buffers them for model name replacement.
|
||||
func (rw *responseRewriter) Write(data []byte) (int, error) {
|
||||
// Ensure header is written
|
||||
if !rw.wroteHeader {
|
||||
rw.WriteHeader(http.StatusOK)
|
||||
}
|
||||
|
||||
// Detect streaming on first write
|
||||
if rw.body.Len() == 0 && !rw.isStreaming {
|
||||
contentType := rw.Header().Get("Content-Type")
|
||||
rw.isStreaming = strings.Contains(contentType, "text/event-stream") ||
|
||||
strings.Contains(contentType, "stream")
|
||||
}
|
||||
|
||||
if rw.isStreaming {
|
||||
n, err := rw.ResponseWriter.Write(rw.rewriteStreamChunk(data))
|
||||
if err == nil {
|
||||
if flusher, ok := rw.ResponseWriter.(http.Flusher); ok {
|
||||
flusher.Flush()
|
||||
}
|
||||
}
|
||||
return n, err
|
||||
}
|
||||
return rw.body.Write(data)
|
||||
}
|
||||
|
||||
// WriteHeader captures the status code and delegates to the underlying writer.
|
||||
func (rw *responseRewriter) WriteHeader(code int) {
|
||||
if !rw.wroteHeader {
|
||||
rw.wroteHeader = true
|
||||
rw.ResponseWriter.WriteHeader(code)
|
||||
}
|
||||
}
|
||||
|
||||
// flush writes the buffered response with model names rewritten.
|
||||
func (rw *responseRewriter) flush() {
|
||||
if rw.flushed {
|
||||
return
|
||||
}
|
||||
rw.flushed = true
|
||||
|
||||
if rw.isStreaming {
|
||||
if flusher, ok := rw.ResponseWriter.(http.Flusher); ok {
|
||||
flusher.Flush()
|
||||
}
|
||||
return
|
||||
}
|
||||
if rw.body.Len() > 0 {
|
||||
data := rw.rewriteModelInResponse(rw.body.Bytes())
|
||||
if _, err := rw.ResponseWriter.Write(data); err != nil {
|
||||
log.Warnf("response rewriter: failed to write rewritten response: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// modelFieldPaths lists all JSON paths where model name may appear.
|
||||
var modelFieldPaths = []string{"model", "modelVersion", "response.modelVersion", "message.model"}
|
||||
|
||||
// rewriteModelInResponse replaces all occurrences of the resolved model with the requested model.
|
||||
func (rw *responseRewriter) rewriteModelInResponse(data []byte) []byte {
|
||||
if rw.requestedModel == "" || rw.resolvedModel == "" || rw.requestedModel == rw.resolvedModel {
|
||||
return data
|
||||
}
|
||||
|
||||
for _, path := range modelFieldPaths {
|
||||
if gjson.GetBytes(data, path).Exists() {
|
||||
data, _ = sjson.SetBytes(data, path, rw.requestedModel)
|
||||
}
|
||||
}
|
||||
return data
|
||||
}
|
||||
|
||||
// rewriteStreamChunk rewrites model names in SSE stream chunks.
|
||||
func (rw *responseRewriter) rewriteStreamChunk(chunk []byte) []byte {
|
||||
if rw.requestedModel == "" || rw.resolvedModel == "" || rw.requestedModel == rw.resolvedModel {
|
||||
return chunk
|
||||
}
|
||||
|
||||
// SSE format: "data: {json}\n\n"
|
||||
lines := bytes.Split(chunk, []byte("\n"))
|
||||
for i, line := range lines {
|
||||
if bytes.HasPrefix(line, []byte("data: ")) {
|
||||
jsonData := bytes.TrimPrefix(line, []byte("data: "))
|
||||
if len(jsonData) > 0 && jsonData[0] == '{' {
|
||||
// Rewrite JSON in the data line
|
||||
rewritten := rw.rewriteModelInResponse(jsonData)
|
||||
lines[i] = append([]byte("data: "), rewritten...)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return bytes.Join(lines, []byte("\n"))
|
||||
}
|
||||
342
internal/routing/rewriter_test.go
Normal file
342
internal/routing/rewriter_test.go
Normal file
@@ -0,0 +1,342 @@
|
||||
package routing
|
||||
|
||||
import (
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestModelRewriter_RewriteRequestBody(t *testing.T) {
|
||||
rewriter := NewModelRewriter()
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
body []byte
|
||||
newModel string
|
||||
wantModel string
|
||||
wantChange bool
|
||||
}{
|
||||
{
|
||||
name: "rewrites model field in JSON body",
|
||||
body: []byte(`{"model":"gpt-4.1","messages":[]}`),
|
||||
newModel: "claude-local",
|
||||
wantModel: "claude-local",
|
||||
wantChange: true,
|
||||
},
|
||||
{
|
||||
name: "rewrites with empty body returns empty",
|
||||
body: []byte{},
|
||||
newModel: "gpt-4",
|
||||
wantModel: "",
|
||||
wantChange: false,
|
||||
},
|
||||
{
|
||||
name: "handles missing model field gracefully",
|
||||
body: []byte(`{"messages":[{"role":"user"}]}`),
|
||||
newModel: "gpt-4",
|
||||
wantModel: "",
|
||||
wantChange: false,
|
||||
},
|
||||
{
|
||||
name: "preserves other fields when rewriting",
|
||||
body: []byte(`{"model":"old-model","temperature":0.7,"max_tokens":100}`),
|
||||
newModel: "new-model",
|
||||
wantModel: "new-model",
|
||||
wantChange: true,
|
||||
},
|
||||
{
|
||||
name: "handles nested JSON structure",
|
||||
body: []byte(`{"model":"gpt-4","messages":[{"role":"user","content":"hello"}],"stream":true}`),
|
||||
newModel: "claude-3-opus",
|
||||
wantModel: "claude-3-opus",
|
||||
wantChange: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result, err := rewriter.RewriteRequestBody(tt.body, tt.newModel)
|
||||
require.NoError(t, err)
|
||||
|
||||
if tt.wantChange {
|
||||
assert.NotEqual(t, string(tt.body), string(result), "body should have been modified")
|
||||
}
|
||||
|
||||
if tt.wantModel != "" {
|
||||
// Parse result and check model field
|
||||
model, _ := NewModelExtractor().Extract(result, nil)
|
||||
assert.Equal(t, tt.wantModel, model)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestModelRewriter_WrapResponseWriter(t *testing.T) {
|
||||
rewriter := NewModelRewriter()
|
||||
|
||||
t.Run("response writer wraps without error", func(t *testing.T) {
|
||||
recorder := httptest.NewRecorder()
|
||||
wrapped, cleanup := rewriter.WrapResponseWriter(recorder, "gpt-4", "claude-local")
|
||||
require.NotNil(t, wrapped)
|
||||
require.NotNil(t, cleanup)
|
||||
defer cleanup()
|
||||
})
|
||||
|
||||
t.Run("rewrites model in non-streaming response", func(t *testing.T) {
|
||||
recorder := httptest.NewRecorder()
|
||||
wrapped, cleanup := rewriter.WrapResponseWriter(recorder, "gpt-4", "claude-local")
|
||||
|
||||
// Write a response with the resolved model
|
||||
response := []byte(`{"model":"claude-local","content":"hello"}`)
|
||||
wrapped.Header().Set("Content-Type", "application/json")
|
||||
_, err := wrapped.Write(response)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Cleanup triggers the rewrite
|
||||
cleanup()
|
||||
|
||||
// Check the response was rewritten to the requested model
|
||||
body := recorder.Body.Bytes()
|
||||
assert.Contains(t, string(body), `"model":"gpt-4"`)
|
||||
assert.NotContains(t, string(body), `"model":"claude-local"`)
|
||||
})
|
||||
|
||||
t.Run("no-op when requested equals resolved", func(t *testing.T) {
|
||||
recorder := httptest.NewRecorder()
|
||||
wrapped, cleanup := rewriter.WrapResponseWriter(recorder, "gpt-4", "gpt-4")
|
||||
|
||||
response := []byte(`{"model":"gpt-4","content":"hello"}`)
|
||||
wrapped.Header().Set("Content-Type", "application/json")
|
||||
_, err := wrapped.Write(response)
|
||||
require.NoError(t, err)
|
||||
|
||||
cleanup()
|
||||
|
||||
body := recorder.Body.Bytes()
|
||||
assert.Contains(t, string(body), `"model":"gpt-4"`)
|
||||
})
|
||||
|
||||
t.Run("rewrites modelVersion field", func(t *testing.T) {
|
||||
recorder := httptest.NewRecorder()
|
||||
wrapped, cleanup := rewriter.WrapResponseWriter(recorder, "gpt-4", "claude-local")
|
||||
|
||||
response := []byte(`{"modelVersion":"claude-local","content":"hello"}`)
|
||||
wrapped.Header().Set("Content-Type", "application/json")
|
||||
_, err := wrapped.Write(response)
|
||||
require.NoError(t, err)
|
||||
|
||||
cleanup()
|
||||
|
||||
body := recorder.Body.Bytes()
|
||||
assert.Contains(t, string(body), `"modelVersion":"gpt-4"`)
|
||||
})
|
||||
|
||||
t.Run("handles streaming responses", func(t *testing.T) {
|
||||
recorder := httptest.NewRecorder()
|
||||
wrapped, cleanup := rewriter.WrapResponseWriter(recorder, "gpt-4", "claude-local")
|
||||
|
||||
// Set streaming content type
|
||||
wrapped.Header().Set("Content-Type", "text/event-stream")
|
||||
|
||||
// Write SSE chunks with resolved model
|
||||
chunk1 := []byte("data: {\"model\":\"claude-local\",\"delta\":\"hello\"}\n\n")
|
||||
_, err := wrapped.Write(chunk1)
|
||||
require.NoError(t, err)
|
||||
|
||||
chunk2 := []byte("data: {\"model\":\"claude-local\",\"delta\":\" world\"}\n\n")
|
||||
_, err = wrapped.Write(chunk2)
|
||||
require.NoError(t, err)
|
||||
|
||||
cleanup()
|
||||
|
||||
// For streaming, data is written immediately with rewrites
|
||||
body := recorder.Body.Bytes()
|
||||
assert.Contains(t, string(body), `"model":"gpt-4"`)
|
||||
assert.NotContains(t, string(body), `"model":"claude-local"`)
|
||||
})
|
||||
|
||||
t.Run("empty body handled gracefully", func(t *testing.T) {
|
||||
recorder := httptest.NewRecorder()
|
||||
wrapped, cleanup := rewriter.WrapResponseWriter(recorder, "gpt-4", "claude-local")
|
||||
|
||||
wrapped.Header().Set("Content-Type", "application/json")
|
||||
// Don't write anything
|
||||
|
||||
cleanup()
|
||||
|
||||
body := recorder.Body.Bytes()
|
||||
assert.Empty(t, body)
|
||||
})
|
||||
|
||||
t.Run("preserves other JSON fields", func(t *testing.T) {
|
||||
recorder := httptest.NewRecorder()
|
||||
wrapped, cleanup := rewriter.WrapResponseWriter(recorder, "gpt-4", "claude-local")
|
||||
|
||||
response := []byte(`{"model":"claude-local","temperature":0.7,"usage":{"prompt_tokens":10}}`)
|
||||
wrapped.Header().Set("Content-Type", "application/json")
|
||||
_, err := wrapped.Write(response)
|
||||
require.NoError(t, err)
|
||||
|
||||
cleanup()
|
||||
|
||||
body := recorder.Body.Bytes()
|
||||
assert.Contains(t, string(body), `"temperature":0.7`)
|
||||
assert.Contains(t, string(body), `"prompt_tokens":10`)
|
||||
})
|
||||
}
|
||||
|
||||
func TestResponseRewriter_ImplementsInterfaces(t *testing.T) {
|
||||
rewriter := NewModelRewriter()
|
||||
recorder := httptest.NewRecorder()
|
||||
wrapped, cleanup := rewriter.WrapResponseWriter(recorder, "gpt-4", "claude-local")
|
||||
defer cleanup()
|
||||
|
||||
// Should implement http.ResponseWriter
|
||||
assert.Implements(t, (*http.ResponseWriter)(nil), wrapped)
|
||||
|
||||
// Should preserve header access
|
||||
wrapped.Header().Set("X-Custom", "value")
|
||||
assert.Equal(t, "value", recorder.Header().Get("X-Custom"))
|
||||
|
||||
// Should write status
|
||||
wrapped.WriteHeader(http.StatusCreated)
|
||||
assert.Equal(t, http.StatusCreated, recorder.Code)
|
||||
}
|
||||
|
||||
func TestResponseRewriter_Flush(t *testing.T) {
|
||||
t.Run("flush writes buffered content", func(t *testing.T) {
|
||||
rewriter := NewModelRewriter()
|
||||
recorder := httptest.NewRecorder()
|
||||
wrapped, cleanup := rewriter.WrapResponseWriter(recorder, "gpt-4", "claude-local")
|
||||
|
||||
response := []byte(`{"model":"claude-local","content":"test"}`)
|
||||
wrapped.Header().Set("Content-Type", "application/json")
|
||||
wrapped.Write(response)
|
||||
|
||||
// Before cleanup, response should be empty (buffered)
|
||||
assert.Empty(t, recorder.Body.Bytes())
|
||||
|
||||
// After cleanup, response should be written
|
||||
cleanup()
|
||||
assert.NotEmpty(t, recorder.Body.Bytes())
|
||||
})
|
||||
|
||||
t.Run("multiple flush calls are safe", func(t *testing.T) {
|
||||
rewriter := NewModelRewriter()
|
||||
recorder := httptest.NewRecorder()
|
||||
wrapped, cleanup := rewriter.WrapResponseWriter(recorder, "gpt-4", "claude-local")
|
||||
|
||||
response := []byte(`{"model":"claude-local"}`)
|
||||
wrapped.Header().Set("Content-Type", "application/json")
|
||||
wrapped.Write(response)
|
||||
|
||||
// First cleanup
|
||||
cleanup()
|
||||
firstBody := recorder.Body.Bytes()
|
||||
|
||||
// Second cleanup should not write again
|
||||
cleanup()
|
||||
secondBody := recorder.Body.Bytes()
|
||||
|
||||
assert.Equal(t, firstBody, secondBody)
|
||||
})
|
||||
}
|
||||
|
||||
func TestResponseRewriter_StreamingWithDataLines(t *testing.T) {
|
||||
rewriter := NewModelRewriter()
|
||||
recorder := httptest.NewRecorder()
|
||||
wrapped, cleanup := rewriter.WrapResponseWriter(recorder, "gpt-4", "claude-local")
|
||||
|
||||
wrapped.Header().Set("Content-Type", "text/event-stream")
|
||||
|
||||
// SSE format with multiple data lines
|
||||
chunk := []byte("data: {\"model\":\"claude-local\"}\n\ndata: {\"model\":\"claude-local\",\"done\":true}\n\n")
|
||||
wrapped.Write(chunk)
|
||||
|
||||
cleanup()
|
||||
|
||||
body := recorder.Body.Bytes()
|
||||
// Both data lines should have model rewritten
|
||||
assert.Contains(t, string(body), `"model":"gpt-4"`)
|
||||
assert.NotContains(t, string(body), `"model":"claude-local"`)
|
||||
}
|
||||
|
||||
func TestModelRewriter_RoundTrip(t *testing.T) {
|
||||
// Simulate a full request -> response cycle with model rewriting
|
||||
rewriter := NewModelRewriter()
|
||||
|
||||
// Step 1: Rewrite request body
|
||||
originalRequest := []byte(`{"model":"gpt-4","messages":[{"role":"user","content":"hello"}]}`)
|
||||
rewrittenRequest, err := rewriter.RewriteRequestBody(originalRequest, "claude-local")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify request was rewritten
|
||||
extractor := NewModelExtractor()
|
||||
requestModel, _ := extractor.Extract(rewrittenRequest, nil)
|
||||
assert.Equal(t, "claude-local", requestModel)
|
||||
|
||||
// Step 2: Simulate response with resolved model
|
||||
recorder := httptest.NewRecorder()
|
||||
wrapped, cleanup := rewriter.WrapResponseWriter(recorder, "gpt-4", "claude-local")
|
||||
|
||||
response := []byte(`{"model":"claude-local","content":"Hello! How can I help?"}`)
|
||||
wrapped.Header().Set("Content-Type", "application/json")
|
||||
wrapped.Write(response)
|
||||
cleanup()
|
||||
|
||||
// Verify response was rewritten back
|
||||
body, _ := io.ReadAll(recorder.Result().Body)
|
||||
responseModel, _ := extractor.Extract(body, nil)
|
||||
assert.Equal(t, "gpt-4", responseModel)
|
||||
}
|
||||
|
||||
func TestModelRewriter_NonJSONBody(t *testing.T) {
|
||||
rewriter := NewModelRewriter()
|
||||
|
||||
// Binary/non-JSON body should be returned unchanged
|
||||
body := []byte{0x00, 0x01, 0x02, 0x03}
|
||||
result, err := rewriter.RewriteRequestBody(body, "gpt-4")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, body, result)
|
||||
}
|
||||
|
||||
func TestModelRewriter_InvalidJSON(t *testing.T) {
|
||||
rewriter := NewModelRewriter()
|
||||
|
||||
// Invalid JSON without model field should be returned unchanged
|
||||
body := []byte(`not valid json`)
|
||||
result, err := rewriter.RewriteRequestBody(body, "gpt-4")
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, body, result)
|
||||
}
|
||||
|
||||
func TestResponseRewriter_StatusCodePreserved(t *testing.T) {
|
||||
rewriter := NewModelRewriter()
|
||||
recorder := httptest.NewRecorder()
|
||||
wrapped, cleanup := rewriter.WrapResponseWriter(recorder, "gpt-4", "claude-local")
|
||||
|
||||
wrapped.WriteHeader(http.StatusAccepted)
|
||||
wrapped.Write([]byte(`{"model":"claude-local"}`))
|
||||
cleanup()
|
||||
|
||||
assert.Equal(t, http.StatusAccepted, recorder.Code)
|
||||
}
|
||||
|
||||
func TestResponseRewriter_HeaderFlushed(t *testing.T) {
|
||||
rewriter := NewModelRewriter()
|
||||
recorder := httptest.NewRecorder()
|
||||
wrapped, cleanup := rewriter.WrapResponseWriter(recorder, "gpt-4", "claude-local")
|
||||
|
||||
wrapped.Header().Set("Content-Type", "application/json")
|
||||
wrapped.Header().Set("X-Request-ID", "abc123")
|
||||
wrapped.Write([]byte(`{"model":"claude-local"}`))
|
||||
cleanup()
|
||||
|
||||
result := recorder.Result()
|
||||
assert.Equal(t, "application/json", result.Header.Get("Content-Type"))
|
||||
assert.Equal(t, "abc123", result.Header.Get("X-Request-ID"))
|
||||
}
|
||||
317
internal/routing/router.go
Normal file
317
internal/routing/router.go
Normal file
@@ -0,0 +1,317 @@
|
||||
package routing
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sort"
|
||||
"strings"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/thinking"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
|
||||
)
|
||||
|
||||
// Router resolves models to provider candidates.
|
||||
type Router struct {
|
||||
registry *Registry
|
||||
modelMappings map[string]string // normalized from -> to
|
||||
oauthAliases map[string][]string // normalized model -> []alias
|
||||
}
|
||||
|
||||
// NewRouter creates a new router with the given configuration.
|
||||
func NewRouter(registry *Registry, cfg *config.Config) *Router {
|
||||
r := &Router{
|
||||
registry: registry,
|
||||
modelMappings: make(map[string]string),
|
||||
oauthAliases: make(map[string][]string),
|
||||
}
|
||||
|
||||
if cfg != nil {
|
||||
r.loadModelMappings(cfg.AmpCode.ModelMappings)
|
||||
r.loadOAuthAliases(cfg.OAuthModelAlias)
|
||||
}
|
||||
|
||||
return r
|
||||
}
|
||||
|
||||
// LegacyRoutingDecision contains the resolved routing information.
|
||||
// Deprecated: Will be replaced by RoutingDecision from types.go in T-013.
|
||||
type LegacyRoutingDecision struct {
|
||||
RequestedModel string // Original model from request
|
||||
ResolvedModel string // After model-mappings
|
||||
Candidates []ProviderCandidate // Ordered list of providers to try
|
||||
}
|
||||
|
||||
// Resolve determines the routing decision for the requested model.
|
||||
// Deprecated: Will be updated to use RoutingRequest and return *RoutingDecision in T-013.
|
||||
func (r *Router) Resolve(requestedModel string) *LegacyRoutingDecision {
|
||||
// 1. Extract thinking suffix
|
||||
suffixResult := thinking.ParseSuffix(requestedModel)
|
||||
baseModel := suffixResult.ModelName
|
||||
|
||||
// 2. Apply model-mappings
|
||||
targetModel := r.applyMappings(baseModel)
|
||||
|
||||
// 3. Find primary providers
|
||||
candidates := r.findCandidates(targetModel, suffixResult)
|
||||
|
||||
// 4. Add fallback aliases
|
||||
for _, alias := range r.oauthAliases[strings.ToLower(targetModel)] {
|
||||
candidates = append(candidates, r.findCandidates(alias, suffixResult)...)
|
||||
}
|
||||
|
||||
// 5. Sort by priority
|
||||
sort.Slice(candidates, func(i, j int) bool {
|
||||
return candidates[i].Provider.Priority() < candidates[j].Provider.Priority()
|
||||
})
|
||||
|
||||
return &LegacyRoutingDecision{
|
||||
RequestedModel: requestedModel,
|
||||
ResolvedModel: targetModel,
|
||||
Candidates: candidates,
|
||||
}
|
||||
}
|
||||
|
||||
// ResolveV2 determines the routing decision for a routing request.
|
||||
// It uses the new RoutingRequest and RoutingDecision types.
|
||||
func (r *Router) ResolveV2(req RoutingRequest) *RoutingDecision {
|
||||
// 1. Extract thinking suffix
|
||||
suffixResult := thinking.ParseSuffix(req.RequestedModel)
|
||||
baseModel := suffixResult.ModelName
|
||||
thinkingSuffix := ""
|
||||
if suffixResult.HasSuffix {
|
||||
thinkingSuffix = "(" + suffixResult.RawSuffix + ")"
|
||||
}
|
||||
|
||||
// 2. Check for local providers
|
||||
localCandidates := r.findLocalCandidates(baseModel, suffixResult)
|
||||
|
||||
// 3. Apply model-mappings if needed
|
||||
mappedModel := r.applyMappings(baseModel)
|
||||
mappingCandidates := r.findLocalCandidates(mappedModel, suffixResult)
|
||||
|
||||
// 4. Determine route type based on preferences and availability
|
||||
var decision *RoutingDecision
|
||||
|
||||
if req.ForceModelMapping && mappedModel != baseModel && len(mappingCandidates) > 0 {
|
||||
// FORCE MODE: Use mapping even if local provider exists
|
||||
decision = r.buildMappingDecision(req.RequestedModel, mappedModel, mappingCandidates, thinkingSuffix, mappingCandidates[1:])
|
||||
} else if req.PreferLocalProvider && len(localCandidates) > 0 {
|
||||
// DEFAULT MODE with local preference: Use local provider first
|
||||
decision = r.buildLocalProviderDecision(req.RequestedModel, localCandidates, thinkingSuffix)
|
||||
} else if len(localCandidates) > 0 {
|
||||
// DEFAULT MODE: Local provider available
|
||||
decision = r.buildLocalProviderDecision(req.RequestedModel, localCandidates, thinkingSuffix)
|
||||
} else if mappedModel != baseModel && len(mappingCandidates) > 0 {
|
||||
// DEFAULT MODE: No local provider, but mapping available
|
||||
decision = r.buildMappingDecision(req.RequestedModel, mappedModel, mappingCandidates, thinkingSuffix, mappingCandidates[1:])
|
||||
} else {
|
||||
// No local provider, no mapping - use amp credits proxy
|
||||
decision = &RoutingDecision{
|
||||
RouteType: RouteTypeAmpCredits,
|
||||
ResolvedModel: req.RequestedModel,
|
||||
ShouldProxy: true,
|
||||
}
|
||||
}
|
||||
|
||||
return decision
|
||||
}
|
||||
|
||||
// findLocalCandidates finds local provider candidates for a model.
|
||||
// If the internal registry is empty, it falls back to the global model registry.
|
||||
func (r *Router) findLocalCandidates(model string, suffixResult thinking.SuffixResult) []ProviderCandidate {
|
||||
var candidates []ProviderCandidate
|
||||
|
||||
// Check internal registry first
|
||||
registryProviders := r.registry.All()
|
||||
if len(registryProviders) > 0 {
|
||||
for _, p := range registryProviders {
|
||||
if !p.SupportsModel(model) {
|
||||
continue
|
||||
}
|
||||
|
||||
// Apply thinking suffix if needed
|
||||
actualModel := model
|
||||
if suffixResult.HasSuffix && !thinking.ParseSuffix(model).HasSuffix {
|
||||
actualModel = model + "(" + suffixResult.RawSuffix + ")"
|
||||
}
|
||||
|
||||
if p.Available(actualModel) {
|
||||
candidates = append(candidates, ProviderCandidate{
|
||||
Provider: p,
|
||||
Model: actualModel,
|
||||
})
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// Fallback to global model registry (same logic as FallbackHandler)
|
||||
// This ensures compatibility when the wrapper is initialized with an empty registry
|
||||
providers := registry.GetGlobalRegistry().GetModelProviders(model)
|
||||
if len(providers) > 0 {
|
||||
actualModel := model
|
||||
if suffixResult.HasSuffix && !thinking.ParseSuffix(model).HasSuffix {
|
||||
actualModel = model + "(" + suffixResult.RawSuffix + ")"
|
||||
}
|
||||
// Create a synthetic provider candidate for each provider
|
||||
for _, providerName := range providers {
|
||||
candidates = append(candidates, ProviderCandidate{
|
||||
Provider: &globalRegistryProvider{name: providerName, model: actualModel},
|
||||
Model: actualModel,
|
||||
})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Sort by priority
|
||||
sort.Slice(candidates, func(i, j int) bool {
|
||||
return candidates[i].Provider.Priority() < candidates[j].Provider.Priority()
|
||||
})
|
||||
|
||||
return candidates
|
||||
}
|
||||
|
||||
// globalRegistryProvider is a synthetic Provider implementation that wraps
|
||||
// a provider name from the global model registry. It is used only for routing
|
||||
// decisions when the internal registry is empty - actual execution goes through
|
||||
// the normal handler path, not through this provider's Execute methods.
|
||||
type globalRegistryProvider struct {
|
||||
name string
|
||||
model string
|
||||
}
|
||||
|
||||
func (p *globalRegistryProvider) Name() string { return p.name }
|
||||
func (p *globalRegistryProvider) Type() ProviderType { return ProviderTypeOAuth }
|
||||
func (p *globalRegistryProvider) Priority() int { return 0 }
|
||||
func (p *globalRegistryProvider) SupportsModel(string) bool { return true }
|
||||
func (p *globalRegistryProvider) Available(string) bool { return true }
|
||||
|
||||
// Execute is not used for globalRegistryProvider - routing wrapper calls the handler directly.
|
||||
func (p *globalRegistryProvider) Execute(ctx context.Context, model string, req executor.Request) (executor.Response, error) {
|
||||
return executor.Response{}, nil
|
||||
}
|
||||
|
||||
// ExecuteStream is not used for globalRegistryProvider - routing wrapper calls the handler directly.
|
||||
func (p *globalRegistryProvider) ExecuteStream(ctx context.Context, model string, req executor.Request) (<-chan executor.StreamChunk, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
// buildLocalProviderDecision creates a decision for local provider routing.
|
||||
func (r *Router) buildLocalProviderDecision(requestedModel string, candidates []ProviderCandidate, thinkingSuffix string) *RoutingDecision {
|
||||
resolvedModel := requestedModel
|
||||
if thinkingSuffix != "" {
|
||||
// Ensure thinking suffix is preserved
|
||||
sr := thinking.ParseSuffix(requestedModel)
|
||||
if !sr.HasSuffix {
|
||||
resolvedModel = requestedModel + thinkingSuffix
|
||||
}
|
||||
}
|
||||
|
||||
var fallbackModels []string
|
||||
if len(candidates) > 1 {
|
||||
for _, c := range candidates[1:] {
|
||||
fallbackModels = append(fallbackModels, c.Model)
|
||||
}
|
||||
}
|
||||
|
||||
return &RoutingDecision{
|
||||
RouteType: RouteTypeLocalProvider,
|
||||
ResolvedModel: resolvedModel,
|
||||
ProviderName: candidates[0].Provider.Name(),
|
||||
FallbackModels: fallbackModels,
|
||||
ShouldProxy: false,
|
||||
}
|
||||
}
|
||||
|
||||
// buildMappingDecision creates a decision for model mapping routing.
|
||||
func (r *Router) buildMappingDecision(requestedModel, mappedModel string, candidates []ProviderCandidate, thinkingSuffix string, fallbackCandidates []ProviderCandidate) *RoutingDecision {
|
||||
// Apply thinking suffix to resolved model if needed
|
||||
resolvedModel := mappedModel
|
||||
if thinkingSuffix != "" {
|
||||
sr := thinking.ParseSuffix(mappedModel)
|
||||
if !sr.HasSuffix {
|
||||
resolvedModel = mappedModel + thinkingSuffix
|
||||
}
|
||||
}
|
||||
|
||||
var fallbackModels []string
|
||||
for _, c := range fallbackCandidates {
|
||||
fallbackModels = append(fallbackModels, c.Model)
|
||||
}
|
||||
|
||||
// Also add oauth aliases as fallbacks
|
||||
baseMapped := thinking.ParseSuffix(mappedModel).ModelName
|
||||
for _, alias := range r.oauthAliases[strings.ToLower(baseMapped)] {
|
||||
// Check if this alias has providers
|
||||
aliasCandidates := r.findLocalCandidates(alias, thinking.SuffixResult{ModelName: alias})
|
||||
for _, c := range aliasCandidates {
|
||||
fallbackModels = append(fallbackModels, c.Model)
|
||||
}
|
||||
}
|
||||
|
||||
return &RoutingDecision{
|
||||
RouteType: RouteTypeModelMapping,
|
||||
ResolvedModel: resolvedModel,
|
||||
ProviderName: candidates[0].Provider.Name(),
|
||||
FallbackModels: fallbackModels,
|
||||
ShouldProxy: false,
|
||||
}
|
||||
}
|
||||
|
||||
// applyMappings applies model-mappings configuration.
|
||||
func (r *Router) applyMappings(model string) string {
|
||||
key := strings.ToLower(strings.TrimSpace(model))
|
||||
if mapped, ok := r.modelMappings[key]; ok {
|
||||
return mapped
|
||||
}
|
||||
return model
|
||||
}
|
||||
|
||||
// findCandidates finds all provider candidates for a model.
|
||||
func (r *Router) findCandidates(model string, suffixResult thinking.SuffixResult) []ProviderCandidate {
|
||||
var candidates []ProviderCandidate
|
||||
|
||||
for _, p := range r.registry.All() {
|
||||
if !p.SupportsModel(model) {
|
||||
continue
|
||||
}
|
||||
|
||||
// Apply thinking suffix if needed
|
||||
actualModel := model
|
||||
if suffixResult.HasSuffix && !thinking.ParseSuffix(model).HasSuffix {
|
||||
actualModel = model + "(" + suffixResult.RawSuffix + ")"
|
||||
}
|
||||
|
||||
if p.Available(actualModel) {
|
||||
candidates = append(candidates, ProviderCandidate{
|
||||
Provider: p,
|
||||
Model: actualModel,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
return candidates
|
||||
}
|
||||
|
||||
// loadModelMappings loads model-mappings from config.
|
||||
func (r *Router) loadModelMappings(mappings []config.AmpModelMapping) {
|
||||
for _, m := range mappings {
|
||||
from := strings.ToLower(strings.TrimSpace(m.From))
|
||||
to := strings.TrimSpace(m.To)
|
||||
if from != "" && to != "" {
|
||||
r.modelMappings[from] = to
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// loadOAuthAliases loads oauth-model-alias from config.
|
||||
func (r *Router) loadOAuthAliases(aliases map[string][]config.OAuthModelAlias) {
|
||||
for _, entries := range aliases {
|
||||
for _, entry := range entries {
|
||||
name := strings.ToLower(strings.TrimSpace(entry.Name))
|
||||
alias := strings.TrimSpace(entry.Alias)
|
||||
if name != "" && alias != "" && name != alias {
|
||||
r.oauthAliases[name] = append(r.oauthAliases[name], alias)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
202
internal/routing/router_test.go
Normal file
202
internal/routing/router_test.go
Normal file
@@ -0,0 +1,202 @@
|
||||
package routing
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
globalRegistry "github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
// mockProvider is a test double for Provider.
|
||||
type mockProvider struct {
|
||||
name string
|
||||
providerType ProviderType
|
||||
supportsModels map[string]bool
|
||||
available bool
|
||||
priority int
|
||||
}
|
||||
|
||||
func (m *mockProvider) Name() string { return m.name }
|
||||
func (m *mockProvider) Type() ProviderType { return m.providerType }
|
||||
func (m *mockProvider) SupportsModel(model string) bool { return m.supportsModels[model] }
|
||||
func (m *mockProvider) Available(model string) bool { return m.available }
|
||||
func (m *mockProvider) Priority() int { return m.priority }
|
||||
func (m *mockProvider) Execute(ctx context.Context, model string, req executor.Request) (executor.Response, error) {
|
||||
return executor.Response{}, nil
|
||||
}
|
||||
func (m *mockProvider) ExecuteStream(ctx context.Context, model string, req executor.Request) (<-chan executor.StreamChunk, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func TestRouter_Resolve_ModelMappings(t *testing.T) {
|
||||
registry := NewRegistry()
|
||||
|
||||
// Add a provider
|
||||
p := &mockProvider{
|
||||
name: "test-provider",
|
||||
providerType: ProviderTypeOAuth,
|
||||
supportsModels: map[string]bool{"target-model": true},
|
||||
available: true,
|
||||
priority: 1,
|
||||
}
|
||||
registry.Register(p)
|
||||
|
||||
// Create router with model mapping
|
||||
cfg := &config.Config{
|
||||
AmpCode: config.AmpCode{
|
||||
ModelMappings: []config.AmpModelMapping{
|
||||
{From: "user-model", To: "target-model"},
|
||||
},
|
||||
},
|
||||
}
|
||||
router := NewRouter(registry, cfg)
|
||||
|
||||
// Resolve
|
||||
decision := router.Resolve("user-model")
|
||||
|
||||
assert.Equal(t, "user-model", decision.RequestedModel)
|
||||
assert.Equal(t, "target-model", decision.ResolvedModel)
|
||||
assert.Len(t, decision.Candidates, 1)
|
||||
assert.Equal(t, "target-model", decision.Candidates[0].Model)
|
||||
}
|
||||
|
||||
func TestRouter_Resolve_OAuthAliases(t *testing.T) {
|
||||
registry := NewRegistry()
|
||||
|
||||
// Add providers
|
||||
p1 := &mockProvider{
|
||||
name: "oauth-1",
|
||||
providerType: ProviderTypeOAuth,
|
||||
supportsModels: map[string]bool{"primary-model": true},
|
||||
available: true,
|
||||
priority: 1,
|
||||
}
|
||||
p2 := &mockProvider{
|
||||
name: "oauth-2",
|
||||
providerType: ProviderTypeOAuth,
|
||||
supportsModels: map[string]bool{"fallback-model": true},
|
||||
available: true,
|
||||
priority: 2,
|
||||
}
|
||||
registry.Register(p1)
|
||||
registry.Register(p2)
|
||||
|
||||
// Create router with oauth aliases
|
||||
cfg := &config.Config{
|
||||
OAuthModelAlias: map[string][]config.OAuthModelAlias{
|
||||
"test-channel": {
|
||||
{Name: "primary-model", Alias: "fallback-model"},
|
||||
},
|
||||
},
|
||||
}
|
||||
router := NewRouter(registry, cfg)
|
||||
|
||||
// Resolve
|
||||
decision := router.Resolve("primary-model")
|
||||
|
||||
assert.Equal(t, "primary-model", decision.ResolvedModel)
|
||||
assert.Len(t, decision.Candidates, 2)
|
||||
// Primary should come first (lower priority value)
|
||||
assert.Equal(t, "primary-model", decision.Candidates[0].Model)
|
||||
assert.Equal(t, "fallback-model", decision.Candidates[1].Model)
|
||||
}
|
||||
|
||||
func TestRouter_Resolve_NoProviders(t *testing.T) {
|
||||
registry := NewRegistry()
|
||||
cfg := &config.Config{}
|
||||
router := NewRouter(registry, cfg)
|
||||
|
||||
decision := router.Resolve("unknown-model")
|
||||
|
||||
assert.Equal(t, "unknown-model", decision.ResolvedModel)
|
||||
assert.Empty(t, decision.Candidates)
|
||||
}
|
||||
|
||||
// === Global Registry Fallback Tests (T-027) ===
|
||||
// These tests verify that when the internal registry is empty,
|
||||
// the router falls back to the global model registry.
|
||||
// This is the core fix for the thinking signature 400 error.
|
||||
|
||||
func TestRouter_GlobalRegistryFallback_LocalProvider(t *testing.T) {
|
||||
// This test requires registering a model in the global registry.
|
||||
// We use a model that's already registered via api-key config in production.
|
||||
// For isolated testing, we can skip if global registry is not populated.
|
||||
|
||||
globalReg := globalRegistry.GetGlobalRegistry()
|
||||
modelCount := globalReg.GetModelCount("claude-sonnet-4-20250514")
|
||||
|
||||
if modelCount == 0 {
|
||||
t.Skip("Global registry not populated - run with server context")
|
||||
}
|
||||
|
||||
// Empty internal registry
|
||||
emptyRegistry := NewRegistry()
|
||||
cfg := &config.Config{}
|
||||
router := NewRouter(emptyRegistry, cfg)
|
||||
|
||||
req := RoutingRequest{
|
||||
RequestedModel: "claude-sonnet-4-20250514",
|
||||
PreferLocalProvider: true,
|
||||
}
|
||||
decision := router.ResolveV2(req)
|
||||
|
||||
// Should find provider from global registry
|
||||
assert.Equal(t, RouteTypeLocalProvider, decision.RouteType)
|
||||
assert.Equal(t, "claude-sonnet-4-20250514", decision.ResolvedModel)
|
||||
assert.False(t, decision.ShouldProxy)
|
||||
}
|
||||
|
||||
func TestRouter_GlobalRegistryFallback_ModelMapping(t *testing.T) {
|
||||
// This test verifies that model mapping works with global registry fallback.
|
||||
|
||||
globalReg := globalRegistry.GetGlobalRegistry()
|
||||
modelCount := globalReg.GetModelCount("claude-opus-4-5-thinking")
|
||||
|
||||
if modelCount == 0 {
|
||||
t.Skip("Global registry not populated - run with server context")
|
||||
}
|
||||
|
||||
// Empty internal registry
|
||||
emptyRegistry := NewRegistry()
|
||||
cfg := &config.Config{
|
||||
AmpCode: config.AmpCode{
|
||||
ModelMappings: []config.AmpModelMapping{
|
||||
{From: "claude-opus-4-5-20251101", To: "claude-opus-4-5-thinking"},
|
||||
},
|
||||
},
|
||||
}
|
||||
router := NewRouter(emptyRegistry, cfg)
|
||||
|
||||
req := RoutingRequest{
|
||||
RequestedModel: "claude-opus-4-5-20251101",
|
||||
PreferLocalProvider: true,
|
||||
}
|
||||
decision := router.ResolveV2(req)
|
||||
|
||||
// Should find mapped model from global registry
|
||||
assert.Equal(t, RouteTypeModelMapping, decision.RouteType)
|
||||
assert.Equal(t, "claude-opus-4-5-thinking", decision.ResolvedModel)
|
||||
assert.False(t, decision.ShouldProxy)
|
||||
}
|
||||
|
||||
func TestRouter_GlobalRegistryFallback_AmpCreditsWhenNotFound(t *testing.T) {
|
||||
// Empty internal registry
|
||||
emptyRegistry := NewRegistry()
|
||||
cfg := &config.Config{}
|
||||
router := NewRouter(emptyRegistry, cfg)
|
||||
|
||||
// Use a model that definitely doesn't exist anywhere
|
||||
req := RoutingRequest{
|
||||
RequestedModel: "nonexistent-model-12345",
|
||||
PreferLocalProvider: true,
|
||||
}
|
||||
decision := router.ResolveV2(req)
|
||||
|
||||
// Should fall back to AMP credits proxy
|
||||
assert.Equal(t, RouteTypeAmpCredits, decision.RouteType)
|
||||
assert.Equal(t, "nonexistent-model-12345", decision.ResolvedModel)
|
||||
assert.True(t, decision.ShouldProxy)
|
||||
}
|
||||
245
internal/routing/router_v2_test.go
Normal file
245
internal/routing/router_v2_test.go
Normal file
@@ -0,0 +1,245 @@
|
||||
package routing
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestRouter_DefaultMode_PrefersLocal(t *testing.T) {
|
||||
// Setup: Create a router with a mock provider that supports "gpt-4"
|
||||
registry := NewRegistry()
|
||||
mockProvider := &MockProvider{
|
||||
name: "openai",
|
||||
supportedModels: []string{"gpt-4"},
|
||||
available: true,
|
||||
priority: 1,
|
||||
}
|
||||
registry.Register(mockProvider)
|
||||
|
||||
cfg := &config.Config{
|
||||
AmpCode: config.AmpCode{
|
||||
ModelMappings: []config.AmpModelMapping{
|
||||
{From: "gpt-4", To: "claude-local"},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
router := NewRouter(registry, cfg)
|
||||
|
||||
// Test: Request gpt-4 when local provider exists
|
||||
req := RoutingRequest{
|
||||
RequestedModel: "gpt-4",
|
||||
PreferLocalProvider: true,
|
||||
ForceModelMapping: false,
|
||||
}
|
||||
|
||||
decision := router.ResolveV2(req)
|
||||
|
||||
// Assert: Should return LOCAL_PROVIDER, not MODEL_MAPPING
|
||||
assert.Equal(t, RouteTypeLocalProvider, decision.RouteType)
|
||||
assert.Equal(t, "gpt-4", decision.ResolvedModel)
|
||||
assert.Equal(t, "openai", decision.ProviderName)
|
||||
assert.False(t, decision.ShouldProxy)
|
||||
}
|
||||
|
||||
func TestRouter_DefaultMode_MapsWhenNoLocal(t *testing.T) {
|
||||
// Setup: Create a router with NO provider for "gpt-4" but a mapping to "claude-local"
|
||||
// which has a provider
|
||||
registry := NewRegistry()
|
||||
mockProvider := &MockProvider{
|
||||
name: "anthropic",
|
||||
supportedModels: []string{"claude-local"},
|
||||
available: true,
|
||||
priority: 1,
|
||||
}
|
||||
registry.Register(mockProvider)
|
||||
|
||||
cfg := &config.Config{
|
||||
AmpCode: config.AmpCode{
|
||||
ModelMappings: []config.AmpModelMapping{
|
||||
{From: "gpt-4", To: "claude-local"},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
router := NewRouter(registry, cfg)
|
||||
|
||||
// Test: Request gpt-4 when no local provider exists, but mapping exists
|
||||
req := RoutingRequest{
|
||||
RequestedModel: "gpt-4",
|
||||
PreferLocalProvider: true,
|
||||
ForceModelMapping: false,
|
||||
}
|
||||
|
||||
decision := router.ResolveV2(req)
|
||||
|
||||
// Assert: Should return MODEL_MAPPING
|
||||
assert.Equal(t, RouteTypeModelMapping, decision.RouteType)
|
||||
assert.Equal(t, "claude-local", decision.ResolvedModel)
|
||||
assert.Equal(t, "anthropic", decision.ProviderName)
|
||||
assert.False(t, decision.ShouldProxy)
|
||||
}
|
||||
|
||||
func TestRouter_DefaultMode_AmpCreditsWhenNoLocalOrMapping(t *testing.T) {
|
||||
// Setup: Create a router with no providers and no mappings
|
||||
registry := NewRegistry()
|
||||
|
||||
cfg := &config.Config{
|
||||
AmpCode: config.AmpCode{
|
||||
ModelMappings: []config.AmpModelMapping{},
|
||||
},
|
||||
}
|
||||
|
||||
router := NewRouter(registry, cfg)
|
||||
|
||||
// Test: Request a model with no local provider and no mapping
|
||||
req := RoutingRequest{
|
||||
RequestedModel: "unknown-model",
|
||||
PreferLocalProvider: true,
|
||||
ForceModelMapping: false,
|
||||
}
|
||||
|
||||
decision := router.ResolveV2(req)
|
||||
|
||||
// Assert: Should return AMP_CREDITS with ShouldProxy=true
|
||||
assert.Equal(t, RouteTypeAmpCredits, decision.RouteType)
|
||||
assert.Equal(t, "unknown-model", decision.ResolvedModel)
|
||||
assert.True(t, decision.ShouldProxy)
|
||||
assert.Empty(t, decision.ProviderName)
|
||||
}
|
||||
|
||||
func TestRouter_ForceMode_MapsEvenWithLocal(t *testing.T) {
|
||||
// Setup: Create a router with BOTH a local provider for "gpt-4" AND a mapping from "gpt-4" to "claude-local"
|
||||
// The mapping target "claude-local" also has a provider
|
||||
registry := NewRegistry()
|
||||
|
||||
// Local provider for gpt-4
|
||||
openaiProvider := &MockProvider{
|
||||
name: "openai",
|
||||
supportedModels: []string{"gpt-4"},
|
||||
available: true,
|
||||
priority: 1,
|
||||
}
|
||||
registry.Register(openaiProvider)
|
||||
|
||||
// Local provider for the mapped model
|
||||
anthropicProvider := &MockProvider{
|
||||
name: "anthropic",
|
||||
supportedModels: []string{"claude-local"},
|
||||
available: true,
|
||||
priority: 2,
|
||||
}
|
||||
registry.Register(anthropicProvider)
|
||||
|
||||
cfg := &config.Config{
|
||||
AmpCode: config.AmpCode{
|
||||
ModelMappings: []config.AmpModelMapping{
|
||||
{From: "gpt-4", To: "claude-local"},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
router := NewRouter(registry, cfg)
|
||||
|
||||
// Test: Request gpt-4 with ForceModelMapping=true
|
||||
// Even though gpt-4 has a local provider, mapping should take precedence
|
||||
req := RoutingRequest{
|
||||
RequestedModel: "gpt-4",
|
||||
PreferLocalProvider: false,
|
||||
ForceModelMapping: true,
|
||||
}
|
||||
|
||||
decision := router.ResolveV2(req)
|
||||
|
||||
// Assert: Should return MODEL_MAPPING, not LOCAL_PROVIDER
|
||||
assert.Equal(t, RouteTypeModelMapping, decision.RouteType)
|
||||
assert.Equal(t, "claude-local", decision.ResolvedModel)
|
||||
assert.Equal(t, "anthropic", decision.ProviderName)
|
||||
assert.False(t, decision.ShouldProxy)
|
||||
}
|
||||
|
||||
func TestRouter_ThinkingSuffix_Preserved(t *testing.T) {
|
||||
// Setup: Create a router with mapping and provider for mapped model
|
||||
registry := NewRegistry()
|
||||
|
||||
mockProvider := &MockProvider{
|
||||
name: "anthropic",
|
||||
supportedModels: []string{"claude-local"},
|
||||
available: true,
|
||||
priority: 1,
|
||||
}
|
||||
registry.Register(mockProvider)
|
||||
|
||||
cfg := &config.Config{
|
||||
AmpCode: config.AmpCode{
|
||||
ModelMappings: []config.AmpModelMapping{
|
||||
{From: "claude-3-5-sonnet", To: "claude-local"},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
router := NewRouter(registry, cfg)
|
||||
|
||||
// Test: Request claude-3-5-sonnet with thinking suffix
|
||||
req := RoutingRequest{
|
||||
RequestedModel: "claude-3-5-sonnet(thinking:foo)",
|
||||
PreferLocalProvider: true,
|
||||
ForceModelMapping: false,
|
||||
}
|
||||
|
||||
decision := router.ResolveV2(req)
|
||||
|
||||
// Assert: Thinking suffix should be preserved in resolved model
|
||||
assert.Equal(t, RouteTypeModelMapping, decision.RouteType)
|
||||
assert.Equal(t, "claude-local(thinking:foo)", decision.ResolvedModel)
|
||||
assert.Equal(t, "anthropic", decision.ProviderName)
|
||||
}
|
||||
|
||||
// MockProvider is a mock implementation of Provider for testing
|
||||
type MockProvider struct {
|
||||
name string
|
||||
providerType ProviderType
|
||||
supportedModels []string
|
||||
available bool
|
||||
priority int
|
||||
}
|
||||
|
||||
func (m *MockProvider) Name() string {
|
||||
return m.name
|
||||
}
|
||||
|
||||
func (m *MockProvider) Type() ProviderType {
|
||||
if m.providerType == "" {
|
||||
return ProviderTypeOAuth
|
||||
}
|
||||
return m.providerType
|
||||
}
|
||||
|
||||
func (m *MockProvider) SupportsModel(model string) bool {
|
||||
for _, supported := range m.supportedModels {
|
||||
if supported == model {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (m *MockProvider) Available(model string) bool {
|
||||
return m.available
|
||||
}
|
||||
|
||||
func (m *MockProvider) Priority() int {
|
||||
return m.priority
|
||||
}
|
||||
|
||||
func (m *MockProvider) Execute(ctx context.Context, model string, req executor.Request) (executor.Response, error) {
|
||||
return executor.Response{}, nil
|
||||
}
|
||||
|
||||
func (m *MockProvider) ExecuteStream(ctx context.Context, model string, req executor.Request) (<-chan executor.StreamChunk, error) {
|
||||
return nil, nil
|
||||
}
|
||||
113
internal/routing/testutil/fake_handler.go
Normal file
113
internal/routing/testutil/fake_handler.go
Normal file
@@ -0,0 +1,113 @@
|
||||
package testutil
|
||||
|
||||
import (
|
||||
"io"
|
||||
"net/http"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// FakeHandlerRecorder records handler invocations for testing.
|
||||
type FakeHandlerRecorder struct {
|
||||
Called bool
|
||||
CallCount int
|
||||
RequestBody []byte
|
||||
RequestHeader http.Header
|
||||
ContextKeys map[string]interface{}
|
||||
ResponseStatus int
|
||||
ResponseBody []byte
|
||||
}
|
||||
|
||||
// NewFakeHandlerRecorder creates a new fake handler recorder.
|
||||
func NewFakeHandlerRecorder() *FakeHandlerRecorder {
|
||||
return &FakeHandlerRecorder{
|
||||
ContextKeys: make(map[string]interface{}),
|
||||
ResponseStatus: http.StatusOK,
|
||||
ResponseBody: []byte(`{"status":"handled"}`),
|
||||
}
|
||||
}
|
||||
|
||||
// GinHandler returns a gin.HandlerFunc that records the invocation.
|
||||
func (f *FakeHandlerRecorder) GinHandler() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
f.record(c)
|
||||
c.Data(f.ResponseStatus, "application/json", f.ResponseBody)
|
||||
}
|
||||
}
|
||||
|
||||
// GinHandlerWithModel returns a gin.HandlerFunc that records the invocation and returns the model from context.
|
||||
// Useful for testing response rewriting in model mapping scenarios.
|
||||
func (f *FakeHandlerRecorder) GinHandlerWithModel() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
f.record(c)
|
||||
// Return a response with the model field that would be in the actual API response
|
||||
// If ResponseBody was explicitly set (not default), use that; otherwise generate from context
|
||||
var body []byte
|
||||
if mappedModel, exists := c.Get("mapped_model"); exists {
|
||||
body = []byte(`{"model":"` + mappedModel.(string) + `","status":"handled"}`)
|
||||
} else {
|
||||
body = f.ResponseBody
|
||||
}
|
||||
c.Data(f.ResponseStatus, "application/json", body)
|
||||
}
|
||||
}
|
||||
|
||||
// HTTPHandler returns an http.HandlerFunc that records the invocation.
|
||||
func (f *FakeHandlerRecorder) HTTPHandler() http.HandlerFunc {
|
||||
return func(w http.ResponseWriter, r *http.Request) {
|
||||
body, _ := io.ReadAll(r.Body)
|
||||
f.Called = true
|
||||
f.CallCount++
|
||||
f.RequestBody = body
|
||||
f.RequestHeader = r.Header.Clone()
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
w.WriteHeader(f.ResponseStatus)
|
||||
w.Write(f.ResponseBody)
|
||||
}
|
||||
}
|
||||
|
||||
// record captures the request details from gin context.
|
||||
func (f *FakeHandlerRecorder) record(c *gin.Context) {
|
||||
f.Called = true
|
||||
f.CallCount++
|
||||
|
||||
body, _ := io.ReadAll(c.Request.Body)
|
||||
f.RequestBody = body
|
||||
f.RequestHeader = c.Request.Header.Clone()
|
||||
|
||||
// Capture common context keys used by routing
|
||||
if val, exists := c.Get("mapped_model"); exists {
|
||||
f.ContextKeys["mapped_model"] = val
|
||||
}
|
||||
if val, exists := c.Get("fallback_models"); exists {
|
||||
f.ContextKeys["fallback_models"] = val
|
||||
}
|
||||
if val, exists := c.Get("route_type"); exists {
|
||||
f.ContextKeys["route_type"] = val
|
||||
}
|
||||
}
|
||||
|
||||
// Reset clears the recorder state.
|
||||
func (f *FakeHandlerRecorder) Reset() {
|
||||
f.Called = false
|
||||
f.CallCount = 0
|
||||
f.RequestBody = nil
|
||||
f.RequestHeader = nil
|
||||
f.ContextKeys = make(map[string]interface{})
|
||||
}
|
||||
|
||||
// GetContextKey returns a captured context key value.
|
||||
func (f *FakeHandlerRecorder) GetContextKey(key string) (interface{}, bool) {
|
||||
val, ok := f.ContextKeys[key]
|
||||
return val, ok
|
||||
}
|
||||
|
||||
// WasCalled returns true if the handler was called.
|
||||
func (f *FakeHandlerRecorder) WasCalled() bool {
|
||||
return f.Called
|
||||
}
|
||||
|
||||
// GetCallCount returns the number of times the handler was called.
|
||||
func (f *FakeHandlerRecorder) GetCallCount() int {
|
||||
return f.CallCount
|
||||
}
|
||||
83
internal/routing/testutil/fake_proxy.go
Normal file
83
internal/routing/testutil/fake_proxy.go
Normal file
@@ -0,0 +1,83 @@
|
||||
package testutil
|
||||
|
||||
import (
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
)
|
||||
|
||||
// CloseNotifierRecorder wraps httptest.ResponseRecorder with CloseNotify support.
|
||||
// This is needed because ReverseProxy requires http.CloseNotifier.
|
||||
type CloseNotifierRecorder struct {
|
||||
*httptest.ResponseRecorder
|
||||
closeChan chan bool
|
||||
}
|
||||
|
||||
// NewCloseNotifierRecorder creates a ResponseRecorder that implements CloseNotifier.
|
||||
func NewCloseNotifierRecorder() *CloseNotifierRecorder {
|
||||
return &CloseNotifierRecorder{
|
||||
ResponseRecorder: httptest.NewRecorder(),
|
||||
closeChan: make(chan bool, 1),
|
||||
}
|
||||
}
|
||||
|
||||
// CloseNotify implements http.CloseNotifier.
|
||||
func (c *CloseNotifierRecorder) CloseNotify() <-chan bool {
|
||||
return c.closeChan
|
||||
}
|
||||
|
||||
// FakeProxyRecorder records proxy invocations for testing.
|
||||
type FakeProxyRecorder struct {
|
||||
Called bool
|
||||
CallCount int
|
||||
RequestBody []byte
|
||||
RequestHeaders http.Header
|
||||
ResponseStatus int
|
||||
ResponseBody []byte
|
||||
}
|
||||
|
||||
// NewFakeProxyRecorder creates a new fake proxy recorder.
|
||||
func NewFakeProxyRecorder() *FakeProxyRecorder {
|
||||
return &FakeProxyRecorder{
|
||||
ResponseStatus: http.StatusOK,
|
||||
ResponseBody: []byte(`{"status":"proxied"}`),
|
||||
}
|
||||
}
|
||||
|
||||
// ServeHTTP implements http.Handler to act as a reverse proxy.
|
||||
func (f *FakeProxyRecorder) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||
f.Called = true
|
||||
f.CallCount++
|
||||
f.RequestHeaders = r.Header.Clone()
|
||||
|
||||
body, err := io.ReadAll(r.Body)
|
||||
if err == nil {
|
||||
f.RequestBody = body
|
||||
}
|
||||
|
||||
w.WriteHeader(f.ResponseStatus)
|
||||
w.Write(f.ResponseBody)
|
||||
}
|
||||
|
||||
// GetCallCount returns the number of times the proxy was called.
|
||||
func (f *FakeProxyRecorder) GetCallCount() int {
|
||||
return f.CallCount
|
||||
}
|
||||
|
||||
// Reset clears the recorder state.
|
||||
func (f *FakeProxyRecorder) Reset() {
|
||||
f.Called = false
|
||||
f.CallCount = 0
|
||||
f.RequestBody = nil
|
||||
f.RequestHeaders = nil
|
||||
}
|
||||
|
||||
// ToHandler returns the recorder as an http.Handler for use with httptest.
|
||||
func (f *FakeProxyRecorder) ToHandler() http.Handler {
|
||||
return http.HandlerFunc(f.ServeHTTP)
|
||||
}
|
||||
|
||||
// CreateTestServer creates an httptest server with this fake proxy.
|
||||
func (f *FakeProxyRecorder) CreateTestServer() *httptest.Server {
|
||||
return httptest.NewServer(f.ToHandler())
|
||||
}
|
||||
62
internal/routing/types.go
Normal file
62
internal/routing/types.go
Normal file
@@ -0,0 +1,62 @@
|
||||
package routing
|
||||
|
||||
// RouteType represents the type of routing decision made for a request.
|
||||
type RouteType string
|
||||
|
||||
const (
|
||||
// RouteTypeLocalProvider indicates the request is handled by a local OAuth provider (free).
|
||||
RouteTypeLocalProvider RouteType = "LOCAL_PROVIDER"
|
||||
// RouteTypeModelMapping indicates the request was remapped to another available model (free).
|
||||
RouteTypeModelMapping RouteType = "MODEL_MAPPING"
|
||||
// RouteTypeAmpCredits indicates the request is forwarded to ampcode.com (uses Amp credits).
|
||||
RouteTypeAmpCredits RouteType = "AMP_CREDITS"
|
||||
// RouteTypeNoProvider indicates no provider or fallback available.
|
||||
RouteTypeNoProvider RouteType = "NO_PROVIDER"
|
||||
)
|
||||
|
||||
// RoutingRequest contains the information needed to make a routing decision.
|
||||
type RoutingRequest struct {
|
||||
// RequestedModel is the model name from the incoming request.
|
||||
RequestedModel string
|
||||
// PreferLocalProvider indicates whether to prefer local providers over mappings.
|
||||
// When true, check local providers first before applying model mappings.
|
||||
PreferLocalProvider bool
|
||||
// ForceModelMapping indicates whether to force model mapping even if local provider exists.
|
||||
// When true, apply model mappings first and skip local provider checks.
|
||||
ForceModelMapping bool
|
||||
}
|
||||
|
||||
// RoutingDecision contains the result of a routing decision.
|
||||
type RoutingDecision struct {
|
||||
// RouteType indicates the type of routing decision.
|
||||
RouteType RouteType
|
||||
// ResolvedModel is the final model name after any mappings.
|
||||
ResolvedModel string
|
||||
// ProviderName is the name of the selected provider (if any).
|
||||
ProviderName string
|
||||
// FallbackModels is a list of alternative models to try if the primary fails.
|
||||
FallbackModels []string
|
||||
// ShouldProxy indicates whether the request should be proxied to ampcode.com.
|
||||
ShouldProxy bool
|
||||
}
|
||||
|
||||
// NewRoutingDecision creates a new RoutingDecision with the given parameters.
|
||||
func NewRoutingDecision(routeType RouteType, resolvedModel, providerName string, fallbackModels []string, shouldProxy bool) *RoutingDecision {
|
||||
return &RoutingDecision{
|
||||
RouteType: routeType,
|
||||
ResolvedModel: resolvedModel,
|
||||
ProviderName: providerName,
|
||||
FallbackModels: fallbackModels,
|
||||
ShouldProxy: shouldProxy,
|
||||
}
|
||||
}
|
||||
|
||||
// IsLocal returns true if the decision routes to a local provider.
|
||||
func (d *RoutingDecision) IsLocal() bool {
|
||||
return d.RouteType == RouteTypeLocalProvider || d.RouteType == RouteTypeModelMapping
|
||||
}
|
||||
|
||||
// HasFallbacks returns true if there are fallback models available.
|
||||
func (d *RoutingDecision) HasFallbacks() bool {
|
||||
return len(d.FallbackModels) > 0
|
||||
}
|
||||
270
internal/routing/wrapper.go
Normal file
270
internal/routing/wrapper.go
Normal file
@@ -0,0 +1,270 @@
|
||||
package routing
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"io"
|
||||
"net"
|
||||
"net/http"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/routing/ctxkeys"
|
||||
"github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// ProxyFunc is the function type for proxying requests.
|
||||
type ProxyFunc func(c *gin.Context)
|
||||
|
||||
// ModelRoutingWrapper wraps HTTP handlers with unified model routing logic.
|
||||
// It replaces the FallbackHandler logic with a Router-based approach.
|
||||
type ModelRoutingWrapper struct {
|
||||
router *Router
|
||||
extractor ModelExtractor
|
||||
rewriter ModelRewriter
|
||||
proxyFunc ProxyFunc
|
||||
logger *logrus.Logger
|
||||
}
|
||||
|
||||
// NewModelRoutingWrapper creates a new ModelRoutingWrapper with the given dependencies.
|
||||
// If extractor is nil, a DefaultModelExtractor is used.
|
||||
// If rewriter is nil, a DefaultModelRewriter is used.
|
||||
// proxyFunc is called for AMP_CREDITS route type; if nil, the handler will be called instead.
|
||||
func NewModelRoutingWrapper(router *Router, extractor ModelExtractor, rewriter ModelRewriter, proxyFunc ProxyFunc) *ModelRoutingWrapper {
|
||||
if extractor == nil {
|
||||
extractor = NewModelExtractor()
|
||||
}
|
||||
if rewriter == nil {
|
||||
rewriter = NewModelRewriter()
|
||||
}
|
||||
return &ModelRoutingWrapper{
|
||||
router: router,
|
||||
extractor: extractor,
|
||||
rewriter: rewriter,
|
||||
proxyFunc: proxyFunc,
|
||||
logger: logrus.New(),
|
||||
}
|
||||
}
|
||||
|
||||
// SetLogger sets the logger for the wrapper.
|
||||
func (w *ModelRoutingWrapper) SetLogger(logger *logrus.Logger) {
|
||||
w.logger = logger
|
||||
}
|
||||
|
||||
// Wrap wraps a gin.HandlerFunc with model routing logic.
|
||||
// The returned handler will:
|
||||
// 1. Extract the model from the request
|
||||
// 2. Get a routing decision from the Router
|
||||
// 3. Handle the request according to the decision type (LOCAL_PROVIDER, MODEL_MAPPING, AMP_CREDITS)
|
||||
func (w *ModelRoutingWrapper) Wrap(handler gin.HandlerFunc) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
// Read request body
|
||||
bodyBytes, err := io.ReadAll(c.Request.Body)
|
||||
if err != nil {
|
||||
w.logger.Errorf("routing wrapper: failed to read request body: %v", err)
|
||||
handler(c)
|
||||
return
|
||||
}
|
||||
|
||||
// Extract model from request
|
||||
ginParams := map[string]string{
|
||||
"action": c.Param("action"),
|
||||
"path": c.Param("path"),
|
||||
}
|
||||
modelName, err := w.extractor.Extract(bodyBytes, ginParams)
|
||||
if err != nil {
|
||||
w.logger.Warnf("routing wrapper: failed to extract model: %v", err)
|
||||
c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes))
|
||||
handler(c)
|
||||
return
|
||||
}
|
||||
|
||||
if modelName == "" {
|
||||
// No model found, proceed with original handler
|
||||
c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes))
|
||||
handler(c)
|
||||
return
|
||||
}
|
||||
|
||||
// Get routing decision
|
||||
req := RoutingRequest{
|
||||
RequestedModel: modelName,
|
||||
PreferLocalProvider: true,
|
||||
ForceModelMapping: false, // TODO: Get from config
|
||||
}
|
||||
decision := w.router.ResolveV2(req)
|
||||
|
||||
// Store decision in context for downstream handlers
|
||||
c.Set(string(ctxkeys.RoutingDecision), decision)
|
||||
|
||||
// Handle based on route type
|
||||
switch decision.RouteType {
|
||||
case RouteTypeLocalProvider:
|
||||
w.handleLocalProvider(c, handler, bodyBytes, decision)
|
||||
case RouteTypeModelMapping:
|
||||
w.handleModelMapping(c, handler, bodyBytes, decision)
|
||||
case RouteTypeAmpCredits:
|
||||
w.handleAmpCredits(c, handler, bodyBytes)
|
||||
default:
|
||||
// No provider available
|
||||
c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes))
|
||||
handler(c)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// handleLocalProvider handles the LOCAL_PROVIDER route type.
|
||||
func (w *ModelRoutingWrapper) handleLocalProvider(c *gin.Context, handler gin.HandlerFunc, bodyBytes []byte, decision *RoutingDecision) {
|
||||
// Filter Anthropic-Beta header for local provider
|
||||
filterAnthropicBetaHeader(c)
|
||||
|
||||
// Restore body with original content
|
||||
c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes))
|
||||
|
||||
// Call handler
|
||||
handler(c)
|
||||
}
|
||||
|
||||
// handleModelMapping handles the MODEL_MAPPING route type.
|
||||
func (w *ModelRoutingWrapper) handleModelMapping(c *gin.Context, handler gin.HandlerFunc, bodyBytes []byte, decision *RoutingDecision) {
|
||||
// Rewrite request body with mapped model
|
||||
rewrittenBody, err := w.rewriter.RewriteRequestBody(bodyBytes, decision.ResolvedModel)
|
||||
if err != nil {
|
||||
w.logger.Warnf("routing wrapper: failed to rewrite request body: %v", err)
|
||||
rewrittenBody = bodyBytes
|
||||
}
|
||||
_ = rewrittenBody
|
||||
|
||||
// Store mapped model in context
|
||||
c.Set(string(ctxkeys.MappedModel), decision.ResolvedModel)
|
||||
|
||||
// Store fallback models in context if present
|
||||
if len(decision.FallbackModels) > 0 {
|
||||
c.Set(string(ctxkeys.FallbackModels), decision.FallbackModels)
|
||||
}
|
||||
|
||||
// Filter Anthropic-Beta header for local provider
|
||||
filterAnthropicBetaHeader(c)
|
||||
|
||||
// Restore body with rewritten content
|
||||
c.Request.Body = io.NopCloser(bytes.NewReader(rewrittenBody))
|
||||
|
||||
// Wrap response writer to rewrite model back
|
||||
wrappedWriter, cleanup := w.rewriter.WrapResponseWriter(c.Writer, decision.ResolvedModel, decision.ResolvedModel)
|
||||
c.Writer = &ginResponseWriterAdapter{ResponseWriter: wrappedWriter, original: c.Writer}
|
||||
|
||||
// Call handler
|
||||
handler(c)
|
||||
|
||||
// Cleanup (flush response rewriting)
|
||||
cleanup()
|
||||
}
|
||||
|
||||
// handleAmpCredits handles the AMP_CREDITS route type.
|
||||
// It calls the proxy function directly if available, otherwise passes to handler.
|
||||
// Does NOT filter headers or rewrite body - proxy handles everything.
|
||||
func (w *ModelRoutingWrapper) handleAmpCredits(c *gin.Context, handler gin.HandlerFunc, bodyBytes []byte) {
|
||||
// Restore body with original content (no rewriting for proxy)
|
||||
c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes))
|
||||
|
||||
// Call proxy function if available, otherwise fall back to handler
|
||||
if w.proxyFunc != nil {
|
||||
w.proxyFunc(c)
|
||||
} else {
|
||||
handler(c)
|
||||
}
|
||||
}
|
||||
|
||||
// filterAnthropicBetaHeader filters Anthropic-Beta header for local providers.
|
||||
func filterAnthropicBetaHeader(c *gin.Context) {
|
||||
if betaHeader := c.Request.Header.Get("Anthropic-Beta"); betaHeader != "" {
|
||||
filtered := filterBetaFeatures(betaHeader, "context-1m-2025-08-07")
|
||||
if filtered != "" {
|
||||
c.Request.Header.Set("Anthropic-Beta", filtered)
|
||||
} else {
|
||||
c.Request.Header.Del("Anthropic-Beta")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// filterBetaFeatures removes specified beta features from the header.
|
||||
func filterBetaFeatures(betaHeader, featureToRemove string) string {
|
||||
// Simple implementation - can be enhanced
|
||||
if betaHeader == featureToRemove {
|
||||
return ""
|
||||
}
|
||||
return betaHeader
|
||||
}
|
||||
|
||||
// ginResponseWriterAdapter adapts http.ResponseWriter to gin.ResponseWriter.
|
||||
type ginResponseWriterAdapter struct {
|
||||
http.ResponseWriter
|
||||
original gin.ResponseWriter
|
||||
}
|
||||
|
||||
func (a *ginResponseWriterAdapter) WriteHeader(code int) {
|
||||
a.ResponseWriter.WriteHeader(code)
|
||||
}
|
||||
|
||||
func (a *ginResponseWriterAdapter) Write(data []byte) (int, error) {
|
||||
return a.ResponseWriter.Write(data)
|
||||
}
|
||||
|
||||
func (a *ginResponseWriterAdapter) Header() http.Header {
|
||||
return a.ResponseWriter.Header()
|
||||
}
|
||||
|
||||
// CloseNotify implements http.CloseNotifier.
|
||||
func (a *ginResponseWriterAdapter) CloseNotify() <-chan bool {
|
||||
if notifier, ok := a.ResponseWriter.(http.CloseNotifier); ok {
|
||||
return notifier.CloseNotify()
|
||||
}
|
||||
return a.original.CloseNotify()
|
||||
}
|
||||
|
||||
// Flush implements http.Flusher.
|
||||
func (a *ginResponseWriterAdapter) Flush() {
|
||||
if flusher, ok := a.ResponseWriter.(http.Flusher); ok {
|
||||
flusher.Flush()
|
||||
}
|
||||
}
|
||||
|
||||
// Hijack implements http.Hijacker.
|
||||
func (a *ginResponseWriterAdapter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
|
||||
if hijacker, ok := a.ResponseWriter.(http.Hijacker); ok {
|
||||
return hijacker.Hijack()
|
||||
}
|
||||
return a.original.Hijack()
|
||||
}
|
||||
|
||||
// Status returns the HTTP status code.
|
||||
func (a *ginResponseWriterAdapter) Status() int {
|
||||
return a.original.Status()
|
||||
}
|
||||
|
||||
// Size returns the number of bytes already written into the response http body.
|
||||
func (a *ginResponseWriterAdapter) Size() int {
|
||||
return a.original.Size()
|
||||
}
|
||||
|
||||
// Written returns whether or not the response for this context has been written.
|
||||
func (a *ginResponseWriterAdapter) Written() bool {
|
||||
return a.original.Written()
|
||||
}
|
||||
|
||||
// WriteHeaderNow forces WriteHeader to be called.
|
||||
func (a *ginResponseWriterAdapter) WriteHeaderNow() {
|
||||
a.original.WriteHeaderNow()
|
||||
}
|
||||
|
||||
// WriteString writes the given string into the response body.
|
||||
func (a *ginResponseWriterAdapter) WriteString(s string) (int, error) {
|
||||
return a.Write([]byte(s))
|
||||
}
|
||||
|
||||
// Pusher returns the http.Pusher for server push.
|
||||
func (a *ginResponseWriterAdapter) Pusher() http.Pusher {
|
||||
if pusher, ok := a.ResponseWriter.(http.Pusher); ok {
|
||||
return pusher
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -141,7 +141,7 @@ func (e *AIStudioExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth,
|
||||
URL: endpoint,
|
||||
Method: http.MethodPost,
|
||||
Headers: wsReq.Headers.Clone(),
|
||||
Body: body.payload,
|
||||
Body: bytes.Clone(body.payload),
|
||||
Provider: e.Identifier(),
|
||||
AuthID: authID,
|
||||
AuthLabel: authLabel,
|
||||
@@ -156,14 +156,14 @@ func (e *AIStudioExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth,
|
||||
}
|
||||
recordAPIResponseMetadata(ctx, e.cfg, wsResp.Status, wsResp.Headers.Clone())
|
||||
if len(wsResp.Body) > 0 {
|
||||
appendAPIResponseChunk(ctx, e.cfg, wsResp.Body)
|
||||
appendAPIResponseChunk(ctx, e.cfg, bytes.Clone(wsResp.Body))
|
||||
}
|
||||
if wsResp.Status < 200 || wsResp.Status >= 300 {
|
||||
return resp, statusErr{code: wsResp.Status, msg: string(wsResp.Body)}
|
||||
}
|
||||
reporter.publish(ctx, parseGeminiUsage(wsResp.Body))
|
||||
var param any
|
||||
out := sdktranslator.TranslateNonStream(ctx, body.toFormat, opts.SourceFormat, req.Model, opts.OriginalRequest, translatedReq, wsResp.Body, ¶m)
|
||||
out := sdktranslator.TranslateNonStream(ctx, body.toFormat, opts.SourceFormat, req.Model, bytes.Clone(opts.OriginalRequest), bytes.Clone(translatedReq), bytes.Clone(wsResp.Body), ¶m)
|
||||
resp = cliproxyexecutor.Response{Payload: ensureColonSpacedJSON([]byte(out))}
|
||||
return resp, nil
|
||||
}
|
||||
@@ -199,7 +199,7 @@ func (e *AIStudioExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth
|
||||
URL: endpoint,
|
||||
Method: http.MethodPost,
|
||||
Headers: wsReq.Headers.Clone(),
|
||||
Body: body.payload,
|
||||
Body: bytes.Clone(body.payload),
|
||||
Provider: e.Identifier(),
|
||||
AuthID: authID,
|
||||
AuthLabel: authLabel,
|
||||
@@ -225,7 +225,7 @@ func (e *AIStudioExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth
|
||||
}
|
||||
var body bytes.Buffer
|
||||
if len(firstEvent.Payload) > 0 {
|
||||
appendAPIResponseChunk(ctx, e.cfg, firstEvent.Payload)
|
||||
appendAPIResponseChunk(ctx, e.cfg, bytes.Clone(firstEvent.Payload))
|
||||
body.Write(firstEvent.Payload)
|
||||
}
|
||||
if firstEvent.Type == wsrelay.MessageTypeStreamEnd {
|
||||
@@ -244,7 +244,7 @@ func (e *AIStudioExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth
|
||||
metadataLogged = true
|
||||
}
|
||||
if len(event.Payload) > 0 {
|
||||
appendAPIResponseChunk(ctx, e.cfg, event.Payload)
|
||||
appendAPIResponseChunk(ctx, e.cfg, bytes.Clone(event.Payload))
|
||||
body.Write(event.Payload)
|
||||
}
|
||||
if event.Type == wsrelay.MessageTypeStreamEnd {
|
||||
@@ -274,12 +274,12 @@ func (e *AIStudioExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth
|
||||
}
|
||||
case wsrelay.MessageTypeStreamChunk:
|
||||
if len(event.Payload) > 0 {
|
||||
appendAPIResponseChunk(ctx, e.cfg, event.Payload)
|
||||
appendAPIResponseChunk(ctx, e.cfg, bytes.Clone(event.Payload))
|
||||
filtered := FilterSSEUsageMetadata(event.Payload)
|
||||
if detail, ok := parseGeminiStreamUsage(filtered); ok {
|
||||
reporter.publish(ctx, detail)
|
||||
}
|
||||
lines := sdktranslator.TranslateStream(ctx, body.toFormat, opts.SourceFormat, req.Model, opts.OriginalRequest, translatedReq, filtered, ¶m)
|
||||
lines := sdktranslator.TranslateStream(ctx, body.toFormat, opts.SourceFormat, req.Model, bytes.Clone(opts.OriginalRequest), translatedReq, bytes.Clone(filtered), ¶m)
|
||||
for i := range lines {
|
||||
out <- cliproxyexecutor.StreamChunk{Payload: ensureColonSpacedJSON([]byte(lines[i]))}
|
||||
}
|
||||
@@ -293,9 +293,9 @@ func (e *AIStudioExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth
|
||||
metadataLogged = true
|
||||
}
|
||||
if len(event.Payload) > 0 {
|
||||
appendAPIResponseChunk(ctx, e.cfg, event.Payload)
|
||||
appendAPIResponseChunk(ctx, e.cfg, bytes.Clone(event.Payload))
|
||||
}
|
||||
lines := sdktranslator.TranslateStream(ctx, body.toFormat, opts.SourceFormat, req.Model, opts.OriginalRequest, translatedReq, event.Payload, ¶m)
|
||||
lines := sdktranslator.TranslateStream(ctx, body.toFormat, opts.SourceFormat, req.Model, bytes.Clone(opts.OriginalRequest), translatedReq, bytes.Clone(event.Payload), ¶m)
|
||||
for i := range lines {
|
||||
out <- cliproxyexecutor.StreamChunk{Payload: ensureColonSpacedJSON([]byte(lines[i]))}
|
||||
}
|
||||
@@ -350,7 +350,7 @@ func (e *AIStudioExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.A
|
||||
URL: endpoint,
|
||||
Method: http.MethodPost,
|
||||
Headers: wsReq.Headers.Clone(),
|
||||
Body: body.payload,
|
||||
Body: bytes.Clone(body.payload),
|
||||
Provider: e.Identifier(),
|
||||
AuthID: authID,
|
||||
AuthLabel: authLabel,
|
||||
@@ -364,7 +364,7 @@ func (e *AIStudioExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.A
|
||||
}
|
||||
recordAPIResponseMetadata(ctx, e.cfg, resp.Status, resp.Headers.Clone())
|
||||
if len(resp.Body) > 0 {
|
||||
appendAPIResponseChunk(ctx, e.cfg, resp.Body)
|
||||
appendAPIResponseChunk(ctx, e.cfg, bytes.Clone(resp.Body))
|
||||
}
|
||||
if resp.Status < 200 || resp.Status >= 300 {
|
||||
return cliproxyexecutor.Response{}, statusErr{code: resp.Status, msg: string(resp.Body)}
|
||||
@@ -373,7 +373,7 @@ func (e *AIStudioExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.A
|
||||
if totalTokens <= 0 {
|
||||
return cliproxyexecutor.Response{}, fmt.Errorf("wsrelay: totalTokens missing in response")
|
||||
}
|
||||
translated := sdktranslator.TranslateTokenCount(ctx, body.toFormat, opts.SourceFormat, totalTokens, resp.Body)
|
||||
translated := sdktranslator.TranslateTokenCount(ctx, body.toFormat, opts.SourceFormat, totalTokens, bytes.Clone(resp.Body))
|
||||
return cliproxyexecutor.Response{Payload: []byte(translated)}, nil
|
||||
}
|
||||
|
||||
@@ -393,13 +393,12 @@ func (e *AIStudioExecutor) translateRequest(req cliproxyexecutor.Request, opts c
|
||||
|
||||
from := opts.SourceFormat
|
||||
to := sdktranslator.FromString("gemini")
|
||||
originalPayloadSource := req.Payload
|
||||
originalPayload := bytes.Clone(req.Payload)
|
||||
if len(opts.OriginalRequest) > 0 {
|
||||
originalPayloadSource = opts.OriginalRequest
|
||||
originalPayload = bytes.Clone(opts.OriginalRequest)
|
||||
}
|
||||
originalPayload := originalPayloadSource
|
||||
originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayload, stream)
|
||||
payload := sdktranslator.TranslateRequest(from, to, baseModel, req.Payload, stream)
|
||||
payload := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), stream)
|
||||
payload, err := thinking.ApplyThinking(payload, req.Model, from.String(), to.String(), e.Identifier())
|
||||
if err != nil {
|
||||
return nil, translatedPayload{}, err
|
||||
|
||||
@@ -133,13 +133,12 @@ func (e *AntigravityExecutor) Execute(ctx context.Context, auth *cliproxyauth.Au
|
||||
from := opts.SourceFormat
|
||||
to := sdktranslator.FromString("antigravity")
|
||||
|
||||
originalPayloadSource := req.Payload
|
||||
originalPayload := bytes.Clone(req.Payload)
|
||||
if len(opts.OriginalRequest) > 0 {
|
||||
originalPayloadSource = opts.OriginalRequest
|
||||
originalPayload = bytes.Clone(opts.OriginalRequest)
|
||||
}
|
||||
originalPayload := originalPayloadSource
|
||||
originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayload, false)
|
||||
translated := sdktranslator.TranslateRequest(from, to, baseModel, req.Payload, false)
|
||||
translated := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), false)
|
||||
|
||||
translated, err = thinking.ApplyThinking(translated, req.Model, from.String(), to.String(), e.Identifier())
|
||||
if err != nil {
|
||||
@@ -231,7 +230,7 @@ attemptLoop:
|
||||
|
||||
reporter.publish(ctx, parseAntigravityUsage(bodyBytes))
|
||||
var param any
|
||||
converted := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, opts.OriginalRequest, translated, bodyBytes, ¶m)
|
||||
converted := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, bytes.Clone(opts.OriginalRequest), translated, bodyBytes, ¶m)
|
||||
resp = cliproxyexecutor.Response{Payload: []byte(converted)}
|
||||
reporter.ensurePublished(ctx)
|
||||
return resp, nil
|
||||
@@ -275,13 +274,12 @@ func (e *AntigravityExecutor) executeClaudeNonStream(ctx context.Context, auth *
|
||||
from := opts.SourceFormat
|
||||
to := sdktranslator.FromString("antigravity")
|
||||
|
||||
originalPayloadSource := req.Payload
|
||||
originalPayload := bytes.Clone(req.Payload)
|
||||
if len(opts.OriginalRequest) > 0 {
|
||||
originalPayloadSource = opts.OriginalRequest
|
||||
originalPayload = bytes.Clone(opts.OriginalRequest)
|
||||
}
|
||||
originalPayload := originalPayloadSource
|
||||
originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayload, true)
|
||||
translated := sdktranslator.TranslateRequest(from, to, baseModel, req.Payload, true)
|
||||
translated := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), true)
|
||||
|
||||
translated, err = thinking.ApplyThinking(translated, req.Model, from.String(), to.String(), e.Identifier())
|
||||
if err != nil {
|
||||
@@ -435,7 +433,7 @@ attemptLoop:
|
||||
|
||||
reporter.publish(ctx, parseAntigravityUsage(resp.Payload))
|
||||
var param any
|
||||
converted := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, opts.OriginalRequest, translated, resp.Payload, ¶m)
|
||||
converted := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, bytes.Clone(opts.OriginalRequest), translated, resp.Payload, ¶m)
|
||||
resp = cliproxyexecutor.Response{Payload: []byte(converted)}
|
||||
reporter.ensurePublished(ctx)
|
||||
|
||||
@@ -667,13 +665,12 @@ func (e *AntigravityExecutor) ExecuteStream(ctx context.Context, auth *cliproxya
|
||||
from := opts.SourceFormat
|
||||
to := sdktranslator.FromString("antigravity")
|
||||
|
||||
originalPayloadSource := req.Payload
|
||||
originalPayload := bytes.Clone(req.Payload)
|
||||
if len(opts.OriginalRequest) > 0 {
|
||||
originalPayloadSource = opts.OriginalRequest
|
||||
originalPayload = bytes.Clone(opts.OriginalRequest)
|
||||
}
|
||||
originalPayload := originalPayloadSource
|
||||
originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayload, true)
|
||||
translated := sdktranslator.TranslateRequest(from, to, baseModel, req.Payload, true)
|
||||
translated := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), true)
|
||||
|
||||
translated, err = thinking.ApplyThinking(translated, req.Model, from.String(), to.String(), e.Identifier())
|
||||
if err != nil {
|
||||
@@ -803,12 +800,12 @@ attemptLoop:
|
||||
reporter.publish(ctx, detail)
|
||||
}
|
||||
|
||||
chunks := sdktranslator.TranslateStream(ctx, to, from, req.Model, opts.OriginalRequest, translated, bytes.Clone(payload), ¶m)
|
||||
chunks := sdktranslator.TranslateStream(ctx, to, from, req.Model, bytes.Clone(opts.OriginalRequest), translated, bytes.Clone(payload), ¶m)
|
||||
for i := range chunks {
|
||||
out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunks[i])}
|
||||
}
|
||||
}
|
||||
tail := sdktranslator.TranslateStream(ctx, to, from, req.Model, opts.OriginalRequest, translated, []byte("[DONE]"), ¶m)
|
||||
tail := sdktranslator.TranslateStream(ctx, to, from, req.Model, bytes.Clone(opts.OriginalRequest), translated, []byte("[DONE]"), ¶m)
|
||||
for i := range tail {
|
||||
out <- cliproxyexecutor.StreamChunk{Payload: []byte(tail[i])}
|
||||
}
|
||||
@@ -875,7 +872,7 @@ func (e *AntigravityExecutor) CountTokens(ctx context.Context, auth *cliproxyaut
|
||||
respCtx := context.WithValue(ctx, "alt", opts.Alt)
|
||||
|
||||
// Prepare payload once (doesn't depend on baseURL)
|
||||
payload := sdktranslator.TranslateRequest(from, to, baseModel, req.Payload, false)
|
||||
payload := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), false)
|
||||
|
||||
payload, err := thinking.ApplyThinking(payload, req.Model, from.String(), to.String(), e.Identifier())
|
||||
if err != nil {
|
||||
@@ -1283,40 +1280,51 @@ func (e *AntigravityExecutor) buildRequest(ctx context.Context, auth *cliproxyau
|
||||
payload = geminiToAntigravity(modelName, payload, projectID)
|
||||
payload, _ = sjson.SetBytes(payload, "model", modelName)
|
||||
|
||||
useAntigravitySchema := strings.Contains(modelName, "claude") || strings.Contains(modelName, "gemini-3-pro-high")
|
||||
payloadStr := string(payload)
|
||||
paths := make([]string, 0)
|
||||
util.Walk(gjson.Parse(payloadStr), "", "parametersJsonSchema", &paths)
|
||||
for _, p := range paths {
|
||||
payloadStr, _ = util.RenameKey(payloadStr, p, p[:len(p)-len("parametersJsonSchema")]+"parameters")
|
||||
}
|
||||
if strings.Contains(modelName, "claude") || strings.Contains(modelName, "gemini-3-pro-high") {
|
||||
strJSON := string(payload)
|
||||
paths := make([]string, 0)
|
||||
util.Walk(gjson.ParseBytes(payload), "", "parametersJsonSchema", &paths)
|
||||
for _, p := range paths {
|
||||
strJSON, _ = util.RenameKey(strJSON, p, p[:len(p)-len("parametersJsonSchema")]+"parameters")
|
||||
}
|
||||
|
||||
if useAntigravitySchema {
|
||||
payloadStr = util.CleanJSONSchemaForAntigravity(payloadStr)
|
||||
// Use the centralized schema cleaner to handle unsupported keywords,
|
||||
// const->enum conversion, and flattening of types/anyOf.
|
||||
strJSON = util.CleanJSONSchemaForAntigravity(strJSON)
|
||||
payload = []byte(strJSON)
|
||||
} else {
|
||||
payloadStr = util.CleanJSONSchemaForGemini(payloadStr)
|
||||
strJSON := string(payload)
|
||||
paths := make([]string, 0)
|
||||
util.Walk(gjson.Parse(strJSON), "", "parametersJsonSchema", &paths)
|
||||
for _, p := range paths {
|
||||
strJSON, _ = util.RenameKey(strJSON, p, p[:len(p)-len("parametersJsonSchema")]+"parameters")
|
||||
}
|
||||
// Clean tool schemas for Gemini to remove unsupported JSON Schema keywords
|
||||
// without adding empty-schema placeholders.
|
||||
strJSON = util.CleanJSONSchemaForGemini(strJSON)
|
||||
payload = []byte(strJSON)
|
||||
}
|
||||
|
||||
if useAntigravitySchema {
|
||||
systemInstructionPartsResult := gjson.Get(payloadStr, "request.systemInstruction.parts")
|
||||
payloadStr, _ = sjson.Set(payloadStr, "request.systemInstruction.role", "user")
|
||||
payloadStr, _ = sjson.Set(payloadStr, "request.systemInstruction.parts.0.text", systemInstruction)
|
||||
payloadStr, _ = sjson.Set(payloadStr, "request.systemInstruction.parts.1.text", fmt.Sprintf("Please ignore following [ignore]%s[/ignore]", systemInstruction))
|
||||
if strings.Contains(modelName, "claude") || strings.Contains(modelName, "gemini-3-pro-high") {
|
||||
systemInstructionPartsResult := gjson.GetBytes(payload, "request.systemInstruction.parts")
|
||||
payload, _ = sjson.SetBytes(payload, "request.systemInstruction.role", "user")
|
||||
payload, _ = sjson.SetBytes(payload, "request.systemInstruction.parts.0.text", systemInstruction)
|
||||
payload, _ = sjson.SetBytes(payload, "request.systemInstruction.parts.1.text", fmt.Sprintf("Please ignore following [ignore]%s[/ignore]", systemInstruction))
|
||||
|
||||
if systemInstructionPartsResult.Exists() && systemInstructionPartsResult.IsArray() {
|
||||
for _, partResult := range systemInstructionPartsResult.Array() {
|
||||
payloadStr, _ = sjson.SetRaw(payloadStr, "request.systemInstruction.parts.-1", partResult.Raw)
|
||||
payload, _ = sjson.SetRawBytes(payload, "request.systemInstruction.parts.-1", []byte(partResult.Raw))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if strings.Contains(modelName, "claude") {
|
||||
payloadStr, _ = sjson.Set(payloadStr, "request.toolConfig.functionCallingConfig.mode", "VALIDATED")
|
||||
payload, _ = sjson.SetBytes(payload, "request.toolConfig.functionCallingConfig.mode", "VALIDATED")
|
||||
} else {
|
||||
payloadStr, _ = sjson.Delete(payloadStr, "request.generationConfig.maxOutputTokens")
|
||||
payload, _ = sjson.DeleteBytes(payload, "request.generationConfig.maxOutputTokens")
|
||||
}
|
||||
|
||||
httpReq, errReq := http.NewRequestWithContext(ctx, http.MethodPost, requestURL.String(), strings.NewReader(payloadStr))
|
||||
httpReq, errReq := http.NewRequestWithContext(ctx, http.MethodPost, requestURL.String(), bytes.NewReader(payload))
|
||||
if errReq != nil {
|
||||
return nil, errReq
|
||||
}
|
||||
@@ -1338,15 +1346,11 @@ func (e *AntigravityExecutor) buildRequest(ctx context.Context, auth *cliproxyau
|
||||
authLabel = auth.Label
|
||||
authType, authValue = auth.AccountInfo()
|
||||
}
|
||||
var payloadLog []byte
|
||||
if e.cfg != nil && e.cfg.RequestLog {
|
||||
payloadLog = []byte(payloadStr)
|
||||
}
|
||||
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{
|
||||
URL: requestURL.String(),
|
||||
Method: http.MethodPost,
|
||||
Headers: httpReq.Header.Clone(),
|
||||
Body: payloadLog,
|
||||
Body: payload,
|
||||
Provider: e.Identifier(),
|
||||
AuthID: authID,
|
||||
AuthLabel: authLabel,
|
||||
|
||||
@@ -100,13 +100,12 @@ func (e *ClaudeExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, r
|
||||
to := sdktranslator.FromString("claude")
|
||||
// Use streaming translation to preserve function calling, except for claude.
|
||||
stream := from != to
|
||||
originalPayloadSource := req.Payload
|
||||
originalPayload := bytes.Clone(req.Payload)
|
||||
if len(opts.OriginalRequest) > 0 {
|
||||
originalPayloadSource = opts.OriginalRequest
|
||||
originalPayload = bytes.Clone(opts.OriginalRequest)
|
||||
}
|
||||
originalPayload := originalPayloadSource
|
||||
originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayload, stream)
|
||||
body := sdktranslator.TranslateRequest(from, to, baseModel, req.Payload, stream)
|
||||
body := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), stream)
|
||||
body, _ = sjson.SetBytes(body, "model", baseModel)
|
||||
|
||||
body, err = thinking.ApplyThinking(body, req.Model, from.String(), to.String(), e.Identifier())
|
||||
@@ -217,7 +216,7 @@ func (e *ClaudeExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, r
|
||||
to,
|
||||
from,
|
||||
req.Model,
|
||||
opts.OriginalRequest,
|
||||
bytes.Clone(opts.OriginalRequest),
|
||||
bodyForTranslation,
|
||||
data,
|
||||
¶m,
|
||||
@@ -241,13 +240,12 @@ func (e *ClaudeExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A
|
||||
defer reporter.trackFailure(ctx, &err)
|
||||
from := opts.SourceFormat
|
||||
to := sdktranslator.FromString("claude")
|
||||
originalPayloadSource := req.Payload
|
||||
originalPayload := bytes.Clone(req.Payload)
|
||||
if len(opts.OriginalRequest) > 0 {
|
||||
originalPayloadSource = opts.OriginalRequest
|
||||
originalPayload = bytes.Clone(opts.OriginalRequest)
|
||||
}
|
||||
originalPayload := originalPayloadSource
|
||||
originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayload, true)
|
||||
body := sdktranslator.TranslateRequest(from, to, baseModel, req.Payload, true)
|
||||
body := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), true)
|
||||
body, _ = sjson.SetBytes(body, "model", baseModel)
|
||||
|
||||
body, err = thinking.ApplyThinking(body, req.Model, from.String(), to.String(), e.Identifier())
|
||||
@@ -383,7 +381,7 @@ func (e *ClaudeExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A
|
||||
to,
|
||||
from,
|
||||
req.Model,
|
||||
opts.OriginalRequest,
|
||||
bytes.Clone(opts.OriginalRequest),
|
||||
bodyForTranslation,
|
||||
bytes.Clone(line),
|
||||
¶m,
|
||||
@@ -413,7 +411,7 @@ func (e *ClaudeExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Aut
|
||||
to := sdktranslator.FromString("claude")
|
||||
// Use streaming translation to preserve function calling, except for claude.
|
||||
stream := from != to
|
||||
body := sdktranslator.TranslateRequest(from, to, baseModel, req.Payload, stream)
|
||||
body := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), stream)
|
||||
body, _ = sjson.SetBytes(body, "model", baseModel)
|
||||
|
||||
if !strings.HasPrefix(baseModel, "claude-3-5-haiku") {
|
||||
|
||||
@@ -27,11 +27,6 @@ import (
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
const (
|
||||
codexClientVersion = "0.98.0"
|
||||
codexUserAgent = "codex_cli_rs/0.98.0 (Mac OS 26.0.1; arm64) Apple_Terminal/464"
|
||||
)
|
||||
|
||||
var dataTag = []byte("data:")
|
||||
|
||||
// CodexExecutor is a stateless executor for Codex (OpenAI Responses API entrypoint).
|
||||
@@ -93,13 +88,12 @@ func (e *CodexExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, re
|
||||
|
||||
from := opts.SourceFormat
|
||||
to := sdktranslator.FromString("codex")
|
||||
originalPayloadSource := req.Payload
|
||||
originalPayload := bytes.Clone(req.Payload)
|
||||
if len(opts.OriginalRequest) > 0 {
|
||||
originalPayloadSource = opts.OriginalRequest
|
||||
originalPayload = bytes.Clone(opts.OriginalRequest)
|
||||
}
|
||||
originalPayload := originalPayloadSource
|
||||
originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayload, false)
|
||||
body := sdktranslator.TranslateRequest(from, to, baseModel, req.Payload, false)
|
||||
body := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), false)
|
||||
|
||||
body, err = thinking.ApplyThinking(body, req.Model, from.String(), to.String(), e.Identifier())
|
||||
if err != nil {
|
||||
@@ -182,7 +176,7 @@ func (e *CodexExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, re
|
||||
}
|
||||
|
||||
var param any
|
||||
out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, originalPayload, body, line, ¶m)
|
||||
out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, bytes.Clone(originalPayload), body, line, ¶m)
|
||||
resp = cliproxyexecutor.Response{Payload: []byte(out)}
|
||||
return resp, nil
|
||||
}
|
||||
@@ -203,13 +197,12 @@ func (e *CodexExecutor) executeCompact(ctx context.Context, auth *cliproxyauth.A
|
||||
|
||||
from := opts.SourceFormat
|
||||
to := sdktranslator.FromString("openai-response")
|
||||
originalPayloadSource := req.Payload
|
||||
originalPayload := bytes.Clone(req.Payload)
|
||||
if len(opts.OriginalRequest) > 0 {
|
||||
originalPayloadSource = opts.OriginalRequest
|
||||
originalPayload = bytes.Clone(opts.OriginalRequest)
|
||||
}
|
||||
originalPayload := originalPayloadSource
|
||||
originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayload, false)
|
||||
body := sdktranslator.TranslateRequest(from, to, baseModel, req.Payload, false)
|
||||
body := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), false)
|
||||
|
||||
body, err = thinking.ApplyThinking(body, req.Model, from.String(), to.String(), e.Identifier())
|
||||
if err != nil {
|
||||
@@ -272,7 +265,7 @@ func (e *CodexExecutor) executeCompact(ctx context.Context, auth *cliproxyauth.A
|
||||
reporter.publish(ctx, parseOpenAIUsage(data))
|
||||
reporter.ensurePublished(ctx)
|
||||
var param any
|
||||
out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, originalPayload, body, data, ¶m)
|
||||
out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, bytes.Clone(originalPayload), body, data, ¶m)
|
||||
resp = cliproxyexecutor.Response{Payload: []byte(out)}
|
||||
return resp, nil
|
||||
}
|
||||
@@ -293,13 +286,12 @@ func (e *CodexExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Au
|
||||
|
||||
from := opts.SourceFormat
|
||||
to := sdktranslator.FromString("codex")
|
||||
originalPayloadSource := req.Payload
|
||||
originalPayload := bytes.Clone(req.Payload)
|
||||
if len(opts.OriginalRequest) > 0 {
|
||||
originalPayloadSource = opts.OriginalRequest
|
||||
originalPayload = bytes.Clone(opts.OriginalRequest)
|
||||
}
|
||||
originalPayload := originalPayloadSource
|
||||
originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayload, true)
|
||||
body := sdktranslator.TranslateRequest(from, to, baseModel, req.Payload, true)
|
||||
body := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), true)
|
||||
|
||||
body, err = thinking.ApplyThinking(body, req.Model, from.String(), to.String(), e.Identifier())
|
||||
if err != nil {
|
||||
@@ -386,7 +378,7 @@ func (e *CodexExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Au
|
||||
}
|
||||
}
|
||||
|
||||
chunks := sdktranslator.TranslateStream(ctx, to, from, req.Model, originalPayload, body, bytes.Clone(line), ¶m)
|
||||
chunks := sdktranslator.TranslateStream(ctx, to, from, req.Model, bytes.Clone(originalPayload), body, bytes.Clone(line), ¶m)
|
||||
for i := range chunks {
|
||||
out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunks[i])}
|
||||
}
|
||||
@@ -405,7 +397,7 @@ func (e *CodexExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Auth
|
||||
|
||||
from := opts.SourceFormat
|
||||
to := sdktranslator.FromString("codex")
|
||||
body := sdktranslator.TranslateRequest(from, to, baseModel, req.Payload, false)
|
||||
body := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), false)
|
||||
|
||||
body, err := thinking.ApplyThinking(body, req.Model, from.String(), to.String(), e.Identifier())
|
||||
if err != nil {
|
||||
@@ -642,10 +634,10 @@ func applyCodexHeaders(r *http.Request, auth *cliproxyauth.Auth, token string, s
|
||||
ginHeaders = ginCtx.Request.Header
|
||||
}
|
||||
|
||||
misc.EnsureHeader(r.Header, ginHeaders, "Version", codexClientVersion)
|
||||
misc.EnsureHeader(r.Header, ginHeaders, "Version", "0.21.0")
|
||||
misc.EnsureHeader(r.Header, ginHeaders, "Openai-Beta", "responses=experimental")
|
||||
misc.EnsureHeader(r.Header, ginHeaders, "Session_id", uuid.NewString())
|
||||
misc.EnsureHeader(r.Header, ginHeaders, "User-Agent", codexUserAgent)
|
||||
misc.EnsureHeader(r.Header, ginHeaders, "User-Agent", "codex_cli_rs/0.50.0 (Mac OS 26.0.1; arm64) Apple_Terminal/464")
|
||||
|
||||
if stream {
|
||||
r.Header.Set("Accept", "text/event-stream")
|
||||
|
||||
@@ -119,13 +119,12 @@ func (e *GeminiCLIExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth
|
||||
from := opts.SourceFormat
|
||||
to := sdktranslator.FromString("gemini-cli")
|
||||
|
||||
originalPayloadSource := req.Payload
|
||||
originalPayload := bytes.Clone(req.Payload)
|
||||
if len(opts.OriginalRequest) > 0 {
|
||||
originalPayloadSource = opts.OriginalRequest
|
||||
originalPayload = bytes.Clone(opts.OriginalRequest)
|
||||
}
|
||||
originalPayload := originalPayloadSource
|
||||
originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayload, false)
|
||||
basePayload := sdktranslator.TranslateRequest(from, to, baseModel, req.Payload, false)
|
||||
basePayload := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), false)
|
||||
|
||||
basePayload, err = thinking.ApplyThinking(basePayload, req.Model, from.String(), to.String(), e.Identifier())
|
||||
if err != nil {
|
||||
@@ -224,7 +223,7 @@ func (e *GeminiCLIExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth
|
||||
if httpResp.StatusCode >= 200 && httpResp.StatusCode < 300 {
|
||||
reporter.publish(ctx, parseGeminiCLIUsage(data))
|
||||
var param any
|
||||
out := sdktranslator.TranslateNonStream(respCtx, to, from, attemptModel, opts.OriginalRequest, payload, data, ¶m)
|
||||
out := sdktranslator.TranslateNonStream(respCtx, to, from, attemptModel, bytes.Clone(opts.OriginalRequest), payload, data, ¶m)
|
||||
resp = cliproxyexecutor.Response{Payload: []byte(out)}
|
||||
return resp, nil
|
||||
}
|
||||
@@ -273,13 +272,12 @@ func (e *GeminiCLIExecutor) ExecuteStream(ctx context.Context, auth *cliproxyaut
|
||||
from := opts.SourceFormat
|
||||
to := sdktranslator.FromString("gemini-cli")
|
||||
|
||||
originalPayloadSource := req.Payload
|
||||
originalPayload := bytes.Clone(req.Payload)
|
||||
if len(opts.OriginalRequest) > 0 {
|
||||
originalPayloadSource = opts.OriginalRequest
|
||||
originalPayload = bytes.Clone(opts.OriginalRequest)
|
||||
}
|
||||
originalPayload := originalPayloadSource
|
||||
originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayload, true)
|
||||
basePayload := sdktranslator.TranslateRequest(from, to, baseModel, req.Payload, true)
|
||||
basePayload := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), true)
|
||||
|
||||
basePayload, err = thinking.ApplyThinking(basePayload, req.Model, from.String(), to.String(), e.Identifier())
|
||||
if err != nil {
|
||||
@@ -401,14 +399,14 @@ func (e *GeminiCLIExecutor) ExecuteStream(ctx context.Context, auth *cliproxyaut
|
||||
reporter.publish(ctx, detail)
|
||||
}
|
||||
if bytes.HasPrefix(line, dataTag) {
|
||||
segments := sdktranslator.TranslateStream(respCtx, to, from, attemptModel, opts.OriginalRequest, reqBody, bytes.Clone(line), ¶m)
|
||||
segments := sdktranslator.TranslateStream(respCtx, to, from, attemptModel, bytes.Clone(opts.OriginalRequest), reqBody, bytes.Clone(line), ¶m)
|
||||
for i := range segments {
|
||||
out <- cliproxyexecutor.StreamChunk{Payload: []byte(segments[i])}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
segments := sdktranslator.TranslateStream(respCtx, to, from, attemptModel, opts.OriginalRequest, reqBody, []byte("[DONE]"), ¶m)
|
||||
segments := sdktranslator.TranslateStream(respCtx, to, from, attemptModel, bytes.Clone(opts.OriginalRequest), reqBody, bytes.Clone([]byte("[DONE]")), ¶m)
|
||||
for i := range segments {
|
||||
out <- cliproxyexecutor.StreamChunk{Payload: []byte(segments[i])}
|
||||
}
|
||||
@@ -430,12 +428,12 @@ func (e *GeminiCLIExecutor) ExecuteStream(ctx context.Context, auth *cliproxyaut
|
||||
appendAPIResponseChunk(ctx, e.cfg, data)
|
||||
reporter.publish(ctx, parseGeminiCLIUsage(data))
|
||||
var param any
|
||||
segments := sdktranslator.TranslateStream(respCtx, to, from, attemptModel, opts.OriginalRequest, reqBody, data, ¶m)
|
||||
segments := sdktranslator.TranslateStream(respCtx, to, from, attemptModel, bytes.Clone(opts.OriginalRequest), reqBody, data, ¶m)
|
||||
for i := range segments {
|
||||
out <- cliproxyexecutor.StreamChunk{Payload: []byte(segments[i])}
|
||||
}
|
||||
|
||||
segments = sdktranslator.TranslateStream(respCtx, to, from, attemptModel, opts.OriginalRequest, reqBody, []byte("[DONE]"), ¶m)
|
||||
segments = sdktranslator.TranslateStream(respCtx, to, from, attemptModel, bytes.Clone(opts.OriginalRequest), reqBody, bytes.Clone([]byte("[DONE]")), ¶m)
|
||||
for i := range segments {
|
||||
out <- cliproxyexecutor.StreamChunk{Payload: []byte(segments[i])}
|
||||
}
|
||||
@@ -487,7 +485,7 @@ func (e *GeminiCLIExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.
|
||||
// The loop variable attemptModel is only used as the concrete model id sent to the upstream
|
||||
// Gemini CLI endpoint when iterating fallback variants.
|
||||
for range models {
|
||||
payload := sdktranslator.TranslateRequest(from, to, baseModel, req.Payload, false)
|
||||
payload := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), false)
|
||||
|
||||
payload, err = thinking.ApplyThinking(payload, req.Model, from.String(), to.String(), e.Identifier())
|
||||
if err != nil {
|
||||
|
||||
@@ -116,13 +116,12 @@ func (e *GeminiExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, r
|
||||
// Official Gemini API via API key or OAuth bearer
|
||||
from := opts.SourceFormat
|
||||
to := sdktranslator.FromString("gemini")
|
||||
originalPayloadSource := req.Payload
|
||||
originalPayload := bytes.Clone(req.Payload)
|
||||
if len(opts.OriginalRequest) > 0 {
|
||||
originalPayloadSource = opts.OriginalRequest
|
||||
originalPayload = bytes.Clone(opts.OriginalRequest)
|
||||
}
|
||||
originalPayload := originalPayloadSource
|
||||
originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayload, false)
|
||||
body := sdktranslator.TranslateRequest(from, to, baseModel, req.Payload, false)
|
||||
body := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), false)
|
||||
|
||||
body, err = thinking.ApplyThinking(body, req.Model, from.String(), to.String(), e.Identifier())
|
||||
if err != nil {
|
||||
@@ -204,7 +203,7 @@ func (e *GeminiExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, r
|
||||
appendAPIResponseChunk(ctx, e.cfg, data)
|
||||
reporter.publish(ctx, parseGeminiUsage(data))
|
||||
var param any
|
||||
out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, opts.OriginalRequest, body, data, ¶m)
|
||||
out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, bytes.Clone(opts.OriginalRequest), body, data, ¶m)
|
||||
resp = cliproxyexecutor.Response{Payload: []byte(out)}
|
||||
return resp, nil
|
||||
}
|
||||
@@ -223,13 +222,12 @@ func (e *GeminiExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A
|
||||
|
||||
from := opts.SourceFormat
|
||||
to := sdktranslator.FromString("gemini")
|
||||
originalPayloadSource := req.Payload
|
||||
originalPayload := bytes.Clone(req.Payload)
|
||||
if len(opts.OriginalRequest) > 0 {
|
||||
originalPayloadSource = opts.OriginalRequest
|
||||
originalPayload = bytes.Clone(opts.OriginalRequest)
|
||||
}
|
||||
originalPayload := originalPayloadSource
|
||||
originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayload, true)
|
||||
body := sdktranslator.TranslateRequest(from, to, baseModel, req.Payload, true)
|
||||
body := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), true)
|
||||
|
||||
body, err = thinking.ApplyThinking(body, req.Model, from.String(), to.String(), e.Identifier())
|
||||
if err != nil {
|
||||
@@ -320,12 +318,12 @@ func (e *GeminiExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A
|
||||
if detail, ok := parseGeminiStreamUsage(payload); ok {
|
||||
reporter.publish(ctx, detail)
|
||||
}
|
||||
lines := sdktranslator.TranslateStream(ctx, to, from, req.Model, opts.OriginalRequest, body, bytes.Clone(payload), ¶m)
|
||||
lines := sdktranslator.TranslateStream(ctx, to, from, req.Model, bytes.Clone(opts.OriginalRequest), body, bytes.Clone(payload), ¶m)
|
||||
for i := range lines {
|
||||
out <- cliproxyexecutor.StreamChunk{Payload: []byte(lines[i])}
|
||||
}
|
||||
}
|
||||
lines := sdktranslator.TranslateStream(ctx, to, from, req.Model, opts.OriginalRequest, body, []byte("[DONE]"), ¶m)
|
||||
lines := sdktranslator.TranslateStream(ctx, to, from, req.Model, bytes.Clone(opts.OriginalRequest), body, bytes.Clone([]byte("[DONE]")), ¶m)
|
||||
for i := range lines {
|
||||
out <- cliproxyexecutor.StreamChunk{Payload: []byte(lines[i])}
|
||||
}
|
||||
@@ -346,7 +344,7 @@ func (e *GeminiExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Aut
|
||||
|
||||
from := opts.SourceFormat
|
||||
to := sdktranslator.FromString("gemini")
|
||||
translatedReq := sdktranslator.TranslateRequest(from, to, baseModel, req.Payload, false)
|
||||
translatedReq := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), false)
|
||||
|
||||
translatedReq, err := thinking.ApplyThinking(translatedReq, req.Model, from.String(), to.String(), e.Identifier())
|
||||
if err != nil {
|
||||
|
||||
@@ -318,13 +318,12 @@ func (e *GeminiVertexExecutor) executeWithServiceAccount(ctx context.Context, au
|
||||
from := opts.SourceFormat
|
||||
to := sdktranslator.FromString("gemini")
|
||||
|
||||
originalPayloadSource := req.Payload
|
||||
originalPayload := bytes.Clone(req.Payload)
|
||||
if len(opts.OriginalRequest) > 0 {
|
||||
originalPayloadSource = opts.OriginalRequest
|
||||
originalPayload = bytes.Clone(opts.OriginalRequest)
|
||||
}
|
||||
originalPayload := originalPayloadSource
|
||||
originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayload, false)
|
||||
body = sdktranslator.TranslateRequest(from, to, baseModel, req.Payload, false)
|
||||
body = sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), false)
|
||||
|
||||
body, err = thinking.ApplyThinking(body, req.Model, from.String(), to.String(), e.Identifier())
|
||||
if err != nil {
|
||||
@@ -418,7 +417,7 @@ func (e *GeminiVertexExecutor) executeWithServiceAccount(ctx context.Context, au
|
||||
from := opts.SourceFormat
|
||||
to := sdktranslator.FromString("gemini")
|
||||
var param any
|
||||
out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, opts.OriginalRequest, body, data, ¶m)
|
||||
out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, bytes.Clone(opts.OriginalRequest), body, data, ¶m)
|
||||
resp = cliproxyexecutor.Response{Payload: []byte(out)}
|
||||
return resp, nil
|
||||
}
|
||||
@@ -433,13 +432,12 @@ func (e *GeminiVertexExecutor) executeWithAPIKey(ctx context.Context, auth *clip
|
||||
from := opts.SourceFormat
|
||||
to := sdktranslator.FromString("gemini")
|
||||
|
||||
originalPayloadSource := req.Payload
|
||||
originalPayload := bytes.Clone(req.Payload)
|
||||
if len(opts.OriginalRequest) > 0 {
|
||||
originalPayloadSource = opts.OriginalRequest
|
||||
originalPayload = bytes.Clone(opts.OriginalRequest)
|
||||
}
|
||||
originalPayload := originalPayloadSource
|
||||
originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayload, false)
|
||||
body := sdktranslator.TranslateRequest(from, to, baseModel, req.Payload, false)
|
||||
body := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), false)
|
||||
|
||||
body, err = thinking.ApplyThinking(body, req.Model, from.String(), to.String(), e.Identifier())
|
||||
if err != nil {
|
||||
@@ -523,7 +521,7 @@ func (e *GeminiVertexExecutor) executeWithAPIKey(ctx context.Context, auth *clip
|
||||
appendAPIResponseChunk(ctx, e.cfg, data)
|
||||
reporter.publish(ctx, parseGeminiUsage(data))
|
||||
var param any
|
||||
out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, opts.OriginalRequest, body, data, ¶m)
|
||||
out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, bytes.Clone(opts.OriginalRequest), body, data, ¶m)
|
||||
resp = cliproxyexecutor.Response{Payload: []byte(out)}
|
||||
return resp, nil
|
||||
}
|
||||
@@ -538,13 +536,12 @@ func (e *GeminiVertexExecutor) executeStreamWithServiceAccount(ctx context.Conte
|
||||
from := opts.SourceFormat
|
||||
to := sdktranslator.FromString("gemini")
|
||||
|
||||
originalPayloadSource := req.Payload
|
||||
originalPayload := bytes.Clone(req.Payload)
|
||||
if len(opts.OriginalRequest) > 0 {
|
||||
originalPayloadSource = opts.OriginalRequest
|
||||
originalPayload = bytes.Clone(opts.OriginalRequest)
|
||||
}
|
||||
originalPayload := originalPayloadSource
|
||||
originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayload, true)
|
||||
body := sdktranslator.TranslateRequest(from, to, baseModel, req.Payload, true)
|
||||
body := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), true)
|
||||
|
||||
body, err = thinking.ApplyThinking(body, req.Model, from.String(), to.String(), e.Identifier())
|
||||
if err != nil {
|
||||
@@ -635,12 +632,12 @@ func (e *GeminiVertexExecutor) executeStreamWithServiceAccount(ctx context.Conte
|
||||
if detail, ok := parseGeminiStreamUsage(line); ok {
|
||||
reporter.publish(ctx, detail)
|
||||
}
|
||||
lines := sdktranslator.TranslateStream(ctx, to, from, req.Model, opts.OriginalRequest, body, bytes.Clone(line), ¶m)
|
||||
lines := sdktranslator.TranslateStream(ctx, to, from, req.Model, bytes.Clone(opts.OriginalRequest), body, bytes.Clone(line), ¶m)
|
||||
for i := range lines {
|
||||
out <- cliproxyexecutor.StreamChunk{Payload: []byte(lines[i])}
|
||||
}
|
||||
}
|
||||
lines := sdktranslator.TranslateStream(ctx, to, from, req.Model, opts.OriginalRequest, body, []byte("[DONE]"), ¶m)
|
||||
lines := sdktranslator.TranslateStream(ctx, to, from, req.Model, bytes.Clone(opts.OriginalRequest), body, []byte("[DONE]"), ¶m)
|
||||
for i := range lines {
|
||||
out <- cliproxyexecutor.StreamChunk{Payload: []byte(lines[i])}
|
||||
}
|
||||
@@ -663,13 +660,12 @@ func (e *GeminiVertexExecutor) executeStreamWithAPIKey(ctx context.Context, auth
|
||||
from := opts.SourceFormat
|
||||
to := sdktranslator.FromString("gemini")
|
||||
|
||||
originalPayloadSource := req.Payload
|
||||
originalPayload := bytes.Clone(req.Payload)
|
||||
if len(opts.OriginalRequest) > 0 {
|
||||
originalPayloadSource = opts.OriginalRequest
|
||||
originalPayload = bytes.Clone(opts.OriginalRequest)
|
||||
}
|
||||
originalPayload := originalPayloadSource
|
||||
originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayload, true)
|
||||
body := sdktranslator.TranslateRequest(from, to, baseModel, req.Payload, true)
|
||||
body := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), true)
|
||||
|
||||
body, err = thinking.ApplyThinking(body, req.Model, from.String(), to.String(), e.Identifier())
|
||||
if err != nil {
|
||||
@@ -760,12 +756,12 @@ func (e *GeminiVertexExecutor) executeStreamWithAPIKey(ctx context.Context, auth
|
||||
if detail, ok := parseGeminiStreamUsage(line); ok {
|
||||
reporter.publish(ctx, detail)
|
||||
}
|
||||
lines := sdktranslator.TranslateStream(ctx, to, from, req.Model, opts.OriginalRequest, body, bytes.Clone(line), ¶m)
|
||||
lines := sdktranslator.TranslateStream(ctx, to, from, req.Model, bytes.Clone(opts.OriginalRequest), body, bytes.Clone(line), ¶m)
|
||||
for i := range lines {
|
||||
out <- cliproxyexecutor.StreamChunk{Payload: []byte(lines[i])}
|
||||
}
|
||||
}
|
||||
lines := sdktranslator.TranslateStream(ctx, to, from, req.Model, opts.OriginalRequest, body, []byte("[DONE]"), ¶m)
|
||||
lines := sdktranslator.TranslateStream(ctx, to, from, req.Model, bytes.Clone(opts.OriginalRequest), body, []byte("[DONE]"), ¶m)
|
||||
for i := range lines {
|
||||
out <- cliproxyexecutor.StreamChunk{Payload: []byte(lines[i])}
|
||||
}
|
||||
@@ -785,7 +781,7 @@ func (e *GeminiVertexExecutor) countTokensWithServiceAccount(ctx context.Context
|
||||
from := opts.SourceFormat
|
||||
to := sdktranslator.FromString("gemini")
|
||||
|
||||
translatedReq := sdktranslator.TranslateRequest(from, to, baseModel, req.Payload, false)
|
||||
translatedReq := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), false)
|
||||
|
||||
translatedReq, err := thinking.ApplyThinking(translatedReq, req.Model, from.String(), to.String(), e.Identifier())
|
||||
if err != nil {
|
||||
@@ -869,7 +865,7 @@ func (e *GeminiVertexExecutor) countTokensWithAPIKey(ctx context.Context, auth *
|
||||
from := opts.SourceFormat
|
||||
to := sdktranslator.FromString("gemini")
|
||||
|
||||
translatedReq := sdktranslator.TranslateRequest(from, to, baseModel, req.Payload, false)
|
||||
translatedReq := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), false)
|
||||
|
||||
translatedReq, err := thinking.ApplyThinking(translatedReq, req.Model, from.String(), to.String(), e.Identifier())
|
||||
if err != nil {
|
||||
@@ -1007,8 +1003,6 @@ func vertexBaseURL(location string) string {
|
||||
loc := strings.TrimSpace(location)
|
||||
if loc == "" {
|
||||
loc = "us-central1"
|
||||
} else if loc == "global" {
|
||||
return "https://aiplatform.googleapis.com"
|
||||
}
|
||||
return fmt.Sprintf("https://%s-aiplatform.googleapis.com", loc)
|
||||
}
|
||||
|
||||
@@ -87,13 +87,12 @@ func (e *IFlowExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, re
|
||||
|
||||
from := opts.SourceFormat
|
||||
to := sdktranslator.FromString("openai")
|
||||
originalPayloadSource := req.Payload
|
||||
originalPayload := bytes.Clone(req.Payload)
|
||||
if len(opts.OriginalRequest) > 0 {
|
||||
originalPayloadSource = opts.OriginalRequest
|
||||
originalPayload = bytes.Clone(opts.OriginalRequest)
|
||||
}
|
||||
originalPayload := originalPayloadSource
|
||||
originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayload, false)
|
||||
body := sdktranslator.TranslateRequest(from, to, baseModel, req.Payload, false)
|
||||
body := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), false)
|
||||
body, _ = sjson.SetBytes(body, "model", baseModel)
|
||||
|
||||
body, err = thinking.ApplyThinking(body, req.Model, from.String(), "iflow", e.Identifier())
|
||||
@@ -164,7 +163,7 @@ func (e *IFlowExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, re
|
||||
var param any
|
||||
// Note: TranslateNonStream uses req.Model (original with suffix) to preserve
|
||||
// the original model name in the response for client compatibility.
|
||||
out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, opts.OriginalRequest, body, data, ¶m)
|
||||
out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, bytes.Clone(opts.OriginalRequest), body, data, ¶m)
|
||||
resp = cliproxyexecutor.Response{Payload: []byte(out)}
|
||||
return resp, nil
|
||||
}
|
||||
@@ -190,13 +189,12 @@ func (e *IFlowExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Au
|
||||
|
||||
from := opts.SourceFormat
|
||||
to := sdktranslator.FromString("openai")
|
||||
originalPayloadSource := req.Payload
|
||||
originalPayload := bytes.Clone(req.Payload)
|
||||
if len(opts.OriginalRequest) > 0 {
|
||||
originalPayloadSource = opts.OriginalRequest
|
||||
originalPayload = bytes.Clone(opts.OriginalRequest)
|
||||
}
|
||||
originalPayload := originalPayloadSource
|
||||
originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayload, true)
|
||||
body := sdktranslator.TranslateRequest(from, to, baseModel, req.Payload, true)
|
||||
body := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), true)
|
||||
body, _ = sjson.SetBytes(body, "model", baseModel)
|
||||
|
||||
body, err = thinking.ApplyThinking(body, req.Model, from.String(), "iflow", e.Identifier())
|
||||
@@ -276,7 +274,7 @@ func (e *IFlowExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Au
|
||||
if detail, ok := parseOpenAIStreamUsage(line); ok {
|
||||
reporter.publish(ctx, detail)
|
||||
}
|
||||
chunks := sdktranslator.TranslateStream(ctx, to, from, req.Model, opts.OriginalRequest, body, bytes.Clone(line), ¶m)
|
||||
chunks := sdktranslator.TranslateStream(ctx, to, from, req.Model, bytes.Clone(opts.OriginalRequest), body, bytes.Clone(line), ¶m)
|
||||
for i := range chunks {
|
||||
out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunks[i])}
|
||||
}
|
||||
@@ -298,7 +296,7 @@ func (e *IFlowExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Auth
|
||||
|
||||
from := opts.SourceFormat
|
||||
to := sdktranslator.FromString("openai")
|
||||
body := sdktranslator.TranslateRequest(from, to, baseModel, req.Payload, false)
|
||||
body := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), false)
|
||||
|
||||
enc, err := tokenizerForModel(baseModel)
|
||||
if err != nil {
|
||||
|
||||
@@ -1,471 +0,0 @@
|
||||
package executor
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"context"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
kimiauth "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/kimi"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/thinking"
|
||||
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
|
||||
sdktranslator "github.com/router-for-me/CLIProxyAPI/v6/sdk/translator"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/tidwall/sjson"
|
||||
)
|
||||
|
||||
// KimiExecutor is a stateless executor for Kimi API using OpenAI-compatible chat completions.
|
||||
type KimiExecutor struct {
|
||||
cfg *config.Config
|
||||
}
|
||||
|
||||
// NewKimiExecutor creates a new Kimi executor.
|
||||
func NewKimiExecutor(cfg *config.Config) *KimiExecutor { return &KimiExecutor{cfg: cfg} }
|
||||
|
||||
// Identifier returns the executor identifier.
|
||||
func (e *KimiExecutor) Identifier() string { return "kimi" }
|
||||
|
||||
// PrepareRequest injects Kimi credentials into the outgoing HTTP request.
|
||||
func (e *KimiExecutor) PrepareRequest(req *http.Request, auth *cliproxyauth.Auth) error {
|
||||
if req == nil {
|
||||
return nil
|
||||
}
|
||||
token := kimiCreds(auth)
|
||||
if strings.TrimSpace(token) != "" {
|
||||
req.Header.Set("Authorization", "Bearer "+token)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// HttpRequest injects Kimi credentials into the request and executes it.
|
||||
func (e *KimiExecutor) HttpRequest(ctx context.Context, auth *cliproxyauth.Auth, req *http.Request) (*http.Response, error) {
|
||||
if req == nil {
|
||||
return nil, fmt.Errorf("kimi executor: request is nil")
|
||||
}
|
||||
if ctx == nil {
|
||||
ctx = req.Context()
|
||||
}
|
||||
httpReq := req.WithContext(ctx)
|
||||
if err := e.PrepareRequest(httpReq, auth); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
|
||||
return httpClient.Do(httpReq)
|
||||
}
|
||||
|
||||
// Execute performs a non-streaming chat completion request to Kimi.
|
||||
func (e *KimiExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (resp cliproxyexecutor.Response, err error) {
|
||||
baseModel := thinking.ParseSuffix(req.Model).ModelName
|
||||
|
||||
token := kimiCreds(auth)
|
||||
|
||||
reporter := newUsageReporter(ctx, e.Identifier(), baseModel, auth)
|
||||
defer reporter.trackFailure(ctx, &err)
|
||||
|
||||
from := opts.SourceFormat
|
||||
to := sdktranslator.FromString("openai")
|
||||
originalPayload := bytes.Clone(req.Payload)
|
||||
if len(opts.OriginalRequest) > 0 {
|
||||
originalPayload = bytes.Clone(opts.OriginalRequest)
|
||||
}
|
||||
originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayload, false)
|
||||
body := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), false)
|
||||
|
||||
// Strip kimi- prefix for upstream API
|
||||
upstreamModel := stripKimiPrefix(baseModel)
|
||||
body, err = sjson.SetBytes(body, "model", upstreamModel)
|
||||
if err != nil {
|
||||
return resp, fmt.Errorf("kimi executor: failed to set model in payload: %w", err)
|
||||
}
|
||||
|
||||
body, err = thinking.ApplyThinking(body, req.Model, from.String(), "kimi", e.Identifier())
|
||||
if err != nil {
|
||||
return resp, err
|
||||
}
|
||||
|
||||
requestedModel := payloadRequestedModel(opts, req.Model)
|
||||
body = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated, requestedModel)
|
||||
|
||||
url := kimiauth.KimiAPIBaseURL + "/chat/completions"
|
||||
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(body))
|
||||
if err != nil {
|
||||
return resp, err
|
||||
}
|
||||
applyKimiHeadersWithAuth(httpReq, token, false, auth)
|
||||
var authID, authLabel, authType, authValue string
|
||||
if auth != nil {
|
||||
authID = auth.ID
|
||||
authLabel = auth.Label
|
||||
authType, authValue = auth.AccountInfo()
|
||||
}
|
||||
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{
|
||||
URL: url,
|
||||
Method: http.MethodPost,
|
||||
Headers: httpReq.Header.Clone(),
|
||||
Body: body,
|
||||
Provider: e.Identifier(),
|
||||
AuthID: authID,
|
||||
AuthLabel: authLabel,
|
||||
AuthType: authType,
|
||||
AuthValue: authValue,
|
||||
})
|
||||
|
||||
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
|
||||
httpResp, err := httpClient.Do(httpReq)
|
||||
if err != nil {
|
||||
recordAPIResponseError(ctx, e.cfg, err)
|
||||
return resp, err
|
||||
}
|
||||
defer func() {
|
||||
if errClose := httpResp.Body.Close(); errClose != nil {
|
||||
log.Errorf("kimi executor: close response body error: %v", errClose)
|
||||
}
|
||||
}()
|
||||
recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
|
||||
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
|
||||
b, _ := io.ReadAll(httpResp.Body)
|
||||
appendAPIResponseChunk(ctx, e.cfg, b)
|
||||
logWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
|
||||
err = statusErr{code: httpResp.StatusCode, msg: string(b)}
|
||||
return resp, err
|
||||
}
|
||||
data, err := io.ReadAll(httpResp.Body)
|
||||
if err != nil {
|
||||
recordAPIResponseError(ctx, e.cfg, err)
|
||||
return resp, err
|
||||
}
|
||||
appendAPIResponseChunk(ctx, e.cfg, data)
|
||||
reporter.publish(ctx, parseOpenAIUsage(data))
|
||||
var param any
|
||||
// Note: TranslateNonStream uses req.Model (original with suffix) to preserve
|
||||
// the original model name in the response for client compatibility.
|
||||
out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, bytes.Clone(opts.OriginalRequest), body, data, ¶m)
|
||||
resp = cliproxyexecutor.Response{Payload: []byte(out)}
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
// ExecuteStream performs a streaming chat completion request to Kimi.
|
||||
func (e *KimiExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (stream <-chan cliproxyexecutor.StreamChunk, err error) {
|
||||
baseModel := thinking.ParseSuffix(req.Model).ModelName
|
||||
|
||||
token := kimiCreds(auth)
|
||||
|
||||
reporter := newUsageReporter(ctx, e.Identifier(), baseModel, auth)
|
||||
defer reporter.trackFailure(ctx, &err)
|
||||
|
||||
from := opts.SourceFormat
|
||||
to := sdktranslator.FromString("openai")
|
||||
originalPayload := bytes.Clone(req.Payload)
|
||||
if len(opts.OriginalRequest) > 0 {
|
||||
originalPayload = bytes.Clone(opts.OriginalRequest)
|
||||
}
|
||||
originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayload, true)
|
||||
body := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), true)
|
||||
|
||||
// Strip kimi- prefix for upstream API
|
||||
upstreamModel := stripKimiPrefix(baseModel)
|
||||
body, err = sjson.SetBytes(body, "model", upstreamModel)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("kimi executor: failed to set model in payload: %w", err)
|
||||
}
|
||||
|
||||
body, err = thinking.ApplyThinking(body, req.Model, from.String(), "kimi", e.Identifier())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
body, err = sjson.SetBytes(body, "stream_options.include_usage", true)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("kimi executor: failed to set stream_options in payload: %w", err)
|
||||
}
|
||||
requestedModel := payloadRequestedModel(opts, req.Model)
|
||||
body = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated, requestedModel)
|
||||
|
||||
url := kimiauth.KimiAPIBaseURL + "/chat/completions"
|
||||
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(body))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
applyKimiHeadersWithAuth(httpReq, token, true, auth)
|
||||
var authID, authLabel, authType, authValue string
|
||||
if auth != nil {
|
||||
authID = auth.ID
|
||||
authLabel = auth.Label
|
||||
authType, authValue = auth.AccountInfo()
|
||||
}
|
||||
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{
|
||||
URL: url,
|
||||
Method: http.MethodPost,
|
||||
Headers: httpReq.Header.Clone(),
|
||||
Body: body,
|
||||
Provider: e.Identifier(),
|
||||
AuthID: authID,
|
||||
AuthLabel: authLabel,
|
||||
AuthType: authType,
|
||||
AuthValue: authValue,
|
||||
})
|
||||
|
||||
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
|
||||
httpResp, err := httpClient.Do(httpReq)
|
||||
if err != nil {
|
||||
recordAPIResponseError(ctx, e.cfg, err)
|
||||
return nil, err
|
||||
}
|
||||
recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
|
||||
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
|
||||
b, _ := io.ReadAll(httpResp.Body)
|
||||
appendAPIResponseChunk(ctx, e.cfg, b)
|
||||
logWithRequestID(ctx).Debugf("request error, error status: %d, error message: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
|
||||
if errClose := httpResp.Body.Close(); errClose != nil {
|
||||
log.Errorf("kimi executor: close response body error: %v", errClose)
|
||||
}
|
||||
err = statusErr{code: httpResp.StatusCode, msg: string(b)}
|
||||
return nil, err
|
||||
}
|
||||
out := make(chan cliproxyexecutor.StreamChunk)
|
||||
stream = out
|
||||
go func() {
|
||||
defer close(out)
|
||||
defer func() {
|
||||
if errClose := httpResp.Body.Close(); errClose != nil {
|
||||
log.Errorf("kimi executor: close response body error: %v", errClose)
|
||||
}
|
||||
}()
|
||||
scanner := bufio.NewScanner(httpResp.Body)
|
||||
scanner.Buffer(nil, 1_048_576) // 1MB
|
||||
var param any
|
||||
for scanner.Scan() {
|
||||
line := scanner.Bytes()
|
||||
appendAPIResponseChunk(ctx, e.cfg, line)
|
||||
if detail, ok := parseOpenAIStreamUsage(line); ok {
|
||||
reporter.publish(ctx, detail)
|
||||
}
|
||||
chunks := sdktranslator.TranslateStream(ctx, to, from, req.Model, bytes.Clone(opts.OriginalRequest), body, bytes.Clone(line), ¶m)
|
||||
for i := range chunks {
|
||||
out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunks[i])}
|
||||
}
|
||||
}
|
||||
doneChunks := sdktranslator.TranslateStream(ctx, to, from, req.Model, bytes.Clone(opts.OriginalRequest), body, bytes.Clone([]byte("[DONE]")), ¶m)
|
||||
for i := range doneChunks {
|
||||
out <- cliproxyexecutor.StreamChunk{Payload: []byte(doneChunks[i])}
|
||||
}
|
||||
if errScan := scanner.Err(); errScan != nil {
|
||||
recordAPIResponseError(ctx, e.cfg, errScan)
|
||||
reporter.publishFailure(ctx)
|
||||
out <- cliproxyexecutor.StreamChunk{Err: errScan}
|
||||
}
|
||||
}()
|
||||
return stream, nil
|
||||
}
|
||||
|
||||
// CountTokens estimates token count for Kimi requests.
|
||||
func (e *KimiExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
|
||||
baseModel := thinking.ParseSuffix(req.Model).ModelName
|
||||
|
||||
from := opts.SourceFormat
|
||||
to := sdktranslator.FromString("openai")
|
||||
body := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), false)
|
||||
|
||||
// Use a generic tokenizer for estimation
|
||||
enc, err := tokenizerForModel("gpt-4")
|
||||
if err != nil {
|
||||
return cliproxyexecutor.Response{}, fmt.Errorf("kimi executor: tokenizer init failed: %w", err)
|
||||
}
|
||||
|
||||
count, err := countOpenAIChatTokens(enc, body)
|
||||
if err != nil {
|
||||
return cliproxyexecutor.Response{}, fmt.Errorf("kimi executor: token counting failed: %w", err)
|
||||
}
|
||||
|
||||
usageJSON := buildOpenAIUsageJSON(count)
|
||||
translated := sdktranslator.TranslateTokenCount(ctx, to, from, count, usageJSON)
|
||||
return cliproxyexecutor.Response{Payload: []byte(translated)}, nil
|
||||
}
|
||||
|
||||
// Refresh refreshes the Kimi token using the refresh token.
|
||||
func (e *KimiExecutor) Refresh(ctx context.Context, auth *cliproxyauth.Auth) (*cliproxyauth.Auth, error) {
|
||||
log.Debugf("kimi executor: refresh called")
|
||||
if auth == nil {
|
||||
return nil, fmt.Errorf("kimi executor: auth is nil")
|
||||
}
|
||||
// Expect refresh_token in metadata for OAuth-based accounts
|
||||
var refreshToken string
|
||||
if auth.Metadata != nil {
|
||||
if v, ok := auth.Metadata["refresh_token"].(string); ok && strings.TrimSpace(v) != "" {
|
||||
refreshToken = v
|
||||
}
|
||||
}
|
||||
if strings.TrimSpace(refreshToken) == "" {
|
||||
// Nothing to refresh
|
||||
return auth, nil
|
||||
}
|
||||
|
||||
client := kimiauth.NewDeviceFlowClientWithDeviceID(e.cfg, resolveKimiDeviceID(auth))
|
||||
td, err := client.RefreshToken(ctx, refreshToken)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if auth.Metadata == nil {
|
||||
auth.Metadata = make(map[string]any)
|
||||
}
|
||||
auth.Metadata["access_token"] = td.AccessToken
|
||||
if td.RefreshToken != "" {
|
||||
auth.Metadata["refresh_token"] = td.RefreshToken
|
||||
}
|
||||
if td.ExpiresAt > 0 {
|
||||
exp := time.Unix(td.ExpiresAt, 0).UTC().Format(time.RFC3339)
|
||||
auth.Metadata["expired"] = exp
|
||||
}
|
||||
auth.Metadata["type"] = "kimi"
|
||||
now := time.Now().Format(time.RFC3339)
|
||||
auth.Metadata["last_refresh"] = now
|
||||
return auth, nil
|
||||
}
|
||||
|
||||
// applyKimiHeaders sets required headers for Kimi API requests.
|
||||
// Headers match kimi-cli client for compatibility.
|
||||
func applyKimiHeaders(r *http.Request, token string, stream bool) {
|
||||
r.Header.Set("Content-Type", "application/json")
|
||||
r.Header.Set("Authorization", "Bearer "+token)
|
||||
// Match kimi-cli headers exactly
|
||||
r.Header.Set("User-Agent", "KimiCLI/1.10.6")
|
||||
r.Header.Set("X-Msh-Platform", "kimi_cli")
|
||||
r.Header.Set("X-Msh-Version", "1.10.6")
|
||||
r.Header.Set("X-Msh-Device-Name", getKimiHostname())
|
||||
r.Header.Set("X-Msh-Device-Model", getKimiDeviceModel())
|
||||
r.Header.Set("X-Msh-Device-Id", getKimiDeviceID())
|
||||
if stream {
|
||||
r.Header.Set("Accept", "text/event-stream")
|
||||
return
|
||||
}
|
||||
r.Header.Set("Accept", "application/json")
|
||||
}
|
||||
|
||||
func resolveKimiDeviceIDFromAuth(auth *cliproxyauth.Auth) string {
|
||||
if auth == nil || auth.Metadata == nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
deviceIDRaw, ok := auth.Metadata["device_id"]
|
||||
if !ok {
|
||||
return ""
|
||||
}
|
||||
|
||||
deviceID, ok := deviceIDRaw.(string)
|
||||
if !ok {
|
||||
return ""
|
||||
}
|
||||
|
||||
return strings.TrimSpace(deviceID)
|
||||
}
|
||||
|
||||
func resolveKimiDeviceIDFromStorage(auth *cliproxyauth.Auth) string {
|
||||
if auth == nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
storage, ok := auth.Storage.(*kimiauth.KimiTokenStorage)
|
||||
if !ok || storage == nil {
|
||||
return ""
|
||||
}
|
||||
|
||||
return strings.TrimSpace(storage.DeviceID)
|
||||
}
|
||||
|
||||
func resolveKimiDeviceID(auth *cliproxyauth.Auth) string {
|
||||
deviceID := resolveKimiDeviceIDFromAuth(auth)
|
||||
if deviceID != "" {
|
||||
return deviceID
|
||||
}
|
||||
return resolveKimiDeviceIDFromStorage(auth)
|
||||
}
|
||||
|
||||
func applyKimiHeadersWithAuth(r *http.Request, token string, stream bool, auth *cliproxyauth.Auth) {
|
||||
applyKimiHeaders(r, token, stream)
|
||||
|
||||
if deviceID := resolveKimiDeviceID(auth); deviceID != "" {
|
||||
r.Header.Set("X-Msh-Device-Id", deviceID)
|
||||
}
|
||||
}
|
||||
|
||||
// getKimiHostname returns the machine hostname.
|
||||
func getKimiHostname() string {
|
||||
hostname, err := os.Hostname()
|
||||
if err != nil {
|
||||
return "unknown"
|
||||
}
|
||||
return hostname
|
||||
}
|
||||
|
||||
// getKimiDeviceModel returns a device model string matching kimi-cli format.
|
||||
func getKimiDeviceModel() string {
|
||||
return fmt.Sprintf("%s %s", runtime.GOOS, runtime.GOARCH)
|
||||
}
|
||||
|
||||
// getKimiDeviceID returns a stable device ID, matching kimi-cli storage location.
|
||||
func getKimiDeviceID() string {
|
||||
homeDir, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return "cli-proxy-api-device"
|
||||
}
|
||||
// Check kimi-cli's device_id location first (platform-specific)
|
||||
var kimiShareDir string
|
||||
switch runtime.GOOS {
|
||||
case "darwin":
|
||||
kimiShareDir = filepath.Join(homeDir, "Library", "Application Support", "kimi")
|
||||
case "windows":
|
||||
appData := os.Getenv("APPDATA")
|
||||
if appData == "" {
|
||||
appData = filepath.Join(homeDir, "AppData", "Roaming")
|
||||
}
|
||||
kimiShareDir = filepath.Join(appData, "kimi")
|
||||
default: // linux and other unix-like
|
||||
kimiShareDir = filepath.Join(homeDir, ".local", "share", "kimi")
|
||||
}
|
||||
deviceIDPath := filepath.Join(kimiShareDir, "device_id")
|
||||
if data, err := os.ReadFile(deviceIDPath); err == nil {
|
||||
return strings.TrimSpace(string(data))
|
||||
}
|
||||
return "cli-proxy-api-device"
|
||||
}
|
||||
|
||||
// kimiCreds extracts the access token from auth.
|
||||
func kimiCreds(a *cliproxyauth.Auth) (token string) {
|
||||
if a == nil {
|
||||
return ""
|
||||
}
|
||||
// Check metadata first (OAuth flow stores tokens here)
|
||||
if a.Metadata != nil {
|
||||
if v, ok := a.Metadata["access_token"].(string); ok && strings.TrimSpace(v) != "" {
|
||||
return v
|
||||
}
|
||||
}
|
||||
// Fallback to attributes (API key style)
|
||||
if a.Attributes != nil {
|
||||
if v := a.Attributes["access_token"]; v != "" {
|
||||
return v
|
||||
}
|
||||
if v := a.Attributes["api_key"]; v != "" {
|
||||
return v
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// stripKimiPrefix removes the "kimi-" prefix from model names for the upstream API.
|
||||
func stripKimiPrefix(model string) string {
|
||||
model = strings.TrimSpace(model)
|
||||
if strings.HasPrefix(strings.ToLower(model), "kimi-") {
|
||||
return model[5:]
|
||||
}
|
||||
return model
|
||||
}
|
||||
@@ -80,7 +80,7 @@ func recordAPIRequest(ctx context.Context, cfg *config.Config, info upstreamRequ
|
||||
writeHeaders(builder, info.Headers)
|
||||
builder.WriteString("\nBody:\n")
|
||||
if len(info.Body) > 0 {
|
||||
builder.WriteString(string(info.Body))
|
||||
builder.WriteString(string(bytes.Clone(info.Body)))
|
||||
} else {
|
||||
builder.WriteString("<empty>")
|
||||
}
|
||||
@@ -152,7 +152,7 @@ func appendAPIResponseChunk(ctx context.Context, cfg *config.Config, chunk []byt
|
||||
if cfg == nil || !cfg.RequestLog {
|
||||
return
|
||||
}
|
||||
data := bytes.TrimSpace(chunk)
|
||||
data := bytes.TrimSpace(bytes.Clone(chunk))
|
||||
if len(data) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
@@ -88,13 +88,12 @@ func (e *OpenAICompatExecutor) Execute(ctx context.Context, auth *cliproxyauth.A
|
||||
to = sdktranslator.FromString("openai-response")
|
||||
endpoint = "/responses/compact"
|
||||
}
|
||||
originalPayloadSource := req.Payload
|
||||
originalPayload := bytes.Clone(req.Payload)
|
||||
if len(opts.OriginalRequest) > 0 {
|
||||
originalPayloadSource = opts.OriginalRequest
|
||||
originalPayload = bytes.Clone(opts.OriginalRequest)
|
||||
}
|
||||
originalPayload := originalPayloadSource
|
||||
originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayload, opts.Stream)
|
||||
translated := sdktranslator.TranslateRequest(from, to, baseModel, req.Payload, opts.Stream)
|
||||
translated := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), opts.Stream)
|
||||
requestedModel := payloadRequestedModel(opts, req.Model)
|
||||
translated = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", translated, originalTranslated, requestedModel)
|
||||
if opts.Alt == "responses/compact" {
|
||||
@@ -171,7 +170,7 @@ func (e *OpenAICompatExecutor) Execute(ctx context.Context, auth *cliproxyauth.A
|
||||
reporter.ensurePublished(ctx)
|
||||
// Translate response back to source format when needed
|
||||
var param any
|
||||
out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, opts.OriginalRequest, translated, body, ¶m)
|
||||
out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, bytes.Clone(opts.OriginalRequest), translated, body, ¶m)
|
||||
resp = cliproxyexecutor.Response{Payload: []byte(out)}
|
||||
return resp, nil
|
||||
}
|
||||
@@ -190,13 +189,12 @@ func (e *OpenAICompatExecutor) ExecuteStream(ctx context.Context, auth *cliproxy
|
||||
|
||||
from := opts.SourceFormat
|
||||
to := sdktranslator.FromString("openai")
|
||||
originalPayloadSource := req.Payload
|
||||
originalPayload := bytes.Clone(req.Payload)
|
||||
if len(opts.OriginalRequest) > 0 {
|
||||
originalPayloadSource = opts.OriginalRequest
|
||||
originalPayload = bytes.Clone(opts.OriginalRequest)
|
||||
}
|
||||
originalPayload := originalPayloadSource
|
||||
originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayload, true)
|
||||
translated := sdktranslator.TranslateRequest(from, to, baseModel, req.Payload, true)
|
||||
translated := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), true)
|
||||
requestedModel := payloadRequestedModel(opts, req.Model)
|
||||
translated = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", translated, originalTranslated, requestedModel)
|
||||
|
||||
@@ -285,7 +283,7 @@ func (e *OpenAICompatExecutor) ExecuteStream(ctx context.Context, auth *cliproxy
|
||||
|
||||
// OpenAI-compatible streams are SSE: lines typically prefixed with "data: ".
|
||||
// Pass through translator; it yields one or more chunks for the target schema.
|
||||
chunks := sdktranslator.TranslateStream(ctx, to, from, req.Model, opts.OriginalRequest, translated, bytes.Clone(line), ¶m)
|
||||
chunks := sdktranslator.TranslateStream(ctx, to, from, req.Model, bytes.Clone(opts.OriginalRequest), translated, bytes.Clone(line), ¶m)
|
||||
for i := range chunks {
|
||||
out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunks[i])}
|
||||
}
|
||||
@@ -306,7 +304,7 @@ func (e *OpenAICompatExecutor) CountTokens(ctx context.Context, auth *cliproxyau
|
||||
|
||||
from := opts.SourceFormat
|
||||
to := sdktranslator.FromString("openai")
|
||||
translated := sdktranslator.TranslateRequest(from, to, baseModel, req.Payload, false)
|
||||
translated := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), false)
|
||||
|
||||
modelForCounting := baseModel
|
||||
|
||||
|
||||
@@ -81,13 +81,12 @@ func (e *QwenExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req
|
||||
|
||||
from := opts.SourceFormat
|
||||
to := sdktranslator.FromString("openai")
|
||||
originalPayloadSource := req.Payload
|
||||
originalPayload := bytes.Clone(req.Payload)
|
||||
if len(opts.OriginalRequest) > 0 {
|
||||
originalPayloadSource = opts.OriginalRequest
|
||||
originalPayload = bytes.Clone(opts.OriginalRequest)
|
||||
}
|
||||
originalPayload := originalPayloadSource
|
||||
originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayload, false)
|
||||
body := sdktranslator.TranslateRequest(from, to, baseModel, req.Payload, false)
|
||||
body := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), false)
|
||||
body, _ = sjson.SetBytes(body, "model", baseModel)
|
||||
|
||||
body, err = thinking.ApplyThinking(body, req.Model, from.String(), to.String(), e.Identifier())
|
||||
@@ -151,7 +150,7 @@ func (e *QwenExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req
|
||||
var param any
|
||||
// Note: TranslateNonStream uses req.Model (original with suffix) to preserve
|
||||
// the original model name in the response for client compatibility.
|
||||
out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, opts.OriginalRequest, body, data, ¶m)
|
||||
out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, bytes.Clone(opts.OriginalRequest), body, data, ¶m)
|
||||
resp = cliproxyexecutor.Response{Payload: []byte(out)}
|
||||
return resp, nil
|
||||
}
|
||||
@@ -172,13 +171,12 @@ func (e *QwenExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Aut
|
||||
|
||||
from := opts.SourceFormat
|
||||
to := sdktranslator.FromString("openai")
|
||||
originalPayloadSource := req.Payload
|
||||
originalPayload := bytes.Clone(req.Payload)
|
||||
if len(opts.OriginalRequest) > 0 {
|
||||
originalPayloadSource = opts.OriginalRequest
|
||||
originalPayload = bytes.Clone(opts.OriginalRequest)
|
||||
}
|
||||
originalPayload := originalPayloadSource
|
||||
originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayload, true)
|
||||
body := sdktranslator.TranslateRequest(from, to, baseModel, req.Payload, true)
|
||||
body := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), true)
|
||||
body, _ = sjson.SetBytes(body, "model", baseModel)
|
||||
|
||||
body, err = thinking.ApplyThinking(body, req.Model, from.String(), to.String(), e.Identifier())
|
||||
@@ -255,12 +253,12 @@ func (e *QwenExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Aut
|
||||
if detail, ok := parseOpenAIStreamUsage(line); ok {
|
||||
reporter.publish(ctx, detail)
|
||||
}
|
||||
chunks := sdktranslator.TranslateStream(ctx, to, from, req.Model, opts.OriginalRequest, body, bytes.Clone(line), ¶m)
|
||||
chunks := sdktranslator.TranslateStream(ctx, to, from, req.Model, bytes.Clone(opts.OriginalRequest), body, bytes.Clone(line), ¶m)
|
||||
for i := range chunks {
|
||||
out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunks[i])}
|
||||
}
|
||||
}
|
||||
doneChunks := sdktranslator.TranslateStream(ctx, to, from, req.Model, opts.OriginalRequest, body, []byte("[DONE]"), ¶m)
|
||||
doneChunks := sdktranslator.TranslateStream(ctx, to, from, req.Model, bytes.Clone(opts.OriginalRequest), body, bytes.Clone([]byte("[DONE]")), ¶m)
|
||||
for i := range doneChunks {
|
||||
out <- cliproxyexecutor.StreamChunk{Payload: []byte(doneChunks[i])}
|
||||
}
|
||||
@@ -278,7 +276,7 @@ func (e *QwenExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Auth,
|
||||
|
||||
from := opts.SourceFormat
|
||||
to := sdktranslator.FromString("openai")
|
||||
body := sdktranslator.TranslateRequest(from, to, baseModel, req.Payload, false)
|
||||
body := sdktranslator.TranslateRequest(from, to, baseModel, bytes.Clone(req.Payload), false)
|
||||
|
||||
modelName := gjson.GetBytes(body, "model").String()
|
||||
if strings.TrimSpace(modelName) == "" {
|
||||
|
||||
@@ -7,6 +7,5 @@ import (
|
||||
_ "github.com/router-for-me/CLIProxyAPI/v6/internal/thinking/provider/gemini"
|
||||
_ "github.com/router-for-me/CLIProxyAPI/v6/internal/thinking/provider/geminicli"
|
||||
_ "github.com/router-for-me/CLIProxyAPI/v6/internal/thinking/provider/iflow"
|
||||
_ "github.com/router-for-me/CLIProxyAPI/v6/internal/thinking/provider/kimi"
|
||||
_ "github.com/router-for-me/CLIProxyAPI/v6/internal/thinking/provider/openai"
|
||||
)
|
||||
|
||||
@@ -18,7 +18,6 @@ var providerAppliers = map[string]ProviderApplier{
|
||||
"codex": nil,
|
||||
"iflow": nil,
|
||||
"antigravity": nil,
|
||||
"kimi": nil,
|
||||
}
|
||||
|
||||
// GetProviderApplier returns the ProviderApplier for the given provider name.
|
||||
@@ -327,9 +326,6 @@ func extractThinkingConfig(body []byte, provider string) ThinkingConfig {
|
||||
return config
|
||||
}
|
||||
return extractOpenAIConfig(body)
|
||||
case "kimi":
|
||||
// Kimi uses OpenAI-compatible reasoning_effort format
|
||||
return extractOpenAIConfig(body)
|
||||
default:
|
||||
return ThinkingConfig{}
|
||||
}
|
||||
@@ -392,12 +388,7 @@ func extractGeminiConfig(body []byte, provider string) ThinkingConfig {
|
||||
}
|
||||
|
||||
// Check thinkingLevel first (Gemini 3 format takes precedence)
|
||||
level := gjson.GetBytes(body, prefix+".thinkingLevel")
|
||||
if !level.Exists() {
|
||||
// Google official Gemini Python SDK sends snake_case field names
|
||||
level = gjson.GetBytes(body, prefix+".thinking_level")
|
||||
}
|
||||
if level.Exists() {
|
||||
if level := gjson.GetBytes(body, prefix+".thinkingLevel"); level.Exists() {
|
||||
value := level.String()
|
||||
switch value {
|
||||
case "none":
|
||||
@@ -410,12 +401,7 @@ func extractGeminiConfig(body []byte, provider string) ThinkingConfig {
|
||||
}
|
||||
|
||||
// Check thinkingBudget (Gemini 2.5 format)
|
||||
budget := gjson.GetBytes(body, prefix+".thinkingBudget")
|
||||
if !budget.Exists() {
|
||||
// Google official Gemini Python SDK sends snake_case field names
|
||||
budget = gjson.GetBytes(body, prefix+".thinking_budget")
|
||||
}
|
||||
if budget.Exists() {
|
||||
if budget := gjson.GetBytes(body, prefix+".thinkingBudget"); budget.Exists() {
|
||||
value := int(budget.Int())
|
||||
switch value {
|
||||
case 0:
|
||||
|
||||
@@ -94,10 +94,8 @@ func (a *Applier) applyCompatible(body []byte, config thinking.ThinkingConfig, m
|
||||
}
|
||||
|
||||
func (a *Applier) applyLevelFormat(body []byte, config thinking.ThinkingConfig) ([]byte, error) {
|
||||
// Remove conflicting fields to avoid both thinkingLevel and thinkingBudget in output
|
||||
// Remove conflicting field to avoid both thinkingLevel and thinkingBudget in output
|
||||
result, _ := sjson.DeleteBytes(body, "request.generationConfig.thinkingConfig.thinkingBudget")
|
||||
result, _ = sjson.DeleteBytes(result, "request.generationConfig.thinkingConfig.thinking_budget")
|
||||
result, _ = sjson.DeleteBytes(result, "request.generationConfig.thinkingConfig.thinking_level")
|
||||
// Normalize includeThoughts field name to avoid oneof conflicts in upstream JSON parsing.
|
||||
result, _ = sjson.DeleteBytes(result, "request.generationConfig.thinkingConfig.include_thoughts")
|
||||
|
||||
@@ -116,30 +114,28 @@ func (a *Applier) applyLevelFormat(body []byte, config thinking.ThinkingConfig)
|
||||
|
||||
level := string(config.Level)
|
||||
result, _ = sjson.SetBytes(result, "request.generationConfig.thinkingConfig.thinkingLevel", level)
|
||||
|
||||
// Respect user's explicit includeThoughts setting from original body; default to true if not set
|
||||
// Support both camelCase and snake_case variants
|
||||
includeThoughts := true
|
||||
if inc := gjson.GetBytes(body, "request.generationConfig.thinkingConfig.includeThoughts"); inc.Exists() {
|
||||
includeThoughts = inc.Bool()
|
||||
} else if inc := gjson.GetBytes(body, "request.generationConfig.thinkingConfig.include_thoughts"); inc.Exists() {
|
||||
includeThoughts = inc.Bool()
|
||||
}
|
||||
result, _ = sjson.SetBytes(result, "request.generationConfig.thinkingConfig.includeThoughts", includeThoughts)
|
||||
result, _ = sjson.SetBytes(result, "request.generationConfig.thinkingConfig.includeThoughts", true)
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (a *Applier) applyBudgetFormat(body []byte, config thinking.ThinkingConfig, modelInfo *registry.ModelInfo, isClaude bool) ([]byte, error) {
|
||||
// Remove conflicting fields to avoid both thinkingLevel and thinkingBudget in output
|
||||
// Remove conflicting field to avoid both thinkingLevel and thinkingBudget in output
|
||||
result, _ := sjson.DeleteBytes(body, "request.generationConfig.thinkingConfig.thinkingLevel")
|
||||
result, _ = sjson.DeleteBytes(result, "request.generationConfig.thinkingConfig.thinking_level")
|
||||
result, _ = sjson.DeleteBytes(result, "request.generationConfig.thinkingConfig.thinking_budget")
|
||||
// Normalize includeThoughts field name to avoid oneof conflicts in upstream JSON parsing.
|
||||
result, _ = sjson.DeleteBytes(result, "request.generationConfig.thinkingConfig.include_thoughts")
|
||||
|
||||
budget := config.Budget
|
||||
includeThoughts := false
|
||||
switch config.Mode {
|
||||
case thinking.ModeNone:
|
||||
includeThoughts = false
|
||||
case thinking.ModeAuto:
|
||||
includeThoughts = true
|
||||
default:
|
||||
includeThoughts = budget > 0
|
||||
}
|
||||
|
||||
// Apply Claude-specific constraints first to get the final budget value
|
||||
// Apply Claude-specific constraints
|
||||
if isClaude && modelInfo != nil {
|
||||
budget, result = a.normalizeClaudeBudget(budget, result, modelInfo)
|
||||
// Check if budget was removed entirely
|
||||
@@ -148,37 +144,6 @@ func (a *Applier) applyBudgetFormat(body []byte, config thinking.ThinkingConfig,
|
||||
}
|
||||
}
|
||||
|
||||
// For ModeNone, always set includeThoughts to false regardless of user setting.
|
||||
// This ensures that when user requests budget=0 (disable thinking output),
|
||||
// the includeThoughts is correctly set to false even if budget is clamped to min.
|
||||
if config.Mode == thinking.ModeNone {
|
||||
result, _ = sjson.SetBytes(result, "request.generationConfig.thinkingConfig.thinkingBudget", budget)
|
||||
result, _ = sjson.SetBytes(result, "request.generationConfig.thinkingConfig.includeThoughts", false)
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// Determine includeThoughts: respect user's explicit setting from original body if provided
|
||||
// Support both camelCase and snake_case variants
|
||||
var includeThoughts bool
|
||||
var userSetIncludeThoughts bool
|
||||
if inc := gjson.GetBytes(body, "request.generationConfig.thinkingConfig.includeThoughts"); inc.Exists() {
|
||||
includeThoughts = inc.Bool()
|
||||
userSetIncludeThoughts = true
|
||||
} else if inc := gjson.GetBytes(body, "request.generationConfig.thinkingConfig.include_thoughts"); inc.Exists() {
|
||||
includeThoughts = inc.Bool()
|
||||
userSetIncludeThoughts = true
|
||||
}
|
||||
|
||||
if !userSetIncludeThoughts {
|
||||
// No explicit setting, use default logic based on mode
|
||||
switch config.Mode {
|
||||
case thinking.ModeAuto:
|
||||
includeThoughts = true
|
||||
default:
|
||||
includeThoughts = budget > 0
|
||||
}
|
||||
}
|
||||
|
||||
result, _ = sjson.SetBytes(result, "request.generationConfig.thinkingConfig.thinkingBudget", budget)
|
||||
result, _ = sjson.SetBytes(result, "request.generationConfig.thinkingConfig.includeThoughts", includeThoughts)
|
||||
return result, nil
|
||||
|
||||
@@ -83,6 +83,10 @@ func (a *Applier) Apply(body []byte, config thinking.ThinkingConfig, modelInfo *
|
||||
|
||||
// Ensure max_tokens > thinking.budget_tokens (Anthropic API constraint)
|
||||
result = a.normalizeClaudeBudget(result, config.Budget, modelInfo)
|
||||
|
||||
// When thinking is enabled, Claude API requires assistant messages with tool_use
|
||||
// to have a thinking block. Inject empty thinking block if missing.
|
||||
result = injectThinkingBlockForToolUse(result)
|
||||
return result, nil
|
||||
}
|
||||
|
||||
@@ -149,18 +153,85 @@ func applyCompatibleClaude(body []byte, config thinking.ThinkingConfig) ([]byte,
|
||||
body = []byte(`{}`)
|
||||
}
|
||||
|
||||
var result []byte
|
||||
switch config.Mode {
|
||||
case thinking.ModeNone:
|
||||
result, _ := sjson.SetBytes(body, "thinking.type", "disabled")
|
||||
result, _ = sjson.SetBytes(body, "thinking.type", "disabled")
|
||||
result, _ = sjson.DeleteBytes(result, "thinking.budget_tokens")
|
||||
return result, nil
|
||||
case thinking.ModeAuto:
|
||||
result, _ := sjson.SetBytes(body, "thinking.type", "enabled")
|
||||
result, _ = sjson.SetBytes(body, "thinking.type", "enabled")
|
||||
result, _ = sjson.DeleteBytes(result, "thinking.budget_tokens")
|
||||
return result, nil
|
||||
default:
|
||||
result, _ := sjson.SetBytes(body, "thinking.type", "enabled")
|
||||
result, _ = sjson.SetBytes(body, "thinking.type", "enabled")
|
||||
result, _ = sjson.SetBytes(result, "thinking.budget_tokens", config.Budget)
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// When thinking is enabled, Claude API requires assistant messages with tool_use
|
||||
// to have a thinking block. Inject empty thinking block if missing.
|
||||
result = injectThinkingBlockForToolUse(result)
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// injectThinkingBlockForToolUse adds empty thinking block to assistant messages
|
||||
// that have tool_use but no thinking block. This is required by Claude API when
|
||||
// thinking is enabled.
|
||||
func injectThinkingBlockForToolUse(body []byte) []byte {
|
||||
messages := gjson.GetBytes(body, "messages")
|
||||
if !messages.IsArray() {
|
||||
return body
|
||||
}
|
||||
|
||||
messageArray := messages.Array()
|
||||
modified := false
|
||||
newMessages := "[]"
|
||||
|
||||
for _, msg := range messageArray {
|
||||
role := msg.Get("role").String()
|
||||
if role != "assistant" {
|
||||
newMessages, _ = sjson.SetRaw(newMessages, "-1", msg.Raw)
|
||||
continue
|
||||
}
|
||||
|
||||
content := msg.Get("content")
|
||||
if !content.IsArray() {
|
||||
newMessages, _ = sjson.SetRaw(newMessages, "-1", msg.Raw)
|
||||
continue
|
||||
}
|
||||
|
||||
contentArray := content.Array()
|
||||
hasToolUse := false
|
||||
hasThinking := false
|
||||
|
||||
for _, part := range contentArray {
|
||||
partType := part.Get("type").String()
|
||||
if partType == "tool_use" {
|
||||
hasToolUse = true
|
||||
}
|
||||
if partType == "thinking" {
|
||||
hasThinking = true
|
||||
}
|
||||
}
|
||||
|
||||
if hasToolUse && !hasThinking {
|
||||
// Inject empty thinking block at the beginning of content
|
||||
newContent := "[]"
|
||||
newContent, _ = sjson.SetRaw(newContent, "-1", `{"type":"thinking","thinking":""}`)
|
||||
for _, part := range contentArray {
|
||||
newContent, _ = sjson.SetRaw(newContent, "-1", part.Raw)
|
||||
}
|
||||
msgJSON := msg.Raw
|
||||
msgJSON, _ = sjson.SetRaw(msgJSON, "content", newContent)
|
||||
newMessages, _ = sjson.SetRaw(newMessages, "-1", msgJSON)
|
||||
modified = true
|
||||
continue
|
||||
}
|
||||
|
||||
newMessages, _ = sjson.SetRaw(newMessages, "-1", msg.Raw)
|
||||
}
|
||||
|
||||
if modified {
|
||||
body, _ = sjson.SetRawBytes(body, "messages", []byte(newMessages))
|
||||
}
|
||||
return body
|
||||
}
|
||||
|
||||
187
internal/thinking/provider/claude/apply_test.go
Normal file
187
internal/thinking/provider/claude/apply_test.go
Normal file
@@ -0,0 +1,187 @@
|
||||
package claude
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/thinking"
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
|
||||
func TestInjectThinkingBlockForToolUse(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "assistant with tool_use but no thinking - should inject thinking",
|
||||
input: `{
|
||||
"model": "kimi-k2.5",
|
||||
"messages": [
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": [
|
||||
{"type": "text", "text": "Let me use a tool"},
|
||||
{"type": "tool_use", "id": "tool_1", "name": "test_tool", "input": {}}
|
||||
]
|
||||
}
|
||||
]
|
||||
}`,
|
||||
expected: "thinking",
|
||||
},
|
||||
{
|
||||
name: "assistant with tool_use and thinking - should not modify",
|
||||
input: `{
|
||||
"model": "kimi-k2.5",
|
||||
"messages": [
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": [
|
||||
{"type": "thinking", "thinking": "I need to use a tool"},
|
||||
{"type": "tool_use", "id": "tool_1", "name": "test_tool", "input": {}}
|
||||
]
|
||||
}
|
||||
]
|
||||
}`,
|
||||
expected: "thinking",
|
||||
},
|
||||
{
|
||||
name: "user message with tool_use - should not modify",
|
||||
input: `{
|
||||
"model": "kimi-k2.5",
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "tool_result", "tool_use_id": "tool_1", "content": "result"}
|
||||
]
|
||||
}
|
||||
]
|
||||
}`,
|
||||
expected: "",
|
||||
},
|
||||
{
|
||||
name: "assistant without tool_use - should not modify",
|
||||
input: `{
|
||||
"model": "kimi-k2.5",
|
||||
"messages": [
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": [
|
||||
{"type": "text", "text": "Hello!"}
|
||||
]
|
||||
}
|
||||
]
|
||||
}`,
|
||||
expected: "",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := injectThinkingBlockForToolUse([]byte(tt.input))
|
||||
|
||||
// Check if thinking block exists in assistant messages with tool_use
|
||||
messages := gjson.GetBytes(result, "messages")
|
||||
if !messages.IsArray() {
|
||||
t.Fatal("messages is not an array")
|
||||
}
|
||||
|
||||
for _, msg := range messages.Array() {
|
||||
if msg.Get("role").String() == "assistant" {
|
||||
content := msg.Get("content")
|
||||
if !content.IsArray() {
|
||||
continue
|
||||
}
|
||||
|
||||
hasToolUse := false
|
||||
hasThinking := false
|
||||
for _, part := range content.Array() {
|
||||
partType := part.Get("type").String()
|
||||
if partType == "tool_use" {
|
||||
hasToolUse = true
|
||||
}
|
||||
if partType == "thinking" {
|
||||
hasThinking = true
|
||||
}
|
||||
}
|
||||
|
||||
if hasToolUse && tt.expected == "thinking" && !hasThinking {
|
||||
t.Errorf("Expected thinking block in assistant message with tool_use, but not found")
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestApplyCompatibleClaude(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input string
|
||||
config thinking.ThinkingConfig
|
||||
expectThinking bool
|
||||
}{
|
||||
{
|
||||
name: "thinking enabled with tool_use - should inject thinking block",
|
||||
input: `{
|
||||
"model": "kimi-k2.5",
|
||||
"messages": [
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": [
|
||||
{"type": "tool_use", "id": "tool_1", "name": "test_tool", "input": {}}
|
||||
]
|
||||
}
|
||||
]
|
||||
}`,
|
||||
config: thinking.ThinkingConfig{
|
||||
Mode: thinking.ModeBudget,
|
||||
Budget: 4000,
|
||||
},
|
||||
expectThinking: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result, err := applyCompatibleClaude([]byte(tt.input), tt.config)
|
||||
if err != nil {
|
||||
t.Fatalf("applyCompatibleClaude failed: %v", err)
|
||||
}
|
||||
|
||||
// Check if thinking.type is enabled
|
||||
thinkingType := gjson.GetBytes(result, "thinking.type").String()
|
||||
if thinkingType != "enabled" {
|
||||
t.Errorf("Expected thinking.type=enabled, got %s", thinkingType)
|
||||
}
|
||||
|
||||
// Check if thinking block is injected
|
||||
messages := gjson.GetBytes(result, "messages")
|
||||
if !messages.IsArray() {
|
||||
t.Fatal("messages is not an array")
|
||||
}
|
||||
|
||||
for _, msg := range messages.Array() {
|
||||
if msg.Get("role").String() == "assistant" {
|
||||
content := msg.Get("content")
|
||||
if !content.IsArray() {
|
||||
continue
|
||||
}
|
||||
|
||||
hasThinking := false
|
||||
for _, part := range content.Array() {
|
||||
if part.Get("type").String() == "thinking" {
|
||||
hasThinking = true
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if tt.expectThinking && !hasThinking {
|
||||
t.Errorf("Expected thinking block in assistant message, but not found. Result: %s", string(result))
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -118,10 +118,8 @@ func (a *Applier) applyLevelFormat(body []byte, config thinking.ThinkingConfig)
|
||||
// - ModeNone + Budget>0: forced to think but hide output (includeThoughts=false)
|
||||
// ValidateConfig sets config.Level to the lowest level when ModeNone + Budget > 0.
|
||||
|
||||
// Remove conflicting fields to avoid both thinkingLevel and thinkingBudget in output
|
||||
// Remove conflicting field to avoid both thinkingLevel and thinkingBudget in output
|
||||
result, _ := sjson.DeleteBytes(body, "generationConfig.thinkingConfig.thinkingBudget")
|
||||
result, _ = sjson.DeleteBytes(result, "generationConfig.thinkingConfig.thinking_budget")
|
||||
result, _ = sjson.DeleteBytes(result, "generationConfig.thinkingConfig.thinking_level")
|
||||
// Normalize includeThoughts field name to avoid oneof conflicts in upstream JSON parsing.
|
||||
result, _ = sjson.DeleteBytes(result, "generationConfig.thinkingConfig.include_thoughts")
|
||||
|
||||
@@ -140,58 +138,29 @@ func (a *Applier) applyLevelFormat(body []byte, config thinking.ThinkingConfig)
|
||||
|
||||
level := string(config.Level)
|
||||
result, _ = sjson.SetBytes(result, "generationConfig.thinkingConfig.thinkingLevel", level)
|
||||
|
||||
// Respect user's explicit includeThoughts setting from original body; default to true if not set
|
||||
// Support both camelCase and snake_case variants
|
||||
includeThoughts := true
|
||||
if inc := gjson.GetBytes(body, "generationConfig.thinkingConfig.includeThoughts"); inc.Exists() {
|
||||
includeThoughts = inc.Bool()
|
||||
} else if inc := gjson.GetBytes(body, "generationConfig.thinkingConfig.include_thoughts"); inc.Exists() {
|
||||
includeThoughts = inc.Bool()
|
||||
}
|
||||
result, _ = sjson.SetBytes(result, "generationConfig.thinkingConfig.includeThoughts", includeThoughts)
|
||||
result, _ = sjson.SetBytes(result, "generationConfig.thinkingConfig.includeThoughts", true)
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (a *Applier) applyBudgetFormat(body []byte, config thinking.ThinkingConfig) ([]byte, error) {
|
||||
// Remove conflicting fields to avoid both thinkingLevel and thinkingBudget in output
|
||||
// Remove conflicting field to avoid both thinkingLevel and thinkingBudget in output
|
||||
result, _ := sjson.DeleteBytes(body, "generationConfig.thinkingConfig.thinkingLevel")
|
||||
result, _ = sjson.DeleteBytes(result, "generationConfig.thinkingConfig.thinking_level")
|
||||
result, _ = sjson.DeleteBytes(result, "generationConfig.thinkingConfig.thinking_budget")
|
||||
// Normalize includeThoughts field name to avoid oneof conflicts in upstream JSON parsing.
|
||||
result, _ = sjson.DeleteBytes(result, "generationConfig.thinkingConfig.include_thoughts")
|
||||
|
||||
budget := config.Budget
|
||||
|
||||
// For ModeNone, always set includeThoughts to false regardless of user setting.
|
||||
// This ensures that when user requests budget=0 (disable thinking output),
|
||||
// the includeThoughts is correctly set to false even if budget is clamped to min.
|
||||
if config.Mode == thinking.ModeNone {
|
||||
result, _ = sjson.SetBytes(result, "generationConfig.thinkingConfig.thinkingBudget", budget)
|
||||
result, _ = sjson.SetBytes(result, "generationConfig.thinkingConfig.includeThoughts", false)
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// Determine includeThoughts: respect user's explicit setting from original body if provided
|
||||
// Support both camelCase and snake_case variants
|
||||
var includeThoughts bool
|
||||
var userSetIncludeThoughts bool
|
||||
if inc := gjson.GetBytes(body, "generationConfig.thinkingConfig.includeThoughts"); inc.Exists() {
|
||||
includeThoughts = inc.Bool()
|
||||
userSetIncludeThoughts = true
|
||||
} else if inc := gjson.GetBytes(body, "generationConfig.thinkingConfig.include_thoughts"); inc.Exists() {
|
||||
includeThoughts = inc.Bool()
|
||||
userSetIncludeThoughts = true
|
||||
}
|
||||
|
||||
if !userSetIncludeThoughts {
|
||||
// No explicit setting, use default logic based on mode
|
||||
switch config.Mode {
|
||||
case thinking.ModeAuto:
|
||||
includeThoughts = true
|
||||
default:
|
||||
includeThoughts = budget > 0
|
||||
}
|
||||
// ModeNone semantics:
|
||||
// - ModeNone + Budget=0: completely disable thinking
|
||||
// - ModeNone + Budget>0: forced to think but hide output (includeThoughts=false)
|
||||
// When ZeroAllowed=false, ValidateConfig clamps Budget to Min while preserving ModeNone.
|
||||
includeThoughts := false
|
||||
switch config.Mode {
|
||||
case thinking.ModeNone:
|
||||
includeThoughts = false
|
||||
case thinking.ModeAuto:
|
||||
includeThoughts = true
|
||||
default:
|
||||
includeThoughts = budget > 0
|
||||
}
|
||||
|
||||
result, _ = sjson.SetBytes(result, "generationConfig.thinkingConfig.thinkingBudget", budget)
|
||||
|
||||
@@ -79,10 +79,8 @@ func (a *Applier) applyCompatible(body []byte, config thinking.ThinkingConfig) (
|
||||
}
|
||||
|
||||
func (a *Applier) applyLevelFormat(body []byte, config thinking.ThinkingConfig) ([]byte, error) {
|
||||
// Remove conflicting fields to avoid both thinkingLevel and thinkingBudget in output
|
||||
// Remove conflicting field to avoid both thinkingLevel and thinkingBudget in output
|
||||
result, _ := sjson.DeleteBytes(body, "request.generationConfig.thinkingConfig.thinkingBudget")
|
||||
result, _ = sjson.DeleteBytes(result, "request.generationConfig.thinkingConfig.thinking_budget")
|
||||
result, _ = sjson.DeleteBytes(result, "request.generationConfig.thinkingConfig.thinking_level")
|
||||
// Normalize includeThoughts field name to avoid oneof conflicts in upstream JSON parsing.
|
||||
result, _ = sjson.DeleteBytes(result, "request.generationConfig.thinkingConfig.include_thoughts")
|
||||
|
||||
@@ -101,58 +99,25 @@ func (a *Applier) applyLevelFormat(body []byte, config thinking.ThinkingConfig)
|
||||
|
||||
level := string(config.Level)
|
||||
result, _ = sjson.SetBytes(result, "request.generationConfig.thinkingConfig.thinkingLevel", level)
|
||||
|
||||
// Respect user's explicit includeThoughts setting from original body; default to true if not set
|
||||
// Support both camelCase and snake_case variants
|
||||
includeThoughts := true
|
||||
if inc := gjson.GetBytes(body, "request.generationConfig.thinkingConfig.includeThoughts"); inc.Exists() {
|
||||
includeThoughts = inc.Bool()
|
||||
} else if inc := gjson.GetBytes(body, "request.generationConfig.thinkingConfig.include_thoughts"); inc.Exists() {
|
||||
includeThoughts = inc.Bool()
|
||||
}
|
||||
result, _ = sjson.SetBytes(result, "request.generationConfig.thinkingConfig.includeThoughts", includeThoughts)
|
||||
result, _ = sjson.SetBytes(result, "request.generationConfig.thinkingConfig.includeThoughts", true)
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (a *Applier) applyBudgetFormat(body []byte, config thinking.ThinkingConfig) ([]byte, error) {
|
||||
// Remove conflicting fields to avoid both thinkingLevel and thinkingBudget in output
|
||||
// Remove conflicting field to avoid both thinkingLevel and thinkingBudget in output
|
||||
result, _ := sjson.DeleteBytes(body, "request.generationConfig.thinkingConfig.thinkingLevel")
|
||||
result, _ = sjson.DeleteBytes(result, "request.generationConfig.thinkingConfig.thinking_level")
|
||||
result, _ = sjson.DeleteBytes(result, "request.generationConfig.thinkingConfig.thinking_budget")
|
||||
// Normalize includeThoughts field name to avoid oneof conflicts in upstream JSON parsing.
|
||||
result, _ = sjson.DeleteBytes(result, "request.generationConfig.thinkingConfig.include_thoughts")
|
||||
|
||||
budget := config.Budget
|
||||
|
||||
// For ModeNone, always set includeThoughts to false regardless of user setting.
|
||||
// This ensures that when user requests budget=0 (disable thinking output),
|
||||
// the includeThoughts is correctly set to false even if budget is clamped to min.
|
||||
if config.Mode == thinking.ModeNone {
|
||||
result, _ = sjson.SetBytes(result, "request.generationConfig.thinkingConfig.thinkingBudget", budget)
|
||||
result, _ = sjson.SetBytes(result, "request.generationConfig.thinkingConfig.includeThoughts", false)
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// Determine includeThoughts: respect user's explicit setting from original body if provided
|
||||
// Support both camelCase and snake_case variants
|
||||
var includeThoughts bool
|
||||
var userSetIncludeThoughts bool
|
||||
if inc := gjson.GetBytes(body, "request.generationConfig.thinkingConfig.includeThoughts"); inc.Exists() {
|
||||
includeThoughts = inc.Bool()
|
||||
userSetIncludeThoughts = true
|
||||
} else if inc := gjson.GetBytes(body, "request.generationConfig.thinkingConfig.include_thoughts"); inc.Exists() {
|
||||
includeThoughts = inc.Bool()
|
||||
userSetIncludeThoughts = true
|
||||
}
|
||||
|
||||
if !userSetIncludeThoughts {
|
||||
// No explicit setting, use default logic based on mode
|
||||
switch config.Mode {
|
||||
case thinking.ModeAuto:
|
||||
includeThoughts = true
|
||||
default:
|
||||
includeThoughts = budget > 0
|
||||
}
|
||||
includeThoughts := false
|
||||
switch config.Mode {
|
||||
case thinking.ModeNone:
|
||||
includeThoughts = false
|
||||
case thinking.ModeAuto:
|
||||
includeThoughts = true
|
||||
default:
|
||||
includeThoughts = budget > 0
|
||||
}
|
||||
|
||||
result, _ = sjson.SetBytes(result, "request.generationConfig.thinkingConfig.thinkingBudget", budget)
|
||||
|
||||
@@ -1,126 +0,0 @@
|
||||
// Package kimi implements thinking configuration for Kimi (Moonshot AI) models.
|
||||
//
|
||||
// Kimi models use the OpenAI-compatible reasoning_effort format with discrete levels
|
||||
// (low/medium/high). The provider strips any existing thinking config and applies
|
||||
// the unified ThinkingConfig in OpenAI format.
|
||||
package kimi
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/thinking"
|
||||
"github.com/tidwall/gjson"
|
||||
"github.com/tidwall/sjson"
|
||||
)
|
||||
|
||||
// Applier implements thinking.ProviderApplier for Kimi models.
|
||||
//
|
||||
// Kimi-specific behavior:
|
||||
// - Output format: reasoning_effort (string: low/medium/high)
|
||||
// - Uses OpenAI-compatible format
|
||||
// - Supports budget-to-level conversion
|
||||
type Applier struct{}
|
||||
|
||||
var _ thinking.ProviderApplier = (*Applier)(nil)
|
||||
|
||||
// NewApplier creates a new Kimi thinking applier.
|
||||
func NewApplier() *Applier {
|
||||
return &Applier{}
|
||||
}
|
||||
|
||||
func init() {
|
||||
thinking.RegisterProvider("kimi", NewApplier())
|
||||
}
|
||||
|
||||
// Apply applies thinking configuration to Kimi request body.
|
||||
//
|
||||
// Expected output format:
|
||||
//
|
||||
// {
|
||||
// "reasoning_effort": "high"
|
||||
// }
|
||||
func (a *Applier) Apply(body []byte, config thinking.ThinkingConfig, modelInfo *registry.ModelInfo) ([]byte, error) {
|
||||
if thinking.IsUserDefinedModel(modelInfo) {
|
||||
return applyCompatibleKimi(body, config)
|
||||
}
|
||||
if modelInfo.Thinking == nil {
|
||||
return body, nil
|
||||
}
|
||||
|
||||
if len(body) == 0 || !gjson.ValidBytes(body) {
|
||||
body = []byte(`{}`)
|
||||
}
|
||||
|
||||
var effort string
|
||||
switch config.Mode {
|
||||
case thinking.ModeLevel:
|
||||
if config.Level == "" {
|
||||
return body, nil
|
||||
}
|
||||
effort = string(config.Level)
|
||||
case thinking.ModeNone:
|
||||
// Kimi uses "none" to disable thinking
|
||||
effort = string(thinking.LevelNone)
|
||||
case thinking.ModeBudget:
|
||||
// Convert budget to level using threshold mapping
|
||||
level, ok := thinking.ConvertBudgetToLevel(config.Budget)
|
||||
if !ok {
|
||||
return body, nil
|
||||
}
|
||||
effort = level
|
||||
case thinking.ModeAuto:
|
||||
// Auto mode maps to "auto" effort
|
||||
effort = string(thinking.LevelAuto)
|
||||
default:
|
||||
return body, nil
|
||||
}
|
||||
|
||||
if effort == "" {
|
||||
return body, nil
|
||||
}
|
||||
|
||||
result, err := sjson.SetBytes(body, "reasoning_effort", effort)
|
||||
if err != nil {
|
||||
return body, fmt.Errorf("kimi thinking: failed to set reasoning_effort: %w", err)
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// applyCompatibleKimi applies thinking config for user-defined Kimi models.
|
||||
func applyCompatibleKimi(body []byte, config thinking.ThinkingConfig) ([]byte, error) {
|
||||
if len(body) == 0 || !gjson.ValidBytes(body) {
|
||||
body = []byte(`{}`)
|
||||
}
|
||||
|
||||
var effort string
|
||||
switch config.Mode {
|
||||
case thinking.ModeLevel:
|
||||
if config.Level == "" {
|
||||
return body, nil
|
||||
}
|
||||
effort = string(config.Level)
|
||||
case thinking.ModeNone:
|
||||
effort = string(thinking.LevelNone)
|
||||
if config.Level != "" {
|
||||
effort = string(config.Level)
|
||||
}
|
||||
case thinking.ModeAuto:
|
||||
effort = string(thinking.LevelAuto)
|
||||
case thinking.ModeBudget:
|
||||
// Convert budget to level
|
||||
level, ok := thinking.ConvertBudgetToLevel(config.Budget)
|
||||
if !ok {
|
||||
return body, nil
|
||||
}
|
||||
effort = level
|
||||
default:
|
||||
return body, nil
|
||||
}
|
||||
|
||||
result, err := sjson.SetBytes(body, "reasoning_effort", effort)
|
||||
if err != nil {
|
||||
return body, fmt.Errorf("kimi thinking: failed to set reasoning_effort: %w", err)
|
||||
}
|
||||
return result, nil
|
||||
}
|
||||
@@ -6,6 +6,7 @@
|
||||
package claude
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"strings"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/cache"
|
||||
@@ -36,7 +37,7 @@ import (
|
||||
// - []byte: The transformed request data in Gemini CLI API format
|
||||
func ConvertClaudeRequestToAntigravity(modelName string, inputRawJSON []byte, _ bool) []byte {
|
||||
enableThoughtTranslate := true
|
||||
rawJSON := inputRawJSON
|
||||
rawJSON := bytes.Clone(inputRawJSON)
|
||||
|
||||
// system instruction
|
||||
systemInstructionJSON := ""
|
||||
@@ -114,6 +115,8 @@ func ConvertClaudeRequestToAntigravity(modelName string, inputRawJSON []byte, _
|
||||
if signatureResult.Exists() && signatureResult.String() != "" {
|
||||
arrayClientSignatures := strings.SplitN(signatureResult.String(), "#", 2)
|
||||
if len(arrayClientSignatures) == 2 {
|
||||
// Compare using model group to handle model mapping
|
||||
// e.g., claude-opus-4-5-thinking -> "claude" group should match "claude#signature"
|
||||
if cache.GetModelGroup(modelName) == arrayClientSignatures[0] {
|
||||
clientSignature = arrayClientSignatures[1]
|
||||
}
|
||||
|
||||
@@ -6,6 +6,7 @@
|
||||
package gemini
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
@@ -33,7 +34,7 @@ import (
|
||||
// Returns:
|
||||
// - []byte: The transformed request data in Gemini API format
|
||||
func ConvertGeminiRequestToAntigravity(modelName string, inputRawJSON []byte, _ bool) []byte {
|
||||
rawJSON := inputRawJSON
|
||||
rawJSON := bytes.Clone(inputRawJSON)
|
||||
template := ""
|
||||
template = `{"project":"","request":{},"model":""}`
|
||||
template, _ = sjson.SetRaw(template, "request", string(rawJSON))
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
package chat_completions
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
@@ -27,7 +28,7 @@ const geminiCLIFunctionThoughtSignature = "skip_thought_signature_validator"
|
||||
// Returns:
|
||||
// - []byte: The transformed request data in Gemini CLI API format
|
||||
func ConvertOpenAIRequestToAntigravity(modelName string, inputRawJSON []byte, _ bool) []byte {
|
||||
rawJSON := inputRawJSON
|
||||
rawJSON := bytes.Clone(inputRawJSON)
|
||||
// Base envelope (no default thinkingConfig)
|
||||
out := []byte(`{"project":"","request":{"contents":[]},"model":"gemini-2.5-pro"}`)
|
||||
|
||||
|
||||
@@ -1,12 +1,14 @@
|
||||
package responses
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
|
||||
. "github.com/router-for-me/CLIProxyAPI/v6/internal/translator/antigravity/gemini"
|
||||
. "github.com/router-for-me/CLIProxyAPI/v6/internal/translator/gemini/openai/responses"
|
||||
)
|
||||
|
||||
func ConvertOpenAIResponsesRequestToAntigravity(modelName string, inputRawJSON []byte, stream bool) []byte {
|
||||
rawJSON := inputRawJSON
|
||||
rawJSON := bytes.Clone(inputRawJSON)
|
||||
rawJSON = ConvertOpenAIResponsesRequestToGemini(modelName, rawJSON, stream)
|
||||
return ConvertGeminiRequestToAntigravity(modelName, rawJSON, stream)
|
||||
}
|
||||
|
||||
@@ -6,6 +6,8 @@
|
||||
package geminiCLI
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
|
||||
. "github.com/router-for-me/CLIProxyAPI/v6/internal/translator/claude/gemini"
|
||||
"github.com/tidwall/gjson"
|
||||
"github.com/tidwall/sjson"
|
||||
@@ -28,7 +30,7 @@ import (
|
||||
// Returns:
|
||||
// - []byte: The transformed request data in Claude Code API format
|
||||
func ConvertGeminiCLIRequestToClaude(modelName string, inputRawJSON []byte, stream bool) []byte {
|
||||
rawJSON := inputRawJSON
|
||||
rawJSON := bytes.Clone(inputRawJSON)
|
||||
|
||||
modelResult := gjson.GetBytes(rawJSON, "model")
|
||||
// Extract the inner request object and promote it to the top level
|
||||
|
||||
@@ -6,6 +6,7 @@
|
||||
package gemini
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
@@ -45,7 +46,7 @@ var (
|
||||
// Returns:
|
||||
// - []byte: The transformed request data in Claude Code API format
|
||||
func ConvertGeminiRequestToClaude(modelName string, inputRawJSON []byte, stream bool) []byte {
|
||||
rawJSON := inputRawJSON
|
||||
rawJSON := bytes.Clone(inputRawJSON)
|
||||
|
||||
if account == "" {
|
||||
u, _ := uuid.NewRandom()
|
||||
@@ -115,11 +116,7 @@ func ConvertGeminiRequestToClaude(modelName string, inputRawJSON []byte, stream
|
||||
// Include thoughts configuration for reasoning process visibility
|
||||
// Translator only does format conversion, ApplyThinking handles model capability validation.
|
||||
if thinkingConfig := genConfig.Get("thinkingConfig"); thinkingConfig.Exists() && thinkingConfig.IsObject() {
|
||||
thinkingLevel := thinkingConfig.Get("thinkingLevel")
|
||||
if !thinkingLevel.Exists() {
|
||||
thinkingLevel = thinkingConfig.Get("thinking_level")
|
||||
}
|
||||
if thinkingLevel.Exists() {
|
||||
if thinkingLevel := thinkingConfig.Get("thinkingLevel"); thinkingLevel.Exists() {
|
||||
level := strings.ToLower(strings.TrimSpace(thinkingLevel.String()))
|
||||
switch level {
|
||||
case "":
|
||||
@@ -135,29 +132,23 @@ func ConvertGeminiRequestToClaude(modelName string, inputRawJSON []byte, stream
|
||||
out, _ = sjson.Set(out, "thinking.budget_tokens", budget)
|
||||
}
|
||||
}
|
||||
} else {
|
||||
thinkingBudget := thinkingConfig.Get("thinkingBudget")
|
||||
if !thinkingBudget.Exists() {
|
||||
thinkingBudget = thinkingConfig.Get("thinking_budget")
|
||||
}
|
||||
if thinkingBudget.Exists() {
|
||||
budget := int(thinkingBudget.Int())
|
||||
switch budget {
|
||||
case 0:
|
||||
out, _ = sjson.Set(out, "thinking.type", "disabled")
|
||||
out, _ = sjson.Delete(out, "thinking.budget_tokens")
|
||||
case -1:
|
||||
out, _ = sjson.Set(out, "thinking.type", "enabled")
|
||||
out, _ = sjson.Delete(out, "thinking.budget_tokens")
|
||||
default:
|
||||
out, _ = sjson.Set(out, "thinking.type", "enabled")
|
||||
out, _ = sjson.Set(out, "thinking.budget_tokens", budget)
|
||||
}
|
||||
} else if includeThoughts := thinkingConfig.Get("includeThoughts"); includeThoughts.Exists() && includeThoughts.Type == gjson.True {
|
||||
} else if thinkingBudget := thinkingConfig.Get("thinkingBudget"); thinkingBudget.Exists() {
|
||||
budget := int(thinkingBudget.Int())
|
||||
switch budget {
|
||||
case 0:
|
||||
out, _ = sjson.Set(out, "thinking.type", "disabled")
|
||||
out, _ = sjson.Delete(out, "thinking.budget_tokens")
|
||||
case -1:
|
||||
out, _ = sjson.Set(out, "thinking.type", "enabled")
|
||||
} else if includeThoughts := thinkingConfig.Get("include_thoughts"); includeThoughts.Exists() && includeThoughts.Type == gjson.True {
|
||||
out, _ = sjson.Delete(out, "thinking.budget_tokens")
|
||||
default:
|
||||
out, _ = sjson.Set(out, "thinking.type", "enabled")
|
||||
out, _ = sjson.Set(out, "thinking.budget_tokens", budget)
|
||||
}
|
||||
} else if includeThoughts := thinkingConfig.Get("includeThoughts"); includeThoughts.Exists() && includeThoughts.Type == gjson.True {
|
||||
out, _ = sjson.Set(out, "thinking.type", "enabled")
|
||||
} else if includeThoughts := thinkingConfig.Get("include_thoughts"); includeThoughts.Exists() && includeThoughts.Type == gjson.True {
|
||||
out, _ = sjson.Set(out, "thinking.type", "enabled")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -6,6 +6,7 @@
|
||||
package chat_completions
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
@@ -43,7 +44,7 @@ var (
|
||||
// Returns:
|
||||
// - []byte: The transformed request data in Claude Code API format
|
||||
func ConvertOpenAIRequestToClaude(modelName string, inputRawJSON []byte, stream bool) []byte {
|
||||
rawJSON := inputRawJSON
|
||||
rawJSON := bytes.Clone(inputRawJSON)
|
||||
|
||||
if account == "" {
|
||||
u, _ := uuid.NewRandom()
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package responses
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
@@ -31,7 +32,7 @@ var (
|
||||
// - max_output_tokens -> max_tokens
|
||||
// - stream passthrough via parameter
|
||||
func ConvertOpenAIResponsesRequestToClaude(modelName string, inputRawJSON []byte, stream bool) []byte {
|
||||
rawJSON := inputRawJSON
|
||||
rawJSON := bytes.Clone(inputRawJSON)
|
||||
|
||||
if account == "" {
|
||||
u, _ := uuid.NewRandom()
|
||||
|
||||
@@ -6,6 +6,7 @@
|
||||
package claude
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
@@ -34,7 +35,7 @@ import (
|
||||
// Returns:
|
||||
// - []byte: The transformed request data in internal client format
|
||||
func ConvertClaudeRequestToCodex(modelName string, inputRawJSON []byte, _ bool) []byte {
|
||||
rawJSON := inputRawJSON
|
||||
rawJSON := bytes.Clone(inputRawJSON)
|
||||
|
||||
template := `{"model":"","instructions":"","input":[]}`
|
||||
|
||||
|
||||
@@ -6,6 +6,8 @@
|
||||
package geminiCLI
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
|
||||
. "github.com/router-for-me/CLIProxyAPI/v6/internal/translator/codex/gemini"
|
||||
"github.com/tidwall/gjson"
|
||||
"github.com/tidwall/sjson"
|
||||
@@ -28,7 +30,7 @@ import (
|
||||
// Returns:
|
||||
// - []byte: The transformed request data in Codex API format
|
||||
func ConvertGeminiCLIRequestToCodex(modelName string, inputRawJSON []byte, stream bool) []byte {
|
||||
rawJSON := inputRawJSON
|
||||
rawJSON := bytes.Clone(inputRawJSON)
|
||||
|
||||
rawJSON = []byte(gjson.GetBytes(rawJSON, "request").Raw)
|
||||
rawJSON, _ = sjson.SetBytes(rawJSON, "model", modelName)
|
||||
|
||||
@@ -6,6 +6,7 @@
|
||||
package gemini
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/rand"
|
||||
"fmt"
|
||||
"math/big"
|
||||
@@ -36,7 +37,7 @@ import (
|
||||
// Returns:
|
||||
// - []byte: The transformed request data in Codex API format
|
||||
func ConvertGeminiRequestToCodex(modelName string, inputRawJSON []byte, _ bool) []byte {
|
||||
rawJSON := inputRawJSON
|
||||
rawJSON := bytes.Clone(inputRawJSON)
|
||||
// Base template
|
||||
out := `{"model":"","instructions":"","input":[]}`
|
||||
|
||||
@@ -242,30 +243,19 @@ func ConvertGeminiRequestToCodex(modelName string, inputRawJSON []byte, _ bool)
|
||||
out, _ = sjson.Set(out, "parallel_tool_calls", true)
|
||||
|
||||
// Convert Gemini thinkingConfig to Codex reasoning.effort.
|
||||
// Note: Google official Python SDK sends snake_case fields (thinking_level/thinking_budget).
|
||||
effortSet := false
|
||||
if genConfig := root.Get("generationConfig"); genConfig.Exists() {
|
||||
if thinkingConfig := genConfig.Get("thinkingConfig"); thinkingConfig.Exists() && thinkingConfig.IsObject() {
|
||||
thinkingLevel := thinkingConfig.Get("thinkingLevel")
|
||||
if !thinkingLevel.Exists() {
|
||||
thinkingLevel = thinkingConfig.Get("thinking_level")
|
||||
}
|
||||
if thinkingLevel.Exists() {
|
||||
if thinkingLevel := thinkingConfig.Get("thinkingLevel"); thinkingLevel.Exists() {
|
||||
effort := strings.ToLower(strings.TrimSpace(thinkingLevel.String()))
|
||||
if effort != "" {
|
||||
out, _ = sjson.Set(out, "reasoning.effort", effort)
|
||||
effortSet = true
|
||||
}
|
||||
} else {
|
||||
thinkingBudget := thinkingConfig.Get("thinkingBudget")
|
||||
if !thinkingBudget.Exists() {
|
||||
thinkingBudget = thinkingConfig.Get("thinking_budget")
|
||||
}
|
||||
if thinkingBudget.Exists() {
|
||||
if effort, ok := thinking.ConvertBudgetToLevel(int(thinkingBudget.Int())); ok {
|
||||
out, _ = sjson.Set(out, "reasoning.effort", effort)
|
||||
effortSet = true
|
||||
}
|
||||
} else if thinkingBudget := thinkingConfig.Get("thinkingBudget"); thinkingBudget.Exists() {
|
||||
if effort, ok := thinking.ConvertBudgetToLevel(int(thinkingBudget.Int())); ok {
|
||||
out, _ = sjson.Set(out, "reasoning.effort", effort)
|
||||
effortSet = true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -7,6 +7,8 @@
|
||||
package chat_completions
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
@@ -27,7 +29,7 @@ import (
|
||||
// Returns:
|
||||
// - []byte: The transformed request data in OpenAI Responses API format
|
||||
func ConvertOpenAIRequestToCodex(modelName string, inputRawJSON []byte, stream bool) []byte {
|
||||
rawJSON := inputRawJSON
|
||||
rawJSON := bytes.Clone(inputRawJSON)
|
||||
// Start with empty JSON object
|
||||
out := `{"instructions":""}`
|
||||
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package responses
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
|
||||
"github.com/tidwall/gjson"
|
||||
@@ -8,13 +9,7 @@ import (
|
||||
)
|
||||
|
||||
func ConvertOpenAIResponsesRequestToCodex(modelName string, inputRawJSON []byte, _ bool) []byte {
|
||||
rawJSON := inputRawJSON
|
||||
|
||||
inputResult := gjson.GetBytes(rawJSON, "input")
|
||||
if inputResult.Type == gjson.String {
|
||||
input, _ := sjson.Set(`[{"type":"message","role":"user","content":[{"type":"input_text","text":""}]}]`, "0.content.0.text", inputResult.String())
|
||||
rawJSON, _ = sjson.SetRawBytes(rawJSON, "input", []byte(input))
|
||||
}
|
||||
rawJSON := bytes.Clone(inputRawJSON)
|
||||
|
||||
rawJSON, _ = sjson.SetBytes(rawJSON, "stream", true)
|
||||
rawJSON, _ = sjson.SetBytes(rawJSON, "store", false)
|
||||
|
||||
@@ -35,7 +35,7 @@ const geminiCLIClaudeThoughtSignature = "skip_thought_signature_validator"
|
||||
// Returns:
|
||||
// - []byte: The transformed request data in Gemini CLI API format
|
||||
func ConvertClaudeRequestToCLI(modelName string, inputRawJSON []byte, _ bool) []byte {
|
||||
rawJSON := inputRawJSON
|
||||
rawJSON := bytes.Clone(inputRawJSON)
|
||||
rawJSON = bytes.Replace(rawJSON, []byte(`"url":{"type":"string","format":"uri",`), []byte(`"url":{"type":"string",`), -1)
|
||||
|
||||
// Build output Gemini CLI request JSON
|
||||
@@ -116,19 +116,6 @@ func ConvertClaudeRequestToCLI(modelName string, inputRawJSON []byte, _ bool) []
|
||||
part, _ = sjson.Set(part, "functionResponse.name", funcName)
|
||||
part, _ = sjson.Set(part, "functionResponse.response.result", responseData)
|
||||
contentJSON, _ = sjson.SetRaw(contentJSON, "parts.-1", part)
|
||||
|
||||
case "image":
|
||||
source := contentResult.Get("source")
|
||||
if source.Get("type").String() == "base64" {
|
||||
mimeType := source.Get("media_type").String()
|
||||
data := source.Get("data").String()
|
||||
if mimeType != "" && data != "" {
|
||||
part := `{"inlineData":{"mime_type":"","data":""}}`
|
||||
part, _ = sjson.Set(part, "inlineData.mime_type", mimeType)
|
||||
part, _ = sjson.Set(part, "inlineData.data", data)
|
||||
contentJSON, _ = sjson.SetRaw(contentJSON, "parts.-1", part)
|
||||
}
|
||||
}
|
||||
}
|
||||
return true
|
||||
})
|
||||
|
||||
@@ -6,6 +6,7 @@
|
||||
package gemini
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/translator/gemini/common"
|
||||
@@ -32,7 +33,7 @@ import (
|
||||
// Returns:
|
||||
// - []byte: The transformed request data in Gemini API format
|
||||
func ConvertGeminiRequestToGeminiCLI(_ string, inputRawJSON []byte, _ bool) []byte {
|
||||
rawJSON := inputRawJSON
|
||||
rawJSON := bytes.Clone(inputRawJSON)
|
||||
template := ""
|
||||
template = `{"project":"","request":{},"model":""}`
|
||||
template, _ = sjson.SetRaw(template, "request", string(rawJSON))
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
package chat_completions
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
@@ -27,7 +28,7 @@ const geminiCLIFunctionThoughtSignature = "skip_thought_signature_validator"
|
||||
// Returns:
|
||||
// - []byte: The transformed request data in Gemini CLI API format
|
||||
func ConvertOpenAIRequestToGeminiCLI(modelName string, inputRawJSON []byte, _ bool) []byte {
|
||||
rawJSON := inputRawJSON
|
||||
rawJSON := bytes.Clone(inputRawJSON)
|
||||
// Base envelope (no default thinkingConfig)
|
||||
out := []byte(`{"project":"","request":{"contents":[]},"model":"gemini-2.5-pro"}`)
|
||||
|
||||
|
||||
@@ -14,7 +14,6 @@ import (
|
||||
"time"
|
||||
|
||||
. "github.com/router-for-me/CLIProxyAPI/v6/internal/translator/gemini/openai/chat-completions"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/tidwall/gjson"
|
||||
"github.com/tidwall/sjson"
|
||||
)
|
||||
@@ -86,7 +85,6 @@ func ConvertCliResponseToOpenAI(_ context.Context, _ string, originalRequestRawJ
|
||||
|
||||
// Extract and set usage metadata (token counts).
|
||||
if usageResult := gjson.GetBytes(rawJSON, "response.usageMetadata"); usageResult.Exists() {
|
||||
cachedTokenCount := usageResult.Get("cachedContentTokenCount").Int()
|
||||
if candidatesTokenCountResult := usageResult.Get("candidatesTokenCount"); candidatesTokenCountResult.Exists() {
|
||||
template, _ = sjson.Set(template, "usage.completion_tokens", candidatesTokenCountResult.Int())
|
||||
}
|
||||
@@ -99,14 +97,6 @@ func ConvertCliResponseToOpenAI(_ context.Context, _ string, originalRequestRawJ
|
||||
if thoughtsTokenCount > 0 {
|
||||
template, _ = sjson.Set(template, "usage.completion_tokens_details.reasoning_tokens", thoughtsTokenCount)
|
||||
}
|
||||
// Include cached token count if present (indicates prompt caching is working)
|
||||
if cachedTokenCount > 0 {
|
||||
var err error
|
||||
template, err = sjson.Set(template, "usage.prompt_tokens_details.cached_tokens", cachedTokenCount)
|
||||
if err != nil {
|
||||
log.Warnf("antigravity openai response: failed to set cached_tokens: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Process the main content part of the response.
|
||||
|
||||
@@ -1,12 +1,14 @@
|
||||
package responses
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
|
||||
. "github.com/router-for-me/CLIProxyAPI/v6/internal/translator/gemini-cli/gemini"
|
||||
. "github.com/router-for-me/CLIProxyAPI/v6/internal/translator/gemini/openai/responses"
|
||||
)
|
||||
|
||||
func ConvertOpenAIResponsesRequestToGeminiCLI(modelName string, inputRawJSON []byte, stream bool) []byte {
|
||||
rawJSON := inputRawJSON
|
||||
rawJSON := bytes.Clone(inputRawJSON)
|
||||
rawJSON = ConvertOpenAIResponsesRequestToGemini(modelName, rawJSON, stream)
|
||||
return ConvertGeminiRequestToGeminiCLI(modelName, rawJSON, stream)
|
||||
}
|
||||
|
||||
@@ -28,7 +28,7 @@ const geminiClaudeThoughtSignature = "skip_thought_signature_validator"
|
||||
// Returns:
|
||||
// - []byte: The transformed request in Gemini CLI format.
|
||||
func ConvertClaudeRequestToGemini(modelName string, inputRawJSON []byte, _ bool) []byte {
|
||||
rawJSON := inputRawJSON
|
||||
rawJSON := bytes.Clone(inputRawJSON)
|
||||
rawJSON = bytes.Replace(rawJSON, []byte(`"url":{"type":"string","format":"uri",`), []byte(`"url":{"type":"string",`), -1)
|
||||
|
||||
// Build output Gemini CLI request JSON
|
||||
|
||||
@@ -6,6 +6,7 @@
|
||||
package geminiCLI
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/translator/gemini/common"
|
||||
@@ -18,7 +19,7 @@ import (
|
||||
// It extracts the model name, system instruction, message contents, and tool declarations
|
||||
// from the raw JSON request and returns them in the format expected by the internal client.
|
||||
func ConvertGeminiCLIRequestToGemini(_ string, inputRawJSON []byte, _ bool) []byte {
|
||||
rawJSON := inputRawJSON
|
||||
rawJSON := bytes.Clone(inputRawJSON)
|
||||
modelResult := gjson.GetBytes(rawJSON, "model")
|
||||
rawJSON = []byte(gjson.GetBytes(rawJSON, "request").Raw)
|
||||
rawJSON, _ = sjson.SetBytes(rawJSON, "model", modelResult.String())
|
||||
|
||||
@@ -4,6 +4,7 @@
|
||||
package gemini
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/translator/gemini/common"
|
||||
@@ -18,7 +19,7 @@ import (
|
||||
//
|
||||
// It keeps the payload otherwise unchanged.
|
||||
func ConvertGeminiRequestToGemini(_ string, inputRawJSON []byte, _ bool) []byte {
|
||||
rawJSON := inputRawJSON
|
||||
rawJSON := bytes.Clone(inputRawJSON)
|
||||
// Fast path: if no contents field, only attach safety settings
|
||||
contents := gjson.GetBytes(rawJSON, "contents")
|
||||
if !contents.Exists() {
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
package chat_completions
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
@@ -27,7 +28,7 @@ const geminiFunctionThoughtSignature = "skip_thought_signature_validator"
|
||||
// Returns:
|
||||
// - []byte: The transformed request data in Gemini API format
|
||||
func ConvertOpenAIRequestToGemini(modelName string, inputRawJSON []byte, _ bool) []byte {
|
||||
rawJSON := inputRawJSON
|
||||
rawJSON := bytes.Clone(inputRawJSON)
|
||||
// Base envelope (no default thinkingConfig)
|
||||
out := []byte(`{"contents":[]}`)
|
||||
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package responses
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"strings"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/translator/gemini/common"
|
||||
@@ -11,7 +12,7 @@ import (
|
||||
const geminiResponsesThoughtSignature = "skip_thought_signature_validator"
|
||||
|
||||
func ConvertOpenAIResponsesRequestToGemini(modelName string, inputRawJSON []byte, stream bool) []byte {
|
||||
rawJSON := inputRawJSON
|
||||
rawJSON := bytes.Clone(inputRawJSON)
|
||||
|
||||
// Note: modelName and stream parameters are part of the fixed method signature
|
||||
_ = modelName // Unused but required by interface
|
||||
|
||||
@@ -6,6 +6,7 @@
|
||||
package claude
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"strings"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/thinking"
|
||||
@@ -17,7 +18,7 @@ import (
|
||||
// It extracts the model name, system instruction, message contents, and tool declarations
|
||||
// from the raw JSON request and returns them in the format expected by the OpenAI API.
|
||||
func ConvertClaudeRequestToOpenAI(modelName string, inputRawJSON []byte, stream bool) []byte {
|
||||
rawJSON := inputRawJSON
|
||||
rawJSON := bytes.Clone(inputRawJSON)
|
||||
// Base OpenAI Chat Completions API template
|
||||
out := `{"model":"","messages":[]}`
|
||||
|
||||
@@ -60,10 +61,13 @@ func ConvertClaudeRequestToOpenAI(modelName string, inputRawJSON []byte, stream
|
||||
out, _ = sjson.Set(out, "stream", stream)
|
||||
|
||||
// Thinking: Convert Claude thinking.budget_tokens to OpenAI reasoning_effort
|
||||
// Also track if thinking is enabled to ensure reasoning_content is added for tool_calls
|
||||
thinkingEnabled := false
|
||||
if thinkingConfig := root.Get("thinking"); thinkingConfig.Exists() && thinkingConfig.IsObject() {
|
||||
if thinkingType := thinkingConfig.Get("type"); thinkingType.Exists() {
|
||||
switch thinkingType.String() {
|
||||
case "enabled":
|
||||
thinkingEnabled = true
|
||||
if budgetTokens := thinkingConfig.Get("budget_tokens"); budgetTokens.Exists() {
|
||||
budget := int(budgetTokens.Int())
|
||||
if effort, ok := thinking.ConvertBudgetToLevel(budget); ok && effort != "" {
|
||||
@@ -216,6 +220,10 @@ func ConvertClaudeRequestToOpenAI(modelName string, inputRawJSON []byte, stream
|
||||
// Add reasoning_content if present
|
||||
if hasReasoning {
|
||||
msgJSON, _ = sjson.Set(msgJSON, "reasoning_content", reasoningContent)
|
||||
} else if thinkingEnabled && hasToolCalls {
|
||||
// Claude API requires reasoning_content in assistant messages with tool_calls
|
||||
// when thinking mode is enabled, even if empty
|
||||
msgJSON, _ = sjson.Set(msgJSON, "reasoning_content", "")
|
||||
}
|
||||
|
||||
// Add tool_calls if present (in same message as content)
|
||||
|
||||
@@ -588,3 +588,124 @@ func TestConvertClaudeRequestToOpenAI_AssistantThinkingToolUseThinkingSplit(t *t
|
||||
t.Fatalf("Expected reasoning_content %q, got %q", "t1\n\nt2", got)
|
||||
}
|
||||
}
|
||||
|
||||
// TestConvertClaudeRequestToOpenAI_ThinkingEnabledToolCallsNoReasoning tests that
|
||||
// when thinking mode is enabled and assistant message has tool_calls but no thinking content,
|
||||
// an empty reasoning_content is added to satisfy Claude API requirements.
|
||||
func TestConvertClaudeRequestToOpenAI_ThinkingEnabledToolCallsNoReasoning(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
inputJSON string
|
||||
wantHasReasoningContent bool
|
||||
wantReasoningContent string
|
||||
}{
|
||||
{
|
||||
name: "thinking enabled with tool_calls but no thinking content adds empty reasoning_content",
|
||||
inputJSON: `{
|
||||
"model": "claude-3-opus",
|
||||
"thinking": {"type": "enabled", "budget_tokens": 4000},
|
||||
"messages": [{
|
||||
"role": "assistant",
|
||||
"content": [
|
||||
{"type": "text", "text": "I will help you."},
|
||||
{"type": "tool_use", "id": "tool_1", "name": "read_file", "input": {"path": "/test.txt"}}
|
||||
]
|
||||
}]
|
||||
}`,
|
||||
wantHasReasoningContent: true,
|
||||
wantReasoningContent: "",
|
||||
},
|
||||
{
|
||||
name: "thinking enabled with tool_calls and thinking content uses actual reasoning",
|
||||
inputJSON: `{
|
||||
"model": "claude-3-opus",
|
||||
"thinking": {"type": "enabled", "budget_tokens": 4000},
|
||||
"messages": [{
|
||||
"role": "assistant",
|
||||
"content": [
|
||||
{"type": "thinking", "thinking": "Let me analyze this..."},
|
||||
{"type": "text", "text": "I will help you."},
|
||||
{"type": "tool_use", "id": "tool_1", "name": "read_file", "input": {"path": "/test.txt"}}
|
||||
]
|
||||
}]
|
||||
}`,
|
||||
wantHasReasoningContent: true,
|
||||
wantReasoningContent: "Let me analyze this...",
|
||||
},
|
||||
{
|
||||
name: "thinking disabled with tool_calls does not add reasoning_content",
|
||||
inputJSON: `{
|
||||
"model": "claude-3-opus",
|
||||
"thinking": {"type": "disabled"},
|
||||
"messages": [{
|
||||
"role": "assistant",
|
||||
"content": [
|
||||
{"type": "text", "text": "I will help you."},
|
||||
{"type": "tool_use", "id": "tool_1", "name": "read_file", "input": {"path": "/test.txt"}}
|
||||
]
|
||||
}]
|
||||
}`,
|
||||
wantHasReasoningContent: false,
|
||||
wantReasoningContent: "",
|
||||
},
|
||||
{
|
||||
name: "no thinking config with tool_calls does not add reasoning_content",
|
||||
inputJSON: `{
|
||||
"model": "claude-3-opus",
|
||||
"messages": [{
|
||||
"role": "assistant",
|
||||
"content": [
|
||||
{"type": "text", "text": "I will help you."},
|
||||
{"type": "tool_use", "id": "tool_1", "name": "read_file", "input": {"path": "/test.txt"}}
|
||||
]
|
||||
}]
|
||||
}`,
|
||||
wantHasReasoningContent: false,
|
||||
wantReasoningContent: "",
|
||||
},
|
||||
{
|
||||
name: "thinking enabled without tool_calls and no thinking content does not add reasoning_content",
|
||||
inputJSON: `{
|
||||
"model": "claude-3-opus",
|
||||
"thinking": {"type": "enabled", "budget_tokens": 4000},
|
||||
"messages": [{
|
||||
"role": "assistant",
|
||||
"content": [
|
||||
{"type": "text", "text": "Simple response without tools."}
|
||||
]
|
||||
}]
|
||||
}`,
|
||||
wantHasReasoningContent: false,
|
||||
wantReasoningContent: "",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := ConvertClaudeRequestToOpenAI("test-model", []byte(tt.inputJSON), false)
|
||||
resultJSON := gjson.ParseBytes(result)
|
||||
|
||||
messages := resultJSON.Get("messages").Array()
|
||||
if len(messages) == 0 {
|
||||
t.Fatal("Expected at least one message")
|
||||
}
|
||||
|
||||
assistantMsg := messages[0]
|
||||
if assistantMsg.Get("role").String() != "assistant" {
|
||||
t.Fatalf("Expected assistant message, got %s", assistantMsg.Get("role").String())
|
||||
}
|
||||
|
||||
hasReasoningContent := assistantMsg.Get("reasoning_content").Exists()
|
||||
if hasReasoningContent != tt.wantHasReasoningContent {
|
||||
t.Errorf("reasoning_content existence = %v, want %v", hasReasoningContent, tt.wantHasReasoningContent)
|
||||
}
|
||||
|
||||
if hasReasoningContent {
|
||||
gotReasoningContent := assistantMsg.Get("reasoning_content").String()
|
||||
if gotReasoningContent != tt.wantReasoningContent {
|
||||
t.Errorf("reasoning_content = %q, want %q", gotReasoningContent, tt.wantReasoningContent)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
@@ -6,6 +6,8 @@
|
||||
package geminiCLI
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
|
||||
. "github.com/router-for-me/CLIProxyAPI/v6/internal/translator/openai/gemini"
|
||||
"github.com/tidwall/gjson"
|
||||
"github.com/tidwall/sjson"
|
||||
@@ -15,7 +17,7 @@ import (
|
||||
// It extracts the model name, generation config, message contents, and tool declarations
|
||||
// from the raw JSON request and returns them in the format expected by the OpenAI API.
|
||||
func ConvertGeminiCLIRequestToOpenAI(modelName string, inputRawJSON []byte, stream bool) []byte {
|
||||
rawJSON := inputRawJSON
|
||||
rawJSON := bytes.Clone(inputRawJSON)
|
||||
rawJSON = []byte(gjson.GetBytes(rawJSON, "request").Raw)
|
||||
rawJSON, _ = sjson.SetBytes(rawJSON, "model", modelName)
|
||||
if gjson.GetBytes(rawJSON, "systemInstruction").Exists() {
|
||||
|
||||
@@ -6,6 +6,7 @@
|
||||
package gemini
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/rand"
|
||||
"fmt"
|
||||
"math/big"
|
||||
@@ -20,7 +21,7 @@ import (
|
||||
// It extracts the model name, generation config, message contents, and tool declarations
|
||||
// from the raw JSON request and returns them in the format expected by the OpenAI API.
|
||||
func ConvertGeminiRequestToOpenAI(modelName string, inputRawJSON []byte, stream bool) []byte {
|
||||
rawJSON := inputRawJSON
|
||||
rawJSON := bytes.Clone(inputRawJSON)
|
||||
// Base OpenAI Chat Completions API template
|
||||
out := `{"model":"","messages":[]}`
|
||||
|
||||
@@ -82,27 +83,16 @@ func ConvertGeminiRequestToOpenAI(modelName string, inputRawJSON []byte, stream
|
||||
}
|
||||
|
||||
// Map Gemini thinkingConfig to OpenAI reasoning_effort.
|
||||
// Always perform conversion to support allowCompat models that may not be in registry.
|
||||
// Note: Google official Python SDK sends snake_case fields (thinking_level/thinking_budget).
|
||||
// Always perform conversion to support allowCompat models that may not be in registry
|
||||
if thinkingConfig := genConfig.Get("thinkingConfig"); thinkingConfig.Exists() && thinkingConfig.IsObject() {
|
||||
thinkingLevel := thinkingConfig.Get("thinkingLevel")
|
||||
if !thinkingLevel.Exists() {
|
||||
thinkingLevel = thinkingConfig.Get("thinking_level")
|
||||
}
|
||||
if thinkingLevel.Exists() {
|
||||
if thinkingLevel := thinkingConfig.Get("thinkingLevel"); thinkingLevel.Exists() {
|
||||
effort := strings.ToLower(strings.TrimSpace(thinkingLevel.String()))
|
||||
if effort != "" {
|
||||
out, _ = sjson.Set(out, "reasoning_effort", effort)
|
||||
}
|
||||
} else {
|
||||
thinkingBudget := thinkingConfig.Get("thinkingBudget")
|
||||
if !thinkingBudget.Exists() {
|
||||
thinkingBudget = thinkingConfig.Get("thinking_budget")
|
||||
}
|
||||
if thinkingBudget.Exists() {
|
||||
if effort, ok := thinking.ConvertBudgetToLevel(int(thinkingBudget.Int())); ok {
|
||||
out, _ = sjson.Set(out, "reasoning_effort", effort)
|
||||
}
|
||||
} else if thinkingBudget := thinkingConfig.Get("thinkingBudget"); thinkingBudget.Exists() {
|
||||
if effort, ok := thinking.ConvertBudgetToLevel(int(thinkingBudget.Int())); ok {
|
||||
out, _ = sjson.Set(out, "reasoning_effort", effort)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
package chat_completions
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"github.com/tidwall/sjson"
|
||||
)
|
||||
|
||||
@@ -24,7 +25,7 @@ func ConvertOpenAIRequestToOpenAI(modelName string, inputRawJSON []byte, _ bool)
|
||||
// If there's an error, return the original JSON or handle the error appropriately.
|
||||
// For now, we'll return the original, but in a real scenario, logging or a more robust error
|
||||
// handling mechanism would be needed.
|
||||
return inputRawJSON
|
||||
return bytes.Clone(inputRawJSON)
|
||||
}
|
||||
return updatedJSON
|
||||
}
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package responses
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"strings"
|
||||
|
||||
"github.com/tidwall/gjson"
|
||||
@@ -27,7 +28,7 @@ import (
|
||||
// Returns:
|
||||
// - []byte: The transformed request data in OpenAI chat completions format
|
||||
func ConvertOpenAIResponsesRequestToOpenAIChatCompletions(modelName string, inputRawJSON []byte, stream bool) []byte {
|
||||
rawJSON := inputRawJSON
|
||||
rawJSON := bytes.Clone(inputRawJSON)
|
||||
// Base OpenAI chat completions template with default values
|
||||
out := `{"model":"","messages":[],"stream":false}`
|
||||
|
||||
@@ -67,9 +68,6 @@ func ConvertOpenAIResponsesRequestToOpenAIChatCompletions(modelName string, inpu
|
||||
case "message", "":
|
||||
// Handle regular message conversion
|
||||
role := item.Get("role").String()
|
||||
if role == "developer" {
|
||||
role = "user"
|
||||
}
|
||||
message := `{"role":"","content":""}`
|
||||
message, _ = sjson.Set(message, "role", role)
|
||||
|
||||
@@ -169,8 +167,7 @@ func ConvertOpenAIResponsesRequestToOpenAIChatCompletions(modelName string, inpu
|
||||
// Only function tools need structural conversion because Chat Completions nests details under "function".
|
||||
toolType := tool.Get("type").String()
|
||||
if toolType != "" && toolType != "function" && tool.IsObject() {
|
||||
// Almost all providers lack built-in tools, so we just ignore them.
|
||||
// chatCompletionsTools = append(chatCompletionsTools, tool.Value())
|
||||
chatCompletionsTools = append(chatCompletionsTools, tool.Value())
|
||||
return true
|
||||
}
|
||||
|
||||
|
||||
@@ -11,7 +11,6 @@ func TestIsClaudeThinkingModel(t *testing.T) {
|
||||
// Claude thinking models - should return true
|
||||
{"claude-sonnet-4-5-thinking", "claude-sonnet-4-5-thinking", true},
|
||||
{"claude-opus-4-5-thinking", "claude-opus-4-5-thinking", true},
|
||||
{"claude-opus-4-6-thinking", "claude-opus-4-6-thinking", true},
|
||||
{"Claude-Sonnet-Thinking uppercase", "Claude-Sonnet-4-5-Thinking", true},
|
||||
{"claude thinking mixed case", "Claude-THINKING-Model", true},
|
||||
|
||||
|
||||
@@ -61,20 +61,14 @@ func cleanJSONSchema(jsonStr string, addPlaceholder bool) string {
|
||||
|
||||
// removeKeywords removes all occurrences of specified keywords from the JSON schema.
|
||||
func removeKeywords(jsonStr string, keywords []string) string {
|
||||
deletePaths := make([]string, 0)
|
||||
pathsByField := findPathsByFields(jsonStr, keywords)
|
||||
for _, key := range keywords {
|
||||
for _, p := range pathsByField[key] {
|
||||
for _, p := range findPaths(jsonStr, key) {
|
||||
if isPropertyDefinition(trimSuffix(p, "."+key)) {
|
||||
continue
|
||||
}
|
||||
deletePaths = append(deletePaths, p)
|
||||
jsonStr, _ = sjson.Delete(jsonStr, p)
|
||||
}
|
||||
}
|
||||
sortByDepth(deletePaths)
|
||||
for _, p := range deletePaths {
|
||||
jsonStr, _ = sjson.Delete(jsonStr, p)
|
||||
}
|
||||
return jsonStr
|
||||
}
|
||||
|
||||
@@ -241,9 +235,8 @@ var unsupportedConstraints = []string{
|
||||
}
|
||||
|
||||
func moveConstraintsToDescription(jsonStr string) string {
|
||||
pathsByField := findPathsByFields(jsonStr, unsupportedConstraints)
|
||||
for _, key := range unsupportedConstraints {
|
||||
for _, p := range pathsByField[key] {
|
||||
for _, p := range findPaths(jsonStr, key) {
|
||||
val := gjson.Get(jsonStr, p)
|
||||
if !val.Exists() || val.IsObject() || val.IsArray() {
|
||||
continue
|
||||
@@ -431,21 +424,14 @@ func removeUnsupportedKeywords(jsonStr string) string {
|
||||
"$schema", "$defs", "definitions", "const", "$ref", "additionalProperties",
|
||||
"propertyNames", // Gemini doesn't support property name validation
|
||||
)
|
||||
|
||||
deletePaths := make([]string, 0)
|
||||
pathsByField := findPathsByFields(jsonStr, keywords)
|
||||
for _, key := range keywords {
|
||||
for _, p := range pathsByField[key] {
|
||||
for _, p := range findPaths(jsonStr, key) {
|
||||
if isPropertyDefinition(trimSuffix(p, "."+key)) {
|
||||
continue
|
||||
}
|
||||
deletePaths = append(deletePaths, p)
|
||||
jsonStr, _ = sjson.Delete(jsonStr, p)
|
||||
}
|
||||
}
|
||||
sortByDepth(deletePaths)
|
||||
for _, p := range deletePaths {
|
||||
jsonStr, _ = sjson.Delete(jsonStr, p)
|
||||
}
|
||||
// Remove x-* extension fields (e.g., x-google-enum-descriptions) that are not supported by Gemini API
|
||||
jsonStr = removeExtensionFields(jsonStr)
|
||||
return jsonStr
|
||||
@@ -595,42 +581,6 @@ func findPaths(jsonStr, field string) []string {
|
||||
return paths
|
||||
}
|
||||
|
||||
func findPathsByFields(jsonStr string, fields []string) map[string][]string {
|
||||
set := make(map[string]struct{}, len(fields))
|
||||
for _, field := range fields {
|
||||
set[field] = struct{}{}
|
||||
}
|
||||
paths := make(map[string][]string, len(set))
|
||||
walkForFields(gjson.Parse(jsonStr), "", set, paths)
|
||||
return paths
|
||||
}
|
||||
|
||||
func walkForFields(value gjson.Result, path string, fields map[string]struct{}, paths map[string][]string) {
|
||||
switch value.Type {
|
||||
case gjson.JSON:
|
||||
value.ForEach(func(key, val gjson.Result) bool {
|
||||
keyStr := key.String()
|
||||
safeKey := escapeGJSONPathKey(keyStr)
|
||||
|
||||
var childPath string
|
||||
if path == "" {
|
||||
childPath = safeKey
|
||||
} else {
|
||||
childPath = path + "." + safeKey
|
||||
}
|
||||
|
||||
if _, ok := fields[keyStr]; ok {
|
||||
paths[keyStr] = append(paths[keyStr], childPath)
|
||||
}
|
||||
|
||||
walkForFields(val, childPath, fields, paths)
|
||||
return true
|
||||
})
|
||||
case gjson.String, gjson.Number, gjson.True, gjson.False, gjson.Null:
|
||||
// Terminal types - no further traversal needed
|
||||
}
|
||||
}
|
||||
|
||||
func sortByDepth(paths []string) {
|
||||
sort.Slice(paths, func(i, j int) bool { return len(paths[i]) > len(paths[j]) })
|
||||
}
|
||||
@@ -717,9 +667,6 @@ func orDefault(val, def string) string {
|
||||
}
|
||||
|
||||
func escapeGJSONPathKey(key string) string {
|
||||
if strings.IndexAny(key, ".*?") == -1 {
|
||||
return key
|
||||
}
|
||||
return gjsonPathKeyReplacer.Replace(key)
|
||||
}
|
||||
|
||||
|
||||
@@ -6,6 +6,7 @@ package util
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/tidwall/gjson"
|
||||
"github.com/tidwall/sjson"
|
||||
@@ -32,15 +33,15 @@ func Walk(value gjson.Result, path, field string, paths *[]string) {
|
||||
// . -> \.
|
||||
// * -> \*
|
||||
// ? -> \?
|
||||
keyStr := key.String()
|
||||
safeKey := escapeGJSONPathKey(keyStr)
|
||||
var keyReplacer = strings.NewReplacer(".", "\\.", "*", "\\*", "?", "\\?")
|
||||
safeKey := keyReplacer.Replace(key.String())
|
||||
|
||||
if path == "" {
|
||||
childPath = safeKey
|
||||
} else {
|
||||
childPath = path + "." + safeKey
|
||||
}
|
||||
if keyStr == field {
|
||||
if key.String() == field {
|
||||
*paths = append(*paths, childPath)
|
||||
}
|
||||
Walk(val, childPath, field, paths)
|
||||
@@ -86,6 +87,15 @@ func RenameKey(jsonStr, oldKeyPath, newKeyPath string) (string, error) {
|
||||
return finalJson, nil
|
||||
}
|
||||
|
||||
func DeleteKey(jsonStr, keyName string) string {
|
||||
paths := make([]string, 0)
|
||||
Walk(gjson.Parse(jsonStr), "", keyName, &paths)
|
||||
for _, p := range paths {
|
||||
jsonStr, _ = sjson.Delete(jsonStr, p)
|
||||
}
|
||||
return jsonStr
|
||||
}
|
||||
|
||||
// FixJSON converts non-standard JSON that uses single quotes for strings into
|
||||
// RFC 8259-compliant JSON by converting those single-quoted strings to
|
||||
// double-quoted strings with proper escaping.
|
||||
|
||||
@@ -6,7 +6,6 @@ import (
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io/fs"
|
||||
"os"
|
||||
@@ -16,7 +15,6 @@ import (
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/watcher/diff"
|
||||
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
@@ -74,7 +72,6 @@ func (w *Watcher) reloadClients(rescanAuth bool, affectedOAuthProviders []string
|
||||
w.clientsMutex.Lock()
|
||||
|
||||
w.lastAuthHashes = make(map[string]string)
|
||||
w.lastAuthContents = make(map[string]*coreauth.Auth)
|
||||
if resolvedAuthDir, errResolveAuthDir := util.ResolveAuthDir(cfg.AuthDir); errResolveAuthDir != nil {
|
||||
log.Errorf("failed to resolve auth directory for hash cache: %v", errResolveAuthDir)
|
||||
} else if resolvedAuthDir != "" {
|
||||
@@ -87,11 +84,6 @@ func (w *Watcher) reloadClients(rescanAuth bool, affectedOAuthProviders []string
|
||||
sum := sha256.Sum256(data)
|
||||
normalizedPath := w.normalizeAuthPath(path)
|
||||
w.lastAuthHashes[normalizedPath] = hex.EncodeToString(sum[:])
|
||||
// Parse and cache auth content for future diff comparisons
|
||||
var auth coreauth.Auth
|
||||
if errParse := json.Unmarshal(data, &auth); errParse == nil {
|
||||
w.lastAuthContents[normalizedPath] = &auth
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
@@ -135,13 +127,6 @@ func (w *Watcher) addOrUpdateClient(path string) {
|
||||
curHash := hex.EncodeToString(sum[:])
|
||||
normalized := w.normalizeAuthPath(path)
|
||||
|
||||
// Parse new auth content for diff comparison
|
||||
var newAuth coreauth.Auth
|
||||
if errParse := json.Unmarshal(data, &newAuth); errParse != nil {
|
||||
log.Errorf("failed to parse auth file %s: %v", filepath.Base(path), errParse)
|
||||
return
|
||||
}
|
||||
|
||||
w.clientsMutex.Lock()
|
||||
|
||||
cfg := w.config
|
||||
@@ -156,26 +141,7 @@ func (w *Watcher) addOrUpdateClient(path string) {
|
||||
return
|
||||
}
|
||||
|
||||
// Get old auth for diff comparison
|
||||
var oldAuth *coreauth.Auth
|
||||
if w.lastAuthContents != nil {
|
||||
oldAuth = w.lastAuthContents[normalized]
|
||||
}
|
||||
|
||||
// Compute and log field changes
|
||||
if changes := diff.BuildAuthChangeDetails(oldAuth, &newAuth); len(changes) > 0 {
|
||||
log.Debugf("auth field changes for %s:", filepath.Base(path))
|
||||
for _, c := range changes {
|
||||
log.Debugf(" %s", c)
|
||||
}
|
||||
}
|
||||
|
||||
// Update caches
|
||||
w.lastAuthHashes[normalized] = curHash
|
||||
if w.lastAuthContents == nil {
|
||||
w.lastAuthContents = make(map[string]*coreauth.Auth)
|
||||
}
|
||||
w.lastAuthContents[normalized] = &newAuth
|
||||
|
||||
w.clientsMutex.Unlock() // Unlock before the callback
|
||||
|
||||
@@ -194,7 +160,6 @@ func (w *Watcher) removeClient(path string) {
|
||||
|
||||
cfg := w.config
|
||||
delete(w.lastAuthHashes, normalized)
|
||||
delete(w.lastAuthContents, normalized)
|
||||
|
||||
w.clientsMutex.Unlock() // Release the lock before the callback
|
||||
|
||||
|
||||
@@ -1,44 +0,0 @@
|
||||
// auth_diff.go computes human-readable diffs for auth file field changes.
|
||||
package diff
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||
)
|
||||
|
||||
// BuildAuthChangeDetails computes a redacted, human-readable list of auth field changes.
|
||||
// Only prefix, proxy_url, and disabled fields are tracked; sensitive data is never printed.
|
||||
func BuildAuthChangeDetails(oldAuth, newAuth *coreauth.Auth) []string {
|
||||
changes := make([]string, 0, 3)
|
||||
|
||||
// Handle nil cases by using empty Auth as default
|
||||
if oldAuth == nil {
|
||||
oldAuth = &coreauth.Auth{}
|
||||
}
|
||||
if newAuth == nil {
|
||||
return changes
|
||||
}
|
||||
|
||||
// Compare prefix
|
||||
oldPrefix := strings.TrimSpace(oldAuth.Prefix)
|
||||
newPrefix := strings.TrimSpace(newAuth.Prefix)
|
||||
if oldPrefix != newPrefix {
|
||||
changes = append(changes, fmt.Sprintf("prefix: %s -> %s", oldPrefix, newPrefix))
|
||||
}
|
||||
|
||||
// Compare proxy_url (redacted)
|
||||
oldProxy := strings.TrimSpace(oldAuth.ProxyURL)
|
||||
newProxy := strings.TrimSpace(newAuth.ProxyURL)
|
||||
if oldProxy != newProxy {
|
||||
changes = append(changes, fmt.Sprintf("proxy_url: %s -> %s", formatProxyURL(oldProxy), formatProxyURL(newProxy)))
|
||||
}
|
||||
|
||||
// Compare disabled
|
||||
if oldAuth.Disabled != newAuth.Disabled {
|
||||
changes = append(changes, fmt.Sprintf("disabled: %t -> %t", oldAuth.Disabled, newAuth.Disabled))
|
||||
}
|
||||
|
||||
return changes
|
||||
}
|
||||
@@ -27,12 +27,6 @@ func BuildConfigChangeDetails(oldCfg, newCfg *config.Config) []string {
|
||||
if oldCfg.Debug != newCfg.Debug {
|
||||
changes = append(changes, fmt.Sprintf("debug: %t -> %t", oldCfg.Debug, newCfg.Debug))
|
||||
}
|
||||
if oldCfg.Pprof.Enable != newCfg.Pprof.Enable {
|
||||
changes = append(changes, fmt.Sprintf("pprof.enable: %t -> %t", oldCfg.Pprof.Enable, newCfg.Pprof.Enable))
|
||||
}
|
||||
if strings.TrimSpace(oldCfg.Pprof.Addr) != strings.TrimSpace(newCfg.Pprof.Addr) {
|
||||
changes = append(changes, fmt.Sprintf("pprof.addr: %s -> %s", strings.TrimSpace(oldCfg.Pprof.Addr), strings.TrimSpace(newCfg.Pprof.Addr)))
|
||||
}
|
||||
if oldCfg.LoggingToFile != newCfg.LoggingToFile {
|
||||
changes = append(changes, fmt.Sprintf("logging-to-file: %t -> %t", oldCfg.LoggingToFile, newCfg.LoggingToFile))
|
||||
}
|
||||
|
||||
@@ -38,7 +38,6 @@ type Watcher struct {
|
||||
reloadCallback func(*config.Config)
|
||||
watcher *fsnotify.Watcher
|
||||
lastAuthHashes map[string]string
|
||||
lastAuthContents map[string]*coreauth.Auth
|
||||
lastRemoveTimes map[string]time.Time
|
||||
lastConfigHash string
|
||||
authQueue chan<- AuthUpdate
|
||||
|
||||
@@ -155,6 +155,20 @@ func requestExecutionMetadata(ctx context.Context) map[string]any {
|
||||
return map[string]any{idempotencyKeyMetadataKey: key}
|
||||
}
|
||||
|
||||
func mergeMetadata(base, overlay map[string]any) map[string]any {
|
||||
if len(base) == 0 && len(overlay) == 0 {
|
||||
return nil
|
||||
}
|
||||
out := make(map[string]any, len(base)+len(overlay))
|
||||
for k, v := range base {
|
||||
out[k] = v
|
||||
}
|
||||
for k, v := range overlay {
|
||||
out[k] = v
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// BaseAPIHandler contains the handlers for API endpoints.
|
||||
// It holds a pool of clients to interact with the backend service and manages
|
||||
// load balancing, client selection, and configuration.
|
||||
@@ -241,16 +255,15 @@ func (h *BaseAPIHandler) GetContextWithCancel(handler interfaces.APIHandler, c *
|
||||
parentCtx = logging.WithRequestID(parentCtx, requestID)
|
||||
}
|
||||
}
|
||||
newCtx, cancel := context.WithCancel(parentCtx)
|
||||
if requestCtx != nil && requestCtx != parentCtx {
|
||||
go func() {
|
||||
select {
|
||||
case <-requestCtx.Done():
|
||||
cancel()
|
||||
case <-newCtx.Done():
|
||||
}
|
||||
}()
|
||||
|
||||
// Use requestCtx as base if available to preserve amp context values (fallback_models, etc.)
|
||||
// Falls back to parentCtx if no request context
|
||||
baseCtx := parentCtx
|
||||
if requestCtx != nil {
|
||||
baseCtx = requestCtx
|
||||
}
|
||||
|
||||
newCtx, cancel := context.WithCancel(baseCtx)
|
||||
newCtx = context.WithValue(newCtx, "gin", c)
|
||||
newCtx = context.WithValue(newCtx, "handler", handler)
|
||||
return newCtx, func(params ...interface{}) {
|
||||
@@ -377,18 +390,14 @@ func (h *BaseAPIHandler) ExecuteWithAuthManager(ctx context.Context, handlerType
|
||||
}
|
||||
reqMeta := requestExecutionMetadata(ctx)
|
||||
reqMeta[coreexecutor.RequestedModelMetadataKey] = normalizedModel
|
||||
payload := rawJSON
|
||||
if len(payload) == 0 {
|
||||
payload = nil
|
||||
}
|
||||
req := coreexecutor.Request{
|
||||
Model: normalizedModel,
|
||||
Payload: payload,
|
||||
Payload: cloneBytes(rawJSON),
|
||||
}
|
||||
opts := coreexecutor.Options{
|
||||
Stream: false,
|
||||
Alt: alt,
|
||||
OriginalRequest: rawJSON,
|
||||
OriginalRequest: cloneBytes(rawJSON),
|
||||
SourceFormat: sdktranslator.FromString(handlerType),
|
||||
}
|
||||
opts.Metadata = reqMeta
|
||||
@@ -408,7 +417,7 @@ func (h *BaseAPIHandler) ExecuteWithAuthManager(ctx context.Context, handlerType
|
||||
}
|
||||
return nil, &interfaces.ErrorMessage{StatusCode: status, Error: err, Addon: addon}
|
||||
}
|
||||
return resp.Payload, nil
|
||||
return cloneBytes(resp.Payload), nil
|
||||
}
|
||||
|
||||
// ExecuteCountWithAuthManager executes a non-streaming request via the core auth manager.
|
||||
@@ -420,18 +429,14 @@ func (h *BaseAPIHandler) ExecuteCountWithAuthManager(ctx context.Context, handle
|
||||
}
|
||||
reqMeta := requestExecutionMetadata(ctx)
|
||||
reqMeta[coreexecutor.RequestedModelMetadataKey] = normalizedModel
|
||||
payload := rawJSON
|
||||
if len(payload) == 0 {
|
||||
payload = nil
|
||||
}
|
||||
req := coreexecutor.Request{
|
||||
Model: normalizedModel,
|
||||
Payload: payload,
|
||||
Payload: cloneBytes(rawJSON),
|
||||
}
|
||||
opts := coreexecutor.Options{
|
||||
Stream: false,
|
||||
Alt: alt,
|
||||
OriginalRequest: rawJSON,
|
||||
OriginalRequest: cloneBytes(rawJSON),
|
||||
SourceFormat: sdktranslator.FromString(handlerType),
|
||||
}
|
||||
opts.Metadata = reqMeta
|
||||
@@ -451,7 +456,7 @@ func (h *BaseAPIHandler) ExecuteCountWithAuthManager(ctx context.Context, handle
|
||||
}
|
||||
return nil, &interfaces.ErrorMessage{StatusCode: status, Error: err, Addon: addon}
|
||||
}
|
||||
return resp.Payload, nil
|
||||
return cloneBytes(resp.Payload), nil
|
||||
}
|
||||
|
||||
// ExecuteStreamWithAuthManager executes a streaming request via the core auth manager.
|
||||
@@ -466,18 +471,14 @@ func (h *BaseAPIHandler) ExecuteStreamWithAuthManager(ctx context.Context, handl
|
||||
}
|
||||
reqMeta := requestExecutionMetadata(ctx)
|
||||
reqMeta[coreexecutor.RequestedModelMetadataKey] = normalizedModel
|
||||
payload := rawJSON
|
||||
if len(payload) == 0 {
|
||||
payload = nil
|
||||
}
|
||||
req := coreexecutor.Request{
|
||||
Model: normalizedModel,
|
||||
Payload: payload,
|
||||
Payload: cloneBytes(rawJSON),
|
||||
}
|
||||
opts := coreexecutor.Options{
|
||||
Stream: true,
|
||||
Alt: alt,
|
||||
OriginalRequest: rawJSON,
|
||||
OriginalRequest: cloneBytes(rawJSON),
|
||||
SourceFormat: sdktranslator.FromString(handlerType),
|
||||
}
|
||||
opts.Metadata = reqMeta
|
||||
@@ -666,6 +667,17 @@ func cloneBytes(src []byte) []byte {
|
||||
return dst
|
||||
}
|
||||
|
||||
func cloneMetadata(src map[string]any) map[string]any {
|
||||
if len(src) == 0 {
|
||||
return nil
|
||||
}
|
||||
dst := make(map[string]any, len(src))
|
||||
for k, v := range src {
|
||||
dst[k] = v
|
||||
}
|
||||
return dst
|
||||
}
|
||||
|
||||
// WriteErrorResponse writes an error message to the response writer using the HTTP status embedded in the message.
|
||||
func (h *BaseAPIHandler) WriteErrorResponse(c *gin.Context, msg *interfaces.ErrorMessage) {
|
||||
status := http.StatusInternalServerError
|
||||
@@ -696,7 +708,7 @@ func (h *BaseAPIHandler) WriteErrorResponse(c *gin.Context, msg *interfaces.Erro
|
||||
var previous []byte
|
||||
if existing, exists := c.Get("API_RESPONSE"); exists {
|
||||
if existingBytes, ok := existing.([]byte); ok && len(existingBytes) > 0 {
|
||||
previous = existingBytes
|
||||
previous = bytes.Clone(existingBytes)
|
||||
}
|
||||
}
|
||||
appendAPIResponse(c, body)
|
||||
|
||||
@@ -18,7 +18,6 @@ type ManagementTokenRequester interface {
|
||||
RequestCodexToken(*gin.Context)
|
||||
RequestAntigravityToken(*gin.Context)
|
||||
RequestQwenToken(*gin.Context)
|
||||
RequestKimiToken(*gin.Context)
|
||||
RequestIFlowToken(*gin.Context)
|
||||
RequestIFlowCookieToken(*gin.Context)
|
||||
GetAuthStatus(c *gin.Context)
|
||||
@@ -56,10 +55,6 @@ func (m *managementTokenRequester) RequestQwenToken(c *gin.Context) {
|
||||
m.handler.RequestQwenToken(c)
|
||||
}
|
||||
|
||||
func (m *managementTokenRequester) RequestKimiToken(c *gin.Context) {
|
||||
m.handler.RequestKimiToken(c)
|
||||
}
|
||||
|
||||
func (m *managementTokenRequester) RequestIFlowToken(c *gin.Context) {
|
||||
m.handler.RequestIFlowToken(c)
|
||||
}
|
||||
|
||||
123
sdk/auth/kimi.go
123
sdk/auth/kimi.go
@@ -1,123 +0,0 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/auth/kimi"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/browser"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// kimiRefreshLead is the duration before token expiry when refresh should occur.
|
||||
var kimiRefreshLead = 5 * time.Minute
|
||||
|
||||
// KimiAuthenticator implements the OAuth device flow login for Kimi (Moonshot AI).
|
||||
type KimiAuthenticator struct{}
|
||||
|
||||
// NewKimiAuthenticator constructs a new Kimi authenticator.
|
||||
func NewKimiAuthenticator() Authenticator {
|
||||
return &KimiAuthenticator{}
|
||||
}
|
||||
|
||||
// Provider returns the provider key for kimi.
|
||||
func (KimiAuthenticator) Provider() string {
|
||||
return "kimi"
|
||||
}
|
||||
|
||||
// RefreshLead returns the duration before token expiry when refresh should occur.
|
||||
// Kimi tokens expire and need to be refreshed before expiry.
|
||||
func (KimiAuthenticator) RefreshLead() *time.Duration {
|
||||
return &kimiRefreshLead
|
||||
}
|
||||
|
||||
// Login initiates the Kimi device flow authentication.
|
||||
func (a KimiAuthenticator) Login(ctx context.Context, cfg *config.Config, opts *LoginOptions) (*coreauth.Auth, error) {
|
||||
if cfg == nil {
|
||||
return nil, fmt.Errorf("cliproxy auth: configuration is required")
|
||||
}
|
||||
if opts == nil {
|
||||
opts = &LoginOptions{}
|
||||
}
|
||||
|
||||
authSvc := kimi.NewKimiAuth(cfg)
|
||||
|
||||
// Start the device flow
|
||||
fmt.Println("Starting Kimi authentication...")
|
||||
deviceCode, err := authSvc.StartDeviceFlow(ctx)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("kimi: failed to start device flow: %w", err)
|
||||
}
|
||||
|
||||
// Display the verification URL
|
||||
verificationURL := deviceCode.VerificationURIComplete
|
||||
if verificationURL == "" {
|
||||
verificationURL = deviceCode.VerificationURI
|
||||
}
|
||||
|
||||
fmt.Printf("\nTo authenticate, please visit:\n%s\n\n", verificationURL)
|
||||
if deviceCode.UserCode != "" {
|
||||
fmt.Printf("User code: %s\n\n", deviceCode.UserCode)
|
||||
}
|
||||
|
||||
// Try to open the browser automatically
|
||||
if !opts.NoBrowser {
|
||||
if browser.IsAvailable() {
|
||||
if errOpen := browser.OpenURL(verificationURL); errOpen != nil {
|
||||
log.Warnf("Failed to open browser automatically: %v", errOpen)
|
||||
} else {
|
||||
fmt.Println("Browser opened automatically.")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fmt.Println("Waiting for authorization...")
|
||||
if deviceCode.ExpiresIn > 0 {
|
||||
fmt.Printf("(This will timeout in %d seconds if not authorized)\n", deviceCode.ExpiresIn)
|
||||
}
|
||||
|
||||
// Wait for user authorization
|
||||
authBundle, err := authSvc.WaitForAuthorization(ctx, deviceCode)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("kimi: %w", err)
|
||||
}
|
||||
|
||||
// Create the token storage
|
||||
tokenStorage := authSvc.CreateTokenStorage(authBundle)
|
||||
|
||||
// Build metadata with token information
|
||||
metadata := map[string]any{
|
||||
"type": "kimi",
|
||||
"access_token": authBundle.TokenData.AccessToken,
|
||||
"refresh_token": authBundle.TokenData.RefreshToken,
|
||||
"token_type": authBundle.TokenData.TokenType,
|
||||
"scope": authBundle.TokenData.Scope,
|
||||
"timestamp": time.Now().UnixMilli(),
|
||||
}
|
||||
|
||||
if authBundle.TokenData.ExpiresAt > 0 {
|
||||
exp := time.Unix(authBundle.TokenData.ExpiresAt, 0).UTC().Format(time.RFC3339)
|
||||
metadata["expired"] = exp
|
||||
}
|
||||
if strings.TrimSpace(authBundle.DeviceID) != "" {
|
||||
metadata["device_id"] = strings.TrimSpace(authBundle.DeviceID)
|
||||
}
|
||||
|
||||
// Generate a unique filename
|
||||
fileName := fmt.Sprintf("kimi-%d.json", time.Now().UnixMilli())
|
||||
|
||||
fmt.Println("\nKimi authentication successful!")
|
||||
|
||||
return &coreauth.Auth{
|
||||
ID: fileName,
|
||||
Provider: a.Provider(),
|
||||
FileName: fileName,
|
||||
Label: "Kimi User",
|
||||
Storage: tokenStorage,
|
||||
Metadata: metadata,
|
||||
}, nil
|
||||
}
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user