mirror of
https://github.com/router-for-me/CLIProxyAPI.git
synced 2026-02-03 04:50:52 +08:00
Add package and constructor documentation for AI Studio, Antigravity, Gemini CLI, Gemini API, and Vertex executors to describe their roles and inputs. Introduce a shared stream scanner buffer constant in the Gemini API executor and reuse it in Gemini CLI and Vertex streaming code so stream handling uses a consistent configuration. Update Refresh implementations for AI Studio, Gemini CLI, Gemini API (API key), and Vertex executors to short‑circuit and simply return the incoming auth object, while keeping Antigravity token renewal as the only executor that performs OAuth refresh. Remove OAuth2-based token refresh logic and related dependencies from the Gemini API executor, since it now operates strictly with API key credentials.
811 lines
32 KiB
Go
811 lines
32 KiB
Go
// Package executor provides runtime execution capabilities for various AI service providers.
|
|
// This file implements the Vertex AI Gemini executor that talks to Google Vertex AI
|
|
// endpoints using service account credentials or API keys.
|
|
package executor
|
|
|
|
import (
|
|
"bufio"
|
|
"bytes"
|
|
"context"
|
|
"encoding/json"
|
|
"fmt"
|
|
"io"
|
|
"net/http"
|
|
"strings"
|
|
|
|
vertexauth "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/vertex"
|
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
|
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
|
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
|
|
sdktranslator "github.com/router-for-me/CLIProxyAPI/v6/sdk/translator"
|
|
log "github.com/sirupsen/logrus"
|
|
"github.com/tidwall/gjson"
|
|
"github.com/tidwall/sjson"
|
|
"golang.org/x/oauth2"
|
|
"golang.org/x/oauth2/google"
|
|
)
|
|
|
|
const (
|
|
// vertexAPIVersion aligns with current public Vertex Generative AI API.
|
|
vertexAPIVersion = "v1"
|
|
)
|
|
|
|
// GeminiVertexExecutor sends requests to Vertex AI Gemini endpoints using service account credentials.
|
|
type GeminiVertexExecutor struct {
|
|
cfg *config.Config
|
|
}
|
|
|
|
// NewGeminiVertexExecutor creates a new Vertex AI Gemini executor instance.
|
|
//
|
|
// Parameters:
|
|
// - cfg: The application configuration
|
|
//
|
|
// Returns:
|
|
// - *GeminiVertexExecutor: A new Vertex AI Gemini executor instance
|
|
func NewGeminiVertexExecutor(cfg *config.Config) *GeminiVertexExecutor {
|
|
return &GeminiVertexExecutor{cfg: cfg}
|
|
}
|
|
|
|
// Identifier returns the executor identifier.
|
|
func (e *GeminiVertexExecutor) Identifier() string { return "vertex" }
|
|
|
|
// PrepareRequest prepares the HTTP request for execution (no-op for Vertex).
|
|
func (e *GeminiVertexExecutor) PrepareRequest(_ *http.Request, _ *cliproxyauth.Auth) error {
|
|
return nil
|
|
}
|
|
|
|
// Execute handles non-streaming requests.
|
|
func (e *GeminiVertexExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (resp cliproxyexecutor.Response, err error) {
|
|
// Try API key authentication first
|
|
apiKey, baseURL := vertexAPICreds(auth)
|
|
|
|
// If no API key found, fall back to service account authentication
|
|
if apiKey == "" {
|
|
projectID, location, saJSON, errCreds := vertexCreds(auth)
|
|
if errCreds != nil {
|
|
return resp, errCreds
|
|
}
|
|
return e.executeWithServiceAccount(ctx, auth, req, opts, projectID, location, saJSON)
|
|
}
|
|
|
|
// Use API key authentication
|
|
return e.executeWithAPIKey(ctx, auth, req, opts, apiKey, baseURL)
|
|
}
|
|
|
|
// ExecuteStream handles SSE streaming for Vertex.
|
|
func (e *GeminiVertexExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (stream <-chan cliproxyexecutor.StreamChunk, err error) {
|
|
// Try API key authentication first
|
|
apiKey, baseURL := vertexAPICreds(auth)
|
|
|
|
// If no API key found, fall back to service account authentication
|
|
if apiKey == "" {
|
|
projectID, location, saJSON, errCreds := vertexCreds(auth)
|
|
if errCreds != nil {
|
|
return nil, errCreds
|
|
}
|
|
return e.executeStreamWithServiceAccount(ctx, auth, req, opts, projectID, location, saJSON)
|
|
}
|
|
|
|
// Use API key authentication
|
|
return e.executeStreamWithAPIKey(ctx, auth, req, opts, apiKey, baseURL)
|
|
}
|
|
|
|
// CountTokens calls Vertex countTokens endpoint.
|
|
func (e *GeminiVertexExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
|
|
// Try API key authentication first
|
|
apiKey, baseURL := vertexAPICreds(auth)
|
|
|
|
// If no API key found, fall back to service account authentication
|
|
if apiKey == "" {
|
|
projectID, location, saJSON, errCreds := vertexCreds(auth)
|
|
if errCreds != nil {
|
|
return cliproxyexecutor.Response{}, errCreds
|
|
}
|
|
return e.countTokensWithServiceAccount(ctx, auth, req, opts, projectID, location, saJSON)
|
|
}
|
|
|
|
// Use API key authentication
|
|
return e.countTokensWithAPIKey(ctx, auth, req, opts, apiKey, baseURL)
|
|
}
|
|
|
|
// countTokensWithServiceAccount handles token counting using service account credentials.
|
|
func (e *GeminiVertexExecutor) countTokensWithServiceAccount(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options, projectID, location string, saJSON []byte) (cliproxyexecutor.Response, error) {
|
|
upstreamModel := util.ResolveOriginalModel(req.Model, req.Metadata)
|
|
|
|
from := opts.SourceFormat
|
|
to := sdktranslator.FromString("gemini")
|
|
translatedReq := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), false)
|
|
if budgetOverride, includeOverride, ok := util.ResolveThinkingConfigFromMetadata(req.Model, req.Metadata); ok && util.ModelSupportsThinking(req.Model) {
|
|
if budgetOverride != nil {
|
|
norm := util.NormalizeThinkingBudget(req.Model, *budgetOverride)
|
|
budgetOverride = &norm
|
|
}
|
|
translatedReq = util.ApplyGeminiThinkingConfig(translatedReq, budgetOverride, includeOverride)
|
|
}
|
|
translatedReq = util.StripThinkingConfigIfUnsupported(req.Model, translatedReq)
|
|
translatedReq = fixGeminiImageAspectRatio(req.Model, translatedReq)
|
|
translatedReq, _ = sjson.SetBytes(translatedReq, "model", upstreamModel)
|
|
respCtx := context.WithValue(ctx, "alt", opts.Alt)
|
|
translatedReq, _ = sjson.DeleteBytes(translatedReq, "tools")
|
|
translatedReq, _ = sjson.DeleteBytes(translatedReq, "generationConfig")
|
|
translatedReq, _ = sjson.DeleteBytes(translatedReq, "safetySettings")
|
|
|
|
baseURL := vertexBaseURL(location)
|
|
url := fmt.Sprintf("%s/%s/projects/%s/locations/%s/publishers/google/models/%s:%s", baseURL, vertexAPIVersion, projectID, location, upstreamModel, "countTokens")
|
|
|
|
httpReq, errNewReq := http.NewRequestWithContext(respCtx, http.MethodPost, url, bytes.NewReader(translatedReq))
|
|
if errNewReq != nil {
|
|
return cliproxyexecutor.Response{}, errNewReq
|
|
}
|
|
httpReq.Header.Set("Content-Type", "application/json")
|
|
if token, errTok := vertexAccessToken(ctx, e.cfg, auth, saJSON); errTok == nil && token != "" {
|
|
httpReq.Header.Set("Authorization", "Bearer "+token)
|
|
} else if errTok != nil {
|
|
log.Errorf("vertex executor: access token error: %v", errTok)
|
|
return cliproxyexecutor.Response{}, statusErr{code: 500, msg: "internal server error"}
|
|
}
|
|
applyGeminiHeaders(httpReq, auth)
|
|
|
|
var authID, authLabel, authType, authValue string
|
|
if auth != nil {
|
|
authID = auth.ID
|
|
authLabel = auth.Label
|
|
authType, authValue = auth.AccountInfo()
|
|
}
|
|
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{
|
|
URL: url,
|
|
Method: http.MethodPost,
|
|
Headers: httpReq.Header.Clone(),
|
|
Body: translatedReq,
|
|
Provider: e.Identifier(),
|
|
AuthID: authID,
|
|
AuthLabel: authLabel,
|
|
AuthType: authType,
|
|
AuthValue: authValue,
|
|
})
|
|
|
|
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
|
|
httpResp, errDo := httpClient.Do(httpReq)
|
|
if errDo != nil {
|
|
recordAPIResponseError(ctx, e.cfg, errDo)
|
|
return cliproxyexecutor.Response{}, errDo
|
|
}
|
|
defer func() {
|
|
if errClose := httpResp.Body.Close(); errClose != nil {
|
|
log.Errorf("vertex executor: close response body error: %v", errClose)
|
|
}
|
|
}()
|
|
recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
|
|
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
|
|
b, _ := io.ReadAll(httpResp.Body)
|
|
appendAPIResponseChunk(ctx, e.cfg, b)
|
|
log.Debugf("request error, error status: %d, error body: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
|
|
return cliproxyexecutor.Response{}, statusErr{code: httpResp.StatusCode, msg: string(b)}
|
|
}
|
|
data, errRead := io.ReadAll(httpResp.Body)
|
|
if errRead != nil {
|
|
recordAPIResponseError(ctx, e.cfg, errRead)
|
|
return cliproxyexecutor.Response{}, errRead
|
|
}
|
|
appendAPIResponseChunk(ctx, e.cfg, data)
|
|
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
|
|
log.Debugf("request error, error status: %d, error body: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), data))
|
|
return cliproxyexecutor.Response{}, statusErr{code: httpResp.StatusCode, msg: string(data)}
|
|
}
|
|
count := gjson.GetBytes(data, "totalTokens").Int()
|
|
out := sdktranslator.TranslateTokenCount(ctx, to, from, count, data)
|
|
return cliproxyexecutor.Response{Payload: []byte(out)}, nil
|
|
}
|
|
|
|
// countTokensWithAPIKey handles token counting using API key credentials.
|
|
func (e *GeminiVertexExecutor) countTokensWithAPIKey(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options, apiKey, baseURL string) (cliproxyexecutor.Response, error) {
|
|
upstreamModel := util.ResolveOriginalModel(req.Model, req.Metadata)
|
|
|
|
from := opts.SourceFormat
|
|
to := sdktranslator.FromString("gemini")
|
|
translatedReq := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), false)
|
|
if budgetOverride, includeOverride, ok := util.ResolveThinkingConfigFromMetadata(req.Model, req.Metadata); ok && util.ModelSupportsThinking(req.Model) {
|
|
if budgetOverride != nil {
|
|
norm := util.NormalizeThinkingBudget(req.Model, *budgetOverride)
|
|
budgetOverride = &norm
|
|
}
|
|
translatedReq = util.ApplyGeminiThinkingConfig(translatedReq, budgetOverride, includeOverride)
|
|
}
|
|
translatedReq = util.StripThinkingConfigIfUnsupported(req.Model, translatedReq)
|
|
translatedReq = fixGeminiImageAspectRatio(req.Model, translatedReq)
|
|
translatedReq, _ = sjson.SetBytes(translatedReq, "model", upstreamModel)
|
|
respCtx := context.WithValue(ctx, "alt", opts.Alt)
|
|
translatedReq, _ = sjson.DeleteBytes(translatedReq, "tools")
|
|
translatedReq, _ = sjson.DeleteBytes(translatedReq, "generationConfig")
|
|
translatedReq, _ = sjson.DeleteBytes(translatedReq, "safetySettings")
|
|
|
|
// For API key auth, use simpler URL format without project/location
|
|
if baseURL == "" {
|
|
baseURL = "https://generativelanguage.googleapis.com"
|
|
}
|
|
url := fmt.Sprintf("%s/%s/publishers/google/models/%s:%s", baseURL, vertexAPIVersion, req.Model, "countTokens")
|
|
|
|
httpReq, errNewReq := http.NewRequestWithContext(respCtx, http.MethodPost, url, bytes.NewReader(translatedReq))
|
|
if errNewReq != nil {
|
|
return cliproxyexecutor.Response{}, errNewReq
|
|
}
|
|
httpReq.Header.Set("Content-Type", "application/json")
|
|
if apiKey != "" {
|
|
httpReq.Header.Set("x-goog-api-key", apiKey)
|
|
}
|
|
applyGeminiHeaders(httpReq, auth)
|
|
|
|
var authID, authLabel, authType, authValue string
|
|
if auth != nil {
|
|
authID = auth.ID
|
|
authLabel = auth.Label
|
|
authType, authValue = auth.AccountInfo()
|
|
}
|
|
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{
|
|
URL: url,
|
|
Method: http.MethodPost,
|
|
Headers: httpReq.Header.Clone(),
|
|
Body: translatedReq,
|
|
Provider: e.Identifier(),
|
|
AuthID: authID,
|
|
AuthLabel: authLabel,
|
|
AuthType: authType,
|
|
AuthValue: authValue,
|
|
})
|
|
|
|
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
|
|
httpResp, errDo := httpClient.Do(httpReq)
|
|
if errDo != nil {
|
|
recordAPIResponseError(ctx, e.cfg, errDo)
|
|
return cliproxyexecutor.Response{}, errDo
|
|
}
|
|
defer func() {
|
|
if errClose := httpResp.Body.Close(); errClose != nil {
|
|
log.Errorf("vertex executor: close response body error: %v", errClose)
|
|
}
|
|
}()
|
|
recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
|
|
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
|
|
b, _ := io.ReadAll(httpResp.Body)
|
|
appendAPIResponseChunk(ctx, e.cfg, b)
|
|
log.Debugf("request error, error status: %d, error body: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
|
|
return cliproxyexecutor.Response{}, statusErr{code: httpResp.StatusCode, msg: string(b)}
|
|
}
|
|
data, errRead := io.ReadAll(httpResp.Body)
|
|
if errRead != nil {
|
|
recordAPIResponseError(ctx, e.cfg, errRead)
|
|
return cliproxyexecutor.Response{}, errRead
|
|
}
|
|
appendAPIResponseChunk(ctx, e.cfg, data)
|
|
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
|
|
log.Debugf("request error, error status: %d, error body: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), data))
|
|
return cliproxyexecutor.Response{}, statusErr{code: httpResp.StatusCode, msg: string(data)}
|
|
}
|
|
count := gjson.GetBytes(data, "totalTokens").Int()
|
|
out := sdktranslator.TranslateTokenCount(ctx, to, from, count, data)
|
|
return cliproxyexecutor.Response{Payload: []byte(out)}, nil
|
|
}
|
|
|
|
// Refresh refreshes the authentication credentials (no-op for Vertex).
|
|
func (e *GeminiVertexExecutor) Refresh(_ context.Context, auth *cliproxyauth.Auth) (*cliproxyauth.Auth, error) {
|
|
return auth, nil
|
|
}
|
|
|
|
// executeWithServiceAccount handles authentication using service account credentials.
|
|
// This method contains the original service account authentication logic.
|
|
func (e *GeminiVertexExecutor) executeWithServiceAccount(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options, projectID, location string, saJSON []byte) (resp cliproxyexecutor.Response, err error) {
|
|
reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth)
|
|
defer reporter.trackFailure(ctx, &err)
|
|
|
|
upstreamModel := util.ResolveOriginalModel(req.Model, req.Metadata)
|
|
|
|
from := opts.SourceFormat
|
|
to := sdktranslator.FromString("gemini")
|
|
body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), false)
|
|
if budgetOverride, includeOverride, ok := util.ResolveThinkingConfigFromMetadata(req.Model, req.Metadata); ok && util.ModelSupportsThinking(req.Model) {
|
|
if budgetOverride != nil {
|
|
norm := util.NormalizeThinkingBudget(req.Model, *budgetOverride)
|
|
budgetOverride = &norm
|
|
}
|
|
body = util.ApplyGeminiThinkingConfig(body, budgetOverride, includeOverride)
|
|
}
|
|
body = util.ApplyDefaultThinkingIfNeeded(req.Model, body)
|
|
body = util.NormalizeGeminiThinkingBudget(req.Model, body)
|
|
body = util.StripThinkingConfigIfUnsupported(req.Model, body)
|
|
body = fixGeminiImageAspectRatio(req.Model, body)
|
|
body = applyPayloadConfig(e.cfg, req.Model, body)
|
|
body, _ = sjson.SetBytes(body, "model", upstreamModel)
|
|
|
|
action := "generateContent"
|
|
if req.Metadata != nil {
|
|
if a, _ := req.Metadata["action"].(string); a == "countTokens" {
|
|
action = "countTokens"
|
|
}
|
|
}
|
|
baseURL := vertexBaseURL(location)
|
|
url := fmt.Sprintf("%s/%s/projects/%s/locations/%s/publishers/google/models/%s:%s", baseURL, vertexAPIVersion, projectID, location, upstreamModel, action)
|
|
if opts.Alt != "" && action != "countTokens" {
|
|
url = url + fmt.Sprintf("?$alt=%s", opts.Alt)
|
|
}
|
|
body, _ = sjson.DeleteBytes(body, "session_id")
|
|
|
|
httpReq, errNewReq := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(body))
|
|
if errNewReq != nil {
|
|
return resp, errNewReq
|
|
}
|
|
httpReq.Header.Set("Content-Type", "application/json")
|
|
if token, errTok := vertexAccessToken(ctx, e.cfg, auth, saJSON); errTok == nil && token != "" {
|
|
httpReq.Header.Set("Authorization", "Bearer "+token)
|
|
} else if errTok != nil {
|
|
log.Errorf("vertex executor: access token error: %v", errTok)
|
|
return resp, statusErr{code: 500, msg: "internal server error"}
|
|
}
|
|
applyGeminiHeaders(httpReq, auth)
|
|
|
|
var authID, authLabel, authType, authValue string
|
|
if auth != nil {
|
|
authID = auth.ID
|
|
authLabel = auth.Label
|
|
authType, authValue = auth.AccountInfo()
|
|
}
|
|
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{
|
|
URL: url,
|
|
Method: http.MethodPost,
|
|
Headers: httpReq.Header.Clone(),
|
|
Body: body,
|
|
Provider: e.Identifier(),
|
|
AuthID: authID,
|
|
AuthLabel: authLabel,
|
|
AuthType: authType,
|
|
AuthValue: authValue,
|
|
})
|
|
|
|
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
|
|
httpResp, errDo := httpClient.Do(httpReq)
|
|
if errDo != nil {
|
|
recordAPIResponseError(ctx, e.cfg, errDo)
|
|
return resp, errDo
|
|
}
|
|
defer func() {
|
|
if errClose := httpResp.Body.Close(); errClose != nil {
|
|
log.Errorf("vertex executor: close response body error: %v", errClose)
|
|
}
|
|
}()
|
|
recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
|
|
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
|
|
b, _ := io.ReadAll(httpResp.Body)
|
|
appendAPIResponseChunk(ctx, e.cfg, b)
|
|
log.Debugf("request error, error status: %d, error body: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
|
|
err = statusErr{code: httpResp.StatusCode, msg: string(b)}
|
|
return resp, err
|
|
}
|
|
data, errRead := io.ReadAll(httpResp.Body)
|
|
if errRead != nil {
|
|
recordAPIResponseError(ctx, e.cfg, errRead)
|
|
return resp, errRead
|
|
}
|
|
appendAPIResponseChunk(ctx, e.cfg, data)
|
|
reporter.publish(ctx, parseGeminiUsage(data))
|
|
var param any
|
|
out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, bytes.Clone(opts.OriginalRequest), body, data, ¶m)
|
|
resp = cliproxyexecutor.Response{Payload: []byte(out)}
|
|
return resp, nil
|
|
}
|
|
|
|
// executeWithAPIKey handles authentication using API key credentials.
|
|
func (e *GeminiVertexExecutor) executeWithAPIKey(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options, apiKey, baseURL string) (resp cliproxyexecutor.Response, err error) {
|
|
reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth)
|
|
defer reporter.trackFailure(ctx, &err)
|
|
|
|
upstreamModel := util.ResolveOriginalModel(req.Model, req.Metadata)
|
|
|
|
from := opts.SourceFormat
|
|
to := sdktranslator.FromString("gemini")
|
|
body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), false)
|
|
if budgetOverride, includeOverride, ok := util.ResolveThinkingConfigFromMetadata(req.Model, req.Metadata); ok && util.ModelSupportsThinking(req.Model) {
|
|
if budgetOverride != nil {
|
|
norm := util.NormalizeThinkingBudget(req.Model, *budgetOverride)
|
|
budgetOverride = &norm
|
|
}
|
|
body = util.ApplyGeminiThinkingConfig(body, budgetOverride, includeOverride)
|
|
}
|
|
body = util.ApplyDefaultThinkingIfNeeded(req.Model, body)
|
|
body = util.NormalizeGeminiThinkingBudget(req.Model, body)
|
|
body = util.StripThinkingConfigIfUnsupported(req.Model, body)
|
|
body = fixGeminiImageAspectRatio(req.Model, body)
|
|
body = applyPayloadConfig(e.cfg, req.Model, body)
|
|
body, _ = sjson.SetBytes(body, "model", upstreamModel)
|
|
|
|
action := "generateContent"
|
|
if req.Metadata != nil {
|
|
if a, _ := req.Metadata["action"].(string); a == "countTokens" {
|
|
action = "countTokens"
|
|
}
|
|
}
|
|
|
|
// For API key auth, use simpler URL format without project/location
|
|
if baseURL == "" {
|
|
baseURL = "https://generativelanguage.googleapis.com"
|
|
}
|
|
url := fmt.Sprintf("%s/%s/publishers/google/models/%s:%s", baseURL, vertexAPIVersion, upstreamModel, action)
|
|
if opts.Alt != "" && action != "countTokens" {
|
|
url = url + fmt.Sprintf("?$alt=%s", opts.Alt)
|
|
}
|
|
body, _ = sjson.DeleteBytes(body, "session_id")
|
|
|
|
httpReq, errNewReq := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(body))
|
|
if errNewReq != nil {
|
|
return resp, errNewReq
|
|
}
|
|
httpReq.Header.Set("Content-Type", "application/json")
|
|
if apiKey != "" {
|
|
httpReq.Header.Set("x-goog-api-key", apiKey)
|
|
}
|
|
applyGeminiHeaders(httpReq, auth)
|
|
|
|
var authID, authLabel, authType, authValue string
|
|
if auth != nil {
|
|
authID = auth.ID
|
|
authLabel = auth.Label
|
|
authType, authValue = auth.AccountInfo()
|
|
}
|
|
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{
|
|
URL: url,
|
|
Method: http.MethodPost,
|
|
Headers: httpReq.Header.Clone(),
|
|
Body: body,
|
|
Provider: e.Identifier(),
|
|
AuthID: authID,
|
|
AuthLabel: authLabel,
|
|
AuthType: authType,
|
|
AuthValue: authValue,
|
|
})
|
|
|
|
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
|
|
httpResp, errDo := httpClient.Do(httpReq)
|
|
if errDo != nil {
|
|
recordAPIResponseError(ctx, e.cfg, errDo)
|
|
return resp, errDo
|
|
}
|
|
defer func() {
|
|
if errClose := httpResp.Body.Close(); errClose != nil {
|
|
log.Errorf("vertex executor: close response body error: %v", errClose)
|
|
}
|
|
}()
|
|
recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
|
|
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
|
|
b, _ := io.ReadAll(httpResp.Body)
|
|
appendAPIResponseChunk(ctx, e.cfg, b)
|
|
log.Debugf("request error, error status: %d, error body: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
|
|
err = statusErr{code: httpResp.StatusCode, msg: string(b)}
|
|
return resp, err
|
|
}
|
|
data, errRead := io.ReadAll(httpResp.Body)
|
|
if errRead != nil {
|
|
recordAPIResponseError(ctx, e.cfg, errRead)
|
|
return resp, errRead
|
|
}
|
|
appendAPIResponseChunk(ctx, e.cfg, data)
|
|
reporter.publish(ctx, parseGeminiUsage(data))
|
|
var param any
|
|
out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, bytes.Clone(opts.OriginalRequest), body, data, ¶m)
|
|
resp = cliproxyexecutor.Response{Payload: []byte(out)}
|
|
return resp, nil
|
|
}
|
|
|
|
// executeStreamWithServiceAccount handles streaming authentication using service account credentials.
|
|
func (e *GeminiVertexExecutor) executeStreamWithServiceAccount(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options, projectID, location string, saJSON []byte) (stream <-chan cliproxyexecutor.StreamChunk, err error) {
|
|
reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth)
|
|
defer reporter.trackFailure(ctx, &err)
|
|
|
|
upstreamModel := util.ResolveOriginalModel(req.Model, req.Metadata)
|
|
|
|
from := opts.SourceFormat
|
|
to := sdktranslator.FromString("gemini")
|
|
body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), true)
|
|
if budgetOverride, includeOverride, ok := util.ResolveThinkingConfigFromMetadata(req.Model, req.Metadata); ok && util.ModelSupportsThinking(req.Model) {
|
|
if budgetOverride != nil {
|
|
norm := util.NormalizeThinkingBudget(req.Model, *budgetOverride)
|
|
budgetOverride = &norm
|
|
}
|
|
body = util.ApplyGeminiThinkingConfig(body, budgetOverride, includeOverride)
|
|
}
|
|
body = util.ApplyDefaultThinkingIfNeeded(req.Model, body)
|
|
body = util.NormalizeGeminiThinkingBudget(req.Model, body)
|
|
body = util.StripThinkingConfigIfUnsupported(req.Model, body)
|
|
body = fixGeminiImageAspectRatio(req.Model, body)
|
|
body = applyPayloadConfig(e.cfg, req.Model, body)
|
|
body, _ = sjson.SetBytes(body, "model", upstreamModel)
|
|
|
|
baseURL := vertexBaseURL(location)
|
|
url := fmt.Sprintf("%s/%s/projects/%s/locations/%s/publishers/google/models/%s:%s", baseURL, vertexAPIVersion, projectID, location, upstreamModel, "streamGenerateContent")
|
|
if opts.Alt == "" {
|
|
url = url + "?alt=sse"
|
|
} else {
|
|
url = url + fmt.Sprintf("?$alt=%s", opts.Alt)
|
|
}
|
|
body, _ = sjson.DeleteBytes(body, "session_id")
|
|
|
|
httpReq, errNewReq := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(body))
|
|
if errNewReq != nil {
|
|
return nil, errNewReq
|
|
}
|
|
httpReq.Header.Set("Content-Type", "application/json")
|
|
if token, errTok := vertexAccessToken(ctx, e.cfg, auth, saJSON); errTok == nil && token != "" {
|
|
httpReq.Header.Set("Authorization", "Bearer "+token)
|
|
} else if errTok != nil {
|
|
log.Errorf("vertex executor: access token error: %v", errTok)
|
|
return nil, statusErr{code: 500, msg: "internal server error"}
|
|
}
|
|
applyGeminiHeaders(httpReq, auth)
|
|
|
|
var authID, authLabel, authType, authValue string
|
|
if auth != nil {
|
|
authID = auth.ID
|
|
authLabel = auth.Label
|
|
authType, authValue = auth.AccountInfo()
|
|
}
|
|
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{
|
|
URL: url,
|
|
Method: http.MethodPost,
|
|
Headers: httpReq.Header.Clone(),
|
|
Body: body,
|
|
Provider: e.Identifier(),
|
|
AuthID: authID,
|
|
AuthLabel: authLabel,
|
|
AuthType: authType,
|
|
AuthValue: authValue,
|
|
})
|
|
|
|
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
|
|
httpResp, errDo := httpClient.Do(httpReq)
|
|
if errDo != nil {
|
|
recordAPIResponseError(ctx, e.cfg, errDo)
|
|
return nil, errDo
|
|
}
|
|
recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
|
|
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
|
|
b, _ := io.ReadAll(httpResp.Body)
|
|
appendAPIResponseChunk(ctx, e.cfg, b)
|
|
log.Debugf("request error, error status: %d, error body: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
|
|
if errClose := httpResp.Body.Close(); errClose != nil {
|
|
log.Errorf("vertex executor: close response body error: %v", errClose)
|
|
}
|
|
return nil, statusErr{code: httpResp.StatusCode, msg: string(b)}
|
|
}
|
|
|
|
out := make(chan cliproxyexecutor.StreamChunk)
|
|
stream = out
|
|
go func() {
|
|
defer close(out)
|
|
defer func() {
|
|
if errClose := httpResp.Body.Close(); errClose != nil {
|
|
log.Errorf("vertex executor: close response body error: %v", errClose)
|
|
}
|
|
}()
|
|
scanner := bufio.NewScanner(httpResp.Body)
|
|
scanner.Buffer(nil, streamScannerBuffer)
|
|
var param any
|
|
for scanner.Scan() {
|
|
line := scanner.Bytes()
|
|
appendAPIResponseChunk(ctx, e.cfg, line)
|
|
if detail, ok := parseGeminiStreamUsage(line); ok {
|
|
reporter.publish(ctx, detail)
|
|
}
|
|
lines := sdktranslator.TranslateStream(ctx, to, from, req.Model, bytes.Clone(opts.OriginalRequest), body, bytes.Clone(line), ¶m)
|
|
for i := range lines {
|
|
out <- cliproxyexecutor.StreamChunk{Payload: []byte(lines[i])}
|
|
}
|
|
}
|
|
lines := sdktranslator.TranslateStream(ctx, to, from, req.Model, bytes.Clone(opts.OriginalRequest), body, []byte("[DONE]"), ¶m)
|
|
for i := range lines {
|
|
out <- cliproxyexecutor.StreamChunk{Payload: []byte(lines[i])}
|
|
}
|
|
if errScan := scanner.Err(); errScan != nil {
|
|
recordAPIResponseError(ctx, e.cfg, errScan)
|
|
reporter.publishFailure(ctx)
|
|
out <- cliproxyexecutor.StreamChunk{Err: errScan}
|
|
}
|
|
}()
|
|
return stream, nil
|
|
}
|
|
|
|
// executeStreamWithAPIKey handles streaming authentication using API key credentials.
|
|
func (e *GeminiVertexExecutor) executeStreamWithAPIKey(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options, apiKey, baseURL string) (stream <-chan cliproxyexecutor.StreamChunk, err error) {
|
|
reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth)
|
|
defer reporter.trackFailure(ctx, &err)
|
|
|
|
upstreamModel := util.ResolveOriginalModel(req.Model, req.Metadata)
|
|
|
|
from := opts.SourceFormat
|
|
to := sdktranslator.FromString("gemini")
|
|
body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), true)
|
|
if budgetOverride, includeOverride, ok := util.ResolveThinkingConfigFromMetadata(req.Model, req.Metadata); ok && util.ModelSupportsThinking(req.Model) {
|
|
if budgetOverride != nil {
|
|
norm := util.NormalizeThinkingBudget(req.Model, *budgetOverride)
|
|
budgetOverride = &norm
|
|
}
|
|
body = util.ApplyGeminiThinkingConfig(body, budgetOverride, includeOverride)
|
|
}
|
|
body = util.ApplyDefaultThinkingIfNeeded(req.Model, body)
|
|
body = util.NormalizeGeminiThinkingBudget(req.Model, body)
|
|
body = util.StripThinkingConfigIfUnsupported(req.Model, body)
|
|
body = fixGeminiImageAspectRatio(req.Model, body)
|
|
body = applyPayloadConfig(e.cfg, req.Model, body)
|
|
body, _ = sjson.SetBytes(body, "model", upstreamModel)
|
|
|
|
// For API key auth, use simpler URL format without project/location
|
|
if baseURL == "" {
|
|
baseURL = "https://generativelanguage.googleapis.com"
|
|
}
|
|
url := fmt.Sprintf("%s/%s/publishers/google/models/%s:%s", baseURL, vertexAPIVersion, upstreamModel, "streamGenerateContent")
|
|
if opts.Alt == "" {
|
|
url = url + "?alt=sse"
|
|
} else {
|
|
url = url + fmt.Sprintf("?$alt=%s", opts.Alt)
|
|
}
|
|
body, _ = sjson.DeleteBytes(body, "session_id")
|
|
|
|
httpReq, errNewReq := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(body))
|
|
if errNewReq != nil {
|
|
return nil, errNewReq
|
|
}
|
|
httpReq.Header.Set("Content-Type", "application/json")
|
|
if apiKey != "" {
|
|
httpReq.Header.Set("x-goog-api-key", apiKey)
|
|
}
|
|
applyGeminiHeaders(httpReq, auth)
|
|
|
|
var authID, authLabel, authType, authValue string
|
|
if auth != nil {
|
|
authID = auth.ID
|
|
authLabel = auth.Label
|
|
authType, authValue = auth.AccountInfo()
|
|
}
|
|
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{
|
|
URL: url,
|
|
Method: http.MethodPost,
|
|
Headers: httpReq.Header.Clone(),
|
|
Body: body,
|
|
Provider: e.Identifier(),
|
|
AuthID: authID,
|
|
AuthLabel: authLabel,
|
|
AuthType: authType,
|
|
AuthValue: authValue,
|
|
})
|
|
|
|
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
|
|
httpResp, errDo := httpClient.Do(httpReq)
|
|
if errDo != nil {
|
|
recordAPIResponseError(ctx, e.cfg, errDo)
|
|
return nil, errDo
|
|
}
|
|
recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
|
|
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
|
|
b, _ := io.ReadAll(httpResp.Body)
|
|
appendAPIResponseChunk(ctx, e.cfg, b)
|
|
log.Debugf("request error, error status: %d, error body: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
|
|
if errClose := httpResp.Body.Close(); errClose != nil {
|
|
log.Errorf("vertex executor: close response body error: %v", errClose)
|
|
}
|
|
return nil, statusErr{code: httpResp.StatusCode, msg: string(b)}
|
|
}
|
|
|
|
out := make(chan cliproxyexecutor.StreamChunk)
|
|
stream = out
|
|
go func() {
|
|
defer close(out)
|
|
defer func() {
|
|
if errClose := httpResp.Body.Close(); errClose != nil {
|
|
log.Errorf("vertex executor: close response body error: %v", errClose)
|
|
}
|
|
}()
|
|
scanner := bufio.NewScanner(httpResp.Body)
|
|
scanner.Buffer(nil, streamScannerBuffer)
|
|
var param any
|
|
for scanner.Scan() {
|
|
line := scanner.Bytes()
|
|
appendAPIResponseChunk(ctx, e.cfg, line)
|
|
if detail, ok := parseGeminiStreamUsage(line); ok {
|
|
reporter.publish(ctx, detail)
|
|
}
|
|
lines := sdktranslator.TranslateStream(ctx, to, from, req.Model, bytes.Clone(opts.OriginalRequest), body, bytes.Clone(line), ¶m)
|
|
for i := range lines {
|
|
out <- cliproxyexecutor.StreamChunk{Payload: []byte(lines[i])}
|
|
}
|
|
}
|
|
lines := sdktranslator.TranslateStream(ctx, to, from, req.Model, bytes.Clone(opts.OriginalRequest), body, []byte("[DONE]"), ¶m)
|
|
for i := range lines {
|
|
out <- cliproxyexecutor.StreamChunk{Payload: []byte(lines[i])}
|
|
}
|
|
if errScan := scanner.Err(); errScan != nil {
|
|
recordAPIResponseError(ctx, e.cfg, errScan)
|
|
reporter.publishFailure(ctx)
|
|
out <- cliproxyexecutor.StreamChunk{Err: errScan}
|
|
}
|
|
}()
|
|
return stream, nil
|
|
}
|
|
|
|
// vertexCreds extracts project, location and raw service account JSON from auth metadata.
|
|
func vertexCreds(a *cliproxyauth.Auth) (projectID, location string, serviceAccountJSON []byte, err error) {
|
|
if a == nil || a.Metadata == nil {
|
|
return "", "", nil, fmt.Errorf("vertex executor: missing auth metadata")
|
|
}
|
|
if v, ok := a.Metadata["project_id"].(string); ok {
|
|
projectID = strings.TrimSpace(v)
|
|
}
|
|
if projectID == "" {
|
|
// Some service accounts may use "project"; still prefer standard field
|
|
if v, ok := a.Metadata["project"].(string); ok {
|
|
projectID = strings.TrimSpace(v)
|
|
}
|
|
}
|
|
if projectID == "" {
|
|
return "", "", nil, fmt.Errorf("vertex executor: missing project_id in credentials")
|
|
}
|
|
if v, ok := a.Metadata["location"].(string); ok && strings.TrimSpace(v) != "" {
|
|
location = strings.TrimSpace(v)
|
|
} else {
|
|
location = "us-central1"
|
|
}
|
|
var sa map[string]any
|
|
if raw, ok := a.Metadata["service_account"].(map[string]any); ok {
|
|
sa = raw
|
|
}
|
|
if sa == nil {
|
|
return "", "", nil, fmt.Errorf("vertex executor: missing service_account in credentials")
|
|
}
|
|
normalized, errNorm := vertexauth.NormalizeServiceAccountMap(sa)
|
|
if errNorm != nil {
|
|
return "", "", nil, fmt.Errorf("vertex executor: %w", errNorm)
|
|
}
|
|
saJSON, errMarshal := json.Marshal(normalized)
|
|
if errMarshal != nil {
|
|
return "", "", nil, fmt.Errorf("vertex executor: marshal service_account failed: %w", errMarshal)
|
|
}
|
|
return projectID, location, saJSON, nil
|
|
}
|
|
|
|
// vertexAPICreds extracts API key and base URL from auth attributes following the claudeCreds pattern.
|
|
func vertexAPICreds(a *cliproxyauth.Auth) (apiKey, baseURL string) {
|
|
if a == nil {
|
|
return "", ""
|
|
}
|
|
if a.Attributes != nil {
|
|
apiKey = a.Attributes["api_key"]
|
|
baseURL = a.Attributes["base_url"]
|
|
}
|
|
if apiKey == "" && a.Metadata != nil {
|
|
if v, ok := a.Metadata["access_token"].(string); ok {
|
|
apiKey = v
|
|
}
|
|
}
|
|
return
|
|
}
|
|
|
|
func vertexBaseURL(location string) string {
|
|
loc := strings.TrimSpace(location)
|
|
if loc == "" {
|
|
loc = "us-central1"
|
|
}
|
|
return fmt.Sprintf("https://%s-aiplatform.googleapis.com", loc)
|
|
}
|
|
|
|
func vertexAccessToken(ctx context.Context, cfg *config.Config, auth *cliproxyauth.Auth, saJSON []byte) (string, error) {
|
|
if httpClient := newProxyAwareHTTPClient(ctx, cfg, auth, 0); httpClient != nil {
|
|
ctx = context.WithValue(ctx, oauth2.HTTPClient, httpClient)
|
|
}
|
|
// Use cloud-platform scope for Vertex AI.
|
|
creds, errCreds := google.CredentialsFromJSON(ctx, saJSON, "https://www.googleapis.com/auth/cloud-platform")
|
|
if errCreds != nil {
|
|
return "", fmt.Errorf("vertex executor: parse service account json failed: %w", errCreds)
|
|
}
|
|
tok, errTok := creds.TokenSource.Token()
|
|
if errTok != nil {
|
|
return "", fmt.Errorf("vertex executor: get access token failed: %w", errTok)
|
|
}
|
|
return tok.AccessToken, nil
|
|
}
|