Files
CLIProxyAPI/internal/auth/vertex/keyutil.go
Luis Pater 717eadf128 feat(vertex): add support for Vertex AI Gemini authentication and execution
Introduce Vertex AI Gemini integration with support for service account-based authentication, credential storage, and import functionality. Added new executor for Vertex AI requests, including execution and streaming paths, and integrated it into the core manager. Enhanced CLI with `--vertex-import` flag for importing service account keys.
2025-11-10 12:23:51 +08:00

209 lines
5.1 KiB
Go

package vertex
import (
"crypto/rsa"
"crypto/x509"
"encoding/base64"
"encoding/json"
"encoding/pem"
"fmt"
"strings"
)
// NormalizeServiceAccountJSON normalizes the given JSON-encoded service account payload.
// It returns the normalized JSON (with sanitized private_key) or, if normalization fails,
// the original bytes and the encountered error.
func NormalizeServiceAccountJSON(raw []byte) ([]byte, error) {
if len(raw) == 0 {
return raw, nil
}
var payload map[string]any
if err := json.Unmarshal(raw, &payload); err != nil {
return raw, err
}
normalized, err := NormalizeServiceAccountMap(payload)
if err != nil {
return raw, err
}
out, err := json.Marshal(normalized)
if err != nil {
return raw, err
}
return out, nil
}
// NormalizeServiceAccountMap returns a copy of the given service account map with
// a sanitized private_key field that is guaranteed to contain a valid RSA PRIVATE KEY PEM block.
func NormalizeServiceAccountMap(sa map[string]any) (map[string]any, error) {
if sa == nil {
return nil, fmt.Errorf("service account payload is empty")
}
pk, _ := sa["private_key"].(string)
if strings.TrimSpace(pk) == "" {
return nil, fmt.Errorf("service account missing private_key")
}
normalized, err := sanitizePrivateKey(pk)
if err != nil {
return nil, err
}
clone := make(map[string]any, len(sa))
for k, v := range sa {
clone[k] = v
}
clone["private_key"] = normalized
return clone, nil
}
func sanitizePrivateKey(raw string) (string, error) {
pk := strings.ReplaceAll(raw, "\r\n", "\n")
pk = strings.ReplaceAll(pk, "\r", "\n")
pk = stripANSIEscape(pk)
pk = strings.ToValidUTF8(pk, "")
pk = strings.TrimSpace(pk)
normalized := pk
if block, _ := pem.Decode([]byte(pk)); block == nil {
// Attempt to reconstruct from the textual payload.
if reconstructed, err := rebuildPEM(pk); err == nil {
normalized = reconstructed
} else {
return "", fmt.Errorf("private_key is not valid pem: %w", err)
}
}
block, _ := pem.Decode([]byte(normalized))
if block == nil {
return "", fmt.Errorf("private_key pem decode failed")
}
rsaBlock, err := ensureRSAPrivateKey(block)
if err != nil {
return "", err
}
return string(pem.EncodeToMemory(rsaBlock)), nil
}
func ensureRSAPrivateKey(block *pem.Block) (*pem.Block, error) {
if block == nil {
return nil, fmt.Errorf("pem block is nil")
}
if block.Type == "RSA PRIVATE KEY" {
if _, err := x509.ParsePKCS1PrivateKey(block.Bytes); err != nil {
return nil, fmt.Errorf("private_key invalid rsa: %w", err)
}
return block, nil
}
if block.Type == "PRIVATE KEY" {
key, err := x509.ParsePKCS8PrivateKey(block.Bytes)
if err != nil {
return nil, fmt.Errorf("private_key invalid pkcs8: %w", err)
}
rsaKey, ok := key.(*rsa.PrivateKey)
if !ok {
return nil, fmt.Errorf("private_key is not an RSA key")
}
der := x509.MarshalPKCS1PrivateKey(rsaKey)
return &pem.Block{Type: "RSA PRIVATE KEY", Bytes: der}, nil
}
// Attempt auto-detection: try PKCS#1 first, then PKCS#8.
if rsaKey, err := x509.ParsePKCS1PrivateKey(block.Bytes); err == nil {
der := x509.MarshalPKCS1PrivateKey(rsaKey)
return &pem.Block{Type: "RSA PRIVATE KEY", Bytes: der}, nil
}
if key, err := x509.ParsePKCS8PrivateKey(block.Bytes); err == nil {
if rsaKey, ok := key.(*rsa.PrivateKey); ok {
der := x509.MarshalPKCS1PrivateKey(rsaKey)
return &pem.Block{Type: "RSA PRIVATE KEY", Bytes: der}, nil
}
}
return nil, fmt.Errorf("private_key uses unsupported format")
}
func rebuildPEM(raw string) (string, error) {
kind := "PRIVATE KEY"
if strings.Contains(raw, "RSA PRIVATE KEY") {
kind = "RSA PRIVATE KEY"
}
header := "-----BEGIN " + kind + "-----"
footer := "-----END " + kind + "-----"
start := strings.Index(raw, header)
end := strings.Index(raw, footer)
if start < 0 || end <= start {
return "", fmt.Errorf("missing pem markers")
}
body := raw[start+len(header) : end]
payload := filterBase64(body)
if payload == "" {
return "", fmt.Errorf("private_key base64 payload empty")
}
der, err := base64.StdEncoding.DecodeString(payload)
if err != nil {
return "", fmt.Errorf("private_key base64 decode failed: %w", err)
}
block := &pem.Block{Type: kind, Bytes: der}
return string(pem.EncodeToMemory(block)), nil
}
func filterBase64(s string) string {
var b strings.Builder
for _, r := range s {
switch {
case r >= 'A' && r <= 'Z':
b.WriteRune(r)
case r >= 'a' && r <= 'z':
b.WriteRune(r)
case r >= '0' && r <= '9':
b.WriteRune(r)
case r == '+' || r == '/' || r == '=':
b.WriteRune(r)
default:
// skip
}
}
return b.String()
}
func stripANSIEscape(s string) string {
in := []rune(s)
var out []rune
for i := 0; i < len(in); i++ {
r := in[i]
if r != 0x1b {
out = append(out, r)
continue
}
if i+1 >= len(in) {
continue
}
next := in[i+1]
switch next {
case ']':
i += 2
for i < len(in) {
if in[i] == 0x07 {
break
}
if in[i] == 0x1b && i+1 < len(in) && in[i+1] == '\\' {
i++
break
}
i++
}
case '[':
i += 2
for i < len(in) {
if (in[i] >= 'A' && in[i] <= 'Z') || (in[i] >= 'a' && in[i] <= 'z') {
break
}
i++
}
default:
// skip single ESC
}
}
return string(out)
}