From b84ccc6e7aeba5f3b4d3ea96829273b1734ea8c0 Mon Sep 17 00:00:00 2001 From: Luis Pater Date: Mon, 22 Dec 2025 22:52:23 +0800 Subject: [PATCH] feat: add unit tests for routing strategies and implement dynamic selector updates Added comprehensive tests for `FillFirstSelector` and `RoundRobinSelector` to ensure proper behavior, including deterministic, cyclical, and concurrent scenarios. Introduced dynamic routing strategy updates in `service.go`, normalizing strategies and seamlessly switching between `fill-first` and `round-robin`. Updated `Manager` to support selector changes via the new `SetSelector` method. --- sdk/cliproxy/auth/manager.go | 12 +++ sdk/cliproxy/auth/selector.go | 6 +- sdk/cliproxy/auth/selector_test.go | 113 +++++++++++++++++++++++++++++ sdk/cliproxy/service.go | 31 ++++++++ 4 files changed, 159 insertions(+), 3 deletions(-) create mode 100644 sdk/cliproxy/auth/selector_test.go diff --git a/sdk/cliproxy/auth/manager.go b/sdk/cliproxy/auth/manager.go index c345cd15..38d4c0fa 100644 --- a/sdk/cliproxy/auth/manager.go +++ b/sdk/cliproxy/auth/manager.go @@ -135,6 +135,18 @@ func NewManager(store Store, selector Selector, hook Hook) *Manager { } } +func (m *Manager) SetSelector(selector Selector) { + if m == nil { + return + } + if selector == nil { + selector = &RoundRobinSelector{} + } + m.mu.Lock() + m.selector = selector + m.mu.Unlock() +} + // SetStore swaps the underlying persistence store. func (m *Manager) SetStore(store Store) { m.mu.Lock() diff --git a/sdk/cliproxy/auth/selector.go b/sdk/cliproxy/auth/selector.go index b1f4d5fe..d7e120c5 100644 --- a/sdk/cliproxy/auth/selector.go +++ b/sdk/cliproxy/auth/selector.go @@ -149,9 +149,6 @@ func getAvailableAuths(auths []*Auth, provider, model string, now time.Time) ([] func (s *RoundRobinSelector) Pick(ctx context.Context, provider, model string, opts cliproxyexecutor.Options, auths []*Auth) (*Auth, error) { _ = ctx _ = opts - if s.cursors == nil { - s.cursors = make(map[string]int) - } now := time.Now() available, err := getAvailableAuths(auths, provider, model, now) if err != nil { @@ -159,6 +156,9 @@ func (s *RoundRobinSelector) Pick(ctx context.Context, provider, model string, o } key := provider + ":" + model s.mu.Lock() + if s.cursors == nil { + s.cursors = make(map[string]int) + } index := s.cursors[key] if index >= 2_147_483_640 { diff --git a/sdk/cliproxy/auth/selector_test.go b/sdk/cliproxy/auth/selector_test.go new file mode 100644 index 00000000..f4beed03 --- /dev/null +++ b/sdk/cliproxy/auth/selector_test.go @@ -0,0 +1,113 @@ +package auth + +import ( + "context" + "errors" + "sync" + "testing" + + cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor" +) + +func TestFillFirstSelectorPick_Deterministic(t *testing.T) { + t.Parallel() + + selector := &FillFirstSelector{} + auths := []*Auth{ + {ID: "b"}, + {ID: "a"}, + {ID: "c"}, + } + + got, err := selector.Pick(context.Background(), "gemini", "", cliproxyexecutor.Options{}, auths) + if err != nil { + t.Fatalf("Pick() error = %v", err) + } + if got == nil { + t.Fatalf("Pick() auth = nil") + } + if got.ID != "a" { + t.Fatalf("Pick() auth.ID = %q, want %q", got.ID, "a") + } +} + +func TestRoundRobinSelectorPick_CyclesDeterministic(t *testing.T) { + t.Parallel() + + selector := &RoundRobinSelector{} + auths := []*Auth{ + {ID: "b"}, + {ID: "a"}, + {ID: "c"}, + } + + want := []string{"a", "b", "c", "a", "b"} + for i, id := range want { + got, err := selector.Pick(context.Background(), "gemini", "", cliproxyexecutor.Options{}, auths) + if err != nil { + t.Fatalf("Pick() #%d error = %v", i, err) + } + if got == nil { + t.Fatalf("Pick() #%d auth = nil", i) + } + if got.ID != id { + t.Fatalf("Pick() #%d auth.ID = %q, want %q", i, got.ID, id) + } + } +} + +func TestRoundRobinSelectorPick_Concurrent(t *testing.T) { + selector := &RoundRobinSelector{} + auths := []*Auth{ + {ID: "b"}, + {ID: "a"}, + {ID: "c"}, + } + + start := make(chan struct{}) + var wg sync.WaitGroup + errCh := make(chan error, 1) + + goroutines := 32 + iterations := 100 + for i := 0; i < goroutines; i++ { + wg.Add(1) + go func() { + defer wg.Done() + <-start + for j := 0; j < iterations; j++ { + got, err := selector.Pick(context.Background(), "gemini", "", cliproxyexecutor.Options{}, auths) + if err != nil { + select { + case errCh <- err: + default: + } + return + } + if got == nil { + select { + case errCh <- errors.New("Pick() returned nil auth"): + default: + } + return + } + if got.ID == "" { + select { + case errCh <- errors.New("Pick() returned auth with empty ID"): + default: + } + return + } + } + }() + } + + close(start) + wg.Wait() + + select { + case err := <-errCh: + t.Fatalf("concurrent Pick() error = %v", err) + default: + } +} diff --git a/sdk/cliproxy/service.go b/sdk/cliproxy/service.go index e4cd9e5d..a699ca61 100644 --- a/sdk/cliproxy/service.go +++ b/sdk/cliproxy/service.go @@ -506,6 +506,13 @@ func (s *Service) Run(ctx context.Context) error { var watcherWrapper *WatcherWrapper reloadCallback := func(newCfg *config.Config) { + previousStrategy := "" + s.cfgMu.RLock() + if s.cfg != nil { + previousStrategy = strings.ToLower(strings.TrimSpace(s.cfg.Routing.Strategy)) + } + s.cfgMu.RUnlock() + if newCfg == nil { s.cfgMu.RLock() newCfg = s.cfg @@ -514,6 +521,30 @@ func (s *Service) Run(ctx context.Context) error { if newCfg == nil { return } + + nextStrategy := strings.ToLower(strings.TrimSpace(newCfg.Routing.Strategy)) + normalizeStrategy := func(strategy string) string { + switch strategy { + case "fill-first", "fillfirst", "ff": + return "fill-first" + default: + return "round-robin" + } + } + previousStrategy = normalizeStrategy(previousStrategy) + nextStrategy = normalizeStrategy(nextStrategy) + if s.coreManager != nil && previousStrategy != nextStrategy { + var selector coreauth.Selector + switch nextStrategy { + case "fill-first": + selector = &coreauth.FillFirstSelector{} + default: + selector = &coreauth.RoundRobinSelector{} + } + s.coreManager.SetSelector(selector) + log.Infof("routing strategy updated to %s", nextStrategy) + } + s.applyRetryConfig(newCfg) if s.server != nil { s.server.UpdateClients(newCfg)