mirror of
https://github.com/router-for-me/CLIProxyAPI.git
synced 2026-02-18 12:20:52 +08:00
feat: add websocket routing and executor unregister API
- Introduce Server.AttachWebsocketRoute(path, handler) to mount websocket upgrade handlers on the Gin engine. - Track registered WS paths via wsRoutes with wsRouteMu to prevent duplicate registrations; initialize in NewServer and import sync. - Add Manager.UnregisterExecutor(provider) for clean executor lifecycle management. - Add github.com/gorilla/websocket v1.5.3 dependency and update go.sum. Motivation: enable services to expose WS endpoints through the core server and allow removing auth executors dynamically while avoiding duplicate route setup. No breaking changes.
This commit is contained in:
187
internal/wsrelay/http.go
Normal file
187
internal/wsrelay/http.go
Normal file
@@ -0,0 +1,187 @@
|
||||
package wsrelay
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
// HTTPRequest represents a proxied HTTP request delivered to websocket clients.
|
||||
type HTTPRequest struct {
|
||||
Method string
|
||||
URL string
|
||||
Headers http.Header
|
||||
Body []byte
|
||||
}
|
||||
|
||||
// HTTPResponse captures the response relayed back from websocket clients.
|
||||
type HTTPResponse struct {
|
||||
Status int
|
||||
Headers http.Header
|
||||
Body []byte
|
||||
}
|
||||
|
||||
// StreamEvent represents a streaming response event from clients.
|
||||
type StreamEvent struct {
|
||||
Type string
|
||||
Payload []byte
|
||||
Status int
|
||||
Headers http.Header
|
||||
Err error
|
||||
}
|
||||
|
||||
// RoundTrip executes a non-streaming HTTP request using the websocket provider.
|
||||
func (m *Manager) RoundTrip(ctx context.Context, provider string, req *HTTPRequest) (*HTTPResponse, error) {
|
||||
if req == nil {
|
||||
return nil, fmt.Errorf("wsrelay: request is nil")
|
||||
}
|
||||
msg := Message{ID: uuid.NewString(), Type: MessageTypeHTTPReq, Payload: encodeRequest(req)}
|
||||
respCh, err := m.Send(ctx, provider, msg)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
case msg, ok := <-respCh:
|
||||
if !ok {
|
||||
return nil, errors.New("wsrelay: connection closed during response")
|
||||
}
|
||||
switch msg.Type {
|
||||
case MessageTypeHTTPResp:
|
||||
return decodeResponse(msg.Payload), nil
|
||||
case MessageTypeError:
|
||||
return nil, decodeError(msg.Payload)
|
||||
case MessageTypeStreamStart, MessageTypeStreamChunk:
|
||||
// Ignore streaming noise in non-stream requests.
|
||||
default:
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Stream executes a streaming HTTP request and returns channel with stream events.
|
||||
func (m *Manager) Stream(ctx context.Context, provider string, req *HTTPRequest) (<-chan StreamEvent, error) {
|
||||
if req == nil {
|
||||
return nil, fmt.Errorf("wsrelay: request is nil")
|
||||
}
|
||||
msg := Message{ID: uuid.NewString(), Type: MessageTypeHTTPReq, Payload: encodeRequest(req)}
|
||||
respCh, err := m.Send(ctx, provider, msg)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
out := make(chan StreamEvent)
|
||||
go func() {
|
||||
defer close(out)
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
out <- StreamEvent{Err: ctx.Err()}
|
||||
return
|
||||
case msg, ok := <-respCh:
|
||||
if !ok {
|
||||
out <- StreamEvent{Err: errors.New("wsrelay: stream closed")}
|
||||
return
|
||||
}
|
||||
switch msg.Type {
|
||||
case MessageTypeStreamStart:
|
||||
resp := decodeResponse(msg.Payload)
|
||||
out <- StreamEvent{Type: MessageTypeStreamStart, Status: resp.Status, Headers: resp.Headers}
|
||||
case MessageTypeStreamChunk:
|
||||
chunk := decodeChunk(msg.Payload)
|
||||
out <- StreamEvent{Type: MessageTypeStreamChunk, Payload: chunk}
|
||||
case MessageTypeStreamEnd:
|
||||
out <- StreamEvent{Type: MessageTypeStreamEnd}
|
||||
return
|
||||
case MessageTypeError:
|
||||
out <- StreamEvent{Type: MessageTypeError, Err: decodeError(msg.Payload)}
|
||||
return
|
||||
case MessageTypeHTTPResp:
|
||||
resp := decodeResponse(msg.Payload)
|
||||
out <- StreamEvent{Type: MessageTypeHTTPResp, Status: resp.Status, Headers: resp.Headers, Payload: resp.Body}
|
||||
return
|
||||
default:
|
||||
}
|
||||
}
|
||||
}
|
||||
}()
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func encodeRequest(req *HTTPRequest) map[string]any {
|
||||
headers := make(map[string]any, len(req.Headers))
|
||||
for key, values := range req.Headers {
|
||||
copyValues := make([]string, len(values))
|
||||
copy(copyValues, values)
|
||||
headers[key] = copyValues
|
||||
}
|
||||
return map[string]any{
|
||||
"method": req.Method,
|
||||
"url": req.URL,
|
||||
"headers": headers,
|
||||
"body": string(req.Body),
|
||||
"sent_at": time.Now().UTC().Format(time.RFC3339Nano),
|
||||
}
|
||||
}
|
||||
|
||||
func decodeResponse(payload map[string]any) *HTTPResponse {
|
||||
if payload == nil {
|
||||
return &HTTPResponse{Status: http.StatusBadGateway, Headers: make(http.Header)}
|
||||
}
|
||||
resp := &HTTPResponse{Status: http.StatusOK, Headers: make(http.Header)}
|
||||
if status, ok := payload["status"].(float64); ok {
|
||||
resp.Status = int(status)
|
||||
}
|
||||
if headers, ok := payload["headers"].(map[string]any); ok {
|
||||
for key, raw := range headers {
|
||||
switch v := raw.(type) {
|
||||
case []any:
|
||||
for _, item := range v {
|
||||
if str, ok := item.(string); ok {
|
||||
resp.Headers.Add(key, str)
|
||||
}
|
||||
}
|
||||
case []string:
|
||||
for _, str := range v {
|
||||
resp.Headers.Add(key, str)
|
||||
}
|
||||
case string:
|
||||
resp.Headers.Set(key, v)
|
||||
}
|
||||
}
|
||||
}
|
||||
if body, ok := payload["body"].(string); ok {
|
||||
resp.Body = []byte(body)
|
||||
}
|
||||
return resp
|
||||
}
|
||||
|
||||
func decodeChunk(payload map[string]any) []byte {
|
||||
if payload == nil {
|
||||
return nil
|
||||
}
|
||||
if data, ok := payload["data"].(string); ok {
|
||||
return []byte(data)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func decodeError(payload map[string]any) error {
|
||||
if payload == nil {
|
||||
return errors.New("wsrelay: unknown error")
|
||||
}
|
||||
message, _ := payload["error"].(string)
|
||||
status := 0
|
||||
if v, ok := payload["status"].(float64); ok {
|
||||
status = int(v)
|
||||
}
|
||||
if message == "" {
|
||||
message = "wsrelay: upstream error"
|
||||
}
|
||||
return fmt.Errorf("%s (status=%d)", message, status)
|
||||
}
|
||||
Reference in New Issue
Block a user