Compare commits

...

11 Commits

Author SHA1 Message Date
Luis Pater
592f6fc66b feat(vertex): add usage source resolution for Vertex projects
Extend `resolveUsageSource` to support Vertex projects by extracting and normalizing `project_id` or `project` from the metadata for accurate source resolution.
2025-11-12 08:43:02 +08:00
Luis Pater
09ecba6dab Merge pull request #237 from TUGOhost/feature/support_auto_model
feat: add auto model resolution and model creation timestamp tracking
2025-11-12 00:03:30 +08:00
Luis Pater
d6bd6f3fb9 feat(vertex, management): enhance token handling and OAuth2 integration
Extend `vertexAccessToken` to support proxy-aware HTTP clients and update calls accordingly for better configurability. Add `deleteTokenRecord` to handle token cleanup, improving management of authentication files.
2025-11-11 23:42:46 +08:00
TUGOhost
92f4278039 feat: add auto model resolution and model creation timestamp tracking
- Add 'created' field to model registry for tracking model creation time
- Implement GetFirstAvailableModel() to find the first available model by newest creation timestamp
- Add ResolveAutoModel() utility function to resolve "auto" model name to actual available model
- Update request handler to resolve "auto" model before processing requests
- Ensures automatic model selection when "auto" is specified as model name

This enables dynamic model selection based on availability and creation time, improving the user experience when no specific model is requested.
2025-11-11 20:30:09 +08:00
Luis Pater
8ae8a5c296 Fixed: #233
feat(management): add auth ID normalization and file-based ID resolution

Introduce `authIDForPath` to standardize ID generation from file paths, improving consistency in authentication handling. Update `registerAuthFromFile` and `disableAuth` to utilize normalized IDs, incorporating relative path resolution and file name extraction where applicable.
2025-11-11 19:23:31 +08:00
Luis Pater
dc804e96fb fix(management): improve error handling and normalize YAML comment indentation
Enhance error management for file operations and clean up temporary files. Add `NormalizeCommentIndentation` function to ensure YAML comments maintain consistent formatting.
2025-11-11 08:37:57 +08:00
Luis Pater
ab76cb3662 feat(management): add Vertex service account import and WebSocket auth management
Introduce an endpoint for importing Vertex service account JSON keys and storing them as authentication records. Add handlers for managing WebSocket authentication configuration.
2025-11-10 20:48:31 +08:00
Luis Pater
2965bdadc1 fix(translator): remove debug print statement from OpenAI Gemini request processing 2025-11-10 18:37:05 +08:00
Luis Pater
40f7061b04 feat(watcher): debounce config reloads to prevent redundant operations
Introduce `scheduleConfigReload` with debounce functionality for config reloads, ensuring efficient handling of frequent changes. Added `stopConfigReloadTimer` for stopping timers during watcher shutdown.
2025-11-10 12:57:40 +08:00
Luis Pater
8c947cafbe Merge branch 'vertex' into dev 2025-11-10 12:24:07 +08:00
Luis Pater
717eadf128 feat(vertex): add support for Vertex AI Gemini authentication and execution
Introduce Vertex AI Gemini integration with support for service account-based authentication, credential storage, and import functionality. Added new executor for Vertex AI requests, including execution and streaming paths, and integrated it into the core manager. Enhanced CLI with `--vertex-import` flag for importing service account keys.
2025-11-10 12:23:51 +08:00
17 changed files with 1303 additions and 84 deletions

View File

