Files
codex/codex-rs/codex-api/tests/clients.rs
T
Owen Lin 040dafa32d feat(core): add metadata field to ResponseItem (#28355)
## Description

This PR adds an optional `metadata` field to `ResponseItem` for
Responses API calls. Only mechanical plumbing, no actual values
populated and sent yet. Turns out just adding a new field to
`ResponseItem` has quite a large blast radius already.

This change is backwards compatible because `metadata` is optional and
omitted when absent, so existing response items and rollout history
without it still deserialize and requests that do not set it keep the
same wire shape. For provider compatibility, we strip out `metadata`
before non-OpenAI Responses requests so Azure and AWS Bedrock never see
this field.

My followup PR here will actually make use of it to start storing and
passing along `turn_id`: https://github.com/openai/codex/pull/28360

## What changed

- Added `ResponseItemMetadata` with optional `turn_id`, plus optional
`metadata` on Responses API item variants and inter-agent communication.
- Preserved item metadata through response-item rewrites such as
truncation, missing tool-output synthesis, compaction history
rebuilding, visible-history conversion, rollout/resume, and generated
app-server schemas/types.
- Strip item metadata from non-OpenAI Responses requests while
preserving it for OpenAI-shaped requests.
- Updated the mechanical fixture/test construction churn required by the
new optional field.
2026-06-15 15:05:28 -07:00

589 lines
17 KiB
Rust

use std::sync::Arc;
use std::sync::Mutex;
use std::time::Duration;
use anyhow::Result;
use bytes::Bytes;
use codex_api::ApiError;
use codex_api::AuthError;
use codex_api::AuthProvider;
use codex_api::Compression;
use codex_api::Provider;
use codex_api::ResponsesApiRequest;
use codex_api::ResponsesClient;
use codex_api::ResponsesOptions;
use codex_client::HttpTransport;
use codex_client::Request;
use codex_client::RequestBody;
use codex_client::Response;
use codex_client::StreamResponse;
use codex_client::TransportError;
use codex_protocol::models::ContentItem;
use codex_protocol::models::ResponseItem;
use codex_protocol::protocol::SessionSource;
use codex_protocol::protocol::SubAgentSource;
use http::HeaderMap;
use http::HeaderValue;
use http::StatusCode;
use pretty_assertions::assert_eq;
fn assert_path_ends_with(requests: &[Request], suffix: &str) {
assert_eq!(requests.len(), 1);
let url = &requests[0].url;
assert!(
url.ends_with(suffix),
"expected url to end with {suffix}, got {url}"
);
}
fn request_body_bytes(request: &Request) -> &[u8] {
let Some(RequestBody::EncodedJson(body)) = request.body.as_ref() else {
panic!("expected a prepared request body");
};
body.as_bytes()
}
#[derive(Debug, Default, Clone)]
struct RecordingState {
stream_requests: Arc<Mutex<Vec<Request>>>,
}
impl RecordingState {
fn record(&self, req: Request) {
let mut guard = self
.stream_requests
.lock()
.unwrap_or_else(|err| panic!("mutex poisoned: {err}"));
guard.push(req);
}
fn take_stream_requests(&self) -> Vec<Request> {
let mut guard = self
.stream_requests
.lock()
.unwrap_or_else(|err| panic!("mutex poisoned: {err}"));
std::mem::take(&mut *guard)
}
}
#[derive(Clone)]
struct RecordingTransport {
state: RecordingState,
}
impl RecordingTransport {
fn new(state: RecordingState) -> Self {
Self { state }
}
}
impl HttpTransport for RecordingTransport {
async fn execute(&self, _req: Request) -> Result<Response, TransportError> {
Err(TransportError::Build("execute should not run".to_string()))
}
async fn stream(&self, req: Request) -> Result<StreamResponse, TransportError> {
self.state.record(req);
let stream = futures::stream::iter(Vec::<Result<Bytes, TransportError>>::new());
Ok(StreamResponse {
status: StatusCode::OK,
headers: HeaderMap::new(),
bytes: Box::pin(stream),
})
}
}
#[derive(Clone, Default)]
struct NoAuth;
impl AuthProvider for NoAuth {
fn add_auth_headers(&self, _headers: &mut HeaderMap) {}
}
#[derive(Clone)]
struct StaticAuth {
token: String,
account_id: String,
}
impl StaticAuth {
fn new(token: &str, account_id: &str) -> Self {
Self {
token: token.to_string(),
account_id: account_id.to_string(),
}
}
}
impl AuthProvider for StaticAuth {
fn add_auth_headers(&self, headers: &mut HeaderMap) {
let token = &self.token;
if let Ok(header) = HeaderValue::from_str(&format!("Bearer {token}")) {
headers.insert(http::header::AUTHORIZATION, header);
}
if let Ok(header) = HeaderValue::from_str(&self.account_id) {
headers.insert("ChatGPT-Account-ID", header);
}
}
}
fn provider(name: &str) -> Provider {
Provider {
name: name.to_string(),
base_url: "https://example.com/v1".to_string(),
query_params: None,
headers: HeaderMap::new(),
retry: codex_api::RetryConfig {
max_attempts: 1,
base_delay: Duration::from_millis(1),
retry_429: false,
retry_5xx: false,
retry_transport: true,
},
stream_idle_timeout: Duration::from_millis(10),
}
}
#[derive(Debug, Default)]
struct FlakyTransportState {
attempts: i64,
requests: Vec<(RequestBody, HeaderMap, codex_client::RequestCompression)>,
}
#[derive(Clone)]
struct FlakyTransport {
state: Arc<Mutex<FlakyTransportState>>,
}
impl Default for FlakyTransport {
fn default() -> Self {
Self::new()
}
}
impl FlakyTransport {
fn new() -> Self {
Self {
state: Arc::new(Mutex::new(FlakyTransportState::default())),
}
}
fn attempts(&self) -> i64 {
self.state
.lock()
.unwrap_or_else(|err| panic!("mutex poisoned: {err}"))
.attempts
}
fn requests(&self) -> Vec<(RequestBody, HeaderMap, codex_client::RequestCompression)> {
self.state
.lock()
.unwrap_or_else(|err| panic!("mutex poisoned: {err}"))
.requests
.clone()
}
}
#[derive(Clone)]
struct FailsOnceAuth {
attempts: Arc<Mutex<i64>>,
error: Arc<AuthError>,
}
impl FailsOnceAuth {
fn transient() -> Self {
Self {
attempts: Arc::new(Mutex::new(0)),
error: Arc::new(AuthError::Transient(
"sts temporarily unavailable".to_string(),
)),
}
}
fn build() -> Self {
Self {
attempts: Arc::new(Mutex::new(0)),
error: Arc::new(AuthError::Build("invalid auth configuration".to_string())),
}
}
fn attempts(&self) -> i64 {
*self
.attempts
.lock()
.unwrap_or_else(|err| panic!("mutex poisoned: {err}"))
}
async fn apply_auth(&self, request: Request) -> Result<Request, AuthError> {
let mut attempts = self
.attempts
.lock()
.unwrap_or_else(|err| panic!("mutex poisoned: {err}"));
*attempts += 1;
if *attempts == 1 {
return match self.error.as_ref() {
AuthError::Build(message) => Err(AuthError::Build(message.clone())),
AuthError::Transient(message) => Err(AuthError::Transient(message.clone())),
};
}
Ok(request)
}
}
impl AuthProvider for FailsOnceAuth {
fn add_auth_headers(&self, _headers: &mut HeaderMap) {}
fn apply_auth(&self, request: Request) -> codex_api::AuthProviderFuture<'_> {
Box::pin(FailsOnceAuth::apply_auth(self, request))
}
}
impl HttpTransport for FlakyTransport {
async fn execute(&self, _req: Request) -> Result<Response, TransportError> {
Err(TransportError::Build("execute should not run".to_string()))
}
async fn stream(&self, req: Request) -> Result<StreamResponse, TransportError> {
let Some(body) = req.body.clone() else {
panic!("request should have a body");
};
let mut state = self
.state
.lock()
.unwrap_or_else(|err| panic!("mutex poisoned: {err}"));
state.attempts += 1;
state
.requests
.push((body, req.headers.clone(), req.compression));
if state.attempts == 1 {
return Err(TransportError::Network("first attempt fails".to_string()));
}
let stream = futures::stream::iter(vec![Ok(Bytes::from(
r#"event: message
data: {"id":"resp-1","output":[{"type":"message","role":"assistant","content":[{"type":"output_text","text":"hi"}]}]}
"#,
))]);
Ok(StreamResponse {
status: StatusCode::OK,
headers: HeaderMap::new(),
bytes: Box::pin(stream),
})
}
}
#[tokio::test]
async fn responses_client_uses_responses_path() -> Result<()> {
let state = RecordingState::default();
let transport = RecordingTransport::new(state.clone());
let client = ResponsesClient::new(transport, provider("openai"), Arc::new(NoAuth));
let body = serde_json::json!({ "echo": true });
let _stream = client
.stream(
body,
HeaderMap::new(),
Compression::None,
/*turn_state*/ None,
)
.await?;
let requests = state.take_stream_requests();
assert_path_ends_with(&requests, "/responses");
Ok(())
}
#[tokio::test]
async fn responses_client_stream_request_preserves_exact_json_body() -> Result<()> {
let state = RecordingState::default();
let transport = RecordingTransport::new(state.clone());
let client = ResponsesClient::new(transport, provider("openai"), Arc::new(NoAuth));
let request = ResponsesApiRequest {
model: "gpt-test".into(),
instructions: "Say hi".into(),
input: vec![ResponseItem::Message {
id: Some("msg_1".into()),
role: "user".into(),
content: vec![ContentItem::InputText { text: "hi".into() }],
phase: None,
}],
tools: Vec::new(),
tool_choice: "auto".into(),
parallel_tool_calls: false,
reasoning: None,
store: false,
stream: true,
include: Vec::new(),
service_tier: None,
prompt_cache_key: None,
text: None,
client_metadata: None,
};
let expected = serde_json::to_vec(&request)?;
let _stream = client
.stream_request(request, ResponsesOptions::default())
.await?;
let requests = state.take_stream_requests();
assert_eq!(requests.len(), 1);
let prepared = requests[0]
.prepare_body_for_send()
.expect("body should prepare");
assert_eq!(prepared.body.as_deref(), Some(expected.as_slice()));
assert_eq!(
prepared.headers.get(http::header::CONTENT_TYPE),
Some(&HeaderValue::from_static("application/json"))
);
Ok(())
}
#[tokio::test]
async fn streaming_client_adds_auth_headers() -> Result<()> {
let state = RecordingState::default();
let transport = RecordingTransport::new(state.clone());
let auth = Arc::new(StaticAuth::new("secret-token", "acct-1"));
let client = ResponsesClient::new(transport, provider("openai"), auth);
let body = serde_json::json!({ "model": "gpt-test" });
let _stream = client
.stream(
body,
HeaderMap::new(),
Compression::None,
/*turn_state*/ None,
)
.await?;
let requests = state.take_stream_requests();
assert_eq!(requests.len(), 1);
let req = &requests[0];
let auth_header = req.headers.get(http::header::AUTHORIZATION);
assert!(auth_header.is_some(), "missing auth header");
assert_eq!(
auth_header.unwrap().to_str().ok(),
Some("Bearer secret-token")
);
let account_header = req.headers.get("ChatGPT-Account-ID");
assert!(account_header.is_some(), "missing account header");
assert_eq!(account_header.unwrap().to_str().ok(), Some("acct-1"));
let accept_header = req.headers.get(http::header::ACCEPT);
assert!(accept_header.is_some(), "missing Accept header");
assert_eq!(
accept_header.unwrap().to_str().ok(),
Some("text/event-stream")
);
Ok(())
}
#[tokio::test]
async fn streaming_client_retries_on_transport_error() -> Result<()> {
let transport = FlakyTransport::new();
let mut provider = provider("openai");
provider.retry.max_attempts = 2;
let request = ResponsesApiRequest {
model: "gpt-test".into(),
instructions: "Say hi".into(),
input: Vec::new(),
tools: Vec::new(),
tool_choice: "auto".into(),
parallel_tool_calls: false,
reasoning: None,
store: false,
stream: true,
include: Vec::new(),
service_tier: None,
prompt_cache_key: None,
text: None,
client_metadata: None,
};
let client = ResponsesClient::new(transport.clone(), provider, Arc::new(NoAuth));
let _stream = client
.stream_request(
request,
ResponsesOptions {
compression: Compression::Zstd,
..Default::default()
},
)
.await?;
assert_eq!(transport.attempts(), 2);
let requests = transport.requests();
assert_eq!(requests.len(), 2);
assert_eq!(requests[0], requests[1]);
let RequestBody::EncodedJson(first_body) = &requests[0].0 else {
panic!("expected an encoded JSON body");
};
let RequestBody::EncodedJson(second_body) = &requests[1].0 else {
panic!("expected an encoded JSON body");
};
assert_eq!(
first_body.as_bytes().as_ptr(),
second_body.as_bytes().as_ptr()
);
assert_eq!(
requests[0].1.get(http::header::CONTENT_ENCODING),
Some(&HeaderValue::from_static("zstd"))
);
assert_eq!(requests[0].2, codex_client::RequestCompression::None);
Ok(())
}
#[tokio::test]
async fn streaming_client_retries_on_transient_auth_error() -> Result<()> {
let state = RecordingState::default();
let transport = RecordingTransport::new(state.clone());
let auth = FailsOnceAuth::transient();
let mut provider = provider("openai");
provider.retry.max_attempts = 2;
let client = ResponsesClient::new(transport, provider, Arc::new(auth.clone()));
let body = serde_json::json!({ "model": "gpt-test" });
let _stream = client
.stream(
body,
HeaderMap::new(),
Compression::None,
/*turn_state*/ None,
)
.await?;
assert_eq!(auth.attempts(), 2);
assert_eq!(state.take_stream_requests().len(), 1);
Ok(())
}
#[tokio::test]
async fn streaming_client_does_not_retry_auth_build_error() -> Result<()> {
let state = RecordingState::default();
let transport = RecordingTransport::new(state.clone());
let auth = FailsOnceAuth::build();
let mut provider = provider("openai");
provider.retry.max_attempts = 2;
let client = ResponsesClient::new(transport, provider, Arc::new(auth.clone()));
let body = serde_json::json!({ "model": "gpt-test" });
let result = client
.stream(
body,
HeaderMap::new(),
Compression::None,
/*turn_state*/ None,
)
.await;
let err = match result {
Ok(_) => panic!("auth build errors should fail without retry"),
Err(err) => err,
};
assert!(matches!(
err,
ApiError::Transport(TransportError::Build(message))
if message == "invalid auth configuration"
));
assert_eq!(auth.attempts(), 1);
assert_eq!(state.take_stream_requests().len(), 0);
Ok(())
}
#[tokio::test]
async fn azure_default_store_attaches_ids_and_headers() -> Result<()> {
let state = RecordingState::default();
let transport = RecordingTransport::new(state.clone());
let client = ResponsesClient::new(transport, provider("azure"), Arc::new(NoAuth));
let request = ResponsesApiRequest {
model: "gpt-test".into(),
instructions: "Say hi".into(),
input: vec![ResponseItem::Message {
id: Some("msg_1".into()),
role: "user".into(),
content: vec![ContentItem::InputText { text: "hi".into() }],
phase: None,
metadata: None,
}],
tools: Vec::new(),
tool_choice: "auto".into(),
parallel_tool_calls: false,
reasoning: None,
store: true,
stream: true,
include: Vec::new(),
service_tier: None,
prompt_cache_key: None,
text: None,
client_metadata: None,
};
let mut extra_headers = HeaderMap::new();
extra_headers.insert("x-test-header", HeaderValue::from_static("present"));
let _stream = client
.stream_request(
request,
ResponsesOptions {
session_id: Some("sess_123".into()),
thread_id: Some("thread_123".into()),
session_source: Some(SessionSource::SubAgent(SubAgentSource::Review)),
extra_headers,
compression: Compression::None,
turn_state: None,
},
)
.await?;
let requests = state.take_stream_requests();
assert_eq!(requests.len(), 1);
let req = &requests[0];
assert_eq!(
req.headers.get("session-id").and_then(|v| v.to_str().ok()),
Some("sess_123")
);
assert_eq!(
req.headers.get("thread-id").and_then(|v| v.to_str().ok()),
Some("thread_123")
);
assert_eq!(
req.headers
.get("x-client-request-id")
.and_then(|v| v.to_str().ok()),
Some("thread_123")
);
assert_eq!(
req.headers
.get("x-openai-subagent")
.and_then(|v| v.to_str().ok()),
Some("review")
);
assert_eq!(
req.headers
.get("x-test-header")
.and_then(|v| v.to_str().ok()),
Some("present")
);
let body: serde_json::Value = serde_json::from_slice(request_body_bytes(req))?;
let input_id = body
.get("input")
.and_then(|input| input.get(0))
.and_then(|item| item.get("id"))
.and_then(|id| id.as_str());
assert_eq!(input_id, Some("msg_1"));
Ok(())
}