mirror of
https://github.com/router-for-me/CLIProxyAPI.git
synced 2026-02-18 04:10:51 +08:00
Introduce `PayloadConfig` in the configuration to define default and override rules for modifying payload parameters. Implement `applyPayloadConfig` and `applyPayloadConfigWithRoot` to apply these rules across various executors, ensuring consistent parameter handling for different models and protocols. Update all relevant executors to utilize this functionality.
548 lines
19 KiB
Go
548 lines
19 KiB
Go
// Package executor provides runtime execution capabilities for various AI service providers.
|
|
// It includes stateless executors that handle API requests, streaming responses,
|
|
// token counting, and authentication refresh for different AI service providers.
|
|
package executor
|
|
|
|
import (
|
|
"bufio"
|
|
"bytes"
|
|
"context"
|
|
"fmt"
|
|
"io"
|
|
"net/http"
|
|
"strings"
|
|
"time"
|
|
|
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
|
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
|
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
|
|
sdktranslator "github.com/router-for-me/CLIProxyAPI/v6/sdk/translator"
|
|
log "github.com/sirupsen/logrus"
|
|
"github.com/tidwall/gjson"
|
|
"github.com/tidwall/sjson"
|
|
"golang.org/x/oauth2"
|
|
"golang.org/x/oauth2/google"
|
|
)
|
|
|
|
const (
|
|
// glEndpoint is the base URL for the Google Generative Language API.
|
|
glEndpoint = "https://generativelanguage.googleapis.com"
|
|
|
|
// glAPIVersion is the API version used for Gemini requests.
|
|
glAPIVersion = "v1beta"
|
|
)
|
|
|
|
// GeminiExecutor is a stateless executor for the official Gemini API using API keys.
|
|
// It handles both API key and OAuth bearer token authentication, supporting both
|
|
// regular and streaming requests to the Google Generative Language API.
|
|
type GeminiExecutor struct {
|
|
// cfg holds the application configuration.
|
|
cfg *config.Config
|
|
}
|
|
|
|
// NewGeminiExecutor creates a new Gemini executor instance.
|
|
//
|
|
// Parameters:
|
|
// - cfg: The application configuration
|
|
//
|
|
// Returns:
|
|
// - *GeminiExecutor: A new Gemini executor instance
|
|
func NewGeminiExecutor(cfg *config.Config) *GeminiExecutor { return &GeminiExecutor{cfg: cfg} }
|
|
|
|
// Identifier returns the executor identifier for Gemini.
|
|
func (e *GeminiExecutor) Identifier() string { return "gemini" }
|
|
|
|
// PrepareRequest prepares the HTTP request for execution (no-op for Gemini).
|
|
func (e *GeminiExecutor) PrepareRequest(_ *http.Request, _ *cliproxyauth.Auth) error { return nil }
|
|
|
|
// Execute performs a non-streaming request to the Gemini API.
|
|
// It translates the request to Gemini format, sends it to the API, and translates
|
|
// the response back to the requested format.
|
|
//
|
|
// Parameters:
|
|
// - ctx: The context for the request
|
|
// - auth: The authentication information
|
|
// - req: The request to execute
|
|
// - opts: Additional execution options
|
|
//
|
|
// Returns:
|
|
// - cliproxyexecutor.Response: The response from the API
|
|
// - error: An error if the request fails
|
|
func (e *GeminiExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (resp cliproxyexecutor.Response, err error) {
|
|
apiKey, bearer := geminiCreds(auth)
|
|
|
|
reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth)
|
|
defer reporter.trackFailure(ctx, &err)
|
|
|
|
// Official Gemini API via API key or OAuth bearer
|
|
from := opts.SourceFormat
|
|
to := sdktranslator.FromString("gemini")
|
|
body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), false)
|
|
if budgetOverride, includeOverride, ok := util.GeminiThinkingFromMetadata(req.Metadata); ok && util.ModelSupportsThinking(req.Model) {
|
|
if budgetOverride != nil {
|
|
norm := util.NormalizeThinkingBudget(req.Model, *budgetOverride)
|
|
budgetOverride = &norm
|
|
}
|
|
body = util.ApplyGeminiThinkingConfig(body, budgetOverride, includeOverride)
|
|
}
|
|
body = util.StripThinkingConfigIfUnsupported(req.Model, body)
|
|
body = fixGeminiImageAspectRatio(req.Model, body)
|
|
body = applyPayloadConfig(e.cfg, req.Model, body)
|
|
|
|
action := "generateContent"
|
|
if req.Metadata != nil {
|
|
if a, _ := req.Metadata["action"].(string); a == "countTokens" {
|
|
action = "countTokens"
|
|
}
|
|
}
|
|
baseURL := resolveGeminiBaseURL(auth)
|
|
url := fmt.Sprintf("%s/%s/models/%s:%s", baseURL, glAPIVersion, req.Model, action)
|
|
if opts.Alt != "" && action != "countTokens" {
|
|
url = url + fmt.Sprintf("?$alt=%s", opts.Alt)
|
|
}
|
|
|
|
body, _ = sjson.DeleteBytes(body, "session_id")
|
|
|
|
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(body))
|
|
if err != nil {
|
|
return resp, err
|
|
}
|
|
httpReq.Header.Set("Content-Type", "application/json")
|
|
if apiKey != "" {
|
|
httpReq.Header.Set("x-goog-api-key", apiKey)
|
|
} else if bearer != "" {
|
|
httpReq.Header.Set("Authorization", "Bearer "+bearer)
|
|
}
|
|
applyGeminiHeaders(httpReq, auth)
|
|
var authID, authLabel, authType, authValue string
|
|
if auth != nil {
|
|
authID = auth.ID
|
|
authLabel = auth.Label
|
|
authType, authValue = auth.AccountInfo()
|
|
}
|
|
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{
|
|
URL: url,
|
|
Method: http.MethodPost,
|
|
Headers: httpReq.Header.Clone(),
|
|
Body: body,
|
|
Provider: e.Identifier(),
|
|
AuthID: authID,
|
|
AuthLabel: authLabel,
|
|
AuthType: authType,
|
|
AuthValue: authValue,
|
|
})
|
|
|
|
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
|
|
httpResp, err := httpClient.Do(httpReq)
|
|
if err != nil {
|
|
recordAPIResponseError(ctx, e.cfg, err)
|
|
return resp, err
|
|
}
|
|
defer func() {
|
|
if errClose := httpResp.Body.Close(); errClose != nil {
|
|
log.Errorf("gemini executor: close response body error: %v", errClose)
|
|
}
|
|
}()
|
|
recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
|
|
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
|
|
b, _ := io.ReadAll(httpResp.Body)
|
|
appendAPIResponseChunk(ctx, e.cfg, b)
|
|
log.Debugf("request error, error status: %d, error body: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
|
|
err = statusErr{code: httpResp.StatusCode, msg: string(b)}
|
|
return resp, err
|
|
}
|
|
data, err := io.ReadAll(httpResp.Body)
|
|
if err != nil {
|
|
recordAPIResponseError(ctx, e.cfg, err)
|
|
return resp, err
|
|
}
|
|
appendAPIResponseChunk(ctx, e.cfg, data)
|
|
reporter.publish(ctx, parseGeminiUsage(data))
|
|
var param any
|
|
out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, bytes.Clone(opts.OriginalRequest), body, data, ¶m)
|
|
resp = cliproxyexecutor.Response{Payload: []byte(out)}
|
|
return resp, nil
|
|
}
|
|
|
|
func (e *GeminiExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (stream <-chan cliproxyexecutor.StreamChunk, err error) {
|
|
apiKey, bearer := geminiCreds(auth)
|
|
|
|
reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth)
|
|
defer reporter.trackFailure(ctx, &err)
|
|
|
|
from := opts.SourceFormat
|
|
to := sdktranslator.FromString("gemini")
|
|
body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), true)
|
|
if budgetOverride, includeOverride, ok := util.GeminiThinkingFromMetadata(req.Metadata); ok && util.ModelSupportsThinking(req.Model) {
|
|
if budgetOverride != nil {
|
|
norm := util.NormalizeThinkingBudget(req.Model, *budgetOverride)
|
|
budgetOverride = &norm
|
|
}
|
|
body = util.ApplyGeminiThinkingConfig(body, budgetOverride, includeOverride)
|
|
}
|
|
body = util.StripThinkingConfigIfUnsupported(req.Model, body)
|
|
body = fixGeminiImageAspectRatio(req.Model, body)
|
|
body = applyPayloadConfig(e.cfg, req.Model, body)
|
|
|
|
baseURL := resolveGeminiBaseURL(auth)
|
|
url := fmt.Sprintf("%s/%s/models/%s:%s", baseURL, glAPIVersion, req.Model, "streamGenerateContent")
|
|
if opts.Alt == "" {
|
|
url = url + "?alt=sse"
|
|
} else {
|
|
url = url + fmt.Sprintf("?$alt=%s", opts.Alt)
|
|
}
|
|
|
|
body, _ = sjson.DeleteBytes(body, "session_id")
|
|
|
|
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(body))
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
httpReq.Header.Set("Content-Type", "application/json")
|
|
if apiKey != "" {
|
|
httpReq.Header.Set("x-goog-api-key", apiKey)
|
|
} else {
|
|
httpReq.Header.Set("Authorization", "Bearer "+bearer)
|
|
}
|
|
applyGeminiHeaders(httpReq, auth)
|
|
var authID, authLabel, authType, authValue string
|
|
if auth != nil {
|
|
authID = auth.ID
|
|
authLabel = auth.Label
|
|
authType, authValue = auth.AccountInfo()
|
|
}
|
|
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{
|
|
URL: url,
|
|
Method: http.MethodPost,
|
|
Headers: httpReq.Header.Clone(),
|
|
Body: body,
|
|
Provider: e.Identifier(),
|
|
AuthID: authID,
|
|
AuthLabel: authLabel,
|
|
AuthType: authType,
|
|
AuthValue: authValue,
|
|
})
|
|
|
|
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
|
|
httpResp, err := httpClient.Do(httpReq)
|
|
if err != nil {
|
|
recordAPIResponseError(ctx, e.cfg, err)
|
|
return nil, err
|
|
}
|
|
recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
|
|
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
|
|
b, _ := io.ReadAll(httpResp.Body)
|
|
appendAPIResponseChunk(ctx, e.cfg, b)
|
|
log.Debugf("request error, error status: %d, error body: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
|
|
if errClose := httpResp.Body.Close(); errClose != nil {
|
|
log.Errorf("gemini executor: close response body error: %v", errClose)
|
|
}
|
|
err = statusErr{code: httpResp.StatusCode, msg: string(b)}
|
|
return nil, err
|
|
}
|
|
out := make(chan cliproxyexecutor.StreamChunk)
|
|
stream = out
|
|
go func() {
|
|
defer close(out)
|
|
defer func() {
|
|
if errClose := httpResp.Body.Close(); errClose != nil {
|
|
log.Errorf("gemini executor: close response body error: %v", errClose)
|
|
}
|
|
}()
|
|
scanner := bufio.NewScanner(httpResp.Body)
|
|
buf := make([]byte, 20_971_520)
|
|
scanner.Buffer(buf, 20_971_520)
|
|
var param any
|
|
for scanner.Scan() {
|
|
line := scanner.Bytes()
|
|
appendAPIResponseChunk(ctx, e.cfg, line)
|
|
if detail, ok := parseGeminiStreamUsage(line); ok {
|
|
reporter.publish(ctx, detail)
|
|
}
|
|
lines := sdktranslator.TranslateStream(ctx, to, from, req.Model, bytes.Clone(opts.OriginalRequest), body, bytes.Clone(line), ¶m)
|
|
for i := range lines {
|
|
out <- cliproxyexecutor.StreamChunk{Payload: []byte(lines[i])}
|
|
}
|
|
}
|
|
lines := sdktranslator.TranslateStream(ctx, to, from, req.Model, bytes.Clone(opts.OriginalRequest), body, bytes.Clone([]byte("[DONE]")), ¶m)
|
|
for i := range lines {
|
|
out <- cliproxyexecutor.StreamChunk{Payload: []byte(lines[i])}
|
|
}
|
|
if errScan := scanner.Err(); errScan != nil {
|
|
recordAPIResponseError(ctx, e.cfg, errScan)
|
|
reporter.publishFailure(ctx)
|
|
out <- cliproxyexecutor.StreamChunk{Err: errScan}
|
|
}
|
|
}()
|
|
return stream, nil
|
|
}
|
|
|
|
func (e *GeminiExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
|
|
apiKey, bearer := geminiCreds(auth)
|
|
|
|
from := opts.SourceFormat
|
|
to := sdktranslator.FromString("gemini")
|
|
translatedReq := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), false)
|
|
if budgetOverride, includeOverride, ok := util.GeminiThinkingFromMetadata(req.Metadata); ok && util.ModelSupportsThinking(req.Model) {
|
|
if budgetOverride != nil {
|
|
norm := util.NormalizeThinkingBudget(req.Model, *budgetOverride)
|
|
budgetOverride = &norm
|
|
}
|
|
translatedReq = util.ApplyGeminiThinkingConfig(translatedReq, budgetOverride, includeOverride)
|
|
}
|
|
translatedReq = util.StripThinkingConfigIfUnsupported(req.Model, translatedReq)
|
|
translatedReq = fixGeminiImageAspectRatio(req.Model, translatedReq)
|
|
respCtx := context.WithValue(ctx, "alt", opts.Alt)
|
|
translatedReq, _ = sjson.DeleteBytes(translatedReq, "tools")
|
|
translatedReq, _ = sjson.DeleteBytes(translatedReq, "generationConfig")
|
|
translatedReq, _ = sjson.DeleteBytes(translatedReq, "safetySettings")
|
|
|
|
baseURL := resolveGeminiBaseURL(auth)
|
|
url := fmt.Sprintf("%s/%s/models/%s:%s", baseURL, glAPIVersion, req.Model, "countTokens")
|
|
|
|
requestBody := bytes.NewReader(translatedReq)
|
|
|
|
httpReq, err := http.NewRequestWithContext(ctx, http.MethodPost, url, requestBody)
|
|
if err != nil {
|
|
return cliproxyexecutor.Response{}, err
|
|
}
|
|
httpReq.Header.Set("Content-Type", "application/json")
|
|
if apiKey != "" {
|
|
httpReq.Header.Set("x-goog-api-key", apiKey)
|
|
} else {
|
|
httpReq.Header.Set("Authorization", "Bearer "+bearer)
|
|
}
|
|
applyGeminiHeaders(httpReq, auth)
|
|
var authID, authLabel, authType, authValue string
|
|
if auth != nil {
|
|
authID = auth.ID
|
|
authLabel = auth.Label
|
|
authType, authValue = auth.AccountInfo()
|
|
}
|
|
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{
|
|
URL: url,
|
|
Method: http.MethodPost,
|
|
Headers: httpReq.Header.Clone(),
|
|
Body: translatedReq,
|
|
Provider: e.Identifier(),
|
|
AuthID: authID,
|
|
AuthLabel: authLabel,
|
|
AuthType: authType,
|
|
AuthValue: authValue,
|
|
})
|
|
|
|
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
|
|
resp, err := httpClient.Do(httpReq)
|
|
if err != nil {
|
|
recordAPIResponseError(ctx, e.cfg, err)
|
|
return cliproxyexecutor.Response{}, err
|
|
}
|
|
defer func() { _ = resp.Body.Close() }()
|
|
recordAPIResponseMetadata(ctx, e.cfg, resp.StatusCode, resp.Header.Clone())
|
|
|
|
data, err := io.ReadAll(resp.Body)
|
|
if err != nil {
|
|
recordAPIResponseError(ctx, e.cfg, err)
|
|
return cliproxyexecutor.Response{}, err
|
|
}
|
|
appendAPIResponseChunk(ctx, e.cfg, data)
|
|
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
|
log.Debugf("request error, error status: %d, error body: %s", resp.StatusCode, summarizeErrorBody(resp.Header.Get("Content-Type"), data))
|
|
return cliproxyexecutor.Response{}, statusErr{code: resp.StatusCode, msg: string(data)}
|
|
}
|
|
|
|
count := gjson.GetBytes(data, "totalTokens").Int()
|
|
translated := sdktranslator.TranslateTokenCount(respCtx, to, from, count, data)
|
|
return cliproxyexecutor.Response{Payload: []byte(translated)}, nil
|
|
}
|
|
|
|
func (e *GeminiExecutor) Refresh(ctx context.Context, auth *cliproxyauth.Auth) (*cliproxyauth.Auth, error) {
|
|
log.Debugf("gemini executor: refresh called")
|
|
// OAuth bearer token refresh for official Gemini API.
|
|
if auth == nil {
|
|
return nil, fmt.Errorf("gemini executor: auth is nil")
|
|
}
|
|
if auth.Metadata == nil {
|
|
return auth, nil
|
|
}
|
|
// Token data is typically nested under "token" map in Gemini files.
|
|
tokenMap, _ := auth.Metadata["token"].(map[string]any)
|
|
var refreshToken, accessToken, clientID, clientSecret, tokenURI, expiryStr string
|
|
if tokenMap != nil {
|
|
if v, ok := tokenMap["refresh_token"].(string); ok {
|
|
refreshToken = v
|
|
}
|
|
if v, ok := tokenMap["access_token"].(string); ok {
|
|
accessToken = v
|
|
}
|
|
if v, ok := tokenMap["client_id"].(string); ok {
|
|
clientID = v
|
|
}
|
|
if v, ok := tokenMap["client_secret"].(string); ok {
|
|
clientSecret = v
|
|
}
|
|
if v, ok := tokenMap["token_uri"].(string); ok {
|
|
tokenURI = v
|
|
}
|
|
if v, ok := tokenMap["expiry"].(string); ok {
|
|
expiryStr = v
|
|
}
|
|
} else {
|
|
// Fallback to top-level keys if present
|
|
if v, ok := auth.Metadata["refresh_token"].(string); ok {
|
|
refreshToken = v
|
|
}
|
|
if v, ok := auth.Metadata["access_token"].(string); ok {
|
|
accessToken = v
|
|
}
|
|
if v, ok := auth.Metadata["client_id"].(string); ok {
|
|
clientID = v
|
|
}
|
|
if v, ok := auth.Metadata["client_secret"].(string); ok {
|
|
clientSecret = v
|
|
}
|
|
if v, ok := auth.Metadata["token_uri"].(string); ok {
|
|
tokenURI = v
|
|
}
|
|
if v, ok := auth.Metadata["expiry"].(string); ok {
|
|
expiryStr = v
|
|
}
|
|
}
|
|
if refreshToken == "" {
|
|
// Nothing to do for API key or cookie based entries
|
|
return auth, nil
|
|
}
|
|
|
|
// Prepare oauth2 config; default to Google endpoints
|
|
endpoint := google.Endpoint
|
|
if tokenURI != "" {
|
|
endpoint.TokenURL = tokenURI
|
|
}
|
|
conf := &oauth2.Config{ClientID: clientID, ClientSecret: clientSecret, Endpoint: endpoint}
|
|
|
|
// Ensure proxy-aware HTTP client for token refresh
|
|
httpClient := util.SetProxy(&e.cfg.SDKConfig, &http.Client{})
|
|
ctx = context.WithValue(ctx, oauth2.HTTPClient, httpClient)
|
|
|
|
// Build base token
|
|
tok := &oauth2.Token{AccessToken: accessToken, RefreshToken: refreshToken}
|
|
if t, err := time.Parse(time.RFC3339, expiryStr); err == nil {
|
|
tok.Expiry = t
|
|
}
|
|
newTok, err := conf.TokenSource(ctx, tok).Token()
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
// Persist back to metadata; prefer nested token map if present
|
|
if tokenMap == nil {
|
|
tokenMap = make(map[string]any)
|
|
}
|
|
tokenMap["access_token"] = newTok.AccessToken
|
|
tokenMap["refresh_token"] = newTok.RefreshToken
|
|
tokenMap["expiry"] = newTok.Expiry.Format(time.RFC3339)
|
|
if clientID != "" {
|
|
tokenMap["client_id"] = clientID
|
|
}
|
|
if clientSecret != "" {
|
|
tokenMap["client_secret"] = clientSecret
|
|
}
|
|
if tokenURI != "" {
|
|
tokenMap["token_uri"] = tokenURI
|
|
}
|
|
auth.Metadata["token"] = tokenMap
|
|
|
|
// Also mirror top-level access_token for compatibility if previously present
|
|
if _, ok := auth.Metadata["access_token"]; ok {
|
|
auth.Metadata["access_token"] = newTok.AccessToken
|
|
}
|
|
return auth, nil
|
|
}
|
|
|
|
func geminiCreds(a *cliproxyauth.Auth) (apiKey, bearer string) {
|
|
if a == nil {
|
|
return "", ""
|
|
}
|
|
if a.Attributes != nil {
|
|
if v := a.Attributes["api_key"]; v != "" {
|
|
apiKey = v
|
|
}
|
|
}
|
|
if a.Metadata != nil {
|
|
// GeminiTokenStorage.Token is a map that may contain access_token
|
|
if v, ok := a.Metadata["access_token"].(string); ok && v != "" {
|
|
bearer = v
|
|
}
|
|
if token, ok := a.Metadata["token"].(map[string]any); ok && token != nil {
|
|
if v, ok2 := token["access_token"].(string); ok2 && v != "" {
|
|
bearer = v
|
|
}
|
|
}
|
|
}
|
|
return
|
|
}
|
|
|
|
func resolveGeminiBaseURL(auth *cliproxyauth.Auth) string {
|
|
base := glEndpoint
|
|
if auth != nil && auth.Attributes != nil {
|
|
if custom := strings.TrimSpace(auth.Attributes["base_url"]); custom != "" {
|
|
base = strings.TrimRight(custom, "/")
|
|
}
|
|
}
|
|
if base == "" {
|
|
return glEndpoint
|
|
}
|
|
return base
|
|
}
|
|
|
|
func applyGeminiHeaders(req *http.Request, auth *cliproxyauth.Auth) {
|
|
var attrs map[string]string
|
|
if auth != nil {
|
|
attrs = auth.Attributes
|
|
}
|
|
util.ApplyCustomHeadersFromAttrs(req, attrs)
|
|
}
|
|
|
|
func fixGeminiImageAspectRatio(modelName string, rawJSON []byte) []byte {
|
|
if modelName == "gemini-2.5-flash-image-preview" {
|
|
aspectRatioResult := gjson.GetBytes(rawJSON, "generationConfig.imageConfig.aspectRatio")
|
|
if aspectRatioResult.Exists() {
|
|
contents := gjson.GetBytes(rawJSON, "contents")
|
|
contentArray := contents.Array()
|
|
if len(contentArray) > 0 {
|
|
hasInlineData := false
|
|
loopContent:
|
|
for i := 0; i < len(contentArray); i++ {
|
|
parts := contentArray[i].Get("parts").Array()
|
|
for j := 0; j < len(parts); j++ {
|
|
if parts[j].Get("inlineData").Exists() {
|
|
hasInlineData = true
|
|
break loopContent
|
|
}
|
|
}
|
|
}
|
|
|
|
if !hasInlineData {
|
|
emptyImageBase64ed, _ := util.CreateWhiteImageBase64(aspectRatioResult.String())
|
|
emptyImagePart := `{"inlineData":{"mime_type":"image/png","data":""}}`
|
|
emptyImagePart, _ = sjson.Set(emptyImagePart, "inlineData.data", emptyImageBase64ed)
|
|
newPartsJson := `[]`
|
|
newPartsJson, _ = sjson.SetRaw(newPartsJson, "-1", `{"text": "Based on the following requirements, create an image within the uploaded picture. The new content *MUST* completely cover the entire area of the original picture, maintaining its exact proportions, and *NO* blank areas should appear."}`)
|
|
newPartsJson, _ = sjson.SetRaw(newPartsJson, "-1", emptyImagePart)
|
|
|
|
parts := contentArray[0].Get("parts").Array()
|
|
for j := 0; j < len(parts); j++ {
|
|
newPartsJson, _ = sjson.SetRaw(newPartsJson, "-1", parts[j].Raw)
|
|
}
|
|
|
|
rawJSON, _ = sjson.SetRawBytes(rawJSON, "contents.0.parts", []byte(newPartsJson))
|
|
rawJSON, _ = sjson.SetRawBytes(rawJSON, "generationConfig.responseModalities", []byte(`["IMAGE", "TEXT"]`))
|
|
}
|
|
}
|
|
rawJSON, _ = sjson.DeleteBytes(rawJSON, "generationConfig.imageConfig")
|
|
}
|
|
}
|
|
return rawJSON
|
|
}
|