feat(aistudio): support non-streaming responses

This commit is contained in:
hkfires
2025-10-25 15:57:27 +08:00
parent c32e013605
commit 8aaed4cf09

View File

@@ -1,6 +1,7 @@
package wsrelay package wsrelay
import ( import (
"bytes"
"context" "context"
"errors" "errors"
"fmt" "fmt"
@@ -44,21 +45,66 @@ func (m *Manager) RoundTrip(ctx context.Context, provider string, req *HTTPReque
if err != nil { if err != nil {
return nil, err return nil, err
} }
var (
streamMode bool
streamResp *HTTPResponse
streamBody bytes.Buffer
)
for { for {
select { select {
case <-ctx.Done(): case <-ctx.Done():
return nil, ctx.Err() return nil, ctx.Err()
case msg, ok := <-respCh: case msg, ok := <-respCh:
if !ok { if !ok {
if streamMode {
if streamResp == nil {
streamResp = &HTTPResponse{Status: http.StatusOK, Headers: make(http.Header)}
} else if streamResp.Headers == nil {
streamResp.Headers = make(http.Header)
}
streamResp.Body = append(streamResp.Body[:0], streamBody.Bytes()...)
return streamResp, nil
}
return nil, errors.New("wsrelay: connection closed during response") return nil, errors.New("wsrelay: connection closed during response")
} }
switch msg.Type { switch msg.Type {
case MessageTypeHTTPResp: case MessageTypeHTTPResp:
return decodeResponse(msg.Payload), nil resp := decodeResponse(msg.Payload)
if streamMode && streamBody.Len() > 0 && len(resp.Body) == 0 {
resp.Body = append(resp.Body[:0], streamBody.Bytes()...)
}
return resp, nil
case MessageTypeError: case MessageTypeError:
return nil, decodeError(msg.Payload) return nil, decodeError(msg.Payload)
case MessageTypeStreamStart, MessageTypeStreamChunk: case MessageTypeStreamStart, MessageTypeStreamChunk:
// Ignore streaming noise in non-stream requests. if msg.Type == MessageTypeStreamStart {
streamMode = true
streamResp = decodeResponse(msg.Payload)
if streamResp.Headers == nil {
streamResp.Headers = make(http.Header)
}
streamBody.Reset()
continue
}
if !streamMode {
streamMode = true
streamResp = &HTTPResponse{Status: http.StatusOK, Headers: make(http.Header)}
}
chunk := decodeChunk(msg.Payload)
if len(chunk) > 0 {
streamBody.Write(chunk)
}
case MessageTypeStreamEnd:
if !streamMode {
return &HTTPResponse{Status: http.StatusOK, Headers: make(http.Header)}, nil
}
if streamResp == nil {
streamResp = &HTTPResponse{Status: http.StatusOK, Headers: make(http.Header)}
} else if streamResp.Headers == nil {
streamResp.Headers = make(http.Header)
}
streamResp.Body = append(streamResp.Body[:0], streamBody.Bytes()...)
return streamResp, nil
default: default:
} }
} }