From 39597267ae0158a092424103b29126321568ec57 Mon Sep 17 00:00:00 2001 From: BigUncle Date: Wed, 17 Dec 2025 23:33:17 +0800 Subject: [PATCH 01/38] fix(auth): prevent token refresh loop by ignoring timestamp fields MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Add metadataEqualIgnoringTimestamps() function to compare metadata JSON without timestamp/expired/expires_in/last_refresh/access_token fields. This prevents unnecessary file writes when only these fields change during refresh, breaking the fsnotify event → Watcher callback → refresh loop. Key insight: Google OAuth returns a new access_token on each refresh, which was causing file writes and triggering the refresh loop. Fixes antigravity channel excessive log generation issue. 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude --- sdk/auth/filestore.go | 32 +++++++++++++++++++++++++++++++- 1 file changed, 31 insertions(+), 1 deletion(-) diff --git a/sdk/auth/filestore.go b/sdk/auth/filestore.go index 3c2d60c4..2fa963df 100644 --- a/sdk/auth/filestore.go +++ b/sdk/auth/filestore.go @@ -72,7 +72,9 @@ func (s *FileTokenStore) Save(ctx context.Context, auth *cliproxyauth.Auth) (str return "", fmt.Errorf("auth filestore: marshal metadata failed: %w", errMarshal) } if existing, errRead := os.ReadFile(path); errRead == nil { - if jsonEqual(existing, raw) { + // Use metadataEqualIgnoringTimestamps to skip writes when only timestamp fields change. + // This prevents the token refresh loop caused by timestamp/expired/expires_in changes. + if metadataEqualIgnoringTimestamps(existing, raw) { return path, nil } } else if errRead != nil && !os.IsNotExist(errRead) { @@ -264,6 +266,8 @@ func (s *FileTokenStore) baseDirSnapshot() string { return s.baseDir } +// DEPRECATED: Use metadataEqualIgnoringTimestamps for comparing auth metadata. +// This function is kept for backward compatibility but can cause refresh loops. func jsonEqual(a, b []byte) bool { var objA any var objB any @@ -276,6 +280,32 @@ func jsonEqual(a, b []byte) bool { return deepEqualJSON(objA, objB) } +// metadataEqualIgnoringTimestamps compares two metadata JSON blobs, +// ignoring fields that change on every refresh but don't affect functionality. +// This prevents unnecessary file writes that would trigger watcher events and +// create refresh loops. +func metadataEqualIgnoringTimestamps(a, b []byte) bool { + var objA, objB map[string]any + if err := json.Unmarshal(a, &objA); err != nil { + return false + } + if err := json.Unmarshal(b, &objB); err != nil { + return false + } + + // Fields to ignore: these change on every refresh but don't affect authentication logic. + // - timestamp, expired, expires_in, last_refresh: time-related fields that change on refresh + // - access_token: Google OAuth returns a new access_token on each refresh, this is expected + // and shouldn't trigger file writes (the new token will be fetched again when needed) + ignoredFields := []string{"timestamp", "expired", "expires_in", "last_refresh", "access_token"} + for _, field := range ignoredFields { + delete(objA, field) + delete(objB, field) + } + + return deepEqualJSON(objA, objB) +} + func deepEqualJSON(a, b any) bool { switch valA := a.(type) { case map[string]any: From cfa8ddb59f3dbddd7288c565163b5cd4b7ae25ad Mon Sep 17 00:00:00 2001 From: Supra4E8C Date: Fri, 19 Dec 2025 00:38:29 +0800 Subject: [PATCH 02/38] feat(oauth): add remote OAuth callback support with session management Introduce a centralized OAuth session store with TTL-based expiration to replace the previous simple map-based status tracking. Add a new /api/oauth/callback endpoint that allows remote clients to relay OAuth callback data back to the CLI proxy, enabling OAuth flows when the callback cannot reach the local machine directly. - Add oauth_sessions.go with thread-safe session store and validation - Add oauth_callback.go with POST handler for remote callback relay - Refactor auth_files.go to use new session management APIs - Register new callback route in server.go --- .../api/handlers/management/auth_files.go | 165 +++++------ .../api/handlers/management/oauth_callback.go | 100 +++++++ .../api/handlers/management/oauth_sessions.go | 258 ++++++++++++++++++ internal/api/server.go | 32 ++- 4 files changed, 466 insertions(+), 89 deletions(-) create mode 100644 internal/api/handlers/management/oauth_callback.go create mode 100644 internal/api/handlers/management/oauth_sessions.go diff --git a/internal/api/handlers/management/auth_files.go b/internal/api/handlers/management/auth_files.go index e29dc29f..433bae92 100644 --- a/internal/api/handlers/management/auth_files.go +++ b/internal/api/handlers/management/auth_files.go @@ -36,10 +36,6 @@ import ( "golang.org/x/oauth2/google" ) -var ( - oauthStatus = make(map[string]string) -) - var lastRefreshKeys = []string{"last_refresh", "lastRefresh", "last_refreshed_at", "lastRefreshedAt"} const ( @@ -786,6 +782,8 @@ func (h *Handler) RequestAnthropicToken(c *gin.Context) { return } + RegisterOAuthSession(state, "anthropic") + isWebUI := isWebUIRequest(c) if isWebUI { targetURL, errTarget := h.managementCallbackURL("/anthropic/callback") @@ -812,7 +810,7 @@ func (h *Handler) RequestAnthropicToken(c *gin.Context) { deadline := time.Now().Add(timeout) for { if time.Now().After(deadline) { - oauthStatus[state] = "Timeout waiting for OAuth callback" + SetOAuthSessionError(state, "Timeout waiting for OAuth callback") return nil, fmt.Errorf("timeout waiting for OAuth callback") } data, errRead := os.ReadFile(path) @@ -837,13 +835,13 @@ func (h *Handler) RequestAnthropicToken(c *gin.Context) { if errStr := resultMap["error"]; errStr != "" { oauthErr := claude.NewOAuthError(errStr, "", http.StatusBadRequest) log.Error(claude.GetUserFriendlyMessage(oauthErr)) - oauthStatus[state] = "Bad request" + SetOAuthSessionError(state, "Bad request") return } if resultMap["state"] != state { authErr := claude.NewAuthenticationError(claude.ErrInvalidState, fmt.Errorf("expected %s, got %s", state, resultMap["state"])) log.Error(claude.GetUserFriendlyMessage(authErr)) - oauthStatus[state] = "State code error" + SetOAuthSessionError(state, "State code error") return } @@ -876,7 +874,7 @@ func (h *Handler) RequestAnthropicToken(c *gin.Context) { if errDo != nil { authErr := claude.NewAuthenticationError(claude.ErrCodeExchangeFailed, errDo) log.Errorf("Failed to exchange authorization code for tokens: %v", authErr) - oauthStatus[state] = "Failed to exchange authorization code for tokens" + SetOAuthSessionError(state, "Failed to exchange authorization code for tokens") return } defer func() { @@ -887,7 +885,7 @@ func (h *Handler) RequestAnthropicToken(c *gin.Context) { respBody, _ := io.ReadAll(resp.Body) if resp.StatusCode != http.StatusOK { log.Errorf("token exchange failed with status %d: %s", resp.StatusCode, string(respBody)) - oauthStatus[state] = fmt.Sprintf("token exchange failed with status %d", resp.StatusCode) + SetOAuthSessionError(state, fmt.Sprintf("token exchange failed with status %d", resp.StatusCode)) return } var tResp struct { @@ -900,7 +898,7 @@ func (h *Handler) RequestAnthropicToken(c *gin.Context) { } if errU := json.Unmarshal(respBody, &tResp); errU != nil { log.Errorf("failed to parse token response: %v", errU) - oauthStatus[state] = "Failed to parse token response" + SetOAuthSessionError(state, "Failed to parse token response") return } bundle := &claude.ClaudeAuthBundle{ @@ -925,7 +923,7 @@ func (h *Handler) RequestAnthropicToken(c *gin.Context) { savedPath, errSave := h.saveTokenRecord(ctx, record) if errSave != nil { log.Errorf("Failed to save authentication tokens: %v", errSave) - oauthStatus[state] = "Failed to save authentication tokens" + SetOAuthSessionError(state, "Failed to save authentication tokens") return } @@ -934,10 +932,9 @@ func (h *Handler) RequestAnthropicToken(c *gin.Context) { fmt.Println("API key obtained and saved") } fmt.Println("You can now use Claude services through this CLI") - delete(oauthStatus, state) + CompleteOAuthSession(state) }() - oauthStatus[state] = "" c.JSON(200, gin.H{"status": "ok", "url": authURL, "state": state}) } @@ -968,6 +965,8 @@ func (h *Handler) RequestGeminiCLIToken(c *gin.Context) { state := fmt.Sprintf("gem-%d", time.Now().UnixNano()) authURL := conf.AuthCodeURL(state, oauth2.AccessTypeOffline, oauth2.SetAuthURLParam("prompt", "consent")) + RegisterOAuthSession(state, "gemini") + isWebUI := isWebUIRequest(c) if isWebUI { targetURL, errTarget := h.managementCallbackURL("/google/callback") @@ -996,7 +995,7 @@ func (h *Handler) RequestGeminiCLIToken(c *gin.Context) { for { if time.Now().After(deadline) { log.Error("oauth flow timed out") - oauthStatus[state] = "OAuth flow timed out" + SetOAuthSessionError(state, "OAuth flow timed out") return } if data, errR := os.ReadFile(waitFile); errR == nil { @@ -1005,13 +1004,13 @@ func (h *Handler) RequestGeminiCLIToken(c *gin.Context) { _ = os.Remove(waitFile) if errStr := m["error"]; errStr != "" { log.Errorf("Authentication failed: %s", errStr) - oauthStatus[state] = "Authentication failed" + SetOAuthSessionError(state, "Authentication failed") return } authCode = m["code"] if authCode == "" { log.Errorf("Authentication failed: code not found") - oauthStatus[state] = "Authentication failed: code not found" + SetOAuthSessionError(state, "Authentication failed: code not found") return } break @@ -1023,7 +1022,7 @@ func (h *Handler) RequestGeminiCLIToken(c *gin.Context) { token, err := conf.Exchange(ctx, authCode) if err != nil { log.Errorf("Failed to exchange token: %v", err) - oauthStatus[state] = "Failed to exchange token" + SetOAuthSessionError(state, "Failed to exchange token") return } @@ -1034,7 +1033,7 @@ func (h *Handler) RequestGeminiCLIToken(c *gin.Context) { req, errNewRequest := http.NewRequestWithContext(ctx, "GET", "https://www.googleapis.com/oauth2/v1/userinfo?alt=json", nil) if errNewRequest != nil { log.Errorf("Could not get user info: %v", errNewRequest) - oauthStatus[state] = "Could not get user info" + SetOAuthSessionError(state, "Could not get user info") return } req.Header.Set("Content-Type", "application/json") @@ -1043,7 +1042,7 @@ func (h *Handler) RequestGeminiCLIToken(c *gin.Context) { resp, errDo := authHTTPClient.Do(req) if errDo != nil { log.Errorf("Failed to execute request: %v", errDo) - oauthStatus[state] = "Failed to execute request" + SetOAuthSessionError(state, "Failed to execute request") return } defer func() { @@ -1055,7 +1054,7 @@ func (h *Handler) RequestGeminiCLIToken(c *gin.Context) { bodyBytes, _ := io.ReadAll(resp.Body) if resp.StatusCode < 200 || resp.StatusCode >= 300 { log.Errorf("Get user info request failed with status %d: %s", resp.StatusCode, string(bodyBytes)) - oauthStatus[state] = fmt.Sprintf("Get user info request failed with status %d", resp.StatusCode) + SetOAuthSessionError(state, fmt.Sprintf("Get user info request failed with status %d", resp.StatusCode)) return } @@ -1064,7 +1063,6 @@ func (h *Handler) RequestGeminiCLIToken(c *gin.Context) { fmt.Printf("Authenticated user email: %s\n", email) } else { fmt.Println("Failed to get user email from token") - oauthStatus[state] = "Failed to get user email from token" } // Marshal/unmarshal oauth2.Token to generic map and enrich fields @@ -1072,7 +1070,7 @@ func (h *Handler) RequestGeminiCLIToken(c *gin.Context) { jsonData, _ := json.Marshal(token) if errUnmarshal := json.Unmarshal(jsonData, &ifToken); errUnmarshal != nil { log.Errorf("Failed to unmarshal token: %v", errUnmarshal) - oauthStatus[state] = "Failed to unmarshal token" + SetOAuthSessionError(state, "Failed to unmarshal token") return } @@ -1098,7 +1096,7 @@ func (h *Handler) RequestGeminiCLIToken(c *gin.Context) { gemClient, errGetClient := gemAuth.GetAuthenticatedClient(ctx, &ts, h.cfg, true) if errGetClient != nil { log.Errorf("failed to get authenticated client: %v", errGetClient) - oauthStatus[state] = "Failed to get authenticated client" + SetOAuthSessionError(state, "Failed to get authenticated client") return } fmt.Println("Authentication successful.") @@ -1108,12 +1106,12 @@ func (h *Handler) RequestGeminiCLIToken(c *gin.Context) { projects, errAll := onboardAllGeminiProjects(ctx, gemClient, &ts) if errAll != nil { log.Errorf("Failed to complete Gemini CLI onboarding: %v", errAll) - oauthStatus[state] = "Failed to complete Gemini CLI onboarding" + SetOAuthSessionError(state, "Failed to complete Gemini CLI onboarding") return } if errVerify := ensureGeminiProjectsEnabled(ctx, gemClient, projects); errVerify != nil { log.Errorf("Failed to verify Cloud AI API status: %v", errVerify) - oauthStatus[state] = "Failed to verify Cloud AI API status" + SetOAuthSessionError(state, "Failed to verify Cloud AI API status") return } ts.ProjectID = strings.Join(projects, ",") @@ -1121,26 +1119,26 @@ func (h *Handler) RequestGeminiCLIToken(c *gin.Context) { } else { if errEnsure := ensureGeminiProjectAndOnboard(ctx, gemClient, &ts, requestedProjectID); errEnsure != nil { log.Errorf("Failed to complete Gemini CLI onboarding: %v", errEnsure) - oauthStatus[state] = "Failed to complete Gemini CLI onboarding" + SetOAuthSessionError(state, "Failed to complete Gemini CLI onboarding") return } if strings.TrimSpace(ts.ProjectID) == "" { log.Error("Onboarding did not return a project ID") - oauthStatus[state] = "Failed to resolve project ID" + SetOAuthSessionError(state, "Failed to resolve project ID") return } isChecked, errCheck := checkCloudAPIIsEnabled(ctx, gemClient, ts.ProjectID) if errCheck != nil { log.Errorf("Failed to verify Cloud AI API status: %v", errCheck) - oauthStatus[state] = "Failed to verify Cloud AI API status" + SetOAuthSessionError(state, "Failed to verify Cloud AI API status") return } ts.Checked = isChecked if !isChecked { log.Error("Cloud AI API is not enabled for the selected project") - oauthStatus[state] = "Cloud AI API not enabled" + SetOAuthSessionError(state, "Cloud AI API not enabled") return } } @@ -1163,15 +1161,14 @@ func (h *Handler) RequestGeminiCLIToken(c *gin.Context) { savedPath, errSave := h.saveTokenRecord(ctx, record) if errSave != nil { log.Errorf("Failed to save token to file: %v", errSave) - oauthStatus[state] = "Failed to save token to file" + SetOAuthSessionError(state, "Failed to save token to file") return } - delete(oauthStatus, state) + CompleteOAuthSession(state) fmt.Printf("You can now use Gemini CLI services through this CLI; token saved to %s\n", savedPath) }() - oauthStatus[state] = "" c.JSON(200, gin.H{"status": "ok", "url": authURL, "state": state}) } @@ -1207,6 +1204,8 @@ func (h *Handler) RequestCodexToken(c *gin.Context) { return } + RegisterOAuthSession(state, "codex") + isWebUI := isWebUIRequest(c) if isWebUI { targetURL, errTarget := h.managementCallbackURL("/codex/callback") @@ -1235,7 +1234,7 @@ func (h *Handler) RequestCodexToken(c *gin.Context) { if time.Now().After(deadline) { authErr := codex.NewAuthenticationError(codex.ErrCallbackTimeout, fmt.Errorf("timeout waiting for OAuth callback")) log.Error(codex.GetUserFriendlyMessage(authErr)) - oauthStatus[state] = "Timeout waiting for OAuth callback" + SetOAuthSessionError(state, "Timeout waiting for OAuth callback") return } if data, errR := os.ReadFile(waitFile); errR == nil { @@ -1245,12 +1244,12 @@ func (h *Handler) RequestCodexToken(c *gin.Context) { if errStr := m["error"]; errStr != "" { oauthErr := codex.NewOAuthError(errStr, "", http.StatusBadRequest) log.Error(codex.GetUserFriendlyMessage(oauthErr)) - oauthStatus[state] = "Bad Request" + SetOAuthSessionError(state, "Bad Request") return } if m["state"] != state { authErr := codex.NewAuthenticationError(codex.ErrInvalidState, fmt.Errorf("expected %s, got %s", state, m["state"])) - oauthStatus[state] = "State code error" + SetOAuthSessionError(state, "State code error") log.Error(codex.GetUserFriendlyMessage(authErr)) return } @@ -1281,14 +1280,14 @@ func (h *Handler) RequestCodexToken(c *gin.Context) { resp, errDo := httpClient.Do(req) if errDo != nil { authErr := codex.NewAuthenticationError(codex.ErrCodeExchangeFailed, errDo) - oauthStatus[state] = "Failed to exchange authorization code for tokens" + SetOAuthSessionError(state, "Failed to exchange authorization code for tokens") log.Errorf("Failed to exchange authorization code for tokens: %v", authErr) return } defer func() { _ = resp.Body.Close() }() respBody, _ := io.ReadAll(resp.Body) if resp.StatusCode != http.StatusOK { - oauthStatus[state] = fmt.Sprintf("Token exchange failed with status %d", resp.StatusCode) + SetOAuthSessionError(state, fmt.Sprintf("Token exchange failed with status %d", resp.StatusCode)) log.Errorf("token exchange failed with status %d: %s", resp.StatusCode, string(respBody)) return } @@ -1299,7 +1298,7 @@ func (h *Handler) RequestCodexToken(c *gin.Context) { ExpiresIn int `json:"expires_in"` } if errU := json.Unmarshal(respBody, &tokenResp); errU != nil { - oauthStatus[state] = "Failed to parse token response" + SetOAuthSessionError(state, "Failed to parse token response") log.Errorf("failed to parse token response: %v", errU) return } @@ -1337,7 +1336,7 @@ func (h *Handler) RequestCodexToken(c *gin.Context) { } savedPath, errSave := h.saveTokenRecord(ctx, record) if errSave != nil { - oauthStatus[state] = "Failed to save authentication tokens" + SetOAuthSessionError(state, "Failed to save authentication tokens") log.Errorf("Failed to save authentication tokens: %v", errSave) return } @@ -1346,10 +1345,9 @@ func (h *Handler) RequestCodexToken(c *gin.Context) { fmt.Println("API key obtained and saved") } fmt.Println("You can now use Codex services through this CLI") - delete(oauthStatus, state) + CompleteOAuthSession(state) }() - oauthStatus[state] = "" c.JSON(200, gin.H{"status": "ok", "url": authURL, "state": state}) } @@ -1390,6 +1388,8 @@ func (h *Handler) RequestAntigravityToken(c *gin.Context) { params.Set("state", state) authURL := "https://accounts.google.com/o/oauth2/v2/auth?" + params.Encode() + RegisterOAuthSession(state, "antigravity") + isWebUI := isWebUIRequest(c) if isWebUI { targetURL, errTarget := h.managementCallbackURL("/antigravity/callback") @@ -1416,7 +1416,7 @@ func (h *Handler) RequestAntigravityToken(c *gin.Context) { for { if time.Now().After(deadline) { log.Error("oauth flow timed out") - oauthStatus[state] = "OAuth flow timed out" + SetOAuthSessionError(state, "OAuth flow timed out") return } if data, errReadFile := os.ReadFile(waitFile); errReadFile == nil { @@ -1425,18 +1425,18 @@ func (h *Handler) RequestAntigravityToken(c *gin.Context) { _ = os.Remove(waitFile) if errStr := strings.TrimSpace(payload["error"]); errStr != "" { log.Errorf("Authentication failed: %s", errStr) - oauthStatus[state] = "Authentication failed" + SetOAuthSessionError(state, "Authentication failed") return } if payloadState := strings.TrimSpace(payload["state"]); payloadState != "" && payloadState != state { log.Errorf("Authentication failed: state mismatch") - oauthStatus[state] = "Authentication failed: state mismatch" + SetOAuthSessionError(state, "Authentication failed: state mismatch") return } authCode = strings.TrimSpace(payload["code"]) if authCode == "" { log.Error("Authentication failed: code not found") - oauthStatus[state] = "Authentication failed: code not found" + SetOAuthSessionError(state, "Authentication failed: code not found") return } break @@ -1455,7 +1455,7 @@ func (h *Handler) RequestAntigravityToken(c *gin.Context) { req, errNewRequest := http.NewRequestWithContext(ctx, http.MethodPost, "https://oauth2.googleapis.com/token", strings.NewReader(form.Encode())) if errNewRequest != nil { log.Errorf("Failed to build token request: %v", errNewRequest) - oauthStatus[state] = "Failed to build token request" + SetOAuthSessionError(state, "Failed to build token request") return } req.Header.Set("Content-Type", "application/x-www-form-urlencoded") @@ -1463,7 +1463,7 @@ func (h *Handler) RequestAntigravityToken(c *gin.Context) { resp, errDo := httpClient.Do(req) if errDo != nil { log.Errorf("Failed to execute token request: %v", errDo) - oauthStatus[state] = "Failed to exchange token" + SetOAuthSessionError(state, "Failed to exchange token") return } defer func() { @@ -1475,7 +1475,7 @@ func (h *Handler) RequestAntigravityToken(c *gin.Context) { if resp.StatusCode < http.StatusOK || resp.StatusCode >= http.StatusMultipleChoices { bodyBytes, _ := io.ReadAll(resp.Body) log.Errorf("Antigravity token exchange failed with status %d: %s", resp.StatusCode, string(bodyBytes)) - oauthStatus[state] = fmt.Sprintf("Token exchange failed: %d", resp.StatusCode) + SetOAuthSessionError(state, fmt.Sprintf("Token exchange failed: %d", resp.StatusCode)) return } @@ -1487,7 +1487,7 @@ func (h *Handler) RequestAntigravityToken(c *gin.Context) { } if errDecode := json.NewDecoder(resp.Body).Decode(&tokenResp); errDecode != nil { log.Errorf("Failed to parse token response: %v", errDecode) - oauthStatus[state] = "Failed to parse token response" + SetOAuthSessionError(state, "Failed to parse token response") return } @@ -1496,7 +1496,7 @@ func (h *Handler) RequestAntigravityToken(c *gin.Context) { infoReq, errInfoReq := http.NewRequestWithContext(ctx, http.MethodGet, "https://www.googleapis.com/oauth2/v1/userinfo?alt=json", nil) if errInfoReq != nil { log.Errorf("Failed to build user info request: %v", errInfoReq) - oauthStatus[state] = "Failed to build user info request" + SetOAuthSessionError(state, "Failed to build user info request") return } infoReq.Header.Set("Authorization", "Bearer "+tokenResp.AccessToken) @@ -1504,7 +1504,7 @@ func (h *Handler) RequestAntigravityToken(c *gin.Context) { infoResp, errInfo := httpClient.Do(infoReq) if errInfo != nil { log.Errorf("Failed to execute user info request: %v", errInfo) - oauthStatus[state] = "Failed to execute user info request" + SetOAuthSessionError(state, "Failed to execute user info request") return } defer func() { @@ -1523,7 +1523,7 @@ func (h *Handler) RequestAntigravityToken(c *gin.Context) { } else { bodyBytes, _ := io.ReadAll(infoResp.Body) log.Errorf("User info request failed with status %d: %s", infoResp.StatusCode, string(bodyBytes)) - oauthStatus[state] = fmt.Sprintf("User info request failed: %d", infoResp.StatusCode) + SetOAuthSessionError(state, fmt.Sprintf("User info request failed: %d", infoResp.StatusCode)) return } } @@ -1571,11 +1571,11 @@ func (h *Handler) RequestAntigravityToken(c *gin.Context) { savedPath, errSave := h.saveTokenRecord(ctx, record) if errSave != nil { log.Errorf("Failed to save token to file: %v", errSave) - oauthStatus[state] = "Failed to save token to file" + SetOAuthSessionError(state, "Failed to save token to file") return } - delete(oauthStatus, state) + CompleteOAuthSession(state) fmt.Printf("Authentication successful! Token saved to %s\n", savedPath) if projectID != "" { fmt.Printf("Using GCP project: %s\n", projectID) @@ -1583,7 +1583,6 @@ func (h *Handler) RequestAntigravityToken(c *gin.Context) { fmt.Println("You can now use Antigravity services through this CLI") }() - oauthStatus[state] = "" c.JSON(200, gin.H{"status": "ok", "url": authURL, "state": state}) } @@ -1605,11 +1604,13 @@ func (h *Handler) RequestQwenToken(c *gin.Context) { } authURL := deviceFlow.VerificationURIComplete + RegisterOAuthSession(state, "qwen") + go func() { fmt.Println("Waiting for authentication...") tokenData, errPollForToken := qwenAuth.PollForToken(deviceFlow.DeviceCode, deviceFlow.CodeVerifier) if errPollForToken != nil { - oauthStatus[state] = "Authentication failed" + SetOAuthSessionError(state, "Authentication failed") fmt.Printf("Authentication failed: %v\n", errPollForToken) return } @@ -1628,16 +1629,15 @@ func (h *Handler) RequestQwenToken(c *gin.Context) { savedPath, errSave := h.saveTokenRecord(ctx, record) if errSave != nil { log.Errorf("Failed to save authentication tokens: %v", errSave) - oauthStatus[state] = "Failed to save authentication tokens" + SetOAuthSessionError(state, "Failed to save authentication tokens") return } fmt.Printf("Authentication successful! Token saved to %s\n", savedPath) fmt.Println("You can now use Qwen services through this CLI") - delete(oauthStatus, state) + CompleteOAuthSession(state) }() - oauthStatus[state] = "" c.JSON(200, gin.H{"status": "ok", "url": authURL, "state": state}) } @@ -1650,6 +1650,8 @@ func (h *Handler) RequestIFlowToken(c *gin.Context) { authSvc := iflowauth.NewIFlowAuth(h.cfg) authURL, redirectURI := authSvc.AuthorizationURL(state, iflowauth.CallbackPort) + RegisterOAuthSession(state, "iflow") + isWebUI := isWebUIRequest(c) if isWebUI { targetURL, errTarget := h.managementCallbackURL("/iflow/callback") @@ -1676,7 +1678,7 @@ func (h *Handler) RequestIFlowToken(c *gin.Context) { var resultMap map[string]string for { if time.Now().After(deadline) { - oauthStatus[state] = "Authentication failed" + SetOAuthSessionError(state, "Authentication failed") fmt.Println("Authentication failed: timeout waiting for callback") return } @@ -1689,26 +1691,26 @@ func (h *Handler) RequestIFlowToken(c *gin.Context) { } if errStr := strings.TrimSpace(resultMap["error"]); errStr != "" { - oauthStatus[state] = "Authentication failed" + SetOAuthSessionError(state, "Authentication failed") fmt.Printf("Authentication failed: %s\n", errStr) return } if resultState := strings.TrimSpace(resultMap["state"]); resultState != state { - oauthStatus[state] = "Authentication failed" + SetOAuthSessionError(state, "Authentication failed") fmt.Println("Authentication failed: state mismatch") return } code := strings.TrimSpace(resultMap["code"]) if code == "" { - oauthStatus[state] = "Authentication failed" + SetOAuthSessionError(state, "Authentication failed") fmt.Println("Authentication failed: code missing") return } tokenData, errExchange := authSvc.ExchangeCodeForTokens(ctx, code, redirectURI) if errExchange != nil { - oauthStatus[state] = "Authentication failed" + SetOAuthSessionError(state, "Authentication failed") fmt.Printf("Authentication failed: %v\n", errExchange) return } @@ -1730,7 +1732,7 @@ func (h *Handler) RequestIFlowToken(c *gin.Context) { savedPath, errSave := h.saveTokenRecord(ctx, record) if errSave != nil { - oauthStatus[state] = "Failed to save authentication tokens" + SetOAuthSessionError(state, "Failed to save authentication tokens") log.Errorf("Failed to save authentication tokens: %v", errSave) return } @@ -1740,10 +1742,9 @@ func (h *Handler) RequestIFlowToken(c *gin.Context) { fmt.Println("API key obtained and saved") } fmt.Println("You can now use iFlow services through this CLI") - delete(oauthStatus, state) + CompleteOAuthSession(state) }() - oauthStatus[state] = "" c.JSON(http.StatusOK, gin.H{"status": "ok", "url": authURL, "state": state}) } @@ -2179,16 +2180,24 @@ func checkCloudAPIIsEnabled(ctx context.Context, httpClient *http.Client, projec } func (h *Handler) GetAuthStatus(c *gin.Context) { - state := c.Query("state") - if err, ok := oauthStatus[state]; ok { - if err != "" { - c.JSON(200, gin.H{"status": "error", "error": err}) - } else { - c.JSON(200, gin.H{"status": "wait"}) - return - } - } else { - c.JSON(200, gin.H{"status": "ok"}) + state := strings.TrimSpace(c.Query("state")) + if state == "" { + c.JSON(http.StatusBadRequest, gin.H{"status": "error", "error": "state is required"}) + return } - delete(oauthStatus, state) + if err := ValidateOAuthState(state); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"status": "error", "error": "invalid state"}) + return + } + + _, status, ok := GetOAuthSession(state) + if !ok { + c.JSON(http.StatusOK, gin.H{"status": "ok"}) + return + } + if status != "" { + c.JSON(http.StatusOK, gin.H{"status": "error", "error": status}) + return + } + c.JSON(http.StatusOK, gin.H{"status": "wait"}) } diff --git a/internal/api/handlers/management/oauth_callback.go b/internal/api/handlers/management/oauth_callback.go new file mode 100644 index 00000000..c69a332e --- /dev/null +++ b/internal/api/handlers/management/oauth_callback.go @@ -0,0 +1,100 @@ +package management + +import ( + "errors" + "net/http" + "net/url" + "strings" + + "github.com/gin-gonic/gin" +) + +type oauthCallbackRequest struct { + Provider string `json:"provider"` + RedirectURL string `json:"redirect_url"` + Code string `json:"code"` + State string `json:"state"` + Error string `json:"error"` +} + +func (h *Handler) PostOAuthCallback(c *gin.Context) { + if h == nil || h.cfg == nil { + c.JSON(http.StatusInternalServerError, gin.H{"status": "error", "error": "handler not initialized"}) + return + } + + var req oauthCallbackRequest + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"status": "error", "error": "invalid body"}) + return + } + + canonicalProvider, err := NormalizeOAuthProvider(req.Provider) + if err != nil { + c.JSON(http.StatusBadRequest, gin.H{"status": "error", "error": "unsupported provider"}) + return + } + + state := strings.TrimSpace(req.State) + code := strings.TrimSpace(req.Code) + errMsg := strings.TrimSpace(req.Error) + + if rawRedirect := strings.TrimSpace(req.RedirectURL); rawRedirect != "" { + u, errParse := url.Parse(rawRedirect) + if errParse != nil { + c.JSON(http.StatusBadRequest, gin.H{"status": "error", "error": "invalid redirect_url"}) + return + } + q := u.Query() + if state == "" { + state = strings.TrimSpace(q.Get("state")) + } + if code == "" { + code = strings.TrimSpace(q.Get("code")) + } + if errMsg == "" { + errMsg = strings.TrimSpace(q.Get("error")) + if errMsg == "" { + errMsg = strings.TrimSpace(q.Get("error_description")) + } + } + } + + if state == "" { + c.JSON(http.StatusBadRequest, gin.H{"status": "error", "error": "state is required"}) + return + } + if err := ValidateOAuthState(state); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"status": "error", "error": "invalid state"}) + return + } + if code == "" && errMsg == "" { + c.JSON(http.StatusBadRequest, gin.H{"status": "error", "error": "code or error is required"}) + return + } + + sessionProvider, sessionStatus, ok := GetOAuthSession(state) + if !ok { + c.JSON(http.StatusNotFound, gin.H{"status": "error", "error": "unknown or expired state"}) + return + } + if sessionStatus != "" { + c.JSON(http.StatusConflict, gin.H{"status": "error", "error": "oauth flow is not pending"}) + return + } + if !strings.EqualFold(sessionProvider, canonicalProvider) { + c.JSON(http.StatusBadRequest, gin.H{"status": "error", "error": "provider does not match state"}) + return + } + + if _, errWrite := WriteOAuthCallbackFileForPendingSession(h.cfg.AuthDir, canonicalProvider, state, code, errMsg); errWrite != nil { + if errors.Is(errWrite, errOAuthSessionNotPending) { + c.JSON(http.StatusConflict, gin.H{"status": "error", "error": "oauth flow is not pending"}) + return + } + c.JSON(http.StatusInternalServerError, gin.H{"status": "error", "error": "failed to persist oauth callback"}) + return + } + + c.JSON(http.StatusOK, gin.H{"status": "ok"}) +} diff --git a/internal/api/handlers/management/oauth_sessions.go b/internal/api/handlers/management/oauth_sessions.go new file mode 100644 index 00000000..f23b608c --- /dev/null +++ b/internal/api/handlers/management/oauth_sessions.go @@ -0,0 +1,258 @@ +package management + +import ( + "encoding/json" + "errors" + "fmt" + "os" + "path/filepath" + "strings" + "sync" + "time" +) + +const ( + oauthSessionTTL = 10 * time.Minute + maxOAuthStateLength = 128 +) + +var ( + errInvalidOAuthState = errors.New("invalid oauth state") + errUnsupportedOAuthFlow = errors.New("unsupported oauth provider") + errOAuthSessionNotPending = errors.New("oauth session is not pending") +) + +type oauthSession struct { + Provider string + Status string + CreatedAt time.Time + ExpiresAt time.Time +} + +type oauthSessionStore struct { + mu sync.RWMutex + ttl time.Duration + sessions map[string]oauthSession +} + +func newOAuthSessionStore(ttl time.Duration) *oauthSessionStore { + if ttl <= 0 { + ttl = oauthSessionTTL + } + return &oauthSessionStore{ + ttl: ttl, + sessions: make(map[string]oauthSession), + } +} + +func (s *oauthSessionStore) purgeExpiredLocked(now time.Time) { + for state, session := range s.sessions { + if !session.ExpiresAt.IsZero() && now.After(session.ExpiresAt) { + delete(s.sessions, state) + } + } +} + +func (s *oauthSessionStore) Register(state, provider string) { + state = strings.TrimSpace(state) + provider = strings.ToLower(strings.TrimSpace(provider)) + if state == "" || provider == "" { + return + } + now := time.Now() + + s.mu.Lock() + defer s.mu.Unlock() + + s.purgeExpiredLocked(now) + s.sessions[state] = oauthSession{ + Provider: provider, + Status: "", + CreatedAt: now, + ExpiresAt: now.Add(s.ttl), + } +} + +func (s *oauthSessionStore) SetError(state, message string) { + state = strings.TrimSpace(state) + message = strings.TrimSpace(message) + if state == "" { + return + } + if message == "" { + message = "Authentication failed" + } + now := time.Now() + + s.mu.Lock() + defer s.mu.Unlock() + + s.purgeExpiredLocked(now) + session, ok := s.sessions[state] + if !ok { + return + } + session.Status = message + session.ExpiresAt = now.Add(s.ttl) + s.sessions[state] = session +} + +func (s *oauthSessionStore) Complete(state string) { + state = strings.TrimSpace(state) + if state == "" { + return + } + now := time.Now() + + s.mu.Lock() + defer s.mu.Unlock() + + s.purgeExpiredLocked(now) + delete(s.sessions, state) +} + +func (s *oauthSessionStore) Get(state string) (oauthSession, bool) { + state = strings.TrimSpace(state) + now := time.Now() + + s.mu.Lock() + defer s.mu.Unlock() + + s.purgeExpiredLocked(now) + session, ok := s.sessions[state] + return session, ok +} + +func (s *oauthSessionStore) IsPending(state, provider string) bool { + state = strings.TrimSpace(state) + provider = strings.ToLower(strings.TrimSpace(provider)) + now := time.Now() + + s.mu.Lock() + defer s.mu.Unlock() + + s.purgeExpiredLocked(now) + session, ok := s.sessions[state] + if !ok { + return false + } + if session.Status != "" { + return false + } + if provider == "" { + return true + } + return strings.EqualFold(session.Provider, provider) +} + +var oauthSessions = newOAuthSessionStore(oauthSessionTTL) + +func RegisterOAuthSession(state, provider string) { oauthSessions.Register(state, provider) } + +func SetOAuthSessionError(state, message string) { oauthSessions.SetError(state, message) } + +func CompleteOAuthSession(state string) { oauthSessions.Complete(state) } + +func GetOAuthSession(state string) (provider string, status string, ok bool) { + session, ok := oauthSessions.Get(state) + if !ok { + return "", "", false + } + return session.Provider, session.Status, true +} + +func IsOAuthSessionPending(state, provider string) bool { + return oauthSessions.IsPending(state, provider) +} + +func ValidateOAuthState(state string) error { + trimmed := strings.TrimSpace(state) + if trimmed == "" { + return fmt.Errorf("%w: empty", errInvalidOAuthState) + } + if len(trimmed) > maxOAuthStateLength { + return fmt.Errorf("%w: too long", errInvalidOAuthState) + } + if strings.Contains(trimmed, "/") || strings.Contains(trimmed, "\\") { + return fmt.Errorf("%w: contains path separator", errInvalidOAuthState) + } + if strings.Contains(trimmed, "..") { + return fmt.Errorf("%w: contains '..'", errInvalidOAuthState) + } + for _, r := range trimmed { + switch { + case r >= 'a' && r <= 'z': + case r >= 'A' && r <= 'Z': + case r >= '0' && r <= '9': + case r == '-' || r == '_' || r == '.': + default: + return fmt.Errorf("%w: invalid character", errInvalidOAuthState) + } + } + return nil +} + +func NormalizeOAuthProvider(provider string) (string, error) { + switch strings.ToLower(strings.TrimSpace(provider)) { + case "anthropic", "claude": + return "anthropic", nil + case "codex", "openai": + return "codex", nil + case "gemini", "google": + return "gemini", nil + case "iflow", "i-flow": + return "iflow", nil + case "antigravity", "anti-gravity": + return "antigravity", nil + case "qwen": + return "qwen", nil + default: + return "", errUnsupportedOAuthFlow + } +} + +type oauthCallbackFilePayload struct { + Code string `json:"code"` + State string `json:"state"` + Error string `json:"error"` +} + +func WriteOAuthCallbackFile(authDir, provider, state, code, errorMessage string) (string, error) { + if strings.TrimSpace(authDir) == "" { + return "", fmt.Errorf("auth dir is empty") + } + canonicalProvider, err := NormalizeOAuthProvider(provider) + if err != nil { + return "", err + } + if err := ValidateOAuthState(state); err != nil { + return "", err + } + + fileName := fmt.Sprintf(".oauth-%s-%s.oauth", canonicalProvider, state) + filePath := filepath.Join(authDir, fileName) + payload := oauthCallbackFilePayload{ + Code: strings.TrimSpace(code), + State: strings.TrimSpace(state), + Error: strings.TrimSpace(errorMessage), + } + data, err := json.Marshal(payload) + if err != nil { + return "", fmt.Errorf("marshal oauth callback payload: %w", err) + } + if err := os.WriteFile(filePath, data, 0o600); err != nil { + return "", fmt.Errorf("write oauth callback file: %w", err) + } + return filePath, nil +} + +func WriteOAuthCallbackFileForPendingSession(authDir, provider, state, code, errorMessage string) (string, error) { + canonicalProvider, err := NormalizeOAuthProvider(provider) + if err != nil { + return "", err + } + if !IsOAuthSessionPending(state, canonicalProvider) { + return "", errOAuthSessionNotPending + } + return WriteOAuthCallbackFile(authDir, canonicalProvider, state, code, errorMessage) +} diff --git a/internal/api/server.go b/internal/api/server.go index e6d03bc3..d6fe91bf 100644 --- a/internal/api/server.go +++ b/internal/api/server.go @@ -354,10 +354,11 @@ func (s *Server) setupRoutes() { code := c.Query("code") state := c.Query("state") errStr := c.Query("error") - // Persist to a temporary file keyed by state + if errStr == "" { + errStr = c.Query("error_description") + } if state != "" { - file := fmt.Sprintf("%s/.oauth-anthropic-%s.oauth", s.cfg.AuthDir, state) - _ = os.WriteFile(file, []byte(fmt.Sprintf(`{"code":"%s","state":"%s","error":"%s"}`, code, state, errStr)), 0o600) + _, _ = managementHandlers.WriteOAuthCallbackFileForPendingSession(s.cfg.AuthDir, "anthropic", state, code, errStr) } c.Header("Content-Type", "text/html; charset=utf-8") c.String(http.StatusOK, oauthCallbackSuccessHTML) @@ -367,9 +368,11 @@ func (s *Server) setupRoutes() { code := c.Query("code") state := c.Query("state") errStr := c.Query("error") + if errStr == "" { + errStr = c.Query("error_description") + } if state != "" { - file := fmt.Sprintf("%s/.oauth-codex-%s.oauth", s.cfg.AuthDir, state) - _ = os.WriteFile(file, []byte(fmt.Sprintf(`{"code":"%s","state":"%s","error":"%s"}`, code, state, errStr)), 0o600) + _, _ = managementHandlers.WriteOAuthCallbackFileForPendingSession(s.cfg.AuthDir, "codex", state, code, errStr) } c.Header("Content-Type", "text/html; charset=utf-8") c.String(http.StatusOK, oauthCallbackSuccessHTML) @@ -379,9 +382,11 @@ func (s *Server) setupRoutes() { code := c.Query("code") state := c.Query("state") errStr := c.Query("error") + if errStr == "" { + errStr = c.Query("error_description") + } if state != "" { - file := fmt.Sprintf("%s/.oauth-gemini-%s.oauth", s.cfg.AuthDir, state) - _ = os.WriteFile(file, []byte(fmt.Sprintf(`{"code":"%s","state":"%s","error":"%s"}`, code, state, errStr)), 0o600) + _, _ = managementHandlers.WriteOAuthCallbackFileForPendingSession(s.cfg.AuthDir, "gemini", state, code, errStr) } c.Header("Content-Type", "text/html; charset=utf-8") c.String(http.StatusOK, oauthCallbackSuccessHTML) @@ -391,9 +396,11 @@ func (s *Server) setupRoutes() { code := c.Query("code") state := c.Query("state") errStr := c.Query("error") + if errStr == "" { + errStr = c.Query("error_description") + } if state != "" { - file := fmt.Sprintf("%s/.oauth-iflow-%s.oauth", s.cfg.AuthDir, state) - _ = os.WriteFile(file, []byte(fmt.Sprintf(`{"code":"%s","state":"%s","error":"%s"}`, code, state, errStr)), 0o600) + _, _ = managementHandlers.WriteOAuthCallbackFileForPendingSession(s.cfg.AuthDir, "iflow", state, code, errStr) } c.Header("Content-Type", "text/html; charset=utf-8") c.String(http.StatusOK, oauthCallbackSuccessHTML) @@ -403,9 +410,11 @@ func (s *Server) setupRoutes() { code := c.Query("code") state := c.Query("state") errStr := c.Query("error") + if errStr == "" { + errStr = c.Query("error_description") + } if state != "" { - file := fmt.Sprintf("%s/.oauth-antigravity-%s.oauth", s.cfg.AuthDir, state) - _ = os.WriteFile(file, []byte(fmt.Sprintf(`{"code":"%s","state":"%s","error":"%s"}`, code, state, errStr)), 0o600) + _, _ = managementHandlers.WriteOAuthCallbackFileForPendingSession(s.cfg.AuthDir, "antigravity", state, code, errStr) } c.Header("Content-Type", "text/html; charset=utf-8") c.String(http.StatusOK, oauthCallbackSuccessHTML) @@ -577,6 +586,7 @@ func (s *Server) registerManagementRoutes() { mgmt.GET("/qwen-auth-url", s.mgmt.RequestQwenToken) mgmt.GET("/iflow-auth-url", s.mgmt.RequestIFlowToken) mgmt.POST("/iflow-auth-url", s.mgmt.RequestIFlowCookieToken) + mgmt.POST("/oauth-callback", s.mgmt.PostOAuthCallback) mgmt.GET("/get-auth-status", s.mgmt.GetAuthStatus) } } From 1b8cb7b77bc3462aad2ce5458a82cfa89288dca5 Mon Sep 17 00:00:00 2001 From: Ben Vargas Date: Thu, 18 Dec 2025 12:50:51 -0700 Subject: [PATCH 03/38] fix: remove propertyNames from JSON schema for Gemini compatibility Gemini API does not support the JSON Schema `propertyNames` keyword, causing 400 errors when Claude tool schemas containing this field are proxied through the Antigravity provider. Add `propertyNames` to the list of unsupported keywords removed by CleanJSONSchemaForGemini(), alongside existing removals like $ref, definitions, and additionalProperties. --- internal/util/gemini_schema.go | 1 + internal/util/gemini_schema_test.go | 65 +++++++++++++++++++++++++++++ 2 files changed, 66 insertions(+) diff --git a/internal/util/gemini_schema.go b/internal/util/gemini_schema.go index b25d14e4..bbd3e339 100644 --- a/internal/util/gemini_schema.go +++ b/internal/util/gemini_schema.go @@ -296,6 +296,7 @@ func flattenTypeArrays(jsonStr string) string { func removeUnsupportedKeywords(jsonStr string) string { keywords := append(unsupportedConstraints, "$schema", "$defs", "definitions", "const", "$ref", "additionalProperties", + "propertyNames", // Gemini doesn't support property name validation ) for _, key := range keywords { for _, p := range findPaths(jsonStr, key) { diff --git a/internal/util/gemini_schema_test.go b/internal/util/gemini_schema_test.go index 655511d9..55a3c5fd 100644 --- a/internal/util/gemini_schema_test.go +++ b/internal/util/gemini_schema_test.go @@ -596,6 +596,71 @@ func TestCleanJSONSchemaForGemini_MultipleNonNullTypes(t *testing.T) { } } +func TestCleanJSONSchemaForGemini_PropertyNamesRemoval(t *testing.T) { + // propertyNames is used to validate object property names (e.g., must match a pattern) + // Gemini doesn't support this keyword and will reject requests containing it + input := `{ + "type": "object", + "properties": { + "metadata": { + "type": "object", + "propertyNames": { + "pattern": "^[a-zA-Z_][a-zA-Z0-9_]*$" + }, + "additionalProperties": { + "type": "string" + } + } + } + }` + + expected := `{ + "type": "object", + "properties": { + "metadata": { + "type": "object" + } + } + }` + + result := CleanJSONSchemaForGemini(input) + compareJSON(t, expected, result) + + // Verify propertyNames is completely removed + if strings.Contains(result, "propertyNames") { + t.Errorf("propertyNames keyword should be removed, got: %s", result) + } +} + +func TestCleanJSONSchemaForGemini_PropertyNamesRemoval_Nested(t *testing.T) { + // Test deeply nested propertyNames (as seen in real Claude tool schemas) + input := `{ + "type": "object", + "properties": { + "items": { + "type": "array", + "items": { + "type": "object", + "properties": { + "config": { + "type": "object", + "propertyNames": { + "type": "string" + } + } + } + } + } + } + }` + + result := CleanJSONSchemaForGemini(input) + + if strings.Contains(result, "propertyNames") { + t.Errorf("Nested propertyNames should be removed, got: %s", result) + } +} + func compareJSON(t *testing.T, expectedJSON, actualJSON string) { var expMap, actMap map[string]interface{} errExp := json.Unmarshal([]byte(expectedJSON), &expMap) From bbcb5552f34e6b107e8e52230d8cb446163865ad Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=EC=9D=B4=EB=8C=80=ED=9D=AC?= Date: Fri, 19 Dec 2025 10:27:24 +0900 Subject: [PATCH 04/38] feat(cache): add signature cache for Claude thinking blocks --- internal/cache/signature_cache.go | 166 +++++++++++++++++++ internal/cache/signature_cache_test.go | 216 +++++++++++++++++++++++++ 2 files changed, 382 insertions(+) create mode 100644 internal/cache/signature_cache.go create mode 100644 internal/cache/signature_cache_test.go diff --git a/internal/cache/signature_cache.go b/internal/cache/signature_cache.go new file mode 100644 index 00000000..12f19cf0 --- /dev/null +++ b/internal/cache/signature_cache.go @@ -0,0 +1,166 @@ +package cache + +import ( + "crypto/sha256" + "encoding/hex" + "sync" + "time" +) + +// SignatureEntry holds a cached thinking signature with timestamp +type SignatureEntry struct { + Signature string + Timestamp time.Time +} + +const ( + // SignatureCacheTTL is how long signatures are valid + SignatureCacheTTL = 1 * time.Hour + + // MaxEntriesPerSession limits memory usage per session + MaxEntriesPerSession = 100 + + // SignatureTextHashLen is the length of the hash key (16 hex chars = 64-bit key space) + SignatureTextHashLen = 16 + + // MinValidSignatureLen is the minimum length for a signature to be considered valid + MinValidSignatureLen = 50 +) + +// signatureCache stores signatures by sessionId -> textHash -> SignatureEntry +var signatureCache sync.Map + +// sessionCache is the inner map type +type sessionCache struct { + mu sync.RWMutex + entries map[string]SignatureEntry +} + +// hashText creates a stable, Unicode-safe key from text content +func hashText(text string) string { + h := sha256.Sum256([]byte(text)) + return hex.EncodeToString(h[:])[:SignatureTextHashLen] +} + +// getOrCreateSession gets or creates a session cache +func getOrCreateSession(sessionID string) *sessionCache { + if val, ok := signatureCache.Load(sessionID); ok { + return val.(*sessionCache) + } + sc := &sessionCache{entries: make(map[string]SignatureEntry)} + actual, _ := signatureCache.LoadOrStore(sessionID, sc) + return actual.(*sessionCache) +} + +// CacheSignature stores a thinking signature for a given session and text. +// Used for Claude models that require signed thinking blocks in multi-turn conversations. +func CacheSignature(sessionID, text, signature string) { + if sessionID == "" || text == "" || signature == "" { + return + } + if len(signature) < MinValidSignatureLen { + return + } + + sc := getOrCreateSession(sessionID) + textHash := hashText(text) + + sc.mu.Lock() + defer sc.mu.Unlock() + + // Evict expired entries if at capacity + if len(sc.entries) >= MaxEntriesPerSession { + now := time.Now() + for key, entry := range sc.entries { + if now.Sub(entry.Timestamp) > SignatureCacheTTL { + delete(sc.entries, key) + } + } + // If still at capacity, remove oldest entries + if len(sc.entries) >= MaxEntriesPerSession { + // Find and remove oldest quarter + oldest := make([]struct { + key string + ts time.Time + }, 0, len(sc.entries)) + for key, entry := range sc.entries { + oldest = append(oldest, struct { + key string + ts time.Time + }{key, entry.Timestamp}) + } + // Simple approach: remove first quarter of entries + toRemove := len(oldest) / 4 + if toRemove < 1 { + toRemove = 1 + } + // Sort by timestamp (oldest first) - simple bubble for small N + for i := 0; i < toRemove; i++ { + minIdx := i + for j := i + 1; j < len(oldest); j++ { + if oldest[j].ts.Before(oldest[minIdx].ts) { + minIdx = j + } + } + oldest[i], oldest[minIdx] = oldest[minIdx], oldest[i] + delete(sc.entries, oldest[i].key) + } + } + } + + sc.entries[textHash] = SignatureEntry{ + Signature: signature, + Timestamp: time.Now(), + } +} + +// GetCachedSignature retrieves a cached signature for a given session and text. +// Returns empty string if not found or expired. +func GetCachedSignature(sessionID, text string) string { + if sessionID == "" || text == "" { + return "" + } + + val, ok := signatureCache.Load(sessionID) + if !ok { + return "" + } + sc := val.(*sessionCache) + + textHash := hashText(text) + + sc.mu.RLock() + entry, exists := sc.entries[textHash] + sc.mu.RUnlock() + + if !exists { + return "" + } + + // Check if expired + if time.Since(entry.Timestamp) > SignatureCacheTTL { + sc.mu.Lock() + delete(sc.entries, textHash) + sc.mu.Unlock() + return "" + } + + return entry.Signature +} + +// ClearSignatureCache clears signature cache for a specific session or all sessions. +func ClearSignatureCache(sessionID string) { + if sessionID != "" { + signatureCache.Delete(sessionID) + } else { + signatureCache.Range(func(key, _ any) bool { + signatureCache.Delete(key) + return true + }) + } +} + +// HasValidSignature checks if a signature is valid (non-empty and long enough) +func HasValidSignature(signature string) bool { + return signature != "" && len(signature) >= MinValidSignatureLen +} diff --git a/internal/cache/signature_cache_test.go b/internal/cache/signature_cache_test.go new file mode 100644 index 00000000..e4bddbe4 --- /dev/null +++ b/internal/cache/signature_cache_test.go @@ -0,0 +1,216 @@ +package cache + +import ( + "testing" + "time" +) + +func TestCacheSignature_BasicStorageAndRetrieval(t *testing.T) { + ClearSignatureCache("") + + sessionID := "test-session-1" + text := "This is some thinking text content" + signature := "abc123validSignature1234567890123456789012345678901234567890" + + // Store signature + CacheSignature(sessionID, text, signature) + + // Retrieve signature + retrieved := GetCachedSignature(sessionID, text) + if retrieved != signature { + t.Errorf("Expected signature '%s', got '%s'", signature, retrieved) + } +} + +func TestCacheSignature_DifferentSessions(t *testing.T) { + ClearSignatureCache("") + + text := "Same text in different sessions" + sig1 := "signature1_1234567890123456789012345678901234567890123456" + sig2 := "signature2_1234567890123456789012345678901234567890123456" + + CacheSignature("session-a", text, sig1) + CacheSignature("session-b", text, sig2) + + if GetCachedSignature("session-a", text) != sig1 { + t.Error("Session-a signature mismatch") + } + if GetCachedSignature("session-b", text) != sig2 { + t.Error("Session-b signature mismatch") + } +} + +func TestCacheSignature_NotFound(t *testing.T) { + ClearSignatureCache("") + + // Non-existent session + if got := GetCachedSignature("nonexistent", "some text"); got != "" { + t.Errorf("Expected empty string for nonexistent session, got '%s'", got) + } + + // Existing session but different text + CacheSignature("session-x", "text-a", "sigA12345678901234567890123456789012345678901234567890") + if got := GetCachedSignature("session-x", "text-b"); got != "" { + t.Errorf("Expected empty string for different text, got '%s'", got) + } +} + +func TestCacheSignature_EmptyInputs(t *testing.T) { + ClearSignatureCache("") + + // All empty/invalid inputs should be no-ops + CacheSignature("", "text", "sig12345678901234567890123456789012345678901234567890") + CacheSignature("session", "", "sig12345678901234567890123456789012345678901234567890") + CacheSignature("session", "text", "") + CacheSignature("session", "text", "short") // Too short + + if got := GetCachedSignature("session", "text"); got != "" { + t.Errorf("Expected empty after invalid cache attempts, got '%s'", got) + } +} + +func TestCacheSignature_ShortSignatureRejected(t *testing.T) { + ClearSignatureCache("") + + sessionID := "test-short-sig" + text := "Some text" + shortSig := "abc123" // Less than 50 chars + + CacheSignature(sessionID, text, shortSig) + + if got := GetCachedSignature(sessionID, text); got != "" { + t.Errorf("Short signature should be rejected, got '%s'", got) + } +} + +func TestClearSignatureCache_SpecificSession(t *testing.T) { + ClearSignatureCache("") + + sig := "validSig1234567890123456789012345678901234567890123456" + CacheSignature("session-1", "text", sig) + CacheSignature("session-2", "text", sig) + + ClearSignatureCache("session-1") + + if got := GetCachedSignature("session-1", "text"); got != "" { + t.Error("session-1 should be cleared") + } + if got := GetCachedSignature("session-2", "text"); got != sig { + t.Error("session-2 should still exist") + } +} + +func TestClearSignatureCache_AllSessions(t *testing.T) { + ClearSignatureCache("") + + sig := "validSig1234567890123456789012345678901234567890123456" + CacheSignature("session-1", "text", sig) + CacheSignature("session-2", "text", sig) + + ClearSignatureCache("") + + if got := GetCachedSignature("session-1", "text"); got != "" { + t.Error("session-1 should be cleared") + } + if got := GetCachedSignature("session-2", "text"); got != "" { + t.Error("session-2 should be cleared") + } +} + +func TestHasValidSignature(t *testing.T) { + tests := []struct { + name string + signature string + expected bool + }{ + {"valid long signature", "abc123validSignature1234567890123456789012345678901234567890", true}, + {"exactly 50 chars", "12345678901234567890123456789012345678901234567890", true}, + {"49 chars - invalid", "1234567890123456789012345678901234567890123456789", false}, + {"empty string", "", false}, + {"short signature", "abc", false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := HasValidSignature(tt.signature) + if result != tt.expected { + t.Errorf("HasValidSignature(%q) = %v, expected %v", tt.signature, result, tt.expected) + } + }) + } +} + +func TestCacheSignature_TextHashCollisionResistance(t *testing.T) { + ClearSignatureCache("") + + sessionID := "hash-test-session" + + // Different texts should produce different hashes + text1 := "First thinking text" + text2 := "Second thinking text" + sig1 := "signature1_1234567890123456789012345678901234567890123456" + sig2 := "signature2_1234567890123456789012345678901234567890123456" + + CacheSignature(sessionID, text1, sig1) + CacheSignature(sessionID, text2, sig2) + + if GetCachedSignature(sessionID, text1) != sig1 { + t.Error("text1 signature mismatch") + } + if GetCachedSignature(sessionID, text2) != sig2 { + t.Error("text2 signature mismatch") + } +} + +func TestCacheSignature_UnicodeText(t *testing.T) { + ClearSignatureCache("") + + sessionID := "unicode-session" + text := "한글 텍스트와 이모지 🎉 그리고 特殊文字" + sig := "unicodeSig123456789012345678901234567890123456789012345" + + CacheSignature(sessionID, text, sig) + + if got := GetCachedSignature(sessionID, text); got != sig { + t.Errorf("Unicode text signature retrieval failed, got '%s'", got) + } +} + +func TestCacheSignature_Overwrite(t *testing.T) { + ClearSignatureCache("") + + sessionID := "overwrite-session" + text := "Same text" + sig1 := "firstSignature12345678901234567890123456789012345678901" + sig2 := "secondSignature1234567890123456789012345678901234567890" + + CacheSignature(sessionID, text, sig1) + CacheSignature(sessionID, text, sig2) // Overwrite + + if got := GetCachedSignature(sessionID, text); got != sig2 { + t.Errorf("Expected overwritten signature '%s', got '%s'", sig2, got) + } +} + +// Note: TTL expiration test is tricky to test without mocking time +// We test the logic path exists but actual expiration would require time manipulation +func TestCacheSignature_ExpirationLogic(t *testing.T) { + ClearSignatureCache("") + + // This test verifies the expiration check exists + // In a real scenario, we'd mock time.Now() + sessionID := "expiration-test" + text := "text" + sig := "validSig1234567890123456789012345678901234567890123456" + + CacheSignature(sessionID, text, sig) + + // Fresh entry should be retrievable + if got := GetCachedSignature(sessionID, text); got != sig { + t.Errorf("Fresh entry should be retrievable, got '%s'", got) + } + + // We can't easily test actual expiration without time mocking + // but the logic is verified by the implementation + _ = time.Now() // Acknowledge we're not testing time passage +} From 1bfa75f7801b80f2fd61d17293b8b23729c6e62b Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=EC=9D=B4=EB=8C=80=ED=9D=AC?= Date: Fri, 19 Dec 2025 10:27:24 +0900 Subject: [PATCH 05/38] feat(util): add helper to detect Claude thinking models --- internal/util/claude_model.go | 10 ++++++++ internal/util/claude_model_test.go | 41 ++++++++++++++++++++++++++++++ 2 files changed, 51 insertions(+) create mode 100644 internal/util/claude_model.go create mode 100644 internal/util/claude_model_test.go diff --git a/internal/util/claude_model.go b/internal/util/claude_model.go new file mode 100644 index 00000000..1534f02c --- /dev/null +++ b/internal/util/claude_model.go @@ -0,0 +1,10 @@ +package util + +import "strings" + +// IsClaudeThinkingModel checks if the model is a Claude thinking model +// that requires the interleaved-thinking beta header. +func IsClaudeThinkingModel(model string) bool { + lower := strings.ToLower(model) + return strings.Contains(lower, "claude") && strings.Contains(lower, "thinking") +} diff --git a/internal/util/claude_model_test.go b/internal/util/claude_model_test.go new file mode 100644 index 00000000..17f6106e --- /dev/null +++ b/internal/util/claude_model_test.go @@ -0,0 +1,41 @@ +package util + +import "testing" + +func TestIsClaudeThinkingModel(t *testing.T) { + tests := []struct { + name string + model string + expected bool + }{ + // Claude thinking models - should return true + {"claude-sonnet-4-5-thinking", "claude-sonnet-4-5-thinking", true}, + {"claude-opus-4-5-thinking", "claude-opus-4-5-thinking", true}, + {"Claude-Sonnet-Thinking uppercase", "Claude-Sonnet-4-5-Thinking", true}, + {"claude thinking mixed case", "Claude-THINKING-Model", true}, + + // Non-thinking Claude models - should return false + {"claude-sonnet-4-5 (no thinking)", "claude-sonnet-4-5", false}, + {"claude-opus-4-5 (no thinking)", "claude-opus-4-5", false}, + {"claude-3-5-sonnet", "claude-3-5-sonnet-20240620", false}, + + // Non-Claude models - should return false + {"gemini-3-pro-preview", "gemini-3-pro-preview", false}, + {"gemini-thinking model", "gemini-3-pro-thinking", false}, // not Claude + {"gpt-4o", "gpt-4o", false}, + {"empty string", "", false}, + + // Edge cases + {"thinking without claude", "thinking-model", false}, + {"claude without thinking", "claude-model", false}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := IsClaudeThinkingModel(tt.model) + if result != tt.expected { + t.Errorf("IsClaudeThinkingModel(%q) = %v, expected %v", tt.model, result, tt.expected) + } + }) + } +} From e44167d7a482aea95441243fdaa4e00355173971 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=EC=9D=B4=EB=8C=80=ED=9D=AC?= Date: Fri, 19 Dec 2025 10:27:24 +0900 Subject: [PATCH 06/38] refactor(util/schema): rename and extend Gemini schema cleaning for Antigravity and add empty-schema placeholders --- internal/util/gemini_schema.go | 56 +++++- internal/util/gemini_schema_test.go | 293 +++++++++++++++++++++++----- 2 files changed, 302 insertions(+), 47 deletions(-) diff --git a/internal/util/gemini_schema.go b/internal/util/gemini_schema.go index b25d14e4..7ca9cf79 100644 --- a/internal/util/gemini_schema.go +++ b/internal/util/gemini_schema.go @@ -12,10 +12,10 @@ import ( var gjsonPathKeyReplacer = strings.NewReplacer(".", "\\.", "*", "\\*", "?", "\\?") -// CleanJSONSchemaForGemini transforms a JSON schema to be compatible with Gemini/Antigravity API. +// CleanJSONSchemaForAntigravity transforms a JSON schema to be compatible with Antigravity API. // It handles unsupported keywords, type flattening, and schema simplification while preserving // semantic information as description hints. -func CleanJSONSchemaForGemini(jsonStr string) string { +func CleanJSONSchemaForAntigravity(jsonStr string) string { // Phase 1: Convert and add hints jsonStr = convertRefsToHints(jsonStr) jsonStr = convertConstToEnum(jsonStr) @@ -32,6 +32,9 @@ func CleanJSONSchemaForGemini(jsonStr string) string { jsonStr = removeUnsupportedKeywords(jsonStr) jsonStr = cleanupRequiredFields(jsonStr) + // Phase 4: Add placeholder for empty object schemas (Claude VALIDATED mode requirement) + jsonStr = addEmptySchemaPlaceholder(jsonStr) + return jsonStr } @@ -105,7 +108,8 @@ func addAdditionalPropertiesHints(jsonStr string) string { var unsupportedConstraints = []string{ "minLength", "maxLength", "exclusiveMinimum", "exclusiveMaximum", - "pattern", "minItems", "maxItems", + "pattern", "minItems", "maxItems", "format", + "default", "examples", // Claude rejects these in VALIDATED mode } func moveConstraintsToDescription(jsonStr string) string { @@ -338,6 +342,52 @@ func cleanupRequiredFields(jsonStr string) string { return jsonStr } +// addEmptySchemaPlaceholder adds a placeholder "reason" property to empty object schemas. +// Claude VALIDATED mode requires at least one property in tool schemas. +func addEmptySchemaPlaceholder(jsonStr string) string { + // Find all "type" fields + paths := findPaths(jsonStr, "type") + + // Process from deepest to shallowest (to handle nested objects properly) + sortByDepth(paths) + + for _, p := range paths { + typeVal := gjson.Get(jsonStr, p) + if typeVal.String() != "object" { + continue + } + + // Get the parent path (the object containing "type") + parentPath := trimSuffix(p, ".type") + + // Check if properties exists and is empty or missing + propsPath := joinPath(parentPath, "properties") + propsVal := gjson.Get(jsonStr, propsPath) + + needsPlaceholder := false + if !propsVal.Exists() { + // No properties field at all + needsPlaceholder = true + } else if propsVal.IsObject() && len(propsVal.Map()) == 0 { + // Empty properties object + needsPlaceholder = true + } + + if needsPlaceholder { + // Add placeholder "reason" property + reasonPath := joinPath(propsPath, "reason") + jsonStr, _ = sjson.Set(jsonStr, reasonPath+".type", "string") + jsonStr, _ = sjson.Set(jsonStr, reasonPath+".description", "Brief explanation of why you are calling this tool") + + // Add to required array + reqPath := joinPath(parentPath, "required") + jsonStr, _ = sjson.Set(jsonStr, reqPath, []string{"reason"}) + } + } + + return jsonStr +} + // --- Helpers --- func findPaths(jsonStr, field string) []string { diff --git a/internal/util/gemini_schema_test.go b/internal/util/gemini_schema_test.go index 655511d9..0f9e3eba 100644 --- a/internal/util/gemini_schema_test.go +++ b/internal/util/gemini_schema_test.go @@ -5,9 +5,11 @@ import ( "reflect" "strings" "testing" + + "github.com/tidwall/gjson" ) -func TestCleanJSONSchemaForGemini_ConstToEnum(t *testing.T) { +func TestCleanJSONSchemaForAntigravity_ConstToEnum(t *testing.T) { input := `{ "type": "object", "properties": { @@ -28,11 +30,11 @@ func TestCleanJSONSchemaForGemini_ConstToEnum(t *testing.T) { } }` - result := CleanJSONSchemaForGemini(input) + result := CleanJSONSchemaForAntigravity(input) compareJSON(t, expected, result) } -func TestCleanJSONSchemaForGemini_TypeFlattening_Nullable(t *testing.T) { +func TestCleanJSONSchemaForAntigravity_TypeFlattening_Nullable(t *testing.T) { input := `{ "type": "object", "properties": { @@ -60,11 +62,11 @@ func TestCleanJSONSchemaForGemini_TypeFlattening_Nullable(t *testing.T) { "required": ["other"] }` - result := CleanJSONSchemaForGemini(input) + result := CleanJSONSchemaForAntigravity(input) compareJSON(t, expected, result) } -func TestCleanJSONSchemaForGemini_ConstraintsToDescription(t *testing.T) { +func TestCleanJSONSchemaForAntigravity_ConstraintsToDescription(t *testing.T) { input := `{ "type": "object", "properties": { @@ -81,7 +83,7 @@ func TestCleanJSONSchemaForGemini_ConstraintsToDescription(t *testing.T) { } }` - result := CleanJSONSchemaForGemini(input) + result := CleanJSONSchemaForAntigravity(input) // minItems should be REMOVED and moved to description if strings.Contains(result, `"minItems"`) { @@ -100,7 +102,7 @@ func TestCleanJSONSchemaForGemini_ConstraintsToDescription(t *testing.T) { } } -func TestCleanJSONSchemaForGemini_AnyOfFlattening_SmartSelection(t *testing.T) { +func TestCleanJSONSchemaForAntigravity_AnyOfFlattening_SmartSelection(t *testing.T) { input := `{ "type": "object", "properties": { @@ -131,11 +133,11 @@ func TestCleanJSONSchemaForGemini_AnyOfFlattening_SmartSelection(t *testing.T) { } }` - result := CleanJSONSchemaForGemini(input) + result := CleanJSONSchemaForAntigravity(input) compareJSON(t, expected, result) } -func TestCleanJSONSchemaForGemini_OneOfFlattening(t *testing.T) { +func TestCleanJSONSchemaForAntigravity_OneOfFlattening(t *testing.T) { input := `{ "type": "object", "properties": { @@ -158,11 +160,11 @@ func TestCleanJSONSchemaForGemini_OneOfFlattening(t *testing.T) { } }` - result := CleanJSONSchemaForGemini(input) + result := CleanJSONSchemaForAntigravity(input) compareJSON(t, expected, result) } -func TestCleanJSONSchemaForGemini_AllOfMerging(t *testing.T) { +func TestCleanJSONSchemaForAntigravity_AllOfMerging(t *testing.T) { input := `{ "type": "object", "allOf": [ @@ -190,11 +192,11 @@ func TestCleanJSONSchemaForGemini_AllOfMerging(t *testing.T) { "required": ["a", "b"] }` - result := CleanJSONSchemaForGemini(input) + result := CleanJSONSchemaForAntigravity(input) compareJSON(t, expected, result) } -func TestCleanJSONSchemaForGemini_RefHandling(t *testing.T) { +func TestCleanJSONSchemaForAntigravity_RefHandling(t *testing.T) { input := `{ "definitions": { "User": { @@ -210,21 +212,29 @@ func TestCleanJSONSchemaForGemini_RefHandling(t *testing.T) { } }` + // After $ref is converted to placeholder object, empty schema placeholder is also added expected := `{ "type": "object", "properties": { "customer": { "type": "object", - "description": "See: User" + "description": "See: User", + "properties": { + "reason": { + "type": "string", + "description": "Brief explanation of why you are calling this tool" + } + }, + "required": ["reason"] } } }` - result := CleanJSONSchemaForGemini(input) + result := CleanJSONSchemaForAntigravity(input) compareJSON(t, expected, result) } -func TestCleanJSONSchemaForGemini_RefHandling_DescriptionEscaping(t *testing.T) { +func TestCleanJSONSchemaForAntigravity_RefHandling_DescriptionEscaping(t *testing.T) { input := `{ "definitions": { "User": { @@ -243,21 +253,29 @@ func TestCleanJSONSchemaForGemini_RefHandling_DescriptionEscaping(t *testing.T) } }` + // After $ref is converted, empty schema placeholder is also added expected := `{ "type": "object", "properties": { "customer": { "type": "object", - "description": "He said \"hi\"\\nsecond line (See: User)" + "description": "He said \"hi\"\\nsecond line (See: User)", + "properties": { + "reason": { + "type": "string", + "description": "Brief explanation of why you are calling this tool" + } + }, + "required": ["reason"] } } }` - result := CleanJSONSchemaForGemini(input) + result := CleanJSONSchemaForAntigravity(input) compareJSON(t, expected, result) } -func TestCleanJSONSchemaForGemini_CyclicRefDefaults(t *testing.T) { +func TestCleanJSONSchemaForAntigravity_CyclicRefDefaults(t *testing.T) { input := `{ "definitions": { "Node": { @@ -270,7 +288,7 @@ func TestCleanJSONSchemaForGemini_CyclicRefDefaults(t *testing.T) { "$ref": "#/definitions/Node" }` - result := CleanJSONSchemaForGemini(input) + result := CleanJSONSchemaForAntigravity(input) var resMap map[string]interface{} json.Unmarshal([]byte(result), &resMap) @@ -285,7 +303,7 @@ func TestCleanJSONSchemaForGemini_CyclicRefDefaults(t *testing.T) { } } -func TestCleanJSONSchemaForGemini_RequiredCleanup(t *testing.T) { +func TestCleanJSONSchemaForAntigravity_RequiredCleanup(t *testing.T) { input := `{ "type": "object", "properties": { @@ -304,11 +322,11 @@ func TestCleanJSONSchemaForGemini_RequiredCleanup(t *testing.T) { "required": ["a", "b"] }` - result := CleanJSONSchemaForGemini(input) + result := CleanJSONSchemaForAntigravity(input) compareJSON(t, expected, result) } -func TestCleanJSONSchemaForGemini_AllOfMerging_DotKeys(t *testing.T) { +func TestCleanJSONSchemaForAntigravity_AllOfMerging_DotKeys(t *testing.T) { input := `{ "type": "object", "allOf": [ @@ -336,11 +354,11 @@ func TestCleanJSONSchemaForGemini_AllOfMerging_DotKeys(t *testing.T) { "required": ["my.param", "b"] }` - result := CleanJSONSchemaForGemini(input) + result := CleanJSONSchemaForAntigravity(input) compareJSON(t, expected, result) } -func TestCleanJSONSchemaForGemini_PropertyNameCollision(t *testing.T) { +func TestCleanJSONSchemaForAntigravity_PropertyNameCollision(t *testing.T) { // A tool has an argument named "pattern" - should NOT be treated as a constraint input := `{ "type": "object", @@ -364,7 +382,7 @@ func TestCleanJSONSchemaForGemini_PropertyNameCollision(t *testing.T) { "required": ["pattern"] }` - result := CleanJSONSchemaForGemini(input) + result := CleanJSONSchemaForAntigravity(input) compareJSON(t, expected, result) var resMap map[string]interface{} @@ -375,7 +393,7 @@ func TestCleanJSONSchemaForGemini_PropertyNameCollision(t *testing.T) { } } -func TestCleanJSONSchemaForGemini_DotKeys(t *testing.T) { +func TestCleanJSONSchemaForAntigravity_DotKeys(t *testing.T) { input := `{ "type": "object", "properties": { @@ -389,7 +407,7 @@ func TestCleanJSONSchemaForGemini_DotKeys(t *testing.T) { } }` - result := CleanJSONSchemaForGemini(input) + result := CleanJSONSchemaForAntigravity(input) var resMap map[string]interface{} if err := json.Unmarshal([]byte(result), &resMap); err != nil { @@ -414,7 +432,7 @@ func TestCleanJSONSchemaForGemini_DotKeys(t *testing.T) { } } -func TestCleanJSONSchemaForGemini_AnyOfAlternativeHints(t *testing.T) { +func TestCleanJSONSchemaForAntigravity_AnyOfAlternativeHints(t *testing.T) { input := `{ "type": "object", "properties": { @@ -428,7 +446,7 @@ func TestCleanJSONSchemaForGemini_AnyOfAlternativeHints(t *testing.T) { } }` - result := CleanJSONSchemaForGemini(input) + result := CleanJSONSchemaForAntigravity(input) if !strings.Contains(result, "Accepts:") { t.Errorf("Expected alternative types hint, got: %s", result) @@ -438,7 +456,7 @@ func TestCleanJSONSchemaForGemini_AnyOfAlternativeHints(t *testing.T) { } } -func TestCleanJSONSchemaForGemini_NullableHint(t *testing.T) { +func TestCleanJSONSchemaForAntigravity_NullableHint(t *testing.T) { input := `{ "type": "object", "properties": { @@ -450,7 +468,7 @@ func TestCleanJSONSchemaForGemini_NullableHint(t *testing.T) { "required": ["name"] }` - result := CleanJSONSchemaForGemini(input) + result := CleanJSONSchemaForAntigravity(input) if !strings.Contains(result, "(nullable)") { t.Errorf("Expected nullable hint, got: %s", result) @@ -460,7 +478,7 @@ func TestCleanJSONSchemaForGemini_NullableHint(t *testing.T) { } } -func TestCleanJSONSchemaForGemini_TypeFlattening_Nullable_DotKey(t *testing.T) { +func TestCleanJSONSchemaForAntigravity_TypeFlattening_Nullable_DotKey(t *testing.T) { input := `{ "type": "object", "properties": { @@ -488,11 +506,11 @@ func TestCleanJSONSchemaForGemini_TypeFlattening_Nullable_DotKey(t *testing.T) { "required": ["other"] }` - result := CleanJSONSchemaForGemini(input) + result := CleanJSONSchemaForAntigravity(input) compareJSON(t, expected, result) } -func TestCleanJSONSchemaForGemini_EnumHint(t *testing.T) { +func TestCleanJSONSchemaForAntigravity_EnumHint(t *testing.T) { input := `{ "type": "object", "properties": { @@ -504,7 +522,7 @@ func TestCleanJSONSchemaForGemini_EnumHint(t *testing.T) { } }` - result := CleanJSONSchemaForGemini(input) + result := CleanJSONSchemaForAntigravity(input) if !strings.Contains(result, "Allowed:") { t.Errorf("Expected enum values hint, got: %s", result) @@ -514,7 +532,7 @@ func TestCleanJSONSchemaForGemini_EnumHint(t *testing.T) { } } -func TestCleanJSONSchemaForGemini_AdditionalPropertiesHint(t *testing.T) { +func TestCleanJSONSchemaForAntigravity_AdditionalPropertiesHint(t *testing.T) { input := `{ "type": "object", "properties": { @@ -523,14 +541,14 @@ func TestCleanJSONSchemaForGemini_AdditionalPropertiesHint(t *testing.T) { "additionalProperties": false }` - result := CleanJSONSchemaForGemini(input) + result := CleanJSONSchemaForAntigravity(input) if !strings.Contains(result, "No extra properties allowed") { t.Errorf("Expected additionalProperties hint, got: %s", result) } } -func TestCleanJSONSchemaForGemini_AnyOfFlattening_PreservesDescription(t *testing.T) { +func TestCleanJSONSchemaForAntigravity_AnyOfFlattening_PreservesDescription(t *testing.T) { input := `{ "type": "object", "properties": { @@ -554,11 +572,11 @@ func TestCleanJSONSchemaForGemini_AnyOfFlattening_PreservesDescription(t *testin } }` - result := CleanJSONSchemaForGemini(input) + result := CleanJSONSchemaForAntigravity(input) compareJSON(t, expected, result) } -func TestCleanJSONSchemaForGemini_SingleEnumNoHint(t *testing.T) { +func TestCleanJSONSchemaForAntigravity_SingleEnumNoHint(t *testing.T) { input := `{ "type": "object", "properties": { @@ -569,14 +587,14 @@ func TestCleanJSONSchemaForGemini_SingleEnumNoHint(t *testing.T) { } }` - result := CleanJSONSchemaForGemini(input) + result := CleanJSONSchemaForAntigravity(input) if strings.Contains(result, "Allowed:") { t.Errorf("Single value enum should not add Allowed hint, got: %s", result) } } -func TestCleanJSONSchemaForGemini_MultipleNonNullTypes(t *testing.T) { +func TestCleanJSONSchemaForAntigravity_MultipleNonNullTypes(t *testing.T) { input := `{ "type": "object", "properties": { @@ -586,7 +604,7 @@ func TestCleanJSONSchemaForGemini_MultipleNonNullTypes(t *testing.T) { } }` - result := CleanJSONSchemaForGemini(input) + result := CleanJSONSchemaForAntigravity(input) if !strings.Contains(result, "Accepts:") { t.Errorf("Expected multiple types hint, got: %s", result) @@ -611,3 +629,190 @@ func compareJSON(t *testing.T, expectedJSON, actualJSON string) { t.Errorf("JSON mismatch:\nExpected:\n%s\n\nActual:\n%s", string(expBytes), string(actBytes)) } } + +// ============================================================================ +// P0-1: Empty Schema Placeholder Tests +// ============================================================================ + +func TestCleanJSONSchemaForAntigravity_EmptySchemaPlaceholder(t *testing.T) { + // Empty object schema with no properties should get a placeholder + input := `{ + "type": "object" + }` + + result := CleanJSONSchemaForAntigravity(input) + + // Should have placeholder property added + if !strings.Contains(result, `"reason"`) { + t.Errorf("Empty schema should have 'reason' placeholder property, got: %s", result) + } + if !strings.Contains(result, `"required"`) { + t.Errorf("Empty schema should have 'required' with 'reason', got: %s", result) + } +} + +func TestCleanJSONSchemaForAntigravity_EmptyPropertiesPlaceholder(t *testing.T) { + // Object with empty properties object + input := `{ + "type": "object", + "properties": {} + }` + + result := CleanJSONSchemaForAntigravity(input) + + // Should have placeholder property added + if !strings.Contains(result, `"reason"`) { + t.Errorf("Empty properties should have 'reason' placeholder, got: %s", result) + } +} + +func TestCleanJSONSchemaForAntigravity_NonEmptySchemaUnchanged(t *testing.T) { + // Schema with properties should NOT get placeholder + input := `{ + "type": "object", + "properties": { + "name": {"type": "string"} + }, + "required": ["name"] + }` + + result := CleanJSONSchemaForAntigravity(input) + + // Should NOT have placeholder property + if strings.Contains(result, `"reason"`) { + t.Errorf("Non-empty schema should NOT have 'reason' placeholder, got: %s", result) + } + // Original properties should be preserved + if !strings.Contains(result, `"name"`) { + t.Errorf("Original property 'name' should be preserved, got: %s", result) + } +} + +func TestCleanJSONSchemaForAntigravity_NestedEmptySchema(t *testing.T) { + // Nested empty object in items should also get placeholder + input := `{ + "type": "object", + "properties": { + "items": { + "type": "array", + "items": { + "type": "object" + } + } + } + }` + + result := CleanJSONSchemaForAntigravity(input) + + // Nested empty object should also get placeholder + // Check that the nested object has a reason property + parsed := gjson.Parse(result) + nestedProps := parsed.Get("properties.items.items.properties") + if !nestedProps.Exists() || !nestedProps.Get("reason").Exists() { + t.Errorf("Nested empty object should have 'reason' placeholder, got: %s", result) + } +} + +func TestCleanJSONSchemaForAntigravity_EmptySchemaWithDescription(t *testing.T) { + // Empty schema with description should preserve description and add placeholder + input := `{ + "type": "object", + "description": "An empty object" + }` + + result := CleanJSONSchemaForAntigravity(input) + + // Should have both description and placeholder + if !strings.Contains(result, `"An empty object"`) { + t.Errorf("Description should be preserved, got: %s", result) + } + if !strings.Contains(result, `"reason"`) { + t.Errorf("Empty schema should have 'reason' placeholder, got: %s", result) + } +} + +// ============================================================================ +// P0-2: Format field handling (ad-hoc patch removal) +// ============================================================================ + +func TestCleanJSONSchemaForAntigravity_FormatFieldRemoval(t *testing.T) { + // format:"uri" should be removed and added as hint + input := `{ + "type": "object", + "properties": { + "url": { + "type": "string", + "format": "uri", + "description": "A URL" + } + } + }` + + result := CleanJSONSchemaForAntigravity(input) + + // format should be removed + if strings.Contains(result, `"format"`) { + t.Errorf("format field should be removed, got: %s", result) + } + // hint should be added to description + if !strings.Contains(result, "format: uri") { + t.Errorf("format hint should be added to description, got: %s", result) + } + // original description should be preserved + if !strings.Contains(result, "A URL") { + t.Errorf("Original description should be preserved, got: %s", result) + } +} + +func TestCleanJSONSchemaForAntigravity_FormatFieldNoDescription(t *testing.T) { + // format without description should create description with hint + input := `{ + "type": "object", + "properties": { + "email": { + "type": "string", + "format": "email" + } + } + }` + + result := CleanJSONSchemaForAntigravity(input) + + // format should be removed + if strings.Contains(result, `"format"`) { + t.Errorf("format field should be removed, got: %s", result) + } + // hint should be added + if !strings.Contains(result, "format: email") { + t.Errorf("format hint should be added, got: %s", result) + } +} + +func TestCleanJSONSchemaForAntigravity_MultipleFormats(t *testing.T) { + // Multiple format fields should all be handled + input := `{ + "type": "object", + "properties": { + "url": {"type": "string", "format": "uri"}, + "email": {"type": "string", "format": "email"}, + "date": {"type": "string", "format": "date-time"} + } + }` + + result := CleanJSONSchemaForAntigravity(input) + + // All format fields should be removed + if strings.Contains(result, `"format"`) { + t.Errorf("All format fields should be removed, got: %s", result) + } + // All hints should be added + if !strings.Contains(result, "format: uri") { + t.Errorf("uri format hint should be added, got: %s", result) + } + if !strings.Contains(result, "format: email") { + t.Errorf("email format hint should be added, got: %s", result) + } + if !strings.Contains(result, "format: date-time") { + t.Errorf("date-time format hint should be added, got: %s", result) + } +} From b6ba15fcbd2c163555f439170790eecd76196513 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=EC=9D=B4=EB=8C=80=ED=9D=AC?= Date: Fri, 19 Dec 2025 10:27:24 +0900 Subject: [PATCH 07/38] fix(runtime/executor): Antigravity executor schema handling and Claude-specific headers --- .../runtime/executor/antigravity_executor.go | 36 ++++++++++--------- .../claude/antigravity_claude_request.go | 21 +++++++++++ 2 files changed, 41 insertions(+), 16 deletions(-) diff --git a/internal/runtime/executor/antigravity_executor.go b/internal/runtime/executor/antigravity_executor.go index 8b4e37ee..6be5bf46 100644 --- a/internal/runtime/executor/antigravity_executor.go +++ b/internal/runtime/executor/antigravity_executor.go @@ -70,10 +70,6 @@ func (e *AntigravityExecutor) PrepareRequest(_ *http.Request, _ *cliproxyauth.Au // Execute performs a non-streaming request to the Antigravity API. func (e *AntigravityExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (resp cliproxyexecutor.Response, err error) { - if strings.Contains(req.Model, "claude") { - return e.executeClaudeNonStream(ctx, auth, req, opts) - } - token, updatedAuth, errToken := e.ensureAccessToken(ctx, auth) if errToken != nil { return resp, errToken @@ -997,21 +993,23 @@ func (e *AntigravityExecutor) buildRequest(ctx context.Context, auth *cliproxyau payload = geminiToAntigravity(modelName, payload, projectID) payload, _ = sjson.SetBytes(payload, "model", alias2ModelName(modelName)) - if strings.Contains(modelName, "claude") { - strJSON := string(payload) - paths := make([]string, 0) - util.Walk(gjson.ParseBytes(payload), "", "parametersJsonSchema", &paths) - for _, p := range paths { - strJSON, _ = util.RenameKey(strJSON, p, p[:len(p)-len("parametersJsonSchema")]+"parameters") - } + // Apply schema processing for all Antigravity models (Claude, Gemini, GPT-OSS) + // Antigravity uses unified Gemini-style format with same schema restrictions + strJSON := string(payload) - // Use the centralized schema cleaner to handle unsupported keywords, - // const->enum conversion, and flattening of types/anyOf. - strJSON = util.CleanJSONSchemaForGemini(strJSON) - - payload = []byte(strJSON) + // Rename parametersJsonSchema -> parameters (used by Claude translator) + paths := make([]string, 0) + util.Walk(gjson.ParseBytes(payload), "", "parametersJsonSchema", &paths) + for _, p := range paths { + strJSON, _ = util.RenameKey(strJSON, p, p[:len(p)-len("parametersJsonSchema")]+"parameters") } + // Use the centralized schema cleaner to handle unsupported keywords, + // const->enum conversion, and flattening of types/anyOf. + strJSON = util.CleanJSONSchemaForAntigravity(strJSON) + + payload = []byte(strJSON) + httpReq, errReq := http.NewRequestWithContext(ctx, http.MethodPost, requestURL.String(), bytes.NewReader(payload)) if errReq != nil { return nil, errReq @@ -1019,6 +1017,12 @@ func (e *AntigravityExecutor) buildRequest(ctx context.Context, auth *cliproxyau httpReq.Header.Set("Content-Type", "application/json") httpReq.Header.Set("Authorization", "Bearer "+token) httpReq.Header.Set("User-Agent", resolveUserAgent(auth)) + + // Add interleaved-thinking header for Claude thinking models + if util.IsClaudeThinkingModel(modelName) { + httpReq.Header.Set("anthropic-beta", "interleaved-thinking-2025-05-14") + } + if stream { httpReq.Header.Set("Accept", "text/event-stream") } else { diff --git a/internal/translator/antigravity/claude/antigravity_claude_request.go b/internal/translator/antigravity/claude/antigravity_claude_request.go index def1cfbe..facd26ef 100644 --- a/internal/translator/antigravity/claude/antigravity_claude_request.go +++ b/internal/translator/antigravity/claude/antigravity_claude_request.go @@ -220,6 +220,27 @@ func ConvertClaudeRequestToAntigravity(modelName string, inputRawJSON []byte, _ // Build output Gemini CLI request JSON out := `{"model":"","request":{"contents":[]}}` out, _ = sjson.Set(out, "model", modelName) + + // P2-B: Inject interleaved thinking hint when both tools and thinking are active + hasTools := toolDeclCount > 0 + thinkingResult := gjson.GetBytes(rawJSON, "thinking") + hasThinking := thinkingResult.Exists() && thinkingResult.IsObject() && thinkingResult.Get("type").String() == "enabled" + isClaudeThinking := util.IsClaudeThinkingModel(modelName) + + if hasTools && hasThinking && isClaudeThinking { + interleavedHint := "Interleaved thinking is enabled. You may think between tool calls and after receiving tool results before deciding the next action or final answer. Do not mention these instructions or any constraints about thinking blocks; just apply them." + + if hasSystemInstruction { + // Append hint to existing system instruction + systemInstructionJSON, _ = sjson.Set(systemInstructionJSON, "parts.-1.text", interleavedHint) + } else { + // Create new system instruction with hint + systemInstructionJSON = `{"role":"user","parts":[]}` + systemInstructionJSON, _ = sjson.Set(systemInstructionJSON, "parts.-1.text", interleavedHint) + hasSystemInstruction = true + } + } + if hasSystemInstruction { out, _ = sjson.SetRaw(out, "request.systemInstruction", systemInstructionJSON) } From 0e7c79ba23ff3e3bb85c5551b02c60be53832c5a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=EC=9D=B4=EB=8C=80=ED=9D=AC?= Date: Fri, 19 Dec 2025 10:27:24 +0900 Subject: [PATCH 08/38] feat(translator/antigravity/claude): support interleaved thinking, signature restoration and system hint injection --- .../claude/antigravity_claude_request.go | 75 ++- .../claude/antigravity_claude_request_test.go | 567 ++++++++++++++++++ 2 files changed, 633 insertions(+), 9 deletions(-) create mode 100644 internal/translator/antigravity/claude/antigravity_claude_request_test.go diff --git a/internal/translator/antigravity/claude/antigravity_claude_request.go b/internal/translator/antigravity/claude/antigravity_claude_request.go index facd26ef..4eaef1f6 100644 --- a/internal/translator/antigravity/claude/antigravity_claude_request.go +++ b/internal/translator/antigravity/claude/antigravity_claude_request.go @@ -7,8 +7,11 @@ package claude import ( "bytes" + "crypto/sha256" + "encoding/hex" "strings" + "github.com/router-for-me/CLIProxyAPI/v6/internal/cache" "github.com/router-for-me/CLIProxyAPI/v6/internal/translator/gemini/common" "github.com/router-for-me/CLIProxyAPI/v6/internal/util" "github.com/tidwall/gjson" @@ -17,6 +20,29 @@ import ( const geminiCLIClaudeThoughtSignature = "skip_thought_signature_validator" +// deriveSessionID generates a stable session ID from the request. +// Uses the hash of the first user message to identify the conversation. +func deriveSessionID(rawJSON []byte) string { + messages := gjson.GetBytes(rawJSON, "messages") + if !messages.IsArray() { + return "" + } + for _, msg := range messages.Array() { + if msg.Get("role").String() == "user" { + content := msg.Get("content").String() + if content == "" { + // Try to get text from content array + content = msg.Get("content.0.text").String() + } + if content != "" { + h := sha256.Sum256([]byte(content)) + return hex.EncodeToString(h[:16]) + } + } + } + return "" +} + // ConvertClaudeRequestToAntigravity parses and transforms a Claude Code API request into Gemini CLI API format. // It extracts the model name, system instruction, message contents, and tool declarations // from the raw JSON request and returns them in the format expected by the Gemini CLI API. @@ -37,7 +63,9 @@ const geminiCLIClaudeThoughtSignature = "skip_thought_signature_validator" // - []byte: The transformed request data in Gemini CLI API format func ConvertClaudeRequestToAntigravity(modelName string, inputRawJSON []byte, _ bool) []byte { rawJSON := bytes.Clone(inputRawJSON) - rawJSON = bytes.Replace(rawJSON, []byte(`"url":{"type":"string","format":"uri",`), []byte(`"url":{"type":"string",`), -1) + + // Derive session ID for signature caching + sessionID := deriveSessionID(rawJSON) // system instruction systemInstructionJSON := "" @@ -67,13 +95,15 @@ func ConvertClaudeRequestToAntigravity(modelName string, inputRawJSON []byte, _ messagesResult := gjson.GetBytes(rawJSON, "messages") if messagesResult.IsArray() { messageResults := messagesResult.Array() - for i := 0; i < len(messageResults); i++ { + numMessages := len(messageResults) + for i := 0; i < numMessages; i++ { messageResult := messageResults[i] roleResult := messageResult.Get("role") if roleResult.Type != gjson.String { continue } - role := roleResult.String() + originalRole := roleResult.String() + role := originalRole if role == "assistant" { role = "model" } @@ -82,20 +112,47 @@ func ConvertClaudeRequestToAntigravity(modelName string, inputRawJSON []byte, _ contentsResult := messageResult.Get("content") if contentsResult.IsArray() { contentResults := contentsResult.Array() - for j := 0; j < len(contentResults); j++ { + numContents := len(contentResults) + for j := 0; j < numContents; j++ { contentResult := contentResults[j] contentTypeResult := contentResult.Get("type") if contentTypeResult.Type == gjson.String && contentTypeResult.String() == "thinking" { - prompt := contentResult.Get("thinking").String() + thinkingText := contentResult.Get("thinking").String() signatureResult := contentResult.Get("signature") - signature := geminiCLIClaudeThoughtSignature - if signatureResult.Exists() { + signature := "" + if signatureResult.Exists() && signatureResult.String() != "" { signature = signatureResult.String() } + + // P3: Try to restore signature from cache for unsigned thinking blocks + if !cache.HasValidSignature(signature) && sessionID != "" && thinkingText != "" { + if cachedSig := cache.GetCachedSignature(sessionID, thinkingText); cachedSig != "" { + signature = cachedSig + log.Debugf("Restored cached signature for thinking block") + } + } + + // P2-A: Skip trailing unsigned thinking blocks on last assistant message + isLastMessage := (i == numMessages-1) + isLastContent := (j == numContents-1) + isAssistant := (originalRole == "assistant") + isUnsigned := !cache.HasValidSignature(signature) + + if isLastMessage && isLastContent && isAssistant && isUnsigned { + // Skip this trailing unsigned thinking block + continue + } + + // Apply sentinel for unsigned thinking blocks that are not trailing + // (includes empty string and short/invalid signatures < 50 chars) + if isUnsigned { + signature = geminiCLIClaudeThoughtSignature + } + partJSON := `{}` partJSON, _ = sjson.Set(partJSON, "thought", true) - if prompt != "" { - partJSON, _ = sjson.Set(partJSON, "text", prompt) + if thinkingText != "" { + partJSON, _ = sjson.Set(partJSON, "text", thinkingText) } if signature != "" { partJSON, _ = sjson.Set(partJSON, "thoughtSignature", signature) diff --git a/internal/translator/antigravity/claude/antigravity_claude_request_test.go b/internal/translator/antigravity/claude/antigravity_claude_request_test.go new file mode 100644 index 00000000..a5bfe49b --- /dev/null +++ b/internal/translator/antigravity/claude/antigravity_claude_request_test.go @@ -0,0 +1,567 @@ +package claude + +import ( + "strings" + "testing" + + "github.com/tidwall/gjson" +) + +func TestConvertClaudeRequestToAntigravity_BasicStructure(t *testing.T) { + inputJSON := []byte(`{ + "model": "claude-3-5-sonnet-20240620", + "messages": [ + { + "role": "user", + "content": [ + {"type": "text", "text": "Hello"} + ] + } + ], + "system": [ + {"type": "text", "text": "You are helpful"} + ] + }`) + + output := ConvertClaudeRequestToAntigravity("claude-sonnet-4-5", inputJSON, false) + outputStr := string(output) + + // Check model + if gjson.Get(outputStr, "model").String() != "claude-sonnet-4-5" { + t.Errorf("Expected model 'claude-sonnet-4-5', got '%s'", gjson.Get(outputStr, "model").String()) + } + + // Check contents exist + contents := gjson.Get(outputStr, "request.contents") + if !contents.Exists() || !contents.IsArray() { + t.Error("request.contents should exist and be an array") + } + + // Check role mapping (assistant -> model) + firstContent := gjson.Get(outputStr, "request.contents.0") + if firstContent.Get("role").String() != "user" { + t.Errorf("Expected role 'user', got '%s'", firstContent.Get("role").String()) + } + + // Check systemInstruction + sysInstruction := gjson.Get(outputStr, "request.systemInstruction") + if !sysInstruction.Exists() { + t.Error("systemInstruction should exist") + } + if sysInstruction.Get("parts.0.text").String() != "You are helpful" { + t.Error("systemInstruction text mismatch") + } +} + +func TestConvertClaudeRequestToAntigravity_RoleMapping(t *testing.T) { + inputJSON := []byte(`{ + "model": "claude-3-5-sonnet-20240620", + "messages": [ + {"role": "user", "content": [{"type": "text", "text": "Hi"}]}, + {"role": "assistant", "content": [{"type": "text", "text": "Hello"}]} + ] + }`) + + output := ConvertClaudeRequestToAntigravity("claude-sonnet-4-5", inputJSON, false) + outputStr := string(output) + + // assistant should be mapped to model + secondContent := gjson.Get(outputStr, "request.contents.1") + if secondContent.Get("role").String() != "model" { + t.Errorf("Expected role 'model' (mapped from 'assistant'), got '%s'", secondContent.Get("role").String()) + } +} + +func TestConvertClaudeRequestToAntigravity_ThinkingBlocks(t *testing.T) { + // Valid signature must be at least 50 characters + validSignature := "abc123validSignature1234567890123456789012345678901234567890" + inputJSON := []byte(`{ + "model": "claude-sonnet-4-5-thinking", + "messages": [ + { + "role": "assistant", + "content": [ + {"type": "thinking", "thinking": "Let me think...", "signature": "` + validSignature + `"}, + {"type": "text", "text": "Answer"} + ] + } + ] + }`) + + output := ConvertClaudeRequestToAntigravity("claude-sonnet-4-5-thinking", inputJSON, false) + outputStr := string(output) + + // Check thinking block conversion + firstPart := gjson.Get(outputStr, "request.contents.0.parts.0") + if !firstPart.Get("thought").Bool() { + t.Error("thinking block should have thought: true") + } + if firstPart.Get("text").String() != "Let me think..." { + t.Error("thinking text mismatch") + } + if firstPart.Get("thoughtSignature").String() != validSignature { + t.Errorf("Expected thoughtSignature '%s', got '%s'", validSignature, firstPart.Get("thoughtSignature").String()) + } +} + +func TestConvertClaudeRequestToAntigravity_ThinkingBlockWithoutSignature(t *testing.T) { + inputJSON := []byte(`{ + "model": "claude-sonnet-4-5-thinking", + "messages": [ + { + "role": "assistant", + "content": [ + {"type": "thinking", "thinking": "Let me think..."}, + {"type": "text", "text": "Answer"} + ] + } + ] + }`) + + output := ConvertClaudeRequestToAntigravity("claude-sonnet-4-5-thinking", inputJSON, false) + outputStr := string(output) + + // Without signature, should use sentinel value + firstPart := gjson.Get(outputStr, "request.contents.0.parts.0") + if firstPart.Get("thoughtSignature").String() != geminiCLIClaudeThoughtSignature { + t.Errorf("Expected sentinel signature '%s', got '%s'", + geminiCLIClaudeThoughtSignature, firstPart.Get("thoughtSignature").String()) + } +} + +func TestConvertClaudeRequestToAntigravity_ToolDeclarations(t *testing.T) { + inputJSON := []byte(`{ + "model": "claude-3-5-sonnet-20240620", + "messages": [], + "tools": [ + { + "name": "test_tool", + "description": "A test tool", + "input_schema": { + "type": "object", + "properties": { + "name": {"type": "string"} + }, + "required": ["name"] + } + } + ] + }`) + + output := ConvertClaudeRequestToAntigravity("gemini-1.5-pro", inputJSON, false) + outputStr := string(output) + + // Check tools structure + tools := gjson.Get(outputStr, "request.tools") + if !tools.Exists() { + t.Error("Tools should exist in output") + } + + funcDecl := gjson.Get(outputStr, "request.tools.0.functionDeclarations.0") + if funcDecl.Get("name").String() != "test_tool" { + t.Errorf("Expected tool name 'test_tool', got '%s'", funcDecl.Get("name").String()) + } + + // Check input_schema renamed to parametersJsonSchema + if funcDecl.Get("parametersJsonSchema").Exists() { + t.Log("parametersJsonSchema exists (expected)") + } + if funcDecl.Get("input_schema").Exists() { + t.Error("input_schema should be removed") + } +} + +func TestConvertClaudeRequestToAntigravity_ToolUse(t *testing.T) { + inputJSON := []byte(`{ + "model": "claude-3-5-sonnet-20240620", + "messages": [ + { + "role": "assistant", + "content": [ + { + "type": "tool_use", + "id": "call_123", + "name": "get_weather", + "input": "{\"location\": \"Paris\"}" + } + ] + } + ] + }`) + + output := ConvertClaudeRequestToAntigravity("claude-sonnet-4-5", inputJSON, false) + outputStr := string(output) + + // Check function call conversion + funcCall := gjson.Get(outputStr, "request.contents.0.parts.0.functionCall") + if !funcCall.Exists() { + t.Error("functionCall should exist") + } + if funcCall.Get("name").String() != "get_weather" { + t.Errorf("Expected function name 'get_weather', got '%s'", funcCall.Get("name").String()) + } + if funcCall.Get("id").String() != "call_123" { + t.Errorf("Expected function id 'call_123', got '%s'", funcCall.Get("id").String()) + } +} + +func TestConvertClaudeRequestToAntigravity_ToolResult(t *testing.T) { + inputJSON := []byte(`{ + "model": "claude-3-5-sonnet-20240620", + "messages": [ + { + "role": "user", + "content": [ + { + "type": "tool_result", + "tool_use_id": "get_weather-call-123", + "content": "22C sunny" + } + ] + } + ] + }`) + + output := ConvertClaudeRequestToAntigravity("claude-sonnet-4-5", inputJSON, false) + outputStr := string(output) + + // Check function response conversion + funcResp := gjson.Get(outputStr, "request.contents.0.parts.0.functionResponse") + if !funcResp.Exists() { + t.Error("functionResponse should exist") + } + if funcResp.Get("id").String() != "get_weather-call-123" { + t.Errorf("Expected function id, got '%s'", funcResp.Get("id").String()) + } +} + +func TestConvertClaudeRequestToAntigravity_ThinkingConfig(t *testing.T) { + // Note: This test requires the model to be registered in the registry + // with Thinking metadata. If the registry is not populated in test environment, + // thinkingConfig won't be added. We'll test the basic structure only. + inputJSON := []byte(`{ + "model": "claude-sonnet-4-5-thinking", + "messages": [], + "thinking": { + "type": "enabled", + "budget_tokens": 8000 + } + }`) + + output := ConvertClaudeRequestToAntigravity("claude-sonnet-4-5-thinking", inputJSON, false) + outputStr := string(output) + + // Check thinking config conversion (only if model supports thinking in registry) + thinkingConfig := gjson.Get(outputStr, "request.generationConfig.thinkingConfig") + if thinkingConfig.Exists() { + if thinkingConfig.Get("thinkingBudget").Int() != 8000 { + t.Errorf("Expected thinkingBudget 8000, got %d", thinkingConfig.Get("thinkingBudget").Int()) + } + if !thinkingConfig.Get("include_thoughts").Bool() { + t.Error("include_thoughts should be true") + } + } else { + t.Log("thinkingConfig not present - model may not be registered in test registry") + } +} + +func TestConvertClaudeRequestToAntigravity_ImageContent(t *testing.T) { + inputJSON := []byte(`{ + "model": "claude-3-5-sonnet-20240620", + "messages": [ + { + "role": "user", + "content": [ + { + "type": "image", + "source": { + "type": "base64", + "media_type": "image/png", + "data": "iVBORw0KGgoAAAANSUhEUg==" + } + } + ] + } + ] + }`) + + output := ConvertClaudeRequestToAntigravity("claude-sonnet-4-5", inputJSON, false) + outputStr := string(output) + + // Check inline data conversion + inlineData := gjson.Get(outputStr, "request.contents.0.parts.0.inlineData") + if !inlineData.Exists() { + t.Error("inlineData should exist") + } + if inlineData.Get("mime_type").String() != "image/png" { + t.Error("mime_type mismatch") + } + if !strings.Contains(inlineData.Get("data").String(), "iVBORw0KGgo") { + t.Error("data mismatch") + } +} + +func TestConvertClaudeRequestToAntigravity_GenerationConfig(t *testing.T) { + inputJSON := []byte(`{ + "model": "claude-3-5-sonnet-20240620", + "messages": [], + "temperature": 0.7, + "top_p": 0.9, + "top_k": 40, + "max_tokens": 2000 + }`) + + output := ConvertClaudeRequestToAntigravity("claude-sonnet-4-5", inputJSON, false) + outputStr := string(output) + + genConfig := gjson.Get(outputStr, "request.generationConfig") + if genConfig.Get("temperature").Float() != 0.7 { + t.Errorf("Expected temperature 0.7, got %f", genConfig.Get("temperature").Float()) + } + if genConfig.Get("topP").Float() != 0.9 { + t.Errorf("Expected topP 0.9, got %f", genConfig.Get("topP").Float()) + } + if genConfig.Get("topK").Float() != 40 { + t.Errorf("Expected topK 40, got %f", genConfig.Get("topK").Float()) + } + if genConfig.Get("maxOutputTokens").Float() != 2000 { + t.Errorf("Expected maxOutputTokens 2000, got %f", genConfig.Get("maxOutputTokens").Float()) + } +} + +// ============================================================================ +// P2-A: Trailing Unsigned Thinking Block Removal +// ============================================================================ + +func TestConvertClaudeRequestToAntigravity_TrailingUnsignedThinking_Removed(t *testing.T) { + // Last assistant message ends with unsigned thinking block - should be removed + inputJSON := []byte(`{ + "model": "claude-sonnet-4-5-thinking", + "messages": [ + { + "role": "user", + "content": [{"type": "text", "text": "Hello"}] + }, + { + "role": "assistant", + "content": [ + {"type": "text", "text": "Here is my answer"}, + {"type": "thinking", "thinking": "I should think more..."} + ] + } + ] + }`) + + output := ConvertClaudeRequestToAntigravity("claude-sonnet-4-5-thinking", inputJSON, false) + outputStr := string(output) + + // The last part of the last assistant message should NOT be a thinking block + lastMessageParts := gjson.Get(outputStr, "request.contents.1.parts") + if !lastMessageParts.IsArray() { + t.Fatal("Last message should have parts array") + } + parts := lastMessageParts.Array() + if len(parts) == 0 { + t.Fatal("Last message should have at least one part") + } + + // The unsigned thinking should be removed, leaving only the text + lastPart := parts[len(parts)-1] + if lastPart.Get("thought").Bool() { + t.Error("Trailing unsigned thinking block should be removed") + } +} + +func TestConvertClaudeRequestToAntigravity_TrailingSignedThinking_Kept(t *testing.T) { + // Last assistant message ends with signed thinking block - should be kept + inputJSON := []byte(`{ + "model": "claude-sonnet-4-5-thinking", + "messages": [ + { + "role": "user", + "content": [{"type": "text", "text": "Hello"}] + }, + { + "role": "assistant", + "content": [ + {"type": "text", "text": "Here is my answer"}, + {"type": "thinking", "thinking": "Valid thinking...", "signature": "abc123validSignature1234567890123456789012345678901234567890"} + ] + } + ] + }`) + + output := ConvertClaudeRequestToAntigravity("claude-sonnet-4-5-thinking", inputJSON, false) + outputStr := string(output) + + // The signed thinking block should be preserved + lastMessageParts := gjson.Get(outputStr, "request.contents.1.parts") + parts := lastMessageParts.Array() + if len(parts) < 2 { + t.Error("Signed thinking block should be preserved") + } +} + +func TestConvertClaudeRequestToAntigravity_MiddleUnsignedThinking_SentinelApplied(t *testing.T) { + // Middle message has unsigned thinking - should use sentinel (existing behavior) + inputJSON := []byte(`{ + "model": "claude-sonnet-4-5-thinking", + "messages": [ + { + "role": "assistant", + "content": [ + {"type": "thinking", "thinking": "Middle thinking..."}, + {"type": "text", "text": "Answer"} + ] + }, + { + "role": "user", + "content": [{"type": "text", "text": "Follow up"}] + } + ] + }`) + + output := ConvertClaudeRequestToAntigravity("claude-sonnet-4-5-thinking", inputJSON, false) + outputStr := string(output) + + // Middle unsigned thinking should have sentinel applied + thinkingPart := gjson.Get(outputStr, "request.contents.0.parts.0") + if !thinkingPart.Get("thought").Bool() { + t.Error("Middle thinking block should be preserved with sentinel") + } + if thinkingPart.Get("thoughtSignature").String() != geminiCLIClaudeThoughtSignature { + t.Errorf("Middle unsigned thinking should use sentinel signature, got: %s", thinkingPart.Get("thoughtSignature").String()) + } +} + +// ============================================================================ +// P2-B: Tool + Thinking System Hint Injection +// ============================================================================ + +func TestConvertClaudeRequestToAntigravity_ToolAndThinking_HintInjected(t *testing.T) { + // When both tools and thinking are enabled, hint should be injected into system instruction + inputJSON := []byte(`{ + "model": "claude-sonnet-4-5-thinking", + "messages": [{"role": "user", "content": [{"type": "text", "text": "Hello"}]}], + "system": [{"type": "text", "text": "You are helpful."}], + "tools": [ + { + "name": "get_weather", + "description": "Get weather", + "input_schema": {"type": "object", "properties": {"location": {"type": "string"}}} + } + ], + "thinking": {"type": "enabled", "budget_tokens": 8000} + }`) + + output := ConvertClaudeRequestToAntigravity("claude-sonnet-4-5-thinking", inputJSON, false) + outputStr := string(output) + + // System instruction should contain the interleaved thinking hint + sysInstruction := gjson.Get(outputStr, "request.systemInstruction") + if !sysInstruction.Exists() { + t.Fatal("systemInstruction should exist") + } + + // Check if hint is appended + sysText := sysInstruction.Get("parts").Array() + found := false + for _, part := range sysText { + if strings.Contains(part.Get("text").String(), "Interleaved thinking is enabled") { + found = true + break + } + } + if !found { + t.Errorf("Interleaved thinking hint should be injected when tools and thinking are both active, got: %v", sysInstruction.Raw) + } +} + +func TestConvertClaudeRequestToAntigravity_ToolsOnly_NoHint(t *testing.T) { + // When only tools are present (no thinking), hint should NOT be injected + inputJSON := []byte(`{ + "model": "claude-sonnet-4-5", + "messages": [{"role": "user", "content": [{"type": "text", "text": "Hello"}]}], + "system": [{"type": "text", "text": "You are helpful."}], + "tools": [ + { + "name": "get_weather", + "description": "Get weather", + "input_schema": {"type": "object", "properties": {"location": {"type": "string"}}} + } + ] + }`) + + output := ConvertClaudeRequestToAntigravity("claude-sonnet-4-5", inputJSON, false) + outputStr := string(output) + + // System instruction should NOT contain the hint + sysInstruction := gjson.Get(outputStr, "request.systemInstruction") + if sysInstruction.Exists() { + for _, part := range sysInstruction.Get("parts").Array() { + if strings.Contains(part.Get("text").String(), "Interleaved thinking is enabled") { + t.Error("Hint should NOT be injected when only tools are present (no thinking)") + } + } + } +} + +func TestConvertClaudeRequestToAntigravity_ThinkingOnly_NoHint(t *testing.T) { + // When only thinking is enabled (no tools), hint should NOT be injected + inputJSON := []byte(`{ + "model": "claude-sonnet-4-5-thinking", + "messages": [{"role": "user", "content": [{"type": "text", "text": "Hello"}]}], + "system": [{"type": "text", "text": "You are helpful."}], + "thinking": {"type": "enabled", "budget_tokens": 8000} + }`) + + output := ConvertClaudeRequestToAntigravity("claude-sonnet-4-5-thinking", inputJSON, false) + outputStr := string(output) + + // System instruction should NOT contain the hint (no tools) + sysInstruction := gjson.Get(outputStr, "request.systemInstruction") + if sysInstruction.Exists() { + for _, part := range sysInstruction.Get("parts").Array() { + if strings.Contains(part.Get("text").String(), "Interleaved thinking is enabled") { + t.Error("Hint should NOT be injected when only thinking is present (no tools)") + } + } + } +} + +func TestConvertClaudeRequestToAntigravity_ToolAndThinking_NoExistingSystem(t *testing.T) { + // When tools + thinking but no system instruction, should create one with hint + inputJSON := []byte(`{ + "model": "claude-sonnet-4-5-thinking", + "messages": [{"role": "user", "content": [{"type": "text", "text": "Hello"}]}], + "tools": [ + { + "name": "get_weather", + "description": "Get weather", + "input_schema": {"type": "object", "properties": {"location": {"type": "string"}}} + } + ], + "thinking": {"type": "enabled", "budget_tokens": 8000} + }`) + + output := ConvertClaudeRequestToAntigravity("claude-sonnet-4-5-thinking", inputJSON, false) + outputStr := string(output) + + // System instruction should be created with hint + sysInstruction := gjson.Get(outputStr, "request.systemInstruction") + if !sysInstruction.Exists() { + t.Fatal("systemInstruction should be created when tools + thinking are active") + } + + sysText := sysInstruction.Get("parts").Array() + found := false + for _, part := range sysText { + if strings.Contains(part.Get("text").String(), "Interleaved thinking is enabled") { + found = true + break + } + } + if !found { + t.Errorf("Interleaved thinking hint should be in created systemInstruction, got: %v", sysInstruction.Raw) + } +} From 98fa2a15971d535bdee3412857ab20bec7c03bd4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=EC=9D=B4=EB=8C=80=ED=9D=AC?= Date: Fri, 19 Dec 2025 10:27:24 +0900 Subject: [PATCH 09/38] feat(translator/antigravity/claude): support interleaved thinking, signature restoration and system hint injection --- .../claude/antigravity_claude_response.go | 43 ++ .../antigravity_claude_response_test.go | 389 ++++++++++++++++++ 2 files changed, 432 insertions(+) create mode 100644 internal/translator/antigravity/claude/antigravity_claude_response_test.go diff --git a/internal/translator/antigravity/claude/antigravity_claude_response.go b/internal/translator/antigravity/claude/antigravity_claude_response.go index 52fc358e..d26a1c9f 100644 --- a/internal/translator/antigravity/claude/antigravity_claude_response.go +++ b/internal/translator/antigravity/claude/antigravity_claude_response.go @@ -9,11 +9,16 @@ package claude import ( "bytes" "context" + "crypto/sha256" + "encoding/hex" "fmt" "strings" "sync/atomic" "time" + "github.com/router-for-me/CLIProxyAPI/v6/internal/cache" + log "github.com/sirupsen/logrus" + "github.com/tidwall/gjson" "github.com/tidwall/sjson" ) @@ -35,6 +40,31 @@ type Params struct { HasSentFinalEvents bool // Indicates if final content/message events have been sent HasToolUse bool // Indicates if tool use was observed in the stream HasContent bool // Tracks whether any content (text, thinking, or tool use) has been output + + // P3: Signature caching support + SessionID string // Session ID derived from request for signature caching + CurrentThinkingText strings.Builder // Accumulates thinking text for signature caching +} + +// deriveSessionIDFromRequest generates a stable session ID from the request JSON. +func deriveSessionIDFromRequest(rawJSON []byte) string { + messages := gjson.GetBytes(rawJSON, "messages") + if !messages.IsArray() { + return "" + } + for _, msg := range messages.Array() { + if msg.Get("role").String() == "user" { + content := msg.Get("content").String() + if content == "" { + content = msg.Get("content.0.text").String() + } + if content != "" { + h := sha256.Sum256([]byte(content)) + return hex.EncodeToString(h[:16]) + } + } + } + return "" } // toolUseIDCounter provides a process-wide unique counter for tool use identifiers. @@ -62,6 +92,7 @@ func ConvertAntigravityResponseToClaude(_ context.Context, _ string, originalReq HasFirstResponse: false, ResponseType: 0, ResponseIndex: 0, + SessionID: deriveSessionIDFromRequest(originalRequestRawJSON), } } @@ -119,11 +150,20 @@ func ConvertAntigravityResponseToClaude(_ context.Context, _ string, originalReq // Process thinking content (internal reasoning) if partResult.Get("thought").Bool() { if thoughtSignature := partResult.Get("thoughtSignature"); thoughtSignature.Exists() && thoughtSignature.String() != "" { + log.Debug("Branch: signature_delta") + + if params.SessionID != "" && params.CurrentThinkingText.Len() > 0 { + cache.CacheSignature(params.SessionID, params.CurrentThinkingText.String(), thoughtSignature.String()) + log.Debugf("Cached signature for thinking block (sessionID=%s, textLen=%d)", params.SessionID, params.CurrentThinkingText.Len()) + params.CurrentThinkingText.Reset() + } + output = output + "event: content_block_delta\n" data, _ := sjson.Set(fmt.Sprintf(`{"type":"content_block_delta","index":%d,"delta":{"type":"signature_delta","signature":""}}`, params.ResponseIndex), "delta.signature", thoughtSignature.String()) output = output + fmt.Sprintf("data: %s\n\n\n", data) params.HasContent = true } else if params.ResponseType == 2 { // Continue existing thinking block if already in thinking state + params.CurrentThinkingText.WriteString(partTextResult.String()) output = output + "event: content_block_delta\n" data, _ := sjson.Set(fmt.Sprintf(`{"type":"content_block_delta","index":%d,"delta":{"type":"thinking_delta","thinking":""}}`, params.ResponseIndex), "delta.thinking", partTextResult.String()) output = output + fmt.Sprintf("data: %s\n\n\n", data) @@ -152,6 +192,9 @@ func ConvertAntigravityResponseToClaude(_ context.Context, _ string, originalReq output = output + fmt.Sprintf("data: %s\n\n\n", data) params.ResponseType = 2 // Set state to thinking params.HasContent = true + // P3: Start accumulating thinking text for signature caching + params.CurrentThinkingText.Reset() + params.CurrentThinkingText.WriteString(partTextResult.String()) } } else { finishReasonResult := gjson.GetBytes(rawJSON, "response.candidates.0.finishReason") diff --git a/internal/translator/antigravity/claude/antigravity_claude_response_test.go b/internal/translator/antigravity/claude/antigravity_claude_response_test.go new file mode 100644 index 00000000..7ffd7666 --- /dev/null +++ b/internal/translator/antigravity/claude/antigravity_claude_response_test.go @@ -0,0 +1,389 @@ +package claude + +import ( + "context" + "strings" + "testing" + + "github.com/router-for-me/CLIProxyAPI/v6/internal/cache" +) + +func TestConvertBashCommandToCmdField(t *testing.T) { + tests := []struct { + name string + input string + expected string + }{ + { + name: "basic command to cmd conversion", + input: `{"command": "git diff"}`, + expected: `{"cmd":"git diff"}`, + }, + { + name: "already has cmd field - no change", + input: `{"cmd": "git diff"}`, + expected: `{"cmd": "git diff"}`, + }, + { + name: "both cmd and command - keep cmd only", + input: `{"command": "git diff", "cmd": "ls"}`, + expected: `{"command": "git diff", "cmd": "ls"}`, // no change when cmd exists + }, + { + name: "command with special characters in value", + input: `{"command": "echo \"command\": test"}`, + expected: `{"cmd":"echo \"command\": test"}`, + }, + { + name: "command with nested quotes", + input: `{"command": "bash -c 'echo \"hello\"'"}`, + expected: `{"cmd":"bash -c 'echo \"hello\"'"}`, + }, + { + name: "command with newlines", + input: `{"command": "echo hello\necho world"}`, + expected: `{"cmd":"echo hello\necho world"}`, + }, + { + name: "empty command value", + input: `{"command": ""}`, + expected: `{"cmd":""}`, + }, + { + name: "command with other fields - preserves them", + input: `{"command": "git diff", "timeout": 30}`, + expected: `{ "timeout": 30,"cmd":"git diff"}`, + }, + { + name: "invalid JSON - returns unchanged", + input: `{invalid json`, + expected: `{invalid json`, + }, + { + name: "empty object", + input: `{}`, + expected: `{}`, + }, + { + name: "no command field", + input: `{"restart": true}`, + expected: `{"restart": true}`, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := convertBashCommandToCmdField(tt.input) + if result != tt.expected { + t.Errorf("convertBashCommandToCmdField(%q) = %q, want %q", tt.input, result, tt.expected) + } + }) + } +} + +// ============================================================================ +// P3: Signature Caching Tests +// ============================================================================ + +func TestConvertAntigravityResponseToClaude_SessionIDDerived(t *testing.T) { + cache.ClearSignatureCache("") + + // Request with user message - should derive session ID + requestJSON := []byte(`{ + "messages": [ + {"role": "user", "content": [{"type": "text", "text": "Hello world"}]} + ] + }`) + + // First response chunk with thinking + responseJSON := []byte(`{ + "response": { + "candidates": [{ + "content": { + "parts": [{"text": "Let me think...", "thought": true}] + } + }] + } + }`) + + var param any + ctx := context.Background() + ConvertAntigravityResponseToClaude(ctx, "claude-sonnet-4-5-thinking", requestJSON, requestJSON, responseJSON, ¶m) + + // Verify session ID was set + params := param.(*Params) + if params.SessionID == "" { + t.Error("SessionID should be derived from request") + } +} + +func TestConvertAntigravityResponseToClaude_ThinkingTextAccumulated(t *testing.T) { + cache.ClearSignatureCache("") + + requestJSON := []byte(`{ + "messages": [{"role": "user", "content": [{"type": "text", "text": "Test"}]}] + }`) + + // First thinking chunk + chunk1 := []byte(`{ + "response": { + "candidates": [{ + "content": { + "parts": [{"text": "First part of thinking...", "thought": true}] + } + }] + } + }`) + + // Second thinking chunk (continuation) + chunk2 := []byte(`{ + "response": { + "candidates": [{ + "content": { + "parts": [{"text": " Second part of thinking...", "thought": true}] + } + }] + } + }`) + + var param any + ctx := context.Background() + + // Process first chunk - starts new thinking block + ConvertAntigravityResponseToClaude(ctx, "claude-sonnet-4-5-thinking", requestJSON, requestJSON, chunk1, ¶m) + params := param.(*Params) + + if params.CurrentThinkingText.Len() == 0 { + t.Error("Thinking text should be accumulated after first chunk") + } + + // Process second chunk - continues thinking block + ConvertAntigravityResponseToClaude(ctx, "claude-sonnet-4-5-thinking", requestJSON, requestJSON, chunk2, ¶m) + + text := params.CurrentThinkingText.String() + if !strings.Contains(text, "First part") || !strings.Contains(text, "Second part") { + t.Errorf("Thinking text should accumulate both parts, got: %s", text) + } +} + +func TestConvertAntigravityResponseToClaude_SignatureCached(t *testing.T) { + cache.ClearSignatureCache("") + + requestJSON := []byte(`{ + "messages": [{"role": "user", "content": [{"type": "text", "text": "Cache test"}]}] + }`) + + // Thinking chunk + thinkingChunk := []byte(`{ + "response": { + "candidates": [{ + "content": { + "parts": [{"text": "My thinking process here", "thought": true}] + } + }] + } + }`) + + // Signature chunk + validSignature := "abc123validSignature1234567890123456789012345678901234567890" + signatureChunk := []byte(`{ + "response": { + "candidates": [{ + "content": { + "parts": [{"text": "", "thought": true, "thoughtSignature": "` + validSignature + `"}] + } + }] + } + }`) + + var param any + ctx := context.Background() + + // Process thinking chunk + ConvertAntigravityResponseToClaude(ctx, "claude-sonnet-4-5-thinking", requestJSON, requestJSON, thinkingChunk, ¶m) + params := param.(*Params) + sessionID := params.SessionID + thinkingText := params.CurrentThinkingText.String() + + if sessionID == "" { + t.Fatal("SessionID should be set") + } + if thinkingText == "" { + t.Fatal("Thinking text should be accumulated") + } + + // Process signature chunk - should cache the signature + ConvertAntigravityResponseToClaude(ctx, "claude-sonnet-4-5-thinking", requestJSON, requestJSON, signatureChunk, ¶m) + + // Verify signature was cached + cachedSig := cache.GetCachedSignature(sessionID, thinkingText) + if cachedSig != validSignature { + t.Errorf("Expected cached signature '%s', got '%s'", validSignature, cachedSig) + } + + // Verify thinking text was reset after caching + if params.CurrentThinkingText.Len() != 0 { + t.Error("Thinking text should be reset after signature is cached") + } +} + +func TestConvertAntigravityResponseToClaude_MultipleThinkingBlocks(t *testing.T) { + cache.ClearSignatureCache("") + + requestJSON := []byte(`{ + "messages": [{"role": "user", "content": [{"type": "text", "text": "Multi block test"}]}] + }`) + + validSig1 := "signature1_12345678901234567890123456789012345678901234567" + validSig2 := "signature2_12345678901234567890123456789012345678901234567" + + // First thinking block with signature + block1Thinking := []byte(`{ + "response": { + "candidates": [{ + "content": { + "parts": [{"text": "First thinking block", "thought": true}] + } + }] + } + }`) + block1Sig := []byte(`{ + "response": { + "candidates": [{ + "content": { + "parts": [{"text": "", "thought": true, "thoughtSignature": "` + validSig1 + `"}] + } + }] + } + }`) + + // Text content (breaks thinking) + textBlock := []byte(`{ + "response": { + "candidates": [{ + "content": { + "parts": [{"text": "Regular text output"}] + } + }] + } + }`) + + // Second thinking block with signature + block2Thinking := []byte(`{ + "response": { + "candidates": [{ + "content": { + "parts": [{"text": "Second thinking block", "thought": true}] + } + }] + } + }`) + block2Sig := []byte(`{ + "response": { + "candidates": [{ + "content": { + "parts": [{"text": "", "thought": true, "thoughtSignature": "` + validSig2 + `"}] + } + }] + } + }`) + + var param any + ctx := context.Background() + + // Process first thinking block + ConvertAntigravityResponseToClaude(ctx, "claude-sonnet-4-5-thinking", requestJSON, requestJSON, block1Thinking, ¶m) + params := param.(*Params) + sessionID := params.SessionID + firstThinkingText := params.CurrentThinkingText.String() + + ConvertAntigravityResponseToClaude(ctx, "claude-sonnet-4-5-thinking", requestJSON, requestJSON, block1Sig, ¶m) + + // Verify first signature cached + if cache.GetCachedSignature(sessionID, firstThinkingText) != validSig1 { + t.Error("First thinking block signature should be cached") + } + + // Process text (transitions out of thinking) + ConvertAntigravityResponseToClaude(ctx, "claude-sonnet-4-5-thinking", requestJSON, requestJSON, textBlock, ¶m) + + // Process second thinking block + ConvertAntigravityResponseToClaude(ctx, "claude-sonnet-4-5-thinking", requestJSON, requestJSON, block2Thinking, ¶m) + secondThinkingText := params.CurrentThinkingText.String() + + ConvertAntigravityResponseToClaude(ctx, "claude-sonnet-4-5-thinking", requestJSON, requestJSON, block2Sig, ¶m) + + // Verify second signature cached + if cache.GetCachedSignature(sessionID, secondThinkingText) != validSig2 { + t.Error("Second thinking block signature should be cached") + } +} + +func TestDeriveSessionIDFromRequest(t *testing.T) { + tests := []struct { + name string + input []byte + wantEmpty bool + }{ + { + name: "valid user message", + input: []byte(`{"messages": [{"role": "user", "content": "Hello"}]}`), + wantEmpty: false, + }, + { + name: "user message with content array", + input: []byte(`{"messages": [{"role": "user", "content": [{"type": "text", "text": "Hello"}]}]}`), + wantEmpty: false, + }, + { + name: "no user message", + input: []byte(`{"messages": [{"role": "assistant", "content": "Hi"}]}`), + wantEmpty: true, + }, + { + name: "empty messages", + input: []byte(`{"messages": []}`), + wantEmpty: true, + }, + { + name: "no messages field", + input: []byte(`{}`), + wantEmpty: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + result := deriveSessionIDFromRequest(tt.input) + if tt.wantEmpty && result != "" { + t.Errorf("Expected empty session ID, got '%s'", result) + } + if !tt.wantEmpty && result == "" { + t.Error("Expected non-empty session ID") + } + }) + } +} + +func TestDeriveSessionIDFromRequest_Deterministic(t *testing.T) { + input := []byte(`{"messages": [{"role": "user", "content": "Same message"}]}`) + + id1 := deriveSessionIDFromRequest(input) + id2 := deriveSessionIDFromRequest(input) + + if id1 != id2 { + t.Errorf("Session ID should be deterministic: '%s' != '%s'", id1, id2) + } +} + +func TestDeriveSessionIDFromRequest_DifferentMessages(t *testing.T) { + input1 := []byte(`{"messages": [{"role": "user", "content": "Message A"}]}`) + input2 := []byte(`{"messages": [{"role": "user", "content": "Message B"}]}`) + + id1 := deriveSessionIDFromRequest(input1) + id2 := deriveSessionIDFromRequest(input2) + + if id1 == id2 { + t.Error("Different messages should produce different session IDs") + } +} From c1f8211acb64ffa473c1236592e462a35d7204a9 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=EC=9D=B4=EB=8C=80=ED=9D=AC?= Date: Fri, 19 Dec 2025 11:12:16 +0900 Subject: [PATCH 10/38] fix: Normalize Bash tool args and add signature caching support Normalize Bash tool arguments by converting a "command" key into "cmd" using JSON-aware parsing, avoiding brittle string replacements that could corrupt values. Apply this conversion in both streaming and non-streaming response paths so bash-style tool calls are emitted with the expected "cmd" field. Add support for accumulating thinking text and carrying session identifiers to enable signature caching/restore for unsigned thinking blocks, improving handling of thinking-state continuity across requests/responses. Also perform small cleanups: import logging, tidy comments and test descriptions. These changes make tool-argument handling more robust and enable reliable signature restoration for thinking blocks. --- .../claude/antigravity_claude_request.go | 7 +-- .../claude/antigravity_claude_request_test.go | 4 +- .../claude/antigravity_claude_response.go | 48 +++++++++++++++++-- .../antigravity_claude_response_test.go | 2 +- internal/util/gemini_schema_test.go | 4 +- 5 files changed, 53 insertions(+), 12 deletions(-) diff --git a/internal/translator/antigravity/claude/antigravity_claude_request.go b/internal/translator/antigravity/claude/antigravity_claude_request.go index 4eaef1f6..fdc0f93e 100644 --- a/internal/translator/antigravity/claude/antigravity_claude_request.go +++ b/internal/translator/antigravity/claude/antigravity_claude_request.go @@ -14,6 +14,7 @@ import ( "github.com/router-for-me/CLIProxyAPI/v6/internal/cache" "github.com/router-for-me/CLIProxyAPI/v6/internal/translator/gemini/common" "github.com/router-for-me/CLIProxyAPI/v6/internal/util" + log "github.com/sirupsen/logrus" "github.com/tidwall/gjson" "github.com/tidwall/sjson" ) @@ -124,7 +125,7 @@ func ConvertClaudeRequestToAntigravity(modelName string, inputRawJSON []byte, _ signature = signatureResult.String() } - // P3: Try to restore signature from cache for unsigned thinking blocks + // Try to restore signature from cache for unsigned thinking blocks if !cache.HasValidSignature(signature) && sessionID != "" && thinkingText != "" { if cachedSig := cache.GetCachedSignature(sessionID, thinkingText); cachedSig != "" { signature = cachedSig @@ -132,7 +133,7 @@ func ConvertClaudeRequestToAntigravity(modelName string, inputRawJSON []byte, _ } } - // P2-A: Skip trailing unsigned thinking blocks on last assistant message + // Skip trailing unsigned thinking blocks on last assistant message isLastMessage := (i == numMessages-1) isLastContent := (j == numContents-1) isAssistant := (originalRole == "assistant") @@ -278,7 +279,7 @@ func ConvertClaudeRequestToAntigravity(modelName string, inputRawJSON []byte, _ out := `{"model":"","request":{"contents":[]}}` out, _ = sjson.Set(out, "model", modelName) - // P2-B: Inject interleaved thinking hint when both tools and thinking are active + // Inject interleaved thinking hint when both tools and thinking are active hasTools := toolDeclCount > 0 thinkingResult := gjson.GetBytes(rawJSON, "thinking") hasThinking := thinkingResult.Exists() && thinkingResult.IsObject() && thinkingResult.Get("type").String() == "enabled" diff --git a/internal/translator/antigravity/claude/antigravity_claude_request_test.go b/internal/translator/antigravity/claude/antigravity_claude_request_test.go index a5bfe49b..796ce0d3 100644 --- a/internal/translator/antigravity/claude/antigravity_claude_request_test.go +++ b/internal/translator/antigravity/claude/antigravity_claude_request_test.go @@ -330,7 +330,7 @@ func TestConvertClaudeRequestToAntigravity_GenerationConfig(t *testing.T) { } // ============================================================================ -// P2-A: Trailing Unsigned Thinking Block Removal +// Trailing Unsigned Thinking Block Removal // ============================================================================ func TestConvertClaudeRequestToAntigravity_TrailingUnsignedThinking_Removed(t *testing.T) { @@ -435,7 +435,7 @@ func TestConvertClaudeRequestToAntigravity_MiddleUnsignedThinking_SentinelApplie } // ============================================================================ -// P2-B: Tool + Thinking System Hint Injection +// Tool + Thinking System Hint Injection // ============================================================================ func TestConvertClaudeRequestToAntigravity_ToolAndThinking_HintInjected(t *testing.T) { diff --git a/internal/translator/antigravity/claude/antigravity_claude_response.go b/internal/translator/antigravity/claude/antigravity_claude_response.go index d26a1c9f..939551ba 100644 --- a/internal/translator/antigravity/claude/antigravity_claude_response.go +++ b/internal/translator/antigravity/claude/antigravity_claude_response.go @@ -41,7 +41,7 @@ type Params struct { HasToolUse bool // Indicates if tool use was observed in the stream HasContent bool // Tracks whether any content (text, thinking, or tool use) has been output - // P3: Signature caching support + // Signature caching support SessionID string // Session ID derived from request for signature caching CurrentThinkingText strings.Builder // Accumulates thinking text for signature caching } @@ -192,7 +192,7 @@ func ConvertAntigravityResponseToClaude(_ context.Context, _ string, originalReq output = output + fmt.Sprintf("data: %s\n\n\n", data) params.ResponseType = 2 // Set state to thinking params.HasContent = true - // P3: Start accumulating thinking text for signature caching + // Start accumulating thinking text for signature caching params.CurrentThinkingText.Reset() params.CurrentThinkingText.WriteString(partTextResult.String()) } @@ -276,8 +276,13 @@ func ConvertAntigravityResponseToClaude(_ context.Context, _ string, originalReq output = output + fmt.Sprintf("data: %s\n\n\n", data) if fcArgsResult := functionCallResult.Get("args"); fcArgsResult.Exists() { + argsRaw := fcArgsResult.Raw + // Convert command → cmd for Bash tools using proper JSON parsing + if fcName == "Bash" || fcName == "bash" || fcName == "bash_20241022" { + argsRaw = convertBashCommandToCmdField(argsRaw) + } output = output + "event: content_block_delta\n" - data, _ = sjson.Set(fmt.Sprintf(`{"type":"content_block_delta","index":%d,"delta":{"type":"input_json_delta","partial_json":""}}`, params.ResponseIndex), "delta.partial_json", fcArgsResult.Raw) + data, _ = sjson.Set(fmt.Sprintf(`{"type":"content_block_delta","index":%d,"delta":{"type":"input_json_delta","partial_json":""}}`, params.ResponseIndex), "delta.partial_json", argsRaw) output = output + fmt.Sprintf("data: %s\n\n\n", data) } params.ResponseType = 3 @@ -365,6 +370,36 @@ func resolveStopReason(params *Params) string { return "end_turn" } +// convertBashCommandToCmdField converts "command" field to "cmd" field for Bash tools. +// Amp expects "cmd" but Gemini sends "command". This uses proper JSON parsing +// to avoid accidentally replacing "command" that appears in values. +func convertBashCommandToCmdField(argsRaw string) string { + // Only process valid JSON + if !gjson.Valid(argsRaw) { + return argsRaw + } + + // Check if "command" key exists and "cmd" doesn't + commandVal := gjson.Get(argsRaw, "command") + cmdVal := gjson.Get(argsRaw, "cmd") + + if commandVal.Exists() && !cmdVal.Exists() { + // Set "cmd" to the value of "command", preserve the raw value type + result, err := sjson.SetRaw(argsRaw, "cmd", commandVal.Raw) + if err != nil { + return argsRaw + } + // Delete "command" key + result, err = sjson.Delete(result, "command") + if err != nil { + return argsRaw + } + return result + } + + return argsRaw +} + // ConvertAntigravityResponseToClaudeNonStream converts a non-streaming Gemini CLI response to a non-streaming Claude response. // // Parameters: @@ -476,7 +511,12 @@ func ConvertAntigravityResponseToClaudeNonStream(_ context.Context, _ string, or toolBlock, _ = sjson.Set(toolBlock, "name", name) if args := functionCall.Get("args"); args.Exists() && args.Raw != "" && gjson.Valid(args.Raw) { - toolBlock, _ = sjson.SetRaw(toolBlock, "input", args.Raw) + argsRaw := args.Raw + // Convert command → cmd for Bash tools + if name == "Bash" || name == "bash" || name == "bash_20241022" { + argsRaw = convertBashCommandToCmdField(argsRaw) + } + toolBlock, _ = sjson.SetRaw(toolBlock, "input", argsRaw) } ensureContentArray() diff --git a/internal/translator/antigravity/claude/antigravity_claude_response_test.go b/internal/translator/antigravity/claude/antigravity_claude_response_test.go index 7ffd7666..4c2f31c1 100644 --- a/internal/translator/antigravity/claude/antigravity_claude_response_test.go +++ b/internal/translator/antigravity/claude/antigravity_claude_response_test.go @@ -82,7 +82,7 @@ func TestConvertBashCommandToCmdField(t *testing.T) { } // ============================================================================ -// P3: Signature Caching Tests +// Signature Caching Tests // ============================================================================ func TestConvertAntigravityResponseToClaude_SessionIDDerived(t *testing.T) { diff --git a/internal/util/gemini_schema_test.go b/internal/util/gemini_schema_test.go index 0f9e3eba..69adbcdb 100644 --- a/internal/util/gemini_schema_test.go +++ b/internal/util/gemini_schema_test.go @@ -631,7 +631,7 @@ func compareJSON(t *testing.T, expectedJSON, actualJSON string) { } // ============================================================================ -// P0-1: Empty Schema Placeholder Tests +// Empty Schema Placeholder Tests // ============================================================================ func TestCleanJSONSchemaForAntigravity_EmptySchemaPlaceholder(t *testing.T) { @@ -732,7 +732,7 @@ func TestCleanJSONSchemaForAntigravity_EmptySchemaWithDescription(t *testing.T) } // ============================================================================ -// P0-2: Format field handling (ad-hoc patch removal) +// Format field handling (ad-hoc patch removal) // ============================================================================ func TestCleanJSONSchemaForAntigravity_FormatFieldRemoval(t *testing.T) { From 3275494fde6efb71a3b12ee66fa0b5f191f2d81f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=EC=9D=B4=EB=8C=80=ED=9D=AC?= Date: Fri, 19 Dec 2025 13:09:36 +0900 Subject: [PATCH 11/38] refactor: Use helper to extract wrapped "thinking" text Improve robustness when handling "thinking" content by using a dedicated helper to extract the thinking text. This ensures wrapped or nested thinking objects are handled correctly instead of relying on a direct string extraction, reducing parsing errors for complex payloads. --- .../claude/antigravity_claude_request.go | 3 +- internal/util/thinking_text.go | 87 +++++++++++++++++++ litellm | 1 + opencode-antigravity-auth | 1 + opencode-google-antigravity-auth | 1 + 5 files changed, 92 insertions(+), 1 deletion(-) create mode 100644 internal/util/thinking_text.go create mode 160000 litellm create mode 160000 opencode-antigravity-auth create mode 160000 opencode-google-antigravity-auth diff --git a/internal/translator/antigravity/claude/antigravity_claude_request.go b/internal/translator/antigravity/claude/antigravity_claude_request.go index fdc0f93e..fdfdf469 100644 --- a/internal/translator/antigravity/claude/antigravity_claude_request.go +++ b/internal/translator/antigravity/claude/antigravity_claude_request.go @@ -118,7 +118,8 @@ func ConvertClaudeRequestToAntigravity(modelName string, inputRawJSON []byte, _ contentResult := contentResults[j] contentTypeResult := contentResult.Get("type") if contentTypeResult.Type == gjson.String && contentTypeResult.String() == "thinking" { - thinkingText := contentResult.Get("thinking").String() + // Use GetThinkingText to handle wrapped thinking objects + thinkingText := util.GetThinkingText(contentResult) signatureResult := contentResult.Get("signature") signature := "" if signatureResult.Exists() && signatureResult.String() != "" { diff --git a/internal/util/thinking_text.go b/internal/util/thinking_text.go new file mode 100644 index 00000000..c36d202d --- /dev/null +++ b/internal/util/thinking_text.go @@ -0,0 +1,87 @@ +package util + +import ( + "github.com/tidwall/gjson" + "github.com/tidwall/sjson" +) + +// GetThinkingText extracts the thinking text from a content part. +// Handles various formats: +// - Simple string: { "thinking": "text" } or { "text": "text" } +// - Wrapped object: { "thinking": { "text": "text", "cache_control": {...} } } +// - Gemini-style: { "thought": true, "text": "text" } +// Returns the extracted text string. +func GetThinkingText(part gjson.Result) string { + // Try direct text field first (Gemini-style) + if text := part.Get("text"); text.Exists() && text.Type == gjson.String { + return text.String() + } + + // Try thinking field + thinkingField := part.Get("thinking") + if !thinkingField.Exists() { + return "" + } + + // thinking is a string + if thinkingField.Type == gjson.String { + return thinkingField.String() + } + + // thinking is an object with inner text/thinking + if thinkingField.IsObject() { + if inner := thinkingField.Get("text"); inner.Exists() && inner.Type == gjson.String { + return inner.String() + } + if inner := thinkingField.Get("thinking"); inner.Exists() && inner.Type == gjson.String { + return inner.String() + } + } + + return "" +} + +// GetThinkingTextFromJSON extracts thinking text from a raw JSON string. +func GetThinkingTextFromJSON(jsonStr string) string { + return GetThinkingText(gjson.Parse(jsonStr)) +} + +// SanitizeThinkingPart normalizes a thinking part to a canonical form. +// Strips cache_control and other non-essential fields. +// Returns the sanitized part as JSON string. +func SanitizeThinkingPart(part gjson.Result) string { + // Gemini-style: { thought: true, text, thoughtSignature } + if part.Get("thought").Bool() { + result := `{"thought":true}` + if text := GetThinkingText(part); text != "" { + result, _ = sjson.Set(result, "text", text) + } + if sig := part.Get("thoughtSignature"); sig.Exists() && sig.Type == gjson.String { + result, _ = sjson.Set(result, "thoughtSignature", sig.String()) + } + return result + } + + // Anthropic-style: { type: "thinking", thinking, signature } + if part.Get("type").String() == "thinking" || part.Get("thinking").Exists() { + result := `{"type":"thinking"}` + if text := GetThinkingText(part); text != "" { + result, _ = sjson.Set(result, "thinking", text) + } + if sig := part.Get("signature"); sig.Exists() && sig.Type == gjson.String { + result, _ = sjson.Set(result, "signature", sig.String()) + } + return result + } + + // Not a thinking part, return as-is but strip cache_control + return StripCacheControl(part.Raw) +} + +// StripCacheControl removes cache_control and providerOptions from a JSON object. +func StripCacheControl(jsonStr string) string { + result := jsonStr + result, _ = sjson.Delete(result, "cache_control") + result, _ = sjson.Delete(result, "providerOptions") + return result +} diff --git a/litellm b/litellm new file mode 160000 index 00000000..0c48826c --- /dev/null +++ b/litellm @@ -0,0 +1 @@ +Subproject commit 0c48826cdc14a30953f55173c1eecdbfc859952d diff --git a/opencode-antigravity-auth b/opencode-antigravity-auth new file mode 160000 index 00000000..261a91f2 --- /dev/null +++ b/opencode-antigravity-auth @@ -0,0 +1 @@ +Subproject commit 261a91f21bd3bc1660168eb2b82301a6cf372e58 diff --git a/opencode-google-antigravity-auth b/opencode-google-antigravity-auth new file mode 160000 index 00000000..9f9493c7 --- /dev/null +++ b/opencode-google-antigravity-auth @@ -0,0 +1 @@ +Subproject commit 9f9493c730cbf0f17429e107a5bde00794752175 From e04b02113a2eaf39d20aecb8375c40f10a3de80f Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=EC=9D=B4=EB=8C=80=ED=9D=AC?= Date: Fri, 19 Dec 2025 13:14:51 +0900 Subject: [PATCH 12/38] refactor: Improve cache eviction ordering and clean up session ID usage Improve the cache eviction routine to sort entries by timestamp using the standard library sort routine (stable, clearer and faster than the prior manual selection/bubble logic), and remove a redundant request-derived session ID helper in favor of the centralized session ID function. Also drop now-unused crypto/encoding imports. This yields clearer, more maintainable eviction logic and removes duplicated/unused code and imports to reduce surface area and potential inconsistencies. --- internal/cache/signature_cache.go | 16 ++++++------ .../claude/antigravity_claude_response.go | 25 +------------------ 2 files changed, 8 insertions(+), 33 deletions(-) diff --git a/internal/cache/signature_cache.go b/internal/cache/signature_cache.go index 12f19cf0..c1326629 100644 --- a/internal/cache/signature_cache.go +++ b/internal/cache/signature_cache.go @@ -3,6 +3,7 @@ package cache import ( "crypto/sha256" "encoding/hex" + "sort" "sync" "time" ) @@ -89,20 +90,17 @@ func CacheSignature(sessionID, text, signature string) { ts time.Time }{key, entry.Timestamp}) } - // Simple approach: remove first quarter of entries + // Sort by timestamp (oldest first) using sort.Slice + sort.Slice(oldest, func(i, j int) bool { + return oldest[i].ts.Before(oldest[j].ts) + }) + toRemove := len(oldest) / 4 if toRemove < 1 { toRemove = 1 } - // Sort by timestamp (oldest first) - simple bubble for small N + for i := 0; i < toRemove; i++ { - minIdx := i - for j := i + 1; j < len(oldest); j++ { - if oldest[j].ts.Before(oldest[minIdx].ts) { - minIdx = j - } - } - oldest[i], oldest[minIdx] = oldest[minIdx], oldest[i] delete(sc.entries, oldest[i].key) } } diff --git a/internal/translator/antigravity/claude/antigravity_claude_response.go b/internal/translator/antigravity/claude/antigravity_claude_response.go index 939551ba..8f47b9bf 100644 --- a/internal/translator/antigravity/claude/antigravity_claude_response.go +++ b/internal/translator/antigravity/claude/antigravity_claude_response.go @@ -9,8 +9,6 @@ package claude import ( "bytes" "context" - "crypto/sha256" - "encoding/hex" "fmt" "strings" "sync/atomic" @@ -46,27 +44,6 @@ type Params struct { CurrentThinkingText strings.Builder // Accumulates thinking text for signature caching } -// deriveSessionIDFromRequest generates a stable session ID from the request JSON. -func deriveSessionIDFromRequest(rawJSON []byte) string { - messages := gjson.GetBytes(rawJSON, "messages") - if !messages.IsArray() { - return "" - } - for _, msg := range messages.Array() { - if msg.Get("role").String() == "user" { - content := msg.Get("content").String() - if content == "" { - content = msg.Get("content.0.text").String() - } - if content != "" { - h := sha256.Sum256([]byte(content)) - return hex.EncodeToString(h[:16]) - } - } - } - return "" -} - // toolUseIDCounter provides a process-wide unique counter for tool use identifiers. var toolUseIDCounter uint64 @@ -92,7 +69,7 @@ func ConvertAntigravityResponseToClaude(_ context.Context, _ string, originalReq HasFirstResponse: false, ResponseType: 0, ResponseIndex: 0, - SessionID: deriveSessionIDFromRequest(originalRequestRawJSON), + SessionID: deriveSessionID(originalRequestRawJSON), } } From 1b358c931cc84c5a8786fd2f6816ca0181637031 Mon Sep 17 00:00:00 2001 From: Supra4E8C Date: Fri, 19 Dec 2025 12:15:22 +0800 Subject: [PATCH 13/38] fix: restore get-auth-status ok fallback and document it --- internal/api/handlers/management/auth_files.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/internal/api/handlers/management/auth_files.go b/internal/api/handlers/management/auth_files.go index 433bae92..bf5a5b9c 100644 --- a/internal/api/handlers/management/auth_files.go +++ b/internal/api/handlers/management/auth_files.go @@ -2182,7 +2182,7 @@ func checkCloudAPIIsEnabled(ctx context.Context, httpClient *http.Client, projec func (h *Handler) GetAuthStatus(c *gin.Context) { state := strings.TrimSpace(c.Query("state")) if state == "" { - c.JSON(http.StatusBadRequest, gin.H{"status": "error", "error": "state is required"}) + c.JSON(http.StatusOK, gin.H{"status": "ok"}) return } if err := ValidateOAuthState(state); err != nil { From 13aa82f3f3bb2a8d8b6ce4797ec88c0c19b8c5f1 Mon Sep 17 00:00:00 2001 From: hkfires <10558748+hkfires@users.noreply.github.com> Date: Fri, 19 Dec 2025 13:11:15 +0800 Subject: [PATCH 14/38] fix(util): disable default thinking for gemini 3 flash --- internal/util/gemini_thinking.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/internal/util/gemini_thinking.go b/internal/util/gemini_thinking.go index af244b60..5a5a29a9 100644 --- a/internal/util/gemini_thinking.go +++ b/internal/util/gemini_thinking.go @@ -242,7 +242,7 @@ func ThinkingBudgetToGemini3Level(model string, budget int) (string, bool) { var modelsWithDefaultThinking = map[string]bool{ "gemini-3-pro-preview": true, "gemini-3-pro-image-preview": true, - "gemini-3-flash-preview": true, + // "gemini-3-flash-preview": true, } // ModelHasDefaultThinking returns true if the model should have thinking enabled by default. From 9d9b9e7a0d9b2dbd41f8c5b911ffbdbddad376d8 Mon Sep 17 00:00:00 2001 From: hkfires <10558748+hkfires@users.noreply.github.com> Date: Fri, 19 Dec 2025 13:57:47 +0800 Subject: [PATCH 15/38] fix(amp): add management auth skipper --- internal/api/modules/amp/routes.go | 20 ++++++++++++++++++-- 1 file changed, 18 insertions(+), 2 deletions(-) diff --git a/internal/api/modules/amp/routes.go b/internal/api/modules/amp/routes.go index 0abd943a..911d2b7d 100644 --- a/internal/api/modules/amp/routes.go +++ b/internal/api/modules/amp/routes.go @@ -95,6 +95,20 @@ func (m *AmpModule) managementAvailabilityMiddleware() gin.HandlerFunc { } } +// wrapManagementAuth skips auth for selected management paths while keeping authentication elsewhere. +func wrapManagementAuth(auth gin.HandlerFunc, prefixes ...string) gin.HandlerFunc { + return func(c *gin.Context) { + path := c.Request.URL.Path + for _, prefix := range prefixes { + if strings.HasPrefix(path, prefix) && (len(path) == len(prefix) || path[len(prefix)] == '/') { + c.Next() + return + } + } + auth(c) + } +} + // registerManagementRoutes registers Amp management proxy routes // These routes proxy through to the Amp control plane for OAuth, user management, etc. // Uses dynamic middleware and proxy getter for hot-reload support. @@ -109,8 +123,10 @@ func (m *AmpModule) registerManagementRoutes(engine *gin.Engine, baseHandler *ha ampAPI.Use(m.localhostOnlyMiddleware()) // Apply authentication middleware - requires valid API key in Authorization header + var authWithBypass gin.HandlerFunc if auth != nil { ampAPI.Use(auth) + authWithBypass = wrapManagementAuth(auth, "/threads", "/auth") } // Dynamic proxy handler that uses m.getProxy() for hot-reload support @@ -156,8 +172,8 @@ func (m *AmpModule) registerManagementRoutes(engine *gin.Engine, baseHandler *ha // Root-level routes that AMP CLI expects without /api prefix // These need the same security middleware as the /api/* routes (dynamic for hot-reload) rootMiddleware := []gin.HandlerFunc{m.managementAvailabilityMiddleware(), noCORSMiddleware(), m.localhostOnlyMiddleware()} - if auth != nil { - rootMiddleware = append(rootMiddleware, auth) + if authWithBypass != nil { + rootMiddleware = append(rootMiddleware, authWithBypass) } engine.GET("/threads/*path", append(rootMiddleware, proxyHandler)...) engine.GET("/threads.rss", append(rootMiddleware, proxyHandler)...) From 20390628451cc9e26dade50f820a59a36e2a4a9a Mon Sep 17 00:00:00 2001 From: hkfires <10558748+hkfires@users.noreply.github.com> Date: Fri, 19 Dec 2025 22:07:43 +0800 Subject: [PATCH 16/38] fix(gemini): add optional skip for gemini3 thinking conversion --- .../runtime/executor/aistudio_executor.go | 4 +- internal/util/gemini_thinking.go | 59 ++++++------------- internal/util/thinking.go | 28 +++++++++ 3 files changed, 49 insertions(+), 42 deletions(-) diff --git a/internal/runtime/executor/aistudio_executor.go b/internal/runtime/executor/aistudio_executor.go index cac23c87..17c8170f 100644 --- a/internal/runtime/executor/aistudio_executor.go +++ b/internal/runtime/executor/aistudio_executor.go @@ -325,8 +325,8 @@ func (e *AIStudioExecutor) translateRequest(req cliproxyexecutor.Request, opts c payload = ApplyThinkingMetadata(payload, req.Metadata, req.Model) payload = util.ApplyGemini3ThinkingLevelFromMetadata(req.Model, req.Metadata, payload) payload = util.ApplyDefaultThinkingIfNeeded(req.Model, payload) - payload = util.ConvertThinkingLevelToBudget(payload, req.Model) - payload = util.NormalizeGeminiThinkingBudget(req.Model, payload) + payload = util.ConvertThinkingLevelToBudget(payload, req.Model, true) + payload = util.NormalizeGeminiThinkingBudget(req.Model, payload, true) payload = util.StripThinkingConfigIfUnsupported(req.Model, payload) payload = fixGeminiImageAspectRatio(req.Model, payload) payload = applyPayloadConfig(e.cfg, req.Model, payload) diff --git a/internal/util/gemini_thinking.go b/internal/util/gemini_thinking.go index 5a5a29a9..ba9e13ef 100644 --- a/internal/util/gemini_thinking.go +++ b/internal/util/gemini_thinking.go @@ -352,8 +352,9 @@ func StripThinkingConfigIfUnsupported(model string, body []byte) []byte { // NormalizeGeminiThinkingBudget normalizes the thinkingBudget value in a standard Gemini // request body (generationConfig.thinkingConfig.thinkingBudget path). -// For Gemini 3 models, converts thinkingBudget to thinkingLevel per Google's documentation. -func NormalizeGeminiThinkingBudget(model string, body []byte) []byte { +// For Gemini 3 models, converts thinkingBudget to thinkingLevel per Google's documentation, +// unless skipGemini3Check is provided and true. +func NormalizeGeminiThinkingBudget(model string, body []byte, skipGemini3Check ...bool) []byte { const budgetPath = "generationConfig.thinkingConfig.thinkingBudget" const levelPath = "generationConfig.thinkingConfig.thinkingLevel" @@ -363,7 +364,8 @@ func NormalizeGeminiThinkingBudget(model string, body []byte) []byte { } // For Gemini 3 models, convert thinkingBudget to thinkingLevel - if IsGemini3Model(model) { + skipGemini3 := len(skipGemini3Check) > 0 && skipGemini3Check[0] + if IsGemini3Model(model) && !skipGemini3 { if level, ok := ThinkingBudgetToGemini3Level(model, int(budget.Int())); ok { updated, _ := sjson.SetBytes(body, levelPath, level) updated, _ = sjson.DeleteBytes(updated, budgetPath) @@ -382,8 +384,9 @@ func NormalizeGeminiThinkingBudget(model string, body []byte) []byte { // NormalizeGeminiCLIThinkingBudget normalizes the thinkingBudget value in a Gemini CLI // request body (request.generationConfig.thinkingConfig.thinkingBudget path). -// For Gemini 3 models, converts thinkingBudget to thinkingLevel per Google's documentation. -func NormalizeGeminiCLIThinkingBudget(model string, body []byte) []byte { +// For Gemini 3 models, converts thinkingBudget to thinkingLevel per Google's documentation, +// unless skipGemini3Check is provided and true. +func NormalizeGeminiCLIThinkingBudget(model string, body []byte, skipGemini3Check ...bool) []byte { const budgetPath = "request.generationConfig.thinkingConfig.thinkingBudget" const levelPath = "request.generationConfig.thinkingConfig.thinkingLevel" @@ -393,7 +396,8 @@ func NormalizeGeminiCLIThinkingBudget(model string, body []byte) []byte { } // For Gemini 3 models, convert thinkingBudget to thinkingLevel - if IsGemini3Model(model) { + skipGemini3 := len(skipGemini3Check) > 0 && skipGemini3Check[0] + if IsGemini3Model(model) && !skipGemini3 { if level, ok := ThinkingBudgetToGemini3Level(model, int(budget.Int())); ok { updated, _ := sjson.SetBytes(body, levelPath, level) updated, _ = sjson.DeleteBytes(updated, budgetPath) @@ -477,7 +481,7 @@ func ApplyReasoningEffortToGeminiCLI(body []byte, effort string) []byte { // ConvertThinkingLevelToBudget checks for "generationConfig.thinkingConfig.thinkingLevel" // and converts it to "thinkingBudget" for Gemini 2.5 models. -// For Gemini 3 models, preserves thinkingLevel as-is (does not convert). +// For Gemini 3 models, preserves thinkingLevel unless skipGemini3Check is provided and true. // Mappings for Gemini 2.5: // - "high" -> 32768 // - "medium" -> 8192 @@ -485,43 +489,31 @@ func ApplyReasoningEffortToGeminiCLI(body []byte, effort string) []byte { // - "minimal" -> 512 // // It removes "thinkingLevel" after conversion (for Gemini 2.5 only). -func ConvertThinkingLevelToBudget(body []byte, model string) []byte { +func ConvertThinkingLevelToBudget(body []byte, model string, skipGemini3Check ...bool) []byte { levelPath := "generationConfig.thinkingConfig.thinkingLevel" res := gjson.GetBytes(body, levelPath) if !res.Exists() { return body } - // For Gemini 3 models, preserve thinkingLevel - don't convert to budget - if IsGemini3Model(model) { + // For Gemini 3 models, preserve thinkingLevel unless explicitly skipped + skipGemini3 := len(skipGemini3Check) > 0 && skipGemini3Check[0] + if IsGemini3Model(model) && !skipGemini3 { return body } - level := strings.ToLower(res.String()) - var budget int - switch level { - case "high": - budget = 32768 - case "medium": - budget = 8192 - case "low": - budget = 1024 - case "minimal": - budget = 512 - default: - // Unknown level - remove it and let the API use defaults + budget, ok := ThinkingLevelToBudget(res.String()) + if !ok { updated, _ := sjson.DeleteBytes(body, levelPath) return updated } - // Set budget budgetPath := "generationConfig.thinkingConfig.thinkingBudget" updated, err := sjson.SetBytes(body, budgetPath, budget) if err != nil { return body } - // Remove level updated, err = sjson.DeleteBytes(updated, levelPath) if err != nil { return body @@ -544,31 +536,18 @@ func ConvertThinkingLevelToBudgetCLI(body []byte, model string) []byte { return body } - level := strings.ToLower(res.String()) - var budget int - switch level { - case "high": - budget = 32768 - case "medium": - budget = 8192 - case "low": - budget = 1024 - case "minimal": - budget = 512 - default: - // Unknown level - remove it and let the API use defaults + budget, ok := ThinkingLevelToBudget(res.String()) + if !ok { updated, _ := sjson.DeleteBytes(body, levelPath) return updated } - // Set budget budgetPath := "request.generationConfig.thinkingConfig.thinkingBudget" updated, err := sjson.SetBytes(body, budgetPath, budget) if err != nil { return body } - // Remove level updated, err = sjson.DeleteBytes(updated, levelPath) if err != nil { return body diff --git a/internal/util/thinking.go b/internal/util/thinking.go index 77ec16ba..74808669 100644 --- a/internal/util/thinking.go +++ b/internal/util/thinking.go @@ -160,6 +160,34 @@ func ThinkingEffortToBudget(model, effort string) (int, bool) { } } +// ThinkingLevelToBudget maps a Gemini thinkingLevel to a numeric thinking budget (tokens). +// +// Mappings: +// - "minimal" -> 512 +// - "low" -> 1024 +// - "medium" -> 8192 +// - "high" -> 32768 +// +// Returns false when the level is empty or unsupported. +func ThinkingLevelToBudget(level string) (int, bool) { + if level == "" { + return 0, false + } + normalized := strings.ToLower(strings.TrimSpace(level)) + switch normalized { + case "minimal": + return 512, true + case "low": + return 1024, true + case "medium": + return 8192, true + case "high": + return 32768, true + default: + return 0, false + } +} + // ThinkingBudgetToEffort maps a numeric thinking budget (tokens) // to a reasoning effort level for level-based models. // From d7afb6eb0c9e2330fcd60964ae8d7c483b10556f Mon Sep 17 00:00:00 2001 From: Luis Pater Date: Sat, 20 Dec 2025 03:11:28 +0800 Subject: [PATCH 17/38] fix(gemini): improve reasoning effort conversion for Gemini 3 models Refactors the reasoning effort conversion logic for Gemini models. The update specifically addresses how `reasoning_effort` is translated into Gemini 3 specific thinking configurations (`thinkingLevel`, `includeThoughts`) and ensures that numeric budgets are not incorrectly applied to level-based models. Changes include: - Differentiating conversion logic for Gemini 3 models versus other models. - Handling `none`, `auto`, and validated thinking levels for Gemini 3. - Maintaining existing conversion for models not using discrete thinking levels. --- .../antigravity_openai_request.go | 19 +++++++++++++-- .../chat-completions/gemini_openai_request.go | 24 +++++++++++++++---- 2 files changed, 37 insertions(+), 6 deletions(-) diff --git a/internal/translator/antigravity/openai/chat-completions/antigravity_openai_request.go b/internal/translator/antigravity/openai/chat-completions/antigravity_openai_request.go index d11afceb..293add0f 100644 --- a/internal/translator/antigravity/openai/chat-completions/antigravity_openai_request.go +++ b/internal/translator/antigravity/openai/chat-completions/antigravity_openai_request.go @@ -39,8 +39,23 @@ func ConvertOpenAIRequestToAntigravity(modelName string, inputRawJSON []byte, _ // Note: OpenAI official fields take precedence over extra_body.google.thinking_config re := gjson.GetBytes(rawJSON, "reasoning_effort") hasOfficialThinking := re.Exists() - if hasOfficialThinking && util.ModelSupportsThinking(modelName) && !util.ModelUsesThinkingLevels(modelName) { - out = util.ApplyReasoningEffortToGeminiCLI(out, re.String()) + if hasOfficialThinking && util.ModelSupportsThinking(modelName) { + effort := strings.ToLower(strings.TrimSpace(re.String())) + if util.IsGemini3Model(modelName) { + switch effort { + case "none": + out, _ = sjson.DeleteBytes(out, "request.generationConfig.thinkingConfig") + case "auto": + includeThoughts := true + out = util.ApplyGeminiCLIThinkingLevel(out, "", &includeThoughts) + default: + if level, ok := util.ValidateGemini3ThinkingLevel(modelName, effort); ok { + out = util.ApplyGeminiCLIThinkingLevel(out, level, nil) + } + } + } else if !util.ModelUsesThinkingLevels(modelName) { + out = util.ApplyReasoningEffortToGeminiCLI(out, effort) + } } // Cherry Studio extension extra_body.google.thinking_config (effective only when official fields are absent) diff --git a/internal/translator/gemini/openai/chat-completions/gemini_openai_request.go b/internal/translator/gemini/openai/chat-completions/gemini_openai_request.go index bc10ee34..c5f26fbd 100644 --- a/internal/translator/gemini/openai/chat-completions/gemini_openai_request.go +++ b/internal/translator/gemini/openai/chat-completions/gemini_openai_request.go @@ -37,12 +37,28 @@ func ConvertOpenAIRequestToGemini(modelName string, inputRawJSON []byte, _ bool) // Reasoning effort -> thinkingBudget/include_thoughts // Note: OpenAI official fields take precedence over extra_body.google.thinking_config - // Only convert for models that use numeric budgets (not discrete levels) to avoid - // incorrectly applying thinkingBudget for level-based models like gpt-5. + // Only apply numeric budgets for models that use budgets (not discrete levels) to avoid + // incorrectly applying thinkingBudget for level-based models like gpt-5. Gemini 3 models + // use thinkingLevel/includeThoughts instead. re := gjson.GetBytes(rawJSON, "reasoning_effort") hasOfficialThinking := re.Exists() - if hasOfficialThinking && util.ModelSupportsThinking(modelName) && !util.ModelUsesThinkingLevels(modelName) { - out = util.ApplyReasoningEffortToGemini(out, re.String()) + if hasOfficialThinking && util.ModelSupportsThinking(modelName) { + effort := strings.ToLower(strings.TrimSpace(re.String())) + if util.IsGemini3Model(modelName) { + switch effort { + case "none": + out, _ = sjson.DeleteBytes(out, "generationConfig.thinkingConfig") + case "auto": + includeThoughts := true + out = util.ApplyGeminiThinkingLevel(out, "", &includeThoughts) + default: + if level, ok := util.ValidateGemini3ThinkingLevel(modelName, effort); ok { + out = util.ApplyGeminiThinkingLevel(out, level, nil) + } + } + } else if !util.ModelUsesThinkingLevels(modelName) { + out = util.ApplyReasoningEffortToGemini(out, effort) + } } // Cherry Studio extension extra_body.google.thinking_config (effective only when official fields are absent) From 8a5db0216571c959e31a5e39e03748c380e71579 Mon Sep 17 00:00:00 2001 From: Luis Pater Date: Sat, 20 Dec 2025 04:49:02 +0800 Subject: [PATCH 18/38] Fixed: #607 refactor(config): re-export internal configuration types for SDK consumers --- examples/custom-provider/main.go | 6 +- internal/config/config.go | 5 +- internal/config/sdk_config.go | 87 +++++++++++++++++++++++ sdk/api/options.go | 46 ++++++++++++ sdk/cliproxy/builder.go | 2 +- sdk/cliproxy/providers.go | 2 +- sdk/cliproxy/service.go | 2 +- sdk/cliproxy/types.go | 2 +- sdk/cliproxy/watcher.go | 2 +- sdk/config/config.go | 118 ++++++++++++------------------- sdk/logging/request_logger.go | 18 +++++ 11 files changed, 206 insertions(+), 84 deletions(-) create mode 100644 internal/config/sdk_config.go create mode 100644 sdk/api/options.go create mode 100644 sdk/logging/request_logger.go diff --git a/examples/custom-provider/main.go b/examples/custom-provider/main.go index ffb5e346..930afdcf 100644 --- a/examples/custom-provider/main.go +++ b/examples/custom-provider/main.go @@ -23,13 +23,13 @@ import ( "time" "github.com/gin-gonic/gin" - "github.com/router-for-me/CLIProxyAPI/v6/internal/api" - "github.com/router-for-me/CLIProxyAPI/v6/internal/config" - "github.com/router-for-me/CLIProxyAPI/v6/internal/logging" + "github.com/router-for-me/CLIProxyAPI/v6/sdk/api" sdkAuth "github.com/router-for-me/CLIProxyAPI/v6/sdk/auth" "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy" coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth" clipexec "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor" + "github.com/router-for-me/CLIProxyAPI/v6/sdk/config" + "github.com/router-for-me/CLIProxyAPI/v6/sdk/logging" sdktr "github.com/router-for-me/CLIProxyAPI/v6/sdk/translator" ) diff --git a/internal/config/config.go b/internal/config/config.go index 63ac1cb0..2ced3796 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -12,7 +12,6 @@ import ( "strings" "syscall" - "github.com/router-for-me/CLIProxyAPI/v6/sdk/config" "golang.org/x/crypto/bcrypt" "gopkg.in/yaml.v3" ) @@ -21,7 +20,7 @@ const DefaultPanelGitHubRepository = "https://github.com/router-for-me/Cli-Proxy // Config represents the application's configuration, loaded from a YAML file. type Config struct { - config.SDKConfig `yaml:",inline"` + SDKConfig `yaml:",inline"` // Host is the network host/interface on which the API server will bind. // Default is empty ("") to bind all interfaces (IPv4 + IPv6). Use "127.0.0.1" or "localhost" for local-only access. Host string `yaml:"host" json:"-"` @@ -692,7 +691,7 @@ func sanitizeConfigForPersist(cfg *Config) *Config { } clone := *cfg clone.SDKConfig = cfg.SDKConfig - clone.SDKConfig.Access = config.AccessConfig{} + clone.SDKConfig.Access = AccessConfig{} return &clone } diff --git a/internal/config/sdk_config.go b/internal/config/sdk_config.go new file mode 100644 index 00000000..f6f20d5c --- /dev/null +++ b/internal/config/sdk_config.go @@ -0,0 +1,87 @@ +// Package config provides configuration management for the CLI Proxy API server. +// It handles loading and parsing YAML configuration files, and provides structured +// access to application settings including server port, authentication directory, +// debug settings, proxy configuration, and API keys. +package config + +// SDKConfig represents the application's configuration, loaded from a YAML file. +type SDKConfig struct { + // ProxyURL is the URL of an optional proxy server to use for outbound requests. + ProxyURL string `yaml:"proxy-url" json:"proxy-url"` + + // ForceModelPrefix requires explicit model prefixes (e.g., "teamA/gemini-3-pro-preview") + // to target prefixed credentials. When false, unprefixed model requests may use prefixed + // credentials as well. + ForceModelPrefix bool `yaml:"force-model-prefix" json:"force-model-prefix"` + + // RequestLog enables or disables detailed request logging functionality. + RequestLog bool `yaml:"request-log" json:"request-log"` + + // APIKeys is a list of keys for authenticating clients to this proxy server. + APIKeys []string `yaml:"api-keys" json:"api-keys"` + + // Access holds request authentication provider configuration. + Access AccessConfig `yaml:"auth,omitempty" json:"auth,omitempty"` +} + +// AccessConfig groups request authentication providers. +type AccessConfig struct { + // Providers lists configured authentication providers. + Providers []AccessProvider `yaml:"providers,omitempty" json:"providers,omitempty"` +} + +// AccessProvider describes a request authentication provider entry. +type AccessProvider struct { + // Name is the instance identifier for the provider. + Name string `yaml:"name" json:"name"` + + // Type selects the provider implementation registered via the SDK. + Type string `yaml:"type" json:"type"` + + // SDK optionally names a third-party SDK module providing this provider. + SDK string `yaml:"sdk,omitempty" json:"sdk,omitempty"` + + // APIKeys lists inline keys for providers that require them. + APIKeys []string `yaml:"api-keys,omitempty" json:"api-keys,omitempty"` + + // Config passes provider-specific options to the implementation. + Config map[string]any `yaml:"config,omitempty" json:"config,omitempty"` +} + +const ( + // AccessProviderTypeConfigAPIKey is the built-in provider validating inline API keys. + AccessProviderTypeConfigAPIKey = "config-api-key" + + // DefaultAccessProviderName is applied when no provider name is supplied. + DefaultAccessProviderName = "config-inline" +) + +// ConfigAPIKeyProvider returns the first inline API key provider if present. +func (c *SDKConfig) ConfigAPIKeyProvider() *AccessProvider { + if c == nil { + return nil + } + for i := range c.Access.Providers { + if c.Access.Providers[i].Type == AccessProviderTypeConfigAPIKey { + if c.Access.Providers[i].Name == "" { + c.Access.Providers[i].Name = DefaultAccessProviderName + } + return &c.Access.Providers[i] + } + } + return nil +} + +// MakeInlineAPIKeyProvider constructs an inline API key provider configuration. +// It returns nil when no keys are supplied. +func MakeInlineAPIKeyProvider(keys []string) *AccessProvider { + if len(keys) == 0 { + return nil + } + provider := &AccessProvider{ + Name: DefaultAccessProviderName, + Type: AccessProviderTypeConfigAPIKey, + APIKeys: append([]string(nil), keys...), + } + return provider +} diff --git a/sdk/api/options.go b/sdk/api/options.go new file mode 100644 index 00000000..8497884b --- /dev/null +++ b/sdk/api/options.go @@ -0,0 +1,46 @@ +// Package api exposes server option helpers for embedding CLIProxyAPI. +// +// It wraps internal server option types so external projects can configure the embedded +// HTTP server without importing internal packages. +package api + +import ( + "time" + + "github.com/gin-gonic/gin" + internalapi "github.com/router-for-me/CLIProxyAPI/v6/internal/api" + "github.com/router-for-me/CLIProxyAPI/v6/sdk/api/handlers" + "github.com/router-for-me/CLIProxyAPI/v6/sdk/config" + "github.com/router-for-me/CLIProxyAPI/v6/sdk/logging" +) + +// ServerOption customises HTTP server construction. +type ServerOption = internalapi.ServerOption + +// WithMiddleware appends additional Gin middleware during server construction. +func WithMiddleware(mw ...gin.HandlerFunc) ServerOption { return internalapi.WithMiddleware(mw...) } + +// WithEngineConfigurator allows callers to mutate the Gin engine prior to middleware setup. +func WithEngineConfigurator(fn func(*gin.Engine)) ServerOption { + return internalapi.WithEngineConfigurator(fn) +} + +// WithRouterConfigurator appends a callback after default routes are registered. +func WithRouterConfigurator(fn func(*gin.Engine, *handlers.BaseAPIHandler, *config.Config)) ServerOption { + return internalapi.WithRouterConfigurator(fn) +} + +// WithLocalManagementPassword stores a runtime-only management password accepted for localhost requests. +func WithLocalManagementPassword(password string) ServerOption { + return internalapi.WithLocalManagementPassword(password) +} + +// WithKeepAliveEndpoint enables a keep-alive endpoint with the provided timeout and callback. +func WithKeepAliveEndpoint(timeout time.Duration, onTimeout func()) ServerOption { + return internalapi.WithKeepAliveEndpoint(timeout, onTimeout) +} + +// WithRequestLoggerFactory customises request logger creation. +func WithRequestLoggerFactory(factory func(*config.Config, string) logging.RequestLogger) ServerOption { + return internalapi.WithRequestLoggerFactory(factory) +} diff --git a/sdk/cliproxy/builder.go b/sdk/cliproxy/builder.go index e1a1d503..a85e91d9 100644 --- a/sdk/cliproxy/builder.go +++ b/sdk/cliproxy/builder.go @@ -7,10 +7,10 @@ import ( "fmt" "github.com/router-for-me/CLIProxyAPI/v6/internal/api" - "github.com/router-for-me/CLIProxyAPI/v6/internal/config" sdkaccess "github.com/router-for-me/CLIProxyAPI/v6/sdk/access" sdkAuth "github.com/router-for-me/CLIProxyAPI/v6/sdk/auth" coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth" + "github.com/router-for-me/CLIProxyAPI/v6/sdk/config" ) // Builder constructs a Service instance with customizable providers. diff --git a/sdk/cliproxy/providers.go b/sdk/cliproxy/providers.go index 401885f5..7ce89f76 100644 --- a/sdk/cliproxy/providers.go +++ b/sdk/cliproxy/providers.go @@ -3,8 +3,8 @@ package cliproxy import ( "context" - "github.com/router-for-me/CLIProxyAPI/v6/internal/config" "github.com/router-for-me/CLIProxyAPI/v6/internal/watcher" + "github.com/router-for-me/CLIProxyAPI/v6/sdk/config" ) // NewFileTokenClientProvider returns the default token-backed client loader. diff --git a/sdk/cliproxy/service.go b/sdk/cliproxy/service.go index f3cbf484..e4cd9e5d 100644 --- a/sdk/cliproxy/service.go +++ b/sdk/cliproxy/service.go @@ -13,7 +13,6 @@ import ( "time" "github.com/router-for-me/CLIProxyAPI/v6/internal/api" - "github.com/router-for-me/CLIProxyAPI/v6/internal/config" "github.com/router-for-me/CLIProxyAPI/v6/internal/registry" "github.com/router-for-me/CLIProxyAPI/v6/internal/runtime/executor" _ "github.com/router-for-me/CLIProxyAPI/v6/internal/usage" @@ -23,6 +22,7 @@ import ( sdkAuth "github.com/router-for-me/CLIProxyAPI/v6/sdk/auth" coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth" "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/usage" + "github.com/router-for-me/CLIProxyAPI/v6/sdk/config" log "github.com/sirupsen/logrus" ) diff --git a/sdk/cliproxy/types.go b/sdk/cliproxy/types.go index 42c7c488..1521dffe 100644 --- a/sdk/cliproxy/types.go +++ b/sdk/cliproxy/types.go @@ -6,9 +6,9 @@ package cliproxy import ( "context" - "github.com/router-for-me/CLIProxyAPI/v6/internal/config" "github.com/router-for-me/CLIProxyAPI/v6/internal/watcher" coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth" + "github.com/router-for-me/CLIProxyAPI/v6/sdk/config" ) // TokenClientProvider loads clients backed by stored authentication tokens. diff --git a/sdk/cliproxy/watcher.go b/sdk/cliproxy/watcher.go index 921e2068..caeadf19 100644 --- a/sdk/cliproxy/watcher.go +++ b/sdk/cliproxy/watcher.go @@ -3,9 +3,9 @@ package cliproxy import ( "context" - "github.com/router-for-me/CLIProxyAPI/v6/internal/config" "github.com/router-for-me/CLIProxyAPI/v6/internal/watcher" coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth" + "github.com/router-for-me/CLIProxyAPI/v6/sdk/config" ) func defaultWatcherFactory(configPath, authDir string, reload func(*config.Config)) (*WatcherWrapper, error) { diff --git a/sdk/config/config.go b/sdk/config/config.go index f6f20d5c..6e4efad5 100644 --- a/sdk/config/config.go +++ b/sdk/config/config.go @@ -1,87 +1,59 @@ -// Package config provides configuration management for the CLI Proxy API server. -// It handles loading and parsing YAML configuration files, and provides structured -// access to application settings including server port, authentication directory, -// debug settings, proxy configuration, and API keys. +// Package config provides the public SDK configuration API. +// +// It re-exports the server configuration types and helpers so external projects can +// embed CLIProxyAPI without importing internal packages. package config -// SDKConfig represents the application's configuration, loaded from a YAML file. -type SDKConfig struct { - // ProxyURL is the URL of an optional proxy server to use for outbound requests. - ProxyURL string `yaml:"proxy-url" json:"proxy-url"` +import internalconfig "github.com/router-for-me/CLIProxyAPI/v6/internal/config" - // ForceModelPrefix requires explicit model prefixes (e.g., "teamA/gemini-3-pro-preview") - // to target prefixed credentials. When false, unprefixed model requests may use prefixed - // credentials as well. - ForceModelPrefix bool `yaml:"force-model-prefix" json:"force-model-prefix"` +type SDKConfig = internalconfig.SDKConfig +type AccessConfig = internalconfig.AccessConfig +type AccessProvider = internalconfig.AccessProvider - // RequestLog enables or disables detailed request logging functionality. - RequestLog bool `yaml:"request-log" json:"request-log"` +type Config = internalconfig.Config - // APIKeys is a list of keys for authenticating clients to this proxy server. - APIKeys []string `yaml:"api-keys" json:"api-keys"` +type TLSConfig = internalconfig.TLSConfig +type RemoteManagement = internalconfig.RemoteManagement +type AmpCode = internalconfig.AmpCode +type PayloadConfig = internalconfig.PayloadConfig +type PayloadRule = internalconfig.PayloadRule +type PayloadModelRule = internalconfig.PayloadModelRule - // Access holds request authentication provider configuration. - Access AccessConfig `yaml:"auth,omitempty" json:"auth,omitempty"` -} +type GeminiKey = internalconfig.GeminiKey +type CodexKey = internalconfig.CodexKey +type ClaudeKey = internalconfig.ClaudeKey +type VertexCompatKey = internalconfig.VertexCompatKey +type VertexCompatModel = internalconfig.VertexCompatModel +type OpenAICompatibility = internalconfig.OpenAICompatibility +type OpenAICompatibilityAPIKey = internalconfig.OpenAICompatibilityAPIKey +type OpenAICompatibilityModel = internalconfig.OpenAICompatibilityModel -// AccessConfig groups request authentication providers. -type AccessConfig struct { - // Providers lists configured authentication providers. - Providers []AccessProvider `yaml:"providers,omitempty" json:"providers,omitempty"` -} - -// AccessProvider describes a request authentication provider entry. -type AccessProvider struct { - // Name is the instance identifier for the provider. - Name string `yaml:"name" json:"name"` - - // Type selects the provider implementation registered via the SDK. - Type string `yaml:"type" json:"type"` - - // SDK optionally names a third-party SDK module providing this provider. - SDK string `yaml:"sdk,omitempty" json:"sdk,omitempty"` - - // APIKeys lists inline keys for providers that require them. - APIKeys []string `yaml:"api-keys,omitempty" json:"api-keys,omitempty"` - - // Config passes provider-specific options to the implementation. - Config map[string]any `yaml:"config,omitempty" json:"config,omitempty"` -} +type TLS = internalconfig.TLSConfig const ( - // AccessProviderTypeConfigAPIKey is the built-in provider validating inline API keys. - AccessProviderTypeConfigAPIKey = "config-api-key" - - // DefaultAccessProviderName is applied when no provider name is supplied. - DefaultAccessProviderName = "config-inline" + AccessProviderTypeConfigAPIKey = internalconfig.AccessProviderTypeConfigAPIKey + DefaultAccessProviderName = internalconfig.DefaultAccessProviderName + DefaultPanelGitHubRepository = internalconfig.DefaultPanelGitHubRepository ) -// ConfigAPIKeyProvider returns the first inline API key provider if present. -func (c *SDKConfig) ConfigAPIKeyProvider() *AccessProvider { - if c == nil { - return nil - } - for i := range c.Access.Providers { - if c.Access.Providers[i].Type == AccessProviderTypeConfigAPIKey { - if c.Access.Providers[i].Name == "" { - c.Access.Providers[i].Name = DefaultAccessProviderName - } - return &c.Access.Providers[i] - } - } - return nil +func MakeInlineAPIKeyProvider(keys []string) *AccessProvider { + return internalconfig.MakeInlineAPIKeyProvider(keys) } -// MakeInlineAPIKeyProvider constructs an inline API key provider configuration. -// It returns nil when no keys are supplied. -func MakeInlineAPIKeyProvider(keys []string) *AccessProvider { - if len(keys) == 0 { - return nil - } - provider := &AccessProvider{ - Name: DefaultAccessProviderName, - Type: AccessProviderTypeConfigAPIKey, - APIKeys: append([]string(nil), keys...), - } - return provider +func LoadConfig(configFile string) (*Config, error) { return internalconfig.LoadConfig(configFile) } + +func LoadConfigOptional(configFile string, optional bool) (*Config, error) { + return internalconfig.LoadConfigOptional(configFile, optional) +} + +func SaveConfigPreserveComments(configFile string, cfg *Config) error { + return internalconfig.SaveConfigPreserveComments(configFile, cfg) +} + +func SaveConfigPreserveCommentsUpdateNestedScalar(configFile string, path []string, value string) error { + return internalconfig.SaveConfigPreserveCommentsUpdateNestedScalar(configFile, path, value) +} + +func NormalizeCommentIndentation(data []byte) []byte { + return internalconfig.NormalizeCommentIndentation(data) } diff --git a/sdk/logging/request_logger.go b/sdk/logging/request_logger.go new file mode 100644 index 00000000..39ff5ba8 --- /dev/null +++ b/sdk/logging/request_logger.go @@ -0,0 +1,18 @@ +// Package logging re-exports request logging primitives for SDK consumers. +package logging + +import internallogging "github.com/router-for-me/CLIProxyAPI/v6/internal/logging" + +// RequestLogger defines the interface for logging HTTP requests and responses. +type RequestLogger = internallogging.RequestLogger + +// StreamingLogWriter handles real-time logging of streaming response chunks. +type StreamingLogWriter = internallogging.StreamingLogWriter + +// FileRequestLogger implements RequestLogger using file-based storage. +type FileRequestLogger = internallogging.FileRequestLogger + +// NewFileRequestLogger creates a new file-based request logger. +func NewFileRequestLogger(enabled bool, logsDir string, configDir string) *FileRequestLogger { + return internallogging.NewFileRequestLogger(enabled, logsDir, configDir) +} From c84ff42bcdf64d6f160ec61ffa3c3abd69aded44 Mon Sep 17 00:00:00 2001 From: hkfires <10558748+hkfires@users.noreply.github.com> Date: Sat, 20 Dec 2025 10:15:25 +0800 Subject: [PATCH 19/38] fix(amp): add /docs routes to proxy --- internal/api/modules/amp/routes.go | 6 +++++- sdk/auth/filestore.go | 2 +- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/internal/api/modules/amp/routes.go b/internal/api/modules/amp/routes.go index 911d2b7d..50900f24 100644 --- a/internal/api/modules/amp/routes.go +++ b/internal/api/modules/amp/routes.go @@ -126,7 +126,7 @@ func (m *AmpModule) registerManagementRoutes(engine *gin.Engine, baseHandler *ha var authWithBypass gin.HandlerFunc if auth != nil { ampAPI.Use(auth) - authWithBypass = wrapManagementAuth(auth, "/threads", "/auth") + authWithBypass = wrapManagementAuth(auth, "/threads", "/auth", "/docs") } // Dynamic proxy handler that uses m.getProxy() for hot-reload support @@ -175,7 +175,11 @@ func (m *AmpModule) registerManagementRoutes(engine *gin.Engine, baseHandler *ha if authWithBypass != nil { rootMiddleware = append(rootMiddleware, authWithBypass) } + engine.GET("/threads", append(rootMiddleware, proxyHandler)...) engine.GET("/threads/*path", append(rootMiddleware, proxyHandler)...) + engine.GET("/docs", append(rootMiddleware, proxyHandler)...) + engine.GET("/docs/*path", append(rootMiddleware, proxyHandler)...) + engine.GET("/threads.rss", append(rootMiddleware, proxyHandler)...) engine.GET("/news.rss", append(rootMiddleware, proxyHandler)...) diff --git a/sdk/auth/filestore.go b/sdk/auth/filestore.go index 2fa963df..84092d37 100644 --- a/sdk/auth/filestore.go +++ b/sdk/auth/filestore.go @@ -267,7 +267,7 @@ func (s *FileTokenStore) baseDirSnapshot() string { } // DEPRECATED: Use metadataEqualIgnoringTimestamps for comparing auth metadata. -// This function is kept for backward compatibility but can cause refresh loops. +// This function is kept for backward compatibility but can cause refresh loops. func jsonEqual(a, b []byte) bool { var objA any var objB any From 1231dc9cdaea52c22c48b2146b67d817055f1d95 Mon Sep 17 00:00:00 2001 From: Ben Vargas Date: Fri, 19 Dec 2025 17:38:05 -0700 Subject: [PATCH 20/38] feat(antigravity): add payload config support to Antigravity executor Add applyPayloadConfig calls to all Antigravity executor paths (Execute, executeClaudeNonStream, ExecuteStream) to enable config.yaml payload overrides for Antigravity/Gemini-Claude models. This allows users to configure thinking budget and other parameters via payload.override in config.yaml for models like gemini-claude-opus-4-5*. --- internal/runtime/executor/antigravity_executor.go | 3 +++ 1 file changed, 3 insertions(+) diff --git a/internal/runtime/executor/antigravity_executor.go b/internal/runtime/executor/antigravity_executor.go index 8b4e37ee..0d1bd175 100644 --- a/internal/runtime/executor/antigravity_executor.go +++ b/internal/runtime/executor/antigravity_executor.go @@ -93,6 +93,7 @@ func (e *AntigravityExecutor) Execute(ctx context.Context, auth *cliproxyauth.Au translated = util.ApplyGemini3ThinkingLevelFromMetadataCLI(req.Model, req.Metadata, translated) translated = util.ApplyDefaultThinkingIfNeededCLI(req.Model, translated) translated = normalizeAntigravityThinking(req.Model, translated) + translated = applyPayloadConfigWithRoot(e.cfg, req.Model, "antigravity", "request", translated) baseURLs := antigravityBaseURLFallbackOrder(auth) httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0) @@ -187,6 +188,7 @@ func (e *AntigravityExecutor) executeClaudeNonStream(ctx context.Context, auth * translated = util.ApplyGemini3ThinkingLevelFromMetadataCLI(req.Model, req.Metadata, translated) translated = util.ApplyDefaultThinkingIfNeededCLI(req.Model, translated) translated = normalizeAntigravityThinking(req.Model, translated) + translated = applyPayloadConfigWithRoot(e.cfg, req.Model, "antigravity", "request", translated) baseURLs := antigravityBaseURLFallbackOrder(auth) httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0) @@ -520,6 +522,7 @@ func (e *AntigravityExecutor) ExecuteStream(ctx context.Context, auth *cliproxya translated = util.ApplyGemini3ThinkingLevelFromMetadataCLI(req.Model, req.Metadata, translated) translated = util.ApplyDefaultThinkingIfNeededCLI(req.Model, translated) translated = normalizeAntigravityThinking(req.Model, translated) + translated = applyPayloadConfigWithRoot(e.cfg, req.Model, "antigravity", "request", translated) baseURLs := antigravityBaseURLFallbackOrder(auth) httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0) From 3e4858a624cb6fb2df963129b3b1a44c315cc182 Mon Sep 17 00:00:00 2001 From: Luis Pater Date: Sat, 20 Dec 2025 15:52:59 +0800 Subject: [PATCH 21/38] feat(config): add log file size limit configuration #535 This commit introduces a new configuration option `logs-max-total-size-mb` that allows users to set a maximum total size (in MB) for log files in the logs directory. When this limit is exceeded, the oldest log files will be automatically deleted to stay within the specified size. Setting this value to 0 (the default) disables this feature. This change enhances log management by preventing excessive disk space usage. --- cmd/server/main.go | 2 +- config.example.yaml | 4 + internal/api/server.go | 15 +- internal/config/config.go | 9 ++ internal/logging/global_logger.go | 32 +++-- internal/logging/log_dir_cleaner.go | 166 +++++++++++++++++++++++ internal/logging/log_dir_cleaner_test.go | 70 ++++++++++ 7 files changed, 282 insertions(+), 16 deletions(-) create mode 100644 internal/logging/log_dir_cleaner.go create mode 100644 internal/logging/log_dir_cleaner_test.go diff --git a/cmd/server/main.go b/cmd/server/main.go index aec51ab8..2b20bcb5 100644 --- a/cmd/server/main.go +++ b/cmd/server/main.go @@ -405,7 +405,7 @@ func main() { usage.SetStatisticsEnabled(cfg.UsageStatisticsEnabled) coreauth.SetQuotaCooldownDisabled(cfg.DisableCooling) - if err = logging.ConfigureLogOutput(cfg.LoggingToFile); err != nil { + if err = logging.ConfigureLogOutput(cfg.LoggingToFile, cfg.LogsMaxTotalSizeMB); err != nil { log.Errorf("failed to configure log output: %v", err) return } diff --git a/config.example.yaml b/config.example.yaml index 563dd06c..1e084cb4 100644 --- a/config.example.yaml +++ b/config.example.yaml @@ -42,6 +42,10 @@ debug: false # When true, write application logs to rotating files instead of stdout logging-to-file: false +# Maximum total size (MB) of log files under the logs directory. When exceeded, the oldest log +# files are deleted until within the limit. Set to 0 to disable. +logs-max-total-size-mb: 0 + # When false, disable in-memory usage statistics aggregation usage-statistics-enabled: false diff --git a/internal/api/server.go b/internal/api/server.go index d6fe91bf..094da118 100644 --- a/internal/api/server.go +++ b/internal/api/server.go @@ -844,11 +844,20 @@ func (s *Server) UpdateClients(cfg *config.Config) { } } - if oldCfg != nil && oldCfg.LoggingToFile != cfg.LoggingToFile { - if err := logging.ConfigureLogOutput(cfg.LoggingToFile); err != nil { + if oldCfg == nil || oldCfg.LoggingToFile != cfg.LoggingToFile || oldCfg.LogsMaxTotalSizeMB != cfg.LogsMaxTotalSizeMB { + if err := logging.ConfigureLogOutput(cfg.LoggingToFile, cfg.LogsMaxTotalSizeMB); err != nil { log.Errorf("failed to reconfigure log output: %v", err) } else { - log.Debugf("logging_to_file updated from %t to %t", oldCfg.LoggingToFile, cfg.LoggingToFile) + if oldCfg == nil { + log.Debug("log output configuration refreshed") + } else { + if oldCfg.LoggingToFile != cfg.LoggingToFile { + log.Debugf("logging_to_file updated from %t to %t", oldCfg.LoggingToFile, cfg.LoggingToFile) + } + if oldCfg.LogsMaxTotalSizeMB != cfg.LogsMaxTotalSizeMB { + log.Debugf("logs_max_total_size_mb updated from %d to %d", oldCfg.LogsMaxTotalSizeMB, cfg.LogsMaxTotalSizeMB) + } + } } } diff --git a/internal/config/config.go b/internal/config/config.go index 2ced3796..cd56bd77 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -42,6 +42,10 @@ type Config struct { // LoggingToFile controls whether application logs are written to rotating files or stdout. LoggingToFile bool `yaml:"logging-to-file" json:"logging-to-file"` + // LogsMaxTotalSizeMB limits the total size (in MB) of log files under the logs directory. + // When exceeded, the oldest log files are deleted until within the limit. Set to 0 to disable. + LogsMaxTotalSizeMB int `yaml:"logs-max-total-size-mb" json:"logs-max-total-size-mb"` + // UsageStatisticsEnabled toggles in-memory usage aggregation; when false, usage data is discarded. UsageStatisticsEnabled bool `yaml:"usage-statistics-enabled" json:"usage-statistics-enabled"` @@ -341,6 +345,7 @@ func LoadConfigOptional(configFile string, optional bool) (*Config, error) { // Set defaults before unmarshal so that absent keys keep defaults. cfg.Host = "" // Default empty: binds to all interfaces (IPv4 + IPv6) cfg.LoggingToFile = false + cfg.LogsMaxTotalSizeMB = 0 cfg.UsageStatisticsEnabled = false cfg.DisableCooling = false cfg.AmpCode.RestrictManagementToLocalhost = false // Default to false: API key auth is sufficient @@ -385,6 +390,10 @@ func LoadConfigOptional(configFile string, optional bool) (*Config, error) { cfg.RemoteManagement.PanelGitHubRepository = DefaultPanelGitHubRepository } + if cfg.LogsMaxTotalSizeMB < 0 { + cfg.LogsMaxTotalSizeMB = 0 + } + // Sync request authentication providers with inline API keys for backwards compatibility. syncInlineAccessProvider(&cfg) diff --git a/internal/logging/global_logger.go b/internal/logging/global_logger.go index 28fde213..e7d795fa 100644 --- a/internal/logging/global_logger.go +++ b/internal/logging/global_logger.go @@ -72,39 +72,45 @@ func SetupBaseLogger() { } // ConfigureLogOutput switches the global log destination between rotating files and stdout. -func ConfigureLogOutput(loggingToFile bool) error { +// When logsMaxTotalSizeMB > 0, a background cleaner removes the oldest log files in the logs directory +// until the total size is within the limit. +func ConfigureLogOutput(loggingToFile bool, logsMaxTotalSizeMB int) error { SetupBaseLogger() writerMu.Lock() defer writerMu.Unlock() + logDir := "logs" + if base := util.WritablePath(); base != "" { + logDir = filepath.Join(base, "logs") + } + + protectedPath := "" if loggingToFile { - logDir := "logs" - if base := util.WritablePath(); base != "" { - logDir = filepath.Join(base, "logs") - } if err := os.MkdirAll(logDir, 0o755); err != nil { return fmt.Errorf("logging: failed to create log directory: %w", err) } if logWriter != nil { _ = logWriter.Close() } + protectedPath = filepath.Join(logDir, "main.log") logWriter = &lumberjack.Logger{ - Filename: filepath.Join(logDir, "main.log"), + Filename: protectedPath, MaxSize: 10, MaxBackups: 0, MaxAge: 0, Compress: false, } log.SetOutput(logWriter) - return nil + } else { + if logWriter != nil { + _ = logWriter.Close() + logWriter = nil + } + log.SetOutput(os.Stdout) } - if logWriter != nil { - _ = logWriter.Close() - logWriter = nil - } - log.SetOutput(os.Stdout) + configureLogDirCleanerLocked(logDir, logsMaxTotalSizeMB, protectedPath) return nil } @@ -112,6 +118,8 @@ func closeLogOutputs() { writerMu.Lock() defer writerMu.Unlock() + stopLogDirCleanerLocked() + if logWriter != nil { _ = logWriter.Close() logWriter = nil diff --git a/internal/logging/log_dir_cleaner.go b/internal/logging/log_dir_cleaner.go new file mode 100644 index 00000000..e563b381 --- /dev/null +++ b/internal/logging/log_dir_cleaner.go @@ -0,0 +1,166 @@ +package logging + +import ( + "context" + "os" + "path/filepath" + "sort" + "strings" + "time" + + log "github.com/sirupsen/logrus" +) + +const logDirCleanerInterval = time.Minute + +var logDirCleanerCancel context.CancelFunc + +func configureLogDirCleanerLocked(logDir string, maxTotalSizeMB int, protectedPath string) { + stopLogDirCleanerLocked() + + if maxTotalSizeMB <= 0 { + return + } + + maxBytes := int64(maxTotalSizeMB) * 1024 * 1024 + if maxBytes <= 0 { + return + } + + dir := strings.TrimSpace(logDir) + if dir == "" { + return + } + + ctx, cancel := context.WithCancel(context.Background()) + logDirCleanerCancel = cancel + go runLogDirCleaner(ctx, filepath.Clean(dir), maxBytes, strings.TrimSpace(protectedPath)) +} + +func stopLogDirCleanerLocked() { + if logDirCleanerCancel == nil { + return + } + logDirCleanerCancel() + logDirCleanerCancel = nil +} + +func runLogDirCleaner(ctx context.Context, logDir string, maxBytes int64, protectedPath string) { + ticker := time.NewTicker(logDirCleanerInterval) + defer ticker.Stop() + + cleanOnce := func() { + deleted, errClean := enforceLogDirSizeLimit(logDir, maxBytes, protectedPath) + if errClean != nil { + log.WithError(errClean).Warn("logging: failed to enforce log directory size limit") + return + } + if deleted > 0 { + log.Debugf("logging: removed %d old log file(s) to enforce log directory size limit", deleted) + } + } + + cleanOnce() + for { + select { + case <-ctx.Done(): + return + case <-ticker.C: + cleanOnce() + } + } +} + +func enforceLogDirSizeLimit(logDir string, maxBytes int64, protectedPath string) (int, error) { + if maxBytes <= 0 { + return 0, nil + } + + dir := strings.TrimSpace(logDir) + if dir == "" { + return 0, nil + } + dir = filepath.Clean(dir) + + entries, errRead := os.ReadDir(dir) + if errRead != nil { + if os.IsNotExist(errRead) { + return 0, nil + } + return 0, errRead + } + + protected := strings.TrimSpace(protectedPath) + if protected != "" { + protected = filepath.Clean(protected) + } + + type logFile struct { + path string + size int64 + modTime time.Time + } + + var ( + files []logFile + total int64 + ) + for _, entry := range entries { + if entry.IsDir() { + continue + } + name := entry.Name() + if !isLogFileName(name) { + continue + } + info, errInfo := entry.Info() + if errInfo != nil { + continue + } + if !info.Mode().IsRegular() { + continue + } + path := filepath.Join(dir, name) + files = append(files, logFile{ + path: path, + size: info.Size(), + modTime: info.ModTime(), + }) + total += info.Size() + } + + if total <= maxBytes { + return 0, nil + } + + sort.Slice(files, func(i, j int) bool { + return files[i].modTime.Before(files[j].modTime) + }) + + deleted := 0 + for _, file := range files { + if total <= maxBytes { + break + } + if protected != "" && filepath.Clean(file.path) == protected { + continue + } + if errRemove := os.Remove(file.path); errRemove != nil { + log.WithError(errRemove).Warnf("logging: failed to remove old log file: %s", filepath.Base(file.path)) + continue + } + total -= file.size + deleted++ + } + + return deleted, nil +} + +func isLogFileName(name string) bool { + trimmed := strings.TrimSpace(name) + if trimmed == "" { + return false + } + lower := strings.ToLower(trimmed) + return strings.HasSuffix(lower, ".log") || strings.HasSuffix(lower, ".log.gz") +} diff --git a/internal/logging/log_dir_cleaner_test.go b/internal/logging/log_dir_cleaner_test.go new file mode 100644 index 00000000..3670da50 --- /dev/null +++ b/internal/logging/log_dir_cleaner_test.go @@ -0,0 +1,70 @@ +package logging + +import ( + "os" + "path/filepath" + "testing" + "time" +) + +func TestEnforceLogDirSizeLimitDeletesOldest(t *testing.T) { + dir := t.TempDir() + + writeLogFile(t, filepath.Join(dir, "old.log"), 60, time.Unix(1, 0)) + writeLogFile(t, filepath.Join(dir, "mid.log"), 60, time.Unix(2, 0)) + protected := filepath.Join(dir, "main.log") + writeLogFile(t, protected, 60, time.Unix(3, 0)) + + deleted, err := enforceLogDirSizeLimit(dir, 120, protected) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if deleted != 1 { + t.Fatalf("expected 1 deleted file, got %d", deleted) + } + + if _, err := os.Stat(filepath.Join(dir, "old.log")); !os.IsNotExist(err) { + t.Fatalf("expected old.log to be removed, stat error: %v", err) + } + if _, err := os.Stat(filepath.Join(dir, "mid.log")); err != nil { + t.Fatalf("expected mid.log to remain, stat error: %v", err) + } + if _, err := os.Stat(protected); err != nil { + t.Fatalf("expected protected main.log to remain, stat error: %v", err) + } +} + +func TestEnforceLogDirSizeLimitSkipsProtected(t *testing.T) { + dir := t.TempDir() + + protected := filepath.Join(dir, "main.log") + writeLogFile(t, protected, 200, time.Unix(1, 0)) + writeLogFile(t, filepath.Join(dir, "other.log"), 50, time.Unix(2, 0)) + + deleted, err := enforceLogDirSizeLimit(dir, 100, protected) + if err != nil { + t.Fatalf("unexpected error: %v", err) + } + if deleted != 1 { + t.Fatalf("expected 1 deleted file, got %d", deleted) + } + + if _, err := os.Stat(protected); err != nil { + t.Fatalf("expected protected main.log to remain, stat error: %v", err) + } + if _, err := os.Stat(filepath.Join(dir, "other.log")); !os.IsNotExist(err) { + t.Fatalf("expected other.log to be removed, stat error: %v", err) + } +} + +func writeLogFile(t *testing.T, path string, size int, modTime time.Time) { + t.Helper() + + data := make([]byte, size) + if err := os.WriteFile(path, data, 0o644); err != nil { + t.Fatalf("write file: %v", err) + } + if err := os.Chtimes(path, modTime, modTime); err != nil { + t.Fatalf("set times: %v", err) + } +} From 93414f1baa7d3734defb5b6937576b1474916a7e Mon Sep 17 00:00:00 2001 From: Supra4E8C Date: Sat, 20 Dec 2025 18:25:55 +0800 Subject: [PATCH 22/38] feat (auth): CLI OAuth supports pasting callback URLs to complete login - Added callback URL resolution and terminal prompt logic - Codex/Claude/iFlow/Antigravity/Gemini login supports callback URL or local callback completion - Update Gemini login option signature and manager call - CLI default prompt function is compatible with null input to continue waiting --- .../api/handlers/management/auth_files.go | 4 +- internal/auth/gemini/gemini_auth.go | 60 ++++++++++++-- internal/cmd/anthropic_login.go | 7 +- internal/cmd/antigravity_login.go | 7 +- internal/cmd/iflow_login.go | 8 +- internal/cmd/login.go | 18 +++-- internal/cmd/openai_login.go | 7 +- internal/misc/oauth.go | 80 +++++++++++++++++++ sdk/auth/antigravity.go | 9 +++ sdk/auth/claude.go | 31 ++++++- sdk/auth/codex.go | 31 ++++++- sdk/auth/gemini.go | 5 +- sdk/auth/iflow.go | 27 ++++++- sdk/auth/oauth_callback.go | 41 ++++++++++ 14 files changed, 302 insertions(+), 33 deletions(-) create mode 100644 sdk/auth/oauth_callback.go diff --git a/internal/api/handlers/management/auth_files.go b/internal/api/handlers/management/auth_files.go index bf5a5b9c..4f42bd7a 100644 --- a/internal/api/handlers/management/auth_files.go +++ b/internal/api/handlers/management/auth_files.go @@ -1093,7 +1093,9 @@ func (h *Handler) RequestGeminiCLIToken(c *gin.Context) { // Initialize authenticated HTTP client via GeminiAuth to honor proxy settings gemAuth := geminiAuth.NewGeminiAuth() - gemClient, errGetClient := gemAuth.GetAuthenticatedClient(ctx, &ts, h.cfg, true) + gemClient, errGetClient := gemAuth.GetAuthenticatedClient(ctx, &ts, h.cfg, &geminiAuth.WebLoginOptions{ + NoBrowser: true, + }) if errGetClient != nil { log.Errorf("failed to get authenticated client: %v", errGetClient) SetOAuthSessionError(state, "Failed to get authenticated client") diff --git a/internal/auth/gemini/gemini_auth.go b/internal/auth/gemini/gemini_auth.go index f173c95f..dc9b1034 100644 --- a/internal/auth/gemini/gemini_auth.go +++ b/internal/auth/gemini/gemini_auth.go @@ -18,6 +18,7 @@ import ( "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/codex" "github.com/router-for-me/CLIProxyAPI/v6/internal/browser" "github.com/router-for-me/CLIProxyAPI/v6/internal/config" + "github.com/router-for-me/CLIProxyAPI/v6/internal/misc" "github.com/router-for-me/CLIProxyAPI/v6/internal/util" log "github.com/sirupsen/logrus" "github.com/tidwall/gjson" @@ -46,6 +47,12 @@ var ( type GeminiAuth struct { } +// WebLoginOptions customizes the interactive OAuth flow. +type WebLoginOptions struct { + NoBrowser bool + Prompt func(string) (string, error) +} + // NewGeminiAuth creates a new instance of GeminiAuth. func NewGeminiAuth() *GeminiAuth { return &GeminiAuth{} @@ -59,12 +66,12 @@ func NewGeminiAuth() *GeminiAuth { // - ctx: The context for the HTTP client // - ts: The Gemini token storage containing authentication tokens // - cfg: The configuration containing proxy settings -// - noBrowser: Optional parameter to disable browser opening +// - opts: Optional parameters to customize browser and prompt behavior // // Returns: // - *http.Client: An HTTP client configured with authentication // - error: An error if the client configuration fails, nil otherwise -func (g *GeminiAuth) GetAuthenticatedClient(ctx context.Context, ts *GeminiTokenStorage, cfg *config.Config, noBrowser ...bool) (*http.Client, error) { +func (g *GeminiAuth) GetAuthenticatedClient(ctx context.Context, ts *GeminiTokenStorage, cfg *config.Config, opts *WebLoginOptions) (*http.Client, error) { // Configure proxy settings for the HTTP client if a proxy URL is provided. proxyURL, err := url.Parse(cfg.ProxyURL) if err == nil { @@ -109,7 +116,7 @@ func (g *GeminiAuth) GetAuthenticatedClient(ctx context.Context, ts *GeminiToken // If no token is found in storage, initiate the web-based OAuth flow. if ts.Token == nil { fmt.Printf("Could not load token from file, starting OAuth flow.\n") - token, err = g.getTokenFromWeb(ctx, conf, noBrowser...) + token, err = g.getTokenFromWeb(ctx, conf, opts) if err != nil { return nil, fmt.Errorf("failed to get token from web: %w", err) } @@ -205,12 +212,12 @@ func (g *GeminiAuth) createTokenStorage(ctx context.Context, config *oauth2.Conf // Parameters: // - ctx: The context for the HTTP client // - config: The OAuth2 configuration -// - noBrowser: Optional parameter to disable browser opening +// - opts: Optional parameters to customize browser and prompt behavior // // Returns: // - *oauth2.Token: The OAuth2 token obtained from the authorization flow // - error: An error if the token acquisition fails, nil otherwise -func (g *GeminiAuth) getTokenFromWeb(ctx context.Context, config *oauth2.Config, noBrowser ...bool) (*oauth2.Token, error) { +func (g *GeminiAuth) getTokenFromWeb(ctx context.Context, config *oauth2.Config, opts *WebLoginOptions) (*oauth2.Token, error) { // Use a channel to pass the authorization code from the HTTP handler to the main function. codeChan := make(chan string) errChan := make(chan error) @@ -250,7 +257,12 @@ func (g *GeminiAuth) getTokenFromWeb(ctx context.Context, config *oauth2.Config, // Open the authorization URL in the user's browser. authURL := config.AuthCodeURL("state-token", oauth2.AccessTypeOffline, oauth2.SetAuthURLParam("prompt", "consent")) - if len(noBrowser) == 1 && !noBrowser[0] { + noBrowser := false + if opts != nil { + noBrowser = opts.NoBrowser + } + + if !noBrowser { fmt.Println("Opening browser for authentication...") // Check if browser is available @@ -281,11 +293,47 @@ func (g *GeminiAuth) getTokenFromWeb(ctx context.Context, config *oauth2.Config, // Wait for the authorization code or an error. var authCode string + manualCodeChan := make(chan string, 1) + manualErrChan := make(chan error, 1) + if opts != nil && opts.Prompt != nil { + go func() { + input, err := opts.Prompt("Paste the Gemini callback URL (or press Enter to keep waiting): ") + if err != nil { + manualErrChan <- err + return + } + parsed, err := misc.ParseOAuthCallback(input) + if err != nil { + manualErrChan <- err + return + } + if parsed == nil { + return + } + if parsed.Error != "" { + manualErrChan <- fmt.Errorf("authentication failed via callback: %s", parsed.Error) + return + } + if parsed.Code == "" { + manualErrChan <- fmt.Errorf("code not found in callback") + return + } + manualCodeChan <- parsed.Code + }() + } else { + manualCodeChan = nil + manualErrChan = nil + } + select { case code := <-codeChan: authCode = code case err := <-errChan: return nil, err + case code := <-manualCodeChan: + authCode = code + case err := <-manualErrChan: + return nil, err case <-time.After(5 * time.Minute): // Timeout return nil, fmt.Errorf("oauth flow timed out") } diff --git a/internal/cmd/anthropic_login.go b/internal/cmd/anthropic_login.go index 8e9d01cd..6efd87a8 100644 --- a/internal/cmd/anthropic_login.go +++ b/internal/cmd/anthropic_login.go @@ -24,12 +24,17 @@ func DoClaudeLogin(cfg *config.Config, options *LoginOptions) { options = &LoginOptions{} } + promptFn := options.Prompt + if promptFn == nil { + promptFn = defaultProjectPrompt() + } + manager := newAuthManager() authOpts := &sdkAuth.LoginOptions{ NoBrowser: options.NoBrowser, Metadata: map[string]string{}, - Prompt: options.Prompt, + Prompt: promptFn, } _, savedPath, err := manager.Login(context.Background(), "claude", cfg, authOpts) diff --git a/internal/cmd/antigravity_login.go b/internal/cmd/antigravity_login.go index b2602638..1cd42899 100644 --- a/internal/cmd/antigravity_login.go +++ b/internal/cmd/antigravity_login.go @@ -15,11 +15,16 @@ func DoAntigravityLogin(cfg *config.Config, options *LoginOptions) { options = &LoginOptions{} } + promptFn := options.Prompt + if promptFn == nil { + promptFn = defaultProjectPrompt() + } + manager := newAuthManager() authOpts := &sdkAuth.LoginOptions{ NoBrowser: options.NoBrowser, Metadata: map[string]string{}, - Prompt: options.Prompt, + Prompt: promptFn, } record, savedPath, err := manager.Login(context.Background(), "antigravity", cfg, authOpts) diff --git a/internal/cmd/iflow_login.go b/internal/cmd/iflow_login.go index ba43470b..cf00b63c 100644 --- a/internal/cmd/iflow_login.go +++ b/internal/cmd/iflow_login.go @@ -20,13 +20,7 @@ func DoIFlowLogin(cfg *config.Config, options *LoginOptions) { promptFn := options.Prompt if promptFn == nil { - promptFn = func(prompt string) (string, error) { - fmt.Println() - fmt.Println(prompt) - var value string - _, err := fmt.Scanln(&value) - return value, err - } + promptFn = defaultProjectPrompt() } authOpts := &sdkAuth.LoginOptions{ diff --git a/internal/cmd/login.go b/internal/cmd/login.go index de01cec5..0f079b4b 100644 --- a/internal/cmd/login.go +++ b/internal/cmd/login.go @@ -55,11 +55,17 @@ func DoLogin(cfg *config.Config, projectID string, options *LoginOptions) { ctx := context.Background() + promptFn := options.Prompt + if promptFn == nil { + promptFn = defaultProjectPrompt() + options.Prompt = promptFn + } + loginOpts := &sdkAuth.LoginOptions{ NoBrowser: options.NoBrowser, ProjectID: strings.TrimSpace(projectID), Metadata: map[string]string{}, - Prompt: options.Prompt, + Prompt: promptFn, } authenticator := sdkAuth.NewGeminiAuthenticator() @@ -76,7 +82,10 @@ func DoLogin(cfg *config.Config, projectID string, options *LoginOptions) { } geminiAuth := gemini.NewGeminiAuth() - httpClient, errClient := geminiAuth.GetAuthenticatedClient(ctx, storage, cfg, options.NoBrowser) + httpClient, errClient := geminiAuth.GetAuthenticatedClient(ctx, storage, cfg, &gemini.WebLoginOptions{ + NoBrowser: options.NoBrowser, + Prompt: promptFn, + }) if errClient != nil { log.Errorf("Gemini authentication failed: %v", errClient) return @@ -90,11 +99,6 @@ func DoLogin(cfg *config.Config, projectID string, options *LoginOptions) { return } - promptFn := options.Prompt - if promptFn == nil { - promptFn = defaultProjectPrompt() - } - selectedProjectID := promptForProjectSelection(projects, strings.TrimSpace(projectID), promptFn) projectSelections, errSelection := resolveProjectSelections(selectedProjectID, projects) if errSelection != nil { diff --git a/internal/cmd/openai_login.go b/internal/cmd/openai_login.go index e402e476..d981f6ae 100644 --- a/internal/cmd/openai_login.go +++ b/internal/cmd/openai_login.go @@ -35,12 +35,17 @@ func DoCodexLogin(cfg *config.Config, options *LoginOptions) { options = &LoginOptions{} } + promptFn := options.Prompt + if promptFn == nil { + promptFn = defaultProjectPrompt() + } + manager := newAuthManager() authOpts := &sdkAuth.LoginOptions{ NoBrowser: options.NoBrowser, Metadata: map[string]string{}, - Prompt: options.Prompt, + Prompt: promptFn, } _, savedPath, err := manager.Login(context.Background(), "codex", cfg, authOpts) diff --git a/internal/misc/oauth.go b/internal/misc/oauth.go index acf034b2..d5cae403 100644 --- a/internal/misc/oauth.go +++ b/internal/misc/oauth.go @@ -4,6 +4,8 @@ import ( "crypto/rand" "encoding/hex" "fmt" + "net/url" + "strings" ) // GenerateRandomState generates a cryptographically secure random state parameter @@ -19,3 +21,81 @@ func GenerateRandomState() (string, error) { } return hex.EncodeToString(bytes), nil } + +// OAuthCallback captures the parsed OAuth callback parameters. +type OAuthCallback struct { + Code string + State string + Error string + ErrorDescription string +} + +// ParseOAuthCallback extracts OAuth parameters from a callback URL. +// It returns nil when the input is empty. +func ParseOAuthCallback(input string) (*OAuthCallback, error) { + trimmed := strings.TrimSpace(input) + if trimmed == "" { + return nil, nil + } + + candidate := trimmed + if !strings.Contains(candidate, "://") { + if strings.HasPrefix(candidate, "?") { + candidate = "http://localhost" + candidate + } else if strings.Contains(candidate, "=") { + candidate = "http://localhost/?" + candidate + } else { + return nil, fmt.Errorf("invalid callback URL") + } + } + + parsedURL, err := url.Parse(candidate) + if err != nil { + return nil, err + } + + query := parsedURL.Query() + code := strings.TrimSpace(query.Get("code")) + state := strings.TrimSpace(query.Get("state")) + errCode := strings.TrimSpace(query.Get("error")) + errDesc := strings.TrimSpace(query.Get("error_description")) + + if parsedURL.Fragment != "" { + if fragQuery, errFrag := url.ParseQuery(parsedURL.Fragment); errFrag == nil { + if code == "" { + code = strings.TrimSpace(fragQuery.Get("code")) + } + if state == "" { + state = strings.TrimSpace(fragQuery.Get("state")) + } + if errCode == "" { + errCode = strings.TrimSpace(fragQuery.Get("error")) + } + if errDesc == "" { + errDesc = strings.TrimSpace(fragQuery.Get("error_description")) + } + } + } + + if code != "" && state == "" && strings.Contains(code, "#") { + parts := strings.SplitN(code, "#", 2) + code = parts[0] + state = parts[1] + } + + if errCode == "" && errDesc != "" { + errCode = errDesc + errDesc = "" + } + + if code == "" && errCode == "" { + return nil, fmt.Errorf("callback URL missing code") + } + + return &OAuthCallback{ + Code: code, + State: state, + Error: errCode, + ErrorDescription: errDesc, + }, nil +} diff --git a/sdk/auth/antigravity.go b/sdk/auth/antigravity.go index b3d7f6c5..832bd88e 100644 --- a/sdk/auth/antigravity.go +++ b/sdk/auth/antigravity.go @@ -99,9 +99,18 @@ func (AntigravityAuthenticator) Login(ctx context.Context, cfg *config.Config, o fmt.Println("Waiting for antigravity authentication callback...") var cbRes callbackResult + manualCh, manualErrCh := promptForOAuthCallback(opts.Prompt, "antigravity") select { case res := <-cbChan: cbRes = res + case manual := <-manualCh: + cbRes = callbackResult{ + Code: manual.Code, + State: manual.State, + Error: manual.Error, + } + case err = <-manualErrCh: + return nil, err case <-time.After(5 * time.Minute): return nil, fmt.Errorf("antigravity: authentication timed out") } diff --git a/sdk/auth/claude.go b/sdk/auth/claude.go index da9e5065..d88cdf29 100644 --- a/sdk/auth/claude.go +++ b/sdk/auth/claude.go @@ -98,16 +98,41 @@ func (a *ClaudeAuthenticator) Login(ctx context.Context, cfg *config.Config, opt fmt.Println("Waiting for Claude authentication callback...") - result, err := oauthServer.WaitForCallback(5 * time.Minute) - if err != nil { + callbackCh := make(chan *claude.OAuthResult, 1) + callbackErrCh := make(chan error, 1) + manualCh, manualErrCh := promptForOAuthCallback(opts.Prompt, "Claude") + manualDescription := "" + + go func() { + result, errWait := oauthServer.WaitForCallback(5 * time.Minute) + if errWait != nil { + callbackErrCh <- errWait + return + } + callbackCh <- result + }() + + var result *claude.OAuthResult + select { + case result = <-callbackCh: + case err = <-callbackErrCh: if strings.Contains(err.Error(), "timeout") { return nil, claude.NewAuthenticationError(claude.ErrCallbackTimeout, err) } return nil, err + case manual := <-manualCh: + manualDescription = manual.ErrorDescription + result = &claude.OAuthResult{ + Code: manual.Code, + State: manual.State, + Error: manual.Error, + } + case err = <-manualErrCh: + return nil, err } if result.Error != "" { - return nil, claude.NewOAuthError(result.Error, "", http.StatusBadRequest) + return nil, claude.NewOAuthError(result.Error, manualDescription, http.StatusBadRequest) } if result.State != state { diff --git a/sdk/auth/codex.go b/sdk/auth/codex.go index 138c2141..b0a6b4a4 100644 --- a/sdk/auth/codex.go +++ b/sdk/auth/codex.go @@ -97,16 +97,41 @@ func (a *CodexAuthenticator) Login(ctx context.Context, cfg *config.Config, opts fmt.Println("Waiting for Codex authentication callback...") - result, err := oauthServer.WaitForCallback(5 * time.Minute) - if err != nil { + callbackCh := make(chan *codex.OAuthResult, 1) + callbackErrCh := make(chan error, 1) + manualCh, manualErrCh := promptForOAuthCallback(opts.Prompt, "Codex") + manualDescription := "" + + go func() { + result, errWait := oauthServer.WaitForCallback(5 * time.Minute) + if errWait != nil { + callbackErrCh <- errWait + return + } + callbackCh <- result + }() + + var result *codex.OAuthResult + select { + case result = <-callbackCh: + case err = <-callbackErrCh: if strings.Contains(err.Error(), "timeout") { return nil, codex.NewAuthenticationError(codex.ErrCallbackTimeout, err) } return nil, err + case manual := <-manualCh: + manualDescription = manual.ErrorDescription + result = &codex.OAuthResult{ + Code: manual.Code, + State: manual.State, + Error: manual.Error, + } + case err = <-manualErrCh: + return nil, err } if result.Error != "" { - return nil, codex.NewOAuthError(result.Error, "", http.StatusBadRequest) + return nil, codex.NewOAuthError(result.Error, manualDescription, http.StatusBadRequest) } if result.State != state { diff --git a/sdk/auth/gemini.go b/sdk/auth/gemini.go index 7110101f..75ef4579 100644 --- a/sdk/auth/gemini.go +++ b/sdk/auth/gemini.go @@ -44,7 +44,10 @@ func (a *GeminiAuthenticator) Login(ctx context.Context, cfg *config.Config, opt } geminiAuth := gemini.NewGeminiAuth() - _, err := geminiAuth.GetAuthenticatedClient(ctx, &ts, cfg, opts.NoBrowser) + _, err := geminiAuth.GetAuthenticatedClient(ctx, &ts, cfg, &gemini.WebLoginOptions{ + NoBrowser: opts.NoBrowser, + Prompt: opts.Prompt, + }) if err != nil { return nil, fmt.Errorf("gemini authentication failed: %w", err) } diff --git a/sdk/auth/iflow.go b/sdk/auth/iflow.go index ee96bdaa..d7621a99 100644 --- a/sdk/auth/iflow.go +++ b/sdk/auth/iflow.go @@ -84,9 +84,32 @@ func (a *IFlowAuthenticator) Login(ctx context.Context, cfg *config.Config, opts fmt.Println("Waiting for iFlow authentication callback...") - result, err := oauthServer.WaitForCallback(5 * time.Minute) - if err != nil { + callbackCh := make(chan *iflow.OAuthResult, 1) + callbackErrCh := make(chan error, 1) + manualCh, manualErrCh := promptForOAuthCallback(opts.Prompt, "iFlow") + + go func() { + result, errWait := oauthServer.WaitForCallback(5 * time.Minute) + if errWait != nil { + callbackErrCh <- errWait + return + } + callbackCh <- result + }() + + var result *iflow.OAuthResult + select { + case result = <-callbackCh: + case err = <-callbackErrCh: return nil, fmt.Errorf("iflow auth: callback wait failed: %w", err) + case manual := <-manualCh: + result = &iflow.OAuthResult{ + Code: manual.Code, + State: manual.State, + Error: manual.Error, + } + case err = <-manualErrCh: + return nil, err } if result.Error != "" { return nil, fmt.Errorf("iflow auth: provider returned error %s", result.Error) diff --git a/sdk/auth/oauth_callback.go b/sdk/auth/oauth_callback.go new file mode 100644 index 00000000..3f0ac925 --- /dev/null +++ b/sdk/auth/oauth_callback.go @@ -0,0 +1,41 @@ +package auth + +import ( + "fmt" + + "github.com/router-for-me/CLIProxyAPI/v6/internal/misc" +) + +func promptForOAuthCallback(prompt func(string) (string, error), provider string) (<-chan *misc.OAuthCallback, <-chan error) { + if prompt == nil { + return nil, nil + } + + resultCh := make(chan *misc.OAuthCallback, 1) + errCh := make(chan error, 1) + + go func() { + label := provider + if label == "" { + label = "OAuth" + } + input, err := prompt(fmt.Sprintf("Paste the %s callback URL (or press Enter to keep waiting): ", label)) + if err != nil { + errCh <- err + return + } + + parsed, err := misc.ParseOAuthCallback(input) + if err != nil { + errCh <- err + return + } + if parsed == nil { + return + } + + resultCh <- parsed + }() + + return resultCh, errCh +} From 9855615f1ee7c936c9d349c3be1babaa2559ab69 Mon Sep 17 00:00:00 2001 From: Supra4E8C Date: Sat, 20 Dec 2025 19:03:38 +0800 Subject: [PATCH 23/38] fix(gemini): avoid stale manual oauth prompt and accept schemeless callbacks --- internal/auth/gemini/gemini_auth.go | 90 ++++++++++++++++++----------- internal/misc/oauth.go | 2 + 2 files changed, 57 insertions(+), 35 deletions(-) diff --git a/internal/auth/gemini/gemini_auth.go b/internal/auth/gemini/gemini_auth.go index dc9b1034..7b18e738 100644 --- a/internal/auth/gemini/gemini_auth.go +++ b/internal/auth/gemini/gemini_auth.go @@ -219,8 +219,8 @@ func (g *GeminiAuth) createTokenStorage(ctx context.Context, config *oauth2.Conf // - error: An error if the token acquisition fails, nil otherwise func (g *GeminiAuth) getTokenFromWeb(ctx context.Context, config *oauth2.Config, opts *WebLoginOptions) (*oauth2.Token, error) { // Use a channel to pass the authorization code from the HTTP handler to the main function. - codeChan := make(chan string) - errChan := make(chan error) + codeChan := make(chan string, 1) + errChan := make(chan error, 1) // Create a new HTTP server with its own multiplexer. mux := http.NewServeMux() @@ -230,17 +230,26 @@ func (g *GeminiAuth) getTokenFromWeb(ctx context.Context, config *oauth2.Config, mux.HandleFunc("/oauth2callback", func(w http.ResponseWriter, r *http.Request) { if err := r.URL.Query().Get("error"); err != "" { _, _ = fmt.Fprintf(w, "Authentication failed: %s", err) - errChan <- fmt.Errorf("authentication failed via callback: %s", err) + select { + case errChan <- fmt.Errorf("authentication failed via callback: %s", err): + default: + } return } code := r.URL.Query().Get("code") if code == "" { _, _ = fmt.Fprint(w, "Authentication failed: code not found.") - errChan <- fmt.Errorf("code not found in callback") + select { + case errChan <- fmt.Errorf("code not found in callback"): + default: + } return } _, _ = fmt.Fprint(w, "

Authentication successful!

You can close this window.

") - codeChan <- code + select { + case codeChan <- code: + default: + } }) // Start the server in a goroutine. @@ -293,49 +302,60 @@ func (g *GeminiAuth) getTokenFromWeb(ctx context.Context, config *oauth2.Config, // Wait for the authorization code or an error. var authCode string - manualCodeChan := make(chan string, 1) - manualErrChan := make(chan error, 1) + timeoutTimer := time.NewTimer(5 * time.Minute) + defer timeoutTimer.Stop() + + var manualPromptTimer *time.Timer + var manualPromptC <-chan time.Time if opts != nil && opts.Prompt != nil { - go func() { + manualPromptTimer = time.NewTimer(15 * time.Second) + manualPromptC = manualPromptTimer.C + defer manualPromptTimer.Stop() + } + +waitForCallback: + for { + select { + case code := <-codeChan: + authCode = code + break waitForCallback + case err := <-errChan: + return nil, err + case <-manualPromptC: + manualPromptC = nil + if manualPromptTimer != nil { + manualPromptTimer.Stop() + } + select { + case code := <-codeChan: + authCode = code + break waitForCallback + case err := <-errChan: + return nil, err + default: + } input, err := opts.Prompt("Paste the Gemini callback URL (or press Enter to keep waiting): ") if err != nil { - manualErrChan <- err - return + return nil, err } parsed, err := misc.ParseOAuthCallback(input) if err != nil { - manualErrChan <- err - return + return nil, err } if parsed == nil { - return + continue } if parsed.Error != "" { - manualErrChan <- fmt.Errorf("authentication failed via callback: %s", parsed.Error) - return + return nil, fmt.Errorf("authentication failed via callback: %s", parsed.Error) } if parsed.Code == "" { - manualErrChan <- fmt.Errorf("code not found in callback") - return + return nil, fmt.Errorf("code not found in callback") } - manualCodeChan <- parsed.Code - }() - } else { - manualCodeChan = nil - manualErrChan = nil - } - - select { - case code := <-codeChan: - authCode = code - case err := <-errChan: - return nil, err - case code := <-manualCodeChan: - authCode = code - case err := <-manualErrChan: - return nil, err - case <-time.After(5 * time.Minute): // Timeout - return nil, fmt.Errorf("oauth flow timed out") + authCode = parsed.Code + break waitForCallback + case <-timeoutTimer.C: + return nil, fmt.Errorf("oauth flow timed out") + } } // Shutdown the server. diff --git a/internal/misc/oauth.go b/internal/misc/oauth.go index d5cae403..c14f39d2 100644 --- a/internal/misc/oauth.go +++ b/internal/misc/oauth.go @@ -42,6 +42,8 @@ func ParseOAuthCallback(input string) (*OAuthCallback, error) { if !strings.Contains(candidate, "://") { if strings.HasPrefix(candidate, "?") { candidate = "http://localhost" + candidate + } else if strings.ContainsAny(candidate, "/?#") || strings.Contains(candidate, ":") { + candidate = "http://" + candidate } else if strings.Contains(candidate, "=") { candidate = "http://localhost/?" + candidate } else { From df777650ac2f25e11be55b7c5a2b961e2009ceb4 Mon Sep 17 00:00:00 2001 From: sheauhuu Date: Sat, 20 Dec 2025 20:05:20 +0800 Subject: [PATCH 24/38] feat: add gemini-3-flash-preview model definition in GetGeminiModels --- internal/registry/model_definitions.go | 15 +++++++++++++++ 1 file changed, 15 insertions(+) diff --git a/internal/registry/model_definitions.go b/internal/registry/model_definitions.go index fe0f85cb..67898bbc 100644 --- a/internal/registry/model_definitions.go +++ b/internal/registry/model_definitions.go @@ -162,6 +162,21 @@ func GetGeminiModels() []*ModelInfo { SupportedGenerationMethods: []string{"generateContent", "countTokens", "createCachedContent", "batchGenerateContent"}, Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"low", "high"}}, }, + { + ID: "gemini-3-flash-preview", + Object: "model", + Created: 1765929600, + OwnedBy: "google", + Type: "gemini", + Name: "models/gemini-3-flash-preview", + Version: "3.0", + DisplayName: "Gemini 3 Flash Preview", + Description: "Gemini 3 Flash Preview", + InputTokenLimit: 1048576, + OutputTokenLimit: 65536, + SupportedGenerationMethods: []string{"generateContent", "countTokens", "createCachedContent", "batchGenerateContent"}, + Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"minimal", "low", "medium", "high"}}, + }, { ID: "gemini-3-pro-image-preview", Object: "model", From ed5ec5b55c9b0bfc70a64b8776f2527e67b42760 Mon Sep 17 00:00:00 2001 From: Luis Pater Date: Sat, 20 Dec 2025 22:19:35 +0800 Subject: [PATCH 25/38] feat(amp): enhance model mapping and Gemini thinking configuration This commit introduces several improvements to the AMP (Advanced Model Proxy) module: - **Model Mapping Logic:** The `FallbackHandler` now uses a more robust approach for model mapping. It includes the extraction and preservation of dynamic "thinking suffixes" (e.g., `(xhigh)`) during mapping, ensuring that these configurations are correctly applied to the mapped model. A new `resolveMappedModel` function centralizes this logic for cleaner code. - **ModelMapper Verification:** The `ModelMapper` in `model_mapping.go` now verifies that the target model of a mapping has available providers *after* normalizing it. This prevents mappings to non-existent or unresolvable models. - **Gemini Thinking Configuration Cleanup:** In `gemini_thinking.go`, unnecessary `generationConfig.thinkingConfig.include_thoughts` and `generationConfig.thinkingConfig.thinkingBudget` fields are now deleted from the request body when applying Gemini thinking levels. This prevents potential conflicts or redundant configurations. - **Testing:** A new test case `TestModelMapper_MapModel_TargetWithThinkingSuffix` has been added to `model_mapping_test.go` to specifically cover the preservation of thinking suffixes during model mapping. --- internal/api/modules/amp/fallback_handlers.go | 90 ++++++++++++------- .../api/modules/amp/fallback_handlers_test.go | 73 +++++++++++++++ internal/api/modules/amp/model_mapping.go | 3 +- .../api/modules/amp/model_mapping_test.go | 19 ++++ internal/util/gemini_thinking.go | 12 +++ 5 files changed, 163 insertions(+), 34 deletions(-) create mode 100644 internal/api/modules/amp/fallback_handlers_test.go diff --git a/internal/api/modules/amp/fallback_handlers.go b/internal/api/modules/amp/fallback_handlers.go index e7132b81..940bd5e8 100644 --- a/internal/api/modules/amp/fallback_handlers.go +++ b/internal/api/modules/amp/fallback_handlers.go @@ -134,7 +134,43 @@ func (fh *FallbackHandler) WrapHandler(handler gin.HandlerFunc) gin.HandlerFunc } // Normalize model (handles dynamic thinking suffixes) - normalizedModel, _ := util.NormalizeThinkingModel(modelName) + normalizedModel, thinkingMetadata := util.NormalizeThinkingModel(modelName) + thinkingSuffix := "" + if thinkingMetadata != nil && strings.HasPrefix(modelName, normalizedModel) { + thinkingSuffix = modelName[len(normalizedModel):] + } + + resolveMappedModel := func() (string, []string) { + if fh.modelMapper == nil { + return "", nil + } + + mappedModel := fh.modelMapper.MapModel(modelName) + if mappedModel == "" { + mappedModel = fh.modelMapper.MapModel(normalizedModel) + } + mappedModel = strings.TrimSpace(mappedModel) + if mappedModel == "" { + return "", nil + } + + // Preserve dynamic thinking suffix (e.g. "(xhigh)") when mapping applies, unless the target + // already specifies its own thinking suffix. + if thinkingSuffix != "" { + _, mappedThinkingMetadata := util.NormalizeThinkingModel(mappedModel) + if mappedThinkingMetadata == nil { + mappedModel += thinkingSuffix + } + } + + mappedBaseModel, _ := util.NormalizeThinkingModel(mappedModel) + mappedProviders := util.GetProviderName(mappedBaseModel) + if len(mappedProviders) == 0 { + return "", nil + } + + return mappedModel, mappedProviders + } // Track resolved model for logging (may change if mapping is applied) resolvedModel := normalizedModel @@ -147,21 +183,15 @@ func (fh *FallbackHandler) WrapHandler(handler gin.HandlerFunc) gin.HandlerFunc if forceMappings { // FORCE MODE: Check model mappings FIRST (takes precedence over local API keys) // This allows users to route Amp requests to their preferred OAuth providers - if fh.modelMapper != nil { - if mappedModel := fh.modelMapper.MapModel(normalizedModel); mappedModel != "" { - // Mapping found - check if we have a provider for the mapped model - mappedProviders := util.GetProviderName(mappedModel) - if len(mappedProviders) > 0 { - // Mapping found and provider available - rewrite the model in request body - bodyBytes = rewriteModelInRequest(bodyBytes, mappedModel) - c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes)) - // Store mapped model in context for handlers that check it (like gemini bridge) - c.Set(MappedModelContextKey, mappedModel) - resolvedModel = mappedModel - usedMapping = true - providers = mappedProviders - } - } + if mappedModel, mappedProviders := resolveMappedModel(); mappedModel != "" { + // Mapping found and provider available - rewrite the model in request body + bodyBytes = rewriteModelInRequest(bodyBytes, mappedModel) + c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes)) + // Store mapped model in context for handlers that check it (like gemini bridge) + c.Set(MappedModelContextKey, mappedModel) + resolvedModel = mappedModel + usedMapping = true + providers = mappedProviders } // If no mapping applied, check for local providers @@ -174,21 +204,15 @@ func (fh *FallbackHandler) WrapHandler(handler gin.HandlerFunc) gin.HandlerFunc if len(providers) == 0 { // No providers configured - check if we have a model mapping - if fh.modelMapper != nil { - if mappedModel := fh.modelMapper.MapModel(normalizedModel); mappedModel != "" { - // Mapping found - check if we have a provider for the mapped model - mappedProviders := util.GetProviderName(mappedModel) - if len(mappedProviders) > 0 { - // Mapping found and provider available - rewrite the model in request body - bodyBytes = rewriteModelInRequest(bodyBytes, mappedModel) - c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes)) - // Store mapped model in context for handlers that check it (like gemini bridge) - c.Set(MappedModelContextKey, mappedModel) - resolvedModel = mappedModel - usedMapping = true - providers = mappedProviders - } - } + if mappedModel, mappedProviders := resolveMappedModel(); mappedModel != "" { + // Mapping found and provider available - rewrite the model in request body + bodyBytes = rewriteModelInRequest(bodyBytes, mappedModel) + c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes)) + // Store mapped model in context for handlers that check it (like gemini bridge) + c.Set(MappedModelContextKey, mappedModel) + resolvedModel = mappedModel + usedMapping = true + providers = mappedProviders } } } @@ -222,14 +246,14 @@ func (fh *FallbackHandler) WrapHandler(handler gin.HandlerFunc) gin.HandlerFunc // Log: Model was mapped to another model log.Debugf("amp model mapping: request %s -> %s", normalizedModel, resolvedModel) logAmpRouting(RouteTypeModelMapping, modelName, resolvedModel, providerName, requestPath) - rewriter := NewResponseRewriter(c.Writer, normalizedModel) + rewriter := NewResponseRewriter(c.Writer, modelName) c.Writer = rewriter // Filter Anthropic-Beta header only for local handling paths filterAntropicBetaHeader(c) c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes)) handler(c) rewriter.Flush() - log.Debugf("amp model mapping: response %s -> %s", resolvedModel, normalizedModel) + log.Debugf("amp model mapping: response %s -> %s", resolvedModel, modelName) } else if len(providers) > 0 { // Log: Using local provider (free) logAmpRouting(RouteTypeLocalProvider, modelName, resolvedModel, providerName, requestPath) diff --git a/internal/api/modules/amp/fallback_handlers_test.go b/internal/api/modules/amp/fallback_handlers_test.go new file mode 100644 index 00000000..a687fd11 --- /dev/null +++ b/internal/api/modules/amp/fallback_handlers_test.go @@ -0,0 +1,73 @@ +package amp + +import ( + "bytes" + "encoding/json" + "net/http" + "net/http/httptest" + "net/http/httputil" + "testing" + + "github.com/gin-gonic/gin" + "github.com/router-for-me/CLIProxyAPI/v6/internal/config" + "github.com/router-for-me/CLIProxyAPI/v6/internal/registry" +) + +func TestFallbackHandler_ModelMapping_PreservesThinkingSuffixAndRewritesResponse(t *testing.T) { + gin.SetMode(gin.TestMode) + + reg := registry.GetGlobalRegistry() + reg.RegisterClient("test-client-amp-fallback", "codex", []*registry.ModelInfo{ + {ID: "test/gpt-5.2", OwnedBy: "openai", Type: "codex"}, + }) + defer reg.UnregisterClient("test-client-amp-fallback") + + mapper := NewModelMapper([]config.AmpModelMapping{ + {From: "gpt-5.2", To: "test/gpt-5.2"}, + }) + + fallback := NewFallbackHandlerWithMapper(func() *httputil.ReverseProxy { return nil }, mapper, nil) + + handler := func(c *gin.Context) { + var req struct { + Model string `json:"model"` + } + if err := c.ShouldBindJSON(&req); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()}) + return + } + + c.JSON(http.StatusOK, gin.H{ + "model": req.Model, + "seen_model": req.Model, + }) + } + + r := gin.New() + r.POST("/chat/completions", fallback.WrapHandler(handler)) + + reqBody := []byte(`{"model":"gpt-5.2(xhigh)"}`) + req := httptest.NewRequest(http.MethodPost, "/chat/completions", bytes.NewReader(reqBody)) + req.Header.Set("Content-Type", "application/json") + w := httptest.NewRecorder() + r.ServeHTTP(w, req) + + if w.Code != http.StatusOK { + t.Fatalf("Expected status 200, got %d", w.Code) + } + + var resp struct { + Model string `json:"model"` + SeenModel string `json:"seen_model"` + } + if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil { + t.Fatalf("Failed to parse response JSON: %v", err) + } + + if resp.Model != "gpt-5.2(xhigh)" { + t.Errorf("Expected response model gpt-5.2(xhigh), got %s", resp.Model) + } + if resp.SeenModel != "test/gpt-5.2(xhigh)" { + t.Errorf("Expected handler to see test/gpt-5.2(xhigh), got %s", resp.SeenModel) + } +} diff --git a/internal/api/modules/amp/model_mapping.go b/internal/api/modules/amp/model_mapping.go index 87384a80..bc31c4e5 100644 --- a/internal/api/modules/amp/model_mapping.go +++ b/internal/api/modules/amp/model_mapping.go @@ -59,7 +59,8 @@ func (m *DefaultModelMapper) MapModel(requestedModel string) string { } // Verify target model has available providers - providers := util.GetProviderName(targetModel) + normalizedTarget, _ := util.NormalizeThinkingModel(targetModel) + providers := util.GetProviderName(normalizedTarget) if len(providers) == 0 { log.Debugf("amp model mapping: target model %s has no available providers, skipping mapping", targetModel) return "" diff --git a/internal/api/modules/amp/model_mapping_test.go b/internal/api/modules/amp/model_mapping_test.go index 4f4e5a8e..664a17c5 100644 --- a/internal/api/modules/amp/model_mapping_test.go +++ b/internal/api/modules/amp/model_mapping_test.go @@ -71,6 +71,25 @@ func TestModelMapper_MapModel_WithProvider(t *testing.T) { } } +func TestModelMapper_MapModel_TargetWithThinkingSuffix(t *testing.T) { + reg := registry.GetGlobalRegistry() + reg.RegisterClient("test-client-thinking", "codex", []*registry.ModelInfo{ + {ID: "gpt-5.2", OwnedBy: "openai", Type: "codex"}, + }) + defer reg.UnregisterClient("test-client-thinking") + + mappings := []config.AmpModelMapping{ + {From: "gpt-5.2-alias", To: "gpt-5.2(xhigh)"}, + } + + mapper := NewModelMapper(mappings) + + result := mapper.MapModel("gpt-5.2-alias") + if result != "gpt-5.2(xhigh)" { + t.Errorf("Expected gpt-5.2(xhigh), got %s", result) + } +} + func TestModelMapper_MapModel_CaseInsensitive(t *testing.T) { reg := registry.GetGlobalRegistry() reg.RegisterClient("test-client2", "claude", []*registry.ModelInfo{ diff --git a/internal/util/gemini_thinking.go b/internal/util/gemini_thinking.go index ba9e13ef..290d5f92 100644 --- a/internal/util/gemini_thinking.go +++ b/internal/util/gemini_thinking.go @@ -136,6 +136,12 @@ func ApplyGeminiThinkingLevel(body []byte, level string, includeThoughts *bool) updated = rewritten } } + if it := gjson.GetBytes(body, "generationConfig.thinkingConfig.include_thoughts"); it.Exists() { + updated, _ = sjson.DeleteBytes(updated, "generationConfig.thinkingConfig.include_thoughts") + } + if tb := gjson.GetBytes(body, "generationConfig.thinkingConfig.thinkingBudget"); tb.Exists() { + updated, _ = sjson.DeleteBytes(updated, "generationConfig.thinkingConfig.thinkingBudget") + } return updated } @@ -167,6 +173,12 @@ func ApplyGeminiCLIThinkingLevel(body []byte, level string, includeThoughts *boo updated = rewritten } } + if it := gjson.GetBytes(body, "request.generationConfig.thinkingConfig.include_thoughts"); it.Exists() { + updated, _ = sjson.DeleteBytes(updated, "request.generationConfig.thinkingConfig.include_thoughts") + } + if tb := gjson.GetBytes(body, "request.generationConfig.thinkingConfig.thinkingBudget"); tb.Exists() { + updated, _ = sjson.DeleteBytes(updated, "request.generationConfig.thinkingConfig.thinkingBudget") + } return updated } From 24970baa576dc76754233a6eae1582e04926435c Mon Sep 17 00:00:00 2001 From: Supra4E8C Date: Sun, 21 Dec 2025 02:14:28 +0800 Subject: [PATCH 26/38] management: allow prefix updates in provider PATCH handlers --- .../api/handlers/management/config_lists.go | 369 ++++++++++-------- 1 file changed, 212 insertions(+), 157 deletions(-) diff --git a/internal/api/handlers/management/config_lists.go b/internal/api/handlers/management/config_lists.go index a0d0b169..7e42b64b 100644 --- a/internal/api/handlers/management/config_lists.go +++ b/internal/api/handlers/management/config_lists.go @@ -145,71 +145,74 @@ func (h *Handler) PutGeminiKeys(c *gin.Context) { h.persist(c) } func (h *Handler) PatchGeminiKey(c *gin.Context) { + type geminiKeyPatch struct { + APIKey *string `json:"api-key"` + Prefix *string `json:"prefix"` + BaseURL *string `json:"base-url"` + ProxyURL *string `json:"proxy-url"` + Headers *map[string]string `json:"headers"` + ExcludedModels *[]string `json:"excluded-models"` + } var body struct { - Index *int `json:"index"` - Match *string `json:"match"` - Value *config.GeminiKey `json:"value"` + Index *int `json:"index"` + Match *string `json:"match"` + Value *geminiKeyPatch `json:"value"` } if err := c.ShouldBindJSON(&body); err != nil || body.Value == nil { c.JSON(400, gin.H{"error": "invalid body"}) return } - value := *body.Value - value.APIKey = strings.TrimSpace(value.APIKey) - value.BaseURL = strings.TrimSpace(value.BaseURL) - value.ProxyURL = strings.TrimSpace(value.ProxyURL) - value.ExcludedModels = config.NormalizeExcludedModels(value.ExcludedModels) - if value.APIKey == "" { - // Treat empty API key as delete. - if body.Index != nil && *body.Index >= 0 && *body.Index < len(h.cfg.GeminiKey) { - h.cfg.GeminiKey = append(h.cfg.GeminiKey[:*body.Index], h.cfg.GeminiKey[*body.Index+1:]...) - h.cfg.SanitizeGeminiKeys() - h.persist(c) - return - } - if body.Match != nil { - match := strings.TrimSpace(*body.Match) - if match != "" { - out := make([]config.GeminiKey, 0, len(h.cfg.GeminiKey)) - removed := false - for i := range h.cfg.GeminiKey { - if !removed && h.cfg.GeminiKey[i].APIKey == match { - removed = true - continue - } - out = append(out, h.cfg.GeminiKey[i]) - } - if removed { - h.cfg.GeminiKey = out - h.cfg.SanitizeGeminiKeys() - h.persist(c) - return + targetIndex := -1 + if body.Index != nil && *body.Index >= 0 && *body.Index < len(h.cfg.GeminiKey) { + targetIndex = *body.Index + } + if targetIndex == -1 && body.Match != nil { + match := strings.TrimSpace(*body.Match) + if match != "" { + for i := range h.cfg.GeminiKey { + if h.cfg.GeminiKey[i].APIKey == match { + targetIndex = i + break } } } + } + if targetIndex == -1 { c.JSON(404, gin.H{"error": "item not found"}) return } - if body.Index != nil && *body.Index >= 0 && *body.Index < len(h.cfg.GeminiKey) { - h.cfg.GeminiKey[*body.Index] = value - h.cfg.SanitizeGeminiKeys() - h.persist(c) - return - } - if body.Match != nil { - match := strings.TrimSpace(*body.Match) - for i := range h.cfg.GeminiKey { - if h.cfg.GeminiKey[i].APIKey == match { - h.cfg.GeminiKey[i] = value - h.cfg.SanitizeGeminiKeys() - h.persist(c) - return - } + entry := h.cfg.GeminiKey[targetIndex] + if body.Value.APIKey != nil { + trimmed := strings.TrimSpace(*body.Value.APIKey) + if trimmed == "" { + h.cfg.GeminiKey = append(h.cfg.GeminiKey[:targetIndex], h.cfg.GeminiKey[targetIndex+1:]...) + h.cfg.SanitizeGeminiKeys() + h.persist(c) + return } + entry.APIKey = trimmed } - c.JSON(404, gin.H{"error": "item not found"}) + if body.Value.Prefix != nil { + entry.Prefix = strings.TrimSpace(*body.Value.Prefix) + } + if body.Value.BaseURL != nil { + entry.BaseURL = strings.TrimSpace(*body.Value.BaseURL) + } + if body.Value.ProxyURL != nil { + entry.ProxyURL = strings.TrimSpace(*body.Value.ProxyURL) + } + if body.Value.Headers != nil { + entry.Headers = config.NormalizeHeaders(*body.Value.Headers) + } + if body.Value.ExcludedModels != nil { + entry.ExcludedModels = config.NormalizeExcludedModels(*body.Value.ExcludedModels) + } + h.cfg.GeminiKey[targetIndex] = entry + h.cfg.SanitizeGeminiKeys() + h.persist(c) } + func (h *Handler) DeleteGeminiKey(c *gin.Context) { if val := strings.TrimSpace(c.Query("api-key")); val != "" { out := make([]config.GeminiKey, 0, len(h.cfg.GeminiKey)) @@ -268,35 +271,70 @@ func (h *Handler) PutClaudeKeys(c *gin.Context) { h.persist(c) } func (h *Handler) PatchClaudeKey(c *gin.Context) { + type claudeKeyPatch struct { + APIKey *string `json:"api-key"` + Prefix *string `json:"prefix"` + BaseURL *string `json:"base-url"` + ProxyURL *string `json:"proxy-url"` + Models *[]config.ClaudeModel `json:"models"` + Headers *map[string]string `json:"headers"` + ExcludedModels *[]string `json:"excluded-models"` + } var body struct { - Index *int `json:"index"` - Match *string `json:"match"` - Value *config.ClaudeKey `json:"value"` + Index *int `json:"index"` + Match *string `json:"match"` + Value *claudeKeyPatch `json:"value"` } if err := c.ShouldBindJSON(&body); err != nil || body.Value == nil { c.JSON(400, gin.H{"error": "invalid body"}) return } - value := *body.Value - normalizeClaudeKey(&value) + targetIndex := -1 if body.Index != nil && *body.Index >= 0 && *body.Index < len(h.cfg.ClaudeKey) { - h.cfg.ClaudeKey[*body.Index] = value - h.cfg.SanitizeClaudeKeys() - h.persist(c) - return + targetIndex = *body.Index } - if body.Match != nil { + if targetIndex == -1 && body.Match != nil { + match := strings.TrimSpace(*body.Match) for i := range h.cfg.ClaudeKey { - if h.cfg.ClaudeKey[i].APIKey == *body.Match { - h.cfg.ClaudeKey[i] = value - h.cfg.SanitizeClaudeKeys() - h.persist(c) - return + if h.cfg.ClaudeKey[i].APIKey == match { + targetIndex = i + break } } } - c.JSON(404, gin.H{"error": "item not found"}) + if targetIndex == -1 { + c.JSON(404, gin.H{"error": "item not found"}) + return + } + + entry := h.cfg.ClaudeKey[targetIndex] + if body.Value.APIKey != nil { + entry.APIKey = strings.TrimSpace(*body.Value.APIKey) + } + if body.Value.Prefix != nil { + entry.Prefix = strings.TrimSpace(*body.Value.Prefix) + } + if body.Value.BaseURL != nil { + entry.BaseURL = strings.TrimSpace(*body.Value.BaseURL) + } + if body.Value.ProxyURL != nil { + entry.ProxyURL = strings.TrimSpace(*body.Value.ProxyURL) + } + if body.Value.Models != nil { + entry.Models = append([]config.ClaudeModel(nil), (*body.Value.Models)...) + } + if body.Value.Headers != nil { + entry.Headers = config.NormalizeHeaders(*body.Value.Headers) + } + if body.Value.ExcludedModels != nil { + entry.ExcludedModels = config.NormalizeExcludedModels(*body.Value.ExcludedModels) + } + normalizeClaudeKey(&entry) + h.cfg.ClaudeKey[targetIndex] = entry + h.cfg.SanitizeClaudeKeys() + h.persist(c) } + func (h *Handler) DeleteClaudeKey(c *gin.Context) { if val := c.Query("api-key"); val != "" { out := make([]config.ClaudeKey, 0, len(h.cfg.ClaudeKey)) @@ -356,62 +394,73 @@ func (h *Handler) PutOpenAICompat(c *gin.Context) { h.persist(c) } func (h *Handler) PatchOpenAICompat(c *gin.Context) { + type openAICompatPatch struct { + Name *string `json:"name"` + Prefix *string `json:"prefix"` + BaseURL *string `json:"base-url"` + APIKeyEntries *[]config.OpenAICompatibilityAPIKey `json:"api-key-entries"` + Models *[]config.OpenAICompatibilityModel `json:"models"` + Headers *map[string]string `json:"headers"` + } var body struct { - Name *string `json:"name"` - Index *int `json:"index"` - Value *config.OpenAICompatibility `json:"value"` + Name *string `json:"name"` + Index *int `json:"index"` + Value *openAICompatPatch `json:"value"` } if err := c.ShouldBindJSON(&body); err != nil || body.Value == nil { c.JSON(400, gin.H{"error": "invalid body"}) return } - normalizeOpenAICompatibilityEntry(body.Value) - // If base-url becomes empty, delete the provider instead of updating - if strings.TrimSpace(body.Value.BaseURL) == "" { - if body.Index != nil && *body.Index >= 0 && *body.Index < len(h.cfg.OpenAICompatibility) { - h.cfg.OpenAICompatibility = append(h.cfg.OpenAICompatibility[:*body.Index], h.cfg.OpenAICompatibility[*body.Index+1:]...) + targetIndex := -1 + if body.Index != nil && *body.Index >= 0 && *body.Index < len(h.cfg.OpenAICompatibility) { + targetIndex = *body.Index + } + if targetIndex == -1 && body.Name != nil { + match := strings.TrimSpace(*body.Name) + for i := range h.cfg.OpenAICompatibility { + if h.cfg.OpenAICompatibility[i].Name == match { + targetIndex = i + break + } + } + } + if targetIndex == -1 { + c.JSON(404, gin.H{"error": "item not found"}) + return + } + + entry := h.cfg.OpenAICompatibility[targetIndex] + if body.Value.Name != nil { + entry.Name = strings.TrimSpace(*body.Value.Name) + } + if body.Value.Prefix != nil { + entry.Prefix = strings.TrimSpace(*body.Value.Prefix) + } + if body.Value.BaseURL != nil { + trimmed := strings.TrimSpace(*body.Value.BaseURL) + if trimmed == "" { + h.cfg.OpenAICompatibility = append(h.cfg.OpenAICompatibility[:targetIndex], h.cfg.OpenAICompatibility[targetIndex+1:]...) h.cfg.SanitizeOpenAICompatibility() h.persist(c) return } - if body.Name != nil { - out := make([]config.OpenAICompatibility, 0, len(h.cfg.OpenAICompatibility)) - removed := false - for i := range h.cfg.OpenAICompatibility { - if !removed && h.cfg.OpenAICompatibility[i].Name == *body.Name { - removed = true - continue - } - out = append(out, h.cfg.OpenAICompatibility[i]) - } - if removed { - h.cfg.OpenAICompatibility = out - h.cfg.SanitizeOpenAICompatibility() - h.persist(c) - return - } - } - c.JSON(404, gin.H{"error": "item not found"}) - return + entry.BaseURL = trimmed } - if body.Index != nil && *body.Index >= 0 && *body.Index < len(h.cfg.OpenAICompatibility) { - h.cfg.OpenAICompatibility[*body.Index] = *body.Value - h.cfg.SanitizeOpenAICompatibility() - h.persist(c) - return + if body.Value.APIKeyEntries != nil { + entry.APIKeyEntries = append([]config.OpenAICompatibilityAPIKey(nil), (*body.Value.APIKeyEntries)...) } - if body.Name != nil { - for i := range h.cfg.OpenAICompatibility { - if h.cfg.OpenAICompatibility[i].Name == *body.Name { - h.cfg.OpenAICompatibility[i] = *body.Value - h.cfg.SanitizeOpenAICompatibility() - h.persist(c) - return - } - } + if body.Value.Models != nil { + entry.Models = append([]config.OpenAICompatibilityModel(nil), (*body.Value.Models)...) } - c.JSON(404, gin.H{"error": "item not found"}) + if body.Value.Headers != nil { + entry.Headers = config.NormalizeHeaders(*body.Value.Headers) + } + normalizeOpenAICompatibilityEntry(&entry) + h.cfg.OpenAICompatibility[targetIndex] = entry + h.cfg.SanitizeOpenAICompatibility() + h.persist(c) } + func (h *Handler) DeleteOpenAICompat(c *gin.Context) { if name := c.Query("name"); name != "" { out := make([]config.OpenAICompatibility, 0, len(h.cfg.OpenAICompatibility)) @@ -563,66 +612,72 @@ func (h *Handler) PutCodexKeys(c *gin.Context) { h.persist(c) } func (h *Handler) PatchCodexKey(c *gin.Context) { + type codexKeyPatch struct { + APIKey *string `json:"api-key"` + Prefix *string `json:"prefix"` + BaseURL *string `json:"base-url"` + ProxyURL *string `json:"proxy-url"` + Headers *map[string]string `json:"headers"` + ExcludedModels *[]string `json:"excluded-models"` + } var body struct { - Index *int `json:"index"` - Match *string `json:"match"` - Value *config.CodexKey `json:"value"` + Index *int `json:"index"` + Match *string `json:"match"` + Value *codexKeyPatch `json:"value"` } if err := c.ShouldBindJSON(&body); err != nil || body.Value == nil { c.JSON(400, gin.H{"error": "invalid body"}) return } - value := *body.Value - value.APIKey = strings.TrimSpace(value.APIKey) - value.BaseURL = strings.TrimSpace(value.BaseURL) - value.ProxyURL = strings.TrimSpace(value.ProxyURL) - value.Headers = config.NormalizeHeaders(value.Headers) - value.ExcludedModels = config.NormalizeExcludedModels(value.ExcludedModels) - // If base-url becomes empty, delete instead of update - if value.BaseURL == "" { - if body.Index != nil && *body.Index >= 0 && *body.Index < len(h.cfg.CodexKey) { - h.cfg.CodexKey = append(h.cfg.CodexKey[:*body.Index], h.cfg.CodexKey[*body.Index+1:]...) - h.cfg.SanitizeCodexKeys() - h.persist(c) - return - } - if body.Match != nil { - out := make([]config.CodexKey, 0, len(h.cfg.CodexKey)) - removed := false - for i := range h.cfg.CodexKey { - if !removed && h.cfg.CodexKey[i].APIKey == *body.Match { - removed = true - continue - } - out = append(out, h.cfg.CodexKey[i]) - } - if removed { - h.cfg.CodexKey = out - h.cfg.SanitizeCodexKeys() - h.persist(c) - return - } - } - } else { - if body.Index != nil && *body.Index >= 0 && *body.Index < len(h.cfg.CodexKey) { - h.cfg.CodexKey[*body.Index] = value - h.cfg.SanitizeCodexKeys() - h.persist(c) - return - } - if body.Match != nil { - for i := range h.cfg.CodexKey { - if h.cfg.CodexKey[i].APIKey == *body.Match { - h.cfg.CodexKey[i] = value - h.cfg.SanitizeCodexKeys() - h.persist(c) - return - } + targetIndex := -1 + if body.Index != nil && *body.Index >= 0 && *body.Index < len(h.cfg.CodexKey) { + targetIndex = *body.Index + } + if targetIndex == -1 && body.Match != nil { + match := strings.TrimSpace(*body.Match) + for i := range h.cfg.CodexKey { + if h.cfg.CodexKey[i].APIKey == match { + targetIndex = i + break } } } - c.JSON(404, gin.H{"error": "item not found"}) + if targetIndex == -1 { + c.JSON(404, gin.H{"error": "item not found"}) + return + } + + entry := h.cfg.CodexKey[targetIndex] + if body.Value.APIKey != nil { + entry.APIKey = strings.TrimSpace(*body.Value.APIKey) + } + if body.Value.Prefix != nil { + entry.Prefix = strings.TrimSpace(*body.Value.Prefix) + } + if body.Value.BaseURL != nil { + trimmed := strings.TrimSpace(*body.Value.BaseURL) + if trimmed == "" { + h.cfg.CodexKey = append(h.cfg.CodexKey[:targetIndex], h.cfg.CodexKey[targetIndex+1:]...) + h.cfg.SanitizeCodexKeys() + h.persist(c) + return + } + entry.BaseURL = trimmed + } + if body.Value.ProxyURL != nil { + entry.ProxyURL = strings.TrimSpace(*body.Value.ProxyURL) + } + if body.Value.Headers != nil { + entry.Headers = config.NormalizeHeaders(*body.Value.Headers) + } + if body.Value.ExcludedModels != nil { + entry.ExcludedModels = config.NormalizeExcludedModels(*body.Value.ExcludedModels) + } + h.cfg.CodexKey[targetIndex] = entry + h.cfg.SanitizeCodexKeys() + h.persist(c) } + func (h *Handler) DeleteCodexKey(c *gin.Context) { if val := c.Query("api-key"); val != "" { out := make([]config.CodexKey, 0, len(h.cfg.CodexKey)) From 653439698e0c698901de6dd2ec599249cee1339f Mon Sep 17 00:00:00 2001 From: Luis Pater Date: Sun, 21 Dec 2025 03:13:58 +0800 Subject: [PATCH 27/38] Fixed: #606 fix: unify response field naming across translators Standardize `text` to `delta` and add missing `output` field in all response payloads for consistency across OpenAI, Claude, and Gemini translators. --- .../openai/responses/claude_openai-responses_response.go | 6 +++--- .../openai/responses/gemini_openai-responses_response.go | 6 +++--- .../openai/responses/openai_openai-responses_response.go | 6 +++--- 3 files changed, 9 insertions(+), 9 deletions(-) diff --git a/internal/translator/claude/openai/responses/claude_openai-responses_response.go b/internal/translator/claude/openai/responses/claude_openai-responses_response.go index 252967f1..8ce17f5c 100644 --- a/internal/translator/claude/openai/responses/claude_openai-responses_response.go +++ b/internal/translator/claude/openai/responses/claude_openai-responses_response.go @@ -95,7 +95,7 @@ func ConvertClaudeResponseToOpenAIResponses(ctx context.Context, modelName strin } } // response.created - created := `{"type":"response.created","sequence_number":0,"response":{"id":"","object":"response","created_at":0,"status":"in_progress","background":false,"error":null,"instructions":""}}` + created := `{"type":"response.created","sequence_number":0,"response":{"id":"","object":"response","created_at":0,"status":"in_progress","background":false,"error":null,"output":[]}}` created, _ = sjson.Set(created, "sequence_number", nextSeq()) created, _ = sjson.Set(created, "response.id", st.ResponseID) created, _ = sjson.Set(created, "response.created_at", st.CreatedAt) @@ -197,11 +197,11 @@ func ConvertClaudeResponseToOpenAIResponses(ctx context.Context, modelName strin if st.ReasoningActive { if t := d.Get("thinking"); t.Exists() { st.ReasoningBuf.WriteString(t.String()) - msg := `{"type":"response.reasoning_summary_text.delta","sequence_number":0,"item_id":"","output_index":0,"summary_index":0,"text":""}` + msg := `{"type":"response.reasoning_summary_text.delta","sequence_number":0,"item_id":"","output_index":0,"summary_index":0,"delta":""}` msg, _ = sjson.Set(msg, "sequence_number", nextSeq()) msg, _ = sjson.Set(msg, "item_id", st.ReasoningItemID) msg, _ = sjson.Set(msg, "output_index", st.ReasoningIndex) - msg, _ = sjson.Set(msg, "text", t.String()) + msg, _ = sjson.Set(msg, "delta", t.String()) out = append(out, emitEvent("response.reasoning_summary_text.delta", msg)) } } diff --git a/internal/translator/gemini/openai/responses/gemini_openai-responses_response.go b/internal/translator/gemini/openai/responses/gemini_openai-responses_response.go index e08b265d..1e2874c4 100644 --- a/internal/translator/gemini/openai/responses/gemini_openai-responses_response.go +++ b/internal/translator/gemini/openai/responses/gemini_openai-responses_response.go @@ -117,7 +117,7 @@ func ConvertGeminiResponseToOpenAIResponses(_ context.Context, modelName string, st.CreatedAt = time.Now().Unix() } - created := `{"type":"response.created","sequence_number":0,"response":{"id":"","object":"response","created_at":0,"status":"in_progress","background":false,"error":null}}` + created := `{"type":"response.created","sequence_number":0,"response":{"id":"","object":"response","created_at":0,"status":"in_progress","background":false,"error":null,"output":[]}}` created, _ = sjson.Set(created, "sequence_number", nextSeq()) created, _ = sjson.Set(created, "response.id", st.ResponseID) created, _ = sjson.Set(created, "response.created_at", st.CreatedAt) @@ -160,11 +160,11 @@ func ConvertGeminiResponseToOpenAIResponses(_ context.Context, modelName string, } if t := part.Get("text"); t.Exists() && t.String() != "" { st.ReasoningBuf.WriteString(t.String()) - msg := `{"type":"response.reasoning_summary_text.delta","sequence_number":0,"item_id":"","output_index":0,"summary_index":0,"text":""}` + msg := `{"type":"response.reasoning_summary_text.delta","sequence_number":0,"item_id":"","output_index":0,"summary_index":0,"delta":""}` msg, _ = sjson.Set(msg, "sequence_number", nextSeq()) msg, _ = sjson.Set(msg, "item_id", st.ReasoningItemID) msg, _ = sjson.Set(msg, "output_index", st.ReasoningIndex) - msg, _ = sjson.Set(msg, "text", t.String()) + msg, _ = sjson.Set(msg, "delta", t.String()) out = append(out, emitEvent("response.reasoning_summary_text.delta", msg)) } return true diff --git a/internal/translator/openai/openai/responses/openai_openai-responses_response.go b/internal/translator/openai/openai/responses/openai_openai-responses_response.go index c698b93f..2bda2029 100644 --- a/internal/translator/openai/openai/responses/openai_openai-responses_response.go +++ b/internal/translator/openai/openai/responses/openai_openai-responses_response.go @@ -143,7 +143,7 @@ func ConvertOpenAIChatCompletionsResponseToOpenAIResponses(ctx context.Context, st.ReasoningTokens = 0 st.UsageSeen = false // response.created - created := `{"type":"response.created","sequence_number":0,"response":{"id":"","object":"response","created_at":0,"status":"in_progress","background":false,"error":null}}` + created := `{"type":"response.created","sequence_number":0,"response":{"id":"","object":"response","created_at":0,"status":"in_progress","background":false,"error":null,"output":[]}}` created, _ = sjson.Set(created, "sequence_number", nextSeq()) created, _ = sjson.Set(created, "response.id", st.ResponseID) created, _ = sjson.Set(created, "response.created_at", st.Created) @@ -216,11 +216,11 @@ func ConvertOpenAIChatCompletionsResponseToOpenAIResponses(ctx context.Context, } // Append incremental text to reasoning buffer st.ReasoningBuf.WriteString(rc.String()) - msg := `{"type":"response.reasoning_summary_text.delta","sequence_number":0,"item_id":"","output_index":0,"summary_index":0,"text":""}` + msg := `{"type":"response.reasoning_summary_text.delta","sequence_number":0,"item_id":"","output_index":0,"summary_index":0,"delta":""}` msg, _ = sjson.Set(msg, "sequence_number", nextSeq()) msg, _ = sjson.Set(msg, "item_id", st.ReasoningID) msg, _ = sjson.Set(msg, "output_index", st.ReasoningIndex) - msg, _ = sjson.Set(msg, "text", rc.String()) + msg, _ = sjson.Set(msg, "delta", rc.String()) out = append(out, emitRespEvent("response.reasoning_summary_text.delta", msg)) } From 453e744abfb4dbfab1d729ddb88e04815997e70e Mon Sep 17 00:00:00 2001 From: Luis Pater Date: Sun, 21 Dec 2025 03:38:38 +0800 Subject: [PATCH 28/38] Fixed: #642 fix: remove unsupported fields `type` and `cache_control` across translators --- .../translator/antigravity/claude/antigravity_claude_request.go | 2 ++ .../translator/gemini-cli/claude/gemini-cli_claude_request.go | 2 ++ internal/translator/gemini/claude/gemini_claude_request.go | 2 ++ 3 files changed, 6 insertions(+) diff --git a/internal/translator/antigravity/claude/antigravity_claude_request.go b/internal/translator/antigravity/claude/antigravity_claude_request.go index def1cfbe..9216628e 100644 --- a/internal/translator/antigravity/claude/antigravity_claude_request.go +++ b/internal/translator/antigravity/claude/antigravity_claude_request.go @@ -211,6 +211,8 @@ func ConvertClaudeRequestToAntigravity(modelName string, inputRawJSON []byte, _ tool, _ = sjson.SetRaw(tool, "parametersJsonSchema", inputSchema) tool, _ = sjson.Delete(tool, "strict") tool, _ = sjson.Delete(tool, "input_examples") + tool, _ = sjson.Delete(tool, "type") + tool, _ = sjson.Delete(tool, "cache_control") toolsJSON, _ = sjson.SetRaw(toolsJSON, "0.functionDeclarations.-1", tool) toolDeclCount++ } diff --git a/internal/translator/gemini-cli/claude/gemini-cli_claude_request.go b/internal/translator/gemini-cli/claude/gemini-cli_claude_request.go index 913727ce..a74ced7c 100644 --- a/internal/translator/gemini-cli/claude/gemini-cli_claude_request.go +++ b/internal/translator/gemini-cli/claude/gemini-cli_claude_request.go @@ -134,6 +134,8 @@ func ConvertClaudeRequestToCLI(modelName string, inputRawJSON []byte, _ bool) [] tool, _ = sjson.SetRaw(tool, "parametersJsonSchema", inputSchema) tool, _ = sjson.Delete(tool, "strict") tool, _ = sjson.Delete(tool, "input_examples") + tool, _ = sjson.Delete(tool, "type") + tool, _ = sjson.Delete(tool, "cache_control") var toolDeclaration any if err := json.Unmarshal([]byte(tool), &toolDeclaration); err == nil { tools[0].FunctionDeclarations = append(tools[0].FunctionDeclarations, toolDeclaration) diff --git a/internal/translator/gemini/claude/gemini_claude_request.go b/internal/translator/gemini/claude/gemini_claude_request.go index f626a581..40f4fac2 100644 --- a/internal/translator/gemini/claude/gemini_claude_request.go +++ b/internal/translator/gemini/claude/gemini_claude_request.go @@ -127,6 +127,8 @@ func ConvertClaudeRequestToGemini(modelName string, inputRawJSON []byte, _ bool) tool, _ = sjson.SetRaw(tool, "parametersJsonSchema", inputSchema) tool, _ = sjson.Delete(tool, "strict") tool, _ = sjson.Delete(tool, "input_examples") + tool, _ = sjson.Delete(tool, "type") + tool, _ = sjson.Delete(tool, "cache_control") var toolDeclaration any if err := json.Unmarshal([]byte(tool), &toolDeclaration); err == nil { tools[0].FunctionDeclarations = append(tools[0].FunctionDeclarations, toolDeclaration) From cd0c94f48acc9b66f51e4c1710923c37c9d2255c Mon Sep 17 00:00:00 2001 From: Supra4E8C Date: Sun, 21 Dec 2025 07:06:28 +0800 Subject: [PATCH 29/38] fix(sdk/auth): prevent OAuth manual prompt goroutine leak,Use timer-based manual prompt per provider and remove oauth_callback helper. --- sdk/auth/antigravity.go | 60 ++++++++++++++++++++++++++-------- sdk/auth/claude.go | 67 +++++++++++++++++++++++++++++--------- sdk/auth/codex.go | 67 +++++++++++++++++++++++++++++--------- sdk/auth/iflow.go | 56 ++++++++++++++++++++++++------- sdk/auth/oauth_callback.go | 41 ----------------------- 5 files changed, 193 insertions(+), 98 deletions(-) delete mode 100644 sdk/auth/oauth_callback.go diff --git a/sdk/auth/antigravity.go b/sdk/auth/antigravity.go index 832bd88e..ae22f772 100644 --- a/sdk/auth/antigravity.go +++ b/sdk/auth/antigravity.go @@ -99,20 +99,54 @@ func (AntigravityAuthenticator) Login(ctx context.Context, cfg *config.Config, o fmt.Println("Waiting for antigravity authentication callback...") var cbRes callbackResult - manualCh, manualErrCh := promptForOAuthCallback(opts.Prompt, "antigravity") - select { - case res := <-cbChan: - cbRes = res - case manual := <-manualCh: - cbRes = callbackResult{ - Code: manual.Code, - State: manual.State, - Error: manual.Error, + timeoutTimer := time.NewTimer(5 * time.Minute) + defer timeoutTimer.Stop() + + var manualPromptTimer *time.Timer + var manualPromptC <-chan time.Time + if opts.Prompt != nil { + manualPromptTimer = time.NewTimer(15 * time.Second) + manualPromptC = manualPromptTimer.C + defer manualPromptTimer.Stop() + } + +waitForCallback: + for { + select { + case res := <-cbChan: + cbRes = res + break waitForCallback + case <-manualPromptC: + manualPromptC = nil + if manualPromptTimer != nil { + manualPromptTimer.Stop() + } + select { + case res := <-cbChan: + cbRes = res + break waitForCallback + default: + } + input, errPrompt := opts.Prompt("Paste the antigravity callback URL (or press Enter to keep waiting): ") + if errPrompt != nil { + return nil, errPrompt + } + parsed, errParse := misc.ParseOAuthCallback(input) + if errParse != nil { + return nil, errParse + } + if parsed == nil { + continue + } + cbRes = callbackResult{ + Code: parsed.Code, + State: parsed.State, + Error: parsed.Error, + } + break waitForCallback + case <-timeoutTimer.C: + return nil, fmt.Errorf("antigravity: authentication timed out") } - case err = <-manualErrCh: - return nil, err - case <-time.After(5 * time.Minute): - return nil, fmt.Errorf("antigravity: authentication timed out") } if cbRes.Error != "" { diff --git a/sdk/auth/claude.go b/sdk/auth/claude.go index d88cdf29..c43b78cd 100644 --- a/sdk/auth/claude.go +++ b/sdk/auth/claude.go @@ -100,7 +100,6 @@ func (a *ClaudeAuthenticator) Login(ctx context.Context, cfg *config.Config, opt callbackCh := make(chan *claude.OAuthResult, 1) callbackErrCh := make(chan error, 1) - manualCh, manualErrCh := promptForOAuthCallback(opts.Prompt, "Claude") manualDescription := "" go func() { @@ -113,22 +112,58 @@ func (a *ClaudeAuthenticator) Login(ctx context.Context, cfg *config.Config, opt }() var result *claude.OAuthResult - select { - case result = <-callbackCh: - case err = <-callbackErrCh: - if strings.Contains(err.Error(), "timeout") { - return nil, claude.NewAuthenticationError(claude.ErrCallbackTimeout, err) + var manualPromptTimer *time.Timer + var manualPromptC <-chan time.Time + if opts.Prompt != nil { + manualPromptTimer = time.NewTimer(15 * time.Second) + manualPromptC = manualPromptTimer.C + defer manualPromptTimer.Stop() + } + +waitForCallback: + for { + select { + case result = <-callbackCh: + break waitForCallback + case err = <-callbackErrCh: + if strings.Contains(err.Error(), "timeout") { + return nil, claude.NewAuthenticationError(claude.ErrCallbackTimeout, err) + } + return nil, err + case <-manualPromptC: + manualPromptC = nil + if manualPromptTimer != nil { + manualPromptTimer.Stop() + } + select { + case result = <-callbackCh: + break waitForCallback + case err = <-callbackErrCh: + if strings.Contains(err.Error(), "timeout") { + return nil, claude.NewAuthenticationError(claude.ErrCallbackTimeout, err) + } + return nil, err + default: + } + input, errPrompt := opts.Prompt("Paste the Claude callback URL (or press Enter to keep waiting): ") + if errPrompt != nil { + return nil, errPrompt + } + parsed, errParse := misc.ParseOAuthCallback(input) + if errParse != nil { + return nil, errParse + } + if parsed == nil { + continue + } + manualDescription = parsed.ErrorDescription + result = &claude.OAuthResult{ + Code: parsed.Code, + State: parsed.State, + Error: parsed.Error, + } + break waitForCallback } - return nil, err - case manual := <-manualCh: - manualDescription = manual.ErrorDescription - result = &claude.OAuthResult{ - Code: manual.Code, - State: manual.State, - Error: manual.Error, - } - case err = <-manualErrCh: - return nil, err } if result.Error != "" { diff --git a/sdk/auth/codex.go b/sdk/auth/codex.go index b0a6b4a4..99992525 100644 --- a/sdk/auth/codex.go +++ b/sdk/auth/codex.go @@ -99,7 +99,6 @@ func (a *CodexAuthenticator) Login(ctx context.Context, cfg *config.Config, opts callbackCh := make(chan *codex.OAuthResult, 1) callbackErrCh := make(chan error, 1) - manualCh, manualErrCh := promptForOAuthCallback(opts.Prompt, "Codex") manualDescription := "" go func() { @@ -112,22 +111,58 @@ func (a *CodexAuthenticator) Login(ctx context.Context, cfg *config.Config, opts }() var result *codex.OAuthResult - select { - case result = <-callbackCh: - case err = <-callbackErrCh: - if strings.Contains(err.Error(), "timeout") { - return nil, codex.NewAuthenticationError(codex.ErrCallbackTimeout, err) + var manualPromptTimer *time.Timer + var manualPromptC <-chan time.Time + if opts.Prompt != nil { + manualPromptTimer = time.NewTimer(15 * time.Second) + manualPromptC = manualPromptTimer.C + defer manualPromptTimer.Stop() + } + +waitForCallback: + for { + select { + case result = <-callbackCh: + break waitForCallback + case err = <-callbackErrCh: + if strings.Contains(err.Error(), "timeout") { + return nil, codex.NewAuthenticationError(codex.ErrCallbackTimeout, err) + } + return nil, err + case <-manualPromptC: + manualPromptC = nil + if manualPromptTimer != nil { + manualPromptTimer.Stop() + } + select { + case result = <-callbackCh: + break waitForCallback + case err = <-callbackErrCh: + if strings.Contains(err.Error(), "timeout") { + return nil, codex.NewAuthenticationError(codex.ErrCallbackTimeout, err) + } + return nil, err + default: + } + input, errPrompt := opts.Prompt("Paste the Codex callback URL (or press Enter to keep waiting): ") + if errPrompt != nil { + return nil, errPrompt + } + parsed, errParse := misc.ParseOAuthCallback(input) + if errParse != nil { + return nil, errParse + } + if parsed == nil { + continue + } + manualDescription = parsed.ErrorDescription + result = &codex.OAuthResult{ + Code: parsed.Code, + State: parsed.State, + Error: parsed.Error, + } + break waitForCallback } - return nil, err - case manual := <-manualCh: - manualDescription = manual.ErrorDescription - result = &codex.OAuthResult{ - Code: manual.Code, - State: manual.State, - Error: manual.Error, - } - case err = <-manualErrCh: - return nil, err } if result.Error != "" { diff --git a/sdk/auth/iflow.go b/sdk/auth/iflow.go index d7621a99..3fd82f1d 100644 --- a/sdk/auth/iflow.go +++ b/sdk/auth/iflow.go @@ -86,7 +86,6 @@ func (a *IFlowAuthenticator) Login(ctx context.Context, cfg *config.Config, opts callbackCh := make(chan *iflow.OAuthResult, 1) callbackErrCh := make(chan error, 1) - manualCh, manualErrCh := promptForOAuthCallback(opts.Prompt, "iFlow") go func() { result, errWait := oauthServer.WaitForCallback(5 * time.Minute) @@ -98,18 +97,51 @@ func (a *IFlowAuthenticator) Login(ctx context.Context, cfg *config.Config, opts }() var result *iflow.OAuthResult - select { - case result = <-callbackCh: - case err = <-callbackErrCh: - return nil, fmt.Errorf("iflow auth: callback wait failed: %w", err) - case manual := <-manualCh: - result = &iflow.OAuthResult{ - Code: manual.Code, - State: manual.State, - Error: manual.Error, + var manualPromptTimer *time.Timer + var manualPromptC <-chan time.Time + if opts.Prompt != nil { + manualPromptTimer = time.NewTimer(15 * time.Second) + manualPromptC = manualPromptTimer.C + defer manualPromptTimer.Stop() + } + +waitForCallback: + for { + select { + case result = <-callbackCh: + break waitForCallback + case err = <-callbackErrCh: + return nil, fmt.Errorf("iflow auth: callback wait failed: %w", err) + case <-manualPromptC: + manualPromptC = nil + if manualPromptTimer != nil { + manualPromptTimer.Stop() + } + select { + case result = <-callbackCh: + break waitForCallback + case err = <-callbackErrCh: + return nil, fmt.Errorf("iflow auth: callback wait failed: %w", err) + default: + } + input, errPrompt := opts.Prompt("Paste the iFlow callback URL (or press Enter to keep waiting): ") + if errPrompt != nil { + return nil, errPrompt + } + parsed, errParse := misc.ParseOAuthCallback(input) + if errParse != nil { + return nil, errParse + } + if parsed == nil { + continue + } + result = &iflow.OAuthResult{ + Code: parsed.Code, + State: parsed.State, + Error: parsed.Error, + } + break waitForCallback } - case err = <-manualErrCh: - return nil, err } if result.Error != "" { return nil, fmt.Errorf("iflow auth: provider returned error %s", result.Error) diff --git a/sdk/auth/oauth_callback.go b/sdk/auth/oauth_callback.go deleted file mode 100644 index 3f0ac925..00000000 --- a/sdk/auth/oauth_callback.go +++ /dev/null @@ -1,41 +0,0 @@ -package auth - -import ( - "fmt" - - "github.com/router-for-me/CLIProxyAPI/v6/internal/misc" -) - -func promptForOAuthCallback(prompt func(string) (string, error), provider string) (<-chan *misc.OAuthCallback, <-chan error) { - if prompt == nil { - return nil, nil - } - - resultCh := make(chan *misc.OAuthCallback, 1) - errCh := make(chan error, 1) - - go func() { - label := provider - if label == "" { - label = "OAuth" - } - input, err := prompt(fmt.Sprintf("Paste the %s callback URL (or press Enter to keep waiting): ", label)) - if err != nil { - errCh <- err - return - } - - parsed, err := misc.ParseOAuthCallback(input) - if err != nil { - errCh <- err - return - } - if parsed == nil { - return - } - - resultCh <- parsed - }() - - return resultCh, errCh -} From 05d201ece84ef10817a0b92544cd1ca816adface Mon Sep 17 00:00:00 2001 From: Supra4E8C Date: Sun, 21 Dec 2025 07:21:12 +0800 Subject: [PATCH 30/38] fix(gemini): gate callback prompt on project_id --- internal/cmd/login.go | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/internal/cmd/login.go b/internal/cmd/login.go index 0f079b4b..3bb0b9a5 100644 --- a/internal/cmd/login.go +++ b/internal/cmd/login.go @@ -58,14 +58,19 @@ func DoLogin(cfg *config.Config, projectID string, options *LoginOptions) { promptFn := options.Prompt if promptFn == nil { promptFn = defaultProjectPrompt() - options.Prompt = promptFn + } + + trimmedProjectID := strings.TrimSpace(projectID) + callbackPrompt := promptFn + if trimmedProjectID == "" { + callbackPrompt = nil } loginOpts := &sdkAuth.LoginOptions{ NoBrowser: options.NoBrowser, - ProjectID: strings.TrimSpace(projectID), + ProjectID: trimmedProjectID, Metadata: map[string]string{}, - Prompt: promptFn, + Prompt: callbackPrompt, } authenticator := sdkAuth.NewGeminiAuthenticator() @@ -84,7 +89,7 @@ func DoLogin(cfg *config.Config, projectID string, options *LoginOptions) { geminiAuth := gemini.NewGeminiAuth() httpClient, errClient := geminiAuth.GetAuthenticatedClient(ctx, storage, cfg, &gemini.WebLoginOptions{ NoBrowser: options.NoBrowser, - Prompt: promptFn, + Prompt: callbackPrompt, }) if errClient != nil { log.Errorf("Gemini authentication failed: %v", errClient) @@ -99,7 +104,7 @@ func DoLogin(cfg *config.Config, projectID string, options *LoginOptions) { return } - selectedProjectID := promptForProjectSelection(projects, strings.TrimSpace(projectID), promptFn) + selectedProjectID := promptForProjectSelection(projects, trimmedProjectID, promptFn) projectSelections, errSelection := resolveProjectSelections(selectedProjectID, projects) if errSelection != nil { log.Errorf("Invalid project selection: %v", errSelection) From 781bc1521b827b4e2c9f60f17a5d8b5a5f28e85c Mon Sep 17 00:00:00 2001 From: Supra4E8C Date: Sun, 21 Dec 2025 10:48:40 +0800 Subject: [PATCH 31/38] fix(oauth): prevent stale session timeouts after login - stop callback forwarders by instance to avoid cross-session shutdowns - clear pending sessions for a provider after successful auth --- .../api/handlers/management/auth_files.go | 66 ++++++++++++++++--- .../api/handlers/management/oauth_sessions.go | 25 +++++++ 2 files changed, 81 insertions(+), 10 deletions(-) diff --git a/internal/api/handlers/management/auth_files.go b/internal/api/handlers/management/auth_files.go index 4f42bd7a..41a4fde4 100644 --- a/internal/api/handlers/management/auth_files.go +++ b/internal/api/handlers/management/auth_files.go @@ -197,6 +197,19 @@ func stopCallbackForwarder(port int) { stopForwarderInstance(port, forwarder) } +func stopCallbackForwarderInstance(port int, forwarder *callbackForwarder) { + if forwarder == nil { + return + } + callbackForwardersMu.Lock() + if current := callbackForwarders[port]; current == forwarder { + delete(callbackForwarders, port) + } + callbackForwardersMu.Unlock() + + stopForwarderInstance(port, forwarder) +} + func stopForwarderInstance(port int, forwarder *callbackForwarder) { if forwarder == nil || forwarder.server == nil { return @@ -785,6 +798,7 @@ func (h *Handler) RequestAnthropicToken(c *gin.Context) { RegisterOAuthSession(state, "anthropic") isWebUI := isWebUIRequest(c) + var forwarder *callbackForwarder if isWebUI { targetURL, errTarget := h.managementCallbackURL("/anthropic/callback") if errTarget != nil { @@ -792,7 +806,8 @@ func (h *Handler) RequestAnthropicToken(c *gin.Context) { c.JSON(http.StatusInternalServerError, gin.H{"error": "callback server unavailable"}) return } - if _, errStart := startCallbackForwarder(anthropicCallbackPort, "anthropic", targetURL); errStart != nil { + var errStart error + if forwarder, errStart = startCallbackForwarder(anthropicCallbackPort, "anthropic", targetURL); errStart != nil { log.WithError(errStart).Error("failed to start anthropic callback forwarder") c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to start callback server"}) return @@ -801,7 +816,7 @@ func (h *Handler) RequestAnthropicToken(c *gin.Context) { go func() { if isWebUI { - defer stopCallbackForwarder(anthropicCallbackPort) + defer stopCallbackForwarderInstance(anthropicCallbackPort, forwarder) } // Helper: wait for callback file @@ -809,6 +824,9 @@ func (h *Handler) RequestAnthropicToken(c *gin.Context) { waitForFile := func(path string, timeout time.Duration) (map[string]string, error) { deadline := time.Now().Add(timeout) for { + if !IsOAuthSessionPending(state, "anthropic") { + return nil, errOAuthSessionNotPending + } if time.Now().After(deadline) { SetOAuthSessionError(state, "Timeout waiting for OAuth callback") return nil, fmt.Errorf("timeout waiting for OAuth callback") @@ -828,6 +846,9 @@ func (h *Handler) RequestAnthropicToken(c *gin.Context) { // Wait up to 5 minutes resultMap, errWait := waitForFile(waitFile, 5*time.Minute) if errWait != nil { + if errors.Is(errWait, errOAuthSessionNotPending) { + return + } authErr := claude.NewAuthenticationError(claude.ErrCallbackTimeout, errWait) log.Error(claude.GetUserFriendlyMessage(authErr)) return @@ -933,6 +954,7 @@ func (h *Handler) RequestAnthropicToken(c *gin.Context) { } fmt.Println("You can now use Claude services through this CLI") CompleteOAuthSession(state) + CompleteOAuthSessionsByProvider("anthropic") }() c.JSON(200, gin.H{"status": "ok", "url": authURL, "state": state}) @@ -968,6 +990,7 @@ func (h *Handler) RequestGeminiCLIToken(c *gin.Context) { RegisterOAuthSession(state, "gemini") isWebUI := isWebUIRequest(c) + var forwarder *callbackForwarder if isWebUI { targetURL, errTarget := h.managementCallbackURL("/google/callback") if errTarget != nil { @@ -975,7 +998,8 @@ func (h *Handler) RequestGeminiCLIToken(c *gin.Context) { c.JSON(http.StatusInternalServerError, gin.H{"error": "callback server unavailable"}) return } - if _, errStart := startCallbackForwarder(geminiCallbackPort, "gemini", targetURL); errStart != nil { + var errStart error + if forwarder, errStart = startCallbackForwarder(geminiCallbackPort, "gemini", targetURL); errStart != nil { log.WithError(errStart).Error("failed to start gemini callback forwarder") c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to start callback server"}) return @@ -984,7 +1008,7 @@ func (h *Handler) RequestGeminiCLIToken(c *gin.Context) { go func() { if isWebUI { - defer stopCallbackForwarder(geminiCallbackPort) + defer stopCallbackForwarderInstance(geminiCallbackPort, forwarder) } // Wait for callback file written by server route @@ -993,6 +1017,9 @@ func (h *Handler) RequestGeminiCLIToken(c *gin.Context) { deadline := time.Now().Add(5 * time.Minute) var authCode string for { + if !IsOAuthSessionPending(state, "gemini") { + return + } if time.Now().After(deadline) { log.Error("oauth flow timed out") SetOAuthSessionError(state, "OAuth flow timed out") @@ -1168,6 +1195,7 @@ func (h *Handler) RequestGeminiCLIToken(c *gin.Context) { } CompleteOAuthSession(state) + CompleteOAuthSessionsByProvider("gemini") fmt.Printf("You can now use Gemini CLI services through this CLI; token saved to %s\n", savedPath) }() @@ -1209,6 +1237,7 @@ func (h *Handler) RequestCodexToken(c *gin.Context) { RegisterOAuthSession(state, "codex") isWebUI := isWebUIRequest(c) + var forwarder *callbackForwarder if isWebUI { targetURL, errTarget := h.managementCallbackURL("/codex/callback") if errTarget != nil { @@ -1216,7 +1245,8 @@ func (h *Handler) RequestCodexToken(c *gin.Context) { c.JSON(http.StatusInternalServerError, gin.H{"error": "callback server unavailable"}) return } - if _, errStart := startCallbackForwarder(codexCallbackPort, "codex", targetURL); errStart != nil { + var errStart error + if forwarder, errStart = startCallbackForwarder(codexCallbackPort, "codex", targetURL); errStart != nil { log.WithError(errStart).Error("failed to start codex callback forwarder") c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to start callback server"}) return @@ -1225,7 +1255,7 @@ func (h *Handler) RequestCodexToken(c *gin.Context) { go func() { if isWebUI { - defer stopCallbackForwarder(codexCallbackPort) + defer stopCallbackForwarderInstance(codexCallbackPort, forwarder) } // Wait for callback file @@ -1233,6 +1263,9 @@ func (h *Handler) RequestCodexToken(c *gin.Context) { deadline := time.Now().Add(5 * time.Minute) var code string for { + if !IsOAuthSessionPending(state, "codex") { + return + } if time.Now().After(deadline) { authErr := codex.NewAuthenticationError(codex.ErrCallbackTimeout, fmt.Errorf("timeout waiting for OAuth callback")) log.Error(codex.GetUserFriendlyMessage(authErr)) @@ -1348,6 +1381,7 @@ func (h *Handler) RequestCodexToken(c *gin.Context) { } fmt.Println("You can now use Codex services through this CLI") CompleteOAuthSession(state) + CompleteOAuthSessionsByProvider("codex") }() c.JSON(200, gin.H{"status": "ok", "url": authURL, "state": state}) @@ -1393,6 +1427,7 @@ func (h *Handler) RequestAntigravityToken(c *gin.Context) { RegisterOAuthSession(state, "antigravity") isWebUI := isWebUIRequest(c) + var forwarder *callbackForwarder if isWebUI { targetURL, errTarget := h.managementCallbackURL("/antigravity/callback") if errTarget != nil { @@ -1400,7 +1435,8 @@ func (h *Handler) RequestAntigravityToken(c *gin.Context) { c.JSON(http.StatusInternalServerError, gin.H{"error": "callback server unavailable"}) return } - if _, errStart := startCallbackForwarder(antigravityCallbackPort, "antigravity", targetURL); errStart != nil { + var errStart error + if forwarder, errStart = startCallbackForwarder(antigravityCallbackPort, "antigravity", targetURL); errStart != nil { log.WithError(errStart).Error("failed to start antigravity callback forwarder") c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to start callback server"}) return @@ -1409,13 +1445,16 @@ func (h *Handler) RequestAntigravityToken(c *gin.Context) { go func() { if isWebUI { - defer stopCallbackForwarder(antigravityCallbackPort) + defer stopCallbackForwarderInstance(antigravityCallbackPort, forwarder) } waitFile := filepath.Join(h.cfg.AuthDir, fmt.Sprintf(".oauth-antigravity-%s.oauth", state)) deadline := time.Now().Add(5 * time.Minute) var authCode string for { + if !IsOAuthSessionPending(state, "antigravity") { + return + } if time.Now().After(deadline) { log.Error("oauth flow timed out") SetOAuthSessionError(state, "OAuth flow timed out") @@ -1578,6 +1617,7 @@ func (h *Handler) RequestAntigravityToken(c *gin.Context) { } CompleteOAuthSession(state) + CompleteOAuthSessionsByProvider("antigravity") fmt.Printf("Authentication successful! Token saved to %s\n", savedPath) if projectID != "" { fmt.Printf("Using GCP project: %s\n", projectID) @@ -1655,6 +1695,7 @@ func (h *Handler) RequestIFlowToken(c *gin.Context) { RegisterOAuthSession(state, "iflow") isWebUI := isWebUIRequest(c) + var forwarder *callbackForwarder if isWebUI { targetURL, errTarget := h.managementCallbackURL("/iflow/callback") if errTarget != nil { @@ -1662,7 +1703,8 @@ func (h *Handler) RequestIFlowToken(c *gin.Context) { c.JSON(http.StatusInternalServerError, gin.H{"status": "error", "error": "callback server unavailable"}) return } - if _, errStart := startCallbackForwarder(iflowauth.CallbackPort, "iflow", targetURL); errStart != nil { + var errStart error + if forwarder, errStart = startCallbackForwarder(iflowauth.CallbackPort, "iflow", targetURL); errStart != nil { log.WithError(errStart).Error("failed to start iflow callback forwarder") c.JSON(http.StatusInternalServerError, gin.H{"status": "error", "error": "failed to start callback server"}) return @@ -1671,7 +1713,7 @@ func (h *Handler) RequestIFlowToken(c *gin.Context) { go func() { if isWebUI { - defer stopCallbackForwarder(iflowauth.CallbackPort) + defer stopCallbackForwarderInstance(iflowauth.CallbackPort, forwarder) } fmt.Println("Waiting for authentication...") @@ -1679,6 +1721,9 @@ func (h *Handler) RequestIFlowToken(c *gin.Context) { deadline := time.Now().Add(5 * time.Minute) var resultMap map[string]string for { + if !IsOAuthSessionPending(state, "iflow") { + return + } if time.Now().After(deadline) { SetOAuthSessionError(state, "Authentication failed") fmt.Println("Authentication failed: timeout waiting for callback") @@ -1745,6 +1790,7 @@ func (h *Handler) RequestIFlowToken(c *gin.Context) { } fmt.Println("You can now use iFlow services through this CLI") CompleteOAuthSession(state) + CompleteOAuthSessionsByProvider("iflow") }() c.JSON(http.StatusOK, gin.H{"status": "ok", "url": authURL, "state": state}) diff --git a/internal/api/handlers/management/oauth_sessions.go b/internal/api/handlers/management/oauth_sessions.go index f23b608c..05ff8d1f 100644 --- a/internal/api/handlers/management/oauth_sessions.go +++ b/internal/api/handlers/management/oauth_sessions.go @@ -111,6 +111,27 @@ func (s *oauthSessionStore) Complete(state string) { delete(s.sessions, state) } +func (s *oauthSessionStore) CompleteProvider(provider string) int { + provider = strings.ToLower(strings.TrimSpace(provider)) + if provider == "" { + return 0 + } + now := time.Now() + + s.mu.Lock() + defer s.mu.Unlock() + + s.purgeExpiredLocked(now) + removed := 0 + for state, session := range s.sessions { + if strings.EqualFold(session.Provider, provider) { + delete(s.sessions, state) + removed++ + } + } + return removed +} + func (s *oauthSessionStore) Get(state string) (oauthSession, bool) { state = strings.TrimSpace(state) now := time.Now() @@ -153,6 +174,10 @@ func SetOAuthSessionError(state, message string) { oauthSessions.SetError(state, func CompleteOAuthSession(state string) { oauthSessions.Complete(state) } +func CompleteOAuthSessionsByProvider(provider string) int { + return oauthSessions.CompleteProvider(provider) +} + func GetOAuthSession(state string) (provider string, status string, ok bool) { session, ok := oauthSessions.Get(state) if !ok { From 3fc410a253ffd38fa751fd7e9120ce3ffb6af682 Mon Sep 17 00:00:00 2001 From: hkfires <10558748+hkfires@users.noreply.github.com> Date: Sun, 21 Dec 2025 12:51:35 +0800 Subject: [PATCH 32/38] fix(amp): add /settings routes to proxy --- internal/api/modules/amp/routes.go | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/internal/api/modules/amp/routes.go b/internal/api/modules/amp/routes.go index 50900f24..a37c0a15 100644 --- a/internal/api/modules/amp/routes.go +++ b/internal/api/modules/amp/routes.go @@ -126,7 +126,7 @@ func (m *AmpModule) registerManagementRoutes(engine *gin.Engine, baseHandler *ha var authWithBypass gin.HandlerFunc if auth != nil { ampAPI.Use(auth) - authWithBypass = wrapManagementAuth(auth, "/threads", "/auth", "/docs") + authWithBypass = wrapManagementAuth(auth, "/threads", "/auth", "/docs", "/settings") } // Dynamic proxy handler that uses m.getProxy() for hot-reload support @@ -179,6 +179,8 @@ func (m *AmpModule) registerManagementRoutes(engine *gin.Engine, baseHandler *ha engine.GET("/threads/*path", append(rootMiddleware, proxyHandler)...) engine.GET("/docs", append(rootMiddleware, proxyHandler)...) engine.GET("/docs/*path", append(rootMiddleware, proxyHandler)...) + engine.GET("/settings", append(rootMiddleware, proxyHandler)...) + engine.GET("/settings/*path", append(rootMiddleware, proxyHandler)...) engine.GET("/threads.rss", append(rootMiddleware, proxyHandler)...) engine.GET("/news.rss", append(rootMiddleware, proxyHandler)...) From 9f9a4fc2af1088d0f5d130a8448b598b2f6493fe Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=EC=9D=B4=EB=8C=80=ED=9D=AC?= Date: Sun, 21 Dec 2025 14:48:50 +0900 Subject: [PATCH 33/38] Remove unused submodules Removes two obsolete git submodules to clean up repository state and reduce maintenance overhead. This eliminates external references that are no longer needed, simplifying dependency management and repository maintenance going forward. --- litellm | 1 - opencode-google-antigravity-auth | 1 - 2 files changed, 2 deletions(-) delete mode 160000 litellm delete mode 160000 opencode-google-antigravity-auth diff --git a/litellm b/litellm deleted file mode 160000 index 0c48826c..00000000 --- a/litellm +++ /dev/null @@ -1 +0,0 @@ -Subproject commit 0c48826cdc14a30953f55173c1eecdbfc859952d diff --git a/opencode-google-antigravity-auth b/opencode-google-antigravity-auth deleted file mode 160000 index 9f9493c7..00000000 --- a/opencode-google-antigravity-auth +++ /dev/null @@ -1 +0,0 @@ -Subproject commit 9f9493c730cbf0f17429e107a5bde00794752175 From 406a27271a39baed8e69eac83892ab2cc002d4d5 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=EC=9D=B4=EB=8C=80=ED=9D=AC?= Date: Sun, 21 Dec 2025 14:54:49 +0900 Subject: [PATCH 34/38] Remove opencode-antigravity-auth submodule Remove the opencode-antigravity-auth submodule reference from the repository. Cleans up the project by eliminating an external submodule pointer that is no longer needed or maintained, reducing repository complexity and avoiding dangling submodule state. --- opencode-antigravity-auth | 1 - 1 file changed, 1 deletion(-) delete mode 160000 opencode-antigravity-auth diff --git a/opencode-antigravity-auth b/opencode-antigravity-auth deleted file mode 160000 index 261a91f2..00000000 --- a/opencode-antigravity-auth +++ /dev/null @@ -1 +0,0 @@ -Subproject commit 261a91f21bd3bc1660168eb2b82301a6cf372e58 From 1e9e4a86a29f94c668488b99795097cae630953e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=EC=9D=B4=EB=8C=80=ED=9D=AC?= Date: Sun, 21 Dec 2025 15:15:50 +0900 Subject: [PATCH 35/38] Improve thinking/tool signature handling for Claude and Gemini requests MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Prefer cached signatures and avoid injecting dummy thinking blocks; instead remove unsigned thinking blocks and add a skip sentinel for tool calls without a valid signature. Generate stable session IDs from the first user message, apply schema cleaning only for Claude models, and reorder thinking parts so thinking appears first. For Gemini, remove thinking blocks and attach a skip sentinel to function calls. Simplify response handling by passing raw function args through (remove special Bash conversion). Update and add tests to reflect the new behavior. These changes prevent rejected dummy signatures, improve compatibility with Antigravity’s signature validation, provide more stable session IDs for conversation grouping, and make request/response translation more robust. --- .../runtime/executor/antigravity_executor.go | 51 +++++-- .../claude/antigravity_claude_request.go | 143 +++++++++++++++--- .../claude/antigravity_claude_request_test.go | 123 +++++++++++++-- .../claude/antigravity_claude_response.go | 44 +----- .../antigravity_claude_response_test.go | 83 +--------- .../gemini/antigravity_gemini_request.go | 28 +++- .../gemini/antigravity_gemini_request_test.go | 129 ++++++++++++++++ 7 files changed, 420 insertions(+), 181 deletions(-) create mode 100644 internal/translator/antigravity/gemini/antigravity_gemini_request_test.go diff --git a/internal/runtime/executor/antigravity_executor.go b/internal/runtime/executor/antigravity_executor.go index 6be5bf46..ddfcfc3f 100644 --- a/internal/runtime/executor/antigravity_executor.go +++ b/internal/runtime/executor/antigravity_executor.go @@ -7,6 +7,8 @@ import ( "bufio" "bytes" "context" + "crypto/sha256" + "encoding/binary" "encoding/json" "fmt" "io" @@ -70,6 +72,10 @@ func (e *AntigravityExecutor) PrepareRequest(_ *http.Request, _ *cliproxyauth.Au // Execute performs a non-streaming request to the Antigravity API. func (e *AntigravityExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (resp cliproxyexecutor.Response, err error) { + if strings.Contains(req.Model, "claude") { + return e.executeClaudeNonStream(ctx, auth, req, opts) + } + token, updatedAuth, errToken := e.ensureAccessToken(ctx, auth) if errToken != nil { return resp, errToken @@ -993,23 +999,21 @@ func (e *AntigravityExecutor) buildRequest(ctx context.Context, auth *cliproxyau payload = geminiToAntigravity(modelName, payload, projectID) payload, _ = sjson.SetBytes(payload, "model", alias2ModelName(modelName)) - // Apply schema processing for all Antigravity models (Claude, Gemini, GPT-OSS) - // Antigravity uses unified Gemini-style format with same schema restrictions - strJSON := string(payload) + if strings.Contains(modelName, "claude") { + strJSON := string(payload) + paths := make([]string, 0) + util.Walk(gjson.ParseBytes(payload), "", "parametersJsonSchema", &paths) + for _, p := range paths { + strJSON, _ = util.RenameKey(strJSON, p, p[:len(p)-len("parametersJsonSchema")]+"parameters") + } - // Rename parametersJsonSchema -> parameters (used by Claude translator) - paths := make([]string, 0) - util.Walk(gjson.ParseBytes(payload), "", "parametersJsonSchema", &paths) - for _, p := range paths { - strJSON, _ = util.RenameKey(strJSON, p, p[:len(p)-len("parametersJsonSchema")]+"parameters") + // Use the centralized schema cleaner to handle unsupported keywords, + // const->enum conversion, and flattening of types/anyOf. + strJSON = util.CleanJSONSchemaForAntigravity(strJSON) + + payload = []byte(strJSON) } - // Use the centralized schema cleaner to handle unsupported keywords, - // const->enum conversion, and flattening of types/anyOf. - strJSON = util.CleanJSONSchemaForAntigravity(strJSON) - - payload = []byte(strJSON) - httpReq, errReq := http.NewRequestWithContext(ctx, http.MethodPost, requestURL.String(), bytes.NewReader(payload)) if errReq != nil { return nil, errReq @@ -1187,7 +1191,7 @@ func geminiToAntigravity(modelName string, payload []byte, projectID string) []b template, _ = sjson.Set(template, "project", generateProjectID()) } template, _ = sjson.Set(template, "requestId", generateRequestID()) - template, _ = sjson.Set(template, "request.sessionId", generateSessionID()) + template, _ = sjson.Set(template, "request.sessionId", generateStableSessionID(payload)) template, _ = sjson.Delete(template, "request.safetySettings") template, _ = sjson.Set(template, "request.toolConfig.functionCallingConfig.mode", "VALIDATED") @@ -1227,6 +1231,23 @@ func generateSessionID() string { return "-" + strconv.FormatInt(n, 10) } +func generateStableSessionID(payload []byte) string { + contents := gjson.GetBytes(payload, "request.contents") + if contents.IsArray() { + for _, content := range contents.Array() { + if content.Get("role").String() == "user" { + text := content.Get("parts.0.text").String() + if text != "" { + h := sha256.Sum256([]byte(text)) + n := int64(binary.BigEndian.Uint64(h[:8])) & 0x7FFFFFFFFFFFFFFF + return "-" + strconv.FormatInt(n, 10) + } + } + } + } + return generateSessionID() +} + func generateProjectID() string { adjectives := []string{"useful", "bright", "swift", "calm", "bold"} nouns := []string{"fuze", "wave", "spark", "flow", "core"} diff --git a/internal/translator/antigravity/claude/antigravity_claude_request.go b/internal/translator/antigravity/claude/antigravity_claude_request.go index fdfdf469..d5845e63 100644 --- a/internal/translator/antigravity/claude/antigravity_claude_request.go +++ b/internal/translator/antigravity/claude/antigravity_claude_request.go @@ -19,8 +19,6 @@ import ( "github.com/tidwall/sjson" ) -const geminiCLIClaudeThoughtSignature = "skip_thought_signature_validator" - // deriveSessionID generates a stable session ID from the request. // Uses the hash of the first user message to identify the conversation. func deriveSessionID(rawJSON []byte) string { @@ -93,6 +91,12 @@ func ConvertClaudeRequestToAntigravity(modelName string, inputRawJSON []byte, _ // contents contentsJSON := "[]" hasContents := false + + // Track if we need to disable thinking (LiteLLM approach) + // If the last assistant message with tool_use has no valid thinking block before it, + // we need to disable thinkingConfig to avoid "Expected thinking but found tool_use" error + lastAssistantHasToolWithoutThinking := false + messagesResult := gjson.GetBytes(rawJSON, "messages") if messagesResult.IsArray() { messageResults := messagesResult.Array() @@ -114,6 +118,7 @@ func ConvertClaudeRequestToAntigravity(modelName string, inputRawJSON []byte, _ if contentsResult.IsArray() { contentResults := contentsResult.Array() numContents := len(contentResults) + var currentMessageThinkingSignature string for j := 0; j < numContents; j++ { contentResult := contentResults[j] contentTypeResult := contentResult.Get("type") @@ -121,36 +126,46 @@ func ConvertClaudeRequestToAntigravity(modelName string, inputRawJSON []byte, _ // Use GetThinkingText to handle wrapped thinking objects thinkingText := util.GetThinkingText(contentResult) signatureResult := contentResult.Get("signature") - signature := "" - if signatureResult.Exists() && signatureResult.String() != "" { - signature = signatureResult.String() - } + clientSignature := "" + if signatureResult.Exists() && signatureResult.String() != "" { + clientSignature = signatureResult.String() + } - // Try to restore signature from cache for unsigned thinking blocks - if !cache.HasValidSignature(signature) && sessionID != "" && thinkingText != "" { - if cachedSig := cache.GetCachedSignature(sessionID, thinkingText); cachedSig != "" { - signature = cachedSig - log.Debugf("Restored cached signature for thinking block") - } + // Always try cached signature first (more reliable than client-provided) + // Client may send stale or invalid signatures from different sessions + signature := "" + if sessionID != "" && thinkingText != "" { + if cachedSig := cache.GetCachedSignature(sessionID, thinkingText); cachedSig != "" { + signature = cachedSig + log.Debugf("Using cached signature for thinking block") } + } + + // Fallback to client signature only if cache miss and client signature is valid + if signature == "" && cache.HasValidSignature(clientSignature) { + signature = clientSignature + log.Debugf("Using client-provided signature for thinking block") + } + + // Store for subsequent tool_use in the same message + if cache.HasValidSignature(signature) { + currentMessageThinkingSignature = signature + } + // Skip trailing unsigned thinking blocks on last assistant message - isLastMessage := (i == numMessages-1) - isLastContent := (j == numContents-1) - isAssistant := (originalRole == "assistant") isUnsigned := !cache.HasValidSignature(signature) - if isLastMessage && isLastContent && isAssistant && isUnsigned { - // Skip this trailing unsigned thinking block + // If unsigned, skip entirely (don't convert to text) + // Claude requires assistant messages to start with thinking blocks when thinking is enabled + // Converting to text would break this requirement + if isUnsigned { + // TypeScript plugin approach: drop unsigned thinking blocks entirely + log.Debugf("Dropping unsigned thinking block (no valid signature)") continue } - // Apply sentinel for unsigned thinking blocks that are not trailing - // (includes empty string and short/invalid signatures < 50 chars) - if isUnsigned { - signature = geminiCLIClaudeThoughtSignature - } - + // Valid signature, send as thought block partJSON := `{}` partJSON, _ = sjson.Set(partJSON, "thought", true) if thinkingText != "" { @@ -168,6 +183,10 @@ func ConvertClaudeRequestToAntigravity(modelName string, inputRawJSON []byte, _ } clientContentJSON, _ = sjson.SetRaw(clientContentJSON, "parts.-1", partJSON) } else if contentTypeResult.Type == gjson.String && contentTypeResult.String() == "tool_use" { + // NOTE: Do NOT inject dummy thinking blocks here. + // Antigravity API validates signatures, so dummy values are rejected. + // The TypeScript plugin removes unsigned thinking blocks instead of injecting dummies. + functionName := contentResult.Get("name").String() functionArgs := contentResult.Get("input").String() functionID := contentResult.Get("id").String() @@ -175,9 +194,18 @@ func ConvertClaudeRequestToAntigravity(modelName string, inputRawJSON []byte, _ argsResult := gjson.Parse(functionArgs) if argsResult.IsObject() { partJSON := `{}` - if !strings.Contains(modelName, "claude") { - partJSON, _ = sjson.Set(partJSON, "thoughtSignature", geminiCLIClaudeThoughtSignature) + + // Use skip_thought_signature_validator for tool calls without valid thinking signature + // This is the approach used in opencode-google-antigravity-auth for Gemini + // and also works for Claude through Antigravity API + const skipSentinel = "skip_thought_signature_validator" + if cache.HasValidSignature(currentMessageThinkingSignature) { + partJSON, _ = sjson.Set(partJSON, "thoughtSignature", currentMessageThinkingSignature) + } else { + // No valid signature - use skip sentinel to bypass validation + partJSON, _ = sjson.Set(partJSON, "thoughtSignature", skipSentinel) } + if functionID != "" { partJSON, _ = sjson.Set(partJSON, "functionCall.id", functionID) } @@ -239,6 +267,64 @@ func ConvertClaudeRequestToAntigravity(modelName string, inputRawJSON []byte, _ } } } + + // Reorder parts for 'model' role to ensure thinking block is first + if role == "model" { + partsResult := gjson.Get(clientContentJSON, "parts") + if partsResult.IsArray() { + parts := partsResult.Array() + var thinkingParts []gjson.Result + var otherParts []gjson.Result + for _, part := range parts { + if part.Get("thought").Bool() { + thinkingParts = append(thinkingParts, part) + } else { + otherParts = append(otherParts, part) + } + } + if len(thinkingParts) > 0 { + firstPartIsThinking := parts[0].Get("thought").Bool() + if !firstPartIsThinking || len(thinkingParts) > 1 { + var newParts []interface{} + for _, p := range thinkingParts { + newParts = append(newParts, p.Value()) + } + for _, p := range otherParts { + newParts = append(newParts, p.Value()) + } + clientContentJSON, _ = sjson.Set(clientContentJSON, "parts", newParts) + } + } + } + } + + // Check if this assistant message has tool_use without valid thinking + if role == "model" { + partsResult := gjson.Get(clientContentJSON, "parts") + if partsResult.IsArray() { + parts := partsResult.Array() + hasValidThinking := false + hasToolUse := false + + for _, part := range parts { + if part.Get("thought").Bool() { + hasValidThinking = true + } + if part.Get("functionCall").Exists() { + hasToolUse = true + } + } + + // If this message has tool_use but no valid thinking, mark it + // This will be used to disable thinking mode if needed + if hasToolUse && !hasValidThinking { + lastAssistantHasToolWithoutThinking = true + } else { + lastAssistantHasToolWithoutThinking = false + } + } + } + contentsJSON, _ = sjson.SetRaw(contentsJSON, "-1", clientContentJSON) hasContents = true } else if contentsResult.Type == gjson.String { @@ -333,6 +419,13 @@ func ConvertClaudeRequestToAntigravity(modelName string, inputRawJSON []byte, _ out, _ = sjson.Set(out, "request.generationConfig.maxOutputTokens", v.Num) } + // Note: We do NOT drop thinkingConfig here anymore. + // Instead, we: + // 1. Remove unsigned thinking blocks (done during message processing) + // 2. Add skip_thought_signature_validator to tool_use without valid thinking signature + // This approach keeps thinking mode enabled while handling the signature requirements. + _ = lastAssistantHasToolWithoutThinking // Variable is tracked but not used to drop thinkingConfig + outBytes := []byte(out) outBytes = common.AttachDefaultSafetySettings(outBytes, "request.safetySettings") diff --git a/internal/translator/antigravity/claude/antigravity_claude_request_test.go b/internal/translator/antigravity/claude/antigravity_claude_request_test.go index 796ce0d3..1d727c94 100644 --- a/internal/translator/antigravity/claude/antigravity_claude_request_test.go +++ b/internal/translator/antigravity/claude/antigravity_claude_request_test.go @@ -105,6 +105,7 @@ func TestConvertClaudeRequestToAntigravity_ThinkingBlocks(t *testing.T) { } func TestConvertClaudeRequestToAntigravity_ThinkingBlockWithoutSignature(t *testing.T) { + // Unsigned thinking blocks should be removed entirely (not converted to text) inputJSON := []byte(`{ "model": "claude-sonnet-4-5-thinking", "messages": [ @@ -121,11 +122,18 @@ func TestConvertClaudeRequestToAntigravity_ThinkingBlockWithoutSignature(t *test output := ConvertClaudeRequestToAntigravity("claude-sonnet-4-5-thinking", inputJSON, false) outputStr := string(output) - // Without signature, should use sentinel value - firstPart := gjson.Get(outputStr, "request.contents.0.parts.0") - if firstPart.Get("thoughtSignature").String() != geminiCLIClaudeThoughtSignature { - t.Errorf("Expected sentinel signature '%s', got '%s'", - geminiCLIClaudeThoughtSignature, firstPart.Get("thoughtSignature").String()) + // Without signature, thinking block should be removed (not converted to text) + parts := gjson.Get(outputStr, "request.contents.0.parts").Array() + if len(parts) != 1 { + t.Fatalf("Expected 1 part (thinking removed), got %d", len(parts)) + } + + // Only text part should remain + if parts[0].Get("thought").Bool() { + t.Error("Thinking block should be removed, not preserved") + } + if parts[0].Get("text").String() != "Answer" { + t.Errorf("Expected text 'Answer', got '%s'", parts[0].Get("text").String()) } } @@ -192,10 +200,16 @@ func TestConvertClaudeRequestToAntigravity_ToolUse(t *testing.T) { output := ConvertClaudeRequestToAntigravity("claude-sonnet-4-5", inputJSON, false) outputStr := string(output) - // Check function call conversion - funcCall := gjson.Get(outputStr, "request.contents.0.parts.0.functionCall") + // Now we expect only 1 part (tool_use), no dummy thinking block injected + parts := gjson.Get(outputStr, "request.contents.0.parts").Array() + if len(parts) != 1 { + t.Fatalf("Expected 1 part (tool only, no dummy injection), got %d", len(parts)) + } + + // Check function call conversion at parts[0] + funcCall := parts[0].Get("functionCall") if !funcCall.Exists() { - t.Error("functionCall should exist") + t.Error("functionCall should exist at parts[0]") } if funcCall.Get("name").String() != "get_weather" { t.Errorf("Expected function name 'get_weather', got '%s'", funcCall.Get("name").String()) @@ -203,6 +217,78 @@ func TestConvertClaudeRequestToAntigravity_ToolUse(t *testing.T) { if funcCall.Get("id").String() != "call_123" { t.Errorf("Expected function id 'call_123', got '%s'", funcCall.Get("id").String()) } + // Verify skip_thought_signature_validator is added (bypass for tools without valid thinking) + expectedSig := "skip_thought_signature_validator" + actualSig := parts[0].Get("thoughtSignature").String() + if actualSig != expectedSig { + t.Errorf("Expected thoughtSignature '%s', got '%s'", expectedSig, actualSig) + } +} + +func TestConvertClaudeRequestToAntigravity_ToolUse_WithSignature(t *testing.T) { + validSignature := "abc123validSignature1234567890123456789012345678901234567890" + inputJSON := []byte(`{ + "model": "claude-sonnet-4-5-thinking", + "messages": [ + { + "role": "assistant", + "content": [ + {"type": "thinking", "thinking": "Let me think...", "signature": "` + validSignature + `"}, + { + "type": "tool_use", + "id": "call_123", + "name": "get_weather", + "input": "{\"location\": \"Paris\"}" + } + ] + } + ] + }`) + + output := ConvertClaudeRequestToAntigravity("claude-sonnet-4-5-thinking", inputJSON, false) + outputStr := string(output) + + // Check function call has the signature from the preceding thinking block + part := gjson.Get(outputStr, "request.contents.0.parts.1") + if part.Get("functionCall.name").String() != "get_weather" { + t.Errorf("Expected functionCall, got %s", part.Raw) + } + if part.Get("thoughtSignature").String() != validSignature { + t.Errorf("Expected thoughtSignature '%s' on tool_use, got '%s'", validSignature, part.Get("thoughtSignature").String()) + } +} + +func TestConvertClaudeRequestToAntigravity_ReorderThinking(t *testing.T) { + // Case: text block followed by thinking block -> should be reordered to thinking first + validSignature := "abc123validSignature1234567890123456789012345678901234567890" + inputJSON := []byte(`{ + "model": "claude-sonnet-4-5-thinking", + "messages": [ + { + "role": "assistant", + "content": [ + {"type": "text", "text": "Here is the plan."}, + {"type": "thinking", "thinking": "Planning...", "signature": "` + validSignature + `"} + ] + } + ] + }`) + + output := ConvertClaudeRequestToAntigravity("claude-sonnet-4-5-thinking", inputJSON, false) + outputStr := string(output) + + // Verify order: Thinking block MUST be first + parts := gjson.Get(outputStr, "request.contents.0.parts").Array() + if len(parts) != 2 { + t.Fatalf("Expected 2 parts, got %d", len(parts)) + } + + if !parts[0].Get("thought").Bool() { + t.Error("First part should be thinking block after reordering") + } + if parts[1].Get("text").String() != "Here is the plan." { + t.Error("Second part should be text block") + } } func TestConvertClaudeRequestToAntigravity_ToolResult(t *testing.T) { @@ -402,8 +488,8 @@ func TestConvertClaudeRequestToAntigravity_TrailingSignedThinking_Kept(t *testin } } -func TestConvertClaudeRequestToAntigravity_MiddleUnsignedThinking_SentinelApplied(t *testing.T) { - // Middle message has unsigned thinking - should use sentinel (existing behavior) +func TestConvertClaudeRequestToAntigravity_MiddleUnsignedThinking_Removed(t *testing.T) { + // Middle message has unsigned thinking - should be removed entirely inputJSON := []byte(`{ "model": "claude-sonnet-4-5-thinking", "messages": [ @@ -424,13 +510,18 @@ func TestConvertClaudeRequestToAntigravity_MiddleUnsignedThinking_SentinelApplie output := ConvertClaudeRequestToAntigravity("claude-sonnet-4-5-thinking", inputJSON, false) outputStr := string(output) - // Middle unsigned thinking should have sentinel applied - thinkingPart := gjson.Get(outputStr, "request.contents.0.parts.0") - if !thinkingPart.Get("thought").Bool() { - t.Error("Middle thinking block should be preserved with sentinel") + // Unsigned thinking should be removed entirely + parts := gjson.Get(outputStr, "request.contents.0.parts").Array() + if len(parts) != 1 { + t.Fatalf("Expected 1 part (thinking removed), got %d", len(parts)) } - if thinkingPart.Get("thoughtSignature").String() != geminiCLIClaudeThoughtSignature { - t.Errorf("Middle unsigned thinking should use sentinel signature, got: %s", thinkingPart.Get("thoughtSignature").String()) + + // Only text part should remain + if parts[0].Get("thought").Bool() { + t.Error("Thinking block should be removed, not preserved") + } + if parts[0].Get("text").String() != "Answer" { + t.Errorf("Expected text 'Answer', got '%s'", parts[0].Get("text").String()) } } diff --git a/internal/translator/antigravity/claude/antigravity_claude_response.go b/internal/translator/antigravity/claude/antigravity_claude_response.go index 8f47b9bf..ddda5ddb 100644 --- a/internal/translator/antigravity/claude/antigravity_claude_response.go +++ b/internal/translator/antigravity/claude/antigravity_claude_response.go @@ -253,13 +253,8 @@ func ConvertAntigravityResponseToClaude(_ context.Context, _ string, originalReq output = output + fmt.Sprintf("data: %s\n\n\n", data) if fcArgsResult := functionCallResult.Get("args"); fcArgsResult.Exists() { - argsRaw := fcArgsResult.Raw - // Convert command → cmd for Bash tools using proper JSON parsing - if fcName == "Bash" || fcName == "bash" || fcName == "bash_20241022" { - argsRaw = convertBashCommandToCmdField(argsRaw) - } output = output + "event: content_block_delta\n" - data, _ = sjson.Set(fmt.Sprintf(`{"type":"content_block_delta","index":%d,"delta":{"type":"input_json_delta","partial_json":""}}`, params.ResponseIndex), "delta.partial_json", argsRaw) + data, _ = sjson.Set(fmt.Sprintf(`{"type":"content_block_delta","index":%d,"delta":{"type":"input_json_delta","partial_json":""}}`, params.ResponseIndex), "delta.partial_json", fcArgsResult.Raw) output = output + fmt.Sprintf("data: %s\n\n\n", data) } params.ResponseType = 3 @@ -347,36 +342,6 @@ func resolveStopReason(params *Params) string { return "end_turn" } -// convertBashCommandToCmdField converts "command" field to "cmd" field for Bash tools. -// Amp expects "cmd" but Gemini sends "command". This uses proper JSON parsing -// to avoid accidentally replacing "command" that appears in values. -func convertBashCommandToCmdField(argsRaw string) string { - // Only process valid JSON - if !gjson.Valid(argsRaw) { - return argsRaw - } - - // Check if "command" key exists and "cmd" doesn't - commandVal := gjson.Get(argsRaw, "command") - cmdVal := gjson.Get(argsRaw, "cmd") - - if commandVal.Exists() && !cmdVal.Exists() { - // Set "cmd" to the value of "command", preserve the raw value type - result, err := sjson.SetRaw(argsRaw, "cmd", commandVal.Raw) - if err != nil { - return argsRaw - } - // Delete "command" key - result, err = sjson.Delete(result, "command") - if err != nil { - return argsRaw - } - return result - } - - return argsRaw -} - // ConvertAntigravityResponseToClaudeNonStream converts a non-streaming Gemini CLI response to a non-streaming Claude response. // // Parameters: @@ -488,12 +453,7 @@ func ConvertAntigravityResponseToClaudeNonStream(_ context.Context, _ string, or toolBlock, _ = sjson.Set(toolBlock, "name", name) if args := functionCall.Get("args"); args.Exists() && args.Raw != "" && gjson.Valid(args.Raw) { - argsRaw := args.Raw - // Convert command → cmd for Bash tools - if name == "Bash" || name == "bash" || name == "bash_20241022" { - argsRaw = convertBashCommandToCmdField(argsRaw) - } - toolBlock, _ = sjson.SetRaw(toolBlock, "input", argsRaw) + toolBlock, _ = sjson.SetRaw(toolBlock, "input", args.Raw) } ensureContentArray() diff --git a/internal/translator/antigravity/claude/antigravity_claude_response_test.go b/internal/translator/antigravity/claude/antigravity_claude_response_test.go index 4c2f31c1..afc3d937 100644 --- a/internal/translator/antigravity/claude/antigravity_claude_response_test.go +++ b/internal/translator/antigravity/claude/antigravity_claude_response_test.go @@ -8,79 +8,6 @@ import ( "github.com/router-for-me/CLIProxyAPI/v6/internal/cache" ) -func TestConvertBashCommandToCmdField(t *testing.T) { - tests := []struct { - name string - input string - expected string - }{ - { - name: "basic command to cmd conversion", - input: `{"command": "git diff"}`, - expected: `{"cmd":"git diff"}`, - }, - { - name: "already has cmd field - no change", - input: `{"cmd": "git diff"}`, - expected: `{"cmd": "git diff"}`, - }, - { - name: "both cmd and command - keep cmd only", - input: `{"command": "git diff", "cmd": "ls"}`, - expected: `{"command": "git diff", "cmd": "ls"}`, // no change when cmd exists - }, - { - name: "command with special characters in value", - input: `{"command": "echo \"command\": test"}`, - expected: `{"cmd":"echo \"command\": test"}`, - }, - { - name: "command with nested quotes", - input: `{"command": "bash -c 'echo \"hello\"'"}`, - expected: `{"cmd":"bash -c 'echo \"hello\"'"}`, - }, - { - name: "command with newlines", - input: `{"command": "echo hello\necho world"}`, - expected: `{"cmd":"echo hello\necho world"}`, - }, - { - name: "empty command value", - input: `{"command": ""}`, - expected: `{"cmd":""}`, - }, - { - name: "command with other fields - preserves them", - input: `{"command": "git diff", "timeout": 30}`, - expected: `{ "timeout": 30,"cmd":"git diff"}`, - }, - { - name: "invalid JSON - returns unchanged", - input: `{invalid json`, - expected: `{invalid json`, - }, - { - name: "empty object", - input: `{}`, - expected: `{}`, - }, - { - name: "no command field", - input: `{"restart": true}`, - expected: `{"restart": true}`, - }, - } - - for _, tt := range tests { - t.Run(tt.name, func(t *testing.T) { - result := convertBashCommandToCmdField(tt.input) - if result != tt.expected { - t.Errorf("convertBashCommandToCmdField(%q) = %q, want %q", tt.input, result, tt.expected) - } - }) - } -} - // ============================================================================ // Signature Caching Tests // ============================================================================ @@ -354,7 +281,7 @@ func TestDeriveSessionIDFromRequest(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { - result := deriveSessionIDFromRequest(tt.input) + result := deriveSessionID(tt.input) if tt.wantEmpty && result != "" { t.Errorf("Expected empty session ID, got '%s'", result) } @@ -368,8 +295,8 @@ func TestDeriveSessionIDFromRequest(t *testing.T) { func TestDeriveSessionIDFromRequest_Deterministic(t *testing.T) { input := []byte(`{"messages": [{"role": "user", "content": "Same message"}]}`) - id1 := deriveSessionIDFromRequest(input) - id2 := deriveSessionIDFromRequest(input) + id1 := deriveSessionID(input) + id2 := deriveSessionID(input) if id1 != id2 { t.Errorf("Session ID should be deterministic: '%s' != '%s'", id1, id2) @@ -380,8 +307,8 @@ func TestDeriveSessionIDFromRequest_DifferentMessages(t *testing.T) { input1 := []byte(`{"messages": [{"role": "user", "content": "Message A"}]}`) input2 := []byte(`{"messages": [{"role": "user", "content": "Message B"}]}`) - id1 := deriveSessionIDFromRequest(input1) - id2 := deriveSessionIDFromRequest(input2) + id1 := deriveSessionID(input1) + id2 := deriveSessionID(input2) if id1 == id2 { t.Error("Different messages should produce different session IDs") diff --git a/internal/translator/antigravity/gemini/antigravity_gemini_request.go b/internal/translator/antigravity/gemini/antigravity_gemini_request.go index e694b790..394cc05b 100644 --- a/internal/translator/antigravity/gemini/antigravity_gemini_request.go +++ b/internal/translator/antigravity/gemini/antigravity_gemini_request.go @@ -98,16 +98,34 @@ func ConvertGeminiRequestToAntigravity(_ string, inputRawJSON []byte, _ bool) [] } } - gjson.GetBytes(rawJSON, "request.contents").ForEach(func(key, content gjson.Result) bool { + // Gemini-specific handling: add skip_thought_signature_validator to functionCall parts + // and remove thinking blocks entirely (Gemini doesn't need to preserve them) + const skipSentinel = "skip_thought_signature_validator" + + gjson.GetBytes(rawJSON, "request.contents").ForEach(func(contentIdx, content gjson.Result) bool { if content.Get("role").String() == "model" { - content.Get("parts").ForEach(func(partKey, part gjson.Result) bool { + // First pass: collect indices of thinking parts to remove + var thinkingIndicesToRemove []int64 + content.Get("parts").ForEach(func(partIdx, part gjson.Result) bool { + // Mark thinking blocks for removal + if part.Get("thought").Bool() { + thinkingIndicesToRemove = append(thinkingIndicesToRemove, partIdx.Int()) + } + // Add skip sentinel to functionCall parts if part.Get("functionCall").Exists() { - rawJSON, _ = sjson.SetBytes(rawJSON, fmt.Sprintf("request.contents.%d.parts.%d.thoughtSignature", key.Int(), partKey.Int()), "skip_thought_signature_validator") - } else if part.Get("thoughtSignature").Exists() { - rawJSON, _ = sjson.SetBytes(rawJSON, fmt.Sprintf("request.contents.%d.parts.%d.thoughtSignature", key.Int(), partKey.Int()), "skip_thought_signature_validator") + existingSig := part.Get("thoughtSignature").String() + if existingSig == "" || len(existingSig) < 50 { + rawJSON, _ = sjson.SetBytes(rawJSON, fmt.Sprintf("request.contents.%d.parts.%d.thoughtSignature", contentIdx.Int(), partIdx.Int()), skipSentinel) + } } return true }) + + // Remove thinking blocks in reverse order to preserve indices + for i := len(thinkingIndicesToRemove) - 1; i >= 0; i-- { + idx := thinkingIndicesToRemove[i] + rawJSON, _ = sjson.DeleteBytes(rawJSON, fmt.Sprintf("request.contents.%d.parts.%d", contentIdx.Int(), idx)) + } } return true }) diff --git a/internal/translator/antigravity/gemini/antigravity_gemini_request_test.go b/internal/translator/antigravity/gemini/antigravity_gemini_request_test.go new file mode 100644 index 00000000..58cffd69 --- /dev/null +++ b/internal/translator/antigravity/gemini/antigravity_gemini_request_test.go @@ -0,0 +1,129 @@ +package gemini + +import ( + "fmt" + "testing" + + "github.com/tidwall/gjson" +) + +func TestConvertGeminiRequestToAntigravity_PreserveValidSignature(t *testing.T) { + // Valid signature on functionCall should be preserved + validSignature := "abc123validSignature1234567890123456789012345678901234567890" + inputJSON := []byte(fmt.Sprintf(`{ + "model": "gemini-3-pro-preview", + "contents": [ + { + "role": "model", + "parts": [ + {"functionCall": {"name": "test_tool", "args": {}}, "thoughtSignature": "%s"} + ] + } + ] + }`, validSignature)) + + output := ConvertGeminiRequestToAntigravity("gemini-3-pro-preview", inputJSON, false) + outputStr := string(output) + + // Check that valid thoughtSignature is preserved + parts := gjson.Get(outputStr, "request.contents.0.parts").Array() + if len(parts) != 1 { + t.Fatalf("Expected 1 part, got %d", len(parts)) + } + + sig := parts[0].Get("thoughtSignature").String() + if sig != validSignature { + t.Errorf("Expected thoughtSignature '%s', got '%s'", validSignature, sig) + } +} + +func TestConvertGeminiRequestToAntigravity_AddSkipSentinelToFunctionCall(t *testing.T) { + // functionCall without signature should get skip_thought_signature_validator + inputJSON := []byte(`{ + "model": "gemini-3-pro-preview", + "contents": [ + { + "role": "model", + "parts": [ + {"functionCall": {"name": "test_tool", "args": {}}} + ] + } + ] + }`) + + output := ConvertGeminiRequestToAntigravity("gemini-3-pro-preview", inputJSON, false) + outputStr := string(output) + + // Check that skip_thought_signature_validator is added to functionCall + sig := gjson.Get(outputStr, "request.contents.0.parts.0.thoughtSignature").String() + expectedSig := "skip_thought_signature_validator" + if sig != expectedSig { + t.Errorf("Expected skip sentinel '%s', got '%s'", expectedSig, sig) + } +} + +func TestConvertGeminiRequestToAntigravity_RemoveThinkingBlocks(t *testing.T) { + // Thinking blocks should be removed entirely for Gemini + validSignature := "abc123validSignature1234567890123456789012345678901234567890" + inputJSON := []byte(fmt.Sprintf(`{ + "model": "gemini-3-pro-preview", + "contents": [ + { + "role": "model", + "parts": [ + {"thought": true, "text": "Thinking...", "thoughtSignature": "%s"}, + {"text": "Here is my response"} + ] + } + ] + }`, validSignature)) + + output := ConvertGeminiRequestToAntigravity("gemini-3-pro-preview", inputJSON, false) + outputStr := string(output) + + // Check that thinking block is removed + parts := gjson.Get(outputStr, "request.contents.0.parts").Array() + if len(parts) != 1 { + t.Fatalf("Expected 1 part (thinking removed), got %d", len(parts)) + } + + // Only text part should remain + if parts[0].Get("thought").Bool() { + t.Error("Thinking block should be removed for Gemini") + } + if parts[0].Get("text").String() != "Here is my response" { + t.Errorf("Expected text 'Here is my response', got '%s'", parts[0].Get("text").String()) + } +} + +func TestConvertGeminiRequestToAntigravity_ParallelFunctionCalls(t *testing.T) { + // Multiple functionCalls should all get skip_thought_signature_validator + inputJSON := []byte(`{ + "model": "gemini-3-pro-preview", + "contents": [ + { + "role": "model", + "parts": [ + {"functionCall": {"name": "tool_one", "args": {"a": "1"}}}, + {"functionCall": {"name": "tool_two", "args": {"b": "2"}}} + ] + } + ] + }`) + + output := ConvertGeminiRequestToAntigravity("gemini-3-pro-preview", inputJSON, false) + outputStr := string(output) + + parts := gjson.Get(outputStr, "request.contents.0.parts").Array() + if len(parts) != 2 { + t.Fatalf("Expected 2 parts, got %d", len(parts)) + } + + expectedSig := "skip_thought_signature_validator" + for i, part := range parts { + sig := part.Get("thoughtSignature").String() + if sig != expectedSig { + t.Errorf("Part %d: Expected '%s', got '%s'", i, expectedSig, sig) + } + } +} From 4070c9de81c1378d4e488417eb46c1cb1f827cfb Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=EC=9D=B4=EB=8C=80=ED=9D=AC?= Date: Sun, 21 Dec 2025 15:29:36 +0900 Subject: [PATCH 36/38] Remove interleaved-thinking header from requests Removes the addition of the "anthropic-beta: interleaved-thinking-2025-05-14" header for Claude thinking models when building HTTP requests. This prevents sending an experimental/feature flag header that is no longer required and avoids potential compatibility or routing issues with downstream services. Keeps request headers simpler and more standard. --- internal/runtime/executor/antigravity_executor.go | 6 ------ 1 file changed, 6 deletions(-) diff --git a/internal/runtime/executor/antigravity_executor.go b/internal/runtime/executor/antigravity_executor.go index ddfcfc3f..38723e77 100644 --- a/internal/runtime/executor/antigravity_executor.go +++ b/internal/runtime/executor/antigravity_executor.go @@ -1021,12 +1021,6 @@ func (e *AntigravityExecutor) buildRequest(ctx context.Context, auth *cliproxyau httpReq.Header.Set("Content-Type", "application/json") httpReq.Header.Set("Authorization", "Bearer "+token) httpReq.Header.Set("User-Agent", resolveUserAgent(auth)) - - // Add interleaved-thinking header for Claude thinking models - if util.IsClaudeThinkingModel(modelName) { - httpReq.Header.Set("anthropic-beta", "interleaved-thinking-2025-05-14") - } - if stream { httpReq.Header.Set("Accept", "text/event-stream") } else { From 7dc40ba6d4ebc019910951dc8f7a99953510cbef Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=EC=9D=B4=EB=8C=80=ED=9D=AC?= Date: Sun, 21 Dec 2025 17:16:40 +0900 Subject: [PATCH 37/38] Improve tool-call parsing, schema sanitization, and hint injection Improve parsing of tool call inputs and Antigravity compatibility to avoid invalid thinking/tool_use errors. - Parse tool call inputs robustly by accepting both object and JSON-string formats and only produce a functionCall part when valid args exist, reducing spurious or malformed parts. - Preserve the skip_thought_signature_validator approach for calls without a valid thinking signature but stop toggling/tracking a separate "disable thinking" flag; this prevents unnecessary removal of thinkingConfig. - Sanitize tool input schemas before attaching them to the Antigravity request to improve compatibility. - Append the interleaved-thinking hint as a new parts entry instead of overwriting/setting text directly, preserving structure. - Remove unused tracking logic and related comments to simplify flow. These changes reduce errors related to missing/invalid thinking signatures, improve schema compatibility, and make hint injection safer and more consistent. --- .../claude/antigravity_claude_request.go | 106 +++++++----------- 1 file changed, 41 insertions(+), 65 deletions(-) diff --git a/internal/translator/antigravity/claude/antigravity_claude_request.go b/internal/translator/antigravity/claude/antigravity_claude_request.go index d5845e63..8a54e739 100644 --- a/internal/translator/antigravity/claude/antigravity_claude_request.go +++ b/internal/translator/antigravity/claude/antigravity_claude_request.go @@ -92,11 +92,6 @@ func ConvertClaudeRequestToAntigravity(modelName string, inputRawJSON []byte, _ contentsJSON := "[]" hasContents := false - // Track if we need to disable thinking (LiteLLM approach) - // If the last assistant message with tool_use has no valid thinking block before it, - // we need to disable thinkingConfig to avoid "Expected thinking but found tool_use" error - lastAssistantHasToolWithoutThinking := false - messagesResult := gjson.GetBytes(rawJSON, "messages") if messagesResult.IsArray() { messageResults := messagesResult.Array() @@ -188,32 +183,42 @@ func ConvertClaudeRequestToAntigravity(modelName string, inputRawJSON []byte, _ // The TypeScript plugin removes unsigned thinking blocks instead of injecting dummies. functionName := contentResult.Get("name").String() - functionArgs := contentResult.Get("input").String() + argsResult := contentResult.Get("input") functionID := contentResult.Get("id").String() - if gjson.Valid(functionArgs) { - argsResult := gjson.Parse(functionArgs) - if argsResult.IsObject() { - partJSON := `{}` - // Use skip_thought_signature_validator for tool calls without valid thinking signature - // This is the approach used in opencode-google-antigravity-auth for Gemini - // and also works for Claude through Antigravity API - const skipSentinel = "skip_thought_signature_validator" - if cache.HasValidSignature(currentMessageThinkingSignature) { - partJSON, _ = sjson.Set(partJSON, "thoughtSignature", currentMessageThinkingSignature) - } else { - // No valid signature - use skip sentinel to bypass validation - partJSON, _ = sjson.Set(partJSON, "thoughtSignature", skipSentinel) - } - - if functionID != "" { - partJSON, _ = sjson.Set(partJSON, "functionCall.id", functionID) - } - partJSON, _ = sjson.Set(partJSON, "functionCall.name", functionName) - partJSON, _ = sjson.SetRaw(partJSON, "functionCall.args", argsResult.Raw) - clientContentJSON, _ = sjson.SetRaw(clientContentJSON, "parts.-1", partJSON) + // Handle both object and string input formats + var argsRaw string + if argsResult.IsObject() { + argsRaw = argsResult.Raw + } else if argsResult.Type == gjson.String { + // Input is a JSON string, parse and validate it + parsed := gjson.Parse(argsResult.String()) + if parsed.IsObject() { + argsRaw = parsed.Raw } } + + if argsRaw != "" { + partJSON := `{}` + + // Use skip_thought_signature_validator for tool calls without valid thinking signature + // This is the approach used in opencode-google-antigravity-auth for Gemini + // and also works for Claude through Antigravity API + const skipSentinel = "skip_thought_signature_validator" + if cache.HasValidSignature(currentMessageThinkingSignature) { + partJSON, _ = sjson.Set(partJSON, "thoughtSignature", currentMessageThinkingSignature) + } else { + // No valid signature - use skip sentinel to bypass validation + partJSON, _ = sjson.Set(partJSON, "thoughtSignature", skipSentinel) + } + + if functionID != "" { + partJSON, _ = sjson.Set(partJSON, "functionCall.id", functionID) + } + partJSON, _ = sjson.Set(partJSON, "functionCall.name", functionName) + partJSON, _ = sjson.SetRaw(partJSON, "functionCall.args", argsRaw) + clientContentJSON, _ = sjson.SetRaw(clientContentJSON, "parts.-1", partJSON) + } } else if contentTypeResult.Type == gjson.String && contentTypeResult.String() == "tool_result" { toolCallID := contentResult.Get("tool_use_id").String() if toolCallID != "" { @@ -298,33 +303,6 @@ func ConvertClaudeRequestToAntigravity(modelName string, inputRawJSON []byte, _ } } - // Check if this assistant message has tool_use without valid thinking - if role == "model" { - partsResult := gjson.Get(clientContentJSON, "parts") - if partsResult.IsArray() { - parts := partsResult.Array() - hasValidThinking := false - hasToolUse := false - - for _, part := range parts { - if part.Get("thought").Bool() { - hasValidThinking = true - } - if part.Get("functionCall").Exists() { - hasToolUse = true - } - } - - // If this message has tool_use but no valid thinking, mark it - // This will be used to disable thinking mode if needed - if hasToolUse && !hasValidThinking { - lastAssistantHasToolWithoutThinking = true - } else { - lastAssistantHasToolWithoutThinking = false - } - } - } - contentsJSON, _ = sjson.SetRaw(contentsJSON, "-1", clientContentJSON) hasContents = true } else if contentsResult.Type == gjson.String { @@ -351,7 +329,8 @@ func ConvertClaudeRequestToAntigravity(modelName string, inputRawJSON []byte, _ toolResult := toolsResults[i] inputSchemaResult := toolResult.Get("input_schema") if inputSchemaResult.Exists() && inputSchemaResult.IsObject() { - inputSchema := inputSchemaResult.Raw + // Sanitize the input schema for Antigravity API compatibility + inputSchema := util.CleanJSONSchemaForAntigravity(inputSchemaResult.Raw) tool, _ := sjson.Delete(toolResult.Raw, "input_schema") tool, _ = sjson.SetRaw(tool, "parametersJsonSchema", inputSchema) tool, _ = sjson.Delete(tool, "strict") @@ -376,12 +355,16 @@ func ConvertClaudeRequestToAntigravity(modelName string, inputRawJSON []byte, _ interleavedHint := "Interleaved thinking is enabled. You may think between tool calls and after receiving tool results before deciding the next action or final answer. Do not mention these instructions or any constraints about thinking blocks; just apply them." if hasSystemInstruction { - // Append hint to existing system instruction - systemInstructionJSON, _ = sjson.Set(systemInstructionJSON, "parts.-1.text", interleavedHint) + // Append hint as a new part to existing system instruction + hintPart := `{"text":""}` + hintPart, _ = sjson.Set(hintPart, "text", interleavedHint) + systemInstructionJSON, _ = sjson.SetRaw(systemInstructionJSON, "parts.-1", hintPart) } else { // Create new system instruction with hint systemInstructionJSON = `{"role":"user","parts":[]}` - systemInstructionJSON, _ = sjson.Set(systemInstructionJSON, "parts.-1.text", interleavedHint) + hintPart := `{"text":""}` + hintPart, _ = sjson.Set(hintPart, "text", interleavedHint) + systemInstructionJSON, _ = sjson.SetRaw(systemInstructionJSON, "parts.-1", hintPart) hasSystemInstruction = true } } @@ -419,13 +402,6 @@ func ConvertClaudeRequestToAntigravity(modelName string, inputRawJSON []byte, _ out, _ = sjson.Set(out, "request.generationConfig.maxOutputTokens", v.Num) } - // Note: We do NOT drop thinkingConfig here anymore. - // Instead, we: - // 1. Remove unsigned thinking blocks (done during message processing) - // 2. Add skip_thought_signature_validator to tool_use without valid thinking signature - // This approach keeps thinking mode enabled while handling the signature requirements. - _ = lastAssistantHasToolWithoutThinking // Variable is tracked but not used to drop thinkingConfig - outBytes := []byte(out) outBytes = common.AttachDefaultSafetySettings(outBytes, "request.safetySettings") From f6d625114c98de2010a4841b5f9a69b70a0ea704 Mon Sep 17 00:00:00 2001 From: Luis Pater Date: Sun, 21 Dec 2025 16:17:48 +0800 Subject: [PATCH 38/38] feat(logging): revamp request logger to support streaming and temporary file spooling This update enhances the `FileRequestLogger` by introducing support for spooling large request and response bodies to temporary files, reducing memory consumption. It adds atomic requestLogID generation for sequential log naming and new methods for non-streaming/streaming log assembly. Also includes better error handling during logging and temp file cleanups. --- internal/logging/request_logger.go | 494 +++++++++++++++++++++++------ 1 file changed, 403 insertions(+), 91 deletions(-) diff --git a/internal/logging/request_logger.go b/internal/logging/request_logger.go index f8c068c5..391f2869 100644 --- a/internal/logging/request_logger.go +++ b/internal/logging/request_logger.go @@ -14,6 +14,7 @@ import ( "regexp" "sort" "strings" + "sync/atomic" "time" "github.com/andybalholm/brotli" @@ -25,6 +26,8 @@ import ( "github.com/router-for-me/CLIProxyAPI/v6/internal/util" ) +var requestLogID atomic.Uint64 + // RequestLogger defines the interface for logging HTTP requests and responses. // It provides methods for logging both regular and streaming HTTP request/response cycles. type RequestLogger interface { @@ -204,19 +207,52 @@ func (l *FileRequestLogger) logRequest(url, method string, requestHeaders map[st } filePath := filepath.Join(l.logsDir, filename) - // Decompress response if needed - decompressedResponse, err := l.decompressResponse(responseHeaders, response) - if err != nil { - // If decompression fails, log the error but continue with original response - decompressedResponse = append(response, []byte(fmt.Sprintf("\n[DECOMPRESSION ERROR: %v]", err))...) + requestBodyPath, errTemp := l.writeRequestBodyTempFile(body) + if errTemp != nil { + log.WithError(errTemp).Warn("failed to create request body temp file, falling back to direct write") + } + if requestBodyPath != "" { + defer func() { + if errRemove := os.Remove(requestBodyPath); errRemove != nil { + log.WithError(errRemove).Warn("failed to remove request body temp file") + } + }() } - // Create log content - content := l.formatLogContent(url, method, requestHeaders, body, apiRequest, apiResponse, decompressedResponse, statusCode, responseHeaders, apiResponseErrors) + responseToWrite, decompressErr := l.decompressResponse(responseHeaders, response) + if decompressErr != nil { + // If decompression fails, continue with original response and annotate the log output. + responseToWrite = response + } - // Write to file - if err = os.WriteFile(filePath, []byte(content), 0644); err != nil { - return fmt.Errorf("failed to write log file: %w", err) + logFile, errOpen := os.OpenFile(filePath, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0644) + if errOpen != nil { + return fmt.Errorf("failed to create log file: %w", errOpen) + } + + writeErr := l.writeNonStreamingLog( + logFile, + url, + method, + requestHeaders, + body, + requestBodyPath, + apiRequest, + apiResponse, + apiResponseErrors, + statusCode, + responseHeaders, + responseToWrite, + decompressErr, + ) + if errClose := logFile.Close(); errClose != nil { + log.WithError(errClose).Warn("failed to close request log file") + if writeErr == nil { + return errClose + } + } + if writeErr != nil { + return fmt.Errorf("failed to write log file: %w", writeErr) } if force && !l.enabled { @@ -253,26 +289,38 @@ func (l *FileRequestLogger) LogStreamingRequest(url, method string, headers map[ filename := l.generateFilename(url) filePath := filepath.Join(l.logsDir, filename) - // Create and open file - file, err := os.Create(filePath) - if err != nil { - return nil, fmt.Errorf("failed to create log file: %w", err) + requestHeaders := make(map[string][]string, len(headers)) + for key, values := range headers { + headerValues := make([]string, len(values)) + copy(headerValues, values) + requestHeaders[key] = headerValues } - // Write initial request information - requestInfo := l.formatRequestInfo(url, method, headers, body) - if _, err = file.WriteString(requestInfo); err != nil { - _ = file.Close() - return nil, fmt.Errorf("failed to write request info: %w", err) + requestBodyPath, errTemp := l.writeRequestBodyTempFile(body) + if errTemp != nil { + return nil, fmt.Errorf("failed to create request body temp file: %w", errTemp) } + responseBodyFile, errCreate := os.CreateTemp(l.logsDir, "response-body-*.tmp") + if errCreate != nil { + _ = os.Remove(requestBodyPath) + return nil, fmt.Errorf("failed to create response body temp file: %w", errCreate) + } + responseBodyPath := responseBodyFile.Name() + // Create streaming writer writer := &FileStreamingLogWriter{ - file: file, - chunkChan: make(chan []byte, 100), // Buffered channel for async writes - closeChan: make(chan struct{}), - errorChan: make(chan error, 1), - bufferedChunks: &bytes.Buffer{}, + logFilePath: filePath, + url: url, + method: method, + timestamp: time.Now(), + requestHeaders: requestHeaders, + requestBodyPath: requestBodyPath, + responseBodyPath: responseBodyPath, + responseBodyFile: responseBodyFile, + chunkChan: make(chan []byte, 100), // Buffered channel for async writes + closeChan: make(chan struct{}), + errorChan: make(chan error, 1), } // Start async writer goroutine @@ -323,7 +371,9 @@ func (l *FileRequestLogger) generateFilename(url string) string { timestamp := time.Now().Format("2006-01-02T150405-.000000000") timestamp = strings.Replace(timestamp, ".", "", -1) - return fmt.Sprintf("%s-%s.log", sanitized, timestamp) + id := requestLogID.Add(1) + + return fmt.Sprintf("%s-%s-%d.log", sanitized, timestamp, id) } // sanitizeForFilename replaces characters that are not safe for filenames. @@ -405,6 +455,220 @@ func (l *FileRequestLogger) cleanupOldErrorLogs() error { return nil } +func (l *FileRequestLogger) writeRequestBodyTempFile(body []byte) (string, error) { + tmpFile, errCreate := os.CreateTemp(l.logsDir, "request-body-*.tmp") + if errCreate != nil { + return "", errCreate + } + tmpPath := tmpFile.Name() + + if _, errCopy := io.Copy(tmpFile, bytes.NewReader(body)); errCopy != nil { + _ = tmpFile.Close() + _ = os.Remove(tmpPath) + return "", errCopy + } + if errClose := tmpFile.Close(); errClose != nil { + _ = os.Remove(tmpPath) + return "", errClose + } + return tmpPath, nil +} + +func (l *FileRequestLogger) writeNonStreamingLog( + w io.Writer, + url, method string, + requestHeaders map[string][]string, + requestBody []byte, + requestBodyPath string, + apiRequest []byte, + apiResponse []byte, + apiResponseErrors []*interfaces.ErrorMessage, + statusCode int, + responseHeaders map[string][]string, + response []byte, + decompressErr error, +) error { + if errWrite := writeRequestInfoWithBody(w, url, method, requestHeaders, requestBody, requestBodyPath, time.Now()); errWrite != nil { + return errWrite + } + if errWrite := writeAPISection(w, "=== API REQUEST ===\n", "=== API REQUEST", apiRequest); errWrite != nil { + return errWrite + } + if errWrite := writeAPIErrorResponses(w, apiResponseErrors); errWrite != nil { + return errWrite + } + if errWrite := writeAPISection(w, "=== API RESPONSE ===\n", "=== API RESPONSE", apiResponse); errWrite != nil { + return errWrite + } + return writeResponseSection(w, statusCode, true, responseHeaders, bytes.NewReader(response), decompressErr, true) +} + +func writeRequestInfoWithBody( + w io.Writer, + url, method string, + headers map[string][]string, + body []byte, + bodyPath string, + timestamp time.Time, +) error { + if _, errWrite := io.WriteString(w, "=== REQUEST INFO ===\n"); errWrite != nil { + return errWrite + } + if _, errWrite := io.WriteString(w, fmt.Sprintf("Version: %s\n", buildinfo.Version)); errWrite != nil { + return errWrite + } + if _, errWrite := io.WriteString(w, fmt.Sprintf("URL: %s\n", url)); errWrite != nil { + return errWrite + } + if _, errWrite := io.WriteString(w, fmt.Sprintf("Method: %s\n", method)); errWrite != nil { + return errWrite + } + if _, errWrite := io.WriteString(w, fmt.Sprintf("Timestamp: %s\n", timestamp.Format(time.RFC3339Nano))); errWrite != nil { + return errWrite + } + if _, errWrite := io.WriteString(w, "\n"); errWrite != nil { + return errWrite + } + + if _, errWrite := io.WriteString(w, "=== HEADERS ===\n"); errWrite != nil { + return errWrite + } + for key, values := range headers { + for _, value := range values { + masked := util.MaskSensitiveHeaderValue(key, value) + if _, errWrite := io.WriteString(w, fmt.Sprintf("%s: %s\n", key, masked)); errWrite != nil { + return errWrite + } + } + } + if _, errWrite := io.WriteString(w, "\n"); errWrite != nil { + return errWrite + } + + if _, errWrite := io.WriteString(w, "=== REQUEST BODY ===\n"); errWrite != nil { + return errWrite + } + + if bodyPath != "" { + bodyFile, errOpen := os.Open(bodyPath) + if errOpen != nil { + return errOpen + } + if _, errCopy := io.Copy(w, bodyFile); errCopy != nil { + _ = bodyFile.Close() + return errCopy + } + if errClose := bodyFile.Close(); errClose != nil { + log.WithError(errClose).Warn("failed to close request body temp file") + } + } else if _, errWrite := w.Write(body); errWrite != nil { + return errWrite + } + + if _, errWrite := io.WriteString(w, "\n\n"); errWrite != nil { + return errWrite + } + return nil +} + +func writeAPISection(w io.Writer, sectionHeader string, sectionPrefix string, payload []byte) error { + if len(payload) == 0 { + return nil + } + + if bytes.HasPrefix(payload, []byte(sectionPrefix)) { + if _, errWrite := w.Write(payload); errWrite != nil { + return errWrite + } + if !bytes.HasSuffix(payload, []byte("\n")) { + if _, errWrite := io.WriteString(w, "\n"); errWrite != nil { + return errWrite + } + } + } else { + if _, errWrite := io.WriteString(w, sectionHeader); errWrite != nil { + return errWrite + } + if _, errWrite := w.Write(payload); errWrite != nil { + return errWrite + } + if _, errWrite := io.WriteString(w, "\n"); errWrite != nil { + return errWrite + } + } + + if _, errWrite := io.WriteString(w, "\n"); errWrite != nil { + return errWrite + } + return nil +} + +func writeAPIErrorResponses(w io.Writer, apiResponseErrors []*interfaces.ErrorMessage) error { + for i := 0; i < len(apiResponseErrors); i++ { + if apiResponseErrors[i] == nil { + continue + } + if _, errWrite := io.WriteString(w, "=== API ERROR RESPONSE ===\n"); errWrite != nil { + return errWrite + } + if _, errWrite := io.WriteString(w, fmt.Sprintf("HTTP Status: %d\n", apiResponseErrors[i].StatusCode)); errWrite != nil { + return errWrite + } + if apiResponseErrors[i].Error != nil { + if _, errWrite := io.WriteString(w, apiResponseErrors[i].Error.Error()); errWrite != nil { + return errWrite + } + } + if _, errWrite := io.WriteString(w, "\n\n"); errWrite != nil { + return errWrite + } + } + return nil +} + +func writeResponseSection(w io.Writer, statusCode int, statusWritten bool, responseHeaders map[string][]string, responseReader io.Reader, decompressErr error, trailingNewline bool) error { + if _, errWrite := io.WriteString(w, "=== RESPONSE ===\n"); errWrite != nil { + return errWrite + } + if statusWritten { + if _, errWrite := io.WriteString(w, fmt.Sprintf("Status: %d\n", statusCode)); errWrite != nil { + return errWrite + } + } + + if responseHeaders != nil { + for key, values := range responseHeaders { + for _, value := range values { + if _, errWrite := io.WriteString(w, fmt.Sprintf("%s: %s\n", key, value)); errWrite != nil { + return errWrite + } + } + } + } + + if _, errWrite := io.WriteString(w, "\n"); errWrite != nil { + return errWrite + } + + if responseReader != nil { + if _, errCopy := io.Copy(w, responseReader); errCopy != nil { + return errCopy + } + } + if decompressErr != nil { + if _, errWrite := io.WriteString(w, fmt.Sprintf("\n[DECOMPRESSION ERROR: %v]", decompressErr)); errWrite != nil { + return errWrite + } + } + + if trailingNewline { + if _, errWrite := io.WriteString(w, "\n"); errWrite != nil { + return errWrite + } + } + return nil +} + // formatLogContent creates the complete log content for non-streaming requests. // // Parameters: @@ -648,13 +912,34 @@ func (l *FileRequestLogger) formatRequestInfo(url, method string, headers map[st } // FileStreamingLogWriter implements StreamingLogWriter for file-based streaming logs. -// It handles asynchronous writing of streaming response chunks to a file. -// All data is buffered and written in the correct order when Close is called. +// It spools streaming response chunks to a temporary file to avoid retaining large responses in memory. +// The final log file is assembled when Close is called. type FileStreamingLogWriter struct { - // file is the file where log data is written. - file *os.File + // logFilePath is the final log file path. + logFilePath string - // chunkChan is a channel for receiving response chunks to buffer. + // url is the request URL (masked upstream in middleware). + url string + + // method is the HTTP method. + method string + + // timestamp is captured when the streaming log is initialized. + timestamp time.Time + + // requestHeaders stores the request headers. + requestHeaders map[string][]string + + // requestBodyPath is a temporary file path holding the request body. + requestBodyPath string + + // responseBodyPath is a temporary file path holding the streaming response body. + responseBodyPath string + + // responseBodyFile is the temp file where chunks are appended by the async writer. + responseBodyFile *os.File + + // chunkChan is a channel for receiving response chunks to spool. chunkChan chan []byte // closeChan is a channel for signaling when the writer is closed. @@ -663,9 +948,6 @@ type FileStreamingLogWriter struct { // errorChan is a channel for reporting errors during writing. errorChan chan error - // bufferedChunks stores the response chunks in order. - bufferedChunks *bytes.Buffer - // responseStatus stores the HTTP status code. responseStatus int @@ -770,85 +1052,115 @@ func (w *FileStreamingLogWriter) Close() error { close(w.chunkChan) } - // Wait for async writer to finish buffering chunks + // Wait for async writer to finish spooling chunks if w.closeChan != nil { <-w.closeChan w.chunkChan = nil } - if w.file == nil { + select { + case errWrite := <-w.errorChan: + w.cleanupTempFiles() + return errWrite + default: + } + + if w.logFilePath == "" { + w.cleanupTempFiles() return nil } - // Write all content in the correct order - var content strings.Builder - - // 1. Write API REQUEST section - if len(w.apiRequest) > 0 { - if bytes.HasPrefix(w.apiRequest, []byte("=== API REQUEST")) { - content.Write(w.apiRequest) - if !bytes.HasSuffix(w.apiRequest, []byte("\n")) { - content.WriteString("\n") - } - } else { - content.WriteString("=== API REQUEST ===\n") - content.Write(w.apiRequest) - content.WriteString("\n") - } - content.WriteString("\n") + logFile, errOpen := os.OpenFile(w.logFilePath, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0644) + if errOpen != nil { + w.cleanupTempFiles() + return fmt.Errorf("failed to create log file: %w", errOpen) } - // 2. Write API RESPONSE section - if len(w.apiResponse) > 0 { - if bytes.HasPrefix(w.apiResponse, []byte("=== API RESPONSE")) { - content.Write(w.apiResponse) - if !bytes.HasSuffix(w.apiResponse, []byte("\n")) { - content.WriteString("\n") - } - } else { - content.WriteString("=== API RESPONSE ===\n") - content.Write(w.apiResponse) - content.WriteString("\n") - } - content.WriteString("\n") - } - - // 3. Write RESPONSE section (status, headers, buffered chunks) - content.WriteString("=== RESPONSE ===\n") - if w.statusWritten { - content.WriteString(fmt.Sprintf("Status: %d\n", w.responseStatus)) - } - - for key, values := range w.responseHeaders { - for _, value := range values { - content.WriteString(fmt.Sprintf("%s: %s\n", key, value)) + writeErr := w.writeFinalLog(logFile) + if errClose := logFile.Close(); errClose != nil { + log.WithError(errClose).Warn("failed to close request log file") + if writeErr == nil { + writeErr = errClose } } - content.WriteString("\n") - // Write buffered response body chunks - if w.bufferedChunks != nil && w.bufferedChunks.Len() > 0 { - content.Write(w.bufferedChunks.Bytes()) - } - - // Write the complete content to file - if _, err := w.file.WriteString(content.String()); err != nil { - _ = w.file.Close() - return err - } - - return w.file.Close() + w.cleanupTempFiles() + return writeErr } // asyncWriter runs in a goroutine to buffer chunks from the channel. -// It continuously reads chunks from the channel and buffers them for later writing. +// It continuously reads chunks from the channel and appends them to a temp file for later assembly. func (w *FileStreamingLogWriter) asyncWriter() { defer close(w.closeChan) for chunk := range w.chunkChan { - if w.bufferedChunks != nil { - w.bufferedChunks.Write(chunk) + if w.responseBodyFile == nil { + continue } + if _, errWrite := w.responseBodyFile.Write(chunk); errWrite != nil { + select { + case w.errorChan <- errWrite: + default: + } + if errClose := w.responseBodyFile.Close(); errClose != nil { + select { + case w.errorChan <- errClose: + default: + } + } + w.responseBodyFile = nil + } + } + + if w.responseBodyFile == nil { + return + } + if errClose := w.responseBodyFile.Close(); errClose != nil { + select { + case w.errorChan <- errClose: + default: + } + } + w.responseBodyFile = nil +} + +func (w *FileStreamingLogWriter) writeFinalLog(logFile *os.File) error { + if errWrite := writeRequestInfoWithBody(logFile, w.url, w.method, w.requestHeaders, nil, w.requestBodyPath, w.timestamp); errWrite != nil { + return errWrite + } + if errWrite := writeAPISection(logFile, "=== API REQUEST ===\n", "=== API REQUEST", w.apiRequest); errWrite != nil { + return errWrite + } + if errWrite := writeAPISection(logFile, "=== API RESPONSE ===\n", "=== API RESPONSE", w.apiResponse); errWrite != nil { + return errWrite + } + + responseBodyFile, errOpen := os.Open(w.responseBodyPath) + if errOpen != nil { + return errOpen + } + defer func() { + if errClose := responseBodyFile.Close(); errClose != nil { + log.WithError(errClose).Warn("failed to close response body temp file") + } + }() + + return writeResponseSection(logFile, w.responseStatus, w.statusWritten, w.responseHeaders, responseBodyFile, nil, false) +} + +func (w *FileStreamingLogWriter) cleanupTempFiles() { + if w.requestBodyPath != "" { + if errRemove := os.Remove(w.requestBodyPath); errRemove != nil { + log.WithError(errRemove).Warn("failed to remove request body temp file") + } + w.requestBodyPath = "" + } + + if w.responseBodyPath != "" { + if errRemove := os.Remove(w.responseBodyPath); errRemove != nil { + log.WithError(errRemove).Warn("failed to remove response body temp file") + } + w.responseBodyPath = "" } }