Merge branch 'main' into fix/antigravity-prompt-caching

This commit is contained in:
Evan Nguyen
2025-12-21 19:43:24 +07:00
65 changed files with 4665 additions and 681 deletions

View File

@@ -405,7 +405,7 @@ func main() {
usage.SetStatisticsEnabled(cfg.UsageStatisticsEnabled) usage.SetStatisticsEnabled(cfg.UsageStatisticsEnabled)
coreauth.SetQuotaCooldownDisabled(cfg.DisableCooling) coreauth.SetQuotaCooldownDisabled(cfg.DisableCooling)
if err = logging.ConfigureLogOutput(cfg.LoggingToFile); err != nil { if err = logging.ConfigureLogOutput(cfg.LoggingToFile, cfg.LogsMaxTotalSizeMB); err != nil {
log.Errorf("failed to configure log output: %v", err) log.Errorf("failed to configure log output: %v", err)
return return
} }

View File

@@ -42,6 +42,10 @@ debug: false
# When true, write application logs to rotating files instead of stdout # When true, write application logs to rotating files instead of stdout
logging-to-file: false logging-to-file: false
# Maximum total size (MB) of log files under the logs directory. When exceeded, the oldest log
# files are deleted until within the limit. Set to 0 to disable.
logs-max-total-size-mb: 0
# When false, disable in-memory usage statistics aggregation # When false, disable in-memory usage statistics aggregation
usage-statistics-enabled: false usage-statistics-enabled: false

View File

@@ -23,13 +23,13 @@ import (
"time" "time"
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/router-for-me/CLIProxyAPI/v6/internal/api" "github.com/router-for-me/CLIProxyAPI/v6/sdk/api"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
"github.com/router-for-me/CLIProxyAPI/v6/internal/logging"
sdkAuth "github.com/router-for-me/CLIProxyAPI/v6/sdk/auth" sdkAuth "github.com/router-for-me/CLIProxyAPI/v6/sdk/auth"
"github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy" "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy"
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth" coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
clipexec "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor" clipexec "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
"github.com/router-for-me/CLIProxyAPI/v6/sdk/config"
"github.com/router-for-me/CLIProxyAPI/v6/sdk/logging"
sdktr "github.com/router-for-me/CLIProxyAPI/v6/sdk/translator" sdktr "github.com/router-for-me/CLIProxyAPI/v6/sdk/translator"
) )

View File

