Files
CLIProxyAPI/sdk/cliproxy/auth/selector_test.go
rensumo 5936f9895c feat: implement credential-based round-robin for gemini-cli virtual auths
Changes the RoundRobinSelector to use two-level round-robin when
gemini-cli virtual auths are detected (via gemini_virtual_parent attr):
- Level 1: cycle across credential groups (parent accounts)
- Level 2: cycle within each group's project auths

Credentials start from a random offset (rand.IntN) for fair distribution.
Non-virtual auths and single-credential scenarios fall back to flat RR.

Adds 3 test cases covering multi-credential grouping, single-parent
fallback, and mixed virtual/non-virtual fallback.
2026-02-21 12:49:48 +08:00

530 lines
14 KiB
Go

package auth
import (
"context"
"encoding/json"
"errors"
"net/http"
"sync"
"testing"
"time"
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
)
func TestFillFirstSelectorPick_Deterministic(t *testing.T) {
t.Parallel()
selector := &FillFirstSelector{}
auths := []*Auth{
{ID: "b"},
{ID: "a"},
{ID: "c"},
}
got, err := selector.Pick(context.Background(), "gemini", "", cliproxyexecutor.Options{}, auths)
if err != nil {
t.Fatalf("Pick() error = %v", err)
}
if got == nil {
t.Fatalf("Pick() auth = nil")
}
if got.ID != "a" {
t.Fatalf("Pick() auth.ID = %q, want %q", got.ID, "a")
}
}
func TestRoundRobinSelectorPick_CyclesDeterministic(t *testing.T) {
t.Parallel()
selector := &RoundRobinSelector{}
auths := []*Auth{
{ID: "b"},
{ID: "a"},
{ID: "c"},
}
want := []string{"a", "b", "c", "a", "b"}
for i, id := range want {
got, err := selector.Pick(context.Background(), "gemini", "", cliproxyexecutor.Options{}, auths)
if err != nil {
t.Fatalf("Pick() #%d error = %v", i, err)
}
if got == nil {
t.Fatalf("Pick() #%d auth = nil", i)
}
if got.ID != id {
t.Fatalf("Pick() #%d auth.ID = %q, want %q", i, got.ID, id)
}
}
}
func TestRoundRobinSelectorPick_PriorityBuckets(t *testing.T) {
t.Parallel()
selector := &RoundRobinSelector{}
auths := []*Auth{
{ID: "c", Attributes: map[string]string{"priority": "0"}},
{ID: "a", Attributes: map[string]string{"priority": "10"}},
{ID: "b", Attributes: map[string]string{"priority": "10"}},
}
want := []string{"a", "b", "a", "b"}
for i, id := range want {
got, err := selector.Pick(context.Background(), "mixed", "", cliproxyexecutor.Options{}, auths)
if err != nil {
t.Fatalf("Pick() #%d error = %v", i, err)
}
if got == nil {
t.Fatalf("Pick() #%d auth = nil", i)
}
if got.ID != id {
t.Fatalf("Pick() #%d auth.ID = %q, want %q", i, got.ID, id)
}
if got.ID == "c" {
t.Fatalf("Pick() #%d unexpectedly selected lower priority auth", i)
}
}
}
func TestFillFirstSelectorPick_PriorityFallbackCooldown(t *testing.T) {
t.Parallel()
selector := &FillFirstSelector{}
now := time.Now()
model := "test-model"
high := &Auth{
ID: "high",
Attributes: map[string]string{"priority": "10"},
ModelStates: map[string]*ModelState{
model: {
Status: StatusActive,
Unavailable: true,
NextRetryAfter: now.Add(30 * time.Minute),
Quota: QuotaState{
Exceeded: true,
},
},
},
}
low := &Auth{ID: "low", Attributes: map[string]string{"priority": "0"}}
got, err := selector.Pick(context.Background(), "mixed", model, cliproxyexecutor.Options{}, []*Auth{high, low})
if err != nil {
t.Fatalf("Pick() error = %v", err)
}
if got == nil {
t.Fatalf("Pick() auth = nil")
}
if got.ID != "low" {
t.Fatalf("Pick() auth.ID = %q, want %q", got.ID, "low")
}
}
func TestRoundRobinSelectorPick_Concurrent(t *testing.T) {
selector := &RoundRobinSelector{}
auths := []*Auth{
{ID: "b"},
{ID: "a"},
{ID: "c"},
}
start := make(chan struct{})
var wg sync.WaitGroup
errCh := make(chan error, 1)
goroutines := 32
iterations := 100
for i := 0; i < goroutines; i++ {
wg.Add(1)
go func() {
defer wg.Done()
<-start
for j := 0; j < iterations; j++ {
got, err := selector.Pick(context.Background(), "gemini", "", cliproxyexecutor.Options{}, auths)
if err != nil {
select {
case errCh <- err:
default:
}
return
}
if got == nil {
select {
case errCh <- errors.New("Pick() returned nil auth"):
default:
}
return
}
if got.ID == "" {
select {
case errCh <- errors.New("Pick() returned auth with empty ID"):
default:
}
return
}
}
}()
}
close(start)
wg.Wait()
select {
case err := <-errCh:
t.Fatalf("concurrent Pick() error = %v", err)
default:
}
}
func TestSelectorPick_AllCooldownReturnsModelCooldownError(t *testing.T) {
t.Parallel()
model := "test-model"
now := time.Now()
next := now.Add(60 * time.Second)
auths := []*Auth{
{
ID: "a",
ModelStates: map[string]*ModelState{
model: {
Status: StatusActive,
Unavailable: true,
NextRetryAfter: next,
Quota: QuotaState{
Exceeded: true,
NextRecoverAt: next,
},
},
},
},
{
ID: "b",
ModelStates: map[string]*ModelState{
model: {
Status: StatusActive,
Unavailable: true,
NextRetryAfter: next,
Quota: QuotaState{
Exceeded: true,
NextRecoverAt: next,
},
},
},
},
}
t.Run("mixed provider redacts provider field", func(t *testing.T) {
t.Parallel()
selector := &FillFirstSelector{}
_, err := selector.Pick(context.Background(), "mixed", model, cliproxyexecutor.Options{}, auths)
if err == nil {
t.Fatalf("Pick() error = nil")
}
var mce *modelCooldownError
if !errors.As(err, &mce) {
t.Fatalf("Pick() error = %T, want *modelCooldownError", err)
}
if mce.StatusCode() != http.StatusTooManyRequests {
t.Fatalf("StatusCode() = %d, want %d", mce.StatusCode(), http.StatusTooManyRequests)
}
headers := mce.Headers()
if got := headers.Get("Retry-After"); got == "" {
t.Fatalf("Headers().Get(Retry-After) = empty")
}
var payload map[string]any
if err := json.Unmarshal([]byte(mce.Error()), &payload); err != nil {
t.Fatalf("json.Unmarshal(Error()) error = %v", err)
}
rawErr, ok := payload["error"].(map[string]any)
if !ok {
t.Fatalf("Error() payload missing error object: %v", payload)
}
if got, _ := rawErr["code"].(string); got != "model_cooldown" {
t.Fatalf("Error().error.code = %q, want %q", got, "model_cooldown")
}
if _, ok := rawErr["provider"]; ok {
t.Fatalf("Error().error.provider exists for mixed provider: %v", rawErr["provider"])
}
})
t.Run("non-mixed provider includes provider field", func(t *testing.T) {
t.Parallel()
selector := &FillFirstSelector{}
_, err := selector.Pick(context.Background(), "gemini", model, cliproxyexecutor.Options{}, auths)
if err == nil {
t.Fatalf("Pick() error = nil")
}
var mce *modelCooldownError
if !errors.As(err, &mce) {
t.Fatalf("Pick() error = %T, want *modelCooldownError", err)
}
var payload map[string]any
if err := json.Unmarshal([]byte(mce.Error()), &payload); err != nil {
t.Fatalf("json.Unmarshal(Error()) error = %v", err)
}
rawErr, ok := payload["error"].(map[string]any)
if !ok {
t.Fatalf("Error() payload missing error object: %v", payload)
}
if got, _ := rawErr["provider"].(string); got != "gemini" {
t.Fatalf("Error().error.provider = %q, want %q", got, "gemini")
}
})
}
func TestIsAuthBlockedForModel_UnavailableWithoutNextRetryIsNotBlocked(t *testing.T) {
t.Parallel()
now := time.Now()
model := "test-model"
auth := &Auth{
ID: "a",
ModelStates: map[string]*ModelState{
model: {
Status: StatusActive,
Unavailable: true,
Quota: QuotaState{
Exceeded: true,
},
},
},
}
blocked, reason, next := isAuthBlockedForModel(auth, model, now)
if blocked {
t.Fatalf("blocked = true, want false")
}
if reason != blockReasonNone {
t.Fatalf("reason = %v, want %v", reason, blockReasonNone)
}
if !next.IsZero() {
t.Fatalf("next = %v, want zero", next)
}
}
func TestFillFirstSelectorPick_ThinkingSuffixFallsBackToBaseModelState(t *testing.T) {
t.Parallel()
selector := &FillFirstSelector{}
now := time.Now()
baseModel := "test-model"
requestedModel := "test-model(high)"
high := &Auth{
ID: "high",
Attributes: map[string]string{"priority": "10"},
ModelStates: map[string]*ModelState{
baseModel: {
Status: StatusActive,
Unavailable: true,
NextRetryAfter: now.Add(30 * time.Minute),
Quota: QuotaState{
Exceeded: true,
},
},
},
}
low := &Auth{
ID: "low",
Attributes: map[string]string{"priority": "0"},
}
got, err := selector.Pick(context.Background(), "mixed", requestedModel, cliproxyexecutor.Options{}, []*Auth{high, low})
if err != nil {
t.Fatalf("Pick() error = %v", err)
}
if got == nil {
t.Fatalf("Pick() auth = nil")
}
if got.ID != "low" {
t.Fatalf("Pick() auth.ID = %q, want %q", got.ID, "low")
}
}
func TestRoundRobinSelectorPick_ThinkingSuffixSharesCursor(t *testing.T) {
t.Parallel()
selector := &RoundRobinSelector{}
auths := []*Auth{
{ID: "b"},
{ID: "a"},
}
first, err := selector.Pick(context.Background(), "gemini", "test-model(high)", cliproxyexecutor.Options{}, auths)
if err != nil {
t.Fatalf("Pick() first error = %v", err)
}
second, err := selector.Pick(context.Background(), "gemini", "test-model(low)", cliproxyexecutor.Options{}, auths)
if err != nil {
t.Fatalf("Pick() second error = %v", err)
}
if first == nil || second == nil {
t.Fatalf("Pick() returned nil auth")
}
if first.ID != "a" {
t.Fatalf("Pick() first auth.ID = %q, want %q", first.ID, "a")
}
if second.ID != "b" {
t.Fatalf("Pick() second auth.ID = %q, want %q", second.ID, "b")
}
}
func TestRoundRobinSelectorPick_CursorKeyCap(t *testing.T) {
t.Parallel()
selector := &RoundRobinSelector{maxKeys: 2}
auths := []*Auth{{ID: "a"}}
_, _ = selector.Pick(context.Background(), "gemini", "m1", cliproxyexecutor.Options{}, auths)
_, _ = selector.Pick(context.Background(), "gemini", "m2", cliproxyexecutor.Options{}, auths)
_, _ = selector.Pick(context.Background(), "gemini", "m3", cliproxyexecutor.Options{}, auths)
selector.mu.Lock()
defer selector.mu.Unlock()
if selector.cursors == nil {
t.Fatalf("selector.cursors = nil")
}
if len(selector.cursors) != 1 {
t.Fatalf("len(selector.cursors) = %d, want %d", len(selector.cursors), 1)
}
if _, ok := selector.cursors["gemini:m3"]; !ok {
t.Fatalf("selector.cursors missing key %q", "gemini:m3")
}
}
func TestRoundRobinSelectorPick_GeminiCLICredentialGrouping(t *testing.T) {
t.Parallel()
selector := &RoundRobinSelector{}
// Simulate two gemini-cli credentials, each with multiple projects:
// Credential A (parent = "cred-a.json") has 3 projects
// Credential B (parent = "cred-b.json") has 2 projects
auths := []*Auth{
{ID: "cred-a.json::proj-a1", Attributes: map[string]string{"gemini_virtual_parent": "cred-a.json"}},
{ID: "cred-a.json::proj-a2", Attributes: map[string]string{"gemini_virtual_parent": "cred-a.json"}},
{ID: "cred-a.json::proj-a3", Attributes: map[string]string{"gemini_virtual_parent": "cred-a.json"}},
{ID: "cred-b.json::proj-b1", Attributes: map[string]string{"gemini_virtual_parent": "cred-b.json"}},
{ID: "cred-b.json::proj-b2", Attributes: map[string]string{"gemini_virtual_parent": "cred-b.json"}},
}
// Two-level round-robin: consecutive picks must alternate between credentials.
// Credential group order is randomized, but within each call the group cursor
// advances by 1, so consecutive picks should cycle through different parents.
picks := make([]string, 6)
parents := make([]string, 6)
for i := 0; i < 6; i++ {
got, err := selector.Pick(context.Background(), "gemini-cli", "gemini-2.5-pro", cliproxyexecutor.Options{}, auths)
if err != nil {
t.Fatalf("Pick() #%d error = %v", i, err)
}
if got == nil {
t.Fatalf("Pick() #%d auth = nil", i)
}
picks[i] = got.ID
parents[i] = got.Attributes["gemini_virtual_parent"]
}
// Verify property: consecutive picks must alternate between credential groups.
for i := 1; i < len(parents); i++ {
if parents[i] == parents[i-1] {
t.Fatalf("Pick() #%d and #%d both from same parent %q (IDs: %q, %q); expected alternating credentials",
i-1, i, parents[i], picks[i-1], picks[i])
}
}
// Verify property: each credential's projects are picked in sequence (round-robin within group).
credPicks := map[string][]string{}
for i, id := range picks {
credPicks[parents[i]] = append(credPicks[parents[i]], id)
}
for parent, ids := range credPicks {
for i := 1; i < len(ids); i++ {
if ids[i] == ids[i-1] {
t.Fatalf("Credential %q picked same project %q twice in a row", parent, ids[i])
}
}
}
}
func TestRoundRobinSelectorPick_SingleParentFallsBackToFlat(t *testing.T) {
t.Parallel()
selector := &RoundRobinSelector{}
// All auths from the same parent - should fall back to flat round-robin
// because there's only one credential group (no benefit from two-level).
auths := []*Auth{
{ID: "cred-a.json::proj-a1", Attributes: map[string]string{"gemini_virtual_parent": "cred-a.json"}},
{ID: "cred-a.json::proj-a2", Attributes: map[string]string{"gemini_virtual_parent": "cred-a.json"}},
{ID: "cred-a.json::proj-a3", Attributes: map[string]string{"gemini_virtual_parent": "cred-a.json"}},
}
// With single parent group, parentOrder has length 1, so it uses flat round-robin.
// Sorted by ID: proj-a1, proj-a2, proj-a3
want := []string{
"cred-a.json::proj-a1",
"cred-a.json::proj-a2",
"cred-a.json::proj-a3",
"cred-a.json::proj-a1",
}
for i, expectedID := range want {
got, err := selector.Pick(context.Background(), "gemini-cli", "gemini-2.5-pro", cliproxyexecutor.Options{}, auths)
if err != nil {
t.Fatalf("Pick() #%d error = %v", i, err)
}
if got == nil {
t.Fatalf("Pick() #%d auth = nil", i)
}
if got.ID != expectedID {
t.Fatalf("Pick() #%d auth.ID = %q, want %q", i, got.ID, expectedID)
}
}
}
func TestRoundRobinSelectorPick_MixedVirtualAndNonVirtualFallsBackToFlat(t *testing.T) {
t.Parallel()
selector := &RoundRobinSelector{}
// Mix of virtual and non-virtual auths (e.g., a regular gemini-cli auth without projects
// alongside virtual ones). Should fall back to flat round-robin.
auths := []*Auth{
{ID: "cred-a.json::proj-a1", Attributes: map[string]string{"gemini_virtual_parent": "cred-a.json"}},
{ID: "cred-regular.json"}, // no gemini_virtual_parent
}
// groupByVirtualParent returns nil when any auth lacks the attribute,
// so flat round-robin is used. Sorted by ID: cred-a.json::proj-a1, cred-regular.json
want := []string{
"cred-a.json::proj-a1",
"cred-regular.json",
"cred-a.json::proj-a1",
}
for i, expectedID := range want {
got, err := selector.Pick(context.Background(), "gemini-cli", "", cliproxyexecutor.Options{}, auths)
if err != nil {
t.Fatalf("Pick() #%d error = %v", i, err)
}
if got == nil {
t.Fatalf("Pick() #%d auth = nil", i)
}
if got.ID != expectedID {
t.Fatalf("Pick() #%d auth.ID = %q, want %q", i, got.ID, expectedID)
}
}
}