Files
CLIProxyAPI/internal/wsrelay/http.go
Luis Pater e93e05ae25 refactor: consolidate channel send logic with context-safe handlers
Optimize channel operations by introducing reusable context-aware send functions (`send` and `sendErr`) across `wsrelay`, `handlers`, and `cliproxy`. Ensure graceful handling of canceled contexts during stream operations.
2026-01-28 10:58:35 +08:00

249 lines
6.4 KiB
Go

package wsrelay
import (
"bytes"
"context"
"errors"
"fmt"
"net/http"
"time"
"github.com/google/uuid"
)
// HTTPRequest represents a proxied HTTP request delivered to websocket clients.
type HTTPRequest struct {
Method string
URL string
Headers http.Header
Body []byte
}
// HTTPResponse captures the response relayed back from websocket clients.
type HTTPResponse struct {
Status int
Headers http.Header
Body []byte
}
// StreamEvent represents a streaming response event from clients.
type StreamEvent struct {
Type string
Payload []byte
Status int
Headers http.Header
Err error
}
// NonStream executes a non-streaming HTTP request using the websocket provider.
func (m *Manager) NonStream(ctx context.Context, provider string, req *HTTPRequest) (*HTTPResponse, error) {
if req == nil {
return nil, fmt.Errorf("wsrelay: request is nil")
}
msg := Message{ID: uuid.NewString(), Type: MessageTypeHTTPReq, Payload: encodeRequest(req)}
respCh, err := m.Send(ctx, provider, msg)
if err != nil {
return nil, err
}
var (
streamMode bool
streamResp *HTTPResponse
streamBody bytes.Buffer
)
for {
select {
case <-ctx.Done():
return nil, ctx.Err()
case msg, ok := <-respCh:
if !ok {
if streamMode {
if streamResp == nil {
streamResp = &HTTPResponse{Status: http.StatusOK, Headers: make(http.Header)}
} else if streamResp.Headers == nil {
streamResp.Headers = make(http.Header)
}
streamResp.Body = append(streamResp.Body[:0], streamBody.Bytes()...)
return streamResp, nil
}
return nil, errors.New("wsrelay: connection closed during response")
}
switch msg.Type {
case MessageTypeHTTPResp:
resp := decodeResponse(msg.Payload)
if streamMode && streamBody.Len() > 0 && len(resp.Body) == 0 {
resp.Body = append(resp.Body[:0], streamBody.Bytes()...)
}
return resp, nil
case MessageTypeError:
return nil, decodeError(msg.Payload)
case MessageTypeStreamStart, MessageTypeStreamChunk:
if msg.Type == MessageTypeStreamStart {
streamMode = true
streamResp = decodeResponse(msg.Payload)
if streamResp.Headers == nil {
streamResp.Headers = make(http.Header)
}
streamBody.Reset()
continue
}
if !streamMode {
streamMode = true
streamResp = &HTTPResponse{Status: http.StatusOK, Headers: make(http.Header)}
}
chunk := decodeChunk(msg.Payload)
if len(chunk) > 0 {
streamBody.Write(chunk)
}
case MessageTypeStreamEnd:
if !streamMode {
return &HTTPResponse{Status: http.StatusOK, Headers: make(http.Header)}, nil
}
if streamResp == nil {
streamResp = &HTTPResponse{Status: http.StatusOK, Headers: make(http.Header)}
} else if streamResp.Headers == nil {
streamResp.Headers = make(http.Header)
}
streamResp.Body = append(streamResp.Body[:0], streamBody.Bytes()...)
return streamResp, nil
default:
}
}
}
}
// Stream executes a streaming HTTP request and returns channel with stream events.
func (m *Manager) Stream(ctx context.Context, provider string, req *HTTPRequest) (<-chan StreamEvent, error) {
if req == nil {
return nil, fmt.Errorf("wsrelay: request is nil")
}
msg := Message{ID: uuid.NewString(), Type: MessageTypeHTTPReq, Payload: encodeRequest(req)}
respCh, err := m.Send(ctx, provider, msg)
if err != nil {
return nil, err
}
out := make(chan StreamEvent)
go func() {
defer close(out)
send := func(ev StreamEvent) bool {
if ctx == nil {
out <- ev
return true
}
select {
case <-ctx.Done():
return false
case out <- ev:
return true
}
}
for {
select {
case <-ctx.Done():
return
case msg, ok := <-respCh:
if !ok {
_ = send(StreamEvent{Err: errors.New("wsrelay: stream closed")})
return
}
switch msg.Type {
case MessageTypeStreamStart:
resp := decodeResponse(msg.Payload)
if okSend := send(StreamEvent{Type: MessageTypeStreamStart, Status: resp.Status, Headers: resp.Headers}); !okSend {
return
}
case MessageTypeStreamChunk:
chunk := decodeChunk(msg.Payload)
if okSend := send(StreamEvent{Type: MessageTypeStreamChunk, Payload: chunk}); !okSend {
return
}
case MessageTypeStreamEnd:
_ = send(StreamEvent{Type: MessageTypeStreamEnd})
return
case MessageTypeError:
_ = send(StreamEvent{Type: MessageTypeError, Err: decodeError(msg.Payload)})
return
case MessageTypeHTTPResp:
resp := decodeResponse(msg.Payload)
_ = send(StreamEvent{Type: MessageTypeHTTPResp, Status: resp.Status, Headers: resp.Headers, Payload: resp.Body})
return
default:
}
}
}
}()
return out, nil
}
func encodeRequest(req *HTTPRequest) map[string]any {
headers := make(map[string]any, len(req.Headers))
for key, values := range req.Headers {
copyValues := make([]string, len(values))
copy(copyValues, values)
headers[key] = copyValues
}
return map[string]any{
"method": req.Method,
"url": req.URL,
"headers": headers,
"body": string(req.Body),
"sent_at": time.Now().UTC().Format(time.RFC3339Nano),
}
}
func decodeResponse(payload map[string]any) *HTTPResponse {
if payload == nil {
return &HTTPResponse{Status: http.StatusBadGateway, Headers: make(http.Header)}
}
resp := &HTTPResponse{Status: http.StatusOK, Headers: make(http.Header)}
if status, ok := payload["status"].(float64); ok {
resp.Status = int(status)
}
if headers, ok := payload["headers"].(map[string]any); ok {
for key, raw := range headers {
switch v := raw.(type) {
case []any:
for _, item := range v {
if str, ok := item.(string); ok {
resp.Headers.Add(key, str)
}
}
case []string:
for _, str := range v {
resp.Headers.Add(key, str)
}
case string:
resp.Headers.Set(key, v)
}
}
}
if body, ok := payload["body"].(string); ok {
resp.Body = []byte(body)
}
return resp
}
func decodeChunk(payload map[string]any) []byte {
if payload == nil {
return nil
}
if data, ok := payload["data"].(string); ok {
return []byte(data)
}
return nil
}
func decodeError(payload map[string]any) error {
if payload == nil {
return errors.New("wsrelay: unknown error")
}
message, _ := payload["error"].(string)
status := 0
if v, ok := payload["status"].(float64); ok {
status = int(v)
}
if message == "" {
message = "wsrelay: upstream error"
}
return fmt.Errorf("%s (status=%d)", message, status)
}