mirror of
https://github.com/router-for-me/CLIProxyAPI.git
synced 2026-02-03 04:50:52 +08:00
Merge branch 'main' into fix/antigravity-prompt-caching
This commit is contained in:
@@ -405,7 +405,7 @@ func main() {
|
||||
usage.SetStatisticsEnabled(cfg.UsageStatisticsEnabled)
|
||||
coreauth.SetQuotaCooldownDisabled(cfg.DisableCooling)
|
||||
|
||||
if err = logging.ConfigureLogOutput(cfg.LoggingToFile); err != nil {
|
||||
if err = logging.ConfigureLogOutput(cfg.LoggingToFile, cfg.LogsMaxTotalSizeMB); err != nil {
|
||||
log.Errorf("failed to configure log output: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -42,6 +42,10 @@ debug: false
|
||||
# When true, write application logs to rotating files instead of stdout
|
||||
logging-to-file: false
|
||||
|
||||
# Maximum total size (MB) of log files under the logs directory. When exceeded, the oldest log
|
||||
# files are deleted until within the limit. Set to 0 to disable.
|
||||
logs-max-total-size-mb: 0
|
||||
|
||||
# When false, disable in-memory usage statistics aggregation
|
||||
usage-statistics-enabled: false
|
||||
|
||||
|
||||
@@ -23,13 +23,13 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/api"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/logging"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/sdk/api"
|
||||
sdkAuth "github.com/router-for-me/CLIProxyAPI/v6/sdk/auth"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy"
|
||||
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||
clipexec "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/sdk/config"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/sdk/logging"
|
||||
sdktr "github.com/router-for-me/CLIProxyAPI/v6/sdk/translator"
|
||||
)
|
||||
|
||||
|
||||
@@ -36,10 +36,6 @@ import (
|
||||
"golang.org/x/oauth2/google"
|
||||
)
|
||||
|
||||
var (
|
||||
oauthStatus = make(map[string]string)
|
||||
)
|
||||
|
||||
var lastRefreshKeys = []string{"last_refresh", "lastRefresh", "last_refreshed_at", "lastRefreshedAt"}
|
||||
|
||||
const (
|
||||
@@ -201,6 +197,19 @@ func stopCallbackForwarder(port int) {
|
||||
stopForwarderInstance(port, forwarder)
|
||||
}
|
||||
|
||||
func stopCallbackForwarderInstance(port int, forwarder *callbackForwarder) {
|
||||
if forwarder == nil {
|
||||
return
|
||||
}
|
||||
callbackForwardersMu.Lock()
|
||||
if current := callbackForwarders[port]; current == forwarder {
|
||||
delete(callbackForwarders, port)
|
||||
}
|
||||
callbackForwardersMu.Unlock()
|
||||
|
||||
stopForwarderInstance(port, forwarder)
|
||||
}
|
||||
|
||||
func stopForwarderInstance(port int, forwarder *callbackForwarder) {
|
||||
if forwarder == nil || forwarder.server == nil {
|
||||
return
|
||||
@@ -786,7 +795,10 @@ func (h *Handler) RequestAnthropicToken(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
RegisterOAuthSession(state, "anthropic")
|
||||
|
||||
isWebUI := isWebUIRequest(c)
|
||||
var forwarder *callbackForwarder
|
||||
if isWebUI {
|
||||
targetURL, errTarget := h.managementCallbackURL("/anthropic/callback")
|
||||
if errTarget != nil {
|
||||
@@ -794,7 +806,8 @@ func (h *Handler) RequestAnthropicToken(c *gin.Context) {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "callback server unavailable"})
|
||||
return
|
||||
}
|
||||
if _, errStart := startCallbackForwarder(anthropicCallbackPort, "anthropic", targetURL); errStart != nil {
|
||||
var errStart error
|
||||
if forwarder, errStart = startCallbackForwarder(anthropicCallbackPort, "anthropic", targetURL); errStart != nil {
|
||||
log.WithError(errStart).Error("failed to start anthropic callback forwarder")
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to start callback server"})
|
||||
return
|
||||
@@ -803,7 +816,7 @@ func (h *Handler) RequestAnthropicToken(c *gin.Context) {
|
||||
|
||||
go func() {
|
||||
if isWebUI {
|
||||
defer stopCallbackForwarder(anthropicCallbackPort)
|
||||
defer stopCallbackForwarderInstance(anthropicCallbackPort, forwarder)
|
||||
}
|
||||
|
||||
// Helper: wait for callback file
|
||||
@@ -811,8 +824,11 @@ func (h *Handler) RequestAnthropicToken(c *gin.Context) {
|
||||
waitForFile := func(path string, timeout time.Duration) (map[string]string, error) {
|
||||
deadline := time.Now().Add(timeout)
|
||||
for {
|
||||
if !IsOAuthSessionPending(state, "anthropic") {
|
||||
return nil, errOAuthSessionNotPending
|
||||
}
|
||||
if time.Now().After(deadline) {
|
||||
oauthStatus[state] = "Timeout waiting for OAuth callback"
|
||||
SetOAuthSessionError(state, "Timeout waiting for OAuth callback")
|
||||
return nil, fmt.Errorf("timeout waiting for OAuth callback")
|
||||
}
|
||||
data, errRead := os.ReadFile(path)
|
||||
@@ -830,6 +846,9 @@ func (h *Handler) RequestAnthropicToken(c *gin.Context) {
|
||||
// Wait up to 5 minutes
|
||||
resultMap, errWait := waitForFile(waitFile, 5*time.Minute)
|
||||
if errWait != nil {
|
||||
if errors.Is(errWait, errOAuthSessionNotPending) {
|
||||
return
|
||||
}
|
||||
authErr := claude.NewAuthenticationError(claude.ErrCallbackTimeout, errWait)
|
||||
log.Error(claude.GetUserFriendlyMessage(authErr))
|
||||
return
|
||||
@@ -837,13 +856,13 @@ func (h *Handler) RequestAnthropicToken(c *gin.Context) {
|
||||
if errStr := resultMap["error"]; errStr != "" {
|
||||
oauthErr := claude.NewOAuthError(errStr, "", http.StatusBadRequest)
|
||||
log.Error(claude.GetUserFriendlyMessage(oauthErr))
|
||||
oauthStatus[state] = "Bad request"
|
||||
SetOAuthSessionError(state, "Bad request")
|
||||
return
|
||||
}
|
||||
if resultMap["state"] != state {
|
||||
authErr := claude.NewAuthenticationError(claude.ErrInvalidState, fmt.Errorf("expected %s, got %s", state, resultMap["state"]))
|
||||
log.Error(claude.GetUserFriendlyMessage(authErr))
|
||||
oauthStatus[state] = "State code error"
|
||||
SetOAuthSessionError(state, "State code error")
|
||||
return
|
||||
}
|
||||
|
||||
@@ -876,7 +895,7 @@ func (h *Handler) RequestAnthropicToken(c *gin.Context) {
|
||||
if errDo != nil {
|
||||
authErr := claude.NewAuthenticationError(claude.ErrCodeExchangeFailed, errDo)
|
||||
log.Errorf("Failed to exchange authorization code for tokens: %v", authErr)
|
||||
oauthStatus[state] = "Failed to exchange authorization code for tokens"
|
||||
SetOAuthSessionError(state, "Failed to exchange authorization code for tokens")
|
||||
return
|
||||
}
|
||||
defer func() {
|
||||
@@ -887,7 +906,7 @@ func (h *Handler) RequestAnthropicToken(c *gin.Context) {
|
||||
respBody, _ := io.ReadAll(resp.Body)
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
log.Errorf("token exchange failed with status %d: %s", resp.StatusCode, string(respBody))
|
||||
oauthStatus[state] = fmt.Sprintf("token exchange failed with status %d", resp.StatusCode)
|
||||
SetOAuthSessionError(state, fmt.Sprintf("token exchange failed with status %d", resp.StatusCode))
|
||||
return
|
||||
}
|
||||
var tResp struct {
|
||||
@@ -900,7 +919,7 @@ func (h *Handler) RequestAnthropicToken(c *gin.Context) {
|
||||
}
|
||||
if errU := json.Unmarshal(respBody, &tResp); errU != nil {
|
||||
log.Errorf("failed to parse token response: %v", errU)
|
||||
oauthStatus[state] = "Failed to parse token response"
|
||||
SetOAuthSessionError(state, "Failed to parse token response")
|
||||
return
|
||||
}
|
||||
bundle := &claude.ClaudeAuthBundle{
|
||||
@@ -925,7 +944,7 @@ func (h *Handler) RequestAnthropicToken(c *gin.Context) {
|
||||
savedPath, errSave := h.saveTokenRecord(ctx, record)
|
||||
if errSave != nil {
|
||||
log.Errorf("Failed to save authentication tokens: %v", errSave)
|
||||
oauthStatus[state] = "Failed to save authentication tokens"
|
||||
SetOAuthSessionError(state, "Failed to save authentication tokens")
|
||||
return
|
||||
}
|
||||
|
||||
@@ -934,10 +953,10 @@ func (h *Handler) RequestAnthropicToken(c *gin.Context) {
|
||||
fmt.Println("API key obtained and saved")
|
||||
}
|
||||
fmt.Println("You can now use Claude services through this CLI")
|
||||
delete(oauthStatus, state)
|
||||
CompleteOAuthSession(state)
|
||||
CompleteOAuthSessionsByProvider("anthropic")
|
||||
}()
|
||||
|
||||
oauthStatus[state] = ""
|
||||
c.JSON(200, gin.H{"status": "ok", "url": authURL, "state": state})
|
||||
}
|
||||
|
||||
@@ -968,7 +987,10 @@ func (h *Handler) RequestGeminiCLIToken(c *gin.Context) {
|
||||
state := fmt.Sprintf("gem-%d", time.Now().UnixNano())
|
||||
authURL := conf.AuthCodeURL(state, oauth2.AccessTypeOffline, oauth2.SetAuthURLParam("prompt", "consent"))
|
||||
|
||||
RegisterOAuthSession(state, "gemini")
|
||||
|
||||
isWebUI := isWebUIRequest(c)
|
||||
var forwarder *callbackForwarder
|
||||
if isWebUI {
|
||||
targetURL, errTarget := h.managementCallbackURL("/google/callback")
|
||||
if errTarget != nil {
|
||||
@@ -976,7 +998,8 @@ func (h *Handler) RequestGeminiCLIToken(c *gin.Context) {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "callback server unavailable"})
|
||||
return
|
||||
}
|
||||
if _, errStart := startCallbackForwarder(geminiCallbackPort, "gemini", targetURL); errStart != nil {
|
||||
var errStart error
|
||||
if forwarder, errStart = startCallbackForwarder(geminiCallbackPort, "gemini", targetURL); errStart != nil {
|
||||
log.WithError(errStart).Error("failed to start gemini callback forwarder")
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to start callback server"})
|
||||
return
|
||||
@@ -985,7 +1008,7 @@ func (h *Handler) RequestGeminiCLIToken(c *gin.Context) {
|
||||
|
||||
go func() {
|
||||
if isWebUI {
|
||||
defer stopCallbackForwarder(geminiCallbackPort)
|
||||
defer stopCallbackForwarderInstance(geminiCallbackPort, forwarder)
|
||||
}
|
||||
|
||||
// Wait for callback file written by server route
|
||||
@@ -994,9 +1017,12 @@ func (h *Handler) RequestGeminiCLIToken(c *gin.Context) {
|
||||
deadline := time.Now().Add(5 * time.Minute)
|
||||
var authCode string
|
||||
for {
|
||||
if !IsOAuthSessionPending(state, "gemini") {
|
||||
return
|
||||
}
|
||||
if time.Now().After(deadline) {
|
||||
log.Error("oauth flow timed out")
|
||||
oauthStatus[state] = "OAuth flow timed out"
|
||||
SetOAuthSessionError(state, "OAuth flow timed out")
|
||||
return
|
||||
}
|
||||
if data, errR := os.ReadFile(waitFile); errR == nil {
|
||||
@@ -1005,13 +1031,13 @@ func (h *Handler) RequestGeminiCLIToken(c *gin.Context) {
|
||||
_ = os.Remove(waitFile)
|
||||
if errStr := m["error"]; errStr != "" {
|
||||
log.Errorf("Authentication failed: %s", errStr)
|
||||
oauthStatus[state] = "Authentication failed"
|
||||
SetOAuthSessionError(state, "Authentication failed")
|
||||
return
|
||||
}
|
||||
authCode = m["code"]
|
||||
if authCode == "" {
|
||||
log.Errorf("Authentication failed: code not found")
|
||||
oauthStatus[state] = "Authentication failed: code not found"
|
||||
SetOAuthSessionError(state, "Authentication failed: code not found")
|
||||
return
|
||||
}
|
||||
break
|
||||
@@ -1023,7 +1049,7 @@ func (h *Handler) RequestGeminiCLIToken(c *gin.Context) {
|
||||
token, err := conf.Exchange(ctx, authCode)
|
||||
if err != nil {
|
||||
log.Errorf("Failed to exchange token: %v", err)
|
||||
oauthStatus[state] = "Failed to exchange token"
|
||||
SetOAuthSessionError(state, "Failed to exchange token")
|
||||
return
|
||||
}
|
||||
|
||||
@@ -1034,7 +1060,7 @@ func (h *Handler) RequestGeminiCLIToken(c *gin.Context) {
|
||||
req, errNewRequest := http.NewRequestWithContext(ctx, "GET", "https://www.googleapis.com/oauth2/v1/userinfo?alt=json", nil)
|
||||
if errNewRequest != nil {
|
||||
log.Errorf("Could not get user info: %v", errNewRequest)
|
||||
oauthStatus[state] = "Could not get user info"
|
||||
SetOAuthSessionError(state, "Could not get user info")
|
||||
return
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
@@ -1043,7 +1069,7 @@ func (h *Handler) RequestGeminiCLIToken(c *gin.Context) {
|
||||
resp, errDo := authHTTPClient.Do(req)
|
||||
if errDo != nil {
|
||||
log.Errorf("Failed to execute request: %v", errDo)
|
||||
oauthStatus[state] = "Failed to execute request"
|
||||
SetOAuthSessionError(state, "Failed to execute request")
|
||||
return
|
||||
}
|
||||
defer func() {
|
||||
@@ -1055,7 +1081,7 @@ func (h *Handler) RequestGeminiCLIToken(c *gin.Context) {
|
||||
bodyBytes, _ := io.ReadAll(resp.Body)
|
||||
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
||||
log.Errorf("Get user info request failed with status %d: %s", resp.StatusCode, string(bodyBytes))
|
||||
oauthStatus[state] = fmt.Sprintf("Get user info request failed with status %d", resp.StatusCode)
|
||||
SetOAuthSessionError(state, fmt.Sprintf("Get user info request failed with status %d", resp.StatusCode))
|
||||
return
|
||||
}
|
||||
|
||||
@@ -1064,7 +1090,6 @@ func (h *Handler) RequestGeminiCLIToken(c *gin.Context) {
|
||||
fmt.Printf("Authenticated user email: %s\n", email)
|
||||
} else {
|
||||
fmt.Println("Failed to get user email from token")
|
||||
oauthStatus[state] = "Failed to get user email from token"
|
||||
}
|
||||
|
||||
// Marshal/unmarshal oauth2.Token to generic map and enrich fields
|
||||
@@ -1072,7 +1097,7 @@ func (h *Handler) RequestGeminiCLIToken(c *gin.Context) {
|
||||
jsonData, _ := json.Marshal(token)
|
||||
if errUnmarshal := json.Unmarshal(jsonData, &ifToken); errUnmarshal != nil {
|
||||
log.Errorf("Failed to unmarshal token: %v", errUnmarshal)
|
||||
oauthStatus[state] = "Failed to unmarshal token"
|
||||
SetOAuthSessionError(state, "Failed to unmarshal token")
|
||||
return
|
||||
}
|
||||
|
||||
@@ -1095,10 +1120,12 @@ func (h *Handler) RequestGeminiCLIToken(c *gin.Context) {
|
||||
|
||||
// Initialize authenticated HTTP client via GeminiAuth to honor proxy settings
|
||||
gemAuth := geminiAuth.NewGeminiAuth()
|
||||
gemClient, errGetClient := gemAuth.GetAuthenticatedClient(ctx, &ts, h.cfg, true)
|
||||
gemClient, errGetClient := gemAuth.GetAuthenticatedClient(ctx, &ts, h.cfg, &geminiAuth.WebLoginOptions{
|
||||
NoBrowser: true,
|
||||
})
|
||||
if errGetClient != nil {
|
||||
log.Errorf("failed to get authenticated client: %v", errGetClient)
|
||||
oauthStatus[state] = "Failed to get authenticated client"
|
||||
SetOAuthSessionError(state, "Failed to get authenticated client")
|
||||
return
|
||||
}
|
||||
fmt.Println("Authentication successful.")
|
||||
@@ -1108,12 +1135,12 @@ func (h *Handler) RequestGeminiCLIToken(c *gin.Context) {
|
||||
projects, errAll := onboardAllGeminiProjects(ctx, gemClient, &ts)
|
||||
if errAll != nil {
|
||||
log.Errorf("Failed to complete Gemini CLI onboarding: %v", errAll)
|
||||
oauthStatus[state] = "Failed to complete Gemini CLI onboarding"
|
||||
SetOAuthSessionError(state, "Failed to complete Gemini CLI onboarding")
|
||||
return
|
||||
}
|
||||
if errVerify := ensureGeminiProjectsEnabled(ctx, gemClient, projects); errVerify != nil {
|
||||
log.Errorf("Failed to verify Cloud AI API status: %v", errVerify)
|
||||
oauthStatus[state] = "Failed to verify Cloud AI API status"
|
||||
SetOAuthSessionError(state, "Failed to verify Cloud AI API status")
|
||||
return
|
||||
}
|
||||
ts.ProjectID = strings.Join(projects, ",")
|
||||
@@ -1121,26 +1148,26 @@ func (h *Handler) RequestGeminiCLIToken(c *gin.Context) {
|
||||
} else {
|
||||
if errEnsure := ensureGeminiProjectAndOnboard(ctx, gemClient, &ts, requestedProjectID); errEnsure != nil {
|
||||
log.Errorf("Failed to complete Gemini CLI onboarding: %v", errEnsure)
|
||||
oauthStatus[state] = "Failed to complete Gemini CLI onboarding"
|
||||
SetOAuthSessionError(state, "Failed to complete Gemini CLI onboarding")
|
||||
return
|
||||
}
|
||||
|
||||
if strings.TrimSpace(ts.ProjectID) == "" {
|
||||
log.Error("Onboarding did not return a project ID")
|
||||
oauthStatus[state] = "Failed to resolve project ID"
|
||||
SetOAuthSessionError(state, "Failed to resolve project ID")
|
||||
return
|
||||
}
|
||||
|
||||
isChecked, errCheck := checkCloudAPIIsEnabled(ctx, gemClient, ts.ProjectID)
|
||||
if errCheck != nil {
|
||||
log.Errorf("Failed to verify Cloud AI API status: %v", errCheck)
|
||||
oauthStatus[state] = "Failed to verify Cloud AI API status"
|
||||
SetOAuthSessionError(state, "Failed to verify Cloud AI API status")
|
||||
return
|
||||
}
|
||||
ts.Checked = isChecked
|
||||
if !isChecked {
|
||||
log.Error("Cloud AI API is not enabled for the selected project")
|
||||
oauthStatus[state] = "Cloud AI API not enabled"
|
||||
SetOAuthSessionError(state, "Cloud AI API not enabled")
|
||||
return
|
||||
}
|
||||
}
|
||||
@@ -1163,15 +1190,15 @@ func (h *Handler) RequestGeminiCLIToken(c *gin.Context) {
|
||||
savedPath, errSave := h.saveTokenRecord(ctx, record)
|
||||
if errSave != nil {
|
||||
log.Errorf("Failed to save token to file: %v", errSave)
|
||||
oauthStatus[state] = "Failed to save token to file"
|
||||
SetOAuthSessionError(state, "Failed to save token to file")
|
||||
return
|
||||
}
|
||||
|
||||
delete(oauthStatus, state)
|
||||
CompleteOAuthSession(state)
|
||||
CompleteOAuthSessionsByProvider("gemini")
|
||||
fmt.Printf("You can now use Gemini CLI services through this CLI; token saved to %s\n", savedPath)
|
||||
}()
|
||||
|
||||
oauthStatus[state] = ""
|
||||
c.JSON(200, gin.H{"status": "ok", "url": authURL, "state": state})
|
||||
}
|
||||
|
||||
@@ -1207,7 +1234,10 @@ func (h *Handler) RequestCodexToken(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
RegisterOAuthSession(state, "codex")
|
||||
|
||||
isWebUI := isWebUIRequest(c)
|
||||
var forwarder *callbackForwarder
|
||||
if isWebUI {
|
||||
targetURL, errTarget := h.managementCallbackURL("/codex/callback")
|
||||
if errTarget != nil {
|
||||
@@ -1215,7 +1245,8 @@ func (h *Handler) RequestCodexToken(c *gin.Context) {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "callback server unavailable"})
|
||||
return
|
||||
}
|
||||
if _, errStart := startCallbackForwarder(codexCallbackPort, "codex", targetURL); errStart != nil {
|
||||
var errStart error
|
||||
if forwarder, errStart = startCallbackForwarder(codexCallbackPort, "codex", targetURL); errStart != nil {
|
||||
log.WithError(errStart).Error("failed to start codex callback forwarder")
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to start callback server"})
|
||||
return
|
||||
@@ -1224,7 +1255,7 @@ func (h *Handler) RequestCodexToken(c *gin.Context) {
|
||||
|
||||
go func() {
|
||||
if isWebUI {
|
||||
defer stopCallbackForwarder(codexCallbackPort)
|
||||
defer stopCallbackForwarderInstance(codexCallbackPort, forwarder)
|
||||
}
|
||||
|
||||
// Wait for callback file
|
||||
@@ -1232,10 +1263,13 @@ func (h *Handler) RequestCodexToken(c *gin.Context) {
|
||||
deadline := time.Now().Add(5 * time.Minute)
|
||||
var code string
|
||||
for {
|
||||
if !IsOAuthSessionPending(state, "codex") {
|
||||
return
|
||||
}
|
||||
if time.Now().After(deadline) {
|
||||
authErr := codex.NewAuthenticationError(codex.ErrCallbackTimeout, fmt.Errorf("timeout waiting for OAuth callback"))
|
||||
log.Error(codex.GetUserFriendlyMessage(authErr))
|
||||
oauthStatus[state] = "Timeout waiting for OAuth callback"
|
||||
SetOAuthSessionError(state, "Timeout waiting for OAuth callback")
|
||||
return
|
||||
}
|
||||
if data, errR := os.ReadFile(waitFile); errR == nil {
|
||||
@@ -1245,12 +1279,12 @@ func (h *Handler) RequestCodexToken(c *gin.Context) {
|
||||
if errStr := m["error"]; errStr != "" {
|
||||
oauthErr := codex.NewOAuthError(errStr, "", http.StatusBadRequest)
|
||||
log.Error(codex.GetUserFriendlyMessage(oauthErr))
|
||||
oauthStatus[state] = "Bad Request"
|
||||
SetOAuthSessionError(state, "Bad Request")
|
||||
return
|
||||
}
|
||||
if m["state"] != state {
|
||||
authErr := codex.NewAuthenticationError(codex.ErrInvalidState, fmt.Errorf("expected %s, got %s", state, m["state"]))
|
||||
oauthStatus[state] = "State code error"
|
||||
SetOAuthSessionError(state, "State code error")
|
||||
log.Error(codex.GetUserFriendlyMessage(authErr))
|
||||
return
|
||||
}
|
||||
@@ -1281,14 +1315,14 @@ func (h *Handler) RequestCodexToken(c *gin.Context) {
|
||||
resp, errDo := httpClient.Do(req)
|
||||
if errDo != nil {
|
||||
authErr := codex.NewAuthenticationError(codex.ErrCodeExchangeFailed, errDo)
|
||||
oauthStatus[state] = "Failed to exchange authorization code for tokens"
|
||||
SetOAuthSessionError(state, "Failed to exchange authorization code for tokens")
|
||||
log.Errorf("Failed to exchange authorization code for tokens: %v", authErr)
|
||||
return
|
||||
}
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
respBody, _ := io.ReadAll(resp.Body)
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
oauthStatus[state] = fmt.Sprintf("Token exchange failed with status %d", resp.StatusCode)
|
||||
SetOAuthSessionError(state, fmt.Sprintf("Token exchange failed with status %d", resp.StatusCode))
|
||||
log.Errorf("token exchange failed with status %d: %s", resp.StatusCode, string(respBody))
|
||||
return
|
||||
}
|
||||
@@ -1299,7 +1333,7 @@ func (h *Handler) RequestCodexToken(c *gin.Context) {
|
||||
ExpiresIn int `json:"expires_in"`
|
||||
}
|
||||
if errU := json.Unmarshal(respBody, &tokenResp); errU != nil {
|
||||
oauthStatus[state] = "Failed to parse token response"
|
||||
SetOAuthSessionError(state, "Failed to parse token response")
|
||||
log.Errorf("failed to parse token response: %v", errU)
|
||||
return
|
||||
}
|
||||
@@ -1337,7 +1371,7 @@ func (h *Handler) RequestCodexToken(c *gin.Context) {
|
||||
}
|
||||
savedPath, errSave := h.saveTokenRecord(ctx, record)
|
||||
if errSave != nil {
|
||||
oauthStatus[state] = "Failed to save authentication tokens"
|
||||
SetOAuthSessionError(state, "Failed to save authentication tokens")
|
||||
log.Errorf("Failed to save authentication tokens: %v", errSave)
|
||||
return
|
||||
}
|
||||
@@ -1346,10 +1380,10 @@ func (h *Handler) RequestCodexToken(c *gin.Context) {
|
||||
fmt.Println("API key obtained and saved")
|
||||
}
|
||||
fmt.Println("You can now use Codex services through this CLI")
|
||||
delete(oauthStatus, state)
|
||||
CompleteOAuthSession(state)
|
||||
CompleteOAuthSessionsByProvider("codex")
|
||||
}()
|
||||
|
||||
oauthStatus[state] = ""
|
||||
c.JSON(200, gin.H{"status": "ok", "url": authURL, "state": state})
|
||||
}
|
||||
|
||||
@@ -1390,7 +1424,10 @@ func (h *Handler) RequestAntigravityToken(c *gin.Context) {
|
||||
params.Set("state", state)
|
||||
authURL := "https://accounts.google.com/o/oauth2/v2/auth?" + params.Encode()
|
||||
|
||||
RegisterOAuthSession(state, "antigravity")
|
||||
|
||||
isWebUI := isWebUIRequest(c)
|
||||
var forwarder *callbackForwarder
|
||||
if isWebUI {
|
||||
targetURL, errTarget := h.managementCallbackURL("/antigravity/callback")
|
||||
if errTarget != nil {
|
||||
@@ -1398,7 +1435,8 @@ func (h *Handler) RequestAntigravityToken(c *gin.Context) {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "callback server unavailable"})
|
||||
return
|
||||
}
|
||||
if _, errStart := startCallbackForwarder(antigravityCallbackPort, "antigravity", targetURL); errStart != nil {
|
||||
var errStart error
|
||||
if forwarder, errStart = startCallbackForwarder(antigravityCallbackPort, "antigravity", targetURL); errStart != nil {
|
||||
log.WithError(errStart).Error("failed to start antigravity callback forwarder")
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "failed to start callback server"})
|
||||
return
|
||||
@@ -1407,16 +1445,19 @@ func (h *Handler) RequestAntigravityToken(c *gin.Context) {
|
||||
|
||||
go func() {
|
||||
if isWebUI {
|
||||
defer stopCallbackForwarder(antigravityCallbackPort)
|
||||
defer stopCallbackForwarderInstance(antigravityCallbackPort, forwarder)
|
||||
}
|
||||
|
||||
waitFile := filepath.Join(h.cfg.AuthDir, fmt.Sprintf(".oauth-antigravity-%s.oauth", state))
|
||||
deadline := time.Now().Add(5 * time.Minute)
|
||||
var authCode string
|
||||
for {
|
||||
if !IsOAuthSessionPending(state, "antigravity") {
|
||||
return
|
||||
}
|
||||
if time.Now().After(deadline) {
|
||||
log.Error("oauth flow timed out")
|
||||
oauthStatus[state] = "OAuth flow timed out"
|
||||
SetOAuthSessionError(state, "OAuth flow timed out")
|
||||
return
|
||||
}
|
||||
if data, errReadFile := os.ReadFile(waitFile); errReadFile == nil {
|
||||
@@ -1425,18 +1466,18 @@ func (h *Handler) RequestAntigravityToken(c *gin.Context) {
|
||||
_ = os.Remove(waitFile)
|
||||
if errStr := strings.TrimSpace(payload["error"]); errStr != "" {
|
||||
log.Errorf("Authentication failed: %s", errStr)
|
||||
oauthStatus[state] = "Authentication failed"
|
||||
SetOAuthSessionError(state, "Authentication failed")
|
||||
return
|
||||
}
|
||||
if payloadState := strings.TrimSpace(payload["state"]); payloadState != "" && payloadState != state {
|
||||
log.Errorf("Authentication failed: state mismatch")
|
||||
oauthStatus[state] = "Authentication failed: state mismatch"
|
||||
SetOAuthSessionError(state, "Authentication failed: state mismatch")
|
||||
return
|
||||
}
|
||||
authCode = strings.TrimSpace(payload["code"])
|
||||
if authCode == "" {
|
||||
log.Error("Authentication failed: code not found")
|
||||
oauthStatus[state] = "Authentication failed: code not found"
|
||||
SetOAuthSessionError(state, "Authentication failed: code not found")
|
||||
return
|
||||
}
|
||||
break
|
||||
@@ -1455,7 +1496,7 @@ func (h *Handler) RequestAntigravityToken(c *gin.Context) {
|
||||
req, errNewRequest := http.NewRequestWithContext(ctx, http.MethodPost, "https://oauth2.googleapis.com/token", strings.NewReader(form.Encode()))
|
||||
if errNewRequest != nil {
|
||||
log.Errorf("Failed to build token request: %v", errNewRequest)
|
||||
oauthStatus[state] = "Failed to build token request"
|
||||
SetOAuthSessionError(state, "Failed to build token request")
|
||||
return
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
@@ -1463,7 +1504,7 @@ func (h *Handler) RequestAntigravityToken(c *gin.Context) {
|
||||
resp, errDo := httpClient.Do(req)
|
||||
if errDo != nil {
|
||||
log.Errorf("Failed to execute token request: %v", errDo)
|
||||
oauthStatus[state] = "Failed to exchange token"
|
||||
SetOAuthSessionError(state, "Failed to exchange token")
|
||||
return
|
||||
}
|
||||
defer func() {
|
||||
@@ -1475,7 +1516,7 @@ func (h *Handler) RequestAntigravityToken(c *gin.Context) {
|
||||
if resp.StatusCode < http.StatusOK || resp.StatusCode >= http.StatusMultipleChoices {
|
||||
bodyBytes, _ := io.ReadAll(resp.Body)
|
||||
log.Errorf("Antigravity token exchange failed with status %d: %s", resp.StatusCode, string(bodyBytes))
|
||||
oauthStatus[state] = fmt.Sprintf("Token exchange failed: %d", resp.StatusCode)
|
||||
SetOAuthSessionError(state, fmt.Sprintf("Token exchange failed: %d", resp.StatusCode))
|
||||
return
|
||||
}
|
||||
|
||||
@@ -1487,7 +1528,7 @@ func (h *Handler) RequestAntigravityToken(c *gin.Context) {
|
||||
}
|
||||
if errDecode := json.NewDecoder(resp.Body).Decode(&tokenResp); errDecode != nil {
|
||||
log.Errorf("Failed to parse token response: %v", errDecode)
|
||||
oauthStatus[state] = "Failed to parse token response"
|
||||
SetOAuthSessionError(state, "Failed to parse token response")
|
||||
return
|
||||
}
|
||||
|
||||
@@ -1496,7 +1537,7 @@ func (h *Handler) RequestAntigravityToken(c *gin.Context) {
|
||||
infoReq, errInfoReq := http.NewRequestWithContext(ctx, http.MethodGet, "https://www.googleapis.com/oauth2/v1/userinfo?alt=json", nil)
|
||||
if errInfoReq != nil {
|
||||
log.Errorf("Failed to build user info request: %v", errInfoReq)
|
||||
oauthStatus[state] = "Failed to build user info request"
|
||||
SetOAuthSessionError(state, "Failed to build user info request")
|
||||
return
|
||||
}
|
||||
infoReq.Header.Set("Authorization", "Bearer "+tokenResp.AccessToken)
|
||||
@@ -1504,7 +1545,7 @@ func (h *Handler) RequestAntigravityToken(c *gin.Context) {
|
||||
infoResp, errInfo := httpClient.Do(infoReq)
|
||||
if errInfo != nil {
|
||||
log.Errorf("Failed to execute user info request: %v", errInfo)
|
||||
oauthStatus[state] = "Failed to execute user info request"
|
||||
SetOAuthSessionError(state, "Failed to execute user info request")
|
||||
return
|
||||
}
|
||||
defer func() {
|
||||
@@ -1523,7 +1564,7 @@ func (h *Handler) RequestAntigravityToken(c *gin.Context) {
|
||||
} else {
|
||||
bodyBytes, _ := io.ReadAll(infoResp.Body)
|
||||
log.Errorf("User info request failed with status %d: %s", infoResp.StatusCode, string(bodyBytes))
|
||||
oauthStatus[state] = fmt.Sprintf("User info request failed: %d", infoResp.StatusCode)
|
||||
SetOAuthSessionError(state, fmt.Sprintf("User info request failed: %d", infoResp.StatusCode))
|
||||
return
|
||||
}
|
||||
}
|
||||
@@ -1571,11 +1612,12 @@ func (h *Handler) RequestAntigravityToken(c *gin.Context) {
|
||||
savedPath, errSave := h.saveTokenRecord(ctx, record)
|
||||
if errSave != nil {
|
||||
log.Errorf("Failed to save token to file: %v", errSave)
|
||||
oauthStatus[state] = "Failed to save token to file"
|
||||
SetOAuthSessionError(state, "Failed to save token to file")
|
||||
return
|
||||
}
|
||||
|
||||
delete(oauthStatus, state)
|
||||
CompleteOAuthSession(state)
|
||||
CompleteOAuthSessionsByProvider("antigravity")
|
||||
fmt.Printf("Authentication successful! Token saved to %s\n", savedPath)
|
||||
if projectID != "" {
|
||||
fmt.Printf("Using GCP project: %s\n", projectID)
|
||||
@@ -1583,7 +1625,6 @@ func (h *Handler) RequestAntigravityToken(c *gin.Context) {
|
||||
fmt.Println("You can now use Antigravity services through this CLI")
|
||||
}()
|
||||
|
||||
oauthStatus[state] = ""
|
||||
c.JSON(200, gin.H{"status": "ok", "url": authURL, "state": state})
|
||||
}
|
||||
|
||||
@@ -1605,11 +1646,13 @@ func (h *Handler) RequestQwenToken(c *gin.Context) {
|
||||
}
|
||||
authURL := deviceFlow.VerificationURIComplete
|
||||
|
||||
RegisterOAuthSession(state, "qwen")
|
||||
|
||||
go func() {
|
||||
fmt.Println("Waiting for authentication...")
|
||||
tokenData, errPollForToken := qwenAuth.PollForToken(deviceFlow.DeviceCode, deviceFlow.CodeVerifier)
|
||||
if errPollForToken != nil {
|
||||
oauthStatus[state] = "Authentication failed"
|
||||
SetOAuthSessionError(state, "Authentication failed")
|
||||
fmt.Printf("Authentication failed: %v\n", errPollForToken)
|
||||
return
|
||||
}
|
||||
@@ -1628,16 +1671,15 @@ func (h *Handler) RequestQwenToken(c *gin.Context) {
|
||||
savedPath, errSave := h.saveTokenRecord(ctx, record)
|
||||
if errSave != nil {
|
||||
log.Errorf("Failed to save authentication tokens: %v", errSave)
|
||||
oauthStatus[state] = "Failed to save authentication tokens"
|
||||
SetOAuthSessionError(state, "Failed to save authentication tokens")
|
||||
return
|
||||
}
|
||||
|
||||
fmt.Printf("Authentication successful! Token saved to %s\n", savedPath)
|
||||
fmt.Println("You can now use Qwen services through this CLI")
|
||||
delete(oauthStatus, state)
|
||||
CompleteOAuthSession(state)
|
||||
}()
|
||||
|
||||
oauthStatus[state] = ""
|
||||
c.JSON(200, gin.H{"status": "ok", "url": authURL, "state": state})
|
||||
}
|
||||
|
||||
@@ -1650,7 +1692,10 @@ func (h *Handler) RequestIFlowToken(c *gin.Context) {
|
||||
authSvc := iflowauth.NewIFlowAuth(h.cfg)
|
||||
authURL, redirectURI := authSvc.AuthorizationURL(state, iflowauth.CallbackPort)
|
||||
|
||||
RegisterOAuthSession(state, "iflow")
|
||||
|
||||
isWebUI := isWebUIRequest(c)
|
||||
var forwarder *callbackForwarder
|
||||
if isWebUI {
|
||||
targetURL, errTarget := h.managementCallbackURL("/iflow/callback")
|
||||
if errTarget != nil {
|
||||
@@ -1658,7 +1703,8 @@ func (h *Handler) RequestIFlowToken(c *gin.Context) {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"status": "error", "error": "callback server unavailable"})
|
||||
return
|
||||
}
|
||||
if _, errStart := startCallbackForwarder(iflowauth.CallbackPort, "iflow", targetURL); errStart != nil {
|
||||
var errStart error
|
||||
if forwarder, errStart = startCallbackForwarder(iflowauth.CallbackPort, "iflow", targetURL); errStart != nil {
|
||||
log.WithError(errStart).Error("failed to start iflow callback forwarder")
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"status": "error", "error": "failed to start callback server"})
|
||||
return
|
||||
@@ -1667,7 +1713,7 @@ func (h *Handler) RequestIFlowToken(c *gin.Context) {
|
||||
|
||||
go func() {
|
||||
if isWebUI {
|
||||
defer stopCallbackForwarder(iflowauth.CallbackPort)
|
||||
defer stopCallbackForwarderInstance(iflowauth.CallbackPort, forwarder)
|
||||
}
|
||||
fmt.Println("Waiting for authentication...")
|
||||
|
||||
@@ -1675,8 +1721,11 @@ func (h *Handler) RequestIFlowToken(c *gin.Context) {
|
||||
deadline := time.Now().Add(5 * time.Minute)
|
||||
var resultMap map[string]string
|
||||
for {
|
||||
if !IsOAuthSessionPending(state, "iflow") {
|
||||
return
|
||||
}
|
||||
if time.Now().After(deadline) {
|
||||
oauthStatus[state] = "Authentication failed"
|
||||
SetOAuthSessionError(state, "Authentication failed")
|
||||
fmt.Println("Authentication failed: timeout waiting for callback")
|
||||
return
|
||||
}
|
||||
@@ -1689,26 +1738,26 @@ func (h *Handler) RequestIFlowToken(c *gin.Context) {
|
||||
}
|
||||
|
||||
if errStr := strings.TrimSpace(resultMap["error"]); errStr != "" {
|
||||
oauthStatus[state] = "Authentication failed"
|
||||
SetOAuthSessionError(state, "Authentication failed")
|
||||
fmt.Printf("Authentication failed: %s\n", errStr)
|
||||
return
|
||||
}
|
||||
if resultState := strings.TrimSpace(resultMap["state"]); resultState != state {
|
||||
oauthStatus[state] = "Authentication failed"
|
||||
SetOAuthSessionError(state, "Authentication failed")
|
||||
fmt.Println("Authentication failed: state mismatch")
|
||||
return
|
||||
}
|
||||
|
||||
code := strings.TrimSpace(resultMap["code"])
|
||||
if code == "" {
|
||||
oauthStatus[state] = "Authentication failed"
|
||||
SetOAuthSessionError(state, "Authentication failed")
|
||||
fmt.Println("Authentication failed: code missing")
|
||||
return
|
||||
}
|
||||
|
||||
tokenData, errExchange := authSvc.ExchangeCodeForTokens(ctx, code, redirectURI)
|
||||
if errExchange != nil {
|
||||
oauthStatus[state] = "Authentication failed"
|
||||
SetOAuthSessionError(state, "Authentication failed")
|
||||
fmt.Printf("Authentication failed: %v\n", errExchange)
|
||||
return
|
||||
}
|
||||
@@ -1730,7 +1779,7 @@ func (h *Handler) RequestIFlowToken(c *gin.Context) {
|
||||
|
||||
savedPath, errSave := h.saveTokenRecord(ctx, record)
|
||||
if errSave != nil {
|
||||
oauthStatus[state] = "Failed to save authentication tokens"
|
||||
SetOAuthSessionError(state, "Failed to save authentication tokens")
|
||||
log.Errorf("Failed to save authentication tokens: %v", errSave)
|
||||
return
|
||||
}
|
||||
@@ -1740,10 +1789,10 @@ func (h *Handler) RequestIFlowToken(c *gin.Context) {
|
||||
fmt.Println("API key obtained and saved")
|
||||
}
|
||||
fmt.Println("You can now use iFlow services through this CLI")
|
||||
delete(oauthStatus, state)
|
||||
CompleteOAuthSession(state)
|
||||
CompleteOAuthSessionsByProvider("iflow")
|
||||
}()
|
||||
|
||||
oauthStatus[state] = ""
|
||||
c.JSON(http.StatusOK, gin.H{"status": "ok", "url": authURL, "state": state})
|
||||
}
|
||||
|
||||
@@ -2179,16 +2228,24 @@ func checkCloudAPIIsEnabled(ctx context.Context, httpClient *http.Client, projec
|
||||
}
|
||||
|
||||
func (h *Handler) GetAuthStatus(c *gin.Context) {
|
||||
state := c.Query("state")
|
||||
if err, ok := oauthStatus[state]; ok {
|
||||
if err != "" {
|
||||
c.JSON(200, gin.H{"status": "error", "error": err})
|
||||
} else {
|
||||
c.JSON(200, gin.H{"status": "wait"})
|
||||
state := strings.TrimSpace(c.Query("state"))
|
||||
if state == "" {
|
||||
c.JSON(http.StatusOK, gin.H{"status": "ok"})
|
||||
return
|
||||
}
|
||||
} else {
|
||||
c.JSON(200, gin.H{"status": "ok"})
|
||||
if err := ValidateOAuthState(state); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"status": "error", "error": "invalid state"})
|
||||
return
|
||||
}
|
||||
delete(oauthStatus, state)
|
||||
|
||||
_, status, ok := GetOAuthSession(state)
|
||||
if !ok {
|
||||
c.JSON(http.StatusOK, gin.H{"status": "ok"})
|
||||
return
|
||||
}
|
||||
if status != "" {
|
||||
c.JSON(http.StatusOK, gin.H{"status": "error", "error": status})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{"status": "wait"})
|
||||
}
|
||||
|
||||
@@ -145,71 +145,74 @@ func (h *Handler) PutGeminiKeys(c *gin.Context) {
|
||||
h.persist(c)
|
||||
}
|
||||
func (h *Handler) PatchGeminiKey(c *gin.Context) {
|
||||
type geminiKeyPatch struct {
|
||||
APIKey *string `json:"api-key"`
|
||||
Prefix *string `json:"prefix"`
|
||||
BaseURL *string `json:"base-url"`
|
||||
ProxyURL *string `json:"proxy-url"`
|
||||
Headers *map[string]string `json:"headers"`
|
||||
ExcludedModels *[]string `json:"excluded-models"`
|
||||
}
|
||||
var body struct {
|
||||
Index *int `json:"index"`
|
||||
Match *string `json:"match"`
|
||||
Value *config.GeminiKey `json:"value"`
|
||||
Value *geminiKeyPatch `json:"value"`
|
||||
}
|
||||
if err := c.ShouldBindJSON(&body); err != nil || body.Value == nil {
|
||||
c.JSON(400, gin.H{"error": "invalid body"})
|
||||
return
|
||||
}
|
||||
value := *body.Value
|
||||
value.APIKey = strings.TrimSpace(value.APIKey)
|
||||
value.BaseURL = strings.TrimSpace(value.BaseURL)
|
||||
value.ProxyURL = strings.TrimSpace(value.ProxyURL)
|
||||
value.ExcludedModels = config.NormalizeExcludedModels(value.ExcludedModels)
|
||||
if value.APIKey == "" {
|
||||
// Treat empty API key as delete.
|
||||
targetIndex := -1
|
||||
if body.Index != nil && *body.Index >= 0 && *body.Index < len(h.cfg.GeminiKey) {
|
||||
h.cfg.GeminiKey = append(h.cfg.GeminiKey[:*body.Index], h.cfg.GeminiKey[*body.Index+1:]...)
|
||||
h.cfg.SanitizeGeminiKeys()
|
||||
h.persist(c)
|
||||
return
|
||||
targetIndex = *body.Index
|
||||
}
|
||||
if body.Match != nil {
|
||||
if targetIndex == -1 && body.Match != nil {
|
||||
match := strings.TrimSpace(*body.Match)
|
||||
if match != "" {
|
||||
out := make([]config.GeminiKey, 0, len(h.cfg.GeminiKey))
|
||||
removed := false
|
||||
for i := range h.cfg.GeminiKey {
|
||||
if !removed && h.cfg.GeminiKey[i].APIKey == match {
|
||||
removed = true
|
||||
continue
|
||||
}
|
||||
out = append(out, h.cfg.GeminiKey[i])
|
||||
}
|
||||
if removed {
|
||||
h.cfg.GeminiKey = out
|
||||
h.cfg.SanitizeGeminiKeys()
|
||||
h.persist(c)
|
||||
return
|
||||
if h.cfg.GeminiKey[i].APIKey == match {
|
||||
targetIndex = i
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if targetIndex == -1 {
|
||||
c.JSON(404, gin.H{"error": "item not found"})
|
||||
return
|
||||
}
|
||||
|
||||
if body.Index != nil && *body.Index >= 0 && *body.Index < len(h.cfg.GeminiKey) {
|
||||
h.cfg.GeminiKey[*body.Index] = value
|
||||
entry := h.cfg.GeminiKey[targetIndex]
|
||||
if body.Value.APIKey != nil {
|
||||
trimmed := strings.TrimSpace(*body.Value.APIKey)
|
||||
if trimmed == "" {
|
||||
h.cfg.GeminiKey = append(h.cfg.GeminiKey[:targetIndex], h.cfg.GeminiKey[targetIndex+1:]...)
|
||||
h.cfg.SanitizeGeminiKeys()
|
||||
h.persist(c)
|
||||
return
|
||||
}
|
||||
if body.Match != nil {
|
||||
match := strings.TrimSpace(*body.Match)
|
||||
for i := range h.cfg.GeminiKey {
|
||||
if h.cfg.GeminiKey[i].APIKey == match {
|
||||
h.cfg.GeminiKey[i] = value
|
||||
entry.APIKey = trimmed
|
||||
}
|
||||
if body.Value.Prefix != nil {
|
||||
entry.Prefix = strings.TrimSpace(*body.Value.Prefix)
|
||||
}
|
||||
if body.Value.BaseURL != nil {
|
||||
entry.BaseURL = strings.TrimSpace(*body.Value.BaseURL)
|
||||
}
|
||||
if body.Value.ProxyURL != nil {
|
||||
entry.ProxyURL = strings.TrimSpace(*body.Value.ProxyURL)
|
||||
}
|
||||
if body.Value.Headers != nil {
|
||||
entry.Headers = config.NormalizeHeaders(*body.Value.Headers)
|
||||
}
|
||||
if body.Value.ExcludedModels != nil {
|
||||
entry.ExcludedModels = config.NormalizeExcludedModels(*body.Value.ExcludedModels)
|
||||
}
|
||||
h.cfg.GeminiKey[targetIndex] = entry
|
||||
h.cfg.SanitizeGeminiKeys()
|
||||
h.persist(c)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
c.JSON(404, gin.H{"error": "item not found"})
|
||||
}
|
||||
|
||||
func (h *Handler) DeleteGeminiKey(c *gin.Context) {
|
||||
if val := strings.TrimSpace(c.Query("api-key")); val != "" {
|
||||
out := make([]config.GeminiKey, 0, len(h.cfg.GeminiKey))
|
||||
@@ -268,35 +271,70 @@ func (h *Handler) PutClaudeKeys(c *gin.Context) {
|
||||
h.persist(c)
|
||||
}
|
||||
func (h *Handler) PatchClaudeKey(c *gin.Context) {
|
||||
type claudeKeyPatch struct {
|
||||
APIKey *string `json:"api-key"`
|
||||
Prefix *string `json:"prefix"`
|
||||
BaseURL *string `json:"base-url"`
|
||||
ProxyURL *string `json:"proxy-url"`
|
||||
Models *[]config.ClaudeModel `json:"models"`
|
||||
Headers *map[string]string `json:"headers"`
|
||||
ExcludedModels *[]string `json:"excluded-models"`
|
||||
}
|
||||
var body struct {
|
||||
Index *int `json:"index"`
|
||||
Match *string `json:"match"`
|
||||
Value *config.ClaudeKey `json:"value"`
|
||||
Value *claudeKeyPatch `json:"value"`
|
||||
}
|
||||
if err := c.ShouldBindJSON(&body); err != nil || body.Value == nil {
|
||||
c.JSON(400, gin.H{"error": "invalid body"})
|
||||
return
|
||||
}
|
||||
value := *body.Value
|
||||
normalizeClaudeKey(&value)
|
||||
targetIndex := -1
|
||||
if body.Index != nil && *body.Index >= 0 && *body.Index < len(h.cfg.ClaudeKey) {
|
||||
h.cfg.ClaudeKey[*body.Index] = value
|
||||
h.cfg.SanitizeClaudeKeys()
|
||||
h.persist(c)
|
||||
return
|
||||
targetIndex = *body.Index
|
||||
}
|
||||
if body.Match != nil {
|
||||
if targetIndex == -1 && body.Match != nil {
|
||||
match := strings.TrimSpace(*body.Match)
|
||||
for i := range h.cfg.ClaudeKey {
|
||||
if h.cfg.ClaudeKey[i].APIKey == *body.Match {
|
||||
h.cfg.ClaudeKey[i] = value
|
||||
h.cfg.SanitizeClaudeKeys()
|
||||
h.persist(c)
|
||||
if h.cfg.ClaudeKey[i].APIKey == match {
|
||||
targetIndex = i
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
if targetIndex == -1 {
|
||||
c.JSON(404, gin.H{"error": "item not found"})
|
||||
return
|
||||
}
|
||||
|
||||
entry := h.cfg.ClaudeKey[targetIndex]
|
||||
if body.Value.APIKey != nil {
|
||||
entry.APIKey = strings.TrimSpace(*body.Value.APIKey)
|
||||
}
|
||||
if body.Value.Prefix != nil {
|
||||
entry.Prefix = strings.TrimSpace(*body.Value.Prefix)
|
||||
}
|
||||
c.JSON(404, gin.H{"error": "item not found"})
|
||||
if body.Value.BaseURL != nil {
|
||||
entry.BaseURL = strings.TrimSpace(*body.Value.BaseURL)
|
||||
}
|
||||
if body.Value.ProxyURL != nil {
|
||||
entry.ProxyURL = strings.TrimSpace(*body.Value.ProxyURL)
|
||||
}
|
||||
if body.Value.Models != nil {
|
||||
entry.Models = append([]config.ClaudeModel(nil), (*body.Value.Models)...)
|
||||
}
|
||||
if body.Value.Headers != nil {
|
||||
entry.Headers = config.NormalizeHeaders(*body.Value.Headers)
|
||||
}
|
||||
if body.Value.ExcludedModels != nil {
|
||||
entry.ExcludedModels = config.NormalizeExcludedModels(*body.Value.ExcludedModels)
|
||||
}
|
||||
normalizeClaudeKey(&entry)
|
||||
h.cfg.ClaudeKey[targetIndex] = entry
|
||||
h.cfg.SanitizeClaudeKeys()
|
||||
h.persist(c)
|
||||
}
|
||||
|
||||
func (h *Handler) DeleteClaudeKey(c *gin.Context) {
|
||||
if val := c.Query("api-key"); val != "" {
|
||||
out := make([]config.ClaudeKey, 0, len(h.cfg.ClaudeKey))
|
||||
@@ -356,62 +394,73 @@ func (h *Handler) PutOpenAICompat(c *gin.Context) {
|
||||
h.persist(c)
|
||||
}
|
||||
func (h *Handler) PatchOpenAICompat(c *gin.Context) {
|
||||
type openAICompatPatch struct {
|
||||
Name *string `json:"name"`
|
||||
Prefix *string `json:"prefix"`
|
||||
BaseURL *string `json:"base-url"`
|
||||
APIKeyEntries *[]config.OpenAICompatibilityAPIKey `json:"api-key-entries"`
|
||||
Models *[]config.OpenAICompatibilityModel `json:"models"`
|
||||
Headers *map[string]string `json:"headers"`
|
||||
}
|
||||
var body struct {
|
||||
Name *string `json:"name"`
|
||||
Index *int `json:"index"`
|
||||
Value *config.OpenAICompatibility `json:"value"`
|
||||
Value *openAICompatPatch `json:"value"`
|
||||
}
|
||||
if err := c.ShouldBindJSON(&body); err != nil || body.Value == nil {
|
||||
c.JSON(400, gin.H{"error": "invalid body"})
|
||||
return
|
||||
}
|
||||
normalizeOpenAICompatibilityEntry(body.Value)
|
||||
// If base-url becomes empty, delete the provider instead of updating
|
||||
if strings.TrimSpace(body.Value.BaseURL) == "" {
|
||||
targetIndex := -1
|
||||
if body.Index != nil && *body.Index >= 0 && *body.Index < len(h.cfg.OpenAICompatibility) {
|
||||
h.cfg.OpenAICompatibility = append(h.cfg.OpenAICompatibility[:*body.Index], h.cfg.OpenAICompatibility[*body.Index+1:]...)
|
||||
h.cfg.SanitizeOpenAICompatibility()
|
||||
h.persist(c)
|
||||
return
|
||||
targetIndex = *body.Index
|
||||
}
|
||||
if body.Name != nil {
|
||||
out := make([]config.OpenAICompatibility, 0, len(h.cfg.OpenAICompatibility))
|
||||
removed := false
|
||||
if targetIndex == -1 && body.Name != nil {
|
||||
match := strings.TrimSpace(*body.Name)
|
||||
for i := range h.cfg.OpenAICompatibility {
|
||||
if !removed && h.cfg.OpenAICompatibility[i].Name == *body.Name {
|
||||
removed = true
|
||||
continue
|
||||
}
|
||||
out = append(out, h.cfg.OpenAICompatibility[i])
|
||||
}
|
||||
if removed {
|
||||
h.cfg.OpenAICompatibility = out
|
||||
h.cfg.SanitizeOpenAICompatibility()
|
||||
h.persist(c)
|
||||
return
|
||||
if h.cfg.OpenAICompatibility[i].Name == match {
|
||||
targetIndex = i
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
if targetIndex == -1 {
|
||||
c.JSON(404, gin.H{"error": "item not found"})
|
||||
return
|
||||
}
|
||||
if body.Index != nil && *body.Index >= 0 && *body.Index < len(h.cfg.OpenAICompatibility) {
|
||||
h.cfg.OpenAICompatibility[*body.Index] = *body.Value
|
||||
|
||||
entry := h.cfg.OpenAICompatibility[targetIndex]
|
||||
if body.Value.Name != nil {
|
||||
entry.Name = strings.TrimSpace(*body.Value.Name)
|
||||
}
|
||||
if body.Value.Prefix != nil {
|
||||
entry.Prefix = strings.TrimSpace(*body.Value.Prefix)
|
||||
}
|
||||
if body.Value.BaseURL != nil {
|
||||
trimmed := strings.TrimSpace(*body.Value.BaseURL)
|
||||
if trimmed == "" {
|
||||
h.cfg.OpenAICompatibility = append(h.cfg.OpenAICompatibility[:targetIndex], h.cfg.OpenAICompatibility[targetIndex+1:]...)
|
||||
h.cfg.SanitizeOpenAICompatibility()
|
||||
h.persist(c)
|
||||
return
|
||||
}
|
||||
if body.Name != nil {
|
||||
for i := range h.cfg.OpenAICompatibility {
|
||||
if h.cfg.OpenAICompatibility[i].Name == *body.Name {
|
||||
h.cfg.OpenAICompatibility[i] = *body.Value
|
||||
entry.BaseURL = trimmed
|
||||
}
|
||||
if body.Value.APIKeyEntries != nil {
|
||||
entry.APIKeyEntries = append([]config.OpenAICompatibilityAPIKey(nil), (*body.Value.APIKeyEntries)...)
|
||||
}
|
||||
if body.Value.Models != nil {
|
||||
entry.Models = append([]config.OpenAICompatibilityModel(nil), (*body.Value.Models)...)
|
||||
}
|
||||
if body.Value.Headers != nil {
|
||||
entry.Headers = config.NormalizeHeaders(*body.Value.Headers)
|
||||
}
|
||||
normalizeOpenAICompatibilityEntry(&entry)
|
||||
h.cfg.OpenAICompatibility[targetIndex] = entry
|
||||
h.cfg.SanitizeOpenAICompatibility()
|
||||
h.persist(c)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
c.JSON(404, gin.H{"error": "item not found"})
|
||||
}
|
||||
|
||||
func (h *Handler) DeleteOpenAICompat(c *gin.Context) {
|
||||
if name := c.Query("name"); name != "" {
|
||||
out := make([]config.OpenAICompatibility, 0, len(h.cfg.OpenAICompatibility))
|
||||
@@ -563,66 +612,72 @@ func (h *Handler) PutCodexKeys(c *gin.Context) {
|
||||
h.persist(c)
|
||||
}
|
||||
func (h *Handler) PatchCodexKey(c *gin.Context) {
|
||||
type codexKeyPatch struct {
|
||||
APIKey *string `json:"api-key"`
|
||||
Prefix *string `json:"prefix"`
|
||||
BaseURL *string `json:"base-url"`
|
||||
ProxyURL *string `json:"proxy-url"`
|
||||
Headers *map[string]string `json:"headers"`
|
||||
ExcludedModels *[]string `json:"excluded-models"`
|
||||
}
|
||||
var body struct {
|
||||
Index *int `json:"index"`
|
||||
Match *string `json:"match"`
|
||||
Value *config.CodexKey `json:"value"`
|
||||
Value *codexKeyPatch `json:"value"`
|
||||
}
|
||||
if err := c.ShouldBindJSON(&body); err != nil || body.Value == nil {
|
||||
c.JSON(400, gin.H{"error": "invalid body"})
|
||||
return
|
||||
}
|
||||
value := *body.Value
|
||||
value.APIKey = strings.TrimSpace(value.APIKey)
|
||||
value.BaseURL = strings.TrimSpace(value.BaseURL)
|
||||
value.ProxyURL = strings.TrimSpace(value.ProxyURL)
|
||||
value.Headers = config.NormalizeHeaders(value.Headers)
|
||||
value.ExcludedModels = config.NormalizeExcludedModels(value.ExcludedModels)
|
||||
// If base-url becomes empty, delete instead of update
|
||||
if value.BaseURL == "" {
|
||||
targetIndex := -1
|
||||
if body.Index != nil && *body.Index >= 0 && *body.Index < len(h.cfg.CodexKey) {
|
||||
h.cfg.CodexKey = append(h.cfg.CodexKey[:*body.Index], h.cfg.CodexKey[*body.Index+1:]...)
|
||||
h.cfg.SanitizeCodexKeys()
|
||||
h.persist(c)
|
||||
return
|
||||
targetIndex = *body.Index
|
||||
}
|
||||
if body.Match != nil {
|
||||
out := make([]config.CodexKey, 0, len(h.cfg.CodexKey))
|
||||
removed := false
|
||||
if targetIndex == -1 && body.Match != nil {
|
||||
match := strings.TrimSpace(*body.Match)
|
||||
for i := range h.cfg.CodexKey {
|
||||
if !removed && h.cfg.CodexKey[i].APIKey == *body.Match {
|
||||
removed = true
|
||||
continue
|
||||
}
|
||||
out = append(out, h.cfg.CodexKey[i])
|
||||
}
|
||||
if removed {
|
||||
h.cfg.CodexKey = out
|
||||
h.cfg.SanitizeCodexKeys()
|
||||
h.persist(c)
|
||||
return
|
||||
}
|
||||
}
|
||||
} else {
|
||||
if body.Index != nil && *body.Index >= 0 && *body.Index < len(h.cfg.CodexKey) {
|
||||
h.cfg.CodexKey[*body.Index] = value
|
||||
h.cfg.SanitizeCodexKeys()
|
||||
h.persist(c)
|
||||
return
|
||||
}
|
||||
if body.Match != nil {
|
||||
for i := range h.cfg.CodexKey {
|
||||
if h.cfg.CodexKey[i].APIKey == *body.Match {
|
||||
h.cfg.CodexKey[i] = value
|
||||
h.cfg.SanitizeCodexKeys()
|
||||
h.persist(c)
|
||||
return
|
||||
}
|
||||
if h.cfg.CodexKey[i].APIKey == match {
|
||||
targetIndex = i
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
if targetIndex == -1 {
|
||||
c.JSON(404, gin.H{"error": "item not found"})
|
||||
return
|
||||
}
|
||||
|
||||
entry := h.cfg.CodexKey[targetIndex]
|
||||
if body.Value.APIKey != nil {
|
||||
entry.APIKey = strings.TrimSpace(*body.Value.APIKey)
|
||||
}
|
||||
if body.Value.Prefix != nil {
|
||||
entry.Prefix = strings.TrimSpace(*body.Value.Prefix)
|
||||
}
|
||||
if body.Value.BaseURL != nil {
|
||||
trimmed := strings.TrimSpace(*body.Value.BaseURL)
|
||||
if trimmed == "" {
|
||||
h.cfg.CodexKey = append(h.cfg.CodexKey[:targetIndex], h.cfg.CodexKey[targetIndex+1:]...)
|
||||
h.cfg.SanitizeCodexKeys()
|
||||
h.persist(c)
|
||||
return
|
||||
}
|
||||
entry.BaseURL = trimmed
|
||||
}
|
||||
if body.Value.ProxyURL != nil {
|
||||
entry.ProxyURL = strings.TrimSpace(*body.Value.ProxyURL)
|
||||
}
|
||||
if body.Value.Headers != nil {
|
||||
entry.Headers = config.NormalizeHeaders(*body.Value.Headers)
|
||||
}
|
||||
if body.Value.ExcludedModels != nil {
|
||||
entry.ExcludedModels = config.NormalizeExcludedModels(*body.Value.ExcludedModels)
|
||||
}
|
||||
h.cfg.CodexKey[targetIndex] = entry
|
||||
h.cfg.SanitizeCodexKeys()
|
||||
h.persist(c)
|
||||
}
|
||||
|
||||
func (h *Handler) DeleteCodexKey(c *gin.Context) {
|
||||
if val := c.Query("api-key"); val != "" {
|
||||
out := make([]config.CodexKey, 0, len(h.cfg.CodexKey))
|
||||
|
||||
100
internal/api/handlers/management/oauth_callback.go
Normal file
100
internal/api/handlers/management/oauth_callback.go
Normal file
@@ -0,0 +1,100 @@
|
||||
package management
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
type oauthCallbackRequest struct {
|
||||
Provider string `json:"provider"`
|
||||
RedirectURL string `json:"redirect_url"`
|
||||
Code string `json:"code"`
|
||||
State string `json:"state"`
|
||||
Error string `json:"error"`
|
||||
}
|
||||
|
||||
func (h *Handler) PostOAuthCallback(c *gin.Context) {
|
||||
if h == nil || h.cfg == nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"status": "error", "error": "handler not initialized"})
|
||||
return
|
||||
}
|
||||
|
||||
var req oauthCallbackRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"status": "error", "error": "invalid body"})
|
||||
return
|
||||
}
|
||||
|
||||
canonicalProvider, err := NormalizeOAuthProvider(req.Provider)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"status": "error", "error": "unsupported provider"})
|
||||
return
|
||||
}
|
||||
|
||||
state := strings.TrimSpace(req.State)
|
||||
code := strings.TrimSpace(req.Code)
|
||||
errMsg := strings.TrimSpace(req.Error)
|
||||
|
||||
if rawRedirect := strings.TrimSpace(req.RedirectURL); rawRedirect != "" {
|
||||
u, errParse := url.Parse(rawRedirect)
|
||||
if errParse != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"status": "error", "error": "invalid redirect_url"})
|
||||
return
|
||||
}
|
||||
q := u.Query()
|
||||
if state == "" {
|
||||
state = strings.TrimSpace(q.Get("state"))
|
||||
}
|
||||
if code == "" {
|
||||
code = strings.TrimSpace(q.Get("code"))
|
||||
}
|
||||
if errMsg == "" {
|
||||
errMsg = strings.TrimSpace(q.Get("error"))
|
||||
if errMsg == "" {
|
||||
errMsg = strings.TrimSpace(q.Get("error_description"))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if state == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"status": "error", "error": "state is required"})
|
||||
return
|
||||
}
|
||||
if err := ValidateOAuthState(state); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"status": "error", "error": "invalid state"})
|
||||
return
|
||||
}
|
||||
if code == "" && errMsg == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"status": "error", "error": "code or error is required"})
|
||||
return
|
||||
}
|
||||
|
||||
sessionProvider, sessionStatus, ok := GetOAuthSession(state)
|
||||
if !ok {
|
||||
c.JSON(http.StatusNotFound, gin.H{"status": "error", "error": "unknown or expired state"})
|
||||
return
|
||||
}
|
||||
if sessionStatus != "" {
|
||||
c.JSON(http.StatusConflict, gin.H{"status": "error", "error": "oauth flow is not pending"})
|
||||
return
|
||||
}
|
||||
if !strings.EqualFold(sessionProvider, canonicalProvider) {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"status": "error", "error": "provider does not match state"})
|
||||
return
|
||||
}
|
||||
|
||||
if _, errWrite := WriteOAuthCallbackFileForPendingSession(h.cfg.AuthDir, canonicalProvider, state, code, errMsg); errWrite != nil {
|
||||
if errors.Is(errWrite, errOAuthSessionNotPending) {
|
||||
c.JSON(http.StatusConflict, gin.H{"status": "error", "error": "oauth flow is not pending"})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"status": "error", "error": "failed to persist oauth callback"})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{"status": "ok"})
|
||||
}
|
||||
283
internal/api/handlers/management/oauth_sessions.go
Normal file
283
internal/api/handlers/management/oauth_sessions.go
Normal file
@@ -0,0 +1,283 @@
|
||||
package management
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
oauthSessionTTL = 10 * time.Minute
|
||||
maxOAuthStateLength = 128
|
||||
)
|
||||
|
||||
var (
|
||||
errInvalidOAuthState = errors.New("invalid oauth state")
|
||||
errUnsupportedOAuthFlow = errors.New("unsupported oauth provider")
|
||||
errOAuthSessionNotPending = errors.New("oauth session is not pending")
|
||||
)
|
||||
|
||||
type oauthSession struct {
|
||||
Provider string
|
||||
Status string
|
||||
CreatedAt time.Time
|
||||
ExpiresAt time.Time
|
||||
}
|
||||
|
||||
type oauthSessionStore struct {
|
||||
mu sync.RWMutex
|
||||
ttl time.Duration
|
||||
sessions map[string]oauthSession
|
||||
}
|
||||
|
||||
func newOAuthSessionStore(ttl time.Duration) *oauthSessionStore {
|
||||
if ttl <= 0 {
|
||||
ttl = oauthSessionTTL
|
||||
}
|
||||
return &oauthSessionStore{
|
||||
ttl: ttl,
|
||||
sessions: make(map[string]oauthSession),
|
||||
}
|
||||
}
|
||||
|
||||
func (s *oauthSessionStore) purgeExpiredLocked(now time.Time) {
|
||||
for state, session := range s.sessions {
|
||||
if !session.ExpiresAt.IsZero() && now.After(session.ExpiresAt) {
|
||||
delete(s.sessions, state)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *oauthSessionStore) Register(state, provider string) {
|
||||
state = strings.TrimSpace(state)
|
||||
provider = strings.ToLower(strings.TrimSpace(provider))
|
||||
if state == "" || provider == "" {
|
||||
return
|
||||
}
|
||||
now := time.Now()
|
||||
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
s.purgeExpiredLocked(now)
|
||||
s.sessions[state] = oauthSession{
|
||||
Provider: provider,
|
||||
Status: "",
|
||||
CreatedAt: now,
|
||||
ExpiresAt: now.Add(s.ttl),
|
||||
}
|
||||
}
|
||||
|
||||
func (s *oauthSessionStore) SetError(state, message string) {
|
||||
state = strings.TrimSpace(state)
|
||||
message = strings.TrimSpace(message)
|
||||
if state == "" {
|
||||
return
|
||||
}
|
||||
if message == "" {
|
||||
message = "Authentication failed"
|
||||
}
|
||||
now := time.Now()
|
||||
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
s.purgeExpiredLocked(now)
|
||||
session, ok := s.sessions[state]
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
session.Status = message
|
||||
session.ExpiresAt = now.Add(s.ttl)
|
||||
s.sessions[state] = session
|
||||
}
|
||||
|
||||
func (s *oauthSessionStore) Complete(state string) {
|
||||
state = strings.TrimSpace(state)
|
||||
if state == "" {
|
||||
return
|
||||
}
|
||||
now := time.Now()
|
||||
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
s.purgeExpiredLocked(now)
|
||||
delete(s.sessions, state)
|
||||
}
|
||||
|
||||
func (s *oauthSessionStore) CompleteProvider(provider string) int {
|
||||
provider = strings.ToLower(strings.TrimSpace(provider))
|
||||
if provider == "" {
|
||||
return 0
|
||||
}
|
||||
now := time.Now()
|
||||
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
s.purgeExpiredLocked(now)
|
||||
removed := 0
|
||||
for state, session := range s.sessions {
|
||||
if strings.EqualFold(session.Provider, provider) {
|
||||
delete(s.sessions, state)
|
||||
removed++
|
||||
}
|
||||
}
|
||||
return removed
|
||||
}
|
||||
|
||||
func (s *oauthSessionStore) Get(state string) (oauthSession, bool) {
|
||||
state = strings.TrimSpace(state)
|
||||
now := time.Now()
|
||||
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
s.purgeExpiredLocked(now)
|
||||
session, ok := s.sessions[state]
|
||||
return session, ok
|
||||
}
|
||||
|
||||
func (s *oauthSessionStore) IsPending(state, provider string) bool {
|
||||
state = strings.TrimSpace(state)
|
||||
provider = strings.ToLower(strings.TrimSpace(provider))
|
||||
now := time.Now()
|
||||
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
s.purgeExpiredLocked(now)
|
||||
session, ok := s.sessions[state]
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
if session.Status != "" {
|
||||
return false
|
||||
}
|
||||
if provider == "" {
|
||||
return true
|
||||
}
|
||||
return strings.EqualFold(session.Provider, provider)
|
||||
}
|
||||
|
||||
var oauthSessions = newOAuthSessionStore(oauthSessionTTL)
|
||||
|
||||
func RegisterOAuthSession(state, provider string) { oauthSessions.Register(state, provider) }
|
||||
|
||||
func SetOAuthSessionError(state, message string) { oauthSessions.SetError(state, message) }
|
||||
|
||||
func CompleteOAuthSession(state string) { oauthSessions.Complete(state) }
|
||||
|
||||
func CompleteOAuthSessionsByProvider(provider string) int {
|
||||
return oauthSessions.CompleteProvider(provider)
|
||||
}
|
||||
|
||||
func GetOAuthSession(state string) (provider string, status string, ok bool) {
|
||||
session, ok := oauthSessions.Get(state)
|
||||
if !ok {
|
||||
return "", "", false
|
||||
}
|
||||
return session.Provider, session.Status, true
|
||||
}
|
||||
|
||||
func IsOAuthSessionPending(state, provider string) bool {
|
||||
return oauthSessions.IsPending(state, provider)
|
||||
}
|
||||
|
||||
func ValidateOAuthState(state string) error {
|
||||
trimmed := strings.TrimSpace(state)
|
||||
if trimmed == "" {
|
||||
return fmt.Errorf("%w: empty", errInvalidOAuthState)
|
||||
}
|
||||
if len(trimmed) > maxOAuthStateLength {
|
||||
return fmt.Errorf("%w: too long", errInvalidOAuthState)
|
||||
}
|
||||
if strings.Contains(trimmed, "/") || strings.Contains(trimmed, "\\") {
|
||||
return fmt.Errorf("%w: contains path separator", errInvalidOAuthState)
|
||||
}
|
||||
if strings.Contains(trimmed, "..") {
|
||||
return fmt.Errorf("%w: contains '..'", errInvalidOAuthState)
|
||||
}
|
||||
for _, r := range trimmed {
|
||||
switch {
|
||||
case r >= 'a' && r <= 'z':
|
||||
case r >= 'A' && r <= 'Z':
|
||||
case r >= '0' && r <= '9':
|
||||
case r == '-' || r == '_' || r == '.':
|
||||
default:
|
||||
return fmt.Errorf("%w: invalid character", errInvalidOAuthState)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func NormalizeOAuthProvider(provider string) (string, error) {
|
||||
switch strings.ToLower(strings.TrimSpace(provider)) {
|
||||
case "anthropic", "claude":
|
||||
return "anthropic", nil
|
||||
case "codex", "openai":
|
||||
return "codex", nil
|
||||
case "gemini", "google":
|
||||
return "gemini", nil
|
||||
case "iflow", "i-flow":
|
||||
return "iflow", nil
|
||||
case "antigravity", "anti-gravity":
|
||||
return "antigravity", nil
|
||||
case "qwen":
|
||||
return "qwen", nil
|
||||
default:
|
||||
return "", errUnsupportedOAuthFlow
|
||||
}
|
||||
}
|
||||
|
||||
type oauthCallbackFilePayload struct {
|
||||
Code string `json:"code"`
|
||||
State string `json:"state"`
|
||||
Error string `json:"error"`
|
||||
}
|
||||
|
||||
func WriteOAuthCallbackFile(authDir, provider, state, code, errorMessage string) (string, error) {
|
||||
if strings.TrimSpace(authDir) == "" {
|
||||
return "", fmt.Errorf("auth dir is empty")
|
||||
}
|
||||
canonicalProvider, err := NormalizeOAuthProvider(provider)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if err := ValidateOAuthState(state); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
fileName := fmt.Sprintf(".oauth-%s-%s.oauth", canonicalProvider, state)
|
||||
filePath := filepath.Join(authDir, fileName)
|
||||
payload := oauthCallbackFilePayload{
|
||||
Code: strings.TrimSpace(code),
|
||||
State: strings.TrimSpace(state),
|
||||
Error: strings.TrimSpace(errorMessage),
|
||||
}
|
||||
data, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("marshal oauth callback payload: %w", err)
|
||||
}
|
||||
if err := os.WriteFile(filePath, data, 0o600); err != nil {
|
||||
return "", fmt.Errorf("write oauth callback file: %w", err)
|
||||
}
|
||||
return filePath, nil
|
||||
}
|
||||
|
||||
func WriteOAuthCallbackFileForPendingSession(authDir, provider, state, code, errorMessage string) (string, error) {
|
||||
canonicalProvider, err := NormalizeOAuthProvider(provider)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if !IsOAuthSessionPending(state, canonicalProvider) {
|
||||
return "", errOAuthSessionNotPending
|
||||
}
|
||||
return WriteOAuthCallbackFile(authDir, canonicalProvider, state, code, errorMessage)
|
||||
}
|
||||
@@ -134,7 +134,43 @@ func (fh *FallbackHandler) WrapHandler(handler gin.HandlerFunc) gin.HandlerFunc
|
||||
}
|
||||
|
||||
// Normalize model (handles dynamic thinking suffixes)
|
||||
normalizedModel, _ := util.NormalizeThinkingModel(modelName)
|
||||
normalizedModel, thinkingMetadata := util.NormalizeThinkingModel(modelName)
|
||||
thinkingSuffix := ""
|
||||
if thinkingMetadata != nil && strings.HasPrefix(modelName, normalizedModel) {
|
||||
thinkingSuffix = modelName[len(normalizedModel):]
|
||||
}
|
||||
|
||||
resolveMappedModel := func() (string, []string) {
|
||||
if fh.modelMapper == nil {
|
||||
return "", nil
|
||||
}
|
||||
|
||||
mappedModel := fh.modelMapper.MapModel(modelName)
|
||||
if mappedModel == "" {
|
||||
mappedModel = fh.modelMapper.MapModel(normalizedModel)
|
||||
}
|
||||
mappedModel = strings.TrimSpace(mappedModel)
|
||||
if mappedModel == "" {
|
||||
return "", nil
|
||||
}
|
||||
|
||||
// Preserve dynamic thinking suffix (e.g. "(xhigh)") when mapping applies, unless the target
|
||||
// already specifies its own thinking suffix.
|
||||
if thinkingSuffix != "" {
|
||||
_, mappedThinkingMetadata := util.NormalizeThinkingModel(mappedModel)
|
||||
if mappedThinkingMetadata == nil {
|
||||
mappedModel += thinkingSuffix
|
||||
}
|
||||
}
|
||||
|
||||
mappedBaseModel, _ := util.NormalizeThinkingModel(mappedModel)
|
||||
mappedProviders := util.GetProviderName(mappedBaseModel)
|
||||
if len(mappedProviders) == 0 {
|
||||
return "", nil
|
||||
}
|
||||
|
||||
return mappedModel, mappedProviders
|
||||
}
|
||||
|
||||
// Track resolved model for logging (may change if mapping is applied)
|
||||
resolvedModel := normalizedModel
|
||||
@@ -147,11 +183,7 @@ func (fh *FallbackHandler) WrapHandler(handler gin.HandlerFunc) gin.HandlerFunc
|
||||
if forceMappings {
|
||||
// FORCE MODE: Check model mappings FIRST (takes precedence over local API keys)
|
||||
// This allows users to route Amp requests to their preferred OAuth providers
|
||||
if fh.modelMapper != nil {
|
||||
if mappedModel := fh.modelMapper.MapModel(normalizedModel); mappedModel != "" {
|
||||
// Mapping found - check if we have a provider for the mapped model
|
||||
mappedProviders := util.GetProviderName(mappedModel)
|
||||
if len(mappedProviders) > 0 {
|
||||
if mappedModel, mappedProviders := resolveMappedModel(); mappedModel != "" {
|
||||
// Mapping found and provider available - rewrite the model in request body
|
||||
bodyBytes = rewriteModelInRequest(bodyBytes, mappedModel)
|
||||
c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes))
|
||||
@@ -161,8 +193,6 @@ func (fh *FallbackHandler) WrapHandler(handler gin.HandlerFunc) gin.HandlerFunc
|
||||
usedMapping = true
|
||||
providers = mappedProviders
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// If no mapping applied, check for local providers
|
||||
if !usedMapping {
|
||||
@@ -174,11 +204,7 @@ func (fh *FallbackHandler) WrapHandler(handler gin.HandlerFunc) gin.HandlerFunc
|
||||
|
||||
if len(providers) == 0 {
|
||||
// No providers configured - check if we have a model mapping
|
||||
if fh.modelMapper != nil {
|
||||
if mappedModel := fh.modelMapper.MapModel(normalizedModel); mappedModel != "" {
|
||||
// Mapping found - check if we have a provider for the mapped model
|
||||
mappedProviders := util.GetProviderName(mappedModel)
|
||||
if len(mappedProviders) > 0 {
|
||||
if mappedModel, mappedProviders := resolveMappedModel(); mappedModel != "" {
|
||||
// Mapping found and provider available - rewrite the model in request body
|
||||
bodyBytes = rewriteModelInRequest(bodyBytes, mappedModel)
|
||||
c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes))
|
||||
@@ -190,8 +216,6 @@ func (fh *FallbackHandler) WrapHandler(handler gin.HandlerFunc) gin.HandlerFunc
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// If no providers available, fallback to ampcode.com
|
||||
if len(providers) == 0 {
|
||||
@@ -222,14 +246,14 @@ func (fh *FallbackHandler) WrapHandler(handler gin.HandlerFunc) gin.HandlerFunc
|
||||
// Log: Model was mapped to another model
|
||||
log.Debugf("amp model mapping: request %s -> %s", normalizedModel, resolvedModel)
|
||||
logAmpRouting(RouteTypeModelMapping, modelName, resolvedModel, providerName, requestPath)
|
||||
rewriter := NewResponseRewriter(c.Writer, normalizedModel)
|
||||
rewriter := NewResponseRewriter(c.Writer, modelName)
|
||||
c.Writer = rewriter
|
||||
// Filter Anthropic-Beta header only for local handling paths
|
||||
filterAntropicBetaHeader(c)
|
||||
c.Request.Body = io.NopCloser(bytes.NewReader(bodyBytes))
|
||||
handler(c)
|
||||
rewriter.Flush()
|
||||
log.Debugf("amp model mapping: response %s -> %s", resolvedModel, normalizedModel)
|
||||
log.Debugf("amp model mapping: response %s -> %s", resolvedModel, modelName)
|
||||
} else if len(providers) > 0 {
|
||||
// Log: Using local provider (free)
|
||||
logAmpRouting(RouteTypeLocalProvider, modelName, resolvedModel, providerName, requestPath)
|
||||
|
||||
73
internal/api/modules/amp/fallback_handlers_test.go
Normal file
73
internal/api/modules/amp/fallback_handlers_test.go
Normal file
@@ -0,0 +1,73 @@
|
||||
package amp
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/http/httputil"
|
||||
"testing"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
|
||||
)
|
||||
|
||||
func TestFallbackHandler_ModelMapping_PreservesThinkingSuffixAndRewritesResponse(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
|
||||
reg := registry.GetGlobalRegistry()
|
||||
reg.RegisterClient("test-client-amp-fallback", "codex", []*registry.ModelInfo{
|
||||
{ID: "test/gpt-5.2", OwnedBy: "openai", Type: "codex"},
|
||||
})
|
||||
defer reg.UnregisterClient("test-client-amp-fallback")
|
||||
|
||||
mapper := NewModelMapper([]config.AmpModelMapping{
|
||||
{From: "gpt-5.2", To: "test/gpt-5.2"},
|
||||
})
|
||||
|
||||
fallback := NewFallbackHandlerWithMapper(func() *httputil.ReverseProxy { return nil }, mapper, nil)
|
||||
|
||||
handler := func(c *gin.Context) {
|
||||
var req struct {
|
||||
Model string `json:"model"`
|
||||
}
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"model": req.Model,
|
||||
"seen_model": req.Model,
|
||||
})
|
||||
}
|
||||
|
||||
r := gin.New()
|
||||
r.POST("/chat/completions", fallback.WrapHandler(handler))
|
||||
|
||||
reqBody := []byte(`{"model":"gpt-5.2(xhigh)"}`)
|
||||
req := httptest.NewRequest(http.MethodPost, "/chat/completions", bytes.NewReader(reqBody))
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
w := httptest.NewRecorder()
|
||||
r.ServeHTTP(w, req)
|
||||
|
||||
if w.Code != http.StatusOK {
|
||||
t.Fatalf("Expected status 200, got %d", w.Code)
|
||||
}
|
||||
|
||||
var resp struct {
|
||||
Model string `json:"model"`
|
||||
SeenModel string `json:"seen_model"`
|
||||
}
|
||||
if err := json.Unmarshal(w.Body.Bytes(), &resp); err != nil {
|
||||
t.Fatalf("Failed to parse response JSON: %v", err)
|
||||
}
|
||||
|
||||
if resp.Model != "gpt-5.2(xhigh)" {
|
||||
t.Errorf("Expected response model gpt-5.2(xhigh), got %s", resp.Model)
|
||||
}
|
||||
if resp.SeenModel != "test/gpt-5.2(xhigh)" {
|
||||
t.Errorf("Expected handler to see test/gpt-5.2(xhigh), got %s", resp.SeenModel)
|
||||
}
|
||||
}
|
||||
@@ -59,7 +59,8 @@ func (m *DefaultModelMapper) MapModel(requestedModel string) string {
|
||||
}
|
||||
|
||||
// Verify target model has available providers
|
||||
providers := util.GetProviderName(targetModel)
|
||||
normalizedTarget, _ := util.NormalizeThinkingModel(targetModel)
|
||||
providers := util.GetProviderName(normalizedTarget)
|
||||
if len(providers) == 0 {
|
||||
log.Debugf("amp model mapping: target model %s has no available providers, skipping mapping", targetModel)
|
||||
return ""
|
||||
|
||||
@@ -71,6 +71,25 @@ func TestModelMapper_MapModel_WithProvider(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestModelMapper_MapModel_TargetWithThinkingSuffix(t *testing.T) {
|
||||
reg := registry.GetGlobalRegistry()
|
||||
reg.RegisterClient("test-client-thinking", "codex", []*registry.ModelInfo{
|
||||
{ID: "gpt-5.2", OwnedBy: "openai", Type: "codex"},
|
||||
})
|
||||
defer reg.UnregisterClient("test-client-thinking")
|
||||
|
||||
mappings := []config.AmpModelMapping{
|
||||
{From: "gpt-5.2-alias", To: "gpt-5.2(xhigh)"},
|
||||
}
|
||||
|
||||
mapper := NewModelMapper(mappings)
|
||||
|
||||
result := mapper.MapModel("gpt-5.2-alias")
|
||||
if result != "gpt-5.2(xhigh)" {
|
||||
t.Errorf("Expected gpt-5.2(xhigh), got %s", result)
|
||||
}
|
||||
}
|
||||
|
||||
func TestModelMapper_MapModel_CaseInsensitive(t *testing.T) {
|
||||
reg := registry.GetGlobalRegistry()
|
||||
reg.RegisterClient("test-client2", "claude", []*registry.ModelInfo{
|
||||
|
||||
@@ -95,6 +95,20 @@ func (m *AmpModule) managementAvailabilityMiddleware() gin.HandlerFunc {
|
||||
}
|
||||
}
|
||||
|
||||
// wrapManagementAuth skips auth for selected management paths while keeping authentication elsewhere.
|
||||
func wrapManagementAuth(auth gin.HandlerFunc, prefixes ...string) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
path := c.Request.URL.Path
|
||||
for _, prefix := range prefixes {
|
||||
if strings.HasPrefix(path, prefix) && (len(path) == len(prefix) || path[len(prefix)] == '/') {
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
}
|
||||
auth(c)
|
||||
}
|
||||
}
|
||||
|
||||
// registerManagementRoutes registers Amp management proxy routes
|
||||
// These routes proxy through to the Amp control plane for OAuth, user management, etc.
|
||||
// Uses dynamic middleware and proxy getter for hot-reload support.
|
||||
@@ -109,8 +123,10 @@ func (m *AmpModule) registerManagementRoutes(engine *gin.Engine, baseHandler *ha
|
||||
ampAPI.Use(m.localhostOnlyMiddleware())
|
||||
|
||||
// Apply authentication middleware - requires valid API key in Authorization header
|
||||
var authWithBypass gin.HandlerFunc
|
||||
if auth != nil {
|
||||
ampAPI.Use(auth)
|
||||
authWithBypass = wrapManagementAuth(auth, "/threads", "/auth", "/docs", "/settings")
|
||||
}
|
||||
|
||||
// Dynamic proxy handler that uses m.getProxy() for hot-reload support
|
||||
@@ -156,10 +172,16 @@ func (m *AmpModule) registerManagementRoutes(engine *gin.Engine, baseHandler *ha
|
||||
// Root-level routes that AMP CLI expects without /api prefix
|
||||
// These need the same security middleware as the /api/* routes (dynamic for hot-reload)
|
||||
rootMiddleware := []gin.HandlerFunc{m.managementAvailabilityMiddleware(), noCORSMiddleware(), m.localhostOnlyMiddleware()}
|
||||
if auth != nil {
|
||||
rootMiddleware = append(rootMiddleware, auth)
|
||||
if authWithBypass != nil {
|
||||
rootMiddleware = append(rootMiddleware, authWithBypass)
|
||||
}
|
||||
engine.GET("/threads", append(rootMiddleware, proxyHandler)...)
|
||||
engine.GET("/threads/*path", append(rootMiddleware, proxyHandler)...)
|
||||
engine.GET("/docs", append(rootMiddleware, proxyHandler)...)
|
||||
engine.GET("/docs/*path", append(rootMiddleware, proxyHandler)...)
|
||||
engine.GET("/settings", append(rootMiddleware, proxyHandler)...)
|
||||
engine.GET("/settings/*path", append(rootMiddleware, proxyHandler)...)
|
||||
|
||||
engine.GET("/threads.rss", append(rootMiddleware, proxyHandler)...)
|
||||
engine.GET("/news.rss", append(rootMiddleware, proxyHandler)...)
|
||||
|
||||
|
||||
@@ -354,10 +354,11 @@ func (s *Server) setupRoutes() {
|
||||
code := c.Query("code")
|
||||
state := c.Query("state")
|
||||
errStr := c.Query("error")
|
||||
// Persist to a temporary file keyed by state
|
||||
if errStr == "" {
|
||||
errStr = c.Query("error_description")
|
||||
}
|
||||
if state != "" {
|
||||
file := fmt.Sprintf("%s/.oauth-anthropic-%s.oauth", s.cfg.AuthDir, state)
|
||||
_ = os.WriteFile(file, []byte(fmt.Sprintf(`{"code":"%s","state":"%s","error":"%s"}`, code, state, errStr)), 0o600)
|
||||
_, _ = managementHandlers.WriteOAuthCallbackFileForPendingSession(s.cfg.AuthDir, "anthropic", state, code, errStr)
|
||||
}
|
||||
c.Header("Content-Type", "text/html; charset=utf-8")
|
||||
c.String(http.StatusOK, oauthCallbackSuccessHTML)
|
||||
@@ -367,9 +368,11 @@ func (s *Server) setupRoutes() {
|
||||
code := c.Query("code")
|
||||
state := c.Query("state")
|
||||
errStr := c.Query("error")
|
||||
if errStr == "" {
|
||||
errStr = c.Query("error_description")
|
||||
}
|
||||
if state != "" {
|
||||
file := fmt.Sprintf("%s/.oauth-codex-%s.oauth", s.cfg.AuthDir, state)
|
||||
_ = os.WriteFile(file, []byte(fmt.Sprintf(`{"code":"%s","state":"%s","error":"%s"}`, code, state, errStr)), 0o600)
|
||||
_, _ = managementHandlers.WriteOAuthCallbackFileForPendingSession(s.cfg.AuthDir, "codex", state, code, errStr)
|
||||
}
|
||||
c.Header("Content-Type", "text/html; charset=utf-8")
|
||||
c.String(http.StatusOK, oauthCallbackSuccessHTML)
|
||||
@@ -379,9 +382,11 @@ func (s *Server) setupRoutes() {
|
||||
code := c.Query("code")
|
||||
state := c.Query("state")
|
||||
errStr := c.Query("error")
|
||||
if errStr == "" {
|
||||
errStr = c.Query("error_description")
|
||||
}
|
||||
if state != "" {
|
||||
file := fmt.Sprintf("%s/.oauth-gemini-%s.oauth", s.cfg.AuthDir, state)
|
||||
_ = os.WriteFile(file, []byte(fmt.Sprintf(`{"code":"%s","state":"%s","error":"%s"}`, code, state, errStr)), 0o600)
|
||||
_, _ = managementHandlers.WriteOAuthCallbackFileForPendingSession(s.cfg.AuthDir, "gemini", state, code, errStr)
|
||||
}
|
||||
c.Header("Content-Type", "text/html; charset=utf-8")
|
||||
c.String(http.StatusOK, oauthCallbackSuccessHTML)
|
||||
@@ -391,9 +396,11 @@ func (s *Server) setupRoutes() {
|
||||
code := c.Query("code")
|
||||
state := c.Query("state")
|
||||
errStr := c.Query("error")
|
||||
if errStr == "" {
|
||||
errStr = c.Query("error_description")
|
||||
}
|
||||
if state != "" {
|
||||
file := fmt.Sprintf("%s/.oauth-iflow-%s.oauth", s.cfg.AuthDir, state)
|
||||
_ = os.WriteFile(file, []byte(fmt.Sprintf(`{"code":"%s","state":"%s","error":"%s"}`, code, state, errStr)), 0o600)
|
||||
_, _ = managementHandlers.WriteOAuthCallbackFileForPendingSession(s.cfg.AuthDir, "iflow", state, code, errStr)
|
||||
}
|
||||
c.Header("Content-Type", "text/html; charset=utf-8")
|
||||
c.String(http.StatusOK, oauthCallbackSuccessHTML)
|
||||
@@ -403,9 +410,11 @@ func (s *Server) setupRoutes() {
|
||||
code := c.Query("code")
|
||||
state := c.Query("state")
|
||||
errStr := c.Query("error")
|
||||
if errStr == "" {
|
||||
errStr = c.Query("error_description")
|
||||
}
|
||||
if state != "" {
|
||||
file := fmt.Sprintf("%s/.oauth-antigravity-%s.oauth", s.cfg.AuthDir, state)
|
||||
_ = os.WriteFile(file, []byte(fmt.Sprintf(`{"code":"%s","state":"%s","error":"%s"}`, code, state, errStr)), 0o600)
|
||||
_, _ = managementHandlers.WriteOAuthCallbackFileForPendingSession(s.cfg.AuthDir, "antigravity", state, code, errStr)
|
||||
}
|
||||
c.Header("Content-Type", "text/html; charset=utf-8")
|
||||
c.String(http.StatusOK, oauthCallbackSuccessHTML)
|
||||
@@ -577,6 +586,7 @@ func (s *Server) registerManagementRoutes() {
|
||||
mgmt.GET("/qwen-auth-url", s.mgmt.RequestQwenToken)
|
||||
mgmt.GET("/iflow-auth-url", s.mgmt.RequestIFlowToken)
|
||||
mgmt.POST("/iflow-auth-url", s.mgmt.RequestIFlowCookieToken)
|
||||
mgmt.POST("/oauth-callback", s.mgmt.PostOAuthCallback)
|
||||
mgmt.GET("/get-auth-status", s.mgmt.GetAuthStatus)
|
||||
}
|
||||
}
|
||||
@@ -834,12 +844,21 @@ func (s *Server) UpdateClients(cfg *config.Config) {
|
||||
}
|
||||
}
|
||||
|
||||
if oldCfg != nil && oldCfg.LoggingToFile != cfg.LoggingToFile {
|
||||
if err := logging.ConfigureLogOutput(cfg.LoggingToFile); err != nil {
|
||||
if oldCfg == nil || oldCfg.LoggingToFile != cfg.LoggingToFile || oldCfg.LogsMaxTotalSizeMB != cfg.LogsMaxTotalSizeMB {
|
||||
if err := logging.ConfigureLogOutput(cfg.LoggingToFile, cfg.LogsMaxTotalSizeMB); err != nil {
|
||||
log.Errorf("failed to reconfigure log output: %v", err)
|
||||
} else {
|
||||
if oldCfg == nil {
|
||||
log.Debug("log output configuration refreshed")
|
||||
} else {
|
||||
if oldCfg.LoggingToFile != cfg.LoggingToFile {
|
||||
log.Debugf("logging_to_file updated from %t to %t", oldCfg.LoggingToFile, cfg.LoggingToFile)
|
||||
}
|
||||
if oldCfg.LogsMaxTotalSizeMB != cfg.LogsMaxTotalSizeMB {
|
||||
log.Debugf("logs_max_total_size_mb updated from %d to %d", oldCfg.LogsMaxTotalSizeMB, cfg.LogsMaxTotalSizeMB)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if oldCfg == nil || oldCfg.UsageStatisticsEnabled != cfg.UsageStatisticsEnabled {
|
||||
|
||||
@@ -18,6 +18,7 @@ import (
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/auth/codex"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/browser"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/misc"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/tidwall/gjson"
|
||||
@@ -46,6 +47,12 @@ var (
|
||||
type GeminiAuth struct {
|
||||
}
|
||||
|
||||
// WebLoginOptions customizes the interactive OAuth flow.
|
||||
type WebLoginOptions struct {
|
||||
NoBrowser bool
|
||||
Prompt func(string) (string, error)
|
||||
}
|
||||
|
||||
// NewGeminiAuth creates a new instance of GeminiAuth.
|
||||
func NewGeminiAuth() *GeminiAuth {
|
||||
return &GeminiAuth{}
|
||||
@@ -59,12 +66,12 @@ func NewGeminiAuth() *GeminiAuth {
|
||||
// - ctx: The context for the HTTP client
|
||||
// - ts: The Gemini token storage containing authentication tokens
|
||||
// - cfg: The configuration containing proxy settings
|
||||
// - noBrowser: Optional parameter to disable browser opening
|
||||
// - opts: Optional parameters to customize browser and prompt behavior
|
||||
//
|
||||
// Returns:
|
||||
// - *http.Client: An HTTP client configured with authentication
|
||||
// - error: An error if the client configuration fails, nil otherwise
|
||||
func (g *GeminiAuth) GetAuthenticatedClient(ctx context.Context, ts *GeminiTokenStorage, cfg *config.Config, noBrowser ...bool) (*http.Client, error) {
|
||||
func (g *GeminiAuth) GetAuthenticatedClient(ctx context.Context, ts *GeminiTokenStorage, cfg *config.Config, opts *WebLoginOptions) (*http.Client, error) {
|
||||
// Configure proxy settings for the HTTP client if a proxy URL is provided.
|
||||
proxyURL, err := url.Parse(cfg.ProxyURL)
|
||||
if err == nil {
|
||||
@@ -109,7 +116,7 @@ func (g *GeminiAuth) GetAuthenticatedClient(ctx context.Context, ts *GeminiToken
|
||||
// If no token is found in storage, initiate the web-based OAuth flow.
|
||||
if ts.Token == nil {
|
||||
fmt.Printf("Could not load token from file, starting OAuth flow.\n")
|
||||
token, err = g.getTokenFromWeb(ctx, conf, noBrowser...)
|
||||
token, err = g.getTokenFromWeb(ctx, conf, opts)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get token from web: %w", err)
|
||||
}
|
||||
@@ -205,15 +212,15 @@ func (g *GeminiAuth) createTokenStorage(ctx context.Context, config *oauth2.Conf
|
||||
// Parameters:
|
||||
// - ctx: The context for the HTTP client
|
||||
// - config: The OAuth2 configuration
|
||||
// - noBrowser: Optional parameter to disable browser opening
|
||||
// - opts: Optional parameters to customize browser and prompt behavior
|
||||
//
|
||||
// Returns:
|
||||
// - *oauth2.Token: The OAuth2 token obtained from the authorization flow
|
||||
// - error: An error if the token acquisition fails, nil otherwise
|
||||
func (g *GeminiAuth) getTokenFromWeb(ctx context.Context, config *oauth2.Config, noBrowser ...bool) (*oauth2.Token, error) {
|
||||
func (g *GeminiAuth) getTokenFromWeb(ctx context.Context, config *oauth2.Config, opts *WebLoginOptions) (*oauth2.Token, error) {
|
||||
// Use a channel to pass the authorization code from the HTTP handler to the main function.
|
||||
codeChan := make(chan string)
|
||||
errChan := make(chan error)
|
||||
codeChan := make(chan string, 1)
|
||||
errChan := make(chan error, 1)
|
||||
|
||||
// Create a new HTTP server with its own multiplexer.
|
||||
mux := http.NewServeMux()
|
||||
@@ -223,17 +230,26 @@ func (g *GeminiAuth) getTokenFromWeb(ctx context.Context, config *oauth2.Config,
|
||||
mux.HandleFunc("/oauth2callback", func(w http.ResponseWriter, r *http.Request) {
|
||||
if err := r.URL.Query().Get("error"); err != "" {
|
||||
_, _ = fmt.Fprintf(w, "Authentication failed: %s", err)
|
||||
errChan <- fmt.Errorf("authentication failed via callback: %s", err)
|
||||
select {
|
||||
case errChan <- fmt.Errorf("authentication failed via callback: %s", err):
|
||||
default:
|
||||
}
|
||||
return
|
||||
}
|
||||
code := r.URL.Query().Get("code")
|
||||
if code == "" {
|
||||
_, _ = fmt.Fprint(w, "Authentication failed: code not found.")
|
||||
errChan <- fmt.Errorf("code not found in callback")
|
||||
select {
|
||||
case errChan <- fmt.Errorf("code not found in callback"):
|
||||
default:
|
||||
}
|
||||
return
|
||||
}
|
||||
_, _ = fmt.Fprint(w, "<html><body><h1>Authentication successful!</h1><p>You can close this window.</p></body></html>")
|
||||
codeChan <- code
|
||||
select {
|
||||
case codeChan <- code:
|
||||
default:
|
||||
}
|
||||
})
|
||||
|
||||
// Start the server in a goroutine.
|
||||
@@ -250,7 +266,12 @@ func (g *GeminiAuth) getTokenFromWeb(ctx context.Context, config *oauth2.Config,
|
||||
// Open the authorization URL in the user's browser.
|
||||
authURL := config.AuthCodeURL("state-token", oauth2.AccessTypeOffline, oauth2.SetAuthURLParam("prompt", "consent"))
|
||||
|
||||
if len(noBrowser) == 1 && !noBrowser[0] {
|
||||
noBrowser := false
|
||||
if opts != nil {
|
||||
noBrowser = opts.NoBrowser
|
||||
}
|
||||
|
||||
if !noBrowser {
|
||||
fmt.Println("Opening browser for authentication...")
|
||||
|
||||
// Check if browser is available
|
||||
@@ -281,14 +302,61 @@ func (g *GeminiAuth) getTokenFromWeb(ctx context.Context, config *oauth2.Config,
|
||||
|
||||
// Wait for the authorization code or an error.
|
||||
var authCode string
|
||||
timeoutTimer := time.NewTimer(5 * time.Minute)
|
||||
defer timeoutTimer.Stop()
|
||||
|
||||
var manualPromptTimer *time.Timer
|
||||
var manualPromptC <-chan time.Time
|
||||
if opts != nil && opts.Prompt != nil {
|
||||
manualPromptTimer = time.NewTimer(15 * time.Second)
|
||||
manualPromptC = manualPromptTimer.C
|
||||
defer manualPromptTimer.Stop()
|
||||
}
|
||||
|
||||
waitForCallback:
|
||||
for {
|
||||
select {
|
||||
case code := <-codeChan:
|
||||
authCode = code
|
||||
break waitForCallback
|
||||
case err := <-errChan:
|
||||
return nil, err
|
||||
case <-time.After(5 * time.Minute): // Timeout
|
||||
case <-manualPromptC:
|
||||
manualPromptC = nil
|
||||
if manualPromptTimer != nil {
|
||||
manualPromptTimer.Stop()
|
||||
}
|
||||
select {
|
||||
case code := <-codeChan:
|
||||
authCode = code
|
||||
break waitForCallback
|
||||
case err := <-errChan:
|
||||
return nil, err
|
||||
default:
|
||||
}
|
||||
input, err := opts.Prompt("Paste the Gemini callback URL (or press Enter to keep waiting): ")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
parsed, err := misc.ParseOAuthCallback(input)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if parsed == nil {
|
||||
continue
|
||||
}
|
||||
if parsed.Error != "" {
|
||||
return nil, fmt.Errorf("authentication failed via callback: %s", parsed.Error)
|
||||
}
|
||||
if parsed.Code == "" {
|
||||
return nil, fmt.Errorf("code not found in callback")
|
||||
}
|
||||
authCode = parsed.Code
|
||||
break waitForCallback
|
||||
case <-timeoutTimer.C:
|
||||
return nil, fmt.Errorf("oauth flow timed out")
|
||||
}
|
||||
}
|
||||
|
||||
// Shutdown the server.
|
||||
if err := server.Shutdown(ctx); err != nil {
|
||||
|
||||
164
internal/cache/signature_cache.go
vendored
Normal file
164
internal/cache/signature_cache.go
vendored
Normal file
@@ -0,0 +1,164 @@
|
||||
package cache
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"sort"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// SignatureEntry holds a cached thinking signature with timestamp
|
||||
type SignatureEntry struct {
|
||||
Signature string
|
||||
Timestamp time.Time
|
||||
}
|
||||
|
||||
const (
|
||||
// SignatureCacheTTL is how long signatures are valid
|
||||
SignatureCacheTTL = 1 * time.Hour
|
||||
|
||||
// MaxEntriesPerSession limits memory usage per session
|
||||
MaxEntriesPerSession = 100
|
||||
|
||||
// SignatureTextHashLen is the length of the hash key (16 hex chars = 64-bit key space)
|
||||
SignatureTextHashLen = 16
|
||||
|
||||
// MinValidSignatureLen is the minimum length for a signature to be considered valid
|
||||
MinValidSignatureLen = 50
|
||||
)
|
||||
|
||||
// signatureCache stores signatures by sessionId -> textHash -> SignatureEntry
|
||||
var signatureCache sync.Map
|
||||
|
||||
// sessionCache is the inner map type
|
||||
type sessionCache struct {
|
||||
mu sync.RWMutex
|
||||
entries map[string]SignatureEntry
|
||||
}
|
||||
|
||||
// hashText creates a stable, Unicode-safe key from text content
|
||||
func hashText(text string) string {
|
||||
h := sha256.Sum256([]byte(text))
|
||||
return hex.EncodeToString(h[:])[:SignatureTextHashLen]
|
||||
}
|
||||
|
||||
// getOrCreateSession gets or creates a session cache
|
||||
func getOrCreateSession(sessionID string) *sessionCache {
|
||||
if val, ok := signatureCache.Load(sessionID); ok {
|
||||
return val.(*sessionCache)
|
||||
}
|
||||
sc := &sessionCache{entries: make(map[string]SignatureEntry)}
|
||||
actual, _ := signatureCache.LoadOrStore(sessionID, sc)
|
||||
return actual.(*sessionCache)
|
||||
}
|
||||
|
||||
// CacheSignature stores a thinking signature for a given session and text.
|
||||
// Used for Claude models that require signed thinking blocks in multi-turn conversations.
|
||||
func CacheSignature(sessionID, text, signature string) {
|
||||
if sessionID == "" || text == "" || signature == "" {
|
||||
return
|
||||
}
|
||||
if len(signature) < MinValidSignatureLen {
|
||||
return
|
||||
}
|
||||
|
||||
sc := getOrCreateSession(sessionID)
|
||||
textHash := hashText(text)
|
||||
|
||||
sc.mu.Lock()
|
||||
defer sc.mu.Unlock()
|
||||
|
||||
// Evict expired entries if at capacity
|
||||
if len(sc.entries) >= MaxEntriesPerSession {
|
||||
now := time.Now()
|
||||
for key, entry := range sc.entries {
|
||||
if now.Sub(entry.Timestamp) > SignatureCacheTTL {
|
||||
delete(sc.entries, key)
|
||||
}
|
||||
}
|
||||
// If still at capacity, remove oldest entries
|
||||
if len(sc.entries) >= MaxEntriesPerSession {
|
||||
// Find and remove oldest quarter
|
||||
oldest := make([]struct {
|
||||
key string
|
||||
ts time.Time
|
||||
}, 0, len(sc.entries))
|
||||
for key, entry := range sc.entries {
|
||||
oldest = append(oldest, struct {
|
||||
key string
|
||||
ts time.Time
|
||||
}{key, entry.Timestamp})
|
||||
}
|
||||
// Sort by timestamp (oldest first) using sort.Slice
|
||||
sort.Slice(oldest, func(i, j int) bool {
|
||||
return oldest[i].ts.Before(oldest[j].ts)
|
||||
})
|
||||
|
||||
toRemove := len(oldest) / 4
|
||||
if toRemove < 1 {
|
||||
toRemove = 1
|
||||
}
|
||||
|
||||
for i := 0; i < toRemove; i++ {
|
||||
delete(sc.entries, oldest[i].key)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
sc.entries[textHash] = SignatureEntry{
|
||||
Signature: signature,
|
||||
Timestamp: time.Now(),
|
||||
}
|
||||
}
|
||||
|
||||
// GetCachedSignature retrieves a cached signature for a given session and text.
|
||||
// Returns empty string if not found or expired.
|
||||
func GetCachedSignature(sessionID, text string) string {
|
||||
if sessionID == "" || text == "" {
|
||||
return ""
|
||||
}
|
||||
|
||||
val, ok := signatureCache.Load(sessionID)
|
||||
if !ok {
|
||||
return ""
|
||||
}
|
||||
sc := val.(*sessionCache)
|
||||
|
||||
textHash := hashText(text)
|
||||
|
||||
sc.mu.RLock()
|
||||
entry, exists := sc.entries[textHash]
|
||||
sc.mu.RUnlock()
|
||||
|
||||
if !exists {
|
||||
return ""
|
||||
}
|
||||
|
||||
// Check if expired
|
||||
if time.Since(entry.Timestamp) > SignatureCacheTTL {
|
||||
sc.mu.Lock()
|
||||
delete(sc.entries, textHash)
|
||||
sc.mu.Unlock()
|
||||
return ""
|
||||
}
|
||||
|
||||
return entry.Signature
|
||||
}
|
||||
|
||||
// ClearSignatureCache clears signature cache for a specific session or all sessions.
|
||||
func ClearSignatureCache(sessionID string) {
|
||||
if sessionID != "" {
|
||||
signatureCache.Delete(sessionID)
|
||||
} else {
|
||||
signatureCache.Range(func(key, _ any) bool {
|
||||
signatureCache.Delete(key)
|
||||
return true
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// HasValidSignature checks if a signature is valid (non-empty and long enough)
|
||||
func HasValidSignature(signature string) bool {
|
||||
return signature != "" && len(signature) >= MinValidSignatureLen
|
||||
}
|
||||
216
internal/cache/signature_cache_test.go
vendored
Normal file
216
internal/cache/signature_cache_test.go
vendored
Normal file
@@ -0,0 +1,216 @@
|
||||
package cache
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestCacheSignature_BasicStorageAndRetrieval(t *testing.T) {
|
||||
ClearSignatureCache("")
|
||||
|
||||
sessionID := "test-session-1"
|
||||
text := "This is some thinking text content"
|
||||
signature := "abc123validSignature1234567890123456789012345678901234567890"
|
||||
|
||||
// Store signature
|
||||
CacheSignature(sessionID, text, signature)
|
||||
|
||||
// Retrieve signature
|
||||
retrieved := GetCachedSignature(sessionID, text)
|
||||
if retrieved != signature {
|
||||
t.Errorf("Expected signature '%s', got '%s'", signature, retrieved)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCacheSignature_DifferentSessions(t *testing.T) {
|
||||
ClearSignatureCache("")
|
||||
|
||||
text := "Same text in different sessions"
|
||||
sig1 := "signature1_1234567890123456789012345678901234567890123456"
|
||||
sig2 := "signature2_1234567890123456789012345678901234567890123456"
|
||||
|
||||
CacheSignature("session-a", text, sig1)
|
||||
CacheSignature("session-b", text, sig2)
|
||||
|
||||
if GetCachedSignature("session-a", text) != sig1 {
|
||||
t.Error("Session-a signature mismatch")
|
||||
}
|
||||
if GetCachedSignature("session-b", text) != sig2 {
|
||||
t.Error("Session-b signature mismatch")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCacheSignature_NotFound(t *testing.T) {
|
||||
ClearSignatureCache("")
|
||||
|
||||
// Non-existent session
|
||||
if got := GetCachedSignature("nonexistent", "some text"); got != "" {
|
||||
t.Errorf("Expected empty string for nonexistent session, got '%s'", got)
|
||||
}
|
||||
|
||||
// Existing session but different text
|
||||
CacheSignature("session-x", "text-a", "sigA12345678901234567890123456789012345678901234567890")
|
||||
if got := GetCachedSignature("session-x", "text-b"); got != "" {
|
||||
t.Errorf("Expected empty string for different text, got '%s'", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCacheSignature_EmptyInputs(t *testing.T) {
|
||||
ClearSignatureCache("")
|
||||
|
||||
// All empty/invalid inputs should be no-ops
|
||||
CacheSignature("", "text", "sig12345678901234567890123456789012345678901234567890")
|
||||
CacheSignature("session", "", "sig12345678901234567890123456789012345678901234567890")
|
||||
CacheSignature("session", "text", "")
|
||||
CacheSignature("session", "text", "short") // Too short
|
||||
|
||||
if got := GetCachedSignature("session", "text"); got != "" {
|
||||
t.Errorf("Expected empty after invalid cache attempts, got '%s'", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCacheSignature_ShortSignatureRejected(t *testing.T) {
|
||||
ClearSignatureCache("")
|
||||
|
||||
sessionID := "test-short-sig"
|
||||
text := "Some text"
|
||||
shortSig := "abc123" // Less than 50 chars
|
||||
|
||||
CacheSignature(sessionID, text, shortSig)
|
||||
|
||||
if got := GetCachedSignature(sessionID, text); got != "" {
|
||||
t.Errorf("Short signature should be rejected, got '%s'", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestClearSignatureCache_SpecificSession(t *testing.T) {
|
||||
ClearSignatureCache("")
|
||||
|
||||
sig := "validSig1234567890123456789012345678901234567890123456"
|
||||
CacheSignature("session-1", "text", sig)
|
||||
CacheSignature("session-2", "text", sig)
|
||||
|
||||
ClearSignatureCache("session-1")
|
||||
|
||||
if got := GetCachedSignature("session-1", "text"); got != "" {
|
||||
t.Error("session-1 should be cleared")
|
||||
}
|
||||
if got := GetCachedSignature("session-2", "text"); got != sig {
|
||||
t.Error("session-2 should still exist")
|
||||
}
|
||||
}
|
||||
|
||||
func TestClearSignatureCache_AllSessions(t *testing.T) {
|
||||
ClearSignatureCache("")
|
||||
|
||||
sig := "validSig1234567890123456789012345678901234567890123456"
|
||||
CacheSignature("session-1", "text", sig)
|
||||
CacheSignature("session-2", "text", sig)
|
||||
|
||||
ClearSignatureCache("")
|
||||
|
||||
if got := GetCachedSignature("session-1", "text"); got != "" {
|
||||
t.Error("session-1 should be cleared")
|
||||
}
|
||||
if got := GetCachedSignature("session-2", "text"); got != "" {
|
||||
t.Error("session-2 should be cleared")
|
||||
}
|
||||
}
|
||||
|
||||
func TestHasValidSignature(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
signature string
|
||||
expected bool
|
||||
}{
|
||||
{"valid long signature", "abc123validSignature1234567890123456789012345678901234567890", true},
|
||||
{"exactly 50 chars", "12345678901234567890123456789012345678901234567890", true},
|
||||
{"49 chars - invalid", "1234567890123456789012345678901234567890123456789", false},
|
||||
{"empty string", "", false},
|
||||
{"short signature", "abc", false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := HasValidSignature(tt.signature)
|
||||
if result != tt.expected {
|
||||
t.Errorf("HasValidSignature(%q) = %v, expected %v", tt.signature, result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCacheSignature_TextHashCollisionResistance(t *testing.T) {
|
||||
ClearSignatureCache("")
|
||||
|
||||
sessionID := "hash-test-session"
|
||||
|
||||
// Different texts should produce different hashes
|
||||
text1 := "First thinking text"
|
||||
text2 := "Second thinking text"
|
||||
sig1 := "signature1_1234567890123456789012345678901234567890123456"
|
||||
sig2 := "signature2_1234567890123456789012345678901234567890123456"
|
||||
|
||||
CacheSignature(sessionID, text1, sig1)
|
||||
CacheSignature(sessionID, text2, sig2)
|
||||
|
||||
if GetCachedSignature(sessionID, text1) != sig1 {
|
||||
t.Error("text1 signature mismatch")
|
||||
}
|
||||
if GetCachedSignature(sessionID, text2) != sig2 {
|
||||
t.Error("text2 signature mismatch")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCacheSignature_UnicodeText(t *testing.T) {
|
||||
ClearSignatureCache("")
|
||||
|
||||
sessionID := "unicode-session"
|
||||
text := "한글 텍스트와 이모지 🎉 그리고 特殊文字"
|
||||
sig := "unicodeSig123456789012345678901234567890123456789012345"
|
||||
|
||||
CacheSignature(sessionID, text, sig)
|
||||
|
||||
if got := GetCachedSignature(sessionID, text); got != sig {
|
||||
t.Errorf("Unicode text signature retrieval failed, got '%s'", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCacheSignature_Overwrite(t *testing.T) {
|
||||
ClearSignatureCache("")
|
||||
|
||||
sessionID := "overwrite-session"
|
||||
text := "Same text"
|
||||
sig1 := "firstSignature12345678901234567890123456789012345678901"
|
||||
sig2 := "secondSignature1234567890123456789012345678901234567890"
|
||||
|
||||
CacheSignature(sessionID, text, sig1)
|
||||
CacheSignature(sessionID, text, sig2) // Overwrite
|
||||
|
||||
if got := GetCachedSignature(sessionID, text); got != sig2 {
|
||||
t.Errorf("Expected overwritten signature '%s', got '%s'", sig2, got)
|
||||
}
|
||||
}
|
||||
|
||||
// Note: TTL expiration test is tricky to test without mocking time
|
||||
// We test the logic path exists but actual expiration would require time manipulation
|
||||
func TestCacheSignature_ExpirationLogic(t *testing.T) {
|
||||
ClearSignatureCache("")
|
||||
|
||||
// This test verifies the expiration check exists
|
||||
// In a real scenario, we'd mock time.Now()
|
||||
sessionID := "expiration-test"
|
||||
text := "text"
|
||||
sig := "validSig1234567890123456789012345678901234567890123456"
|
||||
|
||||
CacheSignature(sessionID, text, sig)
|
||||
|
||||
// Fresh entry should be retrievable
|
||||
if got := GetCachedSignature(sessionID, text); got != sig {
|
||||
t.Errorf("Fresh entry should be retrievable, got '%s'", got)
|
||||
}
|
||||
|
||||
// We can't easily test actual expiration without time mocking
|
||||
// but the logic is verified by the implementation
|
||||
_ = time.Now() // Acknowledge we're not testing time passage
|
||||
}
|
||||
@@ -24,12 +24,17 @@ func DoClaudeLogin(cfg *config.Config, options *LoginOptions) {
|
||||
options = &LoginOptions{}
|
||||
}
|
||||
|
||||
promptFn := options.Prompt
|
||||
if promptFn == nil {
|
||||
promptFn = defaultProjectPrompt()
|
||||
}
|
||||
|
||||
manager := newAuthManager()
|
||||
|
||||
authOpts := &sdkAuth.LoginOptions{
|
||||
NoBrowser: options.NoBrowser,
|
||||
Metadata: map[string]string{},
|
||||
Prompt: options.Prompt,
|
||||
Prompt: promptFn,
|
||||
}
|
||||
|
||||
_, savedPath, err := manager.Login(context.Background(), "claude", cfg, authOpts)
|
||||
|
||||
@@ -15,11 +15,16 @@ func DoAntigravityLogin(cfg *config.Config, options *LoginOptions) {
|
||||
options = &LoginOptions{}
|
||||
}
|
||||
|
||||
promptFn := options.Prompt
|
||||
if promptFn == nil {
|
||||
promptFn = defaultProjectPrompt()
|
||||
}
|
||||
|
||||
manager := newAuthManager()
|
||||
authOpts := &sdkAuth.LoginOptions{
|
||||
NoBrowser: options.NoBrowser,
|
||||
Metadata: map[string]string{},
|
||||
Prompt: options.Prompt,
|
||||
Prompt: promptFn,
|
||||
}
|
||||
|
||||
record, savedPath, err := manager.Login(context.Background(), "antigravity", cfg, authOpts)
|
||||
|
||||
@@ -20,13 +20,7 @@ func DoIFlowLogin(cfg *config.Config, options *LoginOptions) {
|
||||
|
||||
promptFn := options.Prompt
|
||||
if promptFn == nil {
|
||||
promptFn = func(prompt string) (string, error) {
|
||||
fmt.Println()
|
||||
fmt.Println(prompt)
|
||||
var value string
|
||||
_, err := fmt.Scanln(&value)
|
||||
return value, err
|
||||
}
|
||||
promptFn = defaultProjectPrompt()
|
||||
}
|
||||
|
||||
authOpts := &sdkAuth.LoginOptions{
|
||||
|
||||
@@ -55,11 +55,22 @@ func DoLogin(cfg *config.Config, projectID string, options *LoginOptions) {
|
||||
|
||||
ctx := context.Background()
|
||||
|
||||
promptFn := options.Prompt
|
||||
if promptFn == nil {
|
||||
promptFn = defaultProjectPrompt()
|
||||
}
|
||||
|
||||
trimmedProjectID := strings.TrimSpace(projectID)
|
||||
callbackPrompt := promptFn
|
||||
if trimmedProjectID == "" {
|
||||
callbackPrompt = nil
|
||||
}
|
||||
|
||||
loginOpts := &sdkAuth.LoginOptions{
|
||||
NoBrowser: options.NoBrowser,
|
||||
ProjectID: strings.TrimSpace(projectID),
|
||||
ProjectID: trimmedProjectID,
|
||||
Metadata: map[string]string{},
|
||||
Prompt: options.Prompt,
|
||||
Prompt: callbackPrompt,
|
||||
}
|
||||
|
||||
authenticator := sdkAuth.NewGeminiAuthenticator()
|
||||
@@ -76,7 +87,10 @@ func DoLogin(cfg *config.Config, projectID string, options *LoginOptions) {
|
||||
}
|
||||
|
||||
geminiAuth := gemini.NewGeminiAuth()
|
||||
httpClient, errClient := geminiAuth.GetAuthenticatedClient(ctx, storage, cfg, options.NoBrowser)
|
||||
httpClient, errClient := geminiAuth.GetAuthenticatedClient(ctx, storage, cfg, &gemini.WebLoginOptions{
|
||||
NoBrowser: options.NoBrowser,
|
||||
Prompt: callbackPrompt,
|
||||
})
|
||||
if errClient != nil {
|
||||
log.Errorf("Gemini authentication failed: %v", errClient)
|
||||
return
|
||||
@@ -90,12 +104,7 @@ func DoLogin(cfg *config.Config, projectID string, options *LoginOptions) {
|
||||
return
|
||||
}
|
||||
|
||||
promptFn := options.Prompt
|
||||
if promptFn == nil {
|
||||
promptFn = defaultProjectPrompt()
|
||||
}
|
||||
|
||||
selectedProjectID := promptForProjectSelection(projects, strings.TrimSpace(projectID), promptFn)
|
||||
selectedProjectID := promptForProjectSelection(projects, trimmedProjectID, promptFn)
|
||||
projectSelections, errSelection := resolveProjectSelections(selectedProjectID, projects)
|
||||
if errSelection != nil {
|
||||
log.Errorf("Invalid project selection: %v", errSelection)
|
||||
|
||||
@@ -35,12 +35,17 @@ func DoCodexLogin(cfg *config.Config, options *LoginOptions) {
|
||||
options = &LoginOptions{}
|
||||
}
|
||||
|
||||
promptFn := options.Prompt
|
||||
if promptFn == nil {
|
||||
promptFn = defaultProjectPrompt()
|
||||
}
|
||||
|
||||
manager := newAuthManager()
|
||||
|
||||
authOpts := &sdkAuth.LoginOptions{
|
||||
NoBrowser: options.NoBrowser,
|
||||
Metadata: map[string]string{},
|
||||
Prompt: options.Prompt,
|
||||
Prompt: promptFn,
|
||||
}
|
||||
|
||||
_, savedPath, err := manager.Login(context.Background(), "codex", cfg, authOpts)
|
||||
|
||||
@@ -12,7 +12,6 @@ import (
|
||||
"strings"
|
||||
"syscall"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/sdk/config"
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
@@ -21,7 +20,7 @@ const DefaultPanelGitHubRepository = "https://github.com/router-for-me/Cli-Proxy
|
||||
|
||||
// Config represents the application's configuration, loaded from a YAML file.
|
||||
type Config struct {
|
||||
config.SDKConfig `yaml:",inline"`
|
||||
SDKConfig `yaml:",inline"`
|
||||
// Host is the network host/interface on which the API server will bind.
|
||||
// Default is empty ("") to bind all interfaces (IPv4 + IPv6). Use "127.0.0.1" or "localhost" for local-only access.
|
||||
Host string `yaml:"host" json:"-"`
|
||||
@@ -43,6 +42,10 @@ type Config struct {
|
||||
// LoggingToFile controls whether application logs are written to rotating files or stdout.
|
||||
LoggingToFile bool `yaml:"logging-to-file" json:"logging-to-file"`
|
||||
|
||||
// LogsMaxTotalSizeMB limits the total size (in MB) of log files under the logs directory.
|
||||
// When exceeded, the oldest log files are deleted until within the limit. Set to 0 to disable.
|
||||
LogsMaxTotalSizeMB int `yaml:"logs-max-total-size-mb" json:"logs-max-total-size-mb"`
|
||||
|
||||
// UsageStatisticsEnabled toggles in-memory usage aggregation; when false, usage data is discarded.
|
||||
UsageStatisticsEnabled bool `yaml:"usage-statistics-enabled" json:"usage-statistics-enabled"`
|
||||
|
||||
@@ -342,6 +345,7 @@ func LoadConfigOptional(configFile string, optional bool) (*Config, error) {
|
||||
// Set defaults before unmarshal so that absent keys keep defaults.
|
||||
cfg.Host = "" // Default empty: binds to all interfaces (IPv4 + IPv6)
|
||||
cfg.LoggingToFile = false
|
||||
cfg.LogsMaxTotalSizeMB = 0
|
||||
cfg.UsageStatisticsEnabled = false
|
||||
cfg.DisableCooling = false
|
||||
cfg.AmpCode.RestrictManagementToLocalhost = false // Default to false: API key auth is sufficient
|
||||
@@ -386,6 +390,10 @@ func LoadConfigOptional(configFile string, optional bool) (*Config, error) {
|
||||
cfg.RemoteManagement.PanelGitHubRepository = DefaultPanelGitHubRepository
|
||||
}
|
||||
|
||||
if cfg.LogsMaxTotalSizeMB < 0 {
|
||||
cfg.LogsMaxTotalSizeMB = 0
|
||||
}
|
||||
|
||||
// Sync request authentication providers with inline API keys for backwards compatibility.
|
||||
syncInlineAccessProvider(&cfg)
|
||||
|
||||
@@ -692,7 +700,7 @@ func sanitizeConfigForPersist(cfg *Config) *Config {
|
||||
}
|
||||
clone := *cfg
|
||||
clone.SDKConfig = cfg.SDKConfig
|
||||
clone.SDKConfig.Access = config.AccessConfig{}
|
||||
clone.SDKConfig.Access = AccessConfig{}
|
||||
return &clone
|
||||
}
|
||||
|
||||
|
||||
87
internal/config/sdk_config.go
Normal file
87
internal/config/sdk_config.go
Normal file
@@ -0,0 +1,87 @@
|
||||
// Package config provides configuration management for the CLI Proxy API server.
|
||||
// It handles loading and parsing YAML configuration files, and provides structured
|
||||
// access to application settings including server port, authentication directory,
|
||||
// debug settings, proxy configuration, and API keys.
|
||||
package config
|
||||
|
||||
// SDKConfig represents the application's configuration, loaded from a YAML file.
|
||||
type SDKConfig struct {
|
||||
// ProxyURL is the URL of an optional proxy server to use for outbound requests.
|
||||
ProxyURL string `yaml:"proxy-url" json:"proxy-url"`
|
||||
|
||||
// ForceModelPrefix requires explicit model prefixes (e.g., "teamA/gemini-3-pro-preview")
|
||||
// to target prefixed credentials. When false, unprefixed model requests may use prefixed
|
||||
// credentials as well.
|
||||
ForceModelPrefix bool `yaml:"force-model-prefix" json:"force-model-prefix"`
|
||||
|
||||
// RequestLog enables or disables detailed request logging functionality.
|
||||
RequestLog bool `yaml:"request-log" json:"request-log"`
|
||||
|
||||
// APIKeys is a list of keys for authenticating clients to this proxy server.
|
||||
APIKeys []string `yaml:"api-keys" json:"api-keys"`
|
||||
|
||||
// Access holds request authentication provider configuration.
|
||||
Access AccessConfig `yaml:"auth,omitempty" json:"auth,omitempty"`
|
||||
}
|
||||
|
||||
// AccessConfig groups request authentication providers.
|
||||
type AccessConfig struct {
|
||||
// Providers lists configured authentication providers.
|
||||
Providers []AccessProvider `yaml:"providers,omitempty" json:"providers,omitempty"`
|
||||
}
|
||||
|
||||
// AccessProvider describes a request authentication provider entry.
|
||||
type AccessProvider struct {
|
||||
// Name is the instance identifier for the provider.
|
||||
Name string `yaml:"name" json:"name"`
|
||||
|
||||
// Type selects the provider implementation registered via the SDK.
|
||||
Type string `yaml:"type" json:"type"`
|
||||
|
||||
// SDK optionally names a third-party SDK module providing this provider.
|
||||
SDK string `yaml:"sdk,omitempty" json:"sdk,omitempty"`
|
||||
|
||||
// APIKeys lists inline keys for providers that require them.
|
||||
APIKeys []string `yaml:"api-keys,omitempty" json:"api-keys,omitempty"`
|
||||
|
||||
// Config passes provider-specific options to the implementation.
|
||||
Config map[string]any `yaml:"config,omitempty" json:"config,omitempty"`
|
||||
}
|
||||
|
||||
const (
|
||||
// AccessProviderTypeConfigAPIKey is the built-in provider validating inline API keys.
|
||||
AccessProviderTypeConfigAPIKey = "config-api-key"
|
||||
|
||||
// DefaultAccessProviderName is applied when no provider name is supplied.
|
||||
DefaultAccessProviderName = "config-inline"
|
||||
)
|
||||
|
||||
// ConfigAPIKeyProvider returns the first inline API key provider if present.
|
||||
func (c *SDKConfig) ConfigAPIKeyProvider() *AccessProvider {
|
||||
if c == nil {
|
||||
return nil
|
||||
}
|
||||
for i := range c.Access.Providers {
|
||||
if c.Access.Providers[i].Type == AccessProviderTypeConfigAPIKey {
|
||||
if c.Access.Providers[i].Name == "" {
|
||||
c.Access.Providers[i].Name = DefaultAccessProviderName
|
||||
}
|
||||
return &c.Access.Providers[i]
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// MakeInlineAPIKeyProvider constructs an inline API key provider configuration.
|
||||
// It returns nil when no keys are supplied.
|
||||
func MakeInlineAPIKeyProvider(keys []string) *AccessProvider {
|
||||
if len(keys) == 0 {
|
||||
return nil
|
||||
}
|
||||
provider := &AccessProvider{
|
||||
Name: DefaultAccessProviderName,
|
||||
Type: AccessProviderTypeConfigAPIKey,
|
||||
APIKeys: append([]string(nil), keys...),
|
||||
}
|
||||
return provider
|
||||
}
|
||||
@@ -72,39 +72,45 @@ func SetupBaseLogger() {
|
||||
}
|
||||
|
||||
// ConfigureLogOutput switches the global log destination between rotating files and stdout.
|
||||
func ConfigureLogOutput(loggingToFile bool) error {
|
||||
// When logsMaxTotalSizeMB > 0, a background cleaner removes the oldest log files in the logs directory
|
||||
// until the total size is within the limit.
|
||||
func ConfigureLogOutput(loggingToFile bool, logsMaxTotalSizeMB int) error {
|
||||
SetupBaseLogger()
|
||||
|
||||
writerMu.Lock()
|
||||
defer writerMu.Unlock()
|
||||
|
||||
if loggingToFile {
|
||||
logDir := "logs"
|
||||
if base := util.WritablePath(); base != "" {
|
||||
logDir = filepath.Join(base, "logs")
|
||||
}
|
||||
|
||||
protectedPath := ""
|
||||
if loggingToFile {
|
||||
if err := os.MkdirAll(logDir, 0o755); err != nil {
|
||||
return fmt.Errorf("logging: failed to create log directory: %w", err)
|
||||
}
|
||||
if logWriter != nil {
|
||||
_ = logWriter.Close()
|
||||
}
|
||||
protectedPath = filepath.Join(logDir, "main.log")
|
||||
logWriter = &lumberjack.Logger{
|
||||
Filename: filepath.Join(logDir, "main.log"),
|
||||
Filename: protectedPath,
|
||||
MaxSize: 10,
|
||||
MaxBackups: 0,
|
||||
MaxAge: 0,
|
||||
Compress: false,
|
||||
}
|
||||
log.SetOutput(logWriter)
|
||||
return nil
|
||||
}
|
||||
|
||||
} else {
|
||||
if logWriter != nil {
|
||||
_ = logWriter.Close()
|
||||
logWriter = nil
|
||||
}
|
||||
log.SetOutput(os.Stdout)
|
||||
}
|
||||
|
||||
configureLogDirCleanerLocked(logDir, logsMaxTotalSizeMB, protectedPath)
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -112,6 +118,8 @@ func closeLogOutputs() {
|
||||
writerMu.Lock()
|
||||
defer writerMu.Unlock()
|
||||
|
||||
stopLogDirCleanerLocked()
|
||||
|
||||
if logWriter != nil {
|
||||
_ = logWriter.Close()
|
||||
logWriter = nil
|
||||
|
||||
166
internal/logging/log_dir_cleaner.go
Normal file
166
internal/logging/log_dir_cleaner.go
Normal file
@@ -0,0 +1,166 @@
|
||||
package logging
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sort"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
const logDirCleanerInterval = time.Minute
|
||||
|
||||
var logDirCleanerCancel context.CancelFunc
|
||||
|
||||
func configureLogDirCleanerLocked(logDir string, maxTotalSizeMB int, protectedPath string) {
|
||||
stopLogDirCleanerLocked()
|
||||
|
||||
if maxTotalSizeMB <= 0 {
|
||||
return
|
||||
}
|
||||
|
||||
maxBytes := int64(maxTotalSizeMB) * 1024 * 1024
|
||||
if maxBytes <= 0 {
|
||||
return
|
||||
}
|
||||
|
||||
dir := strings.TrimSpace(logDir)
|
||||
if dir == "" {
|
||||
return
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
logDirCleanerCancel = cancel
|
||||
go runLogDirCleaner(ctx, filepath.Clean(dir), maxBytes, strings.TrimSpace(protectedPath))
|
||||
}
|
||||
|
||||
func stopLogDirCleanerLocked() {
|
||||
if logDirCleanerCancel == nil {
|
||||
return
|
||||
}
|
||||
logDirCleanerCancel()
|
||||
logDirCleanerCancel = nil
|
||||
}
|
||||
|
||||
func runLogDirCleaner(ctx context.Context, logDir string, maxBytes int64, protectedPath string) {
|
||||
ticker := time.NewTicker(logDirCleanerInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
cleanOnce := func() {
|
||||
deleted, errClean := enforceLogDirSizeLimit(logDir, maxBytes, protectedPath)
|
||||
if errClean != nil {
|
||||
log.WithError(errClean).Warn("logging: failed to enforce log directory size limit")
|
||||
return
|
||||
}
|
||||
if deleted > 0 {
|
||||
log.Debugf("logging: removed %d old log file(s) to enforce log directory size limit", deleted)
|
||||
}
|
||||
}
|
||||
|
||||
cleanOnce()
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
cleanOnce()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func enforceLogDirSizeLimit(logDir string, maxBytes int64, protectedPath string) (int, error) {
|
||||
if maxBytes <= 0 {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
dir := strings.TrimSpace(logDir)
|
||||
if dir == "" {
|
||||
return 0, nil
|
||||
}
|
||||
dir = filepath.Clean(dir)
|
||||
|
||||
entries, errRead := os.ReadDir(dir)
|
||||
if errRead != nil {
|
||||
if os.IsNotExist(errRead) {
|
||||
return 0, nil
|
||||
}
|
||||
return 0, errRead
|
||||
}
|
||||
|
||||
protected := strings.TrimSpace(protectedPath)
|
||||
if protected != "" {
|
||||
protected = filepath.Clean(protected)
|
||||
}
|
||||
|
||||
type logFile struct {
|
||||
path string
|
||||
size int64
|
||||
modTime time.Time
|
||||
}
|
||||
|
||||
var (
|
||||
files []logFile
|
||||
total int64
|
||||
)
|
||||
for _, entry := range entries {
|
||||
if entry.IsDir() {
|
||||
continue
|
||||
}
|
||||
name := entry.Name()
|
||||
if !isLogFileName(name) {
|
||||
continue
|
||||
}
|
||||
info, errInfo := entry.Info()
|
||||
if errInfo != nil {
|
||||
continue
|
||||
}
|
||||
if !info.Mode().IsRegular() {
|
||||
continue
|
||||
}
|
||||
path := filepath.Join(dir, name)
|
||||
files = append(files, logFile{
|
||||
path: path,
|
||||
size: info.Size(),
|
||||
modTime: info.ModTime(),
|
||||
})
|
||||
total += info.Size()
|
||||
}
|
||||
|
||||
if total <= maxBytes {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
sort.Slice(files, func(i, j int) bool {
|
||||
return files[i].modTime.Before(files[j].modTime)
|
||||
})
|
||||
|
||||
deleted := 0
|
||||
for _, file := range files {
|
||||
if total <= maxBytes {
|
||||
break
|
||||
}
|
||||
if protected != "" && filepath.Clean(file.path) == protected {
|
||||
continue
|
||||
}
|
||||
if errRemove := os.Remove(file.path); errRemove != nil {
|
||||
log.WithError(errRemove).Warnf("logging: failed to remove old log file: %s", filepath.Base(file.path))
|
||||
continue
|
||||
}
|
||||
total -= file.size
|
||||
deleted++
|
||||
}
|
||||
|
||||
return deleted, nil
|
||||
}
|
||||
|
||||
func isLogFileName(name string) bool {
|
||||
trimmed := strings.TrimSpace(name)
|
||||
if trimmed == "" {
|
||||
return false
|
||||
}
|
||||
lower := strings.ToLower(trimmed)
|
||||
return strings.HasSuffix(lower, ".log") || strings.HasSuffix(lower, ".log.gz")
|
||||
}
|
||||
70
internal/logging/log_dir_cleaner_test.go
Normal file
70
internal/logging/log_dir_cleaner_test.go
Normal file
@@ -0,0 +1,70 @@
|
||||
package logging
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestEnforceLogDirSizeLimitDeletesOldest(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
|
||||
writeLogFile(t, filepath.Join(dir, "old.log"), 60, time.Unix(1, 0))
|
||||
writeLogFile(t, filepath.Join(dir, "mid.log"), 60, time.Unix(2, 0))
|
||||
protected := filepath.Join(dir, "main.log")
|
||||
writeLogFile(t, protected, 60, time.Unix(3, 0))
|
||||
|
||||
deleted, err := enforceLogDirSizeLimit(dir, 120, protected)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if deleted != 1 {
|
||||
t.Fatalf("expected 1 deleted file, got %d", deleted)
|
||||
}
|
||||
|
||||
if _, err := os.Stat(filepath.Join(dir, "old.log")); !os.IsNotExist(err) {
|
||||
t.Fatalf("expected old.log to be removed, stat error: %v", err)
|
||||
}
|
||||
if _, err := os.Stat(filepath.Join(dir, "mid.log")); err != nil {
|
||||
t.Fatalf("expected mid.log to remain, stat error: %v", err)
|
||||
}
|
||||
if _, err := os.Stat(protected); err != nil {
|
||||
t.Fatalf("expected protected main.log to remain, stat error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEnforceLogDirSizeLimitSkipsProtected(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
|
||||
protected := filepath.Join(dir, "main.log")
|
||||
writeLogFile(t, protected, 200, time.Unix(1, 0))
|
||||
writeLogFile(t, filepath.Join(dir, "other.log"), 50, time.Unix(2, 0))
|
||||
|
||||
deleted, err := enforceLogDirSizeLimit(dir, 100, protected)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if deleted != 1 {
|
||||
t.Fatalf("expected 1 deleted file, got %d", deleted)
|
||||
}
|
||||
|
||||
if _, err := os.Stat(protected); err != nil {
|
||||
t.Fatalf("expected protected main.log to remain, stat error: %v", err)
|
||||
}
|
||||
if _, err := os.Stat(filepath.Join(dir, "other.log")); !os.IsNotExist(err) {
|
||||
t.Fatalf("expected other.log to be removed, stat error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func writeLogFile(t *testing.T, path string, size int, modTime time.Time) {
|
||||
t.Helper()
|
||||
|
||||
data := make([]byte, size)
|
||||
if err := os.WriteFile(path, data, 0o644); err != nil {
|
||||
t.Fatalf("write file: %v", err)
|
||||
}
|
||||
if err := os.Chtimes(path, modTime, modTime); err != nil {
|
||||
t.Fatalf("set times: %v", err)
|
||||
}
|
||||
}
|
||||
@@ -14,6 +14,7 @@ import (
|
||||
"regexp"
|
||||
"sort"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/andybalholm/brotli"
|
||||
@@ -25,6 +26,8 @@ import (
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
||||
)
|
||||
|
||||
var requestLogID atomic.Uint64
|
||||
|
||||
// RequestLogger defines the interface for logging HTTP requests and responses.
|
||||
// It provides methods for logging both regular and streaming HTTP request/response cycles.
|
||||
type RequestLogger interface {
|
||||
@@ -204,19 +207,52 @@ func (l *FileRequestLogger) logRequest(url, method string, requestHeaders map[st
|
||||
}
|
||||
filePath := filepath.Join(l.logsDir, filename)
|
||||
|
||||
// Decompress response if needed
|
||||
decompressedResponse, err := l.decompressResponse(responseHeaders, response)
|
||||
if err != nil {
|
||||
// If decompression fails, log the error but continue with original response
|
||||
decompressedResponse = append(response, []byte(fmt.Sprintf("\n[DECOMPRESSION ERROR: %v]", err))...)
|
||||
requestBodyPath, errTemp := l.writeRequestBodyTempFile(body)
|
||||
if errTemp != nil {
|
||||
log.WithError(errTemp).Warn("failed to create request body temp file, falling back to direct write")
|
||||
}
|
||||
if requestBodyPath != "" {
|
||||
defer func() {
|
||||
if errRemove := os.Remove(requestBodyPath); errRemove != nil {
|
||||
log.WithError(errRemove).Warn("failed to remove request body temp file")
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// Create log content
|
||||
content := l.formatLogContent(url, method, requestHeaders, body, apiRequest, apiResponse, decompressedResponse, statusCode, responseHeaders, apiResponseErrors)
|
||||
responseToWrite, decompressErr := l.decompressResponse(responseHeaders, response)
|
||||
if decompressErr != nil {
|
||||
// If decompression fails, continue with original response and annotate the log output.
|
||||
responseToWrite = response
|
||||
}
|
||||
|
||||
// Write to file
|
||||
if err = os.WriteFile(filePath, []byte(content), 0644); err != nil {
|
||||
return fmt.Errorf("failed to write log file: %w", err)
|
||||
logFile, errOpen := os.OpenFile(filePath, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0644)
|
||||
if errOpen != nil {
|
||||
return fmt.Errorf("failed to create log file: %w", errOpen)
|
||||
}
|
||||
|
||||
writeErr := l.writeNonStreamingLog(
|
||||
logFile,
|
||||
url,
|
||||
method,
|
||||
requestHeaders,
|
||||
body,
|
||||
requestBodyPath,
|
||||
apiRequest,
|
||||
apiResponse,
|
||||
apiResponseErrors,
|
||||
statusCode,
|
||||
responseHeaders,
|
||||
responseToWrite,
|
||||
decompressErr,
|
||||
)
|
||||
if errClose := logFile.Close(); errClose != nil {
|
||||
log.WithError(errClose).Warn("failed to close request log file")
|
||||
if writeErr == nil {
|
||||
return errClose
|
||||
}
|
||||
}
|
||||
if writeErr != nil {
|
||||
return fmt.Errorf("failed to write log file: %w", writeErr)
|
||||
}
|
||||
|
||||
if force && !l.enabled {
|
||||
@@ -253,26 +289,38 @@ func (l *FileRequestLogger) LogStreamingRequest(url, method string, headers map[
|
||||
filename := l.generateFilename(url)
|
||||
filePath := filepath.Join(l.logsDir, filename)
|
||||
|
||||
// Create and open file
|
||||
file, err := os.Create(filePath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create log file: %w", err)
|
||||
requestHeaders := make(map[string][]string, len(headers))
|
||||
for key, values := range headers {
|
||||
headerValues := make([]string, len(values))
|
||||
copy(headerValues, values)
|
||||
requestHeaders[key] = headerValues
|
||||
}
|
||||
|
||||
// Write initial request information
|
||||
requestInfo := l.formatRequestInfo(url, method, headers, body)
|
||||
if _, err = file.WriteString(requestInfo); err != nil {
|
||||
_ = file.Close()
|
||||
return nil, fmt.Errorf("failed to write request info: %w", err)
|
||||
requestBodyPath, errTemp := l.writeRequestBodyTempFile(body)
|
||||
if errTemp != nil {
|
||||
return nil, fmt.Errorf("failed to create request body temp file: %w", errTemp)
|
||||
}
|
||||
|
||||
responseBodyFile, errCreate := os.CreateTemp(l.logsDir, "response-body-*.tmp")
|
||||
if errCreate != nil {
|
||||
_ = os.Remove(requestBodyPath)
|
||||
return nil, fmt.Errorf("failed to create response body temp file: %w", errCreate)
|
||||
}
|
||||
responseBodyPath := responseBodyFile.Name()
|
||||
|
||||
// Create streaming writer
|
||||
writer := &FileStreamingLogWriter{
|
||||
file: file,
|
||||
logFilePath: filePath,
|
||||
url: url,
|
||||
method: method,
|
||||
timestamp: time.Now(),
|
||||
requestHeaders: requestHeaders,
|
||||
requestBodyPath: requestBodyPath,
|
||||
responseBodyPath: responseBodyPath,
|
||||
responseBodyFile: responseBodyFile,
|
||||
chunkChan: make(chan []byte, 100), // Buffered channel for async writes
|
||||
closeChan: make(chan struct{}),
|
||||
errorChan: make(chan error, 1),
|
||||
bufferedChunks: &bytes.Buffer{},
|
||||
}
|
||||
|
||||
// Start async writer goroutine
|
||||
@@ -323,7 +371,9 @@ func (l *FileRequestLogger) generateFilename(url string) string {
|
||||
timestamp := time.Now().Format("2006-01-02T150405-.000000000")
|
||||
timestamp = strings.Replace(timestamp, ".", "", -1)
|
||||
|
||||
return fmt.Sprintf("%s-%s.log", sanitized, timestamp)
|
||||
id := requestLogID.Add(1)
|
||||
|
||||
return fmt.Sprintf("%s-%s-%d.log", sanitized, timestamp, id)
|
||||
}
|
||||
|
||||
// sanitizeForFilename replaces characters that are not safe for filenames.
|
||||
@@ -405,6 +455,220 @@ func (l *FileRequestLogger) cleanupOldErrorLogs() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (l *FileRequestLogger) writeRequestBodyTempFile(body []byte) (string, error) {
|
||||
tmpFile, errCreate := os.CreateTemp(l.logsDir, "request-body-*.tmp")
|
||||
if errCreate != nil {
|
||||
return "", errCreate
|
||||
}
|
||||
tmpPath := tmpFile.Name()
|
||||
|
||||
if _, errCopy := io.Copy(tmpFile, bytes.NewReader(body)); errCopy != nil {
|
||||
_ = tmpFile.Close()
|
||||
_ = os.Remove(tmpPath)
|
||||
return "", errCopy
|
||||
}
|
||||
if errClose := tmpFile.Close(); errClose != nil {
|
||||
_ = os.Remove(tmpPath)
|
||||
return "", errClose
|
||||
}
|
||||
return tmpPath, nil
|
||||
}
|
||||
|
||||
func (l *FileRequestLogger) writeNonStreamingLog(
|
||||
w io.Writer,
|
||||
url, method string,
|
||||
requestHeaders map[string][]string,
|
||||
requestBody []byte,
|
||||
requestBodyPath string,
|
||||
apiRequest []byte,
|
||||
apiResponse []byte,
|
||||
apiResponseErrors []*interfaces.ErrorMessage,
|
||||
statusCode int,
|
||||
responseHeaders map[string][]string,
|
||||
response []byte,
|
||||
decompressErr error,
|
||||
) error {
|
||||
if errWrite := writeRequestInfoWithBody(w, url, method, requestHeaders, requestBody, requestBodyPath, time.Now()); errWrite != nil {
|
||||
return errWrite
|
||||
}
|
||||
if errWrite := writeAPISection(w, "=== API REQUEST ===\n", "=== API REQUEST", apiRequest); errWrite != nil {
|
||||
return errWrite
|
||||
}
|
||||
if errWrite := writeAPIErrorResponses(w, apiResponseErrors); errWrite != nil {
|
||||
return errWrite
|
||||
}
|
||||
if errWrite := writeAPISection(w, "=== API RESPONSE ===\n", "=== API RESPONSE", apiResponse); errWrite != nil {
|
||||
return errWrite
|
||||
}
|
||||
return writeResponseSection(w, statusCode, true, responseHeaders, bytes.NewReader(response), decompressErr, true)
|
||||
}
|
||||
|
||||
func writeRequestInfoWithBody(
|
||||
w io.Writer,
|
||||
url, method string,
|
||||
headers map[string][]string,
|
||||
body []byte,
|
||||
bodyPath string,
|
||||
timestamp time.Time,
|
||||
) error {
|
||||
if _, errWrite := io.WriteString(w, "=== REQUEST INFO ===\n"); errWrite != nil {
|
||||
return errWrite
|
||||
}
|
||||
if _, errWrite := io.WriteString(w, fmt.Sprintf("Version: %s\n", buildinfo.Version)); errWrite != nil {
|
||||
return errWrite
|
||||
}
|
||||
if _, errWrite := io.WriteString(w, fmt.Sprintf("URL: %s\n", url)); errWrite != nil {
|
||||
return errWrite
|
||||
}
|
||||
if _, errWrite := io.WriteString(w, fmt.Sprintf("Method: %s\n", method)); errWrite != nil {
|
||||
return errWrite
|
||||
}
|
||||
if _, errWrite := io.WriteString(w, fmt.Sprintf("Timestamp: %s\n", timestamp.Format(time.RFC3339Nano))); errWrite != nil {
|
||||
return errWrite
|
||||
}
|
||||
if _, errWrite := io.WriteString(w, "\n"); errWrite != nil {
|
||||
return errWrite
|
||||
}
|
||||
|
||||
if _, errWrite := io.WriteString(w, "=== HEADERS ===\n"); errWrite != nil {
|
||||
return errWrite
|
||||
}
|
||||
for key, values := range headers {
|
||||
for _, value := range values {
|
||||
masked := util.MaskSensitiveHeaderValue(key, value)
|
||||
if _, errWrite := io.WriteString(w, fmt.Sprintf("%s: %s\n", key, masked)); errWrite != nil {
|
||||
return errWrite
|
||||
}
|
||||
}
|
||||
}
|
||||
if _, errWrite := io.WriteString(w, "\n"); errWrite != nil {
|
||||
return errWrite
|
||||
}
|
||||
|
||||
if _, errWrite := io.WriteString(w, "=== REQUEST BODY ===\n"); errWrite != nil {
|
||||
return errWrite
|
||||
}
|
||||
|
||||
if bodyPath != "" {
|
||||
bodyFile, errOpen := os.Open(bodyPath)
|
||||
if errOpen != nil {
|
||||
return errOpen
|
||||
}
|
||||
if _, errCopy := io.Copy(w, bodyFile); errCopy != nil {
|
||||
_ = bodyFile.Close()
|
||||
return errCopy
|
||||
}
|
||||
if errClose := bodyFile.Close(); errClose != nil {
|
||||
log.WithError(errClose).Warn("failed to close request body temp file")
|
||||
}
|
||||
} else if _, errWrite := w.Write(body); errWrite != nil {
|
||||
return errWrite
|
||||
}
|
||||
|
||||
if _, errWrite := io.WriteString(w, "\n\n"); errWrite != nil {
|
||||
return errWrite
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func writeAPISection(w io.Writer, sectionHeader string, sectionPrefix string, payload []byte) error {
|
||||
if len(payload) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
if bytes.HasPrefix(payload, []byte(sectionPrefix)) {
|
||||
if _, errWrite := w.Write(payload); errWrite != nil {
|
||||
return errWrite
|
||||
}
|
||||
if !bytes.HasSuffix(payload, []byte("\n")) {
|
||||
if _, errWrite := io.WriteString(w, "\n"); errWrite != nil {
|
||||
return errWrite
|
||||
}
|
||||
}
|
||||
} else {
|
||||
if _, errWrite := io.WriteString(w, sectionHeader); errWrite != nil {
|
||||
return errWrite
|
||||
}
|
||||
if _, errWrite := w.Write(payload); errWrite != nil {
|
||||
return errWrite
|
||||
}
|
||||
if _, errWrite := io.WriteString(w, "\n"); errWrite != nil {
|
||||
return errWrite
|
||||
}
|
||||
}
|
||||
|
||||
if _, errWrite := io.WriteString(w, "\n"); errWrite != nil {
|
||||
return errWrite
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func writeAPIErrorResponses(w io.Writer, apiResponseErrors []*interfaces.ErrorMessage) error {
|
||||
for i := 0; i < len(apiResponseErrors); i++ {
|
||||
if apiResponseErrors[i] == nil {
|
||||
continue
|
||||
}
|
||||
if _, errWrite := io.WriteString(w, "=== API ERROR RESPONSE ===\n"); errWrite != nil {
|
||||
return errWrite
|
||||
}
|
||||
if _, errWrite := io.WriteString(w, fmt.Sprintf("HTTP Status: %d\n", apiResponseErrors[i].StatusCode)); errWrite != nil {
|
||||
return errWrite
|
||||
}
|
||||
if apiResponseErrors[i].Error != nil {
|
||||
if _, errWrite := io.WriteString(w, apiResponseErrors[i].Error.Error()); errWrite != nil {
|
||||
return errWrite
|
||||
}
|
||||
}
|
||||
if _, errWrite := io.WriteString(w, "\n\n"); errWrite != nil {
|
||||
return errWrite
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func writeResponseSection(w io.Writer, statusCode int, statusWritten bool, responseHeaders map[string][]string, responseReader io.Reader, decompressErr error, trailingNewline bool) error {
|
||||
if _, errWrite := io.WriteString(w, "=== RESPONSE ===\n"); errWrite != nil {
|
||||
return errWrite
|
||||
}
|
||||
if statusWritten {
|
||||
if _, errWrite := io.WriteString(w, fmt.Sprintf("Status: %d\n", statusCode)); errWrite != nil {
|
||||
return errWrite
|
||||
}
|
||||
}
|
||||
|
||||
if responseHeaders != nil {
|
||||
for key, values := range responseHeaders {
|
||||
for _, value := range values {
|
||||
if _, errWrite := io.WriteString(w, fmt.Sprintf("%s: %s\n", key, value)); errWrite != nil {
|
||||
return errWrite
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if _, errWrite := io.WriteString(w, "\n"); errWrite != nil {
|
||||
return errWrite
|
||||
}
|
||||
|
||||
if responseReader != nil {
|
||||
if _, errCopy := io.Copy(w, responseReader); errCopy != nil {
|
||||
return errCopy
|
||||
}
|
||||
}
|
||||
if decompressErr != nil {
|
||||
if _, errWrite := io.WriteString(w, fmt.Sprintf("\n[DECOMPRESSION ERROR: %v]", decompressErr)); errWrite != nil {
|
||||
return errWrite
|
||||
}
|
||||
}
|
||||
|
||||
if trailingNewline {
|
||||
if _, errWrite := io.WriteString(w, "\n"); errWrite != nil {
|
||||
return errWrite
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// formatLogContent creates the complete log content for non-streaming requests.
|
||||
//
|
||||
// Parameters:
|
||||
@@ -648,13 +912,34 @@ func (l *FileRequestLogger) formatRequestInfo(url, method string, headers map[st
|
||||
}
|
||||
|
||||
// FileStreamingLogWriter implements StreamingLogWriter for file-based streaming logs.
|
||||
// It handles asynchronous writing of streaming response chunks to a file.
|
||||
// All data is buffered and written in the correct order when Close is called.
|
||||
// It spools streaming response chunks to a temporary file to avoid retaining large responses in memory.
|
||||
// The final log file is assembled when Close is called.
|
||||
type FileStreamingLogWriter struct {
|
||||
// file is the file where log data is written.
|
||||
file *os.File
|
||||
// logFilePath is the final log file path.
|
||||
logFilePath string
|
||||
|
||||
// chunkChan is a channel for receiving response chunks to buffer.
|
||||
// url is the request URL (masked upstream in middleware).
|
||||
url string
|
||||
|
||||
// method is the HTTP method.
|
||||
method string
|
||||
|
||||
// timestamp is captured when the streaming log is initialized.
|
||||
timestamp time.Time
|
||||
|
||||
// requestHeaders stores the request headers.
|
||||
requestHeaders map[string][]string
|
||||
|
||||
// requestBodyPath is a temporary file path holding the request body.
|
||||
requestBodyPath string
|
||||
|
||||
// responseBodyPath is a temporary file path holding the streaming response body.
|
||||
responseBodyPath string
|
||||
|
||||
// responseBodyFile is the temp file where chunks are appended by the async writer.
|
||||
responseBodyFile *os.File
|
||||
|
||||
// chunkChan is a channel for receiving response chunks to spool.
|
||||
chunkChan chan []byte
|
||||
|
||||
// closeChan is a channel for signaling when the writer is closed.
|
||||
@@ -663,9 +948,6 @@ type FileStreamingLogWriter struct {
|
||||
// errorChan is a channel for reporting errors during writing.
|
||||
errorChan chan error
|
||||
|
||||
// bufferedChunks stores the response chunks in order.
|
||||
bufferedChunks *bytes.Buffer
|
||||
|
||||
// responseStatus stores the HTTP status code.
|
||||
responseStatus int
|
||||
|
||||
@@ -770,85 +1052,115 @@ func (w *FileStreamingLogWriter) Close() error {
|
||||
close(w.chunkChan)
|
||||
}
|
||||
|
||||
// Wait for async writer to finish buffering chunks
|
||||
// Wait for async writer to finish spooling chunks
|
||||
if w.closeChan != nil {
|
||||
<-w.closeChan
|
||||
w.chunkChan = nil
|
||||
}
|
||||
|
||||
if w.file == nil {
|
||||
select {
|
||||
case errWrite := <-w.errorChan:
|
||||
w.cleanupTempFiles()
|
||||
return errWrite
|
||||
default:
|
||||
}
|
||||
|
||||
if w.logFilePath == "" {
|
||||
w.cleanupTempFiles()
|
||||
return nil
|
||||
}
|
||||
|
||||
// Write all content in the correct order
|
||||
var content strings.Builder
|
||||
|
||||
// 1. Write API REQUEST section
|
||||
if len(w.apiRequest) > 0 {
|
||||
if bytes.HasPrefix(w.apiRequest, []byte("=== API REQUEST")) {
|
||||
content.Write(w.apiRequest)
|
||||
if !bytes.HasSuffix(w.apiRequest, []byte("\n")) {
|
||||
content.WriteString("\n")
|
||||
}
|
||||
} else {
|
||||
content.WriteString("=== API REQUEST ===\n")
|
||||
content.Write(w.apiRequest)
|
||||
content.WriteString("\n")
|
||||
}
|
||||
content.WriteString("\n")
|
||||
logFile, errOpen := os.OpenFile(w.logFilePath, os.O_CREATE|os.O_WRONLY|os.O_TRUNC, 0644)
|
||||
if errOpen != nil {
|
||||
w.cleanupTempFiles()
|
||||
return fmt.Errorf("failed to create log file: %w", errOpen)
|
||||
}
|
||||
|
||||
// 2. Write API RESPONSE section
|
||||
if len(w.apiResponse) > 0 {
|
||||
if bytes.HasPrefix(w.apiResponse, []byte("=== API RESPONSE")) {
|
||||
content.Write(w.apiResponse)
|
||||
if !bytes.HasSuffix(w.apiResponse, []byte("\n")) {
|
||||
content.WriteString("\n")
|
||||
writeErr := w.writeFinalLog(logFile)
|
||||
if errClose := logFile.Close(); errClose != nil {
|
||||
log.WithError(errClose).Warn("failed to close request log file")
|
||||
if writeErr == nil {
|
||||
writeErr = errClose
|
||||
}
|
||||
} else {
|
||||
content.WriteString("=== API RESPONSE ===\n")
|
||||
content.Write(w.apiResponse)
|
||||
content.WriteString("\n")
|
||||
}
|
||||
content.WriteString("\n")
|
||||
}
|
||||
|
||||
// 3. Write RESPONSE section (status, headers, buffered chunks)
|
||||
content.WriteString("=== RESPONSE ===\n")
|
||||
if w.statusWritten {
|
||||
content.WriteString(fmt.Sprintf("Status: %d\n", w.responseStatus))
|
||||
}
|
||||
|
||||
for key, values := range w.responseHeaders {
|
||||
for _, value := range values {
|
||||
content.WriteString(fmt.Sprintf("%s: %s\n", key, value))
|
||||
}
|
||||
}
|
||||
content.WriteString("\n")
|
||||
|
||||
// Write buffered response body chunks
|
||||
if w.bufferedChunks != nil && w.bufferedChunks.Len() > 0 {
|
||||
content.Write(w.bufferedChunks.Bytes())
|
||||
}
|
||||
|
||||
// Write the complete content to file
|
||||
if _, err := w.file.WriteString(content.String()); err != nil {
|
||||
_ = w.file.Close()
|
||||
return err
|
||||
}
|
||||
|
||||
return w.file.Close()
|
||||
w.cleanupTempFiles()
|
||||
return writeErr
|
||||
}
|
||||
|
||||
// asyncWriter runs in a goroutine to buffer chunks from the channel.
|
||||
// It continuously reads chunks from the channel and buffers them for later writing.
|
||||
// It continuously reads chunks from the channel and appends them to a temp file for later assembly.
|
||||
func (w *FileStreamingLogWriter) asyncWriter() {
|
||||
defer close(w.closeChan)
|
||||
|
||||
for chunk := range w.chunkChan {
|
||||
if w.bufferedChunks != nil {
|
||||
w.bufferedChunks.Write(chunk)
|
||||
if w.responseBodyFile == nil {
|
||||
continue
|
||||
}
|
||||
if _, errWrite := w.responseBodyFile.Write(chunk); errWrite != nil {
|
||||
select {
|
||||
case w.errorChan <- errWrite:
|
||||
default:
|
||||
}
|
||||
if errClose := w.responseBodyFile.Close(); errClose != nil {
|
||||
select {
|
||||
case w.errorChan <- errClose:
|
||||
default:
|
||||
}
|
||||
}
|
||||
w.responseBodyFile = nil
|
||||
}
|
||||
}
|
||||
|
||||
if w.responseBodyFile == nil {
|
||||
return
|
||||
}
|
||||
if errClose := w.responseBodyFile.Close(); errClose != nil {
|
||||
select {
|
||||
case w.errorChan <- errClose:
|
||||
default:
|
||||
}
|
||||
}
|
||||
w.responseBodyFile = nil
|
||||
}
|
||||
|
||||
func (w *FileStreamingLogWriter) writeFinalLog(logFile *os.File) error {
|
||||
if errWrite := writeRequestInfoWithBody(logFile, w.url, w.method, w.requestHeaders, nil, w.requestBodyPath, w.timestamp); errWrite != nil {
|
||||
return errWrite
|
||||
}
|
||||
if errWrite := writeAPISection(logFile, "=== API REQUEST ===\n", "=== API REQUEST", w.apiRequest); errWrite != nil {
|
||||
return errWrite
|
||||
}
|
||||
if errWrite := writeAPISection(logFile, "=== API RESPONSE ===\n", "=== API RESPONSE", w.apiResponse); errWrite != nil {
|
||||
return errWrite
|
||||
}
|
||||
|
||||
responseBodyFile, errOpen := os.Open(w.responseBodyPath)
|
||||
if errOpen != nil {
|
||||
return errOpen
|
||||
}
|
||||
defer func() {
|
||||
if errClose := responseBodyFile.Close(); errClose != nil {
|
||||
log.WithError(errClose).Warn("failed to close response body temp file")
|
||||
}
|
||||
}()
|
||||
|
||||
return writeResponseSection(logFile, w.responseStatus, w.statusWritten, w.responseHeaders, responseBodyFile, nil, false)
|
||||
}
|
||||
|
||||
func (w *FileStreamingLogWriter) cleanupTempFiles() {
|
||||
if w.requestBodyPath != "" {
|
||||
if errRemove := os.Remove(w.requestBodyPath); errRemove != nil {
|
||||
log.WithError(errRemove).Warn("failed to remove request body temp file")
|
||||
}
|
||||
w.requestBodyPath = ""
|
||||
}
|
||||
|
||||
if w.responseBodyPath != "" {
|
||||
if errRemove := os.Remove(w.responseBodyPath); errRemove != nil {
|
||||
log.WithError(errRemove).Warn("failed to remove response body temp file")
|
||||
}
|
||||
w.responseBodyPath = ""
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
@@ -4,6 +4,8 @@ import (
|
||||
"crypto/rand"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"net/url"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// GenerateRandomState generates a cryptographically secure random state parameter
|
||||
@@ -19,3 +21,83 @@ func GenerateRandomState() (string, error) {
|
||||
}
|
||||
return hex.EncodeToString(bytes), nil
|
||||
}
|
||||
|
||||
// OAuthCallback captures the parsed OAuth callback parameters.
|
||||
type OAuthCallback struct {
|
||||
Code string
|
||||
State string
|
||||
Error string
|
||||
ErrorDescription string
|
||||
}
|
||||
|
||||
// ParseOAuthCallback extracts OAuth parameters from a callback URL.
|
||||
// It returns nil when the input is empty.
|
||||
func ParseOAuthCallback(input string) (*OAuthCallback, error) {
|
||||
trimmed := strings.TrimSpace(input)
|
||||
if trimmed == "" {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
candidate := trimmed
|
||||
if !strings.Contains(candidate, "://") {
|
||||
if strings.HasPrefix(candidate, "?") {
|
||||
candidate = "http://localhost" + candidate
|
||||
} else if strings.ContainsAny(candidate, "/?#") || strings.Contains(candidate, ":") {
|
||||
candidate = "http://" + candidate
|
||||
} else if strings.Contains(candidate, "=") {
|
||||
candidate = "http://localhost/?" + candidate
|
||||
} else {
|
||||
return nil, fmt.Errorf("invalid callback URL")
|
||||
}
|
||||
}
|
||||
|
||||
parsedURL, err := url.Parse(candidate)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
query := parsedURL.Query()
|
||||
code := strings.TrimSpace(query.Get("code"))
|
||||
state := strings.TrimSpace(query.Get("state"))
|
||||
errCode := strings.TrimSpace(query.Get("error"))
|
||||
errDesc := strings.TrimSpace(query.Get("error_description"))
|
||||
|
||||
if parsedURL.Fragment != "" {
|
||||
if fragQuery, errFrag := url.ParseQuery(parsedURL.Fragment); errFrag == nil {
|
||||
if code == "" {
|
||||
code = strings.TrimSpace(fragQuery.Get("code"))
|
||||
}
|
||||
if state == "" {
|
||||
state = strings.TrimSpace(fragQuery.Get("state"))
|
||||
}
|
||||
if errCode == "" {
|
||||
errCode = strings.TrimSpace(fragQuery.Get("error"))
|
||||
}
|
||||
if errDesc == "" {
|
||||
errDesc = strings.TrimSpace(fragQuery.Get("error_description"))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if code != "" && state == "" && strings.Contains(code, "#") {
|
||||
parts := strings.SplitN(code, "#", 2)
|
||||
code = parts[0]
|
||||
state = parts[1]
|
||||
}
|
||||
|
||||
if errCode == "" && errDesc != "" {
|
||||
errCode = errDesc
|
||||
errDesc = ""
|
||||
}
|
||||
|
||||
if code == "" && errCode == "" {
|
||||
return nil, fmt.Errorf("callback URL missing code")
|
||||
}
|
||||
|
||||
return &OAuthCallback{
|
||||
Code: code,
|
||||
State: state,
|
||||
Error: errCode,
|
||||
ErrorDescription: errDesc,
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -162,6 +162,21 @@ func GetGeminiModels() []*ModelInfo {
|
||||
SupportedGenerationMethods: []string{"generateContent", "countTokens", "createCachedContent", "batchGenerateContent"},
|
||||
Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"low", "high"}},
|
||||
},
|
||||
{
|
||||
ID: "gemini-3-flash-preview",
|
||||
Object: "model",
|
||||
Created: 1765929600,
|
||||
OwnedBy: "google",
|
||||
Type: "gemini",
|
||||
Name: "models/gemini-3-flash-preview",
|
||||
Version: "3.0",
|
||||
DisplayName: "Gemini 3 Flash Preview",
|
||||
Description: "Gemini 3 Flash Preview",
|
||||
InputTokenLimit: 1048576,
|
||||
OutputTokenLimit: 65536,
|
||||
SupportedGenerationMethods: []string{"generateContent", "countTokens", "createCachedContent", "batchGenerateContent"},
|
||||
Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"minimal", "low", "medium", "high"}},
|
||||
},
|
||||
{
|
||||
ID: "gemini-3-pro-image-preview",
|
||||
Object: "model",
|
||||
|
||||
@@ -325,8 +325,8 @@ func (e *AIStudioExecutor) translateRequest(req cliproxyexecutor.Request, opts c
|
||||
payload = ApplyThinkingMetadata(payload, req.Metadata, req.Model)
|
||||
payload = util.ApplyGemini3ThinkingLevelFromMetadata(req.Model, req.Metadata, payload)
|
||||
payload = util.ApplyDefaultThinkingIfNeeded(req.Model, payload)
|
||||
payload = util.ConvertThinkingLevelToBudget(payload, req.Model)
|
||||
payload = util.NormalizeGeminiThinkingBudget(req.Model, payload)
|
||||
payload = util.ConvertThinkingLevelToBudget(payload, req.Model, true)
|
||||
payload = util.NormalizeGeminiThinkingBudget(req.Model, payload, true)
|
||||
payload = util.StripThinkingConfigIfUnsupported(req.Model, payload)
|
||||
payload = fixGeminiImageAspectRatio(req.Model, payload)
|
||||
payload = applyPayloadConfig(e.cfg, req.Model, payload)
|
||||
|
||||
@@ -7,6 +7,8 @@ import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"encoding/binary"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
@@ -97,6 +99,7 @@ func (e *AntigravityExecutor) Execute(ctx context.Context, auth *cliproxyauth.Au
|
||||
translated = util.ApplyGemini3ThinkingLevelFromMetadataCLI(req.Model, req.Metadata, translated)
|
||||
translated = util.ApplyDefaultThinkingIfNeededCLI(req.Model, translated)
|
||||
translated = normalizeAntigravityThinking(req.Model, translated)
|
||||
translated = applyPayloadConfigWithRoot(e.cfg, req.Model, "antigravity", "request", translated)
|
||||
|
||||
baseURLs := antigravityBaseURLFallbackOrder(auth)
|
||||
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
|
||||
@@ -191,6 +194,7 @@ func (e *AntigravityExecutor) executeClaudeNonStream(ctx context.Context, auth *
|
||||
translated = util.ApplyGemini3ThinkingLevelFromMetadataCLI(req.Model, req.Metadata, translated)
|
||||
translated = util.ApplyDefaultThinkingIfNeededCLI(req.Model, translated)
|
||||
translated = normalizeAntigravityThinking(req.Model, translated)
|
||||
translated = applyPayloadConfigWithRoot(e.cfg, req.Model, "antigravity", "request", translated)
|
||||
|
||||
baseURLs := antigravityBaseURLFallbackOrder(auth)
|
||||
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
|
||||
@@ -524,6 +528,7 @@ func (e *AntigravityExecutor) ExecuteStream(ctx context.Context, auth *cliproxya
|
||||
translated = util.ApplyGemini3ThinkingLevelFromMetadataCLI(req.Model, req.Metadata, translated)
|
||||
translated = util.ApplyDefaultThinkingIfNeededCLI(req.Model, translated)
|
||||
translated = normalizeAntigravityThinking(req.Model, translated)
|
||||
translated = applyPayloadConfigWithRoot(e.cfg, req.Model, "antigravity", "request", translated)
|
||||
|
||||
baseURLs := antigravityBaseURLFallbackOrder(auth)
|
||||
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
|
||||
@@ -1011,7 +1016,7 @@ func (e *AntigravityExecutor) buildRequest(ctx context.Context, auth *cliproxyau
|
||||
|
||||
// Use the centralized schema cleaner to handle unsupported keywords,
|
||||
// const->enum conversion, and flattening of types/anyOf.
|
||||
strJSON = util.CleanJSONSchemaForGemini(strJSON)
|
||||
strJSON = util.CleanJSONSchemaForAntigravity(strJSON)
|
||||
|
||||
payload = []byte(strJSON)
|
||||
}
|
||||
@@ -1187,7 +1192,7 @@ func geminiToAntigravity(modelName string, payload []byte, projectID string) []b
|
||||
template, _ = sjson.Set(template, "project", generateProjectID())
|
||||
}
|
||||
template, _ = sjson.Set(template, "requestId", generateRequestID())
|
||||
template, _ = sjson.Set(template, "request.sessionId", generateSessionID())
|
||||
template, _ = sjson.Set(template, "request.sessionId", generateStableSessionID(payload))
|
||||
|
||||
template, _ = sjson.Delete(template, "request.safetySettings")
|
||||
template, _ = sjson.Set(template, "request.toolConfig.functionCallingConfig.mode", "VALIDATED")
|
||||
@@ -1229,6 +1234,23 @@ func generateSessionID() string {
|
||||
return "-" + strconv.FormatInt(n, 10)
|
||||
}
|
||||
|
||||
func generateStableSessionID(payload []byte) string {
|
||||
contents := gjson.GetBytes(payload, "request.contents")
|
||||
if contents.IsArray() {
|
||||
for _, content := range contents.Array() {
|
||||
if content.Get("role").String() == "user" {
|
||||
text := content.Get("parts.0.text").String()
|
||||
if text != "" {
|
||||
h := sha256.Sum256([]byte(text))
|
||||
n := int64(binary.BigEndian.Uint64(h[:8])) & 0x7FFFFFFFFFFFFFFF
|
||||
return "-" + strconv.FormatInt(n, 10)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return generateSessionID()
|
||||
}
|
||||
|
||||
func generateProjectID() string {
|
||||
adjectives := []string{"useful", "bright", "swift", "calm", "bold"}
|
||||
nouns := []string{"fuze", "wave", "spark", "flow", "core"}
|
||||
|
||||
@@ -7,15 +7,40 @@ package claude
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"strings"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/cache"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/translator/gemini/common"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/tidwall/gjson"
|
||||
"github.com/tidwall/sjson"
|
||||
)
|
||||
|
||||
const geminiCLIClaudeThoughtSignature = "skip_thought_signature_validator"
|
||||
// deriveSessionID generates a stable session ID from the request.
|
||||
// Uses the hash of the first user message to identify the conversation.
|
||||
func deriveSessionID(rawJSON []byte) string {
|
||||
messages := gjson.GetBytes(rawJSON, "messages")
|
||||
if !messages.IsArray() {
|
||||
return ""
|
||||
}
|
||||
for _, msg := range messages.Array() {
|
||||
if msg.Get("role").String() == "user" {
|
||||
content := msg.Get("content").String()
|
||||
if content == "" {
|
||||
// Try to get text from content array
|
||||
content = msg.Get("content.0.text").String()
|
||||
}
|
||||
if content != "" {
|
||||
h := sha256.Sum256([]byte(content))
|
||||
return hex.EncodeToString(h[:16])
|
||||
}
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// ConvertClaudeRequestToAntigravity parses and transforms a Claude Code API request into Gemini CLI API format.
|
||||
// It extracts the model name, system instruction, message contents, and tool declarations
|
||||
@@ -37,7 +62,9 @@ const geminiCLIClaudeThoughtSignature = "skip_thought_signature_validator"
|
||||
// - []byte: The transformed request data in Gemini CLI API format
|
||||
func ConvertClaudeRequestToAntigravity(modelName string, inputRawJSON []byte, _ bool) []byte {
|
||||
rawJSON := bytes.Clone(inputRawJSON)
|
||||
rawJSON = bytes.Replace(rawJSON, []byte(`"url":{"type":"string","format":"uri",`), []byte(`"url":{"type":"string",`), -1)
|
||||
|
||||
// Derive session ID for signature caching
|
||||
sessionID := deriveSessionID(rawJSON)
|
||||
|
||||
// system instruction
|
||||
systemInstructionJSON := ""
|
||||
@@ -64,16 +91,19 @@ func ConvertClaudeRequestToAntigravity(modelName string, inputRawJSON []byte, _
|
||||
// contents
|
||||
contentsJSON := "[]"
|
||||
hasContents := false
|
||||
|
||||
messagesResult := gjson.GetBytes(rawJSON, "messages")
|
||||
if messagesResult.IsArray() {
|
||||
messageResults := messagesResult.Array()
|
||||
for i := 0; i < len(messageResults); i++ {
|
||||
numMessages := len(messageResults)
|
||||
for i := 0; i < numMessages; i++ {
|
||||
messageResult := messageResults[i]
|
||||
roleResult := messageResult.Get("role")
|
||||
if roleResult.Type != gjson.String {
|
||||
continue
|
||||
}
|
||||
role := roleResult.String()
|
||||
originalRole := roleResult.String()
|
||||
role := originalRole
|
||||
if role == "assistant" {
|
||||
role = "model"
|
||||
}
|
||||
@@ -82,20 +112,59 @@ func ConvertClaudeRequestToAntigravity(modelName string, inputRawJSON []byte, _
|
||||
contentsResult := messageResult.Get("content")
|
||||
if contentsResult.IsArray() {
|
||||
contentResults := contentsResult.Array()
|
||||
for j := 0; j < len(contentResults); j++ {
|
||||
numContents := len(contentResults)
|
||||
var currentMessageThinkingSignature string
|
||||
for j := 0; j < numContents; j++ {
|
||||
contentResult := contentResults[j]
|
||||
contentTypeResult := contentResult.Get("type")
|
||||
if contentTypeResult.Type == gjson.String && contentTypeResult.String() == "thinking" {
|
||||
prompt := contentResult.Get("thinking").String()
|
||||
// Use GetThinkingText to handle wrapped thinking objects
|
||||
thinkingText := util.GetThinkingText(contentResult)
|
||||
signatureResult := contentResult.Get("signature")
|
||||
signature := geminiCLIClaudeThoughtSignature
|
||||
if signatureResult.Exists() {
|
||||
signature = signatureResult.String()
|
||||
clientSignature := ""
|
||||
if signatureResult.Exists() && signatureResult.String() != "" {
|
||||
clientSignature = signatureResult.String()
|
||||
}
|
||||
|
||||
// Always try cached signature first (more reliable than client-provided)
|
||||
// Client may send stale or invalid signatures from different sessions
|
||||
signature := ""
|
||||
if sessionID != "" && thinkingText != "" {
|
||||
if cachedSig := cache.GetCachedSignature(sessionID, thinkingText); cachedSig != "" {
|
||||
signature = cachedSig
|
||||
log.Debugf("Using cached signature for thinking block")
|
||||
}
|
||||
}
|
||||
|
||||
// Fallback to client signature only if cache miss and client signature is valid
|
||||
if signature == "" && cache.HasValidSignature(clientSignature) {
|
||||
signature = clientSignature
|
||||
log.Debugf("Using client-provided signature for thinking block")
|
||||
}
|
||||
|
||||
// Store for subsequent tool_use in the same message
|
||||
if cache.HasValidSignature(signature) {
|
||||
currentMessageThinkingSignature = signature
|
||||
}
|
||||
|
||||
|
||||
// Skip trailing unsigned thinking blocks on last assistant message
|
||||
isUnsigned := !cache.HasValidSignature(signature)
|
||||
|
||||
// If unsigned, skip entirely (don't convert to text)
|
||||
// Claude requires assistant messages to start with thinking blocks when thinking is enabled
|
||||
// Converting to text would break this requirement
|
||||
if isUnsigned {
|
||||
// TypeScript plugin approach: drop unsigned thinking blocks entirely
|
||||
log.Debugf("Dropping unsigned thinking block (no valid signature)")
|
||||
continue
|
||||
}
|
||||
|
||||
// Valid signature, send as thought block
|
||||
partJSON := `{}`
|
||||
partJSON, _ = sjson.Set(partJSON, "thought", true)
|
||||
if prompt != "" {
|
||||
partJSON, _ = sjson.Set(partJSON, "text", prompt)
|
||||
if thinkingText != "" {
|
||||
partJSON, _ = sjson.Set(partJSON, "text", thinkingText)
|
||||
}
|
||||
if signature != "" {
|
||||
partJSON, _ = sjson.Set(partJSON, "thoughtSignature", signature)
|
||||
@@ -109,24 +178,47 @@ func ConvertClaudeRequestToAntigravity(modelName string, inputRawJSON []byte, _
|
||||
}
|
||||
clientContentJSON, _ = sjson.SetRaw(clientContentJSON, "parts.-1", partJSON)
|
||||
} else if contentTypeResult.Type == gjson.String && contentTypeResult.String() == "tool_use" {
|
||||
// NOTE: Do NOT inject dummy thinking blocks here.
|
||||
// Antigravity API validates signatures, so dummy values are rejected.
|
||||
// The TypeScript plugin removes unsigned thinking blocks instead of injecting dummies.
|
||||
|
||||
functionName := contentResult.Get("name").String()
|
||||
functionArgs := contentResult.Get("input").String()
|
||||
argsResult := contentResult.Get("input")
|
||||
functionID := contentResult.Get("id").String()
|
||||
if gjson.Valid(functionArgs) {
|
||||
argsResult := gjson.Parse(functionArgs)
|
||||
|
||||
// Handle both object and string input formats
|
||||
var argsRaw string
|
||||
if argsResult.IsObject() {
|
||||
partJSON := `{}`
|
||||
if !strings.Contains(modelName, "claude") {
|
||||
partJSON, _ = sjson.Set(partJSON, "thoughtSignature", geminiCLIClaudeThoughtSignature)
|
||||
argsRaw = argsResult.Raw
|
||||
} else if argsResult.Type == gjson.String {
|
||||
// Input is a JSON string, parse and validate it
|
||||
parsed := gjson.Parse(argsResult.String())
|
||||
if parsed.IsObject() {
|
||||
argsRaw = parsed.Raw
|
||||
}
|
||||
}
|
||||
|
||||
if argsRaw != "" {
|
||||
partJSON := `{}`
|
||||
|
||||
// Use skip_thought_signature_validator for tool calls without valid thinking signature
|
||||
// This is the approach used in opencode-google-antigravity-auth for Gemini
|
||||
// and also works for Claude through Antigravity API
|
||||
const skipSentinel = "skip_thought_signature_validator"
|
||||
if cache.HasValidSignature(currentMessageThinkingSignature) {
|
||||
partJSON, _ = sjson.Set(partJSON, "thoughtSignature", currentMessageThinkingSignature)
|
||||
} else {
|
||||
// No valid signature - use skip sentinel to bypass validation
|
||||
partJSON, _ = sjson.Set(partJSON, "thoughtSignature", skipSentinel)
|
||||
}
|
||||
|
||||
if functionID != "" {
|
||||
partJSON, _ = sjson.Set(partJSON, "functionCall.id", functionID)
|
||||
}
|
||||
partJSON, _ = sjson.Set(partJSON, "functionCall.name", functionName)
|
||||
partJSON, _ = sjson.SetRaw(partJSON, "functionCall.args", argsResult.Raw)
|
||||
partJSON, _ = sjson.SetRaw(partJSON, "functionCall.args", argsRaw)
|
||||
clientContentJSON, _ = sjson.SetRaw(clientContentJSON, "parts.-1", partJSON)
|
||||
}
|
||||
}
|
||||
} else if contentTypeResult.Type == gjson.String && contentTypeResult.String() == "tool_result" {
|
||||
toolCallID := contentResult.Get("tool_use_id").String()
|
||||
if toolCallID != "" {
|
||||
@@ -180,6 +272,37 @@ func ConvertClaudeRequestToAntigravity(modelName string, inputRawJSON []byte, _
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Reorder parts for 'model' role to ensure thinking block is first
|
||||
if role == "model" {
|
||||
partsResult := gjson.Get(clientContentJSON, "parts")
|
||||
if partsResult.IsArray() {
|
||||
parts := partsResult.Array()
|
||||
var thinkingParts []gjson.Result
|
||||
var otherParts []gjson.Result
|
||||
for _, part := range parts {
|
||||
if part.Get("thought").Bool() {
|
||||
thinkingParts = append(thinkingParts, part)
|
||||
} else {
|
||||
otherParts = append(otherParts, part)
|
||||
}
|
||||
}
|
||||
if len(thinkingParts) > 0 {
|
||||
firstPartIsThinking := parts[0].Get("thought").Bool()
|
||||
if !firstPartIsThinking || len(thinkingParts) > 1 {
|
||||
var newParts []interface{}
|
||||
for _, p := range thinkingParts {
|
||||
newParts = append(newParts, p.Value())
|
||||
}
|
||||
for _, p := range otherParts {
|
||||
newParts = append(newParts, p.Value())
|
||||
}
|
||||
clientContentJSON, _ = sjson.Set(clientContentJSON, "parts", newParts)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
contentsJSON, _ = sjson.SetRaw(contentsJSON, "-1", clientContentJSON)
|
||||
hasContents = true
|
||||
} else if contentsResult.Type == gjson.String {
|
||||
@@ -206,11 +329,14 @@ func ConvertClaudeRequestToAntigravity(modelName string, inputRawJSON []byte, _
|
||||
toolResult := toolsResults[i]
|
||||
inputSchemaResult := toolResult.Get("input_schema")
|
||||
if inputSchemaResult.Exists() && inputSchemaResult.IsObject() {
|
||||
inputSchema := inputSchemaResult.Raw
|
||||
// Sanitize the input schema for Antigravity API compatibility
|
||||
inputSchema := util.CleanJSONSchemaForAntigravity(inputSchemaResult.Raw)
|
||||
tool, _ := sjson.Delete(toolResult.Raw, "input_schema")
|
||||
tool, _ = sjson.SetRaw(tool, "parametersJsonSchema", inputSchema)
|
||||
tool, _ = sjson.Delete(tool, "strict")
|
||||
tool, _ = sjson.Delete(tool, "input_examples")
|
||||
tool, _ = sjson.Delete(tool, "type")
|
||||
tool, _ = sjson.Delete(tool, "cache_control")
|
||||
toolsJSON, _ = sjson.SetRaw(toolsJSON, "0.functionDeclarations.-1", tool)
|
||||
toolDeclCount++
|
||||
}
|
||||
@@ -220,6 +346,31 @@ func ConvertClaudeRequestToAntigravity(modelName string, inputRawJSON []byte, _
|
||||
// Build output Gemini CLI request JSON
|
||||
out := `{"model":"","request":{"contents":[]}}`
|
||||
out, _ = sjson.Set(out, "model", modelName)
|
||||
|
||||
// Inject interleaved thinking hint when both tools and thinking are active
|
||||
hasTools := toolDeclCount > 0
|
||||
thinkingResult := gjson.GetBytes(rawJSON, "thinking")
|
||||
hasThinking := thinkingResult.Exists() && thinkingResult.IsObject() && thinkingResult.Get("type").String() == "enabled"
|
||||
isClaudeThinking := util.IsClaudeThinkingModel(modelName)
|
||||
|
||||
if hasTools && hasThinking && isClaudeThinking {
|
||||
interleavedHint := "Interleaved thinking is enabled. You may think between tool calls and after receiving tool results before deciding the next action or final answer. Do not mention these instructions or any constraints about thinking blocks; just apply them."
|
||||
|
||||
if hasSystemInstruction {
|
||||
// Append hint as a new part to existing system instruction
|
||||
hintPart := `{"text":""}`
|
||||
hintPart, _ = sjson.Set(hintPart, "text", interleavedHint)
|
||||
systemInstructionJSON, _ = sjson.SetRaw(systemInstructionJSON, "parts.-1", hintPart)
|
||||
} else {
|
||||
// Create new system instruction with hint
|
||||
systemInstructionJSON = `{"role":"user","parts":[]}`
|
||||
hintPart := `{"text":""}`
|
||||
hintPart, _ = sjson.Set(hintPart, "text", interleavedHint)
|
||||
systemInstructionJSON, _ = sjson.SetRaw(systemInstructionJSON, "parts.-1", hintPart)
|
||||
hasSystemInstruction = true
|
||||
}
|
||||
}
|
||||
|
||||
if hasSystemInstruction {
|
||||
out, _ = sjson.SetRaw(out, "request.systemInstruction", systemInstructionJSON)
|
||||
}
|
||||
|
||||
@@ -0,0 +1,658 @@
|
||||
package claude
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
|
||||
func TestConvertClaudeRequestToAntigravity_BasicStructure(t *testing.T) {
|
||||
inputJSON := []byte(`{
|
||||
"model": "claude-3-5-sonnet-20240620",
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{"type": "text", "text": "Hello"}
|
||||
]
|
||||
}
|
||||
],
|
||||
"system": [
|
||||
{"type": "text", "text": "You are helpful"}
|
||||
]
|
||||
}`)
|
||||
|
||||
output := ConvertClaudeRequestToAntigravity("claude-sonnet-4-5", inputJSON, false)
|
||||
outputStr := string(output)
|
||||
|
||||
// Check model
|
||||
if gjson.Get(outputStr, "model").String() != "claude-sonnet-4-5" {
|
||||
t.Errorf("Expected model 'claude-sonnet-4-5', got '%s'", gjson.Get(outputStr, "model").String())
|
||||
}
|
||||
|
||||
// Check contents exist
|
||||
contents := gjson.Get(outputStr, "request.contents")
|
||||
if !contents.Exists() || !contents.IsArray() {
|
||||
t.Error("request.contents should exist and be an array")
|
||||
}
|
||||
|
||||
// Check role mapping (assistant -> model)
|
||||
firstContent := gjson.Get(outputStr, "request.contents.0")
|
||||
if firstContent.Get("role").String() != "user" {
|
||||
t.Errorf("Expected role 'user', got '%s'", firstContent.Get("role").String())
|
||||
}
|
||||
|
||||
// Check systemInstruction
|
||||
sysInstruction := gjson.Get(outputStr, "request.systemInstruction")
|
||||
if !sysInstruction.Exists() {
|
||||
t.Error("systemInstruction should exist")
|
||||
}
|
||||
if sysInstruction.Get("parts.0.text").String() != "You are helpful" {
|
||||
t.Error("systemInstruction text mismatch")
|
||||
}
|
||||
}
|
||||
|
||||
func TestConvertClaudeRequestToAntigravity_RoleMapping(t *testing.T) {
|
||||
inputJSON := []byte(`{
|
||||
"model": "claude-3-5-sonnet-20240620",
|
||||
"messages": [
|
||||
{"role": "user", "content": [{"type": "text", "text": "Hi"}]},
|
||||
{"role": "assistant", "content": [{"type": "text", "text": "Hello"}]}
|
||||
]
|
||||
}`)
|
||||
|
||||
output := ConvertClaudeRequestToAntigravity("claude-sonnet-4-5", inputJSON, false)
|
||||
outputStr := string(output)
|
||||
|
||||
// assistant should be mapped to model
|
||||
secondContent := gjson.Get(outputStr, "request.contents.1")
|
||||
if secondContent.Get("role").String() != "model" {
|
||||
t.Errorf("Expected role 'model' (mapped from 'assistant'), got '%s'", secondContent.Get("role").String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestConvertClaudeRequestToAntigravity_ThinkingBlocks(t *testing.T) {
|
||||
// Valid signature must be at least 50 characters
|
||||
validSignature := "abc123validSignature1234567890123456789012345678901234567890"
|
||||
inputJSON := []byte(`{
|
||||
"model": "claude-sonnet-4-5-thinking",
|
||||
"messages": [
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": [
|
||||
{"type": "thinking", "thinking": "Let me think...", "signature": "` + validSignature + `"},
|
||||
{"type": "text", "text": "Answer"}
|
||||
]
|
||||
}
|
||||
]
|
||||
}`)
|
||||
|
||||
output := ConvertClaudeRequestToAntigravity("claude-sonnet-4-5-thinking", inputJSON, false)
|
||||
outputStr := string(output)
|
||||
|
||||
// Check thinking block conversion
|
||||
firstPart := gjson.Get(outputStr, "request.contents.0.parts.0")
|
||||
if !firstPart.Get("thought").Bool() {
|
||||
t.Error("thinking block should have thought: true")
|
||||
}
|
||||
if firstPart.Get("text").String() != "Let me think..." {
|
||||
t.Error("thinking text mismatch")
|
||||
}
|
||||
if firstPart.Get("thoughtSignature").String() != validSignature {
|
||||
t.Errorf("Expected thoughtSignature '%s', got '%s'", validSignature, firstPart.Get("thoughtSignature").String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestConvertClaudeRequestToAntigravity_ThinkingBlockWithoutSignature(t *testing.T) {
|
||||
// Unsigned thinking blocks should be removed entirely (not converted to text)
|
||||
inputJSON := []byte(`{
|
||||
"model": "claude-sonnet-4-5-thinking",
|
||||
"messages": [
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": [
|
||||
{"type": "thinking", "thinking": "Let me think..."},
|
||||
{"type": "text", "text": "Answer"}
|
||||
]
|
||||
}
|
||||
]
|
||||
}`)
|
||||
|
||||
output := ConvertClaudeRequestToAntigravity("claude-sonnet-4-5-thinking", inputJSON, false)
|
||||
outputStr := string(output)
|
||||
|
||||
// Without signature, thinking block should be removed (not converted to text)
|
||||
parts := gjson.Get(outputStr, "request.contents.0.parts").Array()
|
||||
if len(parts) != 1 {
|
||||
t.Fatalf("Expected 1 part (thinking removed), got %d", len(parts))
|
||||
}
|
||||
|
||||
// Only text part should remain
|
||||
if parts[0].Get("thought").Bool() {
|
||||
t.Error("Thinking block should be removed, not preserved")
|
||||
}
|
||||
if parts[0].Get("text").String() != "Answer" {
|
||||
t.Errorf("Expected text 'Answer', got '%s'", parts[0].Get("text").String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestConvertClaudeRequestToAntigravity_ToolDeclarations(t *testing.T) {
|
||||
inputJSON := []byte(`{
|
||||
"model": "claude-3-5-sonnet-20240620",
|
||||
"messages": [],
|
||||
"tools": [
|
||||
{
|
||||
"name": "test_tool",
|
||||
"description": "A test tool",
|
||||
"input_schema": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"name": {"type": "string"}
|
||||
},
|
||||
"required": ["name"]
|
||||
}
|
||||
}
|
||||
]
|
||||
}`)
|
||||
|
||||
output := ConvertClaudeRequestToAntigravity("gemini-1.5-pro", inputJSON, false)
|
||||
outputStr := string(output)
|
||||
|
||||
// Check tools structure
|
||||
tools := gjson.Get(outputStr, "request.tools")
|
||||
if !tools.Exists() {
|
||||
t.Error("Tools should exist in output")
|
||||
}
|
||||
|
||||
funcDecl := gjson.Get(outputStr, "request.tools.0.functionDeclarations.0")
|
||||
if funcDecl.Get("name").String() != "test_tool" {
|
||||
t.Errorf("Expected tool name 'test_tool', got '%s'", funcDecl.Get("name").String())
|
||||
}
|
||||
|
||||
// Check input_schema renamed to parametersJsonSchema
|
||||
if funcDecl.Get("parametersJsonSchema").Exists() {
|
||||
t.Log("parametersJsonSchema exists (expected)")
|
||||
}
|
||||
if funcDecl.Get("input_schema").Exists() {
|
||||
t.Error("input_schema should be removed")
|
||||
}
|
||||
}
|
||||
|
||||
func TestConvertClaudeRequestToAntigravity_ToolUse(t *testing.T) {
|
||||
inputJSON := []byte(`{
|
||||
"model": "claude-3-5-sonnet-20240620",
|
||||
"messages": [
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": [
|
||||
{
|
||||
"type": "tool_use",
|
||||
"id": "call_123",
|
||||
"name": "get_weather",
|
||||
"input": "{\"location\": \"Paris\"}"
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
}`)
|
||||
|
||||
output := ConvertClaudeRequestToAntigravity("claude-sonnet-4-5", inputJSON, false)
|
||||
outputStr := string(output)
|
||||
|
||||
// Now we expect only 1 part (tool_use), no dummy thinking block injected
|
||||
parts := gjson.Get(outputStr, "request.contents.0.parts").Array()
|
||||
if len(parts) != 1 {
|
||||
t.Fatalf("Expected 1 part (tool only, no dummy injection), got %d", len(parts))
|
||||
}
|
||||
|
||||
// Check function call conversion at parts[0]
|
||||
funcCall := parts[0].Get("functionCall")
|
||||
if !funcCall.Exists() {
|
||||
t.Error("functionCall should exist at parts[0]")
|
||||
}
|
||||
if funcCall.Get("name").String() != "get_weather" {
|
||||
t.Errorf("Expected function name 'get_weather', got '%s'", funcCall.Get("name").String())
|
||||
}
|
||||
if funcCall.Get("id").String() != "call_123" {
|
||||
t.Errorf("Expected function id 'call_123', got '%s'", funcCall.Get("id").String())
|
||||
}
|
||||
// Verify skip_thought_signature_validator is added (bypass for tools without valid thinking)
|
||||
expectedSig := "skip_thought_signature_validator"
|
||||
actualSig := parts[0].Get("thoughtSignature").String()
|
||||
if actualSig != expectedSig {
|
||||
t.Errorf("Expected thoughtSignature '%s', got '%s'", expectedSig, actualSig)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConvertClaudeRequestToAntigravity_ToolUse_WithSignature(t *testing.T) {
|
||||
validSignature := "abc123validSignature1234567890123456789012345678901234567890"
|
||||
inputJSON := []byte(`{
|
||||
"model": "claude-sonnet-4-5-thinking",
|
||||
"messages": [
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": [
|
||||
{"type": "thinking", "thinking": "Let me think...", "signature": "` + validSignature + `"},
|
||||
{
|
||||
"type": "tool_use",
|
||||
"id": "call_123",
|
||||
"name": "get_weather",
|
||||
"input": "{\"location\": \"Paris\"}"
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
}`)
|
||||
|
||||
output := ConvertClaudeRequestToAntigravity("claude-sonnet-4-5-thinking", inputJSON, false)
|
||||
outputStr := string(output)
|
||||
|
||||
// Check function call has the signature from the preceding thinking block
|
||||
part := gjson.Get(outputStr, "request.contents.0.parts.1")
|
||||
if part.Get("functionCall.name").String() != "get_weather" {
|
||||
t.Errorf("Expected functionCall, got %s", part.Raw)
|
||||
}
|
||||
if part.Get("thoughtSignature").String() != validSignature {
|
||||
t.Errorf("Expected thoughtSignature '%s' on tool_use, got '%s'", validSignature, part.Get("thoughtSignature").String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestConvertClaudeRequestToAntigravity_ReorderThinking(t *testing.T) {
|
||||
// Case: text block followed by thinking block -> should be reordered to thinking first
|
||||
validSignature := "abc123validSignature1234567890123456789012345678901234567890"
|
||||
inputJSON := []byte(`{
|
||||
"model": "claude-sonnet-4-5-thinking",
|
||||
"messages": [
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": [
|
||||
{"type": "text", "text": "Here is the plan."},
|
||||
{"type": "thinking", "thinking": "Planning...", "signature": "` + validSignature + `"}
|
||||
]
|
||||
}
|
||||
]
|
||||
}`)
|
||||
|
||||
output := ConvertClaudeRequestToAntigravity("claude-sonnet-4-5-thinking", inputJSON, false)
|
||||
outputStr := string(output)
|
||||
|
||||
// Verify order: Thinking block MUST be first
|
||||
parts := gjson.Get(outputStr, "request.contents.0.parts").Array()
|
||||
if len(parts) != 2 {
|
||||
t.Fatalf("Expected 2 parts, got %d", len(parts))
|
||||
}
|
||||
|
||||
if !parts[0].Get("thought").Bool() {
|
||||
t.Error("First part should be thinking block after reordering")
|
||||
}
|
||||
if parts[1].Get("text").String() != "Here is the plan." {
|
||||
t.Error("Second part should be text block")
|
||||
}
|
||||
}
|
||||
|
||||
func TestConvertClaudeRequestToAntigravity_ToolResult(t *testing.T) {
|
||||
inputJSON := []byte(`{
|
||||
"model": "claude-3-5-sonnet-20240620",
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "tool_result",
|
||||
"tool_use_id": "get_weather-call-123",
|
||||
"content": "22C sunny"
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
}`)
|
||||
|
||||
output := ConvertClaudeRequestToAntigravity("claude-sonnet-4-5", inputJSON, false)
|
||||
outputStr := string(output)
|
||||
|
||||
// Check function response conversion
|
||||
funcResp := gjson.Get(outputStr, "request.contents.0.parts.0.functionResponse")
|
||||
if !funcResp.Exists() {
|
||||
t.Error("functionResponse should exist")
|
||||
}
|
||||
if funcResp.Get("id").String() != "get_weather-call-123" {
|
||||
t.Errorf("Expected function id, got '%s'", funcResp.Get("id").String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestConvertClaudeRequestToAntigravity_ThinkingConfig(t *testing.T) {
|
||||
// Note: This test requires the model to be registered in the registry
|
||||
// with Thinking metadata. If the registry is not populated in test environment,
|
||||
// thinkingConfig won't be added. We'll test the basic structure only.
|
||||
inputJSON := []byte(`{
|
||||
"model": "claude-sonnet-4-5-thinking",
|
||||
"messages": [],
|
||||
"thinking": {
|
||||
"type": "enabled",
|
||||
"budget_tokens": 8000
|
||||
}
|
||||
}`)
|
||||
|
||||
output := ConvertClaudeRequestToAntigravity("claude-sonnet-4-5-thinking", inputJSON, false)
|
||||
outputStr := string(output)
|
||||
|
||||
// Check thinking config conversion (only if model supports thinking in registry)
|
||||
thinkingConfig := gjson.Get(outputStr, "request.generationConfig.thinkingConfig")
|
||||
if thinkingConfig.Exists() {
|
||||
if thinkingConfig.Get("thinkingBudget").Int() != 8000 {
|
||||
t.Errorf("Expected thinkingBudget 8000, got %d", thinkingConfig.Get("thinkingBudget").Int())
|
||||
}
|
||||
if !thinkingConfig.Get("include_thoughts").Bool() {
|
||||
t.Error("include_thoughts should be true")
|
||||
}
|
||||
} else {
|
||||
t.Log("thinkingConfig not present - model may not be registered in test registry")
|
||||
}
|
||||
}
|
||||
|
||||
func TestConvertClaudeRequestToAntigravity_ImageContent(t *testing.T) {
|
||||
inputJSON := []byte(`{
|
||||
"model": "claude-3-5-sonnet-20240620",
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [
|
||||
{
|
||||
"type": "image",
|
||||
"source": {
|
||||
"type": "base64",
|
||||
"media_type": "image/png",
|
||||
"data": "iVBORw0KGgoAAAANSUhEUg=="
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
]
|
||||
}`)
|
||||
|
||||
output := ConvertClaudeRequestToAntigravity("claude-sonnet-4-5", inputJSON, false)
|
||||
outputStr := string(output)
|
||||
|
||||
// Check inline data conversion
|
||||
inlineData := gjson.Get(outputStr, "request.contents.0.parts.0.inlineData")
|
||||
if !inlineData.Exists() {
|
||||
t.Error("inlineData should exist")
|
||||
}
|
||||
if inlineData.Get("mime_type").String() != "image/png" {
|
||||
t.Error("mime_type mismatch")
|
||||
}
|
||||
if !strings.Contains(inlineData.Get("data").String(), "iVBORw0KGgo") {
|
||||
t.Error("data mismatch")
|
||||
}
|
||||
}
|
||||
|
||||
func TestConvertClaudeRequestToAntigravity_GenerationConfig(t *testing.T) {
|
||||
inputJSON := []byte(`{
|
||||
"model": "claude-3-5-sonnet-20240620",
|
||||
"messages": [],
|
||||
"temperature": 0.7,
|
||||
"top_p": 0.9,
|
||||
"top_k": 40,
|
||||
"max_tokens": 2000
|
||||
}`)
|
||||
|
||||
output := ConvertClaudeRequestToAntigravity("claude-sonnet-4-5", inputJSON, false)
|
||||
outputStr := string(output)
|
||||
|
||||
genConfig := gjson.Get(outputStr, "request.generationConfig")
|
||||
if genConfig.Get("temperature").Float() != 0.7 {
|
||||
t.Errorf("Expected temperature 0.7, got %f", genConfig.Get("temperature").Float())
|
||||
}
|
||||
if genConfig.Get("topP").Float() != 0.9 {
|
||||
t.Errorf("Expected topP 0.9, got %f", genConfig.Get("topP").Float())
|
||||
}
|
||||
if genConfig.Get("topK").Float() != 40 {
|
||||
t.Errorf("Expected topK 40, got %f", genConfig.Get("topK").Float())
|
||||
}
|
||||
if genConfig.Get("maxOutputTokens").Float() != 2000 {
|
||||
t.Errorf("Expected maxOutputTokens 2000, got %f", genConfig.Get("maxOutputTokens").Float())
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Trailing Unsigned Thinking Block Removal
|
||||
// ============================================================================
|
||||
|
||||
func TestConvertClaudeRequestToAntigravity_TrailingUnsignedThinking_Removed(t *testing.T) {
|
||||
// Last assistant message ends with unsigned thinking block - should be removed
|
||||
inputJSON := []byte(`{
|
||||
"model": "claude-sonnet-4-5-thinking",
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [{"type": "text", "text": "Hello"}]
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": [
|
||||
{"type": "text", "text": "Here is my answer"},
|
||||
{"type": "thinking", "thinking": "I should think more..."}
|
||||
]
|
||||
}
|
||||
]
|
||||
}`)
|
||||
|
||||
output := ConvertClaudeRequestToAntigravity("claude-sonnet-4-5-thinking", inputJSON, false)
|
||||
outputStr := string(output)
|
||||
|
||||
// The last part of the last assistant message should NOT be a thinking block
|
||||
lastMessageParts := gjson.Get(outputStr, "request.contents.1.parts")
|
||||
if !lastMessageParts.IsArray() {
|
||||
t.Fatal("Last message should have parts array")
|
||||
}
|
||||
parts := lastMessageParts.Array()
|
||||
if len(parts) == 0 {
|
||||
t.Fatal("Last message should have at least one part")
|
||||
}
|
||||
|
||||
// The unsigned thinking should be removed, leaving only the text
|
||||
lastPart := parts[len(parts)-1]
|
||||
if lastPart.Get("thought").Bool() {
|
||||
t.Error("Trailing unsigned thinking block should be removed")
|
||||
}
|
||||
}
|
||||
|
||||
func TestConvertClaudeRequestToAntigravity_TrailingSignedThinking_Kept(t *testing.T) {
|
||||
// Last assistant message ends with signed thinking block - should be kept
|
||||
inputJSON := []byte(`{
|
||||
"model": "claude-sonnet-4-5-thinking",
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": [{"type": "text", "text": "Hello"}]
|
||||
},
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": [
|
||||
{"type": "text", "text": "Here is my answer"},
|
||||
{"type": "thinking", "thinking": "Valid thinking...", "signature": "abc123validSignature1234567890123456789012345678901234567890"}
|
||||
]
|
||||
}
|
||||
]
|
||||
}`)
|
||||
|
||||
output := ConvertClaudeRequestToAntigravity("claude-sonnet-4-5-thinking", inputJSON, false)
|
||||
outputStr := string(output)
|
||||
|
||||
// The signed thinking block should be preserved
|
||||
lastMessageParts := gjson.Get(outputStr, "request.contents.1.parts")
|
||||
parts := lastMessageParts.Array()
|
||||
if len(parts) < 2 {
|
||||
t.Error("Signed thinking block should be preserved")
|
||||
}
|
||||
}
|
||||
|
||||
func TestConvertClaudeRequestToAntigravity_MiddleUnsignedThinking_Removed(t *testing.T) {
|
||||
// Middle message has unsigned thinking - should be removed entirely
|
||||
inputJSON := []byte(`{
|
||||
"model": "claude-sonnet-4-5-thinking",
|
||||
"messages": [
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": [
|
||||
{"type": "thinking", "thinking": "Middle thinking..."},
|
||||
{"type": "text", "text": "Answer"}
|
||||
]
|
||||
},
|
||||
{
|
||||
"role": "user",
|
||||
"content": [{"type": "text", "text": "Follow up"}]
|
||||
}
|
||||
]
|
||||
}`)
|
||||
|
||||
output := ConvertClaudeRequestToAntigravity("claude-sonnet-4-5-thinking", inputJSON, false)
|
||||
outputStr := string(output)
|
||||
|
||||
// Unsigned thinking should be removed entirely
|
||||
parts := gjson.Get(outputStr, "request.contents.0.parts").Array()
|
||||
if len(parts) != 1 {
|
||||
t.Fatalf("Expected 1 part (thinking removed), got %d", len(parts))
|
||||
}
|
||||
|
||||
// Only text part should remain
|
||||
if parts[0].Get("thought").Bool() {
|
||||
t.Error("Thinking block should be removed, not preserved")
|
||||
}
|
||||
if parts[0].Get("text").String() != "Answer" {
|
||||
t.Errorf("Expected text 'Answer', got '%s'", parts[0].Get("text").String())
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Tool + Thinking System Hint Injection
|
||||
// ============================================================================
|
||||
|
||||
func TestConvertClaudeRequestToAntigravity_ToolAndThinking_HintInjected(t *testing.T) {
|
||||
// When both tools and thinking are enabled, hint should be injected into system instruction
|
||||
inputJSON := []byte(`{
|
||||
"model": "claude-sonnet-4-5-thinking",
|
||||
"messages": [{"role": "user", "content": [{"type": "text", "text": "Hello"}]}],
|
||||
"system": [{"type": "text", "text": "You are helpful."}],
|
||||
"tools": [
|
||||
{
|
||||
"name": "get_weather",
|
||||
"description": "Get weather",
|
||||
"input_schema": {"type": "object", "properties": {"location": {"type": "string"}}}
|
||||
}
|
||||
],
|
||||
"thinking": {"type": "enabled", "budget_tokens": 8000}
|
||||
}`)
|
||||
|
||||
output := ConvertClaudeRequestToAntigravity("claude-sonnet-4-5-thinking", inputJSON, false)
|
||||
outputStr := string(output)
|
||||
|
||||
// System instruction should contain the interleaved thinking hint
|
||||
sysInstruction := gjson.Get(outputStr, "request.systemInstruction")
|
||||
if !sysInstruction.Exists() {
|
||||
t.Fatal("systemInstruction should exist")
|
||||
}
|
||||
|
||||
// Check if hint is appended
|
||||
sysText := sysInstruction.Get("parts").Array()
|
||||
found := false
|
||||
for _, part := range sysText {
|
||||
if strings.Contains(part.Get("text").String(), "Interleaved thinking is enabled") {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
t.Errorf("Interleaved thinking hint should be injected when tools and thinking are both active, got: %v", sysInstruction.Raw)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConvertClaudeRequestToAntigravity_ToolsOnly_NoHint(t *testing.T) {
|
||||
// When only tools are present (no thinking), hint should NOT be injected
|
||||
inputJSON := []byte(`{
|
||||
"model": "claude-sonnet-4-5",
|
||||
"messages": [{"role": "user", "content": [{"type": "text", "text": "Hello"}]}],
|
||||
"system": [{"type": "text", "text": "You are helpful."}],
|
||||
"tools": [
|
||||
{
|
||||
"name": "get_weather",
|
||||
"description": "Get weather",
|
||||
"input_schema": {"type": "object", "properties": {"location": {"type": "string"}}}
|
||||
}
|
||||
]
|
||||
}`)
|
||||
|
||||
output := ConvertClaudeRequestToAntigravity("claude-sonnet-4-5", inputJSON, false)
|
||||
outputStr := string(output)
|
||||
|
||||
// System instruction should NOT contain the hint
|
||||
sysInstruction := gjson.Get(outputStr, "request.systemInstruction")
|
||||
if sysInstruction.Exists() {
|
||||
for _, part := range sysInstruction.Get("parts").Array() {
|
||||
if strings.Contains(part.Get("text").String(), "Interleaved thinking is enabled") {
|
||||
t.Error("Hint should NOT be injected when only tools are present (no thinking)")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestConvertClaudeRequestToAntigravity_ThinkingOnly_NoHint(t *testing.T) {
|
||||
// When only thinking is enabled (no tools), hint should NOT be injected
|
||||
inputJSON := []byte(`{
|
||||
"model": "claude-sonnet-4-5-thinking",
|
||||
"messages": [{"role": "user", "content": [{"type": "text", "text": "Hello"}]}],
|
||||
"system": [{"type": "text", "text": "You are helpful."}],
|
||||
"thinking": {"type": "enabled", "budget_tokens": 8000}
|
||||
}`)
|
||||
|
||||
output := ConvertClaudeRequestToAntigravity("claude-sonnet-4-5-thinking", inputJSON, false)
|
||||
outputStr := string(output)
|
||||
|
||||
// System instruction should NOT contain the hint (no tools)
|
||||
sysInstruction := gjson.Get(outputStr, "request.systemInstruction")
|
||||
if sysInstruction.Exists() {
|
||||
for _, part := range sysInstruction.Get("parts").Array() {
|
||||
if strings.Contains(part.Get("text").String(), "Interleaved thinking is enabled") {
|
||||
t.Error("Hint should NOT be injected when only thinking is present (no tools)")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestConvertClaudeRequestToAntigravity_ToolAndThinking_NoExistingSystem(t *testing.T) {
|
||||
// When tools + thinking but no system instruction, should create one with hint
|
||||
inputJSON := []byte(`{
|
||||
"model": "claude-sonnet-4-5-thinking",
|
||||
"messages": [{"role": "user", "content": [{"type": "text", "text": "Hello"}]}],
|
||||
"tools": [
|
||||
{
|
||||
"name": "get_weather",
|
||||
"description": "Get weather",
|
||||
"input_schema": {"type": "object", "properties": {"location": {"type": "string"}}}
|
||||
}
|
||||
],
|
||||
"thinking": {"type": "enabled", "budget_tokens": 8000}
|
||||
}`)
|
||||
|
||||
output := ConvertClaudeRequestToAntigravity("claude-sonnet-4-5-thinking", inputJSON, false)
|
||||
outputStr := string(output)
|
||||
|
||||
// System instruction should be created with hint
|
||||
sysInstruction := gjson.Get(outputStr, "request.systemInstruction")
|
||||
if !sysInstruction.Exists() {
|
||||
t.Fatal("systemInstruction should be created when tools + thinking are active")
|
||||
}
|
||||
|
||||
sysText := sysInstruction.Get("parts").Array()
|
||||
found := false
|
||||
for _, part := range sysText {
|
||||
if strings.Contains(part.Get("text").String(), "Interleaved thinking is enabled") {
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
if !found {
|
||||
t.Errorf("Interleaved thinking hint should be in created systemInstruction, got: %v", sysInstruction.Raw)
|
||||
}
|
||||
}
|
||||
@@ -14,7 +14,9 @@ import (
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/cache"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/tidwall/gjson"
|
||||
"github.com/tidwall/sjson"
|
||||
)
|
||||
@@ -37,6 +39,10 @@ type Params struct {
|
||||
HasSentFinalEvents bool // Indicates if final content/message events have been sent
|
||||
HasToolUse bool // Indicates if tool use was observed in the stream
|
||||
HasContent bool // Tracks whether any content (text, thinking, or tool use) has been output
|
||||
|
||||
// Signature caching support
|
||||
SessionID string // Session ID derived from request for signature caching
|
||||
CurrentThinkingText strings.Builder // Accumulates thinking text for signature caching
|
||||
}
|
||||
|
||||
// toolUseIDCounter provides a process-wide unique counter for tool use identifiers.
|
||||
@@ -64,6 +70,7 @@ func ConvertAntigravityResponseToClaude(_ context.Context, _ string, originalReq
|
||||
HasFirstResponse: false,
|
||||
ResponseType: 0,
|
||||
ResponseIndex: 0,
|
||||
SessionID: deriveSessionID(originalRequestRawJSON),
|
||||
}
|
||||
}
|
||||
|
||||
@@ -121,11 +128,20 @@ func ConvertAntigravityResponseToClaude(_ context.Context, _ string, originalReq
|
||||
// Process thinking content (internal reasoning)
|
||||
if partResult.Get("thought").Bool() {
|
||||
if thoughtSignature := partResult.Get("thoughtSignature"); thoughtSignature.Exists() && thoughtSignature.String() != "" {
|
||||
log.Debug("Branch: signature_delta")
|
||||
|
||||
if params.SessionID != "" && params.CurrentThinkingText.Len() > 0 {
|
||||
cache.CacheSignature(params.SessionID, params.CurrentThinkingText.String(), thoughtSignature.String())
|
||||
log.Debugf("Cached signature for thinking block (sessionID=%s, textLen=%d)", params.SessionID, params.CurrentThinkingText.Len())
|
||||
params.CurrentThinkingText.Reset()
|
||||
}
|
||||
|
||||
output = output + "event: content_block_delta\n"
|
||||
data, _ := sjson.Set(fmt.Sprintf(`{"type":"content_block_delta","index":%d,"delta":{"type":"signature_delta","signature":""}}`, params.ResponseIndex), "delta.signature", thoughtSignature.String())
|
||||
output = output + fmt.Sprintf("data: %s\n\n\n", data)
|
||||
params.HasContent = true
|
||||
} else if params.ResponseType == 2 { // Continue existing thinking block if already in thinking state
|
||||
params.CurrentThinkingText.WriteString(partTextResult.String())
|
||||
output = output + "event: content_block_delta\n"
|
||||
data, _ := sjson.Set(fmt.Sprintf(`{"type":"content_block_delta","index":%d,"delta":{"type":"thinking_delta","thinking":""}}`, params.ResponseIndex), "delta.thinking", partTextResult.String())
|
||||
output = output + fmt.Sprintf("data: %s\n\n\n", data)
|
||||
@@ -154,6 +170,9 @@ func ConvertAntigravityResponseToClaude(_ context.Context, _ string, originalReq
|
||||
output = output + fmt.Sprintf("data: %s\n\n\n", data)
|
||||
params.ResponseType = 2 // Set state to thinking
|
||||
params.HasContent = true
|
||||
// Start accumulating thinking text for signature caching
|
||||
params.CurrentThinkingText.Reset()
|
||||
params.CurrentThinkingText.WriteString(partTextResult.String())
|
||||
}
|
||||
} else {
|
||||
finishReasonResult := gjson.GetBytes(rawJSON, "response.candidates.0.finishReason")
|
||||
|
||||
@@ -0,0 +1,316 @@
|
||||
package claude
|
||||
|
||||
import (
|
||||
"context"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/cache"
|
||||
)
|
||||
|
||||
// ============================================================================
|
||||
// Signature Caching Tests
|
||||
// ============================================================================
|
||||
|
||||
func TestConvertAntigravityResponseToClaude_SessionIDDerived(t *testing.T) {
|
||||
cache.ClearSignatureCache("")
|
||||
|
||||
// Request with user message - should derive session ID
|
||||
requestJSON := []byte(`{
|
||||
"messages": [
|
||||
{"role": "user", "content": [{"type": "text", "text": "Hello world"}]}
|
||||
]
|
||||
}`)
|
||||
|
||||
// First response chunk with thinking
|
||||
responseJSON := []byte(`{
|
||||
"response": {
|
||||
"candidates": [{
|
||||
"content": {
|
||||
"parts": [{"text": "Let me think...", "thought": true}]
|
||||
}
|
||||
}]
|
||||
}
|
||||
}`)
|
||||
|
||||
var param any
|
||||
ctx := context.Background()
|
||||
ConvertAntigravityResponseToClaude(ctx, "claude-sonnet-4-5-thinking", requestJSON, requestJSON, responseJSON, ¶m)
|
||||
|
||||
// Verify session ID was set
|
||||
params := param.(*Params)
|
||||
if params.SessionID == "" {
|
||||
t.Error("SessionID should be derived from request")
|
||||
}
|
||||
}
|
||||
|
||||
func TestConvertAntigravityResponseToClaude_ThinkingTextAccumulated(t *testing.T) {
|
||||
cache.ClearSignatureCache("")
|
||||
|
||||
requestJSON := []byte(`{
|
||||
"messages": [{"role": "user", "content": [{"type": "text", "text": "Test"}]}]
|
||||
}`)
|
||||
|
||||
// First thinking chunk
|
||||
chunk1 := []byte(`{
|
||||
"response": {
|
||||
"candidates": [{
|
||||
"content": {
|
||||
"parts": [{"text": "First part of thinking...", "thought": true}]
|
||||
}
|
||||
}]
|
||||
}
|
||||
}`)
|
||||
|
||||
// Second thinking chunk (continuation)
|
||||
chunk2 := []byte(`{
|
||||
"response": {
|
||||
"candidates": [{
|
||||
"content": {
|
||||
"parts": [{"text": " Second part of thinking...", "thought": true}]
|
||||
}
|
||||
}]
|
||||
}
|
||||
}`)
|
||||
|
||||
var param any
|
||||
ctx := context.Background()
|
||||
|
||||
// Process first chunk - starts new thinking block
|
||||
ConvertAntigravityResponseToClaude(ctx, "claude-sonnet-4-5-thinking", requestJSON, requestJSON, chunk1, ¶m)
|
||||
params := param.(*Params)
|
||||
|
||||
if params.CurrentThinkingText.Len() == 0 {
|
||||
t.Error("Thinking text should be accumulated after first chunk")
|
||||
}
|
||||
|
||||
// Process second chunk - continues thinking block
|
||||
ConvertAntigravityResponseToClaude(ctx, "claude-sonnet-4-5-thinking", requestJSON, requestJSON, chunk2, ¶m)
|
||||
|
||||
text := params.CurrentThinkingText.String()
|
||||
if !strings.Contains(text, "First part") || !strings.Contains(text, "Second part") {
|
||||
t.Errorf("Thinking text should accumulate both parts, got: %s", text)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConvertAntigravityResponseToClaude_SignatureCached(t *testing.T) {
|
||||
cache.ClearSignatureCache("")
|
||||
|
||||
requestJSON := []byte(`{
|
||||
"messages": [{"role": "user", "content": [{"type": "text", "text": "Cache test"}]}]
|
||||
}`)
|
||||
|
||||
// Thinking chunk
|
||||
thinkingChunk := []byte(`{
|
||||
"response": {
|
||||
"candidates": [{
|
||||
"content": {
|
||||
"parts": [{"text": "My thinking process here", "thought": true}]
|
||||
}
|
||||
}]
|
||||
}
|
||||
}`)
|
||||
|
||||
// Signature chunk
|
||||
validSignature := "abc123validSignature1234567890123456789012345678901234567890"
|
||||
signatureChunk := []byte(`{
|
||||
"response": {
|
||||
"candidates": [{
|
||||
"content": {
|
||||
"parts": [{"text": "", "thought": true, "thoughtSignature": "` + validSignature + `"}]
|
||||
}
|
||||
}]
|
||||
}
|
||||
}`)
|
||||
|
||||
var param any
|
||||
ctx := context.Background()
|
||||
|
||||
// Process thinking chunk
|
||||
ConvertAntigravityResponseToClaude(ctx, "claude-sonnet-4-5-thinking", requestJSON, requestJSON, thinkingChunk, ¶m)
|
||||
params := param.(*Params)
|
||||
sessionID := params.SessionID
|
||||
thinkingText := params.CurrentThinkingText.String()
|
||||
|
||||
if sessionID == "" {
|
||||
t.Fatal("SessionID should be set")
|
||||
}
|
||||
if thinkingText == "" {
|
||||
t.Fatal("Thinking text should be accumulated")
|
||||
}
|
||||
|
||||
// Process signature chunk - should cache the signature
|
||||
ConvertAntigravityResponseToClaude(ctx, "claude-sonnet-4-5-thinking", requestJSON, requestJSON, signatureChunk, ¶m)
|
||||
|
||||
// Verify signature was cached
|
||||
cachedSig := cache.GetCachedSignature(sessionID, thinkingText)
|
||||
if cachedSig != validSignature {
|
||||
t.Errorf("Expected cached signature '%s', got '%s'", validSignature, cachedSig)
|
||||
}
|
||||
|
||||
// Verify thinking text was reset after caching
|
||||
if params.CurrentThinkingText.Len() != 0 {
|
||||
t.Error("Thinking text should be reset after signature is cached")
|
||||
}
|
||||
}
|
||||
|
||||
func TestConvertAntigravityResponseToClaude_MultipleThinkingBlocks(t *testing.T) {
|
||||
cache.ClearSignatureCache("")
|
||||
|
||||
requestJSON := []byte(`{
|
||||
"messages": [{"role": "user", "content": [{"type": "text", "text": "Multi block test"}]}]
|
||||
}`)
|
||||
|
||||
validSig1 := "signature1_12345678901234567890123456789012345678901234567"
|
||||
validSig2 := "signature2_12345678901234567890123456789012345678901234567"
|
||||
|
||||
// First thinking block with signature
|
||||
block1Thinking := []byte(`{
|
||||
"response": {
|
||||
"candidates": [{
|
||||
"content": {
|
||||
"parts": [{"text": "First thinking block", "thought": true}]
|
||||
}
|
||||
}]
|
||||
}
|
||||
}`)
|
||||
block1Sig := []byte(`{
|
||||
"response": {
|
||||
"candidates": [{
|
||||
"content": {
|
||||
"parts": [{"text": "", "thought": true, "thoughtSignature": "` + validSig1 + `"}]
|
||||
}
|
||||
}]
|
||||
}
|
||||
}`)
|
||||
|
||||
// Text content (breaks thinking)
|
||||
textBlock := []byte(`{
|
||||
"response": {
|
||||
"candidates": [{
|
||||
"content": {
|
||||
"parts": [{"text": "Regular text output"}]
|
||||
}
|
||||
}]
|
||||
}
|
||||
}`)
|
||||
|
||||
// Second thinking block with signature
|
||||
block2Thinking := []byte(`{
|
||||
"response": {
|
||||
"candidates": [{
|
||||
"content": {
|
||||
"parts": [{"text": "Second thinking block", "thought": true}]
|
||||
}
|
||||
}]
|
||||
}
|
||||
}`)
|
||||
block2Sig := []byte(`{
|
||||
"response": {
|
||||
"candidates": [{
|
||||
"content": {
|
||||
"parts": [{"text": "", "thought": true, "thoughtSignature": "` + validSig2 + `"}]
|
||||
}
|
||||
}]
|
||||
}
|
||||
}`)
|
||||
|
||||
var param any
|
||||
ctx := context.Background()
|
||||
|
||||
// Process first thinking block
|
||||
ConvertAntigravityResponseToClaude(ctx, "claude-sonnet-4-5-thinking", requestJSON, requestJSON, block1Thinking, ¶m)
|
||||
params := param.(*Params)
|
||||
sessionID := params.SessionID
|
||||
firstThinkingText := params.CurrentThinkingText.String()
|
||||
|
||||
ConvertAntigravityResponseToClaude(ctx, "claude-sonnet-4-5-thinking", requestJSON, requestJSON, block1Sig, ¶m)
|
||||
|
||||
// Verify first signature cached
|
||||
if cache.GetCachedSignature(sessionID, firstThinkingText) != validSig1 {
|
||||
t.Error("First thinking block signature should be cached")
|
||||
}
|
||||
|
||||
// Process text (transitions out of thinking)
|
||||
ConvertAntigravityResponseToClaude(ctx, "claude-sonnet-4-5-thinking", requestJSON, requestJSON, textBlock, ¶m)
|
||||
|
||||
// Process second thinking block
|
||||
ConvertAntigravityResponseToClaude(ctx, "claude-sonnet-4-5-thinking", requestJSON, requestJSON, block2Thinking, ¶m)
|
||||
secondThinkingText := params.CurrentThinkingText.String()
|
||||
|
||||
ConvertAntigravityResponseToClaude(ctx, "claude-sonnet-4-5-thinking", requestJSON, requestJSON, block2Sig, ¶m)
|
||||
|
||||
// Verify second signature cached
|
||||
if cache.GetCachedSignature(sessionID, secondThinkingText) != validSig2 {
|
||||
t.Error("Second thinking block signature should be cached")
|
||||
}
|
||||
}
|
||||
|
||||
func TestDeriveSessionIDFromRequest(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
input []byte
|
||||
wantEmpty bool
|
||||
}{
|
||||
{
|
||||
name: "valid user message",
|
||||
input: []byte(`{"messages": [{"role": "user", "content": "Hello"}]}`),
|
||||
wantEmpty: false,
|
||||
},
|
||||
{
|
||||
name: "user message with content array",
|
||||
input: []byte(`{"messages": [{"role": "user", "content": [{"type": "text", "text": "Hello"}]}]}`),
|
||||
wantEmpty: false,
|
||||
},
|
||||
{
|
||||
name: "no user message",
|
||||
input: []byte(`{"messages": [{"role": "assistant", "content": "Hi"}]}`),
|
||||
wantEmpty: true,
|
||||
},
|
||||
{
|
||||
name: "empty messages",
|
||||
input: []byte(`{"messages": []}`),
|
||||
wantEmpty: true,
|
||||
},
|
||||
{
|
||||
name: "no messages field",
|
||||
input: []byte(`{}`),
|
||||
wantEmpty: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := deriveSessionID(tt.input)
|
||||
if tt.wantEmpty && result != "" {
|
||||
t.Errorf("Expected empty session ID, got '%s'", result)
|
||||
}
|
||||
if !tt.wantEmpty && result == "" {
|
||||
t.Error("Expected non-empty session ID")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDeriveSessionIDFromRequest_Deterministic(t *testing.T) {
|
||||
input := []byte(`{"messages": [{"role": "user", "content": "Same message"}]}`)
|
||||
|
||||
id1 := deriveSessionID(input)
|
||||
id2 := deriveSessionID(input)
|
||||
|
||||
if id1 != id2 {
|
||||
t.Errorf("Session ID should be deterministic: '%s' != '%s'", id1, id2)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDeriveSessionIDFromRequest_DifferentMessages(t *testing.T) {
|
||||
input1 := []byte(`{"messages": [{"role": "user", "content": "Message A"}]}`)
|
||||
input2 := []byte(`{"messages": [{"role": "user", "content": "Message B"}]}`)
|
||||
|
||||
id1 := deriveSessionID(input1)
|
||||
id2 := deriveSessionID(input2)
|
||||
|
||||
if id1 == id2 {
|
||||
t.Error("Different messages should produce different session IDs")
|
||||
}
|
||||
}
|
||||
@@ -98,16 +98,34 @@ func ConvertGeminiRequestToAntigravity(_ string, inputRawJSON []byte, _ bool) []
|
||||
}
|
||||
}
|
||||
|
||||
gjson.GetBytes(rawJSON, "request.contents").ForEach(func(key, content gjson.Result) bool {
|
||||
// Gemini-specific handling: add skip_thought_signature_validator to functionCall parts
|
||||
// and remove thinking blocks entirely (Gemini doesn't need to preserve them)
|
||||
const skipSentinel = "skip_thought_signature_validator"
|
||||
|
||||
gjson.GetBytes(rawJSON, "request.contents").ForEach(func(contentIdx, content gjson.Result) bool {
|
||||
if content.Get("role").String() == "model" {
|
||||
content.Get("parts").ForEach(func(partKey, part gjson.Result) bool {
|
||||
// First pass: collect indices of thinking parts to remove
|
||||
var thinkingIndicesToRemove []int64
|
||||
content.Get("parts").ForEach(func(partIdx, part gjson.Result) bool {
|
||||
// Mark thinking blocks for removal
|
||||
if part.Get("thought").Bool() {
|
||||
thinkingIndicesToRemove = append(thinkingIndicesToRemove, partIdx.Int())
|
||||
}
|
||||
// Add skip sentinel to functionCall parts
|
||||
if part.Get("functionCall").Exists() {
|
||||
rawJSON, _ = sjson.SetBytes(rawJSON, fmt.Sprintf("request.contents.%d.parts.%d.thoughtSignature", key.Int(), partKey.Int()), "skip_thought_signature_validator")
|
||||
} else if part.Get("thoughtSignature").Exists() {
|
||||
rawJSON, _ = sjson.SetBytes(rawJSON, fmt.Sprintf("request.contents.%d.parts.%d.thoughtSignature", key.Int(), partKey.Int()), "skip_thought_signature_validator")
|
||||
existingSig := part.Get("thoughtSignature").String()
|
||||
if existingSig == "" || len(existingSig) < 50 {
|
||||
rawJSON, _ = sjson.SetBytes(rawJSON, fmt.Sprintf("request.contents.%d.parts.%d.thoughtSignature", contentIdx.Int(), partIdx.Int()), skipSentinel)
|
||||
}
|
||||
}
|
||||
return true
|
||||
})
|
||||
|
||||
// Remove thinking blocks in reverse order to preserve indices
|
||||
for i := len(thinkingIndicesToRemove) - 1; i >= 0; i-- {
|
||||
idx := thinkingIndicesToRemove[i]
|
||||
rawJSON, _ = sjson.DeleteBytes(rawJSON, fmt.Sprintf("request.contents.%d.parts.%d", contentIdx.Int(), idx))
|
||||
}
|
||||
}
|
||||
return true
|
||||
})
|
||||
|
||||
@@ -0,0 +1,129 @@
|
||||
package gemini
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
|
||||
func TestConvertGeminiRequestToAntigravity_PreserveValidSignature(t *testing.T) {
|
||||
// Valid signature on functionCall should be preserved
|
||||
validSignature := "abc123validSignature1234567890123456789012345678901234567890"
|
||||
inputJSON := []byte(fmt.Sprintf(`{
|
||||
"model": "gemini-3-pro-preview",
|
||||
"contents": [
|
||||
{
|
||||
"role": "model",
|
||||
"parts": [
|
||||
{"functionCall": {"name": "test_tool", "args": {}}, "thoughtSignature": "%s"}
|
||||
]
|
||||
}
|
||||
]
|
||||
}`, validSignature))
|
||||
|
||||
output := ConvertGeminiRequestToAntigravity("gemini-3-pro-preview", inputJSON, false)
|
||||
outputStr := string(output)
|
||||
|
||||
// Check that valid thoughtSignature is preserved
|
||||
parts := gjson.Get(outputStr, "request.contents.0.parts").Array()
|
||||
if len(parts) != 1 {
|
||||
t.Fatalf("Expected 1 part, got %d", len(parts))
|
||||
}
|
||||
|
||||
sig := parts[0].Get("thoughtSignature").String()
|
||||
if sig != validSignature {
|
||||
t.Errorf("Expected thoughtSignature '%s', got '%s'", validSignature, sig)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConvertGeminiRequestToAntigravity_AddSkipSentinelToFunctionCall(t *testing.T) {
|
||||
// functionCall without signature should get skip_thought_signature_validator
|
||||
inputJSON := []byte(`{
|
||||
"model": "gemini-3-pro-preview",
|
||||
"contents": [
|
||||
{
|
||||
"role": "model",
|
||||
"parts": [
|
||||
{"functionCall": {"name": "test_tool", "args": {}}}
|
||||
]
|
||||
}
|
||||
]
|
||||
}`)
|
||||
|
||||
output := ConvertGeminiRequestToAntigravity("gemini-3-pro-preview", inputJSON, false)
|
||||
outputStr := string(output)
|
||||
|
||||
// Check that skip_thought_signature_validator is added to functionCall
|
||||
sig := gjson.Get(outputStr, "request.contents.0.parts.0.thoughtSignature").String()
|
||||
expectedSig := "skip_thought_signature_validator"
|
||||
if sig != expectedSig {
|
||||
t.Errorf("Expected skip sentinel '%s', got '%s'", expectedSig, sig)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConvertGeminiRequestToAntigravity_RemoveThinkingBlocks(t *testing.T) {
|
||||
// Thinking blocks should be removed entirely for Gemini
|
||||
validSignature := "abc123validSignature1234567890123456789012345678901234567890"
|
||||
inputJSON := []byte(fmt.Sprintf(`{
|
||||
"model": "gemini-3-pro-preview",
|
||||
"contents": [
|
||||
{
|
||||
"role": "model",
|
||||
"parts": [
|
||||
{"thought": true, "text": "Thinking...", "thoughtSignature": "%s"},
|
||||
{"text": "Here is my response"}
|
||||
]
|
||||
}
|
||||
]
|
||||
}`, validSignature))
|
||||
|
||||
output := ConvertGeminiRequestToAntigravity("gemini-3-pro-preview", inputJSON, false)
|
||||
outputStr := string(output)
|
||||
|
||||
// Check that thinking block is removed
|
||||
parts := gjson.Get(outputStr, "request.contents.0.parts").Array()
|
||||
if len(parts) != 1 {
|
||||
t.Fatalf("Expected 1 part (thinking removed), got %d", len(parts))
|
||||
}
|
||||
|
||||
// Only text part should remain
|
||||
if parts[0].Get("thought").Bool() {
|
||||
t.Error("Thinking block should be removed for Gemini")
|
||||
}
|
||||
if parts[0].Get("text").String() != "Here is my response" {
|
||||
t.Errorf("Expected text 'Here is my response', got '%s'", parts[0].Get("text").String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestConvertGeminiRequestToAntigravity_ParallelFunctionCalls(t *testing.T) {
|
||||
// Multiple functionCalls should all get skip_thought_signature_validator
|
||||
inputJSON := []byte(`{
|
||||
"model": "gemini-3-pro-preview",
|
||||
"contents": [
|
||||
{
|
||||
"role": "model",
|
||||
"parts": [
|
||||
{"functionCall": {"name": "tool_one", "args": {"a": "1"}}},
|
||||
{"functionCall": {"name": "tool_two", "args": {"b": "2"}}}
|
||||
]
|
||||
}
|
||||
]
|
||||
}`)
|
||||
|
||||
output := ConvertGeminiRequestToAntigravity("gemini-3-pro-preview", inputJSON, false)
|
||||
outputStr := string(output)
|
||||
|
||||
parts := gjson.Get(outputStr, "request.contents.0.parts").Array()
|
||||
if len(parts) != 2 {
|
||||
t.Fatalf("Expected 2 parts, got %d", len(parts))
|
||||
}
|
||||
|
||||
expectedSig := "skip_thought_signature_validator"
|
||||
for i, part := range parts {
|
||||
sig := part.Get("thoughtSignature").String()
|
||||
if sig != expectedSig {
|
||||
t.Errorf("Part %d: Expected '%s', got '%s'", i, expectedSig, sig)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -39,8 +39,23 @@ func ConvertOpenAIRequestToAntigravity(modelName string, inputRawJSON []byte, _
|
||||
// Note: OpenAI official fields take precedence over extra_body.google.thinking_config
|
||||
re := gjson.GetBytes(rawJSON, "reasoning_effort")
|
||||
hasOfficialThinking := re.Exists()
|
||||
if hasOfficialThinking && util.ModelSupportsThinking(modelName) && !util.ModelUsesThinkingLevels(modelName) {
|
||||
out = util.ApplyReasoningEffortToGeminiCLI(out, re.String())
|
||||
if hasOfficialThinking && util.ModelSupportsThinking(modelName) {
|
||||
effort := strings.ToLower(strings.TrimSpace(re.String()))
|
||||
if util.IsGemini3Model(modelName) {
|
||||
switch effort {
|
||||
case "none":
|
||||
out, _ = sjson.DeleteBytes(out, "request.generationConfig.thinkingConfig")
|
||||
case "auto":
|
||||
includeThoughts := true
|
||||
out = util.ApplyGeminiCLIThinkingLevel(out, "", &includeThoughts)
|
||||
default:
|
||||
if level, ok := util.ValidateGemini3ThinkingLevel(modelName, effort); ok {
|
||||
out = util.ApplyGeminiCLIThinkingLevel(out, level, nil)
|
||||
}
|
||||
}
|
||||
} else if !util.ModelUsesThinkingLevels(modelName) {
|
||||
out = util.ApplyReasoningEffortToGeminiCLI(out, effort)
|
||||
}
|
||||
}
|
||||
|
||||
// Cherry Studio extension extra_body.google.thinking_config (effective only when official fields are absent)
|
||||
|
||||
@@ -95,7 +95,7 @@ func ConvertClaudeResponseToOpenAIResponses(ctx context.Context, modelName strin
|
||||
}
|
||||
}
|
||||
// response.created
|
||||
created := `{"type":"response.created","sequence_number":0,"response":{"id":"","object":"response","created_at":0,"status":"in_progress","background":false,"error":null,"instructions":""}}`
|
||||
created := `{"type":"response.created","sequence_number":0,"response":{"id":"","object":"response","created_at":0,"status":"in_progress","background":false,"error":null,"output":[]}}`
|
||||
created, _ = sjson.Set(created, "sequence_number", nextSeq())
|
||||
created, _ = sjson.Set(created, "response.id", st.ResponseID)
|
||||
created, _ = sjson.Set(created, "response.created_at", st.CreatedAt)
|
||||
@@ -197,11 +197,11 @@ func ConvertClaudeResponseToOpenAIResponses(ctx context.Context, modelName strin
|
||||
if st.ReasoningActive {
|
||||
if t := d.Get("thinking"); t.Exists() {
|
||||
st.ReasoningBuf.WriteString(t.String())
|
||||
msg := `{"type":"response.reasoning_summary_text.delta","sequence_number":0,"item_id":"","output_index":0,"summary_index":0,"text":""}`
|
||||
msg := `{"type":"response.reasoning_summary_text.delta","sequence_number":0,"item_id":"","output_index":0,"summary_index":0,"delta":""}`
|
||||
msg, _ = sjson.Set(msg, "sequence_number", nextSeq())
|
||||
msg, _ = sjson.Set(msg, "item_id", st.ReasoningItemID)
|
||||
msg, _ = sjson.Set(msg, "output_index", st.ReasoningIndex)
|
||||
msg, _ = sjson.Set(msg, "text", t.String())
|
||||
msg, _ = sjson.Set(msg, "delta", t.String())
|
||||
out = append(out, emitEvent("response.reasoning_summary_text.delta", msg))
|
||||
}
|
||||
}
|
||||
|
||||
@@ -134,6 +134,8 @@ func ConvertClaudeRequestToCLI(modelName string, inputRawJSON []byte, _ bool) []
|
||||
tool, _ = sjson.SetRaw(tool, "parametersJsonSchema", inputSchema)
|
||||
tool, _ = sjson.Delete(tool, "strict")
|
||||
tool, _ = sjson.Delete(tool, "input_examples")
|
||||
tool, _ = sjson.Delete(tool, "type")
|
||||
tool, _ = sjson.Delete(tool, "cache_control")
|
||||
var toolDeclaration any
|
||||
if err := json.Unmarshal([]byte(tool), &toolDeclaration); err == nil {
|
||||
tools[0].FunctionDeclarations = append(tools[0].FunctionDeclarations, toolDeclaration)
|
||||
|
||||
@@ -127,6 +127,8 @@ func ConvertClaudeRequestToGemini(modelName string, inputRawJSON []byte, _ bool)
|
||||
tool, _ = sjson.SetRaw(tool, "parametersJsonSchema", inputSchema)
|
||||
tool, _ = sjson.Delete(tool, "strict")
|
||||
tool, _ = sjson.Delete(tool, "input_examples")
|
||||
tool, _ = sjson.Delete(tool, "type")
|
||||
tool, _ = sjson.Delete(tool, "cache_control")
|
||||
var toolDeclaration any
|
||||
if err := json.Unmarshal([]byte(tool), &toolDeclaration); err == nil {
|
||||
tools[0].FunctionDeclarations = append(tools[0].FunctionDeclarations, toolDeclaration)
|
||||
|
||||
@@ -37,12 +37,28 @@ func ConvertOpenAIRequestToGemini(modelName string, inputRawJSON []byte, _ bool)
|
||||
|
||||
// Reasoning effort -> thinkingBudget/include_thoughts
|
||||
// Note: OpenAI official fields take precedence over extra_body.google.thinking_config
|
||||
// Only convert for models that use numeric budgets (not discrete levels) to avoid
|
||||
// incorrectly applying thinkingBudget for level-based models like gpt-5.
|
||||
// Only apply numeric budgets for models that use budgets (not discrete levels) to avoid
|
||||
// incorrectly applying thinkingBudget for level-based models like gpt-5. Gemini 3 models
|
||||
// use thinkingLevel/includeThoughts instead.
|
||||
re := gjson.GetBytes(rawJSON, "reasoning_effort")
|
||||
hasOfficialThinking := re.Exists()
|
||||
if hasOfficialThinking && util.ModelSupportsThinking(modelName) && !util.ModelUsesThinkingLevels(modelName) {
|
||||
out = util.ApplyReasoningEffortToGemini(out, re.String())
|
||||
if hasOfficialThinking && util.ModelSupportsThinking(modelName) {
|
||||
effort := strings.ToLower(strings.TrimSpace(re.String()))
|
||||
if util.IsGemini3Model(modelName) {
|
||||
switch effort {
|
||||
case "none":
|
||||
out, _ = sjson.DeleteBytes(out, "generationConfig.thinkingConfig")
|
||||
case "auto":
|
||||
includeThoughts := true
|
||||
out = util.ApplyGeminiThinkingLevel(out, "", &includeThoughts)
|
||||
default:
|
||||
if level, ok := util.ValidateGemini3ThinkingLevel(modelName, effort); ok {
|
||||
out = util.ApplyGeminiThinkingLevel(out, level, nil)
|
||||
}
|
||||
}
|
||||
} else if !util.ModelUsesThinkingLevels(modelName) {
|
||||
out = util.ApplyReasoningEffortToGemini(out, effort)
|
||||
}
|
||||
}
|
||||
|
||||
// Cherry Studio extension extra_body.google.thinking_config (effective only when official fields are absent)
|
||||
|
||||
@@ -117,7 +117,7 @@ func ConvertGeminiResponseToOpenAIResponses(_ context.Context, modelName string,
|
||||
st.CreatedAt = time.Now().Unix()
|
||||
}
|
||||
|
||||
created := `{"type":"response.created","sequence_number":0,"response":{"id":"","object":"response","created_at":0,"status":"in_progress","background":false,"error":null}}`
|
||||
created := `{"type":"response.created","sequence_number":0,"response":{"id":"","object":"response","created_at":0,"status":"in_progress","background":false,"error":null,"output":[]}}`
|
||||
created, _ = sjson.Set(created, "sequence_number", nextSeq())
|
||||
created, _ = sjson.Set(created, "response.id", st.ResponseID)
|
||||
created, _ = sjson.Set(created, "response.created_at", st.CreatedAt)
|
||||
@@ -160,11 +160,11 @@ func ConvertGeminiResponseToOpenAIResponses(_ context.Context, modelName string,
|
||||
}
|
||||
if t := part.Get("text"); t.Exists() && t.String() != "" {
|
||||
st.ReasoningBuf.WriteString(t.String())
|
||||
msg := `{"type":"response.reasoning_summary_text.delta","sequence_number":0,"item_id":"","output_index":0,"summary_index":0,"text":""}`
|
||||
msg := `{"type":"response.reasoning_summary_text.delta","sequence_number":0,"item_id":"","output_index":0,"summary_index":0,"delta":""}`
|
||||
msg, _ = sjson.Set(msg, "sequence_number", nextSeq())
|
||||
msg, _ = sjson.Set(msg, "item_id", st.ReasoningItemID)
|
||||
msg, _ = sjson.Set(msg, "output_index", st.ReasoningIndex)
|
||||
msg, _ = sjson.Set(msg, "text", t.String())
|
||||
msg, _ = sjson.Set(msg, "delta", t.String())
|
||||
out = append(out, emitEvent("response.reasoning_summary_text.delta", msg))
|
||||
}
|
||||
return true
|
||||
|
||||
@@ -143,7 +143,7 @@ func ConvertOpenAIChatCompletionsResponseToOpenAIResponses(ctx context.Context,
|
||||
st.ReasoningTokens = 0
|
||||
st.UsageSeen = false
|
||||
// response.created
|
||||
created := `{"type":"response.created","sequence_number":0,"response":{"id":"","object":"response","created_at":0,"status":"in_progress","background":false,"error":null}}`
|
||||
created := `{"type":"response.created","sequence_number":0,"response":{"id":"","object":"response","created_at":0,"status":"in_progress","background":false,"error":null,"output":[]}}`
|
||||
created, _ = sjson.Set(created, "sequence_number", nextSeq())
|
||||
created, _ = sjson.Set(created, "response.id", st.ResponseID)
|
||||
created, _ = sjson.Set(created, "response.created_at", st.Created)
|
||||
@@ -216,11 +216,11 @@ func ConvertOpenAIChatCompletionsResponseToOpenAIResponses(ctx context.Context,
|
||||
}
|
||||
// Append incremental text to reasoning buffer
|
||||
st.ReasoningBuf.WriteString(rc.String())
|
||||
msg := `{"type":"response.reasoning_summary_text.delta","sequence_number":0,"item_id":"","output_index":0,"summary_index":0,"text":""}`
|
||||
msg := `{"type":"response.reasoning_summary_text.delta","sequence_number":0,"item_id":"","output_index":0,"summary_index":0,"delta":""}`
|
||||
msg, _ = sjson.Set(msg, "sequence_number", nextSeq())
|
||||
msg, _ = sjson.Set(msg, "item_id", st.ReasoningID)
|
||||
msg, _ = sjson.Set(msg, "output_index", st.ReasoningIndex)
|
||||
msg, _ = sjson.Set(msg, "text", rc.String())
|
||||
msg, _ = sjson.Set(msg, "delta", rc.String())
|
||||
out = append(out, emitRespEvent("response.reasoning_summary_text.delta", msg))
|
||||
}
|
||||
|
||||
|
||||
10
internal/util/claude_model.go
Normal file
10
internal/util/claude_model.go
Normal file
@@ -0,0 +1,10 @@
|
||||
package util
|
||||
|
||||
import "strings"
|
||||
|
||||
// IsClaudeThinkingModel checks if the model is a Claude thinking model
|
||||
// that requires the interleaved-thinking beta header.
|
||||
func IsClaudeThinkingModel(model string) bool {
|
||||
lower := strings.ToLower(model)
|
||||
return strings.Contains(lower, "claude") && strings.Contains(lower, "thinking")
|
||||
}
|
||||
41
internal/util/claude_model_test.go
Normal file
41
internal/util/claude_model_test.go
Normal file
@@ -0,0 +1,41 @@
|
||||
package util
|
||||
|
||||
import "testing"
|
||||
|
||||
func TestIsClaudeThinkingModel(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
model string
|
||||
expected bool
|
||||
}{
|
||||
// Claude thinking models - should return true
|
||||
{"claude-sonnet-4-5-thinking", "claude-sonnet-4-5-thinking", true},
|
||||
{"claude-opus-4-5-thinking", "claude-opus-4-5-thinking", true},
|
||||
{"Claude-Sonnet-Thinking uppercase", "Claude-Sonnet-4-5-Thinking", true},
|
||||
{"claude thinking mixed case", "Claude-THINKING-Model", true},
|
||||
|
||||
// Non-thinking Claude models - should return false
|
||||
{"claude-sonnet-4-5 (no thinking)", "claude-sonnet-4-5", false},
|
||||
{"claude-opus-4-5 (no thinking)", "claude-opus-4-5", false},
|
||||
{"claude-3-5-sonnet", "claude-3-5-sonnet-20240620", false},
|
||||
|
||||
// Non-Claude models - should return false
|
||||
{"gemini-3-pro-preview", "gemini-3-pro-preview", false},
|
||||
{"gemini-thinking model", "gemini-3-pro-thinking", false}, // not Claude
|
||||
{"gpt-4o", "gpt-4o", false},
|
||||
{"empty string", "", false},
|
||||
|
||||
// Edge cases
|
||||
{"thinking without claude", "thinking-model", false},
|
||||
{"claude without thinking", "claude-model", false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := IsClaudeThinkingModel(tt.model)
|
||||
if result != tt.expected {
|
||||
t.Errorf("IsClaudeThinkingModel(%q) = %v, expected %v", tt.model, result, tt.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -12,10 +12,10 @@ import (
|
||||
|
||||
var gjsonPathKeyReplacer = strings.NewReplacer(".", "\\.", "*", "\\*", "?", "\\?")
|
||||
|
||||
// CleanJSONSchemaForGemini transforms a JSON schema to be compatible with Gemini/Antigravity API.
|
||||
// CleanJSONSchemaForAntigravity transforms a JSON schema to be compatible with Antigravity API.
|
||||
// It handles unsupported keywords, type flattening, and schema simplification while preserving
|
||||
// semantic information as description hints.
|
||||
func CleanJSONSchemaForGemini(jsonStr string) string {
|
||||
func CleanJSONSchemaForAntigravity(jsonStr string) string {
|
||||
// Phase 1: Convert and add hints
|
||||
jsonStr = convertRefsToHints(jsonStr)
|
||||
jsonStr = convertConstToEnum(jsonStr)
|
||||
@@ -32,6 +32,9 @@ func CleanJSONSchemaForGemini(jsonStr string) string {
|
||||
jsonStr = removeUnsupportedKeywords(jsonStr)
|
||||
jsonStr = cleanupRequiredFields(jsonStr)
|
||||
|
||||
// Phase 4: Add placeholder for empty object schemas (Claude VALIDATED mode requirement)
|
||||
jsonStr = addEmptySchemaPlaceholder(jsonStr)
|
||||
|
||||
return jsonStr
|
||||
}
|
||||
|
||||
@@ -105,7 +108,8 @@ func addAdditionalPropertiesHints(jsonStr string) string {
|
||||
|
||||
var unsupportedConstraints = []string{
|
||||
"minLength", "maxLength", "exclusiveMinimum", "exclusiveMaximum",
|
||||
"pattern", "minItems", "maxItems",
|
||||
"pattern", "minItems", "maxItems", "format",
|
||||
"default", "examples", // Claude rejects these in VALIDATED mode
|
||||
}
|
||||
|
||||
func moveConstraintsToDescription(jsonStr string) string {
|
||||
@@ -296,6 +300,7 @@ func flattenTypeArrays(jsonStr string) string {
|
||||
func removeUnsupportedKeywords(jsonStr string) string {
|
||||
keywords := append(unsupportedConstraints,
|
||||
"$schema", "$defs", "definitions", "const", "$ref", "additionalProperties",
|
||||
"propertyNames", // Gemini doesn't support property name validation
|
||||
)
|
||||
for _, key := range keywords {
|
||||
for _, p := range findPaths(jsonStr, key) {
|
||||
@@ -338,6 +343,52 @@ func cleanupRequiredFields(jsonStr string) string {
|
||||
return jsonStr
|
||||
}
|
||||
|
||||
// addEmptySchemaPlaceholder adds a placeholder "reason" property to empty object schemas.
|
||||
// Claude VALIDATED mode requires at least one property in tool schemas.
|
||||
func addEmptySchemaPlaceholder(jsonStr string) string {
|
||||
// Find all "type" fields
|
||||
paths := findPaths(jsonStr, "type")
|
||||
|
||||
// Process from deepest to shallowest (to handle nested objects properly)
|
||||
sortByDepth(paths)
|
||||
|
||||
for _, p := range paths {
|
||||
typeVal := gjson.Get(jsonStr, p)
|
||||
if typeVal.String() != "object" {
|
||||
continue
|
||||
}
|
||||
|
||||
// Get the parent path (the object containing "type")
|
||||
parentPath := trimSuffix(p, ".type")
|
||||
|
||||
// Check if properties exists and is empty or missing
|
||||
propsPath := joinPath(parentPath, "properties")
|
||||
propsVal := gjson.Get(jsonStr, propsPath)
|
||||
|
||||
needsPlaceholder := false
|
||||
if !propsVal.Exists() {
|
||||
// No properties field at all
|
||||
needsPlaceholder = true
|
||||
} else if propsVal.IsObject() && len(propsVal.Map()) == 0 {
|
||||
// Empty properties object
|
||||
needsPlaceholder = true
|
||||
}
|
||||
|
||||
if needsPlaceholder {
|
||||
// Add placeholder "reason" property
|
||||
reasonPath := joinPath(propsPath, "reason")
|
||||
jsonStr, _ = sjson.Set(jsonStr, reasonPath+".type", "string")
|
||||
jsonStr, _ = sjson.Set(jsonStr, reasonPath+".description", "Brief explanation of why you are calling this tool")
|
||||
|
||||
// Add to required array
|
||||
reqPath := joinPath(parentPath, "required")
|
||||
jsonStr, _ = sjson.Set(jsonStr, reqPath, []string{"reason"})
|
||||
}
|
||||
}
|
||||
|
||||
return jsonStr
|
||||
}
|
||||
|
||||
// --- Helpers ---
|
||||
|
||||
func findPaths(jsonStr, field string) []string {
|
||||
|
||||
@@ -5,9 +5,11 @@ import (
|
||||
"reflect"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
|
||||
func TestCleanJSONSchemaForGemini_ConstToEnum(t *testing.T) {
|
||||
func TestCleanJSONSchemaForAntigravity_ConstToEnum(t *testing.T) {
|
||||
input := `{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
@@ -28,11 +30,11 @@ func TestCleanJSONSchemaForGemini_ConstToEnum(t *testing.T) {
|
||||
}
|
||||
}`
|
||||
|
||||
result := CleanJSONSchemaForGemini(input)
|
||||
result := CleanJSONSchemaForAntigravity(input)
|
||||
compareJSON(t, expected, result)
|
||||
}
|
||||
|
||||
func TestCleanJSONSchemaForGemini_TypeFlattening_Nullable(t *testing.T) {
|
||||
func TestCleanJSONSchemaForAntigravity_TypeFlattening_Nullable(t *testing.T) {
|
||||
input := `{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
@@ -60,11 +62,11 @@ func TestCleanJSONSchemaForGemini_TypeFlattening_Nullable(t *testing.T) {
|
||||
"required": ["other"]
|
||||
}`
|
||||
|
||||
result := CleanJSONSchemaForGemini(input)
|
||||
result := CleanJSONSchemaForAntigravity(input)
|
||||
compareJSON(t, expected, result)
|
||||
}
|
||||
|
||||
func TestCleanJSONSchemaForGemini_ConstraintsToDescription(t *testing.T) {
|
||||
func TestCleanJSONSchemaForAntigravity_ConstraintsToDescription(t *testing.T) {
|
||||
input := `{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
@@ -81,7 +83,7 @@ func TestCleanJSONSchemaForGemini_ConstraintsToDescription(t *testing.T) {
|
||||
}
|
||||
}`
|
||||
|
||||
result := CleanJSONSchemaForGemini(input)
|
||||
result := CleanJSONSchemaForAntigravity(input)
|
||||
|
||||
// minItems should be REMOVED and moved to description
|
||||
if strings.Contains(result, `"minItems"`) {
|
||||
@@ -100,7 +102,7 @@ func TestCleanJSONSchemaForGemini_ConstraintsToDescription(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestCleanJSONSchemaForGemini_AnyOfFlattening_SmartSelection(t *testing.T) {
|
||||
func TestCleanJSONSchemaForAntigravity_AnyOfFlattening_SmartSelection(t *testing.T) {
|
||||
input := `{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
@@ -131,11 +133,11 @@ func TestCleanJSONSchemaForGemini_AnyOfFlattening_SmartSelection(t *testing.T) {
|
||||
}
|
||||
}`
|
||||
|
||||
result := CleanJSONSchemaForGemini(input)
|
||||
result := CleanJSONSchemaForAntigravity(input)
|
||||
compareJSON(t, expected, result)
|
||||
}
|
||||
|
||||
func TestCleanJSONSchemaForGemini_OneOfFlattening(t *testing.T) {
|
||||
func TestCleanJSONSchemaForAntigravity_OneOfFlattening(t *testing.T) {
|
||||
input := `{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
@@ -158,11 +160,11 @@ func TestCleanJSONSchemaForGemini_OneOfFlattening(t *testing.T) {
|
||||
}
|
||||
}`
|
||||
|
||||
result := CleanJSONSchemaForGemini(input)
|
||||
result := CleanJSONSchemaForAntigravity(input)
|
||||
compareJSON(t, expected, result)
|
||||
}
|
||||
|
||||
func TestCleanJSONSchemaForGemini_AllOfMerging(t *testing.T) {
|
||||
func TestCleanJSONSchemaForAntigravity_AllOfMerging(t *testing.T) {
|
||||
input := `{
|
||||
"type": "object",
|
||||
"allOf": [
|
||||
@@ -190,11 +192,11 @@ func TestCleanJSONSchemaForGemini_AllOfMerging(t *testing.T) {
|
||||
"required": ["a", "b"]
|
||||
}`
|
||||
|
||||
result := CleanJSONSchemaForGemini(input)
|
||||
result := CleanJSONSchemaForAntigravity(input)
|
||||
compareJSON(t, expected, result)
|
||||
}
|
||||
|
||||
func TestCleanJSONSchemaForGemini_RefHandling(t *testing.T) {
|
||||
func TestCleanJSONSchemaForAntigravity_RefHandling(t *testing.T) {
|
||||
input := `{
|
||||
"definitions": {
|
||||
"User": {
|
||||
@@ -210,21 +212,29 @@ func TestCleanJSONSchemaForGemini_RefHandling(t *testing.T) {
|
||||
}
|
||||
}`
|
||||
|
||||
// After $ref is converted to placeholder object, empty schema placeholder is also added
|
||||
expected := `{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"customer": {
|
||||
"type": "object",
|
||||
"description": "See: User"
|
||||
"description": "See: User",
|
||||
"properties": {
|
||||
"reason": {
|
||||
"type": "string",
|
||||
"description": "Brief explanation of why you are calling this tool"
|
||||
}
|
||||
},
|
||||
"required": ["reason"]
|
||||
}
|
||||
}
|
||||
}`
|
||||
|
||||
result := CleanJSONSchemaForGemini(input)
|
||||
result := CleanJSONSchemaForAntigravity(input)
|
||||
compareJSON(t, expected, result)
|
||||
}
|
||||
|
||||
func TestCleanJSONSchemaForGemini_RefHandling_DescriptionEscaping(t *testing.T) {
|
||||
func TestCleanJSONSchemaForAntigravity_RefHandling_DescriptionEscaping(t *testing.T) {
|
||||
input := `{
|
||||
"definitions": {
|
||||
"User": {
|
||||
@@ -243,21 +253,29 @@ func TestCleanJSONSchemaForGemini_RefHandling_DescriptionEscaping(t *testing.T)
|
||||
}
|
||||
}`
|
||||
|
||||
// After $ref is converted, empty schema placeholder is also added
|
||||
expected := `{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"customer": {
|
||||
"type": "object",
|
||||
"description": "He said \"hi\"\\nsecond line (See: User)"
|
||||
"description": "He said \"hi\"\\nsecond line (See: User)",
|
||||
"properties": {
|
||||
"reason": {
|
||||
"type": "string",
|
||||
"description": "Brief explanation of why you are calling this tool"
|
||||
}
|
||||
},
|
||||
"required": ["reason"]
|
||||
}
|
||||
}
|
||||
}`
|
||||
|
||||
result := CleanJSONSchemaForGemini(input)
|
||||
result := CleanJSONSchemaForAntigravity(input)
|
||||
compareJSON(t, expected, result)
|
||||
}
|
||||
|
||||
func TestCleanJSONSchemaForGemini_CyclicRefDefaults(t *testing.T) {
|
||||
func TestCleanJSONSchemaForAntigravity_CyclicRefDefaults(t *testing.T) {
|
||||
input := `{
|
||||
"definitions": {
|
||||
"Node": {
|
||||
@@ -270,7 +288,7 @@ func TestCleanJSONSchemaForGemini_CyclicRefDefaults(t *testing.T) {
|
||||
"$ref": "#/definitions/Node"
|
||||
}`
|
||||
|
||||
result := CleanJSONSchemaForGemini(input)
|
||||
result := CleanJSONSchemaForAntigravity(input)
|
||||
|
||||
var resMap map[string]interface{}
|
||||
json.Unmarshal([]byte(result), &resMap)
|
||||
@@ -285,7 +303,7 @@ func TestCleanJSONSchemaForGemini_CyclicRefDefaults(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestCleanJSONSchemaForGemini_RequiredCleanup(t *testing.T) {
|
||||
func TestCleanJSONSchemaForAntigravity_RequiredCleanup(t *testing.T) {
|
||||
input := `{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
@@ -304,11 +322,11 @@ func TestCleanJSONSchemaForGemini_RequiredCleanup(t *testing.T) {
|
||||
"required": ["a", "b"]
|
||||
}`
|
||||
|
||||
result := CleanJSONSchemaForGemini(input)
|
||||
result := CleanJSONSchemaForAntigravity(input)
|
||||
compareJSON(t, expected, result)
|
||||
}
|
||||
|
||||
func TestCleanJSONSchemaForGemini_AllOfMerging_DotKeys(t *testing.T) {
|
||||
func TestCleanJSONSchemaForAntigravity_AllOfMerging_DotKeys(t *testing.T) {
|
||||
input := `{
|
||||
"type": "object",
|
||||
"allOf": [
|
||||
@@ -336,11 +354,11 @@ func TestCleanJSONSchemaForGemini_AllOfMerging_DotKeys(t *testing.T) {
|
||||
"required": ["my.param", "b"]
|
||||
}`
|
||||
|
||||
result := CleanJSONSchemaForGemini(input)
|
||||
result := CleanJSONSchemaForAntigravity(input)
|
||||
compareJSON(t, expected, result)
|
||||
}
|
||||
|
||||
func TestCleanJSONSchemaForGemini_PropertyNameCollision(t *testing.T) {
|
||||
func TestCleanJSONSchemaForAntigravity_PropertyNameCollision(t *testing.T) {
|
||||
// A tool has an argument named "pattern" - should NOT be treated as a constraint
|
||||
input := `{
|
||||
"type": "object",
|
||||
@@ -364,7 +382,7 @@ func TestCleanJSONSchemaForGemini_PropertyNameCollision(t *testing.T) {
|
||||
"required": ["pattern"]
|
||||
}`
|
||||
|
||||
result := CleanJSONSchemaForGemini(input)
|
||||
result := CleanJSONSchemaForAntigravity(input)
|
||||
compareJSON(t, expected, result)
|
||||
|
||||
var resMap map[string]interface{}
|
||||
@@ -375,7 +393,7 @@ func TestCleanJSONSchemaForGemini_PropertyNameCollision(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestCleanJSONSchemaForGemini_DotKeys(t *testing.T) {
|
||||
func TestCleanJSONSchemaForAntigravity_DotKeys(t *testing.T) {
|
||||
input := `{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
@@ -389,7 +407,7 @@ func TestCleanJSONSchemaForGemini_DotKeys(t *testing.T) {
|
||||
}
|
||||
}`
|
||||
|
||||
result := CleanJSONSchemaForGemini(input)
|
||||
result := CleanJSONSchemaForAntigravity(input)
|
||||
|
||||
var resMap map[string]interface{}
|
||||
if err := json.Unmarshal([]byte(result), &resMap); err != nil {
|
||||
@@ -414,7 +432,7 @@ func TestCleanJSONSchemaForGemini_DotKeys(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestCleanJSONSchemaForGemini_AnyOfAlternativeHints(t *testing.T) {
|
||||
func TestCleanJSONSchemaForAntigravity_AnyOfAlternativeHints(t *testing.T) {
|
||||
input := `{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
@@ -428,7 +446,7 @@ func TestCleanJSONSchemaForGemini_AnyOfAlternativeHints(t *testing.T) {
|
||||
}
|
||||
}`
|
||||
|
||||
result := CleanJSONSchemaForGemini(input)
|
||||
result := CleanJSONSchemaForAntigravity(input)
|
||||
|
||||
if !strings.Contains(result, "Accepts:") {
|
||||
t.Errorf("Expected alternative types hint, got: %s", result)
|
||||
@@ -438,7 +456,7 @@ func TestCleanJSONSchemaForGemini_AnyOfAlternativeHints(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestCleanJSONSchemaForGemini_NullableHint(t *testing.T) {
|
||||
func TestCleanJSONSchemaForAntigravity_NullableHint(t *testing.T) {
|
||||
input := `{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
@@ -450,7 +468,7 @@ func TestCleanJSONSchemaForGemini_NullableHint(t *testing.T) {
|
||||
"required": ["name"]
|
||||
}`
|
||||
|
||||
result := CleanJSONSchemaForGemini(input)
|
||||
result := CleanJSONSchemaForAntigravity(input)
|
||||
|
||||
if !strings.Contains(result, "(nullable)") {
|
||||
t.Errorf("Expected nullable hint, got: %s", result)
|
||||
@@ -460,7 +478,7 @@ func TestCleanJSONSchemaForGemini_NullableHint(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestCleanJSONSchemaForGemini_TypeFlattening_Nullable_DotKey(t *testing.T) {
|
||||
func TestCleanJSONSchemaForAntigravity_TypeFlattening_Nullable_DotKey(t *testing.T) {
|
||||
input := `{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
@@ -488,11 +506,11 @@ func TestCleanJSONSchemaForGemini_TypeFlattening_Nullable_DotKey(t *testing.T) {
|
||||
"required": ["other"]
|
||||
}`
|
||||
|
||||
result := CleanJSONSchemaForGemini(input)
|
||||
result := CleanJSONSchemaForAntigravity(input)
|
||||
compareJSON(t, expected, result)
|
||||
}
|
||||
|
||||
func TestCleanJSONSchemaForGemini_EnumHint(t *testing.T) {
|
||||
func TestCleanJSONSchemaForAntigravity_EnumHint(t *testing.T) {
|
||||
input := `{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
@@ -504,7 +522,7 @@ func TestCleanJSONSchemaForGemini_EnumHint(t *testing.T) {
|
||||
}
|
||||
}`
|
||||
|
||||
result := CleanJSONSchemaForGemini(input)
|
||||
result := CleanJSONSchemaForAntigravity(input)
|
||||
|
||||
if !strings.Contains(result, "Allowed:") {
|
||||
t.Errorf("Expected enum values hint, got: %s", result)
|
||||
@@ -514,7 +532,7 @@ func TestCleanJSONSchemaForGemini_EnumHint(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestCleanJSONSchemaForGemini_AdditionalPropertiesHint(t *testing.T) {
|
||||
func TestCleanJSONSchemaForAntigravity_AdditionalPropertiesHint(t *testing.T) {
|
||||
input := `{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
@@ -523,14 +541,14 @@ func TestCleanJSONSchemaForGemini_AdditionalPropertiesHint(t *testing.T) {
|
||||
"additionalProperties": false
|
||||
}`
|
||||
|
||||
result := CleanJSONSchemaForGemini(input)
|
||||
result := CleanJSONSchemaForAntigravity(input)
|
||||
|
||||
if !strings.Contains(result, "No extra properties allowed") {
|
||||
t.Errorf("Expected additionalProperties hint, got: %s", result)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCleanJSONSchemaForGemini_AnyOfFlattening_PreservesDescription(t *testing.T) {
|
||||
func TestCleanJSONSchemaForAntigravity_AnyOfFlattening_PreservesDescription(t *testing.T) {
|
||||
input := `{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
@@ -554,11 +572,11 @@ func TestCleanJSONSchemaForGemini_AnyOfFlattening_PreservesDescription(t *testin
|
||||
}
|
||||
}`
|
||||
|
||||
result := CleanJSONSchemaForGemini(input)
|
||||
result := CleanJSONSchemaForAntigravity(input)
|
||||
compareJSON(t, expected, result)
|
||||
}
|
||||
|
||||
func TestCleanJSONSchemaForGemini_SingleEnumNoHint(t *testing.T) {
|
||||
func TestCleanJSONSchemaForAntigravity_SingleEnumNoHint(t *testing.T) {
|
||||
input := `{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
@@ -569,14 +587,14 @@ func TestCleanJSONSchemaForGemini_SingleEnumNoHint(t *testing.T) {
|
||||
}
|
||||
}`
|
||||
|
||||
result := CleanJSONSchemaForGemini(input)
|
||||
result := CleanJSONSchemaForAntigravity(input)
|
||||
|
||||
if strings.Contains(result, "Allowed:") {
|
||||
t.Errorf("Single value enum should not add Allowed hint, got: %s", result)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCleanJSONSchemaForGemini_MultipleNonNullTypes(t *testing.T) {
|
||||
func TestCleanJSONSchemaForAntigravity_MultipleNonNullTypes(t *testing.T) {
|
||||
input := `{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
@@ -586,7 +604,7 @@ func TestCleanJSONSchemaForGemini_MultipleNonNullTypes(t *testing.T) {
|
||||
}
|
||||
}`
|
||||
|
||||
result := CleanJSONSchemaForGemini(input)
|
||||
result := CleanJSONSchemaForAntigravity(input)
|
||||
|
||||
if !strings.Contains(result, "Accepts:") {
|
||||
t.Errorf("Expected multiple types hint, got: %s", result)
|
||||
@@ -596,6 +614,71 @@ func TestCleanJSONSchemaForGemini_MultipleNonNullTypes(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestCleanJSONSchemaForGemini_PropertyNamesRemoval(t *testing.T) {
|
||||
// propertyNames is used to validate object property names (e.g., must match a pattern)
|
||||
// Gemini doesn't support this keyword and will reject requests containing it
|
||||
input := `{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"metadata": {
|
||||
"type": "object",
|
||||
"propertyNames": {
|
||||
"pattern": "^[a-zA-Z_][a-zA-Z0-9_]*$"
|
||||
},
|
||||
"additionalProperties": {
|
||||
"type": "string"
|
||||
}
|
||||
}
|
||||
}
|
||||
}`
|
||||
|
||||
expected := `{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"metadata": {
|
||||
"type": "object"
|
||||
}
|
||||
}
|
||||
}`
|
||||
|
||||
result := CleanJSONSchemaForGemini(input)
|
||||
compareJSON(t, expected, result)
|
||||
|
||||
// Verify propertyNames is completely removed
|
||||
if strings.Contains(result, "propertyNames") {
|
||||
t.Errorf("propertyNames keyword should be removed, got: %s", result)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCleanJSONSchemaForGemini_PropertyNamesRemoval_Nested(t *testing.T) {
|
||||
// Test deeply nested propertyNames (as seen in real Claude tool schemas)
|
||||
input := `{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"items": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"config": {
|
||||
"type": "object",
|
||||
"propertyNames": {
|
||||
"type": "string"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}`
|
||||
|
||||
result := CleanJSONSchemaForGemini(input)
|
||||
|
||||
if strings.Contains(result, "propertyNames") {
|
||||
t.Errorf("Nested propertyNames should be removed, got: %s", result)
|
||||
}
|
||||
}
|
||||
|
||||
func compareJSON(t *testing.T, expectedJSON, actualJSON string) {
|
||||
var expMap, actMap map[string]interface{}
|
||||
errExp := json.Unmarshal([]byte(expectedJSON), &expMap)
|
||||
@@ -611,3 +694,190 @@ func compareJSON(t *testing.T, expectedJSON, actualJSON string) {
|
||||
t.Errorf("JSON mismatch:\nExpected:\n%s\n\nActual:\n%s", string(expBytes), string(actBytes))
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Empty Schema Placeholder Tests
|
||||
// ============================================================================
|
||||
|
||||
func TestCleanJSONSchemaForAntigravity_EmptySchemaPlaceholder(t *testing.T) {
|
||||
// Empty object schema with no properties should get a placeholder
|
||||
input := `{
|
||||
"type": "object"
|
||||
}`
|
||||
|
||||
result := CleanJSONSchemaForAntigravity(input)
|
||||
|
||||
// Should have placeholder property added
|
||||
if !strings.Contains(result, `"reason"`) {
|
||||
t.Errorf("Empty schema should have 'reason' placeholder property, got: %s", result)
|
||||
}
|
||||
if !strings.Contains(result, `"required"`) {
|
||||
t.Errorf("Empty schema should have 'required' with 'reason', got: %s", result)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCleanJSONSchemaForAntigravity_EmptyPropertiesPlaceholder(t *testing.T) {
|
||||
// Object with empty properties object
|
||||
input := `{
|
||||
"type": "object",
|
||||
"properties": {}
|
||||
}`
|
||||
|
||||
result := CleanJSONSchemaForAntigravity(input)
|
||||
|
||||
// Should have placeholder property added
|
||||
if !strings.Contains(result, `"reason"`) {
|
||||
t.Errorf("Empty properties should have 'reason' placeholder, got: %s", result)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCleanJSONSchemaForAntigravity_NonEmptySchemaUnchanged(t *testing.T) {
|
||||
// Schema with properties should NOT get placeholder
|
||||
input := `{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"name": {"type": "string"}
|
||||
},
|
||||
"required": ["name"]
|
||||
}`
|
||||
|
||||
result := CleanJSONSchemaForAntigravity(input)
|
||||
|
||||
// Should NOT have placeholder property
|
||||
if strings.Contains(result, `"reason"`) {
|
||||
t.Errorf("Non-empty schema should NOT have 'reason' placeholder, got: %s", result)
|
||||
}
|
||||
// Original properties should be preserved
|
||||
if !strings.Contains(result, `"name"`) {
|
||||
t.Errorf("Original property 'name' should be preserved, got: %s", result)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCleanJSONSchemaForAntigravity_NestedEmptySchema(t *testing.T) {
|
||||
// Nested empty object in items should also get placeholder
|
||||
input := `{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"items": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "object"
|
||||
}
|
||||
}
|
||||
}
|
||||
}`
|
||||
|
||||
result := CleanJSONSchemaForAntigravity(input)
|
||||
|
||||
// Nested empty object should also get placeholder
|
||||
// Check that the nested object has a reason property
|
||||
parsed := gjson.Parse(result)
|
||||
nestedProps := parsed.Get("properties.items.items.properties")
|
||||
if !nestedProps.Exists() || !nestedProps.Get("reason").Exists() {
|
||||
t.Errorf("Nested empty object should have 'reason' placeholder, got: %s", result)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCleanJSONSchemaForAntigravity_EmptySchemaWithDescription(t *testing.T) {
|
||||
// Empty schema with description should preserve description and add placeholder
|
||||
input := `{
|
||||
"type": "object",
|
||||
"description": "An empty object"
|
||||
}`
|
||||
|
||||
result := CleanJSONSchemaForAntigravity(input)
|
||||
|
||||
// Should have both description and placeholder
|
||||
if !strings.Contains(result, `"An empty object"`) {
|
||||
t.Errorf("Description should be preserved, got: %s", result)
|
||||
}
|
||||
if !strings.Contains(result, `"reason"`) {
|
||||
t.Errorf("Empty schema should have 'reason' placeholder, got: %s", result)
|
||||
}
|
||||
}
|
||||
|
||||
// ============================================================================
|
||||
// Format field handling (ad-hoc patch removal)
|
||||
// ============================================================================
|
||||
|
||||
func TestCleanJSONSchemaForAntigravity_FormatFieldRemoval(t *testing.T) {
|
||||
// format:"uri" should be removed and added as hint
|
||||
input := `{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"url": {
|
||||
"type": "string",
|
||||
"format": "uri",
|
||||
"description": "A URL"
|
||||
}
|
||||
}
|
||||
}`
|
||||
|
||||
result := CleanJSONSchemaForAntigravity(input)
|
||||
|
||||
// format should be removed
|
||||
if strings.Contains(result, `"format"`) {
|
||||
t.Errorf("format field should be removed, got: %s", result)
|
||||
}
|
||||
// hint should be added to description
|
||||
if !strings.Contains(result, "format: uri") {
|
||||
t.Errorf("format hint should be added to description, got: %s", result)
|
||||
}
|
||||
// original description should be preserved
|
||||
if !strings.Contains(result, "A URL") {
|
||||
t.Errorf("Original description should be preserved, got: %s", result)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCleanJSONSchemaForAntigravity_FormatFieldNoDescription(t *testing.T) {
|
||||
// format without description should create description with hint
|
||||
input := `{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"email": {
|
||||
"type": "string",
|
||||
"format": "email"
|
||||
}
|
||||
}
|
||||
}`
|
||||
|
||||
result := CleanJSONSchemaForAntigravity(input)
|
||||
|
||||
// format should be removed
|
||||
if strings.Contains(result, `"format"`) {
|
||||
t.Errorf("format field should be removed, got: %s", result)
|
||||
}
|
||||
// hint should be added
|
||||
if !strings.Contains(result, "format: email") {
|
||||
t.Errorf("format hint should be added, got: %s", result)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCleanJSONSchemaForAntigravity_MultipleFormats(t *testing.T) {
|
||||
// Multiple format fields should all be handled
|
||||
input := `{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"url": {"type": "string", "format": "uri"},
|
||||
"email": {"type": "string", "format": "email"},
|
||||
"date": {"type": "string", "format": "date-time"}
|
||||
}
|
||||
}`
|
||||
|
||||
result := CleanJSONSchemaForAntigravity(input)
|
||||
|
||||
// All format fields should be removed
|
||||
if strings.Contains(result, `"format"`) {
|
||||
t.Errorf("All format fields should be removed, got: %s", result)
|
||||
}
|
||||
// All hints should be added
|
||||
if !strings.Contains(result, "format: uri") {
|
||||
t.Errorf("uri format hint should be added, got: %s", result)
|
||||
}
|
||||
if !strings.Contains(result, "format: email") {
|
||||
t.Errorf("email format hint should be added, got: %s", result)
|
||||
}
|
||||
if !strings.Contains(result, "format: date-time") {
|
||||
t.Errorf("date-time format hint should be added, got: %s", result)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -136,6 +136,12 @@ func ApplyGeminiThinkingLevel(body []byte, level string, includeThoughts *bool)
|
||||
updated = rewritten
|
||||
}
|
||||
}
|
||||
if it := gjson.GetBytes(body, "generationConfig.thinkingConfig.include_thoughts"); it.Exists() {
|
||||
updated, _ = sjson.DeleteBytes(updated, "generationConfig.thinkingConfig.include_thoughts")
|
||||
}
|
||||
if tb := gjson.GetBytes(body, "generationConfig.thinkingConfig.thinkingBudget"); tb.Exists() {
|
||||
updated, _ = sjson.DeleteBytes(updated, "generationConfig.thinkingConfig.thinkingBudget")
|
||||
}
|
||||
return updated
|
||||
}
|
||||
|
||||
@@ -167,6 +173,12 @@ func ApplyGeminiCLIThinkingLevel(body []byte, level string, includeThoughts *boo
|
||||
updated = rewritten
|
||||
}
|
||||
}
|
||||
if it := gjson.GetBytes(body, "request.generationConfig.thinkingConfig.include_thoughts"); it.Exists() {
|
||||
updated, _ = sjson.DeleteBytes(updated, "request.generationConfig.thinkingConfig.include_thoughts")
|
||||
}
|
||||
if tb := gjson.GetBytes(body, "request.generationConfig.thinkingConfig.thinkingBudget"); tb.Exists() {
|
||||
updated, _ = sjson.DeleteBytes(updated, "request.generationConfig.thinkingConfig.thinkingBudget")
|
||||
}
|
||||
return updated
|
||||
}
|
||||
|
||||
@@ -242,7 +254,7 @@ func ThinkingBudgetToGemini3Level(model string, budget int) (string, bool) {
|
||||
var modelsWithDefaultThinking = map[string]bool{
|
||||
"gemini-3-pro-preview": true,
|
||||
"gemini-3-pro-image-preview": true,
|
||||
"gemini-3-flash-preview": true,
|
||||
// "gemini-3-flash-preview": true,
|
||||
}
|
||||
|
||||
// ModelHasDefaultThinking returns true if the model should have thinking enabled by default.
|
||||
@@ -352,8 +364,9 @@ func StripThinkingConfigIfUnsupported(model string, body []byte) []byte {
|
||||
|
||||
// NormalizeGeminiThinkingBudget normalizes the thinkingBudget value in a standard Gemini
|
||||
// request body (generationConfig.thinkingConfig.thinkingBudget path).
|
||||
// For Gemini 3 models, converts thinkingBudget to thinkingLevel per Google's documentation.
|
||||
func NormalizeGeminiThinkingBudget(model string, body []byte) []byte {
|
||||
// For Gemini 3 models, converts thinkingBudget to thinkingLevel per Google's documentation,
|
||||
// unless skipGemini3Check is provided and true.
|
||||
func NormalizeGeminiThinkingBudget(model string, body []byte, skipGemini3Check ...bool) []byte {
|
||||
const budgetPath = "generationConfig.thinkingConfig.thinkingBudget"
|
||||
const levelPath = "generationConfig.thinkingConfig.thinkingLevel"
|
||||
|
||||
@@ -363,7 +376,8 @@ func NormalizeGeminiThinkingBudget(model string, body []byte) []byte {
|
||||
}
|
||||
|
||||
// For Gemini 3 models, convert thinkingBudget to thinkingLevel
|
||||
if IsGemini3Model(model) {
|
||||
skipGemini3 := len(skipGemini3Check) > 0 && skipGemini3Check[0]
|
||||
if IsGemini3Model(model) && !skipGemini3 {
|
||||
if level, ok := ThinkingBudgetToGemini3Level(model, int(budget.Int())); ok {
|
||||
updated, _ := sjson.SetBytes(body, levelPath, level)
|
||||
updated, _ = sjson.DeleteBytes(updated, budgetPath)
|
||||
@@ -382,8 +396,9 @@ func NormalizeGeminiThinkingBudget(model string, body []byte) []byte {
|
||||
|
||||
// NormalizeGeminiCLIThinkingBudget normalizes the thinkingBudget value in a Gemini CLI
|
||||
// request body (request.generationConfig.thinkingConfig.thinkingBudget path).
|
||||
// For Gemini 3 models, converts thinkingBudget to thinkingLevel per Google's documentation.
|
||||
func NormalizeGeminiCLIThinkingBudget(model string, body []byte) []byte {
|
||||
// For Gemini 3 models, converts thinkingBudget to thinkingLevel per Google's documentation,
|
||||
// unless skipGemini3Check is provided and true.
|
||||
func NormalizeGeminiCLIThinkingBudget(model string, body []byte, skipGemini3Check ...bool) []byte {
|
||||
const budgetPath = "request.generationConfig.thinkingConfig.thinkingBudget"
|
||||
const levelPath = "request.generationConfig.thinkingConfig.thinkingLevel"
|
||||
|
||||
@@ -393,7 +408,8 @@ func NormalizeGeminiCLIThinkingBudget(model string, body []byte) []byte {
|
||||
}
|
||||
|
||||
// For Gemini 3 models, convert thinkingBudget to thinkingLevel
|
||||
if IsGemini3Model(model) {
|
||||
skipGemini3 := len(skipGemini3Check) > 0 && skipGemini3Check[0]
|
||||
if IsGemini3Model(model) && !skipGemini3 {
|
||||
if level, ok := ThinkingBudgetToGemini3Level(model, int(budget.Int())); ok {
|
||||
updated, _ := sjson.SetBytes(body, levelPath, level)
|
||||
updated, _ = sjson.DeleteBytes(updated, budgetPath)
|
||||
@@ -477,7 +493,7 @@ func ApplyReasoningEffortToGeminiCLI(body []byte, effort string) []byte {
|
||||
|
||||
// ConvertThinkingLevelToBudget checks for "generationConfig.thinkingConfig.thinkingLevel"
|
||||
// and converts it to "thinkingBudget" for Gemini 2.5 models.
|
||||
// For Gemini 3 models, preserves thinkingLevel as-is (does not convert).
|
||||
// For Gemini 3 models, preserves thinkingLevel unless skipGemini3Check is provided and true.
|
||||
// Mappings for Gemini 2.5:
|
||||
// - "high" -> 32768
|
||||
// - "medium" -> 8192
|
||||
@@ -485,43 +501,31 @@ func ApplyReasoningEffortToGeminiCLI(body []byte, effort string) []byte {
|
||||
// - "minimal" -> 512
|
||||
//
|
||||
// It removes "thinkingLevel" after conversion (for Gemini 2.5 only).
|
||||
func ConvertThinkingLevelToBudget(body []byte, model string) []byte {
|
||||
func ConvertThinkingLevelToBudget(body []byte, model string, skipGemini3Check ...bool) []byte {
|
||||
levelPath := "generationConfig.thinkingConfig.thinkingLevel"
|
||||
res := gjson.GetBytes(body, levelPath)
|
||||
if !res.Exists() {
|
||||
return body
|
||||
}
|
||||
|
||||
// For Gemini 3 models, preserve thinkingLevel - don't convert to budget
|
||||
if IsGemini3Model(model) {
|
||||
// For Gemini 3 models, preserve thinkingLevel unless explicitly skipped
|
||||
skipGemini3 := len(skipGemini3Check) > 0 && skipGemini3Check[0]
|
||||
if IsGemini3Model(model) && !skipGemini3 {
|
||||
return body
|
||||
}
|
||||
|
||||
level := strings.ToLower(res.String())
|
||||
var budget int
|
||||
switch level {
|
||||
case "high":
|
||||
budget = 32768
|
||||
case "medium":
|
||||
budget = 8192
|
||||
case "low":
|
||||
budget = 1024
|
||||
case "minimal":
|
||||
budget = 512
|
||||
default:
|
||||
// Unknown level - remove it and let the API use defaults
|
||||
budget, ok := ThinkingLevelToBudget(res.String())
|
||||
if !ok {
|
||||
updated, _ := sjson.DeleteBytes(body, levelPath)
|
||||
return updated
|
||||
}
|
||||
|
||||
// Set budget
|
||||
budgetPath := "generationConfig.thinkingConfig.thinkingBudget"
|
||||
updated, err := sjson.SetBytes(body, budgetPath, budget)
|
||||
if err != nil {
|
||||
return body
|
||||
}
|
||||
|
||||
// Remove level
|
||||
updated, err = sjson.DeleteBytes(updated, levelPath)
|
||||
if err != nil {
|
||||
return body
|
||||
@@ -544,31 +548,18 @@ func ConvertThinkingLevelToBudgetCLI(body []byte, model string) []byte {
|
||||
return body
|
||||
}
|
||||
|
||||
level := strings.ToLower(res.String())
|
||||
var budget int
|
||||
switch level {
|
||||
case "high":
|
||||
budget = 32768
|
||||
case "medium":
|
||||
budget = 8192
|
||||
case "low":
|
||||
budget = 1024
|
||||
case "minimal":
|
||||
budget = 512
|
||||
default:
|
||||
// Unknown level - remove it and let the API use defaults
|
||||
budget, ok := ThinkingLevelToBudget(res.String())
|
||||
if !ok {
|
||||
updated, _ := sjson.DeleteBytes(body, levelPath)
|
||||
return updated
|
||||
}
|
||||
|
||||
// Set budget
|
||||
budgetPath := "request.generationConfig.thinkingConfig.thinkingBudget"
|
||||
updated, err := sjson.SetBytes(body, budgetPath, budget)
|
||||
if err != nil {
|
||||
return body
|
||||
}
|
||||
|
||||
// Remove level
|
||||
updated, err = sjson.DeleteBytes(updated, levelPath)
|
||||
if err != nil {
|
||||
return body
|
||||
|
||||
@@ -160,6 +160,34 @@ func ThinkingEffortToBudget(model, effort string) (int, bool) {
|
||||
}
|
||||
}
|
||||
|
||||
// ThinkingLevelToBudget maps a Gemini thinkingLevel to a numeric thinking budget (tokens).
|
||||
//
|
||||
// Mappings:
|
||||
// - "minimal" -> 512
|
||||
// - "low" -> 1024
|
||||
// - "medium" -> 8192
|
||||
// - "high" -> 32768
|
||||
//
|
||||
// Returns false when the level is empty or unsupported.
|
||||
func ThinkingLevelToBudget(level string) (int, bool) {
|
||||
if level == "" {
|
||||
return 0, false
|
||||
}
|
||||
normalized := strings.ToLower(strings.TrimSpace(level))
|
||||
switch normalized {
|
||||
case "minimal":
|
||||
return 512, true
|
||||
case "low":
|
||||
return 1024, true
|
||||
case "medium":
|
||||
return 8192, true
|
||||
case "high":
|
||||
return 32768, true
|
||||
default:
|
||||
return 0, false
|
||||
}
|
||||
}
|
||||
|
||||
// ThinkingBudgetToEffort maps a numeric thinking budget (tokens)
|
||||
// to a reasoning effort level for level-based models.
|
||||
//
|
||||
|
||||
87
internal/util/thinking_text.go
Normal file
87
internal/util/thinking_text.go
Normal file
@@ -0,0 +1,87 @@
|
||||
package util
|
||||
|
||||
import (
|
||||
"github.com/tidwall/gjson"
|
||||
"github.com/tidwall/sjson"
|
||||
)
|
||||
|
||||
// GetThinkingText extracts the thinking text from a content part.
|
||||
// Handles various formats:
|
||||
// - Simple string: { "thinking": "text" } or { "text": "text" }
|
||||
// - Wrapped object: { "thinking": { "text": "text", "cache_control": {...} } }
|
||||
// - Gemini-style: { "thought": true, "text": "text" }
|
||||
// Returns the extracted text string.
|
||||
func GetThinkingText(part gjson.Result) string {
|
||||
// Try direct text field first (Gemini-style)
|
||||
if text := part.Get("text"); text.Exists() && text.Type == gjson.String {
|
||||
return text.String()
|
||||
}
|
||||
|
||||
// Try thinking field
|
||||
thinkingField := part.Get("thinking")
|
||||
if !thinkingField.Exists() {
|
||||
return ""
|
||||
}
|
||||
|
||||
// thinking is a string
|
||||
if thinkingField.Type == gjson.String {
|
||||
return thinkingField.String()
|
||||
}
|
||||
|
||||
// thinking is an object with inner text/thinking
|
||||
if thinkingField.IsObject() {
|
||||
if inner := thinkingField.Get("text"); inner.Exists() && inner.Type == gjson.String {
|
||||
return inner.String()
|
||||
}
|
||||
if inner := thinkingField.Get("thinking"); inner.Exists() && inner.Type == gjson.String {
|
||||
return inner.String()
|
||||
}
|
||||
}
|
||||
|
||||
return ""
|
||||
}
|
||||
|
||||
// GetThinkingTextFromJSON extracts thinking text from a raw JSON string.
|
||||
func GetThinkingTextFromJSON(jsonStr string) string {
|
||||
return GetThinkingText(gjson.Parse(jsonStr))
|
||||
}
|
||||
|
||||
// SanitizeThinkingPart normalizes a thinking part to a canonical form.
|
||||
// Strips cache_control and other non-essential fields.
|
||||
// Returns the sanitized part as JSON string.
|
||||
func SanitizeThinkingPart(part gjson.Result) string {
|
||||
// Gemini-style: { thought: true, text, thoughtSignature }
|
||||
if part.Get("thought").Bool() {
|
||||
result := `{"thought":true}`
|
||||
if text := GetThinkingText(part); text != "" {
|
||||
result, _ = sjson.Set(result, "text", text)
|
||||
}
|
||||
if sig := part.Get("thoughtSignature"); sig.Exists() && sig.Type == gjson.String {
|
||||
result, _ = sjson.Set(result, "thoughtSignature", sig.String())
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// Anthropic-style: { type: "thinking", thinking, signature }
|
||||
if part.Get("type").String() == "thinking" || part.Get("thinking").Exists() {
|
||||
result := `{"type":"thinking"}`
|
||||
if text := GetThinkingText(part); text != "" {
|
||||
result, _ = sjson.Set(result, "thinking", text)
|
||||
}
|
||||
if sig := part.Get("signature"); sig.Exists() && sig.Type == gjson.String {
|
||||
result, _ = sjson.Set(result, "signature", sig.String())
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// Not a thinking part, return as-is but strip cache_control
|
||||
return StripCacheControl(part.Raw)
|
||||
}
|
||||
|
||||
// StripCacheControl removes cache_control and providerOptions from a JSON object.
|
||||
func StripCacheControl(jsonStr string) string {
|
||||
result := jsonStr
|
||||
result, _ = sjson.Delete(result, "cache_control")
|
||||
result, _ = sjson.Delete(result, "providerOptions")
|
||||
return result
|
||||
}
|
||||
46
sdk/api/options.go
Normal file
46
sdk/api/options.go
Normal file
@@ -0,0 +1,46 @@
|
||||
// Package api exposes server option helpers for embedding CLIProxyAPI.
|
||||
//
|
||||
// It wraps internal server option types so external projects can configure the embedded
|
||||
// HTTP server without importing internal packages.
|
||||
package api
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
internalapi "github.com/router-for-me/CLIProxyAPI/v6/internal/api"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/sdk/api/handlers"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/sdk/config"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/sdk/logging"
|
||||
)
|
||||
|
||||
// ServerOption customises HTTP server construction.
|
||||
type ServerOption = internalapi.ServerOption
|
||||
|
||||
// WithMiddleware appends additional Gin middleware during server construction.
|
||||
func WithMiddleware(mw ...gin.HandlerFunc) ServerOption { return internalapi.WithMiddleware(mw...) }
|
||||
|
||||
// WithEngineConfigurator allows callers to mutate the Gin engine prior to middleware setup.
|
||||
func WithEngineConfigurator(fn func(*gin.Engine)) ServerOption {
|
||||
return internalapi.WithEngineConfigurator(fn)
|
||||
}
|
||||
|
||||
// WithRouterConfigurator appends a callback after default routes are registered.
|
||||
func WithRouterConfigurator(fn func(*gin.Engine, *handlers.BaseAPIHandler, *config.Config)) ServerOption {
|
||||
return internalapi.WithRouterConfigurator(fn)
|
||||
}
|
||||
|
||||
// WithLocalManagementPassword stores a runtime-only management password accepted for localhost requests.
|
||||
func WithLocalManagementPassword(password string) ServerOption {
|
||||
return internalapi.WithLocalManagementPassword(password)
|
||||
}
|
||||
|
||||
// WithKeepAliveEndpoint enables a keep-alive endpoint with the provided timeout and callback.
|
||||
func WithKeepAliveEndpoint(timeout time.Duration, onTimeout func()) ServerOption {
|
||||
return internalapi.WithKeepAliveEndpoint(timeout, onTimeout)
|
||||
}
|
||||
|
||||
// WithRequestLoggerFactory customises request logger creation.
|
||||
func WithRequestLoggerFactory(factory func(*config.Config, string) logging.RequestLogger) ServerOption {
|
||||
return internalapi.WithRequestLoggerFactory(factory)
|
||||
}
|
||||
@@ -99,12 +99,55 @@ func (AntigravityAuthenticator) Login(ctx context.Context, cfg *config.Config, o
|
||||
fmt.Println("Waiting for antigravity authentication callback...")
|
||||
|
||||
var cbRes callbackResult
|
||||
timeoutTimer := time.NewTimer(5 * time.Minute)
|
||||
defer timeoutTimer.Stop()
|
||||
|
||||
var manualPromptTimer *time.Timer
|
||||
var manualPromptC <-chan time.Time
|
||||
if opts.Prompt != nil {
|
||||
manualPromptTimer = time.NewTimer(15 * time.Second)
|
||||
manualPromptC = manualPromptTimer.C
|
||||
defer manualPromptTimer.Stop()
|
||||
}
|
||||
|
||||
waitForCallback:
|
||||
for {
|
||||
select {
|
||||
case res := <-cbChan:
|
||||
cbRes = res
|
||||
case <-time.After(5 * time.Minute):
|
||||
break waitForCallback
|
||||
case <-manualPromptC:
|
||||
manualPromptC = nil
|
||||
if manualPromptTimer != nil {
|
||||
manualPromptTimer.Stop()
|
||||
}
|
||||
select {
|
||||
case res := <-cbChan:
|
||||
cbRes = res
|
||||
break waitForCallback
|
||||
default:
|
||||
}
|
||||
input, errPrompt := opts.Prompt("Paste the antigravity callback URL (or press Enter to keep waiting): ")
|
||||
if errPrompt != nil {
|
||||
return nil, errPrompt
|
||||
}
|
||||
parsed, errParse := misc.ParseOAuthCallback(input)
|
||||
if errParse != nil {
|
||||
return nil, errParse
|
||||
}
|
||||
if parsed == nil {
|
||||
continue
|
||||
}
|
||||
cbRes = callbackResult{
|
||||
Code: parsed.Code,
|
||||
State: parsed.State,
|
||||
Error: parsed.Error,
|
||||
}
|
||||
break waitForCallback
|
||||
case <-timeoutTimer.C:
|
||||
return nil, fmt.Errorf("antigravity: authentication timed out")
|
||||
}
|
||||
}
|
||||
|
||||
if cbRes.Error != "" {
|
||||
return nil, fmt.Errorf("antigravity: authentication failed: %s", cbRes.Error)
|
||||
|
||||
@@ -98,16 +98,76 @@ func (a *ClaudeAuthenticator) Login(ctx context.Context, cfg *config.Config, opt
|
||||
|
||||
fmt.Println("Waiting for Claude authentication callback...")
|
||||
|
||||
result, err := oauthServer.WaitForCallback(5 * time.Minute)
|
||||
if err != nil {
|
||||
callbackCh := make(chan *claude.OAuthResult, 1)
|
||||
callbackErrCh := make(chan error, 1)
|
||||
manualDescription := ""
|
||||
|
||||
go func() {
|
||||
result, errWait := oauthServer.WaitForCallback(5 * time.Minute)
|
||||
if errWait != nil {
|
||||
callbackErrCh <- errWait
|
||||
return
|
||||
}
|
||||
callbackCh <- result
|
||||
}()
|
||||
|
||||
var result *claude.OAuthResult
|
||||
var manualPromptTimer *time.Timer
|
||||
var manualPromptC <-chan time.Time
|
||||
if opts.Prompt != nil {
|
||||
manualPromptTimer = time.NewTimer(15 * time.Second)
|
||||
manualPromptC = manualPromptTimer.C
|
||||
defer manualPromptTimer.Stop()
|
||||
}
|
||||
|
||||
waitForCallback:
|
||||
for {
|
||||
select {
|
||||
case result = <-callbackCh:
|
||||
break waitForCallback
|
||||
case err = <-callbackErrCh:
|
||||
if strings.Contains(err.Error(), "timeout") {
|
||||
return nil, claude.NewAuthenticationError(claude.ErrCallbackTimeout, err)
|
||||
}
|
||||
return nil, err
|
||||
case <-manualPromptC:
|
||||
manualPromptC = nil
|
||||
if manualPromptTimer != nil {
|
||||
manualPromptTimer.Stop()
|
||||
}
|
||||
select {
|
||||
case result = <-callbackCh:
|
||||
break waitForCallback
|
||||
case err = <-callbackErrCh:
|
||||
if strings.Contains(err.Error(), "timeout") {
|
||||
return nil, claude.NewAuthenticationError(claude.ErrCallbackTimeout, err)
|
||||
}
|
||||
return nil, err
|
||||
default:
|
||||
}
|
||||
input, errPrompt := opts.Prompt("Paste the Claude callback URL (or press Enter to keep waiting): ")
|
||||
if errPrompt != nil {
|
||||
return nil, errPrompt
|
||||
}
|
||||
parsed, errParse := misc.ParseOAuthCallback(input)
|
||||
if errParse != nil {
|
||||
return nil, errParse
|
||||
}
|
||||
if parsed == nil {
|
||||
continue
|
||||
}
|
||||
manualDescription = parsed.ErrorDescription
|
||||
result = &claude.OAuthResult{
|
||||
Code: parsed.Code,
|
||||
State: parsed.State,
|
||||
Error: parsed.Error,
|
||||
}
|
||||
break waitForCallback
|
||||
}
|
||||
}
|
||||
|
||||
if result.Error != "" {
|
||||
return nil, claude.NewOAuthError(result.Error, "", http.StatusBadRequest)
|
||||
return nil, claude.NewOAuthError(result.Error, manualDescription, http.StatusBadRequest)
|
||||
}
|
||||
|
||||
if result.State != state {
|
||||
|
||||
@@ -97,16 +97,76 @@ func (a *CodexAuthenticator) Login(ctx context.Context, cfg *config.Config, opts
|
||||
|
||||
fmt.Println("Waiting for Codex authentication callback...")
|
||||
|
||||
result, err := oauthServer.WaitForCallback(5 * time.Minute)
|
||||
if err != nil {
|
||||
callbackCh := make(chan *codex.OAuthResult, 1)
|
||||
callbackErrCh := make(chan error, 1)
|
||||
manualDescription := ""
|
||||
|
||||
go func() {
|
||||
result, errWait := oauthServer.WaitForCallback(5 * time.Minute)
|
||||
if errWait != nil {
|
||||
callbackErrCh <- errWait
|
||||
return
|
||||
}
|
||||
callbackCh <- result
|
||||
}()
|
||||
|
||||
var result *codex.OAuthResult
|
||||
var manualPromptTimer *time.Timer
|
||||
var manualPromptC <-chan time.Time
|
||||
if opts.Prompt != nil {
|
||||
manualPromptTimer = time.NewTimer(15 * time.Second)
|
||||
manualPromptC = manualPromptTimer.C
|
||||
defer manualPromptTimer.Stop()
|
||||
}
|
||||
|
||||
waitForCallback:
|
||||
for {
|
||||
select {
|
||||
case result = <-callbackCh:
|
||||
break waitForCallback
|
||||
case err = <-callbackErrCh:
|
||||
if strings.Contains(err.Error(), "timeout") {
|
||||
return nil, codex.NewAuthenticationError(codex.ErrCallbackTimeout, err)
|
||||
}
|
||||
return nil, err
|
||||
case <-manualPromptC:
|
||||
manualPromptC = nil
|
||||
if manualPromptTimer != nil {
|
||||
manualPromptTimer.Stop()
|
||||
}
|
||||
select {
|
||||
case result = <-callbackCh:
|
||||
break waitForCallback
|
||||
case err = <-callbackErrCh:
|
||||
if strings.Contains(err.Error(), "timeout") {
|
||||
return nil, codex.NewAuthenticationError(codex.ErrCallbackTimeout, err)
|
||||
}
|
||||
return nil, err
|
||||
default:
|
||||
}
|
||||
input, errPrompt := opts.Prompt("Paste the Codex callback URL (or press Enter to keep waiting): ")
|
||||
if errPrompt != nil {
|
||||
return nil, errPrompt
|
||||
}
|
||||
parsed, errParse := misc.ParseOAuthCallback(input)
|
||||
if errParse != nil {
|
||||
return nil, errParse
|
||||
}
|
||||
if parsed == nil {
|
||||
continue
|
||||
}
|
||||
manualDescription = parsed.ErrorDescription
|
||||
result = &codex.OAuthResult{
|
||||
Code: parsed.Code,
|
||||
State: parsed.State,
|
||||
Error: parsed.Error,
|
||||
}
|
||||
break waitForCallback
|
||||
}
|
||||
}
|
||||
|
||||
if result.Error != "" {
|
||||
return nil, codex.NewOAuthError(result.Error, "", http.StatusBadRequest)
|
||||
return nil, codex.NewOAuthError(result.Error, manualDescription, http.StatusBadRequest)
|
||||
}
|
||||
|
||||
if result.State != state {
|
||||
|
||||
@@ -72,7 +72,9 @@ func (s *FileTokenStore) Save(ctx context.Context, auth *cliproxyauth.Auth) (str
|
||||
return "", fmt.Errorf("auth filestore: marshal metadata failed: %w", errMarshal)
|
||||
}
|
||||
if existing, errRead := os.ReadFile(path); errRead == nil {
|
||||
if jsonEqual(existing, raw) {
|
||||
// Use metadataEqualIgnoringTimestamps to skip writes when only timestamp fields change.
|
||||
// This prevents the token refresh loop caused by timestamp/expired/expires_in changes.
|
||||
if metadataEqualIgnoringTimestamps(existing, raw) {
|
||||
return path, nil
|
||||
}
|
||||
} else if errRead != nil && !os.IsNotExist(errRead) {
|
||||
@@ -264,6 +266,8 @@ func (s *FileTokenStore) baseDirSnapshot() string {
|
||||
return s.baseDir
|
||||
}
|
||||
|
||||
// DEPRECATED: Use metadataEqualIgnoringTimestamps for comparing auth metadata.
|
||||
// This function is kept for backward compatibility but can cause refresh loops.
|
||||
func jsonEqual(a, b []byte) bool {
|
||||
var objA any
|
||||
var objB any
|
||||
@@ -276,6 +280,32 @@ func jsonEqual(a, b []byte) bool {
|
||||
return deepEqualJSON(objA, objB)
|
||||
}
|
||||
|
||||
// metadataEqualIgnoringTimestamps compares two metadata JSON blobs,
|
||||
// ignoring fields that change on every refresh but don't affect functionality.
|
||||
// This prevents unnecessary file writes that would trigger watcher events and
|
||||
// create refresh loops.
|
||||
func metadataEqualIgnoringTimestamps(a, b []byte) bool {
|
||||
var objA, objB map[string]any
|
||||
if err := json.Unmarshal(a, &objA); err != nil {
|
||||
return false
|
||||
}
|
||||
if err := json.Unmarshal(b, &objB); err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
// Fields to ignore: these change on every refresh but don't affect authentication logic.
|
||||
// - timestamp, expired, expires_in, last_refresh: time-related fields that change on refresh
|
||||
// - access_token: Google OAuth returns a new access_token on each refresh, this is expected
|
||||
// and shouldn't trigger file writes (the new token will be fetched again when needed)
|
||||
ignoredFields := []string{"timestamp", "expired", "expires_in", "last_refresh", "access_token"}
|
||||
for _, field := range ignoredFields {
|
||||
delete(objA, field)
|
||||
delete(objB, field)
|
||||
}
|
||||
|
||||
return deepEqualJSON(objA, objB)
|
||||
}
|
||||
|
||||
func deepEqualJSON(a, b any) bool {
|
||||
switch valA := a.(type) {
|
||||
case map[string]any:
|
||||
|
||||
@@ -44,7 +44,10 @@ func (a *GeminiAuthenticator) Login(ctx context.Context, cfg *config.Config, opt
|
||||
}
|
||||
|
||||
geminiAuth := gemini.NewGeminiAuth()
|
||||
_, err := geminiAuth.GetAuthenticatedClient(ctx, &ts, cfg, opts.NoBrowser)
|
||||
_, err := geminiAuth.GetAuthenticatedClient(ctx, &ts, cfg, &gemini.WebLoginOptions{
|
||||
NoBrowser: opts.NoBrowser,
|
||||
Prompt: opts.Prompt,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("gemini authentication failed: %w", err)
|
||||
}
|
||||
|
||||
@@ -84,9 +84,64 @@ func (a *IFlowAuthenticator) Login(ctx context.Context, cfg *config.Config, opts
|
||||
|
||||
fmt.Println("Waiting for iFlow authentication callback...")
|
||||
|
||||
result, err := oauthServer.WaitForCallback(5 * time.Minute)
|
||||
if err != nil {
|
||||
callbackCh := make(chan *iflow.OAuthResult, 1)
|
||||
callbackErrCh := make(chan error, 1)
|
||||
|
||||
go func() {
|
||||
result, errWait := oauthServer.WaitForCallback(5 * time.Minute)
|
||||
if errWait != nil {
|
||||
callbackErrCh <- errWait
|
||||
return
|
||||
}
|
||||
callbackCh <- result
|
||||
}()
|
||||
|
||||
var result *iflow.OAuthResult
|
||||
var manualPromptTimer *time.Timer
|
||||
var manualPromptC <-chan time.Time
|
||||
if opts.Prompt != nil {
|
||||
manualPromptTimer = time.NewTimer(15 * time.Second)
|
||||
manualPromptC = manualPromptTimer.C
|
||||
defer manualPromptTimer.Stop()
|
||||
}
|
||||
|
||||
waitForCallback:
|
||||
for {
|
||||
select {
|
||||
case result = <-callbackCh:
|
||||
break waitForCallback
|
||||
case err = <-callbackErrCh:
|
||||
return nil, fmt.Errorf("iflow auth: callback wait failed: %w", err)
|
||||
case <-manualPromptC:
|
||||
manualPromptC = nil
|
||||
if manualPromptTimer != nil {
|
||||
manualPromptTimer.Stop()
|
||||
}
|
||||
select {
|
||||
case result = <-callbackCh:
|
||||
break waitForCallback
|
||||
case err = <-callbackErrCh:
|
||||
return nil, fmt.Errorf("iflow auth: callback wait failed: %w", err)
|
||||
default:
|
||||
}
|
||||
input, errPrompt := opts.Prompt("Paste the iFlow callback URL (or press Enter to keep waiting): ")
|
||||
if errPrompt != nil {
|
||||
return nil, errPrompt
|
||||
}
|
||||
parsed, errParse := misc.ParseOAuthCallback(input)
|
||||
if errParse != nil {
|
||||
return nil, errParse
|
||||
}
|
||||
if parsed == nil {
|
||||
continue
|
||||
}
|
||||
result = &iflow.OAuthResult{
|
||||
Code: parsed.Code,
|
||||
State: parsed.State,
|
||||
Error: parsed.Error,
|
||||
}
|
||||
break waitForCallback
|
||||
}
|
||||
}
|
||||
if result.Error != "" {
|
||||
return nil, fmt.Errorf("iflow auth: provider returned error %s", result.Error)
|
||||
|
||||
@@ -7,10 +7,10 @@ import (
|
||||
"fmt"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/api"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
sdkaccess "github.com/router-for-me/CLIProxyAPI/v6/sdk/access"
|
||||
sdkAuth "github.com/router-for-me/CLIProxyAPI/v6/sdk/auth"
|
||||
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/sdk/config"
|
||||
)
|
||||
|
||||
// Builder constructs a Service instance with customizable providers.
|
||||
|
||||
@@ -3,8 +3,8 @@ package cliproxy
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/watcher"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/sdk/config"
|
||||
)
|
||||
|
||||
// NewFileTokenClientProvider returns the default token-backed client loader.
|
||||
|
||||
@@ -13,7 +13,6 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/api"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/runtime/executor"
|
||||
_ "github.com/router-for-me/CLIProxyAPI/v6/internal/usage"
|
||||
@@ -23,6 +22,7 @@ import (
|
||||
sdkAuth "github.com/router-for-me/CLIProxyAPI/v6/sdk/auth"
|
||||
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/usage"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/sdk/config"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
|
||||
@@ -6,9 +6,9 @@ package cliproxy
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/watcher"
|
||||
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/sdk/config"
|
||||
)
|
||||
|
||||
// TokenClientProvider loads clients backed by stored authentication tokens.
|
||||
|
||||
@@ -3,9 +3,9 @@ package cliproxy
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/watcher"
|
||||
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/sdk/config"
|
||||
)
|
||||
|
||||
func defaultWatcherFactory(configPath, authDir string, reload func(*config.Config)) (*WatcherWrapper, error) {
|
||||
|
||||
@@ -1,87 +1,59 @@
|
||||
// Package config provides configuration management for the CLI Proxy API server.
|
||||
// It handles loading and parsing YAML configuration files, and provides structured
|
||||
// access to application settings including server port, authentication directory,
|
||||
// debug settings, proxy configuration, and API keys.
|
||||
// Package config provides the public SDK configuration API.
|
||||
//
|
||||
// It re-exports the server configuration types and helpers so external projects can
|
||||
// embed CLIProxyAPI without importing internal packages.
|
||||
package config
|
||||
|
||||
// SDKConfig represents the application's configuration, loaded from a YAML file.
|
||||
type SDKConfig struct {
|
||||
// ProxyURL is the URL of an optional proxy server to use for outbound requests.
|
||||
ProxyURL string `yaml:"proxy-url" json:"proxy-url"`
|
||||
import internalconfig "github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
|
||||
// ForceModelPrefix requires explicit model prefixes (e.g., "teamA/gemini-3-pro-preview")
|
||||
// to target prefixed credentials. When false, unprefixed model requests may use prefixed
|
||||
// credentials as well.
|
||||
ForceModelPrefix bool `yaml:"force-model-prefix" json:"force-model-prefix"`
|
||||
type SDKConfig = internalconfig.SDKConfig
|
||||
type AccessConfig = internalconfig.AccessConfig
|
||||
type AccessProvider = internalconfig.AccessProvider
|
||||
|
||||
// RequestLog enables or disables detailed request logging functionality.
|
||||
RequestLog bool `yaml:"request-log" json:"request-log"`
|
||||
type Config = internalconfig.Config
|
||||
|
||||
// APIKeys is a list of keys for authenticating clients to this proxy server.
|
||||
APIKeys []string `yaml:"api-keys" json:"api-keys"`
|
||||
type TLSConfig = internalconfig.TLSConfig
|
||||
type RemoteManagement = internalconfig.RemoteManagement
|
||||
type AmpCode = internalconfig.AmpCode
|
||||
type PayloadConfig = internalconfig.PayloadConfig
|
||||
type PayloadRule = internalconfig.PayloadRule
|
||||
type PayloadModelRule = internalconfig.PayloadModelRule
|
||||
|
||||
// Access holds request authentication provider configuration.
|
||||
Access AccessConfig `yaml:"auth,omitempty" json:"auth,omitempty"`
|
||||
}
|
||||
type GeminiKey = internalconfig.GeminiKey
|
||||
type CodexKey = internalconfig.CodexKey
|
||||
type ClaudeKey = internalconfig.ClaudeKey
|
||||
type VertexCompatKey = internalconfig.VertexCompatKey
|
||||
type VertexCompatModel = internalconfig.VertexCompatModel
|
||||
type OpenAICompatibility = internalconfig.OpenAICompatibility
|
||||
type OpenAICompatibilityAPIKey = internalconfig.OpenAICompatibilityAPIKey
|
||||
type OpenAICompatibilityModel = internalconfig.OpenAICompatibilityModel
|
||||
|
||||
// AccessConfig groups request authentication providers.
|
||||
type AccessConfig struct {
|
||||
// Providers lists configured authentication providers.
|
||||
Providers []AccessProvider `yaml:"providers,omitempty" json:"providers,omitempty"`
|
||||
}
|
||||
|
||||
// AccessProvider describes a request authentication provider entry.
|
||||
type AccessProvider struct {
|
||||
// Name is the instance identifier for the provider.
|
||||
Name string `yaml:"name" json:"name"`
|
||||
|
||||
// Type selects the provider implementation registered via the SDK.
|
||||
Type string `yaml:"type" json:"type"`
|
||||
|
||||
// SDK optionally names a third-party SDK module providing this provider.
|
||||
SDK string `yaml:"sdk,omitempty" json:"sdk,omitempty"`
|
||||
|
||||
// APIKeys lists inline keys for providers that require them.
|
||||
APIKeys []string `yaml:"api-keys,omitempty" json:"api-keys,omitempty"`
|
||||
|
||||
// Config passes provider-specific options to the implementation.
|
||||
Config map[string]any `yaml:"config,omitempty" json:"config,omitempty"`
|
||||
}
|
||||
type TLS = internalconfig.TLSConfig
|
||||
|
||||
const (
|
||||
// AccessProviderTypeConfigAPIKey is the built-in provider validating inline API keys.
|
||||
AccessProviderTypeConfigAPIKey = "config-api-key"
|
||||
|
||||
// DefaultAccessProviderName is applied when no provider name is supplied.
|
||||
DefaultAccessProviderName = "config-inline"
|
||||
AccessProviderTypeConfigAPIKey = internalconfig.AccessProviderTypeConfigAPIKey
|
||||
DefaultAccessProviderName = internalconfig.DefaultAccessProviderName
|
||||
DefaultPanelGitHubRepository = internalconfig.DefaultPanelGitHubRepository
|
||||
)
|
||||
|
||||
// ConfigAPIKeyProvider returns the first inline API key provider if present.
|
||||
func (c *SDKConfig) ConfigAPIKeyProvider() *AccessProvider {
|
||||
if c == nil {
|
||||
return nil
|
||||
}
|
||||
for i := range c.Access.Providers {
|
||||
if c.Access.Providers[i].Type == AccessProviderTypeConfigAPIKey {
|
||||
if c.Access.Providers[i].Name == "" {
|
||||
c.Access.Providers[i].Name = DefaultAccessProviderName
|
||||
}
|
||||
return &c.Access.Providers[i]
|
||||
}
|
||||
}
|
||||
return nil
|
||||
func MakeInlineAPIKeyProvider(keys []string) *AccessProvider {
|
||||
return internalconfig.MakeInlineAPIKeyProvider(keys)
|
||||
}
|
||||
|
||||
// MakeInlineAPIKeyProvider constructs an inline API key provider configuration.
|
||||
// It returns nil when no keys are supplied.
|
||||
func MakeInlineAPIKeyProvider(keys []string) *AccessProvider {
|
||||
if len(keys) == 0 {
|
||||
return nil
|
||||
func LoadConfig(configFile string) (*Config, error) { return internalconfig.LoadConfig(configFile) }
|
||||
|
||||
func LoadConfigOptional(configFile string, optional bool) (*Config, error) {
|
||||
return internalconfig.LoadConfigOptional(configFile, optional)
|
||||
}
|
||||
provider := &AccessProvider{
|
||||
Name: DefaultAccessProviderName,
|
||||
Type: AccessProviderTypeConfigAPIKey,
|
||||
APIKeys: append([]string(nil), keys...),
|
||||
|
||||
func SaveConfigPreserveComments(configFile string, cfg *Config) error {
|
||||
return internalconfig.SaveConfigPreserveComments(configFile, cfg)
|
||||
}
|
||||
return provider
|
||||
|
||||
func SaveConfigPreserveCommentsUpdateNestedScalar(configFile string, path []string, value string) error {
|
||||
return internalconfig.SaveConfigPreserveCommentsUpdateNestedScalar(configFile, path, value)
|
||||
}
|
||||
|
||||
func NormalizeCommentIndentation(data []byte) []byte {
|
||||
return internalconfig.NormalizeCommentIndentation(data)
|
||||
}
|
||||
|
||||
18
sdk/logging/request_logger.go
Normal file
18
sdk/logging/request_logger.go
Normal file
@@ -0,0 +1,18 @@
|
||||
// Package logging re-exports request logging primitives for SDK consumers.
|
||||
package logging
|
||||
|
||||
import internallogging "github.com/router-for-me/CLIProxyAPI/v6/internal/logging"
|
||||
|
||||
// RequestLogger defines the interface for logging HTTP requests and responses.
|
||||
type RequestLogger = internallogging.RequestLogger
|
||||
|
||||
// StreamingLogWriter handles real-time logging of streaming response chunks.
|
||||
type StreamingLogWriter = internallogging.StreamingLogWriter
|
||||
|
||||
// FileRequestLogger implements RequestLogger using file-based storage.
|
||||
type FileRequestLogger = internallogging.FileRequestLogger
|
||||
|
||||
// NewFileRequestLogger creates a new file-based request logger.
|
||||
func NewFileRequestLogger(enabled bool, logsDir string, configDir string) *FileRequestLogger {
|
||||
return internallogging.NewFileRequestLogger(enabled, logsDir, configDir)
|
||||
}
|
||||
Reference in New Issue
Block a user