mirror of
https://github.com/router-for-me/CLIProxyAPI.git
synced 2026-02-03 13:00:52 +08:00
Introduce `WithSkipPersist` to disable persistence during Manager Update/Register calls, preventing write-back loops caused by redundant file writes. Add corresponding tests and integrate with existing file store and conductor logic.
2210 lines
61 KiB
Go
2210 lines
61 KiB
Go
package auth
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"encoding/json"
|
|
"errors"
|
|
"io"
|
|
"net/http"
|
|
"path/filepath"
|
|
"strconv"
|
|
"strings"
|
|
"sync"
|
|
"sync/atomic"
|
|
"time"
|
|
|
|
"github.com/google/uuid"
|
|
internalconfig "github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/logging"
|
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
|
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/thinking"
|
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
|
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
|
|
log "github.com/sirupsen/logrus"
|
|
)
|
|
|
|
// ProviderExecutor defines the contract required by Manager to execute provider calls.
|
|
type ProviderExecutor interface {
|
|
// Identifier returns the provider key handled by this executor.
|
|
Identifier() string
|
|
// Execute handles non-streaming execution and returns the provider response payload.
|
|
Execute(ctx context.Context, auth *Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error)
|
|
// ExecuteStream handles streaming execution and returns a channel of provider chunks.
|
|
ExecuteStream(ctx context.Context, auth *Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (<-chan cliproxyexecutor.StreamChunk, error)
|
|
// Refresh attempts to refresh provider credentials and returns the updated auth state.
|
|
Refresh(ctx context.Context, auth *Auth) (*Auth, error)
|
|
// CountTokens returns the token count for the given request.
|
|
CountTokens(ctx context.Context, auth *Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error)
|
|
// HttpRequest injects provider credentials into the supplied HTTP request and executes it.
|
|
// Callers must close the response body when non-nil.
|
|
HttpRequest(ctx context.Context, auth *Auth, req *http.Request) (*http.Response, error)
|
|
}
|
|
|
|
// RefreshEvaluator allows runtime state to override refresh decisions.
|
|
type RefreshEvaluator interface {
|
|
ShouldRefresh(now time.Time, auth *Auth) bool
|
|
}
|
|
|
|
const (
|
|
refreshCheckInterval = 5 * time.Second
|
|
refreshPendingBackoff = time.Minute
|
|
refreshFailureBackoff = 5 * time.Minute
|
|
quotaBackoffBase = time.Second
|
|
quotaBackoffMax = 30 * time.Minute
|
|
)
|
|
|
|
var quotaCooldownDisabled atomic.Bool
|
|
|
|
// SetQuotaCooldownDisabled toggles quota cooldown scheduling globally.
|
|
func SetQuotaCooldownDisabled(disable bool) {
|
|
quotaCooldownDisabled.Store(disable)
|
|
}
|
|
|
|
// Result captures execution outcome used to adjust auth state.
|
|
type Result struct {
|
|
// AuthID references the auth that produced this result.
|
|
AuthID string
|
|
// Provider is copied for convenience when emitting hooks.
|
|
Provider string
|
|
// Model is the upstream model identifier used for the request.
|
|
Model string
|
|
// Success marks whether the execution succeeded.
|
|
Success bool
|
|
// RetryAfter carries a provider supplied retry hint (e.g. 429 retryDelay).
|
|
RetryAfter *time.Duration
|
|
// Error describes the failure when Success is false.
|
|
Error *Error
|
|
}
|
|
|
|
// Selector chooses an auth candidate for execution.
|
|
type Selector interface {
|
|
Pick(ctx context.Context, provider, model string, opts cliproxyexecutor.Options, auths []*Auth) (*Auth, error)
|
|
}
|
|
|
|
// Hook captures lifecycle callbacks for observing auth changes.
|
|
type Hook interface {
|
|
// OnAuthRegistered fires when a new auth is registered.
|
|
OnAuthRegistered(ctx context.Context, auth *Auth)
|
|
// OnAuthUpdated fires when an existing auth changes state.
|
|
OnAuthUpdated(ctx context.Context, auth *Auth)
|
|
// OnResult fires when execution result is recorded.
|
|
OnResult(ctx context.Context, result Result)
|
|
}
|
|
|
|
// NoopHook provides optional hook defaults.
|
|
type NoopHook struct{}
|
|
|
|
// OnAuthRegistered implements Hook.
|
|
func (NoopHook) OnAuthRegistered(context.Context, *Auth) {}
|
|
|
|
// OnAuthUpdated implements Hook.
|
|
func (NoopHook) OnAuthUpdated(context.Context, *Auth) {}
|
|
|
|
// OnResult implements Hook.
|
|
func (NoopHook) OnResult(context.Context, Result) {}
|
|
|
|
// Manager orchestrates auth lifecycle, selection, execution, and persistence.
|
|
type Manager struct {
|
|
store Store
|
|
executors map[string]ProviderExecutor
|
|
selector Selector
|
|
hook Hook
|
|
mu sync.RWMutex
|
|
auths map[string]*Auth
|
|
// providerOffsets tracks per-model provider rotation state for multi-provider routing.
|
|
providerOffsets map[string]int
|
|
|
|
// Retry controls request retry behavior.
|
|
requestRetry atomic.Int32
|
|
maxRetryInterval atomic.Int64
|
|
|
|
// oauthModelAlias stores global OAuth model alias mappings (alias -> upstream name) keyed by channel.
|
|
oauthModelAlias atomic.Value
|
|
|
|
// apiKeyModelAlias caches resolved model alias mappings for API-key auths.
|
|
// Keyed by auth.ID, value is alias(lower) -> upstream model (including suffix).
|
|
apiKeyModelAlias atomic.Value
|
|
|
|
// runtimeConfig stores the latest application config for request-time decisions.
|
|
// It is initialized in NewManager; never Load() before first Store().
|
|
runtimeConfig atomic.Value
|
|
|
|
// Optional HTTP RoundTripper provider injected by host.
|
|
rtProvider RoundTripperProvider
|
|
|
|
// Auto refresh state
|
|
refreshCancel context.CancelFunc
|
|
}
|
|
|
|
// NewManager constructs a manager with optional custom selector and hook.
|
|
func NewManager(store Store, selector Selector, hook Hook) *Manager {
|
|
if selector == nil {
|
|
selector = &RoundRobinSelector{}
|
|
}
|
|
if hook == nil {
|
|
hook = NoopHook{}
|
|
}
|
|
manager := &Manager{
|
|
store: store,
|
|
executors: make(map[string]ProviderExecutor),
|
|
selector: selector,
|
|
hook: hook,
|
|
auths: make(map[string]*Auth),
|
|
providerOffsets: make(map[string]int),
|
|
}
|
|
// atomic.Value requires non-nil initial value.
|
|
manager.runtimeConfig.Store(&internalconfig.Config{})
|
|
manager.apiKeyModelAlias.Store(apiKeyModelAliasTable(nil))
|
|
return manager
|
|
}
|
|
|
|
func (m *Manager) SetSelector(selector Selector) {
|
|
if m == nil {
|
|
return
|
|
}
|
|
if selector == nil {
|
|
selector = &RoundRobinSelector{}
|
|
}
|
|
m.mu.Lock()
|
|
m.selector = selector
|
|
m.mu.Unlock()
|
|
}
|
|
|
|
// SetStore swaps the underlying persistence store.
|
|
func (m *Manager) SetStore(store Store) {
|
|
m.mu.Lock()
|
|
defer m.mu.Unlock()
|
|
m.store = store
|
|
}
|
|
|
|
// SetRoundTripperProvider register a provider that returns a per-auth RoundTripper.
|
|
func (m *Manager) SetRoundTripperProvider(p RoundTripperProvider) {
|
|
m.mu.Lock()
|
|
m.rtProvider = p
|
|
m.mu.Unlock()
|
|
}
|
|
|
|
// SetConfig updates the runtime config snapshot used by request-time helpers.
|
|
// Callers should provide the latest config on reload so per-credential alias mapping stays in sync.
|
|
func (m *Manager) SetConfig(cfg *internalconfig.Config) {
|
|
if m == nil {
|
|
return
|
|
}
|
|
if cfg == nil {
|
|
cfg = &internalconfig.Config{}
|
|
}
|
|
m.runtimeConfig.Store(cfg)
|
|
m.rebuildAPIKeyModelAliasFromRuntimeConfig()
|
|
}
|
|
|
|
func (m *Manager) lookupAPIKeyUpstreamModel(authID, requestedModel string) string {
|
|
if m == nil {
|
|
return ""
|
|
}
|
|
authID = strings.TrimSpace(authID)
|
|
if authID == "" {
|
|
return ""
|
|
}
|
|
requestedModel = strings.TrimSpace(requestedModel)
|
|
if requestedModel == "" {
|
|
return ""
|
|
}
|
|
table, _ := m.apiKeyModelAlias.Load().(apiKeyModelAliasTable)
|
|
if table == nil {
|
|
return ""
|
|
}
|
|
byAlias := table[authID]
|
|
if len(byAlias) == 0 {
|
|
return ""
|
|
}
|
|
key := strings.ToLower(thinking.ParseSuffix(requestedModel).ModelName)
|
|
if key == "" {
|
|
key = strings.ToLower(requestedModel)
|
|
}
|
|
resolved := strings.TrimSpace(byAlias[key])
|
|
if resolved == "" {
|
|
return ""
|
|
}
|
|
// Preserve thinking suffix from the client's requested model unless config already has one.
|
|
requestResult := thinking.ParseSuffix(requestedModel)
|
|
if thinking.ParseSuffix(resolved).HasSuffix {
|
|
return resolved
|
|
}
|
|
if requestResult.HasSuffix && requestResult.RawSuffix != "" {
|
|
return resolved + "(" + requestResult.RawSuffix + ")"
|
|
}
|
|
return resolved
|
|
|
|
}
|
|
|
|
func (m *Manager) rebuildAPIKeyModelAliasFromRuntimeConfig() {
|
|
if m == nil {
|
|
return
|
|
}
|
|
cfg, _ := m.runtimeConfig.Load().(*internalconfig.Config)
|
|
if cfg == nil {
|
|
cfg = &internalconfig.Config{}
|
|
}
|
|
m.mu.Lock()
|
|
defer m.mu.Unlock()
|
|
m.rebuildAPIKeyModelAliasLocked(cfg)
|
|
}
|
|
|
|
func (m *Manager) rebuildAPIKeyModelAliasLocked(cfg *internalconfig.Config) {
|
|
if m == nil {
|
|
return
|
|
}
|
|
if cfg == nil {
|
|
cfg = &internalconfig.Config{}
|
|
}
|
|
|
|
out := make(apiKeyModelAliasTable)
|
|
for _, auth := range m.auths {
|
|
if auth == nil {
|
|
continue
|
|
}
|
|
if strings.TrimSpace(auth.ID) == "" {
|
|
continue
|
|
}
|
|
kind, _ := auth.AccountInfo()
|
|
if !strings.EqualFold(strings.TrimSpace(kind), "api_key") {
|
|
continue
|
|
}
|
|
|
|
byAlias := make(map[string]string)
|
|
provider := strings.ToLower(strings.TrimSpace(auth.Provider))
|
|
switch provider {
|
|
case "gemini":
|
|
if entry := resolveGeminiAPIKeyConfig(cfg, auth); entry != nil {
|
|
compileAPIKeyModelAliasForModels(byAlias, entry.Models)
|
|
}
|
|
case "claude":
|
|
if entry := resolveClaudeAPIKeyConfig(cfg, auth); entry != nil {
|
|
compileAPIKeyModelAliasForModels(byAlias, entry.Models)
|
|
}
|
|
case "codex":
|
|
if entry := resolveCodexAPIKeyConfig(cfg, auth); entry != nil {
|
|
compileAPIKeyModelAliasForModels(byAlias, entry.Models)
|
|
}
|
|
case "vertex":
|
|
if entry := resolveVertexAPIKeyConfig(cfg, auth); entry != nil {
|
|
compileAPIKeyModelAliasForModels(byAlias, entry.Models)
|
|
}
|
|
default:
|
|
// OpenAI-compat uses config selection from auth.Attributes.
|
|
providerKey := ""
|
|
compatName := ""
|
|
if auth.Attributes != nil {
|
|
providerKey = strings.TrimSpace(auth.Attributes["provider_key"])
|
|
compatName = strings.TrimSpace(auth.Attributes["compat_name"])
|
|
}
|
|
if compatName != "" || strings.EqualFold(strings.TrimSpace(auth.Provider), "openai-compatibility") {
|
|
if entry := resolveOpenAICompatConfig(cfg, providerKey, compatName, auth.Provider); entry != nil {
|
|
compileAPIKeyModelAliasForModels(byAlias, entry.Models)
|
|
}
|
|
}
|
|
}
|
|
|
|
if len(byAlias) > 0 {
|
|
out[auth.ID] = byAlias
|
|
}
|
|
}
|
|
|
|
m.apiKeyModelAlias.Store(out)
|
|
}
|
|
|
|
func compileAPIKeyModelAliasForModels[T interface {
|
|
GetName() string
|
|
GetAlias() string
|
|
}](out map[string]string, models []T) {
|
|
if out == nil {
|
|
return
|
|
}
|
|
for i := range models {
|
|
alias := strings.TrimSpace(models[i].GetAlias())
|
|
name := strings.TrimSpace(models[i].GetName())
|
|
if alias == "" || name == "" {
|
|
continue
|
|
}
|
|
aliasKey := strings.ToLower(thinking.ParseSuffix(alias).ModelName)
|
|
if aliasKey == "" {
|
|
aliasKey = strings.ToLower(alias)
|
|
}
|
|
// Config priority: first alias wins.
|
|
if _, exists := out[aliasKey]; exists {
|
|
continue
|
|
}
|
|
out[aliasKey] = name
|
|
// Also allow direct lookup by upstream name (case-insensitive), so lookups on already-upstream
|
|
// models remain a cheap no-op.
|
|
nameKey := strings.ToLower(thinking.ParseSuffix(name).ModelName)
|
|
if nameKey == "" {
|
|
nameKey = strings.ToLower(name)
|
|
}
|
|
if nameKey != "" {
|
|
if _, exists := out[nameKey]; !exists {
|
|
out[nameKey] = name
|
|
}
|
|
}
|
|
// Preserve config suffix priority by seeding a base-name lookup when name already has suffix.
|
|
nameResult := thinking.ParseSuffix(name)
|
|
if nameResult.HasSuffix {
|
|
baseKey := strings.ToLower(strings.TrimSpace(nameResult.ModelName))
|
|
if baseKey != "" {
|
|
if _, exists := out[baseKey]; !exists {
|
|
out[baseKey] = name
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// SetRetryConfig updates retry attempts and cooldown wait interval.
|
|
func (m *Manager) SetRetryConfig(retry int, maxRetryInterval time.Duration) {
|
|
if m == nil {
|
|
return
|
|
}
|
|
if retry < 0 {
|
|
retry = 0
|
|
}
|
|
if maxRetryInterval < 0 {
|
|
maxRetryInterval = 0
|
|
}
|
|
m.requestRetry.Store(int32(retry))
|
|
m.maxRetryInterval.Store(maxRetryInterval.Nanoseconds())
|
|
}
|
|
|
|
// RegisterExecutor registers a provider executor with the manager.
|
|
func (m *Manager) RegisterExecutor(executor ProviderExecutor) {
|
|
if executor == nil {
|
|
return
|
|
}
|
|
m.mu.Lock()
|
|
defer m.mu.Unlock()
|
|
m.executors[executor.Identifier()] = executor
|
|
}
|
|
|
|
// UnregisterExecutor removes the executor associated with the provider key.
|
|
func (m *Manager) UnregisterExecutor(provider string) {
|
|
provider = strings.ToLower(strings.TrimSpace(provider))
|
|
if provider == "" {
|
|
return
|
|
}
|
|
m.mu.Lock()
|
|
delete(m.executors, provider)
|
|
m.mu.Unlock()
|
|
}
|
|
|
|
// Register inserts a new auth entry into the manager.
|
|
func (m *Manager) Register(ctx context.Context, auth *Auth) (*Auth, error) {
|
|
if auth == nil {
|
|
return nil, nil
|
|
}
|
|
if auth.ID == "" {
|
|
auth.ID = uuid.NewString()
|
|
}
|
|
auth.EnsureIndex()
|
|
m.mu.Lock()
|
|
m.auths[auth.ID] = auth.Clone()
|
|
m.mu.Unlock()
|
|
m.rebuildAPIKeyModelAliasFromRuntimeConfig()
|
|
_ = m.persist(ctx, auth)
|
|
m.hook.OnAuthRegistered(ctx, auth.Clone())
|
|
return auth.Clone(), nil
|
|
}
|
|
|
|
// Update replaces an existing auth entry and notifies hooks.
|
|
func (m *Manager) Update(ctx context.Context, auth *Auth) (*Auth, error) {
|
|
if auth == nil || auth.ID == "" {
|
|
return nil, nil
|
|
}
|
|
m.mu.Lock()
|
|
if existing, ok := m.auths[auth.ID]; ok && existing != nil && !auth.indexAssigned && auth.Index == "" {
|
|
auth.Index = existing.Index
|
|
auth.indexAssigned = existing.indexAssigned
|
|
}
|
|
auth.EnsureIndex()
|
|
m.auths[auth.ID] = auth.Clone()
|
|
m.mu.Unlock()
|
|
m.rebuildAPIKeyModelAliasFromRuntimeConfig()
|
|
_ = m.persist(ctx, auth)
|
|
m.hook.OnAuthUpdated(ctx, auth.Clone())
|
|
return auth.Clone(), nil
|
|
}
|
|
|
|
// Load resets manager state from the backing store.
|
|
func (m *Manager) Load(ctx context.Context) error {
|
|
m.mu.Lock()
|
|
defer m.mu.Unlock()
|
|
if m.store == nil {
|
|
return nil
|
|
}
|
|
items, err := m.store.List(ctx)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
m.auths = make(map[string]*Auth, len(items))
|
|
for _, auth := range items {
|
|
if auth == nil || auth.ID == "" {
|
|
continue
|
|
}
|
|
auth.EnsureIndex()
|
|
m.auths[auth.ID] = auth.Clone()
|
|
}
|
|
cfg, _ := m.runtimeConfig.Load().(*internalconfig.Config)
|
|
if cfg == nil {
|
|
cfg = &internalconfig.Config{}
|
|
}
|
|
m.rebuildAPIKeyModelAliasLocked(cfg)
|
|
return nil
|
|
}
|
|
|
|
// Execute performs a non-streaming execution using the configured selector and executor.
|
|
// It supports multiple providers for the same model and round-robins the starting provider per model.
|
|
func (m *Manager) Execute(ctx context.Context, providers []string, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
|
|
normalized := m.normalizeProviders(providers)
|
|
if len(normalized) == 0 {
|
|
return cliproxyexecutor.Response{}, &Error{Code: "provider_not_found", Message: "no provider supplied"}
|
|
}
|
|
|
|
retryTimes, maxWait := m.retrySettings()
|
|
attempts := retryTimes + 1
|
|
if attempts < 1 {
|
|
attempts = 1
|
|
}
|
|
|
|
var lastErr error
|
|
for attempt := 0; attempt < attempts; attempt++ {
|
|
resp, errExec := m.executeMixedOnce(ctx, normalized, req, opts)
|
|
if errExec == nil {
|
|
return resp, nil
|
|
}
|
|
lastErr = errExec
|
|
wait, shouldRetry := m.shouldRetryAfterError(errExec, attempt, attempts, normalized, req.Model, maxWait)
|
|
if !shouldRetry {
|
|
break
|
|
}
|
|
if errWait := waitForCooldown(ctx, wait); errWait != nil {
|
|
return cliproxyexecutor.Response{}, errWait
|
|
}
|
|
}
|
|
if lastErr != nil {
|
|
return cliproxyexecutor.Response{}, lastErr
|
|
}
|
|
return cliproxyexecutor.Response{}, &Error{Code: "auth_not_found", Message: "no auth available"}
|
|
}
|
|
|
|
// ExecuteCount performs a non-streaming execution using the configured selector and executor.
|
|
// It supports multiple providers for the same model and round-robins the starting provider per model.
|
|
func (m *Manager) ExecuteCount(ctx context.Context, providers []string, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
|
|
normalized := m.normalizeProviders(providers)
|
|
if len(normalized) == 0 {
|
|
return cliproxyexecutor.Response{}, &Error{Code: "provider_not_found", Message: "no provider supplied"}
|
|
}
|
|
|
|
retryTimes, maxWait := m.retrySettings()
|
|
attempts := retryTimes + 1
|
|
if attempts < 1 {
|
|
attempts = 1
|
|
}
|
|
|
|
var lastErr error
|
|
for attempt := 0; attempt < attempts; attempt++ {
|
|
resp, errExec := m.executeCountMixedOnce(ctx, normalized, req, opts)
|
|
if errExec == nil {
|
|
return resp, nil
|
|
}
|
|
lastErr = errExec
|
|
wait, shouldRetry := m.shouldRetryAfterError(errExec, attempt, attempts, normalized, req.Model, maxWait)
|
|
if !shouldRetry {
|
|
break
|
|
}
|
|
if errWait := waitForCooldown(ctx, wait); errWait != nil {
|
|
return cliproxyexecutor.Response{}, errWait
|
|
}
|
|
}
|
|
if lastErr != nil {
|
|
return cliproxyexecutor.Response{}, lastErr
|
|
}
|
|
return cliproxyexecutor.Response{}, &Error{Code: "auth_not_found", Message: "no auth available"}
|
|
}
|
|
|
|
// ExecuteStream performs a streaming execution using the configured selector and executor.
|
|
// It supports multiple providers for the same model and round-robins the starting provider per model.
|
|
func (m *Manager) ExecuteStream(ctx context.Context, providers []string, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (<-chan cliproxyexecutor.StreamChunk, error) {
|
|
normalized := m.normalizeProviders(providers)
|
|
if len(normalized) == 0 {
|
|
return nil, &Error{Code: "provider_not_found", Message: "no provider supplied"}
|
|
}
|
|
|
|
retryTimes, maxWait := m.retrySettings()
|
|
attempts := retryTimes + 1
|
|
if attempts < 1 {
|
|
attempts = 1
|
|
}
|
|
|
|
var lastErr error
|
|
for attempt := 0; attempt < attempts; attempt++ {
|
|
chunks, errStream := m.executeStreamMixedOnce(ctx, normalized, req, opts)
|
|
if errStream == nil {
|
|
return chunks, nil
|
|
}
|
|
lastErr = errStream
|
|
wait, shouldRetry := m.shouldRetryAfterError(errStream, attempt, attempts, normalized, req.Model, maxWait)
|
|
if !shouldRetry {
|
|
break
|
|
}
|
|
if errWait := waitForCooldown(ctx, wait); errWait != nil {
|
|
return nil, errWait
|
|
}
|
|
}
|
|
if lastErr != nil {
|
|
return nil, lastErr
|
|
}
|
|
return nil, &Error{Code: "auth_not_found", Message: "no auth available"}
|
|
}
|
|
|
|
func (m *Manager) executeMixedOnce(ctx context.Context, providers []string, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
|
|
if len(providers) == 0 {
|
|
return cliproxyexecutor.Response{}, &Error{Code: "provider_not_found", Message: "no provider supplied"}
|
|
}
|
|
routeModel := req.Model
|
|
opts = ensureRequestedModelMetadata(opts, routeModel)
|
|
tried := make(map[string]struct{})
|
|
var lastErr error
|
|
for {
|
|
auth, executor, provider, errPick := m.pickNextMixed(ctx, providers, routeModel, opts, tried)
|
|
if errPick != nil {
|
|
if lastErr != nil {
|
|
return cliproxyexecutor.Response{}, lastErr
|
|
}
|
|
return cliproxyexecutor.Response{}, errPick
|
|
}
|
|
|
|
entry := logEntryWithRequestID(ctx)
|
|
debugLogAuthSelection(entry, auth, provider, req.Model)
|
|
|
|
tried[auth.ID] = struct{}{}
|
|
execCtx := ctx
|
|
if rt := m.roundTripperFor(auth); rt != nil {
|
|
execCtx = context.WithValue(execCtx, roundTripperContextKey{}, rt)
|
|
execCtx = context.WithValue(execCtx, "cliproxy.roundtripper", rt)
|
|
}
|
|
execReq := req
|
|
execReq.Model = rewriteModelForAuth(routeModel, auth)
|
|
execReq.Model = m.applyOAuthModelAlias(auth, execReq.Model)
|
|
execReq.Model = m.applyAPIKeyModelAlias(auth, execReq.Model)
|
|
resp, errExec := executor.Execute(execCtx, auth, execReq, opts)
|
|
result := Result{AuthID: auth.ID, Provider: provider, Model: routeModel, Success: errExec == nil}
|
|
if errExec != nil {
|
|
if errCtx := execCtx.Err(); errCtx != nil {
|
|
return cliproxyexecutor.Response{}, errCtx
|
|
}
|
|
result.Error = &Error{Message: errExec.Error()}
|
|
var se cliproxyexecutor.StatusError
|
|
if errors.As(errExec, &se) && se != nil {
|
|
result.Error.HTTPStatus = se.StatusCode()
|
|
}
|
|
if ra := retryAfterFromError(errExec); ra != nil {
|
|
result.RetryAfter = ra
|
|
}
|
|
m.MarkResult(execCtx, result)
|
|
lastErr = errExec
|
|
continue
|
|
}
|
|
m.MarkResult(execCtx, result)
|
|
return resp, nil
|
|
}
|
|
}
|
|
|
|
func (m *Manager) executeCountMixedOnce(ctx context.Context, providers []string, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
|
|
if len(providers) == 0 {
|
|
return cliproxyexecutor.Response{}, &Error{Code: "provider_not_found", Message: "no provider supplied"}
|
|
}
|
|
routeModel := req.Model
|
|
opts = ensureRequestedModelMetadata(opts, routeModel)
|
|
tried := make(map[string]struct{})
|
|
var lastErr error
|
|
for {
|
|
auth, executor, provider, errPick := m.pickNextMixed(ctx, providers, routeModel, opts, tried)
|
|
if errPick != nil {
|
|
if lastErr != nil {
|
|
return cliproxyexecutor.Response{}, lastErr
|
|
}
|
|
return cliproxyexecutor.Response{}, errPick
|
|
}
|
|
|
|
entry := logEntryWithRequestID(ctx)
|
|
debugLogAuthSelection(entry, auth, provider, req.Model)
|
|
|
|
tried[auth.ID] = struct{}{}
|
|
execCtx := ctx
|
|
if rt := m.roundTripperFor(auth); rt != nil {
|
|
execCtx = context.WithValue(execCtx, roundTripperContextKey{}, rt)
|
|
execCtx = context.WithValue(execCtx, "cliproxy.roundtripper", rt)
|
|
}
|
|
execReq := req
|
|
execReq.Model = rewriteModelForAuth(routeModel, auth)
|
|
execReq.Model = m.applyOAuthModelAlias(auth, execReq.Model)
|
|
execReq.Model = m.applyAPIKeyModelAlias(auth, execReq.Model)
|
|
resp, errExec := executor.CountTokens(execCtx, auth, execReq, opts)
|
|
result := Result{AuthID: auth.ID, Provider: provider, Model: routeModel, Success: errExec == nil}
|
|
if errExec != nil {
|
|
if errCtx := execCtx.Err(); errCtx != nil {
|
|
return cliproxyexecutor.Response{}, errCtx
|
|
}
|
|
result.Error = &Error{Message: errExec.Error()}
|
|
var se cliproxyexecutor.StatusError
|
|
if errors.As(errExec, &se) && se != nil {
|
|
result.Error.HTTPStatus = se.StatusCode()
|
|
}
|
|
if ra := retryAfterFromError(errExec); ra != nil {
|
|
result.RetryAfter = ra
|
|
}
|
|
m.MarkResult(execCtx, result)
|
|
lastErr = errExec
|
|
continue
|
|
}
|
|
m.MarkResult(execCtx, result)
|
|
return resp, nil
|
|
}
|
|
}
|
|
|
|
func (m *Manager) executeStreamMixedOnce(ctx context.Context, providers []string, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (<-chan cliproxyexecutor.StreamChunk, error) {
|
|
if len(providers) == 0 {
|
|
return nil, &Error{Code: "provider_not_found", Message: "no provider supplied"}
|
|
}
|
|
routeModel := req.Model
|
|
opts = ensureRequestedModelMetadata(opts, routeModel)
|
|
tried := make(map[string]struct{})
|
|
var lastErr error
|
|
for {
|
|
auth, executor, provider, errPick := m.pickNextMixed(ctx, providers, routeModel, opts, tried)
|
|
if errPick != nil {
|
|
if lastErr != nil {
|
|
return nil, lastErr
|
|
}
|
|
return nil, errPick
|
|
}
|
|
|
|
entry := logEntryWithRequestID(ctx)
|
|
debugLogAuthSelection(entry, auth, provider, req.Model)
|
|
|
|
tried[auth.ID] = struct{}{}
|
|
execCtx := ctx
|
|
if rt := m.roundTripperFor(auth); rt != nil {
|
|
execCtx = context.WithValue(execCtx, roundTripperContextKey{}, rt)
|
|
execCtx = context.WithValue(execCtx, "cliproxy.roundtripper", rt)
|
|
}
|
|
execReq := req
|
|
execReq.Model = rewriteModelForAuth(routeModel, auth)
|
|
execReq.Model = m.applyOAuthModelAlias(auth, execReq.Model)
|
|
execReq.Model = m.applyAPIKeyModelAlias(auth, execReq.Model)
|
|
chunks, errStream := executor.ExecuteStream(execCtx, auth, execReq, opts)
|
|
if errStream != nil {
|
|
if errCtx := execCtx.Err(); errCtx != nil {
|
|
return nil, errCtx
|
|
}
|
|
rerr := &Error{Message: errStream.Error()}
|
|
var se cliproxyexecutor.StatusError
|
|
if errors.As(errStream, &se) && se != nil {
|
|
rerr.HTTPStatus = se.StatusCode()
|
|
}
|
|
result := Result{AuthID: auth.ID, Provider: provider, Model: routeModel, Success: false, Error: rerr}
|
|
result.RetryAfter = retryAfterFromError(errStream)
|
|
m.MarkResult(execCtx, result)
|
|
lastErr = errStream
|
|
continue
|
|
}
|
|
out := make(chan cliproxyexecutor.StreamChunk)
|
|
go func(streamCtx context.Context, streamAuth *Auth, streamProvider string, streamChunks <-chan cliproxyexecutor.StreamChunk) {
|
|
defer close(out)
|
|
var failed bool
|
|
for chunk := range streamChunks {
|
|
if chunk.Err != nil && !failed {
|
|
failed = true
|
|
rerr := &Error{Message: chunk.Err.Error()}
|
|
var se cliproxyexecutor.StatusError
|
|
if errors.As(chunk.Err, &se) && se != nil {
|
|
rerr.HTTPStatus = se.StatusCode()
|
|
}
|
|
m.MarkResult(streamCtx, Result{AuthID: streamAuth.ID, Provider: streamProvider, Model: routeModel, Success: false, Error: rerr})
|
|
}
|
|
out <- chunk
|
|
}
|
|
if !failed {
|
|
m.MarkResult(streamCtx, Result{AuthID: streamAuth.ID, Provider: streamProvider, Model: routeModel, Success: true})
|
|
}
|
|
}(execCtx, auth.Clone(), provider, chunks)
|
|
return out, nil
|
|
}
|
|
}
|
|
|
|
func ensureRequestedModelMetadata(opts cliproxyexecutor.Options, requestedModel string) cliproxyexecutor.Options {
|
|
requestedModel = strings.TrimSpace(requestedModel)
|
|
if requestedModel == "" {
|
|
return opts
|
|
}
|
|
if hasRequestedModelMetadata(opts.Metadata) {
|
|
return opts
|
|
}
|
|
if len(opts.Metadata) == 0 {
|
|
opts.Metadata = map[string]any{cliproxyexecutor.RequestedModelMetadataKey: requestedModel}
|
|
return opts
|
|
}
|
|
meta := make(map[string]any, len(opts.Metadata)+1)
|
|
for k, v := range opts.Metadata {
|
|
meta[k] = v
|
|
}
|
|
meta[cliproxyexecutor.RequestedModelMetadataKey] = requestedModel
|
|
opts.Metadata = meta
|
|
return opts
|
|
}
|
|
|
|
func hasRequestedModelMetadata(meta map[string]any) bool {
|
|
if len(meta) == 0 {
|
|
return false
|
|
}
|
|
raw, ok := meta[cliproxyexecutor.RequestedModelMetadataKey]
|
|
if !ok || raw == nil {
|
|
return false
|
|
}
|
|
switch v := raw.(type) {
|
|
case string:
|
|
return strings.TrimSpace(v) != ""
|
|
case []byte:
|
|
return strings.TrimSpace(string(v)) != ""
|
|
default:
|
|
return false
|
|
}
|
|
}
|
|
|
|
func rewriteModelForAuth(model string, auth *Auth) string {
|
|
if auth == nil || model == "" {
|
|
return model
|
|
}
|
|
prefix := strings.TrimSpace(auth.Prefix)
|
|
if prefix == "" {
|
|
return model
|
|
}
|
|
needle := prefix + "/"
|
|
if !strings.HasPrefix(model, needle) {
|
|
return model
|
|
}
|
|
return strings.TrimPrefix(model, needle)
|
|
}
|
|
|
|
func (m *Manager) applyAPIKeyModelAlias(auth *Auth, requestedModel string) string {
|
|
if m == nil || auth == nil {
|
|
return requestedModel
|
|
}
|
|
|
|
kind, _ := auth.AccountInfo()
|
|
if !strings.EqualFold(strings.TrimSpace(kind), "api_key") {
|
|
return requestedModel
|
|
}
|
|
|
|
requestedModel = strings.TrimSpace(requestedModel)
|
|
if requestedModel == "" {
|
|
return requestedModel
|
|
}
|
|
|
|
// Fast path: lookup per-auth mapping table (keyed by auth.ID).
|
|
if resolved := m.lookupAPIKeyUpstreamModel(auth.ID, requestedModel); resolved != "" {
|
|
return resolved
|
|
}
|
|
|
|
// Slow path: scan config for the matching credential entry and resolve alias.
|
|
// This acts as a safety net if mappings are stale or auth.ID is missing.
|
|
cfg, _ := m.runtimeConfig.Load().(*internalconfig.Config)
|
|
if cfg == nil {
|
|
cfg = &internalconfig.Config{}
|
|
}
|
|
|
|
provider := strings.ToLower(strings.TrimSpace(auth.Provider))
|
|
upstreamModel := ""
|
|
switch provider {
|
|
case "gemini":
|
|
upstreamModel = resolveUpstreamModelForGeminiAPIKey(cfg, auth, requestedModel)
|
|
case "claude":
|
|
upstreamModel = resolveUpstreamModelForClaudeAPIKey(cfg, auth, requestedModel)
|
|
case "codex":
|
|
upstreamModel = resolveUpstreamModelForCodexAPIKey(cfg, auth, requestedModel)
|
|
case "vertex":
|
|
upstreamModel = resolveUpstreamModelForVertexAPIKey(cfg, auth, requestedModel)
|
|
default:
|
|
upstreamModel = resolveUpstreamModelForOpenAICompatAPIKey(cfg, auth, requestedModel)
|
|
}
|
|
|
|
// Return upstream model if found, otherwise return requested model.
|
|
if upstreamModel != "" {
|
|
return upstreamModel
|
|
}
|
|
return requestedModel
|
|
}
|
|
|
|
// APIKeyConfigEntry is a generic interface for API key configurations.
|
|
type APIKeyConfigEntry interface {
|
|
GetAPIKey() string
|
|
GetBaseURL() string
|
|
}
|
|
|
|
func resolveAPIKeyConfig[T APIKeyConfigEntry](entries []T, auth *Auth) *T {
|
|
if auth == nil || len(entries) == 0 {
|
|
return nil
|
|
}
|
|
attrKey, attrBase := "", ""
|
|
if auth.Attributes != nil {
|
|
attrKey = strings.TrimSpace(auth.Attributes["api_key"])
|
|
attrBase = strings.TrimSpace(auth.Attributes["base_url"])
|
|
}
|
|
for i := range entries {
|
|
entry := &entries[i]
|
|
cfgKey := strings.TrimSpace((*entry).GetAPIKey())
|
|
cfgBase := strings.TrimSpace((*entry).GetBaseURL())
|
|
if attrKey != "" && attrBase != "" {
|
|
if strings.EqualFold(cfgKey, attrKey) && strings.EqualFold(cfgBase, attrBase) {
|
|
return entry
|
|
}
|
|
continue
|
|
}
|
|
if attrKey != "" && strings.EqualFold(cfgKey, attrKey) {
|
|
if cfgBase == "" || strings.EqualFold(cfgBase, attrBase) {
|
|
return entry
|
|
}
|
|
}
|
|
if attrKey == "" && attrBase != "" && strings.EqualFold(cfgBase, attrBase) {
|
|
return entry
|
|
}
|
|
}
|
|
if attrKey != "" {
|
|
for i := range entries {
|
|
entry := &entries[i]
|
|
if strings.EqualFold(strings.TrimSpace((*entry).GetAPIKey()), attrKey) {
|
|
return entry
|
|
}
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func resolveGeminiAPIKeyConfig(cfg *internalconfig.Config, auth *Auth) *internalconfig.GeminiKey {
|
|
if cfg == nil {
|
|
return nil
|
|
}
|
|
return resolveAPIKeyConfig(cfg.GeminiKey, auth)
|
|
}
|
|
|
|
func resolveClaudeAPIKeyConfig(cfg *internalconfig.Config, auth *Auth) *internalconfig.ClaudeKey {
|
|
if cfg == nil {
|
|
return nil
|
|
}
|
|
return resolveAPIKeyConfig(cfg.ClaudeKey, auth)
|
|
}
|
|
|
|
func resolveCodexAPIKeyConfig(cfg *internalconfig.Config, auth *Auth) *internalconfig.CodexKey {
|
|
if cfg == nil {
|
|
return nil
|
|
}
|
|
return resolveAPIKeyConfig(cfg.CodexKey, auth)
|
|
}
|
|
|
|
func resolveVertexAPIKeyConfig(cfg *internalconfig.Config, auth *Auth) *internalconfig.VertexCompatKey {
|
|
if cfg == nil {
|
|
return nil
|
|
}
|
|
return resolveAPIKeyConfig(cfg.VertexCompatAPIKey, auth)
|
|
}
|
|
|
|
func resolveUpstreamModelForGeminiAPIKey(cfg *internalconfig.Config, auth *Auth, requestedModel string) string {
|
|
entry := resolveGeminiAPIKeyConfig(cfg, auth)
|
|
if entry == nil {
|
|
return ""
|
|
}
|
|
return resolveModelAliasFromConfigModels(requestedModel, asModelAliasEntries(entry.Models))
|
|
}
|
|
|
|
func resolveUpstreamModelForClaudeAPIKey(cfg *internalconfig.Config, auth *Auth, requestedModel string) string {
|
|
entry := resolveClaudeAPIKeyConfig(cfg, auth)
|
|
if entry == nil {
|
|
return ""
|
|
}
|
|
return resolveModelAliasFromConfigModels(requestedModel, asModelAliasEntries(entry.Models))
|
|
}
|
|
|
|
func resolveUpstreamModelForCodexAPIKey(cfg *internalconfig.Config, auth *Auth, requestedModel string) string {
|
|
entry := resolveCodexAPIKeyConfig(cfg, auth)
|
|
if entry == nil {
|
|
return ""
|
|
}
|
|
return resolveModelAliasFromConfigModels(requestedModel, asModelAliasEntries(entry.Models))
|
|
}
|
|
|
|
func resolveUpstreamModelForVertexAPIKey(cfg *internalconfig.Config, auth *Auth, requestedModel string) string {
|
|
entry := resolveVertexAPIKeyConfig(cfg, auth)
|
|
if entry == nil {
|
|
return ""
|
|
}
|
|
return resolveModelAliasFromConfigModels(requestedModel, asModelAliasEntries(entry.Models))
|
|
}
|
|
|
|
func resolveUpstreamModelForOpenAICompatAPIKey(cfg *internalconfig.Config, auth *Auth, requestedModel string) string {
|
|
providerKey := ""
|
|
compatName := ""
|
|
if auth != nil && len(auth.Attributes) > 0 {
|
|
providerKey = strings.TrimSpace(auth.Attributes["provider_key"])
|
|
compatName = strings.TrimSpace(auth.Attributes["compat_name"])
|
|
}
|
|
if compatName == "" && !strings.EqualFold(strings.TrimSpace(auth.Provider), "openai-compatibility") {
|
|
return ""
|
|
}
|
|
entry := resolveOpenAICompatConfig(cfg, providerKey, compatName, auth.Provider)
|
|
if entry == nil {
|
|
return ""
|
|
}
|
|
return resolveModelAliasFromConfigModels(requestedModel, asModelAliasEntries(entry.Models))
|
|
}
|
|
|
|
type apiKeyModelAliasTable map[string]map[string]string
|
|
|
|
func resolveOpenAICompatConfig(cfg *internalconfig.Config, providerKey, compatName, authProvider string) *internalconfig.OpenAICompatibility {
|
|
if cfg == nil {
|
|
return nil
|
|
}
|
|
candidates := make([]string, 0, 3)
|
|
if v := strings.TrimSpace(compatName); v != "" {
|
|
candidates = append(candidates, v)
|
|
}
|
|
if v := strings.TrimSpace(providerKey); v != "" {
|
|
candidates = append(candidates, v)
|
|
}
|
|
if v := strings.TrimSpace(authProvider); v != "" {
|
|
candidates = append(candidates, v)
|
|
}
|
|
for i := range cfg.OpenAICompatibility {
|
|
compat := &cfg.OpenAICompatibility[i]
|
|
for _, candidate := range candidates {
|
|
if candidate != "" && strings.EqualFold(strings.TrimSpace(candidate), compat.Name) {
|
|
return compat
|
|
}
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func asModelAliasEntries[T interface {
|
|
GetName() string
|
|
GetAlias() string
|
|
}](models []T) []modelAliasEntry {
|
|
if len(models) == 0 {
|
|
return nil
|
|
}
|
|
out := make([]modelAliasEntry, 0, len(models))
|
|
for i := range models {
|
|
out = append(out, models[i])
|
|
}
|
|
return out
|
|
}
|
|
|
|
func (m *Manager) normalizeProviders(providers []string) []string {
|
|
if len(providers) == 0 {
|
|
return nil
|
|
}
|
|
result := make([]string, 0, len(providers))
|
|
seen := make(map[string]struct{}, len(providers))
|
|
for _, provider := range providers {
|
|
p := strings.TrimSpace(strings.ToLower(provider))
|
|
if p == "" {
|
|
continue
|
|
}
|
|
if _, ok := seen[p]; ok {
|
|
continue
|
|
}
|
|
seen[p] = struct{}{}
|
|
result = append(result, p)
|
|
}
|
|
return result
|
|
}
|
|
|
|
func (m *Manager) retrySettings() (int, time.Duration) {
|
|
if m == nil {
|
|
return 0, 0
|
|
}
|
|
return int(m.requestRetry.Load()), time.Duration(m.maxRetryInterval.Load())
|
|
}
|
|
|
|
func (m *Manager) closestCooldownWait(providers []string, model string) (time.Duration, bool) {
|
|
if m == nil || len(providers) == 0 {
|
|
return 0, false
|
|
}
|
|
now := time.Now()
|
|
providerSet := make(map[string]struct{}, len(providers))
|
|
for i := range providers {
|
|
key := strings.TrimSpace(strings.ToLower(providers[i]))
|
|
if key == "" {
|
|
continue
|
|
}
|
|
providerSet[key] = struct{}{}
|
|
}
|
|
m.mu.RLock()
|
|
defer m.mu.RUnlock()
|
|
var (
|
|
found bool
|
|
minWait time.Duration
|
|
)
|
|
for _, auth := range m.auths {
|
|
if auth == nil {
|
|
continue
|
|
}
|
|
providerKey := strings.TrimSpace(strings.ToLower(auth.Provider))
|
|
if _, ok := providerSet[providerKey]; !ok {
|
|
continue
|
|
}
|
|
blocked, reason, next := isAuthBlockedForModel(auth, model, now)
|
|
if !blocked || next.IsZero() || reason == blockReasonDisabled {
|
|
continue
|
|
}
|
|
wait := next.Sub(now)
|
|
if wait < 0 {
|
|
continue
|
|
}
|
|
if !found || wait < minWait {
|
|
minWait = wait
|
|
found = true
|
|
}
|
|
}
|
|
return minWait, found
|
|
}
|
|
|
|
func (m *Manager) shouldRetryAfterError(err error, attempt, maxAttempts int, providers []string, model string, maxWait time.Duration) (time.Duration, bool) {
|
|
if err == nil || attempt >= maxAttempts-1 {
|
|
return 0, false
|
|
}
|
|
if maxWait <= 0 {
|
|
return 0, false
|
|
}
|
|
if status := statusCodeFromError(err); status == http.StatusOK {
|
|
return 0, false
|
|
}
|
|
wait, found := m.closestCooldownWait(providers, model)
|
|
if !found || wait > maxWait {
|
|
return 0, false
|
|
}
|
|
return wait, true
|
|
}
|
|
|
|
func waitForCooldown(ctx context.Context, wait time.Duration) error {
|
|
if wait <= 0 {
|
|
return nil
|
|
}
|
|
timer := time.NewTimer(wait)
|
|
defer timer.Stop()
|
|
select {
|
|
case <-ctx.Done():
|
|
return ctx.Err()
|
|
case <-timer.C:
|
|
return nil
|
|
}
|
|
}
|
|
|
|
// MarkResult records an execution result and notifies hooks.
|
|
func (m *Manager) MarkResult(ctx context.Context, result Result) {
|
|
if result.AuthID == "" {
|
|
return
|
|
}
|
|
|
|
shouldResumeModel := false
|
|
shouldSuspendModel := false
|
|
suspendReason := ""
|
|
clearModelQuota := false
|
|
setModelQuota := false
|
|
|
|
m.mu.Lock()
|
|
if auth, ok := m.auths[result.AuthID]; ok && auth != nil {
|
|
now := time.Now()
|
|
|
|
if result.Success {
|
|
if result.Model != "" {
|
|
state := ensureModelState(auth, result.Model)
|
|
resetModelState(state, now)
|
|
updateAggregatedAvailability(auth, now)
|
|
if !hasModelError(auth, now) {
|
|
auth.LastError = nil
|
|
auth.StatusMessage = ""
|
|
auth.Status = StatusActive
|
|
}
|
|
auth.UpdatedAt = now
|
|
shouldResumeModel = true
|
|
clearModelQuota = true
|
|
} else {
|
|
clearAuthStateOnSuccess(auth, now)
|
|
}
|
|
} else {
|
|
if result.Model != "" {
|
|
state := ensureModelState(auth, result.Model)
|
|
state.Unavailable = true
|
|
state.Status = StatusError
|
|
state.UpdatedAt = now
|
|
if result.Error != nil {
|
|
state.LastError = cloneError(result.Error)
|
|
state.StatusMessage = result.Error.Message
|
|
auth.LastError = cloneError(result.Error)
|
|
auth.StatusMessage = result.Error.Message
|
|
}
|
|
|
|
statusCode := statusCodeFromResult(result.Error)
|
|
switch statusCode {
|
|
case 401:
|
|
next := now.Add(30 * time.Minute)
|
|
state.NextRetryAfter = next
|
|
suspendReason = "unauthorized"
|
|
shouldSuspendModel = true
|
|
case 402, 403:
|
|
next := now.Add(30 * time.Minute)
|
|
state.NextRetryAfter = next
|
|
suspendReason = "payment_required"
|
|
shouldSuspendModel = true
|
|
case 404:
|
|
next := now.Add(12 * time.Hour)
|
|
state.NextRetryAfter = next
|
|
suspendReason = "not_found"
|
|
shouldSuspendModel = true
|
|
case 429:
|
|
var next time.Time
|
|
backoffLevel := state.Quota.BackoffLevel
|
|
if result.RetryAfter != nil {
|
|
next = now.Add(*result.RetryAfter)
|
|
} else {
|
|
cooldown, nextLevel := nextQuotaCooldown(backoffLevel)
|
|
if cooldown > 0 {
|
|
next = now.Add(cooldown)
|
|
}
|
|
backoffLevel = nextLevel
|
|
}
|
|
state.NextRetryAfter = next
|
|
state.Quota = QuotaState{
|
|
Exceeded: true,
|
|
Reason: "quota",
|
|
NextRecoverAt: next,
|
|
BackoffLevel: backoffLevel,
|
|
}
|
|
suspendReason = "quota"
|
|
shouldSuspendModel = true
|
|
setModelQuota = true
|
|
case 408, 500, 502, 503, 504:
|
|
if quotaCooldownDisabled.Load() {
|
|
state.NextRetryAfter = time.Time{}
|
|
} else {
|
|
next := now.Add(1 * time.Minute)
|
|
state.NextRetryAfter = next
|
|
}
|
|
default:
|
|
state.NextRetryAfter = time.Time{}
|
|
}
|
|
|
|
auth.Status = StatusError
|
|
auth.UpdatedAt = now
|
|
updateAggregatedAvailability(auth, now)
|
|
} else {
|
|
applyAuthFailureState(auth, result.Error, result.RetryAfter, now)
|
|
}
|
|
}
|
|
|
|
_ = m.persist(ctx, auth)
|
|
}
|
|
m.mu.Unlock()
|
|
|
|
if clearModelQuota && result.Model != "" {
|
|
registry.GetGlobalRegistry().ClearModelQuotaExceeded(result.AuthID, result.Model)
|
|
}
|
|
if setModelQuota && result.Model != "" {
|
|
registry.GetGlobalRegistry().SetModelQuotaExceeded(result.AuthID, result.Model)
|
|
}
|
|
if shouldResumeModel {
|
|
registry.GetGlobalRegistry().ResumeClientModel(result.AuthID, result.Model)
|
|
} else if shouldSuspendModel {
|
|
registry.GetGlobalRegistry().SuspendClientModel(result.AuthID, result.Model, suspendReason)
|
|
}
|
|
|
|
m.hook.OnResult(ctx, result)
|
|
}
|
|
|
|
func ensureModelState(auth *Auth, model string) *ModelState {
|
|
if auth == nil || model == "" {
|
|
return nil
|
|
}
|
|
if auth.ModelStates == nil {
|
|
auth.ModelStates = make(map[string]*ModelState)
|
|
}
|
|
if state, ok := auth.ModelStates[model]; ok && state != nil {
|
|
return state
|
|
}
|
|
state := &ModelState{Status: StatusActive}
|
|
auth.ModelStates[model] = state
|
|
return state
|
|
}
|
|
|
|
func resetModelState(state *ModelState, now time.Time) {
|
|
if state == nil {
|
|
return
|
|
}
|
|
state.Unavailable = false
|
|
state.Status = StatusActive
|
|
state.StatusMessage = ""
|
|
state.NextRetryAfter = time.Time{}
|
|
state.LastError = nil
|
|
state.Quota = QuotaState{}
|
|
state.UpdatedAt = now
|
|
}
|
|
|
|
func updateAggregatedAvailability(auth *Auth, now time.Time) {
|
|
if auth == nil || len(auth.ModelStates) == 0 {
|
|
return
|
|
}
|
|
allUnavailable := true
|
|
earliestRetry := time.Time{}
|
|
quotaExceeded := false
|
|
quotaRecover := time.Time{}
|
|
maxBackoffLevel := 0
|
|
for _, state := range auth.ModelStates {
|
|
if state == nil {
|
|
continue
|
|
}
|
|
stateUnavailable := false
|
|
if state.Status == StatusDisabled {
|
|
stateUnavailable = true
|
|
} else if state.Unavailable {
|
|
if state.NextRetryAfter.IsZero() {
|
|
stateUnavailable = true
|
|
} else if state.NextRetryAfter.After(now) {
|
|
stateUnavailable = true
|
|
if earliestRetry.IsZero() || state.NextRetryAfter.Before(earliestRetry) {
|
|
earliestRetry = state.NextRetryAfter
|
|
}
|
|
} else {
|
|
state.Unavailable = false
|
|
state.NextRetryAfter = time.Time{}
|
|
}
|
|
}
|
|
if !stateUnavailable {
|
|
allUnavailable = false
|
|
}
|
|
if state.Quota.Exceeded {
|
|
quotaExceeded = true
|
|
if quotaRecover.IsZero() || (!state.Quota.NextRecoverAt.IsZero() && state.Quota.NextRecoverAt.Before(quotaRecover)) {
|
|
quotaRecover = state.Quota.NextRecoverAt
|
|
}
|
|
if state.Quota.BackoffLevel > maxBackoffLevel {
|
|
maxBackoffLevel = state.Quota.BackoffLevel
|
|
}
|
|
}
|
|
}
|
|
auth.Unavailable = allUnavailable
|
|
if allUnavailable {
|
|
auth.NextRetryAfter = earliestRetry
|
|
} else {
|
|
auth.NextRetryAfter = time.Time{}
|
|
}
|
|
if quotaExceeded {
|
|
auth.Quota.Exceeded = true
|
|
auth.Quota.Reason = "quota"
|
|
auth.Quota.NextRecoverAt = quotaRecover
|
|
auth.Quota.BackoffLevel = maxBackoffLevel
|
|
} else {
|
|
auth.Quota.Exceeded = false
|
|
auth.Quota.Reason = ""
|
|
auth.Quota.NextRecoverAt = time.Time{}
|
|
auth.Quota.BackoffLevel = 0
|
|
}
|
|
}
|
|
|
|
func hasModelError(auth *Auth, now time.Time) bool {
|
|
if auth == nil || len(auth.ModelStates) == 0 {
|
|
return false
|
|
}
|
|
for _, state := range auth.ModelStates {
|
|
if state == nil {
|
|
continue
|
|
}
|
|
if state.LastError != nil {
|
|
return true
|
|
}
|
|
if state.Status == StatusError {
|
|
if state.Unavailable && (state.NextRetryAfter.IsZero() || state.NextRetryAfter.After(now)) {
|
|
return true
|
|
}
|
|
}
|
|
}
|
|
return false
|
|
}
|
|
|
|
func clearAuthStateOnSuccess(auth *Auth, now time.Time) {
|
|
if auth == nil {
|
|
return
|
|
}
|
|
auth.Unavailable = false
|
|
auth.Status = StatusActive
|
|
auth.StatusMessage = ""
|
|
auth.Quota.Exceeded = false
|
|
auth.Quota.Reason = ""
|
|
auth.Quota.NextRecoverAt = time.Time{}
|
|
auth.Quota.BackoffLevel = 0
|
|
auth.LastError = nil
|
|
auth.NextRetryAfter = time.Time{}
|
|
auth.UpdatedAt = now
|
|
}
|
|
|
|
func cloneError(err *Error) *Error {
|
|
if err == nil {
|
|
return nil
|
|
}
|
|
return &Error{
|
|
Code: err.Code,
|
|
Message: err.Message,
|
|
Retryable: err.Retryable,
|
|
HTTPStatus: err.HTTPStatus,
|
|
}
|
|
}
|
|
|
|
func statusCodeFromError(err error) int {
|
|
if err == nil {
|
|
return 0
|
|
}
|
|
type statusCoder interface {
|
|
StatusCode() int
|
|
}
|
|
var sc statusCoder
|
|
if errors.As(err, &sc) && sc != nil {
|
|
return sc.StatusCode()
|
|
}
|
|
return 0
|
|
}
|
|
|
|
func retryAfterFromError(err error) *time.Duration {
|
|
if err == nil {
|
|
return nil
|
|
}
|
|
type retryAfterProvider interface {
|
|
RetryAfter() *time.Duration
|
|
}
|
|
rap, ok := err.(retryAfterProvider)
|
|
if !ok || rap == nil {
|
|
return nil
|
|
}
|
|
retryAfter := rap.RetryAfter()
|
|
if retryAfter == nil {
|
|
return nil
|
|
}
|
|
val := *retryAfter
|
|
return &val
|
|
}
|
|
|
|
func statusCodeFromResult(err *Error) int {
|
|
if err == nil {
|
|
return 0
|
|
}
|
|
return err.StatusCode()
|
|
}
|
|
|
|
func applyAuthFailureState(auth *Auth, resultErr *Error, retryAfter *time.Duration, now time.Time) {
|
|
if auth == nil {
|
|
return
|
|
}
|
|
auth.Unavailable = true
|
|
auth.Status = StatusError
|
|
auth.UpdatedAt = now
|
|
if resultErr != nil {
|
|
auth.LastError = cloneError(resultErr)
|
|
if resultErr.Message != "" {
|
|
auth.StatusMessage = resultErr.Message
|
|
}
|
|
}
|
|
statusCode := statusCodeFromResult(resultErr)
|
|
switch statusCode {
|
|
case 401:
|
|
auth.StatusMessage = "unauthorized"
|
|
auth.NextRetryAfter = now.Add(30 * time.Minute)
|
|
case 402, 403:
|
|
auth.StatusMessage = "payment_required"
|
|
auth.NextRetryAfter = now.Add(30 * time.Minute)
|
|
case 404:
|
|
auth.StatusMessage = "not_found"
|
|
auth.NextRetryAfter = now.Add(12 * time.Hour)
|
|
case 429:
|
|
auth.StatusMessage = "quota exhausted"
|
|
auth.Quota.Exceeded = true
|
|
auth.Quota.Reason = "quota"
|
|
var next time.Time
|
|
if retryAfter != nil {
|
|
next = now.Add(*retryAfter)
|
|
} else {
|
|
cooldown, nextLevel := nextQuotaCooldown(auth.Quota.BackoffLevel)
|
|
if cooldown > 0 {
|
|
next = now.Add(cooldown)
|
|
}
|
|
auth.Quota.BackoffLevel = nextLevel
|
|
}
|
|
auth.Quota.NextRecoverAt = next
|
|
auth.NextRetryAfter = next
|
|
case 408, 500, 502, 503, 504:
|
|
auth.StatusMessage = "transient upstream error"
|
|
if quotaCooldownDisabled.Load() {
|
|
auth.NextRetryAfter = time.Time{}
|
|
} else {
|
|
auth.NextRetryAfter = now.Add(1 * time.Minute)
|
|
}
|
|
default:
|
|
if auth.StatusMessage == "" {
|
|
auth.StatusMessage = "request failed"
|
|
}
|
|
}
|
|
}
|
|
|
|
// nextQuotaCooldown returns the next cooldown duration and updated backoff level for repeated quota errors.
|
|
func nextQuotaCooldown(prevLevel int) (time.Duration, int) {
|
|
if prevLevel < 0 {
|
|
prevLevel = 0
|
|
}
|
|
if quotaCooldownDisabled.Load() {
|
|
return 0, prevLevel
|
|
}
|
|
cooldown := quotaBackoffBase * time.Duration(1<<prevLevel)
|
|
if cooldown < quotaBackoffBase {
|
|
cooldown = quotaBackoffBase
|
|
}
|
|
if cooldown >= quotaBackoffMax {
|
|
return quotaBackoffMax, prevLevel
|
|
}
|
|
return cooldown, prevLevel + 1
|
|
}
|
|
|
|
// List returns all auth entries currently known by the manager.
|
|
func (m *Manager) List() []*Auth {
|
|
m.mu.RLock()
|
|
defer m.mu.RUnlock()
|
|
list := make([]*Auth, 0, len(m.auths))
|
|
for _, auth := range m.auths {
|
|
list = append(list, auth.Clone())
|
|
}
|
|
return list
|
|
}
|
|
|
|
// GetByID retrieves an auth entry by its ID.
|
|
|
|
func (m *Manager) GetByID(id string) (*Auth, bool) {
|
|
if id == "" {
|
|
return nil, false
|
|
}
|
|
m.mu.RLock()
|
|
defer m.mu.RUnlock()
|
|
auth, ok := m.auths[id]
|
|
if !ok {
|
|
return nil, false
|
|
}
|
|
return auth.Clone(), true
|
|
}
|
|
|
|
func (m *Manager) pickNext(ctx context.Context, provider, model string, opts cliproxyexecutor.Options, tried map[string]struct{}) (*Auth, ProviderExecutor, error) {
|
|
m.mu.RLock()
|
|
executor, okExecutor := m.executors[provider]
|
|
if !okExecutor {
|
|
m.mu.RUnlock()
|
|
return nil, nil, &Error{Code: "executor_not_found", Message: "executor not registered"}
|
|
}
|
|
candidates := make([]*Auth, 0, len(m.auths))
|
|
modelKey := strings.TrimSpace(model)
|
|
// Always use base model name (without thinking suffix) for auth matching.
|
|
if modelKey != "" {
|
|
parsed := thinking.ParseSuffix(modelKey)
|
|
if parsed.ModelName != "" {
|
|
modelKey = strings.TrimSpace(parsed.ModelName)
|
|
}
|
|
}
|
|
registryRef := registry.GetGlobalRegistry()
|
|
for _, candidate := range m.auths {
|
|
if candidate.Provider != provider || candidate.Disabled {
|
|
continue
|
|
}
|
|
if _, used := tried[candidate.ID]; used {
|
|
continue
|
|
}
|
|
if modelKey != "" && registryRef != nil && !registryRef.ClientSupportsModel(candidate.ID, modelKey) {
|
|
continue
|
|
}
|
|
candidates = append(candidates, candidate)
|
|
}
|
|
if len(candidates) == 0 {
|
|
m.mu.RUnlock()
|
|
return nil, nil, &Error{Code: "auth_not_found", Message: "no auth available"}
|
|
}
|
|
selected, errPick := m.selector.Pick(ctx, provider, model, opts, candidates)
|
|
if errPick != nil {
|
|
m.mu.RUnlock()
|
|
return nil, nil, errPick
|
|
}
|
|
if selected == nil {
|
|
m.mu.RUnlock()
|
|
return nil, nil, &Error{Code: "auth_not_found", Message: "selector returned no auth"}
|
|
}
|
|
authCopy := selected.Clone()
|
|
m.mu.RUnlock()
|
|
if !selected.indexAssigned {
|
|
m.mu.Lock()
|
|
if current := m.auths[authCopy.ID]; current != nil && !current.indexAssigned {
|
|
current.EnsureIndex()
|
|
authCopy = current.Clone()
|
|
}
|
|
m.mu.Unlock()
|
|
}
|
|
return authCopy, executor, nil
|
|
}
|
|
|
|
func (m *Manager) pickNextMixed(ctx context.Context, providers []string, model string, opts cliproxyexecutor.Options, tried map[string]struct{}) (*Auth, ProviderExecutor, string, error) {
|
|
providerSet := make(map[string]struct{}, len(providers))
|
|
for _, provider := range providers {
|
|
p := strings.TrimSpace(strings.ToLower(provider))
|
|
if p == "" {
|
|
continue
|
|
}
|
|
providerSet[p] = struct{}{}
|
|
}
|
|
if len(providerSet) == 0 {
|
|
return nil, nil, "", &Error{Code: "provider_not_found", Message: "no provider supplied"}
|
|
}
|
|
|
|
m.mu.RLock()
|
|
candidates := make([]*Auth, 0, len(m.auths))
|
|
modelKey := strings.TrimSpace(model)
|
|
// Always use base model name (without thinking suffix) for auth matching.
|
|
if modelKey != "" {
|
|
parsed := thinking.ParseSuffix(modelKey)
|
|
if parsed.ModelName != "" {
|
|
modelKey = strings.TrimSpace(parsed.ModelName)
|
|
}
|
|
}
|
|
registryRef := registry.GetGlobalRegistry()
|
|
for _, candidate := range m.auths {
|
|
if candidate == nil || candidate.Disabled {
|
|
continue
|
|
}
|
|
providerKey := strings.TrimSpace(strings.ToLower(candidate.Provider))
|
|
if providerKey == "" {
|
|
continue
|
|
}
|
|
if _, ok := providerSet[providerKey]; !ok {
|
|
continue
|
|
}
|
|
if _, used := tried[candidate.ID]; used {
|
|
continue
|
|
}
|
|
if _, ok := m.executors[providerKey]; !ok {
|
|
continue
|
|
}
|
|
if modelKey != "" && registryRef != nil && !registryRef.ClientSupportsModel(candidate.ID, modelKey) {
|
|
continue
|
|
}
|
|
candidates = append(candidates, candidate)
|
|
}
|
|
if len(candidates) == 0 {
|
|
m.mu.RUnlock()
|
|
return nil, nil, "", &Error{Code: "auth_not_found", Message: "no auth available"}
|
|
}
|
|
selected, errPick := m.selector.Pick(ctx, "mixed", model, opts, candidates)
|
|
if errPick != nil {
|
|
m.mu.RUnlock()
|
|
return nil, nil, "", errPick
|
|
}
|
|
if selected == nil {
|
|
m.mu.RUnlock()
|
|
return nil, nil, "", &Error{Code: "auth_not_found", Message: "selector returned no auth"}
|
|
}
|
|
providerKey := strings.TrimSpace(strings.ToLower(selected.Provider))
|
|
executor, okExecutor := m.executors[providerKey]
|
|
if !okExecutor {
|
|
m.mu.RUnlock()
|
|
return nil, nil, "", &Error{Code: "executor_not_found", Message: "executor not registered"}
|
|
}
|
|
authCopy := selected.Clone()
|
|
m.mu.RUnlock()
|
|
if !selected.indexAssigned {
|
|
m.mu.Lock()
|
|
if current := m.auths[authCopy.ID]; current != nil && !current.indexAssigned {
|
|
current.EnsureIndex()
|
|
authCopy = current.Clone()
|
|
}
|
|
m.mu.Unlock()
|
|
}
|
|
return authCopy, executor, providerKey, nil
|
|
}
|
|
|
|
func (m *Manager) persist(ctx context.Context, auth *Auth) error {
|
|
if m.store == nil || auth == nil {
|
|
return nil
|
|
}
|
|
if shouldSkipPersist(ctx) {
|
|
return nil
|
|
}
|
|
if auth.Attributes != nil {
|
|
if v := strings.ToLower(strings.TrimSpace(auth.Attributes["runtime_only"])); v == "true" {
|
|
return nil
|
|
}
|
|
}
|
|
// Skip persistence when metadata is absent (e.g., runtime-only auths).
|
|
if auth.Metadata == nil {
|
|
return nil
|
|
}
|
|
_, err := m.store.Save(ctx, auth)
|
|
return err
|
|
}
|
|
|
|
// StartAutoRefresh launches a background loop that evaluates auth freshness
|
|
// every few seconds and triggers refresh operations when required.
|
|
// Only one loop is kept alive; starting a new one cancels the previous run.
|
|
func (m *Manager) StartAutoRefresh(parent context.Context, interval time.Duration) {
|
|
if interval <= 0 || interval > refreshCheckInterval {
|
|
interval = refreshCheckInterval
|
|
} else {
|
|
interval = refreshCheckInterval
|
|
}
|
|
if m.refreshCancel != nil {
|
|
m.refreshCancel()
|
|
m.refreshCancel = nil
|
|
}
|
|
ctx, cancel := context.WithCancel(parent)
|
|
m.refreshCancel = cancel
|
|
go func() {
|
|
ticker := time.NewTicker(interval)
|
|
defer ticker.Stop()
|
|
m.checkRefreshes(ctx)
|
|
for {
|
|
select {
|
|
case <-ctx.Done():
|
|
return
|
|
case <-ticker.C:
|
|
m.checkRefreshes(ctx)
|
|
}
|
|
}
|
|
}()
|
|
}
|
|
|
|
// StopAutoRefresh cancels the background refresh loop, if running.
|
|
func (m *Manager) StopAutoRefresh() {
|
|
if m.refreshCancel != nil {
|
|
m.refreshCancel()
|
|
m.refreshCancel = nil
|
|
}
|
|
}
|
|
|
|
func (m *Manager) checkRefreshes(ctx context.Context) {
|
|
// log.Debugf("checking refreshes")
|
|
now := time.Now()
|
|
snapshot := m.snapshotAuths()
|
|
for _, a := range snapshot {
|
|
typ, _ := a.AccountInfo()
|
|
if typ != "api_key" {
|
|
if !m.shouldRefresh(a, now) {
|
|
continue
|
|
}
|
|
log.Debugf("checking refresh for %s, %s, %s", a.Provider, a.ID, typ)
|
|
|
|
if exec := m.executorFor(a.Provider); exec == nil {
|
|
continue
|
|
}
|
|
if !m.markRefreshPending(a.ID, now) {
|
|
continue
|
|
}
|
|
go m.refreshAuth(ctx, a.ID)
|
|
}
|
|
}
|
|
}
|
|
|
|
func (m *Manager) snapshotAuths() []*Auth {
|
|
m.mu.RLock()
|
|
defer m.mu.RUnlock()
|
|
out := make([]*Auth, 0, len(m.auths))
|
|
for _, a := range m.auths {
|
|
out = append(out, a.Clone())
|
|
}
|
|
return out
|
|
}
|
|
|
|
func (m *Manager) shouldRefresh(a *Auth, now time.Time) bool {
|
|
if a == nil || a.Disabled {
|
|
return false
|
|
}
|
|
if !a.NextRefreshAfter.IsZero() && now.Before(a.NextRefreshAfter) {
|
|
return false
|
|
}
|
|
if evaluator, ok := a.Runtime.(RefreshEvaluator); ok && evaluator != nil {
|
|
return evaluator.ShouldRefresh(now, a)
|
|
}
|
|
|
|
lastRefresh := a.LastRefreshedAt
|
|
if lastRefresh.IsZero() {
|
|
if ts, ok := authLastRefreshTimestamp(a); ok {
|
|
lastRefresh = ts
|
|
}
|
|
}
|
|
|
|
expiry, hasExpiry := a.ExpirationTime()
|
|
|
|
if interval := authPreferredInterval(a); interval > 0 {
|
|
if hasExpiry && !expiry.IsZero() {
|
|
if !expiry.After(now) {
|
|
return true
|
|
}
|
|
if expiry.Sub(now) <= interval {
|
|
return true
|
|
}
|
|
}
|
|
if lastRefresh.IsZero() {
|
|
return true
|
|
}
|
|
return now.Sub(lastRefresh) >= interval
|
|
}
|
|
|
|
provider := strings.ToLower(a.Provider)
|
|
lead := ProviderRefreshLead(provider, a.Runtime)
|
|
if lead == nil {
|
|
return false
|
|
}
|
|
if *lead <= 0 {
|
|
if hasExpiry && !expiry.IsZero() {
|
|
return now.After(expiry)
|
|
}
|
|
return false
|
|
}
|
|
if hasExpiry && !expiry.IsZero() {
|
|
return time.Until(expiry) <= *lead
|
|
}
|
|
if !lastRefresh.IsZero() {
|
|
return now.Sub(lastRefresh) >= *lead
|
|
}
|
|
return true
|
|
}
|
|
|
|
func authPreferredInterval(a *Auth) time.Duration {
|
|
if a == nil {
|
|
return 0
|
|
}
|
|
if d := durationFromMetadata(a.Metadata, "refresh_interval_seconds", "refreshIntervalSeconds", "refresh_interval", "refreshInterval"); d > 0 {
|
|
return d
|
|
}
|
|
if d := durationFromAttributes(a.Attributes, "refresh_interval_seconds", "refreshIntervalSeconds", "refresh_interval", "refreshInterval"); d > 0 {
|
|
return d
|
|
}
|
|
return 0
|
|
}
|
|
|
|
func durationFromMetadata(meta map[string]any, keys ...string) time.Duration {
|
|
if len(meta) == 0 {
|
|
return 0
|
|
}
|
|
for _, key := range keys {
|
|
if val, ok := meta[key]; ok {
|
|
if dur := parseDurationValue(val); dur > 0 {
|
|
return dur
|
|
}
|
|
}
|
|
}
|
|
return 0
|
|
}
|
|
|
|
func durationFromAttributes(attrs map[string]string, keys ...string) time.Duration {
|
|
if len(attrs) == 0 {
|
|
return 0
|
|
}
|
|
for _, key := range keys {
|
|
if val, ok := attrs[key]; ok {
|
|
if dur := parseDurationString(val); dur > 0 {
|
|
return dur
|
|
}
|
|
}
|
|
}
|
|
return 0
|
|
}
|
|
|
|
func parseDurationValue(val any) time.Duration {
|
|
switch v := val.(type) {
|
|
case time.Duration:
|
|
if v <= 0 {
|
|
return 0
|
|
}
|
|
return v
|
|
case int:
|
|
if v <= 0 {
|
|
return 0
|
|
}
|
|
return time.Duration(v) * time.Second
|
|
case int32:
|
|
if v <= 0 {
|
|
return 0
|
|
}
|
|
return time.Duration(v) * time.Second
|
|
case int64:
|
|
if v <= 0 {
|
|
return 0
|
|
}
|
|
return time.Duration(v) * time.Second
|
|
case uint:
|
|
if v == 0 {
|
|
return 0
|
|
}
|
|
return time.Duration(v) * time.Second
|
|
case uint32:
|
|
if v == 0 {
|
|
return 0
|
|
}
|
|
return time.Duration(v) * time.Second
|
|
case uint64:
|
|
if v == 0 {
|
|
return 0
|
|
}
|
|
return time.Duration(v) * time.Second
|
|
case float32:
|
|
if v <= 0 {
|
|
return 0
|
|
}
|
|
return time.Duration(float64(v) * float64(time.Second))
|
|
case float64:
|
|
if v <= 0 {
|
|
return 0
|
|
}
|
|
return time.Duration(v * float64(time.Second))
|
|
case json.Number:
|
|
if i, err := v.Int64(); err == nil {
|
|
if i <= 0 {
|
|
return 0
|
|
}
|
|
return time.Duration(i) * time.Second
|
|
}
|
|
if f, err := v.Float64(); err == nil && f > 0 {
|
|
return time.Duration(f * float64(time.Second))
|
|
}
|
|
case string:
|
|
return parseDurationString(v)
|
|
}
|
|
return 0
|
|
}
|
|
|
|
func parseDurationString(raw string) time.Duration {
|
|
s := strings.TrimSpace(raw)
|
|
if s == "" {
|
|
return 0
|
|
}
|
|
if dur, err := time.ParseDuration(s); err == nil && dur > 0 {
|
|
return dur
|
|
}
|
|
if secs, err := strconv.ParseFloat(s, 64); err == nil && secs > 0 {
|
|
return time.Duration(secs * float64(time.Second))
|
|
}
|
|
return 0
|
|
}
|
|
|
|
func authLastRefreshTimestamp(a *Auth) (time.Time, bool) {
|
|
if a == nil {
|
|
return time.Time{}, false
|
|
}
|
|
if a.Metadata != nil {
|
|
if ts, ok := lookupMetadataTime(a.Metadata, "last_refresh", "lastRefresh", "last_refreshed_at", "lastRefreshedAt"); ok {
|
|
return ts, true
|
|
}
|
|
}
|
|
if a.Attributes != nil {
|
|
for _, key := range []string{"last_refresh", "lastRefresh", "last_refreshed_at", "lastRefreshedAt"} {
|
|
if val := strings.TrimSpace(a.Attributes[key]); val != "" {
|
|
if ts, ok := parseTimeValue(val); ok {
|
|
return ts, true
|
|
}
|
|
}
|
|
}
|
|
}
|
|
return time.Time{}, false
|
|
}
|
|
|
|
func lookupMetadataTime(meta map[string]any, keys ...string) (time.Time, bool) {
|
|
for _, key := range keys {
|
|
if val, ok := meta[key]; ok {
|
|
if ts, ok1 := parseTimeValue(val); ok1 {
|
|
return ts, true
|
|
}
|
|
}
|
|
}
|
|
return time.Time{}, false
|
|
}
|
|
|
|
func (m *Manager) markRefreshPending(id string, now time.Time) bool {
|
|
m.mu.Lock()
|
|
defer m.mu.Unlock()
|
|
auth, ok := m.auths[id]
|
|
if !ok || auth == nil || auth.Disabled {
|
|
return false
|
|
}
|
|
if !auth.NextRefreshAfter.IsZero() && now.Before(auth.NextRefreshAfter) {
|
|
return false
|
|
}
|
|
auth.NextRefreshAfter = now.Add(refreshPendingBackoff)
|
|
m.auths[id] = auth
|
|
return true
|
|
}
|
|
|
|
func (m *Manager) refreshAuth(ctx context.Context, id string) {
|
|
if ctx == nil {
|
|
ctx = context.Background()
|
|
}
|
|
m.mu.RLock()
|
|
auth := m.auths[id]
|
|
var exec ProviderExecutor
|
|
if auth != nil {
|
|
exec = m.executors[auth.Provider]
|
|
}
|
|
m.mu.RUnlock()
|
|
if auth == nil || exec == nil {
|
|
return
|
|
}
|
|
cloned := auth.Clone()
|
|
updated, err := exec.Refresh(ctx, cloned)
|
|
if err != nil && errors.Is(err, context.Canceled) {
|
|
log.Debugf("refresh canceled for %s, %s", auth.Provider, auth.ID)
|
|
return
|
|
}
|
|
log.Debugf("refreshed %s, %s, %v", auth.Provider, auth.ID, err)
|
|
now := time.Now()
|
|
if err != nil {
|
|
m.mu.Lock()
|
|
if current := m.auths[id]; current != nil {
|
|
current.NextRefreshAfter = now.Add(refreshFailureBackoff)
|
|
current.LastError = &Error{Message: err.Error()}
|
|
m.auths[id] = current
|
|
}
|
|
m.mu.Unlock()
|
|
return
|
|
}
|
|
if updated == nil {
|
|
updated = cloned
|
|
}
|
|
// Preserve runtime created by the executor during Refresh.
|
|
// If executor didn't set one, fall back to the previous runtime.
|
|
if updated.Runtime == nil {
|
|
updated.Runtime = auth.Runtime
|
|
}
|
|
updated.LastRefreshedAt = now
|
|
updated.NextRefreshAfter = time.Time{}
|
|
updated.LastError = nil
|
|
updated.UpdatedAt = now
|
|
_, _ = m.Update(ctx, updated)
|
|
}
|
|
|
|
func (m *Manager) executorFor(provider string) ProviderExecutor {
|
|
m.mu.RLock()
|
|
defer m.mu.RUnlock()
|
|
return m.executors[provider]
|
|
}
|
|
|
|
// roundTripperContextKey is an unexported context key type to avoid collisions.
|
|
type roundTripperContextKey struct{}
|
|
|
|
// roundTripperFor retrieves an HTTP RoundTripper for the given auth if a provider is registered.
|
|
func (m *Manager) roundTripperFor(auth *Auth) http.RoundTripper {
|
|
m.mu.RLock()
|
|
p := m.rtProvider
|
|
m.mu.RUnlock()
|
|
if p == nil || auth == nil {
|
|
return nil
|
|
}
|
|
return p.RoundTripperFor(auth)
|
|
}
|
|
|
|
// RoundTripperProvider defines a minimal provider of per-auth HTTP transports.
|
|
type RoundTripperProvider interface {
|
|
RoundTripperFor(auth *Auth) http.RoundTripper
|
|
}
|
|
|
|
// RequestPreparer is an optional interface that provider executors can implement
|
|
// to mutate outbound HTTP requests with provider credentials.
|
|
type RequestPreparer interface {
|
|
PrepareRequest(req *http.Request, auth *Auth) error
|
|
}
|
|
|
|
func executorKeyFromAuth(auth *Auth) string {
|
|
if auth == nil {
|
|
return ""
|
|
}
|
|
if auth.Attributes != nil {
|
|
providerKey := strings.TrimSpace(auth.Attributes["provider_key"])
|
|
compatName := strings.TrimSpace(auth.Attributes["compat_name"])
|
|
if compatName != "" {
|
|
if providerKey == "" {
|
|
providerKey = compatName
|
|
}
|
|
return strings.ToLower(providerKey)
|
|
}
|
|
}
|
|
return strings.ToLower(strings.TrimSpace(auth.Provider))
|
|
}
|
|
|
|
// logEntryWithRequestID returns a logrus entry with request_id field if available in context.
|
|
func logEntryWithRequestID(ctx context.Context) *log.Entry {
|
|
if ctx == nil {
|
|
return log.NewEntry(log.StandardLogger())
|
|
}
|
|
if reqID := logging.GetRequestID(ctx); reqID != "" {
|
|
return log.WithField("request_id", reqID)
|
|
}
|
|
return log.NewEntry(log.StandardLogger())
|
|
}
|
|
|
|
func debugLogAuthSelection(entry *log.Entry, auth *Auth, provider string, model string) {
|
|
if !log.IsLevelEnabled(log.DebugLevel) {
|
|
return
|
|
}
|
|
if entry == nil || auth == nil {
|
|
return
|
|
}
|
|
accountType, accountInfo := auth.AccountInfo()
|
|
proxyInfo := auth.ProxyInfo()
|
|
suffix := ""
|
|
if proxyInfo != "" {
|
|
suffix = " " + proxyInfo
|
|
}
|
|
switch accountType {
|
|
case "api_key":
|
|
entry.Debugf("Use API key %s for model %s%s", util.HideAPIKey(accountInfo), model, suffix)
|
|
case "oauth":
|
|
ident := formatOauthIdentity(auth, provider, accountInfo)
|
|
entry.Debugf("Use OAuth %s for model %s%s", ident, model, suffix)
|
|
}
|
|
}
|
|
|
|
func formatOauthIdentity(auth *Auth, provider string, accountInfo string) string {
|
|
if auth == nil {
|
|
return ""
|
|
}
|
|
// Prefer the auth's provider when available.
|
|
providerName := strings.TrimSpace(auth.Provider)
|
|
if providerName == "" {
|
|
providerName = strings.TrimSpace(provider)
|
|
}
|
|
// Only log the basename to avoid leaking host paths.
|
|
// FileName may be unset for some auth backends; fall back to ID.
|
|
authFile := strings.TrimSpace(auth.FileName)
|
|
if authFile == "" {
|
|
authFile = strings.TrimSpace(auth.ID)
|
|
}
|
|
if authFile != "" {
|
|
authFile = filepath.Base(authFile)
|
|
}
|
|
parts := make([]string, 0, 3)
|
|
if providerName != "" {
|
|
parts = append(parts, "provider="+providerName)
|
|
}
|
|
if authFile != "" {
|
|
parts = append(parts, "auth_file="+authFile)
|
|
}
|
|
if len(parts) == 0 {
|
|
return accountInfo
|
|
}
|
|
return strings.Join(parts, " ")
|
|
}
|
|
|
|
// InjectCredentials delegates per-provider HTTP request preparation when supported.
|
|
// If the registered executor for the auth provider implements RequestPreparer,
|
|
// it will be invoked to modify the request (e.g., add headers).
|
|
func (m *Manager) InjectCredentials(req *http.Request, authID string) error {
|
|
if req == nil || authID == "" {
|
|
return nil
|
|
}
|
|
m.mu.RLock()
|
|
a := m.auths[authID]
|
|
var exec ProviderExecutor
|
|
if a != nil {
|
|
exec = m.executors[executorKeyFromAuth(a)]
|
|
}
|
|
m.mu.RUnlock()
|
|
if a == nil || exec == nil {
|
|
return nil
|
|
}
|
|
if p, ok := exec.(RequestPreparer); ok && p != nil {
|
|
return p.PrepareRequest(req, a)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// PrepareHttpRequest injects provider credentials into the supplied HTTP request.
|
|
func (m *Manager) PrepareHttpRequest(ctx context.Context, auth *Auth, req *http.Request) error {
|
|
if m == nil {
|
|
return &Error{Code: "provider_not_found", Message: "manager is nil"}
|
|
}
|
|
if auth == nil {
|
|
return &Error{Code: "auth_not_found", Message: "auth is nil"}
|
|
}
|
|
if req == nil {
|
|
return &Error{Code: "invalid_request", Message: "http request is nil"}
|
|
}
|
|
if ctx != nil {
|
|
*req = *req.WithContext(ctx)
|
|
}
|
|
providerKey := executorKeyFromAuth(auth)
|
|
if providerKey == "" {
|
|
return &Error{Code: "provider_not_found", Message: "auth provider is empty"}
|
|
}
|
|
exec := m.executorFor(providerKey)
|
|
if exec == nil {
|
|
return &Error{Code: "provider_not_found", Message: "executor not registered for provider: " + providerKey}
|
|
}
|
|
preparer, ok := exec.(RequestPreparer)
|
|
if !ok || preparer == nil {
|
|
return &Error{Code: "not_supported", Message: "executor does not support http request preparation"}
|
|
}
|
|
return preparer.PrepareRequest(req, auth)
|
|
}
|
|
|
|
// NewHttpRequest constructs a new HTTP request and injects provider credentials into it.
|
|
func (m *Manager) NewHttpRequest(ctx context.Context, auth *Auth, method, targetURL string, body []byte, headers http.Header) (*http.Request, error) {
|
|
if ctx == nil {
|
|
ctx = context.Background()
|
|
}
|
|
method = strings.TrimSpace(method)
|
|
if method == "" {
|
|
method = http.MethodGet
|
|
}
|
|
var reader io.Reader
|
|
if body != nil {
|
|
reader = bytes.NewReader(body)
|
|
}
|
|
httpReq, err := http.NewRequestWithContext(ctx, method, targetURL, reader)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if headers != nil {
|
|
httpReq.Header = headers.Clone()
|
|
}
|
|
if errPrepare := m.PrepareHttpRequest(ctx, auth, httpReq); errPrepare != nil {
|
|
return nil, errPrepare
|
|
}
|
|
return httpReq, nil
|
|
}
|
|
|
|
// HttpRequest injects provider credentials into the supplied HTTP request and executes it.
|
|
func (m *Manager) HttpRequest(ctx context.Context, auth *Auth, req *http.Request) (*http.Response, error) {
|
|
if m == nil {
|
|
return nil, &Error{Code: "provider_not_found", Message: "manager is nil"}
|
|
}
|
|
if auth == nil {
|
|
return nil, &Error{Code: "auth_not_found", Message: "auth is nil"}
|
|
}
|
|
if req == nil {
|
|
return nil, &Error{Code: "invalid_request", Message: "http request is nil"}
|
|
}
|
|
providerKey := executorKeyFromAuth(auth)
|
|
if providerKey == "" {
|
|
return nil, &Error{Code: "provider_not_found", Message: "auth provider is empty"}
|
|
}
|
|
exec := m.executorFor(providerKey)
|
|
if exec == nil {
|
|
return nil, &Error{Code: "provider_not_found", Message: "executor not registered for provider: " + providerKey}
|
|
}
|
|
return exec.HttpRequest(ctx, auth, req)
|
|
}
|