mirror of
https://github.com/router-for-me/CLIProxyAPI.git
synced 2026-02-13 01:40:50 +08:00
Compare commits
10 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
a709e5a12d | ||
|
|
f0ac77197b | ||
|
|
da0bbf2a3f | ||
|
|
295f34d7f0 | ||
|
|
c41ce77eea | ||
|
|
4eb1e6093f | ||
|
|
189a066807 | ||
|
|
d0bada7a43 | ||
|
|
9dc0e6d08b | ||
|
|
8510fc313e |
@@ -8,6 +8,7 @@ import (
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/logging"
|
||||
@@ -103,6 +104,7 @@ func captureRequestInfo(c *gin.Context) (*RequestInfo, error) {
|
||||
Headers: headers,
|
||||
Body: body,
|
||||
RequestID: logging.GetGinRequestID(c),
|
||||
Timestamp: time.Now(),
|
||||
}, nil
|
||||
}
|
||||
|
||||
|
||||
@@ -7,6 +7,7 @@ import (
|
||||
"bytes"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/interfaces"
|
||||
@@ -20,22 +21,24 @@ type RequestInfo struct {
|
||||
Headers map[string][]string // Headers contains the request headers.
|
||||
Body []byte // Body is the raw request body.
|
||||
RequestID string // RequestID is the unique identifier for the request.
|
||||
Timestamp time.Time // Timestamp is when the request was received.
|
||||
}
|
||||
|
||||
// ResponseWriterWrapper wraps the standard gin.ResponseWriter to intercept and log response data.
|
||||
// It is designed to handle both standard and streaming responses, ensuring that logging operations do not block the client response.
|
||||
type ResponseWriterWrapper struct {
|
||||
gin.ResponseWriter
|
||||
body *bytes.Buffer // body is a buffer to store the response body for non-streaming responses.
|
||||
isStreaming bool // isStreaming indicates whether the response is a streaming type (e.g., text/event-stream).
|
||||
streamWriter logging.StreamingLogWriter // streamWriter is a writer for handling streaming log entries.
|
||||
chunkChannel chan []byte // chunkChannel is a channel for asynchronously passing response chunks to the logger.
|
||||
streamDone chan struct{} // streamDone signals when the streaming goroutine completes.
|
||||
logger logging.RequestLogger // logger is the instance of the request logger service.
|
||||
requestInfo *RequestInfo // requestInfo holds the details of the original request.
|
||||
statusCode int // statusCode stores the HTTP status code of the response.
|
||||
headers map[string][]string // headers stores the response headers.
|
||||
logOnErrorOnly bool // logOnErrorOnly enables logging only when an error response is detected.
|
||||
body *bytes.Buffer // body is a buffer to store the response body for non-streaming responses.
|
||||
isStreaming bool // isStreaming indicates whether the response is a streaming type (e.g., text/event-stream).
|
||||
streamWriter logging.StreamingLogWriter // streamWriter is a writer for handling streaming log entries.
|
||||
chunkChannel chan []byte // chunkChannel is a channel for asynchronously passing response chunks to the logger.
|
||||
streamDone chan struct{} // streamDone signals when the streaming goroutine completes.
|
||||
logger logging.RequestLogger // logger is the instance of the request logger service.
|
||||
requestInfo *RequestInfo // requestInfo holds the details of the original request.
|
||||
statusCode int // statusCode stores the HTTP status code of the response.
|
||||
headers map[string][]string // headers stores the response headers.
|
||||
logOnErrorOnly bool // logOnErrorOnly enables logging only when an error response is detected.
|
||||
firstChunkTimestamp time.Time // firstChunkTimestamp captures TTFB for streaming responses.
|
||||
}
|
||||
|
||||
// NewResponseWriterWrapper creates and initializes a new ResponseWriterWrapper.
|
||||
@@ -73,6 +76,10 @@ func (w *ResponseWriterWrapper) Write(data []byte) (int, error) {
|
||||
|
||||
// THEN: Handle logging based on response type
|
||||
if w.isStreaming && w.chunkChannel != nil {
|
||||
// Capture TTFB on first chunk (synchronous, before async channel send)
|
||||
if w.firstChunkTimestamp.IsZero() {
|
||||
w.firstChunkTimestamp = time.Now()
|
||||
}
|
||||
// For streaming responses: Send to async logging channel (non-blocking)
|
||||
select {
|
||||
case w.chunkChannel <- append([]byte(nil), data...): // Non-blocking send with copy
|
||||
@@ -117,6 +124,10 @@ func (w *ResponseWriterWrapper) WriteString(data string) (int, error) {
|
||||
|
||||
// THEN: Capture for logging
|
||||
if w.isStreaming && w.chunkChannel != nil {
|
||||
// Capture TTFB on first chunk (synchronous, before async channel send)
|
||||
if w.firstChunkTimestamp.IsZero() {
|
||||
w.firstChunkTimestamp = time.Now()
|
||||
}
|
||||
select {
|
||||
case w.chunkChannel <- []byte(data):
|
||||
default:
|
||||
@@ -280,6 +291,8 @@ func (w *ResponseWriterWrapper) Finalize(c *gin.Context) error {
|
||||
w.streamDone = nil
|
||||
}
|
||||
|
||||
w.streamWriter.SetFirstChunkTimestamp(w.firstChunkTimestamp)
|
||||
|
||||
// Write API Request and Response to the streaming log before closing
|
||||
apiRequest := w.extractAPIRequest(c)
|
||||
if len(apiRequest) > 0 {
|
||||
@@ -297,7 +310,7 @@ func (w *ResponseWriterWrapper) Finalize(c *gin.Context) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
return w.logRequest(finalStatusCode, w.cloneHeaders(), w.body.Bytes(), w.extractAPIRequest(c), w.extractAPIResponse(c), slicesAPIResponseError, forceLog)
|
||||
return w.logRequest(finalStatusCode, w.cloneHeaders(), w.body.Bytes(), w.extractAPIRequest(c), w.extractAPIResponse(c), w.extractAPIResponseTimestamp(c), slicesAPIResponseError, forceLog)
|
||||
}
|
||||
|
||||
func (w *ResponseWriterWrapper) cloneHeaders() map[string][]string {
|
||||
@@ -337,7 +350,18 @@ func (w *ResponseWriterWrapper) extractAPIResponse(c *gin.Context) []byte {
|
||||
return data
|
||||
}
|
||||
|
||||
func (w *ResponseWriterWrapper) logRequest(statusCode int, headers map[string][]string, body []byte, apiRequestBody, apiResponseBody []byte, apiResponseErrors []*interfaces.ErrorMessage, forceLog bool) error {
|
||||
func (w *ResponseWriterWrapper) extractAPIResponseTimestamp(c *gin.Context) time.Time {
|
||||
ts, isExist := c.Get("API_RESPONSE_TIMESTAMP")
|
||||
if !isExist {
|
||||
return time.Time{}
|
||||
}
|
||||
if t, ok := ts.(time.Time); ok {
|
||||
return t
|
||||
}
|
||||
return time.Time{}
|
||||
}
|
||||
|
||||
func (w *ResponseWriterWrapper) logRequest(statusCode int, headers map[string][]string, body []byte, apiRequestBody, apiResponseBody []byte, apiResponseTimestamp time.Time, apiResponseErrors []*interfaces.ErrorMessage, forceLog bool) error {
|
||||
if w.requestInfo == nil {
|
||||
return nil
|
||||
}
|
||||
@@ -348,7 +372,7 @@ func (w *ResponseWriterWrapper) logRequest(statusCode int, headers map[string][]
|
||||
}
|
||||
|
||||
if loggerWithOptions, ok := w.logger.(interface {
|
||||
LogRequestWithOptions(string, string, map[string][]string, []byte, int, map[string][]string, []byte, []byte, []byte, []*interfaces.ErrorMessage, bool, string) error
|
||||
LogRequestWithOptions(string, string, map[string][]string, []byte, int, map[string][]string, []byte, []byte, []byte, []*interfaces.ErrorMessage, bool, string, time.Time, time.Time) error
|
||||
}); ok {
|
||||
return loggerWithOptions.LogRequestWithOptions(
|
||||
w.requestInfo.URL,
|
||||
@@ -363,6 +387,8 @@ func (w *ResponseWriterWrapper) logRequest(statusCode int, headers map[string][]
|
||||
apiResponseErrors,
|
||||
forceLog,
|
||||
w.requestInfo.RequestID,
|
||||
w.requestInfo.Timestamp,
|
||||
apiResponseTimestamp,
|
||||
)
|
||||
}
|
||||
|
||||
@@ -378,5 +404,7 @@ func (w *ResponseWriterWrapper) logRequest(statusCode int, headers map[string][]
|
||||
apiResponseBody,
|
||||
apiResponseErrors,
|
||||
w.requestInfo.RequestID,
|
||||
w.requestInfo.Timestamp,
|
||||
apiResponseTimestamp,
|
||||
)
|
||||
}
|
||||
|
||||
@@ -12,6 +12,7 @@ import (
|
||||
"net/http"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"reflect"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
@@ -990,14 +991,17 @@ func (s *Server) UpdateClients(cfg *config.Config) {
|
||||
s.mgmt.SetAuthManager(s.handlers.AuthManager)
|
||||
}
|
||||
|
||||
// Notify Amp module of config changes (for model mapping hot-reload)
|
||||
if s.ampModule != nil {
|
||||
log.Debugf("triggering amp module config update")
|
||||
if err := s.ampModule.OnConfigUpdated(cfg); err != nil {
|
||||
log.Errorf("failed to update Amp module config: %v", err)
|
||||
// Notify Amp module only when Amp config has changed.
|
||||
ampConfigChanged := oldCfg == nil || !reflect.DeepEqual(oldCfg.AmpCode, cfg.AmpCode)
|
||||
if ampConfigChanged {
|
||||
if s.ampModule != nil {
|
||||
log.Debugf("triggering amp module config update")
|
||||
if err := s.ampModule.OnConfigUpdated(cfg); err != nil {
|
||||
log.Errorf("failed to update Amp module config: %v", err)
|
||||
}
|
||||
} else {
|
||||
log.Warnf("amp module is nil, skipping config update")
|
||||
}
|
||||
} else {
|
||||
log.Warnf("amp module is nil, skipping config update")
|
||||
}
|
||||
|
||||
// Count client sources from configuration and auth store.
|
||||
|
||||
@@ -923,6 +923,7 @@ func SaveConfigPreserveComments(configFile string, cfg *Config) error {
|
||||
removeLegacyGenerativeLanguageKeys(original.Content[0])
|
||||
|
||||
pruneMappingToGeneratedKeys(original.Content[0], generated.Content[0], "oauth-excluded-models")
|
||||
pruneMappingToGeneratedKeys(original.Content[0], generated.Content[0], "oauth-model-alias")
|
||||
|
||||
// Merge generated into original in-place, preserving comments/order of existing nodes.
|
||||
mergeMappingPreserve(original.Content[0], generated.Content[0])
|
||||
@@ -1413,6 +1414,16 @@ func pruneMappingToGeneratedKeys(dstRoot, srcRoot *yaml.Node, key string) {
|
||||
}
|
||||
srcIdx := findMapKeyIndex(srcRoot, key)
|
||||
if srcIdx < 0 {
|
||||
// Keep an explicit empty mapping for oauth-model-alias when it was previously present.
|
||||
//
|
||||
// Rationale: LoadConfig runs MigrateOAuthModelAlias before unmarshalling. If the
|
||||
// oauth-model-alias key is missing, migration will add the default antigravity aliases.
|
||||
// When users delete the last channel from oauth-model-alias via the management API,
|
||||
// we want that deletion to persist across hot reloads and restarts.
|
||||
if key == "oauth-model-alias" {
|
||||
dstRoot.Content[dstIdx+1] = &yaml.Node{Kind: yaml.MappingNode, Tag: "!!map"}
|
||||
return
|
||||
}
|
||||
removeMapKey(dstRoot, key)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -44,10 +44,12 @@ type RequestLogger interface {
|
||||
// - apiRequest: The API request data
|
||||
// - apiResponse: The API response data
|
||||
// - requestID: Optional request ID for log file naming
|
||||
// - requestTimestamp: When the request was received
|
||||
// - apiResponseTimestamp: When the API response was received
|
||||
//
|
||||
// Returns:
|
||||
// - error: An error if logging fails, nil otherwise
|
||||
LogRequest(url, method string, requestHeaders map[string][]string, body []byte, statusCode int, responseHeaders map[string][]string, response, apiRequest, apiResponse []byte, apiResponseErrors []*interfaces.ErrorMessage, requestID string) error
|
||||
LogRequest(url, method string, requestHeaders map[string][]string, body []byte, statusCode int, responseHeaders map[string][]string, response, apiRequest, apiResponse []byte, apiResponseErrors []*interfaces.ErrorMessage, requestID string, requestTimestamp, apiResponseTimestamp time.Time) error
|
||||
|
||||
// LogStreamingRequest initiates logging for a streaming request and returns a writer for chunks.
|
||||
//
|
||||
@@ -109,6 +111,12 @@ type StreamingLogWriter interface {
|
||||
// - error: An error if writing fails, nil otherwise
|
||||
WriteAPIResponse(apiResponse []byte) error
|
||||
|
||||
// SetFirstChunkTimestamp sets the TTFB timestamp captured when first chunk was received.
|
||||
//
|
||||
// Parameters:
|
||||
// - timestamp: The time when first response chunk was received
|
||||
SetFirstChunkTimestamp(timestamp time.Time)
|
||||
|
||||
// Close finalizes the log file and cleans up resources.
|
||||
//
|
||||
// Returns:
|
||||
@@ -180,20 +188,22 @@ func (l *FileRequestLogger) SetEnabled(enabled bool) {
|
||||
// - apiRequest: The API request data
|
||||
// - apiResponse: The API response data
|
||||
// - requestID: Optional request ID for log file naming
|
||||
// - requestTimestamp: When the request was received
|
||||
// - apiResponseTimestamp: When the API response was received
|
||||
//
|
||||
// Returns:
|
||||
// - error: An error if logging fails, nil otherwise
|
||||
func (l *FileRequestLogger) LogRequest(url, method string, requestHeaders map[string][]string, body []byte, statusCode int, responseHeaders map[string][]string, response, apiRequest, apiResponse []byte, apiResponseErrors []*interfaces.ErrorMessage, requestID string) error {
|
||||
return l.logRequest(url, method, requestHeaders, body, statusCode, responseHeaders, response, apiRequest, apiResponse, apiResponseErrors, false, requestID)
|
||||
func (l *FileRequestLogger) LogRequest(url, method string, requestHeaders map[string][]string, body []byte, statusCode int, responseHeaders map[string][]string, response, apiRequest, apiResponse []byte, apiResponseErrors []*interfaces.ErrorMessage, requestID string, requestTimestamp, apiResponseTimestamp time.Time) error {
|
||||
return l.logRequest(url, method, requestHeaders, body, statusCode, responseHeaders, response, apiRequest, apiResponse, apiResponseErrors, false, requestID, requestTimestamp, apiResponseTimestamp)
|
||||
}
|
||||
|
||||
// LogRequestWithOptions logs a request with optional forced logging behavior.
|
||||
// The force flag allows writing error logs even when regular request logging is disabled.
|
||||
func (l *FileRequestLogger) LogRequestWithOptions(url, method string, requestHeaders map[string][]string, body []byte, statusCode int, responseHeaders map[string][]string, response, apiRequest, apiResponse []byte, apiResponseErrors []*interfaces.ErrorMessage, force bool, requestID string) error {
|
||||
return l.logRequest(url, method, requestHeaders, body, statusCode, responseHeaders, response, apiRequest, apiResponse, apiResponseErrors, force, requestID)
|
||||
func (l *FileRequestLogger) LogRequestWithOptions(url, method string, requestHeaders map[string][]string, body []byte, statusCode int, responseHeaders map[string][]string, response, apiRequest, apiResponse []byte, apiResponseErrors []*interfaces.ErrorMessage, force bool, requestID string, requestTimestamp, apiResponseTimestamp time.Time) error {
|
||||
return l.logRequest(url, method, requestHeaders, body, statusCode, responseHeaders, response, apiRequest, apiResponse, apiResponseErrors, force, requestID, requestTimestamp, apiResponseTimestamp)
|
||||
}
|
||||
|
||||
func (l *FileRequestLogger) logRequest(url, method string, requestHeaders map[string][]string, body []byte, statusCode int, responseHeaders map[string][]string, response, apiRequest, apiResponse []byte, apiResponseErrors []*interfaces.ErrorMessage, force bool, requestID string) error {
|
||||
func (l *FileRequestLogger) logRequest(url, method string, requestHeaders map[string][]string, body []byte, statusCode int, responseHeaders map[string][]string, response, apiRequest, apiResponse []byte, apiResponseErrors []*interfaces.ErrorMessage, force bool, requestID string, requestTimestamp, apiResponseTimestamp time.Time) error {
|
||||
if !l.enabled && !force {
|
||||
return nil
|
||||
}
|
||||
@@ -247,6 +257,8 @@ func (l *FileRequestLogger) logRequest(url, method string, requestHeaders map[st
|
||||
responseHeaders,
|
||||
responseToWrite,
|
||||
decompressErr,
|
||||
requestTimestamp,
|
||||
apiResponseTimestamp,
|
||||
)
|
||||
if errClose := logFile.Close(); errClose != nil {
|
||||
log.WithError(errClose).Warn("failed to close request log file")
|
||||
@@ -499,17 +511,22 @@ func (l *FileRequestLogger) writeNonStreamingLog(
|
||||
responseHeaders map[string][]string,
|
||||
response []byte,
|
||||
decompressErr error,
|
||||
requestTimestamp time.Time,
|
||||
apiResponseTimestamp time.Time,
|
||||
) error {
|
||||
if errWrite := writeRequestInfoWithBody(w, url, method, requestHeaders, requestBody, requestBodyPath, time.Now()); errWrite != nil {
|
||||
if requestTimestamp.IsZero() {
|
||||
requestTimestamp = time.Now()
|
||||
}
|
||||
if errWrite := writeRequestInfoWithBody(w, url, method, requestHeaders, requestBody, requestBodyPath, requestTimestamp); errWrite != nil {
|
||||
return errWrite
|
||||
}
|
||||
if errWrite := writeAPISection(w, "=== API REQUEST ===\n", "=== API REQUEST", apiRequest); errWrite != nil {
|
||||
if errWrite := writeAPISection(w, "=== API REQUEST ===\n", "=== API REQUEST", apiRequest, time.Time{}); errWrite != nil {
|
||||
return errWrite
|
||||
}
|
||||
if errWrite := writeAPIErrorResponses(w, apiResponseErrors); errWrite != nil {
|
||||
return errWrite
|
||||
}
|
||||
if errWrite := writeAPISection(w, "=== API RESPONSE ===\n", "=== API RESPONSE", apiResponse); errWrite != nil {
|
||||
if errWrite := writeAPISection(w, "=== API RESPONSE ===\n", "=== API RESPONSE", apiResponse, apiResponseTimestamp); errWrite != nil {
|
||||
return errWrite
|
||||
}
|
||||
return writeResponseSection(w, statusCode, true, responseHeaders, bytes.NewReader(response), decompressErr, true)
|
||||
@@ -583,7 +600,7 @@ func writeRequestInfoWithBody(
|
||||
return nil
|
||||
}
|
||||
|
||||
func writeAPISection(w io.Writer, sectionHeader string, sectionPrefix string, payload []byte) error {
|
||||
func writeAPISection(w io.Writer, sectionHeader string, sectionPrefix string, payload []byte, timestamp time.Time) error {
|
||||
if len(payload) == 0 {
|
||||
return nil
|
||||
}
|
||||
@@ -601,6 +618,11 @@ func writeAPISection(w io.Writer, sectionHeader string, sectionPrefix string, pa
|
||||
if _, errWrite := io.WriteString(w, sectionHeader); errWrite != nil {
|
||||
return errWrite
|
||||
}
|
||||
if !timestamp.IsZero() {
|
||||
if _, errWrite := io.WriteString(w, fmt.Sprintf("Timestamp: %s\n", timestamp.Format(time.RFC3339Nano))); errWrite != nil {
|
||||
return errWrite
|
||||
}
|
||||
}
|
||||
if _, errWrite := w.Write(payload); errWrite != nil {
|
||||
return errWrite
|
||||
}
|
||||
@@ -974,6 +996,9 @@ type FileStreamingLogWriter struct {
|
||||
|
||||
// apiResponse stores the upstream API response data.
|
||||
apiResponse []byte
|
||||
|
||||
// apiResponseTimestamp captures when the API response was received.
|
||||
apiResponseTimestamp time.Time
|
||||
}
|
||||
|
||||
// WriteChunkAsync writes a response chunk asynchronously (non-blocking).
|
||||
@@ -1053,6 +1078,12 @@ func (w *FileStreamingLogWriter) WriteAPIResponse(apiResponse []byte) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (w *FileStreamingLogWriter) SetFirstChunkTimestamp(timestamp time.Time) {
|
||||
if !timestamp.IsZero() {
|
||||
w.apiResponseTimestamp = timestamp
|
||||
}
|
||||
}
|
||||
|
||||
// Close finalizes the log file and cleans up resources.
|
||||
// It writes all buffered data to the file in the correct order:
|
||||
// API REQUEST -> API RESPONSE -> RESPONSE (status, headers, body chunks)
|
||||
@@ -1140,10 +1171,10 @@ func (w *FileStreamingLogWriter) writeFinalLog(logFile *os.File) error {
|
||||
if errWrite := writeRequestInfoWithBody(logFile, w.url, w.method, w.requestHeaders, nil, w.requestBodyPath, w.timestamp); errWrite != nil {
|
||||
return errWrite
|
||||
}
|
||||
if errWrite := writeAPISection(logFile, "=== API REQUEST ===\n", "=== API REQUEST", w.apiRequest); errWrite != nil {
|
||||
if errWrite := writeAPISection(logFile, "=== API REQUEST ===\n", "=== API REQUEST", w.apiRequest, time.Time{}); errWrite != nil {
|
||||
return errWrite
|
||||
}
|
||||
if errWrite := writeAPISection(logFile, "=== API RESPONSE ===\n", "=== API RESPONSE", w.apiResponse); errWrite != nil {
|
||||
if errWrite := writeAPISection(logFile, "=== API RESPONSE ===\n", "=== API RESPONSE", w.apiResponse, w.apiResponseTimestamp); errWrite != nil {
|
||||
return errWrite
|
||||
}
|
||||
|
||||
@@ -1220,6 +1251,8 @@ func (w *NoOpStreamingLogWriter) WriteAPIResponse(_ []byte) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (w *NoOpStreamingLogWriter) SetFirstChunkTimestamp(_ time.Time) {}
|
||||
|
||||
// Close is a no-op implementation that does nothing and always returns nil.
|
||||
//
|
||||
// Returns:
|
||||
|
||||
@@ -41,6 +41,7 @@ func ConvertAntigravityResponseToGemini(ctx context.Context, _ string, originalR
|
||||
responseResult := gjson.GetBytes(rawJSON, "response")
|
||||
if responseResult.Exists() {
|
||||
chunk = []byte(responseResult.Raw)
|
||||
chunk = restoreUsageMetadata(chunk)
|
||||
}
|
||||
} else {
|
||||
chunkTemplate := "[]"
|
||||
@@ -76,7 +77,8 @@ func ConvertAntigravityResponseToGemini(ctx context.Context, _ string, originalR
|
||||
func ConvertAntigravityResponseToGeminiNonStream(_ context.Context, _ string, originalRequestRawJSON, requestRawJSON, rawJSON []byte, _ *any) string {
|
||||
responseResult := gjson.GetBytes(rawJSON, "response")
|
||||
if responseResult.Exists() {
|
||||
return responseResult.Raw
|
||||
chunk := restoreUsageMetadata([]byte(responseResult.Raw))
|
||||
return string(chunk)
|
||||
}
|
||||
return string(rawJSON)
|
||||
}
|
||||
@@ -84,3 +86,15 @@ func ConvertAntigravityResponseToGeminiNonStream(_ context.Context, _ string, or
|
||||
func GeminiTokenCount(ctx context.Context, count int64) string {
|
||||
return fmt.Sprintf(`{"totalTokens":%d,"promptTokensDetails":[{"modality":"TEXT","tokenCount":%d}]}`, count, count)
|
||||
}
|
||||
|
||||
// restoreUsageMetadata renames cpaUsageMetadata back to usageMetadata.
|
||||
// The executor renames usageMetadata to cpaUsageMetadata in non-terminal chunks
|
||||
// to preserve usage data while hiding it from clients that don't expect it.
|
||||
// When returning standard Gemini API format, we must restore the original name.
|
||||
func restoreUsageMetadata(chunk []byte) []byte {
|
||||
if cpaUsage := gjson.GetBytes(chunk, "cpaUsageMetadata"); cpaUsage.Exists() {
|
||||
chunk, _ = sjson.SetRawBytes(chunk, "usageMetadata", []byte(cpaUsage.Raw))
|
||||
chunk, _ = sjson.DeleteBytes(chunk, "cpaUsageMetadata")
|
||||
}
|
||||
return chunk
|
||||
}
|
||||
|
||||
@@ -0,0 +1,95 @@
|
||||
package gemini
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestRestoreUsageMetadata(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input []byte
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "cpaUsageMetadata renamed to usageMetadata",
|
||||
input: []byte(`{"modelVersion":"gemini-3-pro","cpaUsageMetadata":{"promptTokenCount":100,"candidatesTokenCount":200}}`),
|
||||
expected: `{"modelVersion":"gemini-3-pro","usageMetadata":{"promptTokenCount":100,"candidatesTokenCount":200}}`,
|
||||
},
|
||||
{
|
||||
name: "no cpaUsageMetadata unchanged",
|
||||
input: []byte(`{"modelVersion":"gemini-3-pro","usageMetadata":{"promptTokenCount":100}}`),
|
||||
expected: `{"modelVersion":"gemini-3-pro","usageMetadata":{"promptTokenCount":100}}`,
|
||||
},
|
||||
{
|
||||
name: "empty input",
|
||||
input: []byte(`{}`),
|
||||
expected: `{}`,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := restoreUsageMetadata(tt.input)
|
||||
if string(result) != tt.expected {
|
||||
t.Errorf("restoreUsageMetadata() = %s, want %s", string(result), tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestConvertAntigravityResponseToGeminiNonStream(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input []byte
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "cpaUsageMetadata restored in response",
|
||||
input: []byte(`{"response":{"modelVersion":"gemini-3-pro","cpaUsageMetadata":{"promptTokenCount":100}}}`),
|
||||
expected: `{"modelVersion":"gemini-3-pro","usageMetadata":{"promptTokenCount":100}}`,
|
||||
},
|
||||
{
|
||||
name: "usageMetadata preserved",
|
||||
input: []byte(`{"response":{"modelVersion":"gemini-3-pro","usageMetadata":{"promptTokenCount":100}}}`),
|
||||
expected: `{"modelVersion":"gemini-3-pro","usageMetadata":{"promptTokenCount":100}}`,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := ConvertAntigravityResponseToGeminiNonStream(context.Background(), "", nil, nil, tt.input, nil)
|
||||
if result != tt.expected {
|
||||
t.Errorf("ConvertAntigravityResponseToGeminiNonStream() = %s, want %s", result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestConvertAntigravityResponseToGeminiStream(t *testing.T) {
|
||||
ctx := context.WithValue(context.Background(), "alt", "")
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
input []byte
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "cpaUsageMetadata restored in streaming response",
|
||||
input: []byte(`data: {"response":{"modelVersion":"gemini-3-pro","cpaUsageMetadata":{"promptTokenCount":100}}}`),
|
||||
expected: `{"modelVersion":"gemini-3-pro","usageMetadata":{"promptTokenCount":100}}`,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
results := ConvertAntigravityResponseToGemini(ctx, "", nil, nil, tt.input, nil)
|
||||
if len(results) != 1 {
|
||||
t.Fatalf("expected 1 result, got %d", len(results))
|
||||
}
|
||||
if results[0] != tt.expected {
|
||||
t.Errorf("ConvertAntigravityResponseToGemini() = %s, want %s", results[0], tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -124,6 +124,7 @@ func (h *GeminiCLIAPIHandler) CLIHandler(c *gin.Context) {
|
||||
log.Errorf("Failed to read response body: %v", err)
|
||||
return
|
||||
}
|
||||
c.Set("API_RESPONSE_TIMESTAMP", time.Now())
|
||||
_, _ = c.Writer.Write(output)
|
||||
c.Set("API_RESPONSE", output)
|
||||
}
|
||||
|
||||
@@ -361,6 +361,11 @@ func appendAPIResponse(c *gin.Context, data []byte) {
|
||||
return
|
||||
}
|
||||
|
||||
// Capture timestamp on first API response
|
||||
if _, exists := c.Get("API_RESPONSE_TIMESTAMP"); !exists {
|
||||
c.Set("API_RESPONSE_TIMESTAMP", time.Now())
|
||||
}
|
||||
|
||||
if existing, exists := c.Get("API_RESPONSE"); exists {
|
||||
if existingBytes, ok := existing.([]byte); ok && len(existingBytes) > 0 {
|
||||
combined := make([]byte, 0, len(existingBytes)+len(data)+1)
|
||||
|
||||
@@ -70,6 +70,58 @@ func (e *failOnceStreamExecutor) Calls() int {
|
||||
return e.calls
|
||||
}
|
||||
|
||||
type payloadThenErrorStreamExecutor struct {
|
||||
mu sync.Mutex
|
||||
calls int
|
||||
}
|
||||
|
||||
func (e *payloadThenErrorStreamExecutor) Identifier() string { return "codex" }
|
||||
|
||||
func (e *payloadThenErrorStreamExecutor) Execute(context.Context, *coreauth.Auth, coreexecutor.Request, coreexecutor.Options) (coreexecutor.Response, error) {
|
||||
return coreexecutor.Response{}, &coreauth.Error{Code: "not_implemented", Message: "Execute not implemented"}
|
||||
}
|
||||
|
||||
func (e *payloadThenErrorStreamExecutor) ExecuteStream(context.Context, *coreauth.Auth, coreexecutor.Request, coreexecutor.Options) (<-chan coreexecutor.StreamChunk, error) {
|
||||
e.mu.Lock()
|
||||
e.calls++
|
||||
e.mu.Unlock()
|
||||
|
||||
ch := make(chan coreexecutor.StreamChunk, 2)
|
||||
ch <- coreexecutor.StreamChunk{Payload: []byte("partial")}
|
||||
ch <- coreexecutor.StreamChunk{
|
||||
Err: &coreauth.Error{
|
||||
Code: "upstream_closed",
|
||||
Message: "upstream closed",
|
||||
Retryable: false,
|
||||
HTTPStatus: http.StatusBadGateway,
|
||||
},
|
||||
}
|
||||
close(ch)
|
||||
return ch, nil
|
||||
}
|
||||
|
||||
func (e *payloadThenErrorStreamExecutor) Refresh(ctx context.Context, auth *coreauth.Auth) (*coreauth.Auth, error) {
|
||||
return auth, nil
|
||||
}
|
||||
|
||||
func (e *payloadThenErrorStreamExecutor) CountTokens(context.Context, *coreauth.Auth, coreexecutor.Request, coreexecutor.Options) (coreexecutor.Response, error) {
|
||||
return coreexecutor.Response{}, &coreauth.Error{Code: "not_implemented", Message: "CountTokens not implemented"}
|
||||
}
|
||||
|
||||
func (e *payloadThenErrorStreamExecutor) HttpRequest(ctx context.Context, auth *coreauth.Auth, req *http.Request) (*http.Response, error) {
|
||||
return nil, &coreauth.Error{
|
||||
Code: "not_implemented",
|
||||
Message: "HttpRequest not implemented",
|
||||
HTTPStatus: http.StatusNotImplemented,
|
||||
}
|
||||
}
|
||||
|
||||
func (e *payloadThenErrorStreamExecutor) Calls() int {
|
||||
e.mu.Lock()
|
||||
defer e.mu.Unlock()
|
||||
return e.calls
|
||||
}
|
||||
|
||||
func TestExecuteStreamWithAuthManager_RetriesBeforeFirstByte(t *testing.T) {
|
||||
executor := &failOnceStreamExecutor{}
|
||||
manager := coreauth.NewManager(nil, nil, nil)
|
||||
@@ -130,3 +182,73 @@ func TestExecuteStreamWithAuthManager_RetriesBeforeFirstByte(t *testing.T) {
|
||||
t.Fatalf("expected 2 stream attempts, got %d", executor.Calls())
|
||||
}
|
||||
}
|
||||
|
||||
func TestExecuteStreamWithAuthManager_DoesNotRetryAfterFirstByte(t *testing.T) {
|
||||
executor := &payloadThenErrorStreamExecutor{}
|
||||
manager := coreauth.NewManager(nil, nil, nil)
|
||||
manager.RegisterExecutor(executor)
|
||||
|
||||
auth1 := &coreauth.Auth{
|
||||
ID: "auth1",
|
||||
Provider: "codex",
|
||||
Status: coreauth.StatusActive,
|
||||
Metadata: map[string]any{"email": "test1@example.com"},
|
||||
}
|
||||
if _, err := manager.Register(context.Background(), auth1); err != nil {
|
||||
t.Fatalf("manager.Register(auth1): %v", err)
|
||||
}
|
||||
|
||||
auth2 := &coreauth.Auth{
|
||||
ID: "auth2",
|
||||
Provider: "codex",
|
||||
Status: coreauth.StatusActive,
|
||||
Metadata: map[string]any{"email": "test2@example.com"},
|
||||
}
|
||||
if _, err := manager.Register(context.Background(), auth2); err != nil {
|
||||
t.Fatalf("manager.Register(auth2): %v", err)
|
||||
}
|
||||
|
||||
registry.GetGlobalRegistry().RegisterClient(auth1.ID, auth1.Provider, []*registry.ModelInfo{{ID: "test-model"}})
|
||||
registry.GetGlobalRegistry().RegisterClient(auth2.ID, auth2.Provider, []*registry.ModelInfo{{ID: "test-model"}})
|
||||
t.Cleanup(func() {
|
||||
registry.GetGlobalRegistry().UnregisterClient(auth1.ID)
|
||||
registry.GetGlobalRegistry().UnregisterClient(auth2.ID)
|
||||
})
|
||||
|
||||
handler := NewBaseAPIHandlers(&sdkconfig.SDKConfig{
|
||||
Streaming: sdkconfig.StreamingConfig{
|
||||
BootstrapRetries: 1,
|
||||
},
|
||||
}, manager)
|
||||
dataChan, errChan := handler.ExecuteStreamWithAuthManager(context.Background(), "openai", "test-model", []byte(`{"model":"test-model"}`), "")
|
||||
if dataChan == nil || errChan == nil {
|
||||
t.Fatalf("expected non-nil channels")
|
||||
}
|
||||
|
||||
var got []byte
|
||||
for chunk := range dataChan {
|
||||
got = append(got, chunk...)
|
||||
}
|
||||
|
||||
var gotErr error
|
||||
var gotStatus int
|
||||
for msg := range errChan {
|
||||
if msg != nil && msg.Error != nil {
|
||||
gotErr = msg.Error
|
||||
gotStatus = msg.StatusCode
|
||||
}
|
||||
}
|
||||
|
||||
if string(got) != "partial" {
|
||||
t.Fatalf("expected payload partial, got %q", string(got))
|
||||
}
|
||||
if gotErr == nil {
|
||||
t.Fatalf("expected terminal error, got nil")
|
||||
}
|
||||
if gotStatus != http.StatusBadGateway {
|
||||
t.Fatalf("expected status %d, got %d", http.StatusBadGateway, gotStatus)
|
||||
}
|
||||
if executor.Calls() != 1 {
|
||||
t.Fatalf("expected 1 stream attempt, got %d", executor.Calls())
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user