From bf5ac0011e4fd49c7d4bc4a4035e51549e4a75db Mon Sep 17 00:00:00 2001 From: Vegard Stikbakke Date: Wed, 20 May 2026 09:30:58 +0200 Subject: [PATCH] feat(ai): add device code login callback and use for copilot --- packages/ai/CHANGELOG.md | 4 + packages/ai/src/cli.ts | 5 + packages/ai/src/index.ts | 1 + packages/ai/src/utils/oauth/device-code.ts | 80 +++++++++ packages/ai/src/utils/oauth/github-copilot.ts | 155 ++++++------------ packages/ai/src/utils/oauth/index.ts | 1 + packages/ai/src/utils/oauth/types.ts | 8 + packages/ai/test/github-copilot-oauth.test.ts | 95 ++++++++--- packages/ai/test/oauth-device-code.test.ts | 55 +++++++ .../interactive/components/login-dialog.ts | 33 +++- .../src/modes/interactive/interactive-mode.ts | 8 +- 11 files changed, 316 insertions(+), 129 deletions(-) create mode 100644 packages/ai/src/utils/oauth/device-code.ts create mode 100644 packages/ai/test/oauth-device-code.test.ts diff --git a/packages/ai/CHANGELOG.md b/packages/ai/CHANGELOG.md index d4560aa0d..782adee11 100644 --- a/packages/ai/CHANGELOG.md +++ b/packages/ai/CHANGELOG.md @@ -7,6 +7,10 @@ - Changed source syntax to avoid TypeScript constructs that require JavaScript emit, keeping the package compatible with Node.js strip-only TypeScript checks. - Removed the package-level development watch scripts now that the root TypeScript check validates strip-only-compatible sources. +### Added + +- Added first-class OAuth device-code callback metadata, shared polling support, and GitHub Copilot OAuth integration. + ### Fixed - Fixed OpenAI-compatible `streamSimple()` requests to stop sending model-derived default output token caps, avoiding context-window reservation failures on servers such as vLLM while preserving explicit `maxTokens` and required Anthropic `max_tokens` handling ([#4675](https://github.com/earendil-works/pi/issues/4675)). diff --git a/packages/ai/src/cli.ts b/packages/ai/src/cli.ts index 38ee7e346..442ee8ea6 100644 --- a/packages/ai/src/cli.ts +++ b/packages/ai/src/cli.ts @@ -42,6 +42,11 @@ async function login(providerId: OAuthProviderId): Promise { if (info.instructions) console.log(info.instructions); console.log(); }, + onDeviceCode: (info) => { + console.log(`\nOpen this URL in your browser:\n${info.verificationUri}`); + console.log(`Enter code: ${info.userCode}`); + console.log(); + }, onPrompt: async (p) => { return await promptFn(`${p.message}${p.placeholder ? ` (${p.placeholder})` : ""}:`); }, diff --git a/packages/ai/src/index.ts b/packages/ai/src/index.ts index fd06fe81b..ed7aeaa87 100644 --- a/packages/ai/src/index.ts +++ b/packages/ai/src/index.ts @@ -32,6 +32,7 @@ export * from "./utils/json-parse.ts"; export type { OAuthAuthInfo, OAuthCredentials, + OAuthDeviceCodeInfo, OAuthLoginCallbacks, OAuthPrompt, OAuthProvider, diff --git a/packages/ai/src/utils/oauth/device-code.ts b/packages/ai/src/utils/oauth/device-code.ts new file mode 100644 index 000000000..95dba2e78 --- /dev/null +++ b/packages/ai/src/utils/oauth/device-code.ts @@ -0,0 +1,80 @@ +const CANCEL_MESSAGE = "Login cancelled"; +const TIMEOUT_MESSAGE = "Device flow timed out"; +const SLOW_DOWN_TIMEOUT_MESSAGE = + "Device flow timed out after one or more slow_down responses. This is often caused by clock drift in WSL or VM environments. Please sync or restart the VM clock and try again."; +const MINIMUM_INTERVAL_MS = 1000; +// RFC 8628 section 3.2: if the authorization server omits `interval`, the client must use 5 seconds. +const DEFAULT_POLL_INTERVAL_SECONDS = 5; +// RFC 8628 section 3.5: `slow_down` means the polling interval must increase by 5 seconds. +const SLOW_DOWN_INTERVAL_INCREMENT_MS = 5000; + +export type OAuthDeviceCodePollResult = + | { status: "pending" } + | { status: "slow_down" } + | { status: "complete"; accessToken: string } + | { status: "failed"; message: string }; + +export type OAuthDeviceCodePollOptions = { + intervalSeconds?: number; + expiresInSeconds?: number; + poll: () => Promise; + signal?: AbortSignal; +}; + +function abortableSleep(ms: number, signal: AbortSignal | undefined, cancelMessage: string): Promise { + return new Promise((resolve, reject) => { + if (signal?.aborted) { + reject(new Error(cancelMessage)); + return; + } + + const onAbort = () => { + clearTimeout(timeout); + reject(new Error(cancelMessage)); + }; + const timeout = setTimeout(() => { + signal?.removeEventListener("abort", onAbort); + resolve(); + }, ms); + + signal?.addEventListener("abort", onAbort, { once: true }); + }); +} + +export async function pollOAuthDeviceCodeFlow(options: OAuthDeviceCodePollOptions): Promise { + const deadline = + typeof options.expiresInSeconds === "number" + ? Date.now() + options.expiresInSeconds * 1000 + : Number.POSITIVE_INFINITY; + let intervalMs = Math.max( + MINIMUM_INTERVAL_MS, + Math.floor((options.intervalSeconds ?? DEFAULT_POLL_INTERVAL_SECONDS) * 1000), + ); + + let slowDownResponses = 0; + while (Date.now() < deadline) { + if (options.signal?.aborted) { + throw new Error(CANCEL_MESSAGE); + } + + const remainingMs = deadline - Date.now(); + await abortableSleep(Math.min(intervalMs, remainingMs), options.signal, CANCEL_MESSAGE); + + const result = await options.poll(); + if (result.status === "complete") { + return result.accessToken; + } + if (result.status === "pending") { + continue; + } + if (result.status === "slow_down") { + slowDownResponses += 1; + // RFC 8628 section 3.5: apply this increase to this and all subsequent requests. + intervalMs = Math.max(MINIMUM_INTERVAL_MS, intervalMs + SLOW_DOWN_INTERVAL_INCREMENT_MS); + continue; + } + throw new Error(result.message); + } + + throw new Error(slowDownResponses > 0 ? SLOW_DOWN_TIMEOUT_MESSAGE : TIMEOUT_MESSAGE); +} diff --git a/packages/ai/src/utils/oauth/github-copilot.ts b/packages/ai/src/utils/oauth/github-copilot.ts index 6182ab910..c4ce36358 100644 --- a/packages/ai/src/utils/oauth/github-copilot.ts +++ b/packages/ai/src/utils/oauth/github-copilot.ts @@ -4,7 +4,8 @@ import { getModels } from "../../models.ts"; import type { Api, Model } from "../../types.ts"; -import type { OAuthCredentials, OAuthLoginCallbacks, OAuthProviderInterface } from "./types.ts"; +import { pollOAuthDeviceCodeFlow } from "./device-code.ts"; +import type { OAuthCredentials, OAuthDeviceCodeInfo, OAuthLoginCallbacks, OAuthProviderInterface } from "./types.ts"; type CopilotCredentials = OAuthCredentials & { enterpriseUrl?: string; @@ -20,14 +21,11 @@ const COPILOT_HEADERS = { "Copilot-Integration-Id": "vscode-chat", } as const; -const INITIAL_POLL_INTERVAL_MULTIPLIER = 1.2; -const SLOW_DOWN_POLL_INTERVAL_MULTIPLIER = 1.4; - type DeviceCodeResponse = { device_code: string; user_code: string; verification_uri: string; - interval: number; + interval?: number; expires_in: number; }; @@ -40,7 +38,6 @@ type DeviceTokenSuccessResponse = { type DeviceTokenErrorResponse = { error: string; error_description?: string; - interval?: number; }; export function normalizeDomain(input: string): string | null { @@ -129,7 +126,7 @@ async function startDeviceFlow(domain: string): Promise { typeof deviceCode !== "string" || typeof userCode !== "string" || typeof verificationUri !== "string" || - typeof interval !== "number" || + (interval !== undefined && typeof interval !== "number") || typeof expiresIn !== "number" ) { throw new Error("Invalid device code response fields"); @@ -144,95 +141,48 @@ async function startDeviceFlow(domain: string): Promise { }; } -/** - * Sleep that can be interrupted by an AbortSignal - */ -function abortableSleep(ms: number, signal?: AbortSignal): Promise { - return new Promise((resolve, reject) => { - if (signal?.aborted) { - reject(new Error("Login cancelled")); - return; - } - - const timeout = setTimeout(resolve, ms); - - signal?.addEventListener( - "abort", - () => { - clearTimeout(timeout); - reject(new Error("Login cancelled")); - }, - { once: true }, - ); - }); -} - -async function pollForGitHubAccessToken( - domain: string, - deviceCode: string, - intervalSeconds: number, - expiresIn: number, - signal?: AbortSignal, -) { +async function pollForGitHubAccessToken(domain: string, device: DeviceCodeResponse, signal?: AbortSignal) { const urls = getUrls(domain); - const deadline = Date.now() + expiresIn * 1000; - let intervalMs = Math.max(1000, Math.floor(intervalSeconds * 1000)); - let intervalMultiplier = INITIAL_POLL_INTERVAL_MULTIPLIER; - let slowDownResponses = 0; + return pollOAuthDeviceCodeFlow({ + intervalSeconds: device.interval, + expiresInSeconds: device.expires_in, + signal, + poll: async () => { + const raw = await fetchJson(urls.accessTokenUrl, { + method: "POST", + headers: { + Accept: "application/json", + "Content-Type": "application/x-www-form-urlencoded", + "User-Agent": "GitHubCopilotChat/0.35.0", + }, + body: new URLSearchParams({ + client_id: CLIENT_ID, + device_code: device.device_code, + grant_type: "urn:ietf:params:oauth:grant-type:device_code", + }), + }); - while (Date.now() < deadline) { - if (signal?.aborted) { - throw new Error("Login cancelled"); - } - - const remainingMs = deadline - Date.now(); - const waitMs = Math.min(Math.ceil(intervalMs * intervalMultiplier), remainingMs); - await abortableSleep(waitMs, signal); - - const raw = await fetchJson(urls.accessTokenUrl, { - method: "POST", - headers: { - Accept: "application/json", - "Content-Type": "application/x-www-form-urlencoded", - "User-Agent": "GitHubCopilotChat/0.35.0", - }, - body: new URLSearchParams({ - client_id: CLIENT_ID, - device_code: deviceCode, - grant_type: "urn:ietf:params:oauth:grant-type:device_code", - }), - }); - - if (raw && typeof raw === "object" && typeof (raw as DeviceTokenSuccessResponse).access_token === "string") { - return (raw as DeviceTokenSuccessResponse).access_token; - } - - if (raw && typeof raw === "object" && typeof (raw as DeviceTokenErrorResponse).error === "string") { - const { error, error_description: description, interval } = raw as DeviceTokenErrorResponse; - if (error === "authorization_pending") { - continue; + if (raw && typeof raw === "object" && typeof (raw as DeviceTokenSuccessResponse).access_token === "string") { + return { status: "complete", accessToken: (raw as DeviceTokenSuccessResponse).access_token }; } - if (error === "slow_down") { - slowDownResponses += 1; - intervalMs = - typeof interval === "number" && interval > 0 ? interval * 1000 : Math.max(1000, intervalMs + 5000); - intervalMultiplier = SLOW_DOWN_POLL_INTERVAL_MULTIPLIER; - continue; + if (raw && typeof raw === "object" && typeof (raw as DeviceTokenErrorResponse).error === "string") { + const { error, error_description: description } = raw as DeviceTokenErrorResponse; + if (error === "authorization_pending") { + return { status: "pending" }; + } + + if (error === "slow_down") { + return { status: "slow_down" }; + } + + const descriptionSuffix = description ? `: ${description}` : ""; + return { status: "failed", message: `Device flow failed: ${error}${descriptionSuffix}` }; } - const descriptionSuffix = description ? `: ${description}` : ""; - throw new Error(`Device flow failed: ${error}${descriptionSuffix}`); - } - } - - if (slowDownResponses > 0) { - throw new Error( - "Device flow timed out after one or more slow_down responses. This is often caused by clock drift in WSL or VM environments. Please sync or restart the VM clock and try again.", - ); - } - - throw new Error("Device flow timed out"); + return { status: "failed", message: "Invalid device token response" }; + }, + }); } /** @@ -319,13 +269,13 @@ async function enableAllGitHubCopilotModels( /** * Login with GitHub Copilot OAuth (device code flow) * - * @param options.onAuth - Callback with URL and optional instructions (user code) + * @param options.onDeviceCode - Callback with URL and user code * @param options.onPrompt - Callback to prompt user for input * @param options.onProgress - Optional progress callback * @param options.signal - Optional AbortSignal for cancellation */ export async function loginGitHubCopilot(options: { - onAuth: (url: string, instructions?: string) => void; + onDeviceCode: (info: OAuthDeviceCodeInfo) => void; onPrompt: (prompt: { message: string; placeholder?: string; allowEmpty?: boolean }) => Promise; onProgress?: (message: string) => void; signal?: AbortSignal; @@ -348,15 +298,14 @@ export async function loginGitHubCopilot(options: { const domain = enterpriseDomain || "github.com"; const device = await startDeviceFlow(domain); - options.onAuth(device.verification_uri, `Enter code: ${device.user_code}`); + options.onDeviceCode({ + userCode: device.user_code, + verificationUri: device.verification_uri, + intervalSeconds: device.interval, + expiresInSeconds: device.expires_in, + }); - const githubAccessToken = await pollForGitHubAccessToken( - domain, - device.device_code, - device.interval, - device.expires_in, - options.signal, - ); + const githubAccessToken = await pollForGitHubAccessToken(domain, device, options.signal); const credentials = await refreshGitHubCopilotToken(githubAccessToken, enterpriseDomain ?? undefined); // Enable all models after successful login @@ -370,8 +319,12 @@ export const githubCopilotOAuthProvider: OAuthProviderInterface = { name: "GitHub Copilot", async login(callbacks: OAuthLoginCallbacks): Promise { + if (!callbacks.onDeviceCode) { + throw new Error("GitHub Copilot OAuth requires a device code callback"); + } + return loginGitHubCopilot({ - onAuth: (url, instructions) => callbacks.onAuth({ url, instructions }), + onDeviceCode: callbacks.onDeviceCode, onPrompt: callbacks.onPrompt, onProgress: callbacks.onProgress, signal: callbacks.signal, diff --git a/packages/ai/src/utils/oauth/index.ts b/packages/ai/src/utils/oauth/index.ts index 3a3a01b16..55910322b 100644 --- a/packages/ai/src/utils/oauth/index.ts +++ b/packages/ai/src/utils/oauth/index.ts @@ -9,6 +9,7 @@ // Anthropic export { anthropicOAuthProvider, loginAnthropic, refreshAnthropicToken } from "./anthropic.ts"; +export * from "./device-code.ts"; // GitHub Copilot export { getGitHubCopilotBaseUrl, diff --git a/packages/ai/src/utils/oauth/types.ts b/packages/ai/src/utils/oauth/types.ts index a1426d815..3220dcf9d 100644 --- a/packages/ai/src/utils/oauth/types.ts +++ b/packages/ai/src/utils/oauth/types.ts @@ -23,6 +23,13 @@ export type OAuthAuthInfo = { instructions?: string; }; +export type OAuthDeviceCodeInfo = { + userCode: string; + verificationUri: string; + intervalSeconds?: number; + expiresInSeconds?: number; +}; + export type OAuthSelectOption = { id: string; label: string; @@ -35,6 +42,7 @@ export type OAuthSelectPrompt = { export interface OAuthLoginCallbacks { onAuth: (info: OAuthAuthInfo) => void; + onDeviceCode?: (info: OAuthDeviceCodeInfo) => void; onPrompt: (prompt: OAuthPrompt) => Promise; onProgress?: (message: string) => void; onManualCodeInput?: () => Promise; diff --git a/packages/ai/test/github-copilot-oauth.test.ts b/packages/ai/test/github-copilot-oauth.test.ts index 0367892cb..5c13eb1aa 100644 --- a/packages/ai/test/github-copilot-oauth.test.ts +++ b/packages/ai/test/github-copilot-oauth.test.ts @@ -29,7 +29,62 @@ describe("GitHub Copilot OAuth device flow", () => { vi.useRealTimers(); }); - it("waits before the first poll and increases the safety margin after slow_down", async () => { + it("reports device-code details through onDeviceCode", async () => { + vi.useFakeTimers(); + vi.setSystemTime(new Date("2026-03-09T00:00:00Z")); + + const fetchMock = vi.fn(async (input: unknown): Promise => { + const url = getUrl(input); + + if (url.endsWith("/login/device/code")) { + return jsonResponse({ + device_code: "device-code", + user_code: "ABCD-EFGH", + verification_uri: "https://github.com/login/device", + interval: 1, + expires_in: 900, + }); + } + + if (url.endsWith("/login/oauth/access_token")) { + return jsonResponse({ access_token: "ghu_refresh_token" }); + } + + if (url.includes("/copilot_internal/v2/token")) { + return jsonResponse({ + token: "tid=test;exp=9999999999;proxy-ep=proxy.individual.githubcopilot.com;", + expires_at: 9999999999, + }); + } + + if (url.includes("/models/") && url.endsWith("/policy")) { + return new Response("", { status: 200 }); + } + + throw new Error(`Unexpected fetch URL: ${url}`); + }); + + vi.stubGlobal("fetch", fetchMock); + + const onDeviceCode = vi.fn(); + const loginPromise = loginGitHubCopilot({ + onDeviceCode, + onPrompt: async () => "", + }); + + await vi.advanceTimersByTimeAsync(0); + + expect(onDeviceCode).toHaveBeenCalledWith({ + userCode: "ABCD-EFGH", + verificationUri: "https://github.com/login/device", + intervalSeconds: 1, + expiresInSeconds: 900, + }); + await vi.advanceTimersByTimeAsync(1000); + await loginPromise; + }); + + it("waits before the first poll and increases the interval after slow_down", async () => { vi.useFakeTimers(); const startTime = new Date("2026-03-09T00:00:00Z"); vi.setSystemTime(startTime); @@ -37,7 +92,7 @@ describe("GitHub Copilot OAuth device flow", () => { const accessTokenPollTimes: number[] = []; const accessTokenResponses = [ jsonResponse({ error: "authorization_pending", error_description: "pending" }), - jsonResponse({ error: "slow_down", error_description: "slow down", interval: 10 }), + jsonResponse({ error: "slow_down", error_description: "slow down" }), jsonResponse({ access_token: "ghu_refresh_token" }), ]; @@ -95,7 +150,7 @@ describe("GitHub Copilot OAuth device flow", () => { vi.stubGlobal("fetch", fetchMock); const loginPromise = loginGitHubCopilot({ - onAuth: () => {}, + onDeviceCode: () => {}, onPrompt: async () => "", onProgress: () => {}, }); @@ -103,28 +158,28 @@ describe("GitHub Copilot OAuth device flow", () => { await vi.advanceTimersByTimeAsync(0); expect(accessTokenPollTimes).toHaveLength(0); - await vi.advanceTimersByTimeAsync(5999); + await vi.advanceTimersByTimeAsync(4999); expect(accessTokenPollTimes).toHaveLength(0); await vi.advanceTimersByTimeAsync(1); expect(accessTokenPollTimes).toHaveLength(1); - await vi.advanceTimersByTimeAsync(5999); + await vi.advanceTimersByTimeAsync(4999); expect(accessTokenPollTimes).toHaveLength(1); await vi.advanceTimersByTimeAsync(1); expect(accessTokenPollTimes).toHaveLength(2); - await vi.advanceTimersByTimeAsync(13999); + await vi.advanceTimersByTimeAsync(9999); expect(accessTokenPollTimes).toHaveLength(2); await vi.advanceTimersByTimeAsync(1); await loginPromise; expect(accessTokenPollTimes).toEqual([ - startTime.getTime() + 6000, - startTime.getTime() + 12000, - startTime.getTime() + 26000, + startTime.getTime() + 5000, + startTime.getTime() + 10000, + startTime.getTime() + 20000, ]); }); @@ -135,8 +190,8 @@ describe("GitHub Copilot OAuth device flow", () => { const accessTokenPollTimes: number[] = []; const accessTokenResponses = [ - jsonResponse({ error: "slow_down", error_description: "slow down", interval: 10 }), - jsonResponse({ error: "slow_down", error_description: "still too fast", interval: 15 }), + jsonResponse({ error: "slow_down", error_description: "slow down" }), + jsonResponse({ error: "slow_down", error_description: "still too fast" }), jsonResponse({ error: "authorization_pending", error_description: "pending" }), ]; @@ -168,28 +223,28 @@ describe("GitHub Copilot OAuth device flow", () => { vi.stubGlobal("fetch", fetchMock); const loginPromise = loginGitHubCopilot({ - onAuth: () => {}, + onDeviceCode: () => {}, onPrompt: async () => "", }); const rejection = expect(loginPromise).rejects.toThrow( /Device flow timed out after one or more slow_down responses/, ); - await vi.advanceTimersByTimeAsync(6000); - expect(accessTokenPollTimes).toEqual([startTime.getTime() + 6000]); + await vi.advanceTimersByTimeAsync(5000); + expect(accessTokenPollTimes).toEqual([startTime.getTime() + 5000]); - await vi.advanceTimersByTimeAsync(14000); - expect(accessTokenPollTimes).toEqual([startTime.getTime() + 6000, startTime.getTime() + 20000]); + await vi.advanceTimersByTimeAsync(10000); + expect(accessTokenPollTimes).toEqual([startTime.getTime() + 5000, startTime.getTime() + 15000]); - await vi.advanceTimersByTimeAsync(4999); - expect(accessTokenPollTimes).toEqual([startTime.getTime() + 6000, startTime.getTime() + 20000]); + await vi.advanceTimersByTimeAsync(9999); + expect(accessTokenPollTimes).toEqual([startTime.getTime() + 5000, startTime.getTime() + 15000]); await vi.advanceTimersByTimeAsync(1); await rejection; expect(accessTokenPollTimes).toEqual([ - startTime.getTime() + 6000, - startTime.getTime() + 20000, + startTime.getTime() + 5000, + startTime.getTime() + 15000, startTime.getTime() + 25000, ]); }); diff --git a/packages/ai/test/oauth-device-code.test.ts b/packages/ai/test/oauth-device-code.test.ts new file mode 100644 index 000000000..4b937320a --- /dev/null +++ b/packages/ai/test/oauth-device-code.test.ts @@ -0,0 +1,55 @@ +import { afterEach, describe, expect, it, vi } from "vitest"; +import { pollOAuthDeviceCodeFlow } from "../src/utils/oauth/device-code.js"; + +describe("OAuth device-code polling", () => { + afterEach(() => { + vi.useRealTimers(); + }); + + it("waits before the first poll and returns the completed value", async () => { + vi.useFakeTimers(); + vi.setSystemTime(new Date("2026-03-09T00:00:00Z")); + + const pollTimes: number[] = []; + const poll = vi.fn(async () => { + pollTimes.push(Date.now()); + return pollTimes.length === 1 + ? { status: "pending" as const } + : { status: "complete" as const, accessToken: "token" }; + }); + + const resultPromise = pollOAuthDeviceCodeFlow({ + intervalSeconds: 2, + expiresInSeconds: 30, + poll, + }); + + await vi.advanceTimersByTimeAsync(1999); + expect(pollTimes).toEqual([]); + + await vi.advanceTimersByTimeAsync(1); + expect(pollTimes).toEqual([new Date("2026-03-09T00:00:02Z").getTime()]); + + await vi.advanceTimersByTimeAsync(2000); + await expect(resultPromise).resolves.toBe("token"); + expect(pollTimes).toEqual([ + new Date("2026-03-09T00:00:02Z").getTime(), + new Date("2026-03-09T00:00:04Z").getTime(), + ]); + }); + + it("cancels an in-flight wait", async () => { + vi.useFakeTimers(); + const controller = new AbortController(); + + const resultPromise = pollOAuthDeviceCodeFlow({ + intervalSeconds: 5, + expiresInSeconds: 30, + poll: async () => ({ status: "pending" }), + signal: controller.signal, + }); + + controller.abort(); + await expect(resultPromise).rejects.toThrow("Login cancelled"); + }); +}); diff --git a/packages/coding-agent/src/modes/interactive/components/login-dialog.ts b/packages/coding-agent/src/modes/interactive/components/login-dialog.ts index 80560f142..5a2c365af 100644 --- a/packages/coding-agent/src/modes/interactive/components/login-dialog.ts +++ b/packages/coding-agent/src/modes/interactive/components/login-dialog.ts @@ -1,4 +1,4 @@ -import { getOAuthProviders } from "@earendil-works/pi-ai/oauth"; +import { getOAuthProviders, type OAuthDeviceCodeInfo } from "@earendil-works/pi-ai/oauth"; import { Container, type Focusable, getKeybindings, Input, Spacer, Text, type TUI } from "@earendil-works/pi-tui"; import { exec } from "child_process"; import { theme } from "../theme/theme.ts"; @@ -86,7 +86,7 @@ export class LoginDialogComponent extends Container implements Focusable { /** * Called by onAuth callback - show URL and optional instructions */ - showAuth(url: string, instructions?: string): void { + showAuth(url: string, instructions?: string, options: { autoOpenBrowser?: boolean } = {}): void { this.contentContainer.clear(); this.contentContainer.addChild(new Spacer(1)); const linkedUrl = `\x1b]8;;${url}\x07${url}\x1b]8;;\x07`; @@ -101,11 +101,34 @@ export class LoginDialogComponent extends Container implements Focusable { this.contentContainer.addChild(new Text(theme.fg("warning", instructions), 1, 0)); } - // Try to open browser + if (options.autoOpenBrowser ?? true) { + this.openUrl(url); + } + this.tui.requestRender(); + } + + /** + * Called by onDeviceCode callback - show URL and user code. + */ + showDeviceCode(info: OAuthDeviceCodeInfo): void { + this.contentContainer.clear(); + this.contentContainer.addChild(new Spacer(1)); + const linkedUrl = `\x1b]8;;${info.verificationUri}\x07${info.verificationUri}\x1b]8;;\x07`; + this.contentContainer.addChild(new Text(theme.fg("accent", linkedUrl), 1, 0)); + + const clickHint = process.platform === "darwin" ? "Cmd+click to open" : "Ctrl+click to open"; + const hyperlink = `\x1b]8;;${info.verificationUri}\x07${clickHint}\x1b]8;;\x07`; + this.contentContainer.addChild(new Text(theme.fg("dim", hyperlink), 1, 0)); + this.contentContainer.addChild(new Spacer(1)); + this.contentContainer.addChild(new Text(theme.fg("warning", `Enter code: ${info.userCode}`), 1, 0)); + + // Do not open device-code URLs automatically. These flows need to work in headless environments. + this.tui.requestRender(); + } + + private openUrl(url: string): void { const openCmd = process.platform === "darwin" ? "open" : process.platform === "win32" ? "start" : "xdg-open"; exec(`${openCmd} "${url}"`); - - this.tui.requestRender(); } /** diff --git a/packages/coding-agent/src/modes/interactive/interactive-mode.ts b/packages/coding-agent/src/modes/interactive/interactive-mode.ts index c93beb93a..7616d7e0c 100644 --- a/packages/coding-agent/src/modes/interactive/interactive-mode.ts +++ b/packages/coding-agent/src/modes/interactive/interactive-mode.ts @@ -4815,13 +4815,15 @@ export class InteractiveMode { manualCodeReject = undefined; } }); - } else if (providerId === "github-copilot") { - // GitHub Copilot polls after onAuth - dialog.showWaiting("Waiting for browser authentication..."); } // For Anthropic: onPrompt is called immediately after }, + onDeviceCode: (info) => { + dialog.showDeviceCode(info); + dialog.showWaiting("Waiting for authentication..."); + }, + onPrompt: async (prompt: { message: string; placeholder?: string }) => { return dialog.showPrompt(prompt.message, prompt.placeholder); },