mirror of
https://github.com/router-for-me/CLIProxyAPI.git
synced 2026-02-03 04:50:52 +08:00
refactor(config): Implement reconciliation for providers and clients
This commit introduces a reconciliation mechanism for handling configuration updates, significantly improving efficiency and resource management. Previously, reloading the configuration would tear down and recreate all access providers from scratch, regardless of whether their individual configurations had changed. This was inefficient and could disrupt services. The new `sdkaccess.ReconcileProviders` function now compares the old and new configurations to intelligently manage the provider lifecycle: - Unchanged providers are kept. - New providers are created. - Providers removed from the config are closed and discarded. - Providers with updated configurations are gracefully closed and recreated. To support this, a `Close()` method has been added to the `Provider` interface. A similar reconciliation logic has been applied to the client registration state in `state.RegisterClient`. This ensures that model registrations are accurately tracked when a client's configuration is updated, correctly handling added, removed, and unchanged models. Enhanced logging provides visibility into these operations.
This commit is contained in:
@@ -194,7 +194,7 @@ func NewServer(cfg *config.Config, authManager *auth.Manager, accessManager *sdk
|
|||||||
loggerToggle: toggle,
|
loggerToggle: toggle,
|
||||||
configFilePath: configFilePath,
|
configFilePath: configFilePath,
|
||||||
}
|
}
|
||||||
s.applyAccessConfig(cfg)
|
s.applyAccessConfig(nil, cfg)
|
||||||
// Initialize management handler
|
// Initialize management handler
|
||||||
s.mgmt = managementHandlers.NewHandler(cfg, configFilePath, authManager)
|
s.mgmt = managementHandlers.NewHandler(cfg, configFilePath, authManager)
|
||||||
if optionState.localPassword != "" {
|
if optionState.localPassword != "" {
|
||||||
@@ -547,16 +547,23 @@ func corsMiddleware() gin.HandlerFunc {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) applyAccessConfig(cfg *config.Config) {
|
func (s *Server) applyAccessConfig(oldCfg, newCfg *config.Config) {
|
||||||
if s == nil || s.accessManager == nil {
|
if s == nil || s.accessManager == nil || newCfg == nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
providers, err := sdkaccess.BuildProviders(cfg)
|
existing := s.accessManager.Providers()
|
||||||
|
providers, added, updated, removed, err := sdkaccess.ReconcileProviders(oldCfg, newCfg, existing)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("failed to update request auth providers: %v", err)
|
log.Errorf("failed to reconcile request auth providers: %v", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
s.accessManager.SetProviders(providers)
|
s.accessManager.SetProviders(providers)
|
||||||
|
if len(added)+len(updated)+len(removed) > 0 {
|
||||||
|
log.Infof("server request auth providers reconciled (added=%d updated=%d removed=%d)", len(added), len(updated), len(removed))
|
||||||
|
log.Debugf("server request auth provider details - added=%v updated=%v removed=%v", added, updated, removed)
|
||||||
|
} else {
|
||||||
|
log.Debug("server request auth providers unchanged after config update")
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// UpdateClients updates the server's client list and configuration.
|
// UpdateClients updates the server's client list and configuration.
|
||||||
@@ -566,44 +573,60 @@ func (s *Server) applyAccessConfig(cfg *config.Config) {
|
|||||||
// - clients: The new slice of AI service clients
|
// - clients: The new slice of AI service clients
|
||||||
// - cfg: The new application configuration
|
// - cfg: The new application configuration
|
||||||
func (s *Server) UpdateClients(cfg *config.Config) {
|
func (s *Server) UpdateClients(cfg *config.Config) {
|
||||||
|
oldCfg := s.cfg
|
||||||
|
|
||||||
// Update request logger enabled state if it has changed
|
// Update request logger enabled state if it has changed
|
||||||
if s.requestLogger != nil && s.cfg.RequestLog != cfg.RequestLog {
|
previousRequestLog := false
|
||||||
|
if oldCfg != nil {
|
||||||
|
previousRequestLog = oldCfg.RequestLog
|
||||||
|
}
|
||||||
|
if s.requestLogger != nil && (oldCfg == nil || previousRequestLog != cfg.RequestLog) {
|
||||||
if s.loggerToggle != nil {
|
if s.loggerToggle != nil {
|
||||||
s.loggerToggle(cfg.RequestLog)
|
s.loggerToggle(cfg.RequestLog)
|
||||||
} else if toggler, ok := s.requestLogger.(interface{ SetEnabled(bool) }); ok {
|
} else if toggler, ok := s.requestLogger.(interface{ SetEnabled(bool) }); ok {
|
||||||
toggler.SetEnabled(cfg.RequestLog)
|
toggler.SetEnabled(cfg.RequestLog)
|
||||||
}
|
}
|
||||||
log.Debugf("request logging updated from %t to %t", s.cfg.RequestLog, cfg.RequestLog)
|
if oldCfg != nil {
|
||||||
}
|
log.Debugf("request logging updated from %t to %t", previousRequestLog, cfg.RequestLog)
|
||||||
|
|
||||||
if s.cfg.LoggingToFile != cfg.LoggingToFile {
|
|
||||||
if err := logging.ConfigureLogOutput(cfg.LoggingToFile); err != nil {
|
|
||||||
log.Errorf("failed to reconfigure log output: %v", err)
|
|
||||||
} else {
|
} else {
|
||||||
log.Debugf("logging_to_file updated from %t to %t", s.cfg.LoggingToFile, cfg.LoggingToFile)
|
log.Debugf("request logging toggled to %t", cfg.RequestLog)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if s.cfg == nil || s.cfg.UsageStatisticsEnabled != cfg.UsageStatisticsEnabled {
|
if oldCfg != nil && oldCfg.LoggingToFile != cfg.LoggingToFile {
|
||||||
|
if err := logging.ConfigureLogOutput(cfg.LoggingToFile); err != nil {
|
||||||
|
log.Errorf("failed to reconfigure log output: %v", err)
|
||||||
|
} else {
|
||||||
|
log.Debugf("logging_to_file updated from %t to %t", oldCfg.LoggingToFile, cfg.LoggingToFile)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if oldCfg == nil || oldCfg.UsageStatisticsEnabled != cfg.UsageStatisticsEnabled {
|
||||||
usage.SetStatisticsEnabled(cfg.UsageStatisticsEnabled)
|
usage.SetStatisticsEnabled(cfg.UsageStatisticsEnabled)
|
||||||
if s.cfg != nil {
|
if oldCfg != nil {
|
||||||
log.Debugf("usage_statistics_enabled updated from %t to %t", s.cfg.UsageStatisticsEnabled, cfg.UsageStatisticsEnabled)
|
log.Debugf("usage_statistics_enabled updated from %t to %t", oldCfg.UsageStatisticsEnabled, cfg.UsageStatisticsEnabled)
|
||||||
|
} else {
|
||||||
|
log.Debugf("usage_statistics_enabled toggled to %t", cfg.UsageStatisticsEnabled)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Update log level dynamically when debug flag changes
|
// Update log level dynamically when debug flag changes
|
||||||
if s.cfg.Debug != cfg.Debug {
|
if oldCfg == nil || oldCfg.Debug != cfg.Debug {
|
||||||
util.SetLogLevel(cfg)
|
util.SetLogLevel(cfg)
|
||||||
log.Debugf("debug mode updated from %t to %t", s.cfg.Debug, cfg.Debug)
|
if oldCfg != nil {
|
||||||
|
log.Debugf("debug mode updated from %t to %t", oldCfg.Debug, cfg.Debug)
|
||||||
|
} else {
|
||||||
|
log.Debugf("debug mode toggled to %t", cfg.Debug)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
s.applyAccessConfig(oldCfg, cfg)
|
||||||
s.cfg = cfg
|
s.cfg = cfg
|
||||||
s.handlers.UpdateClients(cfg)
|
s.handlers.UpdateClients(cfg)
|
||||||
if s.mgmt != nil {
|
if s.mgmt != nil {
|
||||||
s.mgmt.SetConfig(cfg)
|
s.mgmt.SetConfig(cfg)
|
||||||
s.mgmt.SetAuthManager(s.handlers.AuthManager)
|
s.mgmt.SetAuthManager(s.handlers.AuthManager)
|
||||||
}
|
}
|
||||||
s.applyAccessConfig(cfg)
|
|
||||||
|
|
||||||
// Count client sources from configuration and auth directory
|
// Count client sources from configuration and auth directory
|
||||||
authFiles := util.CountAuthFiles(cfg.AuthDir)
|
authFiles := util.CountAuthFiles(cfg.AuthDir)
|
||||||
|
|||||||
@@ -101,58 +101,220 @@ func (r *ModelRegistry) RegisterClient(clientID, clientProvider string, models [
|
|||||||
r.mutex.Lock()
|
r.mutex.Lock()
|
||||||
defer r.mutex.Unlock()
|
defer r.mutex.Unlock()
|
||||||
|
|
||||||
// Remove any existing registration for this client
|
|
||||||
r.unregisterClientInternal(clientID)
|
|
||||||
|
|
||||||
provider := strings.ToLower(clientProvider)
|
provider := strings.ToLower(clientProvider)
|
||||||
|
seen := make(map[string]struct{})
|
||||||
modelIDs := make([]string, 0, len(models))
|
modelIDs := make([]string, 0, len(models))
|
||||||
|
newModels := make(map[string]*ModelInfo, len(models))
|
||||||
|
for _, model := range models {
|
||||||
|
if model == nil || model.ID == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if _, exists := seen[model.ID]; exists {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
seen[model.ID] = struct{}{}
|
||||||
|
modelIDs = append(modelIDs, model.ID)
|
||||||
|
newModels[model.ID] = model
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(modelIDs) == 0 {
|
||||||
|
// No models supplied; unregister existing client state if present.
|
||||||
|
r.unregisterClientInternal(clientID)
|
||||||
|
delete(r.clientModels, clientID)
|
||||||
|
delete(r.clientProviders, clientID)
|
||||||
|
misc.LogCredentialSeparator()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
now := time.Now()
|
now := time.Now()
|
||||||
|
|
||||||
for _, model := range models {
|
oldModels, hadExisting := r.clientModels[clientID]
|
||||||
modelIDs = append(modelIDs, model.ID)
|
oldProvider, hadProvider := r.clientProviders[clientID]
|
||||||
|
providerChanged := hadProvider && oldProvider != provider
|
||||||
if existing, exists := r.models[model.ID]; exists {
|
if !hadExisting {
|
||||||
// Model already exists, increment count
|
// Pure addition path.
|
||||||
existing.Count++
|
for _, modelID := range modelIDs {
|
||||||
existing.LastUpdated = now
|
model := newModels[modelID]
|
||||||
if existing.SuspendedClients == nil {
|
r.addModelRegistration(modelID, provider, model, now)
|
||||||
existing.SuspendedClients = make(map[string]string)
|
}
|
||||||
}
|
r.clientModels[clientID] = modelIDs
|
||||||
if provider != "" {
|
if provider != "" {
|
||||||
if existing.Providers == nil {
|
r.clientProviders[clientID] = provider
|
||||||
existing.Providers = make(map[string]int)
|
|
||||||
}
|
|
||||||
existing.Providers[provider]++
|
|
||||||
}
|
|
||||||
log.Debugf("Incremented count for model %s, now %d clients", model.ID, existing.Count)
|
|
||||||
} else {
|
} else {
|
||||||
// New model, create registration
|
delete(r.clientProviders, clientID)
|
||||||
registration := &ModelRegistration{
|
}
|
||||||
Info: model,
|
log.Debugf("Registered client %s from provider %s with %d models", clientID, clientProvider, len(modelIDs))
|
||||||
Count: 1,
|
misc.LogCredentialSeparator()
|
||||||
LastUpdated: now,
|
return
|
||||||
QuotaExceededClients: make(map[string]*time.Time),
|
}
|
||||||
SuspendedClients: make(map[string]string),
|
|
||||||
}
|
oldSet := make(map[string]struct{}, len(oldModels))
|
||||||
if provider != "" {
|
for _, id := range oldModels {
|
||||||
registration.Providers = map[string]int{provider: 1}
|
oldSet[id] = struct{}{}
|
||||||
}
|
}
|
||||||
r.models[model.ID] = registration
|
|
||||||
log.Debugf("Registered new model %s from provider %s", model.ID, clientProvider)
|
added := make([]string, 0)
|
||||||
|
removed := make([]string, 0)
|
||||||
|
for _, id := range modelIDs {
|
||||||
|
if _, exists := oldSet[id]; !exists {
|
||||||
|
added = append(added, id)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
for _, id := range oldModels {
|
||||||
|
if _, exists := newModels[id]; !exists {
|
||||||
|
removed = append(removed, id)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
r.clientModels[clientID] = modelIDs
|
// Handle provider change for overlapping models before modifications.
|
||||||
|
if providerChanged && oldProvider != "" {
|
||||||
|
for _, id := range modelIDs {
|
||||||
|
if reg, ok := r.models[id]; ok && reg.Providers != nil {
|
||||||
|
if count, okProv := reg.Providers[oldProvider]; okProv {
|
||||||
|
if count <= 1 {
|
||||||
|
delete(reg.Providers, oldProvider)
|
||||||
|
} else {
|
||||||
|
reg.Providers[oldProvider] = count - 1
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Apply removals first to keep counters accurate.
|
||||||
|
for _, id := range removed {
|
||||||
|
r.removeModelRegistration(clientID, id, oldProvider, now)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Apply additions.
|
||||||
|
for _, id := range added {
|
||||||
|
model := newModels[id]
|
||||||
|
r.addModelRegistration(id, provider, model, now)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Update metadata for models that remain associated with the client.
|
||||||
|
addedSet := make(map[string]struct{}, len(added))
|
||||||
|
for _, id := range added {
|
||||||
|
addedSet[id] = struct{}{}
|
||||||
|
}
|
||||||
|
for _, id := range modelIDs {
|
||||||
|
model := newModels[id]
|
||||||
|
if reg, ok := r.models[id]; ok {
|
||||||
|
reg.Info = cloneModelInfo(model)
|
||||||
|
reg.LastUpdated = now
|
||||||
|
if providerChanged {
|
||||||
|
if _, newlyAdded := addedSet[id]; newlyAdded {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if reg.Providers == nil {
|
||||||
|
reg.Providers = make(map[string]int)
|
||||||
|
}
|
||||||
|
reg.Providers[provider]++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Update client bookkeeping.
|
||||||
|
if len(modelIDs) > 0 {
|
||||||
|
r.clientModels[clientID] = modelIDs
|
||||||
|
}
|
||||||
if provider != "" {
|
if provider != "" {
|
||||||
r.clientProviders[clientID] = provider
|
r.clientProviders[clientID] = provider
|
||||||
} else {
|
} else {
|
||||||
delete(r.clientProviders, clientID)
|
delete(r.clientProviders, clientID)
|
||||||
}
|
}
|
||||||
log.Debugf("Registered client %s from provider %s with %d models", clientID, clientProvider, len(models))
|
|
||||||
// Separator at the end of the registration block (acts as boundary to next group)
|
if len(added) == 0 && len(removed) == 0 && !providerChanged {
|
||||||
|
// Only metadata (e.g., display name) changed.
|
||||||
|
misc.LogCredentialSeparator()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Debugf("Reconciled client %s (provider %s) models: +%d, -%d", clientID, provider, len(added), len(removed))
|
||||||
misc.LogCredentialSeparator()
|
misc.LogCredentialSeparator()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (r *ModelRegistry) addModelRegistration(modelID, provider string, model *ModelInfo, now time.Time) {
|
||||||
|
if model == nil || modelID == "" {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if existing, exists := r.models[modelID]; exists {
|
||||||
|
existing.Count++
|
||||||
|
existing.LastUpdated = now
|
||||||
|
existing.Info = cloneModelInfo(model)
|
||||||
|
if existing.SuspendedClients == nil {
|
||||||
|
existing.SuspendedClients = make(map[string]string)
|
||||||
|
}
|
||||||
|
if provider != "" {
|
||||||
|
if existing.Providers == nil {
|
||||||
|
existing.Providers = make(map[string]int)
|
||||||
|
}
|
||||||
|
existing.Providers[provider]++
|
||||||
|
}
|
||||||
|
log.Debugf("Incremented count for model %s, now %d clients", modelID, existing.Count)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
registration := &ModelRegistration{
|
||||||
|
Info: cloneModelInfo(model),
|
||||||
|
Count: 1,
|
||||||
|
LastUpdated: now,
|
||||||
|
QuotaExceededClients: make(map[string]*time.Time),
|
||||||
|
SuspendedClients: make(map[string]string),
|
||||||
|
}
|
||||||
|
if provider != "" {
|
||||||
|
registration.Providers = map[string]int{provider: 1}
|
||||||
|
}
|
||||||
|
r.models[modelID] = registration
|
||||||
|
log.Debugf("Registered new model %s from provider %s", modelID, provider)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *ModelRegistry) removeModelRegistration(clientID, modelID, provider string, now time.Time) {
|
||||||
|
registration, exists := r.models[modelID]
|
||||||
|
if !exists {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
registration.Count--
|
||||||
|
registration.LastUpdated = now
|
||||||
|
if registration.QuotaExceededClients != nil {
|
||||||
|
delete(registration.QuotaExceededClients, clientID)
|
||||||
|
}
|
||||||
|
if registration.SuspendedClients != nil {
|
||||||
|
delete(registration.SuspendedClients, clientID)
|
||||||
|
}
|
||||||
|
if registration.Count < 0 {
|
||||||
|
registration.Count = 0
|
||||||
|
}
|
||||||
|
if provider != "" && registration.Providers != nil {
|
||||||
|
if count, ok := registration.Providers[provider]; ok {
|
||||||
|
if count <= 1 {
|
||||||
|
delete(registration.Providers, provider)
|
||||||
|
} else {
|
||||||
|
registration.Providers[provider] = count - 1
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
log.Debugf("Decremented count for model %s, now %d clients", modelID, registration.Count)
|
||||||
|
if registration.Count <= 0 {
|
||||||
|
delete(r.models, modelID)
|
||||||
|
log.Debugf("Removed model %s as no clients remain", modelID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func cloneModelInfo(model *ModelInfo) *ModelInfo {
|
||||||
|
if model == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
copy := *model
|
||||||
|
if len(model.SupportedGenerationMethods) > 0 {
|
||||||
|
copy.SupportedGenerationMethods = append([]string(nil), model.SupportedGenerationMethods...)
|
||||||
|
}
|
||||||
|
if len(model.SupportedParameters) > 0 {
|
||||||
|
copy.SupportedParameters = append([]string(nil), model.SupportedParameters...)
|
||||||
|
}
|
||||||
|
return ©
|
||||||
|
}
|
||||||
|
|
||||||
// UnregisterClient removes a client and decrements counts for its models
|
// UnregisterClient removes a client and decrements counts for its models
|
||||||
// Parameters:
|
// Parameters:
|
||||||
// - clientID: Unique identifier for the client to remove
|
// - clientID: Unique identifier for the client to remove
|
||||||
|
|||||||
@@ -52,6 +52,39 @@ type Watcher struct {
|
|||||||
dispatchCancel context.CancelFunc
|
dispatchCancel context.CancelFunc
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type stableIDGenerator struct {
|
||||||
|
counters map[string]int
|
||||||
|
}
|
||||||
|
|
||||||
|
func newStableIDGenerator() *stableIDGenerator {
|
||||||
|
return &stableIDGenerator{counters: make(map[string]int)}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (g *stableIDGenerator) next(kind string, parts ...string) (string, string) {
|
||||||
|
if g == nil {
|
||||||
|
return kind + ":000000000000", "000000000000"
|
||||||
|
}
|
||||||
|
hasher := sha256.New()
|
||||||
|
hasher.Write([]byte(kind))
|
||||||
|
for _, part := range parts {
|
||||||
|
trimmed := strings.TrimSpace(part)
|
||||||
|
hasher.Write([]byte{0})
|
||||||
|
hasher.Write([]byte(trimmed))
|
||||||
|
}
|
||||||
|
digest := hex.EncodeToString(hasher.Sum(nil))
|
||||||
|
if len(digest) < 12 {
|
||||||
|
digest = fmt.Sprintf("%012s", digest)
|
||||||
|
}
|
||||||
|
short := digest[:12]
|
||||||
|
key := kind + ":" + short
|
||||||
|
index := g.counters[key]
|
||||||
|
g.counters[key] = index + 1
|
||||||
|
if index > 0 {
|
||||||
|
short = fmt.Sprintf("%s-%d", short, index)
|
||||||
|
}
|
||||||
|
return fmt.Sprintf("%s:%s", kind, short), short
|
||||||
|
}
|
||||||
|
|
||||||
// AuthUpdateAction represents the type of change detected in auth sources.
|
// AuthUpdateAction represents the type of change detected in auth sources.
|
||||||
type AuthUpdateAction string
|
type AuthUpdateAction string
|
||||||
|
|
||||||
@@ -640,6 +673,7 @@ func (w *Watcher) removeClient(path string) {
|
|||||||
func (w *Watcher) SnapshotCoreAuths() []*coreauth.Auth {
|
func (w *Watcher) SnapshotCoreAuths() []*coreauth.Auth {
|
||||||
out := make([]*coreauth.Auth, 0, 32)
|
out := make([]*coreauth.Auth, 0, 32)
|
||||||
now := time.Now()
|
now := time.Now()
|
||||||
|
idGen := newStableIDGenerator()
|
||||||
// Also synthesize auth entries for OpenAI-compatibility providers directly from config
|
// Also synthesize auth entries for OpenAI-compatibility providers directly from config
|
||||||
w.clientsMutex.RLock()
|
w.clientsMutex.RLock()
|
||||||
cfg := w.config
|
cfg := w.config
|
||||||
@@ -647,14 +681,18 @@ func (w *Watcher) SnapshotCoreAuths() []*coreauth.Auth {
|
|||||||
if cfg != nil {
|
if cfg != nil {
|
||||||
// Gemini official API keys -> synthesize auths
|
// Gemini official API keys -> synthesize auths
|
||||||
for i := range cfg.GlAPIKey {
|
for i := range cfg.GlAPIKey {
|
||||||
k := cfg.GlAPIKey[i]
|
k := strings.TrimSpace(cfg.GlAPIKey[i])
|
||||||
|
if k == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
id, token := idGen.next("gemini:apikey", k)
|
||||||
a := &coreauth.Auth{
|
a := &coreauth.Auth{
|
||||||
ID: fmt.Sprintf("gemini:apikey:%d", i),
|
ID: id,
|
||||||
Provider: "gemini",
|
Provider: "gemini",
|
||||||
Label: "gemini-apikey",
|
Label: "gemini-apikey",
|
||||||
Status: coreauth.StatusActive,
|
Status: coreauth.StatusActive,
|
||||||
Attributes: map[string]string{
|
Attributes: map[string]string{
|
||||||
"source": fmt.Sprintf("config:gemini#%d", i),
|
"source": fmt.Sprintf("config:gemini[%s]", token),
|
||||||
"api_key": k,
|
"api_key": k,
|
||||||
},
|
},
|
||||||
CreatedAt: now,
|
CreatedAt: now,
|
||||||
@@ -665,15 +703,20 @@ func (w *Watcher) SnapshotCoreAuths() []*coreauth.Auth {
|
|||||||
// Claude API keys -> synthesize auths
|
// Claude API keys -> synthesize auths
|
||||||
for i := range cfg.ClaudeKey {
|
for i := range cfg.ClaudeKey {
|
||||||
ck := cfg.ClaudeKey[i]
|
ck := cfg.ClaudeKey[i]
|
||||||
|
key := strings.TrimSpace(ck.APIKey)
|
||||||
|
if key == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
id, token := idGen.next("claude:apikey", key, ck.BaseURL)
|
||||||
attrs := map[string]string{
|
attrs := map[string]string{
|
||||||
"source": fmt.Sprintf("config:claude#%d", i),
|
"source": fmt.Sprintf("config:claude[%s]", token),
|
||||||
"api_key": ck.APIKey,
|
"api_key": key,
|
||||||
}
|
}
|
||||||
if ck.BaseURL != "" {
|
if ck.BaseURL != "" {
|
||||||
attrs["base_url"] = ck.BaseURL
|
attrs["base_url"] = ck.BaseURL
|
||||||
}
|
}
|
||||||
a := &coreauth.Auth{
|
a := &coreauth.Auth{
|
||||||
ID: fmt.Sprintf("claude:apikey:%d", i),
|
ID: id,
|
||||||
Provider: "claude",
|
Provider: "claude",
|
||||||
Label: "claude-apikey",
|
Label: "claude-apikey",
|
||||||
Status: coreauth.StatusActive,
|
Status: coreauth.StatusActive,
|
||||||
@@ -686,15 +729,20 @@ func (w *Watcher) SnapshotCoreAuths() []*coreauth.Auth {
|
|||||||
// Codex API keys -> synthesize auths
|
// Codex API keys -> synthesize auths
|
||||||
for i := range cfg.CodexKey {
|
for i := range cfg.CodexKey {
|
||||||
ck := cfg.CodexKey[i]
|
ck := cfg.CodexKey[i]
|
||||||
|
key := strings.TrimSpace(ck.APIKey)
|
||||||
|
if key == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
id, token := idGen.next("codex:apikey", key, ck.BaseURL)
|
||||||
attrs := map[string]string{
|
attrs := map[string]string{
|
||||||
"source": fmt.Sprintf("config:codex#%d", i),
|
"source": fmt.Sprintf("config:codex[%s]", token),
|
||||||
"api_key": ck.APIKey,
|
"api_key": key,
|
||||||
}
|
}
|
||||||
if ck.BaseURL != "" {
|
if ck.BaseURL != "" {
|
||||||
attrs["base_url"] = ck.BaseURL
|
attrs["base_url"] = ck.BaseURL
|
||||||
}
|
}
|
||||||
a := &coreauth.Auth{
|
a := &coreauth.Auth{
|
||||||
ID: fmt.Sprintf("codex:apikey:%d", i),
|
ID: id,
|
||||||
Provider: "codex",
|
Provider: "codex",
|
||||||
Label: "codex-apikey",
|
Label: "codex-apikey",
|
||||||
Status: coreauth.StatusActive,
|
Status: coreauth.StatusActive,
|
||||||
@@ -710,11 +758,16 @@ func (w *Watcher) SnapshotCoreAuths() []*coreauth.Auth {
|
|||||||
if providerName == "" {
|
if providerName == "" {
|
||||||
providerName = "openai-compatibility"
|
providerName = "openai-compatibility"
|
||||||
}
|
}
|
||||||
base := compat.BaseURL
|
base := strings.TrimSpace(compat.BaseURL)
|
||||||
for j := range compat.APIKeys {
|
for j := range compat.APIKeys {
|
||||||
key := compat.APIKeys[j]
|
key := strings.TrimSpace(compat.APIKeys[j])
|
||||||
|
if key == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
idKind := fmt.Sprintf("openai-compatibility:%s", providerName)
|
||||||
|
id, token := idGen.next(idKind, key, base)
|
||||||
attrs := map[string]string{
|
attrs := map[string]string{
|
||||||
"source": fmt.Sprintf("config:%s#%d", compat.Name, j),
|
"source": fmt.Sprintf("config:%s[%s]", providerName, token),
|
||||||
"base_url": base,
|
"base_url": base,
|
||||||
"api_key": key,
|
"api_key": key,
|
||||||
"compat_name": compat.Name,
|
"compat_name": compat.Name,
|
||||||
@@ -724,7 +777,7 @@ func (w *Watcher) SnapshotCoreAuths() []*coreauth.Auth {
|
|||||||
attrs["models_hash"] = hash
|
attrs["models_hash"] = hash
|
||||||
}
|
}
|
||||||
a := &coreauth.Auth{
|
a := &coreauth.Auth{
|
||||||
ID: fmt.Sprintf("openai-compatibility:%s:%d", compat.Name, j),
|
ID: id,
|
||||||
Provider: providerName,
|
Provider: providerName,
|
||||||
Label: compat.Name,
|
Label: compat.Name,
|
||||||
Status: coreauth.StatusActive,
|
Status: coreauth.StatusActive,
|
||||||
|
|||||||
230
sdk/access/reconcile.go
Normal file
230
sdk/access/reconcile.go
Normal file
@@ -0,0 +1,230 @@
|
|||||||
|
package access
|
||||||
|
|
||||||
|
import (
|
||||||
|
"reflect"
|
||||||
|
"sort"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ReconcileProviders builds the desired provider list by reusing existing providers when possible
|
||||||
|
// and creating or removing providers only when their configuration changed. It returns the final
|
||||||
|
// ordered provider slice along with the identifiers of providers that were added, updated, or
|
||||||
|
// removed compared to the previous configuration.
|
||||||
|
func ReconcileProviders(oldCfg, newCfg *config.Config, existing []Provider) (result []Provider, added, updated, removed []string, err error) {
|
||||||
|
if newCfg == nil {
|
||||||
|
return nil, nil, nil, nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
existingMap := make(map[string]Provider, len(existing))
|
||||||
|
for _, provider := range existing {
|
||||||
|
if provider == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
existingMap[provider.Identifier()] = provider
|
||||||
|
}
|
||||||
|
|
||||||
|
oldCfgMap := accessProviderMap(oldCfg)
|
||||||
|
newEntries := collectProviderEntries(newCfg)
|
||||||
|
|
||||||
|
result = make([]Provider, 0, len(newEntries))
|
||||||
|
finalIDs := make(map[string]struct{}, len(newEntries))
|
||||||
|
|
||||||
|
for _, providerCfg := range newEntries {
|
||||||
|
key := providerIdentifier(providerCfg)
|
||||||
|
if key == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if oldCfgProvider, ok := oldCfgMap[key]; ok && providerConfigEqual(oldCfgProvider, providerCfg) {
|
||||||
|
if existingProvider, ok := existingMap[key]; ok {
|
||||||
|
result = append(result, existingProvider)
|
||||||
|
finalIDs[key] = struct{}{}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
provider, buildErr := buildProvider(providerCfg, newCfg)
|
||||||
|
if buildErr != nil {
|
||||||
|
return nil, nil, nil, nil, buildErr
|
||||||
|
}
|
||||||
|
if _, ok := oldCfgMap[key]; ok {
|
||||||
|
if _, existed := existingMap[key]; existed {
|
||||||
|
updated = append(updated, key)
|
||||||
|
} else {
|
||||||
|
added = append(added, key)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
added = append(added, key)
|
||||||
|
}
|
||||||
|
result = append(result, provider)
|
||||||
|
finalIDs[key] = struct{}{}
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(result) == 0 && len(newCfg.APIKeys) > 0 {
|
||||||
|
config.SyncInlineAPIKeys(newCfg, newCfg.APIKeys)
|
||||||
|
if providerCfg := newCfg.ConfigAPIKeyProvider(); providerCfg != nil {
|
||||||
|
key := providerIdentifier(providerCfg)
|
||||||
|
if key != "" {
|
||||||
|
if oldCfgProvider, ok := oldCfgMap[key]; ok && providerConfigEqual(oldCfgProvider, providerCfg) {
|
||||||
|
if existingProvider, ok := existingMap[key]; ok {
|
||||||
|
result = append(result, existingProvider)
|
||||||
|
} else {
|
||||||
|
provider, buildErr := buildProvider(providerCfg, newCfg)
|
||||||
|
if buildErr != nil {
|
||||||
|
return nil, nil, nil, nil, buildErr
|
||||||
|
}
|
||||||
|
if _, existed := existingMap[key]; existed {
|
||||||
|
updated = append(updated, key)
|
||||||
|
} else {
|
||||||
|
added = append(added, key)
|
||||||
|
}
|
||||||
|
result = append(result, provider)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
provider, buildErr := buildProvider(providerCfg, newCfg)
|
||||||
|
if buildErr != nil {
|
||||||
|
return nil, nil, nil, nil, buildErr
|
||||||
|
}
|
||||||
|
if _, ok := oldCfgMap[key]; ok {
|
||||||
|
if _, existed := existingMap[key]; existed {
|
||||||
|
updated = append(updated, key)
|
||||||
|
} else {
|
||||||
|
added = append(added, key)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
added = append(added, key)
|
||||||
|
}
|
||||||
|
result = append(result, provider)
|
||||||
|
}
|
||||||
|
finalIDs[key] = struct{}{}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
removedSet := make(map[string]struct{})
|
||||||
|
for id := range existingMap {
|
||||||
|
if _, ok := finalIDs[id]; !ok {
|
||||||
|
removedSet[id] = struct{}{}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
removed = make([]string, 0, len(removedSet))
|
||||||
|
for id := range removedSet {
|
||||||
|
removed = append(removed, id)
|
||||||
|
}
|
||||||
|
|
||||||
|
sort.Strings(added)
|
||||||
|
sort.Strings(updated)
|
||||||
|
sort.Strings(removed)
|
||||||
|
|
||||||
|
return result, added, updated, removed, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func accessProviderMap(cfg *config.Config) map[string]*config.AccessProvider {
|
||||||
|
result := make(map[string]*config.AccessProvider)
|
||||||
|
if cfg == nil {
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
for i := range cfg.Access.Providers {
|
||||||
|
providerCfg := &cfg.Access.Providers[i]
|
||||||
|
if providerCfg.Type == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
key := providerIdentifier(providerCfg)
|
||||||
|
if key == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
result[key] = providerCfg
|
||||||
|
}
|
||||||
|
if len(result) == 0 && len(cfg.APIKeys) > 0 {
|
||||||
|
if provider := cfg.ConfigAPIKeyProvider(); provider != nil {
|
||||||
|
if key := providerIdentifier(provider); key != "" {
|
||||||
|
result[key] = provider
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
|
func collectProviderEntries(cfg *config.Config) []*config.AccessProvider {
|
||||||
|
entries := make([]*config.AccessProvider, 0, len(cfg.Access.Providers))
|
||||||
|
if cfg == nil {
|
||||||
|
return entries
|
||||||
|
}
|
||||||
|
for i := range cfg.Access.Providers {
|
||||||
|
providerCfg := &cfg.Access.Providers[i]
|
||||||
|
if providerCfg.Type == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if key := providerIdentifier(providerCfg); key != "" {
|
||||||
|
entries = append(entries, providerCfg)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return entries
|
||||||
|
}
|
||||||
|
|
||||||
|
func providerIdentifier(provider *config.AccessProvider) string {
|
||||||
|
if provider == nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
if name := strings.TrimSpace(provider.Name); name != "" {
|
||||||
|
return name
|
||||||
|
}
|
||||||
|
typ := strings.TrimSpace(provider.Type)
|
||||||
|
if typ == "" {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
if strings.EqualFold(typ, config.AccessProviderTypeConfigAPIKey) {
|
||||||
|
return config.DefaultAccessProviderName
|
||||||
|
}
|
||||||
|
return typ
|
||||||
|
}
|
||||||
|
|
||||||
|
func providerConfigEqual(a, b *config.AccessProvider) bool {
|
||||||
|
if a == nil || b == nil {
|
||||||
|
return a == nil && b == nil
|
||||||
|
}
|
||||||
|
if !strings.EqualFold(strings.TrimSpace(a.Type), strings.TrimSpace(b.Type)) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if strings.TrimSpace(a.SDK) != strings.TrimSpace(b.SDK) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if !stringSetEqual(a.APIKeys, b.APIKeys) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if len(a.Config) != len(b.Config) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if len(a.Config) > 0 && !reflect.DeepEqual(a.Config, b.Config) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
func stringSetEqual(a, b []string) bool {
|
||||||
|
if len(a) != len(b) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if len(a) == 0 {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
seen := make(map[string]int, len(a))
|
||||||
|
for _, val := range a {
|
||||||
|
seen[val]++
|
||||||
|
}
|
||||||
|
for _, val := range b {
|
||||||
|
count := seen[val]
|
||||||
|
if count == 0 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if count == 1 {
|
||||||
|
delete(seen, val)
|
||||||
|
} else {
|
||||||
|
seen[val] = count - 1
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return len(seen) == 0
|
||||||
|
}
|
||||||
@@ -110,12 +110,23 @@ func (s *Service) refreshAccessProviders(cfg *config.Config) {
|
|||||||
if s == nil || s.accessManager == nil || cfg == nil {
|
if s == nil || s.accessManager == nil || cfg == nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
providers, err := sdkaccess.BuildProviders(cfg)
|
s.cfgMu.RLock()
|
||||||
|
oldCfg := s.cfg
|
||||||
|
s.cfgMu.RUnlock()
|
||||||
|
|
||||||
|
existing := s.accessManager.Providers()
|
||||||
|
providers, added, updated, removed, err := sdkaccess.ReconcileProviders(oldCfg, cfg, existing)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("failed to rebuild request auth providers: %v", err)
|
log.Errorf("failed to reconcile request auth providers: %v", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
s.accessManager.SetProviders(providers)
|
s.accessManager.SetProviders(providers)
|
||||||
|
if len(added)+len(updated)+len(removed) > 0 {
|
||||||
|
log.Infof("request auth providers reconciled (added=%d updated=%d removed=%d)", len(added), len(updated), len(removed))
|
||||||
|
log.Debugf("provider changes details - added=%v updated=%v removed=%v", added, updated, removed)
|
||||||
|
} else {
|
||||||
|
log.Debug("request auth providers unchanged after config reload")
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Service) ensureAuthUpdateQueue(ctx context.Context) {
|
func (s *Service) ensureAuthUpdateQueue(ctx context.Context) {
|
||||||
|
|||||||
Reference in New Issue
Block a user