package openai import ( "bytes" "context" "encoding/json" "fmt" "net/http" "strconv" "strings" "time" "github.com/gin-gonic/gin" "github.com/google/uuid" "github.com/gorilla/websocket" "github.com/router-for-me/CLIProxyAPI/v6/internal/interfaces" "github.com/router-for-me/CLIProxyAPI/v6/sdk/api/handlers" cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor" log "github.com/sirupsen/logrus" "github.com/tidwall/gjson" "github.com/tidwall/sjson" ) const ( wsRequestTypeCreate = "response.create" wsRequestTypeAppend = "response.append" wsEventTypeError = "error" wsEventTypeCompleted = "response.completed" wsEventTypeDone = "response.done" wsDoneMarker = "[DONE]" wsTurnStateHeader = "x-codex-turn-state" wsRequestBodyKey = "REQUEST_BODY_OVERRIDE" wsPayloadLogMaxSize = 2048 ) var responsesWebsocketUpgrader = websocket.Upgrader{ ReadBufferSize: 4096, WriteBufferSize: 4096, CheckOrigin: func(r *http.Request) bool { return true }, } // ResponsesWebsocket handles websocket requests for /v1/responses. // It accepts `response.create` and `response.append` requests and streams // response events back as JSON websocket text messages. func (h *OpenAIResponsesAPIHandler) ResponsesWebsocket(c *gin.Context) { conn, err := responsesWebsocketUpgrader.Upgrade(c.Writer, c.Request, websocketUpgradeHeaders(c.Request)) if err != nil { return } passthroughSessionID := uuid.NewString() clientRemoteAddr := "" if c != nil && c.Request != nil { clientRemoteAddr = strings.TrimSpace(c.Request.RemoteAddr) } log.Infof("responses websocket: client connected id=%s remote=%s", passthroughSessionID, clientRemoteAddr) var wsTerminateErr error var wsBodyLog strings.Builder defer func() { if wsTerminateErr != nil { // log.Infof("responses websocket: session closing id=%s reason=%v", passthroughSessionID, wsTerminateErr) } else { log.Infof("responses websocket: session closing id=%s", passthroughSessionID) } if h != nil && h.AuthManager != nil { h.AuthManager.CloseExecutionSession(passthroughSessionID) log.Infof("responses websocket: upstream execution session closed id=%s", passthroughSessionID) } setWebsocketRequestBody(c, wsBodyLog.String()) if errClose := conn.Close(); errClose != nil { log.Warnf("responses websocket: close connection error: %v", errClose) } }() var lastRequest []byte lastResponseOutput := []byte("[]") pinnedAuthID := "" for { msgType, payload, errReadMessage := conn.ReadMessage() if errReadMessage != nil { wsTerminateErr = errReadMessage appendWebsocketEvent(&wsBodyLog, "disconnect", []byte(errReadMessage.Error())) if websocket.IsCloseError(errReadMessage, websocket.CloseNormalClosure, websocket.CloseGoingAway, websocket.CloseNoStatusReceived) { log.Infof("responses websocket: client disconnected id=%s error=%v", passthroughSessionID, errReadMessage) } else { // log.Warnf("responses websocket: read message failed id=%s error=%v", passthroughSessionID, errReadMessage) } return } if msgType != websocket.TextMessage && msgType != websocket.BinaryMessage { continue } // log.Infof( // "responses websocket: downstream_in id=%s type=%d event=%s payload=%s", // passthroughSessionID, // msgType, // websocketPayloadEventType(payload), // websocketPayloadPreview(payload), // ) appendWebsocketEvent(&wsBodyLog, "request", payload) allowIncrementalInputWithPreviousResponseID := websocketUpstreamSupportsIncrementalInput(nil, nil) if pinnedAuthID != "" && h != nil && h.AuthManager != nil { if pinnedAuth, ok := h.AuthManager.GetByID(pinnedAuthID); ok && pinnedAuth != nil { allowIncrementalInputWithPreviousResponseID = websocketUpstreamSupportsIncrementalInput(pinnedAuth.Attributes, pinnedAuth.Metadata) } } var requestJSON []byte var updatedLastRequest []byte var errMsg *interfaces.ErrorMessage requestJSON, updatedLastRequest, errMsg = normalizeResponsesWebsocketRequestWithMode( payload, lastRequest, lastResponseOutput, allowIncrementalInputWithPreviousResponseID, ) if errMsg != nil { h.LoggingAPIResponseError(context.WithValue(context.Background(), "gin", c), errMsg) markAPIResponseTimestamp(c) errorPayload, errWrite := writeResponsesWebsocketError(conn, errMsg) appendWebsocketEvent(&wsBodyLog, "response", errorPayload) log.Infof( "responses websocket: downstream_out id=%s type=%d event=%s payload=%s", passthroughSessionID, websocket.TextMessage, websocketPayloadEventType(errorPayload), websocketPayloadPreview(errorPayload), ) if errWrite != nil { log.Warnf( "responses websocket: downstream_out write failed id=%s event=%s error=%v", passthroughSessionID, websocketPayloadEventType(errorPayload), errWrite, ) return } continue } lastRequest = updatedLastRequest modelName := gjson.GetBytes(requestJSON, "model").String() cliCtx, cliCancel := h.GetContextWithCancel(h, c, context.Background()) cliCtx = cliproxyexecutor.WithDownstreamWebsocket(cliCtx) cliCtx = handlers.WithExecutionSessionID(cliCtx, passthroughSessionID) if pinnedAuthID != "" { cliCtx = handlers.WithPinnedAuthID(cliCtx, pinnedAuthID) } else { cliCtx = handlers.WithSelectedAuthIDCallback(cliCtx, func(authID string) { pinnedAuthID = strings.TrimSpace(authID) }) } dataChan, errChan := h.ExecuteStreamWithAuthManager(cliCtx, h.HandlerType(), modelName, requestJSON, "") completedOutput, errForward := h.forwardResponsesWebsocket(c, conn, cliCancel, dataChan, errChan, &wsBodyLog, passthroughSessionID) if errForward != nil { wsTerminateErr = errForward appendWebsocketEvent(&wsBodyLog, "disconnect", []byte(errForward.Error())) log.Warnf("responses websocket: forward failed id=%s error=%v", passthroughSessionID, errForward) return } lastResponseOutput = completedOutput } } func websocketUpgradeHeaders(req *http.Request) http.Header { headers := http.Header{} if req == nil { return headers } // Keep the same sticky turn-state across reconnects when provided by the client. turnState := strings.TrimSpace(req.Header.Get(wsTurnStateHeader)) if turnState != "" { headers.Set(wsTurnStateHeader, turnState) } return headers } func normalizeResponsesWebsocketRequest(rawJSON []byte, lastRequest []byte, lastResponseOutput []byte) ([]byte, []byte, *interfaces.ErrorMessage) { return normalizeResponsesWebsocketRequestWithMode(rawJSON, lastRequest, lastResponseOutput, true) } func normalizeResponsesWebsocketRequestWithMode(rawJSON []byte, lastRequest []byte, lastResponseOutput []byte, allowIncrementalInputWithPreviousResponseID bool) ([]byte, []byte, *interfaces.ErrorMessage) { requestType := strings.TrimSpace(gjson.GetBytes(rawJSON, "type").String()) switch requestType { case wsRequestTypeCreate: // log.Infof("responses websocket: response.create request") if len(lastRequest) == 0 { return normalizeResponseCreateRequest(rawJSON) } return normalizeResponseSubsequentRequest(rawJSON, lastRequest, lastResponseOutput, allowIncrementalInputWithPreviousResponseID) case wsRequestTypeAppend: // log.Infof("responses websocket: response.append request") return normalizeResponseSubsequentRequest(rawJSON, lastRequest, lastResponseOutput, allowIncrementalInputWithPreviousResponseID) default: return nil, lastRequest, &interfaces.ErrorMessage{ StatusCode: http.StatusBadRequest, Error: fmt.Errorf("unsupported websocket request type: %s", requestType), } } } func normalizeResponseCreateRequest(rawJSON []byte) ([]byte, []byte, *interfaces.ErrorMessage) { normalized, errDelete := sjson.DeleteBytes(rawJSON, "type") if errDelete != nil { normalized = bytes.Clone(rawJSON) } normalized, _ = sjson.SetBytes(normalized, "stream", true) if !gjson.GetBytes(normalized, "input").Exists() { normalized, _ = sjson.SetRawBytes(normalized, "input", []byte("[]")) } modelName := strings.TrimSpace(gjson.GetBytes(normalized, "model").String()) if modelName == "" { return nil, nil, &interfaces.ErrorMessage{ StatusCode: http.StatusBadRequest, Error: fmt.Errorf("missing model in response.create request"), } } return normalized, bytes.Clone(normalized), nil } func normalizeResponseSubsequentRequest(rawJSON []byte, lastRequest []byte, lastResponseOutput []byte, allowIncrementalInputWithPreviousResponseID bool) ([]byte, []byte, *interfaces.ErrorMessage) { if len(lastRequest) == 0 { return nil, lastRequest, &interfaces.ErrorMessage{ StatusCode: http.StatusBadRequest, Error: fmt.Errorf("websocket request received before response.create"), } } nextInput := gjson.GetBytes(rawJSON, "input") if !nextInput.Exists() || !nextInput.IsArray() { return nil, lastRequest, &interfaces.ErrorMessage{ StatusCode: http.StatusBadRequest, Error: fmt.Errorf("websocket request requires array field: input"), } } // Websocket v2 mode uses response.create with previous_response_id + incremental input. // Do not expand it into a full input transcript; upstream expects the incremental payload. if allowIncrementalInputWithPreviousResponseID { if prev := strings.TrimSpace(gjson.GetBytes(rawJSON, "previous_response_id").String()); prev != "" { normalized, errDelete := sjson.DeleteBytes(rawJSON, "type") if errDelete != nil { normalized = bytes.Clone(rawJSON) } if !gjson.GetBytes(normalized, "model").Exists() { modelName := strings.TrimSpace(gjson.GetBytes(lastRequest, "model").String()) if modelName != "" { normalized, _ = sjson.SetBytes(normalized, "model", modelName) } } if !gjson.GetBytes(normalized, "instructions").Exists() { instructions := gjson.GetBytes(lastRequest, "instructions") if instructions.Exists() { normalized, _ = sjson.SetRawBytes(normalized, "instructions", []byte(instructions.Raw)) } } normalized, _ = sjson.SetBytes(normalized, "stream", true) return normalized, bytes.Clone(normalized), nil } } existingInput := gjson.GetBytes(lastRequest, "input") mergedInput, errMerge := mergeJSONArrayRaw(existingInput.Raw, normalizeJSONArrayRaw(lastResponseOutput)) if errMerge != nil { return nil, lastRequest, &interfaces.ErrorMessage{ StatusCode: http.StatusBadRequest, Error: fmt.Errorf("invalid previous response output: %w", errMerge), } } mergedInput, errMerge = mergeJSONArrayRaw(mergedInput, nextInput.Raw) if errMerge != nil { return nil, lastRequest, &interfaces.ErrorMessage{ StatusCode: http.StatusBadRequest, Error: fmt.Errorf("invalid request input: %w", errMerge), } } normalized, errDelete := sjson.DeleteBytes(rawJSON, "type") if errDelete != nil { normalized = bytes.Clone(rawJSON) } normalized, _ = sjson.DeleteBytes(normalized, "previous_response_id") var errSet error normalized, errSet = sjson.SetRawBytes(normalized, "input", []byte(mergedInput)) if errSet != nil { return nil, lastRequest, &interfaces.ErrorMessage{ StatusCode: http.StatusBadRequest, Error: fmt.Errorf("failed to merge websocket input: %w", errSet), } } if !gjson.GetBytes(normalized, "model").Exists() { modelName := strings.TrimSpace(gjson.GetBytes(lastRequest, "model").String()) if modelName != "" { normalized, _ = sjson.SetBytes(normalized, "model", modelName) } } if !gjson.GetBytes(normalized, "instructions").Exists() { instructions := gjson.GetBytes(lastRequest, "instructions") if instructions.Exists() { normalized, _ = sjson.SetRawBytes(normalized, "instructions", []byte(instructions.Raw)) } } normalized, _ = sjson.SetBytes(normalized, "stream", true) return normalized, bytes.Clone(normalized), nil } func websocketUpstreamSupportsIncrementalInput(attributes map[string]string, metadata map[string]any) bool { if len(attributes) > 0 { if raw := strings.TrimSpace(attributes["websockets"]); raw != "" { parsed, errParse := strconv.ParseBool(raw) if errParse == nil { return parsed } } } if len(metadata) == 0 { return false } raw, ok := metadata["websockets"] if !ok || raw == nil { return false } switch value := raw.(type) { case bool: return value case string: parsed, errParse := strconv.ParseBool(strings.TrimSpace(value)) if errParse == nil { return parsed } default: } return false } func mergeJSONArrayRaw(existingRaw, appendRaw string) (string, error) { existingRaw = strings.TrimSpace(existingRaw) appendRaw = strings.TrimSpace(appendRaw) if existingRaw == "" { existingRaw = "[]" } if appendRaw == "" { appendRaw = "[]" } var existing []json.RawMessage if err := json.Unmarshal([]byte(existingRaw), &existing); err != nil { return "", err } var appendItems []json.RawMessage if err := json.Unmarshal([]byte(appendRaw), &appendItems); err != nil { return "", err } merged := append(existing, appendItems...) out, err := json.Marshal(merged) if err != nil { return "", err } return string(out), nil } func normalizeJSONArrayRaw(raw []byte) string { trimmed := strings.TrimSpace(string(raw)) if trimmed == "" { return "[]" } result := gjson.Parse(trimmed) if result.Type == gjson.JSON && result.IsArray() { return trimmed } return "[]" } func (h *OpenAIResponsesAPIHandler) forwardResponsesWebsocket( c *gin.Context, conn *websocket.Conn, cancel handlers.APIHandlerCancelFunc, data <-chan []byte, errs <-chan *interfaces.ErrorMessage, wsBodyLog *strings.Builder, sessionID string, ) ([]byte, error) { completed := false completedOutput := []byte("[]") for { select { case <-c.Request.Context().Done(): cancel(c.Request.Context().Err()) return completedOutput, c.Request.Context().Err() case errMsg, ok := <-errs: if !ok { errs = nil continue } if errMsg != nil { h.LoggingAPIResponseError(context.WithValue(context.Background(), "gin", c), errMsg) markAPIResponseTimestamp(c) errorPayload, errWrite := writeResponsesWebsocketError(conn, errMsg) appendWebsocketEvent(wsBodyLog, "response", errorPayload) log.Infof( "responses websocket: downstream_out id=%s type=%d event=%s payload=%s", sessionID, websocket.TextMessage, websocketPayloadEventType(errorPayload), websocketPayloadPreview(errorPayload), ) if errWrite != nil { // log.Warnf( // "responses websocket: downstream_out write failed id=%s event=%s error=%v", // sessionID, // websocketPayloadEventType(errorPayload), // errWrite, // ) cancel(errMsg.Error) return completedOutput, errWrite } } if errMsg != nil { cancel(errMsg.Error) } else { cancel(nil) } return completedOutput, nil case chunk, ok := <-data: if !ok { if !completed { errMsg := &interfaces.ErrorMessage{ StatusCode: http.StatusRequestTimeout, Error: fmt.Errorf("stream closed before response.completed"), } h.LoggingAPIResponseError(context.WithValue(context.Background(), "gin", c), errMsg) markAPIResponseTimestamp(c) errorPayload, errWrite := writeResponsesWebsocketError(conn, errMsg) appendWebsocketEvent(wsBodyLog, "response", errorPayload) log.Infof( "responses websocket: downstream_out id=%s type=%d event=%s payload=%s", sessionID, websocket.TextMessage, websocketPayloadEventType(errorPayload), websocketPayloadPreview(errorPayload), ) if errWrite != nil { log.Warnf( "responses websocket: downstream_out write failed id=%s event=%s error=%v", sessionID, websocketPayloadEventType(errorPayload), errWrite, ) cancel(errMsg.Error) return completedOutput, errWrite } cancel(errMsg.Error) return completedOutput, nil } cancel(nil) return completedOutput, nil } payloads := websocketJSONPayloadsFromChunk(chunk) for i := range payloads { eventType := gjson.GetBytes(payloads[i], "type").String() if eventType == wsEventTypeCompleted { // log.Infof("replace %s with %s", wsEventTypeCompleted, wsEventTypeDone) payloads[i], _ = sjson.SetBytes(payloads[i], "type", wsEventTypeDone) completed = true completedOutput = responseCompletedOutputFromPayload(payloads[i]) } markAPIResponseTimestamp(c) appendWebsocketEvent(wsBodyLog, "response", payloads[i]) // log.Infof( // "responses websocket: downstream_out id=%s type=%d event=%s payload=%s", // sessionID, // websocket.TextMessage, // websocketPayloadEventType(payloads[i]), // websocketPayloadPreview(payloads[i]), // ) if errWrite := conn.WriteMessage(websocket.TextMessage, payloads[i]); errWrite != nil { log.Warnf( "responses websocket: downstream_out write failed id=%s event=%s error=%v", sessionID, websocketPayloadEventType(payloads[i]), errWrite, ) cancel(errWrite) return completedOutput, errWrite } } } } } func responseCompletedOutputFromPayload(payload []byte) []byte { output := gjson.GetBytes(payload, "response.output") if output.Exists() && output.IsArray() { return bytes.Clone([]byte(output.Raw)) } return []byte("[]") } func websocketJSONPayloadsFromChunk(chunk []byte) [][]byte { payloads := make([][]byte, 0, 2) lines := bytes.Split(chunk, []byte("\n")) for i := range lines { line := bytes.TrimSpace(lines[i]) if len(line) == 0 || bytes.HasPrefix(line, []byte("event:")) { continue } if bytes.HasPrefix(line, []byte("data:")) { line = bytes.TrimSpace(line[len("data:"):]) } if len(line) == 0 || bytes.Equal(line, []byte(wsDoneMarker)) { continue } if json.Valid(line) { payloads = append(payloads, bytes.Clone(line)) } } if len(payloads) > 0 { return payloads } trimmed := bytes.TrimSpace(chunk) if bytes.HasPrefix(trimmed, []byte("data:")) { trimmed = bytes.TrimSpace(trimmed[len("data:"):]) } if len(trimmed) > 0 && !bytes.Equal(trimmed, []byte(wsDoneMarker)) && json.Valid(trimmed) { payloads = append(payloads, bytes.Clone(trimmed)) } return payloads } func writeResponsesWebsocketError(conn *websocket.Conn, errMsg *interfaces.ErrorMessage) ([]byte, error) { status := http.StatusInternalServerError errText := http.StatusText(status) if errMsg != nil { if errMsg.StatusCode > 0 { status = errMsg.StatusCode errText = http.StatusText(status) } if errMsg.Error != nil && strings.TrimSpace(errMsg.Error.Error()) != "" { errText = errMsg.Error.Error() } } body := handlers.BuildErrorResponseBody(status, errText) payload := map[string]any{ "type": wsEventTypeError, "status": status, } if errMsg != nil && errMsg.Addon != nil { headers := map[string]any{} for key, values := range errMsg.Addon { if len(values) == 0 { continue } headers[key] = values[0] } if len(headers) > 0 { payload["headers"] = headers } } if len(body) > 0 && json.Valid(body) { var decoded map[string]any if errDecode := json.Unmarshal(body, &decoded); errDecode == nil { if inner, ok := decoded["error"]; ok { payload["error"] = inner } else { payload["error"] = decoded } } } if _, ok := payload["error"]; !ok { payload["error"] = map[string]any{ "type": "server_error", "message": errText, } } data, err := json.Marshal(payload) if err != nil { return nil, err } return data, conn.WriteMessage(websocket.TextMessage, data) } func appendWebsocketEvent(builder *strings.Builder, eventType string, payload []byte) { if builder == nil { return } trimmedPayload := bytes.TrimSpace(payload) if len(trimmedPayload) == 0 { return } if builder.Len() > 0 { builder.WriteString("\n") } builder.WriteString("websocket.") builder.WriteString(eventType) builder.WriteString("\n") builder.Write(trimmedPayload) builder.WriteString("\n") } func websocketPayloadEventType(payload []byte) string { eventType := strings.TrimSpace(gjson.GetBytes(payload, "type").String()) if eventType == "" { return "-" } return eventType } func websocketPayloadPreview(payload []byte) string { trimmedPayload := bytes.TrimSpace(payload) if len(trimmedPayload) == 0 { return "" } preview := trimmedPayload if len(preview) > wsPayloadLogMaxSize { preview = preview[:wsPayloadLogMaxSize] } previewText := strings.ReplaceAll(string(preview), "\n", "\\n") previewText = strings.ReplaceAll(previewText, "\r", "\\r") if len(trimmedPayload) > wsPayloadLogMaxSize { return fmt.Sprintf("%s...(truncated,total=%d)", previewText, len(trimmedPayload)) } return previewText } func setWebsocketRequestBody(c *gin.Context, body string) { if c == nil { return } trimmedBody := strings.TrimSpace(body) if trimmedBody == "" { return } c.Set(wsRequestBodyKey, []byte(trimmedBody)) } func markAPIResponseTimestamp(c *gin.Context) { if c == nil { return } if _, exists := c.Get("API_RESPONSE_TIMESTAMP"); exists { return } c.Set("API_RESPONSE_TIMESTAMP", time.Now()) }