mirror of
https://github.com/router-for-me/CLIProxyAPI.git
synced 2026-02-03 04:50:52 +08:00
v6 version first commit
This commit is contained in:
32
sdk/cliproxy/auth/errors.go
Normal file
32
sdk/cliproxy/auth/errors.go
Normal file
@@ -0,0 +1,32 @@
|
||||
package auth
|
||||
|
||||
// Error describes an authentication related failure in a provider agnostic format.
|
||||
type Error struct {
|
||||
// Code is a short machine readable identifier.
|
||||
Code string `json:"code,omitempty"`
|
||||
// Message is a human readable description of the failure.
|
||||
Message string `json:"message"`
|
||||
// Retryable indicates whether a retry might fix the issue automatically.
|
||||
Retryable bool `json:"retryable"`
|
||||
// HTTPStatus optionally records an HTTP-like status code for the error.
|
||||
HTTPStatus int `json:"http_status,omitempty"`
|
||||
}
|
||||
|
||||
// Error implements the error interface.
|
||||
func (e *Error) Error() string {
|
||||
if e == nil {
|
||||
return ""
|
||||
}
|
||||
if e.Code == "" {
|
||||
return e.Message
|
||||
}
|
||||
return e.Code + ": " + e.Message
|
||||
}
|
||||
|
||||
// StatusCode implements optional status accessor for manager decision making.
|
||||
func (e *Error) StatusCode() int {
|
||||
if e == nil {
|
||||
return 0
|
||||
}
|
||||
return e.HTTPStatus
|
||||
}
|
||||
247
sdk/cliproxy/auth/filestore.go
Normal file
247
sdk/cliproxy/auth/filestore.go
Normal file
@@ -0,0 +1,247 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io/fs"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
)
|
||||
|
||||
// FileStore implements Store backed by JSON files in a directory.
|
||||
type FileStore struct {
|
||||
dir string
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
// NewFileStore builds a file-backed store rooted at dir.
|
||||
func NewFileStore(dir string) *FileStore {
|
||||
return &FileStore{dir: dir}
|
||||
}
|
||||
|
||||
// List enumerates all auth JSON files under the store directory.
|
||||
func (s *FileStore) List(ctx context.Context) ([]*Auth, error) {
|
||||
if s.dir == "" {
|
||||
return nil, fmt.Errorf("auth filestore: directory not configured")
|
||||
}
|
||||
entries := make([]*Auth, 0)
|
||||
err := filepath.WalkDir(s.dir, func(path string, d fs.DirEntry, err error) error {
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if d.IsDir() {
|
||||
return nil
|
||||
}
|
||||
if !strings.HasSuffix(strings.ToLower(d.Name()), ".json") {
|
||||
return nil
|
||||
}
|
||||
auth, err := s.readFile(path)
|
||||
if err != nil {
|
||||
// Record error but keep scanning to surface remaining auths.
|
||||
return nil
|
||||
}
|
||||
if auth != nil {
|
||||
entries = append(entries, auth)
|
||||
}
|
||||
return nil
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return entries, nil
|
||||
}
|
||||
|
||||
// Save writes the auth metadata back to its source file location.
|
||||
func (s *FileStore) Save(ctx context.Context, auth *Auth) error {
|
||||
if auth == nil {
|
||||
return fmt.Errorf("auth filestore: auth is nil")
|
||||
}
|
||||
path := s.resolvePath(auth)
|
||||
if path == "" {
|
||||
return fmt.Errorf("auth filestore: missing file path attribute for %s", auth.ID)
|
||||
}
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
if err := os.MkdirAll(filepath.Dir(path), 0o700); err != nil {
|
||||
return fmt.Errorf("auth filestore: create dir failed: %w", err)
|
||||
}
|
||||
raw, err := json.Marshal(auth.Metadata)
|
||||
if err != nil {
|
||||
return fmt.Errorf("auth filestore: marshal metadata failed: %w", err)
|
||||
}
|
||||
if existing, err := os.ReadFile(path); err == nil {
|
||||
if jsonEqual(existing, raw) {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
tmp := path + ".tmp"
|
||||
if err = os.WriteFile(tmp, raw, 0o600); err != nil {
|
||||
return fmt.Errorf("auth filestore: write temp failed: %w", err)
|
||||
}
|
||||
if err = os.Rename(tmp, path); err != nil {
|
||||
return fmt.Errorf("auth filestore: rename failed: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func jsonEqual(a, b []byte) bool {
|
||||
var objA any
|
||||
var objB any
|
||||
if err := json.Unmarshal(a, &objA); err != nil {
|
||||
return false
|
||||
}
|
||||
if err := json.Unmarshal(b, &objB); err != nil {
|
||||
return false
|
||||
}
|
||||
return deepEqualJSON(objA, objB)
|
||||
}
|
||||
|
||||
func deepEqualJSON(a, b any) bool {
|
||||
switch valA := a.(type) {
|
||||
case map[string]any:
|
||||
valB, ok := b.(map[string]any)
|
||||
if !ok || len(valA) != len(valB) {
|
||||
return false
|
||||
}
|
||||
for key, subA := range valA {
|
||||
subB, ok := valB[key]
|
||||
if !ok || !deepEqualJSON(subA, subB) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
case []any:
|
||||
sliceB, ok := b.([]any)
|
||||
if !ok || len(valA) != len(sliceB) {
|
||||
return false
|
||||
}
|
||||
for i := range valA {
|
||||
if !deepEqualJSON(valA[i], sliceB[i]) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
case float64:
|
||||
valB, ok := b.(float64)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
return valA == valB
|
||||
case string:
|
||||
valB, ok := b.(string)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
return valA == valB
|
||||
case bool:
|
||||
valB, ok := b.(bool)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
return valA == valB
|
||||
case nil:
|
||||
return b == nil
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// Delete removes the auth file.
|
||||
func (s *FileStore) Delete(ctx context.Context, id string) error {
|
||||
if id == "" {
|
||||
return fmt.Errorf("auth filestore: id is empty")
|
||||
}
|
||||
path := filepath.Join(s.dir, id)
|
||||
if strings.ContainsRune(id, os.PathSeparator) {
|
||||
path = id
|
||||
}
|
||||
if err := os.Remove(path); err != nil && !os.IsNotExist(err) {
|
||||
return fmt.Errorf("auth filestore: delete failed: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *FileStore) readFile(path string) (*Auth, error) {
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("read file: %w", err)
|
||||
}
|
||||
if len(data) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
metadata := make(map[string]any)
|
||||
if err = json.Unmarshal(data, &metadata); err != nil {
|
||||
return nil, fmt.Errorf("unmarshal auth json: %w", err)
|
||||
}
|
||||
provider, _ := metadata["type"].(string)
|
||||
if provider == "" {
|
||||
provider = "unknown"
|
||||
}
|
||||
info, err := os.Stat(path)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("stat file: %w", err)
|
||||
}
|
||||
id := s.idFor(path)
|
||||
auth := &Auth{
|
||||
ID: id,
|
||||
Provider: provider,
|
||||
Label: s.labelFor(metadata),
|
||||
Status: StatusActive,
|
||||
Attributes: map[string]string{"path": path},
|
||||
Metadata: metadata,
|
||||
CreatedAt: info.ModTime(),
|
||||
UpdatedAt: info.ModTime(),
|
||||
LastRefreshedAt: time.Time{},
|
||||
NextRefreshAfter: time.Time{},
|
||||
}
|
||||
if email, ok := metadata["email"].(string); ok && email != "" {
|
||||
auth.Attributes["email"] = email
|
||||
}
|
||||
return auth, nil
|
||||
}
|
||||
|
||||
func (s *FileStore) idFor(path string) string {
|
||||
rel, err := filepath.Rel(s.dir, path)
|
||||
if err != nil {
|
||||
return path
|
||||
}
|
||||
return rel
|
||||
}
|
||||
|
||||
func (s *FileStore) resolvePath(auth *Auth) string {
|
||||
if auth == nil {
|
||||
return ""
|
||||
}
|
||||
if auth.Attributes != nil {
|
||||
if p := auth.Attributes["path"]; p != "" {
|
||||
return p
|
||||
}
|
||||
}
|
||||
if filepath.IsAbs(auth.ID) {
|
||||
return auth.ID
|
||||
}
|
||||
if auth.ID == "" {
|
||||
return ""
|
||||
}
|
||||
return filepath.Join(s.dir, auth.ID)
|
||||
}
|
||||
|
||||
func (s *FileStore) labelFor(metadata map[string]any) string {
|
||||
if metadata == nil {
|
||||
return ""
|
||||
}
|
||||
if v, ok := metadata["label"].(string); ok && v != "" {
|
||||
return v
|
||||
}
|
||||
if v, ok := metadata["email"].(string); ok && v != "" {
|
||||
return v
|
||||
}
|
||||
if project, ok := metadata["project_id"].(string); ok && project != "" {
|
||||
return project
|
||||
}
|
||||
return ""
|
||||
}
|
||||
908
sdk/cliproxy/auth/manager.go
Normal file
908
sdk/cliproxy/auth/manager.go
Normal file
@@ -0,0 +1,908 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
||||
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// ProviderExecutor defines the contract required by Manager to execute provider calls.
|
||||
type ProviderExecutor interface {
|
||||
// Identifier returns the provider key handled by this executor.
|
||||
Identifier() string
|
||||
// Execute handles non-streaming execution and returns the provider response payload.
|
||||
Execute(ctx context.Context, auth *Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error)
|
||||
// ExecuteStream handles streaming execution and returns a channel of provider chunks.
|
||||
ExecuteStream(ctx context.Context, auth *Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (<-chan cliproxyexecutor.StreamChunk, error)
|
||||
// Refresh attempts to refresh provider credentials and returns the updated auth state.
|
||||
Refresh(ctx context.Context, auth *Auth) (*Auth, error)
|
||||
}
|
||||
|
||||
// RefreshEvaluator allows runtime state to override refresh decisions.
|
||||
type RefreshEvaluator interface {
|
||||
ShouldRefresh(now time.Time, auth *Auth) bool
|
||||
}
|
||||
|
||||
const (
|
||||
refreshCheckInterval = 5 * time.Second
|
||||
refreshPendingBackoff = time.Minute
|
||||
refreshFailureBackoff = 5 * time.Minute
|
||||
)
|
||||
|
||||
// Result captures execution outcome used to adjust auth state.
|
||||
type Result struct {
|
||||
// AuthID references the auth that produced this result.
|
||||
AuthID string
|
||||
// Provider is copied for convenience when emitting hooks.
|
||||
Provider string
|
||||
// Model is the upstream model identifier used for the request.
|
||||
Model string
|
||||
// Success marks whether the execution succeeded.
|
||||
Success bool
|
||||
// Error describes the failure when Success is false.
|
||||
Error *Error
|
||||
}
|
||||
|
||||
// Selector chooses an auth candidate for execution.
|
||||
type Selector interface {
|
||||
Pick(ctx context.Context, provider, model string, opts cliproxyexecutor.Options, auths []*Auth) (*Auth, error)
|
||||
}
|
||||
|
||||
// Hook captures lifecycle callbacks for observing auth changes.
|
||||
type Hook interface {
|
||||
// OnAuthRegistered fires when a new auth is registered.
|
||||
OnAuthRegistered(ctx context.Context, auth *Auth)
|
||||
// OnAuthUpdated fires when an existing auth changes state.
|
||||
OnAuthUpdated(ctx context.Context, auth *Auth)
|
||||
// OnResult fires when execution result is recorded.
|
||||
OnResult(ctx context.Context, result Result)
|
||||
}
|
||||
|
||||
// NoopHook provides optional hook defaults.
|
||||
type NoopHook struct{}
|
||||
|
||||
// OnAuthRegistered implements Hook.
|
||||
func (NoopHook) OnAuthRegistered(context.Context, *Auth) {}
|
||||
|
||||
// OnAuthUpdated implements Hook.
|
||||
func (NoopHook) OnAuthUpdated(context.Context, *Auth) {}
|
||||
|
||||
// OnResult implements Hook.
|
||||
func (NoopHook) OnResult(context.Context, Result) {}
|
||||
|
||||
// Manager orchestrates auth lifecycle, selection, execution, and persistence.
|
||||
type Manager struct {
|
||||
store Store
|
||||
executors map[string]ProviderExecutor
|
||||
selector Selector
|
||||
hook Hook
|
||||
mu sync.RWMutex
|
||||
auths map[string]*Auth
|
||||
// providerOffsets tracks per-model provider rotation state for multi-provider routing.
|
||||
providerOffsets map[string]int
|
||||
|
||||
// Optional HTTP RoundTripper provider injected by host.
|
||||
rtProvider RoundTripperProvider
|
||||
|
||||
// Auto refresh state
|
||||
refreshCancel context.CancelFunc
|
||||
}
|
||||
|
||||
// NewManager constructs a manager with optional custom selector and hook.
|
||||
func NewManager(store Store, selector Selector, hook Hook) *Manager {
|
||||
if selector == nil {
|
||||
selector = &RoundRobinSelector{}
|
||||
}
|
||||
if hook == nil {
|
||||
hook = NoopHook{}
|
||||
}
|
||||
return &Manager{
|
||||
store: store,
|
||||
executors: make(map[string]ProviderExecutor),
|
||||
selector: selector,
|
||||
hook: hook,
|
||||
auths: make(map[string]*Auth),
|
||||
providerOffsets: make(map[string]int),
|
||||
}
|
||||
}
|
||||
|
||||
// SetStore swaps the underlying persistence store.
|
||||
func (m *Manager) SetStore(store Store) {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
m.store = store
|
||||
}
|
||||
|
||||
// SetRoundTripperProvider register a provider that returns a per-auth RoundTripper.
|
||||
func (m *Manager) SetRoundTripperProvider(p RoundTripperProvider) {
|
||||
m.mu.Lock()
|
||||
m.rtProvider = p
|
||||
m.mu.Unlock()
|
||||
}
|
||||
|
||||
// RegisterExecutor registers a provider executor with the manager.
|
||||
func (m *Manager) RegisterExecutor(executor ProviderExecutor) {
|
||||
if executor == nil {
|
||||
return
|
||||
}
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
m.executors[executor.Identifier()] = executor
|
||||
}
|
||||
|
||||
// Register inserts a new auth entry into the manager.
|
||||
func (m *Manager) Register(ctx context.Context, auth *Auth) (*Auth, error) {
|
||||
if auth == nil {
|
||||
return nil, nil
|
||||
}
|
||||
if auth.ID == "" {
|
||||
auth.ID = uuid.NewString()
|
||||
}
|
||||
m.mu.Lock()
|
||||
m.auths[auth.ID] = auth.Clone()
|
||||
m.mu.Unlock()
|
||||
_ = m.persist(ctx, auth)
|
||||
m.hook.OnAuthRegistered(ctx, auth.Clone())
|
||||
return auth.Clone(), nil
|
||||
}
|
||||
|
||||
// Update replaces an existing auth entry and notifies hooks.
|
||||
func (m *Manager) Update(ctx context.Context, auth *Auth) (*Auth, error) {
|
||||
if auth == nil || auth.ID == "" {
|
||||
return nil, nil
|
||||
}
|
||||
m.mu.Lock()
|
||||
m.auths[auth.ID] = auth.Clone()
|
||||
m.mu.Unlock()
|
||||
_ = m.persist(ctx, auth)
|
||||
m.hook.OnAuthUpdated(ctx, auth.Clone())
|
||||
return auth.Clone(), nil
|
||||
}
|
||||
|
||||
// Load resets manager state from the backing store.
|
||||
func (m *Manager) Load(ctx context.Context) error {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
if m.store == nil {
|
||||
return nil
|
||||
}
|
||||
items, err := m.store.List(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
m.auths = make(map[string]*Auth, len(items))
|
||||
for _, auth := range items {
|
||||
if auth == nil || auth.ID == "" {
|
||||
continue
|
||||
}
|
||||
m.auths[auth.ID] = auth.Clone()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Execute performs a non-streaming execution using the configured selector and executor.
|
||||
// It supports multiple providers for the same model and round-robins the starting provider per model.
|
||||
func (m *Manager) Execute(ctx context.Context, providers []string, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
|
||||
normalized := m.normalizeProviders(providers)
|
||||
if len(normalized) == 0 {
|
||||
return cliproxyexecutor.Response{}, &Error{Code: "provider_not_found", Message: "no provider supplied"}
|
||||
}
|
||||
rotated := m.rotateProviders(req.Model, normalized)
|
||||
defer m.advanceProviderCursor(req.Model, normalized)
|
||||
|
||||
var lastErr error
|
||||
for _, provider := range rotated {
|
||||
resp, errExec := m.executeWithProvider(ctx, provider, req, opts)
|
||||
if errExec == nil {
|
||||
return resp, nil
|
||||
}
|
||||
lastErr = errExec
|
||||
}
|
||||
if lastErr != nil {
|
||||
return cliproxyexecutor.Response{}, lastErr
|
||||
}
|
||||
return cliproxyexecutor.Response{}, &Error{Code: "auth_not_found", Message: "no auth available"}
|
||||
}
|
||||
|
||||
// ExecuteStream performs a streaming execution using the configured selector and executor.
|
||||
// It supports multiple providers for the same model and round-robins the starting provider per model.
|
||||
func (m *Manager) ExecuteStream(ctx context.Context, providers []string, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (<-chan cliproxyexecutor.StreamChunk, error) {
|
||||
normalized := m.normalizeProviders(providers)
|
||||
if len(normalized) == 0 {
|
||||
return nil, &Error{Code: "provider_not_found", Message: "no provider supplied"}
|
||||
}
|
||||
rotated := m.rotateProviders(req.Model, normalized)
|
||||
defer m.advanceProviderCursor(req.Model, normalized)
|
||||
|
||||
var lastErr error
|
||||
for _, provider := range rotated {
|
||||
chunks, errStream := m.executeStreamWithProvider(ctx, provider, req, opts)
|
||||
if errStream == nil {
|
||||
return chunks, nil
|
||||
}
|
||||
lastErr = errStream
|
||||
}
|
||||
if lastErr != nil {
|
||||
return nil, lastErr
|
||||
}
|
||||
return nil, &Error{Code: "auth_not_found", Message: "no auth available"}
|
||||
}
|
||||
|
||||
func (m *Manager) executeWithProvider(ctx context.Context, provider string, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
|
||||
if provider == "" {
|
||||
return cliproxyexecutor.Response{}, &Error{Code: "provider_not_found", Message: "provider identifier is empty"}
|
||||
}
|
||||
tried := make(map[string]struct{})
|
||||
var lastErr error
|
||||
for {
|
||||
auth, executor, errPick := m.pickNext(ctx, provider, req.Model, opts, tried)
|
||||
if errPick != nil {
|
||||
if lastErr != nil {
|
||||
return cliproxyexecutor.Response{}, lastErr
|
||||
}
|
||||
return cliproxyexecutor.Response{}, errPick
|
||||
}
|
||||
|
||||
if isAPIKey, info := auth.AccountInfo(); isAPIKey {
|
||||
log.Debugf("Use API key %s for model %s", util.HideAPIKey(info), req.Model)
|
||||
} else {
|
||||
log.Debugf("Use OAuth %s for model %s", info, req.Model)
|
||||
}
|
||||
|
||||
tried[auth.ID] = struct{}{}
|
||||
execCtx := ctx
|
||||
if rt := m.roundTripperFor(auth); rt != nil {
|
||||
execCtx = context.WithValue(execCtx, roundTripperContextKey{}, rt)
|
||||
execCtx = context.WithValue(execCtx, "cliproxy.roundtripper", rt)
|
||||
}
|
||||
resp, errExec := executor.Execute(execCtx, auth, req, opts)
|
||||
result := Result{AuthID: auth.ID, Provider: provider, Model: req.Model, Success: errExec == nil}
|
||||
if errExec != nil {
|
||||
result.Error = &Error{Message: errExec.Error()}
|
||||
var se cliproxyexecutor.StatusError
|
||||
if errors.As(errExec, &se) && se != nil {
|
||||
result.Error.HTTPStatus = se.StatusCode()
|
||||
}
|
||||
m.MarkResult(execCtx, result)
|
||||
lastErr = errExec
|
||||
continue
|
||||
}
|
||||
m.MarkResult(execCtx, result)
|
||||
return resp, nil
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Manager) executeStreamWithProvider(ctx context.Context, provider string, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (<-chan cliproxyexecutor.StreamChunk, error) {
|
||||
if provider == "" {
|
||||
return nil, &Error{Code: "provider_not_found", Message: "provider identifier is empty"}
|
||||
}
|
||||
tried := make(map[string]struct{})
|
||||
var lastErr error
|
||||
for {
|
||||
auth, executor, errPick := m.pickNext(ctx, provider, req.Model, opts, tried)
|
||||
if errPick != nil {
|
||||
if lastErr != nil {
|
||||
return nil, lastErr
|
||||
}
|
||||
return nil, errPick
|
||||
}
|
||||
|
||||
if isAPIKey, info := auth.AccountInfo(); isAPIKey {
|
||||
log.Debugf("Use API key %s for model %s", util.HideAPIKey(info), req.Model)
|
||||
} else {
|
||||
log.Debugf("Use OAuth %s for model %s", info, req.Model)
|
||||
}
|
||||
|
||||
tried[auth.ID] = struct{}{}
|
||||
execCtx := ctx
|
||||
if rt := m.roundTripperFor(auth); rt != nil {
|
||||
execCtx = context.WithValue(execCtx, roundTripperContextKey{}, rt)
|
||||
execCtx = context.WithValue(execCtx, "cliproxy.roundtripper", rt)
|
||||
}
|
||||
chunks, errStream := executor.ExecuteStream(execCtx, auth, req, opts)
|
||||
if errStream != nil {
|
||||
rerr := &Error{Message: errStream.Error()}
|
||||
var se cliproxyexecutor.StatusError
|
||||
if errors.As(errStream, &se) && se != nil {
|
||||
rerr.HTTPStatus = se.StatusCode()
|
||||
}
|
||||
result := Result{AuthID: auth.ID, Provider: provider, Model: req.Model, Success: false, Error: rerr}
|
||||
m.MarkResult(execCtx, result)
|
||||
lastErr = errStream
|
||||
continue
|
||||
}
|
||||
out := make(chan cliproxyexecutor.StreamChunk)
|
||||
go func(streamCtx context.Context, streamAuth *Auth, streamProvider string, streamChunks <-chan cliproxyexecutor.StreamChunk) {
|
||||
defer close(out)
|
||||
var failed bool
|
||||
for chunk := range streamChunks {
|
||||
if chunk.Err != nil && !failed {
|
||||
failed = true
|
||||
rerr := &Error{Message: chunk.Err.Error()}
|
||||
var se cliproxyexecutor.StatusError
|
||||
if errors.As(chunk.Err, &se) && se != nil {
|
||||
rerr.HTTPStatus = se.StatusCode()
|
||||
}
|
||||
m.MarkResult(streamCtx, Result{AuthID: streamAuth.ID, Provider: streamProvider, Model: req.Model, Success: false, Error: rerr})
|
||||
}
|
||||
out <- chunk
|
||||
}
|
||||
if !failed {
|
||||
m.MarkResult(streamCtx, Result{AuthID: streamAuth.ID, Provider: streamProvider, Model: req.Model, Success: true})
|
||||
}
|
||||
}(execCtx, auth.Clone(), provider, chunks)
|
||||
return out, nil
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Manager) normalizeProviders(providers []string) []string {
|
||||
if len(providers) == 0 {
|
||||
return nil
|
||||
}
|
||||
result := make([]string, 0, len(providers))
|
||||
seen := make(map[string]struct{}, len(providers))
|
||||
for _, provider := range providers {
|
||||
p := strings.TrimSpace(strings.ToLower(provider))
|
||||
if p == "" {
|
||||
continue
|
||||
}
|
||||
if _, ok := seen[p]; ok {
|
||||
continue
|
||||
}
|
||||
seen[p] = struct{}{}
|
||||
result = append(result, p)
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func (m *Manager) rotateProviders(model string, providers []string) []string {
|
||||
if len(providers) == 0 {
|
||||
return nil
|
||||
}
|
||||
m.mu.RLock()
|
||||
offset := m.providerOffsets[model]
|
||||
m.mu.RUnlock()
|
||||
if len(providers) > 0 {
|
||||
offset %= len(providers)
|
||||
}
|
||||
if offset < 0 {
|
||||
offset = 0
|
||||
}
|
||||
if offset == 0 {
|
||||
return providers
|
||||
}
|
||||
rotated := make([]string, 0, len(providers))
|
||||
rotated = append(rotated, providers[offset:]...)
|
||||
rotated = append(rotated, providers[:offset]...)
|
||||
return rotated
|
||||
}
|
||||
|
||||
func (m *Manager) advanceProviderCursor(model string, providers []string) {
|
||||
if len(providers) == 0 {
|
||||
m.mu.Lock()
|
||||
delete(m.providerOffsets, model)
|
||||
m.mu.Unlock()
|
||||
return
|
||||
}
|
||||
m.mu.Lock()
|
||||
current := m.providerOffsets[model]
|
||||
m.providerOffsets[model] = (current + 1) % len(providers)
|
||||
m.mu.Unlock()
|
||||
}
|
||||
|
||||
// MarkResult records an execution result and notifies hooks.
|
||||
func (m *Manager) MarkResult(ctx context.Context, result Result) {
|
||||
if result.AuthID == "" {
|
||||
return
|
||||
}
|
||||
// Update in-memory auth status based on result.
|
||||
m.mu.Lock()
|
||||
if auth, ok := m.auths[result.AuthID]; ok && auth != nil {
|
||||
now := time.Now()
|
||||
if result.Success {
|
||||
// Clear transient error/quota flags on success.
|
||||
auth.Unavailable = false
|
||||
auth.Status = StatusActive
|
||||
auth.StatusMessage = ""
|
||||
auth.Quota.Exceeded = false
|
||||
auth.Quota.Reason = ""
|
||||
auth.Quota.NextRecoverAt = time.Time{}
|
||||
auth.LastError = nil
|
||||
auth.UpdatedAt = now
|
||||
if result.Model != "" {
|
||||
registry.GetGlobalRegistry().ClearModelQuotaExceeded(auth.ID, result.Model)
|
||||
}
|
||||
} else {
|
||||
// Default transient error state.
|
||||
auth.Unavailable = true
|
||||
auth.Status = StatusError
|
||||
auth.UpdatedAt = now
|
||||
if result.Error != nil {
|
||||
auth.LastError = &Error{Code: result.Error.Code, Message: result.Error.Message, Retryable: result.Error.Retryable}
|
||||
}
|
||||
// If the error carries a status code, adjust backoff/quota accordingly.
|
||||
// 401 -> auth issue; 402/429 -> quota; 5xx -> transient.
|
||||
var statusCode int
|
||||
if se, isOk := any(result.Error).(interface{ StatusCode() int }); isOk && se != nil {
|
||||
statusCode = se.StatusCode()
|
||||
}
|
||||
switch statusCode {
|
||||
case 401:
|
||||
auth.StatusMessage = "unauthorized"
|
||||
auth.NextRefreshAfter = now.Add(5 * time.Minute)
|
||||
case 402, 429:
|
||||
auth.StatusMessage = "quota exhausted"
|
||||
auth.Quota.Exceeded = true
|
||||
auth.Quota.Reason = "quota"
|
||||
auth.Quota.NextRecoverAt = now.Add(10 * time.Minute)
|
||||
auth.NextRefreshAfter = auth.Quota.NextRecoverAt
|
||||
if result.Model != "" {
|
||||
registry.GetGlobalRegistry().SetModelQuotaExceeded(auth.ID, result.Model)
|
||||
}
|
||||
case 403, 408, 500, 502, 503, 504:
|
||||
auth.StatusMessage = "transient upstream error"
|
||||
auth.NextRefreshAfter = now.Add(1 * time.Minute)
|
||||
default:
|
||||
// keep generic
|
||||
if auth.StatusMessage == "" {
|
||||
auth.StatusMessage = "request failed"
|
||||
}
|
||||
}
|
||||
}
|
||||
// Persist best-effort (only metadata is stored for file store).
|
||||
_ = m.persist(ctx, auth)
|
||||
}
|
||||
m.mu.Unlock()
|
||||
|
||||
m.hook.OnResult(ctx, result)
|
||||
}
|
||||
|
||||
// List returns all auth entries currently known by the manager.
|
||||
func (m *Manager) List() []*Auth {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
list := make([]*Auth, 0, len(m.auths))
|
||||
for _, auth := range m.auths {
|
||||
list = append(list, auth.Clone())
|
||||
}
|
||||
return list
|
||||
}
|
||||
|
||||
// GetByID retrieves an auth entry by its ID.
|
||||
|
||||
func (m *Manager) GetByID(id string) (*Auth, bool) {
|
||||
if id == "" {
|
||||
return nil, false
|
||||
}
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
auth, ok := m.auths[id]
|
||||
if !ok {
|
||||
return nil, false
|
||||
}
|
||||
return auth.Clone(), true
|
||||
}
|
||||
|
||||
func (m *Manager) pickNext(ctx context.Context, provider, model string, opts cliproxyexecutor.Options, tried map[string]struct{}) (*Auth, ProviderExecutor, error) {
|
||||
m.mu.RLock()
|
||||
executor, okExecutor := m.executors[provider]
|
||||
if !okExecutor {
|
||||
m.mu.RUnlock()
|
||||
return nil, nil, &Error{Code: "executor_not_found", Message: "executor not registered"}
|
||||
}
|
||||
candidates := make([]*Auth, 0, len(m.auths))
|
||||
for _, auth := range m.auths {
|
||||
if auth.Provider != provider || auth.Disabled {
|
||||
continue
|
||||
}
|
||||
if _, used := tried[auth.ID]; used {
|
||||
continue
|
||||
}
|
||||
candidates = append(candidates, auth.Clone())
|
||||
}
|
||||
m.mu.RUnlock()
|
||||
if len(candidates) == 0 {
|
||||
return nil, nil, &Error{Code: "auth_not_found", Message: "no auth available"}
|
||||
}
|
||||
auth, errPick := m.selector.Pick(ctx, provider, model, opts, candidates)
|
||||
if errPick != nil {
|
||||
return nil, nil, errPick
|
||||
}
|
||||
if auth == nil {
|
||||
return nil, nil, &Error{Code: "auth_not_found", Message: "selector returned no auth"}
|
||||
}
|
||||
return auth, executor, nil
|
||||
}
|
||||
|
||||
func (m *Manager) persist(ctx context.Context, auth *Auth) error {
|
||||
if m.store == nil || auth == nil {
|
||||
return nil
|
||||
}
|
||||
// Skip persistence when metadata is absent (e.g., runtime-only auths).
|
||||
if auth.Metadata == nil {
|
||||
return nil
|
||||
}
|
||||
return m.store.Save(ctx, auth)
|
||||
}
|
||||
|
||||
// StartAutoRefresh launches a background loop that evaluates auth freshness
|
||||
// every few seconds and triggers refresh operations when required.
|
||||
// Only one loop is kept alive; starting a new one cancels the previous run.
|
||||
func (m *Manager) StartAutoRefresh(parent context.Context, interval time.Duration) {
|
||||
if interval <= 0 || interval > refreshCheckInterval {
|
||||
interval = refreshCheckInterval
|
||||
} else {
|
||||
interval = refreshCheckInterval
|
||||
}
|
||||
if m.refreshCancel != nil {
|
||||
m.refreshCancel()
|
||||
m.refreshCancel = nil
|
||||
}
|
||||
ctx, cancel := context.WithCancel(parent)
|
||||
m.refreshCancel = cancel
|
||||
go func() {
|
||||
ticker := time.NewTicker(interval)
|
||||
defer ticker.Stop()
|
||||
m.checkRefreshes(ctx)
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-ticker.C:
|
||||
m.checkRefreshes(ctx)
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// StopAutoRefresh cancels the background refresh loop, if running.
|
||||
func (m *Manager) StopAutoRefresh() {
|
||||
if m.refreshCancel != nil {
|
||||
m.refreshCancel()
|
||||
m.refreshCancel = nil
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Manager) checkRefreshes(ctx context.Context) {
|
||||
now := time.Now()
|
||||
snapshot := m.snapshotAuths()
|
||||
for _, a := range snapshot {
|
||||
if !m.shouldRefresh(a, now) {
|
||||
continue
|
||||
}
|
||||
if exec := m.executorFor(a.Provider); exec == nil {
|
||||
continue
|
||||
}
|
||||
if !m.markRefreshPending(a.ID, now) {
|
||||
continue
|
||||
}
|
||||
go m.refreshAuth(ctx, a.ID)
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Manager) snapshotAuths() []*Auth {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
out := make([]*Auth, 0, len(m.auths))
|
||||
for _, a := range m.auths {
|
||||
out = append(out, a.Clone())
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
func (m *Manager) shouldRefresh(a *Auth, now time.Time) bool {
|
||||
if a == nil || a.Disabled {
|
||||
return false
|
||||
}
|
||||
if !a.NextRefreshAfter.IsZero() && now.Before(a.NextRefreshAfter) {
|
||||
return false
|
||||
}
|
||||
if evaluator, ok := a.Runtime.(RefreshEvaluator); ok && evaluator != nil {
|
||||
return evaluator.ShouldRefresh(now, a)
|
||||
}
|
||||
|
||||
lastRefresh := a.LastRefreshedAt
|
||||
if lastRefresh.IsZero() {
|
||||
if ts, ok := authLastRefreshTimestamp(a); ok {
|
||||
lastRefresh = ts
|
||||
}
|
||||
}
|
||||
|
||||
expiry, hasExpiry := a.ExpirationTime()
|
||||
|
||||
if interval := authPreferredInterval(a); interval > 0 {
|
||||
if hasExpiry && !expiry.IsZero() {
|
||||
if !expiry.After(now) {
|
||||
return true
|
||||
}
|
||||
if expiry.Sub(now) <= interval {
|
||||
return true
|
||||
}
|
||||
}
|
||||
if lastRefresh.IsZero() {
|
||||
return true
|
||||
}
|
||||
return now.Sub(lastRefresh) >= interval
|
||||
}
|
||||
|
||||
provider := strings.ToLower(a.Provider)
|
||||
lead := ProviderRefreshLead(provider, a.Runtime)
|
||||
if lead <= 0 {
|
||||
if hasExpiry && !expiry.IsZero() {
|
||||
return now.After(expiry)
|
||||
}
|
||||
return false
|
||||
}
|
||||
if hasExpiry && !expiry.IsZero() {
|
||||
return time.Until(expiry) <= lead
|
||||
}
|
||||
if !lastRefresh.IsZero() {
|
||||
return now.Sub(lastRefresh) >= lead
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
func authPreferredInterval(a *Auth) time.Duration {
|
||||
if a == nil {
|
||||
return 0
|
||||
}
|
||||
if d := durationFromMetadata(a.Metadata, "refresh_interval_seconds", "refreshIntervalSeconds", "refresh_interval", "refreshInterval"); d > 0 {
|
||||
return d
|
||||
}
|
||||
if d := durationFromAttributes(a.Attributes, "refresh_interval_seconds", "refreshIntervalSeconds", "refresh_interval", "refreshInterval"); d > 0 {
|
||||
return d
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
func durationFromMetadata(meta map[string]any, keys ...string) time.Duration {
|
||||
if len(meta) == 0 {
|
||||
return 0
|
||||
}
|
||||
for _, key := range keys {
|
||||
if val, ok := meta[key]; ok {
|
||||
if dur := parseDurationValue(val); dur > 0 {
|
||||
return dur
|
||||
}
|
||||
}
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
func durationFromAttributes(attrs map[string]string, keys ...string) time.Duration {
|
||||
if len(attrs) == 0 {
|
||||
return 0
|
||||
}
|
||||
for _, key := range keys {
|
||||
if val, ok := attrs[key]; ok {
|
||||
if dur := parseDurationString(val); dur > 0 {
|
||||
return dur
|
||||
}
|
||||
}
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
func parseDurationValue(val any) time.Duration {
|
||||
switch v := val.(type) {
|
||||
case time.Duration:
|
||||
if v <= 0 {
|
||||
return 0
|
||||
}
|
||||
return v
|
||||
case int:
|
||||
if v <= 0 {
|
||||
return 0
|
||||
}
|
||||
return time.Duration(v) * time.Second
|
||||
case int32:
|
||||
if v <= 0 {
|
||||
return 0
|
||||
}
|
||||
return time.Duration(v) * time.Second
|
||||
case int64:
|
||||
if v <= 0 {
|
||||
return 0
|
||||
}
|
||||
return time.Duration(v) * time.Second
|
||||
case uint:
|
||||
if v == 0 {
|
||||
return 0
|
||||
}
|
||||
return time.Duration(v) * time.Second
|
||||
case uint32:
|
||||
if v == 0 {
|
||||
return 0
|
||||
}
|
||||
return time.Duration(v) * time.Second
|
||||
case uint64:
|
||||
if v == 0 {
|
||||
return 0
|
||||
}
|
||||
return time.Duration(v) * time.Second
|
||||
case float32:
|
||||
if v <= 0 {
|
||||
return 0
|
||||
}
|
||||
return time.Duration(float64(v) * float64(time.Second))
|
||||
case float64:
|
||||
if v <= 0 {
|
||||
return 0
|
||||
}
|
||||
return time.Duration(v * float64(time.Second))
|
||||
case json.Number:
|
||||
if i, err := v.Int64(); err == nil {
|
||||
if i <= 0 {
|
||||
return 0
|
||||
}
|
||||
return time.Duration(i) * time.Second
|
||||
}
|
||||
if f, err := v.Float64(); err == nil && f > 0 {
|
||||
return time.Duration(f * float64(time.Second))
|
||||
}
|
||||
case string:
|
||||
return parseDurationString(v)
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
func parseDurationString(raw string) time.Duration {
|
||||
s := strings.TrimSpace(raw)
|
||||
if s == "" {
|
||||
return 0
|
||||
}
|
||||
if dur, err := time.ParseDuration(s); err == nil && dur > 0 {
|
||||
return dur
|
||||
}
|
||||
if secs, err := strconv.ParseFloat(s, 64); err == nil && secs > 0 {
|
||||
return time.Duration(secs * float64(time.Second))
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
func authLastRefreshTimestamp(a *Auth) (time.Time, bool) {
|
||||
if a == nil {
|
||||
return time.Time{}, false
|
||||
}
|
||||
if a.Metadata != nil {
|
||||
if ts, ok := lookupMetadataTime(a.Metadata, "last_refresh", "lastRefresh", "last_refreshed_at", "lastRefreshedAt"); ok {
|
||||
return ts, true
|
||||
}
|
||||
}
|
||||
if a.Attributes != nil {
|
||||
for _, key := range []string{"last_refresh", "lastRefresh", "last_refreshed_at", "lastRefreshedAt"} {
|
||||
if val := strings.TrimSpace(a.Attributes[key]); val != "" {
|
||||
if ts, ok := parseTimeValue(val); ok {
|
||||
return ts, true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return time.Time{}, false
|
||||
}
|
||||
|
||||
func lookupMetadataTime(meta map[string]any, keys ...string) (time.Time, bool) {
|
||||
for _, key := range keys {
|
||||
if val, ok := meta[key]; ok {
|
||||
if ts, ok := parseTimeValue(val); ok {
|
||||
return ts, true
|
||||
}
|
||||
}
|
||||
}
|
||||
return time.Time{}, false
|
||||
}
|
||||
|
||||
func (m *Manager) markRefreshPending(id string, now time.Time) bool {
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
auth, ok := m.auths[id]
|
||||
if !ok || auth == nil || auth.Disabled {
|
||||
return false
|
||||
}
|
||||
if !auth.NextRefreshAfter.IsZero() && now.Before(auth.NextRefreshAfter) {
|
||||
return false
|
||||
}
|
||||
auth.NextRefreshAfter = now.Add(refreshPendingBackoff)
|
||||
m.auths[id] = auth
|
||||
return true
|
||||
}
|
||||
|
||||
func (m *Manager) refreshAuth(ctx context.Context, id string) {
|
||||
m.mu.RLock()
|
||||
auth := m.auths[id]
|
||||
var exec ProviderExecutor
|
||||
if auth != nil {
|
||||
exec = m.executors[auth.Provider]
|
||||
}
|
||||
m.mu.RUnlock()
|
||||
if auth == nil || exec == nil {
|
||||
return
|
||||
}
|
||||
cloned := auth.Clone()
|
||||
updated, err := exec.Refresh(ctx, cloned)
|
||||
now := time.Now()
|
||||
if err != nil {
|
||||
m.mu.Lock()
|
||||
if current := m.auths[id]; current != nil {
|
||||
current.NextRefreshAfter = now.Add(refreshFailureBackoff)
|
||||
current.LastError = &Error{Message: err.Error()}
|
||||
m.auths[id] = current
|
||||
}
|
||||
m.mu.Unlock()
|
||||
return
|
||||
}
|
||||
if updated == nil {
|
||||
updated = cloned
|
||||
}
|
||||
updated.Runtime = auth.Runtime
|
||||
updated.LastRefreshedAt = now
|
||||
updated.NextRefreshAfter = time.Time{}
|
||||
updated.LastError = nil
|
||||
updated.UpdatedAt = now
|
||||
_, _ = m.Update(ctx, updated)
|
||||
}
|
||||
|
||||
func (m *Manager) executorFor(provider string) ProviderExecutor {
|
||||
m.mu.RLock()
|
||||
defer m.mu.RUnlock()
|
||||
return m.executors[provider]
|
||||
}
|
||||
|
||||
// roundTripperContextKey is an unexported context key type to avoid collisions.
|
||||
type roundTripperContextKey struct{}
|
||||
|
||||
// roundTripperFor retrieves an HTTP RoundTripper for the given auth if a provider is registered.
|
||||
func (m *Manager) roundTripperFor(auth *Auth) http.RoundTripper {
|
||||
m.mu.RLock()
|
||||
p := m.rtProvider
|
||||
m.mu.RUnlock()
|
||||
if p == nil || auth == nil {
|
||||
return nil
|
||||
}
|
||||
return p.RoundTripperFor(auth)
|
||||
}
|
||||
|
||||
// RoundTripperProvider defines a minimal provider of per-auth HTTP transports.
|
||||
type RoundTripperProvider interface {
|
||||
RoundTripperFor(auth *Auth) http.RoundTripper
|
||||
}
|
||||
|
||||
// RequestPreparer is an optional interface that provider executors can implement
|
||||
// to mutate outbound HTTP requests with provider credentials.
|
||||
type RequestPreparer interface {
|
||||
PrepareRequest(req *http.Request, auth *Auth) error
|
||||
}
|
||||
|
||||
// InjectCredentials delegates per-provider HTTP request preparation when supported.
|
||||
// If the registered executor for the auth provider implements RequestPreparer,
|
||||
// it will be invoked to modify the request (e.g., add headers).
|
||||
func (m *Manager) InjectCredentials(req *http.Request, authID string) error {
|
||||
if req == nil || authID == "" {
|
||||
return nil
|
||||
}
|
||||
m.mu.RLock()
|
||||
a := m.auths[authID]
|
||||
var exec ProviderExecutor
|
||||
if a != nil {
|
||||
exec = m.executors[a.Provider]
|
||||
}
|
||||
m.mu.RUnlock()
|
||||
if a == nil || exec == nil {
|
||||
return nil
|
||||
}
|
||||
if p, ok := exec.(RequestPreparer); ok && p != nil {
|
||||
return p.PrepareRequest(req, a)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
48
sdk/cliproxy/auth/selector.go
Normal file
48
sdk/cliproxy/auth/selector.go
Normal file
@@ -0,0 +1,48 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
|
||||
)
|
||||
|
||||
// RoundRobinSelector provides a simple provider scoped round-robin selection strategy.
|
||||
type RoundRobinSelector struct {
|
||||
mu sync.Mutex
|
||||
cursors map[string]int
|
||||
}
|
||||
|
||||
// Pick selects the next available auth for the provider in a round-robin manner.
|
||||
func (s *RoundRobinSelector) Pick(ctx context.Context, provider, model string, opts cliproxyexecutor.Options, auths []*Auth) (*Auth, error) {
|
||||
_ = ctx
|
||||
_ = opts
|
||||
if len(auths) == 0 {
|
||||
return nil, &Error{Code: "auth_not_found", Message: "no auth candidates"}
|
||||
}
|
||||
if s.cursors == nil {
|
||||
s.cursors = make(map[string]int)
|
||||
}
|
||||
available := make([]*Auth, 0, len(auths))
|
||||
now := time.Now()
|
||||
for i := range auths {
|
||||
candidate := auths[i]
|
||||
if candidate.Unavailable && candidate.Quota.NextRecoverAt.After(now) {
|
||||
continue
|
||||
}
|
||||
if candidate.Status == StatusDisabled || candidate.Disabled {
|
||||
continue
|
||||
}
|
||||
available = append(available, candidate)
|
||||
}
|
||||
if len(available) == 0 {
|
||||
return nil, &Error{Code: "auth_unavailable", Message: "no auth available"}
|
||||
}
|
||||
key := provider + ":" + model
|
||||
s.mu.Lock()
|
||||
index := s.cursors[key]
|
||||
s.cursors[key] = (index + 1) % len(available)
|
||||
s.mu.Unlock()
|
||||
return available[index%len(available)], nil
|
||||
}
|
||||
19
sdk/cliproxy/auth/status.go
Normal file
19
sdk/cliproxy/auth/status.go
Normal file
@@ -0,0 +1,19 @@
|
||||
package auth
|
||||
|
||||
// Status represents the lifecycle state of an Auth entry.
|
||||
type Status string
|
||||
|
||||
const (
|
||||
// StatusUnknown means the auth state could not be determined.
|
||||
StatusUnknown Status = "unknown"
|
||||
// StatusActive indicates the auth is valid and ready for execution.
|
||||
StatusActive Status = "active"
|
||||
// StatusPending indicates the auth is waiting for an external action, such as MFA.
|
||||
StatusPending Status = "pending"
|
||||
// StatusRefreshing indicates the auth is undergoing a refresh flow.
|
||||
StatusRefreshing Status = "refreshing"
|
||||
// StatusError indicates the auth is temporarily unavailable due to errors.
|
||||
StatusError Status = "error"
|
||||
// StatusDisabled marks the auth as intentionally disabled.
|
||||
StatusDisabled Status = "disabled"
|
||||
)
|
||||
13
sdk/cliproxy/auth/store.go
Normal file
13
sdk/cliproxy/auth/store.go
Normal file
@@ -0,0 +1,13 @@
|
||||
package auth
|
||||
|
||||
import "context"
|
||||
|
||||
// Store abstracts persistence of Auth state across restarts.
|
||||
type Store interface {
|
||||
// List returns all auth records stored in the backend.
|
||||
List(ctx context.Context) ([]*Auth, error)
|
||||
// Save persists the provided auth record, replacing any existing one with same ID.
|
||||
Save(ctx context.Context, auth *Auth) error
|
||||
// Delete removes the auth record identified by id.
|
||||
Delete(ctx context.Context, id string) error
|
||||
}
|
||||
218
sdk/cliproxy/auth/types.go
Normal file
218
sdk/cliproxy/auth/types.go
Normal file
@@ -0,0 +1,218 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
clipauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/auth"
|
||||
)
|
||||
|
||||
// Auth encapsulates the runtime state and metadata associated with a single credential.
|
||||
type Auth struct {
|
||||
// ID uniquely identifies the auth record across restarts.
|
||||
ID string `json:"id"`
|
||||
// Provider is the upstream provider key (e.g. "gemini", "claude").
|
||||
Provider string `json:"provider"`
|
||||
// Label is an optional human readable label for logging.
|
||||
Label string `json:"label,omitempty"`
|
||||
// Status is the lifecycle status managed by the AuthManager.
|
||||
Status Status `json:"status"`
|
||||
// StatusMessage holds a short description for the current status.
|
||||
StatusMessage string `json:"status_message,omitempty"`
|
||||
// Disabled indicates the auth is intentionally disabled by operator.
|
||||
Disabled bool `json:"disabled"`
|
||||
// Unavailable flags transient provider unavailability (e.g. quota exceeded).
|
||||
Unavailable bool `json:"unavailable"`
|
||||
// ProxyURL overrides the global proxy setting for this auth if provided.
|
||||
ProxyURL string `json:"proxy_url,omitempty"`
|
||||
// Attributes stores provider specific metadata needed by executors (immutable configuration).
|
||||
Attributes map[string]string `json:"attributes,omitempty"`
|
||||
// Metadata stores runtime mutable provider state (e.g. tokens, cookies).
|
||||
Metadata map[string]any `json:"metadata,omitempty"`
|
||||
// Quota captures recent quota information for load balancers.
|
||||
Quota QuotaState `json:"quota"`
|
||||
// LastError stores the last failure encountered while executing or refreshing.
|
||||
LastError *Error `json:"last_error,omitempty"`
|
||||
// CreatedAt is the creation timestamp in UTC.
|
||||
CreatedAt time.Time `json:"created_at"`
|
||||
// UpdatedAt is the last modification timestamp in UTC.
|
||||
UpdatedAt time.Time `json:"updated_at"`
|
||||
// LastRefreshedAt records the last successful refresh time in UTC.
|
||||
LastRefreshedAt time.Time `json:"last_refreshed_at"`
|
||||
// NextRefreshAfter is the earliest time a refresh should retrigger.
|
||||
NextRefreshAfter time.Time `json:"next_refresh_after"`
|
||||
|
||||
// Runtime carries non-serialisable data used during execution (in-memory only).
|
||||
Runtime any `json:"-"`
|
||||
}
|
||||
|
||||
// QuotaState contains limiter tracking data for a credential.
|
||||
type QuotaState struct {
|
||||
// Exceeded indicates the credential recently hit a quota error.
|
||||
Exceeded bool `json:"exceeded"`
|
||||
// Reason provides an optional provider specific human readable description.
|
||||
Reason string `json:"reason,omitempty"`
|
||||
// NextRecoverAt is when the credential may become available again.
|
||||
NextRecoverAt time.Time `json:"next_recover_at"`
|
||||
}
|
||||
|
||||
// Clone shallow copies the Auth structure, duplicating maps to avoid accidental mutation.
|
||||
func (a *Auth) Clone() *Auth {
|
||||
if a == nil {
|
||||
return nil
|
||||
}
|
||||
copyAuth := *a
|
||||
if len(a.Attributes) > 0 {
|
||||
copyAuth.Attributes = make(map[string]string, len(a.Attributes))
|
||||
for key, value := range a.Attributes {
|
||||
copyAuth.Attributes[key] = value
|
||||
}
|
||||
}
|
||||
if len(a.Metadata) > 0 {
|
||||
copyAuth.Metadata = make(map[string]any, len(a.Metadata))
|
||||
for key, value := range a.Metadata {
|
||||
copyAuth.Metadata[key] = value
|
||||
}
|
||||
}
|
||||
copyAuth.Runtime = a.Runtime
|
||||
return ©Auth
|
||||
}
|
||||
|
||||
func (a *Auth) AccountInfo() (bool, string) {
|
||||
if a == nil {
|
||||
return false, ""
|
||||
}
|
||||
if a.Metadata != nil {
|
||||
if v, ok := a.Metadata["email"].(string); ok {
|
||||
return false, v
|
||||
}
|
||||
} else if a.Attributes != nil {
|
||||
if v := a.Attributes["api_key"]; v != "" {
|
||||
return true, v
|
||||
}
|
||||
}
|
||||
return false, ""
|
||||
}
|
||||
|
||||
// ExpirationTime attempts to extract the credential expiration timestamp from metadata.
|
||||
// It inspects common keys such as "expired", "expire", "expires_at", and also
|
||||
// nested "token" objects to remain compatible with legacy auth file formats.
|
||||
func (a *Auth) ExpirationTime() (time.Time, bool) {
|
||||
if a == nil {
|
||||
return time.Time{}, false
|
||||
}
|
||||
if ts, ok := expirationFromMap(a.Metadata); ok {
|
||||
return ts, true
|
||||
}
|
||||
return time.Time{}, false
|
||||
}
|
||||
|
||||
var defaultAuthenticatorFactories = map[string]func() clipauth.Authenticator{
|
||||
"codex": func() clipauth.Authenticator { return clipauth.NewCodexAuthenticator() },
|
||||
"claude": func() clipauth.Authenticator { return clipauth.NewClaudeAuthenticator() },
|
||||
"qwen": func() clipauth.Authenticator { return clipauth.NewQwenAuthenticator() },
|
||||
"gemini": func() clipauth.Authenticator { return clipauth.NewGeminiAuthenticator() },
|
||||
"gemini-cli": func() clipauth.Authenticator { return clipauth.NewGeminiAuthenticator() },
|
||||
}
|
||||
|
||||
var expireKeys = [...]string{"expired", "expire", "expires_at", "expiresAt", "expiry", "expires"}
|
||||
|
||||
func expirationFromMap(meta map[string]any) (time.Time, bool) {
|
||||
if meta == nil {
|
||||
return time.Time{}, false
|
||||
}
|
||||
for _, key := range expireKeys {
|
||||
if v, ok := meta[key]; ok {
|
||||
if ts, ok := parseTimeValue(v); ok {
|
||||
return ts, true
|
||||
}
|
||||
}
|
||||
}
|
||||
for _, nestedKey := range []string{"token", "Token"} {
|
||||
if nested, ok := meta[nestedKey]; ok {
|
||||
switch val := nested.(type) {
|
||||
case map[string]any:
|
||||
if ts, ok := expirationFromMap(val); ok {
|
||||
return ts, true
|
||||
}
|
||||
case map[string]string:
|
||||
temp := make(map[string]any, len(val))
|
||||
for k, v := range val {
|
||||
temp[k] = v
|
||||
}
|
||||
if ts, ok := expirationFromMap(temp); ok {
|
||||
return ts, true
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
return time.Time{}, false
|
||||
}
|
||||
|
||||
func ProviderRefreshLead(provider string, runtime any) time.Duration {
|
||||
provider = strings.ToLower(provider)
|
||||
if runtime != nil {
|
||||
if eval, ok := runtime.(interface{ RefreshLead() *time.Duration }); ok {
|
||||
if lead := eval.RefreshLead(); lead != nil && *lead > 0 {
|
||||
return *lead
|
||||
}
|
||||
}
|
||||
}
|
||||
if factory, ok := defaultAuthenticatorFactories[provider]; ok {
|
||||
if auth := factory(); auth != nil {
|
||||
if lead := auth.RefreshLead(); lead != nil && *lead > 0 {
|
||||
return *lead
|
||||
}
|
||||
}
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
func parseTimeValue(v any) (time.Time, bool) {
|
||||
switch value := v.(type) {
|
||||
case string:
|
||||
s := strings.TrimSpace(value)
|
||||
if s == "" {
|
||||
return time.Time{}, false
|
||||
}
|
||||
layouts := []string{
|
||||
time.RFC3339,
|
||||
time.RFC3339Nano,
|
||||
"2006-01-02 15:04:05",
|
||||
"2006-01-02T15:04:05Z07:00",
|
||||
}
|
||||
for _, layout := range layouts {
|
||||
if ts, err := time.Parse(layout, s); err == nil {
|
||||
return ts, true
|
||||
}
|
||||
}
|
||||
if unix, err := strconv.ParseInt(s, 10, 64); err == nil {
|
||||
return normaliseUnix(unix), true
|
||||
}
|
||||
case float64:
|
||||
return normaliseUnix(int64(value)), true
|
||||
case int64:
|
||||
return normaliseUnix(value), true
|
||||
case json.Number:
|
||||
if i, err := value.Int64(); err == nil {
|
||||
return normaliseUnix(i), true
|
||||
}
|
||||
if f, err := value.Float64(); err == nil {
|
||||
return normaliseUnix(int64(f)), true
|
||||
}
|
||||
}
|
||||
return time.Time{}, false
|
||||
}
|
||||
|
||||
func normaliseUnix(raw int64) time.Time {
|
||||
if raw <= 0 {
|
||||
return time.Time{}
|
||||
}
|
||||
// Heuristic: treat values with millisecond precision (>1e12) accordingly.
|
||||
if raw > 1_000_000_000_000 {
|
||||
return time.UnixMilli(raw)
|
||||
}
|
||||
return time.Unix(raw, 0)
|
||||
}
|
||||
Reference in New Issue
Block a user