Compare commits

..

60 Commits

Author SHA1 Message Date
Luis Pater
7646a2b877 Fixed: #749
fix(translators): ensure `gjson.String` content is non-empty before setting `parts` in OpenAI request logic
2025-12-28 00:54:26 +08:00
Luis Pater
62090f2568 Merge pull request #750 from router-for-me/config
fix(config): preserve original config structure and avoid default value pollution
2025-12-27 22:10:01 +08:00
Luis Pater
c281f4cbaf Fixed: #747
fix(translators): rename and integrate `usageMetadata` as `cpaUsageMetadata` in Claude processing logic
2025-12-27 22:02:11 +08:00
hkfires
09455f9e85 fix(config): make streaming keepalive and retries ints 2025-12-27 20:56:47 +08:00
hkfires
c8e72ba0dc fix(config): smart merge writes non-default new keys only 2025-12-27 20:28:54 +08:00
hkfires
375ef252ab docs(config): clarify merge mapping behavior 2025-12-27 19:30:21 +08:00
hkfires
ee552f8720 chore(config): update ignore patterns 2025-12-27 19:13:14 +08:00
hkfires
2e88c4858e fix(config): avoid adding new keys when merging 2025-12-27 19:00:47 +08:00
Luis Pater
3f50da85c1 Merge pull request #745 from router-for-me/auth
fix(auth): make provider rotation atomic
2025-12-27 13:01:22 +08:00
hkfires
8be06255f7 fix(auth): make provider rotation atomic 2025-12-27 12:56:48 +08:00
Luis Pater
72274099aa Fixed: #738
fix(translators): refine prompt token calculation by incorporating cached tokens in Claude response handling
2025-12-27 03:56:11 +08:00
Luis Pater
dcae098e23 Fixed: #736
fix(translators): handle gjson string types in Claude request processing to ensure consistent content parsing
2025-12-27 01:25:47 +08:00
Luis Pater
2eb05ec640 Merge pull request #727 from nguyenphutrong/main
docs: add Quotio to community projects
2025-12-26 11:53:09 +08:00
Luis Pater
3ce0d76aa4 feat(usage): add import/export functionality for usage statistics and enhance deduplication logic 2025-12-26 11:49:51 +08:00
Trong Nguyen
a00b79d9be docs(readme): add Quotio to community projects section 2025-12-26 10:46:05 +07:00
Luis Pater
33e53a2a56 fix(translators): ensure correct handling and output of multimodal assistant content across request handlers 2025-12-26 05:08:04 +08:00
Luis Pater
cd5b80785f Merge pull request #722 from hungthai1401/bugfix/remove-extra-args
Fixed incorrect function signature call to `NewBaseAPIHandlers`
2025-12-26 02:56:56 +08:00
Thai Nguyen Hung
54f71aa273 fix(test): remove extra argument from ExecuteStreamWithAuthManager call 2025-12-25 21:55:35 +07:00
Luis Pater
3f949b7f84 Merge pull request #704 from tinyc0der/add-index
fix(openai): add index field to image response for LiteLLM compatibility
2025-12-25 21:35:12 +08:00
Luis Pater
443c4538bb feat(config): add commercial-mode to optimize HTTP middleware for lower memory usage 2025-12-25 21:05:01 +08:00
TinyCoder
a7fc2ee4cf refactor(image): avoid using json.Marshal 2025-12-25 14:21:01 +07:00
Luis Pater
8e749ac22d docs(readme): update GLM model version from 4.6 to 4.7 in README and README_CN 2025-12-24 23:59:48 +08:00
Luis Pater
69e09d9bc7 docs(readme): update GLM model version from 4.6 to 4.7 in README and README_CN 2025-12-24 23:46:27 +08:00
Luis Pater
06ad527e8c Fixed: #696
fix(translators): adjust prompt token calculation by subtracting cached tokens across Gemini, OpenAI, and Claude handlers
2025-12-24 23:29:18 +08:00
Luis Pater
b7409dd2de Merge pull request #706 from router-for-me/log
Log
2025-12-24 22:24:39 +08:00
hkfires
5ba325a8fc refactor(logging): standardize request id formatting and layout 2025-12-24 22:03:07 +08:00
Luis Pater
d502840f91 Merge pull request #695 from NguyenSiTrung/main
feat: add cached token parsing for Gemini , Antigravity API responses
2025-12-24 21:58:55 +08:00
hkfires
99238a4b59 fix(logging): normalize warning level to warn 2025-12-24 21:11:37 +08:00
hkfires
6d43a2ff9a refactor(logging): inline request id in log output 2025-12-24 21:07:18 +08:00
Luis Pater
3faa1ca9af Merge pull request #700 from router-for-me/log
refactor(sdk/auth): rename manager.go to conductor.go
2025-12-24 19:36:24 +08:00
Luis Pater
9d975e0375 feat(models): add support for GLM-4.7 and MiniMax-M2.1 2025-12-24 19:30:57 +08:00
hkfires
2a6d8b78d4 feat(api): add endpoint to retrieve request logs by ID 2025-12-24 19:24:51 +08:00
TinyCoder
671558a822 fix(openai): add index field to image response for LiteLLM compatibility
LiteLLM's Pydantic model requires an index field in each image object.
Without it, responses fail validation with "images.0.index Field required".
2025-12-24 17:43:31 +07:00
hkfires
26fbb77901 refactor(sdk/auth): rename manager.go to conductor.go 2025-12-24 15:21:03 +08:00
NguyenSiTrung
a277302262 Merge remote-tracking branch 'upstream/main' 2025-12-24 10:54:09 +07:00
NguyenSiTrung
969c1a5b72 refactor: extract parseGeminiFamilyUsageDetail helper to reduce duplication 2025-12-24 10:22:31 +07:00
NguyenSiTrung
872339bceb feat: add cached token parsing for Gemini API responses 2025-12-24 10:20:11 +07:00
Luis Pater
5dc0dbc7aa Merge pull request #697 from Cubence-com/main
docs(readme): add Cubence sponsor and fix PackyCode link
2025-12-24 11:19:32 +08:00
Luis Pater
2b7ba54a2f Merge pull request #688 from router-for-me/feature/request-id-tracking
feat(logging): implement request ID tracking and propagation
2025-12-24 10:54:13 +08:00
hkfires
007c3304f2 feat(logging): scope request ID tracking to AI API endpoints 2025-12-24 09:17:09 +08:00
hkfires
e76ba0ede9 feat(logging): implement request ID tracking and propagation 2025-12-24 08:32:17 +08:00
Luis Pater
c06ac07e23 Merge pull request #686 from ajkdrag/main
feat: regex support for model-mappings
2025-12-24 04:37:44 +08:00
Luis Pater
66769ec657 fix(translators): update role from tool to user in Gemini and Gemini-CLI requests 2025-12-24 04:24:07 +08:00
Luis Pater
f413feec61 refactor(handlers): streamline error and data channel handling in streaming logic
Improved consistency across OpenAI, Claude, and Gemini handlers by replacing initial `select` statement with a `for` loop for better readability and error-handling robustness.
2025-12-24 04:07:24 +08:00
Luis Pater
2e538e3486 Merge pull request #661 from jroth1111/fix/streaming-bootstrap-forwarder
fix: improve streaming bootstrap and forwarding
2025-12-24 03:51:40 +08:00
Luis Pater
9617a7b0d6 Merge pull request #621 from dacsang97/fix/antigravity-prompt-caching
Antigravity Prompt Caching Fix
2025-12-24 03:50:25 +08:00
Luis Pater
7569320770 Merge branch 'dev' into fix/antigravity-prompt-caching 2025-12-24 03:49:46 +08:00
Fetters
8d25cf0d75 fix(readme): update PackyCode sponsorship link and remove redundant tbody 2025-12-23 23:44:40 +08:00
Fetters
64e85e7019 docs(readme): add Cubence sponsor 2025-12-23 23:30:57 +08:00
altamash
0c0aae1eac Robust change detection: replaced string concat with struct-based compare in hasModelMappingsChanged; removed boolTo01.
•  Performance: pre-allocate map and regex slice capacities in UpdateMappings.
   •  Verified with amp module tests (all passing)
2025-12-23 18:52:28 +05:30
altamash
5dcf7cb846 feat: regex support for model-mappings 2025-12-23 18:41:58 +05:30
gwizz
5bf89dd757 fix: keep streaming defaults legacy-safe 2025-12-23 00:53:18 +11:00
gwizz
4442574e53 fix: stop streaming loop on context cancel 2025-12-23 00:37:55 +11:00
gwizz
71a6dffbb6 fix: improve streaming bootstrap and forwarding 2025-12-22 23:34:23 +11:00
Evan Nguyen
24e8e20b59 Merge branch 'main' into fix/antigravity-prompt-caching 2025-12-21 19:43:24 +07:00
Evan Nguyen
a87f09bad2 feat(antigravity): add session ID generation and mutex for random source 2025-12-21 17:50:41 +07:00
evann
bc6c4cdbfc feat(antigravity): add logging for cached token setting errors in responses 2025-12-19 16:49:50 +07:00
evann
404546ce93 refactor(antigravity): regarding production endpoint caching 2025-12-19 16:36:54 +07:00
evann
6dd1cf1dd6 Merge branch 'main' into fix/antigravity-prompt-caching 2025-12-19 16:34:28 +07:00
evann
9058d406a3 feat(antigravity): enhance prompt caching support and update agent version 2025-12-19 16:33:41 +07:00
45 changed files with 1783 additions and 493 deletions

View File

