mirror of
https://github.com/router-for-me/CLIProxyAPI.git
synced 2026-02-20 05:10:52 +08:00
feat(logging, executor): add request logging tests and WebSocket-based Codex executor
- Introduced unit tests for request logging middleware to enhance coverage. - Added WebSocket-based Codex executor to support Responses API upgrade. - Updated middleware logic to selectively capture request bodies for memory efficiency. - Enhanced Codex configuration handling with new WebSocket attributes.
This commit is contained in:
@@ -15,10 +15,12 @@ import (
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
||||
)
|
||||
|
||||
const maxErrorOnlyCapturedRequestBodyBytes int64 = 1 << 20 // 1 MiB
|
||||
|
||||
// RequestLoggingMiddleware creates a Gin middleware that logs HTTP requests and responses.
|
||||
// It captures detailed information about the request and response, including headers and body,
|
||||
// and uses the provided RequestLogger to record this data. When logging is disabled in the
|
||||
// logger, it still captures data so that upstream errors can be persisted.
|
||||
// and uses the provided RequestLogger to record this data. When full request logging is disabled,
|
||||
// body capture is limited to small known-size payloads to avoid large per-request memory spikes.
|
||||
func RequestLoggingMiddleware(logger logging.RequestLogger) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
if logger == nil {
|
||||
@@ -26,7 +28,7 @@ func RequestLoggingMiddleware(logger logging.RequestLogger) gin.HandlerFunc {
|
||||
return
|
||||
}
|
||||
|
||||
if c.Request.Method == http.MethodGet {
|
||||
if shouldSkipMethodForRequestLogging(c.Request) {
|
||||
c.Next()
|
||||
return
|
||||
}
|
||||
@@ -37,8 +39,10 @@ func RequestLoggingMiddleware(logger logging.RequestLogger) gin.HandlerFunc {
|
||||
return
|
||||
}
|
||||
|
||||
loggerEnabled := logger.IsEnabled()
|
||||
|
||||
// Capture request information
|
||||
requestInfo, err := captureRequestInfo(c)
|
||||
requestInfo, err := captureRequestInfo(c, shouldCaptureRequestBody(loggerEnabled, c.Request))
|
||||
if err != nil {
|
||||
// Log error but continue processing
|
||||
// In a real implementation, you might want to use a proper logger here
|
||||
@@ -48,7 +52,7 @@ func RequestLoggingMiddleware(logger logging.RequestLogger) gin.HandlerFunc {
|
||||
|
||||
// Create response writer wrapper
|
||||
wrapper := NewResponseWriterWrapper(c.Writer, logger, requestInfo)
|
||||
if !logger.IsEnabled() {
|
||||
if !loggerEnabled {
|
||||
wrapper.logOnErrorOnly = true
|
||||
}
|
||||
c.Writer = wrapper
|
||||
@@ -64,10 +68,47 @@ func RequestLoggingMiddleware(logger logging.RequestLogger) gin.HandlerFunc {
|
||||
}
|
||||
}
|
||||
|
||||
func shouldSkipMethodForRequestLogging(req *http.Request) bool {
|
||||
if req == nil {
|
||||
return true
|
||||
}
|
||||
if req.Method != http.MethodGet {
|
||||
return false
|
||||
}
|
||||
return !isResponsesWebsocketUpgrade(req)
|
||||
}
|
||||
|
||||
func isResponsesWebsocketUpgrade(req *http.Request) bool {
|
||||
if req == nil || req.URL == nil {
|
||||
return false
|
||||
}
|
||||
if req.URL.Path != "/v1/responses" {
|
||||
return false
|
||||
}
|
||||
return strings.EqualFold(strings.TrimSpace(req.Header.Get("Upgrade")), "websocket")
|
||||
}
|
||||
|
||||
func shouldCaptureRequestBody(loggerEnabled bool, req *http.Request) bool {
|
||||
if loggerEnabled {
|
||||
return true
|
||||
}
|
||||
if req == nil || req.Body == nil {
|
||||
return false
|
||||
}
|
||||
contentType := strings.ToLower(strings.TrimSpace(req.Header.Get("Content-Type")))
|
||||
if strings.HasPrefix(contentType, "multipart/form-data") {
|
||||
return false
|
||||
}
|
||||
if req.ContentLength <= 0 {
|
||||
return false
|
||||
}
|
||||
return req.ContentLength <= maxErrorOnlyCapturedRequestBodyBytes
|
||||
}
|
||||
|
||||
// captureRequestInfo extracts relevant information from the incoming HTTP request.
|
||||
// It captures the URL, method, headers, and body. The request body is read and then
|
||||
// restored so that it can be processed by subsequent handlers.
|
||||
func captureRequestInfo(c *gin.Context) (*RequestInfo, error) {
|
||||
func captureRequestInfo(c *gin.Context, captureBody bool) (*RequestInfo, error) {
|
||||
// Capture URL with sensitive query parameters masked
|
||||
maskedQuery := util.MaskSensitiveQuery(c.Request.URL.RawQuery)
|
||||
url := c.Request.URL.Path
|
||||
@@ -86,7 +127,7 @@ func captureRequestInfo(c *gin.Context) (*RequestInfo, error) {
|
||||
|
||||
// Capture request body
|
||||
var body []byte
|
||||
if c.Request.Body != nil {
|
||||
if captureBody && c.Request.Body != nil {
|
||||
// Read the body
|
||||
bodyBytes, err := io.ReadAll(c.Request.Body)
|
||||
if err != nil {
|
||||
|
||||
138
internal/api/middleware/request_logging_test.go
Normal file
138
internal/api/middleware/request_logging_test.go
Normal file
@@ -0,0 +1,138 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"io"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestShouldSkipMethodForRequestLogging(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
req *http.Request
|
||||
skip bool
|
||||
}{
|
||||
{
|
||||
name: "nil request",
|
||||
req: nil,
|
||||
skip: true,
|
||||
},
|
||||
{
|
||||
name: "post request should not skip",
|
||||
req: &http.Request{
|
||||
Method: http.MethodPost,
|
||||
URL: &url.URL{Path: "/v1/responses"},
|
||||
},
|
||||
skip: false,
|
||||
},
|
||||
{
|
||||
name: "plain get should skip",
|
||||
req: &http.Request{
|
||||
Method: http.MethodGet,
|
||||
URL: &url.URL{Path: "/v1/models"},
|
||||
Header: http.Header{},
|
||||
},
|
||||
skip: true,
|
||||
},
|
||||
{
|
||||
name: "responses websocket upgrade should not skip",
|
||||
req: &http.Request{
|
||||
Method: http.MethodGet,
|
||||
URL: &url.URL{Path: "/v1/responses"},
|
||||
Header: http.Header{"Upgrade": []string{"websocket"}},
|
||||
},
|
||||
skip: false,
|
||||
},
|
||||
{
|
||||
name: "responses get without upgrade should skip",
|
||||
req: &http.Request{
|
||||
Method: http.MethodGet,
|
||||
URL: &url.URL{Path: "/v1/responses"},
|
||||
Header: http.Header{},
|
||||
},
|
||||
skip: true,
|
||||
},
|
||||
}
|
||||
|
||||
for i := range tests {
|
||||
got := shouldSkipMethodForRequestLogging(tests[i].req)
|
||||
if got != tests[i].skip {
|
||||
t.Fatalf("%s: got skip=%t, want %t", tests[i].name, got, tests[i].skip)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestShouldCaptureRequestBody(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
loggerEnabled bool
|
||||
req *http.Request
|
||||
want bool
|
||||
}{
|
||||
{
|
||||
name: "logger enabled always captures",
|
||||
loggerEnabled: true,
|
||||
req: &http.Request{
|
||||
Body: io.NopCloser(strings.NewReader("{}")),
|
||||
ContentLength: -1,
|
||||
Header: http.Header{"Content-Type": []string{"application/json"}},
|
||||
},
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "nil request",
|
||||
loggerEnabled: false,
|
||||
req: nil,
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "small known size json in error-only mode",
|
||||
loggerEnabled: false,
|
||||
req: &http.Request{
|
||||
Body: io.NopCloser(strings.NewReader("{}")),
|
||||
ContentLength: 2,
|
||||
Header: http.Header{"Content-Type": []string{"application/json"}},
|
||||
},
|
||||
want: true,
|
||||
},
|
||||
{
|
||||
name: "large known size skipped in error-only mode",
|
||||
loggerEnabled: false,
|
||||
req: &http.Request{
|
||||
Body: io.NopCloser(strings.NewReader("x")),
|
||||
ContentLength: maxErrorOnlyCapturedRequestBodyBytes + 1,
|
||||
Header: http.Header{"Content-Type": []string{"application/json"}},
|
||||
},
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "unknown size skipped in error-only mode",
|
||||
loggerEnabled: false,
|
||||
req: &http.Request{
|
||||
Body: io.NopCloser(strings.NewReader("x")),
|
||||
ContentLength: -1,
|
||||
Header: http.Header{"Content-Type": []string{"application/json"}},
|
||||
},
|
||||
want: false,
|
||||
},
|
||||
{
|
||||
name: "multipart skipped in error-only mode",
|
||||
loggerEnabled: false,
|
||||
req: &http.Request{
|
||||
Body: io.NopCloser(strings.NewReader("x")),
|
||||
ContentLength: 1,
|
||||
Header: http.Header{"Content-Type": []string{"multipart/form-data; boundary=abc"}},
|
||||
},
|
||||
want: false,
|
||||
},
|
||||
}
|
||||
|
||||
for i := range tests {
|
||||
got := shouldCaptureRequestBody(tests[i].loggerEnabled, tests[i].req)
|
||||
if got != tests[i].want {
|
||||
t.Fatalf("%s: got %t, want %t", tests[i].name, got, tests[i].want)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -14,6 +14,8 @@ import (
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/logging"
|
||||
)
|
||||
|
||||
const requestBodyOverrideContextKey = "REQUEST_BODY_OVERRIDE"
|
||||
|
||||
// RequestInfo holds essential details of an incoming HTTP request for logging purposes.
|
||||
type RequestInfo struct {
|
||||
URL string // URL is the request URL.
|
||||
@@ -223,8 +225,8 @@ func (w *ResponseWriterWrapper) detectStreaming(contentType string) bool {
|
||||
|
||||
// Only fall back to request payload hints when Content-Type is not set yet.
|
||||
if w.requestInfo != nil && len(w.requestInfo.Body) > 0 {
|
||||
bodyStr := string(w.requestInfo.Body)
|
||||
return strings.Contains(bodyStr, `"stream": true`) || strings.Contains(bodyStr, `"stream":true`)
|
||||
return bytes.Contains(w.requestInfo.Body, []byte(`"stream": true`)) ||
|
||||
bytes.Contains(w.requestInfo.Body, []byte(`"stream":true`))
|
||||
}
|
||||
|
||||
return false
|
||||
@@ -310,7 +312,7 @@ func (w *ResponseWriterWrapper) Finalize(c *gin.Context) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
return w.logRequest(finalStatusCode, w.cloneHeaders(), w.body.Bytes(), w.extractAPIRequest(c), w.extractAPIResponse(c), w.extractAPIResponseTimestamp(c), slicesAPIResponseError, forceLog)
|
||||
return w.logRequest(w.extractRequestBody(c), finalStatusCode, w.cloneHeaders(), w.body.Bytes(), w.extractAPIRequest(c), w.extractAPIResponse(c), w.extractAPIResponseTimestamp(c), slicesAPIResponseError, forceLog)
|
||||
}
|
||||
|
||||
func (w *ResponseWriterWrapper) cloneHeaders() map[string][]string {
|
||||
@@ -361,16 +363,32 @@ func (w *ResponseWriterWrapper) extractAPIResponseTimestamp(c *gin.Context) time
|
||||
return time.Time{}
|
||||
}
|
||||
|
||||
func (w *ResponseWriterWrapper) logRequest(statusCode int, headers map[string][]string, body []byte, apiRequestBody, apiResponseBody []byte, apiResponseTimestamp time.Time, apiResponseErrors []*interfaces.ErrorMessage, forceLog bool) error {
|
||||
func (w *ResponseWriterWrapper) extractRequestBody(c *gin.Context) []byte {
|
||||
if c != nil {
|
||||
if bodyOverride, isExist := c.Get(requestBodyOverrideContextKey); isExist {
|
||||
switch value := bodyOverride.(type) {
|
||||
case []byte:
|
||||
if len(value) > 0 {
|
||||
return bytes.Clone(value)
|
||||
}
|
||||
case string:
|
||||
if strings.TrimSpace(value) != "" {
|
||||
return []byte(value)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
if w.requestInfo != nil && len(w.requestInfo.Body) > 0 {
|
||||
return w.requestInfo.Body
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (w *ResponseWriterWrapper) logRequest(requestBody []byte, statusCode int, headers map[string][]string, body []byte, apiRequestBody, apiResponseBody []byte, apiResponseTimestamp time.Time, apiResponseErrors []*interfaces.ErrorMessage, forceLog bool) error {
|
||||
if w.requestInfo == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
var requestBody []byte
|
||||
if len(w.requestInfo.Body) > 0 {
|
||||
requestBody = w.requestInfo.Body
|
||||
}
|
||||
|
||||
if loggerWithOptions, ok := w.logger.(interface {
|
||||
LogRequestWithOptions(string, string, map[string][]string, []byte, int, map[string][]string, []byte, []byte, []byte, []*interfaces.ErrorMessage, bool, string, time.Time, time.Time) error
|
||||
}); ok {
|
||||
|
||||
43
internal/api/middleware/response_writer_test.go
Normal file
43
internal/api/middleware/response_writer_test.go
Normal file
@@ -0,0 +1,43 @@
|
||||
package middleware
|
||||
|
||||
import (
|
||||
"net/http/httptest"
|
||||
"testing"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
func TestExtractRequestBodyPrefersOverride(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
recorder := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(recorder)
|
||||
|
||||
wrapper := &ResponseWriterWrapper{
|
||||
requestInfo: &RequestInfo{Body: []byte("original-body")},
|
||||
}
|
||||
|
||||
body := wrapper.extractRequestBody(c)
|
||||
if string(body) != "original-body" {
|
||||
t.Fatalf("request body = %q, want %q", string(body), "original-body")
|
||||
}
|
||||
|
||||
c.Set(requestBodyOverrideContextKey, []byte("override-body"))
|
||||
body = wrapper.extractRequestBody(c)
|
||||
if string(body) != "override-body" {
|
||||
t.Fatalf("request body = %q, want %q", string(body), "override-body")
|
||||
}
|
||||
}
|
||||
|
||||
func TestExtractRequestBodySupportsStringOverride(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
recorder := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(recorder)
|
||||
|
||||
wrapper := &ResponseWriterWrapper{}
|
||||
c.Set(requestBodyOverrideContextKey, "override-as-string")
|
||||
|
||||
body := wrapper.extractRequestBody(c)
|
||||
if string(body) != "override-as-string" {
|
||||
t.Fatalf("request body = %q, want %q", string(body), "override-as-string")
|
||||
}
|
||||
}
|
||||
@@ -323,6 +323,7 @@ func (s *Server) setupRoutes() {
|
||||
v1.POST("/completions", openaiHandlers.Completions)
|
||||
v1.POST("/messages", claudeCodeHandlers.ClaudeMessages)
|
||||
v1.POST("/messages/count_tokens", claudeCodeHandlers.ClaudeCountTokens)
|
||||
v1.GET("/responses", openaiResponsesHandlers.ResponsesWebsocket)
|
||||
v1.POST("/responses", openaiResponsesHandlers.Responses)
|
||||
v1.POST("/responses/compact", openaiResponsesHandlers.Compact)
|
||||
}
|
||||
|
||||
@@ -355,6 +355,9 @@ type CodexKey struct {
|
||||
// If empty, the default Codex API URL will be used.
|
||||
BaseURL string `yaml:"base-url" json:"base-url"`
|
||||
|
||||
// Websockets enables the Responses API websocket transport for this credential.
|
||||
Websockets bool `yaml:"websockets,omitempty" json:"websockets,omitempty"`
|
||||
|
||||
// ProxyURL overrides the global proxy setting for this API key if provided.
|
||||
ProxyURL string `yaml:"proxy-url" json:"proxy-url"`
|
||||
|
||||
|
||||
@@ -19,6 +19,7 @@ import (
|
||||
// - codex
|
||||
// - qwen
|
||||
// - iflow
|
||||
// - kimi
|
||||
// - antigravity (returns static overrides only)
|
||||
func GetStaticModelDefinitionsByChannel(channel string) []*ModelInfo {
|
||||
key := strings.ToLower(strings.TrimSpace(channel))
|
||||
@@ -39,6 +40,8 @@ func GetStaticModelDefinitionsByChannel(channel string) []*ModelInfo {
|
||||
return GetQwenModels()
|
||||
case "iflow":
|
||||
return GetIFlowModels()
|
||||
case "kimi":
|
||||
return GetKimiModels()
|
||||
case "antigravity":
|
||||
cfg := GetAntigravityModelConfig()
|
||||
if len(cfg) == 0 {
|
||||
@@ -83,6 +86,7 @@ func LookupStaticModelInfo(modelID string) *ModelInfo {
|
||||
GetOpenAIModels(),
|
||||
GetQwenModels(),
|
||||
GetIFlowModels(),
|
||||
GetKimiModels(),
|
||||
}
|
||||
for _, models := range allModels {
|
||||
for _, m := range models {
|
||||
|
||||
@@ -28,6 +28,17 @@ func GetClaudeModels() []*ModelInfo {
|
||||
MaxCompletionTokens: 64000,
|
||||
Thinking: &ThinkingSupport{Min: 1024, Max: 128000, ZeroAllowed: true, DynamicAllowed: false},
|
||||
},
|
||||
{
|
||||
ID: "claude-sonnet-4-6",
|
||||
Object: "model",
|
||||
Created: 1771372800, // 2026-02-17
|
||||
OwnedBy: "anthropic",
|
||||
Type: "claude",
|
||||
DisplayName: "Claude 4.6 Sonnet",
|
||||
ContextLength: 200000,
|
||||
MaxCompletionTokens: 64000,
|
||||
Thinking: &ThinkingSupport{Min: 1024, Max: 128000, ZeroAllowed: true, DynamicAllowed: false},
|
||||
},
|
||||
{
|
||||
ID: "claude-opus-4-6",
|
||||
Object: "model",
|
||||
@@ -788,6 +799,19 @@ func GetQwenModels() []*ModelInfo {
|
||||
MaxCompletionTokens: 2048,
|
||||
SupportedParameters: []string{"temperature", "top_p", "max_tokens", "stream", "stop"},
|
||||
},
|
||||
{
|
||||
ID: "coder-model",
|
||||
Object: "model",
|
||||
Created: 1771171200,
|
||||
OwnedBy: "qwen",
|
||||
Type: "qwen",
|
||||
Version: "3.5",
|
||||
DisplayName: "Qwen 3.5 Plus",
|
||||
Description: "efficient hybrid model with leading coding performance",
|
||||
ContextLength: 1048576,
|
||||
MaxCompletionTokens: 65536,
|
||||
SupportedParameters: []string{"temperature", "top_p", "max_tokens", "stream", "stop"},
|
||||
},
|
||||
{
|
||||
ID: "vision-model",
|
||||
Object: "model",
|
||||
@@ -884,6 +908,8 @@ func GetAntigravityModelConfig() map[string]*AntigravityModelConfig {
|
||||
"claude-opus-4-5-thinking": {Thinking: &ThinkingSupport{Min: 1024, Max: 128000, ZeroAllowed: true, DynamicAllowed: true}, MaxCompletionTokens: 64000},
|
||||
"claude-opus-4-6-thinking": {Thinking: &ThinkingSupport{Min: 1024, Max: 128000, ZeroAllowed: true, DynamicAllowed: true}, MaxCompletionTokens: 64000},
|
||||
"claude-sonnet-4-5": {MaxCompletionTokens: 64000},
|
||||
"claude-sonnet-4-6": {MaxCompletionTokens: 64000},
|
||||
"claude-sonnet-4-6-thinking": {Thinking: &ThinkingSupport{Min: 1024, Max: 128000, ZeroAllowed: true, DynamicAllowed: true}, MaxCompletionTokens: 64000},
|
||||
"gpt-oss-120b-medium": {},
|
||||
"tab_flash_lite_preview": {},
|
||||
}
|
||||
|
||||
1407
internal/runtime/executor/codex_websockets_executor.go
Normal file
1407
internal/runtime/executor/codex_websockets_executor.go
Normal file
File diff suppressed because it is too large
Load Diff
@@ -22,9 +22,7 @@ import (
|
||||
)
|
||||
|
||||
const (
|
||||
qwenUserAgent = "google-api-nodejs-client/9.15.1"
|
||||
qwenXGoogAPIClient = "gl-node/22.17.0"
|
||||
qwenClientMetadataValue = "ideType=IDE_UNSPECIFIED,platform=PLATFORM_UNSPECIFIED,pluginType=GEMINI"
|
||||
qwenUserAgent = "QwenCode/0.10.3 (darwin; arm64)"
|
||||
)
|
||||
|
||||
// QwenExecutor is a stateless executor for Qwen Code using OpenAI-compatible chat completions.
|
||||
@@ -344,8 +342,18 @@ func applyQwenHeaders(r *http.Request, token string, stream bool) {
|
||||
r.Header.Set("Content-Type", "application/json")
|
||||
r.Header.Set("Authorization", "Bearer "+token)
|
||||
r.Header.Set("User-Agent", qwenUserAgent)
|
||||
r.Header.Set("X-Goog-Api-Client", qwenXGoogAPIClient)
|
||||
r.Header.Set("Client-Metadata", qwenClientMetadataValue)
|
||||
r.Header.Set("X-Dashscope-Useragent", qwenUserAgent)
|
||||
r.Header.Set("X-Stainless-Runtime-Version", "v22.17.0")
|
||||
r.Header.Set("Sec-Fetch-Mode", "cors")
|
||||
r.Header.Set("X-Stainless-Lang", "js")
|
||||
r.Header.Set("X-Stainless-Arch", "arm64")
|
||||
r.Header.Set("X-Stainless-Package-Version", "5.11.0")
|
||||
r.Header.Set("X-Dashscope-Cachecontrol", "enable")
|
||||
r.Header.Set("X-Stainless-Retry-Count", "0")
|
||||
r.Header.Set("X-Stainless-Os", "MacOS")
|
||||
r.Header.Set("X-Dashscope-Authtype", "qwen-oauth")
|
||||
r.Header.Set("X-Stainless-Runtime", "node")
|
||||
|
||||
if stream {
|
||||
r.Header.Set("Accept", "text/event-stream")
|
||||
return
|
||||
|
||||
@@ -184,6 +184,9 @@ func BuildConfigChangeDetails(oldCfg, newCfg *config.Config) []string {
|
||||
if strings.TrimSpace(o.Prefix) != strings.TrimSpace(n.Prefix) {
|
||||
changes = append(changes, fmt.Sprintf("codex[%d].prefix: %s -> %s", i, strings.TrimSpace(o.Prefix), strings.TrimSpace(n.Prefix)))
|
||||
}
|
||||
if o.Websockets != n.Websockets {
|
||||
changes = append(changes, fmt.Sprintf("codex[%d].websockets: %t -> %t", i, o.Websockets, n.Websockets))
|
||||
}
|
||||
if strings.TrimSpace(o.APIKey) != strings.TrimSpace(n.APIKey) {
|
||||
changes = append(changes, fmt.Sprintf("codex[%d].api-key: updated", i))
|
||||
}
|
||||
|
||||
@@ -160,6 +160,9 @@ func (s *ConfigSynthesizer) synthesizeCodexKeys(ctx *SynthesisContext) []*coreau
|
||||
if ck.BaseURL != "" {
|
||||
attrs["base_url"] = ck.BaseURL
|
||||
}
|
||||
if ck.Websockets {
|
||||
attrs["websockets"] = "true"
|
||||
}
|
||||
if hash := diff.ComputeCodexModelsHash(ck.Models); hash != "" {
|
||||
attrs["models_hash"] = hash
|
||||
}
|
||||
|
||||
@@ -231,10 +231,11 @@ func TestConfigSynthesizer_CodexKeys(t *testing.T) {
|
||||
Config: &config.Config{
|
||||
CodexKey: []config.CodexKey{
|
||||
{
|
||||
APIKey: "codex-key-123",
|
||||
Prefix: "dev",
|
||||
BaseURL: "https://api.openai.com",
|
||||
ProxyURL: "http://proxy.local",
|
||||
APIKey: "codex-key-123",
|
||||
Prefix: "dev",
|
||||
BaseURL: "https://api.openai.com",
|
||||
ProxyURL: "http://proxy.local",
|
||||
Websockets: true,
|
||||
},
|
||||
},
|
||||
},
|
||||
@@ -259,6 +260,9 @@ func TestConfigSynthesizer_CodexKeys(t *testing.T) {
|
||||
if auths[0].ProxyURL != "http://proxy.local" {
|
||||
t.Errorf("expected proxy_url http://proxy.local, got %s", auths[0].ProxyURL)
|
||||
}
|
||||
if auths[0].Attributes["websockets"] != "true" {
|
||||
t.Errorf("expected websockets=true, got %s", auths[0].Attributes["websockets"])
|
||||
}
|
||||
}
|
||||
|
||||
func TestConfigSynthesizer_CodexKeys_SkipsEmptyAndHeaders(t *testing.T) {
|
||||
|
||||
@@ -52,6 +52,45 @@ const (
|
||||
defaultStreamingBootstrapRetries = 0
|
||||
)
|
||||
|
||||
type pinnedAuthContextKey struct{}
|
||||
type selectedAuthCallbackContextKey struct{}
|
||||
type executionSessionContextKey struct{}
|
||||
|
||||
// WithPinnedAuthID returns a child context that requests execution on a specific auth ID.
|
||||
func WithPinnedAuthID(ctx context.Context, authID string) context.Context {
|
||||
authID = strings.TrimSpace(authID)
|
||||
if authID == "" {
|
||||
return ctx
|
||||
}
|
||||
if ctx == nil {
|
||||
ctx = context.Background()
|
||||
}
|
||||
return context.WithValue(ctx, pinnedAuthContextKey{}, authID)
|
||||
}
|
||||
|
||||
// WithSelectedAuthIDCallback returns a child context that receives the selected auth ID.
|
||||
func WithSelectedAuthIDCallback(ctx context.Context, callback func(string)) context.Context {
|
||||
if callback == nil {
|
||||
return ctx
|
||||
}
|
||||
if ctx == nil {
|
||||
ctx = context.Background()
|
||||
}
|
||||
return context.WithValue(ctx, selectedAuthCallbackContextKey{}, callback)
|
||||
}
|
||||
|
||||
// WithExecutionSessionID returns a child context tagged with a long-lived execution session ID.
|
||||
func WithExecutionSessionID(ctx context.Context, sessionID string) context.Context {
|
||||
sessionID = strings.TrimSpace(sessionID)
|
||||
if sessionID == "" {
|
||||
return ctx
|
||||
}
|
||||
if ctx == nil {
|
||||
ctx = context.Background()
|
||||
}
|
||||
return context.WithValue(ctx, executionSessionContextKey{}, sessionID)
|
||||
}
|
||||
|
||||
// BuildErrorResponseBody builds an OpenAI-compatible JSON error response body.
|
||||
// If errText is already valid JSON, it is returned as-is to preserve upstream error payloads.
|
||||
func BuildErrorResponseBody(status int, errText string) []byte {
|
||||
@@ -152,7 +191,59 @@ func requestExecutionMetadata(ctx context.Context) map[string]any {
|
||||
if key == "" {
|
||||
key = uuid.NewString()
|
||||
}
|
||||
return map[string]any{idempotencyKeyMetadataKey: key}
|
||||
|
||||
meta := map[string]any{idempotencyKeyMetadataKey: key}
|
||||
if pinnedAuthID := pinnedAuthIDFromContext(ctx); pinnedAuthID != "" {
|
||||
meta[coreexecutor.PinnedAuthMetadataKey] = pinnedAuthID
|
||||
}
|
||||
if selectedCallback := selectedAuthIDCallbackFromContext(ctx); selectedCallback != nil {
|
||||
meta[coreexecutor.SelectedAuthCallbackMetadataKey] = selectedCallback
|
||||
}
|
||||
if executionSessionID := executionSessionIDFromContext(ctx); executionSessionID != "" {
|
||||
meta[coreexecutor.ExecutionSessionMetadataKey] = executionSessionID
|
||||
}
|
||||
return meta
|
||||
}
|
||||
|
||||
func pinnedAuthIDFromContext(ctx context.Context) string {
|
||||
if ctx == nil {
|
||||
return ""
|
||||
}
|
||||
raw := ctx.Value(pinnedAuthContextKey{})
|
||||
switch v := raw.(type) {
|
||||
case string:
|
||||
return strings.TrimSpace(v)
|
||||
case []byte:
|
||||
return strings.TrimSpace(string(v))
|
||||
default:
|
||||
return ""
|
||||
}
|
||||
}
|
||||
|
||||
func selectedAuthIDCallbackFromContext(ctx context.Context) func(string) {
|
||||
if ctx == nil {
|
||||
return nil
|
||||
}
|
||||
raw := ctx.Value(selectedAuthCallbackContextKey{})
|
||||
if callback, ok := raw.(func(string)); ok && callback != nil {
|
||||
return callback
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func executionSessionIDFromContext(ctx context.Context) string {
|
||||
if ctx == nil {
|
||||
return ""
|
||||
}
|
||||
raw := ctx.Value(executionSessionContextKey{})
|
||||
switch v := raw.(type) {
|
||||
case string:
|
||||
return strings.TrimSpace(v)
|
||||
case []byte:
|
||||
return strings.TrimSpace(string(v))
|
||||
default:
|
||||
return ""
|
||||
}
|
||||
}
|
||||
|
||||
// BaseAPIHandler contains the handlers for API endpoints.
|
||||
|
||||
@@ -122,6 +122,82 @@ func (e *payloadThenErrorStreamExecutor) Calls() int {
|
||||
return e.calls
|
||||
}
|
||||
|
||||
type authAwareStreamExecutor struct {
|
||||
mu sync.Mutex
|
||||
calls int
|
||||
authIDs []string
|
||||
}
|
||||
|
||||
func (e *authAwareStreamExecutor) Identifier() string { return "codex" }
|
||||
|
||||
func (e *authAwareStreamExecutor) Execute(context.Context, *coreauth.Auth, coreexecutor.Request, coreexecutor.Options) (coreexecutor.Response, error) {
|
||||
return coreexecutor.Response{}, &coreauth.Error{Code: "not_implemented", Message: "Execute not implemented"}
|
||||
}
|
||||
|
||||
func (e *authAwareStreamExecutor) ExecuteStream(ctx context.Context, auth *coreauth.Auth, req coreexecutor.Request, opts coreexecutor.Options) (<-chan coreexecutor.StreamChunk, error) {
|
||||
_ = ctx
|
||||
_ = req
|
||||
_ = opts
|
||||
ch := make(chan coreexecutor.StreamChunk, 1)
|
||||
|
||||
authID := ""
|
||||
if auth != nil {
|
||||
authID = auth.ID
|
||||
}
|
||||
|
||||
e.mu.Lock()
|
||||
e.calls++
|
||||
e.authIDs = append(e.authIDs, authID)
|
||||
e.mu.Unlock()
|
||||
|
||||
if authID == "auth1" {
|
||||
ch <- coreexecutor.StreamChunk{
|
||||
Err: &coreauth.Error{
|
||||
Code: "unauthorized",
|
||||
Message: "unauthorized",
|
||||
Retryable: false,
|
||||
HTTPStatus: http.StatusUnauthorized,
|
||||
},
|
||||
}
|
||||
close(ch)
|
||||
return ch, nil
|
||||
}
|
||||
|
||||
ch <- coreexecutor.StreamChunk{Payload: []byte("ok")}
|
||||
close(ch)
|
||||
return ch, nil
|
||||
}
|
||||
|
||||
func (e *authAwareStreamExecutor) Refresh(ctx context.Context, auth *coreauth.Auth) (*coreauth.Auth, error) {
|
||||
return auth, nil
|
||||
}
|
||||
|
||||
func (e *authAwareStreamExecutor) CountTokens(context.Context, *coreauth.Auth, coreexecutor.Request, coreexecutor.Options) (coreexecutor.Response, error) {
|
||||
return coreexecutor.Response{}, &coreauth.Error{Code: "not_implemented", Message: "CountTokens not implemented"}
|
||||
}
|
||||
|
||||
func (e *authAwareStreamExecutor) HttpRequest(ctx context.Context, auth *coreauth.Auth, req *http.Request) (*http.Response, error) {
|
||||
return nil, &coreauth.Error{
|
||||
Code: "not_implemented",
|
||||
Message: "HttpRequest not implemented",
|
||||
HTTPStatus: http.StatusNotImplemented,
|
||||
}
|
||||
}
|
||||
|
||||
func (e *authAwareStreamExecutor) Calls() int {
|
||||
e.mu.Lock()
|
||||
defer e.mu.Unlock()
|
||||
return e.calls
|
||||
}
|
||||
|
||||
func (e *authAwareStreamExecutor) AuthIDs() []string {
|
||||
e.mu.Lock()
|
||||
defer e.mu.Unlock()
|
||||
out := make([]string, len(e.authIDs))
|
||||
copy(out, e.authIDs)
|
||||
return out
|
||||
}
|
||||
|
||||
func TestExecuteStreamWithAuthManager_RetriesBeforeFirstByte(t *testing.T) {
|
||||
executor := &failOnceStreamExecutor{}
|
||||
manager := coreauth.NewManager(nil, nil, nil)
|
||||
@@ -252,3 +328,128 @@ func TestExecuteStreamWithAuthManager_DoesNotRetryAfterFirstByte(t *testing.T) {
|
||||
t.Fatalf("expected 1 stream attempt, got %d", executor.Calls())
|
||||
}
|
||||
}
|
||||
|
||||
func TestExecuteStreamWithAuthManager_PinnedAuthKeepsSameUpstream(t *testing.T) {
|
||||
executor := &authAwareStreamExecutor{}
|
||||
manager := coreauth.NewManager(nil, nil, nil)
|
||||
manager.RegisterExecutor(executor)
|
||||
|
||||
auth1 := &coreauth.Auth{
|
||||
ID: "auth1",
|
||||
Provider: "codex",
|
||||
Status: coreauth.StatusActive,
|
||||
Metadata: map[string]any{"email": "test1@example.com"},
|
||||
}
|
||||
if _, err := manager.Register(context.Background(), auth1); err != nil {
|
||||
t.Fatalf("manager.Register(auth1): %v", err)
|
||||
}
|
||||
|
||||
auth2 := &coreauth.Auth{
|
||||
ID: "auth2",
|
||||
Provider: "codex",
|
||||
Status: coreauth.StatusActive,
|
||||
Metadata: map[string]any{"email": "test2@example.com"},
|
||||
}
|
||||
if _, err := manager.Register(context.Background(), auth2); err != nil {
|
||||
t.Fatalf("manager.Register(auth2): %v", err)
|
||||
}
|
||||
|
||||
registry.GetGlobalRegistry().RegisterClient(auth1.ID, auth1.Provider, []*registry.ModelInfo{{ID: "test-model"}})
|
||||
registry.GetGlobalRegistry().RegisterClient(auth2.ID, auth2.Provider, []*registry.ModelInfo{{ID: "test-model"}})
|
||||
t.Cleanup(func() {
|
||||
registry.GetGlobalRegistry().UnregisterClient(auth1.ID)
|
||||
registry.GetGlobalRegistry().UnregisterClient(auth2.ID)
|
||||
})
|
||||
|
||||
handler := NewBaseAPIHandlers(&sdkconfig.SDKConfig{
|
||||
Streaming: sdkconfig.StreamingConfig{
|
||||
BootstrapRetries: 1,
|
||||
},
|
||||
}, manager)
|
||||
ctx := WithPinnedAuthID(context.Background(), "auth1")
|
||||
dataChan, errChan := handler.ExecuteStreamWithAuthManager(ctx, "openai", "test-model", []byte(`{"model":"test-model"}`), "")
|
||||
if dataChan == nil || errChan == nil {
|
||||
t.Fatalf("expected non-nil channels")
|
||||
}
|
||||
|
||||
var got []byte
|
||||
for chunk := range dataChan {
|
||||
got = append(got, chunk...)
|
||||
}
|
||||
|
||||
var gotErr error
|
||||
for msg := range errChan {
|
||||
if msg != nil && msg.Error != nil {
|
||||
gotErr = msg.Error
|
||||
}
|
||||
}
|
||||
|
||||
if len(got) != 0 {
|
||||
t.Fatalf("expected empty payload, got %q", string(got))
|
||||
}
|
||||
if gotErr == nil {
|
||||
t.Fatalf("expected terminal error, got nil")
|
||||
}
|
||||
authIDs := executor.AuthIDs()
|
||||
if len(authIDs) == 0 {
|
||||
t.Fatalf("expected at least one upstream attempt")
|
||||
}
|
||||
for _, authID := range authIDs {
|
||||
if authID != "auth1" {
|
||||
t.Fatalf("expected all attempts on auth1, got sequence %v", authIDs)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestExecuteStreamWithAuthManager_SelectedAuthCallbackReceivesAuthID(t *testing.T) {
|
||||
executor := &authAwareStreamExecutor{}
|
||||
manager := coreauth.NewManager(nil, nil, nil)
|
||||
manager.RegisterExecutor(executor)
|
||||
|
||||
auth2 := &coreauth.Auth{
|
||||
ID: "auth2",
|
||||
Provider: "codex",
|
||||
Status: coreauth.StatusActive,
|
||||
Metadata: map[string]any{"email": "test2@example.com"},
|
||||
}
|
||||
if _, err := manager.Register(context.Background(), auth2); err != nil {
|
||||
t.Fatalf("manager.Register(auth2): %v", err)
|
||||
}
|
||||
|
||||
registry.GetGlobalRegistry().RegisterClient(auth2.ID, auth2.Provider, []*registry.ModelInfo{{ID: "test-model"}})
|
||||
t.Cleanup(func() {
|
||||
registry.GetGlobalRegistry().UnregisterClient(auth2.ID)
|
||||
})
|
||||
|
||||
handler := NewBaseAPIHandlers(&sdkconfig.SDKConfig{
|
||||
Streaming: sdkconfig.StreamingConfig{
|
||||
BootstrapRetries: 0,
|
||||
},
|
||||
}, manager)
|
||||
|
||||
selectedAuthID := ""
|
||||
ctx := WithSelectedAuthIDCallback(context.Background(), func(authID string) {
|
||||
selectedAuthID = authID
|
||||
})
|
||||
dataChan, errChan := handler.ExecuteStreamWithAuthManager(ctx, "openai", "test-model", []byte(`{"model":"test-model"}`), "")
|
||||
if dataChan == nil || errChan == nil {
|
||||
t.Fatalf("expected non-nil channels")
|
||||
}
|
||||
|
||||
var got []byte
|
||||
for chunk := range dataChan {
|
||||
got = append(got, chunk...)
|
||||
}
|
||||
for msg := range errChan {
|
||||
if msg != nil {
|
||||
t.Fatalf("unexpected error: %+v", msg)
|
||||
}
|
||||
}
|
||||
|
||||
if string(got) != "ok" {
|
||||
t.Fatalf("expected payload ok, got %q", string(got))
|
||||
}
|
||||
if selectedAuthID != "auth2" {
|
||||
t.Fatalf("selectedAuthID = %q, want %q", selectedAuthID, "auth2")
|
||||
}
|
||||
}
|
||||
|
||||
662
sdk/api/handlers/openai/openai_responses_websocket.go
Normal file
662
sdk/api/handlers/openai/openai_responses_websocket.go
Normal file
@@ -0,0 +1,662 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/google/uuid"
|
||||
"github.com/gorilla/websocket"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/interfaces"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/sdk/api/handlers"
|
||||
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/tidwall/gjson"
|
||||
"github.com/tidwall/sjson"
|
||||
)
|
||||
|
||||
const (
|
||||
wsRequestTypeCreate = "response.create"
|
||||
wsRequestTypeAppend = "response.append"
|
||||
wsEventTypeError = "error"
|
||||
wsEventTypeCompleted = "response.completed"
|
||||
wsEventTypeDone = "response.done"
|
||||
wsDoneMarker = "[DONE]"
|
||||
wsTurnStateHeader = "x-codex-turn-state"
|
||||
wsRequestBodyKey = "REQUEST_BODY_OVERRIDE"
|
||||
wsPayloadLogMaxSize = 2048
|
||||
)
|
||||
|
||||
var responsesWebsocketUpgrader = websocket.Upgrader{
|
||||
ReadBufferSize: 4096,
|
||||
WriteBufferSize: 4096,
|
||||
CheckOrigin: func(r *http.Request) bool {
|
||||
return true
|
||||
},
|
||||
}
|
||||
|
||||
// ResponsesWebsocket handles websocket requests for /v1/responses.
|
||||
// It accepts `response.create` and `response.append` requests and streams
|
||||
// response events back as JSON websocket text messages.
|
||||
func (h *OpenAIResponsesAPIHandler) ResponsesWebsocket(c *gin.Context) {
|
||||
conn, err := responsesWebsocketUpgrader.Upgrade(c.Writer, c.Request, websocketUpgradeHeaders(c.Request))
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
passthroughSessionID := uuid.NewString()
|
||||
clientRemoteAddr := ""
|
||||
if c != nil && c.Request != nil {
|
||||
clientRemoteAddr = strings.TrimSpace(c.Request.RemoteAddr)
|
||||
}
|
||||
log.Infof("responses websocket: client connected id=%s remote=%s", passthroughSessionID, clientRemoteAddr)
|
||||
var wsTerminateErr error
|
||||
var wsBodyLog strings.Builder
|
||||
defer func() {
|
||||
if wsTerminateErr != nil {
|
||||
// log.Infof("responses websocket: session closing id=%s reason=%v", passthroughSessionID, wsTerminateErr)
|
||||
} else {
|
||||
log.Infof("responses websocket: session closing id=%s", passthroughSessionID)
|
||||
}
|
||||
if h != nil && h.AuthManager != nil {
|
||||
h.AuthManager.CloseExecutionSession(passthroughSessionID)
|
||||
log.Infof("responses websocket: upstream execution session closed id=%s", passthroughSessionID)
|
||||
}
|
||||
setWebsocketRequestBody(c, wsBodyLog.String())
|
||||
if errClose := conn.Close(); errClose != nil {
|
||||
log.Warnf("responses websocket: close connection error: %v", errClose)
|
||||
}
|
||||
}()
|
||||
|
||||
var lastRequest []byte
|
||||
lastResponseOutput := []byte("[]")
|
||||
pinnedAuthID := ""
|
||||
|
||||
for {
|
||||
msgType, payload, errReadMessage := conn.ReadMessage()
|
||||
if errReadMessage != nil {
|
||||
wsTerminateErr = errReadMessage
|
||||
appendWebsocketEvent(&wsBodyLog, "disconnect", []byte(errReadMessage.Error()))
|
||||
if websocket.IsCloseError(errReadMessage, websocket.CloseNormalClosure, websocket.CloseGoingAway, websocket.CloseNoStatusReceived) {
|
||||
log.Infof("responses websocket: client disconnected id=%s error=%v", passthroughSessionID, errReadMessage)
|
||||
} else {
|
||||
// log.Warnf("responses websocket: read message failed id=%s error=%v", passthroughSessionID, errReadMessage)
|
||||
}
|
||||
return
|
||||
}
|
||||
if msgType != websocket.TextMessage && msgType != websocket.BinaryMessage {
|
||||
continue
|
||||
}
|
||||
// log.Infof(
|
||||
// "responses websocket: downstream_in id=%s type=%d event=%s payload=%s",
|
||||
// passthroughSessionID,
|
||||
// msgType,
|
||||
// websocketPayloadEventType(payload),
|
||||
// websocketPayloadPreview(payload),
|
||||
// )
|
||||
appendWebsocketEvent(&wsBodyLog, "request", payload)
|
||||
|
||||
allowIncrementalInputWithPreviousResponseID := websocketUpstreamSupportsIncrementalInput(nil, nil)
|
||||
if pinnedAuthID != "" && h != nil && h.AuthManager != nil {
|
||||
if pinnedAuth, ok := h.AuthManager.GetByID(pinnedAuthID); ok && pinnedAuth != nil {
|
||||
allowIncrementalInputWithPreviousResponseID = websocketUpstreamSupportsIncrementalInput(pinnedAuth.Attributes, pinnedAuth.Metadata)
|
||||
}
|
||||
}
|
||||
|
||||
var requestJSON []byte
|
||||
var updatedLastRequest []byte
|
||||
var errMsg *interfaces.ErrorMessage
|
||||
requestJSON, updatedLastRequest, errMsg = normalizeResponsesWebsocketRequestWithMode(
|
||||
payload,
|
||||
lastRequest,
|
||||
lastResponseOutput,
|
||||
allowIncrementalInputWithPreviousResponseID,
|
||||
)
|
||||
if errMsg != nil {
|
||||
h.LoggingAPIResponseError(context.WithValue(context.Background(), "gin", c), errMsg)
|
||||
markAPIResponseTimestamp(c)
|
||||
errorPayload, errWrite := writeResponsesWebsocketError(conn, errMsg)
|
||||
appendWebsocketEvent(&wsBodyLog, "response", errorPayload)
|
||||
log.Infof(
|
||||
"responses websocket: downstream_out id=%s type=%d event=%s payload=%s",
|
||||
passthroughSessionID,
|
||||
websocket.TextMessage,
|
||||
websocketPayloadEventType(errorPayload),
|
||||
websocketPayloadPreview(errorPayload),
|
||||
)
|
||||
if errWrite != nil {
|
||||
log.Warnf(
|
||||
"responses websocket: downstream_out write failed id=%s event=%s error=%v",
|
||||
passthroughSessionID,
|
||||
websocketPayloadEventType(errorPayload),
|
||||
errWrite,
|
||||
)
|
||||
return
|
||||
}
|
||||
continue
|
||||
}
|
||||
lastRequest = updatedLastRequest
|
||||
|
||||
modelName := gjson.GetBytes(requestJSON, "model").String()
|
||||
cliCtx, cliCancel := h.GetContextWithCancel(h, c, context.Background())
|
||||
cliCtx = cliproxyexecutor.WithDownstreamWebsocket(cliCtx)
|
||||
cliCtx = handlers.WithExecutionSessionID(cliCtx, passthroughSessionID)
|
||||
if pinnedAuthID != "" {
|
||||
cliCtx = handlers.WithPinnedAuthID(cliCtx, pinnedAuthID)
|
||||
} else {
|
||||
cliCtx = handlers.WithSelectedAuthIDCallback(cliCtx, func(authID string) {
|
||||
pinnedAuthID = strings.TrimSpace(authID)
|
||||
})
|
||||
}
|
||||
dataChan, errChan := h.ExecuteStreamWithAuthManager(cliCtx, h.HandlerType(), modelName, requestJSON, "")
|
||||
|
||||
completedOutput, errForward := h.forwardResponsesWebsocket(c, conn, cliCancel, dataChan, errChan, &wsBodyLog, passthroughSessionID)
|
||||
if errForward != nil {
|
||||
wsTerminateErr = errForward
|
||||
appendWebsocketEvent(&wsBodyLog, "disconnect", []byte(errForward.Error()))
|
||||
log.Warnf("responses websocket: forward failed id=%s error=%v", passthroughSessionID, errForward)
|
||||
return
|
||||
}
|
||||
lastResponseOutput = completedOutput
|
||||
}
|
||||
}
|
||||
|
||||
func websocketUpgradeHeaders(req *http.Request) http.Header {
|
||||
headers := http.Header{}
|
||||
if req == nil {
|
||||
return headers
|
||||
}
|
||||
|
||||
// Keep the same sticky turn-state across reconnects when provided by the client.
|
||||
turnState := strings.TrimSpace(req.Header.Get(wsTurnStateHeader))
|
||||
if turnState != "" {
|
||||
headers.Set(wsTurnStateHeader, turnState)
|
||||
}
|
||||
return headers
|
||||
}
|
||||
|
||||
func normalizeResponsesWebsocketRequest(rawJSON []byte, lastRequest []byte, lastResponseOutput []byte) ([]byte, []byte, *interfaces.ErrorMessage) {
|
||||
return normalizeResponsesWebsocketRequestWithMode(rawJSON, lastRequest, lastResponseOutput, true)
|
||||
}
|
||||
|
||||
func normalizeResponsesWebsocketRequestWithMode(rawJSON []byte, lastRequest []byte, lastResponseOutput []byte, allowIncrementalInputWithPreviousResponseID bool) ([]byte, []byte, *interfaces.ErrorMessage) {
|
||||
requestType := strings.TrimSpace(gjson.GetBytes(rawJSON, "type").String())
|
||||
switch requestType {
|
||||
case wsRequestTypeCreate:
|
||||
// log.Infof("responses websocket: response.create request")
|
||||
if len(lastRequest) == 0 {
|
||||
return normalizeResponseCreateRequest(rawJSON)
|
||||
}
|
||||
return normalizeResponseSubsequentRequest(rawJSON, lastRequest, lastResponseOutput, allowIncrementalInputWithPreviousResponseID)
|
||||
case wsRequestTypeAppend:
|
||||
// log.Infof("responses websocket: response.append request")
|
||||
return normalizeResponseSubsequentRequest(rawJSON, lastRequest, lastResponseOutput, allowIncrementalInputWithPreviousResponseID)
|
||||
default:
|
||||
return nil, lastRequest, &interfaces.ErrorMessage{
|
||||
StatusCode: http.StatusBadRequest,
|
||||
Error: fmt.Errorf("unsupported websocket request type: %s", requestType),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func normalizeResponseCreateRequest(rawJSON []byte) ([]byte, []byte, *interfaces.ErrorMessage) {
|
||||
normalized, errDelete := sjson.DeleteBytes(rawJSON, "type")
|
||||
if errDelete != nil {
|
||||
normalized = bytes.Clone(rawJSON)
|
||||
}
|
||||
normalized, _ = sjson.SetBytes(normalized, "stream", true)
|
||||
if !gjson.GetBytes(normalized, "input").Exists() {
|
||||
normalized, _ = sjson.SetRawBytes(normalized, "input", []byte("[]"))
|
||||
}
|
||||
|
||||
modelName := strings.TrimSpace(gjson.GetBytes(normalized, "model").String())
|
||||
if modelName == "" {
|
||||
return nil, nil, &interfaces.ErrorMessage{
|
||||
StatusCode: http.StatusBadRequest,
|
||||
Error: fmt.Errorf("missing model in response.create request"),
|
||||
}
|
||||
}
|
||||
return normalized, bytes.Clone(normalized), nil
|
||||
}
|
||||
|
||||
func normalizeResponseSubsequentRequest(rawJSON []byte, lastRequest []byte, lastResponseOutput []byte, allowIncrementalInputWithPreviousResponseID bool) ([]byte, []byte, *interfaces.ErrorMessage) {
|
||||
if len(lastRequest) == 0 {
|
||||
return nil, lastRequest, &interfaces.ErrorMessage{
|
||||
StatusCode: http.StatusBadRequest,
|
||||
Error: fmt.Errorf("websocket request received before response.create"),
|
||||
}
|
||||
}
|
||||
|
||||
nextInput := gjson.GetBytes(rawJSON, "input")
|
||||
if !nextInput.Exists() || !nextInput.IsArray() {
|
||||
return nil, lastRequest, &interfaces.ErrorMessage{
|
||||
StatusCode: http.StatusBadRequest,
|
||||
Error: fmt.Errorf("websocket request requires array field: input"),
|
||||
}
|
||||
}
|
||||
|
||||
// Websocket v2 mode uses response.create with previous_response_id + incremental input.
|
||||
// Do not expand it into a full input transcript; upstream expects the incremental payload.
|
||||
if allowIncrementalInputWithPreviousResponseID {
|
||||
if prev := strings.TrimSpace(gjson.GetBytes(rawJSON, "previous_response_id").String()); prev != "" {
|
||||
normalized, errDelete := sjson.DeleteBytes(rawJSON, "type")
|
||||
if errDelete != nil {
|
||||
normalized = bytes.Clone(rawJSON)
|
||||
}
|
||||
if !gjson.GetBytes(normalized, "model").Exists() {
|
||||
modelName := strings.TrimSpace(gjson.GetBytes(lastRequest, "model").String())
|
||||
if modelName != "" {
|
||||
normalized, _ = sjson.SetBytes(normalized, "model", modelName)
|
||||
}
|
||||
}
|
||||
if !gjson.GetBytes(normalized, "instructions").Exists() {
|
||||
instructions := gjson.GetBytes(lastRequest, "instructions")
|
||||
if instructions.Exists() {
|
||||
normalized, _ = sjson.SetRawBytes(normalized, "instructions", []byte(instructions.Raw))
|
||||
}
|
||||
}
|
||||
normalized, _ = sjson.SetBytes(normalized, "stream", true)
|
||||
return normalized, bytes.Clone(normalized), nil
|
||||
}
|
||||
}
|
||||
|
||||
existingInput := gjson.GetBytes(lastRequest, "input")
|
||||
mergedInput, errMerge := mergeJSONArrayRaw(existingInput.Raw, normalizeJSONArrayRaw(lastResponseOutput))
|
||||
if errMerge != nil {
|
||||
return nil, lastRequest, &interfaces.ErrorMessage{
|
||||
StatusCode: http.StatusBadRequest,
|
||||
Error: fmt.Errorf("invalid previous response output: %w", errMerge),
|
||||
}
|
||||
}
|
||||
|
||||
mergedInput, errMerge = mergeJSONArrayRaw(mergedInput, nextInput.Raw)
|
||||
if errMerge != nil {
|
||||
return nil, lastRequest, &interfaces.ErrorMessage{
|
||||
StatusCode: http.StatusBadRequest,
|
||||
Error: fmt.Errorf("invalid request input: %w", errMerge),
|
||||
}
|
||||
}
|
||||
|
||||
normalized, errDelete := sjson.DeleteBytes(rawJSON, "type")
|
||||
if errDelete != nil {
|
||||
normalized = bytes.Clone(rawJSON)
|
||||
}
|
||||
normalized, _ = sjson.DeleteBytes(normalized, "previous_response_id")
|
||||
var errSet error
|
||||
normalized, errSet = sjson.SetRawBytes(normalized, "input", []byte(mergedInput))
|
||||
if errSet != nil {
|
||||
return nil, lastRequest, &interfaces.ErrorMessage{
|
||||
StatusCode: http.StatusBadRequest,
|
||||
Error: fmt.Errorf("failed to merge websocket input: %w", errSet),
|
||||
}
|
||||
}
|
||||
if !gjson.GetBytes(normalized, "model").Exists() {
|
||||
modelName := strings.TrimSpace(gjson.GetBytes(lastRequest, "model").String())
|
||||
if modelName != "" {
|
||||
normalized, _ = sjson.SetBytes(normalized, "model", modelName)
|
||||
}
|
||||
}
|
||||
if !gjson.GetBytes(normalized, "instructions").Exists() {
|
||||
instructions := gjson.GetBytes(lastRequest, "instructions")
|
||||
if instructions.Exists() {
|
||||
normalized, _ = sjson.SetRawBytes(normalized, "instructions", []byte(instructions.Raw))
|
||||
}
|
||||
}
|
||||
normalized, _ = sjson.SetBytes(normalized, "stream", true)
|
||||
return normalized, bytes.Clone(normalized), nil
|
||||
}
|
||||
|
||||
func websocketUpstreamSupportsIncrementalInput(attributes map[string]string, metadata map[string]any) bool {
|
||||
if len(attributes) > 0 {
|
||||
if raw := strings.TrimSpace(attributes["websockets"]); raw != "" {
|
||||
parsed, errParse := strconv.ParseBool(raw)
|
||||
if errParse == nil {
|
||||
return parsed
|
||||
}
|
||||
}
|
||||
}
|
||||
if len(metadata) == 0 {
|
||||
return false
|
||||
}
|
||||
raw, ok := metadata["websockets"]
|
||||
if !ok || raw == nil {
|
||||
return false
|
||||
}
|
||||
switch value := raw.(type) {
|
||||
case bool:
|
||||
return value
|
||||
case string:
|
||||
parsed, errParse := strconv.ParseBool(strings.TrimSpace(value))
|
||||
if errParse == nil {
|
||||
return parsed
|
||||
}
|
||||
default:
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func mergeJSONArrayRaw(existingRaw, appendRaw string) (string, error) {
|
||||
existingRaw = strings.TrimSpace(existingRaw)
|
||||
appendRaw = strings.TrimSpace(appendRaw)
|
||||
if existingRaw == "" {
|
||||
existingRaw = "[]"
|
||||
}
|
||||
if appendRaw == "" {
|
||||
appendRaw = "[]"
|
||||
}
|
||||
|
||||
var existing []json.RawMessage
|
||||
if err := json.Unmarshal([]byte(existingRaw), &existing); err != nil {
|
||||
return "", err
|
||||
}
|
||||
var appendItems []json.RawMessage
|
||||
if err := json.Unmarshal([]byte(appendRaw), &appendItems); err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
merged := append(existing, appendItems...)
|
||||
out, err := json.Marshal(merged)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return string(out), nil
|
||||
}
|
||||
|
||||
func normalizeJSONArrayRaw(raw []byte) string {
|
||||
trimmed := strings.TrimSpace(string(raw))
|
||||
if trimmed == "" {
|
||||
return "[]"
|
||||
}
|
||||
result := gjson.Parse(trimmed)
|
||||
if result.Type == gjson.JSON && result.IsArray() {
|
||||
return trimmed
|
||||
}
|
||||
return "[]"
|
||||
}
|
||||
|
||||
func (h *OpenAIResponsesAPIHandler) forwardResponsesWebsocket(
|
||||
c *gin.Context,
|
||||
conn *websocket.Conn,
|
||||
cancel handlers.APIHandlerCancelFunc,
|
||||
data <-chan []byte,
|
||||
errs <-chan *interfaces.ErrorMessage,
|
||||
wsBodyLog *strings.Builder,
|
||||
sessionID string,
|
||||
) ([]byte, error) {
|
||||
completed := false
|
||||
completedOutput := []byte("[]")
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-c.Request.Context().Done():
|
||||
cancel(c.Request.Context().Err())
|
||||
return completedOutput, c.Request.Context().Err()
|
||||
case errMsg, ok := <-errs:
|
||||
if !ok {
|
||||
errs = nil
|
||||
continue
|
||||
}
|
||||
if errMsg != nil {
|
||||
h.LoggingAPIResponseError(context.WithValue(context.Background(), "gin", c), errMsg)
|
||||
markAPIResponseTimestamp(c)
|
||||
errorPayload, errWrite := writeResponsesWebsocketError(conn, errMsg)
|
||||
appendWebsocketEvent(wsBodyLog, "response", errorPayload)
|
||||
log.Infof(
|
||||
"responses websocket: downstream_out id=%s type=%d event=%s payload=%s",
|
||||
sessionID,
|
||||
websocket.TextMessage,
|
||||
websocketPayloadEventType(errorPayload),
|
||||
websocketPayloadPreview(errorPayload),
|
||||
)
|
||||
if errWrite != nil {
|
||||
// log.Warnf(
|
||||
// "responses websocket: downstream_out write failed id=%s event=%s error=%v",
|
||||
// sessionID,
|
||||
// websocketPayloadEventType(errorPayload),
|
||||
// errWrite,
|
||||
// )
|
||||
cancel(errMsg.Error)
|
||||
return completedOutput, errWrite
|
||||
}
|
||||
}
|
||||
if errMsg != nil {
|
||||
cancel(errMsg.Error)
|
||||
} else {
|
||||
cancel(nil)
|
||||
}
|
||||
return completedOutput, nil
|
||||
case chunk, ok := <-data:
|
||||
if !ok {
|
||||
if !completed {
|
||||
errMsg := &interfaces.ErrorMessage{
|
||||
StatusCode: http.StatusRequestTimeout,
|
||||
Error: fmt.Errorf("stream closed before response.completed"),
|
||||
}
|
||||
h.LoggingAPIResponseError(context.WithValue(context.Background(), "gin", c), errMsg)
|
||||
markAPIResponseTimestamp(c)
|
||||
errorPayload, errWrite := writeResponsesWebsocketError(conn, errMsg)
|
||||
appendWebsocketEvent(wsBodyLog, "response", errorPayload)
|
||||
log.Infof(
|
||||
"responses websocket: downstream_out id=%s type=%d event=%s payload=%s",
|
||||
sessionID,
|
||||
websocket.TextMessage,
|
||||
websocketPayloadEventType(errorPayload),
|
||||
websocketPayloadPreview(errorPayload),
|
||||
)
|
||||
if errWrite != nil {
|
||||
log.Warnf(
|
||||
"responses websocket: downstream_out write failed id=%s event=%s error=%v",
|
||||
sessionID,
|
||||
websocketPayloadEventType(errorPayload),
|
||||
errWrite,
|
||||
)
|
||||
cancel(errMsg.Error)
|
||||
return completedOutput, errWrite
|
||||
}
|
||||
cancel(errMsg.Error)
|
||||
return completedOutput, nil
|
||||
}
|
||||
cancel(nil)
|
||||
return completedOutput, nil
|
||||
}
|
||||
|
||||
payloads := websocketJSONPayloadsFromChunk(chunk)
|
||||
for i := range payloads {
|
||||
eventType := gjson.GetBytes(payloads[i], "type").String()
|
||||
if eventType == wsEventTypeCompleted {
|
||||
// log.Infof("replace %s with %s", wsEventTypeCompleted, wsEventTypeDone)
|
||||
payloads[i], _ = sjson.SetBytes(payloads[i], "type", wsEventTypeDone)
|
||||
|
||||
completed = true
|
||||
completedOutput = responseCompletedOutputFromPayload(payloads[i])
|
||||
}
|
||||
markAPIResponseTimestamp(c)
|
||||
appendWebsocketEvent(wsBodyLog, "response", payloads[i])
|
||||
// log.Infof(
|
||||
// "responses websocket: downstream_out id=%s type=%d event=%s payload=%s",
|
||||
// sessionID,
|
||||
// websocket.TextMessage,
|
||||
// websocketPayloadEventType(payloads[i]),
|
||||
// websocketPayloadPreview(payloads[i]),
|
||||
// )
|
||||
if errWrite := conn.WriteMessage(websocket.TextMessage, payloads[i]); errWrite != nil {
|
||||
log.Warnf(
|
||||
"responses websocket: downstream_out write failed id=%s event=%s error=%v",
|
||||
sessionID,
|
||||
websocketPayloadEventType(payloads[i]),
|
||||
errWrite,
|
||||
)
|
||||
cancel(errWrite)
|
||||
return completedOutput, errWrite
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func responseCompletedOutputFromPayload(payload []byte) []byte {
|
||||
output := gjson.GetBytes(payload, "response.output")
|
||||
if output.Exists() && output.IsArray() {
|
||||
return bytes.Clone([]byte(output.Raw))
|
||||
}
|
||||
return []byte("[]")
|
||||
}
|
||||
|
||||
func websocketJSONPayloadsFromChunk(chunk []byte) [][]byte {
|
||||
payloads := make([][]byte, 0, 2)
|
||||
lines := bytes.Split(chunk, []byte("\n"))
|
||||
for i := range lines {
|
||||
line := bytes.TrimSpace(lines[i])
|
||||
if len(line) == 0 || bytes.HasPrefix(line, []byte("event:")) {
|
||||
continue
|
||||
}
|
||||
if bytes.HasPrefix(line, []byte("data:")) {
|
||||
line = bytes.TrimSpace(line[len("data:"):])
|
||||
}
|
||||
if len(line) == 0 || bytes.Equal(line, []byte(wsDoneMarker)) {
|
||||
continue
|
||||
}
|
||||
if json.Valid(line) {
|
||||
payloads = append(payloads, bytes.Clone(line))
|
||||
}
|
||||
}
|
||||
|
||||
if len(payloads) > 0 {
|
||||
return payloads
|
||||
}
|
||||
|
||||
trimmed := bytes.TrimSpace(chunk)
|
||||
if bytes.HasPrefix(trimmed, []byte("data:")) {
|
||||
trimmed = bytes.TrimSpace(trimmed[len("data:"):])
|
||||
}
|
||||
if len(trimmed) > 0 && !bytes.Equal(trimmed, []byte(wsDoneMarker)) && json.Valid(trimmed) {
|
||||
payloads = append(payloads, bytes.Clone(trimmed))
|
||||
}
|
||||
return payloads
|
||||
}
|
||||
|
||||
func writeResponsesWebsocketError(conn *websocket.Conn, errMsg *interfaces.ErrorMessage) ([]byte, error) {
|
||||
status := http.StatusInternalServerError
|
||||
errText := http.StatusText(status)
|
||||
if errMsg != nil {
|
||||
if errMsg.StatusCode > 0 {
|
||||
status = errMsg.StatusCode
|
||||
errText = http.StatusText(status)
|
||||
}
|
||||
if errMsg.Error != nil && strings.TrimSpace(errMsg.Error.Error()) != "" {
|
||||
errText = errMsg.Error.Error()
|
||||
}
|
||||
}
|
||||
|
||||
body := handlers.BuildErrorResponseBody(status, errText)
|
||||
payload := map[string]any{
|
||||
"type": wsEventTypeError,
|
||||
"status": status,
|
||||
}
|
||||
|
||||
if errMsg != nil && errMsg.Addon != nil {
|
||||
headers := map[string]any{}
|
||||
for key, values := range errMsg.Addon {
|
||||
if len(values) == 0 {
|
||||
continue
|
||||
}
|
||||
headers[key] = values[0]
|
||||
}
|
||||
if len(headers) > 0 {
|
||||
payload["headers"] = headers
|
||||
}
|
||||
}
|
||||
|
||||
if len(body) > 0 && json.Valid(body) {
|
||||
var decoded map[string]any
|
||||
if errDecode := json.Unmarshal(body, &decoded); errDecode == nil {
|
||||
if inner, ok := decoded["error"]; ok {
|
||||
payload["error"] = inner
|
||||
} else {
|
||||
payload["error"] = decoded
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if _, ok := payload["error"]; !ok {
|
||||
payload["error"] = map[string]any{
|
||||
"type": "server_error",
|
||||
"message": errText,
|
||||
}
|
||||
}
|
||||
|
||||
data, err := json.Marshal(payload)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return data, conn.WriteMessage(websocket.TextMessage, data)
|
||||
}
|
||||
|
||||
func appendWebsocketEvent(builder *strings.Builder, eventType string, payload []byte) {
|
||||
if builder == nil {
|
||||
return
|
||||
}
|
||||
trimmedPayload := bytes.TrimSpace(payload)
|
||||
if len(trimmedPayload) == 0 {
|
||||
return
|
||||
}
|
||||
if builder.Len() > 0 {
|
||||
builder.WriteString("\n")
|
||||
}
|
||||
builder.WriteString("websocket.")
|
||||
builder.WriteString(eventType)
|
||||
builder.WriteString("\n")
|
||||
builder.Write(trimmedPayload)
|
||||
builder.WriteString("\n")
|
||||
}
|
||||
|
||||
func websocketPayloadEventType(payload []byte) string {
|
||||
eventType := strings.TrimSpace(gjson.GetBytes(payload, "type").String())
|
||||
if eventType == "" {
|
||||
return "-"
|
||||
}
|
||||
return eventType
|
||||
}
|
||||
|
||||
func websocketPayloadPreview(payload []byte) string {
|
||||
trimmedPayload := bytes.TrimSpace(payload)
|
||||
if len(trimmedPayload) == 0 {
|
||||
return "<empty>"
|
||||
}
|
||||
preview := trimmedPayload
|
||||
if len(preview) > wsPayloadLogMaxSize {
|
||||
preview = preview[:wsPayloadLogMaxSize]
|
||||
}
|
||||
previewText := strings.ReplaceAll(string(preview), "\n", "\\n")
|
||||
previewText = strings.ReplaceAll(previewText, "\r", "\\r")
|
||||
if len(trimmedPayload) > wsPayloadLogMaxSize {
|
||||
return fmt.Sprintf("%s...(truncated,total=%d)", previewText, len(trimmedPayload))
|
||||
}
|
||||
return previewText
|
||||
}
|
||||
|
||||
func setWebsocketRequestBody(c *gin.Context, body string) {
|
||||
if c == nil {
|
||||
return
|
||||
}
|
||||
trimmedBody := strings.TrimSpace(body)
|
||||
if trimmedBody == "" {
|
||||
return
|
||||
}
|
||||
c.Set(wsRequestBodyKey, []byte(trimmedBody))
|
||||
}
|
||||
|
||||
func markAPIResponseTimestamp(c *gin.Context) {
|
||||
if c == nil {
|
||||
return
|
||||
}
|
||||
if _, exists := c.Get("API_RESPONSE_TIMESTAMP"); exists {
|
||||
return
|
||||
}
|
||||
c.Set("API_RESPONSE_TIMESTAMP", time.Now())
|
||||
}
|
||||
249
sdk/api/handlers/openai/openai_responses_websocket_test.go
Normal file
249
sdk/api/handlers/openai/openai_responses_websocket_test.go
Normal file
@@ -0,0 +1,249 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/tidwall/gjson"
|
||||
)
|
||||
|
||||
func TestNormalizeResponsesWebsocketRequestCreate(t *testing.T) {
|
||||
raw := []byte(`{"type":"response.create","model":"test-model","stream":false,"input":[{"type":"message","id":"msg-1"}]}`)
|
||||
|
||||
normalized, last, errMsg := normalizeResponsesWebsocketRequest(raw, nil, nil)
|
||||
if errMsg != nil {
|
||||
t.Fatalf("unexpected error: %v", errMsg.Error)
|
||||
}
|
||||
if gjson.GetBytes(normalized, "type").Exists() {
|
||||
t.Fatalf("normalized create request must not include type field")
|
||||
}
|
||||
if !gjson.GetBytes(normalized, "stream").Bool() {
|
||||
t.Fatalf("normalized create request must force stream=true")
|
||||
}
|
||||
if gjson.GetBytes(normalized, "model").String() != "test-model" {
|
||||
t.Fatalf("unexpected model: %s", gjson.GetBytes(normalized, "model").String())
|
||||
}
|
||||
if !bytes.Equal(last, normalized) {
|
||||
t.Fatalf("last request snapshot should match normalized request")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeResponsesWebsocketRequestCreateWithHistory(t *testing.T) {
|
||||
lastRequest := []byte(`{"model":"test-model","stream":true,"input":[{"type":"message","id":"msg-1"}]}`)
|
||||
lastResponseOutput := []byte(`[
|
||||
{"type":"function_call","id":"fc-1","call_id":"call-1"},
|
||||
{"type":"message","id":"assistant-1"}
|
||||
]`)
|
||||
raw := []byte(`{"type":"response.create","input":[{"type":"function_call_output","call_id":"call-1","id":"tool-out-1"}]}`)
|
||||
|
||||
normalized, next, errMsg := normalizeResponsesWebsocketRequest(raw, lastRequest, lastResponseOutput)
|
||||
if errMsg != nil {
|
||||
t.Fatalf("unexpected error: %v", errMsg.Error)
|
||||
}
|
||||
if gjson.GetBytes(normalized, "type").Exists() {
|
||||
t.Fatalf("normalized subsequent create request must not include type field")
|
||||
}
|
||||
if gjson.GetBytes(normalized, "model").String() != "test-model" {
|
||||
t.Fatalf("unexpected model: %s", gjson.GetBytes(normalized, "model").String())
|
||||
}
|
||||
|
||||
input := gjson.GetBytes(normalized, "input").Array()
|
||||
if len(input) != 4 {
|
||||
t.Fatalf("merged input len = %d, want 4", len(input))
|
||||
}
|
||||
if input[0].Get("id").String() != "msg-1" ||
|
||||
input[1].Get("id").String() != "fc-1" ||
|
||||
input[2].Get("id").String() != "assistant-1" ||
|
||||
input[3].Get("id").String() != "tool-out-1" {
|
||||
t.Fatalf("unexpected merged input order")
|
||||
}
|
||||
if !bytes.Equal(next, normalized) {
|
||||
t.Fatalf("next request snapshot should match normalized request")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeResponsesWebsocketRequestWithPreviousResponseIDIncremental(t *testing.T) {
|
||||
lastRequest := []byte(`{"model":"test-model","stream":true,"instructions":"be helpful","input":[{"type":"message","id":"msg-1"}]}`)
|
||||
lastResponseOutput := []byte(`[
|
||||
{"type":"function_call","id":"fc-1","call_id":"call-1"},
|
||||
{"type":"message","id":"assistant-1"}
|
||||
]`)
|
||||
raw := []byte(`{"type":"response.create","previous_response_id":"resp-1","input":[{"type":"function_call_output","call_id":"call-1","id":"tool-out-1"}]}`)
|
||||
|
||||
normalized, next, errMsg := normalizeResponsesWebsocketRequestWithMode(raw, lastRequest, lastResponseOutput, true)
|
||||
if errMsg != nil {
|
||||
t.Fatalf("unexpected error: %v", errMsg.Error)
|
||||
}
|
||||
if gjson.GetBytes(normalized, "type").Exists() {
|
||||
t.Fatalf("normalized request must not include type field")
|
||||
}
|
||||
if gjson.GetBytes(normalized, "previous_response_id").String() != "resp-1" {
|
||||
t.Fatalf("previous_response_id must be preserved in incremental mode")
|
||||
}
|
||||
input := gjson.GetBytes(normalized, "input").Array()
|
||||
if len(input) != 1 {
|
||||
t.Fatalf("incremental input len = %d, want 1", len(input))
|
||||
}
|
||||
if input[0].Get("id").String() != "tool-out-1" {
|
||||
t.Fatalf("unexpected incremental input item id: %s", input[0].Get("id").String())
|
||||
}
|
||||
if gjson.GetBytes(normalized, "model").String() != "test-model" {
|
||||
t.Fatalf("unexpected model: %s", gjson.GetBytes(normalized, "model").String())
|
||||
}
|
||||
if gjson.GetBytes(normalized, "instructions").String() != "be helpful" {
|
||||
t.Fatalf("unexpected instructions: %s", gjson.GetBytes(normalized, "instructions").String())
|
||||
}
|
||||
if !bytes.Equal(next, normalized) {
|
||||
t.Fatalf("next request snapshot should match normalized request")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeResponsesWebsocketRequestWithPreviousResponseIDMergedWhenIncrementalDisabled(t *testing.T) {
|
||||
lastRequest := []byte(`{"model":"test-model","stream":true,"input":[{"type":"message","id":"msg-1"}]}`)
|
||||
lastResponseOutput := []byte(`[
|
||||
{"type":"function_call","id":"fc-1","call_id":"call-1"},
|
||||
{"type":"message","id":"assistant-1"}
|
||||
]`)
|
||||
raw := []byte(`{"type":"response.create","previous_response_id":"resp-1","input":[{"type":"function_call_output","call_id":"call-1","id":"tool-out-1"}]}`)
|
||||
|
||||
normalized, next, errMsg := normalizeResponsesWebsocketRequestWithMode(raw, lastRequest, lastResponseOutput, false)
|
||||
if errMsg != nil {
|
||||
t.Fatalf("unexpected error: %v", errMsg.Error)
|
||||
}
|
||||
if gjson.GetBytes(normalized, "previous_response_id").Exists() {
|
||||
t.Fatalf("previous_response_id must be removed when incremental mode is disabled")
|
||||
}
|
||||
input := gjson.GetBytes(normalized, "input").Array()
|
||||
if len(input) != 4 {
|
||||
t.Fatalf("merged input len = %d, want 4", len(input))
|
||||
}
|
||||
if input[0].Get("id").String() != "msg-1" ||
|
||||
input[1].Get("id").String() != "fc-1" ||
|
||||
input[2].Get("id").String() != "assistant-1" ||
|
||||
input[3].Get("id").String() != "tool-out-1" {
|
||||
t.Fatalf("unexpected merged input order")
|
||||
}
|
||||
if !bytes.Equal(next, normalized) {
|
||||
t.Fatalf("next request snapshot should match normalized request")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeResponsesWebsocketRequestAppend(t *testing.T) {
|
||||
lastRequest := []byte(`{"model":"test-model","stream":true,"input":[{"type":"message","id":"msg-1"}]}`)
|
||||
lastResponseOutput := []byte(`[
|
||||
{"type":"message","id":"assistant-1"},
|
||||
{"type":"function_call_output","id":"tool-out-1"}
|
||||
]`)
|
||||
raw := []byte(`{"type":"response.append","input":[{"type":"message","id":"msg-2"},{"type":"message","id":"msg-3"}]}`)
|
||||
|
||||
normalized, next, errMsg := normalizeResponsesWebsocketRequest(raw, lastRequest, lastResponseOutput)
|
||||
if errMsg != nil {
|
||||
t.Fatalf("unexpected error: %v", errMsg.Error)
|
||||
}
|
||||
input := gjson.GetBytes(normalized, "input").Array()
|
||||
if len(input) != 5 {
|
||||
t.Fatalf("merged input len = %d, want 5", len(input))
|
||||
}
|
||||
if input[0].Get("id").String() != "msg-1" ||
|
||||
input[1].Get("id").String() != "assistant-1" ||
|
||||
input[2].Get("id").String() != "tool-out-1" ||
|
||||
input[3].Get("id").String() != "msg-2" ||
|
||||
input[4].Get("id").String() != "msg-3" {
|
||||
t.Fatalf("unexpected merged input order")
|
||||
}
|
||||
if !bytes.Equal(next, normalized) {
|
||||
t.Fatalf("next request snapshot should match normalized append request")
|
||||
}
|
||||
}
|
||||
|
||||
func TestNormalizeResponsesWebsocketRequestAppendWithoutCreate(t *testing.T) {
|
||||
raw := []byte(`{"type":"response.append","input":[]}`)
|
||||
|
||||
_, _, errMsg := normalizeResponsesWebsocketRequest(raw, nil, nil)
|
||||
if errMsg == nil {
|
||||
t.Fatalf("expected error for append without previous request")
|
||||
}
|
||||
if errMsg.StatusCode != http.StatusBadRequest {
|
||||
t.Fatalf("status = %d, want %d", errMsg.StatusCode, http.StatusBadRequest)
|
||||
}
|
||||
}
|
||||
|
||||
func TestWebsocketJSONPayloadsFromChunk(t *testing.T) {
|
||||
chunk := []byte("event: response.created\n\ndata: {\"type\":\"response.created\",\"response\":{\"id\":\"resp-1\"}}\n\ndata: [DONE]\n")
|
||||
|
||||
payloads := websocketJSONPayloadsFromChunk(chunk)
|
||||
if len(payloads) != 1 {
|
||||
t.Fatalf("payloads len = %d, want 1", len(payloads))
|
||||
}
|
||||
if gjson.GetBytes(payloads[0], "type").String() != "response.created" {
|
||||
t.Fatalf("unexpected payload type: %s", gjson.GetBytes(payloads[0], "type").String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestWebsocketJSONPayloadsFromPlainJSONChunk(t *testing.T) {
|
||||
chunk := []byte(`{"type":"response.completed","response":{"id":"resp-1"}}`)
|
||||
|
||||
payloads := websocketJSONPayloadsFromChunk(chunk)
|
||||
if len(payloads) != 1 {
|
||||
t.Fatalf("payloads len = %d, want 1", len(payloads))
|
||||
}
|
||||
if gjson.GetBytes(payloads[0], "type").String() != "response.completed" {
|
||||
t.Fatalf("unexpected payload type: %s", gjson.GetBytes(payloads[0], "type").String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestResponseCompletedOutputFromPayload(t *testing.T) {
|
||||
payload := []byte(`{"type":"response.completed","response":{"id":"resp-1","output":[{"type":"message","id":"out-1"}]}}`)
|
||||
|
||||
output := responseCompletedOutputFromPayload(payload)
|
||||
items := gjson.ParseBytes(output).Array()
|
||||
if len(items) != 1 {
|
||||
t.Fatalf("output len = %d, want 1", len(items))
|
||||
}
|
||||
if items[0].Get("id").String() != "out-1" {
|
||||
t.Fatalf("unexpected output id: %s", items[0].Get("id").String())
|
||||
}
|
||||
}
|
||||
|
||||
func TestAppendWebsocketEvent(t *testing.T) {
|
||||
var builder strings.Builder
|
||||
|
||||
appendWebsocketEvent(&builder, "request", []byte(" {\"type\":\"response.create\"}\n"))
|
||||
appendWebsocketEvent(&builder, "response", []byte("{\"type\":\"response.created\"}"))
|
||||
|
||||
got := builder.String()
|
||||
if !strings.Contains(got, "websocket.request\n{\"type\":\"response.create\"}\n") {
|
||||
t.Fatalf("request event not found in body: %s", got)
|
||||
}
|
||||
if !strings.Contains(got, "websocket.response\n{\"type\":\"response.created\"}\n") {
|
||||
t.Fatalf("response event not found in body: %s", got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSetWebsocketRequestBody(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
recorder := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(recorder)
|
||||
|
||||
setWebsocketRequestBody(c, " \n ")
|
||||
if _, exists := c.Get(wsRequestBodyKey); exists {
|
||||
t.Fatalf("request body key should not be set for empty body")
|
||||
}
|
||||
|
||||
setWebsocketRequestBody(c, "event body")
|
||||
value, exists := c.Get(wsRequestBodyKey)
|
||||
if !exists {
|
||||
t.Fatalf("request body key not set")
|
||||
}
|
||||
bodyBytes, ok := value.([]byte)
|
||||
if !ok {
|
||||
t.Fatalf("request body key type mismatch")
|
||||
}
|
||||
if string(bodyBytes) != "event body" {
|
||||
t.Fatalf("request body = %q, want %q", string(bodyBytes), "event body")
|
||||
}
|
||||
}
|
||||
@@ -41,6 +41,17 @@ type ProviderExecutor interface {
|
||||
HttpRequest(ctx context.Context, auth *Auth, req *http.Request) (*http.Response, error)
|
||||
}
|
||||
|
||||
// ExecutionSessionCloser allows executors to release per-session runtime resources.
|
||||
type ExecutionSessionCloser interface {
|
||||
CloseExecutionSession(sessionID string)
|
||||
}
|
||||
|
||||
const (
|
||||
// CloseAllExecutionSessionsID asks an executor to release all active execution sessions.
|
||||
// Executors that do not support this marker may ignore it.
|
||||
CloseAllExecutionSessionsID = "__all_execution_sessions__"
|
||||
)
|
||||
|
||||
// RefreshEvaluator allows runtime state to override refresh decisions.
|
||||
type RefreshEvaluator interface {
|
||||
ShouldRefresh(now time.Time, auth *Auth) bool
|
||||
@@ -389,9 +400,23 @@ func (m *Manager) RegisterExecutor(executor ProviderExecutor) {
|
||||
if executor == nil {
|
||||
return
|
||||
}
|
||||
provider := strings.TrimSpace(executor.Identifier())
|
||||
if provider == "" {
|
||||
return
|
||||
}
|
||||
|
||||
var replaced ProviderExecutor
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
m.executors[executor.Identifier()] = executor
|
||||
replaced = m.executors[provider]
|
||||
m.executors[provider] = executor
|
||||
m.mu.Unlock()
|
||||
|
||||
if replaced == nil || replaced == executor {
|
||||
return
|
||||
}
|
||||
if closer, ok := replaced.(ExecutionSessionCloser); ok && closer != nil {
|
||||
closer.CloseExecutionSession(CloseAllExecutionSessionsID)
|
||||
}
|
||||
}
|
||||
|
||||
// UnregisterExecutor removes the executor associated with the provider key.
|
||||
@@ -581,6 +606,7 @@ func (m *Manager) executeMixedOnce(ctx context.Context, providers []string, req
|
||||
|
||||
entry := logEntryWithRequestID(ctx)
|
||||
debugLogAuthSelection(entry, auth, provider, req.Model)
|
||||
publishSelectedAuthMetadata(opts.Metadata, auth.ID)
|
||||
|
||||
tried[auth.ID] = struct{}{}
|
||||
execCtx := ctx
|
||||
@@ -636,6 +662,7 @@ func (m *Manager) executeCountMixedOnce(ctx context.Context, providers []string,
|
||||
|
||||
entry := logEntryWithRequestID(ctx)
|
||||
debugLogAuthSelection(entry, auth, provider, req.Model)
|
||||
publishSelectedAuthMetadata(opts.Metadata, auth.ID)
|
||||
|
||||
tried[auth.ID] = struct{}{}
|
||||
execCtx := ctx
|
||||
@@ -691,6 +718,7 @@ func (m *Manager) executeStreamMixedOnce(ctx context.Context, providers []string
|
||||
|
||||
entry := logEntryWithRequestID(ctx)
|
||||
debugLogAuthSelection(entry, auth, provider, req.Model)
|
||||
publishSelectedAuthMetadata(opts.Metadata, auth.ID)
|
||||
|
||||
tried[auth.ID] = struct{}{}
|
||||
execCtx := ctx
|
||||
@@ -794,6 +822,38 @@ func hasRequestedModelMetadata(meta map[string]any) bool {
|
||||
}
|
||||
}
|
||||
|
||||
func pinnedAuthIDFromMetadata(meta map[string]any) string {
|
||||
if len(meta) == 0 {
|
||||
return ""
|
||||
}
|
||||
raw, ok := meta[cliproxyexecutor.PinnedAuthMetadataKey]
|
||||
if !ok || raw == nil {
|
||||
return ""
|
||||
}
|
||||
switch val := raw.(type) {
|
||||
case string:
|
||||
return strings.TrimSpace(val)
|
||||
case []byte:
|
||||
return strings.TrimSpace(string(val))
|
||||
default:
|
||||
return ""
|
||||
}
|
||||
}
|
||||
|
||||
func publishSelectedAuthMetadata(meta map[string]any, authID string) {
|
||||
if len(meta) == 0 {
|
||||
return
|
||||
}
|
||||
authID = strings.TrimSpace(authID)
|
||||
if authID == "" {
|
||||
return
|
||||
}
|
||||
meta[cliproxyexecutor.SelectedAuthMetadataKey] = authID
|
||||
if callback, ok := meta[cliproxyexecutor.SelectedAuthCallbackMetadataKey].(func(string)); ok && callback != nil {
|
||||
callback(authID)
|
||||
}
|
||||
}
|
||||
|
||||
func rewriteModelForAuth(model string, auth *Auth) string {
|
||||
if auth == nil || model == "" {
|
||||
return model
|
||||
@@ -1550,7 +1610,56 @@ func (m *Manager) GetByID(id string) (*Auth, bool) {
|
||||
return auth.Clone(), true
|
||||
}
|
||||
|
||||
// Executor returns the registered provider executor for a provider key.
|
||||
func (m *Manager) Executor(provider string) (ProviderExecutor, bool) {
|
||||
if m == nil {
|
||||
return nil, false
|
||||
}
|
||||
provider = strings.TrimSpace(provider)
|
||||
if provider == "" {
|
||||
return nil, false
|
||||
}
|
||||
|
||||
m.mu.RLock()
|
||||
executor, okExecutor := m.executors[provider]
|
||||
if !okExecutor {
|
||||
lowerProvider := strings.ToLower(provider)
|
||||
if lowerProvider != provider {
|
||||
executor, okExecutor = m.executors[lowerProvider]
|
||||
}
|
||||
}
|
||||
m.mu.RUnlock()
|
||||
|
||||
if !okExecutor || executor == nil {
|
||||
return nil, false
|
||||
}
|
||||
return executor, true
|
||||
}
|
||||
|
||||
// CloseExecutionSession asks all registered executors to release the supplied execution session.
|
||||
func (m *Manager) CloseExecutionSession(sessionID string) {
|
||||
sessionID = strings.TrimSpace(sessionID)
|
||||
if m == nil || sessionID == "" {
|
||||
return
|
||||
}
|
||||
|
||||
m.mu.RLock()
|
||||
executors := make([]ProviderExecutor, 0, len(m.executors))
|
||||
for _, exec := range m.executors {
|
||||
executors = append(executors, exec)
|
||||
}
|
||||
m.mu.RUnlock()
|
||||
|
||||
for i := range executors {
|
||||
if closer, ok := executors[i].(ExecutionSessionCloser); ok && closer != nil {
|
||||
closer.CloseExecutionSession(sessionID)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Manager) pickNext(ctx context.Context, provider, model string, opts cliproxyexecutor.Options, tried map[string]struct{}) (*Auth, ProviderExecutor, error) {
|
||||
pinnedAuthID := pinnedAuthIDFromMetadata(opts.Metadata)
|
||||
|
||||
m.mu.RLock()
|
||||
executor, okExecutor := m.executors[provider]
|
||||
if !okExecutor {
|
||||
@@ -1571,6 +1680,9 @@ func (m *Manager) pickNext(ctx context.Context, provider, model string, opts cli
|
||||
if candidate.Provider != provider || candidate.Disabled {
|
||||
continue
|
||||
}
|
||||
if pinnedAuthID != "" && candidate.ID != pinnedAuthID {
|
||||
continue
|
||||
}
|
||||
if _, used := tried[candidate.ID]; used {
|
||||
continue
|
||||
}
|
||||
@@ -1606,6 +1718,8 @@ func (m *Manager) pickNext(ctx context.Context, provider, model string, opts cli
|
||||
}
|
||||
|
||||
func (m *Manager) pickNextMixed(ctx context.Context, providers []string, model string, opts cliproxyexecutor.Options, tried map[string]struct{}) (*Auth, ProviderExecutor, string, error) {
|
||||
pinnedAuthID := pinnedAuthIDFromMetadata(opts.Metadata)
|
||||
|
||||
providerSet := make(map[string]struct{}, len(providers))
|
||||
for _, provider := range providers {
|
||||
p := strings.TrimSpace(strings.ToLower(provider))
|
||||
@@ -1633,6 +1747,9 @@ func (m *Manager) pickNextMixed(ctx context.Context, providers []string, model s
|
||||
if candidate == nil || candidate.Disabled {
|
||||
continue
|
||||
}
|
||||
if pinnedAuthID != "" && candidate.ID != pinnedAuthID {
|
||||
continue
|
||||
}
|
||||
providerKey := strings.TrimSpace(strings.ToLower(candidate.Provider))
|
||||
if providerKey == "" {
|
||||
continue
|
||||
|
||||
100
sdk/cliproxy/auth/conductor_executor_replace_test.go
Normal file
100
sdk/cliproxy/auth/conductor_executor_replace_test.go
Normal file
@@ -0,0 +1,100 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net/http"
|
||||
"sync"
|
||||
"testing"
|
||||
|
||||
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
|
||||
)
|
||||
|
||||
type replaceAwareExecutor struct {
|
||||
id string
|
||||
|
||||
mu sync.Mutex
|
||||
closedSessionIDs []string
|
||||
}
|
||||
|
||||
func (e *replaceAwareExecutor) Identifier() string {
|
||||
return e.id
|
||||
}
|
||||
|
||||
func (e *replaceAwareExecutor) Execute(context.Context, *Auth, cliproxyexecutor.Request, cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
|
||||
return cliproxyexecutor.Response{}, nil
|
||||
}
|
||||
|
||||
func (e *replaceAwareExecutor) ExecuteStream(context.Context, *Auth, cliproxyexecutor.Request, cliproxyexecutor.Options) (<-chan cliproxyexecutor.StreamChunk, error) {
|
||||
ch := make(chan cliproxyexecutor.StreamChunk)
|
||||
close(ch)
|
||||
return ch, nil
|
||||
}
|
||||
|
||||
func (e *replaceAwareExecutor) Refresh(_ context.Context, auth *Auth) (*Auth, error) {
|
||||
return auth, nil
|
||||
}
|
||||
|
||||
func (e *replaceAwareExecutor) CountTokens(context.Context, *Auth, cliproxyexecutor.Request, cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
|
||||
return cliproxyexecutor.Response{}, nil
|
||||
}
|
||||
|
||||
func (e *replaceAwareExecutor) HttpRequest(context.Context, *Auth, *http.Request) (*http.Response, error) {
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (e *replaceAwareExecutor) CloseExecutionSession(sessionID string) {
|
||||
e.mu.Lock()
|
||||
defer e.mu.Unlock()
|
||||
e.closedSessionIDs = append(e.closedSessionIDs, sessionID)
|
||||
}
|
||||
|
||||
func (e *replaceAwareExecutor) ClosedSessionIDs() []string {
|
||||
e.mu.Lock()
|
||||
defer e.mu.Unlock()
|
||||
out := make([]string, len(e.closedSessionIDs))
|
||||
copy(out, e.closedSessionIDs)
|
||||
return out
|
||||
}
|
||||
|
||||
func TestManagerRegisterExecutorClosesReplacedExecutionSessions(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
manager := NewManager(nil, nil, nil)
|
||||
replaced := &replaceAwareExecutor{id: "codex"}
|
||||
current := &replaceAwareExecutor{id: "codex"}
|
||||
|
||||
manager.RegisterExecutor(replaced)
|
||||
manager.RegisterExecutor(current)
|
||||
|
||||
closed := replaced.ClosedSessionIDs()
|
||||
if len(closed) != 1 {
|
||||
t.Fatalf("expected replaced executor close calls = 1, got %d", len(closed))
|
||||
}
|
||||
if closed[0] != CloseAllExecutionSessionsID {
|
||||
t.Fatalf("expected close marker %q, got %q", CloseAllExecutionSessionsID, closed[0])
|
||||
}
|
||||
if len(current.ClosedSessionIDs()) != 0 {
|
||||
t.Fatalf("expected current executor to stay open")
|
||||
}
|
||||
}
|
||||
|
||||
func TestManagerExecutorReturnsRegisteredExecutor(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
manager := NewManager(nil, nil, nil)
|
||||
current := &replaceAwareExecutor{id: "codex"}
|
||||
manager.RegisterExecutor(current)
|
||||
|
||||
resolved, okResolved := manager.Executor("CODEX")
|
||||
if !okResolved {
|
||||
t.Fatal("expected registered executor to be found")
|
||||
}
|
||||
if resolved != current {
|
||||
t.Fatal("expected resolved executor to match registered executor")
|
||||
}
|
||||
|
||||
_, okMissing := manager.Executor("unknown")
|
||||
if okMissing {
|
||||
t.Fatal("expected unknown provider lookup to fail")
|
||||
}
|
||||
}
|
||||
@@ -134,6 +134,62 @@ func canonicalModelKey(model string) string {
|
||||
return modelName
|
||||
}
|
||||
|
||||
func authWebsocketsEnabled(auth *Auth) bool {
|
||||
if auth == nil {
|
||||
return false
|
||||
}
|
||||
if len(auth.Attributes) > 0 {
|
||||
if raw := strings.TrimSpace(auth.Attributes["websockets"]); raw != "" {
|
||||
parsed, errParse := strconv.ParseBool(raw)
|
||||
if errParse == nil {
|
||||
return parsed
|
||||
}
|
||||
}
|
||||
}
|
||||
if len(auth.Metadata) == 0 {
|
||||
return false
|
||||
}
|
||||
raw, ok := auth.Metadata["websockets"]
|
||||
if !ok || raw == nil {
|
||||
return false
|
||||
}
|
||||
switch v := raw.(type) {
|
||||
case bool:
|
||||
return v
|
||||
case string:
|
||||
parsed, errParse := strconv.ParseBool(strings.TrimSpace(v))
|
||||
if errParse == nil {
|
||||
return parsed
|
||||
}
|
||||
default:
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func preferCodexWebsocketAuths(ctx context.Context, provider string, available []*Auth) []*Auth {
|
||||
if len(available) == 0 {
|
||||
return available
|
||||
}
|
||||
if !cliproxyexecutor.DownstreamWebsocket(ctx) {
|
||||
return available
|
||||
}
|
||||
if !strings.EqualFold(strings.TrimSpace(provider), "codex") {
|
||||
return available
|
||||
}
|
||||
|
||||
wsEnabled := make([]*Auth, 0, len(available))
|
||||
for i := 0; i < len(available); i++ {
|
||||
candidate := available[i]
|
||||
if authWebsocketsEnabled(candidate) {
|
||||
wsEnabled = append(wsEnabled, candidate)
|
||||
}
|
||||
}
|
||||
if len(wsEnabled) > 0 {
|
||||
return wsEnabled
|
||||
}
|
||||
return available
|
||||
}
|
||||
|
||||
func collectAvailableByPriority(auths []*Auth, model string, now time.Time) (available map[int][]*Auth, cooldownCount int, earliest time.Time) {
|
||||
available = make(map[int][]*Auth)
|
||||
for i := 0; i < len(auths); i++ {
|
||||
@@ -193,13 +249,13 @@ func getAvailableAuths(auths []*Auth, provider, model string, now time.Time) ([]
|
||||
|
||||
// Pick selects the next available auth for the provider in a round-robin manner.
|
||||
func (s *RoundRobinSelector) Pick(ctx context.Context, provider, model string, opts cliproxyexecutor.Options, auths []*Auth) (*Auth, error) {
|
||||
_ = ctx
|
||||
_ = opts
|
||||
now := time.Now()
|
||||
available, err := getAvailableAuths(auths, provider, model, now)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
available = preferCodexWebsocketAuths(ctx, provider, available)
|
||||
key := provider + ":" + canonicalModelKey(model)
|
||||
s.mu.Lock()
|
||||
if s.cursors == nil {
|
||||
@@ -226,13 +282,13 @@ func (s *RoundRobinSelector) Pick(ctx context.Context, provider, model string, o
|
||||
|
||||
// Pick selects the first available auth for the provider in a deterministic manner.
|
||||
func (s *FillFirstSelector) Pick(ctx context.Context, provider, model string, opts cliproxyexecutor.Options, auths []*Auth) (*Auth, error) {
|
||||
_ = ctx
|
||||
_ = opts
|
||||
now := time.Now()
|
||||
available, err := getAvailableAuths(auths, provider, model, now)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
available = preferCodexWebsocketAuths(ctx, provider, available)
|
||||
return available[0], nil
|
||||
}
|
||||
|
||||
|
||||
23
sdk/cliproxy/executor/context.go
Normal file
23
sdk/cliproxy/executor/context.go
Normal file
@@ -0,0 +1,23 @@
|
||||
package executor
|
||||
|
||||
import "context"
|
||||
|
||||
type downstreamWebsocketContextKey struct{}
|
||||
|
||||
// WithDownstreamWebsocket marks the current request as coming from a downstream websocket connection.
|
||||
func WithDownstreamWebsocket(ctx context.Context) context.Context {
|
||||
if ctx == nil {
|
||||
ctx = context.Background()
|
||||
}
|
||||
return context.WithValue(ctx, downstreamWebsocketContextKey{}, true)
|
||||
}
|
||||
|
||||
// DownstreamWebsocket reports whether the current request originates from a downstream websocket connection.
|
||||
func DownstreamWebsocket(ctx context.Context) bool {
|
||||
if ctx == nil {
|
||||
return false
|
||||
}
|
||||
raw := ctx.Value(downstreamWebsocketContextKey{})
|
||||
enabled, ok := raw.(bool)
|
||||
return ok && enabled
|
||||
}
|
||||
@@ -10,6 +10,17 @@ import (
|
||||
// RequestedModelMetadataKey stores the client-requested model name in Options.Metadata.
|
||||
const RequestedModelMetadataKey = "requested_model"
|
||||
|
||||
const (
|
||||
// PinnedAuthMetadataKey locks execution to a specific auth ID.
|
||||
PinnedAuthMetadataKey = "pinned_auth_id"
|
||||
// SelectedAuthMetadataKey stores the auth ID selected by the scheduler.
|
||||
SelectedAuthMetadataKey = "selected_auth_id"
|
||||
// SelectedAuthCallbackMetadataKey carries an optional callback invoked with the selected auth ID.
|
||||
SelectedAuthCallbackMetadataKey = "selected_auth_callback"
|
||||
// ExecutionSessionMetadataKey identifies a long-lived downstream execution session.
|
||||
ExecutionSessionMetadataKey = "execution_session_id"
|
||||
)
|
||||
|
||||
// Request encapsulates the translated payload that will be sent to a provider executor.
|
||||
type Request struct {
|
||||
// Model is the upstream model identifier after translation.
|
||||
|
||||
@@ -325,6 +325,9 @@ func (s *Service) applyCoreAuthRemoval(ctx context.Context, id string) {
|
||||
if _, err := s.coreManager.Update(ctx, existing); err != nil {
|
||||
log.Errorf("failed to disable auth %s: %v", id, err)
|
||||
}
|
||||
if strings.EqualFold(strings.TrimSpace(existing.Provider), "codex") {
|
||||
s.ensureExecutorsForAuth(existing)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -357,7 +360,24 @@ func openAICompatInfoFromAuth(a *coreauth.Auth) (providerKey string, compatName
|
||||
}
|
||||
|
||||
func (s *Service) ensureExecutorsForAuth(a *coreauth.Auth) {
|
||||
if s == nil || a == nil {
|
||||
s.ensureExecutorsForAuthWithMode(a, false)
|
||||
}
|
||||
|
||||
func (s *Service) ensureExecutorsForAuthWithMode(a *coreauth.Auth, forceReplace bool) {
|
||||
if s == nil || s.coreManager == nil || a == nil {
|
||||
return
|
||||
}
|
||||
if strings.EqualFold(strings.TrimSpace(a.Provider), "codex") {
|
||||
if !forceReplace {
|
||||
existingExecutor, hasExecutor := s.coreManager.Executor("codex")
|
||||
if hasExecutor {
|
||||
_, isCodexAutoExecutor := existingExecutor.(*executor.CodexAutoExecutor)
|
||||
if isCodexAutoExecutor {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
s.coreManager.RegisterExecutor(executor.NewCodexAutoExecutor(s.cfg))
|
||||
return
|
||||
}
|
||||
// Skip disabled auth entries when (re)binding executors.
|
||||
@@ -392,8 +412,6 @@ func (s *Service) ensureExecutorsForAuth(a *coreauth.Auth) {
|
||||
s.coreManager.RegisterExecutor(executor.NewAntigravityExecutor(s.cfg))
|
||||
case "claude":
|
||||
s.coreManager.RegisterExecutor(executor.NewClaudeExecutor(s.cfg))
|
||||
case "codex":
|
||||
s.coreManager.RegisterExecutor(executor.NewCodexExecutor(s.cfg))
|
||||
case "qwen":
|
||||
s.coreManager.RegisterExecutor(executor.NewQwenExecutor(s.cfg))
|
||||
case "iflow":
|
||||
@@ -415,8 +433,15 @@ func (s *Service) rebindExecutors() {
|
||||
return
|
||||
}
|
||||
auths := s.coreManager.List()
|
||||
reboundCodex := false
|
||||
for _, auth := range auths {
|
||||
s.ensureExecutorsForAuth(auth)
|
||||
if auth != nil && strings.EqualFold(strings.TrimSpace(auth.Provider), "codex") {
|
||||
if reboundCodex {
|
||||
continue
|
||||
}
|
||||
reboundCodex = true
|
||||
}
|
||||
s.ensureExecutorsForAuthWithMode(auth, true)
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
64
sdk/cliproxy/service_codex_executor_binding_test.go
Normal file
64
sdk/cliproxy/service_codex_executor_binding_test.go
Normal file
@@ -0,0 +1,64 @@
|
||||
package cliproxy
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/sdk/config"
|
||||
)
|
||||
|
||||
func TestEnsureExecutorsForAuth_CodexDoesNotReplaceInNormalMode(t *testing.T) {
|
||||
service := &Service{
|
||||
cfg: &config.Config{},
|
||||
coreManager: coreauth.NewManager(nil, nil, nil),
|
||||
}
|
||||
auth := &coreauth.Auth{
|
||||
ID: "codex-auth-1",
|
||||
Provider: "codex",
|
||||
Status: coreauth.StatusActive,
|
||||
}
|
||||
|
||||
service.ensureExecutorsForAuth(auth)
|
||||
firstExecutor, okFirst := service.coreManager.Executor("codex")
|
||||
if !okFirst || firstExecutor == nil {
|
||||
t.Fatal("expected codex executor after first bind")
|
||||
}
|
||||
|
||||
service.ensureExecutorsForAuth(auth)
|
||||
secondExecutor, okSecond := service.coreManager.Executor("codex")
|
||||
if !okSecond || secondExecutor == nil {
|
||||
t.Fatal("expected codex executor after second bind")
|
||||
}
|
||||
|
||||
if firstExecutor != secondExecutor {
|
||||
t.Fatal("expected codex executor to stay unchanged in normal mode")
|
||||
}
|
||||
}
|
||||
|
||||
func TestEnsureExecutorsForAuthWithMode_CodexForceReplace(t *testing.T) {
|
||||
service := &Service{
|
||||
cfg: &config.Config{},
|
||||
coreManager: coreauth.NewManager(nil, nil, nil),
|
||||
}
|
||||
auth := &coreauth.Auth{
|
||||
ID: "codex-auth-2",
|
||||
Provider: "codex",
|
||||
Status: coreauth.StatusActive,
|
||||
}
|
||||
|
||||
service.ensureExecutorsForAuth(auth)
|
||||
firstExecutor, okFirst := service.coreManager.Executor("codex")
|
||||
if !okFirst || firstExecutor == nil {
|
||||
t.Fatal("expected codex executor after first bind")
|
||||
}
|
||||
|
||||
service.ensureExecutorsForAuthWithMode(auth, true)
|
||||
secondExecutor, okSecond := service.coreManager.Executor("codex")
|
||||
if !okSecond || secondExecutor == nil {
|
||||
t.Fatal("expected codex executor after forced rebind")
|
||||
}
|
||||
|
||||
if firstExecutor == secondExecutor {
|
||||
t.Fatal("expected codex executor replacement in force mode")
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user