package handlers import ( "context" "net/http" "sync" "testing" "github.com/router-for-me/CLIProxyAPI/v6/internal/registry" coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth" coreexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor" sdkconfig "github.com/router-for-me/CLIProxyAPI/v6/sdk/config" ) type failOnceStreamExecutor struct { mu sync.Mutex calls int } func (e *failOnceStreamExecutor) Identifier() string { return "codex" } func (e *failOnceStreamExecutor) Execute(context.Context, *coreauth.Auth, coreexecutor.Request, coreexecutor.Options) (coreexecutor.Response, error) { return coreexecutor.Response{}, &coreauth.Error{Code: "not_implemented", Message: "Execute not implemented"} } func (e *failOnceStreamExecutor) ExecuteStream(context.Context, *coreauth.Auth, coreexecutor.Request, coreexecutor.Options) (<-chan coreexecutor.StreamChunk, error) { e.mu.Lock() e.calls++ call := e.calls e.mu.Unlock() ch := make(chan coreexecutor.StreamChunk, 1) if call == 1 { ch <- coreexecutor.StreamChunk{ Err: &coreauth.Error{ Code: "unauthorized", Message: "unauthorized", Retryable: false, HTTPStatus: http.StatusUnauthorized, }, } close(ch) return ch, nil } ch <- coreexecutor.StreamChunk{Payload: []byte("ok")} close(ch) return ch, nil } func (e *failOnceStreamExecutor) Refresh(ctx context.Context, auth *coreauth.Auth) (*coreauth.Auth, error) { return auth, nil } func (e *failOnceStreamExecutor) CountTokens(context.Context, *coreauth.Auth, coreexecutor.Request, coreexecutor.Options) (coreexecutor.Response, error) { return coreexecutor.Response{}, &coreauth.Error{Code: "not_implemented", Message: "CountTokens not implemented"} } func (e *failOnceStreamExecutor) Calls() int { e.mu.Lock() defer e.mu.Unlock() return e.calls } func TestExecuteStreamWithAuthManager_RetriesBeforeFirstByte(t *testing.T) { executor := &failOnceStreamExecutor{} manager := coreauth.NewManager(nil, nil, nil) manager.RegisterExecutor(executor) auth1 := &coreauth.Auth{ ID: "auth1", Provider: "codex", Status: coreauth.StatusActive, Metadata: map[string]any{"email": "test1@example.com"}, } if _, err := manager.Register(context.Background(), auth1); err != nil { t.Fatalf("manager.Register(auth1): %v", err) } auth2 := &coreauth.Auth{ ID: "auth2", Provider: "codex", Status: coreauth.StatusActive, Metadata: map[string]any{"email": "test2@example.com"}, } if _, err := manager.Register(context.Background(), auth2); err != nil { t.Fatalf("manager.Register(auth2): %v", err) } registry.GetGlobalRegistry().RegisterClient(auth1.ID, auth1.Provider, []*registry.ModelInfo{{ID: "test-model"}}) registry.GetGlobalRegistry().RegisterClient(auth2.ID, auth2.Provider, []*registry.ModelInfo{{ID: "test-model"}}) t.Cleanup(func() { registry.GetGlobalRegistry().UnregisterClient(auth1.ID) registry.GetGlobalRegistry().UnregisterClient(auth2.ID) }) bootstrapRetries := 1 handler := NewBaseAPIHandlers(&sdkconfig.SDKConfig{ Streaming: sdkconfig.StreamingConfig{ BootstrapRetries: &bootstrapRetries, }, }, manager, nil) dataChan, errChan := handler.ExecuteStreamWithAuthManager(context.Background(), "openai", "test-model", []byte(`{"model":"test-model"}`), "") if dataChan == nil || errChan == nil { t.Fatalf("expected non-nil channels") } var got []byte for chunk := range dataChan { got = append(got, chunk...) } for msg := range errChan { if msg != nil { t.Fatalf("unexpected error: %+v", msg) } } if string(got) != "ok" { t.Fatalf("expected payload ok, got %q", string(got)) } if executor.Calls() != 2 { t.Fatalf("expected 2 stream attempts, got %d", executor.Calls()) } }