mirror of
https://github.com/earendil-works/pi.git
synced 2026-06-18 15:54:04 +08:00
feat(ai,coding-agent): add faux provider and ModelRegistry factories
This commit is contained in:
@@ -636,6 +636,92 @@ The library uses a registry of API implementations. Built-in APIs include:
|
||||
- **`azure-openai-responses`**: Azure OpenAI Responses API (`streamAzureOpenAIResponses`, `AzureOpenAIResponsesOptions`)
|
||||
- **`bedrock-converse-stream`**: Amazon Bedrock Converse API (`streamBedrock`, `BedrockOptions`)
|
||||
|
||||
### Faux provider for tests
|
||||
|
||||
`registerFauxProvider()` registers a temporary in-memory provider for tests and demos. It is opt-in and not part of the built-in provider set.
|
||||
|
||||
```typescript
|
||||
import {
|
||||
complete,
|
||||
fauxAssistantMessage,
|
||||
fauxText,
|
||||
fauxThinking,
|
||||
fauxToolCall,
|
||||
registerFauxProvider,
|
||||
stream,
|
||||
} from '@mariozechner/pi-ai';
|
||||
|
||||
const registration = registerFauxProvider({
|
||||
tokensPerSecond: 50 // optional
|
||||
});
|
||||
|
||||
const model = registration.getModel();
|
||||
const context = {
|
||||
messages: [{ role: 'user', content: 'Summarize package.json and then call echo', timestamp: Date.now() }]
|
||||
};
|
||||
|
||||
registration.setResponses([
|
||||
fauxAssistantMessage([
|
||||
fauxThinking('Need to inspect package metadata first.'),
|
||||
fauxToolCall('echo', { text: 'package.json' })
|
||||
], { stopReason: 'toolUse' })
|
||||
]);
|
||||
|
||||
const first = await complete(model, context, {
|
||||
sessionId: 'session-1',
|
||||
cacheRetention: 'short'
|
||||
});
|
||||
context.messages.push(first);
|
||||
|
||||
context.messages.push({
|
||||
role: 'toolResult',
|
||||
toolCallId: first.content.find((block) => block.type === 'toolCall')!.id,
|
||||
toolName: 'echo',
|
||||
content: [{ type: 'text', text: 'package.json contents here' }],
|
||||
isError: false,
|
||||
timestamp: Date.now()
|
||||
});
|
||||
|
||||
registration.setResponses([
|
||||
fauxAssistantMessage([
|
||||
fauxThinking('Now I can summarize the tool output.'),
|
||||
fauxText('Here is the summary.')
|
||||
])
|
||||
]);
|
||||
|
||||
const s = stream(model, context);
|
||||
for await (const event of s) {
|
||||
console.log(event.type);
|
||||
}
|
||||
|
||||
// Optional: register multiple faux models for model-switching tests
|
||||
const multiModel = registerFauxProvider({
|
||||
models: [
|
||||
{ id: 'faux-fast', reasoning: false },
|
||||
{ id: 'faux-thinker', reasoning: true }
|
||||
]
|
||||
});
|
||||
const thinker = multiModel.getModel('faux-thinker');
|
||||
|
||||
console.log(thinker?.reasoning);
|
||||
console.log(registration.getPendingResponseCount());
|
||||
console.log(registration.state.callCount);
|
||||
registration.unregister();
|
||||
multiModel.unregister();
|
||||
```
|
||||
|
||||
Notes:
|
||||
- Responses are consumed from a queue in request start order.
|
||||
- If the queue is empty, the faux provider returns an assistant error message with `errorMessage: "No more faux responses queued"`.
|
||||
- Use `registration.setResponses([...])` to replace the remaining queue and `registration.appendResponses([...])` to add more responses.
|
||||
- `registration.models` exposes all registered faux models. `registration.getModel()` returns the first one, and `registration.getModel(id)` returns a specific one.
|
||||
- Use `fauxAssistantMessage(...)` for scripted assistant replies. Use `fauxText(...)`, `fauxThinking(...)`, and `fauxToolCall(...)` to build content blocks without filling in low-level fields manually.
|
||||
- `registration.unregister()` removes the temporary provider from the global API registry.
|
||||
- Usage is estimated at roughly 1 token per 4 characters. When `sessionId` is present and `cacheRetention` is not `"none"`, prompt cache reads and writes are simulated automatically.
|
||||
- Tool call arguments stream incrementally via `toolcall_delta` chunks.
|
||||
- By default, each streamed chunk is emitted on its own microtask. Set `tokensPerSecond` to pace chunk delivery in real time.
|
||||
- The intended use is one deterministic scripted flow per registration. If you need independent concurrent flows, register separate faux providers.
|
||||
|
||||
### Providers and Models
|
||||
|
||||
A **provider** offers models through a specific API. For example:
|
||||
|
||||
@@ -7,6 +7,7 @@ export * from "./models.js";
|
||||
export type { BedrockOptions } from "./providers/amazon-bedrock.js";
|
||||
export type { AnthropicOptions } from "./providers/anthropic.js";
|
||||
export type { AzureOpenAIResponsesOptions } from "./providers/azure-openai-responses.js";
|
||||
export * from "./providers/faux.js";
|
||||
export type { GoogleOptions } from "./providers/google.js";
|
||||
export type { GoogleGeminiCliOptions, GoogleThinkingLevel } from "./providers/google-gemini-cli.js";
|
||||
export type { GoogleVertexOptions } from "./providers/google-vertex.js";
|
||||
|
||||
@@ -0,0 +1,498 @@
|
||||
import { registerApiProvider, unregisterApiProviders } from "../api-registry.js";
|
||||
import type {
|
||||
AssistantMessage,
|
||||
AssistantMessageEventStream,
|
||||
Context,
|
||||
ImageContent,
|
||||
Message,
|
||||
Model,
|
||||
SimpleStreamOptions,
|
||||
StreamFunction,
|
||||
StreamOptions,
|
||||
TextContent,
|
||||
ThinkingContent,
|
||||
ToolCall,
|
||||
ToolResultMessage,
|
||||
Usage,
|
||||
} from "../types.js";
|
||||
import { createAssistantMessageEventStream } from "../utils/event-stream.js";
|
||||
|
||||
const DEFAULT_API = "faux";
|
||||
const DEFAULT_PROVIDER = "faux";
|
||||
const DEFAULT_MODEL_ID = "faux-1";
|
||||
const DEFAULT_MODEL_NAME = "Faux Model";
|
||||
const DEFAULT_BASE_URL = "http://localhost:0";
|
||||
const DEFAULT_MIN_TOKEN_SIZE = 3;
|
||||
const DEFAULT_MAX_TOKEN_SIZE = 5;
|
||||
|
||||
const DEFAULT_USAGE: Usage = {
|
||||
input: 0,
|
||||
output: 0,
|
||||
cacheRead: 0,
|
||||
cacheWrite: 0,
|
||||
totalTokens: 0,
|
||||
cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, total: 0 },
|
||||
};
|
||||
|
||||
export interface FauxModelDefinition {
|
||||
id: string;
|
||||
name?: string;
|
||||
reasoning?: boolean;
|
||||
input?: ("text" | "image")[];
|
||||
cost?: { input: number; output: number; cacheRead: number; cacheWrite: number };
|
||||
contextWindow?: number;
|
||||
maxTokens?: number;
|
||||
}
|
||||
|
||||
export type FauxContentBlock = TextContent | ThinkingContent | ToolCall;
|
||||
|
||||
export function fauxText(text: string): TextContent {
|
||||
return { type: "text", text };
|
||||
}
|
||||
|
||||
export function fauxThinking(thinking: string): ThinkingContent {
|
||||
return { type: "thinking", thinking };
|
||||
}
|
||||
|
||||
export function fauxToolCall(name: string, arguments_: ToolCall["arguments"], options: { id?: string } = {}): ToolCall {
|
||||
return {
|
||||
type: "toolCall",
|
||||
id: options.id ?? randomId("tool"),
|
||||
name,
|
||||
arguments: arguments_,
|
||||
};
|
||||
}
|
||||
|
||||
function normalizeFauxAssistantContent(content: string | FauxContentBlock | FauxContentBlock[]): FauxContentBlock[] {
|
||||
if (typeof content === "string") {
|
||||
return [fauxText(content)];
|
||||
}
|
||||
return Array.isArray(content) ? content : [content];
|
||||
}
|
||||
|
||||
export function fauxAssistantMessage(
|
||||
content: string | FauxContentBlock | FauxContentBlock[],
|
||||
options: {
|
||||
stopReason?: AssistantMessage["stopReason"];
|
||||
errorMessage?: string;
|
||||
responseId?: string;
|
||||
timestamp?: number;
|
||||
} = {},
|
||||
): AssistantMessage {
|
||||
return {
|
||||
role: "assistant",
|
||||
content: normalizeFauxAssistantContent(content),
|
||||
api: DEFAULT_API,
|
||||
provider: DEFAULT_PROVIDER,
|
||||
model: DEFAULT_MODEL_ID,
|
||||
usage: DEFAULT_USAGE,
|
||||
stopReason: options.stopReason ?? "stop",
|
||||
errorMessage: options.errorMessage,
|
||||
responseId: options.responseId,
|
||||
timestamp: options.timestamp ?? Date.now(),
|
||||
};
|
||||
}
|
||||
|
||||
export type FauxResponseFactory = (
|
||||
context: Context,
|
||||
options: StreamOptions | undefined,
|
||||
state: { callCount: number },
|
||||
model: Model<string>,
|
||||
) => AssistantMessage | Promise<AssistantMessage>;
|
||||
|
||||
export type FauxResponseStep = AssistantMessage | FauxResponseFactory;
|
||||
|
||||
export interface RegisterFauxProviderOptions {
|
||||
api?: string;
|
||||
provider?: string;
|
||||
models?: FauxModelDefinition[];
|
||||
tokensPerSecond?: number;
|
||||
tokenSize?: {
|
||||
min?: number;
|
||||
max?: number;
|
||||
};
|
||||
}
|
||||
|
||||
export interface FauxProviderRegistration {
|
||||
api: string;
|
||||
models: [Model<string>, ...Model<string>[]];
|
||||
getModel(): Model<string>;
|
||||
getModel(modelId: string): Model<string> | undefined;
|
||||
state: { callCount: number };
|
||||
setResponses: (responses: FauxResponseStep[]) => void;
|
||||
appendResponses: (responses: FauxResponseStep[]) => void;
|
||||
getPendingResponseCount: () => number;
|
||||
unregister: () => void;
|
||||
}
|
||||
|
||||
function estimateTokens(text: string): number {
|
||||
return Math.ceil(text.length / 4);
|
||||
}
|
||||
|
||||
function randomId(prefix: string): string {
|
||||
return `${prefix}:${Date.now()}:${Math.random().toString(36).slice(2)}`;
|
||||
}
|
||||
|
||||
function contentToText(content: string | Array<TextContent | ImageContent>): string {
|
||||
if (typeof content === "string") {
|
||||
return content;
|
||||
}
|
||||
return content
|
||||
.map((block) => {
|
||||
if (block.type === "text") {
|
||||
return block.text;
|
||||
}
|
||||
return `[image:${block.mimeType}:${block.data.length}]`;
|
||||
})
|
||||
.join("\n");
|
||||
}
|
||||
|
||||
function assistantContentToText(content: Array<TextContent | ThinkingContent | ToolCall>): string {
|
||||
return content
|
||||
.map((block) => {
|
||||
if (block.type === "text") {
|
||||
return block.text;
|
||||
}
|
||||
if (block.type === "thinking") {
|
||||
return block.thinking;
|
||||
}
|
||||
return `${block.name}:${JSON.stringify(block.arguments)}`;
|
||||
})
|
||||
.join("\n");
|
||||
}
|
||||
|
||||
function toolResultToText(message: ToolResultMessage): string {
|
||||
return [message.toolName, ...message.content.map((block) => contentToText([block]))].join("\n");
|
||||
}
|
||||
|
||||
function messageToText(message: Message): string {
|
||||
if (message.role === "user") {
|
||||
return contentToText(message.content);
|
||||
}
|
||||
if (message.role === "assistant") {
|
||||
return assistantContentToText(message.content);
|
||||
}
|
||||
return toolResultToText(message);
|
||||
}
|
||||
|
||||
function serializeContext(context: Context): string {
|
||||
const parts: string[] = [];
|
||||
if (context.systemPrompt) {
|
||||
parts.push(`system:${context.systemPrompt}`);
|
||||
}
|
||||
for (const message of context.messages) {
|
||||
parts.push(`${message.role}:${messageToText(message)}`);
|
||||
}
|
||||
if (context.tools?.length) {
|
||||
parts.push(`tools:${JSON.stringify(context.tools)}`);
|
||||
}
|
||||
return parts.join("\n\n");
|
||||
}
|
||||
|
||||
function commonPrefixLength(a: string, b: string): number {
|
||||
const length = Math.min(a.length, b.length);
|
||||
let index = 0;
|
||||
while (index < length && a[index] === b[index]) {
|
||||
index++;
|
||||
}
|
||||
return index;
|
||||
}
|
||||
|
||||
function withUsageEstimate(
|
||||
message: AssistantMessage,
|
||||
context: Context,
|
||||
options: StreamOptions | undefined,
|
||||
promptCache: Map<string, string>,
|
||||
): AssistantMessage {
|
||||
const promptText = serializeContext(context);
|
||||
const promptTokens = estimateTokens(promptText);
|
||||
const outputTokens = estimateTokens(assistantContentToText(message.content));
|
||||
let input = promptTokens;
|
||||
let cacheRead = 0;
|
||||
let cacheWrite = 0;
|
||||
const sessionId = options?.sessionId;
|
||||
|
||||
if (sessionId && options?.cacheRetention !== "none") {
|
||||
const previousPrompt = promptCache.get(sessionId);
|
||||
if (previousPrompt) {
|
||||
const cachedChars = commonPrefixLength(previousPrompt, promptText);
|
||||
cacheRead = estimateTokens(previousPrompt.slice(0, cachedChars));
|
||||
cacheWrite = estimateTokens(promptText.slice(cachedChars));
|
||||
input = Math.max(0, promptTokens - cacheRead);
|
||||
} else {
|
||||
cacheWrite = promptTokens;
|
||||
}
|
||||
promptCache.set(sessionId, promptText);
|
||||
}
|
||||
|
||||
return {
|
||||
...message,
|
||||
usage: {
|
||||
input,
|
||||
output: outputTokens,
|
||||
cacheRead,
|
||||
cacheWrite,
|
||||
totalTokens: input + outputTokens + cacheRead + cacheWrite,
|
||||
cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, total: 0 },
|
||||
},
|
||||
};
|
||||
}
|
||||
|
||||
function splitStringByTokenSize(text: string, minTokenSize: number, maxTokenSize: number): string[] {
|
||||
const chunks: string[] = [];
|
||||
let index = 0;
|
||||
while (index < text.length) {
|
||||
const tokenSize = minTokenSize + Math.floor(Math.random() * (maxTokenSize - minTokenSize + 1));
|
||||
const charSize = Math.max(1, tokenSize * 4);
|
||||
chunks.push(text.slice(index, index + charSize));
|
||||
index += charSize;
|
||||
}
|
||||
return chunks.length > 0 ? chunks : [""];
|
||||
}
|
||||
|
||||
function cloneMessage(message: AssistantMessage, api: string, provider: string, modelId: string): AssistantMessage {
|
||||
const cloned = structuredClone(message);
|
||||
return {
|
||||
...cloned,
|
||||
api,
|
||||
provider,
|
||||
model: modelId,
|
||||
timestamp: cloned.timestamp ?? Date.now(),
|
||||
usage: cloned.usage ?? DEFAULT_USAGE,
|
||||
};
|
||||
}
|
||||
|
||||
function createErrorMessage(error: unknown, api: string, provider: string, modelId: string): AssistantMessage {
|
||||
return {
|
||||
role: "assistant",
|
||||
content: [],
|
||||
api,
|
||||
provider,
|
||||
model: modelId,
|
||||
usage: DEFAULT_USAGE,
|
||||
stopReason: "error",
|
||||
errorMessage: error instanceof Error ? error.message : String(error),
|
||||
timestamp: Date.now(),
|
||||
};
|
||||
}
|
||||
|
||||
function createAbortedMessage(partial: AssistantMessage): AssistantMessage {
|
||||
return {
|
||||
...partial,
|
||||
stopReason: "aborted",
|
||||
errorMessage: "Request was aborted",
|
||||
timestamp: Date.now(),
|
||||
};
|
||||
}
|
||||
|
||||
function scheduleChunk(chunk: string, tokensPerSecond: number | undefined): Promise<void> {
|
||||
if (!tokensPerSecond || tokensPerSecond <= 0) {
|
||||
return new Promise((resolve) => queueMicrotask(resolve));
|
||||
}
|
||||
const delayMs = (estimateTokens(chunk) / tokensPerSecond) * 1000;
|
||||
return new Promise((resolve) => setTimeout(resolve, delayMs));
|
||||
}
|
||||
|
||||
async function streamWithDeltas(
|
||||
stream: AssistantMessageEventStream,
|
||||
message: AssistantMessage,
|
||||
minTokenSize: number,
|
||||
maxTokenSize: number,
|
||||
tokensPerSecond: number | undefined,
|
||||
signal: AbortSignal | undefined,
|
||||
): Promise<void> {
|
||||
const partial: AssistantMessage = { ...message, content: [] };
|
||||
if (signal?.aborted) {
|
||||
const aborted = createAbortedMessage(partial);
|
||||
stream.push({ type: "error", reason: "aborted", error: aborted });
|
||||
stream.end(aborted);
|
||||
return;
|
||||
}
|
||||
|
||||
stream.push({ type: "start", partial: { ...partial } });
|
||||
|
||||
for (let index = 0; index < message.content.length; index++) {
|
||||
if (signal?.aborted) {
|
||||
const aborted = createAbortedMessage(partial);
|
||||
stream.push({ type: "error", reason: "aborted", error: aborted });
|
||||
stream.end(aborted);
|
||||
return;
|
||||
}
|
||||
|
||||
const block = message.content[index];
|
||||
|
||||
if (block.type === "thinking") {
|
||||
partial.content = [...partial.content, { type: "thinking", thinking: "" }];
|
||||
stream.push({ type: "thinking_start", contentIndex: index, partial: { ...partial } });
|
||||
for (const chunk of splitStringByTokenSize(block.thinking, minTokenSize, maxTokenSize)) {
|
||||
await scheduleChunk(chunk, tokensPerSecond);
|
||||
if (signal?.aborted) {
|
||||
const aborted = createAbortedMessage(partial);
|
||||
stream.push({ type: "error", reason: "aborted", error: aborted });
|
||||
stream.end(aborted);
|
||||
return;
|
||||
}
|
||||
(partial.content[index] as ThinkingContent).thinking += chunk;
|
||||
stream.push({ type: "thinking_delta", contentIndex: index, delta: chunk, partial: { ...partial } });
|
||||
}
|
||||
stream.push({
|
||||
type: "thinking_end",
|
||||
contentIndex: index,
|
||||
content: block.thinking,
|
||||
partial: { ...partial },
|
||||
});
|
||||
continue;
|
||||
}
|
||||
|
||||
if (block.type === "text") {
|
||||
partial.content = [...partial.content, { type: "text", text: "" }];
|
||||
stream.push({ type: "text_start", contentIndex: index, partial: { ...partial } });
|
||||
for (const chunk of splitStringByTokenSize(block.text, minTokenSize, maxTokenSize)) {
|
||||
await scheduleChunk(chunk, tokensPerSecond);
|
||||
if (signal?.aborted) {
|
||||
const aborted = createAbortedMessage(partial);
|
||||
stream.push({ type: "error", reason: "aborted", error: aborted });
|
||||
stream.end(aborted);
|
||||
return;
|
||||
}
|
||||
(partial.content[index] as TextContent).text += chunk;
|
||||
stream.push({ type: "text_delta", contentIndex: index, delta: chunk, partial: { ...partial } });
|
||||
}
|
||||
stream.push({ type: "text_end", contentIndex: index, content: block.text, partial: { ...partial } });
|
||||
continue;
|
||||
}
|
||||
|
||||
partial.content = [...partial.content, { type: "toolCall", id: block.id, name: block.name, arguments: {} }];
|
||||
stream.push({ type: "toolcall_start", contentIndex: index, partial: { ...partial } });
|
||||
for (const chunk of splitStringByTokenSize(JSON.stringify(block.arguments), minTokenSize, maxTokenSize)) {
|
||||
await scheduleChunk(chunk, tokensPerSecond);
|
||||
if (signal?.aborted) {
|
||||
const aborted = createAbortedMessage(partial);
|
||||
stream.push({ type: "error", reason: "aborted", error: aborted });
|
||||
stream.end(aborted);
|
||||
return;
|
||||
}
|
||||
stream.push({ type: "toolcall_delta", contentIndex: index, delta: chunk, partial: { ...partial } });
|
||||
}
|
||||
(partial.content[index] as ToolCall).arguments = block.arguments;
|
||||
stream.push({ type: "toolcall_end", contentIndex: index, toolCall: block, partial: { ...partial } });
|
||||
}
|
||||
|
||||
if (message.stopReason === "error" || message.stopReason === "aborted") {
|
||||
stream.push({ type: "error", reason: message.stopReason, error: message });
|
||||
stream.end(message);
|
||||
return;
|
||||
}
|
||||
|
||||
stream.push({ type: "done", reason: message.stopReason, message });
|
||||
stream.end(message);
|
||||
}
|
||||
|
||||
export function registerFauxProvider(options: RegisterFauxProviderOptions = {}): FauxProviderRegistration {
|
||||
const api = options.api ?? randomId(DEFAULT_API);
|
||||
const provider = options.provider ?? DEFAULT_PROVIDER;
|
||||
const sourceId = randomId("faux-provider");
|
||||
const minTokenSize = Math.max(
|
||||
1,
|
||||
Math.min(options.tokenSize?.min ?? DEFAULT_MIN_TOKEN_SIZE, options.tokenSize?.max ?? DEFAULT_MAX_TOKEN_SIZE),
|
||||
);
|
||||
const maxTokenSize = Math.max(minTokenSize, options.tokenSize?.max ?? DEFAULT_MAX_TOKEN_SIZE);
|
||||
let pendingResponses: FauxResponseStep[] = [];
|
||||
const tokensPerSecond = options.tokensPerSecond;
|
||||
const state = { callCount: 0 };
|
||||
const promptCache = new Map<string, string>();
|
||||
|
||||
const modelDefinitions = options.models?.length
|
||||
? options.models
|
||||
: [
|
||||
{
|
||||
id: DEFAULT_MODEL_ID,
|
||||
name: DEFAULT_MODEL_NAME,
|
||||
reasoning: false,
|
||||
input: ["text", "image"] as ("text" | "image")[],
|
||||
cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0 },
|
||||
contextWindow: 128000,
|
||||
maxTokens: 16384,
|
||||
},
|
||||
];
|
||||
const models = modelDefinitions.map((definition) => ({
|
||||
id: definition.id,
|
||||
name: definition.name ?? definition.id,
|
||||
api,
|
||||
provider,
|
||||
baseUrl: DEFAULT_BASE_URL,
|
||||
reasoning: definition.reasoning ?? false,
|
||||
input: definition.input ?? ["text", "image"],
|
||||
cost: definition.cost ?? { input: 0, output: 0, cacheRead: 0, cacheWrite: 0 },
|
||||
contextWindow: definition.contextWindow ?? 128000,
|
||||
maxTokens: definition.maxTokens ?? 16384,
|
||||
})) as [Model<string>, ...Model<string>[]];
|
||||
|
||||
const stream: StreamFunction<string, StreamOptions> = (requestModel, context, streamOptions) => {
|
||||
const outer = createAssistantMessageEventStream();
|
||||
const step = pendingResponses.shift();
|
||||
state.callCount++;
|
||||
|
||||
queueMicrotask(async () => {
|
||||
try {
|
||||
if (!step) {
|
||||
let message = createErrorMessage(
|
||||
new Error("No more faux responses queued"),
|
||||
api,
|
||||
provider,
|
||||
requestModel.id,
|
||||
);
|
||||
message = withUsageEstimate(message, context, streamOptions, promptCache);
|
||||
outer.push({ type: "error", reason: "error", error: message });
|
||||
outer.end(message);
|
||||
return;
|
||||
}
|
||||
|
||||
const resolved =
|
||||
typeof step === "function" ? await step(context, streamOptions, state, requestModel) : step;
|
||||
let message = cloneMessage(resolved, api, provider, requestModel.id);
|
||||
message = withUsageEstimate(message, context, streamOptions, promptCache);
|
||||
await streamWithDeltas(outer, message, minTokenSize, maxTokenSize, tokensPerSecond, streamOptions?.signal);
|
||||
} catch (error) {
|
||||
const message = createErrorMessage(error, api, provider, requestModel.id);
|
||||
outer.push({ type: "error", reason: "error", error: message });
|
||||
outer.end(message);
|
||||
}
|
||||
});
|
||||
|
||||
return outer;
|
||||
};
|
||||
|
||||
const streamSimple: StreamFunction<string, SimpleStreamOptions> = (streamModel, context, streamOptions) =>
|
||||
stream(streamModel, context, streamOptions);
|
||||
|
||||
registerApiProvider({ api, stream, streamSimple }, sourceId);
|
||||
|
||||
function getModel(): Model<string>;
|
||||
function getModel(requestedModelId: string): Model<string> | undefined;
|
||||
function getModel(requestedModelId?: string): Model<string> | undefined {
|
||||
if (!requestedModelId) {
|
||||
return models[0];
|
||||
}
|
||||
return models.find((candidate) => candidate.id === requestedModelId);
|
||||
}
|
||||
|
||||
return {
|
||||
api,
|
||||
models,
|
||||
getModel,
|
||||
state,
|
||||
setResponses(responses) {
|
||||
pendingResponses = [...responses];
|
||||
},
|
||||
appendResponses(responses) {
|
||||
pendingResponses.push(...responses);
|
||||
},
|
||||
getPendingResponseCount() {
|
||||
return pendingResponses.length;
|
||||
},
|
||||
unregister() {
|
||||
unregisterApiProviders(sourceId);
|
||||
},
|
||||
};
|
||||
}
|
||||
@@ -0,0 +1,597 @@
|
||||
import { afterEach, describe, expect, it } from "vitest";
|
||||
import {
|
||||
complete,
|
||||
fauxAssistantMessage,
|
||||
fauxText,
|
||||
fauxThinking,
|
||||
fauxToolCall,
|
||||
registerFauxProvider,
|
||||
stream,
|
||||
Type,
|
||||
} from "../src/index.js";
|
||||
import type { AssistantMessageEvent, Context } from "../src/types.js";
|
||||
|
||||
async function collectEvents(streamResult: ReturnType<typeof stream>): Promise<AssistantMessageEvent[]> {
|
||||
const events: AssistantMessageEvent[] = [];
|
||||
for await (const event of streamResult) {
|
||||
events.push(event);
|
||||
}
|
||||
return events;
|
||||
}
|
||||
|
||||
const registrations: Array<{ unregister: () => void }> = [];
|
||||
|
||||
afterEach(() => {
|
||||
for (const registration of registrations.splice(0)) {
|
||||
registration.unregister();
|
||||
}
|
||||
});
|
||||
|
||||
describe("faux provider", () => {
|
||||
it("registers a custom provider and estimates usage", async () => {
|
||||
const registration = registerFauxProvider();
|
||||
registrations.push(registration);
|
||||
registration.setResponses([fauxAssistantMessage("hello world")]);
|
||||
|
||||
const context: Context = {
|
||||
systemPrompt: "Be concise.",
|
||||
messages: [{ role: "user", content: "hi there", timestamp: Date.now() }],
|
||||
};
|
||||
|
||||
const response = await complete(registration.getModel(), context);
|
||||
expect(response.content).toEqual([{ type: "text", text: "hello world" }]);
|
||||
expect(response.usage.input).toBeGreaterThan(0);
|
||||
expect(response.usage.output).toBeGreaterThan(0);
|
||||
expect(response.usage.totalTokens).toBe(response.usage.input + response.usage.output);
|
||||
expect(registration.state.callCount).toBe(1);
|
||||
});
|
||||
|
||||
it("supports helper blocks for text, thinking, and tool calls", async () => {
|
||||
const registration = registerFauxProvider();
|
||||
registrations.push(registration);
|
||||
registration.setResponses([
|
||||
fauxAssistantMessage([fauxThinking("think"), fauxToolCall("echo", { text: "hi" }), fauxText("done")], {
|
||||
stopReason: "toolUse",
|
||||
}),
|
||||
]);
|
||||
|
||||
const response = await complete(registration.getModel(), {
|
||||
messages: [{ role: "user", content: "hi", timestamp: Date.now() }],
|
||||
});
|
||||
|
||||
expect(response.content).toEqual([
|
||||
{ type: "thinking", thinking: "think" },
|
||||
{ type: "toolCall", id: expect.any(String), name: "echo", arguments: { text: "hi" } },
|
||||
{ type: "text", text: "done" },
|
||||
]);
|
||||
expect(response.stopReason).toBe("toolUse");
|
||||
});
|
||||
|
||||
it("supports multiple models with per-model reasoning and model-aware factories", async () => {
|
||||
const registration = registerFauxProvider({
|
||||
models: [
|
||||
{ id: "faux-fast", name: "Faux Fast", reasoning: false },
|
||||
{ id: "faux-thinker", name: "Faux Thinker", reasoning: true },
|
||||
],
|
||||
});
|
||||
registrations.push(registration);
|
||||
registration.setResponses([
|
||||
(_context, _options, _state, model) => fauxAssistantMessage(`${model.id}:${String(model.reasoning)}`),
|
||||
(_context, _options, _state, model) => fauxAssistantMessage(`${model.id}:${String(model.reasoning)}`),
|
||||
]);
|
||||
|
||||
expect(registration.models.map((model) => model.id)).toEqual(["faux-fast", "faux-thinker"]);
|
||||
expect(registration.getModel()).toBe(registration.models[0]);
|
||||
expect(registration.getModel("faux-fast")?.reasoning).toBe(false);
|
||||
expect(registration.getModel("faux-thinker")?.reasoning).toBe(true);
|
||||
|
||||
const fast = await complete(registration.getModel("faux-fast")!, {
|
||||
messages: [{ role: "user", content: "hi", timestamp: Date.now() }],
|
||||
});
|
||||
const thinker = await complete(registration.getModel("faux-thinker")!, {
|
||||
messages: [{ role: "user", content: "hi", timestamp: Date.now() }],
|
||||
});
|
||||
|
||||
expect(fast.content).toEqual([{ type: "text", text: "faux-fast:false" }]);
|
||||
expect(thinker.content).toEqual([{ type: "text", text: "faux-thinker:true" }]);
|
||||
});
|
||||
|
||||
it("rewrites api, provider, and model on returned messages", async () => {
|
||||
const registration = registerFauxProvider({
|
||||
api: "faux:test",
|
||||
provider: "faux-provider",
|
||||
models: [{ id: "faux-model" }],
|
||||
});
|
||||
registrations.push(registration);
|
||||
registration.setResponses([fauxAssistantMessage("hello")]);
|
||||
|
||||
const response = await complete(registration.getModel(), {
|
||||
messages: [{ role: "user", content: "hi", timestamp: Date.now() }],
|
||||
});
|
||||
|
||||
expect(response.api).toBe("faux:test");
|
||||
expect(response.provider).toBe("faux-provider");
|
||||
expect(response.model).toBe("faux-model");
|
||||
});
|
||||
|
||||
it("consumes queued responses in order and errors when exhausted", async () => {
|
||||
const registration = registerFauxProvider();
|
||||
registrations.push(registration);
|
||||
registration.setResponses([fauxAssistantMessage("first"), fauxAssistantMessage("second")]);
|
||||
|
||||
const context: Context = {
|
||||
messages: [{ role: "user", content: "hi", timestamp: Date.now() }],
|
||||
};
|
||||
|
||||
const first = await complete(registration.getModel(), context);
|
||||
const second = await complete(registration.getModel(), context);
|
||||
const exhausted = await complete(registration.getModel(), context);
|
||||
|
||||
expect(first.content).toEqual([{ type: "text", text: "first" }]);
|
||||
expect(second.content).toEqual([{ type: "text", text: "second" }]);
|
||||
expect(exhausted.stopReason).toBe("error");
|
||||
expect(exhausted.errorMessage).toBe("No more faux responses queued");
|
||||
expect(registration.getPendingResponseCount()).toBe(0);
|
||||
expect(registration.state.callCount).toBe(3);
|
||||
});
|
||||
|
||||
it("can replace and append queued responses", async () => {
|
||||
const registration = registerFauxProvider();
|
||||
registrations.push(registration);
|
||||
registration.setResponses([fauxAssistantMessage("first")]);
|
||||
|
||||
const context: Context = {
|
||||
messages: [{ role: "user", content: "hi", timestamp: Date.now() }],
|
||||
};
|
||||
|
||||
expect((await complete(registration.getModel(), context)).content).toEqual([{ type: "text", text: "first" }]);
|
||||
expect(registration.getPendingResponseCount()).toBe(0);
|
||||
|
||||
registration.setResponses([fauxAssistantMessage("second")]);
|
||||
expect(registration.getPendingResponseCount()).toBe(1);
|
||||
expect((await complete(registration.getModel(), context)).content).toEqual([{ type: "text", text: "second" }]);
|
||||
|
||||
registration.appendResponses([fauxAssistantMessage("third"), fauxAssistantMessage("fourth")]);
|
||||
expect(registration.getPendingResponseCount()).toBe(2);
|
||||
expect((await complete(registration.getModel(), context)).content).toEqual([{ type: "text", text: "third" }]);
|
||||
expect((await complete(registration.getModel(), context)).content).toEqual([{ type: "text", text: "fourth" }]);
|
||||
expect(registration.getPendingResponseCount()).toBe(0);
|
||||
});
|
||||
|
||||
it("supports async response factories", async () => {
|
||||
const registration = registerFauxProvider();
|
||||
registrations.push(registration);
|
||||
registration.setResponses([
|
||||
async (context, _options, state) => fauxAssistantMessage(`${context.messages.length}:${state.callCount}`),
|
||||
]);
|
||||
|
||||
const response = await complete(registration.getModel(), {
|
||||
messages: [{ role: "user", content: "hi", timestamp: Date.now() }],
|
||||
});
|
||||
|
||||
expect(response.content).toEqual([{ type: "text", text: "1:1" }]);
|
||||
});
|
||||
|
||||
it("emits an error when a response factory throws", async () => {
|
||||
const registration = registerFauxProvider();
|
||||
registrations.push(registration);
|
||||
registration.setResponses([
|
||||
() => {
|
||||
throw new Error("boom");
|
||||
},
|
||||
]);
|
||||
|
||||
const events = await collectEvents(
|
||||
stream(registration.getModel(), { messages: [{ role: "user", content: "hi", timestamp: Date.now() }] }),
|
||||
);
|
||||
|
||||
expect(events).toHaveLength(1);
|
||||
expect(events[0].type).toBe("error");
|
||||
if (events[0].type === "error") {
|
||||
expect(events[0].error.stopReason).toBe("error");
|
||||
expect(events[0].error.errorMessage).toBe("boom");
|
||||
}
|
||||
});
|
||||
|
||||
it("estimates prompt and output tokens from serialized context", async () => {
|
||||
const registration = registerFauxProvider();
|
||||
registrations.push(registration);
|
||||
registration.setResponses([fauxAssistantMessage("done")]);
|
||||
|
||||
const tool = {
|
||||
name: "echo",
|
||||
description: "Echo back text",
|
||||
parameters: Type.Object({ text: Type.String() }),
|
||||
};
|
||||
const context: Context = {
|
||||
systemPrompt: "sys",
|
||||
messages: [
|
||||
{
|
||||
role: "user",
|
||||
content: [
|
||||
{ type: "text", text: "hello" },
|
||||
{ type: "image", mimeType: "image/png", data: "abcd" },
|
||||
],
|
||||
timestamp: 1,
|
||||
},
|
||||
fauxAssistantMessage("prior"),
|
||||
{
|
||||
role: "toolResult",
|
||||
toolCallId: "tool-1",
|
||||
toolName: "echo",
|
||||
content: [{ type: "text", text: "tool out" }],
|
||||
isError: false,
|
||||
timestamp: 2,
|
||||
},
|
||||
],
|
||||
tools: [tool],
|
||||
};
|
||||
|
||||
const response = await complete(registration.getModel(), context);
|
||||
const promptText = [
|
||||
"system:sys",
|
||||
"user:hello\n[image:image/png:4]",
|
||||
"assistant:prior",
|
||||
"toolResult:echo\ntool out",
|
||||
`tools:${JSON.stringify([tool])}`,
|
||||
].join("\n\n");
|
||||
const expectedPromptTokens = Math.ceil(promptText.length / 4);
|
||||
const expectedOutputTokens = Math.ceil("done".length / 4);
|
||||
|
||||
expect(response.usage.input).toBe(expectedPromptTokens);
|
||||
expect(response.usage.output).toBe(expectedOutputTokens);
|
||||
expect(response.usage.cacheRead).toBe(0);
|
||||
expect(response.usage.cacheWrite).toBe(0);
|
||||
expect(response.usage.totalTokens).toBe(expectedPromptTokens + expectedOutputTokens);
|
||||
});
|
||||
|
||||
it("does not share cache across sessions or requests without sessionId", async () => {
|
||||
const registration = registerFauxProvider();
|
||||
registrations.push(registration);
|
||||
registration.setResponses([
|
||||
fauxAssistantMessage("first"),
|
||||
fauxAssistantMessage("second"),
|
||||
fauxAssistantMessage("third"),
|
||||
]);
|
||||
|
||||
const context: Context = {
|
||||
messages: [{ role: "user", content: "hello", timestamp: Date.now() }],
|
||||
};
|
||||
|
||||
const first = await complete(registration.getModel(), context, {
|
||||
sessionId: "session-1",
|
||||
cacheRetention: "short",
|
||||
});
|
||||
expect(first.usage.cacheWrite).toBeGreaterThan(0);
|
||||
context.messages.push(first);
|
||||
context.messages.push({ role: "user", content: "follow up", timestamp: Date.now() + 1 });
|
||||
|
||||
const second = await complete(registration.getModel(), context, {
|
||||
sessionId: "session-2",
|
||||
cacheRetention: "short",
|
||||
});
|
||||
expect(second.usage.cacheRead).toBe(0);
|
||||
expect(second.usage.cacheWrite).toBeGreaterThan(0);
|
||||
|
||||
const third = await complete(registration.getModel(), context);
|
||||
expect(third.usage.cacheRead).toBe(0);
|
||||
expect(third.usage.cacheWrite).toBe(0);
|
||||
});
|
||||
|
||||
it("simulates prompt caching per sessionId", async () => {
|
||||
const registration = registerFauxProvider();
|
||||
registrations.push(registration);
|
||||
registration.setResponses([fauxAssistantMessage("first"), fauxAssistantMessage("second")]);
|
||||
|
||||
const context: Context = {
|
||||
systemPrompt: "Be concise.",
|
||||
messages: [{ role: "user", content: "hello", timestamp: Date.now() }],
|
||||
};
|
||||
|
||||
const first = await complete(registration.getModel(), context, {
|
||||
sessionId: "session-1",
|
||||
cacheRetention: "short",
|
||||
});
|
||||
expect(first.usage.cacheRead).toBe(0);
|
||||
expect(first.usage.cacheWrite).toBeGreaterThan(0);
|
||||
|
||||
context.messages.push(first);
|
||||
context.messages.push({ role: "user", content: "follow up", timestamp: Date.now() + 1 });
|
||||
|
||||
const second = await complete(registration.getModel(), context, {
|
||||
sessionId: "session-1",
|
||||
cacheRetention: "short",
|
||||
});
|
||||
expect(second.usage.cacheRead).toBeGreaterThan(0);
|
||||
expect(second.usage.input + second.usage.cacheRead).toBeGreaterThan(second.usage.input);
|
||||
});
|
||||
|
||||
it("does not simulate caching when cacheRetention is none", async () => {
|
||||
const registration = registerFauxProvider();
|
||||
registrations.push(registration);
|
||||
registration.setResponses([fauxAssistantMessage("first"), fauxAssistantMessage("second")]);
|
||||
|
||||
const context: Context = {
|
||||
messages: [{ role: "user", content: "hello", timestamp: Date.now() }],
|
||||
};
|
||||
|
||||
await complete(registration.getModel(), context, { sessionId: "session-1", cacheRetention: "none" });
|
||||
context.messages.push(fauxAssistantMessage("first"));
|
||||
context.messages.push({ role: "user", content: "follow up", timestamp: Date.now() + 1 });
|
||||
const second = await complete(registration.getModel(), context, {
|
||||
sessionId: "session-1",
|
||||
cacheRetention: "none",
|
||||
});
|
||||
expect(second.usage.cacheRead).toBe(0);
|
||||
expect(second.usage.cacheWrite).toBe(0);
|
||||
});
|
||||
|
||||
it("streams thinking, text, and partial tool call deltas", async () => {
|
||||
const registration = registerFauxProvider();
|
||||
registrations.push(registration);
|
||||
registration.setResponses([
|
||||
fauxAssistantMessage(
|
||||
[
|
||||
fauxThinking("thinking text"),
|
||||
fauxText("answer text"),
|
||||
fauxToolCall("echo", { text: "hi", count: 12 }, { id: "tool-1" }),
|
||||
],
|
||||
{ stopReason: "toolUse" },
|
||||
),
|
||||
]);
|
||||
|
||||
const events: string[] = [];
|
||||
const toolCallDeltas: string[] = [];
|
||||
const s = stream(registration.getModel(), { messages: [{ role: "user", content: "hi", timestamp: Date.now() }] });
|
||||
for await (const event of s) {
|
||||
events.push(event.type);
|
||||
if (event.type === "toolcall_delta") {
|
||||
toolCallDeltas.push(event.delta);
|
||||
}
|
||||
}
|
||||
|
||||
expect(events).toContain("thinking_start");
|
||||
expect(events).toContain("thinking_delta");
|
||||
expect(events).toContain("text_start");
|
||||
expect(events).toContain("text_delta");
|
||||
expect(events).toContain("toolcall_start");
|
||||
expect(events).toContain("toolcall_delta");
|
||||
expect(events).toContain("toolcall_end");
|
||||
expect(toolCallDeltas.length).toBeGreaterThan(1);
|
||||
expect(JSON.parse(toolCallDeltas.join(""))).toEqual({ text: "hi", count: 12 });
|
||||
});
|
||||
|
||||
it("streams an exact event order for fixed-size chunks", async () => {
|
||||
const registration = registerFauxProvider({ tokenSize: { min: 1, max: 1 } });
|
||||
registrations.push(registration);
|
||||
registration.setResponses([
|
||||
fauxAssistantMessage([fauxThinking("go"), fauxText("ok"), fauxToolCall("echo", {}, { id: "tool-1" })], {
|
||||
stopReason: "toolUse",
|
||||
}),
|
||||
]);
|
||||
|
||||
const events = await collectEvents(
|
||||
stream(registration.getModel(), { messages: [{ role: "user", content: "hi", timestamp: Date.now() }] }),
|
||||
);
|
||||
|
||||
expect(events.map((event) => event.type)).toEqual([
|
||||
"start",
|
||||
"thinking_start",
|
||||
"thinking_delta",
|
||||
"thinking_end",
|
||||
"text_start",
|
||||
"text_delta",
|
||||
"text_end",
|
||||
"toolcall_start",
|
||||
"toolcall_delta",
|
||||
"toolcall_end",
|
||||
"done",
|
||||
]);
|
||||
});
|
||||
|
||||
it("streams multiple tool calls in one message", async () => {
|
||||
const registration = registerFauxProvider();
|
||||
registrations.push(registration);
|
||||
registration.setResponses([
|
||||
fauxAssistantMessage(
|
||||
[
|
||||
fauxToolCall("echo", { text: "one" }, { id: "tool-1" }),
|
||||
fauxToolCall("echo", { text: "two" }, { id: "tool-2" }),
|
||||
],
|
||||
{ stopReason: "toolUse" },
|
||||
),
|
||||
]);
|
||||
|
||||
const events = await collectEvents(
|
||||
stream(registration.getModel(), { messages: [{ role: "user", content: "hi", timestamp: Date.now() }] }),
|
||||
);
|
||||
|
||||
expect(events.filter((event) => event.type === "toolcall_start")).toHaveLength(2);
|
||||
expect(events.filter((event) => event.type === "toolcall_end")).toHaveLength(2);
|
||||
});
|
||||
|
||||
it("streams an explicit assistant error message as a terminal error", async () => {
|
||||
const registration = registerFauxProvider({ tokenSize: { min: 2, max: 2 } });
|
||||
registrations.push(registration);
|
||||
registration.setResponses([
|
||||
{
|
||||
...fauxAssistantMessage("partial"),
|
||||
stopReason: "error",
|
||||
errorMessage: "upstream failed",
|
||||
},
|
||||
]);
|
||||
|
||||
const events = await collectEvents(
|
||||
stream(registration.getModel(), { messages: [{ role: "user", content: "hi", timestamp: Date.now() }] }),
|
||||
);
|
||||
|
||||
expect(events.map((event) => event.type)).toEqual(["start", "text_start", "text_delta", "text_end", "error"]);
|
||||
const terminal = events[events.length - 1];
|
||||
expect(terminal.type).toBe("error");
|
||||
if (terminal.type === "error") {
|
||||
expect(terminal.reason).toBe("error");
|
||||
expect(terminal.error.stopReason).toBe("error");
|
||||
expect(terminal.error.errorMessage).toBe("upstream failed");
|
||||
}
|
||||
});
|
||||
|
||||
it("streams an explicit assistant aborted message as a terminal error", async () => {
|
||||
const registration = registerFauxProvider({ tokenSize: { min: 2, max: 2 } });
|
||||
registrations.push(registration);
|
||||
registration.setResponses([
|
||||
{
|
||||
...fauxAssistantMessage("partial"),
|
||||
stopReason: "aborted",
|
||||
errorMessage: "Request was aborted",
|
||||
},
|
||||
]);
|
||||
|
||||
const events = await collectEvents(
|
||||
stream(registration.getModel(), { messages: [{ role: "user", content: "hi", timestamp: Date.now() }] }),
|
||||
);
|
||||
|
||||
expect(events.map((event) => event.type)).toEqual(["start", "text_start", "text_delta", "text_end", "error"]);
|
||||
const terminal = events[events.length - 1];
|
||||
expect(terminal.type).toBe("error");
|
||||
if (terminal.type === "error") {
|
||||
expect(terminal.reason).toBe("aborted");
|
||||
expect(terminal.error.stopReason).toBe("aborted");
|
||||
expect(terminal.error.errorMessage).toBe("Request was aborted");
|
||||
}
|
||||
});
|
||||
|
||||
it("supports aborting before the first chunk", async () => {
|
||||
const registration = registerFauxProvider({ tokensPerSecond: 50, tokenSize: { min: 3, max: 3 } });
|
||||
registrations.push(registration);
|
||||
registration.setResponses([fauxAssistantMessage("abcdefghijklmnopqrstuvwxyz")]);
|
||||
|
||||
const controller = new AbortController();
|
||||
controller.abort();
|
||||
const events = await collectEvents(
|
||||
stream(
|
||||
registration.getModel(),
|
||||
{ messages: [{ role: "user", content: "hi", timestamp: Date.now() }] },
|
||||
{ signal: controller.signal },
|
||||
),
|
||||
);
|
||||
|
||||
expect(events).toHaveLength(1);
|
||||
expect(events[0].type).toBe("error");
|
||||
if (events[0].type === "error") {
|
||||
expect(events[0].reason).toBe("aborted");
|
||||
expect(events[0].error.stopReason).toBe("aborted");
|
||||
}
|
||||
});
|
||||
|
||||
it("supports aborting mid-text stream when paced", async () => {
|
||||
const registration = registerFauxProvider({ tokensPerSecond: 100, tokenSize: { min: 3, max: 3 } });
|
||||
registrations.push(registration);
|
||||
registration.setResponses([fauxAssistantMessage("abcdefghijklmnopqrstuvwxyz")]);
|
||||
|
||||
const controller = new AbortController();
|
||||
const events: string[] = [];
|
||||
let textDeltaCount = 0;
|
||||
const s = stream(
|
||||
registration.getModel(),
|
||||
{ messages: [{ role: "user", content: "hi", timestamp: Date.now() }] },
|
||||
{ signal: controller.signal },
|
||||
);
|
||||
for await (const event of s) {
|
||||
events.push(event.type);
|
||||
if (event.type === "text_delta") {
|
||||
textDeltaCount++;
|
||||
controller.abort();
|
||||
}
|
||||
}
|
||||
|
||||
expect(textDeltaCount).toBe(1);
|
||||
expect(events).toContain("text_start");
|
||||
expect(events).toContain("text_delta");
|
||||
expect(events).toContain("error");
|
||||
expect(events).not.toContain("text_end");
|
||||
});
|
||||
|
||||
it("supports aborting mid-thinking stream when paced", async () => {
|
||||
const registration = registerFauxProvider({ tokensPerSecond: 100, tokenSize: { min: 3, max: 3 } });
|
||||
registrations.push(registration);
|
||||
registration.setResponses([
|
||||
{
|
||||
...fauxAssistantMessage("ignored"),
|
||||
content: [{ type: "thinking", thinking: "abcdefghijklmnopqrstuvwxyz" }],
|
||||
},
|
||||
]);
|
||||
|
||||
const controller = new AbortController();
|
||||
const events: string[] = [];
|
||||
let thinkingDeltaCount = 0;
|
||||
const s = stream(
|
||||
registration.getModel(),
|
||||
{ messages: [{ role: "user", content: "hi", timestamp: Date.now() }] },
|
||||
{ signal: controller.signal },
|
||||
);
|
||||
for await (const event of s) {
|
||||
events.push(event.type);
|
||||
if (event.type === "thinking_delta") {
|
||||
thinkingDeltaCount++;
|
||||
controller.abort();
|
||||
}
|
||||
}
|
||||
|
||||
expect(thinkingDeltaCount).toBe(1);
|
||||
expect(events).toContain("thinking_start");
|
||||
expect(events).toContain("thinking_delta");
|
||||
expect(events).toContain("error");
|
||||
expect(events).not.toContain("thinking_end");
|
||||
});
|
||||
|
||||
it("supports aborting mid-toolcall stream when paced", async () => {
|
||||
const registration = registerFauxProvider({ tokensPerSecond: 100, tokenSize: { min: 3, max: 3 } });
|
||||
registrations.push(registration);
|
||||
registration.setResponses([
|
||||
{
|
||||
...fauxAssistantMessage("done"),
|
||||
content: [
|
||||
{
|
||||
type: "toolCall",
|
||||
id: "tool-1",
|
||||
name: "echo",
|
||||
arguments: { text: "abcdefghijklmnopqrstuvwxyz", count: 123456789 },
|
||||
},
|
||||
],
|
||||
stopReason: "toolUse",
|
||||
},
|
||||
]);
|
||||
|
||||
const controller = new AbortController();
|
||||
const events: string[] = [];
|
||||
let toolCallDeltaCount = 0;
|
||||
const s = stream(
|
||||
registration.getModel(),
|
||||
{ messages: [{ role: "user", content: "hi", timestamp: Date.now() }] },
|
||||
{ signal: controller.signal },
|
||||
);
|
||||
for await (const event of s) {
|
||||
events.push(event.type);
|
||||
if (event.type === "toolcall_delta") {
|
||||
toolCallDeltaCount++;
|
||||
controller.abort();
|
||||
}
|
||||
}
|
||||
|
||||
expect(toolCallDeltaCount).toBe(1);
|
||||
expect(events).toContain("toolcall_start");
|
||||
expect(events).toContain("toolcall_delta");
|
||||
expect(events).toContain("error");
|
||||
expect(events).not.toContain("toolcall_end");
|
||||
});
|
||||
|
||||
it("unregisters the provider", async () => {
|
||||
const registration = registerFauxProvider();
|
||||
registration.setResponses([fauxAssistantMessage("hello")]);
|
||||
registration.unregister();
|
||||
|
||||
await expect(
|
||||
complete(registration.getModel(), { messages: [{ role: "user", content: "hi", timestamp: Date.now() }] }),
|
||||
).rejects.toThrow(`No API provider registered for api: ${registration.api}`);
|
||||
});
|
||||
});
|
||||
@@ -395,7 +395,7 @@ import { AuthStorage, createAgentSession, ModelRegistry, SessionManager } from "
|
||||
const { session } = await createAgentSession({
|
||||
sessionManager: SessionManager.inMemory(),
|
||||
authStorage: AuthStorage.create(),
|
||||
modelRegistry: new ModelRegistry(authStorage),
|
||||
modelRegistry: ModelRegistry.create(authStorage),
|
||||
});
|
||||
|
||||
await session.prompt("What files are in the current directory?");
|
||||
|
||||
@@ -20,7 +20,7 @@ import { AuthStorage, createAgentSession, ModelRegistry, SessionManager } from "
|
||||
|
||||
// Set up credential storage and model registry
|
||||
const authStorage = AuthStorage.create();
|
||||
const modelRegistry = new ModelRegistry(authStorage);
|
||||
const modelRegistry = ModelRegistry.create(authStorage);
|
||||
|
||||
const { session } = await createAgentSession({
|
||||
sessionManager: SessionManager.inMemory(),
|
||||
@@ -289,7 +289,7 @@ import { getModel } from "@mariozechner/pi-ai";
|
||||
import { AuthStorage, ModelRegistry } from "@mariozechner/pi-coding-agent";
|
||||
|
||||
const authStorage = AuthStorage.create();
|
||||
const modelRegistry = new ModelRegistry(authStorage);
|
||||
const modelRegistry = ModelRegistry.create(authStorage);
|
||||
|
||||
// Find specific built-in model (doesn't check if API key exists)
|
||||
const opus = getModel("anthropic", "claude-opus-4-5");
|
||||
@@ -337,7 +337,7 @@ import { AuthStorage, ModelRegistry } from "@mariozechner/pi-coding-agent";
|
||||
|
||||
// Default: uses ~/.pi/agent/auth.json and ~/.pi/agent/models.json
|
||||
const authStorage = AuthStorage.create();
|
||||
const modelRegistry = new ModelRegistry(authStorage);
|
||||
const modelRegistry = ModelRegistry.create(authStorage);
|
||||
|
||||
const { session } = await createAgentSession({
|
||||
sessionManager: SessionManager.inMemory(),
|
||||
@@ -350,7 +350,7 @@ authStorage.setRuntimeApiKey("anthropic", "sk-my-temp-key");
|
||||
|
||||
// Custom auth storage location
|
||||
const customAuth = AuthStorage.create("/my/app/auth.json");
|
||||
const customRegistry = new ModelRegistry(customAuth, "/my/app/models.json");
|
||||
const customRegistry = ModelRegistry.create(customAuth, "/my/app/models.json");
|
||||
|
||||
const { session } = await createAgentSession({
|
||||
sessionManager: SessionManager.inMemory(),
|
||||
@@ -359,7 +359,7 @@ const { session } = await createAgentSession({
|
||||
});
|
||||
|
||||
// No custom models.json (built-in models only)
|
||||
const simpleRegistry = new ModelRegistry(authStorage);
|
||||
const simpleRegistry = ModelRegistry.inMemory(authStorage);
|
||||
```
|
||||
|
||||
> See [examples/sdk/09-api-keys-and-oauth.ts](../examples/sdk/09-api-keys-and-oauth.ts)
|
||||
@@ -788,7 +788,7 @@ if (process.env.MY_KEY) {
|
||||
}
|
||||
|
||||
// Model registry (no custom models.json)
|
||||
const modelRegistry = new ModelRegistry(authStorage);
|
||||
const modelRegistry = ModelRegistry.create(authStorage);
|
||||
|
||||
// Inline tool
|
||||
const statusTool: ToolDefinition = {
|
||||
|
||||
@@ -9,7 +9,7 @@ import { AuthStorage, createAgentSession, ModelRegistry } from "@mariozechner/pi
|
||||
|
||||
// Set up auth storage and model registry
|
||||
const authStorage = AuthStorage.create();
|
||||
const modelRegistry = new ModelRegistry(authStorage);
|
||||
const modelRegistry = ModelRegistry.create(authStorage);
|
||||
|
||||
// Option 1: Find a specific built-in model by provider/id
|
||||
const opus = getModel("anthropic", "claude-opus-4-5");
|
||||
|
||||
@@ -9,7 +9,7 @@ import { AuthStorage, createAgentSession, ModelRegistry, SessionManager } from "
|
||||
// Default: AuthStorage uses ~/.pi/agent/auth.json
|
||||
// ModelRegistry loads built-in + custom models from ~/.pi/agent/models.json
|
||||
const authStorage = AuthStorage.create();
|
||||
const modelRegistry = new ModelRegistry(authStorage);
|
||||
const modelRegistry = ModelRegistry.create(authStorage);
|
||||
|
||||
await createAgentSession({
|
||||
sessionManager: SessionManager.inMemory(),
|
||||
@@ -20,7 +20,7 @@ console.log("Session with default auth storage and model registry");
|
||||
|
||||
// Custom auth storage location
|
||||
const customAuthStorage = AuthStorage.create("/tmp/my-app/auth.json");
|
||||
const customModelRegistry = new ModelRegistry(customAuthStorage, "/tmp/my-app/models.json");
|
||||
const customModelRegistry = ModelRegistry.create(customAuthStorage, "/tmp/my-app/models.json");
|
||||
|
||||
await createAgentSession({
|
||||
sessionManager: SessionManager.inMemory(),
|
||||
@@ -39,7 +39,7 @@ await createAgentSession({
|
||||
console.log("Session with runtime API key override");
|
||||
|
||||
// No models.json - only built-in models
|
||||
const simpleRegistry = new ModelRegistry(authStorage); // null = no models.json
|
||||
const simpleRegistry = ModelRegistry.inMemory(authStorage);
|
||||
await createAgentSession({
|
||||
sessionManager: SessionManager.inMemory(),
|
||||
authStorage,
|
||||
|
||||
@@ -30,7 +30,7 @@ if (process.env.MY_ANTHROPIC_KEY) {
|
||||
}
|
||||
|
||||
// Model registry with no custom models.json
|
||||
const modelRegistry = new ModelRegistry(authStorage);
|
||||
const modelRegistry = ModelRegistry.inMemory(authStorage);
|
||||
|
||||
const model = getModel("anthropic", "claude-sonnet-4-20250514");
|
||||
if (!model) throw new Error("Model not found");
|
||||
|
||||
@@ -44,7 +44,7 @@ import {
|
||||
|
||||
// Auth and models setup
|
||||
const authStorage = AuthStorage.create();
|
||||
const modelRegistry = new ModelRegistry(authStorage);
|
||||
const modelRegistry = ModelRegistry.create(authStorage);
|
||||
|
||||
// Minimal
|
||||
const { session } = await createAgentSession({ authStorage, modelRegistry });
|
||||
@@ -73,7 +73,7 @@ const { session } = await createAgentSession({
|
||||
// Full control
|
||||
const customAuth = AuthStorage.create("/my/app/auth.json");
|
||||
customAuth.setRuntimeApiKey("anthropic", process.env.MY_KEY!);
|
||||
const customRegistry = new ModelRegistry(customAuth);
|
||||
const customRegistry = ModelRegistry.create(customAuth);
|
||||
|
||||
const resourceLoader = new DefaultResourceLoader({
|
||||
systemPromptOverride: () => "You are helpful.",
|
||||
@@ -109,7 +109,7 @@ await session.prompt("Hello");
|
||||
| Option | Default | Description |
|
||||
|--------|---------|-------------|
|
||||
| `authStorage` | `AuthStorage.create()` | Credential storage |
|
||||
| `modelRegistry` | `new ModelRegistry(authStorage)` | Model registry |
|
||||
| `modelRegistry` | `ModelRegistry.create(authStorage)` | Model registry |
|
||||
| `cwd` | `process.cwd()` | Working directory |
|
||||
| `agentDir` | `~/.pi/agent` | Config directory |
|
||||
| `model` | From settings/first available | Model to use |
|
||||
|
||||
@@ -259,13 +259,21 @@ export class ModelRegistry {
|
||||
private registeredProviders: Map<string, ProviderConfigInput> = new Map();
|
||||
private loadError: string | undefined = undefined;
|
||||
|
||||
constructor(
|
||||
private constructor(
|
||||
readonly authStorage: AuthStorage,
|
||||
private modelsJsonPath: string | undefined = join(getAgentDir(), "models.json"),
|
||||
private modelsJsonPath: string | undefined,
|
||||
) {
|
||||
this.loadModels();
|
||||
}
|
||||
|
||||
static create(authStorage: AuthStorage, modelsJsonPath: string = join(getAgentDir(), "models.json")): ModelRegistry {
|
||||
return new ModelRegistry(authStorage, modelsJsonPath);
|
||||
}
|
||||
|
||||
static inMemory(authStorage: AuthStorage): ModelRegistry {
|
||||
return new ModelRegistry(authStorage, undefined);
|
||||
}
|
||||
|
||||
/**
|
||||
* Reload models from disk (built-in + custom from models.json).
|
||||
*/
|
||||
|
||||
@@ -47,7 +47,7 @@ export interface CreateAgentSessionOptions {
|
||||
|
||||
/** Auth storage for credentials. Default: AuthStorage.create(agentDir/auth.json) */
|
||||
authStorage?: AuthStorage;
|
||||
/** Model registry. Default: new ModelRegistry(authStorage, agentDir/models.json) */
|
||||
/** Model registry. Default: ModelRegistry.create(authStorage, agentDir/models.json) */
|
||||
modelRegistry?: ModelRegistry;
|
||||
|
||||
/** Model to use. Default: from settings, else first available */
|
||||
@@ -172,7 +172,7 @@ export async function createAgentSession(options: CreateAgentSessionOptions = {}
|
||||
const authPath = options.agentDir ? join(agentDir, "auth.json") : undefined;
|
||||
const modelsPath = options.agentDir ? join(agentDir, "models.json") : undefined;
|
||||
const authStorage = options.authStorage ?? AuthStorage.create(authPath);
|
||||
const modelRegistry = options.modelRegistry ?? new ModelRegistry(authStorage, modelsPath);
|
||||
const modelRegistry = options.modelRegistry ?? ModelRegistry.create(authStorage, modelsPath);
|
||||
|
||||
const settingsManager = options.settingsManager ?? SettingsManager.create(cwd, agentDir);
|
||||
const sessionManager = options.sessionManager ?? SessionManager.create(cwd, getDefaultSessionDir(cwd, agentDir));
|
||||
|
||||
@@ -654,7 +654,7 @@ export async function main(args: string[]) {
|
||||
const settingsManager = SettingsManager.create(cwd, agentDir);
|
||||
reportSettingsErrors(settingsManager, "startup");
|
||||
const authStorage = AuthStorage.create();
|
||||
const modelRegistry = new ModelRegistry(authStorage, getModelsPath());
|
||||
const modelRegistry = ModelRegistry.create(authStorage, getModelsPath());
|
||||
|
||||
const resourceLoader = new DefaultResourceLoader({
|
||||
cwd,
|
||||
|
||||
@@ -76,7 +76,7 @@ describe("AgentSession auto-compaction queue resume", () => {
|
||||
const settingsManager = SettingsManager.create(tempDir, tempDir);
|
||||
const authStorage = AuthStorage.create(join(tempDir, "auth.json"));
|
||||
authStorage.setRuntimeApiKey("anthropic", "test-key");
|
||||
const modelRegistry = new ModelRegistry(authStorage, tempDir);
|
||||
const modelRegistry = ModelRegistry.create(authStorage, tempDir);
|
||||
|
||||
session = new AgentSession({
|
||||
agent,
|
||||
|
||||
@@ -55,7 +55,7 @@ describe.skipIf(!API_KEY)("AgentSession forking", () => {
|
||||
sessionManager = noSession ? SessionManager.inMemory() : SessionManager.create(tempDir);
|
||||
const settingsManager = SettingsManager.create(tempDir, tempDir);
|
||||
const authStorage = AuthStorage.create(join(tempDir, "auth.json"));
|
||||
const modelRegistry = new ModelRegistry(authStorage, tempDir);
|
||||
const modelRegistry = ModelRegistry.create(authStorage, tempDir);
|
||||
|
||||
session = new AgentSession({
|
||||
agent,
|
||||
|
||||
@@ -61,7 +61,7 @@ describe.skipIf(!API_KEY)("AgentSession compaction e2e", () => {
|
||||
// Use minimal keepRecentTokens so small test conversations have something to summarize
|
||||
settingsManager.applyOverrides({ compaction: { keepRecentTokens: 1 } });
|
||||
const authStorage = AuthStorage.create(join(tempDir, "auth.json"));
|
||||
const modelRegistry = new ModelRegistry(authStorage);
|
||||
const modelRegistry = ModelRegistry.create(authStorage);
|
||||
|
||||
session = new AgentSession({
|
||||
agent,
|
||||
|
||||
@@ -110,7 +110,7 @@ describe("AgentSession concurrent prompt guard", () => {
|
||||
const sessionManager = SessionManager.inMemory();
|
||||
const settingsManager = SettingsManager.create(tempDir, tempDir);
|
||||
const authStorage = AuthStorage.create(join(tempDir, "auth.json"));
|
||||
const modelRegistry = new ModelRegistry(authStorage, tempDir);
|
||||
const modelRegistry = ModelRegistry.create(authStorage, tempDir);
|
||||
// Set a runtime API key so validation passes
|
||||
authStorage.setRuntimeApiKey("anthropic", "test-key");
|
||||
|
||||
@@ -235,7 +235,7 @@ describe("AgentSession concurrent prompt guard", () => {
|
||||
const sessionManager = SessionManager.inMemory();
|
||||
const settingsManager = SettingsManager.create(tempDir, tempDir);
|
||||
const authStorage = AuthStorage.create(join(tempDir, "auth.json"));
|
||||
const modelRegistry = new ModelRegistry(authStorage, tempDir);
|
||||
const modelRegistry = ModelRegistry.create(authStorage, tempDir);
|
||||
authStorage.setRuntimeApiKey("anthropic", "test-key");
|
||||
|
||||
const extensionsResult = await createTestExtensionsResult([
|
||||
@@ -313,7 +313,7 @@ describe("AgentSession concurrent prompt guard", () => {
|
||||
const sessionManager = SessionManager.inMemory();
|
||||
const settingsManager = SettingsManager.create(tempDir, tempDir);
|
||||
const authStorage = AuthStorage.create(join(tempDir, "auth.json"));
|
||||
const modelRegistry = new ModelRegistry(authStorage, tempDir);
|
||||
const modelRegistry = ModelRegistry.create(authStorage, tempDir);
|
||||
authStorage.setRuntimeApiKey("anthropic", "test-key");
|
||||
|
||||
session = new AgentSession({
|
||||
@@ -419,7 +419,7 @@ describe("AgentSession concurrent prompt guard", () => {
|
||||
const sessionManager = SessionManager.inMemory();
|
||||
const settingsManager = SettingsManager.create(tempDir, tempDir);
|
||||
const authStorage = AuthStorage.create(join(tempDir, "auth.json"));
|
||||
const modelRegistry = new ModelRegistry(authStorage, tempDir);
|
||||
const modelRegistry = ModelRegistry.create(authStorage, tempDir);
|
||||
authStorage.setRuntimeApiKey("anthropic", "test-key");
|
||||
|
||||
session = new AgentSession({
|
||||
@@ -556,7 +556,7 @@ describe("AgentSession concurrent prompt guard", () => {
|
||||
const sessionManager = SessionManager.inMemory();
|
||||
const settingsManager = SettingsManager.create(tempDir, tempDir);
|
||||
const authStorage = AuthStorage.create(join(tempDir, "auth.json"));
|
||||
const modelRegistry = new ModelRegistry(authStorage, tempDir);
|
||||
const modelRegistry = ModelRegistry.create(authStorage, tempDir);
|
||||
authStorage.setRuntimeApiKey("anthropic", "test-key");
|
||||
|
||||
session = new AgentSession({
|
||||
|
||||
@@ -37,7 +37,7 @@ function createSession({
|
||||
sessionManager,
|
||||
settingsManager,
|
||||
cwd: process.cwd(),
|
||||
modelRegistry: new ModelRegistry(authStorage, undefined),
|
||||
modelRegistry: ModelRegistry.inMemory(authStorage),
|
||||
resourceLoader: createTestResourceLoader(),
|
||||
scopedModels,
|
||||
});
|
||||
|
||||
@@ -102,7 +102,7 @@ describe("AgentSession retry", () => {
|
||||
const sessionManager = SessionManager.inMemory();
|
||||
const settingsManager = SettingsManager.create(tempDir, tempDir);
|
||||
const authStorage = AuthStorage.create(join(tempDir, "auth.json"));
|
||||
const modelRegistry = new ModelRegistry(authStorage, tempDir);
|
||||
const modelRegistry = ModelRegistry.create(authStorage, tempDir);
|
||||
authStorage.setRuntimeApiKey("anthropic", "test-key");
|
||||
settingsManager.applyOverrides({ retry: { enabled: true, maxRetries, baseDelayMs: 1 } });
|
||||
|
||||
@@ -204,7 +204,7 @@ describe("AgentSession retry", () => {
|
||||
const sessionManager = SessionManager.inMemory();
|
||||
const settingsManager = SettingsManager.create(tempDir, tempDir);
|
||||
const authStorage = AuthStorage.create(join(tempDir, "auth.json"));
|
||||
const modelRegistry = new ModelRegistry(authStorage, tempDir);
|
||||
const modelRegistry = ModelRegistry.create(authStorage, tempDir);
|
||||
authStorage.setRuntimeApiKey("anthropic", "test-key");
|
||||
settingsManager.applyOverrides({ retry: { enabled: true, maxRetries: 3, baseDelayMs: 1 } });
|
||||
session = new AgentSession({
|
||||
@@ -289,7 +289,7 @@ describe("AgentSession retry", () => {
|
||||
const sessionManager = SessionManager.inMemory();
|
||||
const settingsManager = SettingsManager.create(tempDir, tempDir);
|
||||
const authStorage = AuthStorage.create(join(tempDir, "auth.json"));
|
||||
const modelRegistry = new ModelRegistry(authStorage, tempDir);
|
||||
const modelRegistry = ModelRegistry.create(authStorage, tempDir);
|
||||
authStorage.setRuntimeApiKey("anthropic", "test-key");
|
||||
settingsManager.applyOverrides({ retry: { enabled: true, maxRetries: 3, baseDelayMs: 1 } });
|
||||
|
||||
|
||||
@@ -66,7 +66,7 @@ function createSession() {
|
||||
sessionManager,
|
||||
settingsManager,
|
||||
cwd: process.cwd(),
|
||||
modelRegistry: new ModelRegistry(authStorage, undefined),
|
||||
modelRegistry: ModelRegistry.inMemory(authStorage),
|
||||
resourceLoader: createTestResourceLoader(),
|
||||
});
|
||||
|
||||
|
||||
@@ -99,7 +99,7 @@ describe.skipIf(!API_KEY)("Compaction extensions", () => {
|
||||
const sessionManager = SessionManager.create(tempDir);
|
||||
const settingsManager = SettingsManager.create(tempDir, tempDir);
|
||||
const authStorage = AuthStorage.create(join(tempDir, "auth.json"));
|
||||
const modelRegistry = new ModelRegistry(authStorage);
|
||||
const modelRegistry = ModelRegistry.create(authStorage);
|
||||
|
||||
const runtime = createExtensionRuntime();
|
||||
const resourceLoader = {
|
||||
|
||||
@@ -81,7 +81,7 @@ describe.skipIf(!HAS_ANTIGRAVITY_AUTH)("Compaction with thinking models (Antigra
|
||||
// settingsManager.applyOverrides({ compaction: { keepRecentTokens: 1 } });
|
||||
|
||||
const authStorage = getRealAuthStorage();
|
||||
const modelRegistry = new ModelRegistry(authStorage);
|
||||
const modelRegistry = ModelRegistry.create(authStorage);
|
||||
|
||||
session = new AgentSession({
|
||||
agent,
|
||||
@@ -177,7 +177,7 @@ describe.skipIf(!HAS_ANTHROPIC_AUTH)("Compaction with thinking models (Anthropic
|
||||
const settingsManager = SettingsManager.create(tempDir, tempDir);
|
||||
|
||||
const authStorage = getRealAuthStorage();
|
||||
const modelRegistry = new ModelRegistry(authStorage);
|
||||
const modelRegistry = ModelRegistry.create(authStorage);
|
||||
|
||||
session = new AgentSession({
|
||||
agent,
|
||||
|
||||
@@ -29,7 +29,7 @@ describe("Input Event", () => {
|
||||
for (let i = 0; i < extensions.length; i++) fs.writeFileSync(path.join(extensionsDir, `e${i}.ts`), extensions[i]);
|
||||
const result = await discoverAndLoadExtensions([], tempDir, tempDir);
|
||||
const sm = SessionManager.inMemory();
|
||||
const mr = new ModelRegistry(AuthStorage.create(path.join(tempDir, "auth.json")));
|
||||
const mr = ModelRegistry.create(AuthStorage.create(path.join(tempDir, "auth.json")));
|
||||
return new ExtensionRunner(result.extensions, result.runtime, tempDir, sm, mr);
|
||||
}
|
||||
|
||||
|
||||
@@ -27,7 +27,7 @@ describe("ExtensionRunner", () => {
|
||||
fs.mkdirSync(extensionsDir);
|
||||
sessionManager = SessionManager.inMemory();
|
||||
const authStorage = AuthStorage.create(path.join(tempDir, "auth.json"));
|
||||
modelRegistry = new ModelRegistry(authStorage);
|
||||
modelRegistry = ModelRegistry.create(authStorage);
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
|
||||
@@ -94,7 +94,7 @@ describe("ModelRegistry", () => {
|
||||
anthropic: overrideConfig("https://my-proxy.example.com/v1"),
|
||||
});
|
||||
|
||||
const registry = new ModelRegistry(authStorage, modelsJsonPath);
|
||||
const registry = ModelRegistry.create(authStorage, modelsJsonPath);
|
||||
const anthropicModels = getModelsForProvider(registry, "anthropic");
|
||||
|
||||
// Should have multiple built-in models, not just one
|
||||
@@ -107,7 +107,7 @@ describe("ModelRegistry", () => {
|
||||
anthropic: overrideConfig("https://my-proxy.example.com/v1"),
|
||||
});
|
||||
|
||||
const registry = new ModelRegistry(authStorage, modelsJsonPath);
|
||||
const registry = ModelRegistry.create(authStorage, modelsJsonPath);
|
||||
const anthropicModels = getModelsForProvider(registry, "anthropic");
|
||||
|
||||
// All models should have the new baseUrl
|
||||
@@ -123,7 +123,7 @@ describe("ModelRegistry", () => {
|
||||
}),
|
||||
});
|
||||
|
||||
const registry = new ModelRegistry(authStorage, modelsJsonPath);
|
||||
const registry = ModelRegistry.create(authStorage, modelsJsonPath);
|
||||
const anthropicModels = getModelsForProvider(registry, "anthropic");
|
||||
|
||||
for (const model of anthropicModels) {
|
||||
@@ -140,7 +140,7 @@ describe("ModelRegistry", () => {
|
||||
anthropic: overrideConfig("https://my-proxy.example.com/v1"),
|
||||
});
|
||||
|
||||
const registry = new ModelRegistry(authStorage, modelsJsonPath);
|
||||
const registry = ModelRegistry.create(authStorage, modelsJsonPath);
|
||||
const googleModels = getModelsForProvider(registry, "google");
|
||||
|
||||
// Google models should still have their original baseUrl
|
||||
@@ -160,7 +160,7 @@ describe("ModelRegistry", () => {
|
||||
),
|
||||
});
|
||||
|
||||
const registry = new ModelRegistry(authStorage, modelsJsonPath);
|
||||
const registry = ModelRegistry.create(authStorage, modelsJsonPath);
|
||||
|
||||
// Anthropic: multiple built-in models with new baseUrl
|
||||
const anthropicModels = getModelsForProvider(registry, "anthropic");
|
||||
@@ -177,7 +177,7 @@ describe("ModelRegistry", () => {
|
||||
writeRawModelsJson({
|
||||
anthropic: overrideConfig("https://first-proxy.example.com/v1"),
|
||||
});
|
||||
const registry = new ModelRegistry(authStorage, modelsJsonPath);
|
||||
const registry = ModelRegistry.create(authStorage, modelsJsonPath);
|
||||
|
||||
expect(getModelsForProvider(registry, "anthropic")[0].baseUrl).toBe("https://first-proxy.example.com/v1");
|
||||
|
||||
@@ -197,7 +197,7 @@ describe("ModelRegistry", () => {
|
||||
anthropic: providerConfig("https://my-proxy.example.com/v1", [{ id: "claude-custom" }]),
|
||||
});
|
||||
|
||||
const registry = new ModelRegistry(authStorage, modelsJsonPath);
|
||||
const registry = ModelRegistry.create(authStorage, modelsJsonPath);
|
||||
const anthropicModels = getModelsForProvider(registry, "anthropic");
|
||||
|
||||
expect(anthropicModels.length).toBeGreaterThan(1);
|
||||
@@ -214,7 +214,7 @@ describe("ModelRegistry", () => {
|
||||
),
|
||||
});
|
||||
|
||||
const registry = new ModelRegistry(authStorage, modelsJsonPath);
|
||||
const registry = ModelRegistry.create(authStorage, modelsJsonPath);
|
||||
const models = getModelsForProvider(registry, "openrouter");
|
||||
const sonnetModels = models.filter((m) => m.id === "anthropic/claude-sonnet-4");
|
||||
|
||||
@@ -227,7 +227,7 @@ describe("ModelRegistry", () => {
|
||||
anthropic: providerConfig("https://my-proxy.example.com/v1", [{ id: "claude-custom" }]),
|
||||
});
|
||||
|
||||
const registry = new ModelRegistry(authStorage, modelsJsonPath);
|
||||
const registry = ModelRegistry.create(authStorage, modelsJsonPath);
|
||||
|
||||
expect(getModelsForProvider(registry, "google").length).toBeGreaterThan(0);
|
||||
expect(getModelsForProvider(registry, "openai").length).toBeGreaterThan(0);
|
||||
@@ -238,7 +238,7 @@ describe("ModelRegistry", () => {
|
||||
anthropic: providerConfig("https://merged-proxy.example.com/v1", [{ id: "claude-custom" }]),
|
||||
});
|
||||
|
||||
const registry = new ModelRegistry(authStorage, modelsJsonPath);
|
||||
const registry = ModelRegistry.create(authStorage, modelsJsonPath);
|
||||
const anthropicModels = getModelsForProvider(registry, "anthropic");
|
||||
|
||||
for (const model of anthropicModels) {
|
||||
@@ -269,7 +269,7 @@ describe("ModelRegistry", () => {
|
||||
},
|
||||
});
|
||||
|
||||
const registry = new ModelRegistry(authStorage, modelsJsonPath);
|
||||
const registry = ModelRegistry.create(authStorage, modelsJsonPath);
|
||||
const compat = registry.find("demo", "demo-model")?.compat as OpenAICompletionsCompat | undefined;
|
||||
|
||||
expect(compat?.supportsUsageInStreaming).toBe(false);
|
||||
@@ -303,7 +303,7 @@ describe("ModelRegistry", () => {
|
||||
},
|
||||
});
|
||||
|
||||
const registry = new ModelRegistry(authStorage, modelsJsonPath);
|
||||
const registry = ModelRegistry.create(authStorage, modelsJsonPath);
|
||||
const compat = registry.find("demo", "demo-model")?.compat as OpenAICompletionsCompat | undefined;
|
||||
|
||||
expect(compat?.supportsUsageInStreaming).toBe(true);
|
||||
@@ -320,7 +320,7 @@ describe("ModelRegistry", () => {
|
||||
},
|
||||
});
|
||||
|
||||
const registry = new ModelRegistry(authStorage, modelsJsonPath);
|
||||
const registry = ModelRegistry.create(authStorage, modelsJsonPath);
|
||||
const models = getModelsForProvider(registry, "openrouter");
|
||||
|
||||
expect(models.length).toBeGreaterThan(0);
|
||||
@@ -357,7 +357,7 @@ describe("ModelRegistry", () => {
|
||||
},
|
||||
});
|
||||
|
||||
const registry = new ModelRegistry(authStorage, modelsJsonPath);
|
||||
const registry = ModelRegistry.create(authStorage, modelsJsonPath);
|
||||
const compat = registry.find("demo", "demo-model")?.compat as OpenAICompletionsCompat | undefined;
|
||||
|
||||
expect(registry.getError()).toBeUndefined();
|
||||
@@ -394,7 +394,7 @@ describe("ModelRegistry", () => {
|
||||
},
|
||||
});
|
||||
|
||||
const registry = new ModelRegistry(authStorage, modelsJsonPath);
|
||||
const registry = ModelRegistry.create(authStorage, modelsJsonPath);
|
||||
const m25 = registry.find("opencode-go", "minimax-m2.5");
|
||||
const glm5 = registry.find("opencode-go", "glm-5");
|
||||
|
||||
@@ -427,7 +427,7 @@ describe("ModelRegistry", () => {
|
||||
},
|
||||
});
|
||||
|
||||
const registry = new ModelRegistry(authStorage, modelsJsonPath);
|
||||
const registry = ModelRegistry.create(authStorage, modelsJsonPath);
|
||||
const models = getModelsForProvider(registry, "openrouter");
|
||||
|
||||
expect(models.some((m) => m.id === "custom/openrouter-model")).toBe(true);
|
||||
@@ -440,7 +440,7 @@ describe("ModelRegistry", () => {
|
||||
writeModelsJson({
|
||||
anthropic: providerConfig("https://first-proxy.example.com/v1", [{ id: "claude-custom" }]),
|
||||
});
|
||||
const registry = new ModelRegistry(authStorage, modelsJsonPath);
|
||||
const registry = ModelRegistry.create(authStorage, modelsJsonPath);
|
||||
expect(getModelsForProvider(registry, "anthropic").some((m) => m.id === "claude-custom")).toBe(true);
|
||||
|
||||
// Update and refresh
|
||||
@@ -459,7 +459,7 @@ describe("ModelRegistry", () => {
|
||||
writeModelsJson({
|
||||
anthropic: providerConfig("https://proxy.example.com/v1", [{ id: "claude-custom" }]),
|
||||
});
|
||||
const registry = new ModelRegistry(authStorage, modelsJsonPath);
|
||||
const registry = ModelRegistry.create(authStorage, modelsJsonPath);
|
||||
expect(getModelsForProvider(registry, "anthropic").some((m) => m.id === "claude-custom")).toBe(true);
|
||||
|
||||
// Remove custom models and refresh
|
||||
@@ -484,7 +484,7 @@ describe("ModelRegistry", () => {
|
||||
},
|
||||
});
|
||||
|
||||
const registry = new ModelRegistry(authStorage, modelsJsonPath);
|
||||
const registry = ModelRegistry.create(authStorage, modelsJsonPath);
|
||||
const models = getModelsForProvider(registry, "openrouter");
|
||||
|
||||
const sonnet = models.find((m) => m.id === "anthropic/claude-sonnet-4");
|
||||
@@ -508,7 +508,7 @@ describe("ModelRegistry", () => {
|
||||
},
|
||||
});
|
||||
|
||||
const registry = new ModelRegistry(authStorage, modelsJsonPath);
|
||||
const registry = ModelRegistry.create(authStorage, modelsJsonPath);
|
||||
const models = getModelsForProvider(registry, "openrouter");
|
||||
|
||||
const sonnet = models.find((m) => m.id === "anthropic/claude-sonnet-4");
|
||||
@@ -529,7 +529,7 @@ describe("ModelRegistry", () => {
|
||||
},
|
||||
});
|
||||
|
||||
const registry = new ModelRegistry(authStorage, modelsJsonPath);
|
||||
const registry = ModelRegistry.create(authStorage, modelsJsonPath);
|
||||
const models = getModelsForProvider(registry, "openrouter");
|
||||
const sonnet = models.find((m) => m.id === "anthropic/claude-sonnet-4");
|
||||
|
||||
@@ -552,7 +552,7 @@ describe("ModelRegistry", () => {
|
||||
},
|
||||
});
|
||||
|
||||
const registry = new ModelRegistry(authStorage, modelsJsonPath);
|
||||
const registry = ModelRegistry.create(authStorage, modelsJsonPath);
|
||||
const models = getModelsForProvider(registry, "openrouter");
|
||||
|
||||
const sonnet = models.find((m) => m.id === "anthropic/claude-sonnet-4");
|
||||
@@ -576,7 +576,7 @@ describe("ModelRegistry", () => {
|
||||
},
|
||||
});
|
||||
|
||||
const registry = new ModelRegistry(authStorage, modelsJsonPath);
|
||||
const registry = ModelRegistry.create(authStorage, modelsJsonPath);
|
||||
const models = getModelsForProvider(registry, "openrouter");
|
||||
const sonnet = models.find((m) => m.id === "anthropic/claude-sonnet-4");
|
||||
|
||||
@@ -601,7 +601,7 @@ describe("ModelRegistry", () => {
|
||||
},
|
||||
});
|
||||
|
||||
const registry = new ModelRegistry(authStorage, modelsJsonPath);
|
||||
const registry = ModelRegistry.create(authStorage, modelsJsonPath);
|
||||
const models = getModelsForProvider(registry, "openrouter");
|
||||
|
||||
// Should not create a new model
|
||||
@@ -621,7 +621,7 @@ describe("ModelRegistry", () => {
|
||||
},
|
||||
});
|
||||
|
||||
const registry = new ModelRegistry(authStorage, modelsJsonPath);
|
||||
const registry = ModelRegistry.create(authStorage, modelsJsonPath);
|
||||
const models = getModelsForProvider(registry, "openrouter");
|
||||
const sonnet = models.find((m) => m.id === "anthropic/claude-sonnet-4");
|
||||
|
||||
@@ -642,7 +642,7 @@ describe("ModelRegistry", () => {
|
||||
},
|
||||
});
|
||||
|
||||
const registry = new ModelRegistry(authStorage, modelsJsonPath);
|
||||
const registry = ModelRegistry.create(authStorage, modelsJsonPath);
|
||||
const models = getModelsForProvider(registry, "openrouter");
|
||||
const sonnet = models.find((m) => m.id === "anthropic/claude-sonnet-4");
|
||||
expect(sonnet).toBeDefined();
|
||||
@@ -665,7 +665,7 @@ describe("ModelRegistry", () => {
|
||||
},
|
||||
});
|
||||
|
||||
const registry = new ModelRegistry(authStorage, modelsJsonPath);
|
||||
const registry = ModelRegistry.create(authStorage, modelsJsonPath);
|
||||
expect(
|
||||
getModelsForProvider(registry, "openrouter").find((m) => m.id === "anthropic/claude-sonnet-4")?.name,
|
||||
).toBe("First Name");
|
||||
@@ -698,7 +698,7 @@ describe("ModelRegistry", () => {
|
||||
},
|
||||
});
|
||||
|
||||
const registry = new ModelRegistry(authStorage, modelsJsonPath);
|
||||
const registry = ModelRegistry.create(authStorage, modelsJsonPath);
|
||||
const customName = getModelsForProvider(registry, "openrouter").find(
|
||||
(m) => m.id === "anthropic/claude-sonnet-4",
|
||||
)?.name;
|
||||
@@ -717,7 +717,7 @@ describe("ModelRegistry", () => {
|
||||
|
||||
describe("dynamic provider lifecycle", () => {
|
||||
test("failed registerProvider does not persist invalid streamSimple config", () => {
|
||||
const registry = new ModelRegistry(authStorage, modelsJsonPath);
|
||||
const registry = ModelRegistry.create(authStorage, modelsJsonPath);
|
||||
|
||||
expect(() =>
|
||||
registry.registerProvider("broken-provider", {
|
||||
@@ -731,7 +731,7 @@ describe("ModelRegistry", () => {
|
||||
});
|
||||
|
||||
test("failed registerProvider does not remove existing provider models", () => {
|
||||
const registry = new ModelRegistry(authStorage, modelsJsonPath);
|
||||
const registry = ModelRegistry.create(authStorage, modelsJsonPath);
|
||||
|
||||
registry.registerProvider("demo-provider", {
|
||||
baseUrl: "https://provider.test/v1",
|
||||
@@ -776,7 +776,7 @@ describe("ModelRegistry", () => {
|
||||
});
|
||||
|
||||
test("unregisterProvider removes custom OAuth provider and restores built-in OAuth provider", () => {
|
||||
const registry = new ModelRegistry(authStorage, modelsJsonPath);
|
||||
const registry = ModelRegistry.create(authStorage, modelsJsonPath);
|
||||
|
||||
registry.registerProvider("anthropic", {
|
||||
oauth: {
|
||||
@@ -799,7 +799,7 @@ describe("ModelRegistry", () => {
|
||||
});
|
||||
|
||||
test("unregisterProvider removes custom streamSimple override and restores built-in API stream handler", () => {
|
||||
const registry = new ModelRegistry(authStorage, modelsJsonPath);
|
||||
const registry = ModelRegistry.create(authStorage, modelsJsonPath);
|
||||
|
||||
registry.registerProvider("stream-override-provider", {
|
||||
api: "openai-completions",
|
||||
@@ -855,7 +855,7 @@ describe("ModelRegistry", () => {
|
||||
"custom-provider": providerWithApiKey("!echo test-api-key-from-command"),
|
||||
});
|
||||
|
||||
const registry = new ModelRegistry(authStorage, modelsJsonPath);
|
||||
const registry = ModelRegistry.create(authStorage, modelsJsonPath);
|
||||
const apiKey = await registry.getApiKeyForProvider("custom-provider");
|
||||
|
||||
expect(apiKey).toBe("test-api-key-from-command");
|
||||
@@ -866,7 +866,7 @@ describe("ModelRegistry", () => {
|
||||
"custom-provider": providerWithApiKey("!echo ' spaced-key '"),
|
||||
});
|
||||
|
||||
const registry = new ModelRegistry(authStorage, modelsJsonPath);
|
||||
const registry = ModelRegistry.create(authStorage, modelsJsonPath);
|
||||
const apiKey = await registry.getApiKeyForProvider("custom-provider");
|
||||
|
||||
expect(apiKey).toBe("spaced-key");
|
||||
@@ -877,7 +877,7 @@ describe("ModelRegistry", () => {
|
||||
"custom-provider": providerWithApiKey("!printf 'line1\\nline2'"),
|
||||
});
|
||||
|
||||
const registry = new ModelRegistry(authStorage, modelsJsonPath);
|
||||
const registry = ModelRegistry.create(authStorage, modelsJsonPath);
|
||||
const apiKey = await registry.getApiKeyForProvider("custom-provider");
|
||||
|
||||
expect(apiKey).toBe("line1\nline2");
|
||||
@@ -888,7 +888,7 @@ describe("ModelRegistry", () => {
|
||||
"custom-provider": providerWithApiKey("!exit 1"),
|
||||
});
|
||||
|
||||
const registry = new ModelRegistry(authStorage, modelsJsonPath);
|
||||
const registry = ModelRegistry.create(authStorage, modelsJsonPath);
|
||||
const apiKey = await registry.getApiKeyForProvider("custom-provider");
|
||||
|
||||
expect(apiKey).toBeUndefined();
|
||||
@@ -899,7 +899,7 @@ describe("ModelRegistry", () => {
|
||||
"custom-provider": providerWithApiKey("!nonexistent-command-12345"),
|
||||
});
|
||||
|
||||
const registry = new ModelRegistry(authStorage, modelsJsonPath);
|
||||
const registry = ModelRegistry.create(authStorage, modelsJsonPath);
|
||||
const apiKey = await registry.getApiKeyForProvider("custom-provider");
|
||||
|
||||
expect(apiKey).toBeUndefined();
|
||||
@@ -910,7 +910,7 @@ describe("ModelRegistry", () => {
|
||||
"custom-provider": providerWithApiKey("!printf ''"),
|
||||
});
|
||||
|
||||
const registry = new ModelRegistry(authStorage, modelsJsonPath);
|
||||
const registry = ModelRegistry.create(authStorage, modelsJsonPath);
|
||||
const apiKey = await registry.getApiKeyForProvider("custom-provider");
|
||||
|
||||
expect(apiKey).toBeUndefined();
|
||||
@@ -925,7 +925,7 @@ describe("ModelRegistry", () => {
|
||||
"custom-provider": providerWithApiKey("TEST_API_KEY_12345"),
|
||||
});
|
||||
|
||||
const registry = new ModelRegistry(authStorage, modelsJsonPath);
|
||||
const registry = ModelRegistry.create(authStorage, modelsJsonPath);
|
||||
const apiKey = await registry.getApiKeyForProvider("custom-provider");
|
||||
|
||||
expect(apiKey).toBe("env-api-key-value");
|
||||
@@ -946,7 +946,7 @@ describe("ModelRegistry", () => {
|
||||
"custom-provider": providerWithApiKey("literal_api_key_value"),
|
||||
});
|
||||
|
||||
const registry = new ModelRegistry(authStorage, modelsJsonPath);
|
||||
const registry = ModelRegistry.create(authStorage, modelsJsonPath);
|
||||
const apiKey = await registry.getApiKeyForProvider("custom-provider");
|
||||
|
||||
expect(apiKey).toBe("literal_api_key_value");
|
||||
@@ -957,7 +957,7 @@ describe("ModelRegistry", () => {
|
||||
"custom-provider": providerWithApiKey("!echo 'hello world' | tr ' ' '-'"),
|
||||
});
|
||||
|
||||
const registry = new ModelRegistry(authStorage, modelsJsonPath);
|
||||
const registry = ModelRegistry.create(authStorage, modelsJsonPath);
|
||||
const apiKey = await registry.getApiKeyForProvider("custom-provider");
|
||||
|
||||
expect(apiKey).toBe("hello-world");
|
||||
@@ -974,7 +974,7 @@ describe("ModelRegistry", () => {
|
||||
"custom-provider": providerWithApiKey(command),
|
||||
});
|
||||
|
||||
const registry = new ModelRegistry(authStorage, modelsJsonPath);
|
||||
const registry = ModelRegistry.create(authStorage, modelsJsonPath);
|
||||
await registry.getApiKeyForProvider("custom-provider");
|
||||
await registry.getApiKeyForProvider("custom-provider");
|
||||
await registry.getApiKeyForProvider("custom-provider");
|
||||
@@ -993,10 +993,10 @@ describe("ModelRegistry", () => {
|
||||
"custom-provider": providerWithApiKey(command),
|
||||
});
|
||||
|
||||
const registry1 = new ModelRegistry(authStorage, modelsJsonPath);
|
||||
const registry1 = ModelRegistry.create(authStorage, modelsJsonPath);
|
||||
await registry1.getApiKeyForProvider("custom-provider");
|
||||
|
||||
const registry2 = new ModelRegistry(authStorage, modelsJsonPath);
|
||||
const registry2 = ModelRegistry.create(authStorage, modelsJsonPath);
|
||||
await registry2.getApiKeyForProvider("custom-provider");
|
||||
|
||||
const count = parseInt(readFileSync(counterFile, "utf-8").trim(), 10);
|
||||
@@ -1009,7 +1009,7 @@ describe("ModelRegistry", () => {
|
||||
"provider-b": providerWithApiKey("!echo key-b"),
|
||||
});
|
||||
|
||||
const registry = new ModelRegistry(authStorage, modelsJsonPath);
|
||||
const registry = ModelRegistry.create(authStorage, modelsJsonPath);
|
||||
|
||||
const keyA = await registry.getApiKeyForProvider("provider-a");
|
||||
const keyB = await registry.getApiKeyForProvider("provider-b");
|
||||
@@ -1028,7 +1028,7 @@ describe("ModelRegistry", () => {
|
||||
"custom-provider": providerWithApiKey(command),
|
||||
});
|
||||
|
||||
const registry = new ModelRegistry(authStorage, modelsJsonPath);
|
||||
const registry = ModelRegistry.create(authStorage, modelsJsonPath);
|
||||
const key1 = await registry.getApiKeyForProvider("custom-provider");
|
||||
const key2 = await registry.getApiKeyForProvider("custom-provider");
|
||||
|
||||
@@ -1050,7 +1050,7 @@ describe("ModelRegistry", () => {
|
||||
"custom-provider": providerWithApiKey(envVarName),
|
||||
});
|
||||
|
||||
const registry = new ModelRegistry(authStorage, modelsJsonPath);
|
||||
const registry = ModelRegistry.create(authStorage, modelsJsonPath);
|
||||
|
||||
const key1 = await registry.getApiKeyForProvider("custom-provider");
|
||||
expect(key1).toBe("first-value");
|
||||
@@ -1078,7 +1078,7 @@ describe("ModelRegistry", () => {
|
||||
"custom-provider": providerWithApiKey(command),
|
||||
});
|
||||
|
||||
const registry = new ModelRegistry(authStorage, modelsJsonPath);
|
||||
const registry = ModelRegistry.create(authStorage, modelsJsonPath);
|
||||
const available = registry.getAvailable();
|
||||
|
||||
expect(available.some((m) => m.provider === "custom-provider")).toBe(true);
|
||||
@@ -1098,7 +1098,7 @@ describe("ModelRegistry", () => {
|
||||
},
|
||||
});
|
||||
|
||||
const registry = new ModelRegistry(authStorage, modelsJsonPath);
|
||||
const registry = ModelRegistry.create(authStorage, modelsJsonPath);
|
||||
const model = registry.find("custom-provider", "test-model");
|
||||
expect(model).toBeDefined();
|
||||
|
||||
@@ -1127,7 +1127,7 @@ describe("ModelRegistry", () => {
|
||||
},
|
||||
});
|
||||
|
||||
const registry = new ModelRegistry(authStorage, modelsJsonPath);
|
||||
const registry = ModelRegistry.create(authStorage, modelsJsonPath);
|
||||
const model = registry.find("custom-provider", "test-model");
|
||||
expect(model).toBeDefined();
|
||||
|
||||
|
||||
@@ -199,7 +199,7 @@ Project skill`,
|
||||
|
||||
const sessionManager = SessionManager.inMemory();
|
||||
const authStorage = AuthStorage.create(join(tempDir, "auth.json"));
|
||||
const modelRegistry = new ModelRegistry(authStorage);
|
||||
const modelRegistry = ModelRegistry.create(authStorage);
|
||||
const runner = new ExtensionRunner(
|
||||
extensionsResult.extensions,
|
||||
extensionsResult.runtime,
|
||||
@@ -548,7 +548,7 @@ export default function(pi: ExtensionAPI) {
|
||||
|
||||
const sessionManager = SessionManager.inMemory();
|
||||
const authStorage = AuthStorage.create(join(tempDir, "auth-explicit.json"));
|
||||
const modelRegistry = new ModelRegistry(authStorage);
|
||||
const modelRegistry = ModelRegistry.create(authStorage);
|
||||
const runner = new ExtensionRunner(
|
||||
extensionsResult.extensions,
|
||||
extensionsResult.runtime,
|
||||
|
||||
@@ -204,7 +204,7 @@ async function main(): Promise<void> {
|
||||
mkdirSync(dirname(args.sessionPath), { recursive: true });
|
||||
|
||||
const authStorage = AuthStorage.create();
|
||||
const modelRegistry = new ModelRegistry(authStorage);
|
||||
const modelRegistry = ModelRegistry.create(authStorage);
|
||||
|
||||
const model = getModel("openai-codex", "gpt-5.4");
|
||||
if (!model) {
|
||||
|
||||
@@ -0,0 +1,141 @@
|
||||
/**
|
||||
* Local test harness for the new coding-agent test suite.
|
||||
*/
|
||||
|
||||
import { existsSync, mkdirSync, rmSync } from "node:fs";
|
||||
import { tmpdir } from "node:os";
|
||||
import { join } from "node:path";
|
||||
import type { AgentTool } from "@mariozechner/pi-agent-core";
|
||||
import { Agent } from "@mariozechner/pi-agent-core";
|
||||
import type { FauxModelDefinition, FauxProviderRegistration, FauxResponseStep, Model } from "@mariozechner/pi-ai";
|
||||
import { registerFauxProvider } from "@mariozechner/pi-ai";
|
||||
import { AgentSession, type AgentSessionEvent } from "../../src/core/agent-session.js";
|
||||
import { AuthStorage } from "../../src/core/auth-storage.js";
|
||||
import { ModelRegistry } from "../../src/core/model-registry.js";
|
||||
import { SessionManager } from "../../src/core/session-manager.js";
|
||||
import type { Settings } from "../../src/core/settings-manager.js";
|
||||
import { SettingsManager } from "../../src/core/settings-manager.js";
|
||||
import type { ExtensionFactory, ResourceLoader } from "../../src/index.js";
|
||||
import {
|
||||
type CreateTestExtensionsResultInput,
|
||||
createTestExtensionsResult,
|
||||
createTestResourceLoader,
|
||||
} from "../utilities.js";
|
||||
|
||||
export interface HarnessOptions {
|
||||
models?: FauxModelDefinition[];
|
||||
settings?: Partial<Settings>;
|
||||
systemPrompt?: string;
|
||||
tools?: AgentTool[];
|
||||
resourceLoader?: ResourceLoader;
|
||||
extensionFactories?: Array<ExtensionFactory | CreateTestExtensionsResultInput>;
|
||||
}
|
||||
|
||||
export interface Harness {
|
||||
session: AgentSession;
|
||||
sessionManager: SessionManager;
|
||||
settingsManager: SettingsManager;
|
||||
faux: FauxProviderRegistration;
|
||||
models: [Model<string>, ...Model<string>[]];
|
||||
getModel(): Model<string>;
|
||||
getModel(modelId: string): Model<string> | undefined;
|
||||
setResponses: (responses: FauxResponseStep[]) => void;
|
||||
appendResponses: (responses: FauxResponseStep[]) => void;
|
||||
getPendingResponseCount: () => number;
|
||||
events: AgentSessionEvent[];
|
||||
eventsOfType<T extends AgentSessionEvent["type"]>(type: T): Extract<AgentSessionEvent, { type: T }>[];
|
||||
tempDir: string;
|
||||
cleanup: () => void;
|
||||
}
|
||||
|
||||
function createTempDir(): string {
|
||||
const tempDir = join(tmpdir(), `pi-suite-${Date.now()}-${Math.random().toString(36).slice(2)}`);
|
||||
mkdirSync(tempDir, { recursive: true });
|
||||
return tempDir;
|
||||
}
|
||||
|
||||
export async function createHarness(options: HarnessOptions = {}): Promise<Harness> {
|
||||
const tempDir = createTempDir();
|
||||
const fauxProvider: FauxProviderRegistration = registerFauxProvider({
|
||||
models: options.models,
|
||||
});
|
||||
fauxProvider.setResponses([]);
|
||||
const model = fauxProvider.getModel();
|
||||
const toolMap = options.tools ? Object.fromEntries(options.tools.map((tool) => [tool.name, tool])) : undefined;
|
||||
|
||||
const agent = new Agent({
|
||||
getApiKey: () => "faux-key",
|
||||
initialState: {
|
||||
model,
|
||||
systemPrompt: options.systemPrompt ?? "You are a test assistant.",
|
||||
tools: [],
|
||||
},
|
||||
});
|
||||
|
||||
const sessionManager = SessionManager.inMemory();
|
||||
const settingsManager = SettingsManager.inMemory(options.settings);
|
||||
|
||||
const authStorage = AuthStorage.inMemory();
|
||||
authStorage.setRuntimeApiKey(model.provider, "faux-key");
|
||||
const modelRegistry = ModelRegistry.inMemory(authStorage);
|
||||
modelRegistry.registerProvider(model.provider, {
|
||||
baseUrl: model.baseUrl,
|
||||
apiKey: "faux-key",
|
||||
api: fauxProvider.api,
|
||||
models: fauxProvider.models.map((registeredModel) => ({
|
||||
id: registeredModel.id,
|
||||
name: registeredModel.name,
|
||||
api: registeredModel.api,
|
||||
reasoning: registeredModel.reasoning,
|
||||
input: registeredModel.input,
|
||||
cost: registeredModel.cost,
|
||||
contextWindow: registeredModel.contextWindow,
|
||||
maxTokens: registeredModel.maxTokens,
|
||||
baseUrl: registeredModel.baseUrl,
|
||||
})),
|
||||
});
|
||||
const extensionsResult = options.extensionFactories
|
||||
? await createTestExtensionsResult(options.extensionFactories, tempDir)
|
||||
: undefined;
|
||||
const resourceLoader =
|
||||
options.resourceLoader ?? createTestResourceLoader(extensionsResult ? { extensionsResult } : undefined);
|
||||
|
||||
const session = new AgentSession({
|
||||
agent,
|
||||
sessionManager,
|
||||
settingsManager,
|
||||
cwd: tempDir,
|
||||
modelRegistry,
|
||||
resourceLoader,
|
||||
baseToolsOverride: toolMap,
|
||||
});
|
||||
|
||||
const events: AgentSessionEvent[] = [];
|
||||
session.subscribe((event) => {
|
||||
events.push(event);
|
||||
});
|
||||
|
||||
return {
|
||||
session,
|
||||
sessionManager,
|
||||
settingsManager,
|
||||
faux: fauxProvider,
|
||||
models: fauxProvider.models,
|
||||
getModel: fauxProvider.getModel,
|
||||
setResponses: fauxProvider.setResponses,
|
||||
appendResponses: fauxProvider.appendResponses,
|
||||
getPendingResponseCount: fauxProvider.getPendingResponseCount,
|
||||
events,
|
||||
eventsOfType<T extends AgentSessionEvent["type"]>(type: T) {
|
||||
return events.filter((event): event is Extract<AgentSessionEvent, { type: T }> => event.type === type);
|
||||
},
|
||||
tempDir,
|
||||
cleanup() {
|
||||
session.dispose();
|
||||
fauxProvider.unregister();
|
||||
if (existsSync(tempDir)) {
|
||||
rmSync(tempDir, { recursive: true });
|
||||
}
|
||||
},
|
||||
};
|
||||
}
|
||||
@@ -390,7 +390,7 @@ function createHarnessWithResourceLoader(
|
||||
|
||||
const authStorage = AuthStorage.create(join(tempDir, "auth.json"));
|
||||
authStorage.setRuntimeApiKey(model.provider, "faux-key");
|
||||
const modelRegistry = new ModelRegistry(authStorage, tempDir);
|
||||
const modelRegistry = ModelRegistry.create(authStorage, tempDir);
|
||||
|
||||
const session = new AgentSession({
|
||||
agent,
|
||||
|
||||
@@ -254,7 +254,7 @@ export function createTestSession(options: TestSessionOptions = {}): TestSession
|
||||
}
|
||||
|
||||
const authStorage = AuthStorage.create(join(tempDir, "auth.json"));
|
||||
const modelRegistry = new ModelRegistry(authStorage, tempDir);
|
||||
const modelRegistry = ModelRegistry.create(authStorage, tempDir);
|
||||
|
||||
const session = new AgentSession({
|
||||
agent,
|
||||
|
||||
@@ -429,7 +429,7 @@ function createRunner(sandboxConfig: SandboxConfig, channelId: string, channelDi
|
||||
// Create AuthStorage and ModelRegistry
|
||||
// Auth stored outside workspace so agent can't access it
|
||||
const authStorage = AuthStorage.create(join(homedir(), ".pi", "mom", "auth.json"));
|
||||
const modelRegistry = new ModelRegistry(authStorage);
|
||||
const modelRegistry = ModelRegistry.create(authStorage);
|
||||
|
||||
// Create agent
|
||||
const agent = new Agent({
|
||||
|
||||
Reference in New Issue
Block a user