fix(coding-agent): refresh active model after provider updates closes #2291

This commit is contained in:
Mario Zechner
2026-03-18 01:12:11 +01:00
Unverified
parent 1a9185d3cb
commit 2becbbdff3
3 changed files with 169 additions and 4 deletions
@@ -2057,6 +2057,20 @@ export class AgentSession {
: undefined;
}
private _refreshCurrentModelFromRegistry(): void {
const currentModel = this.model;
if (!currentModel) {
return;
}
const refreshedModel = this._modelRegistry.find(currentModel.provider, currentModel.id);
if (!refreshedModel || refreshedModel === currentModel) {
return;
}
this.agent.setModel(refreshedModel);
}
private _bindExtensionCore(runner: ExtensionRunner): void {
const normalizeLocation = (source: string): SlashCommandLocation | undefined => {
if (source === "user" || source === "project" || source === "path") {
@@ -2165,6 +2179,16 @@ export class AgentSession {
},
getSystemPrompt: () => this.systemPrompt,
},
{
registerProvider: (name, config) => {
this._modelRegistry.registerProvider(name, config);
this._refreshCurrentModelFromRegistry();
},
unregisterProvider: (name) => {
this._modelRegistry.unregisterProvider(name);
this._refreshCurrentModelFromRegistry();
},
},
);
}
@@ -34,6 +34,7 @@ import type {
InputEventResult,
InputSource,
MessageRenderer,
ProviderConfig,
RegisteredCommand,
RegisteredTool,
ResourcesDiscoverEvent,
@@ -235,7 +236,14 @@ export class ExtensionRunner {
this.modelRegistry = modelRegistry;
}
bindCore(actions: ExtensionActions, contextActions: ExtensionContextActions): void {
bindCore(
actions: ExtensionActions,
contextActions: ExtensionContextActions,
providerActions?: {
registerProvider?: (name: string, config: ProviderConfig) => void;
unregisterProvider?: (name: string) => void;
},
): void {
// Copy actions into the shared runtime (all extension APIs reference this)
this.runtime.sendMessage = actions.sendMessage;
this.runtime.sendUserMessage = actions.sendUserMessage;
@@ -264,14 +272,30 @@ export class ExtensionRunner {
// Flush provider registrations queued during extension loading
for (const { name, config } of this.runtime.pendingProviderRegistrations) {
this.modelRegistry.registerProvider(name, config);
if (providerActions?.registerProvider) {
providerActions.registerProvider(name, config);
} else {
this.modelRegistry.registerProvider(name, config);
}
}
this.runtime.pendingProviderRegistrations = [];
// From this point on, provider registration/unregistration takes effect immediately
// without requiring a /reload.
this.runtime.registerProvider = (name, config) => this.modelRegistry.registerProvider(name, config);
this.runtime.unregisterProvider = (name) => this.modelRegistry.unregisterProvider(name);
this.runtime.registerProvider = (name, config) => {
if (providerActions?.registerProvider) {
providerActions.registerProvider(name, config);
return;
}
this.modelRegistry.registerProvider(name, config);
};
this.runtime.unregisterProvider = (name) => {
if (providerActions?.unregisterProvider) {
providerActions.unregisterProvider(name);
return;
}
this.modelRegistry.unregisterProvider(name);
};
}
bindCommandContext(actions?: ExtensionCommandContextActions): void {
@@ -0,0 +1,117 @@
import { existsSync, mkdirSync, rmSync } from "node:fs";
import { tmpdir } from "node:os";
import { join } from "node:path";
import { getModel } from "@mariozechner/pi-ai";
import { afterEach, beforeEach, describe, expect, it } from "vitest";
import { AuthStorage } from "../src/core/auth-storage.js";
import { DefaultResourceLoader } from "../src/core/resource-loader.js";
import type { ExtensionFactory } from "../src/core/sdk.js";
import { createAgentSession } from "../src/core/sdk.js";
import { SessionManager } from "../src/core/session-manager.js";
import { SettingsManager } from "../src/core/settings-manager.js";
describe("AgentSession dynamic provider registration", () => {
let tempDir: string;
let agentDir: string;
beforeEach(() => {
tempDir = join(tmpdir(), `pi-dynamic-provider-test-${Date.now()}-${Math.random().toString(36).slice(2)}`);
agentDir = join(tempDir, "agent");
mkdirSync(agentDir, { recursive: true });
});
afterEach(() => {
if (tempDir && existsSync(tempDir)) {
rmSync(tempDir, { recursive: true, force: true });
}
});
async function createSession(extensionFactories: ExtensionFactory[]) {
const settingsManager = SettingsManager.create(tempDir, agentDir);
const sessionManager = SessionManager.inMemory();
const authStorage = AuthStorage.create(join(agentDir, "auth.json"));
authStorage.setRuntimeApiKey("anthropic", "test-key");
const resourceLoader = new DefaultResourceLoader({
cwd: tempDir,
agentDir,
settingsManager,
extensionFactories,
});
await resourceLoader.reload();
const { session } = await createAgentSession({
cwd: tempDir,
agentDir,
model: getModel("anthropic", "claude-sonnet-4-5")!,
settingsManager,
sessionManager,
authStorage,
resourceLoader,
});
return session;
}
async function capturePromptBaseUrl(
session: Awaited<ReturnType<typeof createSession>>,
): Promise<string | undefined> {
let baseUrl: string | undefined;
session.agent.streamFn = async (model) => {
baseUrl = model.baseUrl;
throw new Error("stop");
};
await session.prompt("hello");
return baseUrl;
}
it("applies top-level registerProvider overrides to the active model", async () => {
const session = await createSession([
(pi) => {
pi.registerProvider("anthropic", { baseUrl: "http://localhost:8080/top-level" });
},
]);
expect(session.model?.baseUrl).toBe("http://localhost:8080/top-level");
expect(await capturePromptBaseUrl(session)).toBe("http://localhost:8080/top-level");
session.dispose();
});
it("applies session_start registerProvider overrides to the active model", async () => {
const session = await createSession([
(pi) => {
pi.on("session_start", () => {
pi.registerProvider("anthropic", { baseUrl: "http://localhost:8080/session-start" });
});
},
]);
await session.bindExtensions({});
expect(session.model?.baseUrl).toBe("http://localhost:8080/session-start");
expect(await capturePromptBaseUrl(session)).toBe("http://localhost:8080/session-start");
session.dispose();
});
it("applies command-time registerProvider overrides without reload", async () => {
const session = await createSession([
(pi) => {
pi.registerCommand("use-proxy", {
description: "Use proxy",
handler: async () => {
pi.registerProvider("anthropic", { baseUrl: "http://localhost:8080/command" });
},
});
},
]);
await session.bindExtensions({});
await session.prompt("/use-proxy");
expect(session.model?.baseUrl).toBe("http://localhost:8080/command");
expect(await capturePromptBaseUrl(session)).toBe("http://localhost:8080/command");
session.dispose();
});
});