diff --git a/packages/ai/README.md b/packages/ai/README.md index 9c65c78fa..a54f06648 100644 --- a/packages/ai/README.md +++ b/packages/ai/README.md @@ -636,6 +636,92 @@ The library uses a registry of API implementations. Built-in APIs include: - **`azure-openai-responses`**: Azure OpenAI Responses API (`streamAzureOpenAIResponses`, `AzureOpenAIResponsesOptions`) - **`bedrock-converse-stream`**: Amazon Bedrock Converse API (`streamBedrock`, `BedrockOptions`) +### Faux provider for tests + +`registerFauxProvider()` registers a temporary in-memory provider for tests and demos. It is opt-in and not part of the built-in provider set. + +```typescript +import { + complete, + fauxAssistantMessage, + fauxText, + fauxThinking, + fauxToolCall, + registerFauxProvider, + stream, +} from '@mariozechner/pi-ai'; + +const registration = registerFauxProvider({ + tokensPerSecond: 50 // optional +}); + +const model = registration.getModel(); +const context = { + messages: [{ role: 'user', content: 'Summarize package.json and then call echo', timestamp: Date.now() }] +}; + +registration.setResponses([ + fauxAssistantMessage([ + fauxThinking('Need to inspect package metadata first.'), + fauxToolCall('echo', { text: 'package.json' }) + ], { stopReason: 'toolUse' }) +]); + +const first = await complete(model, context, { + sessionId: 'session-1', + cacheRetention: 'short' +}); +context.messages.push(first); + +context.messages.push({ + role: 'toolResult', + toolCallId: first.content.find((block) => block.type === 'toolCall')!.id, + toolName: 'echo', + content: [{ type: 'text', text: 'package.json contents here' }], + isError: false, + timestamp: Date.now() +}); + +registration.setResponses([ + fauxAssistantMessage([ + fauxThinking('Now I can summarize the tool output.'), + fauxText('Here is the summary.') + ]) +]); + +const s = stream(model, context); +for await (const event of s) { + console.log(event.type); +} + +// Optional: register multiple faux models for model-switching tests +const multiModel = registerFauxProvider({ + models: [ + { id: 'faux-fast', reasoning: false }, + { id: 'faux-thinker', reasoning: true } + ] +}); +const thinker = multiModel.getModel('faux-thinker'); + +console.log(thinker?.reasoning); +console.log(registration.getPendingResponseCount()); +console.log(registration.state.callCount); +registration.unregister(); +multiModel.unregister(); +``` + +Notes: +- Responses are consumed from a queue in request start order. +- If the queue is empty, the faux provider returns an assistant error message with `errorMessage: "No more faux responses queued"`. +- Use `registration.setResponses([...])` to replace the remaining queue and `registration.appendResponses([...])` to add more responses. +- `registration.models` exposes all registered faux models. `registration.getModel()` returns the first one, and `registration.getModel(id)` returns a specific one. +- Use `fauxAssistantMessage(...)` for scripted assistant replies. Use `fauxText(...)`, `fauxThinking(...)`, and `fauxToolCall(...)` to build content blocks without filling in low-level fields manually. +- `registration.unregister()` removes the temporary provider from the global API registry. +- Usage is estimated at roughly 1 token per 4 characters. When `sessionId` is present and `cacheRetention` is not `"none"`, prompt cache reads and writes are simulated automatically. +- Tool call arguments stream incrementally via `toolcall_delta` chunks. +- By default, each streamed chunk is emitted on its own microtask. Set `tokensPerSecond` to pace chunk delivery in real time. +- The intended use is one deterministic scripted flow per registration. If you need independent concurrent flows, register separate faux providers. + ### Providers and Models A **provider** offers models through a specific API. For example: diff --git a/packages/ai/src/index.ts b/packages/ai/src/index.ts index 6edf18be1..1c885a85d 100644 --- a/packages/ai/src/index.ts +++ b/packages/ai/src/index.ts @@ -7,6 +7,7 @@ export * from "./models.js"; export type { BedrockOptions } from "./providers/amazon-bedrock.js"; export type { AnthropicOptions } from "./providers/anthropic.js"; export type { AzureOpenAIResponsesOptions } from "./providers/azure-openai-responses.js"; +export * from "./providers/faux.js"; export type { GoogleOptions } from "./providers/google.js"; export type { GoogleGeminiCliOptions, GoogleThinkingLevel } from "./providers/google-gemini-cli.js"; export type { GoogleVertexOptions } from "./providers/google-vertex.js"; diff --git a/packages/ai/src/providers/faux.ts b/packages/ai/src/providers/faux.ts new file mode 100644 index 000000000..328fd87f7 --- /dev/null +++ b/packages/ai/src/providers/faux.ts @@ -0,0 +1,498 @@ +import { registerApiProvider, unregisterApiProviders } from "../api-registry.js"; +import type { + AssistantMessage, + AssistantMessageEventStream, + Context, + ImageContent, + Message, + Model, + SimpleStreamOptions, + StreamFunction, + StreamOptions, + TextContent, + ThinkingContent, + ToolCall, + ToolResultMessage, + Usage, +} from "../types.js"; +import { createAssistantMessageEventStream } from "../utils/event-stream.js"; + +const DEFAULT_API = "faux"; +const DEFAULT_PROVIDER = "faux"; +const DEFAULT_MODEL_ID = "faux-1"; +const DEFAULT_MODEL_NAME = "Faux Model"; +const DEFAULT_BASE_URL = "http://localhost:0"; +const DEFAULT_MIN_TOKEN_SIZE = 3; +const DEFAULT_MAX_TOKEN_SIZE = 5; + +const DEFAULT_USAGE: Usage = { + input: 0, + output: 0, + cacheRead: 0, + cacheWrite: 0, + totalTokens: 0, + cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, total: 0 }, +}; + +export interface FauxModelDefinition { + id: string; + name?: string; + reasoning?: boolean; + input?: ("text" | "image")[]; + cost?: { input: number; output: number; cacheRead: number; cacheWrite: number }; + contextWindow?: number; + maxTokens?: number; +} + +export type FauxContentBlock = TextContent | ThinkingContent | ToolCall; + +export function fauxText(text: string): TextContent { + return { type: "text", text }; +} + +export function fauxThinking(thinking: string): ThinkingContent { + return { type: "thinking", thinking }; +} + +export function fauxToolCall(name: string, arguments_: ToolCall["arguments"], options: { id?: string } = {}): ToolCall { + return { + type: "toolCall", + id: options.id ?? randomId("tool"), + name, + arguments: arguments_, + }; +} + +function normalizeFauxAssistantContent(content: string | FauxContentBlock | FauxContentBlock[]): FauxContentBlock[] { + if (typeof content === "string") { + return [fauxText(content)]; + } + return Array.isArray(content) ? content : [content]; +} + +export function fauxAssistantMessage( + content: string | FauxContentBlock | FauxContentBlock[], + options: { + stopReason?: AssistantMessage["stopReason"]; + errorMessage?: string; + responseId?: string; + timestamp?: number; + } = {}, +): AssistantMessage { + return { + role: "assistant", + content: normalizeFauxAssistantContent(content), + api: DEFAULT_API, + provider: DEFAULT_PROVIDER, + model: DEFAULT_MODEL_ID, + usage: DEFAULT_USAGE, + stopReason: options.stopReason ?? "stop", + errorMessage: options.errorMessage, + responseId: options.responseId, + timestamp: options.timestamp ?? Date.now(), + }; +} + +export type FauxResponseFactory = ( + context: Context, + options: StreamOptions | undefined, + state: { callCount: number }, + model: Model, +) => AssistantMessage | Promise; + +export type FauxResponseStep = AssistantMessage | FauxResponseFactory; + +export interface RegisterFauxProviderOptions { + api?: string; + provider?: string; + models?: FauxModelDefinition[]; + tokensPerSecond?: number; + tokenSize?: { + min?: number; + max?: number; + }; +} + +export interface FauxProviderRegistration { + api: string; + models: [Model, ...Model[]]; + getModel(): Model; + getModel(modelId: string): Model | undefined; + state: { callCount: number }; + setResponses: (responses: FauxResponseStep[]) => void; + appendResponses: (responses: FauxResponseStep[]) => void; + getPendingResponseCount: () => number; + unregister: () => void; +} + +function estimateTokens(text: string): number { + return Math.ceil(text.length / 4); +} + +function randomId(prefix: string): string { + return `${prefix}:${Date.now()}:${Math.random().toString(36).slice(2)}`; +} + +function contentToText(content: string | Array): string { + if (typeof content === "string") { + return content; + } + return content + .map((block) => { + if (block.type === "text") { + return block.text; + } + return `[image:${block.mimeType}:${block.data.length}]`; + }) + .join("\n"); +} + +function assistantContentToText(content: Array): string { + return content + .map((block) => { + if (block.type === "text") { + return block.text; + } + if (block.type === "thinking") { + return block.thinking; + } + return `${block.name}:${JSON.stringify(block.arguments)}`; + }) + .join("\n"); +} + +function toolResultToText(message: ToolResultMessage): string { + return [message.toolName, ...message.content.map((block) => contentToText([block]))].join("\n"); +} + +function messageToText(message: Message): string { + if (message.role === "user") { + return contentToText(message.content); + } + if (message.role === "assistant") { + return assistantContentToText(message.content); + } + return toolResultToText(message); +} + +function serializeContext(context: Context): string { + const parts: string[] = []; + if (context.systemPrompt) { + parts.push(`system:${context.systemPrompt}`); + } + for (const message of context.messages) { + parts.push(`${message.role}:${messageToText(message)}`); + } + if (context.tools?.length) { + parts.push(`tools:${JSON.stringify(context.tools)}`); + } + return parts.join("\n\n"); +} + +function commonPrefixLength(a: string, b: string): number { + const length = Math.min(a.length, b.length); + let index = 0; + while (index < length && a[index] === b[index]) { + index++; + } + return index; +} + +function withUsageEstimate( + message: AssistantMessage, + context: Context, + options: StreamOptions | undefined, + promptCache: Map, +): AssistantMessage { + const promptText = serializeContext(context); + const promptTokens = estimateTokens(promptText); + const outputTokens = estimateTokens(assistantContentToText(message.content)); + let input = promptTokens; + let cacheRead = 0; + let cacheWrite = 0; + const sessionId = options?.sessionId; + + if (sessionId && options?.cacheRetention !== "none") { + const previousPrompt = promptCache.get(sessionId); + if (previousPrompt) { + const cachedChars = commonPrefixLength(previousPrompt, promptText); + cacheRead = estimateTokens(previousPrompt.slice(0, cachedChars)); + cacheWrite = estimateTokens(promptText.slice(cachedChars)); + input = Math.max(0, promptTokens - cacheRead); + } else { + cacheWrite = promptTokens; + } + promptCache.set(sessionId, promptText); + } + + return { + ...message, + usage: { + input, + output: outputTokens, + cacheRead, + cacheWrite, + totalTokens: input + outputTokens + cacheRead + cacheWrite, + cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0, total: 0 }, + }, + }; +} + +function splitStringByTokenSize(text: string, minTokenSize: number, maxTokenSize: number): string[] { + const chunks: string[] = []; + let index = 0; + while (index < text.length) { + const tokenSize = minTokenSize + Math.floor(Math.random() * (maxTokenSize - minTokenSize + 1)); + const charSize = Math.max(1, tokenSize * 4); + chunks.push(text.slice(index, index + charSize)); + index += charSize; + } + return chunks.length > 0 ? chunks : [""]; +} + +function cloneMessage(message: AssistantMessage, api: string, provider: string, modelId: string): AssistantMessage { + const cloned = structuredClone(message); + return { + ...cloned, + api, + provider, + model: modelId, + timestamp: cloned.timestamp ?? Date.now(), + usage: cloned.usage ?? DEFAULT_USAGE, + }; +} + +function createErrorMessage(error: unknown, api: string, provider: string, modelId: string): AssistantMessage { + return { + role: "assistant", + content: [], + api, + provider, + model: modelId, + usage: DEFAULT_USAGE, + stopReason: "error", + errorMessage: error instanceof Error ? error.message : String(error), + timestamp: Date.now(), + }; +} + +function createAbortedMessage(partial: AssistantMessage): AssistantMessage { + return { + ...partial, + stopReason: "aborted", + errorMessage: "Request was aborted", + timestamp: Date.now(), + }; +} + +function scheduleChunk(chunk: string, tokensPerSecond: number | undefined): Promise { + if (!tokensPerSecond || tokensPerSecond <= 0) { + return new Promise((resolve) => queueMicrotask(resolve)); + } + const delayMs = (estimateTokens(chunk) / tokensPerSecond) * 1000; + return new Promise((resolve) => setTimeout(resolve, delayMs)); +} + +async function streamWithDeltas( + stream: AssistantMessageEventStream, + message: AssistantMessage, + minTokenSize: number, + maxTokenSize: number, + tokensPerSecond: number | undefined, + signal: AbortSignal | undefined, +): Promise { + const partial: AssistantMessage = { ...message, content: [] }; + if (signal?.aborted) { + const aborted = createAbortedMessage(partial); + stream.push({ type: "error", reason: "aborted", error: aborted }); + stream.end(aborted); + return; + } + + stream.push({ type: "start", partial: { ...partial } }); + + for (let index = 0; index < message.content.length; index++) { + if (signal?.aborted) { + const aborted = createAbortedMessage(partial); + stream.push({ type: "error", reason: "aborted", error: aborted }); + stream.end(aborted); + return; + } + + const block = message.content[index]; + + if (block.type === "thinking") { + partial.content = [...partial.content, { type: "thinking", thinking: "" }]; + stream.push({ type: "thinking_start", contentIndex: index, partial: { ...partial } }); + for (const chunk of splitStringByTokenSize(block.thinking, minTokenSize, maxTokenSize)) { + await scheduleChunk(chunk, tokensPerSecond); + if (signal?.aborted) { + const aborted = createAbortedMessage(partial); + stream.push({ type: "error", reason: "aborted", error: aborted }); + stream.end(aborted); + return; + } + (partial.content[index] as ThinkingContent).thinking += chunk; + stream.push({ type: "thinking_delta", contentIndex: index, delta: chunk, partial: { ...partial } }); + } + stream.push({ + type: "thinking_end", + contentIndex: index, + content: block.thinking, + partial: { ...partial }, + }); + continue; + } + + if (block.type === "text") { + partial.content = [...partial.content, { type: "text", text: "" }]; + stream.push({ type: "text_start", contentIndex: index, partial: { ...partial } }); + for (const chunk of splitStringByTokenSize(block.text, minTokenSize, maxTokenSize)) { + await scheduleChunk(chunk, tokensPerSecond); + if (signal?.aborted) { + const aborted = createAbortedMessage(partial); + stream.push({ type: "error", reason: "aborted", error: aborted }); + stream.end(aborted); + return; + } + (partial.content[index] as TextContent).text += chunk; + stream.push({ type: "text_delta", contentIndex: index, delta: chunk, partial: { ...partial } }); + } + stream.push({ type: "text_end", contentIndex: index, content: block.text, partial: { ...partial } }); + continue; + } + + partial.content = [...partial.content, { type: "toolCall", id: block.id, name: block.name, arguments: {} }]; + stream.push({ type: "toolcall_start", contentIndex: index, partial: { ...partial } }); + for (const chunk of splitStringByTokenSize(JSON.stringify(block.arguments), minTokenSize, maxTokenSize)) { + await scheduleChunk(chunk, tokensPerSecond); + if (signal?.aborted) { + const aborted = createAbortedMessage(partial); + stream.push({ type: "error", reason: "aborted", error: aborted }); + stream.end(aborted); + return; + } + stream.push({ type: "toolcall_delta", contentIndex: index, delta: chunk, partial: { ...partial } }); + } + (partial.content[index] as ToolCall).arguments = block.arguments; + stream.push({ type: "toolcall_end", contentIndex: index, toolCall: block, partial: { ...partial } }); + } + + if (message.stopReason === "error" || message.stopReason === "aborted") { + stream.push({ type: "error", reason: message.stopReason, error: message }); + stream.end(message); + return; + } + + stream.push({ type: "done", reason: message.stopReason, message }); + stream.end(message); +} + +export function registerFauxProvider(options: RegisterFauxProviderOptions = {}): FauxProviderRegistration { + const api = options.api ?? randomId(DEFAULT_API); + const provider = options.provider ?? DEFAULT_PROVIDER; + const sourceId = randomId("faux-provider"); + const minTokenSize = Math.max( + 1, + Math.min(options.tokenSize?.min ?? DEFAULT_MIN_TOKEN_SIZE, options.tokenSize?.max ?? DEFAULT_MAX_TOKEN_SIZE), + ); + const maxTokenSize = Math.max(minTokenSize, options.tokenSize?.max ?? DEFAULT_MAX_TOKEN_SIZE); + let pendingResponses: FauxResponseStep[] = []; + const tokensPerSecond = options.tokensPerSecond; + const state = { callCount: 0 }; + const promptCache = new Map(); + + const modelDefinitions = options.models?.length + ? options.models + : [ + { + id: DEFAULT_MODEL_ID, + name: DEFAULT_MODEL_NAME, + reasoning: false, + input: ["text", "image"] as ("text" | "image")[], + cost: { input: 0, output: 0, cacheRead: 0, cacheWrite: 0 }, + contextWindow: 128000, + maxTokens: 16384, + }, + ]; + const models = modelDefinitions.map((definition) => ({ + id: definition.id, + name: definition.name ?? definition.id, + api, + provider, + baseUrl: DEFAULT_BASE_URL, + reasoning: definition.reasoning ?? false, + input: definition.input ?? ["text", "image"], + cost: definition.cost ?? { input: 0, output: 0, cacheRead: 0, cacheWrite: 0 }, + contextWindow: definition.contextWindow ?? 128000, + maxTokens: definition.maxTokens ?? 16384, + })) as [Model, ...Model[]]; + + const stream: StreamFunction = (requestModel, context, streamOptions) => { + const outer = createAssistantMessageEventStream(); + const step = pendingResponses.shift(); + state.callCount++; + + queueMicrotask(async () => { + try { + if (!step) { + let message = createErrorMessage( + new Error("No more faux responses queued"), + api, + provider, + requestModel.id, + ); + message = withUsageEstimate(message, context, streamOptions, promptCache); + outer.push({ type: "error", reason: "error", error: message }); + outer.end(message); + return; + } + + const resolved = + typeof step === "function" ? await step(context, streamOptions, state, requestModel) : step; + let message = cloneMessage(resolved, api, provider, requestModel.id); + message = withUsageEstimate(message, context, streamOptions, promptCache); + await streamWithDeltas(outer, message, minTokenSize, maxTokenSize, tokensPerSecond, streamOptions?.signal); + } catch (error) { + const message = createErrorMessage(error, api, provider, requestModel.id); + outer.push({ type: "error", reason: "error", error: message }); + outer.end(message); + } + }); + + return outer; + }; + + const streamSimple: StreamFunction = (streamModel, context, streamOptions) => + stream(streamModel, context, streamOptions); + + registerApiProvider({ api, stream, streamSimple }, sourceId); + + function getModel(): Model; + function getModel(requestedModelId: string): Model | undefined; + function getModel(requestedModelId?: string): Model | undefined { + if (!requestedModelId) { + return models[0]; + } + return models.find((candidate) => candidate.id === requestedModelId); + } + + return { + api, + models, + getModel, + state, + setResponses(responses) { + pendingResponses = [...responses]; + }, + appendResponses(responses) { + pendingResponses.push(...responses); + }, + getPendingResponseCount() { + return pendingResponses.length; + }, + unregister() { + unregisterApiProviders(sourceId); + }, + }; +} diff --git a/packages/ai/test/faux-provider.test.ts b/packages/ai/test/faux-provider.test.ts new file mode 100644 index 000000000..16f338839 --- /dev/null +++ b/packages/ai/test/faux-provider.test.ts @@ -0,0 +1,597 @@ +import { afterEach, describe, expect, it } from "vitest"; +import { + complete, + fauxAssistantMessage, + fauxText, + fauxThinking, + fauxToolCall, + registerFauxProvider, + stream, + Type, +} from "../src/index.js"; +import type { AssistantMessageEvent, Context } from "../src/types.js"; + +async function collectEvents(streamResult: ReturnType): Promise { + const events: AssistantMessageEvent[] = []; + for await (const event of streamResult) { + events.push(event); + } + return events; +} + +const registrations: Array<{ unregister: () => void }> = []; + +afterEach(() => { + for (const registration of registrations.splice(0)) { + registration.unregister(); + } +}); + +describe("faux provider", () => { + it("registers a custom provider and estimates usage", async () => { + const registration = registerFauxProvider(); + registrations.push(registration); + registration.setResponses([fauxAssistantMessage("hello world")]); + + const context: Context = { + systemPrompt: "Be concise.", + messages: [{ role: "user", content: "hi there", timestamp: Date.now() }], + }; + + const response = await complete(registration.getModel(), context); + expect(response.content).toEqual([{ type: "text", text: "hello world" }]); + expect(response.usage.input).toBeGreaterThan(0); + expect(response.usage.output).toBeGreaterThan(0); + expect(response.usage.totalTokens).toBe(response.usage.input + response.usage.output); + expect(registration.state.callCount).toBe(1); + }); + + it("supports helper blocks for text, thinking, and tool calls", async () => { + const registration = registerFauxProvider(); + registrations.push(registration); + registration.setResponses([ + fauxAssistantMessage([fauxThinking("think"), fauxToolCall("echo", { text: "hi" }), fauxText("done")], { + stopReason: "toolUse", + }), + ]); + + const response = await complete(registration.getModel(), { + messages: [{ role: "user", content: "hi", timestamp: Date.now() }], + }); + + expect(response.content).toEqual([ + { type: "thinking", thinking: "think" }, + { type: "toolCall", id: expect.any(String), name: "echo", arguments: { text: "hi" } }, + { type: "text", text: "done" }, + ]); + expect(response.stopReason).toBe("toolUse"); + }); + + it("supports multiple models with per-model reasoning and model-aware factories", async () => { + const registration = registerFauxProvider({ + models: [ + { id: "faux-fast", name: "Faux Fast", reasoning: false }, + { id: "faux-thinker", name: "Faux Thinker", reasoning: true }, + ], + }); + registrations.push(registration); + registration.setResponses([ + (_context, _options, _state, model) => fauxAssistantMessage(`${model.id}:${String(model.reasoning)}`), + (_context, _options, _state, model) => fauxAssistantMessage(`${model.id}:${String(model.reasoning)}`), + ]); + + expect(registration.models.map((model) => model.id)).toEqual(["faux-fast", "faux-thinker"]); + expect(registration.getModel()).toBe(registration.models[0]); + expect(registration.getModel("faux-fast")?.reasoning).toBe(false); + expect(registration.getModel("faux-thinker")?.reasoning).toBe(true); + + const fast = await complete(registration.getModel("faux-fast")!, { + messages: [{ role: "user", content: "hi", timestamp: Date.now() }], + }); + const thinker = await complete(registration.getModel("faux-thinker")!, { + messages: [{ role: "user", content: "hi", timestamp: Date.now() }], + }); + + expect(fast.content).toEqual([{ type: "text", text: "faux-fast:false" }]); + expect(thinker.content).toEqual([{ type: "text", text: "faux-thinker:true" }]); + }); + + it("rewrites api, provider, and model on returned messages", async () => { + const registration = registerFauxProvider({ + api: "faux:test", + provider: "faux-provider", + models: [{ id: "faux-model" }], + }); + registrations.push(registration); + registration.setResponses([fauxAssistantMessage("hello")]); + + const response = await complete(registration.getModel(), { + messages: [{ role: "user", content: "hi", timestamp: Date.now() }], + }); + + expect(response.api).toBe("faux:test"); + expect(response.provider).toBe("faux-provider"); + expect(response.model).toBe("faux-model"); + }); + + it("consumes queued responses in order and errors when exhausted", async () => { + const registration = registerFauxProvider(); + registrations.push(registration); + registration.setResponses([fauxAssistantMessage("first"), fauxAssistantMessage("second")]); + + const context: Context = { + messages: [{ role: "user", content: "hi", timestamp: Date.now() }], + }; + + const first = await complete(registration.getModel(), context); + const second = await complete(registration.getModel(), context); + const exhausted = await complete(registration.getModel(), context); + + expect(first.content).toEqual([{ type: "text", text: "first" }]); + expect(second.content).toEqual([{ type: "text", text: "second" }]); + expect(exhausted.stopReason).toBe("error"); + expect(exhausted.errorMessage).toBe("No more faux responses queued"); + expect(registration.getPendingResponseCount()).toBe(0); + expect(registration.state.callCount).toBe(3); + }); + + it("can replace and append queued responses", async () => { + const registration = registerFauxProvider(); + registrations.push(registration); + registration.setResponses([fauxAssistantMessage("first")]); + + const context: Context = { + messages: [{ role: "user", content: "hi", timestamp: Date.now() }], + }; + + expect((await complete(registration.getModel(), context)).content).toEqual([{ type: "text", text: "first" }]); + expect(registration.getPendingResponseCount()).toBe(0); + + registration.setResponses([fauxAssistantMessage("second")]); + expect(registration.getPendingResponseCount()).toBe(1); + expect((await complete(registration.getModel(), context)).content).toEqual([{ type: "text", text: "second" }]); + + registration.appendResponses([fauxAssistantMessage("third"), fauxAssistantMessage("fourth")]); + expect(registration.getPendingResponseCount()).toBe(2); + expect((await complete(registration.getModel(), context)).content).toEqual([{ type: "text", text: "third" }]); + expect((await complete(registration.getModel(), context)).content).toEqual([{ type: "text", text: "fourth" }]); + expect(registration.getPendingResponseCount()).toBe(0); + }); + + it("supports async response factories", async () => { + const registration = registerFauxProvider(); + registrations.push(registration); + registration.setResponses([ + async (context, _options, state) => fauxAssistantMessage(`${context.messages.length}:${state.callCount}`), + ]); + + const response = await complete(registration.getModel(), { + messages: [{ role: "user", content: "hi", timestamp: Date.now() }], + }); + + expect(response.content).toEqual([{ type: "text", text: "1:1" }]); + }); + + it("emits an error when a response factory throws", async () => { + const registration = registerFauxProvider(); + registrations.push(registration); + registration.setResponses([ + () => { + throw new Error("boom"); + }, + ]); + + const events = await collectEvents( + stream(registration.getModel(), { messages: [{ role: "user", content: "hi", timestamp: Date.now() }] }), + ); + + expect(events).toHaveLength(1); + expect(events[0].type).toBe("error"); + if (events[0].type === "error") { + expect(events[0].error.stopReason).toBe("error"); + expect(events[0].error.errorMessage).toBe("boom"); + } + }); + + it("estimates prompt and output tokens from serialized context", async () => { + const registration = registerFauxProvider(); + registrations.push(registration); + registration.setResponses([fauxAssistantMessage("done")]); + + const tool = { + name: "echo", + description: "Echo back text", + parameters: Type.Object({ text: Type.String() }), + }; + const context: Context = { + systemPrompt: "sys", + messages: [ + { + role: "user", + content: [ + { type: "text", text: "hello" }, + { type: "image", mimeType: "image/png", data: "abcd" }, + ], + timestamp: 1, + }, + fauxAssistantMessage("prior"), + { + role: "toolResult", + toolCallId: "tool-1", + toolName: "echo", + content: [{ type: "text", text: "tool out" }], + isError: false, + timestamp: 2, + }, + ], + tools: [tool], + }; + + const response = await complete(registration.getModel(), context); + const promptText = [ + "system:sys", + "user:hello\n[image:image/png:4]", + "assistant:prior", + "toolResult:echo\ntool out", + `tools:${JSON.stringify([tool])}`, + ].join("\n\n"); + const expectedPromptTokens = Math.ceil(promptText.length / 4); + const expectedOutputTokens = Math.ceil("done".length / 4); + + expect(response.usage.input).toBe(expectedPromptTokens); + expect(response.usage.output).toBe(expectedOutputTokens); + expect(response.usage.cacheRead).toBe(0); + expect(response.usage.cacheWrite).toBe(0); + expect(response.usage.totalTokens).toBe(expectedPromptTokens + expectedOutputTokens); + }); + + it("does not share cache across sessions or requests without sessionId", async () => { + const registration = registerFauxProvider(); + registrations.push(registration); + registration.setResponses([ + fauxAssistantMessage("first"), + fauxAssistantMessage("second"), + fauxAssistantMessage("third"), + ]); + + const context: Context = { + messages: [{ role: "user", content: "hello", timestamp: Date.now() }], + }; + + const first = await complete(registration.getModel(), context, { + sessionId: "session-1", + cacheRetention: "short", + }); + expect(first.usage.cacheWrite).toBeGreaterThan(0); + context.messages.push(first); + context.messages.push({ role: "user", content: "follow up", timestamp: Date.now() + 1 }); + + const second = await complete(registration.getModel(), context, { + sessionId: "session-2", + cacheRetention: "short", + }); + expect(second.usage.cacheRead).toBe(0); + expect(second.usage.cacheWrite).toBeGreaterThan(0); + + const third = await complete(registration.getModel(), context); + expect(third.usage.cacheRead).toBe(0); + expect(third.usage.cacheWrite).toBe(0); + }); + + it("simulates prompt caching per sessionId", async () => { + const registration = registerFauxProvider(); + registrations.push(registration); + registration.setResponses([fauxAssistantMessage("first"), fauxAssistantMessage("second")]); + + const context: Context = { + systemPrompt: "Be concise.", + messages: [{ role: "user", content: "hello", timestamp: Date.now() }], + }; + + const first = await complete(registration.getModel(), context, { + sessionId: "session-1", + cacheRetention: "short", + }); + expect(first.usage.cacheRead).toBe(0); + expect(first.usage.cacheWrite).toBeGreaterThan(0); + + context.messages.push(first); + context.messages.push({ role: "user", content: "follow up", timestamp: Date.now() + 1 }); + + const second = await complete(registration.getModel(), context, { + sessionId: "session-1", + cacheRetention: "short", + }); + expect(second.usage.cacheRead).toBeGreaterThan(0); + expect(second.usage.input + second.usage.cacheRead).toBeGreaterThan(second.usage.input); + }); + + it("does not simulate caching when cacheRetention is none", async () => { + const registration = registerFauxProvider(); + registrations.push(registration); + registration.setResponses([fauxAssistantMessage("first"), fauxAssistantMessage("second")]); + + const context: Context = { + messages: [{ role: "user", content: "hello", timestamp: Date.now() }], + }; + + await complete(registration.getModel(), context, { sessionId: "session-1", cacheRetention: "none" }); + context.messages.push(fauxAssistantMessage("first")); + context.messages.push({ role: "user", content: "follow up", timestamp: Date.now() + 1 }); + const second = await complete(registration.getModel(), context, { + sessionId: "session-1", + cacheRetention: "none", + }); + expect(second.usage.cacheRead).toBe(0); + expect(second.usage.cacheWrite).toBe(0); + }); + + it("streams thinking, text, and partial tool call deltas", async () => { + const registration = registerFauxProvider(); + registrations.push(registration); + registration.setResponses([ + fauxAssistantMessage( + [ + fauxThinking("thinking text"), + fauxText("answer text"), + fauxToolCall("echo", { text: "hi", count: 12 }, { id: "tool-1" }), + ], + { stopReason: "toolUse" }, + ), + ]); + + const events: string[] = []; + const toolCallDeltas: string[] = []; + const s = stream(registration.getModel(), { messages: [{ role: "user", content: "hi", timestamp: Date.now() }] }); + for await (const event of s) { + events.push(event.type); + if (event.type === "toolcall_delta") { + toolCallDeltas.push(event.delta); + } + } + + expect(events).toContain("thinking_start"); + expect(events).toContain("thinking_delta"); + expect(events).toContain("text_start"); + expect(events).toContain("text_delta"); + expect(events).toContain("toolcall_start"); + expect(events).toContain("toolcall_delta"); + expect(events).toContain("toolcall_end"); + expect(toolCallDeltas.length).toBeGreaterThan(1); + expect(JSON.parse(toolCallDeltas.join(""))).toEqual({ text: "hi", count: 12 }); + }); + + it("streams an exact event order for fixed-size chunks", async () => { + const registration = registerFauxProvider({ tokenSize: { min: 1, max: 1 } }); + registrations.push(registration); + registration.setResponses([ + fauxAssistantMessage([fauxThinking("go"), fauxText("ok"), fauxToolCall("echo", {}, { id: "tool-1" })], { + stopReason: "toolUse", + }), + ]); + + const events = await collectEvents( + stream(registration.getModel(), { messages: [{ role: "user", content: "hi", timestamp: Date.now() }] }), + ); + + expect(events.map((event) => event.type)).toEqual([ + "start", + "thinking_start", + "thinking_delta", + "thinking_end", + "text_start", + "text_delta", + "text_end", + "toolcall_start", + "toolcall_delta", + "toolcall_end", + "done", + ]); + }); + + it("streams multiple tool calls in one message", async () => { + const registration = registerFauxProvider(); + registrations.push(registration); + registration.setResponses([ + fauxAssistantMessage( + [ + fauxToolCall("echo", { text: "one" }, { id: "tool-1" }), + fauxToolCall("echo", { text: "two" }, { id: "tool-2" }), + ], + { stopReason: "toolUse" }, + ), + ]); + + const events = await collectEvents( + stream(registration.getModel(), { messages: [{ role: "user", content: "hi", timestamp: Date.now() }] }), + ); + + expect(events.filter((event) => event.type === "toolcall_start")).toHaveLength(2); + expect(events.filter((event) => event.type === "toolcall_end")).toHaveLength(2); + }); + + it("streams an explicit assistant error message as a terminal error", async () => { + const registration = registerFauxProvider({ tokenSize: { min: 2, max: 2 } }); + registrations.push(registration); + registration.setResponses([ + { + ...fauxAssistantMessage("partial"), + stopReason: "error", + errorMessage: "upstream failed", + }, + ]); + + const events = await collectEvents( + stream(registration.getModel(), { messages: [{ role: "user", content: "hi", timestamp: Date.now() }] }), + ); + + expect(events.map((event) => event.type)).toEqual(["start", "text_start", "text_delta", "text_end", "error"]); + const terminal = events[events.length - 1]; + expect(terminal.type).toBe("error"); + if (terminal.type === "error") { + expect(terminal.reason).toBe("error"); + expect(terminal.error.stopReason).toBe("error"); + expect(terminal.error.errorMessage).toBe("upstream failed"); + } + }); + + it("streams an explicit assistant aborted message as a terminal error", async () => { + const registration = registerFauxProvider({ tokenSize: { min: 2, max: 2 } }); + registrations.push(registration); + registration.setResponses([ + { + ...fauxAssistantMessage("partial"), + stopReason: "aborted", + errorMessage: "Request was aborted", + }, + ]); + + const events = await collectEvents( + stream(registration.getModel(), { messages: [{ role: "user", content: "hi", timestamp: Date.now() }] }), + ); + + expect(events.map((event) => event.type)).toEqual(["start", "text_start", "text_delta", "text_end", "error"]); + const terminal = events[events.length - 1]; + expect(terminal.type).toBe("error"); + if (terminal.type === "error") { + expect(terminal.reason).toBe("aborted"); + expect(terminal.error.stopReason).toBe("aborted"); + expect(terminal.error.errorMessage).toBe("Request was aborted"); + } + }); + + it("supports aborting before the first chunk", async () => { + const registration = registerFauxProvider({ tokensPerSecond: 50, tokenSize: { min: 3, max: 3 } }); + registrations.push(registration); + registration.setResponses([fauxAssistantMessage("abcdefghijklmnopqrstuvwxyz")]); + + const controller = new AbortController(); + controller.abort(); + const events = await collectEvents( + stream( + registration.getModel(), + { messages: [{ role: "user", content: "hi", timestamp: Date.now() }] }, + { signal: controller.signal }, + ), + ); + + expect(events).toHaveLength(1); + expect(events[0].type).toBe("error"); + if (events[0].type === "error") { + expect(events[0].reason).toBe("aborted"); + expect(events[0].error.stopReason).toBe("aborted"); + } + }); + + it("supports aborting mid-text stream when paced", async () => { + const registration = registerFauxProvider({ tokensPerSecond: 100, tokenSize: { min: 3, max: 3 } }); + registrations.push(registration); + registration.setResponses([fauxAssistantMessage("abcdefghijklmnopqrstuvwxyz")]); + + const controller = new AbortController(); + const events: string[] = []; + let textDeltaCount = 0; + const s = stream( + registration.getModel(), + { messages: [{ role: "user", content: "hi", timestamp: Date.now() }] }, + { signal: controller.signal }, + ); + for await (const event of s) { + events.push(event.type); + if (event.type === "text_delta") { + textDeltaCount++; + controller.abort(); + } + } + + expect(textDeltaCount).toBe(1); + expect(events).toContain("text_start"); + expect(events).toContain("text_delta"); + expect(events).toContain("error"); + expect(events).not.toContain("text_end"); + }); + + it("supports aborting mid-thinking stream when paced", async () => { + const registration = registerFauxProvider({ tokensPerSecond: 100, tokenSize: { min: 3, max: 3 } }); + registrations.push(registration); + registration.setResponses([ + { + ...fauxAssistantMessage("ignored"), + content: [{ type: "thinking", thinking: "abcdefghijklmnopqrstuvwxyz" }], + }, + ]); + + const controller = new AbortController(); + const events: string[] = []; + let thinkingDeltaCount = 0; + const s = stream( + registration.getModel(), + { messages: [{ role: "user", content: "hi", timestamp: Date.now() }] }, + { signal: controller.signal }, + ); + for await (const event of s) { + events.push(event.type); + if (event.type === "thinking_delta") { + thinkingDeltaCount++; + controller.abort(); + } + } + + expect(thinkingDeltaCount).toBe(1); + expect(events).toContain("thinking_start"); + expect(events).toContain("thinking_delta"); + expect(events).toContain("error"); + expect(events).not.toContain("thinking_end"); + }); + + it("supports aborting mid-toolcall stream when paced", async () => { + const registration = registerFauxProvider({ tokensPerSecond: 100, tokenSize: { min: 3, max: 3 } }); + registrations.push(registration); + registration.setResponses([ + { + ...fauxAssistantMessage("done"), + content: [ + { + type: "toolCall", + id: "tool-1", + name: "echo", + arguments: { text: "abcdefghijklmnopqrstuvwxyz", count: 123456789 }, + }, + ], + stopReason: "toolUse", + }, + ]); + + const controller = new AbortController(); + const events: string[] = []; + let toolCallDeltaCount = 0; + const s = stream( + registration.getModel(), + { messages: [{ role: "user", content: "hi", timestamp: Date.now() }] }, + { signal: controller.signal }, + ); + for await (const event of s) { + events.push(event.type); + if (event.type === "toolcall_delta") { + toolCallDeltaCount++; + controller.abort(); + } + } + + expect(toolCallDeltaCount).toBe(1); + expect(events).toContain("toolcall_start"); + expect(events).toContain("toolcall_delta"); + expect(events).toContain("error"); + expect(events).not.toContain("toolcall_end"); + }); + + it("unregisters the provider", async () => { + const registration = registerFauxProvider(); + registration.setResponses([fauxAssistantMessage("hello")]); + registration.unregister(); + + await expect( + complete(registration.getModel(), { messages: [{ role: "user", content: "hi", timestamp: Date.now() }] }), + ).rejects.toThrow(`No API provider registered for api: ${registration.api}`); + }); +}); diff --git a/packages/coding-agent/README.md b/packages/coding-agent/README.md index c818ac22d..7efa75e23 100644 --- a/packages/coding-agent/README.md +++ b/packages/coding-agent/README.md @@ -395,7 +395,7 @@ import { AuthStorage, createAgentSession, ModelRegistry, SessionManager } from " const { session } = await createAgentSession({ sessionManager: SessionManager.inMemory(), authStorage: AuthStorage.create(), - modelRegistry: new ModelRegistry(authStorage), + modelRegistry: ModelRegistry.create(authStorage), }); await session.prompt("What files are in the current directory?"); diff --git a/packages/coding-agent/docs/sdk.md b/packages/coding-agent/docs/sdk.md index 6371cfec3..d8382d708 100644 --- a/packages/coding-agent/docs/sdk.md +++ b/packages/coding-agent/docs/sdk.md @@ -20,7 +20,7 @@ import { AuthStorage, createAgentSession, ModelRegistry, SessionManager } from " // Set up credential storage and model registry const authStorage = AuthStorage.create(); -const modelRegistry = new ModelRegistry(authStorage); +const modelRegistry = ModelRegistry.create(authStorage); const { session } = await createAgentSession({ sessionManager: SessionManager.inMemory(), @@ -289,7 +289,7 @@ import { getModel } from "@mariozechner/pi-ai"; import { AuthStorage, ModelRegistry } from "@mariozechner/pi-coding-agent"; const authStorage = AuthStorage.create(); -const modelRegistry = new ModelRegistry(authStorage); +const modelRegistry = ModelRegistry.create(authStorage); // Find specific built-in model (doesn't check if API key exists) const opus = getModel("anthropic", "claude-opus-4-5"); @@ -337,7 +337,7 @@ import { AuthStorage, ModelRegistry } from "@mariozechner/pi-coding-agent"; // Default: uses ~/.pi/agent/auth.json and ~/.pi/agent/models.json const authStorage = AuthStorage.create(); -const modelRegistry = new ModelRegistry(authStorage); +const modelRegistry = ModelRegistry.create(authStorage); const { session } = await createAgentSession({ sessionManager: SessionManager.inMemory(), @@ -350,7 +350,7 @@ authStorage.setRuntimeApiKey("anthropic", "sk-my-temp-key"); // Custom auth storage location const customAuth = AuthStorage.create("/my/app/auth.json"); -const customRegistry = new ModelRegistry(customAuth, "/my/app/models.json"); +const customRegistry = ModelRegistry.create(customAuth, "/my/app/models.json"); const { session } = await createAgentSession({ sessionManager: SessionManager.inMemory(), @@ -359,7 +359,7 @@ const { session } = await createAgentSession({ }); // No custom models.json (built-in models only) -const simpleRegistry = new ModelRegistry(authStorage); +const simpleRegistry = ModelRegistry.inMemory(authStorage); ``` > See [examples/sdk/09-api-keys-and-oauth.ts](../examples/sdk/09-api-keys-and-oauth.ts) @@ -788,7 +788,7 @@ if (process.env.MY_KEY) { } // Model registry (no custom models.json) -const modelRegistry = new ModelRegistry(authStorage); +const modelRegistry = ModelRegistry.create(authStorage); // Inline tool const statusTool: ToolDefinition = { diff --git a/packages/coding-agent/examples/sdk/02-custom-model.ts b/packages/coding-agent/examples/sdk/02-custom-model.ts index ccac7bce5..52c2b73e1 100644 --- a/packages/coding-agent/examples/sdk/02-custom-model.ts +++ b/packages/coding-agent/examples/sdk/02-custom-model.ts @@ -9,7 +9,7 @@ import { AuthStorage, createAgentSession, ModelRegistry } from "@mariozechner/pi // Set up auth storage and model registry const authStorage = AuthStorage.create(); -const modelRegistry = new ModelRegistry(authStorage); +const modelRegistry = ModelRegistry.create(authStorage); // Option 1: Find a specific built-in model by provider/id const opus = getModel("anthropic", "claude-opus-4-5"); diff --git a/packages/coding-agent/examples/sdk/09-api-keys-and-oauth.ts b/packages/coding-agent/examples/sdk/09-api-keys-and-oauth.ts index a41f1e8b9..c735fd466 100644 --- a/packages/coding-agent/examples/sdk/09-api-keys-and-oauth.ts +++ b/packages/coding-agent/examples/sdk/09-api-keys-and-oauth.ts @@ -9,7 +9,7 @@ import { AuthStorage, createAgentSession, ModelRegistry, SessionManager } from " // Default: AuthStorage uses ~/.pi/agent/auth.json // ModelRegistry loads built-in + custom models from ~/.pi/agent/models.json const authStorage = AuthStorage.create(); -const modelRegistry = new ModelRegistry(authStorage); +const modelRegistry = ModelRegistry.create(authStorage); await createAgentSession({ sessionManager: SessionManager.inMemory(), @@ -20,7 +20,7 @@ console.log("Session with default auth storage and model registry"); // Custom auth storage location const customAuthStorage = AuthStorage.create("/tmp/my-app/auth.json"); -const customModelRegistry = new ModelRegistry(customAuthStorage, "/tmp/my-app/models.json"); +const customModelRegistry = ModelRegistry.create(customAuthStorage, "/tmp/my-app/models.json"); await createAgentSession({ sessionManager: SessionManager.inMemory(), @@ -39,7 +39,7 @@ await createAgentSession({ console.log("Session with runtime API key override"); // No models.json - only built-in models -const simpleRegistry = new ModelRegistry(authStorage); // null = no models.json +const simpleRegistry = ModelRegistry.inMemory(authStorage); await createAgentSession({ sessionManager: SessionManager.inMemory(), authStorage, diff --git a/packages/coding-agent/examples/sdk/12-full-control.ts b/packages/coding-agent/examples/sdk/12-full-control.ts index 135a52c4f..2387fa3a5 100644 --- a/packages/coding-agent/examples/sdk/12-full-control.ts +++ b/packages/coding-agent/examples/sdk/12-full-control.ts @@ -30,7 +30,7 @@ if (process.env.MY_ANTHROPIC_KEY) { } // Model registry with no custom models.json -const modelRegistry = new ModelRegistry(authStorage); +const modelRegistry = ModelRegistry.inMemory(authStorage); const model = getModel("anthropic", "claude-sonnet-4-20250514"); if (!model) throw new Error("Model not found"); diff --git a/packages/coding-agent/examples/sdk/README.md b/packages/coding-agent/examples/sdk/README.md index 4a3005185..357fd7448 100644 --- a/packages/coding-agent/examples/sdk/README.md +++ b/packages/coding-agent/examples/sdk/README.md @@ -44,7 +44,7 @@ import { // Auth and models setup const authStorage = AuthStorage.create(); -const modelRegistry = new ModelRegistry(authStorage); +const modelRegistry = ModelRegistry.create(authStorage); // Minimal const { session } = await createAgentSession({ authStorage, modelRegistry }); @@ -73,7 +73,7 @@ const { session } = await createAgentSession({ // Full control const customAuth = AuthStorage.create("/my/app/auth.json"); customAuth.setRuntimeApiKey("anthropic", process.env.MY_KEY!); -const customRegistry = new ModelRegistry(customAuth); +const customRegistry = ModelRegistry.create(customAuth); const resourceLoader = new DefaultResourceLoader({ systemPromptOverride: () => "You are helpful.", @@ -109,7 +109,7 @@ await session.prompt("Hello"); | Option | Default | Description | |--------|---------|-------------| | `authStorage` | `AuthStorage.create()` | Credential storage | -| `modelRegistry` | `new ModelRegistry(authStorage)` | Model registry | +| `modelRegistry` | `ModelRegistry.create(authStorage)` | Model registry | | `cwd` | `process.cwd()` | Working directory | | `agentDir` | `~/.pi/agent` | Config directory | | `model` | From settings/first available | Model to use | diff --git a/packages/coding-agent/src/core/model-registry.ts b/packages/coding-agent/src/core/model-registry.ts index c23139adb..cdad12156 100644 --- a/packages/coding-agent/src/core/model-registry.ts +++ b/packages/coding-agent/src/core/model-registry.ts @@ -259,13 +259,21 @@ export class ModelRegistry { private registeredProviders: Map = new Map(); private loadError: string | undefined = undefined; - constructor( + private constructor( readonly authStorage: AuthStorage, - private modelsJsonPath: string | undefined = join(getAgentDir(), "models.json"), + private modelsJsonPath: string | undefined, ) { this.loadModels(); } + static create(authStorage: AuthStorage, modelsJsonPath: string = join(getAgentDir(), "models.json")): ModelRegistry { + return new ModelRegistry(authStorage, modelsJsonPath); + } + + static inMemory(authStorage: AuthStorage): ModelRegistry { + return new ModelRegistry(authStorage, undefined); + } + /** * Reload models from disk (built-in + custom from models.json). */ diff --git a/packages/coding-agent/src/core/sdk.ts b/packages/coding-agent/src/core/sdk.ts index d9ee45c66..02bc8e35b 100644 --- a/packages/coding-agent/src/core/sdk.ts +++ b/packages/coding-agent/src/core/sdk.ts @@ -47,7 +47,7 @@ export interface CreateAgentSessionOptions { /** Auth storage for credentials. Default: AuthStorage.create(agentDir/auth.json) */ authStorage?: AuthStorage; - /** Model registry. Default: new ModelRegistry(authStorage, agentDir/models.json) */ + /** Model registry. Default: ModelRegistry.create(authStorage, agentDir/models.json) */ modelRegistry?: ModelRegistry; /** Model to use. Default: from settings, else first available */ @@ -172,7 +172,7 @@ export async function createAgentSession(options: CreateAgentSessionOptions = {} const authPath = options.agentDir ? join(agentDir, "auth.json") : undefined; const modelsPath = options.agentDir ? join(agentDir, "models.json") : undefined; const authStorage = options.authStorage ?? AuthStorage.create(authPath); - const modelRegistry = options.modelRegistry ?? new ModelRegistry(authStorage, modelsPath); + const modelRegistry = options.modelRegistry ?? ModelRegistry.create(authStorage, modelsPath); const settingsManager = options.settingsManager ?? SettingsManager.create(cwd, agentDir); const sessionManager = options.sessionManager ?? SessionManager.create(cwd, getDefaultSessionDir(cwd, agentDir)); diff --git a/packages/coding-agent/src/main.ts b/packages/coding-agent/src/main.ts index 2569eb65c..3b4bff643 100644 --- a/packages/coding-agent/src/main.ts +++ b/packages/coding-agent/src/main.ts @@ -654,7 +654,7 @@ export async function main(args: string[]) { const settingsManager = SettingsManager.create(cwd, agentDir); reportSettingsErrors(settingsManager, "startup"); const authStorage = AuthStorage.create(); - const modelRegistry = new ModelRegistry(authStorage, getModelsPath()); + const modelRegistry = ModelRegistry.create(authStorage, getModelsPath()); const resourceLoader = new DefaultResourceLoader({ cwd, diff --git a/packages/coding-agent/test/agent-session-auto-compaction-queue.test.ts b/packages/coding-agent/test/agent-session-auto-compaction-queue.test.ts index f24f60ce4..209373710 100644 --- a/packages/coding-agent/test/agent-session-auto-compaction-queue.test.ts +++ b/packages/coding-agent/test/agent-session-auto-compaction-queue.test.ts @@ -76,7 +76,7 @@ describe("AgentSession auto-compaction queue resume", () => { const settingsManager = SettingsManager.create(tempDir, tempDir); const authStorage = AuthStorage.create(join(tempDir, "auth.json")); authStorage.setRuntimeApiKey("anthropic", "test-key"); - const modelRegistry = new ModelRegistry(authStorage, tempDir); + const modelRegistry = ModelRegistry.create(authStorage, tempDir); session = new AgentSession({ agent, diff --git a/packages/coding-agent/test/agent-session-branching.test.ts b/packages/coding-agent/test/agent-session-branching.test.ts index 2c89b4bd7..fdd0d62d6 100644 --- a/packages/coding-agent/test/agent-session-branching.test.ts +++ b/packages/coding-agent/test/agent-session-branching.test.ts @@ -55,7 +55,7 @@ describe.skipIf(!API_KEY)("AgentSession forking", () => { sessionManager = noSession ? SessionManager.inMemory() : SessionManager.create(tempDir); const settingsManager = SettingsManager.create(tempDir, tempDir); const authStorage = AuthStorage.create(join(tempDir, "auth.json")); - const modelRegistry = new ModelRegistry(authStorage, tempDir); + const modelRegistry = ModelRegistry.create(authStorage, tempDir); session = new AgentSession({ agent, diff --git a/packages/coding-agent/test/agent-session-compaction.test.ts b/packages/coding-agent/test/agent-session-compaction.test.ts index cfa62ed7a..feba45eef 100644 --- a/packages/coding-agent/test/agent-session-compaction.test.ts +++ b/packages/coding-agent/test/agent-session-compaction.test.ts @@ -61,7 +61,7 @@ describe.skipIf(!API_KEY)("AgentSession compaction e2e", () => { // Use minimal keepRecentTokens so small test conversations have something to summarize settingsManager.applyOverrides({ compaction: { keepRecentTokens: 1 } }); const authStorage = AuthStorage.create(join(tempDir, "auth.json")); - const modelRegistry = new ModelRegistry(authStorage); + const modelRegistry = ModelRegistry.create(authStorage); session = new AgentSession({ agent, diff --git a/packages/coding-agent/test/agent-session-concurrent.test.ts b/packages/coding-agent/test/agent-session-concurrent.test.ts index 20aa17df2..18ff57a91 100644 --- a/packages/coding-agent/test/agent-session-concurrent.test.ts +++ b/packages/coding-agent/test/agent-session-concurrent.test.ts @@ -110,7 +110,7 @@ describe("AgentSession concurrent prompt guard", () => { const sessionManager = SessionManager.inMemory(); const settingsManager = SettingsManager.create(tempDir, tempDir); const authStorage = AuthStorage.create(join(tempDir, "auth.json")); - const modelRegistry = new ModelRegistry(authStorage, tempDir); + const modelRegistry = ModelRegistry.create(authStorage, tempDir); // Set a runtime API key so validation passes authStorage.setRuntimeApiKey("anthropic", "test-key"); @@ -235,7 +235,7 @@ describe("AgentSession concurrent prompt guard", () => { const sessionManager = SessionManager.inMemory(); const settingsManager = SettingsManager.create(tempDir, tempDir); const authStorage = AuthStorage.create(join(tempDir, "auth.json")); - const modelRegistry = new ModelRegistry(authStorage, tempDir); + const modelRegistry = ModelRegistry.create(authStorage, tempDir); authStorage.setRuntimeApiKey("anthropic", "test-key"); const extensionsResult = await createTestExtensionsResult([ @@ -313,7 +313,7 @@ describe("AgentSession concurrent prompt guard", () => { const sessionManager = SessionManager.inMemory(); const settingsManager = SettingsManager.create(tempDir, tempDir); const authStorage = AuthStorage.create(join(tempDir, "auth.json")); - const modelRegistry = new ModelRegistry(authStorage, tempDir); + const modelRegistry = ModelRegistry.create(authStorage, tempDir); authStorage.setRuntimeApiKey("anthropic", "test-key"); session = new AgentSession({ @@ -419,7 +419,7 @@ describe("AgentSession concurrent prompt guard", () => { const sessionManager = SessionManager.inMemory(); const settingsManager = SettingsManager.create(tempDir, tempDir); const authStorage = AuthStorage.create(join(tempDir, "auth.json")); - const modelRegistry = new ModelRegistry(authStorage, tempDir); + const modelRegistry = ModelRegistry.create(authStorage, tempDir); authStorage.setRuntimeApiKey("anthropic", "test-key"); session = new AgentSession({ @@ -556,7 +556,7 @@ describe("AgentSession concurrent prompt guard", () => { const sessionManager = SessionManager.inMemory(); const settingsManager = SettingsManager.create(tempDir, tempDir); const authStorage = AuthStorage.create(join(tempDir, "auth.json")); - const modelRegistry = new ModelRegistry(authStorage, tempDir); + const modelRegistry = ModelRegistry.create(authStorage, tempDir); authStorage.setRuntimeApiKey("anthropic", "test-key"); session = new AgentSession({ diff --git a/packages/coding-agent/test/agent-session-model-switch-thinking.test.ts b/packages/coding-agent/test/agent-session-model-switch-thinking.test.ts index f77785dcd..28ff93ebc 100644 --- a/packages/coding-agent/test/agent-session-model-switch-thinking.test.ts +++ b/packages/coding-agent/test/agent-session-model-switch-thinking.test.ts @@ -37,7 +37,7 @@ function createSession({ sessionManager, settingsManager, cwd: process.cwd(), - modelRegistry: new ModelRegistry(authStorage, undefined), + modelRegistry: ModelRegistry.inMemory(authStorage), resourceLoader: createTestResourceLoader(), scopedModels, }); diff --git a/packages/coding-agent/test/agent-session-retry.test.ts b/packages/coding-agent/test/agent-session-retry.test.ts index 1697f2aeb..bedeb81ac 100644 --- a/packages/coding-agent/test/agent-session-retry.test.ts +++ b/packages/coding-agent/test/agent-session-retry.test.ts @@ -102,7 +102,7 @@ describe("AgentSession retry", () => { const sessionManager = SessionManager.inMemory(); const settingsManager = SettingsManager.create(tempDir, tempDir); const authStorage = AuthStorage.create(join(tempDir, "auth.json")); - const modelRegistry = new ModelRegistry(authStorage, tempDir); + const modelRegistry = ModelRegistry.create(authStorage, tempDir); authStorage.setRuntimeApiKey("anthropic", "test-key"); settingsManager.applyOverrides({ retry: { enabled: true, maxRetries, baseDelayMs: 1 } }); @@ -204,7 +204,7 @@ describe("AgentSession retry", () => { const sessionManager = SessionManager.inMemory(); const settingsManager = SettingsManager.create(tempDir, tempDir); const authStorage = AuthStorage.create(join(tempDir, "auth.json")); - const modelRegistry = new ModelRegistry(authStorage, tempDir); + const modelRegistry = ModelRegistry.create(authStorage, tempDir); authStorage.setRuntimeApiKey("anthropic", "test-key"); settingsManager.applyOverrides({ retry: { enabled: true, maxRetries: 3, baseDelayMs: 1 } }); session = new AgentSession({ @@ -289,7 +289,7 @@ describe("AgentSession retry", () => { const sessionManager = SessionManager.inMemory(); const settingsManager = SettingsManager.create(tempDir, tempDir); const authStorage = AuthStorage.create(join(tempDir, "auth.json")); - const modelRegistry = new ModelRegistry(authStorage, tempDir); + const modelRegistry = ModelRegistry.create(authStorage, tempDir); authStorage.setRuntimeApiKey("anthropic", "test-key"); settingsManager.applyOverrides({ retry: { enabled: true, maxRetries: 3, baseDelayMs: 1 } }); diff --git a/packages/coding-agent/test/agent-session-stats.test.ts b/packages/coding-agent/test/agent-session-stats.test.ts index d9cd192b2..7c4d4ff25 100644 --- a/packages/coding-agent/test/agent-session-stats.test.ts +++ b/packages/coding-agent/test/agent-session-stats.test.ts @@ -66,7 +66,7 @@ function createSession() { sessionManager, settingsManager, cwd: process.cwd(), - modelRegistry: new ModelRegistry(authStorage, undefined), + modelRegistry: ModelRegistry.inMemory(authStorage), resourceLoader: createTestResourceLoader(), }); diff --git a/packages/coding-agent/test/compaction-extensions.test.ts b/packages/coding-agent/test/compaction-extensions.test.ts index 23ca726bf..b366a4793 100644 --- a/packages/coding-agent/test/compaction-extensions.test.ts +++ b/packages/coding-agent/test/compaction-extensions.test.ts @@ -99,7 +99,7 @@ describe.skipIf(!API_KEY)("Compaction extensions", () => { const sessionManager = SessionManager.create(tempDir); const settingsManager = SettingsManager.create(tempDir, tempDir); const authStorage = AuthStorage.create(join(tempDir, "auth.json")); - const modelRegistry = new ModelRegistry(authStorage); + const modelRegistry = ModelRegistry.create(authStorage); const runtime = createExtensionRuntime(); const resourceLoader = { diff --git a/packages/coding-agent/test/compaction-thinking-model.test.ts b/packages/coding-agent/test/compaction-thinking-model.test.ts index 0f9edb98d..9e4ce7dc2 100644 --- a/packages/coding-agent/test/compaction-thinking-model.test.ts +++ b/packages/coding-agent/test/compaction-thinking-model.test.ts @@ -81,7 +81,7 @@ describe.skipIf(!HAS_ANTIGRAVITY_AUTH)("Compaction with thinking models (Antigra // settingsManager.applyOverrides({ compaction: { keepRecentTokens: 1 } }); const authStorage = getRealAuthStorage(); - const modelRegistry = new ModelRegistry(authStorage); + const modelRegistry = ModelRegistry.create(authStorage); session = new AgentSession({ agent, @@ -177,7 +177,7 @@ describe.skipIf(!HAS_ANTHROPIC_AUTH)("Compaction with thinking models (Anthropic const settingsManager = SettingsManager.create(tempDir, tempDir); const authStorage = getRealAuthStorage(); - const modelRegistry = new ModelRegistry(authStorage); + const modelRegistry = ModelRegistry.create(authStorage); session = new AgentSession({ agent, diff --git a/packages/coding-agent/test/extensions-input-event.test.ts b/packages/coding-agent/test/extensions-input-event.test.ts index 562378619..5c08fd2ce 100644 --- a/packages/coding-agent/test/extensions-input-event.test.ts +++ b/packages/coding-agent/test/extensions-input-event.test.ts @@ -29,7 +29,7 @@ describe("Input Event", () => { for (let i = 0; i < extensions.length; i++) fs.writeFileSync(path.join(extensionsDir, `e${i}.ts`), extensions[i]); const result = await discoverAndLoadExtensions([], tempDir, tempDir); const sm = SessionManager.inMemory(); - const mr = new ModelRegistry(AuthStorage.create(path.join(tempDir, "auth.json"))); + const mr = ModelRegistry.create(AuthStorage.create(path.join(tempDir, "auth.json"))); return new ExtensionRunner(result.extensions, result.runtime, tempDir, sm, mr); } diff --git a/packages/coding-agent/test/extensions-runner.test.ts b/packages/coding-agent/test/extensions-runner.test.ts index 1b5841739..e60572d6d 100644 --- a/packages/coding-agent/test/extensions-runner.test.ts +++ b/packages/coding-agent/test/extensions-runner.test.ts @@ -27,7 +27,7 @@ describe("ExtensionRunner", () => { fs.mkdirSync(extensionsDir); sessionManager = SessionManager.inMemory(); const authStorage = AuthStorage.create(path.join(tempDir, "auth.json")); - modelRegistry = new ModelRegistry(authStorage); + modelRegistry = ModelRegistry.create(authStorage); }); afterEach(() => { diff --git a/packages/coding-agent/test/model-registry.test.ts b/packages/coding-agent/test/model-registry.test.ts index 40d5fe20e..75fd1b5e7 100644 --- a/packages/coding-agent/test/model-registry.test.ts +++ b/packages/coding-agent/test/model-registry.test.ts @@ -94,7 +94,7 @@ describe("ModelRegistry", () => { anthropic: overrideConfig("https://my-proxy.example.com/v1"), }); - const registry = new ModelRegistry(authStorage, modelsJsonPath); + const registry = ModelRegistry.create(authStorage, modelsJsonPath); const anthropicModels = getModelsForProvider(registry, "anthropic"); // Should have multiple built-in models, not just one @@ -107,7 +107,7 @@ describe("ModelRegistry", () => { anthropic: overrideConfig("https://my-proxy.example.com/v1"), }); - const registry = new ModelRegistry(authStorage, modelsJsonPath); + const registry = ModelRegistry.create(authStorage, modelsJsonPath); const anthropicModels = getModelsForProvider(registry, "anthropic"); // All models should have the new baseUrl @@ -123,7 +123,7 @@ describe("ModelRegistry", () => { }), }); - const registry = new ModelRegistry(authStorage, modelsJsonPath); + const registry = ModelRegistry.create(authStorage, modelsJsonPath); const anthropicModels = getModelsForProvider(registry, "anthropic"); for (const model of anthropicModels) { @@ -140,7 +140,7 @@ describe("ModelRegistry", () => { anthropic: overrideConfig("https://my-proxy.example.com/v1"), }); - const registry = new ModelRegistry(authStorage, modelsJsonPath); + const registry = ModelRegistry.create(authStorage, modelsJsonPath); const googleModels = getModelsForProvider(registry, "google"); // Google models should still have their original baseUrl @@ -160,7 +160,7 @@ describe("ModelRegistry", () => { ), }); - const registry = new ModelRegistry(authStorage, modelsJsonPath); + const registry = ModelRegistry.create(authStorage, modelsJsonPath); // Anthropic: multiple built-in models with new baseUrl const anthropicModels = getModelsForProvider(registry, "anthropic"); @@ -177,7 +177,7 @@ describe("ModelRegistry", () => { writeRawModelsJson({ anthropic: overrideConfig("https://first-proxy.example.com/v1"), }); - const registry = new ModelRegistry(authStorage, modelsJsonPath); + const registry = ModelRegistry.create(authStorage, modelsJsonPath); expect(getModelsForProvider(registry, "anthropic")[0].baseUrl).toBe("https://first-proxy.example.com/v1"); @@ -197,7 +197,7 @@ describe("ModelRegistry", () => { anthropic: providerConfig("https://my-proxy.example.com/v1", [{ id: "claude-custom" }]), }); - const registry = new ModelRegistry(authStorage, modelsJsonPath); + const registry = ModelRegistry.create(authStorage, modelsJsonPath); const anthropicModels = getModelsForProvider(registry, "anthropic"); expect(anthropicModels.length).toBeGreaterThan(1); @@ -214,7 +214,7 @@ describe("ModelRegistry", () => { ), }); - const registry = new ModelRegistry(authStorage, modelsJsonPath); + const registry = ModelRegistry.create(authStorage, modelsJsonPath); const models = getModelsForProvider(registry, "openrouter"); const sonnetModels = models.filter((m) => m.id === "anthropic/claude-sonnet-4"); @@ -227,7 +227,7 @@ describe("ModelRegistry", () => { anthropic: providerConfig("https://my-proxy.example.com/v1", [{ id: "claude-custom" }]), }); - const registry = new ModelRegistry(authStorage, modelsJsonPath); + const registry = ModelRegistry.create(authStorage, modelsJsonPath); expect(getModelsForProvider(registry, "google").length).toBeGreaterThan(0); expect(getModelsForProvider(registry, "openai").length).toBeGreaterThan(0); @@ -238,7 +238,7 @@ describe("ModelRegistry", () => { anthropic: providerConfig("https://merged-proxy.example.com/v1", [{ id: "claude-custom" }]), }); - const registry = new ModelRegistry(authStorage, modelsJsonPath); + const registry = ModelRegistry.create(authStorage, modelsJsonPath); const anthropicModels = getModelsForProvider(registry, "anthropic"); for (const model of anthropicModels) { @@ -269,7 +269,7 @@ describe("ModelRegistry", () => { }, }); - const registry = new ModelRegistry(authStorage, modelsJsonPath); + const registry = ModelRegistry.create(authStorage, modelsJsonPath); const compat = registry.find("demo", "demo-model")?.compat as OpenAICompletionsCompat | undefined; expect(compat?.supportsUsageInStreaming).toBe(false); @@ -303,7 +303,7 @@ describe("ModelRegistry", () => { }, }); - const registry = new ModelRegistry(authStorage, modelsJsonPath); + const registry = ModelRegistry.create(authStorage, modelsJsonPath); const compat = registry.find("demo", "demo-model")?.compat as OpenAICompletionsCompat | undefined; expect(compat?.supportsUsageInStreaming).toBe(true); @@ -320,7 +320,7 @@ describe("ModelRegistry", () => { }, }); - const registry = new ModelRegistry(authStorage, modelsJsonPath); + const registry = ModelRegistry.create(authStorage, modelsJsonPath); const models = getModelsForProvider(registry, "openrouter"); expect(models.length).toBeGreaterThan(0); @@ -357,7 +357,7 @@ describe("ModelRegistry", () => { }, }); - const registry = new ModelRegistry(authStorage, modelsJsonPath); + const registry = ModelRegistry.create(authStorage, modelsJsonPath); const compat = registry.find("demo", "demo-model")?.compat as OpenAICompletionsCompat | undefined; expect(registry.getError()).toBeUndefined(); @@ -394,7 +394,7 @@ describe("ModelRegistry", () => { }, }); - const registry = new ModelRegistry(authStorage, modelsJsonPath); + const registry = ModelRegistry.create(authStorage, modelsJsonPath); const m25 = registry.find("opencode-go", "minimax-m2.5"); const glm5 = registry.find("opencode-go", "glm-5"); @@ -427,7 +427,7 @@ describe("ModelRegistry", () => { }, }); - const registry = new ModelRegistry(authStorage, modelsJsonPath); + const registry = ModelRegistry.create(authStorage, modelsJsonPath); const models = getModelsForProvider(registry, "openrouter"); expect(models.some((m) => m.id === "custom/openrouter-model")).toBe(true); @@ -440,7 +440,7 @@ describe("ModelRegistry", () => { writeModelsJson({ anthropic: providerConfig("https://first-proxy.example.com/v1", [{ id: "claude-custom" }]), }); - const registry = new ModelRegistry(authStorage, modelsJsonPath); + const registry = ModelRegistry.create(authStorage, modelsJsonPath); expect(getModelsForProvider(registry, "anthropic").some((m) => m.id === "claude-custom")).toBe(true); // Update and refresh @@ -459,7 +459,7 @@ describe("ModelRegistry", () => { writeModelsJson({ anthropic: providerConfig("https://proxy.example.com/v1", [{ id: "claude-custom" }]), }); - const registry = new ModelRegistry(authStorage, modelsJsonPath); + const registry = ModelRegistry.create(authStorage, modelsJsonPath); expect(getModelsForProvider(registry, "anthropic").some((m) => m.id === "claude-custom")).toBe(true); // Remove custom models and refresh @@ -484,7 +484,7 @@ describe("ModelRegistry", () => { }, }); - const registry = new ModelRegistry(authStorage, modelsJsonPath); + const registry = ModelRegistry.create(authStorage, modelsJsonPath); const models = getModelsForProvider(registry, "openrouter"); const sonnet = models.find((m) => m.id === "anthropic/claude-sonnet-4"); @@ -508,7 +508,7 @@ describe("ModelRegistry", () => { }, }); - const registry = new ModelRegistry(authStorage, modelsJsonPath); + const registry = ModelRegistry.create(authStorage, modelsJsonPath); const models = getModelsForProvider(registry, "openrouter"); const sonnet = models.find((m) => m.id === "anthropic/claude-sonnet-4"); @@ -529,7 +529,7 @@ describe("ModelRegistry", () => { }, }); - const registry = new ModelRegistry(authStorage, modelsJsonPath); + const registry = ModelRegistry.create(authStorage, modelsJsonPath); const models = getModelsForProvider(registry, "openrouter"); const sonnet = models.find((m) => m.id === "anthropic/claude-sonnet-4"); @@ -552,7 +552,7 @@ describe("ModelRegistry", () => { }, }); - const registry = new ModelRegistry(authStorage, modelsJsonPath); + const registry = ModelRegistry.create(authStorage, modelsJsonPath); const models = getModelsForProvider(registry, "openrouter"); const sonnet = models.find((m) => m.id === "anthropic/claude-sonnet-4"); @@ -576,7 +576,7 @@ describe("ModelRegistry", () => { }, }); - const registry = new ModelRegistry(authStorage, modelsJsonPath); + const registry = ModelRegistry.create(authStorage, modelsJsonPath); const models = getModelsForProvider(registry, "openrouter"); const sonnet = models.find((m) => m.id === "anthropic/claude-sonnet-4"); @@ -601,7 +601,7 @@ describe("ModelRegistry", () => { }, }); - const registry = new ModelRegistry(authStorage, modelsJsonPath); + const registry = ModelRegistry.create(authStorage, modelsJsonPath); const models = getModelsForProvider(registry, "openrouter"); // Should not create a new model @@ -621,7 +621,7 @@ describe("ModelRegistry", () => { }, }); - const registry = new ModelRegistry(authStorage, modelsJsonPath); + const registry = ModelRegistry.create(authStorage, modelsJsonPath); const models = getModelsForProvider(registry, "openrouter"); const sonnet = models.find((m) => m.id === "anthropic/claude-sonnet-4"); @@ -642,7 +642,7 @@ describe("ModelRegistry", () => { }, }); - const registry = new ModelRegistry(authStorage, modelsJsonPath); + const registry = ModelRegistry.create(authStorage, modelsJsonPath); const models = getModelsForProvider(registry, "openrouter"); const sonnet = models.find((m) => m.id === "anthropic/claude-sonnet-4"); expect(sonnet).toBeDefined(); @@ -665,7 +665,7 @@ describe("ModelRegistry", () => { }, }); - const registry = new ModelRegistry(authStorage, modelsJsonPath); + const registry = ModelRegistry.create(authStorage, modelsJsonPath); expect( getModelsForProvider(registry, "openrouter").find((m) => m.id === "anthropic/claude-sonnet-4")?.name, ).toBe("First Name"); @@ -698,7 +698,7 @@ describe("ModelRegistry", () => { }, }); - const registry = new ModelRegistry(authStorage, modelsJsonPath); + const registry = ModelRegistry.create(authStorage, modelsJsonPath); const customName = getModelsForProvider(registry, "openrouter").find( (m) => m.id === "anthropic/claude-sonnet-4", )?.name; @@ -717,7 +717,7 @@ describe("ModelRegistry", () => { describe("dynamic provider lifecycle", () => { test("failed registerProvider does not persist invalid streamSimple config", () => { - const registry = new ModelRegistry(authStorage, modelsJsonPath); + const registry = ModelRegistry.create(authStorage, modelsJsonPath); expect(() => registry.registerProvider("broken-provider", { @@ -731,7 +731,7 @@ describe("ModelRegistry", () => { }); test("failed registerProvider does not remove existing provider models", () => { - const registry = new ModelRegistry(authStorage, modelsJsonPath); + const registry = ModelRegistry.create(authStorage, modelsJsonPath); registry.registerProvider("demo-provider", { baseUrl: "https://provider.test/v1", @@ -776,7 +776,7 @@ describe("ModelRegistry", () => { }); test("unregisterProvider removes custom OAuth provider and restores built-in OAuth provider", () => { - const registry = new ModelRegistry(authStorage, modelsJsonPath); + const registry = ModelRegistry.create(authStorage, modelsJsonPath); registry.registerProvider("anthropic", { oauth: { @@ -799,7 +799,7 @@ describe("ModelRegistry", () => { }); test("unregisterProvider removes custom streamSimple override and restores built-in API stream handler", () => { - const registry = new ModelRegistry(authStorage, modelsJsonPath); + const registry = ModelRegistry.create(authStorage, modelsJsonPath); registry.registerProvider("stream-override-provider", { api: "openai-completions", @@ -855,7 +855,7 @@ describe("ModelRegistry", () => { "custom-provider": providerWithApiKey("!echo test-api-key-from-command"), }); - const registry = new ModelRegistry(authStorage, modelsJsonPath); + const registry = ModelRegistry.create(authStorage, modelsJsonPath); const apiKey = await registry.getApiKeyForProvider("custom-provider"); expect(apiKey).toBe("test-api-key-from-command"); @@ -866,7 +866,7 @@ describe("ModelRegistry", () => { "custom-provider": providerWithApiKey("!echo ' spaced-key '"), }); - const registry = new ModelRegistry(authStorage, modelsJsonPath); + const registry = ModelRegistry.create(authStorage, modelsJsonPath); const apiKey = await registry.getApiKeyForProvider("custom-provider"); expect(apiKey).toBe("spaced-key"); @@ -877,7 +877,7 @@ describe("ModelRegistry", () => { "custom-provider": providerWithApiKey("!printf 'line1\\nline2'"), }); - const registry = new ModelRegistry(authStorage, modelsJsonPath); + const registry = ModelRegistry.create(authStorage, modelsJsonPath); const apiKey = await registry.getApiKeyForProvider("custom-provider"); expect(apiKey).toBe("line1\nline2"); @@ -888,7 +888,7 @@ describe("ModelRegistry", () => { "custom-provider": providerWithApiKey("!exit 1"), }); - const registry = new ModelRegistry(authStorage, modelsJsonPath); + const registry = ModelRegistry.create(authStorage, modelsJsonPath); const apiKey = await registry.getApiKeyForProvider("custom-provider"); expect(apiKey).toBeUndefined(); @@ -899,7 +899,7 @@ describe("ModelRegistry", () => { "custom-provider": providerWithApiKey("!nonexistent-command-12345"), }); - const registry = new ModelRegistry(authStorage, modelsJsonPath); + const registry = ModelRegistry.create(authStorage, modelsJsonPath); const apiKey = await registry.getApiKeyForProvider("custom-provider"); expect(apiKey).toBeUndefined(); @@ -910,7 +910,7 @@ describe("ModelRegistry", () => { "custom-provider": providerWithApiKey("!printf ''"), }); - const registry = new ModelRegistry(authStorage, modelsJsonPath); + const registry = ModelRegistry.create(authStorage, modelsJsonPath); const apiKey = await registry.getApiKeyForProvider("custom-provider"); expect(apiKey).toBeUndefined(); @@ -925,7 +925,7 @@ describe("ModelRegistry", () => { "custom-provider": providerWithApiKey("TEST_API_KEY_12345"), }); - const registry = new ModelRegistry(authStorage, modelsJsonPath); + const registry = ModelRegistry.create(authStorage, modelsJsonPath); const apiKey = await registry.getApiKeyForProvider("custom-provider"); expect(apiKey).toBe("env-api-key-value"); @@ -946,7 +946,7 @@ describe("ModelRegistry", () => { "custom-provider": providerWithApiKey("literal_api_key_value"), }); - const registry = new ModelRegistry(authStorage, modelsJsonPath); + const registry = ModelRegistry.create(authStorage, modelsJsonPath); const apiKey = await registry.getApiKeyForProvider("custom-provider"); expect(apiKey).toBe("literal_api_key_value"); @@ -957,7 +957,7 @@ describe("ModelRegistry", () => { "custom-provider": providerWithApiKey("!echo 'hello world' | tr ' ' '-'"), }); - const registry = new ModelRegistry(authStorage, modelsJsonPath); + const registry = ModelRegistry.create(authStorage, modelsJsonPath); const apiKey = await registry.getApiKeyForProvider("custom-provider"); expect(apiKey).toBe("hello-world"); @@ -974,7 +974,7 @@ describe("ModelRegistry", () => { "custom-provider": providerWithApiKey(command), }); - const registry = new ModelRegistry(authStorage, modelsJsonPath); + const registry = ModelRegistry.create(authStorage, modelsJsonPath); await registry.getApiKeyForProvider("custom-provider"); await registry.getApiKeyForProvider("custom-provider"); await registry.getApiKeyForProvider("custom-provider"); @@ -993,10 +993,10 @@ describe("ModelRegistry", () => { "custom-provider": providerWithApiKey(command), }); - const registry1 = new ModelRegistry(authStorage, modelsJsonPath); + const registry1 = ModelRegistry.create(authStorage, modelsJsonPath); await registry1.getApiKeyForProvider("custom-provider"); - const registry2 = new ModelRegistry(authStorage, modelsJsonPath); + const registry2 = ModelRegistry.create(authStorage, modelsJsonPath); await registry2.getApiKeyForProvider("custom-provider"); const count = parseInt(readFileSync(counterFile, "utf-8").trim(), 10); @@ -1009,7 +1009,7 @@ describe("ModelRegistry", () => { "provider-b": providerWithApiKey("!echo key-b"), }); - const registry = new ModelRegistry(authStorage, modelsJsonPath); + const registry = ModelRegistry.create(authStorage, modelsJsonPath); const keyA = await registry.getApiKeyForProvider("provider-a"); const keyB = await registry.getApiKeyForProvider("provider-b"); @@ -1028,7 +1028,7 @@ describe("ModelRegistry", () => { "custom-provider": providerWithApiKey(command), }); - const registry = new ModelRegistry(authStorage, modelsJsonPath); + const registry = ModelRegistry.create(authStorage, modelsJsonPath); const key1 = await registry.getApiKeyForProvider("custom-provider"); const key2 = await registry.getApiKeyForProvider("custom-provider"); @@ -1050,7 +1050,7 @@ describe("ModelRegistry", () => { "custom-provider": providerWithApiKey(envVarName), }); - const registry = new ModelRegistry(authStorage, modelsJsonPath); + const registry = ModelRegistry.create(authStorage, modelsJsonPath); const key1 = await registry.getApiKeyForProvider("custom-provider"); expect(key1).toBe("first-value"); @@ -1078,7 +1078,7 @@ describe("ModelRegistry", () => { "custom-provider": providerWithApiKey(command), }); - const registry = new ModelRegistry(authStorage, modelsJsonPath); + const registry = ModelRegistry.create(authStorage, modelsJsonPath); const available = registry.getAvailable(); expect(available.some((m) => m.provider === "custom-provider")).toBe(true); @@ -1098,7 +1098,7 @@ describe("ModelRegistry", () => { }, }); - const registry = new ModelRegistry(authStorage, modelsJsonPath); + const registry = ModelRegistry.create(authStorage, modelsJsonPath); const model = registry.find("custom-provider", "test-model"); expect(model).toBeDefined(); @@ -1127,7 +1127,7 @@ describe("ModelRegistry", () => { }, }); - const registry = new ModelRegistry(authStorage, modelsJsonPath); + const registry = ModelRegistry.create(authStorage, modelsJsonPath); const model = registry.find("custom-provider", "test-model"); expect(model).toBeDefined(); diff --git a/packages/coding-agent/test/resource-loader.test.ts b/packages/coding-agent/test/resource-loader.test.ts index 13678653b..29877ba50 100644 --- a/packages/coding-agent/test/resource-loader.test.ts +++ b/packages/coding-agent/test/resource-loader.test.ts @@ -199,7 +199,7 @@ Project skill`, const sessionManager = SessionManager.inMemory(); const authStorage = AuthStorage.create(join(tempDir, "auth.json")); - const modelRegistry = new ModelRegistry(authStorage); + const modelRegistry = ModelRegistry.create(authStorage); const runner = new ExtensionRunner( extensionsResult.extensions, extensionsResult.runtime, @@ -548,7 +548,7 @@ export default function(pi: ExtensionAPI) { const sessionManager = SessionManager.inMemory(); const authStorage = AuthStorage.create(join(tempDir, "auth-explicit.json")); - const modelRegistry = new ModelRegistry(authStorage); + const modelRegistry = ModelRegistry.create(authStorage); const runner = new ExtensionRunner( extensionsResult.extensions, extensionsResult.runtime, diff --git a/packages/coding-agent/test/sdk-codex-cache-probe-tool-loop.ts b/packages/coding-agent/test/sdk-codex-cache-probe-tool-loop.ts index 9918ccf6d..96a9b77d7 100644 --- a/packages/coding-agent/test/sdk-codex-cache-probe-tool-loop.ts +++ b/packages/coding-agent/test/sdk-codex-cache-probe-tool-loop.ts @@ -204,7 +204,7 @@ async function main(): Promise { mkdirSync(dirname(args.sessionPath), { recursive: true }); const authStorage = AuthStorage.create(); - const modelRegistry = new ModelRegistry(authStorage); + const modelRegistry = ModelRegistry.create(authStorage); const model = getModel("openai-codex", "gpt-5.4"); if (!model) { diff --git a/packages/coding-agent/test/suite/harness.ts b/packages/coding-agent/test/suite/harness.ts new file mode 100644 index 000000000..9262f0130 --- /dev/null +++ b/packages/coding-agent/test/suite/harness.ts @@ -0,0 +1,141 @@ +/** + * Local test harness for the new coding-agent test suite. + */ + +import { existsSync, mkdirSync, rmSync } from "node:fs"; +import { tmpdir } from "node:os"; +import { join } from "node:path"; +import type { AgentTool } from "@mariozechner/pi-agent-core"; +import { Agent } from "@mariozechner/pi-agent-core"; +import type { FauxModelDefinition, FauxProviderRegistration, FauxResponseStep, Model } from "@mariozechner/pi-ai"; +import { registerFauxProvider } from "@mariozechner/pi-ai"; +import { AgentSession, type AgentSessionEvent } from "../../src/core/agent-session.js"; +import { AuthStorage } from "../../src/core/auth-storage.js"; +import { ModelRegistry } from "../../src/core/model-registry.js"; +import { SessionManager } from "../../src/core/session-manager.js"; +import type { Settings } from "../../src/core/settings-manager.js"; +import { SettingsManager } from "../../src/core/settings-manager.js"; +import type { ExtensionFactory, ResourceLoader } from "../../src/index.js"; +import { + type CreateTestExtensionsResultInput, + createTestExtensionsResult, + createTestResourceLoader, +} from "../utilities.js"; + +export interface HarnessOptions { + models?: FauxModelDefinition[]; + settings?: Partial; + systemPrompt?: string; + tools?: AgentTool[]; + resourceLoader?: ResourceLoader; + extensionFactories?: Array; +} + +export interface Harness { + session: AgentSession; + sessionManager: SessionManager; + settingsManager: SettingsManager; + faux: FauxProviderRegistration; + models: [Model, ...Model[]]; + getModel(): Model; + getModel(modelId: string): Model | undefined; + setResponses: (responses: FauxResponseStep[]) => void; + appendResponses: (responses: FauxResponseStep[]) => void; + getPendingResponseCount: () => number; + events: AgentSessionEvent[]; + eventsOfType(type: T): Extract[]; + tempDir: string; + cleanup: () => void; +} + +function createTempDir(): string { + const tempDir = join(tmpdir(), `pi-suite-${Date.now()}-${Math.random().toString(36).slice(2)}`); + mkdirSync(tempDir, { recursive: true }); + return tempDir; +} + +export async function createHarness(options: HarnessOptions = {}): Promise { + const tempDir = createTempDir(); + const fauxProvider: FauxProviderRegistration = registerFauxProvider({ + models: options.models, + }); + fauxProvider.setResponses([]); + const model = fauxProvider.getModel(); + const toolMap = options.tools ? Object.fromEntries(options.tools.map((tool) => [tool.name, tool])) : undefined; + + const agent = new Agent({ + getApiKey: () => "faux-key", + initialState: { + model, + systemPrompt: options.systemPrompt ?? "You are a test assistant.", + tools: [], + }, + }); + + const sessionManager = SessionManager.inMemory(); + const settingsManager = SettingsManager.inMemory(options.settings); + + const authStorage = AuthStorage.inMemory(); + authStorage.setRuntimeApiKey(model.provider, "faux-key"); + const modelRegistry = ModelRegistry.inMemory(authStorage); + modelRegistry.registerProvider(model.provider, { + baseUrl: model.baseUrl, + apiKey: "faux-key", + api: fauxProvider.api, + models: fauxProvider.models.map((registeredModel) => ({ + id: registeredModel.id, + name: registeredModel.name, + api: registeredModel.api, + reasoning: registeredModel.reasoning, + input: registeredModel.input, + cost: registeredModel.cost, + contextWindow: registeredModel.contextWindow, + maxTokens: registeredModel.maxTokens, + baseUrl: registeredModel.baseUrl, + })), + }); + const extensionsResult = options.extensionFactories + ? await createTestExtensionsResult(options.extensionFactories, tempDir) + : undefined; + const resourceLoader = + options.resourceLoader ?? createTestResourceLoader(extensionsResult ? { extensionsResult } : undefined); + + const session = new AgentSession({ + agent, + sessionManager, + settingsManager, + cwd: tempDir, + modelRegistry, + resourceLoader, + baseToolsOverride: toolMap, + }); + + const events: AgentSessionEvent[] = []; + session.subscribe((event) => { + events.push(event); + }); + + return { + session, + sessionManager, + settingsManager, + faux: fauxProvider, + models: fauxProvider.models, + getModel: fauxProvider.getModel, + setResponses: fauxProvider.setResponses, + appendResponses: fauxProvider.appendResponses, + getPendingResponseCount: fauxProvider.getPendingResponseCount, + events, + eventsOfType(type: T) { + return events.filter((event): event is Extract => event.type === type); + }, + tempDir, + cleanup() { + session.dispose(); + fauxProvider.unregister(); + if (existsSync(tempDir)) { + rmSync(tempDir, { recursive: true }); + } + }, + }; +} diff --git a/packages/coding-agent/test/test-harness.ts b/packages/coding-agent/test/test-harness.ts index 300ea2b55..9b4672f19 100644 --- a/packages/coding-agent/test/test-harness.ts +++ b/packages/coding-agent/test/test-harness.ts @@ -390,7 +390,7 @@ function createHarnessWithResourceLoader( const authStorage = AuthStorage.create(join(tempDir, "auth.json")); authStorage.setRuntimeApiKey(model.provider, "faux-key"); - const modelRegistry = new ModelRegistry(authStorage, tempDir); + const modelRegistry = ModelRegistry.create(authStorage, tempDir); const session = new AgentSession({ agent, diff --git a/packages/coding-agent/test/utilities.ts b/packages/coding-agent/test/utilities.ts index 6a2f5d6a3..5ede0a574 100644 --- a/packages/coding-agent/test/utilities.ts +++ b/packages/coding-agent/test/utilities.ts @@ -254,7 +254,7 @@ export function createTestSession(options: TestSessionOptions = {}): TestSession } const authStorage = AuthStorage.create(join(tempDir, "auth.json")); - const modelRegistry = new ModelRegistry(authStorage, tempDir); + const modelRegistry = ModelRegistry.create(authStorage, tempDir); const session = new AgentSession({ agent, diff --git a/packages/mom/src/agent.ts b/packages/mom/src/agent.ts index 88d2d1879..142290732 100644 --- a/packages/mom/src/agent.ts +++ b/packages/mom/src/agent.ts @@ -429,7 +429,7 @@ function createRunner(sandboxConfig: SandboxConfig, channelId: string, channelDi // Create AuthStorage and ModelRegistry // Auth stored outside workspace so agent can't access it const authStorage = AuthStorage.create(join(homedir(), ".pi", "mom", "auth.json")); - const modelRegistry = new ModelRegistry(authStorage); + const modelRegistry = ModelRegistry.create(authStorage); // Create agent const agent = new Agent({