mirror of
https://github.com/router-for-me/CLIProxyAPI.git
synced 2026-02-20 13:20:52 +08:00
Merge branch 'dev'
This commit is contained in:
@@ -15,10 +15,12 @@ import (
|
|||||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/util"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
const maxErrorOnlyCapturedRequestBodyBytes int64 = 1 << 20 // 1 MiB
|
||||||
|
|
||||||
// RequestLoggingMiddleware creates a Gin middleware that logs HTTP requests and responses.
|
// RequestLoggingMiddleware creates a Gin middleware that logs HTTP requests and responses.
|
||||||
// It captures detailed information about the request and response, including headers and body,
|
// It captures detailed information about the request and response, including headers and body,
|
||||||
// and uses the provided RequestLogger to record this data. When logging is disabled in the
|
// and uses the provided RequestLogger to record this data. When full request logging is disabled,
|
||||||
// logger, it still captures data so that upstream errors can be persisted.
|
// body capture is limited to small known-size payloads to avoid large per-request memory spikes.
|
||||||
func RequestLoggingMiddleware(logger logging.RequestLogger) gin.HandlerFunc {
|
func RequestLoggingMiddleware(logger logging.RequestLogger) gin.HandlerFunc {
|
||||||
return func(c *gin.Context) {
|
return func(c *gin.Context) {
|
||||||
if logger == nil {
|
if logger == nil {
|
||||||
@@ -26,7 +28,7 @@ func RequestLoggingMiddleware(logger logging.RequestLogger) gin.HandlerFunc {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if c.Request.Method == http.MethodGet {
|
if shouldSkipMethodForRequestLogging(c.Request) {
|
||||||
c.Next()
|
c.Next()
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -37,8 +39,10 @@ func RequestLoggingMiddleware(logger logging.RequestLogger) gin.HandlerFunc {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
loggerEnabled := logger.IsEnabled()
|
||||||
|
|
||||||
// Capture request information
|
// Capture request information
|
||||||
requestInfo, err := captureRequestInfo(c)
|
requestInfo, err := captureRequestInfo(c, shouldCaptureRequestBody(loggerEnabled, c.Request))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// Log error but continue processing
|
// Log error but continue processing
|
||||||
// In a real implementation, you might want to use a proper logger here
|
// In a real implementation, you might want to use a proper logger here
|
||||||
@@ -48,7 +52,7 @@ func RequestLoggingMiddleware(logger logging.RequestLogger) gin.HandlerFunc {
|
|||||||
|
|
||||||
// Create response writer wrapper
|
// Create response writer wrapper
|
||||||
wrapper := NewResponseWriterWrapper(c.Writer, logger, requestInfo)
|
wrapper := NewResponseWriterWrapper(c.Writer, logger, requestInfo)
|
||||||
if !logger.IsEnabled() {
|
if !loggerEnabled {
|
||||||
wrapper.logOnErrorOnly = true
|
wrapper.logOnErrorOnly = true
|
||||||
}
|
}
|
||||||
c.Writer = wrapper
|
c.Writer = wrapper
|
||||||
@@ -64,10 +68,47 @@ func RequestLoggingMiddleware(logger logging.RequestLogger) gin.HandlerFunc {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func shouldSkipMethodForRequestLogging(req *http.Request) bool {
|
||||||
|
if req == nil {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
if req.Method != http.MethodGet {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return !isResponsesWebsocketUpgrade(req)
|
||||||
|
}
|
||||||
|
|
||||||
|
func isResponsesWebsocketUpgrade(req *http.Request) bool {
|
||||||
|
if req == nil || req.URL == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if req.URL.Path != "/v1/responses" {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return strings.EqualFold(strings.TrimSpace(req.Header.Get("Upgrade")), "websocket")
|
||||||
|
}
|
||||||
|
|
||||||
|
func shouldCaptureRequestBody(loggerEnabled bool, req *http.Request) bool {
|
||||||
|
if loggerEnabled {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
if req == nil || req.Body == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
contentType := strings.ToLower(strings.TrimSpace(req.Header.Get("Content-Type")))
|
||||||
|
if strings.HasPrefix(contentType, "multipart/form-data") {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if req.ContentLength <= 0 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return req.ContentLength <= maxErrorOnlyCapturedRequestBodyBytes
|
||||||
|
}
|
||||||
|
|
||||||
// captureRequestInfo extracts relevant information from the incoming HTTP request.
|
// captureRequestInfo extracts relevant information from the incoming HTTP request.
|
||||||
// It captures the URL, method, headers, and body. The request body is read and then
|
// It captures the URL, method, headers, and body. The request body is read and then
|
||||||
// restored so that it can be processed by subsequent handlers.
|
// restored so that it can be processed by subsequent handlers.
|
||||||
func captureRequestInfo(c *gin.Context) (*RequestInfo, error) {
|
func captureRequestInfo(c *gin.Context, captureBody bool) (*RequestInfo, error) {
|
||||||
// Capture URL with sensitive query parameters masked
|
// Capture URL with sensitive query parameters masked
|
||||||
maskedQuery := util.MaskSensitiveQuery(c.Request.URL.RawQuery)
|
maskedQuery := util.MaskSensitiveQuery(c.Request.URL.RawQuery)
|
||||||
url := c.Request.URL.Path
|
url := c.Request.URL.Path
|
||||||
@@ -86,7 +127,7 @@ func captureRequestInfo(c *gin.Context) (*RequestInfo, error) {
|
|||||||
|
|
||||||
// Capture request body
|
// Capture request body
|
||||||
var body []byte
|
var body []byte
|
||||||
if c.Request.Body != nil {
|
if captureBody && c.Request.Body != nil {
|
||||||
// Read the body
|
// Read the body
|
||||||
bodyBytes, err := io.ReadAll(c.Request.Body)
|
bodyBytes, err := io.ReadAll(c.Request.Body)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
138
internal/api/middleware/request_logging_test.go
Normal file
138
internal/api/middleware/request_logging_test.go
Normal file
@@ -0,0 +1,138 @@
|
|||||||
|
package middleware
|
||||||
|
|
||||||
|
import (
|
||||||
|
"io"
|
||||||
|
"net/http"
|
||||||
|
"net/url"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestShouldSkipMethodForRequestLogging(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
req *http.Request
|
||||||
|
skip bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "nil request",
|
||||||
|
req: nil,
|
||||||
|
skip: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "post request should not skip",
|
||||||
|
req: &http.Request{
|
||||||
|
Method: http.MethodPost,
|
||||||
|
URL: &url.URL{Path: "/v1/responses"},
|
||||||
|
},
|
||||||
|
skip: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "plain get should skip",
|
||||||
|
req: &http.Request{
|
||||||
|
Method: http.MethodGet,
|
||||||
|
URL: &url.URL{Path: "/v1/models"},
|
||||||
|
Header: http.Header{},
|
||||||
|
},
|
||||||
|
skip: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "responses websocket upgrade should not skip",
|
||||||
|
req: &http.Request{
|
||||||
|
Method: http.MethodGet,
|
||||||
|
URL: &url.URL{Path: "/v1/responses"},
|
||||||
|
Header: http.Header{"Upgrade": []string{"websocket"}},
|
||||||
|
},
|
||||||
|
skip: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "responses get without upgrade should skip",
|
||||||
|
req: &http.Request{
|
||||||
|
Method: http.MethodGet,
|
||||||
|
URL: &url.URL{Path: "/v1/responses"},
|
||||||
|
Header: http.Header{},
|
||||||
|
},
|
||||||
|
skip: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for i := range tests {
|
||||||
|
got := shouldSkipMethodForRequestLogging(tests[i].req)
|
||||||
|
if got != tests[i].skip {
|
||||||
|
t.Fatalf("%s: got skip=%t, want %t", tests[i].name, got, tests[i].skip)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestShouldCaptureRequestBody(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
loggerEnabled bool
|
||||||
|
req *http.Request
|
||||||
|
want bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "logger enabled always captures",
|
||||||
|
loggerEnabled: true,
|
||||||
|
req: &http.Request{
|
||||||
|
Body: io.NopCloser(strings.NewReader("{}")),
|
||||||
|
ContentLength: -1,
|
||||||
|
Header: http.Header{"Content-Type": []string{"application/json"}},
|
||||||
|
},
|
||||||
|
want: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "nil request",
|
||||||
|
loggerEnabled: false,
|
||||||
|
req: nil,
|
||||||
|
want: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "small known size json in error-only mode",
|
||||||
|
loggerEnabled: false,
|
||||||
|
req: &http.Request{
|
||||||
|
Body: io.NopCloser(strings.NewReader("{}")),
|
||||||
|
ContentLength: 2,
|
||||||
|
Header: http.Header{"Content-Type": []string{"application/json"}},
|
||||||
|
},
|
||||||
|
want: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "large known size skipped in error-only mode",
|
||||||
|
loggerEnabled: false,
|
||||||
|
req: &http.Request{
|
||||||
|
Body: io.NopCloser(strings.NewReader("x")),
|
||||||
|
ContentLength: maxErrorOnlyCapturedRequestBodyBytes + 1,
|
||||||
|
Header: http.Header{"Content-Type": []string{"application/json"}},
|
||||||
|
},
|
||||||
|
want: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "unknown size skipped in error-only mode",
|
||||||
|
loggerEnabled: false,
|
||||||
|
req: &http.Request{
|
||||||
|
Body: io.NopCloser(strings.NewReader("x")),
|
||||||
|
ContentLength: -1,
|
||||||
|
Header: http.Header{"Content-Type": []string{"application/json"}},
|
||||||
|
},
|
||||||
|
want: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "multipart skipped in error-only mode",
|
||||||
|
loggerEnabled: false,
|
||||||
|
req: &http.Request{
|
||||||
|
Body: io.NopCloser(strings.NewReader("x")),
|
||||||
|
ContentLength: 1,
|
||||||
|
Header: http.Header{"Content-Type": []string{"multipart/form-data; boundary=abc"}},
|
||||||
|
},
|
||||||
|
want: false,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for i := range tests {
|
||||||
|
got := shouldCaptureRequestBody(tests[i].loggerEnabled, tests[i].req)
|
||||||
|
if got != tests[i].want {
|
||||||
|
t.Fatalf("%s: got %t, want %t", tests[i].name, got, tests[i].want)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -14,6 +14,8 @@ import (
|
|||||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/logging"
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/logging"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
const requestBodyOverrideContextKey = "REQUEST_BODY_OVERRIDE"
|
||||||
|
|
||||||
// RequestInfo holds essential details of an incoming HTTP request for logging purposes.
|
// RequestInfo holds essential details of an incoming HTTP request for logging purposes.
|
||||||
type RequestInfo struct {
|
type RequestInfo struct {
|
||||||
URL string // URL is the request URL.
|
URL string // URL is the request URL.
|
||||||
@@ -223,8 +225,8 @@ func (w *ResponseWriterWrapper) detectStreaming(contentType string) bool {
|
|||||||
|
|
||||||
// Only fall back to request payload hints when Content-Type is not set yet.
|
// Only fall back to request payload hints when Content-Type is not set yet.
|
||||||
if w.requestInfo != nil && len(w.requestInfo.Body) > 0 {
|
if w.requestInfo != nil && len(w.requestInfo.Body) > 0 {
|
||||||
bodyStr := string(w.requestInfo.Body)
|
return bytes.Contains(w.requestInfo.Body, []byte(`"stream": true`)) ||
|
||||||
return strings.Contains(bodyStr, `"stream": true`) || strings.Contains(bodyStr, `"stream":true`)
|
bytes.Contains(w.requestInfo.Body, []byte(`"stream":true`))
|
||||||
}
|
}
|
||||||
|
|
||||||
return false
|
return false
|
||||||
@@ -310,7 +312,7 @@ func (w *ResponseWriterWrapper) Finalize(c *gin.Context) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
return w.logRequest(finalStatusCode, w.cloneHeaders(), w.body.Bytes(), w.extractAPIRequest(c), w.extractAPIResponse(c), w.extractAPIResponseTimestamp(c), slicesAPIResponseError, forceLog)
|
return w.logRequest(w.extractRequestBody(c), finalStatusCode, w.cloneHeaders(), w.body.Bytes(), w.extractAPIRequest(c), w.extractAPIResponse(c), w.extractAPIResponseTimestamp(c), slicesAPIResponseError, forceLog)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (w *ResponseWriterWrapper) cloneHeaders() map[string][]string {
|
func (w *ResponseWriterWrapper) cloneHeaders() map[string][]string {
|
||||||
@@ -361,14 +363,30 @@ func (w *ResponseWriterWrapper) extractAPIResponseTimestamp(c *gin.Context) time
|
|||||||
return time.Time{}
|
return time.Time{}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (w *ResponseWriterWrapper) logRequest(statusCode int, headers map[string][]string, body []byte, apiRequestBody, apiResponseBody []byte, apiResponseTimestamp time.Time, apiResponseErrors []*interfaces.ErrorMessage, forceLog bool) error {
|
func (w *ResponseWriterWrapper) extractRequestBody(c *gin.Context) []byte {
|
||||||
if w.requestInfo == nil {
|
if c != nil {
|
||||||
|
if bodyOverride, isExist := c.Get(requestBodyOverrideContextKey); isExist {
|
||||||
|
switch value := bodyOverride.(type) {
|
||||||
|
case []byte:
|
||||||
|
if len(value) > 0 {
|
||||||
|
return bytes.Clone(value)
|
||||||
|
}
|
||||||
|
case string:
|
||||||
|
if strings.TrimSpace(value) != "" {
|
||||||
|
return []byte(value)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if w.requestInfo != nil && len(w.requestInfo.Body) > 0 {
|
||||||
|
return w.requestInfo.Body
|
||||||
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
var requestBody []byte
|
func (w *ResponseWriterWrapper) logRequest(requestBody []byte, statusCode int, headers map[string][]string, body []byte, apiRequestBody, apiResponseBody []byte, apiResponseTimestamp time.Time, apiResponseErrors []*interfaces.ErrorMessage, forceLog bool) error {
|
||||||
if len(w.requestInfo.Body) > 0 {
|
if w.requestInfo == nil {
|
||||||
requestBody = w.requestInfo.Body
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
if loggerWithOptions, ok := w.logger.(interface {
|
if loggerWithOptions, ok := w.logger.(interface {
|
||||||
|
|||||||
43
internal/api/middleware/response_writer_test.go
Normal file
43
internal/api/middleware/response_writer_test.go
Normal file
@@ -0,0 +1,43 @@
|
|||||||
|
package middleware
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http/httptest"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestExtractRequestBodyPrefersOverride(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
recorder := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(recorder)
|
||||||
|
|
||||||
|
wrapper := &ResponseWriterWrapper{
|
||||||
|
requestInfo: &RequestInfo{Body: []byte("original-body")},
|
||||||
|
}
|
||||||
|
|
||||||
|
body := wrapper.extractRequestBody(c)
|
||||||
|
if string(body) != "original-body" {
|
||||||
|
t.Fatalf("request body = %q, want %q", string(body), "original-body")
|
||||||
|
}
|
||||||
|
|
||||||
|
c.Set(requestBodyOverrideContextKey, []byte("override-body"))
|
||||||
|
body = wrapper.extractRequestBody(c)
|
||||||
|
if string(body) != "override-body" {
|
||||||
|
t.Fatalf("request body = %q, want %q", string(body), "override-body")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestExtractRequestBodySupportsStringOverride(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
recorder := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(recorder)
|
||||||
|
|
||||||
|
wrapper := &ResponseWriterWrapper{}
|
||||||
|
c.Set(requestBodyOverrideContextKey, "override-as-string")
|
||||||
|
|
||||||
|
body := wrapper.extractRequestBody(c)
|
||||||
|
if string(body) != "override-as-string" {
|
||||||
|
t.Fatalf("request body = %q, want %q", string(body), "override-as-string")
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -323,6 +323,7 @@ func (s *Server) setupRoutes() {
|
|||||||
v1.POST("/completions", openaiHandlers.Completions)
|
v1.POST("/completions", openaiHandlers.Completions)
|
||||||
v1.POST("/messages", claudeCodeHandlers.ClaudeMessages)
|
v1.POST("/messages", claudeCodeHandlers.ClaudeMessages)
|
||||||
v1.POST("/messages/count_tokens", claudeCodeHandlers.ClaudeCountTokens)
|
v1.POST("/messages/count_tokens", claudeCodeHandlers.ClaudeCountTokens)
|
||||||
|
v1.GET("/responses", openaiResponsesHandlers.ResponsesWebsocket)
|
||||||
v1.POST("/responses", openaiResponsesHandlers.Responses)
|
v1.POST("/responses", openaiResponsesHandlers.Responses)
|
||||||
v1.POST("/responses/compact", openaiResponsesHandlers.Compact)
|
v1.POST("/responses/compact", openaiResponsesHandlers.Compact)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -355,6 +355,9 @@ type CodexKey struct {
|
|||||||
// If empty, the default Codex API URL will be used.
|
// If empty, the default Codex API URL will be used.
|
||||||
BaseURL string `yaml:"base-url" json:"base-url"`
|
BaseURL string `yaml:"base-url" json:"base-url"`
|
||||||
|
|
||||||
|
// Websockets enables the Responses API websocket transport for this credential.
|
||||||
|
Websockets bool `yaml:"websockets,omitempty" json:"websockets,omitempty"`
|
||||||
|
|
||||||
// ProxyURL overrides the global proxy setting for this API key if provided.
|
// ProxyURL overrides the global proxy setting for this API key if provided.
|
||||||
ProxyURL string `yaml:"proxy-url" json:"proxy-url"`
|
ProxyURL string `yaml:"proxy-url" json:"proxy-url"`
|
||||||
|
|
||||||
|
|||||||
@@ -28,6 +28,17 @@ func GetClaudeModels() []*ModelInfo {
|
|||||||
MaxCompletionTokens: 64000,
|
MaxCompletionTokens: 64000,
|
||||||
Thinking: &ThinkingSupport{Min: 1024, Max: 128000, ZeroAllowed: true, DynamicAllowed: false},
|
Thinking: &ThinkingSupport{Min: 1024, Max: 128000, ZeroAllowed: true, DynamicAllowed: false},
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
ID: "claude-sonnet-4-6",
|
||||||
|
Object: "model",
|
||||||
|
Created: 1771372800, // 2026-02-17
|
||||||
|
OwnedBy: "anthropic",
|
||||||
|
Type: "claude",
|
||||||
|
DisplayName: "Claude 4.6 Sonnet",
|
||||||
|
ContextLength: 200000,
|
||||||
|
MaxCompletionTokens: 64000,
|
||||||
|
Thinking: &ThinkingSupport{Min: 1024, Max: 128000, ZeroAllowed: true, DynamicAllowed: false},
|
||||||
|
},
|
||||||
{
|
{
|
||||||
ID: "claude-opus-4-6",
|
ID: "claude-opus-4-6",
|
||||||
Object: "model",
|
Object: "model",
|
||||||
@@ -897,6 +908,8 @@ func GetAntigravityModelConfig() map[string]*AntigravityModelConfig {
|
|||||||
"claude-opus-4-5-thinking": {Thinking: &ThinkingSupport{Min: 1024, Max: 128000, ZeroAllowed: true, DynamicAllowed: true}, MaxCompletionTokens: 64000},
|
"claude-opus-4-5-thinking": {Thinking: &ThinkingSupport{Min: 1024, Max: 128000, ZeroAllowed: true, DynamicAllowed: true}, MaxCompletionTokens: 64000},
|
||||||
"claude-opus-4-6-thinking": {Thinking: &ThinkingSupport{Min: 1024, Max: 128000, ZeroAllowed: true, DynamicAllowed: true}, MaxCompletionTokens: 64000},
|
"claude-opus-4-6-thinking": {Thinking: &ThinkingSupport{Min: 1024, Max: 128000, ZeroAllowed: true, DynamicAllowed: true}, MaxCompletionTokens: 64000},
|
||||||
"claude-sonnet-4-5": {MaxCompletionTokens: 64000},
|
"claude-sonnet-4-5": {MaxCompletionTokens: 64000},
|
||||||
|
"claude-sonnet-4-6": {MaxCompletionTokens: 64000},
|
||||||
|
"claude-sonnet-4-6-thinking": {Thinking: &ThinkingSupport{Min: 1024, Max: 128000, ZeroAllowed: true, DynamicAllowed: true}, MaxCompletionTokens: 64000},
|
||||||
"gpt-oss-120b-medium": {},
|
"gpt-oss-120b-medium": {},
|
||||||
"tab_flash_lite_preview": {},
|
"tab_flash_lite_preview": {},
|
||||||
}
|
}
|
||||||
|
|||||||
1407
internal/runtime/executor/codex_websockets_executor.go
Normal file
1407
internal/runtime/executor/codex_websockets_executor.go
Normal file
File diff suppressed because it is too large
Load Diff
@@ -184,6 +184,9 @@ func BuildConfigChangeDetails(oldCfg, newCfg *config.Config) []string {
|
|||||||
if strings.TrimSpace(o.Prefix) != strings.TrimSpace(n.Prefix) {
|
if strings.TrimSpace(o.Prefix) != strings.TrimSpace(n.Prefix) {
|
||||||
changes = append(changes, fmt.Sprintf("codex[%d].prefix: %s -> %s", i, strings.TrimSpace(o.Prefix), strings.TrimSpace(n.Prefix)))
|
changes = append(changes, fmt.Sprintf("codex[%d].prefix: %s -> %s", i, strings.TrimSpace(o.Prefix), strings.TrimSpace(n.Prefix)))
|
||||||
}
|
}
|
||||||
|
if o.Websockets != n.Websockets {
|
||||||
|
changes = append(changes, fmt.Sprintf("codex[%d].websockets: %t -> %t", i, o.Websockets, n.Websockets))
|
||||||
|
}
|
||||||
if strings.TrimSpace(o.APIKey) != strings.TrimSpace(n.APIKey) {
|
if strings.TrimSpace(o.APIKey) != strings.TrimSpace(n.APIKey) {
|
||||||
changes = append(changes, fmt.Sprintf("codex[%d].api-key: updated", i))
|
changes = append(changes, fmt.Sprintf("codex[%d].api-key: updated", i))
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -160,6 +160,9 @@ func (s *ConfigSynthesizer) synthesizeCodexKeys(ctx *SynthesisContext) []*coreau
|
|||||||
if ck.BaseURL != "" {
|
if ck.BaseURL != "" {
|
||||||
attrs["base_url"] = ck.BaseURL
|
attrs["base_url"] = ck.BaseURL
|
||||||
}
|
}
|
||||||
|
if ck.Websockets {
|
||||||
|
attrs["websockets"] = "true"
|
||||||
|
}
|
||||||
if hash := diff.ComputeCodexModelsHash(ck.Models); hash != "" {
|
if hash := diff.ComputeCodexModelsHash(ck.Models); hash != "" {
|
||||||
attrs["models_hash"] = hash
|
attrs["models_hash"] = hash
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -235,6 +235,7 @@ func TestConfigSynthesizer_CodexKeys(t *testing.T) {
|
|||||||
Prefix: "dev",
|
Prefix: "dev",
|
||||||
BaseURL: "https://api.openai.com",
|
BaseURL: "https://api.openai.com",
|
||||||
ProxyURL: "http://proxy.local",
|
ProxyURL: "http://proxy.local",
|
||||||
|
Websockets: true,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
@@ -259,6 +260,9 @@ func TestConfigSynthesizer_CodexKeys(t *testing.T) {
|
|||||||
if auths[0].ProxyURL != "http://proxy.local" {
|
if auths[0].ProxyURL != "http://proxy.local" {
|
||||||
t.Errorf("expected proxy_url http://proxy.local, got %s", auths[0].ProxyURL)
|
t.Errorf("expected proxy_url http://proxy.local, got %s", auths[0].ProxyURL)
|
||||||
}
|
}
|
||||||
|
if auths[0].Attributes["websockets"] != "true" {
|
||||||
|
t.Errorf("expected websockets=true, got %s", auths[0].Attributes["websockets"])
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestConfigSynthesizer_CodexKeys_SkipsEmptyAndHeaders(t *testing.T) {
|
func TestConfigSynthesizer_CodexKeys_SkipsEmptyAndHeaders(t *testing.T) {
|
||||||
|
|||||||
@@ -52,6 +52,45 @@ const (
|
|||||||
defaultStreamingBootstrapRetries = 0
|
defaultStreamingBootstrapRetries = 0
|
||||||
)
|
)
|
||||||
|
|
||||||
|
type pinnedAuthContextKey struct{}
|
||||||
|
type selectedAuthCallbackContextKey struct{}
|
||||||
|
type executionSessionContextKey struct{}
|
||||||
|
|
||||||
|
// WithPinnedAuthID returns a child context that requests execution on a specific auth ID.
|
||||||
|
func WithPinnedAuthID(ctx context.Context, authID string) context.Context {
|
||||||
|
authID = strings.TrimSpace(authID)
|
||||||
|
if authID == "" {
|
||||||
|
return ctx
|
||||||
|
}
|
||||||
|
if ctx == nil {
|
||||||
|
ctx = context.Background()
|
||||||
|
}
|
||||||
|
return context.WithValue(ctx, pinnedAuthContextKey{}, authID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithSelectedAuthIDCallback returns a child context that receives the selected auth ID.
|
||||||
|
func WithSelectedAuthIDCallback(ctx context.Context, callback func(string)) context.Context {
|
||||||
|
if callback == nil {
|
||||||
|
return ctx
|
||||||
|
}
|
||||||
|
if ctx == nil {
|
||||||
|
ctx = context.Background()
|
||||||
|
}
|
||||||
|
return context.WithValue(ctx, selectedAuthCallbackContextKey{}, callback)
|
||||||
|
}
|
||||||
|
|
||||||
|
// WithExecutionSessionID returns a child context tagged with a long-lived execution session ID.
|
||||||
|
func WithExecutionSessionID(ctx context.Context, sessionID string) context.Context {
|
||||||
|
sessionID = strings.TrimSpace(sessionID)
|
||||||
|
if sessionID == "" {
|
||||||
|
return ctx
|
||||||
|
}
|
||||||
|
if ctx == nil {
|
||||||
|
ctx = context.Background()
|
||||||
|
}
|
||||||
|
return context.WithValue(ctx, executionSessionContextKey{}, sessionID)
|
||||||
|
}
|
||||||
|
|
||||||
// BuildErrorResponseBody builds an OpenAI-compatible JSON error response body.
|
// BuildErrorResponseBody builds an OpenAI-compatible JSON error response body.
|
||||||
// If errText is already valid JSON, it is returned as-is to preserve upstream error payloads.
|
// If errText is already valid JSON, it is returned as-is to preserve upstream error payloads.
|
||||||
func BuildErrorResponseBody(status int, errText string) []byte {
|
func BuildErrorResponseBody(status int, errText string) []byte {
|
||||||
@@ -152,7 +191,59 @@ func requestExecutionMetadata(ctx context.Context) map[string]any {
|
|||||||
if key == "" {
|
if key == "" {
|
||||||
key = uuid.NewString()
|
key = uuid.NewString()
|
||||||
}
|
}
|
||||||
return map[string]any{idempotencyKeyMetadataKey: key}
|
|
||||||
|
meta := map[string]any{idempotencyKeyMetadataKey: key}
|
||||||
|
if pinnedAuthID := pinnedAuthIDFromContext(ctx); pinnedAuthID != "" {
|
||||||
|
meta[coreexecutor.PinnedAuthMetadataKey] = pinnedAuthID
|
||||||
|
}
|
||||||
|
if selectedCallback := selectedAuthIDCallbackFromContext(ctx); selectedCallback != nil {
|
||||||
|
meta[coreexecutor.SelectedAuthCallbackMetadataKey] = selectedCallback
|
||||||
|
}
|
||||||
|
if executionSessionID := executionSessionIDFromContext(ctx); executionSessionID != "" {
|
||||||
|
meta[coreexecutor.ExecutionSessionMetadataKey] = executionSessionID
|
||||||
|
}
|
||||||
|
return meta
|
||||||
|
}
|
||||||
|
|
||||||
|
func pinnedAuthIDFromContext(ctx context.Context) string {
|
||||||
|
if ctx == nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
raw := ctx.Value(pinnedAuthContextKey{})
|
||||||
|
switch v := raw.(type) {
|
||||||
|
case string:
|
||||||
|
return strings.TrimSpace(v)
|
||||||
|
case []byte:
|
||||||
|
return strings.TrimSpace(string(v))
|
||||||
|
default:
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func selectedAuthIDCallbackFromContext(ctx context.Context) func(string) {
|
||||||
|
if ctx == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
raw := ctx.Value(selectedAuthCallbackContextKey{})
|
||||||
|
if callback, ok := raw.(func(string)); ok && callback != nil {
|
||||||
|
return callback
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func executionSessionIDFromContext(ctx context.Context) string {
|
||||||
|
if ctx == nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
raw := ctx.Value(executionSessionContextKey{})
|
||||||
|
switch v := raw.(type) {
|
||||||
|
case string:
|
||||||
|
return strings.TrimSpace(v)
|
||||||
|
case []byte:
|
||||||
|
return strings.TrimSpace(string(v))
|
||||||
|
default:
|
||||||
|
return ""
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// BaseAPIHandler contains the handlers for API endpoints.
|
// BaseAPIHandler contains the handlers for API endpoints.
|
||||||
|
|||||||
@@ -122,6 +122,82 @@ func (e *payloadThenErrorStreamExecutor) Calls() int {
|
|||||||
return e.calls
|
return e.calls
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type authAwareStreamExecutor struct {
|
||||||
|
mu sync.Mutex
|
||||||
|
calls int
|
||||||
|
authIDs []string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *authAwareStreamExecutor) Identifier() string { return "codex" }
|
||||||
|
|
||||||
|
func (e *authAwareStreamExecutor) Execute(context.Context, *coreauth.Auth, coreexecutor.Request, coreexecutor.Options) (coreexecutor.Response, error) {
|
||||||
|
return coreexecutor.Response{}, &coreauth.Error{Code: "not_implemented", Message: "Execute not implemented"}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *authAwareStreamExecutor) ExecuteStream(ctx context.Context, auth *coreauth.Auth, req coreexecutor.Request, opts coreexecutor.Options) (<-chan coreexecutor.StreamChunk, error) {
|
||||||
|
_ = ctx
|
||||||
|
_ = req
|
||||||
|
_ = opts
|
||||||
|
ch := make(chan coreexecutor.StreamChunk, 1)
|
||||||
|
|
||||||
|
authID := ""
|
||||||
|
if auth != nil {
|
||||||
|
authID = auth.ID
|
||||||
|
}
|
||||||
|
|
||||||
|
e.mu.Lock()
|
||||||
|
e.calls++
|
||||||
|
e.authIDs = append(e.authIDs, authID)
|
||||||
|
e.mu.Unlock()
|
||||||
|
|
||||||
|
if authID == "auth1" {
|
||||||
|
ch <- coreexecutor.StreamChunk{
|
||||||
|
Err: &coreauth.Error{
|
||||||
|
Code: "unauthorized",
|
||||||
|
Message: "unauthorized",
|
||||||
|
Retryable: false,
|
||||||
|
HTTPStatus: http.StatusUnauthorized,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
close(ch)
|
||||||
|
return ch, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
ch <- coreexecutor.StreamChunk{Payload: []byte("ok")}
|
||||||
|
close(ch)
|
||||||
|
return ch, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *authAwareStreamExecutor) Refresh(ctx context.Context, auth *coreauth.Auth) (*coreauth.Auth, error) {
|
||||||
|
return auth, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *authAwareStreamExecutor) CountTokens(context.Context, *coreauth.Auth, coreexecutor.Request, coreexecutor.Options) (coreexecutor.Response, error) {
|
||||||
|
return coreexecutor.Response{}, &coreauth.Error{Code: "not_implemented", Message: "CountTokens not implemented"}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *authAwareStreamExecutor) HttpRequest(ctx context.Context, auth *coreauth.Auth, req *http.Request) (*http.Response, error) {
|
||||||
|
return nil, &coreauth.Error{
|
||||||
|
Code: "not_implemented",
|
||||||
|
Message: "HttpRequest not implemented",
|
||||||
|
HTTPStatus: http.StatusNotImplemented,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *authAwareStreamExecutor) Calls() int {
|
||||||
|
e.mu.Lock()
|
||||||
|
defer e.mu.Unlock()
|
||||||
|
return e.calls
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *authAwareStreamExecutor) AuthIDs() []string {
|
||||||
|
e.mu.Lock()
|
||||||
|
defer e.mu.Unlock()
|
||||||
|
out := make([]string, len(e.authIDs))
|
||||||
|
copy(out, e.authIDs)
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
func TestExecuteStreamWithAuthManager_RetriesBeforeFirstByte(t *testing.T) {
|
func TestExecuteStreamWithAuthManager_RetriesBeforeFirstByte(t *testing.T) {
|
||||||
executor := &failOnceStreamExecutor{}
|
executor := &failOnceStreamExecutor{}
|
||||||
manager := coreauth.NewManager(nil, nil, nil)
|
manager := coreauth.NewManager(nil, nil, nil)
|
||||||
@@ -252,3 +328,128 @@ func TestExecuteStreamWithAuthManager_DoesNotRetryAfterFirstByte(t *testing.T) {
|
|||||||
t.Fatalf("expected 1 stream attempt, got %d", executor.Calls())
|
t.Fatalf("expected 1 stream attempt, got %d", executor.Calls())
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestExecuteStreamWithAuthManager_PinnedAuthKeepsSameUpstream(t *testing.T) {
|
||||||
|
executor := &authAwareStreamExecutor{}
|
||||||
|
manager := coreauth.NewManager(nil, nil, nil)
|
||||||
|
manager.RegisterExecutor(executor)
|
||||||
|
|
||||||
|
auth1 := &coreauth.Auth{
|
||||||
|
ID: "auth1",
|
||||||
|
Provider: "codex",
|
||||||
|
Status: coreauth.StatusActive,
|
||||||
|
Metadata: map[string]any{"email": "test1@example.com"},
|
||||||
|
}
|
||||||
|
if _, err := manager.Register(context.Background(), auth1); err != nil {
|
||||||
|
t.Fatalf("manager.Register(auth1): %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
auth2 := &coreauth.Auth{
|
||||||
|
ID: "auth2",
|
||||||
|
Provider: "codex",
|
||||||
|
Status: coreauth.StatusActive,
|
||||||
|
Metadata: map[string]any{"email": "test2@example.com"},
|
||||||
|
}
|
||||||
|
if _, err := manager.Register(context.Background(), auth2); err != nil {
|
||||||
|
t.Fatalf("manager.Register(auth2): %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
registry.GetGlobalRegistry().RegisterClient(auth1.ID, auth1.Provider, []*registry.ModelInfo{{ID: "test-model"}})
|
||||||
|
registry.GetGlobalRegistry().RegisterClient(auth2.ID, auth2.Provider, []*registry.ModelInfo{{ID: "test-model"}})
|
||||||
|
t.Cleanup(func() {
|
||||||
|
registry.GetGlobalRegistry().UnregisterClient(auth1.ID)
|
||||||
|
registry.GetGlobalRegistry().UnregisterClient(auth2.ID)
|
||||||
|
})
|
||||||
|
|
||||||
|
handler := NewBaseAPIHandlers(&sdkconfig.SDKConfig{
|
||||||
|
Streaming: sdkconfig.StreamingConfig{
|
||||||
|
BootstrapRetries: 1,
|
||||||
|
},
|
||||||
|
}, manager)
|
||||||
|
ctx := WithPinnedAuthID(context.Background(), "auth1")
|
||||||
|
dataChan, errChan := handler.ExecuteStreamWithAuthManager(ctx, "openai", "test-model", []byte(`{"model":"test-model"}`), "")
|
||||||
|
if dataChan == nil || errChan == nil {
|
||||||
|
t.Fatalf("expected non-nil channels")
|
||||||
|
}
|
||||||
|
|
||||||
|
var got []byte
|
||||||
|
for chunk := range dataChan {
|
||||||
|
got = append(got, chunk...)
|
||||||
|
}
|
||||||
|
|
||||||
|
var gotErr error
|
||||||
|
for msg := range errChan {
|
||||||
|
if msg != nil && msg.Error != nil {
|
||||||
|
gotErr = msg.Error
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(got) != 0 {
|
||||||
|
t.Fatalf("expected empty payload, got %q", string(got))
|
||||||
|
}
|
||||||
|
if gotErr == nil {
|
||||||
|
t.Fatalf("expected terminal error, got nil")
|
||||||
|
}
|
||||||
|
authIDs := executor.AuthIDs()
|
||||||
|
if len(authIDs) == 0 {
|
||||||
|
t.Fatalf("expected at least one upstream attempt")
|
||||||
|
}
|
||||||
|
for _, authID := range authIDs {
|
||||||
|
if authID != "auth1" {
|
||||||
|
t.Fatalf("expected all attempts on auth1, got sequence %v", authIDs)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestExecuteStreamWithAuthManager_SelectedAuthCallbackReceivesAuthID(t *testing.T) {
|
||||||
|
executor := &authAwareStreamExecutor{}
|
||||||
|
manager := coreauth.NewManager(nil, nil, nil)
|
||||||
|
manager.RegisterExecutor(executor)
|
||||||
|
|
||||||
|
auth2 := &coreauth.Auth{
|
||||||
|
ID: "auth2",
|
||||||
|
Provider: "codex",
|
||||||
|
Status: coreauth.StatusActive,
|
||||||
|
Metadata: map[string]any{"email": "test2@example.com"},
|
||||||
|
}
|
||||||
|
if _, err := manager.Register(context.Background(), auth2); err != nil {
|
||||||
|
t.Fatalf("manager.Register(auth2): %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
registry.GetGlobalRegistry().RegisterClient(auth2.ID, auth2.Provider, []*registry.ModelInfo{{ID: "test-model"}})
|
||||||
|
t.Cleanup(func() {
|
||||||
|
registry.GetGlobalRegistry().UnregisterClient(auth2.ID)
|
||||||
|
})
|
||||||
|
|
||||||
|
handler := NewBaseAPIHandlers(&sdkconfig.SDKConfig{
|
||||||
|
Streaming: sdkconfig.StreamingConfig{
|
||||||
|
BootstrapRetries: 0,
|
||||||
|
},
|
||||||
|
}, manager)
|
||||||
|
|
||||||
|
selectedAuthID := ""
|
||||||
|
ctx := WithSelectedAuthIDCallback(context.Background(), func(authID string) {
|
||||||
|
selectedAuthID = authID
|
||||||
|
})
|
||||||
|
dataChan, errChan := handler.ExecuteStreamWithAuthManager(ctx, "openai", "test-model", []byte(`{"model":"test-model"}`), "")
|
||||||
|
if dataChan == nil || errChan == nil {
|
||||||
|
t.Fatalf("expected non-nil channels")
|
||||||
|
}
|
||||||
|
|
||||||
|
var got []byte
|
||||||
|
for chunk := range dataChan {
|
||||||
|
got = append(got, chunk...)
|
||||||
|
}
|
||||||
|
for msg := range errChan {
|
||||||
|
if msg != nil {
|
||||||
|
t.Fatalf("unexpected error: %+v", msg)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if string(got) != "ok" {
|
||||||
|
t.Fatalf("expected payload ok, got %q", string(got))
|
||||||
|
}
|
||||||
|
if selectedAuthID != "auth2" {
|
||||||
|
t.Fatalf("selectedAuthID = %q, want %q", selectedAuthID, "auth2")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
662
sdk/api/handlers/openai/openai_responses_websocket.go
Normal file
662
sdk/api/handlers/openai/openai_responses_websocket.go
Normal file
@@ -0,0 +1,662 @@
|
|||||||
|
package openai
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/google/uuid"
|
||||||
|
"github.com/gorilla/websocket"
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/interfaces"
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/sdk/api/handlers"
|
||||||
|
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
"github.com/tidwall/gjson"
|
||||||
|
"github.com/tidwall/sjson"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
wsRequestTypeCreate = "response.create"
|
||||||
|
wsRequestTypeAppend = "response.append"
|
||||||
|
wsEventTypeError = "error"
|
||||||
|
wsEventTypeCompleted = "response.completed"
|
||||||
|
wsEventTypeDone = "response.done"
|
||||||
|
wsDoneMarker = "[DONE]"
|
||||||
|
wsTurnStateHeader = "x-codex-turn-state"
|
||||||
|
wsRequestBodyKey = "REQUEST_BODY_OVERRIDE"
|
||||||
|
wsPayloadLogMaxSize = 2048
|
||||||
|
)
|
||||||
|
|
||||||
|
var responsesWebsocketUpgrader = websocket.Upgrader{
|
||||||
|
ReadBufferSize: 4096,
|
||||||
|
WriteBufferSize: 4096,
|
||||||
|
CheckOrigin: func(r *http.Request) bool {
|
||||||
|
return true
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
// ResponsesWebsocket handles websocket requests for /v1/responses.
|
||||||
|
// It accepts `response.create` and `response.append` requests and streams
|
||||||
|
// response events back as JSON websocket text messages.
|
||||||
|
func (h *OpenAIResponsesAPIHandler) ResponsesWebsocket(c *gin.Context) {
|
||||||
|
conn, err := responsesWebsocketUpgrader.Upgrade(c.Writer, c.Request, websocketUpgradeHeaders(c.Request))
|
||||||
|
if err != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
passthroughSessionID := uuid.NewString()
|
||||||
|
clientRemoteAddr := ""
|
||||||
|
if c != nil && c.Request != nil {
|
||||||
|
clientRemoteAddr = strings.TrimSpace(c.Request.RemoteAddr)
|
||||||
|
}
|
||||||
|
log.Infof("responses websocket: client connected id=%s remote=%s", passthroughSessionID, clientRemoteAddr)
|
||||||
|
var wsTerminateErr error
|
||||||
|
var wsBodyLog strings.Builder
|
||||||
|
defer func() {
|
||||||
|
if wsTerminateErr != nil {
|
||||||
|
// log.Infof("responses websocket: session closing id=%s reason=%v", passthroughSessionID, wsTerminateErr)
|
||||||
|
} else {
|
||||||
|
log.Infof("responses websocket: session closing id=%s", passthroughSessionID)
|
||||||
|
}
|
||||||
|
if h != nil && h.AuthManager != nil {
|
||||||
|
h.AuthManager.CloseExecutionSession(passthroughSessionID)
|
||||||
|
log.Infof("responses websocket: upstream execution session closed id=%s", passthroughSessionID)
|
||||||
|
}
|
||||||
|
setWebsocketRequestBody(c, wsBodyLog.String())
|
||||||
|
if errClose := conn.Close(); errClose != nil {
|
||||||
|
log.Warnf("responses websocket: close connection error: %v", errClose)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
var lastRequest []byte
|
||||||
|
lastResponseOutput := []byte("[]")
|
||||||
|
pinnedAuthID := ""
|
||||||
|
|
||||||
|
for {
|
||||||
|
msgType, payload, errReadMessage := conn.ReadMessage()
|
||||||
|
if errReadMessage != nil {
|
||||||
|
wsTerminateErr = errReadMessage
|
||||||
|
appendWebsocketEvent(&wsBodyLog, "disconnect", []byte(errReadMessage.Error()))
|
||||||
|
if websocket.IsCloseError(errReadMessage, websocket.CloseNormalClosure, websocket.CloseGoingAway, websocket.CloseNoStatusReceived) {
|
||||||
|
log.Infof("responses websocket: client disconnected id=%s error=%v", passthroughSessionID, errReadMessage)
|
||||||
|
} else {
|
||||||
|
// log.Warnf("responses websocket: read message failed id=%s error=%v", passthroughSessionID, errReadMessage)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if msgType != websocket.TextMessage && msgType != websocket.BinaryMessage {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
// log.Infof(
|
||||||
|
// "responses websocket: downstream_in id=%s type=%d event=%s payload=%s",
|
||||||
|
// passthroughSessionID,
|
||||||
|
// msgType,
|
||||||
|
// websocketPayloadEventType(payload),
|
||||||
|
// websocketPayloadPreview(payload),
|
||||||
|
// )
|
||||||
|
appendWebsocketEvent(&wsBodyLog, "request", payload)
|
||||||
|
|
||||||
|
allowIncrementalInputWithPreviousResponseID := websocketUpstreamSupportsIncrementalInput(nil, nil)
|
||||||
|
if pinnedAuthID != "" && h != nil && h.AuthManager != nil {
|
||||||
|
if pinnedAuth, ok := h.AuthManager.GetByID(pinnedAuthID); ok && pinnedAuth != nil {
|
||||||
|
allowIncrementalInputWithPreviousResponseID = websocketUpstreamSupportsIncrementalInput(pinnedAuth.Attributes, pinnedAuth.Metadata)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
var requestJSON []byte
|
||||||
|
var updatedLastRequest []byte
|
||||||
|
var errMsg *interfaces.ErrorMessage
|
||||||
|
requestJSON, updatedLastRequest, errMsg = normalizeResponsesWebsocketRequestWithMode(
|
||||||
|
payload,
|
||||||
|
lastRequest,
|
||||||
|
lastResponseOutput,
|
||||||
|
allowIncrementalInputWithPreviousResponseID,
|
||||||
|
)
|
||||||
|
if errMsg != nil {
|
||||||
|
h.LoggingAPIResponseError(context.WithValue(context.Background(), "gin", c), errMsg)
|
||||||
|
markAPIResponseTimestamp(c)
|
||||||
|
errorPayload, errWrite := writeResponsesWebsocketError(conn, errMsg)
|
||||||
|
appendWebsocketEvent(&wsBodyLog, "response", errorPayload)
|
||||||
|
log.Infof(
|
||||||
|
"responses websocket: downstream_out id=%s type=%d event=%s payload=%s",
|
||||||
|
passthroughSessionID,
|
||||||
|
websocket.TextMessage,
|
||||||
|
websocketPayloadEventType(errorPayload),
|
||||||
|
websocketPayloadPreview(errorPayload),
|
||||||
|
)
|
||||||
|
if errWrite != nil {
|
||||||
|
log.Warnf(
|
||||||
|
"responses websocket: downstream_out write failed id=%s event=%s error=%v",
|
||||||
|
passthroughSessionID,
|
||||||
|
websocketPayloadEventType(errorPayload),
|
||||||
|
errWrite,
|
||||||
|
)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
lastRequest = updatedLastRequest
|
||||||
|
|
||||||
|
modelName := gjson.GetBytes(requestJSON, "model").String()
|
||||||
|
cliCtx, cliCancel := h.GetContextWithCancel(h, c, context.Background())
|
||||||
|
cliCtx = cliproxyexecutor.WithDownstreamWebsocket(cliCtx)
|
||||||
|
cliCtx = handlers.WithExecutionSessionID(cliCtx, passthroughSessionID)
|
||||||
|
if pinnedAuthID != "" {
|
||||||
|
cliCtx = handlers.WithPinnedAuthID(cliCtx, pinnedAuthID)
|
||||||
|
} else {
|
||||||
|
cliCtx = handlers.WithSelectedAuthIDCallback(cliCtx, func(authID string) {
|
||||||
|
pinnedAuthID = strings.TrimSpace(authID)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
dataChan, errChan := h.ExecuteStreamWithAuthManager(cliCtx, h.HandlerType(), modelName, requestJSON, "")
|
||||||
|
|
||||||
|
completedOutput, errForward := h.forwardResponsesWebsocket(c, conn, cliCancel, dataChan, errChan, &wsBodyLog, passthroughSessionID)
|
||||||
|
if errForward != nil {
|
||||||
|
wsTerminateErr = errForward
|
||||||
|
appendWebsocketEvent(&wsBodyLog, "disconnect", []byte(errForward.Error()))
|
||||||
|
log.Warnf("responses websocket: forward failed id=%s error=%v", passthroughSessionID, errForward)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
lastResponseOutput = completedOutput
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func websocketUpgradeHeaders(req *http.Request) http.Header {
|
||||||
|
headers := http.Header{}
|
||||||
|
if req == nil {
|
||||||
|
return headers
|
||||||
|
}
|
||||||
|
|
||||||
|
// Keep the same sticky turn-state across reconnects when provided by the client.
|
||||||
|
turnState := strings.TrimSpace(req.Header.Get(wsTurnStateHeader))
|
||||||
|
if turnState != "" {
|
||||||
|
headers.Set(wsTurnStateHeader, turnState)
|
||||||
|
}
|
||||||
|
return headers
|
||||||
|
}
|
||||||
|
|
||||||
|
func normalizeResponsesWebsocketRequest(rawJSON []byte, lastRequest []byte, lastResponseOutput []byte) ([]byte, []byte, *interfaces.ErrorMessage) {
|
||||||
|
return normalizeResponsesWebsocketRequestWithMode(rawJSON, lastRequest, lastResponseOutput, true)
|
||||||
|
}
|
||||||
|
|
||||||
|
func normalizeResponsesWebsocketRequestWithMode(rawJSON []byte, lastRequest []byte, lastResponseOutput []byte, allowIncrementalInputWithPreviousResponseID bool) ([]byte, []byte, *interfaces.ErrorMessage) {
|
||||||
|
requestType := strings.TrimSpace(gjson.GetBytes(rawJSON, "type").String())
|
||||||
|
switch requestType {
|
||||||
|
case wsRequestTypeCreate:
|
||||||
|
// log.Infof("responses websocket: response.create request")
|
||||||
|
if len(lastRequest) == 0 {
|
||||||
|
return normalizeResponseCreateRequest(rawJSON)
|
||||||
|
}
|
||||||
|
return normalizeResponseSubsequentRequest(rawJSON, lastRequest, lastResponseOutput, allowIncrementalInputWithPreviousResponseID)
|
||||||
|
case wsRequestTypeAppend:
|
||||||
|
// log.Infof("responses websocket: response.append request")
|
||||||
|
return normalizeResponseSubsequentRequest(rawJSON, lastRequest, lastResponseOutput, allowIncrementalInputWithPreviousResponseID)
|
||||||
|
default:
|
||||||
|
return nil, lastRequest, &interfaces.ErrorMessage{
|
||||||
|
StatusCode: http.StatusBadRequest,
|
||||||
|
Error: fmt.Errorf("unsupported websocket request type: %s", requestType),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func normalizeResponseCreateRequest(rawJSON []byte) ([]byte, []byte, *interfaces.ErrorMessage) {
|
||||||
|
normalized, errDelete := sjson.DeleteBytes(rawJSON, "type")
|
||||||
|
if errDelete != nil {
|
||||||
|
normalized = bytes.Clone(rawJSON)
|
||||||
|
}
|
||||||
|
normalized, _ = sjson.SetBytes(normalized, "stream", true)
|
||||||
|
if !gjson.GetBytes(normalized, "input").Exists() {
|
||||||
|
normalized, _ = sjson.SetRawBytes(normalized, "input", []byte("[]"))
|
||||||
|
}
|
||||||
|
|
||||||
|
modelName := strings.TrimSpace(gjson.GetBytes(normalized, "model").String())
|
||||||
|
if modelName == "" {
|
||||||
|
return nil, nil, &interfaces.ErrorMessage{
|
||||||
|
StatusCode: http.StatusBadRequest,
|
||||||
|
Error: fmt.Errorf("missing model in response.create request"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return normalized, bytes.Clone(normalized), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func normalizeResponseSubsequentRequest(rawJSON []byte, lastRequest []byte, lastResponseOutput []byte, allowIncrementalInputWithPreviousResponseID bool) ([]byte, []byte, *interfaces.ErrorMessage) {
|
||||||
|
if len(lastRequest) == 0 {
|
||||||
|
return nil, lastRequest, &interfaces.ErrorMessage{
|
||||||
|
StatusCode: http.StatusBadRequest,
|
||||||
|
Error: fmt.Errorf("websocket request received before response.create"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
nextInput := gjson.GetBytes(rawJSON, "input")
|
||||||
|
if !nextInput.Exists() || !nextInput.IsArray() {
|
||||||
|
return nil, lastRequest, &interfaces.ErrorMessage{
|
||||||
|
StatusCode: http.StatusBadRequest,
|
||||||
|
Error: fmt.Errorf("websocket request requires array field: input"),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Websocket v2 mode uses response.create with previous_response_id + incremental input.
|
||||||
|
// Do not expand it into a full input transcript; upstream expects the incremental payload.
|
||||||
|
if allowIncrementalInputWithPreviousResponseID {
|
||||||
|
if prev := strings.TrimSpace(gjson.GetBytes(rawJSON, "previous_response_id").String()); prev != "" {
|
||||||
|
normalized, errDelete := sjson.DeleteBytes(rawJSON, "type")
|
||||||
|
if errDelete != nil {
|
||||||
|
normalized = bytes.Clone(rawJSON)
|
||||||
|
}
|
||||||
|
if !gjson.GetBytes(normalized, "model").Exists() {
|
||||||
|
modelName := strings.TrimSpace(gjson.GetBytes(lastRequest, "model").String())
|
||||||
|
if modelName != "" {
|
||||||
|
normalized, _ = sjson.SetBytes(normalized, "model", modelName)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !gjson.GetBytes(normalized, "instructions").Exists() {
|
||||||
|
instructions := gjson.GetBytes(lastRequest, "instructions")
|
||||||
|
if instructions.Exists() {
|
||||||
|
normalized, _ = sjson.SetRawBytes(normalized, "instructions", []byte(instructions.Raw))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
normalized, _ = sjson.SetBytes(normalized, "stream", true)
|
||||||
|
return normalized, bytes.Clone(normalized), nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
existingInput := gjson.GetBytes(lastRequest, "input")
|
||||||
|
mergedInput, errMerge := mergeJSONArrayRaw(existingInput.Raw, normalizeJSONArrayRaw(lastResponseOutput))
|
||||||
|
if errMerge != nil {
|
||||||
|
return nil, lastRequest, &interfaces.ErrorMessage{
|
||||||
|
StatusCode: http.StatusBadRequest,
|
||||||
|
Error: fmt.Errorf("invalid previous response output: %w", errMerge),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
mergedInput, errMerge = mergeJSONArrayRaw(mergedInput, nextInput.Raw)
|
||||||
|
if errMerge != nil {
|
||||||
|
return nil, lastRequest, &interfaces.ErrorMessage{
|
||||||
|
StatusCode: http.StatusBadRequest,
|
||||||
|
Error: fmt.Errorf("invalid request input: %w", errMerge),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
normalized, errDelete := sjson.DeleteBytes(rawJSON, "type")
|
||||||
|
if errDelete != nil {
|
||||||
|
normalized = bytes.Clone(rawJSON)
|
||||||
|
}
|
||||||
|
normalized, _ = sjson.DeleteBytes(normalized, "previous_response_id")
|
||||||
|
var errSet error
|
||||||
|
normalized, errSet = sjson.SetRawBytes(normalized, "input", []byte(mergedInput))
|
||||||
|
if errSet != nil {
|
||||||
|
return nil, lastRequest, &interfaces.ErrorMessage{
|
||||||
|
StatusCode: http.StatusBadRequest,
|
||||||
|
Error: fmt.Errorf("failed to merge websocket input: %w", errSet),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !gjson.GetBytes(normalized, "model").Exists() {
|
||||||
|
modelName := strings.TrimSpace(gjson.GetBytes(lastRequest, "model").String())
|
||||||
|
if modelName != "" {
|
||||||
|
normalized, _ = sjson.SetBytes(normalized, "model", modelName)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !gjson.GetBytes(normalized, "instructions").Exists() {
|
||||||
|
instructions := gjson.GetBytes(lastRequest, "instructions")
|
||||||
|
if instructions.Exists() {
|
||||||
|
normalized, _ = sjson.SetRawBytes(normalized, "instructions", []byte(instructions.Raw))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
normalized, _ = sjson.SetBytes(normalized, "stream", true)
|
||||||
|
return normalized, bytes.Clone(normalized), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func websocketUpstreamSupportsIncrementalInput(attributes map[string]string, metadata map[string]any) bool {
|
||||||
|
if len(attributes) > 0 {
|
||||||
|
if raw := strings.TrimSpace(attributes["websockets"]); raw != "" {
|
||||||
|
parsed, errParse := strconv.ParseBool(raw)
|
||||||
|
if errParse == nil {
|
||||||
|
return parsed
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if len(metadata) == 0 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
raw, ok := metadata["websockets"]
|
||||||
|
if !ok || raw == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
switch value := raw.(type) {
|
||||||
|
case bool:
|
||||||
|
return value
|
||||||
|
case string:
|
||||||
|
parsed, errParse := strconv.ParseBool(strings.TrimSpace(value))
|
||||||
|
if errParse == nil {
|
||||||
|
return parsed
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func mergeJSONArrayRaw(existingRaw, appendRaw string) (string, error) {
|
||||||
|
existingRaw = strings.TrimSpace(existingRaw)
|
||||||
|
appendRaw = strings.TrimSpace(appendRaw)
|
||||||
|
if existingRaw == "" {
|
||||||
|
existingRaw = "[]"
|
||||||
|
}
|
||||||
|
if appendRaw == "" {
|
||||||
|
appendRaw = "[]"
|
||||||
|
}
|
||||||
|
|
||||||
|
var existing []json.RawMessage
|
||||||
|
if err := json.Unmarshal([]byte(existingRaw), &existing); err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
var appendItems []json.RawMessage
|
||||||
|
if err := json.Unmarshal([]byte(appendRaw), &appendItems); err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
|
||||||
|
merged := append(existing, appendItems...)
|
||||||
|
out, err := json.Marshal(merged)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
return string(out), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func normalizeJSONArrayRaw(raw []byte) string {
|
||||||
|
trimmed := strings.TrimSpace(string(raw))
|
||||||
|
if trimmed == "" {
|
||||||
|
return "[]"
|
||||||
|
}
|
||||||
|
result := gjson.Parse(trimmed)
|
||||||
|
if result.Type == gjson.JSON && result.IsArray() {
|
||||||
|
return trimmed
|
||||||
|
}
|
||||||
|
return "[]"
|
||||||
|
}
|
||||||
|
|
||||||
|
func (h *OpenAIResponsesAPIHandler) forwardResponsesWebsocket(
|
||||||
|
c *gin.Context,
|
||||||
|
conn *websocket.Conn,
|
||||||
|
cancel handlers.APIHandlerCancelFunc,
|
||||||
|
data <-chan []byte,
|
||||||
|
errs <-chan *interfaces.ErrorMessage,
|
||||||
|
wsBodyLog *strings.Builder,
|
||||||
|
sessionID string,
|
||||||
|
) ([]byte, error) {
|
||||||
|
completed := false
|
||||||
|
completedOutput := []byte("[]")
|
||||||
|
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-c.Request.Context().Done():
|
||||||
|
cancel(c.Request.Context().Err())
|
||||||
|
return completedOutput, c.Request.Context().Err()
|
||||||
|
case errMsg, ok := <-errs:
|
||||||
|
if !ok {
|
||||||
|
errs = nil
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if errMsg != nil {
|
||||||
|
h.LoggingAPIResponseError(context.WithValue(context.Background(), "gin", c), errMsg)
|
||||||
|
markAPIResponseTimestamp(c)
|
||||||
|
errorPayload, errWrite := writeResponsesWebsocketError(conn, errMsg)
|
||||||
|
appendWebsocketEvent(wsBodyLog, "response", errorPayload)
|
||||||
|
log.Infof(
|
||||||
|
"responses websocket: downstream_out id=%s type=%d event=%s payload=%s",
|
||||||
|
sessionID,
|
||||||
|
websocket.TextMessage,
|
||||||
|
websocketPayloadEventType(errorPayload),
|
||||||
|
websocketPayloadPreview(errorPayload),
|
||||||
|
)
|
||||||
|
if errWrite != nil {
|
||||||
|
// log.Warnf(
|
||||||
|
// "responses websocket: downstream_out write failed id=%s event=%s error=%v",
|
||||||
|
// sessionID,
|
||||||
|
// websocketPayloadEventType(errorPayload),
|
||||||
|
// errWrite,
|
||||||
|
// )
|
||||||
|
cancel(errMsg.Error)
|
||||||
|
return completedOutput, errWrite
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if errMsg != nil {
|
||||||
|
cancel(errMsg.Error)
|
||||||
|
} else {
|
||||||
|
cancel(nil)
|
||||||
|
}
|
||||||
|
return completedOutput, nil
|
||||||
|
case chunk, ok := <-data:
|
||||||
|
if !ok {
|
||||||
|
if !completed {
|
||||||
|
errMsg := &interfaces.ErrorMessage{
|
||||||
|
StatusCode: http.StatusRequestTimeout,
|
||||||
|
Error: fmt.Errorf("stream closed before response.completed"),
|
||||||
|
}
|
||||||
|
h.LoggingAPIResponseError(context.WithValue(context.Background(), "gin", c), errMsg)
|
||||||
|
markAPIResponseTimestamp(c)
|
||||||
|
errorPayload, errWrite := writeResponsesWebsocketError(conn, errMsg)
|
||||||
|
appendWebsocketEvent(wsBodyLog, "response", errorPayload)
|
||||||
|
log.Infof(
|
||||||
|
"responses websocket: downstream_out id=%s type=%d event=%s payload=%s",
|
||||||
|
sessionID,
|
||||||
|
websocket.TextMessage,
|
||||||
|
websocketPayloadEventType(errorPayload),
|
||||||
|
websocketPayloadPreview(errorPayload),
|
||||||
|
)
|
||||||
|
if errWrite != nil {
|
||||||
|
log.Warnf(
|
||||||
|
"responses websocket: downstream_out write failed id=%s event=%s error=%v",
|
||||||
|
sessionID,
|
||||||
|
websocketPayloadEventType(errorPayload),
|
||||||
|
errWrite,
|
||||||
|
)
|
||||||
|
cancel(errMsg.Error)
|
||||||
|
return completedOutput, errWrite
|
||||||
|
}
|
||||||
|
cancel(errMsg.Error)
|
||||||
|
return completedOutput, nil
|
||||||
|
}
|
||||||
|
cancel(nil)
|
||||||
|
return completedOutput, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
payloads := websocketJSONPayloadsFromChunk(chunk)
|
||||||
|
for i := range payloads {
|
||||||
|
eventType := gjson.GetBytes(payloads[i], "type").String()
|
||||||
|
if eventType == wsEventTypeCompleted {
|
||||||
|
// log.Infof("replace %s with %s", wsEventTypeCompleted, wsEventTypeDone)
|
||||||
|
payloads[i], _ = sjson.SetBytes(payloads[i], "type", wsEventTypeDone)
|
||||||
|
|
||||||
|
completed = true
|
||||||
|
completedOutput = responseCompletedOutputFromPayload(payloads[i])
|
||||||
|
}
|
||||||
|
markAPIResponseTimestamp(c)
|
||||||
|
appendWebsocketEvent(wsBodyLog, "response", payloads[i])
|
||||||
|
// log.Infof(
|
||||||
|
// "responses websocket: downstream_out id=%s type=%d event=%s payload=%s",
|
||||||
|
// sessionID,
|
||||||
|
// websocket.TextMessage,
|
||||||
|
// websocketPayloadEventType(payloads[i]),
|
||||||
|
// websocketPayloadPreview(payloads[i]),
|
||||||
|
// )
|
||||||
|
if errWrite := conn.WriteMessage(websocket.TextMessage, payloads[i]); errWrite != nil {
|
||||||
|
log.Warnf(
|
||||||
|
"responses websocket: downstream_out write failed id=%s event=%s error=%v",
|
||||||
|
sessionID,
|
||||||
|
websocketPayloadEventType(payloads[i]),
|
||||||
|
errWrite,
|
||||||
|
)
|
||||||
|
cancel(errWrite)
|
||||||
|
return completedOutput, errWrite
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func responseCompletedOutputFromPayload(payload []byte) []byte {
|
||||||
|
output := gjson.GetBytes(payload, "response.output")
|
||||||
|
if output.Exists() && output.IsArray() {
|
||||||
|
return bytes.Clone([]byte(output.Raw))
|
||||||
|
}
|
||||||
|
return []byte("[]")
|
||||||
|
}
|
||||||
|
|
||||||
|
func websocketJSONPayloadsFromChunk(chunk []byte) [][]byte {
|
||||||
|
payloads := make([][]byte, 0, 2)
|
||||||
|
lines := bytes.Split(chunk, []byte("\n"))
|
||||||
|
for i := range lines {
|
||||||
|
line := bytes.TrimSpace(lines[i])
|
||||||
|
if len(line) == 0 || bytes.HasPrefix(line, []byte("event:")) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if bytes.HasPrefix(line, []byte("data:")) {
|
||||||
|
line = bytes.TrimSpace(line[len("data:"):])
|
||||||
|
}
|
||||||
|
if len(line) == 0 || bytes.Equal(line, []byte(wsDoneMarker)) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if json.Valid(line) {
|
||||||
|
payloads = append(payloads, bytes.Clone(line))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(payloads) > 0 {
|
||||||
|
return payloads
|
||||||
|
}
|
||||||
|
|
||||||
|
trimmed := bytes.TrimSpace(chunk)
|
||||||
|
if bytes.HasPrefix(trimmed, []byte("data:")) {
|
||||||
|
trimmed = bytes.TrimSpace(trimmed[len("data:"):])
|
||||||
|
}
|
||||||
|
if len(trimmed) > 0 && !bytes.Equal(trimmed, []byte(wsDoneMarker)) && json.Valid(trimmed) {
|
||||||
|
payloads = append(payloads, bytes.Clone(trimmed))
|
||||||
|
}
|
||||||
|
return payloads
|
||||||
|
}
|
||||||
|
|
||||||
|
func writeResponsesWebsocketError(conn *websocket.Conn, errMsg *interfaces.ErrorMessage) ([]byte, error) {
|
||||||
|
status := http.StatusInternalServerError
|
||||||
|
errText := http.StatusText(status)
|
||||||
|
if errMsg != nil {
|
||||||
|
if errMsg.StatusCode > 0 {
|
||||||
|
status = errMsg.StatusCode
|
||||||
|
errText = http.StatusText(status)
|
||||||
|
}
|
||||||
|
if errMsg.Error != nil && strings.TrimSpace(errMsg.Error.Error()) != "" {
|
||||||
|
errText = errMsg.Error.Error()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
body := handlers.BuildErrorResponseBody(status, errText)
|
||||||
|
payload := map[string]any{
|
||||||
|
"type": wsEventTypeError,
|
||||||
|
"status": status,
|
||||||
|
}
|
||||||
|
|
||||||
|
if errMsg != nil && errMsg.Addon != nil {
|
||||||
|
headers := map[string]any{}
|
||||||
|
for key, values := range errMsg.Addon {
|
||||||
|
if len(values) == 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
headers[key] = values[0]
|
||||||
|
}
|
||||||
|
if len(headers) > 0 {
|
||||||
|
payload["headers"] = headers
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(body) > 0 && json.Valid(body) {
|
||||||
|
var decoded map[string]any
|
||||||
|
if errDecode := json.Unmarshal(body, &decoded); errDecode == nil {
|
||||||
|
if inner, ok := decoded["error"]; ok {
|
||||||
|
payload["error"] = inner
|
||||||
|
} else {
|
||||||
|
payload["error"] = decoded
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, ok := payload["error"]; !ok {
|
||||||
|
payload["error"] = map[string]any{
|
||||||
|
"type": "server_error",
|
||||||
|
"message": errText,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
data, err := json.Marshal(payload)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return data, conn.WriteMessage(websocket.TextMessage, data)
|
||||||
|
}
|
||||||
|
|
||||||
|
func appendWebsocketEvent(builder *strings.Builder, eventType string, payload []byte) {
|
||||||
|
if builder == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
trimmedPayload := bytes.TrimSpace(payload)
|
||||||
|
if len(trimmedPayload) == 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if builder.Len() > 0 {
|
||||||
|
builder.WriteString("\n")
|
||||||
|
}
|
||||||
|
builder.WriteString("websocket.")
|
||||||
|
builder.WriteString(eventType)
|
||||||
|
builder.WriteString("\n")
|
||||||
|
builder.Write(trimmedPayload)
|
||||||
|
builder.WriteString("\n")
|
||||||
|
}
|
||||||
|
|
||||||
|
func websocketPayloadEventType(payload []byte) string {
|
||||||
|
eventType := strings.TrimSpace(gjson.GetBytes(payload, "type").String())
|
||||||
|
if eventType == "" {
|
||||||
|
return "-"
|
||||||
|
}
|
||||||
|
return eventType
|
||||||
|
}
|
||||||
|
|
||||||
|
func websocketPayloadPreview(payload []byte) string {
|
||||||
|
trimmedPayload := bytes.TrimSpace(payload)
|
||||||
|
if len(trimmedPayload) == 0 {
|
||||||
|
return "<empty>"
|
||||||
|
}
|
||||||
|
preview := trimmedPayload
|
||||||
|
if len(preview) > wsPayloadLogMaxSize {
|
||||||
|
preview = preview[:wsPayloadLogMaxSize]
|
||||||
|
}
|
||||||
|
previewText := strings.ReplaceAll(string(preview), "\n", "\\n")
|
||||||
|
previewText = strings.ReplaceAll(previewText, "\r", "\\r")
|
||||||
|
if len(trimmedPayload) > wsPayloadLogMaxSize {
|
||||||
|
return fmt.Sprintf("%s...(truncated,total=%d)", previewText, len(trimmedPayload))
|
||||||
|
}
|
||||||
|
return previewText
|
||||||
|
}
|
||||||
|
|
||||||
|
func setWebsocketRequestBody(c *gin.Context, body string) {
|
||||||
|
if c == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
trimmedBody := strings.TrimSpace(body)
|
||||||
|
if trimmedBody == "" {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
c.Set(wsRequestBodyKey, []byte(trimmedBody))
|
||||||
|
}
|
||||||
|
|
||||||
|
func markAPIResponseTimestamp(c *gin.Context) {
|
||||||
|
if c == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if _, exists := c.Get("API_RESPONSE_TIMESTAMP"); exists {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
c.Set("API_RESPONSE_TIMESTAMP", time.Now())
|
||||||
|
}
|
||||||
249
sdk/api/handlers/openai/openai_responses_websocket_test.go
Normal file
249
sdk/api/handlers/openai/openai_responses_websocket_test.go
Normal file
@@ -0,0 +1,249 @@
|
|||||||
|
package openai
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/tidwall/gjson"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestNormalizeResponsesWebsocketRequestCreate(t *testing.T) {
|
||||||
|
raw := []byte(`{"type":"response.create","model":"test-model","stream":false,"input":[{"type":"message","id":"msg-1"}]}`)
|
||||||
|
|
||||||
|
normalized, last, errMsg := normalizeResponsesWebsocketRequest(raw, nil, nil)
|
||||||
|
if errMsg != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", errMsg.Error)
|
||||||
|
}
|
||||||
|
if gjson.GetBytes(normalized, "type").Exists() {
|
||||||
|
t.Fatalf("normalized create request must not include type field")
|
||||||
|
}
|
||||||
|
if !gjson.GetBytes(normalized, "stream").Bool() {
|
||||||
|
t.Fatalf("normalized create request must force stream=true")
|
||||||
|
}
|
||||||
|
if gjson.GetBytes(normalized, "model").String() != "test-model" {
|
||||||
|
t.Fatalf("unexpected model: %s", gjson.GetBytes(normalized, "model").String())
|
||||||
|
}
|
||||||
|
if !bytes.Equal(last, normalized) {
|
||||||
|
t.Fatalf("last request snapshot should match normalized request")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNormalizeResponsesWebsocketRequestCreateWithHistory(t *testing.T) {
|
||||||
|
lastRequest := []byte(`{"model":"test-model","stream":true,"input":[{"type":"message","id":"msg-1"}]}`)
|
||||||
|
lastResponseOutput := []byte(`[
|
||||||
|
{"type":"function_call","id":"fc-1","call_id":"call-1"},
|
||||||
|
{"type":"message","id":"assistant-1"}
|
||||||
|
]`)
|
||||||
|
raw := []byte(`{"type":"response.create","input":[{"type":"function_call_output","call_id":"call-1","id":"tool-out-1"}]}`)
|
||||||
|
|
||||||
|
normalized, next, errMsg := normalizeResponsesWebsocketRequest(raw, lastRequest, lastResponseOutput)
|
||||||
|
if errMsg != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", errMsg.Error)
|
||||||
|
}
|
||||||
|
if gjson.GetBytes(normalized, "type").Exists() {
|
||||||
|
t.Fatalf("normalized subsequent create request must not include type field")
|
||||||
|
}
|
||||||
|
if gjson.GetBytes(normalized, "model").String() != "test-model" {
|
||||||
|
t.Fatalf("unexpected model: %s", gjson.GetBytes(normalized, "model").String())
|
||||||
|
}
|
||||||
|
|
||||||
|
input := gjson.GetBytes(normalized, "input").Array()
|
||||||
|
if len(input) != 4 {
|
||||||
|
t.Fatalf("merged input len = %d, want 4", len(input))
|
||||||
|
}
|
||||||
|
if input[0].Get("id").String() != "msg-1" ||
|
||||||
|
input[1].Get("id").String() != "fc-1" ||
|
||||||
|
input[2].Get("id").String() != "assistant-1" ||
|
||||||
|
input[3].Get("id").String() != "tool-out-1" {
|
||||||
|
t.Fatalf("unexpected merged input order")
|
||||||
|
}
|
||||||
|
if !bytes.Equal(next, normalized) {
|
||||||
|
t.Fatalf("next request snapshot should match normalized request")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNormalizeResponsesWebsocketRequestWithPreviousResponseIDIncremental(t *testing.T) {
|
||||||
|
lastRequest := []byte(`{"model":"test-model","stream":true,"instructions":"be helpful","input":[{"type":"message","id":"msg-1"}]}`)
|
||||||
|
lastResponseOutput := []byte(`[
|
||||||
|
{"type":"function_call","id":"fc-1","call_id":"call-1"},
|
||||||
|
{"type":"message","id":"assistant-1"}
|
||||||
|
]`)
|
||||||
|
raw := []byte(`{"type":"response.create","previous_response_id":"resp-1","input":[{"type":"function_call_output","call_id":"call-1","id":"tool-out-1"}]}`)
|
||||||
|
|
||||||
|
normalized, next, errMsg := normalizeResponsesWebsocketRequestWithMode(raw, lastRequest, lastResponseOutput, true)
|
||||||
|
if errMsg != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", errMsg.Error)
|
||||||
|
}
|
||||||
|
if gjson.GetBytes(normalized, "type").Exists() {
|
||||||
|
t.Fatalf("normalized request must not include type field")
|
||||||
|
}
|
||||||
|
if gjson.GetBytes(normalized, "previous_response_id").String() != "resp-1" {
|
||||||
|
t.Fatalf("previous_response_id must be preserved in incremental mode")
|
||||||
|
}
|
||||||
|
input := gjson.GetBytes(normalized, "input").Array()
|
||||||
|
if len(input) != 1 {
|
||||||
|
t.Fatalf("incremental input len = %d, want 1", len(input))
|
||||||
|
}
|
||||||
|
if input[0].Get("id").String() != "tool-out-1" {
|
||||||
|
t.Fatalf("unexpected incremental input item id: %s", input[0].Get("id").String())
|
||||||
|
}
|
||||||
|
if gjson.GetBytes(normalized, "model").String() != "test-model" {
|
||||||
|
t.Fatalf("unexpected model: %s", gjson.GetBytes(normalized, "model").String())
|
||||||
|
}
|
||||||
|
if gjson.GetBytes(normalized, "instructions").String() != "be helpful" {
|
||||||
|
t.Fatalf("unexpected instructions: %s", gjson.GetBytes(normalized, "instructions").String())
|
||||||
|
}
|
||||||
|
if !bytes.Equal(next, normalized) {
|
||||||
|
t.Fatalf("next request snapshot should match normalized request")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNormalizeResponsesWebsocketRequestWithPreviousResponseIDMergedWhenIncrementalDisabled(t *testing.T) {
|
||||||
|
lastRequest := []byte(`{"model":"test-model","stream":true,"input":[{"type":"message","id":"msg-1"}]}`)
|
||||||
|
lastResponseOutput := []byte(`[
|
||||||
|
{"type":"function_call","id":"fc-1","call_id":"call-1"},
|
||||||
|
{"type":"message","id":"assistant-1"}
|
||||||
|
]`)
|
||||||
|
raw := []byte(`{"type":"response.create","previous_response_id":"resp-1","input":[{"type":"function_call_output","call_id":"call-1","id":"tool-out-1"}]}`)
|
||||||
|
|
||||||
|
normalized, next, errMsg := normalizeResponsesWebsocketRequestWithMode(raw, lastRequest, lastResponseOutput, false)
|
||||||
|
if errMsg != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", errMsg.Error)
|
||||||
|
}
|
||||||
|
if gjson.GetBytes(normalized, "previous_response_id").Exists() {
|
||||||
|
t.Fatalf("previous_response_id must be removed when incremental mode is disabled")
|
||||||
|
}
|
||||||
|
input := gjson.GetBytes(normalized, "input").Array()
|
||||||
|
if len(input) != 4 {
|
||||||
|
t.Fatalf("merged input len = %d, want 4", len(input))
|
||||||
|
}
|
||||||
|
if input[0].Get("id").String() != "msg-1" ||
|
||||||
|
input[1].Get("id").String() != "fc-1" ||
|
||||||
|
input[2].Get("id").String() != "assistant-1" ||
|
||||||
|
input[3].Get("id").String() != "tool-out-1" {
|
||||||
|
t.Fatalf("unexpected merged input order")
|
||||||
|
}
|
||||||
|
if !bytes.Equal(next, normalized) {
|
||||||
|
t.Fatalf("next request snapshot should match normalized request")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNormalizeResponsesWebsocketRequestAppend(t *testing.T) {
|
||||||
|
lastRequest := []byte(`{"model":"test-model","stream":true,"input":[{"type":"message","id":"msg-1"}]}`)
|
||||||
|
lastResponseOutput := []byte(`[
|
||||||
|
{"type":"message","id":"assistant-1"},
|
||||||
|
{"type":"function_call_output","id":"tool-out-1"}
|
||||||
|
]`)
|
||||||
|
raw := []byte(`{"type":"response.append","input":[{"type":"message","id":"msg-2"},{"type":"message","id":"msg-3"}]}`)
|
||||||
|
|
||||||
|
normalized, next, errMsg := normalizeResponsesWebsocketRequest(raw, lastRequest, lastResponseOutput)
|
||||||
|
if errMsg != nil {
|
||||||
|
t.Fatalf("unexpected error: %v", errMsg.Error)
|
||||||
|
}
|
||||||
|
input := gjson.GetBytes(normalized, "input").Array()
|
||||||
|
if len(input) != 5 {
|
||||||
|
t.Fatalf("merged input len = %d, want 5", len(input))
|
||||||
|
}
|
||||||
|
if input[0].Get("id").String() != "msg-1" ||
|
||||||
|
input[1].Get("id").String() != "assistant-1" ||
|
||||||
|
input[2].Get("id").String() != "tool-out-1" ||
|
||||||
|
input[3].Get("id").String() != "msg-2" ||
|
||||||
|
input[4].Get("id").String() != "msg-3" {
|
||||||
|
t.Fatalf("unexpected merged input order")
|
||||||
|
}
|
||||||
|
if !bytes.Equal(next, normalized) {
|
||||||
|
t.Fatalf("next request snapshot should match normalized append request")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNormalizeResponsesWebsocketRequestAppendWithoutCreate(t *testing.T) {
|
||||||
|
raw := []byte(`{"type":"response.append","input":[]}`)
|
||||||
|
|
||||||
|
_, _, errMsg := normalizeResponsesWebsocketRequest(raw, nil, nil)
|
||||||
|
if errMsg == nil {
|
||||||
|
t.Fatalf("expected error for append without previous request")
|
||||||
|
}
|
||||||
|
if errMsg.StatusCode != http.StatusBadRequest {
|
||||||
|
t.Fatalf("status = %d, want %d", errMsg.StatusCode, http.StatusBadRequest)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWebsocketJSONPayloadsFromChunk(t *testing.T) {
|
||||||
|
chunk := []byte("event: response.created\n\ndata: {\"type\":\"response.created\",\"response\":{\"id\":\"resp-1\"}}\n\ndata: [DONE]\n")
|
||||||
|
|
||||||
|
payloads := websocketJSONPayloadsFromChunk(chunk)
|
||||||
|
if len(payloads) != 1 {
|
||||||
|
t.Fatalf("payloads len = %d, want 1", len(payloads))
|
||||||
|
}
|
||||||
|
if gjson.GetBytes(payloads[0], "type").String() != "response.created" {
|
||||||
|
t.Fatalf("unexpected payload type: %s", gjson.GetBytes(payloads[0], "type").String())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestWebsocketJSONPayloadsFromPlainJSONChunk(t *testing.T) {
|
||||||
|
chunk := []byte(`{"type":"response.completed","response":{"id":"resp-1"}}`)
|
||||||
|
|
||||||
|
payloads := websocketJSONPayloadsFromChunk(chunk)
|
||||||
|
if len(payloads) != 1 {
|
||||||
|
t.Fatalf("payloads len = %d, want 1", len(payloads))
|
||||||
|
}
|
||||||
|
if gjson.GetBytes(payloads[0], "type").String() != "response.completed" {
|
||||||
|
t.Fatalf("unexpected payload type: %s", gjson.GetBytes(payloads[0], "type").String())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestResponseCompletedOutputFromPayload(t *testing.T) {
|
||||||
|
payload := []byte(`{"type":"response.completed","response":{"id":"resp-1","output":[{"type":"message","id":"out-1"}]}}`)
|
||||||
|
|
||||||
|
output := responseCompletedOutputFromPayload(payload)
|
||||||
|
items := gjson.ParseBytes(output).Array()
|
||||||
|
if len(items) != 1 {
|
||||||
|
t.Fatalf("output len = %d, want 1", len(items))
|
||||||
|
}
|
||||||
|
if items[0].Get("id").String() != "out-1" {
|
||||||
|
t.Fatalf("unexpected output id: %s", items[0].Get("id").String())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAppendWebsocketEvent(t *testing.T) {
|
||||||
|
var builder strings.Builder
|
||||||
|
|
||||||
|
appendWebsocketEvent(&builder, "request", []byte(" {\"type\":\"response.create\"}\n"))
|
||||||
|
appendWebsocketEvent(&builder, "response", []byte("{\"type\":\"response.created\"}"))
|
||||||
|
|
||||||
|
got := builder.String()
|
||||||
|
if !strings.Contains(got, "websocket.request\n{\"type\":\"response.create\"}\n") {
|
||||||
|
t.Fatalf("request event not found in body: %s", got)
|
||||||
|
}
|
||||||
|
if !strings.Contains(got, "websocket.response\n{\"type\":\"response.created\"}\n") {
|
||||||
|
t.Fatalf("response event not found in body: %s", got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSetWebsocketRequestBody(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
recorder := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(recorder)
|
||||||
|
|
||||||
|
setWebsocketRequestBody(c, " \n ")
|
||||||
|
if _, exists := c.Get(wsRequestBodyKey); exists {
|
||||||
|
t.Fatalf("request body key should not be set for empty body")
|
||||||
|
}
|
||||||
|
|
||||||
|
setWebsocketRequestBody(c, "event body")
|
||||||
|
value, exists := c.Get(wsRequestBodyKey)
|
||||||
|
if !exists {
|
||||||
|
t.Fatalf("request body key not set")
|
||||||
|
}
|
||||||
|
bodyBytes, ok := value.([]byte)
|
||||||
|
if !ok {
|
||||||
|
t.Fatalf("request body key type mismatch")
|
||||||
|
}
|
||||||
|
if string(bodyBytes) != "event body" {
|
||||||
|
t.Fatalf("request body = %q, want %q", string(bodyBytes), "event body")
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -41,6 +41,17 @@ type ProviderExecutor interface {
|
|||||||
HttpRequest(ctx context.Context, auth *Auth, req *http.Request) (*http.Response, error)
|
HttpRequest(ctx context.Context, auth *Auth, req *http.Request) (*http.Response, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ExecutionSessionCloser allows executors to release per-session runtime resources.
|
||||||
|
type ExecutionSessionCloser interface {
|
||||||
|
CloseExecutionSession(sessionID string)
|
||||||
|
}
|
||||||
|
|
||||||
|
const (
|
||||||
|
// CloseAllExecutionSessionsID asks an executor to release all active execution sessions.
|
||||||
|
// Executors that do not support this marker may ignore it.
|
||||||
|
CloseAllExecutionSessionsID = "__all_execution_sessions__"
|
||||||
|
)
|
||||||
|
|
||||||
// RefreshEvaluator allows runtime state to override refresh decisions.
|
// RefreshEvaluator allows runtime state to override refresh decisions.
|
||||||
type RefreshEvaluator interface {
|
type RefreshEvaluator interface {
|
||||||
ShouldRefresh(now time.Time, auth *Auth) bool
|
ShouldRefresh(now time.Time, auth *Auth) bool
|
||||||
@@ -389,9 +400,23 @@ func (m *Manager) RegisterExecutor(executor ProviderExecutor) {
|
|||||||
if executor == nil {
|
if executor == nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
provider := strings.TrimSpace(executor.Identifier())
|
||||||
|
if provider == "" {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var replaced ProviderExecutor
|
||||||
m.mu.Lock()
|
m.mu.Lock()
|
||||||
defer m.mu.Unlock()
|
replaced = m.executors[provider]
|
||||||
m.executors[executor.Identifier()] = executor
|
m.executors[provider] = executor
|
||||||
|
m.mu.Unlock()
|
||||||
|
|
||||||
|
if replaced == nil || replaced == executor {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if closer, ok := replaced.(ExecutionSessionCloser); ok && closer != nil {
|
||||||
|
closer.CloseExecutionSession(CloseAllExecutionSessionsID)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// UnregisterExecutor removes the executor associated with the provider key.
|
// UnregisterExecutor removes the executor associated with the provider key.
|
||||||
@@ -581,6 +606,7 @@ func (m *Manager) executeMixedOnce(ctx context.Context, providers []string, req
|
|||||||
|
|
||||||
entry := logEntryWithRequestID(ctx)
|
entry := logEntryWithRequestID(ctx)
|
||||||
debugLogAuthSelection(entry, auth, provider, req.Model)
|
debugLogAuthSelection(entry, auth, provider, req.Model)
|
||||||
|
publishSelectedAuthMetadata(opts.Metadata, auth.ID)
|
||||||
|
|
||||||
tried[auth.ID] = struct{}{}
|
tried[auth.ID] = struct{}{}
|
||||||
execCtx := ctx
|
execCtx := ctx
|
||||||
@@ -636,6 +662,7 @@ func (m *Manager) executeCountMixedOnce(ctx context.Context, providers []string,
|
|||||||
|
|
||||||
entry := logEntryWithRequestID(ctx)
|
entry := logEntryWithRequestID(ctx)
|
||||||
debugLogAuthSelection(entry, auth, provider, req.Model)
|
debugLogAuthSelection(entry, auth, provider, req.Model)
|
||||||
|
publishSelectedAuthMetadata(opts.Metadata, auth.ID)
|
||||||
|
|
||||||
tried[auth.ID] = struct{}{}
|
tried[auth.ID] = struct{}{}
|
||||||
execCtx := ctx
|
execCtx := ctx
|
||||||
@@ -691,6 +718,7 @@ func (m *Manager) executeStreamMixedOnce(ctx context.Context, providers []string
|
|||||||
|
|
||||||
entry := logEntryWithRequestID(ctx)
|
entry := logEntryWithRequestID(ctx)
|
||||||
debugLogAuthSelection(entry, auth, provider, req.Model)
|
debugLogAuthSelection(entry, auth, provider, req.Model)
|
||||||
|
publishSelectedAuthMetadata(opts.Metadata, auth.ID)
|
||||||
|
|
||||||
tried[auth.ID] = struct{}{}
|
tried[auth.ID] = struct{}{}
|
||||||
execCtx := ctx
|
execCtx := ctx
|
||||||
@@ -794,6 +822,38 @@ func hasRequestedModelMetadata(meta map[string]any) bool {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func pinnedAuthIDFromMetadata(meta map[string]any) string {
|
||||||
|
if len(meta) == 0 {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
raw, ok := meta[cliproxyexecutor.PinnedAuthMetadataKey]
|
||||||
|
if !ok || raw == nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
switch val := raw.(type) {
|
||||||
|
case string:
|
||||||
|
return strings.TrimSpace(val)
|
||||||
|
case []byte:
|
||||||
|
return strings.TrimSpace(string(val))
|
||||||
|
default:
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func publishSelectedAuthMetadata(meta map[string]any, authID string) {
|
||||||
|
if len(meta) == 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
authID = strings.TrimSpace(authID)
|
||||||
|
if authID == "" {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
meta[cliproxyexecutor.SelectedAuthMetadataKey] = authID
|
||||||
|
if callback, ok := meta[cliproxyexecutor.SelectedAuthCallbackMetadataKey].(func(string)); ok && callback != nil {
|
||||||
|
callback(authID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func rewriteModelForAuth(model string, auth *Auth) string {
|
func rewriteModelForAuth(model string, auth *Auth) string {
|
||||||
if auth == nil || model == "" {
|
if auth == nil || model == "" {
|
||||||
return model
|
return model
|
||||||
@@ -1550,7 +1610,56 @@ func (m *Manager) GetByID(id string) (*Auth, bool) {
|
|||||||
return auth.Clone(), true
|
return auth.Clone(), true
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Executor returns the registered provider executor for a provider key.
|
||||||
|
func (m *Manager) Executor(provider string) (ProviderExecutor, bool) {
|
||||||
|
if m == nil {
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
provider = strings.TrimSpace(provider)
|
||||||
|
if provider == "" {
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
|
||||||
|
m.mu.RLock()
|
||||||
|
executor, okExecutor := m.executors[provider]
|
||||||
|
if !okExecutor {
|
||||||
|
lowerProvider := strings.ToLower(provider)
|
||||||
|
if lowerProvider != provider {
|
||||||
|
executor, okExecutor = m.executors[lowerProvider]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
m.mu.RUnlock()
|
||||||
|
|
||||||
|
if !okExecutor || executor == nil {
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
return executor, true
|
||||||
|
}
|
||||||
|
|
||||||
|
// CloseExecutionSession asks all registered executors to release the supplied execution session.
|
||||||
|
func (m *Manager) CloseExecutionSession(sessionID string) {
|
||||||
|
sessionID = strings.TrimSpace(sessionID)
|
||||||
|
if m == nil || sessionID == "" {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
m.mu.RLock()
|
||||||
|
executors := make([]ProviderExecutor, 0, len(m.executors))
|
||||||
|
for _, exec := range m.executors {
|
||||||
|
executors = append(executors, exec)
|
||||||
|
}
|
||||||
|
m.mu.RUnlock()
|
||||||
|
|
||||||
|
for i := range executors {
|
||||||
|
if closer, ok := executors[i].(ExecutionSessionCloser); ok && closer != nil {
|
||||||
|
closer.CloseExecutionSession(sessionID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func (m *Manager) pickNext(ctx context.Context, provider, model string, opts cliproxyexecutor.Options, tried map[string]struct{}) (*Auth, ProviderExecutor, error) {
|
func (m *Manager) pickNext(ctx context.Context, provider, model string, opts cliproxyexecutor.Options, tried map[string]struct{}) (*Auth, ProviderExecutor, error) {
|
||||||
|
pinnedAuthID := pinnedAuthIDFromMetadata(opts.Metadata)
|
||||||
|
|
||||||
m.mu.RLock()
|
m.mu.RLock()
|
||||||
executor, okExecutor := m.executors[provider]
|
executor, okExecutor := m.executors[provider]
|
||||||
if !okExecutor {
|
if !okExecutor {
|
||||||
@@ -1571,6 +1680,9 @@ func (m *Manager) pickNext(ctx context.Context, provider, model string, opts cli
|
|||||||
if candidate.Provider != provider || candidate.Disabled {
|
if candidate.Provider != provider || candidate.Disabled {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
if pinnedAuthID != "" && candidate.ID != pinnedAuthID {
|
||||||
|
continue
|
||||||
|
}
|
||||||
if _, used := tried[candidate.ID]; used {
|
if _, used := tried[candidate.ID]; used {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
@@ -1606,6 +1718,8 @@ func (m *Manager) pickNext(ctx context.Context, provider, model string, opts cli
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (m *Manager) pickNextMixed(ctx context.Context, providers []string, model string, opts cliproxyexecutor.Options, tried map[string]struct{}) (*Auth, ProviderExecutor, string, error) {
|
func (m *Manager) pickNextMixed(ctx context.Context, providers []string, model string, opts cliproxyexecutor.Options, tried map[string]struct{}) (*Auth, ProviderExecutor, string, error) {
|
||||||
|
pinnedAuthID := pinnedAuthIDFromMetadata(opts.Metadata)
|
||||||
|
|
||||||
providerSet := make(map[string]struct{}, len(providers))
|
providerSet := make(map[string]struct{}, len(providers))
|
||||||
for _, provider := range providers {
|
for _, provider := range providers {
|
||||||
p := strings.TrimSpace(strings.ToLower(provider))
|
p := strings.TrimSpace(strings.ToLower(provider))
|
||||||
@@ -1633,6 +1747,9 @@ func (m *Manager) pickNextMixed(ctx context.Context, providers []string, model s
|
|||||||
if candidate == nil || candidate.Disabled {
|
if candidate == nil || candidate.Disabled {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
if pinnedAuthID != "" && candidate.ID != pinnedAuthID {
|
||||||
|
continue
|
||||||
|
}
|
||||||
providerKey := strings.TrimSpace(strings.ToLower(candidate.Provider))
|
providerKey := strings.TrimSpace(strings.ToLower(candidate.Provider))
|
||||||
if providerKey == "" {
|
if providerKey == "" {
|
||||||
continue
|
continue
|
||||||
|
|||||||
100
sdk/cliproxy/auth/conductor_executor_replace_test.go
Normal file
100
sdk/cliproxy/auth/conductor_executor_replace_test.go
Normal file
@@ -0,0 +1,100 @@
|
|||||||
|
package auth
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"net/http"
|
||||||
|
"sync"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
cliproxyexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
|
||||||
|
)
|
||||||
|
|
||||||
|
type replaceAwareExecutor struct {
|
||||||
|
id string
|
||||||
|
|
||||||
|
mu sync.Mutex
|
||||||
|
closedSessionIDs []string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *replaceAwareExecutor) Identifier() string {
|
||||||
|
return e.id
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *replaceAwareExecutor) Execute(context.Context, *Auth, cliproxyexecutor.Request, cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
|
||||||
|
return cliproxyexecutor.Response{}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *replaceAwareExecutor) ExecuteStream(context.Context, *Auth, cliproxyexecutor.Request, cliproxyexecutor.Options) (<-chan cliproxyexecutor.StreamChunk, error) {
|
||||||
|
ch := make(chan cliproxyexecutor.StreamChunk)
|
||||||
|
close(ch)
|
||||||
|
return ch, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *replaceAwareExecutor) Refresh(_ context.Context, auth *Auth) (*Auth, error) {
|
||||||
|
return auth, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *replaceAwareExecutor) CountTokens(context.Context, *Auth, cliproxyexecutor.Request, cliproxyexecutor.Options) (cliproxyexecutor.Response, error) {
|
||||||
|
return cliproxyexecutor.Response{}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *replaceAwareExecutor) HttpRequest(context.Context, *Auth, *http.Request) (*http.Response, error) {
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *replaceAwareExecutor) CloseExecutionSession(sessionID string) {
|
||||||
|
e.mu.Lock()
|
||||||
|
defer e.mu.Unlock()
|
||||||
|
e.closedSessionIDs = append(e.closedSessionIDs, sessionID)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *replaceAwareExecutor) ClosedSessionIDs() []string {
|
||||||
|
e.mu.Lock()
|
||||||
|
defer e.mu.Unlock()
|
||||||
|
out := make([]string, len(e.closedSessionIDs))
|
||||||
|
copy(out, e.closedSessionIDs)
|
||||||
|
return out
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestManagerRegisterExecutorClosesReplacedExecutionSessions(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
manager := NewManager(nil, nil, nil)
|
||||||
|
replaced := &replaceAwareExecutor{id: "codex"}
|
||||||
|
current := &replaceAwareExecutor{id: "codex"}
|
||||||
|
|
||||||
|
manager.RegisterExecutor(replaced)
|
||||||
|
manager.RegisterExecutor(current)
|
||||||
|
|
||||||
|
closed := replaced.ClosedSessionIDs()
|
||||||
|
if len(closed) != 1 {
|
||||||
|
t.Fatalf("expected replaced executor close calls = 1, got %d", len(closed))
|
||||||
|
}
|
||||||
|
if closed[0] != CloseAllExecutionSessionsID {
|
||||||
|
t.Fatalf("expected close marker %q, got %q", CloseAllExecutionSessionsID, closed[0])
|
||||||
|
}
|
||||||
|
if len(current.ClosedSessionIDs()) != 0 {
|
||||||
|
t.Fatalf("expected current executor to stay open")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestManagerExecutorReturnsRegisteredExecutor(t *testing.T) {
|
||||||
|
t.Parallel()
|
||||||
|
|
||||||
|
manager := NewManager(nil, nil, nil)
|
||||||
|
current := &replaceAwareExecutor{id: "codex"}
|
||||||
|
manager.RegisterExecutor(current)
|
||||||
|
|
||||||
|
resolved, okResolved := manager.Executor("CODEX")
|
||||||
|
if !okResolved {
|
||||||
|
t.Fatal("expected registered executor to be found")
|
||||||
|
}
|
||||||
|
if resolved != current {
|
||||||
|
t.Fatal("expected resolved executor to match registered executor")
|
||||||
|
}
|
||||||
|
|
||||||
|
_, okMissing := manager.Executor("unknown")
|
||||||
|
if okMissing {
|
||||||
|
t.Fatal("expected unknown provider lookup to fail")
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -134,6 +134,62 @@ func canonicalModelKey(model string) string {
|
|||||||
return modelName
|
return modelName
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func authWebsocketsEnabled(auth *Auth) bool {
|
||||||
|
if auth == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
if len(auth.Attributes) > 0 {
|
||||||
|
if raw := strings.TrimSpace(auth.Attributes["websockets"]); raw != "" {
|
||||||
|
parsed, errParse := strconv.ParseBool(raw)
|
||||||
|
if errParse == nil {
|
||||||
|
return parsed
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if len(auth.Metadata) == 0 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
raw, ok := auth.Metadata["websockets"]
|
||||||
|
if !ok || raw == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
switch v := raw.(type) {
|
||||||
|
case bool:
|
||||||
|
return v
|
||||||
|
case string:
|
||||||
|
parsed, errParse := strconv.ParseBool(strings.TrimSpace(v))
|
||||||
|
if errParse == nil {
|
||||||
|
return parsed
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func preferCodexWebsocketAuths(ctx context.Context, provider string, available []*Auth) []*Auth {
|
||||||
|
if len(available) == 0 {
|
||||||
|
return available
|
||||||
|
}
|
||||||
|
if !cliproxyexecutor.DownstreamWebsocket(ctx) {
|
||||||
|
return available
|
||||||
|
}
|
||||||
|
if !strings.EqualFold(strings.TrimSpace(provider), "codex") {
|
||||||
|
return available
|
||||||
|
}
|
||||||
|
|
||||||
|
wsEnabled := make([]*Auth, 0, len(available))
|
||||||
|
for i := 0; i < len(available); i++ {
|
||||||
|
candidate := available[i]
|
||||||
|
if authWebsocketsEnabled(candidate) {
|
||||||
|
wsEnabled = append(wsEnabled, candidate)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if len(wsEnabled) > 0 {
|
||||||
|
return wsEnabled
|
||||||
|
}
|
||||||
|
return available
|
||||||
|
}
|
||||||
|
|
||||||
func collectAvailableByPriority(auths []*Auth, model string, now time.Time) (available map[int][]*Auth, cooldownCount int, earliest time.Time) {
|
func collectAvailableByPriority(auths []*Auth, model string, now time.Time) (available map[int][]*Auth, cooldownCount int, earliest time.Time) {
|
||||||
available = make(map[int][]*Auth)
|
available = make(map[int][]*Auth)
|
||||||
for i := 0; i < len(auths); i++ {
|
for i := 0; i < len(auths); i++ {
|
||||||
@@ -193,13 +249,13 @@ func getAvailableAuths(auths []*Auth, provider, model string, now time.Time) ([]
|
|||||||
|
|
||||||
// Pick selects the next available auth for the provider in a round-robin manner.
|
// Pick selects the next available auth for the provider in a round-robin manner.
|
||||||
func (s *RoundRobinSelector) Pick(ctx context.Context, provider, model string, opts cliproxyexecutor.Options, auths []*Auth) (*Auth, error) {
|
func (s *RoundRobinSelector) Pick(ctx context.Context, provider, model string, opts cliproxyexecutor.Options, auths []*Auth) (*Auth, error) {
|
||||||
_ = ctx
|
|
||||||
_ = opts
|
_ = opts
|
||||||
now := time.Now()
|
now := time.Now()
|
||||||
available, err := getAvailableAuths(auths, provider, model, now)
|
available, err := getAvailableAuths(auths, provider, model, now)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
available = preferCodexWebsocketAuths(ctx, provider, available)
|
||||||
key := provider + ":" + canonicalModelKey(model)
|
key := provider + ":" + canonicalModelKey(model)
|
||||||
s.mu.Lock()
|
s.mu.Lock()
|
||||||
if s.cursors == nil {
|
if s.cursors == nil {
|
||||||
@@ -226,13 +282,13 @@ func (s *RoundRobinSelector) Pick(ctx context.Context, provider, model string, o
|
|||||||
|
|
||||||
// Pick selects the first available auth for the provider in a deterministic manner.
|
// Pick selects the first available auth for the provider in a deterministic manner.
|
||||||
func (s *FillFirstSelector) Pick(ctx context.Context, provider, model string, opts cliproxyexecutor.Options, auths []*Auth) (*Auth, error) {
|
func (s *FillFirstSelector) Pick(ctx context.Context, provider, model string, opts cliproxyexecutor.Options, auths []*Auth) (*Auth, error) {
|
||||||
_ = ctx
|
|
||||||
_ = opts
|
_ = opts
|
||||||
now := time.Now()
|
now := time.Now()
|
||||||
available, err := getAvailableAuths(auths, provider, model, now)
|
available, err := getAvailableAuths(auths, provider, model, now)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
available = preferCodexWebsocketAuths(ctx, provider, available)
|
||||||
return available[0], nil
|
return available[0], nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
23
sdk/cliproxy/executor/context.go
Normal file
23
sdk/cliproxy/executor/context.go
Normal file
@@ -0,0 +1,23 @@
|
|||||||
|
package executor
|
||||||
|
|
||||||
|
import "context"
|
||||||
|
|
||||||
|
type downstreamWebsocketContextKey struct{}
|
||||||
|
|
||||||
|
// WithDownstreamWebsocket marks the current request as coming from a downstream websocket connection.
|
||||||
|
func WithDownstreamWebsocket(ctx context.Context) context.Context {
|
||||||
|
if ctx == nil {
|
||||||
|
ctx = context.Background()
|
||||||
|
}
|
||||||
|
return context.WithValue(ctx, downstreamWebsocketContextKey{}, true)
|
||||||
|
}
|
||||||
|
|
||||||
|
// DownstreamWebsocket reports whether the current request originates from a downstream websocket connection.
|
||||||
|
func DownstreamWebsocket(ctx context.Context) bool {
|
||||||
|
if ctx == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
raw := ctx.Value(downstreamWebsocketContextKey{})
|
||||||
|
enabled, ok := raw.(bool)
|
||||||
|
return ok && enabled
|
||||||
|
}
|
||||||
@@ -10,6 +10,17 @@ import (
|
|||||||
// RequestedModelMetadataKey stores the client-requested model name in Options.Metadata.
|
// RequestedModelMetadataKey stores the client-requested model name in Options.Metadata.
|
||||||
const RequestedModelMetadataKey = "requested_model"
|
const RequestedModelMetadataKey = "requested_model"
|
||||||
|
|
||||||
|
const (
|
||||||
|
// PinnedAuthMetadataKey locks execution to a specific auth ID.
|
||||||
|
PinnedAuthMetadataKey = "pinned_auth_id"
|
||||||
|
// SelectedAuthMetadataKey stores the auth ID selected by the scheduler.
|
||||||
|
SelectedAuthMetadataKey = "selected_auth_id"
|
||||||
|
// SelectedAuthCallbackMetadataKey carries an optional callback invoked with the selected auth ID.
|
||||||
|
SelectedAuthCallbackMetadataKey = "selected_auth_callback"
|
||||||
|
// ExecutionSessionMetadataKey identifies a long-lived downstream execution session.
|
||||||
|
ExecutionSessionMetadataKey = "execution_session_id"
|
||||||
|
)
|
||||||
|
|
||||||
// Request encapsulates the translated payload that will be sent to a provider executor.
|
// Request encapsulates the translated payload that will be sent to a provider executor.
|
||||||
type Request struct {
|
type Request struct {
|
||||||
// Model is the upstream model identifier after translation.
|
// Model is the upstream model identifier after translation.
|
||||||
|
|||||||
@@ -325,6 +325,9 @@ func (s *Service) applyCoreAuthRemoval(ctx context.Context, id string) {
|
|||||||
if _, err := s.coreManager.Update(ctx, existing); err != nil {
|
if _, err := s.coreManager.Update(ctx, existing); err != nil {
|
||||||
log.Errorf("failed to disable auth %s: %v", id, err)
|
log.Errorf("failed to disable auth %s: %v", id, err)
|
||||||
}
|
}
|
||||||
|
if strings.EqualFold(strings.TrimSpace(existing.Provider), "codex") {
|
||||||
|
s.ensureExecutorsForAuth(existing)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -357,7 +360,24 @@ func openAICompatInfoFromAuth(a *coreauth.Auth) (providerKey string, compatName
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *Service) ensureExecutorsForAuth(a *coreauth.Auth) {
|
func (s *Service) ensureExecutorsForAuth(a *coreauth.Auth) {
|
||||||
if s == nil || a == nil {
|
s.ensureExecutorsForAuthWithMode(a, false)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Service) ensureExecutorsForAuthWithMode(a *coreauth.Auth, forceReplace bool) {
|
||||||
|
if s == nil || s.coreManager == nil || a == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if strings.EqualFold(strings.TrimSpace(a.Provider), "codex") {
|
||||||
|
if !forceReplace {
|
||||||
|
existingExecutor, hasExecutor := s.coreManager.Executor("codex")
|
||||||
|
if hasExecutor {
|
||||||
|
_, isCodexAutoExecutor := existingExecutor.(*executor.CodexAutoExecutor)
|
||||||
|
if isCodexAutoExecutor {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
s.coreManager.RegisterExecutor(executor.NewCodexAutoExecutor(s.cfg))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
// Skip disabled auth entries when (re)binding executors.
|
// Skip disabled auth entries when (re)binding executors.
|
||||||
@@ -392,8 +412,6 @@ func (s *Service) ensureExecutorsForAuth(a *coreauth.Auth) {
|
|||||||
s.coreManager.RegisterExecutor(executor.NewAntigravityExecutor(s.cfg))
|
s.coreManager.RegisterExecutor(executor.NewAntigravityExecutor(s.cfg))
|
||||||
case "claude":
|
case "claude":
|
||||||
s.coreManager.RegisterExecutor(executor.NewClaudeExecutor(s.cfg))
|
s.coreManager.RegisterExecutor(executor.NewClaudeExecutor(s.cfg))
|
||||||
case "codex":
|
|
||||||
s.coreManager.RegisterExecutor(executor.NewCodexExecutor(s.cfg))
|
|
||||||
case "qwen":
|
case "qwen":
|
||||||
s.coreManager.RegisterExecutor(executor.NewQwenExecutor(s.cfg))
|
s.coreManager.RegisterExecutor(executor.NewQwenExecutor(s.cfg))
|
||||||
case "iflow":
|
case "iflow":
|
||||||
@@ -415,8 +433,15 @@ func (s *Service) rebindExecutors() {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
auths := s.coreManager.List()
|
auths := s.coreManager.List()
|
||||||
|
reboundCodex := false
|
||||||
for _, auth := range auths {
|
for _, auth := range auths {
|
||||||
s.ensureExecutorsForAuth(auth)
|
if auth != nil && strings.EqualFold(strings.TrimSpace(auth.Provider), "codex") {
|
||||||
|
if reboundCodex {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
reboundCodex = true
|
||||||
|
}
|
||||||
|
s.ensureExecutorsForAuthWithMode(auth, true)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
64
sdk/cliproxy/service_codex_executor_binding_test.go
Normal file
64
sdk/cliproxy/service_codex_executor_binding_test.go
Normal file
@@ -0,0 +1,64 @@
|
|||||||
|
package cliproxy
|
||||||
|
|
||||||
|
import (
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/sdk/config"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestEnsureExecutorsForAuth_CodexDoesNotReplaceInNormalMode(t *testing.T) {
|
||||||
|
service := &Service{
|
||||||
|
cfg: &config.Config{},
|
||||||
|
coreManager: coreauth.NewManager(nil, nil, nil),
|
||||||
|
}
|
||||||
|
auth := &coreauth.Auth{
|
||||||
|
ID: "codex-auth-1",
|
||||||
|
Provider: "codex",
|
||||||
|
Status: coreauth.StatusActive,
|
||||||
|
}
|
||||||
|
|
||||||
|
service.ensureExecutorsForAuth(auth)
|
||||||
|
firstExecutor, okFirst := service.coreManager.Executor("codex")
|
||||||
|
if !okFirst || firstExecutor == nil {
|
||||||
|
t.Fatal("expected codex executor after first bind")
|
||||||
|
}
|
||||||
|
|
||||||
|
service.ensureExecutorsForAuth(auth)
|
||||||
|
secondExecutor, okSecond := service.coreManager.Executor("codex")
|
||||||
|
if !okSecond || secondExecutor == nil {
|
||||||
|
t.Fatal("expected codex executor after second bind")
|
||||||
|
}
|
||||||
|
|
||||||
|
if firstExecutor != secondExecutor {
|
||||||
|
t.Fatal("expected codex executor to stay unchanged in normal mode")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestEnsureExecutorsForAuthWithMode_CodexForceReplace(t *testing.T) {
|
||||||
|
service := &Service{
|
||||||
|
cfg: &config.Config{},
|
||||||
|
coreManager: coreauth.NewManager(nil, nil, nil),
|
||||||
|
}
|
||||||
|
auth := &coreauth.Auth{
|
||||||
|
ID: "codex-auth-2",
|
||||||
|
Provider: "codex",
|
||||||
|
Status: coreauth.StatusActive,
|
||||||
|
}
|
||||||
|
|
||||||
|
service.ensureExecutorsForAuth(auth)
|
||||||
|
firstExecutor, okFirst := service.coreManager.Executor("codex")
|
||||||
|
if !okFirst || firstExecutor == nil {
|
||||||
|
t.Fatal("expected codex executor after first bind")
|
||||||
|
}
|
||||||
|
|
||||||
|
service.ensureExecutorsForAuthWithMode(auth, true)
|
||||||
|
secondExecutor, okSecond := service.coreManager.Executor("codex")
|
||||||
|
if !okSecond || secondExecutor == nil {
|
||||||
|
t.Fatal("expected codex executor after forced rebind")
|
||||||
|
}
|
||||||
|
|
||||||
|
if firstExecutor == secondExecutor {
|
||||||
|
t.Fatal("expected codex executor replacement in force mode")
|
||||||
|
}
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user