Files
CLIProxyAPI/internal/wsrelay/manager.go

206 lines
5.0 KiB
Go

package wsrelay
import (
"context"
"crypto/rand"
"errors"
"fmt"
"net/http"
"strings"
"sync"
"time"
"github.com/gorilla/websocket"
)
// Manager exposes a websocket endpoint that proxies Gemini requests to
// connected clients.
type Manager struct {
path string
upgrader websocket.Upgrader
sessions map[string]*session
sessMutex sync.RWMutex
providerFactory func(*http.Request) (string, error)
onConnected func(string)
onDisconnected func(string, error)
logDebugf func(string, ...any)
logInfof func(string, ...any)
logWarnf func(string, ...any)
}
// Options configures a Manager instance.
type Options struct {
Path string
ProviderFactory func(*http.Request) (string, error)
OnConnected func(string)
OnDisconnected func(string, error)
LogDebugf func(string, ...any)
LogInfof func(string, ...any)
LogWarnf func(string, ...any)
}
// NewManager builds a websocket relay manager with the supplied options.
func NewManager(opts Options) *Manager {
path := strings.TrimSpace(opts.Path)
if path == "" {
path = "/v1/ws"
}
if !strings.HasPrefix(path, "/") {
path = "/" + path
}
mgr := &Manager{
path: path,
sessions: make(map[string]*session),
upgrader: websocket.Upgrader{
ReadBufferSize: 1024,
WriteBufferSize: 1024,
CheckOrigin: func(r *http.Request) bool {
return true
},
},
providerFactory: opts.ProviderFactory,
onConnected: opts.OnConnected,
onDisconnected: opts.OnDisconnected,
logDebugf: opts.LogDebugf,
logInfof: opts.LogInfof,
logWarnf: opts.LogWarnf,
}
if mgr.logDebugf == nil {
mgr.logDebugf = func(string, ...any) {}
}
if mgr.logInfof == nil {
mgr.logInfof = func(string, ...any) {}
}
if mgr.logWarnf == nil {
mgr.logWarnf = func(s string, args ...any) { fmt.Printf(s+"\n", args...) }
}
return mgr
}
// Path returns the HTTP path the manager expects for websocket upgrades.
func (m *Manager) Path() string {
if m == nil {
return "/v1/ws"
}
return m.path
}
// Handler exposes an http.Handler that upgrades connections to websocket sessions.
func (m *Manager) Handler() http.Handler {
return http.HandlerFunc(m.handleWebsocket)
}
// Stop gracefully closes all active websocket sessions.
func (m *Manager) Stop(_ context.Context) error {
m.sessMutex.Lock()
sessions := make([]*session, 0, len(m.sessions))
for _, sess := range m.sessions {
sessions = append(sessions, sess)
}
m.sessions = make(map[string]*session)
m.sessMutex.Unlock()
for _, sess := range sessions {
if sess != nil {
sess.cleanup(errors.New("wsrelay: manager stopped"))
}
}
return nil
}
// handleWebsocket upgrades the connection and wires the session into the pool.
func (m *Manager) handleWebsocket(w http.ResponseWriter, r *http.Request) {
expectedPath := m.Path()
if expectedPath != "" && r.URL != nil && r.URL.Path != expectedPath {
http.NotFound(w, r)
return
}
if !strings.EqualFold(r.Method, http.MethodGet) {
w.Header().Set("Allow", http.MethodGet)
http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
return
}
conn, err := m.upgrader.Upgrade(w, r, nil)
if err != nil {
m.logWarnf("wsrelay: upgrade failed: %v", err)
return
}
s := newSession(conn, m, randomProviderName())
if m.providerFactory != nil {
name, err := m.providerFactory(r)
if err != nil {
s.cleanup(err)
return
}
if strings.TrimSpace(name) != "" {
s.provider = strings.ToLower(name)
}
}
if s.provider == "" {
s.provider = strings.ToLower(s.id)
}
m.sessMutex.Lock()
var replaced *session
if existing, ok := m.sessions[s.provider]; ok {
replaced = existing
}
m.sessions[s.provider] = s
m.sessMutex.Unlock()
if replaced != nil {
replaced.cleanup(errors.New("replaced by new connection"))
}
if m.onConnected != nil {
m.onConnected(s.provider)
}
go s.run(context.Background())
}
// Send forwards the message to the specific provider connection and returns a channel
// yielding response messages.
func (m *Manager) Send(ctx context.Context, provider string, msg Message) (<-chan Message, error) {
s := m.session(provider)
if s == nil {
return nil, fmt.Errorf("wsrelay: provider %s not connected", provider)
}
return s.request(ctx, msg)
}
func (m *Manager) session(provider string) *session {
key := strings.ToLower(strings.TrimSpace(provider))
m.sessMutex.RLock()
s := m.sessions[key]
m.sessMutex.RUnlock()
return s
}
func (m *Manager) handleSessionClosed(s *session, cause error) {
if s == nil {
return
}
key := strings.ToLower(strings.TrimSpace(s.provider))
m.sessMutex.Lock()
if cur, ok := m.sessions[key]; ok && cur == s {
delete(m.sessions, key)
}
m.sessMutex.Unlock()
if m.onDisconnected != nil {
m.onDisconnected(s.provider, cause)
}
}
func randomProviderName() string {
const alphabet = "abcdefghijklmnopqrstuvwxyz0123456789"
buf := make([]byte, 16)
if _, err := rand.Read(buf); err != nil {
return fmt.Sprintf("aistudio-%x", time.Now().UnixNano())
}
for i := range buf {
buf[i] = alphabet[int(buf[i])%len(alphabet)]
}
return "aistudio-" + string(buf)
}