@@ -57,6 +57,7 @@ func main() {
var iflowLogin bool
var noBrowser bool
var projectID string
var vertexImport string
var configPath string
var password string
@@ -69,6 +70,7 @@ func main() {
flag.BoolVar(&noBrowser, "no-browser", false, "Don't open browser automatically for OAuth")
flag.StringVar(&projectID, "project_id", "", "Project ID (Gemini only, not required)")
flag.StringVar(&configPath, "config", DefaultConfigPath, "Configure File Path")
flag.StringVar(&vertexImport, "vertex-import", "", "Import Vertex service account key JSON file")
flag.StringVar(&password, "password", "", "")
flag.CommandLine.Usage = func() {
@@ -417,7 +419,10 @@ func main() {
// Handle different command modes based on the provided flags.
if login {
if vertexImport != "" {
// Handle Vertex service account import
cmd.DoVertexImport(cfg, vertexImport)
} else if login {
// Handle Google/Gemini login
cmd.DoLogin(cfg, projectID, options)
} else if codexLogin {

View File

@@ -508,6 +508,10 @@ func (h *Handler) DeleteAuthFile(c *gin.Context) {
}
}
if err = os.Remove(full); err == nil {
if errDel := h.deleteTokenRecord(ctx, full); errDel != nil {
c.JSON(500, gin.H{"error": errDel.Error()})
return
}
deleted++
h.disableAuth(ctx, full)
}
@@ -534,10 +538,32 @@ func (h *Handler) DeleteAuthFile(c *gin.Context) {
}
return
}
if err := h.deleteTokenRecord(ctx, full); err != nil {
c.JSON(500, gin.H{"error": err.Error()})
return
}
h.disableAuth(ctx, full)
c.JSON(200, gin.H{"status": "ok"})
}
func (h *Handler) authIDForPath(path string) string {
path = strings.TrimSpace(path)
if path == "" {
return ""
}
if h == nil || h.cfg == nil {
return path
}
authDir := strings.TrimSpace(h.cfg.AuthDir)
if authDir == "" {
return path
}
if rel, err := filepath.Rel(authDir, path); err == nil && rel != "" {
return rel
}
return path
}
func (h *Handler) registerAuthFromFile(ctx context.Context, path string, data []byte) error {
if h.authManager == nil {
return nil
@@ -566,13 +592,18 @@ func (h *Handler) registerAuthFromFile(ctx context.Context, path string, data []
}
lastRefresh, hasLastRefresh := extractLastRefreshTimestamp(metadata)
authID := h.authIDForPath(path)
if authID == "" {
authID = path
}
attr := map[string]string{
"path": path,
"source": path,
}
auth := &coreauth.Auth{
ID: path,
ID: authID,
Provider: provider,
FileName: filepath.Base(path),
Label: label,
Status: coreauth.StatusActive,
Attributes: attr,
@@ -583,7 +614,7 @@ func (h *Handler) registerAuthFromFile(ctx context.Context, path string, data []
if hasLastRefresh {
auth.LastRefreshedAt = lastRefresh
}
if existing, ok := h.authManager.GetByID(path); ok {
if existing, ok := h.authManager.GetByID(authID); ok {
auth.CreatedAt = existing.CreatedAt
if !hasLastRefresh {
auth.LastRefreshedAt = existing.LastRefreshedAt
@@ -598,10 +629,17 @@ func (h *Handler) registerAuthFromFile(ctx context.Context, path string, data []
}
func (h *Handler) disableAuth(ctx context.Context, id string) {
if h.authManager == nil || id == "" {
if h == nil || h.authManager == nil {
return
}
if auth, ok := h.authManager.GetByID(id); ok {
authID := h.authIDForPath(id)
if authID == "" {
authID = strings.TrimSpace(id)
}
if authID == "" {
return
}
if auth, ok := h.authManager.GetByID(authID); ok {
auth.Disabled = true
auth.Status = coreauth.StatusDisabled
auth.StatusMessage = "removed via management API"
@@ -610,9 +648,20 @@ func (h *Handler) disableAuth(ctx context.Context, id string) {
}
}
func (h *Handler) saveTokenRecord(ctx context.Context, record *coreauth.Auth) (string, error) {
if record == nil {
return "", fmt.Errorf("token record is nil")
func (h *Handler) deleteTokenRecord(ctx context.Context, path string) error {
if strings.TrimSpace(path) == "" {
return fmt.Errorf("auth path is empty")
}
store := h.tokenStoreWithBaseDir()
if store == nil {
return fmt.Errorf("token store unavailable")
}
return store.Delete(ctx, path)
}
func (h *Handler) tokenStoreWithBaseDir() coreauth.Store {
if h == nil {
return nil
}
store := h.tokenStore
if store == nil {
@@ -624,6 +673,17 @@ func (h *Handler) saveTokenRecord(ctx context.Context, record *coreauth.Auth) (s
dirSetter.SetBaseDir(h.cfg.AuthDir)
}
}
return store
}
func (h *Handler) saveTokenRecord(ctx context.Context, record *coreauth.Auth) (string, error) {
if record == nil {
return "", fmt.Errorf("token record is nil")
}
store := h.tokenStoreWithBaseDir()
if store == nil {
return "", fmt.Errorf("token store unavailable")
}
return store.Save(ctx, record)
}

View File

@@ -28,7 +28,7 @@ func (h *Handler) GetConfigYAML(c *gin.Context) {
return
}
var node yaml.Node
if err := yaml.Unmarshal(data, &node); err != nil {
if err = yaml.Unmarshal(data, &node); err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "parse_failed", "message": err.Error()})
return
}
@@ -41,17 +41,18 @@ func (h *Handler) GetConfigYAML(c *gin.Context) {
}
func WriteConfig(path string, data []byte) error {
data = config.NormalizeCommentIndentation(data)
f, err := os.OpenFile(path, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0644)
if err != nil {
return err
}
if _, err := f.Write(data); err != nil {
f.Close()
return err
if _, errWrite := f.Write(data); errWrite != nil {
_ = f.Close()
return errWrite
}
if err := f.Sync(); err != nil {
f.Close()
return err
if errSync := f.Sync(); errSync != nil {
_ = f.Close()
return errSync
}
return f.Close()
}
@@ -63,7 +64,7 @@ func (h *Handler) PutConfigYAML(c *gin.Context) {
return
}
var cfg config.Config
if err := yaml.Unmarshal(body, &cfg); err != nil {
if err = yaml.Unmarshal(body, &cfg); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid_yaml", "message": err.Error()})
return
}
@@ -75,18 +76,20 @@ func (h *Handler) PutConfigYAML(c *gin.Context) {
return
}
tempFile := tmpFile.Name()
if _, err := tmpFile.Write(body); err != nil {
tmpFile.Close()
os.Remove(tempFile)
c.JSON(http.StatusInternalServerError, gin.H{"error": "write_failed", "message": err.Error()})
if _, errWrite := tmpFile.Write(body); errWrite != nil {
_ = tmpFile.Close()
_ = os.Remove(tempFile)
c.JSON(http.StatusInternalServerError, gin.H{"error": "write_failed", "message": errWrite.Error()})
return
}
if err := tmpFile.Close(); err != nil {
os.Remove(tempFile)
c.JSON(http.StatusInternalServerError, gin.H{"error": "write_failed", "message": err.Error()})
if errClose := tmpFile.Close(); errClose != nil {
_ = os.Remove(tempFile)
c.JSON(http.StatusInternalServerError, gin.H{"error": "write_failed", "message": errClose.Error()})
return
}
defer os.Remove(tempFile)
defer func() {
_ = os.Remove(tempFile)
}()
_, err = config.LoadConfigOptional(tempFile, false)
if err != nil {
c.JSON(http.StatusUnprocessableEntity, gin.H{"error": "invalid_config", "message": err.Error()})
@@ -153,6 +156,14 @@ func (h *Handler) PutRequestLog(c *gin.Context) {
h.updateBoolField(c, func(v bool) { h.cfg.RequestLog = v })
}
// Websocket auth
func (h *Handler) GetWebsocketAuth(c *gin.Context) {
c.JSON(200, gin.H{"ws-auth": h.cfg.WebsocketAuth})
}
func (h *Handler) PutWebsocketAuth(c *gin.Context) {
h.updateBoolField(c, func(v bool) { h.cfg.WebsocketAuth = v })
}
// Request retry
func (h *Handler) GetRequestRetry(c *gin.Context) {
c.JSON(200, gin.H{"request-retry": h.cfg.RequestRetry})

View File

@@ -0,0 +1,156 @@
package management
import (
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"strings"
"github.com/gin-gonic/gin"
"github.com/router-for-me/CLIProxyAPI/v6/internal/auth/vertex"
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
)
// ImportVertexCredential handles uploading a Vertex service account JSON and saving it as an auth record.
func (h *Handler) ImportVertexCredential(c *gin.Context) {
if h == nil || h.cfg == nil {
c.JSON(http.StatusServiceUnavailable, gin.H{"error": "config unavailable"})
return
}
if h.cfg.AuthDir == "" {
c.JSON(http.StatusServiceUnavailable, gin.H{"error": "auth directory not configured"})
return
}
fileHeader, err := c.FormFile("file")
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "file required"})
return
}
file, err := fileHeader.Open()
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("failed to read file: %v", err)})
return
}
defer file.Close()
data, err := io.ReadAll(file)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("failed to read file: %v", err)})
return
}
var serviceAccount map[string]any
if err := json.Unmarshal(data, &serviceAccount); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid json", "message": err.Error()})
return
}
normalizedSA, err := vertex.NormalizeServiceAccountMap(serviceAccount)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid service account", "message": err.Error()})
return
}
serviceAccount = normalizedSA
projectID := strings.TrimSpace(valueAsString(serviceAccount["project_id"]))
if projectID == "" {
c.JSON(http.StatusBadRequest, gin.H{"error": "project_id missing"})
return
}
email := strings.TrimSpace(valueAsString(serviceAccount["client_email"]))
location := strings.TrimSpace(c.PostForm("location"))
if location == "" {
location = strings.TrimSpace(c.Query("location"))
}
if location == "" {
location = "us-central1"
}
fileName := fmt.Sprintf("vertex-%s.json", sanitizeVertexFilePart(projectID))
label := labelForVertex(projectID, email)
storage := &vertex.VertexCredentialStorage{
ServiceAccount: serviceAccount,
ProjectID: projectID,
Email: email,
Location: location,
Type: "vertex",
}
metadata := map[string]any{
"service_account": serviceAccount,
"project_id": projectID,
"email": email,
"location": location,
"type": "vertex",
"label": label,
}
record := &coreauth.Auth{
ID: fileName,
Provider: "vertex",
FileName: fileName,
Storage: storage,
Label: label,
Metadata: metadata,
}
ctx := context.Background()
if reqCtx := c.Request.Context(); reqCtx != nil {
ctx = reqCtx
}
savedPath, err := h.saveTokenRecord(ctx, record)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "save_failed", "message": err.Error()})
return
}
c.JSON(http.StatusOK, gin.H{
"status": "ok",
"auth-file": savedPath,
"project_id": projectID,
"email": email,
"location": location,
})
}
func valueAsString(v any) string {
if v == nil {
return ""
}
switch t := v.(type) {
case string:
return t
default:
return fmt.Sprint(t)
}
}
func sanitizeVertexFilePart(s string) string {
out := strings.TrimSpace(s)
replacers := []string{"/", "_", "\\", "_", ":", "_", " ", "-"}
for i := 0; i < len(replacers); i += 2 {
out = strings.ReplaceAll(out, replacers[i], replacers[i+1])
}
if out == "" {
return "vertex"
}
return out
}
func labelForVertex(projectID, email string) string {
p := strings.TrimSpace(projectID)
e := strings.TrimSpace(email)
if p != "" && e != "" {
return fmt.Sprintf("%s (%s)", p, e)
}
if p != "" {
return p
}
if e != "" {
return e
}
return "vertex"
}

View File

@@ -484,6 +484,9 @@ func (s *Server) registerManagementRoutes() {
mgmt.GET("/request-log", s.mgmt.GetRequestLog)
mgmt.PUT("/request-log", s.mgmt.PutRequestLog)
mgmt.PATCH("/request-log", s.mgmt.PutRequestLog)
mgmt.GET("/ws-auth", s.mgmt.GetWebsocketAuth)
mgmt.PUT("/ws-auth", s.mgmt.PutWebsocketAuth)
mgmt.PATCH("/ws-auth", s.mgmt.PutWebsocketAuth)
mgmt.GET("/request-retry", s.mgmt.GetRequestRetry)
mgmt.PUT("/request-retry", s.mgmt.PutRequestRetry)
@@ -508,6 +511,7 @@ func (s *Server) registerManagementRoutes() {
mgmt.GET("/auth-files/download", s.mgmt.DownloadAuthFile)
mgmt.POST("/auth-files", s.mgmt.UploadAuthFile)
mgmt.DELETE("/auth-files", s.mgmt.DeleteAuthFile)
mgmt.POST("/vertex/import", s.mgmt.ImportVertexCredential)
mgmt.GET("/anthropic-auth-url", s.mgmt.RequestAnthropicToken)
mgmt.GET("/codex-auth-url", s.mgmt.RequestCodexToken)

View File

@@ -0,0 +1,208 @@
package vertex
import (
"crypto/rsa"
"crypto/x509"
"encoding/base64"
"encoding/json"
"encoding/pem"
"fmt"
"strings"
)
// NormalizeServiceAccountJSON normalizes the given JSON-encoded service account payload.
// It returns the normalized JSON (with sanitized private_key) or, if normalization fails,
// the original bytes and the encountered error.
func NormalizeServiceAccountJSON(raw []byte) ([]byte, error) {
if len(raw) == 0 {
return raw, nil
}
var payload map[string]any
if err := json.Unmarshal(raw, &payload); err != nil {
return raw, err
}
normalized, err := NormalizeServiceAccountMap(payload)
if err != nil {
return raw, err
}
out, err := json.Marshal(normalized)
if err != nil {
return raw, err
}
return out, nil
}
// NormalizeServiceAccountMap returns a copy of the given service account map with
// a sanitized private_key field that is guaranteed to contain a valid RSA PRIVATE KEY PEM block.
func NormalizeServiceAccountMap(sa map[string]any) (map[string]any, error) {
if sa == nil {
return nil, fmt.Errorf("service account payload is empty")
}
pk, _ := sa["private_key"].(string)
if strings.TrimSpace(pk) == "" {
return nil, fmt.Errorf("service account missing private_key")
}
normalized, err := sanitizePrivateKey(pk)
if err != nil {
return nil, err
}
clone := make(map[string]any, len(sa))
for k, v := range sa {
clone[k] = v
}
clone["private_key"] = normalized
return clone, nil
}
func sanitizePrivateKey(raw string) (string, error) {
pk := strings.ReplaceAll(raw, "\r\n", "\n")
pk = strings.ReplaceAll(pk, "\r", "\n")
pk = stripANSIEscape(pk)
pk = strings.ToValidUTF8(pk, "")
pk = strings.TrimSpace(pk)
normalized := pk
if block, _ := pem.Decode([]byte(pk)); block == nil {
// Attempt to reconstruct from the textual payload.
if reconstructed, err := rebuildPEM(pk); err == nil {
normalized = reconstructed
} else {
return "", fmt.Errorf("private_key is not valid pem: %w", err)
}
}
block, _ := pem.Decode([]byte(normalized))
if block == nil {
return "", fmt.Errorf("private_key pem decode failed")
}
rsaBlock, err := ensureRSAPrivateKey(block)
if err != nil {
return "", err
}
return string(pem.EncodeToMemory(rsaBlock)), nil
}
func ensureRSAPrivateKey(block *pem.Block) (*pem.Block, error) {
if block == nil {
return nil, fmt.Errorf("pem block is nil")
}
if block.Type == "RSA PRIVATE KEY" {
if _, err := x509.ParsePKCS1PrivateKey(block.Bytes); err != nil {
return nil, fmt.Errorf("private_key invalid rsa: %w", err)
}
return block, nil
}
if block.Type == "PRIVATE KEY" {
key, err := x509.ParsePKCS8PrivateKey(block.Bytes)
if err != nil {
return nil, fmt.Errorf("private_key invalid pkcs8: %w", err)
}
rsaKey, ok := key.(*rsa.PrivateKey)
if !ok {
return nil, fmt.Errorf("private_key is not an RSA key")
}
der := x509.MarshalPKCS1PrivateKey(rsaKey)
return &pem.Block{Type: "RSA PRIVATE KEY", Bytes: der}, nil
}
// Attempt auto-detection: try PKCS#1 first, then PKCS#8.
if rsaKey, err := x509.ParsePKCS1PrivateKey(block.Bytes); err == nil {
der := x509.MarshalPKCS1PrivateKey(rsaKey)
return &pem.Block{Type: "RSA PRIVATE KEY", Bytes: der}, nil
}
if key, err := x509.ParsePKCS8PrivateKey(block.Bytes); err == nil {
if rsaKey, ok := key.(*rsa.PrivateKey); ok {
der := x509.MarshalPKCS1PrivateKey(rsaKey)
return &pem.Block{Type: "RSA PRIVATE KEY", Bytes: der}, nil
}
}
return nil, fmt.Errorf("private_key uses unsupported format")
}
func rebuildPEM(raw string) (string, error) {
kind := "PRIVATE KEY"
if strings.Contains(raw, "RSA PRIVATE KEY") {
kind = "RSA PRIVATE KEY"
}
header := "-----BEGIN " + kind + "-----"
footer := "-----END " + kind + "-----"
start := strings.Index(raw, header)
end := strings.Index(raw, footer)
if start < 0 || end <= start {
return "", fmt.Errorf("missing pem markers")
}
body := raw[start+len(header) : end]
payload := filterBase64(body)
if payload == "" {
return "", fmt.Errorf("private_key base64 payload empty")
}
der, err := base64.StdEncoding.DecodeString(payload)
if err != nil {
return "", fmt.Errorf("private_key base64 decode failed: %w", err)
}
block := &pem.Block{Type: kind, Bytes: der}
return string(pem.EncodeToMemory(block)), nil
}
func filterBase64(s string) string {
var b strings.Builder
for _, r := range s {
switch {
case r >= 'A' && r <= 'Z':
b.WriteRune(r)
case r >= 'a' && r <= 'z':
b.WriteRune(r)
case r >= '0' && r <= '9':
b.WriteRune(r)
case r == '+' || r == '/' || r == '=':
b.WriteRune(r)
default:
// skip
}
}
return b.String()
}
func stripANSIEscape(s string) string {
in := []rune(s)
var out []rune
for i := 0; i < len(in); i++ {
r := in[i]
if r != 0x1b {
out = append(out, r)
continue
}
if i+1 >= len(in) {
continue
}
next := in[i+1]
switch next {
case ']':
i += 2
for i < len(in) {
if in[i] == 0x07 {
break
}
if in[i] == 0x1b && i+1 < len(in) && in[i+1] == '\\' {
i++
break
}
i++
}
case '[':
i += 2
for i < len(in) {
if (in[i] >= 'A' && in[i] <= 'Z') || (in[i] >= 'a' && in[i] <= 'z') {
break
}
i++
}
default:
// skip single ESC
}
}
return string(out)
}

View File

@@ -0,0 +1,66 @@
// Package vertex provides token storage for Google Vertex AI Gemini via service account credentials.
// It serialises service account JSON into an auth file that is consumed by the runtime executor.
package vertex
import (
"encoding/json"
"fmt"
"os"
"path/filepath"
"github.com/router-for-me/CLIProxyAPI/v6/internal/misc"
log "github.com/sirupsen/logrus"
)
// VertexCredentialStorage stores the service account JSON for Vertex AI access.
// The content is persisted verbatim under the "service_account" key, together with
// helper fields for project, location and email to improve logging and discovery.
type VertexCredentialStorage struct {
// ServiceAccount holds the parsed service account JSON content.
ServiceAccount map[string]any `json:"service_account"`
// ProjectID is derived from the service account JSON (project_id).
ProjectID string `json:"project_id"`
// Email is the client_email from the service account JSON.
Email string `json:"email"`
// Location optionally sets a default region (e.g., us-central1) for Vertex endpoints.
Location string `json:"location,omitempty"`
// Type is the provider identifier stored alongside credentials. Always "vertex".
Type string `json:"type"`
}
// SaveTokenToFile writes the credential payload to the given file path in JSON format.
// It ensures the parent directory exists and logs the operation for transparency.
func (s *VertexCredentialStorage) SaveTokenToFile(authFilePath string) error {
misc.LogSavingCredentials(authFilePath)
if s == nil {
return fmt.Errorf("vertex credential: storage is nil")
}
if s.ServiceAccount == nil {
return fmt.Errorf("vertex credential: service account content is empty")
}
// Ensure we tag the file with the provider type.
s.Type = "vertex"
if err := os.MkdirAll(filepath.Dir(authFilePath), 0o700); err != nil {
return fmt.Errorf("vertex credential: create directory failed: %w", err)
}
f, err := os.Create(authFilePath)
if err != nil {
return fmt.Errorf("vertex credential: create file failed: %w", err)
}
defer func() {
if errClose := f.Close(); errClose != nil {
log.Errorf("vertex credential: failed to close file: %v", errClose)
}
}()
enc := json.NewEncoder(f)
enc.SetIndent("", " ")
if err = enc.Encode(s); err != nil {
return fmt.Errorf("vertex credential: encode failed: %w", err)
}
return nil
}

View File

@@ -0,0 +1,123 @@
// Package cmd contains CLI helpers. This file implements importing a Vertex AI
// service account JSON into the auth store as a dedicated "vertex" credential.
package cmd
import (
"context"
"encoding/json"
"fmt"
"os"
"strings"
"github.com/router-for-me/CLIProxyAPI/v6/internal/auth/vertex"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
sdkAuth "github.com/router-for-me/CLIProxyAPI/v6/sdk/auth"
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
log "github.com/sirupsen/logrus"
)
// DoVertexImport imports a Google Cloud service account key JSON and persists
// it as a "vertex" provider credential. The file content is embedded in the auth
// file to allow portable deployment across stores.
func DoVertexImport(cfg *config.Config, keyPath string) {
if cfg == nil {
cfg = &config.Config{}
}
if resolved, errResolve := util.ResolveAuthDir(cfg.AuthDir); errResolve == nil {
cfg.AuthDir = resolved
}
rawPath := strings.TrimSpace(keyPath)
if rawPath == "" {
log.Fatalf("vertex-import: missing service account key path")
return
}
data, errRead := os.ReadFile(rawPath)
if errRead != nil {
log.Fatalf("vertex-import: read file failed: %v", errRead)
return
}
var sa map[string]any
if errUnmarshal := json.Unmarshal(data, &sa); errUnmarshal != nil {
log.Fatalf("vertex-import: invalid service account json: %v", errUnmarshal)
return
}
// Validate and normalize private_key before saving
normalizedSA, errFix := vertex.NormalizeServiceAccountMap(sa)
if errFix != nil {
log.Fatalf("vertex-import: %v", errFix)
return
}
sa = normalizedSA
email, _ := sa["client_email"].(string)
projectID, _ := sa["project_id"].(string)
if strings.TrimSpace(projectID) == "" {
log.Fatalf("vertex-import: project_id missing in service account json")
return
}
if strings.TrimSpace(email) == "" {
// Keep empty email but warn
log.Warn("vertex-import: client_email missing in service account json")
}
// Default location if not provided by user. Can be edited in the saved file later.
location := "us-central1"
fileName := fmt.Sprintf("vertex-%s.json", sanitizeFilePart(projectID))
// Build auth record
storage := &vertex.VertexCredentialStorage{
ServiceAccount: sa,
ProjectID: projectID,
Email: email,
Location: location,
}
metadata := map[string]any{
"service_account": sa,
"project_id": projectID,
"email": email,
"location": location,
"type": "vertex",
"label": labelForVertex(projectID, email),
}
record := &coreauth.Auth{
ID: fileName,
Provider: "vertex",
FileName: fileName,
Storage: storage,
Metadata: metadata,
}
store := sdkAuth.GetTokenStore()
if setter, ok := store.(interface{ SetBaseDir(string) }); ok {
setter.SetBaseDir(cfg.AuthDir)
}
path, errSave := store.Save(context.Background(), record)
if errSave != nil {
log.Fatalf("vertex-import: save credential failed: %v", errSave)
return
}
fmt.Printf("Vertex credentials imported: %s\n", path)
}
func sanitizeFilePart(s string) string {
out := strings.TrimSpace(s)
replacers := []string{"/", "_", "\\", "_", ":", "_", " ", "-"}
for i := 0; i < len(replacers); i += 2 {
out = strings.ReplaceAll(out, replacers[i], replacers[i+1])
}
return out
}
func labelForVertex(projectID, email string) string {
p := strings.TrimSpace(projectID)
e := strings.TrimSpace(email)
if p != "" && e != "" {
return fmt.Sprintf("%s (%s)", p, e)
}
if p != "" {
return p
}
if e != "" {
return e
}
return "vertex"
}

View File

@@ -5,6 +5,7 @@
package config
import (
"bytes"
"errors"
"fmt"
"os"
@@ -462,13 +463,19 @@ func SaveConfigPreserveComments(configFile string, cfg *Config) error {
return err
}
defer func() { _ = f.Close() }()
enc := yaml.NewEncoder(f)
var buf bytes.Buffer
enc := yaml.NewEncoder(&buf)
enc.SetIndent(2)
if err = enc.Encode(&original); err != nil {
_ = enc.Close()
return err
}
return enc.Close()
if err = enc.Close(); err != nil {
return err
}
data = NormalizeCommentIndentation(buf.Bytes())
_, err = f.Write(data)
return err
}
func sanitizeConfigForPersist(cfg *Config) *Config {
@@ -518,13 +525,40 @@ func SaveConfigPreserveCommentsUpdateNestedScalar(configFile string, path []stri
return err
}
defer func() { _ = f.Close() }()
enc := yaml.NewEncoder(f)
var buf bytes.Buffer
enc := yaml.NewEncoder(&buf)
enc.SetIndent(2)
if err = enc.Encode(&root); err != nil {
_ = enc.Close()
return err
}
return enc.Close()
if err = enc.Close(); err != nil {
return err
}
data = NormalizeCommentIndentation(buf.Bytes())
_, err = f.Write(data)
return err
}
// NormalizeCommentIndentation removes indentation from standalone YAML comment lines to keep them left aligned.
func NormalizeCommentIndentation(data []byte) []byte {
lines := bytes.Split(data, []byte("\n"))
changed := false
for i, line := range lines {
trimmed := bytes.TrimLeft(line, " \t")
if len(trimmed) == 0 || trimmed[0] != '#' {
continue
}
if len(trimmed) == len(line) {
continue
}
lines[i] = append([]byte(nil), trimmed...)
changed = true
}
if !changed {
return data
}
return bytes.Join(lines, []byte("\n"))
}
// getOrCreateMapValue finds the value node for a given key in a mapping node.
@@ -766,6 +800,7 @@ func matchSequenceElement(original []*yaml.Node, used []bool, target *yaml.Node)
}
}
}
default:
}
// Fallback to structural equality to preserve nodes lacking explicit identifiers.
for i := range original {

View File

@@ -4,6 +4,7 @@
package registry
import (
"fmt"
"sort"
"strings"
"sync"
@@ -800,6 +801,9 @@ func (r *ModelRegistry) convertModelToMap(model *ModelInfo, handlerType string)
if model.Type != "" {
result["type"] = model.Type
}
if model.Created != 0 {
result["created"] = model.Created
}
return result
}
}
@@ -821,3 +825,47 @@ func (r *ModelRegistry) CleanupExpiredQuotas() {
}
}
}
// GetFirstAvailableModel returns the first available model for the given handler type.
// It prioritizes models by their creation timestamp (newest first) and checks if they have
// available clients that are not suspended or over quota.
//
// Parameters:
// - handlerType: The API handler type (e.g., "openai", "claude", "gemini")
//
// Returns:
// - string: The model ID of the first available model, or empty string if none available
// - error: An error if no models are available
func (r *ModelRegistry) GetFirstAvailableModel(handlerType string) (string, error) {
r.mutex.RLock()
defer r.mutex.RUnlock()
// Get all available models for this handler type
models := r.GetAvailableModels(handlerType)
if len(models) == 0 {
return "", fmt.Errorf("no models available for handler type: %s", handlerType)
}
// Sort models by creation timestamp (newest first)
sort.Slice(models, func(i, j int) bool {
// Extract created timestamps from map
createdI, okI := models[i]["created"].(int64)
createdJ, okJ := models[j]["created"].(int64)
if !okI || !okJ {
return false
}
return createdI > createdJ
})
// Find the first model with available clients
for _, model := range models {
if modelID, ok := model["id"].(string); ok {
if count := r.GetModelCount(modelID); count > 0 {
return modelID, nil
}
}
}
return "", fmt.Errorf("no available clients for any model in handler type: %s", handlerType)
}

View File

@@ -0,0 +1,425 @@
// Package executor contains provider executors. This file implements the Vertex AI
// Gemini executor that talks to Google Vertex AI endpoints using service account
// credentials imported by the CLI.
package executor
import (
"bufio"
"bytes"
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"strings"
vertexauth "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/vertex"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
sdktranslator "github.com/router-for-me/CLIProxyAPI/v6/sdk/translator"
log "github.com/sirupsen/logrus"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
"golang.org/x/oauth2"
"golang.org/x/oauth2/google"
)
const (
// vertexAPIVersion aligns with current public Vertex Generative AI API.
vertexAPIVersion = "v1"
)
// GeminiVertexExecutor sends requests to Vertex AI Gemini endpoints using service account credentials.
type GeminiVertexExecutor struct {
cfg *config.Config
}
// NewGeminiVertexExecutor constructs the Vertex executor.
func NewGeminiVertexExecutor(cfg *config.Config) *GeminiVertexExecutor {
return &GeminiVertexExecutor{cfg: cfg}
}
// Identifier returns provider key for manager routing.
func (e *GeminiVertexExecutor) Identifier() string { return "vertex" }
// PrepareRequest is a no-op for Vertex.
func (e *GeminiVertexExecutor) PrepareRequest(_ *http.Request, _ *cliproxyauth.Auth) error {
return nil
}
// Execute handles non-streaming requests.
func (e *GeminiVertexExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (resp cliproxyexecutor.Response, err error) {
projectID, location, saJSON, errCreds := vertexCreds(auth)
if errCreds != nil {
return resp, errCreds
}
reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth)
defer reporter.trackFailure(ctx, &err)
from := opts.SourceFormat
to := sdktranslator.FromString("gemini")
body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), false)
if budgetOverride, includeOverride, ok := util.GeminiThinkingFromMetadata(req.Metadata); ok && util.ModelSupportsThinking(req.Model) {
if budgetOverride != nil {
norm := util.NormalizeThinkingBudget(req.Model, *budgetOverride)
budgetOverride = &norm
}
body = util.ApplyGeminiThinkingConfig(body, budgetOverride, includeOverride)
}
body = util.StripThinkingConfigIfUnsupported(req.Model, body)
body = fixGeminiImageAspectRatio(req.Model, body)
action := "generateContent"
if req.Metadata != nil {
if a, _ := req.Metadata["action"].(string); a == "countTokens" {
action = "countTokens"
}
}
baseURL := vertexBaseURL(location)
url := fmt.Sprintf("%s/%s/projects/%s/locations/%s/publishers/google/models/%s:%s", baseURL, vertexAPIVersion, projectID, location, req.Model, action)
if opts.Alt != "" && action != "countTokens" {
url = url + fmt.Sprintf("?$alt=%s", opts.Alt)
}
body, _ = sjson.DeleteBytes(body, "session_id")
httpReq, errNewReq := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(body))
if errNewReq != nil {
return resp, errNewReq
}
httpReq.Header.Set("Content-Type", "application/json")
if token, errTok := vertexAccessToken(ctx, e.cfg, auth, saJSON); errTok == nil && token != "" {
httpReq.Header.Set("Authorization", "Bearer "+token)
} else if errTok != nil {
log.Errorf("vertex executor: access token error: %v", errTok)
return resp, statusErr{code: 500, msg: "internal server error"}
}
applyGeminiHeaders(httpReq, auth)
var authID, authLabel, authType, authValue string
if auth != nil {
authID = auth.ID
authLabel = auth.Label
authType, authValue = auth.AccountInfo()
}
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{
URL: url,
Method: http.MethodPost,
Headers: httpReq.Header.Clone(),
Body: body,
Provider: e.Identifier(),
AuthID: authID,
AuthLabel: authLabel,
AuthType: authType,
AuthValue: authValue,
})
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
httpResp, errDo := httpClient.Do(httpReq)
if errDo != nil {
recordAPIResponseError(ctx, e.cfg, errDo)
return resp, errDo
}
defer func() {
if errClose := httpResp.Body.Close(); errClose != nil {
log.Errorf("vertex executor: close response body error: %v", errClose)
}
}()
recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
b, _ := io.ReadAll(httpResp.Body)
appendAPIResponseChunk(ctx, e.cfg, b)
log.Debugf("request error, error status: %d, error body: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
err = statusErr{code: httpResp.StatusCode, msg: string(b)}
return resp, err
}
data, errRead := io.ReadAll(httpResp.Body)
if errRead != nil {
recordAPIResponseError(ctx, e.cfg, errRead)
return resp, errRead
}
appendAPIResponseChunk(ctx, e.cfg, data)
reporter.publish(ctx, parseGeminiUsage(data))
var param any
out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, bytes.Clone(opts.OriginalRequest), body, data, &param)
resp = cliproxyexecutor.Response{Payload: []byte(out)}
return resp, nil
}
// ExecuteStream handles SSE streaming for Vertex.
func (e *GeminiVertexExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (stream <-chan cliproxyexecutor.StreamChunk, err error) {
projectID, location, saJSON, errCreds := vertexCreds(auth)
if errCreds != nil {
return nil, errCreds
}
reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth)
defer reporter.trackFailure(ctx, &err)
from := opts.SourceFormat
to := sdktranslator.FromString("gemini")
body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), true)
if budgetOverride, includeOverride, ok := util.GeminiThinkingFromMetadata(req.Metadata); ok && util.ModelSupportsThinking(req.Model) {
if budgetOverride != nil {
norm := util.NormalizeThinkingBudget(req.Model, *budgetOverride)
budgetOverride = &norm
}
body = util.ApplyGeminiThinkingConfig(body, budgetOverride, includeOverride)
}
body = util.StripThinkingConfigIfUnsupported(req.Model, body)
body = fixGeminiImageAspectRatio(req.Model, body)
baseURL := vertexBaseURL(location)
url := fmt.Sprintf("%s/%s/projects/%s/locations/%s/publishers/google/models/%s:%s", baseURL, vertexAPIVersion, projectID, location, req.Model, "streamGenerateContent")
if opts.Alt == "" {
url = url + "?alt=sse"
} else {
url = url + fmt.Sprintf("?$alt=%s", opts.Alt)
}
body, _ = sjson.DeleteBytes(body, "session_id")
httpReq, errNewReq := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(body))
if errNewReq != nil {
return nil, errNewReq
}
httpReq.Header.Set("Content-Type", "application/json")
if token, errTok := vertexAccessToken(ctx, e.cfg, auth, saJSON); errTok == nil && token != "" {
httpReq.Header.Set("Authorization", "Bearer "+token)
} else if errTok != nil {
log.Errorf("vertex executor: access token error: %v", errTok)
return nil, statusErr{code: 500, msg: "internal server error"}
}
applyGeminiHeaders(httpReq, auth)
var authID, authLabel, authType, authValue string
if auth != nil {
authID = auth.ID
authLabel = auth.Label
authType, authValue = auth.AccountInfo()
}
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{
URL: url,
Method: http.MethodPost,
Headers: httpReq.Header.Clone(),
Body: body,
Provider: e.Identifier(),
AuthID: authID,
AuthLabel: authLabel,
AuthType: authType,
AuthValue: authValue,
})
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
httpResp, errDo := httpClient.Do(httpReq)
if errDo != nil {
recordAPIResponseError(ctx, e.cfg, errDo)
return nil, errDo
}
recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
b, _ := io.ReadAll(httpResp.Body)
appendAPIResponseChunk(ctx, e.cfg, b)
log.Debugf("request error, error status: %d, error body: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
if errClose := httpResp.Body.Close(); errClose != nil {
log.Errorf("vertex executor: close response body error: %v", errClose)
}
return nil, statusErr{code: httpResp.StatusCode, msg: string(b)}
}
out := make(chan cliproxyexecutor.StreamChunk)
stream = out
go func() {
defer close(out)
defer func() {
if errClose := httpResp.Body.Close(); errClose != nil {
log.Errorf("vertex executor: close response body error: %v", errClose)
}
}()
scanner := bufio.NewScanner(httpResp.Body)
buf := make([]byte, 20_971_520)
scanner.Buffer(buf, 20_971_520)
var param any
for scanner.Scan() {
line := scanner.Bytes()
appendAPIResponseChunk(ctx, e.cfg, line)
if detail, ok := parseGeminiStreamUsage(line); ok {
reporter.publish(ctx, detail)
}
lines := sdktranslator.TranslateStream(ctx, to, from, req.Model, bytes.Clone(opts.OriginalRequest), body, bytes.Clone(line), &param)
for i := range lines {
out <- cliproxyexecutor.StreamChunk{Payload: []byte(lines[i])}
}
}
lines := sdktranslator.TranslateStream(ctx, to, from, req.Model, bytes.Clone(opts.OriginalRequest), body, []byte("[DONE]"), &param)
for i := range lines {
out <- cliproxyexecutor.StreamChunk{Payload: []byte(lines[i])}
}
if errScan := scanner.Err(); errScan != nil {
recordAPIResponseError(ctx, e.cfg, errScan)
reporter.publishFailure(ctx)
out <- cliproxyexecutor.StreamChunk{Err: errScan}
}
}()
return stream, nil
}
// CountTokens calls Vertex countTokens endpoint.
func (e *GeminiVertexExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
projectID, location, saJSON, errCreds := vertexCreds(auth)
if errCreds != nil {
return cliproxyexecutor.Response{}, errCreds
}
from := opts.SourceFormat
to := sdktranslator.FromString("gemini")
translatedReq := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), false)
if budgetOverride, includeOverride, ok := util.GeminiThinkingFromMetadata(req.Metadata); ok && util.ModelSupportsThinking(req.Model) {
if budgetOverride != nil {
norm := util.NormalizeThinkingBudget(req.Model, *budgetOverride)
budgetOverride = &norm
}
translatedReq = util.ApplyGeminiThinkingConfig(translatedReq, budgetOverride, includeOverride)
}
translatedReq = util.StripThinkingConfigIfUnsupported(req.Model, translatedReq)
translatedReq = fixGeminiImageAspectRatio(req.Model, translatedReq)
respCtx := context.WithValue(ctx, "alt", opts.Alt)
translatedReq, _ = sjson.DeleteBytes(translatedReq, "tools")
translatedReq, _ = sjson.DeleteBytes(translatedReq, "generationConfig")
translatedReq, _ = sjson.DeleteBytes(translatedReq, "safetySettings")
baseURL := vertexBaseURL(location)
url := fmt.Sprintf("%s/%s/projects/%s/locations/%s/publishers/google/models/%s:%s", baseURL, vertexAPIVersion, projectID, location, req.Model, "countTokens")
httpReq, errNewReq := http.NewRequestWithContext(respCtx, http.MethodPost, url, bytes.NewReader(translatedReq))
if errNewReq != nil {
return cliproxyexecutor.Response{}, errNewReq
}
httpReq.Header.Set("Content-Type", "application/json")
if token, errTok := vertexAccessToken(ctx, e.cfg, auth, saJSON); errTok == nil && token != "" {
httpReq.Header.Set("Authorization", "Bearer "+token)
} else if errTok != nil {
log.Errorf("vertex executor: access token error: %v", errTok)
return cliproxyexecutor.Response{}, statusErr{code: 500, msg: "internal server error"}
}
applyGeminiHeaders(httpReq, auth)
var authID, authLabel, authType, authValue string
if auth != nil {
authID = auth.ID
authLabel = auth.Label
authType, authValue = auth.AccountInfo()
}
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{
URL: url,
Method: http.MethodPost,
Headers: httpReq.Header.Clone(),
Body: translatedReq,
Provider: e.Identifier(),
AuthID: authID,
AuthLabel: authLabel,
AuthType: authType,
AuthValue: authValue,
})
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
httpResp, errDo := httpClient.Do(httpReq)
if errDo != nil {
recordAPIResponseError(ctx, e.cfg, errDo)
return cliproxyexecutor.Response{}, errDo
}
defer func() {
if errClose := httpResp.Body.Close(); errClose != nil {
log.Errorf("vertex executor: close response body error: %v", errClose)
}
}()
recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
b, _ := io.ReadAll(httpResp.Body)
appendAPIResponseChunk(ctx, e.cfg, b)
log.Debugf("request error, error status: %d, error body: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
return cliproxyexecutor.Response{}, statusErr{code: httpResp.StatusCode, msg: string(b)}
}
data, errRead := io.ReadAll(httpResp.Body)
if errRead != nil {
recordAPIResponseError(ctx, e.cfg, errRead)
return cliproxyexecutor.Response{}, errRead
}
appendAPIResponseChunk(ctx, e.cfg, data)
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
log.Debugf("request error, error status: %d, error body: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), data))
return cliproxyexecutor.Response{}, statusErr{code: httpResp.StatusCode, msg: string(data)}
}
count := gjson.GetBytes(data, "totalTokens").Int()
out := sdktranslator.TranslateTokenCount(ctx, to, from, count, data)
return cliproxyexecutor.Response{Payload: []byte(out)}, nil
}
// Refresh is a no-op for service account based credentials.
func (e *GeminiVertexExecutor) Refresh(_ context.Context, auth *cliproxyauth.Auth) (*cliproxyauth.Auth, error) {
return auth, nil
}
// vertexCreds extracts project, location and raw service account JSON from auth metadata.
func vertexCreds(a *cliproxyauth.Auth) (projectID, location string, serviceAccountJSON []byte, err error) {
if a == nil || a.Metadata == nil {
return "", "", nil, fmt.Errorf("vertex executor: missing auth metadata")
}
if v, ok := a.Metadata["project_id"].(string); ok {
projectID = strings.TrimSpace(v)
}
if projectID == "" {
// Some service accounts may use "project"; still prefer standard field
if v, ok := a.Metadata["project"].(string); ok {
projectID = strings.TrimSpace(v)
}
}
if projectID == "" {
return "", "", nil, fmt.Errorf("vertex executor: missing project_id in credentials")
}
if v, ok := a.Metadata["location"].(string); ok && strings.TrimSpace(v) != "" {
location = strings.TrimSpace(v)
} else {
location = "us-central1"
}
var sa map[string]any
if raw, ok := a.Metadata["service_account"].(map[string]any); ok {
sa = raw
}
if sa == nil {
return "", "", nil, fmt.Errorf("vertex executor: missing service_account in credentials")
}
normalized, errNorm := vertexauth.NormalizeServiceAccountMap(sa)
if errNorm != nil {
return "", "", nil, fmt.Errorf("vertex executor: %w", errNorm)
}
saJSON, errMarshal := json.Marshal(normalized)
if errMarshal != nil {
return "", "", nil, fmt.Errorf("vertex executor: marshal service_account failed: %w", errMarshal)
}
return projectID, location, saJSON, nil
}
func vertexBaseURL(location string) string {
loc := strings.TrimSpace(location)
if loc == "" {
loc = "us-central1"
}
return fmt.Sprintf("https://%s-aiplatform.googleapis.com", loc)
}
func vertexAccessToken(ctx context.Context, cfg *config.Config, auth *cliproxyauth.Auth, saJSON []byte) (string, error) {
if httpClient := newProxyAwareHTTPClient(ctx, cfg, auth, 0); httpClient != nil {
ctx = context.WithValue(ctx, oauth2.HTTPClient, httpClient)
}
// Use cloud-platform scope for Vertex AI.
creds, errCreds := google.CredentialsFromJSON(ctx, saJSON, "https://www.googleapis.com/auth/cloud-platform")
if errCreds != nil {
return "", fmt.Errorf("vertex executor: parse service account json failed: %w", errCreds)
}
tok, errTok := creds.TokenSource.Token()
if errTok != nil {
return "", fmt.Errorf("vertex executor: get access token failed: %w", errTok)
}
return tok.AccessToken, nil
}

