diff --git a/cmd/server/main.go b/cmd/server/main.go index 78259928..0a941b7c 100644 --- a/cmd/server/main.go +++ b/cmd/server/main.go @@ -57,6 +57,7 @@ func main() { var iflowLogin bool var noBrowser bool var projectID string + var vertexImport string var configPath string var password string @@ -69,6 +70,7 @@ func main() { flag.BoolVar(&noBrowser, "no-browser", false, "Don't open browser automatically for OAuth") flag.StringVar(&projectID, "project_id", "", "Project ID (Gemini only, not required)") flag.StringVar(&configPath, "config", DefaultConfigPath, "Configure File Path") + flag.StringVar(&vertexImport, "vertex-import", "", "Import Vertex service account key JSON file") flag.StringVar(&password, "password", "", "") flag.CommandLine.Usage = func() { @@ -417,7 +419,10 @@ func main() { // Handle different command modes based on the provided flags. - if login { + if vertexImport != "" { + // Handle Vertex service account import + cmd.DoVertexImport(cfg, vertexImport) + } else if login { // Handle Google/Gemini login cmd.DoLogin(cfg, projectID, options) } else if codexLogin { diff --git a/internal/auth/vertex/keyutil.go b/internal/auth/vertex/keyutil.go new file mode 100644 index 00000000..a10ade17 --- /dev/null +++ b/internal/auth/vertex/keyutil.go @@ -0,0 +1,208 @@ +package vertex + +import ( + "crypto/rsa" + "crypto/x509" + "encoding/base64" + "encoding/json" + "encoding/pem" + "fmt" + "strings" +) + +// NormalizeServiceAccountJSON normalizes the given JSON-encoded service account payload. +// It returns the normalized JSON (with sanitized private_key) or, if normalization fails, +// the original bytes and the encountered error. +func NormalizeServiceAccountJSON(raw []byte) ([]byte, error) { + if len(raw) == 0 { + return raw, nil + } + var payload map[string]any + if err := json.Unmarshal(raw, &payload); err != nil { + return raw, err + } + normalized, err := NormalizeServiceAccountMap(payload) + if err != nil { + return raw, err + } + out, err := json.Marshal(normalized) + if err != nil { + return raw, err + } + return out, nil +} + +// NormalizeServiceAccountMap returns a copy of the given service account map with +// a sanitized private_key field that is guaranteed to contain a valid RSA PRIVATE KEY PEM block. +func NormalizeServiceAccountMap(sa map[string]any) (map[string]any, error) { + if sa == nil { + return nil, fmt.Errorf("service account payload is empty") + } + pk, _ := sa["private_key"].(string) + if strings.TrimSpace(pk) == "" { + return nil, fmt.Errorf("service account missing private_key") + } + normalized, err := sanitizePrivateKey(pk) + if err != nil { + return nil, err + } + clone := make(map[string]any, len(sa)) + for k, v := range sa { + clone[k] = v + } + clone["private_key"] = normalized + return clone, nil +} + +func sanitizePrivateKey(raw string) (string, error) { + pk := strings.ReplaceAll(raw, "\r\n", "\n") + pk = strings.ReplaceAll(pk, "\r", "\n") + pk = stripANSIEscape(pk) + pk = strings.ToValidUTF8(pk, "") + pk = strings.TrimSpace(pk) + + normalized := pk + if block, _ := pem.Decode([]byte(pk)); block == nil { + // Attempt to reconstruct from the textual payload. + if reconstructed, err := rebuildPEM(pk); err == nil { + normalized = reconstructed + } else { + return "", fmt.Errorf("private_key is not valid pem: %w", err) + } + } + + block, _ := pem.Decode([]byte(normalized)) + if block == nil { + return "", fmt.Errorf("private_key pem decode failed") + } + + rsaBlock, err := ensureRSAPrivateKey(block) + if err != nil { + return "", err + } + return string(pem.EncodeToMemory(rsaBlock)), nil +} + +func ensureRSAPrivateKey(block *pem.Block) (*pem.Block, error) { + if block == nil { + return nil, fmt.Errorf("pem block is nil") + } + + if block.Type == "RSA PRIVATE KEY" { + if _, err := x509.ParsePKCS1PrivateKey(block.Bytes); err != nil { + return nil, fmt.Errorf("private_key invalid rsa: %w", err) + } + return block, nil + } + + if block.Type == "PRIVATE KEY" { + key, err := x509.ParsePKCS8PrivateKey(block.Bytes) + if err != nil { + return nil, fmt.Errorf("private_key invalid pkcs8: %w", err) + } + rsaKey, ok := key.(*rsa.PrivateKey) + if !ok { + return nil, fmt.Errorf("private_key is not an RSA key") + } + der := x509.MarshalPKCS1PrivateKey(rsaKey) + return &pem.Block{Type: "RSA PRIVATE KEY", Bytes: der}, nil + } + + // Attempt auto-detection: try PKCS#1 first, then PKCS#8. + if rsaKey, err := x509.ParsePKCS1PrivateKey(block.Bytes); err == nil { + der := x509.MarshalPKCS1PrivateKey(rsaKey) + return &pem.Block{Type: "RSA PRIVATE KEY", Bytes: der}, nil + } + if key, err := x509.ParsePKCS8PrivateKey(block.Bytes); err == nil { + if rsaKey, ok := key.(*rsa.PrivateKey); ok { + der := x509.MarshalPKCS1PrivateKey(rsaKey) + return &pem.Block{Type: "RSA PRIVATE KEY", Bytes: der}, nil + } + } + return nil, fmt.Errorf("private_key uses unsupported format") +} + +func rebuildPEM(raw string) (string, error) { + kind := "PRIVATE KEY" + if strings.Contains(raw, "RSA PRIVATE KEY") { + kind = "RSA PRIVATE KEY" + } + header := "-----BEGIN " + kind + "-----" + footer := "-----END " + kind + "-----" + start := strings.Index(raw, header) + end := strings.Index(raw, footer) + if start < 0 || end <= start { + return "", fmt.Errorf("missing pem markers") + } + body := raw[start+len(header) : end] + payload := filterBase64(body) + if payload == "" { + return "", fmt.Errorf("private_key base64 payload empty") + } + der, err := base64.StdEncoding.DecodeString(payload) + if err != nil { + return "", fmt.Errorf("private_key base64 decode failed: %w", err) + } + block := &pem.Block{Type: kind, Bytes: der} + return string(pem.EncodeToMemory(block)), nil +} + +func filterBase64(s string) string { + var b strings.Builder + for _, r := range s { + switch { + case r >= 'A' && r <= 'Z': + b.WriteRune(r) + case r >= 'a' && r <= 'z': + b.WriteRune(r) + case r >= '0' && r <= '9': + b.WriteRune(r) + case r == '+' || r == '/' || r == '=': + b.WriteRune(r) + default: + // skip + } + } + return b.String() +} + +func stripANSIEscape(s string) string { + in := []rune(s) + var out []rune + for i := 0; i < len(in); i++ { + r := in[i] + if r != 0x1b { + out = append(out, r) + continue + } + if i+1 >= len(in) { + continue + } + next := in[i+1] + switch next { + case ']': + i += 2 + for i < len(in) { + if in[i] == 0x07 { + break + } + if in[i] == 0x1b && i+1 < len(in) && in[i+1] == '\\' { + i++ + break + } + i++ + } + case '[': + i += 2 + for i < len(in) { + if (in[i] >= 'A' && in[i] <= 'Z') || (in[i] >= 'a' && in[i] <= 'z') { + break + } + i++ + } + default: + // skip single ESC + } + } + return string(out) +} diff --git a/internal/auth/vertex/vertex_credentials.go b/internal/auth/vertex/vertex_credentials.go new file mode 100644 index 00000000..4853d340 --- /dev/null +++ b/internal/auth/vertex/vertex_credentials.go @@ -0,0 +1,66 @@ +// Package vertex provides token storage for Google Vertex AI Gemini via service account credentials. +// It serialises service account JSON into an auth file that is consumed by the runtime executor. +package vertex + +import ( + "encoding/json" + "fmt" + "os" + "path/filepath" + + "github.com/router-for-me/CLIProxyAPI/v6/internal/misc" + log "github.com/sirupsen/logrus" +) + +// VertexCredentialStorage stores the service account JSON for Vertex AI access. +// The content is persisted verbatim under the "service_account" key, together with +// helper fields for project, location and email to improve logging and discovery. +type VertexCredentialStorage struct { + // ServiceAccount holds the parsed service account JSON content. + ServiceAccount map[string]any `json:"service_account"` + + // ProjectID is derived from the service account JSON (project_id). + ProjectID string `json:"project_id"` + + // Email is the client_email from the service account JSON. + Email string `json:"email"` + + // Location optionally sets a default region (e.g., us-central1) for Vertex endpoints. + Location string `json:"location,omitempty"` + + // Type is the provider identifier stored alongside credentials. Always "vertex". + Type string `json:"type"` +} + +// SaveTokenToFile writes the credential payload to the given file path in JSON format. +// It ensures the parent directory exists and logs the operation for transparency. +func (s *VertexCredentialStorage) SaveTokenToFile(authFilePath string) error { + misc.LogSavingCredentials(authFilePath) + if s == nil { + return fmt.Errorf("vertex credential: storage is nil") + } + if s.ServiceAccount == nil { + return fmt.Errorf("vertex credential: service account content is empty") + } + // Ensure we tag the file with the provider type. + s.Type = "vertex" + + if err := os.MkdirAll(filepath.Dir(authFilePath), 0o700); err != nil { + return fmt.Errorf("vertex credential: create directory failed: %w", err) + } + f, err := os.Create(authFilePath) + if err != nil { + return fmt.Errorf("vertex credential: create file failed: %w", err) + } + defer func() { + if errClose := f.Close(); errClose != nil { + log.Errorf("vertex credential: failed to close file: %v", errClose) + } + }() + enc := json.NewEncoder(f) + enc.SetIndent("", " ") + if err = enc.Encode(s); err != nil { + return fmt.Errorf("vertex credential: encode failed: %w", err) + } + return nil +} diff --git a/internal/cmd/vertex_import.go b/internal/cmd/vertex_import.go new file mode 100644 index 00000000..ebb32d0c --- /dev/null +++ b/internal/cmd/vertex_import.go @@ -0,0 +1,123 @@ +// Package cmd contains CLI helpers. This file implements importing a Vertex AI +// service account JSON into the auth store as a dedicated "vertex" credential. +package cmd + +import ( + "context" + "encoding/json" + "fmt" + "os" + "strings" + + "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/vertex" + "github.com/router-for-me/CLIProxyAPI/v6/internal/config" + "github.com/router-for-me/CLIProxyAPI/v6/internal/util" + sdkAuth "github.com/router-for-me/CLIProxyAPI/v6/sdk/auth" + coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth" + log "github.com/sirupsen/logrus" +) + +// DoVertexImport imports a Google Cloud service account key JSON and persists +// it as a "vertex" provider credential. The file content is embedded in the auth +// file to allow portable deployment across stores. +func DoVertexImport(cfg *config.Config, keyPath string) { + if cfg == nil { + cfg = &config.Config{} + } + if resolved, errResolve := util.ResolveAuthDir(cfg.AuthDir); errResolve == nil { + cfg.AuthDir = resolved + } + rawPath := strings.TrimSpace(keyPath) + if rawPath == "" { + log.Fatalf("vertex-import: missing service account key path") + return + } + data, errRead := os.ReadFile(rawPath) + if errRead != nil { + log.Fatalf("vertex-import: read file failed: %v", errRead) + return + } + var sa map[string]any + if errUnmarshal := json.Unmarshal(data, &sa); errUnmarshal != nil { + log.Fatalf("vertex-import: invalid service account json: %v", errUnmarshal) + return + } + // Validate and normalize private_key before saving + normalizedSA, errFix := vertex.NormalizeServiceAccountMap(sa) + if errFix != nil { + log.Fatalf("vertex-import: %v", errFix) + return + } + sa = normalizedSA + email, _ := sa["client_email"].(string) + projectID, _ := sa["project_id"].(string) + if strings.TrimSpace(projectID) == "" { + log.Fatalf("vertex-import: project_id missing in service account json") + return + } + if strings.TrimSpace(email) == "" { + // Keep empty email but warn + log.Warn("vertex-import: client_email missing in service account json") + } + // Default location if not provided by user. Can be edited in the saved file later. + location := "us-central1" + + fileName := fmt.Sprintf("vertex-%s.json", sanitizeFilePart(projectID)) + // Build auth record + storage := &vertex.VertexCredentialStorage{ + ServiceAccount: sa, + ProjectID: projectID, + Email: email, + Location: location, + } + metadata := map[string]any{ + "service_account": sa, + "project_id": projectID, + "email": email, + "location": location, + "type": "vertex", + "label": labelForVertex(projectID, email), + } + record := &coreauth.Auth{ + ID: fileName, + Provider: "vertex", + FileName: fileName, + Storage: storage, + Metadata: metadata, + } + + store := sdkAuth.GetTokenStore() + if setter, ok := store.(interface{ SetBaseDir(string) }); ok { + setter.SetBaseDir(cfg.AuthDir) + } + path, errSave := store.Save(context.Background(), record) + if errSave != nil { + log.Fatalf("vertex-import: save credential failed: %v", errSave) + return + } + fmt.Printf("Vertex credentials imported: %s\n", path) +} + +func sanitizeFilePart(s string) string { + out := strings.TrimSpace(s) + replacers := []string{"/", "_", "\\", "_", ":", "_", " ", "-"} + for i := 0; i < len(replacers); i += 2 { + out = strings.ReplaceAll(out, replacers[i], replacers[i+1]) + } + return out +} + +func labelForVertex(projectID, email string) string { + p := strings.TrimSpace(projectID) + e := strings.TrimSpace(email) + if p != "" && e != "" { + return fmt.Sprintf("%s (%s)", p, e) + } + if p != "" { + return p + } + if e != "" { + return e + } + return "vertex" +} diff --git a/internal/runtime/executor/gemini_vertex_executor.go b/internal/runtime/executor/gemini_vertex_executor.go new file mode 100644 index 00000000..4e606390 --- /dev/null +++ b/internal/runtime/executor/gemini_vertex_executor.go @@ -0,0 +1,421 @@ +// Package executor contains provider executors. This file implements the Vertex AI +// Gemini executor that talks to Google Vertex AI endpoints using service account +// credentials imported by the CLI. +package executor + +import ( + "bufio" + "bytes" + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "strings" + + vertexauth "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/vertex" + "github.com/router-for-me/CLIProxyAPI/v6/internal/config" + "github.com/router-for-me/CLIProxyAPI/v6/internal/util" + cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth" + cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor" + sdktranslator "github.com/router-for-me/CLIProxyAPI/v6/sdk/translator" + log "github.com/sirupsen/logrus" + "github.com/tidwall/gjson" + "github.com/tidwall/sjson" + "golang.org/x/oauth2/google" +) + +const ( + // vertexAPIVersion aligns with current public Vertex Generative AI API. + vertexAPIVersion = "v1" +) + +// GeminiVertexExecutor sends requests to Vertex AI Gemini endpoints using service account credentials. +type GeminiVertexExecutor struct { + cfg *config.Config +} + +// NewGeminiVertexExecutor constructs the Vertex executor. +func NewGeminiVertexExecutor(cfg *config.Config) *GeminiVertexExecutor { + return &GeminiVertexExecutor{cfg: cfg} +} + +// Identifier returns provider key for manager routing. +func (e *GeminiVertexExecutor) Identifier() string { return "vertex" } + +// PrepareRequest is a no-op for Vertex. +func (e *GeminiVertexExecutor) PrepareRequest(_ *http.Request, _ *cliproxyauth.Auth) error { + return nil +} + +// Execute handles non-streaming requests. +func (e *GeminiVertexExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (resp cliproxyexecutor.Response, err error) { + projectID, location, saJSON, errCreds := vertexCreds(auth) + if errCreds != nil { + return resp, errCreds + } + + reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth) + defer reporter.trackFailure(ctx, &err) + + from := opts.SourceFormat + to := sdktranslator.FromString("gemini") + body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), false) + if budgetOverride, includeOverride, ok := util.GeminiThinkingFromMetadata(req.Metadata); ok && util.ModelSupportsThinking(req.Model) { + if budgetOverride != nil { + norm := util.NormalizeThinkingBudget(req.Model, *budgetOverride) + budgetOverride = &norm + } + body = util.ApplyGeminiThinkingConfig(body, budgetOverride, includeOverride) + } + body = util.StripThinkingConfigIfUnsupported(req.Model, body) + body = fixGeminiImageAspectRatio(req.Model, body) + + action := "generateContent" + if req.Metadata != nil { + if a, _ := req.Metadata["action"].(string); a == "countTokens" { + action = "countTokens" + } + } + baseURL := vertexBaseURL(location) + url := fmt.Sprintf("%s/%s/projects/%s/locations/%s/publishers/google/models/%s:%s", baseURL, vertexAPIVersion, projectID, location, req.Model, action) + if opts.Alt != "" && action != "countTokens" { + url = url + fmt.Sprintf("?$alt=%s", opts.Alt) + } + body, _ = sjson.DeleteBytes(body, "session_id") + + httpReq, errNewReq := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(body)) + if errNewReq != nil { + return resp, errNewReq + } + httpReq.Header.Set("Content-Type", "application/json") + if token, errTok := vertexAccessToken(ctx, saJSON); errTok == nil && token != "" { + httpReq.Header.Set("Authorization", "Bearer "+token) + } else if errTok != nil { + log.Errorf("vertex executor: access token error: %v", errTok) + return resp, statusErr{code: 500, msg: "internal server error"} + } + applyGeminiHeaders(httpReq, auth) + + var authID, authLabel, authType, authValue string + if auth != nil { + authID = auth.ID + authLabel = auth.Label + authType, authValue = auth.AccountInfo() + } + recordAPIRequest(ctx, e.cfg, upstreamRequestLog{ + URL: url, + Method: http.MethodPost, + Headers: httpReq.Header.Clone(), + Body: body, + Provider: e.Identifier(), + AuthID: authID, + AuthLabel: authLabel, + AuthType: authType, + AuthValue: authValue, + }) + + httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0) + httpResp, errDo := httpClient.Do(httpReq) + if errDo != nil { + recordAPIResponseError(ctx, e.cfg, errDo) + return resp, errDo + } + defer func() { + if errClose := httpResp.Body.Close(); errClose != nil { + log.Errorf("vertex executor: close response body error: %v", errClose) + } + }() + recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone()) + if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 { + b, _ := io.ReadAll(httpResp.Body) + appendAPIResponseChunk(ctx, e.cfg, b) + log.Debugf("request error, error status: %d, error body: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b)) + err = statusErr{code: httpResp.StatusCode, msg: string(b)} + return resp, err + } + data, errRead := io.ReadAll(httpResp.Body) + if errRead != nil { + recordAPIResponseError(ctx, e.cfg, errRead) + return resp, errRead + } + appendAPIResponseChunk(ctx, e.cfg, data) + reporter.publish(ctx, parseGeminiUsage(data)) + var param any + out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, bytes.Clone(opts.OriginalRequest), body, data, ¶m) + resp = cliproxyexecutor.Response{Payload: []byte(out)} + return resp, nil +} + +// ExecuteStream handles SSE streaming for Vertex. +func (e *GeminiVertexExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (stream <-chan cliproxyexecutor.StreamChunk, err error) { + projectID, location, saJSON, errCreds := vertexCreds(auth) + if errCreds != nil { + return nil, errCreds + } + + reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth) + defer reporter.trackFailure(ctx, &err) + + from := opts.SourceFormat + to := sdktranslator.FromString("gemini") + body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), true) + if budgetOverride, includeOverride, ok := util.GeminiThinkingFromMetadata(req.Metadata); ok && util.ModelSupportsThinking(req.Model) { + if budgetOverride != nil { + norm := util.NormalizeThinkingBudget(req.Model, *budgetOverride) + budgetOverride = &norm + } + body = util.ApplyGeminiThinkingConfig(body, budgetOverride, includeOverride) + } + body = util.StripThinkingConfigIfUnsupported(req.Model, body) + body = fixGeminiImageAspectRatio(req.Model, body) + + baseURL := vertexBaseURL(location) + url := fmt.Sprintf("%s/%s/projects/%s/locations/%s/publishers/google/models/%s:%s", baseURL, vertexAPIVersion, projectID, location, req.Model, "streamGenerateContent") + if opts.Alt == "" { + url = url + "?alt=sse" + } else { + url = url + fmt.Sprintf("?$alt=%s", opts.Alt) + } + body, _ = sjson.DeleteBytes(body, "session_id") + + httpReq, errNewReq := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(body)) + if errNewReq != nil { + return nil, errNewReq + } + httpReq.Header.Set("Content-Type", "application/json") + if token, errTok := vertexAccessToken(ctx, saJSON); errTok == nil && token != "" { + httpReq.Header.Set("Authorization", "Bearer "+token) + } else if errTok != nil { + log.Errorf("vertex executor: access token error: %v", errTok) + return nil, statusErr{code: 500, msg: "internal server error"} + } + applyGeminiHeaders(httpReq, auth) + + var authID, authLabel, authType, authValue string + if auth != nil { + authID = auth.ID + authLabel = auth.Label + authType, authValue = auth.AccountInfo() + } + recordAPIRequest(ctx, e.cfg, upstreamRequestLog{ + URL: url, + Method: http.MethodPost, + Headers: httpReq.Header.Clone(), + Body: body, + Provider: e.Identifier(), + AuthID: authID, + AuthLabel: authLabel, + AuthType: authType, + AuthValue: authValue, + }) + + httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0) + httpResp, errDo := httpClient.Do(httpReq) + if errDo != nil { + recordAPIResponseError(ctx, e.cfg, errDo) + return nil, errDo + } + recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone()) + if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 { + b, _ := io.ReadAll(httpResp.Body) + appendAPIResponseChunk(ctx, e.cfg, b) + log.Debugf("request error, error status: %d, error body: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b)) + if errClose := httpResp.Body.Close(); errClose != nil { + log.Errorf("vertex executor: close response body error: %v", errClose) + } + return nil, statusErr{code: httpResp.StatusCode, msg: string(b)} + } + + out := make(chan cliproxyexecutor.StreamChunk) + stream = out + go func() { + defer close(out) + defer func() { + if errClose := httpResp.Body.Close(); errClose != nil { + log.Errorf("vertex executor: close response body error: %v", errClose) + } + }() + scanner := bufio.NewScanner(httpResp.Body) + buf := make([]byte, 20_971_520) + scanner.Buffer(buf, 20_971_520) + var param any + for scanner.Scan() { + line := scanner.Bytes() + appendAPIResponseChunk(ctx, e.cfg, line) + if detail, ok := parseGeminiStreamUsage(line); ok { + reporter.publish(ctx, detail) + } + lines := sdktranslator.TranslateStream(ctx, to, from, req.Model, bytes.Clone(opts.OriginalRequest), body, bytes.Clone(line), ¶m) + for i := range lines { + out <- cliproxyexecutor.StreamChunk{Payload: []byte(lines[i])} + } + } + lines := sdktranslator.TranslateStream(ctx, to, from, req.Model, bytes.Clone(opts.OriginalRequest), body, []byte("[DONE]"), ¶m) + for i := range lines { + out <- cliproxyexecutor.StreamChunk{Payload: []byte(lines[i])} + } + if errScan := scanner.Err(); errScan != nil { + recordAPIResponseError(ctx, e.cfg, errScan) + reporter.publishFailure(ctx) + out <- cliproxyexecutor.StreamChunk{Err: errScan} + } + }() + return stream, nil +} + +// CountTokens calls Vertex countTokens endpoint. +func (e *GeminiVertexExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) { + projectID, location, saJSON, errCreds := vertexCreds(auth) + if errCreds != nil { + return cliproxyexecutor.Response{}, errCreds + } + from := opts.SourceFormat + to := sdktranslator.FromString("gemini") + translatedReq := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), false) + if budgetOverride, includeOverride, ok := util.GeminiThinkingFromMetadata(req.Metadata); ok && util.ModelSupportsThinking(req.Model) { + if budgetOverride != nil { + norm := util.NormalizeThinkingBudget(req.Model, *budgetOverride) + budgetOverride = &norm + } + translatedReq = util.ApplyGeminiThinkingConfig(translatedReq, budgetOverride, includeOverride) + } + translatedReq = util.StripThinkingConfigIfUnsupported(req.Model, translatedReq) + translatedReq = fixGeminiImageAspectRatio(req.Model, translatedReq) + respCtx := context.WithValue(ctx, "alt", opts.Alt) + translatedReq, _ = sjson.DeleteBytes(translatedReq, "tools") + translatedReq, _ = sjson.DeleteBytes(translatedReq, "generationConfig") + translatedReq, _ = sjson.DeleteBytes(translatedReq, "safetySettings") + + baseURL := vertexBaseURL(location) + url := fmt.Sprintf("%s/%s/projects/%s/locations/%s/publishers/google/models/%s:%s", baseURL, vertexAPIVersion, projectID, location, req.Model, "countTokens") + + httpReq, errNewReq := http.NewRequestWithContext(respCtx, http.MethodPost, url, bytes.NewReader(translatedReq)) + if errNewReq != nil { + return cliproxyexecutor.Response{}, errNewReq + } + httpReq.Header.Set("Content-Type", "application/json") + if token, errTok := vertexAccessToken(ctx, saJSON); errTok == nil && token != "" { + httpReq.Header.Set("Authorization", "Bearer "+token) + } else if errTok != nil { + log.Errorf("vertex executor: access token error: %v", errTok) + return cliproxyexecutor.Response{}, statusErr{code: 500, msg: "internal server error"} + } + applyGeminiHeaders(httpReq, auth) + + var authID, authLabel, authType, authValue string + if auth != nil { + authID = auth.ID + authLabel = auth.Label + authType, authValue = auth.AccountInfo() + } + recordAPIRequest(ctx, e.cfg, upstreamRequestLog{ + URL: url, + Method: http.MethodPost, + Headers: httpReq.Header.Clone(), + Body: translatedReq, + Provider: e.Identifier(), + AuthID: authID, + AuthLabel: authLabel, + AuthType: authType, + AuthValue: authValue, + }) + + httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0) + httpResp, errDo := httpClient.Do(httpReq) + if errDo != nil { + recordAPIResponseError(ctx, e.cfg, errDo) + return cliproxyexecutor.Response{}, errDo + } + defer func() { + if errClose := httpResp.Body.Close(); errClose != nil { + log.Errorf("vertex executor: close response body error: %v", errClose) + } + }() + recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone()) + if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 { + b, _ := io.ReadAll(httpResp.Body) + appendAPIResponseChunk(ctx, e.cfg, b) + log.Debugf("request error, error status: %d, error body: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b)) + return cliproxyexecutor.Response{}, statusErr{code: httpResp.StatusCode, msg: string(b)} + } + data, errRead := io.ReadAll(httpResp.Body) + if errRead != nil { + recordAPIResponseError(ctx, e.cfg, errRead) + return cliproxyexecutor.Response{}, errRead + } + appendAPIResponseChunk(ctx, e.cfg, data) + if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 { + log.Debugf("request error, error status: %d, error body: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), data)) + return cliproxyexecutor.Response{}, statusErr{code: httpResp.StatusCode, msg: string(data)} + } + count := gjson.GetBytes(data, "totalTokens").Int() + out := sdktranslator.TranslateTokenCount(ctx, to, from, count, data) + return cliproxyexecutor.Response{Payload: []byte(out)}, nil +} + +// Refresh is a no-op for service account based credentials. +func (e *GeminiVertexExecutor) Refresh(_ context.Context, auth *cliproxyauth.Auth) (*cliproxyauth.Auth, error) { + return auth, nil +} + +// vertexCreds extracts project, location and raw service account JSON from auth metadata. +func vertexCreds(a *cliproxyauth.Auth) (projectID, location string, serviceAccountJSON []byte, err error) { + if a == nil || a.Metadata == nil { + return "", "", nil, fmt.Errorf("vertex executor: missing auth metadata") + } + if v, ok := a.Metadata["project_id"].(string); ok { + projectID = strings.TrimSpace(v) + } + if projectID == "" { + // Some service accounts may use "project"; still prefer standard field + if v, ok := a.Metadata["project"].(string); ok { + projectID = strings.TrimSpace(v) + } + } + if projectID == "" { + return "", "", nil, fmt.Errorf("vertex executor: missing project_id in credentials") + } + if v, ok := a.Metadata["location"].(string); ok && strings.TrimSpace(v) != "" { + location = strings.TrimSpace(v) + } else { + location = "us-central1" + } + var sa map[string]any + if raw, ok := a.Metadata["service_account"].(map[string]any); ok { + sa = raw + } + if sa == nil { + return "", "", nil, fmt.Errorf("vertex executor: missing service_account in credentials") + } + normalized, errNorm := vertexauth.NormalizeServiceAccountMap(sa) + if errNorm != nil { + return "", "", nil, fmt.Errorf("vertex executor: %w", errNorm) + } + saJSON, errMarshal := json.Marshal(normalized) + if errMarshal != nil { + return "", "", nil, fmt.Errorf("vertex executor: marshal service_account failed: %w", errMarshal) + } + return projectID, location, saJSON, nil +} + +func vertexBaseURL(location string) string { + loc := strings.TrimSpace(location) + if loc == "" { + loc = "us-central1" + } + return fmt.Sprintf("https://%s-aiplatform.googleapis.com", loc) +} + +func vertexAccessToken(ctx context.Context, saJSON []byte) (string, error) { + // Use cloud-platform scope for Vertex AI. + creds, errCreds := google.CredentialsFromJSON(ctx, saJSON, "https://www.googleapis.com/auth/cloud-platform") + if errCreds != nil { + return "", fmt.Errorf("vertex executor: parse service account json failed: %w", errCreds) + } + tok, errTok := creds.TokenSource.Token() + if errTok != nil { + return "", fmt.Errorf("vertex executor: get access token failed: %w", errTok) + } + return tok.AccessToken, nil +} diff --git a/sdk/cliproxy/service.go b/sdk/cliproxy/service.go index 84d15ffe..57ad3295 100644 --- a/sdk/cliproxy/service.go +++ b/sdk/cliproxy/service.go @@ -324,6 +324,8 @@ func (s *Service) ensureExecutorsForAuth(a *coreauth.Auth) { switch strings.ToLower(a.Provider) { case "gemini": s.coreManager.RegisterExecutor(executor.NewGeminiExecutor(s.cfg)) + case "vertex": + s.coreManager.RegisterExecutor(executor.NewGeminiVertexExecutor(s.cfg)) case "gemini-cli": s.coreManager.RegisterExecutor(executor.NewGeminiCLIExecutor(s.cfg)) case "aistudio": @@ -619,6 +621,9 @@ func (s *Service) registerModelsForAuth(a *coreauth.Auth) { switch provider { case "gemini": models = registry.GetGeminiModels() + case "vertex": + // Vertex AI Gemini supports the same model identifiers as Gemini. + models = registry.GetGeminiModels() case "gemini-cli": models = registry.GetGeminiCLIModels() case "aistudio":