mirror of
https://github.com/router-for-me/CLIProxyAPI.git
synced 2026-02-28 08:36:09 +08:00
Merge pull request #1663 from rensumo/main
feat: implement credential-based round-robin for gemini-cli
This commit is contained in:
@@ -5,6 +5,7 @@ import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"math"
|
||||
"math/rand/v2"
|
||||
"net/http"
|
||||
"sort"
|
||||
"strconv"
|
||||
@@ -248,6 +249,9 @@ func getAvailableAuths(auths []*Auth, provider, model string, now time.Time) ([]
|
||||
}
|
||||
|
||||
// Pick selects the next available auth for the provider in a round-robin manner.
|
||||
// For gemini-cli virtual auths (identified by the gemini_virtual_parent attribute),
|
||||
// a two-level round-robin is used: first cycling across credential groups (parent
|
||||
// accounts), then cycling within each group's project auths.
|
||||
func (s *RoundRobinSelector) Pick(ctx context.Context, provider, model string, opts cliproxyexecutor.Options, auths []*Auth) (*Auth, error) {
|
||||
_ = opts
|
||||
now := time.Now()
|
||||
@@ -265,21 +269,87 @@ func (s *RoundRobinSelector) Pick(ctx context.Context, provider, model string, o
|
||||
if limit <= 0 {
|
||||
limit = 4096
|
||||
}
|
||||
if _, ok := s.cursors[key]; !ok && len(s.cursors) >= limit {
|
||||
s.cursors = make(map[string]int)
|
||||
}
|
||||
index := s.cursors[key]
|
||||
|
||||
// Check if any available auth has gemini_virtual_parent attribute,
|
||||
// indicating gemini-cli virtual auths that should use credential-level polling.
|
||||
groups, parentOrder := groupByVirtualParent(available)
|
||||
if len(parentOrder) > 1 {
|
||||
// Two-level round-robin: first select a credential group, then pick within it.
|
||||
groupKey := key + "::group"
|
||||
s.ensureCursorKey(groupKey, limit)
|
||||
if _, exists := s.cursors[groupKey]; !exists {
|
||||
// Seed with a random initial offset so the starting credential is randomized.
|
||||
s.cursors[groupKey] = rand.IntN(len(parentOrder))
|
||||
}
|
||||
groupIndex := s.cursors[groupKey]
|
||||
if groupIndex >= 2_147_483_640 {
|
||||
groupIndex = 0
|
||||
}
|
||||
s.cursors[groupKey] = groupIndex + 1
|
||||
|
||||
selectedParent := parentOrder[groupIndex%len(parentOrder)]
|
||||
group := groups[selectedParent]
|
||||
|
||||
// Second level: round-robin within the selected credential group.
|
||||
innerKey := key + "::cred:" + selectedParent
|
||||
s.ensureCursorKey(innerKey, limit)
|
||||
innerIndex := s.cursors[innerKey]
|
||||
if innerIndex >= 2_147_483_640 {
|
||||
innerIndex = 0
|
||||
}
|
||||
s.cursors[innerKey] = innerIndex + 1
|
||||
s.mu.Unlock()
|
||||
return group[innerIndex%len(group)], nil
|
||||
}
|
||||
|
||||
// Flat round-robin for non-grouped auths (original behavior).
|
||||
s.ensureCursorKey(key, limit)
|
||||
index := s.cursors[key]
|
||||
if index >= 2_147_483_640 {
|
||||
index = 0
|
||||
}
|
||||
|
||||
s.cursors[key] = index + 1
|
||||
s.mu.Unlock()
|
||||
// log.Debugf("available: %d, index: %d, key: %d", len(available), index, index%len(available))
|
||||
return available[index%len(available)], nil
|
||||
}
|
||||
|
||||
// ensureCursorKey ensures the cursor map has capacity for the given key.
|
||||
// Must be called with s.mu held.
|
||||
func (s *RoundRobinSelector) ensureCursorKey(key string, limit int) {
|
||||
if _, ok := s.cursors[key]; !ok && len(s.cursors) >= limit {
|
||||
s.cursors = make(map[string]int)
|
||||
}
|
||||
}
|
||||
|
||||
// groupByVirtualParent groups auths by their gemini_virtual_parent attribute.
|
||||
// Returns a map of parentID -> auths and a sorted slice of parent IDs for stable iteration.
|
||||
// Only auths with a non-empty gemini_virtual_parent are grouped; if any auth lacks
|
||||
// this attribute, nil/nil is returned so the caller falls back to flat round-robin.
|
||||
func groupByVirtualParent(auths []*Auth) (map[string][]*Auth, []string) {
|
||||
if len(auths) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
groups := make(map[string][]*Auth)
|
||||
for _, a := range auths {
|
||||
parent := ""
|
||||
if a.Attributes != nil {
|
||||
parent = strings.TrimSpace(a.Attributes["gemini_virtual_parent"])
|
||||
}
|
||||
if parent == "" {
|
||||
// Non-virtual auth present; fall back to flat round-robin.
|
||||
return nil, nil
|
||||
}
|
||||
groups[parent] = append(groups[parent], a)
|
||||
}
|
||||
// Collect parent IDs in sorted order for stable cursor indexing.
|
||||
parentOrder := make([]string, 0, len(groups))
|
||||
for p := range groups {
|
||||
parentOrder = append(parentOrder, p)
|
||||
}
|
||||
sort.Strings(parentOrder)
|
||||
return groups, parentOrder
|
||||
}
|
||||
|
||||
// Pick selects the first available auth for the provider in a deterministic manner.
|
||||
func (s *FillFirstSelector) Pick(ctx context.Context, provider, model string, opts cliproxyexecutor.Options, auths []*Auth) (*Auth, error) {
|
||||
_ = opts
|
||||
|
||||
@@ -402,3 +402,128 @@ func TestRoundRobinSelectorPick_CursorKeyCap(t *testing.T) {
|
||||
t.Fatalf("selector.cursors missing key %q", "gemini:m3")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRoundRobinSelectorPick_GeminiCLICredentialGrouping(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
selector := &RoundRobinSelector{}
|
||||
|
||||
// Simulate two gemini-cli credentials, each with multiple projects:
|
||||
// Credential A (parent = "cred-a.json") has 3 projects
|
||||
// Credential B (parent = "cred-b.json") has 2 projects
|
||||
auths := []*Auth{
|
||||
{ID: "cred-a.json::proj-a1", Attributes: map[string]string{"gemini_virtual_parent": "cred-a.json"}},
|
||||
{ID: "cred-a.json::proj-a2", Attributes: map[string]string{"gemini_virtual_parent": "cred-a.json"}},
|
||||
{ID: "cred-a.json::proj-a3", Attributes: map[string]string{"gemini_virtual_parent": "cred-a.json"}},
|
||||
{ID: "cred-b.json::proj-b1", Attributes: map[string]string{"gemini_virtual_parent": "cred-b.json"}},
|
||||
{ID: "cred-b.json::proj-b2", Attributes: map[string]string{"gemini_virtual_parent": "cred-b.json"}},
|
||||
}
|
||||
|
||||
// Two-level round-robin: consecutive picks must alternate between credentials.
|
||||
// Credential group order is randomized, but within each call the group cursor
|
||||
// advances by 1, so consecutive picks should cycle through different parents.
|
||||
picks := make([]string, 6)
|
||||
parents := make([]string, 6)
|
||||
for i := 0; i < 6; i++ {
|
||||
got, err := selector.Pick(context.Background(), "gemini-cli", "gemini-2.5-pro", cliproxyexecutor.Options{}, auths)
|
||||
if err != nil {
|
||||
t.Fatalf("Pick() #%d error = %v", i, err)
|
||||
}
|
||||
if got == nil {
|
||||
t.Fatalf("Pick() #%d auth = nil", i)
|
||||
}
|
||||
picks[i] = got.ID
|
||||
parents[i] = got.Attributes["gemini_virtual_parent"]
|
||||
}
|
||||
|
||||
// Verify property: consecutive picks must alternate between credential groups.
|
||||
for i := 1; i < len(parents); i++ {
|
||||
if parents[i] == parents[i-1] {
|
||||
t.Fatalf("Pick() #%d and #%d both from same parent %q (IDs: %q, %q); expected alternating credentials",
|
||||
i-1, i, parents[i], picks[i-1], picks[i])
|
||||
}
|
||||
}
|
||||
|
||||
// Verify property: each credential's projects are picked in sequence (round-robin within group).
|
||||
credPicks := map[string][]string{}
|
||||
for i, id := range picks {
|
||||
credPicks[parents[i]] = append(credPicks[parents[i]], id)
|
||||
}
|
||||
for parent, ids := range credPicks {
|
||||
for i := 1; i < len(ids); i++ {
|
||||
if ids[i] == ids[i-1] {
|
||||
t.Fatalf("Credential %q picked same project %q twice in a row", parent, ids[i])
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestRoundRobinSelectorPick_SingleParentFallsBackToFlat(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
selector := &RoundRobinSelector{}
|
||||
|
||||
// All auths from the same parent - should fall back to flat round-robin
|
||||
// because there's only one credential group (no benefit from two-level).
|
||||
auths := []*Auth{
|
||||
{ID: "cred-a.json::proj-a1", Attributes: map[string]string{"gemini_virtual_parent": "cred-a.json"}},
|
||||
{ID: "cred-a.json::proj-a2", Attributes: map[string]string{"gemini_virtual_parent": "cred-a.json"}},
|
||||
{ID: "cred-a.json::proj-a3", Attributes: map[string]string{"gemini_virtual_parent": "cred-a.json"}},
|
||||
}
|
||||
|
||||
// With single parent group, parentOrder has length 1, so it uses flat round-robin.
|
||||
// Sorted by ID: proj-a1, proj-a2, proj-a3
|
||||
want := []string{
|
||||
"cred-a.json::proj-a1",
|
||||
"cred-a.json::proj-a2",
|
||||
"cred-a.json::proj-a3",
|
||||
"cred-a.json::proj-a1",
|
||||
}
|
||||
|
||||
for i, expectedID := range want {
|
||||
got, err := selector.Pick(context.Background(), "gemini-cli", "gemini-2.5-pro", cliproxyexecutor.Options{}, auths)
|
||||
if err != nil {
|
||||
t.Fatalf("Pick() #%d error = %v", i, err)
|
||||
}
|
||||
if got == nil {
|
||||
t.Fatalf("Pick() #%d auth = nil", i)
|
||||
}
|
||||
if got.ID != expectedID {
|
||||
t.Fatalf("Pick() #%d auth.ID = %q, want %q", i, got.ID, expectedID)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestRoundRobinSelectorPick_MixedVirtualAndNonVirtualFallsBackToFlat(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
selector := &RoundRobinSelector{}
|
||||
|
||||
// Mix of virtual and non-virtual auths (e.g., a regular gemini-cli auth without projects
|
||||
// alongside virtual ones). Should fall back to flat round-robin.
|
||||
auths := []*Auth{
|
||||
{ID: "cred-a.json::proj-a1", Attributes: map[string]string{"gemini_virtual_parent": "cred-a.json"}},
|
||||
{ID: "cred-regular.json"}, // no gemini_virtual_parent
|
||||
}
|
||||
|
||||
// groupByVirtualParent returns nil when any auth lacks the attribute,
|
||||
// so flat round-robin is used. Sorted by ID: cred-a.json::proj-a1, cred-regular.json
|
||||
want := []string{
|
||||
"cred-a.json::proj-a1",
|
||||
"cred-regular.json",
|
||||
"cred-a.json::proj-a1",
|
||||
}
|
||||
|
||||
for i, expectedID := range want {
|
||||
got, err := selector.Pick(context.Background(), "gemini-cli", "", cliproxyexecutor.Options{}, auths)
|
||||
if err != nil {
|
||||
t.Fatalf("Pick() #%d error = %v", i, err)
|
||||
}
|
||||
if got == nil {
|
||||
t.Fatalf("Pick() #%d auth = nil", i)
|
||||
}
|
||||
if got.ID != expectedID {
|
||||
t.Fatalf("Pick() #%d auth.ID = %q, want %q", i, got.ID, expectedID)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user