View File

@@ -129,6 +129,21 @@ func apiKeyFromContext(ctx context.Context) string {
func resolveUsageSource(auth *cliproxyauth.Auth, ctxAPIKey string) string {
if auth != nil {
provider := strings.TrimSpace(auth.Provider)
if strings.EqualFold(provider, "vertex") {
if auth.Metadata != nil {
if projectID, ok := auth.Metadata["project_id"].(string); ok {
if trimmed := strings.TrimSpace(projectID); trimmed != "" {
return trimmed
}
}
if project, ok := auth.Metadata["project"].(string); ok {
if trimmed := strings.TrimSpace(project); trimmed != "" {
return trimmed
}
}
}
}
if _, value := auth.AccountInfo(); value != "" {
return strings.TrimSpace(value)
}

View File

@@ -159,7 +159,6 @@ func ConvertOpenAIRequestToGemini(modelName string, inputRawJSON []byte, _ bool)
}
}
}
fmt.Printf("11111")
for i := 0; i < len(arr); i++ {
m := arr[i]

View File

@@ -9,6 +9,7 @@ import (
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
log "github.com/sirupsen/logrus"
)
// GetProviderName determines all AI service providers capable of serving a registered model.
@@ -59,6 +60,30 @@ func GetProviderName(modelName string) []string {
return providers
}
// ResolveAutoModel resolves the "auto" model name to an actual available model.
// It uses an empty handler type to get any available model from the registry.
//
// Parameters:
// - modelName: The model name to check (should be "auto")
//
// Returns:
// - string: The resolved model name, or the original if not "auto" or resolution fails
func ResolveAutoModel(modelName string) string {
if modelName != "auto" {
return modelName
}
// Use empty string as handler type to get any available model
firstModel, err := registry.GetGlobalRegistry().GetFirstAvailableModel("")
if err != nil {
log.Warnf("Failed to resolve 'auto' model: %v, falling back to original model name", err)
return modelName
}
log.Infof("Resolved 'auto' model to: %s", firstModel)
return firstModel
}
// IsOpenAICompatibilityAlias checks if the given model name is an alias
// configured for OpenAI compatibility routing.
//

View File

@@ -41,24 +41,26 @@ type authDirProvider interface {
// Watcher manages file watching for configuration and authentication files
type Watcher struct {
configPath string
authDir string
config *config.Config
clientsMutex sync.RWMutex
reloadCallback func(*config.Config)
watcher *fsnotify.Watcher
lastAuthHashes map[string]string
lastConfigHash string
authQueue chan<- AuthUpdate
currentAuths map[string]*coreauth.Auth
dispatchMu sync.Mutex
dispatchCond *sync.Cond
pendingUpdates map[string]AuthUpdate
pendingOrder []string
dispatchCancel context.CancelFunc
storePersister storePersister
mirroredAuthDir string
oldConfigYaml []byte
configPath string
authDir string
config *config.Config
clientsMutex sync.RWMutex
configReloadMu sync.Mutex
configReloadTimer *time.Timer
reloadCallback func(*config.Config)
watcher *fsnotify.Watcher
lastAuthHashes map[string]string
lastConfigHash string
authQueue chan<- AuthUpdate
currentAuths map[string]*coreauth.Auth
dispatchMu sync.Mutex
dispatchCond *sync.Cond
pendingUpdates map[string]AuthUpdate
pendingOrder []string
dispatchCancel context.CancelFunc
storePersister storePersister
mirroredAuthDir string
oldConfigYaml []byte
}
type stableIDGenerator struct {
@@ -113,7 +115,8 @@ type AuthUpdate struct {
const (
// replaceCheckDelay is a short delay to allow atomic replace (rename) to settle
// before deciding whether a Remove event indicates a real deletion.
replaceCheckDelay = 50 * time.Millisecond
replaceCheckDelay = 50 * time.Millisecond
configReloadDebounce = 150 * time.Millisecond
)
// NewWatcher creates a new file watcher instance
@@ -172,9 +175,19 @@ func (w *Watcher) Start(ctx context.Context) error {
// Stop stops the file watcher
func (w *Watcher) Stop() error {
w.stopDispatch()
w.stopConfigReloadTimer()
return w.watcher.Close()
}
func (w *Watcher) stopConfigReloadTimer() {
w.configReloadMu.Lock()
if w.configReloadTimer != nil {
w.configReloadTimer.Stop()
w.configReloadTimer = nil
}
w.configReloadMu.Unlock()
}
// SetConfig updates the current configuration
func (w *Watcher) SetConfig(cfg *config.Config) {
w.clientsMutex.Lock()
@@ -476,40 +489,7 @@ func (w *Watcher) handleEvent(event fsnotify.Event) {
// Handle config file changes
if isConfigEvent {
log.Debugf("config file change details - operation: %s, timestamp: %s", event.Op.String(), now.Format("2006-01-02 15:04:05.000"))
data, err := os.ReadFile(w.configPath)
if err != nil {
log.Errorf("failed to read config file for hash check: %v", err)
return
}
if len(data) == 0 {
log.Debugf("ignoring empty config file write event")
return
}
sum := sha256.Sum256(data)
newHash := hex.EncodeToString(sum[:])
w.clientsMutex.RLock()
currentHash := w.lastConfigHash
w.clientsMutex.RUnlock()
if currentHash != "" && currentHash == newHash {
log.Debugf("config file content unchanged (hash match), skipping reload")
return
}
fmt.Printf("config file changed, reloading: %s\n", w.configPath)
if w.reloadConfig() {
finalHash := newHash
if updatedData, errRead := os.ReadFile(w.configPath); errRead == nil && len(updatedData) > 0 {
sumUpdated := sha256.Sum256(updatedData)
finalHash = hex.EncodeToString(sumUpdated[:])
} else if errRead != nil {
log.WithError(errRead).Debug("failed to compute updated config hash after reload")
}
w.clientsMutex.Lock()
w.lastConfigHash = finalHash
w.clientsMutex.Unlock()
w.persistConfigAsync()
}
w.scheduleConfigReload()
return
}
@@ -530,6 +510,57 @@ func (w *Watcher) handleEvent(event fsnotify.Event) {
}
}
func (w *Watcher) scheduleConfigReload() {
w.configReloadMu.Lock()
defer w.configReloadMu.Unlock()
if w.configReloadTimer != nil {
w.configReloadTimer.Stop()
}
w.configReloadTimer = time.AfterFunc(configReloadDebounce, func() {
w.configReloadMu.Lock()
w.configReloadTimer = nil
w.configReloadMu.Unlock()
w.reloadConfigIfChanged()
})
}
func (w *Watcher) reloadConfigIfChanged() {
data, err := os.ReadFile(w.configPath)
if err != nil {
log.Errorf("failed to read config file for hash check: %v", err)
return
}
if len(data) == 0 {
log.Debugf("ignoring empty config file write event")
return
}
sum := sha256.Sum256(data)
newHash := hex.EncodeToString(sum[:])
w.clientsMutex.RLock()
currentHash := w.lastConfigHash
w.clientsMutex.RUnlock()
if currentHash != "" && currentHash == newHash {
log.Debugf("config file content unchanged (hash match), skipping reload")
return
}
fmt.Printf("config file changed, reloading: %s\n", w.configPath)
if w.reloadConfig() {
finalHash := newHash
if updatedData, errRead := os.ReadFile(w.configPath); errRead == nil && len(updatedData) > 0 {
sumUpdated := sha256.Sum256(updatedData)
finalHash = hex.EncodeToString(sumUpdated[:])
} else if errRead != nil {
log.WithError(errRead).Debug("failed to compute updated config hash after reload")
}
w.clientsMutex.Lock()
w.lastConfigHash = finalHash
w.clientsMutex.Unlock()
w.persistConfigAsync()
}
}
// reloadConfig reloads the configuration and triggers a full reload
func (w *Watcher) reloadConfig() bool {
log.Debug("=========================== CONFIG RELOAD ============================")

View File

@@ -295,11 +295,14 @@ func (h *BaseAPIHandler) ExecuteStreamWithAuthManager(ctx context.Context, handl
}
func (h *BaseAPIHandler) getRequestDetails(modelName string) (providers []string, normalizedModel string, metadata map[string]any, err *interfaces.ErrorMessage) {
providerName, extractedModelName, isDynamic := h.parseDynamicModel(modelName)
// Resolve "auto" model to an actual available model first
resolvedModelName := util.ResolveAutoModel(modelName)
providerName, extractedModelName, isDynamic := h.parseDynamicModel(resolvedModelName)
// First, normalize the model name to handle suffixes like "-thinking-128"
// This needs to happen before determining the provider for non-dynamic models.
normalizedModel, metadata = normalizeModelMetadata(modelName)
normalizedModel, metadata = normalizeModelMetadata(resolvedModelName)
if isDynamic {
providers = []string{providerName}

View File

@@ -324,6 +324,8 @@ func (s *Service) ensureExecutorsForAuth(a *coreauth.Auth) {
switch strings.ToLower(a.Provider) {
case "gemini":
s.coreManager.RegisterExecutor(executor.NewGeminiExecutor(s.cfg))
case "vertex":
s.coreManager.RegisterExecutor(executor.NewGeminiVertexExecutor(s.cfg))
case "gemini-cli":
s.coreManager.RegisterExecutor(executor.NewGeminiCLIExecutor(s.cfg))
case "aistudio":
@@ -619,6 +621,9 @@ func (s *Service) registerModelsForAuth(a *coreauth.Auth) {
switch provider {
case "gemini":
models = registry.GetGeminiModels()
case "vertex":
// Vertex AI Gemini supports the same model identifiers as Gemini.
models = registry.GetGeminiModels()
case "gemini-cli":
models = registry.GetGeminiCLIModels()
case "aistudio":