mirror of
https://github.com/router-for-me/CLIProxyAPI.git
synced 2026-02-03 21:10:51 +08:00
Merge pull request #69 from router-for-me/reload
Implement minimal incremental updates for models and keys
This commit is contained in:
@@ -194,7 +194,7 @@ func NewServer(cfg *config.Config, authManager *auth.Manager, accessManager *sdk
|
|||||||
loggerToggle: toggle,
|
loggerToggle: toggle,
|
||||||
configFilePath: configFilePath,
|
configFilePath: configFilePath,
|
||||||
}
|
}
|
||||||
s.applyAccessConfig(cfg)
|
s.applyAccessConfig(nil, cfg)
|
||||||
// Initialize management handler
|
// Initialize management handler
|
||||||
s.mgmt = managementHandlers.NewHandler(cfg, configFilePath, authManager)
|
s.mgmt = managementHandlers.NewHandler(cfg, configFilePath, authManager)
|
||||||
if optionState.localPassword != "" {
|
if optionState.localPassword != "" {
|
||||||
@@ -547,16 +547,23 @@ func corsMiddleware() gin.HandlerFunc {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) applyAccessConfig(cfg *config.Config) {
|
func (s *Server) applyAccessConfig(oldCfg, newCfg *config.Config) {
|
||||||
if s == nil || s.accessManager == nil {
|
if s == nil || s.accessManager == nil || newCfg == nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
providers, err := sdkaccess.BuildProviders(cfg)
|
existing := s.accessManager.Providers()
|
||||||
|
providers, added, updated, removed, err := sdkaccess.ReconcileProviders(oldCfg, newCfg, existing)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("failed to update request auth providers: %v", err)
|
log.Errorf("failed to reconcile request auth providers: %v", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
s.accessManager.SetProviders(providers)
|
s.accessManager.SetProviders(providers)
|
||||||
|
if len(added)+len(updated)+len(removed) > 0 {
|
||||||
|
log.Debugf("auth providers reconciled (added=%d updated=%d removed=%d)", len(added), len(updated), len(removed))
|
||||||
|
log.Debugf("auth provider changes details - added=%v updated=%v removed=%v", added, updated, removed)
|
||||||
|
} else {
|
||||||
|
log.Debug("auth providers unchanged after config update")
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// UpdateClients updates the server's client list and configuration.
|
// UpdateClients updates the server's client list and configuration.
|
||||||
@@ -566,44 +573,60 @@ func (s *Server) applyAccessConfig(cfg *config.Config) {
|
|||||||
// - clients: The new slice of AI service clients
|
// - clients: The new slice of AI service clients
|
||||||
// - cfg: The new application configuration
|
// - cfg: The new application configuration
|
||||||
func (s *Server) UpdateClients(cfg *config.Config) {
|
func (s *Server) UpdateClients(cfg *config.Config) {
|
||||||
|
oldCfg := s.cfg
|
||||||
|
|
||||||
// Update request logger enabled state if it has changed
|
// Update request logger enabled state if it has changed
|
||||||
if s.requestLogger != nil && s.cfg.RequestLog != cfg.RequestLog {
|
previousRequestLog := false
|
||||||
|
if oldCfg != nil {
|
||||||
|
previousRequestLog = oldCfg.RequestLog
|
||||||
|
}
|
||||||
|
if s.requestLogger != nil && (oldCfg == nil || previousRequestLog != cfg.RequestLog) {
|
||||||
if s.loggerToggle != nil {
|
if s.loggerToggle != nil {
|
||||||
s.loggerToggle(cfg.RequestLog)
|
s.loggerToggle(cfg.RequestLog)
|
||||||
} else if toggler, ok := s.requestLogger.(interface{ SetEnabled(bool) }); ok {
|
} else if toggler, ok := s.requestLogger.(interface{ SetEnabled(bool) }); ok {
|
||||||
toggler.SetEnabled(cfg.RequestLog)
|
toggler.SetEnabled(cfg.RequestLog)
|
||||||
}
|
}
|
||||||
log.Debugf("request logging updated from %t to %t", s.cfg.RequestLog, cfg.RequestLog)
|
if oldCfg != nil {
|
||||||
}
|
log.Debugf("request logging updated from %t to %t", previousRequestLog, cfg.RequestLog)
|
||||||
|
|
||||||
if s.cfg.LoggingToFile != cfg.LoggingToFile {
|
|
||||||
if err := logging.ConfigureLogOutput(cfg.LoggingToFile); err != nil {
|
|
||||||
log.Errorf("failed to reconfigure log output: %v", err)
|
|
||||||
} else {
|
} else {
|
||||||
log.Debugf("logging_to_file updated from %t to %t", s.cfg.LoggingToFile, cfg.LoggingToFile)
|
log.Debugf("request logging toggled to %t", cfg.RequestLog)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if s.cfg == nil || s.cfg.UsageStatisticsEnabled != cfg.UsageStatisticsEnabled {
|
if oldCfg != nil && oldCfg.LoggingToFile != cfg.LoggingToFile {
|
||||||
|
if err := logging.ConfigureLogOutput(cfg.LoggingToFile); err != nil {
|
||||||
|
log.Errorf("failed to reconfigure log output: %v", err)
|
||||||
|
} else {
|
||||||
|
log.Debugf("logging_to_file updated from %t to %t", oldCfg.LoggingToFile, cfg.LoggingToFile)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if oldCfg == nil || oldCfg.UsageStatisticsEnabled != cfg.UsageStatisticsEnabled {
|
||||||
usage.SetStatisticsEnabled(cfg.UsageStatisticsEnabled)
|
usage.SetStatisticsEnabled(cfg.UsageStatisticsEnabled)
|
||||||
if s.cfg != nil {
|
if oldCfg != nil {
|
||||||
log.Debugf("usage_statistics_enabled updated from %t to %t", s.cfg.UsageStatisticsEnabled, cfg.UsageStatisticsEnabled)
|
log.Debugf("usage_statistics_enabled updated from %t to %t", oldCfg.UsageStatisticsEnabled, cfg.UsageStatisticsEnabled)
|
||||||
|
} else {
|
||||||
|
log.Debugf("usage_statistics_enabled toggled to %t", cfg.UsageStatisticsEnabled)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Update log level dynamically when debug flag changes
|
// Update log level dynamically when debug flag changes
|
||||||
if s.cfg.Debug != cfg.Debug {
|
if oldCfg == nil || oldCfg.Debug != cfg.Debug {
|
||||||
util.SetLogLevel(cfg)
|
util.SetLogLevel(cfg)
|
||||||
log.Debugf("debug mode updated from %t to %t", s.cfg.Debug, cfg.Debug)
|
if oldCfg != nil {
|
||||||
|
log.Debugf("debug mode updated from %t to %t", oldCfg.Debug, cfg.Debug)
|
||||||
|
} else {
|
||||||
|
log.Debugf("debug mode toggled to %t", cfg.Debug)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
s.applyAccessConfig(oldCfg, cfg)
|
||||||
s.cfg = cfg
|
s.cfg = cfg
|
||||||
s.handlers.UpdateClients(cfg)
|
s.handlers.UpdateClients(cfg)
|
||||||
if s.mgmt != nil {
|
if s.mgmt != nil {
|
||||||
s.mgmt.SetConfig(cfg)
|
s.mgmt.SetConfig(cfg)
|
||||||
s.mgmt.SetAuthManager(s.handlers.AuthManager)
|
s.mgmt.SetAuthManager(s.handlers.AuthManager)
|
||||||
}
|
}
|
||||||
s.applyAccessConfig(cfg)
|
|
||||||
|
|
||||||
// Count client sources from configuration and auth directory
|
// Count client sources from configuration and auth directory
|
||||||
authFiles := util.CountAuthFiles(cfg.AuthDir)
|
authFiles := util.CountAuthFiles(cfg.AuthDir)
|
||||||
|
|||||||
@@ -101,58 +101,267 @@ func (r *ModelRegistry) RegisterClient(clientID, clientProvider string, models [
|
|||||||
r.mutex.Lock()
|
r.mutex.Lock()
|
||||||
defer r.mutex.Unlock()
|
defer r.mutex.Unlock()
|
||||||
|
|
||||||
// Remove any existing registration for this client
|
|
||||||
r.unregisterClientInternal(clientID)
|
|
||||||
|
|
||||||
provider := strings.ToLower(clientProvider)
|
provider := strings.ToLower(clientProvider)
|
||||||
modelIDs := make([]string, 0, len(models))
|
uniqueModelIDs := make([]string, 0, len(models))
|
||||||
|
rawModelIDs := make([]string, 0, len(models))
|
||||||
|
newModels := make(map[string]*ModelInfo, len(models))
|
||||||
|
newCounts := make(map[string]int, len(models))
|
||||||
|
for _, model := range models {
|
||||||
|
if model == nil || model.ID == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
rawModelIDs = append(rawModelIDs, model.ID)
|
||||||
|
newCounts[model.ID]++
|
||||||
|
if _, exists := newModels[model.ID]; exists {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
newModels[model.ID] = model
|
||||||
|
uniqueModelIDs = append(uniqueModelIDs, model.ID)
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(uniqueModelIDs) == 0 {
|
||||||
|
// No models supplied; unregister existing client state if present.
|
||||||
|
r.unregisterClientInternal(clientID)
|
||||||
|
delete(r.clientModels, clientID)
|
||||||
|
delete(r.clientProviders, clientID)
|
||||||
|
misc.LogCredentialSeparator()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
now := time.Now()
|
now := time.Now()
|
||||||
|
|
||||||
for _, model := range models {
|
oldModels, hadExisting := r.clientModels[clientID]
|
||||||
modelIDs = append(modelIDs, model.ID)
|
oldProvider, _ := r.clientProviders[clientID]
|
||||||
|
providerChanged := oldProvider != provider
|
||||||
if existing, exists := r.models[model.ID]; exists {
|
if !hadExisting {
|
||||||
// Model already exists, increment count
|
// Pure addition path.
|
||||||
existing.Count++
|
for _, modelID := range rawModelIDs {
|
||||||
existing.LastUpdated = now
|
model := newModels[modelID]
|
||||||
if existing.SuspendedClients == nil {
|
r.addModelRegistration(modelID, provider, model, now)
|
||||||
existing.SuspendedClients = make(map[string]string)
|
}
|
||||||
}
|
r.clientModels[clientID] = append([]string(nil), rawModelIDs...)
|
||||||
if provider != "" {
|
if provider != "" {
|
||||||
if existing.Providers == nil {
|
r.clientProviders[clientID] = provider
|
||||||
existing.Providers = make(map[string]int)
|
|
||||||
}
|
|
||||||
existing.Providers[provider]++
|
|
||||||
}
|
|
||||||
log.Debugf("Incremented count for model %s, now %d clients", model.ID, existing.Count)
|
|
||||||
} else {
|
} else {
|
||||||
// New model, create registration
|
delete(r.clientProviders, clientID)
|
||||||
registration := &ModelRegistration{
|
}
|
||||||
Info: model,
|
log.Debugf("Registered client %s from provider %s with %d models", clientID, clientProvider, len(rawModelIDs))
|
||||||
Count: 1,
|
misc.LogCredentialSeparator()
|
||||||
LastUpdated: now,
|
return
|
||||||
QuotaExceededClients: make(map[string]*time.Time),
|
}
|
||||||
SuspendedClients: make(map[string]string),
|
|
||||||
}
|
oldCounts := make(map[string]int, len(oldModels))
|
||||||
if provider != "" {
|
for _, id := range oldModels {
|
||||||
registration.Providers = map[string]int{provider: 1}
|
oldCounts[id]++
|
||||||
}
|
}
|
||||||
r.models[model.ID] = registration
|
|
||||||
log.Debugf("Registered new model %s from provider %s", model.ID, clientProvider)
|
added := make([]string, 0)
|
||||||
|
for _, id := range uniqueModelIDs {
|
||||||
|
if oldCounts[id] == 0 {
|
||||||
|
added = append(added, id)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
r.clientModels[clientID] = modelIDs
|
removed := make([]string, 0)
|
||||||
|
for id := range oldCounts {
|
||||||
|
if newCounts[id] == 0 {
|
||||||
|
removed = append(removed, id)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handle provider change for overlapping models before modifications.
|
||||||
|
if providerChanged && oldProvider != "" {
|
||||||
|
for id, newCount := range newCounts {
|
||||||
|
if newCount == 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
oldCount := oldCounts[id]
|
||||||
|
if oldCount == 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
toRemove := newCount
|
||||||
|
if oldCount < toRemove {
|
||||||
|
toRemove = oldCount
|
||||||
|
}
|
||||||
|
if reg, ok := r.models[id]; ok && reg.Providers != nil {
|
||||||
|
if count, okProv := reg.Providers[oldProvider]; okProv {
|
||||||
|
if count <= toRemove {
|
||||||
|
delete(reg.Providers, oldProvider)
|
||||||
|
} else {
|
||||||
|
reg.Providers[oldProvider] = count - toRemove
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Apply removals first to keep counters accurate.
|
||||||
|
for _, id := range removed {
|
||||||
|
oldCount := oldCounts[id]
|
||||||
|
for i := 0; i < oldCount; i++ {
|
||||||
|
r.removeModelRegistration(clientID, id, oldProvider, now)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for id, oldCount := range oldCounts {
|
||||||
|
newCount := newCounts[id]
|
||||||
|
if newCount == 0 || oldCount <= newCount {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
overage := oldCount - newCount
|
||||||
|
for i := 0; i < overage; i++ {
|
||||||
|
r.removeModelRegistration(clientID, id, oldProvider, now)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Apply additions.
|
||||||
|
for id, newCount := range newCounts {
|
||||||
|
oldCount := oldCounts[id]
|
||||||
|
if newCount <= oldCount {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
model := newModels[id]
|
||||||
|
diff := newCount - oldCount
|
||||||
|
for i := 0; i < diff; i++ {
|
||||||
|
r.addModelRegistration(id, provider, model, now)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Update metadata for models that remain associated with the client.
|
||||||
|
addedSet := make(map[string]struct{}, len(added))
|
||||||
|
for _, id := range added {
|
||||||
|
addedSet[id] = struct{}{}
|
||||||
|
}
|
||||||
|
for _, id := range uniqueModelIDs {
|
||||||
|
model := newModels[id]
|
||||||
|
if reg, ok := r.models[id]; ok {
|
||||||
|
reg.Info = cloneModelInfo(model)
|
||||||
|
reg.LastUpdated = now
|
||||||
|
if reg.QuotaExceededClients != nil {
|
||||||
|
delete(reg.QuotaExceededClients, clientID)
|
||||||
|
}
|
||||||
|
if reg.SuspendedClients != nil {
|
||||||
|
delete(reg.SuspendedClients, clientID)
|
||||||
|
}
|
||||||
|
if providerChanged && provider != "" {
|
||||||
|
if _, newlyAdded := addedSet[id]; newlyAdded {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
overlapCount := newCounts[id]
|
||||||
|
if oldCount := oldCounts[id]; oldCount < overlapCount {
|
||||||
|
overlapCount = oldCount
|
||||||
|
}
|
||||||
|
if overlapCount <= 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if reg.Providers == nil {
|
||||||
|
reg.Providers = make(map[string]int)
|
||||||
|
}
|
||||||
|
reg.Providers[provider] += overlapCount
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Update client bookkeeping.
|
||||||
|
if len(rawModelIDs) > 0 {
|
||||||
|
r.clientModels[clientID] = append([]string(nil), rawModelIDs...)
|
||||||
|
}
|
||||||
if provider != "" {
|
if provider != "" {
|
||||||
r.clientProviders[clientID] = provider
|
r.clientProviders[clientID] = provider
|
||||||
} else {
|
} else {
|
||||||
delete(r.clientProviders, clientID)
|
delete(r.clientProviders, clientID)
|
||||||
}
|
}
|
||||||
log.Debugf("Registered client %s from provider %s with %d models", clientID, clientProvider, len(models))
|
|
||||||
// Separator at the end of the registration block (acts as boundary to next group)
|
if len(added) == 0 && len(removed) == 0 && !providerChanged {
|
||||||
|
// Only metadata (e.g., display name) changed; skip separator when no log output.
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Debugf("Reconciled client %s (provider %s) models: +%d, -%d", clientID, provider, len(added), len(removed))
|
||||||
misc.LogCredentialSeparator()
|
misc.LogCredentialSeparator()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (r *ModelRegistry) addModelRegistration(modelID, provider string, model *ModelInfo, now time.Time) {
|
||||||
|
if model == nil || modelID == "" {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if existing, exists := r.models[modelID]; exists {
|
||||||
|
existing.Count++
|
||||||
|
existing.LastUpdated = now
|
||||||
|
existing.Info = cloneModelInfo(model)
|
||||||
|
if existing.SuspendedClients == nil {
|
||||||
|
existing.SuspendedClients = make(map[string]string)
|
||||||
|
}
|
||||||
|
if provider != "" {
|
||||||
|
if existing.Providers == nil {
|
||||||
|
existing.Providers = make(map[string]int)
|
||||||
|
}
|
||||||
|
existing.Providers[provider]++
|
||||||
|
}
|
||||||
|
log.Debugf("Incremented count for model %s, now %d clients", modelID, existing.Count)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
registration := &ModelRegistration{
|
||||||
|
Info: cloneModelInfo(model),
|
||||||
|
Count: 1,
|
||||||
|
LastUpdated: now,
|
||||||
|
QuotaExceededClients: make(map[string]*time.Time),
|
||||||
|
SuspendedClients: make(map[string]string),
|
||||||
|
}
|
||||||
|
if provider != "" {
|
||||||
|
registration.Providers = map[string]int{provider: 1}
|
||||||
|
}
|
||||||
|
r.models[modelID] = registration
|
||||||
|
log.Debugf("Registered new model %s from provider %s", modelID, provider)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *ModelRegistry) removeModelRegistration(clientID, modelID, provider string, now time.Time) {
|
||||||
|
registration, exists := r.models[modelID]
|
||||||
|
if !exists {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
registration.Count--
|
||||||
|
registration.LastUpdated = now
|
||||||
|
if registration.QuotaExceededClients != nil {
|
||||||
|
delete(registration.QuotaExceededClients, clientID)
|
||||||
|
}
|
||||||
|
if registration.SuspendedClients != nil {
|
||||||
|
delete(registration.SuspendedClients, clientID)
|
||||||
|
}
|
||||||
|
if registration.Count < 0 {
|
||||||
|
registration.Count = 0
|
||||||
|
}
|
||||||
|
if provider != "" && registration.Providers != nil {
|
||||||
|
if count, ok := registration.Providers[provider]; ok {
|
||||||
|
if count <= 1 {
|
||||||
|
delete(registration.Providers, provider)
|
||||||
|
} else {
|
||||||
|
registration.Providers[provider] = count - 1
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
log.Debugf("Decremented count for model %s, now %d clients", modelID, registration.Count)
|
||||||
|
if registration.Count <= 0 {
|
||||||
|
delete(r.models, modelID)
|
||||||
|
log.Debugf("Removed model %s as no clients remain", modelID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func cloneModelInfo(model *ModelInfo) *ModelInfo {
|
||||||
|
if model == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
copy := *model
|
||||||
|
if len(model.SupportedGenerationMethods) > 0 {
|
||||||
|
copy.SupportedGenerationMethods = append([]string(nil), model.SupportedGenerationMethods...)
|
||||||
|
}
|
||||||
|
if len(model.SupportedParameters) > 0 {
|
||||||
|
copy.SupportedParameters = append([]string(nil), model.SupportedParameters...)
|
||||||
|
}
|
||||||
|
return ©
|
||||||
|
}
|
||||||
|
|
||||||
// UnregisterClient removes a client and decrements counts for its models
|
// UnregisterClient removes a client and decrements counts for its models
|
||||||
// Parameters:
|
// Parameters:
|
||||||
// - clientID: Unique identifier for the client to remove
|
// - clientID: Unique identifier for the client to remove
|
||||||
|
|||||||
@@ -52,6 +52,39 @@ type Watcher struct {
|
|||||||
dispatchCancel context.CancelFunc
|
dispatchCancel context.CancelFunc
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type stableIDGenerator struct {
|
||||||
|
counters map[string]int
|
||||||
|
}
|
||||||
|
|
||||||
|
func newStableIDGenerator() *stableIDGenerator {
|
||||||
|
return &stableIDGenerator{counters: make(map[string]int)}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (g *stableIDGenerator) next(kind string, parts ...string) (string, string) {
|
||||||
|
if g == nil {
|
||||||
|
return kind + ":000000000000", "000000000000"
|
||||||
|
}
|
||||||
|
hasher := sha256.New()
|
||||||
|
hasher.Write([]byte(kind))
|
||||||
|
for _, part := range parts {
|
||||||
|
trimmed := strings.TrimSpace(part)
|
||||||
|
hasher.Write([]byte{0})
|
||||||
|
hasher.Write([]byte(trimmed))
|
||||||
|
}
|
||||||
|
digest := hex.EncodeToString(hasher.Sum(nil))
|
||||||
|
if len(digest) < 12 {
|
||||||
|
digest = fmt.Sprintf("%012s", digest)
|
||||||
|
}
|
||||||
|
short := digest[:12]
|
||||||
|
key := kind + ":" + short
|
||||||
|
index := g.counters[key]
|
||||||
|
g.counters[key] = index + 1
|
||||||
|
if index > 0 {
|
||||||
|
short = fmt.Sprintf("%s-%d", short, index)
|
||||||
|
}
|
||||||
|
return fmt.Sprintf("%s:%s", kind, short), short
|
||||||
|
}
|
||||||
|
|
||||||
// AuthUpdateAction represents the type of change detected in auth sources.
|
// AuthUpdateAction represents the type of change detected in auth sources.
|
||||||
type AuthUpdateAction string
|
type AuthUpdateAction string
|
||||||
|
|
||||||
@@ -640,6 +673,7 @@ func (w *Watcher) removeClient(path string) {
|
|||||||
func (w *Watcher) SnapshotCoreAuths() []*coreauth.Auth {
|
func (w *Watcher) SnapshotCoreAuths() []*coreauth.Auth {
|
||||||
out := make([]*coreauth.Auth, 0, 32)
|
out := make([]*coreauth.Auth, 0, 32)
|
||||||
now := time.Now()
|
now := time.Now()
|
||||||
|
idGen := newStableIDGenerator()
|
||||||
// Also synthesize auth entries for OpenAI-compatibility providers directly from config
|
// Also synthesize auth entries for OpenAI-compatibility providers directly from config
|
||||||
w.clientsMutex.RLock()
|
w.clientsMutex.RLock()
|
||||||
cfg := w.config
|
cfg := w.config
|
||||||
@@ -647,14 +681,18 @@ func (w *Watcher) SnapshotCoreAuths() []*coreauth.Auth {
|
|||||||
if cfg != nil {
|
if cfg != nil {
|
||||||
// Gemini official API keys -> synthesize auths
|
// Gemini official API keys -> synthesize auths
|
||||||
for i := range cfg.GlAPIKey {
|
for i := range cfg.GlAPIKey {
|
||||||
k := cfg.GlAPIKey[i]
|
k := strings.TrimSpace(cfg.GlAPIKey[i])
|
||||||
|
if k == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
id, token := idGen.next("gemini:apikey", k)
|
||||||
a := &coreauth.Auth{
|
a := &coreauth.Auth{
|
||||||
ID: fmt.Sprintf("gemini:apikey:%d", i),
|
ID: id,
|
||||||
Provider: "gemini",
|
Provider: "gemini",
|
||||||
Label: "gemini-apikey",
|
Label: "gemini-apikey",
|
||||||
Status: coreauth.StatusActive,
|
Status: coreauth.StatusActive,
|
||||||
Attributes: map[string]string{
|
Attributes: map[string]string{
|
||||||
"source": fmt.Sprintf("config:gemini#%d", i),
|
"source": fmt.Sprintf("config:gemini[%s]", token),
|
||||||
"api_key": k,
|
"api_key": k,
|
||||||
},
|
},
|
||||||
CreatedAt: now,
|
CreatedAt: now,
|
||||||
@@ -665,15 +703,20 @@ func (w *Watcher) SnapshotCoreAuths() []*coreauth.Auth {
|
|||||||
// Claude API keys -> synthesize auths
|
// Claude API keys -> synthesize auths
|
||||||
for i := range cfg.ClaudeKey {
|
for i := range cfg.ClaudeKey {
|
||||||
ck := cfg.ClaudeKey[i]
|
ck := cfg.ClaudeKey[i]
|
||||||
|
key := strings.TrimSpace(ck.APIKey)
|
||||||
|
if key == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
id, token := idGen.next("claude:apikey", key, ck.BaseURL)
|
||||||
attrs := map[string]string{
|
attrs := map[string]string{
|
||||||
"source": fmt.Sprintf("config:claude#%d", i),
|
"source": fmt.Sprintf("config:claude[%s]", token),
|
||||||
"api_key": ck.APIKey,
|
"api_key": key,
|
||||||
}
|
}
|
||||||
if ck.BaseURL != "" {
|
if ck.BaseURL != "" {
|
||||||
attrs["base_url"] = ck.BaseURL
|
attrs["base_url"] = ck.BaseURL
|
||||||
}
|
}
|
||||||
a := &coreauth.Auth{
|
a := &coreauth.Auth{
|
||||||
ID: fmt.Sprintf("claude:apikey:%d", i),
|
ID: id,
|
||||||
Provider: "claude",
|
Provider: "claude",
|
||||||
Label: "claude-apikey",
|
Label: "claude-apikey",
|
||||||
Status: coreauth.StatusActive,
|
Status: coreauth.StatusActive,
|
||||||
@@ -686,15 +729,20 @@ func (w *Watcher) SnapshotCoreAuths() []*coreauth.Auth {
|
|||||||
// Codex API keys -> synthesize auths
|
// Codex API keys -> synthesize auths
|
||||||
for i := range cfg.CodexKey {
|
for i := range cfg.CodexKey {
|
||||||
ck := cfg.CodexKey[i]
|
ck := cfg.CodexKey[i]
|
||||||
|
key := strings.TrimSpace(ck.APIKey)
|
||||||
|
if key == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
id, token := idGen.next("codex:apikey", key, ck.BaseURL)
|
||||||
attrs := map[string]string{
|
attrs := map[string]string{
|
||||||
"source": fmt.Sprintf("config:codex#%d", i),
|
"source": fmt.Sprintf("config:codex[%s]", token),
|
||||||
"api_key": ck.APIKey,
|
"api_key": key,
|
||||||
}
|
}
|
||||||
if ck.BaseURL != "" {
|
if ck.BaseURL != "" {
|
||||||
attrs["base_url"] = ck.BaseURL
|
attrs["base_url"] = ck.BaseURL
|
||||||
}
|
}
|
||||||
a := &coreauth.Auth{
|
a := &coreauth.Auth{
|
||||||
ID: fmt.Sprintf("codex:apikey:%d", i),
|
ID: id,
|
||||||
Provider: "codex",
|
Provider: "codex",
|
||||||
Label: "codex-apikey",
|
Label: "codex-apikey",
|
||||||
Status: coreauth.StatusActive,
|
Status: coreauth.StatusActive,
|
||||||
@@ -710,11 +758,16 @@ func (w *Watcher) SnapshotCoreAuths() []*coreauth.Auth {
|
|||||||
if providerName == "" {
|
if providerName == "" {
|
||||||
providerName = "openai-compatibility"
|
providerName = "openai-compatibility"
|
||||||
}
|
}
|
||||||
base := compat.BaseURL
|
base := strings.TrimSpace(compat.BaseURL)
|
||||||
for j := range compat.APIKeys {
|
for j := range compat.APIKeys {
|
||||||
key := compat.APIKeys[j]
|
key := strings.TrimSpace(compat.APIKeys[j])
|
||||||
|
if key == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
idKind := fmt.Sprintf("openai-compatibility:%s", providerName)
|
||||||
|
id, token := idGen.next(idKind, key, base)
|
||||||
attrs := map[string]string{
|
attrs := map[string]string{
|
||||||
"source": fmt.Sprintf("config:%s#%d", compat.Name, j),
|
"source": fmt.Sprintf("config:%s[%s]", providerName, token),
|
||||||
"base_url": base,
|
"base_url": base,
|
||||||
"api_key": key,
|
"api_key": key,
|
||||||
"compat_name": compat.Name,
|
"compat_name": compat.Name,
|
||||||
@@ -724,7 +777,7 @@ func (w *Watcher) SnapshotCoreAuths() []*coreauth.Auth {
|
|||||||
attrs["models_hash"] = hash
|
attrs["models_hash"] = hash
|
||||||
}
|
}
|
||||||
a := &coreauth.Auth{
|
a := &coreauth.Auth{
|
||||||
ID: fmt.Sprintf("openai-compatibility:%s:%d", compat.Name, j),
|
ID: id,
|
||||||
Provider: providerName,
|
Provider: providerName,
|
||||||
Label: compat.Name,
|
Label: compat.Name,
|
||||||
Status: coreauth.StatusActive,
|
Status: coreauth.StatusActive,
|
||||||
|
|||||||
252
sdk/access/reconcile.go
Normal file
252
sdk/access/reconcile.go
Normal file
@@ -0,0 +1,252 @@
|
|||||||
|
package access
|
||||||
|
|
||||||
|
import (
|
||||||
|
"reflect"
|
||||||
|
"sort"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ReconcileProviders builds the desired provider list by reusing existing providers when possible
|
||||||
|
// and creating or removing providers only when their configuration changed. It returns the final
|
||||||
|
// ordered provider slice along with the identifiers of providers that were added, updated, or
|
||||||
|
// removed compared to the previous configuration.
|
||||||
|
func ReconcileProviders(oldCfg, newCfg *config.Config, existing []Provider) (result []Provider, added, updated, removed []string, err error) {
|
||||||
|
if newCfg == nil {
|
||||||
|
return nil, nil, nil, nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
existingMap := make(map[string]Provider, len(existing))
|
||||||
|
for _, provider := range existing {
|
||||||
|
if provider == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
existingMap[provider.Identifier()] = provider
|
||||||
|
}
|
||||||
|
|
||||||
|
oldCfgMap := accessProviderMap(oldCfg)
|
||||||
|
newEntries := collectProviderEntries(newCfg)
|
||||||
|
|
||||||
|
result = make([]Provider, 0, len(newEntries))
|
||||||
|
finalIDs := make(map[string]struct{}, len(newEntries))
|
||||||
|
|
||||||
|
isInlineProvider := func(id string) bool {
|
||||||
|
return strings.EqualFold(id, config.DefaultAccessProviderName)
|
||||||
|
}
|
||||||
|
appendChange := func(list *[]string, id string) {
|
||||||
|
if isInlineProvider(id) {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
*list = append(*list, id)
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, providerCfg := range newEntries {
|
||||||
|
key := providerIdentifier(providerCfg)
|
||||||
|
if key == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if oldCfgProvider, ok := oldCfgMap[key]; ok {
|
||||||
|
isAliased := oldCfgProvider == providerCfg
|
||||||
|
if !isAliased && providerConfigEqual(oldCfgProvider, providerCfg) {
|
||||||
|
if existingProvider, okExisting := existingMap[key]; okExisting {
|
||||||
|
result = append(result, existingProvider)
|
||||||
|
finalIDs[key] = struct{}{}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
provider, buildErr := buildProvider(providerCfg, newCfg)
|
||||||
|
if buildErr != nil {
|
||||||
|
return nil, nil, nil, nil, buildErr
|
||||||
|
}
|
||||||
|
if _, ok := oldCfgMap[key]; ok {
|
||||||
|
if _, existed := existingMap[key]; existed {
|
||||||
|
appendChange(&updated, key)
|
||||||
|
} else {
|
||||||
|
appendChange(&added, key)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
appendChange(&added, key)
|
||||||
|
}
|
||||||
|
result = append(result, provider)
|
||||||
|
finalIDs[key] = struct{}{}
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(result) == 0 && len(newCfg.APIKeys) > 0 {
|
||||||
|
config.SyncInlineAPIKeys(newCfg, newCfg.APIKeys)
|
||||||
|
if providerCfg := newCfg.ConfigAPIKeyProvider(); providerCfg != nil {
|
||||||
|
key := providerIdentifier(providerCfg)
|
||||||
|
if key != "" {
|
||||||
|
if oldCfgProvider, ok := oldCfgMap[key]; ok {
|
||||||
|
isAliased := oldCfgProvider == providerCfg
|
||||||
|
if !isAliased && providerConfigEqual(oldCfgProvider, providerCfg) {
|
||||||
|
if existingProvider, okExisting := existingMap[key]; okExisting {
|
||||||
|
result = append(result, existingProvider)
|
||||||
|
} else {
|
||||||
|
provider, buildErr := buildProvider(providerCfg, newCfg)
|
||||||
|
if buildErr != nil {
|
||||||
|
return nil, nil, nil, nil, buildErr
|
||||||
|
}
|
||||||
|
if _, existed := existingMap[key]; existed {
|
||||||
|
appendChange(&updated, key)
|
||||||
|
} else {
|
||||||
|
appendChange(&added, key)
|
||||||
|
}
|
||||||
|
result = append(result, provider)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
provider, buildErr := buildProvider(providerCfg, newCfg)
|
||||||
|
if buildErr != nil {
|
||||||
|
return nil, nil, nil, nil, buildErr
|
||||||
|
}
|
||||||
|
if _, existed := existingMap[key]; existed {
|
||||||
|
appendChange(&updated, key)
|
||||||
|
} else {
|
||||||
|
appendChange(&added, key)
|
||||||
|
}
|
||||||
|
result = append(result, provider)
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
provider, buildErr := buildProvider(providerCfg, newCfg)
|
||||||
|
if buildErr != nil {
|
||||||
|
return nil, nil, nil, nil, buildErr
|
||||||
|
}
|
||||||
|
appendChange(&added, key)
|
||||||
|
result = append(result, provider)
|
||||||
|
}
|
||||||
|
finalIDs[key] = struct{}{}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
removedSet := make(map[string]struct{})
|
||||||
|
for id := range existingMap {
|
||||||
|
if _, ok := finalIDs[id]; !ok {
|
||||||
|
if isInlineProvider(id) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
removedSet[id] = struct{}{}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
removed = make([]string, 0, len(removedSet))
|
||||||
|
for id := range removedSet {
|
||||||
|
removed = append(removed, id)
|
||||||
|
}
|
||||||
|
|
||||||
|
sort.Strings(added)
|
||||||
|
sort.Strings(updated)
|
||||||
|
sort.Strings(removed)
|
||||||
|
|
||||||
|
return result, added, updated, removed, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func accessProviderMap(cfg *config.Config) map[string]*config.AccessProvider {
|
||||||
|
result := make(map[string]*config.AccessProvider)
|
||||||
|
if cfg == nil {
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
for i := range cfg.Access.Providers {
|
||||||
|
providerCfg := &cfg.Access.Providers[i]
|
||||||
|
if providerCfg.Type == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
key := providerIdentifier(providerCfg)
|
||||||
|
if key == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
result[key] = providerCfg
|
||||||
|
}
|
||||||
|
if len(result) == 0 && len(cfg.APIKeys) > 0 {
|
||||||
|
if provider := cfg.ConfigAPIKeyProvider(); provider != nil {
|
||||||
|
if key := providerIdentifier(provider); key != "" {
|
||||||
|
result[key] = provider
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return result
|
||||||
|
}
|
||||||
|
|
||||||
|
func collectProviderEntries(cfg *config.Config) []*config.AccessProvider {
|
||||||
|
entries := make([]*config.AccessProvider, 0, len(cfg.Access.Providers))
|
||||||
|
if cfg == nil {
|
||||||
|
return entries
|
||||||
|
}
|
||||||
|
for i := range cfg.Access.Providers {
|
||||||
|
providerCfg := &cfg.Access.Providers[i]
|
||||||
|
if providerCfg.Type == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if key := providerIdentifier(providerCfg); key != "" {
|
||||||
|
entries = append(entries, providerCfg)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return entries
|
||||||
|
}
|
||||||
|
|
||||||
|
func providerIdentifier(provider *config.AccessProvider) string {
|
||||||
|
if provider == nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
if name := strings.TrimSpace(provider.Name); name != "" {
|
||||||
|
return name
|
||||||
|
}
|
||||||
|
typ := strings.TrimSpace(provider.Type)
|
||||||
|
if typ == "" {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
if strings.EqualFold(typ, config.AccessProviderTypeConfigAPIKey) {
|
||||||
|
return config.DefaultAccessProviderName
|
||||||
|
}
|
||||||
|
return typ
|
||||||
|
}
|
||||||
|
|
||||||
|
func providerConfigEqual(a, b *config.AccessProvider) bool {
|
||||||
|
if a == nil || b == nil {
|
||||||
|
return a == nil && b == nil
|
||||||
|
}
|
||||||
|
if !strings.EqualFold(strings.TrimSpace(a.Type), strings.TrimSpace(b.Type)) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if strings.TrimSpace(a.SDK) != strings.TrimSpace(b.SDK) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if !stringSetEqual(a.APIKeys, b.APIKeys) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if len(a.Config) != len(b.Config) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if len(a.Config) > 0 && !reflect.DeepEqual(a.Config, b.Config) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
func stringSetEqual(a, b []string) bool {
|
||||||
|
if len(a) != len(b) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if len(a) == 0 {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
seen := make(map[string]int, len(a))
|
||||||
|
for _, val := range a {
|
||||||
|
seen[val]++
|
||||||
|
}
|
||||||
|
for _, val := range b {
|
||||||
|
count := seen[val]
|
||||||
|
if count == 0 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if count == 1 {
|
||||||
|
delete(seen, val)
|
||||||
|
} else {
|
||||||
|
seen[val] = count - 1
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return len(seen) == 0
|
||||||
|
}
|
||||||
@@ -110,12 +110,23 @@ func (s *Service) refreshAccessProviders(cfg *config.Config) {
|
|||||||
if s == nil || s.accessManager == nil || cfg == nil {
|
if s == nil || s.accessManager == nil || cfg == nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
providers, err := sdkaccess.BuildProviders(cfg)
|
s.cfgMu.RLock()
|
||||||
|
oldCfg := s.cfg
|
||||||
|
s.cfgMu.RUnlock()
|
||||||
|
|
||||||
|
existing := s.accessManager.Providers()
|
||||||
|
providers, added, updated, removed, err := sdkaccess.ReconcileProviders(oldCfg, cfg, existing)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("failed to rebuild request auth providers: %v", err)
|
log.Errorf("failed to reconcile request auth providers: %v", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
s.accessManager.SetProviders(providers)
|
s.accessManager.SetProviders(providers)
|
||||||
|
if len(added)+len(updated)+len(removed) > 0 {
|
||||||
|
log.Debugf("auth providers reconciled (added=%d updated=%d removed=%d)", len(added), len(updated), len(removed))
|
||||||
|
log.Debugf("auth provider changes details - added=%v updated=%v removed=%v", added, updated, removed)
|
||||||
|
} else {
|
||||||
|
log.Debug("auth providers unchanged after config reload")
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Service) ensureAuthUpdateQueue(ctx context.Context) {
|
func (s *Service) ensureAuthUpdateQueue(ctx context.Context) {
|
||||||
@@ -497,22 +508,35 @@ func (s *Service) registerModelsForAuth(a *coreauth.Auth) {
|
|||||||
if s.cfg != nil {
|
if s.cfg != nil {
|
||||||
providerKey := provider
|
providerKey := provider
|
||||||
compatName := strings.TrimSpace(a.Provider)
|
compatName := strings.TrimSpace(a.Provider)
|
||||||
|
isCompatAuth := false
|
||||||
if strings.EqualFold(providerKey, "openai-compatibility") {
|
if strings.EqualFold(providerKey, "openai-compatibility") {
|
||||||
|
isCompatAuth = true
|
||||||
if a.Attributes != nil {
|
if a.Attributes != nil {
|
||||||
if v := strings.TrimSpace(a.Attributes["compat_name"]); v != "" {
|
if v := strings.TrimSpace(a.Attributes["compat_name"]); v != "" {
|
||||||
compatName = v
|
compatName = v
|
||||||
}
|
}
|
||||||
if v := strings.TrimSpace(a.Attributes["provider_key"]); v != "" {
|
if v := strings.TrimSpace(a.Attributes["provider_key"]); v != "" {
|
||||||
providerKey = strings.ToLower(v)
|
providerKey = strings.ToLower(v)
|
||||||
|
isCompatAuth = true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if providerKey == "openai-compatibility" && compatName != "" {
|
if providerKey == "openai-compatibility" && compatName != "" {
|
||||||
providerKey = strings.ToLower(compatName)
|
providerKey = strings.ToLower(compatName)
|
||||||
}
|
}
|
||||||
|
} else if a.Attributes != nil {
|
||||||
|
if v := strings.TrimSpace(a.Attributes["compat_name"]); v != "" {
|
||||||
|
compatName = v
|
||||||
|
isCompatAuth = true
|
||||||
|
}
|
||||||
|
if v := strings.TrimSpace(a.Attributes["provider_key"]); v != "" {
|
||||||
|
providerKey = strings.ToLower(v)
|
||||||
|
isCompatAuth = true
|
||||||
|
}
|
||||||
}
|
}
|
||||||
for i := range s.cfg.OpenAICompatibility {
|
for i := range s.cfg.OpenAICompatibility {
|
||||||
compat := &s.cfg.OpenAICompatibility[i]
|
compat := &s.cfg.OpenAICompatibility[i]
|
||||||
if strings.EqualFold(compat.Name, compatName) {
|
if strings.EqualFold(compat.Name, compatName) {
|
||||||
|
isCompatAuth = true
|
||||||
// Convert compatibility models to registry models
|
// Convert compatibility models to registry models
|
||||||
ms := make([]*ModelInfo, 0, len(compat.Models))
|
ms := make([]*ModelInfo, 0, len(compat.Models))
|
||||||
for j := range compat.Models {
|
for j := range compat.Models {
|
||||||
@@ -532,10 +556,18 @@ func (s *Service) registerModelsForAuth(a *coreauth.Auth) {
|
|||||||
providerKey = "openai-compatibility"
|
providerKey = "openai-compatibility"
|
||||||
}
|
}
|
||||||
GlobalModelRegistry().RegisterClient(a.ID, providerKey, ms)
|
GlobalModelRegistry().RegisterClient(a.ID, providerKey, ms)
|
||||||
|
} else {
|
||||||
|
// Ensure stale registrations are cleared when model list becomes empty.
|
||||||
|
GlobalModelRegistry().UnregisterClient(a.ID)
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
if isCompatAuth {
|
||||||
|
// No matching provider found or models removed entirely; drop any prior registration.
|
||||||
|
GlobalModelRegistry().UnregisterClient(a.ID)
|
||||||
|
return
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if len(models) > 0 {
|
if len(models) > 0 {
|
||||||
|
|||||||
Reference in New Issue
Block a user