mirror of
https://github.com/router-for-me/CLIProxyAPI.git
synced 2026-02-20 13:20:52 +08:00
feat(logging, executor): add request logging tests and WebSocket-based Codex executor
- Introduced unit tests for request logging middleware to enhance coverage. - Added WebSocket-based Codex executor to support Responses API upgrade. - Updated middleware logic to selectively capture request bodies for memory efficiency. - Enhanced Codex configuration handling with new WebSocket attributes.
This commit is contained in:
662
sdk/api/handlers/openai/openai_responses_websocket.go
Normal file
662
sdk/api/handlers/openai/openai_responses_websocket.go
Normal file
@@ -0,0 +1,662 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/google/uuid"
|
||||
"github.com/gorilla/websocket"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/interfaces"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/sdk/api/handlers"
|
||||
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/tidwall/gjson"
|
||||
"github.com/tidwall/sjson"
|
||||
)
|
||||
|
||||
const (
|
||||
wsRequestTypeCreate = "response.create"
|
||||
wsRequestTypeAppend = "response.append"
|
||||
wsEventTypeError = "error"
|
||||
wsEventTypeCompleted = "response.completed"
|
||||
wsEventTypeDone = "response.done"
|
||||
wsDoneMarker = "[DONE]"
|
||||
wsTurnStateHeader = "x-codex-turn-state"
|
||||
wsRequestBodyKey = "REQUEST_BODY_OVERRIDE"
|
||||
wsPayloadLogMaxSize = 2048
|
||||
)
|
||||
|
||||
var responsesWebsocketUpgrader = websocket.Upgrader{
|
||||
ReadBufferSize: 4096,
|
||||
WriteBufferSize: 4096,
|
||||
CheckOrigin: func(r *http.Request) bool {
|
||||
return true
|
||||
},
|
||||
}
|
||||
|
||||
// ResponsesWebsocket handles websocket requests for /v1/responses.
|
||||
// It accepts `response.create` and `response.append` requests and streams
|
||||
// response events back as JSON websocket text messages.
|
||||
func (h *OpenAIResponsesAPIHandler) ResponsesWebsocket(c *gin.Context) {
|
||||
conn, err := responsesWebsocketUpgrader.Upgrade(c.Writer, c.Request, websocketUpgradeHeaders(c.Request))
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
passthroughSessionID := uuid.NewString()
|
||||
clientRemoteAddr := ""
|
||||
if c != nil && c.Request != nil {
|
||||
clientRemoteAddr = strings.TrimSpace(c.Request.RemoteAddr)
|
||||
}
|
||||
log.Infof("responses websocket: client connected id=%s remote=%s", passthroughSessionID, clientRemoteAddr)
|
||||
var wsTerminateErr error
|
||||
var wsBodyLog strings.Builder
|
||||
defer func() {
|
||||
if wsTerminateErr != nil {
|
||||
// log.Infof("responses websocket: session closing id=%s reason=%v", passthroughSessionID, wsTerminateErr)
|
||||
} else {
|
||||
log.Infof("responses websocket: session closing id=%s", passthroughSessionID)
|
||||
}
|
||||
if h != nil && h.AuthManager != nil {
|
||||
h.AuthManager.CloseExecutionSession(passthroughSessionID)
|
||||
log.Infof("responses websocket: upstream execution session closed id=%s", passthroughSessionID)
|
||||
}
|
||||
setWebsocketRequestBody(c, wsBodyLog.String())
|
||||
if errClose := conn.Close(); errClose != nil {
|
||||
log.Warnf("responses websocket: close connection error: %v", errClose)
|
||||
}
|
||||
}()
|
||||
|
||||
var lastRequest []byte
|
||||
lastResponseOutput := []byte("[]")
|
||||
pinnedAuthID := ""
|
||||
|
||||
for {
|
||||
msgType, payload, errReadMessage := conn.ReadMessage()
|
||||
if errReadMessage != nil {
|
||||
wsTerminateErr = errReadMessage
|
||||
appendWebsocketEvent(&wsBodyLog, "disconnect", []byte(errReadMessage.Error()))
|
||||
if websocket.IsCloseError(errReadMessage, websocket.CloseNormalClosure, websocket.CloseGoingAway, websocket.CloseNoStatusReceived) {
|
||||
log.Infof("responses websocket: client disconnected id=%s error=%v", passthroughSessionID, errReadMessage)
|
||||
} else {
|
||||
// log.Warnf("responses websocket: read message failed id=%s error=%v", passthroughSessionID, errReadMessage)
|
||||
}
|
||||
return
|
||||
}
|
||||
if msgType != websocket.TextMessage && msgType != websocket.BinaryMessage {
|
||||
continue
|
||||
}
|
||||
// log.Infof(
|
||||
// "responses websocket: downstream_in id=%s type=%d event=%s payload=%s",
|
||||
// passthroughSessionID,
|
||||
// msgType,
|
||||
// websocketPayloadEventType(payload),
|
||||
// websocketPayloadPreview(payload),
|
||||
// )
|
||||
appendWebsocketEvent(&wsBodyLog, "request", payload)
|
||||
|
||||
allowIncrementalInputWithPreviousResponseID := websocketUpstreamSupportsIncrementalInput(nil, nil)
|
||||
if pinnedAuthID != "" && h != nil && h.AuthManager != nil {
|
||||
if pinnedAuth, ok := h.AuthManager.GetByID(pinnedAuthID); ok && pinnedAuth != nil {
|
||||
allowIncrementalInputWithPreviousResponseID = websocketUpstreamSupportsIncrementalInput(pinnedAuth.Attributes, pinnedAuth.Metadata)
|
||||
}
|
||||
}
|
||||
|
||||
var requestJSON []byte
|
||||
var updatedLastRequest []byte
|
||||
var errMsg *interfaces.ErrorMessage
|
||||
requestJSON, updatedLastRequest, errMsg = normalizeResponsesWebsocketRequestWithMode(
|
||||
payload,
|
||||
lastRequest,
|
||||
lastResponseOutput,
|
||||
allowIncrementalInputWithPreviousResponseID,
|
||||
)
|
||||
if errMsg != nil {
|
||||
h.LoggingAPIResponseError(context.WithValue(context.Background(), "gin", c), errMsg)
|
||||
markAPIResponseTimestamp(c)
|
||||
errorPayload, errWrite := writeResponsesWebsocketError(conn, errMsg)
|
||||
appendWebsocketEvent(&wsBodyLog, "response", errorPayload)
|
||||
log.Infof(
|
||||
"responses websocket: downstream_out id=%s type=%d event=%s payload=%s",
|
||||
passthroughSessionID,
|
||||
websocket.TextMessage,
|
||||
websocketPayloadEventType(errorPayload),
|
||||
websocketPayloadPreview(errorPayload),
|
||||
)
|
||||
if errWrite != nil {
|
||||
log.Warnf(
|
||||
"responses websocket: downstream_out write failed id=%s event=%s error=%v",
|
||||
passthroughSessionID,
|
||||
websocketPayloadEventType(errorPayload),
|
||||
errWrite,
|
||||
)
|
||||
return
|
||||
}
|
||||
continue
|
||||
}
|
||||
lastRequest = updatedLastRequest
|
||||
|
||||
modelName := gjson.GetBytes(requestJSON, "model").String()
|
||||
cliCtx, cliCancel := h.GetContextWithCancel(h, c, context.Background())
|
||||
cliCtx = cliproxyexecutor.WithDownstreamWebsocket(cliCtx)
|
||||
cliCtx = handlers.WithExecutionSessionID(cliCtx, passthroughSessionID)
|
||||
if pinnedAuthID != "" {
|
||||
cliCtx = handlers.WithPinnedAuthID(cliCtx, pinnedAuthID)
|
||||
} else {
|
||||
cliCtx = handlers.WithSelectedAuthIDCallback(cliCtx, func(authID string) {
|
||||
pinnedAuthID = strings.TrimSpace(authID)
|
||||
})
|
||||
}
|
||||
dataChan, errChan := h.ExecuteStreamWithAuthManager(cliCtx, h.HandlerType(), modelName, requestJSON, "")
|
||||
|
||||
completedOutput, errForward := h.forwardResponsesWebsocket(c, conn, cliCancel, dataChan, errChan, &wsBodyLog, passthroughSessionID)
|
||||
if errForward != nil {
|
||||
wsTerminateErr = errForward
|
||||
appendWebsocketEvent(&wsBodyLog, "disconnect", []byte(errForward.Error()))
|
||||
log.Warnf("responses websocket: forward failed id=%s error=%v", passthroughSessionID, errForward)
|
||||
return
|
||||
}
|
||||
lastResponseOutput = completedOutput
|
||||
}
|
||||
}
|
||||
|
||||
func websocketUpgradeHeaders(req *http.Request) http.Header {
|
||||
headers := http.Header{}
|
||||
if req == nil {
|
||||
return headers
|
||||
}
|
||||
|
||||
// Keep the same sticky turn-state across reconnects when provided by the client.
|
||||
turnState := strings.TrimSpace(req.Header.Get(wsTurnStateHeader))
|
||||
if turnState != "" {
|
||||
headers.Set(wsTurnStateHeader, turnState)
|
||||
}
|
||||
return headers
|
||||
}
|
||||
|
||||
func normalizeResponsesWebsocketRequest(rawJSON []byte, lastRequest []byte, lastResponseOutput []byte) ([]byte, []byte, *interfaces.ErrorMessage) {
|
||||
return normalizeResponsesWebsocketRequestWithMode(rawJSON, lastRequest, lastResponseOutput, true)
|
||||
}
|
||||
|
||||
func normalizeResponsesWebsocketRequestWithMode(rawJSON []byte, lastRequest []byte, lastResponseOutput []byte, allowIncrementalInputWithPreviousResponseID bool) ([]byte, []byte, *interfaces.ErrorMessage) {
|
||||
requestType := strings.TrimSpace(gjson.GetBytes(rawJSON, "type").String())
|
||||
switch requestType {
|
||||
case wsRequestTypeCreate:
|
||||
// log.Infof("responses websocket: response.create request")
|
||||
if len(lastRequest) == 0 {
|
||||
return normalizeResponseCreateRequest(rawJSON)
|
||||
}
|
||||
return normalizeResponseSubsequentRequest(rawJSON, lastRequest, lastResponseOutput, allowIncrementalInputWithPreviousResponseID)
|
||||
case wsRequestTypeAppend:
|
||||
// log.Infof("responses websocket: response.append request")
|
||||
return normalizeResponseSubsequentRequest(rawJSON, lastRequest, lastResponseOutput, allowIncrementalInputWithPreviousResponseID)
|
||||
default:
|
||||
return nil, lastRequest, &interfaces.ErrorMessage{
|
||||
StatusCode: http.StatusBadRequest,
|
||||
Error: fmt.Errorf("unsupported websocket request type: %s", requestType),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func normalizeResponseCreateRequest(rawJSON []byte) ([]byte, []byte, *interfaces.ErrorMessage) {
|
||||
normalized, errDelete := sjson.DeleteBytes(rawJSON, "type")
|
||||
if errDelete != nil {
|
||||
normalized = bytes.Clone(rawJSON)
|
||||
}
|
||||
normalized, _ = sjson.SetBytes(normalized, "stream", true)
|
||||
if !gjson.GetBytes(normalized, "input").Exists() {
|
||||
normalized, _ = sjson.SetRawBytes(normalized, "input", []byte("[]"))
|
||||
}
|
||||
|
||||
modelName := strings.TrimSpace(gjson.GetBytes(normalized, "model").String())
|
||||
if modelName == "" {
|
||||
return nil, nil, &interfaces.ErrorMessage{
|
||||
StatusCode: http.StatusBadRequest,
|
||||
Error: fmt.Errorf("missing model in response.create request"),
|
||||
}
|
||||
}
|
||||
return normalized, bytes.Clone(normalized), nil
|
||||
}
|
||||
|
||||
func normalizeResponseSubsequentRequest(rawJSON []byte, lastRequest []byte, lastResponseOutput []byte, allowIncrementalInputWithPreviousResponseID bool) ([]byte, []byte, *interfaces.ErrorMessage) {
|
||||
if len(lastRequest) == 0 {
|
||||
return nil, lastRequest, &interfaces.ErrorMessage{
|
||||
StatusCode: http.StatusBadRequest,
|
||||
Error: fmt.Errorf("websocket request received before response.create"),
|
||||
}
|
||||
}
|
||||
|
||||
nextInput := gjson.GetBytes(rawJSON, "input")
|
||||
if !nextInput.Exists() || !nextInput.IsArray() {
|
||||
return nil, lastRequest, &interfaces.ErrorMessage{
|
||||
StatusCode: http.StatusBadRequest,
|
||||
Error: fmt.Errorf("websocket request requires array field: input"),
|
||||
}
|
||||
}
|
||||
|
||||
// Websocket v2 mode uses response.create with previous_response_id + incremental input.
|
||||
// Do not expand it into a full input transcript; upstream expects the incremental payload.
|
||||
if allowIncrementalInputWithPreviousResponseID {
|
||||
if prev := strings.TrimSpace(gjson.GetBytes(rawJSON, "previous_response_id").String()); prev != "" {
|
||||
normalized, errDelete := sjson.DeleteBytes(rawJSON, "type")
|
||||
if errDelete != nil {
|
||||
normalized = bytes.Clone(rawJSON)
|
||||
}
|
||||
if !gjson.GetBytes(normalized, "model").Exists() {
|
||||
modelName := strings.TrimSpace(gjson.GetBytes(lastRequest, "model").String())
|
||||
if modelName != "" {
|
||||
normalized, _ = sjson.SetBytes(normalized, "model", modelName)
|
||||
}
|
||||
}
|
||||
if !gjson.GetBytes(normalized, "instructions").Exists() {
|
||||
instructions := gjson.GetBytes(lastRequest, "instructions")
|
||||
if instructions.Exists() {
|
||||
normalized, _ = sjson.SetRawBytes(normalized, "instructions", []byte(instructions.Raw))
|
||||
}
|
||||
}
|
||||
normalized, _ = sjson.SetBytes(normalized, "stream", true)
|
||||
return normalized, bytes.Clone(normalized), nil
|
||||
}
|
||||
}
|
||||
|
||||
existingInput := gjson.GetBytes(lastRequest, "input")
|
||||
mergedInput, errMerge := mergeJSONArrayRaw(existingInput.Raw, normalizeJSONArrayRaw(lastResponseOutput))
|
||||
if errMerge != nil {
|
||||
return nil, lastRequest, &interfaces.ErrorMessage{
|
||||
StatusCode: http.StatusBadRequest,
|
||||
Error: fmt.Errorf("invalid previous response output: %w", errMerge),
|
||||
}
|
||||
}
|
||||
|
||||
mergedInput, errMerge = mergeJSONArrayRaw(mergedInput, nextInput.Raw)
|
||||
if errMerge != nil {
|
||||
return nil, lastRequest, &interfaces.ErrorMessage{
|
||||
StatusCode: http.StatusBadRequest,
|
||||
Error: fmt.Errorf("invalid request input: %w", errMerge),
|
||||
}
|
||||
}
|
||||
|
||||
normalized, errDelete := sjson.DeleteBytes(rawJSON, "type")
|
||||
if errDelete != nil {
|
||||
normalized = bytes.Clone(rawJSON)
|
||||
}
|
||||
normalized, _ = sjson.DeleteBytes(normalized, "previous_response_id")
|
||||
var errSet error
|
||||
normalized, errSet = sjson.SetRawBytes(normalized, "input", []byte(mergedInput))
|
||||
if errSet != nil {
|
||||
return nil, lastRequest, &interfaces.ErrorMessage{
|
||||
StatusCode: http.StatusBadRequest,
|
||||
Error: fmt.Errorf("failed to merge websocket input: %w", errSet),
|
||||
}
|
||||
}
|
||||
if !gjson.GetBytes(normalized, "model").Exists() {
|
||||
modelName := strings.TrimSpace(gjson.GetBytes(lastRequest, "model").String())
|
||||
if modelName != "" {
|
||||
normalized, _ = sjson.SetBytes(normalized, "model", modelName)
|
||||
}
|
||||
}
|
||||
if !gjson.GetBytes(normalized, "instructions").Exists() {
|
||||
instructions := gjson.GetBytes(lastRequest, "instructions")
|
||||
if instructions.Exists() {
|
||||
normalized, _ = sjson.SetRawBytes(normalized, "instructions", []byte(instructions.Raw))
|
||||
}
|
||||
}
|
||||
normalized, _ = sjson.SetBytes(normalized, "stream", true)
|
||||
return normalized, bytes.Clone(normalized), nil
|
||||
}
|
||||
|
||||
func websocketUpstreamSupportsIncrementalInput(attributes map[string]string, metadata map[string]any) bool {
|
||||
if len(attributes) > 0 {
|
||||
if raw := strings.TrimSpace(attributes["websockets"]); raw != "" {
|
||||
parsed, errParse := strconv.ParseBool(raw)
|
||||
if errParse == nil {
|
||||
return parsed
|
||||
}
|
||||
}
|
||||
}
|
||||
if len(metadata) == 0 {
|
||||
return false
|
||||
}
|
||||
raw, ok := metadata["websockets"]
|
||||
if !ok || raw == nil {
|
||||
return false
|
||||
}
|
||||
switch value := raw.(type) {
|
||||
case bool:
|
||||
return value
|
||||
case string:
|
||||
parsed, errParse := strconv.ParseBool(strings.TrimSpace(value))
|
||||
if errParse == nil {
|
||||
return parsed
|
||||
}
|
||||
default:
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func mergeJSONArrayRaw(existingRaw, appendRaw string) (string, error) {
|
||||
existingRaw = strings.TrimSpace(existingRaw)
|
||||
appendRaw = strings.TrimSpace(appendRaw)
|
||||
if existingRaw == "" {
|
||||
existingRaw = "[]"
|
||||
}
|
||||
if appendRaw == "" {
|
||||
appendRaw = "[]"
|
||||
}
|
||||
|
||||
var existing []json.RawMessage
|
||||
if err := json.Unmarshal([]byte(existingRaw), &existing); err != nil {
|
||||
return "", err
|
||||
}
|
||||
var appendItems []json.RawMessage
|
||||
if err := json.Unmarshal([]byte(appendRaw), &appendItems); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
merged := append(existing, appendItems...)
|
||||
out, err := json.Marshal(merged)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return string(out), nil
|
||||
}
|
||||
|
||||
func normalizeJSONArrayRaw(raw []byte) string {
|
||||
trimmed := strings.TrimSpace(string(raw))
|
||||
if trimmed == "" {
|
||||
return "[]"
|
||||
}
|
||||
result := gjson.Parse(trimmed)
|
||||
if result.Type == gjson.JSON && result.IsArray() {
|
||||
return trimmed
|
||||
}
|
||||
return "[]"
|
||||
}
|
||||
|
||||
func (h *OpenAIResponsesAPIHandler) forwardResponsesWebsocket(
|
||||
c *gin.Context,
|
||||
conn *websocket.Conn,
|
||||
cancel handlers.APIHandlerCancelFunc,
|
||||
data <-chan []byte,
|
||||
errs <-chan *interfaces.ErrorMessage,
|
||||
wsBodyLog *strings.Builder,
|
||||
sessionID string,
|
||||
) ([]byte, error) {
|
||||
completed := false
|
||||
completedOutput := []byte("[]")
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-c.Request.Context().Done():
|
||||
cancel(c.Request.Context().Err())
|
||||
return completedOutput, c.Request.Context().Err()
|
||||
case errMsg, ok := <-errs:
|
||||
if !ok {
|
||||
errs = nil
|
||||
continue
|
||||
}
|
||||
if errMsg != nil {
|
||||
h.LoggingAPIResponseError(context.WithValue(context.Background(), "gin", c), errMsg)
|
||||
markAPIResponseTimestamp(c)
|
||||
errorPayload, errWrite := writeResponsesWebsocketError(conn, errMsg)
|
||||
appendWebsocketEvent(wsBodyLog, "response", errorPayload)
|
||||
log.Infof(
|
||||
"responses websocket: downstream_out id=%s type=%d event=%s payload=%s",
|
||||
sessionID,
|
||||
websocket.TextMessage,
|
||||
websocketPayloadEventType(errorPayload),
|
||||
websocketPayloadPreview(errorPayload),
|
||||
)
|
||||
if errWrite != nil {
|
||||
// log.Warnf(
|
||||
// "responses websocket: downstream_out write failed id=%s event=%s error=%v",
|
||||
// sessionID,
|
||||
// websocketPayloadEventType(errorPayload),
|
||||
// errWrite,
|
||||
// )
|
||||
cancel(errMsg.Error)
|
||||
return completedOutput, errWrite
|
||||
}
|
||||
}
|
||||
if errMsg != nil {
|
||||
cancel(errMsg.Error)
|
||||
} else {
|
||||
cancel(nil)
|
||||
}
|
||||
return completedOutput, nil
|
||||
case chunk, ok := <-data:
|
||||
if !ok {
|
||||
if !completed {
|
||||
errMsg := &interfaces.ErrorMessage{
|
||||
StatusCode: http.StatusRequestTimeout,
|
||||
Error: fmt.Errorf("stream closed before response.completed"),
|
||||
}
|
||||
h.LoggingAPIResponseError(context.WithValue(context.Background(), "gin", c), errMsg)
|
||||
markAPIResponseTimestamp(c)
|
||||
errorPayload, errWrite := writeResponsesWebsocketError(conn, errMsg)
|
||||
appendWebsocketEvent(wsBodyLog, "response", errorPayload)
|
||||
log.Infof(
|
||||
"responses websocket: downstream_out id=%s type=%d event=%s payload=%s",
|
||||
sessionID,
|
||||
websocket.TextMessage,
|
||||
websocketPayloadEventType(errorPayload),
|
||||
websocketPayloadPreview(errorPayload),
|
||||
)
|
||||
if errWrite != nil {
|
||||
log.Warnf(
|
||||
"responses websocket: downstream_out write failed id=%s event=%s error=%v",
|
||||
sessionID,
|
||||
websocketPayloadEventType(errorPayload),
|
||||
errWrite,
|
||||
)
|
||||
cancel(errMsg.Error)
|
||||
return completedOutput, errWrite
|
||||
}
|
||||
cancel(errMsg.Error)
|
||||
return completedOutput, nil
|
||||
}
|
||||
cancel(nil)
|
||||
return completedOutput, nil
|
||||
}
|
||||
|
||||
payloads := websocketJSONPayloadsFromChunk(chunk)
|
||||
for i := range payloads {
|
||||
eventType := gjson.GetBytes(payloads[i], "type").String()
|
||||
if eventType == wsEventTypeCompleted {
|
||||
// log.Infof("replace %s with %s", wsEventTypeCompleted, wsEventTypeDone)
|
||||
payloads[i], _ = sjson.SetBytes(payloads[i], "type", wsEventTypeDone)
|
||||
|
||||
completed = true
|
||||
completedOutput = responseCompletedOutputFromPayload(payloads[i])
|
||||
}
|
||||
markAPIResponseTimestamp(c)
|
||||
appendWebsocketEvent(wsBodyLog, "response", payloads[i])
|
||||
// log.Infof(
|
||||
// "responses websocket: downstream_out id=%s type=%d event=%s payload=%s",
|
||||
// sessionID,
|
||||
// websocket.TextMessage,
|
||||
// websocketPayloadEventType(payloads[i]),
|
||||
// websocketPayloadPreview(payloads[i]),
|
||||
// )
|
||||
if errWrite := conn.WriteMessage(websocket.TextMessage, payloads[i]); errWrite != nil {
|
||||
log.Warnf(
|
||||
"responses websocket: downstream_out write failed id=%s event=%s error=%v",
|
||||
sessionID,
|
||||
websocketPayloadEventType(payloads[i]),
|
||||
errWrite,
|
||||
)
|
||||
cancel(errWrite)
|
||||
return completedOutput, errWrite
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func responseCompletedOutputFromPayload(payload []byte) []byte {
|
||||
output := gjson.GetBytes(payload, "response.output")
|
||||
if output.Exists() && output.IsArray() {
|
||||
return bytes.Clone([]byte(output.Raw))
|
||||
}
|
||||
return []byte("[]")
|
||||
}
|
||||
|
||||
func websocketJSONPayloadsFromChunk(chunk []byte) [][]byte {
|
||||
payloads := make([][]byte, 0, 2)
|
||||
lines := bytes.Split(chunk, []byte("\n"))
|
||||
for i := range lines {
|
||||
line := bytes.TrimSpace(lines[i])
|
||||
if len(line) == 0 || bytes.HasPrefix(line, []byte("event:")) {
|
||||
continue
|
||||
}
|
||||
if bytes.HasPrefix(line, []byte("data:")) {
|
||||
line = bytes.TrimSpace(line[len("data:"):])
|
||||
}
|
||||
if len(line) == 0 || bytes.Equal(line, []byte(wsDoneMarker)) {
|
||||
continue
|
||||
}
|
||||
if json.Valid(line) {
|
||||
payloads = append(payloads, bytes.Clone(line))
|
||||
}
|
||||
}
|
||||
|
||||
if len(payloads) > 0 {
|
||||
return payloads
|
||||
}
|
||||
|
||||
trimmed := bytes.TrimSpace(chunk)
|
||||
if bytes.HasPrefix(trimmed, []byte("data:")) {
|
||||
trimmed = bytes.TrimSpace(trimmed[len("data:"):])
|
||||
}
|
||||
if len(trimmed) > 0 && !bytes.Equal(trimmed, []byte(wsDoneMarker)) && json.Valid(trimmed) {
|
||||
payloads = append(payloads, bytes.Clone(trimmed))
|
||||
}
|
||||
return payloads
|
||||
}
|
||||
|
||||
func writeResponsesWebsocketError(conn *websocket.Conn, errMsg *interfaces.ErrorMessage) ([]byte, error) {
|
||||
status := http.StatusInternalServerError
|
||||
errText := http.StatusText(status)
|
||||
if errMsg != nil {
|
||||
if errMsg.StatusCode > 0 {
|
||||
status = errMsg.StatusCode
|
||||
errText = http.StatusText(status)
|
||||
}
|
||||
if errMsg.Error != nil && strings.TrimSpace(errMsg.Error.Error()) != "" {
|
||||
errText = errMsg.Error.Error()
|
||||
}
|
||||
}
|
||||
|
||||
body := handlers.BuildErrorResponseBody(status, errText)
|
||||
payload := map[string]any{
|
||||
"type": wsEventTypeError,
|
||||
"status": status,
|
||||
}
|
||||
|
||||
if errMsg != nil && errMsg.Addon != nil {
|
||||
headers := map[string]any{}
|
||||
for key, values := range errMsg.Addon {
|
||||
if len(values) == 0 {
|
||||
continue
|
||||
}
|
||||
headers[key] = values[0]
|
||||
}
|
||||
if len(headers) > 0 {
|
||||
payload["headers"] = headers
|
||||
}
|
||||
}
|
||||
|
||||
if len(body) > 0 && json.Valid(body) {
|
||||
var decoded map[string]any
|
||||
if errDecode := json.Unmarshal(body, &decoded); errDecode == nil {
|
||||
if inner, ok := decoded["error"]; ok {
|
||||
payload["error"] = inner
|
||||
} else {
|
||||
payload["error"] = decoded
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if _, ok := payload["error"]; !ok {
|
||||
payload["error"] = map[string]any{
|
||||
"type": "server_error",
|
||||
"message": errText,
|
||||
}
|
||||
}
|
||||
|
||||
data, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return data, conn.WriteMessage(websocket.TextMessage, data)
|
||||
}
|
||||
|
||||
func appendWebsocketEvent(builder *strings.Builder, eventType string, payload []byte) {
|
||||
if builder == nil {
|
||||
return
|
||||
}
|
||||
trimmedPayload := bytes.TrimSpace(payload)
|
||||
if len(trimmedPayload) == 0 {
|
||||
return
|
||||
}
|
||||
if builder.Len() > 0 {
|
||||
builder.WriteString("\n")
|
||||
}
|
||||
builder.WriteString("websocket.")
|
||||
builder.WriteString(eventType)
|
||||
builder.WriteString("\n")
|
||||
builder.Write(trimmedPayload)
|
||||
builder.WriteString("\n")
|
||||
}
|
||||
|
||||
func websocketPayloadEventType(payload []byte) string {
|
||||
eventType := strings.TrimSpace(gjson.GetBytes(payload, "type").String())
|
||||
if eventType == "" {
|
||||
return "-"
|
||||
}
|
||||
return eventType
|
||||
}
|
||||
|
||||
func websocketPayloadPreview(payload []byte) string {
|
||||
trimmedPayload := bytes.TrimSpace(payload)
|
||||
if len(trimmedPayload) == 0 {
|
||||
return "<empty>"
|
||||
}
|
||||
preview := trimmedPayload
|
||||
if len(preview) > wsPayloadLogMaxSize {
|
||||
preview = preview[:wsPayloadLogMaxSize]
|
||||
}
|
||||
previewText := strings.ReplaceAll(string(preview), "\n", "\\n")
|
||||
previewText = strings.ReplaceAll(previewText, "\r", "\\r")
|
||||
if len(trimmedPayload) > wsPayloadLogMaxSize {
|
||||
return fmt.Sprintf("%s...(truncated,total=%d)", previewText, len(trimmedPayload))
|
||||
}
|
||||
return previewText
|
||||
}
|
||||
|
||||
func setWebsocketRequestBody(c *gin.Context, body string) {
|
||||
if c == nil {
|
||||
return
|
||||
}
|
||||
trimmedBody := strings.TrimSpace(body)
|
||||
if trimmedBody == "" {
|
||||
return
|
||||
}
|
||||
c.Set(wsRequestBodyKey, []byte(trimmedBody))
|
||||
}
|
||||
|
||||
func markAPIResponseTimestamp(c *gin.Context) {
|
||||
if c == nil {
|
||||
return
|
||||
}
|
||||
if _, exists := c.Get("API_RESPONSE_TIMESTAMP"); exists {
|
||||
return
|
||||
}
|
||||
c.Set("API_RESPONSE_TIMESTAMP", time.Now())
|
||||
}
|
||||
Reference in New Issue
Block a user