mirror of
https://github.com/router-for-me/CLIProxyAPI.git
synced 2026-02-18 04:10:51 +08:00
Introduce `PayloadConfig` in the configuration to define default and override rules for modifying payload parameters. Implement `applyPayloadConfig` and `applyPayloadConfigWithRoot` to apply these rules across various executors, ensuring consistent parameter handling for different models and protocols. Update all relevant executors to utilize this functionality.
428 lines
16 KiB
Go
428 lines
16 KiB
Go
// Package executor contains provider executors. This file implements the Vertex AI
|
|
// Gemini executor that talks to Google Vertex AI endpoints using service account
|
|
// credentials imported by the CLI.
|
|
package executor
|
|
|
|
import (
|
|
"bufio"
|
|
"bytes"
|
|
"context"
|
|
"encoding/json"
|
|
"fmt"
|
|
"io"
|
|
"net/http"
|
|
"strings"
|
|
|
|
vertexauth "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/vertex"
|
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
|
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
|
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
|
|
sdktranslator "github.com/router-for-me/CLIProxyAPI/v6/sdk/translator"
|
|
log "github.com/sirupsen/logrus"
|
|
"github.com/tidwall/gjson"
|
|
"github.com/tidwall/sjson"
|
|
"golang.org/x/oauth2"
|
|
"golang.org/x/oauth2/google"
|
|
)
|
|
|
|
const (
|
|
// vertexAPIVersion aligns with current public Vertex Generative AI API.
|
|
vertexAPIVersion = "v1"
|
|
)
|
|
|
|
// GeminiVertexExecutor sends requests to Vertex AI Gemini endpoints using service account credentials.
|
|
type GeminiVertexExecutor struct {
|
|
cfg *config.Config
|
|
}
|
|
|
|
// NewGeminiVertexExecutor constructs the Vertex executor.
|
|
func NewGeminiVertexExecutor(cfg *config.Config) *GeminiVertexExecutor {
|
|
return &GeminiVertexExecutor{cfg: cfg}
|
|
}
|
|
|
|
// Identifier returns provider key for manager routing.
|
|
func (e *GeminiVertexExecutor) Identifier() string { return "vertex" }
|
|
|
|
// PrepareRequest is a no-op for Vertex.
|
|
func (e *GeminiVertexExecutor) PrepareRequest(_ *http.Request, _ *cliproxyauth.Auth) error {
|
|
return nil
|
|
}
|
|
|
|
// Execute handles non-streaming requests.
|
|
func (e *GeminiVertexExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (resp cliproxyexecutor.Response, err error) {
|
|
projectID, location, saJSON, errCreds := vertexCreds(auth)
|
|
if errCreds != nil {
|
|
return resp, errCreds
|
|
}
|
|
|
|
reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth)
|
|
defer reporter.trackFailure(ctx, &err)
|
|
|
|
from := opts.SourceFormat
|
|
to := sdktranslator.FromString("gemini")
|
|
body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), false)
|
|
if budgetOverride, includeOverride, ok := util.GeminiThinkingFromMetadata(req.Metadata); ok && util.ModelSupportsThinking(req.Model) {
|
|
if budgetOverride != nil {
|
|
norm := util.NormalizeThinkingBudget(req.Model, *budgetOverride)
|
|
budgetOverride = &norm
|
|
}
|
|
body = util.ApplyGeminiThinkingConfig(body, budgetOverride, includeOverride)
|
|
}
|
|
body = util.StripThinkingConfigIfUnsupported(req.Model, body)
|
|
body = fixGeminiImageAspectRatio(req.Model, body)
|
|
body = applyPayloadConfig(e.cfg, req.Model, body)
|
|
|
|
action := "generateContent"
|
|
if req.Metadata != nil {
|
|
if a, _ := req.Metadata["action"].(string); a == "countTokens" {
|
|
action = "countTokens"
|
|
}
|
|
}
|
|
baseURL := vertexBaseURL(location)
|
|
url := fmt.Sprintf("%s/%s/projects/%s/locations/%s/publishers/google/models/%s:%s", baseURL, vertexAPIVersion, projectID, location, req.Model, action)
|
|
if opts.Alt != "" && action != "countTokens" {
|
|
url = url + fmt.Sprintf("?$alt=%s", opts.Alt)
|
|
}
|
|
body, _ = sjson.DeleteBytes(body, "session_id")
|
|
|
|
httpReq, errNewReq := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(body))
|
|
if errNewReq != nil {
|
|
return resp, errNewReq
|
|
}
|
|
httpReq.Header.Set("Content-Type", "application/json")
|
|
if token, errTok := vertexAccessToken(ctx, e.cfg, auth, saJSON); errTok == nil && token != "" {
|
|
httpReq.Header.Set("Authorization", "Bearer "+token)
|
|
} else if errTok != nil {
|
|
log.Errorf("vertex executor: access token error: %v", errTok)
|
|
return resp, statusErr{code: 500, msg: "internal server error"}
|
|
}
|
|
applyGeminiHeaders(httpReq, auth)
|
|
|
|
var authID, authLabel, authType, authValue string
|
|
if auth != nil {
|
|
authID = auth.ID
|
|
authLabel = auth.Label
|
|
authType, authValue = auth.AccountInfo()
|
|
}
|
|
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{
|
|
URL: url,
|
|
Method: http.MethodPost,
|
|
Headers: httpReq.Header.Clone(),
|
|
Body: body,
|
|
Provider: e.Identifier(),
|
|
AuthID: authID,
|
|
AuthLabel: authLabel,
|
|
AuthType: authType,
|
|
AuthValue: authValue,
|
|
})
|
|
|
|
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
|
|
httpResp, errDo := httpClient.Do(httpReq)
|
|
if errDo != nil {
|
|
recordAPIResponseError(ctx, e.cfg, errDo)
|
|
return resp, errDo
|
|
}
|
|
defer func() {
|
|
if errClose := httpResp.Body.Close(); errClose != nil {
|
|
log.Errorf("vertex executor: close response body error: %v", errClose)
|
|
}
|
|
}()
|
|
recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
|
|
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
|
|
b, _ := io.ReadAll(httpResp.Body)
|
|
appendAPIResponseChunk(ctx, e.cfg, b)
|
|
log.Debugf("request error, error status: %d, error body: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
|
|
err = statusErr{code: httpResp.StatusCode, msg: string(b)}
|
|
return resp, err
|
|
}
|
|
data, errRead := io.ReadAll(httpResp.Body)
|
|
if errRead != nil {
|
|
recordAPIResponseError(ctx, e.cfg, errRead)
|
|
return resp, errRead
|
|
}
|
|
appendAPIResponseChunk(ctx, e.cfg, data)
|
|
reporter.publish(ctx, parseGeminiUsage(data))
|
|
var param any
|
|
out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, bytes.Clone(opts.OriginalRequest), body, data, ¶m)
|
|
resp = cliproxyexecutor.Response{Payload: []byte(out)}
|
|
return resp, nil
|
|
}
|
|
|
|
// ExecuteStream handles SSE streaming for Vertex.
|
|
func (e *GeminiVertexExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (stream <-chan cliproxyexecutor.StreamChunk, err error) {
|
|
projectID, location, saJSON, errCreds := vertexCreds(auth)
|
|
if errCreds != nil {
|
|
return nil, errCreds
|
|
}
|
|
|
|
reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth)
|
|
defer reporter.trackFailure(ctx, &err)
|
|
|
|
from := opts.SourceFormat
|
|
to := sdktranslator.FromString("gemini")
|
|
body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), true)
|
|
if budgetOverride, includeOverride, ok := util.GeminiThinkingFromMetadata(req.Metadata); ok && util.ModelSupportsThinking(req.Model) {
|
|
if budgetOverride != nil {
|
|
norm := util.NormalizeThinkingBudget(req.Model, *budgetOverride)
|
|
budgetOverride = &norm
|
|
}
|
|
body = util.ApplyGeminiThinkingConfig(body, budgetOverride, includeOverride)
|
|
}
|
|
body = util.StripThinkingConfigIfUnsupported(req.Model, body)
|
|
body = fixGeminiImageAspectRatio(req.Model, body)
|
|
body = applyPayloadConfig(e.cfg, req.Model, body)
|
|
|
|
baseURL := vertexBaseURL(location)
|
|
url := fmt.Sprintf("%s/%s/projects/%s/locations/%s/publishers/google/models/%s:%s", baseURL, vertexAPIVersion, projectID, location, req.Model, "streamGenerateContent")
|
|
if opts.Alt == "" {
|
|
url = url + "?alt=sse"
|
|
} else {
|
|
url = url + fmt.Sprintf("?$alt=%s", opts.Alt)
|
|
}
|
|
body, _ = sjson.DeleteBytes(body, "session_id")
|
|
|
|
httpReq, errNewReq := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(body))
|
|
if errNewReq != nil {
|
|
return nil, errNewReq
|
|
}
|
|
httpReq.Header.Set("Content-Type", "application/json")
|
|
if token, errTok := vertexAccessToken(ctx, e.cfg, auth, saJSON); errTok == nil && token != "" {
|
|
httpReq.Header.Set("Authorization", "Bearer "+token)
|
|
} else if errTok != nil {
|
|
log.Errorf("vertex executor: access token error: %v", errTok)
|
|
return nil, statusErr{code: 500, msg: "internal server error"}
|
|
}
|
|
applyGeminiHeaders(httpReq, auth)
|
|
|
|
var authID, authLabel, authType, authValue string
|
|
if auth != nil {
|
|
authID = auth.ID
|
|
authLabel = auth.Label
|
|
authType, authValue = auth.AccountInfo()
|
|
}
|
|
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{
|
|
URL: url,
|
|
Method: http.MethodPost,
|
|
Headers: httpReq.Header.Clone(),
|
|
Body: body,
|
|
Provider: e.Identifier(),
|
|
AuthID: authID,
|
|
AuthLabel: authLabel,
|
|
AuthType: authType,
|
|
AuthValue: authValue,
|
|
})
|
|
|
|
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
|
|
httpResp, errDo := httpClient.Do(httpReq)
|
|
if errDo != nil {
|
|
recordAPIResponseError(ctx, e.cfg, errDo)
|
|
return nil, errDo
|
|
}
|
|
recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
|
|
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
|
|
b, _ := io.ReadAll(httpResp.Body)
|
|
appendAPIResponseChunk(ctx, e.cfg, b)
|
|
log.Debugf("request error, error status: %d, error body: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
|
|
if errClose := httpResp.Body.Close(); errClose != nil {
|
|
log.Errorf("vertex executor: close response body error: %v", errClose)
|
|
}
|
|
return nil, statusErr{code: httpResp.StatusCode, msg: string(b)}
|
|
}
|
|
|
|
out := make(chan cliproxyexecutor.StreamChunk)
|
|
stream = out
|
|
go func() {
|
|
defer close(out)
|
|
defer func() {
|
|
if errClose := httpResp.Body.Close(); errClose != nil {
|
|
log.Errorf("vertex executor: close response body error: %v", errClose)
|
|
}
|
|
}()
|
|
scanner := bufio.NewScanner(httpResp.Body)
|
|
buf := make([]byte, 20_971_520)
|
|
scanner.Buffer(buf, 20_971_520)
|
|
var param any
|
|
for scanner.Scan() {
|
|
line := scanner.Bytes()
|
|
appendAPIResponseChunk(ctx, e.cfg, line)
|
|
if detail, ok := parseGeminiStreamUsage(line); ok {
|
|
reporter.publish(ctx, detail)
|
|
}
|
|
lines := sdktranslator.TranslateStream(ctx, to, from, req.Model, bytes.Clone(opts.OriginalRequest), body, bytes.Clone(line), ¶m)
|
|
for i := range lines {
|
|
out <- cliproxyexecutor.StreamChunk{Payload: []byte(lines[i])}
|
|
}
|
|
}
|
|
lines := sdktranslator.TranslateStream(ctx, to, from, req.Model, bytes.Clone(opts.OriginalRequest), body, []byte("[DONE]"), ¶m)
|
|
for i := range lines {
|
|
out <- cliproxyexecutor.StreamChunk{Payload: []byte(lines[i])}
|
|
}
|
|
if errScan := scanner.Err(); errScan != nil {
|
|
recordAPIResponseError(ctx, e.cfg, errScan)
|
|
reporter.publishFailure(ctx)
|
|
out <- cliproxyexecutor.StreamChunk{Err: errScan}
|
|
}
|
|
}()
|
|
return stream, nil
|
|
}
|
|
|
|
// CountTokens calls Vertex countTokens endpoint.
|
|
func (e *GeminiVertexExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
|
|
projectID, location, saJSON, errCreds := vertexCreds(auth)
|
|
if errCreds != nil {
|
|
return cliproxyexecutor.Response{}, errCreds
|
|
}
|
|
from := opts.SourceFormat
|
|
to := sdktranslator.FromString("gemini")
|
|
translatedReq := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), false)
|
|
if budgetOverride, includeOverride, ok := util.GeminiThinkingFromMetadata(req.Metadata); ok && util.ModelSupportsThinking(req.Model) {
|
|
if budgetOverride != nil {
|
|
norm := util.NormalizeThinkingBudget(req.Model, *budgetOverride)
|
|
budgetOverride = &norm
|
|
}
|
|
translatedReq = util.ApplyGeminiThinkingConfig(translatedReq, budgetOverride, includeOverride)
|
|
}
|
|
translatedReq = util.StripThinkingConfigIfUnsupported(req.Model, translatedReq)
|
|
translatedReq = fixGeminiImageAspectRatio(req.Model, translatedReq)
|
|
respCtx := context.WithValue(ctx, "alt", opts.Alt)
|
|
translatedReq, _ = sjson.DeleteBytes(translatedReq, "tools")
|
|
translatedReq, _ = sjson.DeleteBytes(translatedReq, "generationConfig")
|
|
translatedReq, _ = sjson.DeleteBytes(translatedReq, "safetySettings")
|
|
|
|
baseURL := vertexBaseURL(location)
|
|
url := fmt.Sprintf("%s/%s/projects/%s/locations/%s/publishers/google/models/%s:%s", baseURL, vertexAPIVersion, projectID, location, req.Model, "countTokens")
|
|
|
|
httpReq, errNewReq := http.NewRequestWithContext(respCtx, http.MethodPost, url, bytes.NewReader(translatedReq))
|
|
if errNewReq != nil {
|
|
return cliproxyexecutor.Response{}, errNewReq
|
|
}
|
|
httpReq.Header.Set("Content-Type", "application/json")
|
|
if token, errTok := vertexAccessToken(ctx, e.cfg, auth, saJSON); errTok == nil && token != "" {
|
|
httpReq.Header.Set("Authorization", "Bearer "+token)
|
|
} else if errTok != nil {
|
|
log.Errorf("vertex executor: access token error: %v", errTok)
|
|
return cliproxyexecutor.Response{}, statusErr{code: 500, msg: "internal server error"}
|
|
}
|
|
applyGeminiHeaders(httpReq, auth)
|
|
|
|
var authID, authLabel, authType, authValue string
|
|
if auth != nil {
|
|
authID = auth.ID
|
|
authLabel = auth.Label
|
|
authType, authValue = auth.AccountInfo()
|
|
}
|
|
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{
|
|
URL: url,
|
|
Method: http.MethodPost,
|
|
Headers: httpReq.Header.Clone(),
|
|
Body: translatedReq,
|
|
Provider: e.Identifier(),
|
|
AuthID: authID,
|
|
AuthLabel: authLabel,
|
|
AuthType: authType,
|
|
AuthValue: authValue,
|
|
})
|
|
|
|
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
|
|
httpResp, errDo := httpClient.Do(httpReq)
|
|
if errDo != nil {
|
|
recordAPIResponseError(ctx, e.cfg, errDo)
|
|
return cliproxyexecutor.Response{}, errDo
|
|
}
|
|
defer func() {
|
|
if errClose := httpResp.Body.Close(); errClose != nil {
|
|
log.Errorf("vertex executor: close response body error: %v", errClose)
|
|
}
|
|
}()
|
|
recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
|
|
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
|
|
b, _ := io.ReadAll(httpResp.Body)
|
|
appendAPIResponseChunk(ctx, e.cfg, b)
|
|
log.Debugf("request error, error status: %d, error body: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
|
|
return cliproxyexecutor.Response{}, statusErr{code: httpResp.StatusCode, msg: string(b)}
|
|
}
|
|
data, errRead := io.ReadAll(httpResp.Body)
|
|
if errRead != nil {
|
|
recordAPIResponseError(ctx, e.cfg, errRead)
|
|
return cliproxyexecutor.Response{}, errRead
|
|
}
|
|
appendAPIResponseChunk(ctx, e.cfg, data)
|
|
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
|
|
log.Debugf("request error, error status: %d, error body: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), data))
|
|
return cliproxyexecutor.Response{}, statusErr{code: httpResp.StatusCode, msg: string(data)}
|
|
}
|
|
count := gjson.GetBytes(data, "totalTokens").Int()
|
|
out := sdktranslator.TranslateTokenCount(ctx, to, from, count, data)
|
|
return cliproxyexecutor.Response{Payload: []byte(out)}, nil
|
|
}
|
|
|
|
// Refresh is a no-op for service account based credentials.
|
|
func (e *GeminiVertexExecutor) Refresh(_ context.Context, auth *cliproxyauth.Auth) (*cliproxyauth.Auth, error) {
|
|
return auth, nil
|
|
}
|
|
|
|
// vertexCreds extracts project, location and raw service account JSON from auth metadata.
|
|
func vertexCreds(a *cliproxyauth.Auth) (projectID, location string, serviceAccountJSON []byte, err error) {
|
|
if a == nil || a.Metadata == nil {
|
|
return "", "", nil, fmt.Errorf("vertex executor: missing auth metadata")
|
|
}
|
|
if v, ok := a.Metadata["project_id"].(string); ok {
|
|
projectID = strings.TrimSpace(v)
|
|
}
|
|
if projectID == "" {
|
|
// Some service accounts may use "project"; still prefer standard field
|
|
if v, ok := a.Metadata["project"].(string); ok {
|
|
projectID = strings.TrimSpace(v)
|
|
}
|
|
}
|
|
if projectID == "" {
|
|
return "", "", nil, fmt.Errorf("vertex executor: missing project_id in credentials")
|
|
}
|
|
if v, ok := a.Metadata["location"].(string); ok && strings.TrimSpace(v) != "" {
|
|
location = strings.TrimSpace(v)
|
|
} else {
|
|
location = "us-central1"
|
|
}
|
|
var sa map[string]any
|
|
if raw, ok := a.Metadata["service_account"].(map[string]any); ok {
|
|
sa = raw
|
|
}
|
|
if sa == nil {
|
|
return "", "", nil, fmt.Errorf("vertex executor: missing service_account in credentials")
|
|
}
|
|
normalized, errNorm := vertexauth.NormalizeServiceAccountMap(sa)
|
|
if errNorm != nil {
|
|
return "", "", nil, fmt.Errorf("vertex executor: %w", errNorm)
|
|
}
|
|
saJSON, errMarshal := json.Marshal(normalized)
|
|
if errMarshal != nil {
|
|
return "", "", nil, fmt.Errorf("vertex executor: marshal service_account failed: %w", errMarshal)
|
|
}
|
|
return projectID, location, saJSON, nil
|
|
}
|
|
|
|
func vertexBaseURL(location string) string {
|
|
loc := strings.TrimSpace(location)
|
|
if loc == "" {
|
|
loc = "us-central1"
|
|
}
|
|
return fmt.Sprintf("https://%s-aiplatform.googleapis.com", loc)
|
|
}
|
|
|
|
func vertexAccessToken(ctx context.Context, cfg *config.Config, auth *cliproxyauth.Auth, saJSON []byte) (string, error) {
|
|
if httpClient := newProxyAwareHTTPClient(ctx, cfg, auth, 0); httpClient != nil {
|
|
ctx = context.WithValue(ctx, oauth2.HTTPClient, httpClient)
|
|
}
|
|
// Use cloud-platform scope for Vertex AI.
|
|
creds, errCreds := google.CredentialsFromJSON(ctx, saJSON, "https://www.googleapis.com/auth/cloud-platform")
|
|
if errCreds != nil {
|
|
return "", fmt.Errorf("vertex executor: parse service account json failed: %w", errCreds)
|
|
}
|
|
tok, errTok := creds.TokenSource.Token()
|
|
if errTok != nil {
|
|
return "", fmt.Errorf("vertex executor: get access token failed: %w", errTok)
|
|
}
|
|
return tok.AccessToken, nil
|
|
}
|