From bb86a0c0c44d1ed019c18320d2ee626843d6262f Mon Sep 17 00:00:00 2001 From: Luis Pater Date: Thu, 19 Feb 2026 01:57:02 +0800 Subject: [PATCH] feat(logging, executor): add request logging tests and WebSocket-based Codex executor - Introduced unit tests for request logging middleware to enhance coverage. - Added WebSocket-based Codex executor to support Responses API upgrade. - Updated middleware logic to selectively capture request bodies for memory efficiency. - Enhanced Codex configuration handling with new WebSocket attributes. --- internal/api/middleware/request_logging.go | 55 +- .../api/middleware/request_logging_test.go | 138 ++ internal/api/middleware/response_writer.go | 36 +- .../api/middleware/response_writer_test.go | 43 + internal/api/server.go | 1 + internal/config/config.go | 3 + internal/registry/model_definitions.go | 4 + .../registry/model_definitions_static_data.go | 26 + .../executor/codex_websockets_executor.go | 1407 +++++++++++++++++ internal/runtime/executor/qwen_executor.go | 18 +- internal/watcher/diff/config_diff.go | 3 + internal/watcher/synthesizer/config.go | 3 + internal/watcher/synthesizer/config_test.go | 12 +- sdk/api/handlers/handlers.go | 93 +- .../handlers_stream_bootstrap_test.go | 201 +++ .../openai/openai_responses_websocket.go | 662 ++++++++ .../openai/openai_responses_websocket_test.go | 249 +++ sdk/cliproxy/auth/conductor.go | 121 +- .../auth/conductor_executor_replace_test.go | 100 ++ sdk/cliproxy/auth/selector.go | 60 +- sdk/cliproxy/executor/context.go | 23 + sdk/cliproxy/executor/types.go | 11 + sdk/cliproxy/service.go | 33 +- .../service_codex_executor_binding_test.go | 64 + 24 files changed, 3332 insertions(+), 34 deletions(-) create mode 100644 internal/api/middleware/request_logging_test.go create mode 100644 internal/api/middleware/response_writer_test.go create mode 100644 internal/runtime/executor/codex_websockets_executor.go create mode 100644 sdk/api/handlers/openai/openai_responses_websocket.go create mode 100644 sdk/api/handlers/openai/openai_responses_websocket_test.go create mode 100644 sdk/cliproxy/auth/conductor_executor_replace_test.go create mode 100644 sdk/cliproxy/executor/context.go create mode 100644 sdk/cliproxy/service_codex_executor_binding_test.go diff --git a/internal/api/middleware/request_logging.go b/internal/api/middleware/request_logging.go index 2c9fdbdd..b57dd8aa 100644 --- a/internal/api/middleware/request_logging.go +++ b/internal/api/middleware/request_logging.go @@ -15,10 +15,12 @@ import ( "github.com/router-for-me/CLIProxyAPI/v6/internal/util" ) +const maxErrorOnlyCapturedRequestBodyBytes int64 = 1 << 20 // 1 MiB + // RequestLoggingMiddleware creates a Gin middleware that logs HTTP requests and responses. // It captures detailed information about the request and response, including headers and body, -// and uses the provided RequestLogger to record this data. When logging is disabled in the -// logger, it still captures data so that upstream errors can be persisted. +// and uses the provided RequestLogger to record this data. When full request logging is disabled, +// body capture is limited to small known-size payloads to avoid large per-request memory spikes. func RequestLoggingMiddleware(logger logging.RequestLogger) gin.HandlerFunc { return func(c *gin.Context) { if logger == nil { @@ -26,7 +28,7 @@ func RequestLoggingMiddleware(logger logging.RequestLogger) gin.HandlerFunc { return } - if c.Request.Method == http.MethodGet { + if shouldSkipMethodForRequestLogging(c.Request) { c.Next() return } @@ -37,8 +39,10 @@ func RequestLoggingMiddleware(logger logging.RequestLogger) gin.HandlerFunc { return } + loggerEnabled := logger.IsEnabled() + // Capture request information - requestInfo, err := captureRequestInfo(c) + requestInfo, err := captureRequestInfo(c, shouldCaptureRequestBody(loggerEnabled, c.Request)) if err != nil { // Log error but continue processing // In a real implementation, you might want to use a proper logger here @@ -48,7 +52,7 @@ func RequestLoggingMiddleware(logger logging.RequestLogger) gin.HandlerFunc { // Create response writer wrapper wrapper := NewResponseWriterWrapper(c.Writer, logger, requestInfo) - if !logger.IsEnabled() { + if !loggerEnabled { wrapper.logOnErrorOnly = true } c.Writer = wrapper @@ -64,10 +68,47 @@ func RequestLoggingMiddleware(logger logging.RequestLogger) gin.HandlerFunc { } } +func shouldSkipMethodForRequestLogging(req *http.Request) bool { + if req == nil { + return true + } + if req.Method != http.MethodGet { + return false + } + return !isResponsesWebsocketUpgrade(req) +} + +func isResponsesWebsocketUpgrade(req *http.Request) bool { + if req == nil || req.URL == nil { + return false + } + if req.URL.Path != "/v1/responses" { + return false + } + return strings.EqualFold(strings.TrimSpace(req.Header.Get("Upgrade")), "websocket") +} + +func shouldCaptureRequestBody(loggerEnabled bool, req *http.Request) bool { + if loggerEnabled { + return true + } + if req == nil || req.Body == nil { + return false + } + contentType := strings.ToLower(strings.TrimSpace(req.Header.Get("Content-Type"))) + if strings.HasPrefix(contentType, "multipart/form-data") { + return false + } + if req.ContentLength <= 0 { + return false + } + return req.ContentLength <= maxErrorOnlyCapturedRequestBodyBytes +} + // captureRequestInfo extracts relevant information from the incoming HTTP request. // It captures the URL, method, headers, and body. The request body is read and then // restored so that it can be processed by subsequent handlers. -func captureRequestInfo(c *gin.Context) (*RequestInfo, error) { +func captureRequestInfo(c *gin.Context, captureBody bool) (*RequestInfo, error) { // Capture URL with sensitive query parameters masked maskedQuery := util.MaskSensitiveQuery(c.Request.URL.RawQuery) url := c.Request.URL.Path @@ -86,7 +127,7 @@ func captureRequestInfo(c *gin.Context) (*RequestInfo, error) { // Capture request body var body []byte - if c.Request.Body != nil { + if captureBody && c.Request.Body != nil { // Read the body bodyBytes, err := io.ReadAll(c.Request.Body) if err != nil { diff --git a/internal/api/middleware/request_logging_test.go b/internal/api/middleware/request_logging_test.go new file mode 100644 index 00000000..c4354678 --- /dev/null +++ b/internal/api/middleware/request_logging_test.go @@ -0,0 +1,138 @@ +package middleware + +import ( + "io" + "net/http" + "net/url" + "strings" + "testing" +) + +func TestShouldSkipMethodForRequestLogging(t *testing.T) { + tests := []struct { + name string + req *http.Request + skip bool + }{ + { + name: "nil request", + req: nil, + skip: true, + }, + { + name: "post request should not skip", + req: &http.Request{ + Method: http.MethodPost, + URL: &url.URL{Path: "/v1/responses"}, + }, + skip: false, + }, + { + name: "plain get should skip", + req: &http.Request{ + Method: http.MethodGet, + URL: &url.URL{Path: "/v1/models"}, + Header: http.Header{}, + }, + skip: true, + }, + { + name: "responses websocket upgrade should not skip", + req: &http.Request{ + Method: http.MethodGet, + URL: &url.URL{Path: "/v1/responses"}, + Header: http.Header{"Upgrade": []string{"websocket"}}, + }, + skip: false, + }, + { + name: "responses get without upgrade should skip", + req: &http.Request{ + Method: http.MethodGet, + URL: &url.URL{Path: "/v1/responses"}, + Header: http.Header{}, + }, + skip: true, + }, + } + + for i := range tests { + got := shouldSkipMethodForRequestLogging(tests[i].req) + if got != tests[i].skip { + t.Fatalf("%s: got skip=%t, want %t", tests[i].name, got, tests[i].skip) + } + } +} + +func TestShouldCaptureRequestBody(t *testing.T) { + tests := []struct { + name string + loggerEnabled bool + req *http.Request + want bool + }{ + { + name: "logger enabled always captures", + loggerEnabled: true, + req: &http.Request{ + Body: io.NopCloser(strings.NewReader("{}")), + ContentLength: -1, + Header: http.Header{"Content-Type": []string{"application/json"}}, + }, + want: true, + }, + { + name: "nil request", + loggerEnabled: false, + req: nil, + want: false, + }, + { + name: "small known size json in error-only mode", + loggerEnabled: false, + req: &http.Request{ + Body: io.NopCloser(strings.NewReader("{}")), + ContentLength: 2, + Header: http.Header{"Content-Type": []string{"application/json"}}, + }, + want: true, + }, + { + name: "large known size skipped in error-only mode", + loggerEnabled: false, + req: &http.Request{ + Body: io.NopCloser(strings.NewReader("x")), + ContentLength: maxErrorOnlyCapturedRequestBodyBytes + 1, + Header: http.Header{"Content-Type": []string{"application/json"}}, + }, + want: false, + }, + { + name: "unknown size skipped in error-only mode", + loggerEnabled: false, + req: &http.Request{ + Body: io.NopCloser(strings.NewReader("x")), + ContentLength: -1, + Header: http.Header{"Content-Type": []string{"application/json"}}, + }, + want: false, + }, + { + name: "multipart skipped in error-only mode", + loggerEnabled: false, + req: &http.Request{ + Body: io.NopCloser(strings.NewReader("x")), + ContentLength: 1, + Header: http.Header{"Content-Type": []string{"multipart/form-data; boundary=abc"}}, + }, + want: false, + }, + } + + for i := range tests { + got := shouldCaptureRequestBody(tests[i].loggerEnabled, tests[i].req) + if got != tests[i].want { + t.Fatalf("%s: got %t, want %t", tests[i].name, got, tests[i].want) + } + } +} diff --git a/internal/api/middleware/response_writer.go b/internal/api/middleware/response_writer.go index 50fa1c69..363278ab 100644 --- a/internal/api/middleware/response_writer.go +++ b/internal/api/middleware/response_writer.go @@ -14,6 +14,8 @@ import ( "github.com/router-for-me/CLIProxyAPI/v6/internal/logging" ) +const requestBodyOverrideContextKey = "REQUEST_BODY_OVERRIDE" + // RequestInfo holds essential details of an incoming HTTP request for logging purposes. type RequestInfo struct { URL string // URL is the request URL. @@ -223,8 +225,8 @@ func (w *ResponseWriterWrapper) detectStreaming(contentType string) bool { // Only fall back to request payload hints when Content-Type is not set yet. if w.requestInfo != nil && len(w.requestInfo.Body) > 0 { - bodyStr := string(w.requestInfo.Body) - return strings.Contains(bodyStr, `"stream": true`) || strings.Contains(bodyStr, `"stream":true`) + return bytes.Contains(w.requestInfo.Body, []byte(`"stream": true`)) || + bytes.Contains(w.requestInfo.Body, []byte(`"stream":true`)) } return false @@ -310,7 +312,7 @@ func (w *ResponseWriterWrapper) Finalize(c *gin.Context) error { return nil } - return w.logRequest(finalStatusCode, w.cloneHeaders(), w.body.Bytes(), w.extractAPIRequest(c), w.extractAPIResponse(c), w.extractAPIResponseTimestamp(c), slicesAPIResponseError, forceLog) + return w.logRequest(w.extractRequestBody(c), finalStatusCode, w.cloneHeaders(), w.body.Bytes(), w.extractAPIRequest(c), w.extractAPIResponse(c), w.extractAPIResponseTimestamp(c), slicesAPIResponseError, forceLog) } func (w *ResponseWriterWrapper) cloneHeaders() map[string][]string { @@ -361,16 +363,32 @@ func (w *ResponseWriterWrapper) extractAPIResponseTimestamp(c *gin.Context) time return time.Time{} } -func (w *ResponseWriterWrapper) logRequest(statusCode int, headers map[string][]string, body []byte, apiRequestBody, apiResponseBody []byte, apiResponseTimestamp time.Time, apiResponseErrors []*interfaces.ErrorMessage, forceLog bool) error { +func (w *ResponseWriterWrapper) extractRequestBody(c *gin.Context) []byte { + if c != nil { + if bodyOverride, isExist := c.Get(requestBodyOverrideContextKey); isExist { + switch value := bodyOverride.(type) { + case []byte: + if len(value) > 0 { + return bytes.Clone(value) + } + case string: + if strings.TrimSpace(value) != "" { + return []byte(value) + } + } + } + } + if w.requestInfo != nil && len(w.requestInfo.Body) > 0 { + return w.requestInfo.Body + } + return nil +} + +func (w *ResponseWriterWrapper) logRequest(requestBody []byte, statusCode int, headers map[string][]string, body []byte, apiRequestBody, apiResponseBody []byte, apiResponseTimestamp time.Time, apiResponseErrors []*interfaces.ErrorMessage, forceLog bool) error { if w.requestInfo == nil { return nil } - var requestBody []byte - if len(w.requestInfo.Body) > 0 { - requestBody = w.requestInfo.Body - } - if loggerWithOptions, ok := w.logger.(interface { LogRequestWithOptions(string, string, map[string][]string, []byte, int, map[string][]string, []byte, []byte, []byte, []*interfaces.ErrorMessage, bool, string, time.Time, time.Time) error }); ok { diff --git a/internal/api/middleware/response_writer_test.go b/internal/api/middleware/response_writer_test.go new file mode 100644 index 00000000..fa4708e4 --- /dev/null +++ b/internal/api/middleware/response_writer_test.go @@ -0,0 +1,43 @@ +package middleware + +import ( + "net/http/httptest" + "testing" + + "github.com/gin-gonic/gin" +) + +func TestExtractRequestBodyPrefersOverride(t *testing.T) { + gin.SetMode(gin.TestMode) + recorder := httptest.NewRecorder() + c, _ := gin.CreateTestContext(recorder) + + wrapper := &ResponseWriterWrapper{ + requestInfo: &RequestInfo{Body: []byte("original-body")}, + } + + body := wrapper.extractRequestBody(c) + if string(body) != "original-body" { + t.Fatalf("request body = %q, want %q", string(body), "original-body") + } + + c.Set(requestBodyOverrideContextKey, []byte("override-body")) + body = wrapper.extractRequestBody(c) + if string(body) != "override-body" { + t.Fatalf("request body = %q, want %q", string(body), "override-body") + } +} + +func TestExtractRequestBodySupportsStringOverride(t *testing.T) { + gin.SetMode(gin.TestMode) + recorder := httptest.NewRecorder() + c, _ := gin.CreateTestContext(recorder) + + wrapper := &ResponseWriterWrapper{} + c.Set(requestBodyOverrideContextKey, "override-as-string") + + body := wrapper.extractRequestBody(c) + if string(body) != "override-as-string" { + t.Fatalf("request body = %q, want %q", string(body), "override-as-string") + } +} diff --git a/internal/api/server.go b/internal/api/server.go index 4cbcbba2..932bb4b0 100644 --- a/internal/api/server.go +++ b/internal/api/server.go @@ -323,6 +323,7 @@ func (s *Server) setupRoutes() { v1.POST("/completions", openaiHandlers.Completions) v1.POST("/messages", claudeCodeHandlers.ClaudeMessages) v1.POST("/messages/count_tokens", claudeCodeHandlers.ClaudeCountTokens) + v1.GET("/responses", openaiResponsesHandlers.ResponsesWebsocket) v1.POST("/responses", openaiResponsesHandlers.Responses) v1.POST("/responses/compact", openaiResponsesHandlers.Compact) } diff --git a/internal/config/config.go b/internal/config/config.go index c78b2582..6a1a24c1 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -355,6 +355,9 @@ type CodexKey struct { // If empty, the default Codex API URL will be used. BaseURL string `yaml:"base-url" json:"base-url"` + // Websockets enables the Responses API websocket transport for this credential. + Websockets bool `yaml:"websockets,omitempty" json:"websockets,omitempty"` + // ProxyURL overrides the global proxy setting for this API key if provided. ProxyURL string `yaml:"proxy-url" json:"proxy-url"` diff --git a/internal/registry/model_definitions.go b/internal/registry/model_definitions.go index 585bdf8c..c1796979 100644 --- a/internal/registry/model_definitions.go +++ b/internal/registry/model_definitions.go @@ -19,6 +19,7 @@ import ( // - codex // - qwen // - iflow +// - kimi // - antigravity (returns static overrides only) func GetStaticModelDefinitionsByChannel(channel string) []*ModelInfo { key := strings.ToLower(strings.TrimSpace(channel)) @@ -39,6 +40,8 @@ func GetStaticModelDefinitionsByChannel(channel string) []*ModelInfo { return GetQwenModels() case "iflow": return GetIFlowModels() + case "kimi": + return GetKimiModels() case "antigravity": cfg := GetAntigravityModelConfig() if len(cfg) == 0 { @@ -83,6 +86,7 @@ func LookupStaticModelInfo(modelID string) *ModelInfo { GetOpenAIModels(), GetQwenModels(), GetIFlowModels(), + GetKimiModels(), } for _, models := range allModels { for _, m := range models { diff --git a/internal/registry/model_definitions_static_data.go b/internal/registry/model_definitions_static_data.go index 39b2aa0c..144c4bce 100644 --- a/internal/registry/model_definitions_static_data.go +++ b/internal/registry/model_definitions_static_data.go @@ -28,6 +28,17 @@ func GetClaudeModels() []*ModelInfo { MaxCompletionTokens: 64000, Thinking: &ThinkingSupport{Min: 1024, Max: 128000, ZeroAllowed: true, DynamicAllowed: false}, }, + { + ID: "claude-sonnet-4-6", + Object: "model", + Created: 1771372800, // 2026-02-17 + OwnedBy: "anthropic", + Type: "claude", + DisplayName: "Claude 4.6 Sonnet", + ContextLength: 200000, + MaxCompletionTokens: 64000, + Thinking: &ThinkingSupport{Min: 1024, Max: 128000, ZeroAllowed: true, DynamicAllowed: false}, + }, { ID: "claude-opus-4-6", Object: "model", @@ -788,6 +799,19 @@ func GetQwenModels() []*ModelInfo { MaxCompletionTokens: 2048, SupportedParameters: []string{"temperature", "top_p", "max_tokens", "stream", "stop"}, }, + { + ID: "coder-model", + Object: "model", + Created: 1771171200, + OwnedBy: "qwen", + Type: "qwen", + Version: "3.5", + DisplayName: "Qwen 3.5 Plus", + Description: "efficient hybrid model with leading coding performance", + ContextLength: 1048576, + MaxCompletionTokens: 65536, + SupportedParameters: []string{"temperature", "top_p", "max_tokens", "stream", "stop"}, + }, { ID: "vision-model", Object: "model", @@ -884,6 +908,8 @@ func GetAntigravityModelConfig() map[string]*AntigravityModelConfig { "claude-opus-4-5-thinking": {Thinking: &ThinkingSupport{Min: 1024, Max: 128000, ZeroAllowed: true, DynamicAllowed: true}, MaxCompletionTokens: 64000}, "claude-opus-4-6-thinking": {Thinking: &ThinkingSupport{Min: 1024, Max: 128000, ZeroAllowed: true, DynamicAllowed: true}, MaxCompletionTokens: 64000}, "claude-sonnet-4-5": {MaxCompletionTokens: 64000}, + "claude-sonnet-4-6": {MaxCompletionTokens: 64000}, + "claude-sonnet-4-6-thinking": {Thinking: &ThinkingSupport{Min: 1024, Max: 128000, ZeroAllowed: true, DynamicAllowed: true}, MaxCompletionTokens: 64000}, "gpt-oss-120b-medium": {}, "tab_flash_lite_preview": {}, } diff --git a/internal/runtime/executor/codex_websockets_executor.go b/internal/runtime/executor/codex_websockets_executor.go new file mode 100644 index 00000000..38ffad77 --- /dev/null +++ b/internal/runtime/executor/codex_websockets_executor.go @@ -0,0 +1,1407 @@ +// Package executor provides runtime execution capabilities for various AI service providers. +// This file implements a Codex executor that uses the Responses API WebSocket transport. +package executor + +import ( + "bytes" + "context" + "fmt" + "io" + "net" + "net/http" + "net/url" + "strconv" + "strings" + "sync" + "time" + + "github.com/google/uuid" + "github.com/gorilla/websocket" + "github.com/router-for-me/CLIProxyAPI/v6/internal/config" + "github.com/router-for-me/CLIProxyAPI/v6/internal/misc" + "github.com/router-for-me/CLIProxyAPI/v6/internal/thinking" + "github.com/router-for-me/CLIProxyAPI/v6/internal/util" + cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth" + cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor" + sdktranslator "github.com/router-for-me/CLIProxyAPI/v6/sdk/translator" + log "github.com/sirupsen/logrus" + "github.com/tidwall/gjson" + "github.com/tidwall/sjson" + "golang.org/x/net/proxy" +) + +const ( + codexResponsesWebsocketBetaHeaderValue = "responses_websockets=2026-02-04" + codexResponsesWebsocketIdleTimeout = 5 * time.Minute + codexResponsesWebsocketHandshakeTO = 30 * time.Second +) + +// CodexWebsocketsExecutor executes Codex Responses requests using a WebSocket transport. +// +// It preserves the existing CodexExecutor HTTP implementation as a fallback for endpoints +// not available over WebSocket (e.g. /responses/compact) and for websocket upgrade failures. +type CodexWebsocketsExecutor struct { + *CodexExecutor + + sessMu sync.Mutex + sessions map[string]*codexWebsocketSession +} + +type codexWebsocketSession struct { + sessionID string + + reqMu sync.Mutex + + connMu sync.Mutex + conn *websocket.Conn + wsURL string + authID string + + // connCreateSent tracks whether a `response.create` message has been successfully sent + // on the current websocket connection. The upstream expects the first message on each + // connection to be `response.create`. + connCreateSent bool + + writeMu sync.Mutex + + activeMu sync.Mutex + activeCh chan codexWebsocketRead + activeDone <-chan struct{} + activeCancel context.CancelFunc + + readerConn *websocket.Conn +} + +func NewCodexWebsocketsExecutor(cfg *config.Config) *CodexWebsocketsExecutor { + return &CodexWebsocketsExecutor{ + CodexExecutor: NewCodexExecutor(cfg), + sessions: make(map[string]*codexWebsocketSession), + } +} + +type codexWebsocketRead struct { + conn *websocket.Conn + msgType int + payload []byte + err error +} + +func (s *codexWebsocketSession) setActive(ch chan codexWebsocketRead) { + if s == nil { + return + } + s.activeMu.Lock() + if s.activeCancel != nil { + s.activeCancel() + s.activeCancel = nil + s.activeDone = nil + } + s.activeCh = ch + if ch != nil { + activeCtx, activeCancel := context.WithCancel(context.Background()) + s.activeDone = activeCtx.Done() + s.activeCancel = activeCancel + } + s.activeMu.Unlock() +} + +func (s *codexWebsocketSession) clearActive(ch chan codexWebsocketRead) { + if s == nil { + return + } + s.activeMu.Lock() + if s.activeCh == ch { + s.activeCh = nil + if s.activeCancel != nil { + s.activeCancel() + } + s.activeCancel = nil + s.activeDone = nil + } + s.activeMu.Unlock() +} + +func (s *codexWebsocketSession) writeMessage(conn *websocket.Conn, msgType int, payload []byte) error { + if s == nil { + return fmt.Errorf("codex websockets executor: session is nil") + } + if conn == nil { + return fmt.Errorf("codex websockets executor: websocket conn is nil") + } + s.writeMu.Lock() + defer s.writeMu.Unlock() + return conn.WriteMessage(msgType, payload) +} + +func (s *codexWebsocketSession) configureConn(conn *websocket.Conn) { + if s == nil || conn == nil { + return + } + conn.SetPingHandler(func(appData string) error { + s.writeMu.Lock() + defer s.writeMu.Unlock() + // Reply pongs from the same write lock to avoid concurrent writes. + return conn.WriteControl(websocket.PongMessage, []byte(appData), time.Now().Add(10*time.Second)) + }) +} + +func (e *CodexWebsocketsExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (resp cliproxyexecutor.Response, err error) { + if ctx == nil { + ctx = context.Background() + } + if opts.Alt == "responses/compact" { + return e.CodexExecutor.executeCompact(ctx, auth, req, opts) + } + + baseModel := thinking.ParseSuffix(req.Model).ModelName + apiKey, baseURL := codexCreds(auth) + if baseURL == "" { + baseURL = "https://chatgpt.com/backend-api/codex" + } + + reporter := newUsageReporter(ctx, e.Identifier(), baseModel, auth) + defer reporter.trackFailure(ctx, &err) + + from := opts.SourceFormat + to := sdktranslator.FromString("codex") + originalPayloadSource := req.Payload + if len(opts.OriginalRequest) > 0 { + originalPayloadSource = opts.OriginalRequest + } + originalPayload := originalPayloadSource + originalTranslated := sdktranslator.TranslateRequest(from, to, baseModel, originalPayload, false) + body := sdktranslator.TranslateRequest(from, to, baseModel, req.Payload, false) + + body, err = thinking.ApplyThinking(body, req.Model, from.String(), to.String(), e.Identifier()) + if err != nil { + return resp, err + } + + requestedModel := payloadRequestedModel(opts, req.Model) + body = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, originalTranslated, requestedModel) + body, _ = sjson.SetBytes(body, "model", baseModel) + body, _ = sjson.SetBytes(body, "stream", true) + body, _ = sjson.DeleteBytes(body, "previous_response_id") + body, _ = sjson.DeleteBytes(body, "prompt_cache_retention") + body, _ = sjson.DeleteBytes(body, "safety_identifier") + if !gjson.GetBytes(body, "instructions").Exists() { + body, _ = sjson.SetBytes(body, "instructions", "") + } + + httpURL := strings.TrimSuffix(baseURL, "/") + "/responses" + wsURL, err := buildCodexResponsesWebsocketURL(httpURL) + if err != nil { + return resp, err + } + + body, wsHeaders := applyCodexPromptCacheHeaders(from, req, body) + wsHeaders = applyCodexWebsocketHeaders(ctx, wsHeaders, auth, apiKey) + + var authID, authLabel, authType, authValue string + if auth != nil { + authID = auth.ID + authLabel = auth.Label + authType, authValue = auth.AccountInfo() + } + + executionSessionID := executionSessionIDFromOptions(opts) + var sess *codexWebsocketSession + if executionSessionID != "" { + sess = e.getOrCreateSession(executionSessionID) + sess.reqMu.Lock() + defer sess.reqMu.Unlock() + } + + allowAppend := true + if sess != nil { + sess.connMu.Lock() + allowAppend = sess.connCreateSent + sess.connMu.Unlock() + } + wsReqBody := buildCodexWebsocketRequestBody(body, allowAppend) + recordAPIRequest(ctx, e.cfg, upstreamRequestLog{ + URL: wsURL, + Method: "WEBSOCKET", + Headers: wsHeaders.Clone(), + Body: wsReqBody, + Provider: e.Identifier(), + AuthID: authID, + AuthLabel: authLabel, + AuthType: authType, + AuthValue: authValue, + }) + + conn, respHS, errDial := e.ensureUpstreamConn(ctx, auth, sess, authID, wsURL, wsHeaders) + if respHS != nil { + recordAPIResponseMetadata(ctx, e.cfg, respHS.StatusCode, respHS.Header.Clone()) + } + if errDial != nil { + bodyErr := websocketHandshakeBody(respHS) + if len(bodyErr) > 0 { + appendAPIResponseChunk(ctx, e.cfg, bodyErr) + } + if respHS != nil && respHS.StatusCode == http.StatusUpgradeRequired { + return e.CodexExecutor.Execute(ctx, auth, req, opts) + } + if respHS != nil && respHS.StatusCode > 0 { + return resp, statusErr{code: respHS.StatusCode, msg: string(bodyErr)} + } + recordAPIResponseError(ctx, e.cfg, errDial) + return resp, errDial + } + closeHTTPResponseBody(respHS, "codex websockets executor: close handshake response body error") + if sess == nil { + logCodexWebsocketConnected(executionSessionID, authID, wsURL) + defer func() { + reason := "completed" + if err != nil { + reason = "error" + } + logCodexWebsocketDisconnected(executionSessionID, authID, wsURL, reason, err) + if errClose := conn.Close(); errClose != nil { + log.Errorf("codex websockets executor: close websocket error: %v", errClose) + } + }() + } + + var readCh chan codexWebsocketRead + if sess != nil { + readCh = make(chan codexWebsocketRead, 4096) + sess.setActive(readCh) + defer sess.clearActive(readCh) + } + + if errSend := writeCodexWebsocketMessage(sess, conn, wsReqBody); errSend != nil { + if sess != nil { + e.invalidateUpstreamConn(sess, conn, "send_error", errSend) + + // Retry once with a fresh websocket connection. This is mainly to handle + // upstream closing the socket between sequential requests within the same + // execution session. + connRetry, _, errDialRetry := e.ensureUpstreamConn(ctx, auth, sess, authID, wsURL, wsHeaders) + if errDialRetry == nil && connRetry != nil { + sess.connMu.Lock() + allowAppend = sess.connCreateSent + sess.connMu.Unlock() + wsReqBodyRetry := buildCodexWebsocketRequestBody(body, allowAppend) + recordAPIRequest(ctx, e.cfg, upstreamRequestLog{ + URL: wsURL, + Method: "WEBSOCKET", + Headers: wsHeaders.Clone(), + Body: wsReqBodyRetry, + Provider: e.Identifier(), + AuthID: authID, + AuthLabel: authLabel, + AuthType: authType, + AuthValue: authValue, + }) + if errSendRetry := writeCodexWebsocketMessage(sess, connRetry, wsReqBodyRetry); errSendRetry == nil { + conn = connRetry + wsReqBody = wsReqBodyRetry + } else { + e.invalidateUpstreamConn(sess, connRetry, "send_error", errSendRetry) + recordAPIResponseError(ctx, e.cfg, errSendRetry) + return resp, errSendRetry + } + } else { + recordAPIResponseError(ctx, e.cfg, errDialRetry) + return resp, errDialRetry + } + } else { + recordAPIResponseError(ctx, e.cfg, errSend) + return resp, errSend + } + } + markCodexWebsocketCreateSent(sess, conn, wsReqBody) + + for { + if ctx != nil && ctx.Err() != nil { + return resp, ctx.Err() + } + msgType, payload, errRead := readCodexWebsocketMessage(ctx, sess, conn, readCh) + if errRead != nil { + recordAPIResponseError(ctx, e.cfg, errRead) + return resp, errRead + } + if msgType != websocket.TextMessage { + if msgType == websocket.BinaryMessage { + err = fmt.Errorf("codex websockets executor: unexpected binary message") + if sess != nil { + e.invalidateUpstreamConn(sess, conn, "unexpected_binary", err) + } + recordAPIResponseError(ctx, e.cfg, err) + return resp, err + } + continue + } + + payload = bytes.TrimSpace(payload) + if len(payload) == 0 { + continue + } + appendAPIResponseChunk(ctx, e.cfg, payload) + + if wsErr, ok := parseCodexWebsocketError(payload); ok { + if sess != nil { + e.invalidateUpstreamConn(sess, conn, "upstream_error", wsErr) + } + recordAPIResponseError(ctx, e.cfg, wsErr) + return resp, wsErr + } + + payload = normalizeCodexWebsocketCompletion(payload) + eventType := gjson.GetBytes(payload, "type").String() + if eventType == "response.completed" { + if detail, ok := parseCodexUsage(payload); ok { + reporter.publish(ctx, detail) + } + var param any + out := sdktranslator.TranslateNonStream(ctx, to, from, req.Model, originalPayload, body, payload, ¶m) + resp = cliproxyexecutor.Response{Payload: []byte(out)} + return resp, nil + } + } +} + +func (e *CodexWebsocketsExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (stream <-chan cliproxyexecutor.StreamChunk, err error) { + log.Debugf("Executing Codex Websockets stream request with auth ID: %s, model: %s", auth.ID, req.Model) + if ctx == nil { + ctx = context.Background() + } + if opts.Alt == "responses/compact" { + return nil, statusErr{code: http.StatusBadRequest, msg: "streaming not supported for /responses/compact"} + } + + baseModel := thinking.ParseSuffix(req.Model).ModelName + apiKey, baseURL := codexCreds(auth) + if baseURL == "" { + baseURL = "https://chatgpt.com/backend-api/codex" + } + + reporter := newUsageReporter(ctx, e.Identifier(), baseModel, auth) + defer reporter.trackFailure(ctx, &err) + + from := opts.SourceFormat + to := sdktranslator.FromString("codex") + body := req.Payload + + body, err = thinking.ApplyThinking(body, req.Model, from.String(), to.String(), e.Identifier()) + if err != nil { + return nil, err + } + + requestedModel := payloadRequestedModel(opts, req.Model) + body = applyPayloadConfigWithRoot(e.cfg, baseModel, to.String(), "", body, body, requestedModel) + + httpURL := strings.TrimSuffix(baseURL, "/") + "/responses" + wsURL, err := buildCodexResponsesWebsocketURL(httpURL) + if err != nil { + return nil, err + } + + body, wsHeaders := applyCodexPromptCacheHeaders(from, req, body) + wsHeaders = applyCodexWebsocketHeaders(ctx, wsHeaders, auth, apiKey) + + var authID, authLabel, authType, authValue string + if auth != nil { + authID = auth.ID + authLabel = auth.Label + authType, authValue = auth.AccountInfo() + } + + executionSessionID := executionSessionIDFromOptions(opts) + var sess *codexWebsocketSession + if executionSessionID != "" { + sess = e.getOrCreateSession(executionSessionID) + sess.reqMu.Lock() + } + + allowAppend := true + if sess != nil { + sess.connMu.Lock() + allowAppend = sess.connCreateSent + sess.connMu.Unlock() + } + wsReqBody := buildCodexWebsocketRequestBody(body, allowAppend) + recordAPIRequest(ctx, e.cfg, upstreamRequestLog{ + URL: wsURL, + Method: "WEBSOCKET", + Headers: wsHeaders.Clone(), + Body: wsReqBody, + Provider: e.Identifier(), + AuthID: authID, + AuthLabel: authLabel, + AuthType: authType, + AuthValue: authValue, + }) + + conn, respHS, errDial := e.ensureUpstreamConn(ctx, auth, sess, authID, wsURL, wsHeaders) + if respHS != nil { + recordAPIResponseMetadata(ctx, e.cfg, respHS.StatusCode, respHS.Header.Clone()) + } + if errDial != nil { + bodyErr := websocketHandshakeBody(respHS) + if len(bodyErr) > 0 { + appendAPIResponseChunk(ctx, e.cfg, bodyErr) + } + if respHS != nil && respHS.StatusCode == http.StatusUpgradeRequired { + return e.CodexExecutor.ExecuteStream(ctx, auth, req, opts) + } + if respHS != nil && respHS.StatusCode > 0 { + return nil, statusErr{code: respHS.StatusCode, msg: string(bodyErr)} + } + recordAPIResponseError(ctx, e.cfg, errDial) + if sess != nil { + sess.reqMu.Unlock() + } + return nil, errDial + } + closeHTTPResponseBody(respHS, "codex websockets executor: close handshake response body error") + + if sess == nil { + logCodexWebsocketConnected(executionSessionID, authID, wsURL) + } + + var readCh chan codexWebsocketRead + if sess != nil { + readCh = make(chan codexWebsocketRead, 4096) + sess.setActive(readCh) + } + + if errSend := writeCodexWebsocketMessage(sess, conn, wsReqBody); errSend != nil { + recordAPIResponseError(ctx, e.cfg, errSend) + if sess != nil { + e.invalidateUpstreamConn(sess, conn, "send_error", errSend) + + // Retry once with a new websocket connection for the same execution session. + connRetry, _, errDialRetry := e.ensureUpstreamConn(ctx, auth, sess, authID, wsURL, wsHeaders) + if errDialRetry != nil || connRetry == nil { + recordAPIResponseError(ctx, e.cfg, errDialRetry) + sess.clearActive(readCh) + sess.reqMu.Unlock() + return nil, errDialRetry + } + sess.connMu.Lock() + allowAppend = sess.connCreateSent + sess.connMu.Unlock() + wsReqBodyRetry := buildCodexWebsocketRequestBody(body, allowAppend) + recordAPIRequest(ctx, e.cfg, upstreamRequestLog{ + URL: wsURL, + Method: "WEBSOCKET", + Headers: wsHeaders.Clone(), + Body: wsReqBodyRetry, + Provider: e.Identifier(), + AuthID: authID, + AuthLabel: authLabel, + AuthType: authType, + AuthValue: authValue, + }) + if errSendRetry := writeCodexWebsocketMessage(sess, connRetry, wsReqBodyRetry); errSendRetry != nil { + recordAPIResponseError(ctx, e.cfg, errSendRetry) + e.invalidateUpstreamConn(sess, connRetry, "send_error", errSendRetry) + sess.clearActive(readCh) + sess.reqMu.Unlock() + return nil, errSendRetry + } + conn = connRetry + wsReqBody = wsReqBodyRetry + } else { + logCodexWebsocketDisconnected(executionSessionID, authID, wsURL, "send_error", errSend) + if errClose := conn.Close(); errClose != nil { + log.Errorf("codex websockets executor: close websocket error: %v", errClose) + } + return nil, errSend + } + } + markCodexWebsocketCreateSent(sess, conn, wsReqBody) + + out := make(chan cliproxyexecutor.StreamChunk) + stream = out + go func() { + terminateReason := "completed" + var terminateErr error + + defer close(out) + defer func() { + if sess != nil { + sess.clearActive(readCh) + sess.reqMu.Unlock() + return + } + logCodexWebsocketDisconnected(executionSessionID, authID, wsURL, terminateReason, terminateErr) + if errClose := conn.Close(); errClose != nil { + log.Errorf("codex websockets executor: close websocket error: %v", errClose) + } + }() + + send := func(chunk cliproxyexecutor.StreamChunk) bool { + if ctx == nil { + out <- chunk + return true + } + select { + case out <- chunk: + return true + case <-ctx.Done(): + return false + } + } + + var param any + for { + if ctx != nil && ctx.Err() != nil { + terminateReason = "context_done" + terminateErr = ctx.Err() + _ = send(cliproxyexecutor.StreamChunk{Err: ctx.Err()}) + return + } + msgType, payload, errRead := readCodexWebsocketMessage(ctx, sess, conn, readCh) + if errRead != nil { + if sess != nil && ctx != nil && ctx.Err() != nil { + terminateReason = "context_done" + terminateErr = ctx.Err() + _ = send(cliproxyexecutor.StreamChunk{Err: ctx.Err()}) + return + } + terminateReason = "read_error" + terminateErr = errRead + recordAPIResponseError(ctx, e.cfg, errRead) + reporter.publishFailure(ctx) + _ = send(cliproxyexecutor.StreamChunk{Err: errRead}) + return + } + if msgType != websocket.TextMessage { + if msgType == websocket.BinaryMessage { + err = fmt.Errorf("codex websockets executor: unexpected binary message") + terminateReason = "unexpected_binary" + terminateErr = err + recordAPIResponseError(ctx, e.cfg, err) + reporter.publishFailure(ctx) + if sess != nil { + e.invalidateUpstreamConn(sess, conn, "unexpected_binary", err) + } + _ = send(cliproxyexecutor.StreamChunk{Err: err}) + return + } + continue + } + + payload = bytes.TrimSpace(payload) + if len(payload) == 0 { + continue + } + appendAPIResponseChunk(ctx, e.cfg, payload) + + if wsErr, ok := parseCodexWebsocketError(payload); ok { + terminateReason = "upstream_error" + terminateErr = wsErr + recordAPIResponseError(ctx, e.cfg, wsErr) + reporter.publishFailure(ctx) + if sess != nil { + e.invalidateUpstreamConn(sess, conn, "upstream_error", wsErr) + } + _ = send(cliproxyexecutor.StreamChunk{Err: wsErr}) + return + } + + payload = normalizeCodexWebsocketCompletion(payload) + eventType := gjson.GetBytes(payload, "type").String() + if eventType == "response.completed" || eventType == "response.done" { + if detail, ok := parseCodexUsage(payload); ok { + reporter.publish(ctx, detail) + } + } + + line := encodeCodexWebsocketAsSSE(payload) + chunks := sdktranslator.TranslateStream(ctx, to, from, req.Model, body, body, line, ¶m) + for i := range chunks { + if !send(cliproxyexecutor.StreamChunk{Payload: []byte(chunks[i])}) { + terminateReason = "context_done" + terminateErr = ctx.Err() + return + } + } + if eventType == "response.completed" || eventType == "response.done" { + return + } + } + }() + + return stream, nil +} + +func (e *CodexWebsocketsExecutor) dialCodexWebsocket(ctx context.Context, auth *cliproxyauth.Auth, wsURL string, headers http.Header) (*websocket.Conn, *http.Response, error) { + dialer := newProxyAwareWebsocketDialer(e.cfg, auth) + dialer.HandshakeTimeout = codexResponsesWebsocketHandshakeTO + dialer.EnableCompression = true + if ctx == nil { + ctx = context.Background() + } + conn, resp, err := dialer.DialContext(ctx, wsURL, headers) + if conn != nil { + // Avoid gorilla/websocket flate tail validation issues on some upstreams/Go versions. + // Negotiating permessage-deflate is fine; we just don't compress outbound messages. + conn.EnableWriteCompression(false) + } + return conn, resp, err +} + +func writeCodexWebsocketMessage(sess *codexWebsocketSession, conn *websocket.Conn, payload []byte) error { + if sess != nil { + return sess.writeMessage(conn, websocket.TextMessage, payload) + } + if conn == nil { + return fmt.Errorf("codex websockets executor: websocket conn is nil") + } + return conn.WriteMessage(websocket.TextMessage, payload) +} + +func buildCodexWebsocketRequestBody(body []byte, allowAppend bool) []byte { + if len(body) == 0 { + return nil + } + + // Codex CLI websocket v2 uses `response.create` with `previous_response_id` for incremental turns. + // The upstream ChatGPT Codex websocket currently rejects that with close 1008 (policy violation). + // Fall back to v1 `response.append` semantics on the same websocket connection to keep the session alive. + // + // NOTE: The upstream expects the first websocket event on each connection to be `response.create`, + // so we only use `response.append` after we have initialized the current connection. + if allowAppend { + if prev := strings.TrimSpace(gjson.GetBytes(body, "previous_response_id").String()); prev != "" { + inputNode := gjson.GetBytes(body, "input") + wsReqBody := []byte(`{}`) + wsReqBody, _ = sjson.SetBytes(wsReqBody, "type", "response.append") + if inputNode.Exists() && inputNode.IsArray() && strings.TrimSpace(inputNode.Raw) != "" { + wsReqBody, _ = sjson.SetRawBytes(wsReqBody, "input", []byte(inputNode.Raw)) + return wsReqBody + } + wsReqBody, _ = sjson.SetRawBytes(wsReqBody, "input", []byte("[]")) + return wsReqBody + } + } + + wsReqBody, errSet := sjson.SetBytes(bytes.Clone(body), "type", "response.create") + if errSet == nil && len(wsReqBody) > 0 { + return wsReqBody + } + fallback := bytes.Clone(body) + fallback, _ = sjson.SetBytes(fallback, "type", "response.create") + return fallback +} + +func readCodexWebsocketMessage(ctx context.Context, sess *codexWebsocketSession, conn *websocket.Conn, readCh chan codexWebsocketRead) (int, []byte, error) { + if sess == nil { + if conn == nil { + return 0, nil, fmt.Errorf("codex websockets executor: websocket conn is nil") + } + _ = conn.SetReadDeadline(time.Now().Add(codexResponsesWebsocketIdleTimeout)) + msgType, payload, errRead := conn.ReadMessage() + return msgType, payload, errRead + } + if conn == nil { + return 0, nil, fmt.Errorf("codex websockets executor: websocket conn is nil") + } + if readCh == nil { + return 0, nil, fmt.Errorf("codex websockets executor: session read channel is nil") + } + for { + select { + case <-ctx.Done(): + return 0, nil, ctx.Err() + case ev, ok := <-readCh: + if !ok { + return 0, nil, fmt.Errorf("codex websockets executor: session read channel closed") + } + if ev.conn != conn { + continue + } + if ev.err != nil { + return 0, nil, ev.err + } + return ev.msgType, ev.payload, nil + } + } +} + +func markCodexWebsocketCreateSent(sess *codexWebsocketSession, conn *websocket.Conn, payload []byte) { + if sess == nil || conn == nil || len(payload) == 0 { + return + } + if strings.TrimSpace(gjson.GetBytes(payload, "type").String()) != "response.create" { + return + } + + sess.connMu.Lock() + if sess.conn == conn { + sess.connCreateSent = true + } + sess.connMu.Unlock() +} + +func newProxyAwareWebsocketDialer(cfg *config.Config, auth *cliproxyauth.Auth) *websocket.Dialer { + dialer := &websocket.Dialer{ + Proxy: http.ProxyFromEnvironment, + HandshakeTimeout: codexResponsesWebsocketHandshakeTO, + EnableCompression: true, + NetDialContext: (&net.Dialer{ + Timeout: 30 * time.Second, + KeepAlive: 30 * time.Second, + }).DialContext, + } + + proxyURL := "" + if auth != nil { + proxyURL = strings.TrimSpace(auth.ProxyURL) + } + if proxyURL == "" && cfg != nil { + proxyURL = strings.TrimSpace(cfg.ProxyURL) + } + if proxyURL == "" { + return dialer + } + + parsedURL, errParse := url.Parse(proxyURL) + if errParse != nil { + log.Errorf("codex websockets executor: parse proxy URL failed: %v", errParse) + return dialer + } + + switch parsedURL.Scheme { + case "socks5": + var proxyAuth *proxy.Auth + if parsedURL.User != nil { + username := parsedURL.User.Username() + password, _ := parsedURL.User.Password() + proxyAuth = &proxy.Auth{User: username, Password: password} + } + socksDialer, errSOCKS5 := proxy.SOCKS5("tcp", parsedURL.Host, proxyAuth, proxy.Direct) + if errSOCKS5 != nil { + log.Errorf("codex websockets executor: create SOCKS5 dialer failed: %v", errSOCKS5) + return dialer + } + dialer.Proxy = nil + dialer.NetDialContext = func(_ context.Context, network, addr string) (net.Conn, error) { + return socksDialer.Dial(network, addr) + } + case "http", "https": + dialer.Proxy = http.ProxyURL(parsedURL) + default: + log.Errorf("codex websockets executor: unsupported proxy scheme: %s", parsedURL.Scheme) + } + + return dialer +} + +func buildCodexResponsesWebsocketURL(httpURL string) (string, error) { + parsed, err := url.Parse(strings.TrimSpace(httpURL)) + if err != nil { + return "", err + } + switch strings.ToLower(parsed.Scheme) { + case "http": + parsed.Scheme = "ws" + case "https": + parsed.Scheme = "wss" + } + return parsed.String(), nil +} + +func applyCodexPromptCacheHeaders(from sdktranslator.Format, req cliproxyexecutor.Request, rawJSON []byte) ([]byte, http.Header) { + headers := http.Header{} + if len(rawJSON) == 0 { + return rawJSON, headers + } + + var cache codexCache + if from == "claude" { + userIDResult := gjson.GetBytes(req.Payload, "metadata.user_id") + if userIDResult.Exists() { + key := fmt.Sprintf("%s-%s", req.Model, userIDResult.String()) + if cached, ok := getCodexCache(key); ok { + cache = cached + } else { + cache = codexCache{ + ID: uuid.New().String(), + Expire: time.Now().Add(1 * time.Hour), + } + setCodexCache(key, cache) + } + } + } else if from == "openai-response" { + if promptCacheKey := gjson.GetBytes(req.Payload, "prompt_cache_key"); promptCacheKey.Exists() { + cache.ID = promptCacheKey.String() + } + } + + if cache.ID != "" { + rawJSON, _ = sjson.SetBytes(rawJSON, "prompt_cache_key", cache.ID) + headers.Set("Conversation_id", cache.ID) + headers.Set("Session_id", cache.ID) + } + + return rawJSON, headers +} + +func applyCodexWebsocketHeaders(ctx context.Context, headers http.Header, auth *cliproxyauth.Auth, token string) http.Header { + if headers == nil { + headers = http.Header{} + } + if strings.TrimSpace(token) != "" { + headers.Set("Authorization", "Bearer "+token) + } + + var ginHeaders http.Header + if ginCtx := ginContextFrom(ctx); ginCtx != nil && ginCtx.Request != nil { + ginHeaders = ginCtx.Request.Header + } + + misc.EnsureHeader(headers, ginHeaders, "x-codex-beta-features", "") + misc.EnsureHeader(headers, ginHeaders, "x-codex-turn-state", "") + misc.EnsureHeader(headers, ginHeaders, "x-codex-turn-metadata", "") + misc.EnsureHeader(headers, ginHeaders, "x-responsesapi-include-timing-metrics", "") + + misc.EnsureHeader(headers, ginHeaders, "Version", codexClientVersion) + betaHeader := strings.TrimSpace(headers.Get("OpenAI-Beta")) + if betaHeader == "" && ginHeaders != nil { + betaHeader = strings.TrimSpace(ginHeaders.Get("OpenAI-Beta")) + } + if betaHeader == "" || !strings.Contains(betaHeader, "responses_websockets=") { + betaHeader = codexResponsesWebsocketBetaHeaderValue + } + headers.Set("OpenAI-Beta", betaHeader) + misc.EnsureHeader(headers, ginHeaders, "Session_id", uuid.NewString()) + misc.EnsureHeader(headers, ginHeaders, "User-Agent", codexUserAgent) + + isAPIKey := false + if auth != nil && auth.Attributes != nil { + if v := strings.TrimSpace(auth.Attributes["api_key"]); v != "" { + isAPIKey = true + } + } + if !isAPIKey { + headers.Set("Originator", "codex_cli_rs") + if auth != nil && auth.Metadata != nil { + if accountID, ok := auth.Metadata["account_id"].(string); ok { + if trimmed := strings.TrimSpace(accountID); trimmed != "" { + headers.Set("Chatgpt-Account-Id", trimmed) + } + } + } + } + + var attrs map[string]string + if auth != nil { + attrs = auth.Attributes + } + util.ApplyCustomHeadersFromAttrs(&http.Request{Header: headers}, attrs) + + return headers +} + +type statusErrWithHeaders struct { + statusErr + headers http.Header +} + +func (e statusErrWithHeaders) Headers() http.Header { + if e.headers == nil { + return nil + } + return e.headers.Clone() +} + +func parseCodexWebsocketError(payload []byte) (error, bool) { + if len(payload) == 0 { + return nil, false + } + if strings.TrimSpace(gjson.GetBytes(payload, "type").String()) != "error" { + return nil, false + } + status := int(gjson.GetBytes(payload, "status").Int()) + if status == 0 { + status = int(gjson.GetBytes(payload, "status_code").Int()) + } + if status <= 0 { + return nil, false + } + + out := []byte(`{}`) + if errNode := gjson.GetBytes(payload, "error"); errNode.Exists() { + raw := errNode.Raw + if errNode.Type == gjson.String { + raw = errNode.Raw + } + out, _ = sjson.SetRawBytes(out, "error", []byte(raw)) + } else { + out, _ = sjson.SetBytes(out, "error.type", "server_error") + out, _ = sjson.SetBytes(out, "error.message", http.StatusText(status)) + } + + headers := parseCodexWebsocketErrorHeaders(payload) + return statusErrWithHeaders{ + statusErr: statusErr{code: status, msg: string(out)}, + headers: headers, + }, true +} + +func parseCodexWebsocketErrorHeaders(payload []byte) http.Header { + headersNode := gjson.GetBytes(payload, "headers") + if !headersNode.Exists() || !headersNode.IsObject() { + return nil + } + mapped := make(http.Header) + headersNode.ForEach(func(key, value gjson.Result) bool { + name := strings.TrimSpace(key.String()) + if name == "" { + return true + } + switch value.Type { + case gjson.String: + if v := strings.TrimSpace(value.String()); v != "" { + mapped.Set(name, v) + } + case gjson.Number, gjson.True, gjson.False: + if v := strings.TrimSpace(value.Raw); v != "" { + mapped.Set(name, v) + } + default: + } + return true + }) + if len(mapped) == 0 { + return nil + } + return mapped +} + +func normalizeCodexWebsocketCompletion(payload []byte) []byte { + if strings.TrimSpace(gjson.GetBytes(payload, "type").String()) == "response.done" { + updated, err := sjson.SetBytes(payload, "type", "response.completed") + if err == nil && len(updated) > 0 { + return updated + } + } + return payload +} + +func encodeCodexWebsocketAsSSE(payload []byte) []byte { + if len(payload) == 0 { + return nil + } + line := make([]byte, 0, len("data: ")+len(payload)) + line = append(line, []byte("data: ")...) + line = append(line, payload...) + return line +} + +func websocketHandshakeBody(resp *http.Response) []byte { + if resp == nil || resp.Body == nil { + return nil + } + body, _ := io.ReadAll(resp.Body) + closeHTTPResponseBody(resp, "codex websockets executor: close handshake response body error") + if len(body) == 0 { + return nil + } + return body +} + +func closeHTTPResponseBody(resp *http.Response, logPrefix string) { + if resp == nil || resp.Body == nil { + return + } + if errClose := resp.Body.Close(); errClose != nil { + log.Errorf("%s: %v", logPrefix, errClose) + } +} + +func closeOnContextDone(ctx context.Context, conn *websocket.Conn) chan struct{} { + done := make(chan struct{}) + if ctx == nil || conn == nil { + return done + } + go func() { + select { + case <-done: + case <-ctx.Done(): + _ = conn.Close() + } + }() + return done +} + +func cancelReadOnContextDone(ctx context.Context, conn *websocket.Conn) chan struct{} { + done := make(chan struct{}) + if ctx == nil || conn == nil { + return done + } + go func() { + select { + case <-done: + case <-ctx.Done(): + _ = conn.SetReadDeadline(time.Now()) + } + }() + return done +} + +func executionSessionIDFromOptions(opts cliproxyexecutor.Options) string { + if len(opts.Metadata) == 0 { + return "" + } + raw, ok := opts.Metadata[cliproxyexecutor.ExecutionSessionMetadataKey] + if !ok || raw == nil { + return "" + } + switch v := raw.(type) { + case string: + return strings.TrimSpace(v) + case []byte: + return strings.TrimSpace(string(v)) + default: + return "" + } +} + +func (e *CodexWebsocketsExecutor) getOrCreateSession(sessionID string) *codexWebsocketSession { + sessionID = strings.TrimSpace(sessionID) + if sessionID == "" { + return nil + } + e.sessMu.Lock() + defer e.sessMu.Unlock() + if e.sessions == nil { + e.sessions = make(map[string]*codexWebsocketSession) + } + if sess, ok := e.sessions[sessionID]; ok && sess != nil { + return sess + } + sess := &codexWebsocketSession{sessionID: sessionID} + e.sessions[sessionID] = sess + return sess +} + +func (e *CodexWebsocketsExecutor) ensureUpstreamConn(ctx context.Context, auth *cliproxyauth.Auth, sess *codexWebsocketSession, authID string, wsURL string, headers http.Header) (*websocket.Conn, *http.Response, error) { + if sess == nil { + return e.dialCodexWebsocket(ctx, auth, wsURL, headers) + } + + sess.connMu.Lock() + conn := sess.conn + readerConn := sess.readerConn + sess.connMu.Unlock() + if conn != nil { + if readerConn != conn { + sess.connMu.Lock() + sess.readerConn = conn + sess.connMu.Unlock() + sess.configureConn(conn) + go e.readUpstreamLoop(sess, conn) + } + return conn, nil, nil + } + + conn, resp, errDial := e.dialCodexWebsocket(ctx, auth, wsURL, headers) + if errDial != nil { + return nil, resp, errDial + } + + sess.connMu.Lock() + if sess.conn != nil { + previous := sess.conn + sess.connMu.Unlock() + if errClose := conn.Close(); errClose != nil { + log.Errorf("codex websockets executor: close websocket error: %v", errClose) + } + return previous, nil, nil + } + sess.conn = conn + sess.wsURL = wsURL + sess.authID = authID + sess.connCreateSent = false + sess.readerConn = conn + sess.connMu.Unlock() + + sess.configureConn(conn) + go e.readUpstreamLoop(sess, conn) + logCodexWebsocketConnected(sess.sessionID, authID, wsURL) + return conn, resp, nil +} + +func (e *CodexWebsocketsExecutor) readUpstreamLoop(sess *codexWebsocketSession, conn *websocket.Conn) { + if e == nil || sess == nil || conn == nil { + return + } + for { + _ = conn.SetReadDeadline(time.Now().Add(codexResponsesWebsocketIdleTimeout)) + msgType, payload, errRead := conn.ReadMessage() + if errRead != nil { + sess.activeMu.Lock() + ch := sess.activeCh + done := sess.activeDone + sess.activeMu.Unlock() + if ch != nil { + select { + case ch <- codexWebsocketRead{conn: conn, err: errRead}: + case <-done: + default: + } + sess.clearActive(ch) + close(ch) + } + e.invalidateUpstreamConn(sess, conn, "upstream_disconnected", errRead) + return + } + + if msgType != websocket.TextMessage { + if msgType == websocket.BinaryMessage { + errBinary := fmt.Errorf("codex websockets executor: unexpected binary message") + sess.activeMu.Lock() + ch := sess.activeCh + done := sess.activeDone + sess.activeMu.Unlock() + if ch != nil { + select { + case ch <- codexWebsocketRead{conn: conn, err: errBinary}: + case <-done: + default: + } + sess.clearActive(ch) + close(ch) + } + e.invalidateUpstreamConn(sess, conn, "unexpected_binary", errBinary) + return + } + continue + } + + sess.activeMu.Lock() + ch := sess.activeCh + done := sess.activeDone + sess.activeMu.Unlock() + if ch == nil { + continue + } + select { + case ch <- codexWebsocketRead{conn: conn, msgType: msgType, payload: payload}: + case <-done: + } + } +} + +func (e *CodexWebsocketsExecutor) invalidateUpstreamConn(sess *codexWebsocketSession, conn *websocket.Conn, reason string, err error) { + if sess == nil || conn == nil { + return + } + + sess.connMu.Lock() + current := sess.conn + authID := sess.authID + wsURL := sess.wsURL + sessionID := sess.sessionID + if current == nil || current != conn { + sess.connMu.Unlock() + return + } + sess.conn = nil + sess.connCreateSent = false + if sess.readerConn == conn { + sess.readerConn = nil + } + sess.connMu.Unlock() + + logCodexWebsocketDisconnected(sessionID, authID, wsURL, reason, err) + if errClose := conn.Close(); errClose != nil { + log.Errorf("codex websockets executor: close websocket error: %v", errClose) + } +} + +func (e *CodexWebsocketsExecutor) CloseExecutionSession(sessionID string) { + sessionID = strings.TrimSpace(sessionID) + if e == nil { + return + } + if sessionID == "" { + return + } + if sessionID == cliproxyauth.CloseAllExecutionSessionsID { + e.closeAllExecutionSessions("executor_replaced") + return + } + + e.sessMu.Lock() + sess := e.sessions[sessionID] + delete(e.sessions, sessionID) + e.sessMu.Unlock() + + e.closeExecutionSession(sess, "session_closed") +} + +func (e *CodexWebsocketsExecutor) closeAllExecutionSessions(reason string) { + if e == nil { + return + } + + e.sessMu.Lock() + sessions := make([]*codexWebsocketSession, 0, len(e.sessions)) + for sessionID, sess := range e.sessions { + delete(e.sessions, sessionID) + if sess != nil { + sessions = append(sessions, sess) + } + } + e.sessMu.Unlock() + + for i := range sessions { + e.closeExecutionSession(sessions[i], reason) + } +} + +func (e *CodexWebsocketsExecutor) closeExecutionSession(sess *codexWebsocketSession, reason string) { + if sess == nil { + return + } + reason = strings.TrimSpace(reason) + if reason == "" { + reason = "session_closed" + } + + sess.connMu.Lock() + conn := sess.conn + authID := sess.authID + wsURL := sess.wsURL + sess.conn = nil + sess.connCreateSent = false + if sess.readerConn == conn { + sess.readerConn = nil + } + sessionID := sess.sessionID + sess.connMu.Unlock() + + if conn == nil { + return + } + logCodexWebsocketDisconnected(sessionID, authID, wsURL, reason, nil) + if errClose := conn.Close(); errClose != nil { + log.Errorf("codex websockets executor: close websocket error: %v", errClose) + } +} + +func logCodexWebsocketConnected(sessionID string, authID string, wsURL string) { + log.Infof("codex websockets: upstream connected session=%s auth=%s url=%s", strings.TrimSpace(sessionID), strings.TrimSpace(authID), strings.TrimSpace(wsURL)) +} + +func logCodexWebsocketDisconnected(sessionID string, authID string, wsURL string, reason string, err error) { + if err != nil { + log.Infof("codex websockets: upstream disconnected session=%s auth=%s url=%s reason=%s err=%v", strings.TrimSpace(sessionID), strings.TrimSpace(authID), strings.TrimSpace(wsURL), strings.TrimSpace(reason), err) + return + } + log.Infof("codex websockets: upstream disconnected session=%s auth=%s url=%s reason=%s", strings.TrimSpace(sessionID), strings.TrimSpace(authID), strings.TrimSpace(wsURL), strings.TrimSpace(reason)) +} + +// CodexAutoExecutor routes Codex requests to the websocket transport only when: +// 1. The downstream transport is websocket, and +// 2. The selected auth enables websockets. +// +// For non-websocket downstream requests, it always uses the legacy HTTP implementation. +type CodexAutoExecutor struct { + httpExec *CodexExecutor + wsExec *CodexWebsocketsExecutor +} + +func NewCodexAutoExecutor(cfg *config.Config) *CodexAutoExecutor { + return &CodexAutoExecutor{ + httpExec: NewCodexExecutor(cfg), + wsExec: NewCodexWebsocketsExecutor(cfg), + } +} + +func (e *CodexAutoExecutor) Identifier() string { return "codex" } + +func (e *CodexAutoExecutor) PrepareRequest(req *http.Request, auth *cliproxyauth.Auth) error { + if e == nil || e.httpExec == nil { + return nil + } + return e.httpExec.PrepareRequest(req, auth) +} + +func (e *CodexAutoExecutor) HttpRequest(ctx context.Context, auth *cliproxyauth.Auth, req *http.Request) (*http.Response, error) { + if e == nil || e.httpExec == nil { + return nil, fmt.Errorf("codex auto executor: http executor is nil") + } + return e.httpExec.HttpRequest(ctx, auth, req) +} + +func (e *CodexAutoExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) { + if e == nil || e.httpExec == nil || e.wsExec == nil { + return cliproxyexecutor.Response{}, fmt.Errorf("codex auto executor: executor is nil") + } + if cliproxyexecutor.DownstreamWebsocket(ctx) && codexWebsocketsEnabled(auth) { + return e.wsExec.Execute(ctx, auth, req, opts) + } + return e.httpExec.Execute(ctx, auth, req, opts) +} + +func (e *CodexAutoExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (<-chan cliproxyexecutor.StreamChunk, error) { + if e == nil || e.httpExec == nil || e.wsExec == nil { + return nil, fmt.Errorf("codex auto executor: executor is nil") + } + if cliproxyexecutor.DownstreamWebsocket(ctx) && codexWebsocketsEnabled(auth) { + return e.wsExec.ExecuteStream(ctx, auth, req, opts) + } + return e.httpExec.ExecuteStream(ctx, auth, req, opts) +} + +func (e *CodexAutoExecutor) Refresh(ctx context.Context, auth *cliproxyauth.Auth) (*cliproxyauth.Auth, error) { + if e == nil || e.httpExec == nil { + return nil, fmt.Errorf("codex auto executor: http executor is nil") + } + return e.httpExec.Refresh(ctx, auth) +} + +func (e *CodexAutoExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) { + if e == nil || e.httpExec == nil { + return cliproxyexecutor.Response{}, fmt.Errorf("codex auto executor: http executor is nil") + } + return e.httpExec.CountTokens(ctx, auth, req, opts) +} + +func (e *CodexAutoExecutor) CloseExecutionSession(sessionID string) { + if e == nil || e.wsExec == nil { + return + } + e.wsExec.CloseExecutionSession(sessionID) +} + +func codexWebsocketsEnabled(auth *cliproxyauth.Auth) bool { + if auth == nil { + return false + } + if len(auth.Attributes) > 0 { + if raw := strings.TrimSpace(auth.Attributes["websockets"]); raw != "" { + parsed, errParse := strconv.ParseBool(raw) + if errParse == nil { + return parsed + } + } + } + if len(auth.Metadata) == 0 { + return false + } + raw, ok := auth.Metadata["websockets"] + if !ok || raw == nil { + return false + } + switch v := raw.(type) { + case bool: + return v + case string: + parsed, errParse := strconv.ParseBool(strings.TrimSpace(v)) + if errParse == nil { + return parsed + } + default: + } + return false +} diff --git a/internal/runtime/executor/qwen_executor.go b/internal/runtime/executor/qwen_executor.go index 28b803ad..69e1f7fa 100644 --- a/internal/runtime/executor/qwen_executor.go +++ b/internal/runtime/executor/qwen_executor.go @@ -22,9 +22,7 @@ import ( ) const ( - qwenUserAgent = "google-api-nodejs-client/9.15.1" - qwenXGoogAPIClient = "gl-node/22.17.0" - qwenClientMetadataValue = "ideType=IDE_UNSPECIFIED,platform=PLATFORM_UNSPECIFIED,pluginType=GEMINI" + qwenUserAgent = "QwenCode/0.10.3 (darwin; arm64)" ) // QwenExecutor is a stateless executor for Qwen Code using OpenAI-compatible chat completions. @@ -344,8 +342,18 @@ func applyQwenHeaders(r *http.Request, token string, stream bool) { r.Header.Set("Content-Type", "application/json") r.Header.Set("Authorization", "Bearer "+token) r.Header.Set("User-Agent", qwenUserAgent) - r.Header.Set("X-Goog-Api-Client", qwenXGoogAPIClient) - r.Header.Set("Client-Metadata", qwenClientMetadataValue) + r.Header.Set("X-Dashscope-Useragent", qwenUserAgent) + r.Header.Set("X-Stainless-Runtime-Version", "v22.17.0") + r.Header.Set("Sec-Fetch-Mode", "cors") + r.Header.Set("X-Stainless-Lang", "js") + r.Header.Set("X-Stainless-Arch", "arm64") + r.Header.Set("X-Stainless-Package-Version", "5.11.0") + r.Header.Set("X-Dashscope-Cachecontrol", "enable") + r.Header.Set("X-Stainless-Retry-Count", "0") + r.Header.Set("X-Stainless-Os", "MacOS") + r.Header.Set("X-Dashscope-Authtype", "qwen-oauth") + r.Header.Set("X-Stainless-Runtime", "node") + if stream { r.Header.Set("Accept", "text/event-stream") return diff --git a/internal/watcher/diff/config_diff.go b/internal/watcher/diff/config_diff.go index 98698ead..6687749e 100644 --- a/internal/watcher/diff/config_diff.go +++ b/internal/watcher/diff/config_diff.go @@ -184,6 +184,9 @@ func BuildConfigChangeDetails(oldCfg, newCfg *config.Config) []string { if strings.TrimSpace(o.Prefix) != strings.TrimSpace(n.Prefix) { changes = append(changes, fmt.Sprintf("codex[%d].prefix: %s -> %s", i, strings.TrimSpace(o.Prefix), strings.TrimSpace(n.Prefix))) } + if o.Websockets != n.Websockets { + changes = append(changes, fmt.Sprintf("codex[%d].websockets: %t -> %t", i, o.Websockets, n.Websockets)) + } if strings.TrimSpace(o.APIKey) != strings.TrimSpace(n.APIKey) { changes = append(changes, fmt.Sprintf("codex[%d].api-key: updated", i)) } diff --git a/internal/watcher/synthesizer/config.go b/internal/watcher/synthesizer/config.go index b1ae5885..69194efc 100644 --- a/internal/watcher/synthesizer/config.go +++ b/internal/watcher/synthesizer/config.go @@ -160,6 +160,9 @@ func (s *ConfigSynthesizer) synthesizeCodexKeys(ctx *SynthesisContext) []*coreau if ck.BaseURL != "" { attrs["base_url"] = ck.BaseURL } + if ck.Websockets { + attrs["websockets"] = "true" + } if hash := diff.ComputeCodexModelsHash(ck.Models); hash != "" { attrs["models_hash"] = hash } diff --git a/internal/watcher/synthesizer/config_test.go b/internal/watcher/synthesizer/config_test.go index 32af7c27..437f18d1 100644 --- a/internal/watcher/synthesizer/config_test.go +++ b/internal/watcher/synthesizer/config_test.go @@ -231,10 +231,11 @@ func TestConfigSynthesizer_CodexKeys(t *testing.T) { Config: &config.Config{ CodexKey: []config.CodexKey{ { - APIKey: "codex-key-123", - Prefix: "dev", - BaseURL: "https://api.openai.com", - ProxyURL: "http://proxy.local", + APIKey: "codex-key-123", + Prefix: "dev", + BaseURL: "https://api.openai.com", + ProxyURL: "http://proxy.local", + Websockets: true, }, }, }, @@ -259,6 +260,9 @@ func TestConfigSynthesizer_CodexKeys(t *testing.T) { if auths[0].ProxyURL != "http://proxy.local" { t.Errorf("expected proxy_url http://proxy.local, got %s", auths[0].ProxyURL) } + if auths[0].Attributes["websockets"] != "true" { + t.Errorf("expected websockets=true, got %s", auths[0].Attributes["websockets"]) + } } func TestConfigSynthesizer_CodexKeys_SkipsEmptyAndHeaders(t *testing.T) { diff --git a/sdk/api/handlers/handlers.go b/sdk/api/handlers/handlers.go index 4ad2efb0..23ef6535 100644 --- a/sdk/api/handlers/handlers.go +++ b/sdk/api/handlers/handlers.go @@ -52,6 +52,45 @@ const ( defaultStreamingBootstrapRetries = 0 ) +type pinnedAuthContextKey struct{} +type selectedAuthCallbackContextKey struct{} +type executionSessionContextKey struct{} + +// WithPinnedAuthID returns a child context that requests execution on a specific auth ID. +func WithPinnedAuthID(ctx context.Context, authID string) context.Context { + authID = strings.TrimSpace(authID) + if authID == "" { + return ctx + } + if ctx == nil { + ctx = context.Background() + } + return context.WithValue(ctx, pinnedAuthContextKey{}, authID) +} + +// WithSelectedAuthIDCallback returns a child context that receives the selected auth ID. +func WithSelectedAuthIDCallback(ctx context.Context, callback func(string)) context.Context { + if callback == nil { + return ctx + } + if ctx == nil { + ctx = context.Background() + } + return context.WithValue(ctx, selectedAuthCallbackContextKey{}, callback) +} + +// WithExecutionSessionID returns a child context tagged with a long-lived execution session ID. +func WithExecutionSessionID(ctx context.Context, sessionID string) context.Context { + sessionID = strings.TrimSpace(sessionID) + if sessionID == "" { + return ctx + } + if ctx == nil { + ctx = context.Background() + } + return context.WithValue(ctx, executionSessionContextKey{}, sessionID) +} + // BuildErrorResponseBody builds an OpenAI-compatible JSON error response body. // If errText is already valid JSON, it is returned as-is to preserve upstream error payloads. func BuildErrorResponseBody(status int, errText string) []byte { @@ -152,7 +191,59 @@ func requestExecutionMetadata(ctx context.Context) map[string]any { if key == "" { key = uuid.NewString() } - return map[string]any{idempotencyKeyMetadataKey: key} + + meta := map[string]any{idempotencyKeyMetadataKey: key} + if pinnedAuthID := pinnedAuthIDFromContext(ctx); pinnedAuthID != "" { + meta[coreexecutor.PinnedAuthMetadataKey] = pinnedAuthID + } + if selectedCallback := selectedAuthIDCallbackFromContext(ctx); selectedCallback != nil { + meta[coreexecutor.SelectedAuthCallbackMetadataKey] = selectedCallback + } + if executionSessionID := executionSessionIDFromContext(ctx); executionSessionID != "" { + meta[coreexecutor.ExecutionSessionMetadataKey] = executionSessionID + } + return meta +} + +func pinnedAuthIDFromContext(ctx context.Context) string { + if ctx == nil { + return "" + } + raw := ctx.Value(pinnedAuthContextKey{}) + switch v := raw.(type) { + case string: + return strings.TrimSpace(v) + case []byte: + return strings.TrimSpace(string(v)) + default: + return "" + } +} + +func selectedAuthIDCallbackFromContext(ctx context.Context) func(string) { + if ctx == nil { + return nil + } + raw := ctx.Value(selectedAuthCallbackContextKey{}) + if callback, ok := raw.(func(string)); ok && callback != nil { + return callback + } + return nil +} + +func executionSessionIDFromContext(ctx context.Context) string { + if ctx == nil { + return "" + } + raw := ctx.Value(executionSessionContextKey{}) + switch v := raw.(type) { + case string: + return strings.TrimSpace(v) + case []byte: + return strings.TrimSpace(string(v)) + default: + return "" + } } // BaseAPIHandler contains the handlers for API endpoints. diff --git a/sdk/api/handlers/handlers_stream_bootstrap_test.go b/sdk/api/handlers/handlers_stream_bootstrap_test.go index 7814ff1b..66a49e52 100644 --- a/sdk/api/handlers/handlers_stream_bootstrap_test.go +++ b/sdk/api/handlers/handlers_stream_bootstrap_test.go @@ -122,6 +122,82 @@ func (e *payloadThenErrorStreamExecutor) Calls() int { return e.calls } +type authAwareStreamExecutor struct { + mu sync.Mutex + calls int + authIDs []string +} + +func (e *authAwareStreamExecutor) Identifier() string { return "codex" } + +func (e *authAwareStreamExecutor) Execute(context.Context, *coreauth.Auth, coreexecutor.Request, coreexecutor.Options) (coreexecutor.Response, error) { + return coreexecutor.Response{}, &coreauth.Error{Code: "not_implemented", Message: "Execute not implemented"} +} + +func (e *authAwareStreamExecutor) ExecuteStream(ctx context.Context, auth *coreauth.Auth, req coreexecutor.Request, opts coreexecutor.Options) (<-chan coreexecutor.StreamChunk, error) { + _ = ctx + _ = req + _ = opts + ch := make(chan coreexecutor.StreamChunk, 1) + + authID := "" + if auth != nil { + authID = auth.ID + } + + e.mu.Lock() + e.calls++ + e.authIDs = append(e.authIDs, authID) + e.mu.Unlock() + + if authID == "auth1" { + ch <- coreexecutor.StreamChunk{ + Err: &coreauth.Error{ + Code: "unauthorized", + Message: "unauthorized", + Retryable: false, + HTTPStatus: http.StatusUnauthorized, + }, + } + close(ch) + return ch, nil + } + + ch <- coreexecutor.StreamChunk{Payload: []byte("ok")} + close(ch) + return ch, nil +} + +func (e *authAwareStreamExecutor) Refresh(ctx context.Context, auth *coreauth.Auth) (*coreauth.Auth, error) { + return auth, nil +} + +func (e *authAwareStreamExecutor) CountTokens(context.Context, *coreauth.Auth, coreexecutor.Request, coreexecutor.Options) (coreexecutor.Response, error) { + return coreexecutor.Response{}, &coreauth.Error{Code: "not_implemented", Message: "CountTokens not implemented"} +} + +func (e *authAwareStreamExecutor) HttpRequest(ctx context.Context, auth *coreauth.Auth, req *http.Request) (*http.Response, error) { + return nil, &coreauth.Error{ + Code: "not_implemented", + Message: "HttpRequest not implemented", + HTTPStatus: http.StatusNotImplemented, + } +} + +func (e *authAwareStreamExecutor) Calls() int { + e.mu.Lock() + defer e.mu.Unlock() + return e.calls +} + +func (e *authAwareStreamExecutor) AuthIDs() []string { + e.mu.Lock() + defer e.mu.Unlock() + out := make([]string, len(e.authIDs)) + copy(out, e.authIDs) + return out +} + func TestExecuteStreamWithAuthManager_RetriesBeforeFirstByte(t *testing.T) { executor := &failOnceStreamExecutor{} manager := coreauth.NewManager(nil, nil, nil) @@ -252,3 +328,128 @@ func TestExecuteStreamWithAuthManager_DoesNotRetryAfterFirstByte(t *testing.T) { t.Fatalf("expected 1 stream attempt, got %d", executor.Calls()) } } + +func TestExecuteStreamWithAuthManager_PinnedAuthKeepsSameUpstream(t *testing.T) { + executor := &authAwareStreamExecutor{} + manager := coreauth.NewManager(nil, nil, nil) + manager.RegisterExecutor(executor) + + auth1 := &coreauth.Auth{ + ID: "auth1", + Provider: "codex", + Status: coreauth.StatusActive, + Metadata: map[string]any{"email": "test1@example.com"}, + } + if _, err := manager.Register(context.Background(), auth1); err != nil { + t.Fatalf("manager.Register(auth1): %v", err) + } + + auth2 := &coreauth.Auth{ + ID: "auth2", + Provider: "codex", + Status: coreauth.StatusActive, + Metadata: map[string]any{"email": "test2@example.com"}, + } + if _, err := manager.Register(context.Background(), auth2); err != nil { + t.Fatalf("manager.Register(auth2): %v", err) + } + + registry.GetGlobalRegistry().RegisterClient(auth1.ID, auth1.Provider, []*registry.ModelInfo{{ID: "test-model"}}) + registry.GetGlobalRegistry().RegisterClient(auth2.ID, auth2.Provider, []*registry.ModelInfo{{ID: "test-model"}}) + t.Cleanup(func() { + registry.GetGlobalRegistry().UnregisterClient(auth1.ID) + registry.GetGlobalRegistry().UnregisterClient(auth2.ID) + }) + + handler := NewBaseAPIHandlers(&sdkconfig.SDKConfig{ + Streaming: sdkconfig.StreamingConfig{ + BootstrapRetries: 1, + }, + }, manager) + ctx := WithPinnedAuthID(context.Background(), "auth1") + dataChan, errChan := handler.ExecuteStreamWithAuthManager(ctx, "openai", "test-model", []byte(`{"model":"test-model"}`), "") + if dataChan == nil || errChan == nil { + t.Fatalf("expected non-nil channels") + } + + var got []byte + for chunk := range dataChan { + got = append(got, chunk...) + } + + var gotErr error + for msg := range errChan { + if msg != nil && msg.Error != nil { + gotErr = msg.Error + } + } + + if len(got) != 0 { + t.Fatalf("expected empty payload, got %q", string(got)) + } + if gotErr == nil { + t.Fatalf("expected terminal error, got nil") + } + authIDs := executor.AuthIDs() + if len(authIDs) == 0 { + t.Fatalf("expected at least one upstream attempt") + } + for _, authID := range authIDs { + if authID != "auth1" { + t.Fatalf("expected all attempts on auth1, got sequence %v", authIDs) + } + } +} + +func TestExecuteStreamWithAuthManager_SelectedAuthCallbackReceivesAuthID(t *testing.T) { + executor := &authAwareStreamExecutor{} + manager := coreauth.NewManager(nil, nil, nil) + manager.RegisterExecutor(executor) + + auth2 := &coreauth.Auth{ + ID: "auth2", + Provider: "codex", + Status: coreauth.StatusActive, + Metadata: map[string]any{"email": "test2@example.com"}, + } + if _, err := manager.Register(context.Background(), auth2); err != nil { + t.Fatalf("manager.Register(auth2): %v", err) + } + + registry.GetGlobalRegistry().RegisterClient(auth2.ID, auth2.Provider, []*registry.ModelInfo{{ID: "test-model"}}) + t.Cleanup(func() { + registry.GetGlobalRegistry().UnregisterClient(auth2.ID) + }) + + handler := NewBaseAPIHandlers(&sdkconfig.SDKConfig{ + Streaming: sdkconfig.StreamingConfig{ + BootstrapRetries: 0, + }, + }, manager) + + selectedAuthID := "" + ctx := WithSelectedAuthIDCallback(context.Background(), func(authID string) { + selectedAuthID = authID + }) + dataChan, errChan := handler.ExecuteStreamWithAuthManager(ctx, "openai", "test-model", []byte(`{"model":"test-model"}`), "") + if dataChan == nil || errChan == nil { + t.Fatalf("expected non-nil channels") + } + + var got []byte + for chunk := range dataChan { + got = append(got, chunk...) + } + for msg := range errChan { + if msg != nil { + t.Fatalf("unexpected error: %+v", msg) + } + } + + if string(got) != "ok" { + t.Fatalf("expected payload ok, got %q", string(got)) + } + if selectedAuthID != "auth2" { + t.Fatalf("selectedAuthID = %q, want %q", selectedAuthID, "auth2") + } +} diff --git a/sdk/api/handlers/openai/openai_responses_websocket.go b/sdk/api/handlers/openai/openai_responses_websocket.go new file mode 100644 index 00000000..bcf09311 --- /dev/null +++ b/sdk/api/handlers/openai/openai_responses_websocket.go @@ -0,0 +1,662 @@ +package openai + +import ( + "bytes" + "context" + "encoding/json" + "fmt" + "net/http" + "strconv" + "strings" + "time" + + "github.com/gin-gonic/gin" + "github.com/google/uuid" + "github.com/gorilla/websocket" + "github.com/router-for-me/CLIProxyAPI/v6/internal/interfaces" + "github.com/router-for-me/CLIProxyAPI/v6/sdk/api/handlers" + cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor" + log "github.com/sirupsen/logrus" + "github.com/tidwall/gjson" + "github.com/tidwall/sjson" +) + +const ( + wsRequestTypeCreate = "response.create" + wsRequestTypeAppend = "response.append" + wsEventTypeError = "error" + wsEventTypeCompleted = "response.completed" + wsEventTypeDone = "response.done" + wsDoneMarker = "[DONE]" + wsTurnStateHeader = "x-codex-turn-state" + wsRequestBodyKey = "REQUEST_BODY_OVERRIDE" + wsPayloadLogMaxSize = 2048 +) + +var responsesWebsocketUpgrader = websocket.Upgrader{ + ReadBufferSize: 4096, + WriteBufferSize: 4096, + CheckOrigin: func(r *http.Request) bool { + return true + }, +} + +// ResponsesWebsocket handles websocket requests for /v1/responses. +// It accepts `response.create` and `response.append` requests and streams +// response events back as JSON websocket text messages. +func (h *OpenAIResponsesAPIHandler) ResponsesWebsocket(c *gin.Context) { + conn, err := responsesWebsocketUpgrader.Upgrade(c.Writer, c.Request, websocketUpgradeHeaders(c.Request)) + if err != nil { + return + } + passthroughSessionID := uuid.NewString() + clientRemoteAddr := "" + if c != nil && c.Request != nil { + clientRemoteAddr = strings.TrimSpace(c.Request.RemoteAddr) + } + log.Infof("responses websocket: client connected id=%s remote=%s", passthroughSessionID, clientRemoteAddr) + var wsTerminateErr error + var wsBodyLog strings.Builder + defer func() { + if wsTerminateErr != nil { + // log.Infof("responses websocket: session closing id=%s reason=%v", passthroughSessionID, wsTerminateErr) + } else { + log.Infof("responses websocket: session closing id=%s", passthroughSessionID) + } + if h != nil && h.AuthManager != nil { + h.AuthManager.CloseExecutionSession(passthroughSessionID) + log.Infof("responses websocket: upstream execution session closed id=%s", passthroughSessionID) + } + setWebsocketRequestBody(c, wsBodyLog.String()) + if errClose := conn.Close(); errClose != nil { + log.Warnf("responses websocket: close connection error: %v", errClose) + } + }() + + var lastRequest []byte + lastResponseOutput := []byte("[]") + pinnedAuthID := "" + + for { + msgType, payload, errReadMessage := conn.ReadMessage() + if errReadMessage != nil { + wsTerminateErr = errReadMessage + appendWebsocketEvent(&wsBodyLog, "disconnect", []byte(errReadMessage.Error())) + if websocket.IsCloseError(errReadMessage, websocket.CloseNormalClosure, websocket.CloseGoingAway, websocket.CloseNoStatusReceived) { + log.Infof("responses websocket: client disconnected id=%s error=%v", passthroughSessionID, errReadMessage) + } else { + // log.Warnf("responses websocket: read message failed id=%s error=%v", passthroughSessionID, errReadMessage) + } + return + } + if msgType != websocket.TextMessage && msgType != websocket.BinaryMessage { + continue + } + // log.Infof( + // "responses websocket: downstream_in id=%s type=%d event=%s payload=%s", + // passthroughSessionID, + // msgType, + // websocketPayloadEventType(payload), + // websocketPayloadPreview(payload), + // ) + appendWebsocketEvent(&wsBodyLog, "request", payload) + + allowIncrementalInputWithPreviousResponseID := websocketUpstreamSupportsIncrementalInput(nil, nil) + if pinnedAuthID != "" && h != nil && h.AuthManager != nil { + if pinnedAuth, ok := h.AuthManager.GetByID(pinnedAuthID); ok && pinnedAuth != nil { + allowIncrementalInputWithPreviousResponseID = websocketUpstreamSupportsIncrementalInput(pinnedAuth.Attributes, pinnedAuth.Metadata) + } + } + + var requestJSON []byte + var updatedLastRequest []byte + var errMsg *interfaces.ErrorMessage + requestJSON, updatedLastRequest, errMsg = normalizeResponsesWebsocketRequestWithMode( + payload, + lastRequest, + lastResponseOutput, + allowIncrementalInputWithPreviousResponseID, + ) + if errMsg != nil { + h.LoggingAPIResponseError(context.WithValue(context.Background(), "gin", c), errMsg) + markAPIResponseTimestamp(c) + errorPayload, errWrite := writeResponsesWebsocketError(conn, errMsg) + appendWebsocketEvent(&wsBodyLog, "response", errorPayload) + log.Infof( + "responses websocket: downstream_out id=%s type=%d event=%s payload=%s", + passthroughSessionID, + websocket.TextMessage, + websocketPayloadEventType(errorPayload), + websocketPayloadPreview(errorPayload), + ) + if errWrite != nil { + log.Warnf( + "responses websocket: downstream_out write failed id=%s event=%s error=%v", + passthroughSessionID, + websocketPayloadEventType(errorPayload), + errWrite, + ) + return + } + continue + } + lastRequest = updatedLastRequest + + modelName := gjson.GetBytes(requestJSON, "model").String() + cliCtx, cliCancel := h.GetContextWithCancel(h, c, context.Background()) + cliCtx = cliproxyexecutor.WithDownstreamWebsocket(cliCtx) + cliCtx = handlers.WithExecutionSessionID(cliCtx, passthroughSessionID) + if pinnedAuthID != "" { + cliCtx = handlers.WithPinnedAuthID(cliCtx, pinnedAuthID) + } else { + cliCtx = handlers.WithSelectedAuthIDCallback(cliCtx, func(authID string) { + pinnedAuthID = strings.TrimSpace(authID) + }) + } + dataChan, errChan := h.ExecuteStreamWithAuthManager(cliCtx, h.HandlerType(), modelName, requestJSON, "") + + completedOutput, errForward := h.forwardResponsesWebsocket(c, conn, cliCancel, dataChan, errChan, &wsBodyLog, passthroughSessionID) + if errForward != nil { + wsTerminateErr = errForward + appendWebsocketEvent(&wsBodyLog, "disconnect", []byte(errForward.Error())) + log.Warnf("responses websocket: forward failed id=%s error=%v", passthroughSessionID, errForward) + return + } + lastResponseOutput = completedOutput + } +} + +func websocketUpgradeHeaders(req *http.Request) http.Header { + headers := http.Header{} + if req == nil { + return headers + } + + // Keep the same sticky turn-state across reconnects when provided by the client. + turnState := strings.TrimSpace(req.Header.Get(wsTurnStateHeader)) + if turnState != "" { + headers.Set(wsTurnStateHeader, turnState) + } + return headers +} + +func normalizeResponsesWebsocketRequest(rawJSON []byte, lastRequest []byte, lastResponseOutput []byte) ([]byte, []byte, *interfaces.ErrorMessage) { + return normalizeResponsesWebsocketRequestWithMode(rawJSON, lastRequest, lastResponseOutput, true) +} + +func normalizeResponsesWebsocketRequestWithMode(rawJSON []byte, lastRequest []byte, lastResponseOutput []byte, allowIncrementalInputWithPreviousResponseID bool) ([]byte, []byte, *interfaces.ErrorMessage) { + requestType := strings.TrimSpace(gjson.GetBytes(rawJSON, "type").String()) + switch requestType { + case wsRequestTypeCreate: + // log.Infof("responses websocket: response.create request") + if len(lastRequest) == 0 { + return normalizeResponseCreateRequest(rawJSON) + } + return normalizeResponseSubsequentRequest(rawJSON, lastRequest, lastResponseOutput, allowIncrementalInputWithPreviousResponseID) + case wsRequestTypeAppend: + // log.Infof("responses websocket: response.append request") + return normalizeResponseSubsequentRequest(rawJSON, lastRequest, lastResponseOutput, allowIncrementalInputWithPreviousResponseID) + default: + return nil, lastRequest, &interfaces.ErrorMessage{ + StatusCode: http.StatusBadRequest, + Error: fmt.Errorf("unsupported websocket request type: %s", requestType), + } + } +} + +func normalizeResponseCreateRequest(rawJSON []byte) ([]byte, []byte, *interfaces.ErrorMessage) { + normalized, errDelete := sjson.DeleteBytes(rawJSON, "type") + if errDelete != nil { + normalized = bytes.Clone(rawJSON) + } + normalized, _ = sjson.SetBytes(normalized, "stream", true) + if !gjson.GetBytes(normalized, "input").Exists() { + normalized, _ = sjson.SetRawBytes(normalized, "input", []byte("[]")) + } + + modelName := strings.TrimSpace(gjson.GetBytes(normalized, "model").String()) + if modelName == "" { + return nil, nil, &interfaces.ErrorMessage{ + StatusCode: http.StatusBadRequest, + Error: fmt.Errorf("missing model in response.create request"), + } + } + return normalized, bytes.Clone(normalized), nil +} + +func normalizeResponseSubsequentRequest(rawJSON []byte, lastRequest []byte, lastResponseOutput []byte, allowIncrementalInputWithPreviousResponseID bool) ([]byte, []byte, *interfaces.ErrorMessage) { + if len(lastRequest) == 0 { + return nil, lastRequest, &interfaces.ErrorMessage{ + StatusCode: http.StatusBadRequest, + Error: fmt.Errorf("websocket request received before response.create"), + } + } + + nextInput := gjson.GetBytes(rawJSON, "input") + if !nextInput.Exists() || !nextInput.IsArray() { + return nil, lastRequest, &interfaces.ErrorMessage{ + StatusCode: http.StatusBadRequest, + Error: fmt.Errorf("websocket request requires array field: input"), + } + } + + // Websocket v2 mode uses response.create with previous_response_id + incremental input. + // Do not expand it into a full input transcript; upstream expects the incremental payload. + if allowIncrementalInputWithPreviousResponseID { + if prev := strings.TrimSpace(gjson.GetBytes(rawJSON, "previous_response_id").String()); prev != "" { + normalized, errDelete := sjson.DeleteBytes(rawJSON, "type") + if errDelete != nil { + normalized = bytes.Clone(rawJSON) + } + if !gjson.GetBytes(normalized, "model").Exists() { + modelName := strings.TrimSpace(gjson.GetBytes(lastRequest, "model").String()) + if modelName != "" { + normalized, _ = sjson.SetBytes(normalized, "model", modelName) + } + } + if !gjson.GetBytes(normalized, "instructions").Exists() { + instructions := gjson.GetBytes(lastRequest, "instructions") + if instructions.Exists() { + normalized, _ = sjson.SetRawBytes(normalized, "instructions", []byte(instructions.Raw)) + } + } + normalized, _ = sjson.SetBytes(normalized, "stream", true) + return normalized, bytes.Clone(normalized), nil + } + } + + existingInput := gjson.GetBytes(lastRequest, "input") + mergedInput, errMerge := mergeJSONArrayRaw(existingInput.Raw, normalizeJSONArrayRaw(lastResponseOutput)) + if errMerge != nil { + return nil, lastRequest, &interfaces.ErrorMessage{ + StatusCode: http.StatusBadRequest, + Error: fmt.Errorf("invalid previous response output: %w", errMerge), + } + } + + mergedInput, errMerge = mergeJSONArrayRaw(mergedInput, nextInput.Raw) + if errMerge != nil { + return nil, lastRequest, &interfaces.ErrorMessage{ + StatusCode: http.StatusBadRequest, + Error: fmt.Errorf("invalid request input: %w", errMerge), + } + } + + normalized, errDelete := sjson.DeleteBytes(rawJSON, "type") + if errDelete != nil { + normalized = bytes.Clone(rawJSON) + } + normalized, _ = sjson.DeleteBytes(normalized, "previous_response_id") + var errSet error + normalized, errSet = sjson.SetRawBytes(normalized, "input", []byte(mergedInput)) + if errSet != nil { + return nil, lastRequest, &interfaces.ErrorMessage{ + StatusCode: http.StatusBadRequest, + Error: fmt.Errorf("failed to merge websocket input: %w", errSet), + } + } + if !gjson.GetBytes(normalized, "model").Exists() { + modelName := strings.TrimSpace(gjson.GetBytes(lastRequest, "model").String()) + if modelName != "" { + normalized, _ = sjson.SetBytes(normalized, "model", modelName) + } + } + if !gjson.GetBytes(normalized, "instructions").Exists() { + instructions := gjson.GetBytes(lastRequest, "instructions") + if instructions.Exists() { + normalized, _ = sjson.SetRawBytes(normalized, "instructions", []byte(instructions.Raw)) + } + } + normalized, _ = sjson.SetBytes(normalized, "stream", true) + return normalized, bytes.Clone(normalized), nil +} + +func websocketUpstreamSupportsIncrementalInput(attributes map[string]string, metadata map[string]any) bool { + if len(attributes) > 0 { + if raw := strings.TrimSpace(attributes["websockets"]); raw != "" { + parsed, errParse := strconv.ParseBool(raw) + if errParse == nil { + return parsed + } + } + } + if len(metadata) == 0 { + return false + } + raw, ok := metadata["websockets"] + if !ok || raw == nil { + return false + } + switch value := raw.(type) { + case bool: + return value + case string: + parsed, errParse := strconv.ParseBool(strings.TrimSpace(value)) + if errParse == nil { + return parsed + } + default: + } + return false +} + +func mergeJSONArrayRaw(existingRaw, appendRaw string) (string, error) { + existingRaw = strings.TrimSpace(existingRaw) + appendRaw = strings.TrimSpace(appendRaw) + if existingRaw == "" { + existingRaw = "[]" + } + if appendRaw == "" { + appendRaw = "[]" + } + + var existing []json.RawMessage + if err := json.Unmarshal([]byte(existingRaw), &existing); err != nil { + return "", err + } + var appendItems []json.RawMessage + if err := json.Unmarshal([]byte(appendRaw), &appendItems); err != nil { + return "", err + } + + merged := append(existing, appendItems...) + out, err := json.Marshal(merged) + if err != nil { + return "", err + } + return string(out), nil +} + +func normalizeJSONArrayRaw(raw []byte) string { + trimmed := strings.TrimSpace(string(raw)) + if trimmed == "" { + return "[]" + } + result := gjson.Parse(trimmed) + if result.Type == gjson.JSON && result.IsArray() { + return trimmed + } + return "[]" +} + +func (h *OpenAIResponsesAPIHandler) forwardResponsesWebsocket( + c *gin.Context, + conn *websocket.Conn, + cancel handlers.APIHandlerCancelFunc, + data <-chan []byte, + errs <-chan *interfaces.ErrorMessage, + wsBodyLog *strings.Builder, + sessionID string, +) ([]byte, error) { + completed := false + completedOutput := []byte("[]") + + for { + select { + case <-c.Request.Context().Done(): + cancel(c.Request.Context().Err()) + return completedOutput, c.Request.Context().Err() + case errMsg, ok := <-errs: + if !ok { + errs = nil + continue + } + if errMsg != nil { + h.LoggingAPIResponseError(context.WithValue(context.Background(), "gin", c), errMsg) + markAPIResponseTimestamp(c) + errorPayload, errWrite := writeResponsesWebsocketError(conn, errMsg) + appendWebsocketEvent(wsBodyLog, "response", errorPayload) + log.Infof( + "responses websocket: downstream_out id=%s type=%d event=%s payload=%s", + sessionID, + websocket.TextMessage, + websocketPayloadEventType(errorPayload), + websocketPayloadPreview(errorPayload), + ) + if errWrite != nil { + // log.Warnf( + // "responses websocket: downstream_out write failed id=%s event=%s error=%v", + // sessionID, + // websocketPayloadEventType(errorPayload), + // errWrite, + // ) + cancel(errMsg.Error) + return completedOutput, errWrite + } + } + if errMsg != nil { + cancel(errMsg.Error) + } else { + cancel(nil) + } + return completedOutput, nil + case chunk, ok := <-data: + if !ok { + if !completed { + errMsg := &interfaces.ErrorMessage{ + StatusCode: http.StatusRequestTimeout, + Error: fmt.Errorf("stream closed before response.completed"), + } + h.LoggingAPIResponseError(context.WithValue(context.Background(), "gin", c), errMsg) + markAPIResponseTimestamp(c) + errorPayload, errWrite := writeResponsesWebsocketError(conn, errMsg) + appendWebsocketEvent(wsBodyLog, "response", errorPayload) + log.Infof( + "responses websocket: downstream_out id=%s type=%d event=%s payload=%s", + sessionID, + websocket.TextMessage, + websocketPayloadEventType(errorPayload), + websocketPayloadPreview(errorPayload), + ) + if errWrite != nil { + log.Warnf( + "responses websocket: downstream_out write failed id=%s event=%s error=%v", + sessionID, + websocketPayloadEventType(errorPayload), + errWrite, + ) + cancel(errMsg.Error) + return completedOutput, errWrite + } + cancel(errMsg.Error) + return completedOutput, nil + } + cancel(nil) + return completedOutput, nil + } + + payloads := websocketJSONPayloadsFromChunk(chunk) + for i := range payloads { + eventType := gjson.GetBytes(payloads[i], "type").String() + if eventType == wsEventTypeCompleted { + // log.Infof("replace %s with %s", wsEventTypeCompleted, wsEventTypeDone) + payloads[i], _ = sjson.SetBytes(payloads[i], "type", wsEventTypeDone) + + completed = true + completedOutput = responseCompletedOutputFromPayload(payloads[i]) + } + markAPIResponseTimestamp(c) + appendWebsocketEvent(wsBodyLog, "response", payloads[i]) + // log.Infof( + // "responses websocket: downstream_out id=%s type=%d event=%s payload=%s", + // sessionID, + // websocket.TextMessage, + // websocketPayloadEventType(payloads[i]), + // websocketPayloadPreview(payloads[i]), + // ) + if errWrite := conn.WriteMessage(websocket.TextMessage, payloads[i]); errWrite != nil { + log.Warnf( + "responses websocket: downstream_out write failed id=%s event=%s error=%v", + sessionID, + websocketPayloadEventType(payloads[i]), + errWrite, + ) + cancel(errWrite) + return completedOutput, errWrite + } + } + } + } +} + +func responseCompletedOutputFromPayload(payload []byte) []byte { + output := gjson.GetBytes(payload, "response.output") + if output.Exists() && output.IsArray() { + return bytes.Clone([]byte(output.Raw)) + } + return []byte("[]") +} + +func websocketJSONPayloadsFromChunk(chunk []byte) [][]byte { + payloads := make([][]byte, 0, 2) + lines := bytes.Split(chunk, []byte("\n")) + for i := range lines { + line := bytes.TrimSpace(lines[i]) + if len(line) == 0 || bytes.HasPrefix(line, []byte("event:")) { + continue + } + if bytes.HasPrefix(line, []byte("data:")) { + line = bytes.TrimSpace(line[len("data:"):]) + } + if len(line) == 0 || bytes.Equal(line, []byte(wsDoneMarker)) { + continue + } + if json.Valid(line) { + payloads = append(payloads, bytes.Clone(line)) + } + } + + if len(payloads) > 0 { + return payloads + } + + trimmed := bytes.TrimSpace(chunk) + if bytes.HasPrefix(trimmed, []byte("data:")) { + trimmed = bytes.TrimSpace(trimmed[len("data:"):]) + } + if len(trimmed) > 0 && !bytes.Equal(trimmed, []byte(wsDoneMarker)) && json.Valid(trimmed) { + payloads = append(payloads, bytes.Clone(trimmed)) + } + return payloads +} + +func writeResponsesWebsocketError(conn *websocket.Conn, errMsg *interfaces.ErrorMessage) ([]byte, error) { + status := http.StatusInternalServerError + errText := http.StatusText(status) + if errMsg != nil { + if errMsg.StatusCode > 0 { + status = errMsg.StatusCode + errText = http.StatusText(status) + } + if errMsg.Error != nil && strings.TrimSpace(errMsg.Error.Error()) != "" { + errText = errMsg.Error.Error() + } + } + + body := handlers.BuildErrorResponseBody(status, errText) + payload := map[string]any{ + "type": wsEventTypeError, + "status": status, + } + + if errMsg != nil && errMsg.Addon != nil { + headers := map[string]any{} + for key, values := range errMsg.Addon { + if len(values) == 0 { + continue + } + headers[key] = values[0] + } + if len(headers) > 0 { + payload["headers"] = headers + } + } + + if len(body) > 0 && json.Valid(body) { + var decoded map[string]any + if errDecode := json.Unmarshal(body, &decoded); errDecode == nil { + if inner, ok := decoded["error"]; ok { + payload["error"] = inner + } else { + payload["error"] = decoded + } + } + } + + if _, ok := payload["error"]; !ok { + payload["error"] = map[string]any{ + "type": "server_error", + "message": errText, + } + } + + data, err := json.Marshal(payload) + if err != nil { + return nil, err + } + return data, conn.WriteMessage(websocket.TextMessage, data) +} + +func appendWebsocketEvent(builder *strings.Builder, eventType string, payload []byte) { + if builder == nil { + return + } + trimmedPayload := bytes.TrimSpace(payload) + if len(trimmedPayload) == 0 { + return + } + if builder.Len() > 0 { + builder.WriteString("\n") + } + builder.WriteString("websocket.") + builder.WriteString(eventType) + builder.WriteString("\n") + builder.Write(trimmedPayload) + builder.WriteString("\n") +} + +func websocketPayloadEventType(payload []byte) string { + eventType := strings.TrimSpace(gjson.GetBytes(payload, "type").String()) + if eventType == "" { + return "-" + } + return eventType +} + +func websocketPayloadPreview(payload []byte) string { + trimmedPayload := bytes.TrimSpace(payload) + if len(trimmedPayload) == 0 { + return "" + } + preview := trimmedPayload + if len(preview) > wsPayloadLogMaxSize { + preview = preview[:wsPayloadLogMaxSize] + } + previewText := strings.ReplaceAll(string(preview), "\n", "\\n") + previewText = strings.ReplaceAll(previewText, "\r", "\\r") + if len(trimmedPayload) > wsPayloadLogMaxSize { + return fmt.Sprintf("%s...(truncated,total=%d)", previewText, len(trimmedPayload)) + } + return previewText +} + +func setWebsocketRequestBody(c *gin.Context, body string) { + if c == nil { + return + } + trimmedBody := strings.TrimSpace(body) + if trimmedBody == "" { + return + } + c.Set(wsRequestBodyKey, []byte(trimmedBody)) +} + +func markAPIResponseTimestamp(c *gin.Context) { + if c == nil { + return + } + if _, exists := c.Get("API_RESPONSE_TIMESTAMP"); exists { + return + } + c.Set("API_RESPONSE_TIMESTAMP", time.Now()) +} diff --git a/sdk/api/handlers/openai/openai_responses_websocket_test.go b/sdk/api/handlers/openai/openai_responses_websocket_test.go new file mode 100644 index 00000000..9b6cec78 --- /dev/null +++ b/sdk/api/handlers/openai/openai_responses_websocket_test.go @@ -0,0 +1,249 @@ +package openai + +import ( + "bytes" + "net/http" + "net/http/httptest" + "strings" + "testing" + + "github.com/gin-gonic/gin" + "github.com/tidwall/gjson" +) + +func TestNormalizeResponsesWebsocketRequestCreate(t *testing.T) { + raw := []byte(`{"type":"response.create","model":"test-model","stream":false,"input":[{"type":"message","id":"msg-1"}]}`) + + normalized, last, errMsg := normalizeResponsesWebsocketRequest(raw, nil, nil) + if errMsg != nil { + t.Fatalf("unexpected error: %v", errMsg.Error) + } + if gjson.GetBytes(normalized, "type").Exists() { + t.Fatalf("normalized create request must not include type field") + } + if !gjson.GetBytes(normalized, "stream").Bool() { + t.Fatalf("normalized create request must force stream=true") + } + if gjson.GetBytes(normalized, "model").String() != "test-model" { + t.Fatalf("unexpected model: %s", gjson.GetBytes(normalized, "model").String()) + } + if !bytes.Equal(last, normalized) { + t.Fatalf("last request snapshot should match normalized request") + } +} + +func TestNormalizeResponsesWebsocketRequestCreateWithHistory(t *testing.T) { + lastRequest := []byte(`{"model":"test-model","stream":true,"input":[{"type":"message","id":"msg-1"}]}`) + lastResponseOutput := []byte(`[ + {"type":"function_call","id":"fc-1","call_id":"call-1"}, + {"type":"message","id":"assistant-1"} + ]`) + raw := []byte(`{"type":"response.create","input":[{"type":"function_call_output","call_id":"call-1","id":"tool-out-1"}]}`) + + normalized, next, errMsg := normalizeResponsesWebsocketRequest(raw, lastRequest, lastResponseOutput) + if errMsg != nil { + t.Fatalf("unexpected error: %v", errMsg.Error) + } + if gjson.GetBytes(normalized, "type").Exists() { + t.Fatalf("normalized subsequent create request must not include type field") + } + if gjson.GetBytes(normalized, "model").String() != "test-model" { + t.Fatalf("unexpected model: %s", gjson.GetBytes(normalized, "model").String()) + } + + input := gjson.GetBytes(normalized, "input").Array() + if len(input) != 4 { + t.Fatalf("merged input len = %d, want 4", len(input)) + } + if input[0].Get("id").String() != "msg-1" || + input[1].Get("id").String() != "fc-1" || + input[2].Get("id").String() != "assistant-1" || + input[3].Get("id").String() != "tool-out-1" { + t.Fatalf("unexpected merged input order") + } + if !bytes.Equal(next, normalized) { + t.Fatalf("next request snapshot should match normalized request") + } +} + +func TestNormalizeResponsesWebsocketRequestWithPreviousResponseIDIncremental(t *testing.T) { + lastRequest := []byte(`{"model":"test-model","stream":true,"instructions":"be helpful","input":[{"type":"message","id":"msg-1"}]}`) + lastResponseOutput := []byte(`[ + {"type":"function_call","id":"fc-1","call_id":"call-1"}, + {"type":"message","id":"assistant-1"} + ]`) + raw := []byte(`{"type":"response.create","previous_response_id":"resp-1","input":[{"type":"function_call_output","call_id":"call-1","id":"tool-out-1"}]}`) + + normalized, next, errMsg := normalizeResponsesWebsocketRequestWithMode(raw, lastRequest, lastResponseOutput, true) + if errMsg != nil { + t.Fatalf("unexpected error: %v", errMsg.Error) + } + if gjson.GetBytes(normalized, "type").Exists() { + t.Fatalf("normalized request must not include type field") + } + if gjson.GetBytes(normalized, "previous_response_id").String() != "resp-1" { + t.Fatalf("previous_response_id must be preserved in incremental mode") + } + input := gjson.GetBytes(normalized, "input").Array() + if len(input) != 1 { + t.Fatalf("incremental input len = %d, want 1", len(input)) + } + if input[0].Get("id").String() != "tool-out-1" { + t.Fatalf("unexpected incremental input item id: %s", input[0].Get("id").String()) + } + if gjson.GetBytes(normalized, "model").String() != "test-model" { + t.Fatalf("unexpected model: %s", gjson.GetBytes(normalized, "model").String()) + } + if gjson.GetBytes(normalized, "instructions").String() != "be helpful" { + t.Fatalf("unexpected instructions: %s", gjson.GetBytes(normalized, "instructions").String()) + } + if !bytes.Equal(next, normalized) { + t.Fatalf("next request snapshot should match normalized request") + } +} + +func TestNormalizeResponsesWebsocketRequestWithPreviousResponseIDMergedWhenIncrementalDisabled(t *testing.T) { + lastRequest := []byte(`{"model":"test-model","stream":true,"input":[{"type":"message","id":"msg-1"}]}`) + lastResponseOutput := []byte(`[ + {"type":"function_call","id":"fc-1","call_id":"call-1"}, + {"type":"message","id":"assistant-1"} + ]`) + raw := []byte(`{"type":"response.create","previous_response_id":"resp-1","input":[{"type":"function_call_output","call_id":"call-1","id":"tool-out-1"}]}`) + + normalized, next, errMsg := normalizeResponsesWebsocketRequestWithMode(raw, lastRequest, lastResponseOutput, false) + if errMsg != nil { + t.Fatalf("unexpected error: %v", errMsg.Error) + } + if gjson.GetBytes(normalized, "previous_response_id").Exists() { + t.Fatalf("previous_response_id must be removed when incremental mode is disabled") + } + input := gjson.GetBytes(normalized, "input").Array() + if len(input) != 4 { + t.Fatalf("merged input len = %d, want 4", len(input)) + } + if input[0].Get("id").String() != "msg-1" || + input[1].Get("id").String() != "fc-1" || + input[2].Get("id").String() != "assistant-1" || + input[3].Get("id").String() != "tool-out-1" { + t.Fatalf("unexpected merged input order") + } + if !bytes.Equal(next, normalized) { + t.Fatalf("next request snapshot should match normalized request") + } +} + +func TestNormalizeResponsesWebsocketRequestAppend(t *testing.T) { + lastRequest := []byte(`{"model":"test-model","stream":true,"input":[{"type":"message","id":"msg-1"}]}`) + lastResponseOutput := []byte(`[ + {"type":"message","id":"assistant-1"}, + {"type":"function_call_output","id":"tool-out-1"} + ]`) + raw := []byte(`{"type":"response.append","input":[{"type":"message","id":"msg-2"},{"type":"message","id":"msg-3"}]}`) + + normalized, next, errMsg := normalizeResponsesWebsocketRequest(raw, lastRequest, lastResponseOutput) + if errMsg != nil { + t.Fatalf("unexpected error: %v", errMsg.Error) + } + input := gjson.GetBytes(normalized, "input").Array() + if len(input) != 5 { + t.Fatalf("merged input len = %d, want 5", len(input)) + } + if input[0].Get("id").String() != "msg-1" || + input[1].Get("id").String() != "assistant-1" || + input[2].Get("id").String() != "tool-out-1" || + input[3].Get("id").String() != "msg-2" || + input[4].Get("id").String() != "msg-3" { + t.Fatalf("unexpected merged input order") + } + if !bytes.Equal(next, normalized) { + t.Fatalf("next request snapshot should match normalized append request") + } +} + +func TestNormalizeResponsesWebsocketRequestAppendWithoutCreate(t *testing.T) { + raw := []byte(`{"type":"response.append","input":[]}`) + + _, _, errMsg := normalizeResponsesWebsocketRequest(raw, nil, nil) + if errMsg == nil { + t.Fatalf("expected error for append without previous request") + } + if errMsg.StatusCode != http.StatusBadRequest { + t.Fatalf("status = %d, want %d", errMsg.StatusCode, http.StatusBadRequest) + } +} + +func TestWebsocketJSONPayloadsFromChunk(t *testing.T) { + chunk := []byte("event: response.created\n\ndata: {\"type\":\"response.created\",\"response\":{\"id\":\"resp-1\"}}\n\ndata: [DONE]\n") + + payloads := websocketJSONPayloadsFromChunk(chunk) + if len(payloads) != 1 { + t.Fatalf("payloads len = %d, want 1", len(payloads)) + } + if gjson.GetBytes(payloads[0], "type").String() != "response.created" { + t.Fatalf("unexpected payload type: %s", gjson.GetBytes(payloads[0], "type").String()) + } +} + +func TestWebsocketJSONPayloadsFromPlainJSONChunk(t *testing.T) { + chunk := []byte(`{"type":"response.completed","response":{"id":"resp-1"}}`) + + payloads := websocketJSONPayloadsFromChunk(chunk) + if len(payloads) != 1 { + t.Fatalf("payloads len = %d, want 1", len(payloads)) + } + if gjson.GetBytes(payloads[0], "type").String() != "response.completed" { + t.Fatalf("unexpected payload type: %s", gjson.GetBytes(payloads[0], "type").String()) + } +} + +func TestResponseCompletedOutputFromPayload(t *testing.T) { + payload := []byte(`{"type":"response.completed","response":{"id":"resp-1","output":[{"type":"message","id":"out-1"}]}}`) + + output := responseCompletedOutputFromPayload(payload) + items := gjson.ParseBytes(output).Array() + if len(items) != 1 { + t.Fatalf("output len = %d, want 1", len(items)) + } + if items[0].Get("id").String() != "out-1" { + t.Fatalf("unexpected output id: %s", items[0].Get("id").String()) + } +} + +func TestAppendWebsocketEvent(t *testing.T) { + var builder strings.Builder + + appendWebsocketEvent(&builder, "request", []byte(" {\"type\":\"response.create\"}\n")) + appendWebsocketEvent(&builder, "response", []byte("{\"type\":\"response.created\"}")) + + got := builder.String() + if !strings.Contains(got, "websocket.request\n{\"type\":\"response.create\"}\n") { + t.Fatalf("request event not found in body: %s", got) + } + if !strings.Contains(got, "websocket.response\n{\"type\":\"response.created\"}\n") { + t.Fatalf("response event not found in body: %s", got) + } +} + +func TestSetWebsocketRequestBody(t *testing.T) { + gin.SetMode(gin.TestMode) + recorder := httptest.NewRecorder() + c, _ := gin.CreateTestContext(recorder) + + setWebsocketRequestBody(c, " \n ") + if _, exists := c.Get(wsRequestBodyKey); exists { + t.Fatalf("request body key should not be set for empty body") + } + + setWebsocketRequestBody(c, "event body") + value, exists := c.Get(wsRequestBodyKey) + if !exists { + t.Fatalf("request body key not set") + } + bodyBytes, ok := value.([]byte) + if !ok { + t.Fatalf("request body key type mismatch") + } + if string(bodyBytes) != "event body" { + t.Fatalf("request body = %q, want %q", string(bodyBytes), "event body") + } +} diff --git a/sdk/cliproxy/auth/conductor.go b/sdk/cliproxy/auth/conductor.go index 2c3e9f48..76aae228 100644 --- a/sdk/cliproxy/auth/conductor.go +++ b/sdk/cliproxy/auth/conductor.go @@ -41,6 +41,17 @@ type ProviderExecutor interface { HttpRequest(ctx context.Context, auth *Auth, req *http.Request) (*http.Response, error) } +// ExecutionSessionCloser allows executors to release per-session runtime resources. +type ExecutionSessionCloser interface { + CloseExecutionSession(sessionID string) +} + +const ( + // CloseAllExecutionSessionsID asks an executor to release all active execution sessions. + // Executors that do not support this marker may ignore it. + CloseAllExecutionSessionsID = "__all_execution_sessions__" +) + // RefreshEvaluator allows runtime state to override refresh decisions. type RefreshEvaluator interface { ShouldRefresh(now time.Time, auth *Auth) bool @@ -389,9 +400,23 @@ func (m *Manager) RegisterExecutor(executor ProviderExecutor) { if executor == nil { return } + provider := strings.TrimSpace(executor.Identifier()) + if provider == "" { + return + } + + var replaced ProviderExecutor m.mu.Lock() - defer m.mu.Unlock() - m.executors[executor.Identifier()] = executor + replaced = m.executors[provider] + m.executors[provider] = executor + m.mu.Unlock() + + if replaced == nil || replaced == executor { + return + } + if closer, ok := replaced.(ExecutionSessionCloser); ok && closer != nil { + closer.CloseExecutionSession(CloseAllExecutionSessionsID) + } } // UnregisterExecutor removes the executor associated with the provider key. @@ -581,6 +606,7 @@ func (m *Manager) executeMixedOnce(ctx context.Context, providers []string, req entry := logEntryWithRequestID(ctx) debugLogAuthSelection(entry, auth, provider, req.Model) + publishSelectedAuthMetadata(opts.Metadata, auth.ID) tried[auth.ID] = struct{}{} execCtx := ctx @@ -636,6 +662,7 @@ func (m *Manager) executeCountMixedOnce(ctx context.Context, providers []string, entry := logEntryWithRequestID(ctx) debugLogAuthSelection(entry, auth, provider, req.Model) + publishSelectedAuthMetadata(opts.Metadata, auth.ID) tried[auth.ID] = struct{}{} execCtx := ctx @@ -691,6 +718,7 @@ func (m *Manager) executeStreamMixedOnce(ctx context.Context, providers []string entry := logEntryWithRequestID(ctx) debugLogAuthSelection(entry, auth, provider, req.Model) + publishSelectedAuthMetadata(opts.Metadata, auth.ID) tried[auth.ID] = struct{}{} execCtx := ctx @@ -794,6 +822,38 @@ func hasRequestedModelMetadata(meta map[string]any) bool { } } +func pinnedAuthIDFromMetadata(meta map[string]any) string { + if len(meta) == 0 { + return "" + } + raw, ok := meta[cliproxyexecutor.PinnedAuthMetadataKey] + if !ok || raw == nil { + return "" + } + switch val := raw.(type) { + case string: + return strings.TrimSpace(val) + case []byte: + return strings.TrimSpace(string(val)) + default: + return "" + } +} + +func publishSelectedAuthMetadata(meta map[string]any, authID string) { + if len(meta) == 0 { + return + } + authID = strings.TrimSpace(authID) + if authID == "" { + return + } + meta[cliproxyexecutor.SelectedAuthMetadataKey] = authID + if callback, ok := meta[cliproxyexecutor.SelectedAuthCallbackMetadataKey].(func(string)); ok && callback != nil { + callback(authID) + } +} + func rewriteModelForAuth(model string, auth *Auth) string { if auth == nil || model == "" { return model @@ -1550,7 +1610,56 @@ func (m *Manager) GetByID(id string) (*Auth, bool) { return auth.Clone(), true } +// Executor returns the registered provider executor for a provider key. +func (m *Manager) Executor(provider string) (ProviderExecutor, bool) { + if m == nil { + return nil, false + } + provider = strings.TrimSpace(provider) + if provider == "" { + return nil, false + } + + m.mu.RLock() + executor, okExecutor := m.executors[provider] + if !okExecutor { + lowerProvider := strings.ToLower(provider) + if lowerProvider != provider { + executor, okExecutor = m.executors[lowerProvider] + } + } + m.mu.RUnlock() + + if !okExecutor || executor == nil { + return nil, false + } + return executor, true +} + +// CloseExecutionSession asks all registered executors to release the supplied execution session. +func (m *Manager) CloseExecutionSession(sessionID string) { + sessionID = strings.TrimSpace(sessionID) + if m == nil || sessionID == "" { + return + } + + m.mu.RLock() + executors := make([]ProviderExecutor, 0, len(m.executors)) + for _, exec := range m.executors { + executors = append(executors, exec) + } + m.mu.RUnlock() + + for i := range executors { + if closer, ok := executors[i].(ExecutionSessionCloser); ok && closer != nil { + closer.CloseExecutionSession(sessionID) + } + } +} + func (m *Manager) pickNext(ctx context.Context, provider, model string, opts cliproxyexecutor.Options, tried map[string]struct{}) (*Auth, ProviderExecutor, error) { + pinnedAuthID := pinnedAuthIDFromMetadata(opts.Metadata) + m.mu.RLock() executor, okExecutor := m.executors[provider] if !okExecutor { @@ -1571,6 +1680,9 @@ func (m *Manager) pickNext(ctx context.Context, provider, model string, opts cli if candidate.Provider != provider || candidate.Disabled { continue } + if pinnedAuthID != "" && candidate.ID != pinnedAuthID { + continue + } if _, used := tried[candidate.ID]; used { continue } @@ -1606,6 +1718,8 @@ func (m *Manager) pickNext(ctx context.Context, provider, model string, opts cli } func (m *Manager) pickNextMixed(ctx context.Context, providers []string, model string, opts cliproxyexecutor.Options, tried map[string]struct{}) (*Auth, ProviderExecutor, string, error) { + pinnedAuthID := pinnedAuthIDFromMetadata(opts.Metadata) + providerSet := make(map[string]struct{}, len(providers)) for _, provider := range providers { p := strings.TrimSpace(strings.ToLower(provider)) @@ -1633,6 +1747,9 @@ func (m *Manager) pickNextMixed(ctx context.Context, providers []string, model s if candidate == nil || candidate.Disabled { continue } + if pinnedAuthID != "" && candidate.ID != pinnedAuthID { + continue + } providerKey := strings.TrimSpace(strings.ToLower(candidate.Provider)) if providerKey == "" { continue diff --git a/sdk/cliproxy/auth/conductor_executor_replace_test.go b/sdk/cliproxy/auth/conductor_executor_replace_test.go new file mode 100644 index 00000000..3854f341 --- /dev/null +++ b/sdk/cliproxy/auth/conductor_executor_replace_test.go @@ -0,0 +1,100 @@ +package auth + +import ( + "context" + "net/http" + "sync" + "testing" + + cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor" +) + +type replaceAwareExecutor struct { + id string + + mu sync.Mutex + closedSessionIDs []string +} + +func (e *replaceAwareExecutor) Identifier() string { + return e.id +} + +func (e *replaceAwareExecutor) Execute(context.Context, *Auth, cliproxyexecutor.Request, cliproxyexecutor.Options) (cliproxyexecutor.Response, error) { + return cliproxyexecutor.Response{}, nil +} + +func (e *replaceAwareExecutor) ExecuteStream(context.Context, *Auth, cliproxyexecutor.Request, cliproxyexecutor.Options) (<-chan cliproxyexecutor.StreamChunk, error) { + ch := make(chan cliproxyexecutor.StreamChunk) + close(ch) + return ch, nil +} + +func (e *replaceAwareExecutor) Refresh(_ context.Context, auth *Auth) (*Auth, error) { + return auth, nil +} + +func (e *replaceAwareExecutor) CountTokens(context.Context, *Auth, cliproxyexecutor.Request, cliproxyexecutor.Options) (cliproxyexecutor.Response, error) { + return cliproxyexecutor.Response{}, nil +} + +func (e *replaceAwareExecutor) HttpRequest(context.Context, *Auth, *http.Request) (*http.Response, error) { + return nil, nil +} + +func (e *replaceAwareExecutor) CloseExecutionSession(sessionID string) { + e.mu.Lock() + defer e.mu.Unlock() + e.closedSessionIDs = append(e.closedSessionIDs, sessionID) +} + +func (e *replaceAwareExecutor) ClosedSessionIDs() []string { + e.mu.Lock() + defer e.mu.Unlock() + out := make([]string, len(e.closedSessionIDs)) + copy(out, e.closedSessionIDs) + return out +} + +func TestManagerRegisterExecutorClosesReplacedExecutionSessions(t *testing.T) { + t.Parallel() + + manager := NewManager(nil, nil, nil) + replaced := &replaceAwareExecutor{id: "codex"} + current := &replaceAwareExecutor{id: "codex"} + + manager.RegisterExecutor(replaced) + manager.RegisterExecutor(current) + + closed := replaced.ClosedSessionIDs() + if len(closed) != 1 { + t.Fatalf("expected replaced executor close calls = 1, got %d", len(closed)) + } + if closed[0] != CloseAllExecutionSessionsID { + t.Fatalf("expected close marker %q, got %q", CloseAllExecutionSessionsID, closed[0]) + } + if len(current.ClosedSessionIDs()) != 0 { + t.Fatalf("expected current executor to stay open") + } +} + +func TestManagerExecutorReturnsRegisteredExecutor(t *testing.T) { + t.Parallel() + + manager := NewManager(nil, nil, nil) + current := &replaceAwareExecutor{id: "codex"} + manager.RegisterExecutor(current) + + resolved, okResolved := manager.Executor("CODEX") + if !okResolved { + t.Fatal("expected registered executor to be found") + } + if resolved != current { + t.Fatal("expected resolved executor to match registered executor") + } + + _, okMissing := manager.Executor("unknown") + if okMissing { + t.Fatal("expected unknown provider lookup to fail") + } +} diff --git a/sdk/cliproxy/auth/selector.go b/sdk/cliproxy/auth/selector.go index 28500881..a173ed01 100644 --- a/sdk/cliproxy/auth/selector.go +++ b/sdk/cliproxy/auth/selector.go @@ -134,6 +134,62 @@ func canonicalModelKey(model string) string { return modelName } +func authWebsocketsEnabled(auth *Auth) bool { + if auth == nil { + return false + } + if len(auth.Attributes) > 0 { + if raw := strings.TrimSpace(auth.Attributes["websockets"]); raw != "" { + parsed, errParse := strconv.ParseBool(raw) + if errParse == nil { + return parsed + } + } + } + if len(auth.Metadata) == 0 { + return false + } + raw, ok := auth.Metadata["websockets"] + if !ok || raw == nil { + return false + } + switch v := raw.(type) { + case bool: + return v + case string: + parsed, errParse := strconv.ParseBool(strings.TrimSpace(v)) + if errParse == nil { + return parsed + } + default: + } + return false +} + +func preferCodexWebsocketAuths(ctx context.Context, provider string, available []*Auth) []*Auth { + if len(available) == 0 { + return available + } + if !cliproxyexecutor.DownstreamWebsocket(ctx) { + return available + } + if !strings.EqualFold(strings.TrimSpace(provider), "codex") { + return available + } + + wsEnabled := make([]*Auth, 0, len(available)) + for i := 0; i < len(available); i++ { + candidate := available[i] + if authWebsocketsEnabled(candidate) { + wsEnabled = append(wsEnabled, candidate) + } + } + if len(wsEnabled) > 0 { + return wsEnabled + } + return available +} + func collectAvailableByPriority(auths []*Auth, model string, now time.Time) (available map[int][]*Auth, cooldownCount int, earliest time.Time) { available = make(map[int][]*Auth) for i := 0; i < len(auths); i++ { @@ -193,13 +249,13 @@ func getAvailableAuths(auths []*Auth, provider, model string, now time.Time) ([] // Pick selects the next available auth for the provider in a round-robin manner. func (s *RoundRobinSelector) Pick(ctx context.Context, provider, model string, opts cliproxyexecutor.Options, auths []*Auth) (*Auth, error) { - _ = ctx _ = opts now := time.Now() available, err := getAvailableAuths(auths, provider, model, now) if err != nil { return nil, err } + available = preferCodexWebsocketAuths(ctx, provider, available) key := provider + ":" + canonicalModelKey(model) s.mu.Lock() if s.cursors == nil { @@ -226,13 +282,13 @@ func (s *RoundRobinSelector) Pick(ctx context.Context, provider, model string, o // Pick selects the first available auth for the provider in a deterministic manner. func (s *FillFirstSelector) Pick(ctx context.Context, provider, model string, opts cliproxyexecutor.Options, auths []*Auth) (*Auth, error) { - _ = ctx _ = opts now := time.Now() available, err := getAvailableAuths(auths, provider, model, now) if err != nil { return nil, err } + available = preferCodexWebsocketAuths(ctx, provider, available) return available[0], nil } diff --git a/sdk/cliproxy/executor/context.go b/sdk/cliproxy/executor/context.go new file mode 100644 index 00000000..367b507e --- /dev/null +++ b/sdk/cliproxy/executor/context.go @@ -0,0 +1,23 @@ +package executor + +import "context" + +type downstreamWebsocketContextKey struct{} + +// WithDownstreamWebsocket marks the current request as coming from a downstream websocket connection. +func WithDownstreamWebsocket(ctx context.Context) context.Context { + if ctx == nil { + ctx = context.Background() + } + return context.WithValue(ctx, downstreamWebsocketContextKey{}, true) +} + +// DownstreamWebsocket reports whether the current request originates from a downstream websocket connection. +func DownstreamWebsocket(ctx context.Context) bool { + if ctx == nil { + return false + } + raw := ctx.Value(downstreamWebsocketContextKey{}) + enabled, ok := raw.(bool) + return ok && enabled +} diff --git a/sdk/cliproxy/executor/types.go b/sdk/cliproxy/executor/types.go index 8c11bbc4..4e917eb7 100644 --- a/sdk/cliproxy/executor/types.go +++ b/sdk/cliproxy/executor/types.go @@ -10,6 +10,17 @@ import ( // RequestedModelMetadataKey stores the client-requested model name in Options.Metadata. const RequestedModelMetadataKey = "requested_model" +const ( + // PinnedAuthMetadataKey locks execution to a specific auth ID. + PinnedAuthMetadataKey = "pinned_auth_id" + // SelectedAuthMetadataKey stores the auth ID selected by the scheduler. + SelectedAuthMetadataKey = "selected_auth_id" + // SelectedAuthCallbackMetadataKey carries an optional callback invoked with the selected auth ID. + SelectedAuthCallbackMetadataKey = "selected_auth_callback" + // ExecutionSessionMetadataKey identifies a long-lived downstream execution session. + ExecutionSessionMetadataKey = "execution_session_id" +) + // Request encapsulates the translated payload that will be sent to a provider executor. type Request struct { // Model is the upstream model identifier after translation. diff --git a/sdk/cliproxy/service.go b/sdk/cliproxy/service.go index 536329b5..e89c49c0 100644 --- a/sdk/cliproxy/service.go +++ b/sdk/cliproxy/service.go @@ -325,6 +325,9 @@ func (s *Service) applyCoreAuthRemoval(ctx context.Context, id string) { if _, err := s.coreManager.Update(ctx, existing); err != nil { log.Errorf("failed to disable auth %s: %v", id, err) } + if strings.EqualFold(strings.TrimSpace(existing.Provider), "codex") { + s.ensureExecutorsForAuth(existing) + } } } @@ -357,7 +360,24 @@ func openAICompatInfoFromAuth(a *coreauth.Auth) (providerKey string, compatName } func (s *Service) ensureExecutorsForAuth(a *coreauth.Auth) { - if s == nil || a == nil { + s.ensureExecutorsForAuthWithMode(a, false) +} + +func (s *Service) ensureExecutorsForAuthWithMode(a *coreauth.Auth, forceReplace bool) { + if s == nil || s.coreManager == nil || a == nil { + return + } + if strings.EqualFold(strings.TrimSpace(a.Provider), "codex") { + if !forceReplace { + existingExecutor, hasExecutor := s.coreManager.Executor("codex") + if hasExecutor { + _, isCodexAutoExecutor := existingExecutor.(*executor.CodexAutoExecutor) + if isCodexAutoExecutor { + return + } + } + } + s.coreManager.RegisterExecutor(executor.NewCodexAutoExecutor(s.cfg)) return } // Skip disabled auth entries when (re)binding executors. @@ -392,8 +412,6 @@ func (s *Service) ensureExecutorsForAuth(a *coreauth.Auth) { s.coreManager.RegisterExecutor(executor.NewAntigravityExecutor(s.cfg)) case "claude": s.coreManager.RegisterExecutor(executor.NewClaudeExecutor(s.cfg)) - case "codex": - s.coreManager.RegisterExecutor(executor.NewCodexExecutor(s.cfg)) case "qwen": s.coreManager.RegisterExecutor(executor.NewQwenExecutor(s.cfg)) case "iflow": @@ -415,8 +433,15 @@ func (s *Service) rebindExecutors() { return } auths := s.coreManager.List() + reboundCodex := false for _, auth := range auths { - s.ensureExecutorsForAuth(auth) + if auth != nil && strings.EqualFold(strings.TrimSpace(auth.Provider), "codex") { + if reboundCodex { + continue + } + reboundCodex = true + } + s.ensureExecutorsForAuthWithMode(auth, true) } } diff --git a/sdk/cliproxy/service_codex_executor_binding_test.go b/sdk/cliproxy/service_codex_executor_binding_test.go new file mode 100644 index 00000000..bb4fc84e --- /dev/null +++ b/sdk/cliproxy/service_codex_executor_binding_test.go @@ -0,0 +1,64 @@ +package cliproxy + +import ( + "testing" + + coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth" + "github.com/router-for-me/CLIProxyAPI/v6/sdk/config" +) + +func TestEnsureExecutorsForAuth_CodexDoesNotReplaceInNormalMode(t *testing.T) { + service := &Service{ + cfg: &config.Config{}, + coreManager: coreauth.NewManager(nil, nil, nil), + } + auth := &coreauth.Auth{ + ID: "codex-auth-1", + Provider: "codex", + Status: coreauth.StatusActive, + } + + service.ensureExecutorsForAuth(auth) + firstExecutor, okFirst := service.coreManager.Executor("codex") + if !okFirst || firstExecutor == nil { + t.Fatal("expected codex executor after first bind") + } + + service.ensureExecutorsForAuth(auth) + secondExecutor, okSecond := service.coreManager.Executor("codex") + if !okSecond || secondExecutor == nil { + t.Fatal("expected codex executor after second bind") + } + + if firstExecutor != secondExecutor { + t.Fatal("expected codex executor to stay unchanged in normal mode") + } +} + +func TestEnsureExecutorsForAuthWithMode_CodexForceReplace(t *testing.T) { + service := &Service{ + cfg: &config.Config{}, + coreManager: coreauth.NewManager(nil, nil, nil), + } + auth := &coreauth.Auth{ + ID: "codex-auth-2", + Provider: "codex", + Status: coreauth.StatusActive, + } + + service.ensureExecutorsForAuth(auth) + firstExecutor, okFirst := service.coreManager.Executor("codex") + if !okFirst || firstExecutor == nil { + t.Fatal("expected codex executor after first bind") + } + + service.ensureExecutorsForAuthWithMode(auth, true) + secondExecutor, okSecond := service.coreManager.Executor("codex") + if !okSecond || secondExecutor == nil { + t.Fatal("expected codex executor after forced rebind") + } + + if firstExecutor == secondExecutor { + t.Fatal("expected codex executor replacement in force mode") + } +}