feat(ai,coding-agent): add faux provider and ModelRegistry factories

This commit is contained in:
Mario Zechner
2026-03-29 21:08:50 +02:00
Unverified
parent fa890e3f94
commit ef6af5ebbd
31 changed files with 1425 additions and 94 deletions
+86
View File
@@ -636,6 +636,92 @@ The library uses a registry of API implementations. Built-in APIs include:
- **`azure-openai-responses`**: Azure OpenAI Responses API (`streamAzureOpenAIResponses`, `AzureOpenAIResponsesOptions`)
- **`bedrock-converse-stream`**: Amazon Bedrock Converse API (`streamBedrock`, `BedrockOptions`)
### Faux provider for tests
`registerFauxProvider()` registers a temporary in-memory provider for tests and demos. It is opt-in and not part of the built-in provider set.
```typescript
import {
complete,
fauxAssistantMessage,
fauxText,
fauxThinking,
fauxToolCall,
registerFauxProvider,
stream,
} from '@mariozechner/pi-ai';
const registration = registerFauxProvider({
tokensPerSecond: 50 // optional
});
const model = registration.getModel();
const context = {
messages: [{ role: 'user', content: 'Summarize package.json and then call echo', timestamp: Date.now() }]
};
registration.setResponses([
fauxAssistantMessage([
fauxThinking('Need to inspect package metadata first.'),
fauxToolCall('echo', { text: 'package.json' })
], { stopReason: 'toolUse' })
]);
const first = await complete(model, context, {
sessionId: 'session-1',
cacheRetention: 'short'
});
context.messages.push(first);
context.messages.push({
role: 'toolResult',
toolCallId: first.content.find((block) => block.type === 'toolCall')!.id,
toolName: 'echo',
content: [{ type: 'text', text: 'package.json contents here' }],
isError: false,
timestamp: Date.now()
});
registration.setResponses([
fauxAssistantMessage([
fauxThinking('Now I can summarize the tool output.'),
fauxText('Here is the summary.')
])
]);
const s = stream(model, context);
for await (const event of s) {
console.log(event.type);
}
// Optional: register multiple faux models for model-switching tests
const multiModel = registerFauxProvider({
models: [
{ id: 'faux-fast', reasoning: false },
{ id: 'faux-thinker', reasoning: true }
]
});
const thinker = multiModel.getModel('faux-thinker');
console.log(thinker?.reasoning);
console.log(registration.getPendingResponseCount());
console.log(registration.state.callCount);
registration.unregister();
multiModel.unregister();
```
Notes:
- Responses are consumed from a queue in request start order.
- If the queue is empty, the faux provider returns an assistant error message with `errorMessage: "No more faux responses queued"`.
- Use `registration.setResponses([...])` to replace the remaining queue and `registration.appendResponses([...])` to add more responses.
- `registration.models` exposes all registered faux models. `registration.getModel()` returns the first one, and `registration.getModel(id)` returns a specific one.
- Use `fauxAssistantMessage(...)` for scripted assistant replies. Use `fauxText(...)`, `fauxThinking(...)`, and `fauxToolCall(...)` to build content blocks without filling in low-level fields manually.
- `registration.unregister()` removes the temporary provider from the global API registry.
- Usage is estimated at roughly 1 token per 4 characters. When `sessionId` is present and `cacheRetention` is not `"none"`, prompt cache reads and writes are simulated automatically.
- Tool call arguments stream incrementally via `toolcall_delta` chunks.
- By default, each streamed chunk is emitted on its own microtask. Set `tokensPerSecond` to pace chunk delivery in real time.
- The intended use is one deterministic scripted flow per registration. If you need independent concurrent flows, register separate faux providers.
### Providers and Models
A **provider** offers models through a specific API. For example:
+1
View File
@@ -7,6 +7,7 @@ export * from "./models.js";
export type { BedrockOptions } from "./providers/amazon-bedrock.js";
export type { AnthropicOptions } from "./providers/anthropic.js";
export type { AzureOpenAIResponsesOptions } from "./providers/azure-openai-responses.js";
export * from "./providers/faux.js";
export type { GoogleOptions } from "./providers/google.js";
export type { GoogleGeminiCliOptions, GoogleThinkingLevel } from "./providers/google-gemini-cli.js";
export type { GoogleVertexOptions } from "./providers/google-vertex.js";
+498
View File
@@ -0,0 +1,498 @@
import { registerApiProvider, unregisterApiProviders } from "../api-registry.js";
import type {
AssistantMessage,
AssistantMessageEventStream,
Context,
ImageContent,
Message,
Model,
SimpleStreamOptions,
StreamFunction,
StreamOptions,
TextContent,
ThinkingContent,
ToolCall,
ToolResultMessage,
Usage,
} from "../types.js";
import { createAssistantMessageEventStream } from "../utils/event-stream.js";
const DEFAULT_API = "faux";
const DEFAULT_PROVIDER = "faux";
const DEFAULT_MODEL_ID = "faux-1";
const DEFAULT_MODEL_NAME = "Faux Model";
const DEFAULT_BASE_URL = "http://localhost:0";
const DEFAULT_MIN_TOKEN_SIZE = 3;
const DEFAULT_MAX_TOKEN_SIZE = 5;
const DEFAULT_USAGE: Usage = {
input: 0,
output: 0,
cacheRead: 0,
cacheWrite: 0,
totalTokens: 0,
cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, total: 0 },
};
export interface FauxModelDefinition {
id: string;
name?: string;
reasoning?: boolean;
input?: ("text" | "image")[];
cost?: { input: number; output: number; cacheRead: number; cacheWrite: number };
contextWindow?: number;
maxTokens?: number;
}
export type FauxContentBlock = TextContent | ThinkingContent | ToolCall;
export function fauxText(text: string): TextContent {
return { type: "text", text };
}
export function fauxThinking(thinking: string): ThinkingContent {
return { type: "thinking", thinking };
}
export function fauxToolCall(name: string, arguments_: ToolCall["arguments"], options: { id?: string } = {}): ToolCall {
return {
type: "toolCall",
id: options.id ?? randomId("tool"),
name,
arguments: arguments_,
};
}
function normalizeFauxAssistantContent(content: string | FauxContentBlock | FauxContentBlock[]): FauxContentBlock[] {
if (typeof content === "string") {
return [fauxText(content)];
}
return Array.isArray(content) ? content : [content];
}
export function fauxAssistantMessage(
content: string | FauxContentBlock | FauxContentBlock[],
options: {
stopReason?: AssistantMessage["stopReason"];
errorMessage?: string;
responseId?: string;
timestamp?: number;
} = {},
): AssistantMessage {
return {
role: "assistant",
content: normalizeFauxAssistantContent(content),
api: DEFAULT_API,
provider: DEFAULT_PROVIDER,
model: DEFAULT_MODEL_ID,
usage: DEFAULT_USAGE,
stopReason: options.stopReason ?? "stop",
errorMessage: options.errorMessage,
responseId: options.responseId,
timestamp: options.timestamp ?? Date.now(),
};
}
export type FauxResponseFactory = (
context: Context,
options: StreamOptions | undefined,
state: { callCount: number },
model: Model<string>,
) => AssistantMessage | Promise<AssistantMessage>;
export type FauxResponseStep = AssistantMessage | FauxResponseFactory;
export interface RegisterFauxProviderOptions {
api?: string;
provider?: string;
models?: FauxModelDefinition[];
tokensPerSecond?: number;
tokenSize?: {
min?: number;
max?: number;
};
}
export interface FauxProviderRegistration {
api: string;
models: [Model<string>, ...Model<string>[]];
getModel(): Model<string>;
getModel(modelId: string): Model<string> | undefined;
state: { callCount: number };
setResponses: (responses: FauxResponseStep[]) => void;
appendResponses: (responses: FauxResponseStep[]) => void;
getPendingResponseCount: () => number;
unregister: () => void;
}
function estimateTokens(text: string): number {
return Math.ceil(text.length / 4);
}
function randomId(prefix: string): string {
return `${prefix}:${Date.now()}:${Math.random().toString(36).slice(2)}`;
}
function contentToText(content: string | Array<TextContent | ImageContent>): string {
if (typeof content === "string") {
return content;
}
return content
.map((block) => {
if (block.type === "text") {
return block.text;
}
return `[image:${block.mimeType}:${block.data.length}]`;
})
.join("\n");
}
function assistantContentToText(content: Array<TextContent | ThinkingContent | ToolCall>): string {
return content
.map((block) => {
if (block.type === "text") {
return block.text;
}
if (block.type === "thinking") {
return block.thinking;
}
return `${block.name}:${JSON.stringify(block.arguments)}`;
})
.join("\n");
}
function toolResultToText(message: ToolResultMessage): string {
return [message.toolName, ...message.content.map((block) => contentToText([block]))].join("\n");
}
function messageToText(message: Message): string {
if (message.role === "user") {
return contentToText(message.content);
}
if (message.role === "assistant") {
return assistantContentToText(message.content);
}
return toolResultToText(message);
}
function serializeContext(context: Context): string {
const parts: string[] = [];
if (context.systemPrompt) {
parts.push(`system:${context.systemPrompt}`);
}
for (const message of context.messages) {
parts.push(`${message.role}:${messageToText(message)}`);
}
if (context.tools?.length) {
parts.push(`tools:${JSON.stringify(context.tools)}`);
}
return parts.join("\n\n");
}
function commonPrefixLength(a: string, b: string): number {
const length = Math.min(a.length, b.length);
let index = 0;
while (index < length && a[index] === b[index]) {
index++;
}
return index;
}
function withUsageEstimate(
message: AssistantMessage,
context: Context,
options: StreamOptions | undefined,
promptCache: Map<string, string>,
): AssistantMessage {
const promptText = serializeContext(context);
const promptTokens = estimateTokens(promptText);
const outputTokens = estimateTokens(assistantContentToText(message.content));
let input = promptTokens;
let cacheRead = 0;
let cacheWrite = 0;
const sessionId = options?.sessionId;
if (sessionId && options?.cacheRetention !== "none") {
const previousPrompt = promptCache.get(sessionId);
if (previousPrompt) {
const cachedChars = commonPrefixLength(previousPrompt, promptText);
cacheRead = estimateTokens(previousPrompt.slice(0, cachedChars));
cacheWrite = estimateTokens(promptText.slice(cachedChars));
input = Math.max(0, promptTokens - cacheRead);
} else {
cacheWrite = promptTokens;
}
promptCache.set(sessionId, promptText);
}
return {
...message,
usage: {
input,
output: outputTokens,
cacheRead,
cacheWrite,
totalTokens: input + outputTokens + cacheRead + cacheWrite,
cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, total: 0 },
},
};
}
function splitStringByTokenSize(text: string, minTokenSize: number, maxTokenSize: number): string[] {
const chunks: string[] = [];
let index = 0;
while (index < text.length) {
const tokenSize = minTokenSize + Math.floor(Math.random() * (maxTokenSize - minTokenSize + 1));
const charSize = Math.max(1, tokenSize * 4);
chunks.push(text.slice(index, index + charSize));
index += charSize;
}
return chunks.length > 0 ? chunks : [""];
}
function cloneMessage(message: AssistantMessage, api: string, provider: string, modelId: string): AssistantMessage {
const cloned = structuredClone(message);
return {
...cloned,
api,
provider,
model: modelId,
timestamp: cloned.timestamp ?? Date.now(),
usage: cloned.usage ?? DEFAULT_USAGE,
};
}
function createErrorMessage(error: unknown, api: string, provider: string, modelId: string): AssistantMessage {
return {
role: "assistant",
content: [],
api,
provider,
model: modelId,
usage: DEFAULT_USAGE,
stopReason: "error",
errorMessage: error instanceof Error ? error.message : String(error),
timestamp: Date.now(),
};
}
function createAbortedMessage(partial: AssistantMessage): AssistantMessage {
return {
...partial,
stopReason: "aborted",
errorMessage: "Request was aborted",
timestamp: Date.now(),
};
}
function scheduleChunk(chunk: string, tokensPerSecond: number | undefined): Promise<void> {
if (!tokensPerSecond || tokensPerSecond <= 0) {
return new Promise((resolve) => queueMicrotask(resolve));
}
const delayMs = (estimateTokens(chunk) / tokensPerSecond) * 1000;
return new Promise((resolve) => setTimeout(resolve, delayMs));
}
async function streamWithDeltas(
stream: AssistantMessageEventStream,
message: AssistantMessage,
minTokenSize: number,
maxTokenSize: number,
tokensPerSecond: number | undefined,
signal: AbortSignal | undefined,
): Promise<void> {
const partial: AssistantMessage = { ...message, content: [] };
if (signal?.aborted) {
const aborted = createAbortedMessage(partial);
stream.push({ type: "error", reason: "aborted", error: aborted });
stream.end(aborted);
return;
}
stream.push({ type: "start", partial: { ...partial } });
for (let index = 0; index < message.content.length; index++) {
if (signal?.aborted) {
const aborted = createAbortedMessage(partial);
stream.push({ type: "error", reason: "aborted", error: aborted });
stream.end(aborted);
return;
}
const block = message.content[index];
if (block.type === "thinking") {
partial.content = [...partial.content, { type: "thinking", thinking: "" }];
stream.push({ type: "thinking_start", contentIndex: index, partial: { ...partial } });
for (const chunk of splitStringByTokenSize(block.thinking, minTokenSize, maxTokenSize)) {
await scheduleChunk(chunk, tokensPerSecond);
if (signal?.aborted) {
const aborted = createAbortedMessage(partial);
stream.push({ type: "error", reason: "aborted", error: aborted });
stream.end(aborted);
return;
}
(partial.content[index] as ThinkingContent).thinking += chunk;
stream.push({ type: "thinking_delta", contentIndex: index, delta: chunk, partial: { ...partial } });
}
stream.push({
type: "thinking_end",
contentIndex: index,
content: block.thinking,
partial: { ...partial },
});
continue;
}
if (block.type === "text") {
partial.content = [...partial.content, { type: "text", text: "" }];
stream.push({ type: "text_start", contentIndex: index, partial: { ...partial } });
for (const chunk of splitStringByTokenSize(block.text, minTokenSize, maxTokenSize)) {
await scheduleChunk(chunk, tokensPerSecond);
if (signal?.aborted) {
const aborted = createAbortedMessage(partial);
stream.push({ type: "error", reason: "aborted", error: aborted });
stream.end(aborted);
return;
}
(partial.content[index] as TextContent).text += chunk;
stream.push({ type: "text_delta", contentIndex: index, delta: chunk, partial: { ...partial } });
}
stream.push({ type: "text_end", contentIndex: index, content: block.text, partial: { ...partial } });
continue;
}
partial.content = [...partial.content, { type: "toolCall", id: block.id, name: block.name, arguments: {} }];
stream.push({ type: "toolcall_start", contentIndex: index, partial: { ...partial } });
for (const chunk of splitStringByTokenSize(JSON.stringify(block.arguments), minTokenSize, maxTokenSize)) {
await scheduleChunk(chunk, tokensPerSecond);
if (signal?.aborted) {
const aborted = createAbortedMessage(partial);
stream.push({ type: "error", reason: "aborted", error: aborted });
stream.end(aborted);
return;
}
stream.push({ type: "toolcall_delta", contentIndex: index, delta: chunk, partial: { ...partial } });
}
(partial.content[index] as ToolCall).arguments = block.arguments;
stream.push({ type: "toolcall_end", contentIndex: index, toolCall: block, partial: { ...partial } });
}
if (message.stopReason === "error" || message.stopReason === "aborted") {
stream.push({ type: "error", reason: message.stopReason, error: message });
stream.end(message);
return;
}
stream.push({ type: "done", reason: message.stopReason, message });
stream.end(message);
}
export function registerFauxProvider(options: RegisterFauxProviderOptions = {}): FauxProviderRegistration {
const api = options.api ?? randomId(DEFAULT_API);
const provider = options.provider ?? DEFAULT_PROVIDER;
const sourceId = randomId("faux-provider");
const minTokenSize = Math.max(
1,
Math.min(options.tokenSize?.min ?? DEFAULT_MIN_TOKEN_SIZE, options.tokenSize?.max ?? DEFAULT_MAX_TOKEN_SIZE),
);
const maxTokenSize = Math.max(minTokenSize, options.tokenSize?.max ?? DEFAULT_MAX_TOKEN_SIZE);
let pendingResponses: FauxResponseStep[] = [];
const tokensPerSecond = options.tokensPerSecond;
const state = { callCount: 0 };
const promptCache = new Map<string, string>();
const modelDefinitions = options.models?.length
? options.models
: [
{
id: DEFAULT_MODEL_ID,
name: DEFAULT_MODEL_NAME,
reasoning: false,
input: ["text", "image"] as ("text" | "image")[],
cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0 },
contextWindow: 128000,
maxTokens: 16384,
},
];
const models = modelDefinitions.map((definition) => ({
id: definition.id,
name: definition.name ?? definition.id,
api,
provider,
baseUrl: DEFAULT_BASE_URL,
reasoning: definition.reasoning ?? false,
input: definition.input ?? ["text", "image"],
cost: definition.cost ?? { input: 0, output: 0, cacheRead: 0, cacheWrite: 0 },
contextWindow: definition.contextWindow ?? 128000,
maxTokens: definition.maxTokens ?? 16384,
})) as [Model<string>, ...Model<string>[]];
const stream: StreamFunction<string, StreamOptions> = (requestModel, context, streamOptions) => {
const outer = createAssistantMessageEventStream();
const step = pendingResponses.shift();
state.callCount++;
queueMicrotask(async () => {
try {
if (!step) {
let message = createErrorMessage(
new Error("No more faux responses queued"),
api,
provider,
requestModel.id,
);
message = withUsageEstimate(message, context, streamOptions, promptCache);
outer.push({ type: "error", reason: "error", error: message });
outer.end(message);
return;
}
const resolved =
typeof step === "function" ? await step(context, streamOptions, state, requestModel) : step;
let message = cloneMessage(resolved, api, provider, requestModel.id);
message = withUsageEstimate(message, context, streamOptions, promptCache);
await streamWithDeltas(outer, message, minTokenSize, maxTokenSize, tokensPerSecond, streamOptions?.signal);
} catch (error) {
const message = createErrorMessage(error, api, provider, requestModel.id);
outer.push({ type: "error", reason: "error", error: message });
outer.end(message);
}
});
return outer;
};
const streamSimple: StreamFunction<string, SimpleStreamOptions> = (streamModel, context, streamOptions) =>
stream(streamModel, context, streamOptions);
registerApiProvider({ api, stream, streamSimple }, sourceId);
function getModel(): Model<string>;
function getModel(requestedModelId: string): Model<string> | undefined;
function getModel(requestedModelId?: string): Model<string> | undefined {
if (!requestedModelId) {
return models[0];
}
return models.find((candidate) => candidate.id === requestedModelId);
}
return {
api,
models,
getModel,
state,
setResponses(responses) {
pendingResponses = [...responses];
},
appendResponses(responses) {
pendingResponses.push(...responses);
},
getPendingResponseCount() {
return pendingResponses.length;
},
unregister() {
unregisterApiProviders(sourceId);
},
};
}
+597
View File
@@ -0,0 +1,597 @@
import { afterEach, describe, expect, it } from "vitest";
import {
complete,
fauxAssistantMessage,
fauxText,
fauxThinking,
fauxToolCall,
registerFauxProvider,
stream,
Type,
} from "../src/index.js";
import type { AssistantMessageEvent, Context } from "../src/types.js";
async function collectEvents(streamResult: ReturnType<typeof stream>): Promise<AssistantMessageEvent[]> {
const events: AssistantMessageEvent[] = [];
for await (const event of streamResult) {
events.push(event);
}
return events;
}
const registrations: Array<{ unregister: () => void }> = [];
afterEach(() => {
for (const registration of registrations.splice(0)) {
registration.unregister();
}
});
describe("faux provider", () => {
it("registers a custom provider and estimates usage", async () => {
const registration = registerFauxProvider();
registrations.push(registration);
registration.setResponses([fauxAssistantMessage("hello world")]);
const context: Context = {
systemPrompt: "Be concise.",
messages: [{ role: "user", content: "hi there", timestamp: Date.now() }],
};
const response = await complete(registration.getModel(), context);
expect(response.content).toEqual([{ type: "text", text: "hello world" }]);
expect(response.usage.input).toBeGreaterThan(0);
expect(response.usage.output).toBeGreaterThan(0);
expect(response.usage.totalTokens).toBe(response.usage.input + response.usage.output);
expect(registration.state.callCount).toBe(1);
});
it("supports helper blocks for text, thinking, and tool calls", async () => {
const registration = registerFauxProvider();
registrations.push(registration);
registration.setResponses([
fauxAssistantMessage([fauxThinking("think"), fauxToolCall("echo", { text: "hi" }), fauxText("done")], {
stopReason: "toolUse",
}),
]);
const response = await complete(registration.getModel(), {
messages: [{ role: "user", content: "hi", timestamp: Date.now() }],
});
expect(response.content).toEqual([
{ type: "thinking", thinking: "think" },
{ type: "toolCall", id: expect.any(String), name: "echo", arguments: { text: "hi" } },
{ type: "text", text: "done" },
]);
expect(response.stopReason).toBe("toolUse");
});
it("supports multiple models with per-model reasoning and model-aware factories", async () => {
const registration = registerFauxProvider({
models: [
{ id: "faux-fast", name: "Faux Fast", reasoning: false },
{ id: "faux-thinker", name: "Faux Thinker", reasoning: true },
],
});
registrations.push(registration);
registration.setResponses([
(_context, _options, _state, model) => fauxAssistantMessage(`${model.id}:${String(model.reasoning)}`),
(_context, _options, _state, model) => fauxAssistantMessage(`${model.id}:${String(model.reasoning)}`),
]);
expect(registration.models.map((model) => model.id)).toEqual(["faux-fast", "faux-thinker"]);
expect(registration.getModel()).toBe(registration.models[0]);
expect(registration.getModel("faux-fast")?.reasoning).toBe(false);
expect(registration.getModel("faux-thinker")?.reasoning).toBe(true);
const fast = await complete(registration.getModel("faux-fast")!, {
messages: [{ role: "user", content: "hi", timestamp: Date.now() }],
});
const thinker = await complete(registration.getModel("faux-thinker")!, {
messages: [{ role: "user", content: "hi", timestamp: Date.now() }],
});
expect(fast.content).toEqual([{ type: "text", text: "faux-fast:false" }]);
expect(thinker.content).toEqual([{ type: "text", text: "faux-thinker:true" }]);
});
it("rewrites api, provider, and model on returned messages", async () => {
const registration = registerFauxProvider({
api: "faux:test",
provider: "faux-provider",
models: [{ id: "faux-model" }],
});
registrations.push(registration);
registration.setResponses([fauxAssistantMessage("hello")]);
const response = await complete(registration.getModel(), {
messages: [{ role: "user", content: "hi", timestamp: Date.now() }],
});
expect(response.api).toBe("faux:test");
expect(response.provider).toBe("faux-provider");
expect(response.model).toBe("faux-model");
});
it("consumes queued responses in order and errors when exhausted", async () => {
const registration = registerFauxProvider();
registrations.push(registration);
registration.setResponses([fauxAssistantMessage("first"), fauxAssistantMessage("second")]);
const context: Context = {
messages: [{ role: "user", content: "hi", timestamp: Date.now() }],
};
const first = await complete(registration.getModel(), context);
const second = await complete(registration.getModel(), context);
const exhausted = await complete(registration.getModel(), context);
expect(first.content).toEqual([{ type: "text", text: "first" }]);
expect(second.content).toEqual([{ type: "text", text: "second" }]);
expect(exhausted.stopReason).toBe("error");
expect(exhausted.errorMessage).toBe("No more faux responses queued");
expect(registration.getPendingResponseCount()).toBe(0);
expect(registration.state.callCount).toBe(3);
});
it("can replace and append queued responses", async () => {
const registration = registerFauxProvider();
registrations.push(registration);
registration.setResponses([fauxAssistantMessage("first")]);
const context: Context = {
messages: [{ role: "user", content: "hi", timestamp: Date.now() }],
};
expect((await complete(registration.getModel(), context)).content).toEqual([{ type: "text", text: "first" }]);
expect(registration.getPendingResponseCount()).toBe(0);
registration.setResponses([fauxAssistantMessage("second")]);
expect(registration.getPendingResponseCount()).toBe(1);
expect((await complete(registration.getModel(), context)).content).toEqual([{ type: "text", text: "second" }]);
registration.appendResponses([fauxAssistantMessage("third"), fauxAssistantMessage("fourth")]);
expect(registration.getPendingResponseCount()).toBe(2);
expect((await complete(registration.getModel(), context)).content).toEqual([{ type: "text", text: "third" }]);
expect((await complete(registration.getModel(), context)).content).toEqual([{ type: "text", text: "fourth" }]);
expect(registration.getPendingResponseCount()).toBe(0);
});
it("supports async response factories", async () => {
const registration = registerFauxProvider();
registrations.push(registration);
registration.setResponses([
async (context, _options, state) => fauxAssistantMessage(`${context.messages.length}:${state.callCount}`),
]);
const response = await complete(registration.getModel(), {
messages: [{ role: "user", content: "hi", timestamp: Date.now() }],
});
expect(response.content).toEqual([{ type: "text", text: "1:1" }]);
});
it("emits an error when a response factory throws", async () => {
const registration = registerFauxProvider();
registrations.push(registration);
registration.setResponses([
() => {
throw new Error("boom");
},
]);
const events = await collectEvents(
stream(registration.getModel(), { messages: [{ role: "user", content: "hi", timestamp: Date.now() }] }),
);
expect(events).toHaveLength(1);
expect(events[0].type).toBe("error");
if (events[0].type === "error") {
expect(events[0].error.stopReason).toBe("error");
expect(events[0].error.errorMessage).toBe("boom");
}
});
it("estimates prompt and output tokens from serialized context", async () => {
const registration = registerFauxProvider();
registrations.push(registration);
registration.setResponses([fauxAssistantMessage("done")]);
const tool = {
name: "echo",
description: "Echo back text",
parameters: Type.Object({ text: Type.String() }),
};
const context: Context = {
systemPrompt: "sys",
messages: [
{
role: "user",
content: [
{ type: "text", text: "hello" },
{ type: "image", mimeType: "image/png", data: "abcd" },
],
timestamp: 1,
},
fauxAssistantMessage("prior"),
{
role: "toolResult",
toolCallId: "tool-1",
toolName: "echo",
content: [{ type: "text", text: "tool out" }],
isError: false,
timestamp: 2,
},
],
tools: [tool],
};
const response = await complete(registration.getModel(), context);
const promptText = [
"system:sys",
"user:hello\n[image:image/png:4]",
"assistant:prior",
"toolResult:echo\ntool out",
`tools:${JSON.stringify([tool])}`,
].join("\n\n");
const expectedPromptTokens = Math.ceil(promptText.length / 4);
const expectedOutputTokens = Math.ceil("done".length / 4);
expect(response.usage.input).toBe(expectedPromptTokens);
expect(response.usage.output).toBe(expectedOutputTokens);
expect(response.usage.cacheRead).toBe(0);
expect(response.usage.cacheWrite).toBe(0);
expect(response.usage.totalTokens).toBe(expectedPromptTokens + expectedOutputTokens);
});
it("does not share cache across sessions or requests without sessionId", async () => {
const registration = registerFauxProvider();
registrations.push(registration);
registration.setResponses([
fauxAssistantMessage("first"),
fauxAssistantMessage("second"),
fauxAssistantMessage("third"),
]);
const context: Context = {
messages: [{ role: "user", content: "hello", timestamp: Date.now() }],
};
const first = await complete(registration.getModel(), context, {
sessionId: "session-1",
cacheRetention: "short",
});
expect(first.usage.cacheWrite).toBeGreaterThan(0);
context.messages.push(first);
context.messages.push({ role: "user", content: "follow up", timestamp: Date.now() + 1 });
const second = await complete(registration.getModel(), context, {
sessionId: "session-2",
cacheRetention: "short",
});
expect(second.usage.cacheRead).toBe(0);
expect(second.usage.cacheWrite).toBeGreaterThan(0);
const third = await complete(registration.getModel(), context);
expect(third.usage.cacheRead).toBe(0);
expect(third.usage.cacheWrite).toBe(0);
});
it("simulates prompt caching per sessionId", async () => {
const registration = registerFauxProvider();
registrations.push(registration);
registration.setResponses([fauxAssistantMessage("first"), fauxAssistantMessage("second")]);
const context: Context = {
systemPrompt: "Be concise.",
messages: [{ role: "user", content: "hello", timestamp: Date.now() }],
};
const first = await complete(registration.getModel(), context, {
sessionId: "session-1",
cacheRetention: "short",
});
expect(first.usage.cacheRead).toBe(0);
expect(first.usage.cacheWrite).toBeGreaterThan(0);
context.messages.push(first);
context.messages.push({ role: "user", content: "follow up", timestamp: Date.now() + 1 });
const second = await complete(registration.getModel(), context, {
sessionId: "session-1",
cacheRetention: "short",
});
expect(second.usage.cacheRead).toBeGreaterThan(0);
expect(second.usage.input + second.usage.cacheRead).toBeGreaterThan(second.usage.input);
});
it("does not simulate caching when cacheRetention is none", async () => {
const registration = registerFauxProvider();
registrations.push(registration);
registration.setResponses([fauxAssistantMessage("first"), fauxAssistantMessage("second")]);
const context: Context = {
messages: [{ role: "user", content: "hello", timestamp: Date.now() }],
};
await complete(registration.getModel(), context, { sessionId: "session-1", cacheRetention: "none" });
context.messages.push(fauxAssistantMessage("first"));
context.messages.push({ role: "user", content: "follow up", timestamp: Date.now() + 1 });
const second = await complete(registration.getModel(), context, {
sessionId: "session-1",
cacheRetention: "none",
});
expect(second.usage.cacheRead).toBe(0);
expect(second.usage.cacheWrite).toBe(0);
});
it("streams thinking, text, and partial tool call deltas", async () => {
const registration = registerFauxProvider();
registrations.push(registration);
registration.setResponses([
fauxAssistantMessage(
[
fauxThinking("thinking text"),
fauxText("answer text"),
fauxToolCall("echo", { text: "hi", count: 12 }, { id: "tool-1" }),
],
{ stopReason: "toolUse" },
),
]);
const events: string[] = [];
const toolCallDeltas: string[] = [];
const s = stream(registration.getModel(), { messages: [{ role: "user", content: "hi", timestamp: Date.now() }] });
for await (const event of s) {
events.push(event.type);
if (event.type === "toolcall_delta") {
toolCallDeltas.push(event.delta);
}
}
expect(events).toContain("thinking_start");
expect(events).toContain("thinking_delta");
expect(events).toContain("text_start");
expect(events).toContain("text_delta");
expect(events).toContain("toolcall_start");
expect(events).toContain("toolcall_delta");
expect(events).toContain("toolcall_end");
expect(toolCallDeltas.length).toBeGreaterThan(1);
expect(JSON.parse(toolCallDeltas.join(""))).toEqual({ text: "hi", count: 12 });
});
it("streams an exact event order for fixed-size chunks", async () => {
const registration = registerFauxProvider({ tokenSize: { min: 1, max: 1 } });
registrations.push(registration);
registration.setResponses([
fauxAssistantMessage([fauxThinking("go"), fauxText("ok"), fauxToolCall("echo", {}, { id: "tool-1" })], {
stopReason: "toolUse",
}),
]);
const events = await collectEvents(
stream(registration.getModel(), { messages: [{ role: "user", content: "hi", timestamp: Date.now() }] }),
);
expect(events.map((event) => event.type)).toEqual([
"start",
"thinking_start",
"thinking_delta",
"thinking_end",
"text_start",
"text_delta",
"text_end",
"toolcall_start",
"toolcall_delta",
"toolcall_end",
"done",
]);
});
it("streams multiple tool calls in one message", async () => {
const registration = registerFauxProvider();
registrations.push(registration);
registration.setResponses([
fauxAssistantMessage(
[
fauxToolCall("echo", { text: "one" }, { id: "tool-1" }),
fauxToolCall("echo", { text: "two" }, { id: "tool-2" }),
],
{ stopReason: "toolUse" },
),
]);
const events = await collectEvents(
stream(registration.getModel(), { messages: [{ role: "user", content: "hi", timestamp: Date.now() }] }),
);
expect(events.filter((event) => event.type === "toolcall_start")).toHaveLength(2);
expect(events.filter((event) => event.type === "toolcall_end")).toHaveLength(2);
});
it("streams an explicit assistant error message as a terminal error", async () => {
const registration = registerFauxProvider({ tokenSize: { min: 2, max: 2 } });
registrations.push(registration);
registration.setResponses([
{
...fauxAssistantMessage("partial"),
stopReason: "error",
errorMessage: "upstream failed",
},
]);
const events = await collectEvents(
stream(registration.getModel(), { messages: [{ role: "user", content: "hi", timestamp: Date.now() }] }),
);
expect(events.map((event) => event.type)).toEqual(["start", "text_start", "text_delta", "text_end", "error"]);
const terminal = events[events.length - 1];
expect(terminal.type).toBe("error");
if (terminal.type === "error") {
expect(terminal.reason).toBe("error");
expect(terminal.error.stopReason).toBe("error");
expect(terminal.error.errorMessage).toBe("upstream failed");
}
});
it("streams an explicit assistant aborted message as a terminal error", async () => {
const registration = registerFauxProvider({ tokenSize: { min: 2, max: 2 } });
registrations.push(registration);
registration.setResponses([
{
...fauxAssistantMessage("partial"),
stopReason: "aborted",
errorMessage: "Request was aborted",
},
]);
const events = await collectEvents(
stream(registration.getModel(), { messages: [{ role: "user", content: "hi", timestamp: Date.now() }] }),
);
expect(events.map((event) => event.type)).toEqual(["start", "text_start", "text_delta", "text_end", "error"]);
const terminal = events[events.length - 1];
expect(terminal.type).toBe("error");
if (terminal.type === "error") {
expect(terminal.reason).toBe("aborted");
expect(terminal.error.stopReason).toBe("aborted");
expect(terminal.error.errorMessage).toBe("Request was aborted");
}
});
it("supports aborting before the first chunk", async () => {
const registration = registerFauxProvider({ tokensPerSecond: 50, tokenSize: { min: 3, max: 3 } });
registrations.push(registration);
registration.setResponses([fauxAssistantMessage("abcdefghijklmnopqrstuvwxyz")]);
const controller = new AbortController();
controller.abort();
const events = await collectEvents(
stream(
registration.getModel(),
{ messages: [{ role: "user", content: "hi", timestamp: Date.now() }] },
{ signal: controller.signal },
),
);
expect(events).toHaveLength(1);
expect(events[0].type).toBe("error");
if (events[0].type === "error") {
expect(events[0].reason).toBe("aborted");
expect(events[0].error.stopReason).toBe("aborted");
}
});
it("supports aborting mid-text stream when paced", async () => {
const registration = registerFauxProvider({ tokensPerSecond: 100, tokenSize: { min: 3, max: 3 } });
registrations.push(registration);
registration.setResponses([fauxAssistantMessage("abcdefghijklmnopqrstuvwxyz")]);
const controller = new AbortController();
const events: string[] = [];
let textDeltaCount = 0;
const s = stream(
registration.getModel(),
{ messages: [{ role: "user", content: "hi", timestamp: Date.now() }] },
{ signal: controller.signal },
);
for await (const event of s) {
events.push(event.type);
if (event.type === "text_delta") {
textDeltaCount++;
controller.abort();
}
}
expect(textDeltaCount).toBe(1);
expect(events).toContain("text_start");
expect(events).toContain("text_delta");
expect(events).toContain("error");
expect(events).not.toContain("text_end");
});
it("supports aborting mid-thinking stream when paced", async () => {
const registration = registerFauxProvider({ tokensPerSecond: 100, tokenSize: { min: 3, max: 3 } });
registrations.push(registration);
registration.setResponses([
{
...fauxAssistantMessage("ignored"),
content: [{ type: "thinking", thinking: "abcdefghijklmnopqrstuvwxyz" }],
},
]);
const controller = new AbortController();
const events: string[] = [];
let thinkingDeltaCount = 0;
const s = stream(
registration.getModel(),
{ messages: [{ role: "user", content: "hi", timestamp: Date.now() }] },
{ signal: controller.signal },
);
for await (const event of s) {
events.push(event.type);
if (event.type === "thinking_delta") {
thinkingDeltaCount++;
controller.abort();
}
}
expect(thinkingDeltaCount).toBe(1);
expect(events).toContain("thinking_start");
expect(events).toContain("thinking_delta");
expect(events).toContain("error");
expect(events).not.toContain("thinking_end");
});
it("supports aborting mid-toolcall stream when paced", async () => {
const registration = registerFauxProvider({ tokensPerSecond: 100, tokenSize: { min: 3, max: 3 } });
registrations.push(registration);
registration.setResponses([
{
...fauxAssistantMessage("done"),
content: [
{
type: "toolCall",
id: "tool-1",
name: "echo",
arguments: { text: "abcdefghijklmnopqrstuvwxyz", count: 123456789 },
},
],
stopReason: "toolUse",
},
]);
const controller = new AbortController();
const events: string[] = [];
let toolCallDeltaCount = 0;
const s = stream(
registration.getModel(),
{ messages: [{ role: "user", content: "hi", timestamp: Date.now() }] },
{ signal: controller.signal },
);
for await (const event of s) {
events.push(event.type);
if (event.type === "toolcall_delta") {
toolCallDeltaCount++;
controller.abort();
}
}
expect(toolCallDeltaCount).toBe(1);
expect(events).toContain("toolcall_start");
expect(events).toContain("toolcall_delta");
expect(events).toContain("error");
expect(events).not.toContain("toolcall_end");
});
it("unregisters the provider", async () => {
const registration = registerFauxProvider();
registration.setResponses([fauxAssistantMessage("hello")]);
registration.unregister();
await expect(
complete(registration.getModel(), { messages: [{ role: "user", content: "hi", timestamp: Date.now() }] }),
).rejects.toThrow(`No API provider registered for api: ${registration.api}`);
});
});
+1 -1
View File
@@ -395,7 +395,7 @@ import { AuthStorage, createAgentSession, ModelRegistry, SessionManager } from "
const { session } = await createAgentSession({
sessionManager: SessionManager.inMemory(),
authStorage: AuthStorage.create(),
modelRegistry: new ModelRegistry(authStorage),
modelRegistry: ModelRegistry.create(authStorage),
});
await session.prompt("What files are in the current directory?");
+6 -6
View File
@@ -20,7 +20,7 @@ import { AuthStorage, createAgentSession, ModelRegistry, SessionManager } from "
// Set up credential storage and model registry
const authStorage = AuthStorage.create();
const modelRegistry = new ModelRegistry(authStorage);
const modelRegistry = ModelRegistry.create(authStorage);
const { session } = await createAgentSession({
sessionManager: SessionManager.inMemory(),
@@ -289,7 +289,7 @@ import { getModel } from "@mariozechner/pi-ai";
import { AuthStorage, ModelRegistry } from "@mariozechner/pi-coding-agent";
const authStorage = AuthStorage.create();
const modelRegistry = new ModelRegistry(authStorage);
const modelRegistry = ModelRegistry.create(authStorage);
// Find specific built-in model (doesn't check if API key exists)
const opus = getModel("anthropic", "claude-opus-4-5");
@@ -337,7 +337,7 @@ import { AuthStorage, ModelRegistry } from "@mariozechner/pi-coding-agent";
// Default: uses ~/.pi/agent/auth.json and ~/.pi/agent/models.json
const authStorage = AuthStorage.create();
const modelRegistry = new ModelRegistry(authStorage);
const modelRegistry = ModelRegistry.create(authStorage);
const { session } = await createAgentSession({
sessionManager: SessionManager.inMemory(),
@@ -350,7 +350,7 @@ authStorage.setRuntimeApiKey("anthropic", "sk-my-temp-key");
// Custom auth storage location
const customAuth = AuthStorage.create("/my/app/auth.json");
const customRegistry = new ModelRegistry(customAuth, "/my/app/models.json");
const customRegistry = ModelRegistry.create(customAuth, "/my/app/models.json");
const { session } = await createAgentSession({
sessionManager: SessionManager.inMemory(),
@@ -359,7 +359,7 @@ const { session } = await createAgentSession({
});
// No custom models.json (built-in models only)
const simpleRegistry = new ModelRegistry(authStorage);
const simpleRegistry = ModelRegistry.inMemory(authStorage);
```
> See [examples/sdk/09-api-keys-and-oauth.ts](../examples/sdk/09-api-keys-and-oauth.ts)
@@ -788,7 +788,7 @@ if (process.env.MY_KEY) {
}
// Model registry (no custom models.json)
const modelRegistry = new ModelRegistry(authStorage);
const modelRegistry = ModelRegistry.create(authStorage);
// Inline tool
const statusTool: ToolDefinition = {
@@ -9,7 +9,7 @@ import { AuthStorage, createAgentSession, ModelRegistry } from "@mariozechner/pi
// Set up auth storage and model registry
const authStorage = AuthStorage.create();
const modelRegistry = new ModelRegistry(authStorage);
const modelRegistry = ModelRegistry.create(authStorage);
// Option 1: Find a specific built-in model by provider/id
const opus = getModel("anthropic", "claude-opus-4-5");
@@ -9,7 +9,7 @@ import { AuthStorage, createAgentSession, ModelRegistry, SessionManager } from "
// Default: AuthStorage uses ~/.pi/agent/auth.json
// ModelRegistry loads built-in + custom models from ~/.pi/agent/models.json
const authStorage = AuthStorage.create();
const modelRegistry = new ModelRegistry(authStorage);
const modelRegistry = ModelRegistry.create(authStorage);
await createAgentSession({
sessionManager: SessionManager.inMemory(),
@@ -20,7 +20,7 @@ console.log("Session with default auth storage and model registry");
// Custom auth storage location
const customAuthStorage = AuthStorage.create("/tmp/my-app/auth.json");
const customModelRegistry = new ModelRegistry(customAuthStorage, "/tmp/my-app/models.json");
const customModelRegistry = ModelRegistry.create(customAuthStorage, "/tmp/my-app/models.json");
await createAgentSession({
sessionManager: SessionManager.inMemory(),
@@ -39,7 +39,7 @@ await createAgentSession({
console.log("Session with runtime API key override");
// No models.json - only built-in models
const simpleRegistry = new ModelRegistry(authStorage); // null = no models.json
const simpleRegistry = ModelRegistry.inMemory(authStorage);
await createAgentSession({
sessionManager: SessionManager.inMemory(),
authStorage,
@@ -30,7 +30,7 @@ if (process.env.MY_ANTHROPIC_KEY) {
}
// Model registry with no custom models.json
const modelRegistry = new ModelRegistry(authStorage);
const modelRegistry = ModelRegistry.inMemory(authStorage);
const model = getModel("anthropic", "claude-sonnet-4-20250514");
if (!model) throw new Error("Model not found");
+3 -3
View File
@@ -44,7 +44,7 @@ import {
// Auth and models setup
const authStorage = AuthStorage.create();
const modelRegistry = new ModelRegistry(authStorage);
const modelRegistry = ModelRegistry.create(authStorage);
// Minimal
const { session } = await createAgentSession({ authStorage, modelRegistry });
@@ -73,7 +73,7 @@ const { session } = await createAgentSession({
// Full control
const customAuth = AuthStorage.create("/my/app/auth.json");
customAuth.setRuntimeApiKey("anthropic", process.env.MY_KEY!);
const customRegistry = new ModelRegistry(customAuth);
const customRegistry = ModelRegistry.create(customAuth);
const resourceLoader = new DefaultResourceLoader({
systemPromptOverride: () => "You are helpful.",
@@ -109,7 +109,7 @@ await session.prompt("Hello");
| Option | Default | Description |
|--------|---------|-------------|
| `authStorage` | `AuthStorage.create()` | Credential storage |
| `modelRegistry` | `new ModelRegistry(authStorage)` | Model registry |
| `modelRegistry` | `ModelRegistry.create(authStorage)` | Model registry |
| `cwd` | `process.cwd()` | Working directory |
| `agentDir` | `~/.pi/agent` | Config directory |
| `model` | From settings/first available | Model to use |
@@ -259,13 +259,21 @@ export class ModelRegistry {
private registeredProviders: Map<string, ProviderConfigInput> = new Map();
private loadError: string | undefined = undefined;
constructor(
private constructor(
readonly authStorage: AuthStorage,
private modelsJsonPath: string | undefined = join(getAgentDir(), "models.json"),
private modelsJsonPath: string | undefined,
) {
this.loadModels();
}
static create(authStorage: AuthStorage, modelsJsonPath: string = join(getAgentDir(), "models.json")): ModelRegistry {
return new ModelRegistry(authStorage, modelsJsonPath);
}
static inMemory(authStorage: AuthStorage): ModelRegistry {
return new ModelRegistry(authStorage, undefined);
}
/**
* Reload models from disk (built-in + custom from models.json).
*/
+2 -2
View File
@@ -47,7 +47,7 @@ export interface CreateAgentSessionOptions {
/** Auth storage for credentials. Default: AuthStorage.create(agentDir/auth.json) */
authStorage?: AuthStorage;
/** Model registry. Default: new ModelRegistry(authStorage, agentDir/models.json) */
/** Model registry. Default: ModelRegistry.create(authStorage, agentDir/models.json) */
modelRegistry?: ModelRegistry;
/** Model to use. Default: from settings, else first available */
@@ -172,7 +172,7 @@ export async function createAgentSession(options: CreateAgentSessionOptions = {}
const authPath = options.agentDir ? join(agentDir, "auth.json") : undefined;
const modelsPath = options.agentDir ? join(agentDir, "models.json") : undefined;
const authStorage = options.authStorage ?? AuthStorage.create(authPath);
const modelRegistry = options.modelRegistry ?? new ModelRegistry(authStorage, modelsPath);
const modelRegistry = options.modelRegistry ?? ModelRegistry.create(authStorage, modelsPath);
const settingsManager = options.settingsManager ?? SettingsManager.create(cwd, agentDir);
const sessionManager = options.sessionManager ?? SessionManager.create(cwd, getDefaultSessionDir(cwd, agentDir));
+1 -1
View File
@@ -654,7 +654,7 @@ export async function main(args: string[]) {
const settingsManager = SettingsManager.create(cwd, agentDir);
reportSettingsErrors(settingsManager, "startup");
const authStorage = AuthStorage.create();
const modelRegistry = new ModelRegistry(authStorage, getModelsPath());
const modelRegistry = ModelRegistry.create(authStorage, getModelsPath());
const resourceLoader = new DefaultResourceLoader({
cwd,
@@ -76,7 +76,7 @@ describe("AgentSession auto-compaction queue resume", () => {
const settingsManager = SettingsManager.create(tempDir, tempDir);
const authStorage = AuthStorage.create(join(tempDir, "auth.json"));
authStorage.setRuntimeApiKey("anthropic", "test-key");
const modelRegistry = new ModelRegistry(authStorage, tempDir);
const modelRegistry = ModelRegistry.create(authStorage, tempDir);
session = new AgentSession({
agent,
@@ -55,7 +55,7 @@ describe.skipIf(!API_KEY)("AgentSession forking", () => {
sessionManager = noSession ? SessionManager.inMemory() : SessionManager.create(tempDir);
const settingsManager = SettingsManager.create(tempDir, tempDir);
const authStorage = AuthStorage.create(join(tempDir, "auth.json"));
const modelRegistry = new ModelRegistry(authStorage, tempDir);
const modelRegistry = ModelRegistry.create(authStorage, tempDir);
session = new AgentSession({
agent,
@@ -61,7 +61,7 @@ describe.skipIf(!API_KEY)("AgentSession compaction e2e", () => {
// Use minimal keepRecentTokens so small test conversations have something to summarize
settingsManager.applyOverrides({ compaction: { keepRecentTokens: 1 } });
const authStorage = AuthStorage.create(join(tempDir, "auth.json"));
const modelRegistry = new ModelRegistry(authStorage);
const modelRegistry = ModelRegistry.create(authStorage);
session = new AgentSession({
agent,
@@ -110,7 +110,7 @@ describe("AgentSession concurrent prompt guard", () => {
const sessionManager = SessionManager.inMemory();
const settingsManager = SettingsManager.create(tempDir, tempDir);
const authStorage = AuthStorage.create(join(tempDir, "auth.json"));
const modelRegistry = new ModelRegistry(authStorage, tempDir);
const modelRegistry = ModelRegistry.create(authStorage, tempDir);
// Set a runtime API key so validation passes
authStorage.setRuntimeApiKey("anthropic", "test-key");
@@ -235,7 +235,7 @@ describe("AgentSession concurrent prompt guard", () => {
const sessionManager = SessionManager.inMemory();
const settingsManager = SettingsManager.create(tempDir, tempDir);
const authStorage = AuthStorage.create(join(tempDir, "auth.json"));
const modelRegistry = new ModelRegistry(authStorage, tempDir);
const modelRegistry = ModelRegistry.create(authStorage, tempDir);
authStorage.setRuntimeApiKey("anthropic", "test-key");
const extensionsResult = await createTestExtensionsResult([
@@ -313,7 +313,7 @@ describe("AgentSession concurrent prompt guard", () => {
const sessionManager = SessionManager.inMemory();
const settingsManager = SettingsManager.create(tempDir, tempDir);
const authStorage = AuthStorage.create(join(tempDir, "auth.json"));
const modelRegistry = new ModelRegistry(authStorage, tempDir);
const modelRegistry = ModelRegistry.create(authStorage, tempDir);
authStorage.setRuntimeApiKey("anthropic", "test-key");
session = new AgentSession({
@@ -419,7 +419,7 @@ describe("AgentSession concurrent prompt guard", () => {
const sessionManager = SessionManager.inMemory();
const settingsManager = SettingsManager.create(tempDir, tempDir);
const authStorage = AuthStorage.create(join(tempDir, "auth.json"));
const modelRegistry = new ModelRegistry(authStorage, tempDir);
const modelRegistry = ModelRegistry.create(authStorage, tempDir);
authStorage.setRuntimeApiKey("anthropic", "test-key");
session = new AgentSession({
@@ -556,7 +556,7 @@ describe("AgentSession concurrent prompt guard", () => {
const sessionManager = SessionManager.inMemory();
const settingsManager = SettingsManager.create(tempDir, tempDir);
const authStorage = AuthStorage.create(join(tempDir, "auth.json"));
const modelRegistry = new ModelRegistry(authStorage, tempDir);
const modelRegistry = ModelRegistry.create(authStorage, tempDir);
authStorage.setRuntimeApiKey("anthropic", "test-key");
session = new AgentSession({
@@ -37,7 +37,7 @@ function createSession({
sessionManager,
settingsManager,
cwd: process.cwd(),
modelRegistry: new ModelRegistry(authStorage, undefined),
modelRegistry: ModelRegistry.inMemory(authStorage),
resourceLoader: createTestResourceLoader(),
scopedModels,
});
@@ -102,7 +102,7 @@ describe("AgentSession retry", () => {
const sessionManager = SessionManager.inMemory();
const settingsManager = SettingsManager.create(tempDir, tempDir);
const authStorage = AuthStorage.create(join(tempDir, "auth.json"));
const modelRegistry = new ModelRegistry(authStorage, tempDir);
const modelRegistry = ModelRegistry.create(authStorage, tempDir);
authStorage.setRuntimeApiKey("anthropic", "test-key");
settingsManager.applyOverrides({ retry: { enabled: true, maxRetries, baseDelayMs: 1 } });
@@ -204,7 +204,7 @@ describe("AgentSession retry", () => {
const sessionManager = SessionManager.inMemory();
const settingsManager = SettingsManager.create(tempDir, tempDir);
const authStorage = AuthStorage.create(join(tempDir, "auth.json"));
const modelRegistry = new ModelRegistry(authStorage, tempDir);
const modelRegistry = ModelRegistry.create(authStorage, tempDir);
authStorage.setRuntimeApiKey("anthropic", "test-key");
settingsManager.applyOverrides({ retry: { enabled: true, maxRetries: 3, baseDelayMs: 1 } });
session = new AgentSession({
@@ -289,7 +289,7 @@ describe("AgentSession retry", () => {
const sessionManager = SessionManager.inMemory();
const settingsManager = SettingsManager.create(tempDir, tempDir);
const authStorage = AuthStorage.create(join(tempDir, "auth.json"));
const modelRegistry = new ModelRegistry(authStorage, tempDir);
const modelRegistry = ModelRegistry.create(authStorage, tempDir);
authStorage.setRuntimeApiKey("anthropic", "test-key");
settingsManager.applyOverrides({ retry: { enabled: true, maxRetries: 3, baseDelayMs: 1 } });
@@ -66,7 +66,7 @@ function createSession() {
sessionManager,
settingsManager,
cwd: process.cwd(),
modelRegistry: new ModelRegistry(authStorage, undefined),
modelRegistry: ModelRegistry.inMemory(authStorage),
resourceLoader: createTestResourceLoader(),
});
@@ -99,7 +99,7 @@ describe.skipIf(!API_KEY)("Compaction extensions", () => {
const sessionManager = SessionManager.create(tempDir);
const settingsManager = SettingsManager.create(tempDir, tempDir);
const authStorage = AuthStorage.create(join(tempDir, "auth.json"));
const modelRegistry = new ModelRegistry(authStorage);
const modelRegistry = ModelRegistry.create(authStorage);
const runtime = createExtensionRuntime();
const resourceLoader = {
@@ -81,7 +81,7 @@ describe.skipIf(!HAS_ANTIGRAVITY_AUTH)("Compaction with thinking models (Antigra
// settingsManager.applyOverrides({ compaction: { keepRecentTokens: 1 } });
const authStorage = getRealAuthStorage();
const modelRegistry = new ModelRegistry(authStorage);
const modelRegistry = ModelRegistry.create(authStorage);
session = new AgentSession({
agent,
@@ -177,7 +177,7 @@ describe.skipIf(!HAS_ANTHROPIC_AUTH)("Compaction with thinking models (Anthropic
const settingsManager = SettingsManager.create(tempDir, tempDir);
const authStorage = getRealAuthStorage();
const modelRegistry = new ModelRegistry(authStorage);
const modelRegistry = ModelRegistry.create(authStorage);
session = new AgentSession({
agent,
@@ -29,7 +29,7 @@ describe("Input Event", () => {
for (let i = 0; i < extensions.length; i++) fs.writeFileSync(path.join(extensionsDir, `e${i}.ts`), extensions[i]);
const result = await discoverAndLoadExtensions([], tempDir, tempDir);
const sm = SessionManager.inMemory();
const mr = new ModelRegistry(AuthStorage.create(path.join(tempDir, "auth.json")));
const mr = ModelRegistry.create(AuthStorage.create(path.join(tempDir, "auth.json")));
return new ExtensionRunner(result.extensions, result.runtime, tempDir, sm, mr);
}
@@ -27,7 +27,7 @@ describe("ExtensionRunner", () => {
fs.mkdirSync(extensionsDir);
sessionManager = SessionManager.inMemory();
const authStorage = AuthStorage.create(path.join(tempDir, "auth.json"));
modelRegistry = new ModelRegistry(authStorage);
modelRegistry = ModelRegistry.create(authStorage);
});
afterEach(() => {
@@ -94,7 +94,7 @@ describe("ModelRegistry", () => {
anthropic: overrideConfig("https://my-proxy.example.com/v1"),
});
const registry = new ModelRegistry(authStorage, modelsJsonPath);
const registry = ModelRegistry.create(authStorage, modelsJsonPath);
const anthropicModels = getModelsForProvider(registry, "anthropic");
// Should have multiple built-in models, not just one
@@ -107,7 +107,7 @@ describe("ModelRegistry", () => {
anthropic: overrideConfig("https://my-proxy.example.com/v1"),
});
const registry = new ModelRegistry(authStorage, modelsJsonPath);
const registry = ModelRegistry.create(authStorage, modelsJsonPath);
const anthropicModels = getModelsForProvider(registry, "anthropic");
// All models should have the new baseUrl
@@ -123,7 +123,7 @@ describe("ModelRegistry", () => {
}),
});
const registry = new ModelRegistry(authStorage, modelsJsonPath);
const registry = ModelRegistry.create(authStorage, modelsJsonPath);
const anthropicModels = getModelsForProvider(registry, "anthropic");
for (const model of anthropicModels) {
@@ -140,7 +140,7 @@ describe("ModelRegistry", () => {
anthropic: overrideConfig("https://my-proxy.example.com/v1"),
});
const registry = new ModelRegistry(authStorage, modelsJsonPath);
const registry = ModelRegistry.create(authStorage, modelsJsonPath);
const googleModels = getModelsForProvider(registry, "google");
// Google models should still have their original baseUrl
@@ -160,7 +160,7 @@ describe("ModelRegistry", () => {
),
});
const registry = new ModelRegistry(authStorage, modelsJsonPath);
const registry = ModelRegistry.create(authStorage, modelsJsonPath);
// Anthropic: multiple built-in models with new baseUrl
const anthropicModels = getModelsForProvider(registry, "anthropic");
@@ -177,7 +177,7 @@ describe("ModelRegistry", () => {
writeRawModelsJson({
anthropic: overrideConfig("https://first-proxy.example.com/v1"),
});
const registry = new ModelRegistry(authStorage, modelsJsonPath);
const registry = ModelRegistry.create(authStorage, modelsJsonPath);
expect(getModelsForProvider(registry, "anthropic")[0].baseUrl).toBe("https://first-proxy.example.com/v1");
@@ -197,7 +197,7 @@ describe("ModelRegistry", () => {
anthropic: providerConfig("https://my-proxy.example.com/v1", [{ id: "claude-custom" }]),
});
const registry = new ModelRegistry(authStorage, modelsJsonPath);
const registry = ModelRegistry.create(authStorage, modelsJsonPath);
const anthropicModels = getModelsForProvider(registry, "anthropic");
expect(anthropicModels.length).toBeGreaterThan(1);
@@ -214,7 +214,7 @@ describe("ModelRegistry", () => {
),
});
const registry = new ModelRegistry(authStorage, modelsJsonPath);
const registry = ModelRegistry.create(authStorage, modelsJsonPath);
const models = getModelsForProvider(registry, "openrouter");
const sonnetModels = models.filter((m) => m.id === "anthropic/claude-sonnet-4");
@@ -227,7 +227,7 @@ describe("ModelRegistry", () => {
anthropic: providerConfig("https://my-proxy.example.com/v1", [{ id: "claude-custom" }]),
});
const registry = new ModelRegistry(authStorage, modelsJsonPath);
const registry = ModelRegistry.create(authStorage, modelsJsonPath);
expect(getModelsForProvider(registry, "google").length).toBeGreaterThan(0);
expect(getModelsForProvider(registry, "openai").length).toBeGreaterThan(0);
@@ -238,7 +238,7 @@ describe("ModelRegistry", () => {
anthropic: providerConfig("https://merged-proxy.example.com/v1", [{ id: "claude-custom" }]),
});
const registry = new ModelRegistry(authStorage, modelsJsonPath);
const registry = ModelRegistry.create(authStorage, modelsJsonPath);
const anthropicModels = getModelsForProvider(registry, "anthropic");
for (const model of anthropicModels) {
@@ -269,7 +269,7 @@ describe("ModelRegistry", () => {
},
});
const registry = new ModelRegistry(authStorage, modelsJsonPath);
const registry = ModelRegistry.create(authStorage, modelsJsonPath);
const compat = registry.find("demo", "demo-model")?.compat as OpenAICompletionsCompat | undefined;
expect(compat?.supportsUsageInStreaming).toBe(false);
@@ -303,7 +303,7 @@ describe("ModelRegistry", () => {
},
});
const registry = new ModelRegistry(authStorage, modelsJsonPath);
const registry = ModelRegistry.create(authStorage, modelsJsonPath);
const compat = registry.find("demo", "demo-model")?.compat as OpenAICompletionsCompat | undefined;
expect(compat?.supportsUsageInStreaming).toBe(true);
@@ -320,7 +320,7 @@ describe("ModelRegistry", () => {
},
});
const registry = new ModelRegistry(authStorage, modelsJsonPath);
const registry = ModelRegistry.create(authStorage, modelsJsonPath);
const models = getModelsForProvider(registry, "openrouter");
expect(models.length).toBeGreaterThan(0);
@@ -357,7 +357,7 @@ describe("ModelRegistry", () => {
},
});
const registry = new ModelRegistry(authStorage, modelsJsonPath);
const registry = ModelRegistry.create(authStorage, modelsJsonPath);
const compat = registry.find("demo", "demo-model")?.compat as OpenAICompletionsCompat | undefined;
expect(registry.getError()).toBeUndefined();
@@ -394,7 +394,7 @@ describe("ModelRegistry", () => {
},
});
const registry = new ModelRegistry(authStorage, modelsJsonPath);
const registry = ModelRegistry.create(authStorage, modelsJsonPath);
const m25 = registry.find("opencode-go", "minimax-m2.5");
const glm5 = registry.find("opencode-go", "glm-5");
@@ -427,7 +427,7 @@ describe("ModelRegistry", () => {
},
});
const registry = new ModelRegistry(authStorage, modelsJsonPath);
const registry = ModelRegistry.create(authStorage, modelsJsonPath);
const models = getModelsForProvider(registry, "openrouter");
expect(models.some((m) => m.id === "custom/openrouter-model")).toBe(true);
@@ -440,7 +440,7 @@ describe("ModelRegistry", () => {
writeModelsJson({
anthropic: providerConfig("https://first-proxy.example.com/v1", [{ id: "claude-custom" }]),
});
const registry = new ModelRegistry(authStorage, modelsJsonPath);
const registry = ModelRegistry.create(authStorage, modelsJsonPath);
expect(getModelsForProvider(registry, "anthropic").some((m) => m.id === "claude-custom")).toBe(true);
// Update and refresh
@@ -459,7 +459,7 @@ describe("ModelRegistry", () => {
writeModelsJson({
anthropic: providerConfig("https://proxy.example.com/v1", [{ id: "claude-custom" }]),
});
const registry = new ModelRegistry(authStorage, modelsJsonPath);
const registry = ModelRegistry.create(authStorage, modelsJsonPath);
expect(getModelsForProvider(registry, "anthropic").some((m) => m.id === "claude-custom")).toBe(true);
// Remove custom models and refresh
@@ -484,7 +484,7 @@ describe("ModelRegistry", () => {
},
});
const registry = new ModelRegistry(authStorage, modelsJsonPath);
const registry = ModelRegistry.create(authStorage, modelsJsonPath);
const models = getModelsForProvider(registry, "openrouter");
const sonnet = models.find((m) => m.id === "anthropic/claude-sonnet-4");
@@ -508,7 +508,7 @@ describe("ModelRegistry", () => {
},
});
const registry = new ModelRegistry(authStorage, modelsJsonPath);
const registry = ModelRegistry.create(authStorage, modelsJsonPath);
const models = getModelsForProvider(registry, "openrouter");
const sonnet = models.find((m) => m.id === "anthropic/claude-sonnet-4");
@@ -529,7 +529,7 @@ describe("ModelRegistry", () => {
},
});
const registry = new ModelRegistry(authStorage, modelsJsonPath);
const registry = ModelRegistry.create(authStorage, modelsJsonPath);
const models = getModelsForProvider(registry, "openrouter");
const sonnet = models.find((m) => m.id === "anthropic/claude-sonnet-4");
@@ -552,7 +552,7 @@ describe("ModelRegistry", () => {
},
});
const registry = new ModelRegistry(authStorage, modelsJsonPath);
const registry = ModelRegistry.create(authStorage, modelsJsonPath);
const models = getModelsForProvider(registry, "openrouter");
const sonnet = models.find((m) => m.id === "anthropic/claude-sonnet-4");
@@ -576,7 +576,7 @@ describe("ModelRegistry", () => {
},
});
const registry = new ModelRegistry(authStorage, modelsJsonPath);
const registry = ModelRegistry.create(authStorage, modelsJsonPath);
const models = getModelsForProvider(registry, "openrouter");
const sonnet = models.find((m) => m.id === "anthropic/claude-sonnet-4");
@@ -601,7 +601,7 @@ describe("ModelRegistry", () => {
},
});
const registry = new ModelRegistry(authStorage, modelsJsonPath);
const registry = ModelRegistry.create(authStorage, modelsJsonPath);
const models = getModelsForProvider(registry, "openrouter");
// Should not create a new model
@@ -621,7 +621,7 @@ describe("ModelRegistry", () => {
},
});
const registry = new ModelRegistry(authStorage, modelsJsonPath);
const registry = ModelRegistry.create(authStorage, modelsJsonPath);
const models = getModelsForProvider(registry, "openrouter");
const sonnet = models.find((m) => m.id === "anthropic/claude-sonnet-4");
@@ -642,7 +642,7 @@ describe("ModelRegistry", () => {
},
});
const registry = new ModelRegistry(authStorage, modelsJsonPath);
const registry = ModelRegistry.create(authStorage, modelsJsonPath);
const models = getModelsForProvider(registry, "openrouter");
const sonnet = models.find((m) => m.id === "anthropic/claude-sonnet-4");
expect(sonnet).toBeDefined();
@@ -665,7 +665,7 @@ describe("ModelRegistry", () => {
},
});
const registry = new ModelRegistry(authStorage, modelsJsonPath);
const registry = ModelRegistry.create(authStorage, modelsJsonPath);
expect(
getModelsForProvider(registry, "openrouter").find((m) => m.id === "anthropic/claude-sonnet-4")?.name,
).toBe("First Name");
@@ -698,7 +698,7 @@ describe("ModelRegistry", () => {
},
});
const registry = new ModelRegistry(authStorage, modelsJsonPath);
const registry = ModelRegistry.create(authStorage, modelsJsonPath);
const customName = getModelsForProvider(registry, "openrouter").find(
(m) => m.id === "anthropic/claude-sonnet-4",
)?.name;
@@ -717,7 +717,7 @@ describe("ModelRegistry", () => {
describe("dynamic provider lifecycle", () => {
test("failed registerProvider does not persist invalid streamSimple config", () => {
const registry = new ModelRegistry(authStorage, modelsJsonPath);
const registry = ModelRegistry.create(authStorage, modelsJsonPath);
expect(() =>
registry.registerProvider("broken-provider", {
@@ -731,7 +731,7 @@ describe("ModelRegistry", () => {
});
test("failed registerProvider does not remove existing provider models", () => {
const registry = new ModelRegistry(authStorage, modelsJsonPath);
const registry = ModelRegistry.create(authStorage, modelsJsonPath);
registry.registerProvider("demo-provider", {
baseUrl: "https://provider.test/v1",
@@ -776,7 +776,7 @@ describe("ModelRegistry", () => {
});
test("unregisterProvider removes custom OAuth provider and restores built-in OAuth provider", () => {
const registry = new ModelRegistry(authStorage, modelsJsonPath);
const registry = ModelRegistry.create(authStorage, modelsJsonPath);
registry.registerProvider("anthropic", {
oauth: {
@@ -799,7 +799,7 @@ describe("ModelRegistry", () => {
});
test("unregisterProvider removes custom streamSimple override and restores built-in API stream handler", () => {
const registry = new ModelRegistry(authStorage, modelsJsonPath);
const registry = ModelRegistry.create(authStorage, modelsJsonPath);
registry.registerProvider("stream-override-provider", {
api: "openai-completions",
@@ -855,7 +855,7 @@ describe("ModelRegistry", () => {
"custom-provider": providerWithApiKey("!echo test-api-key-from-command"),
});
const registry = new ModelRegistry(authStorage, modelsJsonPath);
const registry = ModelRegistry.create(authStorage, modelsJsonPath);
const apiKey = await registry.getApiKeyForProvider("custom-provider");
expect(apiKey).toBe("test-api-key-from-command");
@@ -866,7 +866,7 @@ describe("ModelRegistry", () => {
"custom-provider": providerWithApiKey("!echo ' spaced-key '"),
});
const registry = new ModelRegistry(authStorage, modelsJsonPath);
const registry = ModelRegistry.create(authStorage, modelsJsonPath);
const apiKey = await registry.getApiKeyForProvider("custom-provider");
expect(apiKey).toBe("spaced-key");
@@ -877,7 +877,7 @@ describe("ModelRegistry", () => {
"custom-provider": providerWithApiKey("!printf 'line1\\nline2'"),
});
const registry = new ModelRegistry(authStorage, modelsJsonPath);
const registry = ModelRegistry.create(authStorage, modelsJsonPath);
const apiKey = await registry.getApiKeyForProvider("custom-provider");
expect(apiKey).toBe("line1\nline2");
@@ -888,7 +888,7 @@ describe("ModelRegistry", () => {
"custom-provider": providerWithApiKey("!exit 1"),
});
const registry = new ModelRegistry(authStorage, modelsJsonPath);
const registry = ModelRegistry.create(authStorage, modelsJsonPath);
const apiKey = await registry.getApiKeyForProvider("custom-provider");
expect(apiKey).toBeUndefined();
@@ -899,7 +899,7 @@ describe("ModelRegistry", () => {
"custom-provider": providerWithApiKey("!nonexistent-command-12345"),
});
const registry = new ModelRegistry(authStorage, modelsJsonPath);
const registry = ModelRegistry.create(authStorage, modelsJsonPath);
const apiKey = await registry.getApiKeyForProvider("custom-provider");
expect(apiKey).toBeUndefined();
@@ -910,7 +910,7 @@ describe("ModelRegistry", () => {
"custom-provider": providerWithApiKey("!printf ''"),
});
const registry = new ModelRegistry(authStorage, modelsJsonPath);
const registry = ModelRegistry.create(authStorage, modelsJsonPath);
const apiKey = await registry.getApiKeyForProvider("custom-provider");
expect(apiKey).toBeUndefined();
@@ -925,7 +925,7 @@ describe("ModelRegistry", () => {
"custom-provider": providerWithApiKey("TEST_API_KEY_12345"),
});
const registry = new ModelRegistry(authStorage, modelsJsonPath);
const registry = ModelRegistry.create(authStorage, modelsJsonPath);
const apiKey = await registry.getApiKeyForProvider("custom-provider");
expect(apiKey).toBe("env-api-key-value");
@@ -946,7 +946,7 @@ describe("ModelRegistry", () => {
"custom-provider": providerWithApiKey("literal_api_key_value"),
});
const registry = new ModelRegistry(authStorage, modelsJsonPath);
const registry = ModelRegistry.create(authStorage, modelsJsonPath);
const apiKey = await registry.getApiKeyForProvider("custom-provider");
expect(apiKey).toBe("literal_api_key_value");
@@ -957,7 +957,7 @@ describe("ModelRegistry", () => {
"custom-provider": providerWithApiKey("!echo 'hello world' | tr ' ' '-'"),
});
const registry = new ModelRegistry(authStorage, modelsJsonPath);
const registry = ModelRegistry.create(authStorage, modelsJsonPath);
const apiKey = await registry.getApiKeyForProvider("custom-provider");
expect(apiKey).toBe("hello-world");
@@ -974,7 +974,7 @@ describe("ModelRegistry", () => {
"custom-provider": providerWithApiKey(command),
});
const registry = new ModelRegistry(authStorage, modelsJsonPath);
const registry = ModelRegistry.create(authStorage, modelsJsonPath);
await registry.getApiKeyForProvider("custom-provider");
await registry.getApiKeyForProvider("custom-provider");
await registry.getApiKeyForProvider("custom-provider");
@@ -993,10 +993,10 @@ describe("ModelRegistry", () => {
"custom-provider": providerWithApiKey(command),
});
const registry1 = new ModelRegistry(authStorage, modelsJsonPath);
const registry1 = ModelRegistry.create(authStorage, modelsJsonPath);
await registry1.getApiKeyForProvider("custom-provider");
const registry2 = new ModelRegistry(authStorage, modelsJsonPath);
const registry2 = ModelRegistry.create(authStorage, modelsJsonPath);
await registry2.getApiKeyForProvider("custom-provider");
const count = parseInt(readFileSync(counterFile, "utf-8").trim(), 10);
@@ -1009,7 +1009,7 @@ describe("ModelRegistry", () => {
"provider-b": providerWithApiKey("!echo key-b"),
});
const registry = new ModelRegistry(authStorage, modelsJsonPath);
const registry = ModelRegistry.create(authStorage, modelsJsonPath);
const keyA = await registry.getApiKeyForProvider("provider-a");
const keyB = await registry.getApiKeyForProvider("provider-b");
@@ -1028,7 +1028,7 @@ describe("ModelRegistry", () => {
"custom-provider": providerWithApiKey(command),
});
const registry = new ModelRegistry(authStorage, modelsJsonPath);
const registry = ModelRegistry.create(authStorage, modelsJsonPath);
const key1 = await registry.getApiKeyForProvider("custom-provider");
const key2 = await registry.getApiKeyForProvider("custom-provider");
@@ -1050,7 +1050,7 @@ describe("ModelRegistry", () => {
"custom-provider": providerWithApiKey(envVarName),
});
const registry = new ModelRegistry(authStorage, modelsJsonPath);
const registry = ModelRegistry.create(authStorage, modelsJsonPath);
const key1 = await registry.getApiKeyForProvider("custom-provider");
expect(key1).toBe("first-value");
@@ -1078,7 +1078,7 @@ describe("ModelRegistry", () => {
"custom-provider": providerWithApiKey(command),
});
const registry = new ModelRegistry(authStorage, modelsJsonPath);
const registry = ModelRegistry.create(authStorage, modelsJsonPath);
const available = registry.getAvailable();
expect(available.some((m) => m.provider === "custom-provider")).toBe(true);
@@ -1098,7 +1098,7 @@ describe("ModelRegistry", () => {
},
});
const registry = new ModelRegistry(authStorage, modelsJsonPath);
const registry = ModelRegistry.create(authStorage, modelsJsonPath);
const model = registry.find("custom-provider", "test-model");
expect(model).toBeDefined();
@@ -1127,7 +1127,7 @@ describe("ModelRegistry", () => {
},
});
const registry = new ModelRegistry(authStorage, modelsJsonPath);
const registry = ModelRegistry.create(authStorage, modelsJsonPath);
const model = registry.find("custom-provider", "test-model");
expect(model).toBeDefined();
@@ -199,7 +199,7 @@ Project skill`,
const sessionManager = SessionManager.inMemory();
const authStorage = AuthStorage.create(join(tempDir, "auth.json"));
const modelRegistry = new ModelRegistry(authStorage);
const modelRegistry = ModelRegistry.create(authStorage);
const runner = new ExtensionRunner(
extensionsResult.extensions,
extensionsResult.runtime,
@@ -548,7 +548,7 @@ export default function(pi: ExtensionAPI) {
const sessionManager = SessionManager.inMemory();
const authStorage = AuthStorage.create(join(tempDir, "auth-explicit.json"));
const modelRegistry = new ModelRegistry(authStorage);
const modelRegistry = ModelRegistry.create(authStorage);
const runner = new ExtensionRunner(
extensionsResult.extensions,
extensionsResult.runtime,
@@ -204,7 +204,7 @@ async function main(): Promise<void> {
mkdirSync(dirname(args.sessionPath), { recursive: true });
const authStorage = AuthStorage.create();
const modelRegistry = new ModelRegistry(authStorage);
const modelRegistry = ModelRegistry.create(authStorage);
const model = getModel("openai-codex", "gpt-5.4");
if (!model) {
+141
View File
@@ -0,0 +1,141 @@
/**
* Local test harness for the new coding-agent test suite.
*/
import { existsSync, mkdirSync, rmSync } from "node:fs";
import { tmpdir } from "node:os";
import { join } from "node:path";
import type { AgentTool } from "@mariozechner/pi-agent-core";
import { Agent } from "@mariozechner/pi-agent-core";
import type { FauxModelDefinition, FauxProviderRegistration, FauxResponseStep, Model } from "@mariozechner/pi-ai";
import { registerFauxProvider } from "@mariozechner/pi-ai";
import { AgentSession, type AgentSessionEvent } from "../../src/core/agent-session.js";
import { AuthStorage } from "../../src/core/auth-storage.js";
import { ModelRegistry } from "../../src/core/model-registry.js";
import { SessionManager } from "../../src/core/session-manager.js";
import type { Settings } from "../../src/core/settings-manager.js";
import { SettingsManager } from "../../src/core/settings-manager.js";
import type { ExtensionFactory, ResourceLoader } from "../../src/index.js";
import {
type CreateTestExtensionsResultInput,
createTestExtensionsResult,
createTestResourceLoader,
} from "../utilities.js";
export interface HarnessOptions {
models?: FauxModelDefinition[];
settings?: Partial<Settings>;
systemPrompt?: string;
tools?: AgentTool[];
resourceLoader?: ResourceLoader;
extensionFactories?: Array<ExtensionFactory | CreateTestExtensionsResultInput>;
}
export interface Harness {
session: AgentSession;
sessionManager: SessionManager;
settingsManager: SettingsManager;
faux: FauxProviderRegistration;
models: [Model<string>, ...Model<string>[]];
getModel(): Model<string>;
getModel(modelId: string): Model<string> | undefined;
setResponses: (responses: FauxResponseStep[]) => void;
appendResponses: (responses: FauxResponseStep[]) => void;
getPendingResponseCount: () => number;
events: AgentSessionEvent[];
eventsOfType<T extends AgentSessionEvent["type"]>(type: T): Extract<AgentSessionEvent, { type: T }>[];
tempDir: string;
cleanup: () => void;
}
function createTempDir(): string {
const tempDir = join(tmpdir(), `pi-suite-${Date.now()}-${Math.random().toString(36).slice(2)}`);
mkdirSync(tempDir, { recursive: true });
return tempDir;
}
export async function createHarness(options: HarnessOptions = {}): Promise<Harness> {
const tempDir = createTempDir();
const fauxProvider: FauxProviderRegistration = registerFauxProvider({
models: options.models,
});
fauxProvider.setResponses([]);
const model = fauxProvider.getModel();
const toolMap = options.tools ? Object.fromEntries(options.tools.map((tool) => [tool.name, tool])) : undefined;
const agent = new Agent({
getApiKey: () => "faux-key",
initialState: {
model,
systemPrompt: options.systemPrompt ?? "You are a test assistant.",
tools: [],
},
});
const sessionManager = SessionManager.inMemory();
const settingsManager = SettingsManager.inMemory(options.settings);
const authStorage = AuthStorage.inMemory();
authStorage.setRuntimeApiKey(model.provider, "faux-key");
const modelRegistry = ModelRegistry.inMemory(authStorage);
modelRegistry.registerProvider(model.provider, {
baseUrl: model.baseUrl,
apiKey: "faux-key",
api: fauxProvider.api,
models: fauxProvider.models.map((registeredModel) => ({
id: registeredModel.id,
name: registeredModel.name,
api: registeredModel.api,
reasoning: registeredModel.reasoning,
input: registeredModel.input,
cost: registeredModel.cost,
contextWindow: registeredModel.contextWindow,
maxTokens: registeredModel.maxTokens,
baseUrl: registeredModel.baseUrl,
})),
});
const extensionsResult = options.extensionFactories
? await createTestExtensionsResult(options.extensionFactories, tempDir)
: undefined;
const resourceLoader =
options.resourceLoader ?? createTestResourceLoader(extensionsResult ? { extensionsResult } : undefined);
const session = new AgentSession({
agent,
sessionManager,
settingsManager,
cwd: tempDir,
modelRegistry,
resourceLoader,
baseToolsOverride: toolMap,
});
const events: AgentSessionEvent[] = [];
session.subscribe((event) => {
events.push(event);
});
return {
session,
sessionManager,
settingsManager,
faux: fauxProvider,
models: fauxProvider.models,
getModel: fauxProvider.getModel,
setResponses: fauxProvider.setResponses,
appendResponses: fauxProvider.appendResponses,
getPendingResponseCount: fauxProvider.getPendingResponseCount,
events,
eventsOfType<T extends AgentSessionEvent["type"]>(type: T) {
return events.filter((event): event is Extract<AgentSessionEvent, { type: T }> => event.type === type);
},
tempDir,
cleanup() {
session.dispose();
fauxProvider.unregister();
if (existsSync(tempDir)) {
rmSync(tempDir, { recursive: true });
}
},
};
}
+1 -1
View File
@@ -390,7 +390,7 @@ function createHarnessWithResourceLoader(
const authStorage = AuthStorage.create(join(tempDir, "auth.json"));
authStorage.setRuntimeApiKey(model.provider, "faux-key");
const modelRegistry = new ModelRegistry(authStorage, tempDir);
const modelRegistry = ModelRegistry.create(authStorage, tempDir);
const session = new AgentSession({
agent,
+1 -1
View File
@@ -254,7 +254,7 @@ export function createTestSession(options: TestSessionOptions = {}): TestSession
}
const authStorage = AuthStorage.create(join(tempDir, "auth.json"));
const modelRegistry = new ModelRegistry(authStorage, tempDir);
const modelRegistry = ModelRegistry.create(authStorage, tempDir);
const session = new AgentSession({
agent,
+1 -1
View File
@@ -429,7 +429,7 @@ function createRunner(sandboxConfig: SandboxConfig, channelId: string, channelDi
// Create AuthStorage and ModelRegistry
// Auth stored outside workspace so agent can't access it
const authStorage = AuthStorage.create(join(homedir(), ".pi", "mom", "auth.json"));
const modelRegistry = new ModelRegistry(authStorage);
const modelRegistry = ModelRegistry.create(authStorage);
// Create agent
const agent = new Agent({