mirror of
https://github.com/router-for-me/CLIProxyAPI.git
synced 2026-02-02 20:40:52 +08:00
feat(vertex): add support for Vertex AI Gemini authentication and execution
Introduce Vertex AI Gemini integration with support for service account-based authentication, credential storage, and import functionality. Added new executor for Vertex AI requests, including execution and streaming paths, and integrated it into the core manager. Enhanced CLI with `--vertex-import` flag for importing service account keys.
This commit is contained in:
@@ -57,6 +57,7 @@ func main() {
|
|||||||
var iflowLogin bool
|
var iflowLogin bool
|
||||||
var noBrowser bool
|
var noBrowser bool
|
||||||
var projectID string
|
var projectID string
|
||||||
|
var vertexImport string
|
||||||
var configPath string
|
var configPath string
|
||||||
var password string
|
var password string
|
||||||
|
|
||||||
@@ -69,6 +70,7 @@ func main() {
|
|||||||
flag.BoolVar(&noBrowser, "no-browser", false, "Don't open browser automatically for OAuth")
|
flag.BoolVar(&noBrowser, "no-browser", false, "Don't open browser automatically for OAuth")
|
||||||
flag.StringVar(&projectID, "project_id", "", "Project ID (Gemini only, not required)")
|
flag.StringVar(&projectID, "project_id", "", "Project ID (Gemini only, not required)")
|
||||||
flag.StringVar(&configPath, "config", DefaultConfigPath, "Configure File Path")
|
flag.StringVar(&configPath, "config", DefaultConfigPath, "Configure File Path")
|
||||||
|
flag.StringVar(&vertexImport, "vertex-import", "", "Import Vertex service account key JSON file")
|
||||||
flag.StringVar(&password, "password", "", "")
|
flag.StringVar(&password, "password", "", "")
|
||||||
|
|
||||||
flag.CommandLine.Usage = func() {
|
flag.CommandLine.Usage = func() {
|
||||||
@@ -417,7 +419,10 @@ func main() {
|
|||||||
|
|
||||||
// Handle different command modes based on the provided flags.
|
// Handle different command modes based on the provided flags.
|
||||||
|
|
||||||
if login {
|
if vertexImport != "" {
|
||||||
|
// Handle Vertex service account import
|
||||||
|
cmd.DoVertexImport(cfg, vertexImport)
|
||||||
|
} else if login {
|
||||||
// Handle Google/Gemini login
|
// Handle Google/Gemini login
|
||||||
cmd.DoLogin(cfg, projectID, options)
|
cmd.DoLogin(cfg, projectID, options)
|
||||||
} else if codexLogin {
|
} else if codexLogin {
|
||||||
|
|||||||
208
internal/auth/vertex/keyutil.go
Normal file
208
internal/auth/vertex/keyutil.go
Normal file
@@ -0,0 +1,208 @@
|
|||||||
|
package vertex
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/rsa"
|
||||||
|
"crypto/x509"
|
||||||
|
"encoding/base64"
|
||||||
|
"encoding/json"
|
||||||
|
"encoding/pem"
|
||||||
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
// NormalizeServiceAccountJSON normalizes the given JSON-encoded service account payload.
|
||||||
|
// It returns the normalized JSON (with sanitized private_key) or, if normalization fails,
|
||||||
|
// the original bytes and the encountered error.
|
||||||
|
func NormalizeServiceAccountJSON(raw []byte) ([]byte, error) {
|
||||||
|
if len(raw) == 0 {
|
||||||
|
return raw, nil
|
||||||
|
}
|
||||||
|
var payload map[string]any
|
||||||
|
if err := json.Unmarshal(raw, &payload); err != nil {
|
||||||
|
return raw, err
|
||||||
|
}
|
||||||
|
normalized, err := NormalizeServiceAccountMap(payload)
|
||||||
|
if err != nil {
|
||||||
|
return raw, err
|
||||||
|
}
|
||||||
|
out, err := json.Marshal(normalized)
|
||||||
|
if err != nil {
|
||||||
|
return raw, err
|
||||||
|
}
|
||||||
|
return out, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// NormalizeServiceAccountMap returns a copy of the given service account map with
|
||||||
|
// a sanitized private_key field that is guaranteed to contain a valid RSA PRIVATE KEY PEM block.
|
||||||
|
func NormalizeServiceAccountMap(sa map[string]any) (map[string]any, error) {
|
||||||
|
if sa == nil {
|
||||||
|
return nil, fmt.Errorf("service account payload is empty")
|
||||||
|
}
|
||||||
|
pk, _ := sa["private_key"].(string)
|
||||||
|
if strings.TrimSpace(pk) == "" {
|
||||||
|
return nil, fmt.Errorf("service account missing private_key")
|
||||||
|
}
|
||||||
|
normalized, err := sanitizePrivateKey(pk)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
clone := make(map[string]any, len(sa))
|
||||||
|
for k, v := range sa {
|
||||||
|
clone[k] = v
|
||||||
|
}
|
||||||
|
clone["private_key"] = normalized
|
||||||
|
return clone, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func sanitizePrivateKey(raw string) (string, error) {
|
||||||
|
pk := strings.ReplaceAll(raw, "\r\n", "\n")
|
||||||
|
pk = strings.ReplaceAll(pk, "\r", "\n")
|
||||||
|
pk = stripANSIEscape(pk)
|
||||||
|
pk = strings.ToValidUTF8(pk, "")
|
||||||
|
pk = strings.TrimSpace(pk)
|
||||||
|
|
||||||
|
normalized := pk
|
||||||
|
if block, _ := pem.Decode([]byte(pk)); block == nil {
|
||||||
|
// Attempt to reconstruct from the textual payload.
|
||||||
|
if reconstructed, err := rebuildPEM(pk); err == nil {
|
||||||
|
normalized = reconstructed
|
||||||
|
} else {
|
||||||
|
return "", fmt.Errorf("private_key is not valid pem: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
block, _ := pem.Decode([]byte(normalized))
|
||||||
|
if block == nil {
|
||||||
|
return "", fmt.Errorf("private_key pem decode failed")
|
||||||
|
}
|
||||||
|
|
||||||
|
rsaBlock, err := ensureRSAPrivateKey(block)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
return string(pem.EncodeToMemory(rsaBlock)), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func ensureRSAPrivateKey(block *pem.Block) (*pem.Block, error) {
|
||||||
|
if block == nil {
|
||||||
|
return nil, fmt.Errorf("pem block is nil")
|
||||||
|
}
|
||||||
|
|
||||||
|
if block.Type == "RSA PRIVATE KEY" {
|
||||||
|
if _, err := x509.ParsePKCS1PrivateKey(block.Bytes); err != nil {
|
||||||
|
return nil, fmt.Errorf("private_key invalid rsa: %w", err)
|
||||||
|
}
|
||||||
|
return block, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if block.Type == "PRIVATE KEY" {
|
||||||
|
key, err := x509.ParsePKCS8PrivateKey(block.Bytes)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("private_key invalid pkcs8: %w", err)
|
||||||
|
}
|
||||||
|
rsaKey, ok := key.(*rsa.PrivateKey)
|
||||||
|
if !ok {
|
||||||
|
return nil, fmt.Errorf("private_key is not an RSA key")
|
||||||
|
}
|
||||||
|
der := x509.MarshalPKCS1PrivateKey(rsaKey)
|
||||||
|
return &pem.Block{Type: "RSA PRIVATE KEY", Bytes: der}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Attempt auto-detection: try PKCS#1 first, then PKCS#8.
|
||||||
|
if rsaKey, err := x509.ParsePKCS1PrivateKey(block.Bytes); err == nil {
|
||||||
|
der := x509.MarshalPKCS1PrivateKey(rsaKey)
|
||||||
|
return &pem.Block{Type: "RSA PRIVATE KEY", Bytes: der}, nil
|
||||||
|
}
|
||||||
|
if key, err := x509.ParsePKCS8PrivateKey(block.Bytes); err == nil {
|
||||||
|
if rsaKey, ok := key.(*rsa.PrivateKey); ok {
|
||||||
|
der := x509.MarshalPKCS1PrivateKey(rsaKey)
|
||||||
|
return &pem.Block{Type: "RSA PRIVATE KEY", Bytes: der}, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil, fmt.Errorf("private_key uses unsupported format")
|
||||||
|
}
|
||||||
|
|
||||||
|
func rebuildPEM(raw string) (string, error) {
|
||||||
|
kind := "PRIVATE KEY"
|
||||||
|
if strings.Contains(raw, "RSA PRIVATE KEY") {
|
||||||
|
kind = "RSA PRIVATE KEY"
|
||||||
|
}
|
||||||
|
header := "-----BEGIN " + kind + "-----"
|
||||||
|
footer := "-----END " + kind + "-----"
|
||||||
|
start := strings.Index(raw, header)
|
||||||
|
end := strings.Index(raw, footer)
|
||||||
|
if start < 0 || end <= start {
|
||||||
|
return "", fmt.Errorf("missing pem markers")
|
||||||
|
}
|
||||||
|
body := raw[start+len(header) : end]
|
||||||
|
payload := filterBase64(body)
|
||||||
|
if payload == "" {
|
||||||
|
return "", fmt.Errorf("private_key base64 payload empty")
|
||||||
|
}
|
||||||
|
der, err := base64.StdEncoding.DecodeString(payload)
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("private_key base64 decode failed: %w", err)
|
||||||
|
}
|
||||||
|
block := &pem.Block{Type: kind, Bytes: der}
|
||||||
|
return string(pem.EncodeToMemory(block)), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func filterBase64(s string) string {
|
||||||
|
var b strings.Builder
|
||||||
|
for _, r := range s {
|
||||||
|
switch {
|
||||||
|
case r >= 'A' && r <= 'Z':
|
||||||
|
b.WriteRune(r)
|
||||||
|
case r >= 'a' && r <= 'z':
|
||||||
|
b.WriteRune(r)
|
||||||
|
case r >= '0' && r <= '9':
|
||||||
|
b.WriteRune(r)
|
||||||
|
case r == '+' || r == '/' || r == '=':
|
||||||
|
b.WriteRune(r)
|
||||||
|
default:
|
||||||
|
// skip
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return b.String()
|
||||||
|
}
|
||||||
|
|
||||||
|
func stripANSIEscape(s string) string {
|
||||||
|
in := []rune(s)
|
||||||
|
var out []rune
|
||||||
|
for i := 0; i < len(in); i++ {
|
||||||
|
r := in[i]
|
||||||
|
if r != 0x1b {
|
||||||
|
out = append(out, r)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if i+1 >= len(in) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
next := in[i+1]
|
||||||
|
switch next {
|
||||||
|
case ']':
|
||||||
|
i += 2
|
||||||
|
for i < len(in) {
|
||||||
|
if in[i] == 0x07 {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
if in[i] == 0x1b && i+1 < len(in) && in[i+1] == '\\' {
|
||||||
|
i++
|
||||||
|
break
|
||||||
|
}
|
||||||
|
i++
|
||||||
|
}
|
||||||
|
case '[':
|
||||||
|
i += 2
|
||||||
|
for i < len(in) {
|
||||||
|
if (in[i] >= 'A' && in[i] <= 'Z') || (in[i] >= 'a' && in[i] <= 'z') {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
i++
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
// skip single ESC
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return string(out)
|
||||||
|
}
|
||||||
66
internal/auth/vertex/vertex_credentials.go
Normal file
66
internal/auth/vertex/vertex_credentials.go
Normal file
@@ -0,0 +1,66 @@
|
|||||||
|
// Package vertex provides token storage for Google Vertex AI Gemini via service account credentials.
|
||||||
|
// It serialises service account JSON into an auth file that is consumed by the runtime executor.
|
||||||
|
package vertex
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/misc"
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
)
|
||||||
|
|
||||||
|
// VertexCredentialStorage stores the service account JSON for Vertex AI access.
|
||||||
|
// The content is persisted verbatim under the "service_account" key, together with
|
||||||
|
// helper fields for project, location and email to improve logging and discovery.
|
||||||
|
type VertexCredentialStorage struct {
|
||||||
|
// ServiceAccount holds the parsed service account JSON content.
|
||||||
|
ServiceAccount map[string]any `json:"service_account"`
|
||||||
|
|
||||||
|
// ProjectID is derived from the service account JSON (project_id).
|
||||||
|
ProjectID string `json:"project_id"`
|
||||||
|
|
||||||
|
// Email is the client_email from the service account JSON.
|
||||||
|
Email string `json:"email"`
|
||||||
|
|
||||||
|
// Location optionally sets a default region (e.g., us-central1) for Vertex endpoints.
|
||||||
|
Location string `json:"location,omitempty"`
|
||||||
|
|
||||||
|
// Type is the provider identifier stored alongside credentials. Always "vertex".
|
||||||
|
Type string `json:"type"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// SaveTokenToFile writes the credential payload to the given file path in JSON format.
|
||||||
|
// It ensures the parent directory exists and logs the operation for transparency.
|
||||||
|
func (s *VertexCredentialStorage) SaveTokenToFile(authFilePath string) error {
|
||||||
|
misc.LogSavingCredentials(authFilePath)
|
||||||
|
if s == nil {
|
||||||
|
return fmt.Errorf("vertex credential: storage is nil")
|
||||||
|
}
|
||||||
|
if s.ServiceAccount == nil {
|
||||||
|
return fmt.Errorf("vertex credential: service account content is empty")
|
||||||
|
}
|
||||||
|
// Ensure we tag the file with the provider type.
|
||||||
|
s.Type = "vertex"
|
||||||
|
|
||||||
|
if err := os.MkdirAll(filepath.Dir(authFilePath), 0o700); err != nil {
|
||||||
|
return fmt.Errorf("vertex credential: create directory failed: %w", err)
|
||||||
|
}
|
||||||
|
f, err := os.Create(authFilePath)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("vertex credential: create file failed: %w", err)
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
if errClose := f.Close(); errClose != nil {
|
||||||
|
log.Errorf("vertex credential: failed to close file: %v", errClose)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
enc := json.NewEncoder(f)
|
||||||
|
enc.SetIndent("", " ")
|
||||||
|
if err = enc.Encode(s); err != nil {
|
||||||
|
return fmt.Errorf("vertex credential: encode failed: %w", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
123
internal/cmd/vertex_import.go
Normal file
123
internal/cmd/vertex_import.go
Normal file
@@ -0,0 +1,123 @@
|
|||||||
|
// Package cmd contains CLI helpers. This file implements importing a Vertex AI
|
||||||
|
// service account JSON into the auth store as a dedicated "vertex" credential.
|
||||||
|
package cmd
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"os"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/auth/vertex"
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
||||||
|
sdkAuth "github.com/router-for-me/CLIProxyAPI/v6/sdk/auth"
|
||||||
|
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
)
|
||||||
|
|
||||||
|
// DoVertexImport imports a Google Cloud service account key JSON and persists
|
||||||
|
// it as a "vertex" provider credential. The file content is embedded in the auth
|
||||||
|
// file to allow portable deployment across stores.
|
||||||
|
func DoVertexImport(cfg *config.Config, keyPath string) {
|
||||||
|
if cfg == nil {
|
||||||
|
cfg = &config.Config{}
|
||||||
|
}
|
||||||
|
if resolved, errResolve := util.ResolveAuthDir(cfg.AuthDir); errResolve == nil {
|
||||||
|
cfg.AuthDir = resolved
|
||||||
|
}
|
||||||
|
rawPath := strings.TrimSpace(keyPath)
|
||||||
|
if rawPath == "" {
|
||||||
|
log.Fatalf("vertex-import: missing service account key path")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
data, errRead := os.ReadFile(rawPath)
|
||||||
|
if errRead != nil {
|
||||||
|
log.Fatalf("vertex-import: read file failed: %v", errRead)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
var sa map[string]any
|
||||||
|
if errUnmarshal := json.Unmarshal(data, &sa); errUnmarshal != nil {
|
||||||
|
log.Fatalf("vertex-import: invalid service account json: %v", errUnmarshal)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
// Validate and normalize private_key before saving
|
||||||
|
normalizedSA, errFix := vertex.NormalizeServiceAccountMap(sa)
|
||||||
|
if errFix != nil {
|
||||||
|
log.Fatalf("vertex-import: %v", errFix)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
sa = normalizedSA
|
||||||
|
email, _ := sa["client_email"].(string)
|
||||||
|
projectID, _ := sa["project_id"].(string)
|
||||||
|
if strings.TrimSpace(projectID) == "" {
|
||||||
|
log.Fatalf("vertex-import: project_id missing in service account json")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if strings.TrimSpace(email) == "" {
|
||||||
|
// Keep empty email but warn
|
||||||
|
log.Warn("vertex-import: client_email missing in service account json")
|
||||||
|
}
|
||||||
|
// Default location if not provided by user. Can be edited in the saved file later.
|
||||||
|
location := "us-central1"
|
||||||
|
|
||||||
|
fileName := fmt.Sprintf("vertex-%s.json", sanitizeFilePart(projectID))
|
||||||
|
// Build auth record
|
||||||
|
storage := &vertex.VertexCredentialStorage{
|
||||||
|
ServiceAccount: sa,
|
||||||
|
ProjectID: projectID,
|
||||||
|
Email: email,
|
||||||
|
Location: location,
|
||||||
|
}
|
||||||
|
metadata := map[string]any{
|
||||||
|
"service_account": sa,
|
||||||
|
"project_id": projectID,
|
||||||
|
"email": email,
|
||||||
|
"location": location,
|
||||||
|
"type": "vertex",
|
||||||
|
"label": labelForVertex(projectID, email),
|
||||||
|
}
|
||||||
|
record := &coreauth.Auth{
|
||||||
|
ID: fileName,
|
||||||
|
Provider: "vertex",
|
||||||
|
FileName: fileName,
|
||||||
|
Storage: storage,
|
||||||
|
Metadata: metadata,
|
||||||
|
}
|
||||||
|
|
||||||
|
store := sdkAuth.GetTokenStore()
|
||||||
|
if setter, ok := store.(interface{ SetBaseDir(string) }); ok {
|
||||||
|
setter.SetBaseDir(cfg.AuthDir)
|
||||||
|
}
|
||||||
|
path, errSave := store.Save(context.Background(), record)
|
||||||
|
if errSave != nil {
|
||||||
|
log.Fatalf("vertex-import: save credential failed: %v", errSave)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
fmt.Printf("Vertex credentials imported: %s\n", path)
|
||||||
|
}
|
||||||
|
|
||||||
|
func sanitizeFilePart(s string) string {
|
||||||
|
out := strings.TrimSpace(s)
|
||||||
|
replacers := []string{"/", "_", "\\", "_", ":", "_", " ", "-"}
|
||||||
|
for i := 0; i < len(replacers); i += 2 {
|
||||||
|
out = strings.ReplaceAll(out, replacers[i], replacers[i+1])
|
||||||
|
}
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
func labelForVertex(projectID, email string) string {
|
||||||
|
p := strings.TrimSpace(projectID)
|
||||||
|
e := strings.TrimSpace(email)
|
||||||
|
if p != "" && e != "" {
|
||||||
|
return fmt.Sprintf("%s (%s)", p, e)
|
||||||
|
}
|
||||||
|
if p != "" {
|
||||||
|
return p
|
||||||
|
}
|
||||||
|
if e != "" {
|
||||||
|
return e
|
||||||
|
}
|
||||||
|
return "vertex"
|
||||||
|
}
|
||||||
421
internal/runtime/executor/gemini_vertex_executor.go
Normal file
421
internal/runtime/executor/gemini_vertex_executor.go
Normal file
@@ -0,0 +1,421 @@
|
|||||||
|
// Package executor contains provider executors. This file implements the Vertex AI
|
||||||
|
// Gemini executor that talks to Google Vertex AI endpoints using service account
|
||||||
|
// credentials imported by the CLI.
|
||||||
|
package executor
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bufio"
|
||||||
|
"bytes"
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
vertexauth "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/vertex"
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
||||||
|
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||||
|
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
|
||||||
|
sdktranslator "github.com/router-for-me/CLIProxyAPI/v6/sdk/translator"
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
"github.com/tidwall/gjson"
|
||||||
|
"github.com/tidwall/sjson"
|
||||||
|
"golang.org/x/oauth2/google"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
// vertexAPIVersion aligns with current public Vertex Generative AI API.
|
||||||
|
vertexAPIVersion = "v1"
|
||||||
|
)
|
||||||
|
|
||||||
|
// GeminiVertexExecutor sends requests to Vertex AI Gemini endpoints using service account credentials.
|
||||||
|
type GeminiVertexExecutor struct {
|
||||||
|
cfg *config.Config
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewGeminiVertexExecutor constructs the Vertex executor.
|
||||||
|
func NewGeminiVertexExecutor(cfg *config.Config) *GeminiVertexExecutor {
|
||||||
|
return &GeminiVertexExecutor{cfg: cfg}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Identifier returns provider key for manager routing.
|
||||||
|
func (e *GeminiVertexExecutor) Identifier() string { return "vertex" }
|
||||||
|
|
||||||
|
// PrepareRequest is a no-op for Vertex.
|
||||||
|
func (e *GeminiVertexExecutor) PrepareRequest(_ *http.Request, _ *cliproxyauth.Auth) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Execute handles non-streaming requests.
|
||||||
|
func (e *GeminiVertexExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (resp cliproxyexecutor.Response, err error) {
|
||||||
|
projectID, location, saJSON, errCreds := vertexCreds(auth)
|
||||||
|
if errCreds != nil {
|
||||||
|
return resp, errCreds
|
||||||
|
}
|
||||||
|
|
||||||
|
reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth)
|
||||||
|
defer reporter.trackFailure(ctx, &err)
|
||||||
|
|
||||||
|
from := opts.SourceFormat
|
||||||
|
to := sdktranslator.FromString("gemini")
|
||||||
|
body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), false)
|
||||||
|
if budgetOverride, includeOverride, ok := util.GeminiThinkingFromMetadata(req.Metadata); ok && util.ModelSupportsThinking(req.Model) {
|
||||||
|
if budgetOverride != nil {
|
||||||
|
norm := util.NormalizeThinkingBudget(req.Model, *budgetOverride)
|
||||||
|
budgetOverride = &norm
|
||||||
|
}
|
||||||
|
body = util.ApplyGeminiThinkingConfig(body, budgetOverride, includeOverride)
|
||||||
|
}
|
||||||
|
body = util.StripThinkingConfigIfUnsupported(req.Model, body)
|
||||||
|
body = fixGeminiImageAspectRatio(req.Model, body)
|
||||||
|
|
||||||
|
action := "generateContent"
|
||||||
|
if req.Metadata != nil {
|
||||||
|
if a, _ := req.Metadata["action"].(string); a == "countTokens" {
|
||||||
|
action = "countTokens"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
baseURL := vertexBaseURL(location)
|
||||||
|
url := fmt.Sprintf("%s/%s/projects/%s/locations/%s/publishers/google/models/%s:%s", baseURL, vertexAPIVersion, projectID, location, req.Model, action)
|
||||||
|
if opts.Alt != "" && action != "countTokens" {
|
||||||
|
url = url + fmt.Sprintf("?$alt=%s", opts.Alt)
|
||||||
|
}
|
||||||
|
body, _ = sjson.DeleteBytes(body, "session_id")
|
||||||
|
|
||||||
|
httpReq, errNewReq := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(body))
|
||||||
|
if errNewReq != nil {
|
||||||
|
return resp, errNewReq
|
||||||
|
}
|
||||||
|
httpReq.Header.Set("Content-Type", "application/json")
|
||||||
|
if token, errTok := vertexAccessToken(ctx, saJSON); errTok == nil && token != "" {
|
||||||
|
httpReq.Header.Set("Authorization", "Bearer "+token)
|
||||||
|
} else if errTok != nil {
|
||||||
|
log.Errorf("vertex executor: access token error: %v", errTok)
|
||||||
|
return resp, statusErr{code: 500, msg: "internal server error"}
|
||||||
|
}
|
||||||
|
applyGeminiHeaders(httpReq, auth)
|
||||||
|
|
||||||
|
var authID, authLabel, authType, authValue string
|
||||||
|
if auth != nil {
|
||||||
|
authID = auth.ID
|
||||||
|
authLabel = auth.Label
|
||||||
|
authType, authValue = auth.AccountInfo()
|
||||||
|
}
|
||||||
|
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{
|
||||||
|
URL: url,
|
||||||
|
Method: http.MethodPost,
|
||||||
|
Headers: httpReq.Header.Clone(),
|
||||||
|
Body: body,
|
||||||
|
Provider: e.Identifier(),
|
||||||
|
AuthID: authID,
|
||||||
|
AuthLabel: authLabel,
|
||||||
|
AuthType: authType,
|
||||||
|
AuthValue: authValue,
|
||||||
|
})
|
||||||
|
|
||||||
|
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
|
||||||
|
httpResp, errDo := httpClient.Do(httpReq)
|
||||||
|
if errDo != nil {
|
||||||
|
recordAPIResponseError(ctx, e.cfg, errDo)
|
||||||
|
return resp, errDo
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
if errClose := httpResp.Body.Close(); errClose != nil {
|
||||||
|
log.Errorf("vertex executor: close response body error: %v", errClose)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
|
||||||
|
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
|
||||||
|
b, _ := io.ReadAll(httpResp.Body)
|
||||||
|
appendAPIResponseChunk(ctx, e.cfg, b)
|
||||||
|
log.Debugf("request error, error status: %d, error body: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
|
||||||
|
err = statusErr{code: httpResp.StatusCode, msg: string(b)}
|
||||||
|
return resp, err
|
||||||
|
}
|
||||||
|
data, errRead := io.ReadAll(httpResp.Body)
|
||||||
|
if errRead != nil {
|
||||||
|
recordAPIResponseError(ctx, e.cfg, errRead)
|
||||||
|
return resp, errRead
|
||||||
|
}
|
||||||
|
appendAPIResponseChunk(ctx, e.cfg, data)
|
||||||
|
reporter.publish(ctx, parseGeminiUsage(data))
|
||||||
|
var param any
|
||||||
|
out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, bytes.Clone(opts.OriginalRequest), body, data, ¶m)
|
||||||
|
resp = cliproxyexecutor.Response{Payload: []byte(out)}
|
||||||
|
return resp, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// ExecuteStream handles SSE streaming for Vertex.
|
||||||
|
func (e *GeminiVertexExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (stream <-chan cliproxyexecutor.StreamChunk, err error) {
|
||||||
|
projectID, location, saJSON, errCreds := vertexCreds(auth)
|
||||||
|
if errCreds != nil {
|
||||||
|
return nil, errCreds
|
||||||
|
}
|
||||||
|
|
||||||
|
reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth)
|
||||||
|
defer reporter.trackFailure(ctx, &err)
|
||||||
|
|
||||||
|
from := opts.SourceFormat
|
||||||
|
to := sdktranslator.FromString("gemini")
|
||||||
|
body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), true)
|
||||||
|
if budgetOverride, includeOverride, ok := util.GeminiThinkingFromMetadata(req.Metadata); ok && util.ModelSupportsThinking(req.Model) {
|
||||||
|
if budgetOverride != nil {
|
||||||
|
norm := util.NormalizeThinkingBudget(req.Model, *budgetOverride)
|
||||||
|
budgetOverride = &norm
|
||||||
|
}
|
||||||
|
body = util.ApplyGeminiThinkingConfig(body, budgetOverride, includeOverride)
|
||||||
|
}
|
||||||
|
body = util.StripThinkingConfigIfUnsupported(req.Model, body)
|
||||||
|
body = fixGeminiImageAspectRatio(req.Model, body)
|
||||||
|
|
||||||
|
baseURL := vertexBaseURL(location)
|
||||||
|
url := fmt.Sprintf("%s/%s/projects/%s/locations/%s/publishers/google/models/%s:%s", baseURL, vertexAPIVersion, projectID, location, req.Model, "streamGenerateContent")
|
||||||
|
if opts.Alt == "" {
|
||||||
|
url = url + "?alt=sse"
|
||||||
|
} else {
|
||||||
|
url = url + fmt.Sprintf("?$alt=%s", opts.Alt)
|
||||||
|
}
|
||||||
|
body, _ = sjson.DeleteBytes(body, "session_id")
|
||||||
|
|
||||||
|
httpReq, errNewReq := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(body))
|
||||||
|
if errNewReq != nil {
|
||||||
|
return nil, errNewReq
|
||||||
|
}
|
||||||
|
httpReq.Header.Set("Content-Type", "application/json")
|
||||||
|
if token, errTok := vertexAccessToken(ctx, saJSON); errTok == nil && token != "" {
|
||||||
|
httpReq.Header.Set("Authorization", "Bearer "+token)
|
||||||
|
} else if errTok != nil {
|
||||||
|
log.Errorf("vertex executor: access token error: %v", errTok)
|
||||||
|
return nil, statusErr{code: 500, msg: "internal server error"}
|
||||||
|
}
|
||||||
|
applyGeminiHeaders(httpReq, auth)
|
||||||
|
|
||||||
|
var authID, authLabel, authType, authValue string
|
||||||
|
if auth != nil {
|
||||||
|
authID = auth.ID
|
||||||
|
authLabel = auth.Label
|
||||||
|
authType, authValue = auth.AccountInfo()
|
||||||
|
}
|
||||||
|
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{
|
||||||
|
URL: url,
|
||||||
|
Method: http.MethodPost,
|
||||||
|
Headers: httpReq.Header.Clone(),
|
||||||
|
Body: body,
|
||||||
|
Provider: e.Identifier(),
|
||||||
|
AuthID: authID,
|
||||||
|
AuthLabel: authLabel,
|
||||||
|
AuthType: authType,
|
||||||
|
AuthValue: authValue,
|
||||||
|
})
|
||||||
|
|
||||||
|
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
|
||||||
|
httpResp, errDo := httpClient.Do(httpReq)
|
||||||
|
if errDo != nil {
|
||||||
|
recordAPIResponseError(ctx, e.cfg, errDo)
|
||||||
|
return nil, errDo
|
||||||
|
}
|
||||||
|
recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
|
||||||
|
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
|
||||||
|
b, _ := io.ReadAll(httpResp.Body)
|
||||||
|
appendAPIResponseChunk(ctx, e.cfg, b)
|
||||||
|
log.Debugf("request error, error status: %d, error body: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
|
||||||
|
if errClose := httpResp.Body.Close(); errClose != nil {
|
||||||
|
log.Errorf("vertex executor: close response body error: %v", errClose)
|
||||||
|
}
|
||||||
|
return nil, statusErr{code: httpResp.StatusCode, msg: string(b)}
|
||||||
|
}
|
||||||
|
|
||||||
|
out := make(chan cliproxyexecutor.StreamChunk)
|
||||||
|
stream = out
|
||||||
|
go func() {
|
||||||
|
defer close(out)
|
||||||
|
defer func() {
|
||||||
|
if errClose := httpResp.Body.Close(); errClose != nil {
|
||||||
|
log.Errorf("vertex executor: close response body error: %v", errClose)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
scanner := bufio.NewScanner(httpResp.Body)
|
||||||
|
buf := make([]byte, 20_971_520)
|
||||||
|
scanner.Buffer(buf, 20_971_520)
|
||||||
|
var param any
|
||||||
|
for scanner.Scan() {
|
||||||
|
line := scanner.Bytes()
|
||||||
|
appendAPIResponseChunk(ctx, e.cfg, line)
|
||||||
|
if detail, ok := parseGeminiStreamUsage(line); ok {
|
||||||
|
reporter.publish(ctx, detail)
|
||||||
|
}
|
||||||
|
lines := sdktranslator.TranslateStream(ctx, to, from, req.Model, bytes.Clone(opts.OriginalRequest), body, bytes.Clone(line), ¶m)
|
||||||
|
for i := range lines {
|
||||||
|
out <- cliproxyexecutor.StreamChunk{Payload: []byte(lines[i])}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
lines := sdktranslator.TranslateStream(ctx, to, from, req.Model, bytes.Clone(opts.OriginalRequest), body, []byte("[DONE]"), ¶m)
|
||||||
|
for i := range lines {
|
||||||
|
out <- cliproxyexecutor.StreamChunk{Payload: []byte(lines[i])}
|
||||||
|
}
|
||||||
|
if errScan := scanner.Err(); errScan != nil {
|
||||||
|
recordAPIResponseError(ctx, e.cfg, errScan)
|
||||||
|
reporter.publishFailure(ctx)
|
||||||
|
out <- cliproxyexecutor.StreamChunk{Err: errScan}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
return stream, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// CountTokens calls Vertex countTokens endpoint.
|
||||||
|
func (e *GeminiVertexExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
|
||||||
|
projectID, location, saJSON, errCreds := vertexCreds(auth)
|
||||||
|
if errCreds != nil {
|
||||||
|
return cliproxyexecutor.Response{}, errCreds
|
||||||
|
}
|
||||||
|
from := opts.SourceFormat
|
||||||
|
to := sdktranslator.FromString("gemini")
|
||||||
|
translatedReq := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), false)
|
||||||
|
if budgetOverride, includeOverride, ok := util.GeminiThinkingFromMetadata(req.Metadata); ok && util.ModelSupportsThinking(req.Model) {
|
||||||
|
if budgetOverride != nil {
|
||||||
|
norm := util.NormalizeThinkingBudget(req.Model, *budgetOverride)
|
||||||
|
budgetOverride = &norm
|
||||||
|
}
|
||||||
|
translatedReq = util.ApplyGeminiThinkingConfig(translatedReq, budgetOverride, includeOverride)
|
||||||
|
}
|
||||||
|
translatedReq = util.StripThinkingConfigIfUnsupported(req.Model, translatedReq)
|
||||||
|
translatedReq = fixGeminiImageAspectRatio(req.Model, translatedReq)
|
||||||
|
respCtx := context.WithValue(ctx, "alt", opts.Alt)
|
||||||
|
translatedReq, _ = sjson.DeleteBytes(translatedReq, "tools")
|
||||||
|
translatedReq, _ = sjson.DeleteBytes(translatedReq, "generationConfig")
|
||||||
|
translatedReq, _ = sjson.DeleteBytes(translatedReq, "safetySettings")
|
||||||
|
|
||||||
|
baseURL := vertexBaseURL(location)
|
||||||
|
url := fmt.Sprintf("%s/%s/projects/%s/locations/%s/publishers/google/models/%s:%s", baseURL, vertexAPIVersion, projectID, location, req.Model, "countTokens")
|
||||||
|
|
||||||
|
httpReq, errNewReq := http.NewRequestWithContext(respCtx, http.MethodPost, url, bytes.NewReader(translatedReq))
|
||||||
|
if errNewReq != nil {
|
||||||
|
return cliproxyexecutor.Response{}, errNewReq
|
||||||
|
}
|
||||||
|
httpReq.Header.Set("Content-Type", "application/json")
|
||||||
|
if token, errTok := vertexAccessToken(ctx, saJSON); errTok == nil && token != "" {
|
||||||
|
httpReq.Header.Set("Authorization", "Bearer "+token)
|
||||||
|
} else if errTok != nil {
|
||||||
|
log.Errorf("vertex executor: access token error: %v", errTok)
|
||||||
|
return cliproxyexecutor.Response{}, statusErr{code: 500, msg: "internal server error"}
|
||||||
|
}
|
||||||
|
applyGeminiHeaders(httpReq, auth)
|
||||||
|
|
||||||
|
var authID, authLabel, authType, authValue string
|
||||||
|
if auth != nil {
|
||||||
|
authID = auth.ID
|
||||||
|
authLabel = auth.Label
|
||||||
|
authType, authValue = auth.AccountInfo()
|
||||||
|
}
|
||||||
|
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{
|
||||||
|
URL: url,
|
||||||
|
Method: http.MethodPost,
|
||||||
|
Headers: httpReq.Header.Clone(),
|
||||||
|
Body: translatedReq,
|
||||||
|
Provider: e.Identifier(),
|
||||||
|
AuthID: authID,
|
||||||
|
AuthLabel: authLabel,
|
||||||
|
AuthType: authType,
|
||||||
|
AuthValue: authValue,
|
||||||
|
})
|
||||||
|
|
||||||
|
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
|
||||||
|
httpResp, errDo := httpClient.Do(httpReq)
|
||||||
|
if errDo != nil {
|
||||||
|
recordAPIResponseError(ctx, e.cfg, errDo)
|
||||||
|
return cliproxyexecutor.Response{}, errDo
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
if errClose := httpResp.Body.Close(); errClose != nil {
|
||||||
|
log.Errorf("vertex executor: close response body error: %v", errClose)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
|
||||||
|
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
|
||||||
|
b, _ := io.ReadAll(httpResp.Body)
|
||||||
|
appendAPIResponseChunk(ctx, e.cfg, b)
|
||||||
|
log.Debugf("request error, error status: %d, error body: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
|
||||||
|
return cliproxyexecutor.Response{}, statusErr{code: httpResp.StatusCode, msg: string(b)}
|
||||||
|
}
|
||||||
|
data, errRead := io.ReadAll(httpResp.Body)
|
||||||
|
if errRead != nil {
|
||||||
|
recordAPIResponseError(ctx, e.cfg, errRead)
|
||||||
|
return cliproxyexecutor.Response{}, errRead
|
||||||
|
}
|
||||||
|
appendAPIResponseChunk(ctx, e.cfg, data)
|
||||||
|
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
|
||||||
|
log.Debugf("request error, error status: %d, error body: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), data))
|
||||||
|
return cliproxyexecutor.Response{}, statusErr{code: httpResp.StatusCode, msg: string(data)}
|
||||||
|
}
|
||||||
|
count := gjson.GetBytes(data, "totalTokens").Int()
|
||||||
|
out := sdktranslator.TranslateTokenCount(ctx, to, from, count, data)
|
||||||
|
return cliproxyexecutor.Response{Payload: []byte(out)}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Refresh is a no-op for service account based credentials.
|
||||||
|
func (e *GeminiVertexExecutor) Refresh(_ context.Context, auth *cliproxyauth.Auth) (*cliproxyauth.Auth, error) {
|
||||||
|
return auth, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// vertexCreds extracts project, location and raw service account JSON from auth metadata.
|
||||||
|
func vertexCreds(a *cliproxyauth.Auth) (projectID, location string, serviceAccountJSON []byte, err error) {
|
||||||
|
if a == nil || a.Metadata == nil {
|
||||||
|
return "", "", nil, fmt.Errorf("vertex executor: missing auth metadata")
|
||||||
|
}
|
||||||
|
if v, ok := a.Metadata["project_id"].(string); ok {
|
||||||
|
projectID = strings.TrimSpace(v)
|
||||||
|
}
|
||||||
|
if projectID == "" {
|
||||||
|
// Some service accounts may use "project"; still prefer standard field
|
||||||
|
if v, ok := a.Metadata["project"].(string); ok {
|
||||||
|
projectID = strings.TrimSpace(v)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if projectID == "" {
|
||||||
|
return "", "", nil, fmt.Errorf("vertex executor: missing project_id in credentials")
|
||||||
|
}
|
||||||
|
if v, ok := a.Metadata["location"].(string); ok && strings.TrimSpace(v) != "" {
|
||||||
|
location = strings.TrimSpace(v)
|
||||||
|
} else {
|
||||||
|
location = "us-central1"
|
||||||
|
}
|
||||||
|
var sa map[string]any
|
||||||
|
if raw, ok := a.Metadata["service_account"].(map[string]any); ok {
|
||||||
|
sa = raw
|
||||||
|
}
|
||||||
|
if sa == nil {
|
||||||
|
return "", "", nil, fmt.Errorf("vertex executor: missing service_account in credentials")
|
||||||
|
}
|
||||||
|
normalized, errNorm := vertexauth.NormalizeServiceAccountMap(sa)
|
||||||
|
if errNorm != nil {
|
||||||
|
return "", "", nil, fmt.Errorf("vertex executor: %w", errNorm)
|
||||||
|
}
|
||||||
|
saJSON, errMarshal := json.Marshal(normalized)
|
||||||
|
if errMarshal != nil {
|
||||||
|
return "", "", nil, fmt.Errorf("vertex executor: marshal service_account failed: %w", errMarshal)
|
||||||
|
}
|
||||||
|
return projectID, location, saJSON, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func vertexBaseURL(location string) string {
|
||||||
|
loc := strings.TrimSpace(location)
|
||||||
|
if loc == "" {
|
||||||
|
loc = "us-central1"
|
||||||
|
}
|
||||||
|
return fmt.Sprintf("https://%s-aiplatform.googleapis.com", loc)
|
||||||
|
}
|
||||||
|
|
||||||
|
func vertexAccessToken(ctx context.Context, saJSON []byte) (string, error) {
|
||||||
|
// Use cloud-platform scope for Vertex AI.
|
||||||
|
creds, errCreds := google.CredentialsFromJSON(ctx, saJSON, "https://www.googleapis.com/auth/cloud-platform")
|
||||||
|
if errCreds != nil {
|
||||||
|
return "", fmt.Errorf("vertex executor: parse service account json failed: %w", errCreds)
|
||||||
|
}
|
||||||
|
tok, errTok := creds.TokenSource.Token()
|
||||||
|
if errTok != nil {
|
||||||
|
return "", fmt.Errorf("vertex executor: get access token failed: %w", errTok)
|
||||||
|
}
|
||||||
|
return tok.AccessToken, nil
|
||||||
|
}
|
||||||
@@ -324,6 +324,8 @@ func (s *Service) ensureExecutorsForAuth(a *coreauth.Auth) {
|
|||||||
switch strings.ToLower(a.Provider) {
|
switch strings.ToLower(a.Provider) {
|
||||||
case "gemini":
|
case "gemini":
|
||||||
s.coreManager.RegisterExecutor(executor.NewGeminiExecutor(s.cfg))
|
s.coreManager.RegisterExecutor(executor.NewGeminiExecutor(s.cfg))
|
||||||
|
case "vertex":
|
||||||
|
s.coreManager.RegisterExecutor(executor.NewGeminiVertexExecutor(s.cfg))
|
||||||
case "gemini-cli":
|
case "gemini-cli":
|
||||||
s.coreManager.RegisterExecutor(executor.NewGeminiCLIExecutor(s.cfg))
|
s.coreManager.RegisterExecutor(executor.NewGeminiCLIExecutor(s.cfg))
|
||||||
case "aistudio":
|
case "aistudio":
|
||||||
@@ -619,6 +621,9 @@ func (s *Service) registerModelsForAuth(a *coreauth.Auth) {
|
|||||||
switch provider {
|
switch provider {
|
||||||
case "gemini":
|
case "gemini":
|
||||||
models = registry.GetGeminiModels()
|
models = registry.GetGeminiModels()
|
||||||
|
case "vertex":
|
||||||
|
// Vertex AI Gemini supports the same model identifiers as Gemini.
|
||||||
|
models = registry.GetGeminiModels()
|
||||||
case "gemini-cli":
|
case "gemini-cli":
|
||||||
models = registry.GetGeminiCLIModels()
|
models = registry.GetGeminiCLIModels()
|
||||||
case "aistudio":
|
case "aistudio":
|
||||||
|
|||||||
Reference in New Issue
Block a user