mirror of
https://github.com/router-for-me/CLIProxyAPI.git
synced 2026-02-03 04:50:52 +08:00
206 lines
5.0 KiB
Go
206 lines
5.0 KiB
Go
package wsrelay
|
|
|
|
import (
|
|
"context"
|
|
"crypto/rand"
|
|
"errors"
|
|
"fmt"
|
|
"net/http"
|
|
"strings"
|
|
"sync"
|
|
"time"
|
|
|
|
"github.com/gorilla/websocket"
|
|
)
|
|
|
|
// Manager exposes a websocket endpoint that proxies Gemini requests to
|
|
// connected clients.
|
|
type Manager struct {
|
|
path string
|
|
upgrader websocket.Upgrader
|
|
sessions map[string]*session
|
|
sessMutex sync.RWMutex
|
|
|
|
providerFactory func(*http.Request) (string, error)
|
|
onConnected func(string)
|
|
onDisconnected func(string, error)
|
|
|
|
logDebugf func(string, ...any)
|
|
logInfof func(string, ...any)
|
|
logWarnf func(string, ...any)
|
|
}
|
|
|
|
// Options configures a Manager instance.
|
|
type Options struct {
|
|
Path string
|
|
ProviderFactory func(*http.Request) (string, error)
|
|
OnConnected func(string)
|
|
OnDisconnected func(string, error)
|
|
LogDebugf func(string, ...any)
|
|
LogInfof func(string, ...any)
|
|
LogWarnf func(string, ...any)
|
|
}
|
|
|
|
// NewManager builds a websocket relay manager with the supplied options.
|
|
func NewManager(opts Options) *Manager {
|
|
path := strings.TrimSpace(opts.Path)
|
|
if path == "" {
|
|
path = "/v1/ws"
|
|
}
|
|
if !strings.HasPrefix(path, "/") {
|
|
path = "/" + path
|
|
}
|
|
mgr := &Manager{
|
|
path: path,
|
|
sessions: make(map[string]*session),
|
|
upgrader: websocket.Upgrader{
|
|
ReadBufferSize: 1024,
|
|
WriteBufferSize: 1024,
|
|
CheckOrigin: func(r *http.Request) bool {
|
|
return true
|
|
},
|
|
},
|
|
providerFactory: opts.ProviderFactory,
|
|
onConnected: opts.OnConnected,
|
|
onDisconnected: opts.OnDisconnected,
|
|
logDebugf: opts.LogDebugf,
|
|
logInfof: opts.LogInfof,
|
|
logWarnf: opts.LogWarnf,
|
|
}
|
|
if mgr.logDebugf == nil {
|
|
mgr.logDebugf = func(string, ...any) {}
|
|
}
|
|
if mgr.logInfof == nil {
|
|
mgr.logInfof = func(string, ...any) {}
|
|
}
|
|
if mgr.logWarnf == nil {
|
|
mgr.logWarnf = func(s string, args ...any) { fmt.Printf(s+"\n", args...) }
|
|
}
|
|
return mgr
|
|
}
|
|
|
|
// Path returns the HTTP path the manager expects for websocket upgrades.
|
|
func (m *Manager) Path() string {
|
|
if m == nil {
|
|
return "/v1/ws"
|
|
}
|
|
return m.path
|
|
}
|
|
|
|
// Handler exposes an http.Handler that upgrades connections to websocket sessions.
|
|
func (m *Manager) Handler() http.Handler {
|
|
return http.HandlerFunc(m.handleWebsocket)
|
|
}
|
|
|
|
// Stop gracefully closes all active websocket sessions.
|
|
func (m *Manager) Stop(_ context.Context) error {
|
|
m.sessMutex.Lock()
|
|
sessions := make([]*session, 0, len(m.sessions))
|
|
for _, sess := range m.sessions {
|
|
sessions = append(sessions, sess)
|
|
}
|
|
m.sessions = make(map[string]*session)
|
|
m.sessMutex.Unlock()
|
|
|
|
for _, sess := range sessions {
|
|
if sess != nil {
|
|
sess.cleanup(errors.New("wsrelay: manager stopped"))
|
|
}
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// handleWebsocket upgrades the connection and wires the session into the pool.
|
|
func (m *Manager) handleWebsocket(w http.ResponseWriter, r *http.Request) {
|
|
expectedPath := m.Path()
|
|
if expectedPath != "" && r.URL != nil && r.URL.Path != expectedPath {
|
|
http.NotFound(w, r)
|
|
return
|
|
}
|
|
if !strings.EqualFold(r.Method, http.MethodGet) {
|
|
w.Header().Set("Allow", http.MethodGet)
|
|
http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
|
|
return
|
|
}
|
|
conn, err := m.upgrader.Upgrade(w, r, nil)
|
|
if err != nil {
|
|
m.logWarnf("wsrelay: upgrade failed: %v", err)
|
|
return
|
|
}
|
|
s := newSession(conn, m, randomProviderName())
|
|
if m.providerFactory != nil {
|
|
name, err := m.providerFactory(r)
|
|
if err != nil {
|
|
s.cleanup(err)
|
|
return
|
|
}
|
|
if strings.TrimSpace(name) != "" {
|
|
s.provider = strings.ToLower(name)
|
|
}
|
|
}
|
|
if s.provider == "" {
|
|
s.provider = strings.ToLower(s.id)
|
|
}
|
|
m.sessMutex.Lock()
|
|
var replaced *session
|
|
if existing, ok := m.sessions[s.provider]; ok {
|
|
replaced = existing
|
|
}
|
|
m.sessions[s.provider] = s
|
|
m.sessMutex.Unlock()
|
|
|
|
if replaced != nil {
|
|
replaced.cleanup(errors.New("replaced by new connection"))
|
|
}
|
|
if m.onConnected != nil {
|
|
m.onConnected(s.provider)
|
|
}
|
|
|
|
go s.run(context.Background())
|
|
}
|
|
|
|
// Send forwards the message to the specific provider connection and returns a channel
|
|
// yielding response messages.
|
|
func (m *Manager) Send(ctx context.Context, provider string, msg Message) (<-chan Message, error) {
|
|
s := m.session(provider)
|
|
if s == nil {
|
|
return nil, fmt.Errorf("wsrelay: provider %s not connected", provider)
|
|
}
|
|
return s.request(ctx, msg)
|
|
}
|
|
|
|
func (m *Manager) session(provider string) *session {
|
|
key := strings.ToLower(strings.TrimSpace(provider))
|
|
m.sessMutex.RLock()
|
|
s := m.sessions[key]
|
|
m.sessMutex.RUnlock()
|
|
return s
|
|
}
|
|
|
|
func (m *Manager) handleSessionClosed(s *session, cause error) {
|
|
if s == nil {
|
|
return
|
|
}
|
|
key := strings.ToLower(strings.TrimSpace(s.provider))
|
|
m.sessMutex.Lock()
|
|
if cur, ok := m.sessions[key]; ok && cur == s {
|
|
delete(m.sessions, key)
|
|
}
|
|
m.sessMutex.Unlock()
|
|
if m.onDisconnected != nil {
|
|
m.onDisconnected(s.provider, cause)
|
|
}
|
|
}
|
|
|
|
func randomProviderName() string {
|
|
const alphabet = "abcdefghijklmnopqrstuvwxyz0123456789"
|
|
buf := make([]byte, 16)
|
|
if _, err := rand.Read(buf); err != nil {
|
|
return fmt.Sprintf("aistudio-%x", time.Now().UnixNano())
|
|
}
|
|
for i := range buf {
|
|
buf[i] = alphabet[int(buf[i])%len(alphabet)]
|
|
}
|
|
return "aistudio-" + string(buf)
|
|
}
|