feat(handlers): add test to verify no retries after partial stream response

Introduce `TestExecuteStreamWithAuthManager_DoesNotRetryAfterFirstByte` to validate that stream executions do not retry after receiving partial responses. Implement `payloadThenErrorStreamExecutor` for test coverage of this behavior.
This commit is contained in:
Luis Pater
2026-01-29 17:30:48 +08:00
parent 189a066807
commit 4eb1e6093f

View File

@@ -70,6 +70,58 @@ func (e *failOnceStreamExecutor) Calls() int {
return e.calls
}
type payloadThenErrorStreamExecutor struct {
mu sync.Mutex
calls int
}
func (e *payloadThenErrorStreamExecutor) Identifier() string { return "codex" }
func (e *payloadThenErrorStreamExecutor) Execute(context.Context, *coreauth.Auth, coreexecutor.Request, coreexecutor.Options) (coreexecutor.Response, error) {
return coreexecutor.Response{}, &coreauth.Error{Code: "not_implemented", Message: "Execute not implemented"}
}
func (e *payloadThenErrorStreamExecutor) ExecuteStream(context.Context, *coreauth.Auth, coreexecutor.Request, coreexecutor.Options) (<-chan coreexecutor.StreamChunk, error) {
e.mu.Lock()
e.calls++
e.mu.Unlock()
ch := make(chan coreexecutor.StreamChunk, 2)
ch <- coreexecutor.StreamChunk{Payload: []byte("partial")}
ch <- coreexecutor.StreamChunk{
Err: &coreauth.Error{
Code: "upstream_closed",
Message: "upstream closed",
Retryable: false,
HTTPStatus: http.StatusBadGateway,
},
}
close(ch)
return ch, nil
}
func (e *payloadThenErrorStreamExecutor) Refresh(ctx context.Context, auth *coreauth.Auth) (*coreauth.Auth, error) {
return auth, nil
}
func (e *payloadThenErrorStreamExecutor) CountTokens(context.Context, *coreauth.Auth, coreexecutor.Request, coreexecutor.Options) (coreexecutor.Response, error) {
return coreexecutor.Response{}, &coreauth.Error{Code: "not_implemented", Message: "CountTokens not implemented"}
}
func (e *payloadThenErrorStreamExecutor) HttpRequest(ctx context.Context, auth *coreauth.Auth, req *http.Request) (*http.Response, error) {
return nil, &coreauth.Error{
Code: "not_implemented",
Message: "HttpRequest not implemented",
HTTPStatus: http.StatusNotImplemented,
}
}
func (e *payloadThenErrorStreamExecutor) Calls() int {
e.mu.Lock()
defer e.mu.Unlock()
return e.calls
}
func TestExecuteStreamWithAuthManager_RetriesBeforeFirstByte(t *testing.T) {
executor := &failOnceStreamExecutor{}
manager := coreauth.NewManager(nil, nil, nil)
@@ -130,3 +182,73 @@ func TestExecuteStreamWithAuthManager_RetriesBeforeFirstByte(t *testing.T) {
t.Fatalf("expected 2 stream attempts, got %d", executor.Calls())
}
}
func TestExecuteStreamWithAuthManager_DoesNotRetryAfterFirstByte(t *testing.T) {
executor := &payloadThenErrorStreamExecutor{}
manager := coreauth.NewManager(nil, nil, nil)
manager.RegisterExecutor(executor)
auth1 := &coreauth.Auth{
ID: "auth1",
Provider: "codex",
Status: coreauth.StatusActive,
Metadata: map[string]any{"email": "test1@example.com"},
}
if _, err := manager.Register(context.Background(), auth1); err != nil {
t.Fatalf("manager.Register(auth1): %v", err)
}
auth2 := &coreauth.Auth{
ID: "auth2",
Provider: "codex",
Status: coreauth.StatusActive,
Metadata: map[string]any{"email": "test2@example.com"},
}
if _, err := manager.Register(context.Background(), auth2); err != nil {
t.Fatalf("manager.Register(auth2): %v", err)
}
registry.GetGlobalRegistry().RegisterClient(auth1.ID, auth1.Provider, []*registry.ModelInfo{{ID: "test-model"}})
registry.GetGlobalRegistry().RegisterClient(auth2.ID, auth2.Provider, []*registry.ModelInfo{{ID: "test-model"}})
t.Cleanup(func() {
registry.GetGlobalRegistry().UnregisterClient(auth1.ID)
registry.GetGlobalRegistry().UnregisterClient(auth2.ID)
})
handler := NewBaseAPIHandlers(&sdkconfig.SDKConfig{
Streaming: sdkconfig.StreamingConfig{
BootstrapRetries: 1,
},
}, manager)
dataChan, errChan := handler.ExecuteStreamWithAuthManager(context.Background(), "openai", "test-model", []byte(`{"model":"test-model"}`), "")
if dataChan == nil || errChan == nil {
t.Fatalf("expected non-nil channels")
}
var got []byte
for chunk := range dataChan {
got = append(got, chunk...)
}
var gotErr error
var gotStatus int
for msg := range errChan {
if msg != nil && msg.Error != nil {
gotErr = msg.Error
gotStatus = msg.StatusCode
}
}
if string(got) != "partial" {
t.Fatalf("expected payload partial, got %q", string(got))
}
if gotErr == nil {
t.Fatalf("expected terminal error, got nil")
}
if gotStatus != http.StatusBadGateway {
t.Fatalf("expected status %d, got %d", http.StatusBadGateway, gotStatus)
}
if executor.Calls() != 1 {
t.Fatalf("expected 1 stream attempt, got %d", executor.Calls())
}
}