mirror of
https://github.com/router-for-me/CLIProxyAPI.git
synced 2026-02-28 08:36:09 +08:00
fix(responses): reject invalid SSE data JSON
Guard the openai-response streaming path against truncated/invalid SSE data payloads by validating data: JSON before forwarding; surface a 502 terminal error instead of letting clients crash with JSON parse errors.
This commit is contained in:
@@ -716,6 +716,12 @@ func (h *BaseAPIHandler) ExecuteStreamWithAuthManager(ctx context.Context, handl
|
||||
return
|
||||
}
|
||||
if len(chunk.Payload) > 0 {
|
||||
if handlerType == "openai-response" {
|
||||
if err := validateSSEDataJSON(chunk.Payload); err != nil {
|
||||
_ = sendErr(&interfaces.ErrorMessage{StatusCode: http.StatusBadGateway, Error: err})
|
||||
return
|
||||
}
|
||||
}
|
||||
sentPayload = true
|
||||
if okSendData := sendData(cloneBytes(chunk.Payload)); !okSendData {
|
||||
return
|
||||
@@ -727,6 +733,35 @@ func (h *BaseAPIHandler) ExecuteStreamWithAuthManager(ctx context.Context, handl
|
||||
return dataChan, upstreamHeaders, errChan
|
||||
}
|
||||
|
||||
func validateSSEDataJSON(chunk []byte) error {
|
||||
for _, line := range bytes.Split(chunk, []byte("\n")) {
|
||||
line = bytes.TrimSpace(line)
|
||||
if len(line) == 0 {
|
||||
continue
|
||||
}
|
||||
if !bytes.HasPrefix(line, []byte("data:")) {
|
||||
continue
|
||||
}
|
||||
data := bytes.TrimSpace(line[5:])
|
||||
if len(data) == 0 {
|
||||
continue
|
||||
}
|
||||
if bytes.Equal(data, []byte("[DONE]")) {
|
||||
continue
|
||||
}
|
||||
if json.Valid(data) {
|
||||
continue
|
||||
}
|
||||
const max = 512
|
||||
preview := data
|
||||
if len(preview) > max {
|
||||
preview = preview[:max]
|
||||
}
|
||||
return fmt.Errorf("invalid SSE data JSON (len=%d): %q", len(data), preview)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func statusFromError(err error) int {
|
||||
if err == nil {
|
||||
return 0
|
||||
|
||||
@@ -134,6 +134,37 @@ type authAwareStreamExecutor struct {
|
||||
authIDs []string
|
||||
}
|
||||
|
||||
type invalidJSONStreamExecutor struct{}
|
||||
|
||||
func (e *invalidJSONStreamExecutor) Identifier() string { return "codex" }
|
||||
|
||||
func (e *invalidJSONStreamExecutor) Execute(context.Context, *coreauth.Auth, coreexecutor.Request, coreexecutor.Options) (coreexecutor.Response, error) {
|
||||
return coreexecutor.Response{}, &coreauth.Error{Code: "not_implemented", Message: "Execute not implemented"}
|
||||
}
|
||||
|
||||
func (e *invalidJSONStreamExecutor) ExecuteStream(context.Context, *coreauth.Auth, coreexecutor.Request, coreexecutor.Options) (*coreexecutor.StreamResult, error) {
|
||||
ch := make(chan coreexecutor.StreamChunk, 1)
|
||||
ch <- coreexecutor.StreamChunk{Payload: []byte("event: response.completed\ndata: {\"type\"")}
|
||||
close(ch)
|
||||
return &coreexecutor.StreamResult{Chunks: ch}, nil
|
||||
}
|
||||
|
||||
func (e *invalidJSONStreamExecutor) Refresh(ctx context.Context, auth *coreauth.Auth) (*coreauth.Auth, error) {
|
||||
return auth, nil
|
||||
}
|
||||
|
||||
func (e *invalidJSONStreamExecutor) CountTokens(context.Context, *coreauth.Auth, coreexecutor.Request, coreexecutor.Options) (coreexecutor.Response, error) {
|
||||
return coreexecutor.Response{}, &coreauth.Error{Code: "not_implemented", Message: "CountTokens not implemented"}
|
||||
}
|
||||
|
||||
func (e *invalidJSONStreamExecutor) HttpRequest(ctx context.Context, auth *coreauth.Auth, req *http.Request) (*http.Response, error) {
|
||||
return nil, &coreauth.Error{
|
||||
Code: "not_implemented",
|
||||
Message: "HttpRequest not implemented",
|
||||
HTTPStatus: http.StatusNotImplemented,
|
||||
}
|
||||
}
|
||||
|
||||
func (e *authAwareStreamExecutor) Identifier() string { return "codex" }
|
||||
|
||||
func (e *authAwareStreamExecutor) Execute(context.Context, *coreauth.Auth, coreexecutor.Request, coreexecutor.Options) (coreexecutor.Response, error) {
|
||||
@@ -524,3 +555,55 @@ func TestExecuteStreamWithAuthManager_SelectedAuthCallbackReceivesAuthID(t *test
|
||||
t.Fatalf("selectedAuthID = %q, want %q", selectedAuthID, "auth2")
|
||||
}
|
||||
}
|
||||
|
||||
func TestExecuteStreamWithAuthManager_ValidatesOpenAIResponsesStreamDataJSON(t *testing.T) {
|
||||
executor := &invalidJSONStreamExecutor{}
|
||||
manager := coreauth.NewManager(nil, nil, nil)
|
||||
manager.RegisterExecutor(executor)
|
||||
|
||||
auth1 := &coreauth.Auth{
|
||||
ID: "auth1",
|
||||
Provider: "codex",
|
||||
Status: coreauth.StatusActive,
|
||||
Metadata: map[string]any{"email": "test1@example.com"},
|
||||
}
|
||||
if _, err := manager.Register(context.Background(), auth1); err != nil {
|
||||
t.Fatalf("manager.Register(auth1): %v", err)
|
||||
}
|
||||
|
||||
registry.GetGlobalRegistry().RegisterClient(auth1.ID, auth1.Provider, []*registry.ModelInfo{{ID: "test-model"}})
|
||||
t.Cleanup(func() {
|
||||
registry.GetGlobalRegistry().UnregisterClient(auth1.ID)
|
||||
})
|
||||
|
||||
handler := NewBaseAPIHandlers(&sdkconfig.SDKConfig{}, manager)
|
||||
dataChan, _, errChan := handler.ExecuteStreamWithAuthManager(context.Background(), "openai-response", "test-model", []byte(`{"model":"test-model"}`), "")
|
||||
if dataChan == nil || errChan == nil {
|
||||
t.Fatalf("expected non-nil channels")
|
||||
}
|
||||
|
||||
var got []byte
|
||||
for chunk := range dataChan {
|
||||
got = append(got, chunk...)
|
||||
}
|
||||
if len(got) != 0 {
|
||||
t.Fatalf("expected empty payload, got %q", string(got))
|
||||
}
|
||||
|
||||
gotErr := false
|
||||
for msg := range errChan {
|
||||
if msg == nil {
|
||||
continue
|
||||
}
|
||||
if msg.StatusCode != http.StatusBadGateway {
|
||||
t.Fatalf("expected status %d, got %d", http.StatusBadGateway, msg.StatusCode)
|
||||
}
|
||||
if msg.Error == nil {
|
||||
t.Fatalf("expected error")
|
||||
}
|
||||
gotErr = true
|
||||
}
|
||||
if !gotErr {
|
||||
t.Fatalf("expected terminal error")
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user