From ac59023abb72d7c6c6931d0854320448cb6cf816 Mon Sep 17 00:00:00 2001 From: Luis Pater Date: Tue, 23 Sep 2025 02:27:51 +0800 Subject: [PATCH] feat(executor): add `CountTokens` support across all executors - Introduced `CountTokens` method to Codex, Claude, Gemini, Qwen, OpenAI-compatible, and other executors. - Implemented `ExecuteCount` in `AuthManager` for token counting via provider round-robin. - Updated handlers to leverage `ExecuteCountWithAuthManager` for streamlined token counting. - Added fallback and error handling logic for token counting requests. --- .../api/handlers/gemini/gemini_handlers.go | 34 ++------- internal/api/handlers/handlers.go | 27 +++++++ internal/runtime/executor/claude_executor.go | 4 + internal/runtime/executor/client_executor.go | 4 + internal/runtime/executor/codex_executor.go | 5 ++ .../runtime/executor/gemini_cli_executor.go | 75 +++++++++++++++++++ internal/runtime/executor/gemini_executor.go | 4 + .../runtime/executor/gemini_web_executor.go | 4 + .../executor/openai_compat_executor.go | 4 + internal/runtime/executor/qwen_executor.go | 4 + sdk/cliproxy/auth/manager.go | 73 ++++++++++++++++++ 11 files changed, 210 insertions(+), 28 deletions(-) diff --git a/internal/api/handlers/gemini/gemini_handlers.go b/internal/api/handlers/gemini/gemini_handlers.go index 1fab54ba..3208160c 100644 --- a/internal/api/handlers/gemini/gemini_handlers.go +++ b/internal/api/handlers/gemini/gemini_handlers.go @@ -17,9 +17,6 @@ import ( . "github.com/router-for-me/CLIProxyAPI/v6/internal/constant" "github.com/router-for-me/CLIProxyAPI/v6/internal/interfaces" "github.com/router-for-me/CLIProxyAPI/v6/internal/registry" - "github.com/router-for-me/CLIProxyAPI/v6/internal/runtime/executor" - coreexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor" - sdktranslator "github.com/router-for-me/CLIProxyAPI/v6/sdk/translator" ) // GeminiAPIHandler contains the handlers for Gemini API endpoints. @@ -226,35 +223,16 @@ func (h *GeminiAPIHandler) handleStreamGenerateContent(c *gin.Context, modelName // - rawJSON: The raw JSON request body containing the content to count func (h *GeminiAPIHandler) handleCountTokens(c *gin.Context, modelName string, rawJSON []byte) { c.Header("Content-Type", "application/json") - alt := h.GetAlt(c) cliCtx, cliCancel := h.GetContextWithCancel(h, c, context.Background()) - defer func() { cliCancel() }() - - // Execute via AuthManager with action=countTokens - req := coreexecutor.Request{ - Model: modelName, - Payload: rawJSON, - Metadata: map[string]any{ - "action": "countTokens", - }, - } - opts := coreexecutor.Options{ - Stream: false, - Alt: alt, - OriginalRequest: rawJSON, - SourceFormat: sdktranslator.FromString(h.HandlerType()), - } - resp, err := h.AuthManager.Execute(cliCtx, []string{"gemini"}, req, opts) - if err != nil { - if msg, ok := executor.UnwrapError(err); ok { - h.WriteErrorResponse(c, msg) - return - } - h.WriteErrorResponse(c, &interfaces.ErrorMessage{StatusCode: http.StatusInternalServerError, Error: err}) + resp, errMsg := h.ExecuteCountWithAuthManager(cliCtx, h.HandlerType(), modelName, rawJSON, alt) + if errMsg != nil { + h.WriteErrorResponse(c, errMsg) + cliCancel(errMsg.Error) return } - _, _ = c.Writer.Write(resp.Payload) + _, _ = c.Writer.Write(resp) + cliCancel() } // handleGenerateContent handles non-streaming content generation requests for Gemini models. diff --git a/internal/api/handlers/handlers.go b/internal/api/handlers/handlers.go index 85bfbfbd..84509306 100644 --- a/internal/api/handlers/handlers.go +++ b/internal/api/handlers/handlers.go @@ -158,6 +158,33 @@ func (h *BaseAPIHandler) ExecuteWithAuthManager(ctx context.Context, handlerType return cloneBytes(resp.Payload), nil } +// ExecuteCountWithAuthManager executes a non-streaming request via the core auth manager. +// This path is the only supported execution route. +func (h *BaseAPIHandler) ExecuteCountWithAuthManager(ctx context.Context, handlerType, modelName string, rawJSON []byte, alt string) ([]byte, *interfaces.ErrorMessage) { + providers := util.GetProviderName(modelName, h.Cfg) + if len(providers) == 0 { + return nil, &interfaces.ErrorMessage{StatusCode: http.StatusBadRequest, Error: fmt.Errorf("unknown provider for model %s", modelName)} + } + req := coreexecutor.Request{ + Model: modelName, + Payload: cloneBytes(rawJSON), + } + opts := coreexecutor.Options{ + Stream: false, + Alt: alt, + OriginalRequest: cloneBytes(rawJSON), + SourceFormat: sdktranslator.FromString(handlerType), + } + resp, err := h.AuthManager.ExecuteCount(ctx, providers, req, opts) + if err != nil { + if msg, ok := executor.UnwrapError(err); ok { + return nil, msg + } + return nil, &interfaces.ErrorMessage{StatusCode: http.StatusInternalServerError, Error: err} + } + return cloneBytes(resp.Payload), nil +} + // ExecuteStreamWithAuthManager executes a streaming request via the core auth manager. // This path is the only supported execution route. func (h *BaseAPIHandler) ExecuteStreamWithAuthManager(ctx context.Context, handlerType, modelName string, rawJSON []byte, alt string) (<-chan []byte, <-chan *interfaces.ErrorMessage) { diff --git a/internal/runtime/executor/claude_executor.go b/internal/runtime/executor/claude_executor.go index 984acc69..53ba280c 100644 --- a/internal/runtime/executor/claude_executor.go +++ b/internal/runtime/executor/claude_executor.go @@ -142,6 +142,10 @@ func (e *ClaudeExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A return out, nil } +func (e *ClaudeExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) { + return cliproxyexecutor.Response{Payload: []byte{}}, fmt.Errorf("not implemented") +} + func (e *ClaudeExecutor) Refresh(ctx context.Context, auth *cliproxyauth.Auth) (*cliproxyauth.Auth, error) { log.Debugf("claude executor: refresh called") if auth == nil { diff --git a/internal/runtime/executor/client_executor.go b/internal/runtime/executor/client_executor.go index 6a986383..a14d6459 100644 --- a/internal/runtime/executor/client_executor.go +++ b/internal/runtime/executor/client_executor.go @@ -100,6 +100,10 @@ func (a *ClientAdapter) ExecuteStream(ctx context.Context, auth *cliproxyauth.Au return out, nil } +func (e *ClientAdapter) CountTokens(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) { + return cliproxyexecutor.Response{Payload: []byte{}}, nil +} + // Refresh delegates to the legacy client's refresh logic when available. func (a *ClientAdapter) Refresh(ctx context.Context, auth *cliproxyauth.Auth) (*cliproxyauth.Auth, error) { client, _, err := resolveLegacyClient(auth) diff --git a/internal/runtime/executor/codex_executor.go b/internal/runtime/executor/codex_executor.go index 3d455f8a..3de770f2 100644 --- a/internal/runtime/executor/codex_executor.go +++ b/internal/runtime/executor/codex_executor.go @@ -4,6 +4,7 @@ import ( "bufio" "bytes" "context" + "fmt" "io" "net/http" "strings" @@ -192,6 +193,10 @@ func (e *CodexExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Au return out, nil } +func (e *CodexExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) { + return cliproxyexecutor.Response{Payload: []byte{}}, fmt.Errorf("not implemented") +} + func (e *CodexExecutor) Refresh(ctx context.Context, auth *cliproxyauth.Auth) (*cliproxyauth.Auth, error) { log.Debugf("codex executor: refresh called") if auth == nil { diff --git a/internal/runtime/executor/gemini_cli_executor.go b/internal/runtime/executor/gemini_cli_executor.go index a2836053..a4b24b20 100644 --- a/internal/runtime/executor/gemini_cli_executor.go +++ b/internal/runtime/executor/gemini_cli_executor.go @@ -260,6 +260,81 @@ func (e *GeminiCLIExecutor) ExecuteStream(ctx context.Context, auth *cliproxyaut return nil, statusErr{code: lastStatus, msg: string(lastBody)} } +func (e *GeminiCLIExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) { + tokenSource, baseTokenData, err := prepareGeminiCLITokenSource(ctx, auth) + if err != nil { + return cliproxyexecutor.Response{}, err + } + + from := opts.SourceFormat + to := sdktranslator.FromString("gemini-cli") + + models := cliPreviewFallbackOrder(req.Model) + if len(models) == 0 || models[0] != req.Model { + models = append([]string{req.Model}, models...) + } + + httpClient := newHTTPClient(ctx, 0) + respCtx := context.WithValue(ctx, "alt", opts.Alt) + + var lastStatus int + var lastBody []byte + + for _, attemptModel := range models { + payload := sdktranslator.TranslateRequest(from, to, attemptModel, bytes.Clone(req.Payload), false) + payload = deleteJSONField(payload, "project") + payload = deleteJSONField(payload, "model") + + tok, errTok := tokenSource.Token() + if errTok != nil { + return cliproxyexecutor.Response{}, errTok + } + updateGeminiCLITokenMetadata(auth, baseTokenData, tok) + + url := fmt.Sprintf("%s/%s:%s", codeAssistEndpoint, codeAssistVersion, "countTokens") + if opts.Alt != "" { + url = url + fmt.Sprintf("?$alt=%s", opts.Alt) + } + + recordAPIRequest(ctx, e.cfg, payload) + reqHTTP, errReq := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(payload)) + if errReq != nil { + return cliproxyexecutor.Response{}, errReq + } + reqHTTP.Header.Set("Content-Type", "application/json") + reqHTTP.Header.Set("Authorization", "Bearer "+tok.AccessToken) + applyGeminiCLIHeaders(reqHTTP) + reqHTTP.Header.Set("Accept", "application/json") + + resp, errDo := httpClient.Do(reqHTTP) + if errDo != nil { + return cliproxyexecutor.Response{}, errDo + } + data, _ := io.ReadAll(resp.Body) + _ = resp.Body.Close() + appendAPIResponseChunk(ctx, e.cfg, data) + if resp.StatusCode >= 200 && resp.StatusCode < 300 { + var param any + translated := sdktranslator.TranslateNonStream(respCtx, to, from, attemptModel, bytes.Clone(opts.OriginalRequest), payload, data, ¶m) + return cliproxyexecutor.Response{Payload: []byte(translated)}, nil + } + lastStatus = resp.StatusCode + lastBody = data + if resp.StatusCode == 429 { + continue + } + break + } + + if len(lastBody) > 0 { + appendAPIResponseChunk(ctx, e.cfg, lastBody) + } + if lastStatus == 0 { + lastStatus = 429 + } + return cliproxyexecutor.Response{}, statusErr{code: lastStatus, msg: string(lastBody)} +} + func (e *GeminiCLIExecutor) Refresh(ctx context.Context, auth *cliproxyauth.Auth) (*cliproxyauth.Auth, error) { log.Debugf("gemini cli executor: refresh called") _ = ctx diff --git a/internal/runtime/executor/gemini_executor.go b/internal/runtime/executor/gemini_executor.go index 220b318c..d096d9ce 100644 --- a/internal/runtime/executor/gemini_executor.go +++ b/internal/runtime/executor/gemini_executor.go @@ -167,6 +167,10 @@ func (e *GeminiExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A return out, nil } +func (e *GeminiExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) { + return cliproxyexecutor.Response{Payload: []byte{}}, fmt.Errorf("not implemented") +} + func (e *GeminiExecutor) Refresh(ctx context.Context, auth *cliproxyauth.Auth) (*cliproxyauth.Auth, error) { log.Debugf("gemini executor: refresh called") // OAuth bearer token refresh for official Gemini API. diff --git a/internal/runtime/executor/gemini_web_executor.go b/internal/runtime/executor/gemini_web_executor.go index 34bd4c02..612bb954 100644 --- a/internal/runtime/executor/gemini_web_executor.go +++ b/internal/runtime/executor/gemini_web_executor.go @@ -106,6 +106,10 @@ func (e *GeminiWebExecutor) ExecuteStream(ctx context.Context, auth *cliproxyaut return out, nil } +func (e *GeminiWebExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) { + return cliproxyexecutor.Response{Payload: []byte{}}, fmt.Errorf("not implemented") +} + func (e *GeminiWebExecutor) Refresh(ctx context.Context, auth *cliproxyauth.Auth) (*cliproxyauth.Auth, error) { log.Debugf("gemini web executor: refresh called") state, err := e.stateFor(auth) diff --git a/internal/runtime/executor/openai_compat_executor.go b/internal/runtime/executor/openai_compat_executor.go index 81fee651..cab852f8 100644 --- a/internal/runtime/executor/openai_compat_executor.go +++ b/internal/runtime/executor/openai_compat_executor.go @@ -155,6 +155,10 @@ func (e *OpenAICompatExecutor) ExecuteStream(ctx context.Context, auth *cliproxy return out, nil } +func (e *OpenAICompatExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) { + return cliproxyexecutor.Response{Payload: []byte{}}, fmt.Errorf("not implemented") +} + // Refresh is a no-op for API-key based compatibility providers. func (e *OpenAICompatExecutor) Refresh(ctx context.Context, auth *cliproxyauth.Auth) (*cliproxyauth.Auth, error) { log.Debugf("openai compat executor: refresh called") diff --git a/internal/runtime/executor/qwen_executor.go b/internal/runtime/executor/qwen_executor.go index 1a3a6d2d..cba45d93 100644 --- a/internal/runtime/executor/qwen_executor.go +++ b/internal/runtime/executor/qwen_executor.go @@ -150,6 +150,10 @@ func (e *QwenExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Aut return out, nil } +func (e *QwenExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) { + return cliproxyexecutor.Response{Payload: []byte{}}, fmt.Errorf("not implemented") +} + func (e *QwenExecutor) Refresh(ctx context.Context, auth *cliproxyauth.Auth) (*cliproxyauth.Auth, error) { log.Debugf("qwen executor: refresh called") if auth == nil { diff --git a/sdk/cliproxy/auth/manager.go b/sdk/cliproxy/auth/manager.go index ee866ae4..0071e86a 100644 --- a/sdk/cliproxy/auth/manager.go +++ b/sdk/cliproxy/auth/manager.go @@ -27,6 +27,8 @@ type ProviderExecutor interface { ExecuteStream(ctx context.Context, auth *Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (<-chan cliproxyexecutor.StreamChunk, error) // Refresh attempts to refresh provider credentials and returns the updated auth state. Refresh(ctx context.Context, auth *Auth) (*Auth, error) + // CountTokens returns the token count for the given request. + CountTokens(ctx context.Context, auth *Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) } // RefreshEvaluator allows runtime state to override refresh decisions. @@ -215,6 +217,30 @@ func (m *Manager) Execute(ctx context.Context, providers []string, req cliproxye return cliproxyexecutor.Response{}, &Error{Code: "auth_not_found", Message: "no auth available"} } +// ExecuteCount performs a non-streaming execution using the configured selector and executor. +// It supports multiple providers for the same model and round-robins the starting provider per model. +func (m *Manager) ExecuteCount(ctx context.Context, providers []string, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) { + normalized := m.normalizeProviders(providers) + if len(normalized) == 0 { + return cliproxyexecutor.Response{}, &Error{Code: "provider_not_found", Message: "no provider supplied"} + } + rotated := m.rotateProviders(req.Model, normalized) + defer m.advanceProviderCursor(req.Model, normalized) + + var lastErr error + for _, provider := range rotated { + resp, errExec := m.executeCountWithProvider(ctx, provider, req, opts) + if errExec == nil { + return resp, nil + } + lastErr = errExec + } + if lastErr != nil { + return cliproxyexecutor.Response{}, lastErr + } + return cliproxyexecutor.Response{}, &Error{Code: "auth_not_found", Message: "no auth available"} +} + // ExecuteStream performs a streaming execution using the configured selector and executor. // It supports multiple providers for the same model and round-robins the starting provider per model. func (m *Manager) ExecuteStream(ctx context.Context, providers []string, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (<-chan cliproxyexecutor.StreamChunk, error) { @@ -286,6 +312,53 @@ func (m *Manager) executeWithProvider(ctx context.Context, provider string, req } } +func (m *Manager) executeCountWithProvider(ctx context.Context, provider string, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) { + if provider == "" { + return cliproxyexecutor.Response{}, &Error{Code: "provider_not_found", Message: "provider identifier is empty"} + } + tried := make(map[string]struct{}) + var lastErr error + for { + auth, executor, errPick := m.pickNext(ctx, provider, req.Model, opts, tried) + if errPick != nil { + if lastErr != nil { + return cliproxyexecutor.Response{}, lastErr + } + return cliproxyexecutor.Response{}, errPick + } + + accountType, accountInfo := auth.AccountInfo() + if accountType == "api_key" { + log.Debugf("Use API key %s for model %s", util.HideAPIKey(accountInfo), req.Model) + } else if accountType == "oauth" { + log.Debugf("Use OAuth %s for model %s", accountInfo, req.Model) + } else if accountType == "cookie" { + log.Debugf("Use Cookie %s for model %s", util.HideAPIKey(accountInfo), req.Model) + } + + tried[auth.ID] = struct{}{} + execCtx := ctx + if rt := m.roundTripperFor(auth); rt != nil { + execCtx = context.WithValue(execCtx, roundTripperContextKey{}, rt) + execCtx = context.WithValue(execCtx, "cliproxy.roundtripper", rt) + } + resp, errExec := executor.CountTokens(execCtx, auth, req, opts) + result := Result{AuthID: auth.ID, Provider: provider, Model: req.Model, Success: errExec == nil} + if errExec != nil { + result.Error = &Error{Message: errExec.Error()} + var se cliproxyexecutor.StatusError + if errors.As(errExec, &se) && se != nil { + result.Error.HTTPStatus = se.StatusCode() + } + m.MarkResult(execCtx, result) + lastErr = errExec + continue + } + m.MarkResult(execCtx, result) + return resp, nil + } +} + func (m *Manager) executeStreamWithProvider(ctx context.Context, provider string, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (<-chan cliproxyexecutor.StreamChunk, error) { if provider == "" { return nil, &Error{Code: "provider_not_found", Message: "provider identifier is empty"}