mirror of
https://github.com/router-for-me/CLIProxyAPI.git
synced 2026-02-13 18:00:51 +08:00
Compare commits
69 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
10f8c795ac | ||
|
|
3e4858a624 | ||
|
|
c84ff42bcd | ||
|
|
8a5db02165 | ||
|
|
d7afb6eb0c | ||
|
|
bbd1fe890a | ||
|
|
f607231efa | ||
|
|
2039062845 | ||
|
|
99478d13a8 | ||
|
|
69d3a80fc3 | ||
|
|
9e268ad103 | ||
|
|
9d9b9e7a0d | ||
|
|
13aa82f3f3 | ||
|
|
05e55d7dc5 | ||
|
|
1b358c931c | ||
|
|
ca09db21ff | ||
|
|
718ff7a73f | ||
|
|
fa70b220e9 | ||
|
|
774f1fbc17 | ||
|
|
cfa8ddb59f | ||
|
|
39597267ae | ||
|
|
393e38f2c0 | ||
|
|
d1220de02d | ||
|
|
13eb5268de | ||
|
|
88798816f2 | ||
|
|
598f0af19b | ||
|
|
a33f5d31fc | ||
|
|
506699fba1 | ||
|
|
68a27772b3 | ||
|
|
de87fb622b | ||
|
|
f27672f6cf | ||
|
|
28420c14e4 | ||
|
|
0bd221ff41 | ||
|
|
5fda6f8ef3 | ||
|
|
9b956f6338 | ||
|
|
09923f654c | ||
|
|
ae7b972649 | ||
|
|
47885e3710 | ||
|
|
4b9a260b37 | ||
|
|
2c743c8f0b | ||
|
|
9f2c278ee6 | ||
|
|
aea337cfe2 | ||
|
|
811f8f8b4f | ||
|
|
27734a23b1 | ||
|
|
1b8e538a77 | ||
|
|
41c2385aca | ||
|
|
d605985f45 | ||
|
|
d52b28b147 | ||
|
|
4afe1f42ca | ||
|
|
7481c0eaa0 | ||
|
|
ffdfad8482 | ||
|
|
6586f08584 | ||
|
|
f49e887fe6 | ||
|
|
a5b3ff11fd | ||
|
|
084558f200 | ||
|
|
b602eae215 | ||
|
|
d02bf9c243 | ||
|
|
26a5f67df2 | ||
|
|
600fd42a83 | ||
|
|
670685139a | ||
|
|
52b6306388 | ||
|
|
521ec6f1b8 | ||
|
|
b0c5d9640a | ||
|
|
ef8e94e992 | ||
|
|
9df96a4bb4 | ||
|
|
28a428ae2f | ||
|
|
b326ec3641 | ||
|
|
fcecbc7d46 | ||
|
|
f4007f53ba |
@@ -27,5 +27,8 @@ config.yaml
|
||||
bin/*
|
||||
.claude/*
|
||||
.vscode/*
|
||||
.gemini/*
|
||||
.serena/*
|
||||
.bmad/*
|
||||
.agent/*
|
||||
.bmad/*
|
||||
_bmad/*
|
||||
|
||||
23
.github/workflows/pr-test-build.yml
vendored
Normal file
23
.github/workflows/pr-test-build.yml
vendored
Normal file
@@ -0,0 +1,23 @@
|
||||
name: pr-test-build
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
jobs:
|
||||
build:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout
|
||||
uses: actions/checkout@v4
|
||||
- name: Set up Go
|
||||
uses: actions/setup-go@v5
|
||||
with:
|
||||
go-version-file: go.mod
|
||||
cache: true
|
||||
- name: Build
|
||||
run: |
|
||||
go build -o test-output ./cmd/server
|
||||
rm -f test-output
|
||||
3
.gitignore
vendored
3
.gitignore
vendored
@@ -30,8 +30,11 @@ GEMINI.md
|
||||
# Tooling metadata
|
||||
.vscode/*
|
||||
.claude/*
|
||||
.gemini/*
|
||||
.serena/*
|
||||
.agent/*
|
||||
.bmad/*
|
||||
_bmad/*
|
||||
|
||||
# macOS
|
||||
.DS_Store
|
||||
|
||||
@@ -405,7 +405,7 @@ func main() {
|
||||
usage.SetStatisticsEnabled(cfg.UsageStatisticsEnabled)
|
||||
coreauth.SetQuotaCooldownDisabled(cfg.DisableCooling)
|
||||
|
||||
if err = logging.ConfigureLogOutput(cfg.LoggingToFile); err != nil {
|
||||
if err = logging.ConfigureLogOutput(cfg.LoggingToFile, cfg.LogsMaxTotalSizeMB); err != nil {
|
||||
log.Errorf("failed to configure log output: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -42,12 +42,19 @@ debug: false
|
||||
# When true, write application logs to rotating files instead of stdout
|
||||
logging-to-file: false
|
||||
|
||||
# Maximum total size (MB) of log files under the logs directory. When exceeded, the oldest log
|
||||
# files are deleted until within the limit. Set to 0 to disable.
|
||||
logs-max-total-size-mb: 0
|
||||
|
||||
# When false, disable in-memory usage statistics aggregation
|
||||
usage-statistics-enabled: false
|
||||
|
||||
# Proxy URL. Supports socks5/http/https protocols. Example: socks5://user:pass@192.168.1.1:1080/
|
||||
proxy-url: ""
|
||||
|
||||
# When true, unprefixed model requests only use credentials without a prefix (except when prefix == model name).
|
||||
force-model-prefix: false
|
||||
|
||||
# Number of times to retry a request. Retries will occur if the HTTP response code is 403, 408, 500, 502, 503, or 504.
|
||||
request-retry: 3
|
||||
|
||||
@@ -65,6 +72,7 @@ ws-auth: false
|
||||
# Gemini API keys
|
||||
# gemini-api-key:
|
||||
# - api-key: "AIzaSy...01"
|
||||
# prefix: "test" # optional: require calls like "test/gemini-3-pro-preview" to target this credential
|
||||
# base-url: "https://generativelanguage.googleapis.com"
|
||||
# headers:
|
||||
# X-Custom-Header: "custom-value"
|
||||
@@ -79,6 +87,7 @@ ws-auth: false
|
||||
# Codex API keys
|
||||
# codex-api-key:
|
||||
# - api-key: "sk-atSM..."
|
||||
# prefix: "test" # optional: require calls like "test/gpt-5-codex" to target this credential
|
||||
# base-url: "https://www.example.com" # use the custom codex API endpoint
|
||||
# headers:
|
||||
# X-Custom-Header: "custom-value"
|
||||
@@ -93,6 +102,7 @@ ws-auth: false
|
||||
# claude-api-key:
|
||||
# - api-key: "sk-atSM..." # use the official claude API key, no need to set the base url
|
||||
# - api-key: "sk-atSM..."
|
||||
# prefix: "test" # optional: require calls like "test/claude-sonnet-latest" to target this credential
|
||||
# base-url: "https://www.example.com" # use the custom claude API endpoint
|
||||
# headers:
|
||||
# X-Custom-Header: "custom-value"
|
||||
@@ -109,6 +119,7 @@ ws-auth: false
|
||||
# OpenAI compatibility providers
|
||||
# openai-compatibility:
|
||||
# - name: "openrouter" # The name of the provider; it will be used in the user agent and other places.
|
||||
# prefix: "test" # optional: require calls like "test/kimi-k2" to target this provider's credentials
|
||||
# base-url: "https://openrouter.ai/api/v1" # The base URL of the provider.
|
||||
# headers:
|
||||
# X-Custom-Header: "custom-value"
|
||||
@@ -123,6 +134,7 @@ ws-auth: false
|
||||
# Vertex API keys (Vertex-compatible endpoints, use API key + base URL)
|
||||
# vertex-api-key:
|
||||
# - api-key: "vk-123..." # x-goog-api-key header
|
||||
# prefix: "test" # optional: require calls like "test/vertex-pro" to target this credential
|
||||
# base-url: "https://example.com/api" # e.g. https://zenmux.ai/api
|
||||
# proxy-url: "socks5://proxy.example.com:1080" # optional per-key proxy override
|
||||
# headers:
|
||||
|
||||
@@ -23,13 +23,13 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/api"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/logging"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/sdk/api"
|
||||
sdkAuth "github.com/router-for-me/CLIProxyAPI/v6/sdk/auth"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy"
|
||||
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||
clipexec "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/sdk/config"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/sdk/logging"
|
||||
sdktr "github.com/router-for-me/CLIProxyAPI/v6/sdk/translator"
|
||||
)
|
||||
|
||||
|
||||
@@ -36,10 +36,6 @@ import (
|
||||
"golang.org/x/oauth2/google"
|
||||
)
|
||||
|
||||
var (
|
||||
oauthStatus = make(map[string]string)
|
||||
)
|
||||
|
||||
var lastRefreshKeys = []string{"last_refresh", "lastRefresh", "last_refreshed_at", "lastRefreshedAt"}
|
||||
|
||||
const (
|
||||
@@ -786,6 +782,8 @@ func (h *Handler) RequestAnthropicToken(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
RegisterOAuthSession(state, "anthropic")
|
||||
|
||||
isWebUI := isWebUIRequest(c)
|
||||
if isWebUI {
|
||||
targetURL, errTarget := h.managementCallbackURL("/anthropic/callback")
|
||||
@@ -812,7 +810,7 @@ func (h *Handler) RequestAnthropicToken(c *gin.Context) {
|
||||
deadline := time.Now().Add(timeout)
|
||||
for {
|
||||
if time.Now().After(deadline) {
|
||||
oauthStatus[state] = "Timeout waiting for OAuth callback"
|
||||
SetOAuthSessionError(state, "Timeout waiting for OAuth callback")
|
||||
return nil, fmt.Errorf("timeout waiting for OAuth callback")
|
||||
}
|
||||
data, errRead := os.ReadFile(path)
|
||||
@@ -837,13 +835,13 @@ func (h *Handler) RequestAnthropicToken(c *gin.Context) {
|
||||
if errStr := resultMap["error"]; errStr != "" {
|
||||
oauthErr := claude.NewOAuthError(errStr, "", http.StatusBadRequest)
|
||||
log.Error(claude.GetUserFriendlyMessage(oauthErr))
|
||||
oauthStatus[state] = "Bad request"
|
||||
SetOAuthSessionError(state, "Bad request")
|
||||
return
|
||||
}
|
||||
if resultMap["state"] != state {
|
||||
authErr := claude.NewAuthenticationError(claude.ErrInvalidState, fmt.Errorf("expected %s, got %s", state, resultMap["state"]))
|
||||
log.Error(claude.GetUserFriendlyMessage(authErr))
|
||||
oauthStatus[state] = "State code error"
|
||||
SetOAuthSessionError(state, "State code error")
|
||||
return
|
||||
}
|
||||
|
||||
@@ -876,7 +874,7 @@ func (h *Handler) RequestAnthropicToken(c *gin.Context) {
|
||||
if errDo != nil {
|
||||
authErr := claude.NewAuthenticationError(claude.ErrCodeExchangeFailed, errDo)
|
||||
log.Errorf("Failed to exchange authorization code for tokens: %v", authErr)
|
||||
oauthStatus[state] = "Failed to exchange authorization code for tokens"
|
||||
SetOAuthSessionError(state, "Failed to exchange authorization code for tokens")
|
||||
return
|
||||
}
|
||||
defer func() {
|
||||
@@ -887,7 +885,7 @@ func (h *Handler) RequestAnthropicToken(c *gin.Context) {
|
||||
respBody, _ := io.ReadAll(resp.Body)
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
log.Errorf("token exchange failed with status %d: %s", resp.StatusCode, string(respBody))
|
||||
oauthStatus[state] = fmt.Sprintf("token exchange failed with status %d", resp.StatusCode)
|
||||
SetOAuthSessionError(state, fmt.Sprintf("token exchange failed with status %d", resp.StatusCode))
|
||||
return
|
||||
}
|
||||
var tResp struct {
|
||||
@@ -900,7 +898,7 @@ func (h *Handler) RequestAnthropicToken(c *gin.Context) {
|
||||
}
|
||||
if errU := json.Unmarshal(respBody, &tResp); errU != nil {
|
||||
log.Errorf("failed to parse token response: %v", errU)
|
||||
oauthStatus[state] = "Failed to parse token response"
|
||||
SetOAuthSessionError(state, "Failed to parse token response")
|
||||
return
|
||||
}
|
||||
bundle := &claude.ClaudeAuthBundle{
|
||||
@@ -925,7 +923,7 @@ func (h *Handler) RequestAnthropicToken(c *gin.Context) {
|
||||
savedPath, errSave := h.saveTokenRecord(ctx, record)
|
||||
if errSave != nil {
|
||||
log.Errorf("Failed to save authentication tokens: %v", errSave)
|
||||
oauthStatus[state] = "Failed to save authentication tokens"
|
||||
SetOAuthSessionError(state, "Failed to save authentication tokens")
|
||||
return
|
||||
}
|
||||
|
||||
@@ -934,10 +932,9 @@ func (h *Handler) RequestAnthropicToken(c *gin.Context) {
|
||||
fmt.Println("API key obtained and saved")
|
||||
}
|
||||
fmt.Println("You can now use Claude services through this CLI")
|
||||
delete(oauthStatus, state)
|
||||
CompleteOAuthSession(state)
|
||||
}()
|
||||
|
||||
oauthStatus[state] = ""
|
||||
c.JSON(200, gin.H{"status": "ok", "url": authURL, "state": state})
|
||||
}
|
||||
|
||||
@@ -968,6 +965,8 @@ func (h *Handler) RequestGeminiCLIToken(c *gin.Context) {
|
||||
state := fmt.Sprintf("gem-%d", time.Now().UnixNano())
|
||||
authURL := conf.AuthCodeURL(state, oauth2.AccessTypeOffline, oauth2.SetAuthURLParam("prompt", "consent"))
|
||||
|
||||
RegisterOAuthSession(state, "gemini")
|
||||
|
||||
isWebUI := isWebUIRequest(c)
|
||||
if isWebUI {
|
||||
targetURL, errTarget := h.managementCallbackURL("/google/callback")
|
||||
@@ -996,7 +995,7 @@ func (h *Handler) RequestGeminiCLIToken(c *gin.Context) {
|
||||
for {
|
||||
if time.Now().After(deadline) {
|
||||
log.Error("oauth flow timed out")
|
||||
oauthStatus[state] = "OAuth flow timed out"
|
||||
SetOAuthSessionError(state, "OAuth flow timed out")
|
||||
return
|
||||
}
|
||||
if data, errR := os.ReadFile(waitFile); errR == nil {
|
||||
@@ -1005,13 +1004,13 @@ func (h *Handler) RequestGeminiCLIToken(c *gin.Context) {
|
||||
_ = os.Remove(waitFile)
|
||||
if errStr := m["error"]; errStr != "" {
|
||||
log.Errorf("Authentication failed: %s", errStr)
|
||||
oauthStatus[state] = "Authentication failed"
|
||||
SetOAuthSessionError(state, "Authentication failed")
|
||||
return
|
||||
}
|
||||
authCode = m["code"]
|
||||
if authCode == "" {
|
||||
log.Errorf("Authentication failed: code not found")
|
||||
oauthStatus[state] = "Authentication failed: code not found"
|
||||
SetOAuthSessionError(state, "Authentication failed: code not found")
|
||||
return
|
||||
}
|
||||
break
|
||||
@@ -1023,7 +1022,7 @@ func (h *Handler) RequestGeminiCLIToken(c *gin.Context) {
|
||||
token, err := conf.Exchange(ctx, authCode)
|
||||
if err != nil {
|
||||
log.Errorf("Failed to exchange token: %v", err)
|
||||
oauthStatus[state] = "Failed to exchange token"
|
||||
SetOAuthSessionError(state, "Failed to exchange token")
|
||||
return
|
||||
}
|
||||
|
||||
@@ -1034,7 +1033,7 @@ func (h *Handler) RequestGeminiCLIToken(c *gin.Context) {
|
||||
req, errNewRequest := http.NewRequestWithContext(ctx, "GET", "https://www.googleapis.com/oauth2/v1/userinfo?alt=json", nil)
|
||||
if errNewRequest != nil {
|
||||
log.Errorf("Could not get user info: %v", errNewRequest)
|
||||
oauthStatus[state] = "Could not get user info"
|
||||
SetOAuthSessionError(state, "Could not get user info")
|
||||
return
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
@@ -1043,7 +1042,7 @@ func (h *Handler) RequestGeminiCLIToken(c *gin.Context) {
|
||||
resp, errDo := authHTTPClient.Do(req)
|
||||
if errDo != nil {
|
||||
log.Errorf("Failed to execute request: %v", errDo)
|
||||
oauthStatus[state] = "Failed to execute request"
|
||||
SetOAuthSessionError(state, "Failed to execute request")
|
||||
return
|
||||
}
|
||||
defer func() {
|
||||
@@ -1055,7 +1054,7 @@ func (h *Handler) RequestGeminiCLIToken(c *gin.Context) {
|
||||
bodyBytes, _ := io.ReadAll(resp.Body)
|
||||
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
||||
log.Errorf("Get user info request failed with status %d: %s", resp.StatusCode, string(bodyBytes))
|
||||
oauthStatus[state] = fmt.Sprintf("Get user info request failed with status %d", resp.StatusCode)
|
||||
SetOAuthSessionError(state, fmt.Sprintf("Get user info request failed with status %d", resp.StatusCode))
|
||||
return
|
||||
}
|
||||
|
||||
@@ -1064,7 +1063,6 @@ func (h *Handler) RequestGeminiCLIToken(c *gin.Context) {
|
||||
fmt.Printf("Authenticated user email: %s\n", email)
|
||||
} else {
|
||||
fmt.Println("Failed to get user email from token")
|
||||
oauthStatus[state] = "Failed to get user email from token"
|
||||
}
|
||||
|
||||
// Marshal/unmarshal oauth2.Token to generic map and enrich fields
|
||||
@@ -1072,7 +1070,7 @@ func (h *Handler) RequestGeminiCLIToken(c *gin.Context) {
|
||||
jsonData, _ := json.Marshal(token)
|
||||
if errUnmarshal := json.Unmarshal(jsonData, &ifToken); errUnmarshal != nil {
|
||||
log.Errorf("Failed to unmarshal token: %v", errUnmarshal)
|
||||
oauthStatus[state] = "Failed to unmarshal token"
|
||||
SetOAuthSessionError(state, "Failed to unmarshal token")
|
||||
return
|
||||
}
|
||||
|
||||
@@ -1098,7 +1096,7 @@ func (h *Handler) RequestGeminiCLIToken(c *gin.Context) {
|
||||
gemClient, errGetClient := gemAuth.GetAuthenticatedClient(ctx, &ts, h.cfg, true)
|
||||
if errGetClient != nil {
|
||||
log.Errorf("failed to get authenticated client: %v", errGetClient)
|
||||
oauthStatus[state] = "Failed to get authenticated client"
|
||||
SetOAuthSessionError(state, "Failed to get authenticated client")
|
||||
return
|
||||
}
|
||||
fmt.Println("Authentication successful.")
|
||||
@@ -1108,12 +1106,12 @@ func (h *Handler) RequestGeminiCLIToken(c *gin.Context) {
|
||||
projects, errAll := onboardAllGeminiProjects(ctx, gemClient, &ts)
|
||||
if errAll != nil {
|
||||
log.Errorf("Failed to complete Gemini CLI onboarding: %v", errAll)
|
||||
oauthStatus[state] = "Failed to complete Gemini CLI onboarding"
|
||||
SetOAuthSessionError(state, "Failed to complete Gemini CLI onboarding")
|
||||
return
|
||||
}
|
||||
if errVerify := ensureGeminiProjectsEnabled(ctx, gemClient, projects); errVerify != nil {
|
||||
log.Errorf("Failed to verify Cloud AI API status: %v", errVerify)
|
||||
oauthStatus[state] = "Failed to verify Cloud AI API status"
|
||||
SetOAuthSessionError(state, "Failed to verify Cloud AI API status")
|
||||
return
|
||||
}
|
||||
ts.ProjectID = strings.Join(projects, ",")
|
||||
@@ -1121,26 +1119,26 @@ func (h *Handler) RequestGeminiCLIToken(c *gin.Context) {
|
||||
} else {
|
||||
if errEnsure := ensureGeminiProjectAndOnboard(ctx, gemClient, &ts, requestedProjectID); errEnsure != nil {
|
||||
log.Errorf("Failed to complete Gemini CLI onboarding: %v", errEnsure)
|
||||
oauthStatus[state] = "Failed to complete Gemini CLI onboarding"
|
||||
SetOAuthSessionError(state, "Failed to complete Gemini CLI onboarding")
|
||||
return
|
||||
}
|
||||
|
||||
if strings.TrimSpace(ts.ProjectID) == "" {
|
||||
log.Error("Onboarding did not return a project ID")
|
||||
oauthStatus[state] = "Failed to resolve project ID"
|
||||
SetOAuthSessionError(state, "Failed to resolve project ID")
|
||||
return
|
||||
}
|
||||
|
||||
isChecked, errCheck := checkCloudAPIIsEnabled(ctx, gemClient, ts.ProjectID)
|
||||
if errCheck != nil {
|
||||
log.Errorf("Failed to verify Cloud AI API status: %v", errCheck)
|
||||
oauthStatus[state] = "Failed to verify Cloud AI API status"
|
||||
SetOAuthSessionError(state, "Failed to verify Cloud AI API status")
|
||||
return
|
||||
}
|
||||
ts.Checked = isChecked
|
||||
if !isChecked {
|
||||
log.Error("Cloud AI API is not enabled for the selected project")
|
||||
oauthStatus[state] = "Cloud AI API not enabled"
|
||||
SetOAuthSessionError(state, "Cloud AI API not enabled")
|
||||
return
|
||||
}
|
||||
}
|
||||
@@ -1163,15 +1161,14 @@ func (h *Handler) RequestGeminiCLIToken(c *gin.Context) {
|
||||
savedPath, errSave := h.saveTokenRecord(ctx, record)
|
||||
if errSave != nil {
|
||||
log.Errorf("Failed to save token to file: %v", errSave)
|
||||
oauthStatus[state] = "Failed to save token to file"
|
||||
SetOAuthSessionError(state, "Failed to save token to file")
|
||||
return
|
||||
}
|
||||
|
||||
delete(oauthStatus, state)
|
||||
CompleteOAuthSession(state)
|
||||
fmt.Printf("You can now use Gemini CLI services through this CLI; token saved to %s\n", savedPath)
|
||||
}()
|
||||
|
||||
oauthStatus[state] = ""
|
||||
c.JSON(200, gin.H{"status": "ok", "url": authURL, "state": state})
|
||||
}
|
||||
|
||||
@@ -1207,6 +1204,8 @@ func (h *Handler) RequestCodexToken(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
RegisterOAuthSession(state, "codex")
|
||||
|
||||
isWebUI := isWebUIRequest(c)
|
||||
if isWebUI {
|
||||
targetURL, errTarget := h.managementCallbackURL("/codex/callback")
|
||||
@@ -1235,7 +1234,7 @@ func (h *Handler) RequestCodexToken(c *gin.Context) {
|
||||
if time.Now().After(deadline) {
|
||||
authErr := codex.NewAuthenticationError(codex.ErrCallbackTimeout, fmt.Errorf("timeout waiting for OAuth callback"))
|
||||
log.Error(codex.GetUserFriendlyMessage(authErr))
|
||||
oauthStatus[state] = "Timeout waiting for OAuth callback"
|
||||
SetOAuthSessionError(state, "Timeout waiting for OAuth callback")
|
||||
return
|
||||
}
|
||||
if data, errR := os.ReadFile(waitFile); errR == nil {
|
||||
@@ -1245,12 +1244,12 @@ func (h *Handler) RequestCodexToken(c *gin.Context) {
|
||||
if errStr := m["error"]; errStr != "" {
|
||||
oauthErr := codex.NewOAuthError(errStr, "", http.StatusBadRequest)
|
||||
log.Error(codex.GetUserFriendlyMessage(oauthErr))
|
||||
oauthStatus[state] = "Bad Request"
|
||||
SetOAuthSessionError(state, "Bad Request")
|
||||
return
|
||||
}
|
||||
if m["state"] != state {
|
||||
authErr := codex.NewAuthenticationError(codex.ErrInvalidState, fmt.Errorf("expected %s, got %s", state, m["state"]))
|
||||
oauthStatus[state] = "State code error"
|
||||
SetOAuthSessionError(state, "State code error")
|
||||
log.Error(codex.GetUserFriendlyMessage(authErr))
|
||||
return
|
||||
}
|
||||
@@ -1281,14 +1280,14 @@ func (h *Handler) RequestCodexToken(c *gin.Context) {
|
||||
resp, errDo := httpClient.Do(req)
|
||||
if errDo != nil {
|
||||
authErr := codex.NewAuthenticationError(codex.ErrCodeExchangeFailed, errDo)
|
||||
oauthStatus[state] = "Failed to exchange authorization code for tokens"
|
||||
SetOAuthSessionError(state, "Failed to exchange authorization code for tokens")
|
||||
log.Errorf("Failed to exchange authorization code for tokens: %v", authErr)
|
||||
return
|
||||
}
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
respBody, _ := io.ReadAll(resp.Body)
|
||||
if resp.StatusCode != http.StatusOK {
|
||||
oauthStatus[state] = fmt.Sprintf("Token exchange failed with status %d", resp.StatusCode)
|
||||
SetOAuthSessionError(state, fmt.Sprintf("Token exchange failed with status %d", resp.StatusCode))
|
||||
log.Errorf("token exchange failed with status %d: %s", resp.StatusCode, string(respBody))
|
||||
return
|
||||
}
|
||||
@@ -1299,7 +1298,7 @@ func (h *Handler) RequestCodexToken(c *gin.Context) {
|
||||
ExpiresIn int `json:"expires_in"`
|
||||
}
|
||||
if errU := json.Unmarshal(respBody, &tokenResp); errU != nil {
|
||||
oauthStatus[state] = "Failed to parse token response"
|
||||
SetOAuthSessionError(state, "Failed to parse token response")
|
||||
log.Errorf("failed to parse token response: %v", errU)
|
||||
return
|
||||
}
|
||||
@@ -1337,7 +1336,7 @@ func (h *Handler) RequestCodexToken(c *gin.Context) {
|
||||
}
|
||||
savedPath, errSave := h.saveTokenRecord(ctx, record)
|
||||
if errSave != nil {
|
||||
oauthStatus[state] = "Failed to save authentication tokens"
|
||||
SetOAuthSessionError(state, "Failed to save authentication tokens")
|
||||
log.Errorf("Failed to save authentication tokens: %v", errSave)
|
||||
return
|
||||
}
|
||||
@@ -1346,10 +1345,9 @@ func (h *Handler) RequestCodexToken(c *gin.Context) {
|
||||
fmt.Println("API key obtained and saved")
|
||||
}
|
||||
fmt.Println("You can now use Codex services through this CLI")
|
||||
delete(oauthStatus, state)
|
||||
CompleteOAuthSession(state)
|
||||
}()
|
||||
|
||||
oauthStatus[state] = ""
|
||||
c.JSON(200, gin.H{"status": "ok", "url": authURL, "state": state})
|
||||
}
|
||||
|
||||
@@ -1390,6 +1388,8 @@ func (h *Handler) RequestAntigravityToken(c *gin.Context) {
|
||||
params.Set("state", state)
|
||||
authURL := "https://accounts.google.com/o/oauth2/v2/auth?" + params.Encode()
|
||||
|
||||
RegisterOAuthSession(state, "antigravity")
|
||||
|
||||
isWebUI := isWebUIRequest(c)
|
||||
if isWebUI {
|
||||
targetURL, errTarget := h.managementCallbackURL("/antigravity/callback")
|
||||
@@ -1416,7 +1416,7 @@ func (h *Handler) RequestAntigravityToken(c *gin.Context) {
|
||||
for {
|
||||
if time.Now().After(deadline) {
|
||||
log.Error("oauth flow timed out")
|
||||
oauthStatus[state] = "OAuth flow timed out"
|
||||
SetOAuthSessionError(state, "OAuth flow timed out")
|
||||
return
|
||||
}
|
||||
if data, errReadFile := os.ReadFile(waitFile); errReadFile == nil {
|
||||
@@ -1425,18 +1425,18 @@ func (h *Handler) RequestAntigravityToken(c *gin.Context) {
|
||||
_ = os.Remove(waitFile)
|
||||
if errStr := strings.TrimSpace(payload["error"]); errStr != "" {
|
||||
log.Errorf("Authentication failed: %s", errStr)
|
||||
oauthStatus[state] = "Authentication failed"
|
||||
SetOAuthSessionError(state, "Authentication failed")
|
||||
return
|
||||
}
|
||||
if payloadState := strings.TrimSpace(payload["state"]); payloadState != "" && payloadState != state {
|
||||
log.Errorf("Authentication failed: state mismatch")
|
||||
oauthStatus[state] = "Authentication failed: state mismatch"
|
||||
SetOAuthSessionError(state, "Authentication failed: state mismatch")
|
||||
return
|
||||
}
|
||||
authCode = strings.TrimSpace(payload["code"])
|
||||
if authCode == "" {
|
||||
log.Error("Authentication failed: code not found")
|
||||
oauthStatus[state] = "Authentication failed: code not found"
|
||||
SetOAuthSessionError(state, "Authentication failed: code not found")
|
||||
return
|
||||
}
|
||||
break
|
||||
@@ -1455,7 +1455,7 @@ func (h *Handler) RequestAntigravityToken(c *gin.Context) {
|
||||
req, errNewRequest := http.NewRequestWithContext(ctx, http.MethodPost, "https://oauth2.googleapis.com/token", strings.NewReader(form.Encode()))
|
||||
if errNewRequest != nil {
|
||||
log.Errorf("Failed to build token request: %v", errNewRequest)
|
||||
oauthStatus[state] = "Failed to build token request"
|
||||
SetOAuthSessionError(state, "Failed to build token request")
|
||||
return
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/x-www-form-urlencoded")
|
||||
@@ -1463,7 +1463,7 @@ func (h *Handler) RequestAntigravityToken(c *gin.Context) {
|
||||
resp, errDo := httpClient.Do(req)
|
||||
if errDo != nil {
|
||||
log.Errorf("Failed to execute token request: %v", errDo)
|
||||
oauthStatus[state] = "Failed to exchange token"
|
||||
SetOAuthSessionError(state, "Failed to exchange token")
|
||||
return
|
||||
}
|
||||
defer func() {
|
||||
@@ -1475,7 +1475,7 @@ func (h *Handler) RequestAntigravityToken(c *gin.Context) {
|
||||
if resp.StatusCode < http.StatusOK || resp.StatusCode >= http.StatusMultipleChoices {
|
||||
bodyBytes, _ := io.ReadAll(resp.Body)
|
||||
log.Errorf("Antigravity token exchange failed with status %d: %s", resp.StatusCode, string(bodyBytes))
|
||||
oauthStatus[state] = fmt.Sprintf("Token exchange failed: %d", resp.StatusCode)
|
||||
SetOAuthSessionError(state, fmt.Sprintf("Token exchange failed: %d", resp.StatusCode))
|
||||
return
|
||||
}
|
||||
|
||||
@@ -1487,7 +1487,7 @@ func (h *Handler) RequestAntigravityToken(c *gin.Context) {
|
||||
}
|
||||
if errDecode := json.NewDecoder(resp.Body).Decode(&tokenResp); errDecode != nil {
|
||||
log.Errorf("Failed to parse token response: %v", errDecode)
|
||||
oauthStatus[state] = "Failed to parse token response"
|
||||
SetOAuthSessionError(state, "Failed to parse token response")
|
||||
return
|
||||
}
|
||||
|
||||
@@ -1496,7 +1496,7 @@ func (h *Handler) RequestAntigravityToken(c *gin.Context) {
|
||||
infoReq, errInfoReq := http.NewRequestWithContext(ctx, http.MethodGet, "https://www.googleapis.com/oauth2/v1/userinfo?alt=json", nil)
|
||||
if errInfoReq != nil {
|
||||
log.Errorf("Failed to build user info request: %v", errInfoReq)
|
||||
oauthStatus[state] = "Failed to build user info request"
|
||||
SetOAuthSessionError(state, "Failed to build user info request")
|
||||
return
|
||||
}
|
||||
infoReq.Header.Set("Authorization", "Bearer "+tokenResp.AccessToken)
|
||||
@@ -1504,7 +1504,7 @@ func (h *Handler) RequestAntigravityToken(c *gin.Context) {
|
||||
infoResp, errInfo := httpClient.Do(infoReq)
|
||||
if errInfo != nil {
|
||||
log.Errorf("Failed to execute user info request: %v", errInfo)
|
||||
oauthStatus[state] = "Failed to execute user info request"
|
||||
SetOAuthSessionError(state, "Failed to execute user info request")
|
||||
return
|
||||
}
|
||||
defer func() {
|
||||
@@ -1523,7 +1523,7 @@ func (h *Handler) RequestAntigravityToken(c *gin.Context) {
|
||||
} else {
|
||||
bodyBytes, _ := io.ReadAll(infoResp.Body)
|
||||
log.Errorf("User info request failed with status %d: %s", infoResp.StatusCode, string(bodyBytes))
|
||||
oauthStatus[state] = fmt.Sprintf("User info request failed: %d", infoResp.StatusCode)
|
||||
SetOAuthSessionError(state, fmt.Sprintf("User info request failed: %d", infoResp.StatusCode))
|
||||
return
|
||||
}
|
||||
}
|
||||
@@ -1571,11 +1571,11 @@ func (h *Handler) RequestAntigravityToken(c *gin.Context) {
|
||||
savedPath, errSave := h.saveTokenRecord(ctx, record)
|
||||
if errSave != nil {
|
||||
log.Errorf("Failed to save token to file: %v", errSave)
|
||||
oauthStatus[state] = "Failed to save token to file"
|
||||
SetOAuthSessionError(state, "Failed to save token to file")
|
||||
return
|
||||
}
|
||||
|
||||
delete(oauthStatus, state)
|
||||
CompleteOAuthSession(state)
|
||||
fmt.Printf("Authentication successful! Token saved to %s\n", savedPath)
|
||||
if projectID != "" {
|
||||
fmt.Printf("Using GCP project: %s\n", projectID)
|
||||
@@ -1583,7 +1583,6 @@ func (h *Handler) RequestAntigravityToken(c *gin.Context) {
|
||||
fmt.Println("You can now use Antigravity services through this CLI")
|
||||
}()
|
||||
|
||||
oauthStatus[state] = ""
|
||||
c.JSON(200, gin.H{"status": "ok", "url": authURL, "state": state})
|
||||
}
|
||||
|
||||
@@ -1605,11 +1604,13 @@ func (h *Handler) RequestQwenToken(c *gin.Context) {
|
||||
}
|
||||
authURL := deviceFlow.VerificationURIComplete
|
||||
|
||||
RegisterOAuthSession(state, "qwen")
|
||||
|
||||
go func() {
|
||||
fmt.Println("Waiting for authentication...")
|
||||
tokenData, errPollForToken := qwenAuth.PollForToken(deviceFlow.DeviceCode, deviceFlow.CodeVerifier)
|
||||
if errPollForToken != nil {
|
||||
oauthStatus[state] = "Authentication failed"
|
||||
SetOAuthSessionError(state, "Authentication failed")
|
||||
fmt.Printf("Authentication failed: %v\n", errPollForToken)
|
||||
return
|
||||
}
|
||||
@@ -1628,16 +1629,15 @@ func (h *Handler) RequestQwenToken(c *gin.Context) {
|
||||
savedPath, errSave := h.saveTokenRecord(ctx, record)
|
||||
if errSave != nil {
|
||||
log.Errorf("Failed to save authentication tokens: %v", errSave)
|
||||
oauthStatus[state] = "Failed to save authentication tokens"
|
||||
SetOAuthSessionError(state, "Failed to save authentication tokens")
|
||||
return
|
||||
}
|
||||
|
||||
fmt.Printf("Authentication successful! Token saved to %s\n", savedPath)
|
||||
fmt.Println("You can now use Qwen services through this CLI")
|
||||
delete(oauthStatus, state)
|
||||
CompleteOAuthSession(state)
|
||||
}()
|
||||
|
||||
oauthStatus[state] = ""
|
||||
c.JSON(200, gin.H{"status": "ok", "url": authURL, "state": state})
|
||||
}
|
||||
|
||||
@@ -1650,6 +1650,8 @@ func (h *Handler) RequestIFlowToken(c *gin.Context) {
|
||||
authSvc := iflowauth.NewIFlowAuth(h.cfg)
|
||||
authURL, redirectURI := authSvc.AuthorizationURL(state, iflowauth.CallbackPort)
|
||||
|
||||
RegisterOAuthSession(state, "iflow")
|
||||
|
||||
isWebUI := isWebUIRequest(c)
|
||||
if isWebUI {
|
||||
targetURL, errTarget := h.managementCallbackURL("/iflow/callback")
|
||||
@@ -1676,7 +1678,7 @@ func (h *Handler) RequestIFlowToken(c *gin.Context) {
|
||||
var resultMap map[string]string
|
||||
for {
|
||||
if time.Now().After(deadline) {
|
||||
oauthStatus[state] = "Authentication failed"
|
||||
SetOAuthSessionError(state, "Authentication failed")
|
||||
fmt.Println("Authentication failed: timeout waiting for callback")
|
||||
return
|
||||
}
|
||||
@@ -1689,26 +1691,26 @@ func (h *Handler) RequestIFlowToken(c *gin.Context) {
|
||||
}
|
||||
|
||||
if errStr := strings.TrimSpace(resultMap["error"]); errStr != "" {
|
||||
oauthStatus[state] = "Authentication failed"
|
||||
SetOAuthSessionError(state, "Authentication failed")
|
||||
fmt.Printf("Authentication failed: %s\n", errStr)
|
||||
return
|
||||
}
|
||||
if resultState := strings.TrimSpace(resultMap["state"]); resultState != state {
|
||||
oauthStatus[state] = "Authentication failed"
|
||||
SetOAuthSessionError(state, "Authentication failed")
|
||||
fmt.Println("Authentication failed: state mismatch")
|
||||
return
|
||||
}
|
||||
|
||||
code := strings.TrimSpace(resultMap["code"])
|
||||
if code == "" {
|
||||
oauthStatus[state] = "Authentication failed"
|
||||
SetOAuthSessionError(state, "Authentication failed")
|
||||
fmt.Println("Authentication failed: code missing")
|
||||
return
|
||||
}
|
||||
|
||||
tokenData, errExchange := authSvc.ExchangeCodeForTokens(ctx, code, redirectURI)
|
||||
if errExchange != nil {
|
||||
oauthStatus[state] = "Authentication failed"
|
||||
SetOAuthSessionError(state, "Authentication failed")
|
||||
fmt.Printf("Authentication failed: %v\n", errExchange)
|
||||
return
|
||||
}
|
||||
@@ -1730,7 +1732,7 @@ func (h *Handler) RequestIFlowToken(c *gin.Context) {
|
||||
|
||||
savedPath, errSave := h.saveTokenRecord(ctx, record)
|
||||
if errSave != nil {
|
||||
oauthStatus[state] = "Failed to save authentication tokens"
|
||||
SetOAuthSessionError(state, "Failed to save authentication tokens")
|
||||
log.Errorf("Failed to save authentication tokens: %v", errSave)
|
||||
return
|
||||
}
|
||||
@@ -1740,10 +1742,9 @@ func (h *Handler) RequestIFlowToken(c *gin.Context) {
|
||||
fmt.Println("API key obtained and saved")
|
||||
}
|
||||
fmt.Println("You can now use iFlow services through this CLI")
|
||||
delete(oauthStatus, state)
|
||||
CompleteOAuthSession(state)
|
||||
}()
|
||||
|
||||
oauthStatus[state] = ""
|
||||
c.JSON(http.StatusOK, gin.H{"status": "ok", "url": authURL, "state": state})
|
||||
}
|
||||
|
||||
@@ -2179,16 +2180,24 @@ func checkCloudAPIIsEnabled(ctx context.Context, httpClient *http.Client, projec
|
||||
}
|
||||
|
||||
func (h *Handler) GetAuthStatus(c *gin.Context) {
|
||||
state := c.Query("state")
|
||||
if err, ok := oauthStatus[state]; ok {
|
||||
if err != "" {
|
||||
c.JSON(200, gin.H{"status": "error", "error": err})
|
||||
} else {
|
||||
c.JSON(200, gin.H{"status": "wait"})
|
||||
return
|
||||
}
|
||||
} else {
|
||||
c.JSON(200, gin.H{"status": "ok"})
|
||||
state := strings.TrimSpace(c.Query("state"))
|
||||
if state == "" {
|
||||
c.JSON(http.StatusOK, gin.H{"status": "ok"})
|
||||
return
|
||||
}
|
||||
delete(oauthStatus, state)
|
||||
if err := ValidateOAuthState(state); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"status": "error", "error": "invalid state"})
|
||||
return
|
||||
}
|
||||
|
||||
_, status, ok := GetOAuthSession(state)
|
||||
if !ok {
|
||||
c.JSON(http.StatusOK, gin.H{"status": "ok"})
|
||||
return
|
||||
}
|
||||
if status != "" {
|
||||
c.JSON(http.StatusOK, gin.H{"status": "error", "error": status})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{"status": "wait"})
|
||||
}
|
||||
|
||||
100
internal/api/handlers/management/oauth_callback.go
Normal file
100
internal/api/handlers/management/oauth_callback.go
Normal file
@@ -0,0 +1,100 @@
|
||||
package management
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
type oauthCallbackRequest struct {
|
||||
Provider string `json:"provider"`
|
||||
RedirectURL string `json:"redirect_url"`
|
||||
Code string `json:"code"`
|
||||
State string `json:"state"`
|
||||
Error string `json:"error"`
|
||||
}
|
||||
|
||||
func (h *Handler) PostOAuthCallback(c *gin.Context) {
|
||||
if h == nil || h.cfg == nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"status": "error", "error": "handler not initialized"})
|
||||
return
|
||||
}
|
||||
|
||||
var req oauthCallbackRequest
|
||||
if err := c.ShouldBindJSON(&req); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"status": "error", "error": "invalid body"})
|
||||
return
|
||||
}
|
||||
|
||||
canonicalProvider, err := NormalizeOAuthProvider(req.Provider)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"status": "error", "error": "unsupported provider"})
|
||||
return
|
||||
}
|
||||
|
||||
state := strings.TrimSpace(req.State)
|
||||
code := strings.TrimSpace(req.Code)
|
||||
errMsg := strings.TrimSpace(req.Error)
|
||||
|
||||
if rawRedirect := strings.TrimSpace(req.RedirectURL); rawRedirect != "" {
|
||||
u, errParse := url.Parse(rawRedirect)
|
||||
if errParse != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"status": "error", "error": "invalid redirect_url"})
|
||||
return
|
||||
}
|
||||
q := u.Query()
|
||||
if state == "" {
|
||||
state = strings.TrimSpace(q.Get("state"))
|
||||
}
|
||||
if code == "" {
|
||||
code = strings.TrimSpace(q.Get("code"))
|
||||
}
|
||||
if errMsg == "" {
|
||||
errMsg = strings.TrimSpace(q.Get("error"))
|
||||
if errMsg == "" {
|
||||
errMsg = strings.TrimSpace(q.Get("error_description"))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if state == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"status": "error", "error": "state is required"})
|
||||
return
|
||||
}
|
||||
if err := ValidateOAuthState(state); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"status": "error", "error": "invalid state"})
|
||||
return
|
||||
}
|
||||
if code == "" && errMsg == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"status": "error", "error": "code or error is required"})
|
||||
return
|
||||
}
|
||||
|
||||
sessionProvider, sessionStatus, ok := GetOAuthSession(state)
|
||||
if !ok {
|
||||
c.JSON(http.StatusNotFound, gin.H{"status": "error", "error": "unknown or expired state"})
|
||||
return
|
||||
}
|
||||
if sessionStatus != "" {
|
||||
c.JSON(http.StatusConflict, gin.H{"status": "error", "error": "oauth flow is not pending"})
|
||||
return
|
||||
}
|
||||
if !strings.EqualFold(sessionProvider, canonicalProvider) {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"status": "error", "error": "provider does not match state"})
|
||||
return
|
||||
}
|
||||
|
||||
if _, errWrite := WriteOAuthCallbackFileForPendingSession(h.cfg.AuthDir, canonicalProvider, state, code, errMsg); errWrite != nil {
|
||||
if errors.Is(errWrite, errOAuthSessionNotPending) {
|
||||
c.JSON(http.StatusConflict, gin.H{"status": "error", "error": "oauth flow is not pending"})
|
||||
return
|
||||
}
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"status": "error", "error": "failed to persist oauth callback"})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{"status": "ok"})
|
||||
}
|
||||
258
internal/api/handlers/management/oauth_sessions.go
Normal file
258
internal/api/handlers/management/oauth_sessions.go
Normal file
@@ -0,0 +1,258 @@
|
||||
package management
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
const (
|
||||
oauthSessionTTL = 10 * time.Minute
|
||||
maxOAuthStateLength = 128
|
||||
)
|
||||
|
||||
var (
|
||||
errInvalidOAuthState = errors.New("invalid oauth state")
|
||||
errUnsupportedOAuthFlow = errors.New("unsupported oauth provider")
|
||||
errOAuthSessionNotPending = errors.New("oauth session is not pending")
|
||||
)
|
||||
|
||||
type oauthSession struct {
|
||||
Provider string
|
||||
Status string
|
||||
CreatedAt time.Time
|
||||
ExpiresAt time.Time
|
||||
}
|
||||
|
||||
type oauthSessionStore struct {
|
||||
mu sync.RWMutex
|
||||
ttl time.Duration
|
||||
sessions map[string]oauthSession
|
||||
}
|
||||
|
||||
func newOAuthSessionStore(ttl time.Duration) *oauthSessionStore {
|
||||
if ttl <= 0 {
|
||||
ttl = oauthSessionTTL
|
||||
}
|
||||
return &oauthSessionStore{
|
||||
ttl: ttl,
|
||||
sessions: make(map[string]oauthSession),
|
||||
}
|
||||
}
|
||||
|
||||
func (s *oauthSessionStore) purgeExpiredLocked(now time.Time) {
|
||||
for state, session := range s.sessions {
|
||||
if !session.ExpiresAt.IsZero() && now.After(session.ExpiresAt) {
|
||||
delete(s.sessions, state)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *oauthSessionStore) Register(state, provider string) {
|
||||
state = strings.TrimSpace(state)
|
||||
provider = strings.ToLower(strings.TrimSpace(provider))
|
||||
if state == "" || provider == "" {
|
||||
return
|
||||
}
|
||||
now := time.Now()
|
||||
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
s.purgeExpiredLocked(now)
|
||||
s.sessions[state] = oauthSession{
|
||||
Provider: provider,
|
||||
Status: "",
|
||||
CreatedAt: now,
|
||||
ExpiresAt: now.Add(s.ttl),
|
||||
}
|
||||
}
|
||||
|
||||
func (s *oauthSessionStore) SetError(state, message string) {
|
||||
state = strings.TrimSpace(state)
|
||||
message = strings.TrimSpace(message)
|
||||
if state == "" {
|
||||
return
|
||||
}
|
||||
if message == "" {
|
||||
message = "Authentication failed"
|
||||
}
|
||||
now := time.Now()
|
||||
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
s.purgeExpiredLocked(now)
|
||||
session, ok := s.sessions[state]
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
session.Status = message
|
||||
session.ExpiresAt = now.Add(s.ttl)
|
||||
s.sessions[state] = session
|
||||
}
|
||||
|
||||
func (s *oauthSessionStore) Complete(state string) {
|
||||
state = strings.TrimSpace(state)
|
||||
if state == "" {
|
||||
return
|
||||
}
|
||||
now := time.Now()
|
||||
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
s.purgeExpiredLocked(now)
|
||||
delete(s.sessions, state)
|
||||
}
|
||||
|
||||
func (s *oauthSessionStore) Get(state string) (oauthSession, bool) {
|
||||
state = strings.TrimSpace(state)
|
||||
now := time.Now()
|
||||
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
s.purgeExpiredLocked(now)
|
||||
session, ok := s.sessions[state]
|
||||
return session, ok
|
||||
}
|
||||
|
||||
func (s *oauthSessionStore) IsPending(state, provider string) bool {
|
||||
state = strings.TrimSpace(state)
|
||||
provider = strings.ToLower(strings.TrimSpace(provider))
|
||||
now := time.Now()
|
||||
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
s.purgeExpiredLocked(now)
|
||||
session, ok := s.sessions[state]
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
if session.Status != "" {
|
||||
return false
|
||||
}
|
||||
if provider == "" {
|
||||
return true
|
||||
}
|
||||
return strings.EqualFold(session.Provider, provider)
|
||||
}
|
||||
|
||||
var oauthSessions = newOAuthSessionStore(oauthSessionTTL)
|
||||
|
||||
func RegisterOAuthSession(state, provider string) { oauthSessions.Register(state, provider) }
|
||||
|
||||
func SetOAuthSessionError(state, message string) { oauthSessions.SetError(state, message) }
|
||||
|
||||
func CompleteOAuthSession(state string) { oauthSessions.Complete(state) }
|
||||
|
||||
func GetOAuthSession(state string) (provider string, status string, ok bool) {
|
||||
session, ok := oauthSessions.Get(state)
|
||||
if !ok {
|
||||
return "", "", false
|
||||
}
|
||||
return session.Provider, session.Status, true
|
||||
}
|
||||
|
||||
func IsOAuthSessionPending(state, provider string) bool {
|
||||
return oauthSessions.IsPending(state, provider)
|
||||
}
|
||||
|
||||
func ValidateOAuthState(state string) error {
|
||||
trimmed := strings.TrimSpace(state)
|
||||
if trimmed == "" {
|
||||
return fmt.Errorf("%w: empty", errInvalidOAuthState)
|
||||
}
|
||||
if len(trimmed) > maxOAuthStateLength {
|
||||
return fmt.Errorf("%w: too long", errInvalidOAuthState)
|
||||
}
|
||||
if strings.Contains(trimmed, "/") || strings.Contains(trimmed, "\\") {
|
||||
return fmt.Errorf("%w: contains path separator", errInvalidOAuthState)
|
||||
}
|
||||
if strings.Contains(trimmed, "..") {
|
||||
return fmt.Errorf("%w: contains '..'", errInvalidOAuthState)
|
||||
}
|
||||
for _, r := range trimmed {
|
||||
switch {
|
||||
case r >= 'a' && r <= 'z':
|
||||
case r >= 'A' && r <= 'Z':
|
||||
case r >= '0' && r <= '9':
|
||||
case r == '-' || r == '_' || r == '.':
|
||||
default:
|
||||
return fmt.Errorf("%w: invalid character", errInvalidOAuthState)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func NormalizeOAuthProvider(provider string) (string, error) {
|
||||
switch strings.ToLower(strings.TrimSpace(provider)) {
|
||||
case "anthropic", "claude":
|
||||
return "anthropic", nil
|
||||
case "codex", "openai":
|
||||
return "codex", nil
|
||||
case "gemini", "google":
|
||||
return "gemini", nil
|
||||
case "iflow", "i-flow":
|
||||
return "iflow", nil
|
||||
case "antigravity", "anti-gravity":
|
||||
return "antigravity", nil
|
||||
case "qwen":
|
||||
return "qwen", nil
|
||||
default:
|
||||
return "", errUnsupportedOAuthFlow
|
||||
}
|
||||
}
|
||||
|
||||
type oauthCallbackFilePayload struct {
|
||||
Code string `json:"code"`
|
||||
State string `json:"state"`
|
||||
Error string `json:"error"`
|
||||
}
|
||||
|
||||
func WriteOAuthCallbackFile(authDir, provider, state, code, errorMessage string) (string, error) {
|
||||
if strings.TrimSpace(authDir) == "" {
|
||||
return "", fmt.Errorf("auth dir is empty")
|
||||
}
|
||||
canonicalProvider, err := NormalizeOAuthProvider(provider)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if err := ValidateOAuthState(state); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
fileName := fmt.Sprintf(".oauth-%s-%s.oauth", canonicalProvider, state)
|
||||
filePath := filepath.Join(authDir, fileName)
|
||||
payload := oauthCallbackFilePayload{
|
||||
Code: strings.TrimSpace(code),
|
||||
State: strings.TrimSpace(state),
|
||||
Error: strings.TrimSpace(errorMessage),
|
||||
}
|
||||
data, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("marshal oauth callback payload: %w", err)
|
||||
}
|
||||
if err := os.WriteFile(filePath, data, 0o600); err != nil {
|
||||
return "", fmt.Errorf("write oauth callback file: %w", err)
|
||||
}
|
||||
return filePath, nil
|
||||
}
|
||||
|
||||
func WriteOAuthCallbackFileForPendingSession(authDir, provider, state, code, errorMessage string) (string, error) {
|
||||
canonicalProvider, err := NormalizeOAuthProvider(provider)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
if !IsOAuthSessionPending(state, canonicalProvider) {
|
||||
return "", errOAuthSessionNotPending
|
||||
}
|
||||
return WriteOAuthCallbackFile(authDir, canonicalProvider, state, code, errorMessage)
|
||||
}
|
||||
@@ -146,6 +146,9 @@ func TestAmpModule_OnConfigUpdated_CacheInvalidation(t *testing.T) {
|
||||
m := &AmpModule{enabled: true}
|
||||
ms := NewMultiSourceSecretWithPath("", p, time.Minute)
|
||||
m.secretSource = ms
|
||||
m.lastConfig = &config.AmpCode{
|
||||
UpstreamAPIKey: "old-key",
|
||||
}
|
||||
|
||||
// Warm the cache
|
||||
if _, err := ms.Get(context.Background()); err != nil {
|
||||
@@ -157,7 +160,7 @@ func TestAmpModule_OnConfigUpdated_CacheInvalidation(t *testing.T) {
|
||||
}
|
||||
|
||||
// Update config - should invalidate cache
|
||||
if err := m.OnConfigUpdated(&config.Config{AmpCode: config.AmpCode{UpstreamURL: "http://x"}}); err != nil {
|
||||
if err := m.OnConfigUpdated(&config.Config{AmpCode: config.AmpCode{UpstreamURL: "http://x", UpstreamAPIKey: "new-key"}}); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
|
||||
@@ -95,6 +95,20 @@ func (m *AmpModule) managementAvailabilityMiddleware() gin.HandlerFunc {
|
||||
}
|
||||
}
|
||||
|
||||
// wrapManagementAuth skips auth for selected management paths while keeping authentication elsewhere.
|
||||
func wrapManagementAuth(auth gin.HandlerFunc, prefixes ...string) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
path := c.Request.URL.Path
|
||||
for _, prefix := range prefixes {
|
||||
if strings.HasPrefix(path, prefix) && (len(path) == len(prefix) || path[len(prefix)] == '/') {
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
}
|
||||
auth(c)
|
||||
}
|
||||
}
|
||||
|
||||
// registerManagementRoutes registers Amp management proxy routes
|
||||
// These routes proxy through to the Amp control plane for OAuth, user management, etc.
|
||||
// Uses dynamic middleware and proxy getter for hot-reload support.
|
||||
@@ -109,8 +123,10 @@ func (m *AmpModule) registerManagementRoutes(engine *gin.Engine, baseHandler *ha
|
||||
ampAPI.Use(m.localhostOnlyMiddleware())
|
||||
|
||||
// Apply authentication middleware - requires valid API key in Authorization header
|
||||
var authWithBypass gin.HandlerFunc
|
||||
if auth != nil {
|
||||
ampAPI.Use(auth)
|
||||
authWithBypass = wrapManagementAuth(auth, "/threads", "/auth", "/docs")
|
||||
}
|
||||
|
||||
// Dynamic proxy handler that uses m.getProxy() for hot-reload support
|
||||
@@ -156,10 +172,14 @@ func (m *AmpModule) registerManagementRoutes(engine *gin.Engine, baseHandler *ha
|
||||
// Root-level routes that AMP CLI expects without /api prefix
|
||||
// These need the same security middleware as the /api/* routes (dynamic for hot-reload)
|
||||
rootMiddleware := []gin.HandlerFunc{m.managementAvailabilityMiddleware(), noCORSMiddleware(), m.localhostOnlyMiddleware()}
|
||||
if auth != nil {
|
||||
rootMiddleware = append(rootMiddleware, auth)
|
||||
if authWithBypass != nil {
|
||||
rootMiddleware = append(rootMiddleware, authWithBypass)
|
||||
}
|
||||
engine.GET("/threads", append(rootMiddleware, proxyHandler)...)
|
||||
engine.GET("/threads/*path", append(rootMiddleware, proxyHandler)...)
|
||||
engine.GET("/docs", append(rootMiddleware, proxyHandler)...)
|
||||
engine.GET("/docs/*path", append(rootMiddleware, proxyHandler)...)
|
||||
|
||||
engine.GET("/threads.rss", append(rootMiddleware, proxyHandler)...)
|
||||
engine.GET("/news.rss", append(rootMiddleware, proxyHandler)...)
|
||||
|
||||
@@ -267,7 +287,7 @@ func (m *AmpModule) registerProviderAliases(engine *gin.Engine, baseHandler *han
|
||||
v1betaAmp := provider.Group("/v1beta")
|
||||
{
|
||||
v1betaAmp.GET("/models", geminiHandlers.GeminiModels)
|
||||
v1betaAmp.POST("/models/:action", fallbackHandler.WrapHandler(geminiHandlers.GeminiHandler))
|
||||
v1betaAmp.GET("/models/:action", geminiHandlers.GeminiGetHandler)
|
||||
v1betaAmp.POST("/models/*action", fallbackHandler.WrapHandler(geminiHandlers.GeminiHandler))
|
||||
v1betaAmp.GET("/models/*action", geminiHandlers.GeminiGetHandler)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -32,7 +32,9 @@ func TestRegisterManagementRoutes(t *testing.T) {
|
||||
m.setProxy(proxy)
|
||||
|
||||
base := &handlers.BaseAPIHandler{}
|
||||
m.registerManagementRoutes(r, base)
|
||||
m.registerManagementRoutes(r, base, nil)
|
||||
srv := httptest.NewServer(r)
|
||||
defer srv.Close()
|
||||
|
||||
managementPaths := []struct {
|
||||
path string
|
||||
@@ -63,11 +65,17 @@ func TestRegisterManagementRoutes(t *testing.T) {
|
||||
for _, path := range managementPaths {
|
||||
t.Run(path.path, func(t *testing.T) {
|
||||
proxyCalled = false
|
||||
req := httptest.NewRequest(path.method, path.path, nil)
|
||||
w := httptest.NewRecorder()
|
||||
r.ServeHTTP(w, req)
|
||||
req, err := http.NewRequest(path.method, srv.URL+path.path, nil)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to build request: %v", err)
|
||||
}
|
||||
resp, err := http.DefaultClient.Do(req)
|
||||
if err != nil {
|
||||
t.Fatalf("request failed: %v", err)
|
||||
}
|
||||
defer resp.Body.Close()
|
||||
|
||||
if w.Code == http.StatusNotFound {
|
||||
if resp.StatusCode == http.StatusNotFound {
|
||||
t.Fatalf("route %s not registered", path.path)
|
||||
}
|
||||
if !proxyCalled {
|
||||
|
||||
@@ -230,13 +230,9 @@ func NewServer(cfg *config.Config, authManager *auth.Manager, accessManager *sdk
|
||||
envManagementSecret := envAdminPasswordSet && envAdminPassword != ""
|
||||
|
||||
// Create server instance
|
||||
providerNames := make([]string, 0, len(cfg.OpenAICompatibility))
|
||||
for _, p := range cfg.OpenAICompatibility {
|
||||
providerNames = append(providerNames, p.Name)
|
||||
}
|
||||
s := &Server{
|
||||
engine: engine,
|
||||
handlers: handlers.NewBaseAPIHandlers(&cfg.SDKConfig, authManager, providerNames),
|
||||
handlers: handlers.NewBaseAPIHandlers(&cfg.SDKConfig, authManager),
|
||||
cfg: cfg,
|
||||
accessManager: accessManager,
|
||||
requestLogger: requestLogger,
|
||||
@@ -334,8 +330,8 @@ func (s *Server) setupRoutes() {
|
||||
v1beta.Use(AuthMiddleware(s.accessManager))
|
||||
{
|
||||
v1beta.GET("/models", geminiHandlers.GeminiModels)
|
||||
v1beta.POST("/models/:action", geminiHandlers.GeminiHandler)
|
||||
v1beta.GET("/models/:action", geminiHandlers.GeminiGetHandler)
|
||||
v1beta.POST("/models/*action", geminiHandlers.GeminiHandler)
|
||||
v1beta.GET("/models/*action", geminiHandlers.GeminiGetHandler)
|
||||
}
|
||||
|
||||
// Root endpoint
|
||||
@@ -358,10 +354,11 @@ func (s *Server) setupRoutes() {
|
||||
code := c.Query("code")
|
||||
state := c.Query("state")
|
||||
errStr := c.Query("error")
|
||||
// Persist to a temporary file keyed by state
|
||||
if errStr == "" {
|
||||
errStr = c.Query("error_description")
|
||||
}
|
||||
if state != "" {
|
||||
file := fmt.Sprintf("%s/.oauth-anthropic-%s.oauth", s.cfg.AuthDir, state)
|
||||
_ = os.WriteFile(file, []byte(fmt.Sprintf(`{"code":"%s","state":"%s","error":"%s"}`, code, state, errStr)), 0o600)
|
||||
_, _ = managementHandlers.WriteOAuthCallbackFileForPendingSession(s.cfg.AuthDir, "anthropic", state, code, errStr)
|
||||
}
|
||||
c.Header("Content-Type", "text/html; charset=utf-8")
|
||||
c.String(http.StatusOK, oauthCallbackSuccessHTML)
|
||||
@@ -371,9 +368,11 @@ func (s *Server) setupRoutes() {
|
||||
code := c.Query("code")
|
||||
state := c.Query("state")
|
||||
errStr := c.Query("error")
|
||||
if errStr == "" {
|
||||
errStr = c.Query("error_description")
|
||||
}
|
||||
if state != "" {
|
||||
file := fmt.Sprintf("%s/.oauth-codex-%s.oauth", s.cfg.AuthDir, state)
|
||||
_ = os.WriteFile(file, []byte(fmt.Sprintf(`{"code":"%s","state":"%s","error":"%s"}`, code, state, errStr)), 0o600)
|
||||
_, _ = managementHandlers.WriteOAuthCallbackFileForPendingSession(s.cfg.AuthDir, "codex", state, code, errStr)
|
||||
}
|
||||
c.Header("Content-Type", "text/html; charset=utf-8")
|
||||
c.String(http.StatusOK, oauthCallbackSuccessHTML)
|
||||
@@ -383,9 +382,11 @@ func (s *Server) setupRoutes() {
|
||||
code := c.Query("code")
|
||||
state := c.Query("state")
|
||||
errStr := c.Query("error")
|
||||
if errStr == "" {
|
||||
errStr = c.Query("error_description")
|
||||
}
|
||||
if state != "" {
|
||||
file := fmt.Sprintf("%s/.oauth-gemini-%s.oauth", s.cfg.AuthDir, state)
|
||||
_ = os.WriteFile(file, []byte(fmt.Sprintf(`{"code":"%s","state":"%s","error":"%s"}`, code, state, errStr)), 0o600)
|
||||
_, _ = managementHandlers.WriteOAuthCallbackFileForPendingSession(s.cfg.AuthDir, "gemini", state, code, errStr)
|
||||
}
|
||||
c.Header("Content-Type", "text/html; charset=utf-8")
|
||||
c.String(http.StatusOK, oauthCallbackSuccessHTML)
|
||||
@@ -395,9 +396,11 @@ func (s *Server) setupRoutes() {
|
||||
code := c.Query("code")
|
||||
state := c.Query("state")
|
||||
errStr := c.Query("error")
|
||||
if errStr == "" {
|
||||
errStr = c.Query("error_description")
|
||||
}
|
||||
if state != "" {
|
||||
file := fmt.Sprintf("%s/.oauth-iflow-%s.oauth", s.cfg.AuthDir, state)
|
||||
_ = os.WriteFile(file, []byte(fmt.Sprintf(`{"code":"%s","state":"%s","error":"%s"}`, code, state, errStr)), 0o600)
|
||||
_, _ = managementHandlers.WriteOAuthCallbackFileForPendingSession(s.cfg.AuthDir, "iflow", state, code, errStr)
|
||||
}
|
||||
c.Header("Content-Type", "text/html; charset=utf-8")
|
||||
c.String(http.StatusOK, oauthCallbackSuccessHTML)
|
||||
@@ -407,9 +410,11 @@ func (s *Server) setupRoutes() {
|
||||
code := c.Query("code")
|
||||
state := c.Query("state")
|
||||
errStr := c.Query("error")
|
||||
if errStr == "" {
|
||||
errStr = c.Query("error_description")
|
||||
}
|
||||
if state != "" {
|
||||
file := fmt.Sprintf("%s/.oauth-antigravity-%s.oauth", s.cfg.AuthDir, state)
|
||||
_ = os.WriteFile(file, []byte(fmt.Sprintf(`{"code":"%s","state":"%s","error":"%s"}`, code, state, errStr)), 0o600)
|
||||
_, _ = managementHandlers.WriteOAuthCallbackFileForPendingSession(s.cfg.AuthDir, "antigravity", state, code, errStr)
|
||||
}
|
||||
c.Header("Content-Type", "text/html; charset=utf-8")
|
||||
c.String(http.StatusOK, oauthCallbackSuccessHTML)
|
||||
@@ -581,6 +586,7 @@ func (s *Server) registerManagementRoutes() {
|
||||
mgmt.GET("/qwen-auth-url", s.mgmt.RequestQwenToken)
|
||||
mgmt.GET("/iflow-auth-url", s.mgmt.RequestIFlowToken)
|
||||
mgmt.POST("/iflow-auth-url", s.mgmt.RequestIFlowCookieToken)
|
||||
mgmt.POST("/oauth-callback", s.mgmt.PostOAuthCallback)
|
||||
mgmt.GET("/get-auth-status", s.mgmt.GetAuthStatus)
|
||||
}
|
||||
}
|
||||
@@ -838,11 +844,20 @@ func (s *Server) UpdateClients(cfg *config.Config) {
|
||||
}
|
||||
}
|
||||
|
||||
if oldCfg != nil && oldCfg.LoggingToFile != cfg.LoggingToFile {
|
||||
if err := logging.ConfigureLogOutput(cfg.LoggingToFile); err != nil {
|
||||
if oldCfg == nil || oldCfg.LoggingToFile != cfg.LoggingToFile || oldCfg.LogsMaxTotalSizeMB != cfg.LogsMaxTotalSizeMB {
|
||||
if err := logging.ConfigureLogOutput(cfg.LoggingToFile, cfg.LogsMaxTotalSizeMB); err != nil {
|
||||
log.Errorf("failed to reconfigure log output: %v", err)
|
||||
} else {
|
||||
log.Debugf("logging_to_file updated from %t to %t", oldCfg.LoggingToFile, cfg.LoggingToFile)
|
||||
if oldCfg == nil {
|
||||
log.Debug("log output configuration refreshed")
|
||||
} else {
|
||||
if oldCfg.LoggingToFile != cfg.LoggingToFile {
|
||||
log.Debugf("logging_to_file updated from %t to %t", oldCfg.LoggingToFile, cfg.LoggingToFile)
|
||||
}
|
||||
if oldCfg.LogsMaxTotalSizeMB != cfg.LogsMaxTotalSizeMB {
|
||||
log.Debugf("logs_max_total_size_mb updated from %d to %d", oldCfg.LogsMaxTotalSizeMB, cfg.LogsMaxTotalSizeMB)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -919,12 +934,6 @@ func (s *Server) UpdateClients(cfg *config.Config) {
|
||||
// Save YAML snapshot for next comparison
|
||||
s.oldConfigYaml, _ = yaml.Marshal(cfg)
|
||||
|
||||
providerNames := make([]string, 0, len(cfg.OpenAICompatibility))
|
||||
for _, p := range cfg.OpenAICompatibility {
|
||||
providerNames = append(providerNames, p.Name)
|
||||
}
|
||||
s.handlers.OpenAICompatProviders = providerNames
|
||||
|
||||
s.handlers.UpdateClients(&cfg.SDKConfig)
|
||||
|
||||
if !cfg.RemoteManagement.DisableControlPanel {
|
||||
|
||||
@@ -12,7 +12,6 @@ import (
|
||||
"strings"
|
||||
"syscall"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/sdk/config"
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
@@ -21,7 +20,7 @@ const DefaultPanelGitHubRepository = "https://github.com/router-for-me/Cli-Proxy
|
||||
|
||||
// Config represents the application's configuration, loaded from a YAML file.
|
||||
type Config struct {
|
||||
config.SDKConfig `yaml:",inline"`
|
||||
SDKConfig `yaml:",inline"`
|
||||
// Host is the network host/interface on which the API server will bind.
|
||||
// Default is empty ("") to bind all interfaces (IPv4 + IPv6). Use "127.0.0.1" or "localhost" for local-only access.
|
||||
Host string `yaml:"host" json:"-"`
|
||||
@@ -43,6 +42,10 @@ type Config struct {
|
||||
// LoggingToFile controls whether application logs are written to rotating files or stdout.
|
||||
LoggingToFile bool `yaml:"logging-to-file" json:"logging-to-file"`
|
||||
|
||||
// LogsMaxTotalSizeMB limits the total size (in MB) of log files under the logs directory.
|
||||
// When exceeded, the oldest log files are deleted until within the limit. Set to 0 to disable.
|
||||
LogsMaxTotalSizeMB int `yaml:"logs-max-total-size-mb" json:"logs-max-total-size-mb"`
|
||||
|
||||
// UsageStatisticsEnabled toggles in-memory usage aggregation; when false, usage data is discarded.
|
||||
UsageStatisticsEnabled bool `yaml:"usage-statistics-enabled" json:"usage-statistics-enabled"`
|
||||
|
||||
@@ -187,6 +190,9 @@ type ClaudeKey struct {
|
||||
// APIKey is the authentication key for accessing Claude API services.
|
||||
APIKey string `yaml:"api-key" json:"api-key"`
|
||||
|
||||
// Prefix optionally namespaces models for this credential (e.g., "teamA/claude-sonnet-4").
|
||||
Prefix string `yaml:"prefix,omitempty" json:"prefix,omitempty"`
|
||||
|
||||
// BaseURL is the base URL for the Claude API endpoint.
|
||||
// If empty, the default Claude API URL will be used.
|
||||
BaseURL string `yaml:"base-url" json:"base-url"`
|
||||
@@ -219,6 +225,9 @@ type CodexKey struct {
|
||||
// APIKey is the authentication key for accessing Codex API services.
|
||||
APIKey string `yaml:"api-key" json:"api-key"`
|
||||
|
||||
// Prefix optionally namespaces models for this credential (e.g., "teamA/gpt-5-codex").
|
||||
Prefix string `yaml:"prefix,omitempty" json:"prefix,omitempty"`
|
||||
|
||||
// BaseURL is the base URL for the Codex API endpoint.
|
||||
// If empty, the default Codex API URL will be used.
|
||||
BaseURL string `yaml:"base-url" json:"base-url"`
|
||||
@@ -239,6 +248,9 @@ type GeminiKey struct {
|
||||
// APIKey is the authentication key for accessing Gemini API services.
|
||||
APIKey string `yaml:"api-key" json:"api-key"`
|
||||
|
||||
// Prefix optionally namespaces models for this credential (e.g., "teamA/gemini-3-pro-preview").
|
||||
Prefix string `yaml:"prefix,omitempty" json:"prefix,omitempty"`
|
||||
|
||||
// BaseURL optionally overrides the Gemini API endpoint.
|
||||
BaseURL string `yaml:"base-url,omitempty" json:"base-url,omitempty"`
|
||||
|
||||
@@ -258,6 +270,9 @@ type OpenAICompatibility struct {
|
||||
// Name is the identifier for this OpenAI compatibility configuration.
|
||||
Name string `yaml:"name" json:"name"`
|
||||
|
||||
// Prefix optionally namespaces model aliases for this provider (e.g., "teamA/kimi-k2").
|
||||
Prefix string `yaml:"prefix,omitempty" json:"prefix,omitempty"`
|
||||
|
||||
// BaseURL is the base URL for the external OpenAI-compatible API endpoint.
|
||||
BaseURL string `yaml:"base-url" json:"base-url"`
|
||||
|
||||
@@ -330,6 +345,7 @@ func LoadConfigOptional(configFile string, optional bool) (*Config, error) {
|
||||
// Set defaults before unmarshal so that absent keys keep defaults.
|
||||
cfg.Host = "" // Default empty: binds to all interfaces (IPv4 + IPv6)
|
||||
cfg.LoggingToFile = false
|
||||
cfg.LogsMaxTotalSizeMB = 0
|
||||
cfg.UsageStatisticsEnabled = false
|
||||
cfg.DisableCooling = false
|
||||
cfg.AmpCode.RestrictManagementToLocalhost = false // Default to false: API key auth is sufficient
|
||||
@@ -374,6 +390,10 @@ func LoadConfigOptional(configFile string, optional bool) (*Config, error) {
|
||||
cfg.RemoteManagement.PanelGitHubRepository = DefaultPanelGitHubRepository
|
||||
}
|
||||
|
||||
if cfg.LogsMaxTotalSizeMB < 0 {
|
||||
cfg.LogsMaxTotalSizeMB = 0
|
||||
}
|
||||
|
||||
// Sync request authentication providers with inline API keys for backwards compatibility.
|
||||
syncInlineAccessProvider(&cfg)
|
||||
|
||||
@@ -422,6 +442,7 @@ func (cfg *Config) SanitizeOpenAICompatibility() {
|
||||
for i := range cfg.OpenAICompatibility {
|
||||
e := cfg.OpenAICompatibility[i]
|
||||
e.Name = strings.TrimSpace(e.Name)
|
||||
e.Prefix = normalizeModelPrefix(e.Prefix)
|
||||
e.BaseURL = strings.TrimSpace(e.BaseURL)
|
||||
e.Headers = NormalizeHeaders(e.Headers)
|
||||
if e.BaseURL == "" {
|
||||
@@ -442,6 +463,7 @@ func (cfg *Config) SanitizeCodexKeys() {
|
||||
out := make([]CodexKey, 0, len(cfg.CodexKey))
|
||||
for i := range cfg.CodexKey {
|
||||
e := cfg.CodexKey[i]
|
||||
e.Prefix = normalizeModelPrefix(e.Prefix)
|
||||
e.BaseURL = strings.TrimSpace(e.BaseURL)
|
||||
e.Headers = NormalizeHeaders(e.Headers)
|
||||
e.ExcludedModels = NormalizeExcludedModels(e.ExcludedModels)
|
||||
@@ -460,6 +482,7 @@ func (cfg *Config) SanitizeClaudeKeys() {
|
||||
}
|
||||
for i := range cfg.ClaudeKey {
|
||||
entry := &cfg.ClaudeKey[i]
|
||||
entry.Prefix = normalizeModelPrefix(entry.Prefix)
|
||||
entry.Headers = NormalizeHeaders(entry.Headers)
|
||||
entry.ExcludedModels = NormalizeExcludedModels(entry.ExcludedModels)
|
||||
}
|
||||
@@ -479,6 +502,7 @@ func (cfg *Config) SanitizeGeminiKeys() {
|
||||
if entry.APIKey == "" {
|
||||
continue
|
||||
}
|
||||
entry.Prefix = normalizeModelPrefix(entry.Prefix)
|
||||
entry.BaseURL = strings.TrimSpace(entry.BaseURL)
|
||||
entry.ProxyURL = strings.TrimSpace(entry.ProxyURL)
|
||||
entry.Headers = NormalizeHeaders(entry.Headers)
|
||||
@@ -492,6 +516,18 @@ func (cfg *Config) SanitizeGeminiKeys() {
|
||||
cfg.GeminiKey = out
|
||||
}
|
||||
|
||||
func normalizeModelPrefix(prefix string) string {
|
||||
trimmed := strings.TrimSpace(prefix)
|
||||
trimmed = strings.Trim(trimmed, "/")
|
||||
if trimmed == "" {
|
||||
return ""
|
||||
}
|
||||
if strings.Contains(trimmed, "/") {
|
||||
return ""
|
||||
}
|
||||
return trimmed
|
||||
}
|
||||
|
||||
func syncInlineAccessProvider(cfg *Config) {
|
||||
if cfg == nil {
|
||||
return
|
||||
@@ -664,7 +700,7 @@ func sanitizeConfigForPersist(cfg *Config) *Config {
|
||||
}
|
||||
clone := *cfg
|
||||
clone.SDKConfig = cfg.SDKConfig
|
||||
clone.SDKConfig.Access = config.AccessConfig{}
|
||||
clone.SDKConfig.Access = AccessConfig{}
|
||||
return &clone
|
||||
}
|
||||
|
||||
|
||||
87
internal/config/sdk_config.go
Normal file
87
internal/config/sdk_config.go
Normal file
@@ -0,0 +1,87 @@
|
||||
// Package config provides configuration management for the CLI Proxy API server.
|
||||
// It handles loading and parsing YAML configuration files, and provides structured
|
||||
// access to application settings including server port, authentication directory,
|
||||
// debug settings, proxy configuration, and API keys.
|
||||
package config
|
||||
|
||||
// SDKConfig represents the application's configuration, loaded from a YAML file.
|
||||
type SDKConfig struct {
|
||||
// ProxyURL is the URL of an optional proxy server to use for outbound requests.
|
||||
ProxyURL string `yaml:"proxy-url" json:"proxy-url"`
|
||||
|
||||
// ForceModelPrefix requires explicit model prefixes (e.g., "teamA/gemini-3-pro-preview")
|
||||
// to target prefixed credentials. When false, unprefixed model requests may use prefixed
|
||||
// credentials as well.
|
||||
ForceModelPrefix bool `yaml:"force-model-prefix" json:"force-model-prefix"`
|
||||
|
||||
// RequestLog enables or disables detailed request logging functionality.
|
||||
RequestLog bool `yaml:"request-log" json:"request-log"`
|
||||
|
||||
// APIKeys is a list of keys for authenticating clients to this proxy server.
|
||||
APIKeys []string `yaml:"api-keys" json:"api-keys"`
|
||||
|
||||
// Access holds request authentication provider configuration.
|
||||
Access AccessConfig `yaml:"auth,omitempty" json:"auth,omitempty"`
|
||||
}
|
||||
|
||||
// AccessConfig groups request authentication providers.
|
||||
type AccessConfig struct {
|
||||
// Providers lists configured authentication providers.
|
||||
Providers []AccessProvider `yaml:"providers,omitempty" json:"providers,omitempty"`
|
||||
}
|
||||
|
||||
// AccessProvider describes a request authentication provider entry.
|
||||
type AccessProvider struct {
|
||||
// Name is the instance identifier for the provider.
|
||||
Name string `yaml:"name" json:"name"`
|
||||
|
||||
// Type selects the provider implementation registered via the SDK.
|
||||
Type string `yaml:"type" json:"type"`
|
||||
|
||||
// SDK optionally names a third-party SDK module providing this provider.
|
||||
SDK string `yaml:"sdk,omitempty" json:"sdk,omitempty"`
|
||||
|
||||
// APIKeys lists inline keys for providers that require them.
|
||||
APIKeys []string `yaml:"api-keys,omitempty" json:"api-keys,omitempty"`
|
||||
|
||||
// Config passes provider-specific options to the implementation.
|
||||
Config map[string]any `yaml:"config,omitempty" json:"config,omitempty"`
|
||||
}
|
||||
|
||||
const (
|
||||
// AccessProviderTypeConfigAPIKey is the built-in provider validating inline API keys.
|
||||
AccessProviderTypeConfigAPIKey = "config-api-key"
|
||||
|
||||
// DefaultAccessProviderName is applied when no provider name is supplied.
|
||||
DefaultAccessProviderName = "config-inline"
|
||||
)
|
||||
|
||||
// ConfigAPIKeyProvider returns the first inline API key provider if present.
|
||||
func (c *SDKConfig) ConfigAPIKeyProvider() *AccessProvider {
|
||||
if c == nil {
|
||||
return nil
|
||||
}
|
||||
for i := range c.Access.Providers {
|
||||
if c.Access.Providers[i].Type == AccessProviderTypeConfigAPIKey {
|
||||
if c.Access.Providers[i].Name == "" {
|
||||
c.Access.Providers[i].Name = DefaultAccessProviderName
|
||||
}
|
||||
return &c.Access.Providers[i]
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// MakeInlineAPIKeyProvider constructs an inline API key provider configuration.
|
||||
// It returns nil when no keys are supplied.
|
||||
func MakeInlineAPIKeyProvider(keys []string) *AccessProvider {
|
||||
if len(keys) == 0 {
|
||||
return nil
|
||||
}
|
||||
provider := &AccessProvider{
|
||||
Name: DefaultAccessProviderName,
|
||||
Type: AccessProviderTypeConfigAPIKey,
|
||||
APIKeys: append([]string(nil), keys...),
|
||||
}
|
||||
return provider
|
||||
}
|
||||
@@ -13,6 +13,9 @@ type VertexCompatKey struct {
|
||||
// Maps to the x-goog-api-key header.
|
||||
APIKey string `yaml:"api-key" json:"api-key"`
|
||||
|
||||
// Prefix optionally namespaces model aliases for this credential (e.g., "teamA/vertex-pro").
|
||||
Prefix string `yaml:"prefix,omitempty" json:"prefix,omitempty"`
|
||||
|
||||
// BaseURL is the base URL for the Vertex-compatible API endpoint.
|
||||
// The executor will append "/v1/publishers/google/models/{model}:action" to this.
|
||||
// Example: "https://zenmux.ai/api" becomes "https://zenmux.ai/api/v1/publishers/google/models/..."
|
||||
@@ -53,6 +56,7 @@ func (cfg *Config) SanitizeVertexCompatKeys() {
|
||||
if entry.APIKey == "" {
|
||||
continue
|
||||
}
|
||||
entry.Prefix = normalizeModelPrefix(entry.Prefix)
|
||||
entry.BaseURL = strings.TrimSpace(entry.BaseURL)
|
||||
if entry.BaseURL == "" {
|
||||
// BaseURL is required for Vertex API key entries
|
||||
|
||||
@@ -72,39 +72,45 @@ func SetupBaseLogger() {
|
||||
}
|
||||
|
||||
// ConfigureLogOutput switches the global log destination between rotating files and stdout.
|
||||
func ConfigureLogOutput(loggingToFile bool) error {
|
||||
// When logsMaxTotalSizeMB > 0, a background cleaner removes the oldest log files in the logs directory
|
||||
// until the total size is within the limit.
|
||||
func ConfigureLogOutput(loggingToFile bool, logsMaxTotalSizeMB int) error {
|
||||
SetupBaseLogger()
|
||||
|
||||
writerMu.Lock()
|
||||
defer writerMu.Unlock()
|
||||
|
||||
logDir := "logs"
|
||||
if base := util.WritablePath(); base != "" {
|
||||
logDir = filepath.Join(base, "logs")
|
||||
}
|
||||
|
||||
protectedPath := ""
|
||||
if loggingToFile {
|
||||
logDir := "logs"
|
||||
if base := util.WritablePath(); base != "" {
|
||||
logDir = filepath.Join(base, "logs")
|
||||
}
|
||||
if err := os.MkdirAll(logDir, 0o755); err != nil {
|
||||
return fmt.Errorf("logging: failed to create log directory: %w", err)
|
||||
}
|
||||
if logWriter != nil {
|
||||
_ = logWriter.Close()
|
||||
}
|
||||
protectedPath = filepath.Join(logDir, "main.log")
|
||||
logWriter = &lumberjack.Logger{
|
||||
Filename: filepath.Join(logDir, "main.log"),
|
||||
Filename: protectedPath,
|
||||
MaxSize: 10,
|
||||
MaxBackups: 0,
|
||||
MaxAge: 0,
|
||||
Compress: false,
|
||||
}
|
||||
log.SetOutput(logWriter)
|
||||
return nil
|
||||
} else {
|
||||
if logWriter != nil {
|
||||
_ = logWriter.Close()
|
||||
logWriter = nil
|
||||
}
|
||||
log.SetOutput(os.Stdout)
|
||||
}
|
||||
|
||||
if logWriter != nil {
|
||||
_ = logWriter.Close()
|
||||
logWriter = nil
|
||||
}
|
||||
log.SetOutput(os.Stdout)
|
||||
configureLogDirCleanerLocked(logDir, logsMaxTotalSizeMB, protectedPath)
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -112,6 +118,8 @@ func closeLogOutputs() {
|
||||
writerMu.Lock()
|
||||
defer writerMu.Unlock()
|
||||
|
||||
stopLogDirCleanerLocked()
|
||||
|
||||
if logWriter != nil {
|
||||
_ = logWriter.Close()
|
||||
logWriter = nil
|
||||
|
||||
166
internal/logging/log_dir_cleaner.go
Normal file
166
internal/logging/log_dir_cleaner.go
Normal file
@@ -0,0 +1,166 @@
|
||||
package logging
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"sort"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
const logDirCleanerInterval = time.Minute
|
||||
|
||||
var logDirCleanerCancel context.CancelFunc
|
||||
|
||||
func configureLogDirCleanerLocked(logDir string, maxTotalSizeMB int, protectedPath string) {
|
||||
stopLogDirCleanerLocked()
|
||||
|
||||
if maxTotalSizeMB <= 0 {
|
||||
return
|
||||
}
|
||||
|
||||
maxBytes := int64(maxTotalSizeMB) * 1024 * 1024
|
||||
if maxBytes <= 0 {
|
||||
return
|
||||
}
|
||||
|
||||
dir := strings.TrimSpace(logDir)
|
||||
if dir == "" {
|
||||
return
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
logDirCleanerCancel = cancel
|
||||
go runLogDirCleaner(ctx, filepath.Clean(dir), maxBytes, strings.TrimSpace(protectedPath))
|
||||
}
|
||||
|
||||
func stopLogDirCleanerLocked() {
|
||||
if logDirCleanerCancel == nil {
|
||||
return
|
||||
}
|
||||
logDirCleanerCancel()
|
||||
logDirCleanerCancel = nil
|
||||
}
|
||||
|
||||
func runLogDirCleaner(ctx context.Context, logDir string, maxBytes int64, protectedPath string) {
|
||||
ticker := time.NewTicker(logDirCleanerInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
cleanOnce := func() {
|
||||
deleted, errClean := enforceLogDirSizeLimit(logDir, maxBytes, protectedPath)
|
||||
if errClean != nil {
|
||||
log.WithError(errClean).Warn("logging: failed to enforce log directory size limit")
|
||||
return
|
||||
}
|
||||
if deleted > 0 {
|
||||
log.Debugf("logging: removed %d old log file(s) to enforce log directory size limit", deleted)
|
||||
}
|
||||
}
|
||||
|
||||
cleanOnce()
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
cleanOnce()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func enforceLogDirSizeLimit(logDir string, maxBytes int64, protectedPath string) (int, error) {
|
||||
if maxBytes <= 0 {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
dir := strings.TrimSpace(logDir)
|
||||
if dir == "" {
|
||||
return 0, nil
|
||||
}
|
||||
dir = filepath.Clean(dir)
|
||||
|
||||
entries, errRead := os.ReadDir(dir)
|
||||
if errRead != nil {
|
||||
if os.IsNotExist(errRead) {
|
||||
return 0, nil
|
||||
}
|
||||
return 0, errRead
|
||||
}
|
||||
|
||||
protected := strings.TrimSpace(protectedPath)
|
||||
if protected != "" {
|
||||
protected = filepath.Clean(protected)
|
||||
}
|
||||
|
||||
type logFile struct {
|
||||
path string
|
||||
size int64
|
||||
modTime time.Time
|
||||
}
|
||||
|
||||
var (
|
||||
files []logFile
|
||||
total int64
|
||||
)
|
||||
for _, entry := range entries {
|
||||
if entry.IsDir() {
|
||||
continue
|
||||
}
|
||||
name := entry.Name()
|
||||
if !isLogFileName(name) {
|
||||
continue
|
||||
}
|
||||
info, errInfo := entry.Info()
|
||||
if errInfo != nil {
|
||||
continue
|
||||
}
|
||||
if !info.Mode().IsRegular() {
|
||||
continue
|
||||
}
|
||||
path := filepath.Join(dir, name)
|
||||
files = append(files, logFile{
|
||||
path: path,
|
||||
size: info.Size(),
|
||||
modTime: info.ModTime(),
|
||||
})
|
||||
total += info.Size()
|
||||
}
|
||||
|
||||
if total <= maxBytes {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
sort.Slice(files, func(i, j int) bool {
|
||||
return files[i].modTime.Before(files[j].modTime)
|
||||
})
|
||||
|
||||
deleted := 0
|
||||
for _, file := range files {
|
||||
if total <= maxBytes {
|
||||
break
|
||||
}
|
||||
if protected != "" && filepath.Clean(file.path) == protected {
|
||||
continue
|
||||
}
|
||||
if errRemove := os.Remove(file.path); errRemove != nil {
|
||||
log.WithError(errRemove).Warnf("logging: failed to remove old log file: %s", filepath.Base(file.path))
|
||||
continue
|
||||
}
|
||||
total -= file.size
|
||||
deleted++
|
||||
}
|
||||
|
||||
return deleted, nil
|
||||
}
|
||||
|
||||
func isLogFileName(name string) bool {
|
||||
trimmed := strings.TrimSpace(name)
|
||||
if trimmed == "" {
|
||||
return false
|
||||
}
|
||||
lower := strings.ToLower(trimmed)
|
||||
return strings.HasSuffix(lower, ".log") || strings.HasSuffix(lower, ".log.gz")
|
||||
}
|
||||
70
internal/logging/log_dir_cleaner_test.go
Normal file
70
internal/logging/log_dir_cleaner_test.go
Normal file
@@ -0,0 +1,70 @@
|
||||
package logging
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func TestEnforceLogDirSizeLimitDeletesOldest(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
|
||||
writeLogFile(t, filepath.Join(dir, "old.log"), 60, time.Unix(1, 0))
|
||||
writeLogFile(t, filepath.Join(dir, "mid.log"), 60, time.Unix(2, 0))
|
||||
protected := filepath.Join(dir, "main.log")
|
||||
writeLogFile(t, protected, 60, time.Unix(3, 0))
|
||||
|
||||
deleted, err := enforceLogDirSizeLimit(dir, 120, protected)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if deleted != 1 {
|
||||
t.Fatalf("expected 1 deleted file, got %d", deleted)
|
||||
}
|
||||
|
||||
if _, err := os.Stat(filepath.Join(dir, "old.log")); !os.IsNotExist(err) {
|
||||
t.Fatalf("expected old.log to be removed, stat error: %v", err)
|
||||
}
|
||||
if _, err := os.Stat(filepath.Join(dir, "mid.log")); err != nil {
|
||||
t.Fatalf("expected mid.log to remain, stat error: %v", err)
|
||||
}
|
||||
if _, err := os.Stat(protected); err != nil {
|
||||
t.Fatalf("expected protected main.log to remain, stat error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestEnforceLogDirSizeLimitSkipsProtected(t *testing.T) {
|
||||
dir := t.TempDir()
|
||||
|
||||
protected := filepath.Join(dir, "main.log")
|
||||
writeLogFile(t, protected, 200, time.Unix(1, 0))
|
||||
writeLogFile(t, filepath.Join(dir, "other.log"), 50, time.Unix(2, 0))
|
||||
|
||||
deleted, err := enforceLogDirSizeLimit(dir, 100, protected)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if deleted != 1 {
|
||||
t.Fatalf("expected 1 deleted file, got %d", deleted)
|
||||
}
|
||||
|
||||
if _, err := os.Stat(protected); err != nil {
|
||||
t.Fatalf("expected protected main.log to remain, stat error: %v", err)
|
||||
}
|
||||
if _, err := os.Stat(filepath.Join(dir, "other.log")); !os.IsNotExist(err) {
|
||||
t.Fatalf("expected other.log to be removed, stat error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func writeLogFile(t *testing.T, path string, size int, modTime time.Time) {
|
||||
t.Helper()
|
||||
|
||||
data := make([]byte, size)
|
||||
if err := os.WriteFile(path, data, 0o644); err != nil {
|
||||
t.Fatalf("write file: %v", err)
|
||||
}
|
||||
if err := os.Chtimes(path, modTime, modTime); err != nil {
|
||||
t.Fatalf("set times: %v", err)
|
||||
}
|
||||
}
|
||||
@@ -20,6 +20,7 @@ func CodexInstructionsForModel(modelName, systemInstructions string) (bool, stri
|
||||
lastCodexMaxPrompt := ""
|
||||
last51Prompt := ""
|
||||
last52Prompt := ""
|
||||
last52CodexPrompt := ""
|
||||
// lastReviewPrompt := ""
|
||||
for _, entry := range entries {
|
||||
content, _ := codexInstructionsDir.ReadFile("codex_instructions/" + entry.Name())
|
||||
@@ -36,12 +37,16 @@ func CodexInstructionsForModel(modelName, systemInstructions string) (bool, stri
|
||||
last51Prompt = string(content)
|
||||
} else if strings.HasPrefix(entry.Name(), "gpt_5_2_prompt.md") {
|
||||
last52Prompt = string(content)
|
||||
} else if strings.HasPrefix(entry.Name(), "gpt-5.2-codex_prompt.md") {
|
||||
last52CodexPrompt = string(content)
|
||||
} else if strings.HasPrefix(entry.Name(), "review_prompt.md") {
|
||||
// lastReviewPrompt = string(content)
|
||||
}
|
||||
}
|
||||
if strings.Contains(modelName, "codex-max") {
|
||||
return false, lastCodexMaxPrompt
|
||||
} else if strings.Contains(modelName, "5.2-codex") {
|
||||
return false, last52CodexPrompt
|
||||
} else if strings.Contains(modelName, "codex") {
|
||||
return false, lastCodexPrompt
|
||||
} else if strings.Contains(modelName, "5.1") {
|
||||
|
||||
@@ -0,0 +1,117 @@
|
||||
You are Codex, based on GPT-5. You are running as a coding agent in the Codex CLI on a user's computer.
|
||||
|
||||
## General
|
||||
|
||||
- When searching for text or files, prefer using `rg` or `rg --files` respectively because `rg` is much faster than alternatives like `grep`. (If the `rg` command is not found, then use alternatives.)
|
||||
|
||||
## Editing constraints
|
||||
|
||||
- Default to ASCII when editing or creating files. Only introduce non-ASCII or other Unicode characters when there is a clear justification and the file already uses them.
|
||||
- Add succinct code comments that explain what is going on if code is not self-explanatory. You should not add comments like "Assigns the value to the variable", but a brief comment might be useful ahead of a complex code block that the user would otherwise have to spend time parsing out. Usage of these comments should be rare.
|
||||
- Try to use apply_patch for single file edits, but it is fine to explore other options to make the edit if it does not work well. Do not use apply_patch for changes that are auto-generated (i.e. generating package.json or running a lint or format command like gofmt) or when scripting is more efficient (such as search and replacing a string across a codebase).
|
||||
- You may be in a dirty git worktree.
|
||||
* NEVER revert existing changes you did not make unless explicitly requested, since these changes were made by the user.
|
||||
* If asked to make a commit or code edits and there are unrelated changes to your work or changes that you didn't make in those files, don't revert those changes.
|
||||
* If the changes are in files you've touched recently, you should read carefully and understand how you can work with the changes rather than reverting them.
|
||||
* If the changes are in unrelated files, just ignore them and don't revert them.
|
||||
- Do not amend a commit unless explicitly requested to do so.
|
||||
- While you are working, you might notice unexpected changes that you didn't make. If this happens, STOP IMMEDIATELY and ask the user how they would like to proceed.
|
||||
- **NEVER** use destructive commands like `git reset --hard` or `git checkout --` unless specifically requested or approved by the user.
|
||||
|
||||
## Plan tool
|
||||
|
||||
When using the planning tool:
|
||||
- Skip using the planning tool for straightforward tasks (roughly the easiest 25%).
|
||||
- Do not make single-step plans.
|
||||
- When you made a plan, update it after having performed one of the sub-tasks that you shared on the plan.
|
||||
|
||||
## Codex CLI harness, sandboxing, and approvals
|
||||
|
||||
The Codex CLI harness supports several different configurations for sandboxing and escalation approvals that the user can choose from.
|
||||
|
||||
Filesystem sandboxing defines which files can be read or written. The options for `sandbox_mode` are:
|
||||
- **read-only**: The sandbox only permits reading files.
|
||||
- **workspace-write**: The sandbox permits reading files, and editing files in `cwd` and `writable_roots`. Editing files in other directories requires approval.
|
||||
- **danger-full-access**: No filesystem sandboxing - all commands are permitted.
|
||||
|
||||
Network sandboxing defines whether network can be accessed without approval. Options for `network_access` are:
|
||||
- **restricted**: Requires approval
|
||||
- **enabled**: No approval needed
|
||||
|
||||
Approvals are your mechanism to get user consent to run shell commands without the sandbox. Possible configuration options for `approval_policy` are
|
||||
- **untrusted**: The harness will escalate most commands for user approval, apart from a limited allowlist of safe "read" commands.
|
||||
- **on-failure**: The harness will allow all commands to run in the sandbox (if enabled), and failures will be escalated to the user for approval to run again without the sandbox.
|
||||
- **on-request**: Commands will be run in the sandbox by default, and you can specify in your tool call if you want to escalate a command to run without sandboxing. (Note that this mode is not always available. If it is, you'll see parameters for it in the `shell` command description.)
|
||||
- **never**: This is a non-interactive mode where you may NEVER ask the user for approval to run commands. Instead, you must always persist and work around constraints to solve the task for the user. You MUST do your utmost best to finish the task and validate your work before yielding. If this mode is paired with `danger-full-access`, take advantage of it to deliver the best outcome for the user. Further, in this mode, your default testing philosophy is overridden: Even if you don't see local patterns for testing, you may add tests and scripts to validate your work. Just remove them before yielding.
|
||||
|
||||
When you are running with `approval_policy == on-request`, and sandboxing enabled, here are scenarios where you'll need to request approval:
|
||||
- You need to run a command that writes to a directory that requires it (e.g. running tests that write to /var)
|
||||
- You need to run a GUI app (e.g., open/xdg-open/osascript) to open browsers or files.
|
||||
- You are running sandboxed and need to run a command that requires network access (e.g. installing packages)
|
||||
- If you run a command that is important to solving the user's query, but it fails because of sandboxing, rerun the command with approval. ALWAYS proceed to use the `sandbox_permissions` and `justification` parameters - do not message the user before requesting approval for the command.
|
||||
- You are about to take a potentially destructive action such as an `rm` or `git reset` that the user did not explicitly ask for
|
||||
- (for all of these, you should weigh alternative paths that do not require approval)
|
||||
|
||||
When `sandbox_mode` is set to read-only, you'll need to request approval for any command that isn't a read.
|
||||
|
||||
You will be told what filesystem sandboxing, network sandboxing, and approval mode are active in a developer or user message. If you are not told about this, assume that you are running with workspace-write, network sandboxing enabled, and approval on-failure.
|
||||
|
||||
Although they introduce friction to the user because your work is paused until the user responds, you should leverage them when necessary to accomplish important work. If the completing the task requires escalated permissions, Do not let these settings or the sandbox deter you from attempting to accomplish the user's task unless it is set to "never", in which case never ask for approvals.
|
||||
|
||||
When requesting approval to execute a command that will require escalated privileges:
|
||||
- Provide the `sandbox_permissions` parameter with the value `"require_escalated"`
|
||||
- Include a short, 1 sentence explanation for why you need escalated permissions in the justification parameter
|
||||
|
||||
## Special user requests
|
||||
|
||||
- If the user makes a simple request (such as asking for the time) which you can fulfill by running a terminal command (such as `date`), you should do so.
|
||||
- If the user asks for a "review", default to a code review mindset: prioritise identifying bugs, risks, behavioural regressions, and missing tests. Findings must be the primary focus of the response - keep summaries or overviews brief and only after enumerating the issues. Present findings first (ordered by severity with file/line references), follow with open questions or assumptions, and offer a change-summary only as a secondary detail. If no findings are discovered, state that explicitly and mention any residual risks or testing gaps.
|
||||
|
||||
## Frontend tasks
|
||||
When doing frontend design tasks, avoid collapsing into "AI slop" or safe, average-looking layouts.
|
||||
Aim for interfaces that feel intentional, bold, and a bit surprising.
|
||||
- Typography: Use expressive, purposeful fonts and avoid default stacks (Inter, Roboto, Arial, system).
|
||||
- Color & Look: Choose a clear visual direction; define CSS variables; avoid purple-on-white defaults. No purple bias or dark mode bias.
|
||||
- Motion: Use a few meaningful animations (page-load, staggered reveals) instead of generic micro-motions.
|
||||
- Background: Don't rely on flat, single-color backgrounds; use gradients, shapes, or subtle patterns to build atmosphere.
|
||||
- Overall: Avoid boilerplate layouts and interchangeable UI patterns. Vary themes, type families, and visual languages across outputs.
|
||||
- Ensure the page loads properly on both desktop and mobile
|
||||
|
||||
Exception: If working within an existing website or design system, preserve the established patterns, structure, and visual language.
|
||||
|
||||
## Presenting your work and final message
|
||||
|
||||
You are producing plain text that will later be styled by the CLI. Follow these rules exactly. Formatting should make results easy to scan, but not feel mechanical. Use judgment to decide how much structure adds value.
|
||||
|
||||
- Default: be very concise; friendly coding teammate tone.
|
||||
- Ask only when needed; suggest ideas; mirror the user's style.
|
||||
- For substantial work, summarize clearly; follow final‑answer formatting.
|
||||
- Skip heavy formatting for simple confirmations.
|
||||
- Don't dump large files you've written; reference paths only.
|
||||
- No "save/copy this file" - User is on the same machine.
|
||||
- Offer logical next steps (tests, commits, build) briefly; add verify steps if you couldn't do something.
|
||||
- For code changes:
|
||||
* Lead with a quick explanation of the change, and then give more details on the context covering where and why a change was made. Do not start this explanation with "summary", just jump right in.
|
||||
* If there are natural next steps the user may want to take, suggest them at the end of your response. Do not make suggestions if there are no natural next steps.
|
||||
* When suggesting multiple options, use numeric lists for the suggestions so the user can quickly respond with a single number.
|
||||
- The user does not command execution outputs. When asked to show the output of a command (e.g. `git show`), relay the important details in your answer or summarize the key lines so the user understands the result.
|
||||
|
||||
### Final answer structure and style guidelines
|
||||
|
||||
- Plain text; CLI handles styling. Use structure only when it helps scanability.
|
||||
- Headers: optional; short Title Case (1-3 words) wrapped in **…**; no blank line before the first bullet; add only if they truly help.
|
||||
- Bullets: use - ; merge related points; keep to one line when possible; 4–6 per list ordered by importance; keep phrasing consistent.
|
||||
- Monospace: backticks for commands/paths/env vars/code ids and inline examples; use for literal keyword bullets; never combine with **.
|
||||
- Code samples or multi-line snippets should be wrapped in fenced code blocks; include an info string as often as possible.
|
||||
- Structure: group related bullets; order sections general → specific → supporting; for subsections, start with a bolded keyword bullet, then items; match complexity to the task.
|
||||
- Tone: collaborative, concise, factual; present tense, active voice; self‑contained; no "above/below"; parallel wording.
|
||||
- Don'ts: no nested bullets/hierarchies; no ANSI codes; don't cram unrelated keywords; keep keyword lists short—wrap/reformat if long; avoid naming formatting styles in answers.
|
||||
- Adaptation: code explanations → precise, structured with code refs; simple tasks → lead with outcome; big changes → logical walkthrough + rationale + next actions; casual one-offs → plain sentences, no headers/bullets.
|
||||
- File References: When referencing files in your response follow the below rules:
|
||||
* Use inline code to make file paths clickable.
|
||||
* Each reference should have a stand alone path. Even if it's the same file.
|
||||
* Accepted: absolute, workspace‑relative, a/ or b/ diff prefixes, or bare filename/suffix.
|
||||
* Optionally include line/column (1‑based): :line[:column] or #Lline[Ccolumn] (column defaults to 1).
|
||||
* Do not use URIs like file://, vscode://, or https://.
|
||||
* Do not provide range of lines
|
||||
* Examples: src/app.ts, src/app.ts:42, b/server/index.js#L10, C:\repo\project\main.rs:12:5
|
||||
@@ -160,7 +160,7 @@ func GetGeminiModels() []*ModelInfo {
|
||||
InputTokenLimit: 1048576,
|
||||
OutputTokenLimit: 65536,
|
||||
SupportedGenerationMethods: []string{"generateContent", "countTokens", "createCachedContent", "batchGenerateContent"},
|
||||
Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true},
|
||||
Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"low", "high"}},
|
||||
},
|
||||
{
|
||||
ID: "gemini-3-pro-image-preview",
|
||||
@@ -175,7 +175,7 @@ func GetGeminiModels() []*ModelInfo {
|
||||
InputTokenLimit: 1048576,
|
||||
OutputTokenLimit: 65536,
|
||||
SupportedGenerationMethods: []string{"generateContent", "countTokens", "createCachedContent", "batchGenerateContent"},
|
||||
Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true},
|
||||
Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"low", "high"}},
|
||||
},
|
||||
}
|
||||
}
|
||||
@@ -240,7 +240,22 @@ func GetGeminiVertexModels() []*ModelInfo {
|
||||
InputTokenLimit: 1048576,
|
||||
OutputTokenLimit: 65536,
|
||||
SupportedGenerationMethods: []string{"generateContent", "countTokens", "createCachedContent", "batchGenerateContent"},
|
||||
Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true},
|
||||
Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"low", "high"}},
|
||||
},
|
||||
{
|
||||
ID: "gemini-3-flash-preview",
|
||||
Object: "model",
|
||||
Created: 1765929600,
|
||||
OwnedBy: "google",
|
||||
Type: "gemini",
|
||||
Name: "models/gemini-3-flash-preview",
|
||||
Version: "3.0",
|
||||
DisplayName: "Gemini 3 Flash Preview",
|
||||
Description: "Our most intelligent model built for speed, combining frontier intelligence with superior search and grounding.",
|
||||
InputTokenLimit: 1048576,
|
||||
OutputTokenLimit: 65536,
|
||||
SupportedGenerationMethods: []string{"generateContent", "countTokens", "createCachedContent", "batchGenerateContent"},
|
||||
Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"minimal", "low", "medium", "high"}},
|
||||
},
|
||||
{
|
||||
ID: "gemini-3-pro-image-preview",
|
||||
@@ -255,7 +270,7 @@ func GetGeminiVertexModels() []*ModelInfo {
|
||||
InputTokenLimit: 1048576,
|
||||
OutputTokenLimit: 65536,
|
||||
SupportedGenerationMethods: []string{"generateContent", "countTokens", "createCachedContent", "batchGenerateContent"},
|
||||
Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true},
|
||||
Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"low", "high"}},
|
||||
},
|
||||
}
|
||||
}
|
||||
@@ -317,11 +332,26 @@ func GetGeminiCLIModels() []*ModelInfo {
|
||||
Name: "models/gemini-3-pro-preview",
|
||||
Version: "3.0",
|
||||
DisplayName: "Gemini 3 Pro Preview",
|
||||
Description: "Gemini 3 Pro Preview",
|
||||
Description: "Our most intelligent model with SOTA reasoning and multimodal understanding, and powerful agentic and vibe coding capabilities",
|
||||
InputTokenLimit: 1048576,
|
||||
OutputTokenLimit: 65536,
|
||||
SupportedGenerationMethods: []string{"generateContent", "countTokens", "createCachedContent", "batchGenerateContent"},
|
||||
Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true},
|
||||
Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"low", "high"}},
|
||||
},
|
||||
{
|
||||
ID: "gemini-3-flash-preview",
|
||||
Object: "model",
|
||||
Created: 1765929600,
|
||||
OwnedBy: "google",
|
||||
Type: "gemini",
|
||||
Name: "models/gemini-3-flash-preview",
|
||||
Version: "3.0",
|
||||
DisplayName: "Gemini 3 Flash Preview",
|
||||
Description: "Our most intelligent model built for speed, combining frontier intelligence with superior search and grounding.",
|
||||
InputTokenLimit: 1048576,
|
||||
OutputTokenLimit: 65536,
|
||||
SupportedGenerationMethods: []string{"generateContent", "countTokens", "createCachedContent", "batchGenerateContent"},
|
||||
Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"minimal", "low", "medium", "high"}},
|
||||
},
|
||||
}
|
||||
}
|
||||
@@ -387,7 +417,22 @@ func GetAIStudioModels() []*ModelInfo {
|
||||
InputTokenLimit: 1048576,
|
||||
OutputTokenLimit: 65536,
|
||||
SupportedGenerationMethods: []string{"generateContent", "countTokens", "createCachedContent", "batchGenerateContent"},
|
||||
Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true},
|
||||
Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"low", "high"}},
|
||||
},
|
||||
{
|
||||
ID: "gemini-3-flash-preview",
|
||||
Object: "model",
|
||||
Created: 1765929600,
|
||||
OwnedBy: "google",
|
||||
Type: "gemini",
|
||||
Name: "models/gemini-3-flash-preview",
|
||||
Version: "3.0",
|
||||
DisplayName: "Gemini 3 Flash Preview",
|
||||
Description: "Our most intelligent model built for speed, combining frontier intelligence with superior search and grounding.",
|
||||
InputTokenLimit: 1048576,
|
||||
OutputTokenLimit: 65536,
|
||||
SupportedGenerationMethods: []string{"generateContent", "countTokens", "createCachedContent", "batchGenerateContent"},
|
||||
Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"minimal", "low", "medium", "high"}},
|
||||
},
|
||||
{
|
||||
ID: "gemini-pro-latest",
|
||||
@@ -582,6 +627,20 @@ func GetOpenAIModels() []*ModelInfo {
|
||||
SupportedParameters: []string{"tools"},
|
||||
Thinking: &ThinkingSupport{Levels: []string{"none", "low", "medium", "high", "xhigh"}},
|
||||
},
|
||||
{
|
||||
ID: "gpt-5.2-codex",
|
||||
Object: "model",
|
||||
Created: 1765440000,
|
||||
OwnedBy: "openai",
|
||||
Type: "openai",
|
||||
Version: "gpt-5.2",
|
||||
DisplayName: "GPT 5.2 Codex",
|
||||
Description: "Stable version of GPT 5.2 Codex, The best model for coding and agentic tasks across domains.",
|
||||
ContextLength: 400000,
|
||||
MaxCompletionTokens: 128000,
|
||||
SupportedParameters: []string{"tools"},
|
||||
Thinking: &ThinkingSupport{Levels: []string{"low", "medium", "high", "xhigh"}},
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
@@ -630,6 +689,13 @@ func GetQwenModels() []*ModelInfo {
|
||||
}
|
||||
}
|
||||
|
||||
// iFlowThinkingSupport is a shared ThinkingSupport configuration for iFlow models
|
||||
// that support thinking mode via chat_template_kwargs.enable_thinking (boolean toggle).
|
||||
// Uses level-based configuration so standard normalization flows apply before conversion.
|
||||
var iFlowThinkingSupport = &ThinkingSupport{
|
||||
Levels: []string{"none", "auto", "minimal", "low", "medium", "high", "xhigh"},
|
||||
}
|
||||
|
||||
// GetIFlowModels returns supported models for iFlow OAuth accounts.
|
||||
func GetIFlowModels() []*ModelInfo {
|
||||
entries := []struct {
|
||||
@@ -645,9 +711,9 @@ func GetIFlowModels() []*ModelInfo {
|
||||
{ID: "qwen3-vl-plus", DisplayName: "Qwen3-VL-Plus", Description: "Qwen3 multimodal vision-language", Created: 1758672000},
|
||||
{ID: "qwen3-max-preview", DisplayName: "Qwen3-Max-Preview", Description: "Qwen3 Max preview build", Created: 1757030400},
|
||||
{ID: "kimi-k2-0905", DisplayName: "Kimi-K2-Instruct-0905", Description: "Moonshot Kimi K2 instruct 0905", Created: 1757030400},
|
||||
{ID: "glm-4.6", DisplayName: "GLM-4.6", Description: "Zhipu GLM 4.6 general model", Created: 1759190400},
|
||||
{ID: "glm-4.6", DisplayName: "GLM-4.6", Description: "Zhipu GLM 4.6 general model", Created: 1759190400, Thinking: iFlowThinkingSupport},
|
||||
{ID: "kimi-k2", DisplayName: "Kimi-K2", Description: "Moonshot Kimi K2 general model", Created: 1752192000},
|
||||
{ID: "kimi-k2-thinking", DisplayName: "Kimi-K2-Thinking", Description: "Moonshot Kimi K2 thinking model", Created: 1762387200, Thinking: &ThinkingSupport{Levels: []string{"low", "medium", "high"}}},
|
||||
{ID: "kimi-k2-thinking", DisplayName: "Kimi-K2-Thinking", Description: "Moonshot Kimi K2 thinking model", Created: 1762387200},
|
||||
{ID: "deepseek-v3.2-chat", DisplayName: "DeepSeek-V3.2", Description: "DeepSeek V3.2 Chat", Created: 1764576000},
|
||||
{ID: "deepseek-v3.2-reasoner", DisplayName: "DeepSeek-V3.2", Description: "DeepSeek V3.2 Reasoner", Created: 1764576000},
|
||||
{ID: "deepseek-v3.2", DisplayName: "DeepSeek-V3.2-Exp", Description: "DeepSeek V3.2 experimental", Created: 1759104000},
|
||||
@@ -655,10 +721,10 @@ func GetIFlowModels() []*ModelInfo {
|
||||
{ID: "deepseek-r1", DisplayName: "DeepSeek-R1", Description: "DeepSeek reasoning model R1", Created: 1737331200},
|
||||
{ID: "deepseek-v3", DisplayName: "DeepSeek-V3-671B", Description: "DeepSeek V3 671B", Created: 1734307200},
|
||||
{ID: "qwen3-32b", DisplayName: "Qwen3-32B", Description: "Qwen3 32B", Created: 1747094400},
|
||||
{ID: "qwen3-235b-a22b-thinking-2507", DisplayName: "Qwen3-235B-A22B-Thinking", Description: "Qwen3 235B A22B Thinking (2507)", Created: 1753401600, Thinking: &ThinkingSupport{Levels: []string{"low", "medium", "high"}}},
|
||||
{ID: "qwen3-235b-a22b-thinking-2507", DisplayName: "Qwen3-235B-A22B-Thinking", Description: "Qwen3 235B A22B Thinking (2507)", Created: 1753401600},
|
||||
{ID: "qwen3-235b-a22b-instruct", DisplayName: "Qwen3-235B-A22B-Instruct", Description: "Qwen3 235B A22B Instruct", Created: 1753401600},
|
||||
{ID: "qwen3-235b", DisplayName: "Qwen3-235B-A22B", Description: "Qwen3 235B A22B", Created: 1753401600},
|
||||
{ID: "minimax-m2", DisplayName: "MiniMax-M2", Description: "MiniMax M2", Created: 1758672000, Thinking: &ThinkingSupport{Levels: []string{"low", "medium", "high"}}},
|
||||
{ID: "minimax-m2", DisplayName: "MiniMax-M2", Description: "MiniMax M2", Created: 1758672000},
|
||||
}
|
||||
models := make([]*ModelInfo, 0, len(entries))
|
||||
for _, entry := range entries {
|
||||
@@ -691,8 +757,9 @@ func GetAntigravityModelConfig() map[string]*AntigravityModelConfig {
|
||||
"gemini-2.5-flash": {Thinking: &ThinkingSupport{Min: 0, Max: 24576, ZeroAllowed: true, DynamicAllowed: true}, Name: "models/gemini-2.5-flash"},
|
||||
"gemini-2.5-flash-lite": {Thinking: &ThinkingSupport{Min: 0, Max: 24576, ZeroAllowed: true, DynamicAllowed: true}, Name: "models/gemini-2.5-flash-lite"},
|
||||
"gemini-2.5-computer-use-preview-10-2025": {Name: "models/gemini-2.5-computer-use-preview-10-2025"},
|
||||
"gemini-3-pro-preview": {Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true}, Name: "models/gemini-3-pro-preview"},
|
||||
"gemini-3-pro-image-preview": {Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true}, Name: "models/gemini-3-pro-image-preview"},
|
||||
"gemini-3-pro-preview": {Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"low", "high"}}, Name: "models/gemini-3-pro-preview"},
|
||||
"gemini-3-pro-image-preview": {Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"low", "high"}}, Name: "models/gemini-3-pro-image-preview"},
|
||||
"gemini-3-flash-preview": {Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"minimal", "low", "medium", "high"}}, Name: "models/gemini-3-flash-preview"},
|
||||
"gemini-claude-sonnet-4-5-thinking": {Thinking: &ThinkingSupport{Min: 1024, Max: 200000, ZeroAllowed: false, DynamicAllowed: true}, MaxCompletionTokens: 64000},
|
||||
"gemini-claude-opus-4-5-thinking": {Thinking: &ThinkingSupport{Min: 1024, Max: 200000, ZeroAllowed: false, DynamicAllowed: true}, MaxCompletionTokens: 64000},
|
||||
}
|
||||
|
||||
@@ -323,9 +323,10 @@ func (e *AIStudioExecutor) translateRequest(req cliproxyexecutor.Request, opts c
|
||||
to := sdktranslator.FromString("gemini")
|
||||
payload := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), stream)
|
||||
payload = ApplyThinkingMetadata(payload, req.Metadata, req.Model)
|
||||
payload = util.ApplyGemini3ThinkingLevelFromMetadata(req.Model, req.Metadata, payload)
|
||||
payload = util.ApplyDefaultThinkingIfNeeded(req.Model, payload)
|
||||
payload = util.ConvertThinkingLevelToBudget(payload)
|
||||
payload = util.NormalizeGeminiThinkingBudget(req.Model, payload)
|
||||
payload = util.ConvertThinkingLevelToBudget(payload, req.Model, true)
|
||||
payload = util.NormalizeGeminiThinkingBudget(req.Model, payload, true)
|
||||
payload = util.StripThinkingConfigIfUnsupported(req.Model, payload)
|
||||
payload = fixGeminiImageAspectRatio(req.Model, payload)
|
||||
payload = applyPayloadConfig(e.cfg, req.Model, payload)
|
||||
|
||||
@@ -32,15 +32,16 @@ import (
|
||||
const (
|
||||
antigravityBaseURLDaily = "https://daily-cloudcode-pa.sandbox.googleapis.com"
|
||||
// antigravityBaseURLAutopush = "https://autopush-cloudcode-pa.sandbox.googleapis.com"
|
||||
antigravityBaseURLProd = "https://cloudcode-pa.googleapis.com"
|
||||
antigravityStreamPath = "/v1internal:streamGenerateContent"
|
||||
antigravityGeneratePath = "/v1internal:generateContent"
|
||||
antigravityModelsPath = "/v1internal:fetchAvailableModels"
|
||||
antigravityClientID = "1071006060591-tmhssin2h21lcre235vtolojh4g403ep.apps.googleusercontent.com"
|
||||
antigravityClientSecret = "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"
|
||||
defaultAntigravityAgent = "antigravity/1.11.5 windows/amd64"
|
||||
antigravityAuthType = "antigravity"
|
||||
refreshSkew = 3000 * time.Second
|
||||
antigravityBaseURLProd = "https://cloudcode-pa.googleapis.com"
|
||||
antigravityCountTokensPath = "/v1internal:countTokens"
|
||||
antigravityStreamPath = "/v1internal:streamGenerateContent"
|
||||
antigravityGeneratePath = "/v1internal:generateContent"
|
||||
antigravityModelsPath = "/v1internal:fetchAvailableModels"
|
||||
antigravityClientID = "1071006060591-tmhssin2h21lcre235vtolojh4g403ep.apps.googleusercontent.com"
|
||||
antigravityClientSecret = "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"
|
||||
defaultAntigravityAgent = "antigravity/1.11.5 windows/amd64"
|
||||
antigravityAuthType = "antigravity"
|
||||
refreshSkew = 3000 * time.Second
|
||||
)
|
||||
|
||||
var randSource = rand.New(rand.NewSource(time.Now().UnixNano()))
|
||||
@@ -69,6 +70,10 @@ func (e *AntigravityExecutor) PrepareRequest(_ *http.Request, _ *cliproxyauth.Au
|
||||
|
||||
// Execute performs a non-streaming request to the Antigravity API.
|
||||
func (e *AntigravityExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (resp cliproxyexecutor.Response, err error) {
|
||||
if strings.Contains(req.Model, "claude") {
|
||||
return e.executeClaudeNonStream(ctx, auth, req, opts)
|
||||
}
|
||||
|
||||
token, updatedAuth, errToken := e.ensureAccessToken(ctx, auth)
|
||||
if errToken != nil {
|
||||
return resp, errToken
|
||||
@@ -85,6 +90,7 @@ func (e *AntigravityExecutor) Execute(ctx context.Context, auth *cliproxyauth.Au
|
||||
translated := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), false)
|
||||
|
||||
translated = applyThinkingMetadataCLI(translated, req.Metadata, req.Model)
|
||||
translated = util.ApplyGemini3ThinkingLevelFromMetadataCLI(req.Model, req.Metadata, translated)
|
||||
translated = util.ApplyDefaultThinkingIfNeededCLI(req.Model, translated)
|
||||
translated = normalizeAntigravityThinking(req.Model, translated)
|
||||
|
||||
@@ -160,6 +166,337 @@ func (e *AntigravityExecutor) Execute(ctx context.Context, auth *cliproxyauth.Au
|
||||
return resp, err
|
||||
}
|
||||
|
||||
// executeClaudeNonStream performs a claude non-streaming request to the Antigravity API.
|
||||
func (e *AntigravityExecutor) executeClaudeNonStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (resp cliproxyexecutor.Response, err error) {
|
||||
token, updatedAuth, errToken := e.ensureAccessToken(ctx, auth)
|
||||
if errToken != nil {
|
||||
return resp, errToken
|
||||
}
|
||||
if updatedAuth != nil {
|
||||
auth = updatedAuth
|
||||
}
|
||||
|
||||
reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth)
|
||||
defer reporter.trackFailure(ctx, &err)
|
||||
|
||||
from := opts.SourceFormat
|
||||
to := sdktranslator.FromString("antigravity")
|
||||
translated := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), true)
|
||||
|
||||
translated = applyThinkingMetadataCLI(translated, req.Metadata, req.Model)
|
||||
translated = util.ApplyGemini3ThinkingLevelFromMetadataCLI(req.Model, req.Metadata, translated)
|
||||
translated = util.ApplyDefaultThinkingIfNeededCLI(req.Model, translated)
|
||||
translated = normalizeAntigravityThinking(req.Model, translated)
|
||||
|
||||
baseURLs := antigravityBaseURLFallbackOrder(auth)
|
||||
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
|
||||
|
||||
var lastStatus int
|
||||
var lastBody []byte
|
||||
var lastErr error
|
||||
|
||||
for idx, baseURL := range baseURLs {
|
||||
httpReq, errReq := e.buildRequest(ctx, auth, token, req.Model, translated, true, opts.Alt, baseURL)
|
||||
if errReq != nil {
|
||||
err = errReq
|
||||
return resp, err
|
||||
}
|
||||
|
||||
httpResp, errDo := httpClient.Do(httpReq)
|
||||
if errDo != nil {
|
||||
recordAPIResponseError(ctx, e.cfg, errDo)
|
||||
lastStatus = 0
|
||||
lastBody = nil
|
||||
lastErr = errDo
|
||||
if idx+1 < len(baseURLs) {
|
||||
log.Debugf("antigravity executor: request error on base url %s, retrying with fallback base url: %s", baseURL, baseURLs[idx+1])
|
||||
continue
|
||||
}
|
||||
err = errDo
|
||||
return resp, err
|
||||
}
|
||||
recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
|
||||
if httpResp.StatusCode < http.StatusOK || httpResp.StatusCode >= http.StatusMultipleChoices {
|
||||
bodyBytes, errRead := io.ReadAll(httpResp.Body)
|
||||
if errClose := httpResp.Body.Close(); errClose != nil {
|
||||
log.Errorf("antigravity executor: close response body error: %v", errClose)
|
||||
}
|
||||
if errRead != nil {
|
||||
recordAPIResponseError(ctx, e.cfg, errRead)
|
||||
lastStatus = 0
|
||||
lastBody = nil
|
||||
lastErr = errRead
|
||||
if idx+1 < len(baseURLs) {
|
||||
log.Debugf("antigravity executor: read error on base url %s, retrying with fallback base url: %s", baseURL, baseURLs[idx+1])
|
||||
continue
|
||||
}
|
||||
err = errRead
|
||||
return resp, err
|
||||
}
|
||||
appendAPIResponseChunk(ctx, e.cfg, bodyBytes)
|
||||
lastStatus = httpResp.StatusCode
|
||||
lastBody = append([]byte(nil), bodyBytes...)
|
||||
lastErr = nil
|
||||
if httpResp.StatusCode == http.StatusTooManyRequests && idx+1 < len(baseURLs) {
|
||||
log.Debugf("antigravity executor: rate limited on base url %s, retrying with fallback base url: %s", baseURL, baseURLs[idx+1])
|
||||
continue
|
||||
}
|
||||
err = statusErr{code: httpResp.StatusCode, msg: string(bodyBytes)}
|
||||
return resp, err
|
||||
}
|
||||
|
||||
out := make(chan cliproxyexecutor.StreamChunk)
|
||||
go func(resp *http.Response) {
|
||||
defer close(out)
|
||||
defer func() {
|
||||
if errClose := resp.Body.Close(); errClose != nil {
|
||||
log.Errorf("antigravity executor: close response body error: %v", errClose)
|
||||
}
|
||||
}()
|
||||
scanner := bufio.NewScanner(resp.Body)
|
||||
scanner.Buffer(nil, streamScannerBuffer)
|
||||
for scanner.Scan() {
|
||||
line := scanner.Bytes()
|
||||
appendAPIResponseChunk(ctx, e.cfg, line)
|
||||
|
||||
// Filter usage metadata for all models
|
||||
// Only retain usage statistics in the terminal chunk
|
||||
line = FilterSSEUsageMetadata(line)
|
||||
|
||||
payload := jsonPayload(line)
|
||||
if payload == nil {
|
||||
continue
|
||||
}
|
||||
|
||||
if detail, ok := parseAntigravityStreamUsage(payload); ok {
|
||||
reporter.publish(ctx, detail)
|
||||
}
|
||||
|
||||
out <- cliproxyexecutor.StreamChunk{Payload: payload}
|
||||
}
|
||||
if errScan := scanner.Err(); errScan != nil {
|
||||
recordAPIResponseError(ctx, e.cfg, errScan)
|
||||
reporter.publishFailure(ctx)
|
||||
out <- cliproxyexecutor.StreamChunk{Err: errScan}
|
||||
} else {
|
||||
reporter.ensurePublished(ctx)
|
||||
}
|
||||
}(httpResp)
|
||||
|
||||
var buffer bytes.Buffer
|
||||
for chunk := range out {
|
||||
if chunk.Err != nil {
|
||||
return resp, chunk.Err
|
||||
}
|
||||
if len(chunk.Payload) > 0 {
|
||||
_, _ = buffer.Write(chunk.Payload)
|
||||
_, _ = buffer.Write([]byte("\n"))
|
||||
}
|
||||
}
|
||||
resp = cliproxyexecutor.Response{Payload: e.convertStreamToNonStream(buffer.Bytes())}
|
||||
|
||||
reporter.publish(ctx, parseAntigravityUsage(resp.Payload))
|
||||
var param any
|
||||
converted := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, bytes.Clone(opts.OriginalRequest), translated, resp.Payload, ¶m)
|
||||
resp = cliproxyexecutor.Response{Payload: []byte(converted)}
|
||||
reporter.ensurePublished(ctx)
|
||||
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
switch {
|
||||
case lastStatus != 0:
|
||||
err = statusErr{code: lastStatus, msg: string(lastBody)}
|
||||
case lastErr != nil:
|
||||
err = lastErr
|
||||
default:
|
||||
err = statusErr{code: http.StatusServiceUnavailable, msg: "antigravity executor: no base url available"}
|
||||
}
|
||||
return resp, err
|
||||
}
|
||||
|
||||
func (e *AntigravityExecutor) convertStreamToNonStream(stream []byte) []byte {
|
||||
responseTemplate := ""
|
||||
var traceID string
|
||||
var finishReason string
|
||||
var modelVersion string
|
||||
var responseID string
|
||||
var role string
|
||||
var usageRaw string
|
||||
parts := make([]map[string]interface{}, 0)
|
||||
var pendingKind string
|
||||
var pendingText strings.Builder
|
||||
var pendingThoughtSig string
|
||||
|
||||
flushPending := func() {
|
||||
if pendingKind == "" {
|
||||
return
|
||||
}
|
||||
text := pendingText.String()
|
||||
switch pendingKind {
|
||||
case "text":
|
||||
if strings.TrimSpace(text) == "" {
|
||||
pendingKind = ""
|
||||
pendingText.Reset()
|
||||
pendingThoughtSig = ""
|
||||
return
|
||||
}
|
||||
parts = append(parts, map[string]interface{}{"text": text})
|
||||
case "thought":
|
||||
if strings.TrimSpace(text) == "" && pendingThoughtSig == "" {
|
||||
pendingKind = ""
|
||||
pendingText.Reset()
|
||||
pendingThoughtSig = ""
|
||||
return
|
||||
}
|
||||
part := map[string]interface{}{"thought": true}
|
||||
part["text"] = text
|
||||
if pendingThoughtSig != "" {
|
||||
part["thoughtSignature"] = pendingThoughtSig
|
||||
}
|
||||
parts = append(parts, part)
|
||||
}
|
||||
pendingKind = ""
|
||||
pendingText.Reset()
|
||||
pendingThoughtSig = ""
|
||||
}
|
||||
|
||||
normalizePart := func(partResult gjson.Result) map[string]interface{} {
|
||||
var m map[string]interface{}
|
||||
_ = json.Unmarshal([]byte(partResult.Raw), &m)
|
||||
if m == nil {
|
||||
m = map[string]interface{}{}
|
||||
}
|
||||
sig := partResult.Get("thoughtSignature").String()
|
||||
if sig == "" {
|
||||
sig = partResult.Get("thought_signature").String()
|
||||
}
|
||||
if sig != "" {
|
||||
m["thoughtSignature"] = sig
|
||||
delete(m, "thought_signature")
|
||||
}
|
||||
if inlineData, ok := m["inline_data"]; ok {
|
||||
m["inlineData"] = inlineData
|
||||
delete(m, "inline_data")
|
||||
}
|
||||
return m
|
||||
}
|
||||
|
||||
for _, line := range bytes.Split(stream, []byte("\n")) {
|
||||
trimmed := bytes.TrimSpace(line)
|
||||
if len(trimmed) == 0 || !gjson.ValidBytes(trimmed) {
|
||||
continue
|
||||
}
|
||||
|
||||
root := gjson.ParseBytes(trimmed)
|
||||
responseNode := root.Get("response")
|
||||
if !responseNode.Exists() {
|
||||
if root.Get("candidates").Exists() {
|
||||
responseNode = root
|
||||
} else {
|
||||
continue
|
||||
}
|
||||
}
|
||||
responseTemplate = responseNode.Raw
|
||||
|
||||
if traceResult := root.Get("traceId"); traceResult.Exists() && traceResult.String() != "" {
|
||||
traceID = traceResult.String()
|
||||
}
|
||||
|
||||
if roleResult := responseNode.Get("candidates.0.content.role"); roleResult.Exists() {
|
||||
role = roleResult.String()
|
||||
}
|
||||
|
||||
if finishResult := responseNode.Get("candidates.0.finishReason"); finishResult.Exists() && finishResult.String() != "" {
|
||||
finishReason = finishResult.String()
|
||||
}
|
||||
|
||||
if modelResult := responseNode.Get("modelVersion"); modelResult.Exists() && modelResult.String() != "" {
|
||||
modelVersion = modelResult.String()
|
||||
}
|
||||
if responseIDResult := responseNode.Get("responseId"); responseIDResult.Exists() && responseIDResult.String() != "" {
|
||||
responseID = responseIDResult.String()
|
||||
}
|
||||
if usageResult := responseNode.Get("usageMetadata"); usageResult.Exists() {
|
||||
usageRaw = usageResult.Raw
|
||||
} else if usageResult := root.Get("usageMetadata"); usageResult.Exists() {
|
||||
usageRaw = usageResult.Raw
|
||||
}
|
||||
|
||||
if partsResult := responseNode.Get("candidates.0.content.parts"); partsResult.IsArray() {
|
||||
for _, part := range partsResult.Array() {
|
||||
hasFunctionCall := part.Get("functionCall").Exists()
|
||||
hasInlineData := part.Get("inlineData").Exists() || part.Get("inline_data").Exists()
|
||||
sig := part.Get("thoughtSignature").String()
|
||||
if sig == "" {
|
||||
sig = part.Get("thought_signature").String()
|
||||
}
|
||||
text := part.Get("text").String()
|
||||
thought := part.Get("thought").Bool()
|
||||
|
||||
if hasFunctionCall || hasInlineData {
|
||||
flushPending()
|
||||
parts = append(parts, normalizePart(part))
|
||||
continue
|
||||
}
|
||||
|
||||
if thought || part.Get("text").Exists() {
|
||||
kind := "text"
|
||||
if thought {
|
||||
kind = "thought"
|
||||
}
|
||||
if pendingKind != "" && pendingKind != kind {
|
||||
flushPending()
|
||||
}
|
||||
pendingKind = kind
|
||||
pendingText.WriteString(text)
|
||||
if kind == "thought" && sig != "" {
|
||||
pendingThoughtSig = sig
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
flushPending()
|
||||
parts = append(parts, normalizePart(part))
|
||||
}
|
||||
}
|
||||
}
|
||||
flushPending()
|
||||
|
||||
if responseTemplate == "" {
|
||||
responseTemplate = `{"candidates":[{"content":{"role":"model","parts":[]}}]}`
|
||||
}
|
||||
|
||||
partsJSON, _ := json.Marshal(parts)
|
||||
responseTemplate, _ = sjson.SetRaw(responseTemplate, "candidates.0.content.parts", string(partsJSON))
|
||||
if role != "" {
|
||||
responseTemplate, _ = sjson.Set(responseTemplate, "candidates.0.content.role", role)
|
||||
}
|
||||
if finishReason != "" {
|
||||
responseTemplate, _ = sjson.Set(responseTemplate, "candidates.0.finishReason", finishReason)
|
||||
}
|
||||
if modelVersion != "" {
|
||||
responseTemplate, _ = sjson.Set(responseTemplate, "modelVersion", modelVersion)
|
||||
}
|
||||
if responseID != "" {
|
||||
responseTemplate, _ = sjson.Set(responseTemplate, "responseId", responseID)
|
||||
}
|
||||
if usageRaw != "" {
|
||||
responseTemplate, _ = sjson.SetRaw(responseTemplate, "usageMetadata", usageRaw)
|
||||
} else if !gjson.Get(responseTemplate, "usageMetadata").Exists() {
|
||||
responseTemplate, _ = sjson.Set(responseTemplate, "usageMetadata.promptTokenCount", 0)
|
||||
responseTemplate, _ = sjson.Set(responseTemplate, "usageMetadata.candidatesTokenCount", 0)
|
||||
responseTemplate, _ = sjson.Set(responseTemplate, "usageMetadata.totalTokenCount", 0)
|
||||
}
|
||||
|
||||
output := `{"response":{},"traceId":""}`
|
||||
output, _ = sjson.SetRaw(output, "response", responseTemplate)
|
||||
if traceID != "" {
|
||||
output, _ = sjson.Set(output, "traceId", traceID)
|
||||
}
|
||||
return []byte(output)
|
||||
}
|
||||
|
||||
// ExecuteStream performs a streaming request to the Antigravity API.
|
||||
func (e *AntigravityExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (stream <-chan cliproxyexecutor.StreamChunk, err error) {
|
||||
ctx = context.WithValue(ctx, "alt", "")
|
||||
@@ -180,6 +517,7 @@ func (e *AntigravityExecutor) ExecuteStream(ctx context.Context, auth *cliproxya
|
||||
translated := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), true)
|
||||
|
||||
translated = applyThinkingMetadataCLI(translated, req.Metadata, req.Model)
|
||||
translated = util.ApplyGemini3ThinkingLevelFromMetadataCLI(req.Model, req.Metadata, translated)
|
||||
translated = util.ApplyDefaultThinkingIfNeededCLI(req.Model, translated)
|
||||
translated = normalizeAntigravityThinking(req.Model, translated)
|
||||
|
||||
@@ -312,9 +650,131 @@ func (e *AntigravityExecutor) Refresh(ctx context.Context, auth *cliproxyauth.Au
|
||||
return updated, nil
|
||||
}
|
||||
|
||||
// CountTokens counts tokens for the given request (not supported for Antigravity).
|
||||
func (e *AntigravityExecutor) CountTokens(context.Context, *cliproxyauth.Auth, cliproxyexecutor.Request, cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
|
||||
return cliproxyexecutor.Response{}, statusErr{code: http.StatusNotImplemented, msg: "count tokens not supported"}
|
||||
// CountTokens counts tokens for the given request using the Antigravity API.
|
||||
func (e *AntigravityExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
|
||||
token, updatedAuth, errToken := e.ensureAccessToken(ctx, auth)
|
||||
if errToken != nil {
|
||||
return cliproxyexecutor.Response{}, errToken
|
||||
}
|
||||
if updatedAuth != nil {
|
||||
auth = updatedAuth
|
||||
}
|
||||
if strings.TrimSpace(token) == "" {
|
||||
return cliproxyexecutor.Response{}, statusErr{code: http.StatusUnauthorized, msg: "missing access token"}
|
||||
}
|
||||
|
||||
from := opts.SourceFormat
|
||||
to := sdktranslator.FromString("antigravity")
|
||||
respCtx := context.WithValue(ctx, "alt", opts.Alt)
|
||||
|
||||
baseURLs := antigravityBaseURLFallbackOrder(auth)
|
||||
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
|
||||
|
||||
var authID, authLabel, authType, authValue string
|
||||
if auth != nil {
|
||||
authID = auth.ID
|
||||
authLabel = auth.Label
|
||||
authType, authValue = auth.AccountInfo()
|
||||
}
|
||||
|
||||
var lastStatus int
|
||||
var lastBody []byte
|
||||
var lastErr error
|
||||
|
||||
for idx, baseURL := range baseURLs {
|
||||
payload := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), false)
|
||||
payload = applyThinkingMetadataCLI(payload, req.Metadata, req.Model)
|
||||
payload = util.ApplyDefaultThinkingIfNeededCLI(req.Model, payload)
|
||||
payload = normalizeAntigravityThinking(req.Model, payload)
|
||||
payload = deleteJSONField(payload, "project")
|
||||
payload = deleteJSONField(payload, "model")
|
||||
payload = deleteJSONField(payload, "request.safetySettings")
|
||||
|
||||
base := strings.TrimSuffix(baseURL, "/")
|
||||
if base == "" {
|
||||
base = buildBaseURL(auth)
|
||||
}
|
||||
|
||||
var requestURL strings.Builder
|
||||
requestURL.WriteString(base)
|
||||
requestURL.WriteString(antigravityCountTokensPath)
|
||||
if opts.Alt != "" {
|
||||
requestURL.WriteString("?$alt=")
|
||||
requestURL.WriteString(url.QueryEscape(opts.Alt))
|
||||
}
|
||||
|
||||
httpReq, errReq := http.NewRequestWithContext(ctx, http.MethodPost, requestURL.String(), bytes.NewReader(payload))
|
||||
if errReq != nil {
|
||||
return cliproxyexecutor.Response{}, errReq
|
||||
}
|
||||
httpReq.Header.Set("Content-Type", "application/json")
|
||||
httpReq.Header.Set("Authorization", "Bearer "+token)
|
||||
httpReq.Header.Set("User-Agent", resolveUserAgent(auth))
|
||||
httpReq.Header.Set("Accept", "application/json")
|
||||
if host := resolveHost(base); host != "" {
|
||||
httpReq.Host = host
|
||||
}
|
||||
|
||||
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{
|
||||
URL: requestURL.String(),
|
||||
Method: http.MethodPost,
|
||||
Headers: httpReq.Header.Clone(),
|
||||
Body: payload,
|
||||
Provider: e.Identifier(),
|
||||
AuthID: authID,
|
||||
AuthLabel: authLabel,
|
||||
AuthType: authType,
|
||||
AuthValue: authValue,
|
||||
})
|
||||
|
||||
httpResp, errDo := httpClient.Do(httpReq)
|
||||
if errDo != nil {
|
||||
recordAPIResponseError(ctx, e.cfg, errDo)
|
||||
lastStatus = 0
|
||||
lastBody = nil
|
||||
lastErr = errDo
|
||||
if idx+1 < len(baseURLs) {
|
||||
log.Debugf("antigravity executor: request error on base url %s, retrying with fallback base url: %s", baseURL, baseURLs[idx+1])
|
||||
continue
|
||||
}
|
||||
return cliproxyexecutor.Response{}, errDo
|
||||
}
|
||||
|
||||
recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
|
||||
bodyBytes, errRead := io.ReadAll(httpResp.Body)
|
||||
if errClose := httpResp.Body.Close(); errClose != nil {
|
||||
log.Errorf("antigravity executor: close response body error: %v", errClose)
|
||||
}
|
||||
if errRead != nil {
|
||||
recordAPIResponseError(ctx, e.cfg, errRead)
|
||||
return cliproxyexecutor.Response{}, errRead
|
||||
}
|
||||
appendAPIResponseChunk(ctx, e.cfg, bodyBytes)
|
||||
|
||||
if httpResp.StatusCode >= http.StatusOK && httpResp.StatusCode < http.StatusMultipleChoices {
|
||||
count := gjson.GetBytes(bodyBytes, "totalTokens").Int()
|
||||
translated := sdktranslator.TranslateTokenCount(respCtx, to, from, count, bodyBytes)
|
||||
return cliproxyexecutor.Response{Payload: []byte(translated)}, nil
|
||||
}
|
||||
|
||||
lastStatus = httpResp.StatusCode
|
||||
lastBody = append([]byte(nil), bodyBytes...)
|
||||
lastErr = nil
|
||||
if httpResp.StatusCode == http.StatusTooManyRequests && idx+1 < len(baseURLs) {
|
||||
log.Debugf("antigravity executor: rate limited on base url %s, retrying with fallback base url: %s", baseURL, baseURLs[idx+1])
|
||||
continue
|
||||
}
|
||||
return cliproxyexecutor.Response{}, statusErr{code: httpResp.StatusCode, msg: string(bodyBytes)}
|
||||
}
|
||||
|
||||
switch {
|
||||
case lastStatus != 0:
|
||||
return cliproxyexecutor.Response{}, statusErr{code: lastStatus, msg: string(lastBody)}
|
||||
case lastErr != nil:
|
||||
return cliproxyexecutor.Response{}, lastErr
|
||||
default:
|
||||
return cliproxyexecutor.Response{}, statusErr{code: http.StatusServiceUnavailable, msg: "antigravity executor: no base url available"}
|
||||
}
|
||||
}
|
||||
|
||||
// FetchAntigravityModels retrieves available models using the supplied auth.
|
||||
@@ -545,27 +1005,9 @@ func (e *AntigravityExecutor) buildRequest(ctx context.Context, auth *cliproxyau
|
||||
strJSON, _ = util.RenameKey(strJSON, p, p[:len(p)-len("parametersJsonSchema")]+"parameters")
|
||||
}
|
||||
|
||||
strJSON = util.DeleteKey(strJSON, "$schema")
|
||||
strJSON = util.DeleteKey(strJSON, "maxItems")
|
||||
strJSON = util.DeleteKey(strJSON, "minItems")
|
||||
strJSON = util.DeleteKey(strJSON, "minLength")
|
||||
strJSON = util.DeleteKey(strJSON, "maxLength")
|
||||
strJSON = util.DeleteKey(strJSON, "exclusiveMinimum")
|
||||
strJSON = util.DeleteKey(strJSON, "exclusiveMaximum")
|
||||
strJSON = util.DeleteKey(strJSON, "$ref")
|
||||
strJSON = util.DeleteKey(strJSON, "$defs")
|
||||
|
||||
paths = make([]string, 0)
|
||||
util.Walk(gjson.Parse(strJSON), "", "anyOf", &paths)
|
||||
for _, p := range paths {
|
||||
anyOf := gjson.Get(strJSON, p)
|
||||
if anyOf.IsArray() {
|
||||
anyOfItems := anyOf.Array()
|
||||
if len(anyOfItems) > 0 {
|
||||
strJSON, _ = sjson.SetRaw(strJSON, p[:len(p)-len(".anyOf")], anyOfItems[0].Raw)
|
||||
}
|
||||
}
|
||||
}
|
||||
// Use the centralized schema cleaner to handle unsupported keywords,
|
||||
// const->enum conversion, and flattening of types/anyOf.
|
||||
strJSON = util.CleanJSONSchemaForGemini(strJSON)
|
||||
|
||||
payload = []byte(strJSON)
|
||||
}
|
||||
@@ -798,6 +1240,8 @@ func modelName2Alias(modelName string) string {
|
||||
return "gemini-3-pro-image-preview"
|
||||
case "gemini-3-pro-high":
|
||||
return "gemini-3-pro-preview"
|
||||
case "gemini-3-flash":
|
||||
return "gemini-3-flash-preview"
|
||||
case "claude-sonnet-4-5":
|
||||
return "gemini-claude-sonnet-4-5"
|
||||
case "claude-sonnet-4-5-thinking":
|
||||
@@ -819,6 +1263,8 @@ func alias2ModelName(modelName string) string {
|
||||
return "gemini-3-pro-image"
|
||||
case "gemini-3-pro-preview":
|
||||
return "gemini-3-pro-high"
|
||||
case "gemini-3-flash-preview":
|
||||
return "gemini-3-flash"
|
||||
case "gemini-claude-sonnet-4-5":
|
||||
return "claude-sonnet-4-5"
|
||||
case "gemini-claude-sonnet-4-5-thinking":
|
||||
|
||||
@@ -79,6 +79,7 @@ func (e *GeminiCLIExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth
|
||||
to := sdktranslator.FromString("gemini-cli")
|
||||
basePayload := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), false)
|
||||
basePayload = applyThinkingMetadataCLI(basePayload, req.Metadata, req.Model)
|
||||
basePayload = util.ApplyGemini3ThinkingLevelFromMetadataCLI(req.Model, req.Metadata, basePayload)
|
||||
basePayload = util.ApplyDefaultThinkingIfNeededCLI(req.Model, basePayload)
|
||||
basePayload = util.NormalizeGeminiCLIThinkingBudget(req.Model, basePayload)
|
||||
basePayload = util.StripThinkingConfigIfUnsupported(req.Model, basePayload)
|
||||
@@ -217,6 +218,7 @@ func (e *GeminiCLIExecutor) ExecuteStream(ctx context.Context, auth *cliproxyaut
|
||||
to := sdktranslator.FromString("gemini-cli")
|
||||
basePayload := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), true)
|
||||
basePayload = applyThinkingMetadataCLI(basePayload, req.Metadata, req.Model)
|
||||
basePayload = util.ApplyGemini3ThinkingLevelFromMetadataCLI(req.Model, req.Metadata, basePayload)
|
||||
basePayload = util.ApplyDefaultThinkingIfNeededCLI(req.Model, basePayload)
|
||||
basePayload = util.NormalizeGeminiCLIThinkingBudget(req.Model, basePayload)
|
||||
basePayload = util.StripThinkingConfigIfUnsupported(req.Model, basePayload)
|
||||
@@ -418,6 +420,7 @@ func (e *GeminiCLIExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.
|
||||
for _, attemptModel := range models {
|
||||
payload := sdktranslator.TranslateRequest(from, to, attemptModel, bytes.Clone(req.Payload), false)
|
||||
payload = applyThinkingMetadataCLI(payload, req.Metadata, req.Model)
|
||||
payload = util.ApplyGemini3ThinkingLevelFromMetadataCLI(req.Model, req.Metadata, payload)
|
||||
payload = deleteJSONField(payload, "project")
|
||||
payload = deleteJSONField(payload, "model")
|
||||
payload = deleteJSONField(payload, "request.safetySettings")
|
||||
|
||||
@@ -66,6 +66,7 @@ func (e *IFlowExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, re
|
||||
if errValidate := ValidateThinkingConfig(body, upstreamModel); errValidate != nil {
|
||||
return resp, errValidate
|
||||
}
|
||||
body = applyIFlowThinkingConfig(body)
|
||||
body = applyPayloadConfig(e.cfg, req.Model, body)
|
||||
|
||||
endpoint := strings.TrimSuffix(baseURL, "/") + iflowDefaultEndpoint
|
||||
@@ -157,6 +158,7 @@ func (e *IFlowExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Au
|
||||
if errValidate := ValidateThinkingConfig(body, upstreamModel); errValidate != nil {
|
||||
return nil, errValidate
|
||||
}
|
||||
body = applyIFlowThinkingConfig(body)
|
||||
// Ensure tools array exists to avoid provider quirks similar to Qwen's behaviour.
|
||||
toolsResult := gjson.GetBytes(body, "tools")
|
||||
if toolsResult.Exists() && toolsResult.IsArray() && len(toolsResult.Array()) == 0 {
|
||||
@@ -442,3 +444,21 @@ func ensureToolsArray(body []byte) []byte {
|
||||
}
|
||||
return updated
|
||||
}
|
||||
|
||||
// applyIFlowThinkingConfig converts normalized reasoning_effort to iFlow chat_template_kwargs.enable_thinking.
|
||||
// This should be called after NormalizeThinkingConfig has processed the payload.
|
||||
// iFlow only supports boolean enable_thinking, so any non-"none" effort enables thinking.
|
||||
func applyIFlowThinkingConfig(body []byte) []byte {
|
||||
effort := gjson.GetBytes(body, "reasoning_effort")
|
||||
if !effort.Exists() {
|
||||
return body
|
||||
}
|
||||
|
||||
val := strings.ToLower(strings.TrimSpace(effort.String()))
|
||||
enableThinking := val != "none" && val != ""
|
||||
|
||||
body, _ = sjson.DeleteBytes(body, "reasoning_effort")
|
||||
body, _ = sjson.SetBytes(body, "chat_template_kwargs.enable_thinking", enableThinking)
|
||||
|
||||
return body
|
||||
}
|
||||
|
||||
@@ -72,13 +72,7 @@ func ApplyReasoningEffortMetadata(payload []byte, metadata map[string]any, model
|
||||
// Fallback: numeric thinking_budget suffix for level-based (OpenAI-style) models.
|
||||
if util.ModelUsesThinkingLevels(baseModel) || allowCompat {
|
||||
if budget, _, _, matched := util.ThinkingFromMetadata(metadata); matched && budget != nil {
|
||||
if effort, ok := util.OpenAIThinkingBudgetToEffort(baseModel, *budget); ok && effort != "" {
|
||||
if *budget == 0 && effort == "none" && util.ModelUsesThinkingLevels(baseModel) {
|
||||
if _, supported := util.NormalizeReasoningEffortLevel(baseModel, effort); !supported {
|
||||
return StripThinkingFields(payload, false)
|
||||
}
|
||||
}
|
||||
|
||||
if effort, ok := util.ThinkingBudgetToEffort(baseModel, *budget); ok && effort != "" {
|
||||
if updated, err := sjson.SetBytes(payload, field, effort); err == nil {
|
||||
return updated
|
||||
}
|
||||
@@ -273,7 +267,7 @@ func StripThinkingFields(payload []byte, effortOnly bool) []byte {
|
||||
"reasoning.effort",
|
||||
}
|
||||
if !effortOnly {
|
||||
fieldsToRemove = append([]string{"reasoning"}, fieldsToRemove...)
|
||||
fieldsToRemove = append([]string{"reasoning", "thinking"}, fieldsToRemove...)
|
||||
}
|
||||
out := payload
|
||||
for _, field := range fieldsToRemove {
|
||||
|
||||
@@ -7,10 +7,8 @@ package claude
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
"strings"
|
||||
|
||||
client "github.com/router-for-me/CLIProxyAPI/v6/internal/interfaces"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/translator/gemini/common"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
||||
"github.com/tidwall/gjson"
|
||||
@@ -42,27 +40,30 @@ func ConvertClaudeRequestToAntigravity(modelName string, inputRawJSON []byte, _
|
||||
rawJSON = bytes.Replace(rawJSON, []byte(`"url":{"type":"string","format":"uri",`), []byte(`"url":{"type":"string",`), -1)
|
||||
|
||||
// system instruction
|
||||
var systemInstruction *client.Content
|
||||
systemInstructionJSON := ""
|
||||
hasSystemInstruction := false
|
||||
systemResult := gjson.GetBytes(rawJSON, "system")
|
||||
if systemResult.IsArray() {
|
||||
systemResults := systemResult.Array()
|
||||
systemInstruction = &client.Content{Role: "user", Parts: []client.Part{}}
|
||||
systemInstructionJSON = `{"role":"user","parts":[]}`
|
||||
for i := 0; i < len(systemResults); i++ {
|
||||
systemPromptResult := systemResults[i]
|
||||
systemTypePromptResult := systemPromptResult.Get("type")
|
||||
if systemTypePromptResult.Type == gjson.String && systemTypePromptResult.String() == "text" {
|
||||
systemPrompt := systemPromptResult.Get("text").String()
|
||||
systemPart := client.Part{Text: systemPrompt}
|
||||
systemInstruction.Parts = append(systemInstruction.Parts, systemPart)
|
||||
partJSON := `{}`
|
||||
if systemPrompt != "" {
|
||||
partJSON, _ = sjson.Set(partJSON, "text", systemPrompt)
|
||||
}
|
||||
systemInstructionJSON, _ = sjson.SetRaw(systemInstructionJSON, "parts.-1", partJSON)
|
||||
hasSystemInstruction = true
|
||||
}
|
||||
}
|
||||
if len(systemInstruction.Parts) == 0 {
|
||||
systemInstruction = nil
|
||||
}
|
||||
}
|
||||
|
||||
// contents
|
||||
contents := make([]client.Content, 0)
|
||||
contentsJSON := "[]"
|
||||
hasContents := false
|
||||
messagesResult := gjson.GetBytes(rawJSON, "messages")
|
||||
if messagesResult.IsArray() {
|
||||
messageResults := messagesResult.Array()
|
||||
@@ -76,7 +77,8 @@ func ConvertClaudeRequestToAntigravity(modelName string, inputRawJSON []byte, _
|
||||
if role == "assistant" {
|
||||
role = "model"
|
||||
}
|
||||
clientContent := client.Content{Role: role, Parts: []client.Part{}}
|
||||
clientContentJSON := `{"role":"","parts":[]}`
|
||||
clientContentJSON, _ = sjson.Set(clientContentJSON, "role", role)
|
||||
contentsResult := messageResult.Get("content")
|
||||
if contentsResult.IsArray() {
|
||||
contentResults := contentsResult.Array()
|
||||
@@ -90,25 +92,39 @@ func ConvertClaudeRequestToAntigravity(modelName string, inputRawJSON []byte, _
|
||||
if signatureResult.Exists() {
|
||||
signature = signatureResult.String()
|
||||
}
|
||||
clientContent.Parts = append(clientContent.Parts, client.Part{Text: prompt, Thought: true, ThoughtSignature: signature})
|
||||
partJSON := `{}`
|
||||
partJSON, _ = sjson.Set(partJSON, "thought", true)
|
||||
if prompt != "" {
|
||||
partJSON, _ = sjson.Set(partJSON, "text", prompt)
|
||||
}
|
||||
if signature != "" {
|
||||
partJSON, _ = sjson.Set(partJSON, "thoughtSignature", signature)
|
||||
}
|
||||
clientContentJSON, _ = sjson.SetRaw(clientContentJSON, "parts.-1", partJSON)
|
||||
} else if contentTypeResult.Type == gjson.String && contentTypeResult.String() == "text" {
|
||||
prompt := contentResult.Get("text").String()
|
||||
clientContent.Parts = append(clientContent.Parts, client.Part{Text: prompt})
|
||||
partJSON := `{}`
|
||||
if prompt != "" {
|
||||
partJSON, _ = sjson.Set(partJSON, "text", prompt)
|
||||
}
|
||||
clientContentJSON, _ = sjson.SetRaw(clientContentJSON, "parts.-1", partJSON)
|
||||
} else if contentTypeResult.Type == gjson.String && contentTypeResult.String() == "tool_use" {
|
||||
functionName := contentResult.Get("name").String()
|
||||
functionArgs := contentResult.Get("input").String()
|
||||
functionID := contentResult.Get("id").String()
|
||||
var args map[string]any
|
||||
if err := json.Unmarshal([]byte(functionArgs), &args); err == nil {
|
||||
if strings.Contains(modelName, "claude") {
|
||||
clientContent.Parts = append(clientContent.Parts, client.Part{
|
||||
FunctionCall: &client.FunctionCall{ID: functionID, Name: functionName, Args: args},
|
||||
})
|
||||
} else {
|
||||
clientContent.Parts = append(clientContent.Parts, client.Part{
|
||||
FunctionCall: &client.FunctionCall{ID: functionID, Name: functionName, Args: args},
|
||||
ThoughtSignature: geminiCLIClaudeThoughtSignature,
|
||||
})
|
||||
if gjson.Valid(functionArgs) {
|
||||
argsResult := gjson.Parse(functionArgs)
|
||||
if argsResult.IsObject() {
|
||||
partJSON := `{}`
|
||||
if !strings.Contains(modelName, "claude") {
|
||||
partJSON, _ = sjson.Set(partJSON, "thoughtSignature", geminiCLIClaudeThoughtSignature)
|
||||
}
|
||||
if functionID != "" {
|
||||
partJSON, _ = sjson.Set(partJSON, "functionCall.id", functionID)
|
||||
}
|
||||
partJSON, _ = sjson.Set(partJSON, "functionCall.name", functionName)
|
||||
partJSON, _ = sjson.SetRaw(partJSON, "functionCall.args", argsResult.Raw)
|
||||
clientContentJSON, _ = sjson.SetRaw(clientContentJSON, "parts.-1", partJSON)
|
||||
}
|
||||
}
|
||||
} else if contentTypeResult.Type == gjson.String && contentTypeResult.String() == "tool_result" {
|
||||
@@ -117,37 +133,74 @@ func ConvertClaudeRequestToAntigravity(modelName string, inputRawJSON []byte, _
|
||||
funcName := toolCallID
|
||||
toolCallIDs := strings.Split(toolCallID, "-")
|
||||
if len(toolCallIDs) > 1 {
|
||||
funcName = strings.Join(toolCallIDs[0:len(toolCallIDs)-1], "-")
|
||||
funcName = strings.Join(toolCallIDs[0:len(toolCallIDs)-2], "-")
|
||||
}
|
||||
responseData := contentResult.Get("content").Raw
|
||||
functionResponse := client.FunctionResponse{ID: toolCallID, Name: funcName, Response: map[string]interface{}{"result": responseData}}
|
||||
clientContent.Parts = append(clientContent.Parts, client.Part{FunctionResponse: &functionResponse})
|
||||
functionResponseResult := contentResult.Get("content")
|
||||
|
||||
functionResponseJSON := `{}`
|
||||
functionResponseJSON, _ = sjson.Set(functionResponseJSON, "id", toolCallID)
|
||||
functionResponseJSON, _ = sjson.Set(functionResponseJSON, "name", funcName)
|
||||
|
||||
responseData := ""
|
||||
if functionResponseResult.Type == gjson.String {
|
||||
responseData = functionResponseResult.String()
|
||||
functionResponseJSON, _ = sjson.Set(functionResponseJSON, "response.result", responseData)
|
||||
} else if functionResponseResult.IsArray() {
|
||||
frResults := functionResponseResult.Array()
|
||||
if len(frResults) == 1 {
|
||||
functionResponseJSON, _ = sjson.SetRaw(functionResponseJSON, "response.result", frResults[0].Raw)
|
||||
} else {
|
||||
functionResponseJSON, _ = sjson.SetRaw(functionResponseJSON, "response.result", functionResponseResult.Raw)
|
||||
}
|
||||
|
||||
} else if functionResponseResult.IsObject() {
|
||||
functionResponseJSON, _ = sjson.SetRaw(functionResponseJSON, "response.result", functionResponseResult.Raw)
|
||||
} else {
|
||||
functionResponseJSON, _ = sjson.SetRaw(functionResponseJSON, "response.result", functionResponseResult.Raw)
|
||||
}
|
||||
|
||||
partJSON := `{}`
|
||||
partJSON, _ = sjson.SetRaw(partJSON, "functionResponse", functionResponseJSON)
|
||||
clientContentJSON, _ = sjson.SetRaw(clientContentJSON, "parts.-1", partJSON)
|
||||
}
|
||||
} else if contentTypeResult.Type == gjson.String && contentTypeResult.String() == "image" {
|
||||
sourceResult := contentResult.Get("source")
|
||||
if sourceResult.Get("type").String() == "base64" {
|
||||
inlineData := &client.InlineData{
|
||||
MimeType: sourceResult.Get("media_type").String(),
|
||||
Data: sourceResult.Get("data").String(),
|
||||
inlineDataJSON := `{}`
|
||||
if mimeType := sourceResult.Get("media_type").String(); mimeType != "" {
|
||||
inlineDataJSON, _ = sjson.Set(inlineDataJSON, "mime_type", mimeType)
|
||||
}
|
||||
clientContent.Parts = append(clientContent.Parts, client.Part{InlineData: inlineData})
|
||||
if data := sourceResult.Get("data").String(); data != "" {
|
||||
inlineDataJSON, _ = sjson.Set(inlineDataJSON, "data", data)
|
||||
}
|
||||
|
||||
partJSON := `{}`
|
||||
partJSON, _ = sjson.SetRaw(partJSON, "inlineData", inlineDataJSON)
|
||||
clientContentJSON, _ = sjson.SetRaw(clientContentJSON, "parts.-1", partJSON)
|
||||
}
|
||||
}
|
||||
}
|
||||
contents = append(contents, clientContent)
|
||||
contentsJSON, _ = sjson.SetRaw(contentsJSON, "-1", clientContentJSON)
|
||||
hasContents = true
|
||||
} else if contentsResult.Type == gjson.String {
|
||||
prompt := contentsResult.String()
|
||||
contents = append(contents, client.Content{Role: role, Parts: []client.Part{{Text: prompt}}})
|
||||
partJSON := `{}`
|
||||
if prompt != "" {
|
||||
partJSON, _ = sjson.Set(partJSON, "text", prompt)
|
||||
}
|
||||
clientContentJSON, _ = sjson.SetRaw(clientContentJSON, "parts.-1", partJSON)
|
||||
contentsJSON, _ = sjson.SetRaw(contentsJSON, "-1", clientContentJSON)
|
||||
hasContents = true
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// tools
|
||||
var tools []client.ToolDeclaration
|
||||
toolsJSON := ""
|
||||
toolDeclCount := 0
|
||||
toolsResult := gjson.GetBytes(rawJSON, "tools")
|
||||
if toolsResult.IsArray() {
|
||||
tools = make([]client.ToolDeclaration, 1)
|
||||
tools[0].FunctionDeclarations = make([]any, 0)
|
||||
toolsJSON = `[{"functionDeclarations":[]}]`
|
||||
toolsResults := toolsResult.Array()
|
||||
for i := 0; i < len(toolsResults); i++ {
|
||||
toolResult := toolsResults[i]
|
||||
@@ -158,30 +211,23 @@ func ConvertClaudeRequestToAntigravity(modelName string, inputRawJSON []byte, _
|
||||
tool, _ = sjson.SetRaw(tool, "parametersJsonSchema", inputSchema)
|
||||
tool, _ = sjson.Delete(tool, "strict")
|
||||
tool, _ = sjson.Delete(tool, "input_examples")
|
||||
var toolDeclaration any
|
||||
if err := json.Unmarshal([]byte(tool), &toolDeclaration); err == nil {
|
||||
tools[0].FunctionDeclarations = append(tools[0].FunctionDeclarations, toolDeclaration)
|
||||
}
|
||||
toolsJSON, _ = sjson.SetRaw(toolsJSON, "0.functionDeclarations.-1", tool)
|
||||
toolDeclCount++
|
||||
}
|
||||
}
|
||||
} else {
|
||||
tools = make([]client.ToolDeclaration, 0)
|
||||
}
|
||||
|
||||
// Build output Gemini CLI request JSON
|
||||
out := `{"model":"","request":{"contents":[]}}`
|
||||
out, _ = sjson.Set(out, "model", modelName)
|
||||
if systemInstruction != nil {
|
||||
b, _ := json.Marshal(systemInstruction)
|
||||
out, _ = sjson.SetRaw(out, "request.systemInstruction", string(b))
|
||||
if hasSystemInstruction {
|
||||
out, _ = sjson.SetRaw(out, "request.systemInstruction", systemInstructionJSON)
|
||||
}
|
||||
if len(contents) > 0 {
|
||||
b, _ := json.Marshal(contents)
|
||||
out, _ = sjson.SetRaw(out, "request.contents", string(b))
|
||||
if hasContents {
|
||||
out, _ = sjson.SetRaw(out, "request.contents", contentsJSON)
|
||||
}
|
||||
if len(tools) > 0 && len(tools[0].FunctionDeclarations) > 0 {
|
||||
b, _ := json.Marshal(tools)
|
||||
out, _ = sjson.SetRaw(out, "request.tools", string(b))
|
||||
if toolDeclCount > 0 {
|
||||
out, _ = sjson.SetRaw(out, "request.tools", toolsJSON)
|
||||
}
|
||||
|
||||
// Map Anthropic thinking -> Gemini thinkingBudget/include_thoughts when type==enabled
|
||||
|
||||
@@ -9,7 +9,6 @@ package claude
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"strings"
|
||||
"sync/atomic"
|
||||
@@ -350,24 +349,25 @@ func ConvertAntigravityResponseToClaudeNonStream(_ context.Context, _ string, or
|
||||
}
|
||||
}
|
||||
|
||||
response := map[string]interface{}{
|
||||
"id": root.Get("response.responseId").String(),
|
||||
"type": "message",
|
||||
"role": "assistant",
|
||||
"model": root.Get("response.modelVersion").String(),
|
||||
"content": []interface{}{},
|
||||
"stop_reason": nil,
|
||||
"stop_sequence": nil,
|
||||
"usage": map[string]interface{}{
|
||||
"input_tokens": promptTokens,
|
||||
"output_tokens": outputTokens,
|
||||
},
|
||||
responseJSON := `{"id":"","type":"message","role":"assistant","model":"","content":null,"stop_reason":null,"stop_sequence":null,"usage":{"input_tokens":0,"output_tokens":0}}`
|
||||
responseJSON, _ = sjson.Set(responseJSON, "id", root.Get("response.responseId").String())
|
||||
responseJSON, _ = sjson.Set(responseJSON, "model", root.Get("response.modelVersion").String())
|
||||
responseJSON, _ = sjson.Set(responseJSON, "usage.input_tokens", promptTokens)
|
||||
responseJSON, _ = sjson.Set(responseJSON, "usage.output_tokens", outputTokens)
|
||||
|
||||
contentArrayInitialized := false
|
||||
ensureContentArray := func() {
|
||||
if contentArrayInitialized {
|
||||
return
|
||||
}
|
||||
responseJSON, _ = sjson.SetRaw(responseJSON, "content", "[]")
|
||||
contentArrayInitialized = true
|
||||
}
|
||||
|
||||
parts := root.Get("response.candidates.0.content.parts")
|
||||
var contentBlocks []interface{}
|
||||
textBuilder := strings.Builder{}
|
||||
thinkingBuilder := strings.Builder{}
|
||||
thinkingSignature := ""
|
||||
toolIDCounter := 0
|
||||
hasToolCall := false
|
||||
|
||||
@@ -375,28 +375,43 @@ func ConvertAntigravityResponseToClaudeNonStream(_ context.Context, _ string, or
|
||||
if textBuilder.Len() == 0 {
|
||||
return
|
||||
}
|
||||
contentBlocks = append(contentBlocks, map[string]interface{}{
|
||||
"type": "text",
|
||||
"text": textBuilder.String(),
|
||||
})
|
||||
ensureContentArray()
|
||||
block := `{"type":"text","text":""}`
|
||||
block, _ = sjson.Set(block, "text", textBuilder.String())
|
||||
responseJSON, _ = sjson.SetRaw(responseJSON, "content.-1", block)
|
||||
textBuilder.Reset()
|
||||
}
|
||||
|
||||
flushThinking := func() {
|
||||
if thinkingBuilder.Len() == 0 {
|
||||
if thinkingBuilder.Len() == 0 && thinkingSignature == "" {
|
||||
return
|
||||
}
|
||||
contentBlocks = append(contentBlocks, map[string]interface{}{
|
||||
"type": "thinking",
|
||||
"thinking": thinkingBuilder.String(),
|
||||
})
|
||||
ensureContentArray()
|
||||
block := `{"type":"thinking","thinking":""}`
|
||||
block, _ = sjson.Set(block, "thinking", thinkingBuilder.String())
|
||||
if thinkingSignature != "" {
|
||||
block, _ = sjson.Set(block, "signature", thinkingSignature)
|
||||
}
|
||||
responseJSON, _ = sjson.SetRaw(responseJSON, "content.-1", block)
|
||||
thinkingBuilder.Reset()
|
||||
thinkingSignature = ""
|
||||
}
|
||||
|
||||
if parts.IsArray() {
|
||||
for _, part := range parts.Array() {
|
||||
isThought := part.Get("thought").Bool()
|
||||
if isThought {
|
||||
sig := part.Get("thoughtSignature")
|
||||
if !sig.Exists() {
|
||||
sig = part.Get("thought_signature")
|
||||
}
|
||||
if sig.Exists() && sig.String() != "" {
|
||||
thinkingSignature = sig.String()
|
||||
}
|
||||
}
|
||||
|
||||
if text := part.Get("text"); text.Exists() && text.String() != "" {
|
||||
if part.Get("thought").Bool() {
|
||||
if isThought {
|
||||
flushText()
|
||||
thinkingBuilder.WriteString(text.String())
|
||||
continue
|
||||
@@ -413,21 +428,16 @@ func ConvertAntigravityResponseToClaudeNonStream(_ context.Context, _ string, or
|
||||
|
||||
name := functionCall.Get("name").String()
|
||||
toolIDCounter++
|
||||
toolBlock := map[string]interface{}{
|
||||
"type": "tool_use",
|
||||
"id": fmt.Sprintf("tool_%d", toolIDCounter),
|
||||
"name": name,
|
||||
"input": map[string]interface{}{},
|
||||
toolBlock := `{"type":"tool_use","id":"","name":"","input":{}}`
|
||||
toolBlock, _ = sjson.Set(toolBlock, "id", fmt.Sprintf("tool_%d", toolIDCounter))
|
||||
toolBlock, _ = sjson.Set(toolBlock, "name", name)
|
||||
|
||||
if args := functionCall.Get("args"); args.Exists() && args.Raw != "" && gjson.Valid(args.Raw) {
|
||||
toolBlock, _ = sjson.SetRaw(toolBlock, "input", args.Raw)
|
||||
}
|
||||
|
||||
if args := functionCall.Get("args"); args.Exists() {
|
||||
var parsed interface{}
|
||||
if err := json.Unmarshal([]byte(args.Raw), &parsed); err == nil {
|
||||
toolBlock["input"] = parsed
|
||||
}
|
||||
}
|
||||
|
||||
contentBlocks = append(contentBlocks, toolBlock)
|
||||
ensureContentArray()
|
||||
responseJSON, _ = sjson.SetRaw(responseJSON, "content.-1", toolBlock)
|
||||
continue
|
||||
}
|
||||
}
|
||||
@@ -436,8 +446,6 @@ func ConvertAntigravityResponseToClaudeNonStream(_ context.Context, _ string, or
|
||||
flushThinking()
|
||||
flushText()
|
||||
|
||||
response["content"] = contentBlocks
|
||||
|
||||
stopReason := "end_turn"
|
||||
if hasToolCall {
|
||||
stopReason = "tool_use"
|
||||
@@ -453,19 +461,15 @@ func ConvertAntigravityResponseToClaudeNonStream(_ context.Context, _ string, or
|
||||
}
|
||||
}
|
||||
}
|
||||
response["stop_reason"] = stopReason
|
||||
responseJSON, _ = sjson.Set(responseJSON, "stop_reason", stopReason)
|
||||
|
||||
if usage := response["usage"].(map[string]interface{}); usage["input_tokens"] == int64(0) && usage["output_tokens"] == int64(0) {
|
||||
if promptTokens == 0 && outputTokens == 0 {
|
||||
if usageMeta := root.Get("response.usageMetadata"); !usageMeta.Exists() {
|
||||
delete(response, "usage")
|
||||
responseJSON, _ = sjson.Delete(responseJSON, "usage")
|
||||
}
|
||||
}
|
||||
|
||||
encoded, err := json.Marshal(response)
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
return string(encoded)
|
||||
return responseJSON
|
||||
}
|
||||
|
||||
func ClaudeTokenCount(ctx context.Context, count int64) string {
|
||||
|
||||
@@ -39,8 +39,23 @@ func ConvertOpenAIRequestToAntigravity(modelName string, inputRawJSON []byte, _
|
||||
// Note: OpenAI official fields take precedence over extra_body.google.thinking_config
|
||||
re := gjson.GetBytes(rawJSON, "reasoning_effort")
|
||||
hasOfficialThinking := re.Exists()
|
||||
if hasOfficialThinking && util.ModelSupportsThinking(modelName) && !util.ModelUsesThinkingLevels(modelName) {
|
||||
out = util.ApplyReasoningEffortToGeminiCLI(out, re.String())
|
||||
if hasOfficialThinking && util.ModelSupportsThinking(modelName) {
|
||||
effort := strings.ToLower(strings.TrimSpace(re.String()))
|
||||
if util.IsGemini3Model(modelName) {
|
||||
switch effort {
|
||||
case "none":
|
||||
out, _ = sjson.DeleteBytes(out, "request.generationConfig.thinkingConfig")
|
||||
case "auto":
|
||||
includeThoughts := true
|
||||
out = util.ApplyGeminiCLIThinkingLevel(out, "", &includeThoughts)
|
||||
default:
|
||||
if level, ok := util.ValidateGemini3ThinkingLevel(modelName, effort); ok {
|
||||
out = util.ApplyGeminiCLIThinkingLevel(out, level, nil)
|
||||
}
|
||||
}
|
||||
} else if !util.ModelUsesThinkingLevels(modelName) {
|
||||
out = util.ApplyReasoningEffortToGeminiCLI(out, effort)
|
||||
}
|
||||
}
|
||||
|
||||
// Cherry Studio extension extra_body.google.thinking_config (effective only when official fields are absent)
|
||||
@@ -222,62 +237,61 @@ func ConvertOpenAIRequestToAntigravity(modelName string, inputRawJSON []byte, _
|
||||
}
|
||||
out, _ = sjson.SetRawBytes(out, "request.contents.-1", node)
|
||||
} else if role == "assistant" {
|
||||
node := []byte(`{"role":"model","parts":[]}`)
|
||||
p := 0
|
||||
if content.Type == gjson.String {
|
||||
// Assistant text -> single model content
|
||||
node := []byte(`{"role":"model","parts":[{"text":""}]}`)
|
||||
node, _ = sjson.SetBytes(node, "parts.0.text", content.String())
|
||||
node, _ = sjson.SetBytes(node, "parts.-1.text", content.String())
|
||||
out, _ = sjson.SetRawBytes(out, "request.contents.-1", node)
|
||||
} else if !content.Exists() || content.Type == gjson.Null {
|
||||
// Tool calls -> single model content with functionCall parts
|
||||
tcs := m.Get("tool_calls")
|
||||
if tcs.IsArray() {
|
||||
node := []byte(`{"role":"model","parts":[]}`)
|
||||
p := 0
|
||||
fIDs := make([]string, 0)
|
||||
for _, tc := range tcs.Array() {
|
||||
if tc.Get("type").String() != "function" {
|
||||
continue
|
||||
}
|
||||
fid := tc.Get("id").String()
|
||||
fname := tc.Get("function.name").String()
|
||||
fargs := tc.Get("function.arguments").String()
|
||||
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".functionCall.id", fid)
|
||||
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".functionCall.name", fname)
|
||||
node, _ = sjson.SetRawBytes(node, "parts."+itoa(p)+".functionCall.args", []byte(fargs))
|
||||
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".thoughtSignature", geminiCLIFunctionThoughtSignature)
|
||||
p++
|
||||
if fid != "" {
|
||||
fIDs = append(fIDs, fid)
|
||||
}
|
||||
}
|
||||
out, _ = sjson.SetRawBytes(out, "request.contents.-1", node)
|
||||
p++
|
||||
}
|
||||
|
||||
// Append a single tool content combining name + response per function
|
||||
toolNode := []byte(`{"role":"user","parts":[]}`)
|
||||
pp := 0
|
||||
for _, fid := range fIDs {
|
||||
if name, ok := tcID2Name[fid]; ok {
|
||||
toolNode, _ = sjson.SetBytes(toolNode, "parts."+itoa(pp)+".functionResponse.id", fid)
|
||||
toolNode, _ = sjson.SetBytes(toolNode, "parts."+itoa(pp)+".functionResponse.name", name)
|
||||
resp := toolResponses[fid]
|
||||
if resp == "" {
|
||||
resp = "{}"
|
||||
}
|
||||
// Handle non-JSON output gracefully (matches dev branch approach)
|
||||
if resp != "null" {
|
||||
parsed := gjson.Parse(resp)
|
||||
if parsed.Type == gjson.JSON {
|
||||
toolNode, _ = sjson.SetRawBytes(toolNode, "parts."+itoa(pp)+".functionResponse.response.result", []byte(parsed.Raw))
|
||||
} else {
|
||||
toolNode, _ = sjson.SetBytes(toolNode, "parts."+itoa(pp)+".functionResponse.response.result", resp)
|
||||
}
|
||||
}
|
||||
pp++
|
||||
// Tool calls -> single model content with functionCall parts
|
||||
tcs := m.Get("tool_calls")
|
||||
if tcs.IsArray() {
|
||||
fIDs := make([]string, 0)
|
||||
for _, tc := range tcs.Array() {
|
||||
if tc.Get("type").String() != "function" {
|
||||
continue
|
||||
}
|
||||
fid := tc.Get("id").String()
|
||||
fname := tc.Get("function.name").String()
|
||||
fargs := tc.Get("function.arguments").String()
|
||||
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".functionCall.id", fid)
|
||||
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".functionCall.name", fname)
|
||||
node, _ = sjson.SetRawBytes(node, "parts."+itoa(p)+".functionCall.args", []byte(fargs))
|
||||
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".thoughtSignature", geminiCLIFunctionThoughtSignature)
|
||||
p++
|
||||
if fid != "" {
|
||||
fIDs = append(fIDs, fid)
|
||||
}
|
||||
}
|
||||
out, _ = sjson.SetRawBytes(out, "request.contents.-1", node)
|
||||
|
||||
// Append a single tool content combining name + response per function
|
||||
toolNode := []byte(`{"role":"user","parts":[]}`)
|
||||
pp := 0
|
||||
for _, fid := range fIDs {
|
||||
if name, ok := tcID2Name[fid]; ok {
|
||||
toolNode, _ = sjson.SetBytes(toolNode, "parts."+itoa(pp)+".functionResponse.id", fid)
|
||||
toolNode, _ = sjson.SetBytes(toolNode, "parts."+itoa(pp)+".functionResponse.name", name)
|
||||
resp := toolResponses[fid]
|
||||
if resp == "" {
|
||||
resp = "{}"
|
||||
}
|
||||
// Handle non-JSON output gracefully (matches dev branch approach)
|
||||
if resp != "null" {
|
||||
parsed := gjson.Parse(resp)
|
||||
if parsed.Type == gjson.JSON {
|
||||
toolNode, _ = sjson.SetRawBytes(toolNode, "parts."+itoa(pp)+".functionResponse.response.result", []byte(parsed.Raw))
|
||||
} else {
|
||||
toolNode, _ = sjson.SetBytes(toolNode, "parts."+itoa(pp)+".functionResponse.response.result", resp)
|
||||
}
|
||||
}
|
||||
pp++
|
||||
}
|
||||
if pp > 0 {
|
||||
out, _ = sjson.SetRawBytes(out, "request.contents.-1", toolNode)
|
||||
}
|
||||
}
|
||||
if pp > 0 {
|
||||
out, _ = sjson.SetRawBytes(out, "request.contents.-1", toolNode)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -361,18 +375,3 @@ func ConvertOpenAIRequestToAntigravity(modelName string, inputRawJSON []byte, _
|
||||
|
||||
// itoa converts int to string without strconv import for few usages.
|
||||
func itoa(i int) string { return fmt.Sprintf("%d", i) }
|
||||
|
||||
// quoteIfNeeded ensures a string is valid JSON value (quotes plain text), pass-through for JSON objects/arrays.
|
||||
func quoteIfNeeded(s string) string {
|
||||
s = strings.TrimSpace(s)
|
||||
if s == "" {
|
||||
return "\"\""
|
||||
}
|
||||
if len(s) > 0 && (s[0] == '{' || s[0] == '[') {
|
||||
return s
|
||||
}
|
||||
// escape quotes minimally
|
||||
s = strings.ReplaceAll(s, "\\", "\\\\")
|
||||
s = strings.ReplaceAll(s, "\"", "\\\"")
|
||||
return "\"" + s + "\""
|
||||
}
|
||||
|
||||
@@ -219,15 +219,20 @@ func ConvertClaudeRequestToCodex(modelName string, inputRawJSON []byte, _ bool)
|
||||
// Convert thinking.budget_tokens to reasoning.effort for level-based models
|
||||
reasoningEffort := "medium" // default
|
||||
if thinking := rootResult.Get("thinking"); thinking.Exists() && thinking.IsObject() {
|
||||
if thinking.Get("type").String() == "enabled" {
|
||||
switch thinking.Get("type").String() {
|
||||
case "enabled":
|
||||
if util.ModelUsesThinkingLevels(modelName) {
|
||||
if budgetTokens := thinking.Get("budget_tokens"); budgetTokens.Exists() {
|
||||
budget := int(budgetTokens.Int())
|
||||
if effort, ok := util.OpenAIThinkingBudgetToEffort(modelName, budget); ok && effort != "" {
|
||||
if effort, ok := util.ThinkingBudgetToEffort(modelName, budget); ok && effort != "" {
|
||||
reasoningEffort = effort
|
||||
}
|
||||
}
|
||||
}
|
||||
case "disabled":
|
||||
if effort, ok := util.ThinkingBudgetToEffort(modelName, 0); ok && effort != "" {
|
||||
reasoningEffort = effort
|
||||
}
|
||||
}
|
||||
}
|
||||
template, _ = sjson.Set(template, "reasoning.effort", reasoningEffort)
|
||||
|
||||
@@ -253,7 +253,7 @@ func ConvertGeminiRequestToCodex(modelName string, inputRawJSON []byte, _ bool)
|
||||
if util.ModelUsesThinkingLevels(modelName) {
|
||||
if thinkingBudget := thinkingConfig.Get("thinkingBudget"); thinkingBudget.Exists() {
|
||||
budget := int(thinkingBudget.Int())
|
||||
if effort, ok := util.OpenAIThinkingBudgetToEffort(modelName, budget); ok && effort != "" {
|
||||
if effort, ok := util.ThinkingBudgetToEffort(modelName, budget); ok && effort != "" {
|
||||
reasoningEffort = effort
|
||||
}
|
||||
}
|
||||
|
||||
@@ -205,52 +205,52 @@ func ConvertOpenAIRequestToGeminiCLI(modelName string, inputRawJSON []byte, _ bo
|
||||
}
|
||||
out, _ = sjson.SetRawBytes(out, "request.contents.-1", node)
|
||||
} else if role == "assistant" {
|
||||
p := 0
|
||||
node := []byte(`{"role":"model","parts":[]}`)
|
||||
if content.Type == gjson.String {
|
||||
// Assistant text -> single model content
|
||||
node := []byte(`{"role":"model","parts":[{"text":""}]}`)
|
||||
node, _ = sjson.SetBytes(node, "parts.0.text", content.String())
|
||||
node, _ = sjson.SetBytes(node, "parts.-1.text", content.String())
|
||||
out, _ = sjson.SetRawBytes(out, "request.contents.-1", node)
|
||||
} else if !content.Exists() || content.Type == gjson.Null {
|
||||
// Tool calls -> single model content with functionCall parts
|
||||
tcs := m.Get("tool_calls")
|
||||
if tcs.IsArray() {
|
||||
node := []byte(`{"role":"model","parts":[]}`)
|
||||
p := 0
|
||||
fIDs := make([]string, 0)
|
||||
for _, tc := range tcs.Array() {
|
||||
if tc.Get("type").String() != "function" {
|
||||
continue
|
||||
}
|
||||
fid := tc.Get("id").String()
|
||||
fname := tc.Get("function.name").String()
|
||||
fargs := tc.Get("function.arguments").String()
|
||||
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".functionCall.name", fname)
|
||||
node, _ = sjson.SetRawBytes(node, "parts."+itoa(p)+".functionCall.args", []byte(fargs))
|
||||
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".thoughtSignature", geminiCLIFunctionThoughtSignature)
|
||||
p++
|
||||
if fid != "" {
|
||||
fIDs = append(fIDs, fid)
|
||||
}
|
||||
}
|
||||
out, _ = sjson.SetRawBytes(out, "request.contents.-1", node)
|
||||
p++
|
||||
}
|
||||
|
||||
// Append a single tool content combining name + response per function
|
||||
toolNode := []byte(`{"role":"tool","parts":[]}`)
|
||||
pp := 0
|
||||
for _, fid := range fIDs {
|
||||
if name, ok := tcID2Name[fid]; ok {
|
||||
toolNode, _ = sjson.SetBytes(toolNode, "parts."+itoa(pp)+".functionResponse.name", name)
|
||||
resp := toolResponses[fid]
|
||||
if resp == "" {
|
||||
resp = "{}"
|
||||
}
|
||||
toolNode, _ = sjson.SetBytes(toolNode, "parts."+itoa(pp)+".functionResponse.response.result", []byte(resp))
|
||||
pp++
|
||||
// Tool calls -> single model content with functionCall parts
|
||||
tcs := m.Get("tool_calls")
|
||||
if tcs.IsArray() {
|
||||
fIDs := make([]string, 0)
|
||||
for _, tc := range tcs.Array() {
|
||||
if tc.Get("type").String() != "function" {
|
||||
continue
|
||||
}
|
||||
fid := tc.Get("id").String()
|
||||
fname := tc.Get("function.name").String()
|
||||
fargs := tc.Get("function.arguments").String()
|
||||
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".functionCall.name", fname)
|
||||
node, _ = sjson.SetRawBytes(node, "parts."+itoa(p)+".functionCall.args", []byte(fargs))
|
||||
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".thoughtSignature", geminiCLIFunctionThoughtSignature)
|
||||
p++
|
||||
if fid != "" {
|
||||
fIDs = append(fIDs, fid)
|
||||
}
|
||||
}
|
||||
out, _ = sjson.SetRawBytes(out, "request.contents.-1", node)
|
||||
|
||||
// Append a single tool content combining name + response per function
|
||||
toolNode := []byte(`{"role":"tool","parts":[]}`)
|
||||
pp := 0
|
||||
for _, fid := range fIDs {
|
||||
if name, ok := tcID2Name[fid]; ok {
|
||||
toolNode, _ = sjson.SetBytes(toolNode, "parts."+itoa(pp)+".functionResponse.name", name)
|
||||
resp := toolResponses[fid]
|
||||
if resp == "" {
|
||||
resp = "{}"
|
||||
}
|
||||
toolNode, _ = sjson.SetBytes(toolNode, "parts."+itoa(pp)+".functionResponse.response.result", []byte(resp))
|
||||
pp++
|
||||
}
|
||||
if pp > 0 {
|
||||
out, _ = sjson.SetRawBytes(out, "request.contents.-1", toolNode)
|
||||
}
|
||||
}
|
||||
if pp > 0 {
|
||||
out, _ = sjson.SetRawBytes(out, "request.contents.-1", toolNode)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -334,18 +334,3 @@ func ConvertOpenAIRequestToGeminiCLI(modelName string, inputRawJSON []byte, _ bo
|
||||
|
||||
// itoa converts int to string without strconv import for few usages.
|
||||
func itoa(i int) string { return fmt.Sprintf("%d", i) }
|
||||
|
||||
// quoteIfNeeded ensures a string is valid JSON value (quotes plain text), pass-through for JSON objects/arrays.
|
||||
func quoteIfNeeded(s string) string {
|
||||
s = strings.TrimSpace(s)
|
||||
if s == "" {
|
||||
return "\"\""
|
||||
}
|
||||
if len(s) > 0 && (s[0] == '{' || s[0] == '[') {
|
||||
return s
|
||||
}
|
||||
// escape quotes minimally
|
||||
s = strings.ReplaceAll(s, "\\", "\\\\")
|
||||
s = strings.ReplaceAll(s, "\"", "\\\"")
|
||||
return "\"" + s + "\""
|
||||
}
|
||||
|
||||
@@ -179,6 +179,18 @@ func ConvertGeminiResponseToClaude(_ context.Context, _ string, originalRequestR
|
||||
usedTool = true
|
||||
fcName := functionCallResult.Get("name").String()
|
||||
|
||||
// FIX: Handle streaming split/delta where name might be empty in subsequent chunks.
|
||||
// If we are already in tool use mode and name is empty, treat as continuation (delta).
|
||||
if (*param).(*Params).ResponseType == 3 && fcName == "" {
|
||||
if fcArgsResult := functionCallResult.Get("args"); fcArgsResult.Exists() {
|
||||
output = output + "event: content_block_delta\n"
|
||||
data, _ := sjson.Set(fmt.Sprintf(`{"type":"content_block_delta","index":%d,"delta":{"type":"input_json_delta","partial_json":""}}`, (*param).(*Params).ResponseIndex), "delta.partial_json", fcArgsResult.Raw)
|
||||
output = output + fmt.Sprintf("data: %s\n\n\n", data)
|
||||
}
|
||||
// Continue to next part without closing/opening logic
|
||||
continue
|
||||
}
|
||||
|
||||
// Handle state transitions when switching to function calls
|
||||
// Close any existing function call block first
|
||||
if (*param).(*Params).ResponseType == 3 {
|
||||
|
||||
@@ -37,12 +37,28 @@ func ConvertOpenAIRequestToGemini(modelName string, inputRawJSON []byte, _ bool)
|
||||
|
||||
// Reasoning effort -> thinkingBudget/include_thoughts
|
||||
// Note: OpenAI official fields take precedence over extra_body.google.thinking_config
|
||||
// Only convert for models that use numeric budgets (not discrete levels) to avoid
|
||||
// incorrectly applying thinkingBudget for level-based models like gpt-5.
|
||||
// Only apply numeric budgets for models that use budgets (not discrete levels) to avoid
|
||||
// incorrectly applying thinkingBudget for level-based models like gpt-5. Gemini 3 models
|
||||
// use thinkingLevel/includeThoughts instead.
|
||||
re := gjson.GetBytes(rawJSON, "reasoning_effort")
|
||||
hasOfficialThinking := re.Exists()
|
||||
if hasOfficialThinking && util.ModelSupportsThinking(modelName) && !util.ModelUsesThinkingLevels(modelName) {
|
||||
out = util.ApplyReasoningEffortToGemini(out, re.String())
|
||||
if hasOfficialThinking && util.ModelSupportsThinking(modelName) {
|
||||
effort := strings.ToLower(strings.TrimSpace(re.String()))
|
||||
if util.IsGemini3Model(modelName) {
|
||||
switch effort {
|
||||
case "none":
|
||||
out, _ = sjson.DeleteBytes(out, "generationConfig.thinkingConfig")
|
||||
case "auto":
|
||||
includeThoughts := true
|
||||
out = util.ApplyGeminiThinkingLevel(out, "", &includeThoughts)
|
||||
default:
|
||||
if level, ok := util.ValidateGemini3ThinkingLevel(modelName, effort); ok {
|
||||
out = util.ApplyGeminiThinkingLevel(out, level, nil)
|
||||
}
|
||||
}
|
||||
} else if !util.ModelUsesThinkingLevels(modelName) {
|
||||
out = util.ApplyReasoningEffortToGemini(out, effort)
|
||||
}
|
||||
}
|
||||
|
||||
// Cherry Studio extension extra_body.google.thinking_config (effective only when official fields are absent)
|
||||
@@ -207,15 +223,16 @@ func ConvertOpenAIRequestToGemini(modelName string, inputRawJSON []byte, _ bool)
|
||||
}
|
||||
out, _ = sjson.SetRawBytes(out, "contents.-1", node)
|
||||
} else if role == "assistant" {
|
||||
node := []byte(`{"role":"model","parts":[]}`)
|
||||
p := 0
|
||||
|
||||
if content.Type == gjson.String {
|
||||
// Assistant text -> single model content
|
||||
node := []byte(`{"role":"model","parts":[{"text":""}]}`)
|
||||
node, _ = sjson.SetBytes(node, "parts.0.text", content.String())
|
||||
node, _ = sjson.SetBytes(node, "parts.-1.text", content.String())
|
||||
out, _ = sjson.SetRawBytes(out, "contents.-1", node)
|
||||
p++
|
||||
} else if content.IsArray() {
|
||||
// Assistant multimodal content (e.g. text + image) -> single model content with parts
|
||||
node := []byte(`{"role":"model","parts":[]}`)
|
||||
p := 0
|
||||
for _, item := range content.Array() {
|
||||
switch item.Get("type").String() {
|
||||
case "text":
|
||||
@@ -237,47 +254,45 @@ func ConvertOpenAIRequestToGemini(modelName string, inputRawJSON []byte, _ bool)
|
||||
}
|
||||
}
|
||||
out, _ = sjson.SetRawBytes(out, "contents.-1", node)
|
||||
} else if !content.Exists() || content.Type == gjson.Null {
|
||||
// Tool calls -> single model content with functionCall parts
|
||||
tcs := m.Get("tool_calls")
|
||||
if tcs.IsArray() {
|
||||
node := []byte(`{"role":"model","parts":[]}`)
|
||||
p := 0
|
||||
fIDs := make([]string, 0)
|
||||
for _, tc := range tcs.Array() {
|
||||
if tc.Get("type").String() != "function" {
|
||||
continue
|
||||
}
|
||||
fid := tc.Get("id").String()
|
||||
fname := tc.Get("function.name").String()
|
||||
fargs := tc.Get("function.arguments").String()
|
||||
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".functionCall.name", fname)
|
||||
node, _ = sjson.SetRawBytes(node, "parts."+itoa(p)+".functionCall.args", []byte(fargs))
|
||||
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".thoughtSignature", geminiFunctionThoughtSignature)
|
||||
p++
|
||||
if fid != "" {
|
||||
fIDs = append(fIDs, fid)
|
||||
}
|
||||
}
|
||||
out, _ = sjson.SetRawBytes(out, "contents.-1", node)
|
||||
}
|
||||
|
||||
// Append a single tool content combining name + response per function
|
||||
toolNode := []byte(`{"role":"tool","parts":[]}`)
|
||||
pp := 0
|
||||
for _, fid := range fIDs {
|
||||
if name, ok := tcID2Name[fid]; ok {
|
||||
toolNode, _ = sjson.SetBytes(toolNode, "parts."+itoa(pp)+".functionResponse.name", name)
|
||||
resp := toolResponses[fid]
|
||||
if resp == "" {
|
||||
resp = "{}"
|
||||
}
|
||||
toolNode, _ = sjson.SetBytes(toolNode, "parts."+itoa(pp)+".functionResponse.response.result", []byte(resp))
|
||||
pp++
|
||||
// Tool calls -> single model content with functionCall parts
|
||||
tcs := m.Get("tool_calls")
|
||||
if tcs.IsArray() {
|
||||
fIDs := make([]string, 0)
|
||||
for _, tc := range tcs.Array() {
|
||||
if tc.Get("type").String() != "function" {
|
||||
continue
|
||||
}
|
||||
fid := tc.Get("id").String()
|
||||
fname := tc.Get("function.name").String()
|
||||
fargs := tc.Get("function.arguments").String()
|
||||
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".functionCall.name", fname)
|
||||
node, _ = sjson.SetRawBytes(node, "parts."+itoa(p)+".functionCall.args", []byte(fargs))
|
||||
node, _ = sjson.SetBytes(node, "parts."+itoa(p)+".thoughtSignature", geminiFunctionThoughtSignature)
|
||||
p++
|
||||
if fid != "" {
|
||||
fIDs = append(fIDs, fid)
|
||||
}
|
||||
}
|
||||
out, _ = sjson.SetRawBytes(out, "contents.-1", node)
|
||||
|
||||
// Append a single tool content combining name + response per function
|
||||
toolNode := []byte(`{"role":"tool","parts":[]}`)
|
||||
pp := 0
|
||||
for _, fid := range fIDs {
|
||||
if name, ok := tcID2Name[fid]; ok {
|
||||
toolNode, _ = sjson.SetBytes(toolNode, "parts."+itoa(pp)+".functionResponse.name", name)
|
||||
resp := toolResponses[fid]
|
||||
if resp == "" {
|
||||
resp = "{}"
|
||||
}
|
||||
toolNode, _ = sjson.SetBytes(toolNode, "parts."+itoa(pp)+".functionResponse.response.result", []byte(resp))
|
||||
pp++
|
||||
}
|
||||
if pp > 0 {
|
||||
out, _ = sjson.SetRawBytes(out, "contents.-1", toolNode)
|
||||
}
|
||||
}
|
||||
if pp > 0 {
|
||||
out, _ = sjson.SetRawBytes(out, "contents.-1", toolNode)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -363,18 +378,3 @@ func ConvertOpenAIRequestToGemini(modelName string, inputRawJSON []byte, _ bool)
|
||||
|
||||
// itoa converts int to string without strconv import for few usages.
|
||||
func itoa(i int) string { return fmt.Sprintf("%d", i) }
|
||||
|
||||
// quoteIfNeeded ensures a string is valid JSON value (quotes plain text), pass-through for JSON objects/arrays.
|
||||
func quoteIfNeeded(s string) string {
|
||||
s = strings.TrimSpace(s)
|
||||
if s == "" {
|
||||
return "\"\""
|
||||
}
|
||||
if len(s) > 0 && (s[0] == '{' || s[0] == '[') {
|
||||
return s
|
||||
}
|
||||
// escape quotes minimally
|
||||
s = strings.ReplaceAll(s, "\\", "\\\\")
|
||||
s = strings.ReplaceAll(s, "\"", "\\\"")
|
||||
return "\"" + s + "\""
|
||||
}
|
||||
|
||||
@@ -63,10 +63,22 @@ func ConvertClaudeRequestToOpenAI(modelName string, inputRawJSON []byte, stream
|
||||
|
||||
// Thinking: Convert Claude thinking.budget_tokens to OpenAI reasoning_effort
|
||||
if thinking := root.Get("thinking"); thinking.Exists() && thinking.IsObject() {
|
||||
if thinkingType := thinking.Get("type"); thinkingType.Exists() && thinkingType.String() == "enabled" {
|
||||
if budgetTokens := thinking.Get("budget_tokens"); budgetTokens.Exists() {
|
||||
budget := int(budgetTokens.Int())
|
||||
if effort, ok := util.OpenAIThinkingBudgetToEffort(modelName, budget); ok && effort != "" {
|
||||
if thinkingType := thinking.Get("type"); thinkingType.Exists() {
|
||||
switch thinkingType.String() {
|
||||
case "enabled":
|
||||
if budgetTokens := thinking.Get("budget_tokens"); budgetTokens.Exists() {
|
||||
budget := int(budgetTokens.Int())
|
||||
if effort, ok := util.ThinkingBudgetToEffort(modelName, budget); ok && effort != "" {
|
||||
out, _ = sjson.Set(out, "reasoning_effort", effort)
|
||||
}
|
||||
} else {
|
||||
// No budget_tokens specified, default to "auto" for enabled thinking
|
||||
if effort, ok := util.ThinkingBudgetToEffort(modelName, -1); ok && effort != "" {
|
||||
out, _ = sjson.Set(out, "reasoning_effort", effort)
|
||||
}
|
||||
}
|
||||
case "disabled":
|
||||
if effort, ok := util.ThinkingBudgetToEffort(modelName, 0); ok && effort != "" {
|
||||
out, _ = sjson.Set(out, "reasoning_effort", effort)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -128,9 +128,10 @@ func convertOpenAIStreamingChunkToAnthropic(rawJSON []byte, param *ConvertOpenAI
|
||||
param.CreatedAt = root.Get("created").Int()
|
||||
}
|
||||
|
||||
// Check if this is the first chunk (has role)
|
||||
// Emit message_start on the very first chunk, regardless of whether it has a role field.
|
||||
// Some providers (like Copilot) may send tool_calls in the first chunk without a role field.
|
||||
if delta := root.Get("choices.0.delta"); delta.Exists() {
|
||||
if role := delta.Get("role"); role.Exists() && role.String() == "assistant" && !param.MessageStarted {
|
||||
if !param.MessageStarted {
|
||||
// Send message_start event
|
||||
messageStart := map[string]interface{}{
|
||||
"type": "message_start",
|
||||
|
||||
@@ -83,7 +83,7 @@ func ConvertGeminiRequestToOpenAI(modelName string, inputRawJSON []byte, stream
|
||||
if thinkingConfig := genConfig.Get("thinkingConfig"); thinkingConfig.Exists() && thinkingConfig.IsObject() {
|
||||
if thinkingBudget := thinkingConfig.Get("thinkingBudget"); thinkingBudget.Exists() {
|
||||
budget := int(thinkingBudget.Int())
|
||||
if effort, ok := util.OpenAIThinkingBudgetToEffort(modelName, budget); ok && effort != "" {
|
||||
if effort, ok := util.ThinkingBudgetToEffort(modelName, budget); ok && effort != "" {
|
||||
out, _ = sjson.Set(out, "reasoning_effort", effort)
|
||||
}
|
||||
}
|
||||
|
||||
496
internal/util/gemini_schema.go
Normal file
496
internal/util/gemini_schema.go
Normal file
@@ -0,0 +1,496 @@
|
||||
// Package util provides utility functions for the CLI Proxy API server.
|
||||
package util
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sort"
|
||||
"strings"
|
||||
|
||||
"github.com/tidwall/gjson"
|
||||
"github.com/tidwall/sjson"
|
||||
)
|
||||
|
||||
var gjsonPathKeyReplacer = strings.NewReplacer(".", "\\.", "*", "\\*", "?", "\\?")
|
||||
|
||||
// CleanJSONSchemaForGemini transforms a JSON schema to be compatible with Gemini/Antigravity API.
|
||||
// It handles unsupported keywords, type flattening, and schema simplification while preserving
|
||||
// semantic information as description hints.
|
||||
func CleanJSONSchemaForGemini(jsonStr string) string {
|
||||
// Phase 1: Convert and add hints
|
||||
jsonStr = convertRefsToHints(jsonStr)
|
||||
jsonStr = convertConstToEnum(jsonStr)
|
||||
jsonStr = addEnumHints(jsonStr)
|
||||
jsonStr = addAdditionalPropertiesHints(jsonStr)
|
||||
jsonStr = moveConstraintsToDescription(jsonStr)
|
||||
|
||||
// Phase 2: Flatten complex structures
|
||||
jsonStr = mergeAllOf(jsonStr)
|
||||
jsonStr = flattenAnyOfOneOf(jsonStr)
|
||||
jsonStr = flattenTypeArrays(jsonStr)
|
||||
|
||||
// Phase 3: Cleanup
|
||||
jsonStr = removeUnsupportedKeywords(jsonStr)
|
||||
jsonStr = cleanupRequiredFields(jsonStr)
|
||||
|
||||
return jsonStr
|
||||
}
|
||||
|
||||
// convertRefsToHints converts $ref to description hints (Lazy Hint strategy).
|
||||
func convertRefsToHints(jsonStr string) string {
|
||||
paths := findPaths(jsonStr, "$ref")
|
||||
sortByDepth(paths)
|
||||
|
||||
for _, p := range paths {
|
||||
refVal := gjson.Get(jsonStr, p).String()
|
||||
defName := refVal
|
||||
if idx := strings.LastIndex(refVal, "/"); idx >= 0 {
|
||||
defName = refVal[idx+1:]
|
||||
}
|
||||
|
||||
parentPath := trimSuffix(p, ".$ref")
|
||||
hint := fmt.Sprintf("See: %s", defName)
|
||||
if existing := gjson.Get(jsonStr, descriptionPath(parentPath)).String(); existing != "" {
|
||||
hint = fmt.Sprintf("%s (%s)", existing, hint)
|
||||
}
|
||||
|
||||
replacement := `{"type":"object","description":""}`
|
||||
replacement, _ = sjson.Set(replacement, "description", hint)
|
||||
jsonStr = setRawAt(jsonStr, parentPath, replacement)
|
||||
}
|
||||
return jsonStr
|
||||
}
|
||||
|
||||
func convertConstToEnum(jsonStr string) string {
|
||||
for _, p := range findPaths(jsonStr, "const") {
|
||||
val := gjson.Get(jsonStr, p)
|
||||
if !val.Exists() {
|
||||
continue
|
||||
}
|
||||
enumPath := trimSuffix(p, ".const") + ".enum"
|
||||
if !gjson.Get(jsonStr, enumPath).Exists() {
|
||||
jsonStr, _ = sjson.Set(jsonStr, enumPath, []interface{}{val.Value()})
|
||||
}
|
||||
}
|
||||
return jsonStr
|
||||
}
|
||||
|
||||
func addEnumHints(jsonStr string) string {
|
||||
for _, p := range findPaths(jsonStr, "enum") {
|
||||
arr := gjson.Get(jsonStr, p)
|
||||
if !arr.IsArray() {
|
||||
continue
|
||||
}
|
||||
items := arr.Array()
|
||||
if len(items) <= 1 || len(items) > 10 {
|
||||
continue
|
||||
}
|
||||
|
||||
var vals []string
|
||||
for _, item := range items {
|
||||
vals = append(vals, item.String())
|
||||
}
|
||||
jsonStr = appendHint(jsonStr, trimSuffix(p, ".enum"), "Allowed: "+strings.Join(vals, ", "))
|
||||
}
|
||||
return jsonStr
|
||||
}
|
||||
|
||||
func addAdditionalPropertiesHints(jsonStr string) string {
|
||||
for _, p := range findPaths(jsonStr, "additionalProperties") {
|
||||
if gjson.Get(jsonStr, p).Type == gjson.False {
|
||||
jsonStr = appendHint(jsonStr, trimSuffix(p, ".additionalProperties"), "No extra properties allowed")
|
||||
}
|
||||
}
|
||||
return jsonStr
|
||||
}
|
||||
|
||||
var unsupportedConstraints = []string{
|
||||
"minLength", "maxLength", "exclusiveMinimum", "exclusiveMaximum",
|
||||
"pattern", "minItems", "maxItems",
|
||||
}
|
||||
|
||||
func moveConstraintsToDescription(jsonStr string) string {
|
||||
for _, key := range unsupportedConstraints {
|
||||
for _, p := range findPaths(jsonStr, key) {
|
||||
val := gjson.Get(jsonStr, p)
|
||||
if !val.Exists() || val.IsObject() || val.IsArray() {
|
||||
continue
|
||||
}
|
||||
parentPath := trimSuffix(p, "."+key)
|
||||
if isPropertyDefinition(parentPath) {
|
||||
continue
|
||||
}
|
||||
jsonStr = appendHint(jsonStr, parentPath, fmt.Sprintf("%s: %s", key, val.String()))
|
||||
}
|
||||
}
|
||||
return jsonStr
|
||||
}
|
||||
|
||||
func mergeAllOf(jsonStr string) string {
|
||||
paths := findPaths(jsonStr, "allOf")
|
||||
sortByDepth(paths)
|
||||
|
||||
for _, p := range paths {
|
||||
allOf := gjson.Get(jsonStr, p)
|
||||
if !allOf.IsArray() {
|
||||
continue
|
||||
}
|
||||
parentPath := trimSuffix(p, ".allOf")
|
||||
|
||||
for _, item := range allOf.Array() {
|
||||
if props := item.Get("properties"); props.IsObject() {
|
||||
props.ForEach(func(key, value gjson.Result) bool {
|
||||
destPath := joinPath(parentPath, "properties."+escapeGJSONPathKey(key.String()))
|
||||
jsonStr, _ = sjson.SetRaw(jsonStr, destPath, value.Raw)
|
||||
return true
|
||||
})
|
||||
}
|
||||
if req := item.Get("required"); req.IsArray() {
|
||||
reqPath := joinPath(parentPath, "required")
|
||||
current := getStrings(jsonStr, reqPath)
|
||||
for _, r := range req.Array() {
|
||||
if s := r.String(); !contains(current, s) {
|
||||
current = append(current, s)
|
||||
}
|
||||
}
|
||||
jsonStr, _ = sjson.Set(jsonStr, reqPath, current)
|
||||
}
|
||||
}
|
||||
jsonStr, _ = sjson.Delete(jsonStr, p)
|
||||
}
|
||||
return jsonStr
|
||||
}
|
||||
|
||||
func flattenAnyOfOneOf(jsonStr string) string {
|
||||
for _, key := range []string{"anyOf", "oneOf"} {
|
||||
paths := findPaths(jsonStr, key)
|
||||
sortByDepth(paths)
|
||||
|
||||
for _, p := range paths {
|
||||
arr := gjson.Get(jsonStr, p)
|
||||
if !arr.IsArray() || len(arr.Array()) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
parentPath := trimSuffix(p, "."+key)
|
||||
parentDesc := gjson.Get(jsonStr, descriptionPath(parentPath)).String()
|
||||
|
||||
items := arr.Array()
|
||||
bestIdx, allTypes := selectBest(items)
|
||||
selected := items[bestIdx].Raw
|
||||
|
||||
if parentDesc != "" {
|
||||
selected = mergeDescriptionRaw(selected, parentDesc)
|
||||
}
|
||||
|
||||
if len(allTypes) > 1 {
|
||||
hint := "Accepts: " + strings.Join(allTypes, " | ")
|
||||
selected = appendHintRaw(selected, hint)
|
||||
}
|
||||
|
||||
jsonStr = setRawAt(jsonStr, parentPath, selected)
|
||||
}
|
||||
}
|
||||
return jsonStr
|
||||
}
|
||||
|
||||
func selectBest(items []gjson.Result) (bestIdx int, types []string) {
|
||||
bestScore := -1
|
||||
for i, item := range items {
|
||||
t := item.Get("type").String()
|
||||
score := 0
|
||||
|
||||
switch {
|
||||
case t == "object" || item.Get("properties").Exists():
|
||||
score, t = 3, orDefault(t, "object")
|
||||
case t == "array" || item.Get("items").Exists():
|
||||
score, t = 2, orDefault(t, "array")
|
||||
case t != "" && t != "null":
|
||||
score = 1
|
||||
default:
|
||||
t = orDefault(t, "null")
|
||||
}
|
||||
|
||||
if t != "" {
|
||||
types = append(types, t)
|
||||
}
|
||||
if score > bestScore {
|
||||
bestScore, bestIdx = score, i
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func flattenTypeArrays(jsonStr string) string {
|
||||
paths := findPaths(jsonStr, "type")
|
||||
sortByDepth(paths)
|
||||
|
||||
nullableFields := make(map[string][]string)
|
||||
|
||||
for _, p := range paths {
|
||||
res := gjson.Get(jsonStr, p)
|
||||
if !res.IsArray() || len(res.Array()) == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
hasNull := false
|
||||
var nonNullTypes []string
|
||||
for _, item := range res.Array() {
|
||||
s := item.String()
|
||||
if s == "null" {
|
||||
hasNull = true
|
||||
} else if s != "" {
|
||||
nonNullTypes = append(nonNullTypes, s)
|
||||
}
|
||||
}
|
||||
|
||||
firstType := "string"
|
||||
if len(nonNullTypes) > 0 {
|
||||
firstType = nonNullTypes[0]
|
||||
}
|
||||
|
||||
jsonStr, _ = sjson.Set(jsonStr, p, firstType)
|
||||
|
||||
parentPath := trimSuffix(p, ".type")
|
||||
if len(nonNullTypes) > 1 {
|
||||
hint := "Accepts: " + strings.Join(nonNullTypes, " | ")
|
||||
jsonStr = appendHint(jsonStr, parentPath, hint)
|
||||
}
|
||||
|
||||
if hasNull {
|
||||
parts := splitGJSONPath(p)
|
||||
if len(parts) >= 3 && parts[len(parts)-3] == "properties" {
|
||||
fieldNameEscaped := parts[len(parts)-2]
|
||||
fieldName := unescapeGJSONPathKey(fieldNameEscaped)
|
||||
objectPath := strings.Join(parts[:len(parts)-3], ".")
|
||||
nullableFields[objectPath] = append(nullableFields[objectPath], fieldName)
|
||||
|
||||
propPath := joinPath(objectPath, "properties."+fieldNameEscaped)
|
||||
jsonStr = appendHint(jsonStr, propPath, "(nullable)")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for objectPath, fields := range nullableFields {
|
||||
reqPath := joinPath(objectPath, "required")
|
||||
req := gjson.Get(jsonStr, reqPath)
|
||||
if !req.IsArray() {
|
||||
continue
|
||||
}
|
||||
|
||||
var filtered []string
|
||||
for _, r := range req.Array() {
|
||||
if !contains(fields, r.String()) {
|
||||
filtered = append(filtered, r.String())
|
||||
}
|
||||
}
|
||||
|
||||
if len(filtered) == 0 {
|
||||
jsonStr, _ = sjson.Delete(jsonStr, reqPath)
|
||||
} else {
|
||||
jsonStr, _ = sjson.Set(jsonStr, reqPath, filtered)
|
||||
}
|
||||
}
|
||||
return jsonStr
|
||||
}
|
||||
|
||||
func removeUnsupportedKeywords(jsonStr string) string {
|
||||
keywords := append(unsupportedConstraints,
|
||||
"$schema", "$defs", "definitions", "const", "$ref", "additionalProperties",
|
||||
)
|
||||
for _, key := range keywords {
|
||||
for _, p := range findPaths(jsonStr, key) {
|
||||
if isPropertyDefinition(trimSuffix(p, "."+key)) {
|
||||
continue
|
||||
}
|
||||
jsonStr, _ = sjson.Delete(jsonStr, p)
|
||||
}
|
||||
}
|
||||
return jsonStr
|
||||
}
|
||||
|
||||
func cleanupRequiredFields(jsonStr string) string {
|
||||
for _, p := range findPaths(jsonStr, "required") {
|
||||
parentPath := trimSuffix(p, ".required")
|
||||
propsPath := joinPath(parentPath, "properties")
|
||||
|
||||
req := gjson.Get(jsonStr, p)
|
||||
props := gjson.Get(jsonStr, propsPath)
|
||||
if !req.IsArray() || !props.IsObject() {
|
||||
continue
|
||||
}
|
||||
|
||||
var valid []string
|
||||
for _, r := range req.Array() {
|
||||
key := r.String()
|
||||
if props.Get(escapeGJSONPathKey(key)).Exists() {
|
||||
valid = append(valid, key)
|
||||
}
|
||||
}
|
||||
|
||||
if len(valid) != len(req.Array()) {
|
||||
if len(valid) == 0 {
|
||||
jsonStr, _ = sjson.Delete(jsonStr, p)
|
||||
} else {
|
||||
jsonStr, _ = sjson.Set(jsonStr, p, valid)
|
||||
}
|
||||
}
|
||||
}
|
||||
return jsonStr
|
||||
}
|
||||
|
||||
// --- Helpers ---
|
||||
|
||||
func findPaths(jsonStr, field string) []string {
|
||||
var paths []string
|
||||
Walk(gjson.Parse(jsonStr), "", field, &paths)
|
||||
return paths
|
||||
}
|
||||
|
||||
func sortByDepth(paths []string) {
|
||||
sort.Slice(paths, func(i, j int) bool { return len(paths[i]) > len(paths[j]) })
|
||||
}
|
||||
|
||||
func trimSuffix(path, suffix string) string {
|
||||
if path == strings.TrimPrefix(suffix, ".") {
|
||||
return ""
|
||||
}
|
||||
return strings.TrimSuffix(path, suffix)
|
||||
}
|
||||
|
||||
func joinPath(base, suffix string) string {
|
||||
if base == "" {
|
||||
return suffix
|
||||
}
|
||||
return base + "." + suffix
|
||||
}
|
||||
|
||||
func setRawAt(jsonStr, path, value string) string {
|
||||
if path == "" {
|
||||
return value
|
||||
}
|
||||
result, _ := sjson.SetRaw(jsonStr, path, value)
|
||||
return result
|
||||
}
|
||||
|
||||
func isPropertyDefinition(path string) bool {
|
||||
return path == "properties" || strings.HasSuffix(path, ".properties")
|
||||
}
|
||||
|
||||
func descriptionPath(parentPath string) string {
|
||||
if parentPath == "" || parentPath == "@this" {
|
||||
return "description"
|
||||
}
|
||||
return parentPath + ".description"
|
||||
}
|
||||
|
||||
func appendHint(jsonStr, parentPath, hint string) string {
|
||||
descPath := parentPath + ".description"
|
||||
if parentPath == "" || parentPath == "@this" {
|
||||
descPath = "description"
|
||||
}
|
||||
existing := gjson.Get(jsonStr, descPath).String()
|
||||
if existing != "" {
|
||||
hint = fmt.Sprintf("%s (%s)", existing, hint)
|
||||
}
|
||||
jsonStr, _ = sjson.Set(jsonStr, descPath, hint)
|
||||
return jsonStr
|
||||
}
|
||||
|
||||
func appendHintRaw(jsonRaw, hint string) string {
|
||||
existing := gjson.Get(jsonRaw, "description").String()
|
||||
if existing != "" {
|
||||
hint = fmt.Sprintf("%s (%s)", existing, hint)
|
||||
}
|
||||
jsonRaw, _ = sjson.Set(jsonRaw, "description", hint)
|
||||
return jsonRaw
|
||||
}
|
||||
|
||||
func getStrings(jsonStr, path string) []string {
|
||||
var result []string
|
||||
if arr := gjson.Get(jsonStr, path); arr.IsArray() {
|
||||
for _, r := range arr.Array() {
|
||||
result = append(result, r.String())
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func contains(slice []string, item string) bool {
|
||||
for _, s := range slice {
|
||||
if s == item {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func orDefault(val, def string) string {
|
||||
if val == "" {
|
||||
return def
|
||||
}
|
||||
return val
|
||||
}
|
||||
|
||||
func escapeGJSONPathKey(key string) string {
|
||||
return gjsonPathKeyReplacer.Replace(key)
|
||||
}
|
||||
|
||||
func unescapeGJSONPathKey(key string) string {
|
||||
if !strings.Contains(key, "\\") {
|
||||
return key
|
||||
}
|
||||
var b strings.Builder
|
||||
b.Grow(len(key))
|
||||
for i := 0; i < len(key); i++ {
|
||||
if key[i] == '\\' && i+1 < len(key) {
|
||||
i++
|
||||
b.WriteByte(key[i])
|
||||
continue
|
||||
}
|
||||
b.WriteByte(key[i])
|
||||
}
|
||||
return b.String()
|
||||
}
|
||||
|
||||
func splitGJSONPath(path string) []string {
|
||||
if path == "" {
|
||||
return nil
|
||||
}
|
||||
|
||||
parts := make([]string, 0, strings.Count(path, ".")+1)
|
||||
var b strings.Builder
|
||||
b.Grow(len(path))
|
||||
|
||||
for i := 0; i < len(path); i++ {
|
||||
c := path[i]
|
||||
if c == '\\' && i+1 < len(path) {
|
||||
b.WriteByte('\\')
|
||||
i++
|
||||
b.WriteByte(path[i])
|
||||
continue
|
||||
}
|
||||
if c == '.' {
|
||||
parts = append(parts, b.String())
|
||||
b.Reset()
|
||||
continue
|
||||
}
|
||||
b.WriteByte(c)
|
||||
}
|
||||
parts = append(parts, b.String())
|
||||
return parts
|
||||
}
|
||||
|
||||
func mergeDescriptionRaw(schemaRaw, parentDesc string) string {
|
||||
childDesc := gjson.Get(schemaRaw, "description").String()
|
||||
switch {
|
||||
case childDesc == "":
|
||||
schemaRaw, _ = sjson.Set(schemaRaw, "description", parentDesc)
|
||||
return schemaRaw
|
||||
case childDesc == parentDesc:
|
||||
return schemaRaw
|
||||
default:
|
||||
combined := fmt.Sprintf("%s (%s)", parentDesc, childDesc)
|
||||
schemaRaw, _ = sjson.Set(schemaRaw, "description", combined)
|
||||
return schemaRaw
|
||||
}
|
||||
}
|
||||
613
internal/util/gemini_schema_test.go
Normal file
613
internal/util/gemini_schema_test.go
Normal file
@@ -0,0 +1,613 @@
|
||||
package util
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"reflect"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestCleanJSONSchemaForGemini_ConstToEnum(t *testing.T) {
|
||||
input := `{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"kind": {
|
||||
"type": "string",
|
||||
"const": "InsightVizNode"
|
||||
}
|
||||
}
|
||||
}`
|
||||
|
||||
expected := `{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"kind": {
|
||||
"type": "string",
|
||||
"enum": ["InsightVizNode"]
|
||||
}
|
||||
}
|
||||
}`
|
||||
|
||||
result := CleanJSONSchemaForGemini(input)
|
||||
compareJSON(t, expected, result)
|
||||
}
|
||||
|
||||
func TestCleanJSONSchemaForGemini_TypeFlattening_Nullable(t *testing.T) {
|
||||
input := `{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"name": {
|
||||
"type": ["string", "null"]
|
||||
},
|
||||
"other": {
|
||||
"type": "string"
|
||||
}
|
||||
},
|
||||
"required": ["name", "other"]
|
||||
}`
|
||||
|
||||
expected := `{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"name": {
|
||||
"type": "string",
|
||||
"description": "(nullable)"
|
||||
},
|
||||
"other": {
|
||||
"type": "string"
|
||||
}
|
||||
},
|
||||
"required": ["other"]
|
||||
}`
|
||||
|
||||
result := CleanJSONSchemaForGemini(input)
|
||||
compareJSON(t, expected, result)
|
||||
}
|
||||
|
||||
func TestCleanJSONSchemaForGemini_ConstraintsToDescription(t *testing.T) {
|
||||
input := `{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"tags": {
|
||||
"type": "array",
|
||||
"description": "List of tags",
|
||||
"minItems": 1
|
||||
},
|
||||
"name": {
|
||||
"type": "string",
|
||||
"description": "User name",
|
||||
"minLength": 3
|
||||
}
|
||||
}
|
||||
}`
|
||||
|
||||
result := CleanJSONSchemaForGemini(input)
|
||||
|
||||
// minItems should be REMOVED and moved to description
|
||||
if strings.Contains(result, `"minItems"`) {
|
||||
t.Errorf("minItems keyword should be removed")
|
||||
}
|
||||
if !strings.Contains(result, "minItems: 1") {
|
||||
t.Errorf("minItems hint missing in description")
|
||||
}
|
||||
|
||||
// minLength should be moved to description
|
||||
if !strings.Contains(result, "minLength: 3") {
|
||||
t.Errorf("minLength hint missing in description")
|
||||
}
|
||||
if strings.Contains(result, `"minLength":`) || strings.Contains(result, `"minLength" :`) {
|
||||
t.Errorf("minLength keyword should be removed")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCleanJSONSchemaForGemini_AnyOfFlattening_SmartSelection(t *testing.T) {
|
||||
input := `{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"query": {
|
||||
"anyOf": [
|
||||
{ "type": "null" },
|
||||
{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"kind": { "type": "string" }
|
||||
}
|
||||
}
|
||||
]
|
||||
}
|
||||
}
|
||||
}`
|
||||
|
||||
expected := `{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"query": {
|
||||
"type": "object",
|
||||
"description": "Accepts: null | object",
|
||||
"properties": {
|
||||
"kind": { "type": "string" }
|
||||
}
|
||||
}
|
||||
}
|
||||
}`
|
||||
|
||||
result := CleanJSONSchemaForGemini(input)
|
||||
compareJSON(t, expected, result)
|
||||
}
|
||||
|
||||
func TestCleanJSONSchemaForGemini_OneOfFlattening(t *testing.T) {
|
||||
input := `{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"config": {
|
||||
"oneOf": [
|
||||
{ "type": "string" },
|
||||
{ "type": "integer" }
|
||||
]
|
||||
}
|
||||
}
|
||||
}`
|
||||
|
||||
expected := `{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"config": {
|
||||
"type": "string",
|
||||
"description": "Accepts: string | integer"
|
||||
}
|
||||
}
|
||||
}`
|
||||
|
||||
result := CleanJSONSchemaForGemini(input)
|
||||
compareJSON(t, expected, result)
|
||||
}
|
||||
|
||||
func TestCleanJSONSchemaForGemini_AllOfMerging(t *testing.T) {
|
||||
input := `{
|
||||
"type": "object",
|
||||
"allOf": [
|
||||
{
|
||||
"properties": {
|
||||
"a": { "type": "string" }
|
||||
},
|
||||
"required": ["a"]
|
||||
},
|
||||
{
|
||||
"properties": {
|
||||
"b": { "type": "integer" }
|
||||
},
|
||||
"required": ["b"]
|
||||
}
|
||||
]
|
||||
}`
|
||||
|
||||
expected := `{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"a": { "type": "string" },
|
||||
"b": { "type": "integer" }
|
||||
},
|
||||
"required": ["a", "b"]
|
||||
}`
|
||||
|
||||
result := CleanJSONSchemaForGemini(input)
|
||||
compareJSON(t, expected, result)
|
||||
}
|
||||
|
||||
func TestCleanJSONSchemaForGemini_RefHandling(t *testing.T) {
|
||||
input := `{
|
||||
"definitions": {
|
||||
"User": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"name": { "type": "string" }
|
||||
}
|
||||
}
|
||||
},
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"customer": { "$ref": "#/definitions/User" }
|
||||
}
|
||||
}`
|
||||
|
||||
expected := `{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"customer": {
|
||||
"type": "object",
|
||||
"description": "See: User"
|
||||
}
|
||||
}
|
||||
}`
|
||||
|
||||
result := CleanJSONSchemaForGemini(input)
|
||||
compareJSON(t, expected, result)
|
||||
}
|
||||
|
||||
func TestCleanJSONSchemaForGemini_RefHandling_DescriptionEscaping(t *testing.T) {
|
||||
input := `{
|
||||
"definitions": {
|
||||
"User": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"name": { "type": "string" }
|
||||
}
|
||||
}
|
||||
},
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"customer": {
|
||||
"description": "He said \"hi\"\\nsecond line",
|
||||
"$ref": "#/definitions/User"
|
||||
}
|
||||
}
|
||||
}`
|
||||
|
||||
expected := `{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"customer": {
|
||||
"type": "object",
|
||||
"description": "He said \"hi\"\\nsecond line (See: User)"
|
||||
}
|
||||
}
|
||||
}`
|
||||
|
||||
result := CleanJSONSchemaForGemini(input)
|
||||
compareJSON(t, expected, result)
|
||||
}
|
||||
|
||||
func TestCleanJSONSchemaForGemini_CyclicRefDefaults(t *testing.T) {
|
||||
input := `{
|
||||
"definitions": {
|
||||
"Node": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"child": { "$ref": "#/definitions/Node" }
|
||||
}
|
||||
}
|
||||
},
|
||||
"$ref": "#/definitions/Node"
|
||||
}`
|
||||
|
||||
result := CleanJSONSchemaForGemini(input)
|
||||
|
||||
var resMap map[string]interface{}
|
||||
json.Unmarshal([]byte(result), &resMap)
|
||||
|
||||
if resMap["type"] != "object" {
|
||||
t.Errorf("Expected type: object, got: %v", resMap["type"])
|
||||
}
|
||||
|
||||
desc, ok := resMap["description"].(string)
|
||||
if !ok || !strings.Contains(desc, "Node") {
|
||||
t.Errorf("Expected description hint containing 'Node', got: %v", resMap["description"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestCleanJSONSchemaForGemini_RequiredCleanup(t *testing.T) {
|
||||
input := `{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"a": {"type": "string"},
|
||||
"b": {"type": "string"}
|
||||
},
|
||||
"required": ["a", "b", "c"]
|
||||
}`
|
||||
|
||||
expected := `{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"a": {"type": "string"},
|
||||
"b": {"type": "string"}
|
||||
},
|
||||
"required": ["a", "b"]
|
||||
}`
|
||||
|
||||
result := CleanJSONSchemaForGemini(input)
|
||||
compareJSON(t, expected, result)
|
||||
}
|
||||
|
||||
func TestCleanJSONSchemaForGemini_AllOfMerging_DotKeys(t *testing.T) {
|
||||
input := `{
|
||||
"type": "object",
|
||||
"allOf": [
|
||||
{
|
||||
"properties": {
|
||||
"my.param": { "type": "string" }
|
||||
},
|
||||
"required": ["my.param"]
|
||||
},
|
||||
{
|
||||
"properties": {
|
||||
"b": { "type": "integer" }
|
||||
},
|
||||
"required": ["b"]
|
||||
}
|
||||
]
|
||||
}`
|
||||
|
||||
expected := `{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"my.param": { "type": "string" },
|
||||
"b": { "type": "integer" }
|
||||
},
|
||||
"required": ["my.param", "b"]
|
||||
}`
|
||||
|
||||
result := CleanJSONSchemaForGemini(input)
|
||||
compareJSON(t, expected, result)
|
||||
}
|
||||
|
||||
func TestCleanJSONSchemaForGemini_PropertyNameCollision(t *testing.T) {
|
||||
// A tool has an argument named "pattern" - should NOT be treated as a constraint
|
||||
input := `{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"pattern": {
|
||||
"type": "string",
|
||||
"description": "The regex pattern"
|
||||
}
|
||||
},
|
||||
"required": ["pattern"]
|
||||
}`
|
||||
|
||||
expected := `{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"pattern": {
|
||||
"type": "string",
|
||||
"description": "The regex pattern"
|
||||
}
|
||||
},
|
||||
"required": ["pattern"]
|
||||
}`
|
||||
|
||||
result := CleanJSONSchemaForGemini(input)
|
||||
compareJSON(t, expected, result)
|
||||
|
||||
var resMap map[string]interface{}
|
||||
json.Unmarshal([]byte(result), &resMap)
|
||||
props, _ := resMap["properties"].(map[string]interface{})
|
||||
if _, ok := props["description"]; ok {
|
||||
t.Errorf("Invalid 'description' property injected into properties map")
|
||||
}
|
||||
}
|
||||
|
||||
func TestCleanJSONSchemaForGemini_DotKeys(t *testing.T) {
|
||||
input := `{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"my.param": {
|
||||
"type": "string",
|
||||
"$ref": "#/definitions/MyType"
|
||||
}
|
||||
},
|
||||
"definitions": {
|
||||
"MyType": { "type": "string" }
|
||||
}
|
||||
}`
|
||||
|
||||
result := CleanJSONSchemaForGemini(input)
|
||||
|
||||
var resMap map[string]interface{}
|
||||
if err := json.Unmarshal([]byte(result), &resMap); err != nil {
|
||||
t.Fatalf("Failed to unmarshal result: %v", err)
|
||||
}
|
||||
|
||||
props, ok := resMap["properties"].(map[string]interface{})
|
||||
if !ok {
|
||||
t.Fatalf("properties missing")
|
||||
}
|
||||
|
||||
if val, ok := props["my.param"]; !ok {
|
||||
t.Fatalf("Key 'my.param' is missing. Result: %s", result)
|
||||
} else {
|
||||
valMap, _ := val.(map[string]interface{})
|
||||
if _, hasRef := valMap["$ref"]; hasRef {
|
||||
t.Errorf("Key 'my.param' still contains $ref")
|
||||
}
|
||||
if _, ok := props["my"]; ok {
|
||||
t.Errorf("Artifact key 'my' created by sjson splitting")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestCleanJSONSchemaForGemini_AnyOfAlternativeHints(t *testing.T) {
|
||||
input := `{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"value": {
|
||||
"anyOf": [
|
||||
{ "type": "string" },
|
||||
{ "type": "integer" },
|
||||
{ "type": "null" }
|
||||
]
|
||||
}
|
||||
}
|
||||
}`
|
||||
|
||||
result := CleanJSONSchemaForGemini(input)
|
||||
|
||||
if !strings.Contains(result, "Accepts:") {
|
||||
t.Errorf("Expected alternative types hint, got: %s", result)
|
||||
}
|
||||
if !strings.Contains(result, "string") || !strings.Contains(result, "integer") {
|
||||
t.Errorf("Expected all alternative types in hint, got: %s", result)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCleanJSONSchemaForGemini_NullableHint(t *testing.T) {
|
||||
input := `{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"name": {
|
||||
"type": ["string", "null"],
|
||||
"description": "User name"
|
||||
}
|
||||
},
|
||||
"required": ["name"]
|
||||
}`
|
||||
|
||||
result := CleanJSONSchemaForGemini(input)
|
||||
|
||||
if !strings.Contains(result, "(nullable)") {
|
||||
t.Errorf("Expected nullable hint, got: %s", result)
|
||||
}
|
||||
if !strings.Contains(result, "User name") {
|
||||
t.Errorf("Expected original description to be preserved, got: %s", result)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCleanJSONSchemaForGemini_TypeFlattening_Nullable_DotKey(t *testing.T) {
|
||||
input := `{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"my.param": {
|
||||
"type": ["string", "null"]
|
||||
},
|
||||
"other": {
|
||||
"type": "string"
|
||||
}
|
||||
},
|
||||
"required": ["my.param", "other"]
|
||||
}`
|
||||
|
||||
expected := `{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"my.param": {
|
||||
"type": "string",
|
||||
"description": "(nullable)"
|
||||
},
|
||||
"other": {
|
||||
"type": "string"
|
||||
}
|
||||
},
|
||||
"required": ["other"]
|
||||
}`
|
||||
|
||||
result := CleanJSONSchemaForGemini(input)
|
||||
compareJSON(t, expected, result)
|
||||
}
|
||||
|
||||
func TestCleanJSONSchemaForGemini_EnumHint(t *testing.T) {
|
||||
input := `{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"status": {
|
||||
"type": "string",
|
||||
"enum": ["active", "inactive", "pending"],
|
||||
"description": "Current status"
|
||||
}
|
||||
}
|
||||
}`
|
||||
|
||||
result := CleanJSONSchemaForGemini(input)
|
||||
|
||||
if !strings.Contains(result, "Allowed:") {
|
||||
t.Errorf("Expected enum values hint, got: %s", result)
|
||||
}
|
||||
if !strings.Contains(result, "active") || !strings.Contains(result, "inactive") {
|
||||
t.Errorf("Expected enum values in hint, got: %s", result)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCleanJSONSchemaForGemini_AdditionalPropertiesHint(t *testing.T) {
|
||||
input := `{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"name": { "type": "string" }
|
||||
},
|
||||
"additionalProperties": false
|
||||
}`
|
||||
|
||||
result := CleanJSONSchemaForGemini(input)
|
||||
|
||||
if !strings.Contains(result, "No extra properties allowed") {
|
||||
t.Errorf("Expected additionalProperties hint, got: %s", result)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCleanJSONSchemaForGemini_AnyOfFlattening_PreservesDescription(t *testing.T) {
|
||||
input := `{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"config": {
|
||||
"description": "Parent desc",
|
||||
"anyOf": [
|
||||
{ "type": "string", "description": "Child desc" },
|
||||
{ "type": "integer" }
|
||||
]
|
||||
}
|
||||
}
|
||||
}`
|
||||
|
||||
expected := `{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"config": {
|
||||
"type": "string",
|
||||
"description": "Parent desc (Child desc) (Accepts: string | integer)"
|
||||
}
|
||||
}
|
||||
}`
|
||||
|
||||
result := CleanJSONSchemaForGemini(input)
|
||||
compareJSON(t, expected, result)
|
||||
}
|
||||
|
||||
func TestCleanJSONSchemaForGemini_SingleEnumNoHint(t *testing.T) {
|
||||
input := `{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"kind": {
|
||||
"type": "string",
|
||||
"enum": ["fixed"]
|
||||
}
|
||||
}
|
||||
}`
|
||||
|
||||
result := CleanJSONSchemaForGemini(input)
|
||||
|
||||
if strings.Contains(result, "Allowed:") {
|
||||
t.Errorf("Single value enum should not add Allowed hint, got: %s", result)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCleanJSONSchemaForGemini_MultipleNonNullTypes(t *testing.T) {
|
||||
input := `{
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"value": {
|
||||
"type": ["string", "integer", "boolean"]
|
||||
}
|
||||
}
|
||||
}`
|
||||
|
||||
result := CleanJSONSchemaForGemini(input)
|
||||
|
||||
if !strings.Contains(result, "Accepts:") {
|
||||
t.Errorf("Expected multiple types hint, got: %s", result)
|
||||
}
|
||||
if !strings.Contains(result, "string") || !strings.Contains(result, "integer") || !strings.Contains(result, "boolean") {
|
||||
t.Errorf("Expected all types in hint, got: %s", result)
|
||||
}
|
||||
}
|
||||
|
||||
func compareJSON(t *testing.T, expectedJSON, actualJSON string) {
|
||||
var expMap, actMap map[string]interface{}
|
||||
errExp := json.Unmarshal([]byte(expectedJSON), &expMap)
|
||||
errAct := json.Unmarshal([]byte(actualJSON), &actMap)
|
||||
|
||||
if errExp != nil || errAct != nil {
|
||||
t.Fatalf("JSON Unmarshal error. Exp: %v, Act: %v", errExp, errAct)
|
||||
}
|
||||
|
||||
if !reflect.DeepEqual(expMap, actMap) {
|
||||
expBytes, _ := json.MarshalIndent(expMap, "", " ")
|
||||
actBytes, _ := json.MarshalIndent(actMap, "", " ")
|
||||
t.Errorf("JSON mismatch:\nExpected:\n%s\n\nActual:\n%s", string(expBytes), string(actBytes))
|
||||
}
|
||||
}
|
||||
@@ -1,6 +1,7 @@
|
||||
package util
|
||||
|
||||
import (
|
||||
"regexp"
|
||||
"strings"
|
||||
|
||||
"github.com/tidwall/gjson"
|
||||
@@ -13,6 +14,44 @@ const (
|
||||
GeminiOriginalModelMetadataKey = "gemini_original_model"
|
||||
)
|
||||
|
||||
// Gemini model family detection patterns
|
||||
var (
|
||||
gemini3Pattern = regexp.MustCompile(`(?i)^gemini[_-]?3[_-]`)
|
||||
gemini3ProPattern = regexp.MustCompile(`(?i)^gemini[_-]?3[_-]pro`)
|
||||
gemini3FlashPattern = regexp.MustCompile(`(?i)^gemini[_-]?3[_-]flash`)
|
||||
gemini25Pattern = regexp.MustCompile(`(?i)^gemini[_-]?2\.5[_-]`)
|
||||
)
|
||||
|
||||
// IsGemini3Model returns true if the model is a Gemini 3 family model.
|
||||
// Gemini 3 models should use thinkingLevel (string) instead of thinkingBudget (number).
|
||||
func IsGemini3Model(model string) bool {
|
||||
return gemini3Pattern.MatchString(model)
|
||||
}
|
||||
|
||||
// IsGemini3ProModel returns true if the model is a Gemini 3 Pro variant.
|
||||
// Gemini 3 Pro supports thinkingLevel: "low", "high" (default: "high")
|
||||
func IsGemini3ProModel(model string) bool {
|
||||
return gemini3ProPattern.MatchString(model)
|
||||
}
|
||||
|
||||
// IsGemini3FlashModel returns true if the model is a Gemini 3 Flash variant.
|
||||
// Gemini 3 Flash supports thinkingLevel: "minimal", "low", "medium", "high" (default: "high")
|
||||
func IsGemini3FlashModel(model string) bool {
|
||||
return gemini3FlashPattern.MatchString(model)
|
||||
}
|
||||
|
||||
// IsGemini25Model returns true if the model is a Gemini 2.5 family model.
|
||||
// Gemini 2.5 models should use thinkingBudget (number).
|
||||
func IsGemini25Model(model string) bool {
|
||||
return gemini25Pattern.MatchString(model)
|
||||
}
|
||||
|
||||
// Gemini3ProThinkingLevels are the valid thinkingLevel values for Gemini 3 Pro models.
|
||||
var Gemini3ProThinkingLevels = []string{"low", "high"}
|
||||
|
||||
// Gemini3FlashThinkingLevels are the valid thinkingLevel values for Gemini 3 Flash models.
|
||||
var Gemini3FlashThinkingLevels = []string{"minimal", "low", "medium", "high"}
|
||||
|
||||
func ApplyGeminiThinkingConfig(body []byte, budget *int, includeThoughts *bool) []byte {
|
||||
if budget == nil && includeThoughts == nil {
|
||||
return body
|
||||
@@ -69,10 +108,141 @@ func ApplyGeminiCLIThinkingConfig(body []byte, budget *int, includeThoughts *boo
|
||||
return updated
|
||||
}
|
||||
|
||||
// ApplyGeminiThinkingLevel applies thinkingLevel config for Gemini 3 models.
|
||||
// For standard Gemini API format (generationConfig.thinkingConfig path).
|
||||
// Per Google's documentation, Gemini 3 models should use thinkingLevel instead of thinkingBudget.
|
||||
func ApplyGeminiThinkingLevel(body []byte, level string, includeThoughts *bool) []byte {
|
||||
if level == "" && includeThoughts == nil {
|
||||
return body
|
||||
}
|
||||
updated := body
|
||||
if level != "" {
|
||||
valuePath := "generationConfig.thinkingConfig.thinkingLevel"
|
||||
rewritten, err := sjson.SetBytes(updated, valuePath, level)
|
||||
if err == nil {
|
||||
updated = rewritten
|
||||
}
|
||||
}
|
||||
// Default to including thoughts when a level is set but no explicit include flag is provided.
|
||||
incl := includeThoughts
|
||||
if incl == nil && level != "" {
|
||||
defaultInclude := true
|
||||
incl = &defaultInclude
|
||||
}
|
||||
if incl != nil {
|
||||
valuePath := "generationConfig.thinkingConfig.includeThoughts"
|
||||
rewritten, err := sjson.SetBytes(updated, valuePath, *incl)
|
||||
if err == nil {
|
||||
updated = rewritten
|
||||
}
|
||||
}
|
||||
return updated
|
||||
}
|
||||
|
||||
// ApplyGeminiCLIThinkingLevel applies thinkingLevel config for Gemini 3 models.
|
||||
// For Gemini CLI API format (request.generationConfig.thinkingConfig path).
|
||||
// Per Google's documentation, Gemini 3 models should use thinkingLevel instead of thinkingBudget.
|
||||
func ApplyGeminiCLIThinkingLevel(body []byte, level string, includeThoughts *bool) []byte {
|
||||
if level == "" && includeThoughts == nil {
|
||||
return body
|
||||
}
|
||||
updated := body
|
||||
if level != "" {
|
||||
valuePath := "request.generationConfig.thinkingConfig.thinkingLevel"
|
||||
rewritten, err := sjson.SetBytes(updated, valuePath, level)
|
||||
if err == nil {
|
||||
updated = rewritten
|
||||
}
|
||||
}
|
||||
// Default to including thoughts when a level is set but no explicit include flag is provided.
|
||||
incl := includeThoughts
|
||||
if incl == nil && level != "" {
|
||||
defaultInclude := true
|
||||
incl = &defaultInclude
|
||||
}
|
||||
if incl != nil {
|
||||
valuePath := "request.generationConfig.thinkingConfig.includeThoughts"
|
||||
rewritten, err := sjson.SetBytes(updated, valuePath, *incl)
|
||||
if err == nil {
|
||||
updated = rewritten
|
||||
}
|
||||
}
|
||||
return updated
|
||||
}
|
||||
|
||||
// ValidateGemini3ThinkingLevel validates that the thinkingLevel is valid for the Gemini 3 model variant.
|
||||
// Returns the validated level (normalized to lowercase) and true if valid, or empty string and false if invalid.
|
||||
func ValidateGemini3ThinkingLevel(model, level string) (string, bool) {
|
||||
if level == "" {
|
||||
return "", false
|
||||
}
|
||||
normalized := strings.ToLower(strings.TrimSpace(level))
|
||||
|
||||
var validLevels []string
|
||||
if IsGemini3ProModel(model) {
|
||||
validLevels = Gemini3ProThinkingLevels
|
||||
} else if IsGemini3FlashModel(model) {
|
||||
validLevels = Gemini3FlashThinkingLevels
|
||||
} else if IsGemini3Model(model) {
|
||||
// Unknown Gemini 3 variant - allow all levels as fallback
|
||||
validLevels = Gemini3FlashThinkingLevels
|
||||
} else {
|
||||
return "", false
|
||||
}
|
||||
|
||||
for _, valid := range validLevels {
|
||||
if normalized == valid {
|
||||
return normalized, true
|
||||
}
|
||||
}
|
||||
return "", false
|
||||
}
|
||||
|
||||
// ThinkingBudgetToGemini3Level converts a thinkingBudget to a thinkingLevel for Gemini 3 models.
|
||||
// This provides backward compatibility when thinkingBudget is provided for Gemini 3 models.
|
||||
// Returns the appropriate thinkingLevel and true if conversion is possible.
|
||||
func ThinkingBudgetToGemini3Level(model string, budget int) (string, bool) {
|
||||
if !IsGemini3Model(model) {
|
||||
return "", false
|
||||
}
|
||||
|
||||
// Map budget to level based on Google's documentation
|
||||
// Gemini 3 Pro: "low", "high" (default: "high")
|
||||
// Gemini 3 Flash: "minimal", "low", "medium", "high" (default: "high")
|
||||
switch {
|
||||
case budget == -1:
|
||||
// Dynamic budget maps to "high" (API default)
|
||||
return "high", true
|
||||
case budget == 0:
|
||||
// Zero budget - Gemini 3 doesn't support disabling thinking
|
||||
// Map to lowest available level
|
||||
if IsGemini3FlashModel(model) {
|
||||
return "minimal", true
|
||||
}
|
||||
return "low", true
|
||||
case budget > 0 && budget <= 512:
|
||||
if IsGemini3FlashModel(model) {
|
||||
return "minimal", true
|
||||
}
|
||||
return "low", true
|
||||
case budget <= 1024:
|
||||
return "low", true
|
||||
case budget <= 8192:
|
||||
if IsGemini3FlashModel(model) {
|
||||
return "medium", true
|
||||
}
|
||||
return "low", true // Pro doesn't have medium, use low
|
||||
default:
|
||||
return "high", true
|
||||
}
|
||||
}
|
||||
|
||||
// modelsWithDefaultThinking lists models that should have thinking enabled by default
|
||||
// when no explicit thinkingConfig is provided.
|
||||
var modelsWithDefaultThinking = map[string]bool{
|
||||
"gemini-3-pro-preview": true,
|
||||
"gemini-3-pro-preview": true,
|
||||
"gemini-3-pro-image-preview": true,
|
||||
// "gemini-3-flash-preview": true,
|
||||
}
|
||||
|
||||
// ModelHasDefaultThinking returns true if the model should have thinking enabled by default.
|
||||
@@ -83,6 +253,7 @@ func ModelHasDefaultThinking(model string) bool {
|
||||
// ApplyDefaultThinkingIfNeeded injects default thinkingConfig for models that require it.
|
||||
// For standard Gemini API format (generationConfig.thinkingConfig path).
|
||||
// Returns the modified body if thinkingConfig was added, otherwise returns the original.
|
||||
// For Gemini 3 models, uses thinkingLevel instead of thinkingBudget per Google's documentation.
|
||||
func ApplyDefaultThinkingIfNeeded(model string, body []byte) []byte {
|
||||
if !ModelHasDefaultThinking(model) {
|
||||
return body
|
||||
@@ -90,14 +261,59 @@ func ApplyDefaultThinkingIfNeeded(model string, body []byte) []byte {
|
||||
if gjson.GetBytes(body, "generationConfig.thinkingConfig").Exists() {
|
||||
return body
|
||||
}
|
||||
// Gemini 3 models use thinkingLevel instead of thinkingBudget
|
||||
if IsGemini3Model(model) {
|
||||
// Don't set a default - let the API use its dynamic default ("high")
|
||||
// Only set includeThoughts
|
||||
updated, _ := sjson.SetBytes(body, "generationConfig.thinkingConfig.includeThoughts", true)
|
||||
return updated
|
||||
}
|
||||
// Gemini 2.5 and other models use thinkingBudget
|
||||
updated, _ := sjson.SetBytes(body, "generationConfig.thinkingConfig.thinkingBudget", -1)
|
||||
updated, _ = sjson.SetBytes(updated, "generationConfig.thinkingConfig.include_thoughts", true)
|
||||
return updated
|
||||
}
|
||||
|
||||
// ApplyGemini3ThinkingLevelFromMetadata applies thinkingLevel from metadata for Gemini 3 models.
|
||||
// For standard Gemini API format (generationConfig.thinkingConfig path).
|
||||
// This handles the case where reasoning_effort is specified via model name suffix (e.g., model(minimal)).
|
||||
func ApplyGemini3ThinkingLevelFromMetadata(model string, metadata map[string]any, body []byte) []byte {
|
||||
if !IsGemini3Model(model) {
|
||||
return body
|
||||
}
|
||||
effort, ok := ReasoningEffortFromMetadata(metadata)
|
||||
if !ok || effort == "" {
|
||||
return body
|
||||
}
|
||||
// Validate and apply the thinkingLevel
|
||||
if level, valid := ValidateGemini3ThinkingLevel(model, effort); valid {
|
||||
return ApplyGeminiThinkingLevel(body, level, nil)
|
||||
}
|
||||
return body
|
||||
}
|
||||
|
||||
// ApplyGemini3ThinkingLevelFromMetadataCLI applies thinkingLevel from metadata for Gemini 3 models.
|
||||
// For Gemini CLI API format (request.generationConfig.thinkingConfig path).
|
||||
// This handles the case where reasoning_effort is specified via model name suffix (e.g., model(minimal)).
|
||||
func ApplyGemini3ThinkingLevelFromMetadataCLI(model string, metadata map[string]any, body []byte) []byte {
|
||||
if !IsGemini3Model(model) {
|
||||
return body
|
||||
}
|
||||
effort, ok := ReasoningEffortFromMetadata(metadata)
|
||||
if !ok || effort == "" {
|
||||
return body
|
||||
}
|
||||
// Validate and apply the thinkingLevel
|
||||
if level, valid := ValidateGemini3ThinkingLevel(model, effort); valid {
|
||||
return ApplyGeminiCLIThinkingLevel(body, level, nil)
|
||||
}
|
||||
return body
|
||||
}
|
||||
|
||||
// ApplyDefaultThinkingIfNeededCLI injects default thinkingConfig for models that require it.
|
||||
// For Gemini CLI API format (request.generationConfig.thinkingConfig path).
|
||||
// Returns the modified body if thinkingConfig was added, otherwise returns the original.
|
||||
// For Gemini 3 models, uses thinkingLevel instead of thinkingBudget per Google's documentation.
|
||||
func ApplyDefaultThinkingIfNeededCLI(model string, body []byte) []byte {
|
||||
if !ModelHasDefaultThinking(model) {
|
||||
return body
|
||||
@@ -105,6 +321,14 @@ func ApplyDefaultThinkingIfNeededCLI(model string, body []byte) []byte {
|
||||
if gjson.GetBytes(body, "request.generationConfig.thinkingConfig").Exists() {
|
||||
return body
|
||||
}
|
||||
// Gemini 3 models use thinkingLevel instead of thinkingBudget
|
||||
if IsGemini3Model(model) {
|
||||
// Don't set a default - let the API use its dynamic default ("high")
|
||||
// Only set includeThoughts
|
||||
updated, _ := sjson.SetBytes(body, "request.generationConfig.thinkingConfig.includeThoughts", true)
|
||||
return updated
|
||||
}
|
||||
// Gemini 2.5 and other models use thinkingBudget
|
||||
updated, _ := sjson.SetBytes(body, "request.generationConfig.thinkingConfig.thinkingBudget", -1)
|
||||
updated, _ = sjson.SetBytes(updated, "request.generationConfig.thinkingConfig.include_thoughts", true)
|
||||
return updated
|
||||
@@ -128,12 +352,31 @@ func StripThinkingConfigIfUnsupported(model string, body []byte) []byte {
|
||||
|
||||
// NormalizeGeminiThinkingBudget normalizes the thinkingBudget value in a standard Gemini
|
||||
// request body (generationConfig.thinkingConfig.thinkingBudget path).
|
||||
func NormalizeGeminiThinkingBudget(model string, body []byte) []byte {
|
||||
// For Gemini 3 models, converts thinkingBudget to thinkingLevel per Google's documentation,
|
||||
// unless skipGemini3Check is provided and true.
|
||||
func NormalizeGeminiThinkingBudget(model string, body []byte, skipGemini3Check ...bool) []byte {
|
||||
const budgetPath = "generationConfig.thinkingConfig.thinkingBudget"
|
||||
const levelPath = "generationConfig.thinkingConfig.thinkingLevel"
|
||||
|
||||
budget := gjson.GetBytes(body, budgetPath)
|
||||
if !budget.Exists() {
|
||||
return body
|
||||
}
|
||||
|
||||
// For Gemini 3 models, convert thinkingBudget to thinkingLevel
|
||||
skipGemini3 := len(skipGemini3Check) > 0 && skipGemini3Check[0]
|
||||
if IsGemini3Model(model) && !skipGemini3 {
|
||||
if level, ok := ThinkingBudgetToGemini3Level(model, int(budget.Int())); ok {
|
||||
updated, _ := sjson.SetBytes(body, levelPath, level)
|
||||
updated, _ = sjson.DeleteBytes(updated, budgetPath)
|
||||
return updated
|
||||
}
|
||||
// If conversion fails, just remove the budget (let API use default)
|
||||
updated, _ := sjson.DeleteBytes(body, budgetPath)
|
||||
return updated
|
||||
}
|
||||
|
||||
// For Gemini 2.5 and other models, normalize the budget value
|
||||
normalized := NormalizeThinkingBudget(model, int(budget.Int()))
|
||||
updated, _ := sjson.SetBytes(body, budgetPath, normalized)
|
||||
return updated
|
||||
@@ -141,12 +384,31 @@ func NormalizeGeminiThinkingBudget(model string, body []byte) []byte {
|
||||
|
||||
// NormalizeGeminiCLIThinkingBudget normalizes the thinkingBudget value in a Gemini CLI
|
||||
// request body (request.generationConfig.thinkingConfig.thinkingBudget path).
|
||||
func NormalizeGeminiCLIThinkingBudget(model string, body []byte) []byte {
|
||||
// For Gemini 3 models, converts thinkingBudget to thinkingLevel per Google's documentation,
|
||||
// unless skipGemini3Check is provided and true.
|
||||
func NormalizeGeminiCLIThinkingBudget(model string, body []byte, skipGemini3Check ...bool) []byte {
|
||||
const budgetPath = "request.generationConfig.thinkingConfig.thinkingBudget"
|
||||
const levelPath = "request.generationConfig.thinkingConfig.thinkingLevel"
|
||||
|
||||
budget := gjson.GetBytes(body, budgetPath)
|
||||
if !budget.Exists() {
|
||||
return body
|
||||
}
|
||||
|
||||
// For Gemini 3 models, convert thinkingBudget to thinkingLevel
|
||||
skipGemini3 := len(skipGemini3Check) > 0 && skipGemini3Check[0]
|
||||
if IsGemini3Model(model) && !skipGemini3 {
|
||||
if level, ok := ThinkingBudgetToGemini3Level(model, int(budget.Int())); ok {
|
||||
updated, _ := sjson.SetBytes(body, levelPath, level)
|
||||
updated, _ = sjson.DeleteBytes(updated, budgetPath)
|
||||
return updated
|
||||
}
|
||||
// If conversion fails, just remove the budget (let API use default)
|
||||
updated, _ := sjson.DeleteBytes(body, budgetPath)
|
||||
return updated
|
||||
}
|
||||
|
||||
// For Gemini 2.5 and other models, normalize the budget value
|
||||
normalized := NormalizeThinkingBudget(model, int(budget.Int()))
|
||||
updated, _ := sjson.SetBytes(body, budgetPath, normalized)
|
||||
return updated
|
||||
@@ -218,44 +480,74 @@ func ApplyReasoningEffortToGeminiCLI(body []byte, effort string) []byte {
|
||||
}
|
||||
|
||||
// ConvertThinkingLevelToBudget checks for "generationConfig.thinkingConfig.thinkingLevel"
|
||||
// and converts it to "thinkingBudget".
|
||||
// "high" -> 32768
|
||||
// "low" -> 128
|
||||
// It removes "thinkingLevel" after conversion.
|
||||
func ConvertThinkingLevelToBudget(body []byte) []byte {
|
||||
// and converts it to "thinkingBudget" for Gemini 2.5 models.
|
||||
// For Gemini 3 models, preserves thinkingLevel unless skipGemini3Check is provided and true.
|
||||
// Mappings for Gemini 2.5:
|
||||
// - "high" -> 32768
|
||||
// - "medium" -> 8192
|
||||
// - "low" -> 1024
|
||||
// - "minimal" -> 512
|
||||
//
|
||||
// It removes "thinkingLevel" after conversion (for Gemini 2.5 only).
|
||||
func ConvertThinkingLevelToBudget(body []byte, model string, skipGemini3Check ...bool) []byte {
|
||||
levelPath := "generationConfig.thinkingConfig.thinkingLevel"
|
||||
res := gjson.GetBytes(body, levelPath)
|
||||
if !res.Exists() {
|
||||
return body
|
||||
}
|
||||
|
||||
level := strings.ToLower(res.String())
|
||||
var budget int
|
||||
switch level {
|
||||
case "high":
|
||||
budget = 32768
|
||||
case "low":
|
||||
budget = 128
|
||||
default:
|
||||
// If unknown level, we might just leave it or default.
|
||||
// User only specified high and low. We'll assume we shouldn't touch it if it's something else,
|
||||
// or maybe we should just remove the invalid level?
|
||||
// For safety adhering to strict instructions: "If high... if low...".
|
||||
// If it's something else, the upstream might fail anyway if we leave it,
|
||||
// but let's just delete the level if we processed it.
|
||||
// Actually, let's check if we need to do anything for other values.
|
||||
// For now, only handle high/low.
|
||||
// For Gemini 3 models, preserve thinkingLevel unless explicitly skipped
|
||||
skipGemini3 := len(skipGemini3Check) > 0 && skipGemini3Check[0]
|
||||
if IsGemini3Model(model) && !skipGemini3 {
|
||||
return body
|
||||
}
|
||||
|
||||
// Set budget
|
||||
budget, ok := ThinkingLevelToBudget(res.String())
|
||||
if !ok {
|
||||
updated, _ := sjson.DeleteBytes(body, levelPath)
|
||||
return updated
|
||||
}
|
||||
|
||||
budgetPath := "generationConfig.thinkingConfig.thinkingBudget"
|
||||
updated, err := sjson.SetBytes(body, budgetPath, budget)
|
||||
if err != nil {
|
||||
return body
|
||||
}
|
||||
|
||||
// Remove level
|
||||
updated, err = sjson.DeleteBytes(updated, levelPath)
|
||||
if err != nil {
|
||||
return body
|
||||
}
|
||||
return updated
|
||||
}
|
||||
|
||||
// ConvertThinkingLevelToBudgetCLI checks for "request.generationConfig.thinkingConfig.thinkingLevel"
|
||||
// and converts it to "thinkingBudget" for Gemini 2.5 models.
|
||||
// For Gemini 3 models, preserves thinkingLevel as-is (does not convert).
|
||||
func ConvertThinkingLevelToBudgetCLI(body []byte, model string) []byte {
|
||||
levelPath := "request.generationConfig.thinkingConfig.thinkingLevel"
|
||||
res := gjson.GetBytes(body, levelPath)
|
||||
if !res.Exists() {
|
||||
return body
|
||||
}
|
||||
|
||||
// For Gemini 3 models, preserve thinkingLevel - don't convert to budget
|
||||
if IsGemini3Model(model) {
|
||||
return body
|
||||
}
|
||||
|
||||
budget, ok := ThinkingLevelToBudget(res.String())
|
||||
if !ok {
|
||||
updated, _ := sjson.DeleteBytes(body, levelPath)
|
||||
return updated
|
||||
}
|
||||
|
||||
budgetPath := "request.generationConfig.thinkingConfig.thinkingBudget"
|
||||
updated, err := sjson.SetBytes(body, budgetPath, budget)
|
||||
if err != nil {
|
||||
return body
|
||||
}
|
||||
|
||||
updated, err = sjson.DeleteBytes(updated, levelPath)
|
||||
if err != nil {
|
||||
return body
|
||||
|
||||
@@ -1,37 +0,0 @@
|
||||
package util
|
||||
|
||||
// OpenAIThinkingBudgetToEffort maps a numeric thinking budget (tokens)
|
||||
// into an OpenAI-style reasoning effort level for level-based models.
|
||||
//
|
||||
// Ranges:
|
||||
// - 0 -> "none"
|
||||
// - -1 -> "auto"
|
||||
// - 1..1024 -> "low"
|
||||
// - 1025..8192 -> "medium"
|
||||
// - 8193..24576 -> "high"
|
||||
// - 24577.. -> highest supported level for the model (defaults to "xhigh")
|
||||
//
|
||||
// Negative values other than -1 are treated as unsupported.
|
||||
func OpenAIThinkingBudgetToEffort(model string, budget int) (string, bool) {
|
||||
switch {
|
||||
case budget == -1:
|
||||
return "auto", true
|
||||
case budget < -1:
|
||||
return "", false
|
||||
case budget == 0:
|
||||
return "none", true
|
||||
case budget > 0 && budget <= 1024:
|
||||
return "low", true
|
||||
case budget <= 8192:
|
||||
return "medium", true
|
||||
case budget <= 24576:
|
||||
return "high", true
|
||||
case budget > 24576:
|
||||
if levels := GetModelThinkingLevels(model); len(levels) > 0 {
|
||||
return levels[len(levels)-1], true
|
||||
}
|
||||
return "xhigh", true
|
||||
default:
|
||||
return "", false
|
||||
}
|
||||
}
|
||||
@@ -118,3 +118,111 @@ func IsOpenAICompatibilityModel(model string) bool {
|
||||
}
|
||||
return strings.EqualFold(strings.TrimSpace(info.Type), "openai-compatibility")
|
||||
}
|
||||
|
||||
// ThinkingEffortToBudget maps a reasoning effort level to a numeric thinking budget (tokens),
|
||||
// clamping the result to the model's supported range.
|
||||
//
|
||||
// Mappings (values are normalized to model's supported range):
|
||||
// - "none" -> 0
|
||||
// - "auto" -> -1
|
||||
// - "minimal" -> 512
|
||||
// - "low" -> 1024
|
||||
// - "medium" -> 8192
|
||||
// - "high" -> 24576
|
||||
// - "xhigh" -> 32768
|
||||
//
|
||||
// Returns false when the effort level is empty or unsupported.
|
||||
func ThinkingEffortToBudget(model, effort string) (int, bool) {
|
||||
if effort == "" {
|
||||
return 0, false
|
||||
}
|
||||
normalized, ok := NormalizeReasoningEffortLevel(model, effort)
|
||||
if !ok {
|
||||
normalized = strings.ToLower(strings.TrimSpace(effort))
|
||||
}
|
||||
switch normalized {
|
||||
case "none":
|
||||
return 0, true
|
||||
case "auto":
|
||||
return NormalizeThinkingBudget(model, -1), true
|
||||
case "minimal":
|
||||
return NormalizeThinkingBudget(model, 512), true
|
||||
case "low":
|
||||
return NormalizeThinkingBudget(model, 1024), true
|
||||
case "medium":
|
||||
return NormalizeThinkingBudget(model, 8192), true
|
||||
case "high":
|
||||
return NormalizeThinkingBudget(model, 24576), true
|
||||
case "xhigh":
|
||||
return NormalizeThinkingBudget(model, 32768), true
|
||||
default:
|
||||
return 0, false
|
||||
}
|
||||
}
|
||||
|
||||
// ThinkingLevelToBudget maps a Gemini thinkingLevel to a numeric thinking budget (tokens).
|
||||
//
|
||||
// Mappings:
|
||||
// - "minimal" -> 512
|
||||
// - "low" -> 1024
|
||||
// - "medium" -> 8192
|
||||
// - "high" -> 32768
|
||||
//
|
||||
// Returns false when the level is empty or unsupported.
|
||||
func ThinkingLevelToBudget(level string) (int, bool) {
|
||||
if level == "" {
|
||||
return 0, false
|
||||
}
|
||||
normalized := strings.ToLower(strings.TrimSpace(level))
|
||||
switch normalized {
|
||||
case "minimal":
|
||||
return 512, true
|
||||
case "low":
|
||||
return 1024, true
|
||||
case "medium":
|
||||
return 8192, true
|
||||
case "high":
|
||||
return 32768, true
|
||||
default:
|
||||
return 0, false
|
||||
}
|
||||
}
|
||||
|
||||
// ThinkingBudgetToEffort maps a numeric thinking budget (tokens)
|
||||
// to a reasoning effort level for level-based models.
|
||||
//
|
||||
// Mappings:
|
||||
// - 0 -> "none" (or lowest supported level if model doesn't support "none")
|
||||
// - -1 -> "auto"
|
||||
// - 1..1024 -> "low"
|
||||
// - 1025..8192 -> "medium"
|
||||
// - 8193..24576 -> "high"
|
||||
// - 24577.. -> highest supported level for the model (defaults to "xhigh")
|
||||
//
|
||||
// Returns false when the budget is unsupported (negative values other than -1).
|
||||
func ThinkingBudgetToEffort(model string, budget int) (string, bool) {
|
||||
switch {
|
||||
case budget == -1:
|
||||
return "auto", true
|
||||
case budget < -1:
|
||||
return "", false
|
||||
case budget == 0:
|
||||
if levels := GetModelThinkingLevels(model); len(levels) > 0 {
|
||||
return levels[0], true
|
||||
}
|
||||
return "none", true
|
||||
case budget > 0 && budget <= 1024:
|
||||
return "low", true
|
||||
case budget <= 8192:
|
||||
return "medium", true
|
||||
case budget <= 24576:
|
||||
return "high", true
|
||||
case budget > 24576:
|
||||
if levels := GetModelThinkingLevels(model); len(levels) > 0 {
|
||||
return levels[len(levels)-1], true
|
||||
}
|
||||
return "xhigh", true
|
||||
default:
|
||||
return "", false
|
||||
}
|
||||
}
|
||||
|
||||
@@ -201,36 +201,6 @@ func ReasoningEffortFromMetadata(metadata map[string]any) (string, bool) {
|
||||
return "", true
|
||||
}
|
||||
|
||||
// ThinkingEffortToBudget maps reasoning effort levels to approximate budgets,
|
||||
// clamping the result to the model's supported range.
|
||||
func ThinkingEffortToBudget(model, effort string) (int, bool) {
|
||||
if effort == "" {
|
||||
return 0, false
|
||||
}
|
||||
normalized, ok := NormalizeReasoningEffortLevel(model, effort)
|
||||
if !ok {
|
||||
normalized = strings.ToLower(strings.TrimSpace(effort))
|
||||
}
|
||||
switch normalized {
|
||||
case "none":
|
||||
return 0, true
|
||||
case "auto":
|
||||
return NormalizeThinkingBudget(model, -1), true
|
||||
case "minimal":
|
||||
return NormalizeThinkingBudget(model, 512), true
|
||||
case "low":
|
||||
return NormalizeThinkingBudget(model, 1024), true
|
||||
case "medium":
|
||||
return NormalizeThinkingBudget(model, 8192), true
|
||||
case "high":
|
||||
return NormalizeThinkingBudget(model, 24576), true
|
||||
case "xhigh":
|
||||
return NormalizeThinkingBudget(model, 32768), true
|
||||
default:
|
||||
return 0, false
|
||||
}
|
||||
}
|
||||
|
||||
// ResolveOriginalModel returns the original model name stored in metadata (if present),
|
||||
// otherwise falls back to the provided model.
|
||||
func ResolveOriginalModel(model string, metadata map[string]any) string {
|
||||
|
||||
@@ -6,6 +6,7 @@ package util
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/tidwall/gjson"
|
||||
"github.com/tidwall/sjson"
|
||||
@@ -28,10 +29,17 @@ func Walk(value gjson.Result, path, field string, paths *[]string) {
|
||||
// For JSON objects and arrays, iterate through each child
|
||||
value.ForEach(func(key, val gjson.Result) bool {
|
||||
var childPath string
|
||||
// Escape special characters for gjson/sjson path syntax
|
||||
// . -> \.
|
||||
// * -> \*
|
||||
// ? -> \?
|
||||
var keyReplacer = strings.NewReplacer(".", "\\.", "*", "\\*", "?", "\\?")
|
||||
safeKey := keyReplacer.Replace(key.String())
|
||||
|
||||
if path == "" {
|
||||
childPath = key.String()
|
||||
childPath = safeKey
|
||||
} else {
|
||||
childPath = path + "." + key.String()
|
||||
childPath = path + "." + safeKey
|
||||
}
|
||||
if key.String() == field {
|
||||
*paths = append(*paths, childPath)
|
||||
|
||||
270
internal/watcher/clients.go
Normal file
270
internal/watcher/clients.go
Normal file
@@ -0,0 +1,270 @@
|
||||
// clients.go implements watcher client lifecycle logic and persistence helpers.
|
||||
// It reloads clients, handles incremental auth file changes, and persists updates when supported.
|
||||
package watcher
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"io/fs"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
||||
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
func (w *Watcher) reloadClients(rescanAuth bool, affectedOAuthProviders []string, forceAuthRefresh bool) {
|
||||
log.Debugf("starting full client load process")
|
||||
|
||||
w.clientsMutex.RLock()
|
||||
cfg := w.config
|
||||
w.clientsMutex.RUnlock()
|
||||
|
||||
if cfg == nil {
|
||||
log.Error("config is nil, cannot reload clients")
|
||||
return
|
||||
}
|
||||
|
||||
if len(affectedOAuthProviders) > 0 {
|
||||
w.clientsMutex.Lock()
|
||||
if w.currentAuths != nil {
|
||||
filtered := make(map[string]*coreauth.Auth, len(w.currentAuths))
|
||||
for id, auth := range w.currentAuths {
|
||||
if auth == nil {
|
||||
continue
|
||||
}
|
||||
provider := strings.ToLower(strings.TrimSpace(auth.Provider))
|
||||
if _, match := matchProvider(provider, affectedOAuthProviders); match {
|
||||
continue
|
||||
}
|
||||
filtered[id] = auth
|
||||
}
|
||||
w.currentAuths = filtered
|
||||
log.Debugf("applying oauth-excluded-models to providers %v", affectedOAuthProviders)
|
||||
} else {
|
||||
w.currentAuths = nil
|
||||
}
|
||||
w.clientsMutex.Unlock()
|
||||
}
|
||||
|
||||
geminiAPIKeyCount, vertexCompatAPIKeyCount, claudeAPIKeyCount, codexAPIKeyCount, openAICompatCount := BuildAPIKeyClients(cfg)
|
||||
totalAPIKeyClients := geminiAPIKeyCount + vertexCompatAPIKeyCount + claudeAPIKeyCount + codexAPIKeyCount + openAICompatCount
|
||||
log.Debugf("loaded %d API key clients", totalAPIKeyClients)
|
||||
|
||||
var authFileCount int
|
||||
if rescanAuth {
|
||||
authFileCount = w.loadFileClients(cfg)
|
||||
log.Debugf("loaded %d file-based clients", authFileCount)
|
||||
} else {
|
||||
w.clientsMutex.RLock()
|
||||
authFileCount = len(w.lastAuthHashes)
|
||||
w.clientsMutex.RUnlock()
|
||||
log.Debugf("skipping auth directory rescan; retaining %d existing auth files", authFileCount)
|
||||
}
|
||||
|
||||
if rescanAuth {
|
||||
w.clientsMutex.Lock()
|
||||
|
||||
w.lastAuthHashes = make(map[string]string)
|
||||
if resolvedAuthDir, errResolveAuthDir := util.ResolveAuthDir(cfg.AuthDir); errResolveAuthDir != nil {
|
||||
log.Errorf("failed to resolve auth directory for hash cache: %v", errResolveAuthDir)
|
||||
} else if resolvedAuthDir != "" {
|
||||
_ = filepath.Walk(resolvedAuthDir, func(path string, info fs.FileInfo, err error) error {
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
if !info.IsDir() && strings.HasSuffix(strings.ToLower(info.Name()), ".json") {
|
||||
if data, errReadFile := os.ReadFile(path); errReadFile == nil && len(data) > 0 {
|
||||
sum := sha256.Sum256(data)
|
||||
normalizedPath := w.normalizeAuthPath(path)
|
||||
w.lastAuthHashes[normalizedPath] = hex.EncodeToString(sum[:])
|
||||
}
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
w.clientsMutex.Unlock()
|
||||
}
|
||||
|
||||
totalNewClients := authFileCount + geminiAPIKeyCount + vertexCompatAPIKeyCount + claudeAPIKeyCount + codexAPIKeyCount + openAICompatCount
|
||||
|
||||
if w.reloadCallback != nil {
|
||||
log.Debugf("triggering server update callback before auth refresh")
|
||||
w.reloadCallback(cfg)
|
||||
}
|
||||
|
||||
w.refreshAuthState(forceAuthRefresh)
|
||||
|
||||
log.Infof("full client load complete - %d clients (%d auth files + %d Gemini API keys + %d Vertex API keys + %d Claude API keys + %d Codex keys + %d OpenAI-compat)",
|
||||
totalNewClients,
|
||||
authFileCount,
|
||||
geminiAPIKeyCount,
|
||||
vertexCompatAPIKeyCount,
|
||||
claudeAPIKeyCount,
|
||||
codexAPIKeyCount,
|
||||
openAICompatCount,
|
||||
)
|
||||
}
|
||||
|
||||
func (w *Watcher) addOrUpdateClient(path string) {
|
||||
data, errRead := os.ReadFile(path)
|
||||
if errRead != nil {
|
||||
log.Errorf("failed to read auth file %s: %v", filepath.Base(path), errRead)
|
||||
return
|
||||
}
|
||||
if len(data) == 0 {
|
||||
log.Debugf("ignoring empty auth file: %s", filepath.Base(path))
|
||||
return
|
||||
}
|
||||
|
||||
sum := sha256.Sum256(data)
|
||||
curHash := hex.EncodeToString(sum[:])
|
||||
normalized := w.normalizeAuthPath(path)
|
||||
|
||||
w.clientsMutex.Lock()
|
||||
|
||||
cfg := w.config
|
||||
if cfg == nil {
|
||||
log.Error("config is nil, cannot add or update client")
|
||||
w.clientsMutex.Unlock()
|
||||
return
|
||||
}
|
||||
if prev, ok := w.lastAuthHashes[normalized]; ok && prev == curHash {
|
||||
log.Debugf("auth file unchanged (hash match), skipping reload: %s", filepath.Base(path))
|
||||
w.clientsMutex.Unlock()
|
||||
return
|
||||
}
|
||||
|
||||
w.lastAuthHashes[normalized] = curHash
|
||||
|
||||
w.clientsMutex.Unlock() // Unlock before the callback
|
||||
|
||||
w.refreshAuthState(false)
|
||||
|
||||
if w.reloadCallback != nil {
|
||||
log.Debugf("triggering server update callback after add/update")
|
||||
w.reloadCallback(cfg)
|
||||
}
|
||||
w.persistAuthAsync(fmt.Sprintf("Sync auth %s", filepath.Base(path)), path)
|
||||
}
|
||||
|
||||
func (w *Watcher) removeClient(path string) {
|
||||
normalized := w.normalizeAuthPath(path)
|
||||
w.clientsMutex.Lock()
|
||||
|
||||
cfg := w.config
|
||||
delete(w.lastAuthHashes, normalized)
|
||||
|
||||
w.clientsMutex.Unlock() // Release the lock before the callback
|
||||
|
||||
w.refreshAuthState(false)
|
||||
|
||||
if w.reloadCallback != nil {
|
||||
log.Debugf("triggering server update callback after removal")
|
||||
w.reloadCallback(cfg)
|
||||
}
|
||||
w.persistAuthAsync(fmt.Sprintf("Remove auth %s", filepath.Base(path)), path)
|
||||
}
|
||||
|
||||
func (w *Watcher) loadFileClients(cfg *config.Config) int {
|
||||
authFileCount := 0
|
||||
successfulAuthCount := 0
|
||||
|
||||
authDir, errResolveAuthDir := util.ResolveAuthDir(cfg.AuthDir)
|
||||
if errResolveAuthDir != nil {
|
||||
log.Errorf("failed to resolve auth directory: %v", errResolveAuthDir)
|
||||
return 0
|
||||
}
|
||||
if authDir == "" {
|
||||
return 0
|
||||
}
|
||||
|
||||
errWalk := filepath.Walk(authDir, func(path string, info fs.FileInfo, err error) error {
|
||||
if err != nil {
|
||||
log.Debugf("error accessing path %s: %v", path, err)
|
||||
return err
|
||||
}
|
||||
if !info.IsDir() && strings.HasSuffix(strings.ToLower(info.Name()), ".json") {
|
||||
authFileCount++
|
||||
log.Debugf("processing auth file %d: %s", authFileCount, filepath.Base(path))
|
||||
if data, errCreate := os.ReadFile(path); errCreate == nil && len(data) > 0 {
|
||||
successfulAuthCount++
|
||||
}
|
||||
}
|
||||
return nil
|
||||
})
|
||||
|
||||
if errWalk != nil {
|
||||
log.Errorf("error walking auth directory: %v", errWalk)
|
||||
}
|
||||
log.Debugf("auth directory scan complete - found %d .json files, %d readable", authFileCount, successfulAuthCount)
|
||||
return authFileCount
|
||||
}
|
||||
|
||||
func BuildAPIKeyClients(cfg *config.Config) (int, int, int, int, int) {
|
||||
geminiAPIKeyCount := 0
|
||||
vertexCompatAPIKeyCount := 0
|
||||
claudeAPIKeyCount := 0
|
||||
codexAPIKeyCount := 0
|
||||
openAICompatCount := 0
|
||||
|
||||
if len(cfg.GeminiKey) > 0 {
|
||||
geminiAPIKeyCount += len(cfg.GeminiKey)
|
||||
}
|
||||
if len(cfg.VertexCompatAPIKey) > 0 {
|
||||
vertexCompatAPIKeyCount += len(cfg.VertexCompatAPIKey)
|
||||
}
|
||||
if len(cfg.ClaudeKey) > 0 {
|
||||
claudeAPIKeyCount += len(cfg.ClaudeKey)
|
||||
}
|
||||
if len(cfg.CodexKey) > 0 {
|
||||
codexAPIKeyCount += len(cfg.CodexKey)
|
||||
}
|
||||
if len(cfg.OpenAICompatibility) > 0 {
|
||||
for _, compatConfig := range cfg.OpenAICompatibility {
|
||||
openAICompatCount += len(compatConfig.APIKeyEntries)
|
||||
}
|
||||
}
|
||||
return geminiAPIKeyCount, vertexCompatAPIKeyCount, claudeAPIKeyCount, codexAPIKeyCount, openAICompatCount
|
||||
}
|
||||
|
||||
func (w *Watcher) persistConfigAsync() {
|
||||
if w == nil || w.storePersister == nil {
|
||||
return
|
||||
}
|
||||
go func() {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
defer cancel()
|
||||
if err := w.storePersister.PersistConfig(ctx); err != nil {
|
||||
log.Errorf("failed to persist config change: %v", err)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
func (w *Watcher) persistAuthAsync(message string, paths ...string) {
|
||||
if w == nil || w.storePersister == nil {
|
||||
return
|
||||
}
|
||||
filtered := make([]string, 0, len(paths))
|
||||
for _, p := range paths {
|
||||
if trimmed := strings.TrimSpace(p); trimmed != "" {
|
||||
filtered = append(filtered, trimmed)
|
||||
}
|
||||
}
|
||||
if len(filtered) == 0 {
|
||||
return
|
||||
}
|
||||
go func() {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||
defer cancel()
|
||||
if err := w.storePersister.PersistAuthFiles(ctx, message, filtered...); err != nil {
|
||||
log.Errorf("failed to persist auth changes: %v", err)
|
||||
}
|
||||
}()
|
||||
}
|
||||
134
internal/watcher/config_reload.go
Normal file
134
internal/watcher/config_reload.go
Normal file
@@ -0,0 +1,134 @@
|
||||
// config_reload.go implements debounced configuration hot reload.
|
||||
// It detects material changes and reloads clients when the config changes.
|
||||
package watcher
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"os"
|
||||
"time"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/watcher/diff"
|
||||
"gopkg.in/yaml.v3"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
func (w *Watcher) stopConfigReloadTimer() {
|
||||
w.configReloadMu.Lock()
|
||||
if w.configReloadTimer != nil {
|
||||
w.configReloadTimer.Stop()
|
||||
w.configReloadTimer = nil
|
||||
}
|
||||
w.configReloadMu.Unlock()
|
||||
}
|
||||
|
||||
func (w *Watcher) scheduleConfigReload() {
|
||||
w.configReloadMu.Lock()
|
||||
defer w.configReloadMu.Unlock()
|
||||
if w.configReloadTimer != nil {
|
||||
w.configReloadTimer.Stop()
|
||||
}
|
||||
w.configReloadTimer = time.AfterFunc(configReloadDebounce, func() {
|
||||
w.configReloadMu.Lock()
|
||||
w.configReloadTimer = nil
|
||||
w.configReloadMu.Unlock()
|
||||
w.reloadConfigIfChanged()
|
||||
})
|
||||
}
|
||||
|
||||
func (w *Watcher) reloadConfigIfChanged() {
|
||||
data, err := os.ReadFile(w.configPath)
|
||||
if err != nil {
|
||||
log.Errorf("failed to read config file for hash check: %v", err)
|
||||
return
|
||||
}
|
||||
if len(data) == 0 {
|
||||
log.Debugf("ignoring empty config file write event")
|
||||
return
|
||||
}
|
||||
sum := sha256.Sum256(data)
|
||||
newHash := hex.EncodeToString(sum[:])
|
||||
|
||||
w.clientsMutex.RLock()
|
||||
currentHash := w.lastConfigHash
|
||||
w.clientsMutex.RUnlock()
|
||||
|
||||
if currentHash != "" && currentHash == newHash {
|
||||
log.Debugf("config file content unchanged (hash match), skipping reload")
|
||||
return
|
||||
}
|
||||
log.Infof("config file changed, reloading: %s", w.configPath)
|
||||
if w.reloadConfig() {
|
||||
finalHash := newHash
|
||||
if updatedData, errRead := os.ReadFile(w.configPath); errRead == nil && len(updatedData) > 0 {
|
||||
sumUpdated := sha256.Sum256(updatedData)
|
||||
finalHash = hex.EncodeToString(sumUpdated[:])
|
||||
} else if errRead != nil {
|
||||
log.WithError(errRead).Debug("failed to compute updated config hash after reload")
|
||||
}
|
||||
w.clientsMutex.Lock()
|
||||
w.lastConfigHash = finalHash
|
||||
w.clientsMutex.Unlock()
|
||||
w.persistConfigAsync()
|
||||
}
|
||||
}
|
||||
|
||||
func (w *Watcher) reloadConfig() bool {
|
||||
log.Debug("=========================== CONFIG RELOAD ============================")
|
||||
log.Debugf("starting config reload from: %s", w.configPath)
|
||||
|
||||
newConfig, errLoadConfig := config.LoadConfig(w.configPath)
|
||||
if errLoadConfig != nil {
|
||||
log.Errorf("failed to reload config: %v", errLoadConfig)
|
||||
return false
|
||||
}
|
||||
|
||||
if w.mirroredAuthDir != "" {
|
||||
newConfig.AuthDir = w.mirroredAuthDir
|
||||
} else {
|
||||
if resolvedAuthDir, errResolveAuthDir := util.ResolveAuthDir(newConfig.AuthDir); errResolveAuthDir != nil {
|
||||
log.Errorf("failed to resolve auth directory from config: %v", errResolveAuthDir)
|
||||
} else {
|
||||
newConfig.AuthDir = resolvedAuthDir
|
||||
}
|
||||
}
|
||||
|
||||
w.clientsMutex.Lock()
|
||||
var oldConfig *config.Config
|
||||
_ = yaml.Unmarshal(w.oldConfigYaml, &oldConfig)
|
||||
w.oldConfigYaml, _ = yaml.Marshal(newConfig)
|
||||
w.config = newConfig
|
||||
w.clientsMutex.Unlock()
|
||||
|
||||
var affectedOAuthProviders []string
|
||||
if oldConfig != nil {
|
||||
_, affectedOAuthProviders = diff.DiffOAuthExcludedModelChanges(oldConfig.OAuthExcludedModels, newConfig.OAuthExcludedModels)
|
||||
}
|
||||
|
||||
util.SetLogLevel(newConfig)
|
||||
if oldConfig != nil && oldConfig.Debug != newConfig.Debug {
|
||||
log.Debugf("log level updated - debug mode changed from %t to %t", oldConfig.Debug, newConfig.Debug)
|
||||
}
|
||||
|
||||
if oldConfig != nil {
|
||||
details := diff.BuildConfigChangeDetails(oldConfig, newConfig)
|
||||
if len(details) > 0 {
|
||||
log.Debugf("config changes detected:")
|
||||
for _, d := range details {
|
||||
log.Debugf(" %s", d)
|
||||
}
|
||||
} else {
|
||||
log.Debugf("no material config field changes detected")
|
||||
}
|
||||
}
|
||||
|
||||
authDirChanged := oldConfig == nil || oldConfig.AuthDir != newConfig.AuthDir
|
||||
forceAuthRefresh := oldConfig != nil && oldConfig.ForceModelPrefix != newConfig.ForceModelPrefix
|
||||
|
||||
log.Infof("config successfully reloaded, triggering client reload")
|
||||
w.reloadClients(authDirChanged, affectedOAuthProviders, forceAuthRefresh)
|
||||
return true
|
||||
}
|
||||
303
internal/watcher/diff/config_diff.go
Normal file
303
internal/watcher/diff/config_diff.go
Normal file
@@ -0,0 +1,303 @@
|
||||
package diff
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/url"
|
||||
"reflect"
|
||||
"strings"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
)
|
||||
|
||||
// BuildConfigChangeDetails computes a redacted, human-readable list of config changes.
|
||||
// Secrets are never printed; only structural or non-sensitive fields are surfaced.
|
||||
func BuildConfigChangeDetails(oldCfg, newCfg *config.Config) []string {
|
||||
changes := make([]string, 0, 16)
|
||||
if oldCfg == nil || newCfg == nil {
|
||||
return changes
|
||||
}
|
||||
|
||||
// Simple scalars
|
||||
if oldCfg.Port != newCfg.Port {
|
||||
changes = append(changes, fmt.Sprintf("port: %d -> %d", oldCfg.Port, newCfg.Port))
|
||||
}
|
||||
if oldCfg.AuthDir != newCfg.AuthDir {
|
||||
changes = append(changes, fmt.Sprintf("auth-dir: %s -> %s", oldCfg.AuthDir, newCfg.AuthDir))
|
||||
}
|
||||
if oldCfg.Debug != newCfg.Debug {
|
||||
changes = append(changes, fmt.Sprintf("debug: %t -> %t", oldCfg.Debug, newCfg.Debug))
|
||||
}
|
||||
if oldCfg.LoggingToFile != newCfg.LoggingToFile {
|
||||
changes = append(changes, fmt.Sprintf("logging-to-file: %t -> %t", oldCfg.LoggingToFile, newCfg.LoggingToFile))
|
||||
}
|
||||
if oldCfg.UsageStatisticsEnabled != newCfg.UsageStatisticsEnabled {
|
||||
changes = append(changes, fmt.Sprintf("usage-statistics-enabled: %t -> %t", oldCfg.UsageStatisticsEnabled, newCfg.UsageStatisticsEnabled))
|
||||
}
|
||||
if oldCfg.DisableCooling != newCfg.DisableCooling {
|
||||
changes = append(changes, fmt.Sprintf("disable-cooling: %t -> %t", oldCfg.DisableCooling, newCfg.DisableCooling))
|
||||
}
|
||||
if oldCfg.RequestLog != newCfg.RequestLog {
|
||||
changes = append(changes, fmt.Sprintf("request-log: %t -> %t", oldCfg.RequestLog, newCfg.RequestLog))
|
||||
}
|
||||
if oldCfg.RequestRetry != newCfg.RequestRetry {
|
||||
changes = append(changes, fmt.Sprintf("request-retry: %d -> %d", oldCfg.RequestRetry, newCfg.RequestRetry))
|
||||
}
|
||||
if oldCfg.MaxRetryInterval != newCfg.MaxRetryInterval {
|
||||
changes = append(changes, fmt.Sprintf("max-retry-interval: %d -> %d", oldCfg.MaxRetryInterval, newCfg.MaxRetryInterval))
|
||||
}
|
||||
if oldCfg.ProxyURL != newCfg.ProxyURL {
|
||||
changes = append(changes, fmt.Sprintf("proxy-url: %s -> %s", formatProxyURL(oldCfg.ProxyURL), formatProxyURL(newCfg.ProxyURL)))
|
||||
}
|
||||
if oldCfg.WebsocketAuth != newCfg.WebsocketAuth {
|
||||
changes = append(changes, fmt.Sprintf("ws-auth: %t -> %t", oldCfg.WebsocketAuth, newCfg.WebsocketAuth))
|
||||
}
|
||||
if oldCfg.ForceModelPrefix != newCfg.ForceModelPrefix {
|
||||
changes = append(changes, fmt.Sprintf("force-model-prefix: %t -> %t", oldCfg.ForceModelPrefix, newCfg.ForceModelPrefix))
|
||||
}
|
||||
|
||||
// Quota-exceeded behavior
|
||||
if oldCfg.QuotaExceeded.SwitchProject != newCfg.QuotaExceeded.SwitchProject {
|
||||
changes = append(changes, fmt.Sprintf("quota-exceeded.switch-project: %t -> %t", oldCfg.QuotaExceeded.SwitchProject, newCfg.QuotaExceeded.SwitchProject))
|
||||
}
|
||||
if oldCfg.QuotaExceeded.SwitchPreviewModel != newCfg.QuotaExceeded.SwitchPreviewModel {
|
||||
changes = append(changes, fmt.Sprintf("quota-exceeded.switch-preview-model: %t -> %t", oldCfg.QuotaExceeded.SwitchPreviewModel, newCfg.QuotaExceeded.SwitchPreviewModel))
|
||||
}
|
||||
|
||||
// API keys (redacted) and counts
|
||||
if len(oldCfg.APIKeys) != len(newCfg.APIKeys) {
|
||||
changes = append(changes, fmt.Sprintf("api-keys count: %d -> %d", len(oldCfg.APIKeys), len(newCfg.APIKeys)))
|
||||
} else if !reflect.DeepEqual(trimStrings(oldCfg.APIKeys), trimStrings(newCfg.APIKeys)) {
|
||||
changes = append(changes, "api-keys: values updated (count unchanged, redacted)")
|
||||
}
|
||||
if len(oldCfg.GeminiKey) != len(newCfg.GeminiKey) {
|
||||
changes = append(changes, fmt.Sprintf("gemini-api-key count: %d -> %d", len(oldCfg.GeminiKey), len(newCfg.GeminiKey)))
|
||||
} else {
|
||||
for i := range oldCfg.GeminiKey {
|
||||
o := oldCfg.GeminiKey[i]
|
||||
n := newCfg.GeminiKey[i]
|
||||
if strings.TrimSpace(o.BaseURL) != strings.TrimSpace(n.BaseURL) {
|
||||
changes = append(changes, fmt.Sprintf("gemini[%d].base-url: %s -> %s", i, strings.TrimSpace(o.BaseURL), strings.TrimSpace(n.BaseURL)))
|
||||
}
|
||||
if strings.TrimSpace(o.ProxyURL) != strings.TrimSpace(n.ProxyURL) {
|
||||
changes = append(changes, fmt.Sprintf("gemini[%d].proxy-url: %s -> %s", i, formatProxyURL(o.ProxyURL), formatProxyURL(n.ProxyURL)))
|
||||
}
|
||||
if strings.TrimSpace(o.Prefix) != strings.TrimSpace(n.Prefix) {
|
||||
changes = append(changes, fmt.Sprintf("gemini[%d].prefix: %s -> %s", i, strings.TrimSpace(o.Prefix), strings.TrimSpace(n.Prefix)))
|
||||
}
|
||||
if strings.TrimSpace(o.APIKey) != strings.TrimSpace(n.APIKey) {
|
||||
changes = append(changes, fmt.Sprintf("gemini[%d].api-key: updated", i))
|
||||
}
|
||||
if !equalStringMap(o.Headers, n.Headers) {
|
||||
changes = append(changes, fmt.Sprintf("gemini[%d].headers: updated", i))
|
||||
}
|
||||
oldExcluded := SummarizeExcludedModels(o.ExcludedModels)
|
||||
newExcluded := SummarizeExcludedModels(n.ExcludedModels)
|
||||
if oldExcluded.hash != newExcluded.hash {
|
||||
changes = append(changes, fmt.Sprintf("gemini[%d].excluded-models: updated (%d -> %d entries)", i, oldExcluded.count, newExcluded.count))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Claude keys (do not print key material)
|
||||
if len(oldCfg.ClaudeKey) != len(newCfg.ClaudeKey) {
|
||||
changes = append(changes, fmt.Sprintf("claude-api-key count: %d -> %d", len(oldCfg.ClaudeKey), len(newCfg.ClaudeKey)))
|
||||
} else {
|
||||
for i := range oldCfg.ClaudeKey {
|
||||
o := oldCfg.ClaudeKey[i]
|
||||
n := newCfg.ClaudeKey[i]
|
||||
if strings.TrimSpace(o.BaseURL) != strings.TrimSpace(n.BaseURL) {
|
||||
changes = append(changes, fmt.Sprintf("claude[%d].base-url: %s -> %s", i, strings.TrimSpace(o.BaseURL), strings.TrimSpace(n.BaseURL)))
|
||||
}
|
||||
if strings.TrimSpace(o.ProxyURL) != strings.TrimSpace(n.ProxyURL) {
|
||||
changes = append(changes, fmt.Sprintf("claude[%d].proxy-url: %s -> %s", i, formatProxyURL(o.ProxyURL), formatProxyURL(n.ProxyURL)))
|
||||
}
|
||||
if strings.TrimSpace(o.Prefix) != strings.TrimSpace(n.Prefix) {
|
||||
changes = append(changes, fmt.Sprintf("claude[%d].prefix: %s -> %s", i, strings.TrimSpace(o.Prefix), strings.TrimSpace(n.Prefix)))
|
||||
}
|
||||
if strings.TrimSpace(o.APIKey) != strings.TrimSpace(n.APIKey) {
|
||||
changes = append(changes, fmt.Sprintf("claude[%d].api-key: updated", i))
|
||||
}
|
||||
if !equalStringMap(o.Headers, n.Headers) {
|
||||
changes = append(changes, fmt.Sprintf("claude[%d].headers: updated", i))
|
||||
}
|
||||
oldExcluded := SummarizeExcludedModels(o.ExcludedModels)
|
||||
newExcluded := SummarizeExcludedModels(n.ExcludedModels)
|
||||
if oldExcluded.hash != newExcluded.hash {
|
||||
changes = append(changes, fmt.Sprintf("claude[%d].excluded-models: updated (%d -> %d entries)", i, oldExcluded.count, newExcluded.count))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Codex keys (do not print key material)
|
||||
if len(oldCfg.CodexKey) != len(newCfg.CodexKey) {
|
||||
changes = append(changes, fmt.Sprintf("codex-api-key count: %d -> %d", len(oldCfg.CodexKey), len(newCfg.CodexKey)))
|
||||
} else {
|
||||
for i := range oldCfg.CodexKey {
|
||||
o := oldCfg.CodexKey[i]
|
||||
n := newCfg.CodexKey[i]
|
||||
if strings.TrimSpace(o.BaseURL) != strings.TrimSpace(n.BaseURL) {
|
||||
changes = append(changes, fmt.Sprintf("codex[%d].base-url: %s -> %s", i, strings.TrimSpace(o.BaseURL), strings.TrimSpace(n.BaseURL)))
|
||||
}
|
||||
if strings.TrimSpace(o.ProxyURL) != strings.TrimSpace(n.ProxyURL) {
|
||||
changes = append(changes, fmt.Sprintf("codex[%d].proxy-url: %s -> %s", i, formatProxyURL(o.ProxyURL), formatProxyURL(n.ProxyURL)))
|
||||
}
|
||||
if strings.TrimSpace(o.Prefix) != strings.TrimSpace(n.Prefix) {
|
||||
changes = append(changes, fmt.Sprintf("codex[%d].prefix: %s -> %s", i, strings.TrimSpace(o.Prefix), strings.TrimSpace(n.Prefix)))
|
||||
}
|
||||
if strings.TrimSpace(o.APIKey) != strings.TrimSpace(n.APIKey) {
|
||||
changes = append(changes, fmt.Sprintf("codex[%d].api-key: updated", i))
|
||||
}
|
||||
if !equalStringMap(o.Headers, n.Headers) {
|
||||
changes = append(changes, fmt.Sprintf("codex[%d].headers: updated", i))
|
||||
}
|
||||
oldExcluded := SummarizeExcludedModels(o.ExcludedModels)
|
||||
newExcluded := SummarizeExcludedModels(n.ExcludedModels)
|
||||
if oldExcluded.hash != newExcluded.hash {
|
||||
changes = append(changes, fmt.Sprintf("codex[%d].excluded-models: updated (%d -> %d entries)", i, oldExcluded.count, newExcluded.count))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// AmpCode settings (redacted where needed)
|
||||
oldAmpURL := strings.TrimSpace(oldCfg.AmpCode.UpstreamURL)
|
||||
newAmpURL := strings.TrimSpace(newCfg.AmpCode.UpstreamURL)
|
||||
if oldAmpURL != newAmpURL {
|
||||
changes = append(changes, fmt.Sprintf("ampcode.upstream-url: %s -> %s", oldAmpURL, newAmpURL))
|
||||
}
|
||||
oldAmpKey := strings.TrimSpace(oldCfg.AmpCode.UpstreamAPIKey)
|
||||
newAmpKey := strings.TrimSpace(newCfg.AmpCode.UpstreamAPIKey)
|
||||
switch {
|
||||
case oldAmpKey == "" && newAmpKey != "":
|
||||
changes = append(changes, "ampcode.upstream-api-key: added")
|
||||
case oldAmpKey != "" && newAmpKey == "":
|
||||
changes = append(changes, "ampcode.upstream-api-key: removed")
|
||||
case oldAmpKey != newAmpKey:
|
||||
changes = append(changes, "ampcode.upstream-api-key: updated")
|
||||
}
|
||||
if oldCfg.AmpCode.RestrictManagementToLocalhost != newCfg.AmpCode.RestrictManagementToLocalhost {
|
||||
changes = append(changes, fmt.Sprintf("ampcode.restrict-management-to-localhost: %t -> %t", oldCfg.AmpCode.RestrictManagementToLocalhost, newCfg.AmpCode.RestrictManagementToLocalhost))
|
||||
}
|
||||
oldMappings := SummarizeAmpModelMappings(oldCfg.AmpCode.ModelMappings)
|
||||
newMappings := SummarizeAmpModelMappings(newCfg.AmpCode.ModelMappings)
|
||||
if oldMappings.hash != newMappings.hash {
|
||||
changes = append(changes, fmt.Sprintf("ampcode.model-mappings: updated (%d -> %d entries)", oldMappings.count, newMappings.count))
|
||||
}
|
||||
if oldCfg.AmpCode.ForceModelMappings != newCfg.AmpCode.ForceModelMappings {
|
||||
changes = append(changes, fmt.Sprintf("ampcode.force-model-mappings: %t -> %t", oldCfg.AmpCode.ForceModelMappings, newCfg.AmpCode.ForceModelMappings))
|
||||
}
|
||||
|
||||
if entries, _ := DiffOAuthExcludedModelChanges(oldCfg.OAuthExcludedModels, newCfg.OAuthExcludedModels); len(entries) > 0 {
|
||||
changes = append(changes, entries...)
|
||||
}
|
||||
|
||||
// Remote management (never print the key)
|
||||
if oldCfg.RemoteManagement.AllowRemote != newCfg.RemoteManagement.AllowRemote {
|
||||
changes = append(changes, fmt.Sprintf("remote-management.allow-remote: %t -> %t", oldCfg.RemoteManagement.AllowRemote, newCfg.RemoteManagement.AllowRemote))
|
||||
}
|
||||
if oldCfg.RemoteManagement.DisableControlPanel != newCfg.RemoteManagement.DisableControlPanel {
|
||||
changes = append(changes, fmt.Sprintf("remote-management.disable-control-panel: %t -> %t", oldCfg.RemoteManagement.DisableControlPanel, newCfg.RemoteManagement.DisableControlPanel))
|
||||
}
|
||||
oldPanelRepo := strings.TrimSpace(oldCfg.RemoteManagement.PanelGitHubRepository)
|
||||
newPanelRepo := strings.TrimSpace(newCfg.RemoteManagement.PanelGitHubRepository)
|
||||
if oldPanelRepo != newPanelRepo {
|
||||
changes = append(changes, fmt.Sprintf("remote-management.panel-github-repository: %s -> %s", oldPanelRepo, newPanelRepo))
|
||||
}
|
||||
if oldCfg.RemoteManagement.SecretKey != newCfg.RemoteManagement.SecretKey {
|
||||
switch {
|
||||
case oldCfg.RemoteManagement.SecretKey == "" && newCfg.RemoteManagement.SecretKey != "":
|
||||
changes = append(changes, "remote-management.secret-key: created")
|
||||
case oldCfg.RemoteManagement.SecretKey != "" && newCfg.RemoteManagement.SecretKey == "":
|
||||
changes = append(changes, "remote-management.secret-key: deleted")
|
||||
default:
|
||||
changes = append(changes, "remote-management.secret-key: updated")
|
||||
}
|
||||
}
|
||||
|
||||
// OpenAI compatibility providers (summarized)
|
||||
if compat := DiffOpenAICompatibility(oldCfg.OpenAICompatibility, newCfg.OpenAICompatibility); len(compat) > 0 {
|
||||
changes = append(changes, "openai-compatibility:")
|
||||
for _, c := range compat {
|
||||
changes = append(changes, " "+c)
|
||||
}
|
||||
}
|
||||
|
||||
// Vertex-compatible API keys
|
||||
if len(oldCfg.VertexCompatAPIKey) != len(newCfg.VertexCompatAPIKey) {
|
||||
changes = append(changes, fmt.Sprintf("vertex-api-key count: %d -> %d", len(oldCfg.VertexCompatAPIKey), len(newCfg.VertexCompatAPIKey)))
|
||||
} else {
|
||||
for i := range oldCfg.VertexCompatAPIKey {
|
||||
o := oldCfg.VertexCompatAPIKey[i]
|
||||
n := newCfg.VertexCompatAPIKey[i]
|
||||
if strings.TrimSpace(o.BaseURL) != strings.TrimSpace(n.BaseURL) {
|
||||
changes = append(changes, fmt.Sprintf("vertex[%d].base-url: %s -> %s", i, strings.TrimSpace(o.BaseURL), strings.TrimSpace(n.BaseURL)))
|
||||
}
|
||||
if strings.TrimSpace(o.ProxyURL) != strings.TrimSpace(n.ProxyURL) {
|
||||
changes = append(changes, fmt.Sprintf("vertex[%d].proxy-url: %s -> %s", i, formatProxyURL(o.ProxyURL), formatProxyURL(n.ProxyURL)))
|
||||
}
|
||||
if strings.TrimSpace(o.Prefix) != strings.TrimSpace(n.Prefix) {
|
||||
changes = append(changes, fmt.Sprintf("vertex[%d].prefix: %s -> %s", i, strings.TrimSpace(o.Prefix), strings.TrimSpace(n.Prefix)))
|
||||
}
|
||||
if strings.TrimSpace(o.APIKey) != strings.TrimSpace(n.APIKey) {
|
||||
changes = append(changes, fmt.Sprintf("vertex[%d].api-key: updated", i))
|
||||
}
|
||||
oldModels := SummarizeVertexModels(o.Models)
|
||||
newModels := SummarizeVertexModels(n.Models)
|
||||
if oldModels.hash != newModels.hash {
|
||||
changes = append(changes, fmt.Sprintf("vertex[%d].models: updated (%d -> %d entries)", i, oldModels.count, newModels.count))
|
||||
}
|
||||
if !equalStringMap(o.Headers, n.Headers) {
|
||||
changes = append(changes, fmt.Sprintf("vertex[%d].headers: updated", i))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return changes
|
||||
}
|
||||
|
||||
func trimStrings(in []string) []string {
|
||||
out := make([]string, len(in))
|
||||
for i := range in {
|
||||
out[i] = strings.TrimSpace(in[i])
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func equalStringMap(a, b map[string]string) bool {
|
||||
if len(a) != len(b) {
|
||||
return false
|
||||
}
|
||||
for k, v := range a {
|
||||
if b[k] != v {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func formatProxyURL(raw string) string {
|
||||
trimmed := strings.TrimSpace(raw)
|
||||
if trimmed == "" {
|
||||
return "<none>"
|
||||
}
|
||||
parsed, err := url.Parse(trimmed)
|
||||
if err != nil {
|
||||
return "<redacted>"
|
||||
}
|
||||
host := strings.TrimSpace(parsed.Host)
|
||||
scheme := strings.TrimSpace(parsed.Scheme)
|
||||
if host == "" {
|
||||
// Allow host:port style without scheme.
|
||||
parsed2, err2 := url.Parse("http://" + trimmed)
|
||||
if err2 == nil {
|
||||
host = strings.TrimSpace(parsed2.Host)
|
||||
}
|
||||
scheme = ""
|
||||
}
|
||||
if host == "" {
|
||||
return "<redacted>"
|
||||
}
|
||||
if scheme == "" {
|
||||
return host
|
||||
}
|
||||
return scheme + "://" + host
|
||||
}
|
||||
529
internal/watcher/diff/config_diff_test.go
Normal file
529
internal/watcher/diff/config_diff_test.go
Normal file
@@ -0,0 +1,529 @@
|
||||
package diff
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
sdkconfig "github.com/router-for-me/CLIProxyAPI/v6/sdk/config"
|
||||
)
|
||||
|
||||
func TestBuildConfigChangeDetails(t *testing.T) {
|
||||
oldCfg := &config.Config{
|
||||
Port: 8080,
|
||||
AuthDir: "/tmp/auth-old",
|
||||
GeminiKey: []config.GeminiKey{
|
||||
{APIKey: "old", BaseURL: "http://old", ExcludedModels: []string{"old-model"}},
|
||||
},
|
||||
AmpCode: config.AmpCode{
|
||||
UpstreamURL: "http://old-upstream",
|
||||
ModelMappings: []config.AmpModelMapping{{From: "from-old", To: "to-old"}},
|
||||
RestrictManagementToLocalhost: false,
|
||||
},
|
||||
RemoteManagement: config.RemoteManagement{
|
||||
AllowRemote: false,
|
||||
SecretKey: "old",
|
||||
DisableControlPanel: false,
|
||||
PanelGitHubRepository: "repo-old",
|
||||
},
|
||||
OAuthExcludedModels: map[string][]string{
|
||||
"providerA": {"m1"},
|
||||
},
|
||||
OpenAICompatibility: []config.OpenAICompatibility{
|
||||
{
|
||||
Name: "compat-a",
|
||||
APIKeyEntries: []config.OpenAICompatibilityAPIKey{
|
||||
{APIKey: "k1"},
|
||||
},
|
||||
Models: []config.OpenAICompatibilityModel{{Name: "m1"}},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
newCfg := &config.Config{
|
||||
Port: 9090,
|
||||
AuthDir: "/tmp/auth-new",
|
||||
GeminiKey: []config.GeminiKey{
|
||||
{APIKey: "old", BaseURL: "http://old", ExcludedModels: []string{"old-model", "extra"}},
|
||||
},
|
||||
AmpCode: config.AmpCode{
|
||||
UpstreamURL: "http://new-upstream",
|
||||
RestrictManagementToLocalhost: true,
|
||||
ModelMappings: []config.AmpModelMapping{
|
||||
{From: "from-old", To: "to-old"},
|
||||
{From: "from-new", To: "to-new"},
|
||||
},
|
||||
},
|
||||
RemoteManagement: config.RemoteManagement{
|
||||
AllowRemote: true,
|
||||
SecretKey: "new",
|
||||
DisableControlPanel: true,
|
||||
PanelGitHubRepository: "repo-new",
|
||||
},
|
||||
OAuthExcludedModels: map[string][]string{
|
||||
"providerA": {"m1", "m2"},
|
||||
"providerB": {"x"},
|
||||
},
|
||||
OpenAICompatibility: []config.OpenAICompatibility{
|
||||
{
|
||||
Name: "compat-a",
|
||||
APIKeyEntries: []config.OpenAICompatibilityAPIKey{
|
||||
{APIKey: "k1"},
|
||||
},
|
||||
Models: []config.OpenAICompatibilityModel{{Name: "m1"}, {Name: "m2"}},
|
||||
},
|
||||
{
|
||||
Name: "compat-b",
|
||||
APIKeyEntries: []config.OpenAICompatibilityAPIKey{
|
||||
{APIKey: "k2"},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
details := BuildConfigChangeDetails(oldCfg, newCfg)
|
||||
|
||||
expectContains(t, details, "port: 8080 -> 9090")
|
||||
expectContains(t, details, "auth-dir: /tmp/auth-old -> /tmp/auth-new")
|
||||
expectContains(t, details, "gemini[0].excluded-models: updated (1 -> 2 entries)")
|
||||
expectContains(t, details, "ampcode.upstream-url: http://old-upstream -> http://new-upstream")
|
||||
expectContains(t, details, "ampcode.model-mappings: updated (1 -> 2 entries)")
|
||||
expectContains(t, details, "remote-management.allow-remote: false -> true")
|
||||
expectContains(t, details, "remote-management.secret-key: updated")
|
||||
expectContains(t, details, "oauth-excluded-models[providera]: updated (1 -> 2 entries)")
|
||||
expectContains(t, details, "oauth-excluded-models[providerb]: added (1 entries)")
|
||||
expectContains(t, details, "openai-compatibility:")
|
||||
expectContains(t, details, " provider added: compat-b (api-keys=1, models=0)")
|
||||
expectContains(t, details, " provider updated: compat-a (models 1 -> 2)")
|
||||
}
|
||||
|
||||
func TestBuildConfigChangeDetails_NoChanges(t *testing.T) {
|
||||
cfg := &config.Config{
|
||||
Port: 8080,
|
||||
}
|
||||
if details := BuildConfigChangeDetails(cfg, cfg); len(details) != 0 {
|
||||
t.Fatalf("expected no change entries, got %v", details)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildConfigChangeDetails_GeminiVertexHeadersAndForceMappings(t *testing.T) {
|
||||
oldCfg := &config.Config{
|
||||
GeminiKey: []config.GeminiKey{
|
||||
{APIKey: "g1", Headers: map[string]string{"H": "1"}, ExcludedModels: []string{"a"}},
|
||||
},
|
||||
VertexCompatAPIKey: []config.VertexCompatKey{
|
||||
{APIKey: "v1", BaseURL: "http://v-old", Models: []config.VertexCompatModel{{Name: "m1"}}},
|
||||
},
|
||||
AmpCode: config.AmpCode{
|
||||
ModelMappings: []config.AmpModelMapping{{From: "a", To: "b"}},
|
||||
ForceModelMappings: false,
|
||||
},
|
||||
}
|
||||
newCfg := &config.Config{
|
||||
GeminiKey: []config.GeminiKey{
|
||||
{APIKey: "g1", Headers: map[string]string{"H": "2"}, ExcludedModels: []string{"a", "b"}},
|
||||
},
|
||||
VertexCompatAPIKey: []config.VertexCompatKey{
|
||||
{APIKey: "v1", BaseURL: "http://v-new", Models: []config.VertexCompatModel{{Name: "m1"}, {Name: "m2"}}},
|
||||
},
|
||||
AmpCode: config.AmpCode{
|
||||
ModelMappings: []config.AmpModelMapping{{From: "a", To: "c"}},
|
||||
ForceModelMappings: true,
|
||||
},
|
||||
}
|
||||
|
||||
details := BuildConfigChangeDetails(oldCfg, newCfg)
|
||||
expectContains(t, details, "gemini[0].headers: updated")
|
||||
expectContains(t, details, "gemini[0].excluded-models: updated (1 -> 2 entries)")
|
||||
expectContains(t, details, "ampcode.model-mappings: updated (1 -> 1 entries)")
|
||||
expectContains(t, details, "ampcode.force-model-mappings: false -> true")
|
||||
}
|
||||
|
||||
func TestBuildConfigChangeDetails_ModelPrefixes(t *testing.T) {
|
||||
oldCfg := &config.Config{
|
||||
GeminiKey: []config.GeminiKey{
|
||||
{APIKey: "g1", Prefix: "old-g", BaseURL: "http://g", ProxyURL: "http://gp"},
|
||||
},
|
||||
ClaudeKey: []config.ClaudeKey{
|
||||
{APIKey: "c1", Prefix: "old-c", BaseURL: "http://c", ProxyURL: "http://cp"},
|
||||
},
|
||||
CodexKey: []config.CodexKey{
|
||||
{APIKey: "x1", Prefix: "old-x", BaseURL: "http://x", ProxyURL: "http://xp"},
|
||||
},
|
||||
VertexCompatAPIKey: []config.VertexCompatKey{
|
||||
{APIKey: "v1", Prefix: "old-v", BaseURL: "http://v", ProxyURL: "http://vp"},
|
||||
},
|
||||
}
|
||||
newCfg := &config.Config{
|
||||
GeminiKey: []config.GeminiKey{
|
||||
{APIKey: "g1", Prefix: "new-g", BaseURL: "http://g", ProxyURL: "http://gp"},
|
||||
},
|
||||
ClaudeKey: []config.ClaudeKey{
|
||||
{APIKey: "c1", Prefix: "new-c", BaseURL: "http://c", ProxyURL: "http://cp"},
|
||||
},
|
||||
CodexKey: []config.CodexKey{
|
||||
{APIKey: "x1", Prefix: "new-x", BaseURL: "http://x", ProxyURL: "http://xp"},
|
||||
},
|
||||
VertexCompatAPIKey: []config.VertexCompatKey{
|
||||
{APIKey: "v1", Prefix: "new-v", BaseURL: "http://v", ProxyURL: "http://vp"},
|
||||
},
|
||||
}
|
||||
|
||||
changes := BuildConfigChangeDetails(oldCfg, newCfg)
|
||||
expectContains(t, changes, "gemini[0].prefix: old-g -> new-g")
|
||||
expectContains(t, changes, "claude[0].prefix: old-c -> new-c")
|
||||
expectContains(t, changes, "codex[0].prefix: old-x -> new-x")
|
||||
expectContains(t, changes, "vertex[0].prefix: old-v -> new-v")
|
||||
}
|
||||
|
||||
func TestBuildConfigChangeDetails_NilSafe(t *testing.T) {
|
||||
if details := BuildConfigChangeDetails(nil, &config.Config{}); len(details) != 0 {
|
||||
t.Fatalf("expected empty change list when old nil, got %v", details)
|
||||
}
|
||||
if details := BuildConfigChangeDetails(&config.Config{}, nil); len(details) != 0 {
|
||||
t.Fatalf("expected empty change list when new nil, got %v", details)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildConfigChangeDetails_SecretsAndCounts(t *testing.T) {
|
||||
oldCfg := &config.Config{
|
||||
SDKConfig: sdkconfig.SDKConfig{
|
||||
APIKeys: []string{"a"},
|
||||
},
|
||||
AmpCode: config.AmpCode{
|
||||
UpstreamAPIKey: "",
|
||||
},
|
||||
RemoteManagement: config.RemoteManagement{
|
||||
SecretKey: "",
|
||||
},
|
||||
}
|
||||
newCfg := &config.Config{
|
||||
SDKConfig: sdkconfig.SDKConfig{
|
||||
APIKeys: []string{"a", "b", "c"},
|
||||
},
|
||||
AmpCode: config.AmpCode{
|
||||
UpstreamAPIKey: "new-key",
|
||||
},
|
||||
RemoteManagement: config.RemoteManagement{
|
||||
SecretKey: "new-secret",
|
||||
},
|
||||
}
|
||||
|
||||
details := BuildConfigChangeDetails(oldCfg, newCfg)
|
||||
expectContains(t, details, "api-keys count: 1 -> 3")
|
||||
expectContains(t, details, "ampcode.upstream-api-key: added")
|
||||
expectContains(t, details, "remote-management.secret-key: created")
|
||||
}
|
||||
|
||||
func TestBuildConfigChangeDetails_FlagsAndKeys(t *testing.T) {
|
||||
oldCfg := &config.Config{
|
||||
Port: 1000,
|
||||
AuthDir: "/old",
|
||||
Debug: false,
|
||||
LoggingToFile: false,
|
||||
UsageStatisticsEnabled: false,
|
||||
DisableCooling: false,
|
||||
RequestRetry: 1,
|
||||
MaxRetryInterval: 1,
|
||||
WebsocketAuth: false,
|
||||
QuotaExceeded: config.QuotaExceeded{SwitchProject: false, SwitchPreviewModel: false},
|
||||
ClaudeKey: []config.ClaudeKey{{APIKey: "c1"}},
|
||||
CodexKey: []config.CodexKey{{APIKey: "x1"}},
|
||||
AmpCode: config.AmpCode{UpstreamAPIKey: "keep", RestrictManagementToLocalhost: false},
|
||||
RemoteManagement: config.RemoteManagement{DisableControlPanel: false, PanelGitHubRepository: "old/repo", SecretKey: "keep"},
|
||||
SDKConfig: sdkconfig.SDKConfig{
|
||||
RequestLog: false,
|
||||
ProxyURL: "http://old-proxy",
|
||||
APIKeys: []string{"key-1"},
|
||||
ForceModelPrefix: false,
|
||||
},
|
||||
}
|
||||
newCfg := &config.Config{
|
||||
Port: 2000,
|
||||
AuthDir: "/new",
|
||||
Debug: true,
|
||||
LoggingToFile: true,
|
||||
UsageStatisticsEnabled: true,
|
||||
DisableCooling: true,
|
||||
RequestRetry: 2,
|
||||
MaxRetryInterval: 3,
|
||||
WebsocketAuth: true,
|
||||
QuotaExceeded: config.QuotaExceeded{SwitchProject: true, SwitchPreviewModel: true},
|
||||
ClaudeKey: []config.ClaudeKey{
|
||||
{APIKey: "c1", BaseURL: "http://new", ProxyURL: "http://p", Headers: map[string]string{"H": "1"}, ExcludedModels: []string{"a"}},
|
||||
{APIKey: "c2"},
|
||||
},
|
||||
CodexKey: []config.CodexKey{
|
||||
{APIKey: "x1", BaseURL: "http://x", ProxyURL: "http://px", Headers: map[string]string{"H": "2"}, ExcludedModels: []string{"b"}},
|
||||
{APIKey: "x2"},
|
||||
},
|
||||
AmpCode: config.AmpCode{
|
||||
UpstreamAPIKey: "",
|
||||
RestrictManagementToLocalhost: true,
|
||||
ModelMappings: []config.AmpModelMapping{{From: "a", To: "b"}},
|
||||
},
|
||||
RemoteManagement: config.RemoteManagement{
|
||||
DisableControlPanel: true,
|
||||
PanelGitHubRepository: "new/repo",
|
||||
SecretKey: "",
|
||||
},
|
||||
SDKConfig: sdkconfig.SDKConfig{
|
||||
RequestLog: true,
|
||||
ProxyURL: "http://new-proxy",
|
||||
APIKeys: []string{" key-1 ", "key-2"},
|
||||
ForceModelPrefix: true,
|
||||
},
|
||||
}
|
||||
|
||||
details := BuildConfigChangeDetails(oldCfg, newCfg)
|
||||
expectContains(t, details, "debug: false -> true")
|
||||
expectContains(t, details, "logging-to-file: false -> true")
|
||||
expectContains(t, details, "usage-statistics-enabled: false -> true")
|
||||
expectContains(t, details, "disable-cooling: false -> true")
|
||||
expectContains(t, details, "request-log: false -> true")
|
||||
expectContains(t, details, "request-retry: 1 -> 2")
|
||||
expectContains(t, details, "max-retry-interval: 1 -> 3")
|
||||
expectContains(t, details, "proxy-url: http://old-proxy -> http://new-proxy")
|
||||
expectContains(t, details, "ws-auth: false -> true")
|
||||
expectContains(t, details, "force-model-prefix: false -> true")
|
||||
expectContains(t, details, "quota-exceeded.switch-project: false -> true")
|
||||
expectContains(t, details, "quota-exceeded.switch-preview-model: false -> true")
|
||||
expectContains(t, details, "api-keys count: 1 -> 2")
|
||||
expectContains(t, details, "claude-api-key count: 1 -> 2")
|
||||
expectContains(t, details, "codex-api-key count: 1 -> 2")
|
||||
expectContains(t, details, "ampcode.restrict-management-to-localhost: false -> true")
|
||||
expectContains(t, details, "ampcode.upstream-api-key: removed")
|
||||
expectContains(t, details, "remote-management.disable-control-panel: false -> true")
|
||||
expectContains(t, details, "remote-management.panel-github-repository: old/repo -> new/repo")
|
||||
expectContains(t, details, "remote-management.secret-key: deleted")
|
||||
}
|
||||
|
||||
func TestBuildConfigChangeDetails_AllBranches(t *testing.T) {
|
||||
oldCfg := &config.Config{
|
||||
Port: 1,
|
||||
AuthDir: "/a",
|
||||
Debug: false,
|
||||
LoggingToFile: false,
|
||||
UsageStatisticsEnabled: false,
|
||||
DisableCooling: false,
|
||||
RequestRetry: 1,
|
||||
MaxRetryInterval: 1,
|
||||
WebsocketAuth: false,
|
||||
QuotaExceeded: config.QuotaExceeded{SwitchProject: false, SwitchPreviewModel: false},
|
||||
GeminiKey: []config.GeminiKey{
|
||||
{APIKey: "g-old", BaseURL: "http://g-old", ProxyURL: "http://gp-old", Headers: map[string]string{"A": "1"}},
|
||||
},
|
||||
ClaudeKey: []config.ClaudeKey{
|
||||
{APIKey: "c-old", BaseURL: "http://c-old", ProxyURL: "http://cp-old", Headers: map[string]string{"H": "1"}, ExcludedModels: []string{"x"}},
|
||||
},
|
||||
CodexKey: []config.CodexKey{
|
||||
{APIKey: "x-old", BaseURL: "http://x-old", ProxyURL: "http://xp-old", Headers: map[string]string{"H": "1"}, ExcludedModels: []string{"x"}},
|
||||
},
|
||||
VertexCompatAPIKey: []config.VertexCompatKey{
|
||||
{APIKey: "v-old", BaseURL: "http://v-old", ProxyURL: "http://vp-old", Headers: map[string]string{"H": "1"}, Models: []config.VertexCompatModel{{Name: "m1"}}},
|
||||
},
|
||||
AmpCode: config.AmpCode{
|
||||
UpstreamURL: "http://amp-old",
|
||||
UpstreamAPIKey: "old-key",
|
||||
RestrictManagementToLocalhost: false,
|
||||
ModelMappings: []config.AmpModelMapping{{From: "a", To: "b"}},
|
||||
ForceModelMappings: false,
|
||||
},
|
||||
RemoteManagement: config.RemoteManagement{
|
||||
AllowRemote: false,
|
||||
DisableControlPanel: false,
|
||||
PanelGitHubRepository: "old/repo",
|
||||
SecretKey: "old",
|
||||
},
|
||||
SDKConfig: sdkconfig.SDKConfig{
|
||||
RequestLog: false,
|
||||
ProxyURL: "http://old-proxy",
|
||||
APIKeys: []string{" keyA "},
|
||||
},
|
||||
OAuthExcludedModels: map[string][]string{"p1": {"a"}},
|
||||
OpenAICompatibility: []config.OpenAICompatibility{
|
||||
{
|
||||
Name: "prov-old",
|
||||
APIKeyEntries: []config.OpenAICompatibilityAPIKey{
|
||||
{APIKey: "k1"},
|
||||
},
|
||||
Models: []config.OpenAICompatibilityModel{{Name: "m1"}},
|
||||
},
|
||||
},
|
||||
}
|
||||
newCfg := &config.Config{
|
||||
Port: 2,
|
||||
AuthDir: "/b",
|
||||
Debug: true,
|
||||
LoggingToFile: true,
|
||||
UsageStatisticsEnabled: true,
|
||||
DisableCooling: true,
|
||||
RequestRetry: 2,
|
||||
MaxRetryInterval: 3,
|
||||
WebsocketAuth: true,
|
||||
QuotaExceeded: config.QuotaExceeded{SwitchProject: true, SwitchPreviewModel: true},
|
||||
GeminiKey: []config.GeminiKey{
|
||||
{APIKey: "g-new", BaseURL: "http://g-new", ProxyURL: "http://gp-new", Headers: map[string]string{"A": "2"}, ExcludedModels: []string{"x", "y"}},
|
||||
},
|
||||
ClaudeKey: []config.ClaudeKey{
|
||||
{APIKey: "c-new", BaseURL: "http://c-new", ProxyURL: "http://cp-new", Headers: map[string]string{"H": "2"}, ExcludedModels: []string{"x", "y"}},
|
||||
},
|
||||
CodexKey: []config.CodexKey{
|
||||
{APIKey: "x-new", BaseURL: "http://x-new", ProxyURL: "http://xp-new", Headers: map[string]string{"H": "2"}, ExcludedModels: []string{"x", "y"}},
|
||||
},
|
||||
VertexCompatAPIKey: []config.VertexCompatKey{
|
||||
{APIKey: "v-new", BaseURL: "http://v-new", ProxyURL: "http://vp-new", Headers: map[string]string{"H": "2"}, Models: []config.VertexCompatModel{{Name: "m1"}, {Name: "m2"}}},
|
||||
},
|
||||
AmpCode: config.AmpCode{
|
||||
UpstreamURL: "http://amp-new",
|
||||
UpstreamAPIKey: "",
|
||||
RestrictManagementToLocalhost: true,
|
||||
ModelMappings: []config.AmpModelMapping{{From: "a", To: "c"}},
|
||||
ForceModelMappings: true,
|
||||
},
|
||||
RemoteManagement: config.RemoteManagement{
|
||||
AllowRemote: true,
|
||||
DisableControlPanel: true,
|
||||
PanelGitHubRepository: "new/repo",
|
||||
SecretKey: "",
|
||||
},
|
||||
SDKConfig: sdkconfig.SDKConfig{
|
||||
RequestLog: true,
|
||||
ProxyURL: "http://new-proxy",
|
||||
APIKeys: []string{"keyB"},
|
||||
},
|
||||
OAuthExcludedModels: map[string][]string{"p1": {"b", "c"}, "p2": {"d"}},
|
||||
OpenAICompatibility: []config.OpenAICompatibility{
|
||||
{
|
||||
Name: "prov-old",
|
||||
APIKeyEntries: []config.OpenAICompatibilityAPIKey{
|
||||
{APIKey: "k1"},
|
||||
{APIKey: "k2"},
|
||||
},
|
||||
Models: []config.OpenAICompatibilityModel{{Name: "m1"}, {Name: "m2"}},
|
||||
},
|
||||
{
|
||||
Name: "prov-new",
|
||||
APIKeyEntries: []config.OpenAICompatibilityAPIKey{{APIKey: "k3"}},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
changes := BuildConfigChangeDetails(oldCfg, newCfg)
|
||||
expectContains(t, changes, "port: 1 -> 2")
|
||||
expectContains(t, changes, "auth-dir: /a -> /b")
|
||||
expectContains(t, changes, "debug: false -> true")
|
||||
expectContains(t, changes, "logging-to-file: false -> true")
|
||||
expectContains(t, changes, "usage-statistics-enabled: false -> true")
|
||||
expectContains(t, changes, "disable-cooling: false -> true")
|
||||
expectContains(t, changes, "request-retry: 1 -> 2")
|
||||
expectContains(t, changes, "max-retry-interval: 1 -> 3")
|
||||
expectContains(t, changes, "proxy-url: http://old-proxy -> http://new-proxy")
|
||||
expectContains(t, changes, "ws-auth: false -> true")
|
||||
expectContains(t, changes, "quota-exceeded.switch-project: false -> true")
|
||||
expectContains(t, changes, "quota-exceeded.switch-preview-model: false -> true")
|
||||
expectContains(t, changes, "api-keys: values updated (count unchanged, redacted)")
|
||||
expectContains(t, changes, "gemini[0].base-url: http://g-old -> http://g-new")
|
||||
expectContains(t, changes, "gemini[0].proxy-url: http://gp-old -> http://gp-new")
|
||||
expectContains(t, changes, "gemini[0].api-key: updated")
|
||||
expectContains(t, changes, "gemini[0].headers: updated")
|
||||
expectContains(t, changes, "gemini[0].excluded-models: updated (0 -> 2 entries)")
|
||||
expectContains(t, changes, "claude[0].base-url: http://c-old -> http://c-new")
|
||||
expectContains(t, changes, "claude[0].proxy-url: http://cp-old -> http://cp-new")
|
||||
expectContains(t, changes, "claude[0].api-key: updated")
|
||||
expectContains(t, changes, "claude[0].headers: updated")
|
||||
expectContains(t, changes, "claude[0].excluded-models: updated (1 -> 2 entries)")
|
||||
expectContains(t, changes, "codex[0].base-url: http://x-old -> http://x-new")
|
||||
expectContains(t, changes, "codex[0].proxy-url: http://xp-old -> http://xp-new")
|
||||
expectContains(t, changes, "codex[0].api-key: updated")
|
||||
expectContains(t, changes, "codex[0].headers: updated")
|
||||
expectContains(t, changes, "codex[0].excluded-models: updated (1 -> 2 entries)")
|
||||
expectContains(t, changes, "vertex[0].base-url: http://v-old -> http://v-new")
|
||||
expectContains(t, changes, "vertex[0].proxy-url: http://vp-old -> http://vp-new")
|
||||
expectContains(t, changes, "vertex[0].api-key: updated")
|
||||
expectContains(t, changes, "vertex[0].models: updated (1 -> 2 entries)")
|
||||
expectContains(t, changes, "vertex[0].headers: updated")
|
||||
expectContains(t, changes, "ampcode.upstream-url: http://amp-old -> http://amp-new")
|
||||
expectContains(t, changes, "ampcode.upstream-api-key: removed")
|
||||
expectContains(t, changes, "ampcode.restrict-management-to-localhost: false -> true")
|
||||
expectContains(t, changes, "ampcode.model-mappings: updated (1 -> 1 entries)")
|
||||
expectContains(t, changes, "ampcode.force-model-mappings: false -> true")
|
||||
expectContains(t, changes, "oauth-excluded-models[p1]: updated (1 -> 2 entries)")
|
||||
expectContains(t, changes, "oauth-excluded-models[p2]: added (1 entries)")
|
||||
expectContains(t, changes, "remote-management.allow-remote: false -> true")
|
||||
expectContains(t, changes, "remote-management.disable-control-panel: false -> true")
|
||||
expectContains(t, changes, "remote-management.panel-github-repository: old/repo -> new/repo")
|
||||
expectContains(t, changes, "remote-management.secret-key: deleted")
|
||||
expectContains(t, changes, "openai-compatibility:")
|
||||
}
|
||||
|
||||
func TestFormatProxyURL(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
in string
|
||||
want string
|
||||
}{
|
||||
{name: "empty", in: "", want: "<none>"},
|
||||
{name: "invalid", in: "http://[::1", want: "<redacted>"},
|
||||
{name: "fullURLRedactsUserinfoAndPath", in: "http://user:pass@example.com:8080/path?x=1#frag", want: "http://example.com:8080"},
|
||||
{name: "socks5RedactsUserinfoAndPath", in: "socks5://user:pass@192.168.1.1:1080/path?x=1", want: "socks5://192.168.1.1:1080"},
|
||||
{name: "socks5HostPort", in: "socks5://proxy.example.com:1080/", want: "socks5://proxy.example.com:1080"},
|
||||
{name: "hostPortNoScheme", in: "example.com:1234/path?x=1", want: "example.com:1234"},
|
||||
{name: "relativePathRedacted", in: "/just/path", want: "<redacted>"},
|
||||
{name: "schemeAndHost", in: "https://example.com", want: "https://example.com"},
|
||||
}
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
if got := formatProxyURL(tt.in); got != tt.want {
|
||||
t.Fatalf("expected %q, got %q", tt.want, got)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildConfigChangeDetails_SecretAndUpstreamUpdates(t *testing.T) {
|
||||
oldCfg := &config.Config{
|
||||
AmpCode: config.AmpCode{
|
||||
UpstreamAPIKey: "old",
|
||||
},
|
||||
RemoteManagement: config.RemoteManagement{
|
||||
SecretKey: "old",
|
||||
},
|
||||
}
|
||||
newCfg := &config.Config{
|
||||
AmpCode: config.AmpCode{
|
||||
UpstreamAPIKey: "new",
|
||||
},
|
||||
RemoteManagement: config.RemoteManagement{
|
||||
SecretKey: "new",
|
||||
},
|
||||
}
|
||||
|
||||
changes := BuildConfigChangeDetails(oldCfg, newCfg)
|
||||
expectContains(t, changes, "ampcode.upstream-api-key: updated")
|
||||
expectContains(t, changes, "remote-management.secret-key: updated")
|
||||
}
|
||||
|
||||
func TestBuildConfigChangeDetails_CountBranches(t *testing.T) {
|
||||
oldCfg := &config.Config{}
|
||||
newCfg := &config.Config{
|
||||
GeminiKey: []config.GeminiKey{{APIKey: "g"}},
|
||||
ClaudeKey: []config.ClaudeKey{{APIKey: "c"}},
|
||||
CodexKey: []config.CodexKey{{APIKey: "x"}},
|
||||
VertexCompatAPIKey: []config.VertexCompatKey{
|
||||
{APIKey: "v", BaseURL: "http://v"},
|
||||
},
|
||||
}
|
||||
|
||||
changes := BuildConfigChangeDetails(oldCfg, newCfg)
|
||||
expectContains(t, changes, "gemini-api-key count: 0 -> 1")
|
||||
expectContains(t, changes, "claude-api-key count: 0 -> 1")
|
||||
expectContains(t, changes, "codex-api-key count: 0 -> 1")
|
||||
expectContains(t, changes, "vertex-api-key count: 0 -> 1")
|
||||
}
|
||||
|
||||
func TestTrimStrings(t *testing.T) {
|
||||
out := trimStrings([]string{" a ", "b", " c"})
|
||||
if len(out) != 3 || out[0] != "a" || out[1] != "b" || out[2] != "c" {
|
||||
t.Fatalf("unexpected trimmed strings: %v", out)
|
||||
}
|
||||
}
|
||||
102
internal/watcher/diff/model_hash.go
Normal file
102
internal/watcher/diff/model_hash.go
Normal file
@@ -0,0 +1,102 @@
|
||||
package diff
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"sort"
|
||||
"strings"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
)
|
||||
|
||||
// ComputeOpenAICompatModelsHash returns a stable hash for OpenAI-compat models.
|
||||
// Used to detect model list changes during hot reload.
|
||||
func ComputeOpenAICompatModelsHash(models []config.OpenAICompatibilityModel) string {
|
||||
keys := normalizeModelPairs(func(out func(key string)) {
|
||||
for _, model := range models {
|
||||
name := strings.TrimSpace(model.Name)
|
||||
alias := strings.TrimSpace(model.Alias)
|
||||
if name == "" && alias == "" {
|
||||
continue
|
||||
}
|
||||
out(strings.ToLower(name) + "|" + strings.ToLower(alias))
|
||||
}
|
||||
})
|
||||
return hashJoined(keys)
|
||||
}
|
||||
|
||||
// ComputeVertexCompatModelsHash returns a stable hash for Vertex-compatible models.
|
||||
func ComputeVertexCompatModelsHash(models []config.VertexCompatModel) string {
|
||||
keys := normalizeModelPairs(func(out func(key string)) {
|
||||
for _, model := range models {
|
||||
name := strings.TrimSpace(model.Name)
|
||||
alias := strings.TrimSpace(model.Alias)
|
||||
if name == "" && alias == "" {
|
||||
continue
|
||||
}
|
||||
out(strings.ToLower(name) + "|" + strings.ToLower(alias))
|
||||
}
|
||||
})
|
||||
return hashJoined(keys)
|
||||
}
|
||||
|
||||
// ComputeClaudeModelsHash returns a stable hash for Claude model aliases.
|
||||
func ComputeClaudeModelsHash(models []config.ClaudeModel) string {
|
||||
keys := normalizeModelPairs(func(out func(key string)) {
|
||||
for _, model := range models {
|
||||
name := strings.TrimSpace(model.Name)
|
||||
alias := strings.TrimSpace(model.Alias)
|
||||
if name == "" && alias == "" {
|
||||
continue
|
||||
}
|
||||
out(strings.ToLower(name) + "|" + strings.ToLower(alias))
|
||||
}
|
||||
})
|
||||
return hashJoined(keys)
|
||||
}
|
||||
|
||||
// ComputeExcludedModelsHash returns a normalized hash for excluded model lists.
|
||||
func ComputeExcludedModelsHash(excluded []string) string {
|
||||
if len(excluded) == 0 {
|
||||
return ""
|
||||
}
|
||||
normalized := make([]string, 0, len(excluded))
|
||||
for _, entry := range excluded {
|
||||
if trimmed := strings.TrimSpace(entry); trimmed != "" {
|
||||
normalized = append(normalized, strings.ToLower(trimmed))
|
||||
}
|
||||
}
|
||||
if len(normalized) == 0 {
|
||||
return ""
|
||||
}
|
||||
sort.Strings(normalized)
|
||||
data, _ := json.Marshal(normalized)
|
||||
sum := sha256.Sum256(data)
|
||||
return hex.EncodeToString(sum[:])
|
||||
}
|
||||
|
||||
func normalizeModelPairs(collect func(out func(key string))) []string {
|
||||
seen := make(map[string]struct{})
|
||||
keys := make([]string, 0)
|
||||
collect(func(key string) {
|
||||
if _, exists := seen[key]; exists {
|
||||
return
|
||||
}
|
||||
seen[key] = struct{}{}
|
||||
keys = append(keys, key)
|
||||
})
|
||||
if len(keys) == 0 {
|
||||
return nil
|
||||
}
|
||||
sort.Strings(keys)
|
||||
return keys
|
||||
}
|
||||
|
||||
func hashJoined(keys []string) string {
|
||||
if len(keys) == 0 {
|
||||
return ""
|
||||
}
|
||||
sum := sha256.Sum256([]byte(strings.Join(keys, "\n")))
|
||||
return hex.EncodeToString(sum[:])
|
||||
}
|
||||
159
internal/watcher/diff/model_hash_test.go
Normal file
159
internal/watcher/diff/model_hash_test.go
Normal file
@@ -0,0 +1,159 @@
|
||||
package diff
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
)
|
||||
|
||||
func TestComputeOpenAICompatModelsHash_Deterministic(t *testing.T) {
|
||||
models := []config.OpenAICompatibilityModel{
|
||||
{Name: "gpt-4", Alias: "gpt4"},
|
||||
{Name: "gpt-3.5-turbo"},
|
||||
}
|
||||
hash1 := ComputeOpenAICompatModelsHash(models)
|
||||
hash2 := ComputeOpenAICompatModelsHash(models)
|
||||
if hash1 == "" {
|
||||
t.Fatal("hash should not be empty")
|
||||
}
|
||||
if hash1 != hash2 {
|
||||
t.Fatalf("hash should be deterministic, got %s vs %s", hash1, hash2)
|
||||
}
|
||||
changed := ComputeOpenAICompatModelsHash([]config.OpenAICompatibilityModel{{Name: "gpt-4"}, {Name: "gpt-4.1"}})
|
||||
if hash1 == changed {
|
||||
t.Fatal("hash should change when model list changes")
|
||||
}
|
||||
}
|
||||
|
||||
func TestComputeOpenAICompatModelsHash_NormalizesAndDedups(t *testing.T) {
|
||||
a := []config.OpenAICompatibilityModel{
|
||||
{Name: "gpt-4", Alias: "gpt4"},
|
||||
{Name: " "},
|
||||
{Name: "GPT-4", Alias: "GPT4"},
|
||||
{Alias: "a1"},
|
||||
}
|
||||
b := []config.OpenAICompatibilityModel{
|
||||
{Alias: "A1"},
|
||||
{Name: "gpt-4", Alias: "gpt4"},
|
||||
}
|
||||
h1 := ComputeOpenAICompatModelsHash(a)
|
||||
h2 := ComputeOpenAICompatModelsHash(b)
|
||||
if h1 == "" || h2 == "" {
|
||||
t.Fatal("expected non-empty hashes for non-empty model sets")
|
||||
}
|
||||
if h1 != h2 {
|
||||
t.Fatalf("expected normalized hashes to match, got %s / %s", h1, h2)
|
||||
}
|
||||
}
|
||||
|
||||
func TestComputeVertexCompatModelsHash_DifferentInputs(t *testing.T) {
|
||||
models := []config.VertexCompatModel{{Name: "gemini-pro", Alias: "pro"}}
|
||||
hash1 := ComputeVertexCompatModelsHash(models)
|
||||
hash2 := ComputeVertexCompatModelsHash([]config.VertexCompatModel{{Name: "gemini-1.5-pro", Alias: "pro"}})
|
||||
if hash1 == "" || hash2 == "" {
|
||||
t.Fatal("hashes should not be empty for non-empty models")
|
||||
}
|
||||
if hash1 == hash2 {
|
||||
t.Fatal("hash should differ when model content differs")
|
||||
}
|
||||
}
|
||||
|
||||
func TestComputeVertexCompatModelsHash_IgnoresBlankAndOrder(t *testing.T) {
|
||||
a := []config.VertexCompatModel{
|
||||
{Name: "m1", Alias: "a1"},
|
||||
{Name: " "},
|
||||
{Name: "M1", Alias: "A1"},
|
||||
}
|
||||
b := []config.VertexCompatModel{
|
||||
{Name: "m1", Alias: "a1"},
|
||||
}
|
||||
if h1, h2 := ComputeVertexCompatModelsHash(a), ComputeVertexCompatModelsHash(b); h1 == "" || h1 != h2 {
|
||||
t.Fatalf("expected same hash ignoring blanks/dupes, got %q / %q", h1, h2)
|
||||
}
|
||||
}
|
||||
|
||||
func TestComputeClaudeModelsHash_Empty(t *testing.T) {
|
||||
if got := ComputeClaudeModelsHash(nil); got != "" {
|
||||
t.Fatalf("expected empty hash for nil models, got %q", got)
|
||||
}
|
||||
if got := ComputeClaudeModelsHash([]config.ClaudeModel{}); got != "" {
|
||||
t.Fatalf("expected empty hash for empty slice, got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestComputeClaudeModelsHash_IgnoresBlankAndDedup(t *testing.T) {
|
||||
a := []config.ClaudeModel{
|
||||
{Name: "m1", Alias: "a1"},
|
||||
{Name: " "},
|
||||
{Name: "M1", Alias: "A1"},
|
||||
}
|
||||
b := []config.ClaudeModel{
|
||||
{Name: "m1", Alias: "a1"},
|
||||
}
|
||||
if h1, h2 := ComputeClaudeModelsHash(a), ComputeClaudeModelsHash(b); h1 == "" || h1 != h2 {
|
||||
t.Fatalf("expected same hash ignoring blanks/dupes, got %q / %q", h1, h2)
|
||||
}
|
||||
}
|
||||
|
||||
func TestComputeExcludedModelsHash_Normalizes(t *testing.T) {
|
||||
hash1 := ComputeExcludedModelsHash([]string{" A ", "b", "a"})
|
||||
hash2 := ComputeExcludedModelsHash([]string{"a", " b", "A"})
|
||||
if hash1 == "" || hash2 == "" {
|
||||
t.Fatal("hash should not be empty for non-empty input")
|
||||
}
|
||||
if hash1 != hash2 {
|
||||
t.Fatalf("hash should be order/space insensitive for same multiset, got %s vs %s", hash1, hash2)
|
||||
}
|
||||
hash3 := ComputeExcludedModelsHash([]string{"c"})
|
||||
if hash1 == hash3 {
|
||||
t.Fatal("hash should differ for different normalized sets")
|
||||
}
|
||||
}
|
||||
|
||||
func TestComputeOpenAICompatModelsHash_Empty(t *testing.T) {
|
||||
if got := ComputeOpenAICompatModelsHash(nil); got != "" {
|
||||
t.Fatalf("expected empty hash for nil input, got %q", got)
|
||||
}
|
||||
if got := ComputeOpenAICompatModelsHash([]config.OpenAICompatibilityModel{}); got != "" {
|
||||
t.Fatalf("expected empty hash for empty slice, got %q", got)
|
||||
}
|
||||
if got := ComputeOpenAICompatModelsHash([]config.OpenAICompatibilityModel{{Name: " "}, {Alias: ""}}); got != "" {
|
||||
t.Fatalf("expected empty hash for blank models, got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestComputeVertexCompatModelsHash_Empty(t *testing.T) {
|
||||
if got := ComputeVertexCompatModelsHash(nil); got != "" {
|
||||
t.Fatalf("expected empty hash for nil input, got %q", got)
|
||||
}
|
||||
if got := ComputeVertexCompatModelsHash([]config.VertexCompatModel{}); got != "" {
|
||||
t.Fatalf("expected empty hash for empty slice, got %q", got)
|
||||
}
|
||||
if got := ComputeVertexCompatModelsHash([]config.VertexCompatModel{{Name: " "}}); got != "" {
|
||||
t.Fatalf("expected empty hash for blank models, got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestComputeExcludedModelsHash_Empty(t *testing.T) {
|
||||
if got := ComputeExcludedModelsHash(nil); got != "" {
|
||||
t.Fatalf("expected empty hash for nil input, got %q", got)
|
||||
}
|
||||
if got := ComputeExcludedModelsHash([]string{}); got != "" {
|
||||
t.Fatalf("expected empty hash for empty slice, got %q", got)
|
||||
}
|
||||
if got := ComputeExcludedModelsHash([]string{" ", ""}); got != "" {
|
||||
t.Fatalf("expected empty hash for whitespace-only entries, got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestComputeClaudeModelsHash_Deterministic(t *testing.T) {
|
||||
models := []config.ClaudeModel{{Name: "a", Alias: "A"}, {Name: "b"}}
|
||||
h1 := ComputeClaudeModelsHash(models)
|
||||
h2 := ComputeClaudeModelsHash(models)
|
||||
if h1 == "" || h1 != h2 {
|
||||
t.Fatalf("expected deterministic hash, got %s / %s", h1, h2)
|
||||
}
|
||||
if h3 := ComputeClaudeModelsHash([]config.ClaudeModel{{Name: "a"}}); h3 == h1 {
|
||||
t.Fatalf("expected different hash when models change, got %s", h3)
|
||||
}
|
||||
}
|
||||
151
internal/watcher/diff/oauth_excluded.go
Normal file
151
internal/watcher/diff/oauth_excluded.go
Normal file
@@ -0,0 +1,151 @@
|
||||
package diff
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"sort"
|
||||
"strings"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
)
|
||||
|
||||
type ExcludedModelsSummary struct {
|
||||
hash string
|
||||
count int
|
||||
}
|
||||
|
||||
// SummarizeExcludedModels normalizes and hashes an excluded-model list.
|
||||
func SummarizeExcludedModels(list []string) ExcludedModelsSummary {
|
||||
if len(list) == 0 {
|
||||
return ExcludedModelsSummary{}
|
||||
}
|
||||
seen := make(map[string]struct{}, len(list))
|
||||
normalized := make([]string, 0, len(list))
|
||||
for _, entry := range list {
|
||||
if trimmed := strings.ToLower(strings.TrimSpace(entry)); trimmed != "" {
|
||||
if _, exists := seen[trimmed]; exists {
|
||||
continue
|
||||
}
|
||||
seen[trimmed] = struct{}{}
|
||||
normalized = append(normalized, trimmed)
|
||||
}
|
||||
}
|
||||
sort.Strings(normalized)
|
||||
return ExcludedModelsSummary{
|
||||
hash: ComputeExcludedModelsHash(normalized),
|
||||
count: len(normalized),
|
||||
}
|
||||
}
|
||||
|
||||
// SummarizeOAuthExcludedModels summarizes OAuth excluded models per provider.
|
||||
func SummarizeOAuthExcludedModels(entries map[string][]string) map[string]ExcludedModelsSummary {
|
||||
if len(entries) == 0 {
|
||||
return nil
|
||||
}
|
||||
out := make(map[string]ExcludedModelsSummary, len(entries))
|
||||
for k, v := range entries {
|
||||
key := strings.ToLower(strings.TrimSpace(k))
|
||||
if key == "" {
|
||||
continue
|
||||
}
|
||||
out[key] = SummarizeExcludedModels(v)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// DiffOAuthExcludedModelChanges compares OAuth excluded models maps.
|
||||
func DiffOAuthExcludedModelChanges(oldMap, newMap map[string][]string) ([]string, []string) {
|
||||
oldSummary := SummarizeOAuthExcludedModels(oldMap)
|
||||
newSummary := SummarizeOAuthExcludedModels(newMap)
|
||||
keys := make(map[string]struct{}, len(oldSummary)+len(newSummary))
|
||||
for k := range oldSummary {
|
||||
keys[k] = struct{}{}
|
||||
}
|
||||
for k := range newSummary {
|
||||
keys[k] = struct{}{}
|
||||
}
|
||||
changes := make([]string, 0, len(keys))
|
||||
affected := make([]string, 0, len(keys))
|
||||
for key := range keys {
|
||||
oldInfo, okOld := oldSummary[key]
|
||||
newInfo, okNew := newSummary[key]
|
||||
switch {
|
||||
case okOld && !okNew:
|
||||
changes = append(changes, fmt.Sprintf("oauth-excluded-models[%s]: removed", key))
|
||||
affected = append(affected, key)
|
||||
case !okOld && okNew:
|
||||
changes = append(changes, fmt.Sprintf("oauth-excluded-models[%s]: added (%d entries)", key, newInfo.count))
|
||||
affected = append(affected, key)
|
||||
case okOld && okNew && oldInfo.hash != newInfo.hash:
|
||||
changes = append(changes, fmt.Sprintf("oauth-excluded-models[%s]: updated (%d -> %d entries)", key, oldInfo.count, newInfo.count))
|
||||
affected = append(affected, key)
|
||||
}
|
||||
}
|
||||
sort.Strings(changes)
|
||||
sort.Strings(affected)
|
||||
return changes, affected
|
||||
}
|
||||
|
||||
type AmpModelMappingsSummary struct {
|
||||
hash string
|
||||
count int
|
||||
}
|
||||
|
||||
// SummarizeAmpModelMappings hashes Amp model mappings for change detection.
|
||||
func SummarizeAmpModelMappings(mappings []config.AmpModelMapping) AmpModelMappingsSummary {
|
||||
if len(mappings) == 0 {
|
||||
return AmpModelMappingsSummary{}
|
||||
}
|
||||
entries := make([]string, 0, len(mappings))
|
||||
for _, mapping := range mappings {
|
||||
from := strings.TrimSpace(mapping.From)
|
||||
to := strings.TrimSpace(mapping.To)
|
||||
if from == "" && to == "" {
|
||||
continue
|
||||
}
|
||||
entries = append(entries, from+"->"+to)
|
||||
}
|
||||
if len(entries) == 0 {
|
||||
return AmpModelMappingsSummary{}
|
||||
}
|
||||
sort.Strings(entries)
|
||||
sum := sha256.Sum256([]byte(strings.Join(entries, "|")))
|
||||
return AmpModelMappingsSummary{
|
||||
hash: hex.EncodeToString(sum[:]),
|
||||
count: len(entries),
|
||||
}
|
||||
}
|
||||
|
||||
type VertexModelsSummary struct {
|
||||
hash string
|
||||
count int
|
||||
}
|
||||
|
||||
// SummarizeVertexModels hashes vertex-compatible models for change detection.
|
||||
func SummarizeVertexModels(models []config.VertexCompatModel) VertexModelsSummary {
|
||||
if len(models) == 0 {
|
||||
return VertexModelsSummary{}
|
||||
}
|
||||
names := make([]string, 0, len(models))
|
||||
for _, m := range models {
|
||||
name := strings.TrimSpace(m.Name)
|
||||
alias := strings.TrimSpace(m.Alias)
|
||||
if name == "" && alias == "" {
|
||||
continue
|
||||
}
|
||||
if alias != "" {
|
||||
name = alias
|
||||
}
|
||||
names = append(names, name)
|
||||
}
|
||||
if len(names) == 0 {
|
||||
return VertexModelsSummary{}
|
||||
}
|
||||
sort.Strings(names)
|
||||
sum := sha256.Sum256([]byte(strings.Join(names, "|")))
|
||||
return VertexModelsSummary{
|
||||
hash: hex.EncodeToString(sum[:]),
|
||||
count: len(names),
|
||||
}
|
||||
}
|
||||
109
internal/watcher/diff/oauth_excluded_test.go
Normal file
109
internal/watcher/diff/oauth_excluded_test.go
Normal file
@@ -0,0 +1,109 @@
|
||||
package diff
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
)
|
||||
|
||||
func TestSummarizeExcludedModels_NormalizesAndDedupes(t *testing.T) {
|
||||
summary := SummarizeExcludedModels([]string{"A", " a ", "B", "b"})
|
||||
if summary.count != 2 {
|
||||
t.Fatalf("expected 2 unique entries, got %d", summary.count)
|
||||
}
|
||||
if summary.hash == "" {
|
||||
t.Fatal("expected non-empty hash")
|
||||
}
|
||||
if empty := SummarizeExcludedModels(nil); empty.count != 0 || empty.hash != "" {
|
||||
t.Fatalf("expected empty summary for nil input, got %+v", empty)
|
||||
}
|
||||
}
|
||||
|
||||
func TestDiffOAuthExcludedModelChanges(t *testing.T) {
|
||||
oldMap := map[string][]string{
|
||||
"ProviderA": {"model-1", "model-2"},
|
||||
"providerB": {"x"},
|
||||
}
|
||||
newMap := map[string][]string{
|
||||
"providerA": {"model-1", "model-3"},
|
||||
"providerC": {"y"},
|
||||
}
|
||||
|
||||
changes, affected := DiffOAuthExcludedModelChanges(oldMap, newMap)
|
||||
expectContains(t, changes, "oauth-excluded-models[providera]: updated (2 -> 2 entries)")
|
||||
expectContains(t, changes, "oauth-excluded-models[providerb]: removed")
|
||||
expectContains(t, changes, "oauth-excluded-models[providerc]: added (1 entries)")
|
||||
|
||||
if len(affected) != 3 {
|
||||
t.Fatalf("expected 3 affected providers, got %d", len(affected))
|
||||
}
|
||||
}
|
||||
|
||||
func TestSummarizeAmpModelMappings(t *testing.T) {
|
||||
summary := SummarizeAmpModelMappings([]config.AmpModelMapping{
|
||||
{From: "a", To: "A"},
|
||||
{From: "b", To: "B"},
|
||||
{From: " ", To: " "}, // ignored
|
||||
})
|
||||
if summary.count != 2 {
|
||||
t.Fatalf("expected 2 entries, got %d", summary.count)
|
||||
}
|
||||
if summary.hash == "" {
|
||||
t.Fatal("expected non-empty hash")
|
||||
}
|
||||
if empty := SummarizeAmpModelMappings(nil); empty.count != 0 || empty.hash != "" {
|
||||
t.Fatalf("expected empty summary for nil input, got %+v", empty)
|
||||
}
|
||||
if blank := SummarizeAmpModelMappings([]config.AmpModelMapping{{From: " ", To: " "}}); blank.count != 0 || blank.hash != "" {
|
||||
t.Fatalf("expected blank mappings ignored, got %+v", blank)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSummarizeOAuthExcludedModels_NormalizesKeys(t *testing.T) {
|
||||
out := SummarizeOAuthExcludedModels(map[string][]string{
|
||||
"ProvA": {"X"},
|
||||
"": {"ignored"},
|
||||
})
|
||||
if len(out) != 1 {
|
||||
t.Fatalf("expected only non-empty key summary, got %d", len(out))
|
||||
}
|
||||
if _, ok := out["prova"]; !ok {
|
||||
t.Fatalf("expected normalized key 'prova', got keys %v", out)
|
||||
}
|
||||
if out["prova"].count != 1 || out["prova"].hash == "" {
|
||||
t.Fatalf("unexpected summary %+v", out["prova"])
|
||||
}
|
||||
if outEmpty := SummarizeOAuthExcludedModels(nil); outEmpty != nil {
|
||||
t.Fatalf("expected nil map for nil input, got %v", outEmpty)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSummarizeVertexModels(t *testing.T) {
|
||||
summary := SummarizeVertexModels([]config.VertexCompatModel{
|
||||
{Name: "m1"},
|
||||
{Name: " ", Alias: "alias"},
|
||||
{}, // ignored
|
||||
})
|
||||
if summary.count != 2 {
|
||||
t.Fatalf("expected 2 vertex models, got %d", summary.count)
|
||||
}
|
||||
if summary.hash == "" {
|
||||
t.Fatal("expected non-empty hash")
|
||||
}
|
||||
if empty := SummarizeVertexModels(nil); empty.count != 0 || empty.hash != "" {
|
||||
t.Fatalf("expected empty summary for nil input, got %+v", empty)
|
||||
}
|
||||
if blank := SummarizeVertexModels([]config.VertexCompatModel{{Name: " "}}); blank.count != 0 || blank.hash != "" {
|
||||
t.Fatalf("expected blank model ignored, got %+v", blank)
|
||||
}
|
||||
}
|
||||
|
||||
func expectContains(t *testing.T, list []string, target string) {
|
||||
t.Helper()
|
||||
for _, entry := range list {
|
||||
if entry == target {
|
||||
return
|
||||
}
|
||||
}
|
||||
t.Fatalf("expected list to contain %q, got %#v", target, list)
|
||||
}
|
||||
183
internal/watcher/diff/openai_compat.go
Normal file
183
internal/watcher/diff/openai_compat.go
Normal file
@@ -0,0 +1,183 @@
|
||||
package diff
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"sort"
|
||||
"strings"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
)
|
||||
|
||||
// DiffOpenAICompatibility produces human-readable change descriptions.
|
||||
func DiffOpenAICompatibility(oldList, newList []config.OpenAICompatibility) []string {
|
||||
changes := make([]string, 0)
|
||||
oldMap := make(map[string]config.OpenAICompatibility, len(oldList))
|
||||
oldLabels := make(map[string]string, len(oldList))
|
||||
for idx, entry := range oldList {
|
||||
key, label := openAICompatKey(entry, idx)
|
||||
oldMap[key] = entry
|
||||
oldLabels[key] = label
|
||||
}
|
||||
newMap := make(map[string]config.OpenAICompatibility, len(newList))
|
||||
newLabels := make(map[string]string, len(newList))
|
||||
for idx, entry := range newList {
|
||||
key, label := openAICompatKey(entry, idx)
|
||||
newMap[key] = entry
|
||||
newLabels[key] = label
|
||||
}
|
||||
keySet := make(map[string]struct{}, len(oldMap)+len(newMap))
|
||||
for key := range oldMap {
|
||||
keySet[key] = struct{}{}
|
||||
}
|
||||
for key := range newMap {
|
||||
keySet[key] = struct{}{}
|
||||
}
|
||||
orderedKeys := make([]string, 0, len(keySet))
|
||||
for key := range keySet {
|
||||
orderedKeys = append(orderedKeys, key)
|
||||
}
|
||||
sort.Strings(orderedKeys)
|
||||
for _, key := range orderedKeys {
|
||||
oldEntry, oldOk := oldMap[key]
|
||||
newEntry, newOk := newMap[key]
|
||||
label := oldLabels[key]
|
||||
if label == "" {
|
||||
label = newLabels[key]
|
||||
}
|
||||
switch {
|
||||
case !oldOk:
|
||||
changes = append(changes, fmt.Sprintf("provider added: %s (api-keys=%d, models=%d)", label, countAPIKeys(newEntry), countOpenAIModels(newEntry.Models)))
|
||||
case !newOk:
|
||||
changes = append(changes, fmt.Sprintf("provider removed: %s (api-keys=%d, models=%d)", label, countAPIKeys(oldEntry), countOpenAIModels(oldEntry.Models)))
|
||||
default:
|
||||
if detail := describeOpenAICompatibilityUpdate(oldEntry, newEntry); detail != "" {
|
||||
changes = append(changes, fmt.Sprintf("provider updated: %s %s", label, detail))
|
||||
}
|
||||
}
|
||||
}
|
||||
return changes
|
||||
}
|
||||
|
||||
func describeOpenAICompatibilityUpdate(oldEntry, newEntry config.OpenAICompatibility) string {
|
||||
oldKeyCount := countAPIKeys(oldEntry)
|
||||
newKeyCount := countAPIKeys(newEntry)
|
||||
oldModelCount := countOpenAIModels(oldEntry.Models)
|
||||
newModelCount := countOpenAIModels(newEntry.Models)
|
||||
details := make([]string, 0, 3)
|
||||
if oldKeyCount != newKeyCount {
|
||||
details = append(details, fmt.Sprintf("api-keys %d -> %d", oldKeyCount, newKeyCount))
|
||||
}
|
||||
if oldModelCount != newModelCount {
|
||||
details = append(details, fmt.Sprintf("models %d -> %d", oldModelCount, newModelCount))
|
||||
}
|
||||
if !equalStringMap(oldEntry.Headers, newEntry.Headers) {
|
||||
details = append(details, "headers updated")
|
||||
}
|
||||
if len(details) == 0 {
|
||||
return ""
|
||||
}
|
||||
return "(" + strings.Join(details, ", ") + ")"
|
||||
}
|
||||
|
||||
func countAPIKeys(entry config.OpenAICompatibility) int {
|
||||
count := 0
|
||||
for _, keyEntry := range entry.APIKeyEntries {
|
||||
if strings.TrimSpace(keyEntry.APIKey) != "" {
|
||||
count++
|
||||
}
|
||||
}
|
||||
return count
|
||||
}
|
||||
|
||||
func countOpenAIModels(models []config.OpenAICompatibilityModel) int {
|
||||
count := 0
|
||||
for _, model := range models {
|
||||
name := strings.TrimSpace(model.Name)
|
||||
alias := strings.TrimSpace(model.Alias)
|
||||
if name == "" && alias == "" {
|
||||
continue
|
||||
}
|
||||
count++
|
||||
}
|
||||
return count
|
||||
}
|
||||
|
||||
func openAICompatKey(entry config.OpenAICompatibility, index int) (string, string) {
|
||||
name := strings.TrimSpace(entry.Name)
|
||||
if name != "" {
|
||||
return "name:" + name, name
|
||||
}
|
||||
base := strings.TrimSpace(entry.BaseURL)
|
||||
if base != "" {
|
||||
return "base:" + base, base
|
||||
}
|
||||
for _, model := range entry.Models {
|
||||
alias := strings.TrimSpace(model.Alias)
|
||||
if alias == "" {
|
||||
alias = strings.TrimSpace(model.Name)
|
||||
}
|
||||
if alias != "" {
|
||||
return "alias:" + alias, alias
|
||||
}
|
||||
}
|
||||
sig := openAICompatSignature(entry)
|
||||
if sig == "" {
|
||||
return fmt.Sprintf("index:%d", index), fmt.Sprintf("entry-%d", index+1)
|
||||
}
|
||||
short := sig
|
||||
if len(short) > 8 {
|
||||
short = short[:8]
|
||||
}
|
||||
return "sig:" + sig, "compat-" + short
|
||||
}
|
||||
|
||||
func openAICompatSignature(entry config.OpenAICompatibility) string {
|
||||
var parts []string
|
||||
|
||||
if v := strings.TrimSpace(entry.Name); v != "" {
|
||||
parts = append(parts, "name="+strings.ToLower(v))
|
||||
}
|
||||
if v := strings.TrimSpace(entry.BaseURL); v != "" {
|
||||
parts = append(parts, "base="+v)
|
||||
}
|
||||
|
||||
models := make([]string, 0, len(entry.Models))
|
||||
for _, model := range entry.Models {
|
||||
name := strings.TrimSpace(model.Name)
|
||||
alias := strings.TrimSpace(model.Alias)
|
||||
if name == "" && alias == "" {
|
||||
continue
|
||||
}
|
||||
models = append(models, strings.ToLower(name)+"|"+strings.ToLower(alias))
|
||||
}
|
||||
if len(models) > 0 {
|
||||
sort.Strings(models)
|
||||
parts = append(parts, "models="+strings.Join(models, ","))
|
||||
}
|
||||
|
||||
if len(entry.Headers) > 0 {
|
||||
keys := make([]string, 0, len(entry.Headers))
|
||||
for k := range entry.Headers {
|
||||
if trimmed := strings.TrimSpace(k); trimmed != "" {
|
||||
keys = append(keys, strings.ToLower(trimmed))
|
||||
}
|
||||
}
|
||||
if len(keys) > 0 {
|
||||
sort.Strings(keys)
|
||||
parts = append(parts, "headers="+strings.Join(keys, ","))
|
||||
}
|
||||
}
|
||||
|
||||
// Intentionally exclude API key material; only count non-empty entries.
|
||||
if count := countAPIKeys(entry); count > 0 {
|
||||
parts = append(parts, fmt.Sprintf("api_keys=%d", count))
|
||||
}
|
||||
|
||||
if len(parts) == 0 {
|
||||
return ""
|
||||
}
|
||||
sum := sha256.Sum256([]byte(strings.Join(parts, "|")))
|
||||
return hex.EncodeToString(sum[:])
|
||||
}
|
||||
187
internal/watcher/diff/openai_compat_test.go
Normal file
187
internal/watcher/diff/openai_compat_test.go
Normal file
@@ -0,0 +1,187 @@
|
||||
package diff
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
)
|
||||
|
||||
func TestDiffOpenAICompatibility(t *testing.T) {
|
||||
oldList := []config.OpenAICompatibility{
|
||||
{
|
||||
Name: "provider-a",
|
||||
APIKeyEntries: []config.OpenAICompatibilityAPIKey{
|
||||
{APIKey: "key-a"},
|
||||
},
|
||||
Models: []config.OpenAICompatibilityModel{
|
||||
{Name: "m1"},
|
||||
},
|
||||
},
|
||||
}
|
||||
newList := []config.OpenAICompatibility{
|
||||
{
|
||||
Name: "provider-a",
|
||||
APIKeyEntries: []config.OpenAICompatibilityAPIKey{
|
||||
{APIKey: "key-a"},
|
||||
{APIKey: "key-b"},
|
||||
},
|
||||
Models: []config.OpenAICompatibilityModel{
|
||||
{Name: "m1"},
|
||||
{Name: "m2"},
|
||||
},
|
||||
Headers: map[string]string{"X-Test": "1"},
|
||||
},
|
||||
{
|
||||
Name: "provider-b",
|
||||
APIKeyEntries: []config.OpenAICompatibilityAPIKey{{APIKey: "key-b"}},
|
||||
},
|
||||
}
|
||||
|
||||
changes := DiffOpenAICompatibility(oldList, newList)
|
||||
expectContains(t, changes, "provider added: provider-b (api-keys=1, models=0)")
|
||||
expectContains(t, changes, "provider updated: provider-a (api-keys 1 -> 2, models 1 -> 2, headers updated)")
|
||||
}
|
||||
|
||||
func TestDiffOpenAICompatibility_RemovedAndUnchanged(t *testing.T) {
|
||||
oldList := []config.OpenAICompatibility{
|
||||
{
|
||||
Name: "provider-a",
|
||||
APIKeyEntries: []config.OpenAICompatibilityAPIKey{{APIKey: "key-a"}},
|
||||
Models: []config.OpenAICompatibilityModel{{Name: "m1"}},
|
||||
},
|
||||
}
|
||||
newList := []config.OpenAICompatibility{
|
||||
{
|
||||
Name: "provider-a",
|
||||
APIKeyEntries: []config.OpenAICompatibilityAPIKey{{APIKey: "key-a"}},
|
||||
Models: []config.OpenAICompatibilityModel{{Name: "m1"}},
|
||||
},
|
||||
}
|
||||
if changes := DiffOpenAICompatibility(oldList, newList); len(changes) != 0 {
|
||||
t.Fatalf("expected no changes, got %v", changes)
|
||||
}
|
||||
|
||||
newList = nil
|
||||
changes := DiffOpenAICompatibility(oldList, newList)
|
||||
expectContains(t, changes, "provider removed: provider-a (api-keys=1, models=1)")
|
||||
}
|
||||
|
||||
func TestOpenAICompatKeyFallbacks(t *testing.T) {
|
||||
entry := config.OpenAICompatibility{
|
||||
BaseURL: "http://base",
|
||||
Models: []config.OpenAICompatibilityModel{{Alias: "alias-only"}},
|
||||
}
|
||||
key, label := openAICompatKey(entry, 0)
|
||||
if key != "base:http://base" || label != "http://base" {
|
||||
t.Fatalf("expected base key, got %s/%s", key, label)
|
||||
}
|
||||
|
||||
entry.BaseURL = ""
|
||||
key, label = openAICompatKey(entry, 1)
|
||||
if key != "alias:alias-only" || label != "alias-only" {
|
||||
t.Fatalf("expected alias fallback, got %s/%s", key, label)
|
||||
}
|
||||
|
||||
entry.Models = nil
|
||||
key, label = openAICompatKey(entry, 2)
|
||||
if key != "index:2" || label != "entry-3" {
|
||||
t.Fatalf("expected index fallback, got %s/%s", key, label)
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenAICompatKey_UsesName(t *testing.T) {
|
||||
entry := config.OpenAICompatibility{Name: "My-Provider"}
|
||||
key, label := openAICompatKey(entry, 0)
|
||||
if key != "name:My-Provider" || label != "My-Provider" {
|
||||
t.Fatalf("expected name key, got %s/%s", key, label)
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenAICompatKey_SignatureFallbackWhenOnlyAPIKeys(t *testing.T) {
|
||||
entry := config.OpenAICompatibility{
|
||||
APIKeyEntries: []config.OpenAICompatibilityAPIKey{{APIKey: "k1"}, {APIKey: "k2"}},
|
||||
}
|
||||
key, label := openAICompatKey(entry, 0)
|
||||
if !strings.HasPrefix(key, "sig:") || !strings.HasPrefix(label, "compat-") {
|
||||
t.Fatalf("expected signature key, got %s/%s", key, label)
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenAICompatSignature_EmptyReturnsEmpty(t *testing.T) {
|
||||
if got := openAICompatSignature(config.OpenAICompatibility{}); got != "" {
|
||||
t.Fatalf("expected empty signature, got %q", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenAICompatSignature_StableAndNormalized(t *testing.T) {
|
||||
a := config.OpenAICompatibility{
|
||||
Name: " Provider ",
|
||||
BaseURL: "http://base",
|
||||
Models: []config.OpenAICompatibilityModel{
|
||||
{Name: "m1"},
|
||||
{Name: " "},
|
||||
{Alias: "A1"},
|
||||
},
|
||||
Headers: map[string]string{
|
||||
"X-Test": "1",
|
||||
" ": "ignored",
|
||||
},
|
||||
APIKeyEntries: []config.OpenAICompatibilityAPIKey{
|
||||
{APIKey: "k1"},
|
||||
{APIKey: " "},
|
||||
},
|
||||
}
|
||||
b := config.OpenAICompatibility{
|
||||
Name: "provider",
|
||||
BaseURL: "http://base",
|
||||
Models: []config.OpenAICompatibilityModel{
|
||||
{Alias: "a1"},
|
||||
{Name: "m1"},
|
||||
},
|
||||
Headers: map[string]string{
|
||||
"x-test": "2",
|
||||
},
|
||||
APIKeyEntries: []config.OpenAICompatibilityAPIKey{
|
||||
{APIKey: "k2"},
|
||||
},
|
||||
}
|
||||
|
||||
sigA := openAICompatSignature(a)
|
||||
sigB := openAICompatSignature(b)
|
||||
if sigA == "" || sigB == "" {
|
||||
t.Fatalf("expected non-empty signatures, got %q / %q", sigA, sigB)
|
||||
}
|
||||
if sigA != sigB {
|
||||
t.Fatalf("expected normalized signatures to match, got %s / %s", sigA, sigB)
|
||||
}
|
||||
|
||||
c := b
|
||||
c.Models = append(c.Models, config.OpenAICompatibilityModel{Name: "m2"})
|
||||
if sigC := openAICompatSignature(c); sigC == sigB {
|
||||
t.Fatalf("expected signature to change when models change, got %s", sigC)
|
||||
}
|
||||
}
|
||||
|
||||
func TestCountOpenAIModelsSkipsBlanks(t *testing.T) {
|
||||
models := []config.OpenAICompatibilityModel{
|
||||
{Name: "m1"},
|
||||
{Name: ""},
|
||||
{Alias: ""},
|
||||
{Name: " "},
|
||||
{Alias: "a1"},
|
||||
}
|
||||
if got := countOpenAIModels(models); got != 2 {
|
||||
t.Fatalf("expected 2 counted models, got %d", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenAICompatKeyUsesModelNameWhenAliasEmpty(t *testing.T) {
|
||||
entry := config.OpenAICompatibility{
|
||||
Models: []config.OpenAICompatibilityModel{{Name: "model-name"}},
|
||||
}
|
||||
key, label := openAICompatKey(entry, 5)
|
||||
if key != "alias:model-name" || label != "model-name" {
|
||||
t.Fatalf("expected model-name fallback, got %s/%s", key, label)
|
||||
}
|
||||
}
|
||||
273
internal/watcher/dispatcher.go
Normal file
273
internal/watcher/dispatcher.go
Normal file
@@ -0,0 +1,273 @@
|
||||
// dispatcher.go implements auth update dispatching and queue management.
|
||||
// It batches, deduplicates, and delivers auth updates to registered consumers.
|
||||
package watcher
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"reflect"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/watcher/synthesizer"
|
||||
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||
)
|
||||
|
||||
func (w *Watcher) setAuthUpdateQueue(queue chan<- AuthUpdate) {
|
||||
w.clientsMutex.Lock()
|
||||
defer w.clientsMutex.Unlock()
|
||||
w.authQueue = queue
|
||||
if w.dispatchCond == nil {
|
||||
w.dispatchCond = sync.NewCond(&w.dispatchMu)
|
||||
}
|
||||
if w.dispatchCancel != nil {
|
||||
w.dispatchCancel()
|
||||
if w.dispatchCond != nil {
|
||||
w.dispatchMu.Lock()
|
||||
w.dispatchCond.Broadcast()
|
||||
w.dispatchMu.Unlock()
|
||||
}
|
||||
w.dispatchCancel = nil
|
||||
}
|
||||
if queue != nil {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
w.dispatchCancel = cancel
|
||||
go w.dispatchLoop(ctx)
|
||||
}
|
||||
}
|
||||
|
||||
func (w *Watcher) dispatchRuntimeAuthUpdate(update AuthUpdate) bool {
|
||||
if w == nil {
|
||||
return false
|
||||
}
|
||||
w.clientsMutex.Lock()
|
||||
if w.runtimeAuths == nil {
|
||||
w.runtimeAuths = make(map[string]*coreauth.Auth)
|
||||
}
|
||||
switch update.Action {
|
||||
case AuthUpdateActionAdd, AuthUpdateActionModify:
|
||||
if update.Auth != nil && update.Auth.ID != "" {
|
||||
clone := update.Auth.Clone()
|
||||
w.runtimeAuths[clone.ID] = clone
|
||||
if w.currentAuths == nil {
|
||||
w.currentAuths = make(map[string]*coreauth.Auth)
|
||||
}
|
||||
w.currentAuths[clone.ID] = clone.Clone()
|
||||
}
|
||||
case AuthUpdateActionDelete:
|
||||
id := update.ID
|
||||
if id == "" && update.Auth != nil {
|
||||
id = update.Auth.ID
|
||||
}
|
||||
if id != "" {
|
||||
delete(w.runtimeAuths, id)
|
||||
if w.currentAuths != nil {
|
||||
delete(w.currentAuths, id)
|
||||
}
|
||||
}
|
||||
}
|
||||
w.clientsMutex.Unlock()
|
||||
if w.getAuthQueue() == nil {
|
||||
return false
|
||||
}
|
||||
w.dispatchAuthUpdates([]AuthUpdate{update})
|
||||
return true
|
||||
}
|
||||
|
||||
func (w *Watcher) refreshAuthState(force bool) {
|
||||
auths := w.SnapshotCoreAuths()
|
||||
w.clientsMutex.Lock()
|
||||
if len(w.runtimeAuths) > 0 {
|
||||
for _, a := range w.runtimeAuths {
|
||||
if a != nil {
|
||||
auths = append(auths, a.Clone())
|
||||
}
|
||||
}
|
||||
}
|
||||
updates := w.prepareAuthUpdatesLocked(auths, force)
|
||||
w.clientsMutex.Unlock()
|
||||
w.dispatchAuthUpdates(updates)
|
||||
}
|
||||
|
||||
func (w *Watcher) prepareAuthUpdatesLocked(auths []*coreauth.Auth, force bool) []AuthUpdate {
|
||||
newState := make(map[string]*coreauth.Auth, len(auths))
|
||||
for _, auth := range auths {
|
||||
if auth == nil || auth.ID == "" {
|
||||
continue
|
||||
}
|
||||
newState[auth.ID] = auth.Clone()
|
||||
}
|
||||
if w.currentAuths == nil {
|
||||
w.currentAuths = newState
|
||||
if w.authQueue == nil {
|
||||
return nil
|
||||
}
|
||||
updates := make([]AuthUpdate, 0, len(newState))
|
||||
for id, auth := range newState {
|
||||
updates = append(updates, AuthUpdate{Action: AuthUpdateActionAdd, ID: id, Auth: auth.Clone()})
|
||||
}
|
||||
return updates
|
||||
}
|
||||
if w.authQueue == nil {
|
||||
w.currentAuths = newState
|
||||
return nil
|
||||
}
|
||||
updates := make([]AuthUpdate, 0, len(newState)+len(w.currentAuths))
|
||||
for id, auth := range newState {
|
||||
if existing, ok := w.currentAuths[id]; !ok {
|
||||
updates = append(updates, AuthUpdate{Action: AuthUpdateActionAdd, ID: id, Auth: auth.Clone()})
|
||||
} else if force || !authEqual(existing, auth) {
|
||||
updates = append(updates, AuthUpdate{Action: AuthUpdateActionModify, ID: id, Auth: auth.Clone()})
|
||||
}
|
||||
}
|
||||
for id := range w.currentAuths {
|
||||
if _, ok := newState[id]; !ok {
|
||||
updates = append(updates, AuthUpdate{Action: AuthUpdateActionDelete, ID: id})
|
||||
}
|
||||
}
|
||||
w.currentAuths = newState
|
||||
return updates
|
||||
}
|
||||
|
||||
func (w *Watcher) dispatchAuthUpdates(updates []AuthUpdate) {
|
||||
if len(updates) == 0 {
|
||||
return
|
||||
}
|
||||
queue := w.getAuthQueue()
|
||||
if queue == nil {
|
||||
return
|
||||
}
|
||||
baseTS := time.Now().UnixNano()
|
||||
w.dispatchMu.Lock()
|
||||
if w.pendingUpdates == nil {
|
||||
w.pendingUpdates = make(map[string]AuthUpdate)
|
||||
}
|
||||
for idx, update := range updates {
|
||||
key := w.authUpdateKey(update, baseTS+int64(idx))
|
||||
if _, exists := w.pendingUpdates[key]; !exists {
|
||||
w.pendingOrder = append(w.pendingOrder, key)
|
||||
}
|
||||
w.pendingUpdates[key] = update
|
||||
}
|
||||
if w.dispatchCond != nil {
|
||||
w.dispatchCond.Signal()
|
||||
}
|
||||
w.dispatchMu.Unlock()
|
||||
}
|
||||
|
||||
func (w *Watcher) authUpdateKey(update AuthUpdate, ts int64) string {
|
||||
if update.ID != "" {
|
||||
return update.ID
|
||||
}
|
||||
return fmt.Sprintf("%s:%d", update.Action, ts)
|
||||
}
|
||||
|
||||
func (w *Watcher) dispatchLoop(ctx context.Context) {
|
||||
for {
|
||||
batch, ok := w.nextPendingBatch(ctx)
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
queue := w.getAuthQueue()
|
||||
if queue == nil {
|
||||
if ctx.Err() != nil {
|
||||
return
|
||||
}
|
||||
time.Sleep(10 * time.Millisecond)
|
||||
continue
|
||||
}
|
||||
for _, update := range batch {
|
||||
select {
|
||||
case queue <- update:
|
||||
case <-ctx.Done():
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (w *Watcher) nextPendingBatch(ctx context.Context) ([]AuthUpdate, bool) {
|
||||
w.dispatchMu.Lock()
|
||||
defer w.dispatchMu.Unlock()
|
||||
for len(w.pendingOrder) == 0 {
|
||||
if ctx.Err() != nil {
|
||||
return nil, false
|
||||
}
|
||||
w.dispatchCond.Wait()
|
||||
if ctx.Err() != nil {
|
||||
return nil, false
|
||||
}
|
||||
}
|
||||
batch := make([]AuthUpdate, 0, len(w.pendingOrder))
|
||||
for _, key := range w.pendingOrder {
|
||||
batch = append(batch, w.pendingUpdates[key])
|
||||
delete(w.pendingUpdates, key)
|
||||
}
|
||||
w.pendingOrder = w.pendingOrder[:0]
|
||||
return batch, true
|
||||
}
|
||||
|
||||
func (w *Watcher) getAuthQueue() chan<- AuthUpdate {
|
||||
w.clientsMutex.RLock()
|
||||
defer w.clientsMutex.RUnlock()
|
||||
return w.authQueue
|
||||
}
|
||||
|
||||
func (w *Watcher) stopDispatch() {
|
||||
if w.dispatchCancel != nil {
|
||||
w.dispatchCancel()
|
||||
w.dispatchCancel = nil
|
||||
}
|
||||
w.dispatchMu.Lock()
|
||||
w.pendingOrder = nil
|
||||
w.pendingUpdates = nil
|
||||
if w.dispatchCond != nil {
|
||||
w.dispatchCond.Broadcast()
|
||||
}
|
||||
w.dispatchMu.Unlock()
|
||||
w.clientsMutex.Lock()
|
||||
w.authQueue = nil
|
||||
w.clientsMutex.Unlock()
|
||||
}
|
||||
|
||||
func authEqual(a, b *coreauth.Auth) bool {
|
||||
return reflect.DeepEqual(normalizeAuth(a), normalizeAuth(b))
|
||||
}
|
||||
|
||||
func normalizeAuth(a *coreauth.Auth) *coreauth.Auth {
|
||||
if a == nil {
|
||||
return nil
|
||||
}
|
||||
clone := a.Clone()
|
||||
clone.CreatedAt = time.Time{}
|
||||
clone.UpdatedAt = time.Time{}
|
||||
clone.LastRefreshedAt = time.Time{}
|
||||
clone.NextRefreshAfter = time.Time{}
|
||||
clone.Runtime = nil
|
||||
clone.Quota.NextRecoverAt = time.Time{}
|
||||
return clone
|
||||
}
|
||||
|
||||
func snapshotCoreAuths(cfg *config.Config, authDir string) []*coreauth.Auth {
|
||||
ctx := &synthesizer.SynthesisContext{
|
||||
Config: cfg,
|
||||
AuthDir: authDir,
|
||||
Now: time.Now(),
|
||||
IDGenerator: synthesizer.NewStableIDGenerator(),
|
||||
}
|
||||
|
||||
var out []*coreauth.Auth
|
||||
|
||||
configSynth := synthesizer.NewConfigSynthesizer()
|
||||
if auths, err := configSynth.Synthesize(ctx); err == nil {
|
||||
out = append(out, auths...)
|
||||
}
|
||||
|
||||
fileSynth := synthesizer.NewFileSynthesizer()
|
||||
if auths, err := fileSynth.Synthesize(ctx); err == nil {
|
||||
out = append(out, auths...)
|
||||
}
|
||||
|
||||
return out
|
||||
}
|
||||
194
internal/watcher/events.go
Normal file
194
internal/watcher/events.go
Normal file
@@ -0,0 +1,194 @@
|
||||
// events.go implements fsnotify event handling for config and auth file changes.
|
||||
// It normalizes paths, debounces noisy events, and triggers reload/update logic.
|
||||
package watcher
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/fsnotify/fsnotify"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
func matchProvider(provider string, targets []string) (string, bool) {
|
||||
p := strings.ToLower(strings.TrimSpace(provider))
|
||||
for _, t := range targets {
|
||||
if strings.EqualFold(p, strings.TrimSpace(t)) {
|
||||
return p, true
|
||||
}
|
||||
}
|
||||
return p, false
|
||||
}
|
||||
|
||||
func (w *Watcher) start(ctx context.Context) error {
|
||||
if errAddConfig := w.watcher.Add(w.configPath); errAddConfig != nil {
|
||||
log.Errorf("failed to watch config file %s: %v", w.configPath, errAddConfig)
|
||||
return errAddConfig
|
||||
}
|
||||
log.Debugf("watching config file: %s", w.configPath)
|
||||
|
||||
if errAddAuthDir := w.watcher.Add(w.authDir); errAddAuthDir != nil {
|
||||
log.Errorf("failed to watch auth directory %s: %v", w.authDir, errAddAuthDir)
|
||||
return errAddAuthDir
|
||||
}
|
||||
log.Debugf("watching auth directory: %s", w.authDir)
|
||||
|
||||
go w.processEvents(ctx)
|
||||
|
||||
w.reloadClients(true, nil, false)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (w *Watcher) processEvents(ctx context.Context) {
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case event, ok := <-w.watcher.Events:
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
w.handleEvent(event)
|
||||
case errWatch, ok := <-w.watcher.Errors:
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
log.Errorf("file watcher error: %v", errWatch)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (w *Watcher) handleEvent(event fsnotify.Event) {
|
||||
// Filter only relevant events: config file or auth-dir JSON files.
|
||||
configOps := fsnotify.Write | fsnotify.Create | fsnotify.Rename
|
||||
normalizedName := w.normalizeAuthPath(event.Name)
|
||||
normalizedConfigPath := w.normalizeAuthPath(w.configPath)
|
||||
normalizedAuthDir := w.normalizeAuthPath(w.authDir)
|
||||
isConfigEvent := normalizedName == normalizedConfigPath && event.Op&configOps != 0
|
||||
authOps := fsnotify.Create | fsnotify.Write | fsnotify.Remove | fsnotify.Rename
|
||||
isAuthJSON := strings.HasPrefix(normalizedName, normalizedAuthDir) && strings.HasSuffix(normalizedName, ".json") && event.Op&authOps != 0
|
||||
if !isConfigEvent && !isAuthJSON {
|
||||
// Ignore unrelated files (e.g., cookie snapshots *.cookie) and other noise.
|
||||
return
|
||||
}
|
||||
|
||||
now := time.Now()
|
||||
log.Debugf("file system event detected: %s %s", event.Op.String(), event.Name)
|
||||
|
||||
// Handle config file changes
|
||||
if isConfigEvent {
|
||||
log.Debugf("config file change details - operation: %s, timestamp: %s", event.Op.String(), now.Format("2006-01-02 15:04:05.000"))
|
||||
w.scheduleConfigReload()
|
||||
return
|
||||
}
|
||||
|
||||
// Handle auth directory changes incrementally (.json only)
|
||||
if event.Op&(fsnotify.Remove|fsnotify.Rename) != 0 {
|
||||
if w.shouldDebounceRemove(normalizedName, now) {
|
||||
log.Debugf("debouncing remove event for %s", filepath.Base(event.Name))
|
||||
return
|
||||
}
|
||||
// Atomic replace on some platforms may surface as Rename (or Remove) before the new file is ready.
|
||||
// Wait briefly; if the path exists again, treat as an update instead of removal.
|
||||
time.Sleep(replaceCheckDelay)
|
||||
if _, statErr := os.Stat(event.Name); statErr == nil {
|
||||
if unchanged, errSame := w.authFileUnchanged(event.Name); errSame == nil && unchanged {
|
||||
log.Debugf("auth file unchanged (hash match), skipping reload: %s", filepath.Base(event.Name))
|
||||
return
|
||||
}
|
||||
log.Infof("auth file changed (%s): %s, processing incrementally", event.Op.String(), filepath.Base(event.Name))
|
||||
w.addOrUpdateClient(event.Name)
|
||||
return
|
||||
}
|
||||
if !w.isKnownAuthFile(event.Name) {
|
||||
log.Debugf("ignoring remove for unknown auth file: %s", filepath.Base(event.Name))
|
||||
return
|
||||
}
|
||||
log.Infof("auth file changed (%s): %s, processing incrementally", event.Op.String(), filepath.Base(event.Name))
|
||||
w.removeClient(event.Name)
|
||||
return
|
||||
}
|
||||
if event.Op&(fsnotify.Create|fsnotify.Write) != 0 {
|
||||
if unchanged, errSame := w.authFileUnchanged(event.Name); errSame == nil && unchanged {
|
||||
log.Debugf("auth file unchanged (hash match), skipping reload: %s", filepath.Base(event.Name))
|
||||
return
|
||||
}
|
||||
log.Infof("auth file changed (%s): %s, processing incrementally", event.Op.String(), filepath.Base(event.Name))
|
||||
w.addOrUpdateClient(event.Name)
|
||||
}
|
||||
}
|
||||
|
||||
func (w *Watcher) authFileUnchanged(path string) (bool, error) {
|
||||
data, errRead := os.ReadFile(path)
|
||||
if errRead != nil {
|
||||
return false, errRead
|
||||
}
|
||||
if len(data) == 0 {
|
||||
return false, nil
|
||||
}
|
||||
sum := sha256.Sum256(data)
|
||||
curHash := hex.EncodeToString(sum[:])
|
||||
|
||||
normalized := w.normalizeAuthPath(path)
|
||||
w.clientsMutex.RLock()
|
||||
prevHash, ok := w.lastAuthHashes[normalized]
|
||||
w.clientsMutex.RUnlock()
|
||||
if ok && prevHash == curHash {
|
||||
return true, nil
|
||||
}
|
||||
return false, nil
|
||||
}
|
||||
|
||||
func (w *Watcher) isKnownAuthFile(path string) bool {
|
||||
normalized := w.normalizeAuthPath(path)
|
||||
w.clientsMutex.RLock()
|
||||
defer w.clientsMutex.RUnlock()
|
||||
_, ok := w.lastAuthHashes[normalized]
|
||||
return ok
|
||||
}
|
||||
|
||||
func (w *Watcher) normalizeAuthPath(path string) string {
|
||||
trimmed := strings.TrimSpace(path)
|
||||
if trimmed == "" {
|
||||
return ""
|
||||
}
|
||||
cleaned := filepath.Clean(trimmed)
|
||||
if runtime.GOOS == "windows" {
|
||||
cleaned = strings.TrimPrefix(cleaned, `\\?\`)
|
||||
cleaned = strings.ToLower(cleaned)
|
||||
}
|
||||
return cleaned
|
||||
}
|
||||
|
||||
func (w *Watcher) shouldDebounceRemove(normalizedPath string, now time.Time) bool {
|
||||
if normalizedPath == "" {
|
||||
return false
|
||||
}
|
||||
w.clientsMutex.Lock()
|
||||
if w.lastRemoveTimes == nil {
|
||||
w.lastRemoveTimes = make(map[string]time.Time)
|
||||
}
|
||||
if last, ok := w.lastRemoveTimes[normalizedPath]; ok {
|
||||
if now.Sub(last) < authRemoveDebounceWindow {
|
||||
w.clientsMutex.Unlock()
|
||||
return true
|
||||
}
|
||||
}
|
||||
w.lastRemoveTimes[normalizedPath] = now
|
||||
if len(w.lastRemoveTimes) > 128 {
|
||||
cutoff := now.Add(-2 * authRemoveDebounceWindow)
|
||||
for p, t := range w.lastRemoveTimes {
|
||||
if t.Before(cutoff) {
|
||||
delete(w.lastRemoveTimes, p)
|
||||
}
|
||||
}
|
||||
}
|
||||
w.clientsMutex.Unlock()
|
||||
return false
|
||||
}
|
||||
294
internal/watcher/synthesizer/config.go
Normal file
294
internal/watcher/synthesizer/config.go
Normal file
@@ -0,0 +1,294 @@
|
||||
package synthesizer
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/watcher/diff"
|
||||
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||
)
|
||||
|
||||
// ConfigSynthesizer generates Auth entries from configuration API keys.
|
||||
// It handles Gemini, Claude, Codex, OpenAI-compat, and Vertex-compat providers.
|
||||
type ConfigSynthesizer struct{}
|
||||
|
||||
// NewConfigSynthesizer creates a new ConfigSynthesizer instance.
|
||||
func NewConfigSynthesizer() *ConfigSynthesizer {
|
||||
return &ConfigSynthesizer{}
|
||||
}
|
||||
|
||||
// Synthesize generates Auth entries from config API keys.
|
||||
func (s *ConfigSynthesizer) Synthesize(ctx *SynthesisContext) ([]*coreauth.Auth, error) {
|
||||
out := make([]*coreauth.Auth, 0, 32)
|
||||
if ctx == nil || ctx.Config == nil {
|
||||
return out, nil
|
||||
}
|
||||
|
||||
// Gemini API Keys
|
||||
out = append(out, s.synthesizeGeminiKeys(ctx)...)
|
||||
// Claude API Keys
|
||||
out = append(out, s.synthesizeClaudeKeys(ctx)...)
|
||||
// Codex API Keys
|
||||
out = append(out, s.synthesizeCodexKeys(ctx)...)
|
||||
// OpenAI-compat
|
||||
out = append(out, s.synthesizeOpenAICompat(ctx)...)
|
||||
// Vertex-compat
|
||||
out = append(out, s.synthesizeVertexCompat(ctx)...)
|
||||
|
||||
return out, nil
|
||||
}
|
||||
|
||||
// synthesizeGeminiKeys creates Auth entries for Gemini API keys.
|
||||
func (s *ConfigSynthesizer) synthesizeGeminiKeys(ctx *SynthesisContext) []*coreauth.Auth {
|
||||
cfg := ctx.Config
|
||||
now := ctx.Now
|
||||
idGen := ctx.IDGenerator
|
||||
|
||||
out := make([]*coreauth.Auth, 0, len(cfg.GeminiKey))
|
||||
for i := range cfg.GeminiKey {
|
||||
entry := cfg.GeminiKey[i]
|
||||
key := strings.TrimSpace(entry.APIKey)
|
||||
if key == "" {
|
||||
continue
|
||||
}
|
||||
prefix := strings.TrimSpace(entry.Prefix)
|
||||
base := strings.TrimSpace(entry.BaseURL)
|
||||
proxyURL := strings.TrimSpace(entry.ProxyURL)
|
||||
id, token := idGen.Next("gemini:apikey", key, base)
|
||||
attrs := map[string]string{
|
||||
"source": fmt.Sprintf("config:gemini[%s]", token),
|
||||
"api_key": key,
|
||||
}
|
||||
if base != "" {
|
||||
attrs["base_url"] = base
|
||||
}
|
||||
addConfigHeadersToAttrs(entry.Headers, attrs)
|
||||
a := &coreauth.Auth{
|
||||
ID: id,
|
||||
Provider: "gemini",
|
||||
Label: "gemini-apikey",
|
||||
Prefix: prefix,
|
||||
Status: coreauth.StatusActive,
|
||||
ProxyURL: proxyURL,
|
||||
Attributes: attrs,
|
||||
CreatedAt: now,
|
||||
UpdatedAt: now,
|
||||
}
|
||||
ApplyAuthExcludedModelsMeta(a, cfg, entry.ExcludedModels, "apikey")
|
||||
out = append(out, a)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// synthesizeClaudeKeys creates Auth entries for Claude API keys.
|
||||
func (s *ConfigSynthesizer) synthesizeClaudeKeys(ctx *SynthesisContext) []*coreauth.Auth {
|
||||
cfg := ctx.Config
|
||||
now := ctx.Now
|
||||
idGen := ctx.IDGenerator
|
||||
|
||||
out := make([]*coreauth.Auth, 0, len(cfg.ClaudeKey))
|
||||
for i := range cfg.ClaudeKey {
|
||||
ck := cfg.ClaudeKey[i]
|
||||
key := strings.TrimSpace(ck.APIKey)
|
||||
if key == "" {
|
||||
continue
|
||||
}
|
||||
prefix := strings.TrimSpace(ck.Prefix)
|
||||
base := strings.TrimSpace(ck.BaseURL)
|
||||
id, token := idGen.Next("claude:apikey", key, base)
|
||||
attrs := map[string]string{
|
||||
"source": fmt.Sprintf("config:claude[%s]", token),
|
||||
"api_key": key,
|
||||
}
|
||||
if base != "" {
|
||||
attrs["base_url"] = base
|
||||
}
|
||||
if hash := diff.ComputeClaudeModelsHash(ck.Models); hash != "" {
|
||||
attrs["models_hash"] = hash
|
||||
}
|
||||
addConfigHeadersToAttrs(ck.Headers, attrs)
|
||||
proxyURL := strings.TrimSpace(ck.ProxyURL)
|
||||
a := &coreauth.Auth{
|
||||
ID: id,
|
||||
Provider: "claude",
|
||||
Label: "claude-apikey",
|
||||
Prefix: prefix,
|
||||
Status: coreauth.StatusActive,
|
||||
ProxyURL: proxyURL,
|
||||
Attributes: attrs,
|
||||
CreatedAt: now,
|
||||
UpdatedAt: now,
|
||||
}
|
||||
ApplyAuthExcludedModelsMeta(a, cfg, ck.ExcludedModels, "apikey")
|
||||
out = append(out, a)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// synthesizeCodexKeys creates Auth entries for Codex API keys.
|
||||
func (s *ConfigSynthesizer) synthesizeCodexKeys(ctx *SynthesisContext) []*coreauth.Auth {
|
||||
cfg := ctx.Config
|
||||
now := ctx.Now
|
||||
idGen := ctx.IDGenerator
|
||||
|
||||
out := make([]*coreauth.Auth, 0, len(cfg.CodexKey))
|
||||
for i := range cfg.CodexKey {
|
||||
ck := cfg.CodexKey[i]
|
||||
key := strings.TrimSpace(ck.APIKey)
|
||||
if key == "" {
|
||||
continue
|
||||
}
|
||||
prefix := strings.TrimSpace(ck.Prefix)
|
||||
id, token := idGen.Next("codex:apikey", key, ck.BaseURL)
|
||||
attrs := map[string]string{
|
||||
"source": fmt.Sprintf("config:codex[%s]", token),
|
||||
"api_key": key,
|
||||
}
|
||||
if ck.BaseURL != "" {
|
||||
attrs["base_url"] = ck.BaseURL
|
||||
}
|
||||
addConfigHeadersToAttrs(ck.Headers, attrs)
|
||||
proxyURL := strings.TrimSpace(ck.ProxyURL)
|
||||
a := &coreauth.Auth{
|
||||
ID: id,
|
||||
Provider: "codex",
|
||||
Label: "codex-apikey",
|
||||
Prefix: prefix,
|
||||
Status: coreauth.StatusActive,
|
||||
ProxyURL: proxyURL,
|
||||
Attributes: attrs,
|
||||
CreatedAt: now,
|
||||
UpdatedAt: now,
|
||||
}
|
||||
ApplyAuthExcludedModelsMeta(a, cfg, ck.ExcludedModels, "apikey")
|
||||
out = append(out, a)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// synthesizeOpenAICompat creates Auth entries for OpenAI-compatible providers.
|
||||
func (s *ConfigSynthesizer) synthesizeOpenAICompat(ctx *SynthesisContext) []*coreauth.Auth {
|
||||
cfg := ctx.Config
|
||||
now := ctx.Now
|
||||
idGen := ctx.IDGenerator
|
||||
|
||||
out := make([]*coreauth.Auth, 0)
|
||||
for i := range cfg.OpenAICompatibility {
|
||||
compat := &cfg.OpenAICompatibility[i]
|
||||
prefix := strings.TrimSpace(compat.Prefix)
|
||||
providerName := strings.ToLower(strings.TrimSpace(compat.Name))
|
||||
if providerName == "" {
|
||||
providerName = "openai-compatibility"
|
||||
}
|
||||
base := strings.TrimSpace(compat.BaseURL)
|
||||
|
||||
// Handle new APIKeyEntries format (preferred)
|
||||
createdEntries := 0
|
||||
for j := range compat.APIKeyEntries {
|
||||
entry := &compat.APIKeyEntries[j]
|
||||
key := strings.TrimSpace(entry.APIKey)
|
||||
proxyURL := strings.TrimSpace(entry.ProxyURL)
|
||||
idKind := fmt.Sprintf("openai-compatibility:%s", providerName)
|
||||
id, token := idGen.Next(idKind, key, base, proxyURL)
|
||||
attrs := map[string]string{
|
||||
"source": fmt.Sprintf("config:%s[%s]", providerName, token),
|
||||
"base_url": base,
|
||||
"compat_name": compat.Name,
|
||||
"provider_key": providerName,
|
||||
}
|
||||
if key != "" {
|
||||
attrs["api_key"] = key
|
||||
}
|
||||
if hash := diff.ComputeOpenAICompatModelsHash(compat.Models); hash != "" {
|
||||
attrs["models_hash"] = hash
|
||||
}
|
||||
addConfigHeadersToAttrs(compat.Headers, attrs)
|
||||
a := &coreauth.Auth{
|
||||
ID: id,
|
||||
Provider: providerName,
|
||||
Label: compat.Name,
|
||||
Prefix: prefix,
|
||||
Status: coreauth.StatusActive,
|
||||
ProxyURL: proxyURL,
|
||||
Attributes: attrs,
|
||||
CreatedAt: now,
|
||||
UpdatedAt: now,
|
||||
}
|
||||
out = append(out, a)
|
||||
createdEntries++
|
||||
}
|
||||
// Fallback: create entry without API key if no APIKeyEntries
|
||||
if createdEntries == 0 {
|
||||
idKind := fmt.Sprintf("openai-compatibility:%s", providerName)
|
||||
id, token := idGen.Next(idKind, base)
|
||||
attrs := map[string]string{
|
||||
"source": fmt.Sprintf("config:%s[%s]", providerName, token),
|
||||
"base_url": base,
|
||||
"compat_name": compat.Name,
|
||||
"provider_key": providerName,
|
||||
}
|
||||
if hash := diff.ComputeOpenAICompatModelsHash(compat.Models); hash != "" {
|
||||
attrs["models_hash"] = hash
|
||||
}
|
||||
addConfigHeadersToAttrs(compat.Headers, attrs)
|
||||
a := &coreauth.Auth{
|
||||
ID: id,
|
||||
Provider: providerName,
|
||||
Label: compat.Name,
|
||||
Prefix: prefix,
|
||||
Status: coreauth.StatusActive,
|
||||
Attributes: attrs,
|
||||
CreatedAt: now,
|
||||
UpdatedAt: now,
|
||||
}
|
||||
out = append(out, a)
|
||||
}
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// synthesizeVertexCompat creates Auth entries for Vertex-compatible providers.
|
||||
func (s *ConfigSynthesizer) synthesizeVertexCompat(ctx *SynthesisContext) []*coreauth.Auth {
|
||||
cfg := ctx.Config
|
||||
now := ctx.Now
|
||||
idGen := ctx.IDGenerator
|
||||
|
||||
out := make([]*coreauth.Auth, 0, len(cfg.VertexCompatAPIKey))
|
||||
for i := range cfg.VertexCompatAPIKey {
|
||||
compat := &cfg.VertexCompatAPIKey[i]
|
||||
providerName := "vertex"
|
||||
base := strings.TrimSpace(compat.BaseURL)
|
||||
|
||||
key := strings.TrimSpace(compat.APIKey)
|
||||
prefix := strings.TrimSpace(compat.Prefix)
|
||||
proxyURL := strings.TrimSpace(compat.ProxyURL)
|
||||
idKind := "vertex:apikey"
|
||||
id, token := idGen.Next(idKind, key, base, proxyURL)
|
||||
attrs := map[string]string{
|
||||
"source": fmt.Sprintf("config:vertex-apikey[%s]", token),
|
||||
"base_url": base,
|
||||
"provider_key": providerName,
|
||||
}
|
||||
if key != "" {
|
||||
attrs["api_key"] = key
|
||||
}
|
||||
if hash := diff.ComputeVertexCompatModelsHash(compat.Models); hash != "" {
|
||||
attrs["models_hash"] = hash
|
||||
}
|
||||
addConfigHeadersToAttrs(compat.Headers, attrs)
|
||||
a := &coreauth.Auth{
|
||||
ID: id,
|
||||
Provider: providerName,
|
||||
Label: "vertex-apikey",
|
||||
Prefix: prefix,
|
||||
Status: coreauth.StatusActive,
|
||||
ProxyURL: proxyURL,
|
||||
Attributes: attrs,
|
||||
CreatedAt: now,
|
||||
UpdatedAt: now,
|
||||
}
|
||||
ApplyAuthExcludedModelsMeta(a, cfg, nil, "apikey")
|
||||
out = append(out, a)
|
||||
}
|
||||
return out
|
||||
}
|
||||
613
internal/watcher/synthesizer/config_test.go
Normal file
613
internal/watcher/synthesizer/config_test.go
Normal file
@@ -0,0 +1,613 @@
|
||||
package synthesizer
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||
)
|
||||
|
||||
func TestNewConfigSynthesizer(t *testing.T) {
|
||||
synth := NewConfigSynthesizer()
|
||||
if synth == nil {
|
||||
t.Fatal("expected non-nil synthesizer")
|
||||
}
|
||||
}
|
||||
|
||||
func TestConfigSynthesizer_Synthesize_NilContext(t *testing.T) {
|
||||
synth := NewConfigSynthesizer()
|
||||
auths, err := synth.Synthesize(nil)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if len(auths) != 0 {
|
||||
t.Fatalf("expected empty auths, got %d", len(auths))
|
||||
}
|
||||
}
|
||||
|
||||
func TestConfigSynthesizer_Synthesize_NilConfig(t *testing.T) {
|
||||
synth := NewConfigSynthesizer()
|
||||
ctx := &SynthesisContext{
|
||||
Config: nil,
|
||||
Now: time.Now(),
|
||||
IDGenerator: NewStableIDGenerator(),
|
||||
}
|
||||
auths, err := synth.Synthesize(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if len(auths) != 0 {
|
||||
t.Fatalf("expected empty auths, got %d", len(auths))
|
||||
}
|
||||
}
|
||||
|
||||
func TestConfigSynthesizer_GeminiKeys(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
geminiKeys []config.GeminiKey
|
||||
wantLen int
|
||||
validate func(*testing.T, []*coreauth.Auth)
|
||||
}{
|
||||
{
|
||||
name: "single gemini key",
|
||||
geminiKeys: []config.GeminiKey{
|
||||
{APIKey: "test-key-123", Prefix: "team-a"},
|
||||
},
|
||||
wantLen: 1,
|
||||
validate: func(t *testing.T, auths []*coreauth.Auth) {
|
||||
if auths[0].Provider != "gemini" {
|
||||
t.Errorf("expected provider gemini, got %s", auths[0].Provider)
|
||||
}
|
||||
if auths[0].Prefix != "team-a" {
|
||||
t.Errorf("expected prefix team-a, got %s", auths[0].Prefix)
|
||||
}
|
||||
if auths[0].Label != "gemini-apikey" {
|
||||
t.Errorf("expected label gemini-apikey, got %s", auths[0].Label)
|
||||
}
|
||||
if auths[0].Attributes["api_key"] != "test-key-123" {
|
||||
t.Errorf("expected api_key test-key-123, got %s", auths[0].Attributes["api_key"])
|
||||
}
|
||||
if auths[0].Status != coreauth.StatusActive {
|
||||
t.Errorf("expected status active, got %s", auths[0].Status)
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "gemini key with base url and proxy",
|
||||
geminiKeys: []config.GeminiKey{
|
||||
{
|
||||
APIKey: "api-key",
|
||||
BaseURL: "https://custom.api.com",
|
||||
ProxyURL: "http://proxy.local:8080",
|
||||
Prefix: "custom",
|
||||
},
|
||||
},
|
||||
wantLen: 1,
|
||||
validate: func(t *testing.T, auths []*coreauth.Auth) {
|
||||
if auths[0].Attributes["base_url"] != "https://custom.api.com" {
|
||||
t.Errorf("expected base_url https://custom.api.com, got %s", auths[0].Attributes["base_url"])
|
||||
}
|
||||
if auths[0].ProxyURL != "http://proxy.local:8080" {
|
||||
t.Errorf("expected proxy_url http://proxy.local:8080, got %s", auths[0].ProxyURL)
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "gemini key with headers",
|
||||
geminiKeys: []config.GeminiKey{
|
||||
{
|
||||
APIKey: "api-key",
|
||||
Headers: map[string]string{"X-Custom": "value"},
|
||||
},
|
||||
},
|
||||
wantLen: 1,
|
||||
validate: func(t *testing.T, auths []*coreauth.Auth) {
|
||||
if auths[0].Attributes["header:X-Custom"] != "value" {
|
||||
t.Errorf("expected header:X-Custom=value, got %s", auths[0].Attributes["header:X-Custom"])
|
||||
}
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "empty api key skipped",
|
||||
geminiKeys: []config.GeminiKey{
|
||||
{APIKey: ""},
|
||||
{APIKey: " "},
|
||||
{APIKey: "valid-key"},
|
||||
},
|
||||
wantLen: 1,
|
||||
},
|
||||
{
|
||||
name: "multiple gemini keys",
|
||||
geminiKeys: []config.GeminiKey{
|
||||
{APIKey: "key-1", Prefix: "a"},
|
||||
{APIKey: "key-2", Prefix: "b"},
|
||||
{APIKey: "key-3", Prefix: "c"},
|
||||
},
|
||||
wantLen: 3,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
synth := NewConfigSynthesizer()
|
||||
ctx := &SynthesisContext{
|
||||
Config: &config.Config{
|
||||
GeminiKey: tt.geminiKeys,
|
||||
},
|
||||
Now: time.Date(2025, 1, 1, 0, 0, 0, 0, time.UTC),
|
||||
IDGenerator: NewStableIDGenerator(),
|
||||
}
|
||||
|
||||
auths, err := synth.Synthesize(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if len(auths) != tt.wantLen {
|
||||
t.Fatalf("expected %d auths, got %d", tt.wantLen, len(auths))
|
||||
}
|
||||
|
||||
if tt.validate != nil && len(auths) > 0 {
|
||||
tt.validate(t, auths)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestConfigSynthesizer_ClaudeKeys(t *testing.T) {
|
||||
synth := NewConfigSynthesizer()
|
||||
ctx := &SynthesisContext{
|
||||
Config: &config.Config{
|
||||
ClaudeKey: []config.ClaudeKey{
|
||||
{
|
||||
APIKey: "sk-ant-api-xxx",
|
||||
Prefix: "main",
|
||||
BaseURL: "https://api.anthropic.com",
|
||||
Models: []config.ClaudeModel{
|
||||
{Name: "claude-3-opus"},
|
||||
{Name: "claude-3-sonnet"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
Now: time.Now(),
|
||||
IDGenerator: NewStableIDGenerator(),
|
||||
}
|
||||
|
||||
auths, err := synth.Synthesize(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if len(auths) != 1 {
|
||||
t.Fatalf("expected 1 auth, got %d", len(auths))
|
||||
}
|
||||
|
||||
if auths[0].Provider != "claude" {
|
||||
t.Errorf("expected provider claude, got %s", auths[0].Provider)
|
||||
}
|
||||
if auths[0].Label != "claude-apikey" {
|
||||
t.Errorf("expected label claude-apikey, got %s", auths[0].Label)
|
||||
}
|
||||
if auths[0].Prefix != "main" {
|
||||
t.Errorf("expected prefix main, got %s", auths[0].Prefix)
|
||||
}
|
||||
if auths[0].Attributes["api_key"] != "sk-ant-api-xxx" {
|
||||
t.Errorf("expected api_key sk-ant-api-xxx, got %s", auths[0].Attributes["api_key"])
|
||||
}
|
||||
if _, ok := auths[0].Attributes["models_hash"]; !ok {
|
||||
t.Error("expected models_hash in attributes")
|
||||
}
|
||||
}
|
||||
|
||||
func TestConfigSynthesizer_ClaudeKeys_SkipsEmptyAndHeaders(t *testing.T) {
|
||||
synth := NewConfigSynthesizer()
|
||||
ctx := &SynthesisContext{
|
||||
Config: &config.Config{
|
||||
ClaudeKey: []config.ClaudeKey{
|
||||
{APIKey: ""}, // empty, should be skipped
|
||||
{APIKey: " "}, // whitespace, should be skipped
|
||||
{APIKey: "valid-key", Headers: map[string]string{"X-Custom": "value"}},
|
||||
},
|
||||
},
|
||||
Now: time.Now(),
|
||||
IDGenerator: NewStableIDGenerator(),
|
||||
}
|
||||
|
||||
auths, err := synth.Synthesize(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if len(auths) != 1 {
|
||||
t.Fatalf("expected 1 auth (empty keys skipped), got %d", len(auths))
|
||||
}
|
||||
if auths[0].Attributes["header:X-Custom"] != "value" {
|
||||
t.Errorf("expected header:X-Custom=value, got %s", auths[0].Attributes["header:X-Custom"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestConfigSynthesizer_CodexKeys(t *testing.T) {
|
||||
synth := NewConfigSynthesizer()
|
||||
ctx := &SynthesisContext{
|
||||
Config: &config.Config{
|
||||
CodexKey: []config.CodexKey{
|
||||
{
|
||||
APIKey: "codex-key-123",
|
||||
Prefix: "dev",
|
||||
BaseURL: "https://api.openai.com",
|
||||
ProxyURL: "http://proxy.local",
|
||||
},
|
||||
},
|
||||
},
|
||||
Now: time.Now(),
|
||||
IDGenerator: NewStableIDGenerator(),
|
||||
}
|
||||
|
||||
auths, err := synth.Synthesize(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if len(auths) != 1 {
|
||||
t.Fatalf("expected 1 auth, got %d", len(auths))
|
||||
}
|
||||
|
||||
if auths[0].Provider != "codex" {
|
||||
t.Errorf("expected provider codex, got %s", auths[0].Provider)
|
||||
}
|
||||
if auths[0].Label != "codex-apikey" {
|
||||
t.Errorf("expected label codex-apikey, got %s", auths[0].Label)
|
||||
}
|
||||
if auths[0].ProxyURL != "http://proxy.local" {
|
||||
t.Errorf("expected proxy_url http://proxy.local, got %s", auths[0].ProxyURL)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConfigSynthesizer_CodexKeys_SkipsEmptyAndHeaders(t *testing.T) {
|
||||
synth := NewConfigSynthesizer()
|
||||
ctx := &SynthesisContext{
|
||||
Config: &config.Config{
|
||||
CodexKey: []config.CodexKey{
|
||||
{APIKey: ""}, // empty, should be skipped
|
||||
{APIKey: " "}, // whitespace, should be skipped
|
||||
{APIKey: "valid-key", Headers: map[string]string{"Authorization": "Bearer xyz"}},
|
||||
},
|
||||
},
|
||||
Now: time.Now(),
|
||||
IDGenerator: NewStableIDGenerator(),
|
||||
}
|
||||
|
||||
auths, err := synth.Synthesize(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if len(auths) != 1 {
|
||||
t.Fatalf("expected 1 auth (empty keys skipped), got %d", len(auths))
|
||||
}
|
||||
if auths[0].Attributes["header:Authorization"] != "Bearer xyz" {
|
||||
t.Errorf("expected header:Authorization=Bearer xyz, got %s", auths[0].Attributes["header:Authorization"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestConfigSynthesizer_OpenAICompat(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
compat []config.OpenAICompatibility
|
||||
wantLen int
|
||||
}{
|
||||
{
|
||||
name: "with APIKeyEntries",
|
||||
compat: []config.OpenAICompatibility{
|
||||
{
|
||||
Name: "CustomProvider",
|
||||
BaseURL: "https://custom.api.com",
|
||||
APIKeyEntries: []config.OpenAICompatibilityAPIKey{
|
||||
{APIKey: "key-1"},
|
||||
{APIKey: "key-2"},
|
||||
},
|
||||
},
|
||||
},
|
||||
wantLen: 2,
|
||||
},
|
||||
{
|
||||
name: "empty APIKeyEntries included (legacy)",
|
||||
compat: []config.OpenAICompatibility{
|
||||
{
|
||||
Name: "EmptyKeys",
|
||||
BaseURL: "https://empty.api.com",
|
||||
APIKeyEntries: []config.OpenAICompatibilityAPIKey{
|
||||
{APIKey: ""},
|
||||
{APIKey: " "},
|
||||
},
|
||||
},
|
||||
},
|
||||
wantLen: 2,
|
||||
},
|
||||
{
|
||||
name: "without APIKeyEntries (fallback)",
|
||||
compat: []config.OpenAICompatibility{
|
||||
{
|
||||
Name: "NoKeyProvider",
|
||||
BaseURL: "https://no-key.api.com",
|
||||
},
|
||||
},
|
||||
wantLen: 1,
|
||||
},
|
||||
{
|
||||
name: "empty name defaults",
|
||||
compat: []config.OpenAICompatibility{
|
||||
{
|
||||
Name: "",
|
||||
BaseURL: "https://default.api.com",
|
||||
},
|
||||
},
|
||||
wantLen: 1,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
synth := NewConfigSynthesizer()
|
||||
ctx := &SynthesisContext{
|
||||
Config: &config.Config{
|
||||
OpenAICompatibility: tt.compat,
|
||||
},
|
||||
Now: time.Now(),
|
||||
IDGenerator: NewStableIDGenerator(),
|
||||
}
|
||||
|
||||
auths, err := synth.Synthesize(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if len(auths) != tt.wantLen {
|
||||
t.Fatalf("expected %d auths, got %d", tt.wantLen, len(auths))
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestConfigSynthesizer_VertexCompat(t *testing.T) {
|
||||
synth := NewConfigSynthesizer()
|
||||
ctx := &SynthesisContext{
|
||||
Config: &config.Config{
|
||||
VertexCompatAPIKey: []config.VertexCompatKey{
|
||||
{
|
||||
APIKey: "vertex-key-123",
|
||||
BaseURL: "https://vertex.googleapis.com",
|
||||
Prefix: "vertex-prod",
|
||||
},
|
||||
},
|
||||
},
|
||||
Now: time.Now(),
|
||||
IDGenerator: NewStableIDGenerator(),
|
||||
}
|
||||
|
||||
auths, err := synth.Synthesize(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if len(auths) != 1 {
|
||||
t.Fatalf("expected 1 auth, got %d", len(auths))
|
||||
}
|
||||
|
||||
if auths[0].Provider != "vertex" {
|
||||
t.Errorf("expected provider vertex, got %s", auths[0].Provider)
|
||||
}
|
||||
if auths[0].Label != "vertex-apikey" {
|
||||
t.Errorf("expected label vertex-apikey, got %s", auths[0].Label)
|
||||
}
|
||||
if auths[0].Prefix != "vertex-prod" {
|
||||
t.Errorf("expected prefix vertex-prod, got %s", auths[0].Prefix)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConfigSynthesizer_VertexCompat_SkipsEmptyAndHeaders(t *testing.T) {
|
||||
synth := NewConfigSynthesizer()
|
||||
ctx := &SynthesisContext{
|
||||
Config: &config.Config{
|
||||
VertexCompatAPIKey: []config.VertexCompatKey{
|
||||
{APIKey: "", BaseURL: "https://vertex.api"}, // empty key creates auth without api_key attr
|
||||
{APIKey: " ", BaseURL: "https://vertex.api"}, // whitespace key creates auth without api_key attr
|
||||
{APIKey: "valid-key", BaseURL: "https://vertex.api", Headers: map[string]string{"X-Vertex": "test"}},
|
||||
},
|
||||
},
|
||||
Now: time.Now(),
|
||||
IDGenerator: NewStableIDGenerator(),
|
||||
}
|
||||
|
||||
auths, err := synth.Synthesize(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
// Vertex compat doesn't skip empty keys - it creates auths without api_key attribute
|
||||
if len(auths) != 3 {
|
||||
t.Fatalf("expected 3 auths, got %d", len(auths))
|
||||
}
|
||||
// First two should not have api_key attribute
|
||||
if _, ok := auths[0].Attributes["api_key"]; ok {
|
||||
t.Error("expected first auth to not have api_key attribute")
|
||||
}
|
||||
if _, ok := auths[1].Attributes["api_key"]; ok {
|
||||
t.Error("expected second auth to not have api_key attribute")
|
||||
}
|
||||
// Third should have headers
|
||||
if auths[2].Attributes["header:X-Vertex"] != "test" {
|
||||
t.Errorf("expected header:X-Vertex=test, got %s", auths[2].Attributes["header:X-Vertex"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestConfigSynthesizer_OpenAICompat_WithModelsHash(t *testing.T) {
|
||||
synth := NewConfigSynthesizer()
|
||||
ctx := &SynthesisContext{
|
||||
Config: &config.Config{
|
||||
OpenAICompatibility: []config.OpenAICompatibility{
|
||||
{
|
||||
Name: "TestProvider",
|
||||
BaseURL: "https://test.api.com",
|
||||
Models: []config.OpenAICompatibilityModel{
|
||||
{Name: "model-a"},
|
||||
{Name: "model-b"},
|
||||
},
|
||||
APIKeyEntries: []config.OpenAICompatibilityAPIKey{
|
||||
{APIKey: "key-with-models"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
Now: time.Now(),
|
||||
IDGenerator: NewStableIDGenerator(),
|
||||
}
|
||||
|
||||
auths, err := synth.Synthesize(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if len(auths) != 1 {
|
||||
t.Fatalf("expected 1 auth, got %d", len(auths))
|
||||
}
|
||||
if _, ok := auths[0].Attributes["models_hash"]; !ok {
|
||||
t.Error("expected models_hash in attributes")
|
||||
}
|
||||
if auths[0].Attributes["api_key"] != "key-with-models" {
|
||||
t.Errorf("expected api_key key-with-models, got %s", auths[0].Attributes["api_key"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestConfigSynthesizer_OpenAICompat_FallbackWithModels(t *testing.T) {
|
||||
synth := NewConfigSynthesizer()
|
||||
ctx := &SynthesisContext{
|
||||
Config: &config.Config{
|
||||
OpenAICompatibility: []config.OpenAICompatibility{
|
||||
{
|
||||
Name: "NoKeyWithModels",
|
||||
BaseURL: "https://nokey.api.com",
|
||||
Models: []config.OpenAICompatibilityModel{
|
||||
{Name: "model-x"},
|
||||
},
|
||||
Headers: map[string]string{"X-API": "header-value"},
|
||||
// No APIKeyEntries - should use fallback path
|
||||
},
|
||||
},
|
||||
},
|
||||
Now: time.Now(),
|
||||
IDGenerator: NewStableIDGenerator(),
|
||||
}
|
||||
|
||||
auths, err := synth.Synthesize(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if len(auths) != 1 {
|
||||
t.Fatalf("expected 1 auth, got %d", len(auths))
|
||||
}
|
||||
if _, ok := auths[0].Attributes["models_hash"]; !ok {
|
||||
t.Error("expected models_hash in fallback path")
|
||||
}
|
||||
if auths[0].Attributes["header:X-API"] != "header-value" {
|
||||
t.Errorf("expected header:X-API=header-value, got %s", auths[0].Attributes["header:X-API"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestConfigSynthesizer_VertexCompat_WithModels(t *testing.T) {
|
||||
synth := NewConfigSynthesizer()
|
||||
ctx := &SynthesisContext{
|
||||
Config: &config.Config{
|
||||
VertexCompatAPIKey: []config.VertexCompatKey{
|
||||
{
|
||||
APIKey: "vertex-key",
|
||||
BaseURL: "https://vertex.api",
|
||||
Models: []config.VertexCompatModel{
|
||||
{Name: "gemini-pro", Alias: "pro"},
|
||||
{Name: "gemini-ultra", Alias: "ultra"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
Now: time.Now(),
|
||||
IDGenerator: NewStableIDGenerator(),
|
||||
}
|
||||
|
||||
auths, err := synth.Synthesize(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if len(auths) != 1 {
|
||||
t.Fatalf("expected 1 auth, got %d", len(auths))
|
||||
}
|
||||
if _, ok := auths[0].Attributes["models_hash"]; !ok {
|
||||
t.Error("expected models_hash in vertex auth with models")
|
||||
}
|
||||
}
|
||||
|
||||
func TestConfigSynthesizer_IDStability(t *testing.T) {
|
||||
cfg := &config.Config{
|
||||
GeminiKey: []config.GeminiKey{
|
||||
{APIKey: "stable-key", Prefix: "test"},
|
||||
},
|
||||
}
|
||||
|
||||
// Generate IDs twice with fresh generators
|
||||
synth1 := NewConfigSynthesizer()
|
||||
ctx1 := &SynthesisContext{
|
||||
Config: cfg,
|
||||
Now: time.Date(2025, 1, 1, 0, 0, 0, 0, time.UTC),
|
||||
IDGenerator: NewStableIDGenerator(),
|
||||
}
|
||||
auths1, _ := synth1.Synthesize(ctx1)
|
||||
|
||||
synth2 := NewConfigSynthesizer()
|
||||
ctx2 := &SynthesisContext{
|
||||
Config: cfg,
|
||||
Now: time.Date(2025, 1, 1, 0, 0, 0, 0, time.UTC),
|
||||
IDGenerator: NewStableIDGenerator(),
|
||||
}
|
||||
auths2, _ := synth2.Synthesize(ctx2)
|
||||
|
||||
if auths1[0].ID != auths2[0].ID {
|
||||
t.Errorf("same config should produce same ID: got %q and %q", auths1[0].ID, auths2[0].ID)
|
||||
}
|
||||
}
|
||||
|
||||
func TestConfigSynthesizer_AllProviders(t *testing.T) {
|
||||
synth := NewConfigSynthesizer()
|
||||
ctx := &SynthesisContext{
|
||||
Config: &config.Config{
|
||||
GeminiKey: []config.GeminiKey{
|
||||
{APIKey: "gemini-key"},
|
||||
},
|
||||
ClaudeKey: []config.ClaudeKey{
|
||||
{APIKey: "claude-key"},
|
||||
},
|
||||
CodexKey: []config.CodexKey{
|
||||
{APIKey: "codex-key"},
|
||||
},
|
||||
OpenAICompatibility: []config.OpenAICompatibility{
|
||||
{Name: "compat", BaseURL: "https://compat.api"},
|
||||
},
|
||||
VertexCompatAPIKey: []config.VertexCompatKey{
|
||||
{APIKey: "vertex-key", BaseURL: "https://vertex.api"},
|
||||
},
|
||||
},
|
||||
Now: time.Now(),
|
||||
IDGenerator: NewStableIDGenerator(),
|
||||
}
|
||||
|
||||
auths, err := synth.Synthesize(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if len(auths) != 5 {
|
||||
t.Fatalf("expected 5 auths, got %d", len(auths))
|
||||
}
|
||||
|
||||
providers := make(map[string]bool)
|
||||
for _, a := range auths {
|
||||
providers[a.Provider] = true
|
||||
}
|
||||
|
||||
expected := []string{"gemini", "claude", "codex", "compat", "vertex"}
|
||||
for _, p := range expected {
|
||||
if !providers[p] {
|
||||
t.Errorf("expected provider %s not found", p)
|
||||
}
|
||||
}
|
||||
}
|
||||
19
internal/watcher/synthesizer/context.go
Normal file
19
internal/watcher/synthesizer/context.go
Normal file
@@ -0,0 +1,19 @@
|
||||
package synthesizer
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
)
|
||||
|
||||
// SynthesisContext provides the context needed for auth synthesis.
|
||||
type SynthesisContext struct {
|
||||
// Config is the current configuration
|
||||
Config *config.Config
|
||||
// AuthDir is the directory containing auth files
|
||||
AuthDir string
|
||||
// Now is the current time for timestamps
|
||||
Now time.Time
|
||||
// IDGenerator generates stable IDs for auth entries
|
||||
IDGenerator *StableIDGenerator
|
||||
}
|
||||
224
internal/watcher/synthesizer/file.go
Normal file
224
internal/watcher/synthesizer/file.go
Normal file
@@ -0,0 +1,224 @@
|
||||
package synthesizer
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/runtime/geminicli"
|
||||
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||
)
|
||||
|
||||
// FileSynthesizer generates Auth entries from OAuth JSON files.
|
||||
// It handles file-based authentication and Gemini virtual auth generation.
|
||||
type FileSynthesizer struct{}
|
||||
|
||||
// NewFileSynthesizer creates a new FileSynthesizer instance.
|
||||
func NewFileSynthesizer() *FileSynthesizer {
|
||||
return &FileSynthesizer{}
|
||||
}
|
||||
|
||||
// Synthesize generates Auth entries from auth files in the auth directory.
|
||||
func (s *FileSynthesizer) Synthesize(ctx *SynthesisContext) ([]*coreauth.Auth, error) {
|
||||
out := make([]*coreauth.Auth, 0, 16)
|
||||
if ctx == nil || ctx.AuthDir == "" {
|
||||
return out, nil
|
||||
}
|
||||
|
||||
entries, err := os.ReadDir(ctx.AuthDir)
|
||||
if err != nil {
|
||||
// Not an error if directory doesn't exist
|
||||
return out, nil
|
||||
}
|
||||
|
||||
now := ctx.Now
|
||||
cfg := ctx.Config
|
||||
|
||||
for _, e := range entries {
|
||||
if e.IsDir() {
|
||||
continue
|
||||
}
|
||||
name := e.Name()
|
||||
if !strings.HasSuffix(strings.ToLower(name), ".json") {
|
||||
continue
|
||||
}
|
||||
full := filepath.Join(ctx.AuthDir, name)
|
||||
data, errRead := os.ReadFile(full)
|
||||
if errRead != nil || len(data) == 0 {
|
||||
continue
|
||||
}
|
||||
var metadata map[string]any
|
||||
if errUnmarshal := json.Unmarshal(data, &metadata); errUnmarshal != nil {
|
||||
continue
|
||||
}
|
||||
t, _ := metadata["type"].(string)
|
||||
if t == "" {
|
||||
continue
|
||||
}
|
||||
provider := strings.ToLower(t)
|
||||
if provider == "gemini" {
|
||||
provider = "gemini-cli"
|
||||
}
|
||||
label := provider
|
||||
if email, _ := metadata["email"].(string); email != "" {
|
||||
label = email
|
||||
}
|
||||
// Use relative path under authDir as ID to stay consistent with the file-based token store
|
||||
id := full
|
||||
if rel, errRel := filepath.Rel(ctx.AuthDir, full); errRel == nil && rel != "" {
|
||||
id = rel
|
||||
}
|
||||
|
||||
proxyURL := ""
|
||||
if p, ok := metadata["proxy_url"].(string); ok {
|
||||
proxyURL = p
|
||||
}
|
||||
|
||||
prefix := ""
|
||||
if rawPrefix, ok := metadata["prefix"].(string); ok {
|
||||
trimmed := strings.TrimSpace(rawPrefix)
|
||||
trimmed = strings.Trim(trimmed, "/")
|
||||
if trimmed != "" && !strings.Contains(trimmed, "/") {
|
||||
prefix = trimmed
|
||||
}
|
||||
}
|
||||
|
||||
a := &coreauth.Auth{
|
||||
ID: id,
|
||||
Provider: provider,
|
||||
Label: label,
|
||||
Prefix: prefix,
|
||||
Status: coreauth.StatusActive,
|
||||
Attributes: map[string]string{
|
||||
"source": full,
|
||||
"path": full,
|
||||
},
|
||||
ProxyURL: proxyURL,
|
||||
Metadata: metadata,
|
||||
CreatedAt: now,
|
||||
UpdatedAt: now,
|
||||
}
|
||||
ApplyAuthExcludedModelsMeta(a, cfg, nil, "oauth")
|
||||
if provider == "gemini-cli" {
|
||||
if virtuals := SynthesizeGeminiVirtualAuths(a, metadata, now); len(virtuals) > 0 {
|
||||
for _, v := range virtuals {
|
||||
ApplyAuthExcludedModelsMeta(v, cfg, nil, "oauth")
|
||||
}
|
||||
out = append(out, a)
|
||||
out = append(out, virtuals...)
|
||||
continue
|
||||
}
|
||||
}
|
||||
out = append(out, a)
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
// SynthesizeGeminiVirtualAuths creates virtual Auth entries for multi-project Gemini credentials.
|
||||
// It disables the primary auth and creates one virtual auth per project.
|
||||
func SynthesizeGeminiVirtualAuths(primary *coreauth.Auth, metadata map[string]any, now time.Time) []*coreauth.Auth {
|
||||
if primary == nil || metadata == nil {
|
||||
return nil
|
||||
}
|
||||
projects := splitGeminiProjectIDs(metadata)
|
||||
if len(projects) <= 1 {
|
||||
return nil
|
||||
}
|
||||
email, _ := metadata["email"].(string)
|
||||
shared := geminicli.NewSharedCredential(primary.ID, email, metadata, projects)
|
||||
primary.Disabled = true
|
||||
primary.Status = coreauth.StatusDisabled
|
||||
primary.Runtime = shared
|
||||
if primary.Attributes == nil {
|
||||
primary.Attributes = make(map[string]string)
|
||||
}
|
||||
primary.Attributes["gemini_virtual_primary"] = "true"
|
||||
primary.Attributes["virtual_children"] = strings.Join(projects, ",")
|
||||
source := primary.Attributes["source"]
|
||||
authPath := primary.Attributes["path"]
|
||||
originalProvider := primary.Provider
|
||||
if originalProvider == "" {
|
||||
originalProvider = "gemini-cli"
|
||||
}
|
||||
label := primary.Label
|
||||
if label == "" {
|
||||
label = originalProvider
|
||||
}
|
||||
virtuals := make([]*coreauth.Auth, 0, len(projects))
|
||||
for _, projectID := range projects {
|
||||
attrs := map[string]string{
|
||||
"runtime_only": "true",
|
||||
"gemini_virtual_parent": primary.ID,
|
||||
"gemini_virtual_project": projectID,
|
||||
}
|
||||
if source != "" {
|
||||
attrs["source"] = source
|
||||
}
|
||||
if authPath != "" {
|
||||
attrs["path"] = authPath
|
||||
}
|
||||
metadataCopy := map[string]any{
|
||||
"email": email,
|
||||
"project_id": projectID,
|
||||
"virtual": true,
|
||||
"virtual_parent_id": primary.ID,
|
||||
"type": metadata["type"],
|
||||
}
|
||||
proxy := strings.TrimSpace(primary.ProxyURL)
|
||||
if proxy != "" {
|
||||
metadataCopy["proxy_url"] = proxy
|
||||
}
|
||||
virtual := &coreauth.Auth{
|
||||
ID: buildGeminiVirtualID(primary.ID, projectID),
|
||||
Provider: originalProvider,
|
||||
Label: fmt.Sprintf("%s [%s]", label, projectID),
|
||||
Status: coreauth.StatusActive,
|
||||
Attributes: attrs,
|
||||
Metadata: metadataCopy,
|
||||
ProxyURL: primary.ProxyURL,
|
||||
Prefix: primary.Prefix,
|
||||
CreatedAt: primary.CreatedAt,
|
||||
UpdatedAt: primary.UpdatedAt,
|
||||
Runtime: geminicli.NewVirtualCredential(projectID, shared),
|
||||
}
|
||||
virtuals = append(virtuals, virtual)
|
||||
}
|
||||
return virtuals
|
||||
}
|
||||
|
||||
// splitGeminiProjectIDs extracts and deduplicates project IDs from metadata.
|
||||
func splitGeminiProjectIDs(metadata map[string]any) []string {
|
||||
raw, _ := metadata["project_id"].(string)
|
||||
trimmed := strings.TrimSpace(raw)
|
||||
if trimmed == "" {
|
||||
return nil
|
||||
}
|
||||
parts := strings.Split(trimmed, ",")
|
||||
result := make([]string, 0, len(parts))
|
||||
seen := make(map[string]struct{}, len(parts))
|
||||
for _, part := range parts {
|
||||
id := strings.TrimSpace(part)
|
||||
if id == "" {
|
||||
continue
|
||||
}
|
||||
if _, ok := seen[id]; ok {
|
||||
continue
|
||||
}
|
||||
seen[id] = struct{}{}
|
||||
result = append(result, id)
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// buildGeminiVirtualID constructs a virtual auth ID from base ID and project ID.
|
||||
func buildGeminiVirtualID(baseID, projectID string) string {
|
||||
project := strings.TrimSpace(projectID)
|
||||
if project == "" {
|
||||
project = "project"
|
||||
}
|
||||
replacer := strings.NewReplacer("/", "_", "\\", "_", " ", "_")
|
||||
return fmt.Sprintf("%s::%s", baseID, replacer.Replace(project))
|
||||
}
|
||||
612
internal/watcher/synthesizer/file_test.go
Normal file
612
internal/watcher/synthesizer/file_test.go
Normal file
@@ -0,0 +1,612 @@
|
||||
package synthesizer
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||
)
|
||||
|
||||
func TestNewFileSynthesizer(t *testing.T) {
|
||||
synth := NewFileSynthesizer()
|
||||
if synth == nil {
|
||||
t.Fatal("expected non-nil synthesizer")
|
||||
}
|
||||
}
|
||||
|
||||
func TestFileSynthesizer_Synthesize_NilContext(t *testing.T) {
|
||||
synth := NewFileSynthesizer()
|
||||
auths, err := synth.Synthesize(nil)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if len(auths) != 0 {
|
||||
t.Fatalf("expected empty auths, got %d", len(auths))
|
||||
}
|
||||
}
|
||||
|
||||
func TestFileSynthesizer_Synthesize_EmptyAuthDir(t *testing.T) {
|
||||
synth := NewFileSynthesizer()
|
||||
ctx := &SynthesisContext{
|
||||
Config: &config.Config{},
|
||||
AuthDir: "",
|
||||
Now: time.Now(),
|
||||
IDGenerator: NewStableIDGenerator(),
|
||||
}
|
||||
auths, err := synth.Synthesize(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if len(auths) != 0 {
|
||||
t.Fatalf("expected empty auths, got %d", len(auths))
|
||||
}
|
||||
}
|
||||
|
||||
func TestFileSynthesizer_Synthesize_NonExistentDir(t *testing.T) {
|
||||
synth := NewFileSynthesizer()
|
||||
ctx := &SynthesisContext{
|
||||
Config: &config.Config{},
|
||||
AuthDir: "/non/existent/path",
|
||||
Now: time.Now(),
|
||||
IDGenerator: NewStableIDGenerator(),
|
||||
}
|
||||
auths, err := synth.Synthesize(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if len(auths) != 0 {
|
||||
t.Fatalf("expected empty auths, got %d", len(auths))
|
||||
}
|
||||
}
|
||||
|
||||
func TestFileSynthesizer_Synthesize_ValidAuthFile(t *testing.T) {
|
||||
tempDir := t.TempDir()
|
||||
|
||||
// Create a valid auth file
|
||||
authData := map[string]any{
|
||||
"type": "claude",
|
||||
"email": "test@example.com",
|
||||
"proxy_url": "http://proxy.local",
|
||||
"prefix": "test-prefix",
|
||||
}
|
||||
data, _ := json.Marshal(authData)
|
||||
err := os.WriteFile(filepath.Join(tempDir, "claude-auth.json"), data, 0644)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to write auth file: %v", err)
|
||||
}
|
||||
|
||||
synth := NewFileSynthesizer()
|
||||
ctx := &SynthesisContext{
|
||||
Config: &config.Config{},
|
||||
AuthDir: tempDir,
|
||||
Now: time.Date(2025, 1, 1, 0, 0, 0, 0, time.UTC),
|
||||
IDGenerator: NewStableIDGenerator(),
|
||||
}
|
||||
|
||||
auths, err := synth.Synthesize(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if len(auths) != 1 {
|
||||
t.Fatalf("expected 1 auth, got %d", len(auths))
|
||||
}
|
||||
|
||||
if auths[0].Provider != "claude" {
|
||||
t.Errorf("expected provider claude, got %s", auths[0].Provider)
|
||||
}
|
||||
if auths[0].Label != "test@example.com" {
|
||||
t.Errorf("expected label test@example.com, got %s", auths[0].Label)
|
||||
}
|
||||
if auths[0].Prefix != "test-prefix" {
|
||||
t.Errorf("expected prefix test-prefix, got %s", auths[0].Prefix)
|
||||
}
|
||||
if auths[0].ProxyURL != "http://proxy.local" {
|
||||
t.Errorf("expected proxy_url http://proxy.local, got %s", auths[0].ProxyURL)
|
||||
}
|
||||
if auths[0].Status != coreauth.StatusActive {
|
||||
t.Errorf("expected status active, got %s", auths[0].Status)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFileSynthesizer_Synthesize_GeminiProviderMapping(t *testing.T) {
|
||||
tempDir := t.TempDir()
|
||||
|
||||
// Gemini type should be mapped to gemini-cli
|
||||
authData := map[string]any{
|
||||
"type": "gemini",
|
||||
"email": "gemini@example.com",
|
||||
}
|
||||
data, _ := json.Marshal(authData)
|
||||
err := os.WriteFile(filepath.Join(tempDir, "gemini-auth.json"), data, 0644)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to write auth file: %v", err)
|
||||
}
|
||||
|
||||
synth := NewFileSynthesizer()
|
||||
ctx := &SynthesisContext{
|
||||
Config: &config.Config{},
|
||||
AuthDir: tempDir,
|
||||
Now: time.Now(),
|
||||
IDGenerator: NewStableIDGenerator(),
|
||||
}
|
||||
|
||||
auths, err := synth.Synthesize(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if len(auths) != 1 {
|
||||
t.Fatalf("expected 1 auth, got %d", len(auths))
|
||||
}
|
||||
|
||||
if auths[0].Provider != "gemini-cli" {
|
||||
t.Errorf("gemini should be mapped to gemini-cli, got %s", auths[0].Provider)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFileSynthesizer_Synthesize_SkipsInvalidFiles(t *testing.T) {
|
||||
tempDir := t.TempDir()
|
||||
|
||||
// Create various invalid files
|
||||
_ = os.WriteFile(filepath.Join(tempDir, "not-json.txt"), []byte("text content"), 0644)
|
||||
_ = os.WriteFile(filepath.Join(tempDir, "invalid.json"), []byte("not valid json"), 0644)
|
||||
_ = os.WriteFile(filepath.Join(tempDir, "empty.json"), []byte(""), 0644)
|
||||
_ = os.WriteFile(filepath.Join(tempDir, "no-type.json"), []byte(`{"email": "test@example.com"}`), 0644)
|
||||
|
||||
// Create one valid file
|
||||
validData, _ := json.Marshal(map[string]any{"type": "claude", "email": "valid@example.com"})
|
||||
_ = os.WriteFile(filepath.Join(tempDir, "valid.json"), validData, 0644)
|
||||
|
||||
synth := NewFileSynthesizer()
|
||||
ctx := &SynthesisContext{
|
||||
Config: &config.Config{},
|
||||
AuthDir: tempDir,
|
||||
Now: time.Now(),
|
||||
IDGenerator: NewStableIDGenerator(),
|
||||
}
|
||||
|
||||
auths, err := synth.Synthesize(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if len(auths) != 1 {
|
||||
t.Fatalf("only valid auth file should be processed, got %d", len(auths))
|
||||
}
|
||||
if auths[0].Label != "valid@example.com" {
|
||||
t.Errorf("expected label valid@example.com, got %s", auths[0].Label)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFileSynthesizer_Synthesize_SkipsDirectories(t *testing.T) {
|
||||
tempDir := t.TempDir()
|
||||
|
||||
// Create a subdirectory with a json file inside
|
||||
subDir := filepath.Join(tempDir, "subdir.json")
|
||||
err := os.Mkdir(subDir, 0755)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to create subdir: %v", err)
|
||||
}
|
||||
|
||||
// Create a valid file in root
|
||||
validData, _ := json.Marshal(map[string]any{"type": "claude"})
|
||||
_ = os.WriteFile(filepath.Join(tempDir, "valid.json"), validData, 0644)
|
||||
|
||||
synth := NewFileSynthesizer()
|
||||
ctx := &SynthesisContext{
|
||||
Config: &config.Config{},
|
||||
AuthDir: tempDir,
|
||||
Now: time.Now(),
|
||||
IDGenerator: NewStableIDGenerator(),
|
||||
}
|
||||
|
||||
auths, err := synth.Synthesize(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if len(auths) != 1 {
|
||||
t.Fatalf("expected 1 auth, got %d", len(auths))
|
||||
}
|
||||
}
|
||||
|
||||
func TestFileSynthesizer_Synthesize_RelativeID(t *testing.T) {
|
||||
tempDir := t.TempDir()
|
||||
|
||||
authData := map[string]any{"type": "claude"}
|
||||
data, _ := json.Marshal(authData)
|
||||
err := os.WriteFile(filepath.Join(tempDir, "my-auth.json"), data, 0644)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to write auth file: %v", err)
|
||||
}
|
||||
|
||||
synth := NewFileSynthesizer()
|
||||
ctx := &SynthesisContext{
|
||||
Config: &config.Config{},
|
||||
AuthDir: tempDir,
|
||||
Now: time.Now(),
|
||||
IDGenerator: NewStableIDGenerator(),
|
||||
}
|
||||
|
||||
auths, err := synth.Synthesize(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if len(auths) != 1 {
|
||||
t.Fatalf("expected 1 auth, got %d", len(auths))
|
||||
}
|
||||
|
||||
// ID should be relative path
|
||||
if auths[0].ID != "my-auth.json" {
|
||||
t.Errorf("expected ID my-auth.json, got %s", auths[0].ID)
|
||||
}
|
||||
}
|
||||
|
||||
func TestFileSynthesizer_Synthesize_PrefixValidation(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
prefix string
|
||||
wantPrefix string
|
||||
}{
|
||||
{"valid prefix", "myprefix", "myprefix"},
|
||||
{"prefix with slashes trimmed", "/myprefix/", "myprefix"},
|
||||
{"prefix with spaces trimmed", " myprefix ", "myprefix"},
|
||||
{"prefix with internal slash rejected", "my/prefix", ""},
|
||||
{"empty prefix", "", ""},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
tempDir := t.TempDir()
|
||||
authData := map[string]any{
|
||||
"type": "claude",
|
||||
"prefix": tt.prefix,
|
||||
}
|
||||
data, _ := json.Marshal(authData)
|
||||
_ = os.WriteFile(filepath.Join(tempDir, "auth.json"), data, 0644)
|
||||
|
||||
synth := NewFileSynthesizer()
|
||||
ctx := &SynthesisContext{
|
||||
Config: &config.Config{},
|
||||
AuthDir: tempDir,
|
||||
Now: time.Now(),
|
||||
IDGenerator: NewStableIDGenerator(),
|
||||
}
|
||||
|
||||
auths, err := synth.Synthesize(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
if len(auths) != 1 {
|
||||
t.Fatalf("expected 1 auth, got %d", len(auths))
|
||||
}
|
||||
if auths[0].Prefix != tt.wantPrefix {
|
||||
t.Errorf("expected prefix %q, got %q", tt.wantPrefix, auths[0].Prefix)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestSynthesizeGeminiVirtualAuths_NilInputs(t *testing.T) {
|
||||
now := time.Now()
|
||||
|
||||
if SynthesizeGeminiVirtualAuths(nil, nil, now) != nil {
|
||||
t.Error("expected nil for nil primary")
|
||||
}
|
||||
if SynthesizeGeminiVirtualAuths(&coreauth.Auth{}, nil, now) != nil {
|
||||
t.Error("expected nil for nil metadata")
|
||||
}
|
||||
if SynthesizeGeminiVirtualAuths(nil, map[string]any{}, now) != nil {
|
||||
t.Error("expected nil for nil primary with metadata")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSynthesizeGeminiVirtualAuths_SingleProject(t *testing.T) {
|
||||
now := time.Now()
|
||||
primary := &coreauth.Auth{
|
||||
ID: "test-id",
|
||||
Provider: "gemini-cli",
|
||||
Label: "test@example.com",
|
||||
}
|
||||
metadata := map[string]any{
|
||||
"project_id": "single-project",
|
||||
"email": "test@example.com",
|
||||
"type": "gemini",
|
||||
}
|
||||
|
||||
virtuals := SynthesizeGeminiVirtualAuths(primary, metadata, now)
|
||||
if virtuals != nil {
|
||||
t.Error("single project should not create virtuals")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSynthesizeGeminiVirtualAuths_MultiProject(t *testing.T) {
|
||||
now := time.Now()
|
||||
primary := &coreauth.Auth{
|
||||
ID: "primary-id",
|
||||
Provider: "gemini-cli",
|
||||
Label: "test@example.com",
|
||||
Prefix: "test-prefix",
|
||||
ProxyURL: "http://proxy.local",
|
||||
Attributes: map[string]string{
|
||||
"source": "test-source",
|
||||
"path": "/path/to/auth",
|
||||
},
|
||||
}
|
||||
metadata := map[string]any{
|
||||
"project_id": "project-a, project-b, project-c",
|
||||
"email": "test@example.com",
|
||||
"type": "gemini",
|
||||
}
|
||||
|
||||
virtuals := SynthesizeGeminiVirtualAuths(primary, metadata, now)
|
||||
|
||||
if len(virtuals) != 3 {
|
||||
t.Fatalf("expected 3 virtuals, got %d", len(virtuals))
|
||||
}
|
||||
|
||||
// Check primary is disabled
|
||||
if !primary.Disabled {
|
||||
t.Error("expected primary to be disabled")
|
||||
}
|
||||
if primary.Status != coreauth.StatusDisabled {
|
||||
t.Errorf("expected primary status disabled, got %s", primary.Status)
|
||||
}
|
||||
if primary.Attributes["gemini_virtual_primary"] != "true" {
|
||||
t.Error("expected gemini_virtual_primary=true")
|
||||
}
|
||||
if !strings.Contains(primary.Attributes["virtual_children"], "project-a") {
|
||||
t.Error("expected virtual_children to contain project-a")
|
||||
}
|
||||
|
||||
// Check virtuals
|
||||
projectIDs := []string{"project-a", "project-b", "project-c"}
|
||||
for i, v := range virtuals {
|
||||
if v.Provider != "gemini-cli" {
|
||||
t.Errorf("expected provider gemini-cli, got %s", v.Provider)
|
||||
}
|
||||
if v.Status != coreauth.StatusActive {
|
||||
t.Errorf("expected status active, got %s", v.Status)
|
||||
}
|
||||
if v.Prefix != "test-prefix" {
|
||||
t.Errorf("expected prefix test-prefix, got %s", v.Prefix)
|
||||
}
|
||||
if v.ProxyURL != "http://proxy.local" {
|
||||
t.Errorf("expected proxy_url http://proxy.local, got %s", v.ProxyURL)
|
||||
}
|
||||
if v.Attributes["runtime_only"] != "true" {
|
||||
t.Error("expected runtime_only=true")
|
||||
}
|
||||
if v.Attributes["gemini_virtual_parent"] != "primary-id" {
|
||||
t.Errorf("expected gemini_virtual_parent=primary-id, got %s", v.Attributes["gemini_virtual_parent"])
|
||||
}
|
||||
if v.Attributes["gemini_virtual_project"] != projectIDs[i] {
|
||||
t.Errorf("expected gemini_virtual_project=%s, got %s", projectIDs[i], v.Attributes["gemini_virtual_project"])
|
||||
}
|
||||
if !strings.Contains(v.Label, "["+projectIDs[i]+"]") {
|
||||
t.Errorf("expected label to contain [%s], got %s", projectIDs[i], v.Label)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestSynthesizeGeminiVirtualAuths_EmptyProviderAndLabel(t *testing.T) {
|
||||
now := time.Now()
|
||||
// Test with empty Provider and Label to cover fallback branches
|
||||
primary := &coreauth.Auth{
|
||||
ID: "primary-id",
|
||||
Provider: "", // empty provider - should default to gemini-cli
|
||||
Label: "", // empty label - should default to provider
|
||||
Attributes: map[string]string{},
|
||||
}
|
||||
metadata := map[string]any{
|
||||
"project_id": "proj-a, proj-b",
|
||||
"email": "user@example.com",
|
||||
"type": "gemini",
|
||||
}
|
||||
|
||||
virtuals := SynthesizeGeminiVirtualAuths(primary, metadata, now)
|
||||
|
||||
if len(virtuals) != 2 {
|
||||
t.Fatalf("expected 2 virtuals, got %d", len(virtuals))
|
||||
}
|
||||
|
||||
// Check that empty provider defaults to gemini-cli
|
||||
if virtuals[0].Provider != "gemini-cli" {
|
||||
t.Errorf("expected provider gemini-cli (default), got %s", virtuals[0].Provider)
|
||||
}
|
||||
// Check that empty label defaults to provider
|
||||
if !strings.Contains(virtuals[0].Label, "gemini-cli") {
|
||||
t.Errorf("expected label to contain gemini-cli, got %s", virtuals[0].Label)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSynthesizeGeminiVirtualAuths_NilPrimaryAttributes(t *testing.T) {
|
||||
now := time.Now()
|
||||
primary := &coreauth.Auth{
|
||||
ID: "primary-id",
|
||||
Provider: "gemini-cli",
|
||||
Label: "test@example.com",
|
||||
Attributes: nil, // nil attributes
|
||||
}
|
||||
metadata := map[string]any{
|
||||
"project_id": "proj-a, proj-b",
|
||||
"email": "test@example.com",
|
||||
"type": "gemini",
|
||||
}
|
||||
|
||||
virtuals := SynthesizeGeminiVirtualAuths(primary, metadata, now)
|
||||
|
||||
if len(virtuals) != 2 {
|
||||
t.Fatalf("expected 2 virtuals, got %d", len(virtuals))
|
||||
}
|
||||
// Nil attributes should be initialized
|
||||
if primary.Attributes == nil {
|
||||
t.Error("expected primary.Attributes to be initialized")
|
||||
}
|
||||
if primary.Attributes["gemini_virtual_primary"] != "true" {
|
||||
t.Error("expected gemini_virtual_primary=true")
|
||||
}
|
||||
}
|
||||
|
||||
func TestSplitGeminiProjectIDs(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
metadata map[string]any
|
||||
want []string
|
||||
}{
|
||||
{
|
||||
name: "single project",
|
||||
metadata: map[string]any{"project_id": "proj-a"},
|
||||
want: []string{"proj-a"},
|
||||
},
|
||||
{
|
||||
name: "multiple projects",
|
||||
metadata: map[string]any{"project_id": "proj-a, proj-b, proj-c"},
|
||||
want: []string{"proj-a", "proj-b", "proj-c"},
|
||||
},
|
||||
{
|
||||
name: "with duplicates",
|
||||
metadata: map[string]any{"project_id": "proj-a, proj-b, proj-a"},
|
||||
want: []string{"proj-a", "proj-b"},
|
||||
},
|
||||
{
|
||||
name: "with empty parts",
|
||||
metadata: map[string]any{"project_id": "proj-a, , proj-b, "},
|
||||
want: []string{"proj-a", "proj-b"},
|
||||
},
|
||||
{
|
||||
name: "empty project_id",
|
||||
metadata: map[string]any{"project_id": ""},
|
||||
want: nil,
|
||||
},
|
||||
{
|
||||
name: "no project_id",
|
||||
metadata: map[string]any{},
|
||||
want: nil,
|
||||
},
|
||||
{
|
||||
name: "whitespace only",
|
||||
metadata: map[string]any{"project_id": " "},
|
||||
want: nil,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := splitGeminiProjectIDs(tt.metadata)
|
||||
if len(got) != len(tt.want) {
|
||||
t.Fatalf("expected %v, got %v", tt.want, got)
|
||||
}
|
||||
for i := range got {
|
||||
if got[i] != tt.want[i] {
|
||||
t.Errorf("expected %v, got %v", tt.want, got)
|
||||
break
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestFileSynthesizer_Synthesize_MultiProjectGemini(t *testing.T) {
|
||||
tempDir := t.TempDir()
|
||||
|
||||
// Create a gemini auth file with multiple projects
|
||||
authData := map[string]any{
|
||||
"type": "gemini",
|
||||
"email": "multi@example.com",
|
||||
"project_id": "project-a, project-b, project-c",
|
||||
}
|
||||
data, _ := json.Marshal(authData)
|
||||
err := os.WriteFile(filepath.Join(tempDir, "gemini-multi.json"), data, 0644)
|
||||
if err != nil {
|
||||
t.Fatalf("failed to write auth file: %v", err)
|
||||
}
|
||||
|
||||
synth := NewFileSynthesizer()
|
||||
ctx := &SynthesisContext{
|
||||
Config: &config.Config{},
|
||||
AuthDir: tempDir,
|
||||
Now: time.Now(),
|
||||
IDGenerator: NewStableIDGenerator(),
|
||||
}
|
||||
|
||||
auths, err := synth.Synthesize(ctx)
|
||||
if err != nil {
|
||||
t.Fatalf("unexpected error: %v", err)
|
||||
}
|
||||
// Should have 4 auths: 1 primary (disabled) + 3 virtuals
|
||||
if len(auths) != 4 {
|
||||
t.Fatalf("expected 4 auths (1 primary + 3 virtuals), got %d", len(auths))
|
||||
}
|
||||
|
||||
// First auth should be the primary (disabled)
|
||||
primary := auths[0]
|
||||
if !primary.Disabled {
|
||||
t.Error("expected primary to be disabled")
|
||||
}
|
||||
if primary.Status != coreauth.StatusDisabled {
|
||||
t.Errorf("expected primary status disabled, got %s", primary.Status)
|
||||
}
|
||||
|
||||
// Remaining auths should be virtuals
|
||||
for i := 1; i < 4; i++ {
|
||||
v := auths[i]
|
||||
if v.Status != coreauth.StatusActive {
|
||||
t.Errorf("expected virtual %d to be active, got %s", i, v.Status)
|
||||
}
|
||||
if v.Attributes["gemini_virtual_parent"] != primary.ID {
|
||||
t.Errorf("expected virtual %d parent to be %s, got %s", i, primary.ID, v.Attributes["gemini_virtual_parent"])
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildGeminiVirtualID(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
baseID string
|
||||
projectID string
|
||||
want string
|
||||
}{
|
||||
{
|
||||
name: "basic",
|
||||
baseID: "auth.json",
|
||||
projectID: "my-project",
|
||||
want: "auth.json::my-project",
|
||||
},
|
||||
{
|
||||
name: "with slashes",
|
||||
baseID: "path/to/auth.json",
|
||||
projectID: "project/with/slashes",
|
||||
want: "path/to/auth.json::project_with_slashes",
|
||||
},
|
||||
{
|
||||
name: "with spaces",
|
||||
baseID: "auth.json",
|
||||
projectID: "my project",
|
||||
want: "auth.json::my_project",
|
||||
},
|
||||
{
|
||||
name: "empty project",
|
||||
baseID: "auth.json",
|
||||
projectID: "",
|
||||
want: "auth.json::project",
|
||||
},
|
||||
{
|
||||
name: "whitespace project",
|
||||
baseID: "auth.json",
|
||||
projectID: " ",
|
||||
want: "auth.json::project",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
got := buildGeminiVirtualID(tt.baseID, tt.projectID)
|
||||
if got != tt.want {
|
||||
t.Errorf("expected %q, got %q", tt.want, got)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
110
internal/watcher/synthesizer/helpers.go
Normal file
110
internal/watcher/synthesizer/helpers.go
Normal file
@@ -0,0 +1,110 @@
|
||||
package synthesizer
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"sort"
|
||||
"strings"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/watcher/diff"
|
||||
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||
)
|
||||
|
||||
// StableIDGenerator generates stable, deterministic IDs for auth entries.
|
||||
// It uses SHA256 hashing with collision handling via counters.
|
||||
// It is not safe for concurrent use.
|
||||
type StableIDGenerator struct {
|
||||
counters map[string]int
|
||||
}
|
||||
|
||||
// NewStableIDGenerator creates a new StableIDGenerator instance.
|
||||
func NewStableIDGenerator() *StableIDGenerator {
|
||||
return &StableIDGenerator{counters: make(map[string]int)}
|
||||
}
|
||||
|
||||
// Next generates a stable ID based on the kind and parts.
|
||||
// Returns the full ID (kind:hash) and the short hash portion.
|
||||
func (g *StableIDGenerator) Next(kind string, parts ...string) (string, string) {
|
||||
if g == nil {
|
||||
return kind + ":000000000000", "000000000000"
|
||||
}
|
||||
hasher := sha256.New()
|
||||
hasher.Write([]byte(kind))
|
||||
for _, part := range parts {
|
||||
trimmed := strings.TrimSpace(part)
|
||||
hasher.Write([]byte{0})
|
||||
hasher.Write([]byte(trimmed))
|
||||
}
|
||||
digest := hex.EncodeToString(hasher.Sum(nil))
|
||||
if len(digest) < 12 {
|
||||
digest = fmt.Sprintf("%012s", digest)
|
||||
}
|
||||
short := digest[:12]
|
||||
key := kind + ":" + short
|
||||
index := g.counters[key]
|
||||
g.counters[key] = index + 1
|
||||
if index > 0 {
|
||||
short = fmt.Sprintf("%s-%d", short, index)
|
||||
}
|
||||
return fmt.Sprintf("%s:%s", kind, short), short
|
||||
}
|
||||
|
||||
// ApplyAuthExcludedModelsMeta applies excluded models metadata to an auth entry.
|
||||
// It computes a hash of excluded models and sets the auth_kind attribute.
|
||||
func ApplyAuthExcludedModelsMeta(auth *coreauth.Auth, cfg *config.Config, perKey []string, authKind string) {
|
||||
if auth == nil || cfg == nil {
|
||||
return
|
||||
}
|
||||
authKindKey := strings.ToLower(strings.TrimSpace(authKind))
|
||||
seen := make(map[string]struct{})
|
||||
add := func(list []string) {
|
||||
for _, entry := range list {
|
||||
if trimmed := strings.TrimSpace(entry); trimmed != "" {
|
||||
key := strings.ToLower(trimmed)
|
||||
if _, exists := seen[key]; exists {
|
||||
continue
|
||||
}
|
||||
seen[key] = struct{}{}
|
||||
}
|
||||
}
|
||||
}
|
||||
if authKindKey == "apikey" {
|
||||
add(perKey)
|
||||
} else if cfg.OAuthExcludedModels != nil {
|
||||
providerKey := strings.ToLower(strings.TrimSpace(auth.Provider))
|
||||
add(cfg.OAuthExcludedModels[providerKey])
|
||||
}
|
||||
combined := make([]string, 0, len(seen))
|
||||
for k := range seen {
|
||||
combined = append(combined, k)
|
||||
}
|
||||
sort.Strings(combined)
|
||||
hash := diff.ComputeExcludedModelsHash(combined)
|
||||
if auth.Attributes == nil {
|
||||
auth.Attributes = make(map[string]string)
|
||||
}
|
||||
if hash != "" {
|
||||
auth.Attributes["excluded_models_hash"] = hash
|
||||
}
|
||||
if authKind != "" {
|
||||
auth.Attributes["auth_kind"] = authKind
|
||||
}
|
||||
}
|
||||
|
||||
// addConfigHeadersToAttrs adds header configuration to auth attributes.
|
||||
// Headers are prefixed with "header:" in the attributes map.
|
||||
func addConfigHeadersToAttrs(headers map[string]string, attrs map[string]string) {
|
||||
if len(headers) == 0 || attrs == nil {
|
||||
return
|
||||
}
|
||||
for hk, hv := range headers {
|
||||
key := strings.TrimSpace(hk)
|
||||
val := strings.TrimSpace(hv)
|
||||
if key == "" || val == "" {
|
||||
continue
|
||||
}
|
||||
attrs["header:"+key] = val
|
||||
}
|
||||
}
|
||||
264
internal/watcher/synthesizer/helpers_test.go
Normal file
264
internal/watcher/synthesizer/helpers_test.go
Normal file
@@ -0,0 +1,264 @@
|
||||
package synthesizer
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||
)
|
||||
|
||||
func TestNewStableIDGenerator(t *testing.T) {
|
||||
gen := NewStableIDGenerator()
|
||||
if gen == nil {
|
||||
t.Fatal("expected non-nil generator")
|
||||
}
|
||||
if gen.counters == nil {
|
||||
t.Fatal("expected non-nil counters map")
|
||||
}
|
||||
}
|
||||
|
||||
func TestStableIDGenerator_Next(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
kind string
|
||||
parts []string
|
||||
wantPrefix string
|
||||
}{
|
||||
{
|
||||
name: "basic gemini apikey",
|
||||
kind: "gemini:apikey",
|
||||
parts: []string{"test-key", ""},
|
||||
wantPrefix: "gemini:apikey:",
|
||||
},
|
||||
{
|
||||
name: "claude with base url",
|
||||
kind: "claude:apikey",
|
||||
parts: []string{"sk-ant-xxx", "https://api.anthropic.com"},
|
||||
wantPrefix: "claude:apikey:",
|
||||
},
|
||||
{
|
||||
name: "empty parts",
|
||||
kind: "codex:apikey",
|
||||
parts: []string{},
|
||||
wantPrefix: "codex:apikey:",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
gen := NewStableIDGenerator()
|
||||
id, short := gen.Next(tt.kind, tt.parts...)
|
||||
|
||||
if !strings.Contains(id, tt.wantPrefix) {
|
||||
t.Errorf("expected id to contain %q, got %q", tt.wantPrefix, id)
|
||||
}
|
||||
if short == "" {
|
||||
t.Error("expected non-empty short id")
|
||||
}
|
||||
if len(short) != 12 {
|
||||
t.Errorf("expected short id length 12, got %d", len(short))
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestStableIDGenerator_Stability(t *testing.T) {
|
||||
gen1 := NewStableIDGenerator()
|
||||
gen2 := NewStableIDGenerator()
|
||||
|
||||
id1, _ := gen1.Next("gemini:apikey", "test-key", "https://api.example.com")
|
||||
id2, _ := gen2.Next("gemini:apikey", "test-key", "https://api.example.com")
|
||||
|
||||
if id1 != id2 {
|
||||
t.Errorf("same inputs should produce same ID: got %q and %q", id1, id2)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStableIDGenerator_CollisionHandling(t *testing.T) {
|
||||
gen := NewStableIDGenerator()
|
||||
|
||||
id1, short1 := gen.Next("gemini:apikey", "same-key")
|
||||
id2, short2 := gen.Next("gemini:apikey", "same-key")
|
||||
|
||||
if id1 == id2 {
|
||||
t.Error("collision should be handled with suffix")
|
||||
}
|
||||
if short1 == short2 {
|
||||
t.Error("short ids should differ")
|
||||
}
|
||||
if !strings.Contains(short2, "-1") {
|
||||
t.Errorf("second short id should contain -1 suffix, got %q", short2)
|
||||
}
|
||||
}
|
||||
|
||||
func TestStableIDGenerator_NilReceiver(t *testing.T) {
|
||||
var gen *StableIDGenerator = nil
|
||||
id, short := gen.Next("test:kind", "part")
|
||||
|
||||
if id != "test:kind:000000000000" {
|
||||
t.Errorf("expected test:kind:000000000000, got %q", id)
|
||||
}
|
||||
if short != "000000000000" {
|
||||
t.Errorf("expected 000000000000, got %q", short)
|
||||
}
|
||||
}
|
||||
|
||||
func TestApplyAuthExcludedModelsMeta(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
auth *coreauth.Auth
|
||||
cfg *config.Config
|
||||
perKey []string
|
||||
authKind string
|
||||
wantHash bool
|
||||
wantKind string
|
||||
}{
|
||||
{
|
||||
name: "apikey with excluded models",
|
||||
auth: &coreauth.Auth{
|
||||
Provider: "gemini",
|
||||
Attributes: make(map[string]string),
|
||||
},
|
||||
cfg: &config.Config{},
|
||||
perKey: []string{"model-a", "model-b"},
|
||||
authKind: "apikey",
|
||||
wantHash: true,
|
||||
wantKind: "apikey",
|
||||
},
|
||||
{
|
||||
name: "oauth with provider excluded models",
|
||||
auth: &coreauth.Auth{
|
||||
Provider: "claude",
|
||||
Attributes: make(map[string]string),
|
||||
},
|
||||
cfg: &config.Config{
|
||||
OAuthExcludedModels: map[string][]string{
|
||||
"claude": {"claude-2.0"},
|
||||
},
|
||||
},
|
||||
perKey: nil,
|
||||
authKind: "oauth",
|
||||
wantHash: true,
|
||||
wantKind: "oauth",
|
||||
},
|
||||
{
|
||||
name: "nil auth",
|
||||
auth: nil,
|
||||
cfg: &config.Config{},
|
||||
},
|
||||
{
|
||||
name: "nil config",
|
||||
auth: &coreauth.Auth{Provider: "test"},
|
||||
cfg: nil,
|
||||
authKind: "apikey",
|
||||
},
|
||||
{
|
||||
name: "nil attributes initialized",
|
||||
auth: &coreauth.Auth{
|
||||
Provider: "gemini",
|
||||
Attributes: nil,
|
||||
},
|
||||
cfg: &config.Config{},
|
||||
perKey: []string{"model-x"},
|
||||
authKind: "apikey",
|
||||
wantHash: true,
|
||||
wantKind: "apikey",
|
||||
},
|
||||
{
|
||||
name: "apikey with duplicate excluded models",
|
||||
auth: &coreauth.Auth{
|
||||
Provider: "gemini",
|
||||
Attributes: make(map[string]string),
|
||||
},
|
||||
cfg: &config.Config{},
|
||||
perKey: []string{"model-a", "MODEL-A", "model-b", "model-a"},
|
||||
authKind: "apikey",
|
||||
wantHash: true,
|
||||
wantKind: "apikey",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
ApplyAuthExcludedModelsMeta(tt.auth, tt.cfg, tt.perKey, tt.authKind)
|
||||
|
||||
if tt.auth != nil && tt.cfg != nil {
|
||||
if tt.wantHash {
|
||||
if _, ok := tt.auth.Attributes["excluded_models_hash"]; !ok {
|
||||
t.Error("expected excluded_models_hash in attributes")
|
||||
}
|
||||
}
|
||||
if tt.wantKind != "" {
|
||||
if got := tt.auth.Attributes["auth_kind"]; got != tt.wantKind {
|
||||
t.Errorf("expected auth_kind=%s, got %s", tt.wantKind, got)
|
||||
}
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestAddConfigHeadersToAttrs(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
headers map[string]string
|
||||
attrs map[string]string
|
||||
want map[string]string
|
||||
}{
|
||||
{
|
||||
name: "basic headers",
|
||||
headers: map[string]string{
|
||||
"Authorization": "Bearer token",
|
||||
"X-Custom": "value",
|
||||
},
|
||||
attrs: map[string]string{"existing": "key"},
|
||||
want: map[string]string{
|
||||
"existing": "key",
|
||||
"header:Authorization": "Bearer token",
|
||||
"header:X-Custom": "value",
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "empty headers",
|
||||
headers: map[string]string{},
|
||||
attrs: map[string]string{"existing": "key"},
|
||||
want: map[string]string{"existing": "key"},
|
||||
},
|
||||
{
|
||||
name: "nil headers",
|
||||
headers: nil,
|
||||
attrs: map[string]string{"existing": "key"},
|
||||
want: map[string]string{"existing": "key"},
|
||||
},
|
||||
{
|
||||
name: "nil attrs",
|
||||
headers: map[string]string{"key": "value"},
|
||||
attrs: nil,
|
||||
want: nil,
|
||||
},
|
||||
{
|
||||
name: "skip empty keys and values",
|
||||
headers: map[string]string{
|
||||
"": "value",
|
||||
"key": "",
|
||||
" ": "value",
|
||||
"valid": "valid-value",
|
||||
},
|
||||
attrs: make(map[string]string),
|
||||
want: map[string]string{
|
||||
"header:valid": "valid-value",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
addConfigHeadersToAttrs(tt.headers, tt.attrs)
|
||||
if !reflect.DeepEqual(tt.attrs, tt.want) {
|
||||
t.Errorf("expected %v, got %v", tt.want, tt.attrs)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
16
internal/watcher/synthesizer/interface.go
Normal file
16
internal/watcher/synthesizer/interface.go
Normal file
@@ -0,0 +1,16 @@
|
||||
// Package synthesizer provides auth synthesis strategies for the watcher package.
|
||||
// It implements the Strategy pattern to support multiple auth sources:
|
||||
// - ConfigSynthesizer: generates Auth entries from config API keys
|
||||
// - FileSynthesizer: generates Auth entries from OAuth JSON files
|
||||
package synthesizer
|
||||
|
||||
import (
|
||||
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||
)
|
||||
|
||||
// AuthSynthesizer defines the interface for generating Auth entries from various sources.
|
||||
type AuthSynthesizer interface {
|
||||
// Synthesize generates Auth entries from the given context.
|
||||
// Returns a slice of Auth pointers and any error encountered.
|
||||
Synthesize(ctx *SynthesisContext) ([]*coreauth.Auth, error)
|
||||
}
|
||||
File diff suppressed because it is too large
Load Diff
1490
internal/watcher/watcher_test.go
Normal file
1490
internal/watcher/watcher_test.go
Normal file
File diff suppressed because it is too large
Load Diff
@@ -84,7 +84,8 @@ func (h *GeminiAPIHandler) GeminiGetHandler(c *gin.Context) {
|
||||
})
|
||||
return
|
||||
}
|
||||
switch request.Action {
|
||||
action := strings.TrimPrefix(request.Action, "/")
|
||||
switch action {
|
||||
case "gemini-3-pro-preview":
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"name": "models/gemini-3-pro-preview",
|
||||
@@ -189,7 +190,7 @@ func (h *GeminiAPIHandler) GeminiHandler(c *gin.Context) {
|
||||
})
|
||||
return
|
||||
}
|
||||
action := strings.Split(request.Action, ":")
|
||||
action := strings.Split(strings.TrimPrefix(request.Action, "/"), ":")
|
||||
if len(action) != 2 {
|
||||
c.JSON(http.StatusNotFound, handlers.ErrorResponse{
|
||||
Error: handlers.ErrorDetail{
|
||||
|
||||
@@ -49,9 +49,6 @@ type BaseAPIHandler struct {
|
||||
|
||||
// Cfg holds the current application configuration.
|
||||
Cfg *config.SDKConfig
|
||||
|
||||
// OpenAICompatProviders is a list of provider names for OpenAI compatibility.
|
||||
OpenAICompatProviders []string
|
||||
}
|
||||
|
||||
// NewBaseAPIHandlers creates a new API handlers instance.
|
||||
@@ -63,11 +60,10 @@ type BaseAPIHandler struct {
|
||||
//
|
||||
// Returns:
|
||||
// - *BaseAPIHandler: A new API handlers instance
|
||||
func NewBaseAPIHandlers(cfg *config.SDKConfig, authManager *coreauth.Manager, openAICompatProviders []string) *BaseAPIHandler {
|
||||
func NewBaseAPIHandlers(cfg *config.SDKConfig, authManager *coreauth.Manager) *BaseAPIHandler {
|
||||
return &BaseAPIHandler{
|
||||
Cfg: cfg,
|
||||
AuthManager: authManager,
|
||||
OpenAICompatProviders: openAICompatProviders,
|
||||
Cfg: cfg,
|
||||
AuthManager: authManager,
|
||||
}
|
||||
}
|
||||
|
||||
@@ -342,30 +338,19 @@ func (h *BaseAPIHandler) getRequestDetails(modelName string) (providers []string
|
||||
// Resolve "auto" model to an actual available model first
|
||||
resolvedModelName := util.ResolveAutoModel(modelName)
|
||||
|
||||
providerName, extractedModelName, isDynamic := h.parseDynamicModel(resolvedModelName)
|
||||
|
||||
targetModelName := resolvedModelName
|
||||
if isDynamic {
|
||||
targetModelName = extractedModelName
|
||||
}
|
||||
|
||||
// Normalize the model name to handle dynamic thinking suffixes before determining the provider.
|
||||
normalizedModel, metadata = normalizeModelMetadata(targetModelName)
|
||||
normalizedModel, metadata = normalizeModelMetadata(resolvedModelName)
|
||||
|
||||
if isDynamic {
|
||||
providers = []string{providerName}
|
||||
} else {
|
||||
// For non-dynamic models, use the normalizedModel to get the provider name.
|
||||
providers = util.GetProviderName(normalizedModel)
|
||||
if len(providers) == 0 && metadata != nil {
|
||||
if originalRaw, ok := metadata[util.ThinkingOriginalModelMetadataKey]; ok {
|
||||
if originalModel, okStr := originalRaw.(string); okStr {
|
||||
originalModel = strings.TrimSpace(originalModel)
|
||||
if originalModel != "" && !strings.EqualFold(originalModel, normalizedModel) {
|
||||
if altProviders := util.GetProviderName(originalModel); len(altProviders) > 0 {
|
||||
providers = altProviders
|
||||
normalizedModel = originalModel
|
||||
}
|
||||
// Use the normalizedModel to get the provider name.
|
||||
providers = util.GetProviderName(normalizedModel)
|
||||
if len(providers) == 0 && metadata != nil {
|
||||
if originalRaw, ok := metadata[util.ThinkingOriginalModelMetadataKey]; ok {
|
||||
if originalModel, okStr := originalRaw.(string); okStr {
|
||||
originalModel = strings.TrimSpace(originalModel)
|
||||
if originalModel != "" && !strings.EqualFold(originalModel, normalizedModel) {
|
||||
if altProviders := util.GetProviderName(originalModel); len(altProviders) > 0 {
|
||||
providers = altProviders
|
||||
normalizedModel = originalModel
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -383,30 +368,6 @@ func (h *BaseAPIHandler) getRequestDetails(modelName string) (providers []string
|
||||
return providers, normalizedModel, metadata, nil
|
||||
}
|
||||
|
||||
func (h *BaseAPIHandler) parseDynamicModel(modelName string) (providerName, model string, isDynamic bool) {
|
||||
var providerPart, modelPart string
|
||||
for _, sep := range []string{"://"} {
|
||||
if parts := strings.SplitN(modelName, sep, 2); len(parts) == 2 {
|
||||
providerPart = parts[0]
|
||||
modelPart = parts[1]
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if providerPart == "" {
|
||||
return "", modelName, false
|
||||
}
|
||||
|
||||
// Check if the provider is a configured openai-compatibility provider
|
||||
for _, pName := range h.OpenAICompatProviders {
|
||||
if pName == providerPart {
|
||||
return providerPart, modelPart, true
|
||||
}
|
||||
}
|
||||
|
||||
return "", modelName, false
|
||||
}
|
||||
|
||||
func cloneBytes(src []byte) []byte {
|
||||
if len(src) == 0 {
|
||||
return nil
|
||||
|
||||
46
sdk/api/options.go
Normal file
46
sdk/api/options.go
Normal file
@@ -0,0 +1,46 @@
|
||||
// Package api exposes server option helpers for embedding CLIProxyAPI.
|
||||
//
|
||||
// It wraps internal server option types so external projects can configure the embedded
|
||||
// HTTP server without importing internal packages.
|
||||
package api
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
internalapi "github.com/router-for-me/CLIProxyAPI/v6/internal/api"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/sdk/api/handlers"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/sdk/config"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/sdk/logging"
|
||||
)
|
||||
|
||||
// ServerOption customises HTTP server construction.
|
||||
type ServerOption = internalapi.ServerOption
|
||||
|
||||
// WithMiddleware appends additional Gin middleware during server construction.
|
||||
func WithMiddleware(mw ...gin.HandlerFunc) ServerOption { return internalapi.WithMiddleware(mw...) }
|
||||
|
||||
// WithEngineConfigurator allows callers to mutate the Gin engine prior to middleware setup.
|
||||
func WithEngineConfigurator(fn func(*gin.Engine)) ServerOption {
|
||||
return internalapi.WithEngineConfigurator(fn)
|
||||
}
|
||||
|
||||
// WithRouterConfigurator appends a callback after default routes are registered.
|
||||
func WithRouterConfigurator(fn func(*gin.Engine, *handlers.BaseAPIHandler, *config.Config)) ServerOption {
|
||||
return internalapi.WithRouterConfigurator(fn)
|
||||
}
|
||||
|
||||
// WithLocalManagementPassword stores a runtime-only management password accepted for localhost requests.
|
||||
func WithLocalManagementPassword(password string) ServerOption {
|
||||
return internalapi.WithLocalManagementPassword(password)
|
||||
}
|
||||
|
||||
// WithKeepAliveEndpoint enables a keep-alive endpoint with the provided timeout and callback.
|
||||
func WithKeepAliveEndpoint(timeout time.Duration, onTimeout func()) ServerOption {
|
||||
return internalapi.WithKeepAliveEndpoint(timeout, onTimeout)
|
||||
}
|
||||
|
||||
// WithRequestLoggerFactory customises request logger creation.
|
||||
func WithRequestLoggerFactory(factory func(*config.Config, string) logging.RequestLogger) ServerOption {
|
||||
return internalapi.WithRequestLoggerFactory(factory)
|
||||
}
|
||||
@@ -72,7 +72,9 @@ func (s *FileTokenStore) Save(ctx context.Context, auth *cliproxyauth.Auth) (str
|
||||
return "", fmt.Errorf("auth filestore: marshal metadata failed: %w", errMarshal)
|
||||
}
|
||||
if existing, errRead := os.ReadFile(path); errRead == nil {
|
||||
if jsonEqual(existing, raw) {
|
||||
// Use metadataEqualIgnoringTimestamps to skip writes when only timestamp fields change.
|
||||
// This prevents the token refresh loop caused by timestamp/expired/expires_in changes.
|
||||
if metadataEqualIgnoringTimestamps(existing, raw) {
|
||||
return path, nil
|
||||
}
|
||||
} else if errRead != nil && !os.IsNotExist(errRead) {
|
||||
@@ -264,6 +266,8 @@ func (s *FileTokenStore) baseDirSnapshot() string {
|
||||
return s.baseDir
|
||||
}
|
||||
|
||||
// DEPRECATED: Use metadataEqualIgnoringTimestamps for comparing auth metadata.
|
||||
// This function is kept for backward compatibility but can cause refresh loops.
|
||||
func jsonEqual(a, b []byte) bool {
|
||||
var objA any
|
||||
var objB any
|
||||
@@ -276,6 +280,32 @@ func jsonEqual(a, b []byte) bool {
|
||||
return deepEqualJSON(objA, objB)
|
||||
}
|
||||
|
||||
// metadataEqualIgnoringTimestamps compares two metadata JSON blobs,
|
||||
// ignoring fields that change on every refresh but don't affect functionality.
|
||||
// This prevents unnecessary file writes that would trigger watcher events and
|
||||
// create refresh loops.
|
||||
func metadataEqualIgnoringTimestamps(a, b []byte) bool {
|
||||
var objA, objB map[string]any
|
||||
if err := json.Unmarshal(a, &objA); err != nil {
|
||||
return false
|
||||
}
|
||||
if err := json.Unmarshal(b, &objB); err != nil {
|
||||
return false
|
||||
}
|
||||
|
||||
// Fields to ignore: these change on every refresh but don't affect authentication logic.
|
||||
// - timestamp, expired, expires_in, last_refresh: time-related fields that change on refresh
|
||||
// - access_token: Google OAuth returns a new access_token on each refresh, this is expected
|
||||
// and shouldn't trigger file writes (the new token will be fetched again when needed)
|
||||
ignoredFields := []string{"timestamp", "expired", "expires_in", "last_refresh", "access_token"}
|
||||
for _, field := range ignoredFields {
|
||||
delete(objA, field)
|
||||
delete(objB, field)
|
||||
}
|
||||
|
||||
return deepEqualJSON(objA, objB)
|
||||
}
|
||||
|
||||
func deepEqualJSON(a, b any) bool {
|
||||
switch valA := a.(type) {
|
||||
case map[string]any:
|
||||
|
||||
@@ -363,10 +363,11 @@ func (m *Manager) executeWithProvider(ctx context.Context, provider string, req
|
||||
if provider == "" {
|
||||
return cliproxyexecutor.Response{}, &Error{Code: "provider_not_found", Message: "provider identifier is empty"}
|
||||
}
|
||||
routeModel := req.Model
|
||||
tried := make(map[string]struct{})
|
||||
var lastErr error
|
||||
for {
|
||||
auth, executor, errPick := m.pickNext(ctx, provider, req.Model, opts, tried)
|
||||
auth, executor, errPick := m.pickNext(ctx, provider, routeModel, opts, tried)
|
||||
if errPick != nil {
|
||||
if lastErr != nil {
|
||||
return cliproxyexecutor.Response{}, lastErr
|
||||
@@ -396,8 +397,10 @@ func (m *Manager) executeWithProvider(ctx context.Context, provider string, req
|
||||
execCtx = context.WithValue(execCtx, roundTripperContextKey{}, rt)
|
||||
execCtx = context.WithValue(execCtx, "cliproxy.roundtripper", rt)
|
||||
}
|
||||
resp, errExec := executor.Execute(execCtx, auth, req, opts)
|
||||
result := Result{AuthID: auth.ID, Provider: provider, Model: req.Model, Success: errExec == nil}
|
||||
execReq := req
|
||||
execReq.Model, execReq.Metadata = rewriteModelForAuth(routeModel, req.Metadata, auth)
|
||||
resp, errExec := executor.Execute(execCtx, auth, execReq, opts)
|
||||
result := Result{AuthID: auth.ID, Provider: provider, Model: routeModel, Success: errExec == nil}
|
||||
if errExec != nil {
|
||||
result.Error = &Error{Message: errExec.Error()}
|
||||
var se cliproxyexecutor.StatusError
|
||||
@@ -420,10 +423,11 @@ func (m *Manager) executeCountWithProvider(ctx context.Context, provider string,
|
||||
if provider == "" {
|
||||
return cliproxyexecutor.Response{}, &Error{Code: "provider_not_found", Message: "provider identifier is empty"}
|
||||
}
|
||||
routeModel := req.Model
|
||||
tried := make(map[string]struct{})
|
||||
var lastErr error
|
||||
for {
|
||||
auth, executor, errPick := m.pickNext(ctx, provider, req.Model, opts, tried)
|
||||
auth, executor, errPick := m.pickNext(ctx, provider, routeModel, opts, tried)
|
||||
if errPick != nil {
|
||||
if lastErr != nil {
|
||||
return cliproxyexecutor.Response{}, lastErr
|
||||
@@ -453,8 +457,10 @@ func (m *Manager) executeCountWithProvider(ctx context.Context, provider string,
|
||||
execCtx = context.WithValue(execCtx, roundTripperContextKey{}, rt)
|
||||
execCtx = context.WithValue(execCtx, "cliproxy.roundtripper", rt)
|
||||
}
|
||||
resp, errExec := executor.CountTokens(execCtx, auth, req, opts)
|
||||
result := Result{AuthID: auth.ID, Provider: provider, Model: req.Model, Success: errExec == nil}
|
||||
execReq := req
|
||||
execReq.Model, execReq.Metadata = rewriteModelForAuth(routeModel, req.Metadata, auth)
|
||||
resp, errExec := executor.CountTokens(execCtx, auth, execReq, opts)
|
||||
result := Result{AuthID: auth.ID, Provider: provider, Model: routeModel, Success: errExec == nil}
|
||||
if errExec != nil {
|
||||
result.Error = &Error{Message: errExec.Error()}
|
||||
var se cliproxyexecutor.StatusError
|
||||
@@ -477,10 +483,11 @@ func (m *Manager) executeStreamWithProvider(ctx context.Context, provider string
|
||||
if provider == "" {
|
||||
return nil, &Error{Code: "provider_not_found", Message: "provider identifier is empty"}
|
||||
}
|
||||
routeModel := req.Model
|
||||
tried := make(map[string]struct{})
|
||||
var lastErr error
|
||||
for {
|
||||
auth, executor, errPick := m.pickNext(ctx, provider, req.Model, opts, tried)
|
||||
auth, executor, errPick := m.pickNext(ctx, provider, routeModel, opts, tried)
|
||||
if errPick != nil {
|
||||
if lastErr != nil {
|
||||
return nil, lastErr
|
||||
@@ -510,14 +517,16 @@ func (m *Manager) executeStreamWithProvider(ctx context.Context, provider string
|
||||
execCtx = context.WithValue(execCtx, roundTripperContextKey{}, rt)
|
||||
execCtx = context.WithValue(execCtx, "cliproxy.roundtripper", rt)
|
||||
}
|
||||
chunks, errStream := executor.ExecuteStream(execCtx, auth, req, opts)
|
||||
execReq := req
|
||||
execReq.Model, execReq.Metadata = rewriteModelForAuth(routeModel, req.Metadata, auth)
|
||||
chunks, errStream := executor.ExecuteStream(execCtx, auth, execReq, opts)
|
||||
if errStream != nil {
|
||||
rerr := &Error{Message: errStream.Error()}
|
||||
var se cliproxyexecutor.StatusError
|
||||
if errors.As(errStream, &se) && se != nil {
|
||||
rerr.HTTPStatus = se.StatusCode()
|
||||
}
|
||||
result := Result{AuthID: auth.ID, Provider: provider, Model: req.Model, Success: false, Error: rerr}
|
||||
result := Result{AuthID: auth.ID, Provider: provider, Model: routeModel, Success: false, Error: rerr}
|
||||
result.RetryAfter = retryAfterFromError(errStream)
|
||||
m.MarkResult(execCtx, result)
|
||||
lastErr = errStream
|
||||
@@ -535,18 +544,66 @@ func (m *Manager) executeStreamWithProvider(ctx context.Context, provider string
|
||||
if errors.As(chunk.Err, &se) && se != nil {
|
||||
rerr.HTTPStatus = se.StatusCode()
|
||||
}
|
||||
m.MarkResult(streamCtx, Result{AuthID: streamAuth.ID, Provider: streamProvider, Model: req.Model, Success: false, Error: rerr})
|
||||
m.MarkResult(streamCtx, Result{AuthID: streamAuth.ID, Provider: streamProvider, Model: routeModel, Success: false, Error: rerr})
|
||||
}
|
||||
out <- chunk
|
||||
}
|
||||
if !failed {
|
||||
m.MarkResult(streamCtx, Result{AuthID: streamAuth.ID, Provider: streamProvider, Model: req.Model, Success: true})
|
||||
m.MarkResult(streamCtx, Result{AuthID: streamAuth.ID, Provider: streamProvider, Model: routeModel, Success: true})
|
||||
}
|
||||
}(execCtx, auth.Clone(), provider, chunks)
|
||||
return out, nil
|
||||
}
|
||||
}
|
||||
|
||||
func rewriteModelForAuth(model string, metadata map[string]any, auth *Auth) (string, map[string]any) {
|
||||
if auth == nil || model == "" {
|
||||
return model, metadata
|
||||
}
|
||||
prefix := strings.TrimSpace(auth.Prefix)
|
||||
if prefix == "" {
|
||||
return model, metadata
|
||||
}
|
||||
needle := prefix + "/"
|
||||
if !strings.HasPrefix(model, needle) {
|
||||
return model, metadata
|
||||
}
|
||||
rewritten := strings.TrimPrefix(model, needle)
|
||||
return rewritten, stripPrefixFromMetadata(metadata, needle)
|
||||
}
|
||||
|
||||
func stripPrefixFromMetadata(metadata map[string]any, needle string) map[string]any {
|
||||
if len(metadata) == 0 || needle == "" {
|
||||
return metadata
|
||||
}
|
||||
keys := []string{
|
||||
util.ThinkingOriginalModelMetadataKey,
|
||||
util.GeminiOriginalModelMetadataKey,
|
||||
}
|
||||
var out map[string]any
|
||||
for _, key := range keys {
|
||||
raw, ok := metadata[key]
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
value, okStr := raw.(string)
|
||||
if !okStr || !strings.HasPrefix(value, needle) {
|
||||
continue
|
||||
}
|
||||
if out == nil {
|
||||
out = make(map[string]any, len(metadata))
|
||||
for k, v := range metadata {
|
||||
out[k] = v
|
||||
}
|
||||
}
|
||||
out[key] = strings.TrimPrefix(value, needle)
|
||||
}
|
||||
if out == nil {
|
||||
return metadata
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func (m *Manager) normalizeProviders(providers []string) []string {
|
||||
if len(providers) == 0 {
|
||||
return nil
|
||||
|
||||
@@ -19,6 +19,8 @@ type Auth struct {
|
||||
Index uint64 `json:"-"`
|
||||
// Provider is the upstream provider key (e.g. "gemini", "claude").
|
||||
Provider string `json:"provider"`
|
||||
// Prefix optionally namespaces models for routing (e.g., "teamA/gemini-3-pro-preview").
|
||||
Prefix string `json:"prefix,omitempty"`
|
||||
// FileName stores the relative or absolute path of the backing auth file.
|
||||
FileName string `json:"-"`
|
||||
// Storage holds the token persistence implementation used during login flows.
|
||||
|
||||
@@ -7,10 +7,10 @@ import (
|
||||
"fmt"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/api"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
sdkaccess "github.com/router-for-me/CLIProxyAPI/v6/sdk/access"
|
||||
sdkAuth "github.com/router-for-me/CLIProxyAPI/v6/sdk/auth"
|
||||
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/sdk/config"
|
||||
)
|
||||
|
||||
// Builder constructs a Service instance with customizable providers.
|
||||
|
||||
@@ -3,8 +3,8 @@ package cliproxy
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/watcher"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/sdk/config"
|
||||
)
|
||||
|
||||
// NewFileTokenClientProvider returns the default token-backed client loader.
|
||||
|
||||
@@ -13,7 +13,6 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/api"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/runtime/executor"
|
||||
_ "github.com/router-for-me/CLIProxyAPI/v6/internal/usage"
|
||||
@@ -23,6 +22,7 @@ import (
|
||||
sdkAuth "github.com/router-for-me/CLIProxyAPI/v6/sdk/auth"
|
||||
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/usage"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/sdk/config"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
@@ -787,7 +787,7 @@ func (s *Service) registerModelsForAuth(a *coreauth.Auth) {
|
||||
if providerKey == "" {
|
||||
providerKey = "openai-compatibility"
|
||||
}
|
||||
GlobalModelRegistry().RegisterClient(a.ID, providerKey, ms)
|
||||
GlobalModelRegistry().RegisterClient(a.ID, providerKey, applyModelPrefixes(ms, a.Prefix, s.cfg.ForceModelPrefix))
|
||||
} else {
|
||||
// Ensure stale registrations are cleared when model list becomes empty.
|
||||
GlobalModelRegistry().UnregisterClient(a.ID)
|
||||
@@ -807,7 +807,7 @@ func (s *Service) registerModelsForAuth(a *coreauth.Auth) {
|
||||
if key == "" {
|
||||
key = strings.ToLower(strings.TrimSpace(a.Provider))
|
||||
}
|
||||
GlobalModelRegistry().RegisterClient(a.ID, key, models)
|
||||
GlobalModelRegistry().RegisterClient(a.ID, key, applyModelPrefixes(models, a.Prefix, s.cfg != nil && s.cfg.ForceModelPrefix))
|
||||
return
|
||||
}
|
||||
|
||||
@@ -987,6 +987,48 @@ func applyExcludedModels(models []*ModelInfo, excluded []string) []*ModelInfo {
|
||||
return filtered
|
||||
}
|
||||
|
||||
func applyModelPrefixes(models []*ModelInfo, prefix string, forceModelPrefix bool) []*ModelInfo {
|
||||
trimmedPrefix := strings.TrimSpace(prefix)
|
||||
if trimmedPrefix == "" || len(models) == 0 {
|
||||
return models
|
||||
}
|
||||
|
||||
out := make([]*ModelInfo, 0, len(models)*2)
|
||||
seen := make(map[string]struct{}, len(models)*2)
|
||||
|
||||
addModel := func(model *ModelInfo) {
|
||||
if model == nil {
|
||||
return
|
||||
}
|
||||
id := strings.TrimSpace(model.ID)
|
||||
if id == "" {
|
||||
return
|
||||
}
|
||||
if _, exists := seen[id]; exists {
|
||||
return
|
||||
}
|
||||
seen[id] = struct{}{}
|
||||
out = append(out, model)
|
||||
}
|
||||
|
||||
for _, model := range models {
|
||||
if model == nil {
|
||||
continue
|
||||
}
|
||||
baseID := strings.TrimSpace(model.ID)
|
||||
if baseID == "" {
|
||||
continue
|
||||
}
|
||||
if !forceModelPrefix || trimmedPrefix == baseID {
|
||||
addModel(model)
|
||||
}
|
||||
clone := *model
|
||||
clone.ID = trimmedPrefix + "/" + baseID
|
||||
addModel(&clone)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// matchWildcard performs case-insensitive wildcard matching where '*' matches any substring.
|
||||
func matchWildcard(pattern, value string) bool {
|
||||
if pattern == "" {
|
||||
|
||||
@@ -6,9 +6,9 @@ package cliproxy
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/watcher"
|
||||
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/sdk/config"
|
||||
)
|
||||
|
||||
// TokenClientProvider loads clients backed by stored authentication tokens.
|
||||
|
||||
@@ -3,9 +3,9 @@ package cliproxy
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/watcher"
|
||||
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/sdk/config"
|
||||
)
|
||||
|
||||
func defaultWatcherFactory(configPath, authDir string, reload func(*config.Config)) (*WatcherWrapper, error) {
|
||||
|
||||
@@ -1,82 +1,59 @@
|
||||
// Package config provides configuration management for the CLI Proxy API server.
|
||||
// It handles loading and parsing YAML configuration files, and provides structured
|
||||
// access to application settings including server port, authentication directory,
|
||||
// debug settings, proxy configuration, and API keys.
|
||||
// Package config provides the public SDK configuration API.
|
||||
//
|
||||
// It re-exports the server configuration types and helpers so external projects can
|
||||
// embed CLIProxyAPI without importing internal packages.
|
||||
package config
|
||||
|
||||
// SDKConfig represents the application's configuration, loaded from a YAML file.
|
||||
type SDKConfig struct {
|
||||
// ProxyURL is the URL of an optional proxy server to use for outbound requests.
|
||||
ProxyURL string `yaml:"proxy-url" json:"proxy-url"`
|
||||
import internalconfig "github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
|
||||
// RequestLog enables or disables detailed request logging functionality.
|
||||
RequestLog bool `yaml:"request-log" json:"request-log"`
|
||||
type SDKConfig = internalconfig.SDKConfig
|
||||
type AccessConfig = internalconfig.AccessConfig
|
||||
type AccessProvider = internalconfig.AccessProvider
|
||||
|
||||
// APIKeys is a list of keys for authenticating clients to this proxy server.
|
||||
APIKeys []string `yaml:"api-keys" json:"api-keys"`
|
||||
type Config = internalconfig.Config
|
||||
|
||||
// Access holds request authentication provider configuration.
|
||||
Access AccessConfig `yaml:"auth,omitempty" json:"auth,omitempty"`
|
||||
}
|
||||
type TLSConfig = internalconfig.TLSConfig
|
||||
type RemoteManagement = internalconfig.RemoteManagement
|
||||
type AmpCode = internalconfig.AmpCode
|
||||
type PayloadConfig = internalconfig.PayloadConfig
|
||||
type PayloadRule = internalconfig.PayloadRule
|
||||
type PayloadModelRule = internalconfig.PayloadModelRule
|
||||
|
||||
// AccessConfig groups request authentication providers.
|
||||
type AccessConfig struct {
|
||||
// Providers lists configured authentication providers.
|
||||
Providers []AccessProvider `yaml:"providers,omitempty" json:"providers,omitempty"`
|
||||
}
|
||||
type GeminiKey = internalconfig.GeminiKey
|
||||
type CodexKey = internalconfig.CodexKey
|
||||
type ClaudeKey = internalconfig.ClaudeKey
|
||||
type VertexCompatKey = internalconfig.VertexCompatKey
|
||||
type VertexCompatModel = internalconfig.VertexCompatModel
|
||||
type OpenAICompatibility = internalconfig.OpenAICompatibility
|
||||
type OpenAICompatibilityAPIKey = internalconfig.OpenAICompatibilityAPIKey
|
||||
type OpenAICompatibilityModel = internalconfig.OpenAICompatibilityModel
|
||||
|
||||
// AccessProvider describes a request authentication provider entry.
|
||||
type AccessProvider struct {
|
||||
// Name is the instance identifier for the provider.
|
||||
Name string `yaml:"name" json:"name"`
|
||||
|
||||
// Type selects the provider implementation registered via the SDK.
|
||||
Type string `yaml:"type" json:"type"`
|
||||
|
||||
// SDK optionally names a third-party SDK module providing this provider.
|
||||
SDK string `yaml:"sdk,omitempty" json:"sdk,omitempty"`
|
||||
|
||||
// APIKeys lists inline keys for providers that require them.
|
||||
APIKeys []string `yaml:"api-keys,omitempty" json:"api-keys,omitempty"`
|
||||
|
||||
// Config passes provider-specific options to the implementation.
|
||||
Config map[string]any `yaml:"config,omitempty" json:"config,omitempty"`
|
||||
}
|
||||
type TLS = internalconfig.TLSConfig
|
||||
|
||||
const (
|
||||
// AccessProviderTypeConfigAPIKey is the built-in provider validating inline API keys.
|
||||
AccessProviderTypeConfigAPIKey = "config-api-key"
|
||||
|
||||
// DefaultAccessProviderName is applied when no provider name is supplied.
|
||||
DefaultAccessProviderName = "config-inline"
|
||||
AccessProviderTypeConfigAPIKey = internalconfig.AccessProviderTypeConfigAPIKey
|
||||
DefaultAccessProviderName = internalconfig.DefaultAccessProviderName
|
||||
DefaultPanelGitHubRepository = internalconfig.DefaultPanelGitHubRepository
|
||||
)
|
||||
|
||||
// ConfigAPIKeyProvider returns the first inline API key provider if present.
|
||||
func (c *SDKConfig) ConfigAPIKeyProvider() *AccessProvider {
|
||||
if c == nil {
|
||||
return nil
|
||||
}
|
||||
for i := range c.Access.Providers {
|
||||
if c.Access.Providers[i].Type == AccessProviderTypeConfigAPIKey {
|
||||
if c.Access.Providers[i].Name == "" {
|
||||
c.Access.Providers[i].Name = DefaultAccessProviderName
|
||||
}
|
||||
return &c.Access.Providers[i]
|
||||
}
|
||||
}
|
||||
return nil
|
||||
func MakeInlineAPIKeyProvider(keys []string) *AccessProvider {
|
||||
return internalconfig.MakeInlineAPIKeyProvider(keys)
|
||||
}
|
||||
|
||||
// MakeInlineAPIKeyProvider constructs an inline API key provider configuration.
|
||||
// It returns nil when no keys are supplied.
|
||||
func MakeInlineAPIKeyProvider(keys []string) *AccessProvider {
|
||||
if len(keys) == 0 {
|
||||
return nil
|
||||
}
|
||||
provider := &AccessProvider{
|
||||
Name: DefaultAccessProviderName,
|
||||
Type: AccessProviderTypeConfigAPIKey,
|
||||
APIKeys: append([]string(nil), keys...),
|
||||
}
|
||||
return provider
|
||||
func LoadConfig(configFile string) (*Config, error) { return internalconfig.LoadConfig(configFile) }
|
||||
|
||||
func LoadConfigOptional(configFile string, optional bool) (*Config, error) {
|
||||
return internalconfig.LoadConfigOptional(configFile, optional)
|
||||
}
|
||||
|
||||
func SaveConfigPreserveComments(configFile string, cfg *Config) error {
|
||||
return internalconfig.SaveConfigPreserveComments(configFile, cfg)
|
||||
}
|
||||
|
||||
func SaveConfigPreserveCommentsUpdateNestedScalar(configFile string, path []string, value string) error {
|
||||
return internalconfig.SaveConfigPreserveCommentsUpdateNestedScalar(configFile, path, value)
|
||||
}
|
||||
|
||||
func NormalizeCommentIndentation(data []byte) []byte {
|
||||
return internalconfig.NormalizeCommentIndentation(data)
|
||||
}
|
||||
|
||||
18
sdk/logging/request_logger.go
Normal file
18
sdk/logging/request_logger.go
Normal file
@@ -0,0 +1,18 @@
|
||||
// Package logging re-exports request logging primitives for SDK consumers.
|
||||
package logging
|
||||
|
||||
import internallogging "github.com/router-for-me/CLIProxyAPI/v6/internal/logging"
|
||||
|
||||
// RequestLogger defines the interface for logging HTTP requests and responses.
|
||||
type RequestLogger = internallogging.RequestLogger
|
||||
|
||||
// StreamingLogWriter handles real-time logging of streaming response chunks.
|
||||
type StreamingLogWriter = internallogging.StreamingLogWriter
|
||||
|
||||
// FileRequestLogger implements RequestLogger using file-based storage.
|
||||
type FileRequestLogger = internallogging.FileRequestLogger
|
||||
|
||||
// NewFileRequestLogger creates a new file-based request logger.
|
||||
func NewFileRequestLogger(enabled bool, logsDir string, configDir string) *FileRequestLogger {
|
||||
return internallogging.NewFileRequestLogger(enabled, logsDir, configDir)
|
||||
}
|
||||
423
test/gemini3_thinking_level_test.go
Normal file
423
test/gemini3_thinking_level_test.go
Normal file
@@ -0,0 +1,423 @@
|
||||
package test
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
|
||||
// registerGemini3Models loads Gemini 3 models into the registry for testing.
|
||||
func registerGemini3Models(t *testing.T) func() {
|
||||
t.Helper()
|
||||
reg := registry.GetGlobalRegistry()
|
||||
uid := fmt.Sprintf("gemini3-test-%d", time.Now().UnixNano())
|
||||
reg.RegisterClient(uid+"-gemini", "gemini", registry.GetGeminiModels())
|
||||
reg.RegisterClient(uid+"-aistudio", "aistudio", registry.GetAIStudioModels())
|
||||
return func() {
|
||||
reg.UnregisterClient(uid + "-gemini")
|
||||
reg.UnregisterClient(uid + "-aistudio")
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsGemini3Model(t *testing.T) {
|
||||
cases := []struct {
|
||||
model string
|
||||
expected bool
|
||||
}{
|
||||
{"gemini-3-pro-preview", true},
|
||||
{"gemini-3-flash-preview", true},
|
||||
{"gemini_3_pro_preview", true},
|
||||
{"gemini-3-pro", true},
|
||||
{"gemini-3-flash", true},
|
||||
{"GEMINI-3-PRO-PREVIEW", true},
|
||||
{"gemini-2.5-pro", false},
|
||||
{"gemini-2.5-flash", false},
|
||||
{"gpt-5", false},
|
||||
{"claude-sonnet-4-5", false},
|
||||
{"", false},
|
||||
}
|
||||
|
||||
for _, cs := range cases {
|
||||
t.Run(cs.model, func(t *testing.T) {
|
||||
got := util.IsGemini3Model(cs.model)
|
||||
if got != cs.expected {
|
||||
t.Fatalf("IsGemini3Model(%q) = %v, want %v", cs.model, got, cs.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsGemini3ProModel(t *testing.T) {
|
||||
cases := []struct {
|
||||
model string
|
||||
expected bool
|
||||
}{
|
||||
{"gemini-3-pro-preview", true},
|
||||
{"gemini_3_pro_preview", true},
|
||||
{"gemini-3-pro", true},
|
||||
{"GEMINI-3-PRO-PREVIEW", true},
|
||||
{"gemini-3-flash-preview", false},
|
||||
{"gemini-3-flash", false},
|
||||
{"gemini-2.5-pro", false},
|
||||
{"", false},
|
||||
}
|
||||
|
||||
for _, cs := range cases {
|
||||
t.Run(cs.model, func(t *testing.T) {
|
||||
got := util.IsGemini3ProModel(cs.model)
|
||||
if got != cs.expected {
|
||||
t.Fatalf("IsGemini3ProModel(%q) = %v, want %v", cs.model, got, cs.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsGemini3FlashModel(t *testing.T) {
|
||||
cases := []struct {
|
||||
model string
|
||||
expected bool
|
||||
}{
|
||||
{"gemini-3-flash-preview", true},
|
||||
{"gemini_3_flash_preview", true},
|
||||
{"gemini-3-flash", true},
|
||||
{"GEMINI-3-FLASH-PREVIEW", true},
|
||||
{"gemini-3-pro-preview", false},
|
||||
{"gemini-3-pro", false},
|
||||
{"gemini-2.5-flash", false},
|
||||
{"", false},
|
||||
}
|
||||
|
||||
for _, cs := range cases {
|
||||
t.Run(cs.model, func(t *testing.T) {
|
||||
got := util.IsGemini3FlashModel(cs.model)
|
||||
if got != cs.expected {
|
||||
t.Fatalf("IsGemini3FlashModel(%q) = %v, want %v", cs.model, got, cs.expected)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateGemini3ThinkingLevel(t *testing.T) {
|
||||
cases := []struct {
|
||||
name string
|
||||
model string
|
||||
level string
|
||||
wantOK bool
|
||||
wantVal string
|
||||
}{
|
||||
// Gemini 3 Pro: supports "low", "high"
|
||||
{"pro-low", "gemini-3-pro-preview", "low", true, "low"},
|
||||
{"pro-high", "gemini-3-pro-preview", "high", true, "high"},
|
||||
{"pro-minimal-invalid", "gemini-3-pro-preview", "minimal", false, ""},
|
||||
{"pro-medium-invalid", "gemini-3-pro-preview", "medium", false, ""},
|
||||
|
||||
// Gemini 3 Flash: supports "minimal", "low", "medium", "high"
|
||||
{"flash-minimal", "gemini-3-flash-preview", "minimal", true, "minimal"},
|
||||
{"flash-low", "gemini-3-flash-preview", "low", true, "low"},
|
||||
{"flash-medium", "gemini-3-flash-preview", "medium", true, "medium"},
|
||||
{"flash-high", "gemini-3-flash-preview", "high", true, "high"},
|
||||
|
||||
// Case insensitivity
|
||||
{"flash-LOW-case", "gemini-3-flash-preview", "LOW", true, "low"},
|
||||
{"flash-High-case", "gemini-3-flash-preview", "High", true, "high"},
|
||||
{"pro-HIGH-case", "gemini-3-pro-preview", "HIGH", true, "high"},
|
||||
|
||||
// Invalid levels
|
||||
{"flash-invalid", "gemini-3-flash-preview", "xhigh", false, ""},
|
||||
{"flash-invalid-auto", "gemini-3-flash-preview", "auto", false, ""},
|
||||
{"flash-empty", "gemini-3-flash-preview", "", false, ""},
|
||||
|
||||
// Non-Gemini 3 models
|
||||
{"non-gemini3", "gemini-2.5-pro", "high", false, ""},
|
||||
{"gpt5", "gpt-5", "high", false, ""},
|
||||
}
|
||||
|
||||
for _, cs := range cases {
|
||||
t.Run(cs.name, func(t *testing.T) {
|
||||
got, ok := util.ValidateGemini3ThinkingLevel(cs.model, cs.level)
|
||||
if ok != cs.wantOK {
|
||||
t.Fatalf("ValidateGemini3ThinkingLevel(%q, %q) ok = %v, want %v", cs.model, cs.level, ok, cs.wantOK)
|
||||
}
|
||||
if got != cs.wantVal {
|
||||
t.Fatalf("ValidateGemini3ThinkingLevel(%q, %q) = %q, want %q", cs.model, cs.level, got, cs.wantVal)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestThinkingBudgetToGemini3Level(t *testing.T) {
|
||||
cases := []struct {
|
||||
name string
|
||||
model string
|
||||
budget int
|
||||
wantOK bool
|
||||
wantVal string
|
||||
}{
|
||||
// Gemini 3 Pro: maps to "low" or "high"
|
||||
{"pro-dynamic", "gemini-3-pro-preview", -1, true, "high"},
|
||||
{"pro-zero", "gemini-3-pro-preview", 0, true, "low"},
|
||||
{"pro-small", "gemini-3-pro-preview", 1000, true, "low"},
|
||||
{"pro-medium", "gemini-3-pro-preview", 8000, true, "low"},
|
||||
{"pro-large", "gemini-3-pro-preview", 20000, true, "high"},
|
||||
{"pro-huge", "gemini-3-pro-preview", 50000, true, "high"},
|
||||
|
||||
// Gemini 3 Flash: maps to "minimal", "low", "medium", "high"
|
||||
{"flash-dynamic", "gemini-3-flash-preview", -1, true, "high"},
|
||||
{"flash-zero", "gemini-3-flash-preview", 0, true, "minimal"},
|
||||
{"flash-tiny", "gemini-3-flash-preview", 500, true, "minimal"},
|
||||
{"flash-small", "gemini-3-flash-preview", 1000, true, "low"},
|
||||
{"flash-medium-val", "gemini-3-flash-preview", 8000, true, "medium"},
|
||||
{"flash-large", "gemini-3-flash-preview", 20000, true, "high"},
|
||||
{"flash-huge", "gemini-3-flash-preview", 50000, true, "high"},
|
||||
|
||||
// Non-Gemini 3 models should return false
|
||||
{"gemini25-budget", "gemini-2.5-pro", 8000, false, ""},
|
||||
{"gpt5-budget", "gpt-5", 8000, false, ""},
|
||||
}
|
||||
|
||||
for _, cs := range cases {
|
||||
t.Run(cs.name, func(t *testing.T) {
|
||||
got, ok := util.ThinkingBudgetToGemini3Level(cs.model, cs.budget)
|
||||
if ok != cs.wantOK {
|
||||
t.Fatalf("ThinkingBudgetToGemini3Level(%q, %d) ok = %v, want %v", cs.model, cs.budget, ok, cs.wantOK)
|
||||
}
|
||||
if got != cs.wantVal {
|
||||
t.Fatalf("ThinkingBudgetToGemini3Level(%q, %d) = %q, want %q", cs.model, cs.budget, got, cs.wantVal)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestApplyGemini3ThinkingLevelFromMetadata(t *testing.T) {
|
||||
cleanup := registerGemini3Models(t)
|
||||
defer cleanup()
|
||||
|
||||
cases := []struct {
|
||||
name string
|
||||
model string
|
||||
metadata map[string]any
|
||||
inputBody string
|
||||
wantLevel string
|
||||
wantInclude bool
|
||||
wantNoChange bool
|
||||
}{
|
||||
{
|
||||
name: "flash-minimal-from-suffix",
|
||||
model: "gemini-3-flash-preview",
|
||||
metadata: map[string]any{"reasoning_effort": "minimal"},
|
||||
inputBody: `{"generationConfig":{"thinkingConfig":{"includeThoughts":true}}}`,
|
||||
wantLevel: "minimal",
|
||||
wantInclude: true,
|
||||
},
|
||||
{
|
||||
name: "flash-medium-from-suffix",
|
||||
model: "gemini-3-flash-preview",
|
||||
metadata: map[string]any{"reasoning_effort": "medium"},
|
||||
inputBody: `{"generationConfig":{"thinkingConfig":{"includeThoughts":true}}}`,
|
||||
wantLevel: "medium",
|
||||
wantInclude: true,
|
||||
},
|
||||
{
|
||||
name: "pro-high-from-suffix",
|
||||
model: "gemini-3-pro-preview",
|
||||
metadata: map[string]any{"reasoning_effort": "high"},
|
||||
inputBody: `{"generationConfig":{"thinkingConfig":{"includeThoughts":true}}}`,
|
||||
wantLevel: "high",
|
||||
wantInclude: true,
|
||||
},
|
||||
{
|
||||
name: "no-metadata-no-change",
|
||||
model: "gemini-3-flash-preview",
|
||||
metadata: nil,
|
||||
inputBody: `{"generationConfig":{"thinkingConfig":{"includeThoughts":true}}}`,
|
||||
wantNoChange: true,
|
||||
},
|
||||
{
|
||||
name: "non-gemini3-no-change",
|
||||
model: "gemini-2.5-pro",
|
||||
metadata: map[string]any{"reasoning_effort": "high"},
|
||||
inputBody: `{"generationConfig":{"thinkingConfig":{"thinkingBudget":-1}}}`,
|
||||
wantNoChange: true,
|
||||
},
|
||||
{
|
||||
name: "invalid-level-no-change",
|
||||
model: "gemini-3-flash-preview",
|
||||
metadata: map[string]any{"reasoning_effort": "xhigh"},
|
||||
inputBody: `{"generationConfig":{"thinkingConfig":{"includeThoughts":true}}}`,
|
||||
wantNoChange: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, cs := range cases {
|
||||
t.Run(cs.name, func(t *testing.T) {
|
||||
input := []byte(cs.inputBody)
|
||||
result := util.ApplyGemini3ThinkingLevelFromMetadata(cs.model, cs.metadata, input)
|
||||
|
||||
if cs.wantNoChange {
|
||||
if string(result) != cs.inputBody {
|
||||
t.Fatalf("expected no change, but got: %s", string(result))
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
level := gjson.GetBytes(result, "generationConfig.thinkingConfig.thinkingLevel")
|
||||
if !level.Exists() {
|
||||
t.Fatalf("thinkingLevel not set in result: %s", string(result))
|
||||
}
|
||||
if level.String() != cs.wantLevel {
|
||||
t.Fatalf("thinkingLevel = %q, want %q", level.String(), cs.wantLevel)
|
||||
}
|
||||
|
||||
include := gjson.GetBytes(result, "generationConfig.thinkingConfig.includeThoughts")
|
||||
if cs.wantInclude && (!include.Exists() || !include.Bool()) {
|
||||
t.Fatalf("includeThoughts should be true, got: %s", string(result))
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestApplyGemini3ThinkingLevelFromMetadataCLI(t *testing.T) {
|
||||
cleanup := registerGemini3Models(t)
|
||||
defer cleanup()
|
||||
|
||||
cases := []struct {
|
||||
name string
|
||||
model string
|
||||
metadata map[string]any
|
||||
inputBody string
|
||||
wantLevel string
|
||||
wantInclude bool
|
||||
wantNoChange bool
|
||||
}{
|
||||
{
|
||||
name: "flash-minimal-from-suffix-cli",
|
||||
model: "gemini-3-flash-preview",
|
||||
metadata: map[string]any{"reasoning_effort": "minimal"},
|
||||
inputBody: `{"request":{"generationConfig":{"thinkingConfig":{"includeThoughts":true}}}}`,
|
||||
wantLevel: "minimal",
|
||||
wantInclude: true,
|
||||
},
|
||||
{
|
||||
name: "flash-low-from-suffix-cli",
|
||||
model: "gemini-3-flash-preview",
|
||||
metadata: map[string]any{"reasoning_effort": "low"},
|
||||
inputBody: `{"request":{"generationConfig":{"thinkingConfig":{"includeThoughts":true}}}}`,
|
||||
wantLevel: "low",
|
||||
wantInclude: true,
|
||||
},
|
||||
{
|
||||
name: "pro-low-from-suffix-cli",
|
||||
model: "gemini-3-pro-preview",
|
||||
metadata: map[string]any{"reasoning_effort": "low"},
|
||||
inputBody: `{"request":{"generationConfig":{"thinkingConfig":{"includeThoughts":true}}}}`,
|
||||
wantLevel: "low",
|
||||
wantInclude: true,
|
||||
},
|
||||
{
|
||||
name: "no-metadata-no-change-cli",
|
||||
model: "gemini-3-flash-preview",
|
||||
metadata: nil,
|
||||
inputBody: `{"request":{"generationConfig":{"thinkingConfig":{"includeThoughts":true}}}}`,
|
||||
wantNoChange: true,
|
||||
},
|
||||
{
|
||||
name: "non-gemini3-no-change-cli",
|
||||
model: "gemini-2.5-pro",
|
||||
metadata: map[string]any{"reasoning_effort": "high"},
|
||||
inputBody: `{"request":{"generationConfig":{"thinkingConfig":{"thinkingBudget":-1}}}}`,
|
||||
wantNoChange: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, cs := range cases {
|
||||
t.Run(cs.name, func(t *testing.T) {
|
||||
input := []byte(cs.inputBody)
|
||||
result := util.ApplyGemini3ThinkingLevelFromMetadataCLI(cs.model, cs.metadata, input)
|
||||
|
||||
if cs.wantNoChange {
|
||||
if string(result) != cs.inputBody {
|
||||
t.Fatalf("expected no change, but got: %s", string(result))
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
level := gjson.GetBytes(result, "request.generationConfig.thinkingConfig.thinkingLevel")
|
||||
if !level.Exists() {
|
||||
t.Fatalf("thinkingLevel not set in result: %s", string(result))
|
||||
}
|
||||
if level.String() != cs.wantLevel {
|
||||
t.Fatalf("thinkingLevel = %q, want %q", level.String(), cs.wantLevel)
|
||||
}
|
||||
|
||||
include := gjson.GetBytes(result, "request.generationConfig.thinkingConfig.includeThoughts")
|
||||
if cs.wantInclude && (!include.Exists() || !include.Bool()) {
|
||||
t.Fatalf("includeThoughts should be true, got: %s", string(result))
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeGeminiThinkingBudget_Gemini3Conversion(t *testing.T) {
|
||||
cleanup := registerGemini3Models(t)
|
||||
defer cleanup()
|
||||
|
||||
cases := []struct {
|
||||
name string
|
||||
model string
|
||||
inputBody string
|
||||
wantLevel string
|
||||
wantBudget bool // if true, expect thinkingBudget instead of thinkingLevel
|
||||
}{
|
||||
{
|
||||
name: "gemini3-flash-budget-to-level",
|
||||
model: "gemini-3-flash-preview",
|
||||
inputBody: `{"generationConfig":{"thinkingConfig":{"thinkingBudget":8000}}}`,
|
||||
wantLevel: "medium",
|
||||
},
|
||||
{
|
||||
name: "gemini3-pro-budget-to-level",
|
||||
model: "gemini-3-pro-preview",
|
||||
inputBody: `{"generationConfig":{"thinkingConfig":{"thinkingBudget":20000}}}`,
|
||||
wantLevel: "high",
|
||||
},
|
||||
{
|
||||
name: "gemini25-keeps-budget",
|
||||
model: "gemini-2.5-pro",
|
||||
inputBody: `{"generationConfig":{"thinkingConfig":{"thinkingBudget":8000}}}`,
|
||||
wantBudget: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, cs := range cases {
|
||||
t.Run(cs.name, func(t *testing.T) {
|
||||
result := util.NormalizeGeminiThinkingBudget(cs.model, []byte(cs.inputBody))
|
||||
|
||||
if cs.wantBudget {
|
||||
budget := gjson.GetBytes(result, "generationConfig.thinkingConfig.thinkingBudget")
|
||||
if !budget.Exists() {
|
||||
t.Fatalf("thinkingBudget should exist for non-Gemini3 model: %s", string(result))
|
||||
}
|
||||
level := gjson.GetBytes(result, "generationConfig.thinkingConfig.thinkingLevel")
|
||||
if level.Exists() {
|
||||
t.Fatalf("thinkingLevel should not exist for non-Gemini3 model: %s", string(result))
|
||||
}
|
||||
} else {
|
||||
level := gjson.GetBytes(result, "generationConfig.thinkingConfig.thinkingLevel")
|
||||
if !level.Exists() {
|
||||
t.Fatalf("thinkingLevel should exist for Gemini3 model: %s", string(result))
|
||||
}
|
||||
if level.String() != cs.wantLevel {
|
||||
t.Fatalf("thinkingLevel = %q, want %q", level.String(), cs.wantLevel)
|
||||
}
|
||||
budget := gjson.GetBytes(result, "generationConfig.thinkingConfig.thinkingBudget")
|
||||
if budget.Exists() {
|
||||
t.Fatalf("thinkingBudget should be removed for Gemini3 model: %s", string(result))
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -295,7 +295,7 @@ func TestThinkingConversionsAcrossProtocolsAndModels(t *testing.T) {
|
||||
}
|
||||
// Check numeric budget fallback for allowCompat
|
||||
if budget, _, _, matched := util.ThinkingFromMetadata(metadata); matched && budget != nil {
|
||||
if mapped, okMap := util.OpenAIThinkingBudgetToEffort(normalizedModel, *budget); okMap && mapped != "" {
|
||||
if mapped, okMap := util.ThinkingBudgetToEffort(normalizedModel, *budget); okMap && mapped != "" {
|
||||
return true, mapped, false
|
||||
}
|
||||
}
|
||||
@@ -308,7 +308,7 @@ func TestThinkingConversionsAcrossProtocolsAndModels(t *testing.T) {
|
||||
effort, ok := util.ReasoningEffortFromMetadata(metadata)
|
||||
if !ok || strings.TrimSpace(effort) == "" {
|
||||
if budget, _, _, matched := util.ThinkingFromMetadata(metadata); matched && budget != nil {
|
||||
if mapped, okMap := util.OpenAIThinkingBudgetToEffort(normalizedModel, *budget); okMap {
|
||||
if mapped, okMap := util.ThinkingBudgetToEffort(normalizedModel, *budget); okMap {
|
||||
effort = mapped
|
||||
ok = true
|
||||
}
|
||||
@@ -336,7 +336,7 @@ func TestThinkingConversionsAcrossProtocolsAndModels(t *testing.T) {
|
||||
return false, "", true
|
||||
}
|
||||
if budget, _, _, matched := util.ThinkingFromMetadata(metadata); matched && budget != nil {
|
||||
if mapped, okMap := util.OpenAIThinkingBudgetToEffort(normalizedModel, *budget); okMap && mapped != "" {
|
||||
if mapped, okMap := util.ThinkingBudgetToEffort(normalizedModel, *budget); okMap && mapped != "" {
|
||||
mapped = strings.ToLower(strings.TrimSpace(mapped))
|
||||
if normalized, okLevel := util.NormalizeReasoningEffortLevel(normalizedModel, mapped); okLevel {
|
||||
return true, normalized, false
|
||||
@@ -609,7 +609,7 @@ func TestRawPayloadThinkingConversions(t *testing.T) {
|
||||
return true, normalized, false
|
||||
}
|
||||
if budget, ok := cs.thinkingParam.(int); ok {
|
||||
if mapped, okM := util.OpenAIThinkingBudgetToEffort(model, budget); okM && mapped != "" {
|
||||
if mapped, okM := util.ThinkingBudgetToEffort(model, budget); okM && mapped != "" {
|
||||
return true, mapped, false
|
||||
}
|
||||
}
|
||||
@@ -625,7 +625,7 @@ func TestRawPayloadThinkingConversions(t *testing.T) {
|
||||
return false, "", true // invalid level
|
||||
}
|
||||
if budget, ok := cs.thinkingParam.(int); ok {
|
||||
if mapped, okM := util.OpenAIThinkingBudgetToEffort(model, budget); okM && mapped != "" {
|
||||
if mapped, okM := util.ThinkingBudgetToEffort(model, budget); okM && mapped != "" {
|
||||
// Check if the mapped effort is valid for this model
|
||||
if _, validLevel := util.NormalizeReasoningEffortLevel(model, mapped); !validLevel {
|
||||
return true, mapped, true // expect validation error
|
||||
@@ -646,7 +646,7 @@ func TestRawPayloadThinkingConversions(t *testing.T) {
|
||||
return false, "", true
|
||||
}
|
||||
if budget, ok := cs.thinkingParam.(int); ok {
|
||||
if mapped, okM := util.OpenAIThinkingBudgetToEffort(model, budget); okM && mapped != "" {
|
||||
if mapped, okM := util.ThinkingBudgetToEffort(model, budget); okM && mapped != "" {
|
||||
// Check if the mapped effort is valid for this model
|
||||
if _, validLevel := util.NormalizeReasoningEffortLevel(model, mapped); !validLevel {
|
||||
return true, mapped, true // expect validation error
|
||||
@@ -721,7 +721,7 @@ func TestRawPayloadThinkingConversions(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestOpenAIThinkingBudgetToEffortRanges(t *testing.T) {
|
||||
func TestThinkingBudgetToEffort(t *testing.T) {
|
||||
cleanup := registerCoreModels(t)
|
||||
defer cleanup()
|
||||
|
||||
@@ -733,7 +733,7 @@ func TestOpenAIThinkingBudgetToEffortRanges(t *testing.T) {
|
||||
ok bool
|
||||
}{
|
||||
{name: "dynamic-auto", model: "gpt-5", budget: -1, want: "auto", ok: true},
|
||||
{name: "zero-none", model: "gpt-5", budget: 0, want: "none", ok: true},
|
||||
{name: "zero-none", model: "gpt-5", budget: 0, want: "minimal", ok: true},
|
||||
{name: "low-min", model: "gpt-5", budget: 1, want: "low", ok: true},
|
||||
{name: "low-max", model: "gpt-5", budget: 1024, want: "low", ok: true},
|
||||
{name: "medium-min", model: "gpt-5", budget: 1025, want: "medium", ok: true},
|
||||
@@ -741,14 +741,14 @@ func TestOpenAIThinkingBudgetToEffortRanges(t *testing.T) {
|
||||
{name: "high-min", model: "gpt-5", budget: 8193, want: "high", ok: true},
|
||||
{name: "high-max", model: "gpt-5", budget: 24576, want: "high", ok: true},
|
||||
{name: "over-max-clamps-to-highest", model: "gpt-5", budget: 64000, want: "high", ok: true},
|
||||
{name: "over-max-xhigh-model", model: "gpt-5.2", budget: 50000, want: "xhigh", ok: true},
|
||||
{name: "over-max-xhigh-model", model: "gpt-5.2", budget: 64000, want: "xhigh", ok: true},
|
||||
{name: "negative-unsupported", model: "gpt-5", budget: -5, want: "", ok: false},
|
||||
}
|
||||
|
||||
for _, cs := range cases {
|
||||
cs := cs
|
||||
t.Run(cs.name, func(t *testing.T) {
|
||||
got, ok := util.OpenAIThinkingBudgetToEffort(cs.model, cs.budget)
|
||||
got, ok := util.ThinkingBudgetToEffort(cs.model, cs.budget)
|
||||
if ok != cs.ok {
|
||||
t.Fatalf("ok mismatch for model=%s budget=%d: expect %v got %v", cs.model, cs.budget, cs.ok, ok)
|
||||
}
|
||||
@@ -758,3 +758,41 @@ func TestOpenAIThinkingBudgetToEffortRanges(t *testing.T) {
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestThinkingEffortToBudget(t *testing.T) {
|
||||
cleanup := registerCoreModels(t)
|
||||
defer cleanup()
|
||||
|
||||
cases := []struct {
|
||||
name string
|
||||
model string
|
||||
effort string
|
||||
want int
|
||||
ok bool
|
||||
}{
|
||||
{name: "none", model: "gemini-2.5-pro", effort: "none", want: 0, ok: true},
|
||||
{name: "auto", model: "gemini-2.5-pro", effort: "auto", want: -1, ok: true},
|
||||
{name: "minimal", model: "gemini-2.5-pro", effort: "minimal", want: 512, ok: true},
|
||||
{name: "low", model: "gemini-2.5-pro", effort: "low", want: 1024, ok: true},
|
||||
{name: "medium", model: "gemini-2.5-pro", effort: "medium", want: 8192, ok: true},
|
||||
{name: "high", model: "gemini-2.5-pro", effort: "high", want: 24576, ok: true},
|
||||
{name: "xhigh", model: "gemini-2.5-pro", effort: "xhigh", want: 32768, ok: true},
|
||||
{name: "empty-unsupported", model: "gemini-2.5-pro", effort: "", want: 0, ok: false},
|
||||
{name: "invalid-unsupported", model: "gemini-2.5-pro", effort: "ultra", want: 0, ok: false},
|
||||
{name: "case-insensitive", model: "gemini-2.5-pro", effort: "LOW", want: 1024, ok: true},
|
||||
{name: "case-insensitive-medium", model: "gemini-2.5-pro", effort: "MEDIUM", want: 8192, ok: true},
|
||||
}
|
||||
|
||||
for _, cs := range cases {
|
||||
cs := cs
|
||||
t.Run(cs.name, func(t *testing.T) {
|
||||
got, ok := util.ThinkingEffortToBudget(cs.model, cs.effort)
|
||||
if ok != cs.ok {
|
||||
t.Fatalf("ok mismatch for model=%s effort=%s: expect %v got %v", cs.model, cs.effort, cs.ok, ok)
|
||||
}
|
||||
if got != cs.want {
|
||||
t.Fatalf("value mismatch for model=%s effort=%s: expect %d got %d", cs.model, cs.effort, cs.want, got)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user