feat(logging, executor): add request logging tests and WebSocket-based Codex executor

- Introduced unit tests for request logging middleware to enhance coverage.
- Added WebSocket-based Codex executor to support Responses API upgrade.
- Updated middleware logic to selectively capture request bodies for memory efficiency.
- Enhanced Codex configuration handling with new WebSocket attributes.
This commit is contained in:
Luis Pater
2026-02-19 01:57:02 +08:00
parent 55789df275
commit bb86a0c0c4
24 changed files with 3332 additions and 34 deletions

View File

@@ -15,10 +15,12 @@ import (
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
)
const maxErrorOnlyCapturedRequestBodyBytes int64 = 1 << 20 // 1 MiB
// RequestLoggingMiddleware creates a Gin middleware that logs HTTP requests and responses.
// It captures detailed information about the request and response, including headers and body,
// and uses the provided RequestLogger to record this data. When logging is disabled in the
// logger, it still captures data so that upstream errors can be persisted.
// and uses the provided RequestLogger to record this data. When full request logging is disabled,
// body capture is limited to small known-size payloads to avoid large per-request memory spikes.
func RequestLoggingMiddleware(logger logging.RequestLogger) gin.HandlerFunc {
return func(c *gin.Context) {
if logger == nil {
@@ -26,7 +28,7 @@ func RequestLoggingMiddleware(logger logging.RequestLogger) gin.HandlerFunc {
return
}
if c.Request.Method == http.MethodGet {
if shouldSkipMethodForRequestLogging(c.Request) {
c.Next()
return
}
@@ -37,8 +39,10 @@ func RequestLoggingMiddleware(logger logging.RequestLogger) gin.HandlerFunc {
return
}
loggerEnabled := logger.IsEnabled()
// Capture request information
requestInfo, err := captureRequestInfo(c)
requestInfo, err := captureRequestInfo(c, shouldCaptureRequestBody(loggerEnabled, c.Request))
if err != nil {
// Log error but continue processing
// In a real implementation, you might want to use a proper logger here
@@ -48,7 +52,7 @@ func RequestLoggingMiddleware(logger logging.RequestLogger) gin.HandlerFunc {
// Create response writer wrapper
wrapper := NewResponseWriterWrapper(c.Writer, logger, requestInfo)
if !logger.IsEnabled() {
if !loggerEnabled {
wrapper.logOnErrorOnly = true
}
c.Writer = wrapper
@@ -64,10 +68,47 @@ func RequestLoggingMiddleware(logger logging.RequestLogger) gin.HandlerFunc {
}
}
func shouldSkipMethodForRequestLogging(req *http.Request) bool {
if req == nil {
return true
}
if req.Method != http.MethodGet {
return false
}
return !isResponsesWebsocketUpgrade(req)
}
func isResponsesWebsocketUpgrade(req *http.Request) bool {
if req == nil || req.URL == nil {
return false
}
if req.URL.Path != "/v1/responses" {
return false
}
return strings.EqualFold(strings.TrimSpace(req.Header.Get("Upgrade")), "websocket")
}
func shouldCaptureRequestBody(loggerEnabled bool, req *http.Request) bool {
if loggerEnabled {
return true
}
if req == nil || req.Body == nil {
return false
}
contentType := strings.ToLower(strings.TrimSpace(req.Header.Get("Content-Type")))
if strings.HasPrefix(contentType, "multipart/form-data") {
return false
}
if req.ContentLength <= 0 {
return false
}
return req.ContentLength <= maxErrorOnlyCapturedRequestBodyBytes
}
// captureRequestInfo extracts relevant information from the incoming HTTP request.
// It captures the URL, method, headers, and body. The request body is read and then
// restored so that it can be processed by subsequent handlers.
func captureRequestInfo(c *gin.Context) (*RequestInfo, error) {
func captureRequestInfo(c *gin.Context, captureBody bool) (*RequestInfo, error) {
// Capture URL with sensitive query parameters masked
maskedQuery := util.MaskSensitiveQuery(c.Request.URL.RawQuery)
url := c.Request.URL.Path
@@ -86,7 +127,7 @@ func captureRequestInfo(c *gin.Context) (*RequestInfo, error) {
// Capture request body
var body []byte
if c.Request.Body != nil {
if captureBody && c.Request.Body != nil {
// Read the body
bodyBytes, err := io.ReadAll(c.Request.Body)
if err != nil {

View File

@@ -0,0 +1,138 @@
package middleware
import (
"io"
"net/http"
"net/url"
"strings"
"testing"
)
func TestShouldSkipMethodForRequestLogging(t *testing.T) {
tests := []struct {
name string
req *http.Request
skip bool
}{
{
name: "nil request",
req: nil,
skip: true,
},
{
name: "post request should not skip",
req: &http.Request{
Method: http.MethodPost,
URL: &url.URL{Path: "/v1/responses"},
},
skip: false,
},
{
name: "plain get should skip",
req: &http.Request{
Method: http.MethodGet,
URL: &url.URL{Path: "/v1/models"},
Header: http.Header{},
},
skip: true,
},
{
name: "responses websocket upgrade should not skip",
req: &http.Request{
Method: http.MethodGet,
URL: &url.URL{Path: "/v1/responses"},
Header: http.Header{"Upgrade": []string{"websocket"}},
},
skip: false,
},
{
name: "responses get without upgrade should skip",
req: &http.Request{
Method: http.MethodGet,
URL: &url.URL{Path: "/v1/responses"},
Header: http.Header{},
},
skip: true,
},
}
for i := range tests {
got := shouldSkipMethodForRequestLogging(tests[i].req)
if got != tests[i].skip {
t.Fatalf("%s: got skip=%t, want %t", tests[i].name, got, tests[i].skip)
}
}
}
func TestShouldCaptureRequestBody(t *testing.T) {
tests := []struct {
name string
loggerEnabled bool
req *http.Request
want bool
}{
{
name: "logger enabled always captures",
loggerEnabled: true,
req: &http.Request{
Body: io.NopCloser(strings.NewReader("{}")),
ContentLength: -1,
Header: http.Header{"Content-Type": []string{"application/json"}},
},
want: true,
},
{
name: "nil request",
loggerEnabled: false,
req: nil,
want: false,
},
{
name: "small known size json in error-only mode",
loggerEnabled: false,
req: &http.Request{
Body: io.NopCloser(strings.NewReader("{}")),
ContentLength: 2,
Header: http.Header{"Content-Type": []string{"application/json"}},
},
want: true,
},
{
name: "large known size skipped in error-only mode",
loggerEnabled: false,
req: &http.Request{
Body: io.NopCloser(strings.NewReader("x")),
ContentLength: maxErrorOnlyCapturedRequestBodyBytes + 1,
Header: http.Header{"Content-Type": []string{"application/json"}},
},
want: false,
},
{
name: "unknown size skipped in error-only mode",
loggerEnabled: false,
req: &http.Request{
Body: io.NopCloser(strings.NewReader("x")),
ContentLength: -1,
Header: http.Header{"Content-Type": []string{"application/json"}},
},
want: false,
},
{
name: "multipart skipped in error-only mode",
loggerEnabled: false,
req: &http.Request{
Body: io.NopCloser(strings.NewReader("x")),
ContentLength: 1,
Header: http.Header{"Content-Type": []string{"multipart/form-data; boundary=abc"}},
},
want: false,
},
}
for i := range tests {
got := shouldCaptureRequestBody(tests[i].loggerEnabled, tests[i].req)
if got != tests[i].want {
t.Fatalf("%s: got %t, want %t", tests[i].name, got, tests[i].want)
}
}
}

View File

@@ -14,6 +14,8 @@ import (
"github.com/router-for-me/CLIProxyAPI/v6/internal/logging"
)
const requestBodyOverrideContextKey = "REQUEST_BODY_OVERRIDE"
// RequestInfo holds essential details of an incoming HTTP request for logging purposes.
type RequestInfo struct {
URL string // URL is the request URL.
@@ -223,8 +225,8 @@ func (w *ResponseWriterWrapper) detectStreaming(contentType string) bool {
// Only fall back to request payload hints when Content-Type is not set yet.
if w.requestInfo != nil && len(w.requestInfo.Body) > 0 {
bodyStr := string(w.requestInfo.Body)
return strings.Contains(bodyStr, `"stream": true`) || strings.Contains(bodyStr, `"stream":true`)
return bytes.Contains(w.requestInfo.Body, []byte(`"stream": true`)) ||
bytes.Contains(w.requestInfo.Body, []byte(`"stream":true`))
}
return false
@@ -310,7 +312,7 @@ func (w *ResponseWriterWrapper) Finalize(c *gin.Context) error {
return nil
}
return w.logRequest(finalStatusCode, w.cloneHeaders(), w.body.Bytes(), w.extractAPIRequest(c), w.extractAPIResponse(c), w.extractAPIResponseTimestamp(c), slicesAPIResponseError, forceLog)
return w.logRequest(w.extractRequestBody(c), finalStatusCode, w.cloneHeaders(), w.body.Bytes(), w.extractAPIRequest(c), w.extractAPIResponse(c), w.extractAPIResponseTimestamp(c), slicesAPIResponseError, forceLog)
}
func (w *ResponseWriterWrapper) cloneHeaders() map[string][]string {
@@ -361,16 +363,32 @@ func (w *ResponseWriterWrapper) extractAPIResponseTimestamp(c *gin.Context) time
return time.Time{}
}
func (w *ResponseWriterWrapper) logRequest(statusCode int, headers map[string][]string, body []byte, apiRequestBody, apiResponseBody []byte, apiResponseTimestamp time.Time, apiResponseErrors []*interfaces.ErrorMessage, forceLog bool) error {
func (w *ResponseWriterWrapper) extractRequestBody(c *gin.Context) []byte {
if c != nil {
if bodyOverride, isExist := c.Get(requestBodyOverrideContextKey); isExist {
switch value := bodyOverride.(type) {
case []byte:
if len(value) > 0 {
return bytes.Clone(value)
}
case string:
if strings.TrimSpace(value) != "" {
return []byte(value)
}
}
}
}
if w.requestInfo != nil && len(w.requestInfo.Body) > 0 {
return w.requestInfo.Body
}
return nil
}
func (w *ResponseWriterWrapper) logRequest(requestBody []byte, statusCode int, headers map[string][]string, body []byte, apiRequestBody, apiResponseBody []byte, apiResponseTimestamp time.Time, apiResponseErrors []*interfaces.ErrorMessage, forceLog bool) error {
if w.requestInfo == nil {
return nil
}
var requestBody []byte
if len(w.requestInfo.Body) > 0 {
requestBody = w.requestInfo.Body
}
if loggerWithOptions, ok := w.logger.(interface {
LogRequestWithOptions(string, string, map[string][]string, []byte, int, map[string][]string, []byte, []byte, []byte, []*interfaces.ErrorMessage, bool, string, time.Time, time.Time) error
}); ok {

View File

@@ -0,0 +1,43 @@
package middleware
import (
"net/http/httptest"
"testing"
"github.com/gin-gonic/gin"
)
func TestExtractRequestBodyPrefersOverride(t *testing.T) {
gin.SetMode(gin.TestMode)
recorder := httptest.NewRecorder()
c, _ := gin.CreateTestContext(recorder)
wrapper := &ResponseWriterWrapper{
requestInfo: &RequestInfo{Body: []byte("original-body")},
}
body := wrapper.extractRequestBody(c)
if string(body) != "original-body" {
t.Fatalf("request body = %q, want %q", string(body), "original-body")
}
c.Set(requestBodyOverrideContextKey, []byte("override-body"))
body = wrapper.extractRequestBody(c)
if string(body) != "override-body" {
t.Fatalf("request body = %q, want %q", string(body), "override-body")
}
}
func TestExtractRequestBodySupportsStringOverride(t *testing.T) {
gin.SetMode(gin.TestMode)
recorder := httptest.NewRecorder()
c, _ := gin.CreateTestContext(recorder)
wrapper := &ResponseWriterWrapper{}
c.Set(requestBodyOverrideContextKey, "override-as-string")
body := wrapper.extractRequestBody(c)
if string(body) != "override-as-string" {
t.Fatalf("request body = %q, want %q", string(body), "override-as-string")
}
}

View File

@@ -323,6 +323,7 @@ func (s *Server) setupRoutes() {
v1.POST("/completions", openaiHandlers.Completions)
v1.POST("/messages", claudeCodeHandlers.ClaudeMessages)
v1.POST("/messages/count_tokens", claudeCodeHandlers.ClaudeCountTokens)
v1.GET("/responses", openaiResponsesHandlers.ResponsesWebsocket)
v1.POST("/responses", openaiResponsesHandlers.Responses)
v1.POST("/responses/compact", openaiResponsesHandlers.Compact)
}

View File

@@ -355,6 +355,9 @@ type CodexKey struct {
// If empty, the default Codex API URL will be used.
BaseURL string `yaml:"base-url" json:"base-url"`
// Websockets enables the Responses API websocket transport for this credential.
Websockets bool `yaml:"websockets,omitempty" json:"websockets,omitempty"`
// ProxyURL overrides the global proxy setting for this API key if provided.
ProxyURL string `yaml:"proxy-url" json:"proxy-url"`

View File

@@ -19,6 +19,7 @@ import (
// - codex
// - qwen
// - iflow
// - kimi
// - antigravity (returns static overrides only)
func GetStaticModelDefinitionsByChannel(channel string) []*ModelInfo {
key := strings.ToLower(strings.TrimSpace(channel))
@@ -39,6 +40,8 @@ func GetStaticModelDefinitionsByChannel(channel string) []*ModelInfo {
return GetQwenModels()
case "iflow":
return GetIFlowModels()
case "kimi":
return GetKimiModels()
case "antigravity":
cfg := GetAntigravityModelConfig()
if len(cfg) == 0 {
@@ -83,6 +86,7 @@ func LookupStaticModelInfo(modelID string) *ModelInfo {
GetOpenAIModels(),
GetQwenModels(),
GetIFlowModels(),
GetKimiModels(),
}
for _, models := range allModels {
for _, m := range models {

View File

@@ -28,6 +28,17 @@ func GetClaudeModels() []*ModelInfo {
MaxCompletionTokens: 64000,
Thinking: &ThinkingSupport{Min: 1024, Max: 128000, ZeroAllowed: true, DynamicAllowed: false},
},
{
ID: "claude-sonnet-4-6",
Object: "model",
Created: 1771372800, // 2026-02-17
OwnedBy: "anthropic",
Type: "claude",
DisplayName: "Claude 4.6 Sonnet",
ContextLength: 200000,
MaxCompletionTokens: 64000,
Thinking: &ThinkingSupport{Min: 1024, Max: 128000, ZeroAllowed: true, DynamicAllowed: false},
},
{
ID: "claude-opus-4-6",
Object: "model",
@@ -788,6 +799,19 @@ func GetQwenModels() []*ModelInfo {
MaxCompletionTokens: 2048,
SupportedParameters: []string{"temperature", "top_p", "max_tokens", "stream", "stop"},
},
{
ID: "coder-model",
Object: "model",
Created: 1771171200,
OwnedBy: "qwen",
Type: "qwen",
Version: "3.5",
DisplayName: "Qwen 3.5 Plus",
Description: "efficient hybrid model with leading coding performance",
ContextLength: 1048576,
MaxCompletionTokens: 65536,
SupportedParameters: []string{"temperature", "top_p", "max_tokens", "stream", "stop"},
},
{
ID: "vision-model",
Object: "model",
@@ -884,6 +908,8 @@ func GetAntigravityModelConfig() map[string]*AntigravityModelConfig {
"claude-opus-4-5-thinking": {Thinking: &ThinkingSupport{Min: 1024, Max: 128000, ZeroAllowed: true, DynamicAllowed: true}, MaxCompletionTokens: 64000},
"claude-opus-4-6-thinking": {Thinking: &ThinkingSupport{Min: 1024, Max: 128000, ZeroAllowed: true, DynamicAllowed: true}, MaxCompletionTokens: 64000},
"claude-sonnet-4-5": {MaxCompletionTokens: 64000},
"claude-sonnet-4-6": {MaxCompletionTokens: 64000},
"claude-sonnet-4-6-thinking": {Thinking: &ThinkingSupport{Min: 1024, Max: 128000, ZeroAllowed: true, DynamicAllowed: true}, MaxCompletionTokens: 64000},
"gpt-oss-120b-medium": {},
"tab_flash_lite_preview": {},
}

File diff suppressed because it is too large Load Diff

View File

@@ -22,9 +22,7 @@ import (
)
const (
qwenUserAgent = "google-api-nodejs-client/9.15.1"
qwenXGoogAPIClient = "gl-node/22.17.0"
qwenClientMetadataValue = "ideType=IDE_UNSPECIFIED,platform=PLATFORM_UNSPECIFIED,pluginType=GEMINI"
qwenUserAgent = "QwenCode/0.10.3 (darwin; arm64)"
)
// QwenExecutor is a stateless executor for Qwen Code using OpenAI-compatible chat completions.
@@ -344,8 +342,18 @@ func applyQwenHeaders(r *http.Request, token string, stream bool) {
r.Header.Set("Content-Type", "application/json")
r.Header.Set("Authorization", "Bearer "+token)
r.Header.Set("User-Agent", qwenUserAgent)
r.Header.Set("X-Goog-Api-Client", qwenXGoogAPIClient)
r.Header.Set("Client-Metadata", qwenClientMetadataValue)
r.Header.Set("X-Dashscope-Useragent", qwenUserAgent)
r.Header.Set("X-Stainless-Runtime-Version", "v22.17.0")
r.Header.Set("Sec-Fetch-Mode", "cors")
r.Header.Set("X-Stainless-Lang", "js")
r.Header.Set("X-Stainless-Arch", "arm64")
r.Header.Set("X-Stainless-Package-Version", "5.11.0")
r.Header.Set("X-Dashscope-Cachecontrol", "enable")
r.Header.Set("X-Stainless-Retry-Count", "0")
r.Header.Set("X-Stainless-Os", "MacOS")
r.Header.Set("X-Dashscope-Authtype", "qwen-oauth")
r.Header.Set("X-Stainless-Runtime", "node")
if stream {
r.Header.Set("Accept", "text/event-stream")
return

View File

@@ -184,6 +184,9 @@ func BuildConfigChangeDetails(oldCfg, newCfg *config.Config) []string {
if strings.TrimSpace(o.Prefix) != strings.TrimSpace(n.Prefix) {
changes = append(changes, fmt.Sprintf("codex[%d].prefix: %s -> %s", i, strings.TrimSpace(o.Prefix), strings.TrimSpace(n.Prefix)))
}
if o.Websockets != n.Websockets {
changes = append(changes, fmt.Sprintf("codex[%d].websockets: %t -> %t", i, o.Websockets, n.Websockets))
}
if strings.TrimSpace(o.APIKey) != strings.TrimSpace(n.APIKey) {
changes = append(changes, fmt.Sprintf("codex[%d].api-key: updated", i))
}

View File

@@ -160,6 +160,9 @@ func (s *ConfigSynthesizer) synthesizeCodexKeys(ctx *SynthesisContext) []*coreau
if ck.BaseURL != "" {
attrs["base_url"] = ck.BaseURL
}
if ck.Websockets {
attrs["websockets"] = "true"
}
if hash := diff.ComputeCodexModelsHash(ck.Models); hash != "" {
attrs["models_hash"] = hash
}

View File

@@ -231,10 +231,11 @@ func TestConfigSynthesizer_CodexKeys(t *testing.T) {
Config: &config.Config{
CodexKey: []config.CodexKey{
{
APIKey: "codex-key-123",
Prefix: "dev",
BaseURL: "https://api.openai.com",
ProxyURL: "http://proxy.local",
APIKey: "codex-key-123",
Prefix: "dev",
BaseURL: "https://api.openai.com",
ProxyURL: "http://proxy.local",
Websockets: true,
},
},
},
@@ -259,6 +260,9 @@ func TestConfigSynthesizer_CodexKeys(t *testing.T) {
if auths[0].ProxyURL != "http://proxy.local" {
t.Errorf("expected proxy_url http://proxy.local, got %s", auths[0].ProxyURL)
}
if auths[0].Attributes["websockets"] != "true" {
t.Errorf("expected websockets=true, got %s", auths[0].Attributes["websockets"])
}
}
func TestConfigSynthesizer_CodexKeys_SkipsEmptyAndHeaders(t *testing.T) {

View File

@@ -52,6 +52,45 @@ const (
defaultStreamingBootstrapRetries = 0
)
type pinnedAuthContextKey struct{}
type selectedAuthCallbackContextKey struct{}
type executionSessionContextKey struct{}
// WithPinnedAuthID returns a child context that requests execution on a specific auth ID.
func WithPinnedAuthID(ctx context.Context, authID string) context.Context {
authID = strings.TrimSpace(authID)
if authID == "" {
return ctx
}
if ctx == nil {
ctx = context.Background()
}
return context.WithValue(ctx, pinnedAuthContextKey{}, authID)
}
// WithSelectedAuthIDCallback returns a child context that receives the selected auth ID.
func WithSelectedAuthIDCallback(ctx context.Context, callback func(string)) context.Context {
if callback == nil {
return ctx
}
if ctx == nil {
ctx = context.Background()
}
return context.WithValue(ctx, selectedAuthCallbackContextKey{}, callback)
}
// WithExecutionSessionID returns a child context tagged with a long-lived execution session ID.
func WithExecutionSessionID(ctx context.Context, sessionID string) context.Context {
sessionID = strings.TrimSpace(sessionID)
if sessionID == "" {
return ctx
}
if ctx == nil {
ctx = context.Background()
}
return context.WithValue(ctx, executionSessionContextKey{}, sessionID)
}
// BuildErrorResponseBody builds an OpenAI-compatible JSON error response body.
// If errText is already valid JSON, it is returned as-is to preserve upstream error payloads.
func BuildErrorResponseBody(status int, errText string) []byte {
@@ -152,7 +191,59 @@ func requestExecutionMetadata(ctx context.Context) map[string]any {
if key == "" {
key = uuid.NewString()
}
return map[string]any{idempotencyKeyMetadataKey: key}
meta := map[string]any{idempotencyKeyMetadataKey: key}
if pinnedAuthID := pinnedAuthIDFromContext(ctx); pinnedAuthID != "" {
meta[coreexecutor.PinnedAuthMetadataKey] = pinnedAuthID
}
if selectedCallback := selectedAuthIDCallbackFromContext(ctx); selectedCallback != nil {
meta[coreexecutor.SelectedAuthCallbackMetadataKey] = selectedCallback
}
if executionSessionID := executionSessionIDFromContext(ctx); executionSessionID != "" {
meta[coreexecutor.ExecutionSessionMetadataKey] = executionSessionID
}
return meta
}
func pinnedAuthIDFromContext(ctx context.Context) string {
if ctx == nil {
return ""
}
raw := ctx.Value(pinnedAuthContextKey{})
switch v := raw.(type) {
case string:
return strings.TrimSpace(v)
case []byte:
return strings.TrimSpace(string(v))
default:
return ""
}
}
func selectedAuthIDCallbackFromContext(ctx context.Context) func(string) {
if ctx == nil {
return nil
}
raw := ctx.Value(selectedAuthCallbackContextKey{})
if callback, ok := raw.(func(string)); ok && callback != nil {
return callback
}
return nil
}
func executionSessionIDFromContext(ctx context.Context) string {
if ctx == nil {
return ""
}
raw := ctx.Value(executionSessionContextKey{})
switch v := raw.(type) {
case string:
return strings.TrimSpace(v)
case []byte:
return strings.TrimSpace(string(v))
default:
return ""
}
}
// BaseAPIHandler contains the handlers for API endpoints.

View File

@@ -122,6 +122,82 @@ func (e *payloadThenErrorStreamExecutor) Calls() int {
return e.calls
}
type authAwareStreamExecutor struct {
mu sync.Mutex
calls int
authIDs []string
}
func (e *authAwareStreamExecutor) Identifier() string { return "codex" }
func (e *authAwareStreamExecutor) Execute(context.Context, *coreauth.Auth, coreexecutor.Request, coreexecutor.Options) (coreexecutor.Response, error) {
return coreexecutor.Response{}, &coreauth.Error{Code: "not_implemented", Message: "Execute not implemented"}
}
func (e *authAwareStreamExecutor) ExecuteStream(ctx context.Context, auth *coreauth.Auth, req coreexecutor.Request, opts coreexecutor.Options) (<-chan coreexecutor.StreamChunk, error) {
_ = ctx
_ = req
_ = opts
ch := make(chan coreexecutor.StreamChunk, 1)
authID := ""
if auth != nil {
authID = auth.ID
}
e.mu.Lock()
e.calls++
e.authIDs = append(e.authIDs, authID)
e.mu.Unlock()
if authID == "auth1" {
ch <- coreexecutor.StreamChunk{
Err: &coreauth.Error{
Code: "unauthorized",
Message: "unauthorized",
Retryable: false,
HTTPStatus: http.StatusUnauthorized,
},
}
close(ch)
return ch, nil
}
ch <- coreexecutor.StreamChunk{Payload: []byte("ok")}
close(ch)
return ch, nil
}
func (e *authAwareStreamExecutor) Refresh(ctx context.Context, auth *coreauth.Auth) (*coreauth.Auth, error) {
return auth, nil
}
func (e *authAwareStreamExecutor) CountTokens(context.Context, *coreauth.Auth, coreexecutor.Request, coreexecutor.Options) (coreexecutor.Response, error) {
return coreexecutor.Response{}, &coreauth.Error{Code: "not_implemented", Message: "CountTokens not implemented"}
}
func (e *authAwareStreamExecutor) HttpRequest(ctx context.Context, auth *coreauth.Auth, req *http.Request) (*http.Response, error) {
return nil, &coreauth.Error{
Code: "not_implemented",
Message: "HttpRequest not implemented",
HTTPStatus: http.StatusNotImplemented,
}
}
func (e *authAwareStreamExecutor) Calls() int {
e.mu.Lock()
defer e.mu.Unlock()
return e.calls
}
func (e *authAwareStreamExecutor) AuthIDs() []string {
e.mu.Lock()
defer e.mu.Unlock()
out := make([]string, len(e.authIDs))
copy(out, e.authIDs)
return out
}
func TestExecuteStreamWithAuthManager_RetriesBeforeFirstByte(t *testing.T) {
executor := &failOnceStreamExecutor{}
manager := coreauth.NewManager(nil, nil, nil)
@@ -252,3 +328,128 @@ func TestExecuteStreamWithAuthManager_DoesNotRetryAfterFirstByte(t *testing.T) {
t.Fatalf("expected 1 stream attempt, got %d", executor.Calls())
}
}
func TestExecuteStreamWithAuthManager_PinnedAuthKeepsSameUpstream(t *testing.T) {
executor := &authAwareStreamExecutor{}
manager := coreauth.NewManager(nil, nil, nil)
manager.RegisterExecutor(executor)
auth1 := &coreauth.Auth{
ID: "auth1",
Provider: "codex",
Status: coreauth.StatusActive,
Metadata: map[string]any{"email": "test1@example.com"},
}
if _, err := manager.Register(context.Background(), auth1); err != nil {
t.Fatalf("manager.Register(auth1): %v", err)
}
auth2 := &coreauth.Auth{
ID: "auth2",
Provider: "codex",
Status: coreauth.StatusActive,
Metadata: map[string]any{"email": "test2@example.com"},
}
if _, err := manager.Register(context.Background(), auth2); err != nil {
t.Fatalf("manager.Register(auth2): %v", err)
}
registry.GetGlobalRegistry().RegisterClient(auth1.ID, auth1.Provider, []*registry.ModelInfo{{ID: "test-model"}})
registry.GetGlobalRegistry().RegisterClient(auth2.ID, auth2.Provider, []*registry.ModelInfo{{ID: "test-model"}})
t.Cleanup(func() {
registry.GetGlobalRegistry().UnregisterClient(auth1.ID)
registry.GetGlobalRegistry().UnregisterClient(auth2.ID)
})
handler := NewBaseAPIHandlers(&sdkconfig.SDKConfig{
Streaming: sdkconfig.StreamingConfig{
BootstrapRetries: 1,
},
}, manager)
ctx := WithPinnedAuthID(context.Background(), "auth1")
dataChan, errChan := handler.ExecuteStreamWithAuthManager(ctx, "openai", "test-model", []byte(`{"model":"test-model"}`), "")
if dataChan == nil || errChan == nil {
t.Fatalf("expected non-nil channels")
}
var got []byte
for chunk := range dataChan {
got = append(got, chunk...)
}
var gotErr error
for msg := range errChan {
if msg != nil && msg.Error != nil {
gotErr = msg.Error
}
}
if len(got) != 0 {
t.Fatalf("expected empty payload, got %q", string(got))
}
if gotErr == nil {
t.Fatalf("expected terminal error, got nil")
}
authIDs := executor.AuthIDs()
if len(authIDs) == 0 {
t.Fatalf("expected at least one upstream attempt")
}
for _, authID := range authIDs {
if authID != "auth1" {
t.Fatalf("expected all attempts on auth1, got sequence %v", authIDs)
}
}
}
func TestExecuteStreamWithAuthManager_SelectedAuthCallbackReceivesAuthID(t *testing.T) {
executor := &authAwareStreamExecutor{}
manager := coreauth.NewManager(nil, nil, nil)
manager.RegisterExecutor(executor)
auth2 := &coreauth.Auth{
ID: "auth2",
Provider: "codex",
Status: coreauth.StatusActive,
Metadata: map[string]any{"email": "test2@example.com"},
}
if _, err := manager.Register(context.Background(), auth2); err != nil {
t.Fatalf("manager.Register(auth2): %v", err)
}
registry.GetGlobalRegistry().RegisterClient(auth2.ID, auth2.Provider, []*registry.ModelInfo{{ID: "test-model"}})
t.Cleanup(func() {
registry.GetGlobalRegistry().UnregisterClient(auth2.ID)
})
handler := NewBaseAPIHandlers(&sdkconfig.SDKConfig{
Streaming: sdkconfig.StreamingConfig{
BootstrapRetries: 0,
},
}, manager)
selectedAuthID := ""
ctx := WithSelectedAuthIDCallback(context.Background(), func(authID string) {
selectedAuthID = authID
})
dataChan, errChan := handler.ExecuteStreamWithAuthManager(ctx, "openai", "test-model", []byte(`{"model":"test-model"}`), "")
if dataChan == nil || errChan == nil {
t.Fatalf("expected non-nil channels")
}
var got []byte
for chunk := range dataChan {
got = append(got, chunk...)
}
for msg := range errChan {
if msg != nil {
t.Fatalf("unexpected error: %+v", msg)
}
}
if string(got) != "ok" {
t.Fatalf("expected payload ok, got %q", string(got))
}
if selectedAuthID != "auth2" {
t.Fatalf("selectedAuthID = %q, want %q", selectedAuthID, "auth2")
}
}

View File

@@ -0,0 +1,662 @@
package openai
import (
"bytes"
"context"
"encoding/json"
"fmt"
"net/http"
"strconv"
"strings"
"time"
"github.com/gin-gonic/gin"
"github.com/google/uuid"
"github.com/gorilla/websocket"
"github.com/router-for-me/CLIProxyAPI/v6/internal/interfaces"
"github.com/router-for-me/CLIProxyAPI/v6/sdk/api/handlers"
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
log "github.com/sirupsen/logrus"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
)
const (
wsRequestTypeCreate = "response.create"
wsRequestTypeAppend = "response.append"
wsEventTypeError = "error"
wsEventTypeCompleted = "response.completed"
wsEventTypeDone = "response.done"
wsDoneMarker = "[DONE]"
wsTurnStateHeader = "x-codex-turn-state"
wsRequestBodyKey = "REQUEST_BODY_OVERRIDE"
wsPayloadLogMaxSize = 2048
)
var responsesWebsocketUpgrader = websocket.Upgrader{
ReadBufferSize: 4096,
WriteBufferSize: 4096,
CheckOrigin: func(r *http.Request) bool {
return true
},
}
// ResponsesWebsocket handles websocket requests for /v1/responses.
// It accepts `response.create` and `response.append` requests and streams
// response events back as JSON websocket text messages.
func (h *OpenAIResponsesAPIHandler) ResponsesWebsocket(c *gin.Context) {
conn, err := responsesWebsocketUpgrader.Upgrade(c.Writer, c.Request, websocketUpgradeHeaders(c.Request))
if err != nil {
return
}
passthroughSessionID := uuid.NewString()
clientRemoteAddr := ""
if c != nil && c.Request != nil {
clientRemoteAddr = strings.TrimSpace(c.Request.RemoteAddr)
}
log.Infof("responses websocket: client connected id=%s remote=%s", passthroughSessionID, clientRemoteAddr)
var wsTerminateErr error
var wsBodyLog strings.Builder
defer func() {
if wsTerminateErr != nil {
// log.Infof("responses websocket: session closing id=%s reason=%v", passthroughSessionID, wsTerminateErr)
} else {
log.Infof("responses websocket: session closing id=%s", passthroughSessionID)
}
if h != nil && h.AuthManager != nil {
h.AuthManager.CloseExecutionSession(passthroughSessionID)
log.Infof("responses websocket: upstream execution session closed id=%s", passthroughSessionID)
}
setWebsocketRequestBody(c, wsBodyLog.String())
if errClose := conn.Close(); errClose != nil {
log.Warnf("responses websocket: close connection error: %v", errClose)
}
}()
var lastRequest []byte
lastResponseOutput := []byte("[]")
pinnedAuthID := ""
for {
msgType, payload, errReadMessage := conn.ReadMessage()
if errReadMessage != nil {
wsTerminateErr = errReadMessage
appendWebsocketEvent(&wsBodyLog, "disconnect", []byte(errReadMessage.Error()))
if websocket.IsCloseError(errReadMessage, websocket.CloseNormalClosure, websocket.CloseGoingAway, websocket.CloseNoStatusReceived) {
log.Infof("responses websocket: client disconnected id=%s error=%v", passthroughSessionID, errReadMessage)
} else {
// log.Warnf("responses websocket: read message failed id=%s error=%v", passthroughSessionID, errReadMessage)
}
return
}
if msgType != websocket.TextMessage && msgType != websocket.BinaryMessage {
continue
}
// log.Infof(
// "responses websocket: downstream_in id=%s type=%d event=%s payload=%s",
// passthroughSessionID,
// msgType,
// websocketPayloadEventType(payload),
// websocketPayloadPreview(payload),
// )
appendWebsocketEvent(&wsBodyLog, "request", payload)
allowIncrementalInputWithPreviousResponseID := websocketUpstreamSupportsIncrementalInput(nil, nil)
if pinnedAuthID != "" && h != nil && h.AuthManager != nil {
if pinnedAuth, ok := h.AuthManager.GetByID(pinnedAuthID); ok && pinnedAuth != nil {
allowIncrementalInputWithPreviousResponseID = websocketUpstreamSupportsIncrementalInput(pinnedAuth.Attributes, pinnedAuth.Metadata)
}
}
var requestJSON []byte
var updatedLastRequest []byte
var errMsg *interfaces.ErrorMessage
requestJSON, updatedLastRequest, errMsg = normalizeResponsesWebsocketRequestWithMode(
payload,
lastRequest,
lastResponseOutput,
allowIncrementalInputWithPreviousResponseID,
)
if errMsg != nil {
h.LoggingAPIResponseError(context.WithValue(context.Background(), "gin", c), errMsg)
markAPIResponseTimestamp(c)
errorPayload, errWrite := writeResponsesWebsocketError(conn, errMsg)
appendWebsocketEvent(&wsBodyLog, "response", errorPayload)
log.Infof(
"responses websocket: downstream_out id=%s type=%d event=%s payload=%s",
passthroughSessionID,
websocket.TextMessage,
websocketPayloadEventType(errorPayload),
websocketPayloadPreview(errorPayload),
)
if errWrite != nil {
log.Warnf(
"responses websocket: downstream_out write failed id=%s event=%s error=%v",
passthroughSessionID,
websocketPayloadEventType(errorPayload),
errWrite,
)
return
}
continue
}
lastRequest = updatedLastRequest
modelName := gjson.GetBytes(requestJSON, "model").String()
cliCtx, cliCancel := h.GetContextWithCancel(h, c, context.Background())
cliCtx = cliproxyexecutor.WithDownstreamWebsocket(cliCtx)
cliCtx = handlers.WithExecutionSessionID(cliCtx, passthroughSessionID)
if pinnedAuthID != "" {
cliCtx = handlers.WithPinnedAuthID(cliCtx, pinnedAuthID)
} else {
cliCtx = handlers.WithSelectedAuthIDCallback(cliCtx, func(authID string) {
pinnedAuthID = strings.TrimSpace(authID)
})
}
dataChan, errChan := h.ExecuteStreamWithAuthManager(cliCtx, h.HandlerType(), modelName, requestJSON, "")
completedOutput, errForward := h.forwardResponsesWebsocket(c, conn, cliCancel, dataChan, errChan, &wsBodyLog, passthroughSessionID)
if errForward != nil {
wsTerminateErr = errForward
appendWebsocketEvent(&wsBodyLog, "disconnect", []byte(errForward.Error()))
log.Warnf("responses websocket: forward failed id=%s error=%v", passthroughSessionID, errForward)
return
}
lastResponseOutput = completedOutput
}
}
func websocketUpgradeHeaders(req *http.Request) http.Header {
headers := http.Header{}
if req == nil {
return headers
}
// Keep the same sticky turn-state across reconnects when provided by the client.
turnState := strings.TrimSpace(req.Header.Get(wsTurnStateHeader))
if turnState != "" {
headers.Set(wsTurnStateHeader, turnState)
}
return headers
}
func normalizeResponsesWebsocketRequest(rawJSON []byte, lastRequest []byte, lastResponseOutput []byte) ([]byte, []byte, *interfaces.ErrorMessage) {
return normalizeResponsesWebsocketRequestWithMode(rawJSON, lastRequest, lastResponseOutput, true)
}
func normalizeResponsesWebsocketRequestWithMode(rawJSON []byte, lastRequest []byte, lastResponseOutput []byte, allowIncrementalInputWithPreviousResponseID bool) ([]byte, []byte, *interfaces.ErrorMessage) {
requestType := strings.TrimSpace(gjson.GetBytes(rawJSON, "type").String())
switch requestType {
case wsRequestTypeCreate:
// log.Infof("responses websocket: response.create request")
if len(lastRequest) == 0 {
return normalizeResponseCreateRequest(rawJSON)
}
return normalizeResponseSubsequentRequest(rawJSON, lastRequest, lastResponseOutput, allowIncrementalInputWithPreviousResponseID)
case wsRequestTypeAppend:
// log.Infof("responses websocket: response.append request")
return normalizeResponseSubsequentRequest(rawJSON, lastRequest, lastResponseOutput, allowIncrementalInputWithPreviousResponseID)
default:
return nil, lastRequest, &interfaces.ErrorMessage{
StatusCode: http.StatusBadRequest,
Error: fmt.Errorf("unsupported websocket request type: %s", requestType),
}
}
}
func normalizeResponseCreateRequest(rawJSON []byte) ([]byte, []byte, *interfaces.ErrorMessage) {
normalized, errDelete := sjson.DeleteBytes(rawJSON, "type")
if errDelete != nil {
normalized = bytes.Clone(rawJSON)
}
normalized, _ = sjson.SetBytes(normalized, "stream", true)
if !gjson.GetBytes(normalized, "input").Exists() {
normalized, _ = sjson.SetRawBytes(normalized, "input", []byte("[]"))
}
modelName := strings.TrimSpace(gjson.GetBytes(normalized, "model").String())
if modelName == "" {
return nil, nil, &interfaces.ErrorMessage{
StatusCode: http.StatusBadRequest,
Error: fmt.Errorf("missing model in response.create request"),
}
}
return normalized, bytes.Clone(normalized), nil
}
func normalizeResponseSubsequentRequest(rawJSON []byte, lastRequest []byte, lastResponseOutput []byte, allowIncrementalInputWithPreviousResponseID bool) ([]byte, []byte, *interfaces.ErrorMessage) {
if len(lastRequest) == 0 {
return nil, lastRequest, &interfaces.ErrorMessage{
StatusCode: http.StatusBadRequest,
Error: fmt.Errorf("websocket request received before response.create"),
}
}
nextInput := gjson.GetBytes(rawJSON, "input")
if !nextInput.Exists() || !nextInput.IsArray() {
return nil, lastRequest, &interfaces.ErrorMessage{
StatusCode: http.StatusBadRequest,
Error: fmt.Errorf("websocket request requires array field: input"),
}
}
// Websocket v2 mode uses response.create with previous_response_id + incremental input.
// Do not expand it into a full input transcript; upstream expects the incremental payload.
if allowIncrementalInputWithPreviousResponseID {
if prev := strings.TrimSpace(gjson.GetBytes(rawJSON, "previous_response_id").String()); prev != "" {
normalized, errDelete := sjson.DeleteBytes(rawJSON, "type")
if errDelete != nil {
normalized = bytes.Clone(rawJSON)
}
if !gjson.GetBytes(normalized, "model").Exists() {
modelName := strings.TrimSpace(gjson.GetBytes(lastRequest, "model").String())
if modelName != "" {
normalized, _ = sjson.SetBytes(normalized, "model", modelName)
}
}
if !gjson.GetBytes(normalized, "instructions").Exists() {
instructions := gjson.GetBytes(lastRequest, "instructions")
if instructions.Exists() {
normalized, _ = sjson.SetRawBytes(normalized, "instructions", []byte(instructions.Raw))
}
}
normalized, _ = sjson.SetBytes(normalized, "stream", true)
return normalized, bytes.Clone(normalized), nil
}
}
existingInput := gjson.GetBytes(lastRequest, "input")
mergedInput, errMerge := mergeJSONArrayRaw(existingInput.Raw, normalizeJSONArrayRaw(lastResponseOutput))
if errMerge != nil {
return nil, lastRequest, &interfaces.ErrorMessage{
StatusCode: http.StatusBadRequest,
Error: fmt.Errorf("invalid previous response output: %w", errMerge),
}
}
mergedInput, errMerge = mergeJSONArrayRaw(mergedInput, nextInput.Raw)
if errMerge != nil {
return nil, lastRequest, &interfaces.ErrorMessage{
StatusCode: http.StatusBadRequest,
Error: fmt.Errorf("invalid request input: %w", errMerge),
}
}
normalized, errDelete := sjson.DeleteBytes(rawJSON, "type")
if errDelete != nil {
normalized = bytes.Clone(rawJSON)
}
normalized, _ = sjson.DeleteBytes(normalized, "previous_response_id")
var errSet error
normalized, errSet = sjson.SetRawBytes(normalized, "input", []byte(mergedInput))
if errSet != nil {
return nil, lastRequest, &interfaces.ErrorMessage{
StatusCode: http.StatusBadRequest,
Error: fmt.Errorf("failed to merge websocket input: %w", errSet),
}
}
if !gjson.GetBytes(normalized, "model").Exists() {
modelName := strings.TrimSpace(gjson.GetBytes(lastRequest, "model").String())
if modelName != "" {
normalized, _ = sjson.SetBytes(normalized, "model", modelName)
}
}
if !gjson.GetBytes(normalized, "instructions").Exists() {
instructions := gjson.GetBytes(lastRequest, "instructions")
if instructions.Exists() {
normalized, _ = sjson.SetRawBytes(normalized, "instructions", []byte(instructions.Raw))
}
}
normalized, _ = sjson.SetBytes(normalized, "stream", true)
return normalized, bytes.Clone(normalized), nil
}
func websocketUpstreamSupportsIncrementalInput(attributes map[string]string, metadata map[string]any) bool {
if len(attributes) > 0 {
if raw := strings.TrimSpace(attributes["websockets"]); raw != "" {
parsed, errParse := strconv.ParseBool(raw)
if errParse == nil {
return parsed
}
}
}
if len(metadata) == 0 {
return false
}
raw, ok := metadata["websockets"]
if !ok || raw == nil {
return false
}
switch value := raw.(type) {
case bool:
return value
case string:
parsed, errParse := strconv.ParseBool(strings.TrimSpace(value))
if errParse == nil {
return parsed
}
default:
}
return false
}
func mergeJSONArrayRaw(existingRaw, appendRaw string) (string, error) {
existingRaw = strings.TrimSpace(existingRaw)
appendRaw = strings.TrimSpace(appendRaw)
if existingRaw == "" {
existingRaw = "[]"
}
if appendRaw == "" {
appendRaw = "[]"
}
var existing []json.RawMessage
if err := json.Unmarshal([]byte(existingRaw), &existing); err != nil {
return "", err
}
var appendItems []json.RawMessage
if err := json.Unmarshal([]byte(appendRaw), &appendItems); err != nil {
return "", err
}
merged := append(existing, appendItems...)
out, err := json.Marshal(merged)
if err != nil {
return "", err
}
return string(out), nil
}
func normalizeJSONArrayRaw(raw []byte) string {
trimmed := strings.TrimSpace(string(raw))
if trimmed == "" {
return "[]"
}
result := gjson.Parse(trimmed)
if result.Type == gjson.JSON && result.IsArray() {
return trimmed
}
return "[]"
}
func (h *OpenAIResponsesAPIHandler) forwardResponsesWebsocket(
c *gin.Context,
conn *websocket.Conn,
cancel handlers.APIHandlerCancelFunc,
data <-chan []byte,
errs <-chan *interfaces.ErrorMessage,
wsBodyLog *strings.Builder,
sessionID string,
) ([]byte, error) {
completed := false
completedOutput := []byte("[]")
for {
select {
case <-c.Request.Context().Done():
cancel(c.Request.Context().Err())
return completedOutput, c.Request.Context().Err()
case errMsg, ok := <-errs:
if !ok {
errs = nil
continue
}
if errMsg != nil {
h.LoggingAPIResponseError(context.WithValue(context.Background(), "gin", c), errMsg)
markAPIResponseTimestamp(c)
errorPayload, errWrite := writeResponsesWebsocketError(conn, errMsg)
appendWebsocketEvent(wsBodyLog, "response", errorPayload)
log.Infof(
"responses websocket: downstream_out id=%s type=%d event=%s payload=%s",
sessionID,
websocket.TextMessage,
websocketPayloadEventType(errorPayload),
websocketPayloadPreview(errorPayload),
)
if errWrite != nil {
// log.Warnf(
// "responses websocket: downstream_out write failed id=%s event=%s error=%v",
// sessionID,
// websocketPayloadEventType(errorPayload),
// errWrite,
// )
cancel(errMsg.Error)
return completedOutput, errWrite
}
}
if errMsg != nil {
cancel(errMsg.Error)
} else {
cancel(nil)
}
return completedOutput, nil
case chunk, ok := <-data:
if !ok {
if !completed {
errMsg := &interfaces.ErrorMessage{
StatusCode: http.StatusRequestTimeout,
Error: fmt.Errorf("stream closed before response.completed"),
}
h.LoggingAPIResponseError(context.WithValue(context.Background(), "gin", c), errMsg)
markAPIResponseTimestamp(c)
errorPayload, errWrite := writeResponsesWebsocketError(conn, errMsg)
appendWebsocketEvent(wsBodyLog, "response", errorPayload)
log.Infof(
"responses websocket: downstream_out id=%s type=%d event=%s payload=%s",
sessionID,
websocket.TextMessage,
websocketPayloadEventType(errorPayload),
websocketPayloadPreview(errorPayload),
)
if errWrite != nil {
log.Warnf(
"responses websocket: downstream_out write failed id=%s event=%s error=%v",
sessionID,
websocketPayloadEventType(errorPayload),
errWrite,
)
cancel(errMsg.Error)
return completedOutput, errWrite
}
cancel(errMsg.Error)
return completedOutput, nil
}
cancel(nil)
return completedOutput, nil
}
payloads := websocketJSONPayloadsFromChunk(chunk)
for i := range payloads {
eventType := gjson.GetBytes(payloads[i], "type").String()
if eventType == wsEventTypeCompleted {
// log.Infof("replace %s with %s", wsEventTypeCompleted, wsEventTypeDone)
payloads[i], _ = sjson.SetBytes(payloads[i], "type", wsEventTypeDone)
completed = true
completedOutput = responseCompletedOutputFromPayload(payloads[i])
}
markAPIResponseTimestamp(c)
appendWebsocketEvent(wsBodyLog, "response", payloads[i])
// log.Infof(
// "responses websocket: downstream_out id=%s type=%d event=%s payload=%s",
// sessionID,
// websocket.TextMessage,
// websocketPayloadEventType(payloads[i]),
// websocketPayloadPreview(payloads[i]),
// )
if errWrite := conn.WriteMessage(websocket.TextMessage, payloads[i]); errWrite != nil {
log.Warnf(
"responses websocket: downstream_out write failed id=%s event=%s error=%v",
sessionID,
websocketPayloadEventType(payloads[i]),
errWrite,
)
cancel(errWrite)
return completedOutput, errWrite
}
}
}
}
}
func responseCompletedOutputFromPayload(payload []byte) []byte {
output := gjson.GetBytes(payload, "response.output")
if output.Exists() && output.IsArray() {
return bytes.Clone([]byte(output.Raw))
}
return []byte("[]")
}
func websocketJSONPayloadsFromChunk(chunk []byte) [][]byte {
payloads := make([][]byte, 0, 2)
lines := bytes.Split(chunk, []byte("\n"))
for i := range lines {
line := bytes.TrimSpace(lines[i])
if len(line) == 0 || bytes.HasPrefix(line, []byte("event:")) {
continue
}
if bytes.HasPrefix(line, []byte("data:")) {
line = bytes.TrimSpace(line[len("data:"):])
}
if len(line) == 0 || bytes.Equal(line, []byte(wsDoneMarker)) {
continue
}
if json.Valid(line) {
payloads = append(payloads, bytes.Clone(line))
}
}
if len(payloads) > 0 {
return payloads
}
trimmed := bytes.TrimSpace(chunk)
if bytes.HasPrefix(trimmed, []byte("data:")) {
trimmed = bytes.TrimSpace(trimmed[len("data:"):])
}
if len(trimmed) > 0 && !bytes.Equal(trimmed, []byte(wsDoneMarker)) && json.Valid(trimmed) {
payloads = append(payloads, bytes.Clone(trimmed))
}
return payloads
}
func writeResponsesWebsocketError(conn *websocket.Conn, errMsg *interfaces.ErrorMessage) ([]byte, error) {
status := http.StatusInternalServerError
errText := http.StatusText(status)
if errMsg != nil {
if errMsg.StatusCode > 0 {
status = errMsg.StatusCode
errText = http.StatusText(status)
}
if errMsg.Error != nil && strings.TrimSpace(errMsg.Error.Error()) != "" {
errText = errMsg.Error.Error()
}
}
body := handlers.BuildErrorResponseBody(status, errText)
payload := map[string]any{
"type": wsEventTypeError,
"status": status,
}
if errMsg != nil && errMsg.Addon != nil {
headers := map[string]any{}
for key, values := range errMsg.Addon {
if len(values) == 0 {
continue
}
headers[key] = values[0]
}
if len(headers) > 0 {
payload["headers"] = headers
}
}
if len(body) > 0 && json.Valid(body) {
var decoded map[string]any
if errDecode := json.Unmarshal(body, &decoded); errDecode == nil {
if inner, ok := decoded["error"]; ok {
payload["error"] = inner
} else {
payload["error"] = decoded
}
}
}
if _, ok := payload["error"]; !ok {
payload["error"] = map[string]any{
"type": "server_error",
"message": errText,
}
}
data, err := json.Marshal(payload)
if err != nil {
return nil, err
}
return data, conn.WriteMessage(websocket.TextMessage, data)
}
func appendWebsocketEvent(builder *strings.Builder, eventType string, payload []byte) {
if builder == nil {
return
}
trimmedPayload := bytes.TrimSpace(payload)
if len(trimmedPayload) == 0 {
return
}
if builder.Len() > 0 {
builder.WriteString("\n")
}
builder.WriteString("websocket.")
builder.WriteString(eventType)
builder.WriteString("\n")
builder.Write(trimmedPayload)
builder.WriteString("\n")
}
func websocketPayloadEventType(payload []byte) string {
eventType := strings.TrimSpace(gjson.GetBytes(payload, "type").String())
if eventType == "" {
return "-"
}
return eventType
}
func websocketPayloadPreview(payload []byte) string {
trimmedPayload := bytes.TrimSpace(payload)
if len(trimmedPayload) == 0 {
return "<empty>"
}
preview := trimmedPayload
if len(preview) > wsPayloadLogMaxSize {
preview = preview[:wsPayloadLogMaxSize]
}
previewText := strings.ReplaceAll(string(preview), "\n", "\\n")
previewText = strings.ReplaceAll(previewText, "\r", "\\r")
if len(trimmedPayload) > wsPayloadLogMaxSize {
return fmt.Sprintf("%s...(truncated,total=%d)", previewText, len(trimmedPayload))
}
return previewText
}
func setWebsocketRequestBody(c *gin.Context, body string) {
if c == nil {
return
}
trimmedBody := strings.TrimSpace(body)
if trimmedBody == "" {
return
}
c.Set(wsRequestBodyKey, []byte(trimmedBody))
}
func markAPIResponseTimestamp(c *gin.Context) {
if c == nil {
return
}
if _, exists := c.Get("API_RESPONSE_TIMESTAMP"); exists {
return
}
c.Set("API_RESPONSE_TIMESTAMP", time.Now())
}

View File

@@ -0,0 +1,249 @@
package openai
import (
"bytes"
"net/http"
"net/http/httptest"
"strings"
"testing"
"github.com/gin-gonic/gin"
"github.com/tidwall/gjson"
)
func TestNormalizeResponsesWebsocketRequestCreate(t *testing.T) {
raw := []byte(`{"type":"response.create","model":"test-model","stream":false,"input":[{"type":"message","id":"msg-1"}]}`)
normalized, last, errMsg := normalizeResponsesWebsocketRequest(raw, nil, nil)
if errMsg != nil {
t.Fatalf("unexpected error: %v", errMsg.Error)
}
if gjson.GetBytes(normalized, "type").Exists() {
t.Fatalf("normalized create request must not include type field")
}
if !gjson.GetBytes(normalized, "stream").Bool() {
t.Fatalf("normalized create request must force stream=true")
}
if gjson.GetBytes(normalized, "model").String() != "test-model" {
t.Fatalf("unexpected model: %s", gjson.GetBytes(normalized, "model").String())
}
if !bytes.Equal(last, normalized) {
t.Fatalf("last request snapshot should match normalized request")
}
}
func TestNormalizeResponsesWebsocketRequestCreateWithHistory(t *testing.T) {
lastRequest := []byte(`{"model":"test-model","stream":true,"input":[{"type":"message","id":"msg-1"}]}`)
lastResponseOutput := []byte(`[
{"type":"function_call","id":"fc-1","call_id":"call-1"},
{"type":"message","id":"assistant-1"}
]`)
raw := []byte(`{"type":"response.create","input":[{"type":"function_call_output","call_id":"call-1","id":"tool-out-1"}]}`)
normalized, next, errMsg := normalizeResponsesWebsocketRequest(raw, lastRequest, lastResponseOutput)
if errMsg != nil {
t.Fatalf("unexpected error: %v", errMsg.Error)
}
if gjson.GetBytes(normalized, "type").Exists() {
t.Fatalf("normalized subsequent create request must not include type field")
}
if gjson.GetBytes(normalized, "model").String() != "test-model" {
t.Fatalf("unexpected model: %s", gjson.GetBytes(normalized, "model").String())
}
input := gjson.GetBytes(normalized, "input").Array()
if len(input) != 4 {
t.Fatalf("merged input len = %d, want 4", len(input))
}
if input[0].Get("id").String() != "msg-1" ||
input[1].Get("id").String() != "fc-1" ||
input[2].Get("id").String() != "assistant-1" ||
input[3].Get("id").String() != "tool-out-1" {
t.Fatalf("unexpected merged input order")
}
if !bytes.Equal(next, normalized) {
t.Fatalf("next request snapshot should match normalized request")
}
}
func TestNormalizeResponsesWebsocketRequestWithPreviousResponseIDIncremental(t *testing.T) {
lastRequest := []byte(`{"model":"test-model","stream":true,"instructions":"be helpful","input":[{"type":"message","id":"msg-1"}]}`)
lastResponseOutput := []byte(`[
{"type":"function_call","id":"fc-1","call_id":"call-1"},
{"type":"message","id":"assistant-1"}
]`)
raw := []byte(`{"type":"response.create","previous_response_id":"resp-1","input":[{"type":"function_call_output","call_id":"call-1","id":"tool-out-1"}]}`)
normalized, next, errMsg := normalizeResponsesWebsocketRequestWithMode(raw, lastRequest, lastResponseOutput, true)
if errMsg != nil {
t.Fatalf("unexpected error: %v", errMsg.Error)
}
if gjson.GetBytes(normalized, "type").Exists() {
t.Fatalf("normalized request must not include type field")
}
if gjson.GetBytes(normalized, "previous_response_id").String() != "resp-1" {
t.Fatalf("previous_response_id must be preserved in incremental mode")
}
input := gjson.GetBytes(normalized, "input").Array()
if len(input) != 1 {
t.Fatalf("incremental input len = %d, want 1", len(input))
}
if input[0].Get("id").String() != "tool-out-1" {
t.Fatalf("unexpected incremental input item id: %s", input[0].Get("id").String())
}
if gjson.GetBytes(normalized, "model").String() != "test-model" {
t.Fatalf("unexpected model: %s", gjson.GetBytes(normalized, "model").String())
}
if gjson.GetBytes(normalized, "instructions").String() != "be helpful" {
t.Fatalf("unexpected instructions: %s", gjson.GetBytes(normalized, "instructions").String())
}
if !bytes.Equal(next, normalized) {
t.Fatalf("next request snapshot should match normalized request")
}
}
func TestNormalizeResponsesWebsocketRequestWithPreviousResponseIDMergedWhenIncrementalDisabled(t *testing.T) {
lastRequest := []byte(`{"model":"test-model","stream":true,"input":[{"type":"message","id":"msg-1"}]}`)
lastResponseOutput := []byte(`[
{"type":"function_call","id":"fc-1","call_id":"call-1"},
{"type":"message","id":"assistant-1"}
]`)
raw := []byte(`{"type":"response.create","previous_response_id":"resp-1","input":[{"type":"function_call_output","call_id":"call-1","id":"tool-out-1"}]}`)
normalized, next, errMsg := normalizeResponsesWebsocketRequestWithMode(raw, lastRequest, lastResponseOutput, false)
if errMsg != nil {
t.Fatalf("unexpected error: %v", errMsg.Error)
}
if gjson.GetBytes(normalized, "previous_response_id").Exists() {
t.Fatalf("previous_response_id must be removed when incremental mode is disabled")
}
input := gjson.GetBytes(normalized, "input").Array()
if len(input) != 4 {
t.Fatalf("merged input len = %d, want 4", len(input))
}
if input[0].Get("id").String() != "msg-1" ||
input[1].Get("id").String() != "fc-1" ||
input[2].Get("id").String() != "assistant-1" ||
input[3].Get("id").String() != "tool-out-1" {
t.Fatalf("unexpected merged input order")
}
if !bytes.Equal(next, normalized) {
t.Fatalf("next request snapshot should match normalized request")
}
}
func TestNormalizeResponsesWebsocketRequestAppend(t *testing.T) {
lastRequest := []byte(`{"model":"test-model","stream":true,"input":[{"type":"message","id":"msg-1"}]}`)
lastResponseOutput := []byte(`[
{"type":"message","id":"assistant-1"},
{"type":"function_call_output","id":"tool-out-1"}
]`)
raw := []byte(`{"type":"response.append","input":[{"type":"message","id":"msg-2"},{"type":"message","id":"msg-3"}]}`)
normalized, next, errMsg := normalizeResponsesWebsocketRequest(raw, lastRequest, lastResponseOutput)
if errMsg != nil {
t.Fatalf("unexpected error: %v", errMsg.Error)
}
input := gjson.GetBytes(normalized, "input").Array()
if len(input) != 5 {
t.Fatalf("merged input len = %d, want 5", len(input))
}
if input[0].Get("id").String() != "msg-1" ||
input[1].Get("id").String() != "assistant-1" ||
input[2].Get("id").String() != "tool-out-1" ||
input[3].Get("id").String() != "msg-2" ||
input[4].Get("id").String() != "msg-3" {
t.Fatalf("unexpected merged input order")
}
if !bytes.Equal(next, normalized) {
t.Fatalf("next request snapshot should match normalized append request")
}
}
func TestNormalizeResponsesWebsocketRequestAppendWithoutCreate(t *testing.T) {
raw := []byte(`{"type":"response.append","input":[]}`)
_, _, errMsg := normalizeResponsesWebsocketRequest(raw, nil, nil)
if errMsg == nil {
t.Fatalf("expected error for append without previous request")
}
if errMsg.StatusCode != http.StatusBadRequest {
t.Fatalf("status = %d, want %d", errMsg.StatusCode, http.StatusBadRequest)
}
}
func TestWebsocketJSONPayloadsFromChunk(t *testing.T) {
chunk := []byte("event: response.created\n\ndata: {\"type\":\"response.created\",\"response\":{\"id\":\"resp-1\"}}\n\ndata: [DONE]\n")
payloads := websocketJSONPayloadsFromChunk(chunk)
if len(payloads) != 1 {
t.Fatalf("payloads len = %d, want 1", len(payloads))
}
if gjson.GetBytes(payloads[0], "type").String() != "response.created" {
t.Fatalf("unexpected payload type: %s", gjson.GetBytes(payloads[0], "type").String())
}
}
func TestWebsocketJSONPayloadsFromPlainJSONChunk(t *testing.T) {
chunk := []byte(`{"type":"response.completed","response":{"id":"resp-1"}}`)
payloads := websocketJSONPayloadsFromChunk(chunk)
if len(payloads) != 1 {
t.Fatalf("payloads len = %d, want 1", len(payloads))
}
if gjson.GetBytes(payloads[0], "type").String() != "response.completed" {
t.Fatalf("unexpected payload type: %s", gjson.GetBytes(payloads[0], "type").String())
}
}
func TestResponseCompletedOutputFromPayload(t *testing.T) {
payload := []byte(`{"type":"response.completed","response":{"id":"resp-1","output":[{"type":"message","id":"out-1"}]}}`)
output := responseCompletedOutputFromPayload(payload)
items := gjson.ParseBytes(output).Array()
if len(items) != 1 {
t.Fatalf("output len = %d, want 1", len(items))
}
if items[0].Get("id").String() != "out-1" {
t.Fatalf("unexpected output id: %s", items[0].Get("id").String())
}
}
func TestAppendWebsocketEvent(t *testing.T) {
var builder strings.Builder
appendWebsocketEvent(&builder, "request", []byte(" {\"type\":\"response.create\"}\n"))
appendWebsocketEvent(&builder, "response", []byte("{\"type\":\"response.created\"}"))
got := builder.String()
if !strings.Contains(got, "websocket.request\n{\"type\":\"response.create\"}\n") {
t.Fatalf("request event not found in body: %s", got)
}
if !strings.Contains(got, "websocket.response\n{\"type\":\"response.created\"}\n") {
t.Fatalf("response event not found in body: %s", got)
}
}
func TestSetWebsocketRequestBody(t *testing.T) {
gin.SetMode(gin.TestMode)
recorder := httptest.NewRecorder()
c, _ := gin.CreateTestContext(recorder)
setWebsocketRequestBody(c, " \n ")
if _, exists := c.Get(wsRequestBodyKey); exists {
t.Fatalf("request body key should not be set for empty body")
}
setWebsocketRequestBody(c, "event body")
value, exists := c.Get(wsRequestBodyKey)
if !exists {
t.Fatalf("request body key not set")
}
bodyBytes, ok := value.([]byte)
if !ok {
t.Fatalf("request body key type mismatch")
}
if string(bodyBytes) != "event body" {
t.Fatalf("request body = %q, want %q", string(bodyBytes), "event body")
}
}

View File

@@ -41,6 +41,17 @@ type ProviderExecutor interface {
HttpRequest(ctx context.Context, auth *Auth, req *http.Request) (*http.Response, error)
}
// ExecutionSessionCloser allows executors to release per-session runtime resources.
type ExecutionSessionCloser interface {
CloseExecutionSession(sessionID string)
}
const (
// CloseAllExecutionSessionsID asks an executor to release all active execution sessions.
// Executors that do not support this marker may ignore it.
CloseAllExecutionSessionsID = "__all_execution_sessions__"
)
// RefreshEvaluator allows runtime state to override refresh decisions.
type RefreshEvaluator interface {
ShouldRefresh(now time.Time, auth *Auth) bool
@@ -389,9 +400,23 @@ func (m *Manager) RegisterExecutor(executor ProviderExecutor) {
if executor == nil {
return
}
provider := strings.TrimSpace(executor.Identifier())
if provider == "" {
return
}
var replaced ProviderExecutor
m.mu.Lock()
defer m.mu.Unlock()
m.executors[executor.Identifier()] = executor
replaced = m.executors[provider]
m.executors[provider] = executor
m.mu.Unlock()
if replaced == nil || replaced == executor {
return
}
if closer, ok := replaced.(ExecutionSessionCloser); ok && closer != nil {
closer.CloseExecutionSession(CloseAllExecutionSessionsID)
}
}
// UnregisterExecutor removes the executor associated with the provider key.
@@ -581,6 +606,7 @@ func (m *Manager) executeMixedOnce(ctx context.Context, providers []string, req
entry := logEntryWithRequestID(ctx)
debugLogAuthSelection(entry, auth, provider, req.Model)
publishSelectedAuthMetadata(opts.Metadata, auth.ID)
tried[auth.ID] = struct{}{}
execCtx := ctx
@@ -636,6 +662,7 @@ func (m *Manager) executeCountMixedOnce(ctx context.Context, providers []string,
entry := logEntryWithRequestID(ctx)
debugLogAuthSelection(entry, auth, provider, req.Model)
publishSelectedAuthMetadata(opts.Metadata, auth.ID)
tried[auth.ID] = struct{}{}
execCtx := ctx
@@ -691,6 +718,7 @@ func (m *Manager) executeStreamMixedOnce(ctx context.Context, providers []string
entry := logEntryWithRequestID(ctx)
debugLogAuthSelection(entry, auth, provider, req.Model)
publishSelectedAuthMetadata(opts.Metadata, auth.ID)
tried[auth.ID] = struct{}{}
execCtx := ctx
@@ -794,6 +822,38 @@ func hasRequestedModelMetadata(meta map[string]any) bool {
}
}
func pinnedAuthIDFromMetadata(meta map[string]any) string {
if len(meta) == 0 {
return ""
}
raw, ok := meta[cliproxyexecutor.PinnedAuthMetadataKey]
if !ok || raw == nil {
return ""
}
switch val := raw.(type) {
case string:
return strings.TrimSpace(val)
case []byte:
return strings.TrimSpace(string(val))
default:
return ""
}
}
func publishSelectedAuthMetadata(meta map[string]any, authID string) {
if len(meta) == 0 {
return
}
authID = strings.TrimSpace(authID)
if authID == "" {
return
}
meta[cliproxyexecutor.SelectedAuthMetadataKey] = authID
if callback, ok := meta[cliproxyexecutor.SelectedAuthCallbackMetadataKey].(func(string)); ok && callback != nil {
callback(authID)
}
}
func rewriteModelForAuth(model string, auth *Auth) string {
if auth == nil || model == "" {
return model
@@ -1550,7 +1610,56 @@ func (m *Manager) GetByID(id string) (*Auth, bool) {
return auth.Clone(), true
}
// Executor returns the registered provider executor for a provider key.
func (m *Manager) Executor(provider string) (ProviderExecutor, bool) {
if m == nil {
return nil, false
}
provider = strings.TrimSpace(provider)
if provider == "" {
return nil, false
}
m.mu.RLock()
executor, okExecutor := m.executors[provider]
if !okExecutor {
lowerProvider := strings.ToLower(provider)
if lowerProvider != provider {
executor, okExecutor = m.executors[lowerProvider]
}
}
m.mu.RUnlock()
if !okExecutor || executor == nil {
return nil, false
}
return executor, true
}
// CloseExecutionSession asks all registered executors to release the supplied execution session.
func (m *Manager) CloseExecutionSession(sessionID string) {
sessionID = strings.TrimSpace(sessionID)
if m == nil || sessionID == "" {
return
}
m.mu.RLock()
executors := make([]ProviderExecutor, 0, len(m.executors))
for _, exec := range m.executors {
executors = append(executors, exec)
}
m.mu.RUnlock()
for i := range executors {
if closer, ok := executors[i].(ExecutionSessionCloser); ok && closer != nil {
closer.CloseExecutionSession(sessionID)
}
}
}
func (m *Manager) pickNext(ctx context.Context, provider, model string, opts cliproxyexecutor.Options, tried map[string]struct{}) (*Auth, ProviderExecutor, error) {
pinnedAuthID := pinnedAuthIDFromMetadata(opts.Metadata)
m.mu.RLock()
executor, okExecutor := m.executors[provider]
if !okExecutor {
@@ -1571,6 +1680,9 @@ func (m *Manager) pickNext(ctx context.Context, provider, model string, opts cli
if candidate.Provider != provider || candidate.Disabled {
continue
}
if pinnedAuthID != "" && candidate.ID != pinnedAuthID {
continue
}
if _, used := tried[candidate.ID]; used {
continue
}
@@ -1606,6 +1718,8 @@ func (m *Manager) pickNext(ctx context.Context, provider, model string, opts cli
}
func (m *Manager) pickNextMixed(ctx context.Context, providers []string, model string, opts cliproxyexecutor.Options, tried map[string]struct{}) (*Auth, ProviderExecutor, string, error) {
pinnedAuthID := pinnedAuthIDFromMetadata(opts.Metadata)
providerSet := make(map[string]struct{}, len(providers))
for _, provider := range providers {
p := strings.TrimSpace(strings.ToLower(provider))
@@ -1633,6 +1747,9 @@ func (m *Manager) pickNextMixed(ctx context.Context, providers []string, model s
if candidate == nil || candidate.Disabled {
continue
}
if pinnedAuthID != "" && candidate.ID != pinnedAuthID {
continue
}
providerKey := strings.TrimSpace(strings.ToLower(candidate.Provider))
if providerKey == "" {
continue

View File

@@ -0,0 +1,100 @@
package auth
import (
"context"
"net/http"
"sync"
"testing"
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
)
type replaceAwareExecutor struct {
id string
mu sync.Mutex
closedSessionIDs []string
}
func (e *replaceAwareExecutor) Identifier() string {
return e.id
}
func (e *replaceAwareExecutor) Execute(context.Context, *Auth, cliproxyexecutor.Request, cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
return cliproxyexecutor.Response{}, nil
}
func (e *replaceAwareExecutor) ExecuteStream(context.Context, *Auth, cliproxyexecutor.Request, cliproxyexecutor.Options) (<-chan cliproxyexecutor.StreamChunk, error) {
ch := make(chan cliproxyexecutor.StreamChunk)
close(ch)
return ch, nil
}
func (e *replaceAwareExecutor) Refresh(_ context.Context, auth *Auth) (*Auth, error) {
return auth, nil
}
func (e *replaceAwareExecutor) CountTokens(context.Context, *Auth, cliproxyexecutor.Request, cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
return cliproxyexecutor.Response{}, nil
}
func (e *replaceAwareExecutor) HttpRequest(context.Context, *Auth, *http.Request) (*http.Response, error) {
return nil, nil
}
func (e *replaceAwareExecutor) CloseExecutionSession(sessionID string) {
e.mu.Lock()
defer e.mu.Unlock()
e.closedSessionIDs = append(e.closedSessionIDs, sessionID)
}
func (e *replaceAwareExecutor) ClosedSessionIDs() []string {
e.mu.Lock()
defer e.mu.Unlock()
out := make([]string, len(e.closedSessionIDs))
copy(out, e.closedSessionIDs)
return out
}
func TestManagerRegisterExecutorClosesReplacedExecutionSessions(t *testing.T) {
t.Parallel()
manager := NewManager(nil, nil, nil)
replaced := &replaceAwareExecutor{id: "codex"}
current := &replaceAwareExecutor{id: "codex"}
manager.RegisterExecutor(replaced)
manager.RegisterExecutor(current)
closed := replaced.ClosedSessionIDs()
if len(closed) != 1 {
t.Fatalf("expected replaced executor close calls = 1, got %d", len(closed))
}
if closed[0] != CloseAllExecutionSessionsID {
t.Fatalf("expected close marker %q, got %q", CloseAllExecutionSessionsID, closed[0])
}
if len(current.ClosedSessionIDs()) != 0 {
t.Fatalf("expected current executor to stay open")
}
}
func TestManagerExecutorReturnsRegisteredExecutor(t *testing.T) {
t.Parallel()
manager := NewManager(nil, nil, nil)
current := &replaceAwareExecutor{id: "codex"}
manager.RegisterExecutor(current)
resolved, okResolved := manager.Executor("CODEX")
if !okResolved {
t.Fatal("expected registered executor to be found")
}
if resolved != current {
t.Fatal("expected resolved executor to match registered executor")
}
_, okMissing := manager.Executor("unknown")
if okMissing {
t.Fatal("expected unknown provider lookup to fail")
}
}

View File

@@ -134,6 +134,62 @@ func canonicalModelKey(model string) string {
return modelName
}
func authWebsocketsEnabled(auth *Auth) bool {
if auth == nil {
return false
}
if len(auth.Attributes) > 0 {
if raw := strings.TrimSpace(auth.Attributes["websockets"]); raw != "" {
parsed, errParse := strconv.ParseBool(raw)
if errParse == nil {
return parsed
}
}
}
if len(auth.Metadata) == 0 {
return false
}
raw, ok := auth.Metadata["websockets"]
if !ok || raw == nil {
return false
}
switch v := raw.(type) {
case bool:
return v
case string:
parsed, errParse := strconv.ParseBool(strings.TrimSpace(v))
if errParse == nil {
return parsed
}
default:
}
return false
}
func preferCodexWebsocketAuths(ctx context.Context, provider string, available []*Auth) []*Auth {
if len(available) == 0 {
return available
}
if !cliproxyexecutor.DownstreamWebsocket(ctx) {
return available
}
if !strings.EqualFold(strings.TrimSpace(provider), "codex") {
return available
}
wsEnabled := make([]*Auth, 0, len(available))
for i := 0; i < len(available); i++ {
candidate := available[i]
if authWebsocketsEnabled(candidate) {
wsEnabled = append(wsEnabled, candidate)
}
}
if len(wsEnabled) > 0 {
return wsEnabled
}
return available
}
func collectAvailableByPriority(auths []*Auth, model string, now time.Time) (available map[int][]*Auth, cooldownCount int, earliest time.Time) {
available = make(map[int][]*Auth)
for i := 0; i < len(auths); i++ {
@@ -193,13 +249,13 @@ func getAvailableAuths(auths []*Auth, provider, model string, now time.Time) ([]
// Pick selects the next available auth for the provider in a round-robin manner.
func (s *RoundRobinSelector) Pick(ctx context.Context, provider, model string, opts cliproxyexecutor.Options, auths []*Auth) (*Auth, error) {
_ = ctx
_ = opts
now := time.Now()
available, err := getAvailableAuths(auths, provider, model, now)
if err != nil {
return nil, err
}
available = preferCodexWebsocketAuths(ctx, provider, available)
key := provider + ":" + canonicalModelKey(model)
s.mu.Lock()
if s.cursors == nil {
@@ -226,13 +282,13 @@ func (s *RoundRobinSelector) Pick(ctx context.Context, provider, model string, o
// Pick selects the first available auth for the provider in a deterministic manner.
func (s *FillFirstSelector) Pick(ctx context.Context, provider, model string, opts cliproxyexecutor.Options, auths []*Auth) (*Auth, error) {
_ = ctx
_ = opts
now := time.Now()
available, err := getAvailableAuths(auths, provider, model, now)
if err != nil {
return nil, err
}
available = preferCodexWebsocketAuths(ctx, provider, available)
return available[0], nil
}

View File

@@ -0,0 +1,23 @@
package executor
import "context"
type downstreamWebsocketContextKey struct{}
// WithDownstreamWebsocket marks the current request as coming from a downstream websocket connection.
func WithDownstreamWebsocket(ctx context.Context) context.Context {
if ctx == nil {
ctx = context.Background()
}
return context.WithValue(ctx, downstreamWebsocketContextKey{}, true)
}
// DownstreamWebsocket reports whether the current request originates from a downstream websocket connection.
func DownstreamWebsocket(ctx context.Context) bool {
if ctx == nil {
return false
}
raw := ctx.Value(downstreamWebsocketContextKey{})
enabled, ok := raw.(bool)
return ok && enabled
}

View File

@@ -10,6 +10,17 @@ import (
// RequestedModelMetadataKey stores the client-requested model name in Options.Metadata.
const RequestedModelMetadataKey = "requested_model"
const (
// PinnedAuthMetadataKey locks execution to a specific auth ID.
PinnedAuthMetadataKey = "pinned_auth_id"
// SelectedAuthMetadataKey stores the auth ID selected by the scheduler.
SelectedAuthMetadataKey = "selected_auth_id"
// SelectedAuthCallbackMetadataKey carries an optional callback invoked with the selected auth ID.
SelectedAuthCallbackMetadataKey = "selected_auth_callback"
// ExecutionSessionMetadataKey identifies a long-lived downstream execution session.
ExecutionSessionMetadataKey = "execution_session_id"
)
// Request encapsulates the translated payload that will be sent to a provider executor.
type Request struct {
// Model is the upstream model identifier after translation.

View File

@@ -325,6 +325,9 @@ func (s *Service) applyCoreAuthRemoval(ctx context.Context, id string) {
if _, err := s.coreManager.Update(ctx, existing); err != nil {
log.Errorf("failed to disable auth %s: %v", id, err)
}
if strings.EqualFold(strings.TrimSpace(existing.Provider), "codex") {
s.ensureExecutorsForAuth(existing)
}
}
}
@@ -357,7 +360,24 @@ func openAICompatInfoFromAuth(a *coreauth.Auth) (providerKey string, compatName
}
func (s *Service) ensureExecutorsForAuth(a *coreauth.Auth) {
if s == nil || a == nil {
s.ensureExecutorsForAuthWithMode(a, false)
}
func (s *Service) ensureExecutorsForAuthWithMode(a *coreauth.Auth, forceReplace bool) {
if s == nil || s.coreManager == nil || a == nil {
return
}
if strings.EqualFold(strings.TrimSpace(a.Provider), "codex") {
if !forceReplace {
existingExecutor, hasExecutor := s.coreManager.Executor("codex")
if hasExecutor {
_, isCodexAutoExecutor := existingExecutor.(*executor.CodexAutoExecutor)
if isCodexAutoExecutor {
return
}
}
}
s.coreManager.RegisterExecutor(executor.NewCodexAutoExecutor(s.cfg))
return
}
// Skip disabled auth entries when (re)binding executors.
@@ -392,8 +412,6 @@ func (s *Service) ensureExecutorsForAuth(a *coreauth.Auth) {
s.coreManager.RegisterExecutor(executor.NewAntigravityExecutor(s.cfg))
case "claude":
s.coreManager.RegisterExecutor(executor.NewClaudeExecutor(s.cfg))
case "codex":
s.coreManager.RegisterExecutor(executor.NewCodexExecutor(s.cfg))
case "qwen":
s.coreManager.RegisterExecutor(executor.NewQwenExecutor(s.cfg))
case "iflow":
@@ -415,8 +433,15 @@ func (s *Service) rebindExecutors() {
return
}
auths := s.coreManager.List()
reboundCodex := false
for _, auth := range auths {
s.ensureExecutorsForAuth(auth)
if auth != nil && strings.EqualFold(strings.TrimSpace(auth.Provider), "codex") {
if reboundCodex {
continue
}
reboundCodex = true
}
s.ensureExecutorsForAuthWithMode(auth, true)
}
}

View File

@@ -0,0 +1,64 @@
package cliproxy
import (
"testing"
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
"github.com/router-for-me/CLIProxyAPI/v6/sdk/config"
)
func TestEnsureExecutorsForAuth_CodexDoesNotReplaceInNormalMode(t *testing.T) {
service := &Service{
cfg: &config.Config{},
coreManager: coreauth.NewManager(nil, nil, nil),
}
auth := &coreauth.Auth{
ID: "codex-auth-1",
Provider: "codex",
Status: coreauth.StatusActive,
}
service.ensureExecutorsForAuth(auth)
firstExecutor, okFirst := service.coreManager.Executor("codex")
if !okFirst || firstExecutor == nil {
t.Fatal("expected codex executor after first bind")
}
service.ensureExecutorsForAuth(auth)
secondExecutor, okSecond := service.coreManager.Executor("codex")
if !okSecond || secondExecutor == nil {
t.Fatal("expected codex executor after second bind")
}
if firstExecutor != secondExecutor {
t.Fatal("expected codex executor to stay unchanged in normal mode")
}
}
func TestEnsureExecutorsForAuthWithMode_CodexForceReplace(t *testing.T) {
service := &Service{
cfg: &config.Config{},
coreManager: coreauth.NewManager(nil, nil, nil),
}
auth := &coreauth.Auth{
ID: "codex-auth-2",
Provider: "codex",
Status: coreauth.StatusActive,
}
service.ensureExecutorsForAuth(auth)
firstExecutor, okFirst := service.coreManager.Executor("codex")
if !okFirst || firstExecutor == nil {
t.Fatal("expected codex executor after first bind")
}
service.ensureExecutorsForAuthWithMode(auth, true)
secondExecutor, okSecond := service.coreManager.Executor("codex")
if !okSecond || secondExecutor == nil {
t.Fatal("expected codex executor after forced rebind")
}
if firstExecutor == secondExecutor {
t.Fatal("expected codex executor replacement in force mode")
}
}