feat: add provider-scoped environment overrides (#5807)

This commit is contained in:
Armin Ronacher
2026-06-16 17:19:08 +02:00
committed by GitHub
Unverified
parent 3039f3e17d
commit 7f29e7a369
33 changed files with 511 additions and 215 deletions
+4
View File
@@ -2,6 +2,10 @@
## [Unreleased]
### Added
- Added provider-scoped `StreamOptions.env` overrides for provider configuration, including Cloudflare endpoint placeholders, Azure OpenAI, Google Vertex, Amazon Bedrock, cache retention, and proxy environment lookups ([#5728](https://github.com/earendil-works/pi/issues/5728)).
### Fixed
- Fixed Z.AI GLM-5.2 thinking requests to send `reasoning_effort` with the provider's `high`/`max` effort mapping ([#5770](https://github.com/earendil-works/pi/issues/5770)).
+19
View File
@@ -38,6 +38,7 @@ Unified LLM API with automatic model discovery, provider configuration, token an
- [Browser Usage](#browser-usage)
- [Browser Compatibility Notes](#browser-compatibility-notes)
- [Environment Variables](#environment-variables-nodejs-only)
- [Provider-Scoped Environment Overrides](#provider-scoped-environment-overrides)
- [Checking Environment Variables](#checking-environment-variables)
- [OAuth Providers](#oauth-providers)
- [Vertex AI](#vertex-ai)
@@ -1145,6 +1146,24 @@ const response = await complete(model, context, {
});
```
### Provider-Scoped Environment Overrides
Pass `env` in stream options to scope provider configuration to a request. Values in `env` are used before process environment variables for API key discovery and provider configuration such as Cloudflare account IDs, Azure OpenAI settings, Vertex project/location, Bedrock settings, `PI_CACHE_RETENTION`, and `HTTP_PROXY`/`HTTPS_PROXY`.
```typescript
const model = getModel('cloudflare-ai-gateway', 'workers-ai/@cf/moonshotai/kimi-k2.6');
const response = await complete(model, context, {
env: {
CLOUDFLARE_API_KEY: '...',
CLOUDFLARE_ACCOUNT_ID: 'account-id',
CLOUDFLARE_GATEWAY_ID: 'gateway-id'
}
});
```
Use this when one process needs different provider settings per request, or when ambient environment variables should not leak into a provider call.
### Checking Environment Variables
```typescript
+27 -63
View File
@@ -23,44 +23,17 @@ if (typeof process !== "undefined" && (process.versions?.node || process.version
});
}
import type { KnownProvider } from "./types.ts";
let _procEnvCache: Map<string, string> | null = null;
/**
* Fallback for https://github.com/oven-sh/bun/issues/27802
* Bun compiled binaries have an empty `process.env` inside sandbox
* environments on Linux. We can recover the env from `/proc/self/environ`.
*/
function getProcEnv(key: string): string | undefined {
if (!process.versions?.bun) return undefined;
if (typeof process === "undefined") return undefined;
// If process.env already has entries, the bug is not triggered.
if (Object.keys(process.env).length > 0) return undefined;
if (_procEnvCache === null) {
_procEnvCache = new Map();
try {
const { readFileSync } = require("node:fs") as typeof import("node:fs");
const data = readFileSync("/proc/self/environ", "utf-8");
for (const entry of data.split("\0")) {
const idx = entry.indexOf("=");
if (idx > 0) {
_procEnvCache.set(entry.slice(0, idx), entry.slice(idx + 1));
}
}
} catch {
// /proc/self/environ may not be readable.
}
}
return _procEnvCache.get(key);
}
import type { KnownProvider, ProviderEnv } from "./types.ts";
import { getProviderEnvValue } from "./utils/provider-env.ts";
let cachedVertexAdcCredentialsExists: boolean | null = null;
function hasVertexAdcCredentials(): boolean {
function hasVertexAdcCredentials(env?: ProviderEnv): boolean {
const explicitCredentialsPath = env?.GOOGLE_APPLICATION_CREDENTIALS;
if (explicitCredentialsPath) {
return _existsSync ? _existsSync(explicitCredentialsPath) : false;
}
if (cachedVertexAdcCredentialsExists === null) {
// If node modules haven't loaded yet (async import race at startup),
// return false WITHOUT caching so the next call retries once they're ready.
@@ -75,7 +48,7 @@ function hasVertexAdcCredentials(): boolean {
}
// Check GOOGLE_APPLICATION_CREDENTIALS env var first (standard way)
const gacPath = process.env.GOOGLE_APPLICATION_CREDENTIALS || getProcEnv("GOOGLE_APPLICATION_CREDENTIALS");
const gacPath = getProviderEnvValue("GOOGLE_APPLICATION_CREDENTIALS", env);
if (gacPath) {
cachedVertexAdcCredentialsExists = _existsSync(gacPath);
} else {
@@ -143,13 +116,13 @@ function getApiKeyEnvVars(provider: string): readonly string[] | undefined {
* credential sources such as AWS profiles, AWS IAM credentials, and Google
* Application Default Credentials.
*/
export function findEnvKeys(provider: KnownProvider): string[] | undefined;
export function findEnvKeys(provider: string): string[] | undefined;
export function findEnvKeys(provider: string): string[] | undefined {
export function findEnvKeys(provider: KnownProvider, env?: ProviderEnv): string[] | undefined;
export function findEnvKeys(provider: string, env?: ProviderEnv): string[] | undefined;
export function findEnvKeys(provider: string, env?: ProviderEnv): string[] | undefined {
const envVars = getApiKeyEnvVars(provider);
if (!envVars) return undefined;
const found = envVars.filter((envVar) => !!process.env[envVar] || !!getProcEnv(envVar));
const found = envVars.filter((envVar) => !!getProviderEnvValue(envVar, env));
return found.length > 0 ? found : undefined;
}
@@ -158,25 +131,22 @@ export function findEnvKeys(provider: string): string[] | undefined {
*
* Will not return API keys for providers that require OAuth tokens.
*/
export function getEnvApiKey(provider: KnownProvider): string | undefined;
export function getEnvApiKey(provider: string): string | undefined;
export function getEnvApiKey(provider: string): string | undefined {
const envKeys = findEnvKeys(provider);
export function getEnvApiKey(provider: KnownProvider, env?: ProviderEnv): string | undefined;
export function getEnvApiKey(provider: string, env?: ProviderEnv): string | undefined;
export function getEnvApiKey(provider: string, env?: ProviderEnv): string | undefined {
const envKeys = findEnvKeys(provider, env);
if (envKeys?.[0]) {
return process.env[envKeys[0]] || getProcEnv(envKeys[0]);
return getProviderEnvValue(envKeys[0], env);
}
// Vertex AI supports either an explicit API key or Application Default Credentials.
// Auth is configured via `gcloud auth application-default login`.
if (provider === "google-vertex") {
const hasCredentials = hasVertexAdcCredentials();
const hasCredentials = hasVertexAdcCredentials(env);
const hasProject = !!(
process.env.GOOGLE_CLOUD_PROJECT ||
process.env.GCLOUD_PROJECT ||
getProcEnv("GOOGLE_CLOUD_PROJECT") ||
getProcEnv("GCLOUD_PROJECT")
getProviderEnvValue("GOOGLE_CLOUD_PROJECT", env) || getProviderEnvValue("GCLOUD_PROJECT", env)
);
const hasLocation = !!(process.env.GOOGLE_CLOUD_LOCATION || getProcEnv("GOOGLE_CLOUD_LOCATION"));
const hasLocation = !!getProviderEnvValue("GOOGLE_CLOUD_LOCATION", env);
if (hasCredentials && hasProject && hasLocation) {
return "<authenticated>";
@@ -192,18 +162,12 @@ export function getEnvApiKey(provider: string): string | undefined {
// 5. AWS_CONTAINER_CREDENTIALS_FULL_URI - ECS task roles (full URI)
// 6. AWS_WEB_IDENTITY_TOKEN_FILE - IRSA (IAM Roles for Service Accounts)
if (
process.env.AWS_PROFILE ||
(process.env.AWS_ACCESS_KEY_ID && process.env.AWS_SECRET_ACCESS_KEY) ||
process.env.AWS_BEARER_TOKEN_BEDROCK ||
process.env.AWS_CONTAINER_CREDENTIALS_RELATIVE_URI ||
process.env.AWS_CONTAINER_CREDENTIALS_FULL_URI ||
process.env.AWS_WEB_IDENTITY_TOKEN_FILE ||
getProcEnv("AWS_PROFILE") ||
(getProcEnv("AWS_ACCESS_KEY_ID") && getProcEnv("AWS_SECRET_ACCESS_KEY")) ||
getProcEnv("AWS_BEARER_TOKEN_BEDROCK") ||
getProcEnv("AWS_CONTAINER_CREDENTIALS_RELATIVE_URI") ||
getProcEnv("AWS_CONTAINER_CREDENTIALS_FULL_URI") ||
getProcEnv("AWS_WEB_IDENTITY_TOKEN_FILE")
getProviderEnvValue("AWS_PROFILE", env) ||
(getProviderEnvValue("AWS_ACCESS_KEY_ID", env) && getProviderEnvValue("AWS_SECRET_ACCESS_KEY", env)) ||
getProviderEnvValue("AWS_BEARER_TOKEN_BEDROCK", env) ||
getProviderEnvValue("AWS_CONTAINER_CREDENTIALS_RELATIVE_URI", env) ||
getProviderEnvValue("AWS_CONTAINER_CREDENTIALS_FULL_URI", env) ||
getProviderEnvValue("AWS_WEB_IDENTITY_TOKEN_FILE", env)
) {
return "<authenticated>";
}
+58 -34
View File
@@ -1,3 +1,4 @@
import type { Agent as HttpsAgent } from "node:https";
import {
BedrockRuntimeClient,
type BedrockRuntimeClientConfig,
@@ -23,6 +24,8 @@ import {
} from "@aws-sdk/client-bedrock-runtime";
import { NodeHttpHandler } from "@smithy/node-http-handler";
import type { BuildMiddleware, DocumentType, MetadataBearer } from "@smithy/types";
import { HttpProxyAgent } from "http-proxy-agent";
import { HttpsProxyAgent } from "https-proxy-agent";
import { calculateCost } from "../models.ts";
import type {
Api,
@@ -31,6 +34,7 @@ import type {
Context,
ImageContent,
Model,
ProviderEnv,
SimpleStreamOptions,
StopReason,
StreamFunction,
@@ -45,7 +49,8 @@ import type {
} from "../types.ts";
import { AssistantMessageEventStream } from "../utils/event-stream.ts";
import { parseStreamingJson } from "../utils/json-parse.ts";
import { createHttpProxyAgentsForTarget } from "../utils/node-http-proxy.ts";
import { resolveHttpProxyUrlForTarget } from "../utils/node-http-proxy.ts";
import { getProviderEnvValue } from "../utils/provider-env.ts";
import { sanitizeSurrogates } from "../utils/sanitize-unicode.ts";
import { adjustMaxTokensForThinking, buildBaseOptions, clampReasoning } from "./simple-options.ts";
import { transformMessages } from "./transform-messages.ts";
@@ -119,18 +124,18 @@ export const streamBedrock: StreamFunction<"bedrock-converse-stream", BedrockOpt
const blocks = output.content as Block[];
const config: BedrockRuntimeClientConfig = {
profile: options.profile,
profile: options.profile || getProviderEnvValue("AWS_PROFILE", options.env),
};
const configuredRegion = getConfiguredBedrockRegion(options);
const hasConfiguredProfile = hasConfiguredBedrockProfile();
const hasAmbientConfiguredProfile = Boolean(getProviderEnvValue("AWS_PROFILE"));
const endpointRegion = getStandardBedrockEndpointRegion(model.baseUrl);
const useExplicitEndpoint = shouldUseExplicitBedrockEndpoint(
model.baseUrl,
configuredRegion,
hasConfiguredProfile,
hasAmbientConfiguredProfile,
);
// Only pin standard AWS Bedrock runtime endpoints when no region/profile is configured.
// Only pin standard AWS Bedrock runtime endpoints when no region or ambient AWS_PROFILE is configured.
// This preserves custom endpoints (VPC/proxy) from #3402 without forcing built-in
// catalog defaults such as us-east-1 to override AWS_REGION/AWS_PROFILE.
if (useExplicitEndpoint) {
@@ -138,8 +143,10 @@ export const streamBedrock: StreamFunction<"bedrock-converse-stream", BedrockOpt
}
// Resolve bearer token for Bedrock API key auth.
const bearerToken = options.bearerToken || process.env.AWS_BEARER_TOKEN_BEDROCK || undefined;
const useBearerToken = bearerToken !== undefined && process.env.AWS_BEDROCK_SKIP_AUTH !== "1";
const skipAuth = getProviderEnvValue("AWS_BEDROCK_SKIP_AUTH", options.env) === "1";
const bearerToken =
options.bearerToken || getProviderEnvValue("AWS_BEARER_TOKEN_BEDROCK", options.env) || undefined;
const useBearerToken = bearerToken !== undefined && !skipAuth;
// in Node.js/Bun environment only
if (typeof process !== "undefined" && (process.versions?.node || process.versions?.bun)) {
@@ -153,25 +160,33 @@ export const streamBedrock: StreamFunction<"bedrock-converse-stream", BedrockOpt
config.region = configuredRegion;
} else if (endpointRegion && useExplicitEndpoint) {
config.region = endpointRegion;
} else if (!hasConfiguredProfile) {
} else if (!hasAmbientConfiguredProfile) {
config.region = "us-east-1";
}
// Support proxies that don't need authentication
if (process.env.AWS_BEDROCK_SKIP_AUTH === "1") {
if (skipAuth) {
config.credentials = {
accessKeyId: "dummy-access-key",
secretAccessKey: "dummy-secret-key",
};
}
const proxyAgents = createHttpProxyAgentsForTarget(model.baseUrl);
if (proxyAgents) {
const credentials = getConfiguredBedrockCredentials(options.env);
if (!skipAuth && credentials) {
config.credentials = credentials;
}
const proxyUrl = resolveHttpProxyUrlForTarget(model.baseUrl, options.env);
if (proxyUrl) {
// Bedrock runtime uses NodeHttp2Handler by default since v3.798.0, which is based
// on `http2` module and has no support for http agent.
// Use NodeHttpHandler to support HTTP(S) proxy agents.
config.requestHandler = new NodeHttpHandler(proxyAgents);
} else if (process.env.AWS_BEDROCK_FORCE_HTTP1 === "1") {
config.requestHandler = new NodeHttpHandler({
httpAgent: new HttpProxyAgent(proxyUrl),
httpsAgent: new HttpsProxyAgent(proxyUrl) as unknown as HttpsAgent,
});
} else if (getProviderEnvValue("AWS_BEDROCK_FORCE_HTTP1", options.env) === "1") {
// Some custom endpoints require HTTP/1.1 instead of HTTP/2
config.requestHandler = new NodeHttpHandler();
}
@@ -192,12 +207,12 @@ export const streamBedrock: StreamFunction<"bedrock-converse-stream", BedrockOpt
if (options.headers && Object.keys(options.headers).length > 0) {
addCustomHeadersMiddleware(client, options.headers);
}
const cacheRetention = resolveCacheRetention(options.cacheRetention);
const cacheRetention = resolveCacheRetention(options.cacheRetention, options.env);
const inferenceMaxTokens = options.maxTokens ?? (isAnthropicClaudeModel(model) ? model.maxTokens : undefined);
let commandInput = {
modelId: model.id,
messages: convertMessages(context, model, cacheRetention),
system: buildSystemPrompt(context.systemPrompt, model, cacheRetention),
messages: convertMessages(context, model, cacheRetention, options.env),
system: buildSystemPrompt(context.systemPrompt, model, cacheRetention, options.env),
inferenceConfig: {
...(inferenceMaxTokens !== undefined && { maxTokens: inferenceMaxTokens }),
...(options.temperature !== undefined && { temperature: options.temperature }),
@@ -578,11 +593,11 @@ function mapThinkingLevelToEffort(
* Resolve cache retention preference.
* Defaults to "short" and uses PI_CACHE_RETENTION for backward compatibility.
*/
function resolveCacheRetention(cacheRetention?: CacheRetention): CacheRetention {
function resolveCacheRetention(cacheRetention?: CacheRetention, env?: ProviderEnv): CacheRetention {
if (cacheRetention) {
return cacheRetention;
}
if (typeof process !== "undefined" && process.env.PI_CACHE_RETENTION === "long") {
if (getProviderEnvValue("PI_CACHE_RETENTION", env) === "long") {
return "long";
}
return "short";
@@ -617,14 +632,14 @@ function isAnthropicClaudeModel(model: Model<"bedrock-converse-stream">): boolea
* As a last resort, set AWS_BEDROCK_FORCE_CACHE=1 to enable cache points.
* Amazon Nova models have automatic caching and don't need explicit cache points.
*/
function supportsPromptCaching(model: Model<"bedrock-converse-stream">): boolean {
function supportsPromptCaching(model: Model<"bedrock-converse-stream">, env?: ProviderEnv): boolean {
const candidates = getModelMatchCandidates(model.id, model.name);
const hasClaudeRef = candidates.some((s) => s.includes("claude"));
if (!hasClaudeRef) {
// Application inference profiles don't contain the model name in the ARN.
// Allow users to force cache points via environment variable.
if (typeof process !== "undefined" && process.env.AWS_BEDROCK_FORCE_CACHE === "1") return true;
if (getProviderEnvValue("AWS_BEDROCK_FORCE_CACHE", env) === "1") return true;
return false;
}
// Claude 4.x models (opus-4, sonnet-4, haiku-4)
@@ -652,13 +667,14 @@ function buildSystemPrompt(
systemPrompt: string | undefined,
model: Model<"bedrock-converse-stream">,
cacheRetention: CacheRetention,
env?: ProviderEnv,
): SystemContentBlock[] | undefined {
if (!systemPrompt) return undefined;
const blocks: SystemContentBlock[] = [{ text: sanitizeSurrogates(systemPrompt) }];
// Add cache point for supported Claude models when caching is enabled
if (cacheRetention !== "none" && supportsPromptCaching(model)) {
if (cacheRetention !== "none" && supportsPromptCaching(model, env)) {
blocks.push({
cachePoint: { type: CachePointType.DEFAULT, ...(cacheRetention === "long" ? { ttl: CacheTTL.ONE_HOUR } : {}) },
});
@@ -699,6 +715,7 @@ function convertMessages(
context: Context,
model: Model<"bedrock-converse-stream">,
cacheRetention: CacheRetention,
env?: ProviderEnv,
): Message[] {
const result: Message[] = [];
const transformedMessages = transformMessages(context.messages, model, normalizeToolCallId);
@@ -844,7 +861,7 @@ function convertMessages(
}
// Add cache point to the last user message for supported Claude models when caching is enabled
if (cacheRetention !== "none" && supportsPromptCaching(model) && result.length > 0) {
if (cacheRetention !== "none" && supportsPromptCaching(model, env) && result.length > 0) {
const lastMessage = result[result.length - 1];
if (lastMessage.role === ConversationRole.USER && lastMessage.content) {
(lastMessage.content as ContentBlock[]).push({
@@ -906,19 +923,26 @@ function mapStopReason(reason: string | undefined): StopReason {
}
function getConfiguredBedrockRegion(options: BedrockOptions): string | undefined {
if (typeof process === "undefined") {
return options.region;
}
return options.region || process.env.AWS_REGION || process.env.AWS_DEFAULT_REGION || undefined;
return (
options.region ||
getProviderEnvValue("AWS_REGION", options.env) ||
getProviderEnvValue("AWS_DEFAULT_REGION", options.env) ||
undefined
);
}
function hasConfiguredBedrockProfile(): boolean {
if (typeof process === "undefined") {
return false;
function getConfiguredBedrockCredentials(env?: ProviderEnv): BedrockRuntimeClientConfig["credentials"] | undefined {
const accessKeyId = getProviderEnvValue("AWS_ACCESS_KEY_ID", env);
const secretAccessKey = getProviderEnvValue("AWS_SECRET_ACCESS_KEY", env);
if (!accessKeyId || !secretAccessKey) {
return undefined;
}
return Boolean(process.env.AWS_PROFILE);
const sessionToken = getProviderEnvValue("AWS_SESSION_TOKEN", env);
return {
accessKeyId,
secretAccessKey,
...(sessionToken ? { sessionToken } : {}),
};
}
function getStandardBedrockEndpointRegion(baseUrl: string | undefined): string | undefined {
@@ -938,14 +962,14 @@ function getStandardBedrockEndpointRegion(baseUrl: string | undefined): string |
function shouldUseExplicitBedrockEndpoint(
baseUrl: string,
configuredRegion: string | undefined,
hasConfiguredProfile: boolean,
hasAmbientConfiguredProfile: boolean,
): boolean {
const endpointRegion = getStandardBedrockEndpointRegion(baseUrl);
if (!endpointRegion) {
return true;
}
return !configuredRegion && !hasConfiguredProfile;
return !configuredRegion && !hasAmbientConfiguredProfile;
}
function isGovCloudBedrockTarget(model: Model<"bedrock-converse-stream">, options: BedrockOptions): boolean {
+11 -6
View File
@@ -17,6 +17,7 @@ import type {
ImageContent,
Message,
Model,
ProviderEnv,
SimpleStreamOptions,
StopReason,
StreamFunction,
@@ -30,6 +31,7 @@ import type {
import { AssistantMessageEventStream } from "../utils/event-stream.ts";
import { headersToRecord } from "../utils/headers.ts";
import { parseJsonWithRepair, parseStreamingJson } from "../utils/json-parse.ts";
import { getProviderEnvValue } from "../utils/provider-env.ts";
import { sanitizeSurrogates } from "../utils/sanitize-unicode.ts";
import { resolveCloudflareBaseUrl } from "./cloudflare.ts";
@@ -41,11 +43,11 @@ import { transformMessages } from "./transform-messages.ts";
* Resolve cache retention preference.
* Defaults to "short" and uses PI_CACHE_RETENTION for backward compatibility.
*/
function resolveCacheRetention(cacheRetention?: CacheRetention): CacheRetention {
function resolveCacheRetention(cacheRetention?: CacheRetention, env?: ProviderEnv): CacheRetention {
if (cacheRetention) {
return cacheRetention;
}
if (typeof process !== "undefined" && process.env.PI_CACHE_RETENTION === "long") {
if (getProviderEnvValue("PI_CACHE_RETENTION", env) === "long") {
return "long";
}
return "short";
@@ -54,8 +56,9 @@ function resolveCacheRetention(cacheRetention?: CacheRetention): CacheRetention
function getCacheControl(
model: Model<"anthropic-messages">,
cacheRetention?: CacheRetention,
env?: ProviderEnv,
): { retention: CacheRetention; cacheControl?: CacheControlEphemeral } {
const retention = resolveCacheRetention(cacheRetention);
const retention = resolveCacheRetention(cacheRetention, env);
if (retention === "none") {
return { retention };
}
@@ -494,7 +497,7 @@ export const streamAnthropic: StreamFunction<"anthropic-messages", AnthropicOpti
});
}
const cacheRetention = options?.cacheRetention ?? resolveCacheRetention();
const cacheRetention = resolveCacheRetention(options?.cacheRetention, options?.env);
const cacheSessionId = cacheRetention === "none" ? undefined : options?.sessionId;
const created = createClient(
@@ -505,6 +508,7 @@ export const streamAnthropic: StreamFunction<"anthropic-messages", AnthropicOpti
options?.headers,
copilotDynamicHeaders,
cacheSessionId,
options?.env,
);
client = created.client;
isOAuth = created.isOAuthToken;
@@ -794,6 +798,7 @@ function createClient(
optionsHeaders?: Record<string, string>,
dynamicHeaders?: Record<string, string>,
sessionId?: string,
env?: ProviderEnv,
): { client: Anthropic; isOAuthToken: boolean } {
// Adaptive thinking models have interleaved thinking built in, so skip the beta header.
const needsInterleavedBeta = interleavedThinking && model.compat?.forceAdaptiveThinking !== true;
@@ -809,7 +814,7 @@ function createClient(
const client = new Anthropic({
apiKey: null,
authToken: null,
baseURL: resolveCloudflareBaseUrl(model),
baseURL: resolveCloudflareBaseUrl(model, env),
dangerouslyAllowBrowser: true,
defaultHeaders: mergeHeaders(
{
@@ -902,7 +907,7 @@ function buildParams(
isOAuthToken: boolean,
options?: AnthropicOptions,
): MessageCreateParamsStreaming {
const { cacheControl } = getCacheControl(model, options?.cacheRetention);
const { cacheControl } = getCacheControl(model, options?.cacheRetention, options?.env);
const compat = getAnthropicCompat(model);
const params: MessageCreateParamsStreaming = {
model: model.id,
@@ -12,6 +12,7 @@ import type {
} from "../types.ts";
import { AssistantMessageEventStream } from "../utils/event-stream.ts";
import { headersToRecord } from "../utils/headers.ts";
import { getProviderEnvValue } from "../utils/provider-env.ts";
import { clampOpenAIPromptCacheKey } from "./openai-prompt-cache.ts";
import { convertResponsesMessages, convertResponsesTools, processResponsesStream } from "./openai-responses-shared.ts";
import { buildBaseOptions } from "./simple-options.ts";
@@ -36,7 +37,9 @@ function resolveDeploymentName(model: Model<"azure-openai-responses">, options?:
if (options?.azureDeploymentName) {
return options.azureDeploymentName;
}
const mappedDeployment = parseDeploymentNameMap(process.env.AZURE_OPENAI_DEPLOYMENT_NAME_MAP).get(model.id);
const mappedDeployment = parseDeploymentNameMap(
getProviderEnvValue("AZURE_OPENAI_DEPLOYMENT_NAME_MAP", options?.env),
).get(model.id);
return mappedDeployment || model.id;
}
@@ -198,10 +201,14 @@ function resolveAzureConfig(
model: Model<"azure-openai-responses">,
options?: AzureOpenAIResponsesOptions,
): { baseUrl: string; apiVersion: string } {
const apiVersion = options?.azureApiVersion || process.env.AZURE_OPENAI_API_VERSION || DEFAULT_AZURE_API_VERSION;
const apiVersion =
options?.azureApiVersion ||
getProviderEnvValue("AZURE_OPENAI_API_VERSION", options?.env) ||
DEFAULT_AZURE_API_VERSION;
const baseUrl = options?.azureBaseUrl?.trim() || process.env.AZURE_OPENAI_BASE_URL?.trim() || undefined;
const resourceName = options?.azureResourceName || process.env.AZURE_OPENAI_RESOURCE_NAME;
const baseUrl =
options?.azureBaseUrl?.trim() || getProviderEnvValue("AZURE_OPENAI_BASE_URL", options?.env)?.trim() || undefined;
const resourceName = options?.azureResourceName || getProviderEnvValue("AZURE_OPENAI_RESOURCE_NAME", options?.env);
let resolvedBaseUrl = baseUrl;
+5 -4
View File
@@ -1,4 +1,5 @@
import type { Api, Model } from "../types.ts";
import type { Api, Model, ProviderEnv } from "../types.ts";
import { getProviderEnvValue } from "../utils/provider-env.ts";
/** Workers AI direct endpoint. */
export const CLOUDFLARE_WORKERS_AI_BASE_URL =
@@ -20,12 +21,12 @@ export function isCloudflareProvider(provider: string): boolean {
return provider === "cloudflare-workers-ai" || provider === "cloudflare-ai-gateway";
}
/** Substitute `{VAR}` placeholders in a Cloudflare baseUrl from process.env. */
export function resolveCloudflareBaseUrl(model: Model<Api>): string {
/** Substitute `{VAR}` placeholders in a Cloudflare baseUrl from provider env or process.env. */
export function resolveCloudflareBaseUrl(model: Model<Api>, env?: ProviderEnv): string {
const url = model.baseUrl;
if (!url.includes("{")) return url;
const baseUrl = url.replace(/\{([A-Z_][A-Z0-9_]*)\}/g, (_match, name: string) => {
const value = process.env[name];
const value = getProviderEnvValue(name, env);
if (!value) {
throw new Error(`${name} is required for provider ${model.provider} but is not set.`);
}
+16 -3
View File
@@ -14,6 +14,7 @@ import type {
Context,
Model,
ThinkingLevel as PiThinkingLevel,
ProviderEnv,
SimpleStreamOptions,
StreamFunction,
StreamOptions,
@@ -23,6 +24,7 @@ import type {
ToolCall,
} from "../types.ts";
import { AssistantMessageEventStream } from "../utils/event-stream.ts";
import { getProviderEnvValue } from "../utils/provider-env.ts";
import { sanitizeSurrogates } from "../utils/sanitize-unicode.ts";
import type { GoogleThinkingLevel } from "./google-shared.ts";
import {
@@ -91,7 +93,7 @@ export const streamGoogleVertex: StreamFunction<"google-vertex", GoogleVertexOpt
// Create the client using either a Vertex API key, if provided, or ADC with project and location
const client = apiKey
? createClientWithApiKey(model, apiKey, options?.headers)
: createClient(model, resolveProject(options), resolveLocation(options), options?.headers);
: createClient(model, resolveProject(options), resolveLocation(options), options?.headers, options?.env);
let params = buildParams(model, context, options);
const nextParams = await options?.onPayload?.(params, model);
if (nextParams !== undefined) {
@@ -333,12 +335,15 @@ function createClient(
project: string,
location: string,
optionsHeaders?: Record<string, string>,
env?: ProviderEnv,
): GoogleGenAI {
const googleAuthOptions = buildGoogleAuthOptions(env);
return new GoogleGenAI({
vertexai: true,
project,
location,
apiVersion: API_VERSION,
...(googleAuthOptions ? { googleAuthOptions } : {}),
httpOptions: buildHttpOptions(model, optionsHeaders),
});
}
@@ -394,6 +399,11 @@ function baseUrlIncludesApiVersion(baseUrl: string): boolean {
}
}
function buildGoogleAuthOptions(env?: ProviderEnv): { keyFilename: string } | undefined {
const keyFilename = getProviderEnvValue("GOOGLE_APPLICATION_CREDENTIALS", env);
return keyFilename ? { keyFilename } : undefined;
}
function resolveApiKey(options?: GoogleVertexOptions): string | undefined {
const apiKey = options?.apiKey?.trim();
if (!apiKey || apiKey === GCP_VERTEX_CREDENTIALS_MARKER || isPlaceholderApiKey(apiKey)) {
@@ -407,7 +417,10 @@ function isPlaceholderApiKey(apiKey: string): boolean {
}
function resolveProject(options?: GoogleVertexOptions): string {
const project = options?.project || process.env.GOOGLE_CLOUD_PROJECT || process.env.GCLOUD_PROJECT;
const project =
options?.project ||
getProviderEnvValue("GOOGLE_CLOUD_PROJECT", options?.env) ||
getProviderEnvValue("GCLOUD_PROJECT", options?.env);
if (!project) {
throw new Error(
"Vertex AI requires a project ID. Set GOOGLE_CLOUD_PROJECT/GCLOUD_PROJECT or pass project in options.",
@@ -417,7 +430,7 @@ function resolveProject(options?: GoogleVertexOptions): string {
}
function resolveLocation(options?: GoogleVertexOptions): string {
const location = options?.location || process.env.GOOGLE_CLOUD_LOCATION;
const location = options?.location || getProviderEnvValue("GOOGLE_CLOUD_LOCATION", options?.env);
if (!location) {
throw new Error("Vertex AI requires a location. Set GOOGLE_CLOUD_LOCATION or pass location in options.");
}
@@ -27,6 +27,7 @@ import type {
AssistantMessage,
Context,
Model,
ProviderEnv,
SimpleStreamOptions,
StreamFunction,
StreamOptions,
@@ -40,6 +41,7 @@ import {
} from "../utils/diagnostics.ts";
import { AssistantMessageEventStream } from "../utils/event-stream.ts";
import { headersToRecord } from "../utils/headers.ts";
import { resolveHttpProxyUrlForTarget } from "../utils/node-http-proxy.ts";
import { clampOpenAIPromptCacheKey } from "./openai-prompt-cache.ts";
import { convertResponsesMessages, convertResponsesTools, processResponsesStream } from "./openai-responses-shared.ts";
import { buildBaseOptions } from "./simple-options.ts";
@@ -814,19 +816,13 @@ type WebSocketConstructor = new (
) => WebSocketLike;
let _cachedWebsocket: WebSocketConstructor | null = null;
async function getWebSocketConstructor(): Promise<WebSocketConstructor | null> {
if (_cachedWebsocket) return _cachedWebsocket;
async function getWebSocketConstructor(env?: ProviderEnv): Promise<WebSocketConstructor | null> {
if (!env && _cachedWebsocket) return _cachedWebsocket;
// bun doesn't respect http proxy envs, ref: https://github.com/oven-sh/bun/issues/15489
// TODO: remove this when bun supports proxy envs in websocket.
if (
process?.versions?.bun &&
(process.env.HTTP_PROXY || process.env.HTTPS_PROXY || process.env.http_proxy || process.env.https_proxy)
) {
const m = await dynamicImport("proxy-from-env");
const getProxyForUrl = (m as { getProxyForUrl: (url: string | object | URL) => string }).getProxyForUrl;
_cachedWebsocket = class extends WebSocket {
if (typeof process !== "undefined" && process.versions?.bun) {
const WebSocketWithProxy = class extends WebSocket {
constructor(url: string | URL, options?: string | string[] | Record<string, unknown>) {
let _opts: Record<string, unknown> = {};
if (Array.isArray(options) || typeof options === "string") {
@@ -835,11 +831,17 @@ async function getWebSocketConstructor(): Promise<WebSocketConstructor | null> {
_opts = { ...options };
}
const proxy = getProxyForUrl(url.toString().replace(/^wss:/, "https:").replace(/^ws:/, "http:"));
super(url, { ..._opts, ...(proxy ? { proxy } : {}) } as any);
const proxyUrl = resolveHttpProxyUrlForTarget(
url.toString().replace(/^wss:/, "https:").replace(/^ws:/, "http:"),
env,
);
super(url, { ..._opts, ...(proxyUrl ? { proxy: proxyUrl.toString() } : {}) } as any);
}
};
return _cachedWebsocket;
if (!env) {
_cachedWebsocket = WebSocketWithProxy;
}
return WebSocketWithProxy;
}
const ctor = (globalThis as { WebSocket?: unknown }).WebSocket;
@@ -894,8 +896,9 @@ async function connectWebSocket(
headers: Headers,
signal?: AbortSignal,
connectTimeoutMs = DEFAULT_WEBSOCKET_CONNECT_TIMEOUT_MS,
env?: ProviderEnv,
): Promise<WebSocketLike> {
const WebSocketCtor = await getWebSocketConstructor();
const WebSocketCtor = await getWebSocketConstructor(env);
if (!WebSocketCtor) {
throw new Error("WebSocket transport is not available in this runtime");
}
@@ -972,6 +975,7 @@ async function acquireWebSocket(
sessionId: string | undefined,
signal?: AbortSignal,
connectTimeoutMs?: number,
env?: ProviderEnv,
): Promise<{
socket: WebSocketLike;
entry?: CachedWebSocketConnection;
@@ -979,7 +983,7 @@ async function acquireWebSocket(
release: (options?: { keep?: boolean }) => void;
}> {
if (!sessionId) {
const socket = await connectWebSocket(url, headers, signal, connectTimeoutMs);
const socket = await connectWebSocket(url, headers, signal, connectTimeoutMs, env);
return {
socket,
reused: false,
@@ -1011,7 +1015,7 @@ async function acquireWebSocket(
};
}
if (cached.busy) {
const socket = await connectWebSocket(url, headers, signal, connectTimeoutMs);
const socket = await connectWebSocket(url, headers, signal, connectTimeoutMs, env);
return {
socket,
reused: false,
@@ -1026,7 +1030,7 @@ async function acquireWebSocket(
}
}
const socket = await connectWebSocket(url, headers, signal, connectTimeoutMs);
const socket = await connectWebSocket(url, headers, signal, connectTimeoutMs, env);
const entry: CachedWebSocketConnection = { socket, busy: true };
websocketSessionCache.set(sessionId, entry);
return {
@@ -1312,6 +1316,7 @@ async function processWebSocketStream(
options?.sessionId,
options?.signal,
websocketConnectTimeoutMs,
options?.env,
);
let keepConnection = true;
const useCachedContext = options?.transport === "websocket-cached" || options?.transport === "auto";
@@ -19,6 +19,7 @@ import type {
Message,
Model,
OpenAICompletionsCompat,
ProviderEnv,
SimpleStreamOptions,
StopReason,
StreamFunction,
@@ -32,6 +33,7 @@ import type {
import { AssistantMessageEventStream } from "../utils/event-stream.ts";
import { headersToRecord } from "../utils/headers.ts";
import { parseStreamingJson } from "../utils/json-parse.ts";
import { getProviderEnvValue } from "../utils/provider-env.ts";
import { sanitizeSurrogates } from "../utils/sanitize-unicode.ts";
import { isCloudflareProvider, resolveCloudflareBaseUrl } from "./cloudflare.ts";
import { buildCopilotDynamicHeaders, hasCopilotVisionInput } from "./github-copilot-headers.ts";
@@ -98,11 +100,11 @@ type ChatCompletionToolWithCacheControl = OpenAI.Chat.Completions.ChatCompletion
cache_control?: OpenAICompatCacheControl;
};
function resolveCacheRetention(cacheRetention?: CacheRetention): CacheRetention {
function resolveCacheRetention(cacheRetention?: CacheRetention, env?: ProviderEnv): CacheRetention {
if (cacheRetention) {
return cacheRetention;
}
if (typeof process !== "undefined" && process.env.PI_CACHE_RETENTION === "long") {
if (getProviderEnvValue("PI_CACHE_RETENTION", env) === "long") {
return "long";
}
return "short";
@@ -140,9 +142,9 @@ export const streamOpenAICompletions: StreamFunction<"openai-completions", OpenA
throw new Error(`No API key for provider: ${model.provider}`);
}
const compat = getCompat(model);
const cacheRetention = resolveCacheRetention(options?.cacheRetention);
const cacheRetention = resolveCacheRetention(options?.cacheRetention, options?.env);
const cacheSessionId = cacheRetention === "none" ? undefined : options?.sessionId;
const client = createClient(model, context, apiKey, options?.headers, cacheSessionId, compat);
const client = createClient(model, context, apiKey, options?.headers, cacheSessionId, compat, options?.env);
let params = buildParams(model, context, options, compat, cacheRetention);
const nextParams = await options?.onPayload?.(params, model);
if (nextParams !== undefined) {
@@ -454,6 +456,7 @@ function createClient(
optionsHeaders?: Record<string, string>,
sessionId?: string,
compat: ResolvedOpenAICompletionsCompat = getCompat(model),
env?: ProviderEnv,
) {
const headers = { ...model.headers };
if (model.provider === "github-copilot") {
@@ -487,7 +490,7 @@ function createClient(
return new OpenAI({
apiKey,
baseURL: isCloudflareProvider(model.provider) ? resolveCloudflareBaseUrl(model) : model.baseUrl,
baseURL: isCloudflareProvider(model.provider) ? resolveCloudflareBaseUrl(model, env) : model.baseUrl,
dangerouslyAllowBrowser: true,
defaultHeaders,
});
@@ -498,7 +501,7 @@ function buildParams(
context: Context,
options?: OpenAICompletionsOptions,
compat: ResolvedOpenAICompletionsCompat = getCompat(model),
cacheRetention: CacheRetention = resolveCacheRetention(options?.cacheRetention),
cacheRetention: CacheRetention = resolveCacheRetention(options?.cacheRetention, options?.env),
) {
const messages = convertMessages(model, context, compat);
const cacheControl = getCompatCacheControl(compat, cacheRetention);
@@ -8,6 +8,7 @@ import type {
Context,
Model,
OpenAIResponsesCompat,
ProviderEnv,
SimpleStreamOptions,
StreamFunction,
StreamOptions,
@@ -15,6 +16,7 @@ import type {
} from "../types.ts";
import { AssistantMessageEventStream } from "../utils/event-stream.ts";
import { headersToRecord } from "../utils/headers.ts";
import { getProviderEnvValue } from "../utils/provider-env.ts";
import { isCloudflareProvider, resolveCloudflareBaseUrl } from "./cloudflare.ts";
import { buildCopilotDynamicHeaders, hasCopilotVisionInput } from "./github-copilot-headers.ts";
import { clampOpenAIPromptCacheKey } from "./openai-prompt-cache.ts";
@@ -27,11 +29,11 @@ const OPENAI_TOOL_CALL_PROVIDERS = new Set(["openai", "openai-codex", "opencode"
* Resolve cache retention preference.
* Defaults to "short" and uses PI_CACHE_RETENTION for backward compatibility.
*/
function resolveCacheRetention(cacheRetention?: CacheRetention): CacheRetention {
function resolveCacheRetention(cacheRetention?: CacheRetention, env?: ProviderEnv): CacheRetention {
if (cacheRetention) {
return cacheRetention;
}
if (typeof process !== "undefined" && process.env.PI_CACHE_RETENTION === "long") {
if (getProviderEnvValue("PI_CACHE_RETENTION", env) === "long") {
return "long";
}
return "short";
@@ -111,9 +113,9 @@ export const streamOpenAIResponses: StreamFunction<"openai-responses", OpenAIRes
if (!apiKey) {
throw new Error(`No API key for provider: ${model.provider}`);
}
const cacheRetention = resolveCacheRetention(options?.cacheRetention);
const cacheRetention = resolveCacheRetention(options?.cacheRetention, options?.env);
const cacheSessionId = cacheRetention === "none" ? undefined : options?.sessionId;
const client = createClient(model, context, apiKey, options?.headers, cacheSessionId);
const client = createClient(model, context, apiKey, options?.headers, cacheSessionId, options?.env);
let params = buildParams(model, context, options);
const nextParams = await options?.onPayload?.(params, model);
if (nextParams !== undefined) {
@@ -185,6 +187,7 @@ function createClient(
apiKey: string,
optionsHeaders?: Record<string, string>,
sessionId?: string,
env?: ProviderEnv,
) {
const compat = getCompat(model);
const headers = { ...model.headers };
@@ -220,7 +223,7 @@ function createClient(
return new OpenAI({
apiKey,
baseURL: isCloudflareProvider(model.provider) ? resolveCloudflareBaseUrl(model) : model.baseUrl,
baseURL: isCloudflareProvider(model.provider) ? resolveCloudflareBaseUrl(model, env) : model.baseUrl,
dangerouslyAllowBrowser: true,
defaultHeaders,
});
@@ -229,7 +232,7 @@ function createClient(
function buildParams(model: Model<"openai-responses">, context: Context, options?: OpenAIResponsesOptions) {
const messages = convertResponsesMessages(model, context, OPENAI_TOOL_CALL_PROVIDERS);
const cacheRetention = resolveCacheRetention(options?.cacheRetention);
const cacheRetention = resolveCacheRetention(options?.cacheRetention, options?.env);
const compat = getCompat(model);
const params: ResponseCreateParamsStreaming = {
model: model.id,
@@ -17,6 +17,7 @@ export function buildBaseOptions(_model: Model<Api>, options?: SimpleStreamOptio
maxRetries: options?.maxRetries,
maxRetryDelayMs: options?.maxRetryDelayMs,
metadata: options?.metadata,
env: options?.env,
};
}
+1 -1
View File
@@ -24,7 +24,7 @@ function withEnvApiKey<TOptions extends StreamOptions>(
options: TOptions | undefined,
): TOptions | undefined {
if (hasExplicitApiKey(options?.apiKey)) return options;
const apiKey = getEnvApiKey(model.provider);
const apiKey = getEnvApiKey(model.provider, options?.env);
if (!apiKey) return options;
return { ...options, apiKey } as TOptions;
}
+9
View File
@@ -79,6 +79,9 @@ export type CacheRetention = "none" | "short" | "long";
export type Transport = "sse" | "websocket" | "websocket-cached" | "auto";
/** Provider-scoped environment overrides. Values take precedence over process.env. */
export type ProviderEnv = Record<string, string>;
export interface ProviderResponse {
status: number;
headers: Record<string, string>;
@@ -153,6 +156,12 @@ export interface StreamOptions {
* For example, Anthropic uses `user_id` for abuse tracking and rate limiting.
*/
metadata?: Record<string, unknown>;
/**
* Provider-scoped environment values. These take precedence over process.env for
* provider configuration such as regional settings, endpoint placeholders, and
* proxy variables.
*/
env?: ProviderEnv;
}
export type ProviderStreamOptions = StreamOptions & Record<string, unknown>;
+22 -33
View File
@@ -1,7 +1,5 @@
import type { Agent as HttpAgent } from "node:http";
import type { Agent as HttpsAgent } from "node:https";
import { HttpProxyAgent } from "http-proxy-agent";
import { HttpsProxyAgent } from "https-proxy-agent";
import type { ProviderEnv } from "../types.ts";
import { getProviderEnvValue } from "./provider-env.ts";
const DEFAULT_PROXY_PORTS: Record<string, number> = {
ftp: 21,
@@ -12,16 +10,16 @@ const DEFAULT_PROXY_PORTS: Record<string, number> = {
wss: 443,
};
export interface NodeHttpProxyAgents {
httpAgent: HttpAgent;
httpsAgent: HttpsAgent;
}
export const UNSUPPORTED_PROXY_PROTOCOL_MESSAGE =
"Unsupported proxy protocol. SOCKS and PAC proxy URLs are not supported; use an HTTP or HTTPS proxy URL.";
function getProxyEnv(key: string): string {
return process.env[key.toLowerCase()] || process.env[key.toUpperCase()] || "";
function getProxyEnv(key: string, env?: ProviderEnv): string {
const lowercaseKey = key.toLowerCase();
const uppercaseKey = key.toUpperCase();
return (
env?.[lowercaseKey] ||
env?.[uppercaseKey] ||
getProviderEnvValue(lowercaseKey) ||
getProviderEnvValue(uppercaseKey) ||
""
);
}
function parseProxyTargetUrl(targetUrl: string | URL): URL | undefined {
@@ -36,8 +34,8 @@ function parseProxyTargetUrl(targetUrl: string | URL): URL | undefined {
}
}
function shouldProxyHostname(hostname: string, port: number): boolean {
const noProxy = getProxyEnv("no_proxy").toLowerCase();
function shouldProxyHostname(hostname: string, port: number, env?: ProviderEnv): boolean {
const noProxy = getProxyEnv("no_proxy", env).toLowerCase();
if (!noProxy) {
return true;
}
@@ -68,7 +66,7 @@ function shouldProxyHostname(hostname: string, port: number): boolean {
});
}
function getProxyForUrl(targetUrl: string | URL): string {
function getProxyForUrl(targetUrl: string | URL, env?: ProviderEnv): string {
const parsedUrl = parseProxyTargetUrl(targetUrl);
if (!parsedUrl?.protocol || !parsedUrl.host) {
return "";
@@ -77,19 +75,22 @@ function getProxyForUrl(targetUrl: string | URL): string {
const protocol = parsedUrl.protocol.split(":", 1)[0]!;
const hostname = parsedUrl.host.replace(/:\d*$/, "");
const port = Number.parseInt(parsedUrl.port, 10) || DEFAULT_PROXY_PORTS[protocol] || 0;
if (!shouldProxyHostname(hostname, port)) {
if (!shouldProxyHostname(hostname, port, env)) {
return "";
}
let proxy = getProxyEnv(`${protocol}_proxy`) || getProxyEnv("all_proxy");
let proxy = getProxyEnv(`${protocol}_proxy`, env) || getProxyEnv("all_proxy", env);
if (proxy && !proxy.includes("://")) {
proxy = `${protocol}://${proxy}`;
}
return proxy;
}
export function resolveHttpProxyUrlForTarget(targetUrl: string | URL): URL | undefined {
const proxy = getProxyForUrl(targetUrl);
export const UNSUPPORTED_PROXY_PROTOCOL_MESSAGE =
"Unsupported proxy protocol. SOCKS and PAC proxy URLs are not supported; use an HTTP or HTTPS proxy URL.";
export function resolveHttpProxyUrlForTarget(targetUrl: string | URL, env?: ProviderEnv): URL | undefined {
const proxy = getProxyForUrl(targetUrl, env);
if (!proxy) {
return undefined;
}
@@ -109,15 +110,3 @@ export function resolveHttpProxyUrlForTarget(targetUrl: string | URL): URL | und
return proxyUrl;
}
export function createHttpProxyAgentsForTarget(targetUrl: string | URL): NodeHttpProxyAgents | undefined {
const proxyUrl = resolveHttpProxyUrlForTarget(targetUrl);
if (!proxyUrl) {
return undefined;
}
return {
httpAgent: new HttpProxyAgent(proxyUrl),
httpsAgent: new HttpsProxyAgent(proxyUrl) as unknown as HttpsAgent,
};
}
+2 -1
View File
@@ -6,6 +6,7 @@
*/
import type { Server } from "node:http";
import { getProviderEnvValue } from "../provider-env.ts";
import { oauthErrorHtml, oauthSuccessHtml } from "./oauth-page.ts";
import { generatePKCE } from "./pkce.ts";
import type { OAuthCredentials, OAuthLoginCallbacks, OAuthPrompt, OAuthProviderInterface } from "./types.ts";
@@ -28,7 +29,7 @@ const decode = (s: string) => atob(s);
const CLIENT_ID = decode("OWQxYzI1MGEtZTYxYi00NGQ5LTg4ZWQtNTk0NGQxOTYyZjVl");
const AUTHORIZE_URL = "https://claude.ai/oauth/authorize";
const TOKEN_URL = "https://platform.claude.com/v1/oauth/token";
const CALLBACK_HOST = process.env.PI_OAUTH_CALLBACK_HOST || "127.0.0.1";
const CALLBACK_HOST = getProviderEnvValue("PI_OAUTH_CALLBACK_HOST") || "127.0.0.1";
const CALLBACK_PORT = 53692;
const CALLBACK_PATH = "/callback";
const REDIRECT_URI = `http://localhost:${CALLBACK_PORT}${CALLBACK_PATH}`;
+2 -1
View File
@@ -17,6 +17,7 @@ if (typeof process !== "undefined" && (process.versions?.node || process.version
});
}
import { getProviderEnvValue } from "../provider-env.ts";
import { pollOAuthDeviceCodeFlow } from "./device-code.ts";
import { oauthErrorHtml, oauthSuccessHtml } from "./oauth-page.ts";
import { generatePKCE } from "./pkce.ts";
@@ -47,7 +48,7 @@ type OAuthToken = { access: string; refresh: string; expires: number };
type TokenOperation = "exchange" | "refresh";
function getCallbackHost(): string {
return typeof process !== "undefined" ? process.env.PI_OAUTH_CALLBACK_HOST || "127.0.0.1" : "127.0.0.1";
return getProviderEnvValue("PI_OAUTH_CALLBACK_HOST") || "127.0.0.1";
}
type DeviceAuthInfo = {
+52
View File
@@ -0,0 +1,52 @@
import type { ProviderEnv } from "../types.ts";
let procEnvCache: Map<string, string> | null = null;
/**
* Fallback for https://github.com/oven-sh/bun/issues/27802.
* Bun compiled binaries can expose an empty process.env inside Linux sandboxes
* even though /proc/self/environ contains the environment.
*
* This intentionally duplicates restoreSandboxEnv() in
* packages/coding-agent/src/bun/restore-sandbox-env.ts. The ai package can be
* used directly, without going through that entrypoint, so provider env lookup
* must not depend on process.env having been patched.
*/
function getBunSandboxEnvValue(name: string): string | undefined {
if (typeof process === "undefined" || !process.versions?.bun || Object.keys(process.env).length > 0) {
return undefined;
}
if (procEnvCache === null) {
procEnvCache = new Map();
try {
const { readFileSync } = require("node:fs") as {
readFileSync(path: string, encoding: BufferEncoding): string;
};
const data = readFileSync("/proc/self/environ", "utf-8");
for (const entry of data.split("\0")) {
const idx = entry.indexOf("=");
if (idx > 0) {
procEnvCache.set(entry.slice(0, idx), entry.slice(idx + 1));
}
}
} catch {
// /proc/self/environ may not exist or may not be readable.
}
}
return procEnvCache.get(name);
}
/**
* Resolve a provider env value from scoped overrides, normal process.env, then
* the duplicated Bun sandbox fallback for direct pi-ai consumers.
*/
export function getProviderEnvValue(name: string, env?: ProviderEnv): string | undefined {
return (
env?.[name] ||
(typeof process !== "undefined" ? process.env[name] : undefined) ||
getBunSandboxEnvValue(name) ||
undefined
);
}
@@ -45,7 +45,7 @@ vi.mock("@aws-sdk/client-bedrock-runtime", () => {
});
import { getModel } from "../src/models.ts";
import { streamBedrock } from "../src/providers/amazon-bedrock.ts";
import { type BedrockOptions, streamBedrock } from "../src/providers/amazon-bedrock.ts";
import type { Context, Model } from "../src/types.ts";
const context: Context = {
@@ -83,8 +83,12 @@ afterEach(() => {
}
});
async function captureClientConfig(model: Model<"bedrock-converse-stream">): Promise<Record<string, unknown>> {
await streamBedrock(model, context, { cacheRetention: "none" }).result();
async function captureClientConfig(
model: Model<"bedrock-converse-stream">,
options: BedrockOptions = {},
): Promise<Record<string, unknown>> {
bedrockMock.constructorCalls.length = 0;
await streamBedrock(model, context, { cacheRetention: "none", ...options }).result();
expect(bedrockMock.constructorCalls).toHaveLength(1);
return bedrockMock.constructorCalls[0];
}
@@ -115,6 +119,29 @@ describe("bedrock endpoint resolution", () => {
expect(config.region).toBe("eu-central-1");
});
it("handles missing regions for explicit, scoped, and ambient profiles", async () => {
const model = getModel("amazon-bedrock", "eu.anthropic.claude-sonnet-4-5-20250929-v1:0");
let config = await captureClientConfig(model, { profile: "bedrock-profile" });
expect(config.profile).toBe("bedrock-profile");
expect(config.endpoint).toBe("https://bedrock-runtime.eu-central-1.amazonaws.com");
expect(config.region).toBe("eu-central-1");
config = await captureClientConfig(model, { env: { AWS_PROFILE: "scoped-bedrock-profile" } });
expect(config.profile).toBe("scoped-bedrock-profile");
expect(config.endpoint).toBe("https://bedrock-runtime.eu-central-1.amazonaws.com");
expect(config.region).toBe("eu-central-1");
process.env.AWS_PROFILE = "ambient-bedrock-profile";
config = await captureClientConfig(model);
expect(config.profile).toBe("ambient-bedrock-profile");
expect(config.endpoint).toBeUndefined();
expect(config.region).toBeUndefined();
});
it("still passes custom Bedrock endpoints through to the SDK client", async () => {
process.env.AWS_REGION = "us-west-2";
const baseModel = getModel("amazon-bedrock", "us.anthropic.claude-opus-4-8");
+11
View File
@@ -54,6 +54,17 @@ describe("node HTTP proxy resolution", () => {
);
});
it("prefers scoped proxy env aliases before process env aliases", () => {
resetProxyEnv();
process.env.https_proxy = "http://process-proxy.example:8080";
expect(
resolveHttpProxyUrlForTarget("https://bedrock-runtime.us-east-1.amazonaws.com", {
HTTPS_PROXY: "http://scoped-proxy.example:8080",
})?.toString(),
).toBe("http://scoped-proxy.example:8080/");
});
it("rejects SOCKS and PAC proxy URLs explicitly", () => {
resetProxyEnv();
process.env.HTTPS_PROXY = "socks5://proxy.example:1080";
@@ -163,6 +163,31 @@ describe("openai-completions empty tools handling", () => {
expect(clientOptions.defaultHeaders?.["cf-aig-authorization"]).toBe("Bearer test");
});
it("uses provider env before process.env for Cloudflare AI Gateway base URL", async () => {
process.env.CLOUDFLARE_ACCOUNT_ID = "process-account";
process.env.CLOUDFLARE_GATEWAY_ID = "process-gateway";
const model = getModel("cloudflare-ai-gateway", "workers-ai/@cf/moonshotai/kimi-k2.6")!;
await streamSimple(
model,
{
messages: [{ role: "user", content: "hi", timestamp: Date.now() }],
},
{
apiKey: "test",
env: {
CLOUDFLARE_ACCOUNT_ID: "provider-account",
CLOUDFLARE_GATEWAY_ID: "provider-gateway",
},
},
).result();
const clientOptions = mockState.lastClientOptions as { baseURL?: string };
expect(clientOptions.baseURL).toBe(
"https://gateway.ai.cloudflare.com/v1/provider-account/provider-gateway/compat",
);
});
it("preserves inline upstream Authorization for Cloudflare AI Gateway BYOK requests", async () => {
process.env.CLOUDFLARE_ACCOUNT_ID = "account-id";
process.env.CLOUDFLARE_GATEWAY_ID = "gateway-id";
+4
View File
@@ -2,6 +2,10 @@
## [Unreleased]
### Added
- Added `auth.json` API key `env` values so provider-specific environment overrides can be scoped to Pi and propagated to inherited provider configuration ([#5728](https://github.com/earendil-works/pi/issues/5728)).
### Fixed
- Fixed successful `pi update` on Windows to exit naturally instead of calling `process.exit(0)`, avoiding a Node.js/libuv assertion after version-check network requests ([#5805](https://github.com/earendil-works/pi/issues/5805)).
+20 -2
View File
@@ -104,6 +104,24 @@ Store credentials in `~/.pi/agent/auth.json`:
The file is created with `0600` permissions (user read/write only). Auth file credentials take priority over environment variables.
API key credentials can also include provider-scoped environment values. These values are used before process environment variables when resolving the credential key, provider/model headers, and provider configuration such as Cloudflare account IDs, Azure OpenAI settings, Vertex project/location, Bedrock settings, `PI_CACHE_RETENTION`, and `HTTP_PROXY`/`HTTPS_PROXY`.
```json
{
"cloudflare-ai-gateway": {
"type": "api_key",
"key": "$CLOUDFLARE_API_KEY",
"env": {
"CLOUDFLARE_API_KEY": "...",
"CLOUDFLARE_ACCOUNT_ID": "account-id",
"CLOUDFLARE_GATEWAY_ID": "gateway-id"
}
}
}
```
Use this when pi should use different provider settings than the project shell environment.
### Key Resolution
The `key` field supports command execution, environment interpolation, and literals:
@@ -194,7 +212,7 @@ export AWS_BEDROCK_FORCE_HTTP1=1
### Cloudflare AI Gateway
`CLOUDFLARE_API_KEY` can be set via `/login`. The account ID and gateway slug must be set as environment variables.
`CLOUDFLARE_API_KEY` can be set via `/login`. The account ID and gateway slug can be set as environment variables or in the API key credential's `env` object in `auth.json`.
```bash
export CLOUDFLARE_API_KEY=... # or use /login
@@ -218,7 +236,7 @@ For normal pi usage, prefer unified billing or stored BYOK. Inline BYOK requires
### Cloudflare Workers AI
`CLOUDFLARE_API_KEY` can be set via `/login`. `CLOUDFLARE_ACCOUNT_ID` must be set as an environment variable.
`CLOUDFLARE_API_KEY` can be set via `/login`. `CLOUDFLARE_ACCOUNT_ID` can be set as an environment variable or in the API key credential's `env` object in `auth.json`.
```bash
export CLOUDFLARE_API_KEY=... # or use /login
@@ -4,6 +4,10 @@
* Bun compiled binaries have an empty `process.env` when running inside
* sandbox environments (e.g. nono on Linux/macOS). On Linux we can recover
* the environment from `/proc/self/environ`.
*
* Keep this in sync with getBunSandboxEnvValue() in
* packages/ai/src/utils/provider-env.ts. The ai package duplicates the lookup
* for direct consumers that do not go through this coding-agent entrypoint.
*/
import { readFileSync } from "node:fs";
@@ -357,6 +357,7 @@ export class AgentSession {
private async _getRequiredRequestAuth(model: Model<any>): Promise<{
apiKey: string;
headers?: Record<string, string>;
env?: Record<string, string>;
}> {
const result = await this._modelRegistry.getApiKeyAndHeaders(model);
if (!result.ok) {
@@ -366,7 +367,7 @@ export class AgentSession {
throw new Error(result.error);
}
if (result.apiKey) {
return { apiKey: result.apiKey, headers: result.headers };
return { apiKey: result.apiKey, headers: result.headers, env: result.env };
}
const isOAuth = this._modelRegistry.isUsingOAuth(model);
@@ -383,13 +384,14 @@ export class AgentSession {
private async _getCompactionRequestAuth(model: Model<any>): Promise<{
apiKey?: string;
headers?: Record<string, string>;
env?: Record<string, string>;
}> {
if (this.agent.streamFn === streamSimple) {
return this._getRequiredRequestAuth(model);
}
const result = await this._modelRegistry.getApiKeyAndHeaders(model);
return result.ok ? { apiKey: result.apiKey, headers: result.headers } : {};
return result.ok ? { apiKey: result.apiKey, headers: result.headers, env: result.env } : {};
}
/**
@@ -1649,7 +1651,7 @@ export class AgentSession {
throw new Error(formatNoModelSelectedMessage());
}
const { apiKey, headers } = await this._getCompactionRequestAuth(this.model);
const { apiKey, headers, env } = await this._getCompactionRequestAuth(this.model);
const pathEntries = this.sessionManager.getBranch();
const settings = this.settingsManager.getCompactionSettings();
@@ -1708,6 +1710,7 @@ export class AgentSession {
this._compactionAbortController.signal,
this.thinkingLevel,
this.agent.streamFn,
env,
);
summary = result.summary;
firstKeptEntryId = result.firstKeptEntryId;
@@ -1898,6 +1901,7 @@ export class AgentSession {
let apiKey: string | undefined;
let headers: Record<string, string> | undefined;
let env: Record<string, string> | undefined;
if (this.agent.streamFn === streamSimple) {
const authResult = await this._modelRegistry.getApiKeyAndHeaders(this.model);
if (!authResult.ok || !authResult.apiKey) {
@@ -1912,8 +1916,9 @@ export class AgentSession {
}
apiKey = authResult.apiKey;
headers = authResult.headers;
env = authResult.env;
} else {
({ apiKey, headers } = await this._getCompactionRequestAuth(this.model));
({ apiKey, headers, env } = await this._getCompactionRequestAuth(this.model));
}
const pathEntries = this.sessionManager.getBranch();
@@ -1981,6 +1986,7 @@ export class AgentSession {
this._autoCompactionAbortController.signal,
this.thinkingLevel,
this.agent.streamFn,
env,
);
summary = compactResult.summary;
firstKeptEntryId = compactResult.firstKeptEntryId;
@@ -2784,12 +2790,13 @@ export class AgentSession {
let summaryDetails: unknown;
if (options.summarize && entriesToSummarize.length > 0 && !extensionSummary) {
const model = this.model!;
const { apiKey, headers } = await this._getRequiredRequestAuth(model);
const { apiKey, headers, env } = await this._getRequiredRequestAuth(model);
const branchSummarySettings = this.settingsManager.getBranchSummarySettings();
const result = await generateBranchSummary(entriesToSummarize, {
model,
apiKey,
headers,
env,
signal: this._branchSummaryAbortController.signal,
customInstructions,
replaceInstructions,
+10 -1
View File
@@ -24,6 +24,7 @@ import { resolveConfigValue } from "./resolve-config-value.ts";
export type ApiKeyCredential = {
type: "api_key";
key: string;
env?: Record<string, string>;
};
export type OAuthCredential = {
@@ -303,6 +304,14 @@ export class AuthStorage {
return this.data[provider] ?? undefined;
}
/**
* Get provider-scoped environment values for an API key credential.
*/
getProviderEnv(provider: string): Record<string, string> | undefined {
const cred = this.data[provider];
return cred?.type === "api_key" && cred.env ? { ...cred.env } : undefined;
}
/**
* Set credential for a provider.
*/
@@ -471,7 +480,7 @@ export class AuthStorage {
const cred = this.data[providerId];
if (cred?.type === "api_key") {
return resolveConfigValue(cred.key);
return resolveConfigValue(cred.key, cred.env);
}
if (cred?.type === "oauth") {
@@ -69,6 +69,8 @@ export interface GenerateBranchSummaryOptions {
apiKey: string;
/** Request headers for the model */
headers?: Record<string, string>;
/** Provider-scoped environment values for the model */
env?: Record<string, string>;
/** Abort signal for cancellation */
signal: AbortSignal;
/** Optional custom instructions for summarization */
@@ -290,6 +292,7 @@ export async function generateBranchSummary(
model,
apiKey,
headers,
env,
signal,
customInstructions,
replaceInstructions,
@@ -335,7 +338,7 @@ export async function generateBranchSummary(
// request behavior (timeouts, retries, attribution headers) stays consistent
// without running through agent state/events.
const context = { systemPrompt: SUMMARIZATION_SYSTEM_PROMPT, messages: summarizationMessages };
const requestOptions: SimpleStreamOptions = { apiKey, headers, signal, maxTokens: 2048 };
const requestOptions: SimpleStreamOptions = { apiKey, headers, env, signal, maxTokens: 2048 };
const response = streamFn
? await (await streamFn(model, context, requestOptions)).result()
: await completeSimple(model, context, requestOptions);
@@ -528,10 +528,11 @@ function createSummarizationOptions(
maxTokens: number,
apiKey: string | undefined,
headers: Record<string, string> | undefined,
env: Record<string, string> | undefined,
signal: AbortSignal | undefined,
thinkingLevel: ThinkingLevel | undefined,
): SimpleStreamOptions {
const options: SimpleStreamOptions = { maxTokens, signal, apiKey, headers };
const options: SimpleStreamOptions = { maxTokens, signal, apiKey, headers, env };
if (model.reasoning && thinkingLevel && thinkingLevel !== "off") {
options.reasoning = thinkingLevel;
}
@@ -566,6 +567,7 @@ export async function generateSummary(
previousSummary?: string,
thinkingLevel?: ThinkingLevel,
streamFn?: StreamFn,
env?: Record<string, string>,
): Promise<string> {
const maxTokens = Math.min(
Math.floor(0.8 * reserveTokens),
@@ -598,7 +600,7 @@ export async function generateSummary(
},
];
const completionOptions = createSummarizationOptions(model, maxTokens, apiKey, headers, signal, thinkingLevel);
const completionOptions = createSummarizationOptions(model, maxTokens, apiKey, headers, env, signal, thinkingLevel);
const response = await completeSummarization(
model,
@@ -753,6 +755,7 @@ export async function compact(
signal?: AbortSignal,
thinkingLevel?: ThinkingLevel,
streamFn?: StreamFn,
env?: Record<string, string>,
): Promise<CompactionResult> {
const {
firstKeptEntryId,
@@ -783,6 +786,7 @@ export async function compact(
previousSummary,
thinkingLevel,
streamFn,
env,
)
: Promise.resolve("No prior history."),
generateTurnPrefixSummary(
@@ -791,6 +795,7 @@ export async function compact(
settings.reserveTokens,
apiKey,
headers,
env,
signal,
thinkingLevel,
streamFn,
@@ -811,6 +816,7 @@ export async function compact(
previousSummary,
thinkingLevel,
streamFn,
env,
);
}
@@ -839,6 +845,7 @@ async function generateTurnPrefixSummary(
reserveTokens: number,
apiKey: string | undefined,
headers?: Record<string, string>,
env?: Record<string, string>,
signal?: AbortSignal,
thinkingLevel?: ThinkingLevel,
streamFn?: StreamFn,
@@ -861,7 +868,7 @@ async function generateTurnPrefixSummary(
const response = await completeSummarization(
model,
{ systemPrompt: SUMMARIZATION_SYSTEM_PROMPT, messages: summarizationMessages },
createSummarizationOptions(model, maxTokens, apiKey, headers, signal, thinkingLevel),
createSummarizationOptions(model, maxTokens, apiKey, headers, env, signal, thinkingLevel),
streamFn,
);
@@ -240,6 +240,7 @@ export type ResolvedRequestAuth =
ok: true;
apiKey?: string;
headers?: Record<string, string>;
env?: Record<string, string>;
}
| {
ok: false;
@@ -684,17 +685,27 @@ export class ModelRegistry {
async getApiKeyAndHeaders(model: Model<Api>): Promise<ResolvedRequestAuth> {
try {
const providerConfig = this.providerRequestConfigs.get(model.provider);
const providerEnv = this.authStorage.getProviderEnv(model.provider);
const apiKeyFromAuthStorage = await this.authStorage.getApiKey(model.provider, { includeFallback: false });
const apiKey =
apiKeyFromAuthStorage ??
(providerConfig?.apiKey
? resolveConfigValueOrThrow(providerConfig.apiKey, `API key for provider "${model.provider}"`)
? resolveConfigValueOrThrow(
providerConfig.apiKey,
`API key for provider "${model.provider}"`,
providerEnv,
)
: undefined);
const providerHeaders = resolveHeadersOrThrow(providerConfig?.headers, `provider "${model.provider}"`);
const providerHeaders = resolveHeadersOrThrow(
providerConfig?.headers,
`provider "${model.provider}"`,
providerEnv,
);
const modelHeaders = resolveHeadersOrThrow(
this.modelRequestHeaders.get(this.getModelRequestKey(model.provider, model.id)),
`model "${model.provider}/${model.id}"`,
providerEnv,
);
let headers =
@@ -713,6 +724,7 @@ export class ModelRegistry {
ok: true,
apiKey,
headers: headers && Object.keys(headers).length > 0 ? headers : undefined,
env: providerEnv && Object.keys(providerEnv).length > 0 ? providerEnv : undefined,
};
} catch (error) {
return {
@@ -777,7 +789,9 @@ export class ModelRegistry {
}
const providerApiKey = this.providerRequestConfigs.get(provider)?.apiKey;
return providerApiKey ? resolveConfigValueUncached(providerApiKey) : undefined;
return providerApiKey
? resolveConfigValueUncached(providerApiKey, this.authStorage.getProviderEnv(provider))
: undefined;
}
/**
@@ -85,8 +85,8 @@ function parseConfigValueReference(config: string): ConfigValueReference {
return { type: "template", parts: parseConfigValueTemplate(config) };
}
function resolveEnvConfigValue(name: string): string | undefined {
return process.env[name] || undefined;
function resolveEnvConfigValue(name: string, env?: Record<string, string>): string | undefined {
return env?.[name] || process.env[name] || undefined;
}
function getTemplateEnvVarNames(parts: TemplatePart[]): string[] {
@@ -98,14 +98,14 @@ function getTemplateEnvVarNames(parts: TemplatePart[]): string[] {
return names;
}
function resolveTemplate(parts: TemplatePart[]): string | undefined {
function resolveTemplate(parts: TemplatePart[], env?: Record<string, string>): string | undefined {
let resolved = "";
for (const part of parts) {
if (part.type === "literal") {
resolved += part.value;
continue;
}
const envValue = resolveEnvConfigValue(part.name);
const envValue = resolveEnvConfigValue(part.name, env);
if (envValue === undefined) return undefined;
resolved += envValue;
}
@@ -123,16 +123,16 @@ export function getConfigValueEnvVarNames(config: string): string[] {
return reference.type === "template" ? getTemplateEnvVarNames(reference.parts) : [];
}
export function getMissingConfigValueEnvVarNames(config: string): string[] {
return getConfigValueEnvVarNames(config).filter((name) => resolveEnvConfigValue(name) === undefined);
export function getMissingConfigValueEnvVarNames(config: string, env?: Record<string, string>): string[] {
return getConfigValueEnvVarNames(config).filter((name) => resolveEnvConfigValue(name, env) === undefined);
}
export function isCommandConfigValue(config: string): boolean {
return parseConfigValueReference(config).type === "command";
}
export function isConfigValueConfigured(config: string): boolean {
return getMissingConfigValueEnvVarNames(config).length === 0;
export function isConfigValueConfigured(config: string, env?: Record<string, string>): boolean {
return getMissingConfigValueEnvVarNames(config, env).length === 0;
}
/**
@@ -142,12 +142,12 @@ export function isConfigValueConfigured(config: string): boolean {
* - In non-command values, "$$" escapes a literal "$" and "$!" escapes a literal "!"
* - Otherwise treats the value as a literal
*/
export function resolveConfigValue(config: string): string | undefined {
export function resolveConfigValue(config: string, env?: Record<string, string>): string | undefined {
const reference = parseConfigValueReference(config);
if (reference.type === "command") {
return executeCommand(reference.config);
}
return resolveTemplate(reference.parts);
return resolveTemplate(reference.parts, env);
}
function executeWithConfiguredShell(command: string): { executed: boolean; value: string | undefined } {
@@ -216,16 +216,16 @@ function executeCommand(commandConfig: string): string | undefined {
/**
* Resolve all header values using the same resolution logic as API keys.
*/
export function resolveConfigValueUncached(config: string): string | undefined {
export function resolveConfigValueUncached(config: string, env?: Record<string, string>): string | undefined {
const reference = parseConfigValueReference(config);
if (reference.type === "command") {
return executeCommandUncached(reference.config);
}
return resolveTemplate(reference.parts);
return resolveTemplate(reference.parts, env);
}
export function resolveConfigValueOrThrow(config: string, description: string): string {
const resolvedValue = resolveConfigValueUncached(config);
export function resolveConfigValueOrThrow(config: string, description: string, env?: Record<string, string>): string {
const resolvedValue = resolveConfigValueUncached(config, env);
if (resolvedValue !== undefined) {
return resolvedValue;
}
@@ -236,7 +236,7 @@ export function resolveConfigValueOrThrow(config: string, description: string):
}
if (reference.type === "template") {
const missingEnvVars = getMissingConfigValueEnvVarNames(config);
const missingEnvVars = getMissingConfigValueEnvVarNames(config, env);
if (missingEnvVars.length === 1) {
throw new Error(`Failed to resolve ${description} from environment variable: ${missingEnvVars[0]}`);
}
@@ -251,11 +251,14 @@ export function resolveConfigValueOrThrow(config: string, description: string):
/**
* Resolve all header values using the same resolution logic as API keys.
*/
export function resolveHeaders(headers: Record<string, string> | undefined): Record<string, string> | undefined {
export function resolveHeaders(
headers: Record<string, string> | undefined,
env?: Record<string, string>,
): Record<string, string> | undefined {
if (!headers) return undefined;
const resolved: Record<string, string> = {};
for (const [key, value] of Object.entries(headers)) {
const resolvedValue = resolveConfigValue(value);
const resolvedValue = resolveConfigValue(value, env);
if (resolvedValue) {
resolved[key] = resolvedValue;
}
@@ -266,11 +269,12 @@ export function resolveHeaders(headers: Record<string, string> | undefined): Rec
export function resolveHeadersOrThrow(
headers: Record<string, string> | undefined,
description: string,
env?: Record<string, string>,
): Record<string, string> | undefined {
if (!headers) return undefined;
const resolved: Record<string, string> = {};
for (const [key, value] of Object.entries(headers)) {
resolved[key] = resolveConfigValueOrThrow(value, `${description} header "${key}"`);
resolved[key] = resolveConfigValueOrThrow(value, `${description} header "${key}"`, env);
}
return Object.keys(resolved).length > 0 ? resolved : undefined;
}
+2
View File
@@ -303,6 +303,7 @@ export async function createAgentSession(options: CreateAgentSessionOptions = {}
if (!auth.ok) {
throw new Error(auth.error);
}
const env = auth.env || options?.env ? { ...(auth.env ?? {}), ...(options?.env ?? {}) } : undefined;
const providerRetrySettings = settingsManager.getProviderRetrySettings();
const httpIdleTimeoutMs = settingsManager.getHttpIdleTimeoutMs();
// SDKs treat timeout=0 as 0ms (immediate timeout), not "no timeout".
@@ -314,6 +315,7 @@ export async function createAgentSession(options: CreateAgentSessionOptions = {}
return streamSimple(model, context, {
...options,
apiKey: auth.apiKey,
env,
timeoutMs,
websocketConnectTimeoutMs,
maxRetries: options?.maxRetries ?? providerRetrySettings.maxRetries,
@@ -134,6 +134,34 @@ describe("AuthStorage", () => {
}
});
test("apiKey env bag takes precedence over process.env", async () => {
const originalEnv = process.env.TEST_AUTH_SCOPED_API_KEY_12345;
process.env.TEST_AUTH_SCOPED_API_KEY_12345 = "process-env-value";
try {
writeAuthJson({
anthropic: {
type: "api_key",
key: "$TEST_AUTH_SCOPED_API_KEY_12345",
env: { TEST_AUTH_SCOPED_API_KEY_12345: "credential-env-value" },
},
});
authStorage = AuthStorage.create(authJsonPath);
expect(await authStorage.getApiKey("anthropic")).toBe("credential-env-value");
expect(authStorage.getProviderEnv("anthropic")).toEqual({
TEST_AUTH_SCOPED_API_KEY_12345: "credential-env-value",
});
} finally {
if (originalEnv === undefined) {
delete process.env.TEST_AUTH_SCOPED_API_KEY_12345;
} else {
process.env.TEST_AUTH_SCOPED_API_KEY_12345 = originalEnv;
}
}
});
test("apiKey with braced env syntax resolves to env value", async () => {
const originalEnv = process.env.TEST_AUTH_BRACED_API_KEY_12345;
process.env.TEST_AUTH_BRACED_API_KEY_12345 = "braced-env-api-key-value";
@@ -893,6 +893,38 @@ describe("ModelRegistry", () => {
expect(registry.getProviderDisplayName("oauth-provider")).toBe("OAuth Provider");
});
test("stored API key env propagates to request auth and resolves headers", async () => {
authStorage.set("cloudflare-ai-gateway", {
type: "api_key",
key: "$CLOUDFLARE_API_KEY",
env: {
CLOUDFLARE_API_KEY: "stored-cf-token",
CLOUDFLARE_ACCOUNT_ID: "stored-account",
},
});
writeRawModelsJson({
"cloudflare-ai-gateway": {
headers: { "x-account": "$CLOUDFLARE_ACCOUNT_ID" },
},
});
const registry = ModelRegistry.create(authStorage, modelsJsonPath);
const model = registry.getAll().find((m) => m.provider === "cloudflare-ai-gateway");
expect(model).toBeDefined();
const auth = await registry.getApiKeyAndHeaders(model!);
expect(auth).toEqual({
ok: true,
apiKey: "stored-cf-token",
headers: { "x-account": "stored-account" },
env: {
CLOUDFLARE_API_KEY: "stored-cf-token",
CLOUDFLARE_ACCOUNT_ID: "stored-account",
},
});
});
test("registerProvider treats uppercase apiKey and headers as literals", async () => {
const envKeys = ["CUSTOM_NAME", "BEARER", "MODEL_TOKEN"];
const savedEnv: Record<string, string | undefined> = {};