Files
CLIProxyAPI/internal/wsrelay/session.go
hkfires 3839d93ba0 feat: add websocket routing and executor unregister API
- Introduce Server.AttachWebsocketRoute(path, handler) to mount websocket
  upgrade handlers on the Gin engine.
- Track registered WS paths via wsRoutes with wsRouteMu to prevent
  duplicate registrations; initialize in NewServer and import sync.
- Add Manager.UnregisterExecutor(provider) for clean executor lifecycle
  management.
- Add github.com/gorilla/websocket v1.5.3 dependency and update go.sum.

Motivation: enable services to expose WS endpoints through the core server
and allow removing auth executors dynamically while avoiding duplicate
route setup. No breaking changes.
2025-10-26 07:46:03 +08:00

189 lines
4.2 KiB
Go

package wsrelay
import (
"context"
"errors"
"fmt"
"sync"
"time"
"github.com/gorilla/websocket"
)
const (
readTimeout = 60 * time.Second
writeTimeout = 10 * time.Second
maxInboundMessageLen = 64 << 20 // 64 MiB
heartbeatInterval = 30 * time.Second
)
var errClosed = errors.New("websocket session closed")
type pendingRequest struct {
ch chan Message
closeOnce sync.Once
}
func (pr *pendingRequest) close() {
if pr == nil {
return
}
pr.closeOnce.Do(func() {
close(pr.ch)
})
}
type session struct {
conn *websocket.Conn
manager *Manager
provider string
id string
closed chan struct{}
closeOnce sync.Once
writeMutex sync.Mutex
pending sync.Map // map[string]*pendingRequest
}
func newSession(conn *websocket.Conn, mgr *Manager, id string) *session {
s := &session{
conn: conn,
manager: mgr,
provider: "",
id: id,
closed: make(chan struct{}),
}
conn.SetReadLimit(maxInboundMessageLen)
conn.SetReadDeadline(time.Now().Add(readTimeout))
conn.SetPongHandler(func(string) error {
conn.SetReadDeadline(time.Now().Add(readTimeout))
return nil
})
s.startHeartbeat()
return s
}
func (s *session) startHeartbeat() {
if s == nil || s.conn == nil {
return
}
ticker := time.NewTicker(heartbeatInterval)
go func() {
defer ticker.Stop()
for {
select {
case <-s.closed:
return
case <-ticker.C:
s.writeMutex.Lock()
err := s.conn.WriteControl(websocket.PingMessage, []byte("ping"), time.Now().Add(writeTimeout))
s.writeMutex.Unlock()
if err != nil {
s.cleanup(err)
return
}
}
}
}()
}
func (s *session) run(ctx context.Context) {
defer s.cleanup(errClosed)
for {
var msg Message
if err := s.conn.ReadJSON(&msg); err != nil {
s.cleanup(err)
return
}
s.dispatch(msg)
}
}
func (s *session) dispatch(msg Message) {
if msg.Type == MessageTypePing {
_ = s.send(context.Background(), Message{ID: msg.ID, Type: MessageTypePong})
return
}
if value, ok := s.pending.Load(msg.ID); ok {
req := value.(*pendingRequest)
select {
case req.ch <- msg:
default:
}
if msg.Type == MessageTypeHTTPResp || msg.Type == MessageTypeError || msg.Type == MessageTypeStreamEnd {
if actual, loaded := s.pending.LoadAndDelete(msg.ID); loaded {
actual.(*pendingRequest).close()
}
}
return
}
if msg.Type == MessageTypeHTTPResp || msg.Type == MessageTypeError || msg.Type == MessageTypeStreamEnd {
s.manager.logDebugf("wsrelay: received terminal message for unknown id %s (provider=%s)", msg.ID, s.provider)
}
}
func (s *session) send(ctx context.Context, msg Message) error {
select {
case <-s.closed:
return errClosed
default:
}
s.writeMutex.Lock()
defer s.writeMutex.Unlock()
if err := s.conn.SetWriteDeadline(time.Now().Add(writeTimeout)); err != nil {
return fmt.Errorf("set write deadline: %w", err)
}
if err := s.conn.WriteJSON(msg); err != nil {
return fmt.Errorf("write json: %w", err)
}
return nil
}
func (s *session) request(ctx context.Context, msg Message) (<-chan Message, error) {
if msg.ID == "" {
return nil, fmt.Errorf("wsrelay: message id is required")
}
if _, loaded := s.pending.LoadOrStore(msg.ID, &pendingRequest{ch: make(chan Message, 8)}); loaded {
return nil, fmt.Errorf("wsrelay: duplicate message id %s", msg.ID)
}
value, _ := s.pending.Load(msg.ID)
req := value.(*pendingRequest)
if err := s.send(ctx, msg); err != nil {
if actual, loaded := s.pending.LoadAndDelete(msg.ID); loaded {
req := actual.(*pendingRequest)
req.close()
}
return nil, err
}
go func() {
select {
case <-ctx.Done():
if actual, loaded := s.pending.LoadAndDelete(msg.ID); loaded {
actual.(*pendingRequest).close()
}
case <-s.closed:
}
}()
return req.ch, nil
}
func (s *session) cleanup(cause error) {
s.closeOnce.Do(func() {
close(s.closed)
s.pending.Range(func(key, value any) bool {
req := value.(*pendingRequest)
msg := Message{ID: key.(string), Type: MessageTypeError, Payload: map[string]any{"error": cause.Error()}}
select {
case req.ch <- msg:
default:
}
req.close()
return true
})
s.pending = sync.Map{}
_ = s.conn.Close()
if s.manager != nil {
s.manager.handleSessionClosed(s, cause)
}
})
}