mirror of
https://github.com/router-for-me/CLIProxyAPI.git
synced 2026-02-28 11:54:02 +08:00
Merge pull request #1680 from canxin121/fix/responses-stream-error-chunks
fix(responses): emit schema-valid SSE chunks
This commit is contained in:
@@ -716,6 +716,12 @@ func (h *BaseAPIHandler) ExecuteStreamWithAuthManager(ctx context.Context, handl
|
||||
return
|
||||
}
|
||||
if len(chunk.Payload) > 0 {
|
||||
if handlerType == "openai-response" {
|
||||
if err := validateSSEDataJSON(chunk.Payload); err != nil {
|
||||
_ = sendErr(&interfaces.ErrorMessage{StatusCode: http.StatusBadGateway, Error: err})
|
||||
return
|
||||
}
|
||||
}
|
||||
sentPayload = true
|
||||
if okSendData := sendData(cloneBytes(chunk.Payload)); !okSendData {
|
||||
return
|
||||
@@ -727,6 +733,35 @@ func (h *BaseAPIHandler) ExecuteStreamWithAuthManager(ctx context.Context, handl
|
||||
return dataChan, upstreamHeaders, errChan
|
||||
}
|
||||
|
||||
func validateSSEDataJSON(chunk []byte) error {
|
||||
for _, line := range bytes.Split(chunk, []byte("\n")) {
|
||||
line = bytes.TrimSpace(line)
|
||||
if len(line) == 0 {
|
||||
continue
|
||||
}
|
||||
if !bytes.HasPrefix(line, []byte("data:")) {
|
||||
continue
|
||||
}
|
||||
data := bytes.TrimSpace(line[5:])
|
||||
if len(data) == 0 {
|
||||
continue
|
||||
}
|
||||
if bytes.Equal(data, []byte("[DONE]")) {
|
||||
continue
|
||||
}
|
||||
if json.Valid(data) {
|
||||
continue
|
||||
}
|
||||
const max = 512
|
||||
preview := data
|
||||
if len(preview) > max {
|
||||
preview = preview[:max]
|
||||
}
|
||||
return fmt.Errorf("invalid SSE data JSON (len=%d): %q", len(data), preview)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func statusFromError(err error) int {
|
||||
if err == nil {
|
||||
return 0
|
||||
|
||||
@@ -134,6 +134,37 @@ type authAwareStreamExecutor struct {
|
||||
authIDs []string
|
||||
}
|
||||
|
||||
type invalidJSONStreamExecutor struct{}
|
||||
|
||||
func (e *invalidJSONStreamExecutor) Identifier() string { return "codex" }
|
||||
|
||||
func (e *invalidJSONStreamExecutor) Execute(context.Context, *coreauth.Auth, coreexecutor.Request, coreexecutor.Options) (coreexecutor.Response, error) {
|
||||
return coreexecutor.Response{}, &coreauth.Error{Code: "not_implemented", Message: "Execute not implemented"}
|
||||
}
|
||||
|
||||
func (e *invalidJSONStreamExecutor) ExecuteStream(context.Context, *coreauth.Auth, coreexecutor.Request, coreexecutor.Options) (*coreexecutor.StreamResult, error) {
|
||||
ch := make(chan coreexecutor.StreamChunk, 1)
|
||||
ch <- coreexecutor.StreamChunk{Payload: []byte("event: response.completed\ndata: {\"type\"")}
|
||||
close(ch)
|
||||
return &coreexecutor.StreamResult{Chunks: ch}, nil
|
||||
}
|
||||
|
||||
func (e *invalidJSONStreamExecutor) Refresh(ctx context.Context, auth *coreauth.Auth) (*coreauth.Auth, error) {
|
||||
return auth, nil
|
||||
}
|
||||
|
||||
func (e *invalidJSONStreamExecutor) CountTokens(context.Context, *coreauth.Auth, coreexecutor.Request, coreexecutor.Options) (coreexecutor.Response, error) {
|
||||
return coreexecutor.Response{}, &coreauth.Error{Code: "not_implemented", Message: "CountTokens not implemented"}
|
||||
}
|
||||
|
||||
func (e *invalidJSONStreamExecutor) HttpRequest(ctx context.Context, auth *coreauth.Auth, req *http.Request) (*http.Response, error) {
|
||||
return nil, &coreauth.Error{
|
||||
Code: "not_implemented",
|
||||
Message: "HttpRequest not implemented",
|
||||
HTTPStatus: http.StatusNotImplemented,
|
||||
}
|
||||
}
|
||||
|
||||
func (e *authAwareStreamExecutor) Identifier() string { return "codex" }
|
||||
|
||||
func (e *authAwareStreamExecutor) Execute(context.Context, *coreauth.Auth, coreexecutor.Request, coreexecutor.Options) (coreexecutor.Response, error) {
|
||||
@@ -524,3 +555,55 @@ func TestExecuteStreamWithAuthManager_SelectedAuthCallbackReceivesAuthID(t *test
|
||||
t.Fatalf("selectedAuthID = %q, want %q", selectedAuthID, "auth2")
|
||||
}
|
||||
}
|
||||
|
||||
func TestExecuteStreamWithAuthManager_ValidatesOpenAIResponsesStreamDataJSON(t *testing.T) {
|
||||
executor := &invalidJSONStreamExecutor{}
|
||||
manager := coreauth.NewManager(nil, nil, nil)
|
||||
manager.RegisterExecutor(executor)
|
||||
|
||||
auth1 := &coreauth.Auth{
|
||||
ID: "auth1",
|
||||
Provider: "codex",
|
||||
Status: coreauth.StatusActive,
|
||||
Metadata: map[string]any{"email": "test1@example.com"},
|
||||
}
|
||||
if _, err := manager.Register(context.Background(), auth1); err != nil {
|
||||
t.Fatalf("manager.Register(auth1): %v", err)
|
||||
}
|
||||
|
||||
registry.GetGlobalRegistry().RegisterClient(auth1.ID, auth1.Provider, []*registry.ModelInfo{{ID: "test-model"}})
|
||||
t.Cleanup(func() {
|
||||
registry.GetGlobalRegistry().UnregisterClient(auth1.ID)
|
||||
})
|
||||
|
||||
handler := NewBaseAPIHandlers(&sdkconfig.SDKConfig{}, manager)
|
||||
dataChan, _, errChan := handler.ExecuteStreamWithAuthManager(context.Background(), "openai-response", "test-model", []byte(`{"model":"test-model"}`), "")
|
||||
if dataChan == nil || errChan == nil {
|
||||
t.Fatalf("expected non-nil channels")
|
||||
}
|
||||
|
||||
var got []byte
|
||||
for chunk := range dataChan {
|
||||
got = append(got, chunk...)
|
||||
}
|
||||
if len(got) != 0 {
|
||||
t.Fatalf("expected empty payload, got %q", string(got))
|
||||
}
|
||||
|
||||
gotErr := false
|
||||
for msg := range errChan {
|
||||
if msg == nil {
|
||||
continue
|
||||
}
|
||||
if msg.StatusCode != http.StatusBadGateway {
|
||||
t.Fatalf("expected status %d, got %d", http.StatusBadGateway, msg.StatusCode)
|
||||
}
|
||||
if msg.Error == nil {
|
||||
t.Fatalf("expected error")
|
||||
}
|
||||
gotErr = true
|
||||
}
|
||||
if !gotErr {
|
||||
t.Fatalf("expected terminal error")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -265,8 +265,8 @@ func (h *OpenAIResponsesAPIHandler) forwardResponsesStream(c *gin.Context, flush
|
||||
if errMsg.Error != nil && errMsg.Error.Error() != "" {
|
||||
errText = errMsg.Error.Error()
|
||||
}
|
||||
body := handlers.BuildErrorResponseBody(status, errText)
|
||||
_, _ = fmt.Fprintf(c.Writer, "\nevent: error\ndata: %s\n\n", string(body))
|
||||
chunk := handlers.BuildOpenAIResponsesStreamErrorChunk(status, errText, 0)
|
||||
_, _ = fmt.Fprintf(c.Writer, "\nevent: error\ndata: %s\n\n", string(chunk))
|
||||
},
|
||||
WriteDone: func() {
|
||||
_, _ = c.Writer.Write([]byte("\n"))
|
||||
|
||||
@@ -0,0 +1,43 @@
|
||||
package openai
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/internal/interfaces"
|
||||
"github.com/router-for-me/CLIProxyAPI/v6/sdk/api/handlers"
|
||||
sdkconfig "github.com/router-for-me/CLIProxyAPI/v6/sdk/config"
|
||||
)
|
||||
|
||||
func TestForwardResponsesStreamTerminalErrorUsesResponsesErrorChunk(t *testing.T) {
|
||||
gin.SetMode(gin.TestMode)
|
||||
base := handlers.NewBaseAPIHandlers(&sdkconfig.SDKConfig{}, nil)
|
||||
h := NewOpenAIResponsesAPIHandler(base)
|
||||
|
||||
recorder := httptest.NewRecorder()
|
||||
c, _ := gin.CreateTestContext(recorder)
|
||||
c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses", nil)
|
||||
|
||||
flusher, ok := c.Writer.(http.Flusher)
|
||||
if !ok {
|
||||
t.Fatalf("expected gin writer to implement http.Flusher")
|
||||
}
|
||||
|
||||
data := make(chan []byte)
|
||||
errs := make(chan *interfaces.ErrorMessage, 1)
|
||||
errs <- &interfaces.ErrorMessage{StatusCode: http.StatusInternalServerError, Error: errors.New("unexpected EOF")}
|
||||
close(errs)
|
||||
|
||||
h.forwardResponsesStream(c, flusher, func(error) {}, data, errs)
|
||||
body := recorder.Body.String()
|
||||
if !strings.Contains(body, `"type":"error"`) {
|
||||
t.Fatalf("expected responses error chunk, got: %q", body)
|
||||
}
|
||||
if strings.Contains(body, `"error":{`) {
|
||||
t.Fatalf("expected streaming error chunk (top-level type), got HTTP error body: %q", body)
|
||||
}
|
||||
}
|
||||
119
sdk/api/handlers/openai_responses_stream_error.go
Normal file
119
sdk/api/handlers/openai_responses_stream_error.go
Normal file
@@ -0,0 +1,119 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
)
|
||||
|
||||
type openAIResponsesStreamErrorChunk struct {
|
||||
Type string `json:"type"`
|
||||
Code string `json:"code"`
|
||||
Message string `json:"message"`
|
||||
SequenceNumber int `json:"sequence_number"`
|
||||
}
|
||||
|
||||
func openAIResponsesStreamErrorCode(status int) string {
|
||||
switch status {
|
||||
case http.StatusUnauthorized:
|
||||
return "invalid_api_key"
|
||||
case http.StatusForbidden:
|
||||
return "insufficient_quota"
|
||||
case http.StatusTooManyRequests:
|
||||
return "rate_limit_exceeded"
|
||||
case http.StatusNotFound:
|
||||
return "model_not_found"
|
||||
case http.StatusRequestTimeout:
|
||||
return "request_timeout"
|
||||
default:
|
||||
if status >= http.StatusInternalServerError {
|
||||
return "internal_server_error"
|
||||
}
|
||||
if status >= http.StatusBadRequest {
|
||||
return "invalid_request_error"
|
||||
}
|
||||
return "unknown_error"
|
||||
}
|
||||
}
|
||||
|
||||
// BuildOpenAIResponsesStreamErrorChunk builds an OpenAI Responses streaming error chunk.
|
||||
//
|
||||
// Important: OpenAI's HTTP error bodies are shaped like {"error":{...}}; those are valid for
|
||||
// non-streaming responses, but streaming clients validate SSE `data:` payloads against a union
|
||||
// of chunks that requires a top-level `type` field.
|
||||
func BuildOpenAIResponsesStreamErrorChunk(status int, errText string, sequenceNumber int) []byte {
|
||||
if status <= 0 {
|
||||
status = http.StatusInternalServerError
|
||||
}
|
||||
if sequenceNumber < 0 {
|
||||
sequenceNumber = 0
|
||||
}
|
||||
|
||||
message := strings.TrimSpace(errText)
|
||||
if message == "" {
|
||||
message = http.StatusText(status)
|
||||
}
|
||||
|
||||
code := openAIResponsesStreamErrorCode(status)
|
||||
|
||||
trimmed := strings.TrimSpace(errText)
|
||||
if trimmed != "" && json.Valid([]byte(trimmed)) {
|
||||
var payload map[string]any
|
||||
if err := json.Unmarshal([]byte(trimmed), &payload); err == nil {
|
||||
if t, ok := payload["type"].(string); ok && strings.TrimSpace(t) == "error" {
|
||||
if m, ok := payload["message"].(string); ok && strings.TrimSpace(m) != "" {
|
||||
message = strings.TrimSpace(m)
|
||||
}
|
||||
if v, ok := payload["code"]; ok && v != nil {
|
||||
if c, ok := v.(string); ok && strings.TrimSpace(c) != "" {
|
||||
code = strings.TrimSpace(c)
|
||||
} else {
|
||||
code = strings.TrimSpace(fmt.Sprint(v))
|
||||
}
|
||||
}
|
||||
if v, ok := payload["sequence_number"].(float64); ok && sequenceNumber == 0 {
|
||||
sequenceNumber = int(v)
|
||||
}
|
||||
}
|
||||
if e, ok := payload["error"].(map[string]any); ok {
|
||||
if m, ok := e["message"].(string); ok && strings.TrimSpace(m) != "" {
|
||||
message = strings.TrimSpace(m)
|
||||
}
|
||||
if v, ok := e["code"]; ok && v != nil {
|
||||
if c, ok := v.(string); ok && strings.TrimSpace(c) != "" {
|
||||
code = strings.TrimSpace(c)
|
||||
} else {
|
||||
code = strings.TrimSpace(fmt.Sprint(v))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
if strings.TrimSpace(code) == "" {
|
||||
code = "unknown_error"
|
||||
}
|
||||
|
||||
data, err := json.Marshal(openAIResponsesStreamErrorChunk{
|
||||
Type: "error",
|
||||
Code: code,
|
||||
Message: message,
|
||||
SequenceNumber: sequenceNumber,
|
||||
})
|
||||
if err == nil {
|
||||
return data
|
||||
}
|
||||
|
||||
// Extremely defensive fallback.
|
||||
data, _ = json.Marshal(openAIResponsesStreamErrorChunk{
|
||||
Type: "error",
|
||||
Code: "internal_server_error",
|
||||
Message: message,
|
||||
SequenceNumber: sequenceNumber,
|
||||
})
|
||||
if len(data) > 0 {
|
||||
return data
|
||||
}
|
||||
return []byte(`{"type":"error","code":"internal_server_error","message":"internal error","sequence_number":0}`)
|
||||
}
|
||||
48
sdk/api/handlers/openai_responses_stream_error_test.go
Normal file
48
sdk/api/handlers/openai_responses_stream_error_test.go
Normal file
@@ -0,0 +1,48 @@
|
||||
package handlers
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func TestBuildOpenAIResponsesStreamErrorChunk(t *testing.T) {
|
||||
chunk := BuildOpenAIResponsesStreamErrorChunk(http.StatusInternalServerError, "unexpected EOF", 0)
|
||||
var payload map[string]any
|
||||
if err := json.Unmarshal(chunk, &payload); err != nil {
|
||||
t.Fatalf("unmarshal: %v", err)
|
||||
}
|
||||
if payload["type"] != "error" {
|
||||
t.Fatalf("type = %v, want %q", payload["type"], "error")
|
||||
}
|
||||
if payload["code"] != "internal_server_error" {
|
||||
t.Fatalf("code = %v, want %q", payload["code"], "internal_server_error")
|
||||
}
|
||||
if payload["message"] != "unexpected EOF" {
|
||||
t.Fatalf("message = %v, want %q", payload["message"], "unexpected EOF")
|
||||
}
|
||||
if payload["sequence_number"] != float64(0) {
|
||||
t.Fatalf("sequence_number = %v, want %v", payload["sequence_number"], 0)
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildOpenAIResponsesStreamErrorChunkExtractsHTTPErrorBody(t *testing.T) {
|
||||
chunk := BuildOpenAIResponsesStreamErrorChunk(
|
||||
http.StatusInternalServerError,
|
||||
`{"error":{"message":"oops","type":"server_error","code":"internal_server_error"}}`,
|
||||
0,
|
||||
)
|
||||
var payload map[string]any
|
||||
if err := json.Unmarshal(chunk, &payload); err != nil {
|
||||
t.Fatalf("unmarshal: %v", err)
|
||||
}
|
||||
if payload["type"] != "error" {
|
||||
t.Fatalf("type = %v, want %q", payload["type"], "error")
|
||||
}
|
||||
if payload["code"] != "internal_server_error" {
|
||||
t.Fatalf("code = %v, want %q", payload["code"], "internal_server_error")
|
||||
}
|
||||
if payload["message"] != "oops" {
|
||||
t.Fatalf("message = %v, want %q", payload["message"], "oops")
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user