Files
CLIProxyAPI/sdk/cliproxy/auth/manager.go
Luis Pater 837ae1b1b3 chore(logging): add debug logs for executor Refresh methods
- Introduced `logrus` for structured debugging across all executors.
- Added debug log messages in `Refresh` methods for better traceability.
- Updated `Manager` to log additional details during refresh checks.
2025-09-22 20:03:31 +08:00

917 lines
25 KiB
Go

package auth
import (
"context"
"encoding/json"
"errors"
"net/http"
"strconv"
"strings"
"sync"
"time"
"github.com/google/uuid"
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
log "github.com/sirupsen/logrus"
)
// ProviderExecutor defines the contract required by Manager to execute provider calls.
type ProviderExecutor interface {
// Identifier returns the provider key handled by this executor.
Identifier() string
// Execute handles non-streaming execution and returns the provider response payload.
Execute(ctx context.Context, auth *Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error)
// ExecuteStream handles streaming execution and returns a channel of provider chunks.
ExecuteStream(ctx context.Context, auth *Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (<-chan cliproxyexecutor.StreamChunk, error)
// Refresh attempts to refresh provider credentials and returns the updated auth state.
Refresh(ctx context.Context, auth *Auth) (*Auth, error)
}
// RefreshEvaluator allows runtime state to override refresh decisions.
type RefreshEvaluator interface {
ShouldRefresh(now time.Time, auth *Auth) bool
}
const (
refreshCheckInterval = 5 * time.Second
refreshPendingBackoff = time.Minute
refreshFailureBackoff = 5 * time.Minute
)
// Result captures execution outcome used to adjust auth state.
type Result struct {
// AuthID references the auth that produced this result.
AuthID string
// Provider is copied for convenience when emitting hooks.
Provider string
// Model is the upstream model identifier used for the request.
Model string
// Success marks whether the execution succeeded.
Success bool
// Error describes the failure when Success is false.
Error *Error
}
// Selector chooses an auth candidate for execution.
type Selector interface {
Pick(ctx context.Context, provider, model string, opts cliproxyexecutor.Options, auths []*Auth) (*Auth, error)
}
// Hook captures lifecycle callbacks for observing auth changes.
type Hook interface {
// OnAuthRegistered fires when a new auth is registered.
OnAuthRegistered(ctx context.Context, auth *Auth)
// OnAuthUpdated fires when an existing auth changes state.
OnAuthUpdated(ctx context.Context, auth *Auth)
// OnResult fires when execution result is recorded.
OnResult(ctx context.Context, result Result)
}
// NoopHook provides optional hook defaults.
type NoopHook struct{}
// OnAuthRegistered implements Hook.
func (NoopHook) OnAuthRegistered(context.Context, *Auth) {}
// OnAuthUpdated implements Hook.
func (NoopHook) OnAuthUpdated(context.Context, *Auth) {}
// OnResult implements Hook.
func (NoopHook) OnResult(context.Context, Result) {}
// Manager orchestrates auth lifecycle, selection, execution, and persistence.
type Manager struct {
store Store
executors map[string]ProviderExecutor
selector Selector
hook Hook
mu sync.RWMutex
auths map[string]*Auth
// providerOffsets tracks per-model provider rotation state for multi-provider routing.
providerOffsets map[string]int
// Optional HTTP RoundTripper provider injected by host.
rtProvider RoundTripperProvider
// Auto refresh state
refreshCancel context.CancelFunc
}
// NewManager constructs a manager with optional custom selector and hook.
func NewManager(store Store, selector Selector, hook Hook) *Manager {
if selector == nil {
selector = &RoundRobinSelector{}
}
if hook == nil {
hook = NoopHook{}
}
return &Manager{
store: store,
executors: make(map[string]ProviderExecutor),
selector: selector,
hook: hook,
auths: make(map[string]*Auth),
providerOffsets: make(map[string]int),
}
}
// SetStore swaps the underlying persistence store.
func (m *Manager) SetStore(store Store) {
m.mu.Lock()
defer m.mu.Unlock()
m.store = store
}
// SetRoundTripperProvider register a provider that returns a per-auth RoundTripper.
func (m *Manager) SetRoundTripperProvider(p RoundTripperProvider) {
m.mu.Lock()
m.rtProvider = p
m.mu.Unlock()
}
// RegisterExecutor registers a provider executor with the manager.
func (m *Manager) RegisterExecutor(executor ProviderExecutor) {
if executor == nil {
return
}
m.mu.Lock()
defer m.mu.Unlock()
m.executors[executor.Identifier()] = executor
}
// Register inserts a new auth entry into the manager.
func (m *Manager) Register(ctx context.Context, auth *Auth) (*Auth, error) {
if auth == nil {
return nil, nil
}
if auth.ID == "" {
auth.ID = uuid.NewString()
}
m.mu.Lock()
m.auths[auth.ID] = auth.Clone()
m.mu.Unlock()
_ = m.persist(ctx, auth)
m.hook.OnAuthRegistered(ctx, auth.Clone())
return auth.Clone(), nil
}
// Update replaces an existing auth entry and notifies hooks.
func (m *Manager) Update(ctx context.Context, auth *Auth) (*Auth, error) {
if auth == nil || auth.ID == "" {
return nil, nil
}
m.mu.Lock()
m.auths[auth.ID] = auth.Clone()
m.mu.Unlock()
_ = m.persist(ctx, auth)
m.hook.OnAuthUpdated(ctx, auth.Clone())
return auth.Clone(), nil
}
// Load resets manager state from the backing store.
func (m *Manager) Load(ctx context.Context) error {
m.mu.Lock()
defer m.mu.Unlock()
if m.store == nil {
return nil
}
items, err := m.store.List(ctx)
if err != nil {
return err
}
m.auths = make(map[string]*Auth, len(items))
for _, auth := range items {
if auth == nil || auth.ID == "" {
continue
}
m.auths[auth.ID] = auth.Clone()
}
return nil
}
// Execute performs a non-streaming execution using the configured selector and executor.
// It supports multiple providers for the same model and round-robins the starting provider per model.
func (m *Manager) Execute(ctx context.Context, providers []string, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
normalized := m.normalizeProviders(providers)
if len(normalized) == 0 {
return cliproxyexecutor.Response{}, &Error{Code: "provider_not_found", Message: "no provider supplied"}
}
rotated := m.rotateProviders(req.Model, normalized)
defer m.advanceProviderCursor(req.Model, normalized)
var lastErr error
for _, provider := range rotated {
resp, errExec := m.executeWithProvider(ctx, provider, req, opts)
if errExec == nil {
return resp, nil
}
lastErr = errExec
}
if lastErr != nil {
return cliproxyexecutor.Response{}, lastErr
}
return cliproxyexecutor.Response{}, &Error{Code: "auth_not_found", Message: "no auth available"}
}
// ExecuteStream performs a streaming execution using the configured selector and executor.
// It supports multiple providers for the same model and round-robins the starting provider per model.
func (m *Manager) ExecuteStream(ctx context.Context, providers []string, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (<-chan cliproxyexecutor.StreamChunk, error) {
normalized := m.normalizeProviders(providers)
if len(normalized) == 0 {
return nil, &Error{Code: "provider_not_found", Message: "no provider supplied"}
}
rotated := m.rotateProviders(req.Model, normalized)
defer m.advanceProviderCursor(req.Model, normalized)
var lastErr error
for _, provider := range rotated {
chunks, errStream := m.executeStreamWithProvider(ctx, provider, req, opts)
if errStream == nil {
return chunks, nil
}
lastErr = errStream
}
if lastErr != nil {
return nil, lastErr
}
return nil, &Error{Code: "auth_not_found", Message: "no auth available"}
}
func (m *Manager) executeWithProvider(ctx context.Context, provider string, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
if provider == "" {
return cliproxyexecutor.Response{}, &Error{Code: "provider_not_found", Message: "provider identifier is empty"}
}
tried := make(map[string]struct{})
var lastErr error
for {
auth, executor, errPick := m.pickNext(ctx, provider, req.Model, opts, tried)
if errPick != nil {
if lastErr != nil {
return cliproxyexecutor.Response{}, lastErr
}
return cliproxyexecutor.Response{}, errPick
}
accountType, accountInfo := auth.AccountInfo()
if accountType == "api_key" {
log.Debugf("Use API key %s for model %s", util.HideAPIKey(accountInfo), req.Model)
} else if accountType == "oauth" {
log.Debugf("Use OAuth %s for model %s", accountInfo, req.Model)
} else if accountType == "cookie" {
log.Debugf("Use Cookie %s for model %s", util.HideAPIKey(accountInfo), req.Model)
}
tried[auth.ID] = struct{}{}
execCtx := ctx
if rt := m.roundTripperFor(auth); rt != nil {
execCtx = context.WithValue(execCtx, roundTripperContextKey{}, rt)
execCtx = context.WithValue(execCtx, "cliproxy.roundtripper", rt)
}
resp, errExec := executor.Execute(execCtx, auth, req, opts)
result := Result{AuthID: auth.ID, Provider: provider, Model: req.Model, Success: errExec == nil}
if errExec != nil {
result.Error = &Error{Message: errExec.Error()}
var se cliproxyexecutor.StatusError
if errors.As(errExec, &se) && se != nil {
result.Error.HTTPStatus = se.StatusCode()
}
m.MarkResult(execCtx, result)
lastErr = errExec
continue
}
m.MarkResult(execCtx, result)
return resp, nil
}
}
func (m *Manager) executeStreamWithProvider(ctx context.Context, provider string, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (<-chan cliproxyexecutor.StreamChunk, error) {
if provider == "" {
return nil, &Error{Code: "provider_not_found", Message: "provider identifier is empty"}
}
tried := make(map[string]struct{})
var lastErr error
for {
auth, executor, errPick := m.pickNext(ctx, provider, req.Model, opts, tried)
if errPick != nil {
if lastErr != nil {
return nil, lastErr
}
return nil, errPick
}
accountType, accountInfo := auth.AccountInfo()
if accountType == "api_key" {
log.Debugf("Use API key %s for model %s", util.HideAPIKey(accountInfo), req.Model)
} else if accountType == "oauth" {
log.Debugf("Use OAuth %s for model %s", accountInfo, req.Model)
} else if accountType == "cookie" {
log.Debugf("Use Cookie %s for model %s", util.HideAPIKey(accountInfo), req.Model)
}
tried[auth.ID] = struct{}{}
execCtx := ctx
if rt := m.roundTripperFor(auth); rt != nil {
execCtx = context.WithValue(execCtx, roundTripperContextKey{}, rt)
execCtx = context.WithValue(execCtx, "cliproxy.roundtripper", rt)
}
chunks, errStream := executor.ExecuteStream(execCtx, auth, req, opts)
if errStream != nil {
rerr := &Error{Message: errStream.Error()}
var se cliproxyexecutor.StatusError
if errors.As(errStream, &se) && se != nil {
rerr.HTTPStatus = se.StatusCode()
}
result := Result{AuthID: auth.ID, Provider: provider, Model: req.Model, Success: false, Error: rerr}
m.MarkResult(execCtx, result)
lastErr = errStream
continue
}
out := make(chan cliproxyexecutor.StreamChunk)
go func(streamCtx context.Context, streamAuth *Auth, streamProvider string, streamChunks <-chan cliproxyexecutor.StreamChunk) {
defer close(out)
var failed bool
for chunk := range streamChunks {
if chunk.Err != nil && !failed {
failed = true
rerr := &Error{Message: chunk.Err.Error()}
var se cliproxyexecutor.StatusError
if errors.As(chunk.Err, &se) && se != nil {
rerr.HTTPStatus = se.StatusCode()
}
m.MarkResult(streamCtx, Result{AuthID: streamAuth.ID, Provider: streamProvider, Model: req.Model, Success: false, Error: rerr})
}
out <- chunk
}
if !failed {
m.MarkResult(streamCtx, Result{AuthID: streamAuth.ID, Provider: streamProvider, Model: req.Model, Success: true})
}
}(execCtx, auth.Clone(), provider, chunks)
return out, nil
}
}
func (m *Manager) normalizeProviders(providers []string) []string {
if len(providers) == 0 {
return nil
}
result := make([]string, 0, len(providers))
seen := make(map[string]struct{}, len(providers))
for _, provider := range providers {
p := strings.TrimSpace(strings.ToLower(provider))
if p == "" {
continue
}
if _, ok := seen[p]; ok {
continue
}
seen[p] = struct{}{}
result = append(result, p)
}
return result
}
func (m *Manager) rotateProviders(model string, providers []string) []string {
if len(providers) == 0 {
return nil
}
m.mu.RLock()
offset := m.providerOffsets[model]
m.mu.RUnlock()
if len(providers) > 0 {
offset %= len(providers)
}
if offset < 0 {
offset = 0
}
if offset == 0 {
return providers
}
rotated := make([]string, 0, len(providers))
rotated = append(rotated, providers[offset:]...)
rotated = append(rotated, providers[:offset]...)
return rotated
}
func (m *Manager) advanceProviderCursor(model string, providers []string) {
if len(providers) == 0 {
m.mu.Lock()
delete(m.providerOffsets, model)
m.mu.Unlock()
return
}
m.mu.Lock()
current := m.providerOffsets[model]
m.providerOffsets[model] = (current + 1) % len(providers)
m.mu.Unlock()
}
// MarkResult records an execution result and notifies hooks.
func (m *Manager) MarkResult(ctx context.Context, result Result) {
if result.AuthID == "" {
return
}
// Update in-memory auth status based on result.
m.mu.Lock()
if auth, ok := m.auths[result.AuthID]; ok && auth != nil {
now := time.Now()
if result.Success {
// Clear transient error/quota flags on success.
auth.Unavailable = false
auth.Status = StatusActive
auth.StatusMessage = ""
auth.Quota.Exceeded = false
auth.Quota.Reason = ""
auth.Quota.NextRecoverAt = time.Time{}
auth.LastError = nil
auth.UpdatedAt = now
if result.Model != "" {
registry.GetGlobalRegistry().ClearModelQuotaExceeded(auth.ID, result.Model)
}
} else {
// Default transient error state.
auth.Unavailable = true
auth.Status = StatusError
auth.UpdatedAt = now
if result.Error != nil {
auth.LastError = &Error{Code: result.Error.Code, Message: result.Error.Message, Retryable: result.Error.Retryable}
}
// If the error carries a status code, adjust backoff/quota accordingly.
// 401 -> auth issue; 402/429 -> quota; 5xx -> transient.
var statusCode int
if se, isOk := any(result.Error).(interface{ StatusCode() int }); isOk && se != nil {
statusCode = se.StatusCode()
}
switch statusCode {
case 401:
auth.StatusMessage = "unauthorized"
auth.NextRefreshAfter = now.Add(5 * time.Minute)
case 402, 429:
auth.StatusMessage = "quota exhausted"
auth.Quota.Exceeded = true
auth.Quota.Reason = "quota"
auth.Quota.NextRecoverAt = now.Add(10 * time.Minute)
auth.NextRefreshAfter = auth.Quota.NextRecoverAt
if result.Model != "" {
registry.GetGlobalRegistry().SetModelQuotaExceeded(auth.ID, result.Model)
}
case 403, 408, 500, 502, 503, 504:
auth.StatusMessage = "transient upstream error"
auth.NextRefreshAfter = now.Add(1 * time.Minute)
default:
// keep generic
if auth.StatusMessage == "" {
auth.StatusMessage = "request failed"
}
}
}
// Persist best-effort (only metadata is stored for file store).
_ = m.persist(ctx, auth)
}
m.mu.Unlock()
m.hook.OnResult(ctx, result)
}
// List returns all auth entries currently known by the manager.
func (m *Manager) List() []*Auth {
m.mu.RLock()
defer m.mu.RUnlock()
list := make([]*Auth, 0, len(m.auths))
for _, auth := range m.auths {
list = append(list, auth.Clone())
}
return list
}
// GetByID retrieves an auth entry by its ID.
func (m *Manager) GetByID(id string) (*Auth, bool) {
if id == "" {
return nil, false
}
m.mu.RLock()
defer m.mu.RUnlock()
auth, ok := m.auths[id]
if !ok {
return nil, false
}
return auth.Clone(), true
}
func (m *Manager) pickNext(ctx context.Context, provider, model string, opts cliproxyexecutor.Options, tried map[string]struct{}) (*Auth, ProviderExecutor, error) {
m.mu.RLock()
executor, okExecutor := m.executors[provider]
if !okExecutor {
m.mu.RUnlock()
return nil, nil, &Error{Code: "executor_not_found", Message: "executor not registered"}
}
candidates := make([]*Auth, 0, len(m.auths))
for _, auth := range m.auths {
if auth.Provider != provider || auth.Disabled {
continue
}
if _, used := tried[auth.ID]; used {
continue
}
candidates = append(candidates, auth.Clone())
}
m.mu.RUnlock()
if len(candidates) == 0 {
return nil, nil, &Error{Code: "auth_not_found", Message: "no auth available"}
}
auth, errPick := m.selector.Pick(ctx, provider, model, opts, candidates)
if errPick != nil {
return nil, nil, errPick
}
if auth == nil {
return nil, nil, &Error{Code: "auth_not_found", Message: "selector returned no auth"}
}
return auth, executor, nil
}
func (m *Manager) persist(ctx context.Context, auth *Auth) error {
if m.store == nil || auth == nil {
return nil
}
// Skip persistence when metadata is absent (e.g., runtime-only auths).
if auth.Metadata == nil {
return nil
}
return m.store.Save(ctx, auth)
}
// StartAutoRefresh launches a background loop that evaluates auth freshness
// every few seconds and triggers refresh operations when required.
// Only one loop is kept alive; starting a new one cancels the previous run.
func (m *Manager) StartAutoRefresh(parent context.Context, interval time.Duration) {
if interval <= 0 || interval > refreshCheckInterval {
interval = refreshCheckInterval
} else {
interval = refreshCheckInterval
}
if m.refreshCancel != nil {
m.refreshCancel()
m.refreshCancel = nil
}
ctx, cancel := context.WithCancel(parent)
m.refreshCancel = cancel
go func() {
ticker := time.NewTicker(interval)
defer ticker.Stop()
m.checkRefreshes(ctx)
for {
select {
case <-ctx.Done():
return
case <-ticker.C:
m.checkRefreshes(ctx)
}
}
}()
}
// StopAutoRefresh cancels the background refresh loop, if running.
func (m *Manager) StopAutoRefresh() {
if m.refreshCancel != nil {
m.refreshCancel()
m.refreshCancel = nil
}
}
func (m *Manager) checkRefreshes(ctx context.Context) {
log.Debugf("checking refreshes")
now := time.Now()
snapshot := m.snapshotAuths()
for _, a := range snapshot {
log.Debugf("checking refresh for %s, %s", a.Provider, a.ID)
if !m.shouldRefresh(a, now) {
continue
}
if exec := m.executorFor(a.Provider); exec == nil {
continue
}
if !m.markRefreshPending(a.ID, now) {
continue
}
go m.refreshAuth(ctx, a.ID)
}
}
func (m *Manager) snapshotAuths() []*Auth {
m.mu.RLock()
defer m.mu.RUnlock()
out := make([]*Auth, 0, len(m.auths))
for _, a := range m.auths {
out = append(out, a.Clone())
}
return out
}
func (m *Manager) shouldRefresh(a *Auth, now time.Time) bool {
if a == nil || a.Disabled {
return false
}
if !a.NextRefreshAfter.IsZero() && now.Before(a.NextRefreshAfter) {
return false
}
if evaluator, ok := a.Runtime.(RefreshEvaluator); ok && evaluator != nil {
return evaluator.ShouldRefresh(now, a)
}
lastRefresh := a.LastRefreshedAt
if lastRefresh.IsZero() {
if ts, ok := authLastRefreshTimestamp(a); ok {
lastRefresh = ts
}
}
expiry, hasExpiry := a.ExpirationTime()
if interval := authPreferredInterval(a); interval > 0 {
if hasExpiry && !expiry.IsZero() {
if !expiry.After(now) {
return true
}
if expiry.Sub(now) <= interval {
return true
}
}
if lastRefresh.IsZero() {
return true
}
return now.Sub(lastRefresh) >= interval
}
provider := strings.ToLower(a.Provider)
lead := ProviderRefreshLead(provider, a.Runtime)
if lead <= 0 {
if hasExpiry && !expiry.IsZero() {
return now.After(expiry)
}
return false
}
if hasExpiry && !expiry.IsZero() {
return time.Until(expiry) <= lead
}
if !lastRefresh.IsZero() {
return now.Sub(lastRefresh) >= lead
}
return true
}
func authPreferredInterval(a *Auth) time.Duration {
if a == nil {
return 0
}
if d := durationFromMetadata(a.Metadata, "refresh_interval_seconds", "refreshIntervalSeconds", "refresh_interval", "refreshInterval"); d > 0 {
return d
}
if d := durationFromAttributes(a.Attributes, "refresh_interval_seconds", "refreshIntervalSeconds", "refresh_interval", "refreshInterval"); d > 0 {
return d
}
return 0
}
func durationFromMetadata(meta map[string]any, keys ...string) time.Duration {
if len(meta) == 0 {
return 0
}
for _, key := range keys {
if val, ok := meta[key]; ok {
if dur := parseDurationValue(val); dur > 0 {
return dur
}
}
}
return 0
}
func durationFromAttributes(attrs map[string]string, keys ...string) time.Duration {
if len(attrs) == 0 {
return 0
}
for _, key := range keys {
if val, ok := attrs[key]; ok {
if dur := parseDurationString(val); dur > 0 {
return dur
}
}
}
return 0
}
func parseDurationValue(val any) time.Duration {
switch v := val.(type) {
case time.Duration:
if v <= 0 {
return 0
}
return v
case int:
if v <= 0 {
return 0
}
return time.Duration(v) * time.Second
case int32:
if v <= 0 {
return 0
}
return time.Duration(v) * time.Second
case int64:
if v <= 0 {
return 0
}
return time.Duration(v) * time.Second
case uint:
if v == 0 {
return 0
}
return time.Duration(v) * time.Second
case uint32:
if v == 0 {
return 0
}
return time.Duration(v) * time.Second
case uint64:
if v == 0 {
return 0
}
return time.Duration(v) * time.Second
case float32:
if v <= 0 {
return 0
}
return time.Duration(float64(v) * float64(time.Second))
case float64:
if v <= 0 {
return 0
}
return time.Duration(v * float64(time.Second))
case json.Number:
if i, err := v.Int64(); err == nil {
if i <= 0 {
return 0
}
return time.Duration(i) * time.Second
}
if f, err := v.Float64(); err == nil && f > 0 {
return time.Duration(f * float64(time.Second))
}
case string:
return parseDurationString(v)
}
return 0
}
func parseDurationString(raw string) time.Duration {
s := strings.TrimSpace(raw)
if s == "" {
return 0
}
if dur, err := time.ParseDuration(s); err == nil && dur > 0 {
return dur
}
if secs, err := strconv.ParseFloat(s, 64); err == nil && secs > 0 {
return time.Duration(secs * float64(time.Second))
}
return 0
}
func authLastRefreshTimestamp(a *Auth) (time.Time, bool) {
if a == nil {
return time.Time{}, false
}
if a.Metadata != nil {
if ts, ok := lookupMetadataTime(a.Metadata, "last_refresh", "lastRefresh", "last_refreshed_at", "lastRefreshedAt"); ok {
return ts, true
}
}
if a.Attributes != nil {
for _, key := range []string{"last_refresh", "lastRefresh", "last_refreshed_at", "lastRefreshedAt"} {
if val := strings.TrimSpace(a.Attributes[key]); val != "" {
if ts, ok := parseTimeValue(val); ok {
return ts, true
}
}
}
}
return time.Time{}, false
}
func lookupMetadataTime(meta map[string]any, keys ...string) (time.Time, bool) {
for _, key := range keys {
if val, ok := meta[key]; ok {
if ts, ok1 := parseTimeValue(val); ok1 {
return ts, true
}
}
}
return time.Time{}, false
}
func (m *Manager) markRefreshPending(id string, now time.Time) bool {
m.mu.Lock()
defer m.mu.Unlock()
auth, ok := m.auths[id]
if !ok || auth == nil || auth.Disabled {
return false
}
if !auth.NextRefreshAfter.IsZero() && now.Before(auth.NextRefreshAfter) {
return false
}
auth.NextRefreshAfter = now.Add(refreshPendingBackoff)
m.auths[id] = auth
return true
}
func (m *Manager) refreshAuth(ctx context.Context, id string) {
m.mu.RLock()
auth := m.auths[id]
var exec ProviderExecutor
if auth != nil {
exec = m.executors[auth.Provider]
}
m.mu.RUnlock()
if auth == nil || exec == nil {
return
}
cloned := auth.Clone()
updated, err := exec.Refresh(ctx, cloned)
now := time.Now()
if err != nil {
m.mu.Lock()
if current := m.auths[id]; current != nil {
current.NextRefreshAfter = now.Add(refreshFailureBackoff)
current.LastError = &Error{Message: err.Error()}
m.auths[id] = current
}
m.mu.Unlock()
return
}
if updated == nil {
updated = cloned
}
updated.Runtime = auth.Runtime
updated.LastRefreshedAt = now
updated.NextRefreshAfter = time.Time{}
updated.LastError = nil
updated.UpdatedAt = now
_, _ = m.Update(ctx, updated)
}
func (m *Manager) executorFor(provider string) ProviderExecutor {
m.mu.RLock()
defer m.mu.RUnlock()
return m.executors[provider]
}
// roundTripperContextKey is an unexported context key type to avoid collisions.
type roundTripperContextKey struct{}
// roundTripperFor retrieves an HTTP RoundTripper for the given auth if a provider is registered.
func (m *Manager) roundTripperFor(auth *Auth) http.RoundTripper {
m.mu.RLock()
p := m.rtProvider
m.mu.RUnlock()
if p == nil || auth == nil {
return nil
}
return p.RoundTripperFor(auth)
}
// RoundTripperProvider defines a minimal provider of per-auth HTTP transports.
type RoundTripperProvider interface {
RoundTripperFor(auth *Auth) http.RoundTripper
}
// RequestPreparer is an optional interface that provider executors can implement
// to mutate outbound HTTP requests with provider credentials.
type RequestPreparer interface {
PrepareRequest(req *http.Request, auth *Auth) error
}
// InjectCredentials delegates per-provider HTTP request preparation when supported.
// If the registered executor for the auth provider implements RequestPreparer,
// it will be invoked to modify the request (e.g., add headers).
func (m *Manager) InjectCredentials(req *http.Request, authID string) error {
if req == nil || authID == "" {
return nil
}
m.mu.RLock()
a := m.auths[authID]
var exec ProviderExecutor
if a != nil {
exec = m.executors[a.Provider]
}
m.mu.RUnlock()
if a == nil || exec == nil {
return nil
}
if p, ok := exec.(RequestPreparer); ok && p != nil {
return p.PrepareRequest(req, a)
}
return nil
}