mirror of
https://github.com/router-for-me/CLIProxyAPI.git
synced 2026-02-02 20:40:52 +08:00
Compare commits
20 Commits
adedb16d35
...
9299897e04
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
9299897e04 | ||
|
|
527a269799 | ||
|
|
2fe0b6cd2d | ||
|
|
eeb1812d60 | ||
|
|
c82d8e250a | ||
|
|
73db4e64f6 | ||
|
|
69ca0a8fac | ||
|
|
3b04e11544 | ||
|
|
e0927afa40 | ||
|
|
f97d9f3e11 | ||
|
|
6d8609e457 | ||
|
|
d216adeffc | ||
|
|
bb09708c02 | ||
|
|
1150d972a1 | ||
|
|
13bb7cf704 | ||
|
|
8bce696a7c | ||
|
|
6db8d2a28e | ||
|
|
6da7ed53f2 | ||
|
|
fe6043aec7 | ||
|
|
bc32096e9c |
@@ -138,6 +138,10 @@ Windows desktop app built with Tauri + React for monitoring AI coding assistant
|
|||||||
|
|
||||||
A lightweight web admin panel for CLIProxyAPI with health checks, resource monitoring, real-time logs, auto-update, request statistics and pricing display. Supports one-click installation and systemd service.
|
A lightweight web admin panel for CLIProxyAPI with health checks, resource monitoring, real-time logs, auto-update, request statistics and pricing display. Supports one-click installation and systemd service.
|
||||||
|
|
||||||
|
### [CLIProxyAPI Tray](https://github.com/kitephp/CLIProxyAPI_Tray)
|
||||||
|
|
||||||
|
A Windows tray application implemented using PowerShell scripts, without relying on any third-party libraries. The main features include: automatic creation of shortcuts, silent running, password management, channel switching (Main / Plus), and automatic downloading and updating.
|
||||||
|
|
||||||
> [!NOTE]
|
> [!NOTE]
|
||||||
> If you developed a project based on CLIProxyAPI, please open a PR to add it to this list.
|
> If you developed a project based on CLIProxyAPI, please open a PR to add it to this list.
|
||||||
|
|
||||||
|
|||||||
@@ -148,6 +148,10 @@ Windows 桌面应用,基于 Tauri + React 构建,用于通过 CLIProxyAPI
|
|||||||
|
|
||||||
基于 Next.js 的实现,灵感来自 CLIProxyAPI,易于安装使用;自研格式转换(OpenAI/Claude/Gemini/Ollama)、组合系统与自动回退、多账户管理(指数退避)、Next.js Web 控制台,并支持 Cursor、Claude Code、Cline、RooCode 等 CLI 工具,无需 API 密钥。
|
基于 Next.js 的实现,灵感来自 CLIProxyAPI,易于安装使用;自研格式转换(OpenAI/Claude/Gemini/Ollama)、组合系统与自动回退、多账户管理(指数退避)、Next.js Web 控制台,并支持 Cursor、Claude Code、Cline、RooCode 等 CLI 工具,无需 API 密钥。
|
||||||
|
|
||||||
|
### [CLIProxyAPI Tray](https://github.com/kitephp/CLIProxyAPI_Tray)
|
||||||
|
|
||||||
|
Windows 托盘应用,基于 PowerShell 脚本实现,不依赖任何第三方库。主要功能包括:自动创建快捷方式、静默运行、密码管理、通道切换(Main / Plus)以及自动下载与更新。
|
||||||
|
|
||||||
> [!NOTE]
|
> [!NOTE]
|
||||||
> 如果你开发了 CLIProxyAPI 的移植或衍生项目,请提交 PR 将其添加到此列表中。
|
> 如果你开发了 CLIProxyAPI 的移植或衍生项目,请提交 PR 将其添加到此列表中。
|
||||||
|
|
||||||
|
|||||||
@@ -50,6 +50,10 @@ logging-to-file: false
|
|||||||
# files are deleted until within the limit. Set to 0 to disable.
|
# files are deleted until within the limit. Set to 0 to disable.
|
||||||
logs-max-total-size-mb: 0
|
logs-max-total-size-mb: 0
|
||||||
|
|
||||||
|
# Maximum number of error log files retained when request logging is disabled.
|
||||||
|
# When exceeded, the oldest error log files are deleted. Default is 10. Set to 0 to disable cleanup.
|
||||||
|
error-logs-max-files: 10
|
||||||
|
|
||||||
# When false, disable in-memory usage statistics aggregation
|
# When false, disable in-memory usage statistics aggregation
|
||||||
usage-statistics-enabled: false
|
usage-statistics-enabled: false
|
||||||
|
|
||||||
@@ -285,24 +289,31 @@ oauth-model-alias:
|
|||||||
# default: # Default rules only set parameters when they are missing in the payload.
|
# default: # Default rules only set parameters when they are missing in the payload.
|
||||||
# - models:
|
# - models:
|
||||||
# - name: "gemini-2.5-pro" # Supports wildcards (e.g., "gemini-*")
|
# - name: "gemini-2.5-pro" # Supports wildcards (e.g., "gemini-*")
|
||||||
# protocol: "gemini" # restricts the rule to a specific protocol, options: openai, gemini, claude, codex
|
# protocol: "gemini" # restricts the rule to a specific protocol, options: openai, gemini, claude, codex, antigravity
|
||||||
# params: # JSON path (gjson/sjson syntax) -> value
|
# params: # JSON path (gjson/sjson syntax) -> value
|
||||||
# "generationConfig.thinkingConfig.thinkingBudget": 32768
|
# "generationConfig.thinkingConfig.thinkingBudget": 32768
|
||||||
# default-raw: # Default raw rules set parameters using raw JSON when missing (must be valid JSON).
|
# default-raw: # Default raw rules set parameters using raw JSON when missing (must be valid JSON).
|
||||||
# - models:
|
# - models:
|
||||||
# - name: "gemini-2.5-pro" # Supports wildcards (e.g., "gemini-*")
|
# - name: "gemini-2.5-pro" # Supports wildcards (e.g., "gemini-*")
|
||||||
# protocol: "gemini" # restricts the rule to a specific protocol, options: openai, gemini, claude, codex
|
# protocol: "gemini" # restricts the rule to a specific protocol, options: openai, gemini, claude, codex, antigravity
|
||||||
# params: # JSON path (gjson/sjson syntax) -> raw JSON value (strings are used as-is, must be valid JSON)
|
# params: # JSON path (gjson/sjson syntax) -> raw JSON value (strings are used as-is, must be valid JSON)
|
||||||
# "generationConfig.responseJsonSchema": "{\"type\":\"object\",\"properties\":{\"answer\":{\"type\":\"string\"}}}"
|
# "generationConfig.responseJsonSchema": "{\"type\":\"object\",\"properties\":{\"answer\":{\"type\":\"string\"}}}"
|
||||||
# override: # Override rules always set parameters, overwriting any existing values.
|
# override: # Override rules always set parameters, overwriting any existing values.
|
||||||
# - models:
|
# - models:
|
||||||
# - name: "gpt-*" # Supports wildcards (e.g., "gpt-*")
|
# - name: "gpt-*" # Supports wildcards (e.g., "gpt-*")
|
||||||
# protocol: "codex" # restricts the rule to a specific protocol, options: openai, gemini, claude, codex
|
# protocol: "codex" # restricts the rule to a specific protocol, options: openai, gemini, claude, codex, antigravity
|
||||||
# params: # JSON path (gjson/sjson syntax) -> value
|
# params: # JSON path (gjson/sjson syntax) -> value
|
||||||
# "reasoning.effort": "high"
|
# "reasoning.effort": "high"
|
||||||
# override-raw: # Override raw rules always set parameters using raw JSON (must be valid JSON).
|
# override-raw: # Override raw rules always set parameters using raw JSON (must be valid JSON).
|
||||||
# - models:
|
# - models:
|
||||||
# - name: "gpt-*" # Supports wildcards (e.g., "gpt-*")
|
# - name: "gpt-*" # Supports wildcards (e.g., "gpt-*")
|
||||||
# protocol: "codex" # restricts the rule to a specific protocol, options: openai, gemini, claude, codex
|
# protocol: "codex" # restricts the rule to a specific protocol, options: openai, gemini, claude, codex, antigravity
|
||||||
# params: # JSON path (gjson/sjson syntax) -> raw JSON value (strings are used as-is, must be valid JSON)
|
# params: # JSON path (gjson/sjson syntax) -> raw JSON value (strings are used as-is, must be valid JSON)
|
||||||
# "response_format": "{\"type\":\"json_schema\",\"json_schema\":{\"name\":\"answer\",\"schema\":{\"type\":\"object\"}}}"
|
# "response_format": "{\"type\":\"json_schema\",\"json_schema\":{\"name\":\"answer\",\"schema\":{\"type\":\"object\"}}}"
|
||||||
|
# filter: # Filter rules remove specified parameters from the payload.
|
||||||
|
# - models:
|
||||||
|
# - name: "gemini-2.5-pro" # Supports wildcards (e.g., "gemini-*")
|
||||||
|
# protocol: "gemini" # restricts the rule to a specific protocol, options: openai, gemini, claude, codex, antigravity
|
||||||
|
# params: # JSON paths (gjson/sjson syntax) to remove from the payload
|
||||||
|
# - "generationConfig.thinkingConfig.thinkingBudget"
|
||||||
|
# - "generationConfig.responseJsonSchema"
|
||||||
|
|||||||
@@ -205,7 +205,7 @@ func main() {
|
|||||||
// Optional: add a simple middleware + custom request logger
|
// Optional: add a simple middleware + custom request logger
|
||||||
api.WithMiddleware(func(c *gin.Context) { c.Header("X-Example", "custom-provider"); c.Next() }),
|
api.WithMiddleware(func(c *gin.Context) { c.Header("X-Example", "custom-provider"); c.Next() }),
|
||||||
api.WithRequestLoggerFactory(func(cfg *config.Config, cfgPath string) logging.RequestLogger {
|
api.WithRequestLoggerFactory(func(cfg *config.Config, cfgPath string) logging.RequestLogger {
|
||||||
return logging.NewFileRequestLogger(true, "logs", filepath.Dir(cfgPath))
|
return logging.NewFileRequestLoggerWithOptions(true, "logs", filepath.Dir(cfgPath), cfg.ErrorLogsMaxFiles)
|
||||||
}),
|
}),
|
||||||
).
|
).
|
||||||
WithHooks(hooks).
|
WithHooks(hooks).
|
||||||
|
|||||||
@@ -222,6 +222,26 @@ func (h *Handler) PutLogsMaxTotalSizeMB(c *gin.Context) {
|
|||||||
h.persist(c)
|
h.persist(c)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ErrorLogsMaxFiles
|
||||||
|
func (h *Handler) GetErrorLogsMaxFiles(c *gin.Context) {
|
||||||
|
c.JSON(200, gin.H{"error-logs-max-files": h.cfg.ErrorLogsMaxFiles})
|
||||||
|
}
|
||||||
|
func (h *Handler) PutErrorLogsMaxFiles(c *gin.Context) {
|
||||||
|
var body struct {
|
||||||
|
Value *int `json:"value"`
|
||||||
|
}
|
||||||
|
if errBindJSON := c.ShouldBindJSON(&body); errBindJSON != nil || body.Value == nil {
|
||||||
|
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid body"})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
value := *body.Value
|
||||||
|
if value < 0 {
|
||||||
|
value = 10
|
||||||
|
}
|
||||||
|
h.cfg.ErrorLogsMaxFiles = value
|
||||||
|
h.persist(c)
|
||||||
|
}
|
||||||
|
|
||||||
// Request log
|
// Request log
|
||||||
func (h *Handler) GetRequestLog(c *gin.Context) { c.JSON(200, gin.H{"request-log": h.cfg.RequestLog}) }
|
func (h *Handler) GetRequestLog(c *gin.Context) { c.JSON(200, gin.H{"request-log": h.cfg.RequestLog}) }
|
||||||
func (h *Handler) PutRequestLog(c *gin.Context) {
|
func (h *Handler) PutRequestLog(c *gin.Context) {
|
||||||
|
|||||||
@@ -10,6 +10,7 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/routing/ctxkeys"
|
||||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/thinking"
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/thinking"
|
||||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
@@ -32,11 +33,13 @@ const (
|
|||||||
)
|
)
|
||||||
|
|
||||||
// MappedModelContextKey is the Gin context key for passing mapped model names.
|
// MappedModelContextKey is the Gin context key for passing mapped model names.
|
||||||
const MappedModelContextKey = "mapped_model"
|
// Deprecated: Use ctxkeys.MappedModel instead.
|
||||||
|
const MappedModelContextKey = string(ctxkeys.MappedModel)
|
||||||
|
|
||||||
// FallbackModelsContextKey is the Gin context key for passing fallback model names.
|
// FallbackModelsContextKey is the Gin context key for passing fallback model names.
|
||||||
// When the primary mapped model fails (e.g., quota exceeded), these models can be tried.
|
// When the primary mapped model fails (e.g., quota exceeded), these models can be tried.
|
||||||
const FallbackModelsContextKey = "fallback_models"
|
// Deprecated: Use ctxkeys.FallbackModels instead.
|
||||||
|
const FallbackModelsContextKey = string(ctxkeys.FallbackModels)
|
||||||
|
|
||||||
// logAmpRouting logs the routing decision for an Amp request with structured fields
|
// logAmpRouting logs the routing decision for an Amp request with structured fields
|
||||||
func logAmpRouting(routeType AmpRouteType, requestedModel, resolvedModel, provider, path string) {
|
func logAmpRouting(routeType AmpRouteType, requestedModel, resolvedModel, provider, path string) {
|
||||||
@@ -83,6 +86,10 @@ func logAmpRouting(routeType AmpRouteType, requestedModel, resolvedModel, provid
|
|||||||
|
|
||||||
// FallbackHandler wraps a standard handler with fallback logic to ampcode.com
|
// FallbackHandler wraps a standard handler with fallback logic to ampcode.com
|
||||||
// when the model's provider is not available in CLIProxyAPI
|
// when the model's provider is not available in CLIProxyAPI
|
||||||
|
//
|
||||||
|
// Deprecated: FallbackHandler is deprecated in favor of routing.ModelRoutingWrapper.
|
||||||
|
// Use routing.NewModelRoutingWrapper() instead for unified routing logic.
|
||||||
|
// This type is kept for backward compatibility and test purposes.
|
||||||
type FallbackHandler struct {
|
type FallbackHandler struct {
|
||||||
getProxy func() *httputil.ReverseProxy
|
getProxy func() *httputil.ReverseProxy
|
||||||
modelMapper ModelMapper
|
modelMapper ModelMapper
|
||||||
@@ -91,6 +98,8 @@ type FallbackHandler struct {
|
|||||||
|
|
||||||
// NewFallbackHandler creates a new fallback handler wrapper
|
// NewFallbackHandler creates a new fallback handler wrapper
|
||||||
// The getProxy function allows lazy evaluation of the proxy (useful when proxy is created after routes)
|
// The getProxy function allows lazy evaluation of the proxy (useful when proxy is created after routes)
|
||||||
|
//
|
||||||
|
// Deprecated: Use routing.NewModelRoutingWrapper() instead.
|
||||||
func NewFallbackHandler(getProxy func() *httputil.ReverseProxy) *FallbackHandler {
|
func NewFallbackHandler(getProxy func() *httputil.ReverseProxy) *FallbackHandler {
|
||||||
return &FallbackHandler{
|
return &FallbackHandler{
|
||||||
getProxy: getProxy,
|
getProxy: getProxy,
|
||||||
@@ -99,6 +108,8 @@ func NewFallbackHandler(getProxy func() *httputil.ReverseProxy) *FallbackHandler
|
|||||||
}
|
}
|
||||||
|
|
||||||
// NewFallbackHandlerWithMapper creates a new fallback handler with model mapping support
|
// NewFallbackHandlerWithMapper creates a new fallback handler with model mapping support
|
||||||
|
//
|
||||||
|
// Deprecated: Use routing.NewModelRoutingWrapper() instead.
|
||||||
func NewFallbackHandlerWithMapper(getProxy func() *httputil.ReverseProxy, mapper ModelMapper, forceModelMappings func() bool) *FallbackHandler {
|
func NewFallbackHandlerWithMapper(getProxy func() *httputil.ReverseProxy, mapper ModelMapper, forceModelMappings func() bool) *FallbackHandler {
|
||||||
if forceModelMappings == nil {
|
if forceModelMappings == nil {
|
||||||
forceModelMappings = func() bool { return false }
|
forceModelMappings = func() bool { return false }
|
||||||
@@ -119,7 +130,11 @@ func (fh *FallbackHandler) SetModelMapper(mapper ModelMapper) {
|
|||||||
// If the model's provider is not configured in CLIProxyAPI, it forwards to ampcode.com
|
// If the model's provider is not configured in CLIProxyAPI, it forwards to ampcode.com
|
||||||
func (fh *FallbackHandler) WrapHandler(handler gin.HandlerFunc) gin.HandlerFunc {
|
func (fh *FallbackHandler) WrapHandler(handler gin.HandlerFunc) gin.HandlerFunc {
|
||||||
return func(c *gin.Context) {
|
return func(c *gin.Context) {
|
||||||
// Swallow ErrAbortHandler panics from ReverseProxy copyResponse to avoid noisy stack traces
|
// Swallow ErrAbortHandler panics from ReverseProxy to avoid noisy stack traces.
|
||||||
|
// ReverseProxy raises this panic when the client connection is closed prematurely
|
||||||
|
// (e.g., user cancels request, network disconnect) or when ServeHTTP is called
|
||||||
|
// with a ResponseWriter that doesn't implement http.CloseNotifier.
|
||||||
|
// This is an expected error condition, not a bug, so we handle it gracefully.
|
||||||
defer func() {
|
defer func() {
|
||||||
if rec := recover(); rec != nil {
|
if rec := recover(); rec != nil {
|
||||||
if err, ok := rec.(error); ok && errors.Is(err, http.ErrAbortHandler) {
|
if err, ok := rec.(error); ok && errors.Is(err, http.ErrAbortHandler) {
|
||||||
@@ -216,6 +231,19 @@ func (fh *FallbackHandler) WrapHandler(handler gin.HandlerFunc) gin.HandlerFunc
|
|||||||
usedMapping := false
|
usedMapping := false
|
||||||
var providers []string
|
var providers []string
|
||||||
|
|
||||||
|
// Helper to apply model mapping and update state
|
||||||
|
applyMapping := func(mappedModels []string, mappedProviders []string) {
|
||||||
|
bodyBytes = rewriteModelInRequest(bodyBytes, mappedModels[0])
|
||||||
|
c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes))
|
||||||
|
c.Set(string(ctxkeys.MappedModel), mappedModels[0])
|
||||||
|
if len(mappedModels) > 1 {
|
||||||
|
c.Set(string(ctxkeys.FallbackModels), mappedModels[1:])
|
||||||
|
}
|
||||||
|
resolvedModel = mappedModels[0]
|
||||||
|
usedMapping = true
|
||||||
|
providers = mappedProviders
|
||||||
|
}
|
||||||
|
|
||||||
// Check if model mappings should be forced ahead of local API keys
|
// Check if model mappings should be forced ahead of local API keys
|
||||||
forceMappings := fh.forceModelMappings != nil && fh.forceModelMappings()
|
forceMappings := fh.forceModelMappings != nil && fh.forceModelMappings()
|
||||||
|
|
||||||
@@ -223,17 +251,7 @@ func (fh *FallbackHandler) WrapHandler(handler gin.HandlerFunc) gin.HandlerFunc
|
|||||||
// FORCE MODE: Check model mappings FIRST (takes precedence over local API keys)
|
// FORCE MODE: Check model mappings FIRST (takes precedence over local API keys)
|
||||||
// This allows users to route Amp requests to their preferred OAuth providers
|
// This allows users to route Amp requests to their preferred OAuth providers
|
||||||
if mappedModels, mappedProviders := resolveMappedModels(); len(mappedModels) > 0 {
|
if mappedModels, mappedProviders := resolveMappedModels(); len(mappedModels) > 0 {
|
||||||
// Mapping found and provider available - rewrite the model in request body
|
applyMapping(mappedModels, mappedProviders)
|
||||||
bodyBytes = rewriteModelInRequest(bodyBytes, mappedModels[0])
|
|
||||||
c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes))
|
|
||||||
// Store mapped model and fallbacks in context for handlers
|
|
||||||
c.Set(MappedModelContextKey, mappedModels[0])
|
|
||||||
if len(mappedModels) > 1 {
|
|
||||||
c.Set(FallbackModelsContextKey, mappedModels[1:])
|
|
||||||
}
|
|
||||||
resolvedModel = mappedModels[0]
|
|
||||||
usedMapping = true
|
|
||||||
providers = mappedProviders
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// If no mapping applied, check for local providers
|
// If no mapping applied, check for local providers
|
||||||
@@ -247,17 +265,7 @@ func (fh *FallbackHandler) WrapHandler(handler gin.HandlerFunc) gin.HandlerFunc
|
|||||||
if len(providers) == 0 {
|
if len(providers) == 0 {
|
||||||
// No providers configured - check if we have a model mapping
|
// No providers configured - check if we have a model mapping
|
||||||
if mappedModels, mappedProviders := resolveMappedModels(); len(mappedModels) > 0 {
|
if mappedModels, mappedProviders := resolveMappedModels(); len(mappedModels) > 0 {
|
||||||
// Mapping found and provider available - rewrite the model in request body
|
applyMapping(mappedModels, mappedProviders)
|
||||||
bodyBytes = rewriteModelInRequest(bodyBytes, mappedModels[0])
|
|
||||||
c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes))
|
|
||||||
// Store mapped model and fallbacks in context for handlers
|
|
||||||
c.Set(MappedModelContextKey, mappedModels[0])
|
|
||||||
if len(mappedModels) > 1 {
|
|
||||||
c.Set(FallbackModelsContextKey, mappedModels[1:])
|
|
||||||
}
|
|
||||||
resolvedModel = mappedModels[0]
|
|
||||||
usedMapping = true
|
|
||||||
providers = mappedProviders
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -0,0 +1,326 @@
|
|||||||
|
package amp
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"net/http/httputil"
|
||||||
|
"net/url"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/routing/testutil"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Characterization tests for fallback_handlers.go using testutil recorders
|
||||||
|
// These tests capture existing behavior before refactoring to routing layer
|
||||||
|
|
||||||
|
func TestCharacterization_LocalProvider(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
|
// Register a mock provider for the test model
|
||||||
|
reg := registry.GetGlobalRegistry()
|
||||||
|
reg.RegisterClient("char-test-local", "anthropic", []*registry.ModelInfo{
|
||||||
|
{ID: "test-model-local"},
|
||||||
|
})
|
||||||
|
defer reg.UnregisterClient("char-test-local")
|
||||||
|
|
||||||
|
// Setup recorders
|
||||||
|
proxyRecorder := testutil.NewFakeProxyRecorder()
|
||||||
|
handlerRecorder := testutil.NewFakeHandlerRecorder()
|
||||||
|
|
||||||
|
// Create gin context
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(w)
|
||||||
|
|
||||||
|
body := `{"model": "test-model-local", "messages": [{"role": "user", "content": "hello"}]}`
|
||||||
|
req := httptest.NewRequest(http.MethodPost, "/api/provider/anthropic/v1/messages", bytes.NewReader([]byte(body)))
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
c.Request = req
|
||||||
|
|
||||||
|
// Create fallback handler with proxy recorder
|
||||||
|
// Create a test server to act as the proxy target
|
||||||
|
proxyServer := httptest.NewServer(proxyRecorder.ToHandler())
|
||||||
|
defer proxyServer.Close()
|
||||||
|
|
||||||
|
fh := NewFallbackHandler(func() *httputil.ReverseProxy {
|
||||||
|
// Create a reverse proxy that forwards to our test server
|
||||||
|
targetURL, _ := url.Parse(proxyServer.URL)
|
||||||
|
return httputil.NewSingleHostReverseProxy(targetURL)
|
||||||
|
})
|
||||||
|
|
||||||
|
// Execute
|
||||||
|
wrapped := fh.WrapHandler(handlerRecorder.GinHandler())
|
||||||
|
wrapped(c)
|
||||||
|
|
||||||
|
// Assert: proxy NOT called
|
||||||
|
assert.False(t, proxyRecorder.Called, "proxy should NOT be called for local provider")
|
||||||
|
|
||||||
|
// Assert: local handler called once
|
||||||
|
assert.True(t, handlerRecorder.WasCalled(), "local handler should be called")
|
||||||
|
assert.Equal(t, 1, handlerRecorder.GetCallCount(), "local handler should be called exactly once")
|
||||||
|
|
||||||
|
// Assert: request body model unchanged
|
||||||
|
assert.Contains(t, string(handlerRecorder.RequestBody), "test-model-local", "request body model should be unchanged")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCharacterization_ModelMapping(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
|
// Register a mock provider for the TARGET model (the mapped-to model)
|
||||||
|
reg := registry.GetGlobalRegistry()
|
||||||
|
reg.RegisterClient("char-test-mapped", "openai", []*registry.ModelInfo{
|
||||||
|
{ID: "gpt-4-local"},
|
||||||
|
})
|
||||||
|
defer reg.UnregisterClient("char-test-mapped")
|
||||||
|
|
||||||
|
// Setup recorders
|
||||||
|
proxyRecorder := testutil.NewFakeProxyRecorder()
|
||||||
|
handlerRecorder := testutil.NewFakeHandlerRecorder()
|
||||||
|
|
||||||
|
// Create model mapper with a mapping
|
||||||
|
mapper := NewModelMapper([]config.AmpModelMapping{
|
||||||
|
{From: "gpt-4-turbo", To: "gpt-4-local"},
|
||||||
|
})
|
||||||
|
|
||||||
|
// Create gin context
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(w)
|
||||||
|
|
||||||
|
// Request with original model that gets mapped
|
||||||
|
body := `{"model": "gpt-4-turbo", "messages": [{"role": "user", "content": "hello"}]}`
|
||||||
|
req := httptest.NewRequest(http.MethodPost, "/api/provider/openai/v1/chat/completions", bytes.NewReader([]byte(body)))
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
c.Request = req
|
||||||
|
|
||||||
|
// Create fallback handler with mapper
|
||||||
|
proxyServer := httptest.NewServer(proxyRecorder.ToHandler())
|
||||||
|
defer proxyServer.Close()
|
||||||
|
|
||||||
|
fh := NewFallbackHandlerWithMapper(func() *httputil.ReverseProxy {
|
||||||
|
targetURL, _ := url.Parse(proxyServer.URL)
|
||||||
|
return httputil.NewSingleHostReverseProxy(targetURL)
|
||||||
|
}, mapper, func() bool { return false })
|
||||||
|
|
||||||
|
// Execute - use handler that returns model in response for rewriter to work
|
||||||
|
wrapped := fh.WrapHandler(handlerRecorder.GinHandlerWithModel())
|
||||||
|
wrapped(c)
|
||||||
|
|
||||||
|
// Assert: proxy NOT called
|
||||||
|
assert.False(t, proxyRecorder.Called, "proxy should NOT be called for model mapping")
|
||||||
|
|
||||||
|
// Assert: local handler called once
|
||||||
|
assert.True(t, handlerRecorder.WasCalled(), "local handler should be called")
|
||||||
|
assert.Equal(t, 1, handlerRecorder.GetCallCount(), "local handler should be called exactly once")
|
||||||
|
|
||||||
|
// Assert: request body model was rewritten to mapped model
|
||||||
|
assert.Contains(t, string(handlerRecorder.RequestBody), "gpt-4-local", "request body model should be rewritten to mapped model")
|
||||||
|
assert.NotContains(t, string(handlerRecorder.RequestBody), "gpt-4-turbo", "request body should NOT contain original model")
|
||||||
|
|
||||||
|
// Assert: context has mapped_model key set
|
||||||
|
mappedModel, exists := handlerRecorder.GetContextKey("mapped_model")
|
||||||
|
assert.True(t, exists, "context should have mapped_model key")
|
||||||
|
assert.Equal(t, "gpt-4-local", mappedModel, "mapped_model should be the target model")
|
||||||
|
|
||||||
|
// Assert: response body model rewritten back to original
|
||||||
|
// The response writer should rewrite model names in the response
|
||||||
|
responseBody := w.Body.String()
|
||||||
|
assert.Contains(t, responseBody, "gpt-4-turbo", "response should have original model name")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCharacterization_AmpCreditsProxy(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
|
// Setup recorders - NO local provider registered, NO mapping configured
|
||||||
|
proxyRecorder := testutil.NewFakeProxyRecorder()
|
||||||
|
handlerRecorder := testutil.NewFakeHandlerRecorder()
|
||||||
|
|
||||||
|
// Create gin context with CloseNotifier support (required for ReverseProxy)
|
||||||
|
w := testutil.NewCloseNotifierRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(w)
|
||||||
|
|
||||||
|
// Request with a model that has no local provider and no mapping
|
||||||
|
body := `{"model": "unknown-model-no-provider", "messages": [{"role": "user", "content": "hello"}]}`
|
||||||
|
req := httptest.NewRequest(http.MethodPost, "/api/provider/openai/v1/chat/completions", bytes.NewReader([]byte(body)))
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
c.Request = req
|
||||||
|
|
||||||
|
// Create fallback handler
|
||||||
|
proxyServer := httptest.NewServer(proxyRecorder.ToHandler())
|
||||||
|
defer proxyServer.Close()
|
||||||
|
|
||||||
|
fh := NewFallbackHandler(func() *httputil.ReverseProxy {
|
||||||
|
targetURL, _ := url.Parse(proxyServer.URL)
|
||||||
|
return httputil.NewSingleHostReverseProxy(targetURL)
|
||||||
|
})
|
||||||
|
|
||||||
|
// Execute
|
||||||
|
wrapped := fh.WrapHandler(handlerRecorder.GinHandler())
|
||||||
|
wrapped(c)
|
||||||
|
|
||||||
|
// Assert: proxy called once
|
||||||
|
assert.True(t, proxyRecorder.Called, "proxy should be called when no local provider and no mapping")
|
||||||
|
assert.Equal(t, 1, proxyRecorder.GetCallCount(), "proxy should be called exactly once")
|
||||||
|
|
||||||
|
// Assert: local handler NOT called
|
||||||
|
assert.False(t, handlerRecorder.WasCalled(), "local handler should NOT be called when falling back to proxy")
|
||||||
|
|
||||||
|
// Assert: body forwarded to proxy is original (no rewrite)
|
||||||
|
assert.Contains(t, string(proxyRecorder.RequestBody), "unknown-model-no-provider", "request body model should be unchanged when proxying")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCharacterization_BodyRestore(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
|
// Register a mock provider for the test model
|
||||||
|
reg := registry.GetGlobalRegistry()
|
||||||
|
reg.RegisterClient("char-test-body", "anthropic", []*registry.ModelInfo{
|
||||||
|
{ID: "test-model-body"},
|
||||||
|
})
|
||||||
|
defer reg.UnregisterClient("char-test-body")
|
||||||
|
|
||||||
|
// Setup recorders
|
||||||
|
proxyRecorder := testutil.NewFakeProxyRecorder()
|
||||||
|
handlerRecorder := testutil.NewFakeHandlerRecorder()
|
||||||
|
|
||||||
|
// Create gin context
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(w)
|
||||||
|
|
||||||
|
// Create a complex request body that will be read by the wrapper for model extraction
|
||||||
|
originalBody := `{"model": "test-model-body", "messages": [{"role": "user", "content": "hello"}], "temperature": 0.7, "stream": true}`
|
||||||
|
req := httptest.NewRequest(http.MethodPost, "/api/provider/anthropic/v1/messages", bytes.NewReader([]byte(originalBody)))
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
c.Request = req
|
||||||
|
|
||||||
|
// Create fallback handler with proxy recorder
|
||||||
|
proxyServer := httptest.NewServer(proxyRecorder.ToHandler())
|
||||||
|
defer proxyServer.Close()
|
||||||
|
|
||||||
|
fh := NewFallbackHandler(func() *httputil.ReverseProxy {
|
||||||
|
targetURL, _ := url.Parse(proxyServer.URL)
|
||||||
|
return httputil.NewSingleHostReverseProxy(targetURL)
|
||||||
|
})
|
||||||
|
|
||||||
|
// Execute
|
||||||
|
wrapped := fh.WrapHandler(handlerRecorder.GinHandler())
|
||||||
|
wrapped(c)
|
||||||
|
|
||||||
|
// Assert: local handler called (not proxy, since we have a local provider)
|
||||||
|
assert.True(t, handlerRecorder.WasCalled(), "local handler should be called")
|
||||||
|
assert.False(t, proxyRecorder.Called, "proxy should NOT be called for local provider")
|
||||||
|
|
||||||
|
// Assert: handler receives complete original body
|
||||||
|
// This verifies that the body was properly restored after the wrapper read it for model extraction
|
||||||
|
assert.Equal(t, originalBody, string(handlerRecorder.RequestBody), "handler should receive complete original body after wrapper reads it for model extraction")
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestCharacterization_GeminiV1Beta1_PostModels tests that POST requests with /models/ path use Gemini bridge handler
|
||||||
|
// This is a characterization test for the route gating logic in routes.go
|
||||||
|
func TestCharacterization_GeminiV1Beta1_PostModels(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
|
// Register a mock provider for the test model (Gemini format uses path-based model extraction)
|
||||||
|
reg := registry.GetGlobalRegistry()
|
||||||
|
reg.RegisterClient("char-test-gemini", "google", []*registry.ModelInfo{
|
||||||
|
{ID: "gemini-pro"},
|
||||||
|
})
|
||||||
|
defer reg.UnregisterClient("char-test-gemini")
|
||||||
|
|
||||||
|
// Setup recorders
|
||||||
|
proxyRecorder := testutil.NewFakeProxyRecorder()
|
||||||
|
handlerRecorder := testutil.NewFakeHandlerRecorder()
|
||||||
|
|
||||||
|
// Create a test server for the proxy
|
||||||
|
proxyServer := httptest.NewServer(proxyRecorder.ToHandler())
|
||||||
|
defer proxyServer.Close()
|
||||||
|
|
||||||
|
// Create fallback handler
|
||||||
|
fh := NewFallbackHandler(func() *httputil.ReverseProxy {
|
||||||
|
targetURL, _ := url.Parse(proxyServer.URL)
|
||||||
|
return httputil.NewSingleHostReverseProxy(targetURL)
|
||||||
|
})
|
||||||
|
|
||||||
|
// Create the Gemini bridge handler (simulating what routes.go does)
|
||||||
|
geminiBridge := createGeminiBridgeHandler(handlerRecorder.GinHandler())
|
||||||
|
geminiV1Beta1Handler := fh.WrapHandler(geminiBridge)
|
||||||
|
|
||||||
|
// Create router with the same gating logic as routes.go
|
||||||
|
r := gin.New()
|
||||||
|
r.Any("/api/provider/google/v1beta1/*path", func(c *gin.Context) {
|
||||||
|
if c.Request.Method == "POST" {
|
||||||
|
if path := c.Param("path"); strings.Contains(path, "/models/") {
|
||||||
|
// POST with /models/ path -> use Gemini bridge with fallback handler
|
||||||
|
geminiV1Beta1Handler(c)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
// Non-POST or no /models/ in path -> proxy upstream
|
||||||
|
proxyRecorder.ServeHTTP(c.Writer, c.Request)
|
||||||
|
})
|
||||||
|
|
||||||
|
// Execute: POST request with /models/ in path
|
||||||
|
body := `{"contents": [{"role": "user", "parts": [{"text": "hello"}]}]}`
|
||||||
|
req := httptest.NewRequest(http.MethodPost, "/api/provider/google/v1beta1/publishers/google/models/gemini-pro:generateContent", bytes.NewReader([]byte(body)))
|
||||||
|
req.Header.Set("Content-Type", "application/json")
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
r.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
// Assert: local Gemini handler called
|
||||||
|
assert.True(t, handlerRecorder.WasCalled(), "local Gemini handler should be called for POST /models/")
|
||||||
|
|
||||||
|
// Assert: proxy NOT called
|
||||||
|
assert.False(t, proxyRecorder.Called, "proxy should NOT be called for POST /models/ path")
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestCharacterization_GeminiV1Beta1_GetProxies tests that GET requests to Gemini v1beta1 always use proxy
|
||||||
|
// This is a characterization test for the route gating logic in routes.go
|
||||||
|
func TestCharacterization_GeminiV1Beta1_GetProxies(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
|
||||||
|
// Setup recorders
|
||||||
|
proxyRecorder := testutil.NewFakeProxyRecorder()
|
||||||
|
handlerRecorder := testutil.NewFakeHandlerRecorder()
|
||||||
|
|
||||||
|
// Create a test server for the proxy
|
||||||
|
proxyServer := httptest.NewServer(proxyRecorder.ToHandler())
|
||||||
|
defer proxyServer.Close()
|
||||||
|
|
||||||
|
// Create fallback handler
|
||||||
|
fh := NewFallbackHandler(func() *httputil.ReverseProxy {
|
||||||
|
targetURL, _ := url.Parse(proxyServer.URL)
|
||||||
|
return httputil.NewSingleHostReverseProxy(targetURL)
|
||||||
|
})
|
||||||
|
|
||||||
|
// Create the Gemini bridge handler
|
||||||
|
geminiBridge := createGeminiBridgeHandler(handlerRecorder.GinHandler())
|
||||||
|
geminiV1Beta1Handler := fh.WrapHandler(geminiBridge)
|
||||||
|
|
||||||
|
// Create router with the same gating logic as routes.go
|
||||||
|
r := gin.New()
|
||||||
|
r.Any("/api/provider/google/v1beta1/*path", func(c *gin.Context) {
|
||||||
|
if c.Request.Method == "POST" {
|
||||||
|
if path := c.Param("path"); strings.Contains(path, "/models/") {
|
||||||
|
geminiV1Beta1Handler(c)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
proxyRecorder.ServeHTTP(c.Writer, c.Request)
|
||||||
|
})
|
||||||
|
|
||||||
|
// Execute: GET request (even with /models/ in path)
|
||||||
|
req := httptest.NewRequest(http.MethodGet, "/api/provider/google/v1beta1/publishers/google/models/gemini-pro", nil)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
r.ServeHTTP(w, req)
|
||||||
|
|
||||||
|
// Assert: proxy called
|
||||||
|
assert.True(t, proxyRecorder.Called, "proxy should be called for GET requests")
|
||||||
|
assert.Equal(t, 1, proxyRecorder.GetCallCount(), "proxy should be called exactly once")
|
||||||
|
|
||||||
|
// Assert: local handler NOT called
|
||||||
|
assert.False(t, handlerRecorder.WasCalled(), "local handler should NOT be called for GET requests")
|
||||||
|
}
|
||||||
@@ -89,20 +89,6 @@ func (m *DefaultModelMapper) UpdateOAuthModelAlias(aliases map[string][]config.O
|
|||||||
log.Debugf("amp model mapping: loaded oauth-model-alias for %d channel(s)", len(forward))
|
log.Debugf("amp model mapping: loaded oauth-model-alias for %d channel(s)", len(forward))
|
||||||
}
|
}
|
||||||
|
|
||||||
// findProviderViaOAuthAlias checks if targetModel is an oauth-model-alias name
|
|
||||||
// and returns all aliases that have available providers.
|
|
||||||
// Returns the first alias and its providers for backward compatibility,
|
|
||||||
// and also populates allAliases with all available alias models.
|
|
||||||
func (m *DefaultModelMapper) findProviderViaOAuthAlias(targetModel string) (aliasModel string, providers []string) {
|
|
||||||
aliases := m.findAllAliasesWithProviders(targetModel)
|
|
||||||
if len(aliases) == 0 {
|
|
||||||
return "", nil
|
|
||||||
}
|
|
||||||
// Return first one for backward compatibility
|
|
||||||
first := aliases[0]
|
|
||||||
return first, util.GetProviderName(first)
|
|
||||||
}
|
|
||||||
|
|
||||||
// findAllAliasesWithProviders returns all oauth-model-alias aliases for targetModel
|
// findAllAliasesWithProviders returns all oauth-model-alias aliases for targetModel
|
||||||
// that have available providers. Useful for fallback when one alias is quota-exceeded.
|
// that have available providers. Useful for fallback when one alias is quota-exceeded.
|
||||||
func (m *DefaultModelMapper) findAllAliasesWithProviders(targetModel string) []string {
|
func (m *DefaultModelMapper) findAllAliasesWithProviders(targetModel string) []string {
|
||||||
@@ -222,7 +208,7 @@ func (m *DefaultModelMapper) MapModelWithFallbacks(requestedModel string) []stri
|
|||||||
if len(allAliases) == 1 {
|
if len(allAliases) == 1 {
|
||||||
log.Debugf("amp model mapping: resolved %s -> %s via oauth-model-alias", targetModel, allAliases[0])
|
log.Debugf("amp model mapping: resolved %s -> %s via oauth-model-alias", targetModel, allAliases[0])
|
||||||
} else {
|
} else {
|
||||||
log.Debugf("amp model mapping: resolved %s -> %v via oauth-model-alias (%d fallbacks)", targetModel, allAliases, len(allAliases))
|
log.Debugf("amp model mapping: resolved %s -> %v via oauth-model-alias (%d fallbacks)", targetModel, allAliases, len(allAliases)-1)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Apply suffix to all aliases
|
// Apply suffix to all aliases
|
||||||
@@ -290,6 +276,22 @@ func (m *DefaultModelMapper) GetMappings() map[string]string {
|
|||||||
return result
|
return result
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetMappingsAsConfig returns the current model mappings as config.AmpModelMapping slice.
|
||||||
|
// Safe for concurrent use.
|
||||||
|
func (m *DefaultModelMapper) GetMappingsAsConfig() []config.AmpModelMapping {
|
||||||
|
m.mu.RLock()
|
||||||
|
defer m.mu.RUnlock()
|
||||||
|
|
||||||
|
result := make([]config.AmpModelMapping, 0, len(m.mappings))
|
||||||
|
for from, to := range m.mappings {
|
||||||
|
result = append(result, config.AmpModelMapping{
|
||||||
|
From: from,
|
||||||
|
To: to,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
type regexMapping struct {
|
type regexMapping struct {
|
||||||
re *regexp.Regexp
|
re *regexp.Regexp
|
||||||
to string
|
to string
|
||||||
|
|||||||
@@ -5,11 +5,12 @@ import (
|
|||||||
"errors"
|
"errors"
|
||||||
"net"
|
"net"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/http/httputil"
|
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/logging"
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/logging"
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/routing"
|
||||||
"github.com/router-for-me/CLIProxyAPI/v6/sdk/api/handlers"
|
"github.com/router-for-me/CLIProxyAPI/v6/sdk/api/handlers"
|
||||||
"github.com/router-for-me/CLIProxyAPI/v6/sdk/api/handlers/claude"
|
"github.com/router-for-me/CLIProxyAPI/v6/sdk/api/handlers/claude"
|
||||||
"github.com/router-for-me/CLIProxyAPI/v6/sdk/api/handlers/gemini"
|
"github.com/router-for-me/CLIProxyAPI/v6/sdk/api/handlers/gemini"
|
||||||
@@ -234,19 +235,20 @@ func (m *AmpModule) registerManagementRoutes(engine *gin.Engine, baseHandler *ha
|
|||||||
// If no local OAuth is available, falls back to ampcode.com proxy.
|
// If no local OAuth is available, falls back to ampcode.com proxy.
|
||||||
geminiHandlers := gemini.NewGeminiAPIHandler(baseHandler)
|
geminiHandlers := gemini.NewGeminiAPIHandler(baseHandler)
|
||||||
geminiBridge := createGeminiBridgeHandler(geminiHandlers.GeminiHandler)
|
geminiBridge := createGeminiBridgeHandler(geminiHandlers.GeminiHandler)
|
||||||
geminiV1Beta1Fallback := NewFallbackHandlerWithMapper(func() *httputil.ReverseProxy {
|
|
||||||
return m.getProxy()
|
|
||||||
}, m.modelMapper, m.forceModelMappings)
|
|
||||||
geminiV1Beta1Handler := geminiV1Beta1Fallback.WrapHandler(geminiBridge)
|
|
||||||
|
|
||||||
// Route POST model calls through Gemini bridge with FallbackHandler.
|
// T-025: Migrated Gemini v1beta1 bridge to use ModelRoutingWrapper
|
||||||
// FallbackHandler checks provider -> mapping -> proxy fallback automatically.
|
// Create a dedicated routing wrapper for the Gemini bridge
|
||||||
|
geminiBridgeWrapper := m.createModelRoutingWrapper()
|
||||||
|
geminiV1Beta1Handler := geminiBridgeWrapper.Wrap(geminiBridge)
|
||||||
|
|
||||||
|
// Route POST model calls through Gemini bridge with ModelRoutingWrapper.
|
||||||
|
// ModelRoutingWrapper checks provider -> mapping -> proxy fallback automatically.
|
||||||
// All other methods (e.g., GET model listing) always proxy to upstream to preserve Amp CLI behavior.
|
// All other methods (e.g., GET model listing) always proxy to upstream to preserve Amp CLI behavior.
|
||||||
ampAPI.Any("/provider/google/v1beta1/*path", func(c *gin.Context) {
|
ampAPI.Any("/provider/google/v1beta1/*path", func(c *gin.Context) {
|
||||||
if c.Request.Method == "POST" {
|
if c.Request.Method == "POST" {
|
||||||
if path := c.Param("path"); strings.Contains(path, "/models/") {
|
if path := c.Param("path"); strings.Contains(path, "/models/") {
|
||||||
// POST with /models/ path -> use Gemini bridge with fallback handler
|
// POST with /models/ path -> use Gemini bridge with unified routing wrapper
|
||||||
// FallbackHandler will check provider/mapping and proxy if needed
|
// ModelRoutingWrapper will check provider/mapping and proxy if needed
|
||||||
geminiV1Beta1Handler(c)
|
geminiV1Beta1Handler(c)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -256,6 +258,41 @@ func (m *AmpModule) registerManagementRoutes(engine *gin.Engine, baseHandler *ha
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// createModelRoutingWrapper creates a new ModelRoutingWrapper for unified routing.
|
||||||
|
// This is used for testing the new routing implementation (T-021 onwards).
|
||||||
|
func (m *AmpModule) createModelRoutingWrapper() *routing.ModelRoutingWrapper {
|
||||||
|
// Create a registry - in production this would be populated with actual providers
|
||||||
|
registry := routing.NewRegistry()
|
||||||
|
|
||||||
|
// Create a minimal config with just AmpCode settings
|
||||||
|
// The Router only needs AmpCode.ModelMappings and OAuthModelAlias
|
||||||
|
cfg := &config.Config{
|
||||||
|
AmpCode: func() config.AmpCode {
|
||||||
|
if m.modelMapper != nil {
|
||||||
|
return config.AmpCode{
|
||||||
|
ModelMappings: m.modelMapper.GetMappingsAsConfig(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return config.AmpCode{}
|
||||||
|
}(),
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create router with registry and config
|
||||||
|
router := routing.NewRouter(registry, cfg)
|
||||||
|
|
||||||
|
// Create wrapper with proxy function
|
||||||
|
proxyFunc := func(c *gin.Context) {
|
||||||
|
proxy := m.getProxy()
|
||||||
|
if proxy != nil {
|
||||||
|
proxy.ServeHTTP(c.Writer, c.Request)
|
||||||
|
} else {
|
||||||
|
c.JSON(503, gin.H{"error": "amp upstream proxy not available"})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return routing.NewModelRoutingWrapper(router, nil, nil, proxyFunc)
|
||||||
|
}
|
||||||
|
|
||||||
// registerProviderAliases registers /api/provider/{provider}/... routes
|
// registerProviderAliases registers /api/provider/{provider}/... routes
|
||||||
// These allow Amp CLI to route requests like:
|
// These allow Amp CLI to route requests like:
|
||||||
//
|
//
|
||||||
@@ -269,12 +306,9 @@ func (m *AmpModule) registerProviderAliases(engine *gin.Engine, baseHandler *han
|
|||||||
claudeCodeHandlers := claude.NewClaudeCodeAPIHandler(baseHandler)
|
claudeCodeHandlers := claude.NewClaudeCodeAPIHandler(baseHandler)
|
||||||
openaiResponsesHandlers := openai.NewOpenAIResponsesAPIHandler(baseHandler)
|
openaiResponsesHandlers := openai.NewOpenAIResponsesAPIHandler(baseHandler)
|
||||||
|
|
||||||
// Create fallback handler wrapper that forwards to ampcode.com when provider not found
|
// Create unified routing wrapper (T-021 onwards)
|
||||||
// Uses m.getProxy() for hot-reload support (proxy can be updated at runtime)
|
// Replaces FallbackHandler with Router-based unified routing
|
||||||
// Also includes model mapping support for routing unavailable models to alternatives
|
routingWrapper := m.createModelRoutingWrapper()
|
||||||
fallbackHandler := NewFallbackHandlerWithMapper(func() *httputil.ReverseProxy {
|
|
||||||
return m.getProxy()
|
|
||||||
}, m.modelMapper, m.forceModelMappings)
|
|
||||||
|
|
||||||
// Provider-specific routes under /api/provider/:provider
|
// Provider-specific routes under /api/provider/:provider
|
||||||
ampProviders := engine.Group("/api/provider")
|
ampProviders := engine.Group("/api/provider")
|
||||||
@@ -302,33 +336,36 @@ func (m *AmpModule) registerProviderAliases(engine *gin.Engine, baseHandler *han
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Root-level routes (for providers that omit /v1, like groq/cerebras)
|
// Root-level routes (for providers that omit /v1, like groq/cerebras)
|
||||||
// Wrap handlers with fallback logic to forward to ampcode.com when provider not found
|
// T-022: Migrated all OpenAI routes to use ModelRoutingWrapper for unified routing
|
||||||
provider.GET("/models", ampModelsHandler) // Models endpoint doesn't need fallback (no body to check)
|
provider.GET("/models", ampModelsHandler) // Models endpoint doesn't need fallback (no body to check)
|
||||||
provider.POST("/chat/completions", fallbackHandler.WrapHandler(openaiHandlers.ChatCompletions))
|
provider.POST("/chat/completions", routingWrapper.Wrap(openaiHandlers.ChatCompletions))
|
||||||
provider.POST("/completions", fallbackHandler.WrapHandler(openaiHandlers.Completions))
|
provider.POST("/completions", routingWrapper.Wrap(openaiHandlers.Completions))
|
||||||
provider.POST("/responses", fallbackHandler.WrapHandler(openaiResponsesHandlers.Responses))
|
provider.POST("/responses", routingWrapper.Wrap(openaiResponsesHandlers.Responses))
|
||||||
|
|
||||||
// /v1 routes (OpenAI/Claude-compatible endpoints)
|
// /v1 routes (OpenAI/Claude-compatible endpoints)
|
||||||
v1Amp := provider.Group("/v1")
|
v1Amp := provider.Group("/v1")
|
||||||
{
|
{
|
||||||
v1Amp.GET("/models", ampModelsHandler) // Models endpoint doesn't need fallback
|
v1Amp.GET("/models", ampModelsHandler) // Models endpoint doesn't need fallback
|
||||||
|
|
||||||
// OpenAI-compatible endpoints with fallback
|
// OpenAI-compatible endpoints with ModelRoutingWrapper
|
||||||
v1Amp.POST("/chat/completions", fallbackHandler.WrapHandler(openaiHandlers.ChatCompletions))
|
// T-021, T-022: Migrated to unified routing wrapper
|
||||||
v1Amp.POST("/completions", fallbackHandler.WrapHandler(openaiHandlers.Completions))
|
v1Amp.POST("/chat/completions", routingWrapper.Wrap(openaiHandlers.ChatCompletions))
|
||||||
v1Amp.POST("/responses", fallbackHandler.WrapHandler(openaiResponsesHandlers.Responses))
|
v1Amp.POST("/completions", routingWrapper.Wrap(openaiHandlers.Completions))
|
||||||
|
v1Amp.POST("/responses", routingWrapper.Wrap(openaiResponsesHandlers.Responses))
|
||||||
|
|
||||||
// Claude/Anthropic-compatible endpoints with fallback
|
// Claude/Anthropic-compatible endpoints with ModelRoutingWrapper
|
||||||
v1Amp.POST("/messages", fallbackHandler.WrapHandler(claudeCodeHandlers.ClaudeMessages))
|
// T-023: Migrated Claude routes to unified routing wrapper
|
||||||
v1Amp.POST("/messages/count_tokens", fallbackHandler.WrapHandler(claudeCodeHandlers.ClaudeCountTokens))
|
v1Amp.POST("/messages", routingWrapper.Wrap(claudeCodeHandlers.ClaudeMessages))
|
||||||
|
v1Amp.POST("/messages/count_tokens", routingWrapper.Wrap(claudeCodeHandlers.ClaudeCountTokens))
|
||||||
}
|
}
|
||||||
|
|
||||||
// /v1beta routes (Gemini native API)
|
// /v1beta routes (Gemini native API)
|
||||||
// Note: Gemini handler extracts model from URL path, so fallback logic needs special handling
|
// Note: Gemini handler extracts model from URL path, so fallback logic needs special handling
|
||||||
|
// T-024: Migrated Gemini v1beta routes to unified routing wrapper
|
||||||
v1betaAmp := provider.Group("/v1beta")
|
v1betaAmp := provider.Group("/v1beta")
|
||||||
{
|
{
|
||||||
v1betaAmp.GET("/models", geminiHandlers.GeminiModels)
|
v1betaAmp.GET("/models", geminiHandlers.GeminiModels)
|
||||||
v1betaAmp.POST("/models/*action", fallbackHandler.WrapHandler(geminiHandlers.GeminiHandler))
|
v1betaAmp.POST("/models/*action", routingWrapper.Wrap(geminiHandlers.GeminiHandler))
|
||||||
v1betaAmp.GET("/models/*action", geminiHandlers.GeminiGetHandler)
|
v1betaAmp.GET("/models/*action", geminiHandlers.GeminiGetHandler)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -60,9 +60,9 @@ type ServerOption func(*serverOptionConfig)
|
|||||||
func defaultRequestLoggerFactory(cfg *config.Config, configPath string) logging.RequestLogger {
|
func defaultRequestLoggerFactory(cfg *config.Config, configPath string) logging.RequestLogger {
|
||||||
configDir := filepath.Dir(configPath)
|
configDir := filepath.Dir(configPath)
|
||||||
if base := util.WritablePath(); base != "" {
|
if base := util.WritablePath(); base != "" {
|
||||||
return logging.NewFileRequestLogger(cfg.RequestLog, filepath.Join(base, "logs"), configDir)
|
return logging.NewFileRequestLogger(cfg.RequestLog, filepath.Join(base, "logs"), configDir, cfg.ErrorLogsMaxFiles)
|
||||||
}
|
}
|
||||||
return logging.NewFileRequestLogger(cfg.RequestLog, "logs", configDir)
|
return logging.NewFileRequestLogger(cfg.RequestLog, "logs", configDir, cfg.ErrorLogsMaxFiles)
|
||||||
}
|
}
|
||||||
|
|
||||||
// WithMiddleware appends additional Gin middleware during server construction.
|
// WithMiddleware appends additional Gin middleware during server construction.
|
||||||
@@ -497,6 +497,10 @@ func (s *Server) registerManagementRoutes() {
|
|||||||
mgmt.PUT("/logs-max-total-size-mb", s.mgmt.PutLogsMaxTotalSizeMB)
|
mgmt.PUT("/logs-max-total-size-mb", s.mgmt.PutLogsMaxTotalSizeMB)
|
||||||
mgmt.PATCH("/logs-max-total-size-mb", s.mgmt.PutLogsMaxTotalSizeMB)
|
mgmt.PATCH("/logs-max-total-size-mb", s.mgmt.PutLogsMaxTotalSizeMB)
|
||||||
|
|
||||||
|
mgmt.GET("/error-logs-max-files", s.mgmt.GetErrorLogsMaxFiles)
|
||||||
|
mgmt.PUT("/error-logs-max-files", s.mgmt.PutErrorLogsMaxFiles)
|
||||||
|
mgmt.PATCH("/error-logs-max-files", s.mgmt.PutErrorLogsMaxFiles)
|
||||||
|
|
||||||
mgmt.GET("/usage-statistics-enabled", s.mgmt.GetUsageStatisticsEnabled)
|
mgmt.GET("/usage-statistics-enabled", s.mgmt.GetUsageStatisticsEnabled)
|
||||||
mgmt.PUT("/usage-statistics-enabled", s.mgmt.PutUsageStatisticsEnabled)
|
mgmt.PUT("/usage-statistics-enabled", s.mgmt.PutUsageStatisticsEnabled)
|
||||||
mgmt.PATCH("/usage-statistics-enabled", s.mgmt.PutUsageStatisticsEnabled)
|
mgmt.PATCH("/usage-statistics-enabled", s.mgmt.PutUsageStatisticsEnabled)
|
||||||
@@ -907,6 +911,15 @@ func (s *Server) UpdateClients(cfg *config.Config) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if s.requestLogger != nil && (oldCfg == nil || oldCfg.ErrorLogsMaxFiles != cfg.ErrorLogsMaxFiles) {
|
||||||
|
if setter, ok := s.requestLogger.(interface{ SetErrorLogsMaxFiles(int) }); ok {
|
||||||
|
setter.SetErrorLogsMaxFiles(cfg.ErrorLogsMaxFiles)
|
||||||
|
}
|
||||||
|
if oldCfg != nil {
|
||||||
|
log.Debugf("error_logs_max_files updated from %d to %d", oldCfg.ErrorLogsMaxFiles, cfg.ErrorLogsMaxFiles)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if oldCfg == nil || oldCfg.DisableCooling != cfg.DisableCooling {
|
if oldCfg == nil || oldCfg.DisableCooling != cfg.DisableCooling {
|
||||||
auth.SetQuotaCooldownDisabled(cfg.DisableCooling)
|
auth.SetQuotaCooldownDisabled(cfg.DisableCooling)
|
||||||
if oldCfg != nil {
|
if oldCfg != nil {
|
||||||
|
|||||||
@@ -51,6 +51,10 @@ type Config struct {
|
|||||||
// When exceeded, the oldest log files are deleted until within the limit. Set to 0 to disable.
|
// When exceeded, the oldest log files are deleted until within the limit. Set to 0 to disable.
|
||||||
LogsMaxTotalSizeMB int `yaml:"logs-max-total-size-mb" json:"logs-max-total-size-mb"`
|
LogsMaxTotalSizeMB int `yaml:"logs-max-total-size-mb" json:"logs-max-total-size-mb"`
|
||||||
|
|
||||||
|
// ErrorLogsMaxFiles limits the number of error log files retained when request logging is disabled.
|
||||||
|
// When exceeded, the oldest error log files are deleted. Default is 10. Set to 0 to disable cleanup.
|
||||||
|
ErrorLogsMaxFiles int `yaml:"error-logs-max-files" json:"error-logs-max-files"`
|
||||||
|
|
||||||
// UsageStatisticsEnabled toggles in-memory usage aggregation; when false, usage data is discarded.
|
// UsageStatisticsEnabled toggles in-memory usage aggregation; when false, usage data is discarded.
|
||||||
UsageStatisticsEnabled bool `yaml:"usage-statistics-enabled" json:"usage-statistics-enabled"`
|
UsageStatisticsEnabled bool `yaml:"usage-statistics-enabled" json:"usage-statistics-enabled"`
|
||||||
|
|
||||||
@@ -229,6 +233,16 @@ type PayloadConfig struct {
|
|||||||
Override []PayloadRule `yaml:"override" json:"override"`
|
Override []PayloadRule `yaml:"override" json:"override"`
|
||||||
// OverrideRaw defines rules that always set raw JSON values, overwriting any existing values.
|
// OverrideRaw defines rules that always set raw JSON values, overwriting any existing values.
|
||||||
OverrideRaw []PayloadRule `yaml:"override-raw" json:"override-raw"`
|
OverrideRaw []PayloadRule `yaml:"override-raw" json:"override-raw"`
|
||||||
|
// Filter defines rules that remove parameters from the payload by JSON path.
|
||||||
|
Filter []PayloadFilterRule `yaml:"filter" json:"filter"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// PayloadFilterRule describes a rule to remove specific JSON paths from matching model payloads.
|
||||||
|
type PayloadFilterRule struct {
|
||||||
|
// Models lists model entries with name pattern and protocol constraint.
|
||||||
|
Models []PayloadModelRule `yaml:"models" json:"models"`
|
||||||
|
// Params lists JSON paths (gjson/sjson syntax) to remove from the payload.
|
||||||
|
Params []string `yaml:"params" json:"params"`
|
||||||
}
|
}
|
||||||
|
|
||||||
// PayloadRule describes a single rule targeting a list of models with parameter updates.
|
// PayloadRule describes a single rule targeting a list of models with parameter updates.
|
||||||
@@ -502,6 +516,7 @@ func LoadConfigOptional(configFile string, optional bool) (*Config, error) {
|
|||||||
cfg.Host = "" // Default empty: binds to all interfaces (IPv4 + IPv6)
|
cfg.Host = "" // Default empty: binds to all interfaces (IPv4 + IPv6)
|
||||||
cfg.LoggingToFile = false
|
cfg.LoggingToFile = false
|
||||||
cfg.LogsMaxTotalSizeMB = 0
|
cfg.LogsMaxTotalSizeMB = 0
|
||||||
|
cfg.ErrorLogsMaxFiles = 10
|
||||||
cfg.UsageStatisticsEnabled = false
|
cfg.UsageStatisticsEnabled = false
|
||||||
cfg.DisableCooling = false
|
cfg.DisableCooling = false
|
||||||
cfg.AmpCode.RestrictManagementToLocalhost = false // Default to false: API key auth is sufficient
|
cfg.AmpCode.RestrictManagementToLocalhost = false // Default to false: API key auth is sufficient
|
||||||
@@ -550,6 +565,10 @@ func LoadConfigOptional(configFile string, optional bool) (*Config, error) {
|
|||||||
cfg.LogsMaxTotalSizeMB = 0
|
cfg.LogsMaxTotalSizeMB = 0
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if cfg.ErrorLogsMaxFiles < 0 {
|
||||||
|
cfg.ErrorLogsMaxFiles = 10
|
||||||
|
}
|
||||||
|
|
||||||
// Sync request authentication providers with inline API keys for backwards compatibility.
|
// Sync request authentication providers with inline API keys for backwards compatibility.
|
||||||
syncInlineAccessProvider(&cfg)
|
syncInlineAccessProvider(&cfg)
|
||||||
|
|
||||||
|
|||||||
@@ -132,6 +132,9 @@ type FileRequestLogger struct {
|
|||||||
|
|
||||||
// logsDir is the directory where log files are stored.
|
// logsDir is the directory where log files are stored.
|
||||||
logsDir string
|
logsDir string
|
||||||
|
|
||||||
|
// errorLogsMaxFiles limits the number of error log files retained.
|
||||||
|
errorLogsMaxFiles int
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewFileRequestLogger creates a new file-based request logger.
|
// NewFileRequestLogger creates a new file-based request logger.
|
||||||
@@ -141,10 +144,11 @@ type FileRequestLogger struct {
|
|||||||
// - logsDir: The directory where log files should be stored (can be relative)
|
// - logsDir: The directory where log files should be stored (can be relative)
|
||||||
// - configDir: The directory of the configuration file; when logsDir is
|
// - configDir: The directory of the configuration file; when logsDir is
|
||||||
// relative, it will be resolved relative to this directory
|
// relative, it will be resolved relative to this directory
|
||||||
|
// - errorLogsMaxFiles: Maximum number of error log files to retain (0 = no cleanup)
|
||||||
//
|
//
|
||||||
// Returns:
|
// Returns:
|
||||||
// - *FileRequestLogger: A new file-based request logger instance
|
// - *FileRequestLogger: A new file-based request logger instance
|
||||||
func NewFileRequestLogger(enabled bool, logsDir string, configDir string) *FileRequestLogger {
|
func NewFileRequestLogger(enabled bool, logsDir string, configDir string, errorLogsMaxFiles int) *FileRequestLogger {
|
||||||
// Resolve logsDir relative to the configuration file directory when it's not absolute.
|
// Resolve logsDir relative to the configuration file directory when it's not absolute.
|
||||||
if !filepath.IsAbs(logsDir) {
|
if !filepath.IsAbs(logsDir) {
|
||||||
// If configDir is provided, resolve logsDir relative to it.
|
// If configDir is provided, resolve logsDir relative to it.
|
||||||
@@ -155,6 +159,7 @@ func NewFileRequestLogger(enabled bool, logsDir string, configDir string) *FileR
|
|||||||
return &FileRequestLogger{
|
return &FileRequestLogger{
|
||||||
enabled: enabled,
|
enabled: enabled,
|
||||||
logsDir: logsDir,
|
logsDir: logsDir,
|
||||||
|
errorLogsMaxFiles: errorLogsMaxFiles,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -175,6 +180,11 @@ func (l *FileRequestLogger) SetEnabled(enabled bool) {
|
|||||||
l.enabled = enabled
|
l.enabled = enabled
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetErrorLogsMaxFiles updates the maximum number of error log files to retain.
|
||||||
|
func (l *FileRequestLogger) SetErrorLogsMaxFiles(maxFiles int) {
|
||||||
|
l.errorLogsMaxFiles = maxFiles
|
||||||
|
}
|
||||||
|
|
||||||
// LogRequest logs a complete non-streaming request/response cycle to a file.
|
// LogRequest logs a complete non-streaming request/response cycle to a file.
|
||||||
//
|
//
|
||||||
// Parameters:
|
// Parameters:
|
||||||
@@ -433,8 +443,12 @@ func (l *FileRequestLogger) sanitizeForFilename(path string) string {
|
|||||||
return sanitized
|
return sanitized
|
||||||
}
|
}
|
||||||
|
|
||||||
// cleanupOldErrorLogs keeps only the newest 10 forced error log files.
|
// cleanupOldErrorLogs keeps only the newest errorLogsMaxFiles forced error log files.
|
||||||
func (l *FileRequestLogger) cleanupOldErrorLogs() error {
|
func (l *FileRequestLogger) cleanupOldErrorLogs() error {
|
||||||
|
if l.errorLogsMaxFiles <= 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
entries, errRead := os.ReadDir(l.logsDir)
|
entries, errRead := os.ReadDir(l.logsDir)
|
||||||
if errRead != nil {
|
if errRead != nil {
|
||||||
return errRead
|
return errRead
|
||||||
@@ -462,7 +476,7 @@ func (l *FileRequestLogger) cleanupOldErrorLogs() error {
|
|||||||
files = append(files, logFile{name: name, modTime: info.ModTime()})
|
files = append(files, logFile{name: name, modTime: info.ModTime()})
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(files) <= 10 {
|
if len(files) <= l.errorLogsMaxFiles {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -470,7 +484,7 @@ func (l *FileRequestLogger) cleanupOldErrorLogs() error {
|
|||||||
return files[i].modTime.After(files[j].modTime)
|
return files[i].modTime.After(files[j].modTime)
|
||||||
})
|
})
|
||||||
|
|
||||||
for _, file := range files[10:] {
|
for _, file := range files[l.errorLogsMaxFiles:] {
|
||||||
if errRemove := os.Remove(filepath.Join(l.logsDir, file.name)); errRemove != nil {
|
if errRemove := os.Remove(filepath.Join(l.logsDir, file.name)); errRemove != nil {
|
||||||
log.WithError(errRemove).Warnf("failed to remove old error log: %s", file.name)
|
log.WithError(errRemove).Warnf("failed to remove old error log: %s", file.name)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,318 +1,79 @@
|
|||||||
You are a coding agent running in the opencode, a terminal-based coding assistant. opencode is an open source project. You are expected to be precise, safe, and helpful.
|
You are OpenCode, the best coding agent on the planet.
|
||||||
|
|
||||||
Your capabilities:
|
You are an interactive CLI tool that helps users with software engineering tasks. Use the instructions below and the tools available to you to assist the user.
|
||||||
|
|
||||||
- Receive user prompts and other context provided by the harness, such as files in the workspace.
|
## Editing constraints
|
||||||
- Communicate with the user by streaming thinking & responses, and by making & updating plans.
|
- Default to ASCII when editing or creating files. Only introduce non-ASCII or other Unicode characters when there is a clear justification and the file already uses them.
|
||||||
- Emit function calls to run terminal commands and apply edits. Depending on how this specific run is configured, you can request that these function calls be escalated to the user for approval before running. More on this in the "Sandbox and approvals" section.
|
- Only add comments if they are necessary to make a non-obvious block easier to understand.
|
||||||
|
- Try to use apply_patch for single file edits, but it is fine to explore other options to make the edit if it does not work well. Do not use apply_patch for changes that are auto-generated (i.e. generating package.json or running a lint or format command like gofmt) or when scripting is more efficient (such as search and replacing a string across a codebase).
|
||||||
Within this context, Codex refers to the open-source agentic coding interface (not the old Codex language model built by OpenAI).
|
|
||||||
|
## Tool usage
|
||||||
# How you work
|
- Prefer specialized tools over shell for file operations:
|
||||||
|
- Use Read to view files, Edit to modify files, and Write only when needed.
|
||||||
## Personality
|
- Use Glob to find files by name and Grep to search file contents.
|
||||||
|
- Use Bash for terminal operations (git, bun, builds, tests, running scripts).
|
||||||
Your default personality and tone is concise, direct, and friendly. You communicate efficiently, always keeping the user clearly informed about ongoing actions without unnecessary detail. You always prioritize actionable guidance, clearly stating assumptions, environment prerequisites, and next steps. Unless explicitly asked, you avoid excessively verbose explanations about your work.
|
- Run tool calls in parallel when neither call needs the other’s output; otherwise run sequentially.
|
||||||
|
|
||||||
# AGENTS.md spec
|
## Git and workspace hygiene
|
||||||
- Repos often contain AGENTS.md files. These files can appear anywhere within the repository.
|
- You may be in a dirty git worktree.
|
||||||
- These files are a way for humans to give you (the agent) instructions or tips for working within the container.
|
* NEVER revert existing changes you did not make unless explicitly requested, since these changes were made by the user.
|
||||||
- Some examples might be: coding conventions, info about how code is organized, or instructions for how to run or test code.
|
* If asked to make a commit or code edits and there are unrelated changes to your work or changes that you didn't make in those files, don't revert those changes.
|
||||||
- Instructions in AGENTS.md files:
|
* If the changes are in files you've touched recently, you should read carefully and understand how you can work with the changes rather than reverting them.
|
||||||
- The scope of an AGENTS.md file is the entire directory tree rooted at the folder that contains it.
|
* If the changes are in unrelated files, just ignore them and don't revert them.
|
||||||
- For every file you touch in the final patch, you must obey instructions in any AGENTS.md file whose scope includes that file.
|
- Do not amend commits unless explicitly requested.
|
||||||
- Instructions about code style, structure, naming, etc. apply only to code within the AGENTS.md file's scope, unless the file states otherwise.
|
- **NEVER** use destructive commands like `git reset --hard` or `git checkout --` unless specifically requested or approved by the user.
|
||||||
- More-deeply-nested AGENTS.md files take precedence in the case of conflicting instructions.
|
|
||||||
- Direct system/developer/user instructions (as part of a prompt) take precedence over AGENTS.md instructions.
|
## Frontend tasks
|
||||||
- The contents of the AGENTS.md file at the root of the repo and any directories from the CWD up to the root are included with the developer message and don't need to be re-read. When working in a subdirectory of CWD, or a directory outside the CWD, check for any AGENTS.md files that may be applicable.
|
When doing frontend design tasks, avoid collapsing into bland, generic layouts.
|
||||||
|
Aim for interfaces that feel intentional and deliberate.
|
||||||
## Responsiveness
|
- Typography: Use expressive, purposeful fonts and avoid default stacks (Inter, Roboto, Arial, system).
|
||||||
|
- Color & Look: Choose a clear visual direction; define CSS variables; avoid purple-on-white defaults. No purple bias or dark mode bias.
|
||||||
### Preamble messages
|
- Motion: Use a few meaningful animations (page-load, staggered reveals) instead of generic micro-motions.
|
||||||
|
- Background: Don't rely on flat, single-color backgrounds; use gradients, shapes, or subtle patterns to build atmosphere.
|
||||||
Before making tool calls, send a brief preamble to the user explaining what you’re about to do. When sending preamble messages, follow these principles and examples:
|
- Overall: Avoid boilerplate layouts and interchangeable UI patterns. Vary themes, type families, and visual languages across outputs.
|
||||||
|
- Ensure the page loads properly on both desktop and mobile.
|
||||||
- **Logically group related actions**: if you’re about to run several related commands, describe them together in one preamble rather than sending a separate note for each.
|
|
||||||
- **Keep it concise**: be no more than 1-2 sentences, focused on immediate, tangible next steps. (8–12 words for quick updates).
|
Exception: If working within an existing website or design system, preserve the established patterns, structure, and visual language.
|
||||||
- **Build on prior context**: if this is not your first tool call, use the preamble message to connect the dots with what’s been done so far and create a sense of momentum and clarity for the user to understand your next actions.
|
|
||||||
- **Keep your tone light, friendly and curious**: add small touches of personality in preambles feel collaborative and engaging.
|
|
||||||
- **Exception**: Avoid adding a preamble for every trivial read (e.g., `cat` a single file) unless it’s part of a larger grouped action.
|
|
||||||
|
|
||||||
**Examples:**
|
|
||||||
|
|
||||||
- “I’ve explored the repo; now checking the API route definitions.”
|
|
||||||
- “Next, I’ll patch the config and update the related tests.”
|
|
||||||
- “I’m about to scaffold the CLI commands and helper functions.”
|
|
||||||
- “Ok cool, so I’ve wrapped my head around the repo. Now digging into the API routes.”
|
|
||||||
- “Config’s looking tidy. Next up is editing helpers to keep things in sync.”
|
|
||||||
- “Finished poking at the DB gateway. I will now chase down error handling.”
|
|
||||||
- “Alright, build pipeline order is interesting. Checking how it reports failures.”
|
|
||||||
- “Spotted a clever caching util; now hunting where it gets used.”
|
|
||||||
|
|
||||||
## Planning
|
|
||||||
|
|
||||||
You have access to an `todowrite` tool which tracks steps and progress and renders them to the user. Using the tool helps demonstrate that you've understood the task and convey how you're approaching it. Plans can help to make complex, ambiguous, or multi-phase work clearer and more collaborative for the user. A good plan should break the task into meaningful, logically ordered steps that are easy to verify as you go.
|
|
||||||
|
|
||||||
Note that plans are not for padding out simple work with filler steps or stating the obvious. The content of your plan should not involve doing anything that you aren't capable of doing (i.e. don't try to test things that you can't test). Do not use plans for simple or single-step queries that you can just do or answer immediately.
|
|
||||||
|
|
||||||
Do not repeat the full contents of the plan after an `todowrite` call — the harness already displays it. Instead, summarize the change made and highlight any important context or next step.
|
|
||||||
|
|
||||||
Before running a command, consider whether or not you have completed the
|
|
||||||
previous step, and make sure to mark it as completed before moving on to the
|
|
||||||
next step. It may be the case that you complete all steps in your plan after a
|
|
||||||
single pass of implementation. If this is the case, you can simply mark all the
|
|
||||||
planned steps as completed. Sometimes, you may need to change plans in the
|
|
||||||
middle of a task: call `todowrite` with the updated plan and make sure to provide an `explanation` of the rationale when doing so.
|
|
||||||
|
|
||||||
Use a plan when:
|
|
||||||
|
|
||||||
- The task is non-trivial and will require multiple actions over a long time horizon.
|
|
||||||
- There are logical phases or dependencies where sequencing matters.
|
|
||||||
- The work has ambiguity that benefits from outlining high-level goals.
|
|
||||||
- You want intermediate checkpoints for feedback and validation.
|
|
||||||
- When the user asked you to do more than one thing in a single prompt
|
|
||||||
- The user has asked you to use the plan tool (aka "TODOs")
|
|
||||||
- You generate additional steps while working, and plan to do them before yielding to the user
|
|
||||||
|
|
||||||
### Examples
|
|
||||||
|
|
||||||
**High-quality plans**
|
|
||||||
|
|
||||||
Example 1:
|
|
||||||
|
|
||||||
1. Add CLI entry with file args
|
|
||||||
2. Parse Markdown via CommonMark library
|
|
||||||
3. Apply semantic HTML template
|
|
||||||
4. Handle code blocks, images, links
|
|
||||||
5. Add error handling for invalid files
|
|
||||||
|
|
||||||
Example 2:
|
|
||||||
|
|
||||||
1. Define CSS variables for colors
|
|
||||||
2. Add toggle with localStorage state
|
|
||||||
3. Refactor components to use variables
|
|
||||||
4. Verify all views for readability
|
|
||||||
5. Add smooth theme-change transition
|
|
||||||
|
|
||||||
Example 3:
|
|
||||||
|
|
||||||
1. Set up Node.js + WebSocket server
|
|
||||||
2. Add join/leave broadcast events
|
|
||||||
3. Implement messaging with timestamps
|
|
||||||
4. Add usernames + mention highlighting
|
|
||||||
5. Persist messages in lightweight DB
|
|
||||||
6. Add typing indicators + unread count
|
|
||||||
|
|
||||||
**Low-quality plans**
|
|
||||||
|
|
||||||
Example 1:
|
|
||||||
|
|
||||||
1. Create CLI tool
|
|
||||||
2. Add Markdown parser
|
|
||||||
3. Convert to HTML
|
|
||||||
|
|
||||||
Example 2:
|
|
||||||
|
|
||||||
1. Add dark mode toggle
|
|
||||||
2. Save preference
|
|
||||||
3. Make styles look good
|
|
||||||
|
|
||||||
Example 3:
|
|
||||||
|
|
||||||
1. Create single-file HTML game
|
|
||||||
2. Run quick sanity check
|
|
||||||
3. Summarize usage instructions
|
|
||||||
|
|
||||||
If you need to write a plan, only write high quality plans, not low quality ones.
|
|
||||||
|
|
||||||
## Task execution
|
|
||||||
|
|
||||||
You are a coding agent. Please keep going until the query is completely resolved, before ending your turn and yielding back to the user. Only terminate your turn when you are sure that the problem is solved. Autonomously resolve the query to the best of your ability, using the tools available to you, before coming back to the user. Do NOT guess or make up an answer.
|
|
||||||
|
|
||||||
You MUST adhere to the following criteria when solving queries:
|
|
||||||
|
|
||||||
- Working on the repo(s) in the current environment is allowed, even if they are proprietary.
|
|
||||||
- Analyzing code for vulnerabilities is allowed.
|
|
||||||
- Showing user code and tool call details is allowed.
|
|
||||||
- Use the `edit` tool to edit files
|
|
||||||
|
|
||||||
If completing the user's task requires writing or modifying files, your code and final answer should follow these coding guidelines, though user instructions (i.e. AGENTS.md) may override these guidelines:
|
|
||||||
|
|
||||||
- Fix the problem at the root cause rather than applying surface-level patches, when possible.
|
|
||||||
- Avoid unneeded complexity in your solution.
|
|
||||||
- Do not attempt to fix unrelated bugs or broken tests. It is not your responsibility to fix them. (You may mention them to the user in your final message though.)
|
|
||||||
- Update documentation as necessary.
|
|
||||||
- Keep changes consistent with the style of the existing codebase. Changes should be minimal and focused on the task.
|
|
||||||
- Use `git log` and `git blame` to search the history of the codebase if additional context is required.
|
|
||||||
- NEVER add copyright or license headers unless specifically requested.
|
|
||||||
- Do not waste tokens by re-reading files after calling `edit` on them. The tool call will fail if it didn't work. The same goes for making folders, deleting folders, etc.
|
|
||||||
- Do not `git commit` your changes or create new git branches unless explicitly requested.
|
|
||||||
- Do not add inline comments within code unless explicitly requested.
|
|
||||||
- Do not use one-letter variable names unless explicitly requested.
|
|
||||||
- NEVER output inline citations like "【F:README.md†L5-L14】" in your outputs. The CLI is not able to render these so they will just be broken in the UI. Instead, if you output valid filepaths, users will be able to click on them to open the files in their editor.
|
|
||||||
|
|
||||||
## Sandbox and approvals
|
|
||||||
|
|
||||||
The Codex CLI harness supports several different sandboxing, and approval configurations that the user can choose from.
|
|
||||||
|
|
||||||
Filesystem sandboxing prevents you from editing files without user approval. The options are:
|
|
||||||
|
|
||||||
- **read-only**: You can only read files.
|
|
||||||
- **workspace-write**: You can read files. You can write to files in your workspace folder, but not outside it.
|
|
||||||
- **danger-full-access**: No filesystem sandboxing.
|
|
||||||
|
|
||||||
Network sandboxing prevents you from accessing network without approval. Options are
|
|
||||||
|
|
||||||
- **restricted**
|
|
||||||
- **enabled**
|
|
||||||
|
|
||||||
Approvals are your mechanism to get user consent to perform more privileged actions. Although they introduce friction to the user because your work is paused until the user responds, you should leverage them to accomplish your important work. Do not let these settings or the sandbox deter you from attempting to accomplish the user's task. Approval options are
|
|
||||||
|
|
||||||
- **untrusted**: The harness will escalate most commands for user approval, apart from a limited allowlist of safe "read" commands.
|
|
||||||
- **on-failure**: The harness will allow all commands to run in the sandbox (if enabled), and failures will be escalated to the user for approval to run again without the sandbox.
|
|
||||||
- **on-request**: Commands will be run in the sandbox by default, and you can specify in your tool call if you want to escalate a command to run without sandboxing. (Note that this mode is not always available. If it is, you'll see parameters for it in the `shell` command description.)
|
|
||||||
- **never**: This is a non-interactive mode where you may NEVER ask the user for approval to run commands. Instead, you must always persist and work around constraints to solve the task for the user. You MUST do your utmost best to finish the task and validate your work before yielding. If this mode is pared with `danger-full-access`, take advantage of it to deliver the best outcome for the user. Further, in this mode, your default testing philosophy is overridden: Even if you don't see local patterns for testing, you may add tests and scripts to validate your work. Just remove them before yielding.
|
|
||||||
|
|
||||||
When you are running with approvals `on-request`, and sandboxing enabled, here are scenarios where you'll need to request approval:
|
|
||||||
|
|
||||||
- You need to run a command that writes to a directory that requires it (e.g. running tests that write to /tmp)
|
|
||||||
- You need to run a GUI app (e.g., open/xdg-open/osascript) to open browsers or files.
|
|
||||||
- You are running sandboxed and need to run a command that requires network access (e.g. installing packages)
|
|
||||||
- If you run a command that is important to solving the user's query, but it fails because of sandboxing, rerun the command with approval.
|
|
||||||
- You are about to take a potentially destructive action such as an `rm` or `git reset` that the user did not explicitly ask for
|
|
||||||
- (For all of these, you should weigh alternative paths that do not require approval.)
|
|
||||||
|
|
||||||
Note that when sandboxing is set to read-only, you'll need to request approval for any command that isn't a read.
|
|
||||||
|
|
||||||
You will be told what filesystem sandboxing, network sandboxing, and approval mode are active in a developer or user message. If you are not told about this, assume that you are running with workspace-write, network sandboxing ON, and approval on-failure.
|
|
||||||
|
|
||||||
## Validating your work
|
|
||||||
|
|
||||||
If the codebase has tests or the ability to build or run, consider using them to verify that your work is complete.
|
|
||||||
|
|
||||||
When testing, your philosophy should be to start as specific as possible to the code you changed so that you can catch issues efficiently, then make your way to broader tests as you build confidence. If there's no test for the code you changed, and if the adjacent patterns in the codebases show that there's a logical place for you to add a test, you may do so. However, do not add tests to codebases with no tests.
|
|
||||||
|
|
||||||
Similarly, once you're confident in correctness, you can suggest or use formatting commands to ensure that your code is well formatted. If there are issues you can iterate up to 3 times to get formatting right, but if you still can't manage it's better to save the user time and present them a correct solution where you call out the formatting in your final message. If the codebase does not have a formatter configured, do not add one.
|
|
||||||
|
|
||||||
For all of testing, running, building, and formatting, do not attempt to fix unrelated bugs. It is not your responsibility to fix them. (You may mention them to the user in your final message though.)
|
|
||||||
|
|
||||||
Be mindful of whether to run validation commands proactively. In the absence of behavioral guidance:
|
|
||||||
|
|
||||||
- When running in non-interactive approval modes like **never** or **on-failure**, proactively run tests, lint and do whatever you need to ensure you've completed the task.
|
|
||||||
- When working in interactive approval modes like **untrusted**, or **on-request**, hold off on running tests or lint commands until the user is ready for you to finalize your output, because these commands take time to run and slow down iteration. Instead suggest what you want to do next, and let the user confirm first.
|
|
||||||
- When working on test-related tasks, such as adding tests, fixing tests, or reproducing a bug to verify behavior, you may proactively run tests regardless of approval mode. Use your judgement to decide whether this is a test-related task.
|
|
||||||
|
|
||||||
## Ambition vs. precision
|
|
||||||
|
|
||||||
For tasks that have no prior context (i.e. the user is starting something brand new), you should feel free to be ambitious and demonstrate creativity with your implementation.
|
|
||||||
|
|
||||||
If you're operating in an existing codebase, you should make sure you do exactly what the user asks with surgical precision. Treat the surrounding codebase with respect, and don't overstep (i.e. changing filenames or variables unnecessarily). You should balance being sufficiently ambitious and proactive when completing tasks of this nature.
|
|
||||||
|
|
||||||
You should use judicious initiative to decide on the right level of detail and complexity to deliver based on the user's needs. This means showing good judgment that you're capable of doing the right extras without gold-plating. This might be demonstrated by high-value, creative touches when scope of the task is vague; while being surgical and targeted when scope is tightly specified.
|
|
||||||
|
|
||||||
## Sharing progress updates
|
|
||||||
|
|
||||||
For especially longer tasks that you work on (i.e. requiring many tool calls, or a plan with multiple steps), you should provide progress updates back to the user at reasonable intervals. These updates should be structured as a concise sentence or two (no more than 8-10 words long) recapping progress so far in plain language: this update demonstrates your understanding of what needs to be done, progress so far (i.e. files explores, subtasks complete), and where you're going next.
|
|
||||||
|
|
||||||
Before doing large chunks of work that may incur latency as experienced by the user (i.e. writing a new file), you should send a concise message to the user with an update indicating what you're about to do to ensure they know what you're spending time on. Don't start editing or writing large files before informing the user what you are doing and why.
|
|
||||||
|
|
||||||
The messages you send before tool calls should describe what is immediately about to be done next in very concise language. If there was previous work done, this preamble message should also include a note about the work done so far to bring the user along.
|
|
||||||
|
|
||||||
## Presenting your work and final message
|
## Presenting your work and final message
|
||||||
|
|
||||||
Your final message should read naturally, like an update from a concise teammate. For casual conversation, brainstorming tasks, or quick questions from the user, respond in a friendly, conversational tone. You should ask questions, suggest ideas, and adapt to the user’s style. If you've finished a large amount of work, when describing what you've done to the user, you should follow the final answer formatting guidelines to communicate substantive changes. You don't need to add structured formatting for one-word answers, greetings, or purely conversational exchanges.
|
|
||||||
|
|
||||||
You can skip heavy formatting for single, simple actions or confirmations. In these cases, respond in plain sentences with any relevant next step or quick option. Reserve multisection structured responses for results that need grouping or explanation.
|
|
||||||
|
|
||||||
The user is working on the same computer as you, and has access to your work. As such there's no need to show the full contents of large files you have already written unless the user explicitly asks for them. Similarly, if you've created or modified files using `edit`, there's no need to tell users to "save the file" or "copy the code into a file"—just reference the file path.
|
|
||||||
|
|
||||||
If there's something that you think you could help with as a logical next step, concisely ask the user if they want you to do so. Good examples of this are running tests, committing changes, or building out the next logical component. If there’s something that you couldn't do (even with approval) but that the user might want to do (such as verifying changes by running the app), include those instructions succinctly.
|
|
||||||
|
|
||||||
Brevity is very important as a default. You should be very concise (i.e. no more than 10 lines), but can relax this requirement for tasks where additional detail and comprehensiveness is important for the user's understanding.
|
|
||||||
|
|
||||||
### Final answer structure and style guidelines
|
|
||||||
|
|
||||||
You are producing plain text that will later be styled by the CLI. Follow these rules exactly. Formatting should make results easy to scan, but not feel mechanical. Use judgment to decide how much structure adds value.
|
You are producing plain text that will later be styled by the CLI. Follow these rules exactly. Formatting should make results easy to scan, but not feel mechanical. Use judgment to decide how much structure adds value.
|
||||||
|
|
||||||
**Section Headers**
|
- Default: be very concise; friendly coding teammate tone.
|
||||||
|
- Default: do the work without asking questions. Treat short tasks as sufficient direction; infer missing details by reading the codebase and following existing conventions.
|
||||||
|
- Questions: only ask when you are truly blocked after checking relevant context AND you cannot safely pick a reasonable default. This usually means one of:
|
||||||
|
* The request is ambiguous in a way that materially changes the result and you cannot disambiguate by reading the repo.
|
||||||
|
* The action is destructive/irreversible, touches production, or changes billing/security posture.
|
||||||
|
* You need a secret/credential/value that cannot be inferred (API key, account id, etc.).
|
||||||
|
- If you must ask: do all non-blocked work first, then ask exactly one targeted question, include your recommended default, and state what would change based on the answer.
|
||||||
|
- Never ask permission questions like "Should I proceed?" or "Do you want me to run tests?"; proceed with the most reasonable option and mention what you did.
|
||||||
|
- For substantial work, summarize clearly; follow final‑answer formatting.
|
||||||
|
- Skip heavy formatting for simple confirmations.
|
||||||
|
- Don't dump large files you've written; reference paths only.
|
||||||
|
- No "save/copy this file" - User is on the same machine.
|
||||||
|
- Offer logical next steps (tests, commits, build) briefly; add verify steps if you couldn't do something.
|
||||||
|
- For code changes:
|
||||||
|
* Lead with a quick explanation of the change, and then give more details on the context covering where and why a change was made. Do not start this explanation with "summary", just jump right in.
|
||||||
|
* If there are natural next steps the user may want to take, suggest them at the end of your response. Do not make suggestions if there are no natural next steps.
|
||||||
|
* When suggesting multiple options, use numeric lists for the suggestions so the user can quickly respond with a single number.
|
||||||
|
- The user does not command execution outputs. When asked to show the output of a command (e.g. `git show`), relay the important details in your answer or summarize the key lines so the user understands the result.
|
||||||
|
|
||||||
- Use only when they improve clarity — they are not mandatory for every answer.
|
## Final answer structure and style guidelines
|
||||||
- Choose descriptive names that fit the content
|
|
||||||
- Keep headers short (1–3 words) and in `**Title Case**`. Always start headers with `**` and end with `**`
|
|
||||||
- Leave no blank line before the first bullet under a header.
|
|
||||||
- Section headers should only be used where they genuinely improve scannability; avoid fragmenting the answer.
|
|
||||||
|
|
||||||
**Bullets**
|
- Plain text; CLI handles styling. Use structure only when it helps scanability.
|
||||||
|
- Headers: optional; short Title Case (1-3 words) wrapped in **…**; no blank line before the first bullet; add only if they truly help.
|
||||||
- Use `-` followed by a space for every bullet.
|
- Bullets: use - ; merge related points; keep to one line when possible; 4–6 per list ordered by importance; keep phrasing consistent.
|
||||||
- Merge related points when possible; avoid a bullet for every trivial detail.
|
- Monospace: backticks for commands/paths/env vars/code ids and inline examples; use for literal keyword bullets; never combine with **.
|
||||||
- Keep bullets to one line unless breaking for clarity is unavoidable.
|
- Code samples or multi-line snippets should be wrapped in fenced code blocks; include an info string as often as possible.
|
||||||
- Group into short lists (4–6 bullets) ordered by importance.
|
- Structure: group related bullets; order sections general → specific → supporting; for subsections, start with a bolded keyword bullet, then items; match complexity to the task.
|
||||||
- Use consistent keyword phrasing and formatting across sections.
|
- Tone: collaborative, concise, factual; present tense, active voice; self‑contained; no "above/below"; parallel wording.
|
||||||
|
- Don'ts: no nested bullets/hierarchies; no ANSI codes; don't cram unrelated keywords; keep keyword lists short—wrap/reformat if long; avoid naming formatting styles in answers.
|
||||||
**Monospace**
|
- Adaptation: code explanations → precise, structured with code refs; simple tasks → lead with outcome; big changes → logical walkthrough + rationale + next actions; casual one-offs → plain sentences, no headers/bullets.
|
||||||
|
- File References: When referencing files in your response follow the below rules:
|
||||||
- Wrap all commands, file paths, env vars, and code identifiers in backticks (`` `...` ``).
|
|
||||||
- Apply to inline examples and to bullet keywords if the keyword itself is a literal file/command.
|
|
||||||
- Never mix monospace and bold markers; choose one based on whether it’s a keyword (`**`) or inline code/path (`` ` ``).
|
|
||||||
|
|
||||||
**File References**
|
|
||||||
When referencing files in your response, make sure to include the relevant start line and always follow the below rules:
|
|
||||||
* Use inline code to make file paths clickable.
|
* Use inline code to make file paths clickable.
|
||||||
* Each reference should have a standalone path. Even if it's the same file.
|
* Each reference should have a stand alone path. Even if it's the same file.
|
||||||
* Accepted: absolute, workspace‑relative, a/ or b/ diff prefixes, or bare filename/suffix.
|
* Accepted: absolute, workspace‑relative, a/ or b/ diff prefixes, or bare filename/suffix.
|
||||||
* Line/column (1‑based, optional): :line[:column] or #Lline[Ccolumn] (column defaults to 1).
|
* Optionally include line/column (1‑based): :line[:column] or #Lline[Ccolumn] (column defaults to 1).
|
||||||
* Do not use URIs like file://, vscode://, or https://.
|
* Do not use URIs like file://, vscode://, or https://.
|
||||||
* Do not provide range of lines
|
* Do not provide range of lines
|
||||||
* Examples: src/app.ts, src/app.ts:42, b/server/index.js#L10, C:\repo\project\main.rs:12:5
|
* Examples: src/app.ts, src/app.ts:42, b/server/index.js#L10, C:\repo\project\main.rs:12:5
|
||||||
|
|
||||||
**Structure**
|
|
||||||
|
|
||||||
- Place related bullets together; don’t mix unrelated concepts in the same section.
|
|
||||||
- Order sections from general → specific → supporting info.
|
|
||||||
- For subsections (e.g., “Binaries” under “Rust Workspace”), introduce with a bolded keyword bullet, then list items under it.
|
|
||||||
- Match structure to complexity:
|
|
||||||
- Multi-part or detailed results → use clear headers and grouped bullets.
|
|
||||||
- Simple results → minimal headers, possibly just a short list or paragraph.
|
|
||||||
|
|
||||||
**Tone**
|
|
||||||
|
|
||||||
- Keep the voice collaborative and natural, like a coding partner handing off work.
|
|
||||||
- Be concise and factual — no filler or conversational commentary and avoid unnecessary repetition
|
|
||||||
- Use present tense and active voice (e.g., “Runs tests” not “This will run tests”).
|
|
||||||
- Keep descriptions self-contained; don’t refer to “above” or “below”.
|
|
||||||
- Use parallel structure in lists for consistency.
|
|
||||||
|
|
||||||
**Don’t**
|
|
||||||
|
|
||||||
- Don’t use literal words “bold” or “monospace” in the content.
|
|
||||||
- Don’t nest bullets or create deep hierarchies.
|
|
||||||
- Don’t output ANSI escape codes directly — the CLI renderer applies them.
|
|
||||||
- Don’t cram unrelated keywords into a single bullet; split for clarity.
|
|
||||||
- Don’t let keyword lists run long — wrap or reformat for scannability.
|
|
||||||
|
|
||||||
Generally, ensure your final answers adapt their shape and depth to the request. For example, answers to code explanations should have a precise, structured explanation with code references that answer the question directly. For tasks with a simple implementation, lead with the outcome and supplement only with what’s needed for clarity. Larger changes can be presented as a logical walkthrough of your approach, grouping related steps, explaining rationale where it adds value, and highlighting next actions to accelerate the user. Your answers should provide the right level of detail while being easily scannable.
|
|
||||||
|
|
||||||
For casual greetings, acknowledgements, or other one-off conversational messages that are not delivering substantive information or structured results, respond naturally without section headers or bullet formatting.
|
|
||||||
|
|
||||||
# Tool Guidelines
|
|
||||||
|
|
||||||
## Shell commands
|
|
||||||
|
|
||||||
When using the shell, you must adhere to the following guidelines:
|
|
||||||
|
|
||||||
- When searching for text or files, prefer using `rg` or `rg --files` respectively because `rg` is much faster than alternatives like `grep`. (If the `rg` command is not found, then use alternatives.)
|
|
||||||
- Read files in chunks with a max chunk size of 250 lines. Do not use python scripts to attempt to output larger chunks of a file. Command line output will be truncated after 10 kilobytes or 256 lines of output, regardless of the command used.
|
|
||||||
|
|
||||||
## `todowrite`
|
|
||||||
|
|
||||||
A tool named `todowrite` is available to you. You can use it to keep an up‑to‑date, step‑by‑step plan for the task.
|
|
||||||
|
|
||||||
To create a new plan, call `todowrite` with a short list of 1‑sentence steps (no more than 5-7 words each) with a `status` for each step (`pending`, `in_progress`, or `completed`).
|
|
||||||
|
|
||||||
When steps have been completed, use `todowrite` to mark each finished step as
|
|
||||||
`completed` and the next step you are working on as `in_progress`. There should
|
|
||||||
always be exactly one `in_progress` step until everything is done. You can mark
|
|
||||||
multiple items as complete in a single `todowrite` call.
|
|
||||||
|
|
||||||
If all steps are complete, ensure you call `todowrite` to mark all steps as `completed`.
|
|
||||||
|
|||||||
59
internal/routing/extractor.go
Normal file
59
internal/routing/extractor.go
Normal file
@@ -0,0 +1,59 @@
|
|||||||
|
package routing
|
||||||
|
|
||||||
|
import (
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/tidwall/gjson"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ModelExtractor extracts model names from request data.
|
||||||
|
type ModelExtractor interface {
|
||||||
|
// Extract returns the model name from the request body and gin parameters.
|
||||||
|
// The ginParams map contains route parameters like "action" and "path".
|
||||||
|
Extract(body []byte, ginParams map[string]string) (string, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
// DefaultModelExtractor is the standard implementation of ModelExtractor.
|
||||||
|
type DefaultModelExtractor struct{}
|
||||||
|
|
||||||
|
// NewModelExtractor creates a new DefaultModelExtractor.
|
||||||
|
func NewModelExtractor() *DefaultModelExtractor {
|
||||||
|
return &DefaultModelExtractor{}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Extract extracts the model name from the request.
|
||||||
|
// It checks in order:
|
||||||
|
// 1. JSON body "model" field (OpenAI, Claude format)
|
||||||
|
// 2. "action" parameter for Gemini standard format (e.g., "gemini-pro:generateContent")
|
||||||
|
// 3. "path" parameter for AMP CLI Gemini format (e.g., "/publishers/google/models/gemini-3-pro:streamGenerateContent")
|
||||||
|
func (e *DefaultModelExtractor) Extract(body []byte, ginParams map[string]string) (string, error) {
|
||||||
|
// First try to parse from JSON body (OpenAI, Claude, etc.)
|
||||||
|
if result := gjson.GetBytes(body, "model"); result.Exists() && result.Type == gjson.String {
|
||||||
|
return result.String(), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// For Gemini requests, model is in the URL path
|
||||||
|
// Standard format: /models/{model}:generateContent -> :action parameter
|
||||||
|
if action, ok := ginParams["action"]; ok && action != "" {
|
||||||
|
// Split by colon to get model name (e.g., "gemini-pro:generateContent" -> "gemini-pro")
|
||||||
|
parts := strings.Split(action, ":")
|
||||||
|
if len(parts) > 0 && parts[0] != "" {
|
||||||
|
return parts[0], nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// AMP CLI format: /publishers/google/models/{model}:method -> *path parameter
|
||||||
|
// Example: /publishers/google/models/gemini-3-pro-preview:streamGenerateContent
|
||||||
|
if path, ok := ginParams["path"]; ok && path != "" {
|
||||||
|
// Look for /models/{model}:method pattern
|
||||||
|
if idx := strings.Index(path, "/models/"); idx >= 0 {
|
||||||
|
modelPart := path[idx+8:] // Skip "/models/"
|
||||||
|
// Split by colon to get model name
|
||||||
|
if colonIdx := strings.Index(modelPart, ":"); colonIdx > 0 {
|
||||||
|
return modelPart[:colonIdx], nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return "", nil
|
||||||
|
}
|
||||||
214
internal/routing/extractor_test.go
Normal file
214
internal/routing/extractor_test.go
Normal file
@@ -0,0 +1,214 @@
|
|||||||
|
package routing
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestModelExtractor_ExtractFromJSONBody(t *testing.T) {
|
||||||
|
extractor := NewModelExtractor()
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
body []byte
|
||||||
|
want string
|
||||||
|
wantErr bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "extract from JSON body with model field",
|
||||||
|
body: []byte(`{"model":"gpt-4.1"}`),
|
||||||
|
want: "gpt-4.1",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "extract claude model from JSON body",
|
||||||
|
body: []byte(`{"model":"claude-3-5-sonnet-20241022"}`),
|
||||||
|
want: "claude-3-5-sonnet-20241022",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "extract with additional fields",
|
||||||
|
body: []byte(`{"model":"gpt-4","messages":[{"role":"user","content":"hello"}]}`),
|
||||||
|
want: "gpt-4",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "empty body returns empty",
|
||||||
|
body: []byte{},
|
||||||
|
want: "",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "no model field returns empty",
|
||||||
|
body: []byte(`{"messages":[]}`),
|
||||||
|
want: "",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "model is not string returns empty",
|
||||||
|
body: []byte(`{"model":123}`),
|
||||||
|
want: "",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
got, err := extractor.Extract(tt.body, nil)
|
||||||
|
if tt.wantErr {
|
||||||
|
assert.Error(t, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, tt.want, got)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestModelExtractor_ExtractFromGeminiActionParam(t *testing.T) {
|
||||||
|
extractor := NewModelExtractor()
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
body []byte
|
||||||
|
ginParams map[string]string
|
||||||
|
want string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "extract from action parameter - gemini-pro",
|
||||||
|
body: []byte(`{}`),
|
||||||
|
ginParams: map[string]string{"action": "gemini-pro:generateContent"},
|
||||||
|
want: "gemini-pro",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "extract from action parameter - gemini-ultra",
|
||||||
|
body: []byte(`{}`),
|
||||||
|
ginParams: map[string]string{"action": "gemini-ultra:chat"},
|
||||||
|
want: "gemini-ultra",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "empty action returns empty",
|
||||||
|
body: []byte(`{}`),
|
||||||
|
ginParams: map[string]string{"action": ""},
|
||||||
|
want: "",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "action without colon returns full value",
|
||||||
|
body: []byte(`{}`),
|
||||||
|
ginParams: map[string]string{"action": "gemini-model"},
|
||||||
|
want: "gemini-model",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
got, err := extractor.Extract(tt.body, tt.ginParams)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, tt.want, got)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestModelExtractor_ExtractFromGeminiV1Beta1Path(t *testing.T) {
|
||||||
|
extractor := NewModelExtractor()
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
body []byte
|
||||||
|
ginParams map[string]string
|
||||||
|
want string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "extract from v1beta1 path - gemini-3-pro",
|
||||||
|
body: []byte(`{}`),
|
||||||
|
ginParams: map[string]string{"path": "/publishers/google/models/gemini-3-pro:streamGenerateContent"},
|
||||||
|
want: "gemini-3-pro",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "extract from v1beta1 path with preview",
|
||||||
|
body: []byte(`{}`),
|
||||||
|
ginParams: map[string]string{"path": "/publishers/google/models/gemini-3-pro-preview:generateContent"},
|
||||||
|
want: "gemini-3-pro-preview",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "path without models segment returns empty",
|
||||||
|
body: []byte(`{}`),
|
||||||
|
ginParams: map[string]string{"path": "/publishers/google/gemini-3-pro:streamGenerateContent"},
|
||||||
|
want: "",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "empty path returns empty",
|
||||||
|
body: []byte(`{}`),
|
||||||
|
ginParams: map[string]string{"path": ""},
|
||||||
|
want: "",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "path with /models/ but no colon returns empty",
|
||||||
|
body: []byte(`{}`),
|
||||||
|
ginParams: map[string]string{"path": "/publishers/google/models/gemini-3-pro"},
|
||||||
|
want: "",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
got, err := extractor.Extract(tt.body, tt.ginParams)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, tt.want, got)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestModelExtractor_ExtractPriority(t *testing.T) {
|
||||||
|
extractor := NewModelExtractor()
|
||||||
|
|
||||||
|
// JSON body takes priority over gin params
|
||||||
|
t.Run("JSON body takes priority over action param", func(t *testing.T) {
|
||||||
|
body := []byte(`{"model":"gpt-4"}`)
|
||||||
|
params := map[string]string{"action": "gemini-pro:generateContent"}
|
||||||
|
got, err := extractor.Extract(body, params)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, "gpt-4", got)
|
||||||
|
})
|
||||||
|
|
||||||
|
// Action param takes priority over path param
|
||||||
|
t.Run("action param takes priority over path param", func(t *testing.T) {
|
||||||
|
body := []byte(`{}`)
|
||||||
|
params := map[string]string{
|
||||||
|
"action": "gemini-action:generate",
|
||||||
|
"path": "/publishers/google/models/gemini-path:streamGenerateContent",
|
||||||
|
}
|
||||||
|
got, err := extractor.Extract(body, params)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Equal(t, "gemini-action", got)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestModelExtractor_NoModelFound(t *testing.T) {
|
||||||
|
extractor := NewModelExtractor()
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
body []byte
|
||||||
|
ginParams map[string]string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "empty body and no params",
|
||||||
|
body: []byte{},
|
||||||
|
ginParams: nil,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "body without model and no params",
|
||||||
|
body: []byte(`{"messages":[]}`),
|
||||||
|
ginParams: map[string]string{},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "irrelevant params only",
|
||||||
|
body: []byte(`{}`),
|
||||||
|
ginParams: map[string]string{"other": "value"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
got, err := extractor.Extract(tt.body, tt.ginParams)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Empty(t, got)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
159
internal/routing/rewriter.go
Normal file
159
internal/routing/rewriter.go
Normal file
@@ -0,0 +1,159 @@
|
|||||||
|
package routing
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/tidwall/gjson"
|
||||||
|
"github.com/tidwall/sjson"
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ModelRewriter handles model name rewriting in requests and responses.
|
||||||
|
type ModelRewriter interface {
|
||||||
|
// RewriteRequestBody rewrites the model field in a JSON request body.
|
||||||
|
// Returns the modified body or the original if no rewrite was needed.
|
||||||
|
RewriteRequestBody(body []byte, newModel string) ([]byte, error)
|
||||||
|
|
||||||
|
// WrapResponseWriter wraps an http.ResponseWriter to rewrite model names in the response.
|
||||||
|
// Returns the wrapped writer and a cleanup function that must be called after the response is complete.
|
||||||
|
WrapResponseWriter(w http.ResponseWriter, requestedModel, resolvedModel string) (http.ResponseWriter, func())
|
||||||
|
}
|
||||||
|
|
||||||
|
// DefaultModelRewriter is the standard implementation of ModelRewriter.
|
||||||
|
type DefaultModelRewriter struct{}
|
||||||
|
|
||||||
|
// NewModelRewriter creates a new DefaultModelRewriter.
|
||||||
|
func NewModelRewriter() *DefaultModelRewriter {
|
||||||
|
return &DefaultModelRewriter{}
|
||||||
|
}
|
||||||
|
|
||||||
|
// RewriteRequestBody replaces the model name in a JSON request body.
|
||||||
|
func (r *DefaultModelRewriter) RewriteRequestBody(body []byte, newModel string) ([]byte, error) {
|
||||||
|
if !gjson.GetBytes(body, "model").Exists() {
|
||||||
|
return body, nil
|
||||||
|
}
|
||||||
|
result, err := sjson.SetBytes(body, "model", newModel)
|
||||||
|
if err != nil {
|
||||||
|
return body, err
|
||||||
|
}
|
||||||
|
return result, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// WrapResponseWriter wraps a response writer to rewrite model names.
|
||||||
|
// The cleanup function must be called after the handler completes to flush any buffered data.
|
||||||
|
func (r *DefaultModelRewriter) WrapResponseWriter(w http.ResponseWriter, requestedModel, resolvedModel string) (http.ResponseWriter, func()) {
|
||||||
|
rw := &responseRewriter{
|
||||||
|
ResponseWriter: w,
|
||||||
|
body: &bytes.Buffer{},
|
||||||
|
requestedModel: requestedModel,
|
||||||
|
resolvedModel: resolvedModel,
|
||||||
|
}
|
||||||
|
return rw, func() { rw.flush() }
|
||||||
|
}
|
||||||
|
|
||||||
|
// responseRewriter wraps http.ResponseWriter to intercept and modify the response body.
|
||||||
|
type responseRewriter struct {
|
||||||
|
http.ResponseWriter
|
||||||
|
body *bytes.Buffer
|
||||||
|
requestedModel string
|
||||||
|
resolvedModel string
|
||||||
|
isStreaming bool
|
||||||
|
wroteHeader bool
|
||||||
|
flushed bool
|
||||||
|
}
|
||||||
|
|
||||||
|
// Write intercepts response writes and buffers them for model name replacement.
|
||||||
|
func (rw *responseRewriter) Write(data []byte) (int, error) {
|
||||||
|
// Ensure header is written
|
||||||
|
if !rw.wroteHeader {
|
||||||
|
rw.WriteHeader(http.StatusOK)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Detect streaming on first write
|
||||||
|
if rw.body.Len() == 0 && !rw.isStreaming {
|
||||||
|
contentType := rw.Header().Get("Content-Type")
|
||||||
|
rw.isStreaming = strings.Contains(contentType, "text/event-stream") ||
|
||||||
|
strings.Contains(contentType, "stream")
|
||||||
|
}
|
||||||
|
|
||||||
|
if rw.isStreaming {
|
||||||
|
n, err := rw.ResponseWriter.Write(rw.rewriteStreamChunk(data))
|
||||||
|
if err == nil {
|
||||||
|
if flusher, ok := rw.ResponseWriter.(http.Flusher); ok {
|
||||||
|
flusher.Flush()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return n, err
|
||||||
|
}
|
||||||
|
return rw.body.Write(data)
|
||||||
|
}
|
||||||
|
|
||||||
|
// WriteHeader captures the status code and delegates to the underlying writer.
|
||||||
|
func (rw *responseRewriter) WriteHeader(code int) {
|
||||||
|
if !rw.wroteHeader {
|
||||||
|
rw.wroteHeader = true
|
||||||
|
rw.ResponseWriter.WriteHeader(code)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// flush writes the buffered response with model names rewritten.
|
||||||
|
func (rw *responseRewriter) flush() {
|
||||||
|
if rw.flushed {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
rw.flushed = true
|
||||||
|
|
||||||
|
if rw.isStreaming {
|
||||||
|
if flusher, ok := rw.ResponseWriter.(http.Flusher); ok {
|
||||||
|
flusher.Flush()
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if rw.body.Len() > 0 {
|
||||||
|
data := rw.rewriteModelInResponse(rw.body.Bytes())
|
||||||
|
if _, err := rw.ResponseWriter.Write(data); err != nil {
|
||||||
|
log.Warnf("response rewriter: failed to write rewritten response: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// modelFieldPaths lists all JSON paths where model name may appear.
|
||||||
|
var modelFieldPaths = []string{"model", "modelVersion", "response.modelVersion", "message.model"}
|
||||||
|
|
||||||
|
// rewriteModelInResponse replaces all occurrences of the resolved model with the requested model.
|
||||||
|
func (rw *responseRewriter) rewriteModelInResponse(data []byte) []byte {
|
||||||
|
if rw.requestedModel == "" || rw.resolvedModel == "" || rw.requestedModel == rw.resolvedModel {
|
||||||
|
return data
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, path := range modelFieldPaths {
|
||||||
|
if gjson.GetBytes(data, path).Exists() {
|
||||||
|
data, _ = sjson.SetBytes(data, path, rw.requestedModel)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return data
|
||||||
|
}
|
||||||
|
|
||||||
|
// rewriteStreamChunk rewrites model names in SSE stream chunks.
|
||||||
|
func (rw *responseRewriter) rewriteStreamChunk(chunk []byte) []byte {
|
||||||
|
if rw.requestedModel == "" || rw.resolvedModel == "" || rw.requestedModel == rw.resolvedModel {
|
||||||
|
return chunk
|
||||||
|
}
|
||||||
|
|
||||||
|
// SSE format: "data: {json}\n\n"
|
||||||
|
lines := bytes.Split(chunk, []byte("\n"))
|
||||||
|
for i, line := range lines {
|
||||||
|
if bytes.HasPrefix(line, []byte("data: ")) {
|
||||||
|
jsonData := bytes.TrimPrefix(line, []byte("data: "))
|
||||||
|
if len(jsonData) > 0 && jsonData[0] == '{' {
|
||||||
|
// Rewrite JSON in the data line
|
||||||
|
rewritten := rw.rewriteModelInResponse(jsonData)
|
||||||
|
lines[i] = append([]byte("data: "), rewritten...)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return bytes.Join(lines, []byte("\n"))
|
||||||
|
}
|
||||||
342
internal/routing/rewriter_test.go
Normal file
342
internal/routing/rewriter_test.go
Normal file
@@ -0,0 +1,342 @@
|
|||||||
|
package routing
|
||||||
|
|
||||||
|
import (
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestModelRewriter_RewriteRequestBody(t *testing.T) {
|
||||||
|
rewriter := NewModelRewriter()
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
body []byte
|
||||||
|
newModel string
|
||||||
|
wantModel string
|
||||||
|
wantChange bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "rewrites model field in JSON body",
|
||||||
|
body: []byte(`{"model":"gpt-4.1","messages":[]}`),
|
||||||
|
newModel: "claude-local",
|
||||||
|
wantModel: "claude-local",
|
||||||
|
wantChange: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "rewrites with empty body returns empty",
|
||||||
|
body: []byte{},
|
||||||
|
newModel: "gpt-4",
|
||||||
|
wantModel: "",
|
||||||
|
wantChange: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "handles missing model field gracefully",
|
||||||
|
body: []byte(`{"messages":[{"role":"user"}]}`),
|
||||||
|
newModel: "gpt-4",
|
||||||
|
wantModel: "",
|
||||||
|
wantChange: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "preserves other fields when rewriting",
|
||||||
|
body: []byte(`{"model":"old-model","temperature":0.7,"max_tokens":100}`),
|
||||||
|
newModel: "new-model",
|
||||||
|
wantModel: "new-model",
|
||||||
|
wantChange: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "handles nested JSON structure",
|
||||||
|
body: []byte(`{"model":"gpt-4","messages":[{"role":"user","content":"hello"}],"stream":true}`),
|
||||||
|
newModel: "claude-3-opus",
|
||||||
|
wantModel: "claude-3-opus",
|
||||||
|
wantChange: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
result, err := rewriter.RewriteRequestBody(tt.body, tt.newModel)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
if tt.wantChange {
|
||||||
|
assert.NotEqual(t, string(tt.body), string(result), "body should have been modified")
|
||||||
|
}
|
||||||
|
|
||||||
|
if tt.wantModel != "" {
|
||||||
|
// Parse result and check model field
|
||||||
|
model, _ := NewModelExtractor().Extract(result, nil)
|
||||||
|
assert.Equal(t, tt.wantModel, model)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestModelRewriter_WrapResponseWriter(t *testing.T) {
|
||||||
|
rewriter := NewModelRewriter()
|
||||||
|
|
||||||
|
t.Run("response writer wraps without error", func(t *testing.T) {
|
||||||
|
recorder := httptest.NewRecorder()
|
||||||
|
wrapped, cleanup := rewriter.WrapResponseWriter(recorder, "gpt-4", "claude-local")
|
||||||
|
require.NotNil(t, wrapped)
|
||||||
|
require.NotNil(t, cleanup)
|
||||||
|
defer cleanup()
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("rewrites model in non-streaming response", func(t *testing.T) {
|
||||||
|
recorder := httptest.NewRecorder()
|
||||||
|
wrapped, cleanup := rewriter.WrapResponseWriter(recorder, "gpt-4", "claude-local")
|
||||||
|
|
||||||
|
// Write a response with the resolved model
|
||||||
|
response := []byte(`{"model":"claude-local","content":"hello"}`)
|
||||||
|
wrapped.Header().Set("Content-Type", "application/json")
|
||||||
|
_, err := wrapped.Write(response)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Cleanup triggers the rewrite
|
||||||
|
cleanup()
|
||||||
|
|
||||||
|
// Check the response was rewritten to the requested model
|
||||||
|
body := recorder.Body.Bytes()
|
||||||
|
assert.Contains(t, string(body), `"model":"gpt-4"`)
|
||||||
|
assert.NotContains(t, string(body), `"model":"claude-local"`)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("no-op when requested equals resolved", func(t *testing.T) {
|
||||||
|
recorder := httptest.NewRecorder()
|
||||||
|
wrapped, cleanup := rewriter.WrapResponseWriter(recorder, "gpt-4", "gpt-4")
|
||||||
|
|
||||||
|
response := []byte(`{"model":"gpt-4","content":"hello"}`)
|
||||||
|
wrapped.Header().Set("Content-Type", "application/json")
|
||||||
|
_, err := wrapped.Write(response)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
cleanup()
|
||||||
|
|
||||||
|
body := recorder.Body.Bytes()
|
||||||
|
assert.Contains(t, string(body), `"model":"gpt-4"`)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("rewrites modelVersion field", func(t *testing.T) {
|
||||||
|
recorder := httptest.NewRecorder()
|
||||||
|
wrapped, cleanup := rewriter.WrapResponseWriter(recorder, "gpt-4", "claude-local")
|
||||||
|
|
||||||
|
response := []byte(`{"modelVersion":"claude-local","content":"hello"}`)
|
||||||
|
wrapped.Header().Set("Content-Type", "application/json")
|
||||||
|
_, err := wrapped.Write(response)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
cleanup()
|
||||||
|
|
||||||
|
body := recorder.Body.Bytes()
|
||||||
|
assert.Contains(t, string(body), `"modelVersion":"gpt-4"`)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("handles streaming responses", func(t *testing.T) {
|
||||||
|
recorder := httptest.NewRecorder()
|
||||||
|
wrapped, cleanup := rewriter.WrapResponseWriter(recorder, "gpt-4", "claude-local")
|
||||||
|
|
||||||
|
// Set streaming content type
|
||||||
|
wrapped.Header().Set("Content-Type", "text/event-stream")
|
||||||
|
|
||||||
|
// Write SSE chunks with resolved model
|
||||||
|
chunk1 := []byte("data: {\"model\":\"claude-local\",\"delta\":\"hello\"}\n\n")
|
||||||
|
_, err := wrapped.Write(chunk1)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
chunk2 := []byte("data: {\"model\":\"claude-local\",\"delta\":\" world\"}\n\n")
|
||||||
|
_, err = wrapped.Write(chunk2)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
cleanup()
|
||||||
|
|
||||||
|
// For streaming, data is written immediately with rewrites
|
||||||
|
body := recorder.Body.Bytes()
|
||||||
|
assert.Contains(t, string(body), `"model":"gpt-4"`)
|
||||||
|
assert.NotContains(t, string(body), `"model":"claude-local"`)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("empty body handled gracefully", func(t *testing.T) {
|
||||||
|
recorder := httptest.NewRecorder()
|
||||||
|
wrapped, cleanup := rewriter.WrapResponseWriter(recorder, "gpt-4", "claude-local")
|
||||||
|
|
||||||
|
wrapped.Header().Set("Content-Type", "application/json")
|
||||||
|
// Don't write anything
|
||||||
|
|
||||||
|
cleanup()
|
||||||
|
|
||||||
|
body := recorder.Body.Bytes()
|
||||||
|
assert.Empty(t, body)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("preserves other JSON fields", func(t *testing.T) {
|
||||||
|
recorder := httptest.NewRecorder()
|
||||||
|
wrapped, cleanup := rewriter.WrapResponseWriter(recorder, "gpt-4", "claude-local")
|
||||||
|
|
||||||
|
response := []byte(`{"model":"claude-local","temperature":0.7,"usage":{"prompt_tokens":10}}`)
|
||||||
|
wrapped.Header().Set("Content-Type", "application/json")
|
||||||
|
_, err := wrapped.Write(response)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
cleanup()
|
||||||
|
|
||||||
|
body := recorder.Body.Bytes()
|
||||||
|
assert.Contains(t, string(body), `"temperature":0.7`)
|
||||||
|
assert.Contains(t, string(body), `"prompt_tokens":10`)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestResponseRewriter_ImplementsInterfaces(t *testing.T) {
|
||||||
|
rewriter := NewModelRewriter()
|
||||||
|
recorder := httptest.NewRecorder()
|
||||||
|
wrapped, cleanup := rewriter.WrapResponseWriter(recorder, "gpt-4", "claude-local")
|
||||||
|
defer cleanup()
|
||||||
|
|
||||||
|
// Should implement http.ResponseWriter
|
||||||
|
assert.Implements(t, (*http.ResponseWriter)(nil), wrapped)
|
||||||
|
|
||||||
|
// Should preserve header access
|
||||||
|
wrapped.Header().Set("X-Custom", "value")
|
||||||
|
assert.Equal(t, "value", recorder.Header().Get("X-Custom"))
|
||||||
|
|
||||||
|
// Should write status
|
||||||
|
wrapped.WriteHeader(http.StatusCreated)
|
||||||
|
assert.Equal(t, http.StatusCreated, recorder.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestResponseRewriter_Flush(t *testing.T) {
|
||||||
|
t.Run("flush writes buffered content", func(t *testing.T) {
|
||||||
|
rewriter := NewModelRewriter()
|
||||||
|
recorder := httptest.NewRecorder()
|
||||||
|
wrapped, cleanup := rewriter.WrapResponseWriter(recorder, "gpt-4", "claude-local")
|
||||||
|
|
||||||
|
response := []byte(`{"model":"claude-local","content":"test"}`)
|
||||||
|
wrapped.Header().Set("Content-Type", "application/json")
|
||||||
|
wrapped.Write(response)
|
||||||
|
|
||||||
|
// Before cleanup, response should be empty (buffered)
|
||||||
|
assert.Empty(t, recorder.Body.Bytes())
|
||||||
|
|
||||||
|
// After cleanup, response should be written
|
||||||
|
cleanup()
|
||||||
|
assert.NotEmpty(t, recorder.Body.Bytes())
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("multiple flush calls are safe", func(t *testing.T) {
|
||||||
|
rewriter := NewModelRewriter()
|
||||||
|
recorder := httptest.NewRecorder()
|
||||||
|
wrapped, cleanup := rewriter.WrapResponseWriter(recorder, "gpt-4", "claude-local")
|
||||||
|
|
||||||
|
response := []byte(`{"model":"claude-local"}`)
|
||||||
|
wrapped.Header().Set("Content-Type", "application/json")
|
||||||
|
wrapped.Write(response)
|
||||||
|
|
||||||
|
// First cleanup
|
||||||
|
cleanup()
|
||||||
|
firstBody := recorder.Body.Bytes()
|
||||||
|
|
||||||
|
// Second cleanup should not write again
|
||||||
|
cleanup()
|
||||||
|
secondBody := recorder.Body.Bytes()
|
||||||
|
|
||||||
|
assert.Equal(t, firstBody, secondBody)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestResponseRewriter_StreamingWithDataLines(t *testing.T) {
|
||||||
|
rewriter := NewModelRewriter()
|
||||||
|
recorder := httptest.NewRecorder()
|
||||||
|
wrapped, cleanup := rewriter.WrapResponseWriter(recorder, "gpt-4", "claude-local")
|
||||||
|
|
||||||
|
wrapped.Header().Set("Content-Type", "text/event-stream")
|
||||||
|
|
||||||
|
// SSE format with multiple data lines
|
||||||
|
chunk := []byte("data: {\"model\":\"claude-local\"}\n\ndata: {\"model\":\"claude-local\",\"done\":true}\n\n")
|
||||||
|
wrapped.Write(chunk)
|
||||||
|
|
||||||
|
cleanup()
|
||||||
|
|
||||||
|
body := recorder.Body.Bytes()
|
||||||
|
// Both data lines should have model rewritten
|
||||||
|
assert.Contains(t, string(body), `"model":"gpt-4"`)
|
||||||
|
assert.NotContains(t, string(body), `"model":"claude-local"`)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestModelRewriter_RoundTrip(t *testing.T) {
|
||||||
|
// Simulate a full request -> response cycle with model rewriting
|
||||||
|
rewriter := NewModelRewriter()
|
||||||
|
|
||||||
|
// Step 1: Rewrite request body
|
||||||
|
originalRequest := []byte(`{"model":"gpt-4","messages":[{"role":"user","content":"hello"}]}`)
|
||||||
|
rewrittenRequest, err := rewriter.RewriteRequestBody(originalRequest, "claude-local")
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Verify request was rewritten
|
||||||
|
extractor := NewModelExtractor()
|
||||||
|
requestModel, _ := extractor.Extract(rewrittenRequest, nil)
|
||||||
|
assert.Equal(t, "claude-local", requestModel)
|
||||||
|
|
||||||
|
// Step 2: Simulate response with resolved model
|
||||||
|
recorder := httptest.NewRecorder()
|
||||||
|
wrapped, cleanup := rewriter.WrapResponseWriter(recorder, "gpt-4", "claude-local")
|
||||||
|
|
||||||
|
response := []byte(`{"model":"claude-local","content":"Hello! How can I help?"}`)
|
||||||
|
wrapped.Header().Set("Content-Type", "application/json")
|
||||||
|
wrapped.Write(response)
|
||||||
|
cleanup()
|
||||||
|
|
||||||
|
// Verify response was rewritten back
|
||||||
|
body, _ := io.ReadAll(recorder.Result().Body)
|
||||||
|
responseModel, _ := extractor.Extract(body, nil)
|
||||||
|
assert.Equal(t, "gpt-4", responseModel)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestModelRewriter_NonJSONBody(t *testing.T) {
|
||||||
|
rewriter := NewModelRewriter()
|
||||||
|
|
||||||
|
// Binary/non-JSON body should be returned unchanged
|
||||||
|
body := []byte{0x00, 0x01, 0x02, 0x03}
|
||||||
|
result, err := rewriter.RewriteRequestBody(body, "gpt-4")
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, body, result)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestModelRewriter_InvalidJSON(t *testing.T) {
|
||||||
|
rewriter := NewModelRewriter()
|
||||||
|
|
||||||
|
// Invalid JSON without model field should be returned unchanged
|
||||||
|
body := []byte(`not valid json`)
|
||||||
|
result, err := rewriter.RewriteRequestBody(body, "gpt-4")
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, body, result)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestResponseRewriter_StatusCodePreserved(t *testing.T) {
|
||||||
|
rewriter := NewModelRewriter()
|
||||||
|
recorder := httptest.NewRecorder()
|
||||||
|
wrapped, cleanup := rewriter.WrapResponseWriter(recorder, "gpt-4", "claude-local")
|
||||||
|
|
||||||
|
wrapped.WriteHeader(http.StatusAccepted)
|
||||||
|
wrapped.Write([]byte(`{"model":"claude-local"}`))
|
||||||
|
cleanup()
|
||||||
|
|
||||||
|
assert.Equal(t, http.StatusAccepted, recorder.Code)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestResponseRewriter_HeaderFlushed(t *testing.T) {
|
||||||
|
rewriter := NewModelRewriter()
|
||||||
|
recorder := httptest.NewRecorder()
|
||||||
|
wrapped, cleanup := rewriter.WrapResponseWriter(recorder, "gpt-4", "claude-local")
|
||||||
|
|
||||||
|
wrapped.Header().Set("Content-Type", "application/json")
|
||||||
|
wrapped.Header().Set("X-Request-ID", "abc123")
|
||||||
|
wrapped.Write([]byte(`{"model":"claude-local"}`))
|
||||||
|
cleanup()
|
||||||
|
|
||||||
|
result := recorder.Result()
|
||||||
|
assert.Equal(t, "application/json", result.Header.Get("Content-Type"))
|
||||||
|
assert.Equal(t, "abc123", result.Header.Get("X-Request-ID"))
|
||||||
|
}
|
||||||
@@ -31,15 +31,17 @@ func NewRouter(registry *Registry, cfg *config.Config) *Router {
|
|||||||
return r
|
return r
|
||||||
}
|
}
|
||||||
|
|
||||||
// RoutingDecision contains the resolved routing information.
|
// LegacyRoutingDecision contains the resolved routing information.
|
||||||
type RoutingDecision struct {
|
// Deprecated: Will be replaced by RoutingDecision from types.go in T-013.
|
||||||
|
type LegacyRoutingDecision struct {
|
||||||
RequestedModel string // Original model from request
|
RequestedModel string // Original model from request
|
||||||
ResolvedModel string // After model-mappings
|
ResolvedModel string // After model-mappings
|
||||||
Candidates []ProviderCandidate // Ordered list of providers to try
|
Candidates []ProviderCandidate // Ordered list of providers to try
|
||||||
}
|
}
|
||||||
|
|
||||||
// Resolve determines the routing decision for the requested model.
|
// Resolve determines the routing decision for the requested model.
|
||||||
func (r *Router) Resolve(requestedModel string) *RoutingDecision {
|
// Deprecated: Will be updated to use RoutingRequest and return *RoutingDecision in T-013.
|
||||||
|
func (r *Router) Resolve(requestedModel string) *LegacyRoutingDecision {
|
||||||
// 1. Extract thinking suffix
|
// 1. Extract thinking suffix
|
||||||
suffixResult := thinking.ParseSuffix(requestedModel)
|
suffixResult := thinking.ParseSuffix(requestedModel)
|
||||||
baseModel := suffixResult.ModelName
|
baseModel := suffixResult.ModelName
|
||||||
@@ -60,13 +62,151 @@ func (r *Router) Resolve(requestedModel string) *RoutingDecision {
|
|||||||
return candidates[i].Provider.Priority() < candidates[j].Provider.Priority()
|
return candidates[i].Provider.Priority() < candidates[j].Provider.Priority()
|
||||||
})
|
})
|
||||||
|
|
||||||
return &RoutingDecision{
|
return &LegacyRoutingDecision{
|
||||||
RequestedModel: requestedModel,
|
RequestedModel: requestedModel,
|
||||||
ResolvedModel: targetModel,
|
ResolvedModel: targetModel,
|
||||||
Candidates: candidates,
|
Candidates: candidates,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ResolveV2 determines the routing decision for a routing request.
|
||||||
|
// It uses the new RoutingRequest and RoutingDecision types.
|
||||||
|
func (r *Router) ResolveV2(req RoutingRequest) *RoutingDecision {
|
||||||
|
// 1. Extract thinking suffix
|
||||||
|
suffixResult := thinking.ParseSuffix(req.RequestedModel)
|
||||||
|
baseModel := suffixResult.ModelName
|
||||||
|
thinkingSuffix := ""
|
||||||
|
if suffixResult.HasSuffix {
|
||||||
|
thinkingSuffix = "(" + suffixResult.RawSuffix + ")"
|
||||||
|
}
|
||||||
|
|
||||||
|
// 2. Check for local providers
|
||||||
|
localCandidates := r.findLocalCandidates(baseModel, suffixResult)
|
||||||
|
|
||||||
|
// 3. Apply model-mappings if needed
|
||||||
|
mappedModel := r.applyMappings(baseModel)
|
||||||
|
mappingCandidates := r.findLocalCandidates(mappedModel, suffixResult)
|
||||||
|
|
||||||
|
// 4. Determine route type based on preferences and availability
|
||||||
|
var decision *RoutingDecision
|
||||||
|
|
||||||
|
if req.ForceModelMapping && mappedModel != baseModel && len(mappingCandidates) > 0 {
|
||||||
|
// FORCE MODE: Use mapping even if local provider exists
|
||||||
|
decision = r.buildMappingDecision(req.RequestedModel, mappedModel, mappingCandidates, thinkingSuffix, mappingCandidates[1:])
|
||||||
|
} else if req.PreferLocalProvider && len(localCandidates) > 0 {
|
||||||
|
// DEFAULT MODE with local preference: Use local provider first
|
||||||
|
decision = r.buildLocalProviderDecision(req.RequestedModel, localCandidates, thinkingSuffix)
|
||||||
|
} else if len(localCandidates) > 0 {
|
||||||
|
// DEFAULT MODE: Local provider available
|
||||||
|
decision = r.buildLocalProviderDecision(req.RequestedModel, localCandidates, thinkingSuffix)
|
||||||
|
} else if mappedModel != baseModel && len(mappingCandidates) > 0 {
|
||||||
|
// DEFAULT MODE: No local provider, but mapping available
|
||||||
|
decision = r.buildMappingDecision(req.RequestedModel, mappedModel, mappingCandidates, thinkingSuffix, mappingCandidates[1:])
|
||||||
|
} else {
|
||||||
|
// No local provider, no mapping - use amp credits proxy
|
||||||
|
decision = &RoutingDecision{
|
||||||
|
RouteType: RouteTypeAmpCredits,
|
||||||
|
ResolvedModel: req.RequestedModel,
|
||||||
|
ShouldProxy: true,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return decision
|
||||||
|
}
|
||||||
|
|
||||||
|
// findLocalCandidates finds local provider candidates for a model.
|
||||||
|
func (r *Router) findLocalCandidates(model string, suffixResult thinking.SuffixResult) []ProviderCandidate {
|
||||||
|
var candidates []ProviderCandidate
|
||||||
|
|
||||||
|
for _, p := range r.registry.All() {
|
||||||
|
if !p.SupportsModel(model) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Apply thinking suffix if needed
|
||||||
|
actualModel := model
|
||||||
|
if suffixResult.HasSuffix && !thinking.ParseSuffix(model).HasSuffix {
|
||||||
|
actualModel = model + "(" + suffixResult.RawSuffix + ")"
|
||||||
|
}
|
||||||
|
|
||||||
|
if p.Available(actualModel) {
|
||||||
|
candidates = append(candidates, ProviderCandidate{
|
||||||
|
Provider: p,
|
||||||
|
Model: actualModel,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Sort by priority
|
||||||
|
sort.Slice(candidates, func(i, j int) bool {
|
||||||
|
return candidates[i].Provider.Priority() < candidates[j].Provider.Priority()
|
||||||
|
})
|
||||||
|
|
||||||
|
return candidates
|
||||||
|
}
|
||||||
|
|
||||||
|
// buildLocalProviderDecision creates a decision for local provider routing.
|
||||||
|
func (r *Router) buildLocalProviderDecision(requestedModel string, candidates []ProviderCandidate, thinkingSuffix string) *RoutingDecision {
|
||||||
|
resolvedModel := requestedModel
|
||||||
|
if thinkingSuffix != "" {
|
||||||
|
// Ensure thinking suffix is preserved
|
||||||
|
sr := thinking.ParseSuffix(requestedModel)
|
||||||
|
if !sr.HasSuffix {
|
||||||
|
resolvedModel = requestedModel + thinkingSuffix
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
var fallbackModels []string
|
||||||
|
if len(candidates) > 1 {
|
||||||
|
for _, c := range candidates[1:] {
|
||||||
|
fallbackModels = append(fallbackModels, c.Model)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return &RoutingDecision{
|
||||||
|
RouteType: RouteTypeLocalProvider,
|
||||||
|
ResolvedModel: resolvedModel,
|
||||||
|
ProviderName: candidates[0].Provider.Name(),
|
||||||
|
FallbackModels: fallbackModels,
|
||||||
|
ShouldProxy: false,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// buildMappingDecision creates a decision for model mapping routing.
|
||||||
|
func (r *Router) buildMappingDecision(requestedModel, mappedModel string, candidates []ProviderCandidate, thinkingSuffix string, fallbackCandidates []ProviderCandidate) *RoutingDecision {
|
||||||
|
// Apply thinking suffix to resolved model if needed
|
||||||
|
resolvedModel := mappedModel
|
||||||
|
if thinkingSuffix != "" {
|
||||||
|
sr := thinking.ParseSuffix(mappedModel)
|
||||||
|
if !sr.HasSuffix {
|
||||||
|
resolvedModel = mappedModel + thinkingSuffix
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
var fallbackModels []string
|
||||||
|
for _, c := range fallbackCandidates {
|
||||||
|
fallbackModels = append(fallbackModels, c.Model)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Also add oauth aliases as fallbacks
|
||||||
|
baseMapped := thinking.ParseSuffix(mappedModel).ModelName
|
||||||
|
for _, alias := range r.oauthAliases[strings.ToLower(baseMapped)] {
|
||||||
|
// Check if this alias has providers
|
||||||
|
aliasCandidates := r.findLocalCandidates(alias, thinking.SuffixResult{ModelName: alias})
|
||||||
|
for _, c := range aliasCandidates {
|
||||||
|
fallbackModels = append(fallbackModels, c.Model)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return &RoutingDecision{
|
||||||
|
RouteType: RouteTypeModelMapping,
|
||||||
|
ResolvedModel: resolvedModel,
|
||||||
|
ProviderName: candidates[0].Provider.Name(),
|
||||||
|
FallbackModels: fallbackModels,
|
||||||
|
ShouldProxy: false,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// applyMappings applies model-mappings configuration.
|
// applyMappings applies model-mappings configuration.
|
||||||
func (r *Router) applyMappings(model string) string {
|
func (r *Router) applyMappings(model string) string {
|
||||||
key := strings.ToLower(strings.TrimSpace(model))
|
key := strings.ToLower(strings.TrimSpace(model))
|
||||||
|
|||||||
245
internal/routing/router_v2_test.go
Normal file
245
internal/routing/router_v2_test.go
Normal file
@@ -0,0 +1,245 @@
|
|||||||
|
package routing
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestRouter_DefaultMode_PrefersLocal(t *testing.T) {
|
||||||
|
// Setup: Create a router with a mock provider that supports "gpt-4"
|
||||||
|
registry := NewRegistry()
|
||||||
|
mockProvider := &MockProvider{
|
||||||
|
name: "openai",
|
||||||
|
supportedModels: []string{"gpt-4"},
|
||||||
|
available: true,
|
||||||
|
priority: 1,
|
||||||
|
}
|
||||||
|
registry.Register(mockProvider)
|
||||||
|
|
||||||
|
cfg := &config.Config{
|
||||||
|
AmpCode: config.AmpCode{
|
||||||
|
ModelMappings: []config.AmpModelMapping{
|
||||||
|
{From: "gpt-4", To: "claude-local"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
router := NewRouter(registry, cfg)
|
||||||
|
|
||||||
|
// Test: Request gpt-4 when local provider exists
|
||||||
|
req := RoutingRequest{
|
||||||
|
RequestedModel: "gpt-4",
|
||||||
|
PreferLocalProvider: true,
|
||||||
|
ForceModelMapping: false,
|
||||||
|
}
|
||||||
|
|
||||||
|
decision := router.ResolveV2(req)
|
||||||
|
|
||||||
|
// Assert: Should return LOCAL_PROVIDER, not MODEL_MAPPING
|
||||||
|
assert.Equal(t, RouteTypeLocalProvider, decision.RouteType)
|
||||||
|
assert.Equal(t, "gpt-4", decision.ResolvedModel)
|
||||||
|
assert.Equal(t, "openai", decision.ProviderName)
|
||||||
|
assert.False(t, decision.ShouldProxy)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRouter_DefaultMode_MapsWhenNoLocal(t *testing.T) {
|
||||||
|
// Setup: Create a router with NO provider for "gpt-4" but a mapping to "claude-local"
|
||||||
|
// which has a provider
|
||||||
|
registry := NewRegistry()
|
||||||
|
mockProvider := &MockProvider{
|
||||||
|
name: "anthropic",
|
||||||
|
supportedModels: []string{"claude-local"},
|
||||||
|
available: true,
|
||||||
|
priority: 1,
|
||||||
|
}
|
||||||
|
registry.Register(mockProvider)
|
||||||
|
|
||||||
|
cfg := &config.Config{
|
||||||
|
AmpCode: config.AmpCode{
|
||||||
|
ModelMappings: []config.AmpModelMapping{
|
||||||
|
{From: "gpt-4", To: "claude-local"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
router := NewRouter(registry, cfg)
|
||||||
|
|
||||||
|
// Test: Request gpt-4 when no local provider exists, but mapping exists
|
||||||
|
req := RoutingRequest{
|
||||||
|
RequestedModel: "gpt-4",
|
||||||
|
PreferLocalProvider: true,
|
||||||
|
ForceModelMapping: false,
|
||||||
|
}
|
||||||
|
|
||||||
|
decision := router.ResolveV2(req)
|
||||||
|
|
||||||
|
// Assert: Should return MODEL_MAPPING
|
||||||
|
assert.Equal(t, RouteTypeModelMapping, decision.RouteType)
|
||||||
|
assert.Equal(t, "claude-local", decision.ResolvedModel)
|
||||||
|
assert.Equal(t, "anthropic", decision.ProviderName)
|
||||||
|
assert.False(t, decision.ShouldProxy)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRouter_DefaultMode_AmpCreditsWhenNoLocalOrMapping(t *testing.T) {
|
||||||
|
// Setup: Create a router with no providers and no mappings
|
||||||
|
registry := NewRegistry()
|
||||||
|
|
||||||
|
cfg := &config.Config{
|
||||||
|
AmpCode: config.AmpCode{
|
||||||
|
ModelMappings: []config.AmpModelMapping{},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
router := NewRouter(registry, cfg)
|
||||||
|
|
||||||
|
// Test: Request a model with no local provider and no mapping
|
||||||
|
req := RoutingRequest{
|
||||||
|
RequestedModel: "unknown-model",
|
||||||
|
PreferLocalProvider: true,
|
||||||
|
ForceModelMapping: false,
|
||||||
|
}
|
||||||
|
|
||||||
|
decision := router.ResolveV2(req)
|
||||||
|
|
||||||
|
// Assert: Should return AMP_CREDITS with ShouldProxy=true
|
||||||
|
assert.Equal(t, RouteTypeAmpCredits, decision.RouteType)
|
||||||
|
assert.Equal(t, "unknown-model", decision.ResolvedModel)
|
||||||
|
assert.True(t, decision.ShouldProxy)
|
||||||
|
assert.Empty(t, decision.ProviderName)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRouter_ForceMode_MapsEvenWithLocal(t *testing.T) {
|
||||||
|
// Setup: Create a router with BOTH a local provider for "gpt-4" AND a mapping from "gpt-4" to "claude-local"
|
||||||
|
// The mapping target "claude-local" also has a provider
|
||||||
|
registry := NewRegistry()
|
||||||
|
|
||||||
|
// Local provider for gpt-4
|
||||||
|
openaiProvider := &MockProvider{
|
||||||
|
name: "openai",
|
||||||
|
supportedModels: []string{"gpt-4"},
|
||||||
|
available: true,
|
||||||
|
priority: 1,
|
||||||
|
}
|
||||||
|
registry.Register(openaiProvider)
|
||||||
|
|
||||||
|
// Local provider for the mapped model
|
||||||
|
anthropicProvider := &MockProvider{
|
||||||
|
name: "anthropic",
|
||||||
|
supportedModels: []string{"claude-local"},
|
||||||
|
available: true,
|
||||||
|
priority: 2,
|
||||||
|
}
|
||||||
|
registry.Register(anthropicProvider)
|
||||||
|
|
||||||
|
cfg := &config.Config{
|
||||||
|
AmpCode: config.AmpCode{
|
||||||
|
ModelMappings: []config.AmpModelMapping{
|
||||||
|
{From: "gpt-4", To: "claude-local"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
router := NewRouter(registry, cfg)
|
||||||
|
|
||||||
|
// Test: Request gpt-4 with ForceModelMapping=true
|
||||||
|
// Even though gpt-4 has a local provider, mapping should take precedence
|
||||||
|
req := RoutingRequest{
|
||||||
|
RequestedModel: "gpt-4",
|
||||||
|
PreferLocalProvider: false,
|
||||||
|
ForceModelMapping: true,
|
||||||
|
}
|
||||||
|
|
||||||
|
decision := router.ResolveV2(req)
|
||||||
|
|
||||||
|
// Assert: Should return MODEL_MAPPING, not LOCAL_PROVIDER
|
||||||
|
assert.Equal(t, RouteTypeModelMapping, decision.RouteType)
|
||||||
|
assert.Equal(t, "claude-local", decision.ResolvedModel)
|
||||||
|
assert.Equal(t, "anthropic", decision.ProviderName)
|
||||||
|
assert.False(t, decision.ShouldProxy)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRouter_ThinkingSuffix_Preserved(t *testing.T) {
|
||||||
|
// Setup: Create a router with mapping and provider for mapped model
|
||||||
|
registry := NewRegistry()
|
||||||
|
|
||||||
|
mockProvider := &MockProvider{
|
||||||
|
name: "anthropic",
|
||||||
|
supportedModels: []string{"claude-local"},
|
||||||
|
available: true,
|
||||||
|
priority: 1,
|
||||||
|
}
|
||||||
|
registry.Register(mockProvider)
|
||||||
|
|
||||||
|
cfg := &config.Config{
|
||||||
|
AmpCode: config.AmpCode{
|
||||||
|
ModelMappings: []config.AmpModelMapping{
|
||||||
|
{From: "claude-3-5-sonnet", To: "claude-local"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
router := NewRouter(registry, cfg)
|
||||||
|
|
||||||
|
// Test: Request claude-3-5-sonnet with thinking suffix
|
||||||
|
req := RoutingRequest{
|
||||||
|
RequestedModel: "claude-3-5-sonnet(thinking:foo)",
|
||||||
|
PreferLocalProvider: true,
|
||||||
|
ForceModelMapping: false,
|
||||||
|
}
|
||||||
|
|
||||||
|
decision := router.ResolveV2(req)
|
||||||
|
|
||||||
|
// Assert: Thinking suffix should be preserved in resolved model
|
||||||
|
assert.Equal(t, RouteTypeModelMapping, decision.RouteType)
|
||||||
|
assert.Equal(t, "claude-local(thinking:foo)", decision.ResolvedModel)
|
||||||
|
assert.Equal(t, "anthropic", decision.ProviderName)
|
||||||
|
}
|
||||||
|
|
||||||
|
// MockProvider is a mock implementation of Provider for testing
|
||||||
|
type MockProvider struct {
|
||||||
|
name string
|
||||||
|
providerType ProviderType
|
||||||
|
supportedModels []string
|
||||||
|
available bool
|
||||||
|
priority int
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MockProvider) Name() string {
|
||||||
|
return m.name
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MockProvider) Type() ProviderType {
|
||||||
|
if m.providerType == "" {
|
||||||
|
return ProviderTypeOAuth
|
||||||
|
}
|
||||||
|
return m.providerType
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MockProvider) SupportsModel(model string) bool {
|
||||||
|
for _, supported := range m.supportedModels {
|
||||||
|
if supported == model {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MockProvider) Available(model string) bool {
|
||||||
|
return m.available
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MockProvider) Priority() int {
|
||||||
|
return m.priority
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MockProvider) Execute(ctx context.Context, model string, req executor.Request) (executor.Response, error) {
|
||||||
|
return executor.Response{}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *MockProvider) ExecuteStream(ctx context.Context, model string, req executor.Request) (<-chan executor.StreamChunk, error) {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
113
internal/routing/testutil/fake_handler.go
Normal file
113
internal/routing/testutil/fake_handler.go
Normal file
@@ -0,0 +1,113 @@
|
|||||||
|
package testutil
|
||||||
|
|
||||||
|
import (
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
)
|
||||||
|
|
||||||
|
// FakeHandlerRecorder records handler invocations for testing.
|
||||||
|
type FakeHandlerRecorder struct {
|
||||||
|
Called bool
|
||||||
|
CallCount int
|
||||||
|
RequestBody []byte
|
||||||
|
RequestHeader http.Header
|
||||||
|
ContextKeys map[string]interface{}
|
||||||
|
ResponseStatus int
|
||||||
|
ResponseBody []byte
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewFakeHandlerRecorder creates a new fake handler recorder.
|
||||||
|
func NewFakeHandlerRecorder() *FakeHandlerRecorder {
|
||||||
|
return &FakeHandlerRecorder{
|
||||||
|
ContextKeys: make(map[string]interface{}),
|
||||||
|
ResponseStatus: http.StatusOK,
|
||||||
|
ResponseBody: []byte(`{"status":"handled"}`),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// GinHandler returns a gin.HandlerFunc that records the invocation.
|
||||||
|
func (f *FakeHandlerRecorder) GinHandler() gin.HandlerFunc {
|
||||||
|
return func(c *gin.Context) {
|
||||||
|
f.record(c)
|
||||||
|
c.Data(f.ResponseStatus, "application/json", f.ResponseBody)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// GinHandlerWithModel returns a gin.HandlerFunc that records the invocation and returns the model from context.
|
||||||
|
// Useful for testing response rewriting in model mapping scenarios.
|
||||||
|
func (f *FakeHandlerRecorder) GinHandlerWithModel() gin.HandlerFunc {
|
||||||
|
return func(c *gin.Context) {
|
||||||
|
f.record(c)
|
||||||
|
// Return a response with the model field that would be in the actual API response
|
||||||
|
// If ResponseBody was explicitly set (not default), use that; otherwise generate from context
|
||||||
|
var body []byte
|
||||||
|
if mappedModel, exists := c.Get("mapped_model"); exists {
|
||||||
|
body = []byte(`{"model":"` + mappedModel.(string) + `","status":"handled"}`)
|
||||||
|
} else {
|
||||||
|
body = f.ResponseBody
|
||||||
|
}
|
||||||
|
c.Data(f.ResponseStatus, "application/json", body)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// HTTPHandler returns an http.HandlerFunc that records the invocation.
|
||||||
|
func (f *FakeHandlerRecorder) HTTPHandler() http.HandlerFunc {
|
||||||
|
return func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
body, _ := io.ReadAll(r.Body)
|
||||||
|
f.Called = true
|
||||||
|
f.CallCount++
|
||||||
|
f.RequestBody = body
|
||||||
|
f.RequestHeader = r.Header.Clone()
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
w.WriteHeader(f.ResponseStatus)
|
||||||
|
w.Write(f.ResponseBody)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// record captures the request details from gin context.
|
||||||
|
func (f *FakeHandlerRecorder) record(c *gin.Context) {
|
||||||
|
f.Called = true
|
||||||
|
f.CallCount++
|
||||||
|
|
||||||
|
body, _ := io.ReadAll(c.Request.Body)
|
||||||
|
f.RequestBody = body
|
||||||
|
f.RequestHeader = c.Request.Header.Clone()
|
||||||
|
|
||||||
|
// Capture common context keys used by routing
|
||||||
|
if val, exists := c.Get("mapped_model"); exists {
|
||||||
|
f.ContextKeys["mapped_model"] = val
|
||||||
|
}
|
||||||
|
if val, exists := c.Get("fallback_models"); exists {
|
||||||
|
f.ContextKeys["fallback_models"] = val
|
||||||
|
}
|
||||||
|
if val, exists := c.Get("route_type"); exists {
|
||||||
|
f.ContextKeys["route_type"] = val
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Reset clears the recorder state.
|
||||||
|
func (f *FakeHandlerRecorder) Reset() {
|
||||||
|
f.Called = false
|
||||||
|
f.CallCount = 0
|
||||||
|
f.RequestBody = nil
|
||||||
|
f.RequestHeader = nil
|
||||||
|
f.ContextKeys = make(map[string]interface{})
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetContextKey returns a captured context key value.
|
||||||
|
func (f *FakeHandlerRecorder) GetContextKey(key string) (interface{}, bool) {
|
||||||
|
val, ok := f.ContextKeys[key]
|
||||||
|
return val, ok
|
||||||
|
}
|
||||||
|
|
||||||
|
// WasCalled returns true if the handler was called.
|
||||||
|
func (f *FakeHandlerRecorder) WasCalled() bool {
|
||||||
|
return f.Called
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetCallCount returns the number of times the handler was called.
|
||||||
|
func (f *FakeHandlerRecorder) GetCallCount() int {
|
||||||
|
return f.CallCount
|
||||||
|
}
|
||||||
83
internal/routing/testutil/fake_proxy.go
Normal file
83
internal/routing/testutil/fake_proxy.go
Normal file
@@ -0,0 +1,83 @@
|
|||||||
|
package testutil
|
||||||
|
|
||||||
|
import (
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
)
|
||||||
|
|
||||||
|
// CloseNotifierRecorder wraps httptest.ResponseRecorder with CloseNotify support.
|
||||||
|
// This is needed because ReverseProxy requires http.CloseNotifier.
|
||||||
|
type CloseNotifierRecorder struct {
|
||||||
|
*httptest.ResponseRecorder
|
||||||
|
closeChan chan bool
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewCloseNotifierRecorder creates a ResponseRecorder that implements CloseNotifier.
|
||||||
|
func NewCloseNotifierRecorder() *CloseNotifierRecorder {
|
||||||
|
return &CloseNotifierRecorder{
|
||||||
|
ResponseRecorder: httptest.NewRecorder(),
|
||||||
|
closeChan: make(chan bool, 1),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// CloseNotify implements http.CloseNotifier.
|
||||||
|
func (c *CloseNotifierRecorder) CloseNotify() <-chan bool {
|
||||||
|
return c.closeChan
|
||||||
|
}
|
||||||
|
|
||||||
|
// FakeProxyRecorder records proxy invocations for testing.
|
||||||
|
type FakeProxyRecorder struct {
|
||||||
|
Called bool
|
||||||
|
CallCount int
|
||||||
|
RequestBody []byte
|
||||||
|
RequestHeaders http.Header
|
||||||
|
ResponseStatus int
|
||||||
|
ResponseBody []byte
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewFakeProxyRecorder creates a new fake proxy recorder.
|
||||||
|
func NewFakeProxyRecorder() *FakeProxyRecorder {
|
||||||
|
return &FakeProxyRecorder{
|
||||||
|
ResponseStatus: http.StatusOK,
|
||||||
|
ResponseBody: []byte(`{"status":"proxied"}`),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// ServeHTTP implements http.Handler to act as a reverse proxy.
|
||||||
|
func (f *FakeProxyRecorder) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||||
|
f.Called = true
|
||||||
|
f.CallCount++
|
||||||
|
f.RequestHeaders = r.Header.Clone()
|
||||||
|
|
||||||
|
body, err := io.ReadAll(r.Body)
|
||||||
|
if err == nil {
|
||||||
|
f.RequestBody = body
|
||||||
|
}
|
||||||
|
|
||||||
|
w.WriteHeader(f.ResponseStatus)
|
||||||
|
w.Write(f.ResponseBody)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetCallCount returns the number of times the proxy was called.
|
||||||
|
func (f *FakeProxyRecorder) GetCallCount() int {
|
||||||
|
return f.CallCount
|
||||||
|
}
|
||||||
|
|
||||||
|
// Reset clears the recorder state.
|
||||||
|
func (f *FakeProxyRecorder) Reset() {
|
||||||
|
f.Called = false
|
||||||
|
f.CallCount = 0
|
||||||
|
f.RequestBody = nil
|
||||||
|
f.RequestHeaders = nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ToHandler returns the recorder as an http.Handler for use with httptest.
|
||||||
|
func (f *FakeProxyRecorder) ToHandler() http.Handler {
|
||||||
|
return http.HandlerFunc(f.ServeHTTP)
|
||||||
|
}
|
||||||
|
|
||||||
|
// CreateTestServer creates an httptest server with this fake proxy.
|
||||||
|
func (f *FakeProxyRecorder) CreateTestServer() *httptest.Server {
|
||||||
|
return httptest.NewServer(f.ToHandler())
|
||||||
|
}
|
||||||
62
internal/routing/types.go
Normal file
62
internal/routing/types.go
Normal file
@@ -0,0 +1,62 @@
|
|||||||
|
package routing
|
||||||
|
|
||||||
|
// RouteType represents the type of routing decision made for a request.
|
||||||
|
type RouteType string
|
||||||
|
|
||||||
|
const (
|
||||||
|
// RouteTypeLocalProvider indicates the request is handled by a local OAuth provider (free).
|
||||||
|
RouteTypeLocalProvider RouteType = "LOCAL_PROVIDER"
|
||||||
|
// RouteTypeModelMapping indicates the request was remapped to another available model (free).
|
||||||
|
RouteTypeModelMapping RouteType = "MODEL_MAPPING"
|
||||||
|
// RouteTypeAmpCredits indicates the request is forwarded to ampcode.com (uses Amp credits).
|
||||||
|
RouteTypeAmpCredits RouteType = "AMP_CREDITS"
|
||||||
|
// RouteTypeNoProvider indicates no provider or fallback available.
|
||||||
|
RouteTypeNoProvider RouteType = "NO_PROVIDER"
|
||||||
|
)
|
||||||
|
|
||||||
|
// RoutingRequest contains the information needed to make a routing decision.
|
||||||
|
type RoutingRequest struct {
|
||||||
|
// RequestedModel is the model name from the incoming request.
|
||||||
|
RequestedModel string
|
||||||
|
// PreferLocalProvider indicates whether to prefer local providers over mappings.
|
||||||
|
// When true, check local providers first before applying model mappings.
|
||||||
|
PreferLocalProvider bool
|
||||||
|
// ForceModelMapping indicates whether to force model mapping even if local provider exists.
|
||||||
|
// When true, apply model mappings first and skip local provider checks.
|
||||||
|
ForceModelMapping bool
|
||||||
|
}
|
||||||
|
|
||||||
|
// RoutingDecision contains the result of a routing decision.
|
||||||
|
type RoutingDecision struct {
|
||||||
|
// RouteType indicates the type of routing decision.
|
||||||
|
RouteType RouteType
|
||||||
|
// ResolvedModel is the final model name after any mappings.
|
||||||
|
ResolvedModel string
|
||||||
|
// ProviderName is the name of the selected provider (if any).
|
||||||
|
ProviderName string
|
||||||
|
// FallbackModels is a list of alternative models to try if the primary fails.
|
||||||
|
FallbackModels []string
|
||||||
|
// ShouldProxy indicates whether the request should be proxied to ampcode.com.
|
||||||
|
ShouldProxy bool
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewRoutingDecision creates a new RoutingDecision with the given parameters.
|
||||||
|
func NewRoutingDecision(routeType RouteType, resolvedModel, providerName string, fallbackModels []string, shouldProxy bool) *RoutingDecision {
|
||||||
|
return &RoutingDecision{
|
||||||
|
RouteType: routeType,
|
||||||
|
ResolvedModel: resolvedModel,
|
||||||
|
ProviderName: providerName,
|
||||||
|
FallbackModels: fallbackModels,
|
||||||
|
ShouldProxy: shouldProxy,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsLocal returns true if the decision routes to a local provider.
|
||||||
|
func (d *RoutingDecision) IsLocal() bool {
|
||||||
|
return d.RouteType == RouteTypeLocalProvider || d.RouteType == RouteTypeModelMapping
|
||||||
|
}
|
||||||
|
|
||||||
|
// HasFallbacks returns true if there are fallback models available.
|
||||||
|
func (d *RoutingDecision) HasFallbacks() bool {
|
||||||
|
return len(d.FallbackModels) > 0
|
||||||
|
}
|
||||||
270
internal/routing/wrapper.go
Normal file
270
internal/routing/wrapper.go
Normal file
@@ -0,0 +1,270 @@
|
|||||||
|
package routing
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bufio"
|
||||||
|
"bytes"
|
||||||
|
"io"
|
||||||
|
"net"
|
||||||
|
"net/http"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/routing/ctxkeys"
|
||||||
|
"github.com/sirupsen/logrus"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ProxyFunc is the function type for proxying requests.
|
||||||
|
type ProxyFunc func(c *gin.Context)
|
||||||
|
|
||||||
|
// ModelRoutingWrapper wraps HTTP handlers with unified model routing logic.
|
||||||
|
// It replaces the FallbackHandler logic with a Router-based approach.
|
||||||
|
type ModelRoutingWrapper struct {
|
||||||
|
router *Router
|
||||||
|
extractor ModelExtractor
|
||||||
|
rewriter ModelRewriter
|
||||||
|
proxyFunc ProxyFunc
|
||||||
|
logger *logrus.Logger
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewModelRoutingWrapper creates a new ModelRoutingWrapper with the given dependencies.
|
||||||
|
// If extractor is nil, a DefaultModelExtractor is used.
|
||||||
|
// If rewriter is nil, a DefaultModelRewriter is used.
|
||||||
|
// proxyFunc is called for AMP_CREDITS route type; if nil, the handler will be called instead.
|
||||||
|
func NewModelRoutingWrapper(router *Router, extractor ModelExtractor, rewriter ModelRewriter, proxyFunc ProxyFunc) *ModelRoutingWrapper {
|
||||||
|
if extractor == nil {
|
||||||
|
extractor = NewModelExtractor()
|
||||||
|
}
|
||||||
|
if rewriter == nil {
|
||||||
|
rewriter = NewModelRewriter()
|
||||||
|
}
|
||||||
|
return &ModelRoutingWrapper{
|
||||||
|
router: router,
|
||||||
|
extractor: extractor,
|
||||||
|
rewriter: rewriter,
|
||||||
|
proxyFunc: proxyFunc,
|
||||||
|
logger: logrus.New(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetLogger sets the logger for the wrapper.
|
||||||
|
func (w *ModelRoutingWrapper) SetLogger(logger *logrus.Logger) {
|
||||||
|
w.logger = logger
|
||||||
|
}
|
||||||
|
|
||||||
|
// Wrap wraps a gin.HandlerFunc with model routing logic.
|
||||||
|
// The returned handler will:
|
||||||
|
// 1. Extract the model from the request
|
||||||
|
// 2. Get a routing decision from the Router
|
||||||
|
// 3. Handle the request according to the decision type (LOCAL_PROVIDER, MODEL_MAPPING, AMP_CREDITS)
|
||||||
|
func (w *ModelRoutingWrapper) Wrap(handler gin.HandlerFunc) gin.HandlerFunc {
|
||||||
|
return func(c *gin.Context) {
|
||||||
|
// Read request body
|
||||||
|
bodyBytes, err := io.ReadAll(c.Request.Body)
|
||||||
|
if err != nil {
|
||||||
|
w.logger.Errorf("routing wrapper: failed to read request body: %v", err)
|
||||||
|
handler(c)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Extract model from request
|
||||||
|
ginParams := map[string]string{
|
||||||
|
"action": c.Param("action"),
|
||||||
|
"path": c.Param("path"),
|
||||||
|
}
|
||||||
|
modelName, err := w.extractor.Extract(bodyBytes, ginParams)
|
||||||
|
if err != nil {
|
||||||
|
w.logger.Warnf("routing wrapper: failed to extract model: %v", err)
|
||||||
|
c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes))
|
||||||
|
handler(c)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if modelName == "" {
|
||||||
|
// No model found, proceed with original handler
|
||||||
|
c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes))
|
||||||
|
handler(c)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Get routing decision
|
||||||
|
req := RoutingRequest{
|
||||||
|
RequestedModel: modelName,
|
||||||
|
PreferLocalProvider: true,
|
||||||
|
ForceModelMapping: false, // TODO: Get from config
|
||||||
|
}
|
||||||
|
decision := w.router.ResolveV2(req)
|
||||||
|
|
||||||
|
// Store decision in context for downstream handlers
|
||||||
|
c.Set(string(ctxkeys.RoutingDecision), decision)
|
||||||
|
|
||||||
|
// Handle based on route type
|
||||||
|
switch decision.RouteType {
|
||||||
|
case RouteTypeLocalProvider:
|
||||||
|
w.handleLocalProvider(c, handler, bodyBytes, decision)
|
||||||
|
case RouteTypeModelMapping:
|
||||||
|
w.handleModelMapping(c, handler, bodyBytes, decision)
|
||||||
|
case RouteTypeAmpCredits:
|
||||||
|
w.handleAmpCredits(c, handler, bodyBytes)
|
||||||
|
default:
|
||||||
|
// No provider available
|
||||||
|
c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes))
|
||||||
|
handler(c)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// handleLocalProvider handles the LOCAL_PROVIDER route type.
|
||||||
|
func (w *ModelRoutingWrapper) handleLocalProvider(c *gin.Context, handler gin.HandlerFunc, bodyBytes []byte, decision *RoutingDecision) {
|
||||||
|
// Filter Anthropic-Beta header for local provider
|
||||||
|
filterAnthropicBetaHeader(c)
|
||||||
|
|
||||||
|
// Restore body with original content
|
||||||
|
c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes))
|
||||||
|
|
||||||
|
// Call handler
|
||||||
|
handler(c)
|
||||||
|
}
|
||||||
|
|
||||||
|
// handleModelMapping handles the MODEL_MAPPING route type.
|
||||||
|
func (w *ModelRoutingWrapper) handleModelMapping(c *gin.Context, handler gin.HandlerFunc, bodyBytes []byte, decision *RoutingDecision) {
|
||||||
|
// Rewrite request body with mapped model
|
||||||
|
rewrittenBody, err := w.rewriter.RewriteRequestBody(bodyBytes, decision.ResolvedModel)
|
||||||
|
if err != nil {
|
||||||
|
w.logger.Warnf("routing wrapper: failed to rewrite request body: %v", err)
|
||||||
|
rewrittenBody = bodyBytes
|
||||||
|
}
|
||||||
|
_ = rewrittenBody
|
||||||
|
|
||||||
|
// Store mapped model in context
|
||||||
|
c.Set(string(ctxkeys.MappedModel), decision.ResolvedModel)
|
||||||
|
|
||||||
|
// Store fallback models in context if present
|
||||||
|
if len(decision.FallbackModels) > 0 {
|
||||||
|
c.Set(string(ctxkeys.FallbackModels), decision.FallbackModels)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Filter Anthropic-Beta header for local provider
|
||||||
|
filterAnthropicBetaHeader(c)
|
||||||
|
|
||||||
|
// Restore body with rewritten content
|
||||||
|
c.Request.Body = io.NopCloser(bytes.NewReader(rewrittenBody))
|
||||||
|
|
||||||
|
// Wrap response writer to rewrite model back
|
||||||
|
wrappedWriter, cleanup := w.rewriter.WrapResponseWriter(c.Writer, decision.ResolvedModel, decision.ResolvedModel)
|
||||||
|
c.Writer = &ginResponseWriterAdapter{ResponseWriter: wrappedWriter, original: c.Writer}
|
||||||
|
|
||||||
|
// Call handler
|
||||||
|
handler(c)
|
||||||
|
|
||||||
|
// Cleanup (flush response rewriting)
|
||||||
|
cleanup()
|
||||||
|
}
|
||||||
|
|
||||||
|
// handleAmpCredits handles the AMP_CREDITS route type.
|
||||||
|
// It calls the proxy function directly if available, otherwise passes to handler.
|
||||||
|
// Does NOT filter headers or rewrite body - proxy handles everything.
|
||||||
|
func (w *ModelRoutingWrapper) handleAmpCredits(c *gin.Context, handler gin.HandlerFunc, bodyBytes []byte) {
|
||||||
|
// Restore body with original content (no rewriting for proxy)
|
||||||
|
c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes))
|
||||||
|
|
||||||
|
// Call proxy function if available, otherwise fall back to handler
|
||||||
|
if w.proxyFunc != nil {
|
||||||
|
w.proxyFunc(c)
|
||||||
|
} else {
|
||||||
|
handler(c)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// filterAnthropicBetaHeader filters Anthropic-Beta header for local providers.
|
||||||
|
func filterAnthropicBetaHeader(c *gin.Context) {
|
||||||
|
if betaHeader := c.Request.Header.Get("Anthropic-Beta"); betaHeader != "" {
|
||||||
|
filtered := filterBetaFeatures(betaHeader, "context-1m-2025-08-07")
|
||||||
|
if filtered != "" {
|
||||||
|
c.Request.Header.Set("Anthropic-Beta", filtered)
|
||||||
|
} else {
|
||||||
|
c.Request.Header.Del("Anthropic-Beta")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// filterBetaFeatures removes specified beta features from the header.
|
||||||
|
func filterBetaFeatures(betaHeader, featureToRemove string) string {
|
||||||
|
// Simple implementation - can be enhanced
|
||||||
|
if betaHeader == featureToRemove {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
return betaHeader
|
||||||
|
}
|
||||||
|
|
||||||
|
// ginResponseWriterAdapter adapts http.ResponseWriter to gin.ResponseWriter.
|
||||||
|
type ginResponseWriterAdapter struct {
|
||||||
|
http.ResponseWriter
|
||||||
|
original gin.ResponseWriter
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *ginResponseWriterAdapter) WriteHeader(code int) {
|
||||||
|
a.ResponseWriter.WriteHeader(code)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *ginResponseWriterAdapter) Write(data []byte) (int, error) {
|
||||||
|
return a.ResponseWriter.Write(data)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (a *ginResponseWriterAdapter) Header() http.Header {
|
||||||
|
return a.ResponseWriter.Header()
|
||||||
|
}
|
||||||
|
|
||||||
|
// CloseNotify implements http.CloseNotifier.
|
||||||
|
func (a *ginResponseWriterAdapter) CloseNotify() <-chan bool {
|
||||||
|
if notifier, ok := a.ResponseWriter.(http.CloseNotifier); ok {
|
||||||
|
return notifier.CloseNotify()
|
||||||
|
}
|
||||||
|
return a.original.CloseNotify()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Flush implements http.Flusher.
|
||||||
|
func (a *ginResponseWriterAdapter) Flush() {
|
||||||
|
if flusher, ok := a.ResponseWriter.(http.Flusher); ok {
|
||||||
|
flusher.Flush()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Hijack implements http.Hijacker.
|
||||||
|
func (a *ginResponseWriterAdapter) Hijack() (net.Conn, *bufio.ReadWriter, error) {
|
||||||
|
if hijacker, ok := a.ResponseWriter.(http.Hijacker); ok {
|
||||||
|
return hijacker.Hijack()
|
||||||
|
}
|
||||||
|
return a.original.Hijack()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Status returns the HTTP status code.
|
||||||
|
func (a *ginResponseWriterAdapter) Status() int {
|
||||||
|
return a.original.Status()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Size returns the number of bytes already written into the response http body.
|
||||||
|
func (a *ginResponseWriterAdapter) Size() int {
|
||||||
|
return a.original.Size()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Written returns whether or not the response for this context has been written.
|
||||||
|
func (a *ginResponseWriterAdapter) Written() bool {
|
||||||
|
return a.original.Written()
|
||||||
|
}
|
||||||
|
|
||||||
|
// WriteHeaderNow forces WriteHeader to be called.
|
||||||
|
func (a *ginResponseWriterAdapter) WriteHeaderNow() {
|
||||||
|
a.original.WriteHeaderNow()
|
||||||
|
}
|
||||||
|
|
||||||
|
// WriteString writes the given string into the response body.
|
||||||
|
func (a *ginResponseWriterAdapter) WriteString(s string) (int, error) {
|
||||||
|
return a.Write([]byte(s))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Pusher returns the http.Pusher for server push.
|
||||||
|
func (a *ginResponseWriterAdapter) Pusher() http.Pusher {
|
||||||
|
if pusher, ok := a.ResponseWriter.(http.Pusher); ok {
|
||||||
|
return pusher
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
@@ -124,7 +124,9 @@ func (e *ClaudeExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, r
|
|||||||
body = disableThinkingIfToolChoiceForced(body)
|
body = disableThinkingIfToolChoiceForced(body)
|
||||||
|
|
||||||
// Auto-inject cache_control if missing (optimization for ClawdBot/clients without caching support)
|
// Auto-inject cache_control if missing (optimization for ClawdBot/clients without caching support)
|
||||||
|
if countCacheControls(body) == 0 {
|
||||||
body = ensureCacheControl(body)
|
body = ensureCacheControl(body)
|
||||||
|
}
|
||||||
|
|
||||||
// Extract betas from body and convert to header
|
// Extract betas from body and convert to header
|
||||||
var extraBetas []string
|
var extraBetas []string
|
||||||
@@ -262,7 +264,9 @@ func (e *ClaudeExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A
|
|||||||
body = disableThinkingIfToolChoiceForced(body)
|
body = disableThinkingIfToolChoiceForced(body)
|
||||||
|
|
||||||
// Auto-inject cache_control if missing (optimization for ClawdBot/clients without caching support)
|
// Auto-inject cache_control if missing (optimization for ClawdBot/clients without caching support)
|
||||||
|
if countCacheControls(body) == 0 {
|
||||||
body = ensureCacheControl(body)
|
body = ensureCacheControl(body)
|
||||||
|
}
|
||||||
|
|
||||||
// Extract betas from body and convert to header
|
// Extract betas from body and convert to header
|
||||||
var extraBetas []string
|
var extraBetas []string
|
||||||
@@ -1033,6 +1037,51 @@ func ensureCacheControl(payload []byte) []byte {
|
|||||||
return payload
|
return payload
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func countCacheControls(payload []byte) int {
|
||||||
|
count := 0
|
||||||
|
|
||||||
|
// Check system
|
||||||
|
system := gjson.GetBytes(payload, "system")
|
||||||
|
if system.IsArray() {
|
||||||
|
system.ForEach(func(_, item gjson.Result) bool {
|
||||||
|
if item.Get("cache_control").Exists() {
|
||||||
|
count++
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check tools
|
||||||
|
tools := gjson.GetBytes(payload, "tools")
|
||||||
|
if tools.IsArray() {
|
||||||
|
tools.ForEach(func(_, item gjson.Result) bool {
|
||||||
|
if item.Get("cache_control").Exists() {
|
||||||
|
count++
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check messages
|
||||||
|
messages := gjson.GetBytes(payload, "messages")
|
||||||
|
if messages.IsArray() {
|
||||||
|
messages.ForEach(func(_, msg gjson.Result) bool {
|
||||||
|
content := msg.Get("content")
|
||||||
|
if content.IsArray() {
|
||||||
|
content.ForEach(func(_, item gjson.Result) bool {
|
||||||
|
if item.Get("cache_control").Exists() {
|
||||||
|
count++
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
})
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
return count
|
||||||
|
}
|
||||||
|
|
||||||
// injectMessagesCacheControl adds cache_control to the second-to-last user turn for multi-turn caching.
|
// injectMessagesCacheControl adds cache_control to the second-to-last user turn for multi-turn caching.
|
||||||
// Per Anthropic docs: "Place cache_control on the second-to-last User message to let the model reuse the earlier cache."
|
// Per Anthropic docs: "Place cache_control on the second-to-last User message to let the model reuse the earlier cache."
|
||||||
// This enables caching of conversation history, which is especially beneficial for long multi-turn conversations.
|
// This enables caching of conversation history, which is especially beneficial for long multi-turn conversations.
|
||||||
|
|||||||
@@ -21,7 +21,7 @@ func applyPayloadConfigWithRoot(cfg *config.Config, model, protocol, root string
|
|||||||
return payload
|
return payload
|
||||||
}
|
}
|
||||||
rules := cfg.Payload
|
rules := cfg.Payload
|
||||||
if len(rules.Default) == 0 && len(rules.DefaultRaw) == 0 && len(rules.Override) == 0 && len(rules.OverrideRaw) == 0 {
|
if len(rules.Default) == 0 && len(rules.DefaultRaw) == 0 && len(rules.Override) == 0 && len(rules.OverrideRaw) == 0 && len(rules.Filter) == 0 {
|
||||||
return payload
|
return payload
|
||||||
}
|
}
|
||||||
model = strings.TrimSpace(model)
|
model = strings.TrimSpace(model)
|
||||||
@@ -39,7 +39,7 @@ func applyPayloadConfigWithRoot(cfg *config.Config, model, protocol, root string
|
|||||||
// Apply default rules: first write wins per field across all matching rules.
|
// Apply default rules: first write wins per field across all matching rules.
|
||||||
for i := range rules.Default {
|
for i := range rules.Default {
|
||||||
rule := &rules.Default[i]
|
rule := &rules.Default[i]
|
||||||
if !payloadRuleMatchesModels(rule, protocol, candidates) {
|
if !payloadModelRulesMatch(rule.Models, protocol, candidates) {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
for path, value := range rule.Params {
|
for path, value := range rule.Params {
|
||||||
@@ -64,7 +64,7 @@ func applyPayloadConfigWithRoot(cfg *config.Config, model, protocol, root string
|
|||||||
// Apply default raw rules: first write wins per field across all matching rules.
|
// Apply default raw rules: first write wins per field across all matching rules.
|
||||||
for i := range rules.DefaultRaw {
|
for i := range rules.DefaultRaw {
|
||||||
rule := &rules.DefaultRaw[i]
|
rule := &rules.DefaultRaw[i]
|
||||||
if !payloadRuleMatchesModels(rule, protocol, candidates) {
|
if !payloadModelRulesMatch(rule.Models, protocol, candidates) {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
for path, value := range rule.Params {
|
for path, value := range rule.Params {
|
||||||
@@ -93,7 +93,7 @@ func applyPayloadConfigWithRoot(cfg *config.Config, model, protocol, root string
|
|||||||
// Apply override rules: last write wins per field across all matching rules.
|
// Apply override rules: last write wins per field across all matching rules.
|
||||||
for i := range rules.Override {
|
for i := range rules.Override {
|
||||||
rule := &rules.Override[i]
|
rule := &rules.Override[i]
|
||||||
if !payloadRuleMatchesModels(rule, protocol, candidates) {
|
if !payloadModelRulesMatch(rule.Models, protocol, candidates) {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
for path, value := range rule.Params {
|
for path, value := range rule.Params {
|
||||||
@@ -111,7 +111,7 @@ func applyPayloadConfigWithRoot(cfg *config.Config, model, protocol, root string
|
|||||||
// Apply override raw rules: last write wins per field across all matching rules.
|
// Apply override raw rules: last write wins per field across all matching rules.
|
||||||
for i := range rules.OverrideRaw {
|
for i := range rules.OverrideRaw {
|
||||||
rule := &rules.OverrideRaw[i]
|
rule := &rules.OverrideRaw[i]
|
||||||
if !payloadRuleMatchesModels(rule, protocol, candidates) {
|
if !payloadModelRulesMatch(rule.Models, protocol, candidates) {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
for path, value := range rule.Params {
|
for path, value := range rule.Params {
|
||||||
@@ -130,29 +130,33 @@ func applyPayloadConfigWithRoot(cfg *config.Config, model, protocol, root string
|
|||||||
out = updated
|
out = updated
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
// Apply filter rules: remove matching paths from payload.
|
||||||
|
for i := range rules.Filter {
|
||||||
|
rule := &rules.Filter[i]
|
||||||
|
if !payloadModelRulesMatch(rule.Models, protocol, candidates) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
for _, path := range rule.Params {
|
||||||
|
fullPath := buildPayloadPath(root, path)
|
||||||
|
if fullPath == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
updated, errDel := sjson.DeleteBytes(out, fullPath)
|
||||||
|
if errDel != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
out = updated
|
||||||
|
}
|
||||||
|
}
|
||||||
return out
|
return out
|
||||||
}
|
}
|
||||||
|
|
||||||
func payloadRuleMatchesModels(rule *config.PayloadRule, protocol string, models []string) bool {
|
func payloadModelRulesMatch(rules []config.PayloadModelRule, protocol string, models []string) bool {
|
||||||
if rule == nil || len(models) == 0 {
|
if len(rules) == 0 || len(models) == 0 {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
for _, model := range models {
|
for _, model := range models {
|
||||||
if payloadRuleMatchesModel(rule, model, protocol) {
|
for _, entry := range rules {
|
||||||
return true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
func payloadRuleMatchesModel(rule *config.PayloadRule, model, protocol string) bool {
|
|
||||||
if rule == nil {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
if len(rule.Models) == 0 {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
for _, entry := range rule.Models {
|
|
||||||
name := strings.TrimSpace(entry.Name)
|
name := strings.TrimSpace(entry.Name)
|
||||||
if name == "" {
|
if name == "" {
|
||||||
continue
|
continue
|
||||||
@@ -164,6 +168,7 @@ func payloadRuleMatchesModel(rule *config.PayloadRule, model, protocol string) b
|
|||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -386,11 +386,12 @@ func (s *ObjectTokenStore) syncConfigFromBucket(ctx context.Context, example str
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *ObjectTokenStore) syncAuthFromBucket(ctx context.Context) error {
|
func (s *ObjectTokenStore) syncAuthFromBucket(ctx context.Context) error {
|
||||||
if err := os.RemoveAll(s.authDir); err != nil {
|
// NOTE: We intentionally do NOT use os.RemoveAll here.
|
||||||
return fmt.Errorf("object store: reset auth directory: %w", err)
|
// Wiping the directory triggers file watcher delete events, which then
|
||||||
}
|
// propagate deletions to the remote object store (race condition).
|
||||||
|
// Instead, we just ensure the directory exists and overwrite files incrementally.
|
||||||
if err := os.MkdirAll(s.authDir, 0o700); err != nil {
|
if err := os.MkdirAll(s.authDir, 0o700); err != nil {
|
||||||
return fmt.Errorf("object store: recreate auth directory: %w", err)
|
return fmt.Errorf("object store: create auth directory: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
prefix := s.prefixedKey(objectStoreAuthPrefix + "/")
|
prefix := s.prefixedKey(objectStoreAuthPrefix + "/")
|
||||||
|
|||||||
@@ -24,6 +24,8 @@ import (
|
|||||||
type convertCliResponseToOpenAIChatParams struct {
|
type convertCliResponseToOpenAIChatParams struct {
|
||||||
UnixTimestamp int64
|
UnixTimestamp int64
|
||||||
FunctionIndex int
|
FunctionIndex int
|
||||||
|
SawToolCall bool // Tracks if any tool call was seen in the entire stream
|
||||||
|
UpstreamFinishReason string // Caches the upstream finish reason for final chunk
|
||||||
}
|
}
|
||||||
|
|
||||||
// functionCallIDCounter provides a process-wide unique counter for function call identifiers.
|
// functionCallIDCounter provides a process-wide unique counter for function call identifiers.
|
||||||
@@ -79,10 +81,9 @@ func ConvertAntigravityResponseToOpenAI(_ context.Context, _ string, originalReq
|
|||||||
template, _ = sjson.Set(template, "id", responseIDResult.String())
|
template, _ = sjson.Set(template, "id", responseIDResult.String())
|
||||||
}
|
}
|
||||||
|
|
||||||
// Extract and set the finish reason.
|
// Cache the finish reason - do NOT set it in output yet (will be set on final chunk)
|
||||||
if finishReasonResult := gjson.GetBytes(rawJSON, "response.candidates.0.finishReason"); finishReasonResult.Exists() {
|
if finishReasonResult := gjson.GetBytes(rawJSON, "response.candidates.0.finishReason"); finishReasonResult.Exists() {
|
||||||
template, _ = sjson.Set(template, "choices.0.finish_reason", strings.ToLower(finishReasonResult.String()))
|
(*param).(*convertCliResponseToOpenAIChatParams).UpstreamFinishReason = strings.ToUpper(finishReasonResult.String())
|
||||||
template, _ = sjson.Set(template, "choices.0.native_finish_reason", strings.ToLower(finishReasonResult.String()))
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Extract and set usage metadata (token counts).
|
// Extract and set usage metadata (token counts).
|
||||||
@@ -112,7 +113,6 @@ func ConvertAntigravityResponseToOpenAI(_ context.Context, _ string, originalReq
|
|||||||
|
|
||||||
// Process the main content part of the response.
|
// Process the main content part of the response.
|
||||||
partsResult := gjson.GetBytes(rawJSON, "response.candidates.0.content.parts")
|
partsResult := gjson.GetBytes(rawJSON, "response.candidates.0.content.parts")
|
||||||
hasFunctionCall := false
|
|
||||||
if partsResult.IsArray() {
|
if partsResult.IsArray() {
|
||||||
partResults := partsResult.Array()
|
partResults := partsResult.Array()
|
||||||
for i := 0; i < len(partResults); i++ {
|
for i := 0; i < len(partResults); i++ {
|
||||||
@@ -148,7 +148,7 @@ func ConvertAntigravityResponseToOpenAI(_ context.Context, _ string, originalReq
|
|||||||
template, _ = sjson.Set(template, "choices.0.delta.role", "assistant")
|
template, _ = sjson.Set(template, "choices.0.delta.role", "assistant")
|
||||||
} else if functionCallResult.Exists() {
|
} else if functionCallResult.Exists() {
|
||||||
// Handle function call content.
|
// Handle function call content.
|
||||||
hasFunctionCall = true
|
(*param).(*convertCliResponseToOpenAIChatParams).SawToolCall = true // Persist across chunks
|
||||||
toolCallsResult := gjson.Get(template, "choices.0.delta.tool_calls")
|
toolCallsResult := gjson.Get(template, "choices.0.delta.tool_calls")
|
||||||
functionCallIndex := (*param).(*convertCliResponseToOpenAIChatParams).FunctionIndex
|
functionCallIndex := (*param).(*convertCliResponseToOpenAIChatParams).FunctionIndex
|
||||||
(*param).(*convertCliResponseToOpenAIChatParams).FunctionIndex++
|
(*param).(*convertCliResponseToOpenAIChatParams).FunctionIndex++
|
||||||
@@ -195,9 +195,25 @@ func ConvertAntigravityResponseToOpenAI(_ context.Context, _ string, originalReq
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if hasFunctionCall {
|
// Determine finish_reason only on the final chunk (has both finishReason and usage metadata)
|
||||||
template, _ = sjson.Set(template, "choices.0.finish_reason", "tool_calls")
|
params := (*param).(*convertCliResponseToOpenAIChatParams)
|
||||||
template, _ = sjson.Set(template, "choices.0.native_finish_reason", "tool_calls")
|
upstreamFinishReason := params.UpstreamFinishReason
|
||||||
|
sawToolCall := params.SawToolCall
|
||||||
|
|
||||||
|
usageExists := gjson.GetBytes(rawJSON, "response.usageMetadata").Exists()
|
||||||
|
isFinalChunk := upstreamFinishReason != "" && usageExists
|
||||||
|
|
||||||
|
if isFinalChunk {
|
||||||
|
var finishReason string
|
||||||
|
if sawToolCall {
|
||||||
|
finishReason = "tool_calls"
|
||||||
|
} else if upstreamFinishReason == "MAX_TOKENS" {
|
||||||
|
finishReason = "max_tokens"
|
||||||
|
} else {
|
||||||
|
finishReason = "stop"
|
||||||
|
}
|
||||||
|
template, _ = sjson.Set(template, "choices.0.finish_reason", finishReason)
|
||||||
|
template, _ = sjson.Set(template, "choices.0.native_finish_reason", strings.ToLower(upstreamFinishReason))
|
||||||
}
|
}
|
||||||
|
|
||||||
return []string{template}
|
return []string{template}
|
||||||
|
|||||||
@@ -0,0 +1,128 @@
|
|||||||
|
package chat_completions
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/tidwall/gjson"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestFinishReasonToolCallsNotOverwritten(t *testing.T) {
|
||||||
|
ctx := context.Background()
|
||||||
|
var param any
|
||||||
|
|
||||||
|
// Chunk 1: Contains functionCall - should set SawToolCall = true
|
||||||
|
chunk1 := []byte(`{"response":{"candidates":[{"content":{"parts":[{"functionCall":{"name":"list_files","args":{"path":"."}}}]}}]}}`)
|
||||||
|
result1 := ConvertAntigravityResponseToOpenAI(ctx, "model", nil, nil, chunk1, ¶m)
|
||||||
|
|
||||||
|
// Verify chunk1 has no finish_reason (null)
|
||||||
|
if len(result1) != 1 {
|
||||||
|
t.Fatalf("Expected 1 result from chunk1, got %d", len(result1))
|
||||||
|
}
|
||||||
|
fr1 := gjson.Get(result1[0], "choices.0.finish_reason")
|
||||||
|
if fr1.Exists() && fr1.String() != "" && fr1.Type.String() != "Null" {
|
||||||
|
t.Errorf("Expected finish_reason to be null in chunk1, got: %v", fr1.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
// Chunk 2: Contains finishReason STOP + usage (final chunk, no functionCall)
|
||||||
|
// This simulates what the upstream sends AFTER the tool call chunk
|
||||||
|
chunk2 := []byte(`{"response":{"candidates":[{"finishReason":"STOP"}],"usageMetadata":{"promptTokenCount":10,"candidatesTokenCount":20,"totalTokenCount":30}}}`)
|
||||||
|
result2 := ConvertAntigravityResponseToOpenAI(ctx, "model", nil, nil, chunk2, ¶m)
|
||||||
|
|
||||||
|
// Verify chunk2 has finish_reason: "tool_calls" (not "stop")
|
||||||
|
if len(result2) != 1 {
|
||||||
|
t.Fatalf("Expected 1 result from chunk2, got %d", len(result2))
|
||||||
|
}
|
||||||
|
fr2 := gjson.Get(result2[0], "choices.0.finish_reason").String()
|
||||||
|
if fr2 != "tool_calls" {
|
||||||
|
t.Errorf("Expected finish_reason 'tool_calls', got: %s", fr2)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify native_finish_reason is lowercase upstream value
|
||||||
|
nfr2 := gjson.Get(result2[0], "choices.0.native_finish_reason").String()
|
||||||
|
if nfr2 != "stop" {
|
||||||
|
t.Errorf("Expected native_finish_reason 'stop', got: %s", nfr2)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFinishReasonStopForNormalText(t *testing.T) {
|
||||||
|
ctx := context.Background()
|
||||||
|
var param any
|
||||||
|
|
||||||
|
// Chunk 1: Text content only
|
||||||
|
chunk1 := []byte(`{"response":{"candidates":[{"content":{"parts":[{"text":"Hello world"}]}}]}}`)
|
||||||
|
ConvertAntigravityResponseToOpenAI(ctx, "model", nil, nil, chunk1, ¶m)
|
||||||
|
|
||||||
|
// Chunk 2: Final chunk with STOP
|
||||||
|
chunk2 := []byte(`{"response":{"candidates":[{"finishReason":"STOP"}],"usageMetadata":{"promptTokenCount":10,"candidatesTokenCount":5,"totalTokenCount":15}}}`)
|
||||||
|
result2 := ConvertAntigravityResponseToOpenAI(ctx, "model", nil, nil, chunk2, ¶m)
|
||||||
|
|
||||||
|
// Verify finish_reason is "stop" (no tool calls were made)
|
||||||
|
fr := gjson.Get(result2[0], "choices.0.finish_reason").String()
|
||||||
|
if fr != "stop" {
|
||||||
|
t.Errorf("Expected finish_reason 'stop', got: %s", fr)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFinishReasonMaxTokens(t *testing.T) {
|
||||||
|
ctx := context.Background()
|
||||||
|
var param any
|
||||||
|
|
||||||
|
// Chunk 1: Text content
|
||||||
|
chunk1 := []byte(`{"response":{"candidates":[{"content":{"parts":[{"text":"Hello"}]}}]}}`)
|
||||||
|
ConvertAntigravityResponseToOpenAI(ctx, "model", nil, nil, chunk1, ¶m)
|
||||||
|
|
||||||
|
// Chunk 2: Final chunk with MAX_TOKENS
|
||||||
|
chunk2 := []byte(`{"response":{"candidates":[{"finishReason":"MAX_TOKENS"}],"usageMetadata":{"promptTokenCount":10,"candidatesTokenCount":100,"totalTokenCount":110}}}`)
|
||||||
|
result2 := ConvertAntigravityResponseToOpenAI(ctx, "model", nil, nil, chunk2, ¶m)
|
||||||
|
|
||||||
|
// Verify finish_reason is "max_tokens"
|
||||||
|
fr := gjson.Get(result2[0], "choices.0.finish_reason").String()
|
||||||
|
if fr != "max_tokens" {
|
||||||
|
t.Errorf("Expected finish_reason 'max_tokens', got: %s", fr)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestToolCallTakesPriorityOverMaxTokens(t *testing.T) {
|
||||||
|
ctx := context.Background()
|
||||||
|
var param any
|
||||||
|
|
||||||
|
// Chunk 1: Contains functionCall
|
||||||
|
chunk1 := []byte(`{"response":{"candidates":[{"content":{"parts":[{"functionCall":{"name":"test","args":{}}}]}}]}}`)
|
||||||
|
ConvertAntigravityResponseToOpenAI(ctx, "model", nil, nil, chunk1, ¶m)
|
||||||
|
|
||||||
|
// Chunk 2: Final chunk with MAX_TOKENS (but we had a tool call, so tool_calls should win)
|
||||||
|
chunk2 := []byte(`{"response":{"candidates":[{"finishReason":"MAX_TOKENS"}],"usageMetadata":{"promptTokenCount":10,"candidatesTokenCount":100,"totalTokenCount":110}}}`)
|
||||||
|
result2 := ConvertAntigravityResponseToOpenAI(ctx, "model", nil, nil, chunk2, ¶m)
|
||||||
|
|
||||||
|
// Verify finish_reason is "tool_calls" (takes priority over max_tokens)
|
||||||
|
fr := gjson.Get(result2[0], "choices.0.finish_reason").String()
|
||||||
|
if fr != "tool_calls" {
|
||||||
|
t.Errorf("Expected finish_reason 'tool_calls', got: %s", fr)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNoFinishReasonOnIntermediateChunks(t *testing.T) {
|
||||||
|
ctx := context.Background()
|
||||||
|
var param any
|
||||||
|
|
||||||
|
// Chunk 1: Text content (no finish reason, no usage)
|
||||||
|
chunk1 := []byte(`{"response":{"candidates":[{"content":{"parts":[{"text":"Hello"}]}}]}}`)
|
||||||
|
result1 := ConvertAntigravityResponseToOpenAI(ctx, "model", nil, nil, chunk1, ¶m)
|
||||||
|
|
||||||
|
// Verify no finish_reason on intermediate chunk
|
||||||
|
fr1 := gjson.Get(result1[0], "choices.0.finish_reason")
|
||||||
|
if fr1.Exists() && fr1.String() != "" && fr1.Type.String() != "Null" {
|
||||||
|
t.Errorf("Expected no finish_reason on intermediate chunk, got: %v", fr1)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Chunk 2: More text (no finish reason, no usage)
|
||||||
|
chunk2 := []byte(`{"response":{"candidates":[{"content":{"parts":[{"text":" world"}]}}]}}`)
|
||||||
|
result2 := ConvertAntigravityResponseToOpenAI(ctx, "model", nil, nil, chunk2, ¶m)
|
||||||
|
|
||||||
|
// Verify no finish_reason on intermediate chunk
|
||||||
|
fr2 := gjson.Get(result2[0], "choices.0.finish_reason")
|
||||||
|
if fr2.Exists() && fr2.String() != "" && fr2.Type.String() != "Null" {
|
||||||
|
t.Errorf("Expected no finish_reason on intermediate chunk, got: %v", fr2)
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -57,6 +57,9 @@ func BuildConfigChangeDetails(oldCfg, newCfg *config.Config) []string {
|
|||||||
if oldCfg.NonStreamKeepAliveInterval != newCfg.NonStreamKeepAliveInterval {
|
if oldCfg.NonStreamKeepAliveInterval != newCfg.NonStreamKeepAliveInterval {
|
||||||
changes = append(changes, fmt.Sprintf("nonstream-keepalive-interval: %d -> %d", oldCfg.NonStreamKeepAliveInterval, newCfg.NonStreamKeepAliveInterval))
|
changes = append(changes, fmt.Sprintf("nonstream-keepalive-interval: %d -> %d", oldCfg.NonStreamKeepAliveInterval, newCfg.NonStreamKeepAliveInterval))
|
||||||
}
|
}
|
||||||
|
if oldCfg.CodexInstructionsEnabled != newCfg.CodexInstructionsEnabled {
|
||||||
|
changes = append(changes, fmt.Sprintf("codex-instructions-enabled: %t -> %t", oldCfg.CodexInstructionsEnabled, newCfg.CodexInstructionsEnabled))
|
||||||
|
}
|
||||||
|
|
||||||
// Quota-exceeded behavior
|
// Quota-exceeded behavior
|
||||||
if oldCfg.QuotaExceeded.SwitchProject != newCfg.QuotaExceeded.SwitchProject {
|
if oldCfg.QuotaExceeded.SwitchProject != newCfg.QuotaExceeded.SwitchProject {
|
||||||
|
|||||||
@@ -650,7 +650,7 @@ func (h *BaseAPIHandler) getRequestDetails(modelName string) (providers []string
|
|||||||
}
|
}
|
||||||
|
|
||||||
if len(providers) == 0 {
|
if len(providers) == 0 {
|
||||||
return nil, "", &interfaces.ErrorMessage{StatusCode: http.StatusBadRequest, Error: fmt.Errorf("unknown provider for model %s", modelName)}
|
return nil, "", &interfaces.ErrorMessage{StatusCode: http.StatusBadGateway, Error: fmt.Errorf("unknown provider for model %s", modelName)}
|
||||||
}
|
}
|
||||||
|
|
||||||
// The thinking suffix is preserved in the model name itself, so no
|
// The thinking suffix is preserved in the model name itself, so no
|
||||||
|
|||||||
@@ -18,6 +18,7 @@ import (
|
|||||||
internalconfig "github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
internalconfig "github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/logging"
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/logging"
|
||||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/routing/ctxkeys"
|
||||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/thinking"
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/thinking"
|
||||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
||||||
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
|
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
|
||||||
@@ -577,7 +578,7 @@ func (m *Manager) executeWithFallback(
|
|||||||
|
|
||||||
// Track fallback models from context (provided by Amp module fallback_models key)
|
// Track fallback models from context (provided by Amp module fallback_models key)
|
||||||
var fallbacks []string
|
var fallbacks []string
|
||||||
if v := ctx.Value("fallback_models"); v != nil {
|
if v := ctx.Value(ctxkeys.FallbackModels); v != nil {
|
||||||
if fs, ok := v.([]string); ok {
|
if fs, ok := v.([]string); ok {
|
||||||
fallbacks = fs
|
fallbacks = fs
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -19,6 +19,7 @@ type AmpCode = internalconfig.AmpCode
|
|||||||
type OAuthModelAlias = internalconfig.OAuthModelAlias
|
type OAuthModelAlias = internalconfig.OAuthModelAlias
|
||||||
type PayloadConfig = internalconfig.PayloadConfig
|
type PayloadConfig = internalconfig.PayloadConfig
|
||||||
type PayloadRule = internalconfig.PayloadRule
|
type PayloadRule = internalconfig.PayloadRule
|
||||||
|
type PayloadFilterRule = internalconfig.PayloadFilterRule
|
||||||
type PayloadModelRule = internalconfig.PayloadModelRule
|
type PayloadModelRule = internalconfig.PayloadModelRule
|
||||||
|
|
||||||
type GeminiKey = internalconfig.GeminiKey
|
type GeminiKey = internalconfig.GeminiKey
|
||||||
|
|||||||
@@ -3,6 +3,8 @@ package logging
|
|||||||
|
|
||||||
import internallogging "github.com/router-for-me/CLIProxyAPI/v6/internal/logging"
|
import internallogging "github.com/router-for-me/CLIProxyAPI/v6/internal/logging"
|
||||||
|
|
||||||
|
const defaultErrorLogsMaxFiles = 10
|
||||||
|
|
||||||
// RequestLogger defines the interface for logging HTTP requests and responses.
|
// RequestLogger defines the interface for logging HTTP requests and responses.
|
||||||
type RequestLogger = internallogging.RequestLogger
|
type RequestLogger = internallogging.RequestLogger
|
||||||
|
|
||||||
@@ -12,7 +14,12 @@ type StreamingLogWriter = internallogging.StreamingLogWriter
|
|||||||
// FileRequestLogger implements RequestLogger using file-based storage.
|
// FileRequestLogger implements RequestLogger using file-based storage.
|
||||||
type FileRequestLogger = internallogging.FileRequestLogger
|
type FileRequestLogger = internallogging.FileRequestLogger
|
||||||
|
|
||||||
// NewFileRequestLogger creates a new file-based request logger.
|
// NewFileRequestLogger creates a new file-based request logger with default error log retention (10 files).
|
||||||
func NewFileRequestLogger(enabled bool, logsDir string, configDir string) *FileRequestLogger {
|
func NewFileRequestLogger(enabled bool, logsDir string, configDir string) *FileRequestLogger {
|
||||||
return internallogging.NewFileRequestLogger(enabled, logsDir, configDir)
|
return internallogging.NewFileRequestLogger(enabled, logsDir, configDir, defaultErrorLogsMaxFiles)
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewFileRequestLoggerWithOptions creates a new file-based request logger with configurable error log retention.
|
||||||
|
func NewFileRequestLoggerWithOptions(enabled bool, logsDir string, configDir string, errorLogsMaxFiles int) *FileRequestLogger {
|
||||||
|
return internallogging.NewFileRequestLogger(enabled, logsDir, configDir, errorLogsMaxFiles)
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user