mirror of
https://github.com/router-for-me/CLIProxyAPI.git
synced 2026-02-18 12:20:52 +08:00
feat(vertex): add support for Vertex AI Gemini authentication and execution
Introduce Vertex AI Gemini integration with support for service account-based authentication, credential storage, and import functionality. Added new executor for Vertex AI requests, including execution and streaming paths, and integrated it into the core manager. Enhanced CLI with `--vertex-import` flag for importing service account keys.
This commit is contained in:
208
internal/auth/vertex/keyutil.go
Normal file
208
internal/auth/vertex/keyutil.go
Normal file
@@ -0,0 +1,208 @@
|
||||
package vertex
|
||||
|
||||
import (
|
||||
"crypto/rsa"
|
||||
"crypto/x509"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"encoding/pem"
|
||||
"fmt"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// NormalizeServiceAccountJSON normalizes the given JSON-encoded service account payload.
|
||||
// It returns the normalized JSON (with sanitized private_key) or, if normalization fails,
|
||||
// the original bytes and the encountered error.
|
||||
func NormalizeServiceAccountJSON(raw []byte) ([]byte, error) {
|
||||
if len(raw) == 0 {
|
||||
return raw, nil
|
||||
}
|
||||
var payload map[string]any
|
||||
if err := json.Unmarshal(raw, &payload); err != nil {
|
||||
return raw, err
|
||||
}
|
||||
normalized, err := NormalizeServiceAccountMap(payload)
|
||||
if err != nil {
|
||||
return raw, err
|
||||
}
|
||||
out, err := json.Marshal(normalized)
|
||||
if err != nil {
|
||||
return raw, err
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
// NormalizeServiceAccountMap returns a copy of the given service account map with
|
||||
// a sanitized private_key field that is guaranteed to contain a valid RSA PRIVATE KEY PEM block.
|
||||
func NormalizeServiceAccountMap(sa map[string]any) (map[string]any, error) {
|
||||
if sa == nil {
|
||||
return nil, fmt.Errorf("service account payload is empty")
|
||||
}
|
||||
pk, _ := sa["private_key"].(string)
|
||||
if strings.TrimSpace(pk) == "" {
|
||||
return nil, fmt.Errorf("service account missing private_key")
|
||||
}
|
||||
normalized, err := sanitizePrivateKey(pk)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
clone := make(map[string]any, len(sa))
|
||||
for k, v := range sa {
|
||||
clone[k] = v
|
||||
}
|
||||
clone["private_key"] = normalized
|
||||
return clone, nil
|
||||
}
|
||||
|
||||
func sanitizePrivateKey(raw string) (string, error) {
|
||||
pk := strings.ReplaceAll(raw, "\r\n", "\n")
|
||||
pk = strings.ReplaceAll(pk, "\r", "\n")
|
||||
pk = stripANSIEscape(pk)
|
||||
pk = strings.ToValidUTF8(pk, "")
|
||||
pk = strings.TrimSpace(pk)
|
||||
|
||||
normalized := pk
|
||||
if block, _ := pem.Decode([]byte(pk)); block == nil {
|
||||
// Attempt to reconstruct from the textual payload.
|
||||
if reconstructed, err := rebuildPEM(pk); err == nil {
|
||||
normalized = reconstructed
|
||||
} else {
|
||||
return "", fmt.Errorf("private_key is not valid pem: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
block, _ := pem.Decode([]byte(normalized))
|
||||
if block == nil {
|
||||
return "", fmt.Errorf("private_key pem decode failed")
|
||||
}
|
||||
|
||||
rsaBlock, err := ensureRSAPrivateKey(block)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return string(pem.EncodeToMemory(rsaBlock)), nil
|
||||
}
|
||||
|
||||
func ensureRSAPrivateKey(block *pem.Block) (*pem.Block, error) {
|
||||
if block == nil {
|
||||
return nil, fmt.Errorf("pem block is nil")
|
||||
}
|
||||
|
||||
if block.Type == "RSA PRIVATE KEY" {
|
||||
if _, err := x509.ParsePKCS1PrivateKey(block.Bytes); err != nil {
|
||||
return nil, fmt.Errorf("private_key invalid rsa: %w", err)
|
||||
}
|
||||
return block, nil
|
||||
}
|
||||
|
||||
if block.Type == "PRIVATE KEY" {
|
||||
key, err := x509.ParsePKCS8PrivateKey(block.Bytes)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("private_key invalid pkcs8: %w", err)
|
||||
}
|
||||
rsaKey, ok := key.(*rsa.PrivateKey)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("private_key is not an RSA key")
|
||||
}
|
||||
der := x509.MarshalPKCS1PrivateKey(rsaKey)
|
||||
return &pem.Block{Type: "RSA PRIVATE KEY", Bytes: der}, nil
|
||||
}
|
||||
|
||||
// Attempt auto-detection: try PKCS#1 first, then PKCS#8.
|
||||
if rsaKey, err := x509.ParsePKCS1PrivateKey(block.Bytes); err == nil {
|
||||
der := x509.MarshalPKCS1PrivateKey(rsaKey)
|
||||
return &pem.Block{Type: "RSA PRIVATE KEY", Bytes: der}, nil
|
||||
}
|
||||
if key, err := x509.ParsePKCS8PrivateKey(block.Bytes); err == nil {
|
||||
if rsaKey, ok := key.(*rsa.PrivateKey); ok {
|
||||
der := x509.MarshalPKCS1PrivateKey(rsaKey)
|
||||
return &pem.Block{Type: "RSA PRIVATE KEY", Bytes: der}, nil
|
||||
}
|
||||
}
|
||||
return nil, fmt.Errorf("private_key uses unsupported format")
|
||||
}
|
||||
|
||||
func rebuildPEM(raw string) (string, error) {
|
||||
kind := "PRIVATE KEY"
|
||||
if strings.Contains(raw, "RSA PRIVATE KEY") {
|
||||
kind = "RSA PRIVATE KEY"
|
||||
}
|
||||
header := "-----BEGIN " + kind + "-----"
|
||||
footer := "-----END " + kind + "-----"
|
||||
start := strings.Index(raw, header)
|
||||
end := strings.Index(raw, footer)
|
||||
if start < 0 || end <= start {
|
||||
return "", fmt.Errorf("missing pem markers")
|
||||
}
|
||||
body := raw[start+len(header) : end]
|
||||
payload := filterBase64(body)
|
||||
if payload == "" {
|
||||
return "", fmt.Errorf("private_key base64 payload empty")
|
||||
}
|
||||
der, err := base64.StdEncoding.DecodeString(payload)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("private_key base64 decode failed: %w", err)
|
||||
}
|
||||
block := &pem.Block{Type: kind, Bytes: der}
|
||||
return string(pem.EncodeToMemory(block)), nil
|
||||
}
|
||||
|
||||
func filterBase64(s string) string {
|
||||
var b strings.Builder
|
||||
for _, r := range s {
|
||||
switch {
|
||||
case r >= 'A' && r <= 'Z':
|
||||
b.WriteRune(r)
|
||||
case r >= 'a' && r <= 'z':
|
||||
b.WriteRune(r)
|
||||
case r >= '0' && r <= '9':
|
||||
b.WriteRune(r)
|
||||
case r == '+' || r == '/' || r == '=':
|
||||
b.WriteRune(r)
|
||||
default:
|
||||
// skip
|
||||
}
|
||||
}
|
||||
return b.String()
|
||||
}
|
||||
|
||||
func stripANSIEscape(s string) string {
|
||||
in := []rune(s)
|
||||
var out []rune
|
||||
for i := 0; i < len(in); i++ {
|
||||
r := in[i]
|
||||
if r != 0x1b {
|
||||
out = append(out, r)
|
||||
continue
|
||||
}
|
||||
if i+1 >= len(in) {
|
||||
continue
|
||||
}
|
||||
next := in[i+1]
|
||||
switch next {
|
||||
case ']':
|
||||
i += 2
|
||||
for i < len(in) {
|
||||
if in[i] == 0x07 {
|
||||
break
|
||||
}
|
||||
if in[i] == 0x1b && i+1 < len(in) && in[i+1] == '\\' {
|
||||
i++
|
||||
break
|
||||
}
|
||||
i++
|
||||
}
|
||||
case '[':
|
||||
i += 2
|
||||
for i < len(in) {
|
||||
if (in[i] >= 'A' && in[i] <= 'Z') || (in[i] >= 'a' && in[i] <= 'z') {
|
||||
break
|
||||
}
|
||||
i++
|
||||
}
|
||||
default:
|
||||
// skip single ESC
|
||||
}
|
||||
}
|
||||
return string(out)
|
||||
}
|
||||
66
internal/auth/vertex/vertex_credentials.go
Normal file
66
internal/auth/vertex/vertex_credentials.go
Normal file
@@ -0,0 +1,66 @@
|
||||
// Package vertex provides token storage for Google Vertex AI Gemini via service account credentials.
|
||||
// It serialises service account JSON into an auth file that is consumed by the runtime executor.
|
||||
package vertex
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/misc"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// VertexCredentialStorage stores the service account JSON for Vertex AI access.
|
||||
// The content is persisted verbatim under the "service_account" key, together with
|
||||
// helper fields for project, location and email to improve logging and discovery.
|
||||
type VertexCredentialStorage struct {
|
||||
// ServiceAccount holds the parsed service account JSON content.
|
||||
ServiceAccount map[string]any `json:"service_account"`
|
||||
|
||||
// ProjectID is derived from the service account JSON (project_id).
|
||||
ProjectID string `json:"project_id"`
|
||||
|
||||
// Email is the client_email from the service account JSON.
|
||||
Email string `json:"email"`
|
||||
|
||||
// Location optionally sets a default region (e.g., us-central1) for Vertex endpoints.
|
||||
Location string `json:"location,omitempty"`
|
||||
|
||||
// Type is the provider identifier stored alongside credentials. Always "vertex".
|
||||
Type string `json:"type"`
|
||||
}
|
||||
|
||||
// SaveTokenToFile writes the credential payload to the given file path in JSON format.
|
||||
// It ensures the parent directory exists and logs the operation for transparency.
|
||||
func (s *VertexCredentialStorage) SaveTokenToFile(authFilePath string) error {
|
||||
misc.LogSavingCredentials(authFilePath)
|
||||
if s == nil {
|
||||
return fmt.Errorf("vertex credential: storage is nil")
|
||||
}
|
||||
if s.ServiceAccount == nil {
|
||||
return fmt.Errorf("vertex credential: service account content is empty")
|
||||
}
|
||||
// Ensure we tag the file with the provider type.
|
||||
s.Type = "vertex"
|
||||
|
||||
if err := os.MkdirAll(filepath.Dir(authFilePath), 0o700); err != nil {
|
||||
return fmt.Errorf("vertex credential: create directory failed: %w", err)
|
||||
}
|
||||
f, err := os.Create(authFilePath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("vertex credential: create file failed: %w", err)
|
||||
}
|
||||
defer func() {
|
||||
if errClose := f.Close(); errClose != nil {
|
||||
log.Errorf("vertex credential: failed to close file: %v", errClose)
|
||||
}
|
||||
}()
|
||||
enc := json.NewEncoder(f)
|
||||
enc.SetIndent("", " ")
|
||||
if err = enc.Encode(s); err != nil {
|
||||
return fmt.Errorf("vertex credential: encode failed: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
Reference in New Issue
Block a user