diff --git a/internal/api/handlers/management/config_lists.go b/internal/api/handlers/management/config_lists.go index e3636fd8..fb0a7655 100644 --- a/internal/api/handlers/management/config_lists.go +++ b/internal/api/handlers/management/config_lists.go @@ -487,6 +487,137 @@ func (h *Handler) DeleteOpenAICompat(c *gin.Context) { c.JSON(400, gin.H{"error": "missing name or index"}) } +// vertex-api-key: []VertexCompatKey +func (h *Handler) GetVertexCompatKeys(c *gin.Context) { + c.JSON(200, gin.H{"vertex-api-key": h.cfg.VertexCompatAPIKey}) +} +func (h *Handler) PutVertexCompatKeys(c *gin.Context) { + data, err := c.GetRawData() + if err != nil { + c.JSON(400, gin.H{"error": "failed to read body"}) + return + } + var arr []config.VertexCompatKey + if err = json.Unmarshal(data, &arr); err != nil { + var obj struct { + Items []config.VertexCompatKey `json:"items"` + } + if err2 := json.Unmarshal(data, &obj); err2 != nil || len(obj.Items) == 0 { + c.JSON(400, gin.H{"error": "invalid body"}) + return + } + arr = obj.Items + } + for i := range arr { + normalizeVertexCompatKey(&arr[i]) + } + h.cfg.VertexCompatAPIKey = arr + h.cfg.SanitizeVertexCompatKeys() + h.persist(c) +} +func (h *Handler) PatchVertexCompatKey(c *gin.Context) { + type vertexCompatPatch struct { + APIKey *string `json:"api-key"` + Prefix *string `json:"prefix"` + BaseURL *string `json:"base-url"` + ProxyURL *string `json:"proxy-url"` + Headers *map[string]string `json:"headers"` + Models *[]config.VertexCompatModel `json:"models"` + } + var body struct { + Index *int `json:"index"` + Match *string `json:"match"` + Value *vertexCompatPatch `json:"value"` + } + if errBindJSON := c.ShouldBindJSON(&body); errBindJSON != nil || body.Value == nil { + c.JSON(400, gin.H{"error": "invalid body"}) + return + } + targetIndex := -1 + if body.Index != nil && *body.Index >= 0 && *body.Index < len(h.cfg.VertexCompatAPIKey) { + targetIndex = *body.Index + } + if targetIndex == -1 && body.Match != nil { + match := strings.TrimSpace(*body.Match) + if match != "" { + for i := range h.cfg.VertexCompatAPIKey { + if h.cfg.VertexCompatAPIKey[i].APIKey == match { + targetIndex = i + break + } + } + } + } + if targetIndex == -1 { + c.JSON(404, gin.H{"error": "item not found"}) + return + } + + entry := h.cfg.VertexCompatAPIKey[targetIndex] + if body.Value.APIKey != nil { + trimmed := strings.TrimSpace(*body.Value.APIKey) + if trimmed == "" { + h.cfg.VertexCompatAPIKey = append(h.cfg.VertexCompatAPIKey[:targetIndex], h.cfg.VertexCompatAPIKey[targetIndex+1:]...) + h.cfg.SanitizeVertexCompatKeys() + h.persist(c) + return + } + entry.APIKey = trimmed + } + if body.Value.Prefix != nil { + entry.Prefix = strings.TrimSpace(*body.Value.Prefix) + } + if body.Value.BaseURL != nil { + trimmed := strings.TrimSpace(*body.Value.BaseURL) + if trimmed == "" { + h.cfg.VertexCompatAPIKey = append(h.cfg.VertexCompatAPIKey[:targetIndex], h.cfg.VertexCompatAPIKey[targetIndex+1:]...) + h.cfg.SanitizeVertexCompatKeys() + h.persist(c) + return + } + entry.BaseURL = trimmed + } + if body.Value.ProxyURL != nil { + entry.ProxyURL = strings.TrimSpace(*body.Value.ProxyURL) + } + if body.Value.Headers != nil { + entry.Headers = config.NormalizeHeaders(*body.Value.Headers) + } + if body.Value.Models != nil { + entry.Models = append([]config.VertexCompatModel(nil), (*body.Value.Models)...) + } + normalizeVertexCompatKey(&entry) + h.cfg.VertexCompatAPIKey[targetIndex] = entry + h.cfg.SanitizeVertexCompatKeys() + h.persist(c) +} + +func (h *Handler) DeleteVertexCompatKey(c *gin.Context) { + if val := strings.TrimSpace(c.Query("api-key")); val != "" { + out := make([]config.VertexCompatKey, 0, len(h.cfg.VertexCompatAPIKey)) + for _, v := range h.cfg.VertexCompatAPIKey { + if v.APIKey != val { + out = append(out, v) + } + } + h.cfg.VertexCompatAPIKey = out + h.cfg.SanitizeVertexCompatKeys() + h.persist(c) + return + } + if idxStr := c.Query("index"); idxStr != "" { + var idx int + _, errScan := fmt.Sscanf(idxStr, "%d", &idx) + if errScan == nil && idx >= 0 && idx < len(h.cfg.VertexCompatAPIKey) { + h.cfg.VertexCompatAPIKey = append(h.cfg.VertexCompatAPIKey[:idx], h.cfg.VertexCompatAPIKey[idx+1:]...) + h.cfg.SanitizeVertexCompatKeys() + h.persist(c) + return + } + } + c.JSON(400, gin.H{"error": "missing api-key or index"}) +} + // oauth-excluded-models: map[string][]string func (h *Handler) GetOAuthExcludedModels(c *gin.Context) { c.JSON(200, gin.H{"oauth-excluded-models": config.NormalizeOAuthExcludedModels(h.cfg.OAuthExcludedModels)}) @@ -572,6 +703,102 @@ func (h *Handler) DeleteOAuthExcludedModels(c *gin.Context) { h.persist(c) } +// oauth-model-mappings: map[string][]ModelNameMapping +func (h *Handler) GetOAuthModelMappings(c *gin.Context) { + c.JSON(200, gin.H{"oauth-model-mappings": normalizeOAuthModelMappings(h.cfg.OAuthModelMappings)}) +} + +func (h *Handler) PutOAuthModelMappings(c *gin.Context) { + data, err := c.GetRawData() + if err != nil { + c.JSON(400, gin.H{"error": "failed to read body"}) + return + } + var entries map[string][]config.ModelNameMapping + if err = json.Unmarshal(data, &entries); err != nil { + var wrapper struct { + Items map[string][]config.ModelNameMapping `json:"items"` + } + if err2 := json.Unmarshal(data, &wrapper); err2 != nil { + c.JSON(400, gin.H{"error": "invalid body"}) + return + } + entries = wrapper.Items + } + h.cfg.OAuthModelMappings = normalizeOAuthModelMappings(entries) + h.persist(c) +} + +func (h *Handler) PatchOAuthModelMappings(c *gin.Context) { + var body struct { + Provider *string `json:"provider"` + Channel *string `json:"channel"` + Mappings []config.ModelNameMapping `json:"mappings"` + } + if errBindJSON := c.ShouldBindJSON(&body); errBindJSON != nil { + c.JSON(400, gin.H{"error": "invalid body"}) + return + } + channelRaw := "" + if body.Channel != nil { + channelRaw = *body.Channel + } else if body.Provider != nil { + channelRaw = *body.Provider + } + channel := strings.ToLower(strings.TrimSpace(channelRaw)) + if channel == "" { + c.JSON(400, gin.H{"error": "invalid channel"}) + return + } + + normalized := normalizeOAuthModelMappingsList(body.Mappings) + if len(normalized) == 0 { + if h.cfg.OAuthModelMappings == nil { + c.JSON(404, gin.H{"error": "channel not found"}) + return + } + if _, ok := h.cfg.OAuthModelMappings[channel]; !ok { + c.JSON(404, gin.H{"error": "channel not found"}) + return + } + delete(h.cfg.OAuthModelMappings, channel) + if len(h.cfg.OAuthModelMappings) == 0 { + h.cfg.OAuthModelMappings = nil + } + h.persist(c) + return + } + if h.cfg.OAuthModelMappings == nil { + h.cfg.OAuthModelMappings = make(map[string][]config.ModelNameMapping) + } + h.cfg.OAuthModelMappings[channel] = normalized + h.persist(c) +} + +func (h *Handler) DeleteOAuthModelMappings(c *gin.Context) { + channel := strings.ToLower(strings.TrimSpace(c.Query("channel"))) + if channel == "" { + channel = strings.ToLower(strings.TrimSpace(c.Query("provider"))) + } + if channel == "" { + c.JSON(400, gin.H{"error": "missing channel"}) + return + } + if h.cfg.OAuthModelMappings == nil { + c.JSON(404, gin.H{"error": "channel not found"}) + return + } + if _, ok := h.cfg.OAuthModelMappings[channel]; !ok { + c.JSON(404, gin.H{"error": "channel not found"}) + return + } + delete(h.cfg.OAuthModelMappings, channel) + if len(h.cfg.OAuthModelMappings) == 0 { + h.cfg.OAuthModelMappings = nil + } + h.persist(c) +} + // codex-api-key: []CodexKey func (h *Handler) GetCodexKeys(c *gin.Context) { c.JSON(200, gin.H{"codex-api-key": h.cfg.CodexKey}) @@ -789,6 +1016,87 @@ func normalizeCodexKey(entry *config.CodexKey) { entry.Models = normalized } +func normalizeVertexCompatKey(entry *config.VertexCompatKey) { + if entry == nil { + return + } + entry.APIKey = strings.TrimSpace(entry.APIKey) + entry.Prefix = strings.TrimSpace(entry.Prefix) + entry.BaseURL = strings.TrimSpace(entry.BaseURL) + entry.ProxyURL = strings.TrimSpace(entry.ProxyURL) + entry.Headers = config.NormalizeHeaders(entry.Headers) + if len(entry.Models) == 0 { + return + } + normalized := make([]config.VertexCompatModel, 0, len(entry.Models)) + for i := range entry.Models { + model := entry.Models[i] + model.Name = strings.TrimSpace(model.Name) + model.Alias = strings.TrimSpace(model.Alias) + if model.Name == "" || model.Alias == "" { + continue + } + normalized = append(normalized, model) + } + entry.Models = normalized +} + +func normalizeOAuthModelMappingsList(entries []config.ModelNameMapping) []config.ModelNameMapping { + if len(entries) == 0 { + return nil + } + seenName := make(map[string]struct{}, len(entries)) + seenAlias := make(map[string]struct{}, len(entries)) + clean := make([]config.ModelNameMapping, 0, len(entries)) + for _, mapping := range entries { + name := strings.TrimSpace(mapping.Name) + alias := strings.TrimSpace(mapping.Alias) + if name == "" || alias == "" { + continue + } + if strings.EqualFold(name, alias) { + continue + } + nameKey := strings.ToLower(name) + aliasKey := strings.ToLower(alias) + if _, ok := seenName[nameKey]; ok { + continue + } + if _, ok := seenAlias[aliasKey]; ok { + continue + } + seenName[nameKey] = struct{}{} + seenAlias[aliasKey] = struct{}{} + clean = append(clean, config.ModelNameMapping{Name: name, Alias: alias, Fork: mapping.Fork}) + } + if len(clean) == 0 { + return nil + } + return clean +} + +func normalizeOAuthModelMappings(entries map[string][]config.ModelNameMapping) map[string][]config.ModelNameMapping { + if len(entries) == 0 { + return nil + } + out := make(map[string][]config.ModelNameMapping, len(entries)) + for rawChannel, mappings := range entries { + channel := strings.ToLower(strings.TrimSpace(rawChannel)) + if channel == "" { + continue + } + normalized := normalizeOAuthModelMappingsList(mappings) + if len(normalized) == 0 { + continue + } + out[channel] = normalized + } + if len(out) == 0 { + return nil + } + return out +} + // GetAmpCode returns the complete ampcode configuration. func (h *Handler) GetAmpCode(c *gin.Context) { if h == nil || h.cfg == nil { diff --git a/internal/api/server.go b/internal/api/server.go index 3119bbb9..05bb2fee 100644 --- a/internal/api/server.go +++ b/internal/api/server.go @@ -591,11 +591,21 @@ func (s *Server) registerManagementRoutes() { mgmt.PATCH("/openai-compatibility", s.mgmt.PatchOpenAICompat) mgmt.DELETE("/openai-compatibility", s.mgmt.DeleteOpenAICompat) + mgmt.GET("/vertex-api-key", s.mgmt.GetVertexCompatKeys) + mgmt.PUT("/vertex-api-key", s.mgmt.PutVertexCompatKeys) + mgmt.PATCH("/vertex-api-key", s.mgmt.PatchVertexCompatKey) + mgmt.DELETE("/vertex-api-key", s.mgmt.DeleteVertexCompatKey) + mgmt.GET("/oauth-excluded-models", s.mgmt.GetOAuthExcludedModels) mgmt.PUT("/oauth-excluded-models", s.mgmt.PutOAuthExcludedModels) mgmt.PATCH("/oauth-excluded-models", s.mgmt.PatchOAuthExcludedModels) mgmt.DELETE("/oauth-excluded-models", s.mgmt.DeleteOAuthExcludedModels) + mgmt.GET("/oauth-model-mappings", s.mgmt.GetOAuthModelMappings) + mgmt.PUT("/oauth-model-mappings", s.mgmt.PutOAuthModelMappings) + mgmt.PATCH("/oauth-model-mappings", s.mgmt.PatchOAuthModelMappings) + mgmt.DELETE("/oauth-model-mappings", s.mgmt.DeleteOAuthModelMappings) + mgmt.GET("/auth-files", s.mgmt.ListAuthFiles) mgmt.GET("/auth-files/models", s.mgmt.GetAuthFileModels) mgmt.GET("/auth-files/download", s.mgmt.DownloadAuthFile)