mirror of
https://github.com/earendil-works/pi.git
synced 2026-06-18 15:54:04 +08:00
feat: add provider-scoped environment overrides (#5807)
This commit is contained in:
committed by
GitHub
Unverified
parent
3039f3e17d
commit
7f29e7a369
@@ -2,6 +2,10 @@
|
||||
|
||||
## [Unreleased]
|
||||
|
||||
### Added
|
||||
|
||||
- Added provider-scoped `StreamOptions.env` overrides for provider configuration, including Cloudflare endpoint placeholders, Azure OpenAI, Google Vertex, Amazon Bedrock, cache retention, and proxy environment lookups ([#5728](https://github.com/earendil-works/pi/issues/5728)).
|
||||
|
||||
### Fixed
|
||||
|
||||
- Fixed Z.AI GLM-5.2 thinking requests to send `reasoning_effort` with the provider's `high`/`max` effort mapping ([#5770](https://github.com/earendil-works/pi/issues/5770)).
|
||||
|
||||
@@ -38,6 +38,7 @@ Unified LLM API with automatic model discovery, provider configuration, token an
|
||||
- [Browser Usage](#browser-usage)
|
||||
- [Browser Compatibility Notes](#browser-compatibility-notes)
|
||||
- [Environment Variables](#environment-variables-nodejs-only)
|
||||
- [Provider-Scoped Environment Overrides](#provider-scoped-environment-overrides)
|
||||
- [Checking Environment Variables](#checking-environment-variables)
|
||||
- [OAuth Providers](#oauth-providers)
|
||||
- [Vertex AI](#vertex-ai)
|
||||
@@ -1145,6 +1146,24 @@ const response = await complete(model, context, {
|
||||
});
|
||||
```
|
||||
|
||||
### Provider-Scoped Environment Overrides
|
||||
|
||||
Pass `env` in stream options to scope provider configuration to a request. Values in `env` are used before process environment variables for API key discovery and provider configuration such as Cloudflare account IDs, Azure OpenAI settings, Vertex project/location, Bedrock settings, `PI_CACHE_RETENTION`, and `HTTP_PROXY`/`HTTPS_PROXY`.
|
||||
|
||||
```typescript
|
||||
const model = getModel('cloudflare-ai-gateway', 'workers-ai/@cf/moonshotai/kimi-k2.6');
|
||||
|
||||
const response = await complete(model, context, {
|
||||
env: {
|
||||
CLOUDFLARE_API_KEY: '...',
|
||||
CLOUDFLARE_ACCOUNT_ID: 'account-id',
|
||||
CLOUDFLARE_GATEWAY_ID: 'gateway-id'
|
||||
}
|
||||
});
|
||||
```
|
||||
|
||||
Use this when one process needs different provider settings per request, or when ambient environment variables should not leak into a provider call.
|
||||
|
||||
### Checking Environment Variables
|
||||
|
||||
```typescript
|
||||
|
||||
@@ -23,44 +23,17 @@ if (typeof process !== "undefined" && (process.versions?.node || process.version
|
||||
});
|
||||
}
|
||||
|
||||
import type { KnownProvider } from "./types.ts";
|
||||
|
||||
let _procEnvCache: Map<string, string> | null = null;
|
||||
|
||||
/**
|
||||
* Fallback for https://github.com/oven-sh/bun/issues/27802
|
||||
* Bun compiled binaries have an empty `process.env` inside sandbox
|
||||
* environments on Linux. We can recover the env from `/proc/self/environ`.
|
||||
*/
|
||||
function getProcEnv(key: string): string | undefined {
|
||||
if (!process.versions?.bun) return undefined;
|
||||
if (typeof process === "undefined") return undefined;
|
||||
|
||||
// If process.env already has entries, the bug is not triggered.
|
||||
if (Object.keys(process.env).length > 0) return undefined;
|
||||
|
||||
if (_procEnvCache === null) {
|
||||
_procEnvCache = new Map();
|
||||
try {
|
||||
const { readFileSync } = require("node:fs") as typeof import("node:fs");
|
||||
const data = readFileSync("/proc/self/environ", "utf-8");
|
||||
for (const entry of data.split("\0")) {
|
||||
const idx = entry.indexOf("=");
|
||||
if (idx > 0) {
|
||||
_procEnvCache.set(entry.slice(0, idx), entry.slice(idx + 1));
|
||||
}
|
||||
}
|
||||
} catch {
|
||||
// /proc/self/environ may not be readable.
|
||||
}
|
||||
}
|
||||
|
||||
return _procEnvCache.get(key);
|
||||
}
|
||||
import type { KnownProvider, ProviderEnv } from "./types.ts";
|
||||
import { getProviderEnvValue } from "./utils/provider-env.ts";
|
||||
|
||||
let cachedVertexAdcCredentialsExists: boolean | null = null;
|
||||
|
||||
function hasVertexAdcCredentials(): boolean {
|
||||
function hasVertexAdcCredentials(env?: ProviderEnv): boolean {
|
||||
const explicitCredentialsPath = env?.GOOGLE_APPLICATION_CREDENTIALS;
|
||||
if (explicitCredentialsPath) {
|
||||
return _existsSync ? _existsSync(explicitCredentialsPath) : false;
|
||||
}
|
||||
|
||||
if (cachedVertexAdcCredentialsExists === null) {
|
||||
// If node modules haven't loaded yet (async import race at startup),
|
||||
// return false WITHOUT caching so the next call retries once they're ready.
|
||||
@@ -75,7 +48,7 @@ function hasVertexAdcCredentials(): boolean {
|
||||
}
|
||||
|
||||
// Check GOOGLE_APPLICATION_CREDENTIALS env var first (standard way)
|
||||
const gacPath = process.env.GOOGLE_APPLICATION_CREDENTIALS || getProcEnv("GOOGLE_APPLICATION_CREDENTIALS");
|
||||
const gacPath = getProviderEnvValue("GOOGLE_APPLICATION_CREDENTIALS", env);
|
||||
if (gacPath) {
|
||||
cachedVertexAdcCredentialsExists = _existsSync(gacPath);
|
||||
} else {
|
||||
@@ -143,13 +116,13 @@ function getApiKeyEnvVars(provider: string): readonly string[] | undefined {
|
||||
* credential sources such as AWS profiles, AWS IAM credentials, and Google
|
||||
* Application Default Credentials.
|
||||
*/
|
||||
export function findEnvKeys(provider: KnownProvider): string[] | undefined;
|
||||
export function findEnvKeys(provider: string): string[] | undefined;
|
||||
export function findEnvKeys(provider: string): string[] | undefined {
|
||||
export function findEnvKeys(provider: KnownProvider, env?: ProviderEnv): string[] | undefined;
|
||||
export function findEnvKeys(provider: string, env?: ProviderEnv): string[] | undefined;
|
||||
export function findEnvKeys(provider: string, env?: ProviderEnv): string[] | undefined {
|
||||
const envVars = getApiKeyEnvVars(provider);
|
||||
if (!envVars) return undefined;
|
||||
|
||||
const found = envVars.filter((envVar) => !!process.env[envVar] || !!getProcEnv(envVar));
|
||||
const found = envVars.filter((envVar) => !!getProviderEnvValue(envVar, env));
|
||||
return found.length > 0 ? found : undefined;
|
||||
}
|
||||
|
||||
@@ -158,25 +131,22 @@ export function findEnvKeys(provider: string): string[] | undefined {
|
||||
*
|
||||
* Will not return API keys for providers that require OAuth tokens.
|
||||
*/
|
||||
export function getEnvApiKey(provider: KnownProvider): string | undefined;
|
||||
export function getEnvApiKey(provider: string): string | undefined;
|
||||
export function getEnvApiKey(provider: string): string | undefined {
|
||||
const envKeys = findEnvKeys(provider);
|
||||
export function getEnvApiKey(provider: KnownProvider, env?: ProviderEnv): string | undefined;
|
||||
export function getEnvApiKey(provider: string, env?: ProviderEnv): string | undefined;
|
||||
export function getEnvApiKey(provider: string, env?: ProviderEnv): string | undefined {
|
||||
const envKeys = findEnvKeys(provider, env);
|
||||
if (envKeys?.[0]) {
|
||||
return process.env[envKeys[0]] || getProcEnv(envKeys[0]);
|
||||
return getProviderEnvValue(envKeys[0], env);
|
||||
}
|
||||
|
||||
// Vertex AI supports either an explicit API key or Application Default Credentials.
|
||||
// Auth is configured via `gcloud auth application-default login`.
|
||||
if (provider === "google-vertex") {
|
||||
const hasCredentials = hasVertexAdcCredentials();
|
||||
const hasCredentials = hasVertexAdcCredentials(env);
|
||||
const hasProject = !!(
|
||||
process.env.GOOGLE_CLOUD_PROJECT ||
|
||||
process.env.GCLOUD_PROJECT ||
|
||||
getProcEnv("GOOGLE_CLOUD_PROJECT") ||
|
||||
getProcEnv("GCLOUD_PROJECT")
|
||||
getProviderEnvValue("GOOGLE_CLOUD_PROJECT", env) || getProviderEnvValue("GCLOUD_PROJECT", env)
|
||||
);
|
||||
const hasLocation = !!(process.env.GOOGLE_CLOUD_LOCATION || getProcEnv("GOOGLE_CLOUD_LOCATION"));
|
||||
const hasLocation = !!getProviderEnvValue("GOOGLE_CLOUD_LOCATION", env);
|
||||
|
||||
if (hasCredentials && hasProject && hasLocation) {
|
||||
return "<authenticated>";
|
||||
@@ -192,18 +162,12 @@ export function getEnvApiKey(provider: string): string | undefined {
|
||||
// 5. AWS_CONTAINER_CREDENTIALS_FULL_URI - ECS task roles (full URI)
|
||||
// 6. AWS_WEB_IDENTITY_TOKEN_FILE - IRSA (IAM Roles for Service Accounts)
|
||||
if (
|
||||
process.env.AWS_PROFILE ||
|
||||
(process.env.AWS_ACCESS_KEY_ID && process.env.AWS_SECRET_ACCESS_KEY) ||
|
||||
process.env.AWS_BEARER_TOKEN_BEDROCK ||
|
||||
process.env.AWS_CONTAINER_CREDENTIALS_RELATIVE_URI ||
|
||||
process.env.AWS_CONTAINER_CREDENTIALS_FULL_URI ||
|
||||
process.env.AWS_WEB_IDENTITY_TOKEN_FILE ||
|
||||
getProcEnv("AWS_PROFILE") ||
|
||||
(getProcEnv("AWS_ACCESS_KEY_ID") && getProcEnv("AWS_SECRET_ACCESS_KEY")) ||
|
||||
getProcEnv("AWS_BEARER_TOKEN_BEDROCK") ||
|
||||
getProcEnv("AWS_CONTAINER_CREDENTIALS_RELATIVE_URI") ||
|
||||
getProcEnv("AWS_CONTAINER_CREDENTIALS_FULL_URI") ||
|
||||
getProcEnv("AWS_WEB_IDENTITY_TOKEN_FILE")
|
||||
getProviderEnvValue("AWS_PROFILE", env) ||
|
||||
(getProviderEnvValue("AWS_ACCESS_KEY_ID", env) && getProviderEnvValue("AWS_SECRET_ACCESS_KEY", env)) ||
|
||||
getProviderEnvValue("AWS_BEARER_TOKEN_BEDROCK", env) ||
|
||||
getProviderEnvValue("AWS_CONTAINER_CREDENTIALS_RELATIVE_URI", env) ||
|
||||
getProviderEnvValue("AWS_CONTAINER_CREDENTIALS_FULL_URI", env) ||
|
||||
getProviderEnvValue("AWS_WEB_IDENTITY_TOKEN_FILE", env)
|
||||
) {
|
||||
return "<authenticated>";
|
||||
}
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import type { Agent as HttpsAgent } from "node:https";
|
||||
import {
|
||||
BedrockRuntimeClient,
|
||||
type BedrockRuntimeClientConfig,
|
||||
@@ -23,6 +24,8 @@ import {
|
||||
} from "@aws-sdk/client-bedrock-runtime";
|
||||
import { NodeHttpHandler } from "@smithy/node-http-handler";
|
||||
import type { BuildMiddleware, DocumentType, MetadataBearer } from "@smithy/types";
|
||||
import { HttpProxyAgent } from "http-proxy-agent";
|
||||
import { HttpsProxyAgent } from "https-proxy-agent";
|
||||
import { calculateCost } from "../models.ts";
|
||||
import type {
|
||||
Api,
|
||||
@@ -31,6 +34,7 @@ import type {
|
||||
Context,
|
||||
ImageContent,
|
||||
Model,
|
||||
ProviderEnv,
|
||||
SimpleStreamOptions,
|
||||
StopReason,
|
||||
StreamFunction,
|
||||
@@ -45,7 +49,8 @@ import type {
|
||||
} from "../types.ts";
|
||||
import { AssistantMessageEventStream } from "../utils/event-stream.ts";
|
||||
import { parseStreamingJson } from "../utils/json-parse.ts";
|
||||
import { createHttpProxyAgentsForTarget } from "../utils/node-http-proxy.ts";
|
||||
import { resolveHttpProxyUrlForTarget } from "../utils/node-http-proxy.ts";
|
||||
import { getProviderEnvValue } from "../utils/provider-env.ts";
|
||||
import { sanitizeSurrogates } from "../utils/sanitize-unicode.ts";
|
||||
import { adjustMaxTokensForThinking, buildBaseOptions, clampReasoning } from "./simple-options.ts";
|
||||
import { transformMessages } from "./transform-messages.ts";
|
||||
@@ -119,18 +124,18 @@ export const streamBedrock: StreamFunction<"bedrock-converse-stream", BedrockOpt
|
||||
const blocks = output.content as Block[];
|
||||
|
||||
const config: BedrockRuntimeClientConfig = {
|
||||
profile: options.profile,
|
||||
profile: options.profile || getProviderEnvValue("AWS_PROFILE", options.env),
|
||||
};
|
||||
const configuredRegion = getConfiguredBedrockRegion(options);
|
||||
const hasConfiguredProfile = hasConfiguredBedrockProfile();
|
||||
const hasAmbientConfiguredProfile = Boolean(getProviderEnvValue("AWS_PROFILE"));
|
||||
const endpointRegion = getStandardBedrockEndpointRegion(model.baseUrl);
|
||||
const useExplicitEndpoint = shouldUseExplicitBedrockEndpoint(
|
||||
model.baseUrl,
|
||||
configuredRegion,
|
||||
hasConfiguredProfile,
|
||||
hasAmbientConfiguredProfile,
|
||||
);
|
||||
|
||||
// Only pin standard AWS Bedrock runtime endpoints when no region/profile is configured.
|
||||
// Only pin standard AWS Bedrock runtime endpoints when no region or ambient AWS_PROFILE is configured.
|
||||
// This preserves custom endpoints (VPC/proxy) from #3402 without forcing built-in
|
||||
// catalog defaults such as us-east-1 to override AWS_REGION/AWS_PROFILE.
|
||||
if (useExplicitEndpoint) {
|
||||
@@ -138,8 +143,10 @@ export const streamBedrock: StreamFunction<"bedrock-converse-stream", BedrockOpt
|
||||
}
|
||||
|
||||
// Resolve bearer token for Bedrock API key auth.
|
||||
const bearerToken = options.bearerToken || process.env.AWS_BEARER_TOKEN_BEDROCK || undefined;
|
||||
const useBearerToken = bearerToken !== undefined && process.env.AWS_BEDROCK_SKIP_AUTH !== "1";
|
||||
const skipAuth = getProviderEnvValue("AWS_BEDROCK_SKIP_AUTH", options.env) === "1";
|
||||
const bearerToken =
|
||||
options.bearerToken || getProviderEnvValue("AWS_BEARER_TOKEN_BEDROCK", options.env) || undefined;
|
||||
const useBearerToken = bearerToken !== undefined && !skipAuth;
|
||||
|
||||
// in Node.js/Bun environment only
|
||||
if (typeof process !== "undefined" && (process.versions?.node || process.versions?.bun)) {
|
||||
@@ -153,25 +160,33 @@ export const streamBedrock: StreamFunction<"bedrock-converse-stream", BedrockOpt
|
||||
config.region = configuredRegion;
|
||||
} else if (endpointRegion && useExplicitEndpoint) {
|
||||
config.region = endpointRegion;
|
||||
} else if (!hasConfiguredProfile) {
|
||||
} else if (!hasAmbientConfiguredProfile) {
|
||||
config.region = "us-east-1";
|
||||
}
|
||||
|
||||
// Support proxies that don't need authentication
|
||||
if (process.env.AWS_BEDROCK_SKIP_AUTH === "1") {
|
||||
if (skipAuth) {
|
||||
config.credentials = {
|
||||
accessKeyId: "dummy-access-key",
|
||||
secretAccessKey: "dummy-secret-key",
|
||||
};
|
||||
}
|
||||
|
||||
const proxyAgents = createHttpProxyAgentsForTarget(model.baseUrl);
|
||||
if (proxyAgents) {
|
||||
const credentials = getConfiguredBedrockCredentials(options.env);
|
||||
if (!skipAuth && credentials) {
|
||||
config.credentials = credentials;
|
||||
}
|
||||
|
||||
const proxyUrl = resolveHttpProxyUrlForTarget(model.baseUrl, options.env);
|
||||
if (proxyUrl) {
|
||||
// Bedrock runtime uses NodeHttp2Handler by default since v3.798.0, which is based
|
||||
// on `http2` module and has no support for http agent.
|
||||
// Use NodeHttpHandler to support HTTP(S) proxy agents.
|
||||
config.requestHandler = new NodeHttpHandler(proxyAgents);
|
||||
} else if (process.env.AWS_BEDROCK_FORCE_HTTP1 === "1") {
|
||||
config.requestHandler = new NodeHttpHandler({
|
||||
httpAgent: new HttpProxyAgent(proxyUrl),
|
||||
httpsAgent: new HttpsProxyAgent(proxyUrl) as unknown as HttpsAgent,
|
||||
});
|
||||
} else if (getProviderEnvValue("AWS_BEDROCK_FORCE_HTTP1", options.env) === "1") {
|
||||
// Some custom endpoints require HTTP/1.1 instead of HTTP/2
|
||||
config.requestHandler = new NodeHttpHandler();
|
||||
}
|
||||
@@ -192,12 +207,12 @@ export const streamBedrock: StreamFunction<"bedrock-converse-stream", BedrockOpt
|
||||
if (options.headers && Object.keys(options.headers).length > 0) {
|
||||
addCustomHeadersMiddleware(client, options.headers);
|
||||
}
|
||||
const cacheRetention = resolveCacheRetention(options.cacheRetention);
|
||||
const cacheRetention = resolveCacheRetention(options.cacheRetention, options.env);
|
||||
const inferenceMaxTokens = options.maxTokens ?? (isAnthropicClaudeModel(model) ? model.maxTokens : undefined);
|
||||
let commandInput = {
|
||||
modelId: model.id,
|
||||
messages: convertMessages(context, model, cacheRetention),
|
||||
system: buildSystemPrompt(context.systemPrompt, model, cacheRetention),
|
||||
messages: convertMessages(context, model, cacheRetention, options.env),
|
||||
system: buildSystemPrompt(context.systemPrompt, model, cacheRetention, options.env),
|
||||
inferenceConfig: {
|
||||
...(inferenceMaxTokens !== undefined && { maxTokens: inferenceMaxTokens }),
|
||||
...(options.temperature !== undefined && { temperature: options.temperature }),
|
||||
@@ -578,11 +593,11 @@ function mapThinkingLevelToEffort(
|
||||
* Resolve cache retention preference.
|
||||
* Defaults to "short" and uses PI_CACHE_RETENTION for backward compatibility.
|
||||
*/
|
||||
function resolveCacheRetention(cacheRetention?: CacheRetention): CacheRetention {
|
||||
function resolveCacheRetention(cacheRetention?: CacheRetention, env?: ProviderEnv): CacheRetention {
|
||||
if (cacheRetention) {
|
||||
return cacheRetention;
|
||||
}
|
||||
if (typeof process !== "undefined" && process.env.PI_CACHE_RETENTION === "long") {
|
||||
if (getProviderEnvValue("PI_CACHE_RETENTION", env) === "long") {
|
||||
return "long";
|
||||
}
|
||||
return "short";
|
||||
@@ -617,14 +632,14 @@ function isAnthropicClaudeModel(model: Model<"bedrock-converse-stream">): boolea
|
||||
* As a last resort, set AWS_BEDROCK_FORCE_CACHE=1 to enable cache points.
|
||||
* Amazon Nova models have automatic caching and don't need explicit cache points.
|
||||
*/
|
||||
function supportsPromptCaching(model: Model<"bedrock-converse-stream">): boolean {
|
||||
function supportsPromptCaching(model: Model<"bedrock-converse-stream">, env?: ProviderEnv): boolean {
|
||||
const candidates = getModelMatchCandidates(model.id, model.name);
|
||||
|
||||
const hasClaudeRef = candidates.some((s) => s.includes("claude"));
|
||||
if (!hasClaudeRef) {
|
||||
// Application inference profiles don't contain the model name in the ARN.
|
||||
// Allow users to force cache points via environment variable.
|
||||
if (typeof process !== "undefined" && process.env.AWS_BEDROCK_FORCE_CACHE === "1") return true;
|
||||
if (getProviderEnvValue("AWS_BEDROCK_FORCE_CACHE", env) === "1") return true;
|
||||
return false;
|
||||
}
|
||||
// Claude 4.x models (opus-4, sonnet-4, haiku-4)
|
||||
@@ -652,13 +667,14 @@ function buildSystemPrompt(
|
||||
systemPrompt: string | undefined,
|
||||
model: Model<"bedrock-converse-stream">,
|
||||
cacheRetention: CacheRetention,
|
||||
env?: ProviderEnv,
|
||||
): SystemContentBlock[] | undefined {
|
||||
if (!systemPrompt) return undefined;
|
||||
|
||||
const blocks: SystemContentBlock[] = [{ text: sanitizeSurrogates(systemPrompt) }];
|
||||
|
||||
// Add cache point for supported Claude models when caching is enabled
|
||||
if (cacheRetention !== "none" && supportsPromptCaching(model)) {
|
||||
if (cacheRetention !== "none" && supportsPromptCaching(model, env)) {
|
||||
blocks.push({
|
||||
cachePoint: { type: CachePointType.DEFAULT, ...(cacheRetention === "long" ? { ttl: CacheTTL.ONE_HOUR } : {}) },
|
||||
});
|
||||
@@ -699,6 +715,7 @@ function convertMessages(
|
||||
context: Context,
|
||||
model: Model<"bedrock-converse-stream">,
|
||||
cacheRetention: CacheRetention,
|
||||
env?: ProviderEnv,
|
||||
): Message[] {
|
||||
const result: Message[] = [];
|
||||
const transformedMessages = transformMessages(context.messages, model, normalizeToolCallId);
|
||||
@@ -844,7 +861,7 @@ function convertMessages(
|
||||
}
|
||||
|
||||
// Add cache point to the last user message for supported Claude models when caching is enabled
|
||||
if (cacheRetention !== "none" && supportsPromptCaching(model) && result.length > 0) {
|
||||
if (cacheRetention !== "none" && supportsPromptCaching(model, env) && result.length > 0) {
|
||||
const lastMessage = result[result.length - 1];
|
||||
if (lastMessage.role === ConversationRole.USER && lastMessage.content) {
|
||||
(lastMessage.content as ContentBlock[]).push({
|
||||
@@ -906,19 +923,26 @@ function mapStopReason(reason: string | undefined): StopReason {
|
||||
}
|
||||
|
||||
function getConfiguredBedrockRegion(options: BedrockOptions): string | undefined {
|
||||
if (typeof process === "undefined") {
|
||||
return options.region;
|
||||
}
|
||||
|
||||
return options.region || process.env.AWS_REGION || process.env.AWS_DEFAULT_REGION || undefined;
|
||||
return (
|
||||
options.region ||
|
||||
getProviderEnvValue("AWS_REGION", options.env) ||
|
||||
getProviderEnvValue("AWS_DEFAULT_REGION", options.env) ||
|
||||
undefined
|
||||
);
|
||||
}
|
||||
|
||||
function hasConfiguredBedrockProfile(): boolean {
|
||||
if (typeof process === "undefined") {
|
||||
return false;
|
||||
function getConfiguredBedrockCredentials(env?: ProviderEnv): BedrockRuntimeClientConfig["credentials"] | undefined {
|
||||
const accessKeyId = getProviderEnvValue("AWS_ACCESS_KEY_ID", env);
|
||||
const secretAccessKey = getProviderEnvValue("AWS_SECRET_ACCESS_KEY", env);
|
||||
if (!accessKeyId || !secretAccessKey) {
|
||||
return undefined;
|
||||
}
|
||||
|
||||
return Boolean(process.env.AWS_PROFILE);
|
||||
const sessionToken = getProviderEnvValue("AWS_SESSION_TOKEN", env);
|
||||
return {
|
||||
accessKeyId,
|
||||
secretAccessKey,
|
||||
...(sessionToken ? { sessionToken } : {}),
|
||||
};
|
||||
}
|
||||
|
||||
function getStandardBedrockEndpointRegion(baseUrl: string | undefined): string | undefined {
|
||||
@@ -938,14 +962,14 @@ function getStandardBedrockEndpointRegion(baseUrl: string | undefined): string |
|
||||
function shouldUseExplicitBedrockEndpoint(
|
||||
baseUrl: string,
|
||||
configuredRegion: string | undefined,
|
||||
hasConfiguredProfile: boolean,
|
||||
hasAmbientConfiguredProfile: boolean,
|
||||
): boolean {
|
||||
const endpointRegion = getStandardBedrockEndpointRegion(baseUrl);
|
||||
if (!endpointRegion) {
|
||||
return true;
|
||||
}
|
||||
|
||||
return !configuredRegion && !hasConfiguredProfile;
|
||||
return !configuredRegion && !hasAmbientConfiguredProfile;
|
||||
}
|
||||
|
||||
function isGovCloudBedrockTarget(model: Model<"bedrock-converse-stream">, options: BedrockOptions): boolean {
|
||||
|
||||
@@ -17,6 +17,7 @@ import type {
|
||||
ImageContent,
|
||||
Message,
|
||||
Model,
|
||||
ProviderEnv,
|
||||
SimpleStreamOptions,
|
||||
StopReason,
|
||||
StreamFunction,
|
||||
@@ -30,6 +31,7 @@ import type {
|
||||
import { AssistantMessageEventStream } from "../utils/event-stream.ts";
|
||||
import { headersToRecord } from "../utils/headers.ts";
|
||||
import { parseJsonWithRepair, parseStreamingJson } from "../utils/json-parse.ts";
|
||||
import { getProviderEnvValue } from "../utils/provider-env.ts";
|
||||
import { sanitizeSurrogates } from "../utils/sanitize-unicode.ts";
|
||||
|
||||
import { resolveCloudflareBaseUrl } from "./cloudflare.ts";
|
||||
@@ -41,11 +43,11 @@ import { transformMessages } from "./transform-messages.ts";
|
||||
* Resolve cache retention preference.
|
||||
* Defaults to "short" and uses PI_CACHE_RETENTION for backward compatibility.
|
||||
*/
|
||||
function resolveCacheRetention(cacheRetention?: CacheRetention): CacheRetention {
|
||||
function resolveCacheRetention(cacheRetention?: CacheRetention, env?: ProviderEnv): CacheRetention {
|
||||
if (cacheRetention) {
|
||||
return cacheRetention;
|
||||
}
|
||||
if (typeof process !== "undefined" && process.env.PI_CACHE_RETENTION === "long") {
|
||||
if (getProviderEnvValue("PI_CACHE_RETENTION", env) === "long") {
|
||||
return "long";
|
||||
}
|
||||
return "short";
|
||||
@@ -54,8 +56,9 @@ function resolveCacheRetention(cacheRetention?: CacheRetention): CacheRetention
|
||||
function getCacheControl(
|
||||
model: Model<"anthropic-messages">,
|
||||
cacheRetention?: CacheRetention,
|
||||
env?: ProviderEnv,
|
||||
): { retention: CacheRetention; cacheControl?: CacheControlEphemeral } {
|
||||
const retention = resolveCacheRetention(cacheRetention);
|
||||
const retention = resolveCacheRetention(cacheRetention, env);
|
||||
if (retention === "none") {
|
||||
return { retention };
|
||||
}
|
||||
@@ -494,7 +497,7 @@ export const streamAnthropic: StreamFunction<"anthropic-messages", AnthropicOpti
|
||||
});
|
||||
}
|
||||
|
||||
const cacheRetention = options?.cacheRetention ?? resolveCacheRetention();
|
||||
const cacheRetention = resolveCacheRetention(options?.cacheRetention, options?.env);
|
||||
const cacheSessionId = cacheRetention === "none" ? undefined : options?.sessionId;
|
||||
|
||||
const created = createClient(
|
||||
@@ -505,6 +508,7 @@ export const streamAnthropic: StreamFunction<"anthropic-messages", AnthropicOpti
|
||||
options?.headers,
|
||||
copilotDynamicHeaders,
|
||||
cacheSessionId,
|
||||
options?.env,
|
||||
);
|
||||
client = created.client;
|
||||
isOAuth = created.isOAuthToken;
|
||||
@@ -794,6 +798,7 @@ function createClient(
|
||||
optionsHeaders?: Record<string, string>,
|
||||
dynamicHeaders?: Record<string, string>,
|
||||
sessionId?: string,
|
||||
env?: ProviderEnv,
|
||||
): { client: Anthropic; isOAuthToken: boolean } {
|
||||
// Adaptive thinking models have interleaved thinking built in, so skip the beta header.
|
||||
const needsInterleavedBeta = interleavedThinking && model.compat?.forceAdaptiveThinking !== true;
|
||||
@@ -809,7 +814,7 @@ function createClient(
|
||||
const client = new Anthropic({
|
||||
apiKey: null,
|
||||
authToken: null,
|
||||
baseURL: resolveCloudflareBaseUrl(model),
|
||||
baseURL: resolveCloudflareBaseUrl(model, env),
|
||||
dangerouslyAllowBrowser: true,
|
||||
defaultHeaders: mergeHeaders(
|
||||
{
|
||||
@@ -902,7 +907,7 @@ function buildParams(
|
||||
isOAuthToken: boolean,
|
||||
options?: AnthropicOptions,
|
||||
): MessageCreateParamsStreaming {
|
||||
const { cacheControl } = getCacheControl(model, options?.cacheRetention);
|
||||
const { cacheControl } = getCacheControl(model, options?.cacheRetention, options?.env);
|
||||
const compat = getAnthropicCompat(model);
|
||||
const params: MessageCreateParamsStreaming = {
|
||||
model: model.id,
|
||||
|
||||
@@ -12,6 +12,7 @@ import type {
|
||||
} from "../types.ts";
|
||||
import { AssistantMessageEventStream } from "../utils/event-stream.ts";
|
||||
import { headersToRecord } from "../utils/headers.ts";
|
||||
import { getProviderEnvValue } from "../utils/provider-env.ts";
|
||||
import { clampOpenAIPromptCacheKey } from "./openai-prompt-cache.ts";
|
||||
import { convertResponsesMessages, convertResponsesTools, processResponsesStream } from "./openai-responses-shared.ts";
|
||||
import { buildBaseOptions } from "./simple-options.ts";
|
||||
@@ -36,7 +37,9 @@ function resolveDeploymentName(model: Model<"azure-openai-responses">, options?:
|
||||
if (options?.azureDeploymentName) {
|
||||
return options.azureDeploymentName;
|
||||
}
|
||||
const mappedDeployment = parseDeploymentNameMap(process.env.AZURE_OPENAI_DEPLOYMENT_NAME_MAP).get(model.id);
|
||||
const mappedDeployment = parseDeploymentNameMap(
|
||||
getProviderEnvValue("AZURE_OPENAI_DEPLOYMENT_NAME_MAP", options?.env),
|
||||
).get(model.id);
|
||||
return mappedDeployment || model.id;
|
||||
}
|
||||
|
||||
@@ -198,10 +201,14 @@ function resolveAzureConfig(
|
||||
model: Model<"azure-openai-responses">,
|
||||
options?: AzureOpenAIResponsesOptions,
|
||||
): { baseUrl: string; apiVersion: string } {
|
||||
const apiVersion = options?.azureApiVersion || process.env.AZURE_OPENAI_API_VERSION || DEFAULT_AZURE_API_VERSION;
|
||||
const apiVersion =
|
||||
options?.azureApiVersion ||
|
||||
getProviderEnvValue("AZURE_OPENAI_API_VERSION", options?.env) ||
|
||||
DEFAULT_AZURE_API_VERSION;
|
||||
|
||||
const baseUrl = options?.azureBaseUrl?.trim() || process.env.AZURE_OPENAI_BASE_URL?.trim() || undefined;
|
||||
const resourceName = options?.azureResourceName || process.env.AZURE_OPENAI_RESOURCE_NAME;
|
||||
const baseUrl =
|
||||
options?.azureBaseUrl?.trim() || getProviderEnvValue("AZURE_OPENAI_BASE_URL", options?.env)?.trim() || undefined;
|
||||
const resourceName = options?.azureResourceName || getProviderEnvValue("AZURE_OPENAI_RESOURCE_NAME", options?.env);
|
||||
|
||||
let resolvedBaseUrl = baseUrl;
|
||||
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import type { Api, Model } from "../types.ts";
|
||||
import type { Api, Model, ProviderEnv } from "../types.ts";
|
||||
import { getProviderEnvValue } from "../utils/provider-env.ts";
|
||||
|
||||
/** Workers AI direct endpoint. */
|
||||
export const CLOUDFLARE_WORKERS_AI_BASE_URL =
|
||||
@@ -20,12 +21,12 @@ export function isCloudflareProvider(provider: string): boolean {
|
||||
return provider === "cloudflare-workers-ai" || provider === "cloudflare-ai-gateway";
|
||||
}
|
||||
|
||||
/** Substitute `{VAR}` placeholders in a Cloudflare baseUrl from process.env. */
|
||||
export function resolveCloudflareBaseUrl(model: Model<Api>): string {
|
||||
/** Substitute `{VAR}` placeholders in a Cloudflare baseUrl from provider env or process.env. */
|
||||
export function resolveCloudflareBaseUrl(model: Model<Api>, env?: ProviderEnv): string {
|
||||
const url = model.baseUrl;
|
||||
if (!url.includes("{")) return url;
|
||||
const baseUrl = url.replace(/\{([A-Z_][A-Z0-9_]*)\}/g, (_match, name: string) => {
|
||||
const value = process.env[name];
|
||||
const value = getProviderEnvValue(name, env);
|
||||
if (!value) {
|
||||
throw new Error(`${name} is required for provider ${model.provider} but is not set.`);
|
||||
}
|
||||
|
||||
@@ -14,6 +14,7 @@ import type {
|
||||
Context,
|
||||
Model,
|
||||
ThinkingLevel as PiThinkingLevel,
|
||||
ProviderEnv,
|
||||
SimpleStreamOptions,
|
||||
StreamFunction,
|
||||
StreamOptions,
|
||||
@@ -23,6 +24,7 @@ import type {
|
||||
ToolCall,
|
||||
} from "../types.ts";
|
||||
import { AssistantMessageEventStream } from "../utils/event-stream.ts";
|
||||
import { getProviderEnvValue } from "../utils/provider-env.ts";
|
||||
import { sanitizeSurrogates } from "../utils/sanitize-unicode.ts";
|
||||
import type { GoogleThinkingLevel } from "./google-shared.ts";
|
||||
import {
|
||||
@@ -91,7 +93,7 @@ export const streamGoogleVertex: StreamFunction<"google-vertex", GoogleVertexOpt
|
||||
// Create the client using either a Vertex API key, if provided, or ADC with project and location
|
||||
const client = apiKey
|
||||
? createClientWithApiKey(model, apiKey, options?.headers)
|
||||
: createClient(model, resolveProject(options), resolveLocation(options), options?.headers);
|
||||
: createClient(model, resolveProject(options), resolveLocation(options), options?.headers, options?.env);
|
||||
let params = buildParams(model, context, options);
|
||||
const nextParams = await options?.onPayload?.(params, model);
|
||||
if (nextParams !== undefined) {
|
||||
@@ -333,12 +335,15 @@ function createClient(
|
||||
project: string,
|
||||
location: string,
|
||||
optionsHeaders?: Record<string, string>,
|
||||
env?: ProviderEnv,
|
||||
): GoogleGenAI {
|
||||
const googleAuthOptions = buildGoogleAuthOptions(env);
|
||||
return new GoogleGenAI({
|
||||
vertexai: true,
|
||||
project,
|
||||
location,
|
||||
apiVersion: API_VERSION,
|
||||
...(googleAuthOptions ? { googleAuthOptions } : {}),
|
||||
httpOptions: buildHttpOptions(model, optionsHeaders),
|
||||
});
|
||||
}
|
||||
@@ -394,6 +399,11 @@ function baseUrlIncludesApiVersion(baseUrl: string): boolean {
|
||||
}
|
||||
}
|
||||
|
||||
function buildGoogleAuthOptions(env?: ProviderEnv): { keyFilename: string } | undefined {
|
||||
const keyFilename = getProviderEnvValue("GOOGLE_APPLICATION_CREDENTIALS", env);
|
||||
return keyFilename ? { keyFilename } : undefined;
|
||||
}
|
||||
|
||||
function resolveApiKey(options?: GoogleVertexOptions): string | undefined {
|
||||
const apiKey = options?.apiKey?.trim();
|
||||
if (!apiKey || apiKey === GCP_VERTEX_CREDENTIALS_MARKER || isPlaceholderApiKey(apiKey)) {
|
||||
@@ -407,7 +417,10 @@ function isPlaceholderApiKey(apiKey: string): boolean {
|
||||
}
|
||||
|
||||
function resolveProject(options?: GoogleVertexOptions): string {
|
||||
const project = options?.project || process.env.GOOGLE_CLOUD_PROJECT || process.env.GCLOUD_PROJECT;
|
||||
const project =
|
||||
options?.project ||
|
||||
getProviderEnvValue("GOOGLE_CLOUD_PROJECT", options?.env) ||
|
||||
getProviderEnvValue("GCLOUD_PROJECT", options?.env);
|
||||
if (!project) {
|
||||
throw new Error(
|
||||
"Vertex AI requires a project ID. Set GOOGLE_CLOUD_PROJECT/GCLOUD_PROJECT or pass project in options.",
|
||||
@@ -417,7 +430,7 @@ function resolveProject(options?: GoogleVertexOptions): string {
|
||||
}
|
||||
|
||||
function resolveLocation(options?: GoogleVertexOptions): string {
|
||||
const location = options?.location || process.env.GOOGLE_CLOUD_LOCATION;
|
||||
const location = options?.location || getProviderEnvValue("GOOGLE_CLOUD_LOCATION", options?.env);
|
||||
if (!location) {
|
||||
throw new Error("Vertex AI requires a location. Set GOOGLE_CLOUD_LOCATION or pass location in options.");
|
||||
}
|
||||
|
||||
@@ -27,6 +27,7 @@ import type {
|
||||
AssistantMessage,
|
||||
Context,
|
||||
Model,
|
||||
ProviderEnv,
|
||||
SimpleStreamOptions,
|
||||
StreamFunction,
|
||||
StreamOptions,
|
||||
@@ -40,6 +41,7 @@ import {
|
||||
} from "../utils/diagnostics.ts";
|
||||
import { AssistantMessageEventStream } from "../utils/event-stream.ts";
|
||||
import { headersToRecord } from "../utils/headers.ts";
|
||||
import { resolveHttpProxyUrlForTarget } from "../utils/node-http-proxy.ts";
|
||||
import { clampOpenAIPromptCacheKey } from "./openai-prompt-cache.ts";
|
||||
import { convertResponsesMessages, convertResponsesTools, processResponsesStream } from "./openai-responses-shared.ts";
|
||||
import { buildBaseOptions } from "./simple-options.ts";
|
||||
@@ -814,19 +816,13 @@ type WebSocketConstructor = new (
|
||||
) => WebSocketLike;
|
||||
|
||||
let _cachedWebsocket: WebSocketConstructor | null = null;
|
||||
async function getWebSocketConstructor(): Promise<WebSocketConstructor | null> {
|
||||
if (_cachedWebsocket) return _cachedWebsocket;
|
||||
async function getWebSocketConstructor(env?: ProviderEnv): Promise<WebSocketConstructor | null> {
|
||||
if (!env && _cachedWebsocket) return _cachedWebsocket;
|
||||
|
||||
// bun doesn't respect http proxy envs, ref: https://github.com/oven-sh/bun/issues/15489
|
||||
// TODO: remove this when bun supports proxy envs in websocket.
|
||||
if (
|
||||
process?.versions?.bun &&
|
||||
(process.env.HTTP_PROXY || process.env.HTTPS_PROXY || process.env.http_proxy || process.env.https_proxy)
|
||||
) {
|
||||
const m = await dynamicImport("proxy-from-env");
|
||||
const getProxyForUrl = (m as { getProxyForUrl: (url: string | object | URL) => string }).getProxyForUrl;
|
||||
|
||||
_cachedWebsocket = class extends WebSocket {
|
||||
if (typeof process !== "undefined" && process.versions?.bun) {
|
||||
const WebSocketWithProxy = class extends WebSocket {
|
||||
constructor(url: string | URL, options?: string | string[] | Record<string, unknown>) {
|
||||
let _opts: Record<string, unknown> = {};
|
||||
if (Array.isArray(options) || typeof options === "string") {
|
||||
@@ -835,11 +831,17 @@ async function getWebSocketConstructor(): Promise<WebSocketConstructor | null> {
|
||||
_opts = { ...options };
|
||||
}
|
||||
|
||||
const proxy = getProxyForUrl(url.toString().replace(/^wss:/, "https:").replace(/^ws:/, "http:"));
|
||||
super(url, { ..._opts, ...(proxy ? { proxy } : {}) } as any);
|
||||
const proxyUrl = resolveHttpProxyUrlForTarget(
|
||||
url.toString().replace(/^wss:/, "https:").replace(/^ws:/, "http:"),
|
||||
env,
|
||||
);
|
||||
super(url, { ..._opts, ...(proxyUrl ? { proxy: proxyUrl.toString() } : {}) } as any);
|
||||
}
|
||||
};
|
||||
return _cachedWebsocket;
|
||||
if (!env) {
|
||||
_cachedWebsocket = WebSocketWithProxy;
|
||||
}
|
||||
return WebSocketWithProxy;
|
||||
}
|
||||
|
||||
const ctor = (globalThis as { WebSocket?: unknown }).WebSocket;
|
||||
@@ -894,8 +896,9 @@ async function connectWebSocket(
|
||||
headers: Headers,
|
||||
signal?: AbortSignal,
|
||||
connectTimeoutMs = DEFAULT_WEBSOCKET_CONNECT_TIMEOUT_MS,
|
||||
env?: ProviderEnv,
|
||||
): Promise<WebSocketLike> {
|
||||
const WebSocketCtor = await getWebSocketConstructor();
|
||||
const WebSocketCtor = await getWebSocketConstructor(env);
|
||||
if (!WebSocketCtor) {
|
||||
throw new Error("WebSocket transport is not available in this runtime");
|
||||
}
|
||||
@@ -972,6 +975,7 @@ async function acquireWebSocket(
|
||||
sessionId: string | undefined,
|
||||
signal?: AbortSignal,
|
||||
connectTimeoutMs?: number,
|
||||
env?: ProviderEnv,
|
||||
): Promise<{
|
||||
socket: WebSocketLike;
|
||||
entry?: CachedWebSocketConnection;
|
||||
@@ -979,7 +983,7 @@ async function acquireWebSocket(
|
||||
release: (options?: { keep?: boolean }) => void;
|
||||
}> {
|
||||
if (!sessionId) {
|
||||
const socket = await connectWebSocket(url, headers, signal, connectTimeoutMs);
|
||||
const socket = await connectWebSocket(url, headers, signal, connectTimeoutMs, env);
|
||||
return {
|
||||
socket,
|
||||
reused: false,
|
||||
@@ -1011,7 +1015,7 @@ async function acquireWebSocket(
|
||||
};
|
||||
}
|
||||
if (cached.busy) {
|
||||
const socket = await connectWebSocket(url, headers, signal, connectTimeoutMs);
|
||||
const socket = await connectWebSocket(url, headers, signal, connectTimeoutMs, env);
|
||||
return {
|
||||
socket,
|
||||
reused: false,
|
||||
@@ -1026,7 +1030,7 @@ async function acquireWebSocket(
|
||||
}
|
||||
}
|
||||
|
||||
const socket = await connectWebSocket(url, headers, signal, connectTimeoutMs);
|
||||
const socket = await connectWebSocket(url, headers, signal, connectTimeoutMs, env);
|
||||
const entry: CachedWebSocketConnection = { socket, busy: true };
|
||||
websocketSessionCache.set(sessionId, entry);
|
||||
return {
|
||||
@@ -1312,6 +1316,7 @@ async function processWebSocketStream(
|
||||
options?.sessionId,
|
||||
options?.signal,
|
||||
websocketConnectTimeoutMs,
|
||||
options?.env,
|
||||
);
|
||||
let keepConnection = true;
|
||||
const useCachedContext = options?.transport === "websocket-cached" || options?.transport === "auto";
|
||||
|
||||
@@ -19,6 +19,7 @@ import type {
|
||||
Message,
|
||||
Model,
|
||||
OpenAICompletionsCompat,
|
||||
ProviderEnv,
|
||||
SimpleStreamOptions,
|
||||
StopReason,
|
||||
StreamFunction,
|
||||
@@ -32,6 +33,7 @@ import type {
|
||||
import { AssistantMessageEventStream } from "../utils/event-stream.ts";
|
||||
import { headersToRecord } from "../utils/headers.ts";
|
||||
import { parseStreamingJson } from "../utils/json-parse.ts";
|
||||
import { getProviderEnvValue } from "../utils/provider-env.ts";
|
||||
import { sanitizeSurrogates } from "../utils/sanitize-unicode.ts";
|
||||
import { isCloudflareProvider, resolveCloudflareBaseUrl } from "./cloudflare.ts";
|
||||
import { buildCopilotDynamicHeaders, hasCopilotVisionInput } from "./github-copilot-headers.ts";
|
||||
@@ -98,11 +100,11 @@ type ChatCompletionToolWithCacheControl = OpenAI.Chat.Completions.ChatCompletion
|
||||
cache_control?: OpenAICompatCacheControl;
|
||||
};
|
||||
|
||||
function resolveCacheRetention(cacheRetention?: CacheRetention): CacheRetention {
|
||||
function resolveCacheRetention(cacheRetention?: CacheRetention, env?: ProviderEnv): CacheRetention {
|
||||
if (cacheRetention) {
|
||||
return cacheRetention;
|
||||
}
|
||||
if (typeof process !== "undefined" && process.env.PI_CACHE_RETENTION === "long") {
|
||||
if (getProviderEnvValue("PI_CACHE_RETENTION", env) === "long") {
|
||||
return "long";
|
||||
}
|
||||
return "short";
|
||||
@@ -140,9 +142,9 @@ export const streamOpenAICompletions: StreamFunction<"openai-completions", OpenA
|
||||
throw new Error(`No API key for provider: ${model.provider}`);
|
||||
}
|
||||
const compat = getCompat(model);
|
||||
const cacheRetention = resolveCacheRetention(options?.cacheRetention);
|
||||
const cacheRetention = resolveCacheRetention(options?.cacheRetention, options?.env);
|
||||
const cacheSessionId = cacheRetention === "none" ? undefined : options?.sessionId;
|
||||
const client = createClient(model, context, apiKey, options?.headers, cacheSessionId, compat);
|
||||
const client = createClient(model, context, apiKey, options?.headers, cacheSessionId, compat, options?.env);
|
||||
let params = buildParams(model, context, options, compat, cacheRetention);
|
||||
const nextParams = await options?.onPayload?.(params, model);
|
||||
if (nextParams !== undefined) {
|
||||
@@ -454,6 +456,7 @@ function createClient(
|
||||
optionsHeaders?: Record<string, string>,
|
||||
sessionId?: string,
|
||||
compat: ResolvedOpenAICompletionsCompat = getCompat(model),
|
||||
env?: ProviderEnv,
|
||||
) {
|
||||
const headers = { ...model.headers };
|
||||
if (model.provider === "github-copilot") {
|
||||
@@ -487,7 +490,7 @@ function createClient(
|
||||
|
||||
return new OpenAI({
|
||||
apiKey,
|
||||
baseURL: isCloudflareProvider(model.provider) ? resolveCloudflareBaseUrl(model) : model.baseUrl,
|
||||
baseURL: isCloudflareProvider(model.provider) ? resolveCloudflareBaseUrl(model, env) : model.baseUrl,
|
||||
dangerouslyAllowBrowser: true,
|
||||
defaultHeaders,
|
||||
});
|
||||
@@ -498,7 +501,7 @@ function buildParams(
|
||||
context: Context,
|
||||
options?: OpenAICompletionsOptions,
|
||||
compat: ResolvedOpenAICompletionsCompat = getCompat(model),
|
||||
cacheRetention: CacheRetention = resolveCacheRetention(options?.cacheRetention),
|
||||
cacheRetention: CacheRetention = resolveCacheRetention(options?.cacheRetention, options?.env),
|
||||
) {
|
||||
const messages = convertMessages(model, context, compat);
|
||||
const cacheControl = getCompatCacheControl(compat, cacheRetention);
|
||||
|
||||
@@ -8,6 +8,7 @@ import type {
|
||||
Context,
|
||||
Model,
|
||||
OpenAIResponsesCompat,
|
||||
ProviderEnv,
|
||||
SimpleStreamOptions,
|
||||
StreamFunction,
|
||||
StreamOptions,
|
||||
@@ -15,6 +16,7 @@ import type {
|
||||
} from "../types.ts";
|
||||
import { AssistantMessageEventStream } from "../utils/event-stream.ts";
|
||||
import { headersToRecord } from "../utils/headers.ts";
|
||||
import { getProviderEnvValue } from "../utils/provider-env.ts";
|
||||
import { isCloudflareProvider, resolveCloudflareBaseUrl } from "./cloudflare.ts";
|
||||
import { buildCopilotDynamicHeaders, hasCopilotVisionInput } from "./github-copilot-headers.ts";
|
||||
import { clampOpenAIPromptCacheKey } from "./openai-prompt-cache.ts";
|
||||
@@ -27,11 +29,11 @@ const OPENAI_TOOL_CALL_PROVIDERS = new Set(["openai", "openai-codex", "opencode"
|
||||
* Resolve cache retention preference.
|
||||
* Defaults to "short" and uses PI_CACHE_RETENTION for backward compatibility.
|
||||
*/
|
||||
function resolveCacheRetention(cacheRetention?: CacheRetention): CacheRetention {
|
||||
function resolveCacheRetention(cacheRetention?: CacheRetention, env?: ProviderEnv): CacheRetention {
|
||||
if (cacheRetention) {
|
||||
return cacheRetention;
|
||||
}
|
||||
if (typeof process !== "undefined" && process.env.PI_CACHE_RETENTION === "long") {
|
||||
if (getProviderEnvValue("PI_CACHE_RETENTION", env) === "long") {
|
||||
return "long";
|
||||
}
|
||||
return "short";
|
||||
@@ -111,9 +113,9 @@ export const streamOpenAIResponses: StreamFunction<"openai-responses", OpenAIRes
|
||||
if (!apiKey) {
|
||||
throw new Error(`No API key for provider: ${model.provider}`);
|
||||
}
|
||||
const cacheRetention = resolveCacheRetention(options?.cacheRetention);
|
||||
const cacheRetention = resolveCacheRetention(options?.cacheRetention, options?.env);
|
||||
const cacheSessionId = cacheRetention === "none" ? undefined : options?.sessionId;
|
||||
const client = createClient(model, context, apiKey, options?.headers, cacheSessionId);
|
||||
const client = createClient(model, context, apiKey, options?.headers, cacheSessionId, options?.env);
|
||||
let params = buildParams(model, context, options);
|
||||
const nextParams = await options?.onPayload?.(params, model);
|
||||
if (nextParams !== undefined) {
|
||||
@@ -185,6 +187,7 @@ function createClient(
|
||||
apiKey: string,
|
||||
optionsHeaders?: Record<string, string>,
|
||||
sessionId?: string,
|
||||
env?: ProviderEnv,
|
||||
) {
|
||||
const compat = getCompat(model);
|
||||
const headers = { ...model.headers };
|
||||
@@ -220,7 +223,7 @@ function createClient(
|
||||
|
||||
return new OpenAI({
|
||||
apiKey,
|
||||
baseURL: isCloudflareProvider(model.provider) ? resolveCloudflareBaseUrl(model) : model.baseUrl,
|
||||
baseURL: isCloudflareProvider(model.provider) ? resolveCloudflareBaseUrl(model, env) : model.baseUrl,
|
||||
dangerouslyAllowBrowser: true,
|
||||
defaultHeaders,
|
||||
});
|
||||
@@ -229,7 +232,7 @@ function createClient(
|
||||
function buildParams(model: Model<"openai-responses">, context: Context, options?: OpenAIResponsesOptions) {
|
||||
const messages = convertResponsesMessages(model, context, OPENAI_TOOL_CALL_PROVIDERS);
|
||||
|
||||
const cacheRetention = resolveCacheRetention(options?.cacheRetention);
|
||||
const cacheRetention = resolveCacheRetention(options?.cacheRetention, options?.env);
|
||||
const compat = getCompat(model);
|
||||
const params: ResponseCreateParamsStreaming = {
|
||||
model: model.id,
|
||||
|
||||
@@ -17,6 +17,7 @@ export function buildBaseOptions(_model: Model<Api>, options?: SimpleStreamOptio
|
||||
maxRetries: options?.maxRetries,
|
||||
maxRetryDelayMs: options?.maxRetryDelayMs,
|
||||
metadata: options?.metadata,
|
||||
env: options?.env,
|
||||
};
|
||||
}
|
||||
|
||||
|
||||
@@ -24,7 +24,7 @@ function withEnvApiKey<TOptions extends StreamOptions>(
|
||||
options: TOptions | undefined,
|
||||
): TOptions | undefined {
|
||||
if (hasExplicitApiKey(options?.apiKey)) return options;
|
||||
const apiKey = getEnvApiKey(model.provider);
|
||||
const apiKey = getEnvApiKey(model.provider, options?.env);
|
||||
if (!apiKey) return options;
|
||||
return { ...options, apiKey } as TOptions;
|
||||
}
|
||||
|
||||
@@ -79,6 +79,9 @@ export type CacheRetention = "none" | "short" | "long";
|
||||
|
||||
export type Transport = "sse" | "websocket" | "websocket-cached" | "auto";
|
||||
|
||||
/** Provider-scoped environment overrides. Values take precedence over process.env. */
|
||||
export type ProviderEnv = Record<string, string>;
|
||||
|
||||
export interface ProviderResponse {
|
||||
status: number;
|
||||
headers: Record<string, string>;
|
||||
@@ -153,6 +156,12 @@ export interface StreamOptions {
|
||||
* For example, Anthropic uses `user_id` for abuse tracking and rate limiting.
|
||||
*/
|
||||
metadata?: Record<string, unknown>;
|
||||
/**
|
||||
* Provider-scoped environment values. These take precedence over process.env for
|
||||
* provider configuration such as regional settings, endpoint placeholders, and
|
||||
* proxy variables.
|
||||
*/
|
||||
env?: ProviderEnv;
|
||||
}
|
||||
|
||||
export type ProviderStreamOptions = StreamOptions & Record<string, unknown>;
|
||||
|
||||
@@ -1,7 +1,5 @@
|
||||
import type { Agent as HttpAgent } from "node:http";
|
||||
import type { Agent as HttpsAgent } from "node:https";
|
||||
import { HttpProxyAgent } from "http-proxy-agent";
|
||||
import { HttpsProxyAgent } from "https-proxy-agent";
|
||||
import type { ProviderEnv } from "../types.ts";
|
||||
import { getProviderEnvValue } from "./provider-env.ts";
|
||||
|
||||
const DEFAULT_PROXY_PORTS: Record<string, number> = {
|
||||
ftp: 21,
|
||||
@@ -12,16 +10,16 @@ const DEFAULT_PROXY_PORTS: Record<string, number> = {
|
||||
wss: 443,
|
||||
};
|
||||
|
||||
export interface NodeHttpProxyAgents {
|
||||
httpAgent: HttpAgent;
|
||||
httpsAgent: HttpsAgent;
|
||||
}
|
||||
|
||||
export const UNSUPPORTED_PROXY_PROTOCOL_MESSAGE =
|
||||
"Unsupported proxy protocol. SOCKS and PAC proxy URLs are not supported; use an HTTP or HTTPS proxy URL.";
|
||||
|
||||
function getProxyEnv(key: string): string {
|
||||
return process.env[key.toLowerCase()] || process.env[key.toUpperCase()] || "";
|
||||
function getProxyEnv(key: string, env?: ProviderEnv): string {
|
||||
const lowercaseKey = key.toLowerCase();
|
||||
const uppercaseKey = key.toUpperCase();
|
||||
return (
|
||||
env?.[lowercaseKey] ||
|
||||
env?.[uppercaseKey] ||
|
||||
getProviderEnvValue(lowercaseKey) ||
|
||||
getProviderEnvValue(uppercaseKey) ||
|
||||
""
|
||||
);
|
||||
}
|
||||
|
||||
function parseProxyTargetUrl(targetUrl: string | URL): URL | undefined {
|
||||
@@ -36,8 +34,8 @@ function parseProxyTargetUrl(targetUrl: string | URL): URL | undefined {
|
||||
}
|
||||
}
|
||||
|
||||
function shouldProxyHostname(hostname: string, port: number): boolean {
|
||||
const noProxy = getProxyEnv("no_proxy").toLowerCase();
|
||||
function shouldProxyHostname(hostname: string, port: number, env?: ProviderEnv): boolean {
|
||||
const noProxy = getProxyEnv("no_proxy", env).toLowerCase();
|
||||
if (!noProxy) {
|
||||
return true;
|
||||
}
|
||||
@@ -68,7 +66,7 @@ function shouldProxyHostname(hostname: string, port: number): boolean {
|
||||
});
|
||||
}
|
||||
|
||||
function getProxyForUrl(targetUrl: string | URL): string {
|
||||
function getProxyForUrl(targetUrl: string | URL, env?: ProviderEnv): string {
|
||||
const parsedUrl = parseProxyTargetUrl(targetUrl);
|
||||
if (!parsedUrl?.protocol || !parsedUrl.host) {
|
||||
return "";
|
||||
@@ -77,19 +75,22 @@ function getProxyForUrl(targetUrl: string | URL): string {
|
||||
const protocol = parsedUrl.protocol.split(":", 1)[0]!;
|
||||
const hostname = parsedUrl.host.replace(/:\d*$/, "");
|
||||
const port = Number.parseInt(parsedUrl.port, 10) || DEFAULT_PROXY_PORTS[protocol] || 0;
|
||||
if (!shouldProxyHostname(hostname, port)) {
|
||||
if (!shouldProxyHostname(hostname, port, env)) {
|
||||
return "";
|
||||
}
|
||||
|
||||
let proxy = getProxyEnv(`${protocol}_proxy`) || getProxyEnv("all_proxy");
|
||||
let proxy = getProxyEnv(`${protocol}_proxy`, env) || getProxyEnv("all_proxy", env);
|
||||
if (proxy && !proxy.includes("://")) {
|
||||
proxy = `${protocol}://${proxy}`;
|
||||
}
|
||||
return proxy;
|
||||
}
|
||||
|
||||
export function resolveHttpProxyUrlForTarget(targetUrl: string | URL): URL | undefined {
|
||||
const proxy = getProxyForUrl(targetUrl);
|
||||
export const UNSUPPORTED_PROXY_PROTOCOL_MESSAGE =
|
||||
"Unsupported proxy protocol. SOCKS and PAC proxy URLs are not supported; use an HTTP or HTTPS proxy URL.";
|
||||
|
||||
export function resolveHttpProxyUrlForTarget(targetUrl: string | URL, env?: ProviderEnv): URL | undefined {
|
||||
const proxy = getProxyForUrl(targetUrl, env);
|
||||
if (!proxy) {
|
||||
return undefined;
|
||||
}
|
||||
@@ -109,15 +110,3 @@ export function resolveHttpProxyUrlForTarget(targetUrl: string | URL): URL | und
|
||||
|
||||
return proxyUrl;
|
||||
}
|
||||
|
||||
export function createHttpProxyAgentsForTarget(targetUrl: string | URL): NodeHttpProxyAgents | undefined {
|
||||
const proxyUrl = resolveHttpProxyUrlForTarget(targetUrl);
|
||||
if (!proxyUrl) {
|
||||
return undefined;
|
||||
}
|
||||
|
||||
return {
|
||||
httpAgent: new HttpProxyAgent(proxyUrl),
|
||||
httpsAgent: new HttpsProxyAgent(proxyUrl) as unknown as HttpsAgent,
|
||||
};
|
||||
}
|
||||
|
||||
@@ -6,6 +6,7 @@
|
||||
*/
|
||||
|
||||
import type { Server } from "node:http";
|
||||
import { getProviderEnvValue } from "../provider-env.ts";
|
||||
import { oauthErrorHtml, oauthSuccessHtml } from "./oauth-page.ts";
|
||||
import { generatePKCE } from "./pkce.ts";
|
||||
import type { OAuthCredentials, OAuthLoginCallbacks, OAuthPrompt, OAuthProviderInterface } from "./types.ts";
|
||||
@@ -28,7 +29,7 @@ const decode = (s: string) => atob(s);
|
||||
const CLIENT_ID = decode("OWQxYzI1MGEtZTYxYi00NGQ5LTg4ZWQtNTk0NGQxOTYyZjVl");
|
||||
const AUTHORIZE_URL = "https://claude.ai/oauth/authorize";
|
||||
const TOKEN_URL = "https://platform.claude.com/v1/oauth/token";
|
||||
const CALLBACK_HOST = process.env.PI_OAUTH_CALLBACK_HOST || "127.0.0.1";
|
||||
const CALLBACK_HOST = getProviderEnvValue("PI_OAUTH_CALLBACK_HOST") || "127.0.0.1";
|
||||
const CALLBACK_PORT = 53692;
|
||||
const CALLBACK_PATH = "/callback";
|
||||
const REDIRECT_URI = `http://localhost:${CALLBACK_PORT}${CALLBACK_PATH}`;
|
||||
|
||||
@@ -17,6 +17,7 @@ if (typeof process !== "undefined" && (process.versions?.node || process.version
|
||||
});
|
||||
}
|
||||
|
||||
import { getProviderEnvValue } from "../provider-env.ts";
|
||||
import { pollOAuthDeviceCodeFlow } from "./device-code.ts";
|
||||
import { oauthErrorHtml, oauthSuccessHtml } from "./oauth-page.ts";
|
||||
import { generatePKCE } from "./pkce.ts";
|
||||
@@ -47,7 +48,7 @@ type OAuthToken = { access: string; refresh: string; expires: number };
|
||||
type TokenOperation = "exchange" | "refresh";
|
||||
|
||||
function getCallbackHost(): string {
|
||||
return typeof process !== "undefined" ? process.env.PI_OAUTH_CALLBACK_HOST || "127.0.0.1" : "127.0.0.1";
|
||||
return getProviderEnvValue("PI_OAUTH_CALLBACK_HOST") || "127.0.0.1";
|
||||
}
|
||||
|
||||
type DeviceAuthInfo = {
|
||||
|
||||
@@ -0,0 +1,52 @@
|
||||
import type { ProviderEnv } from "../types.ts";
|
||||
|
||||
let procEnvCache: Map<string, string> | null = null;
|
||||
|
||||
/**
|
||||
* Fallback for https://github.com/oven-sh/bun/issues/27802.
|
||||
* Bun compiled binaries can expose an empty process.env inside Linux sandboxes
|
||||
* even though /proc/self/environ contains the environment.
|
||||
*
|
||||
* This intentionally duplicates restoreSandboxEnv() in
|
||||
* packages/coding-agent/src/bun/restore-sandbox-env.ts. The ai package can be
|
||||
* used directly, without going through that entrypoint, so provider env lookup
|
||||
* must not depend on process.env having been patched.
|
||||
*/
|
||||
function getBunSandboxEnvValue(name: string): string | undefined {
|
||||
if (typeof process === "undefined" || !process.versions?.bun || Object.keys(process.env).length > 0) {
|
||||
return undefined;
|
||||
}
|
||||
|
||||
if (procEnvCache === null) {
|
||||
procEnvCache = new Map();
|
||||
try {
|
||||
const { readFileSync } = require("node:fs") as {
|
||||
readFileSync(path: string, encoding: BufferEncoding): string;
|
||||
};
|
||||
const data = readFileSync("/proc/self/environ", "utf-8");
|
||||
for (const entry of data.split("\0")) {
|
||||
const idx = entry.indexOf("=");
|
||||
if (idx > 0) {
|
||||
procEnvCache.set(entry.slice(0, idx), entry.slice(idx + 1));
|
||||
}
|
||||
}
|
||||
} catch {
|
||||
// /proc/self/environ may not exist or may not be readable.
|
||||
}
|
||||
}
|
||||
|
||||
return procEnvCache.get(name);
|
||||
}
|
||||
|
||||
/**
|
||||
* Resolve a provider env value from scoped overrides, normal process.env, then
|
||||
* the duplicated Bun sandbox fallback for direct pi-ai consumers.
|
||||
*/
|
||||
export function getProviderEnvValue(name: string, env?: ProviderEnv): string | undefined {
|
||||
return (
|
||||
env?.[name] ||
|
||||
(typeof process !== "undefined" ? process.env[name] : undefined) ||
|
||||
getBunSandboxEnvValue(name) ||
|
||||
undefined
|
||||
);
|
||||
}
|
||||
@@ -45,7 +45,7 @@ vi.mock("@aws-sdk/client-bedrock-runtime", () => {
|
||||
});
|
||||
|
||||
import { getModel } from "../src/models.ts";
|
||||
import { streamBedrock } from "../src/providers/amazon-bedrock.ts";
|
||||
import { type BedrockOptions, streamBedrock } from "../src/providers/amazon-bedrock.ts";
|
||||
import type { Context, Model } from "../src/types.ts";
|
||||
|
||||
const context: Context = {
|
||||
@@ -83,8 +83,12 @@ afterEach(() => {
|
||||
}
|
||||
});
|
||||
|
||||
async function captureClientConfig(model: Model<"bedrock-converse-stream">): Promise<Record<string, unknown>> {
|
||||
await streamBedrock(model, context, { cacheRetention: "none" }).result();
|
||||
async function captureClientConfig(
|
||||
model: Model<"bedrock-converse-stream">,
|
||||
options: BedrockOptions = {},
|
||||
): Promise<Record<string, unknown>> {
|
||||
bedrockMock.constructorCalls.length = 0;
|
||||
await streamBedrock(model, context, { cacheRetention: "none", ...options }).result();
|
||||
expect(bedrockMock.constructorCalls).toHaveLength(1);
|
||||
return bedrockMock.constructorCalls[0];
|
||||
}
|
||||
@@ -115,6 +119,29 @@ describe("bedrock endpoint resolution", () => {
|
||||
expect(config.region).toBe("eu-central-1");
|
||||
});
|
||||
|
||||
it("handles missing regions for explicit, scoped, and ambient profiles", async () => {
|
||||
const model = getModel("amazon-bedrock", "eu.anthropic.claude-sonnet-4-5-20250929-v1:0");
|
||||
|
||||
let config = await captureClientConfig(model, { profile: "bedrock-profile" });
|
||||
|
||||
expect(config.profile).toBe("bedrock-profile");
|
||||
expect(config.endpoint).toBe("https://bedrock-runtime.eu-central-1.amazonaws.com");
|
||||
expect(config.region).toBe("eu-central-1");
|
||||
|
||||
config = await captureClientConfig(model, { env: { AWS_PROFILE: "scoped-bedrock-profile" } });
|
||||
|
||||
expect(config.profile).toBe("scoped-bedrock-profile");
|
||||
expect(config.endpoint).toBe("https://bedrock-runtime.eu-central-1.amazonaws.com");
|
||||
expect(config.region).toBe("eu-central-1");
|
||||
|
||||
process.env.AWS_PROFILE = "ambient-bedrock-profile";
|
||||
config = await captureClientConfig(model);
|
||||
|
||||
expect(config.profile).toBe("ambient-bedrock-profile");
|
||||
expect(config.endpoint).toBeUndefined();
|
||||
expect(config.region).toBeUndefined();
|
||||
});
|
||||
|
||||
it("still passes custom Bedrock endpoints through to the SDK client", async () => {
|
||||
process.env.AWS_REGION = "us-west-2";
|
||||
const baseModel = getModel("amazon-bedrock", "us.anthropic.claude-opus-4-8");
|
||||
|
||||
@@ -54,6 +54,17 @@ describe("node HTTP proxy resolution", () => {
|
||||
);
|
||||
});
|
||||
|
||||
it("prefers scoped proxy env aliases before process env aliases", () => {
|
||||
resetProxyEnv();
|
||||
process.env.https_proxy = "http://process-proxy.example:8080";
|
||||
|
||||
expect(
|
||||
resolveHttpProxyUrlForTarget("https://bedrock-runtime.us-east-1.amazonaws.com", {
|
||||
HTTPS_PROXY: "http://scoped-proxy.example:8080",
|
||||
})?.toString(),
|
||||
).toBe("http://scoped-proxy.example:8080/");
|
||||
});
|
||||
|
||||
it("rejects SOCKS and PAC proxy URLs explicitly", () => {
|
||||
resetProxyEnv();
|
||||
process.env.HTTPS_PROXY = "socks5://proxy.example:1080";
|
||||
|
||||
@@ -163,6 +163,31 @@ describe("openai-completions empty tools handling", () => {
|
||||
expect(clientOptions.defaultHeaders?.["cf-aig-authorization"]).toBe("Bearer test");
|
||||
});
|
||||
|
||||
it("uses provider env before process.env for Cloudflare AI Gateway base URL", async () => {
|
||||
process.env.CLOUDFLARE_ACCOUNT_ID = "process-account";
|
||||
process.env.CLOUDFLARE_GATEWAY_ID = "process-gateway";
|
||||
const model = getModel("cloudflare-ai-gateway", "workers-ai/@cf/moonshotai/kimi-k2.6")!;
|
||||
|
||||
await streamSimple(
|
||||
model,
|
||||
{
|
||||
messages: [{ role: "user", content: "hi", timestamp: Date.now() }],
|
||||
},
|
||||
{
|
||||
apiKey: "test",
|
||||
env: {
|
||||
CLOUDFLARE_ACCOUNT_ID: "provider-account",
|
||||
CLOUDFLARE_GATEWAY_ID: "provider-gateway",
|
||||
},
|
||||
},
|
||||
).result();
|
||||
|
||||
const clientOptions = mockState.lastClientOptions as { baseURL?: string };
|
||||
expect(clientOptions.baseURL).toBe(
|
||||
"https://gateway.ai.cloudflare.com/v1/provider-account/provider-gateway/compat",
|
||||
);
|
||||
});
|
||||
|
||||
it("preserves inline upstream Authorization for Cloudflare AI Gateway BYOK requests", async () => {
|
||||
process.env.CLOUDFLARE_ACCOUNT_ID = "account-id";
|
||||
process.env.CLOUDFLARE_GATEWAY_ID = "gateway-id";
|
||||
|
||||
@@ -2,6 +2,10 @@
|
||||
|
||||
## [Unreleased]
|
||||
|
||||
### Added
|
||||
|
||||
- Added `auth.json` API key `env` values so provider-specific environment overrides can be scoped to Pi and propagated to inherited provider configuration ([#5728](https://github.com/earendil-works/pi/issues/5728)).
|
||||
|
||||
### Fixed
|
||||
|
||||
- Fixed successful `pi update` on Windows to exit naturally instead of calling `process.exit(0)`, avoiding a Node.js/libuv assertion after version-check network requests ([#5805](https://github.com/earendil-works/pi/issues/5805)).
|
||||
|
||||
@@ -104,6 +104,24 @@ Store credentials in `~/.pi/agent/auth.json`:
|
||||
|
||||
The file is created with `0600` permissions (user read/write only). Auth file credentials take priority over environment variables.
|
||||
|
||||
API key credentials can also include provider-scoped environment values. These values are used before process environment variables when resolving the credential key, provider/model headers, and provider configuration such as Cloudflare account IDs, Azure OpenAI settings, Vertex project/location, Bedrock settings, `PI_CACHE_RETENTION`, and `HTTP_PROXY`/`HTTPS_PROXY`.
|
||||
|
||||
```json
|
||||
{
|
||||
"cloudflare-ai-gateway": {
|
||||
"type": "api_key",
|
||||
"key": "$CLOUDFLARE_API_KEY",
|
||||
"env": {
|
||||
"CLOUDFLARE_API_KEY": "...",
|
||||
"CLOUDFLARE_ACCOUNT_ID": "account-id",
|
||||
"CLOUDFLARE_GATEWAY_ID": "gateway-id"
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
Use this when pi should use different provider settings than the project shell environment.
|
||||
|
||||
### Key Resolution
|
||||
|
||||
The `key` field supports command execution, environment interpolation, and literals:
|
||||
@@ -194,7 +212,7 @@ export AWS_BEDROCK_FORCE_HTTP1=1
|
||||
|
||||
### Cloudflare AI Gateway
|
||||
|
||||
`CLOUDFLARE_API_KEY` can be set via `/login`. The account ID and gateway slug must be set as environment variables.
|
||||
`CLOUDFLARE_API_KEY` can be set via `/login`. The account ID and gateway slug can be set as environment variables or in the API key credential's `env` object in `auth.json`.
|
||||
|
||||
```bash
|
||||
export CLOUDFLARE_API_KEY=... # or use /login
|
||||
@@ -218,7 +236,7 @@ For normal pi usage, prefer unified billing or stored BYOK. Inline BYOK requires
|
||||
|
||||
### Cloudflare Workers AI
|
||||
|
||||
`CLOUDFLARE_API_KEY` can be set via `/login`. `CLOUDFLARE_ACCOUNT_ID` must be set as an environment variable.
|
||||
`CLOUDFLARE_API_KEY` can be set via `/login`. `CLOUDFLARE_ACCOUNT_ID` can be set as an environment variable or in the API key credential's `env` object in `auth.json`.
|
||||
|
||||
```bash
|
||||
export CLOUDFLARE_API_KEY=... # or use /login
|
||||
|
||||
@@ -4,6 +4,10 @@
|
||||
* Bun compiled binaries have an empty `process.env` when running inside
|
||||
* sandbox environments (e.g. nono on Linux/macOS). On Linux we can recover
|
||||
* the environment from `/proc/self/environ`.
|
||||
*
|
||||
* Keep this in sync with getBunSandboxEnvValue() in
|
||||
* packages/ai/src/utils/provider-env.ts. The ai package duplicates the lookup
|
||||
* for direct consumers that do not go through this coding-agent entrypoint.
|
||||
*/
|
||||
|
||||
import { readFileSync } from "node:fs";
|
||||
|
||||
@@ -357,6 +357,7 @@ export class AgentSession {
|
||||
private async _getRequiredRequestAuth(model: Model<any>): Promise<{
|
||||
apiKey: string;
|
||||
headers?: Record<string, string>;
|
||||
env?: Record<string, string>;
|
||||
}> {
|
||||
const result = await this._modelRegistry.getApiKeyAndHeaders(model);
|
||||
if (!result.ok) {
|
||||
@@ -366,7 +367,7 @@ export class AgentSession {
|
||||
throw new Error(result.error);
|
||||
}
|
||||
if (result.apiKey) {
|
||||
return { apiKey: result.apiKey, headers: result.headers };
|
||||
return { apiKey: result.apiKey, headers: result.headers, env: result.env };
|
||||
}
|
||||
|
||||
const isOAuth = this._modelRegistry.isUsingOAuth(model);
|
||||
@@ -383,13 +384,14 @@ export class AgentSession {
|
||||
private async _getCompactionRequestAuth(model: Model<any>): Promise<{
|
||||
apiKey?: string;
|
||||
headers?: Record<string, string>;
|
||||
env?: Record<string, string>;
|
||||
}> {
|
||||
if (this.agent.streamFn === streamSimple) {
|
||||
return this._getRequiredRequestAuth(model);
|
||||
}
|
||||
|
||||
const result = await this._modelRegistry.getApiKeyAndHeaders(model);
|
||||
return result.ok ? { apiKey: result.apiKey, headers: result.headers } : {};
|
||||
return result.ok ? { apiKey: result.apiKey, headers: result.headers, env: result.env } : {};
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -1649,7 +1651,7 @@ export class AgentSession {
|
||||
throw new Error(formatNoModelSelectedMessage());
|
||||
}
|
||||
|
||||
const { apiKey, headers } = await this._getCompactionRequestAuth(this.model);
|
||||
const { apiKey, headers, env } = await this._getCompactionRequestAuth(this.model);
|
||||
|
||||
const pathEntries = this.sessionManager.getBranch();
|
||||
const settings = this.settingsManager.getCompactionSettings();
|
||||
@@ -1708,6 +1710,7 @@ export class AgentSession {
|
||||
this._compactionAbortController.signal,
|
||||
this.thinkingLevel,
|
||||
this.agent.streamFn,
|
||||
env,
|
||||
);
|
||||
summary = result.summary;
|
||||
firstKeptEntryId = result.firstKeptEntryId;
|
||||
@@ -1898,6 +1901,7 @@ export class AgentSession {
|
||||
|
||||
let apiKey: string | undefined;
|
||||
let headers: Record<string, string> | undefined;
|
||||
let env: Record<string, string> | undefined;
|
||||
if (this.agent.streamFn === streamSimple) {
|
||||
const authResult = await this._modelRegistry.getApiKeyAndHeaders(this.model);
|
||||
if (!authResult.ok || !authResult.apiKey) {
|
||||
@@ -1912,8 +1916,9 @@ export class AgentSession {
|
||||
}
|
||||
apiKey = authResult.apiKey;
|
||||
headers = authResult.headers;
|
||||
env = authResult.env;
|
||||
} else {
|
||||
({ apiKey, headers } = await this._getCompactionRequestAuth(this.model));
|
||||
({ apiKey, headers, env } = await this._getCompactionRequestAuth(this.model));
|
||||
}
|
||||
|
||||
const pathEntries = this.sessionManager.getBranch();
|
||||
@@ -1981,6 +1986,7 @@ export class AgentSession {
|
||||
this._autoCompactionAbortController.signal,
|
||||
this.thinkingLevel,
|
||||
this.agent.streamFn,
|
||||
env,
|
||||
);
|
||||
summary = compactResult.summary;
|
||||
firstKeptEntryId = compactResult.firstKeptEntryId;
|
||||
@@ -2784,12 +2790,13 @@ export class AgentSession {
|
||||
let summaryDetails: unknown;
|
||||
if (options.summarize && entriesToSummarize.length > 0 && !extensionSummary) {
|
||||
const model = this.model!;
|
||||
const { apiKey, headers } = await this._getRequiredRequestAuth(model);
|
||||
const { apiKey, headers, env } = await this._getRequiredRequestAuth(model);
|
||||
const branchSummarySettings = this.settingsManager.getBranchSummarySettings();
|
||||
const result = await generateBranchSummary(entriesToSummarize, {
|
||||
model,
|
||||
apiKey,
|
||||
headers,
|
||||
env,
|
||||
signal: this._branchSummaryAbortController.signal,
|
||||
customInstructions,
|
||||
replaceInstructions,
|
||||
|
||||
@@ -24,6 +24,7 @@ import { resolveConfigValue } from "./resolve-config-value.ts";
|
||||
export type ApiKeyCredential = {
|
||||
type: "api_key";
|
||||
key: string;
|
||||
env?: Record<string, string>;
|
||||
};
|
||||
|
||||
export type OAuthCredential = {
|
||||
@@ -303,6 +304,14 @@ export class AuthStorage {
|
||||
return this.data[provider] ?? undefined;
|
||||
}
|
||||
|
||||
/**
|
||||
* Get provider-scoped environment values for an API key credential.
|
||||
*/
|
||||
getProviderEnv(provider: string): Record<string, string> | undefined {
|
||||
const cred = this.data[provider];
|
||||
return cred?.type === "api_key" && cred.env ? { ...cred.env } : undefined;
|
||||
}
|
||||
|
||||
/**
|
||||
* Set credential for a provider.
|
||||
*/
|
||||
@@ -471,7 +480,7 @@ export class AuthStorage {
|
||||
const cred = this.data[providerId];
|
||||
|
||||
if (cred?.type === "api_key") {
|
||||
return resolveConfigValue(cred.key);
|
||||
return resolveConfigValue(cred.key, cred.env);
|
||||
}
|
||||
|
||||
if (cred?.type === "oauth") {
|
||||
|
||||
@@ -69,6 +69,8 @@ export interface GenerateBranchSummaryOptions {
|
||||
apiKey: string;
|
||||
/** Request headers for the model */
|
||||
headers?: Record<string, string>;
|
||||
/** Provider-scoped environment values for the model */
|
||||
env?: Record<string, string>;
|
||||
/** Abort signal for cancellation */
|
||||
signal: AbortSignal;
|
||||
/** Optional custom instructions for summarization */
|
||||
@@ -290,6 +292,7 @@ export async function generateBranchSummary(
|
||||
model,
|
||||
apiKey,
|
||||
headers,
|
||||
env,
|
||||
signal,
|
||||
customInstructions,
|
||||
replaceInstructions,
|
||||
@@ -335,7 +338,7 @@ export async function generateBranchSummary(
|
||||
// request behavior (timeouts, retries, attribution headers) stays consistent
|
||||
// without running through agent state/events.
|
||||
const context = { systemPrompt: SUMMARIZATION_SYSTEM_PROMPT, messages: summarizationMessages };
|
||||
const requestOptions: SimpleStreamOptions = { apiKey, headers, signal, maxTokens: 2048 };
|
||||
const requestOptions: SimpleStreamOptions = { apiKey, headers, env, signal, maxTokens: 2048 };
|
||||
const response = streamFn
|
||||
? await (await streamFn(model, context, requestOptions)).result()
|
||||
: await completeSimple(model, context, requestOptions);
|
||||
|
||||
@@ -528,10 +528,11 @@ function createSummarizationOptions(
|
||||
maxTokens: number,
|
||||
apiKey: string | undefined,
|
||||
headers: Record<string, string> | undefined,
|
||||
env: Record<string, string> | undefined,
|
||||
signal: AbortSignal | undefined,
|
||||
thinkingLevel: ThinkingLevel | undefined,
|
||||
): SimpleStreamOptions {
|
||||
const options: SimpleStreamOptions = { maxTokens, signal, apiKey, headers };
|
||||
const options: SimpleStreamOptions = { maxTokens, signal, apiKey, headers, env };
|
||||
if (model.reasoning && thinkingLevel && thinkingLevel !== "off") {
|
||||
options.reasoning = thinkingLevel;
|
||||
}
|
||||
@@ -566,6 +567,7 @@ export async function generateSummary(
|
||||
previousSummary?: string,
|
||||
thinkingLevel?: ThinkingLevel,
|
||||
streamFn?: StreamFn,
|
||||
env?: Record<string, string>,
|
||||
): Promise<string> {
|
||||
const maxTokens = Math.min(
|
||||
Math.floor(0.8 * reserveTokens),
|
||||
@@ -598,7 +600,7 @@ export async function generateSummary(
|
||||
},
|
||||
];
|
||||
|
||||
const completionOptions = createSummarizationOptions(model, maxTokens, apiKey, headers, signal, thinkingLevel);
|
||||
const completionOptions = createSummarizationOptions(model, maxTokens, apiKey, headers, env, signal, thinkingLevel);
|
||||
|
||||
const response = await completeSummarization(
|
||||
model,
|
||||
@@ -753,6 +755,7 @@ export async function compact(
|
||||
signal?: AbortSignal,
|
||||
thinkingLevel?: ThinkingLevel,
|
||||
streamFn?: StreamFn,
|
||||
env?: Record<string, string>,
|
||||
): Promise<CompactionResult> {
|
||||
const {
|
||||
firstKeptEntryId,
|
||||
@@ -783,6 +786,7 @@ export async function compact(
|
||||
previousSummary,
|
||||
thinkingLevel,
|
||||
streamFn,
|
||||
env,
|
||||
)
|
||||
: Promise.resolve("No prior history."),
|
||||
generateTurnPrefixSummary(
|
||||
@@ -791,6 +795,7 @@ export async function compact(
|
||||
settings.reserveTokens,
|
||||
apiKey,
|
||||
headers,
|
||||
env,
|
||||
signal,
|
||||
thinkingLevel,
|
||||
streamFn,
|
||||
@@ -811,6 +816,7 @@ export async function compact(
|
||||
previousSummary,
|
||||
thinkingLevel,
|
||||
streamFn,
|
||||
env,
|
||||
);
|
||||
}
|
||||
|
||||
@@ -839,6 +845,7 @@ async function generateTurnPrefixSummary(
|
||||
reserveTokens: number,
|
||||
apiKey: string | undefined,
|
||||
headers?: Record<string, string>,
|
||||
env?: Record<string, string>,
|
||||
signal?: AbortSignal,
|
||||
thinkingLevel?: ThinkingLevel,
|
||||
streamFn?: StreamFn,
|
||||
@@ -861,7 +868,7 @@ async function generateTurnPrefixSummary(
|
||||
const response = await completeSummarization(
|
||||
model,
|
||||
{ systemPrompt: SUMMARIZATION_SYSTEM_PROMPT, messages: summarizationMessages },
|
||||
createSummarizationOptions(model, maxTokens, apiKey, headers, signal, thinkingLevel),
|
||||
createSummarizationOptions(model, maxTokens, apiKey, headers, env, signal, thinkingLevel),
|
||||
streamFn,
|
||||
);
|
||||
|
||||
|
||||
@@ -240,6 +240,7 @@ export type ResolvedRequestAuth =
|
||||
ok: true;
|
||||
apiKey?: string;
|
||||
headers?: Record<string, string>;
|
||||
env?: Record<string, string>;
|
||||
}
|
||||
| {
|
||||
ok: false;
|
||||
@@ -684,17 +685,27 @@ export class ModelRegistry {
|
||||
async getApiKeyAndHeaders(model: Model<Api>): Promise<ResolvedRequestAuth> {
|
||||
try {
|
||||
const providerConfig = this.providerRequestConfigs.get(model.provider);
|
||||
const providerEnv = this.authStorage.getProviderEnv(model.provider);
|
||||
const apiKeyFromAuthStorage = await this.authStorage.getApiKey(model.provider, { includeFallback: false });
|
||||
const apiKey =
|
||||
apiKeyFromAuthStorage ??
|
||||
(providerConfig?.apiKey
|
||||
? resolveConfigValueOrThrow(providerConfig.apiKey, `API key for provider "${model.provider}"`)
|
||||
? resolveConfigValueOrThrow(
|
||||
providerConfig.apiKey,
|
||||
`API key for provider "${model.provider}"`,
|
||||
providerEnv,
|
||||
)
|
||||
: undefined);
|
||||
|
||||
const providerHeaders = resolveHeadersOrThrow(providerConfig?.headers, `provider "${model.provider}"`);
|
||||
const providerHeaders = resolveHeadersOrThrow(
|
||||
providerConfig?.headers,
|
||||
`provider "${model.provider}"`,
|
||||
providerEnv,
|
||||
);
|
||||
const modelHeaders = resolveHeadersOrThrow(
|
||||
this.modelRequestHeaders.get(this.getModelRequestKey(model.provider, model.id)),
|
||||
`model "${model.provider}/${model.id}"`,
|
||||
providerEnv,
|
||||
);
|
||||
|
||||
let headers =
|
||||
@@ -713,6 +724,7 @@ export class ModelRegistry {
|
||||
ok: true,
|
||||
apiKey,
|
||||
headers: headers && Object.keys(headers).length > 0 ? headers : undefined,
|
||||
env: providerEnv && Object.keys(providerEnv).length > 0 ? providerEnv : undefined,
|
||||
};
|
||||
} catch (error) {
|
||||
return {
|
||||
@@ -777,7 +789,9 @@ export class ModelRegistry {
|
||||
}
|
||||
|
||||
const providerApiKey = this.providerRequestConfigs.get(provider)?.apiKey;
|
||||
return providerApiKey ? resolveConfigValueUncached(providerApiKey) : undefined;
|
||||
return providerApiKey
|
||||
? resolveConfigValueUncached(providerApiKey, this.authStorage.getProviderEnv(provider))
|
||||
: undefined;
|
||||
}
|
||||
|
||||
/**
|
||||
|
||||
@@ -85,8 +85,8 @@ function parseConfigValueReference(config: string): ConfigValueReference {
|
||||
return { type: "template", parts: parseConfigValueTemplate(config) };
|
||||
}
|
||||
|
||||
function resolveEnvConfigValue(name: string): string | undefined {
|
||||
return process.env[name] || undefined;
|
||||
function resolveEnvConfigValue(name: string, env?: Record<string, string>): string | undefined {
|
||||
return env?.[name] || process.env[name] || undefined;
|
||||
}
|
||||
|
||||
function getTemplateEnvVarNames(parts: TemplatePart[]): string[] {
|
||||
@@ -98,14 +98,14 @@ function getTemplateEnvVarNames(parts: TemplatePart[]): string[] {
|
||||
return names;
|
||||
}
|
||||
|
||||
function resolveTemplate(parts: TemplatePart[]): string | undefined {
|
||||
function resolveTemplate(parts: TemplatePart[], env?: Record<string, string>): string | undefined {
|
||||
let resolved = "";
|
||||
for (const part of parts) {
|
||||
if (part.type === "literal") {
|
||||
resolved += part.value;
|
||||
continue;
|
||||
}
|
||||
const envValue = resolveEnvConfigValue(part.name);
|
||||
const envValue = resolveEnvConfigValue(part.name, env);
|
||||
if (envValue === undefined) return undefined;
|
||||
resolved += envValue;
|
||||
}
|
||||
@@ -123,16 +123,16 @@ export function getConfigValueEnvVarNames(config: string): string[] {
|
||||
return reference.type === "template" ? getTemplateEnvVarNames(reference.parts) : [];
|
||||
}
|
||||
|
||||
export function getMissingConfigValueEnvVarNames(config: string): string[] {
|
||||
return getConfigValueEnvVarNames(config).filter((name) => resolveEnvConfigValue(name) === undefined);
|
||||
export function getMissingConfigValueEnvVarNames(config: string, env?: Record<string, string>): string[] {
|
||||
return getConfigValueEnvVarNames(config).filter((name) => resolveEnvConfigValue(name, env) === undefined);
|
||||
}
|
||||
|
||||
export function isCommandConfigValue(config: string): boolean {
|
||||
return parseConfigValueReference(config).type === "command";
|
||||
}
|
||||
|
||||
export function isConfigValueConfigured(config: string): boolean {
|
||||
return getMissingConfigValueEnvVarNames(config).length === 0;
|
||||
export function isConfigValueConfigured(config: string, env?: Record<string, string>): boolean {
|
||||
return getMissingConfigValueEnvVarNames(config, env).length === 0;
|
||||
}
|
||||
|
||||
/**
|
||||
@@ -142,12 +142,12 @@ export function isConfigValueConfigured(config: string): boolean {
|
||||
* - In non-command values, "$$" escapes a literal "$" and "$!" escapes a literal "!"
|
||||
* - Otherwise treats the value as a literal
|
||||
*/
|
||||
export function resolveConfigValue(config: string): string | undefined {
|
||||
export function resolveConfigValue(config: string, env?: Record<string, string>): string | undefined {
|
||||
const reference = parseConfigValueReference(config);
|
||||
if (reference.type === "command") {
|
||||
return executeCommand(reference.config);
|
||||
}
|
||||
return resolveTemplate(reference.parts);
|
||||
return resolveTemplate(reference.parts, env);
|
||||
}
|
||||
|
||||
function executeWithConfiguredShell(command: string): { executed: boolean; value: string | undefined } {
|
||||
@@ -216,16 +216,16 @@ function executeCommand(commandConfig: string): string | undefined {
|
||||
/**
|
||||
* Resolve all header values using the same resolution logic as API keys.
|
||||
*/
|
||||
export function resolveConfigValueUncached(config: string): string | undefined {
|
||||
export function resolveConfigValueUncached(config: string, env?: Record<string, string>): string | undefined {
|
||||
const reference = parseConfigValueReference(config);
|
||||
if (reference.type === "command") {
|
||||
return executeCommandUncached(reference.config);
|
||||
}
|
||||
return resolveTemplate(reference.parts);
|
||||
return resolveTemplate(reference.parts, env);
|
||||
}
|
||||
|
||||
export function resolveConfigValueOrThrow(config: string, description: string): string {
|
||||
const resolvedValue = resolveConfigValueUncached(config);
|
||||
export function resolveConfigValueOrThrow(config: string, description: string, env?: Record<string, string>): string {
|
||||
const resolvedValue = resolveConfigValueUncached(config, env);
|
||||
if (resolvedValue !== undefined) {
|
||||
return resolvedValue;
|
||||
}
|
||||
@@ -236,7 +236,7 @@ export function resolveConfigValueOrThrow(config: string, description: string):
|
||||
}
|
||||
|
||||
if (reference.type === "template") {
|
||||
const missingEnvVars = getMissingConfigValueEnvVarNames(config);
|
||||
const missingEnvVars = getMissingConfigValueEnvVarNames(config, env);
|
||||
if (missingEnvVars.length === 1) {
|
||||
throw new Error(`Failed to resolve ${description} from environment variable: ${missingEnvVars[0]}`);
|
||||
}
|
||||
@@ -251,11 +251,14 @@ export function resolveConfigValueOrThrow(config: string, description: string):
|
||||
/**
|
||||
* Resolve all header values using the same resolution logic as API keys.
|
||||
*/
|
||||
export function resolveHeaders(headers: Record<string, string> | undefined): Record<string, string> | undefined {
|
||||
export function resolveHeaders(
|
||||
headers: Record<string, string> | undefined,
|
||||
env?: Record<string, string>,
|
||||
): Record<string, string> | undefined {
|
||||
if (!headers) return undefined;
|
||||
const resolved: Record<string, string> = {};
|
||||
for (const [key, value] of Object.entries(headers)) {
|
||||
const resolvedValue = resolveConfigValue(value);
|
||||
const resolvedValue = resolveConfigValue(value, env);
|
||||
if (resolvedValue) {
|
||||
resolved[key] = resolvedValue;
|
||||
}
|
||||
@@ -266,11 +269,12 @@ export function resolveHeaders(headers: Record<string, string> | undefined): Rec
|
||||
export function resolveHeadersOrThrow(
|
||||
headers: Record<string, string> | undefined,
|
||||
description: string,
|
||||
env?: Record<string, string>,
|
||||
): Record<string, string> | undefined {
|
||||
if (!headers) return undefined;
|
||||
const resolved: Record<string, string> = {};
|
||||
for (const [key, value] of Object.entries(headers)) {
|
||||
resolved[key] = resolveConfigValueOrThrow(value, `${description} header "${key}"`);
|
||||
resolved[key] = resolveConfigValueOrThrow(value, `${description} header "${key}"`, env);
|
||||
}
|
||||
return Object.keys(resolved).length > 0 ? resolved : undefined;
|
||||
}
|
||||
|
||||
@@ -303,6 +303,7 @@ export async function createAgentSession(options: CreateAgentSessionOptions = {}
|
||||
if (!auth.ok) {
|
||||
throw new Error(auth.error);
|
||||
}
|
||||
const env = auth.env || options?.env ? { ...(auth.env ?? {}), ...(options?.env ?? {}) } : undefined;
|
||||
const providerRetrySettings = settingsManager.getProviderRetrySettings();
|
||||
const httpIdleTimeoutMs = settingsManager.getHttpIdleTimeoutMs();
|
||||
// SDKs treat timeout=0 as 0ms (immediate timeout), not "no timeout".
|
||||
@@ -314,6 +315,7 @@ export async function createAgentSession(options: CreateAgentSessionOptions = {}
|
||||
return streamSimple(model, context, {
|
||||
...options,
|
||||
apiKey: auth.apiKey,
|
||||
env,
|
||||
timeoutMs,
|
||||
websocketConnectTimeoutMs,
|
||||
maxRetries: options?.maxRetries ?? providerRetrySettings.maxRetries,
|
||||
|
||||
@@ -134,6 +134,34 @@ describe("AuthStorage", () => {
|
||||
}
|
||||
});
|
||||
|
||||
test("apiKey env bag takes precedence over process.env", async () => {
|
||||
const originalEnv = process.env.TEST_AUTH_SCOPED_API_KEY_12345;
|
||||
process.env.TEST_AUTH_SCOPED_API_KEY_12345 = "process-env-value";
|
||||
|
||||
try {
|
||||
writeAuthJson({
|
||||
anthropic: {
|
||||
type: "api_key",
|
||||
key: "$TEST_AUTH_SCOPED_API_KEY_12345",
|
||||
env: { TEST_AUTH_SCOPED_API_KEY_12345: "credential-env-value" },
|
||||
},
|
||||
});
|
||||
|
||||
authStorage = AuthStorage.create(authJsonPath);
|
||||
|
||||
expect(await authStorage.getApiKey("anthropic")).toBe("credential-env-value");
|
||||
expect(authStorage.getProviderEnv("anthropic")).toEqual({
|
||||
TEST_AUTH_SCOPED_API_KEY_12345: "credential-env-value",
|
||||
});
|
||||
} finally {
|
||||
if (originalEnv === undefined) {
|
||||
delete process.env.TEST_AUTH_SCOPED_API_KEY_12345;
|
||||
} else {
|
||||
process.env.TEST_AUTH_SCOPED_API_KEY_12345 = originalEnv;
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
test("apiKey with braced env syntax resolves to env value", async () => {
|
||||
const originalEnv = process.env.TEST_AUTH_BRACED_API_KEY_12345;
|
||||
process.env.TEST_AUTH_BRACED_API_KEY_12345 = "braced-env-api-key-value";
|
||||
|
||||
@@ -893,6 +893,38 @@ describe("ModelRegistry", () => {
|
||||
expect(registry.getProviderDisplayName("oauth-provider")).toBe("OAuth Provider");
|
||||
});
|
||||
|
||||
test("stored API key env propagates to request auth and resolves headers", async () => {
|
||||
authStorage.set("cloudflare-ai-gateway", {
|
||||
type: "api_key",
|
||||
key: "$CLOUDFLARE_API_KEY",
|
||||
env: {
|
||||
CLOUDFLARE_API_KEY: "stored-cf-token",
|
||||
CLOUDFLARE_ACCOUNT_ID: "stored-account",
|
||||
},
|
||||
});
|
||||
writeRawModelsJson({
|
||||
"cloudflare-ai-gateway": {
|
||||
headers: { "x-account": "$CLOUDFLARE_ACCOUNT_ID" },
|
||||
},
|
||||
});
|
||||
|
||||
const registry = ModelRegistry.create(authStorage, modelsJsonPath);
|
||||
const model = registry.getAll().find((m) => m.provider === "cloudflare-ai-gateway");
|
||||
expect(model).toBeDefined();
|
||||
|
||||
const auth = await registry.getApiKeyAndHeaders(model!);
|
||||
|
||||
expect(auth).toEqual({
|
||||
ok: true,
|
||||
apiKey: "stored-cf-token",
|
||||
headers: { "x-account": "stored-account" },
|
||||
env: {
|
||||
CLOUDFLARE_API_KEY: "stored-cf-token",
|
||||
CLOUDFLARE_ACCOUNT_ID: "stored-account",
|
||||
},
|
||||
});
|
||||
});
|
||||
|
||||
test("registerProvider treats uppercase apiKey and headers as literals", async () => {
|
||||
const envKeys = ["CUSTOM_NAME", "BEARER", "MODEL_TOKEN"];
|
||||
const savedEnv: Record<string, string | undefined> = {};
|
||||
|
||||
Reference in New Issue
Block a user