mirror of
https://github.com/router-for-me/CLIProxyAPI.git
synced 2026-02-03 04:50:52 +08:00
feat: add websocket routing and executor unregister API
- Introduce Server.AttachWebsocketRoute(path, handler) to mount websocket upgrade handlers on the Gin engine. - Track registered WS paths via wsRoutes with wsRouteMu to prevent duplicate registrations; initialize in NewServer and import sync. - Add Manager.UnregisterExecutor(provider) for clean executor lifecycle management. - Add github.com/gorilla/websocket v1.5.3 dependency and update go.sum. Motivation: enable services to expose WS endpoints through the core server and allow removing auth executors dynamically while avoiding duplicate route setup. No breaking changes.
This commit is contained in:
187
internal/wsrelay/http.go
Normal file
187
internal/wsrelay/http.go
Normal file
@@ -0,0 +1,187 @@
|
||||
package wsrelay
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
// HTTPRequest represents a proxied HTTP request delivered to websocket clients.
|
||||
type HTTPRequest struct {
|
||||
Method string
|
||||
URL string
|
||||
Headers http.Header
|
||||
Body []byte
|
||||
}
|
||||
|
||||
// HTTPResponse captures the response relayed back from websocket clients.
|
||||
type HTTPResponse struct {
|
||||
Status int
|
||||
Headers http.Header
|
||||
Body []byte
|
||||
}
|
||||
|
||||
// StreamEvent represents a streaming response event from clients.
|
||||
type StreamEvent struct {
|
||||
Type string
|
||||
Payload []byte
|
||||
Status int
|
||||
Headers http.Header
|
||||
Err error
|
||||
}
|
||||
|
||||
// RoundTrip executes a non-streaming HTTP request using the websocket provider.
|
||||
func (m *Manager) RoundTrip(ctx context.Context, provider string, req *HTTPRequest) (*HTTPResponse, error) {
|
||||
if req == nil {
|
||||
return nil, fmt.Errorf("wsrelay: request is nil")
|
||||
}
|
||||
msg := Message{ID: uuid.NewString(), Type: MessageTypeHTTPReq, Payload: encodeRequest(req)}
|
||||
respCh, err := m.Send(ctx, provider, msg)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return nil, ctx.Err()
|
||||
case msg, ok := <-respCh:
|
||||
if !ok {
|
||||
return nil, errors.New("wsrelay: connection closed during response")
|
||||
}
|
||||
switch msg.Type {
|
||||
case MessageTypeHTTPResp:
|
||||
return decodeResponse(msg.Payload), nil
|
||||
case MessageTypeError:
|
||||
return nil, decodeError(msg.Payload)
|
||||
case MessageTypeStreamStart, MessageTypeStreamChunk:
|
||||
// Ignore streaming noise in non-stream requests.
|
||||
default:
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// Stream executes a streaming HTTP request and returns channel with stream events.
|
||||
func (m *Manager) Stream(ctx context.Context, provider string, req *HTTPRequest) (<-chan StreamEvent, error) {
|
||||
if req == nil {
|
||||
return nil, fmt.Errorf("wsrelay: request is nil")
|
||||
}
|
||||
msg := Message{ID: uuid.NewString(), Type: MessageTypeHTTPReq, Payload: encodeRequest(req)}
|
||||
respCh, err := m.Send(ctx, provider, msg)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
out := make(chan StreamEvent)
|
||||
go func() {
|
||||
defer close(out)
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
out <- StreamEvent{Err: ctx.Err()}
|
||||
return
|
||||
case msg, ok := <-respCh:
|
||||
if !ok {
|
||||
out <- StreamEvent{Err: errors.New("wsrelay: stream closed")}
|
||||
return
|
||||
}
|
||||
switch msg.Type {
|
||||
case MessageTypeStreamStart:
|
||||
resp := decodeResponse(msg.Payload)
|
||||
out <- StreamEvent{Type: MessageTypeStreamStart, Status: resp.Status, Headers: resp.Headers}
|
||||
case MessageTypeStreamChunk:
|
||||
chunk := decodeChunk(msg.Payload)
|
||||
out <- StreamEvent{Type: MessageTypeStreamChunk, Payload: chunk}
|
||||
case MessageTypeStreamEnd:
|
||||
out <- StreamEvent{Type: MessageTypeStreamEnd}
|
||||
return
|
||||
case MessageTypeError:
|
||||
out <- StreamEvent{Type: MessageTypeError, Err: decodeError(msg.Payload)}
|
||||
return
|
||||
case MessageTypeHTTPResp:
|
||||
resp := decodeResponse(msg.Payload)
|
||||
out <- StreamEvent{Type: MessageTypeHTTPResp, Status: resp.Status, Headers: resp.Headers, Payload: resp.Body}
|
||||
return
|
||||
default:
|
||||
}
|
||||
}
|
||||
}
|
||||
}()
|
||||
return out, nil
|
||||
}
|
||||
|
||||
func encodeRequest(req *HTTPRequest) map[string]any {
|
||||
headers := make(map[string]any, len(req.Headers))
|
||||
for key, values := range req.Headers {
|
||||
copyValues := make([]string, len(values))
|
||||
copy(copyValues, values)
|
||||
headers[key] = copyValues
|
||||
}
|
||||
return map[string]any{
|
||||
"method": req.Method,
|
||||
"url": req.URL,
|
||||
"headers": headers,
|
||||
"body": string(req.Body),
|
||||
"sent_at": time.Now().UTC().Format(time.RFC3339Nano),
|
||||
}
|
||||
}
|
||||
|
||||
func decodeResponse(payload map[string]any) *HTTPResponse {
|
||||
if payload == nil {
|
||||
return &HTTPResponse{Status: http.StatusBadGateway, Headers: make(http.Header)}
|
||||
}
|
||||
resp := &HTTPResponse{Status: http.StatusOK, Headers: make(http.Header)}
|
||||
if status, ok := payload["status"].(float64); ok {
|
||||
resp.Status = int(status)
|
||||
}
|
||||
if headers, ok := payload["headers"].(map[string]any); ok {
|
||||
for key, raw := range headers {
|
||||
switch v := raw.(type) {
|
||||
case []any:
|
||||
for _, item := range v {
|
||||
if str, ok := item.(string); ok {
|
||||
resp.Headers.Add(key, str)
|
||||
}
|
||||
}
|
||||
case []string:
|
||||
for _, str := range v {
|
||||
resp.Headers.Add(key, str)
|
||||
}
|
||||
case string:
|
||||
resp.Headers.Set(key, v)
|
||||
}
|
||||
}
|
||||
}
|
||||
if body, ok := payload["body"].(string); ok {
|
||||
resp.Body = []byte(body)
|
||||
}
|
||||
return resp
|
||||
}
|
||||
|
||||
func decodeChunk(payload map[string]any) []byte {
|
||||
if payload == nil {
|
||||
return nil
|
||||
}
|
||||
if data, ok := payload["data"].(string); ok {
|
||||
return []byte(data)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func decodeError(payload map[string]any) error {
|
||||
if payload == nil {
|
||||
return errors.New("wsrelay: unknown error")
|
||||
}
|
||||
message, _ := payload["error"].(string)
|
||||
status := 0
|
||||
if v, ok := payload["status"].(float64); ok {
|
||||
status = int(v)
|
||||
}
|
||||
if message == "" {
|
||||
message = "wsrelay: upstream error"
|
||||
}
|
||||
return fmt.Errorf("%s (status=%d)", message, status)
|
||||
}
|
||||
200
internal/wsrelay/manager.go
Normal file
200
internal/wsrelay/manager.go
Normal file
@@ -0,0 +1,200 @@
|
||||
package wsrelay
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/gorilla/websocket"
|
||||
)
|
||||
|
||||
// Manager exposes a websocket endpoint that proxies Gemini requests to
|
||||
// connected clients.
|
||||
type Manager struct {
|
||||
path string
|
||||
upgrader websocket.Upgrader
|
||||
sessions map[string]*session
|
||||
sessMutex sync.RWMutex
|
||||
|
||||
providerFactory func(*http.Request) (string, error)
|
||||
onConnected func(string)
|
||||
onDisconnected func(string, error)
|
||||
|
||||
logDebugf func(string, ...any)
|
||||
logInfof func(string, ...any)
|
||||
logWarnf func(string, ...any)
|
||||
}
|
||||
|
||||
// Options configures a Manager instance.
|
||||
type Options struct {
|
||||
Path string
|
||||
ProviderFactory func(*http.Request) (string, error)
|
||||
OnConnected func(string)
|
||||
OnDisconnected func(string, error)
|
||||
LogDebugf func(string, ...any)
|
||||
LogInfof func(string, ...any)
|
||||
LogWarnf func(string, ...any)
|
||||
}
|
||||
|
||||
// NewManager builds a websocket relay manager with the supplied options.
|
||||
func NewManager(opts Options) *Manager {
|
||||
path := strings.TrimSpace(opts.Path)
|
||||
if path == "" {
|
||||
path = "/v1/ws"
|
||||
}
|
||||
if !strings.HasPrefix(path, "/") {
|
||||
path = "/" + path
|
||||
}
|
||||
mgr := &Manager{
|
||||
path: path,
|
||||
sessions: make(map[string]*session),
|
||||
upgrader: websocket.Upgrader{
|
||||
ReadBufferSize: 1024,
|
||||
WriteBufferSize: 1024,
|
||||
CheckOrigin: func(r *http.Request) bool {
|
||||
return true
|
||||
},
|
||||
},
|
||||
providerFactory: opts.ProviderFactory,
|
||||
onConnected: opts.OnConnected,
|
||||
onDisconnected: opts.OnDisconnected,
|
||||
logDebugf: opts.LogDebugf,
|
||||
logInfof: opts.LogInfof,
|
||||
logWarnf: opts.LogWarnf,
|
||||
}
|
||||
if mgr.logDebugf == nil {
|
||||
mgr.logDebugf = func(string, ...any) {}
|
||||
}
|
||||
if mgr.logInfof == nil {
|
||||
mgr.logInfof = func(string, ...any) {}
|
||||
}
|
||||
if mgr.logWarnf == nil {
|
||||
mgr.logWarnf = func(s string, args ...any) { fmt.Printf(s+"\n", args...) }
|
||||
}
|
||||
return mgr
|
||||
}
|
||||
|
||||
// Path returns the HTTP path the manager expects for websocket upgrades.
|
||||
func (m *Manager) Path() string {
|
||||
if m == nil {
|
||||
return "/v1/ws"
|
||||
}
|
||||
return m.path
|
||||
}
|
||||
|
||||
// Handler exposes an http.Handler that upgrades connections to websocket sessions.
|
||||
func (m *Manager) Handler() http.Handler {
|
||||
return http.HandlerFunc(m.handleWebsocket)
|
||||
}
|
||||
|
||||
// Stop gracefully closes all active websocket sessions.
|
||||
func (m *Manager) Stop(_ context.Context) error {
|
||||
m.sessMutex.Lock()
|
||||
sessions := make([]*session, 0, len(m.sessions))
|
||||
for _, sess := range m.sessions {
|
||||
sessions = append(sessions, sess)
|
||||
}
|
||||
m.sessions = make(map[string]*session)
|
||||
m.sessMutex.Unlock()
|
||||
|
||||
for _, sess := range sessions {
|
||||
if sess != nil {
|
||||
sess.cleanup(errors.New("wsrelay: manager stopped"))
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// handleWebsocket upgrades the connection and wires the session into the pool.
|
||||
func (m *Manager) handleWebsocket(w http.ResponseWriter, r *http.Request) {
|
||||
expectedPath := m.Path()
|
||||
if expectedPath != "" && r.URL != nil && r.URL.Path != expectedPath {
|
||||
http.NotFound(w, r)
|
||||
return
|
||||
}
|
||||
if !strings.EqualFold(r.Method, http.MethodGet) {
|
||||
w.Header().Set("Allow", http.MethodGet)
|
||||
http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
|
||||
return
|
||||
}
|
||||
conn, err := m.upgrader.Upgrade(w, r, nil)
|
||||
if err != nil {
|
||||
m.logWarnf("wsrelay: upgrade failed: %v", err)
|
||||
return
|
||||
}
|
||||
s := newSession(conn, m, randomProviderName())
|
||||
if m.providerFactory != nil {
|
||||
name, err := m.providerFactory(r)
|
||||
if err != nil {
|
||||
s.cleanup(err)
|
||||
return
|
||||
}
|
||||
if strings.TrimSpace(name) != "" {
|
||||
s.provider = strings.ToLower(name)
|
||||
}
|
||||
}
|
||||
if s.provider == "" {
|
||||
s.provider = strings.ToLower(s.id)
|
||||
}
|
||||
m.sessMutex.Lock()
|
||||
if existing, ok := m.sessions[s.provider]; ok {
|
||||
existing.cleanup(errors.New("replaced by new connection"))
|
||||
}
|
||||
m.sessions[s.provider] = s
|
||||
m.sessMutex.Unlock()
|
||||
if m.onConnected != nil {
|
||||
m.onConnected(s.provider)
|
||||
}
|
||||
|
||||
go s.run(context.Background())
|
||||
}
|
||||
|
||||
// Send forwards the message to the specific provider connection and returns a channel
|
||||
// yielding response messages.
|
||||
func (m *Manager) Send(ctx context.Context, provider string, msg Message) (<-chan Message, error) {
|
||||
s := m.session(provider)
|
||||
if s == nil {
|
||||
return nil, fmt.Errorf("wsrelay: provider %s not connected", provider)
|
||||
}
|
||||
return s.request(ctx, msg)
|
||||
}
|
||||
|
||||
func (m *Manager) session(provider string) *session {
|
||||
key := strings.ToLower(strings.TrimSpace(provider))
|
||||
m.sessMutex.RLock()
|
||||
s := m.sessions[key]
|
||||
m.sessMutex.RUnlock()
|
||||
return s
|
||||
}
|
||||
|
||||
func (m *Manager) handleSessionClosed(s *session, cause error) {
|
||||
if s == nil {
|
||||
return
|
||||
}
|
||||
key := strings.ToLower(strings.TrimSpace(s.provider))
|
||||
m.sessMutex.Lock()
|
||||
if cur, ok := m.sessions[key]; ok && cur == s {
|
||||
delete(m.sessions, key)
|
||||
}
|
||||
m.sessMutex.Unlock()
|
||||
if m.onDisconnected != nil {
|
||||
m.onDisconnected(s.provider, cause)
|
||||
}
|
||||
}
|
||||
|
||||
func randomProviderName() string {
|
||||
const alphabet = "abcdefghijklmnopqrstuvwxyz0123456789"
|
||||
buf := make([]byte, 16)
|
||||
if _, err := rand.Read(buf); err != nil {
|
||||
return fmt.Sprintf("aistudio-%x", time.Now().UnixNano())
|
||||
}
|
||||
for i := range buf {
|
||||
buf[i] = alphabet[int(buf[i])%len(alphabet)]
|
||||
}
|
||||
return "aistudio-" + string(buf)
|
||||
}
|
||||
27
internal/wsrelay/message.go
Normal file
27
internal/wsrelay/message.go
Normal file
@@ -0,0 +1,27 @@
|
||||
package wsrelay
|
||||
|
||||
// Message represents the JSON payload exchanged with websocket clients.
|
||||
type Message struct {
|
||||
ID string `json:"id"`
|
||||
Type string `json:"type"`
|
||||
Payload map[string]any `json:"payload,omitempty"`
|
||||
}
|
||||
|
||||
const (
|
||||
// MessageTypeHTTPReq identifies an HTTP-style request envelope.
|
||||
MessageTypeHTTPReq = "http_request"
|
||||
// MessageTypeHTTPResp identifies a non-streaming HTTP response envelope.
|
||||
MessageTypeHTTPResp = "http_response"
|
||||
// MessageTypeStreamStart marks the beginning of a streaming response.
|
||||
MessageTypeStreamStart = "stream_start"
|
||||
// MessageTypeStreamChunk carries a streaming response chunk.
|
||||
MessageTypeStreamChunk = "stream_chunk"
|
||||
// MessageTypeStreamEnd marks the completion of a streaming response.
|
||||
MessageTypeStreamEnd = "stream_end"
|
||||
// MessageTypeError carries an error response.
|
||||
MessageTypeError = "error"
|
||||
// MessageTypePing represents ping messages from clients.
|
||||
MessageTypePing = "ping"
|
||||
// MessageTypePong represents pong responses back to clients.
|
||||
MessageTypePong = "pong"
|
||||
)
|
||||
188
internal/wsrelay/session.go
Normal file
188
internal/wsrelay/session.go
Normal file
@@ -0,0 +1,188 @@
|
||||
package wsrelay
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/gorilla/websocket"
|
||||
)
|
||||
|
||||
const (
|
||||
readTimeout = 60 * time.Second
|
||||
writeTimeout = 10 * time.Second
|
||||
maxInboundMessageLen = 64 << 20 // 64 MiB
|
||||
heartbeatInterval = 30 * time.Second
|
||||
)
|
||||
|
||||
var errClosed = errors.New("websocket session closed")
|
||||
|
||||
type pendingRequest struct {
|
||||
ch chan Message
|
||||
closeOnce sync.Once
|
||||
}
|
||||
|
||||
func (pr *pendingRequest) close() {
|
||||
if pr == nil {
|
||||
return
|
||||
}
|
||||
pr.closeOnce.Do(func() {
|
||||
close(pr.ch)
|
||||
})
|
||||
}
|
||||
|
||||
type session struct {
|
||||
conn *websocket.Conn
|
||||
manager *Manager
|
||||
provider string
|
||||
id string
|
||||
closed chan struct{}
|
||||
closeOnce sync.Once
|
||||
writeMutex sync.Mutex
|
||||
pending sync.Map // map[string]*pendingRequest
|
||||
}
|
||||
|
||||
func newSession(conn *websocket.Conn, mgr *Manager, id string) *session {
|
||||
s := &session{
|
||||
conn: conn,
|
||||
manager: mgr,
|
||||
provider: "",
|
||||
id: id,
|
||||
closed: make(chan struct{}),
|
||||
}
|
||||
conn.SetReadLimit(maxInboundMessageLen)
|
||||
conn.SetReadDeadline(time.Now().Add(readTimeout))
|
||||
conn.SetPongHandler(func(string) error {
|
||||
conn.SetReadDeadline(time.Now().Add(readTimeout))
|
||||
return nil
|
||||
})
|
||||
s.startHeartbeat()
|
||||
return s
|
||||
}
|
||||
|
||||
func (s *session) startHeartbeat() {
|
||||
if s == nil || s.conn == nil {
|
||||
return
|
||||
}
|
||||
ticker := time.NewTicker(heartbeatInterval)
|
||||
go func() {
|
||||
defer ticker.Stop()
|
||||
for {
|
||||
select {
|
||||
case <-s.closed:
|
||||
return
|
||||
case <-ticker.C:
|
||||
s.writeMutex.Lock()
|
||||
err := s.conn.WriteControl(websocket.PingMessage, []byte("ping"), time.Now().Add(writeTimeout))
|
||||
s.writeMutex.Unlock()
|
||||
if err != nil {
|
||||
s.cleanup(err)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
func (s *session) run(ctx context.Context) {
|
||||
defer s.cleanup(errClosed)
|
||||
for {
|
||||
var msg Message
|
||||
if err := s.conn.ReadJSON(&msg); err != nil {
|
||||
s.cleanup(err)
|
||||
return
|
||||
}
|
||||
s.dispatch(msg)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *session) dispatch(msg Message) {
|
||||
if msg.Type == MessageTypePing {
|
||||
_ = s.send(context.Background(), Message{ID: msg.ID, Type: MessageTypePong})
|
||||
return
|
||||
}
|
||||
if value, ok := s.pending.Load(msg.ID); ok {
|
||||
req := value.(*pendingRequest)
|
||||
select {
|
||||
case req.ch <- msg:
|
||||
default:
|
||||
}
|
||||
if msg.Type == MessageTypeHTTPResp || msg.Type == MessageTypeError || msg.Type == MessageTypeStreamEnd {
|
||||
if actual, loaded := s.pending.LoadAndDelete(msg.ID); loaded {
|
||||
actual.(*pendingRequest).close()
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
if msg.Type == MessageTypeHTTPResp || msg.Type == MessageTypeError || msg.Type == MessageTypeStreamEnd {
|
||||
s.manager.logDebugf("wsrelay: received terminal message for unknown id %s (provider=%s)", msg.ID, s.provider)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *session) send(ctx context.Context, msg Message) error {
|
||||
select {
|
||||
case <-s.closed:
|
||||
return errClosed
|
||||
default:
|
||||
}
|
||||
s.writeMutex.Lock()
|
||||
defer s.writeMutex.Unlock()
|
||||
if err := s.conn.SetWriteDeadline(time.Now().Add(writeTimeout)); err != nil {
|
||||
return fmt.Errorf("set write deadline: %w", err)
|
||||
}
|
||||
if err := s.conn.WriteJSON(msg); err != nil {
|
||||
return fmt.Errorf("write json: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *session) request(ctx context.Context, msg Message) (<-chan Message, error) {
|
||||
if msg.ID == "" {
|
||||
return nil, fmt.Errorf("wsrelay: message id is required")
|
||||
}
|
||||
if _, loaded := s.pending.LoadOrStore(msg.ID, &pendingRequest{ch: make(chan Message, 8)}); loaded {
|
||||
return nil, fmt.Errorf("wsrelay: duplicate message id %s", msg.ID)
|
||||
}
|
||||
value, _ := s.pending.Load(msg.ID)
|
||||
req := value.(*pendingRequest)
|
||||
if err := s.send(ctx, msg); err != nil {
|
||||
if actual, loaded := s.pending.LoadAndDelete(msg.ID); loaded {
|
||||
req := actual.(*pendingRequest)
|
||||
req.close()
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
go func() {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
if actual, loaded := s.pending.LoadAndDelete(msg.ID); loaded {
|
||||
actual.(*pendingRequest).close()
|
||||
}
|
||||
case <-s.closed:
|
||||
}
|
||||
}()
|
||||
return req.ch, nil
|
||||
}
|
||||
|
||||
func (s *session) cleanup(cause error) {
|
||||
s.closeOnce.Do(func() {
|
||||
close(s.closed)
|
||||
s.pending.Range(func(key, value any) bool {
|
||||
req := value.(*pendingRequest)
|
||||
msg := Message{ID: key.(string), Type: MessageTypeError, Payload: map[string]any{"error": cause.Error()}}
|
||||
select {
|
||||
case req.ch <- msg:
|
||||
default:
|
||||
}
|
||||
req.close()
|
||||
return true
|
||||
})
|
||||
s.pending = sync.Map{}
|
||||
_ = s.conn.Close()
|
||||
if s.manager != nil {
|
||||
s.manager.handleSessionClosed(s, cause)
|
||||
}
|
||||
})
|
||||
}
|
||||
Reference in New Issue
Block a user