From ab76cb3662a68ee765e8f85b5461f76fcbd46536 Mon Sep 17 00:00:00 2001 From: Luis Pater Date: Mon, 10 Nov 2025 20:48:31 +0800 Subject: [PATCH] feat(management): add Vertex service account import and WebSocket auth management Introduce an endpoint for importing Vertex service account JSON keys and storing them as authentication records. Add handlers for managing WebSocket authentication configuration. --- .../api/handlers/management/config_basic.go | 8 + .../api/handlers/management/vertex_import.go | 156 ++++++++++++++++++ internal/api/server.go | 4 + 3 files changed, 168 insertions(+) create mode 100644 internal/api/handlers/management/vertex_import.go diff --git a/internal/api/handlers/management/config_basic.go b/internal/api/handlers/management/config_basic.go index 9a8c2923..3bbfd39a 100644 --- a/internal/api/handlers/management/config_basic.go +++ b/internal/api/handlers/management/config_basic.go @@ -153,6 +153,14 @@ func (h *Handler) PutRequestLog(c *gin.Context) { h.updateBoolField(c, func(v bool) { h.cfg.RequestLog = v }) } +// Websocket auth +func (h *Handler) GetWebsocketAuth(c *gin.Context) { + c.JSON(200, gin.H{"ws-auth": h.cfg.WebsocketAuth}) +} +func (h *Handler) PutWebsocketAuth(c *gin.Context) { + h.updateBoolField(c, func(v bool) { h.cfg.WebsocketAuth = v }) +} + // Request retry func (h *Handler) GetRequestRetry(c *gin.Context) { c.JSON(200, gin.H{"request-retry": h.cfg.RequestRetry}) diff --git a/internal/api/handlers/management/vertex_import.go b/internal/api/handlers/management/vertex_import.go new file mode 100644 index 00000000..bad066a2 --- /dev/null +++ b/internal/api/handlers/management/vertex_import.go @@ -0,0 +1,156 @@ +package management + +import ( + "context" + "encoding/json" + "fmt" + "io" + "net/http" + "strings" + + "github.com/gin-gonic/gin" + "github.com/router-for-me/CLIProxyAPI/v6/internal/auth/vertex" + coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth" +) + +// ImportVertexCredential handles uploading a Vertex service account JSON and saving it as an auth record. +func (h *Handler) ImportVertexCredential(c *gin.Context) { + if h == nil || h.cfg == nil { + c.JSON(http.StatusServiceUnavailable, gin.H{"error": "config unavailable"}) + return + } + if h.cfg.AuthDir == "" { + c.JSON(http.StatusServiceUnavailable, gin.H{"error": "auth directory not configured"}) + return + } + + fileHeader, err := c.FormFile("file") + if err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": "file required"}) + return + } + + file, err := fileHeader.Open() + if err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("failed to read file: %v", err)}) + return + } + defer file.Close() + + data, err := io.ReadAll(file) + if err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": fmt.Sprintf("failed to read file: %v", err)}) + return + } + + var serviceAccount map[string]any + if err := json.Unmarshal(data, &serviceAccount); err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": "invalid json", "message": err.Error()}) + return + } + + normalizedSA, err := vertex.NormalizeServiceAccountMap(serviceAccount) + if err != nil { + c.JSON(http.StatusBadRequest, gin.H{"error": "invalid service account", "message": err.Error()}) + return + } + serviceAccount = normalizedSA + + projectID := strings.TrimSpace(valueAsString(serviceAccount["project_id"])) + if projectID == "" { + c.JSON(http.StatusBadRequest, gin.H{"error": "project_id missing"}) + return + } + email := strings.TrimSpace(valueAsString(serviceAccount["client_email"])) + + location := strings.TrimSpace(c.PostForm("location")) + if location == "" { + location = strings.TrimSpace(c.Query("location")) + } + if location == "" { + location = "us-central1" + } + + fileName := fmt.Sprintf("vertex-%s.json", sanitizeVertexFilePart(projectID)) + label := labelForVertex(projectID, email) + storage := &vertex.VertexCredentialStorage{ + ServiceAccount: serviceAccount, + ProjectID: projectID, + Email: email, + Location: location, + Type: "vertex", + } + metadata := map[string]any{ + "service_account": serviceAccount, + "project_id": projectID, + "email": email, + "location": location, + "type": "vertex", + "label": label, + } + record := &coreauth.Auth{ + ID: fileName, + Provider: "vertex", + FileName: fileName, + Storage: storage, + Label: label, + Metadata: metadata, + } + + ctx := context.Background() + if reqCtx := c.Request.Context(); reqCtx != nil { + ctx = reqCtx + } + savedPath, err := h.saveTokenRecord(ctx, record) + if err != nil { + c.JSON(http.StatusInternalServerError, gin.H{"error": "save_failed", "message": err.Error()}) + return + } + + c.JSON(http.StatusOK, gin.H{ + "status": "ok", + "auth-file": savedPath, + "project_id": projectID, + "email": email, + "location": location, + }) +} + +func valueAsString(v any) string { + if v == nil { + return "" + } + switch t := v.(type) { + case string: + return t + default: + return fmt.Sprint(t) + } +} + +func sanitizeVertexFilePart(s string) string { + out := strings.TrimSpace(s) + replacers := []string{"/", "_", "\\", "_", ":", "_", " ", "-"} + for i := 0; i < len(replacers); i += 2 { + out = strings.ReplaceAll(out, replacers[i], replacers[i+1]) + } + if out == "" { + return "vertex" + } + return out +} + +func labelForVertex(projectID, email string) string { + p := strings.TrimSpace(projectID) + e := strings.TrimSpace(email) + if p != "" && e != "" { + return fmt.Sprintf("%s (%s)", p, e) + } + if p != "" { + return p + } + if e != "" { + return e + } + return "vertex" +} diff --git a/internal/api/server.go b/internal/api/server.go index 3688ad34..78672f02 100644 --- a/internal/api/server.go +++ b/internal/api/server.go @@ -484,6 +484,9 @@ func (s *Server) registerManagementRoutes() { mgmt.GET("/request-log", s.mgmt.GetRequestLog) mgmt.PUT("/request-log", s.mgmt.PutRequestLog) mgmt.PATCH("/request-log", s.mgmt.PutRequestLog) + mgmt.GET("/ws-auth", s.mgmt.GetWebsocketAuth) + mgmt.PUT("/ws-auth", s.mgmt.PutWebsocketAuth) + mgmt.PATCH("/ws-auth", s.mgmt.PutWebsocketAuth) mgmt.GET("/request-retry", s.mgmt.GetRequestRetry) mgmt.PUT("/request-retry", s.mgmt.PutRequestRetry) @@ -508,6 +511,7 @@ func (s *Server) registerManagementRoutes() { mgmt.GET("/auth-files/download", s.mgmt.DownloadAuthFile) mgmt.POST("/auth-files", s.mgmt.UploadAuthFile) mgmt.DELETE("/auth-files", s.mgmt.DeleteAuthFile) + mgmt.POST("/vertex/import", s.mgmt.ImportVertexCredential) mgmt.GET("/anthropic-auth-url", s.mgmt.RequestAnthropicToken) mgmt.GET("/codex-auth-url", s.mgmt.RequestCodexToken)