mirror of
https://github.com/router-for-me/CLIProxyAPI.git
synced 2026-02-28 12:04:44 +08:00
Merge pull request #1680 from canxin121/fix/responses-stream-error-chunks
fix(responses): emit schema-valid SSE chunks
This commit is contained in:
@@ -716,6 +716,12 @@ func (h *BaseAPIHandler) ExecuteStreamWithAuthManager(ctx context.Context, handl
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
if len(chunk.Payload) > 0 {
|
if len(chunk.Payload) > 0 {
|
||||||
|
if handlerType == "openai-response" {
|
||||||
|
if err := validateSSEDataJSON(chunk.Payload); err != nil {
|
||||||
|
_ = sendErr(&interfaces.ErrorMessage{StatusCode: http.StatusBadGateway, Error: err})
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
sentPayload = true
|
sentPayload = true
|
||||||
if okSendData := sendData(cloneBytes(chunk.Payload)); !okSendData {
|
if okSendData := sendData(cloneBytes(chunk.Payload)); !okSendData {
|
||||||
return
|
return
|
||||||
@@ -727,6 +733,35 @@ func (h *BaseAPIHandler) ExecuteStreamWithAuthManager(ctx context.Context, handl
|
|||||||
return dataChan, upstreamHeaders, errChan
|
return dataChan, upstreamHeaders, errChan
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func validateSSEDataJSON(chunk []byte) error {
|
||||||
|
for _, line := range bytes.Split(chunk, []byte("\n")) {
|
||||||
|
line = bytes.TrimSpace(line)
|
||||||
|
if len(line) == 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if !bytes.HasPrefix(line, []byte("data:")) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
data := bytes.TrimSpace(line[5:])
|
||||||
|
if len(data) == 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if bytes.Equal(data, []byte("[DONE]")) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if json.Valid(data) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
const max = 512
|
||||||
|
preview := data
|
||||||
|
if len(preview) > max {
|
||||||
|
preview = preview[:max]
|
||||||
|
}
|
||||||
|
return fmt.Errorf("invalid SSE data JSON (len=%d): %q", len(data), preview)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func statusFromError(err error) int {
|
func statusFromError(err error) int {
|
||||||
if err == nil {
|
if err == nil {
|
||||||
return 0
|
return 0
|
||||||
|
|||||||
@@ -134,6 +134,37 @@ type authAwareStreamExecutor struct {
|
|||||||
authIDs []string
|
authIDs []string
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type invalidJSONStreamExecutor struct{}
|
||||||
|
|
||||||
|
func (e *invalidJSONStreamExecutor) Identifier() string { return "codex" }
|
||||||
|
|
||||||
|
func (e *invalidJSONStreamExecutor) Execute(context.Context, *coreauth.Auth, coreexecutor.Request, coreexecutor.Options) (coreexecutor.Response, error) {
|
||||||
|
return coreexecutor.Response{}, &coreauth.Error{Code: "not_implemented", Message: "Execute not implemented"}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *invalidJSONStreamExecutor) ExecuteStream(context.Context, *coreauth.Auth, coreexecutor.Request, coreexecutor.Options) (*coreexecutor.StreamResult, error) {
|
||||||
|
ch := make(chan coreexecutor.StreamChunk, 1)
|
||||||
|
ch <- coreexecutor.StreamChunk{Payload: []byte("event: response.completed\ndata: {\"type\"")}
|
||||||
|
close(ch)
|
||||||
|
return &coreexecutor.StreamResult{Chunks: ch}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *invalidJSONStreamExecutor) Refresh(ctx context.Context, auth *coreauth.Auth) (*coreauth.Auth, error) {
|
||||||
|
return auth, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *invalidJSONStreamExecutor) CountTokens(context.Context, *coreauth.Auth, coreexecutor.Request, coreexecutor.Options) (coreexecutor.Response, error) {
|
||||||
|
return coreexecutor.Response{}, &coreauth.Error{Code: "not_implemented", Message: "CountTokens not implemented"}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *invalidJSONStreamExecutor) HttpRequest(ctx context.Context, auth *coreauth.Auth, req *http.Request) (*http.Response, error) {
|
||||||
|
return nil, &coreauth.Error{
|
||||||
|
Code: "not_implemented",
|
||||||
|
Message: "HttpRequest not implemented",
|
||||||
|
HTTPStatus: http.StatusNotImplemented,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func (e *authAwareStreamExecutor) Identifier() string { return "codex" }
|
func (e *authAwareStreamExecutor) Identifier() string { return "codex" }
|
||||||
|
|
||||||
func (e *authAwareStreamExecutor) Execute(context.Context, *coreauth.Auth, coreexecutor.Request, coreexecutor.Options) (coreexecutor.Response, error) {
|
func (e *authAwareStreamExecutor) Execute(context.Context, *coreauth.Auth, coreexecutor.Request, coreexecutor.Options) (coreexecutor.Response, error) {
|
||||||
@@ -524,3 +555,55 @@ func TestExecuteStreamWithAuthManager_SelectedAuthCallbackReceivesAuthID(t *test
|
|||||||
t.Fatalf("selectedAuthID = %q, want %q", selectedAuthID, "auth2")
|
t.Fatalf("selectedAuthID = %q, want %q", selectedAuthID, "auth2")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestExecuteStreamWithAuthManager_ValidatesOpenAIResponsesStreamDataJSON(t *testing.T) {
|
||||||
|
executor := &invalidJSONStreamExecutor{}
|
||||||
|
manager := coreauth.NewManager(nil, nil, nil)
|
||||||
|
manager.RegisterExecutor(executor)
|
||||||
|
|
||||||
|
auth1 := &coreauth.Auth{
|
||||||
|
ID: "auth1",
|
||||||
|
Provider: "codex",
|
||||||
|
Status: coreauth.StatusActive,
|
||||||
|
Metadata: map[string]any{"email": "test1@example.com"},
|
||||||
|
}
|
||||||
|
if _, err := manager.Register(context.Background(), auth1); err != nil {
|
||||||
|
t.Fatalf("manager.Register(auth1): %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
registry.GetGlobalRegistry().RegisterClient(auth1.ID, auth1.Provider, []*registry.ModelInfo{{ID: "test-model"}})
|
||||||
|
t.Cleanup(func() {
|
||||||
|
registry.GetGlobalRegistry().UnregisterClient(auth1.ID)
|
||||||
|
})
|
||||||
|
|
||||||
|
handler := NewBaseAPIHandlers(&sdkconfig.SDKConfig{}, manager)
|
||||||
|
dataChan, _, errChan := handler.ExecuteStreamWithAuthManager(context.Background(), "openai-response", "test-model", []byte(`{"model":"test-model"}`), "")
|
||||||
|
if dataChan == nil || errChan == nil {
|
||||||
|
t.Fatalf("expected non-nil channels")
|
||||||
|
}
|
||||||
|
|
||||||
|
var got []byte
|
||||||
|
for chunk := range dataChan {
|
||||||
|
got = append(got, chunk...)
|
||||||
|
}
|
||||||
|
if len(got) != 0 {
|
||||||
|
t.Fatalf("expected empty payload, got %q", string(got))
|
||||||
|
}
|
||||||
|
|
||||||
|
gotErr := false
|
||||||
|
for msg := range errChan {
|
||||||
|
if msg == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if msg.StatusCode != http.StatusBadGateway {
|
||||||
|
t.Fatalf("expected status %d, got %d", http.StatusBadGateway, msg.StatusCode)
|
||||||
|
}
|
||||||
|
if msg.Error == nil {
|
||||||
|
t.Fatalf("expected error")
|
||||||
|
}
|
||||||
|
gotErr = true
|
||||||
|
}
|
||||||
|
if !gotErr {
|
||||||
|
t.Fatalf("expected terminal error")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -265,8 +265,8 @@ func (h *OpenAIResponsesAPIHandler) forwardResponsesStream(c *gin.Context, flush
|
|||||||
if errMsg.Error != nil && errMsg.Error.Error() != "" {
|
if errMsg.Error != nil && errMsg.Error.Error() != "" {
|
||||||
errText = errMsg.Error.Error()
|
errText = errMsg.Error.Error()
|
||||||
}
|
}
|
||||||
body := handlers.BuildErrorResponseBody(status, errText)
|
chunk := handlers.BuildOpenAIResponsesStreamErrorChunk(status, errText, 0)
|
||||||
_, _ = fmt.Fprintf(c.Writer, "\nevent: error\ndata: %s\n\n", string(body))
|
_, _ = fmt.Fprintf(c.Writer, "\nevent: error\ndata: %s\n\n", string(chunk))
|
||||||
},
|
},
|
||||||
WriteDone: func() {
|
WriteDone: func() {
|
||||||
_, _ = c.Writer.Write([]byte("\n"))
|
_, _ = c.Writer.Write([]byte("\n"))
|
||||||
|
|||||||
@@ -0,0 +1,43 @@
|
|||||||
|
package openai
|
||||||
|
|
||||||
|
import (
|
||||||
|
"errors"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/gin-gonic/gin"
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/interfaces"
|
||||||
|
"github.com/router-for-me/CLIProxyAPI/v6/sdk/api/handlers"
|
||||||
|
sdkconfig "github.com/router-for-me/CLIProxyAPI/v6/sdk/config"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestForwardResponsesStreamTerminalErrorUsesResponsesErrorChunk(t *testing.T) {
|
||||||
|
gin.SetMode(gin.TestMode)
|
||||||
|
base := handlers.NewBaseAPIHandlers(&sdkconfig.SDKConfig{}, nil)
|
||||||
|
h := NewOpenAIResponsesAPIHandler(base)
|
||||||
|
|
||||||
|
recorder := httptest.NewRecorder()
|
||||||
|
c, _ := gin.CreateTestContext(recorder)
|
||||||
|
c.Request = httptest.NewRequest(http.MethodPost, "/v1/responses", nil)
|
||||||
|
|
||||||
|
flusher, ok := c.Writer.(http.Flusher)
|
||||||
|
if !ok {
|
||||||
|
t.Fatalf("expected gin writer to implement http.Flusher")
|
||||||
|
}
|
||||||
|
|
||||||
|
data := make(chan []byte)
|
||||||
|
errs := make(chan *interfaces.ErrorMessage, 1)
|
||||||
|
errs <- &interfaces.ErrorMessage{StatusCode: http.StatusInternalServerError, Error: errors.New("unexpected EOF")}
|
||||||
|
close(errs)
|
||||||
|
|
||||||
|
h.forwardResponsesStream(c, flusher, func(error) {}, data, errs)
|
||||||
|
body := recorder.Body.String()
|
||||||
|
if !strings.Contains(body, `"type":"error"`) {
|
||||||
|
t.Fatalf("expected responses error chunk, got: %q", body)
|
||||||
|
}
|
||||||
|
if strings.Contains(body, `"error":{`) {
|
||||||
|
t.Fatalf("expected streaming error chunk (top-level type), got HTTP error body: %q", body)
|
||||||
|
}
|
||||||
|
}
|
||||||
119
sdk/api/handlers/openai_responses_stream_error.go
Normal file
119
sdk/api/handlers/openai_responses_stream_error.go
Normal file
@@ -0,0 +1,119 @@
|
|||||||
|
package handlers
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"net/http"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
type openAIResponsesStreamErrorChunk struct {
|
||||||
|
Type string `json:"type"`
|
||||||
|
Code string `json:"code"`
|
||||||
|
Message string `json:"message"`
|
||||||
|
SequenceNumber int `json:"sequence_number"`
|
||||||
|
}
|
||||||
|
|
||||||
|
func openAIResponsesStreamErrorCode(status int) string {
|
||||||
|
switch status {
|
||||||
|
case http.StatusUnauthorized:
|
||||||
|
return "invalid_api_key"
|
||||||
|
case http.StatusForbidden:
|
||||||
|
return "insufficient_quota"
|
||||||
|
case http.StatusTooManyRequests:
|
||||||
|
return "rate_limit_exceeded"
|
||||||
|
case http.StatusNotFound:
|
||||||
|
return "model_not_found"
|
||||||
|
case http.StatusRequestTimeout:
|
||||||
|
return "request_timeout"
|
||||||
|
default:
|
||||||
|
if status >= http.StatusInternalServerError {
|
||||||
|
return "internal_server_error"
|
||||||
|
}
|
||||||
|
if status >= http.StatusBadRequest {
|
||||||
|
return "invalid_request_error"
|
||||||
|
}
|
||||||
|
return "unknown_error"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// BuildOpenAIResponsesStreamErrorChunk builds an OpenAI Responses streaming error chunk.
|
||||||
|
//
|
||||||
|
// Important: OpenAI's HTTP error bodies are shaped like {"error":{...}}; those are valid for
|
||||||
|
// non-streaming responses, but streaming clients validate SSE `data:` payloads against a union
|
||||||
|
// of chunks that requires a top-level `type` field.
|
||||||
|
func BuildOpenAIResponsesStreamErrorChunk(status int, errText string, sequenceNumber int) []byte {
|
||||||
|
if status <= 0 {
|
||||||
|
status = http.StatusInternalServerError
|
||||||
|
}
|
||||||
|
if sequenceNumber < 0 {
|
||||||
|
sequenceNumber = 0
|
||||||
|
}
|
||||||
|
|
||||||
|
message := strings.TrimSpace(errText)
|
||||||
|
if message == "" {
|
||||||
|
message = http.StatusText(status)
|
||||||
|
}
|
||||||
|
|
||||||
|
code := openAIResponsesStreamErrorCode(status)
|
||||||
|
|
||||||
|
trimmed := strings.TrimSpace(errText)
|
||||||
|
if trimmed != "" && json.Valid([]byte(trimmed)) {
|
||||||
|
var payload map[string]any
|
||||||
|
if err := json.Unmarshal([]byte(trimmed), &payload); err == nil {
|
||||||
|
if t, ok := payload["type"].(string); ok && strings.TrimSpace(t) == "error" {
|
||||||
|
if m, ok := payload["message"].(string); ok && strings.TrimSpace(m) != "" {
|
||||||
|
message = strings.TrimSpace(m)
|
||||||
|
}
|
||||||
|
if v, ok := payload["code"]; ok && v != nil {
|
||||||
|
if c, ok := v.(string); ok && strings.TrimSpace(c) != "" {
|
||||||
|
code = strings.TrimSpace(c)
|
||||||
|
} else {
|
||||||
|
code = strings.TrimSpace(fmt.Sprint(v))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if v, ok := payload["sequence_number"].(float64); ok && sequenceNumber == 0 {
|
||||||
|
sequenceNumber = int(v)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if e, ok := payload["error"].(map[string]any); ok {
|
||||||
|
if m, ok := e["message"].(string); ok && strings.TrimSpace(m) != "" {
|
||||||
|
message = strings.TrimSpace(m)
|
||||||
|
}
|
||||||
|
if v, ok := e["code"]; ok && v != nil {
|
||||||
|
if c, ok := v.(string); ok && strings.TrimSpace(c) != "" {
|
||||||
|
code = strings.TrimSpace(c)
|
||||||
|
} else {
|
||||||
|
code = strings.TrimSpace(fmt.Sprint(v))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if strings.TrimSpace(code) == "" {
|
||||||
|
code = "unknown_error"
|
||||||
|
}
|
||||||
|
|
||||||
|
data, err := json.Marshal(openAIResponsesStreamErrorChunk{
|
||||||
|
Type: "error",
|
||||||
|
Code: code,
|
||||||
|
Message: message,
|
||||||
|
SequenceNumber: sequenceNumber,
|
||||||
|
})
|
||||||
|
if err == nil {
|
||||||
|
return data
|
||||||
|
}
|
||||||
|
|
||||||
|
// Extremely defensive fallback.
|
||||||
|
data, _ = json.Marshal(openAIResponsesStreamErrorChunk{
|
||||||
|
Type: "error",
|
||||||
|
Code: "internal_server_error",
|
||||||
|
Message: message,
|
||||||
|
SequenceNumber: sequenceNumber,
|
||||||
|
})
|
||||||
|
if len(data) > 0 {
|
||||||
|
return data
|
||||||
|
}
|
||||||
|
return []byte(`{"type":"error","code":"internal_server_error","message":"internal error","sequence_number":0}`)
|
||||||
|
}
|
||||||
48
sdk/api/handlers/openai_responses_stream_error_test.go
Normal file
48
sdk/api/handlers/openai_responses_stream_error_test.go
Normal file
@@ -0,0 +1,48 @@
|
|||||||
|
package handlers
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"net/http"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestBuildOpenAIResponsesStreamErrorChunk(t *testing.T) {
|
||||||
|
chunk := BuildOpenAIResponsesStreamErrorChunk(http.StatusInternalServerError, "unexpected EOF", 0)
|
||||||
|
var payload map[string]any
|
||||||
|
if err := json.Unmarshal(chunk, &payload); err != nil {
|
||||||
|
t.Fatalf("unmarshal: %v", err)
|
||||||
|
}
|
||||||
|
if payload["type"] != "error" {
|
||||||
|
t.Fatalf("type = %v, want %q", payload["type"], "error")
|
||||||
|
}
|
||||||
|
if payload["code"] != "internal_server_error" {
|
||||||
|
t.Fatalf("code = %v, want %q", payload["code"], "internal_server_error")
|
||||||
|
}
|
||||||
|
if payload["message"] != "unexpected EOF" {
|
||||||
|
t.Fatalf("message = %v, want %q", payload["message"], "unexpected EOF")
|
||||||
|
}
|
||||||
|
if payload["sequence_number"] != float64(0) {
|
||||||
|
t.Fatalf("sequence_number = %v, want %v", payload["sequence_number"], 0)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestBuildOpenAIResponsesStreamErrorChunkExtractsHTTPErrorBody(t *testing.T) {
|
||||||
|
chunk := BuildOpenAIResponsesStreamErrorChunk(
|
||||||
|
http.StatusInternalServerError,
|
||||||
|
`{"error":{"message":"oops","type":"server_error","code":"internal_server_error"}}`,
|
||||||
|
0,
|
||||||
|
)
|
||||||
|
var payload map[string]any
|
||||||
|
if err := json.Unmarshal(chunk, &payload); err != nil {
|
||||||
|
t.Fatalf("unmarshal: %v", err)
|
||||||
|
}
|
||||||
|
if payload["type"] != "error" {
|
||||||
|
t.Fatalf("type = %v, want %q", payload["type"], "error")
|
||||||
|
}
|
||||||
|
if payload["code"] != "internal_server_error" {
|
||||||
|
t.Fatalf("code = %v, want %q", payload["code"], "internal_server_error")
|
||||||
|
}
|
||||||
|
if payload["message"] != "oops" {
|
||||||
|
t.Fatalf("message = %v, want %q", payload["message"], "oops")
|
||||||
|
}
|
||||||
|
}
|
||||||
Reference in New Issue
Block a user