feat: add websocket routing and executor unregister API

- Introduce Server.AttachWebsocketRoute(path, handler) to mount websocket
  upgrade handlers on the Gin engine.
- Track registered WS paths via wsRoutes with wsRouteMu to prevent
  duplicate registrations; initialize in NewServer and import sync.
- Add Manager.UnregisterExecutor(provider) for clean executor lifecycle
  management.
- Add github.com/gorilla/websocket v1.5.3 dependency and update go.sum.

Motivation: enable services to expose WS endpoints through the core server
and allow removing auth executors dynamically while avoiding duplicate
route setup. No breaking changes.
This commit is contained in:
hkfires
2025-10-25 11:30:39 +08:00
parent a552a45b81
commit 3839d93ba0
11 changed files with 1035 additions and 3 deletions

View File

@@ -153,6 +153,17 @@ func (m *Manager) RegisterExecutor(executor ProviderExecutor) {
m.executors[executor.Identifier()] = executor
}
// UnregisterExecutor removes the executor associated with the provider key.
func (m *Manager) UnregisterExecutor(provider string) {
provider = strings.ToLower(strings.TrimSpace(provider))
if provider == "" {
return
}
m.mu.Lock()
delete(m.executors, provider)
m.mu.Unlock()
}
// Register inserts a new auth entry into the manager.
func (m *Manager) Register(ctx context.Context, auth *Auth) (*Auth, error) {
if auth == nil {

View File

@@ -156,7 +156,17 @@ func (a *Auth) AccountInfo() (string, string) {
if v, ok := a.Metadata["email"].(string); ok {
return "oauth", v
}
} else if a.Attributes != nil {
}
if strings.HasPrefix(strings.ToLower(strings.TrimSpace(a.Provider)), "aistudio-") {
if label := strings.TrimSpace(a.Label); label != "" {
return "oauth", label
}
if id := strings.TrimSpace(a.ID); id != "" {
return "oauth", id
}
return "oauth", "aistudio"
}
if a.Attributes != nil {
if v := a.Attributes["api_key"]; v != "" {
return "api_key", v
}

View File

@@ -18,6 +18,7 @@ import (
"github.com/router-for-me/CLIProxyAPI/v6/internal/runtime/executor"
_ "github.com/router-for-me/CLIProxyAPI/v6/internal/usage"
"github.com/router-for-me/CLIProxyAPI/v6/internal/watcher"
"github.com/router-for-me/CLIProxyAPI/v6/internal/wsrelay"
sdkaccess "github.com/router-for-me/CLIProxyAPI/v6/sdk/access"
sdkAuth "github.com/router-for-me/CLIProxyAPI/v6/sdk/auth"
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
@@ -82,6 +83,9 @@ type Service struct {
// shutdownOnce ensures shutdown is called only once.
shutdownOnce sync.Once
// wsGateway manages websocket Gemini providers.
wsGateway *wsrelay.Manager
}
// RegisterUsagePlugin registers a usage plugin on the global usage manager.
@@ -172,6 +176,66 @@ func (s *Service) handleAuthUpdate(ctx context.Context, update watcher.AuthUpdat
}
}
func (s *Service) ensureWebsocketGateway() {
if s == nil {
return
}
if s.wsGateway != nil {
return
}
opts := wsrelay.Options{
Path: "/v1/ws",
OnConnected: s.wsOnConnected,
OnDisconnected: s.wsOnDisconnected,
LogDebugf: log.Debugf,
LogInfof: log.Infof,
LogWarnf: log.Warnf,
}
s.wsGateway = wsrelay.NewManager(opts)
}
func (s *Service) wsOnConnected(provider string) {
if s == nil || provider == "" {
return
}
if !strings.HasPrefix(strings.ToLower(provider), "aistudio-") {
return
}
if s.coreManager != nil {
if existing, ok := s.coreManager.GetByID(provider); ok && existing != nil {
return
}
}
now := time.Now().UTC()
auth := &coreauth.Auth{
ID: provider,
Provider: provider,
Label: provider,
Status: coreauth.StatusActive,
CreatedAt: now,
UpdatedAt: now,
Attributes: map[string]string{"ws_provider": "gemini"},
}
log.Infof("websocket provider connected: %s", provider)
s.applyCoreAuthAddOrUpdate(context.Background(), auth)
}
func (s *Service) wsOnDisconnected(provider string, reason error) {
if s == nil || provider == "" {
return
}
if reason != nil {
log.Warnf("websocket provider disconnected: %s (%v)", provider, reason)
} else {
log.Infof("websocket provider disconnected: %s", provider)
}
ctx := context.Background()
s.applyCoreAuthRemoval(ctx, provider)
if s.coreManager != nil {
s.coreManager.UnregisterExecutor(provider)
}
}
func (s *Service) applyCoreAuthAddOrUpdate(ctx context.Context, auth *coreauth.Auth) {
if s == nil || auth == nil || auth.ID == "" {
return
@@ -247,6 +311,12 @@ func (s *Service) ensureExecutorsForAuth(a *coreauth.Auth) {
s.coreManager.RegisterExecutor(executor.NewOpenAICompatExecutor(compatProviderKey, s.cfg))
return
}
if strings.HasPrefix(strings.ToLower(strings.TrimSpace(a.Provider)), "aistudio-") {
if s.wsGateway != nil {
s.coreManager.RegisterExecutor(executor.NewAistudioExecutor(s.cfg, a.Provider, s.wsGateway))
}
return
}
switch strings.ToLower(a.Provider) {
case "gemini":
s.coreManager.RegisterExecutor(executor.NewGeminiExecutor(s.cfg))
@@ -342,6 +412,11 @@ func (s *Service) Run(ctx context.Context) error {
s.authManager = newDefaultAuthManager()
}
s.ensureWebsocketGateway()
if s.server != nil && s.wsGateway != nil {
s.server.AttachWebsocketRoute(s.wsGateway.Path(), s.wsGateway.Handler())
}
if s.hooks.OnBeforeStart != nil {
s.hooks.OnBeforeStart(s.cfg)
}
@@ -449,6 +524,14 @@ func (s *Service) Shutdown(ctx context.Context) error {
shutdownErr = err
}
}
if s.wsGateway != nil {
if err := s.wsGateway.Stop(ctx); err != nil {
log.Errorf("failed to stop websocket gateway: %v", err)
if shutdownErr == nil {
shutdownErr = err
}
}
}
if s.authQueueStop != nil {
s.authQueueStop()
s.authQueueStop = nil
@@ -505,6 +588,13 @@ func (s *Service) registerModelsForAuth(a *coreauth.Auth) {
}
provider := strings.ToLower(strings.TrimSpace(a.Provider))
compatProviderKey, compatDisplayName, compatDetected := openAICompatInfoFromAuth(a)
if a.Attributes != nil {
if strings.EqualFold(a.Attributes["ws_provider"], "gemini") {
models := mergeGeminiModels()
GlobalModelRegistry().RegisterClient(a.ID, provider, models)
return
}
}
if compatDetected {
provider = "openai-compatibility"
}
@@ -611,3 +701,24 @@ func (s *Service) registerModelsForAuth(a *coreauth.Auth) {
GlobalModelRegistry().RegisterClient(a.ID, key, models)
}
}
func mergeGeminiModels() []*ModelInfo {
models := make([]*ModelInfo, 0, 16)
seen := make(map[string]struct{})
appendModels := func(items []*ModelInfo) {
for i := range items {
m := items[i]
if m == nil || m.ID == "" {
continue
}
if _, ok := seen[m.ID]; ok {
continue
}
seen[m.ID] = struct{}{}
models = append(models, m)
}
}
appendModels(registry.GetGeminiModels())
appendModels(registry.GetGeminiCLIModels())
return models
}