mirror of
https://github.com/router-for-me/CLIProxyAPI.git
synced 2026-02-03 04:50:52 +08:00
Merge pull request #161 from router-for-me/aistudio
Add websocket provider
This commit is contained in:
@@ -43,6 +43,9 @@ quota-exceeded:
|
|||||||
switch-project: true # Whether to automatically switch to another project when a quota is exceeded
|
switch-project: true # Whether to automatically switch to another project when a quota is exceeded
|
||||||
switch-preview-model: true # Whether to automatically switch to a preview model when a quota is exceeded
|
switch-preview-model: true # Whether to automatically switch to a preview model when a quota is exceeded
|
||||||
|
|
||||||
|
# When true, enable authentication for the WebSocket API (/v1/ws).
|
||||||
|
ws-auth: false
|
||||||
|
|
||||||
# API keys for official Generative Language API
|
# API keys for official Generative Language API
|
||||||
#generative-language-api-key:
|
#generative-language-api-key:
|
||||||
# - "AIzaSy...01"
|
# - "AIzaSy...01"
|
||||||
|
|||||||
1
go.mod
1
go.mod
@@ -7,6 +7,7 @@ require (
|
|||||||
github.com/gin-gonic/gin v1.10.1
|
github.com/gin-gonic/gin v1.10.1
|
||||||
github.com/go-git/go-git/v6 v6.0.0-20251009132922-75a182125145
|
github.com/go-git/go-git/v6 v6.0.0-20251009132922-75a182125145
|
||||||
github.com/google/uuid v1.6.0
|
github.com/google/uuid v1.6.0
|
||||||
|
github.com/gorilla/websocket v1.5.3
|
||||||
github.com/jackc/pgx/v5 v5.7.6
|
github.com/jackc/pgx/v5 v5.7.6
|
||||||
github.com/joho/godotenv v1.5.1
|
github.com/joho/godotenv v1.5.1
|
||||||
github.com/klauspost/compress v1.17.4
|
github.com/klauspost/compress v1.17.4
|
||||||
|
|||||||
4
go.sum
4
go.sum
@@ -66,6 +66,8 @@ github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/
|
|||||||
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
|
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
|
||||||
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
|
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
|
||||||
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||||
|
github.com/gorilla/websocket v1.5.3 h1:saDtZ6Pbx/0u+bgYQ3q96pZgCzfhKXGPqt7kZ72aNNg=
|
||||||
|
github.com/gorilla/websocket v1.5.3/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
|
||||||
github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM=
|
github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM=
|
||||||
github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg=
|
github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg=
|
||||||
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 h1:iCEnooe7UlwOQYpKFhBabPMi4aNAfoODPEFNiAnClxo=
|
github.com/jackc/pgservicefile v0.0.0-20240606120523-5a60cdf6a761 h1:iCEnooe7UlwOQYpKFhBabPMi4aNAfoODPEFNiAnClxo=
|
||||||
@@ -80,8 +82,6 @@ github.com/json-iterator/go v1.1.12 h1:PV8peI4a0ysnczrg+LtxykD8LfKY9ML6u2jnxaEnr
|
|||||||
github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo=
|
github.com/json-iterator/go v1.1.12/go.mod h1:e30LSqwooZae/UwlEbR2852Gd8hjQvJoHmT4TnhNGBo=
|
||||||
github.com/kevinburke/ssh_config v1.4.0 h1:6xxtP5bZ2E4NF5tuQulISpTO2z8XbtH8cg1PWkxoFkQ=
|
github.com/kevinburke/ssh_config v1.4.0 h1:6xxtP5bZ2E4NF5tuQulISpTO2z8XbtH8cg1PWkxoFkQ=
|
||||||
github.com/kevinburke/ssh_config v1.4.0/go.mod h1:q2RIzfka+BXARoNexmF9gkxEX7DmvbW9P4hIVx2Kg4M=
|
github.com/kevinburke/ssh_config v1.4.0/go.mod h1:q2RIzfka+BXARoNexmF9gkxEX7DmvbW9P4hIVx2Kg4M=
|
||||||
github.com/klauspost/compress v1.17.3 h1:qkRjuerhUU1EmXLYGkSH6EZL+vPSxIrYjLNAK4slzwA=
|
|
||||||
github.com/klauspost/compress v1.17.3/go.mod h1:/dCuZOvVtNoHsyb+cuJD3itjs3NbnF6KH9zAO4BDxPM=
|
|
||||||
github.com/klauspost/compress v1.17.4 h1:Ej5ixsIri7BrIjBkRZLTo6ghwrEtHFk7ijlczPW4fZ4=
|
github.com/klauspost/compress v1.17.4 h1:Ej5ixsIri7BrIjBkRZLTo6ghwrEtHFk7ijlczPW4fZ4=
|
||||||
github.com/klauspost/compress v1.17.4/go.mod h1:/dCuZOvVtNoHsyb+cuJD3itjs3NbnF6KH9zAO4BDxPM=
|
github.com/klauspost/compress v1.17.4/go.mod h1:/dCuZOvVtNoHsyb+cuJD3itjs3NbnF6KH9zAO4BDxPM=
|
||||||
github.com/klauspost/cpuid/v2 v2.0.1/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg=
|
github.com/klauspost/cpuid/v2 v2.0.1/go.mod h1:FInQzS24/EEf25PyTYn52gqo7WaD8xa0213Md/qVLRg=
|
||||||
|
|||||||
@@ -57,10 +57,12 @@ func (p *provider) Authenticate(_ context.Context, r *http.Request) (*sdkaccess.
|
|||||||
authHeaderGoogle := r.Header.Get("X-Goog-Api-Key")
|
authHeaderGoogle := r.Header.Get("X-Goog-Api-Key")
|
||||||
authHeaderAnthropic := r.Header.Get("X-Api-Key")
|
authHeaderAnthropic := r.Header.Get("X-Api-Key")
|
||||||
queryKey := ""
|
queryKey := ""
|
||||||
|
queryAuthToken := ""
|
||||||
if r.URL != nil {
|
if r.URL != nil {
|
||||||
queryKey = r.URL.Query().Get("key")
|
queryKey = r.URL.Query().Get("key")
|
||||||
|
queryAuthToken = r.URL.Query().Get("auth_token")
|
||||||
}
|
}
|
||||||
if authHeader == "" && authHeaderGoogle == "" && authHeaderAnthropic == "" && queryKey == "" {
|
if authHeader == "" && authHeaderGoogle == "" && authHeaderAnthropic == "" && queryKey == "" && queryAuthToken == "" {
|
||||||
return nil, sdkaccess.ErrNoCredentials
|
return nil, sdkaccess.ErrNoCredentials
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -74,6 +76,7 @@ func (p *provider) Authenticate(_ context.Context, r *http.Request) (*sdkaccess.
|
|||||||
{authHeaderGoogle, "x-goog-api-key"},
|
{authHeaderGoogle, "x-goog-api-key"},
|
||||||
{authHeaderAnthropic, "x-api-key"},
|
{authHeaderAnthropic, "x-api-key"},
|
||||||
{queryKey, "query-key"},
|
{queryKey, "query-key"},
|
||||||
|
{queryAuthToken, "query-auth-token"},
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, candidate := range candidates {
|
for _, candidate := range candidates {
|
||||||
|
|||||||
@@ -10,6 +10,7 @@ import (
|
|||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/logging"
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/logging"
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
||||||
)
|
)
|
||||||
|
|
||||||
// RequestLoggingMiddleware creates a Gin middleware that logs HTTP requests and responses.
|
// RequestLoggingMiddleware creates a Gin middleware that logs HTTP requests and responses.
|
||||||
@@ -63,13 +64,11 @@ func RequestLoggingMiddleware(logger logging.RequestLogger) gin.HandlerFunc {
|
|||||||
// It captures the URL, method, headers, and body. The request body is read and then
|
// It captures the URL, method, headers, and body. The request body is read and then
|
||||||
// restored so that it can be processed by subsequent handlers.
|
// restored so that it can be processed by subsequent handlers.
|
||||||
func captureRequestInfo(c *gin.Context) (*RequestInfo, error) {
|
func captureRequestInfo(c *gin.Context) (*RequestInfo, error) {
|
||||||
// Capture URL
|
// Capture URL with sensitive query parameters masked
|
||||||
url := c.Request.URL.String()
|
maskedQuery := util.MaskSensitiveQuery(c.Request.URL.RawQuery)
|
||||||
if c.Request.URL.Path != "" {
|
url := c.Request.URL.Path
|
||||||
url = c.Request.URL.Path
|
if maskedQuery != "" {
|
||||||
if c.Request.URL.RawQuery != "" {
|
url += "?" + maskedQuery
|
||||||
url += "?" + c.Request.URL.RawQuery
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Capture method
|
// Capture method
|
||||||
|
|||||||
@@ -13,6 +13,7 @@ import (
|
|||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"strings"
|
"strings"
|
||||||
|
"sync"
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@@ -138,6 +139,12 @@ type Server struct {
|
|||||||
// currentPath is the absolute path to the current working directory.
|
// currentPath is the absolute path to the current working directory.
|
||||||
currentPath string
|
currentPath string
|
||||||
|
|
||||||
|
// wsRoutes tracks registered websocket upgrade paths.
|
||||||
|
wsRouteMu sync.Mutex
|
||||||
|
wsRoutes map[string]struct{}
|
||||||
|
wsAuthChanged func(bool, bool)
|
||||||
|
wsAuthEnabled atomic.Bool
|
||||||
|
|
||||||
// management handler
|
// management handler
|
||||||
mgmt *managementHandlers.Handler
|
mgmt *managementHandlers.Handler
|
||||||
|
|
||||||
@@ -228,7 +235,9 @@ func NewServer(cfg *config.Config, authManager *auth.Manager, accessManager *sdk
|
|||||||
configFilePath: configFilePath,
|
configFilePath: configFilePath,
|
||||||
currentPath: wd,
|
currentPath: wd,
|
||||||
envManagementSecret: envManagementSecret,
|
envManagementSecret: envManagementSecret,
|
||||||
|
wsRoutes: make(map[string]struct{}),
|
||||||
}
|
}
|
||||||
|
s.wsAuthEnabled.Store(cfg.WebsocketAuth)
|
||||||
// Save initial YAML snapshot
|
// Save initial YAML snapshot
|
||||||
s.oldConfigYaml, _ = yaml.Marshal(cfg)
|
s.oldConfigYaml, _ = yaml.Marshal(cfg)
|
||||||
s.applyAccessConfig(nil, cfg)
|
s.applyAccessConfig(nil, cfg)
|
||||||
@@ -371,6 +380,43 @@ func (s *Server) setupRoutes() {
|
|||||||
// Management routes are registered lazily by registerManagementRoutes when a secret is configured.
|
// Management routes are registered lazily by registerManagementRoutes when a secret is configured.
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// AttachWebsocketRoute registers a websocket upgrade handler on the primary Gin engine.
|
||||||
|
// The handler is served as-is without additional middleware beyond the standard stack already configured.
|
||||||
|
func (s *Server) AttachWebsocketRoute(path string, handler http.Handler) {
|
||||||
|
if s == nil || s.engine == nil || handler == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
trimmed := strings.TrimSpace(path)
|
||||||
|
if trimmed == "" {
|
||||||
|
trimmed = "/v1/ws"
|
||||||
|
}
|
||||||
|
if !strings.HasPrefix(trimmed, "/") {
|
||||||
|
trimmed = "/" + trimmed
|
||||||
|
}
|
||||||
|
s.wsRouteMu.Lock()
|
||||||
|
if _, exists := s.wsRoutes[trimmed]; exists {
|
||||||
|
s.wsRouteMu.Unlock()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
s.wsRoutes[trimmed] = struct{}{}
|
||||||
|
s.wsRouteMu.Unlock()
|
||||||
|
|
||||||
|
authMiddleware := AuthMiddleware(s.accessManager)
|
||||||
|
conditionalAuth := func(c *gin.Context) {
|
||||||
|
if !s.wsAuthEnabled.Load() {
|
||||||
|
c.Next()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
authMiddleware(c)
|
||||||
|
}
|
||||||
|
finalHandler := func(c *gin.Context) {
|
||||||
|
handler.ServeHTTP(c.Writer, c.Request)
|
||||||
|
c.Abort()
|
||||||
|
}
|
||||||
|
|
||||||
|
s.engine.GET(trimmed, conditionalAuth, finalHandler)
|
||||||
|
}
|
||||||
|
|
||||||
func (s *Server) registerManagementRoutes() {
|
func (s *Server) registerManagementRoutes() {
|
||||||
if s == nil || s.engine == nil || s.mgmt == nil {
|
if s == nil || s.engine == nil || s.mgmt == nil {
|
||||||
return
|
return
|
||||||
@@ -770,6 +816,10 @@ func (s *Server) UpdateClients(cfg *config.Config) {
|
|||||||
|
|
||||||
s.applyAccessConfig(oldCfg, cfg)
|
s.applyAccessConfig(oldCfg, cfg)
|
||||||
s.cfg = cfg
|
s.cfg = cfg
|
||||||
|
s.wsAuthEnabled.Store(cfg.WebsocketAuth)
|
||||||
|
if oldCfg != nil && s.wsAuthChanged != nil && oldCfg.WebsocketAuth != cfg.WebsocketAuth {
|
||||||
|
s.wsAuthChanged(oldCfg.WebsocketAuth, cfg.WebsocketAuth)
|
||||||
|
}
|
||||||
managementasset.SetCurrentConfig(cfg)
|
managementasset.SetCurrentConfig(cfg)
|
||||||
// Save YAML snapshot for next comparison
|
// Save YAML snapshot for next comparison
|
||||||
s.oldConfigYaml, _ = yaml.Marshal(cfg)
|
s.oldConfigYaml, _ = yaml.Marshal(cfg)
|
||||||
@@ -810,6 +860,13 @@ func (s *Server) UpdateClients(cfg *config.Config) {
|
|||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *Server) SetWebsocketAuthChangeHandler(fn func(bool, bool)) {
|
||||||
|
if s == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
s.wsAuthChanged = fn
|
||||||
|
}
|
||||||
|
|
||||||
// (management handlers moved to internal/api/handlers/management)
|
// (management handlers moved to internal/api/handlers/management)
|
||||||
|
|
||||||
// AuthMiddleware returns a Gin middleware handler that authenticates requests
|
// AuthMiddleware returns a Gin middleware handler that authenticates requests
|
||||||
|
|||||||
@@ -40,6 +40,9 @@ type Config struct {
|
|||||||
// QuotaExceeded defines the behavior when a quota is exceeded.
|
// QuotaExceeded defines the behavior when a quota is exceeded.
|
||||||
QuotaExceeded QuotaExceeded `yaml:"quota-exceeded" json:"quota-exceeded"`
|
QuotaExceeded QuotaExceeded `yaml:"quota-exceeded" json:"quota-exceeded"`
|
||||||
|
|
||||||
|
// WebsocketAuth enables or disables authentication for the WebSocket API.
|
||||||
|
WebsocketAuth bool `yaml:"ws-auth" json:"ws-auth"`
|
||||||
|
|
||||||
// GlAPIKey is the API key for the generative language API.
|
// GlAPIKey is the API key for the generative language API.
|
||||||
GlAPIKey []string `yaml:"generative-language-api-key" json:"generative-language-api-key"`
|
GlAPIKey []string `yaml:"generative-language-api-key" json:"generative-language-api-key"`
|
||||||
|
|
||||||
|
|||||||
@@ -10,6 +10,7 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/gin-gonic/gin"
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -23,7 +24,7 @@ func GinLogrusLogger() gin.HandlerFunc {
|
|||||||
return func(c *gin.Context) {
|
return func(c *gin.Context) {
|
||||||
start := time.Now()
|
start := time.Now()
|
||||||
path := c.Request.URL.Path
|
path := c.Request.URL.Path
|
||||||
raw := c.Request.URL.RawQuery
|
raw := util.MaskSensitiveQuery(c.Request.URL.RawQuery)
|
||||||
|
|
||||||
c.Next()
|
c.Next()
|
||||||
|
|
||||||
|
|||||||
348
internal/runtime/executor/aistudio_executor.go
Normal file
348
internal/runtime/executor/aistudio_executor.go
Normal file
@@ -0,0 +1,348 @@
|
|||||||
|
package executor
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
"net/url"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/wsrelay"
|
||||||
|
cliproxyauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||||
|
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
|
||||||
|
sdktranslator "github.com/router-for-me/CLIProxyAPI/v6/sdk/translator"
|
||||||
|
"github.com/tidwall/gjson"
|
||||||
|
"github.com/tidwall/sjson"
|
||||||
|
)
|
||||||
|
|
||||||
|
// AistudioExecutor routes AI Studio requests through a websocket-backed transport.
|
||||||
|
type AistudioExecutor struct {
|
||||||
|
provider string
|
||||||
|
relay *wsrelay.Manager
|
||||||
|
cfg *config.Config
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewAistudioExecutor constructs a websocket executor for the provider name.
|
||||||
|
func NewAistudioExecutor(cfg *config.Config, provider string, relay *wsrelay.Manager) *AistudioExecutor {
|
||||||
|
return &AistudioExecutor{provider: strings.ToLower(provider), relay: relay, cfg: cfg}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Identifier returns the provider key served by this executor.
|
||||||
|
func (e *AistudioExecutor) Identifier() string { return e.provider }
|
||||||
|
|
||||||
|
// PrepareRequest is a no-op because websocket transport already injects headers.
|
||||||
|
func (e *AistudioExecutor) PrepareRequest(_ *http.Request, _ *cliproxyauth.Auth) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *AistudioExecutor) Execute(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (resp cliproxyexecutor.Response, err error) {
|
||||||
|
reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth)
|
||||||
|
defer reporter.trackFailure(ctx, &err)
|
||||||
|
|
||||||
|
translatedReq, body, err := e.translateRequest(req, opts, false)
|
||||||
|
if err != nil {
|
||||||
|
return resp, err
|
||||||
|
}
|
||||||
|
endpoint := e.buildEndpoint(req.Model, body.action, opts.Alt)
|
||||||
|
wsReq := &wsrelay.HTTPRequest{
|
||||||
|
Method: http.MethodPost,
|
||||||
|
URL: endpoint,
|
||||||
|
Headers: http.Header{"Content-Type": []string{"application/json"}},
|
||||||
|
Body: body.payload,
|
||||||
|
}
|
||||||
|
|
||||||
|
var authID, authLabel, authType, authValue string
|
||||||
|
if auth != nil {
|
||||||
|
authID = auth.ID
|
||||||
|
authLabel = auth.Label
|
||||||
|
authType, authValue = auth.AccountInfo()
|
||||||
|
}
|
||||||
|
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{
|
||||||
|
URL: endpoint,
|
||||||
|
Method: http.MethodPost,
|
||||||
|
Headers: wsReq.Headers.Clone(),
|
||||||
|
Body: bytes.Clone(body.payload),
|
||||||
|
Provider: e.provider,
|
||||||
|
AuthID: authID,
|
||||||
|
AuthLabel: authLabel,
|
||||||
|
AuthType: authType,
|
||||||
|
AuthValue: authValue,
|
||||||
|
})
|
||||||
|
|
||||||
|
wsResp, err := e.relay.RoundTrip(ctx, e.provider, wsReq)
|
||||||
|
if err != nil {
|
||||||
|
recordAPIResponseError(ctx, e.cfg, err)
|
||||||
|
return resp, err
|
||||||
|
}
|
||||||
|
recordAPIResponseMetadata(ctx, e.cfg, wsResp.Status, wsResp.Headers.Clone())
|
||||||
|
if len(wsResp.Body) > 0 {
|
||||||
|
appendAPIResponseChunk(ctx, e.cfg, bytes.Clone(wsResp.Body))
|
||||||
|
}
|
||||||
|
if wsResp.Status < 200 || wsResp.Status >= 300 {
|
||||||
|
return resp, statusErr{code: wsResp.Status, msg: string(wsResp.Body)}
|
||||||
|
}
|
||||||
|
reporter.publish(ctx, parseGeminiUsage(wsResp.Body))
|
||||||
|
var param any
|
||||||
|
out := sdktranslator.TranslateNonStream(ctx, body.toFormat, opts.SourceFormat, req.Model, bytes.Clone(opts.OriginalRequest), bytes.Clone(translatedReq), bytes.Clone(wsResp.Body), ¶m)
|
||||||
|
resp = cliproxyexecutor.Response{Payload: []byte(out)}
|
||||||
|
return resp, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *AistudioExecutor) ExecuteStream(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (stream <-chan cliproxyexecutor.StreamChunk, err error) {
|
||||||
|
reporter := newUsageReporter(ctx, e.Identifier(), req.Model, auth)
|
||||||
|
defer reporter.trackFailure(ctx, &err)
|
||||||
|
|
||||||
|
translatedReq, body, err := e.translateRequest(req, opts, true)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
endpoint := e.buildEndpoint(req.Model, body.action, opts.Alt)
|
||||||
|
wsReq := &wsrelay.HTTPRequest{
|
||||||
|
Method: http.MethodPost,
|
||||||
|
URL: endpoint,
|
||||||
|
Headers: http.Header{"Content-Type": []string{"application/json"}},
|
||||||
|
Body: body.payload,
|
||||||
|
}
|
||||||
|
var authID, authLabel, authType, authValue string
|
||||||
|
if auth != nil {
|
||||||
|
authID = auth.ID
|
||||||
|
authLabel = auth.Label
|
||||||
|
authType, authValue = auth.AccountInfo()
|
||||||
|
}
|
||||||
|
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{
|
||||||
|
URL: endpoint,
|
||||||
|
Method: http.MethodPost,
|
||||||
|
Headers: wsReq.Headers.Clone(),
|
||||||
|
Body: bytes.Clone(body.payload),
|
||||||
|
Provider: e.provider,
|
||||||
|
AuthID: authID,
|
||||||
|
AuthLabel: authLabel,
|
||||||
|
AuthType: authType,
|
||||||
|
AuthValue: authValue,
|
||||||
|
})
|
||||||
|
wsStream, err := e.relay.Stream(ctx, e.provider, wsReq)
|
||||||
|
if err != nil {
|
||||||
|
recordAPIResponseError(ctx, e.cfg, err)
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
out := make(chan cliproxyexecutor.StreamChunk)
|
||||||
|
stream = out
|
||||||
|
go func() {
|
||||||
|
defer close(out)
|
||||||
|
var param any
|
||||||
|
metadataLogged := false
|
||||||
|
for event := range wsStream {
|
||||||
|
if event.Err != nil {
|
||||||
|
recordAPIResponseError(ctx, e.cfg, event.Err)
|
||||||
|
reporter.publishFailure(ctx)
|
||||||
|
out <- cliproxyexecutor.StreamChunk{Err: fmt.Errorf("wsrelay: %v", event.Err)}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
switch event.Type {
|
||||||
|
case wsrelay.MessageTypeStreamStart:
|
||||||
|
if !metadataLogged && event.Status > 0 {
|
||||||
|
recordAPIResponseMetadata(ctx, e.cfg, event.Status, event.Headers.Clone())
|
||||||
|
metadataLogged = true
|
||||||
|
}
|
||||||
|
case wsrelay.MessageTypeStreamChunk:
|
||||||
|
if len(event.Payload) > 0 {
|
||||||
|
appendAPIResponseChunk(ctx, e.cfg, bytes.Clone(event.Payload))
|
||||||
|
filtered := filterAistudioUsageMetadata(event.Payload)
|
||||||
|
if detail, ok := parseGeminiStreamUsage(filtered); ok {
|
||||||
|
reporter.publish(ctx, detail)
|
||||||
|
}
|
||||||
|
lines := sdktranslator.TranslateStream(ctx, body.toFormat, opts.SourceFormat, req.Model, bytes.Clone(opts.OriginalRequest), translatedReq, bytes.Clone(filtered), ¶m)
|
||||||
|
for i := range lines {
|
||||||
|
out <- cliproxyexecutor.StreamChunk{Payload: []byte(lines[i])}
|
||||||
|
}
|
||||||
|
break
|
||||||
|
}
|
||||||
|
case wsrelay.MessageTypeStreamEnd:
|
||||||
|
return
|
||||||
|
case wsrelay.MessageTypeHTTPResp:
|
||||||
|
if !metadataLogged && event.Status > 0 {
|
||||||
|
recordAPIResponseMetadata(ctx, e.cfg, event.Status, event.Headers.Clone())
|
||||||
|
metadataLogged = true
|
||||||
|
}
|
||||||
|
if len(event.Payload) > 0 {
|
||||||
|
appendAPIResponseChunk(ctx, e.cfg, bytes.Clone(event.Payload))
|
||||||
|
}
|
||||||
|
lines := sdktranslator.TranslateStream(ctx, body.toFormat, opts.SourceFormat, req.Model, bytes.Clone(opts.OriginalRequest), translatedReq, bytes.Clone(event.Payload), ¶m)
|
||||||
|
for i := range lines {
|
||||||
|
out <- cliproxyexecutor.StreamChunk{Payload: []byte(lines[i])}
|
||||||
|
}
|
||||||
|
reporter.publish(ctx, parseGeminiUsage(event.Payload))
|
||||||
|
return
|
||||||
|
case wsrelay.MessageTypeError:
|
||||||
|
recordAPIResponseError(ctx, e.cfg, event.Err)
|
||||||
|
reporter.publishFailure(ctx)
|
||||||
|
out <- cliproxyexecutor.StreamChunk{Err: fmt.Errorf("wsrelay: %v", event.Err)}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
return stream, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *AistudioExecutor) CountTokens(ctx context.Context, auth *cliproxyauth.Auth, req cliproxyexecutor.Request, opts cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
|
||||||
|
_, body, err := e.translateRequest(req, opts, false)
|
||||||
|
if err != nil {
|
||||||
|
return cliproxyexecutor.Response{}, err
|
||||||
|
}
|
||||||
|
|
||||||
|
body.payload, _ = sjson.DeleteBytes(body.payload, "generationConfig")
|
||||||
|
body.payload, _ = sjson.DeleteBytes(body.payload, "tools")
|
||||||
|
|
||||||
|
endpoint := e.buildEndpoint(req.Model, "countTokens", "")
|
||||||
|
wsReq := &wsrelay.HTTPRequest{
|
||||||
|
Method: http.MethodPost,
|
||||||
|
URL: endpoint,
|
||||||
|
Headers: http.Header{"Content-Type": []string{"application/json"}},
|
||||||
|
Body: body.payload,
|
||||||
|
}
|
||||||
|
var authID, authLabel, authType, authValue string
|
||||||
|
if auth != nil {
|
||||||
|
authID = auth.ID
|
||||||
|
authLabel = auth.Label
|
||||||
|
authType, authValue = auth.AccountInfo()
|
||||||
|
}
|
||||||
|
recordAPIRequest(ctx, e.cfg, upstreamRequestLog{
|
||||||
|
URL: endpoint,
|
||||||
|
Method: http.MethodPost,
|
||||||
|
Headers: wsReq.Headers.Clone(),
|
||||||
|
Body: bytes.Clone(body.payload),
|
||||||
|
Provider: e.provider,
|
||||||
|
AuthID: authID,
|
||||||
|
AuthLabel: authLabel,
|
||||||
|
AuthType: authType,
|
||||||
|
AuthValue: authValue,
|
||||||
|
})
|
||||||
|
resp, err := e.relay.RoundTrip(ctx, e.provider, wsReq)
|
||||||
|
if err != nil {
|
||||||
|
recordAPIResponseError(ctx, e.cfg, err)
|
||||||
|
return cliproxyexecutor.Response{}, err
|
||||||
|
}
|
||||||
|
recordAPIResponseMetadata(ctx, e.cfg, resp.Status, resp.Headers.Clone())
|
||||||
|
if len(resp.Body) > 0 {
|
||||||
|
appendAPIResponseChunk(ctx, e.cfg, bytes.Clone(resp.Body))
|
||||||
|
}
|
||||||
|
if resp.Status < 200 || resp.Status >= 300 {
|
||||||
|
return cliproxyexecutor.Response{}, statusErr{code: resp.Status, msg: string(resp.Body)}
|
||||||
|
}
|
||||||
|
totalTokens := gjson.GetBytes(resp.Body, "totalTokens").Int()
|
||||||
|
if totalTokens <= 0 {
|
||||||
|
return cliproxyexecutor.Response{}, fmt.Errorf("wsrelay: totalTokens missing in response")
|
||||||
|
}
|
||||||
|
translated := sdktranslator.TranslateTokenCount(ctx, body.toFormat, opts.SourceFormat, totalTokens, bytes.Clone(resp.Body))
|
||||||
|
return cliproxyexecutor.Response{Payload: []byte(translated)}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *AistudioExecutor) Refresh(ctx context.Context, auth *cliproxyauth.Auth) (*cliproxyauth.Auth, error) {
|
||||||
|
_ = ctx
|
||||||
|
return auth, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type translatedPayload struct {
|
||||||
|
payload []byte
|
||||||
|
action string
|
||||||
|
toFormat sdktranslator.Format
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *AistudioExecutor) translateRequest(req cliproxyexecutor.Request, opts cliproxyexecutor.Options, stream bool) ([]byte, translatedPayload, error) {
|
||||||
|
from := opts.SourceFormat
|
||||||
|
to := sdktranslator.FromString("gemini")
|
||||||
|
payload := sdktranslator.TranslateRequest(from, to, req.Model, bytes.Clone(req.Payload), stream)
|
||||||
|
if budgetOverride, includeOverride, ok := util.GeminiThinkingFromMetadata(req.Metadata); ok {
|
||||||
|
payload = util.ApplyGeminiThinkingConfig(payload, budgetOverride, includeOverride)
|
||||||
|
}
|
||||||
|
payload = disableGeminiThinkingConfig(payload, req.Model)
|
||||||
|
payload = fixGeminiImageAspectRatio(req.Model, payload)
|
||||||
|
metadataAction := "generateContent"
|
||||||
|
if req.Metadata != nil {
|
||||||
|
if action, _ := req.Metadata["action"].(string); action == "countTokens" {
|
||||||
|
metadataAction = action
|
||||||
|
}
|
||||||
|
}
|
||||||
|
action := metadataAction
|
||||||
|
if stream && action != "countTokens" {
|
||||||
|
action = "streamGenerateContent"
|
||||||
|
}
|
||||||
|
payload, _ = sjson.DeleteBytes(payload, "session_id")
|
||||||
|
return payload, translatedPayload{payload: payload, action: action, toFormat: to}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *AistudioExecutor) buildEndpoint(model, action, alt string) string {
|
||||||
|
base := fmt.Sprintf("%s/%s/models/%s:%s", glEndpoint, glAPIVersion, model, action)
|
||||||
|
if action == "streamGenerateContent" {
|
||||||
|
if alt == "" {
|
||||||
|
return base + "?alt=sse"
|
||||||
|
}
|
||||||
|
return base + "?$alt=" + url.QueryEscape(alt)
|
||||||
|
}
|
||||||
|
if alt != "" && action != "countTokens" {
|
||||||
|
return base + "?$alt=" + url.QueryEscape(alt)
|
||||||
|
}
|
||||||
|
return base
|
||||||
|
}
|
||||||
|
|
||||||
|
// filterAistudioUsageMetadata removes usageMetadata from intermediate SSE events so that
|
||||||
|
// only the terminal chunk retains token statistics.
|
||||||
|
func filterAistudioUsageMetadata(payload []byte) []byte {
|
||||||
|
if len(payload) == 0 {
|
||||||
|
return payload
|
||||||
|
}
|
||||||
|
|
||||||
|
lines := bytes.Split(payload, []byte("\n"))
|
||||||
|
modified := false
|
||||||
|
for idx, line := range lines {
|
||||||
|
trimmed := bytes.TrimSpace(line)
|
||||||
|
if len(trimmed) == 0 || !bytes.HasPrefix(trimmed, []byte("data:")) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
dataIdx := bytes.Index(line, []byte("data:"))
|
||||||
|
if dataIdx < 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
rawJSON := bytes.TrimSpace(line[dataIdx+5:])
|
||||||
|
cleaned, changed := stripUsageMetadataFromJSON(rawJSON)
|
||||||
|
if !changed {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
var rebuilt []byte
|
||||||
|
rebuilt = append(rebuilt, line[:dataIdx]...)
|
||||||
|
rebuilt = append(rebuilt, []byte("data:")...)
|
||||||
|
if len(cleaned) > 0 {
|
||||||
|
rebuilt = append(rebuilt, ' ')
|
||||||
|
rebuilt = append(rebuilt, cleaned...)
|
||||||
|
}
|
||||||
|
lines[idx] = rebuilt
|
||||||
|
modified = true
|
||||||
|
}
|
||||||
|
if !modified {
|
||||||
|
return payload
|
||||||
|
}
|
||||||
|
return bytes.Join(lines, []byte("\n"))
|
||||||
|
}
|
||||||
|
|
||||||
|
// stripUsageMetadataFromJSON drops usageMetadata when no finishReason is present.
|
||||||
|
func stripUsageMetadataFromJSON(rawJSON []byte) ([]byte, bool) {
|
||||||
|
jsonBytes := bytes.TrimSpace(rawJSON)
|
||||||
|
if len(jsonBytes) == 0 || !gjson.ValidBytes(jsonBytes) {
|
||||||
|
return rawJSON, false
|
||||||
|
}
|
||||||
|
finishReason := gjson.GetBytes(jsonBytes, "candidates.0.finishReason")
|
||||||
|
if finishReason.Exists() && finishReason.String() != "" {
|
||||||
|
return rawJSON, false
|
||||||
|
}
|
||||||
|
if !gjson.GetBytes(jsonBytes, "usageMetadata").Exists() {
|
||||||
|
return rawJSON, false
|
||||||
|
}
|
||||||
|
cleaned, err := sjson.DeleteBytes(jsonBytes, "usageMetadata")
|
||||||
|
if err != nil {
|
||||||
|
return rawJSON, false
|
||||||
|
}
|
||||||
|
return cleaned, true
|
||||||
|
}
|
||||||
@@ -4,6 +4,7 @@
|
|||||||
package util
|
package util
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"net/url"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/config"
|
||||||
@@ -188,3 +189,56 @@ func MaskSensitiveHeaderValue(key, value string) string {
|
|||||||
return value
|
return value
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// MaskSensitiveQuery masks sensitive query parameters, e.g. auth_token, within the raw query string.
|
||||||
|
func MaskSensitiveQuery(raw string) string {
|
||||||
|
if raw == "" {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
parts := strings.Split(raw, "&")
|
||||||
|
changed := false
|
||||||
|
for i, part := range parts {
|
||||||
|
if part == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
keyPart := part
|
||||||
|
valuePart := ""
|
||||||
|
if idx := strings.Index(part, "="); idx >= 0 {
|
||||||
|
keyPart = part[:idx]
|
||||||
|
valuePart = part[idx+1:]
|
||||||
|
}
|
||||||
|
decodedKey, err := url.QueryUnescape(keyPart)
|
||||||
|
if err != nil {
|
||||||
|
decodedKey = keyPart
|
||||||
|
}
|
||||||
|
if !shouldMaskQueryParam(decodedKey) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
decodedValue, err := url.QueryUnescape(valuePart)
|
||||||
|
if err != nil {
|
||||||
|
decodedValue = valuePart
|
||||||
|
}
|
||||||
|
masked := HideAPIKey(strings.TrimSpace(decodedValue))
|
||||||
|
parts[i] = keyPart + "=" + url.QueryEscape(masked)
|
||||||
|
changed = true
|
||||||
|
}
|
||||||
|
if !changed {
|
||||||
|
return raw
|
||||||
|
}
|
||||||
|
return strings.Join(parts, "&")
|
||||||
|
}
|
||||||
|
|
||||||
|
func shouldMaskQueryParam(key string) bool {
|
||||||
|
key = strings.ToLower(strings.TrimSpace(key))
|
||||||
|
if key == "" {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
key = strings.TrimSuffix(key, "[]")
|
||||||
|
if key == "key" || strings.Contains(key, "api-key") || strings.Contains(key, "apikey") || strings.Contains(key, "api_key") {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
if strings.Contains(key, "token") || strings.Contains(key, "secret") {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|||||||
@@ -1204,6 +1204,9 @@ func buildConfigChangeDetails(oldCfg, newCfg *config.Config) []string {
|
|||||||
if oldCfg.ProxyURL != newCfg.ProxyURL {
|
if oldCfg.ProxyURL != newCfg.ProxyURL {
|
||||||
changes = append(changes, fmt.Sprintf("proxy-url: %s -> %s", oldCfg.ProxyURL, newCfg.ProxyURL))
|
changes = append(changes, fmt.Sprintf("proxy-url: %s -> %s", oldCfg.ProxyURL, newCfg.ProxyURL))
|
||||||
}
|
}
|
||||||
|
if oldCfg.WebsocketAuth != newCfg.WebsocketAuth {
|
||||||
|
changes = append(changes, fmt.Sprintf("ws-auth: %t -> %t", oldCfg.WebsocketAuth, newCfg.WebsocketAuth))
|
||||||
|
}
|
||||||
|
|
||||||
// Quota-exceeded behavior
|
// Quota-exceeded behavior
|
||||||
if oldCfg.QuotaExceeded.SwitchProject != newCfg.QuotaExceeded.SwitchProject {
|
if oldCfg.QuotaExceeded.SwitchProject != newCfg.QuotaExceeded.SwitchProject {
|
||||||
|
|||||||
233
internal/wsrelay/http.go
Normal file
233
internal/wsrelay/http.go
Normal file
@@ -0,0 +1,233 @@
|
|||||||
|
package wsrelay
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/google/uuid"
|
||||||
|
)
|
||||||
|
|
||||||
|
// HTTPRequest represents a proxied HTTP request delivered to websocket clients.
|
||||||
|
type HTTPRequest struct {
|
||||||
|
Method string
|
||||||
|
URL string
|
||||||
|
Headers http.Header
|
||||||
|
Body []byte
|
||||||
|
}
|
||||||
|
|
||||||
|
// HTTPResponse captures the response relayed back from websocket clients.
|
||||||
|
type HTTPResponse struct {
|
||||||
|
Status int
|
||||||
|
Headers http.Header
|
||||||
|
Body []byte
|
||||||
|
}
|
||||||
|
|
||||||
|
// StreamEvent represents a streaming response event from clients.
|
||||||
|
type StreamEvent struct {
|
||||||
|
Type string
|
||||||
|
Payload []byte
|
||||||
|
Status int
|
||||||
|
Headers http.Header
|
||||||
|
Err error
|
||||||
|
}
|
||||||
|
|
||||||
|
// RoundTrip executes a non-streaming HTTP request using the websocket provider.
|
||||||
|
func (m *Manager) RoundTrip(ctx context.Context, provider string, req *HTTPRequest) (*HTTPResponse, error) {
|
||||||
|
if req == nil {
|
||||||
|
return nil, fmt.Errorf("wsrelay: request is nil")
|
||||||
|
}
|
||||||
|
msg := Message{ID: uuid.NewString(), Type: MessageTypeHTTPReq, Payload: encodeRequest(req)}
|
||||||
|
respCh, err := m.Send(ctx, provider, msg)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
var (
|
||||||
|
streamMode bool
|
||||||
|
streamResp *HTTPResponse
|
||||||
|
streamBody bytes.Buffer
|
||||||
|
)
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return nil, ctx.Err()
|
||||||
|
case msg, ok := <-respCh:
|
||||||
|
if !ok {
|
||||||
|
if streamMode {
|
||||||
|
if streamResp == nil {
|
||||||
|
streamResp = &HTTPResponse{Status: http.StatusOK, Headers: make(http.Header)}
|
||||||
|
} else if streamResp.Headers == nil {
|
||||||
|
streamResp.Headers = make(http.Header)
|
||||||
|
}
|
||||||
|
streamResp.Body = append(streamResp.Body[:0], streamBody.Bytes()...)
|
||||||
|
return streamResp, nil
|
||||||
|
}
|
||||||
|
return nil, errors.New("wsrelay: connection closed during response")
|
||||||
|
}
|
||||||
|
switch msg.Type {
|
||||||
|
case MessageTypeHTTPResp:
|
||||||
|
resp := decodeResponse(msg.Payload)
|
||||||
|
if streamMode && streamBody.Len() > 0 && len(resp.Body) == 0 {
|
||||||
|
resp.Body = append(resp.Body[:0], streamBody.Bytes()...)
|
||||||
|
}
|
||||||
|
return resp, nil
|
||||||
|
case MessageTypeError:
|
||||||
|
return nil, decodeError(msg.Payload)
|
||||||
|
case MessageTypeStreamStart, MessageTypeStreamChunk:
|
||||||
|
if msg.Type == MessageTypeStreamStart {
|
||||||
|
streamMode = true
|
||||||
|
streamResp = decodeResponse(msg.Payload)
|
||||||
|
if streamResp.Headers == nil {
|
||||||
|
streamResp.Headers = make(http.Header)
|
||||||
|
}
|
||||||
|
streamBody.Reset()
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if !streamMode {
|
||||||
|
streamMode = true
|
||||||
|
streamResp = &HTTPResponse{Status: http.StatusOK, Headers: make(http.Header)}
|
||||||
|
}
|
||||||
|
chunk := decodeChunk(msg.Payload)
|
||||||
|
if len(chunk) > 0 {
|
||||||
|
streamBody.Write(chunk)
|
||||||
|
}
|
||||||
|
case MessageTypeStreamEnd:
|
||||||
|
if !streamMode {
|
||||||
|
return &HTTPResponse{Status: http.StatusOK, Headers: make(http.Header)}, nil
|
||||||
|
}
|
||||||
|
if streamResp == nil {
|
||||||
|
streamResp = &HTTPResponse{Status: http.StatusOK, Headers: make(http.Header)}
|
||||||
|
} else if streamResp.Headers == nil {
|
||||||
|
streamResp.Headers = make(http.Header)
|
||||||
|
}
|
||||||
|
streamResp.Body = append(streamResp.Body[:0], streamBody.Bytes()...)
|
||||||
|
return streamResp, nil
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Stream executes a streaming HTTP request and returns channel with stream events.
|
||||||
|
func (m *Manager) Stream(ctx context.Context, provider string, req *HTTPRequest) (<-chan StreamEvent, error) {
|
||||||
|
if req == nil {
|
||||||
|
return nil, fmt.Errorf("wsrelay: request is nil")
|
||||||
|
}
|
||||||
|
msg := Message{ID: uuid.NewString(), Type: MessageTypeHTTPReq, Payload: encodeRequest(req)}
|
||||||
|
respCh, err := m.Send(ctx, provider, msg)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
out := make(chan StreamEvent)
|
||||||
|
go func() {
|
||||||
|
defer close(out)
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
out <- StreamEvent{Err: ctx.Err()}
|
||||||
|
return
|
||||||
|
case msg, ok := <-respCh:
|
||||||
|
if !ok {
|
||||||
|
out <- StreamEvent{Err: errors.New("wsrelay: stream closed")}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
switch msg.Type {
|
||||||
|
case MessageTypeStreamStart:
|
||||||
|
resp := decodeResponse(msg.Payload)
|
||||||
|
out <- StreamEvent{Type: MessageTypeStreamStart, Status: resp.Status, Headers: resp.Headers}
|
||||||
|
case MessageTypeStreamChunk:
|
||||||
|
chunk := decodeChunk(msg.Payload)
|
||||||
|
out <- StreamEvent{Type: MessageTypeStreamChunk, Payload: chunk}
|
||||||
|
case MessageTypeStreamEnd:
|
||||||
|
out <- StreamEvent{Type: MessageTypeStreamEnd}
|
||||||
|
return
|
||||||
|
case MessageTypeError:
|
||||||
|
out <- StreamEvent{Type: MessageTypeError, Err: decodeError(msg.Payload)}
|
||||||
|
return
|
||||||
|
case MessageTypeHTTPResp:
|
||||||
|
resp := decodeResponse(msg.Payload)
|
||||||
|
out <- StreamEvent{Type: MessageTypeHTTPResp, Status: resp.Status, Headers: resp.Headers, Payload: resp.Body}
|
||||||
|
return
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
return out, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func encodeRequest(req *HTTPRequest) map[string]any {
|
||||||
|
headers := make(map[string]any, len(req.Headers))
|
||||||
|
for key, values := range req.Headers {
|
||||||
|
copyValues := make([]string, len(values))
|
||||||
|
copy(copyValues, values)
|
||||||
|
headers[key] = copyValues
|
||||||
|
}
|
||||||
|
return map[string]any{
|
||||||
|
"method": req.Method,
|
||||||
|
"url": req.URL,
|
||||||
|
"headers": headers,
|
||||||
|
"body": string(req.Body),
|
||||||
|
"sent_at": time.Now().UTC().Format(time.RFC3339Nano),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func decodeResponse(payload map[string]any) *HTTPResponse {
|
||||||
|
if payload == nil {
|
||||||
|
return &HTTPResponse{Status: http.StatusBadGateway, Headers: make(http.Header)}
|
||||||
|
}
|
||||||
|
resp := &HTTPResponse{Status: http.StatusOK, Headers: make(http.Header)}
|
||||||
|
if status, ok := payload["status"].(float64); ok {
|
||||||
|
resp.Status = int(status)
|
||||||
|
}
|
||||||
|
if headers, ok := payload["headers"].(map[string]any); ok {
|
||||||
|
for key, raw := range headers {
|
||||||
|
switch v := raw.(type) {
|
||||||
|
case []any:
|
||||||
|
for _, item := range v {
|
||||||
|
if str, ok := item.(string); ok {
|
||||||
|
resp.Headers.Add(key, str)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
case []string:
|
||||||
|
for _, str := range v {
|
||||||
|
resp.Headers.Add(key, str)
|
||||||
|
}
|
||||||
|
case string:
|
||||||
|
resp.Headers.Set(key, v)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if body, ok := payload["body"].(string); ok {
|
||||||
|
resp.Body = []byte(body)
|
||||||
|
}
|
||||||
|
return resp
|
||||||
|
}
|
||||||
|
|
||||||
|
func decodeChunk(payload map[string]any) []byte {
|
||||||
|
if payload == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if data, ok := payload["data"].(string); ok {
|
||||||
|
return []byte(data)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func decodeError(payload map[string]any) error {
|
||||||
|
if payload == nil {
|
||||||
|
return errors.New("wsrelay: unknown error")
|
||||||
|
}
|
||||||
|
message, _ := payload["error"].(string)
|
||||||
|
status := 0
|
||||||
|
if v, ok := payload["status"].(float64); ok {
|
||||||
|
status = int(v)
|
||||||
|
}
|
||||||
|
if message == "" {
|
||||||
|
message = "wsrelay: upstream error"
|
||||||
|
}
|
||||||
|
return fmt.Errorf("%s (status=%d)", message, status)
|
||||||
|
}
|
||||||
205
internal/wsrelay/manager.go
Normal file
205
internal/wsrelay/manager.go
Normal file
@@ -0,0 +1,205 @@
|
|||||||
|
package wsrelay
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"crypto/rand"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/gorilla/websocket"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Manager exposes a websocket endpoint that proxies Gemini requests to
|
||||||
|
// connected clients.
|
||||||
|
type Manager struct {
|
||||||
|
path string
|
||||||
|
upgrader websocket.Upgrader
|
||||||
|
sessions map[string]*session
|
||||||
|
sessMutex sync.RWMutex
|
||||||
|
|
||||||
|
providerFactory func(*http.Request) (string, error)
|
||||||
|
onConnected func(string)
|
||||||
|
onDisconnected func(string, error)
|
||||||
|
|
||||||
|
logDebugf func(string, ...any)
|
||||||
|
logInfof func(string, ...any)
|
||||||
|
logWarnf func(string, ...any)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Options configures a Manager instance.
|
||||||
|
type Options struct {
|
||||||
|
Path string
|
||||||
|
ProviderFactory func(*http.Request) (string, error)
|
||||||
|
OnConnected func(string)
|
||||||
|
OnDisconnected func(string, error)
|
||||||
|
LogDebugf func(string, ...any)
|
||||||
|
LogInfof func(string, ...any)
|
||||||
|
LogWarnf func(string, ...any)
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewManager builds a websocket relay manager with the supplied options.
|
||||||
|
func NewManager(opts Options) *Manager {
|
||||||
|
path := strings.TrimSpace(opts.Path)
|
||||||
|
if path == "" {
|
||||||
|
path = "/v1/ws"
|
||||||
|
}
|
||||||
|
if !strings.HasPrefix(path, "/") {
|
||||||
|
path = "/" + path
|
||||||
|
}
|
||||||
|
mgr := &Manager{
|
||||||
|
path: path,
|
||||||
|
sessions: make(map[string]*session),
|
||||||
|
upgrader: websocket.Upgrader{
|
||||||
|
ReadBufferSize: 1024,
|
||||||
|
WriteBufferSize: 1024,
|
||||||
|
CheckOrigin: func(r *http.Request) bool {
|
||||||
|
return true
|
||||||
|
},
|
||||||
|
},
|
||||||
|
providerFactory: opts.ProviderFactory,
|
||||||
|
onConnected: opts.OnConnected,
|
||||||
|
onDisconnected: opts.OnDisconnected,
|
||||||
|
logDebugf: opts.LogDebugf,
|
||||||
|
logInfof: opts.LogInfof,
|
||||||
|
logWarnf: opts.LogWarnf,
|
||||||
|
}
|
||||||
|
if mgr.logDebugf == nil {
|
||||||
|
mgr.logDebugf = func(string, ...any) {}
|
||||||
|
}
|
||||||
|
if mgr.logInfof == nil {
|
||||||
|
mgr.logInfof = func(string, ...any) {}
|
||||||
|
}
|
||||||
|
if mgr.logWarnf == nil {
|
||||||
|
mgr.logWarnf = func(s string, args ...any) { fmt.Printf(s+"\n", args...) }
|
||||||
|
}
|
||||||
|
return mgr
|
||||||
|
}
|
||||||
|
|
||||||
|
// Path returns the HTTP path the manager expects for websocket upgrades.
|
||||||
|
func (m *Manager) Path() string {
|
||||||
|
if m == nil {
|
||||||
|
return "/v1/ws"
|
||||||
|
}
|
||||||
|
return m.path
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handler exposes an http.Handler that upgrades connections to websocket sessions.
|
||||||
|
func (m *Manager) Handler() http.Handler {
|
||||||
|
return http.HandlerFunc(m.handleWebsocket)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Stop gracefully closes all active websocket sessions.
|
||||||
|
func (m *Manager) Stop(_ context.Context) error {
|
||||||
|
m.sessMutex.Lock()
|
||||||
|
sessions := make([]*session, 0, len(m.sessions))
|
||||||
|
for _, sess := range m.sessions {
|
||||||
|
sessions = append(sessions, sess)
|
||||||
|
}
|
||||||
|
m.sessions = make(map[string]*session)
|
||||||
|
m.sessMutex.Unlock()
|
||||||
|
|
||||||
|
for _, sess := range sessions {
|
||||||
|
if sess != nil {
|
||||||
|
sess.cleanup(errors.New("wsrelay: manager stopped"))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// handleWebsocket upgrades the connection and wires the session into the pool.
|
||||||
|
func (m *Manager) handleWebsocket(w http.ResponseWriter, r *http.Request) {
|
||||||
|
expectedPath := m.Path()
|
||||||
|
if expectedPath != "" && r.URL != nil && r.URL.Path != expectedPath {
|
||||||
|
http.NotFound(w, r)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if !strings.EqualFold(r.Method, http.MethodGet) {
|
||||||
|
w.Header().Set("Allow", http.MethodGet)
|
||||||
|
http.Error(w, "method not allowed", http.StatusMethodNotAllowed)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
conn, err := m.upgrader.Upgrade(w, r, nil)
|
||||||
|
if err != nil {
|
||||||
|
m.logWarnf("wsrelay: upgrade failed: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
s := newSession(conn, m, randomProviderName())
|
||||||
|
if m.providerFactory != nil {
|
||||||
|
name, err := m.providerFactory(r)
|
||||||
|
if err != nil {
|
||||||
|
s.cleanup(err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if strings.TrimSpace(name) != "" {
|
||||||
|
s.provider = strings.ToLower(name)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if s.provider == "" {
|
||||||
|
s.provider = strings.ToLower(s.id)
|
||||||
|
}
|
||||||
|
m.sessMutex.Lock()
|
||||||
|
var replaced *session
|
||||||
|
if existing, ok := m.sessions[s.provider]; ok {
|
||||||
|
replaced = existing
|
||||||
|
}
|
||||||
|
m.sessions[s.provider] = s
|
||||||
|
m.sessMutex.Unlock()
|
||||||
|
|
||||||
|
if replaced != nil {
|
||||||
|
replaced.cleanup(errors.New("replaced by new connection"))
|
||||||
|
}
|
||||||
|
if m.onConnected != nil {
|
||||||
|
m.onConnected(s.provider)
|
||||||
|
}
|
||||||
|
|
||||||
|
go s.run(context.Background())
|
||||||
|
}
|
||||||
|
|
||||||
|
// Send forwards the message to the specific provider connection and returns a channel
|
||||||
|
// yielding response messages.
|
||||||
|
func (m *Manager) Send(ctx context.Context, provider string, msg Message) (<-chan Message, error) {
|
||||||
|
s := m.session(provider)
|
||||||
|
if s == nil {
|
||||||
|
return nil, fmt.Errorf("wsrelay: provider %s not connected", provider)
|
||||||
|
}
|
||||||
|
return s.request(ctx, msg)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Manager) session(provider string) *session {
|
||||||
|
key := strings.ToLower(strings.TrimSpace(provider))
|
||||||
|
m.sessMutex.RLock()
|
||||||
|
s := m.sessions[key]
|
||||||
|
m.sessMutex.RUnlock()
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Manager) handleSessionClosed(s *session, cause error) {
|
||||||
|
if s == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
key := strings.ToLower(strings.TrimSpace(s.provider))
|
||||||
|
m.sessMutex.Lock()
|
||||||
|
if cur, ok := m.sessions[key]; ok && cur == s {
|
||||||
|
delete(m.sessions, key)
|
||||||
|
}
|
||||||
|
m.sessMutex.Unlock()
|
||||||
|
if m.onDisconnected != nil {
|
||||||
|
m.onDisconnected(s.provider, cause)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func randomProviderName() string {
|
||||||
|
const alphabet = "abcdefghijklmnopqrstuvwxyz0123456789"
|
||||||
|
buf := make([]byte, 16)
|
||||||
|
if _, err := rand.Read(buf); err != nil {
|
||||||
|
return fmt.Sprintf("aistudio-%x", time.Now().UnixNano())
|
||||||
|
}
|
||||||
|
for i := range buf {
|
||||||
|
buf[i] = alphabet[int(buf[i])%len(alphabet)]
|
||||||
|
}
|
||||||
|
return "aistudio-" + string(buf)
|
||||||
|
}
|
||||||
27
internal/wsrelay/message.go
Normal file
27
internal/wsrelay/message.go
Normal file
@@ -0,0 +1,27 @@
|
|||||||
|
package wsrelay
|
||||||
|
|
||||||
|
// Message represents the JSON payload exchanged with websocket clients.
|
||||||
|
type Message struct {
|
||||||
|
ID string `json:"id"`
|
||||||
|
Type string `json:"type"`
|
||||||
|
Payload map[string]any `json:"payload,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
const (
|
||||||
|
// MessageTypeHTTPReq identifies an HTTP-style request envelope.
|
||||||
|
MessageTypeHTTPReq = "http_request"
|
||||||
|
// MessageTypeHTTPResp identifies a non-streaming HTTP response envelope.
|
||||||
|
MessageTypeHTTPResp = "http_response"
|
||||||
|
// MessageTypeStreamStart marks the beginning of a streaming response.
|
||||||
|
MessageTypeStreamStart = "stream_start"
|
||||||
|
// MessageTypeStreamChunk carries a streaming response chunk.
|
||||||
|
MessageTypeStreamChunk = "stream_chunk"
|
||||||
|
// MessageTypeStreamEnd marks the completion of a streaming response.
|
||||||
|
MessageTypeStreamEnd = "stream_end"
|
||||||
|
// MessageTypeError carries an error response.
|
||||||
|
MessageTypeError = "error"
|
||||||
|
// MessageTypePing represents ping messages from clients.
|
||||||
|
MessageTypePing = "ping"
|
||||||
|
// MessageTypePong represents pong responses back to clients.
|
||||||
|
MessageTypePong = "pong"
|
||||||
|
)
|
||||||
188
internal/wsrelay/session.go
Normal file
188
internal/wsrelay/session.go
Normal file
@@ -0,0 +1,188 @@
|
|||||||
|
package wsrelay
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/gorilla/websocket"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
readTimeout = 60 * time.Second
|
||||||
|
writeTimeout = 10 * time.Second
|
||||||
|
maxInboundMessageLen = 64 << 20 // 64 MiB
|
||||||
|
heartbeatInterval = 30 * time.Second
|
||||||
|
)
|
||||||
|
|
||||||
|
var errClosed = errors.New("websocket session closed")
|
||||||
|
|
||||||
|
type pendingRequest struct {
|
||||||
|
ch chan Message
|
||||||
|
closeOnce sync.Once
|
||||||
|
}
|
||||||
|
|
||||||
|
func (pr *pendingRequest) close() {
|
||||||
|
if pr == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
pr.closeOnce.Do(func() {
|
||||||
|
close(pr.ch)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
type session struct {
|
||||||
|
conn *websocket.Conn
|
||||||
|
manager *Manager
|
||||||
|
provider string
|
||||||
|
id string
|
||||||
|
closed chan struct{}
|
||||||
|
closeOnce sync.Once
|
||||||
|
writeMutex sync.Mutex
|
||||||
|
pending sync.Map // map[string]*pendingRequest
|
||||||
|
}
|
||||||
|
|
||||||
|
func newSession(conn *websocket.Conn, mgr *Manager, id string) *session {
|
||||||
|
s := &session{
|
||||||
|
conn: conn,
|
||||||
|
manager: mgr,
|
||||||
|
provider: "",
|
||||||
|
id: id,
|
||||||
|
closed: make(chan struct{}),
|
||||||
|
}
|
||||||
|
conn.SetReadLimit(maxInboundMessageLen)
|
||||||
|
conn.SetReadDeadline(time.Now().Add(readTimeout))
|
||||||
|
conn.SetPongHandler(func(string) error {
|
||||||
|
conn.SetReadDeadline(time.Now().Add(readTimeout))
|
||||||
|
return nil
|
||||||
|
})
|
||||||
|
s.startHeartbeat()
|
||||||
|
return s
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *session) startHeartbeat() {
|
||||||
|
if s == nil || s.conn == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
ticker := time.NewTicker(heartbeatInterval)
|
||||||
|
go func() {
|
||||||
|
defer ticker.Stop()
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-s.closed:
|
||||||
|
return
|
||||||
|
case <-ticker.C:
|
||||||
|
s.writeMutex.Lock()
|
||||||
|
err := s.conn.WriteControl(websocket.PingMessage, []byte("ping"), time.Now().Add(writeTimeout))
|
||||||
|
s.writeMutex.Unlock()
|
||||||
|
if err != nil {
|
||||||
|
s.cleanup(err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *session) run(ctx context.Context) {
|
||||||
|
defer s.cleanup(errClosed)
|
||||||
|
for {
|
||||||
|
var msg Message
|
||||||
|
if err := s.conn.ReadJSON(&msg); err != nil {
|
||||||
|
s.cleanup(err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
s.dispatch(msg)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *session) dispatch(msg Message) {
|
||||||
|
if msg.Type == MessageTypePing {
|
||||||
|
_ = s.send(context.Background(), Message{ID: msg.ID, Type: MessageTypePong})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if value, ok := s.pending.Load(msg.ID); ok {
|
||||||
|
req := value.(*pendingRequest)
|
||||||
|
select {
|
||||||
|
case req.ch <- msg:
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
if msg.Type == MessageTypeHTTPResp || msg.Type == MessageTypeError || msg.Type == MessageTypeStreamEnd {
|
||||||
|
if actual, loaded := s.pending.LoadAndDelete(msg.ID); loaded {
|
||||||
|
actual.(*pendingRequest).close()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if msg.Type == MessageTypeHTTPResp || msg.Type == MessageTypeError || msg.Type == MessageTypeStreamEnd {
|
||||||
|
s.manager.logDebugf("wsrelay: received terminal message for unknown id %s (provider=%s)", msg.ID, s.provider)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *session) send(ctx context.Context, msg Message) error {
|
||||||
|
select {
|
||||||
|
case <-s.closed:
|
||||||
|
return errClosed
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
s.writeMutex.Lock()
|
||||||
|
defer s.writeMutex.Unlock()
|
||||||
|
if err := s.conn.SetWriteDeadline(time.Now().Add(writeTimeout)); err != nil {
|
||||||
|
return fmt.Errorf("set write deadline: %w", err)
|
||||||
|
}
|
||||||
|
if err := s.conn.WriteJSON(msg); err != nil {
|
||||||
|
return fmt.Errorf("write json: %w", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *session) request(ctx context.Context, msg Message) (<-chan Message, error) {
|
||||||
|
if msg.ID == "" {
|
||||||
|
return nil, fmt.Errorf("wsrelay: message id is required")
|
||||||
|
}
|
||||||
|
if _, loaded := s.pending.LoadOrStore(msg.ID, &pendingRequest{ch: make(chan Message, 8)}); loaded {
|
||||||
|
return nil, fmt.Errorf("wsrelay: duplicate message id %s", msg.ID)
|
||||||
|
}
|
||||||
|
value, _ := s.pending.Load(msg.ID)
|
||||||
|
req := value.(*pendingRequest)
|
||||||
|
if err := s.send(ctx, msg); err != nil {
|
||||||
|
if actual, loaded := s.pending.LoadAndDelete(msg.ID); loaded {
|
||||||
|
req := actual.(*pendingRequest)
|
||||||
|
req.close()
|
||||||
|
}
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
go func() {
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
if actual, loaded := s.pending.LoadAndDelete(msg.ID); loaded {
|
||||||
|
actual.(*pendingRequest).close()
|
||||||
|
}
|
||||||
|
case <-s.closed:
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
return req.ch, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *session) cleanup(cause error) {
|
||||||
|
s.closeOnce.Do(func() {
|
||||||
|
close(s.closed)
|
||||||
|
s.pending.Range(func(key, value any) bool {
|
||||||
|
req := value.(*pendingRequest)
|
||||||
|
msg := Message{ID: key.(string), Type: MessageTypeError, Payload: map[string]any{"error": cause.Error()}}
|
||||||
|
select {
|
||||||
|
case req.ch <- msg:
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
req.close()
|
||||||
|
return true
|
||||||
|
})
|
||||||
|
s.pending = sync.Map{}
|
||||||
|
_ = s.conn.Close()
|
||||||
|
if s.manager != nil {
|
||||||
|
s.manager.handleSessionClosed(s, cause)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
@@ -153,6 +153,17 @@ func (m *Manager) RegisterExecutor(executor ProviderExecutor) {
|
|||||||
m.executors[executor.Identifier()] = executor
|
m.executors[executor.Identifier()] = executor
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// UnregisterExecutor removes the executor associated with the provider key.
|
||||||
|
func (m *Manager) UnregisterExecutor(provider string) {
|
||||||
|
provider = strings.ToLower(strings.TrimSpace(provider))
|
||||||
|
if provider == "" {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
m.mu.Lock()
|
||||||
|
delete(m.executors, provider)
|
||||||
|
m.mu.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
// Register inserts a new auth entry into the manager.
|
// Register inserts a new auth entry into the manager.
|
||||||
func (m *Manager) Register(ctx context.Context, auth *Auth) (*Auth, error) {
|
func (m *Manager) Register(ctx context.Context, auth *Auth) (*Auth, error) {
|
||||||
if auth == nil {
|
if auth == nil {
|
||||||
|
|||||||
@@ -156,7 +156,17 @@ func (a *Auth) AccountInfo() (string, string) {
|
|||||||
if v, ok := a.Metadata["email"].(string); ok {
|
if v, ok := a.Metadata["email"].(string); ok {
|
||||||
return "oauth", v
|
return "oauth", v
|
||||||
}
|
}
|
||||||
} else if a.Attributes != nil {
|
}
|
||||||
|
if strings.HasPrefix(strings.ToLower(strings.TrimSpace(a.Provider)), "aistudio-") {
|
||||||
|
if label := strings.TrimSpace(a.Label); label != "" {
|
||||||
|
return "oauth", label
|
||||||
|
}
|
||||||
|
if id := strings.TrimSpace(a.ID); id != "" {
|
||||||
|
return "oauth", id
|
||||||
|
}
|
||||||
|
return "oauth", "aistudio"
|
||||||
|
}
|
||||||
|
if a.Attributes != nil {
|
||||||
if v := a.Attributes["api_key"]; v != "" {
|
if v := a.Attributes["api_key"]; v != "" {
|
||||||
return "api_key", v
|
return "api_key", v
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -18,6 +18,7 @@ import (
|
|||||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/runtime/executor"
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/runtime/executor"
|
||||||
_ "github.com/router-for-me/CLIProxyAPI/v6/internal/usage"
|
_ "github.com/router-for-me/CLIProxyAPI/v6/internal/usage"
|
||||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/watcher"
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/watcher"
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/wsrelay"
|
||||||
sdkaccess "github.com/router-for-me/CLIProxyAPI/v6/sdk/access"
|
sdkaccess "github.com/router-for-me/CLIProxyAPI/v6/sdk/access"
|
||||||
sdkAuth "github.com/router-for-me/CLIProxyAPI/v6/sdk/auth"
|
sdkAuth "github.com/router-for-me/CLIProxyAPI/v6/sdk/auth"
|
||||||
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||||
@@ -82,6 +83,9 @@ type Service struct {
|
|||||||
|
|
||||||
// shutdownOnce ensures shutdown is called only once.
|
// shutdownOnce ensures shutdown is called only once.
|
||||||
shutdownOnce sync.Once
|
shutdownOnce sync.Once
|
||||||
|
|
||||||
|
// wsGateway manages websocket Gemini providers.
|
||||||
|
wsGateway *wsrelay.Manager
|
||||||
}
|
}
|
||||||
|
|
||||||
// RegisterUsagePlugin registers a usage plugin on the global usage manager.
|
// RegisterUsagePlugin registers a usage plugin on the global usage manager.
|
||||||
@@ -172,6 +176,72 @@ func (s *Service) handleAuthUpdate(ctx context.Context, update watcher.AuthUpdat
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *Service) ensureWebsocketGateway() {
|
||||||
|
if s == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if s.wsGateway != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
opts := wsrelay.Options{
|
||||||
|
Path: "/v1/ws",
|
||||||
|
OnConnected: s.wsOnConnected,
|
||||||
|
OnDisconnected: s.wsOnDisconnected,
|
||||||
|
LogDebugf: log.Debugf,
|
||||||
|
LogInfof: log.Infof,
|
||||||
|
LogWarnf: log.Warnf,
|
||||||
|
}
|
||||||
|
s.wsGateway = wsrelay.NewManager(opts)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Service) wsOnConnected(provider string) {
|
||||||
|
if s == nil || provider == "" {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if !strings.HasPrefix(strings.ToLower(provider), "aistudio-") {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if s.coreManager != nil {
|
||||||
|
if existing, ok := s.coreManager.GetByID(provider); ok && existing != nil {
|
||||||
|
if !existing.Disabled && existing.Status == coreauth.StatusActive {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
now := time.Now().UTC()
|
||||||
|
auth := &coreauth.Auth{
|
||||||
|
ID: provider,
|
||||||
|
Provider: provider,
|
||||||
|
Label: provider,
|
||||||
|
Status: coreauth.StatusActive,
|
||||||
|
CreatedAt: now,
|
||||||
|
UpdatedAt: now,
|
||||||
|
Attributes: map[string]string{"ws_provider": "gemini"},
|
||||||
|
}
|
||||||
|
log.Infof("websocket provider connected: %s", provider)
|
||||||
|
s.applyCoreAuthAddOrUpdate(context.Background(), auth)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Service) wsOnDisconnected(provider string, reason error) {
|
||||||
|
if s == nil || provider == "" {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if reason != nil {
|
||||||
|
if strings.Contains(reason.Error(), "replaced by new connection") {
|
||||||
|
log.Infof("websocket provider replaced: %s", provider)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
log.Warnf("websocket provider disconnected: %s (%v)", provider, reason)
|
||||||
|
} else {
|
||||||
|
log.Infof("websocket provider disconnected: %s", provider)
|
||||||
|
}
|
||||||
|
ctx := context.Background()
|
||||||
|
s.applyCoreAuthRemoval(ctx, provider)
|
||||||
|
if s.coreManager != nil {
|
||||||
|
s.coreManager.UnregisterExecutor(provider)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func (s *Service) applyCoreAuthAddOrUpdate(ctx context.Context, auth *coreauth.Auth) {
|
func (s *Service) applyCoreAuthAddOrUpdate(ctx context.Context, auth *coreauth.Auth) {
|
||||||
if s == nil || auth == nil || auth.ID == "" {
|
if s == nil || auth == nil || auth.ID == "" {
|
||||||
return
|
return
|
||||||
@@ -247,6 +317,12 @@ func (s *Service) ensureExecutorsForAuth(a *coreauth.Auth) {
|
|||||||
s.coreManager.RegisterExecutor(executor.NewOpenAICompatExecutor(compatProviderKey, s.cfg))
|
s.coreManager.RegisterExecutor(executor.NewOpenAICompatExecutor(compatProviderKey, s.cfg))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
if strings.HasPrefix(strings.ToLower(strings.TrimSpace(a.Provider)), "aistudio-") {
|
||||||
|
if s.wsGateway != nil {
|
||||||
|
s.coreManager.RegisterExecutor(executor.NewAistudioExecutor(s.cfg, a.Provider, s.wsGateway))
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
switch strings.ToLower(a.Provider) {
|
switch strings.ToLower(a.Provider) {
|
||||||
case "gemini":
|
case "gemini":
|
||||||
s.coreManager.RegisterExecutor(executor.NewGeminiExecutor(s.cfg))
|
s.coreManager.RegisterExecutor(executor.NewGeminiExecutor(s.cfg))
|
||||||
@@ -342,6 +418,27 @@ func (s *Service) Run(ctx context.Context) error {
|
|||||||
s.authManager = newDefaultAuthManager()
|
s.authManager = newDefaultAuthManager()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
s.ensureWebsocketGateway()
|
||||||
|
if s.server != nil && s.wsGateway != nil {
|
||||||
|
s.server.AttachWebsocketRoute(s.wsGateway.Path(), s.wsGateway.Handler())
|
||||||
|
s.server.SetWebsocketAuthChangeHandler(func(oldEnabled, newEnabled bool) {
|
||||||
|
if oldEnabled == newEnabled {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if !oldEnabled && newEnabled {
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
if errStop := s.wsGateway.Stop(ctx); errStop != nil {
|
||||||
|
log.Warnf("failed to reset websocket connections after ws-auth change %t -> %t: %v", oldEnabled, newEnabled, errStop)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
log.Debugf("ws-auth enabled; existing websocket sessions terminated to enforce authentication")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
log.Debugf("ws-auth disabled; existing websocket sessions remain connected")
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
if s.hooks.OnBeforeStart != nil {
|
if s.hooks.OnBeforeStart != nil {
|
||||||
s.hooks.OnBeforeStart(s.cfg)
|
s.hooks.OnBeforeStart(s.cfg)
|
||||||
}
|
}
|
||||||
@@ -379,7 +476,6 @@ func (s *Service) Run(ctx context.Context) error {
|
|||||||
s.cfg = newCfg
|
s.cfg = newCfg
|
||||||
s.cfgMu.Unlock()
|
s.cfgMu.Unlock()
|
||||||
s.rebindExecutors()
|
s.rebindExecutors()
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
watcherWrapper, err = s.watcherFactory(s.configPath, s.cfg.AuthDir, reloadCallback)
|
watcherWrapper, err = s.watcherFactory(s.configPath, s.cfg.AuthDir, reloadCallback)
|
||||||
@@ -449,6 +545,14 @@ func (s *Service) Shutdown(ctx context.Context) error {
|
|||||||
shutdownErr = err
|
shutdownErr = err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
if s.wsGateway != nil {
|
||||||
|
if err := s.wsGateway.Stop(ctx); err != nil {
|
||||||
|
log.Errorf("failed to stop websocket gateway: %v", err)
|
||||||
|
if shutdownErr == nil {
|
||||||
|
shutdownErr = err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
if s.authQueueStop != nil {
|
if s.authQueueStop != nil {
|
||||||
s.authQueueStop()
|
s.authQueueStop()
|
||||||
s.authQueueStop = nil
|
s.authQueueStop = nil
|
||||||
@@ -505,6 +609,13 @@ func (s *Service) registerModelsForAuth(a *coreauth.Auth) {
|
|||||||
}
|
}
|
||||||
provider := strings.ToLower(strings.TrimSpace(a.Provider))
|
provider := strings.ToLower(strings.TrimSpace(a.Provider))
|
||||||
compatProviderKey, compatDisplayName, compatDetected := openAICompatInfoFromAuth(a)
|
compatProviderKey, compatDisplayName, compatDetected := openAICompatInfoFromAuth(a)
|
||||||
|
if a.Attributes != nil {
|
||||||
|
if strings.EqualFold(a.Attributes["ws_provider"], "gemini") {
|
||||||
|
models := mergeGeminiModels()
|
||||||
|
GlobalModelRegistry().RegisterClient(a.ID, provider, models)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
if compatDetected {
|
if compatDetected {
|
||||||
provider = "openai-compatibility"
|
provider = "openai-compatibility"
|
||||||
}
|
}
|
||||||
@@ -611,3 +722,24 @@ func (s *Service) registerModelsForAuth(a *coreauth.Auth) {
|
|||||||
GlobalModelRegistry().RegisterClient(a.ID, key, models)
|
GlobalModelRegistry().RegisterClient(a.ID, key, models)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func mergeGeminiModels() []*ModelInfo {
|
||||||
|
models := make([]*ModelInfo, 0, 16)
|
||||||
|
seen := make(map[string]struct{})
|
||||||
|
appendModels := func(items []*ModelInfo) {
|
||||||
|
for i := range items {
|
||||||
|
m := items[i]
|
||||||
|
if m == nil || m.ID == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if _, ok := seen[m.ID]; ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
seen[m.ID] = struct{}{}
|
||||||
|
models = append(models, m)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
appendModels(registry.GetGeminiModels())
|
||||||
|
appendModels(registry.GetGeminiCLIModels())
|
||||||
|
return models
|
||||||
|
}
|
||||||
|
|||||||
Reference in New Issue
Block a user