feat: add AWS SigV4 auth for OpenAI-compatible model providers (#17820)

## Summary

Add first-class Amazon Bedrock Mantle provider support so Codex can keep
using its existing Responses API transport with OpenAI-compatible
AWS-hosted endpoints such as AOA/Mantle.

This is needed for the AWS launch path, where provider traffic should
authenticate with AWS credentials instead of OpenAI bearer credentials.
Requests are authenticated immediately before transport send, so SigV4
signs the final method, URL, headers, and body bytes that `reqwest` will
send.

## What Changed

- Added a new `codex-aws-auth` crate for loading AWS SDK config,
resolving credentials, and signing finalized HTTP requests with AWS
SigV4.
- Added a built-in `amazon-bedrock` provider that targets Bedrock Mantle
Responses endpoints, defaults to `us-east-1`, supports region/profile
overrides, disables WebSockets, and does not require OpenAI auth.
- Added Amazon Bedrock auth resolution in `codex-model-provider`: prefer
`AWS_BEARER_TOKEN_BEDROCK` when set, otherwise use AWS SDK credentials
and SigV4 signing.
- Added `AuthProvider::apply_auth` and `Request::prepare_body_for_send`
so request-signing providers can sign the exact outbound request after
JSON serialization/compression.
- Determine the region by taking the `aws.region` config first (required
for bearer token codepath), and fallback to SDK default region.

## Testing
Amazon Bedrock Mantle Responses paths:

- Built the local Codex binary with `cargo build`.
- Verified the custom proxy-backed `aws` provider using `env_key =
"AWS_BEARER_TOKEN_BEDROCK"` streamed raw `responses` output with
`response.output_text.delta`, `response.completed`, and `mantle-env-ok`.
- Verified a full `codex exec --profile aws` turn returned
`mantle-env-ok`.
- Confirmed the custom provider used the bearer env var, not AWS profile
auth: bogus `AWS_PROFILE` still passed, empty env var failed locally,
and malformed env var reached Mantle and failed with `401
invalid_api_key`.
- Verified built-in `amazon-bedrock` with `AWS_BEARER_TOKEN_BEDROCK` set
passed despite bogus AWS profiles, returning `amazon-bedrock-env-ok`.
- Verified built-in `amazon-bedrock` SDK/SigV4 auth passed with
`AWS_BEARER_TOKEN_BEDROCK` unset and temporary AWS session env
credentials, returning `amazon-bedrock-sdk-env-ok`.
This commit is contained in:
Celia Chen
2026-04-21 18:11:17 -07:00
committed by GitHub
Unverified
parent e18fe7a07f
commit 1cd3ad1f49
25 changed files with 1676 additions and 94 deletions
+24
View File
File diff suppressed because one or more lines are too long
+404 -11
View File
@@ -749,6 +749,48 @@ version = "1.5.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "c08606f8c3cbf4ce6ec8e28fb0014a2c086708fe954eaa885384a6165172e7e8"
[[package]]
name = "aws-config"
version = "1.8.12"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "96571e6996817bf3d58f6b569e4b9fd2e9d2fcf9f7424eed07b2ce9bb87535e5"
dependencies = [
"aws-credential-types",
"aws-runtime",
"aws-sdk-sso",
"aws-sdk-ssooidc",
"aws-sdk-sts",
"aws-smithy-async",
"aws-smithy-http",
"aws-smithy-json",
"aws-smithy-runtime",
"aws-smithy-runtime-api",
"aws-smithy-types",
"aws-types",
"bytes",
"fastrand",
"hex",
"http 1.4.0",
"ring",
"time",
"tokio",
"tracing",
"url",
"zeroize",
]
[[package]]
name = "aws-credential-types"
version = "1.2.11"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "3cd362783681b15d136480ad555a099e82ecd8e2d10a841e14dfd0078d67fee3"
dependencies = [
"aws-smithy-async",
"aws-smithy-runtime-api",
"aws-smithy-types",
"zeroize",
]
[[package]]
name = "aws-lc-rs"
version = "1.16.2"
@@ -772,6 +814,290 @@ dependencies = [
"fs_extra",
]
[[package]]
name = "aws-runtime"
version = "1.5.17"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "d81b5b2898f6798ad58f484856768bca817e3cd9de0974c24ae0f1113fe88f1b"
dependencies = [
"aws-credential-types",
"aws-sigv4",
"aws-smithy-async",
"aws-smithy-http",
"aws-smithy-runtime",
"aws-smithy-runtime-api",
"aws-smithy-types",
"aws-types",
"bytes",
"fastrand",
"http 0.2.12",
"http-body 0.4.6",
"percent-encoding",
"pin-project-lite",
"tracing",
"uuid",
]
[[package]]
name = "aws-sdk-sso"
version = "1.91.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "8ee6402a36f27b52fe67661c6732d684b2635152b676aa2babbfb5204f99115d"
dependencies = [
"aws-credential-types",
"aws-runtime",
"aws-smithy-async",
"aws-smithy-http",
"aws-smithy-json",
"aws-smithy-runtime",
"aws-smithy-runtime-api",
"aws-smithy-types",
"aws-types",
"bytes",
"fastrand",
"http 0.2.12",
"regex-lite",
"tracing",
]
[[package]]
name = "aws-sdk-ssooidc"
version = "1.93.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "a45a7f750bbd170ee3677671ad782d90b894548f4e4ae168302c57ec9de5cb3e"
dependencies = [
"aws-credential-types",
"aws-runtime",
"aws-smithy-async",
"aws-smithy-http",
"aws-smithy-json",
"aws-smithy-runtime",
"aws-smithy-runtime-api",
"aws-smithy-types",
"aws-types",
"bytes",
"fastrand",
"http 0.2.12",
"regex-lite",
"tracing",
]
[[package]]
name = "aws-sdk-sts"
version = "1.95.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "55542378e419558e6b1f398ca70adb0b2088077e79ad9f14eb09441f2f7b2164"
dependencies = [
"aws-credential-types",
"aws-runtime",
"aws-smithy-async",
"aws-smithy-http",
"aws-smithy-json",
"aws-smithy-query",
"aws-smithy-runtime",
"aws-smithy-runtime-api",
"aws-smithy-types",
"aws-smithy-xml",
"aws-types",
"fastrand",
"http 0.2.12",
"regex-lite",
"tracing",
]
[[package]]
name = "aws-sigv4"
version = "1.3.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "69e523e1c4e8e7e8ff219d732988e22bfeae8a1cafdbe6d9eca1546fa080be7c"
dependencies = [
"aws-credential-types",
"aws-smithy-http",
"aws-smithy-runtime-api",
"aws-smithy-types",
"bytes",
"form_urlencoded",
"hex",
"hmac",
"http 0.2.12",
"http 1.4.0",
"percent-encoding",
"sha2",
"time",
"tracing",
]
[[package]]
name = "aws-smithy-async"
version = "1.2.7"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "9ee19095c7c4dda59f1697d028ce704c24b2d33c6718790c7f1d5a3015b4107c"
dependencies = [
"futures-util",
"pin-project-lite",
"tokio",
]
[[package]]
name = "aws-smithy-http"
version = "0.62.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "826141069295752372f8203c17f28e30c464d22899a43a0c9fd9c458d469c88b"
dependencies = [
"aws-smithy-runtime-api",
"aws-smithy-types",
"bytes",
"bytes-utils",
"futures-core",
"futures-util",
"http 0.2.12",
"http 1.4.0",
"http-body 0.4.6",
"percent-encoding",
"pin-project-lite",
"pin-utils",
"tracing",
]
[[package]]
name = "aws-smithy-http-client"
version = "1.1.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "59e62db736db19c488966c8d787f52e6270be565727236fd5579eaa301e7bc4a"
dependencies = [
"aws-smithy-async",
"aws-smithy-runtime-api",
"aws-smithy-types",
"h2",
"http 1.4.0",
"hyper",
"hyper-rustls",
"hyper-util",
"pin-project-lite",
"rustls",
"rustls-native-certs",
"rustls-pki-types",
"tokio",
"tokio-rustls",
"tower",
"tracing",
]
[[package]]
name = "aws-smithy-json"
version = "0.61.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "49fa1213db31ac95288d981476f78d05d9cbb0353d22cdf3472cc05bb02f6551"
dependencies = [
"aws-smithy-types",
]
[[package]]
name = "aws-smithy-observability"
version = "0.1.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "17f616c3f2260612fe44cede278bafa18e73e6479c4e393e2c4518cf2a9a228a"
dependencies = [
"aws-smithy-runtime-api",
]
[[package]]
name = "aws-smithy-query"
version = "0.60.9"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ae5d689cf437eae90460e944a58b5668530d433b4ff85789e69d2f2a556e057d"
dependencies = [
"aws-smithy-types",
"urlencoding",
]
[[package]]
name = "aws-smithy-runtime"
version = "1.9.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "65fda37911905ea4d3141a01364bc5509a0f32ae3f3b22d6e330c0abfb62d247"
dependencies = [
"aws-smithy-async",
"aws-smithy-http",
"aws-smithy-http-client",
"aws-smithy-observability",
"aws-smithy-runtime-api",
"aws-smithy-types",
"bytes",
"fastrand",
"http 0.2.12",
"http 1.4.0",
"http-body 0.4.6",
"http-body 1.0.1",
"pin-project-lite",
"pin-utils",
"tokio",
"tracing",
]
[[package]]
name = "aws-smithy-runtime-api"
version = "1.9.3"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "ab0d43d899f9e508300e587bf582ba54c27a452dd0a9ea294690669138ae14a2"
dependencies = [
"aws-smithy-async",
"aws-smithy-types",
"bytes",
"http 0.2.12",
"http 1.4.0",
"pin-project-lite",
"tokio",
"tracing",
"zeroize",
]
[[package]]
name = "aws-smithy-types"
version = "1.3.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "905cb13a9895626d49cf2ced759b062d913834c7482c38e49557eac4e6193f01"
dependencies = [
"base64-simd",
"bytes",
"bytes-utils",
"http 0.2.12",
"http 1.4.0",
"http-body 0.4.6",
"http-body 1.0.1",
"http-body-util",
"itoa",
"num-integer",
"pin-project-lite",
"pin-utils",
"ryu",
"serde",
"time",
]
[[package]]
name = "aws-smithy-xml"
version = "0.60.13"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "11b2f670422ff42bf7065031e72b45bc52a3508bd089f743ea90731ca2b6ea57"
dependencies = [
"xmlparser",
]
[[package]]
name = "aws-types"
version = "1.3.11"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1d980627d2dd7bfc32a3c025685a033eeab8d365cc840c631ef59d1b8f428164"
dependencies = [
"aws-credential-types",
"aws-smithy-async",
"aws-smithy-runtime-api",
"aws-smithy-types",
"rustc_version",
"tracing",
]
[[package]]
name = "axum"
version = "0.8.8"
@@ -784,7 +1110,7 @@ dependencies = [
"form_urlencoded",
"futures-util",
"http 1.4.0",
"http-body",
"http-body 1.0.1",
"http-body-util",
"hyper",
"hyper-util",
@@ -817,7 +1143,7 @@ dependencies = [
"bytes",
"futures-core",
"http 1.4.0",
"http-body",
"http-body 1.0.1",
"http-body-util",
"mime",
"pin-project-lite",
@@ -860,6 +1186,16 @@ version = "0.22.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "72b3254f16251a8381aa12e40e3c4d2f0199f8c6508fbecb9d91f575e0fbb8c6"
[[package]]
name = "base64-simd"
version = "0.8.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "339abbe78e73178762e23bea9dfd08e697eb3f3301cd4be981c0f78ba5859195"
dependencies = [
"outref",
"vsimd",
]
[[package]]
name = "base64ct"
version = "1.8.3"
@@ -1050,6 +1386,16 @@ version = "1.11.1"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1e748733b7cbc798e1434b6ac524f0c1ff2ab456fe201501e6497c8417a4fc33"
[[package]]
name = "bytes-utils"
version = "0.1.4"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7dafe3a8757b027e2be6e4e5601ed563c55989fcf1546e933c66c8eb3a058d35"
dependencies = [
"bytes",
"either",
]
[[package]]
name = "bytestring"
version = "1.5.0"
@@ -1639,6 +1985,21 @@ dependencies = [
"tokio-util",
]
[[package]]
name = "codex-aws-auth"
version = "0.0.0"
dependencies = [
"aws-config",
"aws-credential-types",
"aws-sigv4",
"aws-types",
"bytes",
"http 1.4.0",
"pretty_assertions",
"thiserror 2.0.18",
"tokio",
]
[[package]]
name = "codex-backend-client"
version = "0.0.0"
@@ -2496,11 +2857,14 @@ version = "0.0.0"
dependencies = [
"async-trait",
"codex-api",
"codex-aws-auth",
"codex-client",
"codex-login",
"codex-model-provider-info",
"codex-protocol",
"http 1.4.0",
"pretty_assertions",
"tokio",
]
[[package]]
@@ -6422,6 +6786,17 @@ dependencies = [
"itoa",
]
[[package]]
name = "http-body"
version = "0.4.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "7ceab25649e9960c0311ea418d17bee82c0dcec1bd053b5f9a66e265a693bed2"
dependencies = [
"bytes",
"http 0.2.12",
"pin-project-lite",
]
[[package]]
name = "http-body"
version = "1.0.1"
@@ -6441,7 +6816,7 @@ dependencies = [
"bytes",
"futures-core",
"http 1.4.0",
"http-body",
"http-body 1.0.1",
"pin-project-lite",
]
@@ -6475,7 +6850,7 @@ dependencies = [
"futures-core",
"h2",
"http 1.4.0",
"http-body",
"http-body 1.0.1",
"httparse",
"httpdate",
"itoa",
@@ -6544,7 +6919,7 @@ dependencies = [
"futures-channel",
"futures-util",
"http 1.4.0",
"http-body",
"http-body 1.0.1",
"hyper",
"ipnet",
"libc",
@@ -8636,6 +9011,12 @@ dependencies = [
"windows-sys 0.61.2",
]
[[package]]
name = "outref"
version = "0.5.2"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "1a80800c0488c3a21695ea981a54918fbb37abf04f4d0720c453632255e2ff0e"
[[package]]
name = "owo-colors"
version = "4.3.0"
@@ -9594,7 +9975,7 @@ dependencies = [
"const_format",
"fnv",
"http 1.4.0",
"http-body",
"http-body 1.0.1",
"http-body-util",
"itoa",
"memchr",
@@ -9995,7 +10376,7 @@ dependencies = [
"futures-util",
"h2",
"http 1.4.0",
"http-body",
"http-body 1.0.1",
"http-body-util",
"hyper",
"hyper-rustls",
@@ -10083,7 +10464,7 @@ dependencies = [
"chrono",
"futures",
"http 1.4.0",
"http-body",
"http-body 1.0.1",
"http-body-util",
"oauth2",
"pastey",
@@ -11362,7 +11743,7 @@ checksum = "eb4dc4d33c68ec1f27d386b5610a351922656e1fdf5c05bbaad930cd1519479a"
dependencies = [
"bytes",
"futures-util",
"http-body",
"http-body 1.0.1",
"http-body-util",
"pin-project-lite",
]
@@ -12238,7 +12619,7 @@ dependencies = [
"bytes",
"h2",
"http 1.4.0",
"http-body",
"http-body 1.0.1",
"http-body-util",
"hyper",
"hyper-timeout",
@@ -12325,7 +12706,7 @@ dependencies = [
"bytes",
"futures-util",
"http 1.4.0",
"http-body",
"http-body 1.0.1",
"iri-string",
"pin-project-lite",
"tower",
@@ -12869,6 +13250,12 @@ version = "0.9.5"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "0b928f33d975fc6ad9f86c8f283853ad26bdd5b10b7f1542aa2fa15e2289105a"
[[package]]
name = "vsimd"
version = "0.8.0"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "5c3082ca00d5a5ef149bb8b555a72ae84c9c59f7250f013ac822ac2e49b19c64"
[[package]]
name = "vt100"
version = "0.16.2"
@@ -13935,6 +14322,12 @@ dependencies = [
"windows-sys 0.59.0",
]
[[package]]
name = "xmlparser"
version = "0.13.6"
source = "registry+https://github.com/rust-lang/crates.io-index"
checksum = "66fee0b777b0f5ac1c69bb06d361268faafa61cd4682ae064a171c16c433e9e4"
[[package]]
name = "xz2"
version = "0.1.7"
+6
View File
@@ -1,5 +1,6 @@
[workspace]
members = [
"aws-auth",
"analytics",
"backend-client",
"ansi-escape",
@@ -113,6 +114,7 @@ app_test_support = { path = "app-server/tests/common" }
codex-analytics = { path = "analytics" }
codex-ansi-escape = { path = "ansi-escape" }
codex-api = { path = "codex-api" }
codex-aws-auth = { path = "aws-auth" }
codex-app-server = { path = "app-server" }
codex-app-server-client = { path = "app-server-client" }
codex-app-server-protocol = { path = "app-server-protocol" }
@@ -218,6 +220,10 @@ async-channel = "2.3.1"
async-io = "2.6.0"
async-stream = "0.3.6"
async-trait = "0.1.89"
aws-config = "1"
aws-credential-types = "1"
aws-sigv4 = "1"
aws-types = "1"
axum = { version = "0.8", default-features = false }
base64 = "0.22.1"
bm25 = "2.3.2"
+6
View File
@@ -0,0 +1,6 @@
load("//:defs.bzl", "codex_rust_crate")
codex_rust_crate(
name = "aws-auth",
crate_name = "codex_aws_auth",
)
+26
View File
@@ -0,0 +1,26 @@
[package]
edition.workspace = true
license.workspace = true
name = "codex-aws-auth"
version.workspace = true
[lib]
doctest = false
name = "codex_aws_auth"
path = "src/lib.rs"
[lints]
workspace = true
[dependencies]
aws-config = { workspace = true }
aws-credential-types = { workspace = true }
aws-sigv4 = { workspace = true }
aws-types = { workspace = true }
bytes = { workspace = true }
http = { workspace = true }
thiserror = { workspace = true }
[dev-dependencies]
pretty_assertions = { workspace = true }
tokio = { workspace = true, features = ["macros", "rt-multi-thread"] }
+38
View File
@@ -0,0 +1,38 @@
use aws_config::BehaviorVersion;
use aws_config::SdkConfig;
use aws_credential_types::provider::SharedCredentialsProvider;
use aws_types::region::Region;
use crate::AwsAuthConfig;
use crate::AwsAuthError;
pub(crate) async fn load_sdk_config(config: &AwsAuthConfig) -> Result<SdkConfig, AwsAuthError> {
if config.service.trim().is_empty() {
return Err(AwsAuthError::EmptyService);
}
let mut loader = aws_config::defaults(BehaviorVersion::latest());
if let Some(profile) = config.profile.as_ref() {
loader = loader.profile_name(profile);
}
if let Some(region) = config.region.as_ref() {
loader = loader.region(Region::new(region.clone()));
}
Ok(loader.load().await)
}
pub(crate) fn credentials_provider(
sdk_config: &SdkConfig,
) -> Result<SharedCredentialsProvider, AwsAuthError> {
sdk_config
.credentials_provider()
.ok_or(AwsAuthError::MissingCredentialsProvider)
}
pub(crate) fn resolved_region(sdk_config: &SdkConfig) -> Result<String, AwsAuthError> {
sdk_config
.region()
.map(ToString::to_string)
.ok_or(AwsAuthError::MissingRegion)
}
+261
View File
@@ -0,0 +1,261 @@
mod config;
mod signing;
use std::time::SystemTime;
use aws_credential_types::provider::ProvideCredentials;
use aws_credential_types::provider::SharedCredentialsProvider;
use bytes::Bytes;
use http::HeaderMap;
use http::Method;
use thiserror::Error;
/// AWS auth configuration used to resolve credentials and sign requests.
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct AwsAuthConfig {
pub profile: Option<String>,
pub region: Option<String>,
pub service: String,
}
/// Generic HTTP request shape consumed by SigV4 signing.
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct AwsRequestToSign {
pub method: Method,
pub url: String,
pub headers: HeaderMap,
pub body: Bytes,
}
/// Signed request parts returned to the caller.
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct AwsSignedRequest {
pub url: String,
pub headers: HeaderMap,
}
/// Errors returned by credential loading or SigV4 signing.
#[derive(Debug, Error)]
pub enum AwsAuthError {
#[error("AWS service name must not be empty")]
EmptyService,
#[error("AWS SDK config did not resolve a credentials provider")]
MissingCredentialsProvider,
#[error("AWS SDK config did not resolve a region")]
MissingRegion,
#[error("failed to load AWS credentials: {0}")]
Credentials(#[from] aws_credential_types::provider::error::CredentialsError),
#[error("request URL is not a valid URI: {0}")]
InvalidUri(#[source] http::uri::InvalidUri),
#[error("failed to construct HTTP request for signing: {0}")]
BuildHttpRequest(#[source] http::Error),
#[error("request contains a non-UTF8 header value: {0}")]
InvalidHeaderValue(#[source] http::header::ToStrError),
#[error("failed to build signable request: {0}")]
SigningRequest(#[source] aws_sigv4::http_request::SigningError),
#[error("failed to build SigV4 signing params: {0}")]
SigningParams(String),
#[error("SigV4 signing failed: {0}")]
SigningFailure(#[source] aws_sigv4::http_request::SigningError),
}
/// Loaded AWS auth context that can sign outbound HTTP requests.
#[derive(Clone)]
pub struct AwsAuthContext {
credentials_provider: SharedCredentialsProvider,
region: String,
service: String,
}
impl std::fmt::Debug for AwsAuthContext {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("AwsAuthContext")
.field("region", &self.region)
.field("service", &self.service)
.finish_non_exhaustive()
}
}
impl AwsAuthContext {
pub async fn load(config: AwsAuthConfig) -> Result<Self, AwsAuthError> {
let sdk_config = config::load_sdk_config(&config).await?;
let credentials_provider = config::credentials_provider(&sdk_config)?;
let region = config::resolved_region(&sdk_config)?;
Ok(Self {
credentials_provider,
region,
service: config.service.trim().to_string(),
})
}
pub fn region(&self) -> &str {
&self.region
}
pub fn service(&self) -> &str {
&self.service
}
pub async fn sign(&self, request: AwsRequestToSign) -> Result<AwsSignedRequest, AwsAuthError> {
self.sign_at(request, SystemTime::now()).await
}
async fn sign_at(
&self,
request: AwsRequestToSign,
time: SystemTime,
) -> Result<AwsSignedRequest, AwsAuthError> {
let credentials = self.credentials_provider.provide_credentials().await?;
signing::sign_request(&credentials, &self.region, &self.service, request, time)
}
}
impl AwsAuthError {
/// Returns whether retrying the outbound request can reasonably recover from this auth error.
pub fn is_retryable(&self) -> bool {
match self {
AwsAuthError::Credentials(error) => matches!(
error,
aws_credential_types::provider::error::CredentialsError::ProviderTimedOut(_)
| aws_credential_types::provider::error::CredentialsError::ProviderError(_)
),
AwsAuthError::EmptyService
| AwsAuthError::MissingCredentialsProvider
| AwsAuthError::MissingRegion
| AwsAuthError::InvalidUri(_)
| AwsAuthError::BuildHttpRequest(_)
| AwsAuthError::InvalidHeaderValue(_)
| AwsAuthError::SigningRequest(_)
| AwsAuthError::SigningParams(_)
| AwsAuthError::SigningFailure(_) => false,
}
}
}
#[cfg(test)]
mod tests {
use std::time::Duration;
use std::time::UNIX_EPOCH;
use aws_credential_types::Credentials;
use aws_credential_types::provider::error::CredentialsError;
use pretty_assertions::assert_eq;
use super::*;
fn test_context(session_token: Option<&str>) -> AwsAuthContext {
AwsAuthContext {
credentials_provider: SharedCredentialsProvider::new(Credentials::new(
"AKIDEXAMPLE",
"wJalrXUtnFEMI/K7MDENG+bPxRfiCYEXAMPLEKEY",
session_token.map(str::to_string),
/*expires_after*/ None,
"unit-test",
)),
region: "us-east-1".to_string(),
service: "bedrock".to_string(),
}
}
fn test_request() -> AwsRequestToSign {
let mut headers = HeaderMap::new();
headers.insert(
http::header::CONTENT_TYPE,
http::HeaderValue::from_static("application/json"),
);
headers.insert("x-test-header", http::HeaderValue::from_static("present"));
AwsRequestToSign {
method: Method::POST,
url: "https://bedrock-runtime.us-east-1.amazonaws.com/v1/responses".to_string(),
headers,
body: Bytes::from_static(br#"{"model":"openai.gpt-oss-120b-1:0"}"#),
}
}
#[tokio::test]
async fn sign_adds_sigv4_headers_and_preserves_existing_headers() {
let signed = test_context(/*session_token*/ None)
.sign_at(
test_request(),
UNIX_EPOCH + Duration::from_secs(1_700_000_000),
)
.await
.expect("request should sign");
assert_eq!(
signing::header_value(&signed.headers, http::header::CONTENT_TYPE.as_str()),
Some("application/json".to_string())
);
assert_eq!(
signing::header_value(&signed.headers, "x-test-header"),
Some("present".to_string())
);
assert_eq!(
signed.url,
"https://bedrock-runtime.us-east-1.amazonaws.com/v1/responses"
);
assert!(
signing::header_value(&signed.headers, http::header::AUTHORIZATION.as_str())
.is_some_and(|value| value.starts_with("AWS4-HMAC-SHA256 "))
);
assert!(signing::header_value(&signed.headers, "x-amz-date").is_some());
}
#[test]
fn credentials_provider_failures_are_retryable() {
assert!(
AwsAuthError::Credentials(CredentialsError::provider_error("temporarily unavailable"))
.is_retryable()
);
assert!(
AwsAuthError::Credentials(CredentialsError::provider_timed_out(Duration::from_secs(1)))
.is_retryable()
);
}
#[test]
fn deterministic_aws_auth_errors_are_not_retryable() {
assert!(!AwsAuthError::EmptyService.is_retryable());
assert!(
!AwsAuthError::Credentials(CredentialsError::not_loaded_no_source()).is_retryable()
);
assert!(
!AwsAuthError::Credentials(CredentialsError::invalid_configuration("bad profile"))
.is_retryable()
);
assert!(
!AwsAuthError::Credentials(CredentialsError::unhandled("unexpected response"))
.is_retryable()
);
}
#[tokio::test]
async fn sign_includes_session_token_when_credentials_have_one() {
let signed = test_context(Some("session-token"))
.sign_at(
test_request(),
UNIX_EPOCH + Duration::from_secs(1_700_000_000),
)
.await
.expect("request should sign");
assert_eq!(
signing::header_value(&signed.headers, "x-amz-security-token"),
Some("session-token".to_string())
);
}
#[tokio::test]
async fn load_rejects_empty_service_name() {
let err = AwsAuthContext::load(AwsAuthConfig {
profile: None,
region: None,
service: " ".to_string(),
})
.await
.expect_err("empty service should be rejected");
assert_eq!(err.to_string(), "AWS service name must not be empty");
}
}
+76
View File
@@ -0,0 +1,76 @@
use std::str::FromStr;
use std::time::SystemTime;
use aws_credential_types::Credentials;
use aws_sigv4::http_request::SignableBody;
use aws_sigv4::http_request::SignableRequest;
use aws_sigv4::http_request::SigningSettings;
use aws_sigv4::http_request::sign;
use aws_sigv4::sign::v4;
use http::Request;
use http::Uri;
use crate::AwsAuthError;
use crate::AwsRequestToSign;
use crate::AwsSignedRequest;
pub(crate) fn sign_request(
credentials: &Credentials,
region: &str,
service: &str,
request: AwsRequestToSign,
time: SystemTime,
) -> Result<AwsSignedRequest, AwsAuthError> {
let signable_headers = request
.headers
.iter()
.map(|(name, value)| {
Ok::<_, AwsAuthError>((
name.as_str(),
value.to_str().map_err(AwsAuthError::InvalidHeaderValue)?,
))
})
.collect::<Result<Vec<_>, _>>()?;
let signable_request = SignableRequest::new(
request.method.as_str(),
request.url.as_str(),
signable_headers.into_iter(),
SignableBody::Bytes(request.body.as_ref()),
)
.map_err(AwsAuthError::SigningRequest)?;
let identity = credentials.clone().into();
let signing_params = v4::SigningParams::builder()
.identity(&identity)
.region(region)
.name(service)
.time(time)
.settings(SigningSettings::default())
.build()
.map_err(|err| AwsAuthError::SigningParams(err.to_string()))?;
let (instructions, _signature) = sign(signable_request, &signing_params.into())
.map_err(AwsAuthError::SigningFailure)?
.into_parts();
let uri = Uri::from_str(&request.url).map_err(AwsAuthError::InvalidUri)?;
let mut http_request = Request::builder()
.method(request.method)
.uri(uri)
.body(())
.map_err(AwsAuthError::BuildHttpRequest)?;
*http_request.headers_mut() = request.headers;
instructions.apply_to_request_http1x(&mut http_request);
Ok(AwsSignedRequest {
url: http_request.uri().to_string(),
headers: http_request.headers().clone(),
})
}
#[cfg(test)]
pub(crate) fn header_value(headers: &http::HeaderMap, name: &str) -> Option<String> {
headers
.get(name)
.and_then(|value| value.to_str().ok())
.map(str::to_string)
}
+46 -4
View File
@@ -1,13 +1,55 @@
use async_trait::async_trait;
use codex_client::Request;
use codex_client::TransportError;
use http::HeaderMap;
use std::sync::Arc;
/// Adds authentication headers to API requests.
/// Error returned while applying authentication to an outbound request.
#[derive(Debug, thiserror::Error)]
pub enum AuthError {
#[error("request auth build error: {0}")]
Build(String),
#[error("transient auth error: {0}")]
Transient(String),
}
impl From<AuthError> for TransportError {
fn from(error: AuthError) -> Self {
match error {
AuthError::Build(message) => TransportError::Build(message),
AuthError::Transient(message) => TransportError::Network(message),
}
}
}
/// Applies authentication to API requests.
///
/// Implementations should be cheap and non-blocking; any asynchronous
/// refresh or I/O should be handled by higher layers before requests
/// reach this interface.
/// Header-only providers can implement `add_auth_headers`; providers that sign
/// complete requests can override `apply_auth`.
#[async_trait]
pub trait AuthProvider: Send + Sync {
/// Adds any auth headers that are available without request body access.
///
/// Implementations should be cheap and non-blocking. This method is also
/// used by telemetry and non-HTTP request paths.
fn add_auth_headers(&self, headers: &mut HeaderMap);
/// Applies auth to a complete outbound request and returns the request to send.
///
/// The input `request` is moved into this method. Implementations may mutate
/// the owned request, or replace it entirely, before returning.
///
/// Header-only auth providers can rely on the default implementation.
/// Request-signing providers can override this to inspect the final URL,
/// headers, and body bytes before the transport sends the request.
///
/// Callers must always use the returned request as authoritative.
/// If this returns [`AuthError`], the request should not be sent.
async fn apply_auth(&self, request: Request) -> Result<Request, AuthError> {
let mut request = request;
self.add_auth_headers(&mut request.headers);
Ok(request)
}
}
/// Shared auth handle passed through API clients.
+17 -3
View File
@@ -8,6 +8,7 @@ use codex_client::RequestBody;
use codex_client::RequestTelemetry;
use codex_client::Response;
use codex_client::StreamResponse;
use codex_client::TransportError;
use http::HeaderMap;
use http::Method;
use serde_json::Value;
@@ -55,7 +56,6 @@ impl<T: HttpTransport> EndpointSession<T> {
if let Some(body) = body {
req.body = Some(RequestBody::Json(body.clone()));
}
self.auth.add_auth_headers(&mut req.headers);
req
}
@@ -97,7 +97,14 @@ impl<T: HttpTransport> EndpointSession<T> {
self.provider.retry.to_policy(),
self.request_telemetry.clone(),
make_request,
|req| self.transport.execute(req),
|req| {
let auth = self.auth.clone();
let transport = &self.transport;
async move {
let req = auth.apply_auth(req).await.map_err(TransportError::from)?;
transport.execute(req).await
}
},
)
.await?;
@@ -131,7 +138,14 @@ impl<T: HttpTransport> EndpointSession<T> {
self.provider.retry.to_policy(),
self.request_telemetry.clone(),
make_request,
|req| self.transport.stream(req),
|req| {
let auth = self.auth.clone();
let transport = &self.transport;
async move {
let req = auth.apply_auth(req).await.map_err(TransportError::from)?;
transport.stream(req).await
}
},
)
.await?;
+1
View File
@@ -16,6 +16,7 @@ pub use codex_client::ReqwestTransport;
pub use codex_client::TransportError;
pub use crate::api_bridge::map_api_error;
pub use crate::auth::AuthError;
pub use crate::auth::AuthHeaderTelemetry;
pub use crate::auth::AuthProvider;
pub use crate::auth::SharedAuthProvider;
+114
View File
@@ -5,6 +5,8 @@ use std::time::Duration;
use anyhow::Result;
use async_trait::async_trait;
use bytes::Bytes;
use codex_api::ApiError;
use codex_api::AuthError;
use codex_api::AuthProvider;
use codex_api::Compression;
use codex_api::Provider;
@@ -164,6 +166,59 @@ impl FlakyTransport {
}
}
#[derive(Clone)]
struct FailsOnceAuth {
attempts: Arc<Mutex<i64>>,
error: Arc<AuthError>,
}
impl FailsOnceAuth {
fn transient() -> Self {
Self {
attempts: Arc::new(Mutex::new(0)),
error: Arc::new(AuthError::Transient(
"sts temporarily unavailable".to_string(),
)),
}
}
fn build() -> Self {
Self {
attempts: Arc::new(Mutex::new(0)),
error: Arc::new(AuthError::Build("invalid auth configuration".to_string())),
}
}
fn attempts(&self) -> i64 {
*self
.attempts
.lock()
.unwrap_or_else(|err| panic!("mutex poisoned: {err}"))
}
}
#[async_trait]
impl AuthProvider for FailsOnceAuth {
fn add_auth_headers(&self, _headers: &mut HeaderMap) {}
async fn apply_auth(&self, request: Request) -> Result<Request, AuthError> {
let mut attempts = self
.attempts
.lock()
.unwrap_or_else(|err| panic!("mutex poisoned: {err}"));
*attempts += 1;
if *attempts == 1 {
return match self.error.as_ref() {
AuthError::Build(message) => Err(AuthError::Build(message.clone())),
AuthError::Transient(message) => Err(AuthError::Transient(message.clone())),
};
}
Ok(request)
}
}
#[async_trait]
impl HttpTransport for FlakyTransport {
async fn execute(&self, _req: Request) -> Result<Response, TransportError> {
@@ -296,6 +351,65 @@ async fn streaming_client_retries_on_transport_error() -> Result<()> {
Ok(())
}
#[tokio::test]
async fn streaming_client_retries_on_transient_auth_error() -> Result<()> {
let state = RecordingState::default();
let transport = RecordingTransport::new(state.clone());
let auth = FailsOnceAuth::transient();
let mut provider = provider("openai");
provider.retry.max_attempts = 2;
let client = ResponsesClient::new(transport, provider, Arc::new(auth.clone()));
let body = serde_json::json!({ "model": "gpt-test" });
let _stream = client
.stream(
body,
HeaderMap::new(),
Compression::None,
/*turn_state*/ None,
)
.await?;
assert_eq!(auth.attempts(), 2);
assert_eq!(state.take_stream_requests().len(), 1);
Ok(())
}
#[tokio::test]
async fn streaming_client_does_not_retry_auth_build_error() -> Result<()> {
let state = RecordingState::default();
let transport = RecordingTransport::new(state.clone());
let auth = FailsOnceAuth::build();
let mut provider = provider("openai");
provider.retry.max_attempts = 2;
let client = ResponsesClient::new(transport, provider, Arc::new(auth.clone()));
let body = serde_json::json!({ "model": "gpt-test" });
let result = client
.stream(
body,
HeaderMap::new(),
Compression::None,
/*turn_state*/ None,
)
.await;
let err = match result {
Ok(_) => panic!("auth build errors should fail without retry"),
Err(err) => err,
};
assert!(matches!(
err,
ApiError::Transport(TransportError::Build(message))
if message == "invalid auth configuration"
));
assert_eq!(auth.attempts(), 1);
assert_eq!(state.take_stream_requests().len(), 0);
Ok(())
}
#[tokio::test]
async fn azure_default_store_attaches_ids_and_headers() -> Result<()> {
let state = RecordingState::default();
+1
View File
@@ -25,6 +25,7 @@ pub use crate::default_client::CodexHttpClient;
pub use crate::default_client::CodexRequestBuilder;
pub use crate::error::StreamError;
pub use crate::error::TransportError;
pub use crate::request::PreparedRequestBody;
pub use crate::request::Request;
pub use crate::request::RequestBody;
pub use crate::request::RequestCompression;
+142
View File
@@ -1,6 +1,7 @@
use bytes::Bytes;
use http::Method;
use reqwest::header::HeaderMap;
use reqwest::header::HeaderValue;
use serde::Serialize;
use serde_json::Value;
use std::time::Duration;
@@ -27,6 +28,18 @@ impl RequestBody {
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct PreparedRequestBody {
pub headers: HeaderMap,
pub body: Option<Bytes>,
}
impl PreparedRequestBody {
pub fn body_bytes(&self) -> Bytes {
self.body.clone().unwrap_or_default()
}
}
#[derive(Debug, Clone)]
pub struct Request {
pub method: Method,
@@ -63,6 +76,135 @@ impl Request {
self.compression = compression;
self
}
/// Convert the request body into the exact bytes that will be sent.
///
/// Auth schemes such as AWS SigV4 need to sign the final body bytes, including
/// compression and content headers. Calling this method does not mutate the
/// request.
pub fn prepare_body_for_send(&self) -> Result<PreparedRequestBody, String> {
let mut headers = self.headers.clone();
match self.body.as_ref() {
Some(RequestBody::Raw(raw_body)) => {
if self.compression != RequestCompression::None {
return Err("request compression cannot be used with raw bodies".to_string());
}
Ok(PreparedRequestBody {
headers,
body: Some(raw_body.clone()),
})
}
Some(RequestBody::Json(body)) => {
let json = serde_json::to_vec(&body).map_err(|err| err.to_string())?;
let bytes = if self.compression != RequestCompression::None {
if headers.contains_key(http::header::CONTENT_ENCODING) {
return Err(
"request compression was requested but content-encoding is already set"
.to_string(),
);
}
let pre_compression_bytes = json.len();
let compression_start = std::time::Instant::now();
let (compressed, content_encoding) = match self.compression {
RequestCompression::None => unreachable!("guarded by compression != None"),
RequestCompression::Zstd => (
zstd::stream::encode_all(std::io::Cursor::new(json), 3)
.map_err(|err| err.to_string())?,
HeaderValue::from_static("zstd"),
),
};
let post_compression_bytes = compressed.len();
let compression_duration = compression_start.elapsed();
headers.insert(http::header::CONTENT_ENCODING, content_encoding);
tracing::debug!(
pre_compression_bytes,
post_compression_bytes,
compression_duration_ms = compression_duration.as_millis(),
"Compressed request body with zstd"
);
compressed
} else {
json
};
if !headers.contains_key(http::header::CONTENT_TYPE) {
headers.insert(
http::header::CONTENT_TYPE,
HeaderValue::from_static("application/json"),
);
}
Ok(PreparedRequestBody {
headers,
body: Some(Bytes::from(bytes)),
})
}
None => Ok(PreparedRequestBody {
headers,
body: None,
}),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use http::HeaderValue;
use pretty_assertions::assert_eq;
use serde_json::json;
#[test]
fn prepare_body_for_send_serializes_json_and_sets_content_type() {
let request = Request::new(Method::POST, "https://example.com/v1/responses".to_string())
.with_json(&json!({"model": "test-model"}));
let prepared = request
.prepare_body_for_send()
.expect("body should prepare");
assert_eq!(
prepared.body,
Some(Bytes::from_static(br#"{"model":"test-model"}"#))
);
assert_eq!(
prepared
.headers
.get(http::header::CONTENT_TYPE)
.and_then(|value| value.to_str().ok()),
Some("application/json")
);
assert_eq!(
request.body,
Some(RequestBody::Json(json!({"model": "test-model"})))
);
assert_eq!(request.compression, RequestCompression::None);
}
#[test]
fn prepare_body_for_send_rejects_existing_content_encoding_when_compressing() {
let mut request =
Request::new(Method::POST, "https://example.com/v1/responses".to_string())
.with_json(&json!({"model": "test-model"}))
.with_compression(RequestCompression::Zstd);
request.headers.insert(
http::header::CONTENT_ENCODING,
HeaderValue::from_static("gzip"),
);
let err = request
.prepare_body_for_send()
.expect_err("conflicting content-encoding should fail");
assert_eq!(
err,
"request compression was requested but content-encoding is already set"
);
}
}
#[derive(Debug, Clone)]
+8 -61
View File
@@ -3,7 +3,6 @@ use crate::default_client::CodexRequestBuilder;
use crate::error::TransportError;
use crate::request::Request;
use crate::request::RequestBody;
use crate::request::RequestCompression;
use crate::request::Response;
use async_trait::async_trait;
use bytes::Bytes;
@@ -43,12 +42,14 @@ impl ReqwestTransport {
}
fn build(&self, req: Request) -> Result<CodexRequestBuilder, TransportError> {
let prepared = req.prepare_body_for_send().map_err(TransportError::Build)?;
let Request {
method,
url,
mut headers,
body,
compression,
headers: _,
body: _,
compression: _,
timeout,
} = req;
@@ -61,63 +62,9 @@ impl ReqwestTransport {
builder = builder.timeout(timeout);
}
match body {
Some(RequestBody::Raw(raw_body)) => {
if compression != RequestCompression::None {
return Err(TransportError::Build(
"request compression cannot be used with raw bodies".to_string(),
));
}
builder = builder.headers(headers).body(raw_body);
}
Some(RequestBody::Json(body)) => {
if compression != RequestCompression::None {
if headers.contains_key(http::header::CONTENT_ENCODING) {
return Err(TransportError::Build(
"request compression was requested but content-encoding is already set"
.to_string(),
));
}
let json = serde_json::to_vec(&body)
.map_err(|err| TransportError::Build(err.to_string()))?;
let pre_compression_bytes = json.len();
let compression_start = std::time::Instant::now();
let (compressed, content_encoding) = match compression {
RequestCompression::None => unreachable!("guarded by compression != None"),
RequestCompression::Zstd => (
zstd::stream::encode_all(std::io::Cursor::new(json), 3)
.map_err(|err| TransportError::Build(err.to_string()))?,
http::HeaderValue::from_static("zstd"),
),
};
let post_compression_bytes = compressed.len();
let compression_duration = compression_start.elapsed();
// Ensure the server knows to unpack the request body.
headers.insert(http::header::CONTENT_ENCODING, content_encoding);
if !headers.contains_key(http::header::CONTENT_TYPE) {
headers.insert(
http::header::CONTENT_TYPE,
http::HeaderValue::from_static("application/json"),
);
}
tracing::info!(
pre_compression_bytes,
post_compression_bytes,
compression_duration_ms = compression_duration.as_millis(),
"Compressed request body with zstd"
);
builder = builder.headers(headers).body(compressed);
} else {
builder = builder.headers(headers).json(&body);
}
}
None => {
builder = builder.headers(headers);
}
builder = builder.headers(prepared.headers);
if let Some(body) = prepared.body {
builder = builder.body(body);
}
Ok(builder)
}
+4
View File
@@ -1031,6 +1031,10 @@
"profile": {
"description": "AWS profile name to use. When unset, the AWS SDK default chain decides.",
"type": "string"
},
"region": {
"description": "AWS region to use for provider-specific endpoints.",
"type": "string"
}
},
"type": "object"
+21 -3
View File
@@ -395,9 +395,10 @@ fn accepts_amazon_bedrock_aws_profile_override() {
r#"
[model_providers.amazon-bedrock.aws]
profile = "codex-bedrock"
region = "us-west-2"
"#,
)
.expect("Amazon Bedrock AWS profile override should deserialize");
.expect("Amazon Bedrock AWS overrides should deserialize");
assert_eq!(
cfg.model_providers
@@ -406,6 +407,13 @@ profile = "codex-bedrock"
.and_then(|aws| aws.profile.as_deref()),
Some("codex-bedrock")
);
assert_eq!(
cfg.model_providers
.get("amazon-bedrock")
.and_then(|provider| provider.aws.as_ref())
.and_then(|aws| aws.region.as_deref()),
Some("us-west-2")
);
}
#[tokio::test]
@@ -416,9 +424,10 @@ model_provider = "amazon-bedrock"
[model_providers.amazon-bedrock.aws]
profile = "codex-bedrock"
region = "us-west-2"
"#,
)
.expect("Amazon Bedrock AWS profile override should deserialize");
.expect("Amazon Bedrock AWS overrides should deserialize");
let config = Config::load_from_base_config_with_overrides(
cfg,
@@ -437,6 +446,14 @@ profile = "codex-bedrock"
.and_then(|aws| aws.profile.as_deref()),
Some("codex-bedrock")
);
assert_eq!(
config
.model_provider
.aws
.as_ref()
.and_then(|aws| aws.region.as_deref()),
Some("us-west-2")
);
}
#[tokio::test]
@@ -453,6 +470,7 @@ supports_websockets = true
[model_providers.amazon-bedrock.aws]
profile = "codex-bedrock"
region = "us-west-2"
"#,
)
.expect("Amazon Bedrock unsupported overrides should deserialize");
@@ -467,7 +485,7 @@ profile = "codex-bedrock"
assert_eq!(err.kind(), std::io::ErrorKind::InvalidData);
assert!(err.to_string().contains(
"model_providers.amazon-bedrock only supports changing `aws.profile`; other non-default provider fields are not supported"
"model_providers.amazon-bedrock only supports changing `aws.profile` and `aws.region`; other non-default provider fields are not supported"
));
}
+17 -7
View File
@@ -136,6 +136,8 @@ pub struct ModelProviderInfo {
pub struct ModelProviderAwsAuthInfo {
/// AWS profile name to use. When unset, the AWS SDK default chain decides.
pub profile: Option<String>,
/// AWS region to use for provider-specific endpoints.
pub region: Option<String>,
}
impl ModelProviderInfo {
@@ -352,7 +354,10 @@ impl ModelProviderInfo {
env_key_instructions: None,
experimental_bearer_token: None,
auth: None,
aws: Some(aws.unwrap_or(ModelProviderAwsAuthInfo { profile: None })),
aws: Some(aws.unwrap_or(ModelProviderAwsAuthInfo {
profile: None,
region: None,
})),
wire_api: WireApi::Responses,
query_params: None,
http_headers: None,
@@ -422,7 +427,7 @@ pub fn built_in_model_providers(
///
/// Configured providers extend the built-in set. Built-in providers are not
/// generally overridable, but the built-in Amazon Bedrock provider allows the
/// user to set `aws.profile`.
/// user to set `aws.profile` and `aws.region`.
pub fn merge_configured_model_providers(
mut model_providers: HashMap<String, ModelProviderInfo>,
configured_model_providers: HashMap<String, ModelProviderInfo>,
@@ -433,15 +438,20 @@ pub fn merge_configured_model_providers(
if provider != ModelProviderInfo::default() {
return Err(format!(
"model_providers.{AMAZON_BEDROCK_PROVIDER_ID} only supports changing \
`aws.profile`; other non-default provider fields are not supported"
`aws.profile` and `aws.region`; other non-default provider fields are not supported"
));
}
if let Some(profile) = aws_override.and_then(|aws| aws.profile)
&& let Some(built_in) = model_providers.get_mut(AMAZON_BEDROCK_PROVIDER_ID)
&& let Some(aws) = built_in.aws.as_mut()
if let Some(aws_override) = aws_override
&& let Some(built_in_provider) = model_providers.get_mut(AMAZON_BEDROCK_PROVIDER_ID)
&& let Some(built_in_aws) = built_in_provider.aws.as_mut()
{
aws.profile = Some(profile);
if let Some(profile) = aws_override.profile {
built_in_aws.profile = Some(profile);
}
if let Some(region) = aws_override.region {
built_in_aws.region = Some(region);
}
}
} else {
model_providers.entry(key).or_insert(provider);
@@ -225,6 +225,7 @@ base_url = "https://bedrock.example.com/v1"
[aws]
profile = "codex-bedrock"
region = "us-west-2"
"#;
let provider: ModelProviderInfo = toml::from_str(provider_toml).unwrap();
@@ -233,6 +234,7 @@ profile = "codex-bedrock"
provider.aws,
Some(ModelProviderAwsAuthInfo {
profile: Some("codex-bedrock".to_string()),
region: Some("us-west-2".to_string()),
})
);
}
@@ -248,7 +250,10 @@ fn test_create_amazon_bedrock_provider() {
env_key_instructions: None,
experimental_bearer_token: None,
auth: None,
aws: Some(ModelProviderAwsAuthInfo { profile: None }),
aws: Some(ModelProviderAwsAuthInfo {
profile: None,
region: None,
}),
wire_api: WireApi::Responses,
query_params: None,
http_headers: None,
@@ -304,6 +309,7 @@ fn test_merge_configured_model_providers_applies_amazon_bedrock_profile_override
ModelProviderInfo {
aws: Some(ModelProviderAwsAuthInfo {
profile: Some("codex-bedrock".to_string()),
region: Some("us-west-2".to_string()),
}),
..ModelProviderInfo::default()
},
@@ -315,6 +321,7 @@ fn test_merge_configured_model_providers_applies_amazon_bedrock_profile_override
.expect("Amazon Bedrock provider should be built in")
.aws = Some(ModelProviderAwsAuthInfo {
profile: Some("codex-bedrock".to_string()),
region: Some("us-west-2".to_string()),
});
assert_eq!(
@@ -334,6 +341,7 @@ fn test_merge_configured_model_providers_rejects_amazon_bedrock_non_default_fiel
name: "Custom Bedrock".to_string(),
aws: Some(ModelProviderAwsAuthInfo {
profile: Some("codex-bedrock".to_string()),
region: None,
}),
..ModelProviderInfo::default()
},
@@ -345,7 +353,7 @@ fn test_merge_configured_model_providers_rejects_amazon_bedrock_non_default_fiel
configured_model_providers,
),
Err(
"model_providers.amazon-bedrock only supports changing `aws.profile`; other non-default provider fields are not supported"
"model_providers.amazon-bedrock only supports changing `aws.profile` and `aws.region`; other non-default provider fields are not supported"
.to_string()
)
);
@@ -356,7 +364,10 @@ fn test_merge_configured_model_providers_allows_amazon_bedrock_default_fields()
let configured_model_providers = std::collections::HashMap::from([(
AMAZON_BEDROCK_PROVIDER_ID.to_string(),
ModelProviderInfo {
aws: Some(ModelProviderAwsAuthInfo { profile: None }),
aws: Some(ModelProviderAwsAuthInfo {
profile: None,
region: None,
}),
wire_api: WireApi::Responses,
..ModelProviderInfo::default()
},
@@ -374,7 +385,10 @@ fn test_merge_configured_model_providers_allows_amazon_bedrock_default_fields()
#[test]
fn test_validate_provider_aws_rejects_conflicting_auth() {
let provider = ModelProviderInfo {
aws: Some(ModelProviderAwsAuthInfo { profile: None }),
aws: Some(ModelProviderAwsAuthInfo {
profile: None,
region: None,
}),
env_key: Some("AWS_BEARER_TOKEN_BEDROCK".to_string()),
supports_websockets: false,
..ModelProviderInfo::create_openai_provider(/*base_url*/ None)
@@ -389,7 +403,10 @@ fn test_validate_provider_aws_rejects_conflicting_auth() {
#[test]
fn test_validate_provider_aws_rejects_websockets() {
let provider = ModelProviderInfo {
aws: Some(ModelProviderAwsAuthInfo { profile: None }),
aws: Some(ModelProviderAwsAuthInfo {
profile: None,
region: None,
}),
requires_openai_auth: false,
supports_websockets: true,
..ModelProviderInfo::create_openai_provider(/*base_url*/ None)
+3
View File
@@ -15,10 +15,13 @@ workspace = true
[dependencies]
async-trait = { workspace = true }
codex-api = { workspace = true }
codex-aws-auth = { workspace = true }
codex-client = { workspace = true }
codex-login = { workspace = true }
codex-model-provider-info = { workspace = true }
codex-protocol = { workspace = true }
http = { workspace = true }
tokio = { workspace = true, features = ["sync"] }
[dev-dependencies]
pretty_assertions = { workspace = true }
@@ -0,0 +1,232 @@
use std::sync::Arc;
use codex_api::AuthError;
use codex_api::AuthProvider;
use codex_api::SharedAuthProvider;
use codex_aws_auth::AwsAuthConfig;
use codex_aws_auth::AwsAuthContext;
use codex_aws_auth::AwsAuthError;
use codex_aws_auth::AwsRequestToSign;
use codex_client::Request;
use codex_client::RequestBody;
use codex_client::RequestCompression;
use codex_model_provider_info::ModelProviderAwsAuthInfo;
use codex_protocol::error::CodexErr;
use codex_protocol::error::Result;
use http::HeaderMap;
use tokio::sync::OnceCell;
use crate::BearerAuthProvider;
use super::mantle::aws_auth_config;
use super::mantle::region_from_config;
const AWS_BEARER_TOKEN_BEDROCK_ENV_VAR: &str = "AWS_BEARER_TOKEN_BEDROCK";
const LEGACY_SESSION_ID_HEADER: &str = "session_id";
enum BedrockAuthMethod {
EnvBearerToken {
token: String,
region: String,
},
AwsSdkAuth {
config: AwsAuthConfig,
context: AwsAuthContext,
},
}
async fn resolve_auth_method(aws: &ModelProviderAwsAuthInfo) -> Result<BedrockAuthMethod> {
if let Some(token) = bearer_token_from_env() {
let region = bearer_token_region_from_config(aws)?;
return Ok(BedrockAuthMethod::EnvBearerToken { token, region });
}
let config = aws_auth_config(aws);
let context = AwsAuthContext::load(config.clone())
.await
.map_err(aws_auth_error_to_codex_error)?;
Ok(BedrockAuthMethod::AwsSdkAuth { config, context })
}
pub(super) async fn resolve_provider_auth(
aws: &ModelProviderAwsAuthInfo,
) -> Result<SharedAuthProvider> {
match resolve_auth_method(aws).await? {
BedrockAuthMethod::EnvBearerToken { token, .. } => Ok(Arc::new(BearerAuthProvider {
token: Some(token),
account_id: None,
is_fedramp_account: false,
})),
BedrockAuthMethod::AwsSdkAuth { config, context } => Ok(Arc::new(
BedrockMantleSigV4AuthProvider::with_context(config, context),
)),
}
}
pub(super) async fn resolve_region(aws: &ModelProviderAwsAuthInfo) -> Result<String> {
match resolve_auth_method(aws).await? {
BedrockAuthMethod::EnvBearerToken { region, .. } => Ok(region),
BedrockAuthMethod::AwsSdkAuth { context, .. } => Ok(context.region().to_string()),
}
}
fn bearer_token_from_env() -> Option<String> {
std::env::var(AWS_BEARER_TOKEN_BEDROCK_ENV_VAR)
.ok()
.map(|token| token.trim().to_string())
.filter(|token| !token.is_empty())
}
fn bearer_token_region_from_config(aws: &ModelProviderAwsAuthInfo) -> Result<String> {
region_from_config(aws).ok_or_else(|| {
CodexErr::Fatal(
"Amazon Bedrock bearer token auth requires \
`model_providers.amazon-bedrock.aws.region`"
.to_string(),
)
})
}
fn aws_auth_error_to_codex_error(error: AwsAuthError) -> CodexErr {
CodexErr::Fatal(format!("failed to resolve Amazon Bedrock auth: {error}"))
}
fn aws_auth_error_to_auth_error(error: AwsAuthError) -> AuthError {
if error.is_retryable() {
AuthError::Transient(error.to_string())
} else {
AuthError::Build(error.to_string())
}
}
fn remove_headers_not_preserved_by_bedrock_mantle(headers: &mut HeaderMap) {
// The Bedrock Mantle front door does not preserve this legacy OpenAI header
// for SigV4 verification. Signing it makes the richer Codex agent request
// fail even though raw Responses requests work.
headers.remove(LEGACY_SESSION_ID_HEADER);
}
/// AWS SigV4 auth provider for Bedrock Mantle OpenAI-compatible requests.
#[derive(Debug)]
struct BedrockMantleSigV4AuthProvider {
config: AwsAuthConfig,
context: OnceCell<AwsAuthContext>,
}
impl BedrockMantleSigV4AuthProvider {
fn with_context(config: AwsAuthConfig, context: AwsAuthContext) -> Self {
let cell = OnceCell::new();
let _ = cell.set(context);
Self {
config,
context: cell,
}
}
async fn context(&self) -> std::result::Result<&AwsAuthContext, AuthError> {
self.context
.get_or_try_init(|| AwsAuthContext::load(self.config.clone()))
.await
.map_err(aws_auth_error_to_auth_error)
}
}
#[async_trait::async_trait]
impl AuthProvider for BedrockMantleSigV4AuthProvider {
fn add_auth_headers(&self, _headers: &mut HeaderMap) {}
async fn apply_auth(&self, request: Request) -> std::result::Result<Request, AuthError> {
let mut request = request;
remove_headers_not_preserved_by_bedrock_mantle(&mut request.headers);
let prepared = request.prepare_body_for_send().map_err(AuthError::Build)?;
let context = self.context().await?;
let signed = context
.sign(AwsRequestToSign {
method: request.method.clone(),
url: request.url.clone(),
headers: prepared.headers.clone(),
body: prepared.body_bytes(),
})
.await
.map_err(aws_auth_error_to_auth_error)?;
request.url = signed.url;
request.headers = signed.headers;
request.body = prepared.body.map(RequestBody::Raw);
request.compression = RequestCompression::None;
Ok(request)
}
}
#[cfg(test)]
mod tests {
use codex_api::AuthProvider;
use http::HeaderValue;
use pretty_assertions::assert_eq;
use super::*;
#[test]
fn bedrock_bearer_auth_uses_configured_region_and_header() {
let token = "bedrock-api-key-test".to_string();
let region = bearer_token_region_from_config(&ModelProviderAwsAuthInfo {
profile: None,
region: Some(" us-west-2 ".to_string()),
})
.expect("configured region should resolve");
let provider = BearerAuthProvider {
token: Some(token),
account_id: None,
is_fedramp_account: false,
};
let mut headers = http::HeaderMap::new();
provider.add_auth_headers(&mut headers);
assert_eq!(region, "us-west-2");
assert!(
headers
.get(http::header::AUTHORIZATION)
.and_then(|value| value.to_str().ok())
.is_some_and(|value| value.starts_with("Bearer bedrock-api-key-"))
);
}
#[test]
fn bedrock_bearer_auth_rejects_missing_configured_region() {
let err = bearer_token_region_from_config(&ModelProviderAwsAuthInfo {
profile: None,
region: None,
})
.expect_err("missing region should fail");
assert_eq!(
err.to_string(),
"Fatal error: Amazon Bedrock bearer token auth requires \
`model_providers.amazon-bedrock.aws.region`"
);
}
#[test]
fn bedrock_mantle_sigv4_strips_legacy_session_id_header() {
let mut headers = HeaderMap::new();
headers.insert(
LEGACY_SESSION_ID_HEADER,
HeaderValue::from_static("019dae79-15c3-70c3-8736-3219b8602b37"),
);
headers.insert(
"x-client-request-id",
HeaderValue::from_static("request-id"),
);
remove_headers_not_preserved_by_bedrock_mantle(&mut headers);
assert!(!headers.contains_key(LEGACY_SESSION_ID_HEADER));
assert_eq!(
headers
.get("x-client-request-id")
.and_then(|value| value.to_str().ok()),
Some("request-id")
);
}
}
@@ -0,0 +1,101 @@
use codex_aws_auth::AwsAuthConfig;
use codex_model_provider_info::ModelProviderAwsAuthInfo;
use codex_protocol::error::CodexErr;
use codex_protocol::error::Result;
const BEDROCK_MANTLE_SERVICE_NAME: &str = "bedrock-mantle";
const BEDROCK_MANTLE_SUPPORTED_REGIONS: [&str; 12] = [
"us-east-2",
"us-east-1",
"us-west-2",
"ap-southeast-3",
"ap-south-1",
"ap-northeast-1",
"eu-central-1",
"eu-west-1",
"eu-west-2",
"eu-south-1",
"eu-north-1",
"sa-east-1",
];
pub(super) fn aws_auth_config(aws: &ModelProviderAwsAuthInfo) -> AwsAuthConfig {
AwsAuthConfig {
profile: aws.profile.clone(),
region: region_from_config(aws),
service: BEDROCK_MANTLE_SERVICE_NAME.to_string(),
}
}
pub(super) fn region_from_config(aws: &ModelProviderAwsAuthInfo) -> Option<String> {
aws.region
.as_deref()
.map(str::trim)
.filter(|region| !region.is_empty())
.map(str::to_string)
}
pub(super) fn base_url(region: &str) -> Result<String> {
if BEDROCK_MANTLE_SUPPORTED_REGIONS.contains(&region) {
Ok(format!("https://bedrock-mantle.{region}.api.aws/v1"))
} else {
Err(CodexErr::Fatal(format!(
"Amazon Bedrock Mantle does not support region `{region}`"
)))
}
}
#[cfg(test)]
mod tests {
use pretty_assertions::assert_eq;
use super::*;
#[test]
fn base_url_uses_region_endpoint() {
assert_eq!(
base_url("ap-northeast-1").expect("supported region"),
"https://bedrock-mantle.ap-northeast-1.api.aws/v1"
);
}
#[test]
fn base_url_rejects_unsupported_region() {
let err = base_url("us-west-1").expect_err("unsupported region");
assert_eq!(
err.to_string(),
"Fatal error: Amazon Bedrock Mantle does not support region `us-west-1`"
);
}
#[test]
fn aws_auth_config_uses_profile_and_mantle_service() {
assert_eq!(
aws_auth_config(&ModelProviderAwsAuthInfo {
profile: Some("codex-bedrock".to_string()),
region: None,
}),
AwsAuthConfig {
profile: Some("codex-bedrock".to_string()),
region: None,
service: "bedrock-mantle".to_string(),
}
);
}
#[test]
fn aws_auth_config_uses_configured_region() {
assert_eq!(
aws_auth_config(&ModelProviderAwsAuthInfo {
profile: None,
region: Some(" us-west-2 ".to_string()),
}),
AwsAuthConfig {
profile: None,
region: Some("us-west-2".to_string()),
service: "bedrock-mantle".to_string(),
}
);
}
}
@@ -0,0 +1,73 @@
mod auth;
mod mantle;
use std::sync::Arc;
use codex_api::Provider;
use codex_api::SharedAuthProvider;
use codex_login::AuthManager;
use codex_login::CodexAuth;
use codex_model_provider_info::ModelProviderAwsAuthInfo;
use codex_model_provider_info::ModelProviderInfo;
use codex_protocol::error::Result;
use crate::provider::ModelProvider;
use auth::resolve_provider_auth;
use auth::resolve_region;
use mantle::base_url;
/// Runtime provider for Amazon Bedrock's OpenAI-compatible Mantle endpoint.
#[derive(Clone, Debug)]
pub(crate) struct AmazonBedrockModelProvider {
pub(crate) info: ModelProviderInfo,
pub(crate) aws: ModelProviderAwsAuthInfo,
}
#[async_trait::async_trait]
impl ModelProvider for AmazonBedrockModelProvider {
fn info(&self) -> &ModelProviderInfo {
&self.info
}
fn auth_manager(&self) -> Option<Arc<AuthManager>> {
None
}
async fn auth(&self) -> Option<CodexAuth> {
None
}
async fn api_provider(&self) -> Result<Provider> {
let region = resolve_region(&self.aws).await?;
let mut api_provider_info = self.info.clone();
api_provider_info.base_url = Some(base_url(&region)?);
api_provider_info.to_api_provider(/*auth_mode*/ None)
}
async fn api_auth(&self) -> Result<SharedAuthProvider> {
resolve_provider_auth(&self.aws).await
}
}
#[cfg(test)]
mod tests {
use pretty_assertions::assert_eq;
use super::*;
#[test]
fn api_provider_for_bedrock_bearer_token_uses_configured_region_endpoint() {
let region = "eu-central-1";
let mut api_provider_info =
ModelProviderInfo::create_amazon_bedrock_provider(/*aws*/ None);
api_provider_info.base_url = Some(base_url(region).expect("supported region"));
let api_provider = api_provider_info
.to_api_provider(/*auth_mode*/ None)
.expect("api provider should build");
assert_eq!(
api_provider.base_url,
"https://bedrock-mantle.eu-central-1.api.aws/v1"
);
}
}
+1
View File
@@ -1,3 +1,4 @@
mod amazon_bedrock;
mod auth;
mod bearer_auth_provider;
mod provider;
+32
View File
@@ -5,8 +5,10 @@ use codex_api::Provider;
use codex_api::SharedAuthProvider;
use codex_login::AuthManager;
use codex_login::CodexAuth;
use codex_model_provider_info::ModelProviderAwsAuthInfo;
use codex_model_provider_info::ModelProviderInfo;
use crate::amazon_bedrock::AmazonBedrockModelProvider;
use crate::auth::auth_manager_for_provider;
use crate::auth::resolve_provider_auth;
@@ -53,6 +55,20 @@ pub fn create_model_provider(
provider_info: ModelProviderInfo,
auth_manager: Option<Arc<AuthManager>>,
) -> SharedModelProvider {
if provider_info.is_amazon_bedrock() {
let aws = provider_info
.aws
.clone()
.unwrap_or(ModelProviderAwsAuthInfo {
profile: None,
region: None,
});
return Arc::new(AmazonBedrockModelProvider {
info: provider_info,
aws,
});
}
let auth_manager = auth_manager_for_provider(auth_manager, &provider_info);
Arc::new(ConfiguredModelProvider {
info: provider_info,
@@ -89,6 +105,7 @@ impl ModelProvider for ConfiguredModelProvider {
mod tests {
use std::num::NonZeroU64;
use codex_model_provider_info::ModelProviderAwsAuthInfo;
use codex_protocol::config_types::ModelProviderAuthInfo;
use super::*;
@@ -123,4 +140,19 @@ mod tests {
assert!(auth_manager.has_external_auth());
}
#[test]
fn create_model_provider_does_not_use_openai_auth_manager_for_amazon_bedrock_provider() {
let provider = create_model_provider(
ModelProviderInfo::create_amazon_bedrock_provider(Some(ModelProviderAwsAuthInfo {
profile: Some("codex-bedrock".to_string()),
region: None,
})),
Some(AuthManager::from_auth_for_testing(CodexAuth::from_api_key(
"openai-api-key",
))),
);
assert!(provider.auth_manager().is_none());
}
}