@@ -36,10 +36,6 @@ import (
"golang.org/x/oauth2/google" "golang.org/x/oauth2/google"
) )
var (
oauthStatus = make(map[string]string)
)
var lastRefreshKeys = []string{"last_refresh", "lastRefresh", "last_refreshed_at", "lastRefreshedAt"} var lastRefreshKeys = []string{"last_refresh", "lastRefresh", "last_refreshed_at", "lastRefreshedAt"}
const ( const (
@@ -201,6 +197,19 @@ func stopCallbackForwarder(port int) {
stopForwarderInstance(port, forwarder) stopForwarderInstance(port, forwarder)
} }
func stopCallbackForwarderInstance(port int, forwarder *callbackForwarder) {
if forwarder == nil {
return
}
callbackForwardersMu.Lock()
if current := callbackForwarders[port]; current == forwarder {
delete(callbackForwarders, port)
}
callbackForwardersMu.Unlock()
stopForwarderInstance(port, forwarder)
}
func stopForwarderInstance(port int, forwarder *callbackForwarder) { func stopForwarderInstance(port int, forwarder *callbackForwarder) {
if forwarder == nil || forwarder.server == nil { if forwarder == nil || forwarder.server == nil {
return return
@@ -786,7 +795,10 @@ func (h *Handler) RequestAnthropicToken(c *gin.Context) {
return return
} }
RegisterOAuthSession(state, "anthropic")
isWebUI := isWebUIRequest(c) isWebUI := isWebUIRequest(c)
var forwarder *callbackForwarder
if isWebUI { if isWebUI {
targetURL, errTarget := h.managementCallbackURL("/anthropic/callback") targetURL, errTarget := h.managementCallbackURL("/anthropic/callback")
if errTarget != nil { if errTarget != nil {
@@ -794,7 +806,8 @@ func (h *Handler) RequestAnthropicToken(c *gin.Context) {
c.JSON(http.StatusInternalServerError, gin.H{"error": "callback server unavailable"}) c.JSON(http.StatusInternalServerError, gin.H{"error": "callback server unavailable"})
return return
} }
if _, errStart := startCallbackForwarder(anthropicCallbackPort, "anthropic", targetURL); errStart != nil { var errStart error
if forwarder, errStart = startCallbackForwarder(anthropicCallbackPort, "anthropic", targetURL); errStart != nil {
log.WithError(errStart).Error("failed to start anthropic callback forwarder") log.WithError(errStart).Error("failed to start anthropic callback forwarder")
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to start callback server"}) c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to start callback server"})
return return
@@ -803,7 +816,7 @@ func (h *Handler) RequestAnthropicToken(c *gin.Context) {
go func() { go func() {
if isWebUI { if isWebUI {
defer stopCallbackForwarder(anthropicCallbackPort) defer stopCallbackForwarderInstance(anthropicCallbackPort, forwarder)
} }
// Helper: wait for callback file // Helper: wait for callback file
@@ -811,8 +824,11 @@ func (h *Handler) RequestAnthropicToken(c *gin.Context) {
waitForFile := func(path string, timeout time.Duration) (map[string]string, error) { waitForFile := func(path string, timeout time.Duration) (map[string]string, error) {
deadline := time.Now().Add(timeout) deadline := time.Now().Add(timeout)
for { for {
if !IsOAuthSessionPending(state, "anthropic") {
return nil, errOAuthSessionNotPending
}
if time.Now().After(deadline) { if time.Now().After(deadline) {
oauthStatus[state] = "Timeout waiting for OAuth callback" SetOAuthSessionError(state, "Timeout waiting for OAuth callback")
return nil, fmt.Errorf("timeout waiting for OAuth callback") return nil, fmt.Errorf("timeout waiting for OAuth callback")
} }
data, errRead := os.ReadFile(path) data, errRead := os.ReadFile(path)
@@ -830,6 +846,9 @@ func (h *Handler) RequestAnthropicToken(c *gin.Context) {
// Wait up to 5 minutes // Wait up to 5 minutes
resultMap, errWait := waitForFile(waitFile, 5*time.Minute) resultMap, errWait := waitForFile(waitFile, 5*time.Minute)
if errWait != nil { if errWait != nil {
if errors.Is(errWait, errOAuthSessionNotPending) {
return
}
authErr := claude.NewAuthenticationError(claude.ErrCallbackTimeout, errWait) authErr := claude.NewAuthenticationError(claude.ErrCallbackTimeout, errWait)
log.Error(claude.GetUserFriendlyMessage(authErr)) log.Error(claude.GetUserFriendlyMessage(authErr))
return return
@@ -837,13 +856,13 @@ func (h *Handler) RequestAnthropicToken(c *gin.Context) {
if errStr := resultMap["error"]; errStr != "" { if errStr := resultMap["error"]; errStr != "" {
oauthErr := claude.NewOAuthError(errStr, "", http.StatusBadRequest) oauthErr := claude.NewOAuthError(errStr, "", http.StatusBadRequest)
log.Error(claude.GetUserFriendlyMessage(oauthErr)) log.Error(claude.GetUserFriendlyMessage(oauthErr))
oauthStatus[state] = "Bad request" SetOAuthSessionError(state, "Bad request")
return return
} }
if resultMap["state"] != state { if resultMap["state"] != state {
authErr := claude.NewAuthenticationError(claude.ErrInvalidState, fmt.Errorf("expected %s, got %s", state, resultMap["state"])) authErr := claude.NewAuthenticationError(claude.ErrInvalidState, fmt.Errorf("expected %s, got %s", state, resultMap["state"]))
log.Error(claude.GetUserFriendlyMessage(authErr)) log.Error(claude.GetUserFriendlyMessage(authErr))
oauthStatus[state] = "State code error" SetOAuthSessionError(state, "State code error")
return return
} }
@@ -876,7 +895,7 @@ func (h *Handler) RequestAnthropicToken(c *gin.Context) {
if errDo != nil { if errDo != nil {
authErr := claude.NewAuthenticationError(claude.ErrCodeExchangeFailed, errDo) authErr := claude.NewAuthenticationError(claude.ErrCodeExchangeFailed, errDo)
log.Errorf("Failed to exchange authorization code for tokens: %v", authErr) log.Errorf("Failed to exchange authorization code for tokens: %v", authErr)
oauthStatus[state] = "Failed to exchange authorization code for tokens" SetOAuthSessionError(state, "Failed to exchange authorization code for tokens")
return return
} }
defer func() { defer func() {
@@ -887,7 +906,7 @@ func (h *Handler) RequestAnthropicToken(c *gin.Context) {
respBody, _ := io.ReadAll(resp.Body) respBody, _ := io.ReadAll(resp.Body)
if resp.StatusCode != http.StatusOK { if resp.StatusCode != http.StatusOK {
log.Errorf("token exchange failed with status %d: %s", resp.StatusCode, string(respBody)) log.Errorf("token exchange failed with status %d: %s", resp.StatusCode, string(respBody))
oauthStatus[state] = fmt.Sprintf("token exchange failed with status %d", resp.StatusCode) SetOAuthSessionError(state, fmt.Sprintf("token exchange failed with status %d", resp.StatusCode))
return return
} }
var tResp struct { var tResp struct {
@@ -900,7 +919,7 @@ func (h *Handler) RequestAnthropicToken(c *gin.Context) {
} }
if errU := json.Unmarshal(respBody, &tResp); errU != nil { if errU := json.Unmarshal(respBody, &tResp); errU != nil {
log.Errorf("failed to parse token response: %v", errU) log.Errorf("failed to parse token response: %v", errU)
oauthStatus[state] = "Failed to parse token response" SetOAuthSessionError(state, "Failed to parse token response")
return return
} }
bundle := &claude.ClaudeAuthBundle{ bundle := &claude.ClaudeAuthBundle{
@@ -925,7 +944,7 @@ func (h *Handler) RequestAnthropicToken(c *gin.Context) {
savedPath, errSave := h.saveTokenRecord(ctx, record) savedPath, errSave := h.saveTokenRecord(ctx, record)
if errSave != nil { if errSave != nil {
log.Errorf("Failed to save authentication tokens: %v", errSave) log.Errorf("Failed to save authentication tokens: %v", errSave)
oauthStatus[state] = "Failed to save authentication tokens" SetOAuthSessionError(state, "Failed to save authentication tokens")
return return
} }
@@ -934,10 +953,10 @@ func (h *Handler) RequestAnthropicToken(c *gin.Context) {
fmt.Println("API key obtained and saved") fmt.Println("API key obtained and saved")
} }
fmt.Println("You can now use Claude services through this CLI") fmt.Println("You can now use Claude services through this CLI")
delete(oauthStatus, state) CompleteOAuthSession(state)
CompleteOAuthSessionsByProvider("anthropic")
}() }()
oauthStatus[state] = ""
c.JSON(200, gin.H{"status": "ok", "url": authURL, "state": state}) c.JSON(200, gin.H{"status": "ok", "url": authURL, "state": state})
} }
@@ -968,7 +987,10 @@ func (h *Handler) RequestGeminiCLIToken(c *gin.Context) {
state := fmt.Sprintf("gem-%d", time.Now().UnixNano()) state := fmt.Sprintf("gem-%d", time.Now().UnixNano())
authURL := conf.AuthCodeURL(state, oauth2.AccessTypeOffline, oauth2.SetAuthURLParam("prompt", "consent")) authURL := conf.AuthCodeURL(state, oauth2.AccessTypeOffline, oauth2.SetAuthURLParam("prompt", "consent"))
RegisterOAuthSession(state, "gemini")
isWebUI := isWebUIRequest(c) isWebUI := isWebUIRequest(c)
var forwarder *callbackForwarder
if isWebUI { if isWebUI {
targetURL, errTarget := h.managementCallbackURL("/google/callback") targetURL, errTarget := h.managementCallbackURL("/google/callback")
if errTarget != nil { if errTarget != nil {
@@ -976,7 +998,8 @@ func (h *Handler) RequestGeminiCLIToken(c *gin.Context) {
c.JSON(http.StatusInternalServerError, gin.H{"error": "callback server unavailable"}) c.JSON(http.StatusInternalServerError, gin.H{"error": "callback server unavailable"})
return return
} }
if _, errStart := startCallbackForwarder(geminiCallbackPort, "gemini", targetURL); errStart != nil { var errStart error
if forwarder, errStart = startCallbackForwarder(geminiCallbackPort, "gemini", targetURL); errStart != nil {
log.WithError(errStart).Error("failed to start gemini callback forwarder") log.WithError(errStart).Error("failed to start gemini callback forwarder")
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to start callback server"}) c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to start callback server"})
return return
@@ -985,7 +1008,7 @@ func (h *Handler) RequestGeminiCLIToken(c *gin.Context) {
go func() { go func() {
if isWebUI { if isWebUI {
defer stopCallbackForwarder(geminiCallbackPort) defer stopCallbackForwarderInstance(geminiCallbackPort, forwarder)
} }
// Wait for callback file written by server route // Wait for callback file written by server route
@@ -994,9 +1017,12 @@ func (h *Handler) RequestGeminiCLIToken(c *gin.Context) {
deadline := time.Now().Add(5 * time.Minute) deadline := time.Now().Add(5 * time.Minute)
var authCode string var authCode string
for { for {
if !IsOAuthSessionPending(state, "gemini") {
return
}
if time.Now().After(deadline) { if time.Now().After(deadline) {
log.Error("oauth flow timed out") log.Error("oauth flow timed out")
oauthStatus[state] = "OAuth flow timed out" SetOAuthSessionError(state, "OAuth flow timed out")
return return
} }
if data, errR := os.ReadFile(waitFile); errR == nil { if data, errR := os.ReadFile(waitFile); errR == nil {
@@ -1005,13 +1031,13 @@ func (h *Handler) RequestGeminiCLIToken(c *gin.Context) {
_ = os.Remove(waitFile) _ = os.Remove(waitFile)
if errStr := m["error"]; errStr != "" { if errStr := m["error"]; errStr != "" {
log.Errorf("Authentication failed: %s", errStr) log.Errorf("Authentication failed: %s", errStr)
oauthStatus[state] = "Authentication failed" SetOAuthSessionError(state, "Authentication failed")
return return
} }
authCode = m["code"] authCode = m["code"]
if authCode == "" { if authCode == "" {
log.Errorf("Authentication failed: code not found") log.Errorf("Authentication failed: code not found")
oauthStatus[state] = "Authentication failed: code not found" SetOAuthSessionError(state, "Authentication failed: code not found")
return return
} }
break break
@@ -1023,7 +1049,7 @@ func (h *Handler) RequestGeminiCLIToken(c *gin.Context) {
token, err := conf.Exchange(ctx, authCode) token, err := conf.Exchange(ctx, authCode)
if err != nil { if err != nil {
log.Errorf("Failed to exchange token: %v", err) log.Errorf("Failed to exchange token: %v", err)
oauthStatus[state] = "Failed to exchange token" SetOAuthSessionError(state, "Failed to exchange token")
return return
} }
@@ -1034,7 +1060,7 @@ func (h *Handler) RequestGeminiCLIToken(c *gin.Context) {
req, errNewRequest := http.NewRequestWithContext(ctx, "GET", "https://www.googleapis.com/oauth2/v1/userinfo?alt=json", nil) req, errNewRequest := http.NewRequestWithContext(ctx, "GET", "https://www.googleapis.com/oauth2/v1/userinfo?alt=json", nil)
if errNewRequest != nil { if errNewRequest != nil {
log.Errorf("Could not get user info: %v", errNewRequest) log.Errorf("Could not get user info: %v", errNewRequest)
oauthStatus[state] = "Could not get user info" SetOAuthSessionError(state, "Could not get user info")
return return
} }
req.Header.Set("Content-Type", "application/json") req.Header.Set("Content-Type", "application/json")
@@ -1043,7 +1069,7 @@ func (h *Handler) RequestGeminiCLIToken(c *gin.Context) {
resp, errDo := authHTTPClient.Do(req) resp, errDo := authHTTPClient.Do(req)
if errDo != nil { if errDo != nil {
log.Errorf("Failed to execute request: %v", errDo) log.Errorf("Failed to execute request: %v", errDo)
oauthStatus[state] = "Failed to execute request" SetOAuthSessionError(state, "Failed to execute request")
return return
} }
defer func() { defer func() {
@@ -1055,7 +1081,7 @@ func (h *Handler) RequestGeminiCLIToken(c *gin.Context) {
bodyBytes, _ := io.ReadAll(resp.Body) bodyBytes, _ := io.ReadAll(resp.Body)
if resp.StatusCode < 200 || resp.StatusCode >= 300 { if resp.StatusCode < 200 || resp.StatusCode >= 300 {
log.Errorf("Get user info request failed with status %d: %s", resp.StatusCode, string(bodyBytes)) log.Errorf("Get user info request failed with status %d: %s", resp.StatusCode, string(bodyBytes))
oauthStatus[state] = fmt.Sprintf("Get user info request failed with status %d", resp.StatusCode) SetOAuthSessionError(state, fmt.Sprintf("Get user info request failed with status %d", resp.StatusCode))
return return
} }
@@ -1064,7 +1090,6 @@ func (h *Handler) RequestGeminiCLIToken(c *gin.Context) {
fmt.Printf("Authenticated user email: %s\n", email) fmt.Printf("Authenticated user email: %s\n", email)
} else { } else {
fmt.Println("Failed to get user email from token") fmt.Println("Failed to get user email from token")
oauthStatus[state] = "Failed to get user email from token"
} }
// Marshal/unmarshal oauth2.Token to generic map and enrich fields // Marshal/unmarshal oauth2.Token to generic map and enrich fields
@@ -1072,7 +1097,7 @@ func (h *Handler) RequestGeminiCLIToken(c *gin.Context) {
jsonData, _ := json.Marshal(token) jsonData, _ := json.Marshal(token)
if errUnmarshal := json.Unmarshal(jsonData, &ifToken); errUnmarshal != nil { if errUnmarshal := json.Unmarshal(jsonData, &ifToken); errUnmarshal != nil {
log.Errorf("Failed to unmarshal token: %v", errUnmarshal) log.Errorf("Failed to unmarshal token: %v", errUnmarshal)
oauthStatus[state] = "Failed to unmarshal token" SetOAuthSessionError(state, "Failed to unmarshal token")
return return
} }
@@ -1095,10 +1120,12 @@ func (h *Handler) RequestGeminiCLIToken(c *gin.Context) {
// Initialize authenticated HTTP client via GeminiAuth to honor proxy settings // Initialize authenticated HTTP client via GeminiAuth to honor proxy settings
gemAuth := geminiAuth.NewGeminiAuth() gemAuth := geminiAuth.NewGeminiAuth()
gemClient, errGetClient := gemAuth.GetAuthenticatedClient(ctx, &ts, h.cfg, true) gemClient, errGetClient := gemAuth.GetAuthenticatedClient(ctx, &ts, h.cfg, &geminiAuth.WebLoginOptions{
NoBrowser: true,
})
if errGetClient != nil { if errGetClient != nil {
log.Errorf("failed to get authenticated client: %v", errGetClient) log.Errorf("failed to get authenticated client: %v", errGetClient)
oauthStatus[state] = "Failed to get authenticated client" SetOAuthSessionError(state, "Failed to get authenticated client")
return return
} }
fmt.Println("Authentication successful.") fmt.Println("Authentication successful.")
@@ -1108,12 +1135,12 @@ func (h *Handler) RequestGeminiCLIToken(c *gin.Context) {
projects, errAll := onboardAllGeminiProjects(ctx, gemClient, &ts) projects, errAll := onboardAllGeminiProjects(ctx, gemClient, &ts)
if errAll != nil { if errAll != nil {
log.Errorf("Failed to complete Gemini CLI onboarding: %v", errAll) log.Errorf("Failed to complete Gemini CLI onboarding: %v", errAll)
oauthStatus[state] = "Failed to complete Gemini CLI onboarding" SetOAuthSessionError(state, "Failed to complete Gemini CLI onboarding")
return return
} }
if errVerify := ensureGeminiProjectsEnabled(ctx, gemClient, projects); errVerify != nil { if errVerify := ensureGeminiProjectsEnabled(ctx, gemClient, projects); errVerify != nil {
log.Errorf("Failed to verify Cloud AI API status: %v", errVerify) log.Errorf("Failed to verify Cloud AI API status: %v", errVerify)
oauthStatus[state] = "Failed to verify Cloud AI API status" SetOAuthSessionError(state, "Failed to verify Cloud AI API status")
return return
} }
ts.ProjectID = strings.Join(projects, ",") ts.ProjectID = strings.Join(projects, ",")
@@ -1121,26 +1148,26 @@ func (h *Handler) RequestGeminiCLIToken(c *gin.Context) {
} else { } else {
if errEnsure := ensureGeminiProjectAndOnboard(ctx, gemClient, &ts, requestedProjectID); errEnsure != nil { if errEnsure := ensureGeminiProjectAndOnboard(ctx, gemClient, &ts, requestedProjectID); errEnsure != nil {
log.Errorf("Failed to complete Gemini CLI onboarding: %v", errEnsure) log.Errorf("Failed to complete Gemini CLI onboarding: %v", errEnsure)
oauthStatus[state] = "Failed to complete Gemini CLI onboarding" SetOAuthSessionError(state, "Failed to complete Gemini CLI onboarding")
return return
} }
if strings.TrimSpace(ts.ProjectID) == "" { if strings.TrimSpace(ts.ProjectID) == "" {
log.Error("Onboarding did not return a project ID") log.Error("Onboarding did not return a project ID")
oauthStatus[state] = "Failed to resolve project ID" SetOAuthSessionError(state, "Failed to resolve project ID")
return return
} }
isChecked, errCheck := checkCloudAPIIsEnabled(ctx, gemClient, ts.ProjectID) isChecked, errCheck := checkCloudAPIIsEnabled(ctx, gemClient, ts.ProjectID)
if errCheck != nil { if errCheck != nil {
log.Errorf("Failed to verify Cloud AI API status: %v", errCheck) log.Errorf("Failed to verify Cloud AI API status: %v", errCheck)
oauthStatus[state] = "Failed to verify Cloud AI API status" SetOAuthSessionError(state, "Failed to verify Cloud AI API status")
return return
} }
ts.Checked = isChecked ts.Checked = isChecked
if !isChecked { if !isChecked {
log.Error("Cloud AI API is not enabled for the selected project") log.Error("Cloud AI API is not enabled for the selected project")
oauthStatus[state] = "Cloud AI API not enabled" SetOAuthSessionError(state, "Cloud AI API not enabled")
return return
} }
} }
@@ -1163,15 +1190,15 @@ func (h *Handler) RequestGeminiCLIToken(c *gin.Context) {
savedPath, errSave := h.saveTokenRecord(ctx, record) savedPath, errSave := h.saveTokenRecord(ctx, record)
if errSave != nil { if errSave != nil {
log.Errorf("Failed to save token to file: %v", errSave) log.Errorf("Failed to save token to file: %v", errSave)
oauthStatus[state] = "Failed to save token to file" SetOAuthSessionError(state, "Failed to save token to file")
return return
} }
delete(oauthStatus, state) CompleteOAuthSession(state)
CompleteOAuthSessionsByProvider("gemini")
fmt.Printf("You can now use Gemini CLI services through this CLI; token saved to %s\n", savedPath) fmt.Printf("You can now use Gemini CLI services through this CLI; token saved to %s\n", savedPath)
}() }()
oauthStatus[state] = ""
c.JSON(200, gin.H{"status": "ok", "url": authURL, "state": state}) c.JSON(200, gin.H{"status": "ok", "url": authURL, "state": state})
} }
@@ -1207,7 +1234,10 @@ func (h *Handler) RequestCodexToken(c *gin.Context) {
return return
} }
RegisterOAuthSession(state, "codex")
isWebUI := isWebUIRequest(c) isWebUI := isWebUIRequest(c)
var forwarder *callbackForwarder
if isWebUI { if isWebUI {
targetURL, errTarget := h.managementCallbackURL("/codex/callback") targetURL, errTarget := h.managementCallbackURL("/codex/callback")
if errTarget != nil { if errTarget != nil {
@@ -1215,7 +1245,8 @@ func (h *Handler) RequestCodexToken(c *gin.Context) {
c.JSON(http.StatusInternalServerError, gin.H{"error": "callback server unavailable"}) c.JSON(http.StatusInternalServerError, gin.H{"error": "callback server unavailable"})
return return
} }
if _, errStart := startCallbackForwarder(codexCallbackPort, "codex", targetURL); errStart != nil { var errStart error
if forwarder, errStart = startCallbackForwarder(codexCallbackPort, "codex", targetURL); errStart != nil {
log.WithError(errStart).Error("failed to start codex callback forwarder") log.WithError(errStart).Error("failed to start codex callback forwarder")
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to start callback server"}) c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to start callback server"})
return return
@@ -1224,7 +1255,7 @@ func (h *Handler) RequestCodexToken(c *gin.Context) {
go func() { go func() {
if isWebUI { if isWebUI {
defer stopCallbackForwarder(codexCallbackPort) defer stopCallbackForwarderInstance(codexCallbackPort, forwarder)
} }
// Wait for callback file // Wait for callback file
@@ -1232,10 +1263,13 @@ func (h *Handler) RequestCodexToken(c *gin.Context) {
deadline := time.Now().Add(5 * time.Minute) deadline := time.Now().Add(5 * time.Minute)
var code string var code string
for { for {
if !IsOAuthSessionPending(state, "codex") {
return
}
if time.Now().After(deadline) { if time.Now().After(deadline) {
authErr := codex.NewAuthenticationError(codex.ErrCallbackTimeout, fmt.Errorf("timeout waiting for OAuth callback")) authErr := codex.NewAuthenticationError(codex.ErrCallbackTimeout, fmt.Errorf("timeout waiting for OAuth callback"))
log.Error(codex.GetUserFriendlyMessage(authErr)) log.Error(codex.GetUserFriendlyMessage(authErr))
oauthStatus[state] = "Timeout waiting for OAuth callback" SetOAuthSessionError(state, "Timeout waiting for OAuth callback")
return return
} }
if data, errR := os.ReadFile(waitFile); errR == nil { if data, errR := os.ReadFile(waitFile); errR == nil {
@@ -1245,12 +1279,12 @@ func (h *Handler) RequestCodexToken(c *gin.Context) {
if errStr := m["error"]; errStr != "" { if errStr := m["error"]; errStr != "" {
oauthErr := codex.NewOAuthError(errStr, "", http.StatusBadRequest) oauthErr := codex.NewOAuthError(errStr, "", http.StatusBadRequest)
log.Error(codex.GetUserFriendlyMessage(oauthErr)) log.Error(codex.GetUserFriendlyMessage(oauthErr))
oauthStatus[state] = "Bad Request" SetOAuthSessionError(state, "Bad Request")
return return
} }
if m["state"] != state { if m["state"] != state {
authErr := codex.NewAuthenticationError(codex.ErrInvalidState, fmt.Errorf("expected %s, got %s", state, m["state"])) authErr := codex.NewAuthenticationError(codex.ErrInvalidState, fmt.Errorf("expected %s, got %s", state, m["state"]))
oauthStatus[state] = "State code error" SetOAuthSessionError(state, "State code error")
log.Error(codex.GetUserFriendlyMessage(authErr)) log.Error(codex.GetUserFriendlyMessage(authErr))
return return
} }
@@ -1281,14 +1315,14 @@ func (h *Handler) RequestCodexToken(c *gin.Context) {
resp, errDo := httpClient.Do(req) resp, errDo := httpClient.Do(req)
if errDo != nil { if errDo != nil {
authErr := codex.NewAuthenticationError(codex.ErrCodeExchangeFailed, errDo) authErr := codex.NewAuthenticationError(codex.ErrCodeExchangeFailed, errDo)
oauthStatus[state] = "Failed to exchange authorization code for tokens" SetOAuthSessionError(state, "Failed to exchange authorization code for tokens")
log.Errorf("Failed to exchange authorization code for tokens: %v", authErr) log.Errorf("Failed to exchange authorization code for tokens: %v", authErr)
return return
} }
defer func() { _ = resp.Body.Close() }() defer func() { _ = resp.Body.Close() }()
respBody, _ := io.ReadAll(resp.Body) respBody, _ := io.ReadAll(resp.Body)
if resp.StatusCode != http.StatusOK { if resp.StatusCode != http.StatusOK {
oauthStatus[state] = fmt.Sprintf("Token exchange failed with status %d", resp.StatusCode) SetOAuthSessionError(state, fmt.Sprintf("Token exchange failed with status %d", resp.StatusCode))
log.Errorf("token exchange failed with status %d: %s", resp.StatusCode, string(respBody)) log.Errorf("token exchange failed with status %d: %s", resp.StatusCode, string(respBody))
return return
} }
@@ -1299,7 +1333,7 @@ func (h *Handler) RequestCodexToken(c *gin.Context) {
ExpiresIn int `json:"expires_in"` ExpiresIn int `json:"expires_in"`
} }
if errU := json.Unmarshal(respBody, &tokenResp); errU != nil { if errU := json.Unmarshal(respBody, &tokenResp); errU != nil {
oauthStatus[state] = "Failed to parse token response" SetOAuthSessionError(state, "Failed to parse token response")
log.Errorf("failed to parse token response: %v", errU) log.Errorf("failed to parse token response: %v", errU)
return return
} }
@@ -1337,7 +1371,7 @@ func (h *Handler) RequestCodexToken(c *gin.Context) {
} }
savedPath, errSave := h.saveTokenRecord(ctx, record) savedPath, errSave := h.saveTokenRecord(ctx, record)
if errSave != nil { if errSave != nil {
oauthStatus[state] = "Failed to save authentication tokens" SetOAuthSessionError(state, "Failed to save authentication tokens")
log.Errorf("Failed to save authentication tokens: %v", errSave) log.Errorf("Failed to save authentication tokens: %v", errSave)
return return
} }
@@ -1346,10 +1380,10 @@ func (h *Handler) RequestCodexToken(c *gin.Context) {
fmt.Println("API key obtained and saved") fmt.Println("API key obtained and saved")
} }
fmt.Println("You can now use Codex services through this CLI") fmt.Println("You can now use Codex services through this CLI")
delete(oauthStatus, state) CompleteOAuthSession(state)
CompleteOAuthSessionsByProvider("codex")
}() }()
oauthStatus[state] = ""
c.JSON(200, gin.H{"status": "ok", "url": authURL, "state": state}) c.JSON(200, gin.H{"status": "ok", "url": authURL, "state": state})
} }
@@ -1390,7 +1424,10 @@ func (h *Handler) RequestAntigravityToken(c *gin.Context) {
params.Set("state", state) params.Set("state", state)
authURL := "https://accounts.google.com/o/oauth2/v2/auth?" + params.Encode() authURL := "https://accounts.google.com/o/oauth2/v2/auth?" + params.Encode()
RegisterOAuthSession(state, "antigravity")
isWebUI := isWebUIRequest(c) isWebUI := isWebUIRequest(c)
var forwarder *callbackForwarder
if isWebUI { if isWebUI {
targetURL, errTarget := h.managementCallbackURL("/antigravity/callback") targetURL, errTarget := h.managementCallbackURL("/antigravity/callback")
if errTarget != nil { if errTarget != nil {
@@ -1398,7 +1435,8 @@ func (h *Handler) RequestAntigravityToken(c *gin.Context) {
c.JSON(http.StatusInternalServerError, gin.H{"error": "callback server unavailable"}) c.JSON(http.StatusInternalServerError, gin.H{"error": "callback server unavailable"})
return return
} }
if _, errStart := startCallbackForwarder(antigravityCallbackPort, "antigravity", targetURL); errStart != nil { var errStart error
if forwarder, errStart = startCallbackForwarder(antigravityCallbackPort, "antigravity", targetURL); errStart != nil {
log.WithError(errStart).Error("failed to start antigravity callback forwarder") log.WithError(errStart).Error("failed to start antigravity callback forwarder")
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to start callback server"}) c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to start callback server"})
return return
@@ -1407,16 +1445,19 @@ func (h *Handler) RequestAntigravityToken(c *gin.Context) {
go func() { go func() {
if isWebUI { if isWebUI {
defer stopCallbackForwarder(antigravityCallbackPort) defer stopCallbackForwarderInstance(antigravityCallbackPort, forwarder)
} }
waitFile := filepath.Join(h.cfg.AuthDir, fmt.Sprintf(".oauth-antigravity-%s.oauth", state)) waitFile := filepath.Join(h.cfg.AuthDir, fmt.Sprintf(".oauth-antigravity-%s.oauth", state))
deadline := time.Now().Add(5 * time.Minute) deadline := time.Now().Add(5 * time.Minute)
var authCode string var authCode string
for { for {
if !IsOAuthSessionPending(state, "antigravity") {
return
}
if time.Now().After(deadline) { if time.Now().After(deadline) {
log.Error("oauth flow timed out") log.Error("oauth flow timed out")
oauthStatus[state] = "OAuth flow timed out" SetOAuthSessionError(state, "OAuth flow timed out")
return return
} }
if data, errReadFile := os.ReadFile(waitFile); errReadFile == nil { if data, errReadFile := os.ReadFile(waitFile); errReadFile == nil {
@@ -1425,18 +1466,18 @@ func (h *Handler) RequestAntigravityToken(c *gin.Context) {
_ = os.Remove(waitFile) _ = os.Remove(waitFile)
if errStr := strings.TrimSpace(payload["error"]); errStr != "" { if errStr := strings.TrimSpace(payload["error"]); errStr != "" {
log.Errorf("Authentication failed: %s", errStr) log.Errorf("Authentication failed: %s", errStr)
oauthStatus[state] = "Authentication failed" SetOAuthSessionError(state, "Authentication failed")
return return
} }
if payloadState := strings.TrimSpace(payload["state"]); payloadState != "" && payloadState != state { if payloadState := strings.TrimSpace(payload["state"]); payloadState != "" && payloadState != state {
log.Errorf("Authentication failed: state mismatch") log.Errorf("Authentication failed: state mismatch")
oauthStatus[state] = "Authentication failed: state mismatch" SetOAuthSessionError(state, "Authentication failed: state mismatch")
return return
} }
authCode = strings.TrimSpace(payload["code"]) authCode = strings.TrimSpace(payload["code"])
if authCode == "" { if authCode == "" {
log.Error("Authentication failed: code not found") log.Error("Authentication failed: code not found")
oauthStatus[state] = "Authentication failed: code not found" SetOAuthSessionError(state, "Authentication failed: code not found")
return return
} }
break break
@@ -1455,7 +1496,7 @@ func (h *Handler) RequestAntigravityToken(c *gin.Context) {
req, errNewRequest := http.NewRequestWithContext(ctx, http.MethodPost, "https://oauth2.googleapis.com/token", strings.NewReader(form.Encode())) req, errNewRequest := http.NewRequestWithContext(ctx, http.MethodPost, "https://oauth2.googleapis.com/token", strings.NewReader(form.Encode()))
if errNewRequest != nil { if errNewRequest != nil {
log.Errorf("Failed to build token request: %v", errNewRequest) log.Errorf("Failed to build token request: %v", errNewRequest)
oauthStatus[state] = "Failed to build token request" SetOAuthSessionError(state, "Failed to build token request")
return return
} }
req.Header.Set("Content-Type", "application/x-www-form-urlencoded") req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
@@ -1463,7 +1504,7 @@ func (h *Handler) RequestAntigravityToken(c *gin.Context) {
resp, errDo := httpClient.Do(req) resp, errDo := httpClient.Do(req)
if errDo != nil { if errDo != nil {
log.Errorf("Failed to execute token request: %v", errDo) log.Errorf("Failed to execute token request: %v", errDo)
oauthStatus[state] = "Failed to exchange token" SetOAuthSessionError(state, "Failed to exchange token")
return return
} }
defer func() { defer func() {
@@ -1475,7 +1516,7 @@ func (h *Handler) RequestAntigravityToken(c *gin.Context) {
if resp.StatusCode < http.StatusOK || resp.StatusCode >= http.StatusMultipleChoices { if resp.StatusCode < http.StatusOK || resp.StatusCode >= http.StatusMultipleChoices {
bodyBytes, _ := io.ReadAll(resp.Body) bodyBytes, _ := io.ReadAll(resp.Body)
log.Errorf("Antigravity token exchange failed with status %d: %s", resp.StatusCode, string(bodyBytes)) log.Errorf("Antigravity token exchange failed with status %d: %s", resp.StatusCode, string(bodyBytes))
oauthStatus[state] = fmt.Sprintf("Token exchange failed: %d", resp.StatusCode) SetOAuthSessionError(state, fmt.Sprintf("Token exchange failed: %d", resp.StatusCode))
return return
} }
@@ -1487,7 +1528,7 @@ func (h *Handler) RequestAntigravityToken(c *gin.Context) {
} }
if errDecode := json.NewDecoder(resp.Body).Decode(&tokenResp); errDecode != nil { if errDecode := json.NewDecoder(resp.Body).Decode(&tokenResp); errDecode != nil {
log.Errorf("Failed to parse token response: %v", errDecode) log.Errorf("Failed to parse token response: %v", errDecode)
oauthStatus[state] = "Failed to parse token response" SetOAuthSessionError(state, "Failed to parse token response")
return return
} }
@@ -1496,7 +1537,7 @@ func (h *Handler) RequestAntigravityToken(c *gin.Context) {
infoReq, errInfoReq := http.NewRequestWithContext(ctx, http.MethodGet, "https://www.googleapis.com/oauth2/v1/userinfo?alt=json", nil) infoReq, errInfoReq := http.NewRequestWithContext(ctx, http.MethodGet, "https://www.googleapis.com/oauth2/v1/userinfo?alt=json", nil)
if errInfoReq != nil { if errInfoReq != nil {
log.Errorf("Failed to build user info request: %v", errInfoReq) log.Errorf("Failed to build user info request: %v", errInfoReq)
oauthStatus[state] = "Failed to build user info request" SetOAuthSessionError(state, "Failed to build user info request")
return return
} }
infoReq.Header.Set("Authorization", "Bearer "+tokenResp.AccessToken) infoReq.Header.Set("Authorization", "Bearer "+tokenResp.AccessToken)
@@ -1504,7 +1545,7 @@ func (h *Handler) RequestAntigravityToken(c *gin.Context) {
infoResp, errInfo := httpClient.Do(infoReq) infoResp, errInfo := httpClient.Do(infoReq)
if errInfo != nil { if errInfo != nil {
log.Errorf("Failed to execute user info request: %v", errInfo) log.Errorf("Failed to execute user info request: %v", errInfo)
oauthStatus[state] = "Failed to execute user info request" SetOAuthSessionError(state, "Failed to execute user info request")
return return
} }
defer func() { defer func() {
@@ -1523,7 +1564,7 @@ func (h *Handler) RequestAntigravityToken(c *gin.Context) {
} else { } else {
bodyBytes, _ := io.ReadAll(infoResp.Body) bodyBytes, _ := io.ReadAll(infoResp.Body)
log.Errorf("User info request failed with status %d: %s", infoResp.StatusCode, string(bodyBytes)) log.Errorf("User info request failed with status %d: %s", infoResp.StatusCode, string(bodyBytes))
oauthStatus[state] = fmt.Sprintf("User info request failed: %d", infoResp.StatusCode) SetOAuthSessionError(state, fmt.Sprintf("User info request failed: %d", infoResp.StatusCode))
return return
} }
} }
@@ -1571,11 +1612,12 @@ func (h *Handler) RequestAntigravityToken(c *gin.Context) {
savedPath, errSave := h.saveTokenRecord(ctx, record) savedPath, errSave := h.saveTokenRecord(ctx, record)
if errSave != nil { if errSave != nil {
log.Errorf("Failed to save token to file: %v", errSave) log.Errorf("Failed to save token to file: %v", errSave)
oauthStatus[state] = "Failed to save token to file" SetOAuthSessionError(state, "Failed to save token to file")
return return
} }
delete(oauthStatus, state) CompleteOAuthSession(state)
CompleteOAuthSessionsByProvider("antigravity")
fmt.Printf("Authentication successful! Token saved to %s\n", savedPath) fmt.Printf("Authentication successful! Token saved to %s\n", savedPath)
if projectID != "" { if projectID != "" {
fmt.Printf("Using GCP project: %s\n", projectID) fmt.Printf("Using GCP project: %s\n", projectID)
@@ -1583,7 +1625,6 @@ func (h *Handler) RequestAntigravityToken(c *gin.Context) {
fmt.Println("You can now use Antigravity services through this CLI") fmt.Println("You can now use Antigravity services through this CLI")
}() }()
oauthStatus[state] = ""
c.JSON(200, gin.H{"status": "ok", "url": authURL, "state": state}) c.JSON(200, gin.H{"status": "ok", "url": authURL, "state": state})
} }
@@ -1605,11 +1646,13 @@ func (h *Handler) RequestQwenToken(c *gin.Context) {
} }
authURL := deviceFlow.VerificationURIComplete authURL := deviceFlow.VerificationURIComplete
RegisterOAuthSession(state, "qwen")
go func() { go func() {
fmt.Println("Waiting for authentication...") fmt.Println("Waiting for authentication...")
tokenData, errPollForToken := qwenAuth.PollForToken(deviceFlow.DeviceCode, deviceFlow.CodeVerifier) tokenData, errPollForToken := qwenAuth.PollForToken(deviceFlow.DeviceCode, deviceFlow.CodeVerifier)
if errPollForToken != nil { if errPollForToken != nil {
oauthStatus[state] = "Authentication failed" SetOAuthSessionError(state, "Authentication failed")
fmt.Printf("Authentication failed: %v\n", errPollForToken) fmt.Printf("Authentication failed: %v\n", errPollForToken)
return return
} }
@@ -1628,16 +1671,15 @@ func (h *Handler) RequestQwenToken(c *gin.Context) {
savedPath, errSave := h.saveTokenRecord(ctx, record) savedPath, errSave := h.saveTokenRecord(ctx, record)
if errSave != nil { if errSave != nil {
log.Errorf("Failed to save authentication tokens: %v", errSave) log.Errorf("Failed to save authentication tokens: %v", errSave)
oauthStatus[state] = "Failed to save authentication tokens" SetOAuthSessionError(state, "Failed to save authentication tokens")
return return
} }
fmt.Printf("Authentication successful! Token saved to %s\n", savedPath) fmt.Printf("Authentication successful! Token saved to %s\n", savedPath)
fmt.Println("You can now use Qwen services through this CLI") fmt.Println("You can now use Qwen services through this CLI")
delete(oauthStatus, state) CompleteOAuthSession(state)
}() }()
oauthStatus[state] = ""
c.JSON(200, gin.H{"status": "ok", "url": authURL, "state": state}) c.JSON(200, gin.H{"status": "ok", "url": authURL, "state": state})
} }
@@ -1650,7 +1692,10 @@ func (h *Handler) RequestIFlowToken(c *gin.Context) {
authSvc := iflowauth.NewIFlowAuth(h.cfg) authSvc := iflowauth.NewIFlowAuth(h.cfg)
authURL, redirectURI := authSvc.AuthorizationURL(state, iflowauth.CallbackPort) authURL, redirectURI := authSvc.AuthorizationURL(state, iflowauth.CallbackPort)
RegisterOAuthSession(state, "iflow")
isWebUI := isWebUIRequest(c) isWebUI := isWebUIRequest(c)
var forwarder *callbackForwarder
if isWebUI { if isWebUI {
targetURL, errTarget := h.managementCallbackURL("/iflow/callback") targetURL, errTarget := h.managementCallbackURL("/iflow/callback")
if errTarget != nil { if errTarget != nil {
@@ -1658,7 +1703,8 @@ func (h *Handler) RequestIFlowToken(c *gin.Context) {
c.JSON(http.StatusInternalServerError, gin.H{"status": "error", "error": "callback server unavailable"}) c.JSON(http.StatusInternalServerError, gin.H{"status": "error", "error": "callback server unavailable"})
return return
} }
if _, errStart := startCallbackForwarder(iflowauth.CallbackPort, "iflow", targetURL); errStart != nil { var errStart error
if forwarder, errStart = startCallbackForwarder(iflowauth.CallbackPort, "iflow", targetURL); errStart != nil {
log.WithError(errStart).Error("failed to start iflow callback forwarder") log.WithError(errStart).Error("failed to start iflow callback forwarder")
c.JSON(http.StatusInternalServerError, gin.H{"status": "error", "error": "failed to start callback server"}) c.JSON(http.StatusInternalServerError, gin.H{"status": "error", "error": "failed to start callback server"})
return return
@@ -1667,7 +1713,7 @@ func (h *Handler) RequestIFlowToken(c *gin.Context) {
go func() { go func() {
if isWebUI { if isWebUI {
defer stopCallbackForwarder(iflowauth.CallbackPort) defer stopCallbackForwarderInstance(iflowauth.CallbackPort, forwarder)
} }
fmt.Println("Waiting for authentication...") fmt.Println("Waiting for authentication...")
@@ -1675,8 +1721,11 @@ func (h *Handler) RequestIFlowToken(c *gin.Context) {
deadline := time.Now().Add(5 * time.Minute) deadline := time.Now().Add(5 * time.Minute)
var resultMap map[string]string var resultMap map[string]string
for { for {
if !IsOAuthSessionPending(state, "iflow") {
return
}
if time.Now().After(deadline) { if time.Now().After(deadline) {
oauthStatus[state] = "Authentication failed" SetOAuthSessionError(state, "Authentication failed")
fmt.Println("Authentication failed: timeout waiting for callback") fmt.Println("Authentication failed: timeout waiting for callback")
return return
} }
@@ -1689,26 +1738,26 @@ func (h *Handler) RequestIFlowToken(c *gin.Context) {
} }
if errStr := strings.TrimSpace(resultMap["error"]); errStr != "" { if errStr := strings.TrimSpace(resultMap["error"]); errStr != "" {
oauthStatus[state] = "Authentication failed" SetOAuthSessionError(state, "Authentication failed")
fmt.Printf("Authentication failed: %s\n", errStr) fmt.Printf("Authentication failed: %s\n", errStr)
return return
} }
if resultState := strings.TrimSpace(resultMap["state"]); resultState != state { if resultState := strings.TrimSpace(resultMap["state"]); resultState != state {
oauthStatus[state] = "Authentication failed" SetOAuthSessionError(state, "Authentication failed")
fmt.Println("Authentication failed: state mismatch") fmt.Println("Authentication failed: state mismatch")
return return
} }
code := strings.TrimSpace(resultMap["code"]) code := strings.TrimSpace(resultMap["code"])
if code == "" { if code == "" {
oauthStatus[state] = "Authentication failed" SetOAuthSessionError(state, "Authentication failed")
fmt.Println("Authentication failed: code missing") fmt.Println("Authentication failed: code missing")
return return
} }
tokenData, errExchange := authSvc.ExchangeCodeForTokens(ctx, code, redirectURI) tokenData, errExchange := authSvc.ExchangeCodeForTokens(ctx, code, redirectURI)
if errExchange != nil { if errExchange != nil {
oauthStatus[state] = "Authentication failed" SetOAuthSessionError(state, "Authentication failed")
fmt.Printf("Authentication failed: %v\n", errExchange) fmt.Printf("Authentication failed: %v\n", errExchange)
return return
} }
@@ -1730,7 +1779,7 @@ func (h *Handler) RequestIFlowToken(c *gin.Context) {
savedPath, errSave := h.saveTokenRecord(ctx, record) savedPath, errSave := h.saveTokenRecord(ctx, record)
if errSave != nil { if errSave != nil {
oauthStatus[state] = "Failed to save authentication tokens" SetOAuthSessionError(state, "Failed to save authentication tokens")
log.Errorf("Failed to save authentication tokens: %v", errSave) log.Errorf("Failed to save authentication tokens: %v", errSave)
return return
} }
@@ -1740,10 +1789,10 @@ func (h *Handler) RequestIFlowToken(c *gin.Context) {
fmt.Println("API key obtained and saved") fmt.Println("API key obtained and saved")
} }
fmt.Println("You can now use iFlow services through this CLI") fmt.Println("You can now use iFlow services through this CLI")
delete(oauthStatus, state) CompleteOAuthSession(state)
CompleteOAuthSessionsByProvider("iflow")
}() }()
oauthStatus[state] = ""
c.JSON(http.StatusOK, gin.H{"status": "ok", "url": authURL, "state": state}) c.JSON(http.StatusOK, gin.H{"status": "ok", "url": authURL, "state": state})
} }
@@ -2179,16 +2228,24 @@ func checkCloudAPIIsEnabled(ctx context.Context, httpClient *http.Client, projec
} }
func (h *Handler) GetAuthStatus(c *gin.Context) { func (h *Handler) GetAuthStatus(c *gin.Context) {
state := c.Query("state") state := strings.TrimSpace(c.Query("state"))
if err, ok := oauthStatus[state]; ok { if state == "" {
if err != "" { c.JSON(http.StatusOK, gin.H{"status": "ok"})
c.JSON(200, gin.H{"status": "error", "error": err}) return
} else {
c.JSON(200, gin.H{"status": "wait"})
return
}
} else {
c.JSON(200, gin.H{"status": "ok"})
} }
delete(oauthStatus, state) if err := ValidateOAuthState(state); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"status": "error", "error": "invalid state"})
return
}
_, status, ok := GetOAuthSession(state)
if !ok {
c.JSON(http.StatusOK, gin.H{"status": "ok"})
return
}
if status != "" {
c.JSON(http.StatusOK, gin.H{"status": "error", "error": status})
return
}
c.JSON(http.StatusOK, gin.H{"status": "wait"})
} }

View File

@@ -145,71 +145,74 @@ func (h *Handler) PutGeminiKeys(c *gin.Context) {
h.persist(c) h.persist(c)
} }
func (h *Handler) PatchGeminiKey(c *gin.Context) { func (h *Handler) PatchGeminiKey(c *gin.Context) {
type geminiKeyPatch struct {
APIKey *string `json:"api-key"`
Prefix *string `json:"prefix"`
BaseURL *string `json:"base-url"`
ProxyURL *string `json:"proxy-url"`
Headers *map[string]string `json:"headers"`
ExcludedModels *[]string `json:"excluded-models"`
}
var body struct { var body struct {
Index *int `json:"index"` Index *int `json:"index"`
Match *string `json:"match"` Match *string `json:"match"`
Value *config.GeminiKey `json:"value"` Value *geminiKeyPatch `json:"value"`
} }
if err := c.ShouldBindJSON(&body); err != nil || body.Value == nil { if err := c.ShouldBindJSON(&body); err != nil || body.Value == nil {
c.JSON(400, gin.H{"error": "invalid body"}) c.JSON(400, gin.H{"error": "invalid body"})
return return
} }
value := *body.Value targetIndex := -1
value.APIKey = strings.TrimSpace(value.APIKey) if body.Index != nil && *body.Index >= 0 && *body.Index < len(h.cfg.GeminiKey) {
value.BaseURL = strings.TrimSpace(value.BaseURL) targetIndex = *body.Index
value.ProxyURL = strings.TrimSpace(value.ProxyURL) }
value.ExcludedModels = config.NormalizeExcludedModels(value.ExcludedModels) if targetIndex == -1 && body.Match != nil {
if value.APIKey == "" { match := strings.TrimSpace(*body.Match)
// Treat empty API key as delete. if match != "" {
if body.Index != nil && *body.Index >= 0 && *body.Index < len(h.cfg.GeminiKey) { for i := range h.cfg.GeminiKey {
h.cfg.GeminiKey = append(h.cfg.GeminiKey[:*body.Index], h.cfg.GeminiKey[*body.Index+1:]...) if h.cfg.GeminiKey[i].APIKey == match {
h.cfg.SanitizeGeminiKeys() targetIndex = i
h.persist(c) break
return
}
if body.Match != nil {
match := strings.TrimSpace(*body.Match)
if match != "" {
out := make([]config.GeminiKey, 0, len(h.cfg.GeminiKey))
removed := false
for i := range h.cfg.GeminiKey {
if !removed && h.cfg.GeminiKey[i].APIKey == match {
removed = true
continue
}
out = append(out, h.cfg.GeminiKey[i])
}
if removed {
h.cfg.GeminiKey = out
h.cfg.SanitizeGeminiKeys()
h.persist(c)
return
} }
} }
} }
}
if targetIndex == -1 {
c.JSON(404, gin.H{"error": "item not found"}) c.JSON(404, gin.H{"error": "item not found"})
return return
} }
if body.Index != nil && *body.Index >= 0 && *body.Index < len(h.cfg.GeminiKey) { entry := h.cfg.GeminiKey[targetIndex]
h.cfg.GeminiKey[*body.Index] = value if body.Value.APIKey != nil {
h.cfg.SanitizeGeminiKeys() trimmed := strings.TrimSpace(*body.Value.APIKey)
h.persist(c) if trimmed == "" {
return h.cfg.GeminiKey = append(h.cfg.GeminiKey[:targetIndex], h.cfg.GeminiKey[targetIndex+1:]...)
} h.cfg.SanitizeGeminiKeys()
if body.Match != nil { h.persist(c)
match := strings.TrimSpace(*body.Match) return
for i := range h.cfg.GeminiKey {
if h.cfg.GeminiKey[i].APIKey == match {
h.cfg.GeminiKey[i] = value
h.cfg.SanitizeGeminiKeys()
h.persist(c)
return
}
} }
entry.APIKey = trimmed
} }
c.JSON(404, gin.H{"error": "item not found"}) if body.Value.Prefix != nil {
entry.Prefix = strings.TrimSpace(*body.Value.Prefix)
}
if body.Value.BaseURL != nil {
entry.BaseURL = strings.TrimSpace(*body.Value.BaseURL)
}
if body.Value.ProxyURL != nil {
entry.ProxyURL = strings.TrimSpace(*body.Value.ProxyURL)
}
if body.Value.Headers != nil {
entry.Headers = config.NormalizeHeaders(*body.Value.Headers)
}
if body.Value.ExcludedModels != nil {
entry.ExcludedModels = config.NormalizeExcludedModels(*body.Value.ExcludedModels)
}
h.cfg.GeminiKey[targetIndex] = entry
h.cfg.SanitizeGeminiKeys()
h.persist(c)
} }
func (h *Handler) DeleteGeminiKey(c *gin.Context) { func (h *Handler) DeleteGeminiKey(c *gin.Context) {
if val := strings.TrimSpace(c.Query("api-key")); val != "" { if val := strings.TrimSpace(c.Query("api-key")); val != "" {
out := make([]config.GeminiKey, 0, len(h.cfg.GeminiKey)) out := make([]config.GeminiKey, 0, len(h.cfg.GeminiKey))
@@ -268,35 +271,70 @@ func (h *Handler) PutClaudeKeys(c *gin.Context) {
h.persist(c) h.persist(c)
} }
func (h *Handler) PatchClaudeKey(c *gin.Context) { func (h *Handler) PatchClaudeKey(c *gin.Context) {
type claudeKeyPatch struct {
APIKey *string `json:"api-key"`
Prefix *string `json:"prefix"`
BaseURL *string `json:"base-url"`
ProxyURL *string `json:"proxy-url"`
Models *[]config.ClaudeModel `json:"models"`
Headers *map[string]string `json:"headers"`
ExcludedModels *[]string `json:"excluded-models"`
}
var body struct { var body struct {
Index *int `json:"index"` Index *int `json:"index"`
Match *string `json:"match"` Match *string `json:"match"`
Value *config.ClaudeKey `json:"value"` Value *claudeKeyPatch `json:"value"`
} }
if err := c.ShouldBindJSON(&body); err != nil || body.Value == nil { if err := c.ShouldBindJSON(&body); err != nil || body.Value == nil {
c.JSON(400, gin.H{"error": "invalid body"}) c.JSON(400, gin.H{"error": "invalid body"})
return return
} }
value := *body.Value targetIndex := -1
normalizeClaudeKey(&value)
if body.Index != nil && *body.Index >= 0 && *body.Index < len(h.cfg.ClaudeKey) { if body.Index != nil && *body.Index >= 0 && *body.Index < len(h.cfg.ClaudeKey) {
h.cfg.ClaudeKey[*body.Index] = value targetIndex = *body.Index
h.cfg.SanitizeClaudeKeys()
h.persist(c)
return
} }
if body.Match != nil { if targetIndex == -1 && body.Match != nil {
match := strings.TrimSpace(*body.Match)
for i := range h.cfg.ClaudeKey { for i := range h.cfg.ClaudeKey {
if h.cfg.ClaudeKey[i].APIKey == *body.Match { if h.cfg.ClaudeKey[i].APIKey == match {
h.cfg.ClaudeKey[i] = value targetIndex = i
h.cfg.SanitizeClaudeKeys() break
h.persist(c)
return
} }
} }
} }
c.JSON(404, gin.H{"error": "item not found"}) if targetIndex == -1 {
c.JSON(404, gin.H{"error": "item not found"})
return
}
entry := h.cfg.ClaudeKey[targetIndex]
if body.Value.APIKey != nil {
entry.APIKey = strings.TrimSpace(*body.Value.APIKey)
}
if body.Value.Prefix != nil {
entry.Prefix = strings.TrimSpace(*body.Value.Prefix)
}
if body.Value.BaseURL != nil {
entry.BaseURL = strings.TrimSpace(*body.Value.BaseURL)
}
if body.Value.ProxyURL != nil {
entry.ProxyURL = strings.TrimSpace(*body.Value.ProxyURL)
}
if body.Value.Models != nil {
entry.Models = append([]config.ClaudeModel(nil), (*body.Value.Models)...)
}
if body.Value.Headers != nil {
entry.Headers = config.NormalizeHeaders(*body.Value.Headers)
}
if body.Value.ExcludedModels != nil {
entry.ExcludedModels = config.NormalizeExcludedModels(*body.Value.ExcludedModels)
}
normalizeClaudeKey(&entry)
h.cfg.ClaudeKey[targetIndex] = entry
h.cfg.SanitizeClaudeKeys()
h.persist(c)
} }
func (h *Handler) DeleteClaudeKey(c *gin.Context) { func (h *Handler) DeleteClaudeKey(c *gin.Context) {
if val := c.Query("api-key"); val != "" { if val := c.Query("api-key"); val != "" {
out := make([]config.ClaudeKey, 0, len(h.cfg.ClaudeKey)) out := make([]config.ClaudeKey, 0, len(h.cfg.ClaudeKey))
@@ -356,62 +394,73 @@ func (h *Handler) PutOpenAICompat(c *gin.Context) {
h.persist(c) h.persist(c)
} }
func (h *Handler) PatchOpenAICompat(c *gin.Context) { func (h *Handler) PatchOpenAICompat(c *gin.Context) {
type openAICompatPatch struct {
Name *string `json:"name"`
Prefix *string `json:"prefix"`
BaseURL *string `json:"base-url"`
APIKeyEntries *[]config.OpenAICompatibilityAPIKey `json:"api-key-entries"`
Models *[]config.OpenAICompatibilityModel `json:"models"`
Headers *map[string]string `json:"headers"`
}
var body struct { var body struct {
Name *string `json:"name"` Name *string `json:"name"`
Index *int `json:"index"` Index *int `json:"index"`
Value *config.OpenAICompatibility `json:"value"` Value *openAICompatPatch `json:"value"`
} }
if err := c.ShouldBindJSON(&body); err != nil || body.Value == nil { if err := c.ShouldBindJSON(&body); err != nil || body.Value == nil {
c.JSON(400, gin.H{"error": "invalid body"}) c.JSON(400, gin.H{"error": "invalid body"})
return return
} }
normalizeOpenAICompatibilityEntry(body.Value) targetIndex := -1
// If base-url becomes empty, delete the provider instead of updating if body.Index != nil && *body.Index >= 0 && *body.Index < len(h.cfg.OpenAICompatibility) {
if strings.TrimSpace(body.Value.BaseURL) == "" { targetIndex = *body.Index
if body.Index != nil && *body.Index >= 0 && *body.Index < len(h.cfg.OpenAICompatibility) { }
h.cfg.OpenAICompatibility = append(h.cfg.OpenAICompatibility[:*body.Index], h.cfg.OpenAICompatibility[*body.Index+1:]...) if targetIndex == -1 && body.Name != nil {
match := strings.TrimSpace(*body.Name)
for i := range h.cfg.OpenAICompatibility {
if h.cfg.OpenAICompatibility[i].Name == match {
targetIndex = i
break
}
}
}
if targetIndex == -1 {
c.JSON(404, gin.H{"error": "item not found"})
return
}
entry := h.cfg.OpenAICompatibility[targetIndex]
if body.Value.Name != nil {
entry.Name = strings.TrimSpace(*body.Value.Name)
}
if body.Value.Prefix != nil {
entry.Prefix = strings.TrimSpace(*body.Value.Prefix)
}
if body.Value.BaseURL != nil {
trimmed := strings.TrimSpace(*body.Value.BaseURL)
if trimmed == "" {
h.cfg.OpenAICompatibility = append(h.cfg.OpenAICompatibility[:targetIndex], h.cfg.OpenAICompatibility[targetIndex+1:]...)
h.cfg.SanitizeOpenAICompatibility() h.cfg.SanitizeOpenAICompatibility()
h.persist(c) h.persist(c)
return return
} }
if body.Name != nil { entry.BaseURL = trimmed
out := make([]config.OpenAICompatibility, 0, len(h.cfg.OpenAICompatibility))
removed := false
for i := range h.cfg.OpenAICompatibility {
if !removed && h.cfg.OpenAICompatibility[i].Name == *body.Name {
removed = true
continue
}
out = append(out, h.cfg.OpenAICompatibility[i])
}
if removed {
h.cfg.OpenAICompatibility = out
h.cfg.SanitizeOpenAICompatibility()
h.persist(c)
return
}
}
c.JSON(404, gin.H{"error": "item not found"})
return
} }
if body.Index != nil && *body.Index >= 0 && *body.Index < len(h.cfg.OpenAICompatibility) { if body.Value.APIKeyEntries != nil {
h.cfg.OpenAICompatibility[*body.Index] = *body.Value entry.APIKeyEntries = append([]config.OpenAICompatibilityAPIKey(nil), (*body.Value.APIKeyEntries)...)
h.cfg.SanitizeOpenAICompatibility()
h.persist(c)
return
} }
if body.Name != nil { if body.Value.Models != nil {
for i := range h.cfg.OpenAICompatibility { entry.Models = append([]config.OpenAICompatibilityModel(nil), (*body.Value.Models)...)
if h.cfg.OpenAICompatibility[i].Name == *body.Name {
h.cfg.OpenAICompatibility[i] = *body.Value
h.cfg.SanitizeOpenAICompatibility()
h.persist(c)
return
}
}
} }
c.JSON(404, gin.H{"error": "item not found"}) if body.Value.Headers != nil {
entry.Headers = config.NormalizeHeaders(*body.Value.Headers)
}
normalizeOpenAICompatibilityEntry(&entry)
h.cfg.OpenAICompatibility[targetIndex] = entry
h.cfg.SanitizeOpenAICompatibility()
h.persist(c)
} }
func (h *Handler) DeleteOpenAICompat(c *gin.Context) { func (h *Handler) DeleteOpenAICompat(c *gin.Context) {
if name := c.Query("name"); name != "" { if name := c.Query("name"); name != "" {
out := make([]config.OpenAICompatibility, 0, len(h.cfg.OpenAICompatibility)) out := make([]config.OpenAICompatibility, 0, len(h.cfg.OpenAICompatibility))
@@ -563,66 +612,72 @@ func (h *Handler) PutCodexKeys(c *gin.Context) {
h.persist(c) h.persist(c)
} }
func (h *Handler) PatchCodexKey(c *gin.Context) { func (h *Handler) PatchCodexKey(c *gin.Context) {
type codexKeyPatch struct {
APIKey *string `json:"api-key"`
Prefix *string `json:"prefix"`
BaseURL *string `json:"base-url"`
ProxyURL *string `json:"proxy-url"`
Headers *map[string]string `json:"headers"`
ExcludedModels *[]string `json:"excluded-models"`
}
var body struct { var body struct {
Index *int `json:"index"` Index *int `json:"index"`
Match *string `json:"match"` Match *string `json:"match"`
Value *config.CodexKey `json:"value"` Value *codexKeyPatch `json:"value"`
} }
if err := c.ShouldBindJSON(&body); err != nil || body.Value == nil { if err := c.ShouldBindJSON(&body); err != nil || body.Value == nil {
c.JSON(400, gin.H{"error": "invalid body"}) c.JSON(400, gin.H{"error": "invalid body"})
return return
} }
value := *body.Value targetIndex := -1
value.APIKey = strings.TrimSpace(value.APIKey) if body.Index != nil && *body.Index >= 0 && *body.Index < len(h.cfg.CodexKey) {
value.BaseURL = strings.TrimSpace(value.BaseURL) targetIndex = *body.Index
value.ProxyURL = strings.TrimSpace(value.ProxyURL) }
value.Headers = config.NormalizeHeaders(value.Headers) if targetIndex == -1 && body.Match != nil {
value.ExcludedModels = config.NormalizeExcludedModels(value.ExcludedModels) match := strings.TrimSpace(*body.Match)
// If base-url becomes empty, delete instead of update for i := range h.cfg.CodexKey {
if value.BaseURL == "" { if h.cfg.CodexKey[i].APIKey == match {
if body.Index != nil && *body.Index >= 0 && *body.Index < len(h.cfg.CodexKey) { targetIndex = i
h.cfg.CodexKey = append(h.cfg.CodexKey[:*body.Index], h.cfg.CodexKey[*body.Index+1:]...) break
h.cfg.SanitizeCodexKeys()
h.persist(c)
return
}
if body.Match != nil {
out := make([]config.CodexKey, 0, len(h.cfg.CodexKey))
removed := false
for i := range h.cfg.CodexKey {
if !removed && h.cfg.CodexKey[i].APIKey == *body.Match {
removed = true
continue
}
out = append(out, h.cfg.CodexKey[i])
}
if removed {
h.cfg.CodexKey = out
h.cfg.SanitizeCodexKeys()
h.persist(c)
return
}
}
} else {
if body.Index != nil && *body.Index >= 0 && *body.Index < len(h.cfg.CodexKey) {
h.cfg.CodexKey[*body.Index] = value
h.cfg.SanitizeCodexKeys()
h.persist(c)
return
}
if body.Match != nil {
for i := range h.cfg.CodexKey {
if h.cfg.CodexKey[i].APIKey == *body.Match {
h.cfg.CodexKey[i] = value
h.cfg.SanitizeCodexKeys()
h.persist(c)
return
}
} }
} }
} }
c.JSON(404, gin.H{"error": "item not found"}) if targetIndex == -1 {
c.JSON(404, gin.H{"error": "item not found"})
return
}
entry := h.cfg.CodexKey[targetIndex]
if body.Value.APIKey != nil {
entry.APIKey = strings.TrimSpace(*body.Value.APIKey)
}
if body.Value.Prefix != nil {
entry.Prefix = strings.TrimSpace(*body.Value.Prefix)
}
if body.Value.BaseURL != nil {
trimmed := strings.TrimSpace(*body.Value.BaseURL)
if trimmed == "" {
h.cfg.CodexKey = append(h.cfg.CodexKey[:targetIndex], h.cfg.CodexKey[targetIndex+1:]...)
h.cfg.SanitizeCodexKeys()
h.persist(c)
return
}
entry.BaseURL = trimmed
}
if body.Value.ProxyURL != nil {
entry.ProxyURL = strings.TrimSpace(*body.Value.ProxyURL)
}
if body.Value.Headers != nil {
entry.Headers = config.NormalizeHeaders(*body.Value.Headers)
}
if body.Value.ExcludedModels != nil {
entry.ExcludedModels = config.NormalizeExcludedModels(*body.Value.ExcludedModels)
}
h.cfg.CodexKey[targetIndex] = entry
h.cfg.SanitizeCodexKeys()
h.persist(c)
} }
func (h *Handler) DeleteCodexKey(c *gin.Context) { func (h *Handler) DeleteCodexKey(c *gin.Context) {
if val := c.Query("api-key"); val != "" { if val := c.Query("api-key"); val != "" {
out := make([]config.CodexKey, 0, len(h.cfg.CodexKey)) out := make([]config.CodexKey, 0, len(h.cfg.CodexKey))

View File

@@ -0,0 +1,100 @@
package management
import (
"errors"
"net/http"
"net/url"
"strings"
"github.com/gin-gonic/gin"
)
type oauthCallbackRequest struct {
Provider string `json:"provider"`
RedirectURL string `json:"redirect_url"`
Code string `json:"code"`
State string `json:"state"`
Error string `json:"error"`
}
func (h *Handler) PostOAuthCallback(c *gin.Context) {
if h == nil || h.cfg == nil {
c.JSON(http.StatusInternalServerError, gin.H{"status": "error", "error": "handler not initialized"})
return
}
var req oauthCallbackRequest
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"status": "error", "error": "invalid body"})
return
}
canonicalProvider, err := NormalizeOAuthProvider(req.Provider)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"status": "error", "error": "unsupported provider"})
return
}
state := strings.TrimSpace(req.State)
code := strings.TrimSpace(req.Code)
errMsg := strings.TrimSpace(req.Error)
if rawRedirect := strings.TrimSpace(req.RedirectURL); rawRedirect != "" {
u, errParse := url.Parse(rawRedirect)
if errParse != nil {
c.JSON(http.StatusBadRequest, gin.H{"status": "error", "error": "invalid redirect_url"})
return
}
q := u.Query()
if state == "" {
state = strings.TrimSpace(q.Get("state"))
}
if code == "" {
code = strings.TrimSpace(q.Get("code"))
}
if errMsg == "" {
errMsg = strings.TrimSpace(q.Get("error"))
if errMsg == "" {
errMsg = strings.TrimSpace(q.Get("error_description"))
}
}
}
if state == "" {
c.JSON(http.StatusBadRequest, gin.H{"status": "error", "error": "state is required"})
return
}
if err := ValidateOAuthState(state); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"status": "error", "error": "invalid state"})
return
}
if code == "" && errMsg == "" {
c.JSON(http.StatusBadRequest, gin.H{"status": "error", "error": "code or error is required"})
return
}
sessionProvider, sessionStatus, ok := GetOAuthSession(state)
if !ok {
c.JSON(http.StatusNotFound, gin.H{"status": "error", "error": "unknown or expired state"})
return
}
if sessionStatus != "" {
c.JSON(http.StatusConflict, gin.H{"status": "error", "error": "oauth flow is not pending"})
return
}
if !strings.EqualFold(sessionProvider, canonicalProvider) {
c.JSON(http.StatusBadRequest, gin.H{"status": "error", "error": "provider does not match state"})
return
}
if _, errWrite := WriteOAuthCallbackFileForPendingSession(h.cfg.AuthDir, canonicalProvider, state, code, errMsg); errWrite != nil {
if errors.Is(errWrite, errOAuthSessionNotPending) {
c.JSON(http.StatusConflict, gin.H{"status": "error", "error": "oauth flow is not pending"})
return
}
c.JSON(http.StatusInternalServerError, gin.H{"status": "error", "error": "failed to persist oauth callback"})
return
}
c.JSON(http.StatusOK, gin.H{"status": "ok"})
}

View File

@@ -0,0 +1,283 @@
package management
import (
"encoding/json"
"errors"
"fmt"
"os"
"path/filepath"
"strings"
"sync"
"time"
)
const (
oauthSessionTTL = 10 * time.Minute
maxOAuthStateLength = 128
)
var (
errInvalidOAuthState = errors.New("invalid oauth state")
errUnsupportedOAuthFlow = errors.New("unsupported oauth provider")
errOAuthSessionNotPending = errors.New("oauth session is not pending")
)
type oauthSession struct {
Provider string
Status string
CreatedAt time.Time
ExpiresAt time.Time
}
type oauthSessionStore struct {
mu sync.RWMutex
ttl time.Duration
sessions map[string]oauthSession
}
func newOAuthSessionStore(ttl time.Duration) *oauthSessionStore {
if ttl <= 0 {
ttl = oauthSessionTTL
}
return &oauthSessionStore{
ttl: ttl,
sessions: make(map[string]oauthSession),
}
}
func (s *oauthSessionStore) purgeExpiredLocked(now time.Time) {
for state, session := range s.sessions {
if !session.ExpiresAt.IsZero() && now.After(session.ExpiresAt) {
delete(s.sessions, state)
}
}
}
func (s *oauthSessionStore) Register(state, provider string) {
state = strings.TrimSpace(state)
provider = strings.ToLower(strings.TrimSpace(provider))
if state == "" || provider == "" {
return
}
now := time.Now()
s.mu.Lock()
defer s.mu.Unlock()
s.purgeExpiredLocked(now)
s.sessions[state] = oauthSession{
Provider: provider,
Status: "",
CreatedAt: now,
ExpiresAt: now.Add(s.ttl),
}
}
func (s *oauthSessionStore) SetError(state, message string) {
state = strings.TrimSpace(state)
message = strings.TrimSpace(message)
if state == "" {
return
}
if message == "" {
message = "Authentication failed"
}
now := time.Now()
s.mu.Lock()
defer s.mu.Unlock()
s.purgeExpiredLocked(now)
session, ok := s.sessions[state]
if !ok {
return
}
session.Status = message
session.ExpiresAt = now.Add(s.ttl)
s.sessions[state] = session
}
func (s *oauthSessionStore) Complete(state string) {
state = strings.TrimSpace(state)
if state == "" {
return
}
now := time.Now()
s.mu.Lock()
defer s.mu.Unlock()
s.purgeExpiredLocked(now)
delete(s.sessions, state)
}
func (s *oauthSessionStore) CompleteProvider(provider string) int {
provider = strings.ToLower(strings.TrimSpace(provider))
if provider == "" {
return 0
}
now := time.Now()
s.mu.Lock()
defer s.mu.Unlock()
s.purgeExpiredLocked(now)
removed := 0
for state, session := range s.sessions {
if strings.EqualFold(session.Provider, provider) {
delete(s.sessions, state)
removed++
}
}
return removed
}
func (s *oauthSessionStore) Get(state string) (oauthSession, bool) {
state = strings.TrimSpace(state)
now := time.Now()
s.mu.Lock()
defer s.mu.Unlock()
s.purgeExpiredLocked(now)
session, ok := s.sessions[state]
return session, ok
}
func (s *oauthSessionStore) IsPending(state, provider string) bool {
state = strings.TrimSpace(state)
provider = strings.ToLower(strings.TrimSpace(provider))
now := time.Now()
s.mu.Lock()
defer s.mu.Unlock()
s.purgeExpiredLocked(now)
session, ok := s.sessions[state]
if !ok {
return false
}
if session.Status != "" {
return false
}
if provider == "" {
return true
}
return strings.EqualFold(session.Provider, provider)
}
var oauthSessions = newOAuthSessionStore(oauthSessionTTL)
func RegisterOAuthSession(state, provider string) { oauthSessions.Register(state, provider) }
func SetOAuthSessionError(state, message string) { oauthSessions.SetError(state, message) }
func CompleteOAuthSession(state string) { oauthSessions.Complete(state) }
func CompleteOAuthSessionsByProvider(provider string) int {
return oauthSessions.CompleteProvider(provider)
}
func GetOAuthSession(state string) (provider string, status string, ok bool) {
session, ok := oauthSessions.Get(state)
if !ok {
return "", "", false
}
return session.Provider, session.Status, true
}
func IsOAuthSessionPending(state, provider string) bool {
return oauthSessions.IsPending(state, provider)
}
func ValidateOAuthState(state string) error {
trimmed := strings.TrimSpace(state)
if trimmed == "" {
return fmt.Errorf("%w: empty", errInvalidOAuthState)
}
if len(trimmed) > maxOAuthStateLength {
return fmt.Errorf("%w: too long", errInvalidOAuthState)
}
if strings.Contains(trimmed, "/") || strings.Contains(trimmed, "\\") {
return fmt.Errorf("%w: contains path separator", errInvalidOAuthState)
}
if strings.Contains(trimmed, "..") {
return fmt.Errorf("%w: contains '..'", errInvalidOAuthState)
}
for _, r := range trimmed {
switch {
case r >= 'a' && r <= 'z':
case r >= 'A' && r <= 'Z':
case r >= '0' && r <= '9':
case r == '-' || r == '_' || r == '.':
default:
return fmt.Errorf("%w: invalid character", errInvalidOAuthState)
}
}
return nil
}
func NormalizeOAuthProvider(provider string) (string, error) {
switch strings.ToLower(strings.TrimSpace(provider)) {
case "anthropic", "claude":
return "anthropic", nil
case "codex", "openai":
return "codex", nil
case "gemini", "google":
return "gemini", nil
case "iflow", "i-flow":
return "iflow", nil
case "antigravity", "anti-gravity":
return "antigravity", nil
case "qwen":
return "qwen", nil
default:
return "", errUnsupportedOAuthFlow
}
}
type oauthCallbackFilePayload struct {
Code string `json:"code"`
State string `json:"state"`
Error string `json:"error"`
}
func WriteOAuthCallbackFile(authDir, provider, state, code, errorMessage string) (string, error) {
if strings.TrimSpace(authDir) == "" {
return "", fmt.Errorf("auth dir is empty")
}
canonicalProvider, err := NormalizeOAuthProvider(provider)
if err != nil {
return "", err
}
if err := ValidateOAuthState(state); err != nil {
return "", err
}
fileName := fmt.Sprintf(".oauth-%s-%s.oauth", canonicalProvider, state)
filePath := filepath.Join(authDir, fileName)
payload := oauthCallbackFilePayload{
Code: strings.TrimSpace(code),
State: strings.TrimSpace(state),
Error: strings.TrimSpace(errorMessage),
}
data, err := json.Marshal(payload)
if err != nil {
return "", fmt.Errorf("marshal oauth callback payload: %w", err)
}
if err := os.WriteFile(filePath, data, 0o600); err != nil {
return "", fmt.Errorf("write oauth callback file: %w", err)
}
return filePath, nil
}
func WriteOAuthCallbackFileForPendingSession(authDir, provider, state, code, errorMessage string) (string, error) {
canonicalProvider, err := NormalizeOAuthProvider(provider)
if err != nil {
return "", err
}
if !IsOAuthSessionPending(state, canonicalProvider) {
return "", errOAuthSessionNotPending
}
return WriteOAuthCallbackFile(authDir, canonicalProvider, state, code, errorMessage)
}

View File

@@ -134,7 +134,43 @@ func (fh *FallbackHandler) WrapHandler(handler gin.HandlerFunc) gin.HandlerFunc
} }
// Normalize model (handles dynamic thinking suffixes) // Normalize model (handles dynamic thinking suffixes)
normalizedModel, _ := util.NormalizeThinkingModel(modelName) normalizedModel, thinkingMetadata := util.NormalizeThinkingModel(modelName)
thinkingSuffix := ""
if thinkingMetadata != nil && strings.HasPrefix(modelName, normalizedModel) {
thinkingSuffix = modelName[len(normalizedModel):]
}
resolveMappedModel := func() (string, []string) {
if fh.modelMapper == nil {
return "", nil
}
mappedModel := fh.modelMapper.MapModel(modelName)
if mappedModel == "" {
mappedModel = fh.modelMapper.MapModel(normalizedModel)
}
mappedModel = strings.TrimSpace(mappedModel)
if mappedModel == "" {
return "", nil
}
// Preserve dynamic thinking suffix (e.g. "(xhigh)") when mapping applies, unless the target
// already specifies its own thinking suffix.
if thinkingSuffix != "" {
_, mappedThinkingMetadata := util.NormalizeThinkingModel(mappedModel)
if mappedThinkingMetadata == nil {
mappedModel += thinkingSuffix
}
}
mappedBaseModel, _ := util.NormalizeThinkingModel(mappedModel)
mappedProviders := util.GetProviderName(mappedBaseModel)
if len(mappedProviders) == 0 {
return "", nil
}
return mappedModel, mappedProviders
}
// Track resolved model for logging (may change if mapping is applied) // Track resolved model for logging (may change if mapping is applied)
resolvedModel := normalizedModel resolvedModel := normalizedModel
@@ -147,21 +183,15 @@ func (fh *FallbackHandler) WrapHandler(handler gin.HandlerFunc) gin.HandlerFunc
if forceMappings { if forceMappings {
// FORCE MODE: Check model mappings FIRST (takes precedence over local API keys) // FORCE MODE: Check model mappings FIRST (takes precedence over local API keys)
// This allows users to route Amp requests to their preferred OAuth providers // This allows users to route Amp requests to their preferred OAuth providers
if fh.modelMapper != nil { if mappedModel, mappedProviders := resolveMappedModel(); mappedModel != "" {
if mappedModel := fh.modelMapper.MapModel(normalizedModel); mappedModel != "" { // Mapping found and provider available - rewrite the model in request body
// Mapping found - check if we have a provider for the mapped model bodyBytes = rewriteModelInRequest(bodyBytes, mappedModel)
mappedProviders := util.GetProviderName(mappedModel) c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes))
if len(mappedProviders) > 0 { // Store mapped model in context for handlers that check it (like gemini bridge)
// Mapping found and provider available - rewrite the model in request body c.Set(MappedModelContextKey, mappedModel)
bodyBytes = rewriteModelInRequest(bodyBytes, mappedModel) resolvedModel = mappedModel
c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes)) usedMapping = true
// Store mapped model in context for handlers that check it (like gemini bridge) providers = mappedProviders
c.Set(MappedModelContextKey, mappedModel)
resolvedModel = mappedModel
usedMapping = true
providers = mappedProviders
}
}
} }
// If no mapping applied, check for local providers // If no mapping applied, check for local providers
@@ -174,21 +204,15 @@ func (fh *FallbackHandler) WrapHandler(handler gin.HandlerFunc) gin.HandlerFunc
if len(providers) == 0 { if len(providers) == 0 {
// No providers configured - check if we have a model mapping // No providers configured - check if we have a model mapping
if fh.modelMapper != nil { if mappedModel, mappedProviders := resolveMappedModel(); mappedModel != "" {
if mappedModel := fh.modelMapper.MapModel(normalizedModel); mappedModel != "" { // Mapping found and provider available - rewrite the model in request body
// Mapping found - check if we have a provider for the mapped model bodyBytes = rewriteModelInRequest(bodyBytes, mappedModel)
mappedProviders := util.GetProviderName(mappedModel) c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes))
if len(mappedProviders) > 0 { // Store mapped model in context for handlers that check it (like gemini bridge)
// Mapping found and provider available - rewrite the model in request body c.Set(MappedModelContextKey, mappedModel)
bodyBytes = rewriteModelInRequest(bodyBytes, mappedModel) resolvedModel = mappedModel
c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes)) usedMapping = true
// Store mapped model in context for handlers that check it (like gemini bridge) providers = mappedProviders
c.Set(MappedModelContextKey, mappedModel)
resolvedModel = mappedModel
usedMapping = true
providers = mappedProviders
}
}
} }
} }
} }
@@ -222,14 +246,14 @@ func (fh *FallbackHandler) WrapHandler(handler gin.HandlerFunc) gin.HandlerFunc
// Log: Model was mapped to another model // Log: Model was mapped to another model
log.Debugf("amp model mapping: request %s -> %s", normalizedModel, resolvedModel) log.Debugf("amp model mapping: request %s -> %s", normalizedModel, resolvedModel)
logAmpRouting(RouteTypeModelMapping, modelName, resolvedModel, providerName, requestPath) logAmpRouting(RouteTypeModelMapping, modelName, resolvedModel, providerName, requestPath)
rewriter := NewResponseRewriter(c.Writer, normalizedModel) rewriter := NewResponseRewriter(c.Writer, modelName)
c.Writer = rewriter c.Writer = rewriter
// Filter Anthropic-Beta header only for local handling paths // Filter Anthropic-Beta header only for local handling paths
filterAntropicBetaHeader(c) filterAntropicBetaHeader(c)
c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes)) c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes))
handler(c) handler(c)
rewriter.Flush() rewriter.Flush()
log.Debugf("amp model mapping: response %s -> %s", resolvedModel, normalizedModel) log.Debugf("amp model mapping: response %s -> %s", resolvedModel, modelName)
} else if len(providers) > 0 { } else if len(providers) > 0 {
// Log: Using local provider (free) // Log: Using local provider (free)
logAmpRouting(RouteTypeLocalProvider, modelName, resolvedModel, providerName, requestPath) logAmpRouting(RouteTypeLocalProvider, modelName, resolvedModel, providerName, requestPath)

View File

@@ -0,0 +1,73 @@
package amp
import (
"bytes"
"encoding/json"
"net/http"
"net/http/httptest"
"net/http/httputil"
"testing"
"github.com/gin-gonic/gin"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
)
func TestFallbackHandler_ModelMapping_PreservesThinkingSuffixAndRewritesResponse(t *testing.T) {
gin.SetMode(gin.TestMode)
reg := registry.GetGlobalRegistry()
reg.RegisterClient("test-client-amp-fallback", "codex", []*registry.ModelInfo{
{ID: "test/gpt-5.2", OwnedBy: "openai", Type: "codex"},
})
defer reg.UnregisterClient("test-client-amp-fallback")
mapper := NewModelMapper([]config.AmpModelMapping{
{From: "gpt-5.2", To: "test/gpt-5.2"},
})
fallback := NewFallbackHandlerWithMapper(func() *httputil.ReverseProxy { return nil }, mapper, nil)
handler := func(c *gin.Context) {
var req struct {
Model string `json:"model"`
}
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, gin.H{
"model": req.Model,
"seen_model": req.Model,
})
}
r := gin.New()
r.POST("/chat/completions", fallback.WrapHandler(handler))
reqBody := []byte(`{"model":"gpt-5.2(xhigh)"}`)
req := httptest.NewRequest(http.MethodPost, "/chat/completions", bytes.NewReader(reqBody))
req.Header.Set("Content-Type", "application/json")
w := httptest.NewRecorder()
r.ServeHTTP(w, req)
if w.Code != http.StatusOK {
t.Fatalf("Expected status 200, got %d", w.Code)
}
var resp struct {
Model string `json:"model"`
SeenModel string `json:"seen_model"`
}
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
t.Fatalf("Failed to parse response JSON: %v", err)
}
if resp.Model != "gpt-5.2(xhigh)" {
t.Errorf("Expected response model gpt-5.2(xhigh), got %s", resp.Model)
}
if resp.SeenModel != "test/gpt-5.2(xhigh)" {
t.Errorf("Expected handler to see test/gpt-5.2(xhigh), got %s", resp.SeenModel)
}
}

View File

@@ -59,7 +59,8 @@ func (m *DefaultModelMapper) MapModel(requestedModel string) string {
} }
// Verify target model has available providers // Verify target model has available providers
providers := util.GetProviderName(targetModel) normalizedTarget, _ := util.NormalizeThinkingModel(targetModel)
providers := util.GetProviderName(normalizedTarget)
if len(providers) == 0 { if len(providers) == 0 {
log.Debugf("amp model mapping: target model %s has no available providers, skipping mapping", targetModel) log.Debugf("amp model mapping: target model %s has no available providers, skipping mapping", targetModel)
return "" return ""

View File

@@ -71,6 +71,25 @@ func TestModelMapper_MapModel_WithProvider(t *testing.T) {
} }
} }
func TestModelMapper_MapModel_TargetWithThinkingSuffix(t *testing.T) {
reg := registry.GetGlobalRegistry()
reg.RegisterClient("test-client-thinking", "codex", []*registry.ModelInfo{
{ID: "gpt-5.2", OwnedBy: "openai", Type: "codex"},
})
defer reg.UnregisterClient("test-client-thinking")
mappings := []config.AmpModelMapping{
{From: "gpt-5.2-alias", To: "gpt-5.2(xhigh)"},
}
mapper := NewModelMapper(mappings)
result := mapper.MapModel("gpt-5.2-alias")
if result != "gpt-5.2(xhigh)" {
t.Errorf("Expected gpt-5.2(xhigh), got %s", result)
}
}
func TestModelMapper_MapModel_CaseInsensitive(t *testing.T) { func TestModelMapper_MapModel_CaseInsensitive(t *testing.T) {
reg := registry.GetGlobalRegistry() reg := registry.GetGlobalRegistry()
reg.RegisterClient("test-client2", "claude", []*registry.ModelInfo{ reg.RegisterClient("test-client2", "claude", []*registry.ModelInfo{

View File

@@ -95,6 +95,20 @@ func (m *AmpModule) managementAvailabilityMiddleware() gin.HandlerFunc {
} }
} }
// wrapManagementAuth skips auth for selected management paths while keeping authentication elsewhere.
func wrapManagementAuth(auth gin.HandlerFunc, prefixes ...string) gin.HandlerFunc {
return func(c *gin.Context) {
path := c.Request.URL.Path
for _, prefix := range prefixes {
if strings.HasPrefix(path, prefix) && (len(path) == len(prefix) || path[len(prefix)] == '/') {
c.Next()
return
}
}
auth(c)
}
}
// registerManagementRoutes registers Amp management proxy routes // registerManagementRoutes registers Amp management proxy routes
// These routes proxy through to the Amp control plane for OAuth, user management, etc. // These routes proxy through to the Amp control plane for OAuth, user management, etc.
// Uses dynamic middleware and proxy getter for hot-reload support. // Uses dynamic middleware and proxy getter for hot-reload support.
@@ -109,8 +123,10 @@ func (m *AmpModule) registerManagementRoutes(engine *gin.Engine, baseHandler *ha
ampAPI.Use(m.localhostOnlyMiddleware()) ampAPI.Use(m.localhostOnlyMiddleware())
// Apply authentication middleware - requires valid API key in Authorization header // Apply authentication middleware - requires valid API key in Authorization header
var authWithBypass gin.HandlerFunc
if auth != nil { if auth != nil {
ampAPI.Use(auth) ampAPI.Use(auth)
authWithBypass = wrapManagementAuth(auth, "/threads", "/auth", "/docs", "/settings")
} }
// Dynamic proxy handler that uses m.getProxy() for hot-reload support // Dynamic proxy handler that uses m.getProxy() for hot-reload support
@@ -156,10 +172,16 @@ func (m *AmpModule) registerManagementRoutes(engine *gin.Engine, baseHandler *ha
// Root-level routes that AMP CLI expects without /api prefix // Root-level routes that AMP CLI expects without /api prefix
// These need the same security middleware as the /api/* routes (dynamic for hot-reload) // These need the same security middleware as the /api/* routes (dynamic for hot-reload)
rootMiddleware := []gin.HandlerFunc{m.managementAvailabilityMiddleware(), noCORSMiddleware(), m.localhostOnlyMiddleware()} rootMiddleware := []gin.HandlerFunc{m.managementAvailabilityMiddleware(), noCORSMiddleware(), m.localhostOnlyMiddleware()}
if auth != nil { if authWithBypass != nil {
rootMiddleware = append(rootMiddleware, auth) rootMiddleware = append(rootMiddleware, authWithBypass)
} }
engine.GET("/threads", append(rootMiddleware, proxyHandler)...)
engine.GET("/threads/*path", append(rootMiddleware, proxyHandler)...) engine.GET("/threads/*path", append(rootMiddleware, proxyHandler)...)
engine.GET("/docs", append(rootMiddleware, proxyHandler)...)
engine.GET("/docs/*path", append(rootMiddleware, proxyHandler)...)
engine.GET("/settings", append(rootMiddleware, proxyHandler)...)
engine.GET("/settings/*path", append(rootMiddleware, proxyHandler)...)
engine.GET("/threads.rss", append(rootMiddleware, proxyHandler)...) engine.GET("/threads.rss", append(rootMiddleware, proxyHandler)...)
engine.GET("/news.rss", append(rootMiddleware, proxyHandler)...) engine.GET("/news.rss", append(rootMiddleware, proxyHandler)...)

View File

@@ -354,10 +354,11 @@ func (s *Server) setupRoutes() {
code := c.Query("code") code := c.Query("code")
state := c.Query("state") state := c.Query("state")
errStr := c.Query("error") errStr := c.Query("error")
// Persist to a temporary file keyed by state if errStr == "" {
errStr = c.Query("error_description")
}
if state != "" { if state != "" {
file := fmt.Sprintf("%s/.oauth-anthropic-%s.oauth", s.cfg.AuthDir, state) _, _ = managementHandlers.WriteOAuthCallbackFileForPendingSession(s.cfg.AuthDir, "anthropic", state, code, errStr)
_ = os.WriteFile(file, []byte(fmt.Sprintf(`{"code":"%s","state":"%s","error":"%s"}`, code, state, errStr)), 0o600)
} }
c.Header("Content-Type", "text/html; charset=utf-8") c.Header("Content-Type", "text/html; charset=utf-8")
c.String(http.StatusOK, oauthCallbackSuccessHTML) c.String(http.StatusOK, oauthCallbackSuccessHTML)
@@ -367,9 +368,11 @@ func (s *Server) setupRoutes() {
code := c.Query("code") code := c.Query("code")
state := c.Query("state") state := c.Query("state")
errStr := c.Query("error") errStr := c.Query("error")
if errStr == "" {
errStr = c.Query("error_description")
}
if state != "" { if state != "" {
file := fmt.Sprintf("%s/.oauth-codex-%s.oauth", s.cfg.AuthDir, state) _, _ = managementHandlers.WriteOAuthCallbackFileForPendingSession(s.cfg.AuthDir, "codex", state, code, errStr)
_ = os.WriteFile(file, []byte(fmt.Sprintf(`{"code":"%s","state":"%s","error":"%s"}`, code, state, errStr)), 0o600)
} }
c.Header("Content-Type", "text/html; charset=utf-8") c.Header("Content-Type", "text/html; charset=utf-8")
c.String(http.StatusOK, oauthCallbackSuccessHTML) c.String(http.StatusOK, oauthCallbackSuccessHTML)
@@ -379,9 +382,11 @@ func (s *Server) setupRoutes() {
code := c.Query("code") code := c.Query("code")
state := c.Query("state") state := c.Query("state")
errStr := c.Query("error") errStr := c.Query("error")
if errStr == "" {
errStr = c.Query("error_description")
}
if state != "" { if state != "" {
file := fmt.Sprintf("%s/.oauth-gemini-%s.oauth", s.cfg.AuthDir, state) _, _ = managementHandlers.WriteOAuthCallbackFileForPendingSession(s.cfg.AuthDir, "gemini", state, code, errStr)
_ = os.WriteFile(file, []byte(fmt.Sprintf(`{"code":"%s","state":"%s","error":"%s"}`, code, state, errStr)), 0o600)
} }
c.Header("Content-Type", "text/html; charset=utf-8") c.Header("Content-Type", "text/html; charset=utf-8")
c.String(http.StatusOK, oauthCallbackSuccessHTML) c.String(http.StatusOK, oauthCallbackSuccessHTML)
@@ -391,9 +396,11 @@ func (s *Server) setupRoutes() {
code := c.Query("code") code := c.Query("code")
state := c.Query("state") state := c.Query("state")
errStr := c.Query("error") errStr := c.Query("error")
if errStr == "" {
errStr = c.Query("error_description")
}
if state != "" { if state != "" {
file := fmt.Sprintf("%s/.oauth-iflow-%s.oauth", s.cfg.AuthDir, state) _, _ = managementHandlers.WriteOAuthCallbackFileForPendingSession(s.cfg.AuthDir, "iflow", state, code, errStr)
_ = os.WriteFile(file, []byte(fmt.Sprintf(`{"code":"%s","state":"%s","error":"%s"}`, code, state, errStr)), 0o600)
} }
c.Header("Content-Type", "text/html; charset=utf-8") c.Header("Content-Type", "text/html; charset=utf-8")
c.String(http.StatusOK, oauthCallbackSuccessHTML) c.String(http.StatusOK, oauthCallbackSuccessHTML)
@@ -403,9 +410,11 @@ func (s *Server) setupRoutes() {
code := c.Query("code") code := c.Query("code")
state := c.Query("state") state := c.Query("state")
errStr := c.Query("error") errStr := c.Query("error")
if errStr == "" {
errStr = c.Query("error_description")
}
if state != "" { if state != "" {
file := fmt.Sprintf("%s/.oauth-antigravity-%s.oauth", s.cfg.AuthDir, state) _, _ = managementHandlers.WriteOAuthCallbackFileForPendingSession(s.cfg.AuthDir, "antigravity", state, code, errStr)
_ = os.WriteFile(file, []byte(fmt.Sprintf(`{"code":"%s","state":"%s","error":"%s"}`, code, state, errStr)), 0o600)
} }
c.Header("Content-Type", "text/html; charset=utf-8") c.Header("Content-Type", "text/html; charset=utf-8")
c.String(http.StatusOK, oauthCallbackSuccessHTML) c.String(http.StatusOK, oauthCallbackSuccessHTML)
@@ -577,6 +586,7 @@ func (s *Server) registerManagementRoutes() {
mgmt.GET("/qwen-auth-url", s.mgmt.RequestQwenToken) mgmt.GET("/qwen-auth-url", s.mgmt.RequestQwenToken)
mgmt.GET("/iflow-auth-url", s.mgmt.RequestIFlowToken) mgmt.GET("/iflow-auth-url", s.mgmt.RequestIFlowToken)
mgmt.POST("/iflow-auth-url", s.mgmt.RequestIFlowCookieToken) mgmt.POST("/iflow-auth-url", s.mgmt.RequestIFlowCookieToken)
mgmt.POST("/oauth-callback", s.mgmt.PostOAuthCallback)
mgmt.GET("/get-auth-status", s.mgmt.GetAuthStatus) mgmt.GET("/get-auth-status", s.mgmt.GetAuthStatus)
} }
} }
@@ -834,11 +844,20 @@ func (s *Server) UpdateClients(cfg *config.Config) {
} }
} }
if oldCfg != nil && oldCfg.LoggingToFile != cfg.LoggingToFile { if oldCfg == nil || oldCfg.LoggingToFile != cfg.LoggingToFile || oldCfg.LogsMaxTotalSizeMB != cfg.LogsMaxTotalSizeMB {
if err := logging.ConfigureLogOutput(cfg.LoggingToFile); err != nil { if err := logging.ConfigureLogOutput(cfg.LoggingToFile, cfg.LogsMaxTotalSizeMB); err != nil {
log.Errorf("failed to reconfigure log output: %v", err) log.Errorf("failed to reconfigure log output: %v", err)
} else { } else {
log.Debugf("logging_to_file updated from %t to %t", oldCfg.LoggingToFile, cfg.LoggingToFile) if oldCfg == nil {
log.Debug("log output configuration refreshed")
} else {
if oldCfg.LoggingToFile != cfg.LoggingToFile {
log.Debugf("logging_to_file updated from %t to %t", oldCfg.LoggingToFile, cfg.LoggingToFile)
}
if oldCfg.LogsMaxTotalSizeMB != cfg.LogsMaxTotalSizeMB {
log.Debugf("logs_max_total_size_mb updated from %d to %d", oldCfg.LogsMaxTotalSizeMB, cfg.LogsMaxTotalSizeMB)
}
}
} }
} }

View File

@@ -18,6 +18,7 @@ import (
"github.com/router-for-me/CLIProxyAPI/v6/internal/auth/codex" "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/codex"
"github.com/router-for-me/CLIProxyAPI/v6/internal/browser" "github.com/router-for-me/CLIProxyAPI/v6/internal/browser"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config" "github.com/router-for-me/CLIProxyAPI/v6/internal/config"
"github.com/router-for-me/CLIProxyAPI/v6/internal/misc"
"github.com/router-for-me/CLIProxyAPI/v6/internal/util" "github.com/router-for-me/CLIProxyAPI/v6/internal/util"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/tidwall/gjson" "github.com/tidwall/gjson"
@@ -46,6 +47,12 @@ var (
type GeminiAuth struct { type GeminiAuth struct {
} }
// WebLoginOptions customizes the interactive OAuth flow.
type WebLoginOptions struct {
NoBrowser bool
Prompt func(string) (string, error)
}
// NewGeminiAuth creates a new instance of GeminiAuth. // NewGeminiAuth creates a new instance of GeminiAuth.
func NewGeminiAuth() *GeminiAuth { func NewGeminiAuth() *GeminiAuth {
return &GeminiAuth{} return &GeminiAuth{}
@@ -59,12 +66,12 @@ func NewGeminiAuth() *GeminiAuth {
// - ctx: The context for the HTTP client // - ctx: The context for the HTTP client
// - ts: The Gemini token storage containing authentication tokens // - ts: The Gemini token storage containing authentication tokens
// - cfg: The configuration containing proxy settings // - cfg: The configuration containing proxy settings
// - noBrowser: Optional parameter to disable browser opening // - opts: Optional parameters to customize browser and prompt behavior
// //
// Returns: // Returns:
// - *http.Client: An HTTP client configured with authentication // - *http.Client: An HTTP client configured with authentication
// - error: An error if the client configuration fails, nil otherwise // - error: An error if the client configuration fails, nil otherwise
func (g *GeminiAuth) GetAuthenticatedClient(ctx context.Context, ts *GeminiTokenStorage, cfg *config.Config, noBrowser ...bool) (*http.Client, error) { func (g *GeminiAuth) GetAuthenticatedClient(ctx context.Context, ts *GeminiTokenStorage, cfg *config.Config, opts *WebLoginOptions) (*http.Client, error) {
// Configure proxy settings for the HTTP client if a proxy URL is provided. // Configure proxy settings for the HTTP client if a proxy URL is provided.
proxyURL, err := url.Parse(cfg.ProxyURL) proxyURL, err := url.Parse(cfg.ProxyURL)
if err == nil { if err == nil {
@@ -109,7 +116,7 @@ func (g *GeminiAuth) GetAuthenticatedClient(ctx context.Context, ts *GeminiToken
// If no token is found in storage, initiate the web-based OAuth flow. // If no token is found in storage, initiate the web-based OAuth flow.
if ts.Token == nil { if ts.Token == nil {
fmt.Printf("Could not load token from file, starting OAuth flow.\n") fmt.Printf("Could not load token from file, starting OAuth flow.\n")
token, err = g.getTokenFromWeb(ctx, conf, noBrowser...) token, err = g.getTokenFromWeb(ctx, conf, opts)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to get token from web: %w", err) return nil, fmt.Errorf("failed to get token from web: %w", err)
} }
@@ -205,15 +212,15 @@ func (g *GeminiAuth) createTokenStorage(ctx context.Context, config *oauth2.Conf
// Parameters: // Parameters:
// - ctx: The context for the HTTP client // - ctx: The context for the HTTP client
// - config: The OAuth2 configuration // - config: The OAuth2 configuration
// - noBrowser: Optional parameter to disable browser opening // - opts: Optional parameters to customize browser and prompt behavior
// //
// Returns: // Returns:
// - *oauth2.Token: The OAuth2 token obtained from the authorization flow // - *oauth2.Token: The OAuth2 token obtained from the authorization flow
// - error: An error if the token acquisition fails, nil otherwise // - error: An error if the token acquisition fails, nil otherwise
func (g *GeminiAuth) getTokenFromWeb(ctx context.Context, config *oauth2.Config, noBrowser ...bool) (*oauth2.Token, error) { func (g *GeminiAuth) getTokenFromWeb(ctx context.Context, config *oauth2.Config, opts *WebLoginOptions) (*oauth2.Token, error) {
// Use a channel to pass the authorization code from the HTTP handler to the main function. // Use a channel to pass the authorization code from the HTTP handler to the main function.
codeChan := make(chan string) codeChan := make(chan string, 1)
errChan := make(chan error) errChan := make(chan error, 1)
// Create a new HTTP server with its own multiplexer. // Create a new HTTP server with its own multiplexer.
mux := http.NewServeMux() mux := http.NewServeMux()
@@ -223,17 +230,26 @@ func (g *GeminiAuth) getTokenFromWeb(ctx context.Context, config *oauth2.Config,
mux.HandleFunc("/oauth2callback", func(w http.ResponseWriter, r *http.Request) { mux.HandleFunc("/oauth2callback", func(w http.ResponseWriter, r *http.Request) {
if err := r.URL.Query().Get("error"); err != "" { if err := r.URL.Query().Get("error"); err != "" {
_, _ = fmt.Fprintf(w, "Authentication failed: %s", err) _, _ = fmt.Fprintf(w, "Authentication failed: %s", err)
errChan <- fmt.Errorf("authentication failed via callback: %s", err) select {
case errChan <- fmt.Errorf("authentication failed via callback: %s", err):
default:
}
return return
} }
code := r.URL.Query().Get("code") code := r.URL.Query().Get("code")
if code == "" { if code == "" {
_, _ = fmt.Fprint(w, "Authentication failed: code not found.") _, _ = fmt.Fprint(w, "Authentication failed: code not found.")
errChan <- fmt.Errorf("code not found in callback") select {
case errChan <- fmt.Errorf("code not found in callback"):
default:
}
return return
} }
_, _ = fmt.Fprint(w, "<html><body><h1>Authentication successful!</h1><p>You can close this window.</p></body></html>") _, _ = fmt.Fprint(w, "<html><body><h1>Authentication successful!</h1><p>You can close this window.</p></body></html>")
codeChan <- code select {
case codeChan <- code:
default:
}
}) })
// Start the server in a goroutine. // Start the server in a goroutine.
@@ -250,7 +266,12 @@ func (g *GeminiAuth) getTokenFromWeb(ctx context.Context, config *oauth2.Config,
// Open the authorization URL in the user's browser. // Open the authorization URL in the user's browser.
authURL := config.AuthCodeURL("state-token", oauth2.AccessTypeOffline, oauth2.SetAuthURLParam("prompt", "consent")) authURL := config.AuthCodeURL("state-token", oauth2.AccessTypeOffline, oauth2.SetAuthURLParam("prompt", "consent"))
if len(noBrowser) == 1 && !noBrowser[0] { noBrowser := false
if opts != nil {
noBrowser = opts.NoBrowser
}
if !noBrowser {
fmt.Println("Opening browser for authentication...") fmt.Println("Opening browser for authentication...")
// Check if browser is available // Check if browser is available
@@ -281,13 +302,60 @@ func (g *GeminiAuth) getTokenFromWeb(ctx context.Context, config *oauth2.Config,
// Wait for the authorization code or an error. // Wait for the authorization code or an error.
var authCode string var authCode string
select { timeoutTimer := time.NewTimer(5 * time.Minute)
case code := <-codeChan: defer timeoutTimer.Stop()
authCode = code
case err := <-errChan: var manualPromptTimer *time.Timer
return nil, err var manualPromptC <-chan time.Time
case <-time.After(5 * time.Minute): // Timeout if opts != nil && opts.Prompt != nil {
return nil, fmt.Errorf("oauth flow timed out") manualPromptTimer = time.NewTimer(15 * time.Second)
manualPromptC = manualPromptTimer.C
defer manualPromptTimer.Stop()
}
waitForCallback:
for {
select {
case code := <-codeChan:
authCode = code
break waitForCallback
case err := <-errChan:
return nil, err
case <-manualPromptC:
manualPromptC = nil
if manualPromptTimer != nil {
manualPromptTimer.Stop()
}
select {
case code := <-codeChan:
authCode = code
break waitForCallback
case err := <-errChan:
return nil, err
default:
}
input, err := opts.Prompt("Paste the Gemini callback URL (or press Enter to keep waiting): ")
if err != nil {
return nil, err
}
parsed, err := misc.ParseOAuthCallback(input)
if err != nil {
return nil, err
}
if parsed == nil {
continue
}
if parsed.Error != "" {
return nil, fmt.Errorf("authentication failed via callback: %s", parsed.Error)
}
if parsed.Code == "" {
return nil, fmt.Errorf("code not found in callback")
}
authCode = parsed.Code
break waitForCallback
case <-timeoutTimer.C:
return nil, fmt.Errorf("oauth flow timed out")
}
} }
// Shutdown the server. // Shutdown the server.

164
internal/cache/signature_cache.go vendored Normal file
View File

@@ -0,0 +1,164 @@
package cache
import (
"crypto/sha256"
"encoding/hex"
"sort"
"sync"
"time"
)
// SignatureEntry holds a cached thinking signature with timestamp
type SignatureEntry struct {
Signature string
Timestamp time.Time
}
const (
// SignatureCacheTTL is how long signatures are valid
SignatureCacheTTL = 1 * time.Hour
// MaxEntriesPerSession limits memory usage per session
MaxEntriesPerSession = 100
// SignatureTextHashLen is the length of the hash key (16 hex chars = 64-bit key space)
SignatureTextHashLen = 16
// MinValidSignatureLen is the minimum length for a signature to be considered valid
MinValidSignatureLen = 50
)
// signatureCache stores signatures by sessionId -> textHash -> SignatureEntry
var signatureCache sync.Map
// sessionCache is the inner map type
type sessionCache struct {
mu sync.RWMutex
entries map[string]SignatureEntry
}
// hashText creates a stable, Unicode-safe key from text content
func hashText(text string) string {
h := sha256.Sum256([]byte(text))
return hex.EncodeToString(h[:])[:SignatureTextHashLen]
}
// getOrCreateSession gets or creates a session cache
func getOrCreateSession(sessionID string) *sessionCache {
if val, ok := signatureCache.Load(sessionID); ok {
return val.(*sessionCache)
}
sc := &sessionCache{entries: make(map[string]SignatureEntry)}
actual, _ := signatureCache.LoadOrStore(sessionID, sc)
return actual.(*sessionCache)
}
// CacheSignature stores a thinking signature for a given session and text.
// Used for Claude models that require signed thinking blocks in multi-turn conversations.
func CacheSignature(sessionID, text, signature string) {
if sessionID == "" || text == "" || signature == "" {
return
}
if len(signature) < MinValidSignatureLen {
return
}
sc := getOrCreateSession(sessionID)
textHash := hashText(text)
sc.mu.Lock()
defer sc.mu.Unlock()
// Evict expired entries if at capacity
if len(sc.entries) >= MaxEntriesPerSession {
now := time.Now()
for key, entry := range sc.entries {
if now.Sub(entry.Timestamp) > SignatureCacheTTL {
delete(sc.entries, key)
}
}
// If still at capacity, remove oldest entries
if len(sc.entries) >= MaxEntriesPerSession {
// Find and remove oldest quarter
oldest := make([]struct {
key string
ts time.Time
}, 0, len(sc.entries))
for key, entry := range sc.entries {
oldest = append(oldest, struct {
key string
ts time.Time
}{key, entry.Timestamp})
}
// Sort by timestamp (oldest first) using sort.Slice
sort.Slice(oldest, func(i, j int) bool {
return oldest[i].ts.Before(oldest[j].ts)
})
toRemove := len(oldest) / 4
if toRemove < 1 {
toRemove = 1
}
for i := 0; i < toRemove; i++ {
delete(sc.entries, oldest[i].key)
}
}
}
sc.entries[textHash] = SignatureEntry{
Signature: signature,
Timestamp: time.Now(),
}
}
// GetCachedSignature retrieves a cached signature for a given session and text.
// Returns empty string if not found or expired.
func GetCachedSignature(sessionID, text string) string {
if sessionID == "" || text == "" {
return ""
}
val, ok := signatureCache.Load(sessionID)
if !ok {
return ""
}
sc := val.(*sessionCache)
textHash := hashText(text)
sc.mu.RLock()
entry, exists := sc.entries[textHash]
sc.mu.RUnlock()
if !exists {
return ""
}
// Check if expired
if time.Since(entry.Timestamp) > SignatureCacheTTL {
sc.mu.Lock()
delete(sc.entries, textHash)
sc.mu.Unlock()
return ""
}
return entry.Signature
}
// ClearSignatureCache clears signature cache for a specific session or all sessions.
func ClearSignatureCache(sessionID string) {
if sessionID != "" {
signatureCache.Delete(sessionID)
} else {
signatureCache.Range(func(key, _ any) bool {
signatureCache.Delete(key)
return true
})
}
}
// HasValidSignature checks if a signature is valid (non-empty and long enough)
func HasValidSignature(signature string) bool {
return signature != "" && len(signature) >= MinValidSignatureLen
}

216
internal/cache/signature_cache_test.go vendored Normal file
View File

@@ -0,0 +1,216 @@
package cache
import (
"testing"
"time"
)
func TestCacheSignature_BasicStorageAndRetrieval(t *testing.T) {
ClearSignatureCache("")
sessionID := "test-session-1"
text := "This is some thinking text content"
signature := "abc123validSignature1234567890123456789012345678901234567890"
// Store signature
CacheSignature(sessionID, text, signature)
// Retrieve signature
retrieved := GetCachedSignature(sessionID, text)
if retrieved != signature {
t.Errorf("Expected signature '%s', got '%s'", signature, retrieved)
}
}
func TestCacheSignature_DifferentSessions(t *testing.T) {
ClearSignatureCache("")
text := "Same text in different sessions"
sig1 := "signature1_1234567890123456789012345678901234567890123456"
sig2 := "signature2_1234567890123456789012345678901234567890123456"
CacheSignature("session-a", text, sig1)
CacheSignature("session-b", text, sig2)
if GetCachedSignature("session-a", text) != sig1 {
t.Error("Session-a signature mismatch")
}
if GetCachedSignature("session-b", text) != sig2 {
t.Error("Session-b signature mismatch")
}
}
func TestCacheSignature_NotFound(t *testing.T) {
ClearSignatureCache("")
// Non-existent session
if got := GetCachedSignature("nonexistent", "some text"); got != "" {
t.Errorf("Expected empty string for nonexistent session, got '%s'", got)
}
// Existing session but different text
CacheSignature("session-x", "text-a", "sigA12345678901234567890123456789012345678901234567890")
if got := GetCachedSignature("session-x", "text-b"); got != "" {
t.Errorf("Expected empty string for different text, got '%s'", got)
}
}
func TestCacheSignature_EmptyInputs(t *testing.T) {
ClearSignatureCache("")
// All empty/invalid inputs should be no-ops
CacheSignature("", "text", "sig12345678901234567890123456789012345678901234567890")
CacheSignature("session", "", "sig12345678901234567890123456789012345678901234567890")
CacheSignature("session", "text", "")
CacheSignature("session", "text", "short") // Too short
if got := GetCachedSignature("session", "text"); got != "" {
t.Errorf("Expected empty after invalid cache attempts, got '%s'", got)
}
}
func TestCacheSignature_ShortSignatureRejected(t *testing.T) {
ClearSignatureCache("")
sessionID := "test-short-sig"
text := "Some text"
shortSig := "abc123" // Less than 50 chars
CacheSignature(sessionID, text, shortSig)
if got := GetCachedSignature(sessionID, text); got != "" {
t.Errorf("Short signature should be rejected, got '%s'", got)
}
}
func TestClearSignatureCache_SpecificSession(t *testing.T) {
ClearSignatureCache("")
sig := "validSig1234567890123456789012345678901234567890123456"
CacheSignature("session-1", "text", sig)
CacheSignature("session-2", "text", sig)
ClearSignatureCache("session-1")
if got := GetCachedSignature("session-1", "text"); got != "" {
t.Error("session-1 should be cleared")
}
if got := GetCachedSignature("session-2", "text"); got != sig {
t.Error("session-2 should still exist")
}
}
func TestClearSignatureCache_AllSessions(t *testing.T) {
ClearSignatureCache("")
sig := "validSig1234567890123456789012345678901234567890123456"
CacheSignature("session-1", "text", sig)
CacheSignature("session-2", "text", sig)
ClearSignatureCache("")
if got := GetCachedSignature("session-1", "text"); got != "" {
t.Error("session-1 should be cleared")
}
if got := GetCachedSignature("session-2", "text"); got != "" {
t.Error("session-2 should be cleared")
}
}
func TestHasValidSignature(t *testing.T) {
tests := []struct {
name string
signature string
expected bool
}{
{"valid long signature", "abc123validSignature1234567890123456789012345678901234567890", true},
{"exactly 50 chars", "12345678901234567890123456789012345678901234567890", true},
{"49 chars - invalid", "1234567890123456789012345678901234567890123456789", false},
{"empty string", "", false},
{"short signature", "abc", false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := HasValidSignature(tt.signature)
if result != tt.expected {
t.Errorf("HasValidSignature(%q) = %v, expected %v", tt.signature, result, tt.expected)
}
})
}
}
func TestCacheSignature_TextHashCollisionResistance(t *testing.T) {
ClearSignatureCache("")
sessionID := "hash-test-session"
// Different texts should produce different hashes
text1 := "First thinking text"
text2 := "Second thinking text"
sig1 := "signature1_1234567890123456789012345678901234567890123456"
sig2 := "signature2_1234567890123456789012345678901234567890123456"
CacheSignature(sessionID, text1, sig1)
CacheSignature(sessionID, text2, sig2)
if GetCachedSignature(sessionID, text1) != sig1 {
t.Error("text1 signature mismatch")
}
if GetCachedSignature(sessionID, text2) != sig2 {
t.Error("text2 signature mismatch")
}
}
func TestCacheSignature_UnicodeText(t *testing.T) {
ClearSignatureCache("")
sessionID := "unicode-session"
text := "한글 텍스트와 이모지 🎉 그리고 特殊文字"
sig := "unicodeSig123456789012345678901234567890123456789012345"
CacheSignature(sessionID, text, sig)
if got := GetCachedSignature(sessionID, text); got != sig {
t.Errorf("Unicode text signature retrieval failed, got '%s'", got)
}
}
func TestCacheSignature_Overwrite(t *testing.T) {
ClearSignatureCache("")
sessionID := "overwrite-session"
text := "Same text"
sig1 := "firstSignature12345678901234567890123456789012345678901"
sig2 := "secondSignature1234567890123456789012345678901234567890"
CacheSignature(sessionID, text, sig1)
CacheSignature(sessionID, text, sig2) // Overwrite
if got := GetCachedSignature(sessionID, text); got != sig2 {
t.Errorf("Expected overwritten signature '%s', got '%s'", sig2, got)
}
}
// Note: TTL expiration test is tricky to test without mocking time
// We test the logic path exists but actual expiration would require time manipulation
func TestCacheSignature_ExpirationLogic(t *testing.T) {
ClearSignatureCache("")
// This test verifies the expiration check exists
// In a real scenario, we'd mock time.Now()
sessionID := "expiration-test"
text := "text"
sig := "validSig1234567890123456789012345678901234567890123456"
CacheSignature(sessionID, text, sig)
// Fresh entry should be retrievable
if got := GetCachedSignature(sessionID, text); got != sig {
t.Errorf("Fresh entry should be retrievable, got '%s'", got)
}
// We can't easily test actual expiration without time mocking
// but the logic is verified by the implementation
_ = time.Now() // Acknowledge we're not testing time passage
}

View File

@@ -24,12 +24,17 @@ func DoClaudeLogin(cfg *config.Config, options *LoginOptions) {
options = &LoginOptions{} options = &LoginOptions{}
} }
promptFn := options.Prompt
if promptFn == nil {
promptFn = defaultProjectPrompt()
}
manager := newAuthManager() manager := newAuthManager()
authOpts := &sdkAuth.LoginOptions{ authOpts := &sdkAuth.LoginOptions{
NoBrowser: options.NoBrowser, NoBrowser: options.NoBrowser,
Metadata: map[string]string{}, Metadata: map[string]string{},
Prompt: options.Prompt, Prompt: promptFn,
} }
_, savedPath, err := manager.Login(context.Background(), "claude", cfg, authOpts) _, savedPath, err := manager.Login(context.Background(), "claude", cfg, authOpts)

View File

@@ -15,11 +15,16 @@ func DoAntigravityLogin(cfg *config.Config, options *LoginOptions) {
options = &LoginOptions{} options = &LoginOptions{}
} }
promptFn := options.Prompt
if promptFn == nil {
promptFn = defaultProjectPrompt()
}
manager := newAuthManager() manager := newAuthManager()
authOpts := &sdkAuth.LoginOptions{ authOpts := &sdkAuth.LoginOptions{
NoBrowser: options.NoBrowser, NoBrowser: options.NoBrowser,
Metadata: map[string]string{}, Metadata: map[string]string{},
Prompt: options.Prompt, Prompt: promptFn,
} }
record, savedPath, err := manager.Login(context.Background(), "antigravity", cfg, authOpts) record, savedPath, err := manager.Login(context.Background(), "antigravity", cfg, authOpts)

View File

@@ -20,13 +20,7 @@ func DoIFlowLogin(cfg *config.Config, options *LoginOptions) {
promptFn := options.Prompt promptFn := options.Prompt
if promptFn == nil { if promptFn == nil {
promptFn = func(prompt string) (string, error) { promptFn = defaultProjectPrompt()
fmt.Println()
fmt.Println(prompt)
var value string
_, err := fmt.Scanln(&value)
return value, err
}
} }
authOpts := &sdkAuth.LoginOptions{ authOpts := &sdkAuth.LoginOptions{

View File

@@ -55,11 +55,22 @@ func DoLogin(cfg *config.Config, projectID string, options *LoginOptions) {
ctx := context.Background() ctx := context.Background()
promptFn := options.Prompt
if promptFn == nil {
promptFn = defaultProjectPrompt()
}
trimmedProjectID := strings.TrimSpace(projectID)
callbackPrompt := promptFn
if trimmedProjectID == "" {
callbackPrompt = nil
}
loginOpts := &sdkAuth.LoginOptions{ loginOpts := &sdkAuth.LoginOptions{
NoBrowser: options.NoBrowser, NoBrowser: options.NoBrowser,
ProjectID: strings.TrimSpace(projectID), ProjectID: trimmedProjectID,
Metadata: map[string]string{}, Metadata: map[string]string{},
Prompt: options.Prompt, Prompt: callbackPrompt,
} }
authenticator := sdkAuth.NewGeminiAuthenticator() authenticator := sdkAuth.NewGeminiAuthenticator()
@@ -76,7 +87,10 @@ func DoLogin(cfg *config.Config, projectID string, options *LoginOptions) {
} }
geminiAuth := gemini.NewGeminiAuth() geminiAuth := gemini.NewGeminiAuth()
httpClient, errClient := geminiAuth.GetAuthenticatedClient(ctx, storage, cfg, options.NoBrowser) httpClient, errClient := geminiAuth.GetAuthenticatedClient(ctx, storage, cfg, &gemini.WebLoginOptions{
NoBrowser: options.NoBrowser,
Prompt: callbackPrompt,
})
if errClient != nil { if errClient != nil {
log.Errorf("Gemini authentication failed: %v", errClient) log.Errorf("Gemini authentication failed: %v", errClient)
return return
@@ -90,12 +104,7 @@ func DoLogin(cfg *config.Config, projectID string, options *LoginOptions) {
return return
} }
promptFn := options.Prompt selectedProjectID := promptForProjectSelection(projects, trimmedProjectID, promptFn)
if promptFn == nil {
promptFn = defaultProjectPrompt()
}
selectedProjectID := promptForProjectSelection(projects, strings.TrimSpace(projectID), promptFn)
projectSelections, errSelection := resolveProjectSelections(selectedProjectID, projects) projectSelections, errSelection := resolveProjectSelections(selectedProjectID, projects)
if errSelection != nil { if errSelection != nil {
log.Errorf("Invalid project selection: %v", errSelection) log.Errorf("Invalid project selection: %v", errSelection)

View File

@@ -35,12 +35,17 @@ func DoCodexLogin(cfg *config.Config, options *LoginOptions) {
options = &LoginOptions{} options = &LoginOptions{}
} }
promptFn := options.Prompt
if promptFn == nil {
promptFn = defaultProjectPrompt()
}
manager := newAuthManager() manager := newAuthManager()
authOpts := &sdkAuth.LoginOptions{ authOpts := &sdkAuth.LoginOptions{
NoBrowser: options.NoBrowser, NoBrowser: options.NoBrowser,
Metadata: map[string]string{}, Metadata: map[string]string{},
Prompt: options.Prompt, Prompt: promptFn,
} }
_, savedPath, err := manager.Login(context.Background(), "codex", cfg, authOpts) _, savedPath, err := manager.Login(context.Background(), "codex", cfg, authOpts)

View File

@@ -12,7 +12,6 @@ import (
"strings" "strings"
"syscall" "syscall"
"github.com/router-for-me/CLIProxyAPI/v6/sdk/config"
"golang.org/x/crypto/bcrypt" "golang.org/x/crypto/bcrypt"
"gopkg.in/yaml.v3" "gopkg.in/yaml.v3"
) )
@@ -21,7 +20,7 @@ const DefaultPanelGitHubRepository = "https://github.com/router-for-me/Cli-Proxy
// Config represents the application's configuration, loaded from a YAML file. // Config represents the application's configuration, loaded from a YAML file.
type Config struct { type Config struct {
config.SDKConfig `yaml:",inline"` SDKConfig `yaml:",inline"`
// Host is the network host/interface on which the API server will bind. // Host is the network host/interface on which the API server will bind.
// Default is empty ("") to bind all interfaces (IPv4 + IPv6). Use "127.0.0.1" or "localhost" for local-only access. // Default is empty ("") to bind all interfaces (IPv4 + IPv6). Use "127.0.0.1" or "localhost" for local-only access.
Host string `yaml:"host" json:"-"` Host string `yaml:"host" json:"-"`
@@ -43,6 +42,10 @@ type Config struct {
// LoggingToFile controls whether application logs are written to rotating files or stdout. // LoggingToFile controls whether application logs are written to rotating files or stdout.
LoggingToFile bool `yaml:"logging-to-file" json:"logging-to-file"` LoggingToFile bool `yaml:"logging-to-file" json:"logging-to-file"`
// LogsMaxTotalSizeMB limits the total size (in MB) of log files under the logs directory.
// When exceeded, the oldest log files are deleted until within the limit. Set to 0 to disable.
LogsMaxTotalSizeMB int `yaml:"logs-max-total-size-mb" json:"logs-max-total-size-mb"`
// UsageStatisticsEnabled toggles in-memory usage aggregation; when false, usage data is discarded. // UsageStatisticsEnabled toggles in-memory usage aggregation; when false, usage data is discarded.
UsageStatisticsEnabled bool `yaml:"usage-statistics-enabled" json:"usage-statistics-enabled"` UsageStatisticsEnabled bool `yaml:"usage-statistics-enabled" json:"usage-statistics-enabled"`
@@ -342,6 +345,7 @@ func LoadConfigOptional(configFile string, optional bool) (*Config, error) {
// Set defaults before unmarshal so that absent keys keep defaults. // Set defaults before unmarshal so that absent keys keep defaults.
cfg.Host = "" // Default empty: binds to all interfaces (IPv4 + IPv6) cfg.Host = "" // Default empty: binds to all interfaces (IPv4 + IPv6)
cfg.LoggingToFile = false cfg.LoggingToFile = false
cfg.LogsMaxTotalSizeMB = 0
cfg.UsageStatisticsEnabled = false cfg.UsageStatisticsEnabled = false
cfg.DisableCooling = false cfg.DisableCooling = false
cfg.AmpCode.RestrictManagementToLocalhost = false // Default to false: API key auth is sufficient cfg.AmpCode.RestrictManagementToLocalhost = false // Default to false: API key auth is sufficient
@@ -386,6 +390,10 @@ func LoadConfigOptional(configFile string, optional bool) (*Config, error) {
cfg.RemoteManagement.PanelGitHubRepository = DefaultPanelGitHubRepository cfg.RemoteManagement.PanelGitHubRepository = DefaultPanelGitHubRepository
} }
if cfg.LogsMaxTotalSizeMB < 0 {
cfg.LogsMaxTotalSizeMB = 0
}
// Sync request authentication providers with inline API keys for backwards compatibility. // Sync request authentication providers with inline API keys for backwards compatibility.
syncInlineAccessProvider(&cfg) syncInlineAccessProvider(&cfg)
@@ -692,7 +700,7 @@ func sanitizeConfigForPersist(cfg *Config) *Config {
} }
clone := *cfg clone := *cfg
clone.SDKConfig = cfg.SDKConfig clone.SDKConfig = cfg.SDKConfig
clone.SDKConfig.Access = config.AccessConfig{} clone.SDKConfig.Access = AccessConfig{}
return &clone return &clone
} }

View File

@@ -0,0 +1,87 @@
// Package config provides configuration management for the CLI Proxy API server.
// It handles loading and parsing YAML configuration files, and provides structured
// access to application settings including server port, authentication directory,
// debug settings, proxy configuration, and API keys.
package config
// SDKConfig represents the application's configuration, loaded from a YAML file.
type SDKConfig struct {
// ProxyURL is the URL of an optional proxy server to use for outbound requests.
ProxyURL string `yaml:"proxy-url" json:"proxy-url"`
// ForceModelPrefix requires explicit model prefixes (e.g., "teamA/gemini-3-pro-preview")
// to target prefixed credentials. When false, unprefixed model requests may use prefixed
// credentials as well.
ForceModelPrefix bool `yaml:"force-model-prefix" json:"force-model-prefix"`
// RequestLog enables or disables detailed request logging functionality.
RequestLog bool `yaml:"request-log" json:"request-log"`
// APIKeys is a list of keys for authenticating clients to this proxy server.
APIKeys []string `yaml:"api-keys" json:"api-keys"`
// Access holds request authentication provider configuration.
Access AccessConfig `yaml:"auth,omitempty" json:"auth,omitempty"`
}
// AccessConfig groups request authentication providers.
type AccessConfig struct {
// Providers lists configured authentication providers.
Providers []AccessProvider `yaml:"providers,omitempty" json:"providers,omitempty"`
}
// AccessProvider describes a request authentication provider entry.
type AccessProvider struct {
// Name is the instance identifier for the provider.
Name string `yaml:"name" json:"name"`
// Type selects the provider implementation registered via the SDK.
Type string `yaml:"type" json:"type"`
// SDK optionally names a third-party SDK module providing this provider.
SDK string `yaml:"sdk,omitempty" json:"sdk,omitempty"`
// APIKeys lists inline keys for providers that require them.
APIKeys []string `yaml:"api-keys,omitempty" json:"api-keys,omitempty"`
// Config passes provider-specific options to the implementation.
Config map[string]any `yaml:"config,omitempty" json:"config,omitempty"`
}
const (
// AccessProviderTypeConfigAPIKey is the built-in provider validating inline API keys.
AccessProviderTypeConfigAPIKey = "config-api-key"
// DefaultAccessProviderName is applied when no provider name is supplied.
DefaultAccessProviderName = "config-inline"
)
// ConfigAPIKeyProvider returns the first inline API key provider if present.
func (c *SDKConfig) ConfigAPIKeyProvider() *AccessProvider {
if c == nil {
return nil
}
for i := range c.Access.Providers {
if c.Access.Providers[i].Type == AccessProviderTypeConfigAPIKey {
if c.Access.Providers[i].Name == "" {
c.Access.Providers[i].Name = DefaultAccessProviderName
}
return &c.Access.Providers[i]
}
}
return nil
}
// MakeInlineAPIKeyProvider constructs an inline API key provider configuration.
// It returns nil when no keys are supplied.
func MakeInlineAPIKeyProvider(keys []string) *AccessProvider {
if len(keys) == 0 {
return nil
}
provider := &AccessProvider{
Name: DefaultAccessProviderName,
Type: AccessProviderTypeConfigAPIKey,
APIKeys: append([]string(nil), keys...),
}
return provider
}

View File

@@ -72,39 +72,45 @@ func SetupBaseLogger() {
} }
// ConfigureLogOutput switches the global log destination between rotating files and stdout. // ConfigureLogOutput switches the global log destination between rotating files and stdout.
func ConfigureLogOutput(loggingToFile bool) error { // When logsMaxTotalSizeMB > 0, a background cleaner removes the oldest log files in the logs directory
// until the total size is within the limit.
func ConfigureLogOutput(loggingToFile bool, logsMaxTotalSizeMB int) error {
SetupBaseLogger() SetupBaseLogger()
writerMu.Lock() writerMu.Lock()
defer writerMu.Unlock() defer writerMu.Unlock()
logDir := "logs"
if base := util.WritablePath(); base != "" {
logDir = filepath.Join(base, "logs")
}
protectedPath := ""
if loggingToFile { if loggingToFile {
logDir := "logs"
if base := util.WritablePath(); base != "" {
logDir = filepath.Join(base, "logs")
}
if err := os.MkdirAll(logDir, 0o755); err != nil { if err := os.MkdirAll(logDir, 0o755); err != nil {
return fmt.Errorf("logging: failed to create log directory: %w", err) return fmt.Errorf("logging: failed to create log directory: %w", err)
} }
if logWriter != nil { if logWriter != nil {
_ = logWriter.Close() _ = logWriter.Close()
} }
protectedPath = filepath.Join(logDir, "main.log")
logWriter = &lumberjack.Logger{ logWriter = &lumberjack.Logger{
Filename: filepath.Join(logDir, "main.log"), Filename: protectedPath,
MaxSize: 10, MaxSize: 10,
MaxBackups: 0, MaxBackups: 0,
MaxAge: 0, MaxAge: 0,
Compress: false, Compress: false,
} }
log.SetOutput(logWriter) log.SetOutput(logWriter)
return nil } else {
if logWriter != nil {
_ = logWriter.Close()
logWriter = nil
}
log.SetOutput(os.Stdout)
} }
if logWriter != nil { configureLogDirCleanerLocked(logDir, logsMaxTotalSizeMB, protectedPath)
_ = logWriter.Close()
logWriter = nil
}
log.SetOutput(os.Stdout)
return nil return nil
} }
@@ -112,6 +118,8 @@ func closeLogOutputs() {
writerMu.Lock() writerMu.Lock()
defer writerMu.Unlock() defer writerMu.Unlock()
stopLogDirCleanerLocked()
if logWriter != nil { if logWriter != nil {
_ = logWriter.Close() _ = logWriter.Close()
logWriter = nil logWriter = nil

View File

@@ -0,0 +1,166 @@
package logging
import (
"context"
"os"
"path/filepath"
"sort"
"strings"
"time"
log "github.com/sirupsen/logrus"
)
const logDirCleanerInterval = time.Minute
var logDirCleanerCancel context.CancelFunc
func configureLogDirCleanerLocked(logDir string, maxTotalSizeMB int, protectedPath string) {
stopLogDirCleanerLocked()
if maxTotalSizeMB <= 0 {
return
}
maxBytes := int64(maxTotalSizeMB) * 1024 * 1024
if maxBytes <= 0 {
return
}
dir := strings.TrimSpace(logDir)
if dir == "" {
return
}
ctx, cancel := context.WithCancel(context.Background())
logDirCleanerCancel = cancel
go runLogDirCleaner(ctx, filepath.Clean(dir), maxBytes, strings.TrimSpace(protectedPath))
}
func stopLogDirCleanerLocked() {
if logDirCleanerCancel == nil {
return
}
logDirCleanerCancel()
logDirCleanerCancel = nil
}
func runLogDirCleaner(ctx context.Context, logDir string, maxBytes int64, protectedPath string) {
ticker := time.NewTicker(logDirCleanerInterval)
defer ticker.Stop()
cleanOnce := func() {
deleted, errClean := enforceLogDirSizeLimit(logDir, maxBytes, protectedPath)
if errClean != nil {
log.WithError(errClean).Warn("logging: failed to enforce log directory size limit")
return
}
if deleted > 0 {
log.Debugf("logging: removed %d old log file(s) to enforce log directory size limit", deleted)
}
}
cleanOnce()
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
cleanOnce()
}
}
}
func enforceLogDirSizeLimit(logDir string, maxBytes int64, protectedPath string) (int, error) {
if maxBytes <= 0 {
return 0, nil
}
dir := strings.TrimSpace(logDir)
if dir == "" {
return 0, nil
}
dir = filepath.Clean(dir)
entries, errRead := os.ReadDir(dir)
if errRead != nil {
if os.IsNotExist(errRead) {
return 0, nil
}
return 0, errRead
}
protected := strings.TrimSpace(protectedPath)
if protected != "" {
protected = filepath.Clean(protected)
}
type logFile struct {
path string
size int64
modTime time.Time
}
var (
files []logFile
total int64
)
for _, entry := range entries {
if entry.IsDir() {
continue
}
name := entry.Name()
if !isLogFileName(name) {
continue
}
info, errInfo := entry.Info()
if errInfo != nil {
continue
}
if !info.Mode().IsRegular() {
continue
}
path := filepath.Join(dir, name)
files = append(files, logFile{
path: path,
size: info.Size(),
modTime: info.ModTime(),
})
total += info.Size()
}
if total <= maxBytes {
return 0, nil
}
sort.Slice(files, func(i, j int) bool {
return files[i].modTime.Before(files[j].modTime)
})
deleted := 0
for _, file := range files {
if total <= maxBytes {
break
}
if protected != "" && filepath.Clean(file.path) == protected {
continue
}
if errRemove := os.Remove(file.path); errRemove != nil {
log.WithError(errRemove).Warnf("logging: failed to remove old log file: %s", filepath.Base(file.path))
continue
}
total -= file.size
deleted++
}
return deleted, nil
}
func isLogFileName(name string) bool {
trimmed := strings.TrimSpace(name)
if trimmed == "" {
return false
}
lower := strings.ToLower(trimmed)
return strings.HasSuffix(lower, ".log") || strings.HasSuffix(lower, ".log.gz")
}

View File

@@ -0,0 +1,70 @@
package logging
import (
"os"
"path/filepath"
"testing"
"time"
)
func TestEnforceLogDirSizeLimitDeletesOldest(t *testing.T) {
dir := t.TempDir()
writeLogFile(t, filepath.Join(dir, "old.log"), 60, time.Unix(1, 0))
writeLogFile(t, filepath.Join(dir, "mid.log"), 60, time.Unix(2, 0))
protected := filepath.Join(dir, "main.log")
writeLogFile(t, protected, 60, time.Unix(3, 0))
deleted, err := enforceLogDirSizeLimit(dir, 120, protected)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if deleted != 1 {
t.Fatalf("expected 1 deleted file, got %d", deleted)
}
if _, err := os.Stat(filepath.Join(dir, "old.log")); !os.IsNotExist(err) {
t.Fatalf("expected old.log to be removed, stat error: %v", err)
}
if _, err := os.Stat(filepath.Join(dir, "mid.log")); err != nil {
t.Fatalf("expected mid.log to remain, stat error: %v", err)
}
if _, err := os.Stat(protected); err != nil {
t.Fatalf("expected protected main.log to remain, stat error: %v", err)
}
}
func TestEnforceLogDirSizeLimitSkipsProtected(t *testing.T) {
dir := t.TempDir()
protected := filepath.Join(dir, "main.log")
writeLogFile(t, protected, 200, time.Unix(1, 0))
writeLogFile(t, filepath.Join(dir, "other.log"), 50, time.Unix(2, 0))
deleted, err := enforceLogDirSizeLimit(dir, 100, protected)
if err != nil {
t.Fatalf("unexpected error: %v", err)
}
if deleted != 1 {
t.Fatalf("expected 1 deleted file, got %d", deleted)
}
if _, err := os.Stat(protected); err != nil {
t.Fatalf("expected protected main.log to remain, stat error: %v", err)
}
if _, err := os.Stat(filepath.Join(dir, "other.log")); !os.IsNotExist(err) {
t.Fatalf("expected other.log to be removed, stat error: %v", err)
}
}
func writeLogFile(t *testing.T, path string, size int, modTime time.Time) {
t.Helper()
data := make([]byte, size)
if err := os.WriteFile(path, data, 0o644); err != nil {
t.Fatalf("write file: %v", err)
}
if err := os.Chtimes(path, modTime, modTime); err != nil {
t.Fatalf("set times: %v", err)
}
}

View File

@@ -14,6 +14,7 @@ import (
"regexp" "regexp"
"sort" "sort"
"strings" "strings"
"sync/atomic"
"time" "time"
"github.com/andybalholm/brotli" "github.com/andybalholm/brotli"
@@ -25,6 +26,8 @@ import (
"github.com/router-for-me/CLIProxyAPI/v6/internal/util" "github.com/router-for-me/CLIProxyAPI/v6/internal/util"
) )
var requestLogID atomic.Uint64
// RequestLogger defines the interface for logging HTTP requests and responses. // RequestLogger defines the interface for logging HTTP requests and responses.
// It provides methods for logging both regular and streaming HTTP request/response cycles. // It provides methods for logging both regular and streaming HTTP request/response cycles.
type RequestLogger interface { type RequestLogger interface {
@@ -204,19 +207,52 @@ func (l *FileRequestLogger) logRequest(url, method string, requestHeaders map[st
} }
filePath := filepath.Join(l.logsDir, filename) filePath := filepath.Join(l.logsDir, filename)
// Decompress response if needed requestBodyPath, errTemp := l.writeRequestBodyTempFile(body)
decompressedResponse, err := l.decompressResponse(responseHeaders, response) if errTemp != nil {
if err != nil { log.WithError(errTemp).Warn("failed to create request body temp file, falling back to direct write")
// If decompression fails, log the error but continue with original response }
decompressedResponse = append(response, []byte(fmt.Sprintf("\n[DECOMPRESSION ERROR: %v]", err))...) if requestBodyPath != "" {
defer func() {
if errRemove := os.Remove(requestBodyPath); errRemove != nil {
log.WithError(errRemove).Warn("failed to remove request body temp file")
}
}()
} }
// Create log content responseToWrite, decompressErr := l.decompressResponse(responseHeaders, response)
content := l.formatLogContent(url, method, requestHeaders, body, apiRequest, apiResponse, decompressedResponse, statusCode, responseHeaders, apiResponseErrors) if decompressErr != nil {
// If decompression fails, continue with original response and annotate the log output.
responseToWrite = response
}
// Write to file logFile, errOpen := os.OpenFile(filePath, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0644)
if err = os.WriteFile(filePath, []byte(content), 0644); err != nil { if errOpen != nil {
return fmt.Errorf("failed to write log file: %w", err) return fmt.Errorf("failed to create log file: %w", errOpen)
}
writeErr := l.writeNonStreamingLog(
logFile,
url,
method,
requestHeaders,
body,
requestBodyPath,
apiRequest,
apiResponse,
apiResponseErrors,
statusCode,
responseHeaders,
responseToWrite,
decompressErr,
)
if errClose := logFile.Close(); errClose != nil {
log.WithError(errClose).Warn("failed to close request log file")
if writeErr == nil {
return errClose
}
}
if writeErr != nil {
return fmt.Errorf("failed to write log file: %w", writeErr)
} }
if force && !l.enabled { if force && !l.enabled {
@@ -253,26 +289,38 @@ func (l *FileRequestLogger) LogStreamingRequest(url, method string, headers map[
filename := l.generateFilename(url) filename := l.generateFilename(url)
filePath := filepath.Join(l.logsDir, filename) filePath := filepath.Join(l.logsDir, filename)
// Create and open file requestHeaders := make(map[string][]string, len(headers))
file, err := os.Create(filePath) for key, values := range headers {
if err != nil { headerValues := make([]string, len(values))
return nil, fmt.Errorf("failed to create log file: %w", err) copy(headerValues, values)
requestHeaders[key] = headerValues
} }
// Write initial request information requestBodyPath, errTemp := l.writeRequestBodyTempFile(body)
requestInfo := l.formatRequestInfo(url, method, headers, body) if errTemp != nil {
if _, err = file.WriteString(requestInfo); err != nil { return nil, fmt.Errorf("failed to create request body temp file: %w", errTemp)
_ = file.Close()
return nil, fmt.Errorf("failed to write request info: %w", err)
} }
responseBodyFile, errCreate := os.CreateTemp(l.logsDir, "response-body-*.tmp")
if errCreate != nil {
_ = os.Remove(requestBodyPath)
return nil, fmt.Errorf("failed to create response body temp file: %w", errCreate)
}
responseBodyPath := responseBodyFile.Name()
// Create streaming writer // Create streaming writer
writer := &FileStreamingLogWriter{ writer := &FileStreamingLogWriter{
file: file, logFilePath: filePath,
chunkChan: make(chan []byte, 100), // Buffered channel for async writes url: url,
closeChan: make(chan struct{}), method: method,
errorChan: make(chan error, 1), timestamp: time.Now(),
bufferedChunks: &bytes.Buffer{}, requestHeaders: requestHeaders,
requestBodyPath: requestBodyPath,
responseBodyPath: responseBodyPath,
responseBodyFile: responseBodyFile,
chunkChan: make(chan []byte, 100), // Buffered channel for async writes
closeChan: make(chan struct{}),
errorChan: make(chan error, 1),
} }
// Start async writer goroutine // Start async writer goroutine
@@ -323,7 +371,9 @@ func (l *FileRequestLogger) generateFilename(url string) string {
timestamp := time.Now().Format("2006-01-02T150405-.000000000") timestamp := time.Now().Format("2006-01-02T150405-.000000000")
timestamp = strings.Replace(timestamp, ".", "", -1) timestamp = strings.Replace(timestamp, ".", "", -1)
return fmt.Sprintf("%s-%s.log", sanitized, timestamp) id := requestLogID.Add(1)
return fmt.Sprintf("%s-%s-%d.log", sanitized, timestamp, id)
} }
// sanitizeForFilename replaces characters that are not safe for filenames. // sanitizeForFilename replaces characters that are not safe for filenames.
@@ -405,6 +455,220 @@ func (l *FileRequestLogger) cleanupOldErrorLogs() error {
return nil return nil
} }
func (l *FileRequestLogger) writeRequestBodyTempFile(body []byte) (string, error) {
tmpFile, errCreate := os.CreateTemp(l.logsDir, "request-body-*.tmp")
if errCreate != nil {
return "", errCreate
}
tmpPath := tmpFile.Name()
if _, errCopy := io.Copy(tmpFile, bytes.NewReader(body)); errCopy != nil {
_ = tmpFile.Close()
_ = os.Remove(tmpPath)
return "", errCopy
}
if errClose := tmpFile.Close(); errClose != nil {
_ = os.Remove(tmpPath)
return "", errClose
}
return tmpPath, nil
}
func (l *FileRequestLogger) writeNonStreamingLog(
w io.Writer,
url, method string,
requestHeaders map[string][]string,
requestBody []byte,
requestBodyPath string,
apiRequest []byte,
apiResponse []byte,
apiResponseErrors []*interfaces.ErrorMessage,
statusCode int,
responseHeaders map[string][]string,
response []byte,
decompressErr error,
) error {
if errWrite := writeRequestInfoWithBody(w, url, method, requestHeaders, requestBody, requestBodyPath, time.Now()); errWrite != nil {
return errWrite
}
if errWrite := writeAPISection(w, "=== API REQUEST ===\n", "=== API REQUEST", apiRequest); errWrite != nil {
return errWrite
}
if errWrite := writeAPIErrorResponses(w, apiResponseErrors); errWrite != nil {
return errWrite
}
if errWrite := writeAPISection(w, "=== API RESPONSE ===\n", "=== API RESPONSE", apiResponse); errWrite != nil {
return errWrite
}
return writeResponseSection(w, statusCode, true, responseHeaders, bytes.NewReader(response), decompressErr, true)
}
func writeRequestInfoWithBody(
w io.Writer,
url, method string,
headers map[string][]string,
body []byte,
bodyPath string,
timestamp time.Time,
) error {
if _, errWrite := io.WriteString(w, "=== REQUEST INFO ===\n"); errWrite != nil {
return errWrite
}
if _, errWrite := io.WriteString(w, fmt.Sprintf("Version: %s\n", buildinfo.Version)); errWrite != nil {
return errWrite
}
if _, errWrite := io.WriteString(w, fmt.Sprintf("URL: %s\n", url)); errWrite != nil {
return errWrite
}
if _, errWrite := io.WriteString(w, fmt.Sprintf("Method: %s\n", method)); errWrite != nil {
return errWrite
}
if _, errWrite := io.WriteString(w, fmt.Sprintf("Timestamp: %s\n", timestamp.Format(time.RFC3339Nano))); errWrite != nil {
return errWrite
}
if _, errWrite := io.WriteString(w, "\n"); errWrite != nil {
return errWrite
}
if _, errWrite := io.WriteString(w, "=== HEADERS ===\n"); errWrite != nil {
return errWrite
}
for key, values := range headers {
for _, value := range values {
masked := util.MaskSensitiveHeaderValue(key, value)
if _, errWrite := io.WriteString(w, fmt.Sprintf("%s: %s\n", key, masked)); errWrite != nil {
return errWrite
}
}
}
if _, errWrite := io.WriteString(w, "\n"); errWrite != nil {
return errWrite
}
if _, errWrite := io.WriteString(w, "=== REQUEST BODY ===\n"); errWrite != nil {
return errWrite
}
if bodyPath != "" {
bodyFile, errOpen := os.Open(bodyPath)
if errOpen != nil {
return errOpen
}
if _, errCopy := io.Copy(w, bodyFile); errCopy != nil {
_ = bodyFile.Close()
return errCopy
}
if errClose := bodyFile.Close(); errClose != nil {
log.WithError(errClose).Warn("failed to close request body temp file")
}
} else if _, errWrite := w.Write(body); errWrite != nil {
return errWrite
}
if _, errWrite := io.WriteString(w, "\n\n"); errWrite != nil {
return errWrite
}
return nil
}
func writeAPISection(w io.Writer, sectionHeader string, sectionPrefix string, payload []byte) error {
if len(payload) == 0 {
return nil
}
if bytes.HasPrefix(payload, []byte(sectionPrefix)) {
if _, errWrite := w.Write(payload); errWrite != nil {
return errWrite
}
if !bytes.HasSuffix(payload, []byte("\n")) {
if _, errWrite := io.WriteString(w, "\n"); errWrite != nil {
return errWrite
}
}
} else {
if _, errWrite := io.WriteString(w, sectionHeader); errWrite != nil {
return errWrite
}
if _, errWrite := w.Write(payload); errWrite != nil {
return errWrite
}
if _, errWrite := io.WriteString(w, "\n"); errWrite != nil {
return errWrite
}
}
if _, errWrite := io.WriteString(w, "\n"); errWrite != nil {
return errWrite
}
return nil
}
func writeAPIErrorResponses(w io.Writer, apiResponseErrors []*interfaces.ErrorMessage) error {
for i := 0; i < len(apiResponseErrors); i++ {
if apiResponseErrors[i] == nil {
continue
}
if _, errWrite := io.WriteString(w, "=== API ERROR RESPONSE ===\n"); errWrite != nil {
return errWrite
}
if _, errWrite := io.WriteString(w, fmt.Sprintf("HTTP Status: %d\n", apiResponseErrors[i].StatusCode)); errWrite != nil {
return errWrite
}
if apiResponseErrors[i].Error != nil {
if _, errWrite := io.WriteString(w, apiResponseErrors[i].Error.Error()); errWrite != nil {
return errWrite
}
}
if _, errWrite := io.WriteString(w, "\n\n"); errWrite != nil {
return errWrite
}
}
return nil
}
func writeResponseSection(w io.Writer, statusCode int, statusWritten bool, responseHeaders map[string][]string, responseReader io.Reader, decompressErr error, trailingNewline bool) error {
if _, errWrite := io.WriteString(w, "=== RESPONSE ===\n"); errWrite != nil {
return errWrite
}
if statusWritten {
if _, errWrite := io.WriteString(w, fmt.Sprintf("Status: %d\n", statusCode)); errWrite != nil {
return errWrite
}
}
if responseHeaders != nil {
for key, values := range responseHeaders {
for _, value := range values {
if _, errWrite := io.WriteString(w, fmt.Sprintf("%s: %s\n", key, value)); errWrite != nil {
return errWrite
}
}
}
}
if _, errWrite := io.WriteString(w, "\n"); errWrite != nil {
return errWrite
}
if responseReader != nil {
if _, errCopy := io.Copy(w, responseReader); errCopy != nil {
return errCopy
}
}
if decompressErr != nil {
if _, errWrite := io.WriteString(w, fmt.Sprintf("\n[DECOMPRESSION ERROR: %v]", decompressErr)); errWrite != nil {
return errWrite
}
}
if trailingNewline {
if _, errWrite := io.WriteString(w, "\n"); errWrite != nil {
return errWrite
}
}
return nil
}
// formatLogContent creates the complete log content for non-streaming requests. // formatLogContent creates the complete log content for non-streaming requests.
// //
// Parameters: // Parameters:
@@ -648,13 +912,34 @@ func (l *FileRequestLogger) formatRequestInfo(url, method string, headers map[st
} }
// FileStreamingLogWriter implements StreamingLogWriter for file-based streaming logs. // FileStreamingLogWriter implements StreamingLogWriter for file-based streaming logs.
// It handles asynchronous writing of streaming response chunks to a file. // It spools streaming response chunks to a temporary file to avoid retaining large responses in memory.
// All data is buffered and written in the correct order when Close is called. // The final log file is assembled when Close is called.
type FileStreamingLogWriter struct { type FileStreamingLogWriter struct {
// file is the file where log data is written. // logFilePath is the final log file path.
file *os.File logFilePath string
// chunkChan is a channel for receiving response chunks to buffer. // url is the request URL (masked upstream in middleware).
url string
// method is the HTTP method.
method string
// timestamp is captured when the streaming log is initialized.
timestamp time.Time
// requestHeaders stores the request headers.
requestHeaders map[string][]string
// requestBodyPath is a temporary file path holding the request body.
requestBodyPath string
// responseBodyPath is a temporary file path holding the streaming response body.
responseBodyPath string
// responseBodyFile is the temp file where chunks are appended by the async writer.
responseBodyFile *os.File
// chunkChan is a channel for receiving response chunks to spool.
chunkChan chan []byte chunkChan chan []byte
// closeChan is a channel for signaling when the writer is closed. // closeChan is a channel for signaling when the writer is closed.
@@ -663,9 +948,6 @@ type FileStreamingLogWriter struct {
// errorChan is a channel for reporting errors during writing. // errorChan is a channel for reporting errors during writing.
errorChan chan error errorChan chan error
// bufferedChunks stores the response chunks in order.
bufferedChunks *bytes.Buffer
// responseStatus stores the HTTP status code. // responseStatus stores the HTTP status code.
responseStatus int responseStatus int
@@ -770,85 +1052,115 @@ func (w *FileStreamingLogWriter) Close() error {
close(w.chunkChan) close(w.chunkChan)
} }
// Wait for async writer to finish buffering chunks // Wait for async writer to finish spooling chunks
if w.closeChan != nil { if w.closeChan != nil {
<-w.closeChan <-w.closeChan
w.chunkChan = nil w.chunkChan = nil
} }
if w.file == nil { select {
case errWrite := <-w.errorChan:
w.cleanupTempFiles()
return errWrite
default:
}
if w.logFilePath == "" {
w.cleanupTempFiles()
return nil return nil
} }
// Write all content in the correct order logFile, errOpen := os.OpenFile(w.logFilePath, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0644)
var content strings.Builder if errOpen != nil {
w.cleanupTempFiles()
// 1. Write API REQUEST section return fmt.Errorf("failed to create log file: %w", errOpen)
if len(w.apiRequest) > 0 {
if bytes.HasPrefix(w.apiRequest, []byte("=== API REQUEST")) {
content.Write(w.apiRequest)
if !bytes.HasSuffix(w.apiRequest, []byte("\n")) {
content.WriteString("\n")
}
} else {
content.WriteString("=== API REQUEST ===\n")
content.Write(w.apiRequest)
content.WriteString("\n")
}
content.WriteString("\n")
} }
// 2. Write API RESPONSE section writeErr := w.writeFinalLog(logFile)
if len(w.apiResponse) > 0 { if errClose := logFile.Close(); errClose != nil {
if bytes.HasPrefix(w.apiResponse, []byte("=== API RESPONSE")) { log.WithError(errClose).Warn("failed to close request log file")
content.Write(w.apiResponse) if writeErr == nil {
if !bytes.HasSuffix(w.apiResponse, []byte("\n")) { writeErr = errClose
content.WriteString("\n")
}
} else {
content.WriteString("=== API RESPONSE ===\n")
content.Write(w.apiResponse)
content.WriteString("\n")
}
content.WriteString("\n")
}
// 3. Write RESPONSE section (status, headers, buffered chunks)
content.WriteString("=== RESPONSE ===\n")
if w.statusWritten {
content.WriteString(fmt.Sprintf("Status: %d\n", w.responseStatus))
}
for key, values := range w.responseHeaders {
for _, value := range values {
content.WriteString(fmt.Sprintf("%s: %s\n", key, value))
} }
} }
content.WriteString("\n")
// Write buffered response body chunks w.cleanupTempFiles()
if w.bufferedChunks != nil && w.bufferedChunks.Len() > 0 { return writeErr
content.Write(w.bufferedChunks.Bytes())
}
// Write the complete content to file
if _, err := w.file.WriteString(content.String()); err != nil {
_ = w.file.Close()
return err
}
return w.file.Close()
} }
// asyncWriter runs in a goroutine to buffer chunks from the channel. // asyncWriter runs in a goroutine to buffer chunks from the channel.
// It continuously reads chunks from the channel and buffers them for later writing. // It continuously reads chunks from the channel and appends them to a temp file for later assembly.
func (w *FileStreamingLogWriter) asyncWriter() { func (w *FileStreamingLogWriter) asyncWriter() {
defer close(w.closeChan) defer close(w.closeChan)
for chunk := range w.chunkChan { for chunk := range w.chunkChan {
if w.bufferedChunks != nil { if w.responseBodyFile == nil {
w.bufferedChunks.Write(chunk) continue
} }
if _, errWrite := w.responseBodyFile.Write(chunk); errWrite != nil {
select {
case w.errorChan <- errWrite:
default:
}
if errClose := w.responseBodyFile.Close(); errClose != nil {
select {
case w.errorChan <- errClose:
default:
}
}
w.responseBodyFile = nil
}
}
if w.responseBodyFile == nil {
return
}
if errClose := w.responseBodyFile.Close(); errClose != nil {
select {
case w.errorChan <- errClose:
default:
}
}
w.responseBodyFile = nil
}
func (w *FileStreamingLogWriter) writeFinalLog(logFile *os.File) error {
if errWrite := writeRequestInfoWithBody(logFile, w.url, w.method, w.requestHeaders, nil, w.requestBodyPath, w.timestamp); errWrite != nil {
return errWrite
}
if errWrite := writeAPISection(logFile, "=== API REQUEST ===\n", "=== API REQUEST", w.apiRequest); errWrite != nil {
return errWrite
}
if errWrite := writeAPISection(logFile, "=== API RESPONSE ===\n", "=== API RESPONSE", w.apiResponse); errWrite != nil {
return errWrite
}
responseBodyFile, errOpen := os.Open(w.responseBodyPath)
if errOpen != nil {
return errOpen
}
defer func() {
if errClose := responseBodyFile.Close(); errClose != nil {
log.WithError(errClose).Warn("failed to close response body temp file")
}
}()
return writeResponseSection(logFile, w.responseStatus, w.statusWritten, w.responseHeaders, responseBodyFile, nil, false)
}
func (w *FileStreamingLogWriter) cleanupTempFiles() {
if w.requestBodyPath != "" {
if errRemove := os.Remove(w.requestBodyPath); errRemove != nil {
log.WithError(errRemove).Warn("failed to remove request body temp file")
}
w.requestBodyPath = ""
}
if w.responseBodyPath != "" {
if errRemove := os.Remove(w.responseBodyPath); errRemove != nil {
log.WithError(errRemove).Warn("failed to remove response body temp file")
}
w.responseBodyPath = ""
} }
} }

View File

@@ -4,6 +4,8 @@ import (
"crypto/rand" "crypto/rand"
"encoding/hex" "encoding/hex"
"fmt" "fmt"
"net/url"
"strings"
) )
// GenerateRandomState generates a cryptographically secure random state parameter // GenerateRandomState generates a cryptographically secure random state parameter
@@ -19,3 +21,83 @@ func GenerateRandomState() (string, error) {
} }
return hex.EncodeToString(bytes), nil return hex.EncodeToString(bytes), nil
} }
// OAuthCallback captures the parsed OAuth callback parameters.
type OAuthCallback struct {
Code string
State string
Error string
ErrorDescription string
}
// ParseOAuthCallback extracts OAuth parameters from a callback URL.
// It returns nil when the input is empty.
func ParseOAuthCallback(input string) (*OAuthCallback, error) {
trimmed := strings.TrimSpace(input)
if trimmed == "" {
return nil, nil
}
candidate := trimmed
if !strings.Contains(candidate, "://") {
if strings.HasPrefix(candidate, "?") {
candidate = "http://localhost" + candidate
} else if strings.ContainsAny(candidate, "/?#") || strings.Contains(candidate, ":") {
candidate = "http://" + candidate
} else if strings.Contains(candidate, "=") {
candidate = "http://localhost/?" + candidate
} else {
return nil, fmt.Errorf("invalid callback URL")
}
}
parsedURL, err := url.Parse(candidate)
if err != nil {
return nil, err
}
query := parsedURL.Query()
code := strings.TrimSpace(query.Get("code"))
state := strings.TrimSpace(query.Get("state"))
errCode := strings.TrimSpace(query.Get("error"))
errDesc := strings.TrimSpace(query.Get("error_description"))
if parsedURL.Fragment != "" {
if fragQuery, errFrag := url.ParseQuery(parsedURL.Fragment); errFrag == nil {
if code == "" {
code = strings.TrimSpace(fragQuery.Get("code"))
}
if state == "" {
state = strings.TrimSpace(fragQuery.Get("state"))
}
if errCode == "" {
errCode = strings.TrimSpace(fragQuery.Get("error"))
}
if errDesc == "" {
errDesc = strings.TrimSpace(fragQuery.Get("error_description"))
}
}
}
if code != "" && state == "" && strings.Contains(code, "#") {
parts := strings.SplitN(code, "#", 2)
code = parts[0]
state = parts[1]
}
if errCode == "" && errDesc != "" {
errCode = errDesc
errDesc = ""
}
if code == "" && errCode == "" {
return nil, fmt.Errorf("callback URL missing code")
}
return &OAuthCallback{
Code: code,
State: state,
Error: errCode,
ErrorDescription: errDesc,
}, nil
}

View File

@@ -162,6 +162,21 @@ func GetGeminiModels() []*ModelInfo {
SupportedGenerationMethods: []string{"generateContent", "countTokens", "createCachedContent", "batchGenerateContent"}, SupportedGenerationMethods: []string{"generateContent", "countTokens", "createCachedContent", "batchGenerateContent"},
Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"low", "high"}}, Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"low", "high"}},
}, },
{
ID: "gemini-3-flash-preview",
Object: "model",
Created: 1765929600,
OwnedBy: "google",
Type: "gemini",
Name: "models/gemini-3-flash-preview",
Version: "3.0",
DisplayName: "Gemini 3 Flash Preview",
Description: "Gemini 3 Flash Preview",
InputTokenLimit: 1048576,
OutputTokenLimit: 65536,
SupportedGenerationMethods: []string{"generateContent", "countTokens", "createCachedContent", "batchGenerateContent"},
Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"minimal", "low", "medium", "high"}},
},
{ {
ID: "gemini-3-pro-image-preview", ID: "gemini-3-pro-image-preview",
Object: "model", Object: "model",

View File

@@ -325,8 +325,8 @@ func (e *AIStudioExecutor) translateRequest(req cliproxyexecutor.Request, opts c
payload = ApplyThinkingMetadata(payload, req.Metadata, req.Model) payload = ApplyThinkingMetadata(payload, req.Metadata, req.Model)
payload = util.ApplyGemini3ThinkingLevelFromMetadata(req.Model, req.Metadata, payload) payload = util.ApplyGemini3ThinkingLevelFromMetadata(req.Model, req.Metadata, payload)
payload = util.ApplyDefaultThinkingIfNeeded(req.Model, payload) payload = util.ApplyDefaultThinkingIfNeeded(req.Model, payload)
payload = util.ConvertThinkingLevelToBudget(payload, req.Model) payload = util.ConvertThinkingLevelToBudget(payload, req.Model, true)
payload = util.NormalizeGeminiThinkingBudget(req.Model, payload) payload = util.NormalizeGeminiThinkingBudget(req.Model, payload, true)
payload = util.StripThinkingConfigIfUnsupported(req.Model, payload) payload = util.StripThinkingConfigIfUnsupported(req.Model, payload)
payload = fixGeminiImageAspectRatio(req.Model, payload) payload = fixGeminiImageAspectRatio(req.Model, payload)
payload = applyPayloadConfig(e.cfg, req.Model, payload) payload = applyPayloadConfig(e.cfg, req.Model, payload)

View File

@@ -7,6 +7,8 @@ import (
"bufio" "bufio"
"bytes" "bytes"
"context" "context"
"crypto/sha256"
"encoding/binary"
"encoding/json" "encoding/json"
"fmt" "fmt"
"io" "io"
@@ -97,6 +99,7 @@ func (e *AntigravityExecutor) Execute(ctx context.Context, auth *cliproxyauth.Au
translated = util.ApplyGemini3ThinkingLevelFromMetadataCLI(req.Model, req.Metadata, translated) translated = util.ApplyGemini3ThinkingLevelFromMetadataCLI(req.Model, req.Metadata, translated)
translated = util.ApplyDefaultThinkingIfNeededCLI(req.Model, translated) translated = util.ApplyDefaultThinkingIfNeededCLI(req.Model, translated)
translated = normalizeAntigravityThinking(req.Model, translated) translated = normalizeAntigravityThinking(req.Model, translated)
translated = applyPayloadConfigWithRoot(e.cfg, req.Model, "antigravity", "request", translated)
baseURLs := antigravityBaseURLFallbackOrder(auth) baseURLs := antigravityBaseURLFallbackOrder(auth)
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0) httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
@@ -191,6 +194,7 @@ func (e *AntigravityExecutor) executeClaudeNonStream(ctx context.Context, auth *
translated = util.ApplyGemini3ThinkingLevelFromMetadataCLI(req.Model, req.Metadata, translated) translated = util.ApplyGemini3ThinkingLevelFromMetadataCLI(req.Model, req.Metadata, translated)
translated = util.ApplyDefaultThinkingIfNeededCLI(req.Model, translated) translated = util.ApplyDefaultThinkingIfNeededCLI(req.Model, translated)
translated = normalizeAntigravityThinking(req.Model, translated) translated = normalizeAntigravityThinking(req.Model, translated)
translated = applyPayloadConfigWithRoot(e.cfg, req.Model, "antigravity", "request", translated)
baseURLs := antigravityBaseURLFallbackOrder(auth) baseURLs := antigravityBaseURLFallbackOrder(auth)
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0) httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
@@ -524,6 +528,7 @@ func (e *AntigravityExecutor) ExecuteStream(ctx context.Context, auth *cliproxya
translated = util.ApplyGemini3ThinkingLevelFromMetadataCLI(req.Model, req.Metadata, translated) translated = util.ApplyGemini3ThinkingLevelFromMetadataCLI(req.Model, req.Metadata, translated)
translated = util.ApplyDefaultThinkingIfNeededCLI(req.Model, translated) translated = util.ApplyDefaultThinkingIfNeededCLI(req.Model, translated)
translated = normalizeAntigravityThinking(req.Model, translated) translated = normalizeAntigravityThinking(req.Model, translated)
translated = applyPayloadConfigWithRoot(e.cfg, req.Model, "antigravity", "request", translated)
baseURLs := antigravityBaseURLFallbackOrder(auth) baseURLs := antigravityBaseURLFallbackOrder(auth)
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0) httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
@@ -1011,7 +1016,7 @@ func (e *AntigravityExecutor) buildRequest(ctx context.Context, auth *cliproxyau
// Use the centralized schema cleaner to handle unsupported keywords, // Use the centralized schema cleaner to handle unsupported keywords,
// const->enum conversion, and flattening of types/anyOf. // const->enum conversion, and flattening of types/anyOf.
strJSON = util.CleanJSONSchemaForGemini(strJSON) strJSON = util.CleanJSONSchemaForAntigravity(strJSON)
payload = []byte(strJSON) payload = []byte(strJSON)
} }
@@ -1187,7 +1192,7 @@ func geminiToAntigravity(modelName string, payload []byte, projectID string) []b
template, _ = sjson.Set(template, "project", generateProjectID()) template, _ = sjson.Set(template, "project", generateProjectID())
} }
template, _ = sjson.Set(template, "requestId", generateRequestID()) template, _ = sjson.Set(template, "requestId", generateRequestID())
template, _ = sjson.Set(template, "request.sessionId", generateSessionID()) template, _ = sjson.Set(template, "request.sessionId", generateStableSessionID(payload))
template, _ = sjson.Delete(template, "request.safetySettings") template, _ = sjson.Delete(template, "request.safetySettings")
template, _ = sjson.Set(template, "request.toolConfig.functionCallingConfig.mode", "VALIDATED") template, _ = sjson.Set(template, "request.toolConfig.functionCallingConfig.mode", "VALIDATED")
@@ -1229,6 +1234,23 @@ func generateSessionID() string {
return "-" + strconv.FormatInt(n, 10) return "-" + strconv.FormatInt(n, 10)
} }
func generateStableSessionID(payload []byte) string {
contents := gjson.GetBytes(payload, "request.contents")
if contents.IsArray() {
for _, content := range contents.Array() {
if content.Get("role").String() == "user" {
text := content.Get("parts.0.text").String()
if text != "" {
h := sha256.Sum256([]byte(text))
n := int64(binary.BigEndian.Uint64(h[:8])) & 0x7FFFFFFFFFFFFFFF
return "-" + strconv.FormatInt(n, 10)
}
}
}
}
return generateSessionID()
}
func generateProjectID() string { func generateProjectID() string {
adjectives := []string{"useful", "bright", "swift", "calm", "bold"} adjectives := []string{"useful", "bright", "swift", "calm", "bold"}
nouns := []string{"fuze", "wave", "spark", "flow", "core"} nouns := []string{"fuze", "wave", "spark", "flow", "core"}

View File

@@ -7,15 +7,40 @@ package claude
import ( import (
"bytes" "bytes"
"crypto/sha256"
"encoding/hex"
"strings" "strings"
"github.com/router-for-me/CLIProxyAPI/v6/internal/cache"
"github.com/router-for-me/CLIProxyAPI/v6/internal/translator/gemini/common" "github.com/router-for-me/CLIProxyAPI/v6/internal/translator/gemini/common"
"github.com/router-for-me/CLIProxyAPI/v6/internal/util" "github.com/router-for-me/CLIProxyAPI/v6/internal/util"
log "github.com/sirupsen/logrus"
"github.com/tidwall/gjson" "github.com/tidwall/gjson"
"github.com/tidwall/sjson" "github.com/tidwall/sjson"
) )
const geminiCLIClaudeThoughtSignature = "skip_thought_signature_validator" // deriveSessionID generates a stable session ID from the request.
// Uses the hash of the first user message to identify the conversation.
func deriveSessionID(rawJSON []byte) string {
messages := gjson.GetBytes(rawJSON, "messages")
if !messages.IsArray() {
return ""
}
for _, msg := range messages.Array() {
if msg.Get("role").String() == "user" {
content := msg.Get("content").String()
if content == "" {
// Try to get text from content array
content = msg.Get("content.0.text").String()
}
if content != "" {
h := sha256.Sum256([]byte(content))
return hex.EncodeToString(h[:16])
}
}
}
return ""
}
// ConvertClaudeRequestToAntigravity parses and transforms a Claude Code API request into Gemini CLI API format. // ConvertClaudeRequestToAntigravity parses and transforms a Claude Code API request into Gemini CLI API format.
// It extracts the model name, system instruction, message contents, and tool declarations // It extracts the model name, system instruction, message contents, and tool declarations
@@ -37,7 +62,9 @@ const geminiCLIClaudeThoughtSignature = "skip_thought_signature_validator"
// - []byte: The transformed request data in Gemini CLI API format // - []byte: The transformed request data in Gemini CLI API format
func ConvertClaudeRequestToAntigravity(modelName string, inputRawJSON []byte, _ bool) []byte { func ConvertClaudeRequestToAntigravity(modelName string, inputRawJSON []byte, _ bool) []byte {
rawJSON := bytes.Clone(inputRawJSON) rawJSON := bytes.Clone(inputRawJSON)
rawJSON = bytes.Replace(rawJSON, []byte(`"url":{"type":"string","format":"uri",`), []byte(`"url":{"type":"string",`), -1)
// Derive session ID for signature caching
sessionID := deriveSessionID(rawJSON)
// system instruction // system instruction
systemInstructionJSON := "" systemInstructionJSON := ""
@@ -64,16 +91,19 @@ func ConvertClaudeRequestToAntigravity(modelName string, inputRawJSON []byte, _
// contents // contents
contentsJSON := "[]" contentsJSON := "[]"
hasContents := false hasContents := false
messagesResult := gjson.GetBytes(rawJSON, "messages") messagesResult := gjson.GetBytes(rawJSON, "messages")
if messagesResult.IsArray() { if messagesResult.IsArray() {
messageResults := messagesResult.Array() messageResults := messagesResult.Array()
for i := 0; i < len(messageResults); i++ { numMessages := len(messageResults)
for i := 0; i < numMessages; i++ {
messageResult := messageResults[i] messageResult := messageResults[i]
roleResult := messageResult.Get("role") roleResult := messageResult.Get("role")
if roleResult.Type != gjson.String { if roleResult.Type != gjson.String {
continue continue
} }
role := roleResult.String() originalRole := roleResult.String()
role := originalRole
if role == "assistant" { if role == "assistant" {
role = "model" role = "model"
} }
@@ -82,20 +112,59 @@ func ConvertClaudeRequestToAntigravity(modelName string, inputRawJSON []byte, _
contentsResult := messageResult.Get("content") contentsResult := messageResult.Get("content")
if contentsResult.IsArray() { if contentsResult.IsArray() {
contentResults := contentsResult.Array() contentResults := contentsResult.Array()
for j := 0; j < len(contentResults); j++ { numContents := len(contentResults)
var currentMessageThinkingSignature string
for j := 0; j < numContents; j++ {
contentResult := contentResults[j] contentResult := contentResults[j]
contentTypeResult := contentResult.Get("type") contentTypeResult := contentResult.Get("type")
if contentTypeResult.Type == gjson.String && contentTypeResult.String() == "thinking" { if contentTypeResult.Type == gjson.String && contentTypeResult.String() == "thinking" {
prompt := contentResult.Get("thinking").String() // Use GetThinkingText to handle wrapped thinking objects
thinkingText := util.GetThinkingText(contentResult)
signatureResult := contentResult.Get("signature") signatureResult := contentResult.Get("signature")
signature := geminiCLIClaudeThoughtSignature clientSignature := ""
if signatureResult.Exists() { if signatureResult.Exists() && signatureResult.String() != "" {
signature = signatureResult.String() clientSignature = signatureResult.String()
}
// Always try cached signature first (more reliable than client-provided)
// Client may send stale or invalid signatures from different sessions
signature := ""
if sessionID != "" && thinkingText != "" {
if cachedSig := cache.GetCachedSignature(sessionID, thinkingText); cachedSig != "" {
signature = cachedSig
log.Debugf("Using cached signature for thinking block")
} }
}
// Fallback to client signature only if cache miss and client signature is valid
if signature == "" && cache.HasValidSignature(clientSignature) {
signature = clientSignature
log.Debugf("Using client-provided signature for thinking block")
}
// Store for subsequent tool_use in the same message
if cache.HasValidSignature(signature) {
currentMessageThinkingSignature = signature
}
// Skip trailing unsigned thinking blocks on last assistant message
isUnsigned := !cache.HasValidSignature(signature)
// If unsigned, skip entirely (don't convert to text)
// Claude requires assistant messages to start with thinking blocks when thinking is enabled
// Converting to text would break this requirement
if isUnsigned {
// TypeScript plugin approach: drop unsigned thinking blocks entirely
log.Debugf("Dropping unsigned thinking block (no valid signature)")
continue
}
// Valid signature, send as thought block
partJSON := `{}` partJSON := `{}`
partJSON, _ = sjson.Set(partJSON, "thought", true) partJSON, _ = sjson.Set(partJSON, "thought", true)
if prompt != "" { if thinkingText != "" {
partJSON, _ = sjson.Set(partJSON, "text", prompt) partJSON, _ = sjson.Set(partJSON, "text", thinkingText)
} }
if signature != "" { if signature != "" {
partJSON, _ = sjson.Set(partJSON, "thoughtSignature", signature) partJSON, _ = sjson.Set(partJSON, "thoughtSignature", signature)
@@ -109,24 +178,47 @@ func ConvertClaudeRequestToAntigravity(modelName string, inputRawJSON []byte, _
} }
clientContentJSON, _ = sjson.SetRaw(clientContentJSON, "parts.-1", partJSON) clientContentJSON, _ = sjson.SetRaw(clientContentJSON, "parts.-1", partJSON)
} else if contentTypeResult.Type == gjson.String && contentTypeResult.String() == "tool_use" { } else if contentTypeResult.Type == gjson.String && contentTypeResult.String() == "tool_use" {
// NOTE: Do NOT inject dummy thinking blocks here.
// Antigravity API validates signatures, so dummy values are rejected.
// The TypeScript plugin removes unsigned thinking blocks instead of injecting dummies.
functionName := contentResult.Get("name").String() functionName := contentResult.Get("name").String()
functionArgs := contentResult.Get("input").String() argsResult := contentResult.Get("input")
functionID := contentResult.Get("id").String() functionID := contentResult.Get("id").String()
if gjson.Valid(functionArgs) {
argsResult := gjson.Parse(functionArgs) // Handle both object and string input formats
if argsResult.IsObject() { var argsRaw string
partJSON := `{}` if argsResult.IsObject() {
if !strings.Contains(modelName, "claude") { argsRaw = argsResult.Raw
partJSON, _ = sjson.Set(partJSON, "thoughtSignature", geminiCLIClaudeThoughtSignature) } else if argsResult.Type == gjson.String {
} // Input is a JSON string, parse and validate it
if functionID != "" { parsed := gjson.Parse(argsResult.String())
partJSON, _ = sjson.Set(partJSON, "functionCall.id", functionID) if parsed.IsObject() {
} argsRaw = parsed.Raw
partJSON, _ = sjson.Set(partJSON, "functionCall.name", functionName)
partJSON, _ = sjson.SetRaw(partJSON, "functionCall.args", argsResult.Raw)
clientContentJSON, _ = sjson.SetRaw(clientContentJSON, "parts.-1", partJSON)
} }
} }
if argsRaw != "" {
partJSON := `{}`
// Use skip_thought_signature_validator for tool calls without valid thinking signature
// This is the approach used in opencode-google-antigravity-auth for Gemini
// and also works for Claude through Antigravity API
const skipSentinel = "skip_thought_signature_validator"
if cache.HasValidSignature(currentMessageThinkingSignature) {
partJSON, _ = sjson.Set(partJSON, "thoughtSignature", currentMessageThinkingSignature)
} else {
// No valid signature - use skip sentinel to bypass validation
partJSON, _ = sjson.Set(partJSON, "thoughtSignature", skipSentinel)
}
if functionID != "" {
partJSON, _ = sjson.Set(partJSON, "functionCall.id", functionID)
}
partJSON, _ = sjson.Set(partJSON, "functionCall.name", functionName)
partJSON, _ = sjson.SetRaw(partJSON, "functionCall.args", argsRaw)
clientContentJSON, _ = sjson.SetRaw(clientContentJSON, "parts.-1", partJSON)
}
} else if contentTypeResult.Type == gjson.String && contentTypeResult.String() == "tool_result" { } else if contentTypeResult.Type == gjson.String && contentTypeResult.String() == "tool_result" {
toolCallID := contentResult.Get("tool_use_id").String() toolCallID := contentResult.Get("tool_use_id").String()
if toolCallID != "" { if toolCallID != "" {
@@ -180,6 +272,37 @@ func ConvertClaudeRequestToAntigravity(modelName string, inputRawJSON []byte, _
} }
} }
} }
// Reorder parts for 'model' role to ensure thinking block is first
if role == "model" {
partsResult := gjson.Get(clientContentJSON, "parts")
if partsResult.IsArray() {
parts := partsResult.Array()
var thinkingParts []gjson.Result
var otherParts []gjson.Result
for _, part := range parts {
if part.Get("thought").Bool() {
thinkingParts = append(thinkingParts, part)
} else {
otherParts = append(otherParts, part)
}
}
if len(thinkingParts) > 0 {
firstPartIsThinking := parts[0].Get("thought").Bool()
if !firstPartIsThinking || len(thinkingParts) > 1 {
var newParts []interface{}
for _, p := range thinkingParts {
newParts = append(newParts, p.Value())
}
for _, p := range otherParts {
newParts = append(newParts, p.Value())
}
clientContentJSON, _ = sjson.Set(clientContentJSON, "parts", newParts)
}
}
}
}
contentsJSON, _ = sjson.SetRaw(contentsJSON, "-1", clientContentJSON) contentsJSON, _ = sjson.SetRaw(contentsJSON, "-1", clientContentJSON)
hasContents = true hasContents = true
} else if contentsResult.Type == gjson.String { } else if contentsResult.Type == gjson.String {
@@ -206,11 +329,14 @@ func ConvertClaudeRequestToAntigravity(modelName string, inputRawJSON []byte, _
toolResult := toolsResults[i] toolResult := toolsResults[i]
inputSchemaResult := toolResult.Get("input_schema") inputSchemaResult := toolResult.Get("input_schema")
if inputSchemaResult.Exists() && inputSchemaResult.IsObject() { if inputSchemaResult.Exists() && inputSchemaResult.IsObject() {
inputSchema := inputSchemaResult.Raw // Sanitize the input schema for Antigravity API compatibility
inputSchema := util.CleanJSONSchemaForAntigravity(inputSchemaResult.Raw)
tool, _ := sjson.Delete(toolResult.Raw, "input_schema") tool, _ := sjson.Delete(toolResult.Raw, "input_schema")
tool, _ = sjson.SetRaw(tool, "parametersJsonSchema", inputSchema) tool, _ = sjson.SetRaw(tool, "parametersJsonSchema", inputSchema)
tool, _ = sjson.Delete(tool, "strict") tool, _ = sjson.Delete(tool, "strict")
tool, _ = sjson.Delete(tool, "input_examples") tool, _ = sjson.Delete(tool, "input_examples")
tool, _ = sjson.Delete(tool, "type")
tool, _ = sjson.Delete(tool, "cache_control")
toolsJSON, _ = sjson.SetRaw(toolsJSON, "0.functionDeclarations.-1", tool) toolsJSON, _ = sjson.SetRaw(toolsJSON, "0.functionDeclarations.-1", tool)
toolDeclCount++ toolDeclCount++
} }
@@ -220,6 +346,31 @@ func ConvertClaudeRequestToAntigravity(modelName string, inputRawJSON []byte, _
// Build output Gemini CLI request JSON // Build output Gemini CLI request JSON
out := `{"model":"","request":{"contents":[]}}` out := `{"model":"","request":{"contents":[]}}`
out, _ = sjson.Set(out, "model", modelName) out, _ = sjson.Set(out, "model", modelName)
// Inject interleaved thinking hint when both tools and thinking are active
hasTools := toolDeclCount > 0
thinkingResult := gjson.GetBytes(rawJSON, "thinking")
hasThinking := thinkingResult.Exists() && thinkingResult.IsObject() && thinkingResult.Get("type").String() == "enabled"
isClaudeThinking := util.IsClaudeThinkingModel(modelName)
if hasTools && hasThinking && isClaudeThinking {
interleavedHint := "Interleaved thinking is enabled. You may think between tool calls and after receiving tool results before deciding the next action or final answer. Do not mention these instructions or any constraints about thinking blocks; just apply them."
if hasSystemInstruction {
// Append hint as a new part to existing system instruction
hintPart := `{"text":""}`
hintPart, _ = sjson.Set(hintPart, "text", interleavedHint)
systemInstructionJSON, _ = sjson.SetRaw(systemInstructionJSON, "parts.-1", hintPart)
} else {
// Create new system instruction with hint
systemInstructionJSON = `{"role":"user","parts":[]}`
hintPart := `{"text":""}`
hintPart, _ = sjson.Set(hintPart, "text", interleavedHint)
systemInstructionJSON, _ = sjson.SetRaw(systemInstructionJSON, "parts.-1", hintPart)
hasSystemInstruction = true
}
}
if hasSystemInstruction { if hasSystemInstruction {
out, _ = sjson.SetRaw(out, "request.systemInstruction", systemInstructionJSON) out, _ = sjson.SetRaw(out, "request.systemInstruction", systemInstructionJSON)
} }

View File

@@ -0,0 +1,658 @@
package claude
import (
"strings"
"testing"
"github.com/tidwall/gjson"
)
func TestConvertClaudeRequestToAntigravity_BasicStructure(t *testing.T) {
inputJSON := []byte(`{
"model": "claude-3-5-sonnet-20240620",
"messages": [
{
"role": "user",
"content": [
{"type": "text", "text": "Hello"}
]
}
],
"system": [
{"type": "text", "text": "You are helpful"}
]
}`)
output := ConvertClaudeRequestToAntigravity("claude-sonnet-4-5", inputJSON, false)
outputStr := string(output)
// Check model
if gjson.Get(outputStr, "model").String() != "claude-sonnet-4-5" {
t.Errorf("Expected model 'claude-sonnet-4-5', got '%s'", gjson.Get(outputStr, "model").String())
}
// Check contents exist
contents := gjson.Get(outputStr, "request.contents")
if !contents.Exists() || !contents.IsArray() {
t.Error("request.contents should exist and be an array")
}
// Check role mapping (assistant -> model)
firstContent := gjson.Get(outputStr, "request.contents.0")
if firstContent.Get("role").String() != "user" {
t.Errorf("Expected role 'user', got '%s'", firstContent.Get("role").String())
}
// Check systemInstruction
sysInstruction := gjson.Get(outputStr, "request.systemInstruction")
if !sysInstruction.Exists() {
t.Error("systemInstruction should exist")
}
if sysInstruction.Get("parts.0.text").String() != "You are helpful" {
t.Error("systemInstruction text mismatch")
}
}
func TestConvertClaudeRequestToAntigravity_RoleMapping(t *testing.T) {
inputJSON := []byte(`{
"model": "claude-3-5-sonnet-20240620",
"messages": [
{"role": "user", "content": [{"type": "text", "text": "Hi"}]},
{"role": "assistant", "content": [{"type": "text", "text": "Hello"}]}
]
}`)
output := ConvertClaudeRequestToAntigravity("claude-sonnet-4-5", inputJSON, false)
outputStr := string(output)
// assistant should be mapped to model
secondContent := gjson.Get(outputStr, "request.contents.1")
if secondContent.Get("role").String() != "model" {
t.Errorf("Expected role 'model' (mapped from 'assistant'), got '%s'", secondContent.Get("role").String())
}
}
func TestConvertClaudeRequestToAntigravity_ThinkingBlocks(t *testing.T) {
// Valid signature must be at least 50 characters
validSignature := "abc123validSignature1234567890123456789012345678901234567890"
inputJSON := []byte(`{
"model": "claude-sonnet-4-5-thinking",
"messages": [
{
"role": "assistant",
"content": [
{"type": "thinking", "thinking": "Let me think...", "signature": "` + validSignature + `"},
{"type": "text", "text": "Answer"}
]
}
]
}`)
output := ConvertClaudeRequestToAntigravity("claude-sonnet-4-5-thinking", inputJSON, false)
outputStr := string(output)
// Check thinking block conversion
firstPart := gjson.Get(outputStr, "request.contents.0.parts.0")
if !firstPart.Get("thought").Bool() {
t.Error("thinking block should have thought: true")
}
if firstPart.Get("text").String() != "Let me think..." {
t.Error("thinking text mismatch")
}
if firstPart.Get("thoughtSignature").String() != validSignature {
t.Errorf("Expected thoughtSignature '%s', got '%s'", validSignature, firstPart.Get("thoughtSignature").String())
}
}
func TestConvertClaudeRequestToAntigravity_ThinkingBlockWithoutSignature(t *testing.T) {
// Unsigned thinking blocks should be removed entirely (not converted to text)
inputJSON := []byte(`{
"model": "claude-sonnet-4-5-thinking",
"messages": [
{
"role": "assistant",
"content": [
{"type": "thinking", "thinking": "Let me think..."},
{"type": "text", "text": "Answer"}
]
}
]
}`)
output := ConvertClaudeRequestToAntigravity("claude-sonnet-4-5-thinking", inputJSON, false)
outputStr := string(output)
// Without signature, thinking block should be removed (not converted to text)
parts := gjson.Get(outputStr, "request.contents.0.parts").Array()
if len(parts) != 1 {
t.Fatalf("Expected 1 part (thinking removed), got %d", len(parts))
}
// Only text part should remain
if parts[0].Get("thought").Bool() {
t.Error("Thinking block should be removed, not preserved")
}
if parts[0].Get("text").String() != "Answer" {
t.Errorf("Expected text 'Answer', got '%s'", parts[0].Get("text").String())
}
}
func TestConvertClaudeRequestToAntigravity_ToolDeclarations(t *testing.T) {
inputJSON := []byte(`{
"model": "claude-3-5-sonnet-20240620",
"messages": [],
"tools": [
{
"name": "test_tool",
"description": "A test tool",
"input_schema": {
"type": "object",
"properties": {
"name": {"type": "string"}
},
"required": ["name"]
}
}
]
}`)
output := ConvertClaudeRequestToAntigravity("gemini-1.5-pro", inputJSON, false)
outputStr := string(output)
// Check tools structure
tools := gjson.Get(outputStr, "request.tools")
if !tools.Exists() {
t.Error("Tools should exist in output")
}
funcDecl := gjson.Get(outputStr, "request.tools.0.functionDeclarations.0")
if funcDecl.Get("name").String() != "test_tool" {
t.Errorf("Expected tool name 'test_tool', got '%s'", funcDecl.Get("name").String())
}
// Check input_schema renamed to parametersJsonSchema
if funcDecl.Get("parametersJsonSchema").Exists() {
t.Log("parametersJsonSchema exists (expected)")
}
if funcDecl.Get("input_schema").Exists() {
t.Error("input_schema should be removed")
}
}
func TestConvertClaudeRequestToAntigravity_ToolUse(t *testing.T) {
inputJSON := []byte(`{
"model": "claude-3-5-sonnet-20240620",
"messages": [
{
"role": "assistant",
"content": [
{
"type": "tool_use",
"id": "call_123",
"name": "get_weather",
"input": "{\"location\": \"Paris\"}"
}
]
}
]
}`)
output := ConvertClaudeRequestToAntigravity("claude-sonnet-4-5", inputJSON, false)
outputStr := string(output)
// Now we expect only 1 part (tool_use), no dummy thinking block injected
parts := gjson.Get(outputStr, "request.contents.0.parts").Array()
if len(parts) != 1 {
t.Fatalf("Expected 1 part (tool only, no dummy injection), got %d", len(parts))
}
// Check function call conversion at parts[0]
funcCall := parts[0].Get("functionCall")
if !funcCall.Exists() {
t.Error("functionCall should exist at parts[0]")
}
if funcCall.Get("name").String() != "get_weather" {
t.Errorf("Expected function name 'get_weather', got '%s'", funcCall.Get("name").String())
}
if funcCall.Get("id").String() != "call_123" {
t.Errorf("Expected function id 'call_123', got '%s'", funcCall.Get("id").String())
}
// Verify skip_thought_signature_validator is added (bypass for tools without valid thinking)
expectedSig := "skip_thought_signature_validator"
actualSig := parts[0].Get("thoughtSignature").String()
if actualSig != expectedSig {
t.Errorf("Expected thoughtSignature '%s', got '%s'", expectedSig, actualSig)
}
}
func TestConvertClaudeRequestToAntigravity_ToolUse_WithSignature(t *testing.T) {
validSignature := "abc123validSignature1234567890123456789012345678901234567890"
inputJSON := []byte(`{
"model": "claude-sonnet-4-5-thinking",
"messages": [
{
"role": "assistant",
"content": [
{"type": "thinking", "thinking": "Let me think...", "signature": "` + validSignature + `"},
{
"type": "tool_use",
"id": "call_123",
"name": "get_weather",
"input": "{\"location\": \"Paris\"}"
}
]
}
]
}`)
output := ConvertClaudeRequestToAntigravity("claude-sonnet-4-5-thinking", inputJSON, false)
outputStr := string(output)
// Check function call has the signature from the preceding thinking block
part := gjson.Get(outputStr, "request.contents.0.parts.1")
if part.Get("functionCall.name").String() != "get_weather" {
t.Errorf("Expected functionCall, got %s", part.Raw)
}
if part.Get("thoughtSignature").String() != validSignature {
t.Errorf("Expected thoughtSignature '%s' on tool_use, got '%s'", validSignature, part.Get("thoughtSignature").String())
}
}
func TestConvertClaudeRequestToAntigravity_ReorderThinking(t *testing.T) {
// Case: text block followed by thinking block -> should be reordered to thinking first
validSignature := "abc123validSignature1234567890123456789012345678901234567890"
inputJSON := []byte(`{
"model": "claude-sonnet-4-5-thinking",
"messages": [
{
"role": "assistant",
"content": [
{"type": "text", "text": "Here is the plan."},
{"type": "thinking", "thinking": "Planning...", "signature": "` + validSignature + `"}
]
}
]
}`)
output := ConvertClaudeRequestToAntigravity("claude-sonnet-4-5-thinking", inputJSON, false)
outputStr := string(output)
// Verify order: Thinking block MUST be first
parts := gjson.Get(outputStr, "request.contents.0.parts").Array()
if len(parts) != 2 {
t.Fatalf("Expected 2 parts, got %d", len(parts))
}
if !parts[0].Get("thought").Bool() {
t.Error("First part should be thinking block after reordering")
}
if parts[1].Get("text").String() != "Here is the plan." {
t.Error("Second part should be text block")
}
}
func TestConvertClaudeRequestToAntigravity_ToolResult(t *testing.T) {
inputJSON := []byte(`{
"model": "claude-3-5-sonnet-20240620",
"messages": [
{
"role": "user",
"content": [
{
"type": "tool_result",
"tool_use_id": "get_weather-call-123",
"content": "22C sunny"
}
]
}
]
}`)
output := ConvertClaudeRequestToAntigravity("claude-sonnet-4-5", inputJSON, false)
outputStr := string(output)
// Check function response conversion
funcResp := gjson.Get(outputStr, "request.contents.0.parts.0.functionResponse")
if !funcResp.Exists() {
t.Error("functionResponse should exist")
}
if funcResp.Get("id").String() != "get_weather-call-123" {
t.Errorf("Expected function id, got '%s'", funcResp.Get("id").String())
}
}
func TestConvertClaudeRequestToAntigravity_ThinkingConfig(t *testing.T) {
// Note: This test requires the model to be registered in the registry
// with Thinking metadata. If the registry is not populated in test environment,
// thinkingConfig won't be added. We'll test the basic structure only.
inputJSON := []byte(`{
"model": "claude-sonnet-4-5-thinking",
"messages": [],
"thinking": {
"type": "enabled",
"budget_tokens": 8000
}
}`)
output := ConvertClaudeRequestToAntigravity("claude-sonnet-4-5-thinking", inputJSON, false)
outputStr := string(output)
// Check thinking config conversion (only if model supports thinking in registry)
thinkingConfig := gjson.Get(outputStr, "request.generationConfig.thinkingConfig")
if thinkingConfig.Exists() {
if thinkingConfig.Get("thinkingBudget").Int() != 8000 {
t.Errorf("Expected thinkingBudget 8000, got %d", thinkingConfig.Get("thinkingBudget").Int())
}
if !thinkingConfig.Get("include_thoughts").Bool() {
t.Error("include_thoughts should be true")
}
} else {
t.Log("thinkingConfig not present - model may not be registered in test registry")
}
}
func TestConvertClaudeRequestToAntigravity_ImageContent(t *testing.T) {
inputJSON := []byte(`{
"model": "claude-3-5-sonnet-20240620",
"messages": [
{
"role": "user",
"content": [
{
"type": "image",
"source": {
"type": "base64",
"media_type": "image/png",
"data": "iVBORw0KGgoAAAANSUhEUg=="
}
}
]
}
]
}`)
output := ConvertClaudeRequestToAntigravity("claude-sonnet-4-5", inputJSON, false)
outputStr := string(output)
// Check inline data conversion
inlineData := gjson.Get(outputStr, "request.contents.0.parts.0.inlineData")
if !inlineData.Exists() {
t.Error("inlineData should exist")
}
if inlineData.Get("mime_type").String() != "image/png" {
t.Error("mime_type mismatch")
}
if !strings.Contains(inlineData.Get("data").String(), "iVBORw0KGgo") {
t.Error("data mismatch")
}
}
func TestConvertClaudeRequestToAntigravity_GenerationConfig(t *testing.T) {
inputJSON := []byte(`{
"model": "claude-3-5-sonnet-20240620",
"messages": [],
"temperature": 0.7,
"top_p": 0.9,
"top_k": 40,
"max_tokens": 2000
}`)
output := ConvertClaudeRequestToAntigravity("claude-sonnet-4-5", inputJSON, false)
outputStr := string(output)
genConfig := gjson.Get(outputStr, "request.generationConfig")
if genConfig.Get("temperature").Float() != 0.7 {
t.Errorf("Expected temperature 0.7, got %f", genConfig.Get("temperature").Float())
}
if genConfig.Get("topP").Float() != 0.9 {
t.Errorf("Expected topP 0.9, got %f", genConfig.Get("topP").Float())
}
if genConfig.Get("topK").Float() != 40 {
t.Errorf("Expected topK 40, got %f", genConfig.Get("topK").Float())
}
if genConfig.Get("maxOutputTokens").Float() != 2000 {
t.Errorf("Expected maxOutputTokens 2000, got %f", genConfig.Get("maxOutputTokens").Float())
}
}
// ============================================================================
// Trailing Unsigned Thinking Block Removal
// ============================================================================
func TestConvertClaudeRequestToAntigravity_TrailingUnsignedThinking_Removed(t *testing.T) {
// Last assistant message ends with unsigned thinking block - should be removed
inputJSON := []byte(`{
"model": "claude-sonnet-4-5-thinking",
"messages": [
{
"role": "user",
"content": [{"type": "text", "text": "Hello"}]
},
{
"role": "assistant",
"content": [
{"type": "text", "text": "Here is my answer"},
{"type": "thinking", "thinking": "I should think more..."}
]
}
]
}`)
output := ConvertClaudeRequestToAntigravity("claude-sonnet-4-5-thinking", inputJSON, false)
outputStr := string(output)
// The last part of the last assistant message should NOT be a thinking block
lastMessageParts := gjson.Get(outputStr, "request.contents.1.parts")
if !lastMessageParts.IsArray() {
t.Fatal("Last message should have parts array")
}
parts := lastMessageParts.Array()
if len(parts) == 0 {
t.Fatal("Last message should have at least one part")
}
// The unsigned thinking should be removed, leaving only the text
lastPart := parts[len(parts)-1]
if lastPart.Get("thought").Bool() {
t.Error("Trailing unsigned thinking block should be removed")
}
}
func TestConvertClaudeRequestToAntigravity_TrailingSignedThinking_Kept(t *testing.T) {
// Last assistant message ends with signed thinking block - should be kept
inputJSON := []byte(`{
"model": "claude-sonnet-4-5-thinking",
"messages": [
{
"role": "user",
"content": [{"type": "text", "text": "Hello"}]
},
{
"role": "assistant",
"content": [
{"type": "text", "text": "Here is my answer"},
{"type": "thinking", "thinking": "Valid thinking...", "signature": "abc123validSignature1234567890123456789012345678901234567890"}
]
}
]
}`)
output := ConvertClaudeRequestToAntigravity("claude-sonnet-4-5-thinking", inputJSON, false)
outputStr := string(output)
// The signed thinking block should be preserved
lastMessageParts := gjson.Get(outputStr, "request.contents.1.parts")
parts := lastMessageParts.Array()
if len(parts) < 2 {
t.Error("Signed thinking block should be preserved")
}
}
func TestConvertClaudeRequestToAntigravity_MiddleUnsignedThinking_Removed(t *testing.T) {
// Middle message has unsigned thinking - should be removed entirely
inputJSON := []byte(`{
"model": "claude-sonnet-4-5-thinking",
"messages": [
{
"role": "assistant",
"content": [
{"type": "thinking", "thinking": "Middle thinking..."},
{"type": "text", "text": "Answer"}
]
},
{
"role": "user",
"content": [{"type": "text", "text": "Follow up"}]
}
]
}`)
output := ConvertClaudeRequestToAntigravity("claude-sonnet-4-5-thinking", inputJSON, false)
outputStr := string(output)
// Unsigned thinking should be removed entirely
parts := gjson.Get(outputStr, "request.contents.0.parts").Array()
if len(parts) != 1 {
t.Fatalf("Expected 1 part (thinking removed), got %d", len(parts))
}
// Only text part should remain
if parts[0].Get("thought").Bool() {
t.Error("Thinking block should be removed, not preserved")
}
if parts[0].Get("text").String() != "Answer" {
t.Errorf("Expected text 'Answer', got '%s'", parts[0].Get("text").String())
}
}
// ============================================================================
// Tool + Thinking System Hint Injection
// ============================================================================
func TestConvertClaudeRequestToAntigravity_ToolAndThinking_HintInjected(t *testing.T) {
// When both tools and thinking are enabled, hint should be injected into system instruction
inputJSON := []byte(`{
"model": "claude-sonnet-4-5-thinking",
"messages": [{"role": "user", "content": [{"type": "text", "text": "Hello"}]}],
"system": [{"type": "text", "text": "You are helpful."}],
"tools": [
{
"name": "get_weather",
"description": "Get weather",
"input_schema": {"type": "object", "properties": {"location": {"type": "string"}}}
}
],
"thinking": {"type": "enabled", "budget_tokens": 8000}
}`)
output := ConvertClaudeRequestToAntigravity("claude-sonnet-4-5-thinking", inputJSON, false)
outputStr := string(output)
// System instruction should contain the interleaved thinking hint
sysInstruction := gjson.Get(outputStr, "request.systemInstruction")
if !sysInstruction.Exists() {
t.Fatal("systemInstruction should exist")
}
// Check if hint is appended
sysText := sysInstruction.Get("parts").Array()
found := false
for _, part := range sysText {
if strings.Contains(part.Get("text").String(), "Interleaved thinking is enabled") {
found = true
break
}
}
if !found {
t.Errorf("Interleaved thinking hint should be injected when tools and thinking are both active, got: %v", sysInstruction.Raw)
}
}
func TestConvertClaudeRequestToAntigravity_ToolsOnly_NoHint(t *testing.T) {
// When only tools are present (no thinking), hint should NOT be injected
inputJSON := []byte(`{
"model": "claude-sonnet-4-5",
"messages": [{"role": "user", "content": [{"type": "text", "text": "Hello"}]}],
"system": [{"type": "text", "text": "You are helpful."}],
"tools": [
{
"name": "get_weather",
"description": "Get weather",
"input_schema": {"type": "object", "properties": {"location": {"type": "string"}}}
}
]
}`)
output := ConvertClaudeRequestToAntigravity("claude-sonnet-4-5", inputJSON, false)
outputStr := string(output)
// System instruction should NOT contain the hint
sysInstruction := gjson.Get(outputStr, "request.systemInstruction")
if sysInstruction.Exists() {
for _, part := range sysInstruction.Get("parts").Array() {
if strings.Contains(part.Get("text").String(), "Interleaved thinking is enabled") {
t.Error("Hint should NOT be injected when only tools are present (no thinking)")
}
}
}
}
func TestConvertClaudeRequestToAntigravity_ThinkingOnly_NoHint(t *testing.T) {
// When only thinking is enabled (no tools), hint should NOT be injected
inputJSON := []byte(`{
"model": "claude-sonnet-4-5-thinking",
"messages": [{"role": "user", "content": [{"type": "text", "text": "Hello"}]}],
"system": [{"type": "text", "text": "You are helpful."}],
"thinking": {"type": "enabled", "budget_tokens": 8000}
}`)
output := ConvertClaudeRequestToAntigravity("claude-sonnet-4-5-thinking", inputJSON, false)
outputStr := string(output)
// System instruction should NOT contain the hint (no tools)
sysInstruction := gjson.Get(outputStr, "request.systemInstruction")
if sysInstruction.Exists() {
for _, part := range sysInstruction.Get("parts").Array() {
if strings.Contains(part.Get("text").String(), "Interleaved thinking is enabled") {
t.Error("Hint should NOT be injected when only thinking is present (no tools)")
}
}
}
}
func TestConvertClaudeRequestToAntigravity_ToolAndThinking_NoExistingSystem(t *testing.T) {
// When tools + thinking but no system instruction, should create one with hint
inputJSON := []byte(`{
"model": "claude-sonnet-4-5-thinking",
"messages": [{"role": "user", "content": [{"type": "text", "text": "Hello"}]}],
"tools": [
{
"name": "get_weather",
"description": "Get weather",
"input_schema": {"type": "object", "properties": {"location": {"type": "string"}}}
}
],
"thinking": {"type": "enabled", "budget_tokens": 8000}
}`)
output := ConvertClaudeRequestToAntigravity("claude-sonnet-4-5-thinking", inputJSON, false)
outputStr := string(output)
// System instruction should be created with hint
sysInstruction := gjson.Get(outputStr, "request.systemInstruction")
if !sysInstruction.Exists() {
t.Fatal("systemInstruction should be created when tools + thinking are active")
}
sysText := sysInstruction.Get("parts").Array()
found := false
for _, part := range sysText {
if strings.Contains(part.Get("text").String(), "Interleaved thinking is enabled") {
found = true
break
}
}
if !found {
t.Errorf("Interleaved thinking hint should be in created systemInstruction, got: %v", sysInstruction.Raw)
}
}

View File

@@ -14,7 +14,9 @@ import (
"sync/atomic" "sync/atomic"
"time" "time"
"github.com/router-for-me/CLIProxyAPI/v6/internal/cache"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/tidwall/gjson" "github.com/tidwall/gjson"
"github.com/tidwall/sjson" "github.com/tidwall/sjson"
) )
@@ -37,6 +39,10 @@ type Params struct {
HasSentFinalEvents bool // Indicates if final content/message events have been sent HasSentFinalEvents bool // Indicates if final content/message events have been sent
HasToolUse bool // Indicates if tool use was observed in the stream HasToolUse bool // Indicates if tool use was observed in the stream
HasContent bool // Tracks whether any content (text, thinking, or tool use) has been output HasContent bool // Tracks whether any content (text, thinking, or tool use) has been output
// Signature caching support
SessionID string // Session ID derived from request for signature caching
CurrentThinkingText strings.Builder // Accumulates thinking text for signature caching
} }
// toolUseIDCounter provides a process-wide unique counter for tool use identifiers. // toolUseIDCounter provides a process-wide unique counter for tool use identifiers.
@@ -64,6 +70,7 @@ func ConvertAntigravityResponseToClaude(_ context.Context, _ string, originalReq
HasFirstResponse: false, HasFirstResponse: false,
ResponseType: 0, ResponseType: 0,
ResponseIndex: 0, ResponseIndex: 0,
SessionID: deriveSessionID(originalRequestRawJSON),
} }
} }
@@ -121,11 +128,20 @@ func ConvertAntigravityResponseToClaude(_ context.Context, _ string, originalReq
// Process thinking content (internal reasoning) // Process thinking content (internal reasoning)
if partResult.Get("thought").Bool() { if partResult.Get("thought").Bool() {
if thoughtSignature := partResult.Get("thoughtSignature"); thoughtSignature.Exists() && thoughtSignature.String() != "" { if thoughtSignature := partResult.Get("thoughtSignature"); thoughtSignature.Exists() && thoughtSignature.String() != "" {
log.Debug("Branch: signature_delta")
if params.SessionID != "" && params.CurrentThinkingText.Len() > 0 {
cache.CacheSignature(params.SessionID, params.CurrentThinkingText.String(), thoughtSignature.String())
log.Debugf("Cached signature for thinking block (sessionID=%s, textLen=%d)", params.SessionID, params.CurrentThinkingText.Len())
params.CurrentThinkingText.Reset()
}
output = output + "event: content_block_delta\n" output = output + "event: content_block_delta\n"
data, _ := sjson.Set(fmt.Sprintf(`{"type":"content_block_delta","index":%d,"delta":{"type":"signature_delta","signature":""}}`, params.ResponseIndex), "delta.signature", thoughtSignature.String()) data, _ := sjson.Set(fmt.Sprintf(`{"type":"content_block_delta","index":%d,"delta":{"type":"signature_delta","signature":""}}`, params.ResponseIndex), "delta.signature", thoughtSignature.String())
output = output + fmt.Sprintf("data: %s\n\n\n", data) output = output + fmt.Sprintf("data: %s\n\n\n", data)
params.HasContent = true params.HasContent = true
} else if params.ResponseType == 2 { // Continue existing thinking block if already in thinking state } else if params.ResponseType == 2 { // Continue existing thinking block if already in thinking state
params.CurrentThinkingText.WriteString(partTextResult.String())
output = output + "event: content_block_delta\n" output = output + "event: content_block_delta\n"
data, _ := sjson.Set(fmt.Sprintf(`{"type":"content_block_delta","index":%d,"delta":{"type":"thinking_delta","thinking":""}}`, params.ResponseIndex), "delta.thinking", partTextResult.String()) data, _ := sjson.Set(fmt.Sprintf(`{"type":"content_block_delta","index":%d,"delta":{"type":"thinking_delta","thinking":""}}`, params.ResponseIndex), "delta.thinking", partTextResult.String())
output = output + fmt.Sprintf("data: %s\n\n\n", data) output = output + fmt.Sprintf("data: %s\n\n\n", data)
@@ -154,6 +170,9 @@ func ConvertAntigravityResponseToClaude(_ context.Context, _ string, originalReq
output = output + fmt.Sprintf("data: %s\n\n\n", data) output = output + fmt.Sprintf("data: %s\n\n\n", data)
params.ResponseType = 2 // Set state to thinking params.ResponseType = 2 // Set state to thinking
params.HasContent = true params.HasContent = true
// Start accumulating thinking text for signature caching
params.CurrentThinkingText.Reset()
params.CurrentThinkingText.WriteString(partTextResult.String())
} }
} else { } else {
finishReasonResult := gjson.GetBytes(rawJSON, "response.candidates.0.finishReason") finishReasonResult := gjson.GetBytes(rawJSON, "response.candidates.0.finishReason")

View File

@@ -0,0 +1,316 @@
package claude
import (
"context"
"strings"
"testing"
"github.com/router-for-me/CLIProxyAPI/v6/internal/cache"
)
// ============================================================================
// Signature Caching Tests
// ============================================================================
func TestConvertAntigravityResponseToClaude_SessionIDDerived(t *testing.T) {
cache.ClearSignatureCache("")
// Request with user message - should derive session ID
requestJSON := []byte(`{
"messages": [
{"role": "user", "content": [{"type": "text", "text": "Hello world"}]}
]
}`)
// First response chunk with thinking
responseJSON := []byte(`{
"response": {
"candidates": [{
"content": {
"parts": [{"text": "Let me think...", "thought": true}]
}
}]
}
}`)
var param any
ctx := context.Background()
ConvertAntigravityResponseToClaude(ctx, "claude-sonnet-4-5-thinking", requestJSON, requestJSON, responseJSON, &param)
// Verify session ID was set
params := param.(*Params)
if params.SessionID == "" {
t.Error("SessionID should be derived from request")
}
}
func TestConvertAntigravityResponseToClaude_ThinkingTextAccumulated(t *testing.T) {
cache.ClearSignatureCache("")
requestJSON := []byte(`{
"messages": [{"role": "user", "content": [{"type": "text", "text": "Test"}]}]
}`)
// First thinking chunk
chunk1 := []byte(`{
"response": {
"candidates": [{
"content": {
"parts": [{"text": "First part of thinking...", "thought": true}]
}
}]
}
}`)
// Second thinking chunk (continuation)
chunk2 := []byte(`{
"response": {
"candidates": [{
"content": {
"parts": [{"text": " Second part of thinking...", "thought": true}]
}
}]
}
}`)
var param any
ctx := context.Background()
// Process first chunk - starts new thinking block
ConvertAntigravityResponseToClaude(ctx, "claude-sonnet-4-5-thinking", requestJSON, requestJSON, chunk1, &param)
params := param.(*Params)
if params.CurrentThinkingText.Len() == 0 {
t.Error("Thinking text should be accumulated after first chunk")
}
// Process second chunk - continues thinking block
ConvertAntigravityResponseToClaude(ctx, "claude-sonnet-4-5-thinking", requestJSON, requestJSON, chunk2, &param)
text := params.CurrentThinkingText.String()
if !strings.Contains(text, "First part") || !strings.Contains(text, "Second part") {
t.Errorf("Thinking text should accumulate both parts, got: %s", text)
}
}
func TestConvertAntigravityResponseToClaude_SignatureCached(t *testing.T) {
cache.ClearSignatureCache("")
requestJSON := []byte(`{
"messages": [{"role": "user", "content": [{"type": "text", "text": "Cache test"}]}]
}`)
// Thinking chunk
thinkingChunk := []byte(`{
"response": {
"candidates": [{
"content": {
"parts": [{"text": "My thinking process here", "thought": true}]
}
}]
}
}`)
// Signature chunk
validSignature := "abc123validSignature1234567890123456789012345678901234567890"
signatureChunk := []byte(`{
"response": {
"candidates": [{
"content": {
"parts": [{"text": "", "thought": true, "thoughtSignature": "` + validSignature + `"}]
}
}]
}
}`)
var param any
ctx := context.Background()
// Process thinking chunk
ConvertAntigravityResponseToClaude(ctx, "claude-sonnet-4-5-thinking", requestJSON, requestJSON, thinkingChunk, &param)
params := param.(*Params)
sessionID := params.SessionID
thinkingText := params.CurrentThinkingText.String()
if sessionID == "" {
t.Fatal("SessionID should be set")
}
if thinkingText == "" {
t.Fatal("Thinking text should be accumulated")
}
// Process signature chunk - should cache the signature
ConvertAntigravityResponseToClaude(ctx, "claude-sonnet-4-5-thinking", requestJSON, requestJSON, signatureChunk, &param)
// Verify signature was cached
cachedSig := cache.GetCachedSignature(sessionID, thinkingText)
if cachedSig != validSignature {
t.Errorf("Expected cached signature '%s', got '%s'", validSignature, cachedSig)
}
// Verify thinking text was reset after caching
if params.CurrentThinkingText.Len() != 0 {
t.Error("Thinking text should be reset after signature is cached")
}
}
func TestConvertAntigravityResponseToClaude_MultipleThinkingBlocks(t *testing.T) {
cache.ClearSignatureCache("")
requestJSON := []byte(`{
"messages": [{"role": "user", "content": [{"type": "text", "text": "Multi block test"}]}]
}`)
validSig1 := "signature1_12345678901234567890123456789012345678901234567"
validSig2 := "signature2_12345678901234567890123456789012345678901234567"
// First thinking block with signature
block1Thinking := []byte(`{
"response": {
"candidates": [{
"content": {
"parts": [{"text": "First thinking block", "thought": true}]
}
}]
}
}`)
block1Sig := []byte(`{
"response": {
"candidates": [{
"content": {
"parts": [{"text": "", "thought": true, "thoughtSignature": "` + validSig1 + `"}]
}
}]
}
}`)
// Text content (breaks thinking)
textBlock := []byte(`{
"response": {
"candidates": [{
"content": {
"parts": [{"text": "Regular text output"}]
}
}]
}
}`)
// Second thinking block with signature
block2Thinking := []byte(`{
"response": {
"candidates": [{
"content": {
"parts": [{"text": "Second thinking block", "thought": true}]
}
}]
}
}`)
block2Sig := []byte(`{
"response": {
"candidates": [{
"content": {
"parts": [{"text": "", "thought": true, "thoughtSignature": "` + validSig2 + `"}]
}
}]
}
}`)
var param any
ctx := context.Background()
// Process first thinking block
ConvertAntigravityResponseToClaude(ctx, "claude-sonnet-4-5-thinking", requestJSON, requestJSON, block1Thinking, &param)
params := param.(*Params)
sessionID := params.SessionID
firstThinkingText := params.CurrentThinkingText.String()
ConvertAntigravityResponseToClaude(ctx, "claude-sonnet-4-5-thinking", requestJSON, requestJSON, block1Sig, &param)
// Verify first signature cached
if cache.GetCachedSignature(sessionID, firstThinkingText) != validSig1 {
t.Error("First thinking block signature should be cached")
}
// Process text (transitions out of thinking)
ConvertAntigravityResponseToClaude(ctx, "claude-sonnet-4-5-thinking", requestJSON, requestJSON, textBlock, &param)
// Process second thinking block
ConvertAntigravityResponseToClaude(ctx, "claude-sonnet-4-5-thinking", requestJSON, requestJSON, block2Thinking, &param)
secondThinkingText := params.CurrentThinkingText.String()
ConvertAntigravityResponseToClaude(ctx, "claude-sonnet-4-5-thinking", requestJSON, requestJSON, block2Sig, &param)
// Verify second signature cached
if cache.GetCachedSignature(sessionID, secondThinkingText) != validSig2 {
t.Error("Second thinking block signature should be cached")
}
}
func TestDeriveSessionIDFromRequest(t *testing.T) {
tests := []struct {
name string
input []byte
wantEmpty bool
}{
{
name: "valid user message",
input: []byte(`{"messages": [{"role": "user", "content": "Hello"}]}`),
wantEmpty: false,
},
{
name: "user message with content array",
input: []byte(`{"messages": [{"role": "user", "content": [{"type": "text", "text": "Hello"}]}]}`),
wantEmpty: false,
},
{
name: "no user message",
input: []byte(`{"messages": [{"role": "assistant", "content": "Hi"}]}`),
wantEmpty: true,
},
{
name: "empty messages",
input: []byte(`{"messages": []}`),
wantEmpty: true,
},
{
name: "no messages field",
input: []byte(`{}`),
wantEmpty: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := deriveSessionID(tt.input)
if tt.wantEmpty && result != "" {
t.Errorf("Expected empty session ID, got '%s'", result)
}
if !tt.wantEmpty && result == "" {
t.Error("Expected non-empty session ID")
}
})
}
}
func TestDeriveSessionIDFromRequest_Deterministic(t *testing.T) {
input := []byte(`{"messages": [{"role": "user", "content": "Same message"}]}`)
id1 := deriveSessionID(input)
id2 := deriveSessionID(input)
if id1 != id2 {
t.Errorf("Session ID should be deterministic: '%s' != '%s'", id1, id2)
}
}
func TestDeriveSessionIDFromRequest_DifferentMessages(t *testing.T) {
input1 := []byte(`{"messages": [{"role": "user", "content": "Message A"}]}`)
input2 := []byte(`{"messages": [{"role": "user", "content": "Message B"}]}`)
id1 := deriveSessionID(input1)
id2 := deriveSessionID(input2)
if id1 == id2 {
t.Error("Different messages should produce different session IDs")
}
}

View File

@@ -98,16 +98,34 @@ func ConvertGeminiRequestToAntigravity(_ string, inputRawJSON []byte, _ bool) []
} }
} }
gjson.GetBytes(rawJSON, "request.contents").ForEach(func(key, content gjson.Result) bool { // Gemini-specific handling: add skip_thought_signature_validator to functionCall parts
// and remove thinking blocks entirely (Gemini doesn't need to preserve them)
const skipSentinel = "skip_thought_signature_validator"
gjson.GetBytes(rawJSON, "request.contents").ForEach(func(contentIdx, content gjson.Result) bool {
if content.Get("role").String() == "model" { if content.Get("role").String() == "model" {
content.Get("parts").ForEach(func(partKey, part gjson.Result) bool { // First pass: collect indices of thinking parts to remove
var thinkingIndicesToRemove []int64
content.Get("parts").ForEach(func(partIdx, part gjson.Result) bool {
// Mark thinking blocks for removal
if part.Get("thought").Bool() {
thinkingIndicesToRemove = append(thinkingIndicesToRemove, partIdx.Int())
}
// Add skip sentinel to functionCall parts
if part.Get("functionCall").Exists() { if part.Get("functionCall").Exists() {
rawJSON, _ = sjson.SetBytes(rawJSON, fmt.Sprintf("request.contents.%d.parts.%d.thoughtSignature", key.Int(), partKey.Int()), "skip_thought_signature_validator") existingSig := part.Get("thoughtSignature").String()
} else if part.Get("thoughtSignature").Exists() { if existingSig == "" || len(existingSig) < 50 {
rawJSON, _ = sjson.SetBytes(rawJSON, fmt.Sprintf("request.contents.%d.parts.%d.thoughtSignature", key.Int(), partKey.Int()), "skip_thought_signature_validator") rawJSON, _ = sjson.SetBytes(rawJSON, fmt.Sprintf("request.contents.%d.parts.%d.thoughtSignature", contentIdx.Int(), partIdx.Int()), skipSentinel)
}
} }
return true return true
}) })
// Remove thinking blocks in reverse order to preserve indices
for i := len(thinkingIndicesToRemove) - 1; i >= 0; i-- {
idx := thinkingIndicesToRemove[i]
rawJSON, _ = sjson.DeleteBytes(rawJSON, fmt.Sprintf("request.contents.%d.parts.%d", contentIdx.Int(), idx))
}
} }
return true return true
}) })

View File

@@ -0,0 +1,129 @@
package gemini
import (
"fmt"
"testing"
"github.com/tidwall/gjson"
)
func TestConvertGeminiRequestToAntigravity_PreserveValidSignature(t *testing.T) {
// Valid signature on functionCall should be preserved
validSignature := "abc123validSignature1234567890123456789012345678901234567890"
inputJSON := []byte(fmt.Sprintf(`{
"model": "gemini-3-pro-preview",
"contents": [
{
"role": "model",
"parts": [
{"functionCall": {"name": "test_tool", "args": {}}, "thoughtSignature": "%s"}
]
}
]
}`, validSignature))
output := ConvertGeminiRequestToAntigravity("gemini-3-pro-preview", inputJSON, false)
outputStr := string(output)
// Check that valid thoughtSignature is preserved
parts := gjson.Get(outputStr, "request.contents.0.parts").Array()
if len(parts) != 1 {
t.Fatalf("Expected 1 part, got %d", len(parts))
}
sig := parts[0].Get("thoughtSignature").String()
if sig != validSignature {
t.Errorf("Expected thoughtSignature '%s', got '%s'", validSignature, sig)
}
}
func TestConvertGeminiRequestToAntigravity_AddSkipSentinelToFunctionCall(t *testing.T) {
// functionCall without signature should get skip_thought_signature_validator
inputJSON := []byte(`{
"model": "gemini-3-pro-preview",
"contents": [
{
"role": "model",
"parts": [
{"functionCall": {"name": "test_tool", "args": {}}}
]
}
]
}`)
output := ConvertGeminiRequestToAntigravity("gemini-3-pro-preview", inputJSON, false)
outputStr := string(output)
// Check that skip_thought_signature_validator is added to functionCall
sig := gjson.Get(outputStr, "request.contents.0.parts.0.thoughtSignature").String()
expectedSig := "skip_thought_signature_validator"
if sig != expectedSig {
t.Errorf("Expected skip sentinel '%s', got '%s'", expectedSig, sig)
}
}
func TestConvertGeminiRequestToAntigravity_RemoveThinkingBlocks(t *testing.T) {
// Thinking blocks should be removed entirely for Gemini
validSignature := "abc123validSignature1234567890123456789012345678901234567890"
inputJSON := []byte(fmt.Sprintf(`{
"model": "gemini-3-pro-preview",
"contents": [
{
"role": "model",
"parts": [
{"thought": true, "text": "Thinking...", "thoughtSignature": "%s"},
{"text": "Here is my response"}
]
}
]
}`, validSignature))
output := ConvertGeminiRequestToAntigravity("gemini-3-pro-preview", inputJSON, false)
outputStr := string(output)
// Check that thinking block is removed
parts := gjson.Get(outputStr, "request.contents.0.parts").Array()
if len(parts) != 1 {
t.Fatalf("Expected 1 part (thinking removed), got %d", len(parts))
}
// Only text part should remain
if parts[0].Get("thought").Bool() {
t.Error("Thinking block should be removed for Gemini")
}
if parts[0].Get("text").String() != "Here is my response" {
t.Errorf("Expected text 'Here is my response', got '%s'", parts[0].Get("text").String())
}
}
func TestConvertGeminiRequestToAntigravity_ParallelFunctionCalls(t *testing.T) {
// Multiple functionCalls should all get skip_thought_signature_validator
inputJSON := []byte(`{
"model": "gemini-3-pro-preview",
"contents": [
{
"role": "model",
"parts": [
{"functionCall": {"name": "tool_one", "args": {"a": "1"}}},
{"functionCall": {"name": "tool_two", "args": {"b": "2"}}}
]
}
]
}`)
output := ConvertGeminiRequestToAntigravity("gemini-3-pro-preview", inputJSON, false)
outputStr := string(output)
parts := gjson.Get(outputStr, "request.contents.0.parts").Array()
if len(parts) != 2 {
t.Fatalf("Expected 2 parts, got %d", len(parts))
}
expectedSig := "skip_thought_signature_validator"
for i, part := range parts {
sig := part.Get("thoughtSignature").String()
if sig != expectedSig {
t.Errorf("Part %d: Expected '%s', got '%s'", i, expectedSig, sig)
}
}
}

View File

@@ -39,8 +39,23 @@ func ConvertOpenAIRequestToAntigravity(modelName string, inputRawJSON []byte, _
// Note: OpenAI official fields take precedence over extra_body.google.thinking_config // Note: OpenAI official fields take precedence over extra_body.google.thinking_config
re := gjson.GetBytes(rawJSON, "reasoning_effort") re := gjson.GetBytes(rawJSON, "reasoning_effort")
hasOfficialThinking := re.Exists() hasOfficialThinking := re.Exists()
if hasOfficialThinking && util.ModelSupportsThinking(modelName) && !util.ModelUsesThinkingLevels(modelName) { if hasOfficialThinking && util.ModelSupportsThinking(modelName) {
out = util.ApplyReasoningEffortToGeminiCLI(out, re.String()) effort := strings.ToLower(strings.TrimSpace(re.String()))
if util.IsGemini3Model(modelName) {
switch effort {
case "none":
out, _ = sjson.DeleteBytes(out, "request.generationConfig.thinkingConfig")
case "auto":
includeThoughts := true
out = util.ApplyGeminiCLIThinkingLevel(out, "", &includeThoughts)
default:
if level, ok := util.ValidateGemini3ThinkingLevel(modelName, effort); ok {
out = util.ApplyGeminiCLIThinkingLevel(out, level, nil)
}
}
} else if !util.ModelUsesThinkingLevels(modelName) {
out = util.ApplyReasoningEffortToGeminiCLI(out, effort)
}
} }
// Cherry Studio extension extra_body.google.thinking_config (effective only when official fields are absent) // Cherry Studio extension extra_body.google.thinking_config (effective only when official fields are absent)

View File

@@ -95,7 +95,7 @@ func ConvertClaudeResponseToOpenAIResponses(ctx context.Context, modelName strin
} }
} }
// response.created // response.created
created := `{"type":"response.created","sequence_number":0,"response":{"id":"","object":"response","created_at":0,"status":"in_progress","background":false,"error":null,"instructions":""}}` created := `{"type":"response.created","sequence_number":0,"response":{"id":"","object":"response","created_at":0,"status":"in_progress","background":false,"error":null,"output":[]}}`
created, _ = sjson.Set(created, "sequence_number", nextSeq()) created, _ = sjson.Set(created, "sequence_number", nextSeq())
created, _ = sjson.Set(created, "response.id", st.ResponseID) created, _ = sjson.Set(created, "response.id", st.ResponseID)
created, _ = sjson.Set(created, "response.created_at", st.CreatedAt) created, _ = sjson.Set(created, "response.created_at", st.CreatedAt)
@@ -197,11 +197,11 @@ func ConvertClaudeResponseToOpenAIResponses(ctx context.Context, modelName strin
if st.ReasoningActive { if st.ReasoningActive {
if t := d.Get("thinking"); t.Exists() { if t := d.Get("thinking"); t.Exists() {
st.ReasoningBuf.WriteString(t.String()) st.ReasoningBuf.WriteString(t.String())
msg := `{"type":"response.reasoning_summary_text.delta","sequence_number":0,"item_id":"","output_index":0,"summary_index":0,"text":""}` msg := `{"type":"response.reasoning_summary_text.delta","sequence_number":0,"item_id":"","output_index":0,"summary_index":0,"delta":""}`
msg, _ = sjson.Set(msg, "sequence_number", nextSeq()) msg, _ = sjson.Set(msg, "sequence_number", nextSeq())
msg, _ = sjson.Set(msg, "item_id", st.ReasoningItemID) msg, _ = sjson.Set(msg, "item_id", st.ReasoningItemID)
msg, _ = sjson.Set(msg, "output_index", st.ReasoningIndex) msg, _ = sjson.Set(msg, "output_index", st.ReasoningIndex)
msg, _ = sjson.Set(msg, "text", t.String()) msg, _ = sjson.Set(msg, "delta", t.String())
out = append(out, emitEvent("response.reasoning_summary_text.delta", msg)) out = append(out, emitEvent("response.reasoning_summary_text.delta", msg))
} }
} }

View File

@@ -134,6 +134,8 @@ func ConvertClaudeRequestToCLI(modelName string, inputRawJSON []byte, _ bool) []
tool, _ = sjson.SetRaw(tool, "parametersJsonSchema", inputSchema) tool, _ = sjson.SetRaw(tool, "parametersJsonSchema", inputSchema)
tool, _ = sjson.Delete(tool, "strict") tool, _ = sjson.Delete(tool, "strict")
tool, _ = sjson.Delete(tool, "input_examples") tool, _ = sjson.Delete(tool, "input_examples")
tool, _ = sjson.Delete(tool, "type")
tool, _ = sjson.Delete(tool, "cache_control")
var toolDeclaration any var toolDeclaration any
if err := json.Unmarshal([]byte(tool), &toolDeclaration); err == nil { if err := json.Unmarshal([]byte(tool), &toolDeclaration); err == nil {
tools[0].FunctionDeclarations = append(tools[0].FunctionDeclarations, toolDeclaration) tools[0].FunctionDeclarations = append(tools[0].FunctionDeclarations, toolDeclaration)

View File

@@ -127,6 +127,8 @@ func ConvertClaudeRequestToGemini(modelName string, inputRawJSON []byte, _ bool)
tool, _ = sjson.SetRaw(tool, "parametersJsonSchema", inputSchema) tool, _ = sjson.SetRaw(tool, "parametersJsonSchema", inputSchema)
tool, _ = sjson.Delete(tool, "strict") tool, _ = sjson.Delete(tool, "strict")
tool, _ = sjson.Delete(tool, "input_examples") tool, _ = sjson.Delete(tool, "input_examples")
tool, _ = sjson.Delete(tool, "type")
tool, _ = sjson.Delete(tool, "cache_control")
var toolDeclaration any var toolDeclaration any
if err := json.Unmarshal([]byte(tool), &toolDeclaration); err == nil { if err := json.Unmarshal([]byte(tool), &toolDeclaration); err == nil {
tools[0].FunctionDeclarations = append(tools[0].FunctionDeclarations, toolDeclaration) tools[0].FunctionDeclarations = append(tools[0].FunctionDeclarations, toolDeclaration)

View File

@@ -37,12 +37,28 @@ func ConvertOpenAIRequestToGemini(modelName string, inputRawJSON []byte, _ bool)
// Reasoning effort -> thinkingBudget/include_thoughts // Reasoning effort -> thinkingBudget/include_thoughts
// Note: OpenAI official fields take precedence over extra_body.google.thinking_config // Note: OpenAI official fields take precedence over extra_body.google.thinking_config
// Only convert for models that use numeric budgets (not discrete levels) to avoid // Only apply numeric budgets for models that use budgets (not discrete levels) to avoid
// incorrectly applying thinkingBudget for level-based models like gpt-5. // incorrectly applying thinkingBudget for level-based models like gpt-5. Gemini 3 models
// use thinkingLevel/includeThoughts instead.
re := gjson.GetBytes(rawJSON, "reasoning_effort") re := gjson.GetBytes(rawJSON, "reasoning_effort")
hasOfficialThinking := re.Exists() hasOfficialThinking := re.Exists()
if hasOfficialThinking && util.ModelSupportsThinking(modelName) && !util.ModelUsesThinkingLevels(modelName) { if hasOfficialThinking && util.ModelSupportsThinking(modelName) {
out = util.ApplyReasoningEffortToGemini(out, re.String()) effort := strings.ToLower(strings.TrimSpace(re.String()))
if util.IsGemini3Model(modelName) {
switch effort {
case "none":
out, _ = sjson.DeleteBytes(out, "generationConfig.thinkingConfig")
case "auto":
includeThoughts := true
out = util.ApplyGeminiThinkingLevel(out, "", &includeThoughts)
default:
if level, ok := util.ValidateGemini3ThinkingLevel(modelName, effort); ok {
out = util.ApplyGeminiThinkingLevel(out, level, nil)
}
}
} else if !util.ModelUsesThinkingLevels(modelName) {
out = util.ApplyReasoningEffortToGemini(out, effort)
}
} }
// Cherry Studio extension extra_body.google.thinking_config (effective only when official fields are absent) // Cherry Studio extension extra_body.google.thinking_config (effective only when official fields are absent)

View File

@@ -117,7 +117,7 @@ func ConvertGeminiResponseToOpenAIResponses(_ context.Context, modelName string,
st.CreatedAt = time.Now().Unix() st.CreatedAt = time.Now().Unix()
} }
created := `{"type":"response.created","sequence_number":0,"response":{"id":"","object":"response","created_at":0,"status":"in_progress","background":false,"error":null}}` created := `{"type":"response.created","sequence_number":0,"response":{"id":"","object":"response","created_at":0,"status":"in_progress","background":false,"error":null,"output":[]}}`
created, _ = sjson.Set(created, "sequence_number", nextSeq()) created, _ = sjson.Set(created, "sequence_number", nextSeq())
created, _ = sjson.Set(created, "response.id", st.ResponseID) created, _ = sjson.Set(created, "response.id", st.ResponseID)
created, _ = sjson.Set(created, "response.created_at", st.CreatedAt) created, _ = sjson.Set(created, "response.created_at", st.CreatedAt)
@@ -160,11 +160,11 @@ func ConvertGeminiResponseToOpenAIResponses(_ context.Context, modelName string,
} }
if t := part.Get("text"); t.Exists() && t.String() != "" { if t := part.Get("text"); t.Exists() && t.String() != "" {
st.ReasoningBuf.WriteString(t.String()) st.ReasoningBuf.WriteString(t.String())
msg := `{"type":"response.reasoning_summary_text.delta","sequence_number":0,"item_id":"","output_index":0,"summary_index":0,"text":""}` msg := `{"type":"response.reasoning_summary_text.delta","sequence_number":0,"item_id":"","output_index":0,"summary_index":0,"delta":""}`
msg, _ = sjson.Set(msg, "sequence_number", nextSeq()) msg, _ = sjson.Set(msg, "sequence_number", nextSeq())
msg, _ = sjson.Set(msg, "item_id", st.ReasoningItemID) msg, _ = sjson.Set(msg, "item_id", st.ReasoningItemID)
msg, _ = sjson.Set(msg, "output_index", st.ReasoningIndex) msg, _ = sjson.Set(msg, "output_index", st.ReasoningIndex)
msg, _ = sjson.Set(msg, "text", t.String()) msg, _ = sjson.Set(msg, "delta", t.String())
out = append(out, emitEvent("response.reasoning_summary_text.delta", msg)) out = append(out, emitEvent("response.reasoning_summary_text.delta", msg))
} }
return true return true

View File

@@ -143,7 +143,7 @@ func ConvertOpenAIChatCompletionsResponseToOpenAIResponses(ctx context.Context,
st.ReasoningTokens = 0 st.ReasoningTokens = 0
st.UsageSeen = false st.UsageSeen = false
// response.created // response.created
created := `{"type":"response.created","sequence_number":0,"response":{"id":"","object":"response","created_at":0,"status":"in_progress","background":false,"error":null}}` created := `{"type":"response.created","sequence_number":0,"response":{"id":"","object":"response","created_at":0,"status":"in_progress","background":false,"error":null,"output":[]}}`
created, _ = sjson.Set(created, "sequence_number", nextSeq()) created, _ = sjson.Set(created, "sequence_number", nextSeq())
created, _ = sjson.Set(created, "response.id", st.ResponseID) created, _ = sjson.Set(created, "response.id", st.ResponseID)
created, _ = sjson.Set(created, "response.created_at", st.Created) created, _ = sjson.Set(created, "response.created_at", st.Created)
@@ -216,11 +216,11 @@ func ConvertOpenAIChatCompletionsResponseToOpenAIResponses(ctx context.Context,
} }
// Append incremental text to reasoning buffer // Append incremental text to reasoning buffer
st.ReasoningBuf.WriteString(rc.String()) st.ReasoningBuf.WriteString(rc.String())
msg := `{"type":"response.reasoning_summary_text.delta","sequence_number":0,"item_id":"","output_index":0,"summary_index":0,"text":""}` msg := `{"type":"response.reasoning_summary_text.delta","sequence_number":0,"item_id":"","output_index":0,"summary_index":0,"delta":""}`
msg, _ = sjson.Set(msg, "sequence_number", nextSeq()) msg, _ = sjson.Set(msg, "sequence_number", nextSeq())
msg, _ = sjson.Set(msg, "item_id", st.ReasoningID) msg, _ = sjson.Set(msg, "item_id", st.ReasoningID)
msg, _ = sjson.Set(msg, "output_index", st.ReasoningIndex) msg, _ = sjson.Set(msg, "output_index", st.ReasoningIndex)
msg, _ = sjson.Set(msg, "text", rc.String()) msg, _ = sjson.Set(msg, "delta", rc.String())
out = append(out, emitRespEvent("response.reasoning_summary_text.delta", msg)) out = append(out, emitRespEvent("response.reasoning_summary_text.delta", msg))
} }

View File

@@ -0,0 +1,10 @@
package util
import "strings"
// IsClaudeThinkingModel checks if the model is a Claude thinking model
// that requires the interleaved-thinking beta header.
func IsClaudeThinkingModel(model string) bool {
lower := strings.ToLower(model)
return strings.Contains(lower, "claude") && strings.Contains(lower, "thinking")
}

View File

@@ -0,0 +1,41 @@
package util
import "testing"
func TestIsClaudeThinkingModel(t *testing.T) {
tests := []struct {
name string
model string
expected bool
}{
// Claude thinking models - should return true
{"claude-sonnet-4-5-thinking", "claude-sonnet-4-5-thinking", true},
{"claude-opus-4-5-thinking", "claude-opus-4-5-thinking", true},
{"Claude-Sonnet-Thinking uppercase", "Claude-Sonnet-4-5-Thinking", true},
{"claude thinking mixed case", "Claude-THINKING-Model", true},
// Non-thinking Claude models - should return false
{"claude-sonnet-4-5 (no thinking)", "claude-sonnet-4-5", false},
{"claude-opus-4-5 (no thinking)", "claude-opus-4-5", false},
{"claude-3-5-sonnet", "claude-3-5-sonnet-20240620", false},
// Non-Claude models - should return false
{"gemini-3-pro-preview", "gemini-3-pro-preview", false},
{"gemini-thinking model", "gemini-3-pro-thinking", false}, // not Claude
{"gpt-4o", "gpt-4o", false},
{"empty string", "", false},
// Edge cases
{"thinking without claude", "thinking-model", false},
{"claude without thinking", "claude-model", false},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
result := IsClaudeThinkingModel(tt.model)
if result != tt.expected {
t.Errorf("IsClaudeThinkingModel(%q) = %v, expected %v", tt.model, result, tt.expected)
}
})
}
}

View File

@@ -12,10 +12,10 @@ import (
var gjsonPathKeyReplacer = strings.NewReplacer(".", "\\.", "*", "\\*", "?", "\\?") var gjsonPathKeyReplacer = strings.NewReplacer(".", "\\.", "*", "\\*", "?", "\\?")
// CleanJSONSchemaForGemini transforms a JSON schema to be compatible with Gemini/Antigravity API. // CleanJSONSchemaForAntigravity transforms a JSON schema to be compatible with Antigravity API.
// It handles unsupported keywords, type flattening, and schema simplification while preserving // It handles unsupported keywords, type flattening, and schema simplification while preserving
// semantic information as description hints. // semantic information as description hints.
func CleanJSONSchemaForGemini(jsonStr string) string { func CleanJSONSchemaForAntigravity(jsonStr string) string {
// Phase 1: Convert and add hints // Phase 1: Convert and add hints
jsonStr = convertRefsToHints(jsonStr) jsonStr = convertRefsToHints(jsonStr)
jsonStr = convertConstToEnum(jsonStr) jsonStr = convertConstToEnum(jsonStr)
@@ -32,6 +32,9 @@ func CleanJSONSchemaForGemini(jsonStr string) string {
jsonStr = removeUnsupportedKeywords(jsonStr) jsonStr = removeUnsupportedKeywords(jsonStr)
jsonStr = cleanupRequiredFields(jsonStr) jsonStr = cleanupRequiredFields(jsonStr)
// Phase 4: Add placeholder for empty object schemas (Claude VALIDATED mode requirement)
jsonStr = addEmptySchemaPlaceholder(jsonStr)
return jsonStr return jsonStr
} }
@@ -105,7 +108,8 @@ func addAdditionalPropertiesHints(jsonStr string) string {
var unsupportedConstraints = []string{ var unsupportedConstraints = []string{
"minLength", "maxLength", "exclusiveMinimum", "exclusiveMaximum", "minLength", "maxLength", "exclusiveMinimum", "exclusiveMaximum",
"pattern", "minItems", "maxItems", "pattern", "minItems", "maxItems", "format",
"default", "examples", // Claude rejects these in VALIDATED mode
} }
func moveConstraintsToDescription(jsonStr string) string { func moveConstraintsToDescription(jsonStr string) string {
@@ -296,6 +300,7 @@ func flattenTypeArrays(jsonStr string) string {
func removeUnsupportedKeywords(jsonStr string) string { func removeUnsupportedKeywords(jsonStr string) string {
keywords := append(unsupportedConstraints, keywords := append(unsupportedConstraints,
"$schema", "$defs", "definitions", "const", "$ref", "additionalProperties", "$schema", "$defs", "definitions", "const", "$ref", "additionalProperties",
"propertyNames", // Gemini doesn't support property name validation
) )
for _, key := range keywords { for _, key := range keywords {
for _, p := range findPaths(jsonStr, key) { for _, p := range findPaths(jsonStr, key) {
@@ -338,6 +343,52 @@ func cleanupRequiredFields(jsonStr string) string {
return jsonStr return jsonStr
} }
// addEmptySchemaPlaceholder adds a placeholder "reason" property to empty object schemas.
// Claude VALIDATED mode requires at least one property in tool schemas.
func addEmptySchemaPlaceholder(jsonStr string) string {
// Find all "type" fields
paths := findPaths(jsonStr, "type")
// Process from deepest to shallowest (to handle nested objects properly)
sortByDepth(paths)
for _, p := range paths {
typeVal := gjson.Get(jsonStr, p)
if typeVal.String() != "object" {
continue
}
// Get the parent path (the object containing "type")
parentPath := trimSuffix(p, ".type")
// Check if properties exists and is empty or missing
propsPath := joinPath(parentPath, "properties")
propsVal := gjson.Get(jsonStr, propsPath)
needsPlaceholder := false
if !propsVal.Exists() {
// No properties field at all
needsPlaceholder = true
} else if propsVal.IsObject() && len(propsVal.Map()) == 0 {
// Empty properties object
needsPlaceholder = true
}
if needsPlaceholder {
// Add placeholder "reason" property
reasonPath := joinPath(propsPath, "reason")
jsonStr, _ = sjson.Set(jsonStr, reasonPath+".type", "string")
jsonStr, _ = sjson.Set(jsonStr, reasonPath+".description", "Brief explanation of why you are calling this tool")
// Add to required array
reqPath := joinPath(parentPath, "required")
jsonStr, _ = sjson.Set(jsonStr, reqPath, []string{"reason"})
}
}
return jsonStr
}
// --- Helpers --- // --- Helpers ---
func findPaths(jsonStr, field string) []string { func findPaths(jsonStr, field string) []string {

View File

@@ -5,9 +5,11 @@ import (
"reflect" "reflect"
"strings" "strings"
"testing" "testing"
"github.com/tidwall/gjson"
) )
func TestCleanJSONSchemaForGemini_ConstToEnum(t *testing.T) { func TestCleanJSONSchemaForAntigravity_ConstToEnum(t *testing.T) {
input := `{ input := `{
"type": "object", "type": "object",
"properties": { "properties": {
@@ -28,11 +30,11 @@ func TestCleanJSONSchemaForGemini_ConstToEnum(t *testing.T) {
} }
}` }`
result := CleanJSONSchemaForGemini(input) result := CleanJSONSchemaForAntigravity(input)
compareJSON(t, expected, result) compareJSON(t, expected, result)
} }
func TestCleanJSONSchemaForGemini_TypeFlattening_Nullable(t *testing.T) { func TestCleanJSONSchemaForAntigravity_TypeFlattening_Nullable(t *testing.T) {
input := `{ input := `{
"type": "object", "type": "object",
"properties": { "properties": {
@@ -60,11 +62,11 @@ func TestCleanJSONSchemaForGemini_TypeFlattening_Nullable(t *testing.T) {
"required": ["other"] "required": ["other"]
}` }`
result := CleanJSONSchemaForGemini(input) result := CleanJSONSchemaForAntigravity(input)
compareJSON(t, expected, result) compareJSON(t, expected, result)
} }
func TestCleanJSONSchemaForGemini_ConstraintsToDescription(t *testing.T) { func TestCleanJSONSchemaForAntigravity_ConstraintsToDescription(t *testing.T) {
input := `{ input := `{
"type": "object", "type": "object",
"properties": { "properties": {
@@ -81,7 +83,7 @@ func TestCleanJSONSchemaForGemini_ConstraintsToDescription(t *testing.T) {
} }
}` }`
result := CleanJSONSchemaForGemini(input) result := CleanJSONSchemaForAntigravity(input)
// minItems should be REMOVED and moved to description // minItems should be REMOVED and moved to description
if strings.Contains(result, `"minItems"`) { if strings.Contains(result, `"minItems"`) {
@@ -100,7 +102,7 @@ func TestCleanJSONSchemaForGemini_ConstraintsToDescription(t *testing.T) {
} }
} }
func TestCleanJSONSchemaForGemini_AnyOfFlattening_SmartSelection(t *testing.T) { func TestCleanJSONSchemaForAntigravity_AnyOfFlattening_SmartSelection(t *testing.T) {
input := `{ input := `{
"type": "object", "type": "object",
"properties": { "properties": {
@@ -131,11 +133,11 @@ func TestCleanJSONSchemaForGemini_AnyOfFlattening_SmartSelection(t *testing.T) {
} }
}` }`
result := CleanJSONSchemaForGemini(input) result := CleanJSONSchemaForAntigravity(input)
compareJSON(t, expected, result) compareJSON(t, expected, result)
} }
func TestCleanJSONSchemaForGemini_OneOfFlattening(t *testing.T) { func TestCleanJSONSchemaForAntigravity_OneOfFlattening(t *testing.T) {
input := `{ input := `{
"type": "object", "type": "object",
"properties": { "properties": {
@@ -158,11 +160,11 @@ func TestCleanJSONSchemaForGemini_OneOfFlattening(t *testing.T) {
} }
}` }`
result := CleanJSONSchemaForGemini(input) result := CleanJSONSchemaForAntigravity(input)
compareJSON(t, expected, result) compareJSON(t, expected, result)
} }
func TestCleanJSONSchemaForGemini_AllOfMerging(t *testing.T) { func TestCleanJSONSchemaForAntigravity_AllOfMerging(t *testing.T) {
input := `{ input := `{
"type": "object", "type": "object",
"allOf": [ "allOf": [
@@ -190,11 +192,11 @@ func TestCleanJSONSchemaForGemini_AllOfMerging(t *testing.T) {
"required": ["a", "b"] "required": ["a", "b"]
}` }`
result := CleanJSONSchemaForGemini(input) result := CleanJSONSchemaForAntigravity(input)
compareJSON(t, expected, result) compareJSON(t, expected, result)
} }
func TestCleanJSONSchemaForGemini_RefHandling(t *testing.T) { func TestCleanJSONSchemaForAntigravity_RefHandling(t *testing.T) {
input := `{ input := `{
"definitions": { "definitions": {
"User": { "User": {
@@ -210,21 +212,29 @@ func TestCleanJSONSchemaForGemini_RefHandling(t *testing.T) {
} }
}` }`
// After $ref is converted to placeholder object, empty schema placeholder is also added
expected := `{ expected := `{
"type": "object", "type": "object",
"properties": { "properties": {
"customer": { "customer": {
"type": "object", "type": "object",
"description": "See: User" "description": "See: User",
"properties": {
"reason": {
"type": "string",
"description": "Brief explanation of why you are calling this tool"
}
},
"required": ["reason"]
} }
} }
}` }`
result := CleanJSONSchemaForGemini(input) result := CleanJSONSchemaForAntigravity(input)
compareJSON(t, expected, result) compareJSON(t, expected, result)
} }
func TestCleanJSONSchemaForGemini_RefHandling_DescriptionEscaping(t *testing.T) { func TestCleanJSONSchemaForAntigravity_RefHandling_DescriptionEscaping(t *testing.T) {
input := `{ input := `{
"definitions": { "definitions": {
"User": { "User": {
@@ -243,21 +253,29 @@ func TestCleanJSONSchemaForGemini_RefHandling_DescriptionEscaping(t *testing.T)
} }
}` }`
// After $ref is converted, empty schema placeholder is also added
expected := `{ expected := `{
"type": "object", "type": "object",
"properties": { "properties": {
"customer": { "customer": {
"type": "object", "type": "object",
"description": "He said \"hi\"\\nsecond line (See: User)" "description": "He said \"hi\"\\nsecond line (See: User)",
"properties": {
"reason": {
"type": "string",
"description": "Brief explanation of why you are calling this tool"
}
},
"required": ["reason"]
} }
} }
}` }`
result := CleanJSONSchemaForGemini(input) result := CleanJSONSchemaForAntigravity(input)
compareJSON(t, expected, result) compareJSON(t, expected, result)
} }
func TestCleanJSONSchemaForGemini_CyclicRefDefaults(t *testing.T) { func TestCleanJSONSchemaForAntigravity_CyclicRefDefaults(t *testing.T) {
input := `{ input := `{
"definitions": { "definitions": {
"Node": { "Node": {
@@ -270,7 +288,7 @@ func TestCleanJSONSchemaForGemini_CyclicRefDefaults(t *testing.T) {
"$ref": "#/definitions/Node" "$ref": "#/definitions/Node"
}` }`
result := CleanJSONSchemaForGemini(input) result := CleanJSONSchemaForAntigravity(input)
var resMap map[string]interface{} var resMap map[string]interface{}
json.Unmarshal([]byte(result), &resMap) json.Unmarshal([]byte(result), &resMap)
@@ -285,7 +303,7 @@ func TestCleanJSONSchemaForGemini_CyclicRefDefaults(t *testing.T) {
} }
} }
func TestCleanJSONSchemaForGemini_RequiredCleanup(t *testing.T) { func TestCleanJSONSchemaForAntigravity_RequiredCleanup(t *testing.T) {
input := `{ input := `{
"type": "object", "type": "object",
"properties": { "properties": {
@@ -304,11 +322,11 @@ func TestCleanJSONSchemaForGemini_RequiredCleanup(t *testing.T) {
"required": ["a", "b"] "required": ["a", "b"]
}` }`
result := CleanJSONSchemaForGemini(input) result := CleanJSONSchemaForAntigravity(input)
compareJSON(t, expected, result) compareJSON(t, expected, result)
} }
func TestCleanJSONSchemaForGemini_AllOfMerging_DotKeys(t *testing.T) { func TestCleanJSONSchemaForAntigravity_AllOfMerging_DotKeys(t *testing.T) {
input := `{ input := `{
"type": "object", "type": "object",
"allOf": [ "allOf": [
@@ -336,11 +354,11 @@ func TestCleanJSONSchemaForGemini_AllOfMerging_DotKeys(t *testing.T) {
"required": ["my.param", "b"] "required": ["my.param", "b"]
}` }`
result := CleanJSONSchemaForGemini(input) result := CleanJSONSchemaForAntigravity(input)
compareJSON(t, expected, result) compareJSON(t, expected, result)
} }
func TestCleanJSONSchemaForGemini_PropertyNameCollision(t *testing.T) { func TestCleanJSONSchemaForAntigravity_PropertyNameCollision(t *testing.T) {
// A tool has an argument named "pattern" - should NOT be treated as a constraint // A tool has an argument named "pattern" - should NOT be treated as a constraint
input := `{ input := `{
"type": "object", "type": "object",
@@ -364,7 +382,7 @@ func TestCleanJSONSchemaForGemini_PropertyNameCollision(t *testing.T) {
"required": ["pattern"] "required": ["pattern"]
}` }`
result := CleanJSONSchemaForGemini(input) result := CleanJSONSchemaForAntigravity(input)
compareJSON(t, expected, result) compareJSON(t, expected, result)
var resMap map[string]interface{} var resMap map[string]interface{}
@@ -375,7 +393,7 @@ func TestCleanJSONSchemaForGemini_PropertyNameCollision(t *testing.T) {
} }
} }
func TestCleanJSONSchemaForGemini_DotKeys(t *testing.T) { func TestCleanJSONSchemaForAntigravity_DotKeys(t *testing.T) {
input := `{ input := `{
"type": "object", "type": "object",
"properties": { "properties": {
@@ -389,7 +407,7 @@ func TestCleanJSONSchemaForGemini_DotKeys(t *testing.T) {
} }
}` }`
result := CleanJSONSchemaForGemini(input) result := CleanJSONSchemaForAntigravity(input)
var resMap map[string]interface{} var resMap map[string]interface{}
if err := json.Unmarshal([]byte(result), &resMap); err != nil { if err := json.Unmarshal([]byte(result), &resMap); err != nil {
@@ -414,7 +432,7 @@ func TestCleanJSONSchemaForGemini_DotKeys(t *testing.T) {
} }
} }
func TestCleanJSONSchemaForGemini_AnyOfAlternativeHints(t *testing.T) { func TestCleanJSONSchemaForAntigravity_AnyOfAlternativeHints(t *testing.T) {
input := `{ input := `{
"type": "object", "type": "object",
"properties": { "properties": {
@@ -428,7 +446,7 @@ func TestCleanJSONSchemaForGemini_AnyOfAlternativeHints(t *testing.T) {
} }
}` }`
result := CleanJSONSchemaForGemini(input) result := CleanJSONSchemaForAntigravity(input)
if !strings.Contains(result, "Accepts:") { if !strings.Contains(result, "Accepts:") {
t.Errorf("Expected alternative types hint, got: %s", result) t.Errorf("Expected alternative types hint, got: %s", result)
@@ -438,7 +456,7 @@ func TestCleanJSONSchemaForGemini_AnyOfAlternativeHints(t *testing.T) {
} }
} }
func TestCleanJSONSchemaForGemini_NullableHint(t *testing.T) { func TestCleanJSONSchemaForAntigravity_NullableHint(t *testing.T) {
input := `{ input := `{
"type": "object", "type": "object",
"properties": { "properties": {
@@ -450,7 +468,7 @@ func TestCleanJSONSchemaForGemini_NullableHint(t *testing.T) {
"required": ["name"] "required": ["name"]
}` }`
result := CleanJSONSchemaForGemini(input) result := CleanJSONSchemaForAntigravity(input)
if !strings.Contains(result, "(nullable)") { if !strings.Contains(result, "(nullable)") {
t.Errorf("Expected nullable hint, got: %s", result) t.Errorf("Expected nullable hint, got: %s", result)
@@ -460,7 +478,7 @@ func TestCleanJSONSchemaForGemini_NullableHint(t *testing.T) {
} }
} }
func TestCleanJSONSchemaForGemini_TypeFlattening_Nullable_DotKey(t *testing.T) { func TestCleanJSONSchemaForAntigravity_TypeFlattening_Nullable_DotKey(t *testing.T) {
input := `{ input := `{
"type": "object", "type": "object",
"properties": { "properties": {
@@ -488,11 +506,11 @@ func TestCleanJSONSchemaForGemini_TypeFlattening_Nullable_DotKey(t *testing.T) {
"required": ["other"] "required": ["other"]
}` }`
result := CleanJSONSchemaForGemini(input) result := CleanJSONSchemaForAntigravity(input)
compareJSON(t, expected, result) compareJSON(t, expected, result)
} }
func TestCleanJSONSchemaForGemini_EnumHint(t *testing.T) { func TestCleanJSONSchemaForAntigravity_EnumHint(t *testing.T) {
input := `{ input := `{
"type": "object", "type": "object",
"properties": { "properties": {
@@ -504,7 +522,7 @@ func TestCleanJSONSchemaForGemini_EnumHint(t *testing.T) {
} }
}` }`
result := CleanJSONSchemaForGemini(input) result := CleanJSONSchemaForAntigravity(input)
if !strings.Contains(result, "Allowed:") { if !strings.Contains(result, "Allowed:") {
t.Errorf("Expected enum values hint, got: %s", result) t.Errorf("Expected enum values hint, got: %s", result)
@@ -514,7 +532,7 @@ func TestCleanJSONSchemaForGemini_EnumHint(t *testing.T) {
} }
} }
func TestCleanJSONSchemaForGemini_AdditionalPropertiesHint(t *testing.T) { func TestCleanJSONSchemaForAntigravity_AdditionalPropertiesHint(t *testing.T) {
input := `{ input := `{
"type": "object", "type": "object",
"properties": { "properties": {
@@ -523,14 +541,14 @@ func TestCleanJSONSchemaForGemini_AdditionalPropertiesHint(t *testing.T) {
"additionalProperties": false "additionalProperties": false
}` }`
result := CleanJSONSchemaForGemini(input) result := CleanJSONSchemaForAntigravity(input)
if !strings.Contains(result, "No extra properties allowed") { if !strings.Contains(result, "No extra properties allowed") {
t.Errorf("Expected additionalProperties hint, got: %s", result) t.Errorf("Expected additionalProperties hint, got: %s", result)
} }
} }
func TestCleanJSONSchemaForGemini_AnyOfFlattening_PreservesDescription(t *testing.T) { func TestCleanJSONSchemaForAntigravity_AnyOfFlattening_PreservesDescription(t *testing.T) {
input := `{ input := `{
"type": "object", "type": "object",
"properties": { "properties": {
@@ -554,11 +572,11 @@ func TestCleanJSONSchemaForGemini_AnyOfFlattening_PreservesDescription(t *testin
} }
}` }`
result := CleanJSONSchemaForGemini(input) result := CleanJSONSchemaForAntigravity(input)
compareJSON(t, expected, result) compareJSON(t, expected, result)
} }
func TestCleanJSONSchemaForGemini_SingleEnumNoHint(t *testing.T) { func TestCleanJSONSchemaForAntigravity_SingleEnumNoHint(t *testing.T) {
input := `{ input := `{
"type": "object", "type": "object",
"properties": { "properties": {
@@ -569,14 +587,14 @@ func TestCleanJSONSchemaForGemini_SingleEnumNoHint(t *testing.T) {
} }
}` }`
result := CleanJSONSchemaForGemini(input) result := CleanJSONSchemaForAntigravity(input)
if strings.Contains(result, "Allowed:") { if strings.Contains(result, "Allowed:") {
t.Errorf("Single value enum should not add Allowed hint, got: %s", result) t.Errorf("Single value enum should not add Allowed hint, got: %s", result)
} }
} }
func TestCleanJSONSchemaForGemini_MultipleNonNullTypes(t *testing.T) { func TestCleanJSONSchemaForAntigravity_MultipleNonNullTypes(t *testing.T) {
input := `{ input := `{
"type": "object", "type": "object",
"properties": { "properties": {
@@ -586,7 +604,7 @@ func TestCleanJSONSchemaForGemini_MultipleNonNullTypes(t *testing.T) {
} }
}` }`
result := CleanJSONSchemaForGemini(input) result := CleanJSONSchemaForAntigravity(input)
if !strings.Contains(result, "Accepts:") { if !strings.Contains(result, "Accepts:") {
t.Errorf("Expected multiple types hint, got: %s", result) t.Errorf("Expected multiple types hint, got: %s", result)
@@ -596,6 +614,71 @@ func TestCleanJSONSchemaForGemini_MultipleNonNullTypes(t *testing.T) {
} }
} }
func TestCleanJSONSchemaForGemini_PropertyNamesRemoval(t *testing.T) {
// propertyNames is used to validate object property names (e.g., must match a pattern)
// Gemini doesn't support this keyword and will reject requests containing it
input := `{
"type": "object",
"properties": {
"metadata": {
"type": "object",
"propertyNames": {
"pattern": "^[a-zA-Z_][a-zA-Z0-9_]*$"
},
"additionalProperties": {
"type": "string"
}
}
}
}`
expected := `{
"type": "object",
"properties": {
"metadata": {
"type": "object"
}
}
}`
result := CleanJSONSchemaForGemini(input)
compareJSON(t, expected, result)
// Verify propertyNames is completely removed
if strings.Contains(result, "propertyNames") {
t.Errorf("propertyNames keyword should be removed, got: %s", result)
}
}
func TestCleanJSONSchemaForGemini_PropertyNamesRemoval_Nested(t *testing.T) {
// Test deeply nested propertyNames (as seen in real Claude tool schemas)
input := `{
"type": "object",
"properties": {
"items": {
"type": "array",
"items": {
"type": "object",
"properties": {
"config": {
"type": "object",
"propertyNames": {
"type": "string"
}
}
}
}
}
}
}`
result := CleanJSONSchemaForGemini(input)
if strings.Contains(result, "propertyNames") {
t.Errorf("Nested propertyNames should be removed, got: %s", result)
}
}
func compareJSON(t *testing.T, expectedJSON, actualJSON string) { func compareJSON(t *testing.T, expectedJSON, actualJSON string) {
var expMap, actMap map[string]interface{} var expMap, actMap map[string]interface{}
errExp := json.Unmarshal([]byte(expectedJSON), &expMap) errExp := json.Unmarshal([]byte(expectedJSON), &expMap)
@@ -611,3 +694,190 @@ func compareJSON(t *testing.T, expectedJSON, actualJSON string) {
t.Errorf("JSON mismatch:\nExpected:\n%s\n\nActual:\n%s", string(expBytes), string(actBytes)) t.Errorf("JSON mismatch:\nExpected:\n%s\n\nActual:\n%s", string(expBytes), string(actBytes))
} }
} }
// ============================================================================
// Empty Schema Placeholder Tests
// ============================================================================
func TestCleanJSONSchemaForAntigravity_EmptySchemaPlaceholder(t *testing.T) {
// Empty object schema with no properties should get a placeholder
input := `{
"type": "object"
}`
result := CleanJSONSchemaForAntigravity(input)
// Should have placeholder property added
if !strings.Contains(result, `"reason"`) {
t.Errorf("Empty schema should have 'reason' placeholder property, got: %s", result)
}
if !strings.Contains(result, `"required"`) {
t.Errorf("Empty schema should have 'required' with 'reason', got: %s", result)
}
}
func TestCleanJSONSchemaForAntigravity_EmptyPropertiesPlaceholder(t *testing.T) {
// Object with empty properties object
input := `{
"type": "object",
"properties": {}
}`
result := CleanJSONSchemaForAntigravity(input)
// Should have placeholder property added
if !strings.Contains(result, `"reason"`) {
t.Errorf("Empty properties should have 'reason' placeholder, got: %s", result)
}
}
func TestCleanJSONSchemaForAntigravity_NonEmptySchemaUnchanged(t *testing.T) {
// Schema with properties should NOT get placeholder
input := `{
"type": "object",
"properties": {
"name": {"type": "string"}
},
"required": ["name"]
}`
result := CleanJSONSchemaForAntigravity(input)
// Should NOT have placeholder property
if strings.Contains(result, `"reason"`) {
t.Errorf("Non-empty schema should NOT have 'reason' placeholder, got: %s", result)
}
// Original properties should be preserved
if !strings.Contains(result, `"name"`) {
t.Errorf("Original property 'name' should be preserved, got: %s", result)
}
}
func TestCleanJSONSchemaForAntigravity_NestedEmptySchema(t *testing.T) {
// Nested empty object in items should also get placeholder
input := `{
"type": "object",
"properties": {
"items": {
"type": "array",
"items": {
"type": "object"
}
}
}
}`
result := CleanJSONSchemaForAntigravity(input)
// Nested empty object should also get placeholder
// Check that the nested object has a reason property
parsed := gjson.Parse(result)
nestedProps := parsed.Get("properties.items.items.properties")
if !nestedProps.Exists() || !nestedProps.Get("reason").Exists() {
t.Errorf("Nested empty object should have 'reason' placeholder, got: %s", result)
}
}
func TestCleanJSONSchemaForAntigravity_EmptySchemaWithDescription(t *testing.T) {
// Empty schema with description should preserve description and add placeholder
input := `{
"type": "object",
"description": "An empty object"
}`
result := CleanJSONSchemaForAntigravity(input)
// Should have both description and placeholder
if !strings.Contains(result, `"An empty object"`) {
t.Errorf("Description should be preserved, got: %s", result)
}
if !strings.Contains(result, `"reason"`) {
t.Errorf("Empty schema should have 'reason' placeholder, got: %s", result)
}
}
// ============================================================================
// Format field handling (ad-hoc patch removal)
// ============================================================================
func TestCleanJSONSchemaForAntigravity_FormatFieldRemoval(t *testing.T) {
// format:"uri" should be removed and added as hint
input := `{
"type": "object",
"properties": {
"url": {
"type": "string",
"format": "uri",
"description": "A URL"
}
}
}`
result := CleanJSONSchemaForAntigravity(input)
// format should be removed
if strings.Contains(result, `"format"`) {
t.Errorf("format field should be removed, got: %s", result)
}
// hint should be added to description
if !strings.Contains(result, "format: uri") {
t.Errorf("format hint should be added to description, got: %s", result)
}
// original description should be preserved
if !strings.Contains(result, "A URL") {
t.Errorf("Original description should be preserved, got: %s", result)
}
}
func TestCleanJSONSchemaForAntigravity_FormatFieldNoDescription(t *testing.T) {
// format without description should create description with hint
input := `{
"type": "object",
"properties": {
"email": {
"type": "string",
"format": "email"
}
}
}`
result := CleanJSONSchemaForAntigravity(input)
// format should be removed
if strings.Contains(result, `"format"`) {
t.Errorf("format field should be removed, got: %s", result)
}
// hint should be added
if !strings.Contains(result, "format: email") {
t.Errorf("format hint should be added, got: %s", result)
}
}
func TestCleanJSONSchemaForAntigravity_MultipleFormats(t *testing.T) {
// Multiple format fields should all be handled
input := `{
"type": "object",
"properties": {
"url": {"type": "string", "format": "uri"},
"email": {"type": "string", "format": "email"},
"date": {"type": "string", "format": "date-time"}
}
}`
result := CleanJSONSchemaForAntigravity(input)
// All format fields should be removed
if strings.Contains(result, `"format"`) {
t.Errorf("All format fields should be removed, got: %s", result)
}
// All hints should be added
if !strings.Contains(result, "format: uri") {
t.Errorf("uri format hint should be added, got: %s", result)
}
if !strings.Contains(result, "format: email") {
t.Errorf("email format hint should be added, got: %s", result)
}
if !strings.Contains(result, "format: date-time") {
t.Errorf("date-time format hint should be added, got: %s", result)
}
}

View File

@@ -136,6 +136,12 @@ func ApplyGeminiThinkingLevel(body []byte, level string, includeThoughts *bool)
updated = rewritten updated = rewritten
} }
} }
if it := gjson.GetBytes(body, "generationConfig.thinkingConfig.include_thoughts"); it.Exists() {
updated, _ = sjson.DeleteBytes(updated, "generationConfig.thinkingConfig.include_thoughts")
}
if tb := gjson.GetBytes(body, "generationConfig.thinkingConfig.thinkingBudget"); tb.Exists() {
updated, _ = sjson.DeleteBytes(updated, "generationConfig.thinkingConfig.thinkingBudget")
}
return updated return updated
} }
@@ -167,6 +173,12 @@ func ApplyGeminiCLIThinkingLevel(body []byte, level string, includeThoughts *boo
updated = rewritten updated = rewritten
} }
} }
if it := gjson.GetBytes(body, "request.generationConfig.thinkingConfig.include_thoughts"); it.Exists() {
updated, _ = sjson.DeleteBytes(updated, "request.generationConfig.thinkingConfig.include_thoughts")
}
if tb := gjson.GetBytes(body, "request.generationConfig.thinkingConfig.thinkingBudget"); tb.Exists() {
updated, _ = sjson.DeleteBytes(updated, "request.generationConfig.thinkingConfig.thinkingBudget")
}
return updated return updated
} }
@@ -242,7 +254,7 @@ func ThinkingBudgetToGemini3Level(model string, budget int) (string, bool) {
var modelsWithDefaultThinking = map[string]bool{ var modelsWithDefaultThinking = map[string]bool{
"gemini-3-pro-preview": true, "gemini-3-pro-preview": true,
"gemini-3-pro-image-preview": true, "gemini-3-pro-image-preview": true,
"gemini-3-flash-preview": true, // "gemini-3-flash-preview": true,
} }
// ModelHasDefaultThinking returns true if the model should have thinking enabled by default. // ModelHasDefaultThinking returns true if the model should have thinking enabled by default.
@@ -352,8 +364,9 @@ func StripThinkingConfigIfUnsupported(model string, body []byte) []byte {
// NormalizeGeminiThinkingBudget normalizes the thinkingBudget value in a standard Gemini // NormalizeGeminiThinkingBudget normalizes the thinkingBudget value in a standard Gemini
// request body (generationConfig.thinkingConfig.thinkingBudget path). // request body (generationConfig.thinkingConfig.thinkingBudget path).
// For Gemini 3 models, converts thinkingBudget to thinkingLevel per Google's documentation. // For Gemini 3 models, converts thinkingBudget to thinkingLevel per Google's documentation,
func NormalizeGeminiThinkingBudget(model string, body []byte) []byte { // unless skipGemini3Check is provided and true.
func NormalizeGeminiThinkingBudget(model string, body []byte, skipGemini3Check ...bool) []byte {
const budgetPath = "generationConfig.thinkingConfig.thinkingBudget" const budgetPath = "generationConfig.thinkingConfig.thinkingBudget"
const levelPath = "generationConfig.thinkingConfig.thinkingLevel" const levelPath = "generationConfig.thinkingConfig.thinkingLevel"
@@ -363,7 +376,8 @@ func NormalizeGeminiThinkingBudget(model string, body []byte) []byte {
} }
// For Gemini 3 models, convert thinkingBudget to thinkingLevel // For Gemini 3 models, convert thinkingBudget to thinkingLevel
if IsGemini3Model(model) { skipGemini3 := len(skipGemini3Check) > 0 && skipGemini3Check[0]
if IsGemini3Model(model) && !skipGemini3 {
if level, ok := ThinkingBudgetToGemini3Level(model, int(budget.Int())); ok { if level, ok := ThinkingBudgetToGemini3Level(model, int(budget.Int())); ok {
updated, _ := sjson.SetBytes(body, levelPath, level) updated, _ := sjson.SetBytes(body, levelPath, level)
updated, _ = sjson.DeleteBytes(updated, budgetPath) updated, _ = sjson.DeleteBytes(updated, budgetPath)
@@ -382,8 +396,9 @@ func NormalizeGeminiThinkingBudget(model string, body []byte) []byte {
// NormalizeGeminiCLIThinkingBudget normalizes the thinkingBudget value in a Gemini CLI // NormalizeGeminiCLIThinkingBudget normalizes the thinkingBudget value in a Gemini CLI
// request body (request.generationConfig.thinkingConfig.thinkingBudget path). // request body (request.generationConfig.thinkingConfig.thinkingBudget path).
// For Gemini 3 models, converts thinkingBudget to thinkingLevel per Google's documentation. // For Gemini 3 models, converts thinkingBudget to thinkingLevel per Google's documentation,
func NormalizeGeminiCLIThinkingBudget(model string, body []byte) []byte { // unless skipGemini3Check is provided and true.
func NormalizeGeminiCLIThinkingBudget(model string, body []byte, skipGemini3Check ...bool) []byte {
const budgetPath = "request.generationConfig.thinkingConfig.thinkingBudget" const budgetPath = "request.generationConfig.thinkingConfig.thinkingBudget"
const levelPath = "request.generationConfig.thinkingConfig.thinkingLevel" const levelPath = "request.generationConfig.thinkingConfig.thinkingLevel"
@@ -393,7 +408,8 @@ func NormalizeGeminiCLIThinkingBudget(model string, body []byte) []byte {
} }
// For Gemini 3 models, convert thinkingBudget to thinkingLevel // For Gemini 3 models, convert thinkingBudget to thinkingLevel
if IsGemini3Model(model) { skipGemini3 := len(skipGemini3Check) > 0 && skipGemini3Check[0]
if IsGemini3Model(model) && !skipGemini3 {
if level, ok := ThinkingBudgetToGemini3Level(model, int(budget.Int())); ok { if level, ok := ThinkingBudgetToGemini3Level(model, int(budget.Int())); ok {
updated, _ := sjson.SetBytes(body, levelPath, level) updated, _ := sjson.SetBytes(body, levelPath, level)
updated, _ = sjson.DeleteBytes(updated, budgetPath) updated, _ = sjson.DeleteBytes(updated, budgetPath)
@@ -477,7 +493,7 @@ func ApplyReasoningEffortToGeminiCLI(body []byte, effort string) []byte {
// ConvertThinkingLevelToBudget checks for "generationConfig.thinkingConfig.thinkingLevel" // ConvertThinkingLevelToBudget checks for "generationConfig.thinkingConfig.thinkingLevel"
// and converts it to "thinkingBudget" for Gemini 2.5 models. // and converts it to "thinkingBudget" for Gemini 2.5 models.
// For Gemini 3 models, preserves thinkingLevel as-is (does not convert). // For Gemini 3 models, preserves thinkingLevel unless skipGemini3Check is provided and true.
// Mappings for Gemini 2.5: // Mappings for Gemini 2.5:
// - "high" -> 32768 // - "high" -> 32768
// - "medium" -> 8192 // - "medium" -> 8192
@@ -485,43 +501,31 @@ func ApplyReasoningEffortToGeminiCLI(body []byte, effort string) []byte {
// - "minimal" -> 512 // - "minimal" -> 512
// //
// It removes "thinkingLevel" after conversion (for Gemini 2.5 only). // It removes "thinkingLevel" after conversion (for Gemini 2.5 only).
func ConvertThinkingLevelToBudget(body []byte, model string) []byte { func ConvertThinkingLevelToBudget(body []byte, model string, skipGemini3Check ...bool) []byte {
levelPath := "generationConfig.thinkingConfig.thinkingLevel" levelPath := "generationConfig.thinkingConfig.thinkingLevel"
res := gjson.GetBytes(body, levelPath) res := gjson.GetBytes(body, levelPath)
if !res.Exists() { if !res.Exists() {
return body return body
} }
// For Gemini 3 models, preserve thinkingLevel - don't convert to budget // For Gemini 3 models, preserve thinkingLevel unless explicitly skipped
if IsGemini3Model(model) { skipGemini3 := len(skipGemini3Check) > 0 && skipGemini3Check[0]
if IsGemini3Model(model) && !skipGemini3 {
return body return body
} }
level := strings.ToLower(res.String()) budget, ok := ThinkingLevelToBudget(res.String())
var budget int if !ok {
switch level {
case "high":
budget = 32768
case "medium":
budget = 8192
case "low":
budget = 1024
case "minimal":
budget = 512
default:
// Unknown level - remove it and let the API use defaults
updated, _ := sjson.DeleteBytes(body, levelPath) updated, _ := sjson.DeleteBytes(body, levelPath)
return updated return updated
} }
// Set budget
budgetPath := "generationConfig.thinkingConfig.thinkingBudget" budgetPath := "generationConfig.thinkingConfig.thinkingBudget"
updated, err := sjson.SetBytes(body, budgetPath, budget) updated, err := sjson.SetBytes(body, budgetPath, budget)
if err != nil { if err != nil {
return body return body
} }
// Remove level
updated, err = sjson.DeleteBytes(updated, levelPath) updated, err = sjson.DeleteBytes(updated, levelPath)
if err != nil { if err != nil {
return body return body
@@ -544,31 +548,18 @@ func ConvertThinkingLevelToBudgetCLI(body []byte, model string) []byte {
return body return body
} }
level := strings.ToLower(res.String()) budget, ok := ThinkingLevelToBudget(res.String())
var budget int if !ok {
switch level {
case "high":
budget = 32768
case "medium":
budget = 8192
case "low":
budget = 1024
case "minimal":
budget = 512
default:
// Unknown level - remove it and let the API use defaults
updated, _ := sjson.DeleteBytes(body, levelPath) updated, _ := sjson.DeleteBytes(body, levelPath)
return updated return updated
} }
// Set budget
budgetPath := "request.generationConfig.thinkingConfig.thinkingBudget" budgetPath := "request.generationConfig.thinkingConfig.thinkingBudget"
updated, err := sjson.SetBytes(body, budgetPath, budget) updated, err := sjson.SetBytes(body, budgetPath, budget)
if err != nil { if err != nil {
return body return body
} }
// Remove level
updated, err = sjson.DeleteBytes(updated, levelPath) updated, err = sjson.DeleteBytes(updated, levelPath)
if err != nil { if err != nil {
return body return body

View File

@@ -160,6 +160,34 @@ func ThinkingEffortToBudget(model, effort string) (int, bool) {
} }
} }
// ThinkingLevelToBudget maps a Gemini thinkingLevel to a numeric thinking budget (tokens).
//
// Mappings:
// - "minimal" -> 512
// - "low" -> 1024
// - "medium" -> 8192
// - "high" -> 32768
//
// Returns false when the level is empty or unsupported.
func ThinkingLevelToBudget(level string) (int, bool) {
if level == "" {
return 0, false
}
normalized := strings.ToLower(strings.TrimSpace(level))
switch normalized {
case "minimal":
return 512, true
case "low":
return 1024, true
case "medium":
return 8192, true
case "high":
return 32768, true
default:
return 0, false
}
}
// ThinkingBudgetToEffort maps a numeric thinking budget (tokens) // ThinkingBudgetToEffort maps a numeric thinking budget (tokens)
// to a reasoning effort level for level-based models. // to a reasoning effort level for level-based models.
// //

View File

@@ -0,0 +1,87 @@
package util
import (
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
)
// GetThinkingText extracts the thinking text from a content part.
// Handles various formats:
// - Simple string: { "thinking": "text" } or { "text": "text" }
// - Wrapped object: { "thinking": { "text": "text", "cache_control": {...} } }
// - Gemini-style: { "thought": true, "text": "text" }
// Returns the extracted text string.
func GetThinkingText(part gjson.Result) string {
// Try direct text field first (Gemini-style)
if text := part.Get("text"); text.Exists() && text.Type == gjson.String {
return text.String()
}
// Try thinking field
thinkingField := part.Get("thinking")
if !thinkingField.Exists() {
return ""
}
// thinking is a string
if thinkingField.Type == gjson.String {
return thinkingField.String()
}
// thinking is an object with inner text/thinking
if thinkingField.IsObject() {
if inner := thinkingField.Get("text"); inner.Exists() && inner.Type == gjson.String {
return inner.String()
}
if inner := thinkingField.Get("thinking"); inner.Exists() && inner.Type == gjson.String {
return inner.String()
}
}
return ""
}
// GetThinkingTextFromJSON extracts thinking text from a raw JSON string.
func GetThinkingTextFromJSON(jsonStr string) string {
return GetThinkingText(gjson.Parse(jsonStr))
}
// SanitizeThinkingPart normalizes a thinking part to a canonical form.
// Strips cache_control and other non-essential fields.
// Returns the sanitized part as JSON string.
func SanitizeThinkingPart(part gjson.Result) string {
// Gemini-style: { thought: true, text, thoughtSignature }
if part.Get("thought").Bool() {
result := `{"thought":true}`
if text := GetThinkingText(part); text != "" {
result, _ = sjson.Set(result, "text", text)
}
if sig := part.Get("thoughtSignature"); sig.Exists() && sig.Type == gjson.String {
result, _ = sjson.Set(result, "thoughtSignature", sig.String())
}
return result
}
// Anthropic-style: { type: "thinking", thinking, signature }
if part.Get("type").String() == "thinking" || part.Get("thinking").Exists() {
result := `{"type":"thinking"}`
if text := GetThinkingText(part); text != "" {
result, _ = sjson.Set(result, "thinking", text)
}
if sig := part.Get("signature"); sig.Exists() && sig.Type == gjson.String {
result, _ = sjson.Set(result, "signature", sig.String())
}
return result
}
// Not a thinking part, return as-is but strip cache_control
return StripCacheControl(part.Raw)
}
// StripCacheControl removes cache_control and providerOptions from a JSON object.
func StripCacheControl(jsonStr string) string {
result := jsonStr
result, _ = sjson.Delete(result, "cache_control")
result, _ = sjson.Delete(result, "providerOptions")
return result
}

46
sdk/api/options.go Normal file
View File

@@ -0,0 +1,46 @@
// Package api exposes server option helpers for embedding CLIProxyAPI.
//
// It wraps internal server option types so external projects can configure the embedded
// HTTP server without importing internal packages.
package api
import (
"time"
"github.com/gin-gonic/gin"
internalapi "github.com/router-for-me/CLIProxyAPI/v6/internal/api"
"github.com/router-for-me/CLIProxyAPI/v6/sdk/api/handlers"
"github.com/router-for-me/CLIProxyAPI/v6/sdk/config"
"github.com/router-for-me/CLIProxyAPI/v6/sdk/logging"
)
// ServerOption customises HTTP server construction.
type ServerOption = internalapi.ServerOption
// WithMiddleware appends additional Gin middleware during server construction.
func WithMiddleware(mw ...gin.HandlerFunc) ServerOption { return internalapi.WithMiddleware(mw...) }
// WithEngineConfigurator allows callers to mutate the Gin engine prior to middleware setup.
func WithEngineConfigurator(fn func(*gin.Engine)) ServerOption {
return internalapi.WithEngineConfigurator(fn)
}
// WithRouterConfigurator appends a callback after default routes are registered.
func WithRouterConfigurator(fn func(*gin.Engine, *handlers.BaseAPIHandler, *config.Config)) ServerOption {
return internalapi.WithRouterConfigurator(fn)
}
// WithLocalManagementPassword stores a runtime-only management password accepted for localhost requests.
func WithLocalManagementPassword(password string) ServerOption {
return internalapi.WithLocalManagementPassword(password)
}
// WithKeepAliveEndpoint enables a keep-alive endpoint with the provided timeout and callback.
func WithKeepAliveEndpoint(timeout time.Duration, onTimeout func()) ServerOption {
return internalapi.WithKeepAliveEndpoint(timeout, onTimeout)
}
// WithRequestLoggerFactory customises request logger creation.
func WithRequestLoggerFactory(factory func(*config.Config, string) logging.RequestLogger) ServerOption {
return internalapi.WithRequestLoggerFactory(factory)
}

View File

@@ -99,11 +99,54 @@ func (AntigravityAuthenticator) Login(ctx context.Context, cfg *config.Config, o
fmt.Println("Waiting for antigravity authentication callback...") fmt.Println("Waiting for antigravity authentication callback...")
var cbRes callbackResult var cbRes callbackResult
select { timeoutTimer := time.NewTimer(5 * time.Minute)
case res := <-cbChan: defer timeoutTimer.Stop()
cbRes = res
case <-time.After(5 * time.Minute): var manualPromptTimer *time.Timer
return nil, fmt.Errorf("antigravity: authentication timed out") var manualPromptC <-chan time.Time
if opts.Prompt != nil {
manualPromptTimer = time.NewTimer(15 * time.Second)
manualPromptC = manualPromptTimer.C
defer manualPromptTimer.Stop()
}
waitForCallback:
for {
select {
case res := <-cbChan:
cbRes = res
break waitForCallback
case <-manualPromptC:
manualPromptC = nil
if manualPromptTimer != nil {
manualPromptTimer.Stop()
}
select {
case res := <-cbChan:
cbRes = res
break waitForCallback
default:
}
input, errPrompt := opts.Prompt("Paste the antigravity callback URL (or press Enter to keep waiting): ")
if errPrompt != nil {
return nil, errPrompt
}
parsed, errParse := misc.ParseOAuthCallback(input)
if errParse != nil {
return nil, errParse
}
if parsed == nil {
continue
}
cbRes = callbackResult{
Code: parsed.Code,
State: parsed.State,
Error: parsed.Error,
}
break waitForCallback
case <-timeoutTimer.C:
return nil, fmt.Errorf("antigravity: authentication timed out")
}
} }
if cbRes.Error != "" { if cbRes.Error != "" {

View File

@@ -98,16 +98,76 @@ func (a *ClaudeAuthenticator) Login(ctx context.Context, cfg *config.Config, opt
fmt.Println("Waiting for Claude authentication callback...") fmt.Println("Waiting for Claude authentication callback...")
result, err := oauthServer.WaitForCallback(5 * time.Minute) callbackCh := make(chan *claude.OAuthResult, 1)
if err != nil { callbackErrCh := make(chan error, 1)
if strings.Contains(err.Error(), "timeout") { manualDescription := ""
return nil, claude.NewAuthenticationError(claude.ErrCallbackTimeout, err)
go func() {
result, errWait := oauthServer.WaitForCallback(5 * time.Minute)
if errWait != nil {
callbackErrCh <- errWait
return
}
callbackCh <- result
}()
var result *claude.OAuthResult
var manualPromptTimer *time.Timer
var manualPromptC <-chan time.Time
if opts.Prompt != nil {
manualPromptTimer = time.NewTimer(15 * time.Second)
manualPromptC = manualPromptTimer.C
defer manualPromptTimer.Stop()
}
waitForCallback:
for {
select {
case result = <-callbackCh:
break waitForCallback
case err = <-callbackErrCh:
if strings.Contains(err.Error(), "timeout") {
return nil, claude.NewAuthenticationError(claude.ErrCallbackTimeout, err)
}
return nil, err
case <-manualPromptC:
manualPromptC = nil
if manualPromptTimer != nil {
manualPromptTimer.Stop()
}
select {
case result = <-callbackCh:
break waitForCallback
case err = <-callbackErrCh:
if strings.Contains(err.Error(), "timeout") {
return nil, claude.NewAuthenticationError(claude.ErrCallbackTimeout, err)
}
return nil, err
default:
}
input, errPrompt := opts.Prompt("Paste the Claude callback URL (or press Enter to keep waiting): ")
if errPrompt != nil {
return nil, errPrompt
}
parsed, errParse := misc.ParseOAuthCallback(input)
if errParse != nil {
return nil, errParse
}
if parsed == nil {
continue
}
manualDescription = parsed.ErrorDescription
result = &claude.OAuthResult{
Code: parsed.Code,
State: parsed.State,
Error: parsed.Error,
}
break waitForCallback
} }
return nil, err
} }
if result.Error != "" { if result.Error != "" {
return nil, claude.NewOAuthError(result.Error, "", http.StatusBadRequest) return nil, claude.NewOAuthError(result.Error, manualDescription, http.StatusBadRequest)
} }
if result.State != state { if result.State != state {

View File

@@ -97,16 +97,76 @@ func (a *CodexAuthenticator) Login(ctx context.Context, cfg *config.Config, opts
fmt.Println("Waiting for Codex authentication callback...") fmt.Println("Waiting for Codex authentication callback...")
result, err := oauthServer.WaitForCallback(5 * time.Minute) callbackCh := make(chan *codex.OAuthResult, 1)
if err != nil { callbackErrCh := make(chan error, 1)
if strings.Contains(err.Error(), "timeout") { manualDescription := ""
return nil, codex.NewAuthenticationError(codex.ErrCallbackTimeout, err)
go func() {
result, errWait := oauthServer.WaitForCallback(5 * time.Minute)
if errWait != nil {
callbackErrCh <- errWait
return
}
callbackCh <- result
}()
var result *codex.OAuthResult
var manualPromptTimer *time.Timer
var manualPromptC <-chan time.Time
if opts.Prompt != nil {
manualPromptTimer = time.NewTimer(15 * time.Second)
manualPromptC = manualPromptTimer.C
defer manualPromptTimer.Stop()
}
waitForCallback:
for {
select {
case result = <-callbackCh:
break waitForCallback
case err = <-callbackErrCh:
if strings.Contains(err.Error(), "timeout") {
return nil, codex.NewAuthenticationError(codex.ErrCallbackTimeout, err)
}
return nil, err
case <-manualPromptC:
manualPromptC = nil
if manualPromptTimer != nil {
manualPromptTimer.Stop()
}
select {
case result = <-callbackCh:
break waitForCallback
case err = <-callbackErrCh:
if strings.Contains(err.Error(), "timeout") {
return nil, codex.NewAuthenticationError(codex.ErrCallbackTimeout, err)
}
return nil, err
default:
}
input, errPrompt := opts.Prompt("Paste the Codex callback URL (or press Enter to keep waiting): ")
if errPrompt != nil {
return nil, errPrompt
}
parsed, errParse := misc.ParseOAuthCallback(input)
if errParse != nil {
return nil, errParse
}
if parsed == nil {
continue
}
manualDescription = parsed.ErrorDescription
result = &codex.OAuthResult{
Code: parsed.Code,
State: parsed.State,
Error: parsed.Error,
}
break waitForCallback
} }
return nil, err
} }
if result.Error != "" { if result.Error != "" {
return nil, codex.NewOAuthError(result.Error, "", http.StatusBadRequest) return nil, codex.NewOAuthError(result.Error, manualDescription, http.StatusBadRequest)
} }
if result.State != state { if result.State != state {

View File

@@ -72,7 +72,9 @@ func (s *FileTokenStore) Save(ctx context.Context, auth *cliproxyauth.Auth) (str
return "", fmt.Errorf("auth filestore: marshal metadata failed: %w", errMarshal) return "", fmt.Errorf("auth filestore: marshal metadata failed: %w", errMarshal)
} }
if existing, errRead := os.ReadFile(path); errRead == nil { if existing, errRead := os.ReadFile(path); errRead == nil {
if jsonEqual(existing, raw) { // Use metadataEqualIgnoringTimestamps to skip writes when only timestamp fields change.
// This prevents the token refresh loop caused by timestamp/expired/expires_in changes.
if metadataEqualIgnoringTimestamps(existing, raw) {
return path, nil return path, nil
} }
} else if errRead != nil && !os.IsNotExist(errRead) { } else if errRead != nil && !os.IsNotExist(errRead) {
@@ -264,6 +266,8 @@ func (s *FileTokenStore) baseDirSnapshot() string {
return s.baseDir return s.baseDir
} }
// DEPRECATED: Use metadataEqualIgnoringTimestamps for comparing auth metadata.
// This function is kept for backward compatibility but can cause refresh loops.
func jsonEqual(a, b []byte) bool { func jsonEqual(a, b []byte) bool {
var objA any var objA any
var objB any var objB any
@@ -276,6 +280,32 @@ func jsonEqual(a, b []byte) bool {
return deepEqualJSON(objA, objB) return deepEqualJSON(objA, objB)
} }
// metadataEqualIgnoringTimestamps compares two metadata JSON blobs,
// ignoring fields that change on every refresh but don't affect functionality.
// This prevents unnecessary file writes that would trigger watcher events and
// create refresh loops.
func metadataEqualIgnoringTimestamps(a, b []byte) bool {
var objA, objB map[string]any
if err := json.Unmarshal(a, &objA); err != nil {
return false
}
if err := json.Unmarshal(b, &objB); err != nil {
return false
}
// Fields to ignore: these change on every refresh but don't affect authentication logic.
// - timestamp, expired, expires_in, last_refresh: time-related fields that change on refresh
// - access_token: Google OAuth returns a new access_token on each refresh, this is expected
// and shouldn't trigger file writes (the new token will be fetched again when needed)
ignoredFields := []string{"timestamp", "expired", "expires_in", "last_refresh", "access_token"}
for _, field := range ignoredFields {
delete(objA, field)
delete(objB, field)
}
return deepEqualJSON(objA, objB)
}
func deepEqualJSON(a, b any) bool { func deepEqualJSON(a, b any) bool {
switch valA := a.(type) { switch valA := a.(type) {
case map[string]any: case map[string]any:

View File

@@ -44,7 +44,10 @@ func (a *GeminiAuthenticator) Login(ctx context.Context, cfg *config.Config, opt
} }
geminiAuth := gemini.NewGeminiAuth() geminiAuth := gemini.NewGeminiAuth()
_, err := geminiAuth.GetAuthenticatedClient(ctx, &ts, cfg, opts.NoBrowser) _, err := geminiAuth.GetAuthenticatedClient(ctx, &ts, cfg, &gemini.WebLoginOptions{
NoBrowser: opts.NoBrowser,
Prompt: opts.Prompt,
})
if err != nil { if err != nil {
return nil, fmt.Errorf("gemini authentication failed: %w", err) return nil, fmt.Errorf("gemini authentication failed: %w", err)
} }

View File

@@ -84,9 +84,64 @@ func (a *IFlowAuthenticator) Login(ctx context.Context, cfg *config.Config, opts
fmt.Println("Waiting for iFlow authentication callback...") fmt.Println("Waiting for iFlow authentication callback...")
result, err := oauthServer.WaitForCallback(5 * time.Minute) callbackCh := make(chan *iflow.OAuthResult, 1)
if err != nil { callbackErrCh := make(chan error, 1)
return nil, fmt.Errorf("iflow auth: callback wait failed: %w", err)
go func() {
result, errWait := oauthServer.WaitForCallback(5 * time.Minute)
if errWait != nil {
callbackErrCh <- errWait
return
}
callbackCh <- result
}()
var result *iflow.OAuthResult
var manualPromptTimer *time.Timer
var manualPromptC <-chan time.Time
if opts.Prompt != nil {
manualPromptTimer = time.NewTimer(15 * time.Second)
manualPromptC = manualPromptTimer.C
defer manualPromptTimer.Stop()
}
waitForCallback:
for {
select {
case result = <-callbackCh:
break waitForCallback
case err = <-callbackErrCh:
return nil, fmt.Errorf("iflow auth: callback wait failed: %w", err)
case <-manualPromptC:
manualPromptC = nil
if manualPromptTimer != nil {
manualPromptTimer.Stop()
}
select {
case result = <-callbackCh:
break waitForCallback
case err = <-callbackErrCh:
return nil, fmt.Errorf("iflow auth: callback wait failed: %w", err)
default:
}
input, errPrompt := opts.Prompt("Paste the iFlow callback URL (or press Enter to keep waiting): ")
if errPrompt != nil {
return nil, errPrompt
}
parsed, errParse := misc.ParseOAuthCallback(input)
if errParse != nil {
return nil, errParse
}
if parsed == nil {
continue
}
result = &iflow.OAuthResult{
Code: parsed.Code,
State: parsed.State,
Error: parsed.Error,
}
break waitForCallback
}
} }
if result.Error != "" { if result.Error != "" {
return nil, fmt.Errorf("iflow auth: provider returned error %s", result.Error) return nil, fmt.Errorf("iflow auth: provider returned error %s", result.Error)

View File

@@ -7,10 +7,10 @@ import (
"fmt" "fmt"
"github.com/router-for-me/CLIProxyAPI/v6/internal/api" "github.com/router-for-me/CLIProxyAPI/v6/internal/api"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
sdkaccess "github.com/router-for-me/CLIProxyAPI/v6/sdk/access" sdkaccess "github.com/router-for-me/CLIProxyAPI/v6/sdk/access"
sdkAuth "github.com/router-for-me/CLIProxyAPI/v6/sdk/auth" sdkAuth "github.com/router-for-me/CLIProxyAPI/v6/sdk/auth"
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth" coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
"github.com/router-for-me/CLIProxyAPI/v6/sdk/config"
) )
// Builder constructs a Service instance with customizable providers. // Builder constructs a Service instance with customizable providers.

View File

@@ -3,8 +3,8 @@ package cliproxy
import ( import (
"context" "context"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
"github.com/router-for-me/CLIProxyAPI/v6/internal/watcher" "github.com/router-for-me/CLIProxyAPI/v6/internal/watcher"
"github.com/router-for-me/CLIProxyAPI/v6/sdk/config"
) )
// NewFileTokenClientProvider returns the default token-backed client loader. // NewFileTokenClientProvider returns the default token-backed client loader.

View File

@@ -13,7 +13,6 @@ import (
"time" "time"
"github.com/router-for-me/CLIProxyAPI/v6/internal/api" "github.com/router-for-me/CLIProxyAPI/v6/internal/api"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry" "github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
"github.com/router-for-me/CLIProxyAPI/v6/internal/runtime/executor" "github.com/router-for-me/CLIProxyAPI/v6/internal/runtime/executor"
_ "github.com/router-for-me/CLIProxyAPI/v6/internal/usage" _ "github.com/router-for-me/CLIProxyAPI/v6/internal/usage"
@@ -23,6 +22,7 @@ import (
sdkAuth "github.com/router-for-me/CLIProxyAPI/v6/sdk/auth" sdkAuth "github.com/router-for-me/CLIProxyAPI/v6/sdk/auth"
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth" coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
"github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/usage" "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/usage"
"github.com/router-for-me/CLIProxyAPI/v6/sdk/config"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
) )

View File

@@ -6,9 +6,9 @@ package cliproxy
import ( import (
"context" "context"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
"github.com/router-for-me/CLIProxyAPI/v6/internal/watcher" "github.com/router-for-me/CLIProxyAPI/v6/internal/watcher"
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth" coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
"github.com/router-for-me/CLIProxyAPI/v6/sdk/config"
) )
// TokenClientProvider loads clients backed by stored authentication tokens. // TokenClientProvider loads clients backed by stored authentication tokens.

View File

@@ -3,9 +3,9 @@ package cliproxy
import ( import (
"context" "context"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
"github.com/router-for-me/CLIProxyAPI/v6/internal/watcher" "github.com/router-for-me/CLIProxyAPI/v6/internal/watcher"
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth" coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
"github.com/router-for-me/CLIProxyAPI/v6/sdk/config"
) )
func defaultWatcherFactory(configPath, authDir string, reload func(*config.Config)) (*WatcherWrapper, error) { func defaultWatcherFactory(configPath, authDir string, reload func(*config.Config)) (*WatcherWrapper, error) {

View File

@@ -1,87 +1,59 @@
// Package config provides configuration management for the CLI Proxy API server. // Package config provides the public SDK configuration API.
// It handles loading and parsing YAML configuration files, and provides structured //
// access to application settings including server port, authentication directory, // It re-exports the server configuration types and helpers so external projects can
// debug settings, proxy configuration, and API keys. // embed CLIProxyAPI without importing internal packages.
package config package config
// SDKConfig represents the application's configuration, loaded from a YAML file. import internalconfig "github.com/router-for-me/CLIProxyAPI/v6/internal/config"
type SDKConfig struct {
// ProxyURL is the URL of an optional proxy server to use for outbound requests.
ProxyURL string `yaml:"proxy-url" json:"proxy-url"`
// ForceModelPrefix requires explicit model prefixes (e.g., "teamA/gemini-3-pro-preview") type SDKConfig = internalconfig.SDKConfig
// to target prefixed credentials. When false, unprefixed model requests may use prefixed type AccessConfig = internalconfig.AccessConfig
// credentials as well. type AccessProvider = internalconfig.AccessProvider
ForceModelPrefix bool `yaml:"force-model-prefix" json:"force-model-prefix"`
// RequestLog enables or disables detailed request logging functionality. type Config = internalconfig.Config
RequestLog bool `yaml:"request-log" json:"request-log"`
// APIKeys is a list of keys for authenticating clients to this proxy server. type TLSConfig = internalconfig.TLSConfig
APIKeys []string `yaml:"api-keys" json:"api-keys"` type RemoteManagement = internalconfig.RemoteManagement
type AmpCode = internalconfig.AmpCode
type PayloadConfig = internalconfig.PayloadConfig
type PayloadRule = internalconfig.PayloadRule
type PayloadModelRule = internalconfig.PayloadModelRule
// Access holds request authentication provider configuration. type GeminiKey = internalconfig.GeminiKey
Access AccessConfig `yaml:"auth,omitempty" json:"auth,omitempty"` type CodexKey = internalconfig.CodexKey
} type ClaudeKey = internalconfig.ClaudeKey
type VertexCompatKey = internalconfig.VertexCompatKey
type VertexCompatModel = internalconfig.VertexCompatModel
type OpenAICompatibility = internalconfig.OpenAICompatibility
type OpenAICompatibilityAPIKey = internalconfig.OpenAICompatibilityAPIKey
type OpenAICompatibilityModel = internalconfig.OpenAICompatibilityModel
// AccessConfig groups request authentication providers. type TLS = internalconfig.TLSConfig
type AccessConfig struct {
// Providers lists configured authentication providers.
Providers []AccessProvider `yaml:"providers,omitempty" json:"providers,omitempty"`
}
// AccessProvider describes a request authentication provider entry.
type AccessProvider struct {
// Name is the instance identifier for the provider.
Name string `yaml:"name" json:"name"`
// Type selects the provider implementation registered via the SDK.
Type string `yaml:"type" json:"type"`
// SDK optionally names a third-party SDK module providing this provider.
SDK string `yaml:"sdk,omitempty" json:"sdk,omitempty"`
// APIKeys lists inline keys for providers that require them.
APIKeys []string `yaml:"api-keys,omitempty" json:"api-keys,omitempty"`
// Config passes provider-specific options to the implementation.
Config map[string]any `yaml:"config,omitempty" json:"config,omitempty"`
}
const ( const (
// AccessProviderTypeConfigAPIKey is the built-in provider validating inline API keys. AccessProviderTypeConfigAPIKey = internalconfig.AccessProviderTypeConfigAPIKey
AccessProviderTypeConfigAPIKey = "config-api-key" DefaultAccessProviderName = internalconfig.DefaultAccessProviderName
DefaultPanelGitHubRepository = internalconfig.DefaultPanelGitHubRepository
// DefaultAccessProviderName is applied when no provider name is supplied.
DefaultAccessProviderName = "config-inline"
) )
// ConfigAPIKeyProvider returns the first inline API key provider if present. func MakeInlineAPIKeyProvider(keys []string) *AccessProvider {
func (c *SDKConfig) ConfigAPIKeyProvider() *AccessProvider { return internalconfig.MakeInlineAPIKeyProvider(keys)
if c == nil {
return nil
}
for i := range c.Access.Providers {
if c.Access.Providers[i].Type == AccessProviderTypeConfigAPIKey {
if c.Access.Providers[i].Name == "" {
c.Access.Providers[i].Name = DefaultAccessProviderName
}
return &c.Access.Providers[i]
}
}
return nil
} }
// MakeInlineAPIKeyProvider constructs an inline API key provider configuration. func LoadConfig(configFile string) (*Config, error) { return internalconfig.LoadConfig(configFile) }
// It returns nil when no keys are supplied.
func MakeInlineAPIKeyProvider(keys []string) *AccessProvider { func LoadConfigOptional(configFile string, optional bool) (*Config, error) {
if len(keys) == 0 { return internalconfig.LoadConfigOptional(configFile, optional)
return nil }
}
provider := &AccessProvider{ func SaveConfigPreserveComments(configFile string, cfg *Config) error {
Name: DefaultAccessProviderName, return internalconfig.SaveConfigPreserveComments(configFile, cfg)
Type: AccessProviderTypeConfigAPIKey, }
APIKeys: append([]string(nil), keys...),
} func SaveConfigPreserveCommentsUpdateNestedScalar(configFile string, path []string, value string) error {
return provider return internalconfig.SaveConfigPreserveCommentsUpdateNestedScalar(configFile, path, value)
}
func NormalizeCommentIndentation(data []byte) []byte {
return internalconfig.NormalizeCommentIndentation(data)
} }

View File

@@ -0,0 +1,18 @@
// Package logging re-exports request logging primitives for SDK consumers.
package logging
import internallogging "github.com/router-for-me/CLIProxyAPI/v6/internal/logging"
// RequestLogger defines the interface for logging HTTP requests and responses.
type RequestLogger = internallogging.RequestLogger
// StreamingLogWriter handles real-time logging of streaming response chunks.
type StreamingLogWriter = internallogging.StreamingLogWriter
// FileRequestLogger implements RequestLogger using file-based storage.
type FileRequestLogger = internallogging.FileRequestLogger
// NewFileRequestLogger creates a new file-based request logger.
func NewFileRequestLogger(enabled bool, logsDir string, configDir string) *FileRequestLogger {
return internallogging.NewFileRequestLogger(enabled, logsDir, configDir)
}