mirror of
https://github.com/router-for-me/CLIProxyAPI.git
synced 2026-02-03 04:50:52 +08:00
feat(aistudio): support non-streaming responses
This commit is contained in:
@@ -1,6 +1,7 @@
|
|||||||
package wsrelay
|
package wsrelay
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"bytes"
|
||||||
"context"
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
@@ -44,21 +45,66 @@ func (m *Manager) RoundTrip(ctx context.Context, provider string, req *HTTPReque
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
var (
|
||||||
|
streamMode bool
|
||||||
|
streamResp *HTTPResponse
|
||||||
|
streamBody bytes.Buffer
|
||||||
|
)
|
||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
return nil, ctx.Err()
|
return nil, ctx.Err()
|
||||||
case msg, ok := <-respCh:
|
case msg, ok := <-respCh:
|
||||||
if !ok {
|
if !ok {
|
||||||
|
if streamMode {
|
||||||
|
if streamResp == nil {
|
||||||
|
streamResp = &HTTPResponse{Status: http.StatusOK, Headers: make(http.Header)}
|
||||||
|
} else if streamResp.Headers == nil {
|
||||||
|
streamResp.Headers = make(http.Header)
|
||||||
|
}
|
||||||
|
streamResp.Body = append(streamResp.Body[:0], streamBody.Bytes()...)
|
||||||
|
return streamResp, nil
|
||||||
|
}
|
||||||
return nil, errors.New("wsrelay: connection closed during response")
|
return nil, errors.New("wsrelay: connection closed during response")
|
||||||
}
|
}
|
||||||
switch msg.Type {
|
switch msg.Type {
|
||||||
case MessageTypeHTTPResp:
|
case MessageTypeHTTPResp:
|
||||||
return decodeResponse(msg.Payload), nil
|
resp := decodeResponse(msg.Payload)
|
||||||
|
if streamMode && streamBody.Len() > 0 && len(resp.Body) == 0 {
|
||||||
|
resp.Body = append(resp.Body[:0], streamBody.Bytes()...)
|
||||||
|
}
|
||||||
|
return resp, nil
|
||||||
case MessageTypeError:
|
case MessageTypeError:
|
||||||
return nil, decodeError(msg.Payload)
|
return nil, decodeError(msg.Payload)
|
||||||
case MessageTypeStreamStart, MessageTypeStreamChunk:
|
case MessageTypeStreamStart, MessageTypeStreamChunk:
|
||||||
// Ignore streaming noise in non-stream requests.
|
if msg.Type == MessageTypeStreamStart {
|
||||||
|
streamMode = true
|
||||||
|
streamResp = decodeResponse(msg.Payload)
|
||||||
|
if streamResp.Headers == nil {
|
||||||
|
streamResp.Headers = make(http.Header)
|
||||||
|
}
|
||||||
|
streamBody.Reset()
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if !streamMode {
|
||||||
|
streamMode = true
|
||||||
|
streamResp = &HTTPResponse{Status: http.StatusOK, Headers: make(http.Header)}
|
||||||
|
}
|
||||||
|
chunk := decodeChunk(msg.Payload)
|
||||||
|
if len(chunk) > 0 {
|
||||||
|
streamBody.Write(chunk)
|
||||||
|
}
|
||||||
|
case MessageTypeStreamEnd:
|
||||||
|
if !streamMode {
|
||||||
|
return &HTTPResponse{Status: http.StatusOK, Headers: make(http.Header)}, nil
|
||||||
|
}
|
||||||
|
if streamResp == nil {
|
||||||
|
streamResp = &HTTPResponse{Status: http.StatusOK, Headers: make(http.Header)}
|
||||||
|
} else if streamResp.Headers == nil {
|
||||||
|
streamResp.Headers = make(http.Header)
|
||||||
|
}
|
||||||
|
streamResp.Body = append(streamResp.Body[:0], streamBody.Bytes()...)
|
||||||
|
return streamResp, nil
|
||||||
default:
|
default:
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user