mirror of
https://github.com/earendil-works/pi.git
synced 2026-06-18 15:54:04 +08:00
fix(coding-agent): refresh active model after provider updates closes #2291
This commit is contained in:
@@ -2057,6 +2057,20 @@ export class AgentSession {
|
||||
: undefined;
|
||||
}
|
||||
|
||||
private _refreshCurrentModelFromRegistry(): void {
|
||||
const currentModel = this.model;
|
||||
if (!currentModel) {
|
||||
return;
|
||||
}
|
||||
|
||||
const refreshedModel = this._modelRegistry.find(currentModel.provider, currentModel.id);
|
||||
if (!refreshedModel || refreshedModel === currentModel) {
|
||||
return;
|
||||
}
|
||||
|
||||
this.agent.setModel(refreshedModel);
|
||||
}
|
||||
|
||||
private _bindExtensionCore(runner: ExtensionRunner): void {
|
||||
const normalizeLocation = (source: string): SlashCommandLocation | undefined => {
|
||||
if (source === "user" || source === "project" || source === "path") {
|
||||
@@ -2165,6 +2179,16 @@ export class AgentSession {
|
||||
},
|
||||
getSystemPrompt: () => this.systemPrompt,
|
||||
},
|
||||
{
|
||||
registerProvider: (name, config) => {
|
||||
this._modelRegistry.registerProvider(name, config);
|
||||
this._refreshCurrentModelFromRegistry();
|
||||
},
|
||||
unregisterProvider: (name) => {
|
||||
this._modelRegistry.unregisterProvider(name);
|
||||
this._refreshCurrentModelFromRegistry();
|
||||
},
|
||||
},
|
||||
);
|
||||
}
|
||||
|
||||
|
||||
@@ -34,6 +34,7 @@ import type {
|
||||
InputEventResult,
|
||||
InputSource,
|
||||
MessageRenderer,
|
||||
ProviderConfig,
|
||||
RegisteredCommand,
|
||||
RegisteredTool,
|
||||
ResourcesDiscoverEvent,
|
||||
@@ -235,7 +236,14 @@ export class ExtensionRunner {
|
||||
this.modelRegistry = modelRegistry;
|
||||
}
|
||||
|
||||
bindCore(actions: ExtensionActions, contextActions: ExtensionContextActions): void {
|
||||
bindCore(
|
||||
actions: ExtensionActions,
|
||||
contextActions: ExtensionContextActions,
|
||||
providerActions?: {
|
||||
registerProvider?: (name: string, config: ProviderConfig) => void;
|
||||
unregisterProvider?: (name: string) => void;
|
||||
},
|
||||
): void {
|
||||
// Copy actions into the shared runtime (all extension APIs reference this)
|
||||
this.runtime.sendMessage = actions.sendMessage;
|
||||
this.runtime.sendUserMessage = actions.sendUserMessage;
|
||||
@@ -264,14 +272,30 @@ export class ExtensionRunner {
|
||||
|
||||
// Flush provider registrations queued during extension loading
|
||||
for (const { name, config } of this.runtime.pendingProviderRegistrations) {
|
||||
this.modelRegistry.registerProvider(name, config);
|
||||
if (providerActions?.registerProvider) {
|
||||
providerActions.registerProvider(name, config);
|
||||
} else {
|
||||
this.modelRegistry.registerProvider(name, config);
|
||||
}
|
||||
}
|
||||
this.runtime.pendingProviderRegistrations = [];
|
||||
|
||||
// From this point on, provider registration/unregistration takes effect immediately
|
||||
// without requiring a /reload.
|
||||
this.runtime.registerProvider = (name, config) => this.modelRegistry.registerProvider(name, config);
|
||||
this.runtime.unregisterProvider = (name) => this.modelRegistry.unregisterProvider(name);
|
||||
this.runtime.registerProvider = (name, config) => {
|
||||
if (providerActions?.registerProvider) {
|
||||
providerActions.registerProvider(name, config);
|
||||
return;
|
||||
}
|
||||
this.modelRegistry.registerProvider(name, config);
|
||||
};
|
||||
this.runtime.unregisterProvider = (name) => {
|
||||
if (providerActions?.unregisterProvider) {
|
||||
providerActions.unregisterProvider(name);
|
||||
return;
|
||||
}
|
||||
this.modelRegistry.unregisterProvider(name);
|
||||
};
|
||||
}
|
||||
|
||||
bindCommandContext(actions?: ExtensionCommandContextActions): void {
|
||||
|
||||
@@ -0,0 +1,117 @@
|
||||
import { existsSync, mkdirSync, rmSync } from "node:fs";
|
||||
import { tmpdir } from "node:os";
|
||||
import { join } from "node:path";
|
||||
import { getModel } from "@mariozechner/pi-ai";
|
||||
import { afterEach, beforeEach, describe, expect, it } from "vitest";
|
||||
import { AuthStorage } from "../src/core/auth-storage.js";
|
||||
import { DefaultResourceLoader } from "../src/core/resource-loader.js";
|
||||
import type { ExtensionFactory } from "../src/core/sdk.js";
|
||||
import { createAgentSession } from "../src/core/sdk.js";
|
||||
import { SessionManager } from "../src/core/session-manager.js";
|
||||
import { SettingsManager } from "../src/core/settings-manager.js";
|
||||
|
||||
describe("AgentSession dynamic provider registration", () => {
|
||||
let tempDir: string;
|
||||
let agentDir: string;
|
||||
|
||||
beforeEach(() => {
|
||||
tempDir = join(tmpdir(), `pi-dynamic-provider-test-${Date.now()}-${Math.random().toString(36).slice(2)}`);
|
||||
agentDir = join(tempDir, "agent");
|
||||
mkdirSync(agentDir, { recursive: true });
|
||||
});
|
||||
|
||||
afterEach(() => {
|
||||
if (tempDir && existsSync(tempDir)) {
|
||||
rmSync(tempDir, { recursive: true, force: true });
|
||||
}
|
||||
});
|
||||
|
||||
async function createSession(extensionFactories: ExtensionFactory[]) {
|
||||
const settingsManager = SettingsManager.create(tempDir, agentDir);
|
||||
const sessionManager = SessionManager.inMemory();
|
||||
const authStorage = AuthStorage.create(join(agentDir, "auth.json"));
|
||||
authStorage.setRuntimeApiKey("anthropic", "test-key");
|
||||
const resourceLoader = new DefaultResourceLoader({
|
||||
cwd: tempDir,
|
||||
agentDir,
|
||||
settingsManager,
|
||||
extensionFactories,
|
||||
});
|
||||
await resourceLoader.reload();
|
||||
|
||||
const { session } = await createAgentSession({
|
||||
cwd: tempDir,
|
||||
agentDir,
|
||||
model: getModel("anthropic", "claude-sonnet-4-5")!,
|
||||
settingsManager,
|
||||
sessionManager,
|
||||
authStorage,
|
||||
resourceLoader,
|
||||
});
|
||||
|
||||
return session;
|
||||
}
|
||||
|
||||
async function capturePromptBaseUrl(
|
||||
session: Awaited<ReturnType<typeof createSession>>,
|
||||
): Promise<string | undefined> {
|
||||
let baseUrl: string | undefined;
|
||||
session.agent.streamFn = async (model) => {
|
||||
baseUrl = model.baseUrl;
|
||||
throw new Error("stop");
|
||||
};
|
||||
await session.prompt("hello");
|
||||
return baseUrl;
|
||||
}
|
||||
|
||||
it("applies top-level registerProvider overrides to the active model", async () => {
|
||||
const session = await createSession([
|
||||
(pi) => {
|
||||
pi.registerProvider("anthropic", { baseUrl: "http://localhost:8080/top-level" });
|
||||
},
|
||||
]);
|
||||
|
||||
expect(session.model?.baseUrl).toBe("http://localhost:8080/top-level");
|
||||
expect(await capturePromptBaseUrl(session)).toBe("http://localhost:8080/top-level");
|
||||
|
||||
session.dispose();
|
||||
});
|
||||
|
||||
it("applies session_start registerProvider overrides to the active model", async () => {
|
||||
const session = await createSession([
|
||||
(pi) => {
|
||||
pi.on("session_start", () => {
|
||||
pi.registerProvider("anthropic", { baseUrl: "http://localhost:8080/session-start" });
|
||||
});
|
||||
},
|
||||
]);
|
||||
|
||||
await session.bindExtensions({});
|
||||
|
||||
expect(session.model?.baseUrl).toBe("http://localhost:8080/session-start");
|
||||
expect(await capturePromptBaseUrl(session)).toBe("http://localhost:8080/session-start");
|
||||
|
||||
session.dispose();
|
||||
});
|
||||
|
||||
it("applies command-time registerProvider overrides without reload", async () => {
|
||||
const session = await createSession([
|
||||
(pi) => {
|
||||
pi.registerCommand("use-proxy", {
|
||||
description: "Use proxy",
|
||||
handler: async () => {
|
||||
pi.registerProvider("anthropic", { baseUrl: "http://localhost:8080/command" });
|
||||
},
|
||||
});
|
||||
},
|
||||
]);
|
||||
|
||||
await session.bindExtensions({});
|
||||
await session.prompt("/use-proxy");
|
||||
|
||||
expect(session.model?.baseUrl).toBe("http://localhost:8080/command");
|
||||
expect(await capturePromptBaseUrl(session)).toBe("http://localhost:8080/command");
|
||||
|
||||
session.dispose();
|
||||
});
|
||||
});
|
||||
Reference in New Issue
Block a user