diff --git a/codex-rs/app-server-transport/src/transport/remote_control/auth.rs b/codex-rs/app-server-transport/src/transport/remote_control/auth.rs index dbd231027..ec4eb8131 100644 --- a/codex-rs/app-server-transport/src/transport/remote_control/auth.rs +++ b/codex-rs/app-server-transport/src/transport/remote_control/auth.rs @@ -1,3 +1,5 @@ +use axum::http::HeaderMap; +use axum::http::HeaderValue; use codex_api::SharedAuthProvider; use codex_login::AuthManager; use codex_login::UnauthorizedRecovery; @@ -8,11 +10,30 @@ use tokio::sync::watch; use tracing::info; use tracing::warn; +pub(super) const REMOTE_CONTROL_ACCOUNT_ID_HEADER: &str = "chatgpt-account-id"; + pub(super) struct RemoteControlConnectionAuth { pub(super) auth_provider: SharedAuthProvider, pub(super) account_id: String, } +impl RemoteControlConnectionAuth { + pub(super) fn request_headers(&self) -> io::Result { + let mut headers = HeaderMap::new(); + self.auth_provider.add_auth_headers(&mut headers); + headers.insert( + REMOTE_CONTROL_ACCOUNT_ID_HEADER, + HeaderValue::from_str(&self.account_id).map_err(|err| { + io::Error::new( + ErrorKind::InvalidInput, + format!("invalid remote control account id header: {err}"), + ) + })?, + ); + Ok(headers) + } +} + pub(super) async fn load_remote_control_auth( auth_manager: &Arc, ) -> io::Result { @@ -103,3 +124,100 @@ pub(super) fn mark_recovery_auth_change_seen( auth_change_rx.borrow_and_update(); } } + +#[cfg(test)] +mod tests { + use super::*; + use codex_api::AuthProvider; + use pretty_assertions::assert_eq; + + #[derive(Debug)] + struct TestAuthProvider { + account_ids: Vec<&'static str>, + } + + impl AuthProvider for TestAuthProvider { + fn add_auth_headers(&self, headers: &mut HeaderMap) { + headers.insert( + axum::http::header::AUTHORIZATION, + HeaderValue::from_static("Bearer test-token"), + ); + headers.insert("x-openai-fedramp", HeaderValue::from_static("true")); + for account_id in &self.account_ids { + headers.append("ChatGPT-Account-ID", HeaderValue::from_static(account_id)); + } + } + } + + fn remote_control_auth( + account_id: &str, + provider_account_ids: Vec<&'static str>, + ) -> RemoteControlConnectionAuth { + RemoteControlConnectionAuth { + auth_provider: Arc::new(TestAuthProvider { + account_ids: provider_account_ids, + }), + account_id: account_id.to_string(), + } + } + + #[test] + fn request_headers_adds_account_header_when_provider_omits_it() { + let headers = remote_control_auth("selected-account", Vec::new()) + .request_headers() + .expect("request headers should build"); + + assert_eq!( + headers + .get_all(REMOTE_CONTROL_ACCOUNT_ID_HEADER) + .iter() + .map(|value| value.to_str().expect("account header should be text")) + .collect::>(), + vec!["selected-account"] + ); + } + + #[test] + fn request_headers_replaces_provider_accounts_and_preserves_other_headers() { + let headers = remote_control_auth( + "selected-account", + vec!["provider-account-a", "provider-account-b"], + ) + .request_headers() + .expect("request headers should build"); + + assert_eq!( + headers + .get_all(REMOTE_CONTROL_ACCOUNT_ID_HEADER) + .iter() + .map(|value| value.to_str().expect("account header should be text")) + .collect::>(), + vec!["selected-account"] + ); + assert_eq!( + headers + .get(axum::http::header::AUTHORIZATION) + .and_then(|value| value.to_str().ok()), + Some("Bearer test-token") + ); + assert_eq!( + headers + .get("x-openai-fedramp") + .and_then(|value| value.to_str().ok()), + Some("true") + ); + } + + #[test] + fn request_headers_rejects_invalid_account_header_value() { + let err = remote_control_auth("invalid\naccount", Vec::new()) + .request_headers() + .expect_err("invalid account header should fail"); + + assert_eq!(err.kind(), ErrorKind::InvalidInput); + assert!( + err.to_string() + .starts_with("invalid remote control account id header:") + ); + } +} diff --git a/codex-rs/app-server-transport/src/transport/remote_control/clients.rs b/codex-rs/app-server-transport/src/transport/remote_control/clients.rs index 25316c609..121a7685d 100644 --- a/codex-rs/app-server-transport/src/transport/remote_control/clients.rs +++ b/codex-rs/app-server-transport/src/transport/remote_control/clients.rs @@ -1,7 +1,6 @@ use super::auth::RemoteControlConnectionAuth; use super::auth::load_remote_control_auth; use super::auth::recover_remote_control_auth; -use super::enroll::REMOTE_CONTROL_ACCOUNT_ID_HEADER; use super::enroll::format_headers; use super::enroll::preview_remote_control_response_body; use super::protocol::normalize_remote_control_base_url; @@ -188,8 +187,7 @@ async fn send_client_management_request_once( action: &str, ) -> io::Result { let client = build_reqwest_client(); - let mut auth_headers = HeaderMap::new(); - auth.auth_provider.add_auth_headers(&mut auth_headers); + let auth_headers = auth.request_headers()?; let request = match request { ClientManagementRequest::List { url, params } => { let mut query = Vec::new(); @@ -216,7 +214,6 @@ async fn send_client_management_request_once( let response = request .timeout(REMOTE_CONTROL_CLIENT_MANAGEMENT_TIMEOUT) .headers(auth_headers) - .header(REMOTE_CONTROL_ACCOUNT_ID_HEADER, &auth.account_id) .send() .await .map_err(|err| io::Error::other(format!("failed to {action}: {err}")))?; diff --git a/codex-rs/app-server-transport/src/transport/remote_control/enroll.rs b/codex-rs/app-server-transport/src/transport/remote_control/enroll.rs index 3d6c74cce..e327503a9 100644 --- a/codex-rs/app-server-transport/src/transport/remote_control/enroll.rs +++ b/codex-rs/app-server-transport/src/transport/remote_control/enroll.rs @@ -31,7 +31,6 @@ const REMOTE_CONTROL_SERVER_TOKEN_REFRESH_SKEW_SECS: i64 = 30; const REQUEST_ID_HEADER: &str = "x-request-id"; const OAI_REQUEST_ID_HEADER: &str = "x-oai-request-id"; const CF_RAY_HEADER: &str = "cf-ray"; -pub(super) const REMOTE_CONTROL_ACCOUNT_ID_HEADER: &str = "chatgpt-account-id"; pub(super) const REMOTE_CONTROL_INSTALLATION_ID_HEADER: &str = "x-codex-installation-id"; #[derive(Debug, Clone, PartialEq, Eq)] @@ -492,13 +491,11 @@ where Response: DeserializeOwned, { let client = build_reqwest_client(); - let mut auth_headers = HeaderMap::new(); - auth.auth_provider.add_auth_headers(&mut auth_headers); + let auth_headers = auth.request_headers()?; let response = client .post(url) .timeout(REMOTE_CONTROL_ENROLL_TIMEOUT) .headers(auth_headers) - .header(REMOTE_CONTROL_ACCOUNT_ID_HEADER, &auth.account_id) .header(REMOTE_CONTROL_INSTALLATION_ID_HEADER, installation_id) .json(request) .send() diff --git a/codex-rs/app-server-transport/src/transport/remote_control/tests.rs b/codex-rs/app-server-transport/src/transport/remote_control/tests.rs index a506b1884..3a2d84465 100644 --- a/codex-rs/app-server-transport/src/transport/remote_control/tests.rs +++ b/codex-rs/app-server-transport/src/transport/remote_control/tests.rs @@ -1,4 +1,4 @@ -use super::enroll::REMOTE_CONTROL_ACCOUNT_ID_HEADER; +use super::auth::REMOTE_CONTROL_ACCOUNT_ID_HEADER; use super::enroll::REMOTE_CONTROL_INSTALLATION_ID_HEADER; use super::enroll::RemoteControlEnrollment; use super::enroll::load_persisted_remote_control_enrollment; @@ -1473,14 +1473,16 @@ async fn remote_control_http_mode_enrolls_before_connecting() { Some(&"Bearer Access Token".to_string()) ); assert_eq!( - enroll_request.headers.get(REMOTE_CONTROL_ACCOUNT_ID_HEADER), - Some(&"account_id".to_string()) + enroll_request + .headers + .get_all(REMOTE_CONTROL_ACCOUNT_ID_HEADER), + vec!["account_id"] ); assert_eq!( enroll_request .headers - .get(REMOTE_CONTROL_INSTALLATION_ID_HEADER), - Some(&TEST_INSTALLATION_ID.to_string()) + .get_all(REMOTE_CONTROL_INSTALLATION_ID_HEADER), + vec![TEST_INSTALLATION_ID] ); assert_eq!( serde_json::from_str::(&enroll_request.body) @@ -1720,6 +1722,18 @@ async fn remote_control_http_mode_refreshes_persisted_enrollment_before_connecti refresh_request.headers.get("authorization"), Some(&"Bearer Access Token".to_string()) ); + assert_eq!( + refresh_request + .headers + .get_all(REMOTE_CONTROL_ACCOUNT_ID_HEADER), + vec!["account_id"] + ); + assert_eq!( + refresh_request + .headers + .get_all(REMOTE_CONTROL_INSTALLATION_ID_HEADER), + vec![TEST_INSTALLATION_ID] + ); assert_eq!( serde_json::from_str::(&refresh_request.body) .expect("refresh body should deserialize"), @@ -2578,10 +2592,35 @@ async fn remote_control_http_mode_preserves_enrollment_after_generic_websocket_4 struct CapturedHttpRequest { stream: TcpStream, request_line: String, - headers: BTreeMap, + headers: CapturedHttpHeaders, body: String, } +#[derive(Debug, Default)] +struct CapturedHttpHeaders(Vec<(String, String)>); + +impl CapturedHttpHeaders { + fn append(&mut self, name: String, value: String) { + self.0.push((name, value)); + } + + fn get(&self, name: &str) -> Option<&String> { + self.0 + .iter() + .rev() + .find(|(candidate, _value)| candidate.eq_ignore_ascii_case(name)) + .map(|(_name, value)| value) + } + + fn get_all(&self, name: &str) -> Vec<&str> { + self.0 + .iter() + .filter(|(candidate, _value)| candidate.eq_ignore_ascii_case(name)) + .map(|(_name, value)| value.as_str()) + .collect() + } +} + #[derive(Clone, Debug, PartialEq, Eq)] struct CapturedWebSocketRequest { path: String, @@ -2612,7 +2651,7 @@ async fn accept_http_request(listener: &TcpListener) -> CapturedHttpRequest { .expect("request line should read"); let request_line = request_line.trim_end_matches("\r\n").to_string(); - let mut headers = BTreeMap::new(); + let mut headers = CapturedHttpHeaders::default(); loop { let mut line = String::new(); reader @@ -2624,7 +2663,7 @@ async fn accept_http_request(listener: &TcpListener) -> CapturedHttpRequest { } let line = line.trim_end_matches("\r\n"); let (name, value) = line.split_once(':').expect("header should contain colon"); - headers.insert(name.to_ascii_lowercase(), value.trim().to_string()); + headers.append(name.to_ascii_lowercase(), value.trim().to_string()); } let content_length = headers diff --git a/codex-rs/app-server-transport/src/transport/remote_control/tests/clients_tests.rs b/codex-rs/app-server-transport/src/transport/remote_control/tests/clients_tests.rs index 21ec02ab0..fafbed3f7 100644 --- a/codex-rs/app-server-transport/src/transport/remote_control/tests/clients_tests.rs +++ b/codex-rs/app-server-transport/src/transport/remote_control/tests/clients_tests.rs @@ -60,8 +60,8 @@ async fn remote_control_handle_lists_clients_while_disabled() { Some(&"Bearer Access Token".to_string()) ); assert_eq!( - request.headers.get(REMOTE_CONTROL_ACCOUNT_ID_HEADER), - Some(&"account_id".to_string()) + request.headers.get_all(REMOTE_CONTROL_ACCOUNT_ID_HEADER), + vec!["account_id"] ); respond_with_json( request.stream, @@ -127,6 +127,14 @@ async fn remote_control_handle_revokes_client_while_disabled() { request.request_line, "DELETE /backend-api/wham/remote/control/environments/env%20%2F%3F/clients/client%20%2F%3F HTTP/1.1" ); + assert_eq!( + request.headers.get("authorization"), + Some(&"Bearer Access Token".to_string()) + ); + assert_eq!( + request.headers.get_all(REMOTE_CONTROL_ACCOUNT_ID_HEADER), + vec!["account_id"] + ); respond_with_status(request.stream, "204 No Content", "").await; }); let handle = client_management_handle(remote_control_url, remote_control_auth_manager()); @@ -155,6 +163,12 @@ async fn list_remote_control_clients_recovers_auth_after_unauthorized() { stale_request.headers.get("authorization"), Some(&"Bearer stale-token".to_string()) ); + assert_eq!( + stale_request + .headers + .get_all(REMOTE_CONTROL_ACCOUNT_ID_HEADER), + vec!["account_id"] + ); respond_with_status(stale_request.stream, "401 Unauthorized", "").await; let recovered_request = accept_http_request(&listener).await; @@ -162,6 +176,12 @@ async fn list_remote_control_clients_recovers_auth_after_unauthorized() { recovered_request.headers.get("authorization"), Some(&"Bearer fresh-token".to_string()) ); + assert_eq!( + recovered_request + .headers + .get_all(REMOTE_CONTROL_ACCOUNT_ID_HEADER), + vec!["account_id"] + ); respond_with_json(recovered_request.stream, empty_client_list()).await; }); let codex_home = TempDir::new().expect("temp dir should create"); @@ -235,6 +255,12 @@ async fn list_remote_control_clients_retries_unauthorized_only_once() { stale_request.headers.get("authorization"), Some(&"Bearer stale-token".to_string()) ); + assert_eq!( + stale_request + .headers + .get_all(REMOTE_CONTROL_ACCOUNT_ID_HEADER), + vec!["account_id"] + ); respond_with_status(stale_request.stream, "401 Unauthorized", "").await; let recovered_request = accept_http_request(&listener).await; @@ -242,6 +268,12 @@ async fn list_remote_control_clients_retries_unauthorized_only_once() { recovered_request.headers.get("authorization"), Some(&"Bearer fresh-token".to_string()) ); + assert_eq!( + recovered_request + .headers + .get_all(REMOTE_CONTROL_ACCOUNT_ID_HEADER), + vec!["account_id"] + ); respond_with_status(recovered_request.stream, "401 Unauthorized", "").await; assert!( @@ -311,6 +343,14 @@ async fn revoke_remote_control_client_does_not_retry_forbidden() { let remote_control_url = remote_control_url_for_listener(&listener); let server_task = tokio::spawn(async move { let request = accept_http_request(&listener).await; + assert_eq!( + request.headers.get("authorization"), + Some(&"Bearer Access Token".to_string()) + ); + assert_eq!( + request.headers.get_all(REMOTE_CONTROL_ACCOUNT_ID_HEADER), + vec!["account_id"] + ); respond_with_status_and_headers( request.stream, "403 Forbidden",