Compare commits

..

44 Commits

Author SHA1 Message Date
Luis Pater
39b6b3b289 Fixed: #463
fix(antigravity): remove `$ref` and `$defs` from JSON during key deletion
2025-12-09 17:32:17 +08:00
Luis Pater
c600519fa4 refactor(logging): replace log.Fatalf with log.Errorf and add error handling paths 2025-12-09 17:16:30 +08:00
Luis Pater
92df0cada9 Merge pull request #461 from router-for-me/aistudio
feat(aistudio): normalize thinking budget in request translation
2025-12-09 08:41:46 +08:00
hkfires
96b55acff8 feat(aistudio): normalize thinking budget in request translation 2025-12-09 08:27:44 +08:00
Luis Pater
bb45fee1cf Merge remote-tracking branch 'origin/dev' into dev 2025-12-08 23:28:22 +08:00
Luis Pater
af00304b0c fix(antigravity): remove exclusiveMaximum from JSON during key deletion 2025-12-08 23:28:01 +08:00
vuonglv(Andy)
5c3a013cd1 feat(config): add configurable host binding for server (#454)
* feat(config): add configurable host binding for server
2025-12-08 23:16:39 +08:00
Luis Pater
6ad188921c refactor(logging): remove unused variable in ensureAttempt and redundant function call 2025-12-08 22:25:58 +08:00
Luis Pater
15ed98d6a9 Merge pull request #458 from router-for-me/agry
feat(antigravity): enforce thinking budget limits for Claude models
2025-12-08 20:55:52 +08:00
hkfires
a283545b6b feat(antigravity): enforce thinking budget limits for Claude models 2025-12-08 20:36:17 +08:00
Luis Pater
3efbd865a8 Merge pull request #457 from router-for-me/requestlog
style(logging): remove redundant separator line from response section
2025-12-08 18:21:24 +08:00
hkfires
aee659fb66 style(logging): remove redundant separator line from response section 2025-12-08 18:18:33 +08:00
Luis Pater
5aa386d8b9 Merge pull request #453 from router-for-me/amp
add ampcode management api
2025-12-08 17:42:13 +08:00
Luis Pater
0adc0ee6aa Merge pull request #455 from router-for-me/requestlog
feat(logging): add upstream API request/response capture to streaming logs
2025-12-08 17:40:10 +08:00
hkfires
92f13fc316 feat(logging): add upstream API request/response capture to streaming logs 2025-12-08 17:21:58 +08:00
hkfires
05cfa16e5f refactor(api): simplify request body parsing in ampcode handlers 2025-12-08 14:45:35 +08:00
hkfires
93a6e2d920 feat(api): add comprehensive ampcode management endpoints
Add new REST API endpoints under /v0/management/ampcode for managing
ampcode configuration including upstream URL, API key, localhost
restriction, model mappings, and force model mappings settings.

- Move force-model-mappings from config_basic to config_lists
- Add GET/PUT/PATCH/DELETE endpoints for all ampcode settings
- Support model mapping CRUD with upsert (PATCH) capability
- Add comprehensive test coverage for all ampcode endpoints
2025-12-08 12:03:00 +08:00
Luis Pater
de77903915 Merge pull request #450 from router-for-me/amp
refactor(config): rename prioritize-model-mappings to force-model-mappings
2025-12-08 10:51:32 +08:00
hkfires
56ed0d8d90 refactor(config): rename prioritize-model-mappings to force-model-mappings 2025-12-08 10:44:39 +08:00
Luis Pater
42e818ce05 Merge pull request #435 from heyhuynhgiabuu/fix/amp-model-mapping-priority
fix: prioritize model mappings over local providers for Amp CLI
2025-12-08 10:17:19 +08:00
Luis Pater
2d4c54ba54 Merge pull request #448 from router-for-me/iflow
Iflow
2025-12-08 09:50:05 +08:00
hkfires
e9eb4db8bb feat(auth): refresh API key during cookie authentication 2025-12-08 09:48:31 +08:00
Luis Pater
d26ed069fa Merge pull request #441 from huynguyen03dev/fix/claude-to-openai-whitespace-text
fix: filter whitespace-only text in Claude to OpenAI translation
2025-12-08 09:43:44 +08:00
huynhgiabuu
afcab5efda feat: add prioritize-model-mappings config option
Add a configuration option to control whether model mappings take
precedence over local API keys for Amp CLI requests.

- Add PrioritizeModelMappings field to AmpCode config struct
- When false (default): Local API keys take precedence (original behavior)
- When true: Model mappings take precedence over local API keys
- Add management API endpoints GET/PUT /prioritize-model-mappings

This allows users who want mapping priority to enable it explicitly
while preserving backward compatibility.

Config example:
  ampcode:
    model-mappings:
      - from: claude-opus-4-5-20251101
        to: gemini-claude-opus-4-5-thinking
    prioritize-model-mappings: true
2025-12-07 22:47:43 +07:00
Luis Pater
6cf1d8a947 Merge pull request #444 from router-for-me/agry
feat(registry): add explicit thinking support config for antigravity models
2025-12-07 19:38:43 +08:00
hkfires
a174d015f2 feat(openai): handle thinking.budget_tokens from Anthropic-style requests 2025-12-07 19:14:05 +08:00
hkfires
9c09128e00 feat(registry): add explicit thinking support config for antigravity models 2025-12-07 19:12:55 +08:00
huynguyen03.dev
549c0c2c5a fix: filter whitespace-only text content in Claude to OpenAI translation
Remove redundant existence check since TrimSpace handles empty strings
2025-12-07 16:08:12 +07:00
huynguyen03.dev
f092801b61 fix: filter whitespace-only text in Claude to OpenAI translation
Skip text content blocks that are empty or contain only whitespace
when translating Claude messages to OpenAI format. This fixes GLM-4.6
and other strict OpenAI-compatible providers that reject empty text
with error 'text cannot be empty'.
2025-12-07 15:39:58 +07:00
Luis Pater
1b638b3629 Merge pull request #432 from huynguyen03dev/fix/amp-gemini-model-mapping
fix(amp): pass mapped model to gemini bridge via context
2025-12-07 13:33:28 +08:00
Luis Pater
6f5f81753d Merge pull request #439 from router-for-me/log
feat(logging): add version info to request log output
2025-12-07 13:31:06 +08:00
Luis Pater
76af454034 **feat(antigravity): enhance handling of "thinking" content and refine Claude model response processing** 2025-12-07 13:19:12 +08:00
hkfires
e54d2f6b2a feat(logging): add version info to request log output 2025-12-07 12:49:14 +08:00
huynguyen03.dev
bfc738b76a refactor: remove duplicate provider check in gemini v1beta1 route
Simplifies routing logic by delegating all provider/mapping/proxy
decisions to FallbackHandler. Previously, the route checked for
provider/mapping availability before calling the handler, then
FallbackHandler performed the same checks again.

Changes:
- Remove model extraction and provider checking from route (lines 182-201)
- Route now only checks if request is POST with /models/ path
- FallbackHandler handles provider -> mapping -> proxy fallback
- Remove unused internal/util import

Benefits:
- Eliminates duplicate checks (addresses PR review feedback #2)
- Centralizes all provider/mapping logic in FallbackHandler
- Reduces routing code by ~20 lines
- Aligns with how other /api/provider routes work

Performance: No impact (checks still happen once in FallbackHandler)
2025-12-07 10:54:58 +07:00
huynguyen03.dev
396899a530 refactor: improve gemini bridge testability and code quality
- Change createGeminiBridgeHandler to accept gin.HandlerFunc instead of *gemini.GeminiAPIHandler
  This allows tests to inject mock handlers instead of duplicating bridge logic
- Replace magic number 8 with len(modelsPrefix) for better maintainability
- Remove redundant test case that doesn't test edge case in production
- Update routes.go to pass geminiHandlers.GeminiHandler directly

Addresses PR review feedback on test architecture and code clarity.

Amp-Thread-ID: https://ampcode.com/threads/T-1ae2c691-e434-4b99-a49a-10cabd3544db
2025-12-07 10:15:42 +07:00
Luis Pater
f383840cf9 fix(antigravity): update toolNode role from "tool" to "user" in chat completions 2025-12-07 02:37:46 +08:00
Luis Pater
fd29ab418a Fixed: #424
**feat(antigravity): add support for maxOutputTokens and refine Claude model handling**
2025-12-07 01:55:57 +08:00
Luis Pater
7a628426dc Fixed: #433
refactor(translator): normalize finish reason casing across all OpenAI response handlers
2025-12-07 01:48:24 +08:00
Luis Pater
56b4d7a76e docs(readme): add ProxyPal CLIProxyAPI GUI to project list 2025-12-07 01:13:30 +08:00
Luis Pater
b211c3546d Merge pull request #429 from heyhuynhgiabuu/feature/add-proxypal
docs: add ProxyPal to 'Who is with us?' section
2025-12-07 01:10:44 +08:00
huynguyen03.dev
edc654edf9 refactor: simplify provider check logic in amp routes
Amp-Thread-ID: https://ampcode.com/threads/T-a18fd71c-32ce-4c29-93d7-09f082740e51
2025-12-06 22:07:40 +07:00
huynguyen03.dev
08586334af fix(amp): pass mapped model to gemini bridge via context
Gemini handler extracts model from URL path, not JSON body, so
rewriting the request body alone wasn't sufficient for model mapping.

- Add MappedModelContextKey constant for context passing
- Update routes.go to use NewFallbackHandlerWithMapper
- Add check for valid mapping before routing to local handler
- Add tests for gemini bridge model mapping
2025-12-06 18:59:44 +07:00
Huynh Gia Buu
c04c3832a4 Update README.md
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
2025-12-06 13:48:08 +07:00
huynhgiabuu
5ffbd54755 docs: add ProxyPal to 'Who is with us?' section 2025-12-06 13:45:49 +07:00
35 changed files with 1673 additions and 199 deletions

View File

@@ -95,6 +95,10 @@ Browser-based tool to translate SRT subtitles using your Gemini subscription via
CLI wrapper for instant switching between multiple Claude accounts and alternative models (Gemini, Codex, Antigravity) via CLIProxyAPI OAuth - no API keys needed
### [ProxyPal](https://github.com/heyhuynhgiabuu/proxypal)
Native macOS GUI for managing CLIProxyAPI: configure providers, model mappings, and endpoints via OAuth - no API keys needed.
> [!NOTE]
> If you developed a project based on CLIProxyAPI, please open a PR to add it to this list.

View File

@@ -93,6 +93,10 @@ CLIProxyAPI 已内置对 [Amp CLI](https://ampcode.com) 和 Amp IDE 扩展的支
CLI 封装器,用于通过 CLIProxyAPI OAuth 即时切换多个 Claude 账户和替代模型Gemini, Codex, Antigravity无需 API 密钥。
### [ProxyPal](https://github.com/heyhuynhgiabuu/proxypal)
基于 macOS 平台的原生 CLIProxyAPI GUI配置供应商、模型映射以及OAuth端点无需 API 密钥。
> [!NOTE]
> 如果你开发了基于 CLIProxyAPI 的项目,请提交一个 PR拉取请求将其添加到此列表中。

View File

@@ -139,7 +139,8 @@ func main() {
wd, err := os.Getwd()
if err != nil {
log.Fatalf("failed to get working directory: %v", err)
log.Errorf("failed to get working directory: %v", err)
return
}
// Load environment variables from .env if present.
@@ -233,13 +234,15 @@ func main() {
})
cancel()
if err != nil {
log.Fatalf("failed to initialize postgres token store: %v", err)
log.Errorf("failed to initialize postgres token store: %v", err)
return
}
examplePath := filepath.Join(wd, "config.example.yaml")
ctx, cancel = context.WithTimeout(context.Background(), 30*time.Second)
if errBootstrap := pgStoreInst.Bootstrap(ctx, examplePath); errBootstrap != nil {
cancel()
log.Fatalf("failed to bootstrap postgres-backed config: %v", errBootstrap)
log.Errorf("failed to bootstrap postgres-backed config: %v", errBootstrap)
return
}
cancel()
configFilePath = pgStoreInst.ConfigPath()
@@ -262,7 +265,8 @@ func main() {
if strings.Contains(resolvedEndpoint, "://") {
parsed, errParse := url.Parse(resolvedEndpoint)
if errParse != nil {
log.Fatalf("failed to parse object store endpoint %q: %v", objectStoreEndpoint, errParse)
log.Errorf("failed to parse object store endpoint %q: %v", objectStoreEndpoint, errParse)
return
}
switch strings.ToLower(parsed.Scheme) {
case "http":
@@ -270,10 +274,12 @@ func main() {
case "https":
useSSL = true
default:
log.Fatalf("unsupported object store scheme %q (only http and https are allowed)", parsed.Scheme)
log.Errorf("unsupported object store scheme %q (only http and https are allowed)", parsed.Scheme)
return
}
if parsed.Host == "" {
log.Fatalf("object store endpoint %q is missing host information", objectStoreEndpoint)
log.Errorf("object store endpoint %q is missing host information", objectStoreEndpoint)
return
}
resolvedEndpoint = parsed.Host
if parsed.Path != "" && parsed.Path != "/" {
@@ -292,13 +298,15 @@ func main() {
}
objectStoreInst, err = store.NewObjectTokenStore(objCfg)
if err != nil {
log.Fatalf("failed to initialize object token store: %v", err)
log.Errorf("failed to initialize object token store: %v", err)
return
}
examplePath := filepath.Join(wd, "config.example.yaml")
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
if errBootstrap := objectStoreInst.Bootstrap(ctx, examplePath); errBootstrap != nil {
cancel()
log.Fatalf("failed to bootstrap object-backed config: %v", errBootstrap)
log.Errorf("failed to bootstrap object-backed config: %v", errBootstrap)
return
}
cancel()
configFilePath = objectStoreInst.ConfigPath()
@@ -323,7 +331,8 @@ func main() {
gitStoreInst = store.NewGitTokenStore(gitStoreRemoteURL, gitStoreUser, gitStorePassword)
gitStoreInst.SetBaseDir(authDir)
if errRepo := gitStoreInst.EnsureRepository(); errRepo != nil {
log.Fatalf("failed to prepare git token store: %v", errRepo)
log.Errorf("failed to prepare git token store: %v", errRepo)
return
}
configFilePath = gitStoreInst.ConfigPath()
if configFilePath == "" {
@@ -332,17 +341,21 @@ func main() {
if _, statErr := os.Stat(configFilePath); errors.Is(statErr, fs.ErrNotExist) {
examplePath := filepath.Join(wd, "config.example.yaml")
if _, errExample := os.Stat(examplePath); errExample != nil {
log.Fatalf("failed to find template config file: %v", errExample)
log.Errorf("failed to find template config file: %v", errExample)
return
}
if errCopy := misc.CopyConfigTemplate(examplePath, configFilePath); errCopy != nil {
log.Fatalf("failed to bootstrap git-backed config: %v", errCopy)
log.Errorf("failed to bootstrap git-backed config: %v", errCopy)
return
}
if errCommit := gitStoreInst.PersistConfig(context.Background()); errCommit != nil {
log.Fatalf("failed to commit initial git-backed config: %v", errCommit)
log.Errorf("failed to commit initial git-backed config: %v", errCommit)
return
}
log.Infof("git-backed config initialized from template: %s", configFilePath)
} else if statErr != nil {
log.Fatalf("failed to inspect git-backed config: %v", statErr)
log.Errorf("failed to inspect git-backed config: %v", statErr)
return
}
cfg, err = config.LoadConfigOptional(configFilePath, isCloudDeploy)
if err == nil {
@@ -355,13 +368,15 @@ func main() {
} else {
wd, err = os.Getwd()
if err != nil {
log.Fatalf("failed to get working directory: %v", err)
log.Errorf("failed to get working directory: %v", err)
return
}
configFilePath = filepath.Join(wd, "config.yaml")
cfg, err = config.LoadConfigOptional(configFilePath, isCloudDeploy)
}
if err != nil {
log.Fatalf("failed to load config: %v", err)
log.Errorf("failed to load config: %v", err)
return
}
if cfg == nil {
cfg = &config.Config{}
@@ -391,7 +406,8 @@ func main() {
coreauth.SetQuotaCooldownDisabled(cfg.DisableCooling)
if err = logging.ConfigureLogOutput(cfg.LoggingToFile); err != nil {
log.Fatalf("failed to configure log output: %v", err)
log.Errorf("failed to configure log output: %v", err)
return
}
log.Infof("CLIProxyAPI Version: %s, Commit: %s, BuiltAt: %s", buildinfo.Version, buildinfo.Commit, buildinfo.BuildDate)
@@ -400,7 +416,8 @@ func main() {
util.SetLogLevel(cfg)
if resolvedAuthDir, errResolveAuthDir := util.ResolveAuthDir(cfg.AuthDir); errResolveAuthDir != nil {
log.Fatalf("failed to resolve auth directory: %v", errResolveAuthDir)
log.Errorf("failed to resolve auth directory: %v", errResolveAuthDir)
return
} else {
cfg.AuthDir = resolvedAuthDir
}

View File

@@ -1,3 +1,7 @@
# Server host/interface to bind to. Default is empty ("") to bind all interfaces (IPv4 + IPv6).
# Use "127.0.0.1" or "localhost" to restrict access to local machine only.
host: ""
# Server port
port: 8317
@@ -134,6 +138,8 @@ ws-auth: false
# upstream-api-key: ""
# # Restrict Amp management routes (/api/auth, /api/user, etc.) to localhost only (recommended)
# restrict-management-to-localhost: true
# # Force model mappings to run before checking local API keys (default: false)
# force-model-mappings: false
# # Amp Model Mappings
# # Route unavailable Amp models to alternative models available in your local proxy.
# # Useful when Amp CLI requests models you don't have access to (e.g., Claude Opus 4.5)

View File

@@ -713,14 +713,16 @@ func (h *Handler) RequestAnthropicToken(c *gin.Context) {
// Generate PKCE codes
pkceCodes, err := claude.GeneratePKCECodes()
if err != nil {
log.Fatalf("Failed to generate PKCE codes: %v", err)
log.Errorf("Failed to generate PKCE codes: %v", err)
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to generate PKCE codes"})
return
}
// Generate random state parameter
state, err := misc.GenerateRandomState()
if err != nil {
log.Fatalf("Failed to generate state parameter: %v", err)
log.Errorf("Failed to generate state parameter: %v", err)
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to generate state parameter"})
return
}
@@ -730,7 +732,8 @@ func (h *Handler) RequestAnthropicToken(c *gin.Context) {
// Generate authorization URL (then override redirect_uri to reuse server port)
authURL, state, err := anthropicAuth.GenerateAuthURL(state, pkceCodes)
if err != nil {
log.Fatalf("Failed to generate authorization URL: %v", err)
log.Errorf("Failed to generate authorization URL: %v", err)
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to generate authorization url"})
return
}
@@ -872,7 +875,7 @@ func (h *Handler) RequestAnthropicToken(c *gin.Context) {
}
savedPath, errSave := h.saveTokenRecord(ctx, record)
if errSave != nil {
log.Fatalf("Failed to save authentication tokens: %v", errSave)
log.Errorf("Failed to save authentication tokens: %v", errSave)
oauthStatus[state] = "Failed to save authentication tokens"
return
}
@@ -1045,7 +1048,7 @@ func (h *Handler) RequestGeminiCLIToken(c *gin.Context) {
gemAuth := geminiAuth.NewGeminiAuth()
gemClient, errGetClient := gemAuth.GetAuthenticatedClient(ctx, &ts, h.cfg, true)
if errGetClient != nil {
log.Fatalf("failed to get authenticated client: %v", errGetClient)
log.Errorf("failed to get authenticated client: %v", errGetClient)
oauthStatus[state] = "Failed to get authenticated client"
return
}
@@ -1110,7 +1113,7 @@ func (h *Handler) RequestGeminiCLIToken(c *gin.Context) {
}
savedPath, errSave := h.saveTokenRecord(ctx, record)
if errSave != nil {
log.Fatalf("Failed to save token to file: %v", errSave)
log.Errorf("Failed to save token to file: %v", errSave)
oauthStatus[state] = "Failed to save token to file"
return
}
@@ -1131,14 +1134,16 @@ func (h *Handler) RequestCodexToken(c *gin.Context) {
// Generate PKCE codes
pkceCodes, err := codex.GeneratePKCECodes()
if err != nil {
log.Fatalf("Failed to generate PKCE codes: %v", err)
log.Errorf("Failed to generate PKCE codes: %v", err)
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to generate PKCE codes"})
return
}
// Generate random state parameter
state, err := misc.GenerateRandomState()
if err != nil {
log.Fatalf("Failed to generate state parameter: %v", err)
log.Errorf("Failed to generate state parameter: %v", err)
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to generate state parameter"})
return
}
@@ -1148,7 +1153,8 @@ func (h *Handler) RequestCodexToken(c *gin.Context) {
// Generate authorization URL
authURL, err := openaiAuth.GenerateAuthURL(state, pkceCodes)
if err != nil {
log.Fatalf("Failed to generate authorization URL: %v", err)
log.Errorf("Failed to generate authorization URL: %v", err)
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to generate authorization url"})
return
}
@@ -1283,7 +1289,7 @@ func (h *Handler) RequestCodexToken(c *gin.Context) {
savedPath, errSave := h.saveTokenRecord(ctx, record)
if errSave != nil {
oauthStatus[state] = "Failed to save authentication tokens"
log.Fatalf("Failed to save authentication tokens: %v", errSave)
log.Errorf("Failed to save authentication tokens: %v", errSave)
return
}
fmt.Printf("Authentication successful! Token saved to %s\n", savedPath)
@@ -1318,7 +1324,8 @@ func (h *Handler) RequestAntigravityToken(c *gin.Context) {
state, errState := misc.GenerateRandomState()
if errState != nil {
log.Fatalf("Failed to generate state parameter: %v", errState)
log.Errorf("Failed to generate state parameter: %v", errState)
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to generate state parameter"})
return
}
@@ -1514,7 +1521,7 @@ func (h *Handler) RequestAntigravityToken(c *gin.Context) {
}
savedPath, errSave := h.saveTokenRecord(ctx, record)
if errSave != nil {
log.Fatalf("Failed to save token to file: %v", errSave)
log.Errorf("Failed to save token to file: %v", errSave)
oauthStatus[state] = "Failed to save token to file"
return
}
@@ -1543,7 +1550,8 @@ func (h *Handler) RequestQwenToken(c *gin.Context) {
// Generate authorization URL
deviceFlow, err := qwenAuth.InitiateDeviceFlow(ctx)
if err != nil {
log.Fatalf("Failed to generate authorization URL: %v", err)
log.Errorf("Failed to generate authorization URL: %v", err)
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to generate authorization url"})
return
}
authURL := deviceFlow.VerificationURIComplete
@@ -1570,7 +1578,7 @@ func (h *Handler) RequestQwenToken(c *gin.Context) {
}
savedPath, errSave := h.saveTokenRecord(ctx, record)
if errSave != nil {
log.Fatalf("Failed to save authentication tokens: %v", errSave)
log.Errorf("Failed to save authentication tokens: %v", errSave)
oauthStatus[state] = "Failed to save authentication tokens"
return
}
@@ -1674,7 +1682,7 @@ func (h *Handler) RequestIFlowToken(c *gin.Context) {
savedPath, errSave := h.saveTokenRecord(ctx, record)
if errSave != nil {
oauthStatus[state] = "Failed to save authentication tokens"
log.Fatalf("Failed to save authentication tokens: %v", errSave)
log.Errorf("Failed to save authentication tokens: %v", errSave)
return
}
@@ -2103,6 +2111,7 @@ func checkCloudAPIIsEnabled(ctx context.Context, httpClient *http.Client, projec
continue
}
}
_ = resp.Body.Close()
return false, fmt.Errorf("project activation required: %s", errMessage)
}
return true, nil

View File

@@ -706,3 +706,155 @@ func normalizeClaudeKey(entry *config.ClaudeKey) {
}
entry.Models = normalized
}
// GetAmpCode returns the complete ampcode configuration.
func (h *Handler) GetAmpCode(c *gin.Context) {
if h == nil || h.cfg == nil {
c.JSON(200, gin.H{"ampcode": config.AmpCode{}})
return
}
c.JSON(200, gin.H{"ampcode": h.cfg.AmpCode})
}
// GetAmpUpstreamURL returns the ampcode upstream URL.
func (h *Handler) GetAmpUpstreamURL(c *gin.Context) {
if h == nil || h.cfg == nil {
c.JSON(200, gin.H{"upstream-url": ""})
return
}
c.JSON(200, gin.H{"upstream-url": h.cfg.AmpCode.UpstreamURL})
}
// PutAmpUpstreamURL updates the ampcode upstream URL.
func (h *Handler) PutAmpUpstreamURL(c *gin.Context) {
h.updateStringField(c, func(v string) { h.cfg.AmpCode.UpstreamURL = strings.TrimSpace(v) })
}
// DeleteAmpUpstreamURL clears the ampcode upstream URL.
func (h *Handler) DeleteAmpUpstreamURL(c *gin.Context) {
h.cfg.AmpCode.UpstreamURL = ""
h.persist(c)
}
// GetAmpUpstreamAPIKey returns the ampcode upstream API key.
func (h *Handler) GetAmpUpstreamAPIKey(c *gin.Context) {
if h == nil || h.cfg == nil {
c.JSON(200, gin.H{"upstream-api-key": ""})
return
}
c.JSON(200, gin.H{"upstream-api-key": h.cfg.AmpCode.UpstreamAPIKey})
}
// PutAmpUpstreamAPIKey updates the ampcode upstream API key.
func (h *Handler) PutAmpUpstreamAPIKey(c *gin.Context) {
h.updateStringField(c, func(v string) { h.cfg.AmpCode.UpstreamAPIKey = strings.TrimSpace(v) })
}
// DeleteAmpUpstreamAPIKey clears the ampcode upstream API key.
func (h *Handler) DeleteAmpUpstreamAPIKey(c *gin.Context) {
h.cfg.AmpCode.UpstreamAPIKey = ""
h.persist(c)
}
// GetAmpRestrictManagementToLocalhost returns the localhost restriction setting.
func (h *Handler) GetAmpRestrictManagementToLocalhost(c *gin.Context) {
if h == nil || h.cfg == nil {
c.JSON(200, gin.H{"restrict-management-to-localhost": true})
return
}
c.JSON(200, gin.H{"restrict-management-to-localhost": h.cfg.AmpCode.RestrictManagementToLocalhost})
}
// PutAmpRestrictManagementToLocalhost updates the localhost restriction setting.
func (h *Handler) PutAmpRestrictManagementToLocalhost(c *gin.Context) {
h.updateBoolField(c, func(v bool) { h.cfg.AmpCode.RestrictManagementToLocalhost = v })
}
// GetAmpModelMappings returns the ampcode model mappings.
func (h *Handler) GetAmpModelMappings(c *gin.Context) {
if h == nil || h.cfg == nil {
c.JSON(200, gin.H{"model-mappings": []config.AmpModelMapping{}})
return
}
c.JSON(200, gin.H{"model-mappings": h.cfg.AmpCode.ModelMappings})
}
// PutAmpModelMappings replaces all ampcode model mappings.
func (h *Handler) PutAmpModelMappings(c *gin.Context) {
var body struct {
Value []config.AmpModelMapping `json:"value"`
}
if err := c.ShouldBindJSON(&body); err != nil {
c.JSON(400, gin.H{"error": "invalid body"})
return
}
h.cfg.AmpCode.ModelMappings = body.Value
h.persist(c)
}
// PatchAmpModelMappings adds or updates model mappings.
func (h *Handler) PatchAmpModelMappings(c *gin.Context) {
var body struct {
Value []config.AmpModelMapping `json:"value"`
}
if err := c.ShouldBindJSON(&body); err != nil {
c.JSON(400, gin.H{"error": "invalid body"})
return
}
existing := make(map[string]int)
for i, m := range h.cfg.AmpCode.ModelMappings {
existing[strings.TrimSpace(m.From)] = i
}
for _, newMapping := range body.Value {
from := strings.TrimSpace(newMapping.From)
if idx, ok := existing[from]; ok {
h.cfg.AmpCode.ModelMappings[idx] = newMapping
} else {
h.cfg.AmpCode.ModelMappings = append(h.cfg.AmpCode.ModelMappings, newMapping)
existing[from] = len(h.cfg.AmpCode.ModelMappings) - 1
}
}
h.persist(c)
}
// DeleteAmpModelMappings removes specified model mappings by "from" field.
func (h *Handler) DeleteAmpModelMappings(c *gin.Context) {
var body struct {
Value []string `json:"value"`
}
if err := c.ShouldBindJSON(&body); err != nil || len(body.Value) == 0 {
h.cfg.AmpCode.ModelMappings = nil
h.persist(c)
return
}
toRemove := make(map[string]bool)
for _, from := range body.Value {
toRemove[strings.TrimSpace(from)] = true
}
newMappings := make([]config.AmpModelMapping, 0, len(h.cfg.AmpCode.ModelMappings))
for _, m := range h.cfg.AmpCode.ModelMappings {
if !toRemove[strings.TrimSpace(m.From)] {
newMappings = append(newMappings, m)
}
}
h.cfg.AmpCode.ModelMappings = newMappings
h.persist(c)
}
// GetAmpForceModelMappings returns whether model mappings are forced.
func (h *Handler) GetAmpForceModelMappings(c *gin.Context) {
if h == nil || h.cfg == nil {
c.JSON(200, gin.H{"force-model-mappings": false})
return
}
c.JSON(200, gin.H{"force-model-mappings": h.cfg.AmpCode.ForceModelMappings})
}
// PutAmpForceModelMappings updates the force model mappings setting.
func (h *Handler) PutAmpForceModelMappings(c *gin.Context) {
h.updateBoolField(c, func(v bool) { h.cfg.AmpCode.ForceModelMappings = v })
}

View File

@@ -240,16 +240,6 @@ func (h *Handler) updateBoolField(c *gin.Context, set func(bool)) {
Value *bool `json:"value"`
}
if err := c.ShouldBindJSON(&body); err != nil || body.Value == nil {
var m map[string]any
if err2 := c.ShouldBindJSON(&m); err2 == nil {
for _, v := range m {
if b, ok := v.(bool); ok {
set(b)
h.persist(c)
return
}
}
}
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid body"})
return
}

View File

@@ -232,7 +232,16 @@ func (w *ResponseWriterWrapper) Finalize(c *gin.Context) error {
w.streamDone = nil
}
// Write API Request and Response to the streaming log before closing
if w.streamWriter != nil {
apiRequest := w.extractAPIRequest(c)
if len(apiRequest) > 0 {
_ = w.streamWriter.WriteAPIRequest(apiRequest)
}
apiResponse := w.extractAPIResponse(c)
if len(apiResponse) > 0 {
_ = w.streamWriter.WriteAPIResponse(apiResponse)
}
if err := w.streamWriter.Close(); err != nil {
w.streamWriter = nil
return err

View File

@@ -100,6 +100,16 @@ func (m *AmpModule) Name() string {
return "amp-routing"
}
// forceModelMappings returns whether model mappings should take precedence over local API keys
func (m *AmpModule) forceModelMappings() bool {
m.configMu.RLock()
defer m.configMu.RUnlock()
if m.lastConfig == nil {
return false
}
return m.lastConfig.ForceModelMappings
}
// Register sets up Amp routes if configured.
// This implements the RouteModuleV2 interface with Context.
// Routes are registered only once via sync.Once for idempotent behavior.

View File

@@ -28,6 +28,9 @@ const (
RouteTypeNoProvider AmpRouteType = "NO_PROVIDER"
)
// MappedModelContextKey is the Gin context key for passing mapped model names.
const MappedModelContextKey = "mapped_model"
// logAmpRouting logs the routing decision for an Amp request with structured fields
func logAmpRouting(routeType AmpRouteType, requestedModel, resolvedModel, provider, path string) {
fields := log.Fields{
@@ -74,23 +77,29 @@ func logAmpRouting(routeType AmpRouteType, requestedModel, resolvedModel, provid
// FallbackHandler wraps a standard handler with fallback logic to ampcode.com
// when the model's provider is not available in CLIProxyAPI
type FallbackHandler struct {
getProxy func() *httputil.ReverseProxy
modelMapper ModelMapper
getProxy func() *httputil.ReverseProxy
modelMapper ModelMapper
forceModelMappings func() bool
}
// NewFallbackHandler creates a new fallback handler wrapper
// The getProxy function allows lazy evaluation of the proxy (useful when proxy is created after routes)
func NewFallbackHandler(getProxy func() *httputil.ReverseProxy) *FallbackHandler {
return &FallbackHandler{
getProxy: getProxy,
getProxy: getProxy,
forceModelMappings: func() bool { return false },
}
}
// NewFallbackHandlerWithMapper creates a new fallback handler with model mapping support
func NewFallbackHandlerWithMapper(getProxy func() *httputil.ReverseProxy, mapper ModelMapper) *FallbackHandler {
func NewFallbackHandlerWithMapper(getProxy func() *httputil.ReverseProxy, mapper ModelMapper, forceModelMappings func() bool) *FallbackHandler {
if forceModelMappings == nil {
forceModelMappings = func() bool { return false }
}
return &FallbackHandler{
getProxy: getProxy,
modelMapper: mapper,
getProxy: getProxy,
modelMapper: mapper,
forceModelMappings: forceModelMappings,
}
}
@@ -127,32 +136,65 @@ func (fh *FallbackHandler) WrapHandler(handler gin.HandlerFunc) gin.HandlerFunc
// Normalize model (handles Gemini thinking suffixes)
normalizedModel, _ := util.NormalizeGeminiThinkingModel(modelName)
// Check if we have providers for this model
providers := util.GetProviderName(normalizedModel)
// Track resolved model for logging (may change if mapping is applied)
resolvedModel := normalizedModel
usedMapping := false
var providers []string
if len(providers) == 0 {
// No providers configured - check if we have a model mapping
// Check if model mappings should be forced ahead of local API keys
forceMappings := fh.forceModelMappings != nil && fh.forceModelMappings()
if forceMappings {
// FORCE MODE: Check model mappings FIRST (takes precedence over local API keys)
// This allows users to route Amp requests to their preferred OAuth providers
if fh.modelMapper != nil {
if mappedModel := fh.modelMapper.MapModel(normalizedModel); mappedModel != "" {
// Mapping found - rewrite the model in request body
bodyBytes = rewriteModelInRequest(bodyBytes, mappedModel)
c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes))
resolvedModel = mappedModel
usedMapping = true
// Get providers for the mapped model
providers = util.GetProviderName(mappedModel)
// Continue to handler with remapped model
goto handleRequest
// Mapping found - check if we have a provider for the mapped model
mappedProviders := util.GetProviderName(mappedModel)
if len(mappedProviders) > 0 {
// Mapping found and provider available - rewrite the model in request body
bodyBytes = rewriteModelInRequest(bodyBytes, mappedModel)
c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes))
// Store mapped model in context for handlers that check it (like gemini bridge)
c.Set(MappedModelContextKey, mappedModel)
resolvedModel = mappedModel
usedMapping = true
providers = mappedProviders
}
}
}
// No mapping found - check if we have a proxy for fallback
// If no mapping applied, check for local providers
if !usedMapping {
providers = util.GetProviderName(normalizedModel)
}
} else {
// DEFAULT MODE: Check local providers first, then mappings as fallback
providers = util.GetProviderName(normalizedModel)
if len(providers) == 0 {
// No providers configured - check if we have a model mapping
if fh.modelMapper != nil {
if mappedModel := fh.modelMapper.MapModel(normalizedModel); mappedModel != "" {
// Mapping found - check if we have a provider for the mapped model
mappedProviders := util.GetProviderName(mappedModel)
if len(mappedProviders) > 0 {
// Mapping found and provider available - rewrite the model in request body
bodyBytes = rewriteModelInRequest(bodyBytes, mappedModel)
c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes))
// Store mapped model in context for handlers that check it (like gemini bridge)
c.Set(MappedModelContextKey, mappedModel)
resolvedModel = mappedModel
usedMapping = true
providers = mappedProviders
}
}
}
}
}
// If no providers available, fallback to ampcode.com
if len(providers) == 0 {
proxy := fh.getProxy()
if proxy != nil {
// Log: Forwarding to ampcode.com (uses Amp credits)
@@ -170,8 +212,6 @@ func (fh *FallbackHandler) WrapHandler(handler gin.HandlerFunc) gin.HandlerFunc
logAmpRouting(RouteTypeNoProvider, modelName, "", "", requestPath)
}
handleRequest:
// Log the routing decision
providerName := ""
if len(providers) > 0 {

View File

@@ -4,7 +4,6 @@ import (
"strings"
"github.com/gin-gonic/gin"
"github.com/router-for-me/CLIProxyAPI/v6/sdk/api/handlers/gemini"
)
// createGeminiBridgeHandler creates a handler that bridges AMP CLI's non-standard Gemini paths
@@ -15,16 +14,31 @@ import (
//
// This extracts the model+method from the AMP path and sets it as the :action parameter
// so the standard Gemini handler can process it.
func createGeminiBridgeHandler(geminiHandler *gemini.GeminiAPIHandler) gin.HandlerFunc {
//
// The handler parameter should be a Gemini-compatible handler that expects the :action param.
func createGeminiBridgeHandler(handler gin.HandlerFunc) gin.HandlerFunc {
return func(c *gin.Context) {
// Get the full path from the catch-all parameter
path := c.Param("path")
// Extract model:method from AMP CLI path format
// Example: /publishers/google/models/gemini-3-pro-preview:streamGenerateContent
if idx := strings.Index(path, "/models/"); idx >= 0 {
// Extract everything after "/models/"
actionPart := path[idx+8:] // Skip "/models/"
const modelsPrefix = "/models/"
if idx := strings.Index(path, modelsPrefix); idx >= 0 {
// Extract everything after modelsPrefix
actionPart := path[idx+len(modelsPrefix):]
// Check if model was mapped by FallbackHandler
if mappedModel, exists := c.Get(MappedModelContextKey); exists {
if strModel, ok := mappedModel.(string); ok && strModel != "" {
// Replace the model part in the action
// actionPart is like "model-name:method"
if colonIdx := strings.Index(actionPart, ":"); colonIdx > 0 {
method := actionPart[colonIdx:] // ":method"
actionPart = strModel + method
}
}
}
// Set this as the :action parameter that the Gemini handler expects
c.Params = append(c.Params, gin.Param{
@@ -32,8 +46,8 @@ func createGeminiBridgeHandler(geminiHandler *gemini.GeminiAPIHandler) gin.Handl
Value: actionPart,
})
// Call the standard Gemini handler
geminiHandler.GeminiHandler(c)
// Call the handler
handler(c)
return
}

View File

@@ -0,0 +1,93 @@
package amp
import (
"net/http"
"net/http/httptest"
"testing"
"github.com/gin-gonic/gin"
)
func TestCreateGeminiBridgeHandler_ActionParameterExtraction(t *testing.T) {
gin.SetMode(gin.TestMode)
tests := []struct {
name string
path string
mappedModel string // empty string means no mapping
expectedAction string
}{
{
name: "no_mapping_uses_url_model",
path: "/publishers/google/models/gemini-pro:generateContent",
mappedModel: "",
expectedAction: "gemini-pro:generateContent",
},
{
name: "mapped_model_replaces_url_model",
path: "/publishers/google/models/gemini-exp:generateContent",
mappedModel: "gemini-2.0-flash",
expectedAction: "gemini-2.0-flash:generateContent",
},
{
name: "mapping_preserves_method",
path: "/publishers/google/models/gemini-2.5-preview:streamGenerateContent",
mappedModel: "gemini-flash",
expectedAction: "gemini-flash:streamGenerateContent",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
var capturedAction string
mockGeminiHandler := func(c *gin.Context) {
capturedAction = c.Param("action")
c.JSON(http.StatusOK, gin.H{"captured": capturedAction})
}
// Use the actual createGeminiBridgeHandler function
bridgeHandler := createGeminiBridgeHandler(mockGeminiHandler)
r := gin.New()
if tt.mappedModel != "" {
r.Use(func(c *gin.Context) {
c.Set(MappedModelContextKey, tt.mappedModel)
c.Next()
})
}
r.POST("/api/provider/google/v1beta1/*path", bridgeHandler)
req := httptest.NewRequest(http.MethodPost, "/api/provider/google/v1beta1"+tt.path, nil)
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("Expected status 200, got %d", w.Code)
}
if capturedAction != tt.expectedAction {
t.Errorf("Expected action '%s', got '%s'", tt.expectedAction, capturedAction)
}
})
}
}
func TestCreateGeminiBridgeHandler_InvalidPath(t *testing.T) {
gin.SetMode(gin.TestMode)
mockHandler := func(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{"ok": true})
}
bridgeHandler := createGeminiBridgeHandler(mockHandler)
r := gin.New()
r.POST("/api/provider/google/v1beta1/*path", bridgeHandler)
req := httptest.NewRequest(http.MethodPost, "/api/provider/google/v1beta1/invalid/path", nil)
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
if w.Code != http.StatusBadRequest {
t.Errorf("Expected status 400 for invalid path, got %d", w.Code)
}
}

View File

@@ -9,7 +9,6 @@ import (
"github.com/gin-gonic/gin"
"github.com/router-for-me/CLIProxyAPI/v6/internal/logging"
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
"github.com/router-for-me/CLIProxyAPI/v6/sdk/api/handlers"
"github.com/router-for-me/CLIProxyAPI/v6/sdk/api/handlers/claude"
"github.com/router-for-me/CLIProxyAPI/v6/sdk/api/handlers/gemini"
@@ -169,30 +168,22 @@ func (m *AmpModule) registerManagementRoutes(engine *gin.Engine, baseHandler *ha
// We bridge these to our standard Gemini handler to enable local OAuth.
// If no local OAuth is available, falls back to ampcode.com proxy.
geminiHandlers := gemini.NewGeminiAPIHandler(baseHandler)
geminiBridge := createGeminiBridgeHandler(geminiHandlers)
geminiV1Beta1Fallback := NewFallbackHandler(func() *httputil.ReverseProxy {
geminiBridge := createGeminiBridgeHandler(geminiHandlers.GeminiHandler)
geminiV1Beta1Fallback := NewFallbackHandlerWithMapper(func() *httputil.ReverseProxy {
return m.getProxy()
})
}, m.modelMapper, m.forceModelMappings)
geminiV1Beta1Handler := geminiV1Beta1Fallback.WrapHandler(geminiBridge)
// Route POST model calls through Gemini bridge when a local provider exists, otherwise proxy.
// Route POST model calls through Gemini bridge with FallbackHandler.
// FallbackHandler checks provider -> mapping -> proxy fallback automatically.
// All other methods (e.g., GET model listing) always proxy to upstream to preserve Amp CLI behavior.
ampAPI.Any("/provider/google/v1beta1/*path", func(c *gin.Context) {
if c.Request.Method == "POST" {
// Attempt to extract the model name from the AMP-style path
if path := c.Param("path"); strings.Contains(path, "/models/") {
modelPart := path[strings.Index(path, "/models/")+len("/models/"):]
if colonIdx := strings.Index(modelPart, ":"); colonIdx > 0 {
modelPart = modelPart[:colonIdx]
}
if modelPart != "" {
normalized, _ := util.NormalizeGeminiThinkingModel(modelPart)
// Only handle locally when we have a provider; otherwise fall back to proxy
if providers := util.GetProviderName(normalized); len(providers) > 0 {
geminiV1Beta1Handler(c)
return
}
}
// POST with /models/ path -> use Gemini bridge with fallback handler
// FallbackHandler will check provider/mapping and proxy if needed
geminiV1Beta1Handler(c)
return
}
}
// Non-POST or no local provider available -> proxy upstream
@@ -218,7 +209,7 @@ func (m *AmpModule) registerProviderAliases(engine *gin.Engine, baseHandler *han
// Also includes model mapping support for routing unavailable models to alternatives
fallbackHandler := NewFallbackHandlerWithMapper(func() *httputil.ReverseProxy {
return m.getProxy()
}, m.modelMapper)
}, m.modelMapper, m.forceModelMappings)
// Provider-specific routes under /api/provider/:provider
ampProviders := engine.Group("/api/provider")

View File

@@ -300,7 +300,7 @@ func NewServer(cfg *config.Config, authManager *auth.Manager, accessManager *sdk
// Create HTTP server
s.server = &http.Server{
Addr: fmt.Sprintf(":%d", cfg.Port),
Addr: fmt.Sprintf("%s:%d", cfg.Host, cfg.Port),
Handler: engine,
}
@@ -520,6 +520,26 @@ func (s *Server) registerManagementRoutes() {
mgmt.PUT("/ws-auth", s.mgmt.PutWebsocketAuth)
mgmt.PATCH("/ws-auth", s.mgmt.PutWebsocketAuth)
mgmt.GET("/ampcode", s.mgmt.GetAmpCode)
mgmt.GET("/ampcode/upstream-url", s.mgmt.GetAmpUpstreamURL)
mgmt.PUT("/ampcode/upstream-url", s.mgmt.PutAmpUpstreamURL)
mgmt.PATCH("/ampcode/upstream-url", s.mgmt.PutAmpUpstreamURL)
mgmt.DELETE("/ampcode/upstream-url", s.mgmt.DeleteAmpUpstreamURL)
mgmt.GET("/ampcode/upstream-api-key", s.mgmt.GetAmpUpstreamAPIKey)
mgmt.PUT("/ampcode/upstream-api-key", s.mgmt.PutAmpUpstreamAPIKey)
mgmt.PATCH("/ampcode/upstream-api-key", s.mgmt.PutAmpUpstreamAPIKey)
mgmt.DELETE("/ampcode/upstream-api-key", s.mgmt.DeleteAmpUpstreamAPIKey)
mgmt.GET("/ampcode/restrict-management-to-localhost", s.mgmt.GetAmpRestrictManagementToLocalhost)
mgmt.PUT("/ampcode/restrict-management-to-localhost", s.mgmt.PutAmpRestrictManagementToLocalhost)
mgmt.PATCH("/ampcode/restrict-management-to-localhost", s.mgmt.PutAmpRestrictManagementToLocalhost)
mgmt.GET("/ampcode/model-mappings", s.mgmt.GetAmpModelMappings)
mgmt.PUT("/ampcode/model-mappings", s.mgmt.PutAmpModelMappings)
mgmt.PATCH("/ampcode/model-mappings", s.mgmt.PatchAmpModelMappings)
mgmt.DELETE("/ampcode/model-mappings", s.mgmt.DeleteAmpModelMappings)
mgmt.GET("/ampcode/force-model-mappings", s.mgmt.GetAmpForceModelMappings)
mgmt.PUT("/ampcode/force-model-mappings", s.mgmt.PutAmpForceModelMappings)
mgmt.PATCH("/ampcode/force-model-mappings", s.mgmt.PutAmpForceModelMappings)
mgmt.GET("/request-retry", s.mgmt.GetRequestRetry)
mgmt.PUT("/request-retry", s.mgmt.PutRequestRetry)
mgmt.PATCH("/request-retry", s.mgmt.PutRequestRetry)

View File

@@ -76,7 +76,8 @@ func (g *GeminiAuth) GetAuthenticatedClient(ctx context.Context, ts *GeminiToken
auth := &proxy.Auth{User: username, Password: password}
dialer, errSOCKS5 := proxy.SOCKS5("tcp", proxyURL.Host, auth, proxy.Direct)
if errSOCKS5 != nil {
log.Fatalf("create SOCKS5 dialer failed: %v", errSOCKS5)
log.Errorf("create SOCKS5 dialer failed: %v", errSOCKS5)
return nil, fmt.Errorf("create SOCKS5 dialer failed: %w", errSOCKS5)
}
transport = &http.Transport{
DialContext: func(ctx context.Context, network, addr string) (net.Conn, error) {
@@ -238,7 +239,11 @@ func (g *GeminiAuth) getTokenFromWeb(ctx context.Context, config *oauth2.Config,
// Start the server in a goroutine.
go func() {
if err := server.ListenAndServe(); !errors.Is(err, http.ErrServerClosed) {
log.Fatalf("ListenAndServe(): %v", err)
log.Errorf("ListenAndServe(): %v", err)
select {
case errChan <- err:
default:
}
}
}()

View File

@@ -309,17 +309,23 @@ func (ia *IFlowAuth) AuthenticateWithCookie(ctx context.Context, cookie string)
return nil, fmt.Errorf("iflow cookie authentication: cookie is empty")
}
// First, get initial API key information using GET request
// First, get initial API key information using GET request to obtain the name
keyInfo, err := ia.fetchAPIKeyInfo(ctx, cookie)
if err != nil {
return nil, fmt.Errorf("iflow cookie authentication: fetch initial API key info failed: %w", err)
}
// Convert to token data format
// Refresh the API key using POST request
refreshedKeyInfo, err := ia.RefreshAPIKey(ctx, cookie, keyInfo.Name)
if err != nil {
return nil, fmt.Errorf("iflow cookie authentication: refresh API key failed: %w", err)
}
// Convert to token data format using refreshed key
data := &IFlowTokenData{
APIKey: keyInfo.APIKey,
Expire: keyInfo.ExpireTime,
Email: keyInfo.Name,
APIKey: refreshedKeyInfo.APIKey,
Expire: refreshedKeyInfo.ExpireTime,
Email: refreshedKeyInfo.Name,
Cookie: cookie,
}

View File

@@ -65,20 +65,20 @@ func DoLogin(cfg *config.Config, projectID string, options *LoginOptions) {
authenticator := sdkAuth.NewGeminiAuthenticator()
record, errLogin := authenticator.Login(ctx, cfg, loginOpts)
if errLogin != nil {
log.Fatalf("Gemini authentication failed: %v", errLogin)
log.Errorf("Gemini authentication failed: %v", errLogin)
return
}
storage, okStorage := record.Storage.(*gemini.GeminiTokenStorage)
if !okStorage || storage == nil {
log.Fatal("Gemini authentication failed: unsupported token storage")
log.Error("Gemini authentication failed: unsupported token storage")
return
}
geminiAuth := gemini.NewGeminiAuth()
httpClient, errClient := geminiAuth.GetAuthenticatedClient(ctx, storage, cfg, options.NoBrowser)
if errClient != nil {
log.Fatalf("Gemini authentication failed: %v", errClient)
log.Errorf("Gemini authentication failed: %v", errClient)
return
}
@@ -86,7 +86,7 @@ func DoLogin(cfg *config.Config, projectID string, options *LoginOptions) {
projects, errProjects := fetchGCPProjects(ctx, httpClient)
if errProjects != nil {
log.Fatalf("Failed to get project list: %v", errProjects)
log.Errorf("Failed to get project list: %v", errProjects)
return
}
@@ -98,11 +98,11 @@ func DoLogin(cfg *config.Config, projectID string, options *LoginOptions) {
selectedProjectID := promptForProjectSelection(projects, strings.TrimSpace(projectID), promptFn)
projectSelections, errSelection := resolveProjectSelections(selectedProjectID, projects)
if errSelection != nil {
log.Fatalf("Invalid project selection: %v", errSelection)
log.Errorf("Invalid project selection: %v", errSelection)
return
}
if len(projectSelections) == 0 {
log.Fatal("No project selected; aborting login.")
log.Error("No project selected; aborting login.")
return
}
@@ -116,7 +116,7 @@ func DoLogin(cfg *config.Config, projectID string, options *LoginOptions) {
showProjectSelectionHelp(storage.Email, projects)
return
}
log.Fatalf("Failed to complete user setup: %v", errSetup)
log.Errorf("Failed to complete user setup: %v", errSetup)
return
}
finalID := strings.TrimSpace(storage.ProjectID)
@@ -133,11 +133,11 @@ func DoLogin(cfg *config.Config, projectID string, options *LoginOptions) {
for _, pid := range activatedProjects {
isChecked, errCheck := checkCloudAPIIsEnabled(ctx, httpClient, pid)
if errCheck != nil {
log.Fatalf("Failed to check if Cloud AI API is enabled for %s: %v", pid, errCheck)
log.Errorf("Failed to check if Cloud AI API is enabled for %s: %v", pid, errCheck)
return
}
if !isChecked {
log.Fatalf("Failed to check if Cloud AI API is enabled for project %s. If you encounter an error message, please create an issue.", pid)
log.Errorf("Failed to check if Cloud AI API is enabled for project %s. If you encounter an error message, please create an issue.", pid)
return
}
}
@@ -153,7 +153,7 @@ func DoLogin(cfg *config.Config, projectID string, options *LoginOptions) {
savedPath, errSave := store.Save(ctx, record)
if errSave != nil {
log.Fatalf("Failed to save token to file: %v", errSave)
log.Errorf("Failed to save token to file: %v", errSave)
return
}
@@ -555,6 +555,7 @@ func checkCloudAPIIsEnabled(ctx context.Context, httpClient *http.Client, projec
continue
}
}
_ = resp.Body.Close()
return false, fmt.Errorf("project activation required: %s", errMessage)
}
return true, nil

View File

@@ -45,12 +45,13 @@ func StartService(cfg *config.Config, configPath string, localPassword string) {
service, err := builder.Build()
if err != nil {
log.Fatalf("failed to build proxy service: %v", err)
log.Errorf("failed to build proxy service: %v", err)
return
}
err = service.Run(runCtx)
if err != nil && !errors.Is(err, context.Canceled) {
log.Fatalf("proxy service exited with error: %v", err)
log.Errorf("proxy service exited with error: %v", err)
}
}

View File

@@ -29,30 +29,30 @@ func DoVertexImport(cfg *config.Config, keyPath string) {
}
rawPath := strings.TrimSpace(keyPath)
if rawPath == "" {
log.Fatalf("vertex-import: missing service account key path")
log.Errorf("vertex-import: missing service account key path")
return
}
data, errRead := os.ReadFile(rawPath)
if errRead != nil {
log.Fatalf("vertex-import: read file failed: %v", errRead)
log.Errorf("vertex-import: read file failed: %v", errRead)
return
}
var sa map[string]any
if errUnmarshal := json.Unmarshal(data, &sa); errUnmarshal != nil {
log.Fatalf("vertex-import: invalid service account json: %v", errUnmarshal)
log.Errorf("vertex-import: invalid service account json: %v", errUnmarshal)
return
}
// Validate and normalize private_key before saving
normalizedSA, errFix := vertex.NormalizeServiceAccountMap(sa)
if errFix != nil {
log.Fatalf("vertex-import: %v", errFix)
log.Errorf("vertex-import: %v", errFix)
return
}
sa = normalizedSA
email, _ := sa["client_email"].(string)
projectID, _ := sa["project_id"].(string)
if strings.TrimSpace(projectID) == "" {
log.Fatalf("vertex-import: project_id missing in service account json")
log.Errorf("vertex-import: project_id missing in service account json")
return
}
if strings.TrimSpace(email) == "" {
@@ -92,7 +92,7 @@ func DoVertexImport(cfg *config.Config, keyPath string) {
}
path, errSave := store.Save(context.Background(), record)
if errSave != nil {
log.Fatalf("vertex-import: save credential failed: %v", errSave)
log.Errorf("vertex-import: save credential failed: %v", errSave)
return
}
fmt.Printf("Vertex credentials imported: %s\n", path)

View File

@@ -20,6 +20,9 @@ import (
// Config represents the application's configuration, loaded from a YAML file.
type Config struct {
config.SDKConfig `yaml:",inline"`
// Host is the network host/interface on which the API server will bind.
// Default is empty ("") to bind all interfaces (IPv4 + IPv6). Use "127.0.0.1" or "localhost" for local-only access.
Host string `yaml:"host" json:"-"`
// Port is the network port on which the API server will listen.
Port int `yaml:"port" json:"-"`
@@ -143,6 +146,10 @@ type AmpCode struct {
// When Amp requests a model that isn't available locally, these mappings
// allow routing to an alternative model that IS available.
ModelMappings []AmpModelMapping `yaml:"model-mappings" json:"model-mappings"`
// ForceModelMappings when true, model mappings take precedence over local API keys.
// When false (default), local API keys are used first if available.
ForceModelMappings bool `yaml:"force-model-mappings" json:"force-model-mappings"`
}
// PayloadConfig defines default and override parameter rules applied to provider payloads.
@@ -316,6 +323,7 @@ func LoadConfigOptional(configFile string, optional bool) (*Config, error) {
// Unmarshal the YAML data into the Config struct.
var cfg Config
// Set defaults before unmarshal so that absent keys keep defaults.
cfg.Host = "" // Default empty: binds to all interfaces (IPv4 + IPv6)
cfg.LoggingToFile = false
cfg.UsageStatisticsEnabled = false
cfg.DisableCooling = false

View File

@@ -56,6 +56,8 @@ type Content struct {
// Part represents a distinct piece of content within a message.
// A part can be text, inline data (like an image), a function call, or a function response.
type Part struct {
Thought bool `json:"thought,omitempty"`
// Text contains plain text content.
Text string `json:"text,omitempty"`

View File

@@ -20,6 +20,7 @@ import (
"github.com/klauspost/compress/zstd"
log "github.com/sirupsen/logrus"
"github.com/router-for-me/CLIProxyAPI/v6/internal/buildinfo"
"github.com/router-for-me/CLIProxyAPI/v6/internal/interfaces"
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
)
@@ -83,6 +84,26 @@ type StreamingLogWriter interface {
// - error: An error if writing fails, nil otherwise
WriteStatus(status int, headers map[string][]string) error
// WriteAPIRequest writes the upstream API request details to the log.
// This should be called before WriteStatus to maintain proper log ordering.
//
// Parameters:
// - apiRequest: The API request data (typically includes URL, headers, body sent upstream)
//
// Returns:
// - error: An error if writing fails, nil otherwise
WriteAPIRequest(apiRequest []byte) error
// WriteAPIResponse writes the upstream API response details to the log.
// This should be called after the streaming response is complete.
//
// Parameters:
// - apiResponse: The API response data
//
// Returns:
// - error: An error if writing fails, nil otherwise
WriteAPIResponse(apiResponse []byte) error
// Close finalizes the log file and cleans up resources.
//
// Returns:
@@ -247,10 +268,11 @@ func (l *FileRequestLogger) LogStreamingRequest(url, method string, headers map[
// Create streaming writer
writer := &FileStreamingLogWriter{
file: file,
chunkChan: make(chan []byte, 100), // Buffered channel for async writes
closeChan: make(chan struct{}),
errorChan: make(chan error, 1),
file: file,
chunkChan: make(chan []byte, 100), // Buffered channel for async writes
closeChan: make(chan struct{}),
errorChan: make(chan error, 1),
bufferedChunks: &bytes.Buffer{},
}
// Start async writer goroutine
@@ -603,6 +625,7 @@ func (l *FileRequestLogger) formatRequestInfo(url, method string, headers map[st
var content strings.Builder
content.WriteString("=== REQUEST INFO ===\n")
content.WriteString(fmt.Sprintf("Version: %s\n", buildinfo.Version))
content.WriteString(fmt.Sprintf("URL: %s\n", url))
content.WriteString(fmt.Sprintf("Method: %s\n", method))
content.WriteString(fmt.Sprintf("Timestamp: %s\n", time.Now().Format(time.RFC3339Nano)))
@@ -626,11 +649,12 @@ func (l *FileRequestLogger) formatRequestInfo(url, method string, headers map[st
// FileStreamingLogWriter implements StreamingLogWriter for file-based streaming logs.
// It handles asynchronous writing of streaming response chunks to a file.
// All data is buffered and written in the correct order when Close is called.
type FileStreamingLogWriter struct {
// file is the file where log data is written.
file *os.File
// chunkChan is a channel for receiving response chunks to write.
// chunkChan is a channel for receiving response chunks to buffer.
chunkChan chan []byte
// closeChan is a channel for signaling when the writer is closed.
@@ -639,8 +663,23 @@ type FileStreamingLogWriter struct {
// errorChan is a channel for reporting errors during writing.
errorChan chan error
// statusWritten indicates whether the response status has been written.
// bufferedChunks stores the response chunks in order.
bufferedChunks *bytes.Buffer
// responseStatus stores the HTTP status code.
responseStatus int
// statusWritten indicates whether a non-zero status was recorded.
statusWritten bool
// responseHeaders stores the response headers.
responseHeaders map[string][]string
// apiRequest stores the upstream API request data.
apiRequest []byte
// apiResponse stores the upstream API response data.
apiResponse []byte
}
// WriteChunkAsync writes a response chunk asynchronously (non-blocking).
@@ -664,39 +703,65 @@ func (w *FileStreamingLogWriter) WriteChunkAsync(chunk []byte) {
}
}
// WriteStatus writes the response status and headers to the log.
// WriteStatus buffers the response status and headers for later writing.
//
// Parameters:
// - status: The response status code
// - headers: The response headers
//
// Returns:
// - error: An error if writing fails, nil otherwise
// - error: Always returns nil (buffering cannot fail)
func (w *FileStreamingLogWriter) WriteStatus(status int, headers map[string][]string) error {
if w.file == nil || w.statusWritten {
if status == 0 {
return nil
}
var content strings.Builder
content.WriteString("========================================\n")
content.WriteString("=== RESPONSE ===\n")
content.WriteString(fmt.Sprintf("Status: %d\n", status))
for key, values := range headers {
for _, value := range values {
content.WriteString(fmt.Sprintf("%s: %s\n", key, value))
w.responseStatus = status
if headers != nil {
w.responseHeaders = make(map[string][]string, len(headers))
for key, values := range headers {
headerValues := make([]string, len(values))
copy(headerValues, values)
w.responseHeaders[key] = headerValues
}
}
content.WriteString("\n")
w.statusWritten = true
return nil
}
_, err := w.file.WriteString(content.String())
if err == nil {
w.statusWritten = true
// WriteAPIRequest buffers the upstream API request details for later writing.
//
// Parameters:
// - apiRequest: The API request data (typically includes URL, headers, body sent upstream)
//
// Returns:
// - error: Always returns nil (buffering cannot fail)
func (w *FileStreamingLogWriter) WriteAPIRequest(apiRequest []byte) error {
if len(apiRequest) == 0 {
return nil
}
return err
w.apiRequest = bytes.Clone(apiRequest)
return nil
}
// WriteAPIResponse buffers the upstream API response details for later writing.
//
// Parameters:
// - apiResponse: The API response data
//
// Returns:
// - error: Always returns nil (buffering cannot fail)
func (w *FileStreamingLogWriter) WriteAPIResponse(apiResponse []byte) error {
if len(apiResponse) == 0 {
return nil
}
w.apiResponse = bytes.Clone(apiResponse)
return nil
}
// Close finalizes the log file and cleans up resources.
// It writes all buffered data to the file in the correct order:
// API REQUEST -> API RESPONSE -> RESPONSE (status, headers, body chunks)
//
// Returns:
// - error: An error if closing fails, nil otherwise
@@ -705,27 +770,84 @@ func (w *FileStreamingLogWriter) Close() error {
close(w.chunkChan)
}
// Wait for async writer to finish
// Wait for async writer to finish buffering chunks
if w.closeChan != nil {
<-w.closeChan
w.chunkChan = nil
}
if w.file != nil {
return w.file.Close()
if w.file == nil {
return nil
}
return nil
// Write all content in the correct order
var content strings.Builder
// 1. Write API REQUEST section
if len(w.apiRequest) > 0 {
if bytes.HasPrefix(w.apiRequest, []byte("=== API REQUEST")) {
content.Write(w.apiRequest)
if !bytes.HasSuffix(w.apiRequest, []byte("\n")) {
content.WriteString("\n")
}
} else {
content.WriteString("=== API REQUEST ===\n")
content.Write(w.apiRequest)
content.WriteString("\n")
}
content.WriteString("\n")
}
// 2. Write API RESPONSE section
if len(w.apiResponse) > 0 {
if bytes.HasPrefix(w.apiResponse, []byte("=== API RESPONSE")) {
content.Write(w.apiResponse)
if !bytes.HasSuffix(w.apiResponse, []byte("\n")) {
content.WriteString("\n")
}
} else {
content.WriteString("=== API RESPONSE ===\n")
content.Write(w.apiResponse)
content.WriteString("\n")
}
content.WriteString("\n")
}
// 3. Write RESPONSE section (status, headers, buffered chunks)
content.WriteString("=== RESPONSE ===\n")
if w.statusWritten {
content.WriteString(fmt.Sprintf("Status: %d\n", w.responseStatus))
}
for key, values := range w.responseHeaders {
for _, value := range values {
content.WriteString(fmt.Sprintf("%s: %s\n", key, value))
}
}
content.WriteString("\n")
// Write buffered response body chunks
if w.bufferedChunks != nil && w.bufferedChunks.Len() > 0 {
content.Write(w.bufferedChunks.Bytes())
}
// Write the complete content to file
if _, err := w.file.WriteString(content.String()); err != nil {
_ = w.file.Close()
return err
}
return w.file.Close()
}
// asyncWriter runs in a goroutine to handle async chunk writing.
// It continuously reads chunks from the channel and writes them to the file.
// asyncWriter runs in a goroutine to buffer chunks from the channel.
// It continuously reads chunks from the channel and buffers them for later writing.
func (w *FileStreamingLogWriter) asyncWriter() {
defer close(w.closeChan)
for chunk := range w.chunkChan {
if w.file != nil {
_, _ = w.file.Write(chunk)
if w.bufferedChunks != nil {
w.bufferedChunks.Write(chunk)
}
}
}
@@ -752,6 +874,28 @@ func (w *NoOpStreamingLogWriter) WriteStatus(_ int, _ map[string][]string) error
return nil
}
// WriteAPIRequest is a no-op implementation that does nothing and always returns nil.
//
// Parameters:
// - apiRequest: The API request data (ignored)
//
// Returns:
// - error: Always returns nil
func (w *NoOpStreamingLogWriter) WriteAPIRequest(_ []byte) error {
return nil
}
// WriteAPIResponse is a no-op implementation that does nothing and always returns nil.
//
// Parameters:
// - apiResponse: The API response data (ignored)
//
// Returns:
// - error: Always returns nil
func (w *NoOpStreamingLogWriter) WriteAPIResponse(_ []byte) error {
return nil
}
// Close is a no-op implementation that does nothing and always returns nil.
//
// Returns:

View File

@@ -944,7 +944,6 @@ func GetQwenModels() []*ModelInfo {
}
// GetIFlowModels returns supported models for iFlow OAuth accounts.
func GetIFlowModels() []*ModelInfo {
entries := []struct {
ID string
@@ -986,3 +985,22 @@ func GetIFlowModels() []*ModelInfo {
}
return models
}
// AntigravityModelConfig captures static antigravity model overrides, including
// Thinking budget limits and provider max completion tokens.
type AntigravityModelConfig struct {
Thinking *ThinkingSupport
MaxCompletionTokens int
}
// GetAntigravityModelConfig returns static configuration for antigravity models.
// Keys use the ALIASED model names (after modelName2Alias conversion) for direct lookup.
func GetAntigravityModelConfig() map[string]*AntigravityModelConfig {
return map[string]*AntigravityModelConfig{
"gemini-2.5-flash": {Thinking: &ThinkingSupport{Min: 0, Max: 24576, ZeroAllowed: true, DynamicAllowed: true}},
"gemini-2.5-flash-lite": {Thinking: &ThinkingSupport{Min: 0, Max: 24576, ZeroAllowed: true, DynamicAllowed: true}},
"gemini-3-pro-preview": {Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true}},
"gemini-claude-sonnet-4-5-thinking": {Thinking: &ThinkingSupport{Min: 1024, Max: 200000, ZeroAllowed: false, DynamicAllowed: true}, MaxCompletionTokens: 64000},
"gemini-claude-opus-4-5-thinking": {Thinking: &ThinkingSupport{Min: 1024, Max: 200000, ZeroAllowed: false, DynamicAllowed: true}, MaxCompletionTokens: 64000},
}
}

View File

@@ -310,6 +310,10 @@ func (e *AIStudioExecutor) translateRequest(req cliproxyexecutor.Request, opts c
payload := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), stream)
payload = applyThinkingMetadata(payload, req.Metadata, req.Model)
payload = util.ConvertThinkingLevelToBudget(payload)
if budget := gjson.GetBytes(payload, "generationConfig.thinkingConfig.thinkingBudget"); budget.Exists() {
normalized := util.NormalizeThinkingBudget(req.Model, int(budget.Int()))
payload, _ = sjson.SetBytes(payload, "generationConfig.thinkingConfig.thinkingBudget", normalized)
}
payload = util.StripThinkingConfigIfUnsupported(req.Model, payload)
payload = fixGeminiImageAspectRatio(req.Model, payload)
payload = applyPayloadConfig(e.cfg, req.Model, payload)

View File

@@ -77,6 +77,7 @@ func (e *AntigravityExecutor) Execute(ctx context.Context, auth *cliproxyauth.Au
translated := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), false)
translated = applyThinkingMetadataCLI(translated, req.Metadata, req.Model)
translated = normalizeAntigravityThinking(req.Model, translated)
baseURLs := antigravityBaseURLFallbackOrder(auth)
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
@@ -170,6 +171,7 @@ func (e *AntigravityExecutor) ExecuteStream(ctx context.Context, auth *cliproxya
translated := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), true)
translated = applyThinkingMetadataCLI(translated, req.Metadata, req.Model)
translated = normalizeAntigravityThinking(req.Model, translated)
baseURLs := antigravityBaseURLFallbackOrder(auth)
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
@@ -366,28 +368,29 @@ func FetchAntigravityModels(ctx context.Context, auth *cliproxyauth.Auth, cfg *c
}
now := time.Now().Unix()
modelConfig := registry.GetAntigravityModelConfig()
models := make([]*registry.ModelInfo, 0, len(result.Map()))
for id := range result.Map() {
id = modelName2Alias(id)
if id != "" {
for originalName := range result.Map() {
aliasName := modelName2Alias(originalName)
if aliasName != "" {
modelInfo := &registry.ModelInfo{
ID: id,
Name: id,
Description: id,
DisplayName: id,
Version: id,
ID: aliasName,
Name: aliasName,
Description: aliasName,
DisplayName: aliasName,
Version: aliasName,
Object: "model",
Created: now,
OwnedBy: antigravityAuthType,
Type: antigravityAuthType,
}
// Add Thinking support for thinking models
if strings.HasSuffix(id, "-thinking") || strings.Contains(id, "-thinking-") {
modelInfo.Thinking = &registry.ThinkingSupport{
Min: 1024,
Max: 100000,
ZeroAllowed: false,
DynamicAllowed: true,
// Look up Thinking support from static config using alias name
if cfg, ok := modelConfig[aliasName]; ok {
if cfg.Thinking != nil {
modelInfo.Thinking = cfg.Thinking
}
if cfg.MaxCompletionTokens > 0 {
modelInfo.MaxCompletionTokens = cfg.MaxCompletionTokens
}
}
models = append(models, modelInfo)
@@ -533,6 +536,9 @@ func (e *AntigravityExecutor) buildRequest(ctx context.Context, auth *cliproxyau
strJSON = util.DeleteKey(strJSON, "minLength")
strJSON = util.DeleteKey(strJSON, "maxLength")
strJSON = util.DeleteKey(strJSON, "exclusiveMinimum")
strJSON = util.DeleteKey(strJSON, "exclusiveMaximum")
strJSON = util.DeleteKey(strJSON, "$ref")
strJSON = util.DeleteKey(strJSON, "$defs")
paths = make([]string, 0)
util.Walk(gjson.Parse(strJSON), "", "anyOf", &paths)
@@ -724,7 +730,7 @@ func geminiToAntigravity(modelName string, payload []byte, projectID string) []b
template, _ = sjson.Delete(template, "request.safetySettings")
template, _ = sjson.Set(template, "request.toolConfig.functionCallingConfig.mode", "VALIDATED")
template, _ = sjson.Delete(template, "request.generationConfig.maxOutputTokens")
if !strings.HasPrefix(modelName, "gemini-3-") {
if thinkingLevel := gjson.Get(template, "request.generationConfig.thinkingConfig.thinkingLevel"); thinkingLevel.Exists() {
template, _ = sjson.Delete(template, "request.generationConfig.thinkingConfig.thinkingLevel")
@@ -732,7 +738,7 @@ func geminiToAntigravity(modelName string, payload []byte, projectID string) []b
}
}
if strings.HasPrefix(modelName, "claude-sonnet-") {
if strings.Contains(modelName, "claude") {
gjson.Get(template, "request.tools").ForEach(func(key, tool gjson.Result) bool {
tool.Get("functionDeclarations").ForEach(func(funKey, funcDecl gjson.Result) bool {
if funcDecl.Get("parametersJsonSchema").Exists() {
@@ -744,6 +750,8 @@ func geminiToAntigravity(modelName string, payload []byte, projectID string) []b
})
return true
})
} else {
template, _ = sjson.Delete(template, "request.generationConfig.maxOutputTokens")
}
return []byte(template)
@@ -806,3 +814,53 @@ func alias2ModelName(modelName string) string {
return modelName
}
}
// normalizeAntigravityThinking clamps or removes thinking config based on model support.
// For Claude models, it additionally ensures thinking budget < max_tokens.
func normalizeAntigravityThinking(model string, payload []byte) []byte {
payload = util.StripThinkingConfigIfUnsupported(model, payload)
if !util.ModelSupportsThinking(model) {
return payload
}
budget := gjson.GetBytes(payload, "request.generationConfig.thinkingConfig.thinkingBudget")
if !budget.Exists() {
return payload
}
raw := int(budget.Int())
normalized := util.NormalizeThinkingBudget(model, raw)
isClaude := strings.Contains(strings.ToLower(model), "claude")
if isClaude {
effectiveMax, setDefaultMax := antigravityEffectiveMaxTokens(model, payload)
if effectiveMax > 0 && normalized >= effectiveMax {
normalized = effectiveMax - 1
if normalized < 1 {
normalized = 1
}
}
if setDefaultMax {
if res, errSet := sjson.SetBytes(payload, "request.generationConfig.maxOutputTokens", effectiveMax); errSet == nil {
payload = res
}
}
}
updated, err := sjson.SetBytes(payload, "request.generationConfig.thinkingConfig.thinkingBudget", normalized)
if err != nil {
return payload
}
return updated
}
// antigravityEffectiveMaxTokens returns the max tokens to cap thinking:
// prefer request-provided maxOutputTokens; otherwise fall back to model default.
// The boolean indicates whether the value came from the model default (and thus should be written back).
func antigravityEffectiveMaxTokens(model string, payload []byte) (max int, fromModel bool) {
if maxTok := gjson.GetBytes(payload, "request.generationConfig.maxOutputTokens"); maxTok.Exists() && maxTok.Int() > 0 {
return int(maxTok.Int()), false
}
if modelInfo := registry.GetGlobalRegistry().GetModelInfo(model); modelInfo != nil && modelInfo.MaxCompletionTokens > 0 {
return modelInfo.MaxCompletionTokens, true
}
return 0, false
}

View File

@@ -157,7 +157,7 @@ func appendAPIResponseChunk(ctx context.Context, cfg *config.Config, chunk []byt
if ginCtx == nil {
return
}
attempts, attempt := ensureAttempt(ginCtx)
_, attempt := ensureAttempt(ginCtx)
ensureResponseIntro(attempt)
if !attempt.headersWritten {
@@ -175,8 +175,6 @@ func appendAPIResponseChunk(ctx context.Context, cfg *config.Config, chunk []byt
}
attempt.response.WriteString(string(data))
attempt.bodyHasContent = true
updateAggregatedResponse(ginCtx, attempts)
}
func ginContextFrom(ctx context.Context) *gin.Context {

View File

@@ -83,7 +83,15 @@ func ConvertClaudeRequestToAntigravity(modelName string, inputRawJSON []byte, _
for j := 0; j < len(contentResults); j++ {
contentResult := contentResults[j]
contentTypeResult := contentResult.Get("type")
if contentTypeResult.Type == gjson.String && contentTypeResult.String() == "text" {
if contentTypeResult.Type == gjson.String && contentTypeResult.String() == "thinking" {
prompt := contentResult.Get("thinking").String()
signatureResult := contentResult.Get("signature")
signature := geminiCLIClaudeThoughtSignature
if signatureResult.Exists() {
signature = signatureResult.String()
}
clientContent.Parts = append(clientContent.Parts, client.Part{Text: prompt, Thought: true, ThoughtSignature: signature})
} else if contentTypeResult.Type == gjson.String && contentTypeResult.String() == "text" {
prompt := contentResult.Get("text").String()
clientContent.Parts = append(clientContent.Parts, client.Part{Text: prompt})
} else if contentTypeResult.Type == gjson.String && contentTypeResult.String() == "tool_use" {
@@ -92,10 +100,16 @@ func ConvertClaudeRequestToAntigravity(modelName string, inputRawJSON []byte, _
functionID := contentResult.Get("id").String()
var args map[string]any
if err := json.Unmarshal([]byte(functionArgs), &args); err == nil {
clientContent.Parts = append(clientContent.Parts, client.Part{
FunctionCall: &client.FunctionCall{ID: functionID, Name: functionName, Args: args},
ThoughtSignature: geminiCLIClaudeThoughtSignature,
})
if strings.Contains(modelName, "claude") {
clientContent.Parts = append(clientContent.Parts, client.Part{
FunctionCall: &client.FunctionCall{ID: functionID, Name: functionName, Args: args},
})
} else {
clientContent.Parts = append(clientContent.Parts, client.Part{
FunctionCall: &client.FunctionCall{ID: functionID, Name: functionName, Args: args},
ThoughtSignature: geminiCLIClaudeThoughtSignature,
})
}
}
} else if contentTypeResult.Type == gjson.String && contentTypeResult.String() == "tool_result" {
toolCallID := contentResult.Get("tool_use_id").String()
@@ -181,6 +195,9 @@ func ConvertClaudeRequestToAntigravity(modelName string, inputRawJSON []byte, _
if v := gjson.GetBytes(rawJSON, "top_k"); v.Exists() && v.Type == gjson.Number {
out, _ = sjson.Set(out, "request.generationConfig.topK", v.Num)
}
if v := gjson.GetBytes(rawJSON, "max_tokens"); v.Exists() && v.Type == gjson.Number {
out, _ = sjson.Set(out, "request.generationConfig.maxOutputTokens", v.Num)
}
outBytes := []byte(out)
outBytes = common.AttachDefaultSafetySettings(outBytes, "request.safetySettings")

View File

@@ -111,8 +111,11 @@ func ConvertAntigravityResponseToClaude(_ context.Context, _ string, originalReq
if partTextResult.Exists() {
// Process thinking content (internal reasoning)
if partResult.Get("thought").Bool() {
// Continue existing thinking block if already in thinking state
if params.ResponseType == 2 {
if thoughtSignature := partResult.Get("thoughtSignature"); thoughtSignature.Exists() && thoughtSignature.String() != "" {
output = output + "event: content_block_delta\n"
data, _ := sjson.Set(fmt.Sprintf(`{"type":"content_block_delta","index":%d,"delta":{"type":"signature_delta","signature":""}}`, params.ResponseIndex), "delta.signature", thoughtSignature.String())
output = output + fmt.Sprintf("data: %s\n\n\n", data)
} else if params.ResponseType == 2 { // Continue existing thinking block if already in thinking state
output = output + "event: content_block_delta\n"
data, _ := sjson.Set(fmt.Sprintf(`{"type":"content_block_delta","index":%d,"delta":{"type":"thinking_delta","thinking":""}}`, params.ResponseIndex), "delta.thinking", partTextResult.String())
output = output + fmt.Sprintf("data: %s\n\n\n", data)
@@ -163,15 +166,16 @@ func ConvertAntigravityResponseToClaude(_ context.Context, _ string, originalReq
output = output + "\n\n\n"
params.ResponseIndex++
}
// Start a new text content block
output = output + "event: content_block_start\n"
output = output + fmt.Sprintf(`data: {"type":"content_block_start","index":%d,"content_block":{"type":"text","text":""}}`, params.ResponseIndex)
output = output + "\n\n\n"
output = output + "event: content_block_delta\n"
data, _ := sjson.Set(fmt.Sprintf(`{"type":"content_block_delta","index":%d,"delta":{"type":"text_delta","text":""}}`, params.ResponseIndex), "delta.text", partTextResult.String())
output = output + fmt.Sprintf("data: %s\n\n\n", data)
params.ResponseType = 1 // Set state to content
if partTextResult.String() != "" {
// Start a new text content block
output = output + "event: content_block_start\n"
output = output + fmt.Sprintf(`data: {"type":"content_block_start","index":%d,"content_block":{"type":"text","text":""}}`, params.ResponseIndex)
output = output + "\n\n\n"
output = output + "event: content_block_delta\n"
data, _ := sjson.Set(fmt.Sprintf(`{"type":"content_block_delta","index":%d,"delta":{"type":"text_delta","text":""}}`, params.ResponseIndex), "delta.text", partTextResult.String())
output = output + fmt.Sprintf("data: %s\n\n\n", data)
params.ResponseType = 1 // Set state to content
}
}
}
}

View File

@@ -88,6 +88,20 @@ func ConvertOpenAIRequestToAntigravity(modelName string, inputRawJSON []byte, _
}
}
// Claude/Anthropic API format: thinking.type == "enabled" with budget_tokens
// This allows Claude Code and other Claude API clients to pass thinking configuration
if !gjson.GetBytes(out, "request.generationConfig.thinkingConfig").Exists() && util.ModelSupportsThinking(modelName) {
if t := gjson.GetBytes(rawJSON, "thinking"); t.Exists() && t.IsObject() {
if t.Get("type").String() == "enabled" {
if b := t.Get("budget_tokens"); b.Exists() && b.Type == gjson.Number {
budget := util.NormalizeThinkingBudget(modelName, int(b.Int()))
out, _ = sjson.SetBytes(out, "request.generationConfig.thinkingConfig.thinkingBudget", budget)
out, _ = sjson.SetBytes(out, "request.generationConfig.thinkingConfig.include_thoughts", true)
}
}
}
}
// For gemini-3-pro-preview, always send default thinkingConfig when none specified.
// This matches the official Gemini CLI behavior which always sends:
// { thinkingBudget: -1, includeThoughts: true }
@@ -97,7 +111,7 @@ func ConvertOpenAIRequestToAntigravity(modelName string, inputRawJSON []byte, _
out, _ = sjson.SetBytes(out, "request.generationConfig.thinkingConfig.include_thoughts", true)
}
// Temperature/top_p/top_k
// Temperature/top_p/top_k/max_tokens
if tr := gjson.GetBytes(rawJSON, "temperature"); tr.Exists() && tr.Type == gjson.Number {
out, _ = sjson.SetBytes(out, "request.generationConfig.temperature", tr.Num)
}
@@ -107,6 +121,9 @@ func ConvertOpenAIRequestToAntigravity(modelName string, inputRawJSON []byte, _
if tkr := gjson.GetBytes(rawJSON, "top_k"); tkr.Exists() && tkr.Type == gjson.Number {
out, _ = sjson.SetBytes(out, "request.generationConfig.topK", tkr.Num)
}
if maxTok := gjson.GetBytes(rawJSON, "max_tokens"); maxTok.Exists() && maxTok.Type == gjson.Number {
out, _ = sjson.SetBytes(out, "request.generationConfig.maxOutputTokens", maxTok.Num)
}
// Map OpenAI modalities -> Gemini CLI request.generationConfig.responseModalities
// e.g. "modalities": ["image", "text"] -> ["IMAGE", "TEXT"]
@@ -263,7 +280,7 @@ func ConvertOpenAIRequestToAntigravity(modelName string, inputRawJSON []byte, _
out, _ = sjson.SetRawBytes(out, "request.contents.-1", node)
// Append a single tool content combining name + response per function
toolNode := []byte(`{"role":"tool","parts":[]}`)
toolNode := []byte(`{"role":"user","parts":[]}`)
pp := 0
for _, fid := range fIDs {
if name, ok := tcID2Name[fid]; ok {

View File

@@ -10,6 +10,7 @@ import (
"context"
"encoding/json"
"fmt"
"strings"
"time"
. "github.com/router-for-me/CLIProxyAPI/v6/internal/translator/gemini/openai/chat-completions"
@@ -75,8 +76,8 @@ func ConvertAntigravityResponseToOpenAI(_ context.Context, _ string, originalReq
// Extract and set the finish reason.
if finishReasonResult := gjson.GetBytes(rawJSON, "response.candidates.0.finishReason"); finishReasonResult.Exists() {
template, _ = sjson.Set(template, "choices.0.finish_reason", finishReasonResult.String())
template, _ = sjson.Set(template, "choices.0.native_finish_reason", finishReasonResult.String())
template, _ = sjson.Set(template, "choices.0.finish_reason", strings.ToLower(finishReasonResult.String()))
template, _ = sjson.Set(template, "choices.0.native_finish_reason", strings.ToLower(finishReasonResult.String()))
}
// Extract and set usage metadata (token counts).

View File

@@ -10,6 +10,7 @@ import (
"context"
"encoding/json"
"fmt"
"strings"
"time"
. "github.com/router-for-me/CLIProxyAPI/v6/internal/translator/gemini/openai/chat-completions"
@@ -75,8 +76,8 @@ func ConvertCliResponseToOpenAI(_ context.Context, _ string, originalRequestRawJ
// Extract and set the finish reason.
if finishReasonResult := gjson.GetBytes(rawJSON, "response.candidates.0.finishReason"); finishReasonResult.Exists() {
template, _ = sjson.Set(template, "choices.0.finish_reason", finishReasonResult.String())
template, _ = sjson.Set(template, "choices.0.native_finish_reason", finishReasonResult.String())
template, _ = sjson.Set(template, "choices.0.finish_reason", strings.ToLower(finishReasonResult.String()))
template, _ = sjson.Set(template, "choices.0.native_finish_reason", strings.ToLower(finishReasonResult.String()))
}
// Extract and set usage metadata (token counts).

View File

@@ -10,6 +10,7 @@ import (
"context"
"encoding/json"
"fmt"
"strings"
"time"
"github.com/tidwall/gjson"
@@ -78,8 +79,8 @@ func ConvertGeminiResponseToOpenAI(_ context.Context, _ string, originalRequestR
// Extract and set the finish reason.
if finishReasonResult := gjson.GetBytes(rawJSON, "candidates.0.finishReason"); finishReasonResult.Exists() {
template, _ = sjson.Set(template, "choices.0.finish_reason", finishReasonResult.String())
template, _ = sjson.Set(template, "choices.0.native_finish_reason", finishReasonResult.String())
template, _ = sjson.Set(template, "choices.0.finish_reason", strings.ToLower(finishReasonResult.String()))
template, _ = sjson.Set(template, "choices.0.native_finish_reason", strings.ToLower(finishReasonResult.String()))
}
// Extract and set usage metadata (token counts).
@@ -230,8 +231,8 @@ func ConvertGeminiResponseToOpenAINonStream(_ context.Context, _ string, origina
}
if finishReasonResult := gjson.GetBytes(rawJSON, "candidates.0.finishReason"); finishReasonResult.Exists() {
template, _ = sjson.Set(template, "choices.0.finish_reason", finishReasonResult.String())
template, _ = sjson.Set(template, "choices.0.native_finish_reason", finishReasonResult.String())
template, _ = sjson.Set(template, "choices.0.finish_reason", strings.ToLower(finishReasonResult.String()))
template, _ = sjson.Set(template, "choices.0.native_finish_reason", strings.ToLower(finishReasonResult.String()))
}
if usageResult := gjson.GetBytes(rawJSON, "usageMetadata"); usageResult.Exists() {

View File

@@ -8,6 +8,7 @@ package claude
import (
"bytes"
"encoding/json"
"strings"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
@@ -242,11 +243,12 @@ func convertClaudeContentPart(part gjson.Result) (string, bool) {
switch partType {
case "text":
if !part.Get("text").Exists() {
text := part.Get("text").String()
if strings.TrimSpace(text) == "" {
return "", false
}
textContent := `{"type":"text","text":""}`
textContent, _ = sjson.Set(textContent, "text", part.Get("text").String())
textContent, _ = sjson.Set(textContent, "text", text)
return textContent, true
case "image":

View File

@@ -498,7 +498,7 @@ func (s *Service) Run(ctx context.Context) error {
}()
time.Sleep(100 * time.Millisecond)
fmt.Printf("API server started successfully on: %d\n", s.cfg.Port)
fmt.Printf("API server started successfully on: %s:%d\n", s.cfg.Host, s.cfg.Port)
if s.hooks.OnAfterStart != nil {
s.hooks.OnAfterStart(s)

827
test/amp_management_test.go Normal file
View File

@@ -0,0 +1,827 @@
package test
import (
"bytes"
"encoding/json"
"net/http"
"net/http/httptest"
"os"
"path/filepath"
"testing"
"github.com/gin-gonic/gin"
"github.com/router-for-me/CLIProxyAPI/v6/internal/api/handlers/management"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
)
func init() {
gin.SetMode(gin.TestMode)
}
// newAmpTestHandler creates a test handler with default ampcode configuration.
func newAmpTestHandler(t *testing.T) (*management.Handler, string) {
t.Helper()
tmpDir := t.TempDir()
configPath := filepath.Join(tmpDir, "config.yaml")
cfg := &config.Config{
AmpCode: config.AmpCode{
UpstreamURL: "https://example.com",
UpstreamAPIKey: "test-api-key-12345",
RestrictManagementToLocalhost: true,
ForceModelMappings: false,
ModelMappings: []config.AmpModelMapping{
{From: "gpt-4", To: "gemini-pro"},
},
},
}
if err := os.WriteFile(configPath, []byte("port: 8080\n"), 0644); err != nil {
t.Fatalf("failed to write config file: %v", err)
}
h := management.NewHandler(cfg, configPath, nil)
return h, configPath
}
// setupAmpRouter creates a test router with all ampcode management endpoints.
func setupAmpRouter(h *management.Handler) *gin.Engine {
r := gin.New()
mgmt := r.Group("/v0/management")
{
mgmt.GET("/ampcode", h.GetAmpCode)
mgmt.GET("/ampcode/upstream-url", h.GetAmpUpstreamURL)
mgmt.PUT("/ampcode/upstream-url", h.PutAmpUpstreamURL)
mgmt.DELETE("/ampcode/upstream-url", h.DeleteAmpUpstreamURL)
mgmt.GET("/ampcode/upstream-api-key", h.GetAmpUpstreamAPIKey)
mgmt.PUT("/ampcode/upstream-api-key", h.PutAmpUpstreamAPIKey)
mgmt.DELETE("/ampcode/upstream-api-key", h.DeleteAmpUpstreamAPIKey)
mgmt.GET("/ampcode/restrict-management-to-localhost", h.GetAmpRestrictManagementToLocalhost)
mgmt.PUT("/ampcode/restrict-management-to-localhost", h.PutAmpRestrictManagementToLocalhost)
mgmt.GET("/ampcode/model-mappings", h.GetAmpModelMappings)
mgmt.PUT("/ampcode/model-mappings", h.PutAmpModelMappings)
mgmt.PATCH("/ampcode/model-mappings", h.PatchAmpModelMappings)
mgmt.DELETE("/ampcode/model-mappings", h.DeleteAmpModelMappings)
mgmt.GET("/ampcode/force-model-mappings", h.GetAmpForceModelMappings)
mgmt.PUT("/ampcode/force-model-mappings", h.PutAmpForceModelMappings)
}
return r
}
// TestGetAmpCode verifies GET /v0/management/ampcode returns full ampcode config.
func TestGetAmpCode(t *testing.T) {
h, _ := newAmpTestHandler(t)
r := setupAmpRouter(h)
req := httptest.NewRequest(http.MethodGet, "/v0/management/ampcode", nil)
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("expected status %d, got %d", http.StatusOK, w.Code)
}
var resp map[string]config.AmpCode
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
t.Fatalf("failed to unmarshal response: %v", err)
}
ampcode := resp["ampcode"]
if ampcode.UpstreamURL != "https://example.com" {
t.Errorf("expected upstream-url %q, got %q", "https://example.com", ampcode.UpstreamURL)
}
if len(ampcode.ModelMappings) != 1 {
t.Errorf("expected 1 model mapping, got %d", len(ampcode.ModelMappings))
}
}
// TestGetAmpUpstreamURL verifies GET /v0/management/ampcode/upstream-url returns the upstream URL.
func TestGetAmpUpstreamURL(t *testing.T) {
h, _ := newAmpTestHandler(t)
r := setupAmpRouter(h)
req := httptest.NewRequest(http.MethodGet, "/v0/management/ampcode/upstream-url", nil)
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("expected status %d, got %d", http.StatusOK, w.Code)
}
var resp map[string]string
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
t.Fatalf("failed to unmarshal response: %v", err)
}
if resp["upstream-url"] != "https://example.com" {
t.Errorf("expected %q, got %q", "https://example.com", resp["upstream-url"])
}
}
// TestPutAmpUpstreamURL verifies PUT /v0/management/ampcode/upstream-url updates the upstream URL.
func TestPutAmpUpstreamURL(t *testing.T) {
h, _ := newAmpTestHandler(t)
r := setupAmpRouter(h)
body := `{"value": "https://new-upstream.com"}`
req := httptest.NewRequest(http.MethodPut, "/v0/management/ampcode/upstream-url", bytes.NewBufferString(body))
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("expected status %d, got %d: %s", http.StatusOK, w.Code, w.Body.String())
}
}
// TestDeleteAmpUpstreamURL verifies DELETE /v0/management/ampcode/upstream-url clears the upstream URL.
func TestDeleteAmpUpstreamURL(t *testing.T) {
h, _ := newAmpTestHandler(t)
r := setupAmpRouter(h)
req := httptest.NewRequest(http.MethodDelete, "/v0/management/ampcode/upstream-url", nil)
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("expected status %d, got %d", http.StatusOK, w.Code)
}
}
// TestGetAmpUpstreamAPIKey verifies GET /v0/management/ampcode/upstream-api-key returns the API key.
func TestGetAmpUpstreamAPIKey(t *testing.T) {
h, _ := newAmpTestHandler(t)
r := setupAmpRouter(h)
req := httptest.NewRequest(http.MethodGet, "/v0/management/ampcode/upstream-api-key", nil)
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("expected status %d, got %d", http.StatusOK, w.Code)
}
var resp map[string]any
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
t.Fatalf("failed to unmarshal response: %v", err)
}
key := resp["upstream-api-key"].(string)
if key != "test-api-key-12345" {
t.Errorf("expected key %q, got %q", "test-api-key-12345", key)
}
}
// TestPutAmpUpstreamAPIKey verifies PUT /v0/management/ampcode/upstream-api-key updates the API key.
func TestPutAmpUpstreamAPIKey(t *testing.T) {
h, _ := newAmpTestHandler(t)
r := setupAmpRouter(h)
body := `{"value": "new-secret-key"}`
req := httptest.NewRequest(http.MethodPut, "/v0/management/ampcode/upstream-api-key", bytes.NewBufferString(body))
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("expected status %d, got %d", http.StatusOK, w.Code)
}
}
// TestDeleteAmpUpstreamAPIKey verifies DELETE /v0/management/ampcode/upstream-api-key clears the API key.
func TestDeleteAmpUpstreamAPIKey(t *testing.T) {
h, _ := newAmpTestHandler(t)
r := setupAmpRouter(h)
req := httptest.NewRequest(http.MethodDelete, "/v0/management/ampcode/upstream-api-key", nil)
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("expected status %d, got %d", http.StatusOK, w.Code)
}
}
// TestGetAmpRestrictManagementToLocalhost verifies GET returns the localhost restriction setting.
func TestGetAmpRestrictManagementToLocalhost(t *testing.T) {
h, _ := newAmpTestHandler(t)
r := setupAmpRouter(h)
req := httptest.NewRequest(http.MethodGet, "/v0/management/ampcode/restrict-management-to-localhost", nil)
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("expected status %d, got %d", http.StatusOK, w.Code)
}
var resp map[string]bool
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
t.Fatalf("failed to unmarshal response: %v", err)
}
if resp["restrict-management-to-localhost"] != true {
t.Error("expected restrict-management-to-localhost to be true")
}
}
// TestPutAmpRestrictManagementToLocalhost verifies PUT updates the localhost restriction setting.
func TestPutAmpRestrictManagementToLocalhost(t *testing.T) {
h, _ := newAmpTestHandler(t)
r := setupAmpRouter(h)
body := `{"value": false}`
req := httptest.NewRequest(http.MethodPut, "/v0/management/ampcode/restrict-management-to-localhost", bytes.NewBufferString(body))
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("expected status %d, got %d", http.StatusOK, w.Code)
}
}
// TestGetAmpModelMappings verifies GET /v0/management/ampcode/model-mappings returns all mappings.
func TestGetAmpModelMappings(t *testing.T) {
h, _ := newAmpTestHandler(t)
r := setupAmpRouter(h)
req := httptest.NewRequest(http.MethodGet, "/v0/management/ampcode/model-mappings", nil)
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("expected status %d, got %d", http.StatusOK, w.Code)
}
var resp map[string][]config.AmpModelMapping
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
t.Fatalf("failed to unmarshal response: %v", err)
}
mappings := resp["model-mappings"]
if len(mappings) != 1 {
t.Fatalf("expected 1 mapping, got %d", len(mappings))
}
if mappings[0].From != "gpt-4" || mappings[0].To != "gemini-pro" {
t.Errorf("unexpected mapping: %+v", mappings[0])
}
}
// TestPutAmpModelMappings verifies PUT /v0/management/ampcode/model-mappings replaces all mappings.
func TestPutAmpModelMappings(t *testing.T) {
h, _ := newAmpTestHandler(t)
r := setupAmpRouter(h)
body := `{"value": [{"from": "claude-3", "to": "gpt-4o"}, {"from": "gemini", "to": "claude"}]}`
req := httptest.NewRequest(http.MethodPut, "/v0/management/ampcode/model-mappings", bytes.NewBufferString(body))
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("expected status %d, got %d: %s", http.StatusOK, w.Code, w.Body.String())
}
}
// TestPatchAmpModelMappings verifies PATCH updates existing mappings and adds new ones.
func TestPatchAmpModelMappings(t *testing.T) {
h, _ := newAmpTestHandler(t)
r := setupAmpRouter(h)
body := `{"value": [{"from": "gpt-4", "to": "updated-model"}, {"from": "new-model", "to": "target"}]}`
req := httptest.NewRequest(http.MethodPatch, "/v0/management/ampcode/model-mappings", bytes.NewBufferString(body))
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("expected status %d, got %d: %s", http.StatusOK, w.Code, w.Body.String())
}
}
// TestDeleteAmpModelMappings_Specific verifies DELETE removes specified mappings by "from" field.
func TestDeleteAmpModelMappings_Specific(t *testing.T) {
h, _ := newAmpTestHandler(t)
r := setupAmpRouter(h)
body := `{"value": ["gpt-4"]}`
req := httptest.NewRequest(http.MethodDelete, "/v0/management/ampcode/model-mappings", bytes.NewBufferString(body))
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("expected status %d, got %d", http.StatusOK, w.Code)
}
}
// TestDeleteAmpModelMappings_All verifies DELETE with empty body removes all mappings.
func TestDeleteAmpModelMappings_All(t *testing.T) {
h, _ := newAmpTestHandler(t)
r := setupAmpRouter(h)
req := httptest.NewRequest(http.MethodDelete, "/v0/management/ampcode/model-mappings", nil)
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("expected status %d, got %d", http.StatusOK, w.Code)
}
}
// TestGetAmpForceModelMappings verifies GET returns the force-model-mappings setting.
func TestGetAmpForceModelMappings(t *testing.T) {
h, _ := newAmpTestHandler(t)
r := setupAmpRouter(h)
req := httptest.NewRequest(http.MethodGet, "/v0/management/ampcode/force-model-mappings", nil)
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("expected status %d, got %d", http.StatusOK, w.Code)
}
var resp map[string]bool
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
t.Fatalf("failed to unmarshal response: %v", err)
}
if resp["force-model-mappings"] != false {
t.Error("expected force-model-mappings to be false")
}
}
// TestPutAmpForceModelMappings verifies PUT updates the force-model-mappings setting.
func TestPutAmpForceModelMappings(t *testing.T) {
h, _ := newAmpTestHandler(t)
r := setupAmpRouter(h)
body := `{"value": true}`
req := httptest.NewRequest(http.MethodPut, "/v0/management/ampcode/force-model-mappings", bytes.NewBufferString(body))
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("expected status %d, got %d", http.StatusOK, w.Code)
}
}
// TestPutAmpModelMappings_VerifyState verifies PUT replaces mappings and state is persisted.
func TestPutAmpModelMappings_VerifyState(t *testing.T) {
h, _ := newAmpTestHandler(t)
r := setupAmpRouter(h)
body := `{"value": [{"from": "model-a", "to": "model-b"}, {"from": "model-c", "to": "model-d"}, {"from": "model-e", "to": "model-f"}]}`
req := httptest.NewRequest(http.MethodPut, "/v0/management/ampcode/model-mappings", bytes.NewBufferString(body))
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("PUT failed: status %d, body: %s", w.Code, w.Body.String())
}
req = httptest.NewRequest(http.MethodGet, "/v0/management/ampcode/model-mappings", nil)
w = httptest.NewRecorder()
r.ServeHTTP(w, req)
var resp map[string][]config.AmpModelMapping
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
t.Fatalf("failed to unmarshal: %v", err)
}
mappings := resp["model-mappings"]
if len(mappings) != 3 {
t.Fatalf("expected 3 mappings, got %d", len(mappings))
}
expected := map[string]string{"model-a": "model-b", "model-c": "model-d", "model-e": "model-f"}
for _, m := range mappings {
if expected[m.From] != m.To {
t.Errorf("mapping %q -> expected %q, got %q", m.From, expected[m.From], m.To)
}
}
}
// TestPatchAmpModelMappings_VerifyState verifies PATCH merges mappings correctly.
func TestPatchAmpModelMappings_VerifyState(t *testing.T) {
h, _ := newAmpTestHandler(t)
r := setupAmpRouter(h)
body := `{"value": [{"from": "gpt-4", "to": "updated-target"}, {"from": "new-model", "to": "new-target"}]}`
req := httptest.NewRequest(http.MethodPatch, "/v0/management/ampcode/model-mappings", bytes.NewBufferString(body))
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("PATCH failed: status %d", w.Code)
}
req = httptest.NewRequest(http.MethodGet, "/v0/management/ampcode/model-mappings", nil)
w = httptest.NewRecorder()
r.ServeHTTP(w, req)
var resp map[string][]config.AmpModelMapping
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
t.Fatalf("failed to unmarshal: %v", err)
}
mappings := resp["model-mappings"]
if len(mappings) != 2 {
t.Fatalf("expected 2 mappings (1 updated + 1 new), got %d", len(mappings))
}
found := make(map[string]string)
for _, m := range mappings {
found[m.From] = m.To
}
if found["gpt-4"] != "updated-target" {
t.Errorf("gpt-4 should map to updated-target, got %q", found["gpt-4"])
}
if found["new-model"] != "new-target" {
t.Errorf("new-model should map to new-target, got %q", found["new-model"])
}
}
// TestDeleteAmpModelMappings_VerifyState verifies DELETE removes specific mappings and keeps others.
func TestDeleteAmpModelMappings_VerifyState(t *testing.T) {
h, _ := newAmpTestHandler(t)
r := setupAmpRouter(h)
putBody := `{"value": [{"from": "a", "to": "1"}, {"from": "b", "to": "2"}, {"from": "c", "to": "3"}]}`
req := httptest.NewRequest(http.MethodPut, "/v0/management/ampcode/model-mappings", bytes.NewBufferString(putBody))
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
delBody := `{"value": ["a", "c"]}`
req = httptest.NewRequest(http.MethodDelete, "/v0/management/ampcode/model-mappings", bytes.NewBufferString(delBody))
req.Header.Set("Content-Type", "application/json")
w = httptest.NewRecorder()
r.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("DELETE failed: status %d", w.Code)
}
req = httptest.NewRequest(http.MethodGet, "/v0/management/ampcode/model-mappings", nil)
w = httptest.NewRecorder()
r.ServeHTTP(w, req)
var resp map[string][]config.AmpModelMapping
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
t.Fatalf("failed to unmarshal: %v", err)
}
mappings := resp["model-mappings"]
if len(mappings) != 1 {
t.Fatalf("expected 1 mapping remaining, got %d", len(mappings))
}
if mappings[0].From != "b" || mappings[0].To != "2" {
t.Errorf("expected b->2, got %s->%s", mappings[0].From, mappings[0].To)
}
}
// TestDeleteAmpModelMappings_NonExistent verifies DELETE with non-existent mapping doesn't affect existing ones.
func TestDeleteAmpModelMappings_NonExistent(t *testing.T) {
h, _ := newAmpTestHandler(t)
r := setupAmpRouter(h)
delBody := `{"value": ["non-existent-model"]}`
req := httptest.NewRequest(http.MethodDelete, "/v0/management/ampcode/model-mappings", bytes.NewBufferString(delBody))
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("expected status %d, got %d", http.StatusOK, w.Code)
}
req = httptest.NewRequest(http.MethodGet, "/v0/management/ampcode/model-mappings", nil)
w = httptest.NewRecorder()
r.ServeHTTP(w, req)
var resp map[string][]config.AmpModelMapping
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
t.Fatalf("failed to unmarshal: %v", err)
}
if len(resp["model-mappings"]) != 1 {
t.Errorf("original mapping should remain, got %d mappings", len(resp["model-mappings"]))
}
}
// TestPutAmpModelMappings_Empty verifies PUT with empty array clears all mappings.
func TestPutAmpModelMappings_Empty(t *testing.T) {
h, _ := newAmpTestHandler(t)
r := setupAmpRouter(h)
body := `{"value": []}`
req := httptest.NewRequest(http.MethodPut, "/v0/management/ampcode/model-mappings", bytes.NewBufferString(body))
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("expected status %d, got %d", http.StatusOK, w.Code)
}
req = httptest.NewRequest(http.MethodGet, "/v0/management/ampcode/model-mappings", nil)
w = httptest.NewRecorder()
r.ServeHTTP(w, req)
var resp map[string][]config.AmpModelMapping
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
t.Fatalf("failed to unmarshal: %v", err)
}
if len(resp["model-mappings"]) != 0 {
t.Errorf("expected 0 mappings, got %d", len(resp["model-mappings"]))
}
}
// TestPutAmpUpstreamURL_VerifyState verifies PUT updates upstream URL and persists state.
func TestPutAmpUpstreamURL_VerifyState(t *testing.T) {
h, _ := newAmpTestHandler(t)
r := setupAmpRouter(h)
body := `{"value": "https://new-api.example.com"}`
req := httptest.NewRequest(http.MethodPut, "/v0/management/ampcode/upstream-url", bytes.NewBufferString(body))
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("PUT failed: status %d", w.Code)
}
req = httptest.NewRequest(http.MethodGet, "/v0/management/ampcode/upstream-url", nil)
w = httptest.NewRecorder()
r.ServeHTTP(w, req)
var resp map[string]string
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
t.Fatalf("failed to unmarshal: %v", err)
}
if resp["upstream-url"] != "https://new-api.example.com" {
t.Errorf("expected %q, got %q", "https://new-api.example.com", resp["upstream-url"])
}
}
// TestDeleteAmpUpstreamURL_VerifyState verifies DELETE clears upstream URL.
func TestDeleteAmpUpstreamURL_VerifyState(t *testing.T) {
h, _ := newAmpTestHandler(t)
r := setupAmpRouter(h)
req := httptest.NewRequest(http.MethodDelete, "/v0/management/ampcode/upstream-url", nil)
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("DELETE failed: status %d", w.Code)
}
req = httptest.NewRequest(http.MethodGet, "/v0/management/ampcode/upstream-url", nil)
w = httptest.NewRecorder()
r.ServeHTTP(w, req)
var resp map[string]string
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
t.Fatalf("failed to unmarshal: %v", err)
}
if resp["upstream-url"] != "" {
t.Errorf("expected empty string, got %q", resp["upstream-url"])
}
}
// TestPutAmpUpstreamAPIKey_VerifyState verifies PUT updates API key and persists state.
func TestPutAmpUpstreamAPIKey_VerifyState(t *testing.T) {
h, _ := newAmpTestHandler(t)
r := setupAmpRouter(h)
body := `{"value": "new-secret-api-key-xyz"}`
req := httptest.NewRequest(http.MethodPut, "/v0/management/ampcode/upstream-api-key", bytes.NewBufferString(body))
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("PUT failed: status %d", w.Code)
}
req = httptest.NewRequest(http.MethodGet, "/v0/management/ampcode/upstream-api-key", nil)
w = httptest.NewRecorder()
r.ServeHTTP(w, req)
var resp map[string]string
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
t.Fatalf("failed to unmarshal: %v", err)
}
if resp["upstream-api-key"] != "new-secret-api-key-xyz" {
t.Errorf("expected %q, got %q", "new-secret-api-key-xyz", resp["upstream-api-key"])
}
}
// TestDeleteAmpUpstreamAPIKey_VerifyState verifies DELETE clears API key.
func TestDeleteAmpUpstreamAPIKey_VerifyState(t *testing.T) {
h, _ := newAmpTestHandler(t)
r := setupAmpRouter(h)
req := httptest.NewRequest(http.MethodDelete, "/v0/management/ampcode/upstream-api-key", nil)
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("DELETE failed: status %d", w.Code)
}
req = httptest.NewRequest(http.MethodGet, "/v0/management/ampcode/upstream-api-key", nil)
w = httptest.NewRecorder()
r.ServeHTTP(w, req)
var resp map[string]string
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
t.Fatalf("failed to unmarshal: %v", err)
}
if resp["upstream-api-key"] != "" {
t.Errorf("expected empty string, got %q", resp["upstream-api-key"])
}
}
// TestPutAmpRestrictManagementToLocalhost_VerifyState verifies PUT updates localhost restriction.
func TestPutAmpRestrictManagementToLocalhost_VerifyState(t *testing.T) {
h, _ := newAmpTestHandler(t)
r := setupAmpRouter(h)
body := `{"value": false}`
req := httptest.NewRequest(http.MethodPut, "/v0/management/ampcode/restrict-management-to-localhost", bytes.NewBufferString(body))
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("PUT failed: status %d", w.Code)
}
req = httptest.NewRequest(http.MethodGet, "/v0/management/ampcode/restrict-management-to-localhost", nil)
w = httptest.NewRecorder()
r.ServeHTTP(w, req)
var resp map[string]bool
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
t.Fatalf("failed to unmarshal: %v", err)
}
if resp["restrict-management-to-localhost"] != false {
t.Error("expected false after update")
}
}
// TestPutAmpForceModelMappings_VerifyState verifies PUT updates force-model-mappings setting.
func TestPutAmpForceModelMappings_VerifyState(t *testing.T) {
h, _ := newAmpTestHandler(t)
r := setupAmpRouter(h)
body := `{"value": true}`
req := httptest.NewRequest(http.MethodPut, "/v0/management/ampcode/force-model-mappings", bytes.NewBufferString(body))
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("PUT failed: status %d", w.Code)
}
req = httptest.NewRequest(http.MethodGet, "/v0/management/ampcode/force-model-mappings", nil)
w = httptest.NewRecorder()
r.ServeHTTP(w, req)
var resp map[string]bool
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
t.Fatalf("failed to unmarshal: %v", err)
}
if resp["force-model-mappings"] != true {
t.Error("expected true after update")
}
}
// TestPutBoolField_EmptyObject verifies PUT with empty object returns 400.
func TestPutBoolField_EmptyObject(t *testing.T) {
h, _ := newAmpTestHandler(t)
r := setupAmpRouter(h)
body := `{}`
req := httptest.NewRequest(http.MethodPut, "/v0/management/ampcode/force-model-mappings", bytes.NewBufferString(body))
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
if w.Code != http.StatusBadRequest {
t.Fatalf("expected status %d for empty object, got %d", http.StatusBadRequest, w.Code)
}
}
// TestComplexMappingsWorkflow tests a full workflow: PUT, PATCH, DELETE, and GET.
func TestComplexMappingsWorkflow(t *testing.T) {
h, _ := newAmpTestHandler(t)
r := setupAmpRouter(h)
putBody := `{"value": [{"from": "m1", "to": "t1"}, {"from": "m2", "to": "t2"}, {"from": "m3", "to": "t3"}, {"from": "m4", "to": "t4"}]}`
req := httptest.NewRequest(http.MethodPut, "/v0/management/ampcode/model-mappings", bytes.NewBufferString(putBody))
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
patchBody := `{"value": [{"from": "m2", "to": "t2-updated"}, {"from": "m5", "to": "t5"}]}`
req = httptest.NewRequest(http.MethodPatch, "/v0/management/ampcode/model-mappings", bytes.NewBufferString(patchBody))
req.Header.Set("Content-Type", "application/json")
w = httptest.NewRecorder()
r.ServeHTTP(w, req)
delBody := `{"value": ["m1", "m3"]}`
req = httptest.NewRequest(http.MethodDelete, "/v0/management/ampcode/model-mappings", bytes.NewBufferString(delBody))
req.Header.Set("Content-Type", "application/json")
w = httptest.NewRecorder()
r.ServeHTTP(w, req)
req = httptest.NewRequest(http.MethodGet, "/v0/management/ampcode/model-mappings", nil)
w = httptest.NewRecorder()
r.ServeHTTP(w, req)
var resp map[string][]config.AmpModelMapping
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
t.Fatalf("failed to unmarshal: %v", err)
}
mappings := resp["model-mappings"]
if len(mappings) != 3 {
t.Fatalf("expected 3 mappings (m2, m4, m5), got %d", len(mappings))
}
expected := map[string]string{"m2": "t2-updated", "m4": "t4", "m5": "t5"}
found := make(map[string]string)
for _, m := range mappings {
found[m.From] = m.To
}
for from, to := range expected {
if found[from] != to {
t.Errorf("mapping %s: expected %q, got %q", from, to, found[from])
}
}
}
// TestNilHandlerGetAmpCode verifies handler works with empty config.
func TestNilHandlerGetAmpCode(t *testing.T) {
cfg := &config.Config{}
h := management.NewHandler(cfg, "", nil)
r := setupAmpRouter(h)
req := httptest.NewRequest(http.MethodGet, "/v0/management/ampcode", nil)
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("expected status %d, got %d", http.StatusOK, w.Code)
}
}
// TestEmptyConfigGetAmpModelMappings verifies GET returns empty array for fresh config.
func TestEmptyConfigGetAmpModelMappings(t *testing.T) {
cfg := &config.Config{}
tmpDir := t.TempDir()
configPath := filepath.Join(tmpDir, "config.yaml")
if err := os.WriteFile(configPath, []byte("port: 8080\n"), 0644); err != nil {
t.Fatalf("failed to write config: %v", err)
}
h := management.NewHandler(cfg, configPath, nil)
r := setupAmpRouter(h)
req := httptest.NewRequest(http.MethodGet, "/v0/management/ampcode/model-mappings", nil)
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("expected status %d, got %d", http.StatusOK, w.Code)
}
var resp map[string][]config.AmpModelMapping
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
t.Fatalf("failed to unmarshal: %v", err)
}
if len(resp["model-mappings"]) != 0 {
t.Errorf("expected 0 mappings, got %d", len(resp["model-mappings"]))
}
}