@@ -13,8 +13,6 @@ Dockerfile
docs/*
README.md
README_CN.md
MANAGEMENT_API.md
MANAGEMENT_API_CN.md
LICENSE
# Runtime data folders (should be mounted as volumes)
@@ -32,3 +30,4 @@ bin/*
.agent/*
.bmad/*
_bmad/*
_bmad-output/*

7
.gitignore vendored
View File

@@ -11,11 +11,15 @@ bin/*
logs/*
conv/*
temp/*
refs/*
# Storage backends
pgstore/*
gitstore/*
objectstore/*
# Static assets
static/*
refs/*
# Authentication data
auths/*
@@ -35,6 +39,7 @@ GEMINI.md
.agent/*
.bmad/*
_bmad/*
_bmad-output/*
# macOS
.DS_Store

View File

@@ -10,11 +10,11 @@ So you can use local or multi-account CLI access with OpenAI(include Responses)/
## Sponsor
[![z.ai](https://assets.router-for.me/english.png)](https://z.ai/subscribe?ic=8JVLJQFSKB)
[![z.ai](https://assets.router-for.me/english-4.7.png)](https://z.ai/subscribe?ic=8JVLJQFSKB)
This project is sponsored by Z.ai, supporting us with their GLM CODING PLAN.
GLM CODING PLAN is a subscription service designed for AI coding, starting at just $3/month. It provides access to their flagship GLM-4.6 model across 10+ popular AI coding tools (Claude Code, Cline, Roo Code, etc.), offering developers top-tier, fast, and stable coding experiences.
GLM CODING PLAN is a subscription service designed for AI coding, starting at just $3/month. It provides access to their flagship GLM-4.7 model across 10+ popular AI coding tools (Claude Code, Cline, Roo Code, etc.), offering developers top-tier, fast, and stable coding experiences.
Get 10% OFF GLM CODING PLANhttps://z.ai/subscribe?ic=8JVLJQFSKB
@@ -26,6 +26,10 @@ Get 10% OFF GLM CODING PLANhttps://z.ai/subscribe?ic=8JVLJQFSKB
<td width="180"><a href="https://www.packyapi.com/register?aff=cliproxyapi"><img src="./assets/packycode.png" alt="PackyCode" width="150"></a></td>
<td>Thanks to PackyCode for sponsoring this project! PackyCode is a reliable and efficient API relay service provider, offering relay services for Claude Code, Codex, Gemini, and more. PackyCode provides special discounts for our software users: register using <a href="https://www.packyapi.com/register?aff=cliproxyapi">this link</a> and enter the "cliproxyapi" promo code during recharge to get 10% off.</td>
</tr>
<tr>
<td width="180"><a href="https://cubence.com/signup?code=CLIPROXYAPI&source=cpa"><img src="./assets/cubence.png" alt="Cubence" width="150"></a></td>
<td>Thanks to Cubence for sponsoring this project! Cubence is a reliable and efficient API relay service provider, offering relay services for Claude Code, Codex, Gemini, and more. Cubence provides special discounts for our software users: register using <a href="https://cubence.com/signup?code=CLIPROXYAPI&source=cpa">this link</a> and enter the "CLIPROXYAPI" promo code during recharge to get 10% off.</td>
</tr>
</tbody>
</table>
@@ -110,6 +114,10 @@ CLI wrapper for instant switching between multiple Claude accounts and alternati
Native macOS GUI for managing CLIProxyAPI: configure providers, model mappings, and endpoints via OAuth - no API keys needed.
### [Quotio](https://github.com/nguyenphutrong/quotio)
Native macOS menu bar app that unifies Claude, Gemini, OpenAI, Qwen, and Antigravity subscriptions with real-time quota tracking and smart auto-failover for AI coding tools like Claude Code, OpenCode, and Droid - no API keys needed.
> [!NOTE]
> If you developed a project based on CLIProxyAPI, please open a PR to add it to this list.

View File

@@ -10,11 +10,11 @@
## 赞助商
[![bigmodel.cn](https://assets.router-for.me/chinese.png)](https://www.bigmodel.cn/claude-code?ic=RRVJPB5SII)
[![bigmodel.cn](https://assets.router-for.me/chinese-4.7.png)](https://www.bigmodel.cn/claude-code?ic=RRVJPB5SII)
本项目由 Z智谱 提供赞助, 他们通过 GLM CODING PLAN 对本项目提供技术支持。
GLM CODING PLAN 是专为AI编码打造的订阅套餐每月最低仅需20元即可在十余款主流AI编码工具如 Claude Code、Cline、Roo Code 中畅享智谱旗舰模型GLM-4.6,为开发者提供顶尖的编码体验。
GLM CODING PLAN 是专为AI编码打造的订阅套餐每月最低仅需20元即可在十余款主流AI编码工具如 Claude Code、Cline、Roo Code 中畅享智谱旗舰模型GLM-4.7,为开发者提供顶尖的编码体验。
智谱AI为本软件提供了特别优惠使用以下链接购买可以享受九折优惠https://www.bigmodel.cn/claude-code?ic=RRVJPB5SII
@@ -26,9 +26,14 @@ GLM CODING PLAN 是专为AI编码打造的订阅套餐每月最低仅需20元
<td width="180"><a href="https://www.packyapi.com/register?aff=cliproxyapi"><img src="./assets/packycode.png" alt="PackyCode" width="150"></a></td>
<td>感谢 PackyCode 对本项目的赞助PackyCode 是一家可靠高效的 API 中转服务商,提供 Claude Code、Codex、Gemini 等多种服务的中转。PackyCode 为本软件用户提供了特别优惠:使用<a href="https://www.packyapi.com/register?aff=cliproxyapi">此链接</a>注册,并在充值时输入 "cliproxyapi" 优惠码即可享受九折优惠。</td>
</tr>
<tr>
<td width="180"><a href="https://cubence.com/signup?code=CLIPROXYAPI&source=cpa"><img src="./assets/cubence.png" alt="Cubence" width="150"></a></td>
<td>感谢 Cubence 对本项目的赞助Cubence 是一家可靠高效的 API 中转服务商,提供 Claude Code、Codex、Gemini 等多种服务的中转。Cubence 为本软件用户提供了特别优惠:使用<a href="https://cubence.com/signup?code=CLIPROXYAPI&source=cpa">此链接</a>注册,并在充值时输入 "CLIPROXYAPI" 优惠码即可享受九折优惠。</td>
</tr>
</tbody>
</table>
## 功能特性
- 为 CLI 模型提供 OpenAI/Gemini/Claude/Codex 兼容的 API 端点
@@ -108,6 +113,10 @@ CLI 封装器,用于通过 CLIProxyAPI OAuth 即时切换多个 Claude 账户
基于 macOS 平台的原生 CLIProxyAPI GUI配置供应商、模型映射以及OAuth端点无需 API 密钥。
### [Quotio](https://github.com/nguyenphutrong/quotio)
原生 macOS 菜单栏应用,统一管理 Claude、Gemini、OpenAI、Qwen 和 Antigravity 订阅,提供实时配额追踪和智能自动故障转移,支持 Claude Code、OpenCode 和 Droid 等 AI 编程工具,无需 API 密钥。
> [!NOTE]
> 如果你开发了基于 CLIProxyAPI 的项目,请提交一个 PR拉取请求将其添加到此列表中。

BIN
assets/cubence.png Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 51 KiB

View File

@@ -39,6 +39,9 @@ api-keys:
# Enable debug logging
debug: false
# When true, disable high-overhead HTTP middleware features to reduce per-request memory usage under high concurrency.
commercial-mode: false
# When true, write application logs to rotating files instead of stdout
logging-to-file: false
@@ -73,6 +76,11 @@ routing:
# When true, enable authentication for the WebSocket API (/v1/ws).
ws-auth: false
# Streaming behavior (SSE keep-alives + safe bootstrap retries).
# streaming:
# keepalive-seconds: 15 # Default: 0 (disabled). <= 0 disables keep-alives.
# bootstrap-retries: 1 # Default: 0 (disabled). Retries before first byte is sent.
# Gemini API keys
# gemini-api-key:
# - api-key: "AIzaSy...01"

View File

@@ -209,6 +209,94 @@ func (h *Handler) GetRequestErrorLogs(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{"files": files})
}
// GetRequestLogByID finds and downloads a request log file by its request ID.
// The ID is matched against the suffix of log file names (format: *-{requestID}.log).
func (h *Handler) GetRequestLogByID(c *gin.Context) {
if h == nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "handler unavailable"})
return
}
if h.cfg == nil {
c.JSON(http.StatusServiceUnavailable, gin.H{"error": "configuration unavailable"})
return
}
dir := h.logDirectory()
if strings.TrimSpace(dir) == "" {
c.JSON(http.StatusInternalServerError, gin.H{"error": "log directory not configured"})
return
}
requestID := strings.TrimSpace(c.Param("id"))
if requestID == "" {
requestID = strings.TrimSpace(c.Query("id"))
}
if requestID == "" {
c.JSON(http.StatusBadRequest, gin.H{"error": "missing request ID"})
return
}
if strings.ContainsAny(requestID, "/\\") {
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid request ID"})
return
}
entries, err := os.ReadDir(dir)
if err != nil {
if os.IsNotExist(err) {
c.JSON(http.StatusNotFound, gin.H{"error": "log directory not found"})
return
}
c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("failed to list log directory: %v", err)})
return
}
suffix := "-" + requestID + ".log"
var matchedFile string
for _, entry := range entries {
if entry.IsDir() {
continue
}
name := entry.Name()
if strings.HasSuffix(name, suffix) {
matchedFile = name
break
}
}
if matchedFile == "" {
c.JSON(http.StatusNotFound, gin.H{"error": "log file not found for the given request ID"})
return
}
dirAbs, errAbs := filepath.Abs(dir)
if errAbs != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("failed to resolve log directory: %v", errAbs)})
return
}
fullPath := filepath.Clean(filepath.Join(dirAbs, matchedFile))
prefix := dirAbs + string(os.PathSeparator)
if !strings.HasPrefix(fullPath, prefix) {
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid log file path"})
return
}
info, errStat := os.Stat(fullPath)
if errStat != nil {
if os.IsNotExist(errStat) {
c.JSON(http.StatusNotFound, gin.H{"error": "log file not found"})
return
}
c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("failed to read log file: %v", errStat)})
return
}
if info.IsDir() {
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid log file"})
return
}
c.FileAttachment(fullPath, matchedFile)
}
// DownloadRequestErrorLog downloads a specific error request log file by name.
func (h *Handler) DownloadRequestErrorLog(c *gin.Context) {
if h == nil {

View File

@@ -1,12 +1,25 @@
package management
import (
"encoding/json"
"net/http"
"time"
"github.com/gin-gonic/gin"
"github.com/router-for-me/CLIProxyAPI/v6/internal/usage"
)
type usageExportPayload struct {
Version int `json:"version"`
ExportedAt time.Time `json:"exported_at"`
Usage usage.StatisticsSnapshot `json:"usage"`
}
type usageImportPayload struct {
Version int `json:"version"`
Usage usage.StatisticsSnapshot `json:"usage"`
}
// GetUsageStatistics returns the in-memory request statistics snapshot.
func (h *Handler) GetUsageStatistics(c *gin.Context) {
var snapshot usage.StatisticsSnapshot
@@ -18,3 +31,49 @@ func (h *Handler) GetUsageStatistics(c *gin.Context) {
"failed_requests": snapshot.FailureCount,
})
}
// ExportUsageStatistics returns a complete usage snapshot for backup/migration.
func (h *Handler) ExportUsageStatistics(c *gin.Context) {
var snapshot usage.StatisticsSnapshot
if h != nil && h.usageStats != nil {
snapshot = h.usageStats.Snapshot()
}
c.JSON(http.StatusOK, usageExportPayload{
Version: 1,
ExportedAt: time.Now().UTC(),
Usage: snapshot,
})
}
// ImportUsageStatistics merges a previously exported usage snapshot into memory.
func (h *Handler) ImportUsageStatistics(c *gin.Context) {
if h == nil || h.usageStats == nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "usage statistics unavailable"})
return
}
data, err := c.GetRawData()
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "failed to read request body"})
return
}
var payload usageImportPayload
if err := json.Unmarshal(data, &payload); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid json"})
return
}
if payload.Version != 0 && payload.Version != 1 {
c.JSON(http.StatusBadRequest, gin.H{"error": "unsupported version"})
return
}
result := h.usageStats.MergeSnapshot(payload.Usage)
snapshot := h.usageStats.Snapshot()
c.JSON(http.StatusOK, gin.H{
"added": result.Added,
"skipped": result.Skipped,
"total_requests": snapshot.TotalRequests,
"failed_requests": snapshot.FailureCount,
})
}

View File

@@ -98,10 +98,11 @@ func captureRequestInfo(c *gin.Context) (*RequestInfo, error) {
}
return &RequestInfo{
URL: url,
Method: method,
Headers: headers,
Body: body,
URL: url,
Method: method,
Headers: headers,
Body: body,
RequestID: logging.GetGinRequestID(c),
}, nil
}

View File

@@ -15,10 +15,11 @@ import (
// RequestInfo holds essential details of an incoming HTTP request for logging purposes.
type RequestInfo struct {
URL string // URL is the request URL.
Method string // Method is the HTTP method (e.g., GET, POST).
Headers map[string][]string // Headers contains the request headers.
Body []byte // Body is the raw request body.
URL string // URL is the request URL.
Method string // Method is the HTTP method (e.g., GET, POST).
Headers map[string][]string // Headers contains the request headers.
Body []byte // Body is the raw request body.
RequestID string // RequestID is the unique identifier for the request.
}
// ResponseWriterWrapper wraps the standard gin.ResponseWriter to intercept and log response data.
@@ -149,6 +150,7 @@ func (w *ResponseWriterWrapper) WriteHeader(statusCode int) {
w.requestInfo.Method,
w.requestInfo.Headers,
w.requestInfo.Body,
w.requestInfo.RequestID,
)
if err == nil {
w.streamWriter = streamWriter
@@ -346,7 +348,7 @@ func (w *ResponseWriterWrapper) logRequest(statusCode int, headers map[string][]
}
if loggerWithOptions, ok := w.logger.(interface {
LogRequestWithOptions(string, string, map[string][]string, []byte, int, map[string][]string, []byte, []byte, []byte, []*interfaces.ErrorMessage, bool) error
LogRequestWithOptions(string, string, map[string][]string, []byte, int, map[string][]string, []byte, []byte, []byte, []*interfaces.ErrorMessage, bool, string) error
}); ok {
return loggerWithOptions.LogRequestWithOptions(
w.requestInfo.URL,
@@ -360,6 +362,7 @@ func (w *ResponseWriterWrapper) logRequest(statusCode int, headers map[string][]
apiResponseBody,
apiResponseErrors,
forceLog,
w.requestInfo.RequestID,
)
}
@@ -374,5 +377,6 @@ func (w *ResponseWriterWrapper) logRequest(statusCode int, headers map[string][]
apiRequestBody,
apiResponseBody,
apiResponseErrors,
w.requestInfo.RequestID,
)
}

View File

@@ -279,16 +279,23 @@ func (m *AmpModule) hasModelMappingsChanged(old *config.AmpCode, new *config.Amp
return true
}
// Build map for efficient comparison
oldMap := make(map[string]string, len(old.ModelMappings))
// Build map for efficient and robust comparison
type mappingInfo struct {
to string
regex bool
}
oldMap := make(map[string]mappingInfo, len(old.ModelMappings))
for _, mapping := range old.ModelMappings {
oldMap[strings.TrimSpace(mapping.From)] = strings.TrimSpace(mapping.To)
oldMap[strings.TrimSpace(mapping.From)] = mappingInfo{
to: strings.TrimSpace(mapping.To),
regex: mapping.Regex,
}
}
for _, mapping := range new.ModelMappings {
from := strings.TrimSpace(mapping.From)
to := strings.TrimSpace(mapping.To)
if oldTo, exists := oldMap[from]; !exists || oldTo != to {
if oldVal, exists := oldMap[from]; !exists || oldVal.to != to || oldVal.regex != mapping.Regex {
return true
}
}

View File

@@ -3,6 +3,7 @@
package amp
import (
"regexp"
"strings"
"sync"
@@ -26,13 +27,15 @@ type ModelMapper interface {
// DefaultModelMapper implements ModelMapper with thread-safe mapping storage.
type DefaultModelMapper struct {
mu sync.RWMutex
mappings map[string]string // from -> to (normalized lowercase keys)
mappings map[string]string // exact: from -> to (normalized lowercase keys)
regexps []regexMapping // regex rules evaluated in order
}
// NewModelMapper creates a new model mapper with the given initial mappings.
func NewModelMapper(mappings []config.AmpModelMapping) *DefaultModelMapper {
m := &DefaultModelMapper{
mappings: make(map[string]string),
regexps: nil,
}
m.UpdateMappings(mappings)
return m
@@ -55,7 +58,18 @@ func (m *DefaultModelMapper) MapModel(requestedModel string) string {
// Check for direct mapping
targetModel, exists := m.mappings[normalizedRequest]
if !exists {
return ""
// Try regex mappings in order
base, _ := util.NormalizeThinkingModel(requestedModel)
for _, rm := range m.regexps {
if rm.re.MatchString(requestedModel) || (base != "" && rm.re.MatchString(base)) {
targetModel = rm.to
exists = true
break
}
}
if !exists {
return ""
}
}
// Verify target model has available providers
@@ -78,6 +92,7 @@ func (m *DefaultModelMapper) UpdateMappings(mappings []config.AmpModelMapping) {
// Clear and rebuild mappings
m.mappings = make(map[string]string, len(mappings))
m.regexps = make([]regexMapping, 0, len(mappings))
for _, mapping := range mappings {
from := strings.TrimSpace(mapping.From)
@@ -88,16 +103,30 @@ func (m *DefaultModelMapper) UpdateMappings(mappings []config.AmpModelMapping) {
continue
}
// Store with normalized lowercase key for case-insensitive lookup
normalizedFrom := strings.ToLower(from)
m.mappings[normalizedFrom] = to
log.Debugf("amp model mapping registered: %s -> %s", from, to)
if mapping.Regex {
// Compile case-insensitive regex; wrap with (?i) to match behavior of exact lookups
pattern := "(?i)" + from
re, err := regexp.Compile(pattern)
if err != nil {
log.Warnf("amp model mapping: invalid regex %q: %v", from, err)
continue
}
m.regexps = append(m.regexps, regexMapping{re: re, to: to})
log.Debugf("amp model regex mapping registered: /%s/ -> %s", from, to)
} else {
// Store with normalized lowercase key for case-insensitive lookup
normalizedFrom := strings.ToLower(from)
m.mappings[normalizedFrom] = to
log.Debugf("amp model mapping registered: %s -> %s", from, to)
}
}
if len(m.mappings) > 0 {
log.Infof("amp model mapping: loaded %d mapping(s)", len(m.mappings))
}
if n := len(m.regexps); n > 0 {
log.Infof("amp model mapping: loaded %d regex mapping(s)", n)
}
}
// GetMappings returns a copy of current mappings (for debugging/status).
@@ -111,3 +140,8 @@ func (m *DefaultModelMapper) GetMappings() map[string]string {
}
return result
}
type regexMapping struct {
re *regexp.Regexp
to string
}

View File

@@ -203,3 +203,81 @@ func TestModelMapper_GetMappings_ReturnsCopy(t *testing.T) {
t.Error("Original map was modified")
}
}
func TestModelMapper_Regex_MatchBaseWithoutParens(t *testing.T) {
reg := registry.GetGlobalRegistry()
reg.RegisterClient("test-client-regex-1", "gemini", []*registry.ModelInfo{
{ID: "gemini-2.5-pro", OwnedBy: "google", Type: "gemini"},
})
defer reg.UnregisterClient("test-client-regex-1")
mappings := []config.AmpModelMapping{
{From: "^gpt-5$", To: "gemini-2.5-pro", Regex: true},
}
mapper := NewModelMapper(mappings)
// Incoming model has reasoning suffix but should match base via regex
result := mapper.MapModel("gpt-5(high)")
if result != "gemini-2.5-pro" {
t.Errorf("Expected gemini-2.5-pro, got %s", result)
}
}
func TestModelMapper_Regex_ExactPrecedence(t *testing.T) {
reg := registry.GetGlobalRegistry()
reg.RegisterClient("test-client-regex-2", "claude", []*registry.ModelInfo{
{ID: "claude-sonnet-4", OwnedBy: "anthropic", Type: "claude"},
})
reg.RegisterClient("test-client-regex-3", "gemini", []*registry.ModelInfo{
{ID: "gemini-2.5-pro", OwnedBy: "google", Type: "gemini"},
})
defer reg.UnregisterClient("test-client-regex-2")
defer reg.UnregisterClient("test-client-regex-3")
mappings := []config.AmpModelMapping{
{From: "gpt-5", To: "claude-sonnet-4"}, // exact
{From: "^gpt-5.*$", To: "gemini-2.5-pro", Regex: true}, // regex
}
mapper := NewModelMapper(mappings)
// Exact match should win over regex
result := mapper.MapModel("gpt-5")
if result != "claude-sonnet-4" {
t.Errorf("Expected claude-sonnet-4, got %s", result)
}
}
func TestModelMapper_Regex_InvalidPattern_Skipped(t *testing.T) {
// Invalid regex should be skipped and not cause panic
mappings := []config.AmpModelMapping{
{From: "(", To: "target", Regex: true},
}
mapper := NewModelMapper(mappings)
result := mapper.MapModel("anything")
if result != "" {
t.Errorf("Expected empty result due to invalid regex, got %s", result)
}
}
func TestModelMapper_Regex_CaseInsensitive(t *testing.T) {
reg := registry.GetGlobalRegistry()
reg.RegisterClient("test-client-regex-4", "claude", []*registry.ModelInfo{
{ID: "claude-sonnet-4", OwnedBy: "anthropic", Type: "claude"},
})
defer reg.UnregisterClient("test-client-regex-4")
mappings := []config.AmpModelMapping{
{From: "^CLAUDE-OPUS-.*$", To: "claude-sonnet-4", Regex: true},
}
mapper := NewModelMapper(mappings)
result := mapper.MapModel("claude-opus-4.5")
if result != "claude-sonnet-4" {
t.Errorf("Expected claude-sonnet-4, got %s", result)
}
}

View File

@@ -209,13 +209,15 @@ func NewServer(cfg *config.Config, authManager *auth.Manager, accessManager *sdk
// Resolve logs directory relative to the configuration file directory.
var requestLogger logging.RequestLogger
var toggle func(bool)
if optionState.requestLoggerFactory != nil {
requestLogger = optionState.requestLoggerFactory(cfg, configFilePath)
}
if requestLogger != nil {
engine.Use(middleware.RequestLoggingMiddleware(requestLogger))
if setter, ok := requestLogger.(interface{ SetEnabled(bool) }); ok {
toggle = setter.SetEnabled
if !cfg.CommercialMode {
if optionState.requestLoggerFactory != nil {
requestLogger = optionState.requestLoggerFactory(cfg, configFilePath)
}
if requestLogger != nil {
engine.Use(middleware.RequestLoggingMiddleware(requestLogger))
if setter, ok := requestLogger.(interface{ SetEnabled(bool) }); ok {
toggle = setter.SetEnabled
}
}
}
@@ -474,6 +476,8 @@ func (s *Server) registerManagementRoutes() {
mgmt.Use(s.managementAvailabilityMiddleware(), s.mgmt.Middleware())
{
mgmt.GET("/usage", s.mgmt.GetUsageStatistics)
mgmt.GET("/usage/export", s.mgmt.ExportUsageStatistics)
mgmt.POST("/usage/import", s.mgmt.ImportUsageStatistics)
mgmt.GET("/config", s.mgmt.GetConfig)
mgmt.GET("/config.yaml", s.mgmt.GetConfigYAML)
mgmt.PUT("/config.yaml", s.mgmt.PutConfigYAML)
@@ -518,6 +522,7 @@ func (s *Server) registerManagementRoutes() {
mgmt.DELETE("/logs", s.mgmt.DeleteLogs)
mgmt.GET("/request-error-logs", s.mgmt.GetRequestErrorLogs)
mgmt.GET("/request-error-logs/:name", s.mgmt.DownloadRequestErrorLog)
mgmt.GET("/request-log-by-id/:id", s.mgmt.GetRequestLogByID)
mgmt.GET("/request-log", s.mgmt.GetRequestLog)
mgmt.PUT("/request-log", s.mgmt.PutRequestLog)
mgmt.PATCH("/request-log", s.mgmt.PutRequestLog)

View File

@@ -39,6 +39,9 @@ type Config struct {
// Debug enables or disables debug-level logging and other debug features.
Debug bool `yaml:"debug" json:"debug"`
// CommercialMode disables high-overhead HTTP middleware features to minimize per-request memory usage.
CommercialMode bool `yaml:"commercial-mode" json:"commercial-mode"`
// LoggingToFile controls whether application logs are written to rotating files or stdout.
LoggingToFile bool `yaml:"logging-to-file" json:"logging-to-file"`
@@ -144,6 +147,11 @@ type AmpModelMapping struct {
// To is the target model name to route to (e.g., "claude-sonnet-4").
// The target model must have available providers in the registry.
To string `yaml:"to" json:"to"`
// Regex indicates whether the 'from' field should be interpreted as a regular
// expression for matching model names. When true, this mapping is evaluated
// after exact matches and in the order provided. Defaults to false (exact match).
Regex bool `yaml:"regex,omitempty" json:"regex,omitempty"`
}
// AmpCode groups Amp CLI integration settings including upstream routing,
@@ -809,8 +817,8 @@ func getOrCreateMapValue(mapNode *yaml.Node, key string) *yaml.Node {
}
// mergeMappingPreserve merges keys from src into dst mapping node while preserving
// key order and comments of existing keys in dst. Unknown keys from src are appended
// to dst at the end, copying their node structure from src.
// key order and comments of existing keys in dst. New keys are only added if their
// value is non-zero to avoid polluting the config with defaults.
func mergeMappingPreserve(dst, src *yaml.Node) {
if dst == nil || src == nil {
return
@@ -821,20 +829,19 @@ func mergeMappingPreserve(dst, src *yaml.Node) {
copyNodeShallow(dst, src)
return
}
// Build a lookup of existing keys in dst
for i := 0; i+1 < len(src.Content); i += 2 {
sk := src.Content[i]
sv := src.Content[i+1]
idx := findMapKeyIndex(dst, sk.Value)
if idx >= 0 {
// Merge into existing value node
// Merge into existing value node (always update, even to zero values)
dv := dst.Content[idx+1]
mergeNodePreserve(dv, sv)
} else {
if shouldSkipEmptyCollectionOnPersist(sk.Value, sv) {
// New key: only add if value is non-zero to avoid polluting config with defaults
if isZeroValueNode(sv) {
continue
}
// Append new key/value pair by deep-copying from src
dst.Content = append(dst.Content, deepCopyNode(sk), deepCopyNode(sv))
}
}
@@ -917,32 +924,49 @@ func findMapKeyIndex(mapNode *yaml.Node, key string) int {
return -1
}
func shouldSkipEmptyCollectionOnPersist(key string, node *yaml.Node) bool {
switch key {
case "generative-language-api-key",
"gemini-api-key",
"vertex-api-key",
"claude-api-key",
"codex-api-key",
"openai-compatibility":
return isEmptyCollectionNode(node)
default:
return false
}
}
func isEmptyCollectionNode(node *yaml.Node) bool {
// isZeroValueNode returns true if the YAML node represents a zero/default value
// that should not be written as a new key to preserve config cleanliness.
// For mappings and sequences, recursively checks if all children are zero values.
func isZeroValueNode(node *yaml.Node) bool {
if node == nil {
return true
}
switch node.Kind {
case yaml.SequenceNode:
return len(node.Content) == 0
case yaml.ScalarNode:
return node.Tag == "!!null"
default:
return false
switch node.Tag {
case "!!bool":
return node.Value == "false"
case "!!int", "!!float":
return node.Value == "0" || node.Value == "0.0"
case "!!str":
return node.Value == ""
case "!!null":
return true
}
case yaml.SequenceNode:
if len(node.Content) == 0 {
return true
}
// Check if all elements are zero values
for _, child := range node.Content {
if !isZeroValueNode(child) {
return false
}
}
return true
case yaml.MappingNode:
if len(node.Content) == 0 {
return true
}
// Check if all values are zero values (values are at odd indices)
for i := 1; i < len(node.Content); i += 2 {
if !isZeroValueNode(node.Content[i]) {
return false
}
}
return true
}
return false
}
// deepCopyNode creates a deep copy of a yaml.Node graph.

View File

@@ -22,6 +22,21 @@ type SDKConfig struct {
// Access holds request authentication provider configuration.
Access AccessConfig `yaml:"auth,omitempty" json:"auth,omitempty"`
// Streaming configures server-side streaming behavior (keep-alives and safe bootstrap retries).
Streaming StreamingConfig `yaml:"streaming" json:"streaming"`
}
// StreamingConfig holds server streaming behavior configuration.
type StreamingConfig struct {
// KeepAliveSeconds controls how often the server emits SSE heartbeats (": keep-alive\n\n").
// <= 0 disables keep-alives. Default is 0.
KeepAliveSeconds int `yaml:"keepalive-seconds,omitempty" json:"keepalive-seconds,omitempty"`
// BootstrapRetries controls how many times the server may retry a streaming request before any bytes are sent,
// to allow auth rotation / transient recovery.
// <= 0 disables bootstrap retries. Default is 0.
BootstrapRetries int `yaml:"bootstrap-retries,omitempty" json:"bootstrap-retries,omitempty"`
}
// AccessConfig groups request authentication providers.

View File

@@ -7,6 +7,7 @@ import (
"fmt"
"net/http"
"runtime/debug"
"strings"
"time"
"github.com/gin-gonic/gin"
@@ -14,11 +15,24 @@ import (
log "github.com/sirupsen/logrus"
)
// aiAPIPrefixes defines path prefixes for AI API requests that should have request ID tracking.
var aiAPIPrefixes = []string{
"/v1/chat/completions",
"/v1/completions",
"/v1/messages",
"/v1/responses",
"/v1beta/models/",
"/api/provider/",
}
const skipGinLogKey = "__gin_skip_request_logging__"
// GinLogrusLogger returns a Gin middleware handler that logs HTTP requests and responses
// using logrus. It captures request details including method, path, status code, latency,
// client IP, and any error messages, formatting them in a Gin-style log format.
// client IP, and any error messages. Request ID is only added for AI API requests.
//
// Output format (AI API): [2025-12-23 20:14:10] [info ] | a1b2c3d4 | 200 | 23.559s | ...
// Output format (others): [2025-12-23 20:14:10] [info ] | -------- | 200 | 23.559s | ...
//
// Returns:
// - gin.HandlerFunc: A middleware handler for request logging
@@ -28,6 +42,15 @@ func GinLogrusLogger() gin.HandlerFunc {
path := c.Request.URL.Path
raw := util.MaskSensitiveQuery(c.Request.URL.RawQuery)
// Only generate request ID for AI API paths
var requestID string
if isAIAPIPath(path) {
requestID = GenerateRequestID()
SetGinRequestID(c, requestID)
ctx := WithRequestID(c.Request.Context(), requestID)
c.Request = c.Request.WithContext(ctx)
}
c.Next()
if shouldSkipGinRequestLogging(c) {
@@ -49,23 +72,38 @@ func GinLogrusLogger() gin.HandlerFunc {
clientIP := c.ClientIP()
method := c.Request.Method
errorMessage := c.Errors.ByType(gin.ErrorTypePrivate).String()
timestamp := time.Now().Format("2006/01/02 - 15:04:05")
logLine := fmt.Sprintf("[GIN] %s | %3d | %13v | %15s | %-7s \"%s\"", timestamp, statusCode, latency, clientIP, method, path)
if requestID == "" {
requestID = "--------"
}
logLine := fmt.Sprintf("%3d | %13v | %15s | %-7s \"%s\"", statusCode, latency, clientIP, method, path)
if errorMessage != "" {
logLine = logLine + " | " + errorMessage
}
entry := log.WithField("request_id", requestID)
switch {
case statusCode >= http.StatusInternalServerError:
log.Error(logLine)
entry.Error(logLine)
case statusCode >= http.StatusBadRequest:
log.Warn(logLine)
entry.Warn(logLine)
default:
log.Info(logLine)
entry.Info(logLine)
}
}
}
// isAIAPIPath checks if the given path is an AI API endpoint that should have request ID tracking.
func isAIAPIPath(path string) bool {
for _, prefix := range aiAPIPrefixes {
if strings.HasPrefix(path, prefix) {
return true
}
}
return false
}
// GinLogrusRecovery returns a Gin middleware handler that recovers from panics and logs
// them using logrus. When a panic occurs, it captures the panic value, stack trace,
// and request path, then returns a 500 Internal Server Error response to the client.

View File

@@ -24,7 +24,8 @@ var (
)
// LogFormatter defines a custom log format for logrus.
// This formatter adds timestamp, level, and source location to each log entry.
// This formatter adds timestamp, level, request ID, and source location to each log entry.
// Format: [2025-12-23 20:14:04] [debug] [manager.go:524] | a1b2c3d4 | Use API key sk-9...0RHO for model gpt-5.2
type LogFormatter struct{}
// Format renders a single log entry with custom formatting.
@@ -39,11 +40,22 @@ func (m *LogFormatter) Format(entry *log.Entry) ([]byte, error) {
timestamp := entry.Time.Format("2006-01-02 15:04:05")
message := strings.TrimRight(entry.Message, "\r\n")
reqID := "--------"
if id, ok := entry.Data["request_id"].(string); ok && id != "" {
reqID = id
}
level := entry.Level.String()
if level == "warning" {
level = "warn"
}
levelStr := fmt.Sprintf("%-5s", level)
var formatted string
if entry.Caller != nil {
formatted = fmt.Sprintf("[%s] [%s] [%s:%d] %s\n", timestamp, entry.Level, filepath.Base(entry.Caller.File), entry.Caller.Line, message)
formatted = fmt.Sprintf("[%s] [%s] [%s] [%s:%d] %s\n", timestamp, reqID, levelStr, filepath.Base(entry.Caller.File), entry.Caller.Line, message)
} else {
formatted = fmt.Sprintf("[%s] [%s] %s\n", timestamp, entry.Level, message)
formatted = fmt.Sprintf("[%s] [%s] [%s] %s\n", timestamp, reqID, levelStr, message)
}
buffer.WriteString(formatted)

View File

@@ -43,10 +43,11 @@ type RequestLogger interface {
// - response: The raw response data
// - apiRequest: The API request data
// - apiResponse: The API response data
// - requestID: Optional request ID for log file naming
//
// Returns:
// - error: An error if logging fails, nil otherwise
LogRequest(url, method string, requestHeaders map[string][]string, body []byte, statusCode int, responseHeaders map[string][]string, response, apiRequest, apiResponse []byte, apiResponseErrors []*interfaces.ErrorMessage) error
LogRequest(url, method string, requestHeaders map[string][]string, body []byte, statusCode int, responseHeaders map[string][]string, response, apiRequest, apiResponse []byte, apiResponseErrors []*interfaces.ErrorMessage, requestID string) error
// LogStreamingRequest initiates logging for a streaming request and returns a writer for chunks.
//
@@ -55,11 +56,12 @@ type RequestLogger interface {
// - method: The HTTP method
// - headers: The request headers
// - body: The request body
// - requestID: Optional request ID for log file naming
//
// Returns:
// - StreamingLogWriter: A writer for streaming response chunks
// - error: An error if logging initialization fails, nil otherwise
LogStreamingRequest(url, method string, headers map[string][]string, body []byte) (StreamingLogWriter, error)
LogStreamingRequest(url, method string, headers map[string][]string, body []byte, requestID string) (StreamingLogWriter, error)
// IsEnabled returns whether request logging is currently enabled.
//
@@ -177,20 +179,21 @@ func (l *FileRequestLogger) SetEnabled(enabled bool) {
// - response: The raw response data
// - apiRequest: The API request data
// - apiResponse: The API response data
// - requestID: Optional request ID for log file naming
//
// Returns:
// - error: An error if logging fails, nil otherwise
func (l *FileRequestLogger) LogRequest(url, method string, requestHeaders map[string][]string, body []byte, statusCode int, responseHeaders map[string][]string, response, apiRequest, apiResponse []byte, apiResponseErrors []*interfaces.ErrorMessage) error {
return l.logRequest(url, method, requestHeaders, body, statusCode, responseHeaders, response, apiRequest, apiResponse, apiResponseErrors, false)
func (l *FileRequestLogger) LogRequest(url, method string, requestHeaders map[string][]string, body []byte, statusCode int, responseHeaders map[string][]string, response, apiRequest, apiResponse []byte, apiResponseErrors []*interfaces.ErrorMessage, requestID string) error {
return l.logRequest(url, method, requestHeaders, body, statusCode, responseHeaders, response, apiRequest, apiResponse, apiResponseErrors, false, requestID)
}
// LogRequestWithOptions logs a request with optional forced logging behavior.
// The force flag allows writing error logs even when regular request logging is disabled.
func (l *FileRequestLogger) LogRequestWithOptions(url, method string, requestHeaders map[string][]string, body []byte, statusCode int, responseHeaders map[string][]string, response, apiRequest, apiResponse []byte, apiResponseErrors []*interfaces.ErrorMessage, force bool) error {
return l.logRequest(url, method, requestHeaders, body, statusCode, responseHeaders, response, apiRequest, apiResponse, apiResponseErrors, force)
func (l *FileRequestLogger) LogRequestWithOptions(url, method string, requestHeaders map[string][]string, body []byte, statusCode int, responseHeaders map[string][]string, response, apiRequest, apiResponse []byte, apiResponseErrors []*interfaces.ErrorMessage, force bool, requestID string) error {
return l.logRequest(url, method, requestHeaders, body, statusCode, responseHeaders, response, apiRequest, apiResponse, apiResponseErrors, force, requestID)
}
func (l *FileRequestLogger) logRequest(url, method string, requestHeaders map[string][]string, body []byte, statusCode int, responseHeaders map[string][]string, response, apiRequest, apiResponse []byte, apiResponseErrors []*interfaces.ErrorMessage, force bool) error {
func (l *FileRequestLogger) logRequest(url, method string, requestHeaders map[string][]string, body []byte, statusCode int, responseHeaders map[string][]string, response, apiRequest, apiResponse []byte, apiResponseErrors []*interfaces.ErrorMessage, force bool, requestID string) error {
if !l.enabled && !force {
return nil
}
@@ -200,10 +203,10 @@ func (l *FileRequestLogger) logRequest(url, method string, requestHeaders map[st
return fmt.Errorf("failed to create logs directory: %w", errEnsure)
}
// Generate filename
filename := l.generateFilename(url)
// Generate filename with request ID
filename := l.generateFilename(url, requestID)
if force && !l.enabled {
filename = l.generateErrorFilename(url)
filename = l.generateErrorFilename(url, requestID)
}
filePath := filepath.Join(l.logsDir, filename)
@@ -271,11 +274,12 @@ func (l *FileRequestLogger) logRequest(url, method string, requestHeaders map[st
// - method: The HTTP method
// - headers: The request headers
// - body: The request body
// - requestID: Optional request ID for log file naming
//
// Returns:
// - StreamingLogWriter: A writer for streaming response chunks
// - error: An error if logging initialization fails, nil otherwise
func (l *FileRequestLogger) LogStreamingRequest(url, method string, headers map[string][]string, body []byte) (StreamingLogWriter, error) {
func (l *FileRequestLogger) LogStreamingRequest(url, method string, headers map[string][]string, body []byte, requestID string) (StreamingLogWriter, error) {
if !l.enabled {
return &NoOpStreamingLogWriter{}, nil
}
@@ -285,8 +289,8 @@ func (l *FileRequestLogger) LogStreamingRequest(url, method string, headers map[
return nil, fmt.Errorf("failed to create logs directory: %w", err)
}
// Generate filename
filename := l.generateFilename(url)
// Generate filename with request ID
filename := l.generateFilename(url, requestID)
filePath := filepath.Join(l.logsDir, filename)
requestHeaders := make(map[string][]string, len(headers))
@@ -330,8 +334,8 @@ func (l *FileRequestLogger) LogStreamingRequest(url, method string, headers map[
}
// generateErrorFilename creates a filename with an error prefix to differentiate forced error logs.
func (l *FileRequestLogger) generateErrorFilename(url string) string {
return fmt.Sprintf("error-%s", l.generateFilename(url))
func (l *FileRequestLogger) generateErrorFilename(url string, requestID ...string) string {
return fmt.Sprintf("error-%s", l.generateFilename(url, requestID...))
}
// ensureLogsDir creates the logs directory if it doesn't exist.
@@ -346,13 +350,15 @@ func (l *FileRequestLogger) ensureLogsDir() error {
}
// generateFilename creates a sanitized filename from the URL path and current timestamp.
// Format: v1-responses-2025-12-23T195811-a1b2c3d4.log
//
// Parameters:
// - url: The request URL
// - requestID: Optional request ID to include in filename
//
// Returns:
// - string: A sanitized filename for the log file
func (l *FileRequestLogger) generateFilename(url string) string {
func (l *FileRequestLogger) generateFilename(url string, requestID ...string) string {
// Extract path from URL
path := url
if strings.Contains(url, "?") {
@@ -368,12 +374,18 @@ func (l *FileRequestLogger) generateFilename(url string) string {
sanitized := l.sanitizeForFilename(path)
// Add timestamp
timestamp := time.Now().Format("2006-01-02T150405-.000000000")
timestamp = strings.Replace(timestamp, ".", "", -1)
timestamp := time.Now().Format("2006-01-02T150405")
id := requestLogID.Add(1)
// Use request ID if provided, otherwise use sequential ID
var idPart string
if len(requestID) > 0 && requestID[0] != "" {
idPart = requestID[0]
} else {
id := requestLogID.Add(1)
idPart = fmt.Sprintf("%d", id)
}
return fmt.Sprintf("%s-%s-%d.log", sanitized, timestamp, id)
return fmt.Sprintf("%s-%s-%s.log", sanitized, timestamp, idPart)
}
// sanitizeForFilename replaces characters that are not safe for filenames.

View File

@@ -0,0 +1,61 @@
package logging
import (
"context"
"crypto/rand"
"encoding/hex"
"github.com/gin-gonic/gin"
)
// requestIDKey is the context key for storing/retrieving request IDs.
type requestIDKey struct{}
// ginRequestIDKey is the Gin context key for request IDs.
const ginRequestIDKey = "__request_id__"
// GenerateRequestID creates a new 8-character hex request ID.
func GenerateRequestID() string {
b := make([]byte, 4)
if _, err := rand.Read(b); err != nil {
return "00000000"
}
return hex.EncodeToString(b)
}
// WithRequestID returns a new context with the request ID attached.
func WithRequestID(ctx context.Context, requestID string) context.Context {
return context.WithValue(ctx, requestIDKey{}, requestID)
}
// GetRequestID retrieves the request ID from the context.
// Returns empty string if not found.
func GetRequestID(ctx context.Context) string {
if ctx == nil {
return ""
}
if id, ok := ctx.Value(requestIDKey{}).(string); ok {
return id
}
return ""
}
// SetGinRequestID stores the request ID in the Gin context.
func SetGinRequestID(c *gin.Context, requestID string) {
if c != nil {
c.Set(ginRequestIDKey, requestID)
}
}
// GetGinRequestID retrieves the request ID from the Gin context.
func GetGinRequestID(c *gin.Context) string {
if c == nil {
return ""
}
if id, exists := c.Get(ginRequestIDKey); exists {
if s, ok := id.(string); ok {
return s
}
}
return ""
}

View File

@@ -727,6 +727,7 @@ func GetIFlowModels() []*ModelInfo {
{ID: "qwen3-max-preview", DisplayName: "Qwen3-Max-Preview", Description: "Qwen3 Max preview build", Created: 1757030400},
{ID: "kimi-k2-0905", DisplayName: "Kimi-K2-Instruct-0905", Description: "Moonshot Kimi K2 instruct 0905", Created: 1757030400},
{ID: "glm-4.6", DisplayName: "GLM-4.6", Description: "Zhipu GLM 4.6 general model", Created: 1759190400, Thinking: iFlowThinkingSupport},
{ID: "glm-4.7", DisplayName: "GLM-4.7", Description: "Zhipu GLM 4.7 general model", Created: 1766448000, Thinking: iFlowThinkingSupport},
{ID: "kimi-k2", DisplayName: "Kimi-K2", Description: "Moonshot Kimi K2 general model", Created: 1752192000},
{ID: "kimi-k2-thinking", DisplayName: "Kimi-K2-Thinking", Description: "Moonshot Kimi K2 thinking model", Created: 1762387200},
{ID: "deepseek-v3.2-chat", DisplayName: "DeepSeek-V3.2", Description: "DeepSeek V3.2 Chat", Created: 1764576000},
@@ -740,6 +741,7 @@ func GetIFlowModels() []*ModelInfo {
{ID: "qwen3-235b-a22b-instruct", DisplayName: "Qwen3-235B-A22B-Instruct", Description: "Qwen3 235B A22B Instruct", Created: 1753401600},
{ID: "qwen3-235b", DisplayName: "Qwen3-235B-A22B", Description: "Qwen3 235B A22B", Created: 1753401600},
{ID: "minimax-m2", DisplayName: "MiniMax-M2", Description: "MiniMax M2", Created: 1758672000},
{ID: "minimax-m2.1", DisplayName: "MiniMax-M2.1", Description: "MiniMax M2.1", Created: 1766448000},
}
models := make([]*ModelInfo, 0, len(entries))
for _, entry := range entries {

View File

@@ -17,6 +17,7 @@ import (
"net/url"
"strconv"
"strings"
"sync"
"time"
"github.com/google/uuid"
@@ -41,12 +42,15 @@ const (
antigravityModelsPath = "/v1internal:fetchAvailableModels"
antigravityClientID = "1071006060591-tmhssin2h21lcre235vtolojh4g403ep.apps.googleusercontent.com"
antigravityClientSecret = "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"
defaultAntigravityAgent = "antigravity/1.11.5 windows/amd64"
defaultAntigravityAgent = "antigravity/1.104.0 darwin/arm64"
antigravityAuthType = "antigravity"
refreshSkew = 3000 * time.Second
)
var randSource = rand.New(rand.NewSource(time.Now().UnixNano()))
var (
randSource = rand.New(rand.NewSource(time.Now().UnixNano()))
randSourceMutex sync.Mutex
)
// AntigravityExecutor proxies requests to the antigravity upstream.
type AntigravityExecutor struct {
@@ -1224,7 +1228,9 @@ func generateRequestID() string {
}
func generateSessionID() string {
randSourceMutex.Lock()
n := randSource.Int63n(9_000_000_000_000_000_000)
randSourceMutex.Unlock()
return "-" + strconv.FormatInt(n, 10)
}
@@ -1248,8 +1254,10 @@ func generateStableSessionID(payload []byte) string {
func generateProjectID() string {
adjectives := []string{"useful", "bright", "swift", "calm", "bold"}
nouns := []string{"fuze", "wave", "spark", "flow", "core"}
randSourceMutex.Lock()
adj := adjectives[randSource.Intn(len(adjectives))]
noun := nouns[randSource.Intn(len(nouns))]
randSourceMutex.Unlock()
randomPart := strings.ToLower(uuid.NewString())[:5]
return adj + "-" + noun + "-" + randomPart
}

View File

@@ -19,7 +19,7 @@ type usageReporter struct {
provider string
model string
authID string
authIndex uint64
authIndex string
apiKey string
source string
requestedAt time.Time
@@ -275,6 +275,20 @@ func parseClaudeStreamUsage(line []byte) (usage.Detail, bool) {
return detail, true
}
func parseGeminiFamilyUsageDetail(node gjson.Result) usage.Detail {
detail := usage.Detail{
InputTokens: node.Get("promptTokenCount").Int(),
OutputTokens: node.Get("candidatesTokenCount").Int(),
ReasoningTokens: node.Get("thoughtsTokenCount").Int(),
TotalTokens: node.Get("totalTokenCount").Int(),
CachedTokens: node.Get("cachedContentTokenCount").Int(),
}
if detail.TotalTokens == 0 {
detail.TotalTokens = detail.InputTokens + detail.OutputTokens + detail.ReasoningTokens
}
return detail
}
func parseGeminiCLIUsage(data []byte) usage.Detail {
usageNode := gjson.ParseBytes(data)
node := usageNode.Get("response.usageMetadata")
@@ -284,16 +298,7 @@ func parseGeminiCLIUsage(data []byte) usage.Detail {
if !node.Exists() {
return usage.Detail{}
}
detail := usage.Detail{
InputTokens: node.Get("promptTokenCount").Int(),
OutputTokens: node.Get("candidatesTokenCount").Int(),
ReasoningTokens: node.Get("thoughtsTokenCount").Int(),
TotalTokens: node.Get("totalTokenCount").Int(),
}
if detail.TotalTokens == 0 {
detail.TotalTokens = detail.InputTokens + detail.OutputTokens + detail.ReasoningTokens
}
return detail
return parseGeminiFamilyUsageDetail(node)
}
func parseGeminiUsage(data []byte) usage.Detail {
@@ -305,16 +310,7 @@ func parseGeminiUsage(data []byte) usage.Detail {
if !node.Exists() {
return usage.Detail{}
}
detail := usage.Detail{
InputTokens: node.Get("promptTokenCount").Int(),
OutputTokens: node.Get("candidatesTokenCount").Int(),
ReasoningTokens: node.Get("thoughtsTokenCount").Int(),
TotalTokens: node.Get("totalTokenCount").Int(),
}
if detail.TotalTokens == 0 {
detail.TotalTokens = detail.InputTokens + detail.OutputTokens + detail.ReasoningTokens
}
return detail
return parseGeminiFamilyUsageDetail(node)
}
func parseGeminiStreamUsage(line []byte) (usage.Detail, bool) {
@@ -329,16 +325,7 @@ func parseGeminiStreamUsage(line []byte) (usage.Detail, bool) {
if !node.Exists() {
return usage.Detail{}, false
}
detail := usage.Detail{
InputTokens: node.Get("promptTokenCount").Int(),
OutputTokens: node.Get("candidatesTokenCount").Int(),
ReasoningTokens: node.Get("thoughtsTokenCount").Int(),
TotalTokens: node.Get("totalTokenCount").Int(),
}
if detail.TotalTokens == 0 {
detail.TotalTokens = detail.InputTokens + detail.OutputTokens + detail.ReasoningTokens
}
return detail, true
return parseGeminiFamilyUsageDetail(node), true
}
func parseGeminiCLIStreamUsage(line []byte) (usage.Detail, bool) {
@@ -353,16 +340,7 @@ func parseGeminiCLIStreamUsage(line []byte) (usage.Detail, bool) {
if !node.Exists() {
return usage.Detail{}, false
}
detail := usage.Detail{
InputTokens: node.Get("promptTokenCount").Int(),
OutputTokens: node.Get("candidatesTokenCount").Int(),
ReasoningTokens: node.Get("thoughtsTokenCount").Int(),
TotalTokens: node.Get("totalTokenCount").Int(),
}
if detail.TotalTokens == 0 {
detail.TotalTokens = detail.InputTokens + detail.OutputTokens + detail.ReasoningTokens
}
return detail, true
return parseGeminiFamilyUsageDetail(node), true
}
func parseAntigravityUsage(data []byte) usage.Detail {
@@ -377,16 +355,7 @@ func parseAntigravityUsage(data []byte) usage.Detail {
if !node.Exists() {
return usage.Detail{}
}
detail := usage.Detail{
InputTokens: node.Get("promptTokenCount").Int(),
OutputTokens: node.Get("candidatesTokenCount").Int(),
ReasoningTokens: node.Get("thoughtsTokenCount").Int(),
TotalTokens: node.Get("totalTokenCount").Int(),
}
if detail.TotalTokens == 0 {
detail.TotalTokens = detail.InputTokens + detail.OutputTokens + detail.ReasoningTokens
}
return detail
return parseGeminiFamilyUsageDetail(node)
}
func parseAntigravityStreamUsage(line []byte) (usage.Detail, bool) {
@@ -404,16 +373,7 @@ func parseAntigravityStreamUsage(line []byte) (usage.Detail, bool) {
if !node.Exists() {
return usage.Detail{}, false
}
detail := usage.Detail{
InputTokens: node.Get("promptTokenCount").Int(),
OutputTokens: node.Get("candidatesTokenCount").Int(),
ReasoningTokens: node.Get("thoughtsTokenCount").Int(),
TotalTokens: node.Get("totalTokenCount").Int(),
}
if detail.TotalTokens == 0 {
detail.TotalTokens = detail.InputTokens + detail.OutputTokens + detail.ReasoningTokens
}
return detail, true
return parseGeminiFamilyUsageDetail(node), true
}
var stopChunkWithoutUsage sync.Map
@@ -522,12 +482,16 @@ func StripUsageMetadataFromJSON(rawJSON []byte) ([]byte, bool) {
cleaned := jsonBytes
var changed bool
if gjson.GetBytes(cleaned, "usageMetadata").Exists() {
if usageMetadata = gjson.GetBytes(cleaned, "usageMetadata"); usageMetadata.Exists() {
// Rename usageMetadata to cpaUsageMetadata in the message_start event of Claude
cleaned, _ = sjson.SetRawBytes(cleaned, "cpaUsageMetadata", []byte(usageMetadata.Raw))
cleaned, _ = sjson.DeleteBytes(cleaned, "usageMetadata")
changed = true
}
if gjson.GetBytes(cleaned, "response.usageMetadata").Exists() {
if usageMetadata = gjson.GetBytes(cleaned, "response.usageMetadata"); usageMetadata.Exists() {
// Rename usageMetadata to cpaUsageMetadata in the message_start event of Claude
cleaned, _ = sjson.SetRawBytes(cleaned, "response.cpaUsageMetadata", []byte(usageMetadata.Raw))
cleaned, _ = sjson.DeleteBytes(cleaned, "response.usageMetadata")
changed = true
}

View File

@@ -35,6 +35,7 @@ type Params struct {
CandidatesTokenCount int64 // Cached candidate token count from usage metadata
ThoughtsTokenCount int64 // Cached thinking token count from usage metadata
TotalTokenCount int64 // Cached total token count from usage metadata
CachedTokenCount int64 // Cached content token count (indicates prompt caching)
HasSentFinalEvents bool // Indicates if final content/message events have been sent
HasToolUse bool // Indicates if tool use was observed in the stream
HasContent bool // Tracks whether any content (text, thinking, or tool use) has been output
@@ -98,6 +99,14 @@ func ConvertAntigravityResponseToClaude(_ context.Context, _ string, originalReq
// This follows the Claude Code API specification for streaming message initialization
messageStartTemplate := `{"type": "message_start", "message": {"id": "msg_1nZdL29xx5MUA1yADyHTEsnR8uuvGzszyY", "type": "message", "role": "assistant", "content": [], "model": "claude-3-5-sonnet-20241022", "stop_reason": null, "stop_sequence": null, "usage": {"input_tokens": 0, "output_tokens": 0}}}`
// Use cpaUsageMetadata within the message_start event for Claude.
if promptTokenCount := gjson.GetBytes(rawJSON, "response.cpaUsageMetadata.promptTokenCount"); promptTokenCount.Exists() {
messageStartTemplate, _ = sjson.Set(messageStartTemplate, "message.usage.input_tokens", promptTokenCount.Int())
}
if candidatesTokenCount := gjson.GetBytes(rawJSON, "response.cpaUsageMetadata.candidatesTokenCount"); candidatesTokenCount.Exists() {
messageStartTemplate, _ = sjson.Set(messageStartTemplate, "message.usage.output_tokens", candidatesTokenCount.Int())
}
// Override default values with actual response metadata if available from the Gemini CLI response
if modelVersionResult := gjson.GetBytes(rawJSON, "response.modelVersion"); modelVersionResult.Exists() {
messageStartTemplate, _ = sjson.Set(messageStartTemplate, "message.model", modelVersionResult.String())
@@ -270,7 +279,8 @@ func ConvertAntigravityResponseToClaude(_ context.Context, _ string, originalReq
if usageResult := gjson.GetBytes(rawJSON, "response.usageMetadata"); usageResult.Exists() {
params.HasUsageMetadata = true
params.PromptTokenCount = usageResult.Get("promptTokenCount").Int()
params.CachedTokenCount = usageResult.Get("cachedContentTokenCount").Int()
params.PromptTokenCount = usageResult.Get("promptTokenCount").Int() - params.CachedTokenCount
params.CandidatesTokenCount = usageResult.Get("candidatesTokenCount").Int()
params.ThoughtsTokenCount = usageResult.Get("thoughtsTokenCount").Int()
params.TotalTokenCount = usageResult.Get("totalTokenCount").Int()
@@ -322,6 +332,14 @@ func appendFinalEvents(params *Params, output *string, force bool) {
*output = *output + "event: message_delta\n"
*output = *output + "data: "
delta := fmt.Sprintf(`{"type":"message_delta","delta":{"stop_reason":"%s","stop_sequence":null},"usage":{"input_tokens":%d,"output_tokens":%d}}`, stopReason, params.PromptTokenCount, usageOutputTokens)
// Add cache_read_input_tokens if cached tokens are present (indicates prompt caching is working)
if params.CachedTokenCount > 0 {
var err error
delta, err = sjson.Set(delta, "usage.cache_read_input_tokens", params.CachedTokenCount)
if err != nil {
log.Warnf("antigravity claude response: failed to set cache_read_input_tokens: %v", err)
}
}
*output = *output + delta + "\n\n\n"
params.HasSentFinalEvents = true
@@ -361,6 +379,7 @@ func ConvertAntigravityResponseToClaudeNonStream(_ context.Context, _ string, or
candidateTokens := root.Get("response.usageMetadata.candidatesTokenCount").Int()
thoughtTokens := root.Get("response.usageMetadata.thoughtsTokenCount").Int()
totalTokens := root.Get("response.usageMetadata.totalTokenCount").Int()
cachedTokens := root.Get("response.usageMetadata.cachedContentTokenCount").Int()
outputTokens := candidateTokens + thoughtTokens
if outputTokens == 0 && totalTokens > 0 {
outputTokens = totalTokens - promptTokens
@@ -374,6 +393,14 @@ func ConvertAntigravityResponseToClaudeNonStream(_ context.Context, _ string, or
responseJSON, _ = sjson.Set(responseJSON, "model", root.Get("response.modelVersion").String())
responseJSON, _ = sjson.Set(responseJSON, "usage.input_tokens", promptTokens)
responseJSON, _ = sjson.Set(responseJSON, "usage.output_tokens", outputTokens)
// Add cache_read_input_tokens if cached tokens are present (indicates prompt caching is working)
if cachedTokens > 0 {
var err error
responseJSON, err = sjson.Set(responseJSON, "usage.cache_read_input_tokens", cachedTokens)
if err != nil {
log.Warnf("antigravity claude response: failed to set cache_read_input_tokens: %v", err)
}
}
contentArrayInitialized := false
ensureContentArray := func() {

View File

@@ -247,10 +247,30 @@ func ConvertOpenAIRequestToAntigravity(modelName string, inputRawJSON []byte, _
} else if role == "assistant" {
node := []byte(`{"role":"model","parts":[]}`)
p := 0
if content.Type == gjson.String {
if content.Type == gjson.String && content.String() != "" {
node, _ = sjson.SetBytes(node, "parts.-1.text", content.String())
out, _ = sjson.SetRawBytes(out, "request.contents.-1", node)
p++
} else if content.IsArray() {
// Assistant multimodal content (e.g. text + image) -> single model content with parts
for _, item := range content.Array() {
switch item.Get("type").String() {
case "text":
p++
case "image_url":
// If the assistant returned an inline data URL, preserve it for history fidelity.
imageURL := item.Get("image_url.url").String()
if len(imageURL) > 5 { // expect data:...
pieces := strings.SplitN(imageURL[5:], ";", 2)
if len(pieces) == 2 && len(pieces[1]) > 7 {
mime := pieces[0]
data := pieces[1][7:]
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".inlineData.mime_type", mime)
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".inlineData.data", data)
p++
}
}
}
}
}
// Tool calls -> single model content with functionCall parts
@@ -305,6 +325,8 @@ func ConvertOpenAIRequestToAntigravity(modelName string, inputRawJSON []byte, _
if pp > 0 {
out, _ = sjson.SetRawBytes(out, "request.contents.-1", toolNode)
}
} else {
out, _ = sjson.SetRawBytes(out, "request.contents.-1", node)
}
}
}

View File

@@ -13,6 +13,8 @@ import (
"sync/atomic"
"time"
log "github.com/sirupsen/logrus"
. "github.com/router-for-me/CLIProxyAPI/v6/internal/translator/gemini/openai/chat-completions"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
@@ -85,18 +87,27 @@ func ConvertAntigravityResponseToOpenAI(_ context.Context, _ string, originalReq
// Extract and set usage metadata (token counts).
if usageResult := gjson.GetBytes(rawJSON, "response.usageMetadata"); usageResult.Exists() {
cachedTokenCount := usageResult.Get("cachedContentTokenCount").Int()
if candidatesTokenCountResult := usageResult.Get("candidatesTokenCount"); candidatesTokenCountResult.Exists() {
template, _ = sjson.Set(template, "usage.completion_tokens", candidatesTokenCountResult.Int())
}
if totalTokenCountResult := usageResult.Get("totalTokenCount"); totalTokenCountResult.Exists() {
template, _ = sjson.Set(template, "usage.total_tokens", totalTokenCountResult.Int())
}
promptTokenCount := usageResult.Get("promptTokenCount").Int()
promptTokenCount := usageResult.Get("promptTokenCount").Int() - cachedTokenCount
thoughtsTokenCount := usageResult.Get("thoughtsTokenCount").Int()
template, _ = sjson.Set(template, "usage.prompt_tokens", promptTokenCount+thoughtsTokenCount)
if thoughtsTokenCount > 0 {
template, _ = sjson.Set(template, "usage.completion_tokens_details.reasoning_tokens", thoughtsTokenCount)
}
// Include cached token count if present (indicates prompt caching is working)
if cachedTokenCount > 0 {
var err error
template, err = sjson.Set(template, "usage.prompt_tokens_details.cached_tokens", cachedTokenCount)
if err != nil {
log.Warnf("antigravity openai response: failed to set cached_tokens: %v", err)
}
}
}
// Process the main content part of the response.
@@ -170,12 +181,14 @@ func ConvertAntigravityResponseToOpenAI(_ context.Context, _ string, originalReq
mimeType = "image/png"
}
imageURL := fmt.Sprintf("data:%s;base64,%s", mimeType, data)
imagePayload := `{"image_url":{"url":""},"type":"image_url"}`
imagePayload, _ = sjson.Set(imagePayload, "image_url.url", imageURL)
imagesResult := gjson.Get(template, "choices.0.delta.images")
if !imagesResult.Exists() || !imagesResult.IsArray() {
template, _ = sjson.SetRaw(template, "choices.0.delta.images", `[]`)
}
imageIndex := len(gjson.Get(template, "choices.0.delta.images").Array())
imagePayload := `{"type":"image_url","image_url":{"url":""}}`
imagePayload, _ = sjson.Set(imagePayload, "index", imageIndex)
imagePayload, _ = sjson.Set(imagePayload, "image_url.url", imageURL)
template, _ = sjson.Set(template, "choices.0.delta.role", "assistant")
template, _ = sjson.SetRaw(template, "choices.0.delta.images.-1", imagePayload)
}

View File

@@ -205,9 +205,12 @@ func ConvertClaudeResponseToOpenAI(_ context.Context, modelName string, original
if usage := root.Get("usage"); usage.Exists() {
inputTokens := usage.Get("input_tokens").Int()
outputTokens := usage.Get("output_tokens").Int()
template, _ = sjson.Set(template, "usage.prompt_tokens", inputTokens)
cacheReadInputTokens := usage.Get("cache_read_input_tokens").Int()
cacheCreationInputTokens := usage.Get("cache_creation_input_tokens").Int()
template, _ = sjson.Set(template, "usage.prompt_tokens", inputTokens+cacheCreationInputTokens)
template, _ = sjson.Set(template, "usage.completion_tokens", outputTokens)
template, _ = sjson.Set(template, "usage.total_tokens", inputTokens+outputTokens)
template, _ = sjson.Set(template, "usage.prompt_tokens_details.cached_tokens", cacheReadInputTokens)
}
return []string{template}
@@ -281,8 +284,6 @@ func ConvertClaudeResponseToOpenAINonStream(_ context.Context, _ string, origina
var messageID string
var model string
var createdAt int64
var inputTokens, outputTokens int64
var reasoningTokens int64
var stopReason string
var contentParts []string
var reasoningParts []string
@@ -299,9 +300,6 @@ func ConvertClaudeResponseToOpenAINonStream(_ context.Context, _ string, origina
messageID = message.Get("id").String()
model = message.Get("model").String()
createdAt = time.Now().Unix()
if usage := message.Get("usage"); usage.Exists() {
inputTokens = usage.Get("input_tokens").Int()
}
}
case "content_block_start":
@@ -364,11 +362,14 @@ func ConvertClaudeResponseToOpenAINonStream(_ context.Context, _ string, origina
}
}
if usage := root.Get("usage"); usage.Exists() {
outputTokens = usage.Get("output_tokens").Int()
// Estimate reasoning tokens from accumulated thinking content
if len(reasoningParts) > 0 {
reasoningTokens = int64(len(strings.Join(reasoningParts, "")) / 4) // Rough estimation
}
inputTokens := usage.Get("input_tokens").Int()
outputTokens := usage.Get("output_tokens").Int()
cacheReadInputTokens := usage.Get("cache_read_input_tokens").Int()
cacheCreationInputTokens := usage.Get("cache_creation_input_tokens").Int()
out, _ = sjson.Set(out, "usage.prompt_tokens", inputTokens+cacheCreationInputTokens)
out, _ = sjson.Set(out, "usage.completion_tokens", outputTokens)
out, _ = sjson.Set(out, "usage.total_tokens", inputTokens+outputTokens)
out, _ = sjson.Set(out, "usage.prompt_tokens_details.cached_tokens", cacheReadInputTokens)
}
}
}
@@ -427,16 +428,5 @@ func ConvertClaudeResponseToOpenAINonStream(_ context.Context, _ string, origina
out, _ = sjson.Set(out, "choices.0.finish_reason", mapAnthropicStopReasonToOpenAI(stopReason))
}
// Set usage information including prompt tokens, completion tokens, and total tokens
totalTokens := inputTokens + outputTokens
out, _ = sjson.Set(out, "usage.prompt_tokens", inputTokens)
out, _ = sjson.Set(out, "usage.completion_tokens", outputTokens)
out, _ = sjson.Set(out, "usage.total_tokens", totalTokens)
// Add reasoning tokens to usage details if any reasoning content was processed
if reasoningTokens > 0 {
out, _ = sjson.Set(out, "usage.completion_tokens_details.reasoning_tokens", reasoningTokens)
}
return out
}

View File

@@ -114,13 +114,16 @@ func ConvertOpenAIResponsesRequestToClaude(modelName string, inputRawJSON []byte
var builder strings.Builder
if parts := item.Get("content"); parts.Exists() && parts.IsArray() {
parts.ForEach(func(_, part gjson.Result) bool {
text := part.Get("text").String()
textResult := part.Get("text")
text := textResult.String()
if builder.Len() > 0 && text != "" {
builder.WriteByte('\n')
}
builder.WriteString(text)
return true
})
} else if parts.Type == gjson.String {
builder.WriteString(parts.String())
}
instructionsText = builder.String()
if instructionsText != "" {
@@ -207,6 +210,8 @@ func ConvertOpenAIResponsesRequestToClaude(modelName string, inputRawJSON []byte
}
return true
})
} else if parts.Type == gjson.String {
textAggregate.WriteString(parts.String())
}
// Fallback to given role if content types not decisive

View File

@@ -218,8 +218,29 @@ func ConvertOpenAIRequestToGeminiCLI(modelName string, inputRawJSON []byte, _ bo
if content.Type == gjson.String {
// Assistant text -> single model content
node, _ = sjson.SetBytes(node, "parts.-1.text", content.String())
out, _ = sjson.SetRawBytes(out, "request.contents.-1", node)
p++
} else if content.IsArray() {
// Assistant multimodal content (e.g. text + image) -> single model content with parts
for _, item := range content.Array() {
switch item.Get("type").String() {
case "text":
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".text", item.Get("text").String())
p++
case "image_url":
// If the assistant returned an inline data URL, preserve it for history fidelity.
imageURL := item.Get("image_url.url").String()
if len(imageURL) > 5 { // expect data:...
pieces := strings.SplitN(imageURL[5:], ";", 2)
if len(pieces) == 2 && len(pieces[1]) > 7 {
mime := pieces[0]
data := pieces[1][7:]
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".inlineData.mime_type", mime)
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".inlineData.data", data)
p++
}
}
}
}
}
// Tool calls -> single model content with functionCall parts
@@ -244,7 +265,7 @@ func ConvertOpenAIRequestToGeminiCLI(modelName string, inputRawJSON []byte, _ bo
out, _ = sjson.SetRawBytes(out, "request.contents.-1", node)
// Append a single tool content combining name + response per function
toolNode := []byte(`{"role":"tool","parts":[]}`)
toolNode := []byte(`{"role":"user","parts":[]}`)
pp := 0
for _, fid := range fIDs {
if name, ok := tcID2Name[fid]; ok {
@@ -260,6 +281,8 @@ func ConvertOpenAIRequestToGeminiCLI(modelName string, inputRawJSON []byte, _ bo
if pp > 0 {
out, _ = sjson.SetRawBytes(out, "request.contents.-1", toolNode)
}
} else {
out, _ = sjson.SetRawBytes(out, "request.contents.-1", node)
}
}
}

View File

@@ -170,12 +170,14 @@ func ConvertCliResponseToOpenAI(_ context.Context, _ string, originalRequestRawJ
mimeType = "image/png"
}
imageURL := fmt.Sprintf("data:%s;base64,%s", mimeType, data)
imagePayload := `{"image_url":{"url":""},"type":"image_url"}`
imagePayload, _ = sjson.Set(imagePayload, "image_url.url", imageURL)
imagesResult := gjson.Get(template, "choices.0.delta.images")
if !imagesResult.Exists() || !imagesResult.IsArray() {
template, _ = sjson.SetRaw(template, "choices.0.delta.images", `[]`)
}
imageIndex := len(gjson.Get(template, "choices.0.delta.images").Array())
imagePayload := `{"type":"image_url","image_url":{"url":""}}`
imagePayload, _ = sjson.Set(imagePayload, "index", imageIndex)
imagePayload, _ = sjson.Set(imagePayload, "image_url.url", imageURL)
template, _ = sjson.Set(template, "choices.0.delta.role", "assistant")
template, _ = sjson.SetRaw(template, "choices.0.delta.images.-1", imagePayload)
}

View File

@@ -233,18 +233,15 @@ func ConvertOpenAIRequestToGemini(modelName string, inputRawJSON []byte, _ bool)
} else if role == "assistant" {
node := []byte(`{"role":"model","parts":[]}`)
p := 0
if content.Type == gjson.String {
// Assistant text -> single model content
node, _ = sjson.SetBytes(node, "parts.-1.text", content.String())
out, _ = sjson.SetRawBytes(out, "contents.-1", node)
p++
} else if content.IsArray() {
// Assistant multimodal content (e.g. text + image) -> single model content with parts
for _, item := range content.Array() {
switch item.Get("type").String() {
case "text":
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".text", item.Get("text").String())
p++
case "image_url":
// If the assistant returned an inline data URL, preserve it for history fidelity.
@@ -261,7 +258,6 @@ func ConvertOpenAIRequestToGemini(modelName string, inputRawJSON []byte, _ bool)
}
}
}
out, _ = sjson.SetRawBytes(out, "contents.-1", node)
}
// Tool calls -> single model content with functionCall parts
@@ -286,7 +282,7 @@ func ConvertOpenAIRequestToGemini(modelName string, inputRawJSON []byte, _ bool)
out, _ = sjson.SetRawBytes(out, "contents.-1", node)
// Append a single tool content combining name + response per function
toolNode := []byte(`{"role":"tool","parts":[]}`)
toolNode := []byte(`{"role":"user","parts":[]}`)
pp := 0
for _, fid := range fIDs {
if name, ok := tcID2Name[fid]; ok {
@@ -302,6 +298,8 @@ func ConvertOpenAIRequestToGemini(modelName string, inputRawJSON []byte, _ bool)
if pp > 0 {
out, _ = sjson.SetRawBytes(out, "contents.-1", toolNode)
}
} else {
out, _ = sjson.SetRawBytes(out, "contents.-1", node)
}
}
}

View File

@@ -13,6 +13,7 @@ import (
"sync/atomic"
"time"
log "github.com/sirupsen/logrus"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
)
@@ -88,18 +89,27 @@ func ConvertGeminiResponseToOpenAI(_ context.Context, _ string, originalRequestR
// Extract and set usage metadata (token counts).
if usageResult := gjson.GetBytes(rawJSON, "usageMetadata"); usageResult.Exists() {
cachedTokenCount := usageResult.Get("cachedContentTokenCount").Int()
if candidatesTokenCountResult := usageResult.Get("candidatesTokenCount"); candidatesTokenCountResult.Exists() {
template, _ = sjson.Set(template, "usage.completion_tokens", candidatesTokenCountResult.Int())
}
if totalTokenCountResult := usageResult.Get("totalTokenCount"); totalTokenCountResult.Exists() {
template, _ = sjson.Set(template, "usage.total_tokens", totalTokenCountResult.Int())
}
promptTokenCount := usageResult.Get("promptTokenCount").Int()
promptTokenCount := usageResult.Get("promptTokenCount").Int() - cachedTokenCount
thoughtsTokenCount := usageResult.Get("thoughtsTokenCount").Int()
template, _ = sjson.Set(template, "usage.prompt_tokens", promptTokenCount+thoughtsTokenCount)
if thoughtsTokenCount > 0 {
template, _ = sjson.Set(template, "usage.completion_tokens_details.reasoning_tokens", thoughtsTokenCount)
}
// Include cached token count if present (indicates prompt caching is working)
if cachedTokenCount > 0 {
var err error
template, err = sjson.Set(template, "usage.prompt_tokens_details.cached_tokens", cachedTokenCount)
if err != nil {
log.Warnf("gemini openai response: failed to set cached_tokens in streaming: %v", err)
}
}
}
// Process the main content part of the response.
@@ -172,12 +182,14 @@ func ConvertGeminiResponseToOpenAI(_ context.Context, _ string, originalRequestR
mimeType = "image/png"
}
imageURL := fmt.Sprintf("data:%s;base64,%s", mimeType, data)
imagePayload := `{"image_url":{"url":""},"type":"image_url"}`
imagePayload, _ = sjson.Set(imagePayload, "image_url.url", imageURL)
imagesResult := gjson.Get(template, "choices.0.delta.images")
if !imagesResult.Exists() || !imagesResult.IsArray() {
template, _ = sjson.SetRaw(template, "choices.0.delta.images", `[]`)
}
imageIndex := len(gjson.Get(template, "choices.0.delta.images").Array())
imagePayload := `{"type":"image_url","image_url":{"url":""}}`
imagePayload, _ = sjson.Set(imagePayload, "index", imageIndex)
imagePayload, _ = sjson.Set(imagePayload, "image_url.url", imageURL)
template, _ = sjson.Set(template, "choices.0.delta.role", "assistant")
template, _ = sjson.SetRaw(template, "choices.0.delta.images.-1", imagePayload)
}
@@ -240,10 +252,19 @@ func ConvertGeminiResponseToOpenAINonStream(_ context.Context, _ string, origina
}
promptTokenCount := usageResult.Get("promptTokenCount").Int()
thoughtsTokenCount := usageResult.Get("thoughtsTokenCount").Int()
cachedTokenCount := usageResult.Get("cachedContentTokenCount").Int()
template, _ = sjson.Set(template, "usage.prompt_tokens", promptTokenCount+thoughtsTokenCount)
if thoughtsTokenCount > 0 {
template, _ = sjson.Set(template, "usage.completion_tokens_details.reasoning_tokens", thoughtsTokenCount)
}
// Include cached token count if present (indicates prompt caching is working)
if cachedTokenCount > 0 {
var err error
template, err = sjson.Set(template, "usage.prompt_tokens_details.cached_tokens", cachedTokenCount)
if err != nil {
log.Warnf("gemini openai response: failed to set cached_tokens in non-streaming: %v", err)
}
}
}
// Process the main content part of the response.
@@ -297,12 +318,14 @@ func ConvertGeminiResponseToOpenAINonStream(_ context.Context, _ string, origina
mimeType = "image/png"
}
imageURL := fmt.Sprintf("data:%s;base64,%s", mimeType, data)
imagePayload := `{"image_url":{"url":""},"type":"image_url"}`
imagePayload, _ = sjson.Set(imagePayload, "image_url.url", imageURL)
imagesResult := gjson.Get(template, "choices.0.message.images")
if !imagesResult.Exists() || !imagesResult.IsArray() {
template, _ = sjson.SetRaw(template, "choices.0.message.images", `[]`)
}
imageIndex := len(gjson.Get(template, "choices.0.message.images").Array())
imagePayload := `{"type":"image_url","image_url":{"url":""}}`
imagePayload, _ = sjson.Set(imagePayload, "index", imageIndex)
imagePayload, _ = sjson.Set(imagePayload, "image_url.url", imageURL)
template, _ = sjson.Set(template, "choices.0.message.role", "assistant")
template, _ = sjson.SetRaw(template, "choices.0.message.images.-1", imagePayload)
}

View File

@@ -6,6 +6,7 @@ package usage
import (
"context"
"fmt"
"strings"
"sync"
"sync/atomic"
"time"
@@ -90,7 +91,7 @@ type modelStats struct {
type RequestDetail struct {
Timestamp time.Time `json:"timestamp"`
Source string `json:"source"`
AuthIndex uint64 `json:"auth_index"`
AuthIndex string `json:"auth_index"`
Tokens TokenStats `json:"tokens"`
Failed bool `json:"failed"`
}
@@ -281,6 +282,118 @@ func (s *RequestStatistics) Snapshot() StatisticsSnapshot {
return result
}
type MergeResult struct {
Added int64 `json:"added"`
Skipped int64 `json:"skipped"`
}
// MergeSnapshot merges an exported statistics snapshot into the current store.
// Existing data is preserved and duplicate request details are skipped.
func (s *RequestStatistics) MergeSnapshot(snapshot StatisticsSnapshot) MergeResult {
result := MergeResult{}
if s == nil {
return result
}
s.mu.Lock()
defer s.mu.Unlock()
seen := make(map[string]struct{})
for apiName, stats := range s.apis {
if stats == nil {
continue
}
for modelName, modelStatsValue := range stats.Models {
if modelStatsValue == nil {
continue
}
for _, detail := range modelStatsValue.Details {
seen[dedupKey(apiName, modelName, detail)] = struct{}{}
}
}
}
for apiName, apiSnapshot := range snapshot.APIs {
apiName = strings.TrimSpace(apiName)
if apiName == "" {
continue
}
stats, ok := s.apis[apiName]
if !ok || stats == nil {
stats = &apiStats{Models: make(map[string]*modelStats)}
s.apis[apiName] = stats
} else if stats.Models == nil {
stats.Models = make(map[string]*modelStats)
}
for modelName, modelSnapshot := range apiSnapshot.Models {
modelName = strings.TrimSpace(modelName)
if modelName == "" {
modelName = "unknown"
}
for _, detail := range modelSnapshot.Details {
detail.Tokens = normaliseTokenStats(detail.Tokens)
if detail.Timestamp.IsZero() {
detail.Timestamp = time.Now()
}
key := dedupKey(apiName, modelName, detail)
if _, exists := seen[key]; exists {
result.Skipped++
continue
}
seen[key] = struct{}{}
s.recordImported(apiName, modelName, stats, detail)
result.Added++
}
}
}
return result
}
func (s *RequestStatistics) recordImported(apiName, modelName string, stats *apiStats, detail RequestDetail) {
totalTokens := detail.Tokens.TotalTokens
if totalTokens < 0 {
totalTokens = 0
}
s.totalRequests++
if detail.Failed {
s.failureCount++
} else {
s.successCount++
}
s.totalTokens += totalTokens
s.updateAPIStats(stats, modelName, detail)
dayKey := detail.Timestamp.Format("2006-01-02")
hourKey := detail.Timestamp.Hour()
s.requestsByDay[dayKey]++
s.requestsByHour[hourKey]++
s.tokensByDay[dayKey] += totalTokens
s.tokensByHour[hourKey] += totalTokens
}
func dedupKey(apiName, modelName string, detail RequestDetail) string {
timestamp := detail.Timestamp.UTC().Format(time.RFC3339Nano)
tokens := normaliseTokenStats(detail.Tokens)
return fmt.Sprintf(
"%s|%s|%s|%s|%s|%t|%d|%d|%d|%d|%d",
apiName,
modelName,
timestamp,
detail.Source,
detail.AuthIndex,
detail.Failed,
tokens.InputTokens,
tokens.OutputTokens,
tokens.ReasoningTokens,
tokens.CachedTokens,
tokens.TotalTokens,
)
}
func resolveAPIIdentifier(ctx context.Context, record coreusage.Record) string {
if ctx != nil {
if ginCtx, ok := ctx.Value("gin").(*gin.Context); ok && ginCtx != nil {
@@ -340,6 +453,16 @@ func normaliseDetail(detail coreusage.Detail) TokenStats {
return tokens
}
func normaliseTokenStats(tokens TokenStats) TokenStats {
if tokens.TotalTokens == 0 {
tokens.TotalTokens = tokens.InputTokens + tokens.OutputTokens + tokens.ReasoningTokens
}
if tokens.TotalTokens == 0 {
tokens.TotalTokens = tokens.InputTokens + tokens.OutputTokens + tokens.ReasoningTokens + tokens.CachedTokens
}
return tokens
}
func formatHour(hour int) string {
if hour < 0 {
hour = 0

View File

@@ -14,7 +14,6 @@ import (
"fmt"
"io"
"net/http"
"time"
"github.com/gin-gonic/gin"
. "github.com/router-for-me/CLIProxyAPI/v6/internal/constant"
@@ -185,14 +184,6 @@ func (h *ClaudeCodeAPIHandler) handleNonStreamingResponse(c *gin.Context, rawJSO
// - c: The Gin context for the request.
// - rawJSON: The raw JSON request body.
func (h *ClaudeCodeAPIHandler) handleStreamingResponse(c *gin.Context, rawJSON []byte) {
// Set up Server-Sent Events (SSE) headers for streaming response
// These headers are essential for maintaining a persistent connection
// and enabling real-time streaming of chat completions
c.Header("Content-Type", "text/event-stream")
c.Header("Cache-Control", "no-cache")
c.Header("Connection", "keep-alive")
c.Header("Access-Control-Allow-Origin", "*")
// Get the http.Flusher interface to manually flush the response.
// This is crucial for streaming as it allows immediate sending of data chunks
flusher, ok := c.Writer.(http.Flusher)
@@ -213,58 +204,82 @@ func (h *ClaudeCodeAPIHandler) handleStreamingResponse(c *gin.Context, rawJSON [
cliCtx, cliCancel := h.GetContextWithCancel(h, c, context.Background())
dataChan, errChan := h.ExecuteStreamWithAuthManager(cliCtx, h.HandlerType(), modelName, rawJSON, "")
h.forwardClaudeStream(c, flusher, func(err error) { cliCancel(err) }, dataChan, errChan)
return
}
setSSEHeaders := func() {
c.Header("Content-Type", "text/event-stream")
c.Header("Cache-Control", "no-cache")
c.Header("Connection", "keep-alive")
c.Header("Access-Control-Allow-Origin", "*")
}
func (h *ClaudeCodeAPIHandler) forwardClaudeStream(c *gin.Context, flusher http.Flusher, cancel func(error), data <-chan []byte, errs <-chan *interfaces.ErrorMessage) {
// OpenAI-style stream forwarding: write each SSE chunk and flush immediately.
// This guarantees clients see incremental output even for small responses.
// Peek at the first chunk to determine success or failure before setting headers
for {
select {
case <-c.Request.Context().Done():
cancel(c.Request.Context().Err())
cliCancel(c.Request.Context().Err())
return
case chunk, ok := <-data:
case errMsg, ok := <-errChan:
if !ok {
// Err channel closed cleanly; wait for data channel.
errChan = nil
continue
}
// Upstream failed immediately. Return proper error status and JSON.
h.WriteErrorResponse(c, errMsg)
if errMsg != nil {
cliCancel(errMsg.Error)
} else {
cliCancel(nil)
}
return
case chunk, ok := <-dataChan:
if !ok {
// Stream closed without data? Send DONE or just headers.
setSSEHeaders()
flusher.Flush()
cancel(nil)
cliCancel(nil)
return
}
// Success! Set headers now.
setSSEHeaders()
// Write the first chunk
if len(chunk) > 0 {
_, _ = c.Writer.Write(chunk)
flusher.Flush()
}
case errMsg, ok := <-errs:
if !ok {
continue
}
if errMsg != nil {
status := http.StatusInternalServerError
if errMsg.StatusCode > 0 {
status = errMsg.StatusCode
}
c.Status(status)
// An error occurred: emit as a proper SSE error event
errorBytes, _ := json.Marshal(h.toClaudeError(errMsg))
_, _ = fmt.Fprintf(c.Writer, "event: error\ndata: %s\n\n", errorBytes)
flusher.Flush()
}
var execErr error
if errMsg != nil {
execErr = errMsg.Error
}
cancel(execErr)
// Continue streaming the rest
h.forwardClaudeStream(c, flusher, func(err error) { cliCancel(err) }, dataChan, errChan)
return
case <-time.After(500 * time.Millisecond):
}
}
}
func (h *ClaudeCodeAPIHandler) forwardClaudeStream(c *gin.Context, flusher http.Flusher, cancel func(error), data <-chan []byte, errs <-chan *interfaces.ErrorMessage) {
h.ForwardStream(c, flusher, cancel, data, errs, handlers.StreamForwardOptions{
WriteChunk: func(chunk []byte) {
if len(chunk) == 0 {
return
}
_, _ = c.Writer.Write(chunk)
},
WriteTerminalError: func(errMsg *interfaces.ErrorMessage) {
if errMsg == nil {
return
}
status := http.StatusInternalServerError
if errMsg.StatusCode > 0 {
status = errMsg.StatusCode
}
c.Status(status)
errorBytes, _ := json.Marshal(h.toClaudeError(errMsg))
_, _ = fmt.Fprintf(c.Writer, "event: error\ndata: %s\n\n", errorBytes)
},
})
}
type claudeErrorDetail struct {
Type string `json:"type"`
Message string `json:"message"`

View File

@@ -182,19 +182,18 @@ func (h *GeminiCLIAPIHandler) handleInternalGenerateContent(c *gin.Context, rawJ
}
func (h *GeminiCLIAPIHandler) forwardCLIStream(c *gin.Context, flusher http.Flusher, alt string, cancel func(error), data <-chan []byte, errs <-chan *interfaces.ErrorMessage) {
for {
select {
case <-c.Request.Context().Done():
cancel(c.Request.Context().Err())
return
case chunk, ok := <-data:
if !ok {
cancel(nil)
return
}
var keepAliveInterval *time.Duration
if alt != "" {
disabled := time.Duration(0)
keepAliveInterval = &disabled
}
h.ForwardStream(c, flusher, cancel, data, errs, handlers.StreamForwardOptions{
KeepAliveInterval: keepAliveInterval,
WriteChunk: func(chunk []byte) {
if alt == "" {
if bytes.Equal(chunk, []byte("data: [DONE]")) || bytes.Equal(chunk, []byte("[DONE]")) {
continue
return
}
if !bytes.HasPrefix(chunk, []byte("data:")) {
@@ -206,22 +205,25 @@ func (h *GeminiCLIAPIHandler) forwardCLIStream(c *gin.Context, flusher http.Flus
} else {
_, _ = c.Writer.Write(chunk)
}
flusher.Flush()
case errMsg, ok := <-errs:
if !ok {
continue
},
WriteTerminalError: func(errMsg *interfaces.ErrorMessage) {
if errMsg == nil {
return
}
if errMsg != nil {
h.WriteErrorResponse(c, errMsg)
flusher.Flush()
status := http.StatusInternalServerError
if errMsg.StatusCode > 0 {
status = errMsg.StatusCode
}
var execErr error
if errMsg != nil {
execErr = errMsg.Error
errText := http.StatusText(status)
if errMsg.Error != nil && errMsg.Error.Error() != "" {
errText = errMsg.Error.Error()
}
cancel(execErr)
return
case <-time.After(500 * time.Millisecond):
}
}
body := handlers.BuildErrorResponseBody(status, errText)
if alt == "" {
_, _ = fmt.Fprintf(c.Writer, "event: error\ndata: %s\n\n", string(body))
} else {
_, _ = c.Writer.Write(body)
}
},
})
}

View File

@@ -226,13 +226,6 @@ func (h *GeminiAPIHandler) GeminiHandler(c *gin.Context) {
func (h *GeminiAPIHandler) handleStreamGenerateContent(c *gin.Context, modelName string, rawJSON []byte) {
alt := h.GetAlt(c)
if alt == "" {
c.Header("Content-Type", "text/event-stream")
c.Header("Cache-Control", "no-cache")
c.Header("Connection", "keep-alive")
c.Header("Access-Control-Allow-Origin", "*")
}
// Get the http.Flusher interface to manually flush the response.
flusher, ok := c.Writer.(http.Flusher)
if !ok {
@@ -247,8 +240,65 @@ func (h *GeminiAPIHandler) handleStreamGenerateContent(c *gin.Context, modelName
cliCtx, cliCancel := h.GetContextWithCancel(h, c, context.Background())
dataChan, errChan := h.ExecuteStreamWithAuthManager(cliCtx, h.HandlerType(), modelName, rawJSON, alt)
h.forwardGeminiStream(c, flusher, alt, func(err error) { cliCancel(err) }, dataChan, errChan)
return
setSSEHeaders := func() {
c.Header("Content-Type", "text/event-stream")
c.Header("Cache-Control", "no-cache")
c.Header("Connection", "keep-alive")
c.Header("Access-Control-Allow-Origin", "*")
}
// Peek at the first chunk
for {
select {
case <-c.Request.Context().Done():
cliCancel(c.Request.Context().Err())
return
case errMsg, ok := <-errChan:
if !ok {
// Err channel closed cleanly; wait for data channel.
errChan = nil
continue
}
// Upstream failed immediately. Return proper error status and JSON.
h.WriteErrorResponse(c, errMsg)
if errMsg != nil {
cliCancel(errMsg.Error)
} else {
cliCancel(nil)
}
return
case chunk, ok := <-dataChan:
if !ok {
// Closed without data
if alt == "" {
setSSEHeaders()
}
flusher.Flush()
cliCancel(nil)
return
}
// Success! Set headers.
if alt == "" {
setSSEHeaders()
}
// Write first chunk
if alt == "" {
_, _ = c.Writer.Write([]byte("data: "))
_, _ = c.Writer.Write(chunk)
_, _ = c.Writer.Write([]byte("\n\n"))
} else {
_, _ = c.Writer.Write(chunk)
}
flusher.Flush()
// Continue
h.forwardGeminiStream(c, flusher, alt, func(err error) { cliCancel(err) }, dataChan, errChan)
return
}
}
}
// handleCountTokens handles token counting requests for Gemini models.
@@ -297,16 +347,15 @@ func (h *GeminiAPIHandler) handleGenerateContent(c *gin.Context, modelName strin
}
func (h *GeminiAPIHandler) forwardGeminiStream(c *gin.Context, flusher http.Flusher, alt string, cancel func(error), data <-chan []byte, errs <-chan *interfaces.ErrorMessage) {
for {
select {
case <-c.Request.Context().Done():
cancel(c.Request.Context().Err())
return
case chunk, ok := <-data:
if !ok {
cancel(nil)
return
}
var keepAliveInterval *time.Duration
if alt != "" {
disabled := time.Duration(0)
keepAliveInterval = &disabled
}
h.ForwardStream(c, flusher, cancel, data, errs, handlers.StreamForwardOptions{
KeepAliveInterval: keepAliveInterval,
WriteChunk: func(chunk []byte) {
if alt == "" {
_, _ = c.Writer.Write([]byte("data: "))
_, _ = c.Writer.Write(chunk)
@@ -314,22 +363,25 @@ func (h *GeminiAPIHandler) forwardGeminiStream(c *gin.Context, flusher http.Flus
} else {
_, _ = c.Writer.Write(chunk)
}
flusher.Flush()
case errMsg, ok := <-errs:
if !ok {
continue
},
WriteTerminalError: func(errMsg *interfaces.ErrorMessage) {
if errMsg == nil {
return
}
if errMsg != nil {
h.WriteErrorResponse(c, errMsg)
flusher.Flush()
status := http.StatusInternalServerError
if errMsg.StatusCode > 0 {
status = errMsg.StatusCode
}
var execErr error
if errMsg != nil {
execErr = errMsg.Error
errText := http.StatusText(status)
if errMsg.Error != nil && errMsg.Error.Error() != "" {
errText = errMsg.Error.Error()
}
cancel(execErr)
return
case <-time.After(500 * time.Millisecond):
}
}
body := handlers.BuildErrorResponseBody(status, errText)
if alt == "" {
_, _ = fmt.Fprintf(c.Writer, "event: error\ndata: %s\n\n", string(body))
} else {
_, _ = c.Writer.Write(body)
}
},
})
}

View File

@@ -9,9 +9,12 @@ import (
"fmt"
"net/http"
"strings"
"time"
"github.com/gin-gonic/gin"
"github.com/google/uuid"
"github.com/router-for-me/CLIProxyAPI/v6/internal/interfaces"
"github.com/router-for-me/CLIProxyAPI/v6/internal/logging"
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
coreexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
@@ -40,6 +43,117 @@ type ErrorDetail struct {
Code string `json:"code,omitempty"`
}
const idempotencyKeyMetadataKey = "idempotency_key"
const (
defaultStreamingKeepAliveSeconds = 0
defaultStreamingBootstrapRetries = 0
)
// BuildErrorResponseBody builds an OpenAI-compatible JSON error response body.
// If errText is already valid JSON, it is returned as-is to preserve upstream error payloads.
func BuildErrorResponseBody(status int, errText string) []byte {
if status <= 0 {
status = http.StatusInternalServerError
}
if strings.TrimSpace(errText) == "" {
errText = http.StatusText(status)
}
trimmed := strings.TrimSpace(errText)
if trimmed != "" && json.Valid([]byte(trimmed)) {
return []byte(trimmed)
}
errType := "invalid_request_error"
var code string
switch status {
case http.StatusUnauthorized:
errType = "authentication_error"
code = "invalid_api_key"
case http.StatusForbidden:
errType = "permission_error"
code = "insufficient_quota"
case http.StatusTooManyRequests:
errType = "rate_limit_error"
code = "rate_limit_exceeded"
case http.StatusNotFound:
errType = "invalid_request_error"
code = "model_not_found"
default:
if status >= http.StatusInternalServerError {
errType = "server_error"
code = "internal_server_error"
}
}
payload, err := json.Marshal(ErrorResponse{
Error: ErrorDetail{
Message: errText,
Type: errType,
Code: code,
},
})
if err != nil {
return []byte(fmt.Sprintf(`{"error":{"message":%q,"type":"server_error","code":"internal_server_error"}}`, errText))
}
return payload
}
// StreamingKeepAliveInterval returns the SSE keep-alive interval for this server.
// Returning 0 disables keep-alives (default when unset).
func StreamingKeepAliveInterval(cfg *config.SDKConfig) time.Duration {
seconds := defaultStreamingKeepAliveSeconds
if cfg != nil {
seconds = cfg.Streaming.KeepAliveSeconds
}
if seconds <= 0 {
return 0
}
return time.Duration(seconds) * time.Second
}
// StreamingBootstrapRetries returns how many times a streaming request may be retried before any bytes are sent.
func StreamingBootstrapRetries(cfg *config.SDKConfig) int {
retries := defaultStreamingBootstrapRetries
if cfg != nil {
retries = cfg.Streaming.BootstrapRetries
}
if retries < 0 {
retries = 0
}
return retries
}
func requestExecutionMetadata(ctx context.Context) map[string]any {
// Idempotency-Key is an optional client-supplied header used to correlate retries.
// It is forwarded as execution metadata; when absent we generate a UUID.
key := ""
if ctx != nil {
if ginCtx, ok := ctx.Value("gin").(*gin.Context); ok && ginCtx != nil && ginCtx.Request != nil {
key = strings.TrimSpace(ginCtx.GetHeader("Idempotency-Key"))
}
}
if key == "" {
key = uuid.NewString()
}
return map[string]any{idempotencyKeyMetadataKey: key}
}
func mergeMetadata(base, overlay map[string]any) map[string]any {
if len(base) == 0 && len(overlay) == 0 {
return nil
}
out := make(map[string]any, len(base)+len(overlay))
for k, v := range base {
out[k] = v
}
for k, v := range overlay {
out[k] = v
}
return out
}
// BaseAPIHandler contains the handlers for API endpoints.
// It holds a pool of clients to interact with the backend service and manages
// load balancing, client selection, and configuration.
@@ -103,13 +217,39 @@ func (h *BaseAPIHandler) GetAlt(c *gin.Context) string {
// Parameters:
// - handler: The API handler associated with the request.
// - c: The Gin context of the current request.
// - ctx: The parent context.
// - ctx: The parent context (caller values/deadlines are preserved; request context adds cancellation and request ID).
//
// Returns:
// - context.Context: The new context with cancellation and embedded values.
// - APIHandlerCancelFunc: A function to cancel the context and log the response.
func (h *BaseAPIHandler) GetContextWithCancel(handler interfaces.APIHandler, c *gin.Context, ctx context.Context) (context.Context, APIHandlerCancelFunc) {
newCtx, cancel := context.WithCancel(ctx)
parentCtx := ctx
if parentCtx == nil {
parentCtx = context.Background()
}
var requestCtx context.Context
if c != nil && c.Request != nil {
requestCtx = c.Request.Context()
}
if requestCtx != nil && logging.GetRequestID(parentCtx) == "" {
if requestID := logging.GetRequestID(requestCtx); requestID != "" {
parentCtx = logging.WithRequestID(parentCtx, requestID)
} else if requestID := logging.GetGinRequestID(c); requestID != "" {
parentCtx = logging.WithRequestID(parentCtx, requestID)
}
}
newCtx, cancel := context.WithCancel(parentCtx)
if requestCtx != nil && requestCtx != parentCtx {
go func() {
select {
case <-requestCtx.Done():
cancel()
case <-newCtx.Done():
}
}()
}
newCtx = context.WithValue(newCtx, "gin", c)
newCtx = context.WithValue(newCtx, "handler", handler)
return newCtx, func(params ...interface{}) {
@@ -182,6 +322,7 @@ func (h *BaseAPIHandler) ExecuteWithAuthManager(ctx context.Context, handlerType
if errMsg != nil {
return nil, errMsg
}
reqMeta := requestExecutionMetadata(ctx)
req := coreexecutor.Request{
Model: normalizedModel,
Payload: cloneBytes(rawJSON),
@@ -195,9 +336,7 @@ func (h *BaseAPIHandler) ExecuteWithAuthManager(ctx context.Context, handlerType
OriginalRequest: cloneBytes(rawJSON),
SourceFormat: sdktranslator.FromString(handlerType),
}
if cloned := cloneMetadata(metadata); cloned != nil {
opts.Metadata = cloned
}
opts.Metadata = mergeMetadata(cloneMetadata(metadata), reqMeta)
resp, err := h.AuthManager.Execute(ctx, providers, req, opts)
if err != nil {
status := http.StatusInternalServerError
@@ -224,6 +363,7 @@ func (h *BaseAPIHandler) ExecuteCountWithAuthManager(ctx context.Context, handle
if errMsg != nil {
return nil, errMsg
}
reqMeta := requestExecutionMetadata(ctx)
req := coreexecutor.Request{
Model: normalizedModel,
Payload: cloneBytes(rawJSON),
@@ -237,9 +377,7 @@ func (h *BaseAPIHandler) ExecuteCountWithAuthManager(ctx context.Context, handle
OriginalRequest: cloneBytes(rawJSON),
SourceFormat: sdktranslator.FromString(handlerType),
}
if cloned := cloneMetadata(metadata); cloned != nil {
opts.Metadata = cloned
}
opts.Metadata = mergeMetadata(cloneMetadata(metadata), reqMeta)
resp, err := h.AuthManager.ExecuteCount(ctx, providers, req, opts)
if err != nil {
status := http.StatusInternalServerError
@@ -269,6 +407,7 @@ func (h *BaseAPIHandler) ExecuteStreamWithAuthManager(ctx context.Context, handl
close(errChan)
return nil, errChan
}
reqMeta := requestExecutionMetadata(ctx)
req := coreexecutor.Request{
Model: normalizedModel,
Payload: cloneBytes(rawJSON),
@@ -282,9 +421,7 @@ func (h *BaseAPIHandler) ExecuteStreamWithAuthManager(ctx context.Context, handl
OriginalRequest: cloneBytes(rawJSON),
SourceFormat: sdktranslator.FromString(handlerType),
}
if cloned := cloneMetadata(metadata); cloned != nil {
opts.Metadata = cloned
}
opts.Metadata = mergeMetadata(cloneMetadata(metadata), reqMeta)
chunks, err := h.AuthManager.ExecuteStream(ctx, providers, req, opts)
if err != nil {
errChan := make(chan *interfaces.ErrorMessage, 1)
@@ -309,31 +446,94 @@ func (h *BaseAPIHandler) ExecuteStreamWithAuthManager(ctx context.Context, handl
go func() {
defer close(dataChan)
defer close(errChan)
for chunk := range chunks {
if chunk.Err != nil {
status := http.StatusInternalServerError
if se, ok := chunk.Err.(interface{ StatusCode() int }); ok && se != nil {
if code := se.StatusCode(); code > 0 {
status = code
}
}
var addon http.Header
if he, ok := chunk.Err.(interface{ Headers() http.Header }); ok && he != nil {
if hdr := he.Headers(); hdr != nil {
addon = hdr.Clone()
}
}
errChan <- &interfaces.ErrorMessage{StatusCode: status, Error: chunk.Err, Addon: addon}
return
sentPayload := false
bootstrapRetries := 0
maxBootstrapRetries := StreamingBootstrapRetries(h.Cfg)
bootstrapEligible := func(err error) bool {
status := statusFromError(err)
if status == 0 {
return true
}
if len(chunk.Payload) > 0 {
dataChan <- cloneBytes(chunk.Payload)
switch status {
case http.StatusUnauthorized, http.StatusForbidden, http.StatusPaymentRequired,
http.StatusRequestTimeout, http.StatusTooManyRequests:
return true
default:
return status >= http.StatusInternalServerError
}
}
outer:
for {
for {
var chunk coreexecutor.StreamChunk
var ok bool
if ctx != nil {
select {
case <-ctx.Done():
return
case chunk, ok = <-chunks:
}
} else {
chunk, ok = <-chunks
}
if !ok {
return
}
if chunk.Err != nil {
streamErr := chunk.Err
// Safe bootstrap recovery: if the upstream fails before any payload bytes are sent,
// retry a few times (to allow auth rotation / transient recovery) and then attempt model fallback.
if !sentPayload {
if bootstrapRetries < maxBootstrapRetries && bootstrapEligible(streamErr) {
bootstrapRetries++
retryChunks, retryErr := h.AuthManager.ExecuteStream(ctx, providers, req, opts)
if retryErr == nil {
chunks = retryChunks
continue outer
}
streamErr = retryErr
}
}
status := http.StatusInternalServerError
if se, ok := streamErr.(interface{ StatusCode() int }); ok && se != nil {
if code := se.StatusCode(); code > 0 {
status = code
}
}
var addon http.Header
if he, ok := streamErr.(interface{ Headers() http.Header }); ok && he != nil {
if hdr := he.Headers(); hdr != nil {
addon = hdr.Clone()
}
}
errChan <- &interfaces.ErrorMessage{StatusCode: status, Error: streamErr, Addon: addon}
return
}
if len(chunk.Payload) > 0 {
sentPayload = true
dataChan <- cloneBytes(chunk.Payload)
}
}
}
}()
return dataChan, errChan
}
func statusFromError(err error) int {
if err == nil {
return 0
}
if se, ok := err.(interface{ StatusCode() int }); ok && se != nil {
if code := se.StatusCode(); code > 0 {
return code
}
}
return 0
}
func (h *BaseAPIHandler) getRequestDetails(modelName string) (providers []string, normalizedModel string, metadata map[string]any, err *interfaces.ErrorMessage) {
// Resolve "auto" model to an actual available model first
resolvedModelName := util.ResolveAutoModel(modelName)
@@ -417,38 +617,7 @@ func (h *BaseAPIHandler) WriteErrorResponse(c *gin.Context, msg *interfaces.Erro
}
}
// Prefer preserving upstream JSON error bodies when possible.
buildJSONBody := func() []byte {
trimmed := strings.TrimSpace(errText)
if trimmed != "" && json.Valid([]byte(trimmed)) {
return []byte(trimmed)
}
errType := "invalid_request_error"
switch status {
case http.StatusUnauthorized:
errType = "authentication_error"
case http.StatusForbidden:
errType = "permission_error"
case http.StatusTooManyRequests:
errType = "rate_limit_error"
default:
if status >= http.StatusInternalServerError {
errType = "server_error"
}
}
payload, err := json.Marshal(ErrorResponse{
Error: ErrorDetail{
Message: errText,
Type: errType,
},
})
if err != nil {
return []byte(fmt.Sprintf(`{"error":{"message":%q,"type":"server_error"}}`, errText))
}
return payload
}
body := buildJSONBody()
body := BuildErrorResponseBody(status, errText)
c.Set("API_RESPONSE", bytes.Clone(body))
if !c.Writer.Written() {

View File

@@ -0,0 +1,124 @@
package handlers
import (
"context"
"net/http"
"sync"
"testing"
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
coreexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
sdkconfig "github.com/router-for-me/CLIProxyAPI/v6/sdk/config"
)
type failOnceStreamExecutor struct {
mu sync.Mutex
calls int
}
func (e *failOnceStreamExecutor) Identifier() string { return "codex" }
func (e *failOnceStreamExecutor) Execute(context.Context, *coreauth.Auth, coreexecutor.Request, coreexecutor.Options) (coreexecutor.Response, error) {
return coreexecutor.Response{}, &coreauth.Error{Code: "not_implemented", Message: "Execute not implemented"}
}
func (e *failOnceStreamExecutor) ExecuteStream(context.Context, *coreauth.Auth, coreexecutor.Request, coreexecutor.Options) (<-chan coreexecutor.StreamChunk, error) {
e.mu.Lock()
e.calls++
call := e.calls
e.mu.Unlock()
ch := make(chan coreexecutor.StreamChunk, 1)
if call == 1 {
ch <- coreexecutor.StreamChunk{
Err: &coreauth.Error{
Code: "unauthorized",
Message: "unauthorized",
Retryable: false,
HTTPStatus: http.StatusUnauthorized,
},
}
close(ch)
return ch, nil
}
ch <- coreexecutor.StreamChunk{Payload: []byte("ok")}
close(ch)
return ch, nil
}
func (e *failOnceStreamExecutor) Refresh(ctx context.Context, auth *coreauth.Auth) (*coreauth.Auth, error) {
return auth, nil
}
func (e *failOnceStreamExecutor) CountTokens(context.Context, *coreauth.Auth, coreexecutor.Request, coreexecutor.Options) (coreexecutor.Response, error) {
return coreexecutor.Response{}, &coreauth.Error{Code: "not_implemented", Message: "CountTokens not implemented"}
}
func (e *failOnceStreamExecutor) Calls() int {
e.mu.Lock()
defer e.mu.Unlock()
return e.calls
}
func TestExecuteStreamWithAuthManager_RetriesBeforeFirstByte(t *testing.T) {
executor := &failOnceStreamExecutor{}
manager := coreauth.NewManager(nil, nil, nil)
manager.RegisterExecutor(executor)
auth1 := &coreauth.Auth{
ID: "auth1",
Provider: "codex",
Status: coreauth.StatusActive,
Metadata: map[string]any{"email": "test1@example.com"},
}
if _, err := manager.Register(context.Background(), auth1); err != nil {
t.Fatalf("manager.Register(auth1): %v", err)
}
auth2 := &coreauth.Auth{
ID: "auth2",
Provider: "codex",
Status: coreauth.StatusActive,
Metadata: map[string]any{"email": "test2@example.com"},
}
if _, err := manager.Register(context.Background(), auth2); err != nil {
t.Fatalf("manager.Register(auth2): %v", err)
}
registry.GetGlobalRegistry().RegisterClient(auth1.ID, auth1.Provider, []*registry.ModelInfo{{ID: "test-model"}})
registry.GetGlobalRegistry().RegisterClient(auth2.ID, auth2.Provider, []*registry.ModelInfo{{ID: "test-model"}})
t.Cleanup(func() {
registry.GetGlobalRegistry().UnregisterClient(auth1.ID)
registry.GetGlobalRegistry().UnregisterClient(auth2.ID)
})
handler := NewBaseAPIHandlers(&sdkconfig.SDKConfig{
Streaming: sdkconfig.StreamingConfig{
BootstrapRetries: 1,
},
}, manager)
dataChan, errChan := handler.ExecuteStreamWithAuthManager(context.Background(), "openai", "test-model", []byte(`{"model":"test-model"}`), "")
if dataChan == nil || errChan == nil {
t.Fatalf("expected non-nil channels")
}
var got []byte
for chunk := range dataChan {
got = append(got, chunk...)
}
for msg := range errChan {
if msg != nil {
t.Fatalf("unexpected error: %+v", msg)
}
}
if string(got) != "ok" {
t.Fatalf("expected payload ok, got %q", string(got))
}
if executor.Calls() != 2 {
t.Fatalf("expected 2 stream attempts, got %d", executor.Calls())
}
}

View File

@@ -11,7 +11,7 @@ import (
"encoding/json"
"fmt"
"net/http"
"time"
"sync"
"github.com/gin-gonic/gin"
. "github.com/router-for-me/CLIProxyAPI/v6/internal/constant"
@@ -443,11 +443,6 @@ func (h *OpenAIAPIHandler) handleNonStreamingResponse(c *gin.Context, rawJSON []
// - c: The Gin context containing the HTTP request and response
// - rawJSON: The raw JSON bytes of the OpenAI-compatible request
func (h *OpenAIAPIHandler) handleStreamingResponse(c *gin.Context, rawJSON []byte) {
c.Header("Content-Type", "text/event-stream")
c.Header("Cache-Control", "no-cache")
c.Header("Connection", "keep-alive")
c.Header("Access-Control-Allow-Origin", "*")
// Get the http.Flusher interface to manually flush the response.
flusher, ok := c.Writer.(http.Flusher)
if !ok {
@@ -463,7 +458,55 @@ func (h *OpenAIAPIHandler) handleStreamingResponse(c *gin.Context, rawJSON []byt
modelName := gjson.GetBytes(rawJSON, "model").String()
cliCtx, cliCancel := h.GetContextWithCancel(h, c, context.Background())
dataChan, errChan := h.ExecuteStreamWithAuthManager(cliCtx, h.HandlerType(), modelName, rawJSON, h.GetAlt(c))
h.handleStreamResult(c, flusher, func(err error) { cliCancel(err) }, dataChan, errChan)
setSSEHeaders := func() {
c.Header("Content-Type", "text/event-stream")
c.Header("Cache-Control", "no-cache")
c.Header("Connection", "keep-alive")
c.Header("Access-Control-Allow-Origin", "*")
}
// Peek at the first chunk to determine success or failure before setting headers
for {
select {
case <-c.Request.Context().Done():
cliCancel(c.Request.Context().Err())
return
case errMsg, ok := <-errChan:
if !ok {
// Err channel closed cleanly; wait for data channel.
errChan = nil
continue
}
// Upstream failed immediately. Return proper error status and JSON.
h.WriteErrorResponse(c, errMsg)
if errMsg != nil {
cliCancel(errMsg.Error)
} else {
cliCancel(nil)
}
return
case chunk, ok := <-dataChan:
if !ok {
// Stream closed without data? Send DONE or just headers.
setSSEHeaders()
_, _ = fmt.Fprintf(c.Writer, "data: [DONE]\n\n")
flusher.Flush()
cliCancel(nil)
return
}
// Success! Commit to streaming headers.
setSSEHeaders()
_, _ = fmt.Fprintf(c.Writer, "data: %s\n\n", string(chunk))
flusher.Flush()
// Continue streaming the rest
h.handleStreamResult(c, flusher, func(err error) { cliCancel(err) }, dataChan, errChan)
return
}
}
}
// handleCompletionsNonStreamingResponse handles non-streaming completions responses.
@@ -500,11 +543,6 @@ func (h *OpenAIAPIHandler) handleCompletionsNonStreamingResponse(c *gin.Context,
// - c: The Gin context containing the HTTP request and response
// - rawJSON: The raw JSON bytes of the OpenAI-compatible completions request
func (h *OpenAIAPIHandler) handleCompletionsStreamingResponse(c *gin.Context, rawJSON []byte) {
c.Header("Content-Type", "text/event-stream")
c.Header("Cache-Control", "no-cache")
c.Header("Connection", "keep-alive")
c.Header("Access-Control-Allow-Origin", "*")
// Get the http.Flusher interface to manually flush the response.
flusher, ok := c.Writer.(http.Flusher)
if !ok {
@@ -524,71 +562,109 @@ func (h *OpenAIAPIHandler) handleCompletionsStreamingResponse(c *gin.Context, ra
cliCtx, cliCancel := h.GetContextWithCancel(h, c, context.Background())
dataChan, errChan := h.ExecuteStreamWithAuthManager(cliCtx, h.HandlerType(), modelName, chatCompletionsJSON, "")
setSSEHeaders := func() {
c.Header("Content-Type", "text/event-stream")
c.Header("Cache-Control", "no-cache")
c.Header("Connection", "keep-alive")
c.Header("Access-Control-Allow-Origin", "*")
}
// Peek at the first chunk
for {
select {
case <-c.Request.Context().Done():
cliCancel(c.Request.Context().Err())
return
case chunk, isOk := <-dataChan:
if !isOk {
case errMsg, ok := <-errChan:
if !ok {
// Err channel closed cleanly; wait for data channel.
errChan = nil
continue
}
h.WriteErrorResponse(c, errMsg)
if errMsg != nil {
cliCancel(errMsg.Error)
} else {
cliCancel(nil)
}
return
case chunk, ok := <-dataChan:
if !ok {
setSSEHeaders()
_, _ = fmt.Fprintf(c.Writer, "data: [DONE]\n\n")
flusher.Flush()
cliCancel()
cliCancel(nil)
return
}
// Success! Set headers.
setSSEHeaders()
// Write the first chunk
converted := convertChatCompletionsStreamChunkToCompletions(chunk)
if converted != nil {
_, _ = fmt.Fprintf(c.Writer, "data: %s\n\n", string(converted))
flusher.Flush()
}
case errMsg, isOk := <-errChan:
if !isOk {
continue
}
if errMsg != nil {
h.WriteErrorResponse(c, errMsg)
flusher.Flush()
}
var execErr error
if errMsg != nil {
execErr = errMsg.Error
}
cliCancel(execErr)
done := make(chan struct{})
var doneOnce sync.Once
stop := func() { doneOnce.Do(func() { close(done) }) }
convertedChan := make(chan []byte)
go func() {
defer close(convertedChan)
for {
select {
case <-done:
return
case chunk, ok := <-dataChan:
if !ok {
return
}
converted := convertChatCompletionsStreamChunkToCompletions(chunk)
if converted == nil {
continue
}
select {
case <-done:
return
case convertedChan <- converted:
}
}
}
}()
h.handleStreamResult(c, flusher, func(err error) {
stop()
cliCancel(err)
}, convertedChan, errChan)
return
case <-time.After(500 * time.Millisecond):
}
}
}
func (h *OpenAIAPIHandler) handleStreamResult(c *gin.Context, flusher http.Flusher, cancel func(error), data <-chan []byte, errs <-chan *interfaces.ErrorMessage) {
for {
select {
case <-c.Request.Context().Done():
cancel(c.Request.Context().Err())
return
case chunk, ok := <-data:
if !ok {
_, _ = fmt.Fprintf(c.Writer, "data: [DONE]\n\n")
flusher.Flush()
cancel(nil)
h.ForwardStream(c, flusher, cancel, data, errs, handlers.StreamForwardOptions{
WriteChunk: func(chunk []byte) {
_, _ = fmt.Fprintf(c.Writer, "data: %s\n\n", string(chunk))
},
WriteTerminalError: func(errMsg *interfaces.ErrorMessage) {
if errMsg == nil {
return
}
_, _ = fmt.Fprintf(c.Writer, "data: %s\n\n", string(chunk))
flusher.Flush()
case errMsg, ok := <-errs:
if !ok {
continue
status := http.StatusInternalServerError
if errMsg.StatusCode > 0 {
status = errMsg.StatusCode
}
if errMsg != nil {
h.WriteErrorResponse(c, errMsg)
flusher.Flush()
errText := http.StatusText(status)
if errMsg.Error != nil && errMsg.Error.Error() != "" {
errText = errMsg.Error.Error()
}
var execErr error
if errMsg != nil {
execErr = errMsg.Error
}
cancel(execErr)
return
case <-time.After(500 * time.Millisecond):
}
}
body := handlers.BuildErrorResponseBody(status, errText)
_, _ = fmt.Fprintf(c.Writer, "data: %s\n\n", string(body))
},
WriteDone: func() {
_, _ = fmt.Fprint(c.Writer, "data: [DONE]\n\n")
},
})
}

View File

@@ -11,7 +11,6 @@ import (
"context"
"fmt"
"net/http"
"time"
"github.com/gin-gonic/gin"
. "github.com/router-for-me/CLIProxyAPI/v6/internal/constant"
@@ -128,11 +127,6 @@ func (h *OpenAIResponsesAPIHandler) handleNonStreamingResponse(c *gin.Context, r
// - c: The Gin context containing the HTTP request and response
// - rawJSON: The raw JSON bytes of the OpenAIResponses-compatible request
func (h *OpenAIResponsesAPIHandler) handleStreamingResponse(c *gin.Context, rawJSON []byte) {
c.Header("Content-Type", "text/event-stream")
c.Header("Cache-Control", "no-cache")
c.Header("Connection", "keep-alive")
c.Header("Access-Control-Allow-Origin", "*")
// Get the http.Flusher interface to manually flush the response.
flusher, ok := c.Writer.(http.Flusher)
if !ok {
@@ -149,46 +143,88 @@ func (h *OpenAIResponsesAPIHandler) handleStreamingResponse(c *gin.Context, rawJ
modelName := gjson.GetBytes(rawJSON, "model").String()
cliCtx, cliCancel := h.GetContextWithCancel(h, c, context.Background())
dataChan, errChan := h.ExecuteStreamWithAuthManager(cliCtx, h.HandlerType(), modelName, rawJSON, "")
h.forwardResponsesStream(c, flusher, func(err error) { cliCancel(err) }, dataChan, errChan)
return
}
func (h *OpenAIResponsesAPIHandler) forwardResponsesStream(c *gin.Context, flusher http.Flusher, cancel func(error), data <-chan []byte, errs <-chan *interfaces.ErrorMessage) {
setSSEHeaders := func() {
c.Header("Content-Type", "text/event-stream")
c.Header("Cache-Control", "no-cache")
c.Header("Connection", "keep-alive")
c.Header("Access-Control-Allow-Origin", "*")
}
// Peek at the first chunk
for {
select {
case <-c.Request.Context().Done():
cancel(c.Request.Context().Err())
cliCancel(c.Request.Context().Err())
return
case chunk, ok := <-data:
case errMsg, ok := <-errChan:
if !ok {
// Err channel closed cleanly; wait for data channel.
errChan = nil
continue
}
// Upstream failed immediately. Return proper error status and JSON.
h.WriteErrorResponse(c, errMsg)
if errMsg != nil {
cliCancel(errMsg.Error)
} else {
cliCancel(nil)
}
return
case chunk, ok := <-dataChan:
if !ok {
// Stream closed without data? Send headers and done.
setSSEHeaders()
_, _ = c.Writer.Write([]byte("\n"))
flusher.Flush()
cancel(nil)
cliCancel(nil)
return
}
// Success! Set headers.
setSSEHeaders()
// Write first chunk logic (matching forwardResponsesStream)
if bytes.HasPrefix(chunk, []byte("event:")) {
_, _ = c.Writer.Write([]byte("\n"))
}
_, _ = c.Writer.Write(chunk)
_, _ = c.Writer.Write([]byte("\n"))
flusher.Flush()
case errMsg, ok := <-errs:
if !ok {
continue
}
if errMsg != nil {
h.WriteErrorResponse(c, errMsg)
flusher.Flush()
}
var execErr error
if errMsg != nil {
execErr = errMsg.Error
}
cancel(execErr)
// Continue
h.forwardResponsesStream(c, flusher, func(err error) { cliCancel(err) }, dataChan, errChan)
return
case <-time.After(500 * time.Millisecond):
}
}
}
func (h *OpenAIResponsesAPIHandler) forwardResponsesStream(c *gin.Context, flusher http.Flusher, cancel func(error), data <-chan []byte, errs <-chan *interfaces.ErrorMessage) {
h.ForwardStream(c, flusher, cancel, data, errs, handlers.StreamForwardOptions{
WriteChunk: func(chunk []byte) {
if bytes.HasPrefix(chunk, []byte("event:")) {
_, _ = c.Writer.Write([]byte("\n"))
}
_, _ = c.Writer.Write(chunk)
_, _ = c.Writer.Write([]byte("\n"))
},
WriteTerminalError: func(errMsg *interfaces.ErrorMessage) {
if errMsg == nil {
return
}
status := http.StatusInternalServerError
if errMsg.StatusCode > 0 {
status = errMsg.StatusCode
}
errText := http.StatusText(status)
if errMsg.Error != nil && errMsg.Error.Error() != "" {
errText = errMsg.Error.Error()
}
body := handlers.BuildErrorResponseBody(status, errText)
_, _ = fmt.Fprintf(c.Writer, "\nevent: error\ndata: %s\n\n", string(body))
},
WriteDone: func() {
_, _ = c.Writer.Write([]byte("\n"))
},
})
}

View File

@@ -0,0 +1,121 @@
package handlers
import (
"net/http"
"time"
"github.com/gin-gonic/gin"
"github.com/router-for-me/CLIProxyAPI/v6/internal/interfaces"
)
type StreamForwardOptions struct {
// KeepAliveInterval overrides the configured streaming keep-alive interval.
// If nil, the configured default is used. If set to <= 0, keep-alives are disabled.
KeepAliveInterval *time.Duration
// WriteChunk writes a single data chunk to the response body. It should not flush.
WriteChunk func(chunk []byte)
// WriteTerminalError writes an error payload to the response body when streaming fails
// after headers have already been committed. It should not flush.
WriteTerminalError func(errMsg *interfaces.ErrorMessage)
// WriteDone optionally writes a terminal marker when the upstream data channel closes
// without an error (e.g. OpenAI's `[DONE]`). It should not flush.
WriteDone func()
// WriteKeepAlive optionally writes a keep-alive heartbeat. It should not flush.
// When nil, a standard SSE comment heartbeat is used.
WriteKeepAlive func()
}
func (h *BaseAPIHandler) ForwardStream(c *gin.Context, flusher http.Flusher, cancel func(error), data <-chan []byte, errs <-chan *interfaces.ErrorMessage, opts StreamForwardOptions) {
if c == nil {
return
}
if cancel == nil {
return
}
writeChunk := opts.WriteChunk
if writeChunk == nil {
writeChunk = func([]byte) {}
}
writeKeepAlive := opts.WriteKeepAlive
if writeKeepAlive == nil {
writeKeepAlive = func() {
_, _ = c.Writer.Write([]byte(": keep-alive\n\n"))
}
}
keepAliveInterval := StreamingKeepAliveInterval(h.Cfg)
if opts.KeepAliveInterval != nil {
keepAliveInterval = *opts.KeepAliveInterval
}
var keepAlive *time.Ticker
var keepAliveC <-chan time.Time
if keepAliveInterval > 0 {
keepAlive = time.NewTicker(keepAliveInterval)
defer keepAlive.Stop()
keepAliveC = keepAlive.C
}
var terminalErr *interfaces.ErrorMessage
for {
select {
case <-c.Request.Context().Done():
cancel(c.Request.Context().Err())
return
case chunk, ok := <-data:
if !ok {
// Prefer surfacing a terminal error if one is pending.
if terminalErr == nil {
select {
case errMsg, ok := <-errs:
if ok && errMsg != nil {
terminalErr = errMsg
}
default:
}
}
if terminalErr != nil {
if opts.WriteTerminalError != nil {
opts.WriteTerminalError(terminalErr)
}
flusher.Flush()
cancel(terminalErr.Error)
return
}
if opts.WriteDone != nil {
opts.WriteDone()
}
flusher.Flush()
cancel(nil)
return
}
writeChunk(chunk)
flusher.Flush()
case errMsg, ok := <-errs:
if !ok {
continue
}
if errMsg != nil {
terminalErr = errMsg
if opts.WriteTerminalError != nil {
opts.WriteTerminalError(errMsg)
flusher.Flush()
}
}
var execErr error
if errMsg != nil {
execErr = errMsg.Error
}
cancel(execErr)
return
case <-keepAliveC:
writeKeepAlive()
flusher.Flush()
}
}
}

View File

@@ -12,6 +12,7 @@ import (
"time"
"github.com/google/uuid"
"github.com/router-for-me/CLIProxyAPI/v6/internal/logging"
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
@@ -202,10 +203,10 @@ func (m *Manager) Register(ctx context.Context, auth *Auth) (*Auth, error) {
if auth == nil {
return nil, nil
}
auth.EnsureIndex()
if auth.ID == "" {
auth.ID = uuid.NewString()
}
auth.EnsureIndex()
m.mu.Lock()
m.auths[auth.ID] = auth.Clone()
m.mu.Unlock()
@@ -220,7 +221,7 @@ func (m *Manager) Update(ctx context.Context, auth *Auth) (*Auth, error) {
return nil, nil
}
m.mu.Lock()
if existing, ok := m.auths[auth.ID]; ok && existing != nil && !auth.indexAssigned && auth.Index == 0 {
if existing, ok := m.auths[auth.ID]; ok && existing != nil && !auth.indexAssigned && auth.Index == "" {
auth.Index = existing.Index
auth.indexAssigned = existing.indexAssigned
}
@@ -262,7 +263,6 @@ func (m *Manager) Execute(ctx context.Context, providers []string, req cliproxye
return cliproxyexecutor.Response{}, &Error{Code: "provider_not_found", Message: "no provider supplied"}
}
rotated := m.rotateProviders(req.Model, normalized)
defer m.advanceProviderCursor(req.Model, normalized)
retryTimes, maxWait := m.retrySettings()
attempts := retryTimes + 1
@@ -301,7 +301,6 @@ func (m *Manager) ExecuteCount(ctx context.Context, providers []string, req clip
return cliproxyexecutor.Response{}, &Error{Code: "provider_not_found", Message: "no provider supplied"}
}
rotated := m.rotateProviders(req.Model, normalized)
defer m.advanceProviderCursor(req.Model, normalized)
retryTimes, maxWait := m.retrySettings()
attempts := retryTimes + 1
@@ -340,7 +339,6 @@ func (m *Manager) ExecuteStream(ctx context.Context, providers []string, req cli
return nil, &Error{Code: "provider_not_found", Message: "no provider supplied"}
}
rotated := m.rotateProviders(req.Model, normalized)
defer m.advanceProviderCursor(req.Model, normalized)
retryTimes, maxWait := m.retrySettings()
attempts := retryTimes + 1
@@ -389,17 +387,18 @@ func (m *Manager) executeWithProvider(ctx context.Context, provider string, req
accountType, accountInfo := auth.AccountInfo()
proxyInfo := auth.ProxyInfo()
entry := logEntryWithRequestID(ctx)
if accountType == "api_key" {
if proxyInfo != "" {
log.Debugf("Use API key %s for model %s %s", util.HideAPIKey(accountInfo), req.Model, proxyInfo)
entry.Debugf("Use API key %s for model %s %s", util.HideAPIKey(accountInfo), req.Model, proxyInfo)
} else {
log.Debugf("Use API key %s for model %s", util.HideAPIKey(accountInfo), req.Model)
entry.Debugf("Use API key %s for model %s", util.HideAPIKey(accountInfo), req.Model)
}
} else if accountType == "oauth" {
if proxyInfo != "" {
log.Debugf("Use OAuth %s for model %s %s", accountInfo, req.Model, proxyInfo)
entry.Debugf("Use OAuth %s for model %s %s", accountInfo, req.Model, proxyInfo)
} else {
log.Debugf("Use OAuth %s for model %s", accountInfo, req.Model)
entry.Debugf("Use OAuth %s for model %s", accountInfo, req.Model)
}
}
@@ -449,17 +448,18 @@ func (m *Manager) executeCountWithProvider(ctx context.Context, provider string,
accountType, accountInfo := auth.AccountInfo()
proxyInfo := auth.ProxyInfo()
entry := logEntryWithRequestID(ctx)
if accountType == "api_key" {
if proxyInfo != "" {
log.Debugf("Use API key %s for model %s %s", util.HideAPIKey(accountInfo), req.Model, proxyInfo)
entry.Debugf("Use API key %s for model %s %s", util.HideAPIKey(accountInfo), req.Model, proxyInfo)
} else {
log.Debugf("Use API key %s for model %s", util.HideAPIKey(accountInfo), req.Model)
entry.Debugf("Use API key %s for model %s", util.HideAPIKey(accountInfo), req.Model)
}
} else if accountType == "oauth" {
if proxyInfo != "" {
log.Debugf("Use OAuth %s for model %s %s", accountInfo, req.Model, proxyInfo)
entry.Debugf("Use OAuth %s for model %s %s", accountInfo, req.Model, proxyInfo)
} else {
log.Debugf("Use OAuth %s for model %s", accountInfo, req.Model)
entry.Debugf("Use OAuth %s for model %s", accountInfo, req.Model)
}
}
@@ -509,17 +509,18 @@ func (m *Manager) executeStreamWithProvider(ctx context.Context, provider string
accountType, accountInfo := auth.AccountInfo()
proxyInfo := auth.ProxyInfo()
entry := logEntryWithRequestID(ctx)
if accountType == "api_key" {
if proxyInfo != "" {
log.Debugf("Use API key %s for model %s %s", util.HideAPIKey(accountInfo), req.Model, proxyInfo)
entry.Debugf("Use API key %s for model %s %s", util.HideAPIKey(accountInfo), req.Model, proxyInfo)
} else {
log.Debugf("Use API key %s for model %s", util.HideAPIKey(accountInfo), req.Model)
entry.Debugf("Use API key %s for model %s", util.HideAPIKey(accountInfo), req.Model)
}
} else if accountType == "oauth" {
if proxyInfo != "" {
log.Debugf("Use OAuth %s for model %s %s", accountInfo, req.Model, proxyInfo)
entry.Debugf("Use OAuth %s for model %s %s", accountInfo, req.Model, proxyInfo)
} else {
log.Debugf("Use OAuth %s for model %s", accountInfo, req.Model)
entry.Debugf("Use OAuth %s for model %s", accountInfo, req.Model)
}
}
@@ -636,13 +637,20 @@ func (m *Manager) normalizeProviders(providers []string) []string {
return result
}
// rotateProviders returns a rotated view of the providers list starting from the
// current offset for the model, and atomically increments the offset for the next call.
// This ensures concurrent requests get different starting providers.
func (m *Manager) rotateProviders(model string, providers []string) []string {
if len(providers) == 0 {
return nil
}
m.mu.RLock()
// Atomic read-and-increment: get current offset and advance cursor in one lock
m.mu.Lock()
offset := m.providerOffsets[model]
m.mu.RUnlock()
m.providerOffsets[model] = (offset + 1) % len(providers)
m.mu.Unlock()
if len(providers) > 0 {
offset %= len(providers)
}
@@ -658,19 +666,6 @@ func (m *Manager) rotateProviders(model string, providers []string) []string {
return rotated
}
func (m *Manager) advanceProviderCursor(model string, providers []string) {
if len(providers) == 0 {
m.mu.Lock()
delete(m.providerOffsets, model)
m.mu.Unlock()
return
}
m.mu.Lock()
current := m.providerOffsets[model]
m.providerOffsets[model] = (current + 1) % len(providers)
m.mu.Unlock()
}
func (m *Manager) retrySettings() (int, time.Duration) {
if m == nil {
return 0, 0
@@ -1604,6 +1599,17 @@ type RequestPreparer interface {
PrepareRequest(req *http.Request, auth *Auth) error
}
// logEntryWithRequestID returns a logrus entry with request_id field if available in context.
func logEntryWithRequestID(ctx context.Context) *log.Entry {
if ctx == nil {
return log.NewEntry(log.StandardLogger())
}
if reqID := logging.GetRequestID(ctx); reqID != "" {
return log.WithField("request_id", reqID)
}
return log.NewEntry(log.StandardLogger())
}
// InjectCredentials delegates per-provider HTTP request preparation when supported.
// If the registered executor for the auth provider implements RequestPreparer,
// it will be invoked to modify the request (e.g., add headers).

View File

@@ -1,11 +1,12 @@
package auth
import (
"crypto/sha256"
"encoding/hex"
"encoding/json"
"strconv"
"strings"
"sync"
"sync/atomic"
"time"
baseauth "github.com/router-for-me/CLIProxyAPI/v6/internal/auth"
@@ -15,8 +16,8 @@ import (
type Auth struct {
// ID uniquely identifies the auth record across restarts.
ID string `json:"id"`
// Index is a monotonically increasing runtime identifier used for diagnostics.
Index uint64 `json:"-"`
// Index is a stable runtime identifier derived from auth metadata (not persisted).
Index string `json:"-"`
// Provider is the upstream provider key (e.g. "gemini", "claude").
Provider string `json:"provider"`
// Prefix optionally namespaces models for routing (e.g., "teamA/gemini-3-pro-preview").
@@ -94,12 +95,6 @@ type ModelState struct {
UpdatedAt time.Time `json:"updated_at"`
}
var authIndexCounter atomic.Uint64
func nextAuthIndex() uint64 {
return authIndexCounter.Add(1) - 1
}
// Clone shallow copies the Auth structure, duplicating maps to avoid accidental mutation.
func (a *Auth) Clone() *Auth {
if a == nil {
@@ -128,15 +123,41 @@ func (a *Auth) Clone() *Auth {
return &copyAuth
}
// EnsureIndex returns the global index, assigning one if it was not set yet.
func (a *Auth) EnsureIndex() uint64 {
if a == nil {
return 0
func stableAuthIndex(seed string) string {
seed = strings.TrimSpace(seed)
if seed == "" {
return ""
}
if a.indexAssigned {
sum := sha256.Sum256([]byte(seed))
return hex.EncodeToString(sum[:8])
}
// EnsureIndex returns a stable index derived from the auth file name or API key.
func (a *Auth) EnsureIndex() string {
if a == nil {
return ""
}
if a.indexAssigned && a.Index != "" {
return a.Index
}
idx := nextAuthIndex()
seed := strings.TrimSpace(a.FileName)
if seed != "" {
seed = "file:" + seed
} else if a.Attributes != nil {
if apiKey := strings.TrimSpace(a.Attributes["api_key"]); apiKey != "" {
seed = "api_key:" + apiKey
}
}
if seed == "" {
if id := strings.TrimSpace(a.ID); id != "" {
seed = "id:" + id
} else {
return ""
}
}
idx := stableAuthIndex(seed)
a.Index = idx
a.indexAssigned = true
return idx

View File

@@ -14,7 +14,7 @@ type Record struct {
Model string
APIKey string
AuthID string
AuthIndex uint64
AuthIndex string
Source string
RequestedAt time.Time
Failed bool

View File

@@ -12,6 +12,7 @@ type AccessProvider = internalconfig.AccessProvider
type Config = internalconfig.Config
type StreamingConfig = internalconfig.StreamingConfig
type TLSConfig = internalconfig.TLSConfig
type RemoteManagement = internalconfig.RemoteManagement
type AmpCode = internalconfig.AmpCode