mirror of
https://github.com/router-for-me/CLIProxyAPI.git
synced 2026-02-02 20:40:52 +08:00
feat(selector): add priority support for auth selection
This commit is contained in:
@@ -242,6 +242,10 @@ type ClaudeKey struct {
|
|||||||
// APIKey is the authentication key for accessing Claude API services.
|
// APIKey is the authentication key for accessing Claude API services.
|
||||||
APIKey string `yaml:"api-key" json:"api-key"`
|
APIKey string `yaml:"api-key" json:"api-key"`
|
||||||
|
|
||||||
|
// Priority controls selection preference when multiple credentials match.
|
||||||
|
// Higher values are preferred; defaults to 0.
|
||||||
|
Priority int `yaml:"priority,omitempty" json:"priority,omitempty"`
|
||||||
|
|
||||||
// Prefix optionally namespaces models for this credential (e.g., "teamA/claude-sonnet-4").
|
// Prefix optionally namespaces models for this credential (e.g., "teamA/claude-sonnet-4").
|
||||||
Prefix string `yaml:"prefix,omitempty" json:"prefix,omitempty"`
|
Prefix string `yaml:"prefix,omitempty" json:"prefix,omitempty"`
|
||||||
|
|
||||||
@@ -280,6 +284,10 @@ type CodexKey struct {
|
|||||||
// APIKey is the authentication key for accessing Codex API services.
|
// APIKey is the authentication key for accessing Codex API services.
|
||||||
APIKey string `yaml:"api-key" json:"api-key"`
|
APIKey string `yaml:"api-key" json:"api-key"`
|
||||||
|
|
||||||
|
// Priority controls selection preference when multiple credentials match.
|
||||||
|
// Higher values are preferred; defaults to 0.
|
||||||
|
Priority int `yaml:"priority,omitempty" json:"priority,omitempty"`
|
||||||
|
|
||||||
// Prefix optionally namespaces models for this credential (e.g., "teamA/gpt-5-codex").
|
// Prefix optionally namespaces models for this credential (e.g., "teamA/gpt-5-codex").
|
||||||
Prefix string `yaml:"prefix,omitempty" json:"prefix,omitempty"`
|
Prefix string `yaml:"prefix,omitempty" json:"prefix,omitempty"`
|
||||||
|
|
||||||
@@ -318,6 +326,10 @@ type GeminiKey struct {
|
|||||||
// APIKey is the authentication key for accessing Gemini API services.
|
// APIKey is the authentication key for accessing Gemini API services.
|
||||||
APIKey string `yaml:"api-key" json:"api-key"`
|
APIKey string `yaml:"api-key" json:"api-key"`
|
||||||
|
|
||||||
|
// Priority controls selection preference when multiple credentials match.
|
||||||
|
// Higher values are preferred; defaults to 0.
|
||||||
|
Priority int `yaml:"priority,omitempty" json:"priority,omitempty"`
|
||||||
|
|
||||||
// Prefix optionally namespaces models for this credential (e.g., "teamA/gemini-3-pro-preview").
|
// Prefix optionally namespaces models for this credential (e.g., "teamA/gemini-3-pro-preview").
|
||||||
Prefix string `yaml:"prefix,omitempty" json:"prefix,omitempty"`
|
Prefix string `yaml:"prefix,omitempty" json:"prefix,omitempty"`
|
||||||
|
|
||||||
@@ -355,6 +367,10 @@ type OpenAICompatibility struct {
|
|||||||
// Name is the identifier for this OpenAI compatibility configuration.
|
// Name is the identifier for this OpenAI compatibility configuration.
|
||||||
Name string `yaml:"name" json:"name"`
|
Name string `yaml:"name" json:"name"`
|
||||||
|
|
||||||
|
// Priority controls selection preference when multiple providers or credentials match.
|
||||||
|
// Higher values are preferred; defaults to 0.
|
||||||
|
Priority int `yaml:"priority,omitempty" json:"priority,omitempty"`
|
||||||
|
|
||||||
// Prefix optionally namespaces model aliases for this provider (e.g., "teamA/kimi-k2").
|
// Prefix optionally namespaces model aliases for this provider (e.g., "teamA/kimi-k2").
|
||||||
Prefix string `yaml:"prefix,omitempty" json:"prefix,omitempty"`
|
Prefix string `yaml:"prefix,omitempty" json:"prefix,omitempty"`
|
||||||
|
|
||||||
|
|||||||
@@ -13,6 +13,10 @@ type VertexCompatKey struct {
|
|||||||
// Maps to the x-goog-api-key header.
|
// Maps to the x-goog-api-key header.
|
||||||
APIKey string `yaml:"api-key" json:"api-key"`
|
APIKey string `yaml:"api-key" json:"api-key"`
|
||||||
|
|
||||||
|
// Priority controls selection preference when multiple credentials match.
|
||||||
|
// Higher values are preferred; defaults to 0.
|
||||||
|
Priority int `yaml:"priority,omitempty" json:"priority,omitempty"`
|
||||||
|
|
||||||
// Prefix optionally namespaces model aliases for this credential (e.g., "teamA/vertex-pro").
|
// Prefix optionally namespaces model aliases for this credential (e.g., "teamA/vertex-pro").
|
||||||
Prefix string `yaml:"prefix,omitempty" json:"prefix,omitempty"`
|
Prefix string `yaml:"prefix,omitempty" json:"prefix,omitempty"`
|
||||||
|
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ package synthesizer
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/watcher/diff"
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/watcher/diff"
|
||||||
@@ -59,6 +60,9 @@ func (s *ConfigSynthesizer) synthesizeGeminiKeys(ctx *SynthesisContext) []*corea
|
|||||||
"source": fmt.Sprintf("config:gemini[%s]", token),
|
"source": fmt.Sprintf("config:gemini[%s]", token),
|
||||||
"api_key": key,
|
"api_key": key,
|
||||||
}
|
}
|
||||||
|
if entry.Priority != 0 {
|
||||||
|
attrs["priority"] = strconv.Itoa(entry.Priority)
|
||||||
|
}
|
||||||
if base != "" {
|
if base != "" {
|
||||||
attrs["base_url"] = base
|
attrs["base_url"] = base
|
||||||
}
|
}
|
||||||
@@ -103,6 +107,9 @@ func (s *ConfigSynthesizer) synthesizeClaudeKeys(ctx *SynthesisContext) []*corea
|
|||||||
"source": fmt.Sprintf("config:claude[%s]", token),
|
"source": fmt.Sprintf("config:claude[%s]", token),
|
||||||
"api_key": key,
|
"api_key": key,
|
||||||
}
|
}
|
||||||
|
if ck.Priority != 0 {
|
||||||
|
attrs["priority"] = strconv.Itoa(ck.Priority)
|
||||||
|
}
|
||||||
if base != "" {
|
if base != "" {
|
||||||
attrs["base_url"] = base
|
attrs["base_url"] = base
|
||||||
}
|
}
|
||||||
@@ -147,6 +154,9 @@ func (s *ConfigSynthesizer) synthesizeCodexKeys(ctx *SynthesisContext) []*coreau
|
|||||||
"source": fmt.Sprintf("config:codex[%s]", token),
|
"source": fmt.Sprintf("config:codex[%s]", token),
|
||||||
"api_key": key,
|
"api_key": key,
|
||||||
}
|
}
|
||||||
|
if ck.Priority != 0 {
|
||||||
|
attrs["priority"] = strconv.Itoa(ck.Priority)
|
||||||
|
}
|
||||||
if ck.BaseURL != "" {
|
if ck.BaseURL != "" {
|
||||||
attrs["base_url"] = ck.BaseURL
|
attrs["base_url"] = ck.BaseURL
|
||||||
}
|
}
|
||||||
@@ -202,6 +212,9 @@ func (s *ConfigSynthesizer) synthesizeOpenAICompat(ctx *SynthesisContext) []*cor
|
|||||||
"compat_name": compat.Name,
|
"compat_name": compat.Name,
|
||||||
"provider_key": providerName,
|
"provider_key": providerName,
|
||||||
}
|
}
|
||||||
|
if compat.Priority != 0 {
|
||||||
|
attrs["priority"] = strconv.Itoa(compat.Priority)
|
||||||
|
}
|
||||||
if key != "" {
|
if key != "" {
|
||||||
attrs["api_key"] = key
|
attrs["api_key"] = key
|
||||||
}
|
}
|
||||||
@@ -233,6 +246,9 @@ func (s *ConfigSynthesizer) synthesizeOpenAICompat(ctx *SynthesisContext) []*cor
|
|||||||
"compat_name": compat.Name,
|
"compat_name": compat.Name,
|
||||||
"provider_key": providerName,
|
"provider_key": providerName,
|
||||||
}
|
}
|
||||||
|
if compat.Priority != 0 {
|
||||||
|
attrs["priority"] = strconv.Itoa(compat.Priority)
|
||||||
|
}
|
||||||
if hash := diff.ComputeOpenAICompatModelsHash(compat.Models); hash != "" {
|
if hash := diff.ComputeOpenAICompatModelsHash(compat.Models); hash != "" {
|
||||||
attrs["models_hash"] = hash
|
attrs["models_hash"] = hash
|
||||||
}
|
}
|
||||||
@@ -275,6 +291,9 @@ func (s *ConfigSynthesizer) synthesizeVertexCompat(ctx *SynthesisContext) []*cor
|
|||||||
"base_url": base,
|
"base_url": base,
|
||||||
"provider_key": providerName,
|
"provider_key": providerName,
|
||||||
}
|
}
|
||||||
|
if compat.Priority != 0 {
|
||||||
|
attrs["priority"] = strconv.Itoa(compat.Priority)
|
||||||
|
}
|
||||||
if key != "" {
|
if key != "" {
|
||||||
attrs["api_key"] = key
|
attrs["api_key"] = key
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -271,7 +271,6 @@ func (m *Manager) Execute(ctx context.Context, providers []string, req cliproxye
|
|||||||
if len(normalized) == 0 {
|
if len(normalized) == 0 {
|
||||||
return cliproxyexecutor.Response{}, &Error{Code: "provider_not_found", Message: "no provider supplied"}
|
return cliproxyexecutor.Response{}, &Error{Code: "provider_not_found", Message: "no provider supplied"}
|
||||||
}
|
}
|
||||||
rotated := m.rotateProviders(req.Model, normalized)
|
|
||||||
|
|
||||||
retryTimes, maxWait := m.retrySettings()
|
retryTimes, maxWait := m.retrySettings()
|
||||||
attempts := retryTimes + 1
|
attempts := retryTimes + 1
|
||||||
@@ -281,14 +280,12 @@ func (m *Manager) Execute(ctx context.Context, providers []string, req cliproxye
|
|||||||
|
|
||||||
var lastErr error
|
var lastErr error
|
||||||
for attempt := 0; attempt < attempts; attempt++ {
|
for attempt := 0; attempt < attempts; attempt++ {
|
||||||
resp, errExec := m.executeProvidersOnce(ctx, rotated, func(execCtx context.Context, provider string) (cliproxyexecutor.Response, error) {
|
resp, errExec := m.executeMixedOnce(ctx, normalized, req, opts)
|
||||||
return m.executeWithProvider(execCtx, provider, req, opts)
|
|
||||||
})
|
|
||||||
if errExec == nil {
|
if errExec == nil {
|
||||||
return resp, nil
|
return resp, nil
|
||||||
}
|
}
|
||||||
lastErr = errExec
|
lastErr = errExec
|
||||||
wait, shouldRetry := m.shouldRetryAfterError(errExec, attempt, attempts, rotated, req.Model, maxWait)
|
wait, shouldRetry := m.shouldRetryAfterError(errExec, attempt, attempts, normalized, req.Model, maxWait)
|
||||||
if !shouldRetry {
|
if !shouldRetry {
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
@@ -309,7 +306,6 @@ func (m *Manager) ExecuteCount(ctx context.Context, providers []string, req clip
|
|||||||
if len(normalized) == 0 {
|
if len(normalized) == 0 {
|
||||||
return cliproxyexecutor.Response{}, &Error{Code: "provider_not_found", Message: "no provider supplied"}
|
return cliproxyexecutor.Response{}, &Error{Code: "provider_not_found", Message: "no provider supplied"}
|
||||||
}
|
}
|
||||||
rotated := m.rotateProviders(req.Model, normalized)
|
|
||||||
|
|
||||||
retryTimes, maxWait := m.retrySettings()
|
retryTimes, maxWait := m.retrySettings()
|
||||||
attempts := retryTimes + 1
|
attempts := retryTimes + 1
|
||||||
@@ -319,14 +315,12 @@ func (m *Manager) ExecuteCount(ctx context.Context, providers []string, req clip
|
|||||||
|
|
||||||
var lastErr error
|
var lastErr error
|
||||||
for attempt := 0; attempt < attempts; attempt++ {
|
for attempt := 0; attempt < attempts; attempt++ {
|
||||||
resp, errExec := m.executeProvidersOnce(ctx, rotated, func(execCtx context.Context, provider string) (cliproxyexecutor.Response, error) {
|
resp, errExec := m.executeCountMixedOnce(ctx, normalized, req, opts)
|
||||||
return m.executeCountWithProvider(execCtx, provider, req, opts)
|
|
||||||
})
|
|
||||||
if errExec == nil {
|
if errExec == nil {
|
||||||
return resp, nil
|
return resp, nil
|
||||||
}
|
}
|
||||||
lastErr = errExec
|
lastErr = errExec
|
||||||
wait, shouldRetry := m.shouldRetryAfterError(errExec, attempt, attempts, rotated, req.Model, maxWait)
|
wait, shouldRetry := m.shouldRetryAfterError(errExec, attempt, attempts, normalized, req.Model, maxWait)
|
||||||
if !shouldRetry {
|
if !shouldRetry {
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
@@ -347,7 +341,6 @@ func (m *Manager) ExecuteStream(ctx context.Context, providers []string, req cli
|
|||||||
if len(normalized) == 0 {
|
if len(normalized) == 0 {
|
||||||
return nil, &Error{Code: "provider_not_found", Message: "no provider supplied"}
|
return nil, &Error{Code: "provider_not_found", Message: "no provider supplied"}
|
||||||
}
|
}
|
||||||
rotated := m.rotateProviders(req.Model, normalized)
|
|
||||||
|
|
||||||
retryTimes, maxWait := m.retrySettings()
|
retryTimes, maxWait := m.retrySettings()
|
||||||
attempts := retryTimes + 1
|
attempts := retryTimes + 1
|
||||||
@@ -357,14 +350,12 @@ func (m *Manager) ExecuteStream(ctx context.Context, providers []string, req cli
|
|||||||
|
|
||||||
var lastErr error
|
var lastErr error
|
||||||
for attempt := 0; attempt < attempts; attempt++ {
|
for attempt := 0; attempt < attempts; attempt++ {
|
||||||
chunks, errStream := m.executeStreamProvidersOnce(ctx, rotated, func(execCtx context.Context, provider string) (<-chan cliproxyexecutor.StreamChunk, error) {
|
chunks, errStream := m.executeStreamMixedOnce(ctx, normalized, req, opts)
|
||||||
return m.executeStreamWithProvider(execCtx, provider, req, opts)
|
|
||||||
})
|
|
||||||
if errStream == nil {
|
if errStream == nil {
|
||||||
return chunks, nil
|
return chunks, nil
|
||||||
}
|
}
|
||||||
lastErr = errStream
|
lastErr = errStream
|
||||||
wait, shouldRetry := m.shouldRetryAfterError(errStream, attempt, attempts, rotated, req.Model, maxWait)
|
wait, shouldRetry := m.shouldRetryAfterError(errStream, attempt, attempts, normalized, req.Model, maxWait)
|
||||||
if !shouldRetry {
|
if !shouldRetry {
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
@@ -378,6 +369,167 @@ func (m *Manager) ExecuteStream(ctx context.Context, providers []string, req cli
|
|||||||
return nil, &Error{Code: "auth_not_found", Message: "no auth available"}
|
return nil, &Error{Code: "auth_not_found", Message: "no auth available"}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (m *Manager) executeMixedOnce(ctx context.Context, providers []string, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
|
||||||
|
if len(providers) == 0 {
|
||||||
|
return cliproxyexecutor.Response{}, &Error{Code: "provider_not_found", Message: "no provider supplied"}
|
||||||
|
}
|
||||||
|
routeModel := req.Model
|
||||||
|
tried := make(map[string]struct{})
|
||||||
|
var lastErr error
|
||||||
|
for {
|
||||||
|
auth, executor, provider, errPick := m.pickNextMixed(ctx, providers, routeModel, opts, tried)
|
||||||
|
if errPick != nil {
|
||||||
|
if lastErr != nil {
|
||||||
|
return cliproxyexecutor.Response{}, lastErr
|
||||||
|
}
|
||||||
|
return cliproxyexecutor.Response{}, errPick
|
||||||
|
}
|
||||||
|
|
||||||
|
entry := logEntryWithRequestID(ctx)
|
||||||
|
debugLogAuthSelection(entry, auth, provider, req.Model)
|
||||||
|
|
||||||
|
tried[auth.ID] = struct{}{}
|
||||||
|
execCtx := ctx
|
||||||
|
if rt := m.roundTripperFor(auth); rt != nil {
|
||||||
|
execCtx = context.WithValue(execCtx, roundTripperContextKey{}, rt)
|
||||||
|
execCtx = context.WithValue(execCtx, "cliproxy.roundtripper", rt)
|
||||||
|
}
|
||||||
|
execReq := req
|
||||||
|
execReq.Model, execReq.Metadata = rewriteModelForAuth(routeModel, req.Metadata, auth)
|
||||||
|
execReq.Model, execReq.Metadata = m.applyOAuthModelMapping(auth, execReq.Model, execReq.Metadata)
|
||||||
|
resp, errExec := executor.Execute(execCtx, auth, execReq, opts)
|
||||||
|
result := Result{AuthID: auth.ID, Provider: provider, Model: routeModel, Success: errExec == nil}
|
||||||
|
if errExec != nil {
|
||||||
|
result.Error = &Error{Message: errExec.Error()}
|
||||||
|
var se cliproxyexecutor.StatusError
|
||||||
|
if errors.As(errExec, &se) && se != nil {
|
||||||
|
result.Error.HTTPStatus = se.StatusCode()
|
||||||
|
}
|
||||||
|
if ra := retryAfterFromError(errExec); ra != nil {
|
||||||
|
result.RetryAfter = ra
|
||||||
|
}
|
||||||
|
m.MarkResult(execCtx, result)
|
||||||
|
lastErr = errExec
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
m.MarkResult(execCtx, result)
|
||||||
|
return resp, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Manager) executeCountMixedOnce(ctx context.Context, providers []string, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
|
||||||
|
if len(providers) == 0 {
|
||||||
|
return cliproxyexecutor.Response{}, &Error{Code: "provider_not_found", Message: "no provider supplied"}
|
||||||
|
}
|
||||||
|
routeModel := req.Model
|
||||||
|
tried := make(map[string]struct{})
|
||||||
|
var lastErr error
|
||||||
|
for {
|
||||||
|
auth, executor, provider, errPick := m.pickNextMixed(ctx, providers, routeModel, opts, tried)
|
||||||
|
if errPick != nil {
|
||||||
|
if lastErr != nil {
|
||||||
|
return cliproxyexecutor.Response{}, lastErr
|
||||||
|
}
|
||||||
|
return cliproxyexecutor.Response{}, errPick
|
||||||
|
}
|
||||||
|
|
||||||
|
entry := logEntryWithRequestID(ctx)
|
||||||
|
debugLogAuthSelection(entry, auth, provider, req.Model)
|
||||||
|
|
||||||
|
tried[auth.ID] = struct{}{}
|
||||||
|
execCtx := ctx
|
||||||
|
if rt := m.roundTripperFor(auth); rt != nil {
|
||||||
|
execCtx = context.WithValue(execCtx, roundTripperContextKey{}, rt)
|
||||||
|
execCtx = context.WithValue(execCtx, "cliproxy.roundtripper", rt)
|
||||||
|
}
|
||||||
|
execReq := req
|
||||||
|
execReq.Model, execReq.Metadata = rewriteModelForAuth(routeModel, req.Metadata, auth)
|
||||||
|
execReq.Model, execReq.Metadata = m.applyOAuthModelMapping(auth, execReq.Model, execReq.Metadata)
|
||||||
|
resp, errExec := executor.CountTokens(execCtx, auth, execReq, opts)
|
||||||
|
result := Result{AuthID: auth.ID, Provider: provider, Model: routeModel, Success: errExec == nil}
|
||||||
|
if errExec != nil {
|
||||||
|
result.Error = &Error{Message: errExec.Error()}
|
||||||
|
var se cliproxyexecutor.StatusError
|
||||||
|
if errors.As(errExec, &se) && se != nil {
|
||||||
|
result.Error.HTTPStatus = se.StatusCode()
|
||||||
|
}
|
||||||
|
if ra := retryAfterFromError(errExec); ra != nil {
|
||||||
|
result.RetryAfter = ra
|
||||||
|
}
|
||||||
|
m.MarkResult(execCtx, result)
|
||||||
|
lastErr = errExec
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
m.MarkResult(execCtx, result)
|
||||||
|
return resp, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Manager) executeStreamMixedOnce(ctx context.Context, providers []string, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (<-chan cliproxyexecutor.StreamChunk, error) {
|
||||||
|
if len(providers) == 0 {
|
||||||
|
return nil, &Error{Code: "provider_not_found", Message: "no provider supplied"}
|
||||||
|
}
|
||||||
|
routeModel := req.Model
|
||||||
|
tried := make(map[string]struct{})
|
||||||
|
var lastErr error
|
||||||
|
for {
|
||||||
|
auth, executor, provider, errPick := m.pickNextMixed(ctx, providers, routeModel, opts, tried)
|
||||||
|
if errPick != nil {
|
||||||
|
if lastErr != nil {
|
||||||
|
return nil, lastErr
|
||||||
|
}
|
||||||
|
return nil, errPick
|
||||||
|
}
|
||||||
|
|
||||||
|
entry := logEntryWithRequestID(ctx)
|
||||||
|
debugLogAuthSelection(entry, auth, provider, req.Model)
|
||||||
|
|
||||||
|
tried[auth.ID] = struct{}{}
|
||||||
|
execCtx := ctx
|
||||||
|
if rt := m.roundTripperFor(auth); rt != nil {
|
||||||
|
execCtx = context.WithValue(execCtx, roundTripperContextKey{}, rt)
|
||||||
|
execCtx = context.WithValue(execCtx, "cliproxy.roundtripper", rt)
|
||||||
|
}
|
||||||
|
execReq := req
|
||||||
|
execReq.Model, execReq.Metadata = rewriteModelForAuth(routeModel, req.Metadata, auth)
|
||||||
|
execReq.Model, execReq.Metadata = m.applyOAuthModelMapping(auth, execReq.Model, execReq.Metadata)
|
||||||
|
chunks, errStream := executor.ExecuteStream(execCtx, auth, execReq, opts)
|
||||||
|
if errStream != nil {
|
||||||
|
rerr := &Error{Message: errStream.Error()}
|
||||||
|
var se cliproxyexecutor.StatusError
|
||||||
|
if errors.As(errStream, &se) && se != nil {
|
||||||
|
rerr.HTTPStatus = se.StatusCode()
|
||||||
|
}
|
||||||
|
result := Result{AuthID: auth.ID, Provider: provider, Model: routeModel, Success: false, Error: rerr}
|
||||||
|
result.RetryAfter = retryAfterFromError(errStream)
|
||||||
|
m.MarkResult(execCtx, result)
|
||||||
|
lastErr = errStream
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
out := make(chan cliproxyexecutor.StreamChunk)
|
||||||
|
go func(streamCtx context.Context, streamAuth *Auth, streamProvider string, streamChunks <-chan cliproxyexecutor.StreamChunk) {
|
||||||
|
defer close(out)
|
||||||
|
var failed bool
|
||||||
|
for chunk := range streamChunks {
|
||||||
|
if chunk.Err != nil && !failed {
|
||||||
|
failed = true
|
||||||
|
rerr := &Error{Message: chunk.Err.Error()}
|
||||||
|
var se cliproxyexecutor.StatusError
|
||||||
|
if errors.As(chunk.Err, &se) && se != nil {
|
||||||
|
rerr.HTTPStatus = se.StatusCode()
|
||||||
|
}
|
||||||
|
m.MarkResult(streamCtx, Result{AuthID: streamAuth.ID, Provider: streamProvider, Model: routeModel, Success: false, Error: rerr})
|
||||||
|
}
|
||||||
|
out <- chunk
|
||||||
|
}
|
||||||
|
if !failed {
|
||||||
|
m.MarkResult(streamCtx, Result{AuthID: streamAuth.ID, Provider: streamProvider, Model: routeModel, Success: true})
|
||||||
|
}
|
||||||
|
}(execCtx, auth.Clone(), provider, chunks)
|
||||||
|
return out, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func (m *Manager) executeWithProvider(ctx context.Context, provider string, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
|
func (m *Manager) executeWithProvider(ctx context.Context, provider string, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
|
||||||
if provider == "" {
|
if provider == "" {
|
||||||
return cliproxyexecutor.Response{}, &Error{Code: "provider_not_found", Message: "provider identifier is empty"}
|
return cliproxyexecutor.Response{}, &Error{Code: "provider_not_found", Message: "provider identifier is empty"}
|
||||||
@@ -1191,6 +1343,77 @@ func (m *Manager) pickNext(ctx context.Context, provider, model string, opts cli
|
|||||||
return authCopy, executor, nil
|
return authCopy, executor, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (m *Manager) pickNextMixed(ctx context.Context, providers []string, model string, opts cliproxyexecutor.Options, tried map[string]struct{}) (*Auth, ProviderExecutor, string, error) {
|
||||||
|
providerSet := make(map[string]struct{}, len(providers))
|
||||||
|
for _, provider := range providers {
|
||||||
|
p := strings.TrimSpace(strings.ToLower(provider))
|
||||||
|
if p == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
providerSet[p] = struct{}{}
|
||||||
|
}
|
||||||
|
if len(providerSet) == 0 {
|
||||||
|
return nil, nil, "", &Error{Code: "provider_not_found", Message: "no provider supplied"}
|
||||||
|
}
|
||||||
|
|
||||||
|
m.mu.RLock()
|
||||||
|
candidates := make([]*Auth, 0, len(m.auths))
|
||||||
|
modelKey := strings.TrimSpace(model)
|
||||||
|
registryRef := registry.GetGlobalRegistry()
|
||||||
|
for _, candidate := range m.auths {
|
||||||
|
if candidate == nil || candidate.Disabled {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
providerKey := strings.TrimSpace(strings.ToLower(candidate.Provider))
|
||||||
|
if providerKey == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if _, ok := providerSet[providerKey]; !ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if _, used := tried[candidate.ID]; used {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if _, ok := m.executors[providerKey]; !ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if modelKey != "" && registryRef != nil && !registryRef.ClientSupportsModel(candidate.ID, modelKey) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
candidates = append(candidates, candidate)
|
||||||
|
}
|
||||||
|
if len(candidates) == 0 {
|
||||||
|
m.mu.RUnlock()
|
||||||
|
return nil, nil, "", &Error{Code: "auth_not_found", Message: "no auth available"}
|
||||||
|
}
|
||||||
|
selected, errPick := m.selector.Pick(ctx, "mixed", model, opts, candidates)
|
||||||
|
if errPick != nil {
|
||||||
|
m.mu.RUnlock()
|
||||||
|
return nil, nil, "", errPick
|
||||||
|
}
|
||||||
|
if selected == nil {
|
||||||
|
m.mu.RUnlock()
|
||||||
|
return nil, nil, "", &Error{Code: "auth_not_found", Message: "selector returned no auth"}
|
||||||
|
}
|
||||||
|
providerKey := strings.TrimSpace(strings.ToLower(selected.Provider))
|
||||||
|
executor, okExecutor := m.executors[providerKey]
|
||||||
|
if !okExecutor {
|
||||||
|
m.mu.RUnlock()
|
||||||
|
return nil, nil, "", &Error{Code: "executor_not_found", Message: "executor not registered"}
|
||||||
|
}
|
||||||
|
authCopy := selected.Clone()
|
||||||
|
m.mu.RUnlock()
|
||||||
|
if !selected.indexAssigned {
|
||||||
|
m.mu.Lock()
|
||||||
|
if current := m.auths[authCopy.ID]; current != nil && !current.indexAssigned {
|
||||||
|
current.EnsureIndex()
|
||||||
|
authCopy = current.Clone()
|
||||||
|
}
|
||||||
|
m.mu.Unlock()
|
||||||
|
}
|
||||||
|
return authCopy, executor, providerKey, nil
|
||||||
|
}
|
||||||
|
|
||||||
func (m *Manager) persist(ctx context.Context, auth *Auth) error {
|
func (m *Manager) persist(ctx context.Context, auth *Auth) error {
|
||||||
if m.store == nil || auth == nil {
|
if m.store == nil || auth == nil {
|
||||||
return nil
|
return nil
|
||||||
|
|||||||
@@ -8,6 +8,7 @@ import (
|
|||||||
"net/http"
|
"net/http"
|
||||||
"sort"
|
"sort"
|
||||||
"strconv"
|
"strconv"
|
||||||
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@@ -103,13 +104,29 @@ func (e *modelCooldownError) Headers() http.Header {
|
|||||||
return headers
|
return headers
|
||||||
}
|
}
|
||||||
|
|
||||||
func collectAvailable(auths []*Auth, model string, now time.Time) (available []*Auth, cooldownCount int, earliest time.Time) {
|
func authPriority(auth *Auth) int {
|
||||||
available = make([]*Auth, 0, len(auths))
|
if auth == nil || auth.Attributes == nil {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
raw := strings.TrimSpace(auth.Attributes["priority"])
|
||||||
|
if raw == "" {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
parsed, err := strconv.Atoi(raw)
|
||||||
|
if err != nil {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
return parsed
|
||||||
|
}
|
||||||
|
|
||||||
|
func collectAvailableByPriority(auths []*Auth, model string, now time.Time) (available map[int][]*Auth, cooldownCount int, earliest time.Time) {
|
||||||
|
available = make(map[int][]*Auth)
|
||||||
for i := 0; i < len(auths); i++ {
|
for i := 0; i < len(auths); i++ {
|
||||||
candidate := auths[i]
|
candidate := auths[i]
|
||||||
blocked, reason, next := isAuthBlockedForModel(candidate, model, now)
|
blocked, reason, next := isAuthBlockedForModel(candidate, model, now)
|
||||||
if !blocked {
|
if !blocked {
|
||||||
available = append(available, candidate)
|
priority := authPriority(candidate)
|
||||||
|
available[priority] = append(available[priority], candidate)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
if reason == blockReasonCooldown {
|
if reason == blockReasonCooldown {
|
||||||
@@ -119,9 +136,6 @@ func collectAvailable(auths []*Auth, model string, now time.Time) (available []*
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if len(available) > 1 {
|
|
||||||
sort.Slice(available, func(i, j int) bool { return available[i].ID < available[j].ID })
|
|
||||||
}
|
|
||||||
return available, cooldownCount, earliest
|
return available, cooldownCount, earliest
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -130,18 +144,35 @@ func getAvailableAuths(auths []*Auth, provider, model string, now time.Time) ([]
|
|||||||
return nil, &Error{Code: "auth_not_found", Message: "no auth candidates"}
|
return nil, &Error{Code: "auth_not_found", Message: "no auth candidates"}
|
||||||
}
|
}
|
||||||
|
|
||||||
available, cooldownCount, earliest := collectAvailable(auths, model, now)
|
availableByPriority, cooldownCount, earliest := collectAvailableByPriority(auths, model, now)
|
||||||
if len(available) == 0 {
|
if len(availableByPriority) == 0 {
|
||||||
if cooldownCount == len(auths) && !earliest.IsZero() {
|
if cooldownCount == len(auths) && !earliest.IsZero() {
|
||||||
|
providerForError := provider
|
||||||
|
if providerForError == "mixed" {
|
||||||
|
providerForError = ""
|
||||||
|
}
|
||||||
resetIn := earliest.Sub(now)
|
resetIn := earliest.Sub(now)
|
||||||
if resetIn < 0 {
|
if resetIn < 0 {
|
||||||
resetIn = 0
|
resetIn = 0
|
||||||
}
|
}
|
||||||
return nil, newModelCooldownError(model, provider, resetIn)
|
return nil, newModelCooldownError(model, providerForError, resetIn)
|
||||||
}
|
}
|
||||||
return nil, &Error{Code: "auth_unavailable", Message: "no auth available"}
|
return nil, &Error{Code: "auth_unavailable", Message: "no auth available"}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
bestPriority := 0
|
||||||
|
found := false
|
||||||
|
for priority := range availableByPriority {
|
||||||
|
if !found || priority > bestPriority {
|
||||||
|
bestPriority = priority
|
||||||
|
found = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
available := availableByPriority[bestPriority]
|
||||||
|
if len(available) > 1 {
|
||||||
|
sort.Slice(available, func(i, j int) bool { return available[i].ID < available[j].ID })
|
||||||
|
}
|
||||||
return available, nil
|
return available, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -5,6 +5,7 @@ import (
|
|||||||
"errors"
|
"errors"
|
||||||
"sync"
|
"sync"
|
||||||
"testing"
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
|
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
|
||||||
)
|
)
|
||||||
@@ -56,6 +57,69 @@ func TestRoundRobinSelectorPick_CyclesDeterministic(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestRoundRobinSelectorPick_PriorityBuckets(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
selector := &RoundRobinSelector{}
|
||||||
|
auths := []*Auth{
|
||||||
|
{ID: "c", Attributes: map[string]string{"priority": "0"}},
|
||||||
|
{ID: "a", Attributes: map[string]string{"priority": "10"}},
|
||||||
|
{ID: "b", Attributes: map[string]string{"priority": "10"}},
|
||||||
|
}
|
||||||
|
|
||||||
|
want := []string{"a", "b", "a", "b"}
|
||||||
|
for i, id := range want {
|
||||||
|
got, err := selector.Pick(context.Background(), "mixed", "", cliproxyexecutor.Options{}, auths)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Pick() #%d error = %v", i, err)
|
||||||
|
}
|
||||||
|
if got == nil {
|
||||||
|
t.Fatalf("Pick() #%d auth = nil", i)
|
||||||
|
}
|
||||||
|
if got.ID != id {
|
||||||
|
t.Fatalf("Pick() #%d auth.ID = %q, want %q", i, got.ID, id)
|
||||||
|
}
|
||||||
|
if got.ID == "c" {
|
||||||
|
t.Fatalf("Pick() #%d unexpectedly selected lower priority auth", i)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestFillFirstSelectorPick_PriorityFallbackCooldown(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
selector := &FillFirstSelector{}
|
||||||
|
now := time.Now()
|
||||||
|
model := "test-model"
|
||||||
|
|
||||||
|
high := &Auth{
|
||||||
|
ID: "high",
|
||||||
|
Attributes: map[string]string{"priority": "10"},
|
||||||
|
ModelStates: map[string]*ModelState{
|
||||||
|
model: {
|
||||||
|
Status: StatusActive,
|
||||||
|
Unavailable: true,
|
||||||
|
NextRetryAfter: now.Add(30 * time.Minute),
|
||||||
|
Quota: QuotaState{
|
||||||
|
Exceeded: true,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
low := &Auth{ID: "low", Attributes: map[string]string{"priority": "0"}}
|
||||||
|
|
||||||
|
got, err := selector.Pick(context.Background(), "mixed", model, cliproxyexecutor.Options{}, []*Auth{high, low})
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Pick() error = %v", err)
|
||||||
|
}
|
||||||
|
if got == nil {
|
||||||
|
t.Fatalf("Pick() auth = nil")
|
||||||
|
}
|
||||||
|
if got.ID != "low" {
|
||||||
|
t.Fatalf("Pick() auth.ID = %q, want %q", got.ID, "low")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func TestRoundRobinSelectorPick_Concurrent(t *testing.T) {
|
func TestRoundRobinSelectorPick_Concurrent(t *testing.T) {
|
||||||
selector := &RoundRobinSelector{}
|
selector := &RoundRobinSelector{}
|
||||||
auths := []*Auth{
|
auths := []*Auth{
|
||||||
|
|||||||
Reference in New Issue
Block a user