mirror of
https://github.com/router-for-me/CLIProxyAPI.git
synced 2026-02-04 05:20:52 +08:00
Compare commits
45 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
33e53a2a56 | ||
|
|
cd5b80785f | ||
|
|
54f71aa273 | ||
|
|
3f949b7f84 | ||
|
|
443c4538bb | ||
|
|
a7fc2ee4cf | ||
|
|
8e749ac22d | ||
|
|
69e09d9bc7 | ||
|
|
06ad527e8c | ||
|
|
b7409dd2de | ||
|
|
5ba325a8fc | ||
|
|
d502840f91 | ||
|
|
99238a4b59 | ||
|
|
6d43a2ff9a | ||
|
|
3faa1ca9af | ||
|
|
9d975e0375 | ||
|
|
2a6d8b78d4 | ||
|
|
671558a822 | ||
|
|
26fbb77901 | ||
|
|
a277302262 | ||
|
|
969c1a5b72 | ||
|
|
872339bceb | ||
|
|
5dc0dbc7aa | ||
|
|
2b7ba54a2f | ||
|
|
007c3304f2 | ||
|
|
e76ba0ede9 | ||
|
|
c06ac07e23 | ||
|
|
66769ec657 | ||
|
|
f413feec61 | ||
|
|
2e538e3486 | ||
|
|
9617a7b0d6 | ||
|
|
7569320770 | ||
|
|
8d25cf0d75 | ||
|
|
64e85e7019 | ||
|
|
0c0aae1eac | ||
|
|
5dcf7cb846 | ||
|
|
5bf89dd757 | ||
|
|
4442574e53 | ||
|
|
71a6dffbb6 | ||
|
|
24e8e20b59 | ||
|
|
a87f09bad2 | ||
|
|
bc6c4cdbfc | ||
|
|
404546ce93 | ||
|
|
6dd1cf1dd6 | ||
|
|
9058d406a3 |
@@ -10,11 +10,11 @@ So you can use local or multi-account CLI access with OpenAI(include Responses)/
|
||||
|
||||
## Sponsor
|
||||
|
||||
[](https://z.ai/subscribe?ic=8JVLJQFSKB)
|
||||
[](https://z.ai/subscribe?ic=8JVLJQFSKB)
|
||||
|
||||
This project is sponsored by Z.ai, supporting us with their GLM CODING PLAN.
|
||||
|
||||
GLM CODING PLAN is a subscription service designed for AI coding, starting at just $3/month. It provides access to their flagship GLM-4.6 model across 10+ popular AI coding tools (Claude Code, Cline, Roo Code, etc.), offering developers top-tier, fast, and stable coding experiences.
|
||||
GLM CODING PLAN is a subscription service designed for AI coding, starting at just $3/month. It provides access to their flagship GLM-4.7 model across 10+ popular AI coding tools (Claude Code, Cline, Roo Code, etc.), offering developers top-tier, fast, and stable coding experiences.
|
||||
|
||||
Get 10% OFF GLM CODING PLAN:https://z.ai/subscribe?ic=8JVLJQFSKB
|
||||
|
||||
@@ -26,6 +26,10 @@ Get 10% OFF GLM CODING PLAN:https://z.ai/subscribe?ic=8JVLJQFSKB
|
||||
<td width="180"><a href="https://www.packyapi.com/register?aff=cliproxyapi"><img src="./assets/packycode.png" alt="PackyCode" width="150"></a></td>
|
||||
<td>Thanks to PackyCode for sponsoring this project! PackyCode is a reliable and efficient API relay service provider, offering relay services for Claude Code, Codex, Gemini, and more. PackyCode provides special discounts for our software users: register using <a href="https://www.packyapi.com/register?aff=cliproxyapi">this link</a> and enter the "cliproxyapi" promo code during recharge to get 10% off.</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td width="180"><a href="https://cubence.com/signup?code=CLIPROXYAPI&source=cpa"><img src="./assets/cubence.png" alt="Cubence" width="150"></a></td>
|
||||
<td>Thanks to Cubence for sponsoring this project! Cubence is a reliable and efficient API relay service provider, offering relay services for Claude Code, Codex, Gemini, and more. Cubence provides special discounts for our software users: register using <a href="https://cubence.com/signup?code=CLIPROXYAPI&source=cpa">this link</a> and enter the "CLIPROXYAPI" promo code during recharge to get 10% off.</td>
|
||||
</tr>
|
||||
</tbody>
|
||||
</table>
|
||||
|
||||
|
||||
@@ -10,11 +10,11 @@
|
||||
|
||||
## 赞助商
|
||||
|
||||
[](https://www.bigmodel.cn/claude-code?ic=RRVJPB5SII)
|
||||
[](https://www.bigmodel.cn/claude-code?ic=RRVJPB5SII)
|
||||
|
||||
本项目由 Z智谱 提供赞助, 他们通过 GLM CODING PLAN 对本项目提供技术支持。
|
||||
|
||||
GLM CODING PLAN 是专为AI编码打造的订阅套餐,每月最低仅需20元,即可在十余款主流AI编码工具如 Claude Code、Cline、Roo Code 中畅享智谱旗舰模型GLM-4.6,为开发者提供顶尖的编码体验。
|
||||
GLM CODING PLAN 是专为AI编码打造的订阅套餐,每月最低仅需20元,即可在十余款主流AI编码工具如 Claude Code、Cline、Roo Code 中畅享智谱旗舰模型GLM-4.7,为开发者提供顶尖的编码体验。
|
||||
|
||||
智谱AI为本软件提供了特别优惠,使用以下链接购买可以享受九折优惠:https://www.bigmodel.cn/claude-code?ic=RRVJPB5SII
|
||||
|
||||
@@ -26,9 +26,14 @@ GLM CODING PLAN 是专为AI编码打造的订阅套餐,每月最低仅需20元
|
||||
<td width="180"><a href="https://www.packyapi.com/register?aff=cliproxyapi"><img src="./assets/packycode.png" alt="PackyCode" width="150"></a></td>
|
||||
<td>感谢 PackyCode 对本项目的赞助!PackyCode 是一家可靠高效的 API 中转服务商,提供 Claude Code、Codex、Gemini 等多种服务的中转。PackyCode 为本软件用户提供了特别优惠:使用<a href="https://www.packyapi.com/register?aff=cliproxyapi">此链接</a>注册,并在充值时输入 "cliproxyapi" 优惠码即可享受九折优惠。</td>
|
||||
</tr>
|
||||
<tr>
|
||||
<td width="180"><a href="https://cubence.com/signup?code=CLIPROXYAPI&source=cpa"><img src="./assets/cubence.png" alt="Cubence" width="150"></a></td>
|
||||
<td>感谢 Cubence 对本项目的赞助!Cubence 是一家可靠高效的 API 中转服务商,提供 Claude Code、Codex、Gemini 等多种服务的中转。Cubence 为本软件用户提供了特别优惠:使用<a href="https://cubence.com/signup?code=CLIPROXYAPI&source=cpa">此链接</a>注册,并在充值时输入 "CLIPROXYAPI" 优惠码即可享受九折优惠。</td>
|
||||
</tr>
|
||||
</tbody>
|
||||
</table>
|
||||
|
||||
|
||||
## 功能特性
|
||||
|
||||
- 为 CLI 模型提供 OpenAI/Gemini/Claude/Codex 兼容的 API 端点
|
||||
|
||||
BIN
assets/cubence.png
Normal file
BIN
assets/cubence.png
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 51 KiB |
@@ -39,6 +39,9 @@ api-keys:
|
||||
# Enable debug logging
|
||||
debug: false
|
||||
|
||||
# When true, disable high-overhead HTTP middleware features to reduce per-request memory usage under high concurrency.
|
||||
commercial-mode: false
|
||||
|
||||
# When true, write application logs to rotating files instead of stdout
|
||||
logging-to-file: false
|
||||
|
||||
@@ -73,6 +76,11 @@ routing:
|
||||
# When true, enable authentication for the WebSocket API (/v1/ws).
|
||||
ws-auth: false
|
||||
|
||||
# Streaming behavior (SSE keep-alives + safe bootstrap retries).
|
||||
# streaming:
|
||||
# keepalive-seconds: 15 # Default: 0 (disabled). <= 0 disables keep-alives.
|
||||
# bootstrap-retries: 1 # Default: 0 (disabled). Retries before first byte is sent.
|
||||
|
||||
# Gemini API keys
|
||||
# gemini-api-key:
|
||||
# - api-key: "AIzaSy...01"
|
||||
|
||||
@@ -209,6 +209,94 @@ func (h *Handler) GetRequestErrorLogs(c *gin.Context) {
|
||||
c.JSON(http.StatusOK, gin.H{"files": files})
|
||||
}
|
||||
|
||||
// GetRequestLogByID finds and downloads a request log file by its request ID.
|
||||
// The ID is matched against the suffix of log file names (format: *-{requestID}.log).
|
||||
func (h *Handler) GetRequestLogByID(c *gin.Context) {
|
||||
if h == nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "handler unavailable"})
|
||||
return
|
||||
}
|
||||
if h.cfg == nil {
|
||||
c.JSON(http.StatusServiceUnavailable, gin.H{"error": "configuration unavailable"})
|
||||
return
|
||||
}
|
||||
|
||||
dir := h.logDirectory()
|
||||
if strings.TrimSpace(dir) == "" {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "log directory not configured"})
|
||||
return
|
||||
}
|
||||
|
||||
requestID := strings.TrimSpace(c.Param("id"))
|
||||
if requestID == "" {
|
||||
requestID = strings.TrimSpace(c.Query("id"))
|
||||
}
|
||||
if requestID == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "missing request ID"})
|
||||
return
|
||||
}
|
||||
if strings.ContainsAny(requestID, "/\\") {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid request ID"})
|
||||
return
|
||||
}
|
||||
|
||||
entries, err := os.ReadDir(dir)
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "log directory not found"})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("failed to list log directory: %v", err)})
|
||||
return
|
||||
}
|
||||
|
||||
suffix := "-" + requestID + ".log"
|
||||
var matchedFile string
|
||||
for _, entry := range entries {
|
||||
if entry.IsDir() {
|
||||
continue
|
||||
}
|
||||
name := entry.Name()
|
||||
if strings.HasSuffix(name, suffix) {
|
||||
matchedFile = name
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if matchedFile == "" {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "log file not found for the given request ID"})
|
||||
return
|
||||
}
|
||||
|
||||
dirAbs, errAbs := filepath.Abs(dir)
|
||||
if errAbs != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("failed to resolve log directory: %v", errAbs)})
|
||||
return
|
||||
}
|
||||
fullPath := filepath.Clean(filepath.Join(dirAbs, matchedFile))
|
||||
prefix := dirAbs + string(os.PathSeparator)
|
||||
if !strings.HasPrefix(fullPath, prefix) {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid log file path"})
|
||||
return
|
||||
}
|
||||
|
||||
info, errStat := os.Stat(fullPath)
|
||||
if errStat != nil {
|
||||
if os.IsNotExist(errStat) {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "log file not found"})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("failed to read log file: %v", errStat)})
|
||||
return
|
||||
}
|
||||
if info.IsDir() {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid log file"})
|
||||
return
|
||||
}
|
||||
|
||||
c.FileAttachment(fullPath, matchedFile)
|
||||
}
|
||||
|
||||
// DownloadRequestErrorLog downloads a specific error request log file by name.
|
||||
func (h *Handler) DownloadRequestErrorLog(c *gin.Context) {
|
||||
if h == nil {
|
||||
|
||||
@@ -98,10 +98,11 @@ func captureRequestInfo(c *gin.Context) (*RequestInfo, error) {
|
||||
}
|
||||
|
||||
return &RequestInfo{
|
||||
URL: url,
|
||||
Method: method,
|
||||
Headers: headers,
|
||||
Body: body,
|
||||
URL: url,
|
||||
Method: method,
|
||||
Headers: headers,
|
||||
Body: body,
|
||||
RequestID: logging.GetGinRequestID(c),
|
||||
}, nil
|
||||
}
|
||||
|
||||
|
||||
@@ -15,10 +15,11 @@ import (
|
||||
|
||||
// RequestInfo holds essential details of an incoming HTTP request for logging purposes.
|
||||
type RequestInfo struct {
|
||||
URL string // URL is the request URL.
|
||||
Method string // Method is the HTTP method (e.g., GET, POST).
|
||||
Headers map[string][]string // Headers contains the request headers.
|
||||
Body []byte // Body is the raw request body.
|
||||
URL string // URL is the request URL.
|
||||
Method string // Method is the HTTP method (e.g., GET, POST).
|
||||
Headers map[string][]string // Headers contains the request headers.
|
||||
Body []byte // Body is the raw request body.
|
||||
RequestID string // RequestID is the unique identifier for the request.
|
||||
}
|
||||
|
||||
// ResponseWriterWrapper wraps the standard gin.ResponseWriter to intercept and log response data.
|
||||
@@ -149,6 +150,7 @@ func (w *ResponseWriterWrapper) WriteHeader(statusCode int) {
|
||||
w.requestInfo.Method,
|
||||
w.requestInfo.Headers,
|
||||
w.requestInfo.Body,
|
||||
w.requestInfo.RequestID,
|
||||
)
|
||||
if err == nil {
|
||||
w.streamWriter = streamWriter
|
||||
@@ -346,7 +348,7 @@ func (w *ResponseWriterWrapper) logRequest(statusCode int, headers map[string][]
|
||||
}
|
||||
|
||||
if loggerWithOptions, ok := w.logger.(interface {
|
||||
LogRequestWithOptions(string, string, map[string][]string, []byte, int, map[string][]string, []byte, []byte, []byte, []*interfaces.ErrorMessage, bool) error
|
||||
LogRequestWithOptions(string, string, map[string][]string, []byte, int, map[string][]string, []byte, []byte, []byte, []*interfaces.ErrorMessage, bool, string) error
|
||||
}); ok {
|
||||
return loggerWithOptions.LogRequestWithOptions(
|
||||
w.requestInfo.URL,
|
||||
@@ -360,6 +362,7 @@ func (w *ResponseWriterWrapper) logRequest(statusCode int, headers map[string][]
|
||||
apiResponseBody,
|
||||
apiResponseErrors,
|
||||
forceLog,
|
||||
w.requestInfo.RequestID,
|
||||
)
|
||||
}
|
||||
|
||||
@@ -374,5 +377,6 @@ func (w *ResponseWriterWrapper) logRequest(statusCode int, headers map[string][]
|
||||
apiRequestBody,
|
||||
apiResponseBody,
|
||||
apiResponseErrors,
|
||||
w.requestInfo.RequestID,
|
||||
)
|
||||
}
|
||||
|
||||
@@ -279,16 +279,23 @@ func (m *AmpModule) hasModelMappingsChanged(old *config.AmpCode, new *config.Amp
|
||||
return true
|
||||
}
|
||||
|
||||
// Build map for efficient comparison
|
||||
oldMap := make(map[string]string, len(old.ModelMappings))
|
||||
// Build map for efficient and robust comparison
|
||||
type mappingInfo struct {
|
||||
to string
|
||||
regex bool
|
||||
}
|
||||
oldMap := make(map[string]mappingInfo, len(old.ModelMappings))
|
||||
for _, mapping := range old.ModelMappings {
|
||||
oldMap[strings.TrimSpace(mapping.From)] = strings.TrimSpace(mapping.To)
|
||||
oldMap[strings.TrimSpace(mapping.From)] = mappingInfo{
|
||||
to: strings.TrimSpace(mapping.To),
|
||||
regex: mapping.Regex,
|
||||
}
|
||||
}
|
||||
|
||||
for _, mapping := range new.ModelMappings {
|
||||
from := strings.TrimSpace(mapping.From)
|
||||
to := strings.TrimSpace(mapping.To)
|
||||
if oldTo, exists := oldMap[from]; !exists || oldTo != to {
|
||||
if oldVal, exists := oldMap[from]; !exists || oldVal.to != to || oldVal.regex != mapping.Regex {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
package amp
|
||||
|
||||
import (
|
||||
"regexp"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
@@ -26,13 +27,15 @@ type ModelMapper interface {
|
||||
// DefaultModelMapper implements ModelMapper with thread-safe mapping storage.
|
||||
type DefaultModelMapper struct {
|
||||
mu sync.RWMutex
|
||||
mappings map[string]string // from -> to (normalized lowercase keys)
|
||||
mappings map[string]string // exact: from -> to (normalized lowercase keys)
|
||||
regexps []regexMapping // regex rules evaluated in order
|
||||
}
|
||||
|
||||
// NewModelMapper creates a new model mapper with the given initial mappings.
|
||||
func NewModelMapper(mappings []config.AmpModelMapping) *DefaultModelMapper {
|
||||
m := &DefaultModelMapper{
|
||||
mappings: make(map[string]string),
|
||||
regexps: nil,
|
||||
}
|
||||
m.UpdateMappings(mappings)
|
||||
return m
|
||||
@@ -55,7 +58,18 @@ func (m *DefaultModelMapper) MapModel(requestedModel string) string {
|
||||
// Check for direct mapping
|
||||
targetModel, exists := m.mappings[normalizedRequest]
|
||||
if !exists {
|
||||
return ""
|
||||
// Try regex mappings in order
|
||||
base, _ := util.NormalizeThinkingModel(requestedModel)
|
||||
for _, rm := range m.regexps {
|
||||
if rm.re.MatchString(requestedModel) || (base != "" && rm.re.MatchString(base)) {
|
||||
targetModel = rm.to
|
||||
exists = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !exists {
|
||||
return ""
|
||||
}
|
||||
}
|
||||
|
||||
// Verify target model has available providers
|
||||
@@ -78,6 +92,7 @@ func (m *DefaultModelMapper) UpdateMappings(mappings []config.AmpModelMapping) {
|
||||
|
||||
// Clear and rebuild mappings
|
||||
m.mappings = make(map[string]string, len(mappings))
|
||||
m.regexps = make([]regexMapping, 0, len(mappings))
|
||||
|
||||
for _, mapping := range mappings {
|
||||
from := strings.TrimSpace(mapping.From)
|
||||
@@ -88,16 +103,30 @@ func (m *DefaultModelMapper) UpdateMappings(mappings []config.AmpModelMapping) {
|
||||
continue
|
||||
}
|
||||
|
||||
// Store with normalized lowercase key for case-insensitive lookup
|
||||
normalizedFrom := strings.ToLower(from)
|
||||
m.mappings[normalizedFrom] = to
|
||||
|
||||
log.Debugf("amp model mapping registered: %s -> %s", from, to)
|
||||
if mapping.Regex {
|
||||
// Compile case-insensitive regex; wrap with (?i) to match behavior of exact lookups
|
||||
pattern := "(?i)" + from
|
||||
re, err := regexp.Compile(pattern)
|
||||
if err != nil {
|
||||
log.Warnf("amp model mapping: invalid regex %q: %v", from, err)
|
||||
continue
|
||||
}
|
||||
m.regexps = append(m.regexps, regexMapping{re: re, to: to})
|
||||
log.Debugf("amp model regex mapping registered: /%s/ -> %s", from, to)
|
||||
} else {
|
||||
// Store with normalized lowercase key for case-insensitive lookup
|
||||
normalizedFrom := strings.ToLower(from)
|
||||
m.mappings[normalizedFrom] = to
|
||||
log.Debugf("amp model mapping registered: %s -> %s", from, to)
|
||||
}
|
||||
}
|
||||
|
||||
if len(m.mappings) > 0 {
|
||||
log.Infof("amp model mapping: loaded %d mapping(s)", len(m.mappings))
|
||||
}
|
||||
if n := len(m.regexps); n > 0 {
|
||||
log.Infof("amp model mapping: loaded %d regex mapping(s)", n)
|
||||
}
|
||||
}
|
||||
|
||||
// GetMappings returns a copy of current mappings (for debugging/status).
|
||||
@@ -111,3 +140,8 @@ func (m *DefaultModelMapper) GetMappings() map[string]string {
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
type regexMapping struct {
|
||||
re *regexp.Regexp
|
||||
to string
|
||||
}
|
||||
|
||||
@@ -203,3 +203,81 @@ func TestModelMapper_GetMappings_ReturnsCopy(t *testing.T) {
|
||||
t.Error("Original map was modified")
|
||||
}
|
||||
}
|
||||
|
||||
func TestModelMapper_Regex_MatchBaseWithoutParens(t *testing.T) {
|
||||
reg := registry.GetGlobalRegistry()
|
||||
reg.RegisterClient("test-client-regex-1", "gemini", []*registry.ModelInfo{
|
||||
{ID: "gemini-2.5-pro", OwnedBy: "google", Type: "gemini"},
|
||||
})
|
||||
defer reg.UnregisterClient("test-client-regex-1")
|
||||
|
||||
mappings := []config.AmpModelMapping{
|
||||
{From: "^gpt-5$", To: "gemini-2.5-pro", Regex: true},
|
||||
}
|
||||
|
||||
mapper := NewModelMapper(mappings)
|
||||
|
||||
// Incoming model has reasoning suffix but should match base via regex
|
||||
result := mapper.MapModel("gpt-5(high)")
|
||||
if result != "gemini-2.5-pro" {
|
||||
t.Errorf("Expected gemini-2.5-pro, got %s", result)
|
||||
}
|
||||
}
|
||||
|
||||
func TestModelMapper_Regex_ExactPrecedence(t *testing.T) {
|
||||
reg := registry.GetGlobalRegistry()
|
||||
reg.RegisterClient("test-client-regex-2", "claude", []*registry.ModelInfo{
|
||||
{ID: "claude-sonnet-4", OwnedBy: "anthropic", Type: "claude"},
|
||||
})
|
||||
reg.RegisterClient("test-client-regex-3", "gemini", []*registry.ModelInfo{
|
||||
{ID: "gemini-2.5-pro", OwnedBy: "google", Type: "gemini"},
|
||||
})
|
||||
defer reg.UnregisterClient("test-client-regex-2")
|
||||
defer reg.UnregisterClient("test-client-regex-3")
|
||||
|
||||
mappings := []config.AmpModelMapping{
|
||||
{From: "gpt-5", To: "claude-sonnet-4"}, // exact
|
||||
{From: "^gpt-5.*$", To: "gemini-2.5-pro", Regex: true}, // regex
|
||||
}
|
||||
|
||||
mapper := NewModelMapper(mappings)
|
||||
|
||||
// Exact match should win over regex
|
||||
result := mapper.MapModel("gpt-5")
|
||||
if result != "claude-sonnet-4" {
|
||||
t.Errorf("Expected claude-sonnet-4, got %s", result)
|
||||
}
|
||||
}
|
||||
|
||||
func TestModelMapper_Regex_InvalidPattern_Skipped(t *testing.T) {
|
||||
// Invalid regex should be skipped and not cause panic
|
||||
mappings := []config.AmpModelMapping{
|
||||
{From: "(", To: "target", Regex: true},
|
||||
}
|
||||
|
||||
mapper := NewModelMapper(mappings)
|
||||
|
||||
result := mapper.MapModel("anything")
|
||||
if result != "" {
|
||||
t.Errorf("Expected empty result due to invalid regex, got %s", result)
|
||||
}
|
||||
}
|
||||
|
||||
func TestModelMapper_Regex_CaseInsensitive(t *testing.T) {
|
||||
reg := registry.GetGlobalRegistry()
|
||||
reg.RegisterClient("test-client-regex-4", "claude", []*registry.ModelInfo{
|
||||
{ID: "claude-sonnet-4", OwnedBy: "anthropic", Type: "claude"},
|
||||
})
|
||||
defer reg.UnregisterClient("test-client-regex-4")
|
||||
|
||||
mappings := []config.AmpModelMapping{
|
||||
{From: "^CLAUDE-OPUS-.*$", To: "claude-sonnet-4", Regex: true},
|
||||
}
|
||||
|
||||
mapper := NewModelMapper(mappings)
|
||||
|
||||
result := mapper.MapModel("claude-opus-4.5")
|
||||
if result != "claude-sonnet-4" {
|
||||
t.Errorf("Expected claude-sonnet-4, got %s", result)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -209,13 +209,15 @@ func NewServer(cfg *config.Config, authManager *auth.Manager, accessManager *sdk
|
||||
// Resolve logs directory relative to the configuration file directory.
|
||||
var requestLogger logging.RequestLogger
|
||||
var toggle func(bool)
|
||||
if optionState.requestLoggerFactory != nil {
|
||||
requestLogger = optionState.requestLoggerFactory(cfg, configFilePath)
|
||||
}
|
||||
if requestLogger != nil {
|
||||
engine.Use(middleware.RequestLoggingMiddleware(requestLogger))
|
||||
if setter, ok := requestLogger.(interface{ SetEnabled(bool) }); ok {
|
||||
toggle = setter.SetEnabled
|
||||
if !cfg.CommercialMode {
|
||||
if optionState.requestLoggerFactory != nil {
|
||||
requestLogger = optionState.requestLoggerFactory(cfg, configFilePath)
|
||||
}
|
||||
if requestLogger != nil {
|
||||
engine.Use(middleware.RequestLoggingMiddleware(requestLogger))
|
||||
if setter, ok := requestLogger.(interface{ SetEnabled(bool) }); ok {
|
||||
toggle = setter.SetEnabled
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -518,6 +520,7 @@ func (s *Server) registerManagementRoutes() {
|
||||
mgmt.DELETE("/logs", s.mgmt.DeleteLogs)
|
||||
mgmt.GET("/request-error-logs", s.mgmt.GetRequestErrorLogs)
|
||||
mgmt.GET("/request-error-logs/:name", s.mgmt.DownloadRequestErrorLog)
|
||||
mgmt.GET("/request-log-by-id/:id", s.mgmt.GetRequestLogByID)
|
||||
mgmt.GET("/request-log", s.mgmt.GetRequestLog)
|
||||
mgmt.PUT("/request-log", s.mgmt.PutRequestLog)
|
||||
mgmt.PATCH("/request-log", s.mgmt.PutRequestLog)
|
||||
|
||||
@@ -39,6 +39,9 @@ type Config struct {
|
||||
// Debug enables or disables debug-level logging and other debug features.
|
||||
Debug bool `yaml:"debug" json:"debug"`
|
||||
|
||||
// CommercialMode disables high-overhead HTTP middleware features to minimize per-request memory usage.
|
||||
CommercialMode bool `yaml:"commercial-mode" json:"commercial-mode"`
|
||||
|
||||
// LoggingToFile controls whether application logs are written to rotating files or stdout.
|
||||
LoggingToFile bool `yaml:"logging-to-file" json:"logging-to-file"`
|
||||
|
||||
@@ -144,6 +147,11 @@ type AmpModelMapping struct {
|
||||
// To is the target model name to route to (e.g., "claude-sonnet-4").
|
||||
// The target model must have available providers in the registry.
|
||||
To string `yaml:"to" json:"to"`
|
||||
|
||||
// Regex indicates whether the 'from' field should be interpreted as a regular
|
||||
// expression for matching model names. When true, this mapping is evaluated
|
||||
// after exact matches and in the order provided. Defaults to false (exact match).
|
||||
Regex bool `yaml:"regex,omitempty" json:"regex,omitempty"`
|
||||
}
|
||||
|
||||
// AmpCode groups Amp CLI integration settings including upstream routing,
|
||||
|
||||
@@ -22,6 +22,21 @@ type SDKConfig struct {
|
||||
|
||||
// Access holds request authentication provider configuration.
|
||||
Access AccessConfig `yaml:"auth,omitempty" json:"auth,omitempty"`
|
||||
|
||||
// Streaming configures server-side streaming behavior (keep-alives and safe bootstrap retries).
|
||||
Streaming StreamingConfig `yaml:"streaming" json:"streaming"`
|
||||
}
|
||||
|
||||
// StreamingConfig holds server streaming behavior configuration.
|
||||
type StreamingConfig struct {
|
||||
// KeepAliveSeconds controls how often the server emits SSE heartbeats (": keep-alive\n\n").
|
||||
// nil means default (15 seconds). <= 0 disables keep-alives.
|
||||
KeepAliveSeconds *int `yaml:"keepalive-seconds,omitempty" json:"keepalive-seconds,omitempty"`
|
||||
|
||||
// BootstrapRetries controls how many times the server may retry a streaming request before any bytes are sent,
|
||||
// to allow auth rotation / transient recovery.
|
||||
// nil means default (2). 0 disables bootstrap retries.
|
||||
BootstrapRetries *int `yaml:"bootstrap-retries,omitempty" json:"bootstrap-retries,omitempty"`
|
||||
}
|
||||
|
||||
// AccessConfig groups request authentication providers.
|
||||
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"runtime/debug"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
@@ -14,11 +15,24 @@ import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// aiAPIPrefixes defines path prefixes for AI API requests that should have request ID tracking.
|
||||
var aiAPIPrefixes = []string{
|
||||
"/v1/chat/completions",
|
||||
"/v1/completions",
|
||||
"/v1/messages",
|
||||
"/v1/responses",
|
||||
"/v1beta/models/",
|
||||
"/api/provider/",
|
||||
}
|
||||
|
||||
const skipGinLogKey = "__gin_skip_request_logging__"
|
||||
|
||||
// GinLogrusLogger returns a Gin middleware handler that logs HTTP requests and responses
|
||||
// using logrus. It captures request details including method, path, status code, latency,
|
||||
// client IP, and any error messages, formatting them in a Gin-style log format.
|
||||
// client IP, and any error messages. Request ID is only added for AI API requests.
|
||||
//
|
||||
// Output format (AI API): [2025-12-23 20:14:10] [info ] | a1b2c3d4 | 200 | 23.559s | ...
|
||||
// Output format (others): [2025-12-23 20:14:10] [info ] | -------- | 200 | 23.559s | ...
|
||||
//
|
||||
// Returns:
|
||||
// - gin.HandlerFunc: A middleware handler for request logging
|
||||
@@ -28,6 +42,15 @@ func GinLogrusLogger() gin.HandlerFunc {
|
||||
path := c.Request.URL.Path
|
||||
raw := util.MaskSensitiveQuery(c.Request.URL.RawQuery)
|
||||
|
||||
// Only generate request ID for AI API paths
|
||||
var requestID string
|
||||
if isAIAPIPath(path) {
|
||||
requestID = GenerateRequestID()
|
||||
SetGinRequestID(c, requestID)
|
||||
ctx := WithRequestID(c.Request.Context(), requestID)
|
||||
c.Request = c.Request.WithContext(ctx)
|
||||
}
|
||||
|
||||
c.Next()
|
||||
|
||||
if shouldSkipGinRequestLogging(c) {
|
||||
@@ -49,23 +72,38 @@ func GinLogrusLogger() gin.HandlerFunc {
|
||||
clientIP := c.ClientIP()
|
||||
method := c.Request.Method
|
||||
errorMessage := c.Errors.ByType(gin.ErrorTypePrivate).String()
|
||||
timestamp := time.Now().Format("2006/01/02 - 15:04:05")
|
||||
logLine := fmt.Sprintf("[GIN] %s | %3d | %13v | %15s | %-7s \"%s\"", timestamp, statusCode, latency, clientIP, method, path)
|
||||
|
||||
if requestID == "" {
|
||||
requestID = "--------"
|
||||
}
|
||||
logLine := fmt.Sprintf("%3d | %13v | %15s | %-7s \"%s\"", statusCode, latency, clientIP, method, path)
|
||||
if errorMessage != "" {
|
||||
logLine = logLine + " | " + errorMessage
|
||||
}
|
||||
|
||||
entry := log.WithField("request_id", requestID)
|
||||
|
||||
switch {
|
||||
case statusCode >= http.StatusInternalServerError:
|
||||
log.Error(logLine)
|
||||
entry.Error(logLine)
|
||||
case statusCode >= http.StatusBadRequest:
|
||||
log.Warn(logLine)
|
||||
entry.Warn(logLine)
|
||||
default:
|
||||
log.Info(logLine)
|
||||
entry.Info(logLine)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// isAIAPIPath checks if the given path is an AI API endpoint that should have request ID tracking.
|
||||
func isAIAPIPath(path string) bool {
|
||||
for _, prefix := range aiAPIPrefixes {
|
||||
if strings.HasPrefix(path, prefix) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// GinLogrusRecovery returns a Gin middleware handler that recovers from panics and logs
|
||||
// them using logrus. When a panic occurs, it captures the panic value, stack trace,
|
||||
// and request path, then returns a 500 Internal Server Error response to the client.
|
||||
|
||||
@@ -24,7 +24,8 @@ var (
|
||||
)
|
||||
|
||||
// LogFormatter defines a custom log format for logrus.
|
||||
// This formatter adds timestamp, level, and source location to each log entry.
|
||||
// This formatter adds timestamp, level, request ID, and source location to each log entry.
|
||||
// Format: [2025-12-23 20:14:04] [debug] [manager.go:524] | a1b2c3d4 | Use API key sk-9...0RHO for model gpt-5.2
|
||||
type LogFormatter struct{}
|
||||
|
||||
// Format renders a single log entry with custom formatting.
|
||||
@@ -39,11 +40,22 @@ func (m *LogFormatter) Format(entry *log.Entry) ([]byte, error) {
|
||||
timestamp := entry.Time.Format("2006-01-02 15:04:05")
|
||||
message := strings.TrimRight(entry.Message, "\r\n")
|
||||
|
||||
reqID := "--------"
|
||||
if id, ok := entry.Data["request_id"].(string); ok && id != "" {
|
||||
reqID = id
|
||||
}
|
||||
|
||||
level := entry.Level.String()
|
||||
if level == "warning" {
|
||||
level = "warn"
|
||||
}
|
||||
levelStr := fmt.Sprintf("%-5s", level)
|
||||
|
||||
var formatted string
|
||||
if entry.Caller != nil {
|
||||
formatted = fmt.Sprintf("[%s] [%s] [%s:%d] %s\n", timestamp, entry.Level, filepath.Base(entry.Caller.File), entry.Caller.Line, message)
|
||||
formatted = fmt.Sprintf("[%s] [%s] [%s] [%s:%d] %s\n", timestamp, reqID, levelStr, filepath.Base(entry.Caller.File), entry.Caller.Line, message)
|
||||
} else {
|
||||
formatted = fmt.Sprintf("[%s] [%s] %s\n", timestamp, entry.Level, message)
|
||||
formatted = fmt.Sprintf("[%s] [%s] [%s] %s\n", timestamp, reqID, levelStr, message)
|
||||
}
|
||||
buffer.WriteString(formatted)
|
||||
|
||||
|
||||
@@ -43,10 +43,11 @@ type RequestLogger interface {
|
||||
// - response: The raw response data
|
||||
// - apiRequest: The API request data
|
||||
// - apiResponse: The API response data
|
||||
// - requestID: Optional request ID for log file naming
|
||||
//
|
||||
// Returns:
|
||||
// - error: An error if logging fails, nil otherwise
|
||||
LogRequest(url, method string, requestHeaders map[string][]string, body []byte, statusCode int, responseHeaders map[string][]string, response, apiRequest, apiResponse []byte, apiResponseErrors []*interfaces.ErrorMessage) error
|
||||
LogRequest(url, method string, requestHeaders map[string][]string, body []byte, statusCode int, responseHeaders map[string][]string, response, apiRequest, apiResponse []byte, apiResponseErrors []*interfaces.ErrorMessage, requestID string) error
|
||||
|
||||
// LogStreamingRequest initiates logging for a streaming request and returns a writer for chunks.
|
||||
//
|
||||
@@ -55,11 +56,12 @@ type RequestLogger interface {
|
||||
// - method: The HTTP method
|
||||
// - headers: The request headers
|
||||
// - body: The request body
|
||||
// - requestID: Optional request ID for log file naming
|
||||
//
|
||||
// Returns:
|
||||
// - StreamingLogWriter: A writer for streaming response chunks
|
||||
// - error: An error if logging initialization fails, nil otherwise
|
||||
LogStreamingRequest(url, method string, headers map[string][]string, body []byte) (StreamingLogWriter, error)
|
||||
LogStreamingRequest(url, method string, headers map[string][]string, body []byte, requestID string) (StreamingLogWriter, error)
|
||||
|
||||
// IsEnabled returns whether request logging is currently enabled.
|
||||
//
|
||||
@@ -177,20 +179,21 @@ func (l *FileRequestLogger) SetEnabled(enabled bool) {
|
||||
// - response: The raw response data
|
||||
// - apiRequest: The API request data
|
||||
// - apiResponse: The API response data
|
||||
// - requestID: Optional request ID for log file naming
|
||||
//
|
||||
// Returns:
|
||||
// - error: An error if logging fails, nil otherwise
|
||||
func (l *FileRequestLogger) LogRequest(url, method string, requestHeaders map[string][]string, body []byte, statusCode int, responseHeaders map[string][]string, response, apiRequest, apiResponse []byte, apiResponseErrors []*interfaces.ErrorMessage) error {
|
||||
return l.logRequest(url, method, requestHeaders, body, statusCode, responseHeaders, response, apiRequest, apiResponse, apiResponseErrors, false)
|
||||
func (l *FileRequestLogger) LogRequest(url, method string, requestHeaders map[string][]string, body []byte, statusCode int, responseHeaders map[string][]string, response, apiRequest, apiResponse []byte, apiResponseErrors []*interfaces.ErrorMessage, requestID string) error {
|
||||
return l.logRequest(url, method, requestHeaders, body, statusCode, responseHeaders, response, apiRequest, apiResponse, apiResponseErrors, false, requestID)
|
||||
}
|
||||
|
||||
// LogRequestWithOptions logs a request with optional forced logging behavior.
|
||||
// The force flag allows writing error logs even when regular request logging is disabled.
|
||||
func (l *FileRequestLogger) LogRequestWithOptions(url, method string, requestHeaders map[string][]string, body []byte, statusCode int, responseHeaders map[string][]string, response, apiRequest, apiResponse []byte, apiResponseErrors []*interfaces.ErrorMessage, force bool) error {
|
||||
return l.logRequest(url, method, requestHeaders, body, statusCode, responseHeaders, response, apiRequest, apiResponse, apiResponseErrors, force)
|
||||
func (l *FileRequestLogger) LogRequestWithOptions(url, method string, requestHeaders map[string][]string, body []byte, statusCode int, responseHeaders map[string][]string, response, apiRequest, apiResponse []byte, apiResponseErrors []*interfaces.ErrorMessage, force bool, requestID string) error {
|
||||
return l.logRequest(url, method, requestHeaders, body, statusCode, responseHeaders, response, apiRequest, apiResponse, apiResponseErrors, force, requestID)
|
||||
}
|
||||
|
||||
func (l *FileRequestLogger) logRequest(url, method string, requestHeaders map[string][]string, body []byte, statusCode int, responseHeaders map[string][]string, response, apiRequest, apiResponse []byte, apiResponseErrors []*interfaces.ErrorMessage, force bool) error {
|
||||
func (l *FileRequestLogger) logRequest(url, method string, requestHeaders map[string][]string, body []byte, statusCode int, responseHeaders map[string][]string, response, apiRequest, apiResponse []byte, apiResponseErrors []*interfaces.ErrorMessage, force bool, requestID string) error {
|
||||
if !l.enabled && !force {
|
||||
return nil
|
||||
}
|
||||
@@ -200,10 +203,10 @@ func (l *FileRequestLogger) logRequest(url, method string, requestHeaders map[st
|
||||
return fmt.Errorf("failed to create logs directory: %w", errEnsure)
|
||||
}
|
||||
|
||||
// Generate filename
|
||||
filename := l.generateFilename(url)
|
||||
// Generate filename with request ID
|
||||
filename := l.generateFilename(url, requestID)
|
||||
if force && !l.enabled {
|
||||
filename = l.generateErrorFilename(url)
|
||||
filename = l.generateErrorFilename(url, requestID)
|
||||
}
|
||||
filePath := filepath.Join(l.logsDir, filename)
|
||||
|
||||
@@ -271,11 +274,12 @@ func (l *FileRequestLogger) logRequest(url, method string, requestHeaders map[st
|
||||
// - method: The HTTP method
|
||||
// - headers: The request headers
|
||||
// - body: The request body
|
||||
// - requestID: Optional request ID for log file naming
|
||||
//
|
||||
// Returns:
|
||||
// - StreamingLogWriter: A writer for streaming response chunks
|
||||
// - error: An error if logging initialization fails, nil otherwise
|
||||
func (l *FileRequestLogger) LogStreamingRequest(url, method string, headers map[string][]string, body []byte) (StreamingLogWriter, error) {
|
||||
func (l *FileRequestLogger) LogStreamingRequest(url, method string, headers map[string][]string, body []byte, requestID string) (StreamingLogWriter, error) {
|
||||
if !l.enabled {
|
||||
return &NoOpStreamingLogWriter{}, nil
|
||||
}
|
||||
@@ -285,8 +289,8 @@ func (l *FileRequestLogger) LogStreamingRequest(url, method string, headers map[
|
||||
return nil, fmt.Errorf("failed to create logs directory: %w", err)
|
||||
}
|
||||
|
||||
// Generate filename
|
||||
filename := l.generateFilename(url)
|
||||
// Generate filename with request ID
|
||||
filename := l.generateFilename(url, requestID)
|
||||
filePath := filepath.Join(l.logsDir, filename)
|
||||
|
||||
requestHeaders := make(map[string][]string, len(headers))
|
||||
@@ -330,8 +334,8 @@ func (l *FileRequestLogger) LogStreamingRequest(url, method string, headers map[
|
||||
}
|
||||
|
||||
// generateErrorFilename creates a filename with an error prefix to differentiate forced error logs.
|
||||
func (l *FileRequestLogger) generateErrorFilename(url string) string {
|
||||
return fmt.Sprintf("error-%s", l.generateFilename(url))
|
||||
func (l *FileRequestLogger) generateErrorFilename(url string, requestID ...string) string {
|
||||
return fmt.Sprintf("error-%s", l.generateFilename(url, requestID...))
|
||||
}
|
||||
|
||||
// ensureLogsDir creates the logs directory if it doesn't exist.
|
||||
@@ -346,13 +350,15 @@ func (l *FileRequestLogger) ensureLogsDir() error {
|
||||
}
|
||||
|
||||
// generateFilename creates a sanitized filename from the URL path and current timestamp.
|
||||
// Format: v1-responses-2025-12-23T195811-a1b2c3d4.log
|
||||
//
|
||||
// Parameters:
|
||||
// - url: The request URL
|
||||
// - requestID: Optional request ID to include in filename
|
||||
//
|
||||
// Returns:
|
||||
// - string: A sanitized filename for the log file
|
||||
func (l *FileRequestLogger) generateFilename(url string) string {
|
||||
func (l *FileRequestLogger) generateFilename(url string, requestID ...string) string {
|
||||
// Extract path from URL
|
||||
path := url
|
||||
if strings.Contains(url, "?") {
|
||||
@@ -368,12 +374,18 @@ func (l *FileRequestLogger) generateFilename(url string) string {
|
||||
sanitized := l.sanitizeForFilename(path)
|
||||
|
||||
// Add timestamp
|
||||
timestamp := time.Now().Format("2006-01-02T150405-.000000000")
|
||||
timestamp = strings.Replace(timestamp, ".", "", -1)
|
||||
timestamp := time.Now().Format("2006-01-02T150405")
|
||||
|
||||
id := requestLogID.Add(1)
|
||||
// Use request ID if provided, otherwise use sequential ID
|
||||
var idPart string
|
||||
if len(requestID) > 0 && requestID[0] != "" {
|
||||
idPart = requestID[0]
|
||||
} else {
|
||||
id := requestLogID.Add(1)
|
||||
idPart = fmt.Sprintf("%d", id)
|
||||
}
|
||||
|
||||
return fmt.Sprintf("%s-%s-%d.log", sanitized, timestamp, id)
|
||||
return fmt.Sprintf("%s-%s-%s.log", sanitized, timestamp, idPart)
|
||||
}
|
||||
|
||||
// sanitizeForFilename replaces characters that are not safe for filenames.
|
||||
|
||||
61
internal/logging/requestid.go
Normal file
61
internal/logging/requestid.go
Normal file
@@ -0,0 +1,61 @@
|
||||
package logging
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"encoding/hex"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// requestIDKey is the context key for storing/retrieving request IDs.
|
||||
type requestIDKey struct{}
|
||||
|
||||
// ginRequestIDKey is the Gin context key for request IDs.
|
||||
const ginRequestIDKey = "__request_id__"
|
||||
|
||||
// GenerateRequestID creates a new 8-character hex request ID.
|
||||
func GenerateRequestID() string {
|
||||
b := make([]byte, 4)
|
||||
if _, err := rand.Read(b); err != nil {
|
||||
return "00000000"
|
||||
}
|
||||
return hex.EncodeToString(b)
|
||||
}
|
||||
|
||||
// WithRequestID returns a new context with the request ID attached.
|
||||
func WithRequestID(ctx context.Context, requestID string) context.Context {
|
||||
return context.WithValue(ctx, requestIDKey{}, requestID)
|
||||
}
|
||||
|
||||
// GetRequestID retrieves the request ID from the context.
|
||||
// Returns empty string if not found.
|
||||
func GetRequestID(ctx context.Context) string {
|
||||
if ctx == nil {
|
||||
return ""
|
||||
}
|
||||
if id, ok := ctx.Value(requestIDKey{}).(string); ok {
|
||||
return id
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// SetGinRequestID stores the request ID in the Gin context.
|
||||
func SetGinRequestID(c *gin.Context, requestID string) {
|
||||
if c != nil {
|
||||
c.Set(ginRequestIDKey, requestID)
|
||||
}
|
||||
}
|
||||
|
||||
// GetGinRequestID retrieves the request ID from the Gin context.
|
||||
func GetGinRequestID(c *gin.Context) string {
|
||||
if c == nil {
|
||||
return ""
|
||||
}
|
||||
if id, exists := c.Get(ginRequestIDKey); exists {
|
||||
if s, ok := id.(string); ok {
|
||||
return s
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
@@ -727,6 +727,7 @@ func GetIFlowModels() []*ModelInfo {
|
||||
{ID: "qwen3-max-preview", DisplayName: "Qwen3-Max-Preview", Description: "Qwen3 Max preview build", Created: 1757030400},
|
||||
{ID: "kimi-k2-0905", DisplayName: "Kimi-K2-Instruct-0905", Description: "Moonshot Kimi K2 instruct 0905", Created: 1757030400},
|
||||
{ID: "glm-4.6", DisplayName: "GLM-4.6", Description: "Zhipu GLM 4.6 general model", Created: 1759190400, Thinking: iFlowThinkingSupport},
|
||||
{ID: "glm-4.7", DisplayName: "GLM-4.7", Description: "Zhipu GLM 4.7 general model", Created: 1766448000, Thinking: iFlowThinkingSupport},
|
||||
{ID: "kimi-k2", DisplayName: "Kimi-K2", Description: "Moonshot Kimi K2 general model", Created: 1752192000},
|
||||
{ID: "kimi-k2-thinking", DisplayName: "Kimi-K2-Thinking", Description: "Moonshot Kimi K2 thinking model", Created: 1762387200},
|
||||
{ID: "deepseek-v3.2-chat", DisplayName: "DeepSeek-V3.2", Description: "DeepSeek V3.2 Chat", Created: 1764576000},
|
||||
@@ -740,6 +741,7 @@ func GetIFlowModels() []*ModelInfo {
|
||||
{ID: "qwen3-235b-a22b-instruct", DisplayName: "Qwen3-235B-A22B-Instruct", Description: "Qwen3 235B A22B Instruct", Created: 1753401600},
|
||||
{ID: "qwen3-235b", DisplayName: "Qwen3-235B-A22B", Description: "Qwen3 235B A22B", Created: 1753401600},
|
||||
{ID: "minimax-m2", DisplayName: "MiniMax-M2", Description: "MiniMax M2", Created: 1758672000},
|
||||
{ID: "minimax-m2.1", DisplayName: "MiniMax-M2.1", Description: "MiniMax M2.1", Created: 1766448000},
|
||||
}
|
||||
models := make([]*ModelInfo, 0, len(entries))
|
||||
for _, entry := range entries {
|
||||
|
||||
@@ -17,6 +17,7 @@ import (
|
||||
"net/url"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
@@ -41,12 +42,15 @@ const (
|
||||
antigravityModelsPath = "/v1internal:fetchAvailableModels"
|
||||
antigravityClientID = "1071006060591-tmhssin2h21lcre235vtolojh4g403ep.apps.googleusercontent.com"
|
||||
antigravityClientSecret = "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"
|
||||
defaultAntigravityAgent = "antigravity/1.11.5 windows/amd64"
|
||||
defaultAntigravityAgent = "antigravity/1.104.0 darwin/arm64"
|
||||
antigravityAuthType = "antigravity"
|
||||
refreshSkew = 3000 * time.Second
|
||||
)
|
||||
|
||||
var randSource = rand.New(rand.NewSource(time.Now().UnixNano()))
|
||||
var (
|
||||
randSource = rand.New(rand.NewSource(time.Now().UnixNano()))
|
||||
randSourceMutex sync.Mutex
|
||||
)
|
||||
|
||||
// AntigravityExecutor proxies requests to the antigravity upstream.
|
||||
type AntigravityExecutor struct {
|
||||
@@ -1224,7 +1228,9 @@ func generateRequestID() string {
|
||||
}
|
||||
|
||||
func generateSessionID() string {
|
||||
randSourceMutex.Lock()
|
||||
n := randSource.Int63n(9_000_000_000_000_000_000)
|
||||
randSourceMutex.Unlock()
|
||||
return "-" + strconv.FormatInt(n, 10)
|
||||
}
|
||||
|
||||
@@ -1248,8 +1254,10 @@ func generateStableSessionID(payload []byte) string {
|
||||
func generateProjectID() string {
|
||||
adjectives := []string{"useful", "bright", "swift", "calm", "bold"}
|
||||
nouns := []string{"fuze", "wave", "spark", "flow", "core"}
|
||||
randSourceMutex.Lock()
|
||||
adj := adjectives[randSource.Intn(len(adjectives))]
|
||||
noun := nouns[randSource.Intn(len(nouns))]
|
||||
randSourceMutex.Unlock()
|
||||
randomPart := strings.ToLower(uuid.NewString())[:5]
|
||||
return adj + "-" + noun + "-" + randomPart
|
||||
}
|
||||
|
||||
@@ -275,6 +275,20 @@ func parseClaudeStreamUsage(line []byte) (usage.Detail, bool) {
|
||||
return detail, true
|
||||
}
|
||||
|
||||
func parseGeminiFamilyUsageDetail(node gjson.Result) usage.Detail {
|
||||
detail := usage.Detail{
|
||||
InputTokens: node.Get("promptTokenCount").Int(),
|
||||
OutputTokens: node.Get("candidatesTokenCount").Int(),
|
||||
ReasoningTokens: node.Get("thoughtsTokenCount").Int(),
|
||||
TotalTokens: node.Get("totalTokenCount").Int(),
|
||||
CachedTokens: node.Get("cachedContentTokenCount").Int(),
|
||||
}
|
||||
if detail.TotalTokens == 0 {
|
||||
detail.TotalTokens = detail.InputTokens + detail.OutputTokens + detail.ReasoningTokens
|
||||
}
|
||||
return detail
|
||||
}
|
||||
|
||||
func parseGeminiCLIUsage(data []byte) usage.Detail {
|
||||
usageNode := gjson.ParseBytes(data)
|
||||
node := usageNode.Get("response.usageMetadata")
|
||||
@@ -284,16 +298,7 @@ func parseGeminiCLIUsage(data []byte) usage.Detail {
|
||||
if !node.Exists() {
|
||||
return usage.Detail{}
|
||||
}
|
||||
detail := usage.Detail{
|
||||
InputTokens: node.Get("promptTokenCount").Int(),
|
||||
OutputTokens: node.Get("candidatesTokenCount").Int(),
|
||||
ReasoningTokens: node.Get("thoughtsTokenCount").Int(),
|
||||
TotalTokens: node.Get("totalTokenCount").Int(),
|
||||
}
|
||||
if detail.TotalTokens == 0 {
|
||||
detail.TotalTokens = detail.InputTokens + detail.OutputTokens + detail.ReasoningTokens
|
||||
}
|
||||
return detail
|
||||
return parseGeminiFamilyUsageDetail(node)
|
||||
}
|
||||
|
||||
func parseGeminiUsage(data []byte) usage.Detail {
|
||||
@@ -305,16 +310,7 @@ func parseGeminiUsage(data []byte) usage.Detail {
|
||||
if !node.Exists() {
|
||||
return usage.Detail{}
|
||||
}
|
||||
detail := usage.Detail{
|
||||
InputTokens: node.Get("promptTokenCount").Int(),
|
||||
OutputTokens: node.Get("candidatesTokenCount").Int(),
|
||||
ReasoningTokens: node.Get("thoughtsTokenCount").Int(),
|
||||
TotalTokens: node.Get("totalTokenCount").Int(),
|
||||
}
|
||||
if detail.TotalTokens == 0 {
|
||||
detail.TotalTokens = detail.InputTokens + detail.OutputTokens + detail.ReasoningTokens
|
||||
}
|
||||
return detail
|
||||
return parseGeminiFamilyUsageDetail(node)
|
||||
}
|
||||
|
||||
func parseGeminiStreamUsage(line []byte) (usage.Detail, bool) {
|
||||
@@ -329,16 +325,7 @@ func parseGeminiStreamUsage(line []byte) (usage.Detail, bool) {
|
||||
if !node.Exists() {
|
||||
return usage.Detail{}, false
|
||||
}
|
||||
detail := usage.Detail{
|
||||
InputTokens: node.Get("promptTokenCount").Int(),
|
||||
OutputTokens: node.Get("candidatesTokenCount").Int(),
|
||||
ReasoningTokens: node.Get("thoughtsTokenCount").Int(),
|
||||
TotalTokens: node.Get("totalTokenCount").Int(),
|
||||
}
|
||||
if detail.TotalTokens == 0 {
|
||||
detail.TotalTokens = detail.InputTokens + detail.OutputTokens + detail.ReasoningTokens
|
||||
}
|
||||
return detail, true
|
||||
return parseGeminiFamilyUsageDetail(node), true
|
||||
}
|
||||
|
||||
func parseGeminiCLIStreamUsage(line []byte) (usage.Detail, bool) {
|
||||
@@ -353,16 +340,7 @@ func parseGeminiCLIStreamUsage(line []byte) (usage.Detail, bool) {
|
||||
if !node.Exists() {
|
||||
return usage.Detail{}, false
|
||||
}
|
||||
detail := usage.Detail{
|
||||
InputTokens: node.Get("promptTokenCount").Int(),
|
||||
OutputTokens: node.Get("candidatesTokenCount").Int(),
|
||||
ReasoningTokens: node.Get("thoughtsTokenCount").Int(),
|
||||
TotalTokens: node.Get("totalTokenCount").Int(),
|
||||
}
|
||||
if detail.TotalTokens == 0 {
|
||||
detail.TotalTokens = detail.InputTokens + detail.OutputTokens + detail.ReasoningTokens
|
||||
}
|
||||
return detail, true
|
||||
return parseGeminiFamilyUsageDetail(node), true
|
||||
}
|
||||
|
||||
func parseAntigravityUsage(data []byte) usage.Detail {
|
||||
@@ -377,16 +355,7 @@ func parseAntigravityUsage(data []byte) usage.Detail {
|
||||
if !node.Exists() {
|
||||
return usage.Detail{}
|
||||
}
|
||||
detail := usage.Detail{
|
||||
InputTokens: node.Get("promptTokenCount").Int(),
|
||||
OutputTokens: node.Get("candidatesTokenCount").Int(),
|
||||
ReasoningTokens: node.Get("thoughtsTokenCount").Int(),
|
||||
TotalTokens: node.Get("totalTokenCount").Int(),
|
||||
}
|
||||
if detail.TotalTokens == 0 {
|
||||
detail.TotalTokens = detail.InputTokens + detail.OutputTokens + detail.ReasoningTokens
|
||||
}
|
||||
return detail
|
||||
return parseGeminiFamilyUsageDetail(node)
|
||||
}
|
||||
|
||||
func parseAntigravityStreamUsage(line []byte) (usage.Detail, bool) {
|
||||
@@ -404,16 +373,7 @@ func parseAntigravityStreamUsage(line []byte) (usage.Detail, bool) {
|
||||
if !node.Exists() {
|
||||
return usage.Detail{}, false
|
||||
}
|
||||
detail := usage.Detail{
|
||||
InputTokens: node.Get("promptTokenCount").Int(),
|
||||
OutputTokens: node.Get("candidatesTokenCount").Int(),
|
||||
ReasoningTokens: node.Get("thoughtsTokenCount").Int(),
|
||||
TotalTokens: node.Get("totalTokenCount").Int(),
|
||||
}
|
||||
if detail.TotalTokens == 0 {
|
||||
detail.TotalTokens = detail.InputTokens + detail.OutputTokens + detail.ReasoningTokens
|
||||
}
|
||||
return detail, true
|
||||
return parseGeminiFamilyUsageDetail(node), true
|
||||
}
|
||||
|
||||
var stopChunkWithoutUsage sync.Map
|
||||
|
||||
@@ -35,6 +35,7 @@ type Params struct {
|
||||
CandidatesTokenCount int64 // Cached candidate token count from usage metadata
|
||||
ThoughtsTokenCount int64 // Cached thinking token count from usage metadata
|
||||
TotalTokenCount int64 // Cached total token count from usage metadata
|
||||
CachedTokenCount int64 // Cached content token count (indicates prompt caching)
|
||||
HasSentFinalEvents bool // Indicates if final content/message events have been sent
|
||||
HasToolUse bool // Indicates if tool use was observed in the stream
|
||||
HasContent bool // Tracks whether any content (text, thinking, or tool use) has been output
|
||||
@@ -270,7 +271,8 @@ func ConvertAntigravityResponseToClaude(_ context.Context, _ string, originalReq
|
||||
|
||||
if usageResult := gjson.GetBytes(rawJSON, "response.usageMetadata"); usageResult.Exists() {
|
||||
params.HasUsageMetadata = true
|
||||
params.PromptTokenCount = usageResult.Get("promptTokenCount").Int()
|
||||
params.CachedTokenCount = usageResult.Get("cachedContentTokenCount").Int()
|
||||
params.PromptTokenCount = usageResult.Get("promptTokenCount").Int() - params.CachedTokenCount
|
||||
params.CandidatesTokenCount = usageResult.Get("candidatesTokenCount").Int()
|
||||
params.ThoughtsTokenCount = usageResult.Get("thoughtsTokenCount").Int()
|
||||
params.TotalTokenCount = usageResult.Get("totalTokenCount").Int()
|
||||
@@ -322,6 +324,14 @@ func appendFinalEvents(params *Params, output *string, force bool) {
|
||||
*output = *output + "event: message_delta\n"
|
||||
*output = *output + "data: "
|
||||
delta := fmt.Sprintf(`{"type":"message_delta","delta":{"stop_reason":"%s","stop_sequence":null},"usage":{"input_tokens":%d,"output_tokens":%d}}`, stopReason, params.PromptTokenCount, usageOutputTokens)
|
||||
// Add cache_read_input_tokens if cached tokens are present (indicates prompt caching is working)
|
||||
if params.CachedTokenCount > 0 {
|
||||
var err error
|
||||
delta, err = sjson.Set(delta, "usage.cache_read_input_tokens", params.CachedTokenCount)
|
||||
if err != nil {
|
||||
log.Warnf("antigravity claude response: failed to set cache_read_input_tokens: %v", err)
|
||||
}
|
||||
}
|
||||
*output = *output + delta + "\n\n\n"
|
||||
|
||||
params.HasSentFinalEvents = true
|
||||
@@ -361,6 +371,7 @@ func ConvertAntigravityResponseToClaudeNonStream(_ context.Context, _ string, or
|
||||
candidateTokens := root.Get("response.usageMetadata.candidatesTokenCount").Int()
|
||||
thoughtTokens := root.Get("response.usageMetadata.thoughtsTokenCount").Int()
|
||||
totalTokens := root.Get("response.usageMetadata.totalTokenCount").Int()
|
||||
cachedTokens := root.Get("response.usageMetadata.cachedContentTokenCount").Int()
|
||||
outputTokens := candidateTokens + thoughtTokens
|
||||
if outputTokens == 0 && totalTokens > 0 {
|
||||
outputTokens = totalTokens - promptTokens
|
||||
@@ -374,6 +385,14 @@ func ConvertAntigravityResponseToClaudeNonStream(_ context.Context, _ string, or
|
||||
responseJSON, _ = sjson.Set(responseJSON, "model", root.Get("response.modelVersion").String())
|
||||
responseJSON, _ = sjson.Set(responseJSON, "usage.input_tokens", promptTokens)
|
||||
responseJSON, _ = sjson.Set(responseJSON, "usage.output_tokens", outputTokens)
|
||||
// Add cache_read_input_tokens if cached tokens are present (indicates prompt caching is working)
|
||||
if cachedTokens > 0 {
|
||||
var err error
|
||||
responseJSON, err = sjson.Set(responseJSON, "usage.cache_read_input_tokens", cachedTokens)
|
||||
if err != nil {
|
||||
log.Warnf("antigravity claude response: failed to set cache_read_input_tokens: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
contentArrayInitialized := false
|
||||
ensureContentArray := func() {
|
||||
|
||||
@@ -249,8 +249,28 @@ func ConvertOpenAIRequestToAntigravity(modelName string, inputRawJSON []byte, _
|
||||
p := 0
|
||||
if content.Type == gjson.String {
|
||||
node, _ = sjson.SetBytes(node, "parts.-1.text", content.String())
|
||||
out, _ = sjson.SetRawBytes(out, "request.contents.-1", node)
|
||||
p++
|
||||
} else if content.IsArray() {
|
||||
// Assistant multimodal content (e.g. text + image) -> single model content with parts
|
||||
for _, item := range content.Array() {
|
||||
switch item.Get("type").String() {
|
||||
case "text":
|
||||
p++
|
||||
case "image_url":
|
||||
// If the assistant returned an inline data URL, preserve it for history fidelity.
|
||||
imageURL := item.Get("image_url.url").String()
|
||||
if len(imageURL) > 5 { // expect data:...
|
||||
pieces := strings.SplitN(imageURL[5:], ";", 2)
|
||||
if len(pieces) == 2 && len(pieces[1]) > 7 {
|
||||
mime := pieces[0]
|
||||
data := pieces[1][7:]
|
||||
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".inlineData.mime_type", mime)
|
||||
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".inlineData.data", data)
|
||||
p++
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Tool calls -> single model content with functionCall parts
|
||||
@@ -305,6 +325,8 @@ func ConvertOpenAIRequestToAntigravity(modelName string, inputRawJSON []byte, _
|
||||
if pp > 0 {
|
||||
out, _ = sjson.SetRawBytes(out, "request.contents.-1", toolNode)
|
||||
}
|
||||
} else {
|
||||
out, _ = sjson.SetRawBytes(out, "request.contents.-1", node)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -13,6 +13,8 @@ import (
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
. "github.com/router-for-me/CLIProxyAPI/v6/internal/translator/gemini/openai/chat-completions"
|
||||
"github.com/tidwall/gjson"
|
||||
"github.com/tidwall/sjson"
|
||||
@@ -85,18 +87,27 @@ func ConvertAntigravityResponseToOpenAI(_ context.Context, _ string, originalReq
|
||||
|
||||
// Extract and set usage metadata (token counts).
|
||||
if usageResult := gjson.GetBytes(rawJSON, "response.usageMetadata"); usageResult.Exists() {
|
||||
cachedTokenCount := usageResult.Get("cachedContentTokenCount").Int()
|
||||
if candidatesTokenCountResult := usageResult.Get("candidatesTokenCount"); candidatesTokenCountResult.Exists() {
|
||||
template, _ = sjson.Set(template, "usage.completion_tokens", candidatesTokenCountResult.Int())
|
||||
}
|
||||
if totalTokenCountResult := usageResult.Get("totalTokenCount"); totalTokenCountResult.Exists() {
|
||||
template, _ = sjson.Set(template, "usage.total_tokens", totalTokenCountResult.Int())
|
||||
}
|
||||
promptTokenCount := usageResult.Get("promptTokenCount").Int()
|
||||
promptTokenCount := usageResult.Get("promptTokenCount").Int() - cachedTokenCount
|
||||
thoughtsTokenCount := usageResult.Get("thoughtsTokenCount").Int()
|
||||
template, _ = sjson.Set(template, "usage.prompt_tokens", promptTokenCount+thoughtsTokenCount)
|
||||
if thoughtsTokenCount > 0 {
|
||||
template, _ = sjson.Set(template, "usage.completion_tokens_details.reasoning_tokens", thoughtsTokenCount)
|
||||
}
|
||||
// Include cached token count if present (indicates prompt caching is working)
|
||||
if cachedTokenCount > 0 {
|
||||
var err error
|
||||
template, err = sjson.Set(template, "usage.prompt_tokens_details.cached_tokens", cachedTokenCount)
|
||||
if err != nil {
|
||||
log.Warnf("antigravity openai response: failed to set cached_tokens: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Process the main content part of the response.
|
||||
@@ -170,12 +181,14 @@ func ConvertAntigravityResponseToOpenAI(_ context.Context, _ string, originalReq
|
||||
mimeType = "image/png"
|
||||
}
|
||||
imageURL := fmt.Sprintf("data:%s;base64,%s", mimeType, data)
|
||||
imagePayload := `{"image_url":{"url":""},"type":"image_url"}`
|
||||
imagePayload, _ = sjson.Set(imagePayload, "image_url.url", imageURL)
|
||||
imagesResult := gjson.Get(template, "choices.0.delta.images")
|
||||
if !imagesResult.Exists() || !imagesResult.IsArray() {
|
||||
template, _ = sjson.SetRaw(template, "choices.0.delta.images", `[]`)
|
||||
}
|
||||
imageIndex := len(gjson.Get(template, "choices.0.delta.images").Array())
|
||||
imagePayload := `{"type":"image_url","image_url":{"url":""}}`
|
||||
imagePayload, _ = sjson.Set(imagePayload, "index", imageIndex)
|
||||
imagePayload, _ = sjson.Set(imagePayload, "image_url.url", imageURL)
|
||||
template, _ = sjson.Set(template, "choices.0.delta.role", "assistant")
|
||||
template, _ = sjson.SetRaw(template, "choices.0.delta.images.-1", imagePayload)
|
||||
}
|
||||
|
||||
@@ -218,8 +218,29 @@ func ConvertOpenAIRequestToGeminiCLI(modelName string, inputRawJSON []byte, _ bo
|
||||
if content.Type == gjson.String {
|
||||
// Assistant text -> single model content
|
||||
node, _ = sjson.SetBytes(node, "parts.-1.text", content.String())
|
||||
out, _ = sjson.SetRawBytes(out, "request.contents.-1", node)
|
||||
p++
|
||||
} else if content.IsArray() {
|
||||
// Assistant multimodal content (e.g. text + image) -> single model content with parts
|
||||
for _, item := range content.Array() {
|
||||
switch item.Get("type").String() {
|
||||
case "text":
|
||||
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".text", item.Get("text").String())
|
||||
p++
|
||||
case "image_url":
|
||||
// If the assistant returned an inline data URL, preserve it for history fidelity.
|
||||
imageURL := item.Get("image_url.url").String()
|
||||
if len(imageURL) > 5 { // expect data:...
|
||||
pieces := strings.SplitN(imageURL[5:], ";", 2)
|
||||
if len(pieces) == 2 && len(pieces[1]) > 7 {
|
||||
mime := pieces[0]
|
||||
data := pieces[1][7:]
|
||||
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".inlineData.mime_type", mime)
|
||||
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".inlineData.data", data)
|
||||
p++
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Tool calls -> single model content with functionCall parts
|
||||
@@ -244,7 +265,7 @@ func ConvertOpenAIRequestToGeminiCLI(modelName string, inputRawJSON []byte, _ bo
|
||||
out, _ = sjson.SetRawBytes(out, "request.contents.-1", node)
|
||||
|
||||
// Append a single tool content combining name + response per function
|
||||
toolNode := []byte(`{"role":"tool","parts":[]}`)
|
||||
toolNode := []byte(`{"role":"user","parts":[]}`)
|
||||
pp := 0
|
||||
for _, fid := range fIDs {
|
||||
if name, ok := tcID2Name[fid]; ok {
|
||||
@@ -260,6 +281,8 @@ func ConvertOpenAIRequestToGeminiCLI(modelName string, inputRawJSON []byte, _ bo
|
||||
if pp > 0 {
|
||||
out, _ = sjson.SetRawBytes(out, "request.contents.-1", toolNode)
|
||||
}
|
||||
} else {
|
||||
out, _ = sjson.SetRawBytes(out, "request.contents.-1", node)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -170,12 +170,14 @@ func ConvertCliResponseToOpenAI(_ context.Context, _ string, originalRequestRawJ
|
||||
mimeType = "image/png"
|
||||
}
|
||||
imageURL := fmt.Sprintf("data:%s;base64,%s", mimeType, data)
|
||||
imagePayload := `{"image_url":{"url":""},"type":"image_url"}`
|
||||
imagePayload, _ = sjson.Set(imagePayload, "image_url.url", imageURL)
|
||||
imagesResult := gjson.Get(template, "choices.0.delta.images")
|
||||
if !imagesResult.Exists() || !imagesResult.IsArray() {
|
||||
template, _ = sjson.SetRaw(template, "choices.0.delta.images", `[]`)
|
||||
}
|
||||
imageIndex := len(gjson.Get(template, "choices.0.delta.images").Array())
|
||||
imagePayload := `{"type":"image_url","image_url":{"url":""}}`
|
||||
imagePayload, _ = sjson.Set(imagePayload, "index", imageIndex)
|
||||
imagePayload, _ = sjson.Set(imagePayload, "image_url.url", imageURL)
|
||||
template, _ = sjson.Set(template, "choices.0.delta.role", "assistant")
|
||||
template, _ = sjson.SetRaw(template, "choices.0.delta.images.-1", imagePayload)
|
||||
}
|
||||
|
||||
@@ -233,18 +233,15 @@ func ConvertOpenAIRequestToGemini(modelName string, inputRawJSON []byte, _ bool)
|
||||
} else if role == "assistant" {
|
||||
node := []byte(`{"role":"model","parts":[]}`)
|
||||
p := 0
|
||||
|
||||
if content.Type == gjson.String {
|
||||
// Assistant text -> single model content
|
||||
node, _ = sjson.SetBytes(node, "parts.-1.text", content.String())
|
||||
out, _ = sjson.SetRawBytes(out, "contents.-1", node)
|
||||
p++
|
||||
} else if content.IsArray() {
|
||||
// Assistant multimodal content (e.g. text + image) -> single model content with parts
|
||||
for _, item := range content.Array() {
|
||||
switch item.Get("type").String() {
|
||||
case "text":
|
||||
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".text", item.Get("text").String())
|
||||
p++
|
||||
case "image_url":
|
||||
// If the assistant returned an inline data URL, preserve it for history fidelity.
|
||||
@@ -261,7 +258,6 @@ func ConvertOpenAIRequestToGemini(modelName string, inputRawJSON []byte, _ bool)
|
||||
}
|
||||
}
|
||||
}
|
||||
out, _ = sjson.SetRawBytes(out, "contents.-1", node)
|
||||
}
|
||||
|
||||
// Tool calls -> single model content with functionCall parts
|
||||
@@ -286,7 +282,7 @@ func ConvertOpenAIRequestToGemini(modelName string, inputRawJSON []byte, _ bool)
|
||||
out, _ = sjson.SetRawBytes(out, "contents.-1", node)
|
||||
|
||||
// Append a single tool content combining name + response per function
|
||||
toolNode := []byte(`{"role":"tool","parts":[]}`)
|
||||
toolNode := []byte(`{"role":"user","parts":[]}`)
|
||||
pp := 0
|
||||
for _, fid := range fIDs {
|
||||
if name, ok := tcID2Name[fid]; ok {
|
||||
@@ -302,6 +298,8 @@ func ConvertOpenAIRequestToGemini(modelName string, inputRawJSON []byte, _ bool)
|
||||
if pp > 0 {
|
||||
out, _ = sjson.SetRawBytes(out, "contents.-1", toolNode)
|
||||
}
|
||||
} else {
|
||||
out, _ = sjson.SetRawBytes(out, "contents.-1", node)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -13,6 +13,7 @@ import (
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/tidwall/gjson"
|
||||
"github.com/tidwall/sjson"
|
||||
)
|
||||
@@ -88,18 +89,27 @@ func ConvertGeminiResponseToOpenAI(_ context.Context, _ string, originalRequestR
|
||||
|
||||
// Extract and set usage metadata (token counts).
|
||||
if usageResult := gjson.GetBytes(rawJSON, "usageMetadata"); usageResult.Exists() {
|
||||
cachedTokenCount := usageResult.Get("cachedContentTokenCount").Int()
|
||||
if candidatesTokenCountResult := usageResult.Get("candidatesTokenCount"); candidatesTokenCountResult.Exists() {
|
||||
template, _ = sjson.Set(template, "usage.completion_tokens", candidatesTokenCountResult.Int())
|
||||
}
|
||||
if totalTokenCountResult := usageResult.Get("totalTokenCount"); totalTokenCountResult.Exists() {
|
||||
template, _ = sjson.Set(template, "usage.total_tokens", totalTokenCountResult.Int())
|
||||
}
|
||||
promptTokenCount := usageResult.Get("promptTokenCount").Int()
|
||||
promptTokenCount := usageResult.Get("promptTokenCount").Int() - cachedTokenCount
|
||||
thoughtsTokenCount := usageResult.Get("thoughtsTokenCount").Int()
|
||||
template, _ = sjson.Set(template, "usage.prompt_tokens", promptTokenCount+thoughtsTokenCount)
|
||||
if thoughtsTokenCount > 0 {
|
||||
template, _ = sjson.Set(template, "usage.completion_tokens_details.reasoning_tokens", thoughtsTokenCount)
|
||||
}
|
||||
// Include cached token count if present (indicates prompt caching is working)
|
||||
if cachedTokenCount > 0 {
|
||||
var err error
|
||||
template, err = sjson.Set(template, "usage.prompt_tokens_details.cached_tokens", cachedTokenCount)
|
||||
if err != nil {
|
||||
log.Warnf("gemini openai response: failed to set cached_tokens in streaming: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Process the main content part of the response.
|
||||
@@ -172,12 +182,14 @@ func ConvertGeminiResponseToOpenAI(_ context.Context, _ string, originalRequestR
|
||||
mimeType = "image/png"
|
||||
}
|
||||
imageURL := fmt.Sprintf("data:%s;base64,%s", mimeType, data)
|
||||
imagePayload := `{"image_url":{"url":""},"type":"image_url"}`
|
||||
imagePayload, _ = sjson.Set(imagePayload, "image_url.url", imageURL)
|
||||
imagesResult := gjson.Get(template, "choices.0.delta.images")
|
||||
if !imagesResult.Exists() || !imagesResult.IsArray() {
|
||||
template, _ = sjson.SetRaw(template, "choices.0.delta.images", `[]`)
|
||||
}
|
||||
imageIndex := len(gjson.Get(template, "choices.0.delta.images").Array())
|
||||
imagePayload := `{"type":"image_url","image_url":{"url":""}}`
|
||||
imagePayload, _ = sjson.Set(imagePayload, "index", imageIndex)
|
||||
imagePayload, _ = sjson.Set(imagePayload, "image_url.url", imageURL)
|
||||
template, _ = sjson.Set(template, "choices.0.delta.role", "assistant")
|
||||
template, _ = sjson.SetRaw(template, "choices.0.delta.images.-1", imagePayload)
|
||||
}
|
||||
@@ -240,10 +252,19 @@ func ConvertGeminiResponseToOpenAINonStream(_ context.Context, _ string, origina
|
||||
}
|
||||
promptTokenCount := usageResult.Get("promptTokenCount").Int()
|
||||
thoughtsTokenCount := usageResult.Get("thoughtsTokenCount").Int()
|
||||
cachedTokenCount := usageResult.Get("cachedContentTokenCount").Int()
|
||||
template, _ = sjson.Set(template, "usage.prompt_tokens", promptTokenCount+thoughtsTokenCount)
|
||||
if thoughtsTokenCount > 0 {
|
||||
template, _ = sjson.Set(template, "usage.completion_tokens_details.reasoning_tokens", thoughtsTokenCount)
|
||||
}
|
||||
// Include cached token count if present (indicates prompt caching is working)
|
||||
if cachedTokenCount > 0 {
|
||||
var err error
|
||||
template, err = sjson.Set(template, "usage.prompt_tokens_details.cached_tokens", cachedTokenCount)
|
||||
if err != nil {
|
||||
log.Warnf("gemini openai response: failed to set cached_tokens in non-streaming: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Process the main content part of the response.
|
||||
@@ -297,12 +318,14 @@ func ConvertGeminiResponseToOpenAINonStream(_ context.Context, _ string, origina
|
||||
mimeType = "image/png"
|
||||
}
|
||||
imageURL := fmt.Sprintf("data:%s;base64,%s", mimeType, data)
|
||||
imagePayload := `{"image_url":{"url":""},"type":"image_url"}`
|
||||
imagePayload, _ = sjson.Set(imagePayload, "image_url.url", imageURL)
|
||||
imagesResult := gjson.Get(template, "choices.0.message.images")
|
||||
if !imagesResult.Exists() || !imagesResult.IsArray() {
|
||||
template, _ = sjson.SetRaw(template, "choices.0.message.images", `[]`)
|
||||
}
|
||||
imageIndex := len(gjson.Get(template, "choices.0.message.images").Array())
|
||||
imagePayload := `{"type":"image_url","image_url":{"url":""}}`
|
||||
imagePayload, _ = sjson.Set(imagePayload, "index", imageIndex)
|
||||
imagePayload, _ = sjson.Set(imagePayload, "image_url.url", imageURL)
|
||||
template, _ = sjson.Set(template, "choices.0.message.role", "assistant")
|
||||
template, _ = sjson.SetRaw(template, "choices.0.message.images.-1", imagePayload)
|
||||
}
|
||||
|
||||
@@ -14,7 +14,6 @@ import (
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
. "github.com/router-for-me/CLIProxyAPI/v6/internal/constant"
|
||||
@@ -185,14 +184,6 @@ func (h *ClaudeCodeAPIHandler) handleNonStreamingResponse(c *gin.Context, rawJSO
|
||||
// - c: The Gin context for the request.
|
||||
// - rawJSON: The raw JSON request body.
|
||||
func (h *ClaudeCodeAPIHandler) handleStreamingResponse(c *gin.Context, rawJSON []byte) {
|
||||
// Set up Server-Sent Events (SSE) headers for streaming response
|
||||
// These headers are essential for maintaining a persistent connection
|
||||
// and enabling real-time streaming of chat completions
|
||||
c.Header("Content-Type", "text/event-stream")
|
||||
c.Header("Cache-Control", "no-cache")
|
||||
c.Header("Connection", "keep-alive")
|
||||
c.Header("Access-Control-Allow-Origin", "*")
|
||||
|
||||
// Get the http.Flusher interface to manually flush the response.
|
||||
// This is crucial for streaming as it allows immediate sending of data chunks
|
||||
flusher, ok := c.Writer.(http.Flusher)
|
||||
@@ -213,58 +204,82 @@ func (h *ClaudeCodeAPIHandler) handleStreamingResponse(c *gin.Context, rawJSON [
|
||||
cliCtx, cliCancel := h.GetContextWithCancel(h, c, context.Background())
|
||||
|
||||
dataChan, errChan := h.ExecuteStreamWithAuthManager(cliCtx, h.HandlerType(), modelName, rawJSON, "")
|
||||
h.forwardClaudeStream(c, flusher, func(err error) { cliCancel(err) }, dataChan, errChan)
|
||||
return
|
||||
}
|
||||
setSSEHeaders := func() {
|
||||
c.Header("Content-Type", "text/event-stream")
|
||||
c.Header("Cache-Control", "no-cache")
|
||||
c.Header("Connection", "keep-alive")
|
||||
c.Header("Access-Control-Allow-Origin", "*")
|
||||
}
|
||||
|
||||
func (h *ClaudeCodeAPIHandler) forwardClaudeStream(c *gin.Context, flusher http.Flusher, cancel func(error), data <-chan []byte, errs <-chan *interfaces.ErrorMessage) {
|
||||
// OpenAI-style stream forwarding: write each SSE chunk and flush immediately.
|
||||
// This guarantees clients see incremental output even for small responses.
|
||||
// Peek at the first chunk to determine success or failure before setting headers
|
||||
for {
|
||||
select {
|
||||
case <-c.Request.Context().Done():
|
||||
cancel(c.Request.Context().Err())
|
||||
cliCancel(c.Request.Context().Err())
|
||||
return
|
||||
|
||||
case chunk, ok := <-data:
|
||||
case errMsg, ok := <-errChan:
|
||||
if !ok {
|
||||
// Err channel closed cleanly; wait for data channel.
|
||||
errChan = nil
|
||||
continue
|
||||
}
|
||||
// Upstream failed immediately. Return proper error status and JSON.
|
||||
h.WriteErrorResponse(c, errMsg)
|
||||
if errMsg != nil {
|
||||
cliCancel(errMsg.Error)
|
||||
} else {
|
||||
cliCancel(nil)
|
||||
}
|
||||
return
|
||||
case chunk, ok := <-dataChan:
|
||||
if !ok {
|
||||
// Stream closed without data? Send DONE or just headers.
|
||||
setSSEHeaders()
|
||||
flusher.Flush()
|
||||
cancel(nil)
|
||||
cliCancel(nil)
|
||||
return
|
||||
}
|
||||
|
||||
// Success! Set headers now.
|
||||
setSSEHeaders()
|
||||
|
||||
// Write the first chunk
|
||||
if len(chunk) > 0 {
|
||||
_, _ = c.Writer.Write(chunk)
|
||||
flusher.Flush()
|
||||
}
|
||||
|
||||
case errMsg, ok := <-errs:
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
if errMsg != nil {
|
||||
status := http.StatusInternalServerError
|
||||
if errMsg.StatusCode > 0 {
|
||||
status = errMsg.StatusCode
|
||||
}
|
||||
c.Status(status)
|
||||
|
||||
// An error occurred: emit as a proper SSE error event
|
||||
errorBytes, _ := json.Marshal(h.toClaudeError(errMsg))
|
||||
_, _ = fmt.Fprintf(c.Writer, "event: error\ndata: %s\n\n", errorBytes)
|
||||
flusher.Flush()
|
||||
}
|
||||
|
||||
var execErr error
|
||||
if errMsg != nil {
|
||||
execErr = errMsg.Error
|
||||
}
|
||||
cancel(execErr)
|
||||
// Continue streaming the rest
|
||||
h.forwardClaudeStream(c, flusher, func(err error) { cliCancel(err) }, dataChan, errChan)
|
||||
return
|
||||
case <-time.After(500 * time.Millisecond):
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (h *ClaudeCodeAPIHandler) forwardClaudeStream(c *gin.Context, flusher http.Flusher, cancel func(error), data <-chan []byte, errs <-chan *interfaces.ErrorMessage) {
|
||||
h.ForwardStream(c, flusher, cancel, data, errs, handlers.StreamForwardOptions{
|
||||
WriteChunk: func(chunk []byte) {
|
||||
if len(chunk) == 0 {
|
||||
return
|
||||
}
|
||||
_, _ = c.Writer.Write(chunk)
|
||||
},
|
||||
WriteTerminalError: func(errMsg *interfaces.ErrorMessage) {
|
||||
if errMsg == nil {
|
||||
return
|
||||
}
|
||||
status := http.StatusInternalServerError
|
||||
if errMsg.StatusCode > 0 {
|
||||
status = errMsg.StatusCode
|
||||
}
|
||||
c.Status(status)
|
||||
|
||||
errorBytes, _ := json.Marshal(h.toClaudeError(errMsg))
|
||||
_, _ = fmt.Fprintf(c.Writer, "event: error\ndata: %s\n\n", errorBytes)
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
type claudeErrorDetail struct {
|
||||
Type string `json:"type"`
|
||||
Message string `json:"message"`
|
||||
|
||||
@@ -182,19 +182,18 @@ func (h *GeminiCLIAPIHandler) handleInternalGenerateContent(c *gin.Context, rawJ
|
||||
}
|
||||
|
||||
func (h *GeminiCLIAPIHandler) forwardCLIStream(c *gin.Context, flusher http.Flusher, alt string, cancel func(error), data <-chan []byte, errs <-chan *interfaces.ErrorMessage) {
|
||||
for {
|
||||
select {
|
||||
case <-c.Request.Context().Done():
|
||||
cancel(c.Request.Context().Err())
|
||||
return
|
||||
case chunk, ok := <-data:
|
||||
if !ok {
|
||||
cancel(nil)
|
||||
return
|
||||
}
|
||||
var keepAliveInterval *time.Duration
|
||||
if alt != "" {
|
||||
disabled := time.Duration(0)
|
||||
keepAliveInterval = &disabled
|
||||
}
|
||||
|
||||
h.ForwardStream(c, flusher, cancel, data, errs, handlers.StreamForwardOptions{
|
||||
KeepAliveInterval: keepAliveInterval,
|
||||
WriteChunk: func(chunk []byte) {
|
||||
if alt == "" {
|
||||
if bytes.Equal(chunk, []byte("data: [DONE]")) || bytes.Equal(chunk, []byte("[DONE]")) {
|
||||
continue
|
||||
return
|
||||
}
|
||||
|
||||
if !bytes.HasPrefix(chunk, []byte("data:")) {
|
||||
@@ -206,22 +205,25 @@ func (h *GeminiCLIAPIHandler) forwardCLIStream(c *gin.Context, flusher http.Flus
|
||||
} else {
|
||||
_, _ = c.Writer.Write(chunk)
|
||||
}
|
||||
flusher.Flush()
|
||||
case errMsg, ok := <-errs:
|
||||
if !ok {
|
||||
continue
|
||||
},
|
||||
WriteTerminalError: func(errMsg *interfaces.ErrorMessage) {
|
||||
if errMsg == nil {
|
||||
return
|
||||
}
|
||||
if errMsg != nil {
|
||||
h.WriteErrorResponse(c, errMsg)
|
||||
flusher.Flush()
|
||||
status := http.StatusInternalServerError
|
||||
if errMsg.StatusCode > 0 {
|
||||
status = errMsg.StatusCode
|
||||
}
|
||||
var execErr error
|
||||
if errMsg != nil {
|
||||
execErr = errMsg.Error
|
||||
errText := http.StatusText(status)
|
||||
if errMsg.Error != nil && errMsg.Error.Error() != "" {
|
||||
errText = errMsg.Error.Error()
|
||||
}
|
||||
cancel(execErr)
|
||||
return
|
||||
case <-time.After(500 * time.Millisecond):
|
||||
}
|
||||
}
|
||||
body := handlers.BuildErrorResponseBody(status, errText)
|
||||
if alt == "" {
|
||||
_, _ = fmt.Fprintf(c.Writer, "event: error\ndata: %s\n\n", string(body))
|
||||
} else {
|
||||
_, _ = c.Writer.Write(body)
|
||||
}
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
@@ -226,13 +226,6 @@ func (h *GeminiAPIHandler) GeminiHandler(c *gin.Context) {
|
||||
func (h *GeminiAPIHandler) handleStreamGenerateContent(c *gin.Context, modelName string, rawJSON []byte) {
|
||||
alt := h.GetAlt(c)
|
||||
|
||||
if alt == "" {
|
||||
c.Header("Content-Type", "text/event-stream")
|
||||
c.Header("Cache-Control", "no-cache")
|
||||
c.Header("Connection", "keep-alive")
|
||||
c.Header("Access-Control-Allow-Origin", "*")
|
||||
}
|
||||
|
||||
// Get the http.Flusher interface to manually flush the response.
|
||||
flusher, ok := c.Writer.(http.Flusher)
|
||||
if !ok {
|
||||
@@ -247,8 +240,65 @@ func (h *GeminiAPIHandler) handleStreamGenerateContent(c *gin.Context, modelName
|
||||
|
||||
cliCtx, cliCancel := h.GetContextWithCancel(h, c, context.Background())
|
||||
dataChan, errChan := h.ExecuteStreamWithAuthManager(cliCtx, h.HandlerType(), modelName, rawJSON, alt)
|
||||
h.forwardGeminiStream(c, flusher, alt, func(err error) { cliCancel(err) }, dataChan, errChan)
|
||||
return
|
||||
|
||||
setSSEHeaders := func() {
|
||||
c.Header("Content-Type", "text/event-stream")
|
||||
c.Header("Cache-Control", "no-cache")
|
||||
c.Header("Connection", "keep-alive")
|
||||
c.Header("Access-Control-Allow-Origin", "*")
|
||||
}
|
||||
|
||||
// Peek at the first chunk
|
||||
for {
|
||||
select {
|
||||
case <-c.Request.Context().Done():
|
||||
cliCancel(c.Request.Context().Err())
|
||||
return
|
||||
case errMsg, ok := <-errChan:
|
||||
if !ok {
|
||||
// Err channel closed cleanly; wait for data channel.
|
||||
errChan = nil
|
||||
continue
|
||||
}
|
||||
// Upstream failed immediately. Return proper error status and JSON.
|
||||
h.WriteErrorResponse(c, errMsg)
|
||||
if errMsg != nil {
|
||||
cliCancel(errMsg.Error)
|
||||
} else {
|
||||
cliCancel(nil)
|
||||
}
|
||||
return
|
||||
case chunk, ok := <-dataChan:
|
||||
if !ok {
|
||||
// Closed without data
|
||||
if alt == "" {
|
||||
setSSEHeaders()
|
||||
}
|
||||
flusher.Flush()
|
||||
cliCancel(nil)
|
||||
return
|
||||
}
|
||||
|
||||
// Success! Set headers.
|
||||
if alt == "" {
|
||||
setSSEHeaders()
|
||||
}
|
||||
|
||||
// Write first chunk
|
||||
if alt == "" {
|
||||
_, _ = c.Writer.Write([]byte("data: "))
|
||||
_, _ = c.Writer.Write(chunk)
|
||||
_, _ = c.Writer.Write([]byte("\n\n"))
|
||||
} else {
|
||||
_, _ = c.Writer.Write(chunk)
|
||||
}
|
||||
flusher.Flush()
|
||||
|
||||
// Continue
|
||||
h.forwardGeminiStream(c, flusher, alt, func(err error) { cliCancel(err) }, dataChan, errChan)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// handleCountTokens handles token counting requests for Gemini models.
|
||||
@@ -297,16 +347,15 @@ func (h *GeminiAPIHandler) handleGenerateContent(c *gin.Context, modelName strin
|
||||
}
|
||||
|
||||
func (h *GeminiAPIHandler) forwardGeminiStream(c *gin.Context, flusher http.Flusher, alt string, cancel func(error), data <-chan []byte, errs <-chan *interfaces.ErrorMessage) {
|
||||
for {
|
||||
select {
|
||||
case <-c.Request.Context().Done():
|
||||
cancel(c.Request.Context().Err())
|
||||
return
|
||||
case chunk, ok := <-data:
|
||||
if !ok {
|
||||
cancel(nil)
|
||||
return
|
||||
}
|
||||
var keepAliveInterval *time.Duration
|
||||
if alt != "" {
|
||||
disabled := time.Duration(0)
|
||||
keepAliveInterval = &disabled
|
||||
}
|
||||
|
||||
h.ForwardStream(c, flusher, cancel, data, errs, handlers.StreamForwardOptions{
|
||||
KeepAliveInterval: keepAliveInterval,
|
||||
WriteChunk: func(chunk []byte) {
|
||||
if alt == "" {
|
||||
_, _ = c.Writer.Write([]byte("data: "))
|
||||
_, _ = c.Writer.Write(chunk)
|
||||
@@ -314,22 +363,25 @@ func (h *GeminiAPIHandler) forwardGeminiStream(c *gin.Context, flusher http.Flus
|
||||
} else {
|
||||
_, _ = c.Writer.Write(chunk)
|
||||
}
|
||||
flusher.Flush()
|
||||
case errMsg, ok := <-errs:
|
||||
if !ok {
|
||||
continue
|
||||
},
|
||||
WriteTerminalError: func(errMsg *interfaces.ErrorMessage) {
|
||||
if errMsg == nil {
|
||||
return
|
||||
}
|
||||
if errMsg != nil {
|
||||
h.WriteErrorResponse(c, errMsg)
|
||||
flusher.Flush()
|
||||
status := http.StatusInternalServerError
|
||||
if errMsg.StatusCode > 0 {
|
||||
status = errMsg.StatusCode
|
||||
}
|
||||
var execErr error
|
||||
if errMsg != nil {
|
||||
execErr = errMsg.Error
|
||||
errText := http.StatusText(status)
|
||||
if errMsg.Error != nil && errMsg.Error.Error() != "" {
|
||||
errText = errMsg.Error.Error()
|
||||
}
|
||||
cancel(execErr)
|
||||
return
|
||||
case <-time.After(500 * time.Millisecond):
|
||||
}
|
||||
}
|
||||
body := handlers.BuildErrorResponseBody(status, errText)
|
||||
if alt == "" {
|
||||
_, _ = fmt.Fprintf(c.Writer, "event: error\ndata: %s\n\n", string(body))
|
||||
} else {
|
||||
_, _ = c.Writer.Write(body)
|
||||
}
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
@@ -9,9 +9,12 @@ import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/google/uuid"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/interfaces"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/logging"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
||||
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||
coreexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
|
||||
@@ -40,6 +43,117 @@ type ErrorDetail struct {
|
||||
Code string `json:"code,omitempty"`
|
||||
}
|
||||
|
||||
const idempotencyKeyMetadataKey = "idempotency_key"
|
||||
|
||||
const (
|
||||
defaultStreamingKeepAliveSeconds = 0
|
||||
defaultStreamingBootstrapRetries = 0
|
||||
)
|
||||
|
||||
// BuildErrorResponseBody builds an OpenAI-compatible JSON error response body.
|
||||
// If errText is already valid JSON, it is returned as-is to preserve upstream error payloads.
|
||||
func BuildErrorResponseBody(status int, errText string) []byte {
|
||||
if status <= 0 {
|
||||
status = http.StatusInternalServerError
|
||||
}
|
||||
if strings.TrimSpace(errText) == "" {
|
||||
errText = http.StatusText(status)
|
||||
}
|
||||
|
||||
trimmed := strings.TrimSpace(errText)
|
||||
if trimmed != "" && json.Valid([]byte(trimmed)) {
|
||||
return []byte(trimmed)
|
||||
}
|
||||
|
||||
errType := "invalid_request_error"
|
||||
var code string
|
||||
switch status {
|
||||
case http.StatusUnauthorized:
|
||||
errType = "authentication_error"
|
||||
code = "invalid_api_key"
|
||||
case http.StatusForbidden:
|
||||
errType = "permission_error"
|
||||
code = "insufficient_quota"
|
||||
case http.StatusTooManyRequests:
|
||||
errType = "rate_limit_error"
|
||||
code = "rate_limit_exceeded"
|
||||
case http.StatusNotFound:
|
||||
errType = "invalid_request_error"
|
||||
code = "model_not_found"
|
||||
default:
|
||||
if status >= http.StatusInternalServerError {
|
||||
errType = "server_error"
|
||||
code = "internal_server_error"
|
||||
}
|
||||
}
|
||||
|
||||
payload, err := json.Marshal(ErrorResponse{
|
||||
Error: ErrorDetail{
|
||||
Message: errText,
|
||||
Type: errType,
|
||||
Code: code,
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
return []byte(fmt.Sprintf(`{"error":{"message":%q,"type":"server_error","code":"internal_server_error"}}`, errText))
|
||||
}
|
||||
return payload
|
||||
}
|
||||
|
||||
// StreamingKeepAliveInterval returns the SSE keep-alive interval for this server.
|
||||
// Returning 0 disables keep-alives (default when unset).
|
||||
func StreamingKeepAliveInterval(cfg *config.SDKConfig) time.Duration {
|
||||
seconds := defaultStreamingKeepAliveSeconds
|
||||
if cfg != nil && cfg.Streaming.KeepAliveSeconds != nil {
|
||||
seconds = *cfg.Streaming.KeepAliveSeconds
|
||||
}
|
||||
if seconds <= 0 {
|
||||
return 0
|
||||
}
|
||||
return time.Duration(seconds) * time.Second
|
||||
}
|
||||
|
||||
// StreamingBootstrapRetries returns how many times a streaming request may be retried before any bytes are sent.
|
||||
func StreamingBootstrapRetries(cfg *config.SDKConfig) int {
|
||||
retries := defaultStreamingBootstrapRetries
|
||||
if cfg != nil && cfg.Streaming.BootstrapRetries != nil {
|
||||
retries = *cfg.Streaming.BootstrapRetries
|
||||
}
|
||||
if retries < 0 {
|
||||
retries = 0
|
||||
}
|
||||
return retries
|
||||
}
|
||||
|
||||
func requestExecutionMetadata(ctx context.Context) map[string]any {
|
||||
// Idempotency-Key is an optional client-supplied header used to correlate retries.
|
||||
// It is forwarded as execution metadata; when absent we generate a UUID.
|
||||
key := ""
|
||||
if ctx != nil {
|
||||
if ginCtx, ok := ctx.Value("gin").(*gin.Context); ok && ginCtx != nil && ginCtx.Request != nil {
|
||||
key = strings.TrimSpace(ginCtx.GetHeader("Idempotency-Key"))
|
||||
}
|
||||
}
|
||||
if key == "" {
|
||||
key = uuid.NewString()
|
||||
}
|
||||
return map[string]any{idempotencyKeyMetadataKey: key}
|
||||
}
|
||||
|
||||
func mergeMetadata(base, overlay map[string]any) map[string]any {
|
||||
if len(base) == 0 && len(overlay) == 0 {
|
||||
return nil
|
||||
}
|
||||
out := make(map[string]any, len(base)+len(overlay))
|
||||
for k, v := range base {
|
||||
out[k] = v
|
||||
}
|
||||
for k, v := range overlay {
|
||||
out[k] = v
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// BaseAPIHandler contains the handlers for API endpoints.
|
||||
// It holds a pool of clients to interact with the backend service and manages
|
||||
// load balancing, client selection, and configuration.
|
||||
@@ -103,13 +217,39 @@ func (h *BaseAPIHandler) GetAlt(c *gin.Context) string {
|
||||
// Parameters:
|
||||
// - handler: The API handler associated with the request.
|
||||
// - c: The Gin context of the current request.
|
||||
// - ctx: The parent context.
|
||||
// - ctx: The parent context (caller values/deadlines are preserved; request context adds cancellation and request ID).
|
||||
//
|
||||
// Returns:
|
||||
// - context.Context: The new context with cancellation and embedded values.
|
||||
// - APIHandlerCancelFunc: A function to cancel the context and log the response.
|
||||
func (h *BaseAPIHandler) GetContextWithCancel(handler interfaces.APIHandler, c *gin.Context, ctx context.Context) (context.Context, APIHandlerCancelFunc) {
|
||||
newCtx, cancel := context.WithCancel(ctx)
|
||||
parentCtx := ctx
|
||||
if parentCtx == nil {
|
||||
parentCtx = context.Background()
|
||||
}
|
||||
|
||||
var requestCtx context.Context
|
||||
if c != nil && c.Request != nil {
|
||||
requestCtx = c.Request.Context()
|
||||
}
|
||||
|
||||
if requestCtx != nil && logging.GetRequestID(parentCtx) == "" {
|
||||
if requestID := logging.GetRequestID(requestCtx); requestID != "" {
|
||||
parentCtx = logging.WithRequestID(parentCtx, requestID)
|
||||
} else if requestID := logging.GetGinRequestID(c); requestID != "" {
|
||||
parentCtx = logging.WithRequestID(parentCtx, requestID)
|
||||
}
|
||||
}
|
||||
newCtx, cancel := context.WithCancel(parentCtx)
|
||||
if requestCtx != nil && requestCtx != parentCtx {
|
||||
go func() {
|
||||
select {
|
||||
case <-requestCtx.Done():
|
||||
cancel()
|
||||
case <-newCtx.Done():
|
||||
}
|
||||
}()
|
||||
}
|
||||
newCtx = context.WithValue(newCtx, "gin", c)
|
||||
newCtx = context.WithValue(newCtx, "handler", handler)
|
||||
return newCtx, func(params ...interface{}) {
|
||||
@@ -182,6 +322,7 @@ func (h *BaseAPIHandler) ExecuteWithAuthManager(ctx context.Context, handlerType
|
||||
if errMsg != nil {
|
||||
return nil, errMsg
|
||||
}
|
||||
reqMeta := requestExecutionMetadata(ctx)
|
||||
req := coreexecutor.Request{
|
||||
Model: normalizedModel,
|
||||
Payload: cloneBytes(rawJSON),
|
||||
@@ -195,9 +336,7 @@ func (h *BaseAPIHandler) ExecuteWithAuthManager(ctx context.Context, handlerType
|
||||
OriginalRequest: cloneBytes(rawJSON),
|
||||
SourceFormat: sdktranslator.FromString(handlerType),
|
||||
}
|
||||
if cloned := cloneMetadata(metadata); cloned != nil {
|
||||
opts.Metadata = cloned
|
||||
}
|
||||
opts.Metadata = mergeMetadata(cloneMetadata(metadata), reqMeta)
|
||||
resp, err := h.AuthManager.Execute(ctx, providers, req, opts)
|
||||
if err != nil {
|
||||
status := http.StatusInternalServerError
|
||||
@@ -224,6 +363,7 @@ func (h *BaseAPIHandler) ExecuteCountWithAuthManager(ctx context.Context, handle
|
||||
if errMsg != nil {
|
||||
return nil, errMsg
|
||||
}
|
||||
reqMeta := requestExecutionMetadata(ctx)
|
||||
req := coreexecutor.Request{
|
||||
Model: normalizedModel,
|
||||
Payload: cloneBytes(rawJSON),
|
||||
@@ -237,9 +377,7 @@ func (h *BaseAPIHandler) ExecuteCountWithAuthManager(ctx context.Context, handle
|
||||
OriginalRequest: cloneBytes(rawJSON),
|
||||
SourceFormat: sdktranslator.FromString(handlerType),
|
||||
}
|
||||
if cloned := cloneMetadata(metadata); cloned != nil {
|
||||
opts.Metadata = cloned
|
||||
}
|
||||
opts.Metadata = mergeMetadata(cloneMetadata(metadata), reqMeta)
|
||||
resp, err := h.AuthManager.ExecuteCount(ctx, providers, req, opts)
|
||||
if err != nil {
|
||||
status := http.StatusInternalServerError
|
||||
@@ -269,6 +407,7 @@ func (h *BaseAPIHandler) ExecuteStreamWithAuthManager(ctx context.Context, handl
|
||||
close(errChan)
|
||||
return nil, errChan
|
||||
}
|
||||
reqMeta := requestExecutionMetadata(ctx)
|
||||
req := coreexecutor.Request{
|
||||
Model: normalizedModel,
|
||||
Payload: cloneBytes(rawJSON),
|
||||
@@ -282,9 +421,7 @@ func (h *BaseAPIHandler) ExecuteStreamWithAuthManager(ctx context.Context, handl
|
||||
OriginalRequest: cloneBytes(rawJSON),
|
||||
SourceFormat: sdktranslator.FromString(handlerType),
|
||||
}
|
||||
if cloned := cloneMetadata(metadata); cloned != nil {
|
||||
opts.Metadata = cloned
|
||||
}
|
||||
opts.Metadata = mergeMetadata(cloneMetadata(metadata), reqMeta)
|
||||
chunks, err := h.AuthManager.ExecuteStream(ctx, providers, req, opts)
|
||||
if err != nil {
|
||||
errChan := make(chan *interfaces.ErrorMessage, 1)
|
||||
@@ -309,31 +446,94 @@ func (h *BaseAPIHandler) ExecuteStreamWithAuthManager(ctx context.Context, handl
|
||||
go func() {
|
||||
defer close(dataChan)
|
||||
defer close(errChan)
|
||||
for chunk := range chunks {
|
||||
if chunk.Err != nil {
|
||||
status := http.StatusInternalServerError
|
||||
if se, ok := chunk.Err.(interface{ StatusCode() int }); ok && se != nil {
|
||||
if code := se.StatusCode(); code > 0 {
|
||||
status = code
|
||||
}
|
||||
}
|
||||
var addon http.Header
|
||||
if he, ok := chunk.Err.(interface{ Headers() http.Header }); ok && he != nil {
|
||||
if hdr := he.Headers(); hdr != nil {
|
||||
addon = hdr.Clone()
|
||||
}
|
||||
}
|
||||
errChan <- &interfaces.ErrorMessage{StatusCode: status, Error: chunk.Err, Addon: addon}
|
||||
return
|
||||
sentPayload := false
|
||||
bootstrapRetries := 0
|
||||
maxBootstrapRetries := StreamingBootstrapRetries(h.Cfg)
|
||||
|
||||
bootstrapEligible := func(err error) bool {
|
||||
status := statusFromError(err)
|
||||
if status == 0 {
|
||||
return true
|
||||
}
|
||||
if len(chunk.Payload) > 0 {
|
||||
dataChan <- cloneBytes(chunk.Payload)
|
||||
switch status {
|
||||
case http.StatusUnauthorized, http.StatusForbidden, http.StatusPaymentRequired,
|
||||
http.StatusRequestTimeout, http.StatusTooManyRequests:
|
||||
return true
|
||||
default:
|
||||
return status >= http.StatusInternalServerError
|
||||
}
|
||||
}
|
||||
|
||||
outer:
|
||||
for {
|
||||
for {
|
||||
var chunk coreexecutor.StreamChunk
|
||||
var ok bool
|
||||
if ctx != nil {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case chunk, ok = <-chunks:
|
||||
}
|
||||
} else {
|
||||
chunk, ok = <-chunks
|
||||
}
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
if chunk.Err != nil {
|
||||
streamErr := chunk.Err
|
||||
// Safe bootstrap recovery: if the upstream fails before any payload bytes are sent,
|
||||
// retry a few times (to allow auth rotation / transient recovery) and then attempt model fallback.
|
||||
if !sentPayload {
|
||||
if bootstrapRetries < maxBootstrapRetries && bootstrapEligible(streamErr) {
|
||||
bootstrapRetries++
|
||||
retryChunks, retryErr := h.AuthManager.ExecuteStream(ctx, providers, req, opts)
|
||||
if retryErr == nil {
|
||||
chunks = retryChunks
|
||||
continue outer
|
||||
}
|
||||
streamErr = retryErr
|
||||
}
|
||||
}
|
||||
|
||||
status := http.StatusInternalServerError
|
||||
if se, ok := streamErr.(interface{ StatusCode() int }); ok && se != nil {
|
||||
if code := se.StatusCode(); code > 0 {
|
||||
status = code
|
||||
}
|
||||
}
|
||||
var addon http.Header
|
||||
if he, ok := streamErr.(interface{ Headers() http.Header }); ok && he != nil {
|
||||
if hdr := he.Headers(); hdr != nil {
|
||||
addon = hdr.Clone()
|
||||
}
|
||||
}
|
||||
errChan <- &interfaces.ErrorMessage{StatusCode: status, Error: streamErr, Addon: addon}
|
||||
return
|
||||
}
|
||||
if len(chunk.Payload) > 0 {
|
||||
sentPayload = true
|
||||
dataChan <- cloneBytes(chunk.Payload)
|
||||
}
|
||||
}
|
||||
}
|
||||
}()
|
||||
return dataChan, errChan
|
||||
}
|
||||
|
||||
func statusFromError(err error) int {
|
||||
if err == nil {
|
||||
return 0
|
||||
}
|
||||
if se, ok := err.(interface{ StatusCode() int }); ok && se != nil {
|
||||
if code := se.StatusCode(); code > 0 {
|
||||
return code
|
||||
}
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
func (h *BaseAPIHandler) getRequestDetails(modelName string) (providers []string, normalizedModel string, metadata map[string]any, err *interfaces.ErrorMessage) {
|
||||
// Resolve "auto" model to an actual available model first
|
||||
resolvedModelName := util.ResolveAutoModel(modelName)
|
||||
@@ -417,38 +617,7 @@ func (h *BaseAPIHandler) WriteErrorResponse(c *gin.Context, msg *interfaces.Erro
|
||||
}
|
||||
}
|
||||
|
||||
// Prefer preserving upstream JSON error bodies when possible.
|
||||
buildJSONBody := func() []byte {
|
||||
trimmed := strings.TrimSpace(errText)
|
||||
if trimmed != "" && json.Valid([]byte(trimmed)) {
|
||||
return []byte(trimmed)
|
||||
}
|
||||
errType := "invalid_request_error"
|
||||
switch status {
|
||||
case http.StatusUnauthorized:
|
||||
errType = "authentication_error"
|
||||
case http.StatusForbidden:
|
||||
errType = "permission_error"
|
||||
case http.StatusTooManyRequests:
|
||||
errType = "rate_limit_error"
|
||||
default:
|
||||
if status >= http.StatusInternalServerError {
|
||||
errType = "server_error"
|
||||
}
|
||||
}
|
||||
payload, err := json.Marshal(ErrorResponse{
|
||||
Error: ErrorDetail{
|
||||
Message: errText,
|
||||
Type: errType,
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
return []byte(fmt.Sprintf(`{"error":{"message":%q,"type":"server_error"}}`, errText))
|
||||
}
|
||||
return payload
|
||||
}
|
||||
|
||||
body := buildJSONBody()
|
||||
body := BuildErrorResponseBody(status, errText)
|
||||
c.Set("API_RESPONSE", bytes.Clone(body))
|
||||
|
||||
if !c.Writer.Written() {
|
||||
|
||||
125
sdk/api/handlers/handlers_stream_bootstrap_test.go
Normal file
125
sdk/api/handlers/handlers_stream_bootstrap_test.go
Normal file
@@ -0,0 +1,125 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
|
||||
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||
coreexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
|
||||
sdkconfig "github.com/router-for-me/CLIProxyAPI/v6/sdk/config"
|
||||
)
|
||||
|
||||
type failOnceStreamExecutor struct {
|
||||
mu sync.Mutex
|
||||
calls int
|
||||
}
|
||||
|
||||
func (e *failOnceStreamExecutor) Identifier() string { return "codex" }
|
||||
|
||||
func (e *failOnceStreamExecutor) Execute(context.Context, *coreauth.Auth, coreexecutor.Request, coreexecutor.Options) (coreexecutor.Response, error) {
|
||||
return coreexecutor.Response{}, &coreauth.Error{Code: "not_implemented", Message: "Execute not implemented"}
|
||||
}
|
||||
|
||||
func (e *failOnceStreamExecutor) ExecuteStream(context.Context, *coreauth.Auth, coreexecutor.Request, coreexecutor.Options) (<-chan coreexecutor.StreamChunk, error) {
|
||||
e.mu.Lock()
|
||||
e.calls++
|
||||
call := e.calls
|
||||
e.mu.Unlock()
|
||||
|
||||
ch := make(chan coreexecutor.StreamChunk, 1)
|
||||
if call == 1 {
|
||||
ch <- coreexecutor.StreamChunk{
|
||||
Err: &coreauth.Error{
|
||||
Code: "unauthorized",
|
||||
Message: "unauthorized",
|
||||
Retryable: false,
|
||||
HTTPStatus: http.StatusUnauthorized,
|
||||
},
|
||||
}
|
||||
close(ch)
|
||||
return ch, nil
|
||||
}
|
||||
|
||||
ch <- coreexecutor.StreamChunk{Payload: []byte("ok")}
|
||||
close(ch)
|
||||
return ch, nil
|
||||
}
|
||||
|
||||
func (e *failOnceStreamExecutor) Refresh(ctx context.Context, auth *coreauth.Auth) (*coreauth.Auth, error) {
|
||||
return auth, nil
|
||||
}
|
||||
|
||||
func (e *failOnceStreamExecutor) CountTokens(context.Context, *coreauth.Auth, coreexecutor.Request, coreexecutor.Options) (coreexecutor.Response, error) {
|
||||
return coreexecutor.Response{}, &coreauth.Error{Code: "not_implemented", Message: "CountTokens not implemented"}
|
||||
}
|
||||
|
||||
func (e *failOnceStreamExecutor) Calls() int {
|
||||
e.mu.Lock()
|
||||
defer e.mu.Unlock()
|
||||
return e.calls
|
||||
}
|
||||
|
||||
func TestExecuteStreamWithAuthManager_RetriesBeforeFirstByte(t *testing.T) {
|
||||
executor := &failOnceStreamExecutor{}
|
||||
manager := coreauth.NewManager(nil, nil, nil)
|
||||
manager.RegisterExecutor(executor)
|
||||
|
||||
auth1 := &coreauth.Auth{
|
||||
ID: "auth1",
|
||||
Provider: "codex",
|
||||
Status: coreauth.StatusActive,
|
||||
Metadata: map[string]any{"email": "test1@example.com"},
|
||||
}
|
||||
if _, err := manager.Register(context.Background(), auth1); err != nil {
|
||||
t.Fatalf("manager.Register(auth1): %v", err)
|
||||
}
|
||||
|
||||
auth2 := &coreauth.Auth{
|
||||
ID: "auth2",
|
||||
Provider: "codex",
|
||||
Status: coreauth.StatusActive,
|
||||
Metadata: map[string]any{"email": "test2@example.com"},
|
||||
}
|
||||
if _, err := manager.Register(context.Background(), auth2); err != nil {
|
||||
t.Fatalf("manager.Register(auth2): %v", err)
|
||||
}
|
||||
|
||||
registry.GetGlobalRegistry().RegisterClient(auth1.ID, auth1.Provider, []*registry.ModelInfo{{ID: "test-model"}})
|
||||
registry.GetGlobalRegistry().RegisterClient(auth2.ID, auth2.Provider, []*registry.ModelInfo{{ID: "test-model"}})
|
||||
t.Cleanup(func() {
|
||||
registry.GetGlobalRegistry().UnregisterClient(auth1.ID)
|
||||
registry.GetGlobalRegistry().UnregisterClient(auth2.ID)
|
||||
})
|
||||
|
||||
bootstrapRetries := 1
|
||||
handler := NewBaseAPIHandlers(&sdkconfig.SDKConfig{
|
||||
Streaming: sdkconfig.StreamingConfig{
|
||||
BootstrapRetries: &bootstrapRetries,
|
||||
},
|
||||
}, manager)
|
||||
dataChan, errChan := handler.ExecuteStreamWithAuthManager(context.Background(), "openai", "test-model", []byte(`{"model":"test-model"}`), "")
|
||||
if dataChan == nil || errChan == nil {
|
||||
t.Fatalf("expected non-nil channels")
|
||||
}
|
||||
|
||||
var got []byte
|
||||
for chunk := range dataChan {
|
||||
got = append(got, chunk...)
|
||||
}
|
||||
|
||||
for msg := range errChan {
|
||||
if msg != nil {
|
||||
t.Fatalf("unexpected error: %+v", msg)
|
||||
}
|
||||
}
|
||||
|
||||
if string(got) != "ok" {
|
||||
t.Fatalf("expected payload ok, got %q", string(got))
|
||||
}
|
||||
if executor.Calls() != 2 {
|
||||
t.Fatalf("expected 2 stream attempts, got %d", executor.Calls())
|
||||
}
|
||||
}
|
||||
@@ -11,7 +11,7 @@ import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"time"
|
||||
"sync"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
. "github.com/router-for-me/CLIProxyAPI/v6/internal/constant"
|
||||
@@ -443,11 +443,6 @@ func (h *OpenAIAPIHandler) handleNonStreamingResponse(c *gin.Context, rawJSON []
|
||||
// - c: The Gin context containing the HTTP request and response
|
||||
// - rawJSON: The raw JSON bytes of the OpenAI-compatible request
|
||||
func (h *OpenAIAPIHandler) handleStreamingResponse(c *gin.Context, rawJSON []byte) {
|
||||
c.Header("Content-Type", "text/event-stream")
|
||||
c.Header("Cache-Control", "no-cache")
|
||||
c.Header("Connection", "keep-alive")
|
||||
c.Header("Access-Control-Allow-Origin", "*")
|
||||
|
||||
// Get the http.Flusher interface to manually flush the response.
|
||||
flusher, ok := c.Writer.(http.Flusher)
|
||||
if !ok {
|
||||
@@ -463,7 +458,55 @@ func (h *OpenAIAPIHandler) handleStreamingResponse(c *gin.Context, rawJSON []byt
|
||||
modelName := gjson.GetBytes(rawJSON, "model").String()
|
||||
cliCtx, cliCancel := h.GetContextWithCancel(h, c, context.Background())
|
||||
dataChan, errChan := h.ExecuteStreamWithAuthManager(cliCtx, h.HandlerType(), modelName, rawJSON, h.GetAlt(c))
|
||||
h.handleStreamResult(c, flusher, func(err error) { cliCancel(err) }, dataChan, errChan)
|
||||
|
||||
setSSEHeaders := func() {
|
||||
c.Header("Content-Type", "text/event-stream")
|
||||
c.Header("Cache-Control", "no-cache")
|
||||
c.Header("Connection", "keep-alive")
|
||||
c.Header("Access-Control-Allow-Origin", "*")
|
||||
}
|
||||
|
||||
// Peek at the first chunk to determine success or failure before setting headers
|
||||
for {
|
||||
select {
|
||||
case <-c.Request.Context().Done():
|
||||
cliCancel(c.Request.Context().Err())
|
||||
return
|
||||
case errMsg, ok := <-errChan:
|
||||
if !ok {
|
||||
// Err channel closed cleanly; wait for data channel.
|
||||
errChan = nil
|
||||
continue
|
||||
}
|
||||
// Upstream failed immediately. Return proper error status and JSON.
|
||||
h.WriteErrorResponse(c, errMsg)
|
||||
if errMsg != nil {
|
||||
cliCancel(errMsg.Error)
|
||||
} else {
|
||||
cliCancel(nil)
|
||||
}
|
||||
return
|
||||
case chunk, ok := <-dataChan:
|
||||
if !ok {
|
||||
// Stream closed without data? Send DONE or just headers.
|
||||
setSSEHeaders()
|
||||
_, _ = fmt.Fprintf(c.Writer, "data: [DONE]\n\n")
|
||||
flusher.Flush()
|
||||
cliCancel(nil)
|
||||
return
|
||||
}
|
||||
|
||||
// Success! Commit to streaming headers.
|
||||
setSSEHeaders()
|
||||
|
||||
_, _ = fmt.Fprintf(c.Writer, "data: %s\n\n", string(chunk))
|
||||
flusher.Flush()
|
||||
|
||||
// Continue streaming the rest
|
||||
h.handleStreamResult(c, flusher, func(err error) { cliCancel(err) }, dataChan, errChan)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// handleCompletionsNonStreamingResponse handles non-streaming completions responses.
|
||||
@@ -500,11 +543,6 @@ func (h *OpenAIAPIHandler) handleCompletionsNonStreamingResponse(c *gin.Context,
|
||||
// - c: The Gin context containing the HTTP request and response
|
||||
// - rawJSON: The raw JSON bytes of the OpenAI-compatible completions request
|
||||
func (h *OpenAIAPIHandler) handleCompletionsStreamingResponse(c *gin.Context, rawJSON []byte) {
|
||||
c.Header("Content-Type", "text/event-stream")
|
||||
c.Header("Cache-Control", "no-cache")
|
||||
c.Header("Connection", "keep-alive")
|
||||
c.Header("Access-Control-Allow-Origin", "*")
|
||||
|
||||
// Get the http.Flusher interface to manually flush the response.
|
||||
flusher, ok := c.Writer.(http.Flusher)
|
||||
if !ok {
|
||||
@@ -524,71 +562,109 @@ func (h *OpenAIAPIHandler) handleCompletionsStreamingResponse(c *gin.Context, ra
|
||||
cliCtx, cliCancel := h.GetContextWithCancel(h, c, context.Background())
|
||||
dataChan, errChan := h.ExecuteStreamWithAuthManager(cliCtx, h.HandlerType(), modelName, chatCompletionsJSON, "")
|
||||
|
||||
setSSEHeaders := func() {
|
||||
c.Header("Content-Type", "text/event-stream")
|
||||
c.Header("Cache-Control", "no-cache")
|
||||
c.Header("Connection", "keep-alive")
|
||||
c.Header("Access-Control-Allow-Origin", "*")
|
||||
}
|
||||
|
||||
// Peek at the first chunk
|
||||
for {
|
||||
select {
|
||||
case <-c.Request.Context().Done():
|
||||
cliCancel(c.Request.Context().Err())
|
||||
return
|
||||
case chunk, isOk := <-dataChan:
|
||||
if !isOk {
|
||||
case errMsg, ok := <-errChan:
|
||||
if !ok {
|
||||
// Err channel closed cleanly; wait for data channel.
|
||||
errChan = nil
|
||||
continue
|
||||
}
|
||||
h.WriteErrorResponse(c, errMsg)
|
||||
if errMsg != nil {
|
||||
cliCancel(errMsg.Error)
|
||||
} else {
|
||||
cliCancel(nil)
|
||||
}
|
||||
return
|
||||
case chunk, ok := <-dataChan:
|
||||
if !ok {
|
||||
setSSEHeaders()
|
||||
_, _ = fmt.Fprintf(c.Writer, "data: [DONE]\n\n")
|
||||
flusher.Flush()
|
||||
cliCancel()
|
||||
cliCancel(nil)
|
||||
return
|
||||
}
|
||||
|
||||
// Success! Set headers.
|
||||
setSSEHeaders()
|
||||
|
||||
// Write the first chunk
|
||||
converted := convertChatCompletionsStreamChunkToCompletions(chunk)
|
||||
if converted != nil {
|
||||
_, _ = fmt.Fprintf(c.Writer, "data: %s\n\n", string(converted))
|
||||
flusher.Flush()
|
||||
}
|
||||
case errMsg, isOk := <-errChan:
|
||||
if !isOk {
|
||||
continue
|
||||
}
|
||||
if errMsg != nil {
|
||||
h.WriteErrorResponse(c, errMsg)
|
||||
flusher.Flush()
|
||||
}
|
||||
var execErr error
|
||||
if errMsg != nil {
|
||||
execErr = errMsg.Error
|
||||
}
|
||||
cliCancel(execErr)
|
||||
|
||||
done := make(chan struct{})
|
||||
var doneOnce sync.Once
|
||||
stop := func() { doneOnce.Do(func() { close(done) }) }
|
||||
|
||||
convertedChan := make(chan []byte)
|
||||
go func() {
|
||||
defer close(convertedChan)
|
||||
for {
|
||||
select {
|
||||
case <-done:
|
||||
return
|
||||
case chunk, ok := <-dataChan:
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
converted := convertChatCompletionsStreamChunkToCompletions(chunk)
|
||||
if converted == nil {
|
||||
continue
|
||||
}
|
||||
select {
|
||||
case <-done:
|
||||
return
|
||||
case convertedChan <- converted:
|
||||
}
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
h.handleStreamResult(c, flusher, func(err error) {
|
||||
stop()
|
||||
cliCancel(err)
|
||||
}, convertedChan, errChan)
|
||||
return
|
||||
case <-time.After(500 * time.Millisecond):
|
||||
}
|
||||
}
|
||||
}
|
||||
func (h *OpenAIAPIHandler) handleStreamResult(c *gin.Context, flusher http.Flusher, cancel func(error), data <-chan []byte, errs <-chan *interfaces.ErrorMessage) {
|
||||
for {
|
||||
select {
|
||||
case <-c.Request.Context().Done():
|
||||
cancel(c.Request.Context().Err())
|
||||
return
|
||||
case chunk, ok := <-data:
|
||||
if !ok {
|
||||
_, _ = fmt.Fprintf(c.Writer, "data: [DONE]\n\n")
|
||||
flusher.Flush()
|
||||
cancel(nil)
|
||||
h.ForwardStream(c, flusher, cancel, data, errs, handlers.StreamForwardOptions{
|
||||
WriteChunk: func(chunk []byte) {
|
||||
_, _ = fmt.Fprintf(c.Writer, "data: %s\n\n", string(chunk))
|
||||
},
|
||||
WriteTerminalError: func(errMsg *interfaces.ErrorMessage) {
|
||||
if errMsg == nil {
|
||||
return
|
||||
}
|
||||
_, _ = fmt.Fprintf(c.Writer, "data: %s\n\n", string(chunk))
|
||||
flusher.Flush()
|
||||
case errMsg, ok := <-errs:
|
||||
if !ok {
|
||||
continue
|
||||
status := http.StatusInternalServerError
|
||||
if errMsg.StatusCode > 0 {
|
||||
status = errMsg.StatusCode
|
||||
}
|
||||
if errMsg != nil {
|
||||
h.WriteErrorResponse(c, errMsg)
|
||||
flusher.Flush()
|
||||
errText := http.StatusText(status)
|
||||
if errMsg.Error != nil && errMsg.Error.Error() != "" {
|
||||
errText = errMsg.Error.Error()
|
||||
}
|
||||
var execErr error
|
||||
if errMsg != nil {
|
||||
execErr = errMsg.Error
|
||||
}
|
||||
cancel(execErr)
|
||||
return
|
||||
case <-time.After(500 * time.Millisecond):
|
||||
}
|
||||
}
|
||||
body := handlers.BuildErrorResponseBody(status, errText)
|
||||
_, _ = fmt.Fprintf(c.Writer, "data: %s\n\n", string(body))
|
||||
},
|
||||
WriteDone: func() {
|
||||
_, _ = fmt.Fprint(c.Writer, "data: [DONE]\n\n")
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
@@ -11,7 +11,6 @@ import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
. "github.com/router-for-me/CLIProxyAPI/v6/internal/constant"
|
||||
@@ -128,11 +127,6 @@ func (h *OpenAIResponsesAPIHandler) handleNonStreamingResponse(c *gin.Context, r
|
||||
// - c: The Gin context containing the HTTP request and response
|
||||
// - rawJSON: The raw JSON bytes of the OpenAIResponses-compatible request
|
||||
func (h *OpenAIResponsesAPIHandler) handleStreamingResponse(c *gin.Context, rawJSON []byte) {
|
||||
c.Header("Content-Type", "text/event-stream")
|
||||
c.Header("Cache-Control", "no-cache")
|
||||
c.Header("Connection", "keep-alive")
|
||||
c.Header("Access-Control-Allow-Origin", "*")
|
||||
|
||||
// Get the http.Flusher interface to manually flush the response.
|
||||
flusher, ok := c.Writer.(http.Flusher)
|
||||
if !ok {
|
||||
@@ -149,46 +143,88 @@ func (h *OpenAIResponsesAPIHandler) handleStreamingResponse(c *gin.Context, rawJ
|
||||
modelName := gjson.GetBytes(rawJSON, "model").String()
|
||||
cliCtx, cliCancel := h.GetContextWithCancel(h, c, context.Background())
|
||||
dataChan, errChan := h.ExecuteStreamWithAuthManager(cliCtx, h.HandlerType(), modelName, rawJSON, "")
|
||||
h.forwardResponsesStream(c, flusher, func(err error) { cliCancel(err) }, dataChan, errChan)
|
||||
return
|
||||
}
|
||||
|
||||
func (h *OpenAIResponsesAPIHandler) forwardResponsesStream(c *gin.Context, flusher http.Flusher, cancel func(error), data <-chan []byte, errs <-chan *interfaces.ErrorMessage) {
|
||||
setSSEHeaders := func() {
|
||||
c.Header("Content-Type", "text/event-stream")
|
||||
c.Header("Cache-Control", "no-cache")
|
||||
c.Header("Connection", "keep-alive")
|
||||
c.Header("Access-Control-Allow-Origin", "*")
|
||||
}
|
||||
|
||||
// Peek at the first chunk
|
||||
for {
|
||||
select {
|
||||
case <-c.Request.Context().Done():
|
||||
cancel(c.Request.Context().Err())
|
||||
cliCancel(c.Request.Context().Err())
|
||||
return
|
||||
case chunk, ok := <-data:
|
||||
case errMsg, ok := <-errChan:
|
||||
if !ok {
|
||||
// Err channel closed cleanly; wait for data channel.
|
||||
errChan = nil
|
||||
continue
|
||||
}
|
||||
// Upstream failed immediately. Return proper error status and JSON.
|
||||
h.WriteErrorResponse(c, errMsg)
|
||||
if errMsg != nil {
|
||||
cliCancel(errMsg.Error)
|
||||
} else {
|
||||
cliCancel(nil)
|
||||
}
|
||||
return
|
||||
case chunk, ok := <-dataChan:
|
||||
if !ok {
|
||||
// Stream closed without data? Send headers and done.
|
||||
setSSEHeaders()
|
||||
_, _ = c.Writer.Write([]byte("\n"))
|
||||
flusher.Flush()
|
||||
cancel(nil)
|
||||
cliCancel(nil)
|
||||
return
|
||||
}
|
||||
|
||||
// Success! Set headers.
|
||||
setSSEHeaders()
|
||||
|
||||
// Write first chunk logic (matching forwardResponsesStream)
|
||||
if bytes.HasPrefix(chunk, []byte("event:")) {
|
||||
_, _ = c.Writer.Write([]byte("\n"))
|
||||
}
|
||||
_, _ = c.Writer.Write(chunk)
|
||||
_, _ = c.Writer.Write([]byte("\n"))
|
||||
|
||||
flusher.Flush()
|
||||
case errMsg, ok := <-errs:
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
if errMsg != nil {
|
||||
h.WriteErrorResponse(c, errMsg)
|
||||
flusher.Flush()
|
||||
}
|
||||
var execErr error
|
||||
if errMsg != nil {
|
||||
execErr = errMsg.Error
|
||||
}
|
||||
cancel(execErr)
|
||||
|
||||
// Continue
|
||||
h.forwardResponsesStream(c, flusher, func(err error) { cliCancel(err) }, dataChan, errChan)
|
||||
return
|
||||
case <-time.After(500 * time.Millisecond):
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (h *OpenAIResponsesAPIHandler) forwardResponsesStream(c *gin.Context, flusher http.Flusher, cancel func(error), data <-chan []byte, errs <-chan *interfaces.ErrorMessage) {
|
||||
h.ForwardStream(c, flusher, cancel, data, errs, handlers.StreamForwardOptions{
|
||||
WriteChunk: func(chunk []byte) {
|
||||
if bytes.HasPrefix(chunk, []byte("event:")) {
|
||||
_, _ = c.Writer.Write([]byte("\n"))
|
||||
}
|
||||
_, _ = c.Writer.Write(chunk)
|
||||
_, _ = c.Writer.Write([]byte("\n"))
|
||||
},
|
||||
WriteTerminalError: func(errMsg *interfaces.ErrorMessage) {
|
||||
if errMsg == nil {
|
||||
return
|
||||
}
|
||||
status := http.StatusInternalServerError
|
||||
if errMsg.StatusCode > 0 {
|
||||
status = errMsg.StatusCode
|
||||
}
|
||||
errText := http.StatusText(status)
|
||||
if errMsg.Error != nil && errMsg.Error.Error() != "" {
|
||||
errText = errMsg.Error.Error()
|
||||
}
|
||||
body := handlers.BuildErrorResponseBody(status, errText)
|
||||
_, _ = fmt.Fprintf(c.Writer, "\nevent: error\ndata: %s\n\n", string(body))
|
||||
},
|
||||
WriteDone: func() {
|
||||
_, _ = c.Writer.Write([]byte("\n"))
|
||||
},
|
||||
})
|
||||
}
|
||||
|
||||
121
sdk/api/handlers/stream_forwarder.go
Normal file
121
sdk/api/handlers/stream_forwarder.go
Normal file
@@ -0,0 +1,121 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/interfaces"
|
||||
)
|
||||
|
||||
type StreamForwardOptions struct {
|
||||
// KeepAliveInterval overrides the configured streaming keep-alive interval.
|
||||
// If nil, the configured default is used. If set to <= 0, keep-alives are disabled.
|
||||
KeepAliveInterval *time.Duration
|
||||
|
||||
// WriteChunk writes a single data chunk to the response body. It should not flush.
|
||||
WriteChunk func(chunk []byte)
|
||||
|
||||
// WriteTerminalError writes an error payload to the response body when streaming fails
|
||||
// after headers have already been committed. It should not flush.
|
||||
WriteTerminalError func(errMsg *interfaces.ErrorMessage)
|
||||
|
||||
// WriteDone optionally writes a terminal marker when the upstream data channel closes
|
||||
// without an error (e.g. OpenAI's `[DONE]`). It should not flush.
|
||||
WriteDone func()
|
||||
|
||||
// WriteKeepAlive optionally writes a keep-alive heartbeat. It should not flush.
|
||||
// When nil, a standard SSE comment heartbeat is used.
|
||||
WriteKeepAlive func()
|
||||
}
|
||||
|
||||
func (h *BaseAPIHandler) ForwardStream(c *gin.Context, flusher http.Flusher, cancel func(error), data <-chan []byte, errs <-chan *interfaces.ErrorMessage, opts StreamForwardOptions) {
|
||||
if c == nil {
|
||||
return
|
||||
}
|
||||
if cancel == nil {
|
||||
return
|
||||
}
|
||||
|
||||
writeChunk := opts.WriteChunk
|
||||
if writeChunk == nil {
|
||||
writeChunk = func([]byte) {}
|
||||
}
|
||||
|
||||
writeKeepAlive := opts.WriteKeepAlive
|
||||
if writeKeepAlive == nil {
|
||||
writeKeepAlive = func() {
|
||||
_, _ = c.Writer.Write([]byte(": keep-alive\n\n"))
|
||||
}
|
||||
}
|
||||
|
||||
keepAliveInterval := StreamingKeepAliveInterval(h.Cfg)
|
||||
if opts.KeepAliveInterval != nil {
|
||||
keepAliveInterval = *opts.KeepAliveInterval
|
||||
}
|
||||
var keepAlive *time.Ticker
|
||||
var keepAliveC <-chan time.Time
|
||||
if keepAliveInterval > 0 {
|
||||
keepAlive = time.NewTicker(keepAliveInterval)
|
||||
defer keepAlive.Stop()
|
||||
keepAliveC = keepAlive.C
|
||||
}
|
||||
|
||||
var terminalErr *interfaces.ErrorMessage
|
||||
for {
|
||||
select {
|
||||
case <-c.Request.Context().Done():
|
||||
cancel(c.Request.Context().Err())
|
||||
return
|
||||
case chunk, ok := <-data:
|
||||
if !ok {
|
||||
// Prefer surfacing a terminal error if one is pending.
|
||||
if terminalErr == nil {
|
||||
select {
|
||||
case errMsg, ok := <-errs:
|
||||
if ok && errMsg != nil {
|
||||
terminalErr = errMsg
|
||||
}
|
||||
default:
|
||||
}
|
||||
}
|
||||
if terminalErr != nil {
|
||||
if opts.WriteTerminalError != nil {
|
||||
opts.WriteTerminalError(terminalErr)
|
||||
}
|
||||
flusher.Flush()
|
||||
cancel(terminalErr.Error)
|
||||
return
|
||||
}
|
||||
if opts.WriteDone != nil {
|
||||
opts.WriteDone()
|
||||
}
|
||||
flusher.Flush()
|
||||
cancel(nil)
|
||||
return
|
||||
}
|
||||
writeChunk(chunk)
|
||||
flusher.Flush()
|
||||
case errMsg, ok := <-errs:
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
if errMsg != nil {
|
||||
terminalErr = errMsg
|
||||
if opts.WriteTerminalError != nil {
|
||||
opts.WriteTerminalError(errMsg)
|
||||
flusher.Flush()
|
||||
}
|
||||
}
|
||||
var execErr error
|
||||
if errMsg != nil {
|
||||
execErr = errMsg.Error
|
||||
}
|
||||
cancel(execErr)
|
||||
return
|
||||
case <-keepAliveC:
|
||||
writeKeepAlive()
|
||||
flusher.Flush()
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -12,6 +12,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/logging"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
||||
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
|
||||
@@ -389,17 +390,18 @@ func (m *Manager) executeWithProvider(ctx context.Context, provider string, req
|
||||
|
||||
accountType, accountInfo := auth.AccountInfo()
|
||||
proxyInfo := auth.ProxyInfo()
|
||||
entry := logEntryWithRequestID(ctx)
|
||||
if accountType == "api_key" {
|
||||
if proxyInfo != "" {
|
||||
log.Debugf("Use API key %s for model %s %s", util.HideAPIKey(accountInfo), req.Model, proxyInfo)
|
||||
entry.Debugf("Use API key %s for model %s %s", util.HideAPIKey(accountInfo), req.Model, proxyInfo)
|
||||
} else {
|
||||
log.Debugf("Use API key %s for model %s", util.HideAPIKey(accountInfo), req.Model)
|
||||
entry.Debugf("Use API key %s for model %s", util.HideAPIKey(accountInfo), req.Model)
|
||||
}
|
||||
} else if accountType == "oauth" {
|
||||
if proxyInfo != "" {
|
||||
log.Debugf("Use OAuth %s for model %s %s", accountInfo, req.Model, proxyInfo)
|
||||
entry.Debugf("Use OAuth %s for model %s %s", accountInfo, req.Model, proxyInfo)
|
||||
} else {
|
||||
log.Debugf("Use OAuth %s for model %s", accountInfo, req.Model)
|
||||
entry.Debugf("Use OAuth %s for model %s", accountInfo, req.Model)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -449,17 +451,18 @@ func (m *Manager) executeCountWithProvider(ctx context.Context, provider string,
|
||||
|
||||
accountType, accountInfo := auth.AccountInfo()
|
||||
proxyInfo := auth.ProxyInfo()
|
||||
entry := logEntryWithRequestID(ctx)
|
||||
if accountType == "api_key" {
|
||||
if proxyInfo != "" {
|
||||
log.Debugf("Use API key %s for model %s %s", util.HideAPIKey(accountInfo), req.Model, proxyInfo)
|
||||
entry.Debugf("Use API key %s for model %s %s", util.HideAPIKey(accountInfo), req.Model, proxyInfo)
|
||||
} else {
|
||||
log.Debugf("Use API key %s for model %s", util.HideAPIKey(accountInfo), req.Model)
|
||||
entry.Debugf("Use API key %s for model %s", util.HideAPIKey(accountInfo), req.Model)
|
||||
}
|
||||
} else if accountType == "oauth" {
|
||||
if proxyInfo != "" {
|
||||
log.Debugf("Use OAuth %s for model %s %s", accountInfo, req.Model, proxyInfo)
|
||||
entry.Debugf("Use OAuth %s for model %s %s", accountInfo, req.Model, proxyInfo)
|
||||
} else {
|
||||
log.Debugf("Use OAuth %s for model %s", accountInfo, req.Model)
|
||||
entry.Debugf("Use OAuth %s for model %s", accountInfo, req.Model)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -509,17 +512,18 @@ func (m *Manager) executeStreamWithProvider(ctx context.Context, provider string
|
||||
|
||||
accountType, accountInfo := auth.AccountInfo()
|
||||
proxyInfo := auth.ProxyInfo()
|
||||
entry := logEntryWithRequestID(ctx)
|
||||
if accountType == "api_key" {
|
||||
if proxyInfo != "" {
|
||||
log.Debugf("Use API key %s for model %s %s", util.HideAPIKey(accountInfo), req.Model, proxyInfo)
|
||||
entry.Debugf("Use API key %s for model %s %s", util.HideAPIKey(accountInfo), req.Model, proxyInfo)
|
||||
} else {
|
||||
log.Debugf("Use API key %s for model %s", util.HideAPIKey(accountInfo), req.Model)
|
||||
entry.Debugf("Use API key %s for model %s", util.HideAPIKey(accountInfo), req.Model)
|
||||
}
|
||||
} else if accountType == "oauth" {
|
||||
if proxyInfo != "" {
|
||||
log.Debugf("Use OAuth %s for model %s %s", accountInfo, req.Model, proxyInfo)
|
||||
entry.Debugf("Use OAuth %s for model %s %s", accountInfo, req.Model, proxyInfo)
|
||||
} else {
|
||||
log.Debugf("Use OAuth %s for model %s", accountInfo, req.Model)
|
||||
entry.Debugf("Use OAuth %s for model %s", accountInfo, req.Model)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1604,6 +1608,17 @@ type RequestPreparer interface {
|
||||
PrepareRequest(req *http.Request, auth *Auth) error
|
||||
}
|
||||
|
||||
// logEntryWithRequestID returns a logrus entry with request_id field if available in context.
|
||||
func logEntryWithRequestID(ctx context.Context) *log.Entry {
|
||||
if ctx == nil {
|
||||
return log.NewEntry(log.StandardLogger())
|
||||
}
|
||||
if reqID := logging.GetRequestID(ctx); reqID != "" {
|
||||
return log.WithField("request_id", reqID)
|
||||
}
|
||||
return log.NewEntry(log.StandardLogger())
|
||||
}
|
||||
|
||||
// InjectCredentials delegates per-provider HTTP request preparation when supported.
|
||||
// If the registered executor for the auth provider implements RequestPreparer,
|
||||
// it will be invoked to modify the request (e.g., add headers).
|
||||
@@ -12,6 +12,7 @@ type AccessProvider = internalconfig.AccessProvider
|
||||
|
||||
type Config = internalconfig.Config
|
||||
|
||||
type StreamingConfig = internalconfig.StreamingConfig
|
||||
type TLSConfig = internalconfig.TLSConfig
|
||||
type RemoteManagement = internalconfig.RemoteManagement
|
||||
type AmpCode = internalconfig.AmpCode
|
||||
|
||||
Reference in New Issue
Block a user