Merge pull request #69 from router-for-me/reload

Implement minimal incremental updates for models and keys
This commit is contained in:
Luis Pater
2025-09-26 23:06:27 +08:00
committed by GitHub
5 changed files with 640 additions and 71 deletions

View File

@@ -194,7 +194,7 @@ func NewServer(cfg *config.Config, authManager *auth.Manager, accessManager *sdk
loggerToggle: toggle, loggerToggle: toggle,
configFilePath: configFilePath, configFilePath: configFilePath,
} }
s.applyAccessConfig(cfg) s.applyAccessConfig(nil, cfg)
// Initialize management handler // Initialize management handler
s.mgmt = managementHandlers.NewHandler(cfg, configFilePath, authManager) s.mgmt = managementHandlers.NewHandler(cfg, configFilePath, authManager)
if optionState.localPassword != "" { if optionState.localPassword != "" {
@@ -547,16 +547,23 @@ func corsMiddleware() gin.HandlerFunc {
} }
} }
func (s *Server) applyAccessConfig(cfg *config.Config) { func (s *Server) applyAccessConfig(oldCfg, newCfg *config.Config) {
if s == nil || s.accessManager == nil { if s == nil || s.accessManager == nil || newCfg == nil {
return return
} }
providers, err := sdkaccess.BuildProviders(cfg) existing := s.accessManager.Providers()
providers, added, updated, removed, err := sdkaccess.ReconcileProviders(oldCfg, newCfg, existing)
if err != nil { if err != nil {
log.Errorf("failed to update request auth providers: %v", err) log.Errorf("failed to reconcile request auth providers: %v", err)
return return
} }
s.accessManager.SetProviders(providers) s.accessManager.SetProviders(providers)
if len(added)+len(updated)+len(removed) > 0 {
log.Debugf("auth providers reconciled (added=%d updated=%d removed=%d)", len(added), len(updated), len(removed))
log.Debugf("auth provider changes details - added=%v updated=%v removed=%v", added, updated, removed)
} else {
log.Debug("auth providers unchanged after config update")
}
} }
// UpdateClients updates the server's client list and configuration. // UpdateClients updates the server's client list and configuration.
@@ -566,44 +573,60 @@ func (s *Server) applyAccessConfig(cfg *config.Config) {
// - clients: The new slice of AI service clients // - clients: The new slice of AI service clients
// - cfg: The new application configuration // - cfg: The new application configuration
func (s *Server) UpdateClients(cfg *config.Config) { func (s *Server) UpdateClients(cfg *config.Config) {
oldCfg := s.cfg
// Update request logger enabled state if it has changed // Update request logger enabled state if it has changed
if s.requestLogger != nil && s.cfg.RequestLog != cfg.RequestLog { previousRequestLog := false
if oldCfg != nil {
previousRequestLog = oldCfg.RequestLog
}
if s.requestLogger != nil && (oldCfg == nil || previousRequestLog != cfg.RequestLog) {
if s.loggerToggle != nil { if s.loggerToggle != nil {
s.loggerToggle(cfg.RequestLog) s.loggerToggle(cfg.RequestLog)
} else if toggler, ok := s.requestLogger.(interface{ SetEnabled(bool) }); ok { } else if toggler, ok := s.requestLogger.(interface{ SetEnabled(bool) }); ok {
toggler.SetEnabled(cfg.RequestLog) toggler.SetEnabled(cfg.RequestLog)
} }
log.Debugf("request logging updated from %t to %t", s.cfg.RequestLog, cfg.RequestLog) if oldCfg != nil {
} log.Debugf("request logging updated from %t to %t", previousRequestLog, cfg.RequestLog)
if s.cfg.LoggingToFile != cfg.LoggingToFile {
if err := logging.ConfigureLogOutput(cfg.LoggingToFile); err != nil {
log.Errorf("failed to reconfigure log output: %v", err)
} else { } else {
log.Debugf("logging_to_file updated from %t to %t", s.cfg.LoggingToFile, cfg.LoggingToFile) log.Debugf("request logging toggled to %t", cfg.RequestLog)
} }
} }
if s.cfg == nil || s.cfg.UsageStatisticsEnabled != cfg.UsageStatisticsEnabled { if oldCfg != nil && oldCfg.LoggingToFile != cfg.LoggingToFile {
if err := logging.ConfigureLogOutput(cfg.LoggingToFile); err != nil {
log.Errorf("failed to reconfigure log output: %v", err)
} else {
log.Debugf("logging_to_file updated from %t to %t", oldCfg.LoggingToFile, cfg.LoggingToFile)
}
}
if oldCfg == nil || oldCfg.UsageStatisticsEnabled != cfg.UsageStatisticsEnabled {
usage.SetStatisticsEnabled(cfg.UsageStatisticsEnabled) usage.SetStatisticsEnabled(cfg.UsageStatisticsEnabled)
if s.cfg != nil { if oldCfg != nil {
log.Debugf("usage_statistics_enabled updated from %t to %t", s.cfg.UsageStatisticsEnabled, cfg.UsageStatisticsEnabled) log.Debugf("usage_statistics_enabled updated from %t to %t", oldCfg.UsageStatisticsEnabled, cfg.UsageStatisticsEnabled)
} else {
log.Debugf("usage_statistics_enabled toggled to %t", cfg.UsageStatisticsEnabled)
} }
} }
// Update log level dynamically when debug flag changes // Update log level dynamically when debug flag changes
if s.cfg.Debug != cfg.Debug { if oldCfg == nil || oldCfg.Debug != cfg.Debug {
util.SetLogLevel(cfg) util.SetLogLevel(cfg)
log.Debugf("debug mode updated from %t to %t", s.cfg.Debug, cfg.Debug) if oldCfg != nil {
log.Debugf("debug mode updated from %t to %t", oldCfg.Debug, cfg.Debug)
} else {
log.Debugf("debug mode toggled to %t", cfg.Debug)
}
} }
s.applyAccessConfig(oldCfg, cfg)
s.cfg = cfg s.cfg = cfg
s.handlers.UpdateClients(cfg) s.handlers.UpdateClients(cfg)
if s.mgmt != nil { if s.mgmt != nil {
s.mgmt.SetConfig(cfg) s.mgmt.SetConfig(cfg)
s.mgmt.SetAuthManager(s.handlers.AuthManager) s.mgmt.SetAuthManager(s.handlers.AuthManager)
} }
s.applyAccessConfig(cfg)
// Count client sources from configuration and auth directory // Count client sources from configuration and auth directory
authFiles := util.CountAuthFiles(cfg.AuthDir) authFiles := util.CountAuthFiles(cfg.AuthDir)

View File

@@ -101,58 +101,267 @@ func (r *ModelRegistry) RegisterClient(clientID, clientProvider string, models [
r.mutex.Lock() r.mutex.Lock()
defer r.mutex.Unlock() defer r.mutex.Unlock()
// Remove any existing registration for this client
r.unregisterClientInternal(clientID)
provider := strings.ToLower(clientProvider) provider := strings.ToLower(clientProvider)
modelIDs := make([]string, 0, len(models)) uniqueModelIDs := make([]string, 0, len(models))
rawModelIDs := make([]string, 0, len(models))
newModels := make(map[string]*ModelInfo, len(models))
newCounts := make(map[string]int, len(models))
for _, model := range models {
if model == nil || model.ID == "" {
continue
}
rawModelIDs = append(rawModelIDs, model.ID)
newCounts[model.ID]++
if _, exists := newModels[model.ID]; exists {
continue
}
newModels[model.ID] = model
uniqueModelIDs = append(uniqueModelIDs, model.ID)
}
if len(uniqueModelIDs) == 0 {
// No models supplied; unregister existing client state if present.
r.unregisterClientInternal(clientID)
delete(r.clientModels, clientID)
delete(r.clientProviders, clientID)
misc.LogCredentialSeparator()
return
}
now := time.Now() now := time.Now()
for _, model := range models { oldModels, hadExisting := r.clientModels[clientID]
modelIDs = append(modelIDs, model.ID) oldProvider, _ := r.clientProviders[clientID]
providerChanged := oldProvider != provider
if existing, exists := r.models[model.ID]; exists { if !hadExisting {
// Model already exists, increment count // Pure addition path.
existing.Count++ for _, modelID := range rawModelIDs {
existing.LastUpdated = now model := newModels[modelID]
if existing.SuspendedClients == nil { r.addModelRegistration(modelID, provider, model, now)
existing.SuspendedClients = make(map[string]string) }
} r.clientModels[clientID] = append([]string(nil), rawModelIDs...)
if provider != "" { if provider != "" {
if existing.Providers == nil { r.clientProviders[clientID] = provider
existing.Providers = make(map[string]int)
}
existing.Providers[provider]++
}
log.Debugf("Incremented count for model %s, now %d clients", model.ID, existing.Count)
} else { } else {
// New model, create registration delete(r.clientProviders, clientID)
registration := &ModelRegistration{ }
Info: model, log.Debugf("Registered client %s from provider %s with %d models", clientID, clientProvider, len(rawModelIDs))
Count: 1, misc.LogCredentialSeparator()
LastUpdated: now, return
QuotaExceededClients: make(map[string]*time.Time), }
SuspendedClients: make(map[string]string),
} oldCounts := make(map[string]int, len(oldModels))
if provider != "" { for _, id := range oldModels {
registration.Providers = map[string]int{provider: 1} oldCounts[id]++
} }
r.models[model.ID] = registration
log.Debugf("Registered new model %s from provider %s", model.ID, clientProvider) added := make([]string, 0)
for _, id := range uniqueModelIDs {
if oldCounts[id] == 0 {
added = append(added, id)
} }
} }
r.clientModels[clientID] = modelIDs removed := make([]string, 0)
for id := range oldCounts {
if newCounts[id] == 0 {
removed = append(removed, id)
}
}
// Handle provider change for overlapping models before modifications.
if providerChanged && oldProvider != "" {
for id, newCount := range newCounts {
if newCount == 0 {
continue
}
oldCount := oldCounts[id]
if oldCount == 0 {
continue
}
toRemove := newCount
if oldCount < toRemove {
toRemove = oldCount
}
if reg, ok := r.models[id]; ok && reg.Providers != nil {
if count, okProv := reg.Providers[oldProvider]; okProv {
if count <= toRemove {
delete(reg.Providers, oldProvider)
} else {
reg.Providers[oldProvider] = count - toRemove
}
}
}
}
}
// Apply removals first to keep counters accurate.
for _, id := range removed {
oldCount := oldCounts[id]
for i := 0; i < oldCount; i++ {
r.removeModelRegistration(clientID, id, oldProvider, now)
}
}
for id, oldCount := range oldCounts {
newCount := newCounts[id]
if newCount == 0 || oldCount <= newCount {
continue
}
overage := oldCount - newCount
for i := 0; i < overage; i++ {
r.removeModelRegistration(clientID, id, oldProvider, now)
}
}
// Apply additions.
for id, newCount := range newCounts {
oldCount := oldCounts[id]
if newCount <= oldCount {
continue
}
model := newModels[id]
diff := newCount - oldCount
for i := 0; i < diff; i++ {
r.addModelRegistration(id, provider, model, now)
}
}
// Update metadata for models that remain associated with the client.
addedSet := make(map[string]struct{}, len(added))
for _, id := range added {
addedSet[id] = struct{}{}
}
for _, id := range uniqueModelIDs {
model := newModels[id]
if reg, ok := r.models[id]; ok {
reg.Info = cloneModelInfo(model)
reg.LastUpdated = now
if reg.QuotaExceededClients != nil {
delete(reg.QuotaExceededClients, clientID)
}
if reg.SuspendedClients != nil {
delete(reg.SuspendedClients, clientID)
}
if providerChanged && provider != "" {
if _, newlyAdded := addedSet[id]; newlyAdded {
continue
}
overlapCount := newCounts[id]
if oldCount := oldCounts[id]; oldCount < overlapCount {
overlapCount = oldCount
}
if overlapCount <= 0 {
continue
}
if reg.Providers == nil {
reg.Providers = make(map[string]int)
}
reg.Providers[provider] += overlapCount
}
}
}
// Update client bookkeeping.
if len(rawModelIDs) > 0 {
r.clientModels[clientID] = append([]string(nil), rawModelIDs...)
}
if provider != "" { if provider != "" {
r.clientProviders[clientID] = provider r.clientProviders[clientID] = provider
} else { } else {
delete(r.clientProviders, clientID) delete(r.clientProviders, clientID)
} }
log.Debugf("Registered client %s from provider %s with %d models", clientID, clientProvider, len(models))
// Separator at the end of the registration block (acts as boundary to next group) if len(added) == 0 && len(removed) == 0 && !providerChanged {
// Only metadata (e.g., display name) changed; skip separator when no log output.
return
}
log.Debugf("Reconciled client %s (provider %s) models: +%d, -%d", clientID, provider, len(added), len(removed))
misc.LogCredentialSeparator() misc.LogCredentialSeparator()
} }
func (r *ModelRegistry) addModelRegistration(modelID, provider string, model *ModelInfo, now time.Time) {
if model == nil || modelID == "" {
return
}
if existing, exists := r.models[modelID]; exists {
existing.Count++
existing.LastUpdated = now
existing.Info = cloneModelInfo(model)
if existing.SuspendedClients == nil {
existing.SuspendedClients = make(map[string]string)
}
if provider != "" {
if existing.Providers == nil {
existing.Providers = make(map[string]int)
}
existing.Providers[provider]++
}
log.Debugf("Incremented count for model %s, now %d clients", modelID, existing.Count)
return
}
registration := &ModelRegistration{
Info: cloneModelInfo(model),
Count: 1,
LastUpdated: now,
QuotaExceededClients: make(map[string]*time.Time),
SuspendedClients: make(map[string]string),
}
if provider != "" {
registration.Providers = map[string]int{provider: 1}
}
r.models[modelID] = registration
log.Debugf("Registered new model %s from provider %s", modelID, provider)
}
func (r *ModelRegistry) removeModelRegistration(clientID, modelID, provider string, now time.Time) {
registration, exists := r.models[modelID]
if !exists {
return
}
registration.Count--
registration.LastUpdated = now
if registration.QuotaExceededClients != nil {
delete(registration.QuotaExceededClients, clientID)
}
if registration.SuspendedClients != nil {
delete(registration.SuspendedClients, clientID)
}
if registration.Count < 0 {
registration.Count = 0
}
if provider != "" && registration.Providers != nil {
if count, ok := registration.Providers[provider]; ok {
if count <= 1 {
delete(registration.Providers, provider)
} else {
registration.Providers[provider] = count - 1
}
}
}
log.Debugf("Decremented count for model %s, now %d clients", modelID, registration.Count)
if registration.Count <= 0 {
delete(r.models, modelID)
log.Debugf("Removed model %s as no clients remain", modelID)
}
}
func cloneModelInfo(model *ModelInfo) *ModelInfo {
if model == nil {
return nil
}
copy := *model
if len(model.SupportedGenerationMethods) > 0 {
copy.SupportedGenerationMethods = append([]string(nil), model.SupportedGenerationMethods...)
}
if len(model.SupportedParameters) > 0 {
copy.SupportedParameters = append([]string(nil), model.SupportedParameters...)
}
return &copy
}
// UnregisterClient removes a client and decrements counts for its models // UnregisterClient removes a client and decrements counts for its models
// Parameters: // Parameters:
// - clientID: Unique identifier for the client to remove // - clientID: Unique identifier for the client to remove

View File

@@ -52,6 +52,39 @@ type Watcher struct {
dispatchCancel context.CancelFunc dispatchCancel context.CancelFunc
} }
type stableIDGenerator struct {
counters map[string]int
}
func newStableIDGenerator() *stableIDGenerator {
return &stableIDGenerator{counters: make(map[string]int)}
}
func (g *stableIDGenerator) next(kind string, parts ...string) (string, string) {
if g == nil {
return kind + ":000000000000", "000000000000"
}
hasher := sha256.New()
hasher.Write([]byte(kind))
for _, part := range parts {
trimmed := strings.TrimSpace(part)
hasher.Write([]byte{0})
hasher.Write([]byte(trimmed))
}
digest := hex.EncodeToString(hasher.Sum(nil))
if len(digest) < 12 {
digest = fmt.Sprintf("%012s", digest)
}
short := digest[:12]
key := kind + ":" + short
index := g.counters[key]
g.counters[key] = index + 1
if index > 0 {
short = fmt.Sprintf("%s-%d", short, index)
}
return fmt.Sprintf("%s:%s", kind, short), short
}
// AuthUpdateAction represents the type of change detected in auth sources. // AuthUpdateAction represents the type of change detected in auth sources.
type AuthUpdateAction string type AuthUpdateAction string
@@ -640,6 +673,7 @@ func (w *Watcher) removeClient(path string) {
func (w *Watcher) SnapshotCoreAuths() []*coreauth.Auth { func (w *Watcher) SnapshotCoreAuths() []*coreauth.Auth {
out := make([]*coreauth.Auth, 0, 32) out := make([]*coreauth.Auth, 0, 32)
now := time.Now() now := time.Now()
idGen := newStableIDGenerator()
// Also synthesize auth entries for OpenAI-compatibility providers directly from config // Also synthesize auth entries for OpenAI-compatibility providers directly from config
w.clientsMutex.RLock() w.clientsMutex.RLock()
cfg := w.config cfg := w.config
@@ -647,14 +681,18 @@ func (w *Watcher) SnapshotCoreAuths() []*coreauth.Auth {
if cfg != nil { if cfg != nil {
// Gemini official API keys -> synthesize auths // Gemini official API keys -> synthesize auths
for i := range cfg.GlAPIKey { for i := range cfg.GlAPIKey {
k := cfg.GlAPIKey[i] k := strings.TrimSpace(cfg.GlAPIKey[i])
if k == "" {
continue
}
id, token := idGen.next("gemini:apikey", k)
a := &coreauth.Auth{ a := &coreauth.Auth{
ID: fmt.Sprintf("gemini:apikey:%d", i), ID: id,
Provider: "gemini", Provider: "gemini",
Label: "gemini-apikey", Label: "gemini-apikey",
Status: coreauth.StatusActive, Status: coreauth.StatusActive,
Attributes: map[string]string{ Attributes: map[string]string{
"source": fmt.Sprintf("config:gemini#%d", i), "source": fmt.Sprintf("config:gemini[%s]", token),
"api_key": k, "api_key": k,
}, },
CreatedAt: now, CreatedAt: now,
@@ -665,15 +703,20 @@ func (w *Watcher) SnapshotCoreAuths() []*coreauth.Auth {
// Claude API keys -> synthesize auths // Claude API keys -> synthesize auths
for i := range cfg.ClaudeKey { for i := range cfg.ClaudeKey {
ck := cfg.ClaudeKey[i] ck := cfg.ClaudeKey[i]
key := strings.TrimSpace(ck.APIKey)
if key == "" {
continue
}
id, token := idGen.next("claude:apikey", key, ck.BaseURL)
attrs := map[string]string{ attrs := map[string]string{
"source": fmt.Sprintf("config:claude#%d", i), "source": fmt.Sprintf("config:claude[%s]", token),
"api_key": ck.APIKey, "api_key": key,
} }
if ck.BaseURL != "" { if ck.BaseURL != "" {
attrs["base_url"] = ck.BaseURL attrs["base_url"] = ck.BaseURL
} }
a := &coreauth.Auth{ a := &coreauth.Auth{
ID: fmt.Sprintf("claude:apikey:%d", i), ID: id,
Provider: "claude", Provider: "claude",
Label: "claude-apikey", Label: "claude-apikey",
Status: coreauth.StatusActive, Status: coreauth.StatusActive,
@@ -686,15 +729,20 @@ func (w *Watcher) SnapshotCoreAuths() []*coreauth.Auth {
// Codex API keys -> synthesize auths // Codex API keys -> synthesize auths
for i := range cfg.CodexKey { for i := range cfg.CodexKey {
ck := cfg.CodexKey[i] ck := cfg.CodexKey[i]
key := strings.TrimSpace(ck.APIKey)
if key == "" {
continue
}
id, token := idGen.next("codex:apikey", key, ck.BaseURL)
attrs := map[string]string{ attrs := map[string]string{
"source": fmt.Sprintf("config:codex#%d", i), "source": fmt.Sprintf("config:codex[%s]", token),
"api_key": ck.APIKey, "api_key": key,
} }
if ck.BaseURL != "" { if ck.BaseURL != "" {
attrs["base_url"] = ck.BaseURL attrs["base_url"] = ck.BaseURL
} }
a := &coreauth.Auth{ a := &coreauth.Auth{
ID: fmt.Sprintf("codex:apikey:%d", i), ID: id,
Provider: "codex", Provider: "codex",
Label: "codex-apikey", Label: "codex-apikey",
Status: coreauth.StatusActive, Status: coreauth.StatusActive,
@@ -710,11 +758,16 @@ func (w *Watcher) SnapshotCoreAuths() []*coreauth.Auth {
if providerName == "" { if providerName == "" {
providerName = "openai-compatibility" providerName = "openai-compatibility"
} }
base := compat.BaseURL base := strings.TrimSpace(compat.BaseURL)
for j := range compat.APIKeys { for j := range compat.APIKeys {
key := compat.APIKeys[j] key := strings.TrimSpace(compat.APIKeys[j])
if key == "" {
continue
}
idKind := fmt.Sprintf("openai-compatibility:%s", providerName)
id, token := idGen.next(idKind, key, base)
attrs := map[string]string{ attrs := map[string]string{
"source": fmt.Sprintf("config:%s#%d", compat.Name, j), "source": fmt.Sprintf("config:%s[%s]", providerName, token),
"base_url": base, "base_url": base,
"api_key": key, "api_key": key,
"compat_name": compat.Name, "compat_name": compat.Name,
@@ -724,7 +777,7 @@ func (w *Watcher) SnapshotCoreAuths() []*coreauth.Auth {
attrs["models_hash"] = hash attrs["models_hash"] = hash
} }
a := &coreauth.Auth{ a := &coreauth.Auth{
ID: fmt.Sprintf("openai-compatibility:%s:%d", compat.Name, j), ID: id,
Provider: providerName, Provider: providerName,
Label: compat.Name, Label: compat.Name,
Status: coreauth.StatusActive, Status: coreauth.StatusActive,

252
sdk/access/reconcile.go Normal file
View File

@@ -0,0 +1,252 @@
package access
import (
"reflect"
"sort"
"strings"
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
)
// ReconcileProviders builds the desired provider list by reusing existing providers when possible
// and creating or removing providers only when their configuration changed. It returns the final
// ordered provider slice along with the identifiers of providers that were added, updated, or
// removed compared to the previous configuration.
func ReconcileProviders(oldCfg, newCfg *config.Config, existing []Provider) (result []Provider, added, updated, removed []string, err error) {
if newCfg == nil {
return nil, nil, nil, nil, nil
}
existingMap := make(map[string]Provider, len(existing))
for _, provider := range existing {
if provider == nil {
continue
}
existingMap[provider.Identifier()] = provider
}
oldCfgMap := accessProviderMap(oldCfg)
newEntries := collectProviderEntries(newCfg)
result = make([]Provider, 0, len(newEntries))
finalIDs := make(map[string]struct{}, len(newEntries))
isInlineProvider := func(id string) bool {
return strings.EqualFold(id, config.DefaultAccessProviderName)
}
appendChange := func(list *[]string, id string) {
if isInlineProvider(id) {
return
}
*list = append(*list, id)
}
for _, providerCfg := range newEntries {
key := providerIdentifier(providerCfg)
if key == "" {
continue
}
if oldCfgProvider, ok := oldCfgMap[key]; ok {
isAliased := oldCfgProvider == providerCfg
if !isAliased && providerConfigEqual(oldCfgProvider, providerCfg) {
if existingProvider, okExisting := existingMap[key]; okExisting {
result = append(result, existingProvider)
finalIDs[key] = struct{}{}
continue
}
}
}
provider, buildErr := buildProvider(providerCfg, newCfg)
if buildErr != nil {
return nil, nil, nil, nil, buildErr
}
if _, ok := oldCfgMap[key]; ok {
if _, existed := existingMap[key]; existed {
appendChange(&updated, key)
} else {
appendChange(&added, key)
}
} else {
appendChange(&added, key)
}
result = append(result, provider)
finalIDs[key] = struct{}{}
}
if len(result) == 0 && len(newCfg.APIKeys) > 0 {
config.SyncInlineAPIKeys(newCfg, newCfg.APIKeys)
if providerCfg := newCfg.ConfigAPIKeyProvider(); providerCfg != nil {
key := providerIdentifier(providerCfg)
if key != "" {
if oldCfgProvider, ok := oldCfgMap[key]; ok {
isAliased := oldCfgProvider == providerCfg
if !isAliased && providerConfigEqual(oldCfgProvider, providerCfg) {
if existingProvider, okExisting := existingMap[key]; okExisting {
result = append(result, existingProvider)
} else {
provider, buildErr := buildProvider(providerCfg, newCfg)
if buildErr != nil {
return nil, nil, nil, nil, buildErr
}
if _, existed := existingMap[key]; existed {
appendChange(&updated, key)
} else {
appendChange(&added, key)
}
result = append(result, provider)
}
} else {
provider, buildErr := buildProvider(providerCfg, newCfg)
if buildErr != nil {
return nil, nil, nil, nil, buildErr
}
if _, existed := existingMap[key]; existed {
appendChange(&updated, key)
} else {
appendChange(&added, key)
}
result = append(result, provider)
}
} else {
provider, buildErr := buildProvider(providerCfg, newCfg)
if buildErr != nil {
return nil, nil, nil, nil, buildErr
}
appendChange(&added, key)
result = append(result, provider)
}
finalIDs[key] = struct{}{}
}
}
}
removedSet := make(map[string]struct{})
for id := range existingMap {
if _, ok := finalIDs[id]; !ok {
if isInlineProvider(id) {
continue
}
removedSet[id] = struct{}{}
}
}
removed = make([]string, 0, len(removedSet))
for id := range removedSet {
removed = append(removed, id)
}
sort.Strings(added)
sort.Strings(updated)
sort.Strings(removed)
return result, added, updated, removed, nil
}
func accessProviderMap(cfg *config.Config) map[string]*config.AccessProvider {
result := make(map[string]*config.AccessProvider)
if cfg == nil {
return result
}
for i := range cfg.Access.Providers {
providerCfg := &cfg.Access.Providers[i]
if providerCfg.Type == "" {
continue
}
key := providerIdentifier(providerCfg)
if key == "" {
continue
}
result[key] = providerCfg
}
if len(result) == 0 && len(cfg.APIKeys) > 0 {
if provider := cfg.ConfigAPIKeyProvider(); provider != nil {
if key := providerIdentifier(provider); key != "" {
result[key] = provider
}
}
}
return result
}
func collectProviderEntries(cfg *config.Config) []*config.AccessProvider {
entries := make([]*config.AccessProvider, 0, len(cfg.Access.Providers))
if cfg == nil {
return entries
}
for i := range cfg.Access.Providers {
providerCfg := &cfg.Access.Providers[i]
if providerCfg.Type == "" {
continue
}
if key := providerIdentifier(providerCfg); key != "" {
entries = append(entries, providerCfg)
}
}
return entries
}
func providerIdentifier(provider *config.AccessProvider) string {
if provider == nil {
return ""
}
if name := strings.TrimSpace(provider.Name); name != "" {
return name
}
typ := strings.TrimSpace(provider.Type)
if typ == "" {
return ""
}
if strings.EqualFold(typ, config.AccessProviderTypeConfigAPIKey) {
return config.DefaultAccessProviderName
}
return typ
}
func providerConfigEqual(a, b *config.AccessProvider) bool {
if a == nil || b == nil {
return a == nil && b == nil
}
if !strings.EqualFold(strings.TrimSpace(a.Type), strings.TrimSpace(b.Type)) {
return false
}
if strings.TrimSpace(a.SDK) != strings.TrimSpace(b.SDK) {
return false
}
if !stringSetEqual(a.APIKeys, b.APIKeys) {
return false
}
if len(a.Config) != len(b.Config) {
return false
}
if len(a.Config) > 0 && !reflect.DeepEqual(a.Config, b.Config) {
return false
}
return true
}
func stringSetEqual(a, b []string) bool {
if len(a) != len(b) {
return false
}
if len(a) == 0 {
return true
}
seen := make(map[string]int, len(a))
for _, val := range a {
seen[val]++
}
for _, val := range b {
count := seen[val]
if count == 0 {
return false
}
if count == 1 {
delete(seen, val)
} else {
seen[val] = count - 1
}
}
return len(seen) == 0
}

View File

@@ -110,12 +110,23 @@ func (s *Service) refreshAccessProviders(cfg *config.Config) {
if s == nil || s.accessManager == nil || cfg == nil { if s == nil || s.accessManager == nil || cfg == nil {
return return
} }
providers, err := sdkaccess.BuildProviders(cfg) s.cfgMu.RLock()
oldCfg := s.cfg
s.cfgMu.RUnlock()
existing := s.accessManager.Providers()
providers, added, updated, removed, err := sdkaccess.ReconcileProviders(oldCfg, cfg, existing)
if err != nil { if err != nil {
log.Errorf("failed to rebuild request auth providers: %v", err) log.Errorf("failed to reconcile request auth providers: %v", err)
return return
} }
s.accessManager.SetProviders(providers) s.accessManager.SetProviders(providers)
if len(added)+len(updated)+len(removed) > 0 {
log.Debugf("auth providers reconciled (added=%d updated=%d removed=%d)", len(added), len(updated), len(removed))
log.Debugf("auth provider changes details - added=%v updated=%v removed=%v", added, updated, removed)
} else {
log.Debug("auth providers unchanged after config reload")
}
} }
func (s *Service) ensureAuthUpdateQueue(ctx context.Context) { func (s *Service) ensureAuthUpdateQueue(ctx context.Context) {
@@ -497,22 +508,35 @@ func (s *Service) registerModelsForAuth(a *coreauth.Auth) {
if s.cfg != nil { if s.cfg != nil {
providerKey := provider providerKey := provider
compatName := strings.TrimSpace(a.Provider) compatName := strings.TrimSpace(a.Provider)
isCompatAuth := false
if strings.EqualFold(providerKey, "openai-compatibility") { if strings.EqualFold(providerKey, "openai-compatibility") {
isCompatAuth = true
if a.Attributes != nil { if a.Attributes != nil {
if v := strings.TrimSpace(a.Attributes["compat_name"]); v != "" { if v := strings.TrimSpace(a.Attributes["compat_name"]); v != "" {
compatName = v compatName = v
} }
if v := strings.TrimSpace(a.Attributes["provider_key"]); v != "" { if v := strings.TrimSpace(a.Attributes["provider_key"]); v != "" {
providerKey = strings.ToLower(v) providerKey = strings.ToLower(v)
isCompatAuth = true
} }
} }
if providerKey == "openai-compatibility" && compatName != "" { if providerKey == "openai-compatibility" && compatName != "" {
providerKey = strings.ToLower(compatName) providerKey = strings.ToLower(compatName)
} }
} else if a.Attributes != nil {
if v := strings.TrimSpace(a.Attributes["compat_name"]); v != "" {
compatName = v
isCompatAuth = true
}
if v := strings.TrimSpace(a.Attributes["provider_key"]); v != "" {
providerKey = strings.ToLower(v)
isCompatAuth = true
}
} }
for i := range s.cfg.OpenAICompatibility { for i := range s.cfg.OpenAICompatibility {
compat := &s.cfg.OpenAICompatibility[i] compat := &s.cfg.OpenAICompatibility[i]
if strings.EqualFold(compat.Name, compatName) { if strings.EqualFold(compat.Name, compatName) {
isCompatAuth = true
// Convert compatibility models to registry models // Convert compatibility models to registry models
ms := make([]*ModelInfo, 0, len(compat.Models)) ms := make([]*ModelInfo, 0, len(compat.Models))
for j := range compat.Models { for j := range compat.Models {
@@ -532,10 +556,18 @@ func (s *Service) registerModelsForAuth(a *coreauth.Auth) {
providerKey = "openai-compatibility" providerKey = "openai-compatibility"
} }
GlobalModelRegistry().RegisterClient(a.ID, providerKey, ms) GlobalModelRegistry().RegisterClient(a.ID, providerKey, ms)
} else {
// Ensure stale registrations are cleared when model list becomes empty.
GlobalModelRegistry().UnregisterClient(a.ID)
} }
return return
} }
} }
if isCompatAuth {
// No matching provider found or models removed entirely; drop any prior registration.
GlobalModelRegistry().UnregisterClient(a.ID)
return
}
} }
} }
if len(models) > 0 { if len(models) > 0 {