From 8aaed4cf09c6290bdce14b8d39717ff3a27bf786 Mon Sep 17 00:00:00 2001 From: hkfires <10558748+hkfires@users.noreply.github.com> Date: Sat, 25 Oct 2025 15:57:27 +0800 Subject: [PATCH] feat(aistudio): support non-streaming responses --- internal/wsrelay/http.go | 50 ++++++++++++++++++++++++++++++++++++++-- 1 file changed, 48 insertions(+), 2 deletions(-) diff --git a/internal/wsrelay/http.go b/internal/wsrelay/http.go index 96f80ec3..f34a61ca 100644 --- a/internal/wsrelay/http.go +++ b/internal/wsrelay/http.go @@ -1,6 +1,7 @@ package wsrelay import ( + "bytes" "context" "errors" "fmt" @@ -44,21 +45,66 @@ func (m *Manager) RoundTrip(ctx context.Context, provider string, req *HTTPReque if err != nil { return nil, err } + var ( + streamMode bool + streamResp *HTTPResponse + streamBody bytes.Buffer + ) for { select { case <-ctx.Done(): return nil, ctx.Err() case msg, ok := <-respCh: if !ok { + if streamMode { + if streamResp == nil { + streamResp = &HTTPResponse{Status: http.StatusOK, Headers: make(http.Header)} + } else if streamResp.Headers == nil { + streamResp.Headers = make(http.Header) + } + streamResp.Body = append(streamResp.Body[:0], streamBody.Bytes()...) + return streamResp, nil + } return nil, errors.New("wsrelay: connection closed during response") } switch msg.Type { case MessageTypeHTTPResp: - return decodeResponse(msg.Payload), nil + resp := decodeResponse(msg.Payload) + if streamMode && streamBody.Len() > 0 && len(resp.Body) == 0 { + resp.Body = append(resp.Body[:0], streamBody.Bytes()...) + } + return resp, nil case MessageTypeError: return nil, decodeError(msg.Payload) case MessageTypeStreamStart, MessageTypeStreamChunk: - // Ignore streaming noise in non-stream requests. + if msg.Type == MessageTypeStreamStart { + streamMode = true + streamResp = decodeResponse(msg.Payload) + if streamResp.Headers == nil { + streamResp.Headers = make(http.Header) + } + streamBody.Reset() + continue + } + if !streamMode { + streamMode = true + streamResp = &HTTPResponse{Status: http.StatusOK, Headers: make(http.Header)} + } + chunk := decodeChunk(msg.Payload) + if len(chunk) > 0 { + streamBody.Write(chunk) + } + case MessageTypeStreamEnd: + if !streamMode { + return &HTTPResponse{Status: http.StatusOK, Headers: make(http.Header)}, nil + } + if streamResp == nil { + streamResp = &HTTPResponse{Status: http.StatusOK, Headers: make(http.Header)} + } else if streamResp.Headers == nil { + streamResp.Headers = make(http.Header) + } + streamResp.Body = append(streamResp.Body[:0], streamBody.Bytes()...) + return streamResp, nil default: } }