mirror of
https://github.com/router-for-me/CLIProxyAPI.git
synced 2026-02-03 04:50:52 +08:00
- Introduce Server.AttachWebsocketRoute(path, handler) to mount websocket upgrade handlers on the Gin engine. - Track registered WS paths via wsRoutes with wsRouteMu to prevent duplicate registrations; initialize in NewServer and import sync. - Add Manager.UnregisterExecutor(provider) for clean executor lifecycle management. - Add github.com/gorilla/websocket v1.5.3 dependency and update go.sum. Motivation: enable services to expose WS endpoints through the core server and allow removing auth executors dynamically while avoiding duplicate route setup. No breaking changes.
189 lines
4.2 KiB
Go
189 lines
4.2 KiB
Go
package wsrelay
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"fmt"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/gorilla/websocket"
|
|
)
|
|
|
|
const (
|
|
readTimeout = 60 * time.Second
|
|
writeTimeout = 10 * time.Second
|
|
maxInboundMessageLen = 64 << 20 // 64 MiB
|
|
heartbeatInterval = 30 * time.Second
|
|
)
|
|
|
|
var errClosed = errors.New("websocket session closed")
|
|
|
|
type pendingRequest struct {
|
|
ch chan Message
|
|
closeOnce sync.Once
|
|
}
|
|
|
|
func (pr *pendingRequest) close() {
|
|
if pr == nil {
|
|
return
|
|
}
|
|
pr.closeOnce.Do(func() {
|
|
close(pr.ch)
|
|
})
|
|
}
|
|
|
|
type session struct {
|
|
conn *websocket.Conn
|
|
manager *Manager
|
|
provider string
|
|
id string
|
|
closed chan struct{}
|
|
closeOnce sync.Once
|
|
writeMutex sync.Mutex
|
|
pending sync.Map // map[string]*pendingRequest
|
|
}
|
|
|
|
func newSession(conn *websocket.Conn, mgr *Manager, id string) *session {
|
|
s := &session{
|
|
conn: conn,
|
|
manager: mgr,
|
|
provider: "",
|
|
id: id,
|
|
closed: make(chan struct{}),
|
|
}
|
|
conn.SetReadLimit(maxInboundMessageLen)
|
|
conn.SetReadDeadline(time.Now().Add(readTimeout))
|
|
conn.SetPongHandler(func(string) error {
|
|
conn.SetReadDeadline(time.Now().Add(readTimeout))
|
|
return nil
|
|
})
|
|
s.startHeartbeat()
|
|
return s
|
|
}
|
|
|
|
func (s *session) startHeartbeat() {
|
|
if s == nil || s.conn == nil {
|
|
return
|
|
}
|
|
ticker := time.NewTicker(heartbeatInterval)
|
|
go func() {
|
|
defer ticker.Stop()
|
|
for {
|
|
select {
|
|
case <-s.closed:
|
|
return
|
|
case <-ticker.C:
|
|
s.writeMutex.Lock()
|
|
err := s.conn.WriteControl(websocket.PingMessage, []byte("ping"), time.Now().Add(writeTimeout))
|
|
s.writeMutex.Unlock()
|
|
if err != nil {
|
|
s.cleanup(err)
|
|
return
|
|
}
|
|
}
|
|
}
|
|
}()
|
|
}
|
|
|
|
func (s *session) run(ctx context.Context) {
|
|
defer s.cleanup(errClosed)
|
|
for {
|
|
var msg Message
|
|
if err := s.conn.ReadJSON(&msg); err != nil {
|
|
s.cleanup(err)
|
|
return
|
|
}
|
|
s.dispatch(msg)
|
|
}
|
|
}
|
|
|
|
func (s *session) dispatch(msg Message) {
|
|
if msg.Type == MessageTypePing {
|
|
_ = s.send(context.Background(), Message{ID: msg.ID, Type: MessageTypePong})
|
|
return
|
|
}
|
|
if value, ok := s.pending.Load(msg.ID); ok {
|
|
req := value.(*pendingRequest)
|
|
select {
|
|
case req.ch <- msg:
|
|
default:
|
|
}
|
|
if msg.Type == MessageTypeHTTPResp || msg.Type == MessageTypeError || msg.Type == MessageTypeStreamEnd {
|
|
if actual, loaded := s.pending.LoadAndDelete(msg.ID); loaded {
|
|
actual.(*pendingRequest).close()
|
|
}
|
|
}
|
|
return
|
|
}
|
|
if msg.Type == MessageTypeHTTPResp || msg.Type == MessageTypeError || msg.Type == MessageTypeStreamEnd {
|
|
s.manager.logDebugf("wsrelay: received terminal message for unknown id %s (provider=%s)", msg.ID, s.provider)
|
|
}
|
|
}
|
|
|
|
func (s *session) send(ctx context.Context, msg Message) error {
|
|
select {
|
|
case <-s.closed:
|
|
return errClosed
|
|
default:
|
|
}
|
|
s.writeMutex.Lock()
|
|
defer s.writeMutex.Unlock()
|
|
if err := s.conn.SetWriteDeadline(time.Now().Add(writeTimeout)); err != nil {
|
|
return fmt.Errorf("set write deadline: %w", err)
|
|
}
|
|
if err := s.conn.WriteJSON(msg); err != nil {
|
|
return fmt.Errorf("write json: %w", err)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
func (s *session) request(ctx context.Context, msg Message) (<-chan Message, error) {
|
|
if msg.ID == "" {
|
|
return nil, fmt.Errorf("wsrelay: message id is required")
|
|
}
|
|
if _, loaded := s.pending.LoadOrStore(msg.ID, &pendingRequest{ch: make(chan Message, 8)}); loaded {
|
|
return nil, fmt.Errorf("wsrelay: duplicate message id %s", msg.ID)
|
|
}
|
|
value, _ := s.pending.Load(msg.ID)
|
|
req := value.(*pendingRequest)
|
|
if err := s.send(ctx, msg); err != nil {
|
|
if actual, loaded := s.pending.LoadAndDelete(msg.ID); loaded {
|
|
req := actual.(*pendingRequest)
|
|
req.close()
|
|
}
|
|
return nil, err
|
|
}
|
|
go func() {
|
|
select {
|
|
case <-ctx.Done():
|
|
if actual, loaded := s.pending.LoadAndDelete(msg.ID); loaded {
|
|
actual.(*pendingRequest).close()
|
|
}
|
|
case <-s.closed:
|
|
}
|
|
}()
|
|
return req.ch, nil
|
|
}
|
|
|
|
func (s *session) cleanup(cause error) {
|
|
s.closeOnce.Do(func() {
|
|
close(s.closed)
|
|
s.pending.Range(func(key, value any) bool {
|
|
req := value.(*pendingRequest)
|
|
msg := Message{ID: key.(string), Type: MessageTypeError, Payload: map[string]any{"error": cause.Error()}}
|
|
select {
|
|
case req.ch <- msg:
|
|
default:
|
|
}
|
|
req.close()
|
|
return true
|
|
})
|
|
s.pending = sync.Map{}
|
|
_ = s.conn.Close()
|
|
if s.manager != nil {
|
|
s.manager.handleSessionClosed(s, cause)
|
|
}
|
|
})
|
|
}
|