mirror of
https://github.com/router-for-me/CLIProxyAPI.git
synced 2026-02-28 12:04:44 +08:00
Compare commits
35 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
c66cb0afd2 | ||
|
|
fb48eee973 | ||
|
|
bb44e5ec44 | ||
|
|
0659ffab75 | ||
|
|
7cb398d167 | ||
|
|
c3e12c5e58 | ||
|
|
1825fc7503 | ||
|
|
48732ba05e | ||
|
|
acf483c9e6 | ||
|
|
492b9c46f0 | ||
|
|
eb7571936c | ||
|
|
5382764d8a | ||
|
|
49c8ec69d0 | ||
|
|
713388dd7b | ||
|
|
e6c7af0fa9 | ||
|
|
d210be06c2 | ||
|
|
5936f9895c | ||
|
|
0cbfe7f457 | ||
|
|
b9ae4ab803 | ||
|
|
a45c6defa7 | ||
|
|
40bee3e8d9 | ||
|
|
93147dddeb | ||
|
|
c0f9b15a58 | ||
|
|
6f2fbdcbae | ||
|
|
65debb874f | ||
|
|
3caadac003 | ||
|
|
6a9e3a6b84 | ||
|
|
269972440a | ||
|
|
cce13e6ad2 | ||
|
|
8a565dcad8 | ||
|
|
d536110404 | ||
|
|
48e957ddff | ||
|
|
94563d622c | ||
|
|
ce0c6aa82b | ||
|
|
3c85d2a4d7 |
@@ -58,6 +58,7 @@ func main() {
|
||||
// Command-line flags to control the application's behavior.
|
||||
var login bool
|
||||
var codexLogin bool
|
||||
var codexDeviceLogin bool
|
||||
var claudeLogin bool
|
||||
var qwenLogin bool
|
||||
var iflowLogin bool
|
||||
@@ -76,6 +77,7 @@ func main() {
|
||||
// Define command-line flags for different operation modes.
|
||||
flag.BoolVar(&login, "login", false, "Login Google Account")
|
||||
flag.BoolVar(&codexLogin, "codex-login", false, "Login to Codex using OAuth")
|
||||
flag.BoolVar(&codexDeviceLogin, "codex-device-login", false, "Login to Codex using device code flow")
|
||||
flag.BoolVar(&claudeLogin, "claude-login", false, "Login to Claude using OAuth")
|
||||
flag.BoolVar(&qwenLogin, "qwen-login", false, "Login to Qwen using OAuth")
|
||||
flag.BoolVar(&iflowLogin, "iflow-login", false, "Login to iFlow using OAuth")
|
||||
@@ -467,6 +469,9 @@ func main() {
|
||||
} else if codexLogin {
|
||||
// Handle Codex login
|
||||
cmd.DoCodexLogin(cfg, options)
|
||||
} else if codexDeviceLogin {
|
||||
// Handle Codex device-code login
|
||||
cmd.DoCodexDeviceLogin(cfg, options)
|
||||
} else if claudeLogin {
|
||||
// Handle Claude login
|
||||
cmd.DoClaudeLogin(cfg, options)
|
||||
|
||||
@@ -945,11 +945,17 @@ func (h *Handler) saveTokenRecord(ctx context.Context, record *coreauth.Auth) (s
|
||||
if store == nil {
|
||||
return "", fmt.Errorf("token store unavailable")
|
||||
}
|
||||
if h.postAuthHook != nil {
|
||||
if err := h.postAuthHook(ctx, record); err != nil {
|
||||
return "", fmt.Errorf("post-auth hook failed: %w", err)
|
||||
}
|
||||
}
|
||||
return store.Save(ctx, record)
|
||||
}
|
||||
|
||||
func (h *Handler) RequestAnthropicToken(c *gin.Context) {
|
||||
ctx := context.Background()
|
||||
ctx = PopulateAuthContext(ctx, c)
|
||||
|
||||
fmt.Println("Initializing Claude authentication...")
|
||||
|
||||
@@ -1094,6 +1100,7 @@ func (h *Handler) RequestAnthropicToken(c *gin.Context) {
|
||||
|
||||
func (h *Handler) RequestGeminiCLIToken(c *gin.Context) {
|
||||
ctx := context.Background()
|
||||
ctx = PopulateAuthContext(ctx, c)
|
||||
proxyHTTPClient := util.SetProxy(&h.cfg.SDKConfig, &http.Client{})
|
||||
ctx = context.WithValue(ctx, oauth2.HTTPClient, proxyHTTPClient)
|
||||
|
||||
@@ -1352,6 +1359,7 @@ func (h *Handler) RequestGeminiCLIToken(c *gin.Context) {
|
||||
|
||||
func (h *Handler) RequestCodexToken(c *gin.Context) {
|
||||
ctx := context.Background()
|
||||
ctx = PopulateAuthContext(ctx, c)
|
||||
|
||||
fmt.Println("Initializing Codex authentication...")
|
||||
|
||||
@@ -1497,6 +1505,7 @@ func (h *Handler) RequestCodexToken(c *gin.Context) {
|
||||
|
||||
func (h *Handler) RequestAntigravityToken(c *gin.Context) {
|
||||
ctx := context.Background()
|
||||
ctx = PopulateAuthContext(ctx, c)
|
||||
|
||||
fmt.Println("Initializing Antigravity authentication...")
|
||||
|
||||
@@ -1661,6 +1670,7 @@ func (h *Handler) RequestAntigravityToken(c *gin.Context) {
|
||||
|
||||
func (h *Handler) RequestQwenToken(c *gin.Context) {
|
||||
ctx := context.Background()
|
||||
ctx = PopulateAuthContext(ctx, c)
|
||||
|
||||
fmt.Println("Initializing Qwen authentication...")
|
||||
|
||||
@@ -1716,6 +1726,7 @@ func (h *Handler) RequestQwenToken(c *gin.Context) {
|
||||
|
||||
func (h *Handler) RequestKimiToken(c *gin.Context) {
|
||||
ctx := context.Background()
|
||||
ctx = PopulateAuthContext(ctx, c)
|
||||
|
||||
fmt.Println("Initializing Kimi authentication...")
|
||||
|
||||
@@ -1792,6 +1803,7 @@ func (h *Handler) RequestKimiToken(c *gin.Context) {
|
||||
|
||||
func (h *Handler) RequestIFlowToken(c *gin.Context) {
|
||||
ctx := context.Background()
|
||||
ctx = PopulateAuthContext(ctx, c)
|
||||
|
||||
fmt.Println("Initializing iFlow authentication...")
|
||||
|
||||
@@ -2412,3 +2424,12 @@ func (h *Handler) GetAuthStatus(c *gin.Context) {
|
||||
}
|
||||
c.JSON(http.StatusOK, gin.H{"status": "wait"})
|
||||
}
|
||||
|
||||
// PopulateAuthContext extracts request info and adds it to the context
|
||||
func PopulateAuthContext(ctx context.Context, c *gin.Context) context.Context {
|
||||
info := &coreauth.RequestInfo{
|
||||
Query: c.Request.URL.Query(),
|
||||
Headers: c.Request.Header,
|
||||
}
|
||||
return coreauth.WithRequestInfo(ctx, info)
|
||||
}
|
||||
|
||||
@@ -47,6 +47,7 @@ type Handler struct {
|
||||
allowRemoteOverride bool
|
||||
envSecret string
|
||||
logDir string
|
||||
postAuthHook coreauth.PostAuthHook
|
||||
}
|
||||
|
||||
// NewHandler creates a new management handler instance.
|
||||
@@ -128,6 +129,11 @@ func (h *Handler) SetLogDirectory(dir string) {
|
||||
h.logDir = dir
|
||||
}
|
||||
|
||||
// SetPostAuthHook registers a hook to be called after auth record creation but before persistence.
|
||||
func (h *Handler) SetPostAuthHook(hook coreauth.PostAuthHook) {
|
||||
h.postAuthHook = hook
|
||||
}
|
||||
|
||||
// Middleware enforces access control for management endpoints.
|
||||
// All requests (local and remote) require a valid management key.
|
||||
// Additionally, remote access requires allow-remote-management=true.
|
||||
|
||||
@@ -3,6 +3,8 @@ package amp
|
||||
import (
|
||||
"bytes"
|
||||
"compress/gzip"
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
@@ -188,6 +190,10 @@ func createReverseProxy(upstreamURL string, secretSource SecretSource) (*httputi
|
||||
|
||||
// Error handler for proxy failures
|
||||
proxy.ErrorHandler = func(rw http.ResponseWriter, req *http.Request, err error) {
|
||||
// Client-side cancellations are common during polling; suppress logging in this case
|
||||
if errors.Is(err, context.Canceled) {
|
||||
return
|
||||
}
|
||||
log.Errorf("amp upstream proxy error for %s %s: %v", req.Method, req.URL.Path, err)
|
||||
rw.Header().Set("Content-Type", "application/json")
|
||||
rw.WriteHeader(http.StatusBadGateway)
|
||||
|
||||
@@ -493,6 +493,30 @@ func TestReverseProxy_ErrorHandler(t *testing.T) {
|
||||
}
|
||||
}
|
||||
|
||||
func TestReverseProxy_ErrorHandler_ContextCanceled(t *testing.T) {
|
||||
// Test that context.Canceled errors return 499 without generic error response
|
||||
proxy, err := createReverseProxy("http://example.com", NewStaticSecretSource(""))
|
||||
if err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
// Create a canceled context to trigger the cancellation path
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
cancel() // Cancel immediately
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/test", nil).WithContext(ctx)
|
||||
rr := httptest.NewRecorder()
|
||||
|
||||
// Directly invoke the ErrorHandler with context.Canceled
|
||||
proxy.ErrorHandler(rr, req, context.Canceled)
|
||||
|
||||
// Body should be empty for canceled requests (no JSON error response)
|
||||
body := rr.Body.Bytes()
|
||||
if len(body) > 0 {
|
||||
t.Fatalf("expected empty body for canceled context, got: %s", body)
|
||||
}
|
||||
}
|
||||
|
||||
func TestReverseProxy_FullRoundTrip_Gzip(t *testing.T) {
|
||||
// Upstream returns gzipped JSON without Content-Encoding header
|
||||
upstream := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
@@ -51,6 +51,7 @@ type serverOptionConfig struct {
|
||||
keepAliveEnabled bool
|
||||
keepAliveTimeout time.Duration
|
||||
keepAliveOnTimeout func()
|
||||
postAuthHook auth.PostAuthHook
|
||||
}
|
||||
|
||||
// ServerOption customises HTTP server construction.
|
||||
@@ -111,6 +112,13 @@ func WithRequestLoggerFactory(factory func(*config.Config, string) logging.Reque
|
||||
}
|
||||
}
|
||||
|
||||
// WithPostAuthHook registers a hook to be called after auth record creation.
|
||||
func WithPostAuthHook(hook auth.PostAuthHook) ServerOption {
|
||||
return func(cfg *serverOptionConfig) {
|
||||
cfg.postAuthHook = hook
|
||||
}
|
||||
}
|
||||
|
||||
// Server represents the main API server.
|
||||
// It encapsulates the Gin engine, HTTP server, handlers, and configuration.
|
||||
type Server struct {
|
||||
@@ -262,6 +270,9 @@ func NewServer(cfg *config.Config, authManager *auth.Manager, accessManager *sdk
|
||||
}
|
||||
logDir := logging.ResolveLogDirectory(cfg)
|
||||
s.mgmt.SetLogDirectory(logDir)
|
||||
if optionState.postAuthHook != nil {
|
||||
s.mgmt.SetPostAuthHook(optionState.postAuthHook)
|
||||
}
|
||||
s.localPassword = optionState.localPassword
|
||||
|
||||
// Setup routes
|
||||
|
||||
@@ -36,11 +36,21 @@ type ClaudeTokenStorage struct {
|
||||
|
||||
// Expire is the timestamp when the current access token expires.
|
||||
Expire string `json:"expired"`
|
||||
|
||||
// Metadata holds arbitrary key-value pairs injected via hooks.
|
||||
// It is not exported to JSON directly to allow flattening during serialization.
|
||||
Metadata map[string]any `json:"-"`
|
||||
}
|
||||
|
||||
// SetMetadata allows external callers to inject metadata into the storage before saving.
|
||||
func (ts *ClaudeTokenStorage) SetMetadata(meta map[string]any) {
|
||||
ts.Metadata = meta
|
||||
}
|
||||
|
||||
// SaveTokenToFile serializes the Claude token storage to a JSON file.
|
||||
// This method creates the necessary directory structure and writes the token
|
||||
// data in JSON format to the specified file path for persistent storage.
|
||||
// It merges any injected metadata into the top-level JSON object.
|
||||
//
|
||||
// Parameters:
|
||||
// - authFilePath: The full path where the token file should be saved
|
||||
@@ -65,8 +75,14 @@ func (ts *ClaudeTokenStorage) SaveTokenToFile(authFilePath string) error {
|
||||
_ = f.Close()
|
||||
}()
|
||||
|
||||
// Merge metadata using helper
|
||||
data, errMerge := misc.MergeMetadata(ts, ts.Metadata)
|
||||
if errMerge != nil {
|
||||
return fmt.Errorf("failed to merge metadata: %w", errMerge)
|
||||
}
|
||||
|
||||
// Encode and write the token data as JSON
|
||||
if err = json.NewEncoder(f).Encode(ts); err != nil {
|
||||
if err = json.NewEncoder(f).Encode(data); err != nil {
|
||||
return fmt.Errorf("failed to write token to file: %w", err)
|
||||
}
|
||||
return nil
|
||||
|
||||
@@ -71,16 +71,26 @@ func (o *CodexAuth) GenerateAuthURL(state string, pkceCodes *PKCECodes) (string,
|
||||
// It performs an HTTP POST request to the OpenAI token endpoint with the provided
|
||||
// authorization code and PKCE verifier.
|
||||
func (o *CodexAuth) ExchangeCodeForTokens(ctx context.Context, code string, pkceCodes *PKCECodes) (*CodexAuthBundle, error) {
|
||||
return o.ExchangeCodeForTokensWithRedirect(ctx, code, RedirectURI, pkceCodes)
|
||||
}
|
||||
|
||||
// ExchangeCodeForTokensWithRedirect exchanges an authorization code for tokens using
|
||||
// a caller-provided redirect URI. This supports alternate auth flows such as device
|
||||
// login while preserving the existing token parsing and storage behavior.
|
||||
func (o *CodexAuth) ExchangeCodeForTokensWithRedirect(ctx context.Context, code, redirectURI string, pkceCodes *PKCECodes) (*CodexAuthBundle, error) {
|
||||
if pkceCodes == nil {
|
||||
return nil, fmt.Errorf("PKCE codes are required for token exchange")
|
||||
}
|
||||
if strings.TrimSpace(redirectURI) == "" {
|
||||
return nil, fmt.Errorf("redirect URI is required for token exchange")
|
||||
}
|
||||
|
||||
// Prepare token exchange request
|
||||
data := url.Values{
|
||||
"grant_type": {"authorization_code"},
|
||||
"client_id": {ClientID},
|
||||
"code": {code},
|
||||
"redirect_uri": {RedirectURI},
|
||||
"redirect_uri": {strings.TrimSpace(redirectURI)},
|
||||
"code_verifier": {pkceCodes.CodeVerifier},
|
||||
}
|
||||
|
||||
|
||||
@@ -32,11 +32,21 @@ type CodexTokenStorage struct {
|
||||
Type string `json:"type"`
|
||||
// Expire is the timestamp when the current access token expires.
|
||||
Expire string `json:"expired"`
|
||||
|
||||
// Metadata holds arbitrary key-value pairs injected via hooks.
|
||||
// It is not exported to JSON directly to allow flattening during serialization.
|
||||
Metadata map[string]any `json:"-"`
|
||||
}
|
||||
|
||||
// SetMetadata allows external callers to inject metadata into the storage before saving.
|
||||
func (ts *CodexTokenStorage) SetMetadata(meta map[string]any) {
|
||||
ts.Metadata = meta
|
||||
}
|
||||
|
||||
// SaveTokenToFile serializes the Codex token storage to a JSON file.
|
||||
// This method creates the necessary directory structure and writes the token
|
||||
// data in JSON format to the specified file path for persistent storage.
|
||||
// It merges any injected metadata into the top-level JSON object.
|
||||
//
|
||||
// Parameters:
|
||||
// - authFilePath: The full path where the token file should be saved
|
||||
@@ -58,7 +68,13 @@ func (ts *CodexTokenStorage) SaveTokenToFile(authFilePath string) error {
|
||||
_ = f.Close()
|
||||
}()
|
||||
|
||||
if err = json.NewEncoder(f).Encode(ts); err != nil {
|
||||
// Merge metadata using helper
|
||||
data, errMerge := misc.MergeMetadata(ts, ts.Metadata)
|
||||
if errMerge != nil {
|
||||
return fmt.Errorf("failed to merge metadata: %w", errMerge)
|
||||
}
|
||||
|
||||
if err = json.NewEncoder(f).Encode(data); err != nil {
|
||||
return fmt.Errorf("failed to write token to file: %w", err)
|
||||
}
|
||||
return nil
|
||||
|
||||
@@ -35,11 +35,21 @@ type GeminiTokenStorage struct {
|
||||
|
||||
// Type indicates the authentication provider type, always "gemini" for this storage.
|
||||
Type string `json:"type"`
|
||||
|
||||
// Metadata holds arbitrary key-value pairs injected via hooks.
|
||||
// It is not exported to JSON directly to allow flattening during serialization.
|
||||
Metadata map[string]any `json:"-"`
|
||||
}
|
||||
|
||||
// SetMetadata allows external callers to inject metadata into the storage before saving.
|
||||
func (ts *GeminiTokenStorage) SetMetadata(meta map[string]any) {
|
||||
ts.Metadata = meta
|
||||
}
|
||||
|
||||
// SaveTokenToFile serializes the Gemini token storage to a JSON file.
|
||||
// This method creates the necessary directory structure and writes the token
|
||||
// data in JSON format to the specified file path for persistent storage.
|
||||
// It merges any injected metadata into the top-level JSON object.
|
||||
//
|
||||
// Parameters:
|
||||
// - authFilePath: The full path where the token file should be saved
|
||||
@@ -49,6 +59,11 @@ type GeminiTokenStorage struct {
|
||||
func (ts *GeminiTokenStorage) SaveTokenToFile(authFilePath string) error {
|
||||
misc.LogSavingCredentials(authFilePath)
|
||||
ts.Type = "gemini"
|
||||
// Merge metadata using helper
|
||||
data, errMerge := misc.MergeMetadata(ts, ts.Metadata)
|
||||
if errMerge != nil {
|
||||
return fmt.Errorf("failed to merge metadata: %w", errMerge)
|
||||
}
|
||||
if err := os.MkdirAll(filepath.Dir(authFilePath), 0700); err != nil {
|
||||
return fmt.Errorf("failed to create directory: %v", err)
|
||||
}
|
||||
@@ -63,7 +78,9 @@ func (ts *GeminiTokenStorage) SaveTokenToFile(authFilePath string) error {
|
||||
}
|
||||
}()
|
||||
|
||||
if err = json.NewEncoder(f).Encode(ts); err != nil {
|
||||
enc := json.NewEncoder(f)
|
||||
enc.SetIndent("", " ")
|
||||
if err := enc.Encode(data); err != nil {
|
||||
return fmt.Errorf("failed to write token to file: %w", err)
|
||||
}
|
||||
return nil
|
||||
|
||||
@@ -21,6 +21,15 @@ type IFlowTokenStorage struct {
|
||||
Scope string `json:"scope"`
|
||||
Cookie string `json:"cookie"`
|
||||
Type string `json:"type"`
|
||||
|
||||
// Metadata holds arbitrary key-value pairs injected via hooks.
|
||||
// It is not exported to JSON directly to allow flattening during serialization.
|
||||
Metadata map[string]any `json:"-"`
|
||||
}
|
||||
|
||||
// SetMetadata allows external callers to inject metadata into the storage before saving.
|
||||
func (ts *IFlowTokenStorage) SetMetadata(meta map[string]any) {
|
||||
ts.Metadata = meta
|
||||
}
|
||||
|
||||
// SaveTokenToFile serialises the token storage to disk.
|
||||
@@ -37,7 +46,13 @@ func (ts *IFlowTokenStorage) SaveTokenToFile(authFilePath string) error {
|
||||
}
|
||||
defer func() { _ = f.Close() }()
|
||||
|
||||
if err = json.NewEncoder(f).Encode(ts); err != nil {
|
||||
// Merge metadata using helper
|
||||
data, errMerge := misc.MergeMetadata(ts, ts.Metadata)
|
||||
if errMerge != nil {
|
||||
return fmt.Errorf("failed to merge metadata: %w", errMerge)
|
||||
}
|
||||
|
||||
if err = json.NewEncoder(f).Encode(data); err != nil {
|
||||
return fmt.Errorf("iflow token: encode token failed: %w", err)
|
||||
}
|
||||
return nil
|
||||
|
||||
@@ -29,6 +29,15 @@ type KimiTokenStorage struct {
|
||||
Expired string `json:"expired,omitempty"`
|
||||
// Type indicates the authentication provider type, always "kimi" for this storage.
|
||||
Type string `json:"type"`
|
||||
|
||||
// Metadata holds arbitrary key-value pairs injected via hooks.
|
||||
// It is not exported to JSON directly to allow flattening during serialization.
|
||||
Metadata map[string]any `json:"-"`
|
||||
}
|
||||
|
||||
// SetMetadata allows external callers to inject metadata into the storage before saving.
|
||||
func (ts *KimiTokenStorage) SetMetadata(meta map[string]any) {
|
||||
ts.Metadata = meta
|
||||
}
|
||||
|
||||
// KimiTokenData holds the raw OAuth token response from Kimi.
|
||||
@@ -86,9 +95,15 @@ func (ts *KimiTokenStorage) SaveTokenToFile(authFilePath string) error {
|
||||
_ = f.Close()
|
||||
}()
|
||||
|
||||
// Merge metadata using helper
|
||||
data, errMerge := misc.MergeMetadata(ts, ts.Metadata)
|
||||
if errMerge != nil {
|
||||
return fmt.Errorf("failed to merge metadata: %w", errMerge)
|
||||
}
|
||||
|
||||
encoder := json.NewEncoder(f)
|
||||
encoder.SetIndent("", " ")
|
||||
if err = encoder.Encode(ts); err != nil {
|
||||
if err = encoder.Encode(data); err != nil {
|
||||
return fmt.Errorf("failed to write token to file: %w", err)
|
||||
}
|
||||
return nil
|
||||
|
||||
@@ -30,11 +30,21 @@ type QwenTokenStorage struct {
|
||||
Type string `json:"type"`
|
||||
// Expire is the timestamp when the current access token expires.
|
||||
Expire string `json:"expired"`
|
||||
|
||||
// Metadata holds arbitrary key-value pairs injected via hooks.
|
||||
// It is not exported to JSON directly to allow flattening during serialization.
|
||||
Metadata map[string]any `json:"-"`
|
||||
}
|
||||
|
||||
// SetMetadata allows external callers to inject metadata into the storage before saving.
|
||||
func (ts *QwenTokenStorage) SetMetadata(meta map[string]any) {
|
||||
ts.Metadata = meta
|
||||
}
|
||||
|
||||
// SaveTokenToFile serializes the Qwen token storage to a JSON file.
|
||||
// This method creates the necessary directory structure and writes the token
|
||||
// data in JSON format to the specified file path for persistent storage.
|
||||
// It merges any injected metadata into the top-level JSON object.
|
||||
//
|
||||
// Parameters:
|
||||
// - authFilePath: The full path where the token file should be saved
|
||||
@@ -56,7 +66,13 @@ func (ts *QwenTokenStorage) SaveTokenToFile(authFilePath string) error {
|
||||
_ = f.Close()
|
||||
}()
|
||||
|
||||
if err = json.NewEncoder(f).Encode(ts); err != nil {
|
||||
// Merge metadata using helper
|
||||
data, errMerge := misc.MergeMetadata(ts, ts.Metadata)
|
||||
if errMerge != nil {
|
||||
return fmt.Errorf("failed to merge metadata: %w", errMerge)
|
||||
}
|
||||
|
||||
if err = json.NewEncoder(f).Encode(data); err != nil {
|
||||
return fmt.Errorf("failed to write token to file: %w", err)
|
||||
}
|
||||
return nil
|
||||
|
||||
60
internal/cmd/openai_device_login.go
Normal file
60
internal/cmd/openai_device_login.go
Normal file
@@ -0,0 +1,60 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/auth/codex"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
sdkAuth "github.com/router-for-me/CLIProxyAPI/v6/sdk/auth"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
const (
|
||||
codexLoginModeMetadataKey = "codex_login_mode"
|
||||
codexLoginModeDevice = "device"
|
||||
)
|
||||
|
||||
// DoCodexDeviceLogin triggers the Codex device-code flow while keeping the
|
||||
// existing codex-login OAuth callback flow intact.
|
||||
func DoCodexDeviceLogin(cfg *config.Config, options *LoginOptions) {
|
||||
if options == nil {
|
||||
options = &LoginOptions{}
|
||||
}
|
||||
|
||||
promptFn := options.Prompt
|
||||
if promptFn == nil {
|
||||
promptFn = defaultProjectPrompt()
|
||||
}
|
||||
|
||||
manager := newAuthManager()
|
||||
|
||||
authOpts := &sdkAuth.LoginOptions{
|
||||
NoBrowser: options.NoBrowser,
|
||||
CallbackPort: options.CallbackPort,
|
||||
Metadata: map[string]string{
|
||||
codexLoginModeMetadataKey: codexLoginModeDevice,
|
||||
},
|
||||
Prompt: promptFn,
|
||||
}
|
||||
|
||||
_, savedPath, err := manager.Login(context.Background(), "codex", cfg, authOpts)
|
||||
if err != nil {
|
||||
if authErr, ok := errors.AsType[*codex.AuthenticationError](err); ok {
|
||||
log.Error(codex.GetUserFriendlyMessage(authErr))
|
||||
if authErr.Type == codex.ErrPortInUse.Type {
|
||||
os.Exit(codex.ErrPortInUse.Code)
|
||||
}
|
||||
return
|
||||
}
|
||||
fmt.Printf("Codex device authentication failed: %v\n", err)
|
||||
return
|
||||
}
|
||||
|
||||
if savedPath != "" {
|
||||
fmt.Printf("Authentication saved to %s\n", savedPath)
|
||||
}
|
||||
fmt.Println("Codex device authentication successful!")
|
||||
}
|
||||
@@ -1,6 +1,7 @@
|
||||
package misc
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
@@ -24,3 +25,37 @@ func LogSavingCredentials(path string) {
|
||||
func LogCredentialSeparator() {
|
||||
log.Debug(credentialSeparator)
|
||||
}
|
||||
|
||||
// MergeMetadata serializes the source struct into a map and merges the provided metadata into it.
|
||||
func MergeMetadata(source any, metadata map[string]any) (map[string]any, error) {
|
||||
var data map[string]any
|
||||
|
||||
// Fast path: if source is already a map, just copy it to avoid mutation of original
|
||||
if srcMap, ok := source.(map[string]any); ok {
|
||||
data = make(map[string]any, len(srcMap)+len(metadata))
|
||||
for k, v := range srcMap {
|
||||
data[k] = v
|
||||
}
|
||||
} else {
|
||||
// Slow path: marshal to JSON and back to map to respect JSON tags
|
||||
temp, err := json.Marshal(source)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to marshal source: %w", err)
|
||||
}
|
||||
if err := json.Unmarshal(temp, &data); err != nil {
|
||||
return nil, fmt.Errorf("failed to unmarshal to map: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Merge extra metadata
|
||||
if metadata != nil {
|
||||
if data == nil {
|
||||
data = make(map[string]any)
|
||||
}
|
||||
for k, v := range metadata {
|
||||
data[k] = v
|
||||
}
|
||||
}
|
||||
|
||||
return data, nil
|
||||
}
|
||||
|
||||
@@ -322,7 +322,7 @@ func GetGeminiVertexModels() []*ModelInfo {
|
||||
InputTokenLimit: 1048576,
|
||||
OutputTokenLimit: 65536,
|
||||
SupportedGenerationMethods: []string{"generateContent", "countTokens", "createCachedContent", "batchGenerateContent"},
|
||||
Thinking: &ThinkingSupport{Min: 1, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"low", "high"}},
|
||||
Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"low", "high"}},
|
||||
},
|
||||
{
|
||||
ID: "gemini-3-pro-image-preview",
|
||||
@@ -466,6 +466,21 @@ func GetGeminiCLIModels() []*ModelInfo {
|
||||
SupportedGenerationMethods: []string{"generateContent", "countTokens", "createCachedContent", "batchGenerateContent"},
|
||||
Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"low", "high"}},
|
||||
},
|
||||
{
|
||||
ID: "gemini-3.1-pro-preview",
|
||||
Object: "model",
|
||||
Created: 1771459200,
|
||||
OwnedBy: "google",
|
||||
Type: "gemini",
|
||||
Name: "models/gemini-3.1-pro-preview",
|
||||
Version: "3.1",
|
||||
DisplayName: "Gemini 3.1 Pro Preview",
|
||||
Description: "Gemini 3.1 Pro Preview",
|
||||
InputTokenLimit: 1048576,
|
||||
OutputTokenLimit: 65536,
|
||||
SupportedGenerationMethods: []string{"generateContent", "countTokens", "createCachedContent", "batchGenerateContent"},
|
||||
Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"low", "high"}},
|
||||
},
|
||||
{
|
||||
ID: "gemini-3-flash-preview",
|
||||
Object: "model",
|
||||
@@ -948,6 +963,7 @@ func GetAntigravityModelConfig() map[string]*AntigravityModelConfig {
|
||||
"gemini-2.5-flash-lite": {Thinking: &ThinkingSupport{Min: 0, Max: 24576, ZeroAllowed: true, DynamicAllowed: true}},
|
||||
"gemini-3-pro-high": {Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"low", "high"}}},
|
||||
"gemini-3-pro-image": {Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"low", "high"}}},
|
||||
"gemini-3.1-pro-high": {Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"low", "high"}}},
|
||||
"gemini-3-flash": {Thinking: &ThinkingSupport{Min: 128, Max: 32768, ZeroAllowed: false, DynamicAllowed: true, Levels: []string{"minimal", "low", "medium", "high"}}},
|
||||
"claude-sonnet-4-5-thinking": {Thinking: &ThinkingSupport{Min: 1024, Max: 128000, ZeroAllowed: true, DynamicAllowed: true}, MaxCompletionTokens: 64000},
|
||||
"claude-opus-4-5-thinking": {Thinking: &ThinkingSupport{Min: 1024, Max: 128000, ZeroAllowed: true, DynamicAllowed: true}, MaxCompletionTokens: 64000},
|
||||
|
||||
@@ -10,53 +10,10 @@ import (
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/thinking"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/tidwall/gjson"
|
||||
"github.com/tidwall/sjson"
|
||||
)
|
||||
|
||||
// validReasoningEffortLevels contains the standard values accepted by the
|
||||
// OpenAI reasoning_effort field. Provider-specific extensions (xhigh, minimal,
|
||||
// auto) are NOT in this set and must be clamped before use.
|
||||
var validReasoningEffortLevels = map[string]struct{}{
|
||||
"none": {},
|
||||
"low": {},
|
||||
"medium": {},
|
||||
"high": {},
|
||||
}
|
||||
|
||||
// clampReasoningEffort maps any thinking level string to a value that is safe
|
||||
// to send as OpenAI reasoning_effort. Non-standard CPA-internal values are
|
||||
// mapped to the nearest standard equivalent.
|
||||
//
|
||||
// Mapping rules:
|
||||
// - none / low / medium / high → returned as-is (already valid)
|
||||
// - xhigh → "high" (nearest lower standard level)
|
||||
// - minimal → "low" (nearest higher standard level)
|
||||
// - auto → "medium" (reasonable default)
|
||||
// - anything else → "medium" (safe default)
|
||||
func clampReasoningEffort(level string) string {
|
||||
if _, ok := validReasoningEffortLevels[level]; ok {
|
||||
return level
|
||||
}
|
||||
var clamped string
|
||||
switch level {
|
||||
case string(thinking.LevelXHigh):
|
||||
clamped = string(thinking.LevelHigh)
|
||||
case string(thinking.LevelMinimal):
|
||||
clamped = string(thinking.LevelLow)
|
||||
case string(thinking.LevelAuto):
|
||||
clamped = string(thinking.LevelMedium)
|
||||
default:
|
||||
clamped = string(thinking.LevelMedium)
|
||||
}
|
||||
log.WithFields(log.Fields{
|
||||
"original": level,
|
||||
"clamped": clamped,
|
||||
}).Debug("openai: reasoning_effort clamped to nearest valid standard value")
|
||||
return clamped
|
||||
}
|
||||
|
||||
// Applier implements thinking.ProviderApplier for OpenAI models.
|
||||
//
|
||||
// OpenAI-specific behavior:
|
||||
@@ -101,7 +58,7 @@ func (a *Applier) Apply(body []byte, config thinking.ThinkingConfig, modelInfo *
|
||||
}
|
||||
|
||||
if config.Mode == thinking.ModeLevel {
|
||||
result, _ := sjson.SetBytes(body, "reasoning_effort", clampReasoningEffort(string(config.Level)))
|
||||
result, _ := sjson.SetBytes(body, "reasoning_effort", string(config.Level))
|
||||
return result, nil
|
||||
}
|
||||
|
||||
@@ -122,7 +79,7 @@ func (a *Applier) Apply(body []byte, config thinking.ThinkingConfig, modelInfo *
|
||||
return body, nil
|
||||
}
|
||||
|
||||
result, _ := sjson.SetBytes(body, "reasoning_effort", clampReasoningEffort(effort))
|
||||
result, _ := sjson.SetBytes(body, "reasoning_effort", effort)
|
||||
return result, nil
|
||||
}
|
||||
|
||||
@@ -157,7 +114,7 @@ func applyCompatibleOpenAI(body []byte, config thinking.ThinkingConfig) ([]byte,
|
||||
return body, nil
|
||||
}
|
||||
|
||||
result, _ := sjson.SetBytes(body, "reasoning_effort", clampReasoningEffort(effort))
|
||||
result, _ := sjson.SetBytes(body, "reasoning_effort", effort)
|
||||
return result, nil
|
||||
}
|
||||
|
||||
|
||||
@@ -95,9 +95,9 @@ func ConvertAntigravityResponseToOpenAI(_ context.Context, _ string, originalReq
|
||||
if totalTokenCountResult := usageResult.Get("totalTokenCount"); totalTokenCountResult.Exists() {
|
||||
template, _ = sjson.Set(template, "usage.total_tokens", totalTokenCountResult.Int())
|
||||
}
|
||||
promptTokenCount := usageResult.Get("promptTokenCount").Int() - cachedTokenCount
|
||||
promptTokenCount := usageResult.Get("promptTokenCount").Int()
|
||||
thoughtsTokenCount := usageResult.Get("thoughtsTokenCount").Int()
|
||||
template, _ = sjson.Set(template, "usage.prompt_tokens", promptTokenCount+thoughtsTokenCount)
|
||||
template, _ = sjson.Set(template, "usage.prompt_tokens", promptTokenCount)
|
||||
if thoughtsTokenCount > 0 {
|
||||
template, _ = sjson.Set(template, "usage.completion_tokens_details.reasoning_tokens", thoughtsTokenCount)
|
||||
}
|
||||
|
||||
@@ -199,6 +199,21 @@ func ConvertOpenAIRequestToClaude(modelName string, inputRawJSON []byte, stream
|
||||
msg, _ = sjson.SetRaw(msg, "content.-1", imagePart)
|
||||
}
|
||||
}
|
||||
|
||||
case "file":
|
||||
fileData := part.Get("file.file_data").String()
|
||||
if strings.HasPrefix(fileData, "data:") {
|
||||
semicolonIdx := strings.Index(fileData, ";")
|
||||
commaIdx := strings.Index(fileData, ",")
|
||||
if semicolonIdx != -1 && commaIdx != -1 && commaIdx > semicolonIdx {
|
||||
mediaType := strings.TrimPrefix(fileData[:semicolonIdx], "data:")
|
||||
data := fileData[commaIdx+1:]
|
||||
docPart := `{"type":"document","source":{"type":"base64","media_type":"","data":""}}`
|
||||
docPart, _ = sjson.Set(docPart, "source.media_type", mediaType)
|
||||
docPart, _ = sjson.Set(docPart, "source.data", data)
|
||||
msg, _ = sjson.SetRaw(msg, "content.-1", docPart)
|
||||
}
|
||||
}
|
||||
}
|
||||
return true
|
||||
})
|
||||
|
||||
@@ -155,6 +155,7 @@ func ConvertOpenAIResponsesRequestToClaude(modelName string, inputRawJSON []byte
|
||||
var textAggregate strings.Builder
|
||||
var partsJSON []string
|
||||
hasImage := false
|
||||
hasFile := false
|
||||
if parts := item.Get("content"); parts.Exists() && parts.IsArray() {
|
||||
parts.ForEach(func(_, part gjson.Result) bool {
|
||||
ptype := part.Get("type").String()
|
||||
@@ -207,6 +208,30 @@ func ConvertOpenAIResponsesRequestToClaude(modelName string, inputRawJSON []byte
|
||||
hasImage = true
|
||||
}
|
||||
}
|
||||
case "input_file":
|
||||
fileData := part.Get("file_data").String()
|
||||
if fileData != "" {
|
||||
mediaType := "application/octet-stream"
|
||||
data := fileData
|
||||
if strings.HasPrefix(fileData, "data:") {
|
||||
trimmed := strings.TrimPrefix(fileData, "data:")
|
||||
mediaAndData := strings.SplitN(trimmed, ";base64,", 2)
|
||||
if len(mediaAndData) == 2 {
|
||||
if mediaAndData[0] != "" {
|
||||
mediaType = mediaAndData[0]
|
||||
}
|
||||
data = mediaAndData[1]
|
||||
}
|
||||
}
|
||||
contentPart := `{"type":"document","source":{"type":"base64","media_type":"","data":""}}`
|
||||
contentPart, _ = sjson.Set(contentPart, "source.media_type", mediaType)
|
||||
contentPart, _ = sjson.Set(contentPart, "source.data", data)
|
||||
partsJSON = append(partsJSON, contentPart)
|
||||
if role == "" {
|
||||
role = "user"
|
||||
}
|
||||
hasFile = true
|
||||
}
|
||||
}
|
||||
return true
|
||||
})
|
||||
@@ -228,7 +253,7 @@ func ConvertOpenAIResponsesRequestToClaude(modelName string, inputRawJSON []byte
|
||||
if len(partsJSON) > 0 {
|
||||
msg := `{"role":"","content":[]}`
|
||||
msg, _ = sjson.Set(msg, "role", role)
|
||||
if len(partsJSON) == 1 && !hasImage {
|
||||
if len(partsJSON) == 1 && !hasImage && !hasFile {
|
||||
// Preserve legacy behavior for single text content
|
||||
msg, _ = sjson.Delete(msg, "content")
|
||||
textPart := gjson.Parse(partsJSON[0])
|
||||
|
||||
@@ -180,7 +180,19 @@ func ConvertOpenAIRequestToCodex(modelName string, inputRawJSON []byte, stream b
|
||||
msg, _ = sjson.SetRaw(msg, "content.-1", part)
|
||||
}
|
||||
case "file":
|
||||
// Files are not specified in examples; skip for now
|
||||
if role == "user" {
|
||||
fileData := it.Get("file.file_data").String()
|
||||
filename := it.Get("file.filename").String()
|
||||
if fileData != "" {
|
||||
part := `{}`
|
||||
part, _ = sjson.Set(part, "type", "input_file")
|
||||
part, _ = sjson.Set(part, "file_data", fileData)
|
||||
if filename != "" {
|
||||
part, _ = sjson.Set(part, "filename", filename)
|
||||
}
|
||||
msg, _ = sjson.SetRaw(msg, "content.-1", part)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -100,7 +100,7 @@ func ConvertCliResponseToOpenAI(_ context.Context, _ string, originalRequestRawJ
|
||||
}
|
||||
promptTokenCount := usageResult.Get("promptTokenCount").Int()
|
||||
thoughtsTokenCount := usageResult.Get("thoughtsTokenCount").Int()
|
||||
template, _ = sjson.Set(template, "usage.prompt_tokens", promptTokenCount+thoughtsTokenCount)
|
||||
template, _ = sjson.Set(template, "usage.prompt_tokens", promptTokenCount)
|
||||
if thoughtsTokenCount > 0 {
|
||||
template, _ = sjson.Set(template, "usage.completion_tokens_details.reasoning_tokens", thoughtsTokenCount)
|
||||
}
|
||||
|
||||
@@ -100,9 +100,9 @@ func ConvertGeminiResponseToOpenAI(_ context.Context, _ string, originalRequestR
|
||||
if totalTokenCountResult := usageResult.Get("totalTokenCount"); totalTokenCountResult.Exists() {
|
||||
baseTemplate, _ = sjson.Set(baseTemplate, "usage.total_tokens", totalTokenCountResult.Int())
|
||||
}
|
||||
promptTokenCount := usageResult.Get("promptTokenCount").Int() - cachedTokenCount
|
||||
promptTokenCount := usageResult.Get("promptTokenCount").Int()
|
||||
thoughtsTokenCount := usageResult.Get("thoughtsTokenCount").Int()
|
||||
baseTemplate, _ = sjson.Set(baseTemplate, "usage.prompt_tokens", promptTokenCount+thoughtsTokenCount)
|
||||
baseTemplate, _ = sjson.Set(baseTemplate, "usage.prompt_tokens", promptTokenCount)
|
||||
if thoughtsTokenCount > 0 {
|
||||
baseTemplate, _ = sjson.Set(baseTemplate, "usage.completion_tokens_details.reasoning_tokens", thoughtsTokenCount)
|
||||
}
|
||||
@@ -297,7 +297,7 @@ func ConvertGeminiResponseToOpenAINonStream(_ context.Context, _ string, origina
|
||||
promptTokenCount := usageResult.Get("promptTokenCount").Int()
|
||||
thoughtsTokenCount := usageResult.Get("thoughtsTokenCount").Int()
|
||||
cachedTokenCount := usageResult.Get("cachedContentTokenCount").Int()
|
||||
template, _ = sjson.Set(template, "usage.prompt_tokens", promptTokenCount+thoughtsTokenCount)
|
||||
template, _ = sjson.Set(template, "usage.prompt_tokens", promptTokenCount)
|
||||
if thoughtsTokenCount > 0 {
|
||||
template, _ = sjson.Set(template, "usage.completion_tokens_details.reasoning_tokens", thoughtsTokenCount)
|
||||
}
|
||||
|
||||
@@ -531,8 +531,8 @@ func ConvertGeminiResponseToOpenAIResponses(_ context.Context, modelName string,
|
||||
|
||||
// usage mapping
|
||||
if um := root.Get("usageMetadata"); um.Exists() {
|
||||
// input tokens = prompt + thoughts
|
||||
input := um.Get("promptTokenCount").Int() + um.Get("thoughtsTokenCount").Int()
|
||||
// input tokens = prompt only (thoughts go to output)
|
||||
input := um.Get("promptTokenCount").Int()
|
||||
completed, _ = sjson.Set(completed, "response.usage.input_tokens", input)
|
||||
// cached token details: align with OpenAI "cached_tokens" semantics.
|
||||
completed, _ = sjson.Set(completed, "response.usage.input_tokens_details.cached_tokens", um.Get("cachedContentTokenCount").Int())
|
||||
@@ -737,8 +737,8 @@ func ConvertGeminiResponseToOpenAIResponsesNonStream(_ context.Context, _ string
|
||||
|
||||
// usage mapping
|
||||
if um := root.Get("usageMetadata"); um.Exists() {
|
||||
// input tokens = prompt + thoughts
|
||||
input := um.Get("promptTokenCount").Int() + um.Get("thoughtsTokenCount").Int()
|
||||
// input tokens = prompt only (thoughts go to output)
|
||||
input := um.Get("promptTokenCount").Int()
|
||||
resp, _ = sjson.Set(resp, "usage.input_tokens", input)
|
||||
// cached token details: align with OpenAI "cached_tokens" semantics.
|
||||
resp, _ = sjson.Set(resp, "usage.input_tokens_details.cached_tokens", um.Get("cachedContentTokenCount").Int())
|
||||
|
||||
@@ -716,6 +716,12 @@ func (h *BaseAPIHandler) ExecuteStreamWithAuthManager(ctx context.Context, handl
|
||||
return
|
||||
}
|
||||
if len(chunk.Payload) > 0 {
|
||||
if handlerType == "openai-response" {
|
||||
if err := validateSSEDataJSON(chunk.Payload); err != nil {
|
||||
_ = sendErr(&interfaces.ErrorMessage{StatusCode: http.StatusBadGateway, Error: err})
|
||||
return
|
||||
}
|
||||
}
|
||||
sentPayload = true
|
||||
if okSendData := sendData(cloneBytes(chunk.Payload)); !okSendData {
|
||||
return
|
||||
@@ -727,6 +733,35 @@ func (h *BaseAPIHandler) ExecuteStreamWithAuthManager(ctx context.Context, handl
|
||||
return dataChan, upstreamHeaders, errChan
|
||||
}
|
||||
|
||||
func validateSSEDataJSON(chunk []byte) error {
|
||||
for _, line := range bytes.Split(chunk, []byte("\n")) {
|
||||
line = bytes.TrimSpace(line)
|
||||
if len(line) == 0 {
|
||||
continue
|
||||
}
|
||||
if !bytes.HasPrefix(line, []byte("data:")) {
|
||||
continue
|
||||
}
|
||||
data := bytes.TrimSpace(line[5:])
|
||||
if len(data) == 0 {
|
||||
continue
|
||||
}
|
||||
if bytes.Equal(data, []byte("[DONE]")) {
|
||||
continue
|
||||
}
|
||||
if json.Valid(data) {
|
||||
continue
|
||||
}
|
||||
const max = 512
|
||||
preview := data
|
||||
if len(preview) > max {
|
||||
preview = preview[:max]
|
||||
}
|
||||
return fmt.Errorf("invalid SSE data JSON (len=%d): %q", len(data), preview)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func statusFromError(err error) int {
|
||||
if err == nil {
|
||||
return 0
|
||||
|
||||
@@ -134,6 +134,37 @@ type authAwareStreamExecutor struct {
|
||||
authIDs []string
|
||||
}
|
||||
|
||||
type invalidJSONStreamExecutor struct{}
|
||||
|
||||
func (e *invalidJSONStreamExecutor) Identifier() string { return "codex" }
|
||||
|
||||
func (e *invalidJSONStreamExecutor) Execute(context.Context, *coreauth.Auth, coreexecutor.Request, coreexecutor.Options) (coreexecutor.Response, error) {
|
||||
return coreexecutor.Response{}, &coreauth.Error{Code: "not_implemented", Message: "Execute not implemented"}
|
||||
}
|
||||
|
||||
func (e *invalidJSONStreamExecutor) ExecuteStream(context.Context, *coreauth.Auth, coreexecutor.Request, coreexecutor.Options) (*coreexecutor.StreamResult, error) {
|
||||
ch := make(chan coreexecutor.StreamChunk, 1)
|
||||
ch <- coreexecutor.StreamChunk{Payload: []byte("event: response.completed\ndata: {\"type\"")}
|
||||
close(ch)
|
||||
return &coreexecutor.StreamResult{Chunks: ch}, nil
|
||||
}
|
||||
|
||||
func (e *invalidJSONStreamExecutor) Refresh(ctx context.Context, auth *coreauth.Auth) (*coreauth.Auth, error) {
|
||||
return auth, nil
|
||||
}
|
||||
|
||||
func (e *invalidJSONStreamExecutor) CountTokens(context.Context, *coreauth.Auth, coreexecutor.Request, coreexecutor.Options) (coreexecutor.Response, error) {
|
||||
return coreexecutor.Response{}, &coreauth.Error{Code: "not_implemented", Message: "CountTokens not implemented"}
|
||||
}
|
||||
|
||||
func (e *invalidJSONStreamExecutor) HttpRequest(ctx context.Context, auth *coreauth.Auth, req *http.Request) (*http.Response, error) {
|
||||
return nil, &coreauth.Error{
|
||||
Code: "not_implemented",
|
||||
Message: "HttpRequest not implemented",
|
||||
HTTPStatus: http.StatusNotImplemented,
|
||||
}
|
||||
}
|
||||
|
||||
func (e *authAwareStreamExecutor) Identifier() string { return "codex" }
|
||||
|
||||
func (e *authAwareStreamExecutor) Execute(context.Context, *coreauth.Auth, coreexecutor.Request, coreexecutor.Options) (coreexecutor.Response, error) {
|
||||
@@ -524,3 +555,55 @@ func TestExecuteStreamWithAuthManager_SelectedAuthCallbackReceivesAuthID(t *test
|
||||
t.Fatalf("selectedAuthID = %q, want %q", selectedAuthID, "auth2")
|
||||
}
|
||||
}
|
||||
|
||||
func TestExecuteStreamWithAuthManager_ValidatesOpenAIResponsesStreamDataJSON(t *testing.T) {
|
||||
executor := &invalidJSONStreamExecutor{}
|
||||
manager := coreauth.NewManager(nil, nil, nil)
|
||||
manager.RegisterExecutor(executor)
|
||||
|
||||
auth1 := &coreauth.Auth{
|
||||
ID: "auth1",
|
||||
Provider: "codex",
|
||||
Status: coreauth.StatusActive,
|
||||
Metadata: map[string]any{"email": "test1@example.com"},
|
||||
}
|
||||
if _, err := manager.Register(context.Background(), auth1); err != nil {
|
||||
t.Fatalf("manager.Register(auth1): %v", err)
|
||||
}
|
||||
|
||||
registry.GetGlobalRegistry().RegisterClient(auth1.ID, auth1.Provider, []*registry.ModelInfo{{ID: "test-model"}})
|
||||
t.Cleanup(func() {
|
||||
registry.GetGlobalRegistry().UnregisterClient(auth1.ID)
|
||||
})
|
||||
|
||||
handler := NewBaseAPIHandlers(&sdkconfig.SDKConfig{}, manager)
|
||||
dataChan, _, errChan := handler.ExecuteStreamWithAuthManager(context.Background(), "openai-response", "test-model", []byte(`{"model":"test-model"}`), "")
|
||||
if dataChan == nil || errChan == nil {
|
||||
t.Fatalf("expected non-nil channels")
|
||||
}
|
||||
|
||||
var got []byte
|
||||
for chunk := range dataChan {
|
||||
got = append(got, chunk...)
|
||||
}
|
||||
if len(got) != 0 {
|
||||
t.Fatalf("expected empty payload, got %q", string(got))
|
||||
}
|
||||
|
||||
gotErr := false
|
||||
for msg := range errChan {
|
||||
if msg == nil {
|
||||
continue
|
||||
}
|
||||
if msg.StatusCode != http.StatusBadGateway {
|
||||
t.Fatalf("expected status %d, got %d", http.StatusBadGateway, msg.StatusCode)
|
||||
}
|
||||
if msg.Error == nil {
|
||||
t.Fatalf("expected error")
|
||||
}
|
||||
gotErr = true
|
||||
}
|
||||
if !gotErr {
|
||||
t.Fatalf("expected terminal error")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -265,8 +265,8 @@ func (h *OpenAIResponsesAPIHandler) forwardResponsesStream(c *gin.Context, flush
|
||||
if errMsg.Error != nil && errMsg.Error.Error() != "" {
|
||||
errText = errMsg.Error.Error()
|
||||
}
|
||||
body := handlers.BuildErrorResponseBody(status, errText)
|
||||
_, _ = fmt.Fprintf(c.Writer, "\nevent: error\ndata: %s\n\n", string(body))
|
||||
chunk := handlers.BuildOpenAIResponsesStreamErrorChunk(status, errText, 0)
|
||||
_, _ = fmt.Fprintf(c.Writer, "\nevent: error\ndata: %s\n\n", string(chunk))
|
||||
},
|
||||
WriteDone: func() {
|
||||
_, _ = c.Writer.Write([]byte("\n"))
|
||||
|
||||
@@ -0,0 +1,43 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/interfaces"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/sdk/api/handlers"
|
||||
sdkconfig "github.com/router-for-me/CLIProxyAPI/v6/sdk/config"
|
||||
)
|
||||
|
||||
func TestForwardResponsesStreamTerminalErrorUsesResponsesErrorChunk(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
base := handlers.NewBaseAPIHandlers(&sdkconfig.SDKConfig{}, nil)
|
||||
h := NewOpenAIResponsesAPIHandler(base)
|
||||
|
||||
recorder := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(recorder)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses", nil)
|
||||
|
||||
flusher, ok := c.Writer.(http.Flusher)
|
||||
if !ok {
|
||||
t.Fatalf("expected gin writer to implement http.Flusher")
|
||||
}
|
||||
|
||||
data := make(chan []byte)
|
||||
errs := make(chan *interfaces.ErrorMessage, 1)
|
||||
errs <- &interfaces.ErrorMessage{StatusCode: http.StatusInternalServerError, Error: errors.New("unexpected EOF")}
|
||||
close(errs)
|
||||
|
||||
h.forwardResponsesStream(c, flusher, func(error) {}, data, errs)
|
||||
body := recorder.Body.String()
|
||||
if !strings.Contains(body, `"type":"error"`) {
|
||||
t.Fatalf("expected responses error chunk, got: %q", body)
|
||||
}
|
||||
if strings.Contains(body, `"error":{`) {
|
||||
t.Fatalf("expected streaming error chunk (top-level type), got HTTP error body: %q", body)
|
||||
}
|
||||
}
|
||||
119
sdk/api/handlers/openai_responses_stream_error.go
Normal file
119
sdk/api/handlers/openai_responses_stream_error.go
Normal file
@@ -0,0 +1,119 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type openAIResponsesStreamErrorChunk struct {
|
||||
Type string `json:"type"`
|
||||
Code string `json:"code"`
|
||||
Message string `json:"message"`
|
||||
SequenceNumber int `json:"sequence_number"`
|
||||
}
|
||||
|
||||
func openAIResponsesStreamErrorCode(status int) string {
|
||||
switch status {
|
||||
case http.StatusUnauthorized:
|
||||
return "invalid_api_key"
|
||||
case http.StatusForbidden:
|
||||
return "insufficient_quota"
|
||||
case http.StatusTooManyRequests:
|
||||
return "rate_limit_exceeded"
|
||||
case http.StatusNotFound:
|
||||
return "model_not_found"
|
||||
case http.StatusRequestTimeout:
|
||||
return "request_timeout"
|
||||
default:
|
||||
if status >= http.StatusInternalServerError {
|
||||
return "internal_server_error"
|
||||
}
|
||||
if status >= http.StatusBadRequest {
|
||||
return "invalid_request_error"
|
||||
}
|
||||
return "unknown_error"
|
||||
}
|
||||
}
|
||||
|
||||
// BuildOpenAIResponsesStreamErrorChunk builds an OpenAI Responses streaming error chunk.
|
||||
//
|
||||
// Important: OpenAI's HTTP error bodies are shaped like {"error":{...}}; those are valid for
|
||||
// non-streaming responses, but streaming clients validate SSE `data:` payloads against a union
|
||||
// of chunks that requires a top-level `type` field.
|
||||
func BuildOpenAIResponsesStreamErrorChunk(status int, errText string, sequenceNumber int) []byte {
|
||||
if status <= 0 {
|
||||
status = http.StatusInternalServerError
|
||||
}
|
||||
if sequenceNumber < 0 {
|
||||
sequenceNumber = 0
|
||||
}
|
||||
|
||||
message := strings.TrimSpace(errText)
|
||||
if message == "" {
|
||||
message = http.StatusText(status)
|
||||
}
|
||||
|
||||
code := openAIResponsesStreamErrorCode(status)
|
||||
|
||||
trimmed := strings.TrimSpace(errText)
|
||||
if trimmed != "" && json.Valid([]byte(trimmed)) {
|
||||
var payload map[string]any
|
||||
if err := json.Unmarshal([]byte(trimmed), &payload); err == nil {
|
||||
if t, ok := payload["type"].(string); ok && strings.TrimSpace(t) == "error" {
|
||||
if m, ok := payload["message"].(string); ok && strings.TrimSpace(m) != "" {
|
||||
message = strings.TrimSpace(m)
|
||||
}
|
||||
if v, ok := payload["code"]; ok && v != nil {
|
||||
if c, ok := v.(string); ok && strings.TrimSpace(c) != "" {
|
||||
code = strings.TrimSpace(c)
|
||||
} else {
|
||||
code = strings.TrimSpace(fmt.Sprint(v))
|
||||
}
|
||||
}
|
||||
if v, ok := payload["sequence_number"].(float64); ok && sequenceNumber == 0 {
|
||||
sequenceNumber = int(v)
|
||||
}
|
||||
}
|
||||
if e, ok := payload["error"].(map[string]any); ok {
|
||||
if m, ok := e["message"].(string); ok && strings.TrimSpace(m) != "" {
|
||||
message = strings.TrimSpace(m)
|
||||
}
|
||||
if v, ok := e["code"]; ok && v != nil {
|
||||
if c, ok := v.(string); ok && strings.TrimSpace(c) != "" {
|
||||
code = strings.TrimSpace(c)
|
||||
} else {
|
||||
code = strings.TrimSpace(fmt.Sprint(v))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if strings.TrimSpace(code) == "" {
|
||||
code = "unknown_error"
|
||||
}
|
||||
|
||||
data, err := json.Marshal(openAIResponsesStreamErrorChunk{
|
||||
Type: "error",
|
||||
Code: code,
|
||||
Message: message,
|
||||
SequenceNumber: sequenceNumber,
|
||||
})
|
||||
if err == nil {
|
||||
return data
|
||||
}
|
||||
|
||||
// Extremely defensive fallback.
|
||||
data, _ = json.Marshal(openAIResponsesStreamErrorChunk{
|
||||
Type: "error",
|
||||
Code: "internal_server_error",
|
||||
Message: message,
|
||||
SequenceNumber: sequenceNumber,
|
||||
})
|
||||
if len(data) > 0 {
|
||||
return data
|
||||
}
|
||||
return []byte(`{"type":"error","code":"internal_server_error","message":"internal error","sequence_number":0}`)
|
||||
}
|
||||
48
sdk/api/handlers/openai_responses_stream_error_test.go
Normal file
48
sdk/api/handlers/openai_responses_stream_error_test.go
Normal file
@@ -0,0 +1,48 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestBuildOpenAIResponsesStreamErrorChunk(t *testing.T) {
|
||||
chunk := BuildOpenAIResponsesStreamErrorChunk(http.StatusInternalServerError, "unexpected EOF", 0)
|
||||
var payload map[string]any
|
||||
if err := json.Unmarshal(chunk, &payload); err != nil {
|
||||
t.Fatalf("unmarshal: %v", err)
|
||||
}
|
||||
if payload["type"] != "error" {
|
||||
t.Fatalf("type = %v, want %q", payload["type"], "error")
|
||||
}
|
||||
if payload["code"] != "internal_server_error" {
|
||||
t.Fatalf("code = %v, want %q", payload["code"], "internal_server_error")
|
||||
}
|
||||
if payload["message"] != "unexpected EOF" {
|
||||
t.Fatalf("message = %v, want %q", payload["message"], "unexpected EOF")
|
||||
}
|
||||
if payload["sequence_number"] != float64(0) {
|
||||
t.Fatalf("sequence_number = %v, want %v", payload["sequence_number"], 0)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildOpenAIResponsesStreamErrorChunkExtractsHTTPErrorBody(t *testing.T) {
|
||||
chunk := BuildOpenAIResponsesStreamErrorChunk(
|
||||
http.StatusInternalServerError,
|
||||
`{"error":{"message":"oops","type":"server_error","code":"internal_server_error"}}`,
|
||||
0,
|
||||
)
|
||||
var payload map[string]any
|
||||
if err := json.Unmarshal(chunk, &payload); err != nil {
|
||||
t.Fatalf("unmarshal: %v", err)
|
||||
}
|
||||
if payload["type"] != "error" {
|
||||
t.Fatalf("type = %v, want %q", payload["type"], "error")
|
||||
}
|
||||
if payload["code"] != "internal_server_error" {
|
||||
t.Fatalf("code = %v, want %q", payload["code"], "internal_server_error")
|
||||
}
|
||||
if payload["message"] != "oops" {
|
||||
t.Fatalf("message = %v, want %q", payload["message"], "oops")
|
||||
}
|
||||
}
|
||||
@@ -2,8 +2,6 @@ package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
@@ -48,6 +46,10 @@ func (a *CodexAuthenticator) Login(ctx context.Context, cfg *config.Config, opts
|
||||
opts = &LoginOptions{}
|
||||
}
|
||||
|
||||
if shouldUseCodexDeviceFlow(opts) {
|
||||
return a.loginWithDeviceFlow(ctx, cfg, opts)
|
||||
}
|
||||
|
||||
callbackPort := a.CallbackPort
|
||||
if opts.CallbackPort > 0 {
|
||||
callbackPort = opts.CallbackPort
|
||||
@@ -186,39 +188,5 @@ waitForCallback:
|
||||
return nil, codex.NewAuthenticationError(codex.ErrCodeExchangeFailed, err)
|
||||
}
|
||||
|
||||
tokenStorage := authSvc.CreateTokenStorage(authBundle)
|
||||
|
||||
if tokenStorage == nil || tokenStorage.Email == "" {
|
||||
return nil, fmt.Errorf("codex token storage missing account information")
|
||||
}
|
||||
|
||||
planType := ""
|
||||
hashAccountID := ""
|
||||
if tokenStorage.IDToken != "" {
|
||||
if claims, errParse := codex.ParseJWTToken(tokenStorage.IDToken); errParse == nil && claims != nil {
|
||||
planType = strings.TrimSpace(claims.CodexAuthInfo.ChatgptPlanType)
|
||||
accountID := strings.TrimSpace(claims.CodexAuthInfo.ChatgptAccountID)
|
||||
if accountID != "" {
|
||||
digest := sha256.Sum256([]byte(accountID))
|
||||
hashAccountID = hex.EncodeToString(digest[:])[:8]
|
||||
}
|
||||
}
|
||||
}
|
||||
fileName := codex.CredentialFileName(tokenStorage.Email, planType, hashAccountID, true)
|
||||
metadata := map[string]any{
|
||||
"email": tokenStorage.Email,
|
||||
}
|
||||
|
||||
fmt.Println("Codex authentication successful")
|
||||
if authBundle.APIKey != "" {
|
||||
fmt.Println("Codex API key obtained and stored")
|
||||
}
|
||||
|
||||
return &coreauth.Auth{
|
||||
ID: fileName,
|
||||
Provider: a.Provider(),
|
||||
FileName: fileName,
|
||||
Storage: tokenStorage,
|
||||
Metadata: metadata,
|
||||
}, nil
|
||||
return a.buildAuthRecord(authSvc, authBundle)
|
||||
}
|
||||
|
||||
291
sdk/auth/codex_device.go
Normal file
291
sdk/auth/codex_device.go
Normal file
@@ -0,0 +1,291 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/auth/codex"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/browser"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
||||
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
const (
|
||||
codexLoginModeMetadataKey = "codex_login_mode"
|
||||
codexLoginModeDevice = "device"
|
||||
codexDeviceUserCodeURL = "https://auth.openai.com/api/accounts/deviceauth/usercode"
|
||||
codexDeviceTokenURL = "https://auth.openai.com/api/accounts/deviceauth/token"
|
||||
codexDeviceVerificationURL = "https://auth.openai.com/codex/device"
|
||||
codexDeviceTokenExchangeRedirectURI = "https://auth.openai.com/deviceauth/callback"
|
||||
codexDeviceTimeout = 15 * time.Minute
|
||||
codexDeviceDefaultPollIntervalSeconds = 5
|
||||
)
|
||||
|
||||
type codexDeviceUserCodeRequest struct {
|
||||
ClientID string `json:"client_id"`
|
||||
}
|
||||
|
||||
type codexDeviceUserCodeResponse struct {
|
||||
DeviceAuthID string `json:"device_auth_id"`
|
||||
UserCode string `json:"user_code"`
|
||||
UserCodeAlt string `json:"usercode"`
|
||||
Interval json.RawMessage `json:"interval"`
|
||||
}
|
||||
|
||||
type codexDeviceTokenRequest struct {
|
||||
DeviceAuthID string `json:"device_auth_id"`
|
||||
UserCode string `json:"user_code"`
|
||||
}
|
||||
|
||||
type codexDeviceTokenResponse struct {
|
||||
AuthorizationCode string `json:"authorization_code"`
|
||||
CodeVerifier string `json:"code_verifier"`
|
||||
CodeChallenge string `json:"code_challenge"`
|
||||
}
|
||||
|
||||
func shouldUseCodexDeviceFlow(opts *LoginOptions) bool {
|
||||
if opts == nil || opts.Metadata == nil {
|
||||
return false
|
||||
}
|
||||
return strings.EqualFold(strings.TrimSpace(opts.Metadata[codexLoginModeMetadataKey]), codexLoginModeDevice)
|
||||
}
|
||||
|
||||
func (a *CodexAuthenticator) loginWithDeviceFlow(ctx context.Context, cfg *config.Config, opts *LoginOptions) (*coreauth.Auth, error) {
|
||||
if ctx == nil {
|
||||
ctx = context.Background()
|
||||
}
|
||||
|
||||
httpClient := util.SetProxy(&cfg.SDKConfig, &http.Client{})
|
||||
|
||||
userCodeResp, err := requestCodexDeviceUserCode(ctx, httpClient)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
deviceCode := strings.TrimSpace(userCodeResp.UserCode)
|
||||
if deviceCode == "" {
|
||||
deviceCode = strings.TrimSpace(userCodeResp.UserCodeAlt)
|
||||
}
|
||||
deviceAuthID := strings.TrimSpace(userCodeResp.DeviceAuthID)
|
||||
if deviceCode == "" || deviceAuthID == "" {
|
||||
return nil, fmt.Errorf("codex device flow did not return required fields")
|
||||
}
|
||||
|
||||
pollInterval := parseCodexDevicePollInterval(userCodeResp.Interval)
|
||||
|
||||
fmt.Println("Starting Codex device authentication...")
|
||||
fmt.Printf("Codex device URL: %s\n", codexDeviceVerificationURL)
|
||||
fmt.Printf("Codex device code: %s\n", deviceCode)
|
||||
|
||||
if !opts.NoBrowser {
|
||||
if !browser.IsAvailable() {
|
||||
log.Warn("No browser available; please open the device URL manually")
|
||||
} else if errOpen := browser.OpenURL(codexDeviceVerificationURL); errOpen != nil {
|
||||
log.Warnf("Failed to open browser automatically: %v", errOpen)
|
||||
}
|
||||
}
|
||||
|
||||
tokenResp, err := pollCodexDeviceToken(ctx, httpClient, deviceAuthID, deviceCode, pollInterval)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
authCode := strings.TrimSpace(tokenResp.AuthorizationCode)
|
||||
codeVerifier := strings.TrimSpace(tokenResp.CodeVerifier)
|
||||
codeChallenge := strings.TrimSpace(tokenResp.CodeChallenge)
|
||||
if authCode == "" || codeVerifier == "" || codeChallenge == "" {
|
||||
return nil, fmt.Errorf("codex device flow token response missing required fields")
|
||||
}
|
||||
|
||||
authSvc := codex.NewCodexAuth(cfg)
|
||||
authBundle, err := authSvc.ExchangeCodeForTokensWithRedirect(
|
||||
ctx,
|
||||
authCode,
|
||||
codexDeviceTokenExchangeRedirectURI,
|
||||
&codex.PKCECodes{
|
||||
CodeVerifier: codeVerifier,
|
||||
CodeChallenge: codeChallenge,
|
||||
},
|
||||
)
|
||||
if err != nil {
|
||||
return nil, codex.NewAuthenticationError(codex.ErrCodeExchangeFailed, err)
|
||||
}
|
||||
|
||||
return a.buildAuthRecord(authSvc, authBundle)
|
||||
}
|
||||
|
||||
func requestCodexDeviceUserCode(ctx context.Context, client *http.Client) (*codexDeviceUserCodeResponse, error) {
|
||||
body, err := json.Marshal(codexDeviceUserCodeRequest{ClientID: codex.ClientID})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to encode codex device request: %w", err)
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, codexDeviceUserCodeURL, bytes.NewReader(body))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create codex device request: %w", err)
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Accept", "application/json")
|
||||
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to request codex device code: %w", err)
|
||||
}
|
||||
defer func() { _ = resp.Body.Close() }()
|
||||
|
||||
respBody, err := io.ReadAll(resp.Body)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to read codex device code response: %w", err)
|
||||
}
|
||||
|
||||
if !codexDeviceIsSuccessStatus(resp.StatusCode) {
|
||||
trimmed := strings.TrimSpace(string(respBody))
|
||||
if resp.StatusCode == http.StatusNotFound {
|
||||
return nil, fmt.Errorf("codex device endpoint is unavailable (status %d)", resp.StatusCode)
|
||||
}
|
||||
if trimmed == "" {
|
||||
trimmed = "empty response body"
|
||||
}
|
||||
return nil, fmt.Errorf("codex device code request failed with status %d: %s", resp.StatusCode, trimmed)
|
||||
}
|
||||
|
||||
var parsed codexDeviceUserCodeResponse
|
||||
if err := json.Unmarshal(respBody, &parsed); err != nil {
|
||||
return nil, fmt.Errorf("failed to decode codex device code response: %w", err)
|
||||
}
|
||||
|
||||
return &parsed, nil
|
||||
}
|
||||
|
||||
func pollCodexDeviceToken(ctx context.Context, client *http.Client, deviceAuthID, userCode string, interval time.Duration) (*codexDeviceTokenResponse, error) {
|
||||
deadline := time.Now().Add(codexDeviceTimeout)
|
||||
|
||||
for {
|
||||
if time.Now().After(deadline) {
|
||||
return nil, fmt.Errorf("codex device authentication timed out after 15 minutes")
|
||||
}
|
||||
|
||||
body, err := json.Marshal(codexDeviceTokenRequest{
|
||||
DeviceAuthID: deviceAuthID,
|
||||
UserCode: userCode,
|
||||
})
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to encode codex device poll request: %w", err)
|
||||
}
|
||||
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodPost, codexDeviceTokenURL, bytes.NewReader(body))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to create codex device poll request: %w", err)
|
||||
}
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
req.Header.Set("Accept", "application/json")
|
||||
|
||||
resp, err := client.Do(req)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to poll codex device token: %w", err)
|
||||
}
|
||||
|
||||
respBody, readErr := io.ReadAll(resp.Body)
|
||||
_ = resp.Body.Close()
|
||||
if readErr != nil {
|
||||
return nil, fmt.Errorf("failed to read codex device poll response: %w", readErr)
|
||||
}
|
||||
|
||||
switch {
|
||||
case codexDeviceIsSuccessStatus(resp.StatusCode):
|
||||
var parsed codexDeviceTokenResponse
|
||||
if err := json.Unmarshal(respBody, &parsed); err != nil {
|
||||
return nil, fmt.Errorf("failed to decode codex device token response: %w", err)
|
||||
}
|
||||
return &parsed, nil
|
||||
case resp.StatusCode == http.StatusForbidden || resp.StatusCode == http.StatusNotFound:
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
case <-time.After(interval):
|
||||
continue
|
||||
}
|
||||
default:
|
||||
trimmed := strings.TrimSpace(string(respBody))
|
||||
if trimmed == "" {
|
||||
trimmed = "empty response body"
|
||||
}
|
||||
return nil, fmt.Errorf("codex device token polling failed with status %d: %s", resp.StatusCode, trimmed)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func parseCodexDevicePollInterval(raw json.RawMessage) time.Duration {
|
||||
defaultInterval := time.Duration(codexDeviceDefaultPollIntervalSeconds) * time.Second
|
||||
if len(raw) == 0 {
|
||||
return defaultInterval
|
||||
}
|
||||
|
||||
var asString string
|
||||
if err := json.Unmarshal(raw, &asString); err == nil {
|
||||
if seconds, convErr := strconv.Atoi(strings.TrimSpace(asString)); convErr == nil && seconds > 0 {
|
||||
return time.Duration(seconds) * time.Second
|
||||
}
|
||||
}
|
||||
|
||||
var asInt int
|
||||
if err := json.Unmarshal(raw, &asInt); err == nil && asInt > 0 {
|
||||
return time.Duration(asInt) * time.Second
|
||||
}
|
||||
|
||||
return defaultInterval
|
||||
}
|
||||
|
||||
func codexDeviceIsSuccessStatus(code int) bool {
|
||||
return code >= 200 && code < 300
|
||||
}
|
||||
|
||||
func (a *CodexAuthenticator) buildAuthRecord(authSvc *codex.CodexAuth, authBundle *codex.CodexAuthBundle) (*coreauth.Auth, error) {
|
||||
tokenStorage := authSvc.CreateTokenStorage(authBundle)
|
||||
|
||||
if tokenStorage == nil || tokenStorage.Email == "" {
|
||||
return nil, fmt.Errorf("codex token storage missing account information")
|
||||
}
|
||||
|
||||
planType := ""
|
||||
hashAccountID := ""
|
||||
if tokenStorage.IDToken != "" {
|
||||
if claims, errParse := codex.ParseJWTToken(tokenStorage.IDToken); errParse == nil && claims != nil {
|
||||
planType = strings.TrimSpace(claims.CodexAuthInfo.ChatgptPlanType)
|
||||
accountID := strings.TrimSpace(claims.CodexAuthInfo.ChatgptAccountID)
|
||||
if accountID != "" {
|
||||
digest := sha256.Sum256([]byte(accountID))
|
||||
hashAccountID = hex.EncodeToString(digest[:])[:8]
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fileName := codex.CredentialFileName(tokenStorage.Email, planType, hashAccountID, true)
|
||||
metadata := map[string]any{
|
||||
"email": tokenStorage.Email,
|
||||
}
|
||||
|
||||
fmt.Println("Codex authentication successful")
|
||||
if authBundle.APIKey != "" {
|
||||
fmt.Println("Codex API key obtained and stored")
|
||||
}
|
||||
|
||||
return &coreauth.Auth{
|
||||
ID: fileName,
|
||||
Provider: a.Provider(),
|
||||
FileName: fileName,
|
||||
Storage: tokenStorage,
|
||||
Metadata: metadata,
|
||||
}, nil
|
||||
}
|
||||
@@ -64,8 +64,16 @@ func (s *FileTokenStore) Save(ctx context.Context, auth *cliproxyauth.Auth) (str
|
||||
return "", fmt.Errorf("auth filestore: create dir failed: %w", err)
|
||||
}
|
||||
|
||||
// metadataSetter is a private interface for TokenStorage implementations that support metadata injection.
|
||||
type metadataSetter interface {
|
||||
SetMetadata(map[string]any)
|
||||
}
|
||||
|
||||
switch {
|
||||
case auth.Storage != nil:
|
||||
if setter, ok := auth.Storage.(metadataSetter); ok {
|
||||
setter.SetMetadata(auth.Metadata)
|
||||
}
|
||||
if err = auth.Storage.SaveTokenToFile(path); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
@@ -5,6 +5,7 @@ import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"math"
|
||||
"math/rand/v2"
|
||||
"net/http"
|
||||
"sort"
|
||||
"strconv"
|
||||
@@ -248,6 +249,9 @@ func getAvailableAuths(auths []*Auth, provider, model string, now time.Time) ([]
|
||||
}
|
||||
|
||||
// Pick selects the next available auth for the provider in a round-robin manner.
|
||||
// For gemini-cli virtual auths (identified by the gemini_virtual_parent attribute),
|
||||
// a two-level round-robin is used: first cycling across credential groups (parent
|
||||
// accounts), then cycling within each group's project auths.
|
||||
func (s *RoundRobinSelector) Pick(ctx context.Context, provider, model string, opts cliproxyexecutor.Options, auths []*Auth) (*Auth, error) {
|
||||
_ = opts
|
||||
now := time.Now()
|
||||
@@ -265,21 +269,87 @@ func (s *RoundRobinSelector) Pick(ctx context.Context, provider, model string, o
|
||||
if limit <= 0 {
|
||||
limit = 4096
|
||||
}
|
||||
if _, ok := s.cursors[key]; !ok && len(s.cursors) >= limit {
|
||||
s.cursors = make(map[string]int)
|
||||
}
|
||||
index := s.cursors[key]
|
||||
|
||||
// Check if any available auth has gemini_virtual_parent attribute,
|
||||
// indicating gemini-cli virtual auths that should use credential-level polling.
|
||||
groups, parentOrder := groupByVirtualParent(available)
|
||||
if len(parentOrder) > 1 {
|
||||
// Two-level round-robin: first select a credential group, then pick within it.
|
||||
groupKey := key + "::group"
|
||||
s.ensureCursorKey(groupKey, limit)
|
||||
if _, exists := s.cursors[groupKey]; !exists {
|
||||
// Seed with a random initial offset so the starting credential is randomized.
|
||||
s.cursors[groupKey] = rand.IntN(len(parentOrder))
|
||||
}
|
||||
groupIndex := s.cursors[groupKey]
|
||||
if groupIndex >= 2_147_483_640 {
|
||||
groupIndex = 0
|
||||
}
|
||||
s.cursors[groupKey] = groupIndex + 1
|
||||
|
||||
selectedParent := parentOrder[groupIndex%len(parentOrder)]
|
||||
group := groups[selectedParent]
|
||||
|
||||
// Second level: round-robin within the selected credential group.
|
||||
innerKey := key + "::cred:" + selectedParent
|
||||
s.ensureCursorKey(innerKey, limit)
|
||||
innerIndex := s.cursors[innerKey]
|
||||
if innerIndex >= 2_147_483_640 {
|
||||
innerIndex = 0
|
||||
}
|
||||
s.cursors[innerKey] = innerIndex + 1
|
||||
s.mu.Unlock()
|
||||
return group[innerIndex%len(group)], nil
|
||||
}
|
||||
|
||||
// Flat round-robin for non-grouped auths (original behavior).
|
||||
s.ensureCursorKey(key, limit)
|
||||
index := s.cursors[key]
|
||||
if index >= 2_147_483_640 {
|
||||
index = 0
|
||||
}
|
||||
|
||||
s.cursors[key] = index + 1
|
||||
s.mu.Unlock()
|
||||
// log.Debugf("available: %d, index: %d, key: %d", len(available), index, index%len(available))
|
||||
return available[index%len(available)], nil
|
||||
}
|
||||
|
||||
// ensureCursorKey ensures the cursor map has capacity for the given key.
|
||||
// Must be called with s.mu held.
|
||||
func (s *RoundRobinSelector) ensureCursorKey(key string, limit int) {
|
||||
if _, ok := s.cursors[key]; !ok && len(s.cursors) >= limit {
|
||||
s.cursors = make(map[string]int)
|
||||
}
|
||||
}
|
||||
|
||||
// groupByVirtualParent groups auths by their gemini_virtual_parent attribute.
|
||||
// Returns a map of parentID -> auths and a sorted slice of parent IDs for stable iteration.
|
||||
// Only auths with a non-empty gemini_virtual_parent are grouped; if any auth lacks
|
||||
// this attribute, nil/nil is returned so the caller falls back to flat round-robin.
|
||||
func groupByVirtualParent(auths []*Auth) (map[string][]*Auth, []string) {
|
||||
if len(auths) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
groups := make(map[string][]*Auth)
|
||||
for _, a := range auths {
|
||||
parent := ""
|
||||
if a.Attributes != nil {
|
||||
parent = strings.TrimSpace(a.Attributes["gemini_virtual_parent"])
|
||||
}
|
||||
if parent == "" {
|
||||
// Non-virtual auth present; fall back to flat round-robin.
|
||||
return nil, nil
|
||||
}
|
||||
groups[parent] = append(groups[parent], a)
|
||||
}
|
||||
// Collect parent IDs in sorted order for stable cursor indexing.
|
||||
parentOrder := make([]string, 0, len(groups))
|
||||
for p := range groups {
|
||||
parentOrder = append(parentOrder, p)
|
||||
}
|
||||
sort.Strings(parentOrder)
|
||||
return groups, parentOrder
|
||||
}
|
||||
|
||||
// Pick selects the first available auth for the provider in a deterministic manner.
|
||||
func (s *FillFirstSelector) Pick(ctx context.Context, provider, model string, opts cliproxyexecutor.Options, auths []*Auth) (*Auth, error) {
|
||||
_ = opts
|
||||
|
||||
@@ -402,3 +402,128 @@ func TestRoundRobinSelectorPick_CursorKeyCap(t *testing.T) {
|
||||
t.Fatalf("selector.cursors missing key %q", "gemini:m3")
|
||||
}
|
||||
}
|
||||
|
||||
func TestRoundRobinSelectorPick_GeminiCLICredentialGrouping(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
selector := &RoundRobinSelector{}
|
||||
|
||||
// Simulate two gemini-cli credentials, each with multiple projects:
|
||||
// Credential A (parent = "cred-a.json") has 3 projects
|
||||
// Credential B (parent = "cred-b.json") has 2 projects
|
||||
auths := []*Auth{
|
||||
{ID: "cred-a.json::proj-a1", Attributes: map[string]string{"gemini_virtual_parent": "cred-a.json"}},
|
||||
{ID: "cred-a.json::proj-a2", Attributes: map[string]string{"gemini_virtual_parent": "cred-a.json"}},
|
||||
{ID: "cred-a.json::proj-a3", Attributes: map[string]string{"gemini_virtual_parent": "cred-a.json"}},
|
||||
{ID: "cred-b.json::proj-b1", Attributes: map[string]string{"gemini_virtual_parent": "cred-b.json"}},
|
||||
{ID: "cred-b.json::proj-b2", Attributes: map[string]string{"gemini_virtual_parent": "cred-b.json"}},
|
||||
}
|
||||
|
||||
// Two-level round-robin: consecutive picks must alternate between credentials.
|
||||
// Credential group order is randomized, but within each call the group cursor
|
||||
// advances by 1, so consecutive picks should cycle through different parents.
|
||||
picks := make([]string, 6)
|
||||
parents := make([]string, 6)
|
||||
for i := 0; i < 6; i++ {
|
||||
got, err := selector.Pick(context.Background(), "gemini-cli", "gemini-2.5-pro", cliproxyexecutor.Options{}, auths)
|
||||
if err != nil {
|
||||
t.Fatalf("Pick() #%d error = %v", i, err)
|
||||
}
|
||||
if got == nil {
|
||||
t.Fatalf("Pick() #%d auth = nil", i)
|
||||
}
|
||||
picks[i] = got.ID
|
||||
parents[i] = got.Attributes["gemini_virtual_parent"]
|
||||
}
|
||||
|
||||
// Verify property: consecutive picks must alternate between credential groups.
|
||||
for i := 1; i < len(parents); i++ {
|
||||
if parents[i] == parents[i-1] {
|
||||
t.Fatalf("Pick() #%d and #%d both from same parent %q (IDs: %q, %q); expected alternating credentials",
|
||||
i-1, i, parents[i], picks[i-1], picks[i])
|
||||
}
|
||||
}
|
||||
|
||||
// Verify property: each credential's projects are picked in sequence (round-robin within group).
|
||||
credPicks := map[string][]string{}
|
||||
for i, id := range picks {
|
||||
credPicks[parents[i]] = append(credPicks[parents[i]], id)
|
||||
}
|
||||
for parent, ids := range credPicks {
|
||||
for i := 1; i < len(ids); i++ {
|
||||
if ids[i] == ids[i-1] {
|
||||
t.Fatalf("Credential %q picked same project %q twice in a row", parent, ids[i])
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestRoundRobinSelectorPick_SingleParentFallsBackToFlat(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
selector := &RoundRobinSelector{}
|
||||
|
||||
// All auths from the same parent - should fall back to flat round-robin
|
||||
// because there's only one credential group (no benefit from two-level).
|
||||
auths := []*Auth{
|
||||
{ID: "cred-a.json::proj-a1", Attributes: map[string]string{"gemini_virtual_parent": "cred-a.json"}},
|
||||
{ID: "cred-a.json::proj-a2", Attributes: map[string]string{"gemini_virtual_parent": "cred-a.json"}},
|
||||
{ID: "cred-a.json::proj-a3", Attributes: map[string]string{"gemini_virtual_parent": "cred-a.json"}},
|
||||
}
|
||||
|
||||
// With single parent group, parentOrder has length 1, so it uses flat round-robin.
|
||||
// Sorted by ID: proj-a1, proj-a2, proj-a3
|
||||
want := []string{
|
||||
"cred-a.json::proj-a1",
|
||||
"cred-a.json::proj-a2",
|
||||
"cred-a.json::proj-a3",
|
||||
"cred-a.json::proj-a1",
|
||||
}
|
||||
|
||||
for i, expectedID := range want {
|
||||
got, err := selector.Pick(context.Background(), "gemini-cli", "gemini-2.5-pro", cliproxyexecutor.Options{}, auths)
|
||||
if err != nil {
|
||||
t.Fatalf("Pick() #%d error = %v", i, err)
|
||||
}
|
||||
if got == nil {
|
||||
t.Fatalf("Pick() #%d auth = nil", i)
|
||||
}
|
||||
if got.ID != expectedID {
|
||||
t.Fatalf("Pick() #%d auth.ID = %q, want %q", i, got.ID, expectedID)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestRoundRobinSelectorPick_MixedVirtualAndNonVirtualFallsBackToFlat(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
selector := &RoundRobinSelector{}
|
||||
|
||||
// Mix of virtual and non-virtual auths (e.g., a regular gemini-cli auth without projects
|
||||
// alongside virtual ones). Should fall back to flat round-robin.
|
||||
auths := []*Auth{
|
||||
{ID: "cred-a.json::proj-a1", Attributes: map[string]string{"gemini_virtual_parent": "cred-a.json"}},
|
||||
{ID: "cred-regular.json"}, // no gemini_virtual_parent
|
||||
}
|
||||
|
||||
// groupByVirtualParent returns nil when any auth lacks the attribute,
|
||||
// so flat round-robin is used. Sorted by ID: cred-a.json::proj-a1, cred-regular.json
|
||||
want := []string{
|
||||
"cred-a.json::proj-a1",
|
||||
"cred-regular.json",
|
||||
"cred-a.json::proj-a1",
|
||||
}
|
||||
|
||||
for i, expectedID := range want {
|
||||
got, err := selector.Pick(context.Background(), "gemini-cli", "", cliproxyexecutor.Options{}, auths)
|
||||
if err != nil {
|
||||
t.Fatalf("Pick() #%d error = %v", i, err)
|
||||
}
|
||||
if got == nil {
|
||||
t.Fatalf("Pick() #%d auth = nil", i)
|
||||
}
|
||||
if got.ID != expectedID {
|
||||
t.Fatalf("Pick() #%d auth.ID = %q, want %q", i, got.ID, expectedID)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,9 +1,12 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
@@ -12,6 +15,33 @@ import (
|
||||
baseauth "github.com/router-for-me/CLIProxyAPI/v6/internal/auth"
|
||||
)
|
||||
|
||||
// PostAuthHook defines a function that is called after an Auth record is created
|
||||
// but before it is persisted to storage. This allows for modification of the
|
||||
// Auth record (e.g., injecting metadata) based on external context.
|
||||
type PostAuthHook func(context.Context, *Auth) error
|
||||
|
||||
// RequestInfo holds information extracted from the HTTP request.
|
||||
// It is injected into the context passed to PostAuthHook.
|
||||
type RequestInfo struct {
|
||||
Query url.Values
|
||||
Headers http.Header
|
||||
}
|
||||
|
||||
type requestInfoKey struct{}
|
||||
|
||||
// WithRequestInfo returns a new context with the given RequestInfo attached.
|
||||
func WithRequestInfo(ctx context.Context, info *RequestInfo) context.Context {
|
||||
return context.WithValue(ctx, requestInfoKey{}, info)
|
||||
}
|
||||
|
||||
// GetRequestInfo retrieves the RequestInfo from the context, if present.
|
||||
func GetRequestInfo(ctx context.Context) *RequestInfo {
|
||||
if val, ok := ctx.Value(requestInfoKey{}).(*RequestInfo); ok {
|
||||
return val
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Auth encapsulates the runtime state and metadata associated with a single credential.
|
||||
type Auth struct {
|
||||
// ID uniquely identifies the auth record across restarts.
|
||||
|
||||
@@ -153,6 +153,16 @@ func (b *Builder) WithLocalManagementPassword(password string) *Builder {
|
||||
return b
|
||||
}
|
||||
|
||||
// WithPostAuthHook registers a hook to be called after an Auth record is created
|
||||
// but before it is persisted to storage.
|
||||
func (b *Builder) WithPostAuthHook(hook coreauth.PostAuthHook) *Builder {
|
||||
if hook == nil {
|
||||
return b
|
||||
}
|
||||
b.serverOptions = append(b.serverOptions, api.WithPostAuthHook(hook))
|
||||
return b
|
||||
}
|
||||
|
||||
// Build validates inputs, applies defaults, and returns a ready-to-run service.
|
||||
func (b *Builder) Build() (*Service, error) {
|
||||
if b.cfg == nil {
|
||||
|
||||
Reference in New Issue
Block a user