diff --git a/packages/coding-agent/src/core/agent-session.ts b/packages/coding-agent/src/core/agent-session.ts index 111adcd2c..4517dc318 100644 --- a/packages/coding-agent/src/core/agent-session.ts +++ b/packages/coding-agent/src/core/agent-session.ts @@ -2057,6 +2057,20 @@ export class AgentSession { : undefined; } + private _refreshCurrentModelFromRegistry(): void { + const currentModel = this.model; + if (!currentModel) { + return; + } + + const refreshedModel = this._modelRegistry.find(currentModel.provider, currentModel.id); + if (!refreshedModel || refreshedModel === currentModel) { + return; + } + + this.agent.setModel(refreshedModel); + } + private _bindExtensionCore(runner: ExtensionRunner): void { const normalizeLocation = (source: string): SlashCommandLocation | undefined => { if (source === "user" || source === "project" || source === "path") { @@ -2165,6 +2179,16 @@ export class AgentSession { }, getSystemPrompt: () => this.systemPrompt, }, + { + registerProvider: (name, config) => { + this._modelRegistry.registerProvider(name, config); + this._refreshCurrentModelFromRegistry(); + }, + unregisterProvider: (name) => { + this._modelRegistry.unregisterProvider(name); + this._refreshCurrentModelFromRegistry(); + }, + }, ); } diff --git a/packages/coding-agent/src/core/extensions/runner.ts b/packages/coding-agent/src/core/extensions/runner.ts index 05f2d786f..65234bb4a 100644 --- a/packages/coding-agent/src/core/extensions/runner.ts +++ b/packages/coding-agent/src/core/extensions/runner.ts @@ -34,6 +34,7 @@ import type { InputEventResult, InputSource, MessageRenderer, + ProviderConfig, RegisteredCommand, RegisteredTool, ResourcesDiscoverEvent, @@ -235,7 +236,14 @@ export class ExtensionRunner { this.modelRegistry = modelRegistry; } - bindCore(actions: ExtensionActions, contextActions: ExtensionContextActions): void { + bindCore( + actions: ExtensionActions, + contextActions: ExtensionContextActions, + providerActions?: { + registerProvider?: (name: string, config: ProviderConfig) => void; + unregisterProvider?: (name: string) => void; + }, + ): void { // Copy actions into the shared runtime (all extension APIs reference this) this.runtime.sendMessage = actions.sendMessage; this.runtime.sendUserMessage = actions.sendUserMessage; @@ -264,14 +272,30 @@ export class ExtensionRunner { // Flush provider registrations queued during extension loading for (const { name, config } of this.runtime.pendingProviderRegistrations) { - this.modelRegistry.registerProvider(name, config); + if (providerActions?.registerProvider) { + providerActions.registerProvider(name, config); + } else { + this.modelRegistry.registerProvider(name, config); + } } this.runtime.pendingProviderRegistrations = []; // From this point on, provider registration/unregistration takes effect immediately // without requiring a /reload. - this.runtime.registerProvider = (name, config) => this.modelRegistry.registerProvider(name, config); - this.runtime.unregisterProvider = (name) => this.modelRegistry.unregisterProvider(name); + this.runtime.registerProvider = (name, config) => { + if (providerActions?.registerProvider) { + providerActions.registerProvider(name, config); + return; + } + this.modelRegistry.registerProvider(name, config); + }; + this.runtime.unregisterProvider = (name) => { + if (providerActions?.unregisterProvider) { + providerActions.unregisterProvider(name); + return; + } + this.modelRegistry.unregisterProvider(name); + }; } bindCommandContext(actions?: ExtensionCommandContextActions): void { diff --git a/packages/coding-agent/test/agent-session-dynamic-provider.test.ts b/packages/coding-agent/test/agent-session-dynamic-provider.test.ts new file mode 100644 index 000000000..6c3b81352 --- /dev/null +++ b/packages/coding-agent/test/agent-session-dynamic-provider.test.ts @@ -0,0 +1,117 @@ +import { existsSync, mkdirSync, rmSync } from "node:fs"; +import { tmpdir } from "node:os"; +import { join } from "node:path"; +import { getModel } from "@mariozechner/pi-ai"; +import { afterEach, beforeEach, describe, expect, it } from "vitest"; +import { AuthStorage } from "../src/core/auth-storage.js"; +import { DefaultResourceLoader } from "../src/core/resource-loader.js"; +import type { ExtensionFactory } from "../src/core/sdk.js"; +import { createAgentSession } from "../src/core/sdk.js"; +import { SessionManager } from "../src/core/session-manager.js"; +import { SettingsManager } from "../src/core/settings-manager.js"; + +describe("AgentSession dynamic provider registration", () => { + let tempDir: string; + let agentDir: string; + + beforeEach(() => { + tempDir = join(tmpdir(), `pi-dynamic-provider-test-${Date.now()}-${Math.random().toString(36).slice(2)}`); + agentDir = join(tempDir, "agent"); + mkdirSync(agentDir, { recursive: true }); + }); + + afterEach(() => { + if (tempDir && existsSync(tempDir)) { + rmSync(tempDir, { recursive: true, force: true }); + } + }); + + async function createSession(extensionFactories: ExtensionFactory[]) { + const settingsManager = SettingsManager.create(tempDir, agentDir); + const sessionManager = SessionManager.inMemory(); + const authStorage = AuthStorage.create(join(agentDir, "auth.json")); + authStorage.setRuntimeApiKey("anthropic", "test-key"); + const resourceLoader = new DefaultResourceLoader({ + cwd: tempDir, + agentDir, + settingsManager, + extensionFactories, + }); + await resourceLoader.reload(); + + const { session } = await createAgentSession({ + cwd: tempDir, + agentDir, + model: getModel("anthropic", "claude-sonnet-4-5")!, + settingsManager, + sessionManager, + authStorage, + resourceLoader, + }); + + return session; + } + + async function capturePromptBaseUrl( + session: Awaited>, + ): Promise { + let baseUrl: string | undefined; + session.agent.streamFn = async (model) => { + baseUrl = model.baseUrl; + throw new Error("stop"); + }; + await session.prompt("hello"); + return baseUrl; + } + + it("applies top-level registerProvider overrides to the active model", async () => { + const session = await createSession([ + (pi) => { + pi.registerProvider("anthropic", { baseUrl: "http://localhost:8080/top-level" }); + }, + ]); + + expect(session.model?.baseUrl).toBe("http://localhost:8080/top-level"); + expect(await capturePromptBaseUrl(session)).toBe("http://localhost:8080/top-level"); + + session.dispose(); + }); + + it("applies session_start registerProvider overrides to the active model", async () => { + const session = await createSession([ + (pi) => { + pi.on("session_start", () => { + pi.registerProvider("anthropic", { baseUrl: "http://localhost:8080/session-start" }); + }); + }, + ]); + + await session.bindExtensions({}); + + expect(session.model?.baseUrl).toBe("http://localhost:8080/session-start"); + expect(await capturePromptBaseUrl(session)).toBe("http://localhost:8080/session-start"); + + session.dispose(); + }); + + it("applies command-time registerProvider overrides without reload", async () => { + const session = await createSession([ + (pi) => { + pi.registerCommand("use-proxy", { + description: "Use proxy", + handler: async () => { + pi.registerProvider("anthropic", { baseUrl: "http://localhost:8080/command" }); + }, + }); + }, + ]); + + await session.bindExtensions({}); + await session.prompt("/use-proxy"); + + expect(session.model?.baseUrl).toBe("http://localhost:8080/command"); + expect(await capturePromptBaseUrl(session)).toBe("http://localhost:8080/command"); + + session.dispose(); + }); +});