feat(management): 更新OAuth模型映射的清理逻辑以增强数据安全性

This commit is contained in:
Supra4E8C
2026-01-04 17:57:34 +08:00
parent f0e73efda2
commit cd22c849e2

View File

@@ -705,7 +705,7 @@ func (h *Handler) DeleteOAuthExcludedModels(c *gin.Context) {
// oauth-model-mappings: map[string][]ModelNameMapping // oauth-model-mappings: map[string][]ModelNameMapping
func (h *Handler) GetOAuthModelMappings(c *gin.Context) { func (h *Handler) GetOAuthModelMappings(c *gin.Context) {
c.JSON(200, gin.H{"oauth-model-mappings": normalizeOAuthModelMappings(h.cfg.OAuthModelMappings)}) c.JSON(200, gin.H{"oauth-model-mappings": sanitizedOAuthModelMappings(h.cfg.OAuthModelMappings)})
} }
func (h *Handler) PutOAuthModelMappings(c *gin.Context) { func (h *Handler) PutOAuthModelMappings(c *gin.Context) {
@@ -725,7 +725,7 @@ func (h *Handler) PutOAuthModelMappings(c *gin.Context) {
} }
entries = wrapper.Items entries = wrapper.Items
} }
h.cfg.OAuthModelMappings = normalizeOAuthModelMappings(entries) h.cfg.OAuthModelMappings = sanitizedOAuthModelMappings(entries)
h.persist(c) h.persist(c)
} }
@@ -751,7 +751,8 @@ func (h *Handler) PatchOAuthModelMappings(c *gin.Context) {
return return
} }
normalized := normalizeOAuthModelMappingsList(body.Mappings) normalizedMap := sanitizedOAuthModelMappings(map[string][]config.ModelNameMapping{channel: body.Mappings})
normalized := normalizedMap[channel]
if len(normalized) == 0 { if len(normalized) == 0 {
if h.cfg.OAuthModelMappings == nil { if h.cfg.OAuthModelMappings == nil {
c.JSON(404, gin.H{"error": "channel not found"}) c.JSON(404, gin.H{"error": "channel not found"})
@@ -1041,60 +1042,26 @@ func normalizeVertexCompatKey(entry *config.VertexCompatKey) {
entry.Models = normalized entry.Models = normalized
} }
func normalizeOAuthModelMappingsList(entries []config.ModelNameMapping) []config.ModelNameMapping { func sanitizedOAuthModelMappings(entries map[string][]config.ModelNameMapping) map[string][]config.ModelNameMapping {
if len(entries) == 0 { if len(entries) == 0 {
return nil return nil
} }
seenName := make(map[string]struct{}, len(entries)) copied := make(map[string][]config.ModelNameMapping, len(entries))
seenAlias := make(map[string]struct{}, len(entries)) for channel, mappings := range entries {
clean := make([]config.ModelNameMapping, 0, len(entries)) if len(mappings) == 0 {
for _, mapping := range entries {
name := strings.TrimSpace(mapping.Name)
alias := strings.TrimSpace(mapping.Alias)
if name == "" || alias == "" {
continue continue
} }
if strings.EqualFold(name, alias) { copied[channel] = append([]config.ModelNameMapping(nil), mappings...)
continue
} }
nameKey := strings.ToLower(name) if len(copied) == 0 {
aliasKey := strings.ToLower(alias)
if _, ok := seenName[nameKey]; ok {
continue
}
if _, ok := seenAlias[aliasKey]; ok {
continue
}
seenName[nameKey] = struct{}{}
seenAlias[aliasKey] = struct{}{}
clean = append(clean, config.ModelNameMapping{Name: name, Alias: alias, Fork: mapping.Fork})
}
if len(clean) == 0 {
return nil return nil
} }
return clean cfg := config.Config{OAuthModelMappings: copied}
} cfg.SanitizeOAuthModelMappings()
if len(cfg.OAuthModelMappings) == 0 {
func normalizeOAuthModelMappings(entries map[string][]config.ModelNameMapping) map[string][]config.ModelNameMapping {
if len(entries) == 0 {
return nil return nil
} }
out := make(map[string][]config.ModelNameMapping, len(entries)) return cfg.OAuthModelMappings
for rawChannel, mappings := range entries {
channel := strings.ToLower(strings.TrimSpace(rawChannel))
if channel == "" {
continue
}
normalized := normalizeOAuthModelMappingsList(mappings)
if len(normalized) == 0 {
continue
}
out[channel] = normalized
}
if len(out) == 0 {
return nil
}
return out
} }
// GetAmpCode returns the complete ampcode configuration. // GetAmpCode returns the complete ampcode configuration.