mirror of
https://github.com/router-for-me/CLIProxyAPI.git
synced 2026-02-03 21:10:51 +08:00
Introduce Vertex AI Gemini integration with support for service account-based authentication, credential storage, and import functionality. Added new executor for Vertex AI requests, including execution and streaming paths, and integrated it into the core manager. Enhanced CLI with `--vertex-import` flag for importing service account keys.
209 lines
5.1 KiB
Go
209 lines
5.1 KiB
Go
package vertex
|
|
|
|
import (
|
|
"crypto/rsa"
|
|
"crypto/x509"
|
|
"encoding/base64"
|
|
"encoding/json"
|
|
"encoding/pem"
|
|
"fmt"
|
|
"strings"
|
|
)
|
|
|
|
// NormalizeServiceAccountJSON normalizes the given JSON-encoded service account payload.
|
|
// It returns the normalized JSON (with sanitized private_key) or, if normalization fails,
|
|
// the original bytes and the encountered error.
|
|
func NormalizeServiceAccountJSON(raw []byte) ([]byte, error) {
|
|
if len(raw) == 0 {
|
|
return raw, nil
|
|
}
|
|
var payload map[string]any
|
|
if err := json.Unmarshal(raw, &payload); err != nil {
|
|
return raw, err
|
|
}
|
|
normalized, err := NormalizeServiceAccountMap(payload)
|
|
if err != nil {
|
|
return raw, err
|
|
}
|
|
out, err := json.Marshal(normalized)
|
|
if err != nil {
|
|
return raw, err
|
|
}
|
|
return out, nil
|
|
}
|
|
|
|
// NormalizeServiceAccountMap returns a copy of the given service account map with
|
|
// a sanitized private_key field that is guaranteed to contain a valid RSA PRIVATE KEY PEM block.
|
|
func NormalizeServiceAccountMap(sa map[string]any) (map[string]any, error) {
|
|
if sa == nil {
|
|
return nil, fmt.Errorf("service account payload is empty")
|
|
}
|
|
pk, _ := sa["private_key"].(string)
|
|
if strings.TrimSpace(pk) == "" {
|
|
return nil, fmt.Errorf("service account missing private_key")
|
|
}
|
|
normalized, err := sanitizePrivateKey(pk)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
clone := make(map[string]any, len(sa))
|
|
for k, v := range sa {
|
|
clone[k] = v
|
|
}
|
|
clone["private_key"] = normalized
|
|
return clone, nil
|
|
}
|
|
|
|
func sanitizePrivateKey(raw string) (string, error) {
|
|
pk := strings.ReplaceAll(raw, "\r\n", "\n")
|
|
pk = strings.ReplaceAll(pk, "\r", "\n")
|
|
pk = stripANSIEscape(pk)
|
|
pk = strings.ToValidUTF8(pk, "")
|
|
pk = strings.TrimSpace(pk)
|
|
|
|
normalized := pk
|
|
if block, _ := pem.Decode([]byte(pk)); block == nil {
|
|
// Attempt to reconstruct from the textual payload.
|
|
if reconstructed, err := rebuildPEM(pk); err == nil {
|
|
normalized = reconstructed
|
|
} else {
|
|
return "", fmt.Errorf("private_key is not valid pem: %w", err)
|
|
}
|
|
}
|
|
|
|
block, _ := pem.Decode([]byte(normalized))
|
|
if block == nil {
|
|
return "", fmt.Errorf("private_key pem decode failed")
|
|
}
|
|
|
|
rsaBlock, err := ensureRSAPrivateKey(block)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
return string(pem.EncodeToMemory(rsaBlock)), nil
|
|
}
|
|
|
|
func ensureRSAPrivateKey(block *pem.Block) (*pem.Block, error) {
|
|
if block == nil {
|
|
return nil, fmt.Errorf("pem block is nil")
|
|
}
|
|
|
|
if block.Type == "RSA PRIVATE KEY" {
|
|
if _, err := x509.ParsePKCS1PrivateKey(block.Bytes); err != nil {
|
|
return nil, fmt.Errorf("private_key invalid rsa: %w", err)
|
|
}
|
|
return block, nil
|
|
}
|
|
|
|
if block.Type == "PRIVATE KEY" {
|
|
key, err := x509.ParsePKCS8PrivateKey(block.Bytes)
|
|
if err != nil {
|
|
return nil, fmt.Errorf("private_key invalid pkcs8: %w", err)
|
|
}
|
|
rsaKey, ok := key.(*rsa.PrivateKey)
|
|
if !ok {
|
|
return nil, fmt.Errorf("private_key is not an RSA key")
|
|
}
|
|
der := x509.MarshalPKCS1PrivateKey(rsaKey)
|
|
return &pem.Block{Type: "RSA PRIVATE KEY", Bytes: der}, nil
|
|
}
|
|
|
|
// Attempt auto-detection: try PKCS#1 first, then PKCS#8.
|
|
if rsaKey, err := x509.ParsePKCS1PrivateKey(block.Bytes); err == nil {
|
|
der := x509.MarshalPKCS1PrivateKey(rsaKey)
|
|
return &pem.Block{Type: "RSA PRIVATE KEY", Bytes: der}, nil
|
|
}
|
|
if key, err := x509.ParsePKCS8PrivateKey(block.Bytes); err == nil {
|
|
if rsaKey, ok := key.(*rsa.PrivateKey); ok {
|
|
der := x509.MarshalPKCS1PrivateKey(rsaKey)
|
|
return &pem.Block{Type: "RSA PRIVATE KEY", Bytes: der}, nil
|
|
}
|
|
}
|
|
return nil, fmt.Errorf("private_key uses unsupported format")
|
|
}
|
|
|
|
func rebuildPEM(raw string) (string, error) {
|
|
kind := "PRIVATE KEY"
|
|
if strings.Contains(raw, "RSA PRIVATE KEY") {
|
|
kind = "RSA PRIVATE KEY"
|
|
}
|
|
header := "-----BEGIN " + kind + "-----"
|
|
footer := "-----END " + kind + "-----"
|
|
start := strings.Index(raw, header)
|
|
end := strings.Index(raw, footer)
|
|
if start < 0 || end <= start {
|
|
return "", fmt.Errorf("missing pem markers")
|
|
}
|
|
body := raw[start+len(header) : end]
|
|
payload := filterBase64(body)
|
|
if payload == "" {
|
|
return "", fmt.Errorf("private_key base64 payload empty")
|
|
}
|
|
der, err := base64.StdEncoding.DecodeString(payload)
|
|
if err != nil {
|
|
return "", fmt.Errorf("private_key base64 decode failed: %w", err)
|
|
}
|
|
block := &pem.Block{Type: kind, Bytes: der}
|
|
return string(pem.EncodeToMemory(block)), nil
|
|
}
|
|
|
|
func filterBase64(s string) string {
|
|
var b strings.Builder
|
|
for _, r := range s {
|
|
switch {
|
|
case r >= 'A' && r <= 'Z':
|
|
b.WriteRune(r)
|
|
case r >= 'a' && r <= 'z':
|
|
b.WriteRune(r)
|
|
case r >= '0' && r <= '9':
|
|
b.WriteRune(r)
|
|
case r == '+' || r == '/' || r == '=':
|
|
b.WriteRune(r)
|
|
default:
|
|
// skip
|
|
}
|
|
}
|
|
return b.String()
|
|
}
|
|
|
|
func stripANSIEscape(s string) string {
|
|
in := []rune(s)
|
|
var out []rune
|
|
for i := 0; i < len(in); i++ {
|
|
r := in[i]
|
|
if r != 0x1b {
|
|
out = append(out, r)
|
|
continue
|
|
}
|
|
if i+1 >= len(in) {
|
|
continue
|
|
}
|
|
next := in[i+1]
|
|
switch next {
|
|
case ']':
|
|
i += 2
|
|
for i < len(in) {
|
|
if in[i] == 0x07 {
|
|
break
|
|
}
|
|
if in[i] == 0x1b && i+1 < len(in) && in[i+1] == '\\' {
|
|
i++
|
|
break
|
|
}
|
|
i++
|
|
}
|
|
case '[':
|
|
i += 2
|
|
for i < len(in) {
|
|
if (in[i] >= 'A' && in[i] <= 'Z') || (in[i] >= 'a' && in[i] <= 'z') {
|
|
break
|
|
}
|
|
i++
|
|
}
|
|
default:
|
|
// skip single ESC
|
|
}
|
|
}
|
|
return string(out)
|
|
}
|