diff --git a/internal/runtime/executor/gemini_cli_executor.go b/internal/runtime/executor/gemini_cli_executor.go index 8a70b196..b4a8fc64 100644 --- a/internal/runtime/executor/gemini_cli_executor.go +++ b/internal/runtime/executor/gemini_cli_executor.go @@ -60,7 +60,11 @@ func (e *GeminiCLIExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth from := opts.SourceFormat to := sdktranslator.FromString("gemini-cli") + budgetOverride, includeOverride, hasOverride := util.GeminiThinkingFromMetadata(req.Metadata) basePayload := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), false) + if hasOverride { + basePayload = util.ApplyGeminiCLIThinkingConfig(basePayload, budgetOverride, includeOverride) + } basePayload = fixGeminiCLIImageAspectRatio(req.Model, basePayload) action := "generateContent" @@ -149,7 +153,11 @@ func (e *GeminiCLIExecutor) ExecuteStream(ctx context.Context, auth *cliproxyaut from := opts.SourceFormat to := sdktranslator.FromString("gemini-cli") + budgetOverride, includeOverride, hasOverride := util.GeminiThinkingFromMetadata(req.Metadata) basePayload := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), true) + if hasOverride { + basePayload = util.ApplyGeminiCLIThinkingConfig(basePayload, budgetOverride, includeOverride) + } basePayload = fixGeminiCLIImageAspectRatio(req.Model, basePayload) projectID := strings.TrimSpace(stringValue(auth.Metadata, "project_id")) @@ -292,8 +300,12 @@ func (e *GeminiCLIExecutor) CountTokens(ctx context.Context, auth *cliproxyauth. var lastStatus int var lastBody []byte + budgetOverride, includeOverride, hasOverride := util.GeminiThinkingFromMetadata(req.Metadata) for _, attemptModel := range models { payload := sdktranslator.TranslateRequest(from, to, attemptModel, bytes.Clone(req.Payload), false) + if hasOverride { + payload = util.ApplyGeminiCLIThinkingConfig(payload, budgetOverride, includeOverride) + } payload = deleteJSONField(payload, "project") payload = deleteJSONField(payload, "model") payload = disableGeminiThinkingConfig(payload, attemptModel) diff --git a/internal/runtime/executor/gemini_executor.go b/internal/runtime/executor/gemini_executor.go index 9073e0b6..59a8ae9e 100644 --- a/internal/runtime/executor/gemini_executor.go +++ b/internal/runtime/executor/gemini_executor.go @@ -77,6 +77,9 @@ func (e *GeminiExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, r from := opts.SourceFormat to := sdktranslator.FromString("gemini") body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), false) + if budgetOverride, includeOverride, ok := util.GeminiThinkingFromMetadata(req.Metadata); ok { + body = util.ApplyGeminiThinkingConfig(body, budgetOverride, includeOverride) + } body = disableGeminiThinkingConfig(body, req.Model) body = fixGeminiImageAspectRatio(req.Model, body) @@ -136,6 +139,9 @@ func (e *GeminiExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.A from := opts.SourceFormat to := sdktranslator.FromString("gemini") body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), true) + if budgetOverride, includeOverride, ok := util.GeminiThinkingFromMetadata(req.Metadata); ok { + body = util.ApplyGeminiThinkingConfig(body, budgetOverride, includeOverride) + } body = disableGeminiThinkingConfig(body, req.Model) body = fixGeminiImageAspectRatio(req.Model, body) @@ -208,6 +214,9 @@ func (e *GeminiExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Aut from := opts.SourceFormat to := sdktranslator.FromString("gemini") translatedReq := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), false) + if budgetOverride, includeOverride, ok := util.GeminiThinkingFromMetadata(req.Metadata); ok { + translatedReq = util.ApplyGeminiThinkingConfig(translatedReq, budgetOverride, includeOverride) + } translatedReq = disableGeminiThinkingConfig(translatedReq, req.Model) translatedReq = fixGeminiImageAspectRatio(req.Model, translatedReq) respCtx := context.WithValue(ctx, "alt", opts.Alt) diff --git a/internal/util/gemini_thinking.go b/internal/util/gemini_thinking.go new file mode 100644 index 00000000..9403a8e4 --- /dev/null +++ b/internal/util/gemini_thinking.go @@ -0,0 +1,181 @@ +package util + +import ( + "encoding/json" + "strconv" + "strings" + + "github.com/tidwall/sjson" +) + +const ( + GeminiThinkingBudgetMetadataKey = "gemini_thinking_budget" + GeminiIncludeThoughtsMetadataKey = "gemini_include_thoughts" + GeminiOriginalModelMetadataKey = "gemini_original_model" +) + +func ParseGeminiThinkingSuffix(model string) (string, *int, *bool, bool) { + if model == "" { + return model, nil, nil, false + } + lower := strings.ToLower(model) + if !strings.HasPrefix(lower, "gemini-") { + return model, nil, nil, false + } + + if strings.HasSuffix(lower, "-nothinking") { + base := model[:len(model)-len("-nothinking")] + budgetValue := 0 + if strings.HasPrefix(lower, "gemini-2.5-pro") { + budgetValue = 128 + } + include := false + return base, &budgetValue, &include, true + } + + idx := strings.LastIndex(lower, "-thinking-") + if idx == -1 { + return model, nil, nil, false + } + + digits := model[idx+len("-thinking-"):] + if digits == "" { + return model, nil, nil, false + } + end := len(digits) + for i := 0; i < len(digits); i++ { + if digits[i] < '0' || digits[i] > '9' { + end = i + break + } + } + if end == 0 { + return model, nil, nil, false + } + valueStr := digits[:end] + value, err := strconv.Atoi(valueStr) + if err != nil { + return model, nil, nil, false + } + base := model[:idx] + budgetValue := value + return base, &budgetValue, nil, true +} + +func ApplyGeminiThinkingConfig(body []byte, budget *int, includeThoughts *bool) []byte { + if budget == nil && includeThoughts == nil { + return body + } + updated := body + if budget != nil { + valuePath := "generationConfig.thinkingConfig.thinkingBudget" + rewritten, err := sjson.SetBytes(updated, valuePath, *budget) + if err == nil { + updated = rewritten + } + } + if includeThoughts != nil { + valuePath := "generationConfig.thinkingConfig.include_thoughts" + rewritten, err := sjson.SetBytes(updated, valuePath, *includeThoughts) + if err == nil { + updated = rewritten + } + } + return updated +} + +func ApplyGeminiCLIThinkingConfig(body []byte, budget *int, includeThoughts *bool) []byte { + if budget == nil && includeThoughts == nil { + return body + } + updated := body + if budget != nil { + valuePath := "request.generationConfig.thinkingConfig.thinkingBudget" + rewritten, err := sjson.SetBytes(updated, valuePath, *budget) + if err == nil { + updated = rewritten + } + } + if includeThoughts != nil { + valuePath := "request.generationConfig.thinkingConfig.include_thoughts" + rewritten, err := sjson.SetBytes(updated, valuePath, *includeThoughts) + if err == nil { + updated = rewritten + } + } + return updated +} + +func GeminiThinkingFromMetadata(metadata map[string]any) (*int, *bool, bool) { + if len(metadata) == 0 { + return nil, nil, false + } + var ( + budgetPtr *int + includePtr *bool + matched bool + ) + if rawBudget, ok := metadata[GeminiThinkingBudgetMetadataKey]; ok { + switch v := rawBudget.(type) { + case int: + budget := v + budgetPtr = &budget + matched = true + case int32: + budget := int(v) + budgetPtr = &budget + matched = true + case int64: + budget := int(v) + budgetPtr = &budget + matched = true + case float64: + budget := int(v) + budgetPtr = &budget + matched = true + case json.Number: + if val, err := v.Int64(); err == nil { + budget := int(val) + budgetPtr = &budget + matched = true + } + } + } + if rawInclude, ok := metadata[GeminiIncludeThoughtsMetadataKey]; ok { + switch v := rawInclude.(type) { + case bool: + include := v + includePtr = &include + matched = true + case string: + if parsed, err := strconv.ParseBool(v); err == nil { + include := parsed + includePtr = &include + matched = true + } + case json.Number: + if val, err := v.Int64(); err == nil { + include := val != 0 + includePtr = &include + matched = true + } + case int: + include := v != 0 + includePtr = &include + matched = true + case int32: + include := v != 0 + includePtr = &include + matched = true + case int64: + include := v != 0 + includePtr = &include + matched = true + case float64: + include := v != 0 + includePtr = &include + matched = true + } + } + return budgetPtr, includePtr, matched +} diff --git a/sdk/api/handlers/handlers.go b/sdk/api/handlers/handlers.go index f9f86fd3..0eb4588a 100644 --- a/sdk/api/handlers/handlers.go +++ b/sdk/api/handlers/handlers.go @@ -133,20 +133,27 @@ func (h *BaseAPIHandler) GetContextWithCancel(handler interfaces.APIHandler, c * // ExecuteWithAuthManager executes a non-streaming request via the core auth manager. // This path is the only supported execution route. func (h *BaseAPIHandler) ExecuteWithAuthManager(ctx context.Context, handlerType, modelName string, rawJSON []byte, alt string) ([]byte, *interfaces.ErrorMessage) { - providers := util.GetProviderName(modelName) + normalizedModel, metadata := normalizeModelMetadata(modelName) + providers := util.GetProviderName(normalizedModel) if len(providers) == 0 { return nil, &interfaces.ErrorMessage{StatusCode: http.StatusBadRequest, Error: fmt.Errorf("unknown provider for model %s", modelName)} } req := coreexecutor.Request{ - Model: modelName, + Model: normalizedModel, Payload: cloneBytes(rawJSON), } + if cloned := cloneMetadata(metadata); cloned != nil { + req.Metadata = cloned + } opts := coreexecutor.Options{ Stream: false, Alt: alt, OriginalRequest: cloneBytes(rawJSON), SourceFormat: sdktranslator.FromString(handlerType), } + if cloned := cloneMetadata(metadata); cloned != nil { + opts.Metadata = cloned + } resp, err := h.AuthManager.Execute(ctx, providers, req, opts) if err != nil { return nil, &interfaces.ErrorMessage{StatusCode: http.StatusInternalServerError, Error: err} @@ -157,20 +164,27 @@ func (h *BaseAPIHandler) ExecuteWithAuthManager(ctx context.Context, handlerType // ExecuteCountWithAuthManager executes a non-streaming request via the core auth manager. // This path is the only supported execution route. func (h *BaseAPIHandler) ExecuteCountWithAuthManager(ctx context.Context, handlerType, modelName string, rawJSON []byte, alt string) ([]byte, *interfaces.ErrorMessage) { - providers := util.GetProviderName(modelName) + normalizedModel, metadata := normalizeModelMetadata(modelName) + providers := util.GetProviderName(normalizedModel) if len(providers) == 0 { return nil, &interfaces.ErrorMessage{StatusCode: http.StatusBadRequest, Error: fmt.Errorf("unknown provider for model %s", modelName)} } req := coreexecutor.Request{ - Model: modelName, + Model: normalizedModel, Payload: cloneBytes(rawJSON), } + if cloned := cloneMetadata(metadata); cloned != nil { + req.Metadata = cloned + } opts := coreexecutor.Options{ Stream: false, Alt: alt, OriginalRequest: cloneBytes(rawJSON), SourceFormat: sdktranslator.FromString(handlerType), } + if cloned := cloneMetadata(metadata); cloned != nil { + opts.Metadata = cloned + } resp, err := h.AuthManager.ExecuteCount(ctx, providers, req, opts) if err != nil { return nil, &interfaces.ErrorMessage{StatusCode: http.StatusInternalServerError, Error: err} @@ -181,7 +195,8 @@ func (h *BaseAPIHandler) ExecuteCountWithAuthManager(ctx context.Context, handle // ExecuteStreamWithAuthManager executes a streaming request via the core auth manager. // This path is the only supported execution route. func (h *BaseAPIHandler) ExecuteStreamWithAuthManager(ctx context.Context, handlerType, modelName string, rawJSON []byte, alt string) (<-chan []byte, <-chan *interfaces.ErrorMessage) { - providers := util.GetProviderName(modelName) + normalizedModel, metadata := normalizeModelMetadata(modelName) + providers := util.GetProviderName(normalizedModel) if len(providers) == 0 { errChan := make(chan *interfaces.ErrorMessage, 1) errChan <- &interfaces.ErrorMessage{StatusCode: http.StatusBadRequest, Error: fmt.Errorf("unknown provider for model %s", modelName)} @@ -189,15 +204,21 @@ func (h *BaseAPIHandler) ExecuteStreamWithAuthManager(ctx context.Context, handl return nil, errChan } req := coreexecutor.Request{ - Model: modelName, + Model: normalizedModel, Payload: cloneBytes(rawJSON), } + if cloned := cloneMetadata(metadata); cloned != nil { + req.Metadata = cloned + } opts := coreexecutor.Options{ Stream: true, Alt: alt, OriginalRequest: cloneBytes(rawJSON), SourceFormat: sdktranslator.FromString(handlerType), } + if cloned := cloneMetadata(metadata); cloned != nil { + opts.Metadata = cloned + } chunks, err := h.AuthManager.ExecuteStream(ctx, providers, req, opts) if err != nil { errChan := make(chan *interfaces.ErrorMessage, 1) @@ -232,6 +253,34 @@ func cloneBytes(src []byte) []byte { return dst } +func normalizeModelMetadata(modelName string) (string, map[string]any) { + baseModel, budget, include, matched := util.ParseGeminiThinkingSuffix(modelName) + if !matched { + return baseModel, nil + } + metadata := map[string]any{ + util.GeminiOriginalModelMetadataKey: modelName, + } + if budget != nil { + metadata[util.GeminiThinkingBudgetMetadataKey] = *budget + } + if include != nil { + metadata[util.GeminiIncludeThoughtsMetadataKey] = *include + } + return baseModel, metadata +} + +func cloneMetadata(src map[string]any) map[string]any { + if len(src) == 0 { + return nil + } + dst := make(map[string]any, len(src)) + for k, v := range src { + dst[k] = v + } + return dst +} + // WriteErrorResponse writes an error message to the response writer using the HTTP status embedded in the message. func (h *BaseAPIHandler) WriteErrorResponse(c *gin.Context, msg *interfaces.ErrorMessage) { status := http.StatusInternalServerError