mirror of
https://github.com/router-for-me/CLIProxyAPI.git
synced 2026-02-03 13:00:52 +08:00
feat: add websocket routing and executor unregister API
- Introduce Server.AttachWebsocketRoute(path, handler) to mount websocket upgrade handlers on the Gin engine. - Track registered WS paths via wsRoutes with wsRouteMu to prevent duplicate registrations; initialize in NewServer and import sync. - Add Manager.UnregisterExecutor(provider) for clean executor lifecycle management. - Add github.com/gorilla/websocket v1.5.3 dependency and update go.sum. Motivation: enable services to expose WS endpoints through the core server and allow removing auth executors dynamically while avoiding duplicate route setup. No breaking changes.
This commit is contained in:
@@ -153,6 +153,17 @@ func (m *Manager) RegisterExecutor(executor ProviderExecutor) {
|
||||
m.executors[executor.Identifier()] = executor
|
||||
}
|
||||
|
||||
// UnregisterExecutor removes the executor associated with the provider key.
|
||||
func (m *Manager) UnregisterExecutor(provider string) {
|
||||
provider = strings.ToLower(strings.TrimSpace(provider))
|
||||
if provider == "" {
|
||||
return
|
||||
}
|
||||
m.mu.Lock()
|
||||
delete(m.executors, provider)
|
||||
m.mu.Unlock()
|
||||
}
|
||||
|
||||
// Register inserts a new auth entry into the manager.
|
||||
func (m *Manager) Register(ctx context.Context, auth *Auth) (*Auth, error) {
|
||||
if auth == nil {
|
||||
|
||||
@@ -156,7 +156,17 @@ func (a *Auth) AccountInfo() (string, string) {
|
||||
if v, ok := a.Metadata["email"].(string); ok {
|
||||
return "oauth", v
|
||||
}
|
||||
} else if a.Attributes != nil {
|
||||
}
|
||||
if strings.HasPrefix(strings.ToLower(strings.TrimSpace(a.Provider)), "aistudio-") {
|
||||
if label := strings.TrimSpace(a.Label); label != "" {
|
||||
return "oauth", label
|
||||
}
|
||||
if id := strings.TrimSpace(a.ID); id != "" {
|
||||
return "oauth", id
|
||||
}
|
||||
return "oauth", "aistudio"
|
||||
}
|
||||
if a.Attributes != nil {
|
||||
if v := a.Attributes["api_key"]; v != "" {
|
||||
return "api_key", v
|
||||
}
|
||||
|
||||
@@ -18,6 +18,7 @@ import (
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/runtime/executor"
|
||||
_ "github.com/router-for-me/CLIProxyAPI/v6/internal/usage"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/watcher"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/wsrelay"
|
||||
sdkaccess "github.com/router-for-me/CLIProxyAPI/v6/sdk/access"
|
||||
sdkAuth "github.com/router-for-me/CLIProxyAPI/v6/sdk/auth"
|
||||
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||
@@ -82,6 +83,9 @@ type Service struct {
|
||||
|
||||
// shutdownOnce ensures shutdown is called only once.
|
||||
shutdownOnce sync.Once
|
||||
|
||||
// wsGateway manages websocket Gemini providers.
|
||||
wsGateway *wsrelay.Manager
|
||||
}
|
||||
|
||||
// RegisterUsagePlugin registers a usage plugin on the global usage manager.
|
||||
@@ -172,6 +176,66 @@ func (s *Service) handleAuthUpdate(ctx context.Context, update watcher.AuthUpdat
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Service) ensureWebsocketGateway() {
|
||||
if s == nil {
|
||||
return
|
||||
}
|
||||
if s.wsGateway != nil {
|
||||
return
|
||||
}
|
||||
opts := wsrelay.Options{
|
||||
Path: "/v1/ws",
|
||||
OnConnected: s.wsOnConnected,
|
||||
OnDisconnected: s.wsOnDisconnected,
|
||||
LogDebugf: log.Debugf,
|
||||
LogInfof: log.Infof,
|
||||
LogWarnf: log.Warnf,
|
||||
}
|
||||
s.wsGateway = wsrelay.NewManager(opts)
|
||||
}
|
||||
|
||||
func (s *Service) wsOnConnected(provider string) {
|
||||
if s == nil || provider == "" {
|
||||
return
|
||||
}
|
||||
if !strings.HasPrefix(strings.ToLower(provider), "aistudio-") {
|
||||
return
|
||||
}
|
||||
if s.coreManager != nil {
|
||||
if existing, ok := s.coreManager.GetByID(provider); ok && existing != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
now := time.Now().UTC()
|
||||
auth := &coreauth.Auth{
|
||||
ID: provider,
|
||||
Provider: provider,
|
||||
Label: provider,
|
||||
Status: coreauth.StatusActive,
|
||||
CreatedAt: now,
|
||||
UpdatedAt: now,
|
||||
Attributes: map[string]string{"ws_provider": "gemini"},
|
||||
}
|
||||
log.Infof("websocket provider connected: %s", provider)
|
||||
s.applyCoreAuthAddOrUpdate(context.Background(), auth)
|
||||
}
|
||||
|
||||
func (s *Service) wsOnDisconnected(provider string, reason error) {
|
||||
if s == nil || provider == "" {
|
||||
return
|
||||
}
|
||||
if reason != nil {
|
||||
log.Warnf("websocket provider disconnected: %s (%v)", provider, reason)
|
||||
} else {
|
||||
log.Infof("websocket provider disconnected: %s", provider)
|
||||
}
|
||||
ctx := context.Background()
|
||||
s.applyCoreAuthRemoval(ctx, provider)
|
||||
if s.coreManager != nil {
|
||||
s.coreManager.UnregisterExecutor(provider)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Service) applyCoreAuthAddOrUpdate(ctx context.Context, auth *coreauth.Auth) {
|
||||
if s == nil || auth == nil || auth.ID == "" {
|
||||
return
|
||||
@@ -247,6 +311,12 @@ func (s *Service) ensureExecutorsForAuth(a *coreauth.Auth) {
|
||||
s.coreManager.RegisterExecutor(executor.NewOpenAICompatExecutor(compatProviderKey, s.cfg))
|
||||
return
|
||||
}
|
||||
if strings.HasPrefix(strings.ToLower(strings.TrimSpace(a.Provider)), "aistudio-") {
|
||||
if s.wsGateway != nil {
|
||||
s.coreManager.RegisterExecutor(executor.NewAistudioExecutor(s.cfg, a.Provider, s.wsGateway))
|
||||
}
|
||||
return
|
||||
}
|
||||
switch strings.ToLower(a.Provider) {
|
||||
case "gemini":
|
||||
s.coreManager.RegisterExecutor(executor.NewGeminiExecutor(s.cfg))
|
||||
@@ -342,6 +412,11 @@ func (s *Service) Run(ctx context.Context) error {
|
||||
s.authManager = newDefaultAuthManager()
|
||||
}
|
||||
|
||||
s.ensureWebsocketGateway()
|
||||
if s.server != nil && s.wsGateway != nil {
|
||||
s.server.AttachWebsocketRoute(s.wsGateway.Path(), s.wsGateway.Handler())
|
||||
}
|
||||
|
||||
if s.hooks.OnBeforeStart != nil {
|
||||
s.hooks.OnBeforeStart(s.cfg)
|
||||
}
|
||||
@@ -449,6 +524,14 @@ func (s *Service) Shutdown(ctx context.Context) error {
|
||||
shutdownErr = err
|
||||
}
|
||||
}
|
||||
if s.wsGateway != nil {
|
||||
if err := s.wsGateway.Stop(ctx); err != nil {
|
||||
log.Errorf("failed to stop websocket gateway: %v", err)
|
||||
if shutdownErr == nil {
|
||||
shutdownErr = err
|
||||
}
|
||||
}
|
||||
}
|
||||
if s.authQueueStop != nil {
|
||||
s.authQueueStop()
|
||||
s.authQueueStop = nil
|
||||
@@ -505,6 +588,13 @@ func (s *Service) registerModelsForAuth(a *coreauth.Auth) {
|
||||
}
|
||||
provider := strings.ToLower(strings.TrimSpace(a.Provider))
|
||||
compatProviderKey, compatDisplayName, compatDetected := openAICompatInfoFromAuth(a)
|
||||
if a.Attributes != nil {
|
||||
if strings.EqualFold(a.Attributes["ws_provider"], "gemini") {
|
||||
models := mergeGeminiModels()
|
||||
GlobalModelRegistry().RegisterClient(a.ID, provider, models)
|
||||
return
|
||||
}
|
||||
}
|
||||
if compatDetected {
|
||||
provider = "openai-compatibility"
|
||||
}
|
||||
@@ -611,3 +701,24 @@ func (s *Service) registerModelsForAuth(a *coreauth.Auth) {
|
||||
GlobalModelRegistry().RegisterClient(a.ID, key, models)
|
||||
}
|
||||
}
|
||||
|
||||
func mergeGeminiModels() []*ModelInfo {
|
||||
models := make([]*ModelInfo, 0, 16)
|
||||
seen := make(map[string]struct{})
|
||||
appendModels := func(items []*ModelInfo) {
|
||||
for i := range items {
|
||||
m := items[i]
|
||||
if m == nil || m.ID == "" {
|
||||
continue
|
||||
}
|
||||
if _, ok := seen[m.ID]; ok {
|
||||
continue
|
||||
}
|
||||
seen[m.ID] = struct{}{}
|
||||
models = append(models, m)
|
||||
}
|
||||
}
|
||||
appendModels(registry.GetGeminiModels())
|
||||
appendModels(registry.GetGeminiCLIModels())
|
||||
return models
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user