mirror of
https://github.com/router-for-me/CLIProxyAPI.git
synced 2026-02-04 05:20:52 +08:00
Compare commits
33 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
39621a0340 | ||
|
|
346b663079 | ||
|
|
0bcae68c6c | ||
|
|
c8cee547fd | ||
|
|
36755421fe | ||
|
|
6c17dbc4da | ||
|
|
ee6429cc75 | ||
|
|
a4a26d978e | ||
|
|
ed9f6e897e | ||
|
|
9c1e3c0687 | ||
|
|
2e5681ea32 | ||
|
|
52c17f03a5 | ||
|
|
d0e694d4ed | ||
|
|
506f1117dd | ||
|
|
113db3c5bf | ||
|
|
1aa0b6cd11 | ||
|
|
0895533400 | ||
|
|
43f007c234 | ||
|
|
0ceee56d99 | ||
|
|
943a8c74df | ||
|
|
0a47b452e9 | ||
|
|
261f08a82a | ||
|
|
d114d8d0bd | ||
|
|
bb9955e461 | ||
|
|
7063a176f4 | ||
|
|
e3082887a6 | ||
|
|
ddb0c0ec1c | ||
|
|
d1736cb29c | ||
|
|
62bfd62871 | ||
|
|
257621c5ed | ||
|
|
ac064389ca | ||
|
|
8d23ffc873 | ||
|
|
4307f08bbc |
@@ -1,6 +1,12 @@
|
||||
# Server port
|
||||
port: 8317
|
||||
|
||||
# TLS settings for HTTPS. When enabled, the server listens with the provided certificate and key.
|
||||
tls:
|
||||
enable: false
|
||||
cert: ""
|
||||
key: ""
|
||||
|
||||
# Management API settings
|
||||
remote-management:
|
||||
# Whether to allow remote (non-localhost) management access.
|
||||
@@ -38,6 +44,9 @@ proxy-url: ""
|
||||
# Number of times to retry a request. Retries will occur if the HTTP response code is 403, 408, 500, 502, 503, or 504.
|
||||
request-retry: 3
|
||||
|
||||
# Maximum wait time in seconds for a cooled-down credential before triggering a retry.
|
||||
max-retry-interval: 30
|
||||
|
||||
# Quota exceeded behavior
|
||||
quota-exceeded:
|
||||
switch-project: true # Whether to automatically switch to another project when a quota is exceeded
|
||||
|
||||
@@ -235,7 +235,11 @@ func (h *Handler) managementCallbackURL(path string) (string, error) {
|
||||
if !strings.HasPrefix(path, "/") {
|
||||
path = "/" + path
|
||||
}
|
||||
return fmt.Sprintf("http://127.0.0.1:%d%s", h.cfg.Port, path), nil
|
||||
scheme := "http"
|
||||
if h.cfg.TLS.Enable {
|
||||
scheme = "https"
|
||||
}
|
||||
return fmt.Sprintf("%s://127.0.0.1:%d%s", scheme, h.cfg.Port, path), nil
|
||||
}
|
||||
|
||||
func (h *Handler) ListAuthFiles(c *gin.Context) {
|
||||
|
||||
@@ -172,6 +172,14 @@ func (h *Handler) PutRequestRetry(c *gin.Context) {
|
||||
h.updateIntField(c, func(v int) { h.cfg.RequestRetry = v })
|
||||
}
|
||||
|
||||
// Max retry interval
|
||||
func (h *Handler) GetMaxRetryInterval(c *gin.Context) {
|
||||
c.JSON(200, gin.H{"max-retry-interval": h.cfg.MaxRetryInterval})
|
||||
}
|
||||
func (h *Handler) PutMaxRetryInterval(c *gin.Context) {
|
||||
h.updateIntField(c, func(v int) { h.cfg.MaxRetryInterval = v })
|
||||
}
|
||||
|
||||
// Proxy URL
|
||||
func (h *Handler) GetProxyURL(c *gin.Context) { c.JSON(200, gin.H{"proxy-url": h.cfg.ProxyURL}) }
|
||||
func (h *Handler) PutProxyURL(c *gin.Context) {
|
||||
|
||||
@@ -58,8 +58,14 @@ func (h *Handler) GetLogs(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
limit, errLimit := parseLimit(c.Query("limit"))
|
||||
if errLimit != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("invalid limit: %v", errLimit)})
|
||||
return
|
||||
}
|
||||
|
||||
cutoff := parseCutoff(c.Query("after"))
|
||||
acc := newLogAccumulator(cutoff)
|
||||
acc := newLogAccumulator(cutoff, limit)
|
||||
for i := range files {
|
||||
if errProcess := acc.consumeFile(files[i]); errProcess != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("failed to read log file %s: %v", files[i], errProcess)})
|
||||
@@ -139,6 +145,126 @@ func (h *Handler) DeleteLogs(c *gin.Context) {
|
||||
})
|
||||
}
|
||||
|
||||
// GetRequestErrorLogs lists error request log files when RequestLog is disabled.
|
||||
// It returns an empty list when RequestLog is enabled.
|
||||
func (h *Handler) GetRequestErrorLogs(c *gin.Context) {
|
||||
if h == nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "handler unavailable"})
|
||||
return
|
||||
}
|
||||
if h.cfg == nil {
|
||||
c.JSON(http.StatusServiceUnavailable, gin.H{"error": "configuration unavailable"})
|
||||
return
|
||||
}
|
||||
if h.cfg.RequestLog {
|
||||
c.JSON(http.StatusOK, gin.H{"files": []any{}})
|
||||
return
|
||||
}
|
||||
|
||||
dir := h.logDirectory()
|
||||
if strings.TrimSpace(dir) == "" {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "log directory not configured"})
|
||||
return
|
||||
}
|
||||
|
||||
entries, err := os.ReadDir(dir)
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
c.JSON(http.StatusOK, gin.H{"files": []any{}})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("failed to list request error logs: %v", err)})
|
||||
return
|
||||
}
|
||||
|
||||
type errorLog struct {
|
||||
Name string `json:"name"`
|
||||
Size int64 `json:"size"`
|
||||
Modified int64 `json:"modified"`
|
||||
}
|
||||
|
||||
files := make([]errorLog, 0, len(entries))
|
||||
for _, entry := range entries {
|
||||
if entry.IsDir() {
|
||||
continue
|
||||
}
|
||||
name := entry.Name()
|
||||
if !strings.HasPrefix(name, "error-") || !strings.HasSuffix(name, ".log") {
|
||||
continue
|
||||
}
|
||||
info, errInfo := entry.Info()
|
||||
if errInfo != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("failed to read log info for %s: %v", name, errInfo)})
|
||||
return
|
||||
}
|
||||
files = append(files, errorLog{
|
||||
Name: name,
|
||||
Size: info.Size(),
|
||||
Modified: info.ModTime().Unix(),
|
||||
})
|
||||
}
|
||||
|
||||
sort.Slice(files, func(i, j int) bool { return files[i].Modified > files[j].Modified })
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{"files": files})
|
||||
}
|
||||
|
||||
// DownloadRequestErrorLog downloads a specific error request log file by name.
|
||||
func (h *Handler) DownloadRequestErrorLog(c *gin.Context) {
|
||||
if h == nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "handler unavailable"})
|
||||
return
|
||||
}
|
||||
if h.cfg == nil {
|
||||
c.JSON(http.StatusServiceUnavailable, gin.H{"error": "configuration unavailable"})
|
||||
return
|
||||
}
|
||||
|
||||
dir := h.logDirectory()
|
||||
if strings.TrimSpace(dir) == "" {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "log directory not configured"})
|
||||
return
|
||||
}
|
||||
|
||||
name := strings.TrimSpace(c.Param("name"))
|
||||
if name == "" || strings.Contains(name, "/") || strings.Contains(name, "\\") {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid log file name"})
|
||||
return
|
||||
}
|
||||
if !strings.HasPrefix(name, "error-") || !strings.HasSuffix(name, ".log") {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "log file not found"})
|
||||
return
|
||||
}
|
||||
|
||||
dirAbs, errAbs := filepath.Abs(dir)
|
||||
if errAbs != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("failed to resolve log directory: %v", errAbs)})
|
||||
return
|
||||
}
|
||||
fullPath := filepath.Clean(filepath.Join(dirAbs, name))
|
||||
prefix := dirAbs + string(os.PathSeparator)
|
||||
if !strings.HasPrefix(fullPath, prefix) {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid log file path"})
|
||||
return
|
||||
}
|
||||
|
||||
info, errStat := os.Stat(fullPath)
|
||||
if errStat != nil {
|
||||
if os.IsNotExist(errStat) {
|
||||
c.JSON(http.StatusNotFound, gin.H{"error": "log file not found"})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": fmt.Sprintf("failed to read log file: %v", errStat)})
|
||||
return
|
||||
}
|
||||
if info.IsDir() {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid log file"})
|
||||
return
|
||||
}
|
||||
|
||||
c.FileAttachment(fullPath, name)
|
||||
}
|
||||
|
||||
func (h *Handler) logDirectory() string {
|
||||
if h == nil {
|
||||
return ""
|
||||
@@ -194,16 +320,22 @@ func (h *Handler) collectLogFiles(dir string) ([]string, error) {
|
||||
|
||||
type logAccumulator struct {
|
||||
cutoff int64
|
||||
limit int
|
||||
lines []string
|
||||
total int
|
||||
latest int64
|
||||
include bool
|
||||
}
|
||||
|
||||
func newLogAccumulator(cutoff int64) *logAccumulator {
|
||||
func newLogAccumulator(cutoff int64, limit int) *logAccumulator {
|
||||
capacity := 256
|
||||
if limit > 0 && limit < capacity {
|
||||
capacity = limit
|
||||
}
|
||||
return &logAccumulator{
|
||||
cutoff: cutoff,
|
||||
lines: make([]string, 0, 256),
|
||||
limit: limit,
|
||||
lines: make([]string, 0, capacity),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -215,7 +347,9 @@ func (acc *logAccumulator) consumeFile(path string) error {
|
||||
}
|
||||
return err
|
||||
}
|
||||
defer file.Close()
|
||||
defer func() {
|
||||
_ = file.Close()
|
||||
}()
|
||||
|
||||
scanner := bufio.NewScanner(file)
|
||||
buf := make([]byte, 0, logScannerInitialBuffer)
|
||||
@@ -239,12 +373,19 @@ func (acc *logAccumulator) addLine(raw string) {
|
||||
if ts > 0 {
|
||||
acc.include = acc.cutoff == 0 || ts > acc.cutoff
|
||||
if acc.cutoff == 0 || acc.include {
|
||||
acc.lines = append(acc.lines, line)
|
||||
acc.append(line)
|
||||
}
|
||||
return
|
||||
}
|
||||
if acc.cutoff == 0 || acc.include {
|
||||
acc.lines = append(acc.lines, line)
|
||||
acc.append(line)
|
||||
}
|
||||
}
|
||||
|
||||
func (acc *logAccumulator) append(line string) {
|
||||
acc.lines = append(acc.lines, line)
|
||||
if acc.limit > 0 && len(acc.lines) > acc.limit {
|
||||
acc.lines = acc.lines[len(acc.lines)-acc.limit:]
|
||||
}
|
||||
}
|
||||
|
||||
@@ -267,6 +408,21 @@ func parseCutoff(raw string) int64 {
|
||||
return ts
|
||||
}
|
||||
|
||||
func parseLimit(raw string) (int, error) {
|
||||
value := strings.TrimSpace(raw)
|
||||
if value == "" {
|
||||
return 0, nil
|
||||
}
|
||||
limit, err := strconv.Atoi(value)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("must be a positive integer")
|
||||
}
|
||||
if limit <= 0 {
|
||||
return 0, fmt.Errorf("must be greater than zero")
|
||||
}
|
||||
return limit, nil
|
||||
}
|
||||
|
||||
func parseTimestamp(line string) int64 {
|
||||
if strings.HasPrefix(line, "[") {
|
||||
line = line[1:]
|
||||
|
||||
@@ -6,6 +6,7 @@ package middleware
|
||||
import (
|
||||
"bytes"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
@@ -15,8 +16,8 @@ import (
|
||||
|
||||
// RequestLoggingMiddleware creates a Gin middleware that logs HTTP requests and responses.
|
||||
// It captures detailed information about the request and response, including headers and body,
|
||||
// and uses the provided RequestLogger to record this data. If logging is disabled in the
|
||||
// logger, the middleware has minimal overhead.
|
||||
// and uses the provided RequestLogger to record this data. When logging is disabled in the
|
||||
// logger, it still captures data so that upstream errors can be persisted.
|
||||
func RequestLoggingMiddleware(logger logging.RequestLogger) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
if logger == nil {
|
||||
@@ -24,14 +25,13 @@ func RequestLoggingMiddleware(logger logging.RequestLogger) gin.HandlerFunc {
|
||||
return
|
||||
}
|
||||
|
||||
path := c.Request.URL.Path
|
||||
if !shouldLogRequest(path) {
|
||||
if c.Request.Method == http.MethodGet {
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
|
||||
// Early return if logging is disabled (zero overhead)
|
||||
if !logger.IsEnabled() {
|
||||
path := c.Request.URL.Path
|
||||
if !shouldLogRequest(path) {
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
@@ -47,6 +47,9 @@ func RequestLoggingMiddleware(logger logging.RequestLogger) gin.HandlerFunc {
|
||||
|
||||
// Create response writer wrapper
|
||||
wrapper := NewResponseWriterWrapper(c.Writer, logger, requestInfo)
|
||||
if !logger.IsEnabled() {
|
||||
wrapper.logOnErrorOnly = true
|
||||
}
|
||||
c.Writer = wrapper
|
||||
|
||||
// Process the request
|
||||
|
||||
@@ -5,6 +5,7 @@ package middleware
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
@@ -24,15 +25,16 @@ type RequestInfo struct {
|
||||
// It is designed to handle both standard and streaming responses, ensuring that logging operations do not block the client response.
|
||||
type ResponseWriterWrapper struct {
|
||||
gin.ResponseWriter
|
||||
body *bytes.Buffer // body is a buffer to store the response body for non-streaming responses.
|
||||
isStreaming bool // isStreaming indicates whether the response is a streaming type (e.g., text/event-stream).
|
||||
streamWriter logging.StreamingLogWriter // streamWriter is a writer for handling streaming log entries.
|
||||
chunkChannel chan []byte // chunkChannel is a channel for asynchronously passing response chunks to the logger.
|
||||
streamDone chan struct{} // streamDone signals when the streaming goroutine completes.
|
||||
logger logging.RequestLogger // logger is the instance of the request logger service.
|
||||
requestInfo *RequestInfo // requestInfo holds the details of the original request.
|
||||
statusCode int // statusCode stores the HTTP status code of the response.
|
||||
headers map[string][]string // headers stores the response headers.
|
||||
body *bytes.Buffer // body is a buffer to store the response body for non-streaming responses.
|
||||
isStreaming bool // isStreaming indicates whether the response is a streaming type (e.g., text/event-stream).
|
||||
streamWriter logging.StreamingLogWriter // streamWriter is a writer for handling streaming log entries.
|
||||
chunkChannel chan []byte // chunkChannel is a channel for asynchronously passing response chunks to the logger.
|
||||
streamDone chan struct{} // streamDone signals when the streaming goroutine completes.
|
||||
logger logging.RequestLogger // logger is the instance of the request logger service.
|
||||
requestInfo *RequestInfo // requestInfo holds the details of the original request.
|
||||
statusCode int // statusCode stores the HTTP status code of the response.
|
||||
headers map[string][]string // headers stores the response headers.
|
||||
logOnErrorOnly bool // logOnErrorOnly enables logging only when an error response is detected.
|
||||
}
|
||||
|
||||
// NewResponseWriterWrapper creates and initializes a new ResponseWriterWrapper.
|
||||
@@ -192,12 +194,34 @@ func (w *ResponseWriterWrapper) processStreamingChunks(done chan struct{}) {
|
||||
// For non-streaming responses, it logs the complete request and response details,
|
||||
// including any API-specific request/response data stored in the Gin context.
|
||||
func (w *ResponseWriterWrapper) Finalize(c *gin.Context) error {
|
||||
if !w.logger.IsEnabled() {
|
||||
if w.logger == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
finalStatusCode := w.statusCode
|
||||
if finalStatusCode == 0 {
|
||||
if statusWriter, ok := w.ResponseWriter.(interface{ Status() int }); ok {
|
||||
finalStatusCode = statusWriter.Status()
|
||||
} else {
|
||||
finalStatusCode = 200
|
||||
}
|
||||
}
|
||||
|
||||
var slicesAPIResponseError []*interfaces.ErrorMessage
|
||||
apiResponseError, isExist := c.Get("API_RESPONSE_ERROR")
|
||||
if isExist {
|
||||
if apiErrors, ok := apiResponseError.([]*interfaces.ErrorMessage); ok {
|
||||
slicesAPIResponseError = apiErrors
|
||||
}
|
||||
}
|
||||
|
||||
hasAPIError := len(slicesAPIResponseError) > 0 || finalStatusCode >= http.StatusBadRequest
|
||||
forceLog := w.logOnErrorOnly && hasAPIError && !w.logger.IsEnabled()
|
||||
if !w.logger.IsEnabled() && !forceLog {
|
||||
return nil
|
||||
}
|
||||
|
||||
if w.isStreaming {
|
||||
// Close streaming channel and writer
|
||||
if w.chunkChannel != nil {
|
||||
close(w.chunkChannel)
|
||||
w.chunkChannel = nil
|
||||
@@ -209,80 +233,98 @@ func (w *ResponseWriterWrapper) Finalize(c *gin.Context) error {
|
||||
}
|
||||
|
||||
if w.streamWriter != nil {
|
||||
err := w.streamWriter.Close()
|
||||
if err := w.streamWriter.Close(); err != nil {
|
||||
w.streamWriter = nil
|
||||
return err
|
||||
}
|
||||
w.streamWriter = nil
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
// Capture final status code and headers if not already captured
|
||||
finalStatusCode := w.statusCode
|
||||
if finalStatusCode == 0 {
|
||||
// Get status from underlying ResponseWriter if available
|
||||
if statusWriter, ok := w.ResponseWriter.(interface{ Status() int }); ok {
|
||||
finalStatusCode = statusWriter.Status()
|
||||
} else {
|
||||
finalStatusCode = 200 // Default
|
||||
}
|
||||
if forceLog {
|
||||
return w.logRequest(finalStatusCode, w.cloneHeaders(), w.body.Bytes(), w.extractAPIRequest(c), w.extractAPIResponse(c), slicesAPIResponseError, forceLog)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Ensure we have the latest headers before finalizing
|
||||
w.ensureHeadersCaptured()
|
||||
return w.logRequest(finalStatusCode, w.cloneHeaders(), w.body.Bytes(), w.extractAPIRequest(c), w.extractAPIResponse(c), slicesAPIResponseError, forceLog)
|
||||
}
|
||||
|
||||
// Use the captured headers as the final headers
|
||||
finalHeaders := make(map[string][]string)
|
||||
for key, values := range w.headers {
|
||||
// Make a copy of the values slice to avoid reference issues
|
||||
headerValues := make([]string, len(values))
|
||||
copy(headerValues, values)
|
||||
finalHeaders[key] = headerValues
|
||||
}
|
||||
func (w *ResponseWriterWrapper) cloneHeaders() map[string][]string {
|
||||
w.ensureHeadersCaptured()
|
||||
|
||||
var apiRequestBody []byte
|
||||
apiRequest, isExist := c.Get("API_REQUEST")
|
||||
if isExist {
|
||||
var ok bool
|
||||
apiRequestBody, ok = apiRequest.([]byte)
|
||||
if !ok {
|
||||
apiRequestBody = nil
|
||||
}
|
||||
}
|
||||
finalHeaders := make(map[string][]string, len(w.headers))
|
||||
for key, values := range w.headers {
|
||||
headerValues := make([]string, len(values))
|
||||
copy(headerValues, values)
|
||||
finalHeaders[key] = headerValues
|
||||
}
|
||||
|
||||
var apiResponseBody []byte
|
||||
apiResponse, isExist := c.Get("API_RESPONSE")
|
||||
if isExist {
|
||||
var ok bool
|
||||
apiResponseBody, ok = apiResponse.([]byte)
|
||||
if !ok {
|
||||
apiResponseBody = nil
|
||||
}
|
||||
}
|
||||
return finalHeaders
|
||||
}
|
||||
|
||||
var slicesAPIResponseError []*interfaces.ErrorMessage
|
||||
apiResponseError, isExist := c.Get("API_RESPONSE_ERROR")
|
||||
if isExist {
|
||||
var ok bool
|
||||
slicesAPIResponseError, ok = apiResponseError.([]*interfaces.ErrorMessage)
|
||||
if !ok {
|
||||
slicesAPIResponseError = nil
|
||||
}
|
||||
}
|
||||
func (w *ResponseWriterWrapper) extractAPIRequest(c *gin.Context) []byte {
|
||||
apiRequest, isExist := c.Get("API_REQUEST")
|
||||
if !isExist {
|
||||
return nil
|
||||
}
|
||||
data, ok := apiRequest.([]byte)
|
||||
if !ok || len(data) == 0 {
|
||||
return nil
|
||||
}
|
||||
return data
|
||||
}
|
||||
|
||||
// Log complete non-streaming response
|
||||
return w.logger.LogRequest(
|
||||
func (w *ResponseWriterWrapper) extractAPIResponse(c *gin.Context) []byte {
|
||||
apiResponse, isExist := c.Get("API_RESPONSE")
|
||||
if !isExist {
|
||||
return nil
|
||||
}
|
||||
data, ok := apiResponse.([]byte)
|
||||
if !ok || len(data) == 0 {
|
||||
return nil
|
||||
}
|
||||
return data
|
||||
}
|
||||
|
||||
func (w *ResponseWriterWrapper) logRequest(statusCode int, headers map[string][]string, body []byte, apiRequestBody, apiResponseBody []byte, apiResponseErrors []*interfaces.ErrorMessage, forceLog bool) error {
|
||||
if w.requestInfo == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
var requestBody []byte
|
||||
if len(w.requestInfo.Body) > 0 {
|
||||
requestBody = w.requestInfo.Body
|
||||
}
|
||||
|
||||
if loggerWithOptions, ok := w.logger.(interface {
|
||||
LogRequestWithOptions(string, string, map[string][]string, []byte, int, map[string][]string, []byte, []byte, []byte, []*interfaces.ErrorMessage, bool) error
|
||||
}); ok {
|
||||
return loggerWithOptions.LogRequestWithOptions(
|
||||
w.requestInfo.URL,
|
||||
w.requestInfo.Method,
|
||||
w.requestInfo.Headers,
|
||||
w.requestInfo.Body,
|
||||
finalStatusCode,
|
||||
finalHeaders,
|
||||
w.body.Bytes(),
|
||||
requestBody,
|
||||
statusCode,
|
||||
headers,
|
||||
body,
|
||||
apiRequestBody,
|
||||
apiResponseBody,
|
||||
slicesAPIResponseError,
|
||||
apiResponseErrors,
|
||||
forceLog,
|
||||
)
|
||||
}
|
||||
|
||||
return nil
|
||||
return w.logger.LogRequest(
|
||||
w.requestInfo.URL,
|
||||
w.requestInfo.Method,
|
||||
w.requestInfo.Headers,
|
||||
requestBody,
|
||||
statusCode,
|
||||
headers,
|
||||
body,
|
||||
apiRequestBody,
|
||||
apiResponseBody,
|
||||
apiResponseErrors,
|
||||
)
|
||||
}
|
||||
|
||||
// Status returns the HTTP response status code captured by the wrapper.
|
||||
|
||||
@@ -181,5 +181,3 @@ func (m *AmpModule) OnConfigUpdated(cfg *config.Config) error {
|
||||
log.Debug("Amp config updated (restart required for URL changes)")
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
|
||||
@@ -83,7 +83,7 @@ func createReverseProxy(upstreamURL string, secretSource SecretSource) (*httputi
|
||||
// Peek at first 2 bytes to detect gzip magic bytes
|
||||
header := make([]byte, 2)
|
||||
n, _ := io.ReadFull(originalBody, header)
|
||||
|
||||
|
||||
// Check for gzip magic bytes (0x1f 0x8b)
|
||||
// If n < 2, we didn't get enough bytes, so it's not gzip
|
||||
if n >= 2 && header[0] == 0x1f && header[1] == 0x8b {
|
||||
@@ -97,7 +97,7 @@ func createReverseProxy(upstreamURL string, secretSource SecretSource) (*httputi
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
// Reconstruct complete gzipped data
|
||||
gzippedData := append(header[:n], rest...)
|
||||
|
||||
@@ -129,8 +129,8 @@ func createReverseProxy(upstreamURL string, secretSource SecretSource) (*httputi
|
||||
resp.ContentLength = int64(len(decompressed))
|
||||
|
||||
// Update headers to reflect decompressed state
|
||||
resp.Header.Del("Content-Encoding") // No longer compressed
|
||||
resp.Header.Del("Content-Length") // Remove stale compressed length
|
||||
resp.Header.Del("Content-Encoding") // No longer compressed
|
||||
resp.Header.Del("Content-Length") // Remove stale compressed length
|
||||
resp.Header.Set("Content-Length", strconv.FormatInt(resp.ContentLength, 10)) // Set decompressed length
|
||||
|
||||
log.Debugf("amp proxy: decompressed gzip response (%d -> %d bytes)", len(gzippedData), len(decompressed))
|
||||
|
||||
@@ -440,52 +440,52 @@ func TestIsStreamingResponse(t *testing.T) {
|
||||
|
||||
func TestFilterBetaFeatures(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
header string
|
||||
name string
|
||||
header string
|
||||
featureToRemove string
|
||||
expected string
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "Remove context-1m from middle",
|
||||
header: "fine-grained-tool-streaming-2025-05-14,context-1m-2025-08-07,oauth-2025-04-20",
|
||||
name: "Remove context-1m from middle",
|
||||
header: "fine-grained-tool-streaming-2025-05-14,context-1m-2025-08-07,oauth-2025-04-20",
|
||||
featureToRemove: "context-1m-2025-08-07",
|
||||
expected: "fine-grained-tool-streaming-2025-05-14,oauth-2025-04-20",
|
||||
expected: "fine-grained-tool-streaming-2025-05-14,oauth-2025-04-20",
|
||||
},
|
||||
{
|
||||
name: "Remove context-1m from start",
|
||||
header: "context-1m-2025-08-07,fine-grained-tool-streaming-2025-05-14",
|
||||
name: "Remove context-1m from start",
|
||||
header: "context-1m-2025-08-07,fine-grained-tool-streaming-2025-05-14",
|
||||
featureToRemove: "context-1m-2025-08-07",
|
||||
expected: "fine-grained-tool-streaming-2025-05-14",
|
||||
expected: "fine-grained-tool-streaming-2025-05-14",
|
||||
},
|
||||
{
|
||||
name: "Remove context-1m from end",
|
||||
header: "fine-grained-tool-streaming-2025-05-14,context-1m-2025-08-07",
|
||||
name: "Remove context-1m from end",
|
||||
header: "fine-grained-tool-streaming-2025-05-14,context-1m-2025-08-07",
|
||||
featureToRemove: "context-1m-2025-08-07",
|
||||
expected: "fine-grained-tool-streaming-2025-05-14",
|
||||
expected: "fine-grained-tool-streaming-2025-05-14",
|
||||
},
|
||||
{
|
||||
name: "Feature not present",
|
||||
header: "fine-grained-tool-streaming-2025-05-14,oauth-2025-04-20",
|
||||
name: "Feature not present",
|
||||
header: "fine-grained-tool-streaming-2025-05-14,oauth-2025-04-20",
|
||||
featureToRemove: "context-1m-2025-08-07",
|
||||
expected: "fine-grained-tool-streaming-2025-05-14,oauth-2025-04-20",
|
||||
expected: "fine-grained-tool-streaming-2025-05-14,oauth-2025-04-20",
|
||||
},
|
||||
{
|
||||
name: "Only feature to remove",
|
||||
header: "context-1m-2025-08-07",
|
||||
name: "Only feature to remove",
|
||||
header: "context-1m-2025-08-07",
|
||||
featureToRemove: "context-1m-2025-08-07",
|
||||
expected: "",
|
||||
expected: "",
|
||||
},
|
||||
{
|
||||
name: "Empty header",
|
||||
header: "",
|
||||
name: "Empty header",
|
||||
header: "",
|
||||
featureToRemove: "context-1m-2025-08-07",
|
||||
expected: "",
|
||||
expected: "",
|
||||
},
|
||||
{
|
||||
name: "Header with spaces",
|
||||
header: "fine-grained-tool-streaming-2025-05-14, context-1m-2025-08-07 , oauth-2025-04-20",
|
||||
name: "Header with spaces",
|
||||
header: "fine-grained-tool-streaming-2025-05-14, context-1m-2025-08-07 , oauth-2025-04-20",
|
||||
featureToRemove: "context-1m-2025-08-07",
|
||||
expected: "fine-grained-tool-streaming-2025-05-14,oauth-2025-04-20",
|
||||
expected: "fine-grained-tool-streaming-2025-05-14,oauth-2025-04-20",
|
||||
},
|
||||
}
|
||||
|
||||
|
||||
@@ -6,11 +6,11 @@ import (
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/sdk/api/handlers"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/sdk/api/handlers/claude"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/sdk/api/handlers/gemini"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/sdk/api/handlers/openai"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
|
||||
@@ -247,6 +247,9 @@ func NewServer(cfg *config.Config, authManager *auth.Manager, accessManager *sdk
|
||||
// Save initial YAML snapshot
|
||||
s.oldConfigYaml, _ = yaml.Marshal(cfg)
|
||||
s.applyAccessConfig(nil, cfg)
|
||||
if authManager != nil {
|
||||
authManager.SetRetryConfig(cfg.RequestRetry, time.Duration(cfg.MaxRetryInterval)*time.Second)
|
||||
}
|
||||
managementasset.SetCurrentConfig(cfg)
|
||||
auth.SetQuotaCooldownDisabled(cfg.DisableCooling)
|
||||
// Initialize management handler
|
||||
@@ -509,6 +512,8 @@ func (s *Server) registerManagementRoutes() {
|
||||
|
||||
mgmt.GET("/logs", s.mgmt.GetLogs)
|
||||
mgmt.DELETE("/logs", s.mgmt.DeleteLogs)
|
||||
mgmt.GET("/request-error-logs", s.mgmt.GetRequestErrorLogs)
|
||||
mgmt.GET("/request-error-logs/:name", s.mgmt.DownloadRequestErrorLog)
|
||||
mgmt.GET("/request-log", s.mgmt.GetRequestLog)
|
||||
mgmt.PUT("/request-log", s.mgmt.PutRequestLog)
|
||||
mgmt.PATCH("/request-log", s.mgmt.PutRequestLog)
|
||||
@@ -519,6 +524,9 @@ func (s *Server) registerManagementRoutes() {
|
||||
mgmt.GET("/request-retry", s.mgmt.GetRequestRetry)
|
||||
mgmt.PUT("/request-retry", s.mgmt.PutRequestRetry)
|
||||
mgmt.PATCH("/request-retry", s.mgmt.PutRequestRetry)
|
||||
mgmt.GET("/max-retry-interval", s.mgmt.GetMaxRetryInterval)
|
||||
mgmt.PUT("/max-retry-interval", s.mgmt.PutMaxRetryInterval)
|
||||
mgmt.PATCH("/max-retry-interval", s.mgmt.PutMaxRetryInterval)
|
||||
|
||||
mgmt.GET("/claude-api-key", s.mgmt.GetClaudeKeys)
|
||||
mgmt.PUT("/claude-api-key", s.mgmt.PutClaudeKeys)
|
||||
@@ -686,17 +694,33 @@ func (s *Server) unifiedModelsHandler(openaiHandler *openai.OpenAIAPIHandler, cl
|
||||
}
|
||||
}
|
||||
|
||||
// Start begins listening for and serving HTTP requests.
|
||||
// Start begins listening for and serving HTTP or HTTPS requests.
|
||||
// It's a blocking call and will only return on an unrecoverable error.
|
||||
//
|
||||
// Returns:
|
||||
// - error: An error if the server fails to start
|
||||
func (s *Server) Start() error {
|
||||
log.Debugf("Starting API server on %s", s.server.Addr)
|
||||
if s == nil || s.server == nil {
|
||||
return fmt.Errorf("failed to start HTTP server: server not initialized")
|
||||
}
|
||||
|
||||
// Start the HTTP server.
|
||||
if err := s.server.ListenAndServe(); err != nil && !errors.Is(err, http.ErrServerClosed) {
|
||||
return fmt.Errorf("failed to start HTTP server: %v", err)
|
||||
useTLS := s.cfg != nil && s.cfg.TLS.Enable
|
||||
if useTLS {
|
||||
cert := strings.TrimSpace(s.cfg.TLS.Cert)
|
||||
key := strings.TrimSpace(s.cfg.TLS.Key)
|
||||
if cert == "" || key == "" {
|
||||
return fmt.Errorf("failed to start HTTPS server: tls.cert or tls.key is empty")
|
||||
}
|
||||
log.Debugf("Starting API server on %s with TLS", s.server.Addr)
|
||||
if errServeTLS := s.server.ListenAndServeTLS(cert, key); errServeTLS != nil && !errors.Is(errServeTLS, http.ErrServerClosed) {
|
||||
return fmt.Errorf("failed to start HTTPS server: %v", errServeTLS)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
log.Debugf("Starting API server on %s", s.server.Addr)
|
||||
if errServe := s.server.ListenAndServe(); errServe != nil && !errors.Is(errServe, http.ErrServerClosed) {
|
||||
return fmt.Errorf("failed to start HTTP server: %v", errServe)
|
||||
}
|
||||
|
||||
return nil
|
||||
@@ -814,6 +838,9 @@ func (s *Server) UpdateClients(cfg *config.Config) {
|
||||
log.Debugf("disable_cooling toggled to %t", cfg.DisableCooling)
|
||||
}
|
||||
}
|
||||
if s.handlers != nil && s.handlers.AuthManager != nil {
|
||||
s.handlers.AuthManager.SetRetryConfig(cfg.RequestRetry, time.Duration(cfg.MaxRetryInterval)*time.Second)
|
||||
}
|
||||
|
||||
// Update log level dynamically when debug flag changes
|
||||
if oldCfg == nil || oldCfg.Debug != cfg.Debug {
|
||||
|
||||
@@ -23,6 +23,9 @@ type Config struct {
|
||||
// Port is the network port on which the API server will listen.
|
||||
Port int `yaml:"port" json:"-"`
|
||||
|
||||
// TLS config controls HTTPS server settings.
|
||||
TLS TLSConfig `yaml:"tls" json:"tls"`
|
||||
|
||||
// AmpUpstreamURL defines the upstream Amp control plane used for non-provider calls.
|
||||
AmpUpstreamURL string `yaml:"amp-upstream-url" json:"amp-upstream-url"`
|
||||
|
||||
@@ -63,6 +66,8 @@ type Config struct {
|
||||
|
||||
// RequestRetry defines the retry times when the request failed.
|
||||
RequestRetry int `yaml:"request-retry" json:"request-retry"`
|
||||
// MaxRetryInterval defines the maximum wait time in seconds before retrying a cooled-down credential.
|
||||
MaxRetryInterval int `yaml:"max-retry-interval" json:"max-retry-interval"`
|
||||
|
||||
// ClaudeKey defines a list of Claude API key configurations as specified in the YAML configuration file.
|
||||
ClaudeKey []ClaudeKey `yaml:"claude-api-key" json:"claude-api-key"`
|
||||
@@ -80,6 +85,16 @@ type Config struct {
|
||||
Payload PayloadConfig `yaml:"payload" json:"payload"`
|
||||
}
|
||||
|
||||
// TLSConfig holds HTTPS server settings.
|
||||
type TLSConfig struct {
|
||||
// Enable toggles HTTPS server mode.
|
||||
Enable bool `yaml:"enable" json:"enable"`
|
||||
// Cert is the path to the TLS certificate file.
|
||||
Cert string `yaml:"cert" json:"cert"`
|
||||
// Key is the path to the TLS private key file.
|
||||
Key string `yaml:"key" json:"key"`
|
||||
}
|
||||
|
||||
// RemoteManagement holds management API configuration under 'remote-management'.
|
||||
type RemoteManagement struct {
|
||||
// AllowRemote toggles remote (non-localhost) access to management API.
|
||||
|
||||
@@ -12,6 +12,7 @@ import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"regexp"
|
||||
"sort"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
@@ -156,17 +157,30 @@ func (l *FileRequestLogger) SetEnabled(enabled bool) {
|
||||
// Returns:
|
||||
// - error: An error if logging fails, nil otherwise
|
||||
func (l *FileRequestLogger) LogRequest(url, method string, requestHeaders map[string][]string, body []byte, statusCode int, responseHeaders map[string][]string, response, apiRequest, apiResponse []byte, apiResponseErrors []*interfaces.ErrorMessage) error {
|
||||
if !l.enabled {
|
||||
return l.logRequest(url, method, requestHeaders, body, statusCode, responseHeaders, response, apiRequest, apiResponse, apiResponseErrors, false)
|
||||
}
|
||||
|
||||
// LogRequestWithOptions logs a request with optional forced logging behavior.
|
||||
// The force flag allows writing error logs even when regular request logging is disabled.
|
||||
func (l *FileRequestLogger) LogRequestWithOptions(url, method string, requestHeaders map[string][]string, body []byte, statusCode int, responseHeaders map[string][]string, response, apiRequest, apiResponse []byte, apiResponseErrors []*interfaces.ErrorMessage, force bool) error {
|
||||
return l.logRequest(url, method, requestHeaders, body, statusCode, responseHeaders, response, apiRequest, apiResponse, apiResponseErrors, force)
|
||||
}
|
||||
|
||||
func (l *FileRequestLogger) logRequest(url, method string, requestHeaders map[string][]string, body []byte, statusCode int, responseHeaders map[string][]string, response, apiRequest, apiResponse []byte, apiResponseErrors []*interfaces.ErrorMessage, force bool) error {
|
||||
if !l.enabled && !force {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Ensure logs directory exists
|
||||
if err := l.ensureLogsDir(); err != nil {
|
||||
return fmt.Errorf("failed to create logs directory: %w", err)
|
||||
if errEnsure := l.ensureLogsDir(); errEnsure != nil {
|
||||
return fmt.Errorf("failed to create logs directory: %w", errEnsure)
|
||||
}
|
||||
|
||||
// Generate filename
|
||||
filename := l.generateFilename(url)
|
||||
if force && !l.enabled {
|
||||
filename = l.generateErrorFilename(url)
|
||||
}
|
||||
filePath := filepath.Join(l.logsDir, filename)
|
||||
|
||||
// Decompress response if needed
|
||||
@@ -184,6 +198,12 @@ func (l *FileRequestLogger) LogRequest(url, method string, requestHeaders map[st
|
||||
return fmt.Errorf("failed to write log file: %w", err)
|
||||
}
|
||||
|
||||
if force && !l.enabled {
|
||||
if errCleanup := l.cleanupOldErrorLogs(); errCleanup != nil {
|
||||
log.WithError(errCleanup).Warn("failed to clean up old error logs")
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -239,6 +259,11 @@ func (l *FileRequestLogger) LogStreamingRequest(url, method string, headers map[
|
||||
return writer, nil
|
||||
}
|
||||
|
||||
// generateErrorFilename creates a filename with an error prefix to differentiate forced error logs.
|
||||
func (l *FileRequestLogger) generateErrorFilename(url string) string {
|
||||
return fmt.Sprintf("error-%s", l.generateFilename(url))
|
||||
}
|
||||
|
||||
// ensureLogsDir creates the logs directory if it doesn't exist.
|
||||
//
|
||||
// Returns:
|
||||
@@ -312,6 +337,52 @@ func (l *FileRequestLogger) sanitizeForFilename(path string) string {
|
||||
return sanitized
|
||||
}
|
||||
|
||||
// cleanupOldErrorLogs keeps only the newest 10 forced error log files.
|
||||
func (l *FileRequestLogger) cleanupOldErrorLogs() error {
|
||||
entries, errRead := os.ReadDir(l.logsDir)
|
||||
if errRead != nil {
|
||||
return errRead
|
||||
}
|
||||
|
||||
type logFile struct {
|
||||
name string
|
||||
modTime time.Time
|
||||
}
|
||||
|
||||
var files []logFile
|
||||
for _, entry := range entries {
|
||||
if entry.IsDir() {
|
||||
continue
|
||||
}
|
||||
name := entry.Name()
|
||||
if !strings.HasPrefix(name, "error-") || !strings.HasSuffix(name, ".log") {
|
||||
continue
|
||||
}
|
||||
info, errInfo := entry.Info()
|
||||
if errInfo != nil {
|
||||
log.WithError(errInfo).Warn("failed to read error log info")
|
||||
continue
|
||||
}
|
||||
files = append(files, logFile{name: name, modTime: info.ModTime()})
|
||||
}
|
||||
|
||||
if len(files) <= 10 {
|
||||
return nil
|
||||
}
|
||||
|
||||
sort.Slice(files, func(i, j int) bool {
|
||||
return files[i].modTime.After(files[j].modTime)
|
||||
})
|
||||
|
||||
for _, file := range files[10:] {
|
||||
if errRemove := os.Remove(filepath.Join(l.logsDir, file.name)); errRemove != nil {
|
||||
log.WithError(errRemove).Warnf("failed to remove old error log: %s", file.name)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// formatLogContent creates the complete log content for non-streaming requests.
|
||||
//
|
||||
// Parameters:
|
||||
|
||||
@@ -23,6 +23,62 @@ func GetClaudeModels() []*ModelInfo {
|
||||
Type: "claude",
|
||||
DisplayName: "Claude 4.5 Sonnet",
|
||||
},
|
||||
{
|
||||
ID: "claude-sonnet-4-5-thinking",
|
||||
Object: "model",
|
||||
Created: 1759104000, // 2025-09-29
|
||||
OwnedBy: "anthropic",
|
||||
Type: "claude",
|
||||
DisplayName: "Claude 4.5 Sonnet Thinking",
|
||||
Thinking: &ThinkingSupport{Min: 1024, Max: 100000, ZeroAllowed: false, DynamicAllowed: true},
|
||||
},
|
||||
{
|
||||
ID: "claude-opus-4-5-thinking",
|
||||
Object: "model",
|
||||
Created: 1761955200, // 2025-11-01
|
||||
OwnedBy: "anthropic",
|
||||
Type: "claude",
|
||||
DisplayName: "Claude 4.5 Opus Thinking",
|
||||
Thinking: &ThinkingSupport{Min: 1024, Max: 100000, ZeroAllowed: false, DynamicAllowed: true},
|
||||
},
|
||||
{
|
||||
ID: "claude-opus-4-5-thinking-low",
|
||||
Object: "model",
|
||||
Created: 1761955200, // 2025-11-01
|
||||
OwnedBy: "anthropic",
|
||||
Type: "claude",
|
||||
DisplayName: "Claude 4.5 Opus Thinking Low",
|
||||
Thinking: &ThinkingSupport{Min: 1024, Max: 100000, ZeroAllowed: false, DynamicAllowed: true},
|
||||
},
|
||||
{
|
||||
ID: "claude-opus-4-5-thinking-medium",
|
||||
Object: "model",
|
||||
Created: 1761955200, // 2025-11-01
|
||||
OwnedBy: "anthropic",
|
||||
Type: "claude",
|
||||
DisplayName: "Claude 4.5 Opus Thinking Medium",
|
||||
Thinking: &ThinkingSupport{Min: 1024, Max: 100000, ZeroAllowed: false, DynamicAllowed: true},
|
||||
},
|
||||
{
|
||||
ID: "claude-opus-4-5-thinking-high",
|
||||
Object: "model",
|
||||
Created: 1761955200, // 2025-11-01
|
||||
OwnedBy: "anthropic",
|
||||
Type: "claude",
|
||||
DisplayName: "Claude 4.5 Opus Thinking High",
|
||||
Thinking: &ThinkingSupport{Min: 1024, Max: 100000, ZeroAllowed: false, DynamicAllowed: true},
|
||||
},
|
||||
{
|
||||
ID: "claude-opus-4-5-20251101",
|
||||
Object: "model",
|
||||
Created: 1761955200, // 2025-11-01
|
||||
OwnedBy: "anthropic",
|
||||
Type: "claude",
|
||||
DisplayName: "Claude 4.5 Opus",
|
||||
Description: "Premium model combining maximum intelligence with practical performance",
|
||||
ContextLength: 200000,
|
||||
MaxCompletionTokens: 64000,
|
||||
},
|
||||
{
|
||||
ID: "claude-opus-4-1-20250805",
|
||||
Object: "model",
|
||||
@@ -129,6 +185,20 @@ func GetGeminiModels() []*ModelInfo {
|
||||
SupportedGenerationMethods: []string{"generateContent", "countTokens", "createCachedContent", "batchGenerateContent"},
|
||||
Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true},
|
||||
},
|
||||
{
|
||||
ID: "gemini-3-pro-image-preview",
|
||||
Object: "model",
|
||||
Created: 1737158400,
|
||||
OwnedBy: "google",
|
||||
Type: "gemini",
|
||||
Name: "models/gemini-3-pro-image-preview",
|
||||
Version: "3.0",
|
||||
DisplayName: "Gemini 3 Pro Image Preview",
|
||||
Description: "Gemini 3 Pro Image Preview",
|
||||
InputTokenLimit: 1048576,
|
||||
OutputTokenLimit: 65536,
|
||||
SupportedGenerationMethods: []string{"generateContent", "countTokens", "createCachedContent", "batchGenerateContent"},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -826,7 +826,6 @@ func (r *ModelRegistry) CleanupExpiredQuotas() {
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
// GetFirstAvailableModel returns the first available model for the given handler type.
|
||||
// It prioritizes models by their creation timestamp (newest first) and checks if they have
|
||||
// available clients that are not suspended or over quota.
|
||||
|
||||
@@ -319,6 +319,9 @@ func (e *AIStudioExecutor) translateRequest(req cliproxyexecutor.Request, opts c
|
||||
payload = util.StripThinkingConfigIfUnsupported(req.Model, payload)
|
||||
payload = fixGeminiImageAspectRatio(req.Model, payload)
|
||||
payload = applyPayloadConfig(e.cfg, req.Model, payload)
|
||||
payload, _ = sjson.DeleteBytes(payload, "generationConfig.maxOutputTokens")
|
||||
payload, _ = sjson.DeleteBytes(payload, "generationConfig.responseMimeType")
|
||||
payload, _ = sjson.DeleteBytes(payload, "generationConfig.responseJsonSchema")
|
||||
metadataAction := "generateContent"
|
||||
if req.Metadata != nil {
|
||||
if action, _ := req.Metadata["action"].(string); action == "countTokens" {
|
||||
|
||||
@@ -26,16 +26,18 @@ import (
|
||||
)
|
||||
|
||||
const (
|
||||
antigravityBaseURL = "https://daily-cloudcode-pa.sandbox.googleapis.com"
|
||||
antigravityStreamPath = "/v1internal:streamGenerateContent"
|
||||
antigravityGeneratePath = "/v1internal:generateContent"
|
||||
antigravityModelsPath = "/v1internal:fetchAvailableModels"
|
||||
antigravityClientID = "1071006060591-tmhssin2h21lcre235vtolojh4g403ep.apps.googleusercontent.com"
|
||||
antigravityClientSecret = "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"
|
||||
defaultAntigravityAgent = "antigravity/1.11.3 windows/amd64"
|
||||
antigravityAuthType = "antigravity"
|
||||
refreshSkew = 5 * time.Minute
|
||||
streamScannerBuffer int = 20_971_520
|
||||
antigravityBaseURLDaily = "https://daily-cloudcode-pa.sandbox.googleapis.com"
|
||||
antigravityBaseURLAutopush = "https://autopush-cloudcode-pa.sandbox.googleapis.com"
|
||||
antigravityBaseURLProd = "https://cloudcode-pa.googleapis.com"
|
||||
antigravityStreamPath = "/v1internal:streamGenerateContent"
|
||||
antigravityGeneratePath = "/v1internal:generateContent"
|
||||
antigravityModelsPath = "/v1internal:fetchAvailableModels"
|
||||
antigravityClientID = "1071006060591-tmhssin2h21lcre235vtolojh4g403ep.apps.googleusercontent.com"
|
||||
antigravityClientSecret = "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"
|
||||
defaultAntigravityAgent = "antigravity/1.11.5 windows/amd64"
|
||||
antigravityAuthType = "antigravity"
|
||||
refreshSkew = 3000 * time.Second
|
||||
streamScannerBuffer int = 20_971_520
|
||||
)
|
||||
|
||||
var randSource = rand.New(rand.NewSource(time.Now().UnixNano()))
|
||||
@@ -73,43 +75,76 @@ func (e *AntigravityExecutor) Execute(ctx context.Context, auth *cliproxyauth.Au
|
||||
to := sdktranslator.FromString("antigravity")
|
||||
translated := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), false)
|
||||
|
||||
httpReq, errReq := e.buildRequest(ctx, auth, token, req.Model, translated, false, opts.Alt)
|
||||
if errReq != nil {
|
||||
return resp, errReq
|
||||
}
|
||||
|
||||
baseURLs := antigravityBaseURLFallbackOrder(auth)
|
||||
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
|
||||
httpResp, errDo := httpClient.Do(httpReq)
|
||||
if errDo != nil {
|
||||
recordAPIResponseError(ctx, e.cfg, errDo)
|
||||
return resp, errDo
|
||||
}
|
||||
defer func() {
|
||||
|
||||
var lastStatus int
|
||||
var lastBody []byte
|
||||
var lastErr error
|
||||
|
||||
for idx, baseURL := range baseURLs {
|
||||
httpReq, errReq := e.buildRequest(ctx, auth, token, req.Model, translated, false, opts.Alt, baseURL)
|
||||
if errReq != nil {
|
||||
err = errReq
|
||||
return resp, err
|
||||
}
|
||||
|
||||
httpResp, errDo := httpClient.Do(httpReq)
|
||||
if errDo != nil {
|
||||
recordAPIResponseError(ctx, e.cfg, errDo)
|
||||
lastStatus = 0
|
||||
lastBody = nil
|
||||
lastErr = errDo
|
||||
if idx+1 < len(baseURLs) {
|
||||
log.Debugf("antigravity executor: request error on base url %s, retrying with fallback base url: %s", baseURL, baseURLs[idx+1])
|
||||
continue
|
||||
}
|
||||
err = errDo
|
||||
return resp, err
|
||||
}
|
||||
|
||||
recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
|
||||
bodyBytes, errRead := io.ReadAll(httpResp.Body)
|
||||
if errClose := httpResp.Body.Close(); errClose != nil {
|
||||
log.Errorf("antigravity executor: close response body error: %v", errClose)
|
||||
}
|
||||
}()
|
||||
if errRead != nil {
|
||||
recordAPIResponseError(ctx, e.cfg, errRead)
|
||||
err = errRead
|
||||
return resp, err
|
||||
}
|
||||
appendAPIResponseChunk(ctx, e.cfg, bodyBytes)
|
||||
|
||||
recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
|
||||
bodyBytes, errRead := io.ReadAll(httpResp.Body)
|
||||
if errRead != nil {
|
||||
recordAPIResponseError(ctx, e.cfg, errRead)
|
||||
return resp, errRead
|
||||
}
|
||||
appendAPIResponseChunk(ctx, e.cfg, bodyBytes)
|
||||
if httpResp.StatusCode < http.StatusOK || httpResp.StatusCode >= http.StatusMultipleChoices {
|
||||
log.Debugf("antigravity executor: upstream error status: %d, body: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), bodyBytes))
|
||||
lastStatus = httpResp.StatusCode
|
||||
lastBody = append([]byte(nil), bodyBytes...)
|
||||
lastErr = nil
|
||||
if httpResp.StatusCode == http.StatusTooManyRequests && idx+1 < len(baseURLs) {
|
||||
log.Debugf("antigravity executor: rate limited on base url %s, retrying with fallback base url: %s", baseURL, baseURLs[idx+1])
|
||||
continue
|
||||
}
|
||||
err = statusErr{code: httpResp.StatusCode, msg: string(bodyBytes)}
|
||||
return resp, err
|
||||
}
|
||||
|
||||
if httpResp.StatusCode < http.StatusOK || httpResp.StatusCode >= http.StatusMultipleChoices {
|
||||
log.Debugf("antigravity executor: upstream error status: %d, body: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), bodyBytes))
|
||||
err = statusErr{code: httpResp.StatusCode, msg: string(bodyBytes)}
|
||||
return resp, err
|
||||
reporter.publish(ctx, parseAntigravityUsage(bodyBytes))
|
||||
var param any
|
||||
converted := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, bytes.Clone(opts.OriginalRequest), translated, bodyBytes, ¶m)
|
||||
resp = cliproxyexecutor.Response{Payload: []byte(converted)}
|
||||
reporter.ensurePublished(ctx)
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
reporter.publish(ctx, parseAntigravityUsage(bodyBytes))
|
||||
var param any
|
||||
converted := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, bytes.Clone(opts.OriginalRequest), translated, bodyBytes, ¶m)
|
||||
resp = cliproxyexecutor.Response{Payload: []byte(converted)}
|
||||
reporter.ensurePublished(ctx)
|
||||
return resp, nil
|
||||
switch {
|
||||
case lastStatus != 0:
|
||||
err = statusErr{code: lastStatus, msg: string(lastBody)}
|
||||
case lastErr != nil:
|
||||
err = lastErr
|
||||
default:
|
||||
err = statusErr{code: http.StatusServiceUnavailable, msg: "antigravity executor: no base url available"}
|
||||
}
|
||||
return resp, err
|
||||
}
|
||||
|
||||
// ExecuteStream handles streaming requests via the antigravity upstream.
|
||||
@@ -131,75 +166,121 @@ func (e *AntigravityExecutor) ExecuteStream(ctx context.Context, auth *cliproxya
|
||||
to := sdktranslator.FromString("antigravity")
|
||||
translated := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), true)
|
||||
|
||||
httpReq, errReq := e.buildRequest(ctx, auth, token, req.Model, translated, true, opts.Alt)
|
||||
if errReq != nil {
|
||||
return nil, errReq
|
||||
}
|
||||
|
||||
baseURLs := antigravityBaseURLFallbackOrder(auth)
|
||||
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
|
||||
httpResp, errDo := httpClient.Do(httpReq)
|
||||
if errDo != nil {
|
||||
recordAPIResponseError(ctx, e.cfg, errDo)
|
||||
return nil, errDo
|
||||
}
|
||||
recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
|
||||
if httpResp.StatusCode < http.StatusOK || httpResp.StatusCode >= http.StatusMultipleChoices {
|
||||
bodyBytes, _ := io.ReadAll(httpResp.Body)
|
||||
appendAPIResponseChunk(ctx, e.cfg, bodyBytes)
|
||||
if errClose := httpResp.Body.Close(); errClose != nil {
|
||||
log.Errorf("antigravity executor: close response body error: %v", errClose)
|
||||
}
|
||||
err = statusErr{code: httpResp.StatusCode, msg: string(bodyBytes)}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
out := make(chan cliproxyexecutor.StreamChunk)
|
||||
stream = out
|
||||
go func() {
|
||||
defer close(out)
|
||||
defer func() {
|
||||
var lastStatus int
|
||||
var lastBody []byte
|
||||
var lastErr error
|
||||
|
||||
for idx, baseURL := range baseURLs {
|
||||
httpReq, errReq := e.buildRequest(ctx, auth, token, req.Model, translated, true, opts.Alt, baseURL)
|
||||
if errReq != nil {
|
||||
err = errReq
|
||||
return nil, err
|
||||
}
|
||||
|
||||
httpResp, errDo := httpClient.Do(httpReq)
|
||||
if errDo != nil {
|
||||
recordAPIResponseError(ctx, e.cfg, errDo)
|
||||
lastStatus = 0
|
||||
lastBody = nil
|
||||
lastErr = errDo
|
||||
if idx+1 < len(baseURLs) {
|
||||
log.Debugf("antigravity executor: request error on base url %s, retrying with fallback base url: %s", baseURL, baseURLs[idx+1])
|
||||
continue
|
||||
}
|
||||
err = errDo
|
||||
return nil, err
|
||||
}
|
||||
recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
|
||||
if httpResp.StatusCode < http.StatusOK || httpResp.StatusCode >= http.StatusMultipleChoices {
|
||||
bodyBytes, errRead := io.ReadAll(httpResp.Body)
|
||||
if errClose := httpResp.Body.Close(); errClose != nil {
|
||||
log.Errorf("antigravity executor: close response body error: %v", errClose)
|
||||
}
|
||||
}()
|
||||
scanner := bufio.NewScanner(httpResp.Body)
|
||||
scanner.Buffer(nil, streamScannerBuffer)
|
||||
var param any
|
||||
for scanner.Scan() {
|
||||
line := scanner.Bytes()
|
||||
appendAPIResponseChunk(ctx, e.cfg, line)
|
||||
|
||||
// Filter usage metadata for all models
|
||||
// Only retain usage statistics in the terminal chunk
|
||||
line = FilterSSEUsageMetadata(line)
|
||||
|
||||
payload := jsonPayload(line)
|
||||
if payload == nil {
|
||||
if errRead != nil {
|
||||
recordAPIResponseError(ctx, e.cfg, errRead)
|
||||
lastStatus = 0
|
||||
lastBody = nil
|
||||
lastErr = errRead
|
||||
if idx+1 < len(baseURLs) {
|
||||
log.Debugf("antigravity executor: read error on base url %s, retrying with fallback base url: %s", baseURL, baseURLs[idx+1])
|
||||
continue
|
||||
}
|
||||
err = errRead
|
||||
return nil, err
|
||||
}
|
||||
appendAPIResponseChunk(ctx, e.cfg, bodyBytes)
|
||||
lastStatus = httpResp.StatusCode
|
||||
lastBody = append([]byte(nil), bodyBytes...)
|
||||
lastErr = nil
|
||||
if httpResp.StatusCode == http.StatusTooManyRequests && idx+1 < len(baseURLs) {
|
||||
log.Debugf("antigravity executor: rate limited on base url %s, retrying with fallback base url: %s", baseURL, baseURLs[idx+1])
|
||||
continue
|
||||
}
|
||||
err = statusErr{code: httpResp.StatusCode, msg: string(bodyBytes)}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if detail, ok := parseAntigravityStreamUsage(payload); ok {
|
||||
reporter.publish(ctx, detail)
|
||||
}
|
||||
out := make(chan cliproxyexecutor.StreamChunk)
|
||||
stream = out
|
||||
go func(resp *http.Response) {
|
||||
defer close(out)
|
||||
defer func() {
|
||||
if errClose := resp.Body.Close(); errClose != nil {
|
||||
log.Errorf("antigravity executor: close response body error: %v", errClose)
|
||||
}
|
||||
}()
|
||||
scanner := bufio.NewScanner(resp.Body)
|
||||
scanner.Buffer(nil, streamScannerBuffer)
|
||||
var param any
|
||||
for scanner.Scan() {
|
||||
line := scanner.Bytes()
|
||||
appendAPIResponseChunk(ctx, e.cfg, line)
|
||||
|
||||
chunks := sdktranslator.TranslateStream(ctx, to, from, req.Model, bytes.Clone(opts.OriginalRequest), translated, bytes.Clone(payload), ¶m)
|
||||
for i := range chunks {
|
||||
out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunks[i])}
|
||||
// Filter usage metadata for all models
|
||||
// Only retain usage statistics in the terminal chunk
|
||||
line = FilterSSEUsageMetadata(line)
|
||||
|
||||
payload := jsonPayload(line)
|
||||
if payload == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
if detail, ok := parseAntigravityStreamUsage(payload); ok {
|
||||
reporter.publish(ctx, detail)
|
||||
}
|
||||
|
||||
chunks := sdktranslator.TranslateStream(ctx, to, from, req.Model, bytes.Clone(opts.OriginalRequest), translated, bytes.Clone(payload), ¶m)
|
||||
for i := range chunks {
|
||||
out <- cliproxyexecutor.StreamChunk{Payload: []byte(chunks[i])}
|
||||
}
|
||||
}
|
||||
}
|
||||
tail := sdktranslator.TranslateStream(ctx, to, from, req.Model, bytes.Clone(opts.OriginalRequest), translated, []byte("[DONE]"), ¶m)
|
||||
for i := range tail {
|
||||
out <- cliproxyexecutor.StreamChunk{Payload: []byte(tail[i])}
|
||||
}
|
||||
if errScan := scanner.Err(); errScan != nil {
|
||||
recordAPIResponseError(ctx, e.cfg, errScan)
|
||||
reporter.publishFailure(ctx)
|
||||
out <- cliproxyexecutor.StreamChunk{Err: errScan}
|
||||
} else {
|
||||
reporter.ensurePublished(ctx)
|
||||
}
|
||||
}()
|
||||
return stream, nil
|
||||
tail := sdktranslator.TranslateStream(ctx, to, from, req.Model, bytes.Clone(opts.OriginalRequest), translated, []byte("[DONE]"), ¶m)
|
||||
for i := range tail {
|
||||
out <- cliproxyexecutor.StreamChunk{Payload: []byte(tail[i])}
|
||||
}
|
||||
if errScan := scanner.Err(); errScan != nil {
|
||||
recordAPIResponseError(ctx, e.cfg, errScan)
|
||||
reporter.publishFailure(ctx)
|
||||
out <- cliproxyexecutor.StreamChunk{Err: errScan}
|
||||
} else {
|
||||
reporter.ensurePublished(ctx)
|
||||
}
|
||||
}(httpResp)
|
||||
return stream, nil
|
||||
}
|
||||
|
||||
switch {
|
||||
case lastStatus != 0:
|
||||
err = statusErr{code: lastStatus, msg: string(lastBody)}
|
||||
case lastErr != nil:
|
||||
err = lastErr
|
||||
default:
|
||||
err = statusErr{code: http.StatusServiceUnavailable, msg: "antigravity executor: no base url available"}
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Refresh refreshes the OAuth token using the refresh token.
|
||||
@@ -230,54 +311,86 @@ func FetchAntigravityModels(ctx context.Context, auth *cliproxyauth.Auth, cfg *c
|
||||
auth = updatedAuth
|
||||
}
|
||||
|
||||
modelsURL := buildBaseURL(auth) + antigravityModelsPath
|
||||
httpReq, errReq := http.NewRequestWithContext(ctx, http.MethodPost, modelsURL, bytes.NewReader([]byte(`{}`)))
|
||||
if errReq != nil {
|
||||
return nil
|
||||
}
|
||||
httpReq.Header.Set("Content-Type", "application/json")
|
||||
httpReq.Header.Set("Authorization", "Bearer "+token)
|
||||
httpReq.Header.Set("User-Agent", resolveUserAgent(auth))
|
||||
if host := resolveHost(auth); host != "" {
|
||||
httpReq.Host = host
|
||||
}
|
||||
|
||||
baseURLs := antigravityBaseURLFallbackOrder(auth)
|
||||
httpClient := newProxyAwareHTTPClient(ctx, cfg, auth, 0)
|
||||
httpResp, errDo := httpClient.Do(httpReq)
|
||||
if errDo != nil {
|
||||
return nil
|
||||
}
|
||||
defer func() {
|
||||
|
||||
for idx, baseURL := range baseURLs {
|
||||
modelsURL := baseURL + antigravityModelsPath
|
||||
httpReq, errReq := http.NewRequestWithContext(ctx, http.MethodPost, modelsURL, bytes.NewReader([]byte(`{}`)))
|
||||
if errReq != nil {
|
||||
return nil
|
||||
}
|
||||
httpReq.Header.Set("Content-Type", "application/json")
|
||||
httpReq.Header.Set("Authorization", "Bearer "+token)
|
||||
httpReq.Header.Set("User-Agent", resolveUserAgent(auth))
|
||||
if host := resolveHost(baseURL); host != "" {
|
||||
httpReq.Host = host
|
||||
}
|
||||
|
||||
httpResp, errDo := httpClient.Do(httpReq)
|
||||
if errDo != nil {
|
||||
if idx+1 < len(baseURLs) {
|
||||
log.Debugf("antigravity executor: models request error on base url %s, retrying with fallback base url: %s", baseURL, baseURLs[idx+1])
|
||||
continue
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
bodyBytes, errRead := io.ReadAll(httpResp.Body)
|
||||
if errClose := httpResp.Body.Close(); errClose != nil {
|
||||
log.Errorf("antigravity executor: close response body error: %v", errClose)
|
||||
}
|
||||
}()
|
||||
if errRead != nil {
|
||||
if idx+1 < len(baseURLs) {
|
||||
log.Debugf("antigravity executor: models read error on base url %s, retrying with fallback base url: %s", baseURL, baseURLs[idx+1])
|
||||
continue
|
||||
}
|
||||
return nil
|
||||
}
|
||||
if httpResp.StatusCode < http.StatusOK || httpResp.StatusCode >= http.StatusMultipleChoices {
|
||||
if httpResp.StatusCode == http.StatusTooManyRequests && idx+1 < len(baseURLs) {
|
||||
log.Debugf("antigravity executor: models request rate limited on base url %s, retrying with fallback base url: %s", baseURL, baseURLs[idx+1])
|
||||
continue
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
bodyBytes, errRead := io.ReadAll(httpResp.Body)
|
||||
if errRead != nil {
|
||||
return nil
|
||||
}
|
||||
if httpResp.StatusCode < http.StatusOK || httpResp.StatusCode >= http.StatusMultipleChoices {
|
||||
return nil
|
||||
}
|
||||
result := gjson.GetBytes(bodyBytes, "models")
|
||||
if !result.Exists() {
|
||||
return nil
|
||||
}
|
||||
|
||||
result := gjson.GetBytes(bodyBytes, "models")
|
||||
if !result.Exists() {
|
||||
return nil
|
||||
now := time.Now().Unix()
|
||||
models := make([]*registry.ModelInfo, 0, len(result.Map()))
|
||||
for id := range result.Map() {
|
||||
id = modelName2Alias(id)
|
||||
if id != "" {
|
||||
modelInfo := ®istry.ModelInfo{
|
||||
ID: id,
|
||||
Name: id,
|
||||
Description: id,
|
||||
DisplayName: id,
|
||||
Version: id,
|
||||
Object: "model",
|
||||
Created: now,
|
||||
OwnedBy: antigravityAuthType,
|
||||
Type: antigravityAuthType,
|
||||
}
|
||||
// Add Thinking support for thinking models
|
||||
if strings.HasSuffix(id, "-thinking") || strings.Contains(id, "-thinking-") {
|
||||
modelInfo.Thinking = ®istry.ThinkingSupport{
|
||||
Min: 1024,
|
||||
Max: 100000,
|
||||
ZeroAllowed: false,
|
||||
DynamicAllowed: true,
|
||||
}
|
||||
}
|
||||
models = append(models, modelInfo)
|
||||
}
|
||||
}
|
||||
return models
|
||||
}
|
||||
|
||||
now := time.Now().Unix()
|
||||
models := make([]*registry.ModelInfo, 0, len(result.Map()))
|
||||
for id := range result.Map() {
|
||||
models = append(models, ®istry.ModelInfo{
|
||||
ID: id,
|
||||
Object: "model",
|
||||
Created: now,
|
||||
OwnedBy: antigravityAuthType,
|
||||
Type: antigravityAuthType,
|
||||
})
|
||||
}
|
||||
return models
|
||||
return nil
|
||||
}
|
||||
|
||||
func (e *AntigravityExecutor) ensureAccessToken(ctx context.Context, auth *cliproxyauth.Auth) (string, *cliproxyauth.Auth, error) {
|
||||
@@ -363,12 +476,15 @@ func (e *AntigravityExecutor) refreshToken(ctx context.Context, auth *cliproxyau
|
||||
return auth, nil
|
||||
}
|
||||
|
||||
func (e *AntigravityExecutor) buildRequest(ctx context.Context, auth *cliproxyauth.Auth, token, modelName string, payload []byte, stream bool, alt string) (*http.Request, error) {
|
||||
func (e *AntigravityExecutor) buildRequest(ctx context.Context, auth *cliproxyauth.Auth, token, modelName string, payload []byte, stream bool, alt, baseURL string) (*http.Request, error) {
|
||||
if token == "" {
|
||||
return nil, statusErr{code: http.StatusUnauthorized, msg: "missing access token"}
|
||||
}
|
||||
|
||||
base := buildBaseURL(auth)
|
||||
base := strings.TrimSuffix(baseURL, "/")
|
||||
if base == "" {
|
||||
base = buildBaseURL(auth)
|
||||
}
|
||||
path := antigravityGeneratePath
|
||||
if stream {
|
||||
path = antigravityStreamPath
|
||||
@@ -389,6 +505,7 @@ func (e *AntigravityExecutor) buildRequest(ctx context.Context, auth *cliproxyau
|
||||
}
|
||||
|
||||
payload = geminiToAntigravity(modelName, payload)
|
||||
payload, _ = sjson.SetBytes(payload, "model", alias2ModelName(modelName))
|
||||
httpReq, errReq := http.NewRequestWithContext(ctx, http.MethodPost, requestURL.String(), bytes.NewReader(payload))
|
||||
if errReq != nil {
|
||||
return nil, errReq
|
||||
@@ -401,7 +518,7 @@ func (e *AntigravityExecutor) buildRequest(ctx context.Context, auth *cliproxyau
|
||||
} else {
|
||||
httpReq.Header.Set("Accept", "application/json")
|
||||
}
|
||||
if host := resolveHost(auth); host != "" {
|
||||
if host := resolveHost(base); host != "" {
|
||||
httpReq.Host = host
|
||||
}
|
||||
|
||||
@@ -485,26 +602,13 @@ func int64Value(value any) (int64, bool) {
|
||||
}
|
||||
|
||||
func buildBaseURL(auth *cliproxyauth.Auth) string {
|
||||
if auth != nil {
|
||||
if auth.Attributes != nil {
|
||||
if v := strings.TrimSpace(auth.Attributes["base_url"]); v != "" {
|
||||
return strings.TrimSuffix(v, "/")
|
||||
}
|
||||
}
|
||||
if auth.Metadata != nil {
|
||||
if v, ok := auth.Metadata["base_url"].(string); ok {
|
||||
v = strings.TrimSpace(v)
|
||||
if v != "" {
|
||||
return strings.TrimSuffix(v, "/")
|
||||
}
|
||||
}
|
||||
}
|
||||
if baseURLs := antigravityBaseURLFallbackOrder(auth); len(baseURLs) > 0 {
|
||||
return baseURLs[0]
|
||||
}
|
||||
return antigravityBaseURL
|
||||
return antigravityBaseURLAutopush
|
||||
}
|
||||
|
||||
func resolveHost(auth *cliproxyauth.Auth) string {
|
||||
base := buildBaseURL(auth)
|
||||
func resolveHost(base string) string {
|
||||
parsed, errParse := url.Parse(base)
|
||||
if errParse != nil {
|
||||
return ""
|
||||
@@ -531,6 +635,37 @@ func resolveUserAgent(auth *cliproxyauth.Auth) string {
|
||||
return defaultAntigravityAgent
|
||||
}
|
||||
|
||||
func antigravityBaseURLFallbackOrder(auth *cliproxyauth.Auth) []string {
|
||||
if base := resolveCustomAntigravityBaseURL(auth); base != "" {
|
||||
return []string{base}
|
||||
}
|
||||
return []string{
|
||||
antigravityBaseURLDaily,
|
||||
antigravityBaseURLAutopush,
|
||||
// antigravityBaseURLProd,
|
||||
}
|
||||
}
|
||||
|
||||
func resolveCustomAntigravityBaseURL(auth *cliproxyauth.Auth) string {
|
||||
if auth == nil {
|
||||
return ""
|
||||
}
|
||||
if auth.Attributes != nil {
|
||||
if v := strings.TrimSpace(auth.Attributes["base_url"]); v != "" {
|
||||
return strings.TrimSuffix(v, "/")
|
||||
}
|
||||
}
|
||||
if auth.Metadata != nil {
|
||||
if v, ok := auth.Metadata["base_url"].(string); ok {
|
||||
v = strings.TrimSpace(v)
|
||||
if v != "" {
|
||||
return strings.TrimSuffix(v, "/")
|
||||
}
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func geminiToAntigravity(modelName string, payload []byte) []byte {
|
||||
template, _ := sjson.Set(string(payload), "model", modelName)
|
||||
template, _ = sjson.Set(template, "userAgent", "antigravity")
|
||||
@@ -540,18 +675,27 @@ func geminiToAntigravity(modelName string, payload []byte) []byte {
|
||||
|
||||
template, _ = sjson.Delete(template, "request.safetySettings")
|
||||
template, _ = sjson.Set(template, "request.toolConfig.functionCallingConfig.mode", "VALIDATED")
|
||||
template, _ = sjson.Delete(template, "request.generationConfig.maxOutputTokens")
|
||||
if !strings.HasPrefix(modelName, "gemini-3-") {
|
||||
if thinkingLevel := gjson.Get(template, "request.generationConfig.thinkingConfig.thinkingLevel"); thinkingLevel.Exists() {
|
||||
template, _ = sjson.Delete(template, "request.generationConfig.thinkingConfig.thinkingLevel")
|
||||
template, _ = sjson.Set(template, "request.generationConfig.thinkingConfig.thinkingBudget", -1)
|
||||
}
|
||||
}
|
||||
|
||||
gjson.Get(template, "request.contents").ForEach(func(key, content gjson.Result) bool {
|
||||
if content.Get("role").String() == "model" {
|
||||
content.Get("parts").ForEach(func(partKey, part gjson.Result) bool {
|
||||
if part.Get("functionCall").Exists() {
|
||||
template, _ = sjson.Set(template, fmt.Sprintf("request.contents.%d.parts.%d.thoughtSignature", key.Int(), partKey.Int()), "skip_thought_signature_validator")
|
||||
if strings.HasPrefix(modelName, "claude-sonnet-") {
|
||||
gjson.Get(template, "request.tools").ForEach(func(key, tool gjson.Result) bool {
|
||||
tool.Get("functionDeclarations").ForEach(func(funKey, funcDecl gjson.Result) bool {
|
||||
if funcDecl.Get("parametersJsonSchema").Exists() {
|
||||
template, _ = sjson.SetRaw(template, fmt.Sprintf("request.tools.%d.functionDeclarations.%d.parameters", key.Int(), funKey.Int()), funcDecl.Get("parametersJsonSchema").Raw)
|
||||
template, _ = sjson.Delete(template, fmt.Sprintf("request.tools.%d.functionDeclarations.%d.parameters.$schema", key.Int(), funKey.Int()))
|
||||
template, _ = sjson.Delete(template, fmt.Sprintf("request.tools.%d.functionDeclarations.%d.parametersJsonSchema", key.Int(), funKey.Int()))
|
||||
}
|
||||
return true
|
||||
})
|
||||
}
|
||||
return true
|
||||
})
|
||||
return true
|
||||
})
|
||||
}
|
||||
|
||||
return []byte(template)
|
||||
}
|
||||
@@ -573,3 +717,39 @@ func generateProjectID() string {
|
||||
randomPart := strings.ToLower(uuid.NewString())[:5]
|
||||
return adj + "-" + noun + "-" + randomPart
|
||||
}
|
||||
|
||||
func modelName2Alias(modelName string) string {
|
||||
switch modelName {
|
||||
case "rev19-uic3-1p":
|
||||
return "gemini-2.5-computer-use-preview-10-2025"
|
||||
case "gemini-3-pro-image":
|
||||
return "gemini-3-pro-image-preview"
|
||||
case "gemini-3-pro-high":
|
||||
return "gemini-3-pro-preview"
|
||||
case "claude-sonnet-4-5":
|
||||
return "gemini-claude-sonnet-4-5"
|
||||
case "claude-sonnet-4-5-thinking":
|
||||
return "gemini-claude-sonnet-4-5-thinking"
|
||||
case "chat_20706", "chat_23310", "gemini-2.5-flash-thinking", "gemini-3-pro-low", "gemini-2.5-pro":
|
||||
return ""
|
||||
default:
|
||||
return modelName
|
||||
}
|
||||
}
|
||||
|
||||
func alias2ModelName(modelName string) string {
|
||||
switch modelName {
|
||||
case "gemini-2.5-computer-use-preview-10-2025":
|
||||
return "rev19-uic3-1p"
|
||||
case "gemini-3-pro-image-preview":
|
||||
return "gemini-3-pro-image"
|
||||
case "gemini-3-pro-preview":
|
||||
return "gemini-3-pro-high"
|
||||
case "gemini-claude-sonnet-4-5":
|
||||
return "claude-sonnet-4-5"
|
||||
case "gemini-claude-sonnet-4-5-thinking":
|
||||
return "claude-sonnet-4-5-thinking"
|
||||
default:
|
||||
return modelName
|
||||
}
|
||||
}
|
||||
|
||||
@@ -58,18 +58,24 @@ func (e *ClaudeExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, r
|
||||
body, _ = sjson.SetBytes(body, "model", modelOverride)
|
||||
modelForUpstream = modelOverride
|
||||
}
|
||||
// Inject thinking config based on model suffix for thinking variants
|
||||
body = e.injectThinkingConfig(req.Model, body)
|
||||
|
||||
if !strings.HasPrefix(modelForUpstream, "claude-3-5-haiku") {
|
||||
body, _ = sjson.SetRawBytes(body, "system", []byte(misc.ClaudeCodeInstructions))
|
||||
body = checkSystemInstructions(body)
|
||||
}
|
||||
body = applyPayloadConfig(e.cfg, req.Model, body)
|
||||
|
||||
// Extract betas from body and convert to header
|
||||
var extraBetas []string
|
||||
extraBetas, body = extractAndRemoveBetas(body)
|
||||
|
||||
url := fmt.Sprintf("%s/v1/messages?beta=true", baseURL)
|
||||
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(body))
|
||||
if err != nil {
|
||||
return resp, err
|
||||
}
|
||||
applyClaudeHeaders(httpReq, auth, apiKey, false)
|
||||
applyClaudeHeaders(httpReq, auth, apiKey, false, extraBetas)
|
||||
var authID, authLabel, authType, authValue string
|
||||
if auth != nil {
|
||||
authID = auth.ID
|
||||
@@ -154,15 +160,21 @@ func (e *ClaudeExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A
|
||||
if modelOverride := e.resolveUpstreamModel(req.Model, auth); modelOverride != "" {
|
||||
body, _ = sjson.SetBytes(body, "model", modelOverride)
|
||||
}
|
||||
body, _ = sjson.SetRawBytes(body, "system", []byte(misc.ClaudeCodeInstructions))
|
||||
// Inject thinking config based on model suffix for thinking variants
|
||||
body = e.injectThinkingConfig(req.Model, body)
|
||||
body = checkSystemInstructions(body)
|
||||
body = applyPayloadConfig(e.cfg, req.Model, body)
|
||||
|
||||
// Extract betas from body and convert to header
|
||||
var extraBetas []string
|
||||
extraBetas, body = extractAndRemoveBetas(body)
|
||||
|
||||
url := fmt.Sprintf("%s/v1/messages?beta=true", baseURL)
|
||||
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(body))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
applyClaudeHeaders(httpReq, auth, apiKey, true)
|
||||
applyClaudeHeaders(httpReq, auth, apiKey, true, extraBetas)
|
||||
var authID, authLabel, authType, authValue string
|
||||
if auth != nil {
|
||||
authID = auth.ID
|
||||
@@ -283,15 +295,19 @@ func (e *ClaudeExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Aut
|
||||
}
|
||||
|
||||
if !strings.HasPrefix(modelForUpstream, "claude-3-5-haiku") {
|
||||
body, _ = sjson.SetRawBytes(body, "system", []byte(misc.ClaudeCodeInstructions))
|
||||
body = checkSystemInstructions(body)
|
||||
}
|
||||
|
||||
// Extract betas from body and convert to header (for count_tokens too)
|
||||
var extraBetas []string
|
||||
extraBetas, body = extractAndRemoveBetas(body)
|
||||
|
||||
url := fmt.Sprintf("%s/v1/messages/count_tokens?beta=true", baseURL)
|
||||
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(body))
|
||||
if err != nil {
|
||||
return cliproxyexecutor.Response{}, err
|
||||
}
|
||||
applyClaudeHeaders(httpReq, auth, apiKey, false)
|
||||
applyClaudeHeaders(httpReq, auth, apiKey, false, extraBetas)
|
||||
var authID, authLabel, authType, authValue string
|
||||
if auth != nil {
|
||||
authID = auth.ID
|
||||
@@ -383,10 +399,65 @@ func (e *ClaudeExecutor) Refresh(ctx context.Context, auth *cliproxyauth.Auth) (
|
||||
return auth, nil
|
||||
}
|
||||
|
||||
// extractAndRemoveBetas extracts the "betas" array from the body and removes it.
|
||||
// Returns the extracted betas as a string slice and the modified body.
|
||||
func extractAndRemoveBetas(body []byte) ([]string, []byte) {
|
||||
betasResult := gjson.GetBytes(body, "betas")
|
||||
if !betasResult.Exists() {
|
||||
return nil, body
|
||||
}
|
||||
var betas []string
|
||||
if betasResult.IsArray() {
|
||||
for _, item := range betasResult.Array() {
|
||||
if s := strings.TrimSpace(item.String()); s != "" {
|
||||
betas = append(betas, s)
|
||||
}
|
||||
}
|
||||
} else if s := strings.TrimSpace(betasResult.String()); s != "" {
|
||||
betas = append(betas, s)
|
||||
}
|
||||
body, _ = sjson.DeleteBytes(body, "betas")
|
||||
return betas, body
|
||||
}
|
||||
|
||||
// injectThinkingConfig adds thinking configuration based on model name suffix
|
||||
func (e *ClaudeExecutor) injectThinkingConfig(modelName string, body []byte) []byte {
|
||||
// Only inject if thinking config is not already present
|
||||
if gjson.GetBytes(body, "thinking").Exists() {
|
||||
return body
|
||||
}
|
||||
|
||||
var budgetTokens int
|
||||
switch {
|
||||
case strings.HasSuffix(modelName, "-thinking-low"):
|
||||
budgetTokens = 1024
|
||||
case strings.HasSuffix(modelName, "-thinking-medium"):
|
||||
budgetTokens = 8192
|
||||
case strings.HasSuffix(modelName, "-thinking-high"):
|
||||
budgetTokens = 24576
|
||||
case strings.HasSuffix(modelName, "-thinking"):
|
||||
// Default thinking without suffix uses medium budget
|
||||
budgetTokens = 8192
|
||||
default:
|
||||
return body
|
||||
}
|
||||
|
||||
body, _ = sjson.SetBytes(body, "thinking.type", "enabled")
|
||||
body, _ = sjson.SetBytes(body, "thinking.budget_tokens", budgetTokens)
|
||||
return body
|
||||
}
|
||||
|
||||
func (e *ClaudeExecutor) resolveUpstreamModel(alias string, auth *cliproxyauth.Auth) string {
|
||||
if alias == "" {
|
||||
return ""
|
||||
}
|
||||
// Hardcoded mappings for thinking models to actual Claude model names
|
||||
switch alias {
|
||||
case "claude-opus-4-5-thinking", "claude-opus-4-5-thinking-low", "claude-opus-4-5-thinking-medium", "claude-opus-4-5-thinking-high":
|
||||
return "claude-opus-4-5-20251101"
|
||||
case "claude-sonnet-4-5-thinking":
|
||||
return "claude-sonnet-4-5-20250929"
|
||||
}
|
||||
entry := e.resolveClaudeConfig(auth)
|
||||
if entry == nil {
|
||||
return ""
|
||||
@@ -530,7 +601,7 @@ func decodeResponseBody(body io.ReadCloser, contentEncoding string) (io.ReadClos
|
||||
return body, nil
|
||||
}
|
||||
|
||||
func applyClaudeHeaders(r *http.Request, auth *cliproxyauth.Auth, apiKey string, stream bool) {
|
||||
func applyClaudeHeaders(r *http.Request, auth *cliproxyauth.Auth, apiKey string, stream bool, extraBetas []string) {
|
||||
r.Header.Set("Authorization", "Bearer "+apiKey)
|
||||
r.Header.Set("Content-Type", "application/json")
|
||||
|
||||
@@ -539,15 +610,30 @@ func applyClaudeHeaders(r *http.Request, auth *cliproxyauth.Auth, apiKey string,
|
||||
ginHeaders = ginCtx.Request.Header
|
||||
}
|
||||
|
||||
baseBetas := "claude-code-20250219,oauth-2025-04-20,interleaved-thinking-2025-05-14,fine-grained-tool-streaming-2025-05-14"
|
||||
if val := strings.TrimSpace(ginHeaders.Get("Anthropic-Beta")); val != "" {
|
||||
baseBetas = val
|
||||
if !strings.Contains(val, "oauth") {
|
||||
val += ",oauth-2025-04-20"
|
||||
baseBetas += ",oauth-2025-04-20"
|
||||
}
|
||||
r.Header.Set("Anthropic-Beta", val)
|
||||
} else {
|
||||
r.Header.Set("Anthropic-Beta", "claude-code-20250219,oauth-2025-04-20,interleaved-thinking-2025-05-14,fine-grained-tool-streaming-2025-05-14")
|
||||
}
|
||||
|
||||
// Merge extra betas from request body
|
||||
if len(extraBetas) > 0 {
|
||||
existingSet := make(map[string]bool)
|
||||
for _, b := range strings.Split(baseBetas, ",") {
|
||||
existingSet[strings.TrimSpace(b)] = true
|
||||
}
|
||||
for _, beta := range extraBetas {
|
||||
beta = strings.TrimSpace(beta)
|
||||
if beta != "" && !existingSet[beta] {
|
||||
baseBetas += "," + beta
|
||||
existingSet[beta] = true
|
||||
}
|
||||
}
|
||||
}
|
||||
r.Header.Set("Anthropic-Beta", baseBetas)
|
||||
|
||||
misc.EnsureHeader(r.Header, ginHeaders, "Anthropic-Version", "2023-06-01")
|
||||
misc.EnsureHeader(r.Header, ginHeaders, "Anthropic-Dangerous-Direct-Browser-Access", "true")
|
||||
misc.EnsureHeader(r.Header, ginHeaders, "X-App", "cli")
|
||||
@@ -590,3 +676,22 @@ func claudeCreds(a *cliproxyauth.Auth) (apiKey, baseURL string) {
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func checkSystemInstructions(payload []byte) []byte {
|
||||
system := gjson.GetBytes(payload, "system")
|
||||
claudeCodeInstructions := `[{"type":"text","text":"You are Claude Code, Anthropic's official CLI for Claude."}]`
|
||||
if system.IsArray() {
|
||||
if gjson.GetBytes(payload, "system.0.text").String() != "You are Claude Code, Anthropic's official CLI for Claude." {
|
||||
system.ForEach(func(_, part gjson.Result) bool {
|
||||
if part.Get("type").String() == "text" {
|
||||
claudeCodeInstructions, _ = sjson.SetRaw(claudeCodeInstructions, "-1", part.Raw)
|
||||
}
|
||||
return true
|
||||
})
|
||||
payload, _ = sjson.SetRawBytes(payload, "system", []byte(claudeCodeInstructions))
|
||||
}
|
||||
} else {
|
||||
payload, _ = sjson.SetRawBytes(payload, "system", []byte(claudeCodeInstructions))
|
||||
}
|
||||
return payload
|
||||
}
|
||||
|
||||
@@ -98,6 +98,20 @@ func ConvertGeminiRequestToAntigravity(_ string, inputRawJSON []byte, _ bool) []
|
||||
}
|
||||
}
|
||||
|
||||
gjson.GetBytes(rawJSON, "request.contents").ForEach(func(key, content gjson.Result) bool {
|
||||
if content.Get("role").String() == "model" {
|
||||
content.Get("parts").ForEach(func(partKey, part gjson.Result) bool {
|
||||
if part.Get("functionCall").Exists() {
|
||||
rawJSON, _ = sjson.SetBytes(rawJSON, fmt.Sprintf("request.contents.%d.parts.%d.thoughtSignature", key.Int(), partKey.Int()), "skip_thought_signature_validator")
|
||||
} else if part.Get("thoughtSignature").Exists() {
|
||||
rawJSON, _ = sjson.SetBytes(rawJSON, fmt.Sprintf("request.contents.%d.parts.%d.thoughtSignature", key.Int(), partKey.Int()), "skip_thought_signature_validator")
|
||||
}
|
||||
return true
|
||||
})
|
||||
}
|
||||
return true
|
||||
})
|
||||
|
||||
return common.AttachDefaultSafetySettings(rawJSON, "request.safetySettings")
|
||||
}
|
||||
|
||||
|
||||
@@ -105,14 +105,19 @@ func ConvertAntigravityResponseToOpenAI(_ context.Context, _ string, originalReq
|
||||
partTextResult := partResult.Get("text")
|
||||
functionCallResult := partResult.Get("functionCall")
|
||||
thoughtSignatureResult := partResult.Get("thoughtSignature")
|
||||
if !thoughtSignatureResult.Exists() {
|
||||
thoughtSignatureResult = partResult.Get("thought_signature")
|
||||
}
|
||||
inlineDataResult := partResult.Get("inlineData")
|
||||
if !inlineDataResult.Exists() {
|
||||
inlineDataResult = partResult.Get("inline_data")
|
||||
}
|
||||
|
||||
// Handle thoughtSignature - this is encrypted reasoning content that should not be exposed to the client
|
||||
if thoughtSignatureResult.Exists() && thoughtSignatureResult.String() != "" {
|
||||
// Skip thoughtSignature processing - it's internal encrypted data
|
||||
hasThoughtSignature := thoughtSignatureResult.Exists() && thoughtSignatureResult.String() != ""
|
||||
hasContentPayload := partTextResult.Exists() || functionCallResult.Exists() || inlineDataResult.Exists()
|
||||
|
||||
// Ignore encrypted thoughtSignature but keep any actual content in the same part.
|
||||
if hasThoughtSignature && !hasContentPayload {
|
||||
continue
|
||||
}
|
||||
|
||||
|
||||
@@ -98,6 +98,20 @@ func ConvertGeminiRequestToGeminiCLI(_ string, inputRawJSON []byte, _ bool) []by
|
||||
}
|
||||
}
|
||||
|
||||
gjson.GetBytes(rawJSON, "request.contents").ForEach(func(key, content gjson.Result) bool {
|
||||
if content.Get("role").String() == "model" {
|
||||
content.Get("parts").ForEach(func(partKey, part gjson.Result) bool {
|
||||
if part.Get("functionCall").Exists() {
|
||||
rawJSON, _ = sjson.SetBytes(rawJSON, fmt.Sprintf("request.contents.%d.parts.%d.thoughtSignature", key.Int(), partKey.Int()), "skip_thought_signature_validator")
|
||||
} else if part.Get("thoughtSignature").Exists() {
|
||||
rawJSON, _ = sjson.SetBytes(rawJSON, fmt.Sprintf("request.contents.%d.parts.%d.thoughtSignature", key.Int(), partKey.Int()), "skip_thought_signature_validator")
|
||||
}
|
||||
return true
|
||||
})
|
||||
}
|
||||
return true
|
||||
})
|
||||
|
||||
return common.AttachDefaultSafetySettings(rawJSON, "request.safetySettings")
|
||||
}
|
||||
|
||||
|
||||
@@ -105,14 +105,19 @@ func ConvertCliResponseToOpenAI(_ context.Context, _ string, originalRequestRawJ
|
||||
partTextResult := partResult.Get("text")
|
||||
functionCallResult := partResult.Get("functionCall")
|
||||
thoughtSignatureResult := partResult.Get("thoughtSignature")
|
||||
if !thoughtSignatureResult.Exists() {
|
||||
thoughtSignatureResult = partResult.Get("thought_signature")
|
||||
}
|
||||
inlineDataResult := partResult.Get("inlineData")
|
||||
if !inlineDataResult.Exists() {
|
||||
inlineDataResult = partResult.Get("inline_data")
|
||||
}
|
||||
|
||||
// Handle thoughtSignature - this is encrypted reasoning content that should not be exposed to the client
|
||||
if thoughtSignatureResult.Exists() && thoughtSignatureResult.String() != "" {
|
||||
// Skip thoughtSignature processing - it's internal encrypted data
|
||||
hasThoughtSignature := thoughtSignatureResult.Exists() && thoughtSignatureResult.String() != ""
|
||||
hasContentPayload := partTextResult.Exists() || functionCallResult.Exists() || inlineDataResult.Exists()
|
||||
|
||||
// Ignore encrypted thoughtSignature but keep any actual content in the same part.
|
||||
if hasThoughtSignature && !hasContentPayload {
|
||||
continue
|
||||
}
|
||||
|
||||
|
||||
@@ -46,5 +46,19 @@ func ConvertGeminiCLIRequestToGemini(_ string, inputRawJSON []byte, _ bool) []by
|
||||
}
|
||||
}
|
||||
|
||||
gjson.GetBytes(rawJSON, "contents").ForEach(func(key, content gjson.Result) bool {
|
||||
if content.Get("role").String() == "model" {
|
||||
content.Get("parts").ForEach(func(partKey, part gjson.Result) bool {
|
||||
if part.Get("functionCall").Exists() {
|
||||
rawJSON, _ = sjson.SetBytes(rawJSON, fmt.Sprintf("contents.%d.parts.%d.thoughtSignature", key.Int(), partKey.Int()), "skip_thought_signature_validator")
|
||||
} else if part.Get("thoughtSignature").Exists() {
|
||||
rawJSON, _ = sjson.SetBytes(rawJSON, fmt.Sprintf("contents.%d.parts.%d.thoughtSignature", key.Int(), partKey.Int()), "skip_thought_signature_validator")
|
||||
}
|
||||
return true
|
||||
})
|
||||
}
|
||||
return true
|
||||
})
|
||||
|
||||
return common.AttachDefaultSafetySettings(rawJSON, "safetySettings")
|
||||
}
|
||||
|
||||
@@ -30,6 +30,11 @@ func ConvertGeminiRequestToGemini(_ string, inputRawJSON []byte, _ bool) []byte
|
||||
if toolsResult.Exists() && toolsResult.IsArray() {
|
||||
toolResults := toolsResult.Array()
|
||||
for i := 0; i < len(toolResults); i++ {
|
||||
if gjson.GetBytes(rawJSON, fmt.Sprintf("tools.%d.functionDeclarations", i)).Exists() {
|
||||
strJson, _ := util.RenameKey(string(rawJSON), fmt.Sprintf("tools.%d.functionDeclarations", i), fmt.Sprintf("tools.%d.function_declarations", i))
|
||||
rawJSON = []byte(strJson)
|
||||
}
|
||||
|
||||
functionDeclarationsResult := gjson.GetBytes(rawJSON, fmt.Sprintf("tools.%d.function_declarations", i))
|
||||
if functionDeclarationsResult.Exists() && functionDeclarationsResult.IsArray() {
|
||||
functionDeclarationsResults := functionDeclarationsResult.Array()
|
||||
@@ -72,7 +77,20 @@ func ConvertGeminiRequestToGemini(_ string, inputRawJSON []byte, _ bool) []byte
|
||||
return true
|
||||
})
|
||||
|
||||
out = common.AttachDefaultSafetySettings(out, "safetySettings")
|
||||
gjson.GetBytes(out, "contents").ForEach(func(key, content gjson.Result) bool {
|
||||
if content.Get("role").String() == "model" {
|
||||
content.Get("parts").ForEach(func(partKey, part gjson.Result) bool {
|
||||
if part.Get("functionCall").Exists() {
|
||||
out, _ = sjson.SetBytes(out, fmt.Sprintf("contents.%d.parts.%d.thoughtSignature", key.Int(), partKey.Int()), "skip_thought_signature_validator")
|
||||
} else if part.Get("thoughtSignature").Exists() {
|
||||
out, _ = sjson.SetBytes(out, fmt.Sprintf("contents.%d.parts.%d.thoughtSignature", key.Int(), partKey.Int()), "skip_thought_signature_validator")
|
||||
}
|
||||
return true
|
||||
})
|
||||
}
|
||||
return true
|
||||
})
|
||||
|
||||
out = common.AttachDefaultSafetySettings(out, "safetySettings")
|
||||
return out
|
||||
}
|
||||
|
||||
@@ -116,8 +116,11 @@ func ConvertGeminiResponseToOpenAI(_ context.Context, _ string, originalRequestR
|
||||
thoughtSignatureResult = partResult.Get("thought_signature")
|
||||
}
|
||||
|
||||
// Skip thoughtSignature parts (encrypted reasoning not exposed downstream).
|
||||
if thoughtSignatureResult.Exists() && thoughtSignatureResult.String() != "" {
|
||||
hasThoughtSignature := thoughtSignatureResult.Exists() && thoughtSignatureResult.String() != ""
|
||||
hasContentPayload := partTextResult.Exists() || functionCallResult.Exists() || inlineDataResult.Exists()
|
||||
|
||||
// Skip pure thoughtSignature parts but keep any actual payload in the same part.
|
||||
if hasThoughtSignature && !hasContentPayload {
|
||||
continue
|
||||
}
|
||||
|
||||
|
||||
@@ -33,7 +33,83 @@ func ConvertOpenAIResponsesRequestToGemini(modelName string, inputRawJSON []byte
|
||||
|
||||
// Convert input messages to Gemini contents format
|
||||
if input := root.Get("input"); input.Exists() && input.IsArray() {
|
||||
input.ForEach(func(_, item gjson.Result) bool {
|
||||
items := input.Array()
|
||||
|
||||
// Normalize consecutive function calls and outputs so each call is immediately followed by its response
|
||||
normalized := make([]gjson.Result, 0, len(items))
|
||||
for i := 0; i < len(items); {
|
||||
item := items[i]
|
||||
itemType := item.Get("type").String()
|
||||
itemRole := item.Get("role").String()
|
||||
if itemType == "" && itemRole != "" {
|
||||
itemType = "message"
|
||||
}
|
||||
|
||||
if itemType == "function_call" {
|
||||
var calls []gjson.Result
|
||||
var outputs []gjson.Result
|
||||
|
||||
for i < len(items) {
|
||||
next := items[i]
|
||||
nextType := next.Get("type").String()
|
||||
nextRole := next.Get("role").String()
|
||||
if nextType == "" && nextRole != "" {
|
||||
nextType = "message"
|
||||
}
|
||||
if nextType != "function_call" {
|
||||
break
|
||||
}
|
||||
calls = append(calls, next)
|
||||
i++
|
||||
}
|
||||
|
||||
for i < len(items) {
|
||||
next := items[i]
|
||||
nextType := next.Get("type").String()
|
||||
nextRole := next.Get("role").String()
|
||||
if nextType == "" && nextRole != "" {
|
||||
nextType = "message"
|
||||
}
|
||||
if nextType != "function_call_output" {
|
||||
break
|
||||
}
|
||||
outputs = append(outputs, next)
|
||||
i++
|
||||
}
|
||||
|
||||
if len(calls) > 0 {
|
||||
outputMap := make(map[string]gjson.Result, len(outputs))
|
||||
for _, out := range outputs {
|
||||
outputMap[out.Get("call_id").String()] = out
|
||||
}
|
||||
for _, call := range calls {
|
||||
normalized = append(normalized, call)
|
||||
callID := call.Get("call_id").String()
|
||||
if resp, ok := outputMap[callID]; ok {
|
||||
normalized = append(normalized, resp)
|
||||
delete(outputMap, callID)
|
||||
}
|
||||
}
|
||||
for _, out := range outputs {
|
||||
if _, ok := outputMap[out.Get("call_id").String()]; ok {
|
||||
normalized = append(normalized, out)
|
||||
}
|
||||
}
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
if itemType == "function_call_output" {
|
||||
normalized = append(normalized, item)
|
||||
i++
|
||||
continue
|
||||
}
|
||||
|
||||
normalized = append(normalized, item)
|
||||
i++
|
||||
}
|
||||
|
||||
for _, item := range normalized {
|
||||
itemType := item.Get("type").String()
|
||||
itemRole := item.Get("role").String()
|
||||
if itemType == "" && itemRole != "" {
|
||||
@@ -59,7 +135,7 @@ func ConvertOpenAIResponsesRequestToGemini(modelName string, inputRawJSON []byte
|
||||
out, _ = sjson.SetRaw(out, "system_instruction", systemInstr)
|
||||
}
|
||||
}
|
||||
return true
|
||||
continue
|
||||
}
|
||||
|
||||
// Handle regular messages
|
||||
@@ -186,7 +262,8 @@ func ConvertOpenAIResponsesRequestToGemini(modelName string, inputRawJSON []byte
|
||||
case "function_call_output":
|
||||
// Handle function call outputs - convert to function message with functionResponse
|
||||
callID := item.Get("call_id").String()
|
||||
output := item.Get("output").String()
|
||||
// Use .Raw to preserve the JSON encoding (includes quotes for strings)
|
||||
outputRaw := item.Get("output").Str
|
||||
|
||||
functionContent := `{"role":"function","parts":[]}`
|
||||
functionResponse := `{"functionResponse":{"name":"","response":{}}}`
|
||||
@@ -209,18 +286,19 @@ func ConvertOpenAIResponsesRequestToGemini(modelName string, inputRawJSON []byte
|
||||
|
||||
functionResponse, _ = sjson.Set(functionResponse, "functionResponse.name", functionName)
|
||||
|
||||
// Parse output JSON string and set as response content
|
||||
if output != "" {
|
||||
outputResult := gjson.Parse(output)
|
||||
functionResponse, _ = sjson.Set(functionResponse, "functionResponse.response.result", outputResult.Raw)
|
||||
// Set the raw JSON output directly (preserves string encoding)
|
||||
if outputRaw != "" && outputRaw != "null" {
|
||||
output := gjson.Parse(outputRaw)
|
||||
if output.Type == gjson.JSON {
|
||||
functionResponse, _ = sjson.SetRaw(functionResponse, "functionResponse.response.result", output.Raw)
|
||||
} else {
|
||||
functionResponse, _ = sjson.Set(functionResponse, "functionResponse.response.result", outputRaw)
|
||||
}
|
||||
}
|
||||
|
||||
functionContent, _ = sjson.SetRaw(functionContent, "parts.-1", functionResponse)
|
||||
out, _ = sjson.SetRaw(out, "contents.-1", functionContent)
|
||||
}
|
||||
|
||||
return true
|
||||
})
|
||||
}
|
||||
} else if input.Exists() && input.Type == gjson.String {
|
||||
// Simple string input conversion to user message
|
||||
userContent := `{"role":"user","parts":[{"text":""}]}`
|
||||
|
||||
@@ -433,12 +433,18 @@ func ConvertGeminiResponseToOpenAIResponses(_ context.Context, modelName string,
|
||||
// output tokens
|
||||
if v := um.Get("candidatesTokenCount"); v.Exists() {
|
||||
completed, _ = sjson.Set(completed, "response.usage.output_tokens", v.Int())
|
||||
} else {
|
||||
completed, _ = sjson.Set(completed, "response.usage.output_tokens", 0)
|
||||
}
|
||||
if v := um.Get("thoughtsTokenCount"); v.Exists() {
|
||||
completed, _ = sjson.Set(completed, "response.usage.output_tokens_details.reasoning_tokens", v.Int())
|
||||
} else {
|
||||
completed, _ = sjson.Set(completed, "response.usage.output_tokens_details.reasoning_tokens", 0)
|
||||
}
|
||||
if v := um.Get("totalTokenCount"); v.Exists() {
|
||||
completed, _ = sjson.Set(completed, "response.usage.total_tokens", v.Int())
|
||||
} else {
|
||||
completed, _ = sjson.Set(completed, "response.usage.total_tokens", 0)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -474,6 +474,33 @@ func (w *Watcher) processEvents(ctx context.Context) {
|
||||
}
|
||||
}
|
||||
|
||||
func (w *Watcher) authFileUnchanged(path string) (bool, error) {
|
||||
data, errRead := os.ReadFile(path)
|
||||
if errRead != nil {
|
||||
return false, errRead
|
||||
}
|
||||
if len(data) == 0 {
|
||||
return false, nil
|
||||
}
|
||||
sum := sha256.Sum256(data)
|
||||
curHash := hex.EncodeToString(sum[:])
|
||||
|
||||
w.clientsMutex.RLock()
|
||||
prevHash, ok := w.lastAuthHashes[path]
|
||||
w.clientsMutex.RUnlock()
|
||||
if ok && prevHash == curHash {
|
||||
return true, nil
|
||||
}
|
||||
return false, nil
|
||||
}
|
||||
|
||||
func (w *Watcher) isKnownAuthFile(path string) bool {
|
||||
w.clientsMutex.RLock()
|
||||
defer w.clientsMutex.RUnlock()
|
||||
_, ok := w.lastAuthHashes[path]
|
||||
return ok
|
||||
}
|
||||
|
||||
// handleEvent processes individual file system events
|
||||
func (w *Watcher) handleEvent(event fsnotify.Event) {
|
||||
// Filter only relevant events: config file or auth-dir JSON files.
|
||||
@@ -497,19 +524,33 @@ func (w *Watcher) handleEvent(event fsnotify.Event) {
|
||||
}
|
||||
|
||||
// Handle auth directory changes incrementally (.json only)
|
||||
fmt.Printf("auth file changed (%s): %s, processing incrementally\n", event.Op.String(), filepath.Base(event.Name))
|
||||
if event.Op&(fsnotify.Remove|fsnotify.Rename) != 0 {
|
||||
// Atomic replace on some platforms may surface as Rename (or Remove) before the new file is ready.
|
||||
// Wait briefly; if the path exists again, treat as an update instead of removal.
|
||||
time.Sleep(replaceCheckDelay)
|
||||
if _, statErr := os.Stat(event.Name); statErr == nil {
|
||||
if unchanged, errSame := w.authFileUnchanged(event.Name); errSame == nil && unchanged {
|
||||
log.Debugf("auth file unchanged (hash match), skipping reload: %s", filepath.Base(event.Name))
|
||||
return
|
||||
}
|
||||
fmt.Printf("auth file changed (%s): %s, processing incrementally\n", event.Op.String(), filepath.Base(event.Name))
|
||||
w.addOrUpdateClient(event.Name)
|
||||
return
|
||||
}
|
||||
if !w.isKnownAuthFile(event.Name) {
|
||||
log.Debugf("ignoring remove for unknown auth file: %s", filepath.Base(event.Name))
|
||||
return
|
||||
}
|
||||
fmt.Printf("auth file changed (%s): %s, processing incrementally\n", event.Op.String(), filepath.Base(event.Name))
|
||||
w.removeClient(event.Name)
|
||||
return
|
||||
}
|
||||
if event.Op&(fsnotify.Create|fsnotify.Write) != 0 {
|
||||
if unchanged, errSame := w.authFileUnchanged(event.Name); errSame == nil && unchanged {
|
||||
log.Debugf("auth file unchanged (hash match), skipping reload: %s", filepath.Base(event.Name))
|
||||
return
|
||||
}
|
||||
fmt.Printf("auth file changed (%s): %s, processing incrementally\n", event.Op.String(), filepath.Base(event.Name))
|
||||
w.addOrUpdateClient(event.Name)
|
||||
}
|
||||
}
|
||||
@@ -1378,6 +1419,9 @@ func buildConfigChangeDetails(oldCfg, newCfg *config.Config) []string {
|
||||
if oldCfg.RequestRetry != newCfg.RequestRetry {
|
||||
changes = append(changes, fmt.Sprintf("request-retry: %d -> %d", oldCfg.RequestRetry, newCfg.RequestRetry))
|
||||
}
|
||||
if oldCfg.MaxRetryInterval != newCfg.MaxRetryInterval {
|
||||
changes = append(changes, fmt.Sprintf("max-retry-interval: %d -> %d", oldCfg.MaxRetryInterval, newCfg.MaxRetryInterval))
|
||||
}
|
||||
if oldCfg.ProxyURL != newCfg.ProxyURL {
|
||||
changes = append(changes, fmt.Sprintf("proxy-url: %s -> %s", oldCfg.ProxyURL, newCfg.ProxyURL))
|
||||
}
|
||||
|
||||
@@ -69,6 +69,27 @@ func (h *GeminiAPIHandler) GeminiGetHandler(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
switch request.Action {
|
||||
case "gemini-3-pro-preview":
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"name": "models/gemini-3-pro-preview",
|
||||
"version": "3",
|
||||
"displayName": "Gemini 3 Pro Preview",
|
||||
"description": "Gemini 3 Pro Preview",
|
||||
"inputTokenLimit": 1048576,
|
||||
"outputTokenLimit": 65536,
|
||||
"supportedGenerationMethods": []string{
|
||||
"generateContent",
|
||||
"countTokens",
|
||||
"createCachedContent",
|
||||
"batchGenerateContent",
|
||||
},
|
||||
"temperature": 1,
|
||||
"topP": 0.95,
|
||||
"topK": 64,
|
||||
"maxTemperature": 2,
|
||||
"thinking": true,
|
||||
},
|
||||
)
|
||||
case "gemini-2.5-pro":
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"name": "models/gemini-2.5-pro",
|
||||
|
||||
@@ -4,6 +4,7 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
@@ -120,11 +121,11 @@ func (h *BaseAPIHandler) GetContextWithCancel(handler interfaces.APIHandler, c *
|
||||
data := params[0]
|
||||
switch data.(type) {
|
||||
case []byte:
|
||||
c.Set("API_RESPONSE", data.([]byte))
|
||||
appendAPIResponse(c, data.([]byte))
|
||||
case error:
|
||||
c.Set("API_RESPONSE", []byte(data.(error).Error()))
|
||||
appendAPIResponse(c, []byte(data.(error).Error()))
|
||||
case string:
|
||||
c.Set("API_RESPONSE", []byte(data.(string)))
|
||||
appendAPIResponse(c, []byte(data.(string)))
|
||||
case bool:
|
||||
case nil:
|
||||
}
|
||||
@@ -135,6 +136,28 @@ func (h *BaseAPIHandler) GetContextWithCancel(handler interfaces.APIHandler, c *
|
||||
}
|
||||
}
|
||||
|
||||
// appendAPIResponse preserves any previously captured API response and appends new data.
|
||||
func appendAPIResponse(c *gin.Context, data []byte) {
|
||||
if c == nil || len(data) == 0 {
|
||||
return
|
||||
}
|
||||
|
||||
if existing, exists := c.Get("API_RESPONSE"); exists {
|
||||
if existingBytes, ok := existing.([]byte); ok && len(existingBytes) > 0 {
|
||||
combined := make([]byte, 0, len(existingBytes)+len(data)+1)
|
||||
combined = append(combined, existingBytes...)
|
||||
if existingBytes[len(existingBytes)-1] != '\n' {
|
||||
combined = append(combined, '\n')
|
||||
}
|
||||
combined = append(combined, data...)
|
||||
c.Set("API_RESPONSE", combined)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
c.Set("API_RESPONSE", bytes.Clone(data))
|
||||
}
|
||||
|
||||
// ExecuteWithAuthManager executes a non-streaming request via the core auth manager.
|
||||
// This path is the only supported execution route.
|
||||
func (h *BaseAPIHandler) ExecuteWithAuthManager(ctx context.Context, handlerType, modelName string, rawJSON []byte, alt string) ([]byte, *interfaces.ErrorMessage) {
|
||||
@@ -297,7 +320,7 @@ func (h *BaseAPIHandler) ExecuteStreamWithAuthManager(ctx context.Context, handl
|
||||
func (h *BaseAPIHandler) getRequestDetails(modelName string) (providers []string, normalizedModel string, metadata map[string]any, err *interfaces.ErrorMessage) {
|
||||
// Resolve "auto" model to an actual available model first
|
||||
resolvedModelName := util.ResolveAutoModel(modelName)
|
||||
|
||||
|
||||
providerName, extractedModelName, isDynamic := h.parseDynamicModel(resolvedModelName)
|
||||
|
||||
// First, normalize the model name to handle suffixes like "-thinking-128"
|
||||
|
||||
@@ -106,6 +106,10 @@ type Manager struct {
|
||||
// providerOffsets tracks per-model provider rotation state for multi-provider routing.
|
||||
providerOffsets map[string]int
|
||||
|
||||
// Retry controls request retry behavior.
|
||||
requestRetry atomic.Int32
|
||||
maxRetryInterval atomic.Int64
|
||||
|
||||
// Optional HTTP RoundTripper provider injected by host.
|
||||
rtProvider RoundTripperProvider
|
||||
|
||||
@@ -145,6 +149,21 @@ func (m *Manager) SetRoundTripperProvider(p RoundTripperProvider) {
|
||||
m.mu.Unlock()
|
||||
}
|
||||
|
||||
// SetRetryConfig updates retry attempts and cooldown wait interval.
|
||||
func (m *Manager) SetRetryConfig(retry int, maxRetryInterval time.Duration) {
|
||||
if m == nil {
|
||||
return
|
||||
}
|
||||
if retry < 0 {
|
||||
retry = 0
|
||||
}
|
||||
if maxRetryInterval < 0 {
|
||||
maxRetryInterval = 0
|
||||
}
|
||||
m.requestRetry.Store(int32(retry))
|
||||
m.maxRetryInterval.Store(maxRetryInterval.Nanoseconds())
|
||||
}
|
||||
|
||||
// RegisterExecutor registers a provider executor with the manager.
|
||||
func (m *Manager) RegisterExecutor(executor ProviderExecutor) {
|
||||
if executor == nil {
|
||||
@@ -188,8 +207,12 @@ func (m *Manager) Update(ctx context.Context, auth *Auth) (*Auth, error) {
|
||||
if auth == nil || auth.ID == "" {
|
||||
return nil, nil
|
||||
}
|
||||
auth.EnsureIndex()
|
||||
m.mu.Lock()
|
||||
if existing, ok := m.auths[auth.ID]; ok && existing != nil && !auth.indexAssigned && auth.Index == 0 {
|
||||
auth.Index = existing.Index
|
||||
auth.indexAssigned = existing.indexAssigned
|
||||
}
|
||||
auth.EnsureIndex()
|
||||
m.auths[auth.ID] = auth.Clone()
|
||||
m.mu.Unlock()
|
||||
_ = m.persist(ctx, auth)
|
||||
@@ -229,13 +252,28 @@ func (m *Manager) Execute(ctx context.Context, providers []string, req cliproxye
|
||||
rotated := m.rotateProviders(req.Model, normalized)
|
||||
defer m.advanceProviderCursor(req.Model, normalized)
|
||||
|
||||
retryTimes, maxWait := m.retrySettings()
|
||||
attempts := retryTimes + 1
|
||||
if attempts < 1 {
|
||||
attempts = 1
|
||||
}
|
||||
|
||||
var lastErr error
|
||||
for _, provider := range rotated {
|
||||
resp, errExec := m.executeWithProvider(ctx, provider, req, opts)
|
||||
for attempt := 0; attempt < attempts; attempt++ {
|
||||
resp, errExec := m.executeProvidersOnce(ctx, rotated, func(execCtx context.Context, provider string) (cliproxyexecutor.Response, error) {
|
||||
return m.executeWithProvider(execCtx, provider, req, opts)
|
||||
})
|
||||
if errExec == nil {
|
||||
return resp, nil
|
||||
}
|
||||
lastErr = errExec
|
||||
wait, shouldRetry := m.shouldRetryAfterError(errExec, attempt, attempts, rotated, req.Model, maxWait)
|
||||
if !shouldRetry {
|
||||
break
|
||||
}
|
||||
if errWait := waitForCooldown(ctx, wait); errWait != nil {
|
||||
return cliproxyexecutor.Response{}, errWait
|
||||
}
|
||||
}
|
||||
if lastErr != nil {
|
||||
return cliproxyexecutor.Response{}, lastErr
|
||||
@@ -253,13 +291,28 @@ func (m *Manager) ExecuteCount(ctx context.Context, providers []string, req clip
|
||||
rotated := m.rotateProviders(req.Model, normalized)
|
||||
defer m.advanceProviderCursor(req.Model, normalized)
|
||||
|
||||
retryTimes, maxWait := m.retrySettings()
|
||||
attempts := retryTimes + 1
|
||||
if attempts < 1 {
|
||||
attempts = 1
|
||||
}
|
||||
|
||||
var lastErr error
|
||||
for _, provider := range rotated {
|
||||
resp, errExec := m.executeCountWithProvider(ctx, provider, req, opts)
|
||||
for attempt := 0; attempt < attempts; attempt++ {
|
||||
resp, errExec := m.executeProvidersOnce(ctx, rotated, func(execCtx context.Context, provider string) (cliproxyexecutor.Response, error) {
|
||||
return m.executeCountWithProvider(execCtx, provider, req, opts)
|
||||
})
|
||||
if errExec == nil {
|
||||
return resp, nil
|
||||
}
|
||||
lastErr = errExec
|
||||
wait, shouldRetry := m.shouldRetryAfterError(errExec, attempt, attempts, rotated, req.Model, maxWait)
|
||||
if !shouldRetry {
|
||||
break
|
||||
}
|
||||
if errWait := waitForCooldown(ctx, wait); errWait != nil {
|
||||
return cliproxyexecutor.Response{}, errWait
|
||||
}
|
||||
}
|
||||
if lastErr != nil {
|
||||
return cliproxyexecutor.Response{}, lastErr
|
||||
@@ -277,13 +330,28 @@ func (m *Manager) ExecuteStream(ctx context.Context, providers []string, req cli
|
||||
rotated := m.rotateProviders(req.Model, normalized)
|
||||
defer m.advanceProviderCursor(req.Model, normalized)
|
||||
|
||||
retryTimes, maxWait := m.retrySettings()
|
||||
attempts := retryTimes + 1
|
||||
if attempts < 1 {
|
||||
attempts = 1
|
||||
}
|
||||
|
||||
var lastErr error
|
||||
for _, provider := range rotated {
|
||||
chunks, errStream := m.executeStreamWithProvider(ctx, provider, req, opts)
|
||||
for attempt := 0; attempt < attempts; attempt++ {
|
||||
chunks, errStream := m.executeStreamProvidersOnce(ctx, rotated, func(execCtx context.Context, provider string) (<-chan cliproxyexecutor.StreamChunk, error) {
|
||||
return m.executeStreamWithProvider(execCtx, provider, req, opts)
|
||||
})
|
||||
if errStream == nil {
|
||||
return chunks, nil
|
||||
}
|
||||
lastErr = errStream
|
||||
wait, shouldRetry := m.shouldRetryAfterError(errStream, attempt, attempts, rotated, req.Model, maxWait)
|
||||
if !shouldRetry {
|
||||
break
|
||||
}
|
||||
if errWait := waitForCooldown(ctx, wait); errWait != nil {
|
||||
return nil, errWait
|
||||
}
|
||||
}
|
||||
if lastErr != nil {
|
||||
return nil, lastErr
|
||||
@@ -507,6 +575,123 @@ func (m *Manager) advanceProviderCursor(model string, providers []string) {
|
||||
m.mu.Unlock()
|
||||
}
|
||||
|
||||
func (m *Manager) retrySettings() (int, time.Duration) {
|
||||
if m == nil {
|
||||
return 0, 0
|
||||
}
|
||||
return int(m.requestRetry.Load()), time.Duration(m.maxRetryInterval.Load())
|
||||
}
|
||||
|
||||
func (m *Manager) closestCooldownWait(providers []string, model string) (time.Duration, bool) {
|
||||
if m == nil || len(providers) == 0 {
|
||||
return 0, false
|
||||
}
|
||||
now := time.Now()
|
||||
providerSet := make(map[string]struct{}, len(providers))
|
||||
for i := range providers {
|
||||
key := strings.TrimSpace(strings.ToLower(providers[i]))
|
||||
if key == "" {
|
||||
continue
|
||||
}
|
||||
providerSet[key] = struct{}{}
|
||||
}
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
var (
|
||||
found bool
|
||||
minWait time.Duration
|
||||
)
|
||||
for _, auth := range m.auths {
|
||||
if auth == nil {
|
||||
continue
|
||||
}
|
||||
providerKey := strings.TrimSpace(strings.ToLower(auth.Provider))
|
||||
if _, ok := providerSet[providerKey]; !ok {
|
||||
continue
|
||||
}
|
||||
blocked, reason, next := isAuthBlockedForModel(auth, model, now)
|
||||
if !blocked || next.IsZero() || reason == blockReasonDisabled {
|
||||
continue
|
||||
}
|
||||
wait := next.Sub(now)
|
||||
if wait < 0 {
|
||||
continue
|
||||
}
|
||||
if !found || wait < minWait {
|
||||
minWait = wait
|
||||
found = true
|
||||
}
|
||||
}
|
||||
return minWait, found
|
||||
}
|
||||
|
||||
func (m *Manager) shouldRetryAfterError(err error, attempt, maxAttempts int, providers []string, model string, maxWait time.Duration) (time.Duration, bool) {
|
||||
if err == nil || attempt >= maxAttempts-1 {
|
||||
return 0, false
|
||||
}
|
||||
if maxWait <= 0 {
|
||||
return 0, false
|
||||
}
|
||||
if status := statusCodeFromError(err); status == http.StatusOK {
|
||||
return 0, false
|
||||
}
|
||||
wait, found := m.closestCooldownWait(providers, model)
|
||||
if !found || wait > maxWait {
|
||||
return 0, false
|
||||
}
|
||||
return wait, true
|
||||
}
|
||||
|
||||
func waitForCooldown(ctx context.Context, wait time.Duration) error {
|
||||
if wait <= 0 {
|
||||
return nil
|
||||
}
|
||||
timer := time.NewTimer(wait)
|
||||
defer timer.Stop()
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
case <-timer.C:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Manager) executeProvidersOnce(ctx context.Context, providers []string, fn func(context.Context, string) (cliproxyexecutor.Response, error)) (cliproxyexecutor.Response, error) {
|
||||
if len(providers) == 0 {
|
||||
return cliproxyexecutor.Response{}, &Error{Code: "provider_not_found", Message: "no provider supplied"}
|
||||
}
|
||||
var lastErr error
|
||||
for _, provider := range providers {
|
||||
resp, errExec := fn(ctx, provider)
|
||||
if errExec == nil {
|
||||
return resp, nil
|
||||
}
|
||||
lastErr = errExec
|
||||
}
|
||||
if lastErr != nil {
|
||||
return cliproxyexecutor.Response{}, lastErr
|
||||
}
|
||||
return cliproxyexecutor.Response{}, &Error{Code: "auth_not_found", Message: "no auth available"}
|
||||
}
|
||||
|
||||
func (m *Manager) executeStreamProvidersOnce(ctx context.Context, providers []string, fn func(context.Context, string) (<-chan cliproxyexecutor.StreamChunk, error)) (<-chan cliproxyexecutor.StreamChunk, error) {
|
||||
if len(providers) == 0 {
|
||||
return nil, &Error{Code: "provider_not_found", Message: "no provider supplied"}
|
||||
}
|
||||
var lastErr error
|
||||
for _, provider := range providers {
|
||||
chunks, errExec := fn(ctx, provider)
|
||||
if errExec == nil {
|
||||
return chunks, nil
|
||||
}
|
||||
lastErr = errExec
|
||||
}
|
||||
if lastErr != nil {
|
||||
return nil, lastErr
|
||||
}
|
||||
return nil, &Error{Code: "auth_not_found", Message: "no auth available"}
|
||||
}
|
||||
|
||||
// MarkResult records an execution result and notifies hooks.
|
||||
func (m *Manager) MarkResult(ctx context.Context, result Result) {
|
||||
if result.AuthID == "" {
|
||||
@@ -762,6 +947,20 @@ func cloneError(err *Error) *Error {
|
||||
}
|
||||
}
|
||||
|
||||
func statusCodeFromError(err error) int {
|
||||
if err == nil {
|
||||
return 0
|
||||
}
|
||||
type statusCoder interface {
|
||||
StatusCode() int
|
||||
}
|
||||
var sc statusCoder
|
||||
if errors.As(err, &sc) && sc != nil {
|
||||
return sc.StatusCode()
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
func retryAfterFromError(err error) *time.Duration {
|
||||
if err == nil {
|
||||
return nil
|
||||
|
||||
@@ -281,6 +281,14 @@ func (s *Service) applyCoreAuthRemoval(ctx context.Context, id string) {
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Service) applyRetryConfig(cfg *config.Config) {
|
||||
if s == nil || s.coreManager == nil || cfg == nil {
|
||||
return
|
||||
}
|
||||
maxInterval := time.Duration(cfg.MaxRetryInterval) * time.Second
|
||||
s.coreManager.SetRetryConfig(cfg.RequestRetry, maxInterval)
|
||||
}
|
||||
|
||||
func openAICompatInfoFromAuth(a *coreauth.Auth) (providerKey string, compatName string, ok bool) {
|
||||
if a == nil {
|
||||
return "", "", false
|
||||
@@ -394,6 +402,8 @@ func (s *Service) Run(ctx context.Context) error {
|
||||
return err
|
||||
}
|
||||
|
||||
s.applyRetryConfig(s.cfg)
|
||||
|
||||
if s.coreManager != nil {
|
||||
if errLoad := s.coreManager.Load(ctx); errLoad != nil {
|
||||
log.Warnf("failed to load auth store: %v", errLoad)
|
||||
@@ -476,6 +486,7 @@ func (s *Service) Run(ctx context.Context) error {
|
||||
if newCfg == nil {
|
||||
return
|
||||
}
|
||||
s.applyRetryConfig(newCfg)
|
||||
if s.server != nil {
|
||||
s.server.UpdateClients(newCfg)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user