mirror of
https://github.com/router-for-me/CLIProxyAPI.git
synced 2026-02-03 04:50:52 +08:00
234 lines
6.1 KiB
Go
234 lines
6.1 KiB
Go
package wsrelay
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"errors"
|
|
"fmt"
|
|
"net/http"
|
|
"time"
|
|
|
|
"github.com/google/uuid"
|
|
)
|
|
|
|
// HTTPRequest represents a proxied HTTP request delivered to websocket clients.
|
|
type HTTPRequest struct {
|
|
Method string
|
|
URL string
|
|
Headers http.Header
|
|
Body []byte
|
|
}
|
|
|
|
// HTTPResponse captures the response relayed back from websocket clients.
|
|
type HTTPResponse struct {
|
|
Status int
|
|
Headers http.Header
|
|
Body []byte
|
|
}
|
|
|
|
// StreamEvent represents a streaming response event from clients.
|
|
type StreamEvent struct {
|
|
Type string
|
|
Payload []byte
|
|
Status int
|
|
Headers http.Header
|
|
Err error
|
|
}
|
|
|
|
// NonStream executes a non-streaming HTTP request using the websocket provider.
|
|
func (m *Manager) NonStream(ctx context.Context, provider string, req *HTTPRequest) (*HTTPResponse, error) {
|
|
if req == nil {
|
|
return nil, fmt.Errorf("wsrelay: request is nil")
|
|
}
|
|
msg := Message{ID: uuid.NewString(), Type: MessageTypeHTTPReq, Payload: encodeRequest(req)}
|
|
respCh, err := m.Send(ctx, provider, msg)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
var (
|
|
streamMode bool
|
|
streamResp *HTTPResponse
|
|
streamBody bytes.Buffer
|
|
)
|
|
for {
|
|
select {
|
|
case <-ctx.Done():
|
|
return nil, ctx.Err()
|
|
case msg, ok := <-respCh:
|
|
if !ok {
|
|
if streamMode {
|
|
if streamResp == nil {
|
|
streamResp = &HTTPResponse{Status: http.StatusOK, Headers: make(http.Header)}
|
|
} else if streamResp.Headers == nil {
|
|
streamResp.Headers = make(http.Header)
|
|
}
|
|
streamResp.Body = append(streamResp.Body[:0], streamBody.Bytes()...)
|
|
return streamResp, nil
|
|
}
|
|
return nil, errors.New("wsrelay: connection closed during response")
|
|
}
|
|
switch msg.Type {
|
|
case MessageTypeHTTPResp:
|
|
resp := decodeResponse(msg.Payload)
|
|
if streamMode && streamBody.Len() > 0 && len(resp.Body) == 0 {
|
|
resp.Body = append(resp.Body[:0], streamBody.Bytes()...)
|
|
}
|
|
return resp, nil
|
|
case MessageTypeError:
|
|
return nil, decodeError(msg.Payload)
|
|
case MessageTypeStreamStart, MessageTypeStreamChunk:
|
|
if msg.Type == MessageTypeStreamStart {
|
|
streamMode = true
|
|
streamResp = decodeResponse(msg.Payload)
|
|
if streamResp.Headers == nil {
|
|
streamResp.Headers = make(http.Header)
|
|
}
|
|
streamBody.Reset()
|
|
continue
|
|
}
|
|
if !streamMode {
|
|
streamMode = true
|
|
streamResp = &HTTPResponse{Status: http.StatusOK, Headers: make(http.Header)}
|
|
}
|
|
chunk := decodeChunk(msg.Payload)
|
|
if len(chunk) > 0 {
|
|
streamBody.Write(chunk)
|
|
}
|
|
case MessageTypeStreamEnd:
|
|
if !streamMode {
|
|
return &HTTPResponse{Status: http.StatusOK, Headers: make(http.Header)}, nil
|
|
}
|
|
if streamResp == nil {
|
|
streamResp = &HTTPResponse{Status: http.StatusOK, Headers: make(http.Header)}
|
|
} else if streamResp.Headers == nil {
|
|
streamResp.Headers = make(http.Header)
|
|
}
|
|
streamResp.Body = append(streamResp.Body[:0], streamBody.Bytes()...)
|
|
return streamResp, nil
|
|
default:
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
// Stream executes a streaming HTTP request and returns channel with stream events.
|
|
func (m *Manager) Stream(ctx context.Context, provider string, req *HTTPRequest) (<-chan StreamEvent, error) {
|
|
if req == nil {
|
|
return nil, fmt.Errorf("wsrelay: request is nil")
|
|
}
|
|
msg := Message{ID: uuid.NewString(), Type: MessageTypeHTTPReq, Payload: encodeRequest(req)}
|
|
respCh, err := m.Send(ctx, provider, msg)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
out := make(chan StreamEvent)
|
|
go func() {
|
|
defer close(out)
|
|
for {
|
|
select {
|
|
case <-ctx.Done():
|
|
out <- StreamEvent{Err: ctx.Err()}
|
|
return
|
|
case msg, ok := <-respCh:
|
|
if !ok {
|
|
out <- StreamEvent{Err: errors.New("wsrelay: stream closed")}
|
|
return
|
|
}
|
|
switch msg.Type {
|
|
case MessageTypeStreamStart:
|
|
resp := decodeResponse(msg.Payload)
|
|
out <- StreamEvent{Type: MessageTypeStreamStart, Status: resp.Status, Headers: resp.Headers}
|
|
case MessageTypeStreamChunk:
|
|
chunk := decodeChunk(msg.Payload)
|
|
out <- StreamEvent{Type: MessageTypeStreamChunk, Payload: chunk}
|
|
case MessageTypeStreamEnd:
|
|
out <- StreamEvent{Type: MessageTypeStreamEnd}
|
|
return
|
|
case MessageTypeError:
|
|
out <- StreamEvent{Type: MessageTypeError, Err: decodeError(msg.Payload)}
|
|
return
|
|
case MessageTypeHTTPResp:
|
|
resp := decodeResponse(msg.Payload)
|
|
out <- StreamEvent{Type: MessageTypeHTTPResp, Status: resp.Status, Headers: resp.Headers, Payload: resp.Body}
|
|
return
|
|
default:
|
|
}
|
|
}
|
|
}
|
|
}()
|
|
return out, nil
|
|
}
|
|
|
|
func encodeRequest(req *HTTPRequest) map[string]any {
|
|
headers := make(map[string]any, len(req.Headers))
|
|
for key, values := range req.Headers {
|
|
copyValues := make([]string, len(values))
|
|
copy(copyValues, values)
|
|
headers[key] = copyValues
|
|
}
|
|
return map[string]any{
|
|
"method": req.Method,
|
|
"url": req.URL,
|
|
"headers": headers,
|
|
"body": string(req.Body),
|
|
"sent_at": time.Now().UTC().Format(time.RFC3339Nano),
|
|
}
|
|
}
|
|
|
|
func decodeResponse(payload map[string]any) *HTTPResponse {
|
|
if payload == nil {
|
|
return &HTTPResponse{Status: http.StatusBadGateway, Headers: make(http.Header)}
|
|
}
|
|
resp := &HTTPResponse{Status: http.StatusOK, Headers: make(http.Header)}
|
|
if status, ok := payload["status"].(float64); ok {
|
|
resp.Status = int(status)
|
|
}
|
|
if headers, ok := payload["headers"].(map[string]any); ok {
|
|
for key, raw := range headers {
|
|
switch v := raw.(type) {
|
|
case []any:
|
|
for _, item := range v {
|
|
if str, ok := item.(string); ok {
|
|
resp.Headers.Add(key, str)
|
|
}
|
|
}
|
|
case []string:
|
|
for _, str := range v {
|
|
resp.Headers.Add(key, str)
|
|
}
|
|
case string:
|
|
resp.Headers.Set(key, v)
|
|
}
|
|
}
|
|
}
|
|
if body, ok := payload["body"].(string); ok {
|
|
resp.Body = []byte(body)
|
|
}
|
|
return resp
|
|
}
|
|
|
|
func decodeChunk(payload map[string]any) []byte {
|
|
if payload == nil {
|
|
return nil
|
|
}
|
|
if data, ok := payload["data"].(string); ok {
|
|
return []byte(data)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func decodeError(payload map[string]any) error {
|
|
if payload == nil {
|
|
return errors.New("wsrelay: unknown error")
|
|
}
|
|
message, _ := payload["error"].(string)
|
|
status := 0
|
|
if v, ok := payload["status"].(float64); ok {
|
|
status = int(v)
|
|
}
|
|
if message == "" {
|
|
message = "wsrelay: upstream error"
|
|
}
|
|
return fmt.Errorf("%s (status=%d)", message, status)
|
|
}
|