diff --git a/sdk/cliproxy/auth/selector.go b/sdk/cliproxy/auth/selector.go index a173ed01..cf79e173 100644 --- a/sdk/cliproxy/auth/selector.go +++ b/sdk/cliproxy/auth/selector.go @@ -5,6 +5,7 @@ import ( "encoding/json" "fmt" "math" + "math/rand/v2" "net/http" "sort" "strconv" @@ -248,6 +249,9 @@ func getAvailableAuths(auths []*Auth, provider, model string, now time.Time) ([] } // Pick selects the next available auth for the provider in a round-robin manner. +// For gemini-cli virtual auths (identified by the gemini_virtual_parent attribute), +// a two-level round-robin is used: first cycling across credential groups (parent +// accounts), then cycling within each group's project auths. func (s *RoundRobinSelector) Pick(ctx context.Context, provider, model string, opts cliproxyexecutor.Options, auths []*Auth) (*Auth, error) { _ = opts now := time.Now() @@ -265,21 +269,87 @@ func (s *RoundRobinSelector) Pick(ctx context.Context, provider, model string, o if limit <= 0 { limit = 4096 } - if _, ok := s.cursors[key]; !ok && len(s.cursors) >= limit { - s.cursors = make(map[string]int) - } - index := s.cursors[key] + // Check if any available auth has gemini_virtual_parent attribute, + // indicating gemini-cli virtual auths that should use credential-level polling. + groups, parentOrder := groupByVirtualParent(available) + if len(parentOrder) > 1 { + // Two-level round-robin: first select a credential group, then pick within it. + groupKey := key + "::group" + s.ensureCursorKey(groupKey, limit) + if _, exists := s.cursors[groupKey]; !exists { + // Seed with a random initial offset so the starting credential is randomized. + s.cursors[groupKey] = rand.IntN(len(parentOrder)) + } + groupIndex := s.cursors[groupKey] + if groupIndex >= 2_147_483_640 { + groupIndex = 0 + } + s.cursors[groupKey] = groupIndex + 1 + + selectedParent := parentOrder[groupIndex%len(parentOrder)] + group := groups[selectedParent] + + // Second level: round-robin within the selected credential group. + innerKey := key + "::cred:" + selectedParent + s.ensureCursorKey(innerKey, limit) + innerIndex := s.cursors[innerKey] + if innerIndex >= 2_147_483_640 { + innerIndex = 0 + } + s.cursors[innerKey] = innerIndex + 1 + s.mu.Unlock() + return group[innerIndex%len(group)], nil + } + + // Flat round-robin for non-grouped auths (original behavior). + s.ensureCursorKey(key, limit) + index := s.cursors[key] if index >= 2_147_483_640 { index = 0 } - s.cursors[key] = index + 1 s.mu.Unlock() - // log.Debugf("available: %d, index: %d, key: %d", len(available), index, index%len(available)) return available[index%len(available)], nil } +// ensureCursorKey ensures the cursor map has capacity for the given key. +// Must be called with s.mu held. +func (s *RoundRobinSelector) ensureCursorKey(key string, limit int) { + if _, ok := s.cursors[key]; !ok && len(s.cursors) >= limit { + s.cursors = make(map[string]int) + } +} + +// groupByVirtualParent groups auths by their gemini_virtual_parent attribute. +// Returns a map of parentID -> auths and a sorted slice of parent IDs for stable iteration. +// Only auths with a non-empty gemini_virtual_parent are grouped; if any auth lacks +// this attribute, nil/nil is returned so the caller falls back to flat round-robin. +func groupByVirtualParent(auths []*Auth) (map[string][]*Auth, []string) { + if len(auths) == 0 { + return nil, nil + } + groups := make(map[string][]*Auth) + for _, a := range auths { + parent := "" + if a.Attributes != nil { + parent = strings.TrimSpace(a.Attributes["gemini_virtual_parent"]) + } + if parent == "" { + // Non-virtual auth present; fall back to flat round-robin. + return nil, nil + } + groups[parent] = append(groups[parent], a) + } + // Collect parent IDs in sorted order for stable cursor indexing. + parentOrder := make([]string, 0, len(groups)) + for p := range groups { + parentOrder = append(parentOrder, p) + } + sort.Strings(parentOrder) + return groups, parentOrder +} + // Pick selects the first available auth for the provider in a deterministic manner. func (s *FillFirstSelector) Pick(ctx context.Context, provider, model string, opts cliproxyexecutor.Options, auths []*Auth) (*Auth, error) { _ = opts diff --git a/sdk/cliproxy/auth/selector_test.go b/sdk/cliproxy/auth/selector_test.go index fe1cf15e..79431a9a 100644 --- a/sdk/cliproxy/auth/selector_test.go +++ b/sdk/cliproxy/auth/selector_test.go @@ -402,3 +402,128 @@ func TestRoundRobinSelectorPick_CursorKeyCap(t *testing.T) { t.Fatalf("selector.cursors missing key %q", "gemini:m3") } } + +func TestRoundRobinSelectorPick_GeminiCLICredentialGrouping(t *testing.T) { + t.Parallel() + + selector := &RoundRobinSelector{} + + // Simulate two gemini-cli credentials, each with multiple projects: + // Credential A (parent = "cred-a.json") has 3 projects + // Credential B (parent = "cred-b.json") has 2 projects + auths := []*Auth{ + {ID: "cred-a.json::proj-a1", Attributes: map[string]string{"gemini_virtual_parent": "cred-a.json"}}, + {ID: "cred-a.json::proj-a2", Attributes: map[string]string{"gemini_virtual_parent": "cred-a.json"}}, + {ID: "cred-a.json::proj-a3", Attributes: map[string]string{"gemini_virtual_parent": "cred-a.json"}}, + {ID: "cred-b.json::proj-b1", Attributes: map[string]string{"gemini_virtual_parent": "cred-b.json"}}, + {ID: "cred-b.json::proj-b2", Attributes: map[string]string{"gemini_virtual_parent": "cred-b.json"}}, + } + + // Two-level round-robin: consecutive picks must alternate between credentials. + // Credential group order is randomized, but within each call the group cursor + // advances by 1, so consecutive picks should cycle through different parents. + picks := make([]string, 6) + parents := make([]string, 6) + for i := 0; i < 6; i++ { + got, err := selector.Pick(context.Background(), "gemini-cli", "gemini-2.5-pro", cliproxyexecutor.Options{}, auths) + if err != nil { + t.Fatalf("Pick() #%d error = %v", i, err) + } + if got == nil { + t.Fatalf("Pick() #%d auth = nil", i) + } + picks[i] = got.ID + parents[i] = got.Attributes["gemini_virtual_parent"] + } + + // Verify property: consecutive picks must alternate between credential groups. + for i := 1; i < len(parents); i++ { + if parents[i] == parents[i-1] { + t.Fatalf("Pick() #%d and #%d both from same parent %q (IDs: %q, %q); expected alternating credentials", + i-1, i, parents[i], picks[i-1], picks[i]) + } + } + + // Verify property: each credential's projects are picked in sequence (round-robin within group). + credPicks := map[string][]string{} + for i, id := range picks { + credPicks[parents[i]] = append(credPicks[parents[i]], id) + } + for parent, ids := range credPicks { + for i := 1; i < len(ids); i++ { + if ids[i] == ids[i-1] { + t.Fatalf("Credential %q picked same project %q twice in a row", parent, ids[i]) + } + } + } +} + +func TestRoundRobinSelectorPick_SingleParentFallsBackToFlat(t *testing.T) { + t.Parallel() + + selector := &RoundRobinSelector{} + + // All auths from the same parent - should fall back to flat round-robin + // because there's only one credential group (no benefit from two-level). + auths := []*Auth{ + {ID: "cred-a.json::proj-a1", Attributes: map[string]string{"gemini_virtual_parent": "cred-a.json"}}, + {ID: "cred-a.json::proj-a2", Attributes: map[string]string{"gemini_virtual_parent": "cred-a.json"}}, + {ID: "cred-a.json::proj-a3", Attributes: map[string]string{"gemini_virtual_parent": "cred-a.json"}}, + } + + // With single parent group, parentOrder has length 1, so it uses flat round-robin. + // Sorted by ID: proj-a1, proj-a2, proj-a3 + want := []string{ + "cred-a.json::proj-a1", + "cred-a.json::proj-a2", + "cred-a.json::proj-a3", + "cred-a.json::proj-a1", + } + + for i, expectedID := range want { + got, err := selector.Pick(context.Background(), "gemini-cli", "gemini-2.5-pro", cliproxyexecutor.Options{}, auths) + if err != nil { + t.Fatalf("Pick() #%d error = %v", i, err) + } + if got == nil { + t.Fatalf("Pick() #%d auth = nil", i) + } + if got.ID != expectedID { + t.Fatalf("Pick() #%d auth.ID = %q, want %q", i, got.ID, expectedID) + } + } +} + +func TestRoundRobinSelectorPick_MixedVirtualAndNonVirtualFallsBackToFlat(t *testing.T) { + t.Parallel() + + selector := &RoundRobinSelector{} + + // Mix of virtual and non-virtual auths (e.g., a regular gemini-cli auth without projects + // alongside virtual ones). Should fall back to flat round-robin. + auths := []*Auth{ + {ID: "cred-a.json::proj-a1", Attributes: map[string]string{"gemini_virtual_parent": "cred-a.json"}}, + {ID: "cred-regular.json"}, // no gemini_virtual_parent + } + + // groupByVirtualParent returns nil when any auth lacks the attribute, + // so flat round-robin is used. Sorted by ID: cred-a.json::proj-a1, cred-regular.json + want := []string{ + "cred-a.json::proj-a1", + "cred-regular.json", + "cred-a.json::proj-a1", + } + + for i, expectedID := range want { + got, err := selector.Pick(context.Background(), "gemini-cli", "", cliproxyexecutor.Options{}, auths) + if err != nil { + t.Fatalf("Pick() #%d error = %v", i, err) + } + if got == nil { + t.Fatalf("Pick() #%d auth = nil", i) + } + if got.ID != expectedID { + t.Fatalf("Pick() #%d auth.ID = %q, want %q", i, got.ID, expectedID) + } + } +}