mirror of
https://github.com/router-for-me/CLIProxyAPI.git
synced 2026-02-04 21:40:51 +08:00
Compare commits
67 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
1434bc38e5 | ||
|
|
0fd2abbc3b | ||
|
|
0ebb654019 | ||
|
|
08a1d2edf9 | ||
|
|
3409f4e336 | ||
|
|
9354b87e54 | ||
|
|
54e24110ec | ||
|
|
717c703bff | ||
|
|
1c6f4be8ae | ||
|
|
0de2560cee | ||
|
|
85eb926482 | ||
|
|
c52ef08e67 | ||
|
|
cb580cd083 | ||
|
|
75e278c7a5 | ||
|
|
73208c4e55 | ||
|
|
32d3809f8c | ||
|
|
a748e93fd9 | ||
|
|
54a9c4c3c7 | ||
|
|
18b5c35dea | ||
|
|
7b7871ede2 | ||
|
|
c4e3646b75 | ||
|
|
022aa81be1 | ||
|
|
c43f0ea7b1 | ||
|
|
6a191358af | ||
|
|
db1119dd78 | ||
|
|
33a5656235 | ||
|
|
2cd59806e2 | ||
|
|
5983e3ec87 | ||
|
|
f8cebb9343 | ||
|
|
72c7ef7647 | ||
|
|
d2e4639b2a | ||
|
|
08321223c4 | ||
|
|
7e30157590 | ||
|
|
e73cdf5cff | ||
|
|
39621a0340 | ||
|
|
346b663079 | ||
|
|
0bcae68c6c | ||
|
|
c8cee547fd | ||
|
|
36755421fe | ||
|
|
6c17dbc4da | ||
|
|
ee6429cc75 | ||
|
|
a4a26d978e | ||
|
|
ed9f6e897e | ||
|
|
9c1e3c0687 | ||
|
|
2e5681ea32 | ||
|
|
52c17f03a5 | ||
|
|
d0e694d4ed | ||
|
|
506f1117dd | ||
|
|
113db3c5bf | ||
|
|
1aa0b6cd11 | ||
|
|
0895533400 | ||
|
|
43f007c234 | ||
|
|
0ceee56d99 | ||
|
|
943a8c74df | ||
|
|
0a47b452e9 | ||
|
|
261f08a82a | ||
|
|
d114d8d0bd | ||
|
|
bb9955e461 | ||
|
|
7063a176f4 | ||
|
|
e3082887a6 | ||
|
|
ddb0c0ec1c | ||
|
|
d1736cb29c | ||
|
|
62bfd62871 | ||
|
|
257621c5ed | ||
|
|
ac064389ca | ||
|
|
8d23ffc873 | ||
|
|
4307f08bbc |
5
.gitignore
vendored
5
.gitignore
vendored
@@ -15,6 +15,7 @@ pgstore/*
|
||||
gitstore/*
|
||||
objectstore/*
|
||||
static/*
|
||||
refs/*
|
||||
|
||||
# Authentication data
|
||||
auths/*
|
||||
@@ -30,3 +31,7 @@ GEMINI.md
|
||||
.vscode/*
|
||||
.claude/*
|
||||
.serena/*
|
||||
|
||||
# macOS
|
||||
.DS_Store
|
||||
._*
|
||||
|
||||
@@ -56,6 +56,7 @@ CLIProxyAPI includes integrated support for [Amp CLI](https://ampcode.com) and A
|
||||
- Provider route aliases for Amp's API patterns (`/api/provider/{provider}/v1...`)
|
||||
- Management proxy for OAuth authentication and account features
|
||||
- Smart model fallback with automatic routing
|
||||
- **Model mapping** to route unavailable models to alternatives (e.g., `claude-opus-4.5` → `claude-sonnet-4`)
|
||||
- Security-first design with localhost-only management endpoints
|
||||
|
||||
**→ [Complete Amp CLI Integration Guide](docs/amp-cli-integration.md)**
|
||||
@@ -90,6 +91,10 @@ Native macOS menu bar app to use your Claude Code & ChatGPT subscriptions with A
|
||||
|
||||
Browser-based tool to translate SRT subtitles using your Gemini subscription via CLIProxyAPI with automatic validation/error correction - no API keys needed
|
||||
|
||||
### [CCS (Claude Code Switch)](https://github.com/kaitranntt/ccs)
|
||||
|
||||
CLI wrapper for instant switching between multiple Claude accounts and alternative models (Gemini, Codex, Antigravity) via CLIProxyAPI OAuth - no API keys needed
|
||||
|
||||
> [!NOTE]
|
||||
> If you developed a project based on CLIProxyAPI, please open a PR to add it to this list.
|
||||
|
||||
|
||||
@@ -89,6 +89,10 @@ CLIProxyAPI 已内置对 [Amp CLI](https://ampcode.com) 和 Amp IDE 扩展的支
|
||||
|
||||
一款基于浏览器的 SRT 字幕翻译工具,可通过 CLI 代理 API 使用您的 Gemini 订阅。内置自动验证与错误修正功能,无需 API 密钥。
|
||||
|
||||
### [CCS (Claude Code Switch)](https://github.com/kaitranntt/ccs)
|
||||
|
||||
CLI 封装器,用于通过 CLIProxyAPI OAuth 即时切换多个 Claude 账户和替代模型(Gemini, Codex, Antigravity),无需 API 密钥。
|
||||
|
||||
> [!NOTE]
|
||||
> 如果你开发了基于 CLIProxyAPI 的项目,请提交一个 PR(拉取请求)将其添加到此列表中。
|
||||
|
||||
|
||||
@@ -1,6 +1,12 @@
|
||||
# Server port
|
||||
port: 8317
|
||||
|
||||
# TLS settings for HTTPS. When enabled, the server listens with the provided certificate and key.
|
||||
tls:
|
||||
enable: false
|
||||
cert: ""
|
||||
key: ""
|
||||
|
||||
# Management API settings
|
||||
remote-management:
|
||||
# Whether to allow remote (non-localhost) management access.
|
||||
@@ -38,6 +44,9 @@ proxy-url: ""
|
||||
# Number of times to retry a request. Retries will occur if the HTTP response code is 403, 408, 500, 502, 503, or 504.
|
||||
request-retry: 3
|
||||
|
||||
# Maximum wait time in seconds for a cooled-down credential before triggering a retry.
|
||||
max-retry-interval: 30
|
||||
|
||||
# Quota exceeded behavior
|
||||
quota-exceeded:
|
||||
switch-project: true # Whether to automatically switch to another project when a quota is exceeded
|
||||
@@ -46,6 +55,28 @@ quota-exceeded:
|
||||
# When true, enable authentication for the WebSocket API (/v1/ws).
|
||||
ws-auth: false
|
||||
|
||||
# Amp CLI Integration
|
||||
# Configure upstream URL for Amp CLI OAuth and management features
|
||||
#amp-upstream-url: "https://ampcode.com"
|
||||
|
||||
# Optional: Override API key for Amp upstream (otherwise uses env or file)
|
||||
#amp-upstream-api-key: ""
|
||||
|
||||
# Restrict Amp management routes (/api/auth, /api/user, etc.) to localhost only (recommended)
|
||||
#amp-restrict-management-to-localhost: true
|
||||
|
||||
# Amp Model Mappings
|
||||
# Route unavailable Amp models to alternative models available in your local proxy.
|
||||
# Useful when Amp CLI requests models you don't have access to (e.g., Claude Opus 4.5)
|
||||
# but you have a similar model available (e.g., Claude Sonnet 4).
|
||||
#amp-model-mappings:
|
||||
# - from: "claude-opus-4.5" # Model requested by Amp CLI
|
||||
# to: "claude-sonnet-4" # Route to this available model instead
|
||||
# - from: "gpt-5"
|
||||
# to: "gemini-2.5-pro"
|
||||
# - from: "claude-3-opus-20240229"
|
||||
# to: "claude-3-5-sonnet-20241022"
|
||||
|
||||
# Gemini API keys (preferred)
|
||||
#gemini-api-key:
|
||||
# - api-key: "AIzaSy...01"
|
||||
@@ -53,6 +84,11 @@ ws-auth: false
|
||||
# headers:
|
||||
# X-Custom-Header: "custom-value"
|
||||
# proxy-url: "socks5://proxy.example.com:1080"
|
||||
# excluded-models:
|
||||
# - "gemini-2.5-pro" # exclude specific models from this provider (exact match)
|
||||
# - "gemini-2.5-*" # wildcard matching prefix (e.g. gemini-2.5-flash, gemini-2.5-pro)
|
||||
# - "*-preview" # wildcard matching suffix (e.g. gemini-3-pro-preview)
|
||||
# - "*flash*" # wildcard matching substring (e.g. gemini-2.5-flash-lite)
|
||||
# - api-key: "AIzaSy...02"
|
||||
|
||||
# API keys for official Generative Language API (legacy compatibility)
|
||||
@@ -67,6 +103,11 @@ ws-auth: false
|
||||
# headers:
|
||||
# X-Custom-Header: "custom-value"
|
||||
# proxy-url: "socks5://proxy.example.com:1080" # optional: per-key proxy override
|
||||
# excluded-models:
|
||||
# - "gpt-5.1" # exclude specific models (exact match)
|
||||
# - "gpt-5-*" # wildcard matching prefix (e.g. gpt-5-medium, gpt-5-codex)
|
||||
# - "*-mini" # wildcard matching suffix (e.g. gpt-5-codex-mini)
|
||||
# - "*codex*" # wildcard matching substring (e.g. gpt-5-codex-low)
|
||||
|
||||
# Claude API keys
|
||||
#claude-api-key:
|
||||
@@ -79,6 +120,11 @@ ws-auth: false
|
||||
# models:
|
||||
# - name: "claude-3-5-sonnet-20241022" # upstream model name
|
||||
# alias: "claude-sonnet-latest" # client alias mapped to the upstream model
|
||||
# excluded-models:
|
||||
# - "claude-opus-4-5-20251101" # exclude specific models (exact match)
|
||||
# - "claude-3-*" # wildcard matching prefix (e.g. claude-3-7-sonnet-20250219)
|
||||
# - "*-think" # wildcard matching suffix (e.g. claude-opus-4-5-thinking)
|
||||
# - "*haiku*" # wildcard matching substring (e.g. claude-3-5-haiku-20241022)
|
||||
|
||||
# OpenAI compatibility providers
|
||||
#openai-compatibility:
|
||||
@@ -99,6 +145,19 @@ ws-auth: false
|
||||
# - name: "moonshotai/kimi-k2:free" # The actual model name.
|
||||
# alias: "kimi-k2" # The alias used in the API.
|
||||
|
||||
# Vertex API keys (Vertex-compatible endpoints, use API key + base URL)
|
||||
#vertex-api-key:
|
||||
# - api-key: "vk-123..." # x-goog-api-key header
|
||||
# base-url: "https://example.com/api" # e.g. https://zenmux.ai/api
|
||||
# proxy-url: "socks5://proxy.example.com:1080" # optional per-key proxy override
|
||||
# headers:
|
||||
# X-Custom-Header: "custom-value"
|
||||
# models: # optional: map aliases to upstream model names
|
||||
# - name: "gemini-2.0-flash" # upstream model name
|
||||
# alias: "vertex-flash" # client-visible alias
|
||||
# - name: "gemini-1.5-pro"
|
||||
# alias: "vertex-pro"
|
||||
|
||||
#payload: # Optional payload configuration
|
||||
# default: # Default rules only set parameters when they are missing in the payload.
|
||||
# - models:
|
||||
@@ -112,3 +171,25 @@ ws-auth: false
|
||||
# protocol: "codex" # restricts the rule to a specific protocol, options: openai, gemini, claude, codex
|
||||
# params: # JSON path (gjson/sjson syntax) -> value
|
||||
# "reasoning.effort": "high"
|
||||
|
||||
# OAuth provider excluded models
|
||||
#oauth-excluded-models:
|
||||
# gemini-cli:
|
||||
# - "gemini-2.5-pro" # exclude specific models (exact match)
|
||||
# - "gemini-2.5-*" # wildcard matching prefix (e.g. gemini-2.5-flash, gemini-2.5-pro)
|
||||
# - "*-preview" # wildcard matching suffix (e.g. gemini-3-pro-preview)
|
||||
# - "*flash*" # wildcard matching substring (e.g. gemini-2.5-flash-lite)
|
||||
# vertex:
|
||||
# - "gemini-3-pro-preview"
|
||||
# aistudio:
|
||||
# - "gemini-3-pro-preview"
|
||||
# antigravity:
|
||||
# - "gemini-3-pro-preview"
|
||||
# claude:
|
||||
# - "claude-3-5-haiku-20241022"
|
||||
# codex:
|
||||
# - "gpt-5-codex-mini"
|
||||
# qwen:
|
||||
# - "vision-model"
|
||||
# iflow:
|
||||
# - "tstars2.0"
|
||||
|
||||
@@ -8,6 +8,7 @@ This guide explains how to use CLIProxyAPI with Amp CLI and Amp IDE extensions,
|
||||
- [Which Providers Should You Authenticate?](#which-providers-should-you-authenticate)
|
||||
- [Architecture](#architecture)
|
||||
- [Configuration](#configuration)
|
||||
- [Model Mapping Configuration](#model-mapping-configuration)
|
||||
- [Setup](#setup)
|
||||
- [Usage](#usage)
|
||||
- [Troubleshooting](#troubleshooting)
|
||||
@@ -21,6 +22,7 @@ The Amp CLI integration adds specialized routing to support Amp's API patterns w
|
||||
- **Provider route aliases**: Maps Amp's `/api/provider/{provider}/v1...` patterns to CLIProxyAPI handlers
|
||||
- **Management proxy**: Forwards OAuth and account management requests to Amp's control plane
|
||||
- **Smart fallback**: Automatically routes unconfigured models to ampcode.com
|
||||
- **Model mapping**: Route unavailable models to alternatives you have access to (e.g., `claude-opus-4.5` → `claude-sonnet-4`)
|
||||
- **Secret management**: Configurable precedence (config > env > file) with 5-minute caching
|
||||
- **Security-first**: Management routes restricted to localhost by default
|
||||
- **Automatic gzip handling**: Decompresses responses from Amp upstream
|
||||
@@ -75,7 +77,10 @@ Amp CLI/IDE
|
||||
│ ↓
|
||||
│ ├─ Model configured locally?
|
||||
│ │ YES → Use local OAuth tokens (OpenAI/Claude/Gemini handlers)
|
||||
│ │ NO → Forward to ampcode.com (reverse proxy)
|
||||
│ │ NO ↓
|
||||
│ │ ├─ Model mapping configured?
|
||||
│ │ │ YES → Rewrite model → Use local handler (free)
|
||||
│ │ │ NO → Forward to ampcode.com (uses Amp credits)
|
||||
│ ↓
|
||||
│ Response
|
||||
│
|
||||
@@ -115,6 +120,49 @@ amp-upstream-url: "https://ampcode.com"
|
||||
amp-restrict-management-to-localhost: true
|
||||
```
|
||||
|
||||
### Model Mapping Configuration
|
||||
|
||||
When Amp CLI requests a model that you don't have access to, you can configure mappings to route those requests to alternative models that you DO have available. This avoids consuming Amp credits for models you could handle locally.
|
||||
|
||||
```yaml
|
||||
# Route unavailable models to alternatives
|
||||
amp-model-mappings:
|
||||
# Example: Route Claude Opus 4.5 requests to Claude Sonnet 4
|
||||
- from: "claude-opus-4.5"
|
||||
to: "claude-sonnet-4"
|
||||
|
||||
# Example: Route GPT-5 requests to Gemini 2.5 Pro
|
||||
- from: "gpt-5"
|
||||
to: "gemini-2.5-pro"
|
||||
|
||||
# Example: Map older model names to newer versions
|
||||
- from: "claude-3-opus-20240229"
|
||||
to: "claude-3-5-sonnet-20241022"
|
||||
```
|
||||
|
||||
**How it works:**
|
||||
|
||||
1. Amp CLI requests a model (e.g., `claude-opus-4.5`)
|
||||
2. CLIProxyAPI checks if a local provider is available for that model
|
||||
3. If not available, it checks the model mappings
|
||||
4. If a mapping exists, the request is rewritten to use the target model
|
||||
5. The request is then handled locally (free, using your OAuth subscription)
|
||||
|
||||
**Benefits:**
|
||||
- **Save Amp credits**: Use your local subscriptions instead of forwarding to ampcode.com
|
||||
- **Hot-reload**: Mappings can be updated without restarting the proxy
|
||||
- **Structured logging**: Clear logs show when mappings are applied
|
||||
|
||||
**Routing Decision Logs:**
|
||||
|
||||
The proxy logs each routing decision with structured fields:
|
||||
|
||||
```
|
||||
[AMP] Using local provider for model: gemini-2.5-pro # Local provider (free)
|
||||
[AMP] Model mapped: claude-opus-4.5 -> claude-sonnet-4 # Mapping applied (free)
|
||||
[AMP] Forwarding to ampcode.com (uses Amp credits) - model_id: gpt-5 # Fallback (costs credits)
|
||||
```
|
||||
|
||||
### Secret Resolution Precedence
|
||||
|
||||
The Amp module resolves API keys using this precedence order:
|
||||
@@ -301,11 +349,14 @@ When Amp requests a model:
|
||||
|
||||
1. **Check local configuration**: Does CLIProxyAPI have OAuth tokens for this model's provider?
|
||||
2. **If YES**: Route to local handler (use your OAuth subscription)
|
||||
3. **If NO**: Forward to ampcode.com (use Amp's default routing)
|
||||
3. **If NO**: Check if a model mapping exists
|
||||
4. **If mapping exists**: Rewrite request to mapped model → Route to local handler (free)
|
||||
5. **If no mapping**: Forward to ampcode.com (uses Amp credits)
|
||||
|
||||
This enables seamless mixed usage:
|
||||
- Models you've configured (Gemini, ChatGPT, Claude) → Your OAuth subscriptions
|
||||
- Models you haven't configured → Amp's default providers
|
||||
- Models with mappings configured → Routed to alternative local models (free)
|
||||
- Models you haven't configured and have no mapping → Amp's default providers (uses credits)
|
||||
|
||||
### Example API Calls
|
||||
|
||||
|
||||
@@ -235,7 +235,11 @@ func (h *Handler) managementCallbackURL(path string) (string, error) {
|
||||
if !strings.HasPrefix(path, "/") {
|
||||
path = "/" + path
|
||||
}
|
||||
return fmt.Sprintf("http://127.0.0.1:%d%s", h.cfg.Port, path), nil
|
||||
scheme := "http"
|
||||
if h.cfg.TLS.Enable {
|
||||
scheme = "https"
|
||||
}
|
||||
return fmt.Sprintf("%s://127.0.0.1:%d%s", scheme, h.cfg.Port, path), nil
|
||||
}
|
||||
|
||||
func (h *Handler) ListAuthFiles(c *gin.Context) {
|
||||
|
||||
@@ -172,6 +172,14 @@ func (h *Handler) PutRequestRetry(c *gin.Context) {
|
||||
h.updateIntField(c, func(v int) { h.cfg.RequestRetry = v })
|
||||
}
|
||||
|
||||
// Max retry interval
|
||||
func (h *Handler) GetMaxRetryInterval(c *gin.Context) {
|
||||
c.JSON(200, gin.H{"max-retry-interval": h.cfg.MaxRetryInterval})
|
||||
}
|
||||
func (h *Handler) PutMaxRetryInterval(c *gin.Context) {
|
||||
h.updateIntField(c, func(v int) { h.cfg.MaxRetryInterval = v })
|
||||
}
|
||||
|
||||
// Proxy URL
|
||||
func (h *Handler) GetProxyURL(c *gin.Context) { c.JSON(200, gin.H{"proxy-url": h.cfg.ProxyURL}) }
|
||||
func (h *Handler) PutProxyURL(c *gin.Context) {
|
||||
|
||||
@@ -223,6 +223,7 @@ func (h *Handler) PatchGeminiKey(c *gin.Context) {
|
||||
value.APIKey = strings.TrimSpace(value.APIKey)
|
||||
value.BaseURL = strings.TrimSpace(value.BaseURL)
|
||||
value.ProxyURL = strings.TrimSpace(value.ProxyURL)
|
||||
value.ExcludedModels = config.NormalizeExcludedModels(value.ExcludedModels)
|
||||
if value.APIKey == "" {
|
||||
// Treat empty API key as delete.
|
||||
if body.Index != nil && *body.Index >= 0 && *body.Index < len(h.cfg.GeminiKey) {
|
||||
@@ -504,6 +505,91 @@ func (h *Handler) DeleteOpenAICompat(c *gin.Context) {
|
||||
c.JSON(400, gin.H{"error": "missing name or index"})
|
||||
}
|
||||
|
||||
// oauth-excluded-models: map[string][]string
|
||||
func (h *Handler) GetOAuthExcludedModels(c *gin.Context) {
|
||||
c.JSON(200, gin.H{"oauth-excluded-models": config.NormalizeOAuthExcludedModels(h.cfg.OAuthExcludedModels)})
|
||||
}
|
||||
|
||||
func (h *Handler) PutOAuthExcludedModels(c *gin.Context) {
|
||||
data, err := c.GetRawData()
|
||||
if err != nil {
|
||||
c.JSON(400, gin.H{"error": "failed to read body"})
|
||||
return
|
||||
}
|
||||
var entries map[string][]string
|
||||
if err = json.Unmarshal(data, &entries); err != nil {
|
||||
var wrapper struct {
|
||||
Items map[string][]string `json:"items"`
|
||||
}
|
||||
if err2 := json.Unmarshal(data, &wrapper); err2 != nil {
|
||||
c.JSON(400, gin.H{"error": "invalid body"})
|
||||
return
|
||||
}
|
||||
entries = wrapper.Items
|
||||
}
|
||||
h.cfg.OAuthExcludedModels = config.NormalizeOAuthExcludedModels(entries)
|
||||
h.persist(c)
|
||||
}
|
||||
|
||||
func (h *Handler) PatchOAuthExcludedModels(c *gin.Context) {
|
||||
var body struct {
|
||||
Provider *string `json:"provider"`
|
||||
Models []string `json:"models"`
|
||||
}
|
||||
if err := c.ShouldBindJSON(&body); err != nil || body.Provider == nil {
|
||||
c.JSON(400, gin.H{"error": "invalid body"})
|
||||
return
|
||||
}
|
||||
provider := strings.ToLower(strings.TrimSpace(*body.Provider))
|
||||
if provider == "" {
|
||||
c.JSON(400, gin.H{"error": "invalid provider"})
|
||||
return
|
||||
}
|
||||
normalized := config.NormalizeExcludedModels(body.Models)
|
||||
if len(normalized) == 0 {
|
||||
if h.cfg.OAuthExcludedModels == nil {
|
||||
c.JSON(404, gin.H{"error": "provider not found"})
|
||||
return
|
||||
}
|
||||
if _, ok := h.cfg.OAuthExcludedModels[provider]; !ok {
|
||||
c.JSON(404, gin.H{"error": "provider not found"})
|
||||
return
|
||||
}
|
||||
delete(h.cfg.OAuthExcludedModels, provider)
|
||||
if len(h.cfg.OAuthExcludedModels) == 0 {
|
||||
h.cfg.OAuthExcludedModels = nil
|
||||
}
|
||||
h.persist(c)
|
||||
return
|
||||
}
|
||||
if h.cfg.OAuthExcludedModels == nil {
|
||||
h.cfg.OAuthExcludedModels = make(map[string][]string)
|
||||
}
|
||||
h.cfg.OAuthExcludedModels[provider] = normalized
|
||||
h.persist(c)
|
||||
}
|
||||
|
||||
func (h *Handler) DeleteOAuthExcludedModels(c *gin.Context) {
|
||||
provider := strings.ToLower(strings.TrimSpace(c.Query("provider")))
|
||||
if provider == "" {
|
||||
c.JSON(400, gin.H{"error": "missing provider"})
|
||||
return
|
||||
}
|
||||
if h.cfg.OAuthExcludedModels == nil {
|
||||
c.JSON(404, gin.H{"error": "provider not found"})
|
||||
return
|
||||
}
|
||||
if _, ok := h.cfg.OAuthExcludedModels[provider]; !ok {
|
||||
c.JSON(404, gin.H{"error": "provider not found"})
|
||||
return
|
||||
}
|
||||
delete(h.cfg.OAuthExcludedModels, provider)
|
||||
if len(h.cfg.OAuthExcludedModels) == 0 {
|
||||
h.cfg.OAuthExcludedModels = nil
|
||||
}
|
||||
h.persist(c)
|
||||
}
|
||||
|
||||
// codex-api-key: []CodexKey
|
||||
func (h *Handler) GetCodexKeys(c *gin.Context) {
|
||||
c.JSON(200, gin.H{"codex-api-key": h.cfg.CodexKey})
|
||||
@@ -533,6 +619,7 @@ func (h *Handler) PutCodexKeys(c *gin.Context) {
|
||||
entry.BaseURL = strings.TrimSpace(entry.BaseURL)
|
||||
entry.ProxyURL = strings.TrimSpace(entry.ProxyURL)
|
||||
entry.Headers = config.NormalizeHeaders(entry.Headers)
|
||||
entry.ExcludedModels = config.NormalizeExcludedModels(entry.ExcludedModels)
|
||||
if entry.BaseURL == "" {
|
||||
continue
|
||||
}
|
||||
@@ -557,6 +644,7 @@ func (h *Handler) PatchCodexKey(c *gin.Context) {
|
||||
value.BaseURL = strings.TrimSpace(value.BaseURL)
|
||||
value.ProxyURL = strings.TrimSpace(value.ProxyURL)
|
||||
value.Headers = config.NormalizeHeaders(value.Headers)
|
||||
value.ExcludedModels = config.NormalizeExcludedModels(value.ExcludedModels)
|
||||
// If base-url becomes empty, delete instead of update
|
||||
if value.BaseURL == "" {
|
||||
if body.Index != nil && *body.Index >= 0 && *body.Index < len(h.cfg.CodexKey) {
|
||||
@@ -694,6 +782,7 @@ func normalizeClaudeKey(entry *config.ClaudeKey) {
|
||||
entry.BaseURL = strings.TrimSpace(entry.BaseURL)
|
||||
entry.ProxyURL = strings.TrimSpace(entry.ProxyURL)
|
||||
entry.Headers = config.NormalizeHeaders(entry.Headers)
|
||||
entry.ExcludedModels = config.NormalizeExcludedModels(entry.ExcludedModels)
|
||||
if len(entry.Models) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
@@ -58,8 +58,14 @@ func (h *Handler) GetLogs(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
limit, errLimit := parseLimit(c.Query("limit"))
|
||||
if errLimit != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("invalid limit: %v", errLimit)})
|
||||
return
|
||||
}
|
||||
|
||||
cutoff := parseCutoff(c.Query("after"))
|
||||
acc := newLogAccumulator(cutoff)
|
||||
acc := newLogAccumulator(cutoff, limit)
|
||||
for i := range files {
|
||||
if errProcess := acc.consumeFile(files[i]); errProcess != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("failed to read log file %s: %v", files[i], errProcess)})
|
||||
@@ -139,6 +145,126 @@ func (h *Handler) DeleteLogs(c *gin.Context) {
|
||||
})
|
||||
}
|
||||
|
||||
// GetRequestErrorLogs lists error request log files when RequestLog is disabled.
|
||||
// It returns an empty list when RequestLog is enabled.
|
||||
func (h *Handler) GetRequestErrorLogs(c *gin.Context) {
|
||||
if h == nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "handler unavailable"})
|
||||
return
|
||||
}
|
||||
if h.cfg == nil {
|
||||
c.JSON(http.StatusServiceUnavailable, gin.H{"error": "configuration unavailable"})
|
||||
return
|
||||
}
|
||||
if h.cfg.RequestLog {
|
||||
c.JSON(http.StatusOK, gin.H{"files": []any{}})
|
||||
return
|
||||
}
|
||||
|
||||
dir := h.logDirectory()
|
||||
if strings.TrimSpace(dir) == "" {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "log directory not configured"})
|
||||
return
|
||||
}
|
||||
|
||||
entries, err := os.ReadDir(dir)
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
c.JSON(http.StatusOK, gin.H{"files": []any{}})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("failed to list request error logs: %v", err)})
|
||||
return
|
||||
}
|
||||
|
||||
type errorLog struct {
|
||||
Name string `json:"name"`
|
||||
Size int64 `json:"size"`
|
||||
Modified int64 `json:"modified"`
|
||||
}
|
||||
|
||||
files := make([]errorLog, 0, len(entries))
|
||||
for _, entry := range entries {
|
||||
if entry.IsDir() {
|
||||
continue
|
||||
}
|
||||
name := entry.Name()
|
||||
if !strings.HasPrefix(name, "error-") || !strings.HasSuffix(name, ".log") {
|
||||
continue
|
||||
}
|
||||
info, errInfo := entry.Info()
|
||||
if errInfo != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("failed to read log info for %s: %v", name, errInfo)})
|
||||
return
|
||||
}
|
||||
files = append(files, errorLog{
|
||||
Name: name,
|
||||
Size: info.Size(),
|
||||
Modified: info.ModTime().Unix(),
|
||||
})
|
||||
}
|
||||
|
||||
sort.Slice(files, func(i, j int) bool { return files[i].Modified > files[j].Modified })
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{"files": files})
|
||||
}
|
||||
|
||||
// DownloadRequestErrorLog downloads a specific error request log file by name.
|
||||
func (h *Handler) DownloadRequestErrorLog(c *gin.Context) {
|
||||
if h == nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "handler unavailable"})
|
||||
return
|
||||
}
|
||||
if h.cfg == nil {
|
||||
c.JSON(http.StatusServiceUnavailable, gin.H{"error": "configuration unavailable"})
|
||||
return
|
||||
}
|
||||
|
||||
dir := h.logDirectory()
|
||||
if strings.TrimSpace(dir) == "" {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "log directory not configured"})
|
||||
return
|
||||
}
|
||||
|
||||
name := strings.TrimSpace(c.Param("name"))
|
||||
if name == "" || strings.Contains(name, "/") || strings.Contains(name, "\\") {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid log file name"})
|
||||
return
|
||||
}
|
||||
if !strings.HasPrefix(name, "error-") || !strings.HasSuffix(name, ".log") {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "log file not found"})
|
||||
return
|
||||
}
|
||||
|
||||
dirAbs, errAbs := filepath.Abs(dir)
|
||||
if errAbs != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("failed to resolve log directory: %v", errAbs)})
|
||||
return
|
||||
}
|
||||
fullPath := filepath.Clean(filepath.Join(dirAbs, name))
|
||||
prefix := dirAbs + string(os.PathSeparator)
|
||||
if !strings.HasPrefix(fullPath, prefix) {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid log file path"})
|
||||
return
|
||||
}
|
||||
|
||||
info, errStat := os.Stat(fullPath)
|
||||
if errStat != nil {
|
||||
if os.IsNotExist(errStat) {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "log file not found"})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("failed to read log file: %v", errStat)})
|
||||
return
|
||||
}
|
||||
if info.IsDir() {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid log file"})
|
||||
return
|
||||
}
|
||||
|
||||
c.FileAttachment(fullPath, name)
|
||||
}
|
||||
|
||||
func (h *Handler) logDirectory() string {
|
||||
if h == nil {
|
||||
return ""
|
||||
@@ -194,16 +320,22 @@ func (h *Handler) collectLogFiles(dir string) ([]string, error) {
|
||||
|
||||
type logAccumulator struct {
|
||||
cutoff int64
|
||||
limit int
|
||||
lines []string
|
||||
total int
|
||||
latest int64
|
||||
include bool
|
||||
}
|
||||
|
||||
func newLogAccumulator(cutoff int64) *logAccumulator {
|
||||
func newLogAccumulator(cutoff int64, limit int) *logAccumulator {
|
||||
capacity := 256
|
||||
if limit > 0 && limit < capacity {
|
||||
capacity = limit
|
||||
}
|
||||
return &logAccumulator{
|
||||
cutoff: cutoff,
|
||||
lines: make([]string, 0, 256),
|
||||
limit: limit,
|
||||
lines: make([]string, 0, capacity),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -215,7 +347,9 @@ func (acc *logAccumulator) consumeFile(path string) error {
|
||||
}
|
||||
return err
|
||||
}
|
||||
defer file.Close()
|
||||
defer func() {
|
||||
_ = file.Close()
|
||||
}()
|
||||
|
||||
scanner := bufio.NewScanner(file)
|
||||
buf := make([]byte, 0, logScannerInitialBuffer)
|
||||
@@ -239,12 +373,19 @@ func (acc *logAccumulator) addLine(raw string) {
|
||||
if ts > 0 {
|
||||
acc.include = acc.cutoff == 0 || ts > acc.cutoff
|
||||
if acc.cutoff == 0 || acc.include {
|
||||
acc.lines = append(acc.lines, line)
|
||||
acc.append(line)
|
||||
}
|
||||
return
|
||||
}
|
||||
if acc.cutoff == 0 || acc.include {
|
||||
acc.lines = append(acc.lines, line)
|
||||
acc.append(line)
|
||||
}
|
||||
}
|
||||
|
||||
func (acc *logAccumulator) append(line string) {
|
||||
acc.lines = append(acc.lines, line)
|
||||
if acc.limit > 0 && len(acc.lines) > acc.limit {
|
||||
acc.lines = acc.lines[len(acc.lines)-acc.limit:]
|
||||
}
|
||||
}
|
||||
|
||||
@@ -267,6 +408,21 @@ func parseCutoff(raw string) int64 {
|
||||
return ts
|
||||
}
|
||||
|
||||
func parseLimit(raw string) (int, error) {
|
||||
value := strings.TrimSpace(raw)
|
||||
if value == "" {
|
||||
return 0, nil
|
||||
}
|
||||
limit, err := strconv.Atoi(value)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("must be a positive integer")
|
||||
}
|
||||
if limit <= 0 {
|
||||
return 0, fmt.Errorf("must be greater than zero")
|
||||
}
|
||||
return limit, nil
|
||||
}
|
||||
|
||||
func parseTimestamp(line string) int64 {
|
||||
if strings.HasPrefix(line, "[") {
|
||||
line = line[1:]
|
||||
|
||||
@@ -6,6 +6,7 @@ package middleware
|
||||
import (
|
||||
"bytes"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
@@ -15,8 +16,8 @@ import (
|
||||
|
||||
// RequestLoggingMiddleware creates a Gin middleware that logs HTTP requests and responses.
|
||||
// It captures detailed information about the request and response, including headers and body,
|
||||
// and uses the provided RequestLogger to record this data. If logging is disabled in the
|
||||
// logger, the middleware has minimal overhead.
|
||||
// and uses the provided RequestLogger to record this data. When logging is disabled in the
|
||||
// logger, it still captures data so that upstream errors can be persisted.
|
||||
func RequestLoggingMiddleware(logger logging.RequestLogger) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
if logger == nil {
|
||||
@@ -24,14 +25,13 @@ func RequestLoggingMiddleware(logger logging.RequestLogger) gin.HandlerFunc {
|
||||
return
|
||||
}
|
||||
|
||||
path := c.Request.URL.Path
|
||||
if !shouldLogRequest(path) {
|
||||
if c.Request.Method == http.MethodGet {
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
|
||||
// Early return if logging is disabled (zero overhead)
|
||||
if !logger.IsEnabled() {
|
||||
path := c.Request.URL.Path
|
||||
if !shouldLogRequest(path) {
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
@@ -47,6 +47,9 @@ func RequestLoggingMiddleware(logger logging.RequestLogger) gin.HandlerFunc {
|
||||
|
||||
// Create response writer wrapper
|
||||
wrapper := NewResponseWriterWrapper(c.Writer, logger, requestInfo)
|
||||
if !logger.IsEnabled() {
|
||||
wrapper.logOnErrorOnly = true
|
||||
}
|
||||
c.Writer = wrapper
|
||||
|
||||
// Process the request
|
||||
|
||||
@@ -5,6 +5,7 @@ package middleware
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
@@ -24,15 +25,16 @@ type RequestInfo struct {
|
||||
// It is designed to handle both standard and streaming responses, ensuring that logging operations do not block the client response.
|
||||
type ResponseWriterWrapper struct {
|
||||
gin.ResponseWriter
|
||||
body *bytes.Buffer // body is a buffer to store the response body for non-streaming responses.
|
||||
isStreaming bool // isStreaming indicates whether the response is a streaming type (e.g., text/event-stream).
|
||||
streamWriter logging.StreamingLogWriter // streamWriter is a writer for handling streaming log entries.
|
||||
chunkChannel chan []byte // chunkChannel is a channel for asynchronously passing response chunks to the logger.
|
||||
streamDone chan struct{} // streamDone signals when the streaming goroutine completes.
|
||||
logger logging.RequestLogger // logger is the instance of the request logger service.
|
||||
requestInfo *RequestInfo // requestInfo holds the details of the original request.
|
||||
statusCode int // statusCode stores the HTTP status code of the response.
|
||||
headers map[string][]string // headers stores the response headers.
|
||||
body *bytes.Buffer // body is a buffer to store the response body for non-streaming responses.
|
||||
isStreaming bool // isStreaming indicates whether the response is a streaming type (e.g., text/event-stream).
|
||||
streamWriter logging.StreamingLogWriter // streamWriter is a writer for handling streaming log entries.
|
||||
chunkChannel chan []byte // chunkChannel is a channel for asynchronously passing response chunks to the logger.
|
||||
streamDone chan struct{} // streamDone signals when the streaming goroutine completes.
|
||||
logger logging.RequestLogger // logger is the instance of the request logger service.
|
||||
requestInfo *RequestInfo // requestInfo holds the details of the original request.
|
||||
statusCode int // statusCode stores the HTTP status code of the response.
|
||||
headers map[string][]string // headers stores the response headers.
|
||||
logOnErrorOnly bool // logOnErrorOnly enables logging only when an error response is detected.
|
||||
}
|
||||
|
||||
// NewResponseWriterWrapper creates and initializes a new ResponseWriterWrapper.
|
||||
@@ -192,12 +194,34 @@ func (w *ResponseWriterWrapper) processStreamingChunks(done chan struct{}) {
|
||||
// For non-streaming responses, it logs the complete request and response details,
|
||||
// including any API-specific request/response data stored in the Gin context.
|
||||
func (w *ResponseWriterWrapper) Finalize(c *gin.Context) error {
|
||||
if !w.logger.IsEnabled() {
|
||||
if w.logger == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
finalStatusCode := w.statusCode
|
||||
if finalStatusCode == 0 {
|
||||
if statusWriter, ok := w.ResponseWriter.(interface{ Status() int }); ok {
|
||||
finalStatusCode = statusWriter.Status()
|
||||
} else {
|
||||
finalStatusCode = 200
|
||||
}
|
||||
}
|
||||
|
||||
var slicesAPIResponseError []*interfaces.ErrorMessage
|
||||
apiResponseError, isExist := c.Get("API_RESPONSE_ERROR")
|
||||
if isExist {
|
||||
if apiErrors, ok := apiResponseError.([]*interfaces.ErrorMessage); ok {
|
||||
slicesAPIResponseError = apiErrors
|
||||
}
|
||||
}
|
||||
|
||||
hasAPIError := len(slicesAPIResponseError) > 0 || finalStatusCode >= http.StatusBadRequest
|
||||
forceLog := w.logOnErrorOnly && hasAPIError && !w.logger.IsEnabled()
|
||||
if !w.logger.IsEnabled() && !forceLog {
|
||||
return nil
|
||||
}
|
||||
|
||||
if w.isStreaming {
|
||||
// Close streaming channel and writer
|
||||
if w.chunkChannel != nil {
|
||||
close(w.chunkChannel)
|
||||
w.chunkChannel = nil
|
||||
@@ -209,80 +233,98 @@ func (w *ResponseWriterWrapper) Finalize(c *gin.Context) error {
|
||||
}
|
||||
|
||||
if w.streamWriter != nil {
|
||||
err := w.streamWriter.Close()
|
||||
if err := w.streamWriter.Close(); err != nil {
|
||||
w.streamWriter = nil
|
||||
return err
|
||||
}
|
||||
w.streamWriter = nil
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
// Capture final status code and headers if not already captured
|
||||
finalStatusCode := w.statusCode
|
||||
if finalStatusCode == 0 {
|
||||
// Get status from underlying ResponseWriter if available
|
||||
if statusWriter, ok := w.ResponseWriter.(interface{ Status() int }); ok {
|
||||
finalStatusCode = statusWriter.Status()
|
||||
} else {
|
||||
finalStatusCode = 200 // Default
|
||||
}
|
||||
if forceLog {
|
||||
return w.logRequest(finalStatusCode, w.cloneHeaders(), w.body.Bytes(), w.extractAPIRequest(c), w.extractAPIResponse(c), slicesAPIResponseError, forceLog)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Ensure we have the latest headers before finalizing
|
||||
w.ensureHeadersCaptured()
|
||||
return w.logRequest(finalStatusCode, w.cloneHeaders(), w.body.Bytes(), w.extractAPIRequest(c), w.extractAPIResponse(c), slicesAPIResponseError, forceLog)
|
||||
}
|
||||
|
||||
// Use the captured headers as the final headers
|
||||
finalHeaders := make(map[string][]string)
|
||||
for key, values := range w.headers {
|
||||
// Make a copy of the values slice to avoid reference issues
|
||||
headerValues := make([]string, len(values))
|
||||
copy(headerValues, values)
|
||||
finalHeaders[key] = headerValues
|
||||
}
|
||||
func (w *ResponseWriterWrapper) cloneHeaders() map[string][]string {
|
||||
w.ensureHeadersCaptured()
|
||||
|
||||
var apiRequestBody []byte
|
||||
apiRequest, isExist := c.Get("API_REQUEST")
|
||||
if isExist {
|
||||
var ok bool
|
||||
apiRequestBody, ok = apiRequest.([]byte)
|
||||
if !ok {
|
||||
apiRequestBody = nil
|
||||
}
|
||||
}
|
||||
finalHeaders := make(map[string][]string, len(w.headers))
|
||||
for key, values := range w.headers {
|
||||
headerValues := make([]string, len(values))
|
||||
copy(headerValues, values)
|
||||
finalHeaders[key] = headerValues
|
||||
}
|
||||
|
||||
var apiResponseBody []byte
|
||||
apiResponse, isExist := c.Get("API_RESPONSE")
|
||||
if isExist {
|
||||
var ok bool
|
||||
apiResponseBody, ok = apiResponse.([]byte)
|
||||
if !ok {
|
||||
apiResponseBody = nil
|
||||
}
|
||||
}
|
||||
return finalHeaders
|
||||
}
|
||||
|
||||
var slicesAPIResponseError []*interfaces.ErrorMessage
|
||||
apiResponseError, isExist := c.Get("API_RESPONSE_ERROR")
|
||||
if isExist {
|
||||
var ok bool
|
||||
slicesAPIResponseError, ok = apiResponseError.([]*interfaces.ErrorMessage)
|
||||
if !ok {
|
||||
slicesAPIResponseError = nil
|
||||
}
|
||||
}
|
||||
func (w *ResponseWriterWrapper) extractAPIRequest(c *gin.Context) []byte {
|
||||
apiRequest, isExist := c.Get("API_REQUEST")
|
||||
if !isExist {
|
||||
return nil
|
||||
}
|
||||
data, ok := apiRequest.([]byte)
|
||||
if !ok || len(data) == 0 {
|
||||
return nil
|
||||
}
|
||||
return data
|
||||
}
|
||||
|
||||
// Log complete non-streaming response
|
||||
return w.logger.LogRequest(
|
||||
func (w *ResponseWriterWrapper) extractAPIResponse(c *gin.Context) []byte {
|
||||
apiResponse, isExist := c.Get("API_RESPONSE")
|
||||
if !isExist {
|
||||
return nil
|
||||
}
|
||||
data, ok := apiResponse.([]byte)
|
||||
if !ok || len(data) == 0 {
|
||||
return nil
|
||||
}
|
||||
return data
|
||||
}
|
||||
|
||||
func (w *ResponseWriterWrapper) logRequest(statusCode int, headers map[string][]string, body []byte, apiRequestBody, apiResponseBody []byte, apiResponseErrors []*interfaces.ErrorMessage, forceLog bool) error {
|
||||
if w.requestInfo == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
var requestBody []byte
|
||||
if len(w.requestInfo.Body) > 0 {
|
||||
requestBody = w.requestInfo.Body
|
||||
}
|
||||
|
||||
if loggerWithOptions, ok := w.logger.(interface {
|
||||
LogRequestWithOptions(string, string, map[string][]string, []byte, int, map[string][]string, []byte, []byte, []byte, []*interfaces.ErrorMessage, bool) error
|
||||
}); ok {
|
||||
return loggerWithOptions.LogRequestWithOptions(
|
||||
w.requestInfo.URL,
|
||||
w.requestInfo.Method,
|
||||
w.requestInfo.Headers,
|
||||
w.requestInfo.Body,
|
||||
finalStatusCode,
|
||||
finalHeaders,
|
||||
w.body.Bytes(),
|
||||
requestBody,
|
||||
statusCode,
|
||||
headers,
|
||||
body,
|
||||
apiRequestBody,
|
||||
apiResponseBody,
|
||||
slicesAPIResponseError,
|
||||
apiResponseErrors,
|
||||
forceLog,
|
||||
)
|
||||
}
|
||||
|
||||
return nil
|
||||
return w.logger.LogRequest(
|
||||
w.requestInfo.URL,
|
||||
w.requestInfo.Method,
|
||||
w.requestInfo.Headers,
|
||||
requestBody,
|
||||
statusCode,
|
||||
headers,
|
||||
body,
|
||||
apiRequestBody,
|
||||
apiResponseBody,
|
||||
apiResponseErrors,
|
||||
)
|
||||
}
|
||||
|
||||
// Status returns the HTTP response status code captured by the wrapper.
|
||||
|
||||
@@ -23,11 +23,13 @@ type Option func(*AmpModule)
|
||||
// - Reverse proxy to Amp control plane for OAuth/management
|
||||
// - Provider-specific route aliases (/api/provider/{provider}/...)
|
||||
// - Automatic gzip decompression for misconfigured upstreams
|
||||
// - Model mapping for routing unavailable models to alternatives
|
||||
type AmpModule struct {
|
||||
secretSource SecretSource
|
||||
proxy *httputil.ReverseProxy
|
||||
accessManager *sdkaccess.Manager
|
||||
authMiddleware_ gin.HandlerFunc
|
||||
modelMapper *DefaultModelMapper
|
||||
enabled bool
|
||||
registerOnce sync.Once
|
||||
}
|
||||
@@ -101,6 +103,9 @@ func (m *AmpModule) Register(ctx modules.Context) error {
|
||||
// Use registerOnce to ensure routes are only registered once
|
||||
var regErr error
|
||||
m.registerOnce.Do(func() {
|
||||
// Initialize model mapper from config (for routing unavailable models to alternatives)
|
||||
m.modelMapper = NewModelMapper(ctx.Config.AmpModelMappings)
|
||||
|
||||
// Always register provider aliases - these work without an upstream
|
||||
m.registerProviderAliases(ctx.Engine, ctx.BaseHandler, auth)
|
||||
|
||||
@@ -159,8 +164,16 @@ func (m *AmpModule) getAuthMiddleware(ctx modules.Context) gin.HandlerFunc {
|
||||
// OnConfigUpdated handles configuration updates.
|
||||
// Currently requires restart for URL changes (could be enhanced for dynamic updates).
|
||||
func (m *AmpModule) OnConfigUpdated(cfg *config.Config) error {
|
||||
// Update model mappings (hot-reload supported)
|
||||
if m.modelMapper != nil {
|
||||
log.Infof("amp config updated: reloading %d model mapping(s)", len(cfg.AmpModelMappings))
|
||||
m.modelMapper.UpdateMappings(cfg.AmpModelMappings)
|
||||
} else {
|
||||
log.Warnf("amp model mapper not initialized, skipping model mapping update")
|
||||
}
|
||||
|
||||
if !m.enabled {
|
||||
log.Debug("Amp routing not enabled, skipping config update")
|
||||
log.Debug("Amp routing not enabled, skipping other config updates")
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -182,4 +195,7 @@ func (m *AmpModule) OnConfigUpdated(cfg *config.Config) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
// GetModelMapper returns the model mapper instance (for testing/debugging).
|
||||
func (m *AmpModule) GetModelMapper() *DefaultModelMapper {
|
||||
return m.modelMapper
|
||||
}
|
||||
|
||||
@@ -6,16 +6,75 @@ import (
|
||||
"io"
|
||||
"net/http/httputil"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// AmpRouteType represents the type of routing decision made for an Amp request
|
||||
type AmpRouteType string
|
||||
|
||||
const (
|
||||
// RouteTypeLocalProvider indicates the request is handled by a local OAuth provider (free)
|
||||
RouteTypeLocalProvider AmpRouteType = "LOCAL_PROVIDER"
|
||||
// RouteTypeModelMapping indicates the request was remapped to another available model (free)
|
||||
RouteTypeModelMapping AmpRouteType = "MODEL_MAPPING"
|
||||
// RouteTypeAmpCredits indicates the request is forwarded to ampcode.com (uses Amp credits)
|
||||
RouteTypeAmpCredits AmpRouteType = "AMP_CREDITS"
|
||||
// RouteTypeNoProvider indicates no provider or fallback available
|
||||
RouteTypeNoProvider AmpRouteType = "NO_PROVIDER"
|
||||
)
|
||||
|
||||
// logAmpRouting logs the routing decision for an Amp request with structured fields
|
||||
func logAmpRouting(routeType AmpRouteType, requestedModel, resolvedModel, provider, path string) {
|
||||
fields := log.Fields{
|
||||
"component": "amp-routing",
|
||||
"route_type": string(routeType),
|
||||
"requested_model": requestedModel,
|
||||
"path": path,
|
||||
"timestamp": time.Now().Format(time.RFC3339),
|
||||
}
|
||||
|
||||
if resolvedModel != "" && resolvedModel != requestedModel {
|
||||
fields["resolved_model"] = resolvedModel
|
||||
}
|
||||
if provider != "" {
|
||||
fields["provider"] = provider
|
||||
}
|
||||
|
||||
switch routeType {
|
||||
case RouteTypeLocalProvider:
|
||||
fields["cost"] = "free"
|
||||
fields["source"] = "local_oauth"
|
||||
log.WithFields(fields).Infof("[AMP] Using local provider for model: %s", requestedModel)
|
||||
|
||||
case RouteTypeModelMapping:
|
||||
fields["cost"] = "free"
|
||||
fields["source"] = "local_oauth"
|
||||
fields["mapping"] = requestedModel + " -> " + resolvedModel
|
||||
log.WithFields(fields).Infof("[AMP] Model mapped: %s -> %s", requestedModel, resolvedModel)
|
||||
|
||||
case RouteTypeAmpCredits:
|
||||
fields["cost"] = "amp_credits"
|
||||
fields["source"] = "ampcode.com"
|
||||
fields["model_id"] = requestedModel // Explicit model_id for easy config reference
|
||||
log.WithFields(fields).Warnf("[AMP] Forwarding to ampcode.com (uses Amp credits) - model_id: %s | To use local proxy, add to config: amp-model-mappings: [{from: \"%s\", to: \"<your-local-model>\"}]", requestedModel, requestedModel)
|
||||
|
||||
case RouteTypeNoProvider:
|
||||
fields["cost"] = "none"
|
||||
fields["source"] = "error"
|
||||
fields["model_id"] = requestedModel // Explicit model_id for easy config reference
|
||||
log.WithFields(fields).Warnf("[AMP] No provider available for model_id: %s", requestedModel)
|
||||
}
|
||||
}
|
||||
|
||||
// FallbackHandler wraps a standard handler with fallback logic to ampcode.com
|
||||
// when the model's provider is not available in CLIProxyAPI
|
||||
type FallbackHandler struct {
|
||||
getProxy func() *httputil.ReverseProxy
|
||||
getProxy func() *httputil.ReverseProxy
|
||||
modelMapper ModelMapper
|
||||
}
|
||||
|
||||
// NewFallbackHandler creates a new fallback handler wrapper
|
||||
@@ -26,10 +85,25 @@ func NewFallbackHandler(getProxy func() *httputil.ReverseProxy) *FallbackHandler
|
||||
}
|
||||
}
|
||||
|
||||
// NewFallbackHandlerWithMapper creates a new fallback handler with model mapping support
|
||||
func NewFallbackHandlerWithMapper(getProxy func() *httputil.ReverseProxy, mapper ModelMapper) *FallbackHandler {
|
||||
return &FallbackHandler{
|
||||
getProxy: getProxy,
|
||||
modelMapper: mapper,
|
||||
}
|
||||
}
|
||||
|
||||
// SetModelMapper sets the model mapper for this handler (allows late binding)
|
||||
func (fh *FallbackHandler) SetModelMapper(mapper ModelMapper) {
|
||||
fh.modelMapper = mapper
|
||||
}
|
||||
|
||||
// WrapHandler wraps a gin.HandlerFunc with fallback logic
|
||||
// If the model's provider is not configured in CLIProxyAPI, it forwards to ampcode.com
|
||||
func (fh *FallbackHandler) WrapHandler(handler gin.HandlerFunc) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
requestPath := c.Request.URL.Path
|
||||
|
||||
// Read the request body to extract the model name
|
||||
bodyBytes, err := io.ReadAll(c.Request.Body)
|
||||
if err != nil {
|
||||
@@ -55,12 +129,33 @@ func (fh *FallbackHandler) WrapHandler(handler gin.HandlerFunc) gin.HandlerFunc
|
||||
// Check if we have providers for this model
|
||||
providers := util.GetProviderName(normalizedModel)
|
||||
|
||||
// Track resolved model for logging (may change if mapping is applied)
|
||||
resolvedModel := normalizedModel
|
||||
usedMapping := false
|
||||
|
||||
if len(providers) == 0 {
|
||||
// No providers configured - check if we have a proxy for fallback
|
||||
// No providers configured - check if we have a model mapping
|
||||
if fh.modelMapper != nil {
|
||||
if mappedModel := fh.modelMapper.MapModel(normalizedModel); mappedModel != "" {
|
||||
// Mapping found - rewrite the model in request body
|
||||
bodyBytes = rewriteModelInBody(bodyBytes, mappedModel)
|
||||
c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes))
|
||||
resolvedModel = mappedModel
|
||||
usedMapping = true
|
||||
|
||||
// Get providers for the mapped model
|
||||
providers = util.GetProviderName(mappedModel)
|
||||
|
||||
// Continue to handler with remapped model
|
||||
goto handleRequest
|
||||
}
|
||||
}
|
||||
|
||||
// No mapping found - check if we have a proxy for fallback
|
||||
proxy := fh.getProxy()
|
||||
if proxy != nil {
|
||||
// Fallback to ampcode.com
|
||||
log.Infof("amp fallback: model %s has no configured provider, forwarding to ampcode.com", modelName)
|
||||
// Log: Forwarding to ampcode.com (uses Amp credits)
|
||||
logAmpRouting(RouteTypeAmpCredits, modelName, "", "", requestPath)
|
||||
|
||||
// Restore body again for the proxy
|
||||
c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes))
|
||||
@@ -71,7 +166,23 @@ func (fh *FallbackHandler) WrapHandler(handler gin.HandlerFunc) gin.HandlerFunc
|
||||
}
|
||||
|
||||
// No proxy available, let the normal handler return the error
|
||||
log.Debugf("amp fallback: model %s has no configured provider and no proxy available", modelName)
|
||||
logAmpRouting(RouteTypeNoProvider, modelName, "", "", requestPath)
|
||||
}
|
||||
|
||||
handleRequest:
|
||||
|
||||
// Log the routing decision
|
||||
providerName := ""
|
||||
if len(providers) > 0 {
|
||||
providerName = providers[0]
|
||||
}
|
||||
|
||||
if usedMapping {
|
||||
// Log: Model was mapped to another model
|
||||
logAmpRouting(RouteTypeModelMapping, modelName, resolvedModel, providerName, requestPath)
|
||||
} else if len(providers) > 0 {
|
||||
// Log: Using local provider (free)
|
||||
logAmpRouting(RouteTypeLocalProvider, modelName, resolvedModel, providerName, requestPath)
|
||||
}
|
||||
|
||||
// Providers available or no proxy for fallback, restore body and use normal handler
|
||||
@@ -91,6 +202,27 @@ func (fh *FallbackHandler) WrapHandler(handler gin.HandlerFunc) gin.HandlerFunc
|
||||
}
|
||||
}
|
||||
|
||||
// rewriteModelInBody replaces the model name in a JSON request body
|
||||
func rewriteModelInBody(body []byte, newModel string) []byte {
|
||||
var payload map[string]interface{}
|
||||
if err := json.Unmarshal(body, &payload); err != nil {
|
||||
log.Warnf("amp model mapping: failed to parse body for rewrite: %v", err)
|
||||
return body
|
||||
}
|
||||
|
||||
if _, exists := payload["model"]; exists {
|
||||
payload["model"] = newModel
|
||||
newBody, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
log.Warnf("amp model mapping: failed to marshal rewritten body: %v", err)
|
||||
return body
|
||||
}
|
||||
return newBody
|
||||
}
|
||||
|
||||
return body
|
||||
}
|
||||
|
||||
// extractModelFromRequest attempts to extract the model name from various request formats
|
||||
func extractModelFromRequest(body []byte, c *gin.Context) string {
|
||||
// First try to parse from JSON body (OpenAI, Claude, etc.)
|
||||
|
||||
113
internal/api/modules/amp/model_mapping.go
Normal file
113
internal/api/modules/amp/model_mapping.go
Normal file
@@ -0,0 +1,113 @@
|
||||
// Package amp provides model mapping functionality for routing Amp CLI requests
|
||||
// to alternative models when the requested model is not available locally.
|
||||
package amp
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// ModelMapper provides model name mapping/aliasing for Amp CLI requests.
|
||||
// When an Amp request comes in for a model that isn't available locally,
|
||||
// this mapper can redirect it to an alternative model that IS available.
|
||||
type ModelMapper interface {
|
||||
// MapModel returns the target model name if a mapping exists and the target
|
||||
// model has available providers. Returns empty string if no mapping applies.
|
||||
MapModel(requestedModel string) string
|
||||
|
||||
// UpdateMappings refreshes the mapping configuration (for hot-reload).
|
||||
UpdateMappings(mappings []config.AmpModelMapping)
|
||||
}
|
||||
|
||||
// DefaultModelMapper implements ModelMapper with thread-safe mapping storage.
|
||||
type DefaultModelMapper struct {
|
||||
mu sync.RWMutex
|
||||
mappings map[string]string // from -> to (normalized lowercase keys)
|
||||
}
|
||||
|
||||
// NewModelMapper creates a new model mapper with the given initial mappings.
|
||||
func NewModelMapper(mappings []config.AmpModelMapping) *DefaultModelMapper {
|
||||
m := &DefaultModelMapper{
|
||||
mappings: make(map[string]string),
|
||||
}
|
||||
m.UpdateMappings(mappings)
|
||||
return m
|
||||
}
|
||||
|
||||
// MapModel checks if a mapping exists for the requested model and if the
|
||||
// target model has available local providers. Returns the mapped model name
|
||||
// or empty string if no valid mapping exists.
|
||||
func (m *DefaultModelMapper) MapModel(requestedModel string) string {
|
||||
if requestedModel == "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
|
||||
// Normalize the requested model for lookup
|
||||
normalizedRequest := strings.ToLower(strings.TrimSpace(requestedModel))
|
||||
|
||||
// Check for direct mapping
|
||||
targetModel, exists := m.mappings[normalizedRequest]
|
||||
if !exists {
|
||||
return ""
|
||||
}
|
||||
|
||||
// Verify target model has available providers
|
||||
providers := util.GetProviderName(targetModel)
|
||||
if len(providers) == 0 {
|
||||
log.Debugf("amp model mapping: target model %s has no available providers, skipping mapping", targetModel)
|
||||
return ""
|
||||
}
|
||||
|
||||
// Note: Detailed routing log is handled by logAmpRouting in fallback_handlers.go
|
||||
log.Debugf("amp model mapping: resolved %s -> %s", requestedModel, targetModel)
|
||||
return targetModel
|
||||
}
|
||||
|
||||
// UpdateMappings refreshes the mapping configuration from config.
|
||||
// This is called during initialization and on config hot-reload.
|
||||
func (m *DefaultModelMapper) UpdateMappings(mappings []config.AmpModelMapping) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
|
||||
// Clear and rebuild mappings
|
||||
m.mappings = make(map[string]string, len(mappings))
|
||||
|
||||
for _, mapping := range mappings {
|
||||
from := strings.TrimSpace(mapping.From)
|
||||
to := strings.TrimSpace(mapping.To)
|
||||
|
||||
if from == "" || to == "" {
|
||||
log.Warnf("amp model mapping: skipping invalid mapping (from=%q, to=%q)", from, to)
|
||||
continue
|
||||
}
|
||||
|
||||
// Store with normalized lowercase key for case-insensitive lookup
|
||||
normalizedFrom := strings.ToLower(from)
|
||||
m.mappings[normalizedFrom] = to
|
||||
|
||||
log.Debugf("amp model mapping registered: %s -> %s", from, to)
|
||||
}
|
||||
|
||||
if len(m.mappings) > 0 {
|
||||
log.Infof("amp model mapping: loaded %d mapping(s)", len(m.mappings))
|
||||
}
|
||||
}
|
||||
|
||||
// GetMappings returns a copy of current mappings (for debugging/status).
|
||||
func (m *DefaultModelMapper) GetMappings() map[string]string {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
|
||||
result := make(map[string]string, len(m.mappings))
|
||||
for k, v := range m.mappings {
|
||||
result[k] = v
|
||||
}
|
||||
return result
|
||||
}
|
||||
186
internal/api/modules/amp/model_mapping_test.go
Normal file
186
internal/api/modules/amp/model_mapping_test.go
Normal file
@@ -0,0 +1,186 @@
|
||||
package amp
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
|
||||
)
|
||||
|
||||
func TestNewModelMapper(t *testing.T) {
|
||||
mappings := []config.AmpModelMapping{
|
||||
{From: "claude-opus-4.5", To: "claude-sonnet-4"},
|
||||
{From: "gpt-5", To: "gemini-2.5-pro"},
|
||||
}
|
||||
|
||||
mapper := NewModelMapper(mappings)
|
||||
if mapper == nil {
|
||||
t.Fatal("Expected non-nil mapper")
|
||||
}
|
||||
|
||||
result := mapper.GetMappings()
|
||||
if len(result) != 2 {
|
||||
t.Errorf("Expected 2 mappings, got %d", len(result))
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewModelMapper_Empty(t *testing.T) {
|
||||
mapper := NewModelMapper(nil)
|
||||
if mapper == nil {
|
||||
t.Fatal("Expected non-nil mapper")
|
||||
}
|
||||
|
||||
result := mapper.GetMappings()
|
||||
if len(result) != 0 {
|
||||
t.Errorf("Expected 0 mappings, got %d", len(result))
|
||||
}
|
||||
}
|
||||
|
||||
func TestModelMapper_MapModel_NoProvider(t *testing.T) {
|
||||
mappings := []config.AmpModelMapping{
|
||||
{From: "claude-opus-4.5", To: "claude-sonnet-4"},
|
||||
}
|
||||
|
||||
mapper := NewModelMapper(mappings)
|
||||
|
||||
// Without a registered provider for the target, mapping should return empty
|
||||
result := mapper.MapModel("claude-opus-4.5")
|
||||
if result != "" {
|
||||
t.Errorf("Expected empty result when target has no provider, got %s", result)
|
||||
}
|
||||
}
|
||||
|
||||
func TestModelMapper_MapModel_WithProvider(t *testing.T) {
|
||||
// Register a mock provider for the target model
|
||||
reg := registry.GetGlobalRegistry()
|
||||
reg.RegisterClient("test-client", "claude", []*registry.ModelInfo{
|
||||
{ID: "claude-sonnet-4", OwnedBy: "anthropic", Type: "claude"},
|
||||
})
|
||||
defer reg.UnregisterClient("test-client")
|
||||
|
||||
mappings := []config.AmpModelMapping{
|
||||
{From: "claude-opus-4.5", To: "claude-sonnet-4"},
|
||||
}
|
||||
|
||||
mapper := NewModelMapper(mappings)
|
||||
|
||||
// With a registered provider, mapping should work
|
||||
result := mapper.MapModel("claude-opus-4.5")
|
||||
if result != "claude-sonnet-4" {
|
||||
t.Errorf("Expected claude-sonnet-4, got %s", result)
|
||||
}
|
||||
}
|
||||
|
||||
func TestModelMapper_MapModel_CaseInsensitive(t *testing.T) {
|
||||
reg := registry.GetGlobalRegistry()
|
||||
reg.RegisterClient("test-client2", "claude", []*registry.ModelInfo{
|
||||
{ID: "claude-sonnet-4", OwnedBy: "anthropic", Type: "claude"},
|
||||
})
|
||||
defer reg.UnregisterClient("test-client2")
|
||||
|
||||
mappings := []config.AmpModelMapping{
|
||||
{From: "Claude-Opus-4.5", To: "claude-sonnet-4"},
|
||||
}
|
||||
|
||||
mapper := NewModelMapper(mappings)
|
||||
|
||||
// Should match case-insensitively
|
||||
result := mapper.MapModel("claude-opus-4.5")
|
||||
if result != "claude-sonnet-4" {
|
||||
t.Errorf("Expected claude-sonnet-4, got %s", result)
|
||||
}
|
||||
}
|
||||
|
||||
func TestModelMapper_MapModel_NotFound(t *testing.T) {
|
||||
mappings := []config.AmpModelMapping{
|
||||
{From: "claude-opus-4.5", To: "claude-sonnet-4"},
|
||||
}
|
||||
|
||||
mapper := NewModelMapper(mappings)
|
||||
|
||||
// Unknown model should return empty
|
||||
result := mapper.MapModel("unknown-model")
|
||||
if result != "" {
|
||||
t.Errorf("Expected empty for unknown model, got %s", result)
|
||||
}
|
||||
}
|
||||
|
||||
func TestModelMapper_MapModel_EmptyInput(t *testing.T) {
|
||||
mappings := []config.AmpModelMapping{
|
||||
{From: "claude-opus-4.5", To: "claude-sonnet-4"},
|
||||
}
|
||||
|
||||
mapper := NewModelMapper(mappings)
|
||||
|
||||
result := mapper.MapModel("")
|
||||
if result != "" {
|
||||
t.Errorf("Expected empty for empty input, got %s", result)
|
||||
}
|
||||
}
|
||||
|
||||
func TestModelMapper_UpdateMappings(t *testing.T) {
|
||||
mapper := NewModelMapper(nil)
|
||||
|
||||
// Initially empty
|
||||
if len(mapper.GetMappings()) != 0 {
|
||||
t.Error("Expected 0 initial mappings")
|
||||
}
|
||||
|
||||
// Update with new mappings
|
||||
mapper.UpdateMappings([]config.AmpModelMapping{
|
||||
{From: "model-a", To: "model-b"},
|
||||
{From: "model-c", To: "model-d"},
|
||||
})
|
||||
|
||||
result := mapper.GetMappings()
|
||||
if len(result) != 2 {
|
||||
t.Errorf("Expected 2 mappings after update, got %d", len(result))
|
||||
}
|
||||
|
||||
// Update again should replace, not append
|
||||
mapper.UpdateMappings([]config.AmpModelMapping{
|
||||
{From: "model-x", To: "model-y"},
|
||||
})
|
||||
|
||||
result = mapper.GetMappings()
|
||||
if len(result) != 1 {
|
||||
t.Errorf("Expected 1 mapping after second update, got %d", len(result))
|
||||
}
|
||||
}
|
||||
|
||||
func TestModelMapper_UpdateMappings_SkipsInvalid(t *testing.T) {
|
||||
mapper := NewModelMapper(nil)
|
||||
|
||||
mapper.UpdateMappings([]config.AmpModelMapping{
|
||||
{From: "", To: "model-b"}, // Invalid: empty from
|
||||
{From: "model-a", To: ""}, // Invalid: empty to
|
||||
{From: " ", To: "model-b"}, // Invalid: whitespace from
|
||||
{From: "model-c", To: "model-d"}, // Valid
|
||||
})
|
||||
|
||||
result := mapper.GetMappings()
|
||||
if len(result) != 1 {
|
||||
t.Errorf("Expected 1 valid mapping, got %d", len(result))
|
||||
}
|
||||
}
|
||||
|
||||
func TestModelMapper_GetMappings_ReturnsCopy(t *testing.T) {
|
||||
mappings := []config.AmpModelMapping{
|
||||
{From: "model-a", To: "model-b"},
|
||||
}
|
||||
|
||||
mapper := NewModelMapper(mappings)
|
||||
|
||||
// Get mappings and modify the returned map
|
||||
result := mapper.GetMappings()
|
||||
result["new-key"] = "new-value"
|
||||
|
||||
// Original should be unchanged
|
||||
original := mapper.GetMappings()
|
||||
if len(original) != 1 {
|
||||
t.Errorf("Expected original to have 1 mapping, got %d", len(original))
|
||||
}
|
||||
if _, exists := original["new-key"]; exists {
|
||||
t.Error("Original map was modified")
|
||||
}
|
||||
}
|
||||
@@ -83,7 +83,7 @@ func createReverseProxy(upstreamURL string, secretSource SecretSource) (*httputi
|
||||
// Peek at first 2 bytes to detect gzip magic bytes
|
||||
header := make([]byte, 2)
|
||||
n, _ := io.ReadFull(originalBody, header)
|
||||
|
||||
|
||||
// Check for gzip magic bytes (0x1f 0x8b)
|
||||
// If n < 2, we didn't get enough bytes, so it's not gzip
|
||||
if n >= 2 && header[0] == 0x1f && header[1] == 0x8b {
|
||||
@@ -97,7 +97,7 @@ func createReverseProxy(upstreamURL string, secretSource SecretSource) (*httputi
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
// Reconstruct complete gzipped data
|
||||
gzippedData := append(header[:n], rest...)
|
||||
|
||||
@@ -129,8 +129,8 @@ func createReverseProxy(upstreamURL string, secretSource SecretSource) (*httputi
|
||||
resp.ContentLength = int64(len(decompressed))
|
||||
|
||||
// Update headers to reflect decompressed state
|
||||
resp.Header.Del("Content-Encoding") // No longer compressed
|
||||
resp.Header.Del("Content-Length") // Remove stale compressed length
|
||||
resp.Header.Del("Content-Encoding") // No longer compressed
|
||||
resp.Header.Del("Content-Length") // Remove stale compressed length
|
||||
resp.Header.Set("Content-Length", strconv.FormatInt(resp.ContentLength, 10)) // Set decompressed length
|
||||
|
||||
log.Debugf("amp proxy: decompressed gzip response (%d -> %d bytes)", len(gzippedData), len(decompressed))
|
||||
|
||||
@@ -440,52 +440,52 @@ func TestIsStreamingResponse(t *testing.T) {
|
||||
|
||||
func TestFilterBetaFeatures(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
header string
|
||||
name string
|
||||
header string
|
||||
featureToRemove string
|
||||
expected string
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "Remove context-1m from middle",
|
||||
header: "fine-grained-tool-streaming-2025-05-14,context-1m-2025-08-07,oauth-2025-04-20",
|
||||
name: "Remove context-1m from middle",
|
||||
header: "fine-grained-tool-streaming-2025-05-14,context-1m-2025-08-07,oauth-2025-04-20",
|
||||
featureToRemove: "context-1m-2025-08-07",
|
||||
expected: "fine-grained-tool-streaming-2025-05-14,oauth-2025-04-20",
|
||||
expected: "fine-grained-tool-streaming-2025-05-14,oauth-2025-04-20",
|
||||
},
|
||||
{
|
||||
name: "Remove context-1m from start",
|
||||
header: "context-1m-2025-08-07,fine-grained-tool-streaming-2025-05-14",
|
||||
name: "Remove context-1m from start",
|
||||
header: "context-1m-2025-08-07,fine-grained-tool-streaming-2025-05-14",
|
||||
featureToRemove: "context-1m-2025-08-07",
|
||||
expected: "fine-grained-tool-streaming-2025-05-14",
|
||||
expected: "fine-grained-tool-streaming-2025-05-14",
|
||||
},
|
||||
{
|
||||
name: "Remove context-1m from end",
|
||||
header: "fine-grained-tool-streaming-2025-05-14,context-1m-2025-08-07",
|
||||
name: "Remove context-1m from end",
|
||||
header: "fine-grained-tool-streaming-2025-05-14,context-1m-2025-08-07",
|
||||
featureToRemove: "context-1m-2025-08-07",
|
||||
expected: "fine-grained-tool-streaming-2025-05-14",
|
||||
expected: "fine-grained-tool-streaming-2025-05-14",
|
||||
},
|
||||
{
|
||||
name: "Feature not present",
|
||||
header: "fine-grained-tool-streaming-2025-05-14,oauth-2025-04-20",
|
||||
name: "Feature not present",
|
||||
header: "fine-grained-tool-streaming-2025-05-14,oauth-2025-04-20",
|
||||
featureToRemove: "context-1m-2025-08-07",
|
||||
expected: "fine-grained-tool-streaming-2025-05-14,oauth-2025-04-20",
|
||||
expected: "fine-grained-tool-streaming-2025-05-14,oauth-2025-04-20",
|
||||
},
|
||||
{
|
||||
name: "Only feature to remove",
|
||||
header: "context-1m-2025-08-07",
|
||||
name: "Only feature to remove",
|
||||
header: "context-1m-2025-08-07",
|
||||
featureToRemove: "context-1m-2025-08-07",
|
||||
expected: "",
|
||||
expected: "",
|
||||
},
|
||||
{
|
||||
name: "Empty header",
|
||||
header: "",
|
||||
name: "Empty header",
|
||||
header: "",
|
||||
featureToRemove: "context-1m-2025-08-07",
|
||||
expected: "",
|
||||
expected: "",
|
||||
},
|
||||
{
|
||||
name: "Header with spaces",
|
||||
header: "fine-grained-tool-streaming-2025-05-14, context-1m-2025-08-07 , oauth-2025-04-20",
|
||||
name: "Header with spaces",
|
||||
header: "fine-grained-tool-streaming-2025-05-14, context-1m-2025-08-07 , oauth-2025-04-20",
|
||||
featureToRemove: "context-1m-2025-08-07",
|
||||
expected: "fine-grained-tool-streaming-2025-05-14,oauth-2025-04-20",
|
||||
expected: "fine-grained-tool-streaming-2025-05-14,oauth-2025-04-20",
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@@ -6,11 +6,11 @@ import (
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/sdk/api/handlers"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/sdk/api/handlers/claude"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/sdk/api/handlers/gemini"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/sdk/api/handlers/openai"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
@@ -111,6 +111,14 @@ func (m *AmpModule) registerManagementRoutes(engine *gin.Engine, baseHandler *ha
|
||||
ampAPI.Any("/otel", proxyHandler)
|
||||
ampAPI.Any("/otel/*path", proxyHandler)
|
||||
|
||||
// Root-level routes that AMP CLI expects without /api prefix
|
||||
// These need the same security middleware as the /api/* routes
|
||||
rootMiddleware := []gin.HandlerFunc{noCORSMiddleware()}
|
||||
if restrictToLocalhost {
|
||||
rootMiddleware = append(rootMiddleware, localhostOnlyMiddleware())
|
||||
}
|
||||
engine.GET("/threads.rss", append(rootMiddleware, proxyHandler)...)
|
||||
|
||||
// Google v1beta1 passthrough with OAuth fallback
|
||||
// AMP CLI uses non-standard paths like /publishers/google/models/...
|
||||
// We bridge these to our standard Gemini handler to enable local OAuth.
|
||||
@@ -162,9 +170,10 @@ func (m *AmpModule) registerProviderAliases(engine *gin.Engine, baseHandler *han
|
||||
|
||||
// Create fallback handler wrapper that forwards to ampcode.com when provider not found
|
||||
// Uses lazy evaluation to access proxy (which is created after routes are registered)
|
||||
fallbackHandler := NewFallbackHandler(func() *httputil.ReverseProxy {
|
||||
// Also includes model mapping support for routing unavailable models to alternatives
|
||||
fallbackHandler := NewFallbackHandlerWithMapper(func() *httputil.ReverseProxy {
|
||||
return m.proxy
|
||||
})
|
||||
}, m.modelMapper)
|
||||
|
||||
// Provider-specific routes under /api/provider/:provider
|
||||
ampProviders := engine.Group("/api/provider")
|
||||
|
||||
@@ -37,6 +37,7 @@ func TestRegisterManagementRoutes(t *testing.T) {
|
||||
{"/api/meta", http.MethodGet},
|
||||
{"/api/telemetry", http.MethodGet},
|
||||
{"/api/threads", http.MethodGet},
|
||||
{"/threads.rss", http.MethodGet}, // Root-level route (no /api prefix)
|
||||
{"/api/otel", http.MethodGet},
|
||||
// Google v1beta1 bridge should still proxy non-model requests (GET) and allow POST
|
||||
{"/api/provider/google/v1beta1/models", http.MethodGet},
|
||||
|
||||
@@ -150,6 +150,9 @@ type Server struct {
|
||||
// management handler
|
||||
mgmt *managementHandlers.Handler
|
||||
|
||||
// ampModule is the Amp routing module for model mapping hot-reload
|
||||
ampModule *ampmodule.AmpModule
|
||||
|
||||
// managementRoutesRegistered tracks whether the management routes have been attached to the engine.
|
||||
managementRoutesRegistered atomic.Bool
|
||||
// managementRoutesEnabled controls whether management endpoints serve real handlers.
|
||||
@@ -247,6 +250,9 @@ func NewServer(cfg *config.Config, authManager *auth.Manager, accessManager *sdk
|
||||
// Save initial YAML snapshot
|
||||
s.oldConfigYaml, _ = yaml.Marshal(cfg)
|
||||
s.applyAccessConfig(nil, cfg)
|
||||
if authManager != nil {
|
||||
authManager.SetRetryConfig(cfg.RequestRetry, time.Duration(cfg.MaxRetryInterval)*time.Second)
|
||||
}
|
||||
managementasset.SetCurrentConfig(cfg)
|
||||
auth.SetQuotaCooldownDisabled(cfg.DisableCooling)
|
||||
// Initialize management handler
|
||||
@@ -265,14 +271,14 @@ func NewServer(cfg *config.Config, authManager *auth.Manager, accessManager *sdk
|
||||
s.setupRoutes()
|
||||
|
||||
// Register Amp module using V2 interface with Context
|
||||
ampModule := ampmodule.NewLegacy(accessManager, AuthMiddleware(accessManager))
|
||||
s.ampModule = ampmodule.NewLegacy(accessManager, AuthMiddleware(accessManager))
|
||||
ctx := modules.Context{
|
||||
Engine: engine,
|
||||
BaseHandler: s.handlers,
|
||||
Config: cfg,
|
||||
AuthMiddleware: AuthMiddleware(accessManager),
|
||||
}
|
||||
if err := modules.RegisterModule(ctx, ampModule); err != nil {
|
||||
if err := modules.RegisterModule(ctx, s.ampModule); err != nil {
|
||||
log.Errorf("Failed to register Amp module: %v", err)
|
||||
}
|
||||
|
||||
@@ -509,6 +515,8 @@ func (s *Server) registerManagementRoutes() {
|
||||
|
||||
mgmt.GET("/logs", s.mgmt.GetLogs)
|
||||
mgmt.DELETE("/logs", s.mgmt.DeleteLogs)
|
||||
mgmt.GET("/request-error-logs", s.mgmt.GetRequestErrorLogs)
|
||||
mgmt.GET("/request-error-logs/:name", s.mgmt.DownloadRequestErrorLog)
|
||||
mgmt.GET("/request-log", s.mgmt.GetRequestLog)
|
||||
mgmt.PUT("/request-log", s.mgmt.PutRequestLog)
|
||||
mgmt.PATCH("/request-log", s.mgmt.PutRequestLog)
|
||||
@@ -519,6 +527,9 @@ func (s *Server) registerManagementRoutes() {
|
||||
mgmt.GET("/request-retry", s.mgmt.GetRequestRetry)
|
||||
mgmt.PUT("/request-retry", s.mgmt.PutRequestRetry)
|
||||
mgmt.PATCH("/request-retry", s.mgmt.PutRequestRetry)
|
||||
mgmt.GET("/max-retry-interval", s.mgmt.GetMaxRetryInterval)
|
||||
mgmt.PUT("/max-retry-interval", s.mgmt.PutMaxRetryInterval)
|
||||
mgmt.PATCH("/max-retry-interval", s.mgmt.PutMaxRetryInterval)
|
||||
|
||||
mgmt.GET("/claude-api-key", s.mgmt.GetClaudeKeys)
|
||||
mgmt.PUT("/claude-api-key", s.mgmt.PutClaudeKeys)
|
||||
@@ -535,6 +546,11 @@ func (s *Server) registerManagementRoutes() {
|
||||
mgmt.PATCH("/openai-compatibility", s.mgmt.PatchOpenAICompat)
|
||||
mgmt.DELETE("/openai-compatibility", s.mgmt.DeleteOpenAICompat)
|
||||
|
||||
mgmt.GET("/oauth-excluded-models", s.mgmt.GetOAuthExcludedModels)
|
||||
mgmt.PUT("/oauth-excluded-models", s.mgmt.PutOAuthExcludedModels)
|
||||
mgmt.PATCH("/oauth-excluded-models", s.mgmt.PatchOAuthExcludedModels)
|
||||
mgmt.DELETE("/oauth-excluded-models", s.mgmt.DeleteOAuthExcludedModels)
|
||||
|
||||
mgmt.GET("/auth-files", s.mgmt.ListAuthFiles)
|
||||
mgmt.GET("/auth-files/download", s.mgmt.DownloadAuthFile)
|
||||
mgmt.POST("/auth-files", s.mgmt.UploadAuthFile)
|
||||
@@ -686,17 +702,33 @@ func (s *Server) unifiedModelsHandler(openaiHandler *openai.OpenAIAPIHandler, cl
|
||||
}
|
||||
}
|
||||
|
||||
// Start begins listening for and serving HTTP requests.
|
||||
// Start begins listening for and serving HTTP or HTTPS requests.
|
||||
// It's a blocking call and will only return on an unrecoverable error.
|
||||
//
|
||||
// Returns:
|
||||
// - error: An error if the server fails to start
|
||||
func (s *Server) Start() error {
|
||||
log.Debugf("Starting API server on %s", s.server.Addr)
|
||||
if s == nil || s.server == nil {
|
||||
return fmt.Errorf("failed to start HTTP server: server not initialized")
|
||||
}
|
||||
|
||||
// Start the HTTP server.
|
||||
if err := s.server.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) {
|
||||
return fmt.Errorf("failed to start HTTP server: %v", err)
|
||||
useTLS := s.cfg != nil && s.cfg.TLS.Enable
|
||||
if useTLS {
|
||||
cert := strings.TrimSpace(s.cfg.TLS.Cert)
|
||||
key := strings.TrimSpace(s.cfg.TLS.Key)
|
||||
if cert == "" || key == "" {
|
||||
return fmt.Errorf("failed to start HTTPS server: tls.cert or tls.key is empty")
|
||||
}
|
||||
log.Debugf("Starting API server on %s with TLS", s.server.Addr)
|
||||
if errServeTLS := s.server.ListenAndServeTLS(cert, key); errServeTLS != nil && !errors.Is(errServeTLS, http.ErrServerClosed) {
|
||||
return fmt.Errorf("failed to start HTTPS server: %v", errServeTLS)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
log.Debugf("Starting API server on %s", s.server.Addr)
|
||||
if errServe := s.server.ListenAndServe(); errServe != nil && !errors.Is(errServe, http.ErrServerClosed) {
|
||||
return fmt.Errorf("failed to start HTTP server: %v", errServe)
|
||||
}
|
||||
|
||||
return nil
|
||||
@@ -814,6 +846,9 @@ func (s *Server) UpdateClients(cfg *config.Config) {
|
||||
log.Debugf("disable_cooling toggled to %t", cfg.DisableCooling)
|
||||
}
|
||||
}
|
||||
if s.handlers != nil && s.handlers.AuthManager != nil {
|
||||
s.handlers.AuthManager.SetRetryConfig(cfg.RequestRetry, time.Duration(cfg.MaxRetryInterval)*time.Second)
|
||||
}
|
||||
|
||||
// Update log level dynamically when debug flag changes
|
||||
if oldCfg == nil || oldCfg.Debug != cfg.Debug {
|
||||
@@ -884,11 +919,22 @@ func (s *Server) UpdateClients(cfg *config.Config) {
|
||||
s.mgmt.SetAuthManager(s.handlers.AuthManager)
|
||||
}
|
||||
|
||||
// Notify Amp module of config changes (for model mapping hot-reload)
|
||||
if s.ampModule != nil {
|
||||
log.Debugf("triggering amp module config update")
|
||||
if err := s.ampModule.OnConfigUpdated(cfg); err != nil {
|
||||
log.Errorf("failed to update Amp module config: %v", err)
|
||||
}
|
||||
} else {
|
||||
log.Warnf("amp module is nil, skipping config update")
|
||||
}
|
||||
|
||||
// Count client sources from configuration and auth directory
|
||||
authFiles := util.CountAuthFiles(cfg.AuthDir)
|
||||
geminiAPIKeyCount := len(cfg.GeminiKey)
|
||||
claudeAPIKeyCount := len(cfg.ClaudeKey)
|
||||
codexAPIKeyCount := len(cfg.CodexKey)
|
||||
vertexAICompatCount := len(cfg.VertexCompatAPIKey)
|
||||
openAICompatCount := 0
|
||||
for i := range cfg.OpenAICompatibility {
|
||||
entry := cfg.OpenAICompatibility[i]
|
||||
@@ -899,13 +945,14 @@ func (s *Server) UpdateClients(cfg *config.Config) {
|
||||
openAICompatCount += len(entry.APIKeys)
|
||||
}
|
||||
|
||||
total := authFiles + geminiAPIKeyCount + claudeAPIKeyCount + codexAPIKeyCount + openAICompatCount
|
||||
fmt.Printf("server clients and configuration updated: %d clients (%d auth files + %d Gemini API keys + %d Claude API keys + %d Codex keys + %d OpenAI-compat)\n",
|
||||
total := authFiles + geminiAPIKeyCount + claudeAPIKeyCount + codexAPIKeyCount + vertexAICompatCount + openAICompatCount
|
||||
fmt.Printf("server clients and configuration updated: %d clients (%d auth files + %d Gemini API keys + %d Claude API keys + %d Codex keys + %d Vertex-compat + %d OpenAI-compat)\n",
|
||||
total,
|
||||
authFiles,
|
||||
geminiAPIKeyCount,
|
||||
claudeAPIKeyCount,
|
||||
codexAPIKeyCount,
|
||||
vertexAICompatCount,
|
||||
openAICompatCount,
|
||||
)
|
||||
}
|
||||
|
||||
@@ -23,6 +23,9 @@ type Config struct {
|
||||
// Port is the network port on which the API server will listen.
|
||||
Port int `yaml:"port" json:"-"`
|
||||
|
||||
// TLS config controls HTTPS server settings.
|
||||
TLS TLSConfig `yaml:"tls" json:"tls"`
|
||||
|
||||
// AmpUpstreamURL defines the upstream Amp control plane used for non-provider calls.
|
||||
AmpUpstreamURL string `yaml:"amp-upstream-url" json:"amp-upstream-url"`
|
||||
|
||||
@@ -34,6 +37,12 @@ type Config struct {
|
||||
// browser attacks and remote access to management endpoints. Default: true (recommended).
|
||||
AmpRestrictManagementToLocalhost bool `yaml:"amp-restrict-management-to-localhost" json:"amp-restrict-management-to-localhost"`
|
||||
|
||||
// AmpModelMappings defines model name mappings for Amp CLI requests.
|
||||
// When Amp requests a model that isn't available locally, these mappings
|
||||
// allow routing to an alternative model that IS available.
|
||||
// Example: Map "claude-opus-4.5" -> "claude-sonnet-4" when opus isn't available.
|
||||
AmpModelMappings []AmpModelMapping `yaml:"amp-model-mappings" json:"amp-model-mappings"`
|
||||
|
||||
// AuthDir is the directory where authentication token files are stored.
|
||||
AuthDir string `yaml:"auth-dir" json:"-"`
|
||||
|
||||
@@ -61,8 +70,14 @@ type Config struct {
|
||||
// GeminiKey defines Gemini API key configurations with optional routing overrides.
|
||||
GeminiKey []GeminiKey `yaml:"gemini-api-key" json:"gemini-api-key"`
|
||||
|
||||
// VertexCompatAPIKey defines Vertex AI-compatible API key configurations for third-party providers.
|
||||
// Used for services that use Vertex AI-style paths but with simple API key authentication.
|
||||
VertexCompatAPIKey []VertexCompatKey `yaml:"vertex-api-key" json:"vertex-api-key"`
|
||||
|
||||
// RequestRetry defines the retry times when the request failed.
|
||||
RequestRetry int `yaml:"request-retry" json:"request-retry"`
|
||||
// MaxRetryInterval defines the maximum wait time in seconds before retrying a cooled-down credential.
|
||||
MaxRetryInterval int `yaml:"max-retry-interval" json:"max-retry-interval"`
|
||||
|
||||
// ClaudeKey defines a list of Claude API key configurations as specified in the YAML configuration file.
|
||||
ClaudeKey []ClaudeKey `yaml:"claude-api-key" json:"claude-api-key"`
|
||||
@@ -78,6 +93,19 @@ type Config struct {
|
||||
|
||||
// Payload defines default and override rules for provider payload parameters.
|
||||
Payload PayloadConfig `yaml:"payload" json:"payload"`
|
||||
|
||||
// OAuthExcludedModels defines per-provider global model exclusions applied to OAuth/file-backed auth entries.
|
||||
OAuthExcludedModels map[string][]string `yaml:"oauth-excluded-models,omitempty" json:"oauth-excluded-models,omitempty"`
|
||||
}
|
||||
|
||||
// TLSConfig holds HTTPS server settings.
|
||||
type TLSConfig struct {
|
||||
// Enable toggles HTTPS server mode.
|
||||
Enable bool `yaml:"enable" json:"enable"`
|
||||
// Cert is the path to the TLS certificate file.
|
||||
Cert string `yaml:"cert" json:"cert"`
|
||||
// Key is the path to the TLS private key file.
|
||||
Key string `yaml:"key" json:"key"`
|
||||
}
|
||||
|
||||
// RemoteManagement holds management API configuration under 'remote-management'.
|
||||
@@ -100,6 +128,18 @@ type QuotaExceeded struct {
|
||||
SwitchPreviewModel bool `yaml:"switch-preview-model" json:"switch-preview-model"`
|
||||
}
|
||||
|
||||
// AmpModelMapping defines a model name mapping for Amp CLI requests.
|
||||
// When Amp requests a model that isn't available locally, this mapping
|
||||
// allows routing to an alternative model that IS available.
|
||||
type AmpModelMapping struct {
|
||||
// From is the model name that Amp CLI requests (e.g., "claude-opus-4.5").
|
||||
From string `yaml:"from" json:"from"`
|
||||
|
||||
// To is the target model name to route to (e.g., "claude-sonnet-4").
|
||||
// The target model must have available providers in the registry.
|
||||
To string `yaml:"to" json:"to"`
|
||||
}
|
||||
|
||||
// PayloadConfig defines default and override parameter rules applied to provider payloads.
|
||||
type PayloadConfig struct {
|
||||
// Default defines rules that only set parameters when they are missing in the payload.
|
||||
@@ -142,6 +182,9 @@ type ClaudeKey struct {
|
||||
|
||||
// Headers optionally adds extra HTTP headers for requests sent with this key.
|
||||
Headers map[string]string `yaml:"headers,omitempty" json:"headers,omitempty"`
|
||||
|
||||
// ExcludedModels lists model IDs that should be excluded for this provider.
|
||||
ExcludedModels []string `yaml:"excluded-models,omitempty" json:"excluded-models,omitempty"`
|
||||
}
|
||||
|
||||
// ClaudeModel describes a mapping between an alias and the actual upstream model name.
|
||||
@@ -168,6 +211,9 @@ type CodexKey struct {
|
||||
|
||||
// Headers optionally adds extra HTTP headers for requests sent with this key.
|
||||
Headers map[string]string `yaml:"headers,omitempty" json:"headers,omitempty"`
|
||||
|
||||
// ExcludedModels lists model IDs that should be excluded for this provider.
|
||||
ExcludedModels []string `yaml:"excluded-models,omitempty" json:"excluded-models,omitempty"`
|
||||
}
|
||||
|
||||
// GeminiKey represents the configuration for a Gemini API key,
|
||||
@@ -184,6 +230,9 @@ type GeminiKey struct {
|
||||
|
||||
// Headers optionally adds extra HTTP headers for requests sent with this key.
|
||||
Headers map[string]string `yaml:"headers,omitempty" json:"headers,omitempty"`
|
||||
|
||||
// ExcludedModels lists model IDs that should be excluded for this provider.
|
||||
ExcludedModels []string `yaml:"excluded-models,omitempty" json:"excluded-models,omitempty"`
|
||||
}
|
||||
|
||||
// OpenAICompatibility represents the configuration for OpenAI API compatibility
|
||||
@@ -298,6 +347,9 @@ func LoadConfigOptional(configFile string, optional bool) (*Config, error) {
|
||||
// Sanitize Gemini API key configuration and migrate legacy entries.
|
||||
cfg.SanitizeGeminiKeys()
|
||||
|
||||
// Sanitize Vertex-compatible API keys: drop entries without base-url
|
||||
cfg.SanitizeVertexCompatKeys()
|
||||
|
||||
// Sanitize Codex keys: drop entries without base-url
|
||||
cfg.SanitizeCodexKeys()
|
||||
|
||||
@@ -307,6 +359,9 @@ func LoadConfigOptional(configFile string, optional bool) (*Config, error) {
|
||||
// Sanitize OpenAI compatibility providers: drop entries without base-url
|
||||
cfg.SanitizeOpenAICompatibility()
|
||||
|
||||
// Normalize OAuth provider model exclusion map.
|
||||
cfg.OAuthExcludedModels = NormalizeOAuthExcludedModels(cfg.OAuthExcludedModels)
|
||||
|
||||
// Return the populated configuration struct.
|
||||
return &cfg, nil
|
||||
}
|
||||
@@ -344,6 +399,7 @@ func (cfg *Config) SanitizeCodexKeys() {
|
||||
e := cfg.CodexKey[i]
|
||||
e.BaseURL = strings.TrimSpace(e.BaseURL)
|
||||
e.Headers = NormalizeHeaders(e.Headers)
|
||||
e.ExcludedModels = NormalizeExcludedModels(e.ExcludedModels)
|
||||
if e.BaseURL == "" {
|
||||
continue
|
||||
}
|
||||
@@ -360,6 +416,7 @@ func (cfg *Config) SanitizeClaudeKeys() {
|
||||
for i := range cfg.ClaudeKey {
|
||||
entry := &cfg.ClaudeKey[i]
|
||||
entry.Headers = NormalizeHeaders(entry.Headers)
|
||||
entry.ExcludedModels = NormalizeExcludedModels(entry.ExcludedModels)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -380,6 +437,7 @@ func (cfg *Config) SanitizeGeminiKeys() {
|
||||
entry.BaseURL = strings.TrimSpace(entry.BaseURL)
|
||||
entry.ProxyURL = strings.TrimSpace(entry.ProxyURL)
|
||||
entry.Headers = NormalizeHeaders(entry.Headers)
|
||||
entry.ExcludedModels = NormalizeExcludedModels(entry.ExcludedModels)
|
||||
if _, exists := seen[entry.APIKey]; exists {
|
||||
continue
|
||||
}
|
||||
@@ -442,6 +500,55 @@ func NormalizeHeaders(headers map[string]string) map[string]string {
|
||||
return clean
|
||||
}
|
||||
|
||||
// NormalizeExcludedModels trims, lowercases, and deduplicates model exclusion patterns.
|
||||
// It preserves the order of first occurrences and drops empty entries.
|
||||
func NormalizeExcludedModels(models []string) []string {
|
||||
if len(models) == 0 {
|
||||
return nil
|
||||
}
|
||||
seen := make(map[string]struct{}, len(models))
|
||||
out := make([]string, 0, len(models))
|
||||
for _, raw := range models {
|
||||
trimmed := strings.ToLower(strings.TrimSpace(raw))
|
||||
if trimmed == "" {
|
||||
continue
|
||||
}
|
||||
if _, exists := seen[trimmed]; exists {
|
||||
continue
|
||||
}
|
||||
seen[trimmed] = struct{}{}
|
||||
out = append(out, trimmed)
|
||||
}
|
||||
if len(out) == 0 {
|
||||
return nil
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// NormalizeOAuthExcludedModels cleans provider -> excluded models mappings by normalizing provider keys
|
||||
// and applying model exclusion normalization to each entry.
|
||||
func NormalizeOAuthExcludedModels(entries map[string][]string) map[string][]string {
|
||||
if len(entries) == 0 {
|
||||
return nil
|
||||
}
|
||||
out := make(map[string][]string, len(entries))
|
||||
for provider, models := range entries {
|
||||
key := strings.ToLower(strings.TrimSpace(provider))
|
||||
if key == "" {
|
||||
continue
|
||||
}
|
||||
normalized := NormalizeExcludedModels(models)
|
||||
if len(normalized) == 0 {
|
||||
continue
|
||||
}
|
||||
out[key] = normalized
|
||||
}
|
||||
if len(out) == 0 {
|
||||
return nil
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// hashSecret hashes the given secret using bcrypt.
|
||||
func hashSecret(secret string) (string, error) {
|
||||
// Use default cost for simplicity.
|
||||
@@ -731,6 +838,7 @@ func shouldSkipEmptyCollectionOnPersist(key string, node *yaml.Node) bool {
|
||||
switch key {
|
||||
case "generative-language-api-key",
|
||||
"gemini-api-key",
|
||||
"vertex-api-key",
|
||||
"claude-api-key",
|
||||
"codex-api-key",
|
||||
"openai-compatibility":
|
||||
|
||||
84
internal/config/vertex_compat.go
Normal file
84
internal/config/vertex_compat.go
Normal file
@@ -0,0 +1,84 @@
|
||||
package config
|
||||
|
||||
import "strings"
|
||||
|
||||
// VertexCompatKey represents the configuration for Vertex AI-compatible API keys.
|
||||
// This supports third-party services that use Vertex AI-style endpoint paths
|
||||
// (/publishers/google/models/{model}:streamGenerateContent) but authenticate
|
||||
// with simple API keys instead of Google Cloud service account credentials.
|
||||
//
|
||||
// Example services: zenmux.ai and similar Vertex-compatible providers.
|
||||
type VertexCompatKey struct {
|
||||
// APIKey is the authentication key for accessing the Vertex-compatible API.
|
||||
// Maps to the x-goog-api-key header.
|
||||
APIKey string `yaml:"api-key" json:"api-key"`
|
||||
|
||||
// BaseURL is the base URL for the Vertex-compatible API endpoint.
|
||||
// The executor will append "/v1/publishers/google/models/{model}:action" to this.
|
||||
// Example: "https://zenmux.ai/api" becomes "https://zenmux.ai/api/v1/publishers/google/models/..."
|
||||
BaseURL string `yaml:"base-url,omitempty" json:"base-url,omitempty"`
|
||||
|
||||
// ProxyURL optionally overrides the global proxy for this API key.
|
||||
ProxyURL string `yaml:"proxy-url,omitempty" json:"proxy-url,omitempty"`
|
||||
|
||||
// Headers optionally adds extra HTTP headers for requests sent with this key.
|
||||
// Commonly used for cookies, user-agent, and other authentication headers.
|
||||
Headers map[string]string `yaml:"headers,omitempty" json:"headers,omitempty"`
|
||||
|
||||
// Models defines the model configurations including aliases for routing.
|
||||
Models []VertexCompatModel `yaml:"models,omitempty" json:"models,omitempty"`
|
||||
}
|
||||
|
||||
// VertexCompatModel represents a model configuration for Vertex compatibility,
|
||||
// including the actual model name and its alias for API routing.
|
||||
type VertexCompatModel struct {
|
||||
// Name is the actual model name used by the external provider.
|
||||
Name string `yaml:"name" json:"name"`
|
||||
|
||||
// Alias is the model name alias that clients will use to reference this model.
|
||||
Alias string `yaml:"alias" json:"alias"`
|
||||
}
|
||||
|
||||
// SanitizeVertexCompatKeys deduplicates and normalizes Vertex-compatible API key credentials.
|
||||
func (cfg *Config) SanitizeVertexCompatKeys() {
|
||||
if cfg == nil {
|
||||
return
|
||||
}
|
||||
|
||||
seen := make(map[string]struct{}, len(cfg.VertexCompatAPIKey))
|
||||
out := cfg.VertexCompatAPIKey[:0]
|
||||
for i := range cfg.VertexCompatAPIKey {
|
||||
entry := cfg.VertexCompatAPIKey[i]
|
||||
entry.APIKey = strings.TrimSpace(entry.APIKey)
|
||||
if entry.APIKey == "" {
|
||||
continue
|
||||
}
|
||||
entry.BaseURL = strings.TrimSpace(entry.BaseURL)
|
||||
if entry.BaseURL == "" {
|
||||
// BaseURL is required for Vertex API key entries
|
||||
continue
|
||||
}
|
||||
entry.ProxyURL = strings.TrimSpace(entry.ProxyURL)
|
||||
entry.Headers = NormalizeHeaders(entry.Headers)
|
||||
|
||||
// Sanitize models: remove entries without valid alias
|
||||
sanitizedModels := make([]VertexCompatModel, 0, len(entry.Models))
|
||||
for _, model := range entry.Models {
|
||||
model.Alias = strings.TrimSpace(model.Alias)
|
||||
model.Name = strings.TrimSpace(model.Name)
|
||||
if model.Alias != "" && model.Name != "" {
|
||||
sanitizedModels = append(sanitizedModels, model)
|
||||
}
|
||||
}
|
||||
entry.Models = sanitizedModels
|
||||
|
||||
// Use API key + base URL as uniqueness key
|
||||
uniqueKey := entry.APIKey + "|" + entry.BaseURL
|
||||
if _, exists := seen[uniqueKey]; exists {
|
||||
continue
|
||||
}
|
||||
seen[uniqueKey] = struct{}{}
|
||||
out = append(out, entry)
|
||||
}
|
||||
cfg.VertexCompatAPIKey = out
|
||||
}
|
||||
@@ -12,6 +12,7 @@ import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"regexp"
|
||||
"sort"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
@@ -156,17 +157,30 @@ func (l *FileRequestLogger) SetEnabled(enabled bool) {
|
||||
// Returns:
|
||||
// - error: An error if logging fails, nil otherwise
|
||||
func (l *FileRequestLogger) LogRequest(url, method string, requestHeaders map[string][]string, body []byte, statusCode int, responseHeaders map[string][]string, response, apiRequest, apiResponse []byte, apiResponseErrors []*interfaces.ErrorMessage) error {
|
||||
if !l.enabled {
|
||||
return l.logRequest(url, method, requestHeaders, body, statusCode, responseHeaders, response, apiRequest, apiResponse, apiResponseErrors, false)
|
||||
}
|
||||
|
||||
// LogRequestWithOptions logs a request with optional forced logging behavior.
|
||||
// The force flag allows writing error logs even when regular request logging is disabled.
|
||||
func (l *FileRequestLogger) LogRequestWithOptions(url, method string, requestHeaders map[string][]string, body []byte, statusCode int, responseHeaders map[string][]string, response, apiRequest, apiResponse []byte, apiResponseErrors []*interfaces.ErrorMessage, force bool) error {
|
||||
return l.logRequest(url, method, requestHeaders, body, statusCode, responseHeaders, response, apiRequest, apiResponse, apiResponseErrors, force)
|
||||
}
|
||||
|
||||
func (l *FileRequestLogger) logRequest(url, method string, requestHeaders map[string][]string, body []byte, statusCode int, responseHeaders map[string][]string, response, apiRequest, apiResponse []byte, apiResponseErrors []*interfaces.ErrorMessage, force bool) error {
|
||||
if !l.enabled && !force {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Ensure logs directory exists
|
||||
if err := l.ensureLogsDir(); err != nil {
|
||||
return fmt.Errorf("failed to create logs directory: %w", err)
|
||||
if errEnsure := l.ensureLogsDir(); errEnsure != nil {
|
||||
return fmt.Errorf("failed to create logs directory: %w", errEnsure)
|
||||
}
|
||||
|
||||
// Generate filename
|
||||
filename := l.generateFilename(url)
|
||||
if force && !l.enabled {
|
||||
filename = l.generateErrorFilename(url)
|
||||
}
|
||||
filePath := filepath.Join(l.logsDir, filename)
|
||||
|
||||
// Decompress response if needed
|
||||
@@ -184,6 +198,12 @@ func (l *FileRequestLogger) LogRequest(url, method string, requestHeaders map[st
|
||||
return fmt.Errorf("failed to write log file: %w", err)
|
||||
}
|
||||
|
||||
if force && !l.enabled {
|
||||
if errCleanup := l.cleanupOldErrorLogs(); errCleanup != nil {
|
||||
log.WithError(errCleanup).Warn("failed to clean up old error logs")
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -239,6 +259,11 @@ func (l *FileRequestLogger) LogStreamingRequest(url, method string, headers map[
|
||||
return writer, nil
|
||||
}
|
||||
|
||||
// generateErrorFilename creates a filename with an error prefix to differentiate forced error logs.
|
||||
func (l *FileRequestLogger) generateErrorFilename(url string) string {
|
||||
return fmt.Sprintf("error-%s", l.generateFilename(url))
|
||||
}
|
||||
|
||||
// ensureLogsDir creates the logs directory if it doesn't exist.
|
||||
//
|
||||
// Returns:
|
||||
@@ -312,6 +337,52 @@ func (l *FileRequestLogger) sanitizeForFilename(path string) string {
|
||||
return sanitized
|
||||
}
|
||||
|
||||
// cleanupOldErrorLogs keeps only the newest 10 forced error log files.
|
||||
func (l *FileRequestLogger) cleanupOldErrorLogs() error {
|
||||
entries, errRead := os.ReadDir(l.logsDir)
|
||||
if errRead != nil {
|
||||
return errRead
|
||||
}
|
||||
|
||||
type logFile struct {
|
||||
name string
|
||||
modTime time.Time
|
||||
}
|
||||
|
||||
var files []logFile
|
||||
for _, entry := range entries {
|
||||
if entry.IsDir() {
|
||||
continue
|
||||
}
|
||||
name := entry.Name()
|
||||
if !strings.HasPrefix(name, "error-") || !strings.HasSuffix(name, ".log") {
|
||||
continue
|
||||
}
|
||||
info, errInfo := entry.Info()
|
||||
if errInfo != nil {
|
||||
log.WithError(errInfo).Warn("failed to read error log info")
|
||||
continue
|
||||
}
|
||||
files = append(files, logFile{name: name, modTime: info.ModTime()})
|
||||
}
|
||||
|
||||
if len(files) <= 10 {
|
||||
return nil
|
||||
}
|
||||
|
||||
sort.Slice(files, func(i, j int) bool {
|
||||
return files[i].modTime.After(files[j].modTime)
|
||||
})
|
||||
|
||||
for _, file := range files[10:] {
|
||||
if errRemove := os.Remove(filepath.Join(l.logsDir, file.name)); errRemove != nil {
|
||||
log.WithError(errRemove).Warnf("failed to remove old error log: %s", file.name)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// formatLogContent creates the complete log content for non-streaming requests.
|
||||
//
|
||||
// Parameters:
|
||||
|
||||
@@ -8,60 +8,140 @@ func GetClaudeModels() []*ModelInfo {
|
||||
return []*ModelInfo{
|
||||
|
||||
{
|
||||
ID: "claude-haiku-4-5-20251001",
|
||||
Object: "model",
|
||||
Created: 1759276800, // 2025-10-01
|
||||
OwnedBy: "anthropic",
|
||||
Type: "claude",
|
||||
DisplayName: "Claude 4.5 Haiku",
|
||||
ID: "claude-haiku-4-5-20251001",
|
||||
Object: "model",
|
||||
Created: 1759276800, // 2025-10-01
|
||||
OwnedBy: "anthropic",
|
||||
Type: "claude",
|
||||
DisplayName: "Claude 4.5 Haiku",
|
||||
ContextLength: 200000,
|
||||
MaxCompletionTokens: 64000,
|
||||
},
|
||||
{
|
||||
ID: "claude-sonnet-4-5-20250929",
|
||||
Object: "model",
|
||||
Created: 1759104000, // 2025-09-29
|
||||
OwnedBy: "anthropic",
|
||||
Type: "claude",
|
||||
DisplayName: "Claude 4.5 Sonnet",
|
||||
ID: "claude-sonnet-4-5-20250929",
|
||||
Object: "model",
|
||||
Created: 1759104000, // 2025-09-29
|
||||
OwnedBy: "anthropic",
|
||||
Type: "claude",
|
||||
DisplayName: "Claude 4.5 Sonnet",
|
||||
ContextLength: 200000,
|
||||
MaxCompletionTokens: 64000,
|
||||
},
|
||||
{
|
||||
ID: "claude-opus-4-1-20250805",
|
||||
Object: "model",
|
||||
Created: 1722945600, // 2025-08-05
|
||||
OwnedBy: "anthropic",
|
||||
Type: "claude",
|
||||
DisplayName: "Claude 4.1 Opus",
|
||||
ID: "claude-sonnet-4-5-thinking",
|
||||
Object: "model",
|
||||
Created: 1759104000, // 2025-09-29
|
||||
OwnedBy: "anthropic",
|
||||
Type: "claude",
|
||||
DisplayName: "Claude 4.5 Sonnet Thinking",
|
||||
ContextLength: 200000,
|
||||
MaxCompletionTokens: 64000,
|
||||
Thinking: &ThinkingSupport{Min: 1024, Max: 100000, ZeroAllowed: false, DynamicAllowed: true},
|
||||
},
|
||||
{
|
||||
ID: "claude-opus-4-20250514",
|
||||
Object: "model",
|
||||
Created: 1715644800, // 2025-05-14
|
||||
OwnedBy: "anthropic",
|
||||
Type: "claude",
|
||||
DisplayName: "Claude 4 Opus",
|
||||
ID: "claude-opus-4-5-thinking",
|
||||
Object: "model",
|
||||
Created: 1761955200, // 2025-11-01
|
||||
OwnedBy: "anthropic",
|
||||
Type: "claude",
|
||||
DisplayName: "Claude 4.5 Opus Thinking",
|
||||
ContextLength: 200000,
|
||||
MaxCompletionTokens: 64000,
|
||||
Thinking: &ThinkingSupport{Min: 1024, Max: 100000, ZeroAllowed: false, DynamicAllowed: true},
|
||||
},
|
||||
{
|
||||
ID: "claude-sonnet-4-20250514",
|
||||
Object: "model",
|
||||
Created: 1715644800, // 2025-05-14
|
||||
OwnedBy: "anthropic",
|
||||
Type: "claude",
|
||||
DisplayName: "Claude 4 Sonnet",
|
||||
ID: "claude-opus-4-5-thinking-low",
|
||||
Object: "model",
|
||||
Created: 1761955200, // 2025-11-01
|
||||
OwnedBy: "anthropic",
|
||||
Type: "claude",
|
||||
DisplayName: "Claude 4.5 Opus Thinking Low",
|
||||
ContextLength: 200000,
|
||||
MaxCompletionTokens: 64000,
|
||||
Thinking: &ThinkingSupport{Min: 1024, Max: 100000, ZeroAllowed: false, DynamicAllowed: true},
|
||||
},
|
||||
{
|
||||
ID: "claude-3-7-sonnet-20250219",
|
||||
Object: "model",
|
||||
Created: 1708300800, // 2025-02-19
|
||||
OwnedBy: "anthropic",
|
||||
Type: "claude",
|
||||
DisplayName: "Claude 3.7 Sonnet",
|
||||
ID: "claude-opus-4-5-thinking-medium",
|
||||
Object: "model",
|
||||
Created: 1761955200, // 2025-11-01
|
||||
OwnedBy: "anthropic",
|
||||
Type: "claude",
|
||||
DisplayName: "Claude 4.5 Opus Thinking Medium",
|
||||
ContextLength: 200000,
|
||||
MaxCompletionTokens: 64000,
|
||||
Thinking: &ThinkingSupport{Min: 1024, Max: 100000, ZeroAllowed: false, DynamicAllowed: true},
|
||||
},
|
||||
{
|
||||
ID: "claude-3-5-haiku-20241022",
|
||||
Object: "model",
|
||||
Created: 1729555200, // 2024-10-22
|
||||
OwnedBy: "anthropic",
|
||||
Type: "claude",
|
||||
DisplayName: "Claude 3.5 Haiku",
|
||||
ID: "claude-opus-4-5-thinking-high",
|
||||
Object: "model",
|
||||
Created: 1761955200, // 2025-11-01
|
||||
OwnedBy: "anthropic",
|
||||
Type: "claude",
|
||||
DisplayName: "Claude 4.5 Opus Thinking High",
|
||||
ContextLength: 200000,
|
||||
MaxCompletionTokens: 64000,
|
||||
Thinking: &ThinkingSupport{Min: 1024, Max: 100000, ZeroAllowed: false, DynamicAllowed: true},
|
||||
},
|
||||
{
|
||||
ID: "claude-opus-4-5-20251101",
|
||||
Object: "model",
|
||||
Created: 1761955200, // 2025-11-01
|
||||
OwnedBy: "anthropic",
|
||||
Type: "claude",
|
||||
DisplayName: "Claude 4.5 Opus",
|
||||
Description: "Premium model combining maximum intelligence with practical performance",
|
||||
ContextLength: 200000,
|
||||
MaxCompletionTokens: 64000,
|
||||
},
|
||||
{
|
||||
ID: "claude-opus-4-1-20250805",
|
||||
Object: "model",
|
||||
Created: 1722945600, // 2025-08-05
|
||||
OwnedBy: "anthropic",
|
||||
Type: "claude",
|
||||
DisplayName: "Claude 4.1 Opus",
|
||||
ContextLength: 200000,
|
||||
MaxCompletionTokens: 32000,
|
||||
},
|
||||
{
|
||||
ID: "claude-opus-4-20250514",
|
||||
Object: "model",
|
||||
Created: 1715644800, // 2025-05-14
|
||||
OwnedBy: "anthropic",
|
||||
Type: "claude",
|
||||
DisplayName: "Claude 4 Opus",
|
||||
ContextLength: 200000,
|
||||
MaxCompletionTokens: 32000,
|
||||
},
|
||||
{
|
||||
ID: "claude-sonnet-4-20250514",
|
||||
Object: "model",
|
||||
Created: 1715644800, // 2025-05-14
|
||||
OwnedBy: "anthropic",
|
||||
Type: "claude",
|
||||
DisplayName: "Claude 4 Sonnet",
|
||||
ContextLength: 200000,
|
||||
MaxCompletionTokens: 64000,
|
||||
},
|
||||
{
|
||||
ID: "claude-3-7-sonnet-20250219",
|
||||
Object: "model",
|
||||
Created: 1708300800, // 2025-02-19
|
||||
OwnedBy: "anthropic",
|
||||
Type: "claude",
|
||||
DisplayName: "Claude 3.7 Sonnet",
|
||||
ContextLength: 128000,
|
||||
MaxCompletionTokens: 8192,
|
||||
},
|
||||
{
|
||||
ID: "claude-3-5-haiku-20241022",
|
||||
Object: "model",
|
||||
Created: 1729555200, // 2024-10-22
|
||||
OwnedBy: "anthropic",
|
||||
Type: "claude",
|
||||
DisplayName: "Claude 3.5 Haiku",
|
||||
ContextLength: 128000,
|
||||
MaxCompletionTokens: 8192,
|
||||
},
|
||||
}
|
||||
}
|
||||
@@ -129,6 +209,21 @@ func GetGeminiModels() []*ModelInfo {
|
||||
SupportedGenerationMethods: []string{"generateContent", "countTokens", "createCachedContent", "batchGenerateContent"},
|
||||
Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true},
|
||||
},
|
||||
{
|
||||
ID: "gemini-3-pro-image-preview",
|
||||
Object: "model",
|
||||
Created: 1737158400,
|
||||
OwnedBy: "google",
|
||||
Type: "gemini",
|
||||
Name: "models/gemini-3-pro-image-preview",
|
||||
Version: "3.0",
|
||||
DisplayName: "Gemini 3 Pro Image Preview",
|
||||
Description: "Gemini 3 Pro Image Preview",
|
||||
InputTokenLimit: 1048576,
|
||||
OutputTokenLimit: 65536,
|
||||
SupportedGenerationMethods: []string{"generateContent", "countTokens", "createCachedContent", "batchGenerateContent"},
|
||||
Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
@@ -207,6 +302,7 @@ func GetGeminiVertexModels() []*ModelInfo {
|
||||
InputTokenLimit: 1048576,
|
||||
OutputTokenLimit: 65536,
|
||||
SupportedGenerationMethods: []string{"generateContent", "countTokens", "createCachedContent", "batchGenerateContent"},
|
||||
Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true},
|
||||
},
|
||||
}
|
||||
}
|
||||
@@ -858,7 +954,6 @@ func GetIFlowModels() []*ModelInfo {
|
||||
}{
|
||||
{ID: "tstars2.0", DisplayName: "TStars-2.0", Description: "iFlow TStars-2.0 multimodal assistant", Created: 1746489600},
|
||||
{ID: "qwen3-coder-plus", DisplayName: "Qwen3-Coder-Plus", Description: "Qwen3 Coder Plus code generation", Created: 1753228800},
|
||||
{ID: "qwen3-coder", DisplayName: "Qwen3-Coder-480B-A35B", Description: "Qwen3 Coder 480B A35B", Created: 1753228800},
|
||||
{ID: "qwen3-max", DisplayName: "Qwen3-Max", Description: "Qwen3 flagship model", Created: 1758672000},
|
||||
{ID: "qwen3-vl-plus", DisplayName: "Qwen3-VL-Plus", Description: "Qwen3 multimodal vision-language", Created: 1758672000},
|
||||
{ID: "qwen3-max-preview", DisplayName: "Qwen3-Max-Preview", Description: "Qwen3 Max preview build", Created: 1757030400},
|
||||
|
||||
@@ -826,7 +826,6 @@ func (r *ModelRegistry) CleanupExpiredQuotas() {
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
// GetFirstAvailableModel returns the first available model for the given handler type.
|
||||
// It prioritizes models by their creation timestamp (newest first) and checks if they have
|
||||
// available clients that are not suspended or over quota.
|
||||
|
||||
@@ -308,17 +308,14 @@ func (e *AIStudioExecutor) translateRequest(req cliproxyexecutor.Request, opts c
|
||||
from := opts.SourceFormat
|
||||
to := sdktranslator.FromString("gemini")
|
||||
payload := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), stream)
|
||||
if budgetOverride, includeOverride, ok := util.GeminiThinkingFromMetadata(req.Metadata); ok && util.ModelSupportsThinking(req.Model) {
|
||||
if budgetOverride != nil {
|
||||
norm := util.NormalizeThinkingBudget(req.Model, *budgetOverride)
|
||||
budgetOverride = &norm
|
||||
}
|
||||
payload = util.ApplyGeminiThinkingConfig(payload, budgetOverride, includeOverride)
|
||||
}
|
||||
payload = applyThinkingMetadata(payload, req.Metadata, req.Model)
|
||||
payload = util.ConvertThinkingLevelToBudget(payload)
|
||||
payload = util.StripThinkingConfigIfUnsupported(req.Model, payload)
|
||||
payload = fixGeminiImageAspectRatio(req.Model, payload)
|
||||
payload = applyPayloadConfig(e.cfg, req.Model, payload)
|
||||
payload, _ = sjson.DeleteBytes(payload, "generationConfig.maxOutputTokens")
|
||||
payload, _ = sjson.DeleteBytes(payload, "generationConfig.responseMimeType")
|
||||
payload, _ = sjson.DeleteBytes(payload, "generationConfig.responseJsonSchema")
|
||||
metadataAction := "generateContent"
|
||||
if req.Metadata != nil {
|
||||
if action, _ := req.Metadata["action"].(string); action == "countTokens" {
|
||||
|
||||
@@ -26,16 +26,18 @@ import (
|
||||
)
|
||||
|
||||
const (
|
||||
antigravityBaseURL = "https://daily-cloudcode-pa.sandbox.googleapis.com"
|
||||
antigravityStreamPath = "/v1internal:streamGenerateContent"
|
||||
antigravityGeneratePath = "/v1internal:generateContent"
|
||||
antigravityModelsPath = "/v1internal:fetchAvailableModels"
|
||||
antigravityClientID = "1071006060591-tmhssin2h21lcre235vtolojh4g403ep.apps.googleusercontent.com"
|
||||
antigravityClientSecret = "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"
|
||||
defaultAntigravityAgent = "antigravity/1.11.3 windows/amd64"
|
||||
antigravityAuthType = "antigravity"
|
||||
refreshSkew = 5 * time.Minute
|
||||
streamScannerBuffer int = 20_971_520
|
||||
antigravityBaseURLDaily = "https://daily-cloudcode-pa.sandbox.googleapis.com"
|
||||
antigravityBaseURLAutopush = "https://autopush-cloudcode-pa.sandbox.googleapis.com"
|
||||
antigravityBaseURLProd = "https://cloudcode-pa.googleapis.com"
|
||||
antigravityStreamPath = "/v1internal:streamGenerateContent"
|
||||
antigravityGeneratePath = "/v1internal:generateContent"
|
||||
antigravityModelsPath = "/v1internal:fetchAvailableModels"
|
||||
antigravityClientID = "1071006060591-tmhssin2h21lcre235vtolojh4g403ep.apps.googleusercontent.com"
|
||||
antigravityClientSecret = "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"
|
||||
defaultAntigravityAgent = "antigravity/1.11.5 windows/amd64"
|
||||
antigravityAuthType = "antigravity"
|
||||
refreshSkew = 3000 * time.Second
|
||||
streamScannerBuffer int = 20_971_520
|
||||
)
|
||||
|
||||
var randSource = rand.New(rand.NewSource(time.Now().UnixNano()))
|
||||
@@ -73,43 +75,78 @@ func (e *AntigravityExecutor) Execute(ctx context.Context, auth *cliproxyauth.Au
|
||||
to := sdktranslator.FromString("antigravity")
|
||||
translated := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), false)
|
||||
|
||||
httpReq, errReq := e.buildRequest(ctx, auth, token, req.Model, translated, false, opts.Alt)
|
||||
if errReq != nil {
|
||||
return resp, errReq
|
||||
}
|
||||
translated = applyThinkingMetadataCLI(translated, req.Metadata, req.Model)
|
||||
|
||||
baseURLs := antigravityBaseURLFallbackOrder(auth)
|
||||
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
|
||||
httpResp, errDo := httpClient.Do(httpReq)
|
||||
if errDo != nil {
|
||||
recordAPIResponseError(ctx, e.cfg, errDo)
|
||||
return resp, errDo
|
||||
}
|
||||
defer func() {
|
||||
|
||||
var lastStatus int
|
||||
var lastBody []byte
|
||||
var lastErr error
|
||||
|
||||
for idx, baseURL := range baseURLs {
|
||||
httpReq, errReq := e.buildRequest(ctx, auth, token, req.Model, translated, false, opts.Alt, baseURL)
|
||||
if errReq != nil {
|
||||
err = errReq
|
||||
return resp, err
|
||||
}
|
||||
|
||||
httpResp, errDo := httpClient.Do(httpReq)
|
||||
if errDo != nil {
|
||||
recordAPIResponseError(ctx, e.cfg, errDo)
|
||||
lastStatus = 0
|
||||
lastBody = nil
|
||||
lastErr = errDo
|
||||
if idx+1 < len(baseURLs) {
|
||||
log.Debugf("antigravity executor: request error on base url %s, retrying with fallback base url: %s", baseURL, baseURLs[idx+1])
|
||||
continue
|
||||
}
|
||||
err = errDo
|
||||
return resp, err
|
||||
}
|
||||
|
||||
recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
|
||||
bodyBytes, errRead := io.ReadAll(httpResp.Body)
|
||||
if errClose := httpResp.Body.Close(); errClose != nil {
|
||||
log.Errorf("antigravity executor: close response body error: %v", errClose)
|
||||
}
|
||||
}()
|
||||
if errRead != nil {
|
||||
recordAPIResponseError(ctx, e.cfg, errRead)
|
||||
err = errRead
|
||||
return resp, err
|
||||
}
|
||||
appendAPIResponseChunk(ctx, e.cfg, bodyBytes)
|
||||
|
||||
recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
|
||||
bodyBytes, errRead := io.ReadAll(httpResp.Body)
|
||||
if errRead != nil {
|
||||
recordAPIResponseError(ctx, e.cfg, errRead)
|
||||
return resp, errRead
|
||||
}
|
||||
appendAPIResponseChunk(ctx, e.cfg, bodyBytes)
|
||||
if httpResp.StatusCode < http.StatusOK || httpResp.StatusCode >= http.StatusMultipleChoices {
|
||||
log.Debugf("antigravity executor: upstream error status: %d, body: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), bodyBytes))
|
||||
lastStatus = httpResp.StatusCode
|
||||
lastBody = append([]byte(nil), bodyBytes...)
|
||||
lastErr = nil
|
||||
if httpResp.StatusCode == http.StatusTooManyRequests && idx+1 < len(baseURLs) {
|
||||
log.Debugf("antigravity executor: rate limited on base url %s, retrying with fallback base url: %s", baseURL, baseURLs[idx+1])
|
||||
continue
|
||||
}
|
||||
err = statusErr{code: httpResp.StatusCode, msg: string(bodyBytes)}
|
||||
return resp, err
|
||||
}
|
||||
|
||||
if httpResp.StatusCode < http.StatusOK || httpResp.StatusCode >= http.StatusMultipleChoices {
|
||||
log.Debugf("antigravity executor: upstream error status: %d, body: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), bodyBytes))
|
||||
err = statusErr{code: httpResp.StatusCode, msg: string(bodyBytes)}
|
||||
return resp, err
|
||||
reporter.publish(ctx, parseAntigravityUsage(bodyBytes))
|
||||
var param any
|
||||
converted := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, bytes.Clone(opts.OriginalRequest), translated, bodyBytes, ¶m)
|
||||
resp = cliproxyexecutor.Response{Payload: []byte(converted)}
|
||||
reporter.ensurePublished(ctx)
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
reporter.publish(ctx, parseAntigravityUsage(bodyBytes))
|
||||
var param any
|
||||
converted := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, bytes.Clone(opts.OriginalRequest), translated, bodyBytes, ¶m)
|
||||
resp = cliproxyexecutor.Response{Payload: []byte(converted)}
|
||||
reporter.ensurePublished(ctx)
|
||||
return resp, nil
|
||||
switch {
|
||||
case lastStatus != 0:
|
||||
err = statusErr{code: lastStatus, msg: string(lastBody)}
|
||||
case lastErr != nil:
|
||||
err = lastErr
|
||||
default:
|
||||
err = statusErr{code: http.StatusServiceUnavailable, msg: "antigravity executor: no base url available"}
|
||||
}
|
||||
return resp, err
|
||||
}
|
||||
|
||||
// ExecuteStream handles streaming requests via the antigravity upstream.
|
||||
@@ -131,75 +168,123 @@ func (e *AntigravityExecutor) ExecuteStream(ctx context.Context, auth *cliproxya
|
||||
to := sdktranslator.FromString("antigravity")
|
||||
translated := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), true)
|
||||
|
||||
httpReq, errReq := e.buildRequest(ctx, auth, token, req.Model, translated, true, opts.Alt)
|
||||
if errReq != nil {
|
||||
return nil, errReq
|
||||
}
|
||||
translated = applyThinkingMetadataCLI(translated, req.Metadata, req.Model)
|
||||
|
||||
baseURLs := antigravityBaseURLFallbackOrder(auth)
|
||||
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
|
||||
httpResp, errDo := httpClient.Do(httpReq)
|
||||
if errDo != nil {
|
||||
recordAPIResponseError(ctx, e.cfg, errDo)
|
||||
return nil, errDo
|
||||
}
|
||||
recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
|
||||
if httpResp.StatusCode < http.StatusOK || httpResp.StatusCode >= http.StatusMultipleChoices {
|
||||
bodyBytes, _ := io.ReadAll(httpResp.Body)
|
||||
appendAPIResponseChunk(ctx, e.cfg, bodyBytes)
|
||||
if errClose := httpResp.Body.Close(); errClose != nil {
|
||||
log.Errorf("antigravity executor: close response body error: %v", errClose)
|
||||
}
|
||||
err = statusErr{code: httpResp.StatusCode, msg: string(bodyBytes)}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
out := make(chan cliproxyexecutor.StreamChunk)
|
||||
stream = out
|
||||
go func() {
|
||||
defer close(out)
|
||||
defer func() {
|
||||
var lastStatus int
|
||||
var lastBody []byte
|
||||
var lastErr error
|
||||
|
||||
for idx, baseURL := range baseURLs {
|
||||
httpReq, errReq := e.buildRequest(ctx, auth, token, req.Model, translated, true, opts.Alt, baseURL)
|
||||
if errReq != nil {
|
||||
err = errReq
|
||||
return nil, err
|
||||
}
|
||||
|
||||
httpResp, errDo := httpClient.Do(httpReq)
|
||||
if errDo != nil {
|
||||
recordAPIResponseError(ctx, e.cfg, errDo)
|
||||
lastStatus = 0
|
||||
lastBody = nil
|
||||
lastErr = errDo
|
||||
if idx+1 < len(baseURLs) {
|
||||
log.Debugf("antigravity executor: request error on base url %s, retrying with fallback base url: %s", baseURL, baseURLs[idx+1])
|
||||
continue
|
||||
}
|
||||
err = errDo
|
||||
return nil, err
|
||||
}
|
||||
recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
|
||||
if httpResp.StatusCode < http.StatusOK || httpResp.StatusCode >= http.StatusMultipleChoices {
|
||||
bodyBytes, errRead := io.ReadAll(httpResp.Body)
|
||||
if errClose := httpResp.Body.Close(); errClose != nil {
|
||||
log.Errorf("antigravity executor: close response body error: %v", errClose)
|
||||
}
|
||||
}()
|
||||
scanner := bufio.NewScanner(httpResp.Body)
|
||||
scanner.Buffer(nil, streamScannerBuffer)
|
||||
var param any
|
||||
for scanner.Scan() {
|
||||
line := scanner.Bytes()
|
||||
appendAPIResponseChunk(ctx, e.cfg, line)
|
||||
|
||||
// Filter usage metadata for all models
|
||||
// Only retain usage statistics in the terminal chunk
|
||||
line = FilterSSEUsageMetadata(line)
|
||||
|
||||
payload := jsonPayload(line)
|
||||
if payload == nil {
|
||||
if errRead != nil {
|
||||
recordAPIResponseError(ctx, e.cfg, errRead)
|
||||
lastStatus = 0
|
||||
lastBody = nil
|
||||
lastErr = errRead
|
||||
if idx+1 < len(baseURLs) {
|
||||
log.Debugf("antigravity executor: read error on base url %s, retrying with fallback base url: %s", baseURL, baseURLs[idx+1])
|
||||
continue
|
||||
}
|
||||
err = errRead
|
||||
return nil, err
|
||||
}
|
||||
appendAPIResponseChunk(ctx, e.cfg, bodyBytes)
|
||||
lastStatus = httpResp.StatusCode
|
||||
lastBody = append([]byte(nil), bodyBytes...)
|
||||
lastErr = nil
|
||||
if httpResp.StatusCode == http.StatusTooManyRequests && idx+1 < len(baseURLs) {
|
||||
log.Debugf("antigravity executor: rate limited on base url %s, retrying with fallback base url: %s", baseURL, baseURLs[idx+1])
|
||||
continue
|
||||
}
|
||||
err = statusErr{code: httpResp.StatusCode, msg: string(bodyBytes)}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if detail, ok := parseAntigravityStreamUsage(payload); ok {
|
||||
reporter.publish(ctx, detail)
|
||||
}
|
||||
out := make(chan cliproxyexecutor.StreamChunk)
|
||||
stream = out
|
||||
go func(resp *http.Response) {
|
||||
defer close(out)
|
||||
defer func() {
|
||||
if errClose := resp.Body.Close(); errClose != nil {
|
||||
log.Errorf("antigravity executor: close response body error: %v", errClose)
|
||||
}
|
||||
}()
|
||||
scanner := bufio.NewScanner(resp.Body)
|
||||
scanner.Buffer(nil, streamScannerBuffer)
|
||||
var param any
|
||||
for scanner.Scan() {
|
||||
line := scanner.Bytes()
|
||||
appendAPIResponseChunk(ctx, e.cfg, line)
|
||||
|
||||
chunks := sdktranslator.TranslateStream(ctx, to, from, req.Model, bytes.Clone(opts.OriginalRequest), translated, bytes.Clone(payload), ¶m)
|
||||
for i := range chunks {
|
||||
out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunks[i])}
|
||||
// Filter usage metadata for all models
|
||||
// Only retain usage statistics in the terminal chunk
|
||||
line = FilterSSEUsageMetadata(line)
|
||||
|
||||
payload := jsonPayload(line)
|
||||
if payload == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
if detail, ok := parseAntigravityStreamUsage(payload); ok {
|
||||
reporter.publish(ctx, detail)
|
||||
}
|
||||
|
||||
chunks := sdktranslator.TranslateStream(ctx, to, from, req.Model, bytes.Clone(opts.OriginalRequest), translated, bytes.Clone(payload), ¶m)
|
||||
for i := range chunks {
|
||||
out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunks[i])}
|
||||
}
|
||||
}
|
||||
}
|
||||
tail := sdktranslator.TranslateStream(ctx, to, from, req.Model, bytes.Clone(opts.OriginalRequest), translated, []byte("[DONE]"), ¶m)
|
||||
for i := range tail {
|
||||
out <- cliproxyexecutor.StreamChunk{Payload: []byte(tail[i])}
|
||||
}
|
||||
if errScan := scanner.Err(); errScan != nil {
|
||||
recordAPIResponseError(ctx, e.cfg, errScan)
|
||||
reporter.publishFailure(ctx)
|
||||
out <- cliproxyexecutor.StreamChunk{Err: errScan}
|
||||
} else {
|
||||
reporter.ensurePublished(ctx)
|
||||
}
|
||||
}()
|
||||
return stream, nil
|
||||
tail := sdktranslator.TranslateStream(ctx, to, from, req.Model, bytes.Clone(opts.OriginalRequest), translated, []byte("[DONE]"), ¶m)
|
||||
for i := range tail {
|
||||
out <- cliproxyexecutor.StreamChunk{Payload: []byte(tail[i])}
|
||||
}
|
||||
if errScan := scanner.Err(); errScan != nil {
|
||||
recordAPIResponseError(ctx, e.cfg, errScan)
|
||||
reporter.publishFailure(ctx)
|
||||
out <- cliproxyexecutor.StreamChunk{Err: errScan}
|
||||
} else {
|
||||
reporter.ensurePublished(ctx)
|
||||
}
|
||||
}(httpResp)
|
||||
return stream, nil
|
||||
}
|
||||
|
||||
switch {
|
||||
case lastStatus != 0:
|
||||
err = statusErr{code: lastStatus, msg: string(lastBody)}
|
||||
case lastErr != nil:
|
||||
err = lastErr
|
||||
default:
|
||||
err = statusErr{code: http.StatusServiceUnavailable, msg: "antigravity executor: no base url available"}
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Refresh refreshes the OAuth token using the refresh token.
|
||||
@@ -230,54 +315,86 @@ func FetchAntigravityModels(ctx context.Context, auth *cliproxyauth.Auth, cfg *c
|
||||
auth = updatedAuth
|
||||
}
|
||||
|
||||
modelsURL := buildBaseURL(auth) + antigravityModelsPath
|
||||
httpReq, errReq := http.NewRequestWithContext(ctx, http.MethodPost, modelsURL, bytes.NewReader([]byte(`{}`)))
|
||||
if errReq != nil {
|
||||
return nil
|
||||
}
|
||||
httpReq.Header.Set("Content-Type", "application/json")
|
||||
httpReq.Header.Set("Authorization", "Bearer "+token)
|
||||
httpReq.Header.Set("User-Agent", resolveUserAgent(auth))
|
||||
if host := resolveHost(auth); host != "" {
|
||||
httpReq.Host = host
|
||||
}
|
||||
|
||||
baseURLs := antigravityBaseURLFallbackOrder(auth)
|
||||
httpClient := newProxyAwareHTTPClient(ctx, cfg, auth, 0)
|
||||
httpResp, errDo := httpClient.Do(httpReq)
|
||||
if errDo != nil {
|
||||
return nil
|
||||
}
|
||||
defer func() {
|
||||
|
||||
for idx, baseURL := range baseURLs {
|
||||
modelsURL := baseURL + antigravityModelsPath
|
||||
httpReq, errReq := http.NewRequestWithContext(ctx, http.MethodPost, modelsURL, bytes.NewReader([]byte(`{}`)))
|
||||
if errReq != nil {
|
||||
return nil
|
||||
}
|
||||
httpReq.Header.Set("Content-Type", "application/json")
|
||||
httpReq.Header.Set("Authorization", "Bearer "+token)
|
||||
httpReq.Header.Set("User-Agent", resolveUserAgent(auth))
|
||||
if host := resolveHost(baseURL); host != "" {
|
||||
httpReq.Host = host
|
||||
}
|
||||
|
||||
httpResp, errDo := httpClient.Do(httpReq)
|
||||
if errDo != nil {
|
||||
if idx+1 < len(baseURLs) {
|
||||
log.Debugf("antigravity executor: models request error on base url %s, retrying with fallback base url: %s", baseURL, baseURLs[idx+1])
|
||||
continue
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
bodyBytes, errRead := io.ReadAll(httpResp.Body)
|
||||
if errClose := httpResp.Body.Close(); errClose != nil {
|
||||
log.Errorf("antigravity executor: close response body error: %v", errClose)
|
||||
}
|
||||
}()
|
||||
if errRead != nil {
|
||||
if idx+1 < len(baseURLs) {
|
||||
log.Debugf("antigravity executor: models read error on base url %s, retrying with fallback base url: %s", baseURL, baseURLs[idx+1])
|
||||
continue
|
||||
}
|
||||
return nil
|
||||
}
|
||||
if httpResp.StatusCode < http.StatusOK || httpResp.StatusCode >= http.StatusMultipleChoices {
|
||||
if httpResp.StatusCode == http.StatusTooManyRequests && idx+1 < len(baseURLs) {
|
||||
log.Debugf("antigravity executor: models request rate limited on base url %s, retrying with fallback base url: %s", baseURL, baseURLs[idx+1])
|
||||
continue
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
bodyBytes, errRead := io.ReadAll(httpResp.Body)
|
||||
if errRead != nil {
|
||||
return nil
|
||||
}
|
||||
if httpResp.StatusCode < http.StatusOK || httpResp.StatusCode >= http.StatusMultipleChoices {
|
||||
return nil
|
||||
}
|
||||
result := gjson.GetBytes(bodyBytes, "models")
|
||||
if !result.Exists() {
|
||||
return nil
|
||||
}
|
||||
|
||||
result := gjson.GetBytes(bodyBytes, "models")
|
||||
if !result.Exists() {
|
||||
return nil
|
||||
now := time.Now().Unix()
|
||||
models := make([]*registry.ModelInfo, 0, len(result.Map()))
|
||||
for id := range result.Map() {
|
||||
id = modelName2Alias(id)
|
||||
if id != "" {
|
||||
modelInfo := ®istry.ModelInfo{
|
||||
ID: id,
|
||||
Name: id,
|
||||
Description: id,
|
||||
DisplayName: id,
|
||||
Version: id,
|
||||
Object: "model",
|
||||
Created: now,
|
||||
OwnedBy: antigravityAuthType,
|
||||
Type: antigravityAuthType,
|
||||
}
|
||||
// Add Thinking support for thinking models
|
||||
if strings.HasSuffix(id, "-thinking") || strings.Contains(id, "-thinking-") {
|
||||
modelInfo.Thinking = ®istry.ThinkingSupport{
|
||||
Min: 1024,
|
||||
Max: 100000,
|
||||
ZeroAllowed: false,
|
||||
DynamicAllowed: true,
|
||||
}
|
||||
}
|
||||
models = append(models, modelInfo)
|
||||
}
|
||||
}
|
||||
return models
|
||||
}
|
||||
|
||||
now := time.Now().Unix()
|
||||
models := make([]*registry.ModelInfo, 0, len(result.Map()))
|
||||
for id := range result.Map() {
|
||||
models = append(models, ®istry.ModelInfo{
|
||||
ID: id,
|
||||
Object: "model",
|
||||
Created: now,
|
||||
OwnedBy: antigravityAuthType,
|
||||
Type: antigravityAuthType,
|
||||
})
|
||||
}
|
||||
return models
|
||||
return nil
|
||||
}
|
||||
|
||||
func (e *AntigravityExecutor) ensureAccessToken(ctx context.Context, auth *cliproxyauth.Auth) (string, *cliproxyauth.Auth, error) {
|
||||
@@ -363,12 +480,15 @@ func (e *AntigravityExecutor) refreshToken(ctx context.Context, auth *cliproxyau
|
||||
return auth, nil
|
||||
}
|
||||
|
||||
func (e *AntigravityExecutor) buildRequest(ctx context.Context, auth *cliproxyauth.Auth, token, modelName string, payload []byte, stream bool, alt string) (*http.Request, error) {
|
||||
func (e *AntigravityExecutor) buildRequest(ctx context.Context, auth *cliproxyauth.Auth, token, modelName string, payload []byte, stream bool, alt, baseURL string) (*http.Request, error) {
|
||||
if token == "" {
|
||||
return nil, statusErr{code: http.StatusUnauthorized, msg: "missing access token"}
|
||||
}
|
||||
|
||||
base := buildBaseURL(auth)
|
||||
base := strings.TrimSuffix(baseURL, "/")
|
||||
if base == "" {
|
||||
base = buildBaseURL(auth)
|
||||
}
|
||||
path := antigravityGeneratePath
|
||||
if stream {
|
||||
path = antigravityStreamPath
|
||||
@@ -389,6 +509,7 @@ func (e *AntigravityExecutor) buildRequest(ctx context.Context, auth *cliproxyau
|
||||
}
|
||||
|
||||
payload = geminiToAntigravity(modelName, payload)
|
||||
payload, _ = sjson.SetBytes(payload, "model", alias2ModelName(modelName))
|
||||
httpReq, errReq := http.NewRequestWithContext(ctx, http.MethodPost, requestURL.String(), bytes.NewReader(payload))
|
||||
if errReq != nil {
|
||||
return nil, errReq
|
||||
@@ -401,7 +522,7 @@ func (e *AntigravityExecutor) buildRequest(ctx context.Context, auth *cliproxyau
|
||||
} else {
|
||||
httpReq.Header.Set("Accept", "application/json")
|
||||
}
|
||||
if host := resolveHost(auth); host != "" {
|
||||
if host := resolveHost(base); host != "" {
|
||||
httpReq.Host = host
|
||||
}
|
||||
|
||||
@@ -485,26 +606,13 @@ func int64Value(value any) (int64, bool) {
|
||||
}
|
||||
|
||||
func buildBaseURL(auth *cliproxyauth.Auth) string {
|
||||
if auth != nil {
|
||||
if auth.Attributes != nil {
|
||||
if v := strings.TrimSpace(auth.Attributes["base_url"]); v != "" {
|
||||
return strings.TrimSuffix(v, "/")
|
||||
}
|
||||
}
|
||||
if auth.Metadata != nil {
|
||||
if v, ok := auth.Metadata["base_url"].(string); ok {
|
||||
v = strings.TrimSpace(v)
|
||||
if v != "" {
|
||||
return strings.TrimSuffix(v, "/")
|
||||
}
|
||||
}
|
||||
}
|
||||
if baseURLs := antigravityBaseURLFallbackOrder(auth); len(baseURLs) > 0 {
|
||||
return baseURLs[0]
|
||||
}
|
||||
return antigravityBaseURL
|
||||
return antigravityBaseURLAutopush
|
||||
}
|
||||
|
||||
func resolveHost(auth *cliproxyauth.Auth) string {
|
||||
base := buildBaseURL(auth)
|
||||
func resolveHost(base string) string {
|
||||
parsed, errParse := url.Parse(base)
|
||||
if errParse != nil {
|
||||
return ""
|
||||
@@ -531,6 +639,37 @@ func resolveUserAgent(auth *cliproxyauth.Auth) string {
|
||||
return defaultAntigravityAgent
|
||||
}
|
||||
|
||||
func antigravityBaseURLFallbackOrder(auth *cliproxyauth.Auth) []string {
|
||||
if base := resolveCustomAntigravityBaseURL(auth); base != "" {
|
||||
return []string{base}
|
||||
}
|
||||
return []string{
|
||||
antigravityBaseURLDaily,
|
||||
antigravityBaseURLAutopush,
|
||||
// antigravityBaseURLProd,
|
||||
}
|
||||
}
|
||||
|
||||
func resolveCustomAntigravityBaseURL(auth *cliproxyauth.Auth) string {
|
||||
if auth == nil {
|
||||
return ""
|
||||
}
|
||||
if auth.Attributes != nil {
|
||||
if v := strings.TrimSpace(auth.Attributes["base_url"]); v != "" {
|
||||
return strings.TrimSuffix(v, "/")
|
||||
}
|
||||
}
|
||||
if auth.Metadata != nil {
|
||||
if v, ok := auth.Metadata["base_url"].(string); ok {
|
||||
v = strings.TrimSpace(v)
|
||||
if v != "" {
|
||||
return strings.TrimSuffix(v, "/")
|
||||
}
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func geminiToAntigravity(modelName string, payload []byte) []byte {
|
||||
template, _ := sjson.Set(string(payload), "model", modelName)
|
||||
template, _ = sjson.Set(template, "userAgent", "antigravity")
|
||||
@@ -540,18 +679,27 @@ func geminiToAntigravity(modelName string, payload []byte) []byte {
|
||||
|
||||
template, _ = sjson.Delete(template, "request.safetySettings")
|
||||
template, _ = sjson.Set(template, "request.toolConfig.functionCallingConfig.mode", "VALIDATED")
|
||||
template, _ = sjson.Delete(template, "request.generationConfig.maxOutputTokens")
|
||||
if !strings.HasPrefix(modelName, "gemini-3-") {
|
||||
if thinkingLevel := gjson.Get(template, "request.generationConfig.thinkingConfig.thinkingLevel"); thinkingLevel.Exists() {
|
||||
template, _ = sjson.Delete(template, "request.generationConfig.thinkingConfig.thinkingLevel")
|
||||
template, _ = sjson.Set(template, "request.generationConfig.thinkingConfig.thinkingBudget", -1)
|
||||
}
|
||||
}
|
||||
|
||||
gjson.Get(template, "request.contents").ForEach(func(key, content gjson.Result) bool {
|
||||
if content.Get("role").String() == "model" {
|
||||
content.Get("parts").ForEach(func(partKey, part gjson.Result) bool {
|
||||
if part.Get("functionCall").Exists() {
|
||||
template, _ = sjson.Set(template, fmt.Sprintf("request.contents.%d.parts.%d.thoughtSignature", key.Int(), partKey.Int()), "skip_thought_signature_validator")
|
||||
if strings.HasPrefix(modelName, "claude-sonnet-") {
|
||||
gjson.Get(template, "request.tools").ForEach(func(key, tool gjson.Result) bool {
|
||||
tool.Get("functionDeclarations").ForEach(func(funKey, funcDecl gjson.Result) bool {
|
||||
if funcDecl.Get("parametersJsonSchema").Exists() {
|
||||
template, _ = sjson.SetRaw(template, fmt.Sprintf("request.tools.%d.functionDeclarations.%d.parameters", key.Int(), funKey.Int()), funcDecl.Get("parametersJsonSchema").Raw)
|
||||
template, _ = sjson.Delete(template, fmt.Sprintf("request.tools.%d.functionDeclarations.%d.parameters.$schema", key.Int(), funKey.Int()))
|
||||
template, _ = sjson.Delete(template, fmt.Sprintf("request.tools.%d.functionDeclarations.%d.parametersJsonSchema", key.Int(), funKey.Int()))
|
||||
}
|
||||
return true
|
||||
})
|
||||
}
|
||||
return true
|
||||
})
|
||||
return true
|
||||
})
|
||||
}
|
||||
|
||||
return []byte(template)
|
||||
}
|
||||
@@ -573,3 +721,39 @@ func generateProjectID() string {
|
||||
randomPart := strings.ToLower(uuid.NewString())[:5]
|
||||
return adj + "-" + noun + "-" + randomPart
|
||||
}
|
||||
|
||||
func modelName2Alias(modelName string) string {
|
||||
switch modelName {
|
||||
case "rev19-uic3-1p":
|
||||
return "gemini-2.5-computer-use-preview-10-2025"
|
||||
case "gemini-3-pro-image":
|
||||
return "gemini-3-pro-image-preview"
|
||||
case "gemini-3-pro-high":
|
||||
return "gemini-3-pro-preview"
|
||||
case "claude-sonnet-4-5":
|
||||
return "gemini-claude-sonnet-4-5"
|
||||
case "claude-sonnet-4-5-thinking":
|
||||
return "gemini-claude-sonnet-4-5-thinking"
|
||||
case "chat_20706", "chat_23310", "gemini-2.5-flash-thinking", "gemini-3-pro-low", "gemini-2.5-pro":
|
||||
return ""
|
||||
default:
|
||||
return modelName
|
||||
}
|
||||
}
|
||||
|
||||
func alias2ModelName(modelName string) string {
|
||||
switch modelName {
|
||||
case "gemini-2.5-computer-use-preview-10-2025":
|
||||
return "rev19-uic3-1p"
|
||||
case "gemini-3-pro-image-preview":
|
||||
return "gemini-3-pro-image"
|
||||
case "gemini-3-pro-preview":
|
||||
return "gemini-3-pro-high"
|
||||
case "gemini-claude-sonnet-4-5":
|
||||
return "claude-sonnet-4-5"
|
||||
case "gemini-claude-sonnet-4-5-thinking":
|
||||
return "claude-sonnet-4-5-thinking"
|
||||
default:
|
||||
return modelName
|
||||
}
|
||||
}
|
||||
|
||||
@@ -17,6 +17,7 @@ import (
|
||||
claudeauth "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/claude"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/misc"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
||||
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
|
||||
@@ -58,18 +59,27 @@ func (e *ClaudeExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, r
|
||||
body, _ = sjson.SetBytes(body, "model", modelOverride)
|
||||
modelForUpstream = modelOverride
|
||||
}
|
||||
// Inject thinking config based on model suffix for thinking variants
|
||||
body = e.injectThinkingConfig(req.Model, body)
|
||||
|
||||
if !strings.HasPrefix(modelForUpstream, "claude-3-5-haiku") {
|
||||
body, _ = sjson.SetRawBytes(body, "system", []byte(misc.ClaudeCodeInstructions))
|
||||
body = checkSystemInstructions(body)
|
||||
}
|
||||
body = applyPayloadConfig(e.cfg, req.Model, body)
|
||||
|
||||
// Ensure max_tokens > thinking.budget_tokens when thinking is enabled
|
||||
body = ensureMaxTokensForThinking(req.Model, body)
|
||||
|
||||
// Extract betas from body and convert to header
|
||||
var extraBetas []string
|
||||
extraBetas, body = extractAndRemoveBetas(body)
|
||||
|
||||
url := fmt.Sprintf("%s/v1/messages?beta=true", baseURL)
|
||||
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(body))
|
||||
if err != nil {
|
||||
return resp, err
|
||||
}
|
||||
applyClaudeHeaders(httpReq, auth, apiKey, false)
|
||||
applyClaudeHeaders(httpReq, auth, apiKey, false, extraBetas)
|
||||
var authID, authLabel, authType, authValue string
|
||||
if auth != nil {
|
||||
authID = auth.ID
|
||||
@@ -154,15 +164,24 @@ func (e *ClaudeExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A
|
||||
if modelOverride := e.resolveUpstreamModel(req.Model, auth); modelOverride != "" {
|
||||
body, _ = sjson.SetBytes(body, "model", modelOverride)
|
||||
}
|
||||
body, _ = sjson.SetRawBytes(body, "system", []byte(misc.ClaudeCodeInstructions))
|
||||
// Inject thinking config based on model suffix for thinking variants
|
||||
body = e.injectThinkingConfig(req.Model, body)
|
||||
body = checkSystemInstructions(body)
|
||||
body = applyPayloadConfig(e.cfg, req.Model, body)
|
||||
|
||||
// Ensure max_tokens > thinking.budget_tokens when thinking is enabled
|
||||
body = ensureMaxTokensForThinking(req.Model, body)
|
||||
|
||||
// Extract betas from body and convert to header
|
||||
var extraBetas []string
|
||||
extraBetas, body = extractAndRemoveBetas(body)
|
||||
|
||||
url := fmt.Sprintf("%s/v1/messages?beta=true", baseURL)
|
||||
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(body))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
applyClaudeHeaders(httpReq, auth, apiKey, true)
|
||||
applyClaudeHeaders(httpReq, auth, apiKey, true, extraBetas)
|
||||
var authID, authLabel, authType, authValue string
|
||||
if auth != nil {
|
||||
authID = auth.ID
|
||||
@@ -283,15 +302,19 @@ func (e *ClaudeExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Aut
|
||||
}
|
||||
|
||||
if !strings.HasPrefix(modelForUpstream, "claude-3-5-haiku") {
|
||||
body, _ = sjson.SetRawBytes(body, "system", []byte(misc.ClaudeCodeInstructions))
|
||||
body = checkSystemInstructions(body)
|
||||
}
|
||||
|
||||
// Extract betas from body and convert to header (for count_tokens too)
|
||||
var extraBetas []string
|
||||
extraBetas, body = extractAndRemoveBetas(body)
|
||||
|
||||
url := fmt.Sprintf("%s/v1/messages/count_tokens?beta=true", baseURL)
|
||||
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(body))
|
||||
if err != nil {
|
||||
return cliproxyexecutor.Response{}, err
|
||||
}
|
||||
applyClaudeHeaders(httpReq, auth, apiKey, false)
|
||||
applyClaudeHeaders(httpReq, auth, apiKey, false, extraBetas)
|
||||
var authID, authLabel, authType, authValue string
|
||||
if auth != nil {
|
||||
authID = auth.ID
|
||||
@@ -383,10 +406,101 @@ func (e *ClaudeExecutor) Refresh(ctx context.Context, auth *cliproxyauth.Auth) (
|
||||
return auth, nil
|
||||
}
|
||||
|
||||
// extractAndRemoveBetas extracts the "betas" array from the body and removes it.
|
||||
// Returns the extracted betas as a string slice and the modified body.
|
||||
func extractAndRemoveBetas(body []byte) ([]string, []byte) {
|
||||
betasResult := gjson.GetBytes(body, "betas")
|
||||
if !betasResult.Exists() {
|
||||
return nil, body
|
||||
}
|
||||
var betas []string
|
||||
if betasResult.IsArray() {
|
||||
for _, item := range betasResult.Array() {
|
||||
if s := strings.TrimSpace(item.String()); s != "" {
|
||||
betas = append(betas, s)
|
||||
}
|
||||
}
|
||||
} else if s := strings.TrimSpace(betasResult.String()); s != "" {
|
||||
betas = append(betas, s)
|
||||
}
|
||||
body, _ = sjson.DeleteBytes(body, "betas")
|
||||
return betas, body
|
||||
}
|
||||
|
||||
// injectThinkingConfig adds thinking configuration based on model name suffix
|
||||
func (e *ClaudeExecutor) injectThinkingConfig(modelName string, body []byte) []byte {
|
||||
// Only inject if thinking config is not already present
|
||||
if gjson.GetBytes(body, "thinking").Exists() {
|
||||
return body
|
||||
}
|
||||
|
||||
var budgetTokens int
|
||||
switch {
|
||||
case strings.HasSuffix(modelName, "-thinking-low"):
|
||||
budgetTokens = 1024
|
||||
case strings.HasSuffix(modelName, "-thinking-medium"):
|
||||
budgetTokens = 8192
|
||||
case strings.HasSuffix(modelName, "-thinking-high"):
|
||||
budgetTokens = 24576
|
||||
case strings.HasSuffix(modelName, "-thinking"):
|
||||
// Default thinking without suffix uses medium budget
|
||||
budgetTokens = 8192
|
||||
default:
|
||||
return body
|
||||
}
|
||||
|
||||
body, _ = sjson.SetBytes(body, "thinking.type", "enabled")
|
||||
body, _ = sjson.SetBytes(body, "thinking.budget_tokens", budgetTokens)
|
||||
return body
|
||||
}
|
||||
|
||||
// ensureMaxTokensForThinking ensures max_tokens > thinking.budget_tokens when thinking is enabled.
|
||||
// Anthropic API requires this constraint; violating it returns a 400 error.
|
||||
// This function should be called after all thinking configuration is finalized.
|
||||
// It looks up the model's MaxCompletionTokens from the registry to use as the cap.
|
||||
func ensureMaxTokensForThinking(modelName string, body []byte) []byte {
|
||||
thinkingType := gjson.GetBytes(body, "thinking.type").String()
|
||||
if thinkingType != "enabled" {
|
||||
return body
|
||||
}
|
||||
|
||||
budgetTokens := gjson.GetBytes(body, "thinking.budget_tokens").Int()
|
||||
if budgetTokens <= 0 {
|
||||
return body
|
||||
}
|
||||
|
||||
maxTokens := gjson.GetBytes(body, "max_tokens").Int()
|
||||
|
||||
// Look up the model's max completion tokens from the registry
|
||||
maxCompletionTokens := 0
|
||||
if modelInfo := registry.GetGlobalRegistry().GetModelInfo(modelName); modelInfo != nil {
|
||||
maxCompletionTokens = modelInfo.MaxCompletionTokens
|
||||
}
|
||||
|
||||
// Fall back to budget + buffer if registry lookup fails or returns 0
|
||||
const fallbackBuffer = 4000
|
||||
requiredMaxTokens := budgetTokens + fallbackBuffer
|
||||
if maxCompletionTokens > 0 {
|
||||
requiredMaxTokens = int64(maxCompletionTokens)
|
||||
}
|
||||
|
||||
if maxTokens < requiredMaxTokens {
|
||||
body, _ = sjson.SetBytes(body, "max_tokens", requiredMaxTokens)
|
||||
}
|
||||
return body
|
||||
}
|
||||
|
||||
func (e *ClaudeExecutor) resolveUpstreamModel(alias string, auth *cliproxyauth.Auth) string {
|
||||
if alias == "" {
|
||||
return ""
|
||||
}
|
||||
// Hardcoded mappings for thinking models to actual Claude model names
|
||||
switch alias {
|
||||
case "claude-opus-4-5-thinking", "claude-opus-4-5-thinking-low", "claude-opus-4-5-thinking-medium", "claude-opus-4-5-thinking-high":
|
||||
return "claude-opus-4-5-20251101"
|
||||
case "claude-sonnet-4-5-thinking":
|
||||
return "claude-sonnet-4-5-20250929"
|
||||
}
|
||||
entry := e.resolveClaudeConfig(auth)
|
||||
if entry == nil {
|
||||
return ""
|
||||
@@ -530,7 +644,7 @@ func decodeResponseBody(body io.ReadCloser, contentEncoding string) (io.ReadClos
|
||||
return body, nil
|
||||
}
|
||||
|
||||
func applyClaudeHeaders(r *http.Request, auth *cliproxyauth.Auth, apiKey string, stream bool) {
|
||||
func applyClaudeHeaders(r *http.Request, auth *cliproxyauth.Auth, apiKey string, stream bool, extraBetas []string) {
|
||||
r.Header.Set("Authorization", "Bearer "+apiKey)
|
||||
r.Header.Set("Content-Type", "application/json")
|
||||
|
||||
@@ -539,15 +653,30 @@ func applyClaudeHeaders(r *http.Request, auth *cliproxyauth.Auth, apiKey string,
|
||||
ginHeaders = ginCtx.Request.Header
|
||||
}
|
||||
|
||||
baseBetas := "claude-code-20250219,oauth-2025-04-20,interleaved-thinking-2025-05-14,fine-grained-tool-streaming-2025-05-14"
|
||||
if val := strings.TrimSpace(ginHeaders.Get("Anthropic-Beta")); val != "" {
|
||||
baseBetas = val
|
||||
if !strings.Contains(val, "oauth") {
|
||||
val += ",oauth-2025-04-20"
|
||||
baseBetas += ",oauth-2025-04-20"
|
||||
}
|
||||
r.Header.Set("Anthropic-Beta", val)
|
||||
} else {
|
||||
r.Header.Set("Anthropic-Beta", "claude-code-20250219,oauth-2025-04-20,interleaved-thinking-2025-05-14,fine-grained-tool-streaming-2025-05-14")
|
||||
}
|
||||
|
||||
// Merge extra betas from request body
|
||||
if len(extraBetas) > 0 {
|
||||
existingSet := make(map[string]bool)
|
||||
for _, b := range strings.Split(baseBetas, ",") {
|
||||
existingSet[strings.TrimSpace(b)] = true
|
||||
}
|
||||
for _, beta := range extraBetas {
|
||||
beta = strings.TrimSpace(beta)
|
||||
if beta != "" && !existingSet[beta] {
|
||||
baseBetas += "," + beta
|
||||
existingSet[beta] = true
|
||||
}
|
||||
}
|
||||
}
|
||||
r.Header.Set("Anthropic-Beta", baseBetas)
|
||||
|
||||
misc.EnsureHeader(r.Header, ginHeaders, "Anthropic-Version", "2023-06-01")
|
||||
misc.EnsureHeader(r.Header, ginHeaders, "Anthropic-Dangerous-Direct-Browser-Access", "true")
|
||||
misc.EnsureHeader(r.Header, ginHeaders, "X-App", "cli")
|
||||
@@ -590,3 +719,22 @@ func claudeCreds(a *cliproxyauth.Auth) (apiKey, baseURL string) {
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func checkSystemInstructions(payload []byte) []byte {
|
||||
system := gjson.GetBytes(payload, "system")
|
||||
claudeCodeInstructions := `[{"type":"text","text":"You are Claude Code, Anthropic's official CLI for Claude."}]`
|
||||
if system.IsArray() {
|
||||
if gjson.GetBytes(payload, "system.0.text").String() != "You are Claude Code, Anthropic's official CLI for Claude." {
|
||||
system.ForEach(func(_, part gjson.Result) bool {
|
||||
if part.Get("type").String() == "text" {
|
||||
claudeCodeInstructions, _ = sjson.SetRaw(claudeCodeInstructions, "-1", part.Raw)
|
||||
}
|
||||
return true
|
||||
})
|
||||
payload, _ = sjson.SetRawBytes(payload, "system", []byte(claudeCodeInstructions))
|
||||
}
|
||||
} else {
|
||||
payload, _ = sjson.SetRawBytes(payload, "system", []byte(claudeCodeInstructions))
|
||||
}
|
||||
return payload
|
||||
}
|
||||
|
||||
@@ -62,15 +62,8 @@ func (e *GeminiCLIExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth
|
||||
|
||||
from := opts.SourceFormat
|
||||
to := sdktranslator.FromString("gemini-cli")
|
||||
budgetOverride, includeOverride, hasOverride := util.GeminiThinkingFromMetadata(req.Metadata)
|
||||
basePayload := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), false)
|
||||
if hasOverride && util.ModelSupportsThinking(req.Model) {
|
||||
if budgetOverride != nil {
|
||||
norm := util.NormalizeThinkingBudget(req.Model, *budgetOverride)
|
||||
budgetOverride = &norm
|
||||
}
|
||||
basePayload = util.ApplyGeminiCLIThinkingConfig(basePayload, budgetOverride, includeOverride)
|
||||
}
|
||||
basePayload = applyThinkingMetadataCLI(basePayload, req.Metadata, req.Model)
|
||||
basePayload = util.StripThinkingConfigIfUnsupported(req.Model, basePayload)
|
||||
basePayload = fixGeminiCLIImageAspectRatio(req.Model, basePayload)
|
||||
basePayload = applyPayloadConfigWithRoot(e.cfg, req.Model, "gemini", "request", basePayload)
|
||||
@@ -204,15 +197,8 @@ func (e *GeminiCLIExecutor) ExecuteStream(ctx context.Context, auth *cliproxyaut
|
||||
|
||||
from := opts.SourceFormat
|
||||
to := sdktranslator.FromString("gemini-cli")
|
||||
budgetOverride, includeOverride, hasOverride := util.GeminiThinkingFromMetadata(req.Metadata)
|
||||
basePayload := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), true)
|
||||
if hasOverride && util.ModelSupportsThinking(req.Model) {
|
||||
if budgetOverride != nil {
|
||||
norm := util.NormalizeThinkingBudget(req.Model, *budgetOverride)
|
||||
budgetOverride = &norm
|
||||
}
|
||||
basePayload = util.ApplyGeminiCLIThinkingConfig(basePayload, budgetOverride, includeOverride)
|
||||
}
|
||||
basePayload = applyThinkingMetadataCLI(basePayload, req.Metadata, req.Model)
|
||||
basePayload = util.StripThinkingConfigIfUnsupported(req.Model, basePayload)
|
||||
basePayload = fixGeminiCLIImageAspectRatio(req.Model, basePayload)
|
||||
basePayload = applyPayloadConfigWithRoot(e.cfg, req.Model, "gemini", "request", basePayload)
|
||||
@@ -408,16 +394,9 @@ func (e *GeminiCLIExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.
|
||||
var lastStatus int
|
||||
var lastBody []byte
|
||||
|
||||
budgetOverride, includeOverride, hasOverride := util.GeminiThinkingFromMetadata(req.Metadata)
|
||||
for _, attemptModel := range models {
|
||||
payload := sdktranslator.TranslateRequest(from, to, attemptModel, bytes.Clone(req.Payload), false)
|
||||
if hasOverride && util.ModelSupportsThinking(req.Model) {
|
||||
if budgetOverride != nil {
|
||||
norm := util.NormalizeThinkingBudget(req.Model, *budgetOverride)
|
||||
budgetOverride = &norm
|
||||
}
|
||||
payload = util.ApplyGeminiCLIThinkingConfig(payload, budgetOverride, includeOverride)
|
||||
}
|
||||
payload = applyThinkingMetadataCLI(payload, req.Metadata, req.Model)
|
||||
payload = deleteJSONField(payload, "project")
|
||||
payload = deleteJSONField(payload, "model")
|
||||
payload = deleteJSONField(payload, "request.safetySettings")
|
||||
|
||||
@@ -79,13 +79,7 @@ func (e *GeminiExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, r
|
||||
from := opts.SourceFormat
|
||||
to := sdktranslator.FromString("gemini")
|
||||
body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), false)
|
||||
if budgetOverride, includeOverride, ok := util.GeminiThinkingFromMetadata(req.Metadata); ok && util.ModelSupportsThinking(req.Model) {
|
||||
if budgetOverride != nil {
|
||||
norm := util.NormalizeThinkingBudget(req.Model, *budgetOverride)
|
||||
budgetOverride = &norm
|
||||
}
|
||||
body = util.ApplyGeminiThinkingConfig(body, budgetOverride, includeOverride)
|
||||
}
|
||||
body = applyThinkingMetadata(body, req.Metadata, req.Model)
|
||||
body = util.StripThinkingConfigIfUnsupported(req.Model, body)
|
||||
body = fixGeminiImageAspectRatio(req.Model, body)
|
||||
body = applyPayloadConfig(e.cfg, req.Model, body)
|
||||
@@ -174,13 +168,7 @@ func (e *GeminiExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A
|
||||
from := opts.SourceFormat
|
||||
to := sdktranslator.FromString("gemini")
|
||||
body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), true)
|
||||
if budgetOverride, includeOverride, ok := util.GeminiThinkingFromMetadata(req.Metadata); ok && util.ModelSupportsThinking(req.Model) {
|
||||
if budgetOverride != nil {
|
||||
norm := util.NormalizeThinkingBudget(req.Model, *budgetOverride)
|
||||
budgetOverride = &norm
|
||||
}
|
||||
body = util.ApplyGeminiThinkingConfig(body, budgetOverride, includeOverride)
|
||||
}
|
||||
body = applyThinkingMetadata(body, req.Metadata, req.Model)
|
||||
body = util.StripThinkingConfigIfUnsupported(req.Model, body)
|
||||
body = fixGeminiImageAspectRatio(req.Model, body)
|
||||
body = applyPayloadConfig(e.cfg, req.Model, body)
|
||||
@@ -288,13 +276,7 @@ func (e *GeminiExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Aut
|
||||
from := opts.SourceFormat
|
||||
to := sdktranslator.FromString("gemini")
|
||||
translatedReq := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), false)
|
||||
if budgetOverride, includeOverride, ok := util.GeminiThinkingFromMetadata(req.Metadata); ok && util.ModelSupportsThinking(req.Model) {
|
||||
if budgetOverride != nil {
|
||||
norm := util.NormalizeThinkingBudget(req.Model, *budgetOverride)
|
||||
budgetOverride = &norm
|
||||
}
|
||||
translatedReq = util.ApplyGeminiThinkingConfig(translatedReq, budgetOverride, includeOverride)
|
||||
}
|
||||
translatedReq = applyThinkingMetadata(translatedReq, req.Metadata, req.Model)
|
||||
translatedReq = util.StripThinkingConfigIfUnsupported(req.Model, translatedReq)
|
||||
translatedReq = fixGeminiImageAspectRatio(req.Model, translatedReq)
|
||||
respCtx := context.WithValue(ctx, "alt", opts.Alt)
|
||||
|
||||
@@ -51,11 +51,238 @@ func (e *GeminiVertexExecutor) PrepareRequest(_ *http.Request, _ *cliproxyauth.A
|
||||
|
||||
// Execute handles non-streaming requests.
|
||||
func (e *GeminiVertexExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (resp cliproxyexecutor.Response, err error) {
|
||||
projectID, location, saJSON, errCreds := vertexCreds(auth)
|
||||
if errCreds != nil {
|
||||
return resp, errCreds
|
||||
// Try API key authentication first
|
||||
apiKey, baseURL := vertexAPICreds(auth)
|
||||
|
||||
// If no API key found, fall back to service account authentication
|
||||
if apiKey == "" {
|
||||
projectID, location, saJSON, errCreds := vertexCreds(auth)
|
||||
if errCreds != nil {
|
||||
return resp, errCreds
|
||||
}
|
||||
return e.executeWithServiceAccount(ctx, auth, req, opts, projectID, location, saJSON)
|
||||
}
|
||||
|
||||
// Use API key authentication
|
||||
return e.executeWithAPIKey(ctx, auth, req, opts, apiKey, baseURL)
|
||||
}
|
||||
|
||||
// ExecuteStream handles SSE streaming for Vertex.
|
||||
func (e *GeminiVertexExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (stream <-chan cliproxyexecutor.StreamChunk, err error) {
|
||||
// Try API key authentication first
|
||||
apiKey, baseURL := vertexAPICreds(auth)
|
||||
|
||||
// If no API key found, fall back to service account authentication
|
||||
if apiKey == "" {
|
||||
projectID, location, saJSON, errCreds := vertexCreds(auth)
|
||||
if errCreds != nil {
|
||||
return nil, errCreds
|
||||
}
|
||||
return e.executeStreamWithServiceAccount(ctx, auth, req, opts, projectID, location, saJSON)
|
||||
}
|
||||
|
||||
// Use API key authentication
|
||||
return e.executeStreamWithAPIKey(ctx, auth, req, opts, apiKey, baseURL)
|
||||
}
|
||||
|
||||
// CountTokens calls Vertex countTokens endpoint.
|
||||
func (e *GeminiVertexExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
|
||||
// Try API key authentication first
|
||||
apiKey, baseURL := vertexAPICreds(auth)
|
||||
|
||||
// If no API key found, fall back to service account authentication
|
||||
if apiKey == "" {
|
||||
projectID, location, saJSON, errCreds := vertexCreds(auth)
|
||||
if errCreds != nil {
|
||||
return cliproxyexecutor.Response{}, errCreds
|
||||
}
|
||||
return e.countTokensWithServiceAccount(ctx, auth, req, opts, projectID, location, saJSON)
|
||||
}
|
||||
|
||||
// Use API key authentication
|
||||
return e.countTokensWithAPIKey(ctx, auth, req, opts, apiKey, baseURL)
|
||||
}
|
||||
|
||||
// countTokensWithServiceAccount handles token counting using service account credentials.
|
||||
func (e *GeminiVertexExecutor) countTokensWithServiceAccount(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options, projectID, location string, saJSON []byte) (cliproxyexecutor.Response, error) {
|
||||
from := opts.SourceFormat
|
||||
to := sdktranslator.FromString("gemini")
|
||||
translatedReq := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), false)
|
||||
if budgetOverride, includeOverride, ok := util.GeminiThinkingFromMetadata(req.Metadata); ok && util.ModelSupportsThinking(req.Model) {
|
||||
if budgetOverride != nil {
|
||||
norm := util.NormalizeThinkingBudget(req.Model, *budgetOverride)
|
||||
budgetOverride = &norm
|
||||
}
|
||||
translatedReq = util.ApplyGeminiThinkingConfig(translatedReq, budgetOverride, includeOverride)
|
||||
}
|
||||
translatedReq = util.StripThinkingConfigIfUnsupported(req.Model, translatedReq)
|
||||
translatedReq = fixGeminiImageAspectRatio(req.Model, translatedReq)
|
||||
respCtx := context.WithValue(ctx, "alt", opts.Alt)
|
||||
translatedReq, _ = sjson.DeleteBytes(translatedReq, "tools")
|
||||
translatedReq, _ = sjson.DeleteBytes(translatedReq, "generationConfig")
|
||||
translatedReq, _ = sjson.DeleteBytes(translatedReq, "safetySettings")
|
||||
|
||||
baseURL := vertexBaseURL(location)
|
||||
url := fmt.Sprintf("%s/%s/projects/%s/locations/%s/publishers/google/models/%s:%s", baseURL, vertexAPIVersion, projectID, location, req.Model, "countTokens")
|
||||
|
||||
httpReq, errNewReq := http.NewRequestWithContext(respCtx, http.MethodPost, url, bytes.NewReader(translatedReq))
|
||||
if errNewReq != nil {
|
||||
return cliproxyexecutor.Response{}, errNewReq
|
||||
}
|
||||
httpReq.Header.Set("Content-Type", "application/json")
|
||||
if token, errTok := vertexAccessToken(ctx, e.cfg, auth, saJSON); errTok == nil && token != "" {
|
||||
httpReq.Header.Set("Authorization", "Bearer "+token)
|
||||
} else if errTok != nil {
|
||||
log.Errorf("vertex executor: access token error: %v", errTok)
|
||||
return cliproxyexecutor.Response{}, statusErr{code: 500, msg: "internal server error"}
|
||||
}
|
||||
applyGeminiHeaders(httpReq, auth)
|
||||
|
||||
var authID, authLabel, authType, authValue string
|
||||
if auth != nil {
|
||||
authID = auth.ID
|
||||
authLabel = auth.Label
|
||||
authType, authValue = auth.AccountInfo()
|
||||
}
|
||||
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{
|
||||
URL: url,
|
||||
Method: http.MethodPost,
|
||||
Headers: httpReq.Header.Clone(),
|
||||
Body: translatedReq,
|
||||
Provider: e.Identifier(),
|
||||
AuthID: authID,
|
||||
AuthLabel: authLabel,
|
||||
AuthType: authType,
|
||||
AuthValue: authValue,
|
||||
})
|
||||
|
||||
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
|
||||
httpResp, errDo := httpClient.Do(httpReq)
|
||||
if errDo != nil {
|
||||
recordAPIResponseError(ctx, e.cfg, errDo)
|
||||
return cliproxyexecutor.Response{}, errDo
|
||||
}
|
||||
defer func() {
|
||||
if errClose := httpResp.Body.Close(); errClose != nil {
|
||||
log.Errorf("vertex executor: close response body error: %v", errClose)
|
||||
}
|
||||
}()
|
||||
recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
|
||||
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
|
||||
b, _ := io.ReadAll(httpResp.Body)
|
||||
appendAPIResponseChunk(ctx, e.cfg, b)
|
||||
log.Debugf("request error, error status: %d, error body: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
|
||||
return cliproxyexecutor.Response{}, statusErr{code: httpResp.StatusCode, msg: string(b)}
|
||||
}
|
||||
data, errRead := io.ReadAll(httpResp.Body)
|
||||
if errRead != nil {
|
||||
recordAPIResponseError(ctx, e.cfg, errRead)
|
||||
return cliproxyexecutor.Response{}, errRead
|
||||
}
|
||||
appendAPIResponseChunk(ctx, e.cfg, data)
|
||||
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
|
||||
log.Debugf("request error, error status: %d, error body: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), data))
|
||||
return cliproxyexecutor.Response{}, statusErr{code: httpResp.StatusCode, msg: string(data)}
|
||||
}
|
||||
count := gjson.GetBytes(data, "totalTokens").Int()
|
||||
out := sdktranslator.TranslateTokenCount(ctx, to, from, count, data)
|
||||
return cliproxyexecutor.Response{Payload: []byte(out)}, nil
|
||||
}
|
||||
|
||||
// countTokensWithAPIKey handles token counting using API key credentials.
|
||||
func (e *GeminiVertexExecutor) countTokensWithAPIKey(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options, apiKey, baseURL string) (cliproxyexecutor.Response, error) {
|
||||
from := opts.SourceFormat
|
||||
to := sdktranslator.FromString("gemini")
|
||||
translatedReq := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), false)
|
||||
if budgetOverride, includeOverride, ok := util.GeminiThinkingFromMetadata(req.Metadata); ok && util.ModelSupportsThinking(req.Model) {
|
||||
if budgetOverride != nil {
|
||||
norm := util.NormalizeThinkingBudget(req.Model, *budgetOverride)
|
||||
budgetOverride = &norm
|
||||
}
|
||||
translatedReq = util.ApplyGeminiThinkingConfig(translatedReq, budgetOverride, includeOverride)
|
||||
}
|
||||
translatedReq = util.StripThinkingConfigIfUnsupported(req.Model, translatedReq)
|
||||
translatedReq = fixGeminiImageAspectRatio(req.Model, translatedReq)
|
||||
respCtx := context.WithValue(ctx, "alt", opts.Alt)
|
||||
translatedReq, _ = sjson.DeleteBytes(translatedReq, "tools")
|
||||
translatedReq, _ = sjson.DeleteBytes(translatedReq, "generationConfig")
|
||||
translatedReq, _ = sjson.DeleteBytes(translatedReq, "safetySettings")
|
||||
|
||||
// For API key auth, use simpler URL format without project/location
|
||||
if baseURL == "" {
|
||||
baseURL = "https://generativelanguage.googleapis.com"
|
||||
}
|
||||
url := fmt.Sprintf("%s/%s/publishers/google/models/%s:%s", baseURL, vertexAPIVersion, req.Model, "countTokens")
|
||||
|
||||
httpReq, errNewReq := http.NewRequestWithContext(respCtx, http.MethodPost, url, bytes.NewReader(translatedReq))
|
||||
if errNewReq != nil {
|
||||
return cliproxyexecutor.Response{}, errNewReq
|
||||
}
|
||||
httpReq.Header.Set("Content-Type", "application/json")
|
||||
if apiKey != "" {
|
||||
httpReq.Header.Set("x-goog-api-key", apiKey)
|
||||
}
|
||||
applyGeminiHeaders(httpReq, auth)
|
||||
|
||||
var authID, authLabel, authType, authValue string
|
||||
if auth != nil {
|
||||
authID = auth.ID
|
||||
authLabel = auth.Label
|
||||
authType, authValue = auth.AccountInfo()
|
||||
}
|
||||
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{
|
||||
URL: url,
|
||||
Method: http.MethodPost,
|
||||
Headers: httpReq.Header.Clone(),
|
||||
Body: translatedReq,
|
||||
Provider: e.Identifier(),
|
||||
AuthID: authID,
|
||||
AuthLabel: authLabel,
|
||||
AuthType: authType,
|
||||
AuthValue: authValue,
|
||||
})
|
||||
|
||||
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
|
||||
httpResp, errDo := httpClient.Do(httpReq)
|
||||
if errDo != nil {
|
||||
recordAPIResponseError(ctx, e.cfg, errDo)
|
||||
return cliproxyexecutor.Response{}, errDo
|
||||
}
|
||||
defer func() {
|
||||
if errClose := httpResp.Body.Close(); errClose != nil {
|
||||
log.Errorf("vertex executor: close response body error: %v", errClose)
|
||||
}
|
||||
}()
|
||||
recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
|
||||
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
|
||||
b, _ := io.ReadAll(httpResp.Body)
|
||||
appendAPIResponseChunk(ctx, e.cfg, b)
|
||||
log.Debugf("request error, error status: %d, error body: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
|
||||
return cliproxyexecutor.Response{}, statusErr{code: httpResp.StatusCode, msg: string(b)}
|
||||
}
|
||||
data, errRead := io.ReadAll(httpResp.Body)
|
||||
if errRead != nil {
|
||||
recordAPIResponseError(ctx, e.cfg, errRead)
|
||||
return cliproxyexecutor.Response{}, errRead
|
||||
}
|
||||
appendAPIResponseChunk(ctx, e.cfg, data)
|
||||
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
|
||||
log.Debugf("request error, error status: %d, error body: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), data))
|
||||
return cliproxyexecutor.Response{}, statusErr{code: httpResp.StatusCode, msg: string(data)}
|
||||
}
|
||||
count := gjson.GetBytes(data, "totalTokens").Int()
|
||||
out := sdktranslator.TranslateTokenCount(ctx, to, from, count, data)
|
||||
return cliproxyexecutor.Response{Payload: []byte(out)}, nil
|
||||
}
|
||||
|
||||
// Refresh is a no-op for service account based credentials.
|
||||
func (e *GeminiVertexExecutor) Refresh(_ context.Context, auth *cliproxyauth.Auth) (*cliproxyauth.Auth, error) {
|
||||
return auth, nil
|
||||
}
|
||||
|
||||
// executeWithServiceAccount handles authentication using service account credentials.
|
||||
// This method contains the original service account authentication logic.
|
||||
func (e *GeminiVertexExecutor) executeWithServiceAccount(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options, projectID, location string, saJSON []byte) (resp cliproxyexecutor.Response, err error) {
|
||||
reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth)
|
||||
defer reporter.trackFailure(ctx, &err)
|
||||
|
||||
@@ -149,13 +376,104 @@ func (e *GeminiVertexExecutor) Execute(ctx context.Context, auth *cliproxyauth.A
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
// ExecuteStream handles SSE streaming for Vertex.
|
||||
func (e *GeminiVertexExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (stream <-chan cliproxyexecutor.StreamChunk, err error) {
|
||||
projectID, location, saJSON, errCreds := vertexCreds(auth)
|
||||
if errCreds != nil {
|
||||
return nil, errCreds
|
||||
// executeWithAPIKey handles authentication using API key credentials.
|
||||
func (e *GeminiVertexExecutor) executeWithAPIKey(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options, apiKey, baseURL string) (resp cliproxyexecutor.Response, err error) {
|
||||
reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth)
|
||||
defer reporter.trackFailure(ctx, &err)
|
||||
|
||||
from := opts.SourceFormat
|
||||
to := sdktranslator.FromString("gemini")
|
||||
body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), false)
|
||||
if budgetOverride, includeOverride, ok := util.GeminiThinkingFromMetadata(req.Metadata); ok && util.ModelSupportsThinking(req.Model) {
|
||||
if budgetOverride != nil {
|
||||
norm := util.NormalizeThinkingBudget(req.Model, *budgetOverride)
|
||||
budgetOverride = &norm
|
||||
}
|
||||
body = util.ApplyGeminiThinkingConfig(body, budgetOverride, includeOverride)
|
||||
}
|
||||
body = util.StripThinkingConfigIfUnsupported(req.Model, body)
|
||||
body = fixGeminiImageAspectRatio(req.Model, body)
|
||||
body = applyPayloadConfig(e.cfg, req.Model, body)
|
||||
|
||||
action := "generateContent"
|
||||
if req.Metadata != nil {
|
||||
if a, _ := req.Metadata["action"].(string); a == "countTokens" {
|
||||
action = "countTokens"
|
||||
}
|
||||
}
|
||||
|
||||
// For API key auth, use simpler URL format without project/location
|
||||
if baseURL == "" {
|
||||
baseURL = "https://generativelanguage.googleapis.com"
|
||||
}
|
||||
url := fmt.Sprintf("%s/%s/publishers/google/models/%s:%s", baseURL, vertexAPIVersion, req.Model, action)
|
||||
if opts.Alt != "" && action != "countTokens" {
|
||||
url = url + fmt.Sprintf("?$alt=%s", opts.Alt)
|
||||
}
|
||||
body, _ = sjson.DeleteBytes(body, "session_id")
|
||||
|
||||
httpReq, errNewReq := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(body))
|
||||
if errNewReq != nil {
|
||||
return resp, errNewReq
|
||||
}
|
||||
httpReq.Header.Set("Content-Type", "application/json")
|
||||
if apiKey != "" {
|
||||
httpReq.Header.Set("x-goog-api-key", apiKey)
|
||||
}
|
||||
applyGeminiHeaders(httpReq, auth)
|
||||
|
||||
var authID, authLabel, authType, authValue string
|
||||
if auth != nil {
|
||||
authID = auth.ID
|
||||
authLabel = auth.Label
|
||||
authType, authValue = auth.AccountInfo()
|
||||
}
|
||||
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{
|
||||
URL: url,
|
||||
Method: http.MethodPost,
|
||||
Headers: httpReq.Header.Clone(),
|
||||
Body: body,
|
||||
Provider: e.Identifier(),
|
||||
AuthID: authID,
|
||||
AuthLabel: authLabel,
|
||||
AuthType: authType,
|
||||
AuthValue: authValue,
|
||||
})
|
||||
|
||||
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
|
||||
httpResp, errDo := httpClient.Do(httpReq)
|
||||
if errDo != nil {
|
||||
recordAPIResponseError(ctx, e.cfg, errDo)
|
||||
return resp, errDo
|
||||
}
|
||||
defer func() {
|
||||
if errClose := httpResp.Body.Close(); errClose != nil {
|
||||
log.Errorf("vertex executor: close response body error: %v", errClose)
|
||||
}
|
||||
}()
|
||||
recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
|
||||
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
|
||||
b, _ := io.ReadAll(httpResp.Body)
|
||||
appendAPIResponseChunk(ctx, e.cfg, b)
|
||||
log.Debugf("request error, error status: %d, error body: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
|
||||
err = statusErr{code: httpResp.StatusCode, msg: string(b)}
|
||||
return resp, err
|
||||
}
|
||||
data, errRead := io.ReadAll(httpResp.Body)
|
||||
if errRead != nil {
|
||||
recordAPIResponseError(ctx, e.cfg, errRead)
|
||||
return resp, errRead
|
||||
}
|
||||
appendAPIResponseChunk(ctx, e.cfg, data)
|
||||
reporter.publish(ctx, parseGeminiUsage(data))
|
||||
var param any
|
||||
out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, bytes.Clone(opts.OriginalRequest), body, data, ¶m)
|
||||
resp = cliproxyexecutor.Response{Payload: []byte(out)}
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
// executeStreamWithServiceAccount handles streaming authentication using service account credentials.
|
||||
func (e *GeminiVertexExecutor) executeStreamWithServiceAccount(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options, projectID, location string, saJSON []byte) (stream <-chan cliproxyexecutor.StreamChunk, err error) {
|
||||
reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth)
|
||||
defer reporter.trackFailure(ctx, &err)
|
||||
|
||||
@@ -266,42 +584,44 @@ func (e *GeminiVertexExecutor) ExecuteStream(ctx context.Context, auth *cliproxy
|
||||
return stream, nil
|
||||
}
|
||||
|
||||
// CountTokens calls Vertex countTokens endpoint.
|
||||
func (e *GeminiVertexExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
|
||||
projectID, location, saJSON, errCreds := vertexCreds(auth)
|
||||
if errCreds != nil {
|
||||
return cliproxyexecutor.Response{}, errCreds
|
||||
}
|
||||
// executeStreamWithAPIKey handles streaming authentication using API key credentials.
|
||||
func (e *GeminiVertexExecutor) executeStreamWithAPIKey(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options, apiKey, baseURL string) (stream <-chan cliproxyexecutor.StreamChunk, err error) {
|
||||
reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth)
|
||||
defer reporter.trackFailure(ctx, &err)
|
||||
|
||||
from := opts.SourceFormat
|
||||
to := sdktranslator.FromString("gemini")
|
||||
translatedReq := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), false)
|
||||
body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), true)
|
||||
if budgetOverride, includeOverride, ok := util.GeminiThinkingFromMetadata(req.Metadata); ok && util.ModelSupportsThinking(req.Model) {
|
||||
if budgetOverride != nil {
|
||||
norm := util.NormalizeThinkingBudget(req.Model, *budgetOverride)
|
||||
budgetOverride = &norm
|
||||
}
|
||||
translatedReq = util.ApplyGeminiThinkingConfig(translatedReq, budgetOverride, includeOverride)
|
||||
body = util.ApplyGeminiThinkingConfig(body, budgetOverride, includeOverride)
|
||||
}
|
||||
translatedReq = util.StripThinkingConfigIfUnsupported(req.Model, translatedReq)
|
||||
translatedReq = fixGeminiImageAspectRatio(req.Model, translatedReq)
|
||||
respCtx := context.WithValue(ctx, "alt", opts.Alt)
|
||||
translatedReq, _ = sjson.DeleteBytes(translatedReq, "tools")
|
||||
translatedReq, _ = sjson.DeleteBytes(translatedReq, "generationConfig")
|
||||
translatedReq, _ = sjson.DeleteBytes(translatedReq, "safetySettings")
|
||||
body = util.StripThinkingConfigIfUnsupported(req.Model, body)
|
||||
body = fixGeminiImageAspectRatio(req.Model, body)
|
||||
body = applyPayloadConfig(e.cfg, req.Model, body)
|
||||
|
||||
baseURL := vertexBaseURL(location)
|
||||
url := fmt.Sprintf("%s/%s/projects/%s/locations/%s/publishers/google/models/%s:%s", baseURL, vertexAPIVersion, projectID, location, req.Model, "countTokens")
|
||||
// For API key auth, use simpler URL format without project/location
|
||||
if baseURL == "" {
|
||||
baseURL = "https://generativelanguage.googleapis.com"
|
||||
}
|
||||
url := fmt.Sprintf("%s/%s/publishers/google/models/%s:%s", baseURL, vertexAPIVersion, req.Model, "streamGenerateContent")
|
||||
if opts.Alt == "" {
|
||||
url = url + "?alt=sse"
|
||||
} else {
|
||||
url = url + fmt.Sprintf("?$alt=%s", opts.Alt)
|
||||
}
|
||||
body, _ = sjson.DeleteBytes(body, "session_id")
|
||||
|
||||
httpReq, errNewReq := http.NewRequestWithContext(respCtx, http.MethodPost, url, bytes.NewReader(translatedReq))
|
||||
httpReq, errNewReq := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(body))
|
||||
if errNewReq != nil {
|
||||
return cliproxyexecutor.Response{}, errNewReq
|
||||
return nil, errNewReq
|
||||
}
|
||||
httpReq.Header.Set("Content-Type", "application/json")
|
||||
if token, errTok := vertexAccessToken(ctx, e.cfg, auth, saJSON); errTok == nil && token != "" {
|
||||
httpReq.Header.Set("Authorization", "Bearer "+token)
|
||||
} else if errTok != nil {
|
||||
log.Errorf("vertex executor: access token error: %v", errTok)
|
||||
return cliproxyexecutor.Response{}, statusErr{code: 500, msg: "internal server error"}
|
||||
if apiKey != "" {
|
||||
httpReq.Header.Set("x-goog-api-key", apiKey)
|
||||
}
|
||||
applyGeminiHeaders(httpReq, auth)
|
||||
|
||||
@@ -315,7 +635,7 @@ func (e *GeminiVertexExecutor) CountTokens(ctx context.Context, auth *cliproxyau
|
||||
URL: url,
|
||||
Method: http.MethodPost,
|
||||
Headers: httpReq.Header.Clone(),
|
||||
Body: translatedReq,
|
||||
Body: body,
|
||||
Provider: e.Identifier(),
|
||||
AuthID: authID,
|
||||
AuthLabel: authLabel,
|
||||
@@ -327,38 +647,53 @@ func (e *GeminiVertexExecutor) CountTokens(ctx context.Context, auth *cliproxyau
|
||||
httpResp, errDo := httpClient.Do(httpReq)
|
||||
if errDo != nil {
|
||||
recordAPIResponseError(ctx, e.cfg, errDo)
|
||||
return cliproxyexecutor.Response{}, errDo
|
||||
return nil, errDo
|
||||
}
|
||||
defer func() {
|
||||
if errClose := httpResp.Body.Close(); errClose != nil {
|
||||
log.Errorf("vertex executor: close response body error: %v", errClose)
|
||||
}
|
||||
}()
|
||||
recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
|
||||
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
|
||||
b, _ := io.ReadAll(httpResp.Body)
|
||||
appendAPIResponseChunk(ctx, e.cfg, b)
|
||||
log.Debugf("request error, error status: %d, error body: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
|
||||
return cliproxyexecutor.Response{}, statusErr{code: httpResp.StatusCode, msg: string(b)}
|
||||
if errClose := httpResp.Body.Close(); errClose != nil {
|
||||
log.Errorf("vertex executor: close response body error: %v", errClose)
|
||||
}
|
||||
return nil, statusErr{code: httpResp.StatusCode, msg: string(b)}
|
||||
}
|
||||
data, errRead := io.ReadAll(httpResp.Body)
|
||||
if errRead != nil {
|
||||
recordAPIResponseError(ctx, e.cfg, errRead)
|
||||
return cliproxyexecutor.Response{}, errRead
|
||||
}
|
||||
appendAPIResponseChunk(ctx, e.cfg, data)
|
||||
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
|
||||
log.Debugf("request error, error status: %d, error body: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), data))
|
||||
return cliproxyexecutor.Response{}, statusErr{code: httpResp.StatusCode, msg: string(data)}
|
||||
}
|
||||
count := gjson.GetBytes(data, "totalTokens").Int()
|
||||
out := sdktranslator.TranslateTokenCount(ctx, to, from, count, data)
|
||||
return cliproxyexecutor.Response{Payload: []byte(out)}, nil
|
||||
}
|
||||
|
||||
// Refresh is a no-op for service account based credentials.
|
||||
func (e *GeminiVertexExecutor) Refresh(_ context.Context, auth *cliproxyauth.Auth) (*cliproxyauth.Auth, error) {
|
||||
return auth, nil
|
||||
out := make(chan cliproxyexecutor.StreamChunk)
|
||||
stream = out
|
||||
go func() {
|
||||
defer close(out)
|
||||
defer func() {
|
||||
if errClose := httpResp.Body.Close(); errClose != nil {
|
||||
log.Errorf("vertex executor: close response body error: %v", errClose)
|
||||
}
|
||||
}()
|
||||
scanner := bufio.NewScanner(httpResp.Body)
|
||||
scanner.Buffer(nil, 20_971_520)
|
||||
var param any
|
||||
for scanner.Scan() {
|
||||
line := scanner.Bytes()
|
||||
appendAPIResponseChunk(ctx, e.cfg, line)
|
||||
if detail, ok := parseGeminiStreamUsage(line); ok {
|
||||
reporter.publish(ctx, detail)
|
||||
}
|
||||
lines := sdktranslator.TranslateStream(ctx, to, from, req.Model, bytes.Clone(opts.OriginalRequest), body, bytes.Clone(line), ¶m)
|
||||
for i := range lines {
|
||||
out <- cliproxyexecutor.StreamChunk{Payload: []byte(lines[i])}
|
||||
}
|
||||
}
|
||||
lines := sdktranslator.TranslateStream(ctx, to, from, req.Model, bytes.Clone(opts.OriginalRequest), body, []byte("[DONE]"), ¶m)
|
||||
for i := range lines {
|
||||
out <- cliproxyexecutor.StreamChunk{Payload: []byte(lines[i])}
|
||||
}
|
||||
if errScan := scanner.Err(); errScan != nil {
|
||||
recordAPIResponseError(ctx, e.cfg, errScan)
|
||||
reporter.publishFailure(ctx)
|
||||
out <- cliproxyexecutor.StreamChunk{Err: errScan}
|
||||
}
|
||||
}()
|
||||
return stream, nil
|
||||
}
|
||||
|
||||
// vertexCreds extracts project, location and raw service account JSON from auth metadata.
|
||||
@@ -401,6 +736,23 @@ func vertexCreds(a *cliproxyauth.Auth) (projectID, location string, serviceAccou
|
||||
return projectID, location, saJSON, nil
|
||||
}
|
||||
|
||||
// vertexAPICreds extracts API key and base URL from auth attributes following the claudeCreds pattern.
|
||||
func vertexAPICreds(a *cliproxyauth.Auth) (apiKey, baseURL string) {
|
||||
if a == nil {
|
||||
return "", ""
|
||||
}
|
||||
if a.Attributes != nil {
|
||||
apiKey = a.Attributes["api_key"]
|
||||
baseURL = a.Attributes["base_url"]
|
||||
}
|
||||
if apiKey == "" && a.Metadata != nil {
|
||||
if v, ok := a.Metadata["access_token"].(string); ok {
|
||||
apiKey = v
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func vertexBaseURL(location string) string {
|
||||
loc := strings.TrimSpace(location)
|
||||
if loc == "" {
|
||||
|
||||
@@ -4,10 +4,42 @@ import (
|
||||
"strings"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
||||
"github.com/tidwall/gjson"
|
||||
"github.com/tidwall/sjson"
|
||||
)
|
||||
|
||||
// applyThinkingMetadata applies thinking config from model suffix metadata (e.g., -reasoning, -thinking-N)
|
||||
// for standard Gemini format payloads. It normalizes the budget when the model supports thinking.
|
||||
func applyThinkingMetadata(payload []byte, metadata map[string]any, model string) []byte {
|
||||
budgetOverride, includeOverride, ok := util.GeminiThinkingFromMetadata(metadata)
|
||||
if !ok {
|
||||
return payload
|
||||
}
|
||||
if !util.ModelSupportsThinking(model) {
|
||||
return payload
|
||||
}
|
||||
if budgetOverride != nil {
|
||||
norm := util.NormalizeThinkingBudget(model, *budgetOverride)
|
||||
budgetOverride = &norm
|
||||
}
|
||||
return util.ApplyGeminiThinkingConfig(payload, budgetOverride, includeOverride)
|
||||
}
|
||||
|
||||
// applyThinkingMetadataCLI applies thinking config from model suffix metadata (e.g., -reasoning, -thinking-N)
|
||||
// for Gemini CLI format payloads (nested under "request"). It normalizes the budget when the model supports thinking.
|
||||
func applyThinkingMetadataCLI(payload []byte, metadata map[string]any, model string) []byte {
|
||||
budgetOverride, includeOverride, ok := util.GeminiThinkingFromMetadata(metadata)
|
||||
if !ok {
|
||||
return payload
|
||||
}
|
||||
if budgetOverride != nil && util.ModelSupportsThinking(model) {
|
||||
norm := util.NormalizeThinkingBudget(model, *budgetOverride)
|
||||
budgetOverride = &norm
|
||||
}
|
||||
return util.ApplyGeminiCLIThinkingConfig(payload, budgetOverride, includeOverride)
|
||||
}
|
||||
|
||||
// applyPayloadConfig applies payload default and override rules from configuration
|
||||
// to the given JSON payload for the specified model.
|
||||
// Defaults only fill missing fields, while overrides always overwrite existing values.
|
||||
|
||||
@@ -37,7 +37,7 @@ func newUsageReporter(ctx context.Context, provider, model string, auth *cliprox
|
||||
}
|
||||
if auth != nil {
|
||||
reporter.authID = auth.ID
|
||||
reporter.authIndex = auth.Index
|
||||
reporter.authIndex = auth.EnsureIndex()
|
||||
}
|
||||
return reporter
|
||||
}
|
||||
|
||||
@@ -98,6 +98,20 @@ func ConvertGeminiRequestToAntigravity(_ string, inputRawJSON []byte, _ bool) []
|
||||
}
|
||||
}
|
||||
|
||||
gjson.GetBytes(rawJSON, "request.contents").ForEach(func(key, content gjson.Result) bool {
|
||||
if content.Get("role").String() == "model" {
|
||||
content.Get("parts").ForEach(func(partKey, part gjson.Result) bool {
|
||||
if part.Get("functionCall").Exists() {
|
||||
rawJSON, _ = sjson.SetBytes(rawJSON, fmt.Sprintf("request.contents.%d.parts.%d.thoughtSignature", key.Int(), partKey.Int()), "skip_thought_signature_validator")
|
||||
} else if part.Get("thoughtSignature").Exists() {
|
||||
rawJSON, _ = sjson.SetBytes(rawJSON, fmt.Sprintf("request.contents.%d.parts.%d.thoughtSignature", key.Int(), partKey.Int()), "skip_thought_signature_validator")
|
||||
}
|
||||
return true
|
||||
})
|
||||
}
|
||||
return true
|
||||
})
|
||||
|
||||
return common.AttachDefaultSafetySettings(rawJSON, "request.safetySettings")
|
||||
}
|
||||
|
||||
|
||||
@@ -271,7 +271,15 @@ func ConvertOpenAIRequestToAntigravity(modelName string, inputRawJSON []byte, _
|
||||
if resp == "" {
|
||||
resp = "{}"
|
||||
}
|
||||
toolNode, _ = sjson.SetBytes(toolNode, "parts."+itoa(pp)+".functionResponse.response.result", []byte(resp))
|
||||
// Handle non-JSON output gracefully (matches dev branch approach)
|
||||
if resp != "null" {
|
||||
parsed := gjson.Parse(resp)
|
||||
if parsed.Type == gjson.JSON {
|
||||
toolNode, _ = sjson.SetRawBytes(toolNode, "parts."+itoa(pp)+".functionResponse.response.result", []byte(parsed.Raw))
|
||||
} else {
|
||||
toolNode, _ = sjson.SetBytes(toolNode, "parts."+itoa(pp)+".functionResponse.response.result", resp)
|
||||
}
|
||||
}
|
||||
pp++
|
||||
}
|
||||
}
|
||||
|
||||
@@ -105,14 +105,19 @@ func ConvertAntigravityResponseToOpenAI(_ context.Context, _ string, originalReq
|
||||
partTextResult := partResult.Get("text")
|
||||
functionCallResult := partResult.Get("functionCall")
|
||||
thoughtSignatureResult := partResult.Get("thoughtSignature")
|
||||
if !thoughtSignatureResult.Exists() {
|
||||
thoughtSignatureResult = partResult.Get("thought_signature")
|
||||
}
|
||||
inlineDataResult := partResult.Get("inlineData")
|
||||
if !inlineDataResult.Exists() {
|
||||
inlineDataResult = partResult.Get("inline_data")
|
||||
}
|
||||
|
||||
// Handle thoughtSignature - this is encrypted reasoning content that should not be exposed to the client
|
||||
if thoughtSignatureResult.Exists() && thoughtSignatureResult.String() != "" {
|
||||
// Skip thoughtSignature processing - it's internal encrypted data
|
||||
hasThoughtSignature := thoughtSignatureResult.Exists() && thoughtSignatureResult.String() != ""
|
||||
hasContentPayload := partTextResult.Exists() || functionCallResult.Exists() || inlineDataResult.Exists()
|
||||
|
||||
// Ignore encrypted thoughtSignature but keep any actual content in the same part.
|
||||
if hasThoughtSignature && !hasContentPayload {
|
||||
continue
|
||||
}
|
||||
|
||||
|
||||
@@ -98,6 +98,20 @@ func ConvertGeminiRequestToGeminiCLI(_ string, inputRawJSON []byte, _ bool) []by
|
||||
}
|
||||
}
|
||||
|
||||
gjson.GetBytes(rawJSON, "request.contents").ForEach(func(key, content gjson.Result) bool {
|
||||
if content.Get("role").String() == "model" {
|
||||
content.Get("parts").ForEach(func(partKey, part gjson.Result) bool {
|
||||
if part.Get("functionCall").Exists() {
|
||||
rawJSON, _ = sjson.SetBytes(rawJSON, fmt.Sprintf("request.contents.%d.parts.%d.thoughtSignature", key.Int(), partKey.Int()), "skip_thought_signature_validator")
|
||||
} else if part.Get("thoughtSignature").Exists() {
|
||||
rawJSON, _ = sjson.SetBytes(rawJSON, fmt.Sprintf("request.contents.%d.parts.%d.thoughtSignature", key.Int(), partKey.Int()), "skip_thought_signature_validator")
|
||||
}
|
||||
return true
|
||||
})
|
||||
}
|
||||
return true
|
||||
})
|
||||
|
||||
return common.AttachDefaultSafetySettings(rawJSON, "request.safetySettings")
|
||||
}
|
||||
|
||||
|
||||
@@ -105,14 +105,19 @@ func ConvertCliResponseToOpenAI(_ context.Context, _ string, originalRequestRawJ
|
||||
partTextResult := partResult.Get("text")
|
||||
functionCallResult := partResult.Get("functionCall")
|
||||
thoughtSignatureResult := partResult.Get("thoughtSignature")
|
||||
if !thoughtSignatureResult.Exists() {
|
||||
thoughtSignatureResult = partResult.Get("thought_signature")
|
||||
}
|
||||
inlineDataResult := partResult.Get("inlineData")
|
||||
if !inlineDataResult.Exists() {
|
||||
inlineDataResult = partResult.Get("inline_data")
|
||||
}
|
||||
|
||||
// Handle thoughtSignature - this is encrypted reasoning content that should not be exposed to the client
|
||||
if thoughtSignatureResult.Exists() && thoughtSignatureResult.String() != "" {
|
||||
// Skip thoughtSignature processing - it's internal encrypted data
|
||||
hasThoughtSignature := thoughtSignatureResult.Exists() && thoughtSignatureResult.String() != ""
|
||||
hasContentPayload := partTextResult.Exists() || functionCallResult.Exists() || inlineDataResult.Exists()
|
||||
|
||||
// Ignore encrypted thoughtSignature but keep any actual content in the same part.
|
||||
if hasThoughtSignature && !hasContentPayload {
|
||||
continue
|
||||
}
|
||||
|
||||
|
||||
@@ -46,5 +46,19 @@ func ConvertGeminiCLIRequestToGemini(_ string, inputRawJSON []byte, _ bool) []by
|
||||
}
|
||||
}
|
||||
|
||||
gjson.GetBytes(rawJSON, "contents").ForEach(func(key, content gjson.Result) bool {
|
||||
if content.Get("role").String() == "model" {
|
||||
content.Get("parts").ForEach(func(partKey, part gjson.Result) bool {
|
||||
if part.Get("functionCall").Exists() {
|
||||
rawJSON, _ = sjson.SetBytes(rawJSON, fmt.Sprintf("contents.%d.parts.%d.thoughtSignature", key.Int(), partKey.Int()), "skip_thought_signature_validator")
|
||||
} else if part.Get("thoughtSignature").Exists() {
|
||||
rawJSON, _ = sjson.SetBytes(rawJSON, fmt.Sprintf("contents.%d.parts.%d.thoughtSignature", key.Int(), partKey.Int()), "skip_thought_signature_validator")
|
||||
}
|
||||
return true
|
||||
})
|
||||
}
|
||||
return true
|
||||
})
|
||||
|
||||
return common.AttachDefaultSafetySettings(rawJSON, "safetySettings")
|
||||
}
|
||||
|
||||
@@ -30,6 +30,11 @@ func ConvertGeminiRequestToGemini(_ string, inputRawJSON []byte, _ bool) []byte
|
||||
if toolsResult.Exists() && toolsResult.IsArray() {
|
||||
toolResults := toolsResult.Array()
|
||||
for i := 0; i < len(toolResults); i++ {
|
||||
if gjson.GetBytes(rawJSON, fmt.Sprintf("tools.%d.functionDeclarations", i)).Exists() {
|
||||
strJson, _ := util.RenameKey(string(rawJSON), fmt.Sprintf("tools.%d.functionDeclarations", i), fmt.Sprintf("tools.%d.function_declarations", i))
|
||||
rawJSON = []byte(strJson)
|
||||
}
|
||||
|
||||
functionDeclarationsResult := gjson.GetBytes(rawJSON, fmt.Sprintf("tools.%d.function_declarations", i))
|
||||
if functionDeclarationsResult.Exists() && functionDeclarationsResult.IsArray() {
|
||||
functionDeclarationsResults := functionDeclarationsResult.Array()
|
||||
@@ -72,7 +77,20 @@ func ConvertGeminiRequestToGemini(_ string, inputRawJSON []byte, _ bool) []byte
|
||||
return true
|
||||
})
|
||||
|
||||
out = common.AttachDefaultSafetySettings(out, "safetySettings")
|
||||
gjson.GetBytes(out, "contents").ForEach(func(key, content gjson.Result) bool {
|
||||
if content.Get("role").String() == "model" {
|
||||
content.Get("parts").ForEach(func(partKey, part gjson.Result) bool {
|
||||
if part.Get("functionCall").Exists() {
|
||||
out, _ = sjson.SetBytes(out, fmt.Sprintf("contents.%d.parts.%d.thoughtSignature", key.Int(), partKey.Int()), "skip_thought_signature_validator")
|
||||
} else if part.Get("thoughtSignature").Exists() {
|
||||
out, _ = sjson.SetBytes(out, fmt.Sprintf("contents.%d.parts.%d.thoughtSignature", key.Int(), partKey.Int()), "skip_thought_signature_validator")
|
||||
}
|
||||
return true
|
||||
})
|
||||
}
|
||||
return true
|
||||
})
|
||||
|
||||
out = common.AttachDefaultSafetySettings(out, "safetySettings")
|
||||
return out
|
||||
}
|
||||
|
||||
@@ -116,8 +116,11 @@ func ConvertGeminiResponseToOpenAI(_ context.Context, _ string, originalRequestR
|
||||
thoughtSignatureResult = partResult.Get("thought_signature")
|
||||
}
|
||||
|
||||
// Skip thoughtSignature parts (encrypted reasoning not exposed downstream).
|
||||
if thoughtSignatureResult.Exists() && thoughtSignatureResult.String() != "" {
|
||||
hasThoughtSignature := thoughtSignatureResult.Exists() && thoughtSignatureResult.String() != ""
|
||||
hasContentPayload := partTextResult.Exists() || functionCallResult.Exists() || inlineDataResult.Exists()
|
||||
|
||||
// Skip pure thoughtSignature parts but keep any actual payload in the same part.
|
||||
if hasThoughtSignature && !hasContentPayload {
|
||||
continue
|
||||
}
|
||||
|
||||
|
||||
@@ -33,7 +33,83 @@ func ConvertOpenAIResponsesRequestToGemini(modelName string, inputRawJSON []byte
|
||||
|
||||
// Convert input messages to Gemini contents format
|
||||
if input := root.Get("input"); input.Exists() && input.IsArray() {
|
||||
input.ForEach(func(_, item gjson.Result) bool {
|
||||
items := input.Array()
|
||||
|
||||
// Normalize consecutive function calls and outputs so each call is immediately followed by its response
|
||||
normalized := make([]gjson.Result, 0, len(items))
|
||||
for i := 0; i < len(items); {
|
||||
item := items[i]
|
||||
itemType := item.Get("type").String()
|
||||
itemRole := item.Get("role").String()
|
||||
if itemType == "" && itemRole != "" {
|
||||
itemType = "message"
|
||||
}
|
||||
|
||||
if itemType == "function_call" {
|
||||
var calls []gjson.Result
|
||||
var outputs []gjson.Result
|
||||
|
||||
for i < len(items) {
|
||||
next := items[i]
|
||||
nextType := next.Get("type").String()
|
||||
nextRole := next.Get("role").String()
|
||||
if nextType == "" && nextRole != "" {
|
||||
nextType = "message"
|
||||
}
|
||||
if nextType != "function_call" {
|
||||
break
|
||||
}
|
||||
calls = append(calls, next)
|
||||
i++
|
||||
}
|
||||
|
||||
for i < len(items) {
|
||||
next := items[i]
|
||||
nextType := next.Get("type").String()
|
||||
nextRole := next.Get("role").String()
|
||||
if nextType == "" && nextRole != "" {
|
||||
nextType = "message"
|
||||
}
|
||||
if nextType != "function_call_output" {
|
||||
break
|
||||
}
|
||||
outputs = append(outputs, next)
|
||||
i++
|
||||
}
|
||||
|
||||
if len(calls) > 0 {
|
||||
outputMap := make(map[string]gjson.Result, len(outputs))
|
||||
for _, out := range outputs {
|
||||
outputMap[out.Get("call_id").String()] = out
|
||||
}
|
||||
for _, call := range calls {
|
||||
normalized = append(normalized, call)
|
||||
callID := call.Get("call_id").String()
|
||||
if resp, ok := outputMap[callID]; ok {
|
||||
normalized = append(normalized, resp)
|
||||
delete(outputMap, callID)
|
||||
}
|
||||
}
|
||||
for _, out := range outputs {
|
||||
if _, ok := outputMap[out.Get("call_id").String()]; ok {
|
||||
normalized = append(normalized, out)
|
||||
}
|
||||
}
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
if itemType == "function_call_output" {
|
||||
normalized = append(normalized, item)
|
||||
i++
|
||||
continue
|
||||
}
|
||||
|
||||
normalized = append(normalized, item)
|
||||
i++
|
||||
}
|
||||
|
||||
for _, item := range normalized {
|
||||
itemType := item.Get("type").String()
|
||||
itemRole := item.Get("role").String()
|
||||
if itemType == "" && itemRole != "" {
|
||||
@@ -59,7 +135,7 @@ func ConvertOpenAIResponsesRequestToGemini(modelName string, inputRawJSON []byte
|
||||
out, _ = sjson.SetRaw(out, "system_instruction", systemInstr)
|
||||
}
|
||||
}
|
||||
return true
|
||||
continue
|
||||
}
|
||||
|
||||
// Handle regular messages
|
||||
@@ -186,7 +262,8 @@ func ConvertOpenAIResponsesRequestToGemini(modelName string, inputRawJSON []byte
|
||||
case "function_call_output":
|
||||
// Handle function call outputs - convert to function message with functionResponse
|
||||
callID := item.Get("call_id").String()
|
||||
output := item.Get("output").String()
|
||||
// Use .Raw to preserve the JSON encoding (includes quotes for strings)
|
||||
outputRaw := item.Get("output").Str
|
||||
|
||||
functionContent := `{"role":"function","parts":[]}`
|
||||
functionResponse := `{"functionResponse":{"name":"","response":{}}}`
|
||||
@@ -209,18 +286,19 @@ func ConvertOpenAIResponsesRequestToGemini(modelName string, inputRawJSON []byte
|
||||
|
||||
functionResponse, _ = sjson.Set(functionResponse, "functionResponse.name", functionName)
|
||||
|
||||
// Parse output JSON string and set as response content
|
||||
if output != "" {
|
||||
outputResult := gjson.Parse(output)
|
||||
functionResponse, _ = sjson.Set(functionResponse, "functionResponse.response.result", outputResult.Raw)
|
||||
// Set the raw JSON output directly (preserves string encoding)
|
||||
if outputRaw != "" && outputRaw != "null" {
|
||||
output := gjson.Parse(outputRaw)
|
||||
if output.Type == gjson.JSON {
|
||||
functionResponse, _ = sjson.SetRaw(functionResponse, "functionResponse.response.result", output.Raw)
|
||||
} else {
|
||||
functionResponse, _ = sjson.Set(functionResponse, "functionResponse.response.result", outputRaw)
|
||||
}
|
||||
}
|
||||
|
||||
functionContent, _ = sjson.SetRaw(functionContent, "parts.-1", functionResponse)
|
||||
out, _ = sjson.SetRaw(out, "contents.-1", functionContent)
|
||||
}
|
||||
|
||||
return true
|
||||
})
|
||||
}
|
||||
} else if input.Exists() && input.Type == gjson.String {
|
||||
// Simple string input conversion to user message
|
||||
userContent := `{"role":"user","parts":[{"text":""}]}`
|
||||
|
||||
@@ -433,12 +433,18 @@ func ConvertGeminiResponseToOpenAIResponses(_ context.Context, modelName string,
|
||||
// output tokens
|
||||
if v := um.Get("candidatesTokenCount"); v.Exists() {
|
||||
completed, _ = sjson.Set(completed, "response.usage.output_tokens", v.Int())
|
||||
} else {
|
||||
completed, _ = sjson.Set(completed, "response.usage.output_tokens", 0)
|
||||
}
|
||||
if v := um.Get("thoughtsTokenCount"); v.Exists() {
|
||||
completed, _ = sjson.Set(completed, "response.usage.output_tokens_details.reasoning_tokens", v.Int())
|
||||
} else {
|
||||
completed, _ = sjson.Set(completed, "response.usage.output_tokens_details.reasoning_tokens", 0)
|
||||
}
|
||||
if v := um.Get("totalTokenCount"); v.Exists() {
|
||||
completed, _ = sjson.Set(completed, "response.usage.total_tokens", v.Int())
|
||||
} else {
|
||||
completed, _ = sjson.Set(completed, "response.usage.total_tokens", 0)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -202,6 +202,8 @@ func ConvertOpenAIResponsesRequestToOpenAIChatCompletions(modelName string, inpu
|
||||
out, _ = sjson.Set(out, "reasoning_effort", "medium")
|
||||
case "high":
|
||||
out, _ = sjson.Set(out, "reasoning_effort", "high")
|
||||
case "xhigh":
|
||||
out, _ = sjson.Set(out, "reasoning_effort", "xhigh")
|
||||
default:
|
||||
out, _ = sjson.Set(out, "reasoning_effort", "auto")
|
||||
}
|
||||
|
||||
@@ -34,6 +34,15 @@ func ParseGeminiThinkingSuffix(model string) (string, *int, *bool, bool) {
|
||||
return base, &budgetValue, &include, true
|
||||
}
|
||||
|
||||
// Handle "-reasoning" suffix: enables thinking with dynamic budget (-1)
|
||||
// Maps: gemini-2.5-flash-reasoning -> gemini-2.5-flash with thinkingBudget=-1
|
||||
if strings.HasSuffix(lower, "-reasoning") {
|
||||
base := model[:len(model)-len("-reasoning")]
|
||||
budgetValue := -1 // Dynamic budget
|
||||
include := true
|
||||
return base, &budgetValue, &include, true
|
||||
}
|
||||
|
||||
idx := strings.LastIndex(lower, "-thinking-")
|
||||
if idx == -1 {
|
||||
return model, nil, nil, false
|
||||
|
||||
@@ -30,6 +30,16 @@ import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
func matchProvider(provider string, targets []string) (string, bool) {
|
||||
p := strings.ToLower(strings.TrimSpace(provider))
|
||||
for _, t := range targets {
|
||||
if strings.EqualFold(p, strings.TrimSpace(t)) {
|
||||
return p, true
|
||||
}
|
||||
}
|
||||
return p, false
|
||||
}
|
||||
|
||||
// storePersister captures persistence-capable token store methods used by the watcher.
|
||||
type storePersister interface {
|
||||
PersistConfig(ctx context.Context) error
|
||||
@@ -54,6 +64,7 @@ type Watcher struct {
|
||||
lastConfigHash string
|
||||
authQueue chan<- AuthUpdate
|
||||
currentAuths map[string]*coreauth.Auth
|
||||
runtimeAuths map[string]*coreauth.Auth
|
||||
dispatchMu sync.Mutex
|
||||
dispatchCond *sync.Cond
|
||||
pendingUpdates map[string]AuthUpdate
|
||||
@@ -151,12 +162,14 @@ func NewWatcher(configPath, authDir string, reloadCallback func(*config.Config))
|
||||
|
||||
// Start begins watching the configuration file and authentication directory
|
||||
func (w *Watcher) Start(ctx context.Context) error {
|
||||
// Watch the config file
|
||||
if errAddConfig := w.watcher.Add(w.configPath); errAddConfig != nil {
|
||||
log.Errorf("failed to watch config file %s: %v", w.configPath, errAddConfig)
|
||||
// Watch the config file's parent directory instead of the file itself.
|
||||
// This handles editors that use atomic save (write to temp, then rename).
|
||||
configDir := filepath.Dir(w.configPath)
|
||||
if errAddConfig := w.watcher.Add(configDir); errAddConfig != nil {
|
||||
log.Errorf("failed to watch config directory %s: %v", configDir, errAddConfig)
|
||||
return errAddConfig
|
||||
}
|
||||
log.Debugf("watching config file: %s", w.configPath)
|
||||
log.Debugf("watching config directory: %s (for file: %s)", configDir, filepath.Base(w.configPath))
|
||||
|
||||
// Watch the auth directory
|
||||
if errAddAuthDir := w.watcher.Add(w.authDir); errAddAuthDir != nil {
|
||||
@@ -169,7 +182,7 @@ func (w *Watcher) Start(ctx context.Context) error {
|
||||
go w.processEvents(ctx)
|
||||
|
||||
// Perform an initial full reload based on current config and auth dir
|
||||
w.reloadClients(true)
|
||||
w.reloadClients(true, nil)
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -221,9 +234,57 @@ func (w *Watcher) SetAuthUpdateQueue(queue chan<- AuthUpdate) {
|
||||
}
|
||||
}
|
||||
|
||||
// DispatchRuntimeAuthUpdate allows external runtime providers (e.g., websocket-driven auths)
|
||||
// to push auth updates through the same queue used by file/config watchers.
|
||||
// Returns true if the update was enqueued; false if no queue is configured.
|
||||
func (w *Watcher) DispatchRuntimeAuthUpdate(update AuthUpdate) bool {
|
||||
if w == nil {
|
||||
return false
|
||||
}
|
||||
w.clientsMutex.Lock()
|
||||
if w.runtimeAuths == nil {
|
||||
w.runtimeAuths = make(map[string]*coreauth.Auth)
|
||||
}
|
||||
switch update.Action {
|
||||
case AuthUpdateActionAdd, AuthUpdateActionModify:
|
||||
if update.Auth != nil && update.Auth.ID != "" {
|
||||
clone := update.Auth.Clone()
|
||||
w.runtimeAuths[clone.ID] = clone
|
||||
if w.currentAuths == nil {
|
||||
w.currentAuths = make(map[string]*coreauth.Auth)
|
||||
}
|
||||
w.currentAuths[clone.ID] = clone.Clone()
|
||||
}
|
||||
case AuthUpdateActionDelete:
|
||||
id := update.ID
|
||||
if id == "" && update.Auth != nil {
|
||||
id = update.Auth.ID
|
||||
}
|
||||
if id != "" {
|
||||
delete(w.runtimeAuths, id)
|
||||
if w.currentAuths != nil {
|
||||
delete(w.currentAuths, id)
|
||||
}
|
||||
}
|
||||
}
|
||||
w.clientsMutex.Unlock()
|
||||
if w.getAuthQueue() == nil {
|
||||
return false
|
||||
}
|
||||
w.dispatchAuthUpdates([]AuthUpdate{update})
|
||||
return true
|
||||
}
|
||||
|
||||
func (w *Watcher) refreshAuthState() {
|
||||
auths := w.SnapshotCoreAuths()
|
||||
w.clientsMutex.Lock()
|
||||
if len(w.runtimeAuths) > 0 {
|
||||
for _, a := range w.runtimeAuths {
|
||||
if a != nil {
|
||||
auths = append(auths, a.Clone())
|
||||
}
|
||||
}
|
||||
}
|
||||
updates := w.prepareAuthUpdatesLocked(auths)
|
||||
w.clientsMutex.Unlock()
|
||||
w.dispatchAuthUpdates(updates)
|
||||
@@ -437,6 +498,18 @@ func computeOpenAICompatModelsHash(models []config.OpenAICompatibilityModel) str
|
||||
return hex.EncodeToString(sum[:])
|
||||
}
|
||||
|
||||
func computeVertexCompatModelsHash(models []config.VertexCompatModel) string {
|
||||
if len(models) == 0 {
|
||||
return ""
|
||||
}
|
||||
data, err := json.Marshal(models)
|
||||
if err != nil || len(data) == 0 {
|
||||
return ""
|
||||
}
|
||||
sum := sha256.Sum256(data)
|
||||
return hex.EncodeToString(sum[:])
|
||||
}
|
||||
|
||||
// computeClaudeModelsHash returns a stable hash for Claude model aliases.
|
||||
func computeClaudeModelsHash(models []config.ClaudeModel) string {
|
||||
if len(models) == 0 {
|
||||
@@ -450,6 +523,142 @@ func computeClaudeModelsHash(models []config.ClaudeModel) string {
|
||||
return hex.EncodeToString(sum[:])
|
||||
}
|
||||
|
||||
func computeExcludedModelsHash(excluded []string) string {
|
||||
if len(excluded) == 0 {
|
||||
return ""
|
||||
}
|
||||
normalized := make([]string, 0, len(excluded))
|
||||
for _, entry := range excluded {
|
||||
if trimmed := strings.TrimSpace(entry); trimmed != "" {
|
||||
normalized = append(normalized, strings.ToLower(trimmed))
|
||||
}
|
||||
}
|
||||
if len(normalized) == 0 {
|
||||
return ""
|
||||
}
|
||||
sort.Strings(normalized)
|
||||
data, err := json.Marshal(normalized)
|
||||
if err != nil || len(data) == 0 {
|
||||
return ""
|
||||
}
|
||||
sum := sha256.Sum256(data)
|
||||
return hex.EncodeToString(sum[:])
|
||||
}
|
||||
|
||||
type excludedModelsSummary struct {
|
||||
hash string
|
||||
count int
|
||||
}
|
||||
|
||||
func summarizeExcludedModels(list []string) excludedModelsSummary {
|
||||
if len(list) == 0 {
|
||||
return excludedModelsSummary{}
|
||||
}
|
||||
seen := make(map[string]struct{}, len(list))
|
||||
normalized := make([]string, 0, len(list))
|
||||
for _, entry := range list {
|
||||
if trimmed := strings.ToLower(strings.TrimSpace(entry)); trimmed != "" {
|
||||
if _, exists := seen[trimmed]; exists {
|
||||
continue
|
||||
}
|
||||
seen[trimmed] = struct{}{}
|
||||
normalized = append(normalized, trimmed)
|
||||
}
|
||||
}
|
||||
sort.Strings(normalized)
|
||||
return excludedModelsSummary{
|
||||
hash: computeExcludedModelsHash(normalized),
|
||||
count: len(normalized),
|
||||
}
|
||||
}
|
||||
|
||||
func summarizeOAuthExcludedModels(entries map[string][]string) map[string]excludedModelsSummary {
|
||||
if len(entries) == 0 {
|
||||
return nil
|
||||
}
|
||||
out := make(map[string]excludedModelsSummary, len(entries))
|
||||
for k, v := range entries {
|
||||
key := strings.ToLower(strings.TrimSpace(k))
|
||||
if key == "" {
|
||||
continue
|
||||
}
|
||||
out[key] = summarizeExcludedModels(v)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func diffOAuthExcludedModelChanges(oldMap, newMap map[string][]string) ([]string, []string) {
|
||||
oldSummary := summarizeOAuthExcludedModels(oldMap)
|
||||
newSummary := summarizeOAuthExcludedModels(newMap)
|
||||
keys := make(map[string]struct{}, len(oldSummary)+len(newSummary))
|
||||
for k := range oldSummary {
|
||||
keys[k] = struct{}{}
|
||||
}
|
||||
for k := range newSummary {
|
||||
keys[k] = struct{}{}
|
||||
}
|
||||
changes := make([]string, 0, len(keys))
|
||||
affected := make([]string, 0, len(keys))
|
||||
for key := range keys {
|
||||
oldInfo, okOld := oldSummary[key]
|
||||
newInfo, okNew := newSummary[key]
|
||||
switch {
|
||||
case okOld && !okNew:
|
||||
changes = append(changes, fmt.Sprintf("oauth-excluded-models[%s]: removed", key))
|
||||
affected = append(affected, key)
|
||||
case !okOld && okNew:
|
||||
changes = append(changes, fmt.Sprintf("oauth-excluded-models[%s]: added (%d entries)", key, newInfo.count))
|
||||
affected = append(affected, key)
|
||||
case okOld && okNew && oldInfo.hash != newInfo.hash:
|
||||
changes = append(changes, fmt.Sprintf("oauth-excluded-models[%s]: updated (%d -> %d entries)", key, oldInfo.count, newInfo.count))
|
||||
affected = append(affected, key)
|
||||
}
|
||||
}
|
||||
sort.Strings(changes)
|
||||
sort.Strings(affected)
|
||||
return changes, affected
|
||||
}
|
||||
|
||||
func applyAuthExcludedModelsMeta(auth *coreauth.Auth, cfg *config.Config, perKey []string, authKind string) {
|
||||
if auth == nil || cfg == nil {
|
||||
return
|
||||
}
|
||||
authKindKey := strings.ToLower(strings.TrimSpace(authKind))
|
||||
seen := make(map[string]struct{})
|
||||
add := func(list []string) {
|
||||
for _, entry := range list {
|
||||
if trimmed := strings.TrimSpace(entry); trimmed != "" {
|
||||
key := strings.ToLower(trimmed)
|
||||
if _, exists := seen[key]; exists {
|
||||
continue
|
||||
}
|
||||
seen[key] = struct{}{}
|
||||
}
|
||||
}
|
||||
}
|
||||
if authKindKey == "apikey" {
|
||||
add(perKey)
|
||||
} else if cfg.OAuthExcludedModels != nil {
|
||||
providerKey := strings.ToLower(strings.TrimSpace(auth.Provider))
|
||||
add(cfg.OAuthExcludedModels[providerKey])
|
||||
}
|
||||
combined := make([]string, 0, len(seen))
|
||||
for k := range seen {
|
||||
combined = append(combined, k)
|
||||
}
|
||||
sort.Strings(combined)
|
||||
hash := computeExcludedModelsHash(combined)
|
||||
if auth.Attributes == nil {
|
||||
auth.Attributes = make(map[string]string)
|
||||
}
|
||||
if hash != "" {
|
||||
auth.Attributes["excluded_models_hash"] = hash
|
||||
}
|
||||
if authKind != "" {
|
||||
auth.Attributes["auth_kind"] = authKind
|
||||
}
|
||||
}
|
||||
|
||||
// SetClients sets the file-based clients.
|
||||
// SetClients removed
|
||||
// SetAPIKeyClients removed
|
||||
@@ -474,11 +683,54 @@ func (w *Watcher) processEvents(ctx context.Context) {
|
||||
}
|
||||
}
|
||||
|
||||
func (w *Watcher) authFileUnchanged(path string) (bool, error) {
|
||||
data, errRead := os.ReadFile(path)
|
||||
if errRead != nil {
|
||||
return false, errRead
|
||||
}
|
||||
if len(data) == 0 {
|
||||
return false, nil
|
||||
}
|
||||
sum := sha256.Sum256(data)
|
||||
curHash := hex.EncodeToString(sum[:])
|
||||
|
||||
w.clientsMutex.RLock()
|
||||
prevHash, ok := w.lastAuthHashes[path]
|
||||
w.clientsMutex.RUnlock()
|
||||
if ok && prevHash == curHash {
|
||||
return true, nil
|
||||
}
|
||||
return false, nil
|
||||
}
|
||||
|
||||
func (w *Watcher) isKnownAuthFile(path string) bool {
|
||||
w.clientsMutex.RLock()
|
||||
defer w.clientsMutex.RUnlock()
|
||||
_, ok := w.lastAuthHashes[path]
|
||||
return ok
|
||||
}
|
||||
|
||||
// handleEvent processes individual file system events
|
||||
func (w *Watcher) handleEvent(event fsnotify.Event) {
|
||||
// Filter only relevant events: config file or auth-dir JSON files.
|
||||
configOps := fsnotify.Write | fsnotify.Create | fsnotify.Rename
|
||||
isConfigEvent := event.Name == w.configPath && event.Op&configOps != 0
|
||||
// Check if this event is for our config file (handle both exact match and basename match for directory watching)
|
||||
isConfigEvent := false
|
||||
if event.Op&configOps != 0 {
|
||||
// Exact path match
|
||||
if event.Name == w.configPath {
|
||||
isConfigEvent = true
|
||||
} else {
|
||||
// Check if basename matches and it's in the config directory (for atomic save detection)
|
||||
configDir := filepath.Dir(w.configPath)
|
||||
configBase := filepath.Base(w.configPath)
|
||||
eventDir := filepath.Dir(event.Name)
|
||||
eventBase := filepath.Base(event.Name)
|
||||
if eventDir == configDir && eventBase == configBase {
|
||||
isConfigEvent = true
|
||||
}
|
||||
}
|
||||
}
|
||||
authOps := fsnotify.Create | fsnotify.Write | fsnotify.Remove | fsnotify.Rename
|
||||
isAuthJSON := strings.HasPrefix(event.Name, w.authDir) && strings.HasSuffix(event.Name, ".json") && event.Op&authOps != 0
|
||||
if !isConfigEvent && !isAuthJSON {
|
||||
@@ -497,19 +749,33 @@ func (w *Watcher) handleEvent(event fsnotify.Event) {
|
||||
}
|
||||
|
||||
// Handle auth directory changes incrementally (.json only)
|
||||
fmt.Printf("auth file changed (%s): %s, processing incrementally\n", event.Op.String(), filepath.Base(event.Name))
|
||||
if event.Op&(fsnotify.Remove|fsnotify.Rename) != 0 {
|
||||
// Atomic replace on some platforms may surface as Rename (or Remove) before the new file is ready.
|
||||
// Wait briefly; if the path exists again, treat as an update instead of removal.
|
||||
time.Sleep(replaceCheckDelay)
|
||||
if _, statErr := os.Stat(event.Name); statErr == nil {
|
||||
if unchanged, errSame := w.authFileUnchanged(event.Name); errSame == nil && unchanged {
|
||||
log.Debugf("auth file unchanged (hash match), skipping reload: %s", filepath.Base(event.Name))
|
||||
return
|
||||
}
|
||||
fmt.Printf("auth file changed (%s): %s, processing incrementally\n", event.Op.String(), filepath.Base(event.Name))
|
||||
w.addOrUpdateClient(event.Name)
|
||||
return
|
||||
}
|
||||
if !w.isKnownAuthFile(event.Name) {
|
||||
log.Debugf("ignoring remove for unknown auth file: %s", filepath.Base(event.Name))
|
||||
return
|
||||
}
|
||||
fmt.Printf("auth file changed (%s): %s, processing incrementally\n", event.Op.String(), filepath.Base(event.Name))
|
||||
w.removeClient(event.Name)
|
||||
return
|
||||
}
|
||||
if event.Op&(fsnotify.Create|fsnotify.Write) != 0 {
|
||||
if unchanged, errSame := w.authFileUnchanged(event.Name); errSame == nil && unchanged {
|
||||
log.Debugf("auth file unchanged (hash match), skipping reload: %s", filepath.Base(event.Name))
|
||||
return
|
||||
}
|
||||
fmt.Printf("auth file changed (%s): %s, processing incrementally\n", event.Op.String(), filepath.Base(event.Name))
|
||||
w.addOrUpdateClient(event.Name)
|
||||
}
|
||||
}
|
||||
@@ -593,6 +859,11 @@ func (w *Watcher) reloadConfig() bool {
|
||||
w.config = newConfig
|
||||
w.clientsMutex.Unlock()
|
||||
|
||||
var affectedOAuthProviders []string
|
||||
if oldConfig != nil {
|
||||
_, affectedOAuthProviders = diffOAuthExcludedModelChanges(oldConfig.OAuthExcludedModels, newConfig.OAuthExcludedModels)
|
||||
}
|
||||
|
||||
// Always apply the current log level based on the latest config.
|
||||
// This ensures logrus reflects the desired level even if change detection misses.
|
||||
util.SetLogLevel(newConfig)
|
||||
@@ -618,12 +889,12 @@ func (w *Watcher) reloadConfig() bool {
|
||||
|
||||
log.Infof("config successfully reloaded, triggering client reload")
|
||||
// Reload clients with new config
|
||||
w.reloadClients(authDirChanged)
|
||||
w.reloadClients(authDirChanged, affectedOAuthProviders)
|
||||
return true
|
||||
}
|
||||
|
||||
// reloadClients performs a full scan and reload of all clients.
|
||||
func (w *Watcher) reloadClients(rescanAuth bool) {
|
||||
func (w *Watcher) reloadClients(rescanAuth bool, affectedOAuthProviders []string) {
|
||||
log.Debugf("starting full client load process")
|
||||
|
||||
w.clientsMutex.RLock()
|
||||
@@ -635,12 +906,34 @@ func (w *Watcher) reloadClients(rescanAuth bool) {
|
||||
return
|
||||
}
|
||||
|
||||
if len(affectedOAuthProviders) > 0 {
|
||||
w.clientsMutex.Lock()
|
||||
if w.currentAuths != nil {
|
||||
filtered := make(map[string]*coreauth.Auth, len(w.currentAuths))
|
||||
for id, auth := range w.currentAuths {
|
||||
if auth == nil {
|
||||
continue
|
||||
}
|
||||
provider := strings.ToLower(strings.TrimSpace(auth.Provider))
|
||||
if _, match := matchProvider(provider, affectedOAuthProviders); match {
|
||||
continue
|
||||
}
|
||||
filtered[id] = auth
|
||||
}
|
||||
w.currentAuths = filtered
|
||||
log.Debugf("applying oauth-excluded-models to providers %v", affectedOAuthProviders)
|
||||
} else {
|
||||
w.currentAuths = nil
|
||||
}
|
||||
w.clientsMutex.Unlock()
|
||||
}
|
||||
|
||||
// Unregister all old API key clients before creating new ones
|
||||
// no legacy clients to unregister
|
||||
|
||||
// Create new API key clients based on the new config
|
||||
geminiAPIKeyCount, claudeAPIKeyCount, codexAPIKeyCount, openAICompatCount := BuildAPIKeyClients(cfg)
|
||||
totalAPIKeyClients := geminiAPIKeyCount + claudeAPIKeyCount + codexAPIKeyCount + openAICompatCount
|
||||
geminiAPIKeyCount, vertexCompatAPIKeyCount, claudeAPIKeyCount, codexAPIKeyCount, openAICompatCount := BuildAPIKeyClients(cfg)
|
||||
totalAPIKeyClients := geminiAPIKeyCount + vertexCompatAPIKeyCount + claudeAPIKeyCount + codexAPIKeyCount + openAICompatCount
|
||||
log.Debugf("loaded %d API key clients", totalAPIKeyClients)
|
||||
|
||||
var authFileCount int
|
||||
@@ -683,7 +976,7 @@ func (w *Watcher) reloadClients(rescanAuth bool) {
|
||||
w.clientsMutex.Unlock()
|
||||
}
|
||||
|
||||
totalNewClients := authFileCount + geminiAPIKeyCount + claudeAPIKeyCount + codexAPIKeyCount + openAICompatCount
|
||||
totalNewClients := authFileCount + geminiAPIKeyCount + vertexCompatAPIKeyCount + claudeAPIKeyCount + codexAPIKeyCount + openAICompatCount
|
||||
|
||||
// Ensure consumers observe the new configuration before auth updates dispatch.
|
||||
if w.reloadCallback != nil {
|
||||
@@ -693,10 +986,11 @@ func (w *Watcher) reloadClients(rescanAuth bool) {
|
||||
|
||||
w.refreshAuthState()
|
||||
|
||||
log.Infof("full client load complete - %d clients (%d auth files + %d Gemini API keys + %d Claude API keys + %d Codex keys + %d OpenAI-compat)",
|
||||
log.Infof("full client load complete - %d clients (%d auth files + %d Gemini API keys + %d Vertex API keys + %d Claude API keys + %d Codex keys + %d OpenAI-compat)",
|
||||
totalNewClients,
|
||||
authFileCount,
|
||||
geminiAPIKeyCount,
|
||||
vertexCompatAPIKeyCount,
|
||||
claudeAPIKeyCount,
|
||||
codexAPIKeyCount,
|
||||
openAICompatCount,
|
||||
@@ -808,8 +1102,10 @@ func (w *Watcher) SnapshotCoreAuths() []*coreauth.Auth {
|
||||
CreatedAt: now,
|
||||
UpdatedAt: now,
|
||||
}
|
||||
applyAuthExcludedModelsMeta(a, cfg, entry.ExcludedModels, "apikey")
|
||||
out = append(out, a)
|
||||
}
|
||||
|
||||
// Claude API keys -> synthesize auths
|
||||
for i := range cfg.ClaudeKey {
|
||||
ck := cfg.ClaudeKey[i]
|
||||
@@ -841,6 +1137,7 @@ func (w *Watcher) SnapshotCoreAuths() []*coreauth.Auth {
|
||||
CreatedAt: now,
|
||||
UpdatedAt: now,
|
||||
}
|
||||
applyAuthExcludedModelsMeta(a, cfg, ck.ExcludedModels, "apikey")
|
||||
out = append(out, a)
|
||||
}
|
||||
// Codex API keys -> synthesize auths
|
||||
@@ -870,6 +1167,7 @@ func (w *Watcher) SnapshotCoreAuths() []*coreauth.Auth {
|
||||
CreatedAt: now,
|
||||
UpdatedAt: now,
|
||||
}
|
||||
applyAuthExcludedModelsMeta(a, cfg, ck.ExcludedModels, "apikey")
|
||||
out = append(out, a)
|
||||
}
|
||||
for i := range cfg.OpenAICompatibility {
|
||||
@@ -974,6 +1272,43 @@ func (w *Watcher) SnapshotCoreAuths() []*coreauth.Auth {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Process Vertex API key providers (Vertex-compatible endpoints)
|
||||
for i := range cfg.VertexCompatAPIKey {
|
||||
compat := &cfg.VertexCompatAPIKey[i]
|
||||
providerName := "vertex"
|
||||
base := strings.TrimSpace(compat.BaseURL)
|
||||
|
||||
key := strings.TrimSpace(compat.APIKey)
|
||||
proxyURL := strings.TrimSpace(compat.ProxyURL)
|
||||
idKind := fmt.Sprintf("vertex:apikey:%s", base)
|
||||
id, token := idGen.next(idKind, key, base, proxyURL)
|
||||
attrs := map[string]string{
|
||||
"source": fmt.Sprintf("config:vertex-apikey[%s]", token),
|
||||
"base_url": base,
|
||||
"provider_key": providerName,
|
||||
}
|
||||
if key != "" {
|
||||
attrs["api_key"] = key
|
||||
}
|
||||
if hash := computeVertexCompatModelsHash(compat.Models); hash != "" {
|
||||
attrs["models_hash"] = hash
|
||||
}
|
||||
addConfigHeadersToAttrs(compat.Headers, attrs)
|
||||
a := &coreauth.Auth{
|
||||
ID: id,
|
||||
Provider: providerName,
|
||||
Label: "vertex-apikey",
|
||||
Status: coreauth.StatusActive,
|
||||
ProxyURL: proxyURL,
|
||||
Attributes: attrs,
|
||||
CreatedAt: now,
|
||||
UpdatedAt: now,
|
||||
}
|
||||
applyAuthExcludedModelsMeta(a, cfg, nil, "apikey")
|
||||
out = append(out, a)
|
||||
}
|
||||
|
||||
// Also synthesize auth entries directly from auth files (for OAuth/file-backed providers)
|
||||
entries, _ := os.ReadDir(w.authDir)
|
||||
for _, e := range entries {
|
||||
@@ -1030,8 +1365,12 @@ func (w *Watcher) SnapshotCoreAuths() []*coreauth.Auth {
|
||||
CreatedAt: now,
|
||||
UpdatedAt: now,
|
||||
}
|
||||
applyAuthExcludedModelsMeta(a, cfg, nil, "oauth")
|
||||
if provider == "gemini-cli" {
|
||||
if virtuals := synthesizeGeminiVirtualAuths(a, metadata, now); len(virtuals) > 0 {
|
||||
for _, v := range virtuals {
|
||||
applyAuthExcludedModelsMeta(v, cfg, nil, "oauth")
|
||||
}
|
||||
out = append(out, a)
|
||||
out = append(out, virtuals...)
|
||||
continue
|
||||
@@ -1186,8 +1525,9 @@ func (w *Watcher) loadFileClients(cfg *config.Config) int {
|
||||
return authFileCount
|
||||
}
|
||||
|
||||
func BuildAPIKeyClients(cfg *config.Config) (int, int, int, int) {
|
||||
func BuildAPIKeyClients(cfg *config.Config) (int, int, int, int, int) {
|
||||
geminiAPIKeyCount := 0
|
||||
vertexCompatAPIKeyCount := 0
|
||||
claudeAPIKeyCount := 0
|
||||
codexAPIKeyCount := 0
|
||||
openAICompatCount := 0
|
||||
@@ -1196,6 +1536,9 @@ func BuildAPIKeyClients(cfg *config.Config) (int, int, int, int) {
|
||||
// Stateless executor handles Gemini API keys; avoid constructing legacy clients.
|
||||
geminiAPIKeyCount += len(cfg.GeminiKey)
|
||||
}
|
||||
if len(cfg.VertexCompatAPIKey) > 0 {
|
||||
vertexCompatAPIKeyCount += len(cfg.VertexCompatAPIKey)
|
||||
}
|
||||
if len(cfg.ClaudeKey) > 0 {
|
||||
claudeAPIKeyCount += len(cfg.ClaudeKey)
|
||||
}
|
||||
@@ -1213,7 +1556,7 @@ func BuildAPIKeyClients(cfg *config.Config) (int, int, int, int) {
|
||||
}
|
||||
}
|
||||
}
|
||||
return geminiAPIKeyCount, claudeAPIKeyCount, codexAPIKeyCount, openAICompatCount
|
||||
return geminiAPIKeyCount, vertexCompatAPIKeyCount, claudeAPIKeyCount, codexAPIKeyCount, openAICompatCount
|
||||
}
|
||||
|
||||
func diffOpenAICompatibility(oldList, newList []config.OpenAICompatibility) []string {
|
||||
@@ -1378,6 +1721,9 @@ func buildConfigChangeDetails(oldCfg, newCfg *config.Config) []string {
|
||||
if oldCfg.RequestRetry != newCfg.RequestRetry {
|
||||
changes = append(changes, fmt.Sprintf("request-retry: %d -> %d", oldCfg.RequestRetry, newCfg.RequestRetry))
|
||||
}
|
||||
if oldCfg.MaxRetryInterval != newCfg.MaxRetryInterval {
|
||||
changes = append(changes, fmt.Sprintf("max-retry-interval: %d -> %d", oldCfg.MaxRetryInterval, newCfg.MaxRetryInterval))
|
||||
}
|
||||
if oldCfg.ProxyURL != newCfg.ProxyURL {
|
||||
changes = append(changes, fmt.Sprintf("proxy-url: %s -> %s", oldCfg.ProxyURL, newCfg.ProxyURL))
|
||||
}
|
||||
@@ -1420,6 +1766,11 @@ func buildConfigChangeDetails(oldCfg, newCfg *config.Config) []string {
|
||||
if !equalStringMap(o.Headers, n.Headers) {
|
||||
changes = append(changes, fmt.Sprintf("gemini[%d].headers: updated", i))
|
||||
}
|
||||
oldExcluded := summarizeExcludedModels(o.ExcludedModels)
|
||||
newExcluded := summarizeExcludedModels(n.ExcludedModels)
|
||||
if oldExcluded.hash != newExcluded.hash {
|
||||
changes = append(changes, fmt.Sprintf("gemini[%d].excluded-models: updated (%d -> %d entries)", i, oldExcluded.count, newExcluded.count))
|
||||
}
|
||||
}
|
||||
if !reflect.DeepEqual(trimStrings(oldCfg.GlAPIKey), trimStrings(newCfg.GlAPIKey)) {
|
||||
changes = append(changes, "generative-language-api-key: values updated (legacy view, redacted)")
|
||||
@@ -1448,6 +1799,11 @@ func buildConfigChangeDetails(oldCfg, newCfg *config.Config) []string {
|
||||
if !equalStringMap(o.Headers, n.Headers) {
|
||||
changes = append(changes, fmt.Sprintf("claude[%d].headers: updated", i))
|
||||
}
|
||||
oldExcluded := summarizeExcludedModels(o.ExcludedModels)
|
||||
newExcluded := summarizeExcludedModels(n.ExcludedModels)
|
||||
if oldExcluded.hash != newExcluded.hash {
|
||||
changes = append(changes, fmt.Sprintf("claude[%d].excluded-models: updated (%d -> %d entries)", i, oldExcluded.count, newExcluded.count))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1473,9 +1829,18 @@ func buildConfigChangeDetails(oldCfg, newCfg *config.Config) []string {
|
||||
if !equalStringMap(o.Headers, n.Headers) {
|
||||
changes = append(changes, fmt.Sprintf("codex[%d].headers: updated", i))
|
||||
}
|
||||
oldExcluded := summarizeExcludedModels(o.ExcludedModels)
|
||||
newExcluded := summarizeExcludedModels(n.ExcludedModels)
|
||||
if oldExcluded.hash != newExcluded.hash {
|
||||
changes = append(changes, fmt.Sprintf("codex[%d].excluded-models: updated (%d -> %d entries)", i, oldExcluded.count, newExcluded.count))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if entries, _ := diffOAuthExcludedModelChanges(oldCfg.OAuthExcludedModels, newCfg.OAuthExcludedModels); len(entries) > 0 {
|
||||
changes = append(changes, entries...)
|
||||
}
|
||||
|
||||
// Remote management (never print the key)
|
||||
if oldCfg.RemoteManagement.AllowRemote != newCfg.RemoteManagement.AllowRemote {
|
||||
changes = append(changes, fmt.Sprintf("remote-management.allow-remote: %t -> %t", oldCfg.RemoteManagement.AllowRemote, newCfg.RemoteManagement.AllowRemote))
|
||||
|
||||
@@ -69,6 +69,27 @@ func (h *GeminiAPIHandler) GeminiGetHandler(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
switch request.Action {
|
||||
case "gemini-3-pro-preview":
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"name": "models/gemini-3-pro-preview",
|
||||
"version": "3",
|
||||
"displayName": "Gemini 3 Pro Preview",
|
||||
"description": "Gemini 3 Pro Preview",
|
||||
"inputTokenLimit": 1048576,
|
||||
"outputTokenLimit": 65536,
|
||||
"supportedGenerationMethods": []string{
|
||||
"generateContent",
|
||||
"countTokens",
|
||||
"createCachedContent",
|
||||
"batchGenerateContent",
|
||||
},
|
||||
"temperature": 1,
|
||||
"topP": 0.95,
|
||||
"topK": 64,
|
||||
"maxTemperature": 2,
|
||||
"thinking": true,
|
||||
},
|
||||
)
|
||||
case "gemini-2.5-pro":
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"name": "models/gemini-2.5-pro",
|
||||
|
||||
@@ -4,6 +4,7 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
@@ -120,11 +121,11 @@ func (h *BaseAPIHandler) GetContextWithCancel(handler interfaces.APIHandler, c *
|
||||
data := params[0]
|
||||
switch data.(type) {
|
||||
case []byte:
|
||||
c.Set("API_RESPONSE", data.([]byte))
|
||||
appendAPIResponse(c, data.([]byte))
|
||||
case error:
|
||||
c.Set("API_RESPONSE", []byte(data.(error).Error()))
|
||||
appendAPIResponse(c, []byte(data.(error).Error()))
|
||||
case string:
|
||||
c.Set("API_RESPONSE", []byte(data.(string)))
|
||||
appendAPIResponse(c, []byte(data.(string)))
|
||||
case bool:
|
||||
case nil:
|
||||
}
|
||||
@@ -135,6 +136,28 @@ func (h *BaseAPIHandler) GetContextWithCancel(handler interfaces.APIHandler, c *
|
||||
}
|
||||
}
|
||||
|
||||
// appendAPIResponse preserves any previously captured API response and appends new data.
|
||||
func appendAPIResponse(c *gin.Context, data []byte) {
|
||||
if c == nil || len(data) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
if existing, exists := c.Get("API_RESPONSE"); exists {
|
||||
if existingBytes, ok := existing.([]byte); ok && len(existingBytes) > 0 {
|
||||
combined := make([]byte, 0, len(existingBytes)+len(data)+1)
|
||||
combined = append(combined, existingBytes...)
|
||||
if existingBytes[len(existingBytes)-1] != '\n' {
|
||||
combined = append(combined, '\n')
|
||||
}
|
||||
combined = append(combined, data...)
|
||||
c.Set("API_RESPONSE", combined)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
c.Set("API_RESPONSE", bytes.Clone(data))
|
||||
}
|
||||
|
||||
// ExecuteWithAuthManager executes a non-streaming request via the core auth manager.
|
||||
// This path is the only supported execution route.
|
||||
func (h *BaseAPIHandler) ExecuteWithAuthManager(ctx context.Context, handlerType, modelName string, rawJSON []byte, alt string) ([]byte, *interfaces.ErrorMessage) {
|
||||
@@ -297,7 +320,7 @@ func (h *BaseAPIHandler) ExecuteStreamWithAuthManager(ctx context.Context, handl
|
||||
func (h *BaseAPIHandler) getRequestDetails(modelName string) (providers []string, normalizedModel string, metadata map[string]any, err *interfaces.ErrorMessage) {
|
||||
// Resolve "auto" model to an actual available model first
|
||||
resolvedModelName := util.ResolveAutoModel(modelName)
|
||||
|
||||
|
||||
providerName, extractedModelName, isDynamic := h.parseDynamicModel(resolvedModelName)
|
||||
|
||||
// First, normalize the model name to handle suffixes like "-thinking-128"
|
||||
|
||||
@@ -106,6 +106,10 @@ type Manager struct {
|
||||
// providerOffsets tracks per-model provider rotation state for multi-provider routing.
|
||||
providerOffsets map[string]int
|
||||
|
||||
// Retry controls request retry behavior.
|
||||
requestRetry atomic.Int32
|
||||
maxRetryInterval atomic.Int64
|
||||
|
||||
// Optional HTTP RoundTripper provider injected by host.
|
||||
rtProvider RoundTripperProvider
|
||||
|
||||
@@ -145,6 +149,21 @@ func (m *Manager) SetRoundTripperProvider(p RoundTripperProvider) {
|
||||
m.mu.Unlock()
|
||||
}
|
||||
|
||||
// SetRetryConfig updates retry attempts and cooldown wait interval.
|
||||
func (m *Manager) SetRetryConfig(retry int, maxRetryInterval time.Duration) {
|
||||
if m == nil {
|
||||
return
|
||||
}
|
||||
if retry < 0 {
|
||||
retry = 0
|
||||
}
|
||||
if maxRetryInterval < 0 {
|
||||
maxRetryInterval = 0
|
||||
}
|
||||
m.requestRetry.Store(int32(retry))
|
||||
m.maxRetryInterval.Store(maxRetryInterval.Nanoseconds())
|
||||
}
|
||||
|
||||
// RegisterExecutor registers a provider executor with the manager.
|
||||
func (m *Manager) RegisterExecutor(executor ProviderExecutor) {
|
||||
if executor == nil {
|
||||
@@ -188,8 +207,12 @@ func (m *Manager) Update(ctx context.Context, auth *Auth) (*Auth, error) {
|
||||
if auth == nil || auth.ID == "" {
|
||||
return nil, nil
|
||||
}
|
||||
auth.EnsureIndex()
|
||||
m.mu.Lock()
|
||||
if existing, ok := m.auths[auth.ID]; ok && existing != nil && !auth.indexAssigned && auth.Index == 0 {
|
||||
auth.Index = existing.Index
|
||||
auth.indexAssigned = existing.indexAssigned
|
||||
}
|
||||
auth.EnsureIndex()
|
||||
m.auths[auth.ID] = auth.Clone()
|
||||
m.mu.Unlock()
|
||||
_ = m.persist(ctx, auth)
|
||||
@@ -229,13 +252,28 @@ func (m *Manager) Execute(ctx context.Context, providers []string, req cliproxye
|
||||
rotated := m.rotateProviders(req.Model, normalized)
|
||||
defer m.advanceProviderCursor(req.Model, normalized)
|
||||
|
||||
retryTimes, maxWait := m.retrySettings()
|
||||
attempts := retryTimes + 1
|
||||
if attempts < 1 {
|
||||
attempts = 1
|
||||
}
|
||||
|
||||
var lastErr error
|
||||
for _, provider := range rotated {
|
||||
resp, errExec := m.executeWithProvider(ctx, provider, req, opts)
|
||||
for attempt := 0; attempt < attempts; attempt++ {
|
||||
resp, errExec := m.executeProvidersOnce(ctx, rotated, func(execCtx context.Context, provider string) (cliproxyexecutor.Response, error) {
|
||||
return m.executeWithProvider(execCtx, provider, req, opts)
|
||||
})
|
||||
if errExec == nil {
|
||||
return resp, nil
|
||||
}
|
||||
lastErr = errExec
|
||||
wait, shouldRetry := m.shouldRetryAfterError(errExec, attempt, attempts, rotated, req.Model, maxWait)
|
||||
if !shouldRetry {
|
||||
break
|
||||
}
|
||||
if errWait := waitForCooldown(ctx, wait); errWait != nil {
|
||||
return cliproxyexecutor.Response{}, errWait
|
||||
}
|
||||
}
|
||||
if lastErr != nil {
|
||||
return cliproxyexecutor.Response{}, lastErr
|
||||
@@ -253,13 +291,28 @@ func (m *Manager) ExecuteCount(ctx context.Context, providers []string, req clip
|
||||
rotated := m.rotateProviders(req.Model, normalized)
|
||||
defer m.advanceProviderCursor(req.Model, normalized)
|
||||
|
||||
retryTimes, maxWait := m.retrySettings()
|
||||
attempts := retryTimes + 1
|
||||
if attempts < 1 {
|
||||
attempts = 1
|
||||
}
|
||||
|
||||
var lastErr error
|
||||
for _, provider := range rotated {
|
||||
resp, errExec := m.executeCountWithProvider(ctx, provider, req, opts)
|
||||
for attempt := 0; attempt < attempts; attempt++ {
|
||||
resp, errExec := m.executeProvidersOnce(ctx, rotated, func(execCtx context.Context, provider string) (cliproxyexecutor.Response, error) {
|
||||
return m.executeCountWithProvider(execCtx, provider, req, opts)
|
||||
})
|
||||
if errExec == nil {
|
||||
return resp, nil
|
||||
}
|
||||
lastErr = errExec
|
||||
wait, shouldRetry := m.shouldRetryAfterError(errExec, attempt, attempts, rotated, req.Model, maxWait)
|
||||
if !shouldRetry {
|
||||
break
|
||||
}
|
||||
if errWait := waitForCooldown(ctx, wait); errWait != nil {
|
||||
return cliproxyexecutor.Response{}, errWait
|
||||
}
|
||||
}
|
||||
if lastErr != nil {
|
||||
return cliproxyexecutor.Response{}, lastErr
|
||||
@@ -277,13 +330,28 @@ func (m *Manager) ExecuteStream(ctx context.Context, providers []string, req cli
|
||||
rotated := m.rotateProviders(req.Model, normalized)
|
||||
defer m.advanceProviderCursor(req.Model, normalized)
|
||||
|
||||
retryTimes, maxWait := m.retrySettings()
|
||||
attempts := retryTimes + 1
|
||||
if attempts < 1 {
|
||||
attempts = 1
|
||||
}
|
||||
|
||||
var lastErr error
|
||||
for _, provider := range rotated {
|
||||
chunks, errStream := m.executeStreamWithProvider(ctx, provider, req, opts)
|
||||
for attempt := 0; attempt < attempts; attempt++ {
|
||||
chunks, errStream := m.executeStreamProvidersOnce(ctx, rotated, func(execCtx context.Context, provider string) (<-chan cliproxyexecutor.StreamChunk, error) {
|
||||
return m.executeStreamWithProvider(execCtx, provider, req, opts)
|
||||
})
|
||||
if errStream == nil {
|
||||
return chunks, nil
|
||||
}
|
||||
lastErr = errStream
|
||||
wait, shouldRetry := m.shouldRetryAfterError(errStream, attempt, attempts, rotated, req.Model, maxWait)
|
||||
if !shouldRetry {
|
||||
break
|
||||
}
|
||||
if errWait := waitForCooldown(ctx, wait); errWait != nil {
|
||||
return nil, errWait
|
||||
}
|
||||
}
|
||||
if lastErr != nil {
|
||||
return nil, lastErr
|
||||
@@ -507,6 +575,123 @@ func (m *Manager) advanceProviderCursor(model string, providers []string) {
|
||||
m.mu.Unlock()
|
||||
}
|
||||
|
||||
func (m *Manager) retrySettings() (int, time.Duration) {
|
||||
if m == nil {
|
||||
return 0, 0
|
||||
}
|
||||
return int(m.requestRetry.Load()), time.Duration(m.maxRetryInterval.Load())
|
||||
}
|
||||
|
||||
func (m *Manager) closestCooldownWait(providers []string, model string) (time.Duration, bool) {
|
||||
if m == nil || len(providers) == 0 {
|
||||
return 0, false
|
||||
}
|
||||
now := time.Now()
|
||||
providerSet := make(map[string]struct{}, len(providers))
|
||||
for i := range providers {
|
||||
key := strings.TrimSpace(strings.ToLower(providers[i]))
|
||||
if key == "" {
|
||||
continue
|
||||
}
|
||||
providerSet[key] = struct{}{}
|
||||
}
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
var (
|
||||
found bool
|
||||
minWait time.Duration
|
||||
)
|
||||
for _, auth := range m.auths {
|
||||
if auth == nil {
|
||||
continue
|
||||
}
|
||||
providerKey := strings.TrimSpace(strings.ToLower(auth.Provider))
|
||||
if _, ok := providerSet[providerKey]; !ok {
|
||||
continue
|
||||
}
|
||||
blocked, reason, next := isAuthBlockedForModel(auth, model, now)
|
||||
if !blocked || next.IsZero() || reason == blockReasonDisabled {
|
||||
continue
|
||||
}
|
||||
wait := next.Sub(now)
|
||||
if wait < 0 {
|
||||
continue
|
||||
}
|
||||
if !found || wait < minWait {
|
||||
minWait = wait
|
||||
found = true
|
||||
}
|
||||
}
|
||||
return minWait, found
|
||||
}
|
||||
|
||||
func (m *Manager) shouldRetryAfterError(err error, attempt, maxAttempts int, providers []string, model string, maxWait time.Duration) (time.Duration, bool) {
|
||||
if err == nil || attempt >= maxAttempts-1 {
|
||||
return 0, false
|
||||
}
|
||||
if maxWait <= 0 {
|
||||
return 0, false
|
||||
}
|
||||
if status := statusCodeFromError(err); status == http.StatusOK {
|
||||
return 0, false
|
||||
}
|
||||
wait, found := m.closestCooldownWait(providers, model)
|
||||
if !found || wait > maxWait {
|
||||
return 0, false
|
||||
}
|
||||
return wait, true
|
||||
}
|
||||
|
||||
func waitForCooldown(ctx context.Context, wait time.Duration) error {
|
||||
if wait <= 0 {
|
||||
return nil
|
||||
}
|
||||
timer := time.NewTimer(wait)
|
||||
defer timer.Stop()
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
case <-timer.C:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Manager) executeProvidersOnce(ctx context.Context, providers []string, fn func(context.Context, string) (cliproxyexecutor.Response, error)) (cliproxyexecutor.Response, error) {
|
||||
if len(providers) == 0 {
|
||||
return cliproxyexecutor.Response{}, &Error{Code: "provider_not_found", Message: "no provider supplied"}
|
||||
}
|
||||
var lastErr error
|
||||
for _, provider := range providers {
|
||||
resp, errExec := fn(ctx, provider)
|
||||
if errExec == nil {
|
||||
return resp, nil
|
||||
}
|
||||
lastErr = errExec
|
||||
}
|
||||
if lastErr != nil {
|
||||
return cliproxyexecutor.Response{}, lastErr
|
||||
}
|
||||
return cliproxyexecutor.Response{}, &Error{Code: "auth_not_found", Message: "no auth available"}
|
||||
}
|
||||
|
||||
func (m *Manager) executeStreamProvidersOnce(ctx context.Context, providers []string, fn func(context.Context, string) (<-chan cliproxyexecutor.StreamChunk, error)) (<-chan cliproxyexecutor.StreamChunk, error) {
|
||||
if len(providers) == 0 {
|
||||
return nil, &Error{Code: "provider_not_found", Message: "no provider supplied"}
|
||||
}
|
||||
var lastErr error
|
||||
for _, provider := range providers {
|
||||
chunks, errExec := fn(ctx, provider)
|
||||
if errExec == nil {
|
||||
return chunks, nil
|
||||
}
|
||||
lastErr = errExec
|
||||
}
|
||||
if lastErr != nil {
|
||||
return nil, lastErr
|
||||
}
|
||||
return nil, &Error{Code: "auth_not_found", Message: "no auth available"}
|
||||
}
|
||||
|
||||
// MarkResult records an execution result and notifies hooks.
|
||||
func (m *Manager) MarkResult(ctx context.Context, result Result) {
|
||||
if result.AuthID == "" {
|
||||
@@ -762,6 +947,20 @@ func cloneError(err *Error) *Error {
|
||||
}
|
||||
}
|
||||
|
||||
func statusCodeFromError(err error) int {
|
||||
if err == nil {
|
||||
return 0
|
||||
}
|
||||
type statusCoder interface {
|
||||
StatusCode() int
|
||||
}
|
||||
var sc statusCoder
|
||||
if errors.As(err, &sc) && sc != nil {
|
||||
return sc.StatusCode()
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
func retryAfterFromError(err error) *time.Duration {
|
||||
if err == nil {
|
||||
return nil
|
||||
@@ -919,6 +1118,14 @@ func (m *Manager) pickNext(ctx context.Context, provider, model string, opts cli
|
||||
}
|
||||
authCopy := selected.Clone()
|
||||
m.mu.RUnlock()
|
||||
if !selected.indexAssigned {
|
||||
m.mu.Lock()
|
||||
if current := m.auths[authCopy.ID]; current != nil && !current.indexAssigned {
|
||||
current.EnsureIndex()
|
||||
authCopy = current.Clone()
|
||||
}
|
||||
m.mu.Unlock()
|
||||
}
|
||||
return authCopy, executor, nil
|
||||
}
|
||||
|
||||
|
||||
@@ -29,7 +29,7 @@ func NewAPIKeyClientProvider() APIKeyClientProvider {
|
||||
type apiKeyClientProvider struct{}
|
||||
|
||||
func (p *apiKeyClientProvider) Load(ctx context.Context, cfg *config.Config) (*APIKeyClientResult, error) {
|
||||
geminiCount, claudeCount, codexCount, openAICompat := watcher.BuildAPIKeyClients(cfg)
|
||||
geminiCount, vertexCompatCount, claudeCount, codexCount, openAICompat := watcher.BuildAPIKeyClients(cfg)
|
||||
if ctx != nil {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
@@ -38,9 +38,10 @@ func (p *apiKeyClientProvider) Load(ctx context.Context, cfg *config.Config) (*A
|
||||
}
|
||||
}
|
||||
return &APIKeyClientResult{
|
||||
GeminiKeyCount: geminiCount,
|
||||
ClaudeKeyCount: claudeCount,
|
||||
CodexKeyCount: codexCount,
|
||||
OpenAICompatCount: openAICompat,
|
||||
GeminiKeyCount: geminiCount,
|
||||
VertexCompatKeyCount: vertexCompatCount,
|
||||
ClaudeKeyCount: claudeCount,
|
||||
CodexKeyCount: codexCount,
|
||||
OpenAICompatCount: openAICompat,
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -146,6 +146,27 @@ func (s *Service) consumeAuthUpdates(ctx context.Context) {
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Service) emitAuthUpdate(ctx context.Context, update watcher.AuthUpdate) {
|
||||
if s == nil {
|
||||
return
|
||||
}
|
||||
if ctx == nil {
|
||||
ctx = context.Background()
|
||||
}
|
||||
if s.watcher != nil && s.watcher.DispatchRuntimeAuthUpdate(update) {
|
||||
return
|
||||
}
|
||||
if s.authUpdates != nil {
|
||||
select {
|
||||
case s.authUpdates <- update:
|
||||
return
|
||||
default:
|
||||
log.Debugf("auth update queue saturated, applying inline action=%v id=%s", update.Action, update.ID)
|
||||
}
|
||||
}
|
||||
s.handleAuthUpdate(ctx, update)
|
||||
}
|
||||
|
||||
func (s *Service) handleAuthUpdate(ctx context.Context, update watcher.AuthUpdate) {
|
||||
if s == nil {
|
||||
return
|
||||
@@ -220,7 +241,11 @@ func (s *Service) wsOnConnected(channelID string) {
|
||||
Metadata: map[string]any{"email": channelID}, // metadata drives logging and usage tracking
|
||||
}
|
||||
log.Infof("websocket provider connected: %s", channelID)
|
||||
s.applyCoreAuthAddOrUpdate(context.Background(), auth)
|
||||
s.emitAuthUpdate(context.Background(), watcher.AuthUpdate{
|
||||
Action: watcher.AuthUpdateActionAdd,
|
||||
ID: auth.ID,
|
||||
Auth: auth,
|
||||
})
|
||||
}
|
||||
|
||||
func (s *Service) wsOnDisconnected(channelID string, reason error) {
|
||||
@@ -237,7 +262,10 @@ func (s *Service) wsOnDisconnected(channelID string, reason error) {
|
||||
log.Infof("websocket provider disconnected: %s", channelID)
|
||||
}
|
||||
ctx := context.Background()
|
||||
s.applyCoreAuthRemoval(ctx, channelID)
|
||||
s.emitAuthUpdate(ctx, watcher.AuthUpdate{
|
||||
Action: watcher.AuthUpdateActionDelete,
|
||||
ID: channelID,
|
||||
})
|
||||
}
|
||||
|
||||
func (s *Service) applyCoreAuthAddOrUpdate(ctx context.Context, auth *coreauth.Auth) {
|
||||
@@ -281,6 +309,14 @@ func (s *Service) applyCoreAuthRemoval(ctx context.Context, id string) {
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Service) applyRetryConfig(cfg *config.Config) {
|
||||
if s == nil || s.coreManager == nil || cfg == nil {
|
||||
return
|
||||
}
|
||||
maxInterval := time.Duration(cfg.MaxRetryInterval) * time.Second
|
||||
s.coreManager.SetRetryConfig(cfg.RequestRetry, maxInterval)
|
||||
}
|
||||
|
||||
func openAICompatInfoFromAuth(a *coreauth.Auth) (providerKey string, compatName string, ok bool) {
|
||||
if a == nil {
|
||||
return "", "", false
|
||||
@@ -288,7 +324,7 @@ func openAICompatInfoFromAuth(a *coreauth.Auth) (providerKey string, compatName
|
||||
if len(a.Attributes) > 0 {
|
||||
providerKey = strings.TrimSpace(a.Attributes["provider_key"])
|
||||
compatName = strings.TrimSpace(a.Attributes["compat_name"])
|
||||
if providerKey != "" || compatName != "" {
|
||||
if compatName != "" {
|
||||
if providerKey == "" {
|
||||
providerKey = compatName
|
||||
}
|
||||
@@ -394,6 +430,8 @@ func (s *Service) Run(ctx context.Context) error {
|
||||
return err
|
||||
}
|
||||
|
||||
s.applyRetryConfig(s.cfg)
|
||||
|
||||
if s.coreManager != nil {
|
||||
if errLoad := s.coreManager.Load(ctx); errLoad != nil {
|
||||
log.Warnf("failed to load auth store: %v", errLoad)
|
||||
@@ -460,7 +498,7 @@ func (s *Service) Run(ctx context.Context) error {
|
||||
}()
|
||||
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
fmt.Println("API server started successfully")
|
||||
fmt.Printf("API server started successfully on: %d\n", s.cfg.Port)
|
||||
|
||||
if s.hooks.OnAfterStart != nil {
|
||||
s.hooks.OnAfterStart(s)
|
||||
@@ -476,6 +514,7 @@ func (s *Service) Run(ctx context.Context) error {
|
||||
if newCfg == nil {
|
||||
return
|
||||
}
|
||||
s.applyRetryConfig(newCfg)
|
||||
if s.server != nil {
|
||||
s.server.UpdateClients(newCfg)
|
||||
}
|
||||
@@ -606,6 +645,7 @@ func (s *Service) registerModelsForAuth(a *coreauth.Auth) {
|
||||
if a == nil || a.ID == "" {
|
||||
return
|
||||
}
|
||||
authKind := strings.ToLower(strings.TrimSpace(a.Attributes["auth_kind"]))
|
||||
if a.Attributes != nil {
|
||||
if v := strings.TrimSpace(a.Attributes["gemini_virtual_primary"]); strings.EqualFold(v, "true") {
|
||||
GlobalModelRegistry().UnregisterClient(a.ID)
|
||||
@@ -625,32 +665,62 @@ func (s *Service) registerModelsForAuth(a *coreauth.Auth) {
|
||||
if compatDetected {
|
||||
provider = "openai-compatibility"
|
||||
}
|
||||
excluded := s.oauthExcludedModels(provider, authKind)
|
||||
var models []*ModelInfo
|
||||
switch provider {
|
||||
case "gemini":
|
||||
models = registry.GetGeminiModels()
|
||||
if entry := s.resolveConfigGeminiKey(a); entry != nil {
|
||||
if authKind == "apikey" {
|
||||
excluded = entry.ExcludedModels
|
||||
}
|
||||
}
|
||||
models = applyExcludedModels(models, excluded)
|
||||
case "vertex":
|
||||
// Vertex AI Gemini supports the same model identifiers as Gemini.
|
||||
models = registry.GetGeminiVertexModels()
|
||||
if authKind == "apikey" {
|
||||
if entry := s.resolveConfigVertexCompatKey(a); entry != nil && len(entry.Models) > 0 {
|
||||
models = buildVertexCompatConfigModels(entry)
|
||||
}
|
||||
}
|
||||
models = applyExcludedModels(models, excluded)
|
||||
case "gemini-cli":
|
||||
models = registry.GetGeminiCLIModels()
|
||||
models = applyExcludedModels(models, excluded)
|
||||
case "aistudio":
|
||||
models = registry.GetAIStudioModels()
|
||||
models = applyExcludedModels(models, excluded)
|
||||
case "antigravity":
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second)
|
||||
models = executor.FetchAntigravityModels(ctx, a, s.cfg)
|
||||
cancel()
|
||||
models = applyExcludedModels(models, excluded)
|
||||
case "claude":
|
||||
models = registry.GetClaudeModels()
|
||||
if entry := s.resolveConfigClaudeKey(a); entry != nil && len(entry.Models) > 0 {
|
||||
models = buildClaudeConfigModels(entry)
|
||||
if entry := s.resolveConfigClaudeKey(a); entry != nil {
|
||||
if len(entry.Models) > 0 {
|
||||
models = buildClaudeConfigModels(entry)
|
||||
}
|
||||
if authKind == "apikey" {
|
||||
excluded = entry.ExcludedModels
|
||||
}
|
||||
}
|
||||
models = applyExcludedModels(models, excluded)
|
||||
case "codex":
|
||||
models = registry.GetOpenAIModels()
|
||||
if entry := s.resolveConfigCodexKey(a); entry != nil {
|
||||
if authKind == "apikey" {
|
||||
excluded = entry.ExcludedModels
|
||||
}
|
||||
}
|
||||
models = applyExcludedModels(models, excluded)
|
||||
case "qwen":
|
||||
models = registry.GetQwenModels()
|
||||
models = applyExcludedModels(models, excluded)
|
||||
case "iflow":
|
||||
models = registry.GetIFlowModels()
|
||||
models = applyExcludedModels(models, excluded)
|
||||
default:
|
||||
// Handle OpenAI-compatibility providers by name using config
|
||||
if s.cfg != nil {
|
||||
@@ -738,7 +808,10 @@ func (s *Service) registerModelsForAuth(a *coreauth.Auth) {
|
||||
key = strings.ToLower(strings.TrimSpace(a.Provider))
|
||||
}
|
||||
GlobalModelRegistry().RegisterClient(a.ID, key, models)
|
||||
return
|
||||
}
|
||||
|
||||
GlobalModelRegistry().UnregisterClient(a.ID)
|
||||
}
|
||||
|
||||
func (s *Service) resolveConfigClaudeKey(auth *coreauth.Auth) *config.ClaudeKey {
|
||||
@@ -780,6 +853,222 @@ func (s *Service) resolveConfigClaudeKey(auth *coreauth.Auth) *config.ClaudeKey
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Service) resolveConfigGeminiKey(auth *coreauth.Auth) *config.GeminiKey {
|
||||
if auth == nil || s.cfg == nil {
|
||||
return nil
|
||||
}
|
||||
var attrKey, attrBase string
|
||||
if auth.Attributes != nil {
|
||||
attrKey = strings.TrimSpace(auth.Attributes["api_key"])
|
||||
attrBase = strings.TrimSpace(auth.Attributes["base_url"])
|
||||
}
|
||||
for i := range s.cfg.GeminiKey {
|
||||
entry := &s.cfg.GeminiKey[i]
|
||||
cfgKey := strings.TrimSpace(entry.APIKey)
|
||||
cfgBase := strings.TrimSpace(entry.BaseURL)
|
||||
if attrKey != "" && strings.EqualFold(cfgKey, attrKey) {
|
||||
if cfgBase == "" || strings.EqualFold(cfgBase, attrBase) {
|
||||
return entry
|
||||
}
|
||||
continue
|
||||
}
|
||||
if attrKey == "" && attrBase != "" && strings.EqualFold(cfgBase, attrBase) {
|
||||
return entry
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Service) resolveConfigVertexCompatKey(auth *coreauth.Auth) *config.VertexCompatKey {
|
||||
if auth == nil || s.cfg == nil {
|
||||
return nil
|
||||
}
|
||||
var attrKey, attrBase string
|
||||
if auth.Attributes != nil {
|
||||
attrKey = strings.TrimSpace(auth.Attributes["api_key"])
|
||||
attrBase = strings.TrimSpace(auth.Attributes["base_url"])
|
||||
}
|
||||
for i := range s.cfg.VertexCompatAPIKey {
|
||||
entry := &s.cfg.VertexCompatAPIKey[i]
|
||||
cfgKey := strings.TrimSpace(entry.APIKey)
|
||||
cfgBase := strings.TrimSpace(entry.BaseURL)
|
||||
if attrKey != "" && strings.EqualFold(cfgKey, attrKey) {
|
||||
if cfgBase == "" || strings.EqualFold(cfgBase, attrBase) {
|
||||
return entry
|
||||
}
|
||||
continue
|
||||
}
|
||||
if attrKey == "" && attrBase != "" && strings.EqualFold(cfgBase, attrBase) {
|
||||
return entry
|
||||
}
|
||||
}
|
||||
if attrKey != "" {
|
||||
for i := range s.cfg.VertexCompatAPIKey {
|
||||
entry := &s.cfg.VertexCompatAPIKey[i]
|
||||
if strings.EqualFold(strings.TrimSpace(entry.APIKey), attrKey) {
|
||||
return entry
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Service) resolveConfigCodexKey(auth *coreauth.Auth) *config.CodexKey {
|
||||
if auth == nil || s.cfg == nil {
|
||||
return nil
|
||||
}
|
||||
var attrKey, attrBase string
|
||||
if auth.Attributes != nil {
|
||||
attrKey = strings.TrimSpace(auth.Attributes["api_key"])
|
||||
attrBase = strings.TrimSpace(auth.Attributes["base_url"])
|
||||
}
|
||||
for i := range s.cfg.CodexKey {
|
||||
entry := &s.cfg.CodexKey[i]
|
||||
cfgKey := strings.TrimSpace(entry.APIKey)
|
||||
cfgBase := strings.TrimSpace(entry.BaseURL)
|
||||
if attrKey != "" && strings.EqualFold(cfgKey, attrKey) {
|
||||
if cfgBase == "" || strings.EqualFold(cfgBase, attrBase) {
|
||||
return entry
|
||||
}
|
||||
continue
|
||||
}
|
||||
if attrKey == "" && attrBase != "" && strings.EqualFold(cfgBase, attrBase) {
|
||||
return entry
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Service) oauthExcludedModels(provider, authKind string) []string {
|
||||
cfg := s.cfg
|
||||
if cfg == nil {
|
||||
return nil
|
||||
}
|
||||
authKindKey := strings.ToLower(strings.TrimSpace(authKind))
|
||||
providerKey := strings.ToLower(strings.TrimSpace(provider))
|
||||
if authKindKey == "apikey" {
|
||||
return nil
|
||||
}
|
||||
return cfg.OAuthExcludedModels[providerKey]
|
||||
}
|
||||
|
||||
func applyExcludedModels(models []*ModelInfo, excluded []string) []*ModelInfo {
|
||||
if len(models) == 0 || len(excluded) == 0 {
|
||||
return models
|
||||
}
|
||||
|
||||
patterns := make([]string, 0, len(excluded))
|
||||
for _, item := range excluded {
|
||||
if trimmed := strings.TrimSpace(item); trimmed != "" {
|
||||
patterns = append(patterns, strings.ToLower(trimmed))
|
||||
}
|
||||
}
|
||||
if len(patterns) == 0 {
|
||||
return models
|
||||
}
|
||||
|
||||
filtered := make([]*ModelInfo, 0, len(models))
|
||||
for _, model := range models {
|
||||
if model == nil {
|
||||
continue
|
||||
}
|
||||
modelID := strings.ToLower(strings.TrimSpace(model.ID))
|
||||
blocked := false
|
||||
for _, pattern := range patterns {
|
||||
if matchWildcard(pattern, modelID) {
|
||||
blocked = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !blocked {
|
||||
filtered = append(filtered, model)
|
||||
}
|
||||
}
|
||||
return filtered
|
||||
}
|
||||
|
||||
// matchWildcard performs case-insensitive wildcard matching where '*' matches any substring.
|
||||
func matchWildcard(pattern, value string) bool {
|
||||
if pattern == "" {
|
||||
return false
|
||||
}
|
||||
|
||||
// Fast path for exact match (no wildcard present).
|
||||
if !strings.Contains(pattern, "*") {
|
||||
return pattern == value
|
||||
}
|
||||
|
||||
parts := strings.Split(pattern, "*")
|
||||
// Handle prefix.
|
||||
if prefix := parts[0]; prefix != "" {
|
||||
if !strings.HasPrefix(value, prefix) {
|
||||
return false
|
||||
}
|
||||
value = value[len(prefix):]
|
||||
}
|
||||
|
||||
// Handle suffix.
|
||||
if suffix := parts[len(parts)-1]; suffix != "" {
|
||||
if !strings.HasSuffix(value, suffix) {
|
||||
return false
|
||||
}
|
||||
value = value[:len(value)-len(suffix)]
|
||||
}
|
||||
|
||||
// Handle middle segments in order.
|
||||
for i := 1; i < len(parts)-1; i++ {
|
||||
segment := parts[i]
|
||||
if segment == "" {
|
||||
continue
|
||||
}
|
||||
idx := strings.Index(value, segment)
|
||||
if idx < 0 {
|
||||
return false
|
||||
}
|
||||
value = value[idx+len(segment):]
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
func buildVertexCompatConfigModels(entry *config.VertexCompatKey) []*ModelInfo {
|
||||
if entry == nil || len(entry.Models) == 0 {
|
||||
return nil
|
||||
}
|
||||
now := time.Now().Unix()
|
||||
out := make([]*ModelInfo, 0, len(entry.Models))
|
||||
seen := make(map[string]struct{}, len(entry.Models))
|
||||
for i := range entry.Models {
|
||||
model := entry.Models[i]
|
||||
name := strings.TrimSpace(model.Name)
|
||||
alias := strings.TrimSpace(model.Alias)
|
||||
if alias == "" {
|
||||
alias = name
|
||||
}
|
||||
if alias == "" {
|
||||
continue
|
||||
}
|
||||
key := strings.ToLower(alias)
|
||||
if _, exists := seen[key]; exists {
|
||||
continue
|
||||
}
|
||||
seen[key] = struct{}{}
|
||||
display := name
|
||||
if display == "" {
|
||||
display = alias
|
||||
}
|
||||
out = append(out, &ModelInfo{
|
||||
ID: alias,
|
||||
Object: "model",
|
||||
Created: now,
|
||||
OwnedBy: "vertex",
|
||||
Type: "vertex",
|
||||
DisplayName: display,
|
||||
})
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func buildClaudeConfigModels(entry *config.ClaudeKey) []*ModelInfo {
|
||||
if entry == nil || len(entry.Models) == 0 {
|
||||
return nil
|
||||
|
||||
@@ -49,19 +49,21 @@ type APIKeyClientProvider interface {
|
||||
Load(ctx context.Context, cfg *config.Config) (*APIKeyClientResult, error)
|
||||
}
|
||||
|
||||
// APIKeyClientResult contains API key based clients along with type counts.
|
||||
// It provides metadata about the number of clients loaded for each provider type.
|
||||
// APIKeyClientResult is returned by APIKeyClientProvider.Load()
|
||||
type APIKeyClientResult struct {
|
||||
// GeminiKeyCount is the number of Gemini API key clients loaded.
|
||||
// GeminiKeyCount is the number of Gemini API keys loaded
|
||||
GeminiKeyCount int
|
||||
|
||||
// ClaudeKeyCount is the number of Claude API key clients loaded.
|
||||
// VertexCompatKeyCount is the number of Vertex-compatible API keys loaded
|
||||
VertexCompatKeyCount int
|
||||
|
||||
// ClaudeKeyCount is the number of Claude API keys loaded
|
||||
ClaudeKeyCount int
|
||||
|
||||
// CodexKeyCount is the number of Codex API key clients loaded.
|
||||
// CodexKeyCount is the number of Codex API keys loaded
|
||||
CodexKeyCount int
|
||||
|
||||
// OpenAICompatCount is the number of OpenAI-compatible API key clients loaded.
|
||||
// OpenAICompatCount is the number of OpenAI compatibility API keys loaded
|
||||
OpenAICompatCount int
|
||||
}
|
||||
|
||||
@@ -83,9 +85,10 @@ type WatcherWrapper struct {
|
||||
start func(ctx context.Context) error
|
||||
stop func() error
|
||||
|
||||
setConfig func(cfg *config.Config)
|
||||
snapshotAuths func() []*coreauth.Auth
|
||||
setUpdateQueue func(queue chan<- watcher.AuthUpdate)
|
||||
setConfig func(cfg *config.Config)
|
||||
snapshotAuths func() []*coreauth.Auth
|
||||
setUpdateQueue func(queue chan<- watcher.AuthUpdate)
|
||||
dispatchRuntimeUpdate func(update watcher.AuthUpdate) bool
|
||||
}
|
||||
|
||||
// Start proxies to the underlying watcher Start implementation.
|
||||
@@ -112,6 +115,16 @@ func (w *WatcherWrapper) SetConfig(cfg *config.Config) {
|
||||
w.setConfig(cfg)
|
||||
}
|
||||
|
||||
// DispatchRuntimeAuthUpdate forwards runtime auth updates (e.g., websocket providers)
|
||||
// into the watcher-managed auth update queue when available.
|
||||
// Returns true if the update was enqueued successfully.
|
||||
func (w *WatcherWrapper) DispatchRuntimeAuthUpdate(update watcher.AuthUpdate) bool {
|
||||
if w == nil || w.dispatchRuntimeUpdate == nil {
|
||||
return false
|
||||
}
|
||||
return w.dispatchRuntimeUpdate(update)
|
||||
}
|
||||
|
||||
// SetClients updates the watcher file-backed clients registry.
|
||||
// SetClients and SetAPIKeyClients removed; watcher manages its own caches
|
||||
|
||||
|
||||
@@ -28,5 +28,8 @@ func defaultWatcherFactory(configPath, authDir string, reload func(*config.Confi
|
||||
setUpdateQueue: func(queue chan<- watcher.AuthUpdate) {
|
||||
w.SetAuthUpdateQueue(queue)
|
||||
},
|
||||
dispatchRuntimeUpdate: func(update watcher.AuthUpdate) bool {
|
||||
return w.DispatchRuntimeAuthUpdate(update)
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user