mirror of
https://github.com/router-for-me/CLIProxyAPI.git
synced 2026-02-03 21:10:51 +08:00
Compare commits
18 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
a7d2f669e7 | ||
|
|
ce569ab36e | ||
|
|
d0aa741d59 | ||
|
|
592f6fc66b | ||
|
|
09ecba6dab | ||
|
|
d6bd6f3fb9 | ||
|
|
92f4278039 | ||
|
|
8ae8a5c296 | ||
|
|
dc804e96fb | ||
|
|
ab76cb3662 | ||
|
|
2965bdadc1 | ||
|
|
40f7061b04 | ||
|
|
8c947cafbe | ||
|
|
717eadf128 | ||
|
|
9e105738fd | ||
|
|
5d806fcefc | ||
|
|
6ae1dd78ed | ||
|
|
43095de162 |
@@ -17,6 +17,7 @@ import (
|
||||
|
||||
"github.com/joho/godotenv"
|
||||
configaccess "github.com/router-for-me/CLIProxyAPI/v6/internal/access/config_access"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/buildinfo"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/cmd"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/logging"
|
||||
@@ -41,13 +42,16 @@ var (
|
||||
// init initializes the shared logger setup.
|
||||
func init() {
|
||||
logging.SetupBaseLogger()
|
||||
buildinfo.Version = Version
|
||||
buildinfo.Commit = Commit
|
||||
buildinfo.BuildDate = BuildDate
|
||||
}
|
||||
|
||||
// main is the entry point of the application.
|
||||
// It parses command-line flags, loads configuration, and starts the appropriate
|
||||
// service based on the provided flags (login, codex-login, or server mode).
|
||||
func main() {
|
||||
fmt.Printf("CLIProxyAPI Version: %s, Commit: %s, BuiltAt: %s\n", Version, Commit, BuildDate)
|
||||
fmt.Printf("CLIProxyAPI Version: %s, Commit: %s, BuiltAt: %s\n", buildinfo.Version, buildinfo.Commit, buildinfo.BuildDate)
|
||||
|
||||
// Command-line flags to control the application's behavior.
|
||||
var login bool
|
||||
@@ -57,6 +61,7 @@ func main() {
|
||||
var iflowLogin bool
|
||||
var noBrowser bool
|
||||
var projectID string
|
||||
var vertexImport string
|
||||
var configPath string
|
||||
var password string
|
||||
|
||||
@@ -69,6 +74,7 @@ func main() {
|
||||
flag.BoolVar(&noBrowser, "no-browser", false, "Don't open browser automatically for OAuth")
|
||||
flag.StringVar(&projectID, "project_id", "", "Project ID (Gemini only, not required)")
|
||||
flag.StringVar(&configPath, "config", DefaultConfigPath, "Configure File Path")
|
||||
flag.StringVar(&vertexImport, "vertex-import", "", "Import Vertex service account key JSON file")
|
||||
flag.StringVar(&password, "password", "", "")
|
||||
|
||||
flag.CommandLine.Usage = func() {
|
||||
@@ -384,7 +390,7 @@ func main() {
|
||||
log.Fatalf("failed to configure log output: %v", err)
|
||||
}
|
||||
|
||||
log.Infof("CLIProxyAPI Version: %s, Commit: %s, BuiltAt: %s", Version, Commit, BuildDate)
|
||||
log.Infof("CLIProxyAPI Version: %s, Commit: %s, BuiltAt: %s", buildinfo.Version, buildinfo.Commit, buildinfo.BuildDate)
|
||||
|
||||
// Set the log level based on the configuration.
|
||||
util.SetLogLevel(cfg)
|
||||
@@ -417,7 +423,10 @@ func main() {
|
||||
|
||||
// Handle different command modes based on the provided flags.
|
||||
|
||||
if login {
|
||||
if vertexImport != "" {
|
||||
// Handle Vertex service account import
|
||||
cmd.DoVertexImport(cfg, vertexImport)
|
||||
} else if login {
|
||||
// Handle Google/Gemini login
|
||||
cmd.DoLogin(cfg, projectID, options)
|
||||
} else if codexLogin {
|
||||
|
||||
@@ -293,6 +293,9 @@ func (h *Handler) buildAuthFileEntry(auth *coreauth.Auth) gin.H {
|
||||
return nil
|
||||
}
|
||||
runtimeOnly := isRuntimeOnlyAuth(auth)
|
||||
if runtimeOnly && (auth.Disabled || auth.Status == coreauth.StatusDisabled) {
|
||||
return nil
|
||||
}
|
||||
path := strings.TrimSpace(authAttribute(auth, "path"))
|
||||
if path == "" && !runtimeOnly {
|
||||
return nil
|
||||
@@ -505,6 +508,10 @@ func (h *Handler) DeleteAuthFile(c *gin.Context) {
|
||||
}
|
||||
}
|
||||
if err = os.Remove(full); err == nil {
|
||||
if errDel := h.deleteTokenRecord(ctx, full); errDel != nil {
|
||||
c.JSON(500, gin.H{"error": errDel.Error()})
|
||||
return
|
||||
}
|
||||
deleted++
|
||||
h.disableAuth(ctx, full)
|
||||
}
|
||||
@@ -531,10 +538,32 @@ func (h *Handler) DeleteAuthFile(c *gin.Context) {
|
||||
}
|
||||
return
|
||||
}
|
||||
if err := h.deleteTokenRecord(ctx, full); err != nil {
|
||||
c.JSON(500, gin.H{"error": err.Error()})
|
||||
return
|
||||
}
|
||||
h.disableAuth(ctx, full)
|
||||
c.JSON(200, gin.H{"status": "ok"})
|
||||
}
|
||||
|
||||
func (h *Handler) authIDForPath(path string) string {
|
||||
path = strings.TrimSpace(path)
|
||||
if path == "" {
|
||||
return ""
|
||||
}
|
||||
if h == nil || h.cfg == nil {
|
||||
return path
|
||||
}
|
||||
authDir := strings.TrimSpace(h.cfg.AuthDir)
|
||||
if authDir == "" {
|
||||
return path
|
||||
}
|
||||
if rel, err := filepath.Rel(authDir, path); err == nil && rel != "" {
|
||||
return rel
|
||||
}
|
||||
return path
|
||||
}
|
||||
|
||||
func (h *Handler) registerAuthFromFile(ctx context.Context, path string, data []byte) error {
|
||||
if h.authManager == nil {
|
||||
return nil
|
||||
@@ -563,13 +592,18 @@ func (h *Handler) registerAuthFromFile(ctx context.Context, path string, data []
|
||||
}
|
||||
lastRefresh, hasLastRefresh := extractLastRefreshTimestamp(metadata)
|
||||
|
||||
authID := h.authIDForPath(path)
|
||||
if authID == "" {
|
||||
authID = path
|
||||
}
|
||||
attr := map[string]string{
|
||||
"path": path,
|
||||
"source": path,
|
||||
}
|
||||
auth := &coreauth.Auth{
|
||||
ID: path,
|
||||
ID: authID,
|
||||
Provider: provider,
|
||||
FileName: filepath.Base(path),
|
||||
Label: label,
|
||||
Status: coreauth.StatusActive,
|
||||
Attributes: attr,
|
||||
@@ -580,7 +614,7 @@ func (h *Handler) registerAuthFromFile(ctx context.Context, path string, data []
|
||||
if hasLastRefresh {
|
||||
auth.LastRefreshedAt = lastRefresh
|
||||
}
|
||||
if existing, ok := h.authManager.GetByID(path); ok {
|
||||
if existing, ok := h.authManager.GetByID(authID); ok {
|
||||
auth.CreatedAt = existing.CreatedAt
|
||||
if !hasLastRefresh {
|
||||
auth.LastRefreshedAt = existing.LastRefreshedAt
|
||||
@@ -595,10 +629,17 @@ func (h *Handler) registerAuthFromFile(ctx context.Context, path string, data []
|
||||
}
|
||||
|
||||
func (h *Handler) disableAuth(ctx context.Context, id string) {
|
||||
if h.authManager == nil || id == "" {
|
||||
if h == nil || h.authManager == nil {
|
||||
return
|
||||
}
|
||||
if auth, ok := h.authManager.GetByID(id); ok {
|
||||
authID := h.authIDForPath(id)
|
||||
if authID == "" {
|
||||
authID = strings.TrimSpace(id)
|
||||
}
|
||||
if authID == "" {
|
||||
return
|
||||
}
|
||||
if auth, ok := h.authManager.GetByID(authID); ok {
|
||||
auth.Disabled = true
|
||||
auth.Status = coreauth.StatusDisabled
|
||||
auth.StatusMessage = "removed via management API"
|
||||
@@ -607,9 +648,20 @@ func (h *Handler) disableAuth(ctx context.Context, id string) {
|
||||
}
|
||||
}
|
||||
|
||||
func (h *Handler) saveTokenRecord(ctx context.Context, record *coreauth.Auth) (string, error) {
|
||||
if record == nil {
|
||||
return "", fmt.Errorf("token record is nil")
|
||||
func (h *Handler) deleteTokenRecord(ctx context.Context, path string) error {
|
||||
if strings.TrimSpace(path) == "" {
|
||||
return fmt.Errorf("auth path is empty")
|
||||
}
|
||||
store := h.tokenStoreWithBaseDir()
|
||||
if store == nil {
|
||||
return fmt.Errorf("token store unavailable")
|
||||
}
|
||||
return store.Delete(ctx, path)
|
||||
}
|
||||
|
||||
func (h *Handler) tokenStoreWithBaseDir() coreauth.Store {
|
||||
if h == nil {
|
||||
return nil
|
||||
}
|
||||
store := h.tokenStore
|
||||
if store == nil {
|
||||
@@ -621,6 +673,17 @@ func (h *Handler) saveTokenRecord(ctx context.Context, record *coreauth.Auth) (s
|
||||
dirSetter.SetBaseDir(h.cfg.AuthDir)
|
||||
}
|
||||
}
|
||||
return store
|
||||
}
|
||||
|
||||
func (h *Handler) saveTokenRecord(ctx context.Context, record *coreauth.Auth) (string, error) {
|
||||
if record == nil {
|
||||
return "", fmt.Errorf("token record is nil")
|
||||
}
|
||||
store := h.tokenStoreWithBaseDir()
|
||||
if store == nil {
|
||||
return "", fmt.Errorf("token store unavailable")
|
||||
}
|
||||
return store.Save(ctx, record)
|
||||
}
|
||||
|
||||
@@ -968,29 +1031,46 @@ func (h *Handler) RequestGeminiCLIToken(c *gin.Context) {
|
||||
}
|
||||
fmt.Println("Authentication successful.")
|
||||
|
||||
if errEnsure := ensureGeminiProjectAndOnboard(ctx, gemClient, &ts, requestedProjectID); errEnsure != nil {
|
||||
log.Errorf("Failed to complete Gemini CLI onboarding: %v", errEnsure)
|
||||
oauthStatus[state] = "Failed to complete Gemini CLI onboarding"
|
||||
return
|
||||
}
|
||||
if strings.EqualFold(requestedProjectID, "ALL") {
|
||||
ts.Auto = false
|
||||
projects, errAll := onboardAllGeminiProjects(ctx, gemClient, &ts)
|
||||
if errAll != nil {
|
||||
log.Errorf("Failed to complete Gemini CLI onboarding: %v", errAll)
|
||||
oauthStatus[state] = "Failed to complete Gemini CLI onboarding"
|
||||
return
|
||||
}
|
||||
if errVerify := ensureGeminiProjectsEnabled(ctx, gemClient, projects); errVerify != nil {
|
||||
log.Errorf("Failed to verify Cloud AI API status: %v", errVerify)
|
||||
oauthStatus[state] = "Failed to verify Cloud AI API status"
|
||||
return
|
||||
}
|
||||
ts.ProjectID = strings.Join(projects, ",")
|
||||
ts.Checked = true
|
||||
} else {
|
||||
if errEnsure := ensureGeminiProjectAndOnboard(ctx, gemClient, &ts, requestedProjectID); errEnsure != nil {
|
||||
log.Errorf("Failed to complete Gemini CLI onboarding: %v", errEnsure)
|
||||
oauthStatus[state] = "Failed to complete Gemini CLI onboarding"
|
||||
return
|
||||
}
|
||||
|
||||
if strings.TrimSpace(ts.ProjectID) == "" {
|
||||
log.Error("Onboarding did not return a project ID")
|
||||
oauthStatus[state] = "Failed to resolve project ID"
|
||||
return
|
||||
}
|
||||
if strings.TrimSpace(ts.ProjectID) == "" {
|
||||
log.Error("Onboarding did not return a project ID")
|
||||
oauthStatus[state] = "Failed to resolve project ID"
|
||||
return
|
||||
}
|
||||
|
||||
isChecked, errCheck := checkCloudAPIIsEnabled(ctx, gemClient, ts.ProjectID)
|
||||
if errCheck != nil {
|
||||
log.Errorf("Failed to verify Cloud AI API status: %v", errCheck)
|
||||
oauthStatus[state] = "Failed to verify Cloud AI API status"
|
||||
return
|
||||
}
|
||||
ts.Checked = isChecked
|
||||
if !isChecked {
|
||||
log.Error("Cloud AI API is not enabled for the selected project")
|
||||
oauthStatus[state] = "Cloud AI API not enabled"
|
||||
return
|
||||
isChecked, errCheck := checkCloudAPIIsEnabled(ctx, gemClient, ts.ProjectID)
|
||||
if errCheck != nil {
|
||||
log.Errorf("Failed to verify Cloud AI API status: %v", errCheck)
|
||||
oauthStatus[state] = "Failed to verify Cloud AI API status"
|
||||
return
|
||||
}
|
||||
ts.Checked = isChecked
|
||||
if !isChecked {
|
||||
log.Error("Cloud AI API is not enabled for the selected project")
|
||||
oauthStatus[state] = "Cloud AI API not enabled"
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
recordMetadata := map[string]any{
|
||||
@@ -1000,10 +1080,11 @@ func (h *Handler) RequestGeminiCLIToken(c *gin.Context) {
|
||||
"checked": ts.Checked,
|
||||
}
|
||||
|
||||
fileName := geminiAuth.CredentialFileName(ts.Email, ts.ProjectID, true)
|
||||
record := &coreauth.Auth{
|
||||
ID: fmt.Sprintf("gemini-%s-%s.json", ts.Email, ts.ProjectID),
|
||||
ID: fileName,
|
||||
Provider: "gemini",
|
||||
FileName: fmt.Sprintf("gemini-%s-%s.json", ts.Email, ts.ProjectID),
|
||||
FileName: fileName,
|
||||
Storage: &ts,
|
||||
Metadata: recordMetadata,
|
||||
}
|
||||
@@ -1396,6 +1477,57 @@ func ensureGeminiProjectAndOnboard(ctx context.Context, httpClient *http.Client,
|
||||
return nil
|
||||
}
|
||||
|
||||
func onboardAllGeminiProjects(ctx context.Context, httpClient *http.Client, storage *geminiAuth.GeminiTokenStorage) ([]string, error) {
|
||||
projects, errProjects := fetchGCPProjects(ctx, httpClient)
|
||||
if errProjects != nil {
|
||||
return nil, fmt.Errorf("fetch project list: %w", errProjects)
|
||||
}
|
||||
if len(projects) == 0 {
|
||||
return nil, fmt.Errorf("no Google Cloud projects available for this account")
|
||||
}
|
||||
activated := make([]string, 0, len(projects))
|
||||
seen := make(map[string]struct{}, len(projects))
|
||||
for _, project := range projects {
|
||||
candidate := strings.TrimSpace(project.ProjectID)
|
||||
if candidate == "" {
|
||||
continue
|
||||
}
|
||||
if _, dup := seen[candidate]; dup {
|
||||
continue
|
||||
}
|
||||
if err := performGeminiCLISetup(ctx, httpClient, storage, candidate); err != nil {
|
||||
return nil, fmt.Errorf("onboard project %s: %w", candidate, err)
|
||||
}
|
||||
finalID := strings.TrimSpace(storage.ProjectID)
|
||||
if finalID == "" {
|
||||
finalID = candidate
|
||||
}
|
||||
activated = append(activated, finalID)
|
||||
seen[candidate] = struct{}{}
|
||||
}
|
||||
if len(activated) == 0 {
|
||||
return nil, fmt.Errorf("no Google Cloud projects available for this account")
|
||||
}
|
||||
return activated, nil
|
||||
}
|
||||
|
||||
func ensureGeminiProjectsEnabled(ctx context.Context, httpClient *http.Client, projectIDs []string) error {
|
||||
for _, pid := range projectIDs {
|
||||
trimmed := strings.TrimSpace(pid)
|
||||
if trimmed == "" {
|
||||
continue
|
||||
}
|
||||
isChecked, errCheck := checkCloudAPIIsEnabled(ctx, httpClient, trimmed)
|
||||
if errCheck != nil {
|
||||
return fmt.Errorf("project %s: %w", trimmed, errCheck)
|
||||
}
|
||||
if !isChecked {
|
||||
return fmt.Errorf("project %s: Cloud AI API not enabled", trimmed)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func performGeminiCLISetup(ctx context.Context, httpClient *http.Client, storage *geminiAuth.GeminiTokenStorage, requestedProject string) error {
|
||||
metadata := map[string]string{
|
||||
"ideType": "IDE_UNSPECIFIED",
|
||||
|
||||
@@ -28,7 +28,7 @@ func (h *Handler) GetConfigYAML(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
var node yaml.Node
|
||||
if err := yaml.Unmarshal(data, &node); err != nil {
|
||||
if err = yaml.Unmarshal(data, &node); err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "parse_failed", "message": err.Error()})
|
||||
return
|
||||
}
|
||||
@@ -41,17 +41,18 @@ func (h *Handler) GetConfigYAML(c *gin.Context) {
|
||||
}
|
||||
|
||||
func WriteConfig(path string, data []byte) error {
|
||||
data = config.NormalizeCommentIndentation(data)
|
||||
f, err := os.OpenFile(path, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0644)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if _, err := f.Write(data); err != nil {
|
||||
f.Close()
|
||||
return err
|
||||
if _, errWrite := f.Write(data); errWrite != nil {
|
||||
_ = f.Close()
|
||||
return errWrite
|
||||
}
|
||||
if err := f.Sync(); err != nil {
|
||||
f.Close()
|
||||
return err
|
||||
if errSync := f.Sync(); errSync != nil {
|
||||
_ = f.Close()
|
||||
return errSync
|
||||
}
|
||||
return f.Close()
|
||||
}
|
||||
@@ -63,7 +64,7 @@ func (h *Handler) PutConfigYAML(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
var cfg config.Config
|
||||
if err := yaml.Unmarshal(body, &cfg); err != nil {
|
||||
if err = yaml.Unmarshal(body, &cfg); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid_yaml", "message": err.Error()})
|
||||
return
|
||||
}
|
||||
@@ -75,18 +76,20 @@ func (h *Handler) PutConfigYAML(c *gin.Context) {
|
||||
return
|
||||
}
|
||||
tempFile := tmpFile.Name()
|
||||
if _, err := tmpFile.Write(body); err != nil {
|
||||
tmpFile.Close()
|
||||
os.Remove(tempFile)
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "write_failed", "message": err.Error()})
|
||||
if _, errWrite := tmpFile.Write(body); errWrite != nil {
|
||||
_ = tmpFile.Close()
|
||||
_ = os.Remove(tempFile)
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "write_failed", "message": errWrite.Error()})
|
||||
return
|
||||
}
|
||||
if err := tmpFile.Close(); err != nil {
|
||||
os.Remove(tempFile)
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "write_failed", "message": err.Error()})
|
||||
if errClose := tmpFile.Close(); errClose != nil {
|
||||
_ = os.Remove(tempFile)
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "write_failed", "message": errClose.Error()})
|
||||
return
|
||||
}
|
||||
defer os.Remove(tempFile)
|
||||
defer func() {
|
||||
_ = os.Remove(tempFile)
|
||||
}()
|
||||
_, err = config.LoadConfigOptional(tempFile, false)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusUnprocessableEntity, gin.H{"error": "invalid_config", "message": err.Error()})
|
||||
@@ -153,6 +156,14 @@ func (h *Handler) PutRequestLog(c *gin.Context) {
|
||||
h.updateBoolField(c, func(v bool) { h.cfg.RequestLog = v })
|
||||
}
|
||||
|
||||
// Websocket auth
|
||||
func (h *Handler) GetWebsocketAuth(c *gin.Context) {
|
||||
c.JSON(200, gin.H{"ws-auth": h.cfg.WebsocketAuth})
|
||||
}
|
||||
func (h *Handler) PutWebsocketAuth(c *gin.Context) {
|
||||
h.updateBoolField(c, func(v bool) { h.cfg.WebsocketAuth = v })
|
||||
}
|
||||
|
||||
// Request retry
|
||||
func (h *Handler) GetRequestRetry(c *gin.Context) {
|
||||
c.JSON(200, gin.H{"request-retry": h.cfg.RequestRetry})
|
||||
|
||||
@@ -13,6 +13,7 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/buildinfo"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/usage"
|
||||
sdkAuth "github.com/router-for-me/CLIProxyAPI/v6/sdk/auth"
|
||||
@@ -91,6 +92,10 @@ func (h *Handler) Middleware() gin.HandlerFunc {
|
||||
const banDuration = 30 * time.Minute
|
||||
|
||||
return func(c *gin.Context) {
|
||||
c.Header("X-CPA-VERSION", buildinfo.Version)
|
||||
c.Header("X-CPA-COMMIT", buildinfo.Commit)
|
||||
c.Header("X-CPA-BUILD-DATE", buildinfo.BuildDate)
|
||||
|
||||
clientIP := c.ClientIP()
|
||||
localClient := clientIP == "127.0.0.1" || clientIP == "::1"
|
||||
cfg := h.cfg
|
||||
|
||||
156
internal/api/handlers/management/vertex_import.go
Normal file
156
internal/api/handlers/management/vertex_import.go
Normal file
@@ -0,0 +1,156 @@
|
||||
package management
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/auth/vertex"
|
||||
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||
)
|
||||
|
||||
// ImportVertexCredential handles uploading a Vertex service account JSON and saving it as an auth record.
|
||||
func (h *Handler) ImportVertexCredential(c *gin.Context) {
|
||||
if h == nil || h.cfg == nil {
|
||||
c.JSON(http.StatusServiceUnavailable, gin.H{"error": "config unavailable"})
|
||||
return
|
||||
}
|
||||
if h.cfg.AuthDir == "" {
|
||||
c.JSON(http.StatusServiceUnavailable, gin.H{"error": "auth directory not configured"})
|
||||
return
|
||||
}
|
||||
|
||||
fileHeader, err := c.FormFile("file")
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "file required"})
|
||||
return
|
||||
}
|
||||
|
||||
file, err := fileHeader.Open()
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("failed to read file: %v", err)})
|
||||
return
|
||||
}
|
||||
defer file.Close()
|
||||
|
||||
data, err := io.ReadAll(file)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("failed to read file: %v", err)})
|
||||
return
|
||||
}
|
||||
|
||||
var serviceAccount map[string]any
|
||||
if err := json.Unmarshal(data, &serviceAccount); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid json", "message": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
normalizedSA, err := vertex.NormalizeServiceAccountMap(serviceAccount)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid service account", "message": err.Error()})
|
||||
return
|
||||
}
|
||||
serviceAccount = normalizedSA
|
||||
|
||||
projectID := strings.TrimSpace(valueAsString(serviceAccount["project_id"]))
|
||||
if projectID == "" {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "project_id missing"})
|
||||
return
|
||||
}
|
||||
email := strings.TrimSpace(valueAsString(serviceAccount["client_email"]))
|
||||
|
||||
location := strings.TrimSpace(c.PostForm("location"))
|
||||
if location == "" {
|
||||
location = strings.TrimSpace(c.Query("location"))
|
||||
}
|
||||
if location == "" {
|
||||
location = "us-central1"
|
||||
}
|
||||
|
||||
fileName := fmt.Sprintf("vertex-%s.json", sanitizeVertexFilePart(projectID))
|
||||
label := labelForVertex(projectID, email)
|
||||
storage := &vertex.VertexCredentialStorage{
|
||||
ServiceAccount: serviceAccount,
|
||||
ProjectID: projectID,
|
||||
Email: email,
|
||||
Location: location,
|
||||
Type: "vertex",
|
||||
}
|
||||
metadata := map[string]any{
|
||||
"service_account": serviceAccount,
|
||||
"project_id": projectID,
|
||||
"email": email,
|
||||
"location": location,
|
||||
"type": "vertex",
|
||||
"label": label,
|
||||
}
|
||||
record := &coreauth.Auth{
|
||||
ID: fileName,
|
||||
Provider: "vertex",
|
||||
FileName: fileName,
|
||||
Storage: storage,
|
||||
Label: label,
|
||||
Metadata: metadata,
|
||||
}
|
||||
|
||||
ctx := context.Background()
|
||||
if reqCtx := c.Request.Context(); reqCtx != nil {
|
||||
ctx = reqCtx
|
||||
}
|
||||
savedPath, err := h.saveTokenRecord(ctx, record)
|
||||
if err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "save_failed", "message": err.Error()})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"status": "ok",
|
||||
"auth-file": savedPath,
|
||||
"project_id": projectID,
|
||||
"email": email,
|
||||
"location": location,
|
||||
})
|
||||
}
|
||||
|
||||
func valueAsString(v any) string {
|
||||
if v == nil {
|
||||
return ""
|
||||
}
|
||||
switch t := v.(type) {
|
||||
case string:
|
||||
return t
|
||||
default:
|
||||
return fmt.Sprint(t)
|
||||
}
|
||||
}
|
||||
|
||||
func sanitizeVertexFilePart(s string) string {
|
||||
out := strings.TrimSpace(s)
|
||||
replacers := []string{"/", "_", "\\", "_", ":", "_", " ", "-"}
|
||||
for i := 0; i < len(replacers); i += 2 {
|
||||
out = strings.ReplaceAll(out, replacers[i], replacers[i+1])
|
||||
}
|
||||
if out == "" {
|
||||
return "vertex"
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func labelForVertex(projectID, email string) string {
|
||||
p := strings.TrimSpace(projectID)
|
||||
e := strings.TrimSpace(email)
|
||||
if p != "" && e != "" {
|
||||
return fmt.Sprintf("%s (%s)", p, e)
|
||||
}
|
||||
if p != "" {
|
||||
return p
|
||||
}
|
||||
if e != "" {
|
||||
return e
|
||||
}
|
||||
return "vertex"
|
||||
}
|
||||
@@ -484,6 +484,9 @@ func (s *Server) registerManagementRoutes() {
|
||||
mgmt.GET("/request-log", s.mgmt.GetRequestLog)
|
||||
mgmt.PUT("/request-log", s.mgmt.PutRequestLog)
|
||||
mgmt.PATCH("/request-log", s.mgmt.PutRequestLog)
|
||||
mgmt.GET("/ws-auth", s.mgmt.GetWebsocketAuth)
|
||||
mgmt.PUT("/ws-auth", s.mgmt.PutWebsocketAuth)
|
||||
mgmt.PATCH("/ws-auth", s.mgmt.PutWebsocketAuth)
|
||||
|
||||
mgmt.GET("/request-retry", s.mgmt.GetRequestRetry)
|
||||
mgmt.PUT("/request-retry", s.mgmt.PutRequestRetry)
|
||||
@@ -508,6 +511,7 @@ func (s *Server) registerManagementRoutes() {
|
||||
mgmt.GET("/auth-files/download", s.mgmt.DownloadAuthFile)
|
||||
mgmt.POST("/auth-files", s.mgmt.UploadAuthFile)
|
||||
mgmt.DELETE("/auth-files", s.mgmt.DeleteAuthFile)
|
||||
mgmt.POST("/vertex/import", s.mgmt.ImportVertexCredential)
|
||||
|
||||
mgmt.GET("/anthropic-auth-url", s.mgmt.RequestAnthropicToken)
|
||||
mgmt.GET("/codex-auth-url", s.mgmt.RequestCodexToken)
|
||||
@@ -703,7 +707,7 @@ func (s *Server) Stop(ctx context.Context) error {
|
||||
func corsMiddleware() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
c.Header("Access-Control-Allow-Origin", "*")
|
||||
c.Header("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, OPTIONS")
|
||||
c.Header("Access-Control-Allow-Methods", "GET, POST, PUT, PATCH, DELETE, OPTIONS")
|
||||
c.Header("Access-Control-Allow-Headers", "*")
|
||||
|
||||
if c.Request.Method == "OPTIONS" {
|
||||
|
||||
@@ -8,6 +8,7 @@ import (
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/misc"
|
||||
log "github.com/sirupsen/logrus"
|
||||
@@ -67,3 +68,20 @@ func (ts *GeminiTokenStorage) SaveTokenToFile(authFilePath string) error {
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// CredentialFileName returns the filename used to persist Gemini CLI credentials.
|
||||
// When projectID represents multiple projects (comma-separated or literal ALL),
|
||||
// the suffix is normalized to "all" and a "gemini-" prefix is enforced to keep
|
||||
// web and CLI generated files consistent.
|
||||
func CredentialFileName(email, projectID string, includeProviderPrefix bool) string {
|
||||
email = strings.TrimSpace(email)
|
||||
project := strings.TrimSpace(projectID)
|
||||
if strings.EqualFold(project, "all") || strings.Contains(project, ",") {
|
||||
return fmt.Sprintf("gemini-%s-all.json", email)
|
||||
}
|
||||
prefix := ""
|
||||
if includeProviderPrefix {
|
||||
prefix = "gemini-"
|
||||
}
|
||||
return fmt.Sprintf("%s%s-%s.json", prefix, email, project)
|
||||
}
|
||||
|
||||
208
internal/auth/vertex/keyutil.go
Normal file
208
internal/auth/vertex/keyutil.go
Normal file
@@ -0,0 +1,208 @@
|
||||
package vertex
|
||||
|
||||
import (
|
||||
"crypto/rsa"
|
||||
"crypto/x509"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"encoding/pem"
|
||||
"fmt"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// NormalizeServiceAccountJSON normalizes the given JSON-encoded service account payload.
|
||||
// It returns the normalized JSON (with sanitized private_key) or, if normalization fails,
|
||||
// the original bytes and the encountered error.
|
||||
func NormalizeServiceAccountJSON(raw []byte) ([]byte, error) {
|
||||
if len(raw) == 0 {
|
||||
return raw, nil
|
||||
}
|
||||
var payload map[string]any
|
||||
if err := json.Unmarshal(raw, &payload); err != nil {
|
||||
return raw, err
|
||||
}
|
||||
normalized, err := NormalizeServiceAccountMap(payload)
|
||||
if err != nil {
|
||||
return raw, err
|
||||
}
|
||||
out, err := json.Marshal(normalized)
|
||||
if err != nil {
|
||||
return raw, err
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
// NormalizeServiceAccountMap returns a copy of the given service account map with
|
||||
// a sanitized private_key field that is guaranteed to contain a valid RSA PRIVATE KEY PEM block.
|
||||
func NormalizeServiceAccountMap(sa map[string]any) (map[string]any, error) {
|
||||
if sa == nil {
|
||||
return nil, fmt.Errorf("service account payload is empty")
|
||||
}
|
||||
pk, _ := sa["private_key"].(string)
|
||||
if strings.TrimSpace(pk) == "" {
|
||||
return nil, fmt.Errorf("service account missing private_key")
|
||||
}
|
||||
normalized, err := sanitizePrivateKey(pk)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
clone := make(map[string]any, len(sa))
|
||||
for k, v := range sa {
|
||||
clone[k] = v
|
||||
}
|
||||
clone["private_key"] = normalized
|
||||
return clone, nil
|
||||
}
|
||||
|
||||
func sanitizePrivateKey(raw string) (string, error) {
|
||||
pk := strings.ReplaceAll(raw, "\r\n", "\n")
|
||||
pk = strings.ReplaceAll(pk, "\r", "\n")
|
||||
pk = stripANSIEscape(pk)
|
||||
pk = strings.ToValidUTF8(pk, "")
|
||||
pk = strings.TrimSpace(pk)
|
||||
|
||||
normalized := pk
|
||||
if block, _ := pem.Decode([]byte(pk)); block == nil {
|
||||
// Attempt to reconstruct from the textual payload.
|
||||
if reconstructed, err := rebuildPEM(pk); err == nil {
|
||||
normalized = reconstructed
|
||||
} else {
|
||||
return "", fmt.Errorf("private_key is not valid pem: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
block, _ := pem.Decode([]byte(normalized))
|
||||
if block == nil {
|
||||
return "", fmt.Errorf("private_key pem decode failed")
|
||||
}
|
||||
|
||||
rsaBlock, err := ensureRSAPrivateKey(block)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return string(pem.EncodeToMemory(rsaBlock)), nil
|
||||
}
|
||||
|
||||
func ensureRSAPrivateKey(block *pem.Block) (*pem.Block, error) {
|
||||
if block == nil {
|
||||
return nil, fmt.Errorf("pem block is nil")
|
||||
}
|
||||
|
||||
if block.Type == "RSA PRIVATE KEY" {
|
||||
if _, err := x509.ParsePKCS1PrivateKey(block.Bytes); err != nil {
|
||||
return nil, fmt.Errorf("private_key invalid rsa: %w", err)
|
||||
}
|
||||
return block, nil
|
||||
}
|
||||
|
||||
if block.Type == "PRIVATE KEY" {
|
||||
key, err := x509.ParsePKCS8PrivateKey(block.Bytes)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("private_key invalid pkcs8: %w", err)
|
||||
}
|
||||
rsaKey, ok := key.(*rsa.PrivateKey)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("private_key is not an RSA key")
|
||||
}
|
||||
der := x509.MarshalPKCS1PrivateKey(rsaKey)
|
||||
return &pem.Block{Type: "RSA PRIVATE KEY", Bytes: der}, nil
|
||||
}
|
||||
|
||||
// Attempt auto-detection: try PKCS#1 first, then PKCS#8.
|
||||
if rsaKey, err := x509.ParsePKCS1PrivateKey(block.Bytes); err == nil {
|
||||
der := x509.MarshalPKCS1PrivateKey(rsaKey)
|
||||
return &pem.Block{Type: "RSA PRIVATE KEY", Bytes: der}, nil
|
||||
}
|
||||
if key, err := x509.ParsePKCS8PrivateKey(block.Bytes); err == nil {
|
||||
if rsaKey, ok := key.(*rsa.PrivateKey); ok {
|
||||
der := x509.MarshalPKCS1PrivateKey(rsaKey)
|
||||
return &pem.Block{Type: "RSA PRIVATE KEY", Bytes: der}, nil
|
||||
}
|
||||
}
|
||||
return nil, fmt.Errorf("private_key uses unsupported format")
|
||||
}
|
||||
|
||||
func rebuildPEM(raw string) (string, error) {
|
||||
kind := "PRIVATE KEY"
|
||||
if strings.Contains(raw, "RSA PRIVATE KEY") {
|
||||
kind = "RSA PRIVATE KEY"
|
||||
}
|
||||
header := "-----BEGIN " + kind + "-----"
|
||||
footer := "-----END " + kind + "-----"
|
||||
start := strings.Index(raw, header)
|
||||
end := strings.Index(raw, footer)
|
||||
if start < 0 || end <= start {
|
||||
return "", fmt.Errorf("missing pem markers")
|
||||
}
|
||||
body := raw[start+len(header) : end]
|
||||
payload := filterBase64(body)
|
||||
if payload == "" {
|
||||
return "", fmt.Errorf("private_key base64 payload empty")
|
||||
}
|
||||
der, err := base64.StdEncoding.DecodeString(payload)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("private_key base64 decode failed: %w", err)
|
||||
}
|
||||
block := &pem.Block{Type: kind, Bytes: der}
|
||||
return string(pem.EncodeToMemory(block)), nil
|
||||
}
|
||||
|
||||
func filterBase64(s string) string {
|
||||
var b strings.Builder
|
||||
for _, r := range s {
|
||||
switch {
|
||||
case r >= 'A' && r <= 'Z':
|
||||
b.WriteRune(r)
|
||||
case r >= 'a' && r <= 'z':
|
||||
b.WriteRune(r)
|
||||
case r >= '0' && r <= '9':
|
||||
b.WriteRune(r)
|
||||
case r == '+' || r == '/' || r == '=':
|
||||
b.WriteRune(r)
|
||||
default:
|
||||
// skip
|
||||
}
|
||||
}
|
||||
return b.String()
|
||||
}
|
||||
|
||||
func stripANSIEscape(s string) string {
|
||||
in := []rune(s)
|
||||
var out []rune
|
||||
for i := 0; i < len(in); i++ {
|
||||
r := in[i]
|
||||
if r != 0x1b {
|
||||
out = append(out, r)
|
||||
continue
|
||||
}
|
||||
if i+1 >= len(in) {
|
||||
continue
|
||||
}
|
||||
next := in[i+1]
|
||||
switch next {
|
||||
case ']':
|
||||
i += 2
|
||||
for i < len(in) {
|
||||
if in[i] == 0x07 {
|
||||
break
|
||||
}
|
||||
if in[i] == 0x1b && i+1 < len(in) && in[i+1] == '\\' {
|
||||
i++
|
||||
break
|
||||
}
|
||||
i++
|
||||
}
|
||||
case '[':
|
||||
i += 2
|
||||
for i < len(in) {
|
||||
if (in[i] >= 'A' && in[i] <= 'Z') || (in[i] >= 'a' && in[i] <= 'z') {
|
||||
break
|
||||
}
|
||||
i++
|
||||
}
|
||||
default:
|
||||
// skip single ESC
|
||||
}
|
||||
}
|
||||
return string(out)
|
||||
}
|
||||
66
internal/auth/vertex/vertex_credentials.go
Normal file
66
internal/auth/vertex/vertex_credentials.go
Normal file
@@ -0,0 +1,66 @@
|
||||
// Package vertex provides token storage for Google Vertex AI Gemini via service account credentials.
|
||||
// It serialises service account JSON into an auth file that is consumed by the runtime executor.
|
||||
package vertex
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/misc"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// VertexCredentialStorage stores the service account JSON for Vertex AI access.
|
||||
// The content is persisted verbatim under the "service_account" key, together with
|
||||
// helper fields for project, location and email to improve logging and discovery.
|
||||
type VertexCredentialStorage struct {
|
||||
// ServiceAccount holds the parsed service account JSON content.
|
||||
ServiceAccount map[string]any `json:"service_account"`
|
||||
|
||||
// ProjectID is derived from the service account JSON (project_id).
|
||||
ProjectID string `json:"project_id"`
|
||||
|
||||
// Email is the client_email from the service account JSON.
|
||||
Email string `json:"email"`
|
||||
|
||||
// Location optionally sets a default region (e.g., us-central1) for Vertex endpoints.
|
||||
Location string `json:"location,omitempty"`
|
||||
|
||||
// Type is the provider identifier stored alongside credentials. Always "vertex".
|
||||
Type string `json:"type"`
|
||||
}
|
||||
|
||||
// SaveTokenToFile writes the credential payload to the given file path in JSON format.
|
||||
// It ensures the parent directory exists and logs the operation for transparency.
|
||||
func (s *VertexCredentialStorage) SaveTokenToFile(authFilePath string) error {
|
||||
misc.LogSavingCredentials(authFilePath)
|
||||
if s == nil {
|
||||
return fmt.Errorf("vertex credential: storage is nil")
|
||||
}
|
||||
if s.ServiceAccount == nil {
|
||||
return fmt.Errorf("vertex credential: service account content is empty")
|
||||
}
|
||||
// Ensure we tag the file with the provider type.
|
||||
s.Type = "vertex"
|
||||
|
||||
if err := os.MkdirAll(filepath.Dir(authFilePath), 0o700); err != nil {
|
||||
return fmt.Errorf("vertex credential: create directory failed: %w", err)
|
||||
}
|
||||
f, err := os.Create(authFilePath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("vertex credential: create file failed: %w", err)
|
||||
}
|
||||
defer func() {
|
||||
if errClose := f.Close(); errClose != nil {
|
||||
log.Errorf("vertex credential: failed to close file: %v", errClose)
|
||||
}
|
||||
}()
|
||||
enc := json.NewEncoder(f)
|
||||
enc.SetIndent("", " ")
|
||||
if err = enc.Encode(s); err != nil {
|
||||
return fmt.Errorf("vertex credential: encode failed: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
15
internal/buildinfo/buildinfo.go
Normal file
15
internal/buildinfo/buildinfo.go
Normal file
@@ -0,0 +1,15 @@
|
||||
// Package buildinfo exposes compile-time metadata shared across the server.
|
||||
package buildinfo
|
||||
|
||||
// The following variables are overridden via ldflags during release builds.
|
||||
// Defaults cover local development builds.
|
||||
var (
|
||||
// Version is the semantic version or git describe output of the binary.
|
||||
Version = "dev"
|
||||
|
||||
// Commit is the git commit SHA baked into the binary.
|
||||
Commit = "none"
|
||||
|
||||
// BuildDate records when the binary was built in UTC.
|
||||
BuildDate = "unknown"
|
||||
)
|
||||
@@ -96,35 +96,52 @@ func DoLogin(cfg *config.Config, projectID string, options *LoginOptions) {
|
||||
}
|
||||
|
||||
selectedProjectID := promptForProjectSelection(projects, strings.TrimSpace(projectID), promptFn)
|
||||
if strings.TrimSpace(selectedProjectID) == "" {
|
||||
projectSelections, errSelection := resolveProjectSelections(selectedProjectID, projects)
|
||||
if errSelection != nil {
|
||||
log.Fatalf("Invalid project selection: %v", errSelection)
|
||||
return
|
||||
}
|
||||
if len(projectSelections) == 0 {
|
||||
log.Fatal("No project selected; aborting login.")
|
||||
return
|
||||
}
|
||||
|
||||
if errSetup := performGeminiCLISetup(ctx, httpClient, storage, selectedProjectID); errSetup != nil {
|
||||
var projectErr *projectSelectionRequiredError
|
||||
if errors.As(errSetup, &projectErr) {
|
||||
log.Error("Failed to start user onboarding: A project ID is required.")
|
||||
showProjectSelectionHelp(storage.Email, projects)
|
||||
activatedProjects := make([]string, 0, len(projectSelections))
|
||||
for _, candidateID := range projectSelections {
|
||||
log.Infof("Activating project %s", candidateID)
|
||||
if errSetup := performGeminiCLISetup(ctx, httpClient, storage, candidateID); errSetup != nil {
|
||||
var projectErr *projectSelectionRequiredError
|
||||
if errors.As(errSetup, &projectErr) {
|
||||
log.Error("Failed to start user onboarding: A project ID is required.")
|
||||
showProjectSelectionHelp(storage.Email, projects)
|
||||
return
|
||||
}
|
||||
log.Fatalf("Failed to complete user setup: %v", errSetup)
|
||||
return
|
||||
}
|
||||
log.Fatalf("Failed to complete user setup: %v", errSetup)
|
||||
return
|
||||
finalID := strings.TrimSpace(storage.ProjectID)
|
||||
if finalID == "" {
|
||||
finalID = candidateID
|
||||
}
|
||||
activatedProjects = append(activatedProjects, finalID)
|
||||
}
|
||||
|
||||
storage.Auto = false
|
||||
storage.ProjectID = strings.Join(activatedProjects, ",")
|
||||
|
||||
if !storage.Auto && !storage.Checked {
|
||||
isChecked, errCheck := checkCloudAPIIsEnabled(ctx, httpClient, storage.ProjectID)
|
||||
if errCheck != nil {
|
||||
log.Fatalf("Failed to check if Cloud AI API is enabled: %v", errCheck)
|
||||
return
|
||||
}
|
||||
storage.Checked = isChecked
|
||||
if !isChecked {
|
||||
log.Fatal("Failed to check if Cloud AI API is enabled. If you encounter an error message, please create an issue.")
|
||||
return
|
||||
for _, pid := range activatedProjects {
|
||||
isChecked, errCheck := checkCloudAPIIsEnabled(ctx, httpClient, pid)
|
||||
if errCheck != nil {
|
||||
log.Fatalf("Failed to check if Cloud AI API is enabled for %s: %v", pid, errCheck)
|
||||
return
|
||||
}
|
||||
if !isChecked {
|
||||
log.Fatalf("Failed to check if Cloud AI API is enabled for project %s. If you encounter an error message, please create an issue.", pid)
|
||||
return
|
||||
}
|
||||
}
|
||||
storage.Checked = true
|
||||
}
|
||||
|
||||
updateAuthRecord(record, storage)
|
||||
@@ -354,10 +371,14 @@ func promptForProjectSelection(projects []interfaces.GCPProjectProjects, presetI
|
||||
defaultIndex = idx
|
||||
}
|
||||
}
|
||||
fmt.Println("Type 'ALL' to onboard every listed project.")
|
||||
|
||||
defaultID := projects[defaultIndex].ProjectID
|
||||
|
||||
if trimmedPreset != "" {
|
||||
if strings.EqualFold(trimmedPreset, "ALL") {
|
||||
return "ALL"
|
||||
}
|
||||
for _, project := range projects {
|
||||
if project.ProjectID == trimmedPreset {
|
||||
return trimmedPreset
|
||||
@@ -367,13 +388,16 @@ func promptForProjectSelection(projects []interfaces.GCPProjectProjects, presetI
|
||||
}
|
||||
|
||||
for {
|
||||
promptMsg := fmt.Sprintf("Enter project ID [%s]: ", defaultID)
|
||||
promptMsg := fmt.Sprintf("Enter project ID [%s] or ALL: ", defaultID)
|
||||
answer, errPrompt := promptFn(promptMsg)
|
||||
if errPrompt != nil {
|
||||
log.Errorf("Project selection prompt failed: %v", errPrompt)
|
||||
return defaultID
|
||||
}
|
||||
answer = strings.TrimSpace(answer)
|
||||
if strings.EqualFold(answer, "ALL") {
|
||||
return "ALL"
|
||||
}
|
||||
if answer == "" {
|
||||
return defaultID
|
||||
}
|
||||
@@ -394,6 +418,52 @@ func promptForProjectSelection(projects []interfaces.GCPProjectProjects, presetI
|
||||
}
|
||||
}
|
||||
|
||||
func resolveProjectSelections(selection string, projects []interfaces.GCPProjectProjects) ([]string, error) {
|
||||
trimmed := strings.TrimSpace(selection)
|
||||
if trimmed == "" {
|
||||
return nil, nil
|
||||
}
|
||||
available := make(map[string]struct{}, len(projects))
|
||||
ordered := make([]string, 0, len(projects))
|
||||
for _, project := range projects {
|
||||
id := strings.TrimSpace(project.ProjectID)
|
||||
if id == "" {
|
||||
continue
|
||||
}
|
||||
if _, exists := available[id]; exists {
|
||||
continue
|
||||
}
|
||||
available[id] = struct{}{}
|
||||
ordered = append(ordered, id)
|
||||
}
|
||||
if strings.EqualFold(trimmed, "ALL") {
|
||||
if len(ordered) == 0 {
|
||||
return nil, fmt.Errorf("no projects available for ALL selection")
|
||||
}
|
||||
return append([]string(nil), ordered...), nil
|
||||
}
|
||||
parts := strings.Split(trimmed, ",")
|
||||
selections := make([]string, 0, len(parts))
|
||||
seen := make(map[string]struct{}, len(parts))
|
||||
for _, part := range parts {
|
||||
id := strings.TrimSpace(part)
|
||||
if id == "" {
|
||||
continue
|
||||
}
|
||||
if _, dup := seen[id]; dup {
|
||||
continue
|
||||
}
|
||||
if len(available) > 0 {
|
||||
if _, ok := available[id]; !ok {
|
||||
return nil, fmt.Errorf("project %s not found in available projects", id)
|
||||
}
|
||||
}
|
||||
seen[id] = struct{}{}
|
||||
selections = append(selections, id)
|
||||
}
|
||||
return selections, nil
|
||||
}
|
||||
|
||||
func defaultProjectPrompt() func(string) (string, error) {
|
||||
reader := bufio.NewReader(os.Stdin)
|
||||
return func(prompt string) (string, error) {
|
||||
@@ -495,7 +565,7 @@ func updateAuthRecord(record *cliproxyauth.Auth, storage *gemini.GeminiTokenStor
|
||||
return
|
||||
}
|
||||
|
||||
finalName := fmt.Sprintf("%s-%s.json", storage.Email, storage.ProjectID)
|
||||
finalName := gemini.CredentialFileName(storage.Email, storage.ProjectID, false)
|
||||
|
||||
if record.Metadata == nil {
|
||||
record.Metadata = make(map[string]any)
|
||||
|
||||
123
internal/cmd/vertex_import.go
Normal file
123
internal/cmd/vertex_import.go
Normal file
@@ -0,0 +1,123 @@
|
||||
// Package cmd contains CLI helpers. This file implements importing a Vertex AI
|
||||
// service account JSON into the auth store as a dedicated "vertex" credential.
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/auth/vertex"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
||||
sdkAuth "github.com/router-for-me/CLIProxyAPI/v6/sdk/auth"
|
||||
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// DoVertexImport imports a Google Cloud service account key JSON and persists
|
||||
// it as a "vertex" provider credential. The file content is embedded in the auth
|
||||
// file to allow portable deployment across stores.
|
||||
func DoVertexImport(cfg *config.Config, keyPath string) {
|
||||
if cfg == nil {
|
||||
cfg = &config.Config{}
|
||||
}
|
||||
if resolved, errResolve := util.ResolveAuthDir(cfg.AuthDir); errResolve == nil {
|
||||
cfg.AuthDir = resolved
|
||||
}
|
||||
rawPath := strings.TrimSpace(keyPath)
|
||||
if rawPath == "" {
|
||||
log.Fatalf("vertex-import: missing service account key path")
|
||||
return
|
||||
}
|
||||
data, errRead := os.ReadFile(rawPath)
|
||||
if errRead != nil {
|
||||
log.Fatalf("vertex-import: read file failed: %v", errRead)
|
||||
return
|
||||
}
|
||||
var sa map[string]any
|
||||
if errUnmarshal := json.Unmarshal(data, &sa); errUnmarshal != nil {
|
||||
log.Fatalf("vertex-import: invalid service account json: %v", errUnmarshal)
|
||||
return
|
||||
}
|
||||
// Validate and normalize private_key before saving
|
||||
normalizedSA, errFix := vertex.NormalizeServiceAccountMap(sa)
|
||||
if errFix != nil {
|
||||
log.Fatalf("vertex-import: %v", errFix)
|
||||
return
|
||||
}
|
||||
sa = normalizedSA
|
||||
email, _ := sa["client_email"].(string)
|
||||
projectID, _ := sa["project_id"].(string)
|
||||
if strings.TrimSpace(projectID) == "" {
|
||||
log.Fatalf("vertex-import: project_id missing in service account json")
|
||||
return
|
||||
}
|
||||
if strings.TrimSpace(email) == "" {
|
||||
// Keep empty email but warn
|
||||
log.Warn("vertex-import: client_email missing in service account json")
|
||||
}
|
||||
// Default location if not provided by user. Can be edited in the saved file later.
|
||||
location := "us-central1"
|
||||
|
||||
fileName := fmt.Sprintf("vertex-%s.json", sanitizeFilePart(projectID))
|
||||
// Build auth record
|
||||
storage := &vertex.VertexCredentialStorage{
|
||||
ServiceAccount: sa,
|
||||
ProjectID: projectID,
|
||||
Email: email,
|
||||
Location: location,
|
||||
}
|
||||
metadata := map[string]any{
|
||||
"service_account": sa,
|
||||
"project_id": projectID,
|
||||
"email": email,
|
||||
"location": location,
|
||||
"type": "vertex",
|
||||
"label": labelForVertex(projectID, email),
|
||||
}
|
||||
record := &coreauth.Auth{
|
||||
ID: fileName,
|
||||
Provider: "vertex",
|
||||
FileName: fileName,
|
||||
Storage: storage,
|
||||
Metadata: metadata,
|
||||
}
|
||||
|
||||
store := sdkAuth.GetTokenStore()
|
||||
if setter, ok := store.(interface{ SetBaseDir(string) }); ok {
|
||||
setter.SetBaseDir(cfg.AuthDir)
|
||||
}
|
||||
path, errSave := store.Save(context.Background(), record)
|
||||
if errSave != nil {
|
||||
log.Fatalf("vertex-import: save credential failed: %v", errSave)
|
||||
return
|
||||
}
|
||||
fmt.Printf("Vertex credentials imported: %s\n", path)
|
||||
}
|
||||
|
||||
func sanitizeFilePart(s string) string {
|
||||
out := strings.TrimSpace(s)
|
||||
replacers := []string{"/", "_", "\\", "_", ":", "_", " ", "-"}
|
||||
for i := 0; i < len(replacers); i += 2 {
|
||||
out = strings.ReplaceAll(out, replacers[i], replacers[i+1])
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func labelForVertex(projectID, email string) string {
|
||||
p := strings.TrimSpace(projectID)
|
||||
e := strings.TrimSpace(email)
|
||||
if p != "" && e != "" {
|
||||
return fmt.Sprintf("%s (%s)", p, e)
|
||||
}
|
||||
if p != "" {
|
||||
return p
|
||||
}
|
||||
if e != "" {
|
||||
return e
|
||||
}
|
||||
return "vertex"
|
||||
}
|
||||
@@ -5,6 +5,7 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
@@ -462,13 +463,19 @@ func SaveConfigPreserveComments(configFile string, cfg *Config) error {
|
||||
return err
|
||||
}
|
||||
defer func() { _ = f.Close() }()
|
||||
enc := yaml.NewEncoder(f)
|
||||
var buf bytes.Buffer
|
||||
enc := yaml.NewEncoder(&buf)
|
||||
enc.SetIndent(2)
|
||||
if err = enc.Encode(&original); err != nil {
|
||||
_ = enc.Close()
|
||||
return err
|
||||
}
|
||||
return enc.Close()
|
||||
if err = enc.Close(); err != nil {
|
||||
return err
|
||||
}
|
||||
data = NormalizeCommentIndentation(buf.Bytes())
|
||||
_, err = f.Write(data)
|
||||
return err
|
||||
}
|
||||
|
||||
func sanitizeConfigForPersist(cfg *Config) *Config {
|
||||
@@ -518,13 +525,40 @@ func SaveConfigPreserveCommentsUpdateNestedScalar(configFile string, path []stri
|
||||
return err
|
||||
}
|
||||
defer func() { _ = f.Close() }()
|
||||
enc := yaml.NewEncoder(f)
|
||||
var buf bytes.Buffer
|
||||
enc := yaml.NewEncoder(&buf)
|
||||
enc.SetIndent(2)
|
||||
if err = enc.Encode(&root); err != nil {
|
||||
_ = enc.Close()
|
||||
return err
|
||||
}
|
||||
return enc.Close()
|
||||
if err = enc.Close(); err != nil {
|
||||
return err
|
||||
}
|
||||
data = NormalizeCommentIndentation(buf.Bytes())
|
||||
_, err = f.Write(data)
|
||||
return err
|
||||
}
|
||||
|
||||
// NormalizeCommentIndentation removes indentation from standalone YAML comment lines to keep them left aligned.
|
||||
func NormalizeCommentIndentation(data []byte) []byte {
|
||||
lines := bytes.Split(data, []byte("\n"))
|
||||
changed := false
|
||||
for i, line := range lines {
|
||||
trimmed := bytes.TrimLeft(line, " \t")
|
||||
if len(trimmed) == 0 || trimmed[0] != '#' {
|
||||
continue
|
||||
}
|
||||
if len(trimmed) == len(line) {
|
||||
continue
|
||||
}
|
||||
lines[i] = append([]byte(nil), trimmed...)
|
||||
changed = true
|
||||
}
|
||||
if !changed {
|
||||
return data
|
||||
}
|
||||
return bytes.Join(lines, []byte("\n"))
|
||||
}
|
||||
|
||||
// getOrCreateMapValue finds the value node for a given key in a mapping node.
|
||||
@@ -766,6 +800,7 @@ func matchSequenceElement(original []*yaml.Node, used []bool, target *yaml.Node)
|
||||
}
|
||||
}
|
||||
}
|
||||
default:
|
||||
}
|
||||
// Fallback to structural equality to preserve nodes lacking explicit identifiers.
|
||||
for i := range original {
|
||||
|
||||
@@ -4,6 +4,7 @@
|
||||
package registry
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sort"
|
||||
"strings"
|
||||
"sync"
|
||||
@@ -800,6 +801,9 @@ func (r *ModelRegistry) convertModelToMap(model *ModelInfo, handlerType string)
|
||||
if model.Type != "" {
|
||||
result["type"] = model.Type
|
||||
}
|
||||
if model.Created != 0 {
|
||||
result["created"] = model.Created
|
||||
}
|
||||
return result
|
||||
}
|
||||
}
|
||||
@@ -821,3 +825,47 @@ func (r *ModelRegistry) CleanupExpiredQuotas() {
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
// GetFirstAvailableModel returns the first available model for the given handler type.
|
||||
// It prioritizes models by their creation timestamp (newest first) and checks if they have
|
||||
// available clients that are not suspended or over quota.
|
||||
//
|
||||
// Parameters:
|
||||
// - handlerType: The API handler type (e.g., "openai", "claude", "gemini")
|
||||
//
|
||||
// Returns:
|
||||
// - string: The model ID of the first available model, or empty string if none available
|
||||
// - error: An error if no models are available
|
||||
func (r *ModelRegistry) GetFirstAvailableModel(handlerType string) (string, error) {
|
||||
r.mutex.RLock()
|
||||
defer r.mutex.RUnlock()
|
||||
|
||||
// Get all available models for this handler type
|
||||
models := r.GetAvailableModels(handlerType)
|
||||
if len(models) == 0 {
|
||||
return "", fmt.Errorf("no models available for handler type: %s", handlerType)
|
||||
}
|
||||
|
||||
// Sort models by creation timestamp (newest first)
|
||||
sort.Slice(models, func(i, j int) bool {
|
||||
// Extract created timestamps from map
|
||||
createdI, okI := models[i]["created"].(int64)
|
||||
createdJ, okJ := models[j]["created"].(int64)
|
||||
if !okI || !okJ {
|
||||
return false
|
||||
}
|
||||
return createdI > createdJ
|
||||
})
|
||||
|
||||
// Find the first model with available clients
|
||||
for _, model := range models {
|
||||
if modelID, ok := model["id"].(string); ok {
|
||||
if count := r.GetModelCount(modelID); count > 0 {
|
||||
return modelID, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return "", fmt.Errorf("no available clients for any model in handler type: %s", handlerType)
|
||||
}
|
||||
|
||||
@@ -14,6 +14,7 @@ import (
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/misc"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/runtime/geminicli"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
||||
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
|
||||
@@ -80,7 +81,7 @@ func (e *GeminiCLIExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth
|
||||
}
|
||||
}
|
||||
|
||||
projectID := strings.TrimSpace(stringValue(auth.Metadata, "project_id"))
|
||||
projectID := resolveGeminiProjectID(auth)
|
||||
models := cliPreviewFallbackOrder(req.Model)
|
||||
if len(models) == 0 || models[0] != req.Model {
|
||||
models = append([]string{req.Model}, models...)
|
||||
@@ -214,7 +215,7 @@ func (e *GeminiCLIExecutor) ExecuteStream(ctx context.Context, auth *cliproxyaut
|
||||
basePayload = util.StripThinkingConfigIfUnsupported(req.Model, basePayload)
|
||||
basePayload = fixGeminiCLIImageAspectRatio(req.Model, basePayload)
|
||||
|
||||
projectID := strings.TrimSpace(stringValue(auth.Metadata, "project_id"))
|
||||
projectID := resolveGeminiProjectID(auth)
|
||||
|
||||
models := cliPreviewFallbackOrder(req.Model)
|
||||
if len(models) == 0 || models[0] != req.Model {
|
||||
@@ -493,12 +494,13 @@ func (e *GeminiCLIExecutor) Refresh(ctx context.Context, auth *cliproxyauth.Auth
|
||||
}
|
||||
|
||||
func prepareGeminiCLITokenSource(ctx context.Context, cfg *config.Config, auth *cliproxyauth.Auth) (oauth2.TokenSource, map[string]any, error) {
|
||||
if auth == nil || auth.Metadata == nil {
|
||||
metadata := geminiOAuthMetadata(auth)
|
||||
if auth == nil || metadata == nil {
|
||||
return nil, nil, fmt.Errorf("gemini-cli auth metadata missing")
|
||||
}
|
||||
|
||||
var base map[string]any
|
||||
if tokenRaw, ok := auth.Metadata["token"].(map[string]any); ok && tokenRaw != nil {
|
||||
if tokenRaw, ok := metadata["token"].(map[string]any); ok && tokenRaw != nil {
|
||||
base = cloneMap(tokenRaw)
|
||||
} else {
|
||||
base = make(map[string]any)
|
||||
@@ -512,16 +514,16 @@ func prepareGeminiCLITokenSource(ctx context.Context, cfg *config.Config, auth *
|
||||
}
|
||||
|
||||
if token.AccessToken == "" {
|
||||
token.AccessToken = stringValue(auth.Metadata, "access_token")
|
||||
token.AccessToken = stringValue(metadata, "access_token")
|
||||
}
|
||||
if token.RefreshToken == "" {
|
||||
token.RefreshToken = stringValue(auth.Metadata, "refresh_token")
|
||||
token.RefreshToken = stringValue(metadata, "refresh_token")
|
||||
}
|
||||
if token.TokenType == "" {
|
||||
token.TokenType = stringValue(auth.Metadata, "token_type")
|
||||
token.TokenType = stringValue(metadata, "token_type")
|
||||
}
|
||||
if token.Expiry.IsZero() {
|
||||
if expiry := stringValue(auth.Metadata, "expiry"); expiry != "" {
|
||||
if expiry := stringValue(metadata, "expiry"); expiry != "" {
|
||||
if ts, err := time.Parse(time.RFC3339, expiry); err == nil {
|
||||
token.Expiry = ts
|
||||
}
|
||||
@@ -550,22 +552,28 @@ func prepareGeminiCLITokenSource(ctx context.Context, cfg *config.Config, auth *
|
||||
}
|
||||
|
||||
func updateGeminiCLITokenMetadata(auth *cliproxyauth.Auth, base map[string]any, tok *oauth2.Token) {
|
||||
if auth == nil || auth.Metadata == nil || tok == nil {
|
||||
if auth == nil || tok == nil {
|
||||
return
|
||||
}
|
||||
if tok.AccessToken != "" {
|
||||
auth.Metadata["access_token"] = tok.AccessToken
|
||||
merged := buildGeminiTokenMap(base, tok)
|
||||
fields := buildGeminiTokenFields(tok, merged)
|
||||
shared := geminicli.ResolveSharedCredential(auth.Runtime)
|
||||
if shared != nil {
|
||||
snapshot := shared.MergeMetadata(fields)
|
||||
if !geminicli.IsVirtual(auth.Runtime) {
|
||||
auth.Metadata = snapshot
|
||||
}
|
||||
return
|
||||
}
|
||||
if tok.TokenType != "" {
|
||||
auth.Metadata["token_type"] = tok.TokenType
|
||||
if auth.Metadata == nil {
|
||||
auth.Metadata = make(map[string]any)
|
||||
}
|
||||
if tok.RefreshToken != "" {
|
||||
auth.Metadata["refresh_token"] = tok.RefreshToken
|
||||
}
|
||||
if !tok.Expiry.IsZero() {
|
||||
auth.Metadata["expiry"] = tok.Expiry.Format(time.RFC3339)
|
||||
for k, v := range fields {
|
||||
auth.Metadata[k] = v
|
||||
}
|
||||
}
|
||||
|
||||
func buildGeminiTokenMap(base map[string]any, tok *oauth2.Token) map[string]any {
|
||||
merged := cloneMap(base)
|
||||
if merged == nil {
|
||||
merged = make(map[string]any)
|
||||
@@ -578,8 +586,51 @@ func updateGeminiCLITokenMetadata(auth *cliproxyauth.Auth, base map[string]any,
|
||||
}
|
||||
}
|
||||
}
|
||||
return merged
|
||||
}
|
||||
|
||||
auth.Metadata["token"] = merged
|
||||
func buildGeminiTokenFields(tok *oauth2.Token, merged map[string]any) map[string]any {
|
||||
fields := make(map[string]any, 5)
|
||||
if tok.AccessToken != "" {
|
||||
fields["access_token"] = tok.AccessToken
|
||||
}
|
||||
if tok.TokenType != "" {
|
||||
fields["token_type"] = tok.TokenType
|
||||
}
|
||||
if tok.RefreshToken != "" {
|
||||
fields["refresh_token"] = tok.RefreshToken
|
||||
}
|
||||
if !tok.Expiry.IsZero() {
|
||||
fields["expiry"] = tok.Expiry.Format(time.RFC3339)
|
||||
}
|
||||
if len(merged) > 0 {
|
||||
fields["token"] = cloneMap(merged)
|
||||
}
|
||||
return fields
|
||||
}
|
||||
|
||||
func resolveGeminiProjectID(auth *cliproxyauth.Auth) string {
|
||||
if auth == nil {
|
||||
return ""
|
||||
}
|
||||
if runtime := auth.Runtime; runtime != nil {
|
||||
if virtual, ok := runtime.(*geminicli.VirtualCredential); ok && virtual != nil {
|
||||
return strings.TrimSpace(virtual.ProjectID)
|
||||
}
|
||||
}
|
||||
return strings.TrimSpace(stringValue(auth.Metadata, "project_id"))
|
||||
}
|
||||
|
||||
func geminiOAuthMetadata(auth *cliproxyauth.Auth) map[string]any {
|
||||
if auth == nil {
|
||||
return nil
|
||||
}
|
||||
if shared := geminicli.ResolveSharedCredential(auth.Runtime); shared != nil {
|
||||
if snapshot := shared.MetadataSnapshot(); len(snapshot) > 0 {
|
||||
return snapshot
|
||||
}
|
||||
}
|
||||
return auth.Metadata
|
||||
}
|
||||
|
||||
func newHTTPClient(ctx context.Context, cfg *config.Config, auth *cliproxyauth.Auth, timeout time.Duration) *http.Client {
|
||||
|
||||
425
internal/runtime/executor/gemini_vertex_executor.go
Normal file
425
internal/runtime/executor/gemini_vertex_executor.go
Normal file
@@ -0,0 +1,425 @@
|
||||
// Package executor contains provider executors. This file implements the Vertex AI
|
||||
// Gemini executor that talks to Google Vertex AI endpoints using service account
|
||||
// credentials imported by the CLI.
|
||||
package executor
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
vertexauth "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/vertex"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
||||
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
|
||||
sdktranslator "github.com/router-for-me/CLIProxyAPI/v6/sdk/translator"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/tidwall/gjson"
|
||||
"github.com/tidwall/sjson"
|
||||
"golang.org/x/oauth2"
|
||||
"golang.org/x/oauth2/google"
|
||||
)
|
||||
|
||||
const (
|
||||
// vertexAPIVersion aligns with current public Vertex Generative AI API.
|
||||
vertexAPIVersion = "v1"
|
||||
)
|
||||
|
||||
// GeminiVertexExecutor sends requests to Vertex AI Gemini endpoints using service account credentials.
|
||||
type GeminiVertexExecutor struct {
|
||||
cfg *config.Config
|
||||
}
|
||||
|
||||
// NewGeminiVertexExecutor constructs the Vertex executor.
|
||||
func NewGeminiVertexExecutor(cfg *config.Config) *GeminiVertexExecutor {
|
||||
return &GeminiVertexExecutor{cfg: cfg}
|
||||
}
|
||||
|
||||
// Identifier returns provider key for manager routing.
|
||||
func (e *GeminiVertexExecutor) Identifier() string { return "vertex" }
|
||||
|
||||
// PrepareRequest is a no-op for Vertex.
|
||||
func (e *GeminiVertexExecutor) PrepareRequest(_ *http.Request, _ *cliproxyauth.Auth) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Execute handles non-streaming requests.
|
||||
func (e *GeminiVertexExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (resp cliproxyexecutor.Response, err error) {
|
||||
projectID, location, saJSON, errCreds := vertexCreds(auth)
|
||||
if errCreds != nil {
|
||||
return resp, errCreds
|
||||
}
|
||||
|
||||
reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth)
|
||||
defer reporter.trackFailure(ctx, &err)
|
||||
|
||||
from := opts.SourceFormat
|
||||
to := sdktranslator.FromString("gemini")
|
||||
body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), false)
|
||||
if budgetOverride, includeOverride, ok := util.GeminiThinkingFromMetadata(req.Metadata); ok && util.ModelSupportsThinking(req.Model) {
|
||||
if budgetOverride != nil {
|
||||
norm := util.NormalizeThinkingBudget(req.Model, *budgetOverride)
|
||||
budgetOverride = &norm
|
||||
}
|
||||
body = util.ApplyGeminiThinkingConfig(body, budgetOverride, includeOverride)
|
||||
}
|
||||
body = util.StripThinkingConfigIfUnsupported(req.Model, body)
|
||||
body = fixGeminiImageAspectRatio(req.Model, body)
|
||||
|
||||
action := "generateContent"
|
||||
if req.Metadata != nil {
|
||||
if a, _ := req.Metadata["action"].(string); a == "countTokens" {
|
||||
action = "countTokens"
|
||||
}
|
||||
}
|
||||
baseURL := vertexBaseURL(location)
|
||||
url := fmt.Sprintf("%s/%s/projects/%s/locations/%s/publishers/google/models/%s:%s", baseURL, vertexAPIVersion, projectID, location, req.Model, action)
|
||||
if opts.Alt != "" && action != "countTokens" {
|
||||
url = url + fmt.Sprintf("?$alt=%s", opts.Alt)
|
||||
}
|
||||
body, _ = sjson.DeleteBytes(body, "session_id")
|
||||
|
||||
httpReq, errNewReq := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(body))
|
||||
if errNewReq != nil {
|
||||
return resp, errNewReq
|
||||
}
|
||||
httpReq.Header.Set("Content-Type", "application/json")
|
||||
if token, errTok := vertexAccessToken(ctx, e.cfg, auth, saJSON); errTok == nil && token != "" {
|
||||
httpReq.Header.Set("Authorization", "Bearer "+token)
|
||||
} else if errTok != nil {
|
||||
log.Errorf("vertex executor: access token error: %v", errTok)
|
||||
return resp, statusErr{code: 500, msg: "internal server error"}
|
||||
}
|
||||
applyGeminiHeaders(httpReq, auth)
|
||||
|
||||
var authID, authLabel, authType, authValue string
|
||||
if auth != nil {
|
||||
authID = auth.ID
|
||||
authLabel = auth.Label
|
||||
authType, authValue = auth.AccountInfo()
|
||||
}
|
||||
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{
|
||||
URL: url,
|
||||
Method: http.MethodPost,
|
||||
Headers: httpReq.Header.Clone(),
|
||||
Body: body,
|
||||
Provider: e.Identifier(),
|
||||
AuthID: authID,
|
||||
AuthLabel: authLabel,
|
||||
AuthType: authType,
|
||||
AuthValue: authValue,
|
||||
})
|
||||
|
||||
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
|
||||
httpResp, errDo := httpClient.Do(httpReq)
|
||||
if errDo != nil {
|
||||
recordAPIResponseError(ctx, e.cfg, errDo)
|
||||
return resp, errDo
|
||||
}
|
||||
defer func() {
|
||||
if errClose := httpResp.Body.Close(); errClose != nil {
|
||||
log.Errorf("vertex executor: close response body error: %v", errClose)
|
||||
}
|
||||
}()
|
||||
recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
|
||||
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
|
||||
b, _ := io.ReadAll(httpResp.Body)
|
||||
appendAPIResponseChunk(ctx, e.cfg, b)
|
||||
log.Debugf("request error, error status: %d, error body: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
|
||||
err = statusErr{code: httpResp.StatusCode, msg: string(b)}
|
||||
return resp, err
|
||||
}
|
||||
data, errRead := io.ReadAll(httpResp.Body)
|
||||
if errRead != nil {
|
||||
recordAPIResponseError(ctx, e.cfg, errRead)
|
||||
return resp, errRead
|
||||
}
|
||||
appendAPIResponseChunk(ctx, e.cfg, data)
|
||||
reporter.publish(ctx, parseGeminiUsage(data))
|
||||
var param any
|
||||
out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, bytes.Clone(opts.OriginalRequest), body, data, ¶m)
|
||||
resp = cliproxyexecutor.Response{Payload: []byte(out)}
|
||||
return resp, nil
|
||||
}
|
||||
|
||||
// ExecuteStream handles SSE streaming for Vertex.
|
||||
func (e *GeminiVertexExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (stream <-chan cliproxyexecutor.StreamChunk, err error) {
|
||||
projectID, location, saJSON, errCreds := vertexCreds(auth)
|
||||
if errCreds != nil {
|
||||
return nil, errCreds
|
||||
}
|
||||
|
||||
reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth)
|
||||
defer reporter.trackFailure(ctx, &err)
|
||||
|
||||
from := opts.SourceFormat
|
||||
to := sdktranslator.FromString("gemini")
|
||||
body := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), true)
|
||||
if budgetOverride, includeOverride, ok := util.GeminiThinkingFromMetadata(req.Metadata); ok && util.ModelSupportsThinking(req.Model) {
|
||||
if budgetOverride != nil {
|
||||
norm := util.NormalizeThinkingBudget(req.Model, *budgetOverride)
|
||||
budgetOverride = &norm
|
||||
}
|
||||
body = util.ApplyGeminiThinkingConfig(body, budgetOverride, includeOverride)
|
||||
}
|
||||
body = util.StripThinkingConfigIfUnsupported(req.Model, body)
|
||||
body = fixGeminiImageAspectRatio(req.Model, body)
|
||||
|
||||
baseURL := vertexBaseURL(location)
|
||||
url := fmt.Sprintf("%s/%s/projects/%s/locations/%s/publishers/google/models/%s:%s", baseURL, vertexAPIVersion, projectID, location, req.Model, "streamGenerateContent")
|
||||
if opts.Alt == "" {
|
||||
url = url + "?alt=sse"
|
||||
} else {
|
||||
url = url + fmt.Sprintf("?$alt=%s", opts.Alt)
|
||||
}
|
||||
body, _ = sjson.DeleteBytes(body, "session_id")
|
||||
|
||||
httpReq, errNewReq := http.NewRequestWithContext(ctx, http.MethodPost, url, bytes.NewReader(body))
|
||||
if errNewReq != nil {
|
||||
return nil, errNewReq
|
||||
}
|
||||
httpReq.Header.Set("Content-Type", "application/json")
|
||||
if token, errTok := vertexAccessToken(ctx, e.cfg, auth, saJSON); errTok == nil && token != "" {
|
||||
httpReq.Header.Set("Authorization", "Bearer "+token)
|
||||
} else if errTok != nil {
|
||||
log.Errorf("vertex executor: access token error: %v", errTok)
|
||||
return nil, statusErr{code: 500, msg: "internal server error"}
|
||||
}
|
||||
applyGeminiHeaders(httpReq, auth)
|
||||
|
||||
var authID, authLabel, authType, authValue string
|
||||
if auth != nil {
|
||||
authID = auth.ID
|
||||
authLabel = auth.Label
|
||||
authType, authValue = auth.AccountInfo()
|
||||
}
|
||||
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{
|
||||
URL: url,
|
||||
Method: http.MethodPost,
|
||||
Headers: httpReq.Header.Clone(),
|
||||
Body: body,
|
||||
Provider: e.Identifier(),
|
||||
AuthID: authID,
|
||||
AuthLabel: authLabel,
|
||||
AuthType: authType,
|
||||
AuthValue: authValue,
|
||||
})
|
||||
|
||||
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
|
||||
httpResp, errDo := httpClient.Do(httpReq)
|
||||
if errDo != nil {
|
||||
recordAPIResponseError(ctx, e.cfg, errDo)
|
||||
return nil, errDo
|
||||
}
|
||||
recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
|
||||
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
|
||||
b, _ := io.ReadAll(httpResp.Body)
|
||||
appendAPIResponseChunk(ctx, e.cfg, b)
|
||||
log.Debugf("request error, error status: %d, error body: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
|
||||
if errClose := httpResp.Body.Close(); errClose != nil {
|
||||
log.Errorf("vertex executor: close response body error: %v", errClose)
|
||||
}
|
||||
return nil, statusErr{code: httpResp.StatusCode, msg: string(b)}
|
||||
}
|
||||
|
||||
out := make(chan cliproxyexecutor.StreamChunk)
|
||||
stream = out
|
||||
go func() {
|
||||
defer close(out)
|
||||
defer func() {
|
||||
if errClose := httpResp.Body.Close(); errClose != nil {
|
||||
log.Errorf("vertex executor: close response body error: %v", errClose)
|
||||
}
|
||||
}()
|
||||
scanner := bufio.NewScanner(httpResp.Body)
|
||||
buf := make([]byte, 20_971_520)
|
||||
scanner.Buffer(buf, 20_971_520)
|
||||
var param any
|
||||
for scanner.Scan() {
|
||||
line := scanner.Bytes()
|
||||
appendAPIResponseChunk(ctx, e.cfg, line)
|
||||
if detail, ok := parseGeminiStreamUsage(line); ok {
|
||||
reporter.publish(ctx, detail)
|
||||
}
|
||||
lines := sdktranslator.TranslateStream(ctx, to, from, req.Model, bytes.Clone(opts.OriginalRequest), body, bytes.Clone(line), ¶m)
|
||||
for i := range lines {
|
||||
out <- cliproxyexecutor.StreamChunk{Payload: []byte(lines[i])}
|
||||
}
|
||||
}
|
||||
lines := sdktranslator.TranslateStream(ctx, to, from, req.Model, bytes.Clone(opts.OriginalRequest), body, []byte("[DONE]"), ¶m)
|
||||
for i := range lines {
|
||||
out <- cliproxyexecutor.StreamChunk{Payload: []byte(lines[i])}
|
||||
}
|
||||
if errScan := scanner.Err(); errScan != nil {
|
||||
recordAPIResponseError(ctx, e.cfg, errScan)
|
||||
reporter.publishFailure(ctx)
|
||||
out <- cliproxyexecutor.StreamChunk{Err: errScan}
|
||||
}
|
||||
}()
|
||||
return stream, nil
|
||||
}
|
||||
|
||||
// CountTokens calls Vertex countTokens endpoint.
|
||||
func (e *GeminiVertexExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
|
||||
projectID, location, saJSON, errCreds := vertexCreds(auth)
|
||||
if errCreds != nil {
|
||||
return cliproxyexecutor.Response{}, errCreds
|
||||
}
|
||||
from := opts.SourceFormat
|
||||
to := sdktranslator.FromString("gemini")
|
||||
translatedReq := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), false)
|
||||
if budgetOverride, includeOverride, ok := util.GeminiThinkingFromMetadata(req.Metadata); ok && util.ModelSupportsThinking(req.Model) {
|
||||
if budgetOverride != nil {
|
||||
norm := util.NormalizeThinkingBudget(req.Model, *budgetOverride)
|
||||
budgetOverride = &norm
|
||||
}
|
||||
translatedReq = util.ApplyGeminiThinkingConfig(translatedReq, budgetOverride, includeOverride)
|
||||
}
|
||||
translatedReq = util.StripThinkingConfigIfUnsupported(req.Model, translatedReq)
|
||||
translatedReq = fixGeminiImageAspectRatio(req.Model, translatedReq)
|
||||
respCtx := context.WithValue(ctx, "alt", opts.Alt)
|
||||
translatedReq, _ = sjson.DeleteBytes(translatedReq, "tools")
|
||||
translatedReq, _ = sjson.DeleteBytes(translatedReq, "generationConfig")
|
||||
translatedReq, _ = sjson.DeleteBytes(translatedReq, "safetySettings")
|
||||
|
||||
baseURL := vertexBaseURL(location)
|
||||
url := fmt.Sprintf("%s/%s/projects/%s/locations/%s/publishers/google/models/%s:%s", baseURL, vertexAPIVersion, projectID, location, req.Model, "countTokens")
|
||||
|
||||
httpReq, errNewReq := http.NewRequestWithContext(respCtx, http.MethodPost, url, bytes.NewReader(translatedReq))
|
||||
if errNewReq != nil {
|
||||
return cliproxyexecutor.Response{}, errNewReq
|
||||
}
|
||||
httpReq.Header.Set("Content-Type", "application/json")
|
||||
if token, errTok := vertexAccessToken(ctx, e.cfg, auth, saJSON); errTok == nil && token != "" {
|
||||
httpReq.Header.Set("Authorization", "Bearer "+token)
|
||||
} else if errTok != nil {
|
||||
log.Errorf("vertex executor: access token error: %v", errTok)
|
||||
return cliproxyexecutor.Response{}, statusErr{code: 500, msg: "internal server error"}
|
||||
}
|
||||
applyGeminiHeaders(httpReq, auth)
|
||||
|
||||
var authID, authLabel, authType, authValue string
|
||||
if auth != nil {
|
||||
authID = auth.ID
|
||||
authLabel = auth.Label
|
||||
authType, authValue = auth.AccountInfo()
|
||||
}
|
||||
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{
|
||||
URL: url,
|
||||
Method: http.MethodPost,
|
||||
Headers: httpReq.Header.Clone(),
|
||||
Body: translatedReq,
|
||||
Provider: e.Identifier(),
|
||||
AuthID: authID,
|
||||
AuthLabel: authLabel,
|
||||
AuthType: authType,
|
||||
AuthValue: authValue,
|
||||
})
|
||||
|
||||
httpClient := newProxyAwareHTTPClient(ctx, e.cfg, auth, 0)
|
||||
httpResp, errDo := httpClient.Do(httpReq)
|
||||
if errDo != nil {
|
||||
recordAPIResponseError(ctx, e.cfg, errDo)
|
||||
return cliproxyexecutor.Response{}, errDo
|
||||
}
|
||||
defer func() {
|
||||
if errClose := httpResp.Body.Close(); errClose != nil {
|
||||
log.Errorf("vertex executor: close response body error: %v", errClose)
|
||||
}
|
||||
}()
|
||||
recordAPIResponseMetadata(ctx, e.cfg, httpResp.StatusCode, httpResp.Header.Clone())
|
||||
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
|
||||
b, _ := io.ReadAll(httpResp.Body)
|
||||
appendAPIResponseChunk(ctx, e.cfg, b)
|
||||
log.Debugf("request error, error status: %d, error body: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), b))
|
||||
return cliproxyexecutor.Response{}, statusErr{code: httpResp.StatusCode, msg: string(b)}
|
||||
}
|
||||
data, errRead := io.ReadAll(httpResp.Body)
|
||||
if errRead != nil {
|
||||
recordAPIResponseError(ctx, e.cfg, errRead)
|
||||
return cliproxyexecutor.Response{}, errRead
|
||||
}
|
||||
appendAPIResponseChunk(ctx, e.cfg, data)
|
||||
if httpResp.StatusCode < 200 || httpResp.StatusCode >= 300 {
|
||||
log.Debugf("request error, error status: %d, error body: %s", httpResp.StatusCode, summarizeErrorBody(httpResp.Header.Get("Content-Type"), data))
|
||||
return cliproxyexecutor.Response{}, statusErr{code: httpResp.StatusCode, msg: string(data)}
|
||||
}
|
||||
count := gjson.GetBytes(data, "totalTokens").Int()
|
||||
out := sdktranslator.TranslateTokenCount(ctx, to, from, count, data)
|
||||
return cliproxyexecutor.Response{Payload: []byte(out)}, nil
|
||||
}
|
||||
|
||||
// Refresh is a no-op for service account based credentials.
|
||||
func (e *GeminiVertexExecutor) Refresh(_ context.Context, auth *cliproxyauth.Auth) (*cliproxyauth.Auth, error) {
|
||||
return auth, nil
|
||||
}
|
||||
|
||||
// vertexCreds extracts project, location and raw service account JSON from auth metadata.
|
||||
func vertexCreds(a *cliproxyauth.Auth) (projectID, location string, serviceAccountJSON []byte, err error) {
|
||||
if a == nil || a.Metadata == nil {
|
||||
return "", "", nil, fmt.Errorf("vertex executor: missing auth metadata")
|
||||
}
|
||||
if v, ok := a.Metadata["project_id"].(string); ok {
|
||||
projectID = strings.TrimSpace(v)
|
||||
}
|
||||
if projectID == "" {
|
||||
// Some service accounts may use "project"; still prefer standard field
|
||||
if v, ok := a.Metadata["project"].(string); ok {
|
||||
projectID = strings.TrimSpace(v)
|
||||
}
|
||||
}
|
||||
if projectID == "" {
|
||||
return "", "", nil, fmt.Errorf("vertex executor: missing project_id in credentials")
|
||||
}
|
||||
if v, ok := a.Metadata["location"].(string); ok && strings.TrimSpace(v) != "" {
|
||||
location = strings.TrimSpace(v)
|
||||
} else {
|
||||
location = "us-central1"
|
||||
}
|
||||
var sa map[string]any
|
||||
if raw, ok := a.Metadata["service_account"].(map[string]any); ok {
|
||||
sa = raw
|
||||
}
|
||||
if sa == nil {
|
||||
return "", "", nil, fmt.Errorf("vertex executor: missing service_account in credentials")
|
||||
}
|
||||
normalized, errNorm := vertexauth.NormalizeServiceAccountMap(sa)
|
||||
if errNorm != nil {
|
||||
return "", "", nil, fmt.Errorf("vertex executor: %w", errNorm)
|
||||
}
|
||||
saJSON, errMarshal := json.Marshal(normalized)
|
||||
if errMarshal != nil {
|
||||
return "", "", nil, fmt.Errorf("vertex executor: marshal service_account failed: %w", errMarshal)
|
||||
}
|
||||
return projectID, location, saJSON, nil
|
||||
}
|
||||
|
||||
func vertexBaseURL(location string) string {
|
||||
loc := strings.TrimSpace(location)
|
||||
if loc == "" {
|
||||
loc = "us-central1"
|
||||
}
|
||||
return fmt.Sprintf("https://%s-aiplatform.googleapis.com", loc)
|
||||
}
|
||||
|
||||
func vertexAccessToken(ctx context.Context, cfg *config.Config, auth *cliproxyauth.Auth, saJSON []byte) (string, error) {
|
||||
if httpClient := newProxyAwareHTTPClient(ctx, cfg, auth, 0); httpClient != nil {
|
||||
ctx = context.WithValue(ctx, oauth2.HTTPClient, httpClient)
|
||||
}
|
||||
// Use cloud-platform scope for Vertex AI.
|
||||
creds, errCreds := google.CredentialsFromJSON(ctx, saJSON, "https://www.googleapis.com/auth/cloud-platform")
|
||||
if errCreds != nil {
|
||||
return "", fmt.Errorf("vertex executor: parse service account json failed: %w", errCreds)
|
||||
}
|
||||
tok, errTok := creds.TokenSource.Token()
|
||||
if errTok != nil {
|
||||
return "", fmt.Errorf("vertex executor: get access token failed: %w", errTok)
|
||||
}
|
||||
return tok.AccessToken, nil
|
||||
}
|
||||
@@ -9,7 +9,6 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
||||
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/usage"
|
||||
"github.com/tidwall/gjson"
|
||||
@@ -32,7 +31,7 @@ func newUsageReporter(ctx context.Context, provider, model string, auth *cliprox
|
||||
model: model,
|
||||
requestedAt: time.Now(),
|
||||
apiKey: apiKey,
|
||||
source: util.HideAPIKey(resolveUsageSource(auth, apiKey)),
|
||||
source: resolveUsageSource(auth, apiKey),
|
||||
}
|
||||
if auth != nil {
|
||||
reporter.authID = auth.ID
|
||||
@@ -129,6 +128,26 @@ func apiKeyFromContext(ctx context.Context) string {
|
||||
|
||||
func resolveUsageSource(auth *cliproxyauth.Auth, ctxAPIKey string) string {
|
||||
if auth != nil {
|
||||
provider := strings.TrimSpace(auth.Provider)
|
||||
if strings.EqualFold(provider, "gemini-cli") {
|
||||
if id := strings.TrimSpace(auth.ID); id != "" {
|
||||
return id
|
||||
}
|
||||
}
|
||||
if strings.EqualFold(provider, "vertex") {
|
||||
if auth.Metadata != nil {
|
||||
if projectID, ok := auth.Metadata["project_id"].(string); ok {
|
||||
if trimmed := strings.TrimSpace(projectID); trimmed != "" {
|
||||
return trimmed
|
||||
}
|
||||
}
|
||||
if project, ok := auth.Metadata["project"].(string); ok {
|
||||
if trimmed := strings.TrimSpace(project); trimmed != "" {
|
||||
return trimmed
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if _, value := auth.AccountInfo(); value != "" {
|
||||
return strings.TrimSpace(value)
|
||||
}
|
||||
|
||||
144
internal/runtime/geminicli/state.go
Normal file
144
internal/runtime/geminicli/state.go
Normal file
@@ -0,0 +1,144 @@
|
||||
package geminicli
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"sync"
|
||||
)
|
||||
|
||||
// SharedCredential keeps canonical OAuth metadata for a multi-project Gemini CLI login.
|
||||
type SharedCredential struct {
|
||||
primaryID string
|
||||
email string
|
||||
metadata map[string]any
|
||||
projectIDs []string
|
||||
mu sync.RWMutex
|
||||
}
|
||||
|
||||
// NewSharedCredential builds a shared credential container for the given primary entry.
|
||||
func NewSharedCredential(primaryID, email string, metadata map[string]any, projectIDs []string) *SharedCredential {
|
||||
return &SharedCredential{
|
||||
primaryID: strings.TrimSpace(primaryID),
|
||||
email: strings.TrimSpace(email),
|
||||
metadata: cloneMap(metadata),
|
||||
projectIDs: cloneStrings(projectIDs),
|
||||
}
|
||||
}
|
||||
|
||||
// PrimaryID returns the owning credential identifier.
|
||||
func (s *SharedCredential) PrimaryID() string {
|
||||
if s == nil {
|
||||
return ""
|
||||
}
|
||||
return s.primaryID
|
||||
}
|
||||
|
||||
// Email returns the associated account email.
|
||||
func (s *SharedCredential) Email() string {
|
||||
if s == nil {
|
||||
return ""
|
||||
}
|
||||
return s.email
|
||||
}
|
||||
|
||||
// ProjectIDs returns a snapshot of the configured project identifiers.
|
||||
func (s *SharedCredential) ProjectIDs() []string {
|
||||
if s == nil {
|
||||
return nil
|
||||
}
|
||||
return cloneStrings(s.projectIDs)
|
||||
}
|
||||
|
||||
// MetadataSnapshot returns a deep copy of the stored OAuth metadata.
|
||||
func (s *SharedCredential) MetadataSnapshot() map[string]any {
|
||||
if s == nil {
|
||||
return nil
|
||||
}
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
return cloneMap(s.metadata)
|
||||
}
|
||||
|
||||
// MergeMetadata merges the provided fields into the shared metadata and returns an updated copy.
|
||||
func (s *SharedCredential) MergeMetadata(values map[string]any) map[string]any {
|
||||
if s == nil {
|
||||
return nil
|
||||
}
|
||||
if len(values) == 0 {
|
||||
return s.MetadataSnapshot()
|
||||
}
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
if s.metadata == nil {
|
||||
s.metadata = make(map[string]any, len(values))
|
||||
}
|
||||
for k, v := range values {
|
||||
if v == nil {
|
||||
delete(s.metadata, k)
|
||||
continue
|
||||
}
|
||||
s.metadata[k] = v
|
||||
}
|
||||
return cloneMap(s.metadata)
|
||||
}
|
||||
|
||||
// SetProjectIDs updates the stored project identifiers.
|
||||
func (s *SharedCredential) SetProjectIDs(ids []string) {
|
||||
if s == nil {
|
||||
return
|
||||
}
|
||||
s.mu.Lock()
|
||||
s.projectIDs = cloneStrings(ids)
|
||||
s.mu.Unlock()
|
||||
}
|
||||
|
||||
// VirtualCredential tracks a per-project virtual auth entry that reuses a primary credential.
|
||||
type VirtualCredential struct {
|
||||
ProjectID string
|
||||
Parent *SharedCredential
|
||||
}
|
||||
|
||||
// NewVirtualCredential creates a virtual credential descriptor bound to the shared parent.
|
||||
func NewVirtualCredential(projectID string, parent *SharedCredential) *VirtualCredential {
|
||||
return &VirtualCredential{ProjectID: strings.TrimSpace(projectID), Parent: parent}
|
||||
}
|
||||
|
||||
// ResolveSharedCredential returns the shared credential backing the provided runtime payload.
|
||||
func ResolveSharedCredential(runtime any) *SharedCredential {
|
||||
switch typed := runtime.(type) {
|
||||
case *SharedCredential:
|
||||
return typed
|
||||
case *VirtualCredential:
|
||||
return typed.Parent
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
// IsVirtual reports whether the runtime payload represents a virtual credential.
|
||||
func IsVirtual(runtime any) bool {
|
||||
if runtime == nil {
|
||||
return false
|
||||
}
|
||||
_, ok := runtime.(*VirtualCredential)
|
||||
return ok
|
||||
}
|
||||
|
||||
func cloneMap(in map[string]any) map[string]any {
|
||||
if len(in) == 0 {
|
||||
return nil
|
||||
}
|
||||
out := make(map[string]any, len(in))
|
||||
for k, v := range in {
|
||||
out[k] = v
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func cloneStrings(in []string) []string {
|
||||
if len(in) == 0 {
|
||||
return nil
|
||||
}
|
||||
out := make([]string, len(in))
|
||||
copy(out, in)
|
||||
return out
|
||||
}
|
||||
@@ -159,7 +159,6 @@ func ConvertOpenAIRequestToGemini(modelName string, inputRawJSON []byte, _ bool)
|
||||
}
|
||||
}
|
||||
}
|
||||
fmt.Printf("11111")
|
||||
|
||||
for i := 0; i < len(arr); i++ {
|
||||
m := arr[i]
|
||||
|
||||
@@ -85,6 +85,58 @@ func ConvertGeminiRequestToOpenAI(modelName string, inputRawJSON []byte, stream
|
||||
var openAIMessages []interface{}
|
||||
var toolCallIDs []string // Track tool call IDs for matching with tool results
|
||||
|
||||
// System instruction -> OpenAI system message
|
||||
// Gemini may provide `systemInstruction` or `system_instruction`; support both keys.
|
||||
systemInstruction := root.Get("systemInstruction")
|
||||
if !systemInstruction.Exists() {
|
||||
systemInstruction = root.Get("system_instruction")
|
||||
}
|
||||
if systemInstruction.Exists() {
|
||||
parts := systemInstruction.Get("parts")
|
||||
msg := map[string]interface{}{
|
||||
"role": "system",
|
||||
"content": []interface{}{},
|
||||
}
|
||||
|
||||
var aggregatedParts []interface{}
|
||||
|
||||
if parts.Exists() && parts.IsArray() {
|
||||
parts.ForEach(func(_, part gjson.Result) bool {
|
||||
// Handle text parts
|
||||
if text := part.Get("text"); text.Exists() {
|
||||
formattedText := text.String()
|
||||
aggregatedParts = append(aggregatedParts, map[string]interface{}{
|
||||
"type": "text",
|
||||
"text": formattedText,
|
||||
})
|
||||
}
|
||||
|
||||
// Handle inline data (e.g., images)
|
||||
if inlineData := part.Get("inlineData"); inlineData.Exists() {
|
||||
mimeType := inlineData.Get("mimeType").String()
|
||||
if mimeType == "" {
|
||||
mimeType = "application/octet-stream"
|
||||
}
|
||||
data := inlineData.Get("data").String()
|
||||
imageURL := fmt.Sprintf("data:%s;base64,%s", mimeType, data)
|
||||
|
||||
aggregatedParts = append(aggregatedParts, map[string]interface{}{
|
||||
"type": "image_url",
|
||||
"image_url": map[string]interface{}{
|
||||
"url": imageURL,
|
||||
},
|
||||
})
|
||||
}
|
||||
return true
|
||||
})
|
||||
}
|
||||
|
||||
if len(aggregatedParts) > 0 {
|
||||
msg["content"] = aggregatedParts
|
||||
openAIMessages = append(openAIMessages, msg)
|
||||
}
|
||||
}
|
||||
|
||||
if contents := root.Get("contents"); contents.Exists() && contents.IsArray() {
|
||||
contents.ForEach(func(_, content gjson.Result) bool {
|
||||
role := content.Get("role").String()
|
||||
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// GetProviderName determines all AI service providers capable of serving a registered model.
|
||||
@@ -59,6 +60,30 @@ func GetProviderName(modelName string) []string {
|
||||
return providers
|
||||
}
|
||||
|
||||
// ResolveAutoModel resolves the "auto" model name to an actual available model.
|
||||
// It uses an empty handler type to get any available model from the registry.
|
||||
//
|
||||
// Parameters:
|
||||
// - modelName: The model name to check (should be "auto")
|
||||
//
|
||||
// Returns:
|
||||
// - string: The resolved model name, or the original if not "auto" or resolution fails
|
||||
func ResolveAutoModel(modelName string) string {
|
||||
if modelName != "auto" {
|
||||
return modelName
|
||||
}
|
||||
|
||||
// Use empty string as handler type to get any available model
|
||||
firstModel, err := registry.GetGlobalRegistry().GetFirstAvailableModel("")
|
||||
if err != nil {
|
||||
log.Warnf("Failed to resolve 'auto' model: %v, falling back to original model name", err)
|
||||
return modelName
|
||||
}
|
||||
|
||||
log.Infof("Resolved 'auto' model to: %s", firstModel)
|
||||
return firstModel
|
||||
}
|
||||
|
||||
// IsOpenAICompatibilityAlias checks if the given model name is an alias
|
||||
// configured for OpenAI compatibility routing.
|
||||
//
|
||||
|
||||
@@ -21,6 +21,7 @@ import (
|
||||
|
||||
"github.com/fsnotify/fsnotify"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/runtime/geminicli"
|
||||
"gopkg.in/yaml.v3"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
||||
@@ -41,24 +42,26 @@ type authDirProvider interface {
|
||||
|
||||
// Watcher manages file watching for configuration and authentication files
|
||||
type Watcher struct {
|
||||
configPath string
|
||||
authDir string
|
||||
config *config.Config
|
||||
clientsMutex sync.RWMutex
|
||||
reloadCallback func(*config.Config)
|
||||
watcher *fsnotify.Watcher
|
||||
lastAuthHashes map[string]string
|
||||
lastConfigHash string
|
||||
authQueue chan<- AuthUpdate
|
||||
currentAuths map[string]*coreauth.Auth
|
||||
dispatchMu sync.Mutex
|
||||
dispatchCond *sync.Cond
|
||||
pendingUpdates map[string]AuthUpdate
|
||||
pendingOrder []string
|
||||
dispatchCancel context.CancelFunc
|
||||
storePersister storePersister
|
||||
mirroredAuthDir string
|
||||
oldConfigYaml []byte
|
||||
configPath string
|
||||
authDir string
|
||||
config *config.Config
|
||||
clientsMutex sync.RWMutex
|
||||
configReloadMu sync.Mutex
|
||||
configReloadTimer *time.Timer
|
||||
reloadCallback func(*config.Config)
|
||||
watcher *fsnotify.Watcher
|
||||
lastAuthHashes map[string]string
|
||||
lastConfigHash string
|
||||
authQueue chan<- AuthUpdate
|
||||
currentAuths map[string]*coreauth.Auth
|
||||
dispatchMu sync.Mutex
|
||||
dispatchCond *sync.Cond
|
||||
pendingUpdates map[string]AuthUpdate
|
||||
pendingOrder []string
|
||||
dispatchCancel context.CancelFunc
|
||||
storePersister storePersister
|
||||
mirroredAuthDir string
|
||||
oldConfigYaml []byte
|
||||
}
|
||||
|
||||
type stableIDGenerator struct {
|
||||
@@ -113,7 +116,8 @@ type AuthUpdate struct {
|
||||
const (
|
||||
// replaceCheckDelay is a short delay to allow atomic replace (rename) to settle
|
||||
// before deciding whether a Remove event indicates a real deletion.
|
||||
replaceCheckDelay = 50 * time.Millisecond
|
||||
replaceCheckDelay = 50 * time.Millisecond
|
||||
configReloadDebounce = 150 * time.Millisecond
|
||||
)
|
||||
|
||||
// NewWatcher creates a new file watcher instance
|
||||
@@ -172,9 +176,19 @@ func (w *Watcher) Start(ctx context.Context) error {
|
||||
// Stop stops the file watcher
|
||||
func (w *Watcher) Stop() error {
|
||||
w.stopDispatch()
|
||||
w.stopConfigReloadTimer()
|
||||
return w.watcher.Close()
|
||||
}
|
||||
|
||||
func (w *Watcher) stopConfigReloadTimer() {
|
||||
w.configReloadMu.Lock()
|
||||
if w.configReloadTimer != nil {
|
||||
w.configReloadTimer.Stop()
|
||||
w.configReloadTimer = nil
|
||||
}
|
||||
w.configReloadMu.Unlock()
|
||||
}
|
||||
|
||||
// SetConfig updates the current configuration
|
||||
func (w *Watcher) SetConfig(cfg *config.Config) {
|
||||
w.clientsMutex.Lock()
|
||||
@@ -463,8 +477,10 @@ func (w *Watcher) processEvents(ctx context.Context) {
|
||||
// handleEvent processes individual file system events
|
||||
func (w *Watcher) handleEvent(event fsnotify.Event) {
|
||||
// Filter only relevant events: config file or auth-dir JSON files.
|
||||
isConfigEvent := event.Name == w.configPath && (event.Op&fsnotify.Write == fsnotify.Write || event.Op&fsnotify.Create == fsnotify.Create)
|
||||
isAuthJSON := strings.HasPrefix(event.Name, w.authDir) && strings.HasSuffix(event.Name, ".json")
|
||||
configOps := fsnotify.Write | fsnotify.Create | fsnotify.Rename
|
||||
isConfigEvent := event.Name == w.configPath && event.Op&configOps != 0
|
||||
authOps := fsnotify.Create | fsnotify.Write | fsnotify.Remove | fsnotify.Rename
|
||||
isAuthJSON := strings.HasPrefix(event.Name, w.authDir) && strings.HasSuffix(event.Name, ".json") && event.Op&authOps != 0
|
||||
if !isConfigEvent && !isAuthJSON {
|
||||
// Ignore unrelated files (e.g., cookie snapshots *.cookie) and other noise.
|
||||
return
|
||||
@@ -476,57 +492,76 @@ func (w *Watcher) handleEvent(event fsnotify.Event) {
|
||||
// Handle config file changes
|
||||
if isConfigEvent {
|
||||
log.Debugf("config file change details - operation: %s, timestamp: %s", event.Op.String(), now.Format("2006-01-02 15:04:05.000"))
|
||||
data, err := os.ReadFile(w.configPath)
|
||||
if err != nil {
|
||||
log.Errorf("failed to read config file for hash check: %v", err)
|
||||
return
|
||||
}
|
||||
if len(data) == 0 {
|
||||
log.Debugf("ignoring empty config file write event")
|
||||
return
|
||||
}
|
||||
sum := sha256.Sum256(data)
|
||||
newHash := hex.EncodeToString(sum[:])
|
||||
|
||||
w.clientsMutex.RLock()
|
||||
currentHash := w.lastConfigHash
|
||||
w.clientsMutex.RUnlock()
|
||||
|
||||
if currentHash != "" && currentHash == newHash {
|
||||
log.Debugf("config file content unchanged (hash match), skipping reload")
|
||||
return
|
||||
}
|
||||
fmt.Printf("config file changed, reloading: %s\n", w.configPath)
|
||||
if w.reloadConfig() {
|
||||
finalHash := newHash
|
||||
if updatedData, errRead := os.ReadFile(w.configPath); errRead == nil && len(updatedData) > 0 {
|
||||
sumUpdated := sha256.Sum256(updatedData)
|
||||
finalHash = hex.EncodeToString(sumUpdated[:])
|
||||
} else if errRead != nil {
|
||||
log.WithError(errRead).Debug("failed to compute updated config hash after reload")
|
||||
}
|
||||
w.clientsMutex.Lock()
|
||||
w.lastConfigHash = finalHash
|
||||
w.clientsMutex.Unlock()
|
||||
w.persistConfigAsync()
|
||||
}
|
||||
w.scheduleConfigReload()
|
||||
return
|
||||
}
|
||||
|
||||
// Handle auth directory changes incrementally (.json only)
|
||||
fmt.Printf("auth file changed (%s): %s, processing incrementally\n", event.Op.String(), filepath.Base(event.Name))
|
||||
if event.Op&fsnotify.Create == fsnotify.Create || event.Op&fsnotify.Write == fsnotify.Write {
|
||||
w.addOrUpdateClient(event.Name)
|
||||
} else if event.Op&fsnotify.Remove == fsnotify.Remove {
|
||||
// Atomic replace on some platforms may surface as Remove+Create for the target path.
|
||||
// Wait briefly; if the file exists again, treat as update instead of removal.
|
||||
if event.Op&(fsnotify.Remove|fsnotify.Rename) != 0 {
|
||||
// Atomic replace on some platforms may surface as Rename (or Remove) before the new file is ready.
|
||||
// Wait briefly; if the path exists again, treat as an update instead of removal.
|
||||
time.Sleep(replaceCheckDelay)
|
||||
if _, statErr := os.Stat(event.Name); statErr == nil {
|
||||
// File exists after a short delay; handle as an update.
|
||||
w.addOrUpdateClient(event.Name)
|
||||
return
|
||||
}
|
||||
w.removeClient(event.Name)
|
||||
return
|
||||
}
|
||||
if event.Op&(fsnotify.Create|fsnotify.Write) != 0 {
|
||||
w.addOrUpdateClient(event.Name)
|
||||
}
|
||||
}
|
||||
|
||||
func (w *Watcher) scheduleConfigReload() {
|
||||
w.configReloadMu.Lock()
|
||||
defer w.configReloadMu.Unlock()
|
||||
if w.configReloadTimer != nil {
|
||||
w.configReloadTimer.Stop()
|
||||
}
|
||||
w.configReloadTimer = time.AfterFunc(configReloadDebounce, func() {
|
||||
w.configReloadMu.Lock()
|
||||
w.configReloadTimer = nil
|
||||
w.configReloadMu.Unlock()
|
||||
w.reloadConfigIfChanged()
|
||||
})
|
||||
}
|
||||
|
||||
func (w *Watcher) reloadConfigIfChanged() {
|
||||
data, err := os.ReadFile(w.configPath)
|
||||
if err != nil {
|
||||
log.Errorf("failed to read config file for hash check: %v", err)
|
||||
return
|
||||
}
|
||||
if len(data) == 0 {
|
||||
log.Debugf("ignoring empty config file write event")
|
||||
return
|
||||
}
|
||||
sum := sha256.Sum256(data)
|
||||
newHash := hex.EncodeToString(sum[:])
|
||||
|
||||
w.clientsMutex.RLock()
|
||||
currentHash := w.lastConfigHash
|
||||
w.clientsMutex.RUnlock()
|
||||
|
||||
if currentHash != "" && currentHash == newHash {
|
||||
log.Debugf("config file content unchanged (hash match), skipping reload")
|
||||
return
|
||||
}
|
||||
fmt.Printf("config file changed, reloading: %s\n", w.configPath)
|
||||
if w.reloadConfig() {
|
||||
finalHash := newHash
|
||||
if updatedData, errRead := os.ReadFile(w.configPath); errRead == nil && len(updatedData) > 0 {
|
||||
sumUpdated := sha256.Sum256(updatedData)
|
||||
finalHash = hex.EncodeToString(sumUpdated[:])
|
||||
} else if errRead != nil {
|
||||
log.WithError(errRead).Debug("failed to compute updated config hash after reload")
|
||||
}
|
||||
w.clientsMutex.Lock()
|
||||
w.lastConfigHash = finalHash
|
||||
w.clientsMutex.Unlock()
|
||||
w.persistConfigAsync()
|
||||
}
|
||||
}
|
||||
|
||||
@@ -995,11 +1030,119 @@ func (w *Watcher) SnapshotCoreAuths() []*coreauth.Auth {
|
||||
CreatedAt: now,
|
||||
UpdatedAt: now,
|
||||
}
|
||||
if provider == "gemini-cli" {
|
||||
if virtuals := synthesizeGeminiVirtualAuths(a, metadata, now); len(virtuals) > 0 {
|
||||
out = append(out, a)
|
||||
out = append(out, virtuals...)
|
||||
continue
|
||||
}
|
||||
}
|
||||
out = append(out, a)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func synthesizeGeminiVirtualAuths(primary *coreauth.Auth, metadata map[string]any, now time.Time) []*coreauth.Auth {
|
||||
if primary == nil || metadata == nil {
|
||||
return nil
|
||||
}
|
||||
projects := splitGeminiProjectIDs(metadata)
|
||||
if len(projects) <= 1 {
|
||||
return nil
|
||||
}
|
||||
email, _ := metadata["email"].(string)
|
||||
shared := geminicli.NewSharedCredential(primary.ID, email, metadata, projects)
|
||||
primary.Disabled = true
|
||||
primary.Status = coreauth.StatusDisabled
|
||||
primary.Runtime = shared
|
||||
if primary.Attributes == nil {
|
||||
primary.Attributes = make(map[string]string)
|
||||
}
|
||||
primary.Attributes["gemini_virtual_primary"] = "true"
|
||||
primary.Attributes["virtual_children"] = strings.Join(projects, ",")
|
||||
source := primary.Attributes["source"]
|
||||
authPath := primary.Attributes["path"]
|
||||
originalProvider := primary.Provider
|
||||
if originalProvider == "" {
|
||||
originalProvider = "gemini-cli"
|
||||
}
|
||||
label := primary.Label
|
||||
if label == "" {
|
||||
label = originalProvider
|
||||
}
|
||||
virtuals := make([]*coreauth.Auth, 0, len(projects))
|
||||
for _, projectID := range projects {
|
||||
attrs := map[string]string{
|
||||
"runtime_only": "true",
|
||||
"gemini_virtual_parent": primary.ID,
|
||||
"gemini_virtual_project": projectID,
|
||||
}
|
||||
if source != "" {
|
||||
attrs["source"] = source
|
||||
}
|
||||
if authPath != "" {
|
||||
attrs["path"] = authPath
|
||||
}
|
||||
metadataCopy := map[string]any{
|
||||
"email": email,
|
||||
"project_id": projectID,
|
||||
"virtual": true,
|
||||
"virtual_parent_id": primary.ID,
|
||||
"type": metadata["type"],
|
||||
}
|
||||
proxy := strings.TrimSpace(primary.ProxyURL)
|
||||
if proxy != "" {
|
||||
metadataCopy["proxy_url"] = proxy
|
||||
}
|
||||
virtual := &coreauth.Auth{
|
||||
ID: buildGeminiVirtualID(primary.ID, projectID),
|
||||
Provider: originalProvider,
|
||||
Label: fmt.Sprintf("%s [%s]", label, projectID),
|
||||
Status: coreauth.StatusActive,
|
||||
Attributes: attrs,
|
||||
Metadata: metadataCopy,
|
||||
ProxyURL: primary.ProxyURL,
|
||||
CreatedAt: now,
|
||||
UpdatedAt: now,
|
||||
Runtime: geminicli.NewVirtualCredential(projectID, shared),
|
||||
}
|
||||
virtuals = append(virtuals, virtual)
|
||||
}
|
||||
return virtuals
|
||||
}
|
||||
|
||||
func splitGeminiProjectIDs(metadata map[string]any) []string {
|
||||
raw, _ := metadata["project_id"].(string)
|
||||
trimmed := strings.TrimSpace(raw)
|
||||
if trimmed == "" {
|
||||
return nil
|
||||
}
|
||||
parts := strings.Split(trimmed, ",")
|
||||
result := make([]string, 0, len(parts))
|
||||
seen := make(map[string]struct{}, len(parts))
|
||||
for _, part := range parts {
|
||||
id := strings.TrimSpace(part)
|
||||
if id == "" {
|
||||
continue
|
||||
}
|
||||
if _, ok := seen[id]; ok {
|
||||
continue
|
||||
}
|
||||
seen[id] = struct{}{}
|
||||
result = append(result, id)
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func buildGeminiVirtualID(baseID, projectID string) string {
|
||||
project := strings.TrimSpace(projectID)
|
||||
if project == "" {
|
||||
project = "project"
|
||||
}
|
||||
replacer := strings.NewReplacer("/", "_", "\\", "_", " ", "_")
|
||||
return fmt.Sprintf("%s::%s", baseID, replacer.Replace(project))
|
||||
}
|
||||
|
||||
// buildCombinedClientMap merges file-based clients with API key clients from the cache.
|
||||
// buildCombinedClientMap removed
|
||||
|
||||
|
||||
@@ -295,11 +295,14 @@ func (h *BaseAPIHandler) ExecuteStreamWithAuthManager(ctx context.Context, handl
|
||||
}
|
||||
|
||||
func (h *BaseAPIHandler) getRequestDetails(modelName string) (providers []string, normalizedModel string, metadata map[string]any, err *interfaces.ErrorMessage) {
|
||||
providerName, extractedModelName, isDynamic := h.parseDynamicModel(modelName)
|
||||
// Resolve "auto" model to an actual available model first
|
||||
resolvedModelName := util.ResolveAutoModel(modelName)
|
||||
|
||||
providerName, extractedModelName, isDynamic := h.parseDynamicModel(resolvedModelName)
|
||||
|
||||
// First, normalize the model name to handle suffixes like "-thinking-128"
|
||||
// This needs to happen before determining the provider for non-dynamic models.
|
||||
normalizedModel, metadata = normalizeModelMetadata(modelName)
|
||||
normalizedModel, metadata = normalizeModelMetadata(resolvedModelName)
|
||||
|
||||
if isDynamic {
|
||||
providers = []string{providerName}
|
||||
|
||||
@@ -324,6 +324,8 @@ func (s *Service) ensureExecutorsForAuth(a *coreauth.Auth) {
|
||||
switch strings.ToLower(a.Provider) {
|
||||
case "gemini":
|
||||
s.coreManager.RegisterExecutor(executor.NewGeminiExecutor(s.cfg))
|
||||
case "vertex":
|
||||
s.coreManager.RegisterExecutor(executor.NewGeminiVertexExecutor(s.cfg))
|
||||
case "gemini-cli":
|
||||
s.coreManager.RegisterExecutor(executor.NewGeminiCLIExecutor(s.cfg))
|
||||
case "aistudio":
|
||||
@@ -602,6 +604,12 @@ func (s *Service) registerModelsForAuth(a *coreauth.Auth) {
|
||||
if a == nil || a.ID == "" {
|
||||
return
|
||||
}
|
||||
if a.Attributes != nil {
|
||||
if v := strings.TrimSpace(a.Attributes["gemini_virtual_primary"]); strings.EqualFold(v, "true") {
|
||||
GlobalModelRegistry().UnregisterClient(a.ID)
|
||||
return
|
||||
}
|
||||
}
|
||||
// Unregister legacy client ID (if present) to avoid double counting
|
||||
if a.Runtime != nil {
|
||||
if idGetter, ok := a.Runtime.(interface{ GetClientID() string }); ok {
|
||||
@@ -619,6 +627,9 @@ func (s *Service) registerModelsForAuth(a *coreauth.Auth) {
|
||||
switch provider {
|
||||
case "gemini":
|
||||
models = registry.GetGeminiModels()
|
||||
case "vertex":
|
||||
// Vertex AI Gemini supports the same model identifiers as Gemini.
|
||||
models = registry.GetGeminiModels()
|
||||
case "gemini-cli":
|
||||
models = registry.GetGeminiCLIModels()
|
||||
case "aistudio":
|
||||
|
||||
Reference in New Issue
Block a user