mirror of
https://github.com/router-for-me/CLIProxyAPI.git
synced 2026-02-02 20:40:52 +08:00
133 lines
3.9 KiB
Go
133 lines
3.9 KiB
Go
package handlers
|
|
|
|
import (
|
|
"context"
|
|
"net/http"
|
|
"sync"
|
|
"testing"
|
|
|
|
"github.com/router-for-me/CLIProxyAPI/v6/internal/registry"
|
|
coreauth "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/auth"
|
|
coreexecutor "github.com/router-for-me/CLIProxyAPI/v6/sdk/cliproxy/executor"
|
|
sdkconfig "github.com/router-for-me/CLIProxyAPI/v6/sdk/config"
|
|
)
|
|
|
|
type failOnceStreamExecutor struct {
|
|
mu sync.Mutex
|
|
calls int
|
|
}
|
|
|
|
func (e *failOnceStreamExecutor) Identifier() string { return "codex" }
|
|
|
|
func (e *failOnceStreamExecutor) Execute(context.Context, *coreauth.Auth, coreexecutor.Request, coreexecutor.Options) (coreexecutor.Response, error) {
|
|
return coreexecutor.Response{}, &coreauth.Error{Code: "not_implemented", Message: "Execute not implemented"}
|
|
}
|
|
|
|
func (e *failOnceStreamExecutor) ExecuteStream(context.Context, *coreauth.Auth, coreexecutor.Request, coreexecutor.Options) (<-chan coreexecutor.StreamChunk, error) {
|
|
e.mu.Lock()
|
|
e.calls++
|
|
call := e.calls
|
|
e.mu.Unlock()
|
|
|
|
ch := make(chan coreexecutor.StreamChunk, 1)
|
|
if call == 1 {
|
|
ch <- coreexecutor.StreamChunk{
|
|
Err: &coreauth.Error{
|
|
Code: "unauthorized",
|
|
Message: "unauthorized",
|
|
Retryable: false,
|
|
HTTPStatus: http.StatusUnauthorized,
|
|
},
|
|
}
|
|
close(ch)
|
|
return ch, nil
|
|
}
|
|
|
|
ch <- coreexecutor.StreamChunk{Payload: []byte("ok")}
|
|
close(ch)
|
|
return ch, nil
|
|
}
|
|
|
|
func (e *failOnceStreamExecutor) Refresh(ctx context.Context, auth *coreauth.Auth) (*coreauth.Auth, error) {
|
|
return auth, nil
|
|
}
|
|
|
|
func (e *failOnceStreamExecutor) CountTokens(context.Context, *coreauth.Auth, coreexecutor.Request, coreexecutor.Options) (coreexecutor.Response, error) {
|
|
return coreexecutor.Response{}, &coreauth.Error{Code: "not_implemented", Message: "CountTokens not implemented"}
|
|
}
|
|
|
|
func (e *failOnceStreamExecutor) HttpRequest(ctx context.Context, auth *coreauth.Auth, req *http.Request) (*http.Response, error) {
|
|
return nil, &coreauth.Error{
|
|
Code: "not_implemented",
|
|
Message: "HttpRequest not implemented",
|
|
HTTPStatus: http.StatusNotImplemented,
|
|
}
|
|
}
|
|
|
|
func (e *failOnceStreamExecutor) Calls() int {
|
|
e.mu.Lock()
|
|
defer e.mu.Unlock()
|
|
return e.calls
|
|
}
|
|
|
|
func TestExecuteStreamWithAuthManager_RetriesBeforeFirstByte(t *testing.T) {
|
|
executor := &failOnceStreamExecutor{}
|
|
manager := coreauth.NewManager(nil, nil, nil)
|
|
manager.RegisterExecutor(executor)
|
|
|
|
auth1 := &coreauth.Auth{
|
|
ID: "auth1",
|
|
Provider: "codex",
|
|
Status: coreauth.StatusActive,
|
|
Metadata: map[string]any{"email": "test1@example.com"},
|
|
}
|
|
if _, err := manager.Register(context.Background(), auth1); err != nil {
|
|
t.Fatalf("manager.Register(auth1): %v", err)
|
|
}
|
|
|
|
auth2 := &coreauth.Auth{
|
|
ID: "auth2",
|
|
Provider: "codex",
|
|
Status: coreauth.StatusActive,
|
|
Metadata: map[string]any{"email": "test2@example.com"},
|
|
}
|
|
if _, err := manager.Register(context.Background(), auth2); err != nil {
|
|
t.Fatalf("manager.Register(auth2): %v", err)
|
|
}
|
|
|
|
registry.GetGlobalRegistry().RegisterClient(auth1.ID, auth1.Provider, []*registry.ModelInfo{{ID: "test-model"}})
|
|
registry.GetGlobalRegistry().RegisterClient(auth2.ID, auth2.Provider, []*registry.ModelInfo{{ID: "test-model"}})
|
|
t.Cleanup(func() {
|
|
registry.GetGlobalRegistry().UnregisterClient(auth1.ID)
|
|
registry.GetGlobalRegistry().UnregisterClient(auth2.ID)
|
|
})
|
|
|
|
handler := NewBaseAPIHandlers(&sdkconfig.SDKConfig{
|
|
Streaming: sdkconfig.StreamingConfig{
|
|
BootstrapRetries: 1,
|
|
},
|
|
}, manager)
|
|
dataChan, errChan := handler.ExecuteStreamWithAuthManager(context.Background(), "openai", "test-model", []byte(`{"model":"test-model"}`), "")
|
|
if dataChan == nil || errChan == nil {
|
|
t.Fatalf("expected non-nil channels")
|
|
}
|
|
|
|
var got []byte
|
|
for chunk := range dataChan {
|
|
got = append(got, chunk...)
|
|
}
|
|
|
|
for msg := range errChan {
|
|
if msg != nil {
|
|
t.Fatalf("unexpected error: %+v", msg)
|
|
}
|
|
}
|
|
|
|
if string(got) != "ok" {
|
|
t.Fatalf("expected payload ok, got %q", string(got))
|
|
}
|
|
if executor.Calls() != 2 {
|
|
t.Fatalf("expected 2 stream attempts, got %d", executor.Calls())
|
|
}
|
|
}
|