[3/4] Add executor-backed RMCP HTTP client (#18583)

### Why
The RMCP layer needs a Streamable HTTP client that can talk either
directly over `reqwest` or through the executor HTTP runner without
duplicating MCP session logic higher in the stack. This PR adds that
client-side transport boundary so remote Streamable HTTP MCP can reuse
the same RMCP flow as the local path.

### What
- Add a shared `rmcp-client/src/streamable_http/` module with:
  - `transport_client.rs` for the local-or-remote transport enum
  - `local_client.rs` for the direct `reqwest` implementation
  - `remote_client.rs` for the executor-backed implementation
  - `common.rs` for the small shared Streamable HTTP helpers
- Teach `RmcpClient` to build Streamable HTTP transports in either local
or remote mode while keeping the existing OAuth ownership in RMCP.
- Translate remote POST, GET, and DELETE session operations into
executor `http/request` calls.
- Preserve RMCP session expiry handling and reconnect behavior for the
remote transport.
- Add remote transport coverage in
`rmcp-client/tests/streamable_http_remote.rs` and keep the shared test
support in `rmcp-client/tests/streamable_http_test_support.rs`.

### Verification
- `cargo check -p codex-rmcp-client`
- online CI

### Stack
1. #18581 protocol
2. #18582 runner
3. #18583 RMCP client
4. #18584 manager wiring and local/remote coverage

---------

Co-authored-by: Codex <noreply@openai.com>
This commit is contained in:
Ahmed Ibrahim
2026-04-22 17:38:04 -07:00
committed by GitHub
Unverified
parent 83ec1eb5d6
commit 0e78ce80ee
20 changed files with 1595 additions and 972 deletions
+3
View File
@@ -2579,7 +2579,9 @@ dependencies = [
"arc-swap",
"async-trait",
"base64 0.22.1",
"bytes",
"codex-app-server-protocol",
"codex-client",
"codex-config",
"codex-protocol",
"codex-sandboxing",
@@ -3141,6 +3143,7 @@ version = "0.0.0"
dependencies = [
"anyhow",
"axum",
"bytes",
"codex-client",
"codex-config",
"codex-exec-server",
@@ -1589,23 +1589,11 @@ async fn make_rmcp_client(
env_http_headers,
bearer_token_env_var,
} => {
if remote_environment {
if !runtime_environment.environment().is_remote() {
return Err(StartupOutcomeError::from(anyhow!(
"remote MCP server `{server_name}` requires a remote executor environment"
)));
}
if remote_environment && !runtime_environment.environment().is_remote() {
return Err(StartupOutcomeError::from(anyhow!(
// Remote HTTP needs the future low-level executor
// `network/request` API so reqwest runs on the executor side.
// Do not fall back to local HTTP here; the config explicitly
// asked for remote placement.
"remote streamable HTTP MCP server `{server_name}` is not implemented yet"
"remote MCP server `{server_name}` requires a remote environment"
)));
}
// Local streamable HTTP remains the existing reqwest path from
// the orchestrator process.
let resolved_bearer_token =
match resolve_bearer_token(server_name, bearer_token_env_var.as_deref()) {
Ok(token) => token,
@@ -1618,6 +1606,7 @@ async fn make_rmcp_client(
http_headers,
env_http_headers,
store_mode,
runtime_environment.environment().get_http_client(),
)
.await
.map_err(StartupOutcomeError::from)
+2
View File
@@ -14,7 +14,9 @@ workspace = true
arc-swap = { workspace = true }
async-trait = { workspace = true }
base64 = { workspace = true }
bytes = { workspace = true }
codex-app-server-protocol = { workspace = true }
codex-client = { workspace = true }
codex-config = { workspace = true }
codex-protocol = { workspace = true }
codex-sandboxing = { workspace = true }
+24
View File
@@ -8,6 +8,8 @@ use std::time::Duration;
use arc_swap::ArcSwap;
use codex_app_server_protocol::JSONRPCNotification;
use futures::FutureExt;
use futures::future::BoxFuture;
use serde_json::Value;
use tokio::sync::Mutex;
use tokio::sync::OnceCell;
@@ -20,6 +22,7 @@ use tracing::debug;
use crate::ProcessId;
use crate::client_api::ExecServerClientConnectOptions;
use crate::client_api::HttpClient;
use crate::client_api::RemoteExecServerConnectArgs;
use crate::connection::JsonRpcConnection;
use crate::process::ExecProcessEvent;
@@ -206,6 +209,25 @@ impl LazyRemoteExecServerClient {
}
}
impl HttpClient for LazyRemoteExecServerClient {
fn http_request(
&self,
params: crate::HttpRequestParams,
) -> BoxFuture<'_, Result<crate::HttpRequestResponse, ExecServerError>> {
async move { self.get().await?.http_request(params).await }.boxed()
}
fn http_request_stream(
&self,
params: crate::HttpRequestParams,
) -> BoxFuture<
'_,
Result<(crate::HttpRequestResponse, crate::HttpResponseBodyStream), ExecServerError>,
> {
async move { self.get().await?.http_request_stream(params).await }.boxed()
}
}
#[derive(Debug, thiserror::Error)]
pub enum ExecServerError {
#[error("failed to spawn exec-server: {0}")]
@@ -226,6 +248,8 @@ pub enum ExecServerError {
Disconnected(String),
#[error("failed to serialize or deserialize exec-server JSON: {0}")]
Json(#[from] serde_json::Error),
#[error("HTTP request failed: {0}")]
HttpRequest(String),
#[error("exec-server protocol error: {0}")]
Protocol(String),
#[error("exec-server rejected request ({code}): {message}")]
+24 -520
View File
@@ -1,522 +1,26 @@
use std::collections::HashMap;
use std::sync::Arc;
use std::sync::atomic::Ordering;
use std::time::Duration;
//! HTTP client capability implementations shared by local and remote environments.
//!
//! This module is the facade for the environment-owned [`crate::HttpClient`]
//! capability:
//! - [`ReqwestHttpClient`] executes requests directly with `reqwest`
//! - [`ExecServerClient`] forwards requests over the JSON-RPC transport
//! - [`HttpResponseBodyStream`] presents buffered local bodies and streamed
//! remote `http/request/bodyDelta` notifications through one byte-stream API
//!
//! Runtime split:
//! - orchestrator process: holds an `Arc<dyn HttpClient>` and chooses local or
//! remote execution
//! - remote runtime: serves the `http/request` RPC and runs the concrete local
//! HTTP request there when the orchestrator uses [`ExecServerClient`]
use codex_app_server_protocol::JSONRPCErrorError;
use futures::StreamExt;
use reqwest::Method;
use reqwest::Url;
use reqwest::header::HeaderMap;
use reqwest::header::HeaderName;
use reqwest::header::HeaderValue;
use serde_json::Value;
use serde_json::from_value;
use tokio::runtime::Handle;
use tokio::sync::mpsc;
use tokio::sync::mpsc::error::TrySendError;
use tracing::debug;
#[path = "reqwest_http_client.rs"]
mod reqwest_http_client;
#[path = "http_response_body_stream.rs"]
pub(crate) mod response_body_stream;
#[path = "rpc_http_client.rs"]
mod rpc_http_client;
use super::ExecServerClient;
use super::ExecServerError;
use super::Inner;
use crate::protocol::HTTP_REQUEST_BODY_DELTA_METHOD;
use crate::protocol::HTTP_REQUEST_METHOD;
use crate::protocol::HttpHeader;
use crate::protocol::HttpRequestBodyDeltaNotification;
use crate::protocol::HttpRequestParams;
use crate::protocol::HttpRequestResponse;
use crate::rpc::RpcNotificationSender;
use crate::rpc::internal_error;
use crate::rpc::invalid_params;
/// Maximum queued body frames per streamed executor HTTP response.
const HTTP_BODY_DELTA_CHANNEL_CAPACITY: usize = 256;
pub(crate) struct ExecutorPendingHttpBodyStream {
pub(crate) request_id: String,
response: reqwest::Response,
}
pub(crate) struct ExecutorHttpRequestRunner {
client: reqwest::Client,
}
/// Request-scoped stream of body chunks for an executor HTTP response.
///
/// The initial `http/request` call returns status and headers. This stream then
/// receives the ordered `http/request/bodyDelta` notifications for that request
/// id until EOF or a terminal error.
pub struct HttpResponseBodyStream {
inner: Arc<Inner>,
request_id: String,
next_seq: u64,
rx: mpsc::Receiver<HttpRequestBodyDeltaNotification>,
// Terminal frames can carry a final chunk; return that once, then EOF.
pending_eof: bool,
closed: bool,
}
impl ExecServerClient {
/// Performs an executor-side HTTP request and buffers the response body.
pub async fn http_request(
&self,
mut params: HttpRequestParams,
) -> Result<HttpRequestResponse, ExecServerError> {
params.stream_response = false;
self.call(HTTP_REQUEST_METHOD, &params).await
}
/// Performs an executor-side HTTP request and returns a body stream.
///
/// The method sets `stream_response` and replaces any caller-supplied
/// `request_id` with a connection-local id, so late deltas from abandoned
/// streams cannot be confused with later requests.
pub async fn http_request_stream(
&self,
mut params: HttpRequestParams,
) -> Result<(HttpRequestResponse, HttpResponseBodyStream), ExecServerError> {
params.stream_response = true;
let request_id = self.inner.next_http_body_stream_request_id();
params.request_id = request_id.clone();
let (tx, rx) = mpsc::channel(HTTP_BODY_DELTA_CHANNEL_CAPACITY);
self.inner
.insert_http_body_stream(request_id.clone(), tx)
.await?;
let mut registration = HttpBodyStreamRegistration {
inner: Arc::clone(&self.inner),
request_id: request_id.clone(),
active: true,
};
let response = match self.call(HTTP_REQUEST_METHOD, &params).await {
Ok(response) => response,
Err(error) => {
self.inner.remove_http_body_stream(&request_id).await;
registration.active = false;
return Err(error);
}
};
registration.active = false;
Ok((
response,
HttpResponseBodyStream {
inner: Arc::clone(&self.inner),
request_id,
next_seq: 1,
rx,
pending_eof: false,
closed: false,
},
))
}
}
impl HttpResponseBodyStream {
/// Receives the next response-body chunk.
///
/// Returns `Ok(None)` at EOF and converts sequence gaps or executor-side
/// stream errors into protocol errors.
pub async fn recv(&mut self) -> Result<Option<Vec<u8>>, ExecServerError> {
if self.pending_eof {
self.pending_eof = false;
self.finish().await;
return Ok(None);
}
let Some(delta) = self.rx.recv().await else {
self.finish().await;
if let Some(error) = self
.inner
.take_http_body_stream_failure(&self.request_id)
.await
{
return Err(ExecServerError::Protocol(format!(
"http response stream `{}` failed: {error}",
self.request_id
)));
}
return Ok(None);
};
if delta.seq != self.next_seq {
self.finish().await;
return Err(ExecServerError::Protocol(format!(
"http response stream `{}` received seq {}, expected {}",
self.request_id, delta.seq, self.next_seq
)));
}
self.next_seq += 1;
let chunk = delta.delta.into_inner();
if let Some(error) = delta.error {
self.finish().await;
return Err(ExecServerError::Protocol(format!(
"http response stream `{}` failed: {error}",
self.request_id
)));
}
if delta.done {
self.finish().await;
if chunk.is_empty() {
return Ok(None);
}
self.pending_eof = true;
}
Ok(Some(chunk))
}
/// Removes this stream from the connection routing table once it reaches EOF.
async fn finish(&mut self) {
if self.closed {
return;
}
self.closed = true;
self.inner.remove_http_body_stream(&self.request_id).await;
}
}
impl Drop for HttpResponseBodyStream {
/// Schedules stream-route removal if the consumer drops before EOF.
fn drop(&mut self) {
if self.closed {
return;
}
self.closed = true;
spawn_remove_http_body_stream(Arc::clone(&self.inner), self.request_id.clone());
}
}
impl ExecutorHttpRequestRunner {
pub(crate) fn new(timeout_ms: Option<u64>) -> Result<Self, JSONRPCErrorError> {
let client = match timeout_ms {
None => reqwest::Client::builder(),
Some(timeout_ms) => {
reqwest::Client::builder().timeout(Duration::from_millis(timeout_ms))
}
}
.build()
.map_err(|err| internal_error(format!("failed to build http/request client: {err}")))?;
Ok(Self { client })
}
pub(crate) async fn run(
&self,
params: HttpRequestParams,
) -> Result<(HttpRequestResponse, Option<ExecutorPendingHttpBodyStream>), JSONRPCErrorError>
{
let method = Method::from_bytes(params.method.as_bytes())
.map_err(|err| invalid_params(format!("http/request method is invalid: {err}")))?;
let url = Url::parse(&params.url)
.map_err(|err| invalid_params(format!("http/request url is invalid: {err}")))?;
match url.scheme() {
"http" | "https" => {}
scheme => {
return Err(invalid_params(format!(
"http/request only supports http and https URLs, got {scheme}"
)));
}
}
let headers = Self::build_headers(params.headers)?;
let mut request = self.client.request(method, url).headers(headers);
if let Some(body) = params.body {
request = request.body(body.into_inner());
}
let response = request
.send()
.await
.map_err(|err| internal_error(format!("http/request failed: {err}")))?;
let status = response.status().as_u16();
let headers = Self::response_headers(response.headers());
if params.stream_response {
return Ok((
HttpRequestResponse {
status,
headers,
body: Vec::new().into(),
},
Some(ExecutorPendingHttpBodyStream {
request_id: params.request_id,
response,
}),
));
}
let body = response.bytes().await.map_err(|err| {
internal_error(format!("failed to read http/request response body: {err}"))
})?;
Ok((
HttpRequestResponse {
status,
headers,
body: body.to_vec().into(),
},
None,
))
}
fn build_headers(headers: Vec<HttpHeader>) -> Result<HeaderMap, JSONRPCErrorError> {
let mut header_map = HeaderMap::new();
for header in headers {
let name = HeaderName::from_bytes(header.name.as_bytes()).map_err(|err| {
invalid_params(format!("http/request header name is invalid: {err}"))
})?;
let value = HeaderValue::from_str(&header.value).map_err(|err| {
invalid_params(format!(
"http/request header value is invalid for {}: {err}",
header.name
))
})?;
header_map.append(name, value);
}
Ok(header_map)
}
fn response_headers(headers: &HeaderMap) -> Vec<HttpHeader> {
headers
.iter()
.filter_map(|(name, value)| {
Some(HttpHeader {
name: name.as_str().to_string(),
value: value.to_str().ok()?.to_string(),
})
})
.collect()
}
pub(crate) async fn stream_body(
pending_stream: ExecutorPendingHttpBodyStream,
notifications: RpcNotificationSender,
) {
let ExecutorPendingHttpBodyStream {
request_id,
response,
} = pending_stream;
let mut seq = 1;
let mut body = response.bytes_stream();
while let Some(chunk) = body.next().await {
match chunk {
Ok(bytes) => {
if !send_executor_body_delta(
&notifications,
HttpRequestBodyDeltaNotification {
request_id: request_id.clone(),
seq,
delta: bytes.to_vec().into(),
done: false,
error: None,
},
)
.await
{
return;
}
seq += 1;
}
Err(err) => {
let _ = send_executor_body_delta(
&notifications,
HttpRequestBodyDeltaNotification {
request_id,
seq,
delta: Vec::new().into(),
done: true,
error: Some(err.to_string()),
},
)
.await;
return;
}
}
}
let _ = send_executor_body_delta(
&notifications,
HttpRequestBodyDeltaNotification {
request_id,
seq,
delta: Vec::new().into(),
done: true,
error: None,
},
)
.await;
}
}
impl Inner {
/// Routes one streamed HTTP body notification into its request-local receiver.
pub(super) async fn handle_http_body_delta_notification(
&self,
params: Option<Value>,
) -> Result<(), ExecServerError> {
let params: HttpRequestBodyDeltaNotification = from_value(params.unwrap_or(Value::Null))?;
// Unknown request ids are ignored intentionally: a stream may have already
// reached EOF and released its route.
if let Some(tx) = self
.http_body_streams
.load()
.get(&params.request_id)
.cloned()
{
let request_id = params.request_id.clone();
let terminal_delta = params.done || params.error.is_some();
match tx.try_send(params) {
Ok(()) => {
if terminal_delta {
self.remove_http_body_stream(&request_id).await;
}
}
Err(TrySendError::Closed(_)) => {
self.remove_http_body_stream(&request_id).await;
debug!("http response stream receiver dropped before body delta delivery");
}
Err(TrySendError::Full(_)) => {
self.record_http_body_stream_failure(
&request_id,
"body delta channel filled before delivery".to_string(),
)
.await;
self.remove_http_body_stream(&request_id).await;
debug!(
"closing http response stream `{request_id}` after body delta backpressure"
);
}
}
}
Ok(())
}
/// Fails active streamed HTTP bodies so callers do not wait forever after a
/// transport disconnect or notification handling failure.
pub(super) async fn fail_all_http_body_streams(&self, message: String) {
let _streams_write_guard = self.http_body_streams_write_lock.lock().await;
let streams = self.http_body_streams.load();
let streams = streams.as_ref().clone();
self.http_body_streams.store(Arc::new(HashMap::new()));
for (request_id, tx) in streams {
if tx
.try_send(HttpRequestBodyDeltaNotification {
request_id: request_id.clone(),
seq: 1,
delta: Vec::new().into(),
done: true,
error: Some(message.clone()),
})
.is_err()
{
let mut next_failures = self.http_body_stream_failures.load().as_ref().clone();
next_failures.insert(request_id, message.clone());
self.http_body_stream_failures
.store(Arc::new(next_failures));
}
}
}
/// Allocates a connection-local streamed HTTP response id.
fn next_http_body_stream_request_id(&self) -> String {
let id = self
.http_body_stream_next_id
.fetch_add(1, Ordering::Relaxed);
format!("http-{id}")
}
/// Registers a request id before issuing an executor streaming HTTP call.
async fn insert_http_body_stream(
&self,
request_id: String,
tx: mpsc::Sender<HttpRequestBodyDeltaNotification>,
) -> Result<(), ExecServerError> {
let _streams_write_guard = self.http_body_streams_write_lock.lock().await;
let streams = self.http_body_streams.load();
if streams.contains_key(&request_id) {
return Err(ExecServerError::Protocol(format!(
"http response stream already registered for request {request_id}"
)));
}
let mut next_streams = streams.as_ref().clone();
next_streams.insert(request_id.clone(), tx);
self.http_body_streams.store(Arc::new(next_streams));
let failures = self.http_body_stream_failures.load();
if failures.contains_key(&request_id) {
let mut next_failures = failures.as_ref().clone();
next_failures.remove(&request_id);
self.http_body_stream_failures
.store(Arc::new(next_failures));
}
Ok(())
}
/// Removes a request id after EOF, terminal error, or request failure.
async fn remove_http_body_stream(
&self,
request_id: &str,
) -> Option<mpsc::Sender<HttpRequestBodyDeltaNotification>> {
let _streams_write_guard = self.http_body_streams_write_lock.lock().await;
let streams = self.http_body_streams.load();
let stream = streams.get(request_id).cloned();
stream.as_ref()?;
let mut next_streams = streams.as_ref().clone();
next_streams.remove(request_id);
self.http_body_streams.store(Arc::new(next_streams));
stream
}
async fn record_http_body_stream_failure(&self, request_id: &str, message: String) {
let _streams_write_guard = self.http_body_streams_write_lock.lock().await;
let failures = self.http_body_stream_failures.load();
let mut next_failures = failures.as_ref().clone();
next_failures.insert(request_id.to_string(), message);
self.http_body_stream_failures
.store(Arc::new(next_failures));
}
async fn take_http_body_stream_failure(&self, request_id: &str) -> Option<String> {
let _streams_write_guard = self.http_body_streams_write_lock.lock().await;
let failures = self.http_body_stream_failures.load();
let error = failures.get(request_id).cloned();
error.as_ref()?;
let mut next_failures = failures.as_ref().clone();
next_failures.remove(request_id);
self.http_body_stream_failures
.store(Arc::new(next_failures));
error
}
}
/// Active route registration owned while `http_request_stream` awaits headers.
struct HttpBodyStreamRegistration {
inner: Arc<Inner>,
request_id: String,
active: bool,
}
impl Drop for HttpBodyStreamRegistration {
/// Removes the route if the stream request future is cancelled before headers return.
fn drop(&mut self) {
if self.active {
spawn_remove_http_body_stream(Arc::clone(&self.inner), self.request_id.clone());
}
}
}
/// Schedules HTTP body route removal from synchronous drop paths.
fn spawn_remove_http_body_stream(inner: Arc<Inner>, request_id: String) {
if let Ok(handle) = Handle::try_current() {
handle.spawn(async move {
inner.remove_http_body_stream(&request_id).await;
});
}
}
async fn send_executor_body_delta(
notifications: &RpcNotificationSender,
delta: HttpRequestBodyDeltaNotification,
) -> bool {
notifications
.notify(HTTP_REQUEST_BODY_DELTA_METHOD, &delta)
.await
.is_ok()
}
pub(crate) use reqwest_http_client::PendingReqwestHttpBodyStream;
pub use reqwest_http_client::ReqwestHttpClient;
pub(crate) use reqwest_http_client::ReqwestHttpRequestRunner;
pub use response_body_stream::HttpResponseBodyStream;
@@ -0,0 +1,355 @@
//! Shared HTTP response-body stream plumbing for local and remote execution.
//!
//! This module owns the byte-stream type exposed by the `HttpClient`
//! capability plus the remote-side routing table used to turn
//! `http/request/bodyDelta` notifications back into per-request streams.
use std::collections::HashMap;
use std::pin::Pin;
use std::sync::Arc;
use std::sync::atomic::Ordering;
use bytes::Bytes;
use futures::StreamExt;
use reqwest::Response;
use serde_json::Value;
use serde_json::from_value;
use tokio::runtime::Handle;
use tokio::sync::mpsc;
use tokio::sync::mpsc::error::TrySendError;
use tracing::debug;
use crate::client::ExecServerError;
use crate::client::Inner;
use crate::protocol::HTTP_REQUEST_BODY_DELTA_METHOD;
use crate::protocol::HttpRequestBodyDeltaNotification;
use crate::rpc::RpcNotificationSender;
pub(super) struct HttpBodyStreamRegistration {
inner: Arc<Inner>,
request_id: String,
active: bool,
}
enum HttpResponseBodyStreamInner {
Local {
body: Pin<Box<dyn futures::Stream<Item = Result<Bytes, reqwest::Error>> + Send>>,
},
Remote {
inner: Arc<Inner>,
request_id: String,
next_seq: u64,
rx: mpsc::Receiver<HttpRequestBodyDeltaNotification>,
pending_eof: bool,
closed: bool,
},
}
/// Request-scoped stream of body chunks for an HTTP response.
///
/// The initial `http/request` call returns status and headers. This stream then
/// receives the ordered `http/request/bodyDelta` notifications for that request
/// id until EOF or a terminal error.
pub struct HttpResponseBodyStream {
inner: HttpResponseBodyStreamInner,
}
impl HttpResponseBodyStream {
pub(super) fn local(response: Response) -> Self {
Self {
inner: HttpResponseBodyStreamInner::Local {
body: Box::pin(response.bytes_stream()),
},
}
}
pub(super) fn remote(
inner: Arc<Inner>,
request_id: String,
rx: mpsc::Receiver<HttpRequestBodyDeltaNotification>,
) -> Self {
Self {
inner: HttpResponseBodyStreamInner::Remote {
inner,
request_id,
next_seq: 1,
rx,
pending_eof: false,
closed: false,
},
}
}
/// Receives the next response-body chunk.
///
/// Returns `Ok(None)` at EOF and converts sequence gaps or stream-side
/// stream errors into protocol errors.
pub async fn recv(&mut self) -> Result<Option<Vec<u8>>, ExecServerError> {
match &mut self.inner {
HttpResponseBodyStreamInner::Local { body } => match body.next().await {
Some(chunk) => match chunk {
Ok(bytes) => Ok(Some(bytes.to_vec())),
Err(error) => Err(ExecServerError::HttpRequest(error.to_string())),
},
None => Ok(None),
},
HttpResponseBodyStreamInner::Remote {
inner,
request_id,
next_seq,
rx,
pending_eof,
closed,
} => {
if *pending_eof {
*pending_eof = false;
finish_remote_stream(inner, request_id, closed).await;
return Ok(None);
}
let Some(delta) = rx.recv().await else {
finish_remote_stream(inner, request_id, closed).await;
if let Some(error) = inner.take_http_body_stream_failure(request_id).await {
return Err(ExecServerError::Protocol(format!(
"http response stream `{request_id}` failed: {error}",
)));
}
return Ok(None);
};
if delta.seq != *next_seq {
finish_remote_stream(inner, request_id, closed).await;
return Err(ExecServerError::Protocol(format!(
"http response stream `{request_id}` received seq {}, expected {}",
delta.seq, *next_seq
)));
}
*next_seq += 1;
let chunk = delta.delta.into_inner();
if let Some(error) = delta.error {
finish_remote_stream(inner, request_id, closed).await;
return Err(ExecServerError::Protocol(format!(
"http response stream `{request_id}` failed: {error}",
)));
}
if delta.done {
finish_remote_stream(inner, request_id, closed).await;
if chunk.is_empty() {
return Ok(None);
}
*pending_eof = true;
}
Ok(Some(chunk))
}
}
}
}
impl Drop for HttpResponseBodyStream {
/// Schedules stream-route removal if the consumer drops before EOF.
fn drop(&mut self) {
if let HttpResponseBodyStreamInner::Remote {
inner,
request_id,
closed,
..
} = &mut self.inner
{
if *closed {
return;
}
*closed = true;
spawn_remove_http_body_stream(Arc::clone(inner), request_id.clone());
}
}
}
impl HttpBodyStreamRegistration {
pub(super) fn new(inner: Arc<Inner>, request_id: String) -> Self {
Self {
inner,
request_id,
active: true,
}
}
pub(super) fn disarm(&mut self) {
self.active = false;
}
}
impl Drop for HttpBodyStreamRegistration {
/// Removes the route if the stream request future is cancelled before headers return.
fn drop(&mut self) {
if self.active {
spawn_remove_http_body_stream(Arc::clone(&self.inner), self.request_id.clone());
}
}
}
async fn finish_remote_stream(inner: &Arc<Inner>, request_id: &str, closed: &mut bool) {
if *closed {
return;
}
*closed = true;
inner.remove_http_body_stream(request_id).await;
}
/// Schedules HTTP body route removal from synchronous drop paths.
fn spawn_remove_http_body_stream(inner: Arc<Inner>, request_id: String) {
if let Ok(handle) = Handle::try_current() {
handle.spawn(async move {
inner.remove_http_body_stream(&request_id).await;
});
}
}
pub(super) async fn send_body_delta(
notifications: &RpcNotificationSender,
delta: HttpRequestBodyDeltaNotification,
) -> bool {
notifications
.notify(HTTP_REQUEST_BODY_DELTA_METHOD, &delta)
.await
.is_ok()
}
impl Inner {
/// Routes one streamed HTTP body notification into its request-local receiver.
pub(crate) async fn handle_http_body_delta_notification(
&self,
params: Option<Value>,
) -> Result<(), ExecServerError> {
let params: HttpRequestBodyDeltaNotification = from_value(params.unwrap_or(Value::Null))?;
// Unknown request ids are ignored intentionally: a stream may have already
// reached EOF and released its route.
if let Some(tx) = self
.http_body_streams
.load()
.get(&params.request_id)
.cloned()
{
let request_id = params.request_id.clone();
let terminal_delta = params.done || params.error.is_some();
match tx.try_send(params) {
Ok(()) => {
if terminal_delta {
self.remove_http_body_stream(&request_id).await;
}
}
Err(TrySendError::Closed(_)) => {
self.remove_http_body_stream(&request_id).await;
debug!("http response stream receiver dropped before body delta delivery");
}
Err(TrySendError::Full(_)) => {
self.record_http_body_stream_failure(
&request_id,
"body delta channel filled before delivery".to_string(),
)
.await;
self.remove_http_body_stream(&request_id).await;
debug!(
"closing http response stream `{request_id}` after body delta backpressure"
);
}
}
}
Ok(())
}
/// Fails active streamed HTTP bodies so callers do not wait forever after a
/// transport disconnect or notification handling failure.
pub(crate) async fn fail_all_http_body_streams(&self, message: String) {
let _streams_write_guard = self.http_body_streams_write_lock.lock().await;
let streams = self.http_body_streams.load();
let streams = streams.as_ref().clone();
self.http_body_streams.store(Arc::new(HashMap::new()));
for (request_id, tx) in streams {
if tx
.try_send(HttpRequestBodyDeltaNotification {
request_id: request_id.clone(),
seq: 1,
delta: Vec::new().into(),
done: true,
error: Some(message.clone()),
})
.is_err()
{
let mut next_failures = self.http_body_stream_failures.load().as_ref().clone();
next_failures.insert(request_id, message.clone());
self.http_body_stream_failures
.store(Arc::new(next_failures));
}
}
}
/// Allocates a connection-local streamed HTTP response id.
pub(super) fn next_http_body_stream_request_id(&self) -> String {
let id = self
.http_body_stream_next_id
.fetch_add(1, Ordering::Relaxed);
format!("http-{id}")
}
/// Registers a request id before issuing a streaming HTTP call.
pub(super) async fn insert_http_body_stream(
&self,
request_id: String,
tx: mpsc::Sender<HttpRequestBodyDeltaNotification>,
) -> Result<(), ExecServerError> {
let _streams_write_guard = self.http_body_streams_write_lock.lock().await;
let streams = self.http_body_streams.load();
if streams.contains_key(&request_id) {
return Err(ExecServerError::Protocol(format!(
"http response stream already registered for request {request_id}"
)));
}
let mut next_streams = streams.as_ref().clone();
next_streams.insert(request_id.clone(), tx);
self.http_body_streams.store(Arc::new(next_streams));
let failures = self.http_body_stream_failures.load();
if failures.contains_key(&request_id) {
let mut next_failures = failures.as_ref().clone();
next_failures.remove(&request_id);
self.http_body_stream_failures
.store(Arc::new(next_failures));
}
Ok(())
}
/// Removes a request id after EOF, terminal error, or request failure.
pub(super) async fn remove_http_body_stream(
&self,
request_id: &str,
) -> Option<mpsc::Sender<HttpRequestBodyDeltaNotification>> {
let _streams_write_guard = self.http_body_streams_write_lock.lock().await;
let streams = self.http_body_streams.load();
let stream = streams.get(request_id).cloned();
stream.as_ref()?;
let mut next_streams = streams.as_ref().clone();
next_streams.remove(request_id);
self.http_body_streams.store(Arc::new(next_streams));
stream
}
async fn record_http_body_stream_failure(&self, request_id: &str, message: String) {
let _streams_write_guard = self.http_body_streams_write_lock.lock().await;
let failures = self.http_body_stream_failures.load();
let mut next_failures = failures.as_ref().clone();
next_failures.insert(request_id.to_string(), message);
self.http_body_stream_failures
.store(Arc::new(next_failures));
}
async fn take_http_body_stream_failure(&self, request_id: &str) -> Option<String> {
let _streams_write_guard = self.http_body_streams_write_lock.lock().await;
let failures = self.http_body_stream_failures.load();
let error = failures.get(request_id).cloned();
error.as_ref()?;
let mut next_failures = failures.as_ref().clone();
next_failures.remove(request_id);
self.http_body_stream_failures
.store(Arc::new(next_failures));
error
}
}
@@ -0,0 +1,267 @@
//! `reqwest`-backed `HttpClient` implementation.
//!
//! This code runs wherever the real network request should originate:
//! - in a local environment, that means the orchestrator process
//! - in a remote environment, that means the remote runtime after the
//! orchestrator has forwarded `http/request` over JSON-RPC
use std::time::Duration;
use codex_app_server_protocol::JSONRPCErrorError;
use codex_client::build_reqwest_client_with_custom_ca;
use futures::FutureExt;
use futures::StreamExt;
use futures::future::BoxFuture;
use reqwest::Method;
use reqwest::Url;
use reqwest::header::HeaderMap;
use reqwest::header::HeaderName;
use reqwest::header::HeaderValue;
use super::HttpResponseBodyStream;
use super::response_body_stream::send_body_delta;
use crate::HttpClient;
use crate::client::ExecServerError;
use crate::protocol::HttpHeader;
use crate::protocol::HttpRequestBodyDeltaNotification;
use crate::protocol::HttpRequestParams;
use crate::protocol::HttpRequestResponse;
use crate::rpc::RpcNotificationSender;
use crate::rpc::internal_error;
use crate::rpc::invalid_params;
/// `HttpClient` implementation that performs the actual HTTP request with
/// `reqwest`.
#[derive(Clone, Default)]
pub struct ReqwestHttpClient;
/// Streaming response state held between the initial HTTP response and
/// downstream body-delta forwarding.
pub(crate) struct PendingReqwestHttpBodyStream {
pub(crate) request_id: String,
pub(crate) response: reqwest::Response,
}
/// Validates `http/request` parameters and runs the actual `reqwest` call used
/// by the exec-server route and the local [`HttpClient`] backend.
pub(crate) struct ReqwestHttpRequestRunner {
client: reqwest::Client,
}
impl ReqwestHttpClient {
fn build_client(timeout_ms: Option<u64>) -> Result<reqwest::Client, ExecServerError> {
let builder = match timeout_ms {
None => reqwest::Client::builder(),
Some(timeout_ms) => {
reqwest::Client::builder().timeout(Duration::from_millis(timeout_ms))
}
};
build_reqwest_client_with_custom_ca(builder)
.map_err(|error| ExecServerError::HttpRequest(error.to_string()))
}
}
impl HttpClient for ReqwestHttpClient {
fn http_request(
&self,
params: HttpRequestParams,
) -> BoxFuture<'_, Result<HttpRequestResponse, ExecServerError>> {
async move {
let runner = ReqwestHttpRequestRunner::new(params.timeout_ms)
.map_err(|error| ExecServerError::HttpRequest(error.message))?;
let (response, _) = runner
.run(HttpRequestParams {
stream_response: false,
..params
})
.await
.map_err(|error| ExecServerError::HttpRequest(error.message))?;
Ok(response)
}
.boxed()
}
fn http_request_stream(
&self,
params: HttpRequestParams,
) -> BoxFuture<'_, Result<(HttpRequestResponse, HttpResponseBodyStream), ExecServerError>> {
async move {
let runner = ReqwestHttpRequestRunner::new(params.timeout_ms)
.map_err(|error| ExecServerError::HttpRequest(error.message))?;
let (response, pending_stream) = runner
.run(HttpRequestParams {
stream_response: true,
..params
})
.await
.map_err(|error| ExecServerError::HttpRequest(error.message))?;
let pending_stream = pending_stream.ok_or_else(|| {
ExecServerError::Protocol(
"http request stream did not return a response body stream".to_string(),
)
})?;
Ok((
response,
HttpResponseBodyStream::local(pending_stream.response),
))
}
.boxed()
}
}
impl ReqwestHttpRequestRunner {
pub(crate) fn new(timeout_ms: Option<u64>) -> Result<Self, JSONRPCErrorError> {
let client = ReqwestHttpClient::build_client(timeout_ms)
.map_err(|error| internal_error(error.to_string()))?;
Ok(Self { client })
}
pub(crate) async fn run(
&self,
params: HttpRequestParams,
) -> Result<(HttpRequestResponse, Option<PendingReqwestHttpBodyStream>), JSONRPCErrorError>
{
let method = Method::from_bytes(params.method.as_bytes())
.map_err(|error| invalid_params(format!("http/request method is invalid: {error}")))?;
let url = Url::parse(&params.url)
.map_err(|error| invalid_params(format!("http/request url is invalid: {error}")))?;
match url.scheme() {
"http" | "https" => {}
scheme => {
return Err(invalid_params(format!(
"http/request only supports http and https URLs, got {scheme}"
)));
}
}
let headers = Self::build_headers(params.headers)?;
let mut request = self.client.request(method, url).headers(headers);
if let Some(body) = params.body {
request = request.body(body.into_inner());
}
let response = request
.send()
.await
.map_err(|error| internal_error(format!("http/request failed: {error}")))?;
let status = response.status().as_u16();
let headers = Self::response_headers(response.headers());
if params.stream_response {
return Ok((
HttpRequestResponse {
status,
headers,
body: Vec::new().into(),
},
Some(PendingReqwestHttpBodyStream {
request_id: params.request_id,
response,
}),
));
}
let body = response.bytes().await.map_err(|error| {
internal_error(format!(
"failed to read http/request response body: {error}"
))
})?;
Ok((
HttpRequestResponse {
status,
headers,
body: body.to_vec().into(),
},
None,
))
}
pub(crate) async fn stream_body(
pending_stream: PendingReqwestHttpBodyStream,
notifications: RpcNotificationSender,
) {
let PendingReqwestHttpBodyStream {
request_id,
response,
} = pending_stream;
let mut seq = 1;
let mut body = response.bytes_stream();
while let Some(chunk) = body.next().await {
match chunk {
Ok(bytes) => {
if !send_body_delta(
&notifications,
HttpRequestBodyDeltaNotification {
request_id: request_id.clone(),
seq,
delta: bytes.to_vec().into(),
done: false,
error: None,
},
)
.await
{
return;
}
seq += 1;
}
Err(error) => {
let _ = send_body_delta(
&notifications,
HttpRequestBodyDeltaNotification {
request_id,
seq,
delta: Vec::new().into(),
done: true,
error: Some(error.to_string()),
},
)
.await;
return;
}
}
}
let _ = send_body_delta(
&notifications,
HttpRequestBodyDeltaNotification {
request_id,
seq,
delta: Vec::new().into(),
done: true,
error: None,
},
)
.await;
}
fn build_headers(headers: Vec<HttpHeader>) -> Result<HeaderMap, JSONRPCErrorError> {
let mut header_map = HeaderMap::new();
for header in headers {
let name = HeaderName::from_bytes(header.name.as_bytes()).map_err(|error| {
invalid_params(format!("http/request header name is invalid: {error}"))
})?;
let value = HeaderValue::from_str(&header.value).map_err(|error| {
invalid_params(format!(
"http/request header value is invalid for {}: {error}",
header.name
))
})?;
header_map.append(name, value);
}
Ok(header_map)
}
fn response_headers(headers: &HeaderMap) -> Vec<HttpHeader> {
headers
.iter()
.filter_map(|(name, value)| {
Some(HttpHeader {
name: name.as_str().to_string(),
value: value.to_str().ok()?.to_string(),
})
})
.collect()
}
}
@@ -0,0 +1,88 @@
//! JSON-RPC-backed `HttpClient` implementation.
//!
//! This code runs in the orchestrator process. It does not issue network
//! requests directly; instead it forwards `http/request` to the remote runtime
//! and then reconstructs streamed bodies from `http/request/bodyDelta`
//! notifications on the shared connection.
use std::sync::Arc;
use futures::FutureExt;
use futures::future::BoxFuture;
use tokio::sync::mpsc;
use super::HttpResponseBodyStream;
use super::response_body_stream::HttpBodyStreamRegistration;
use crate::HttpClient;
use crate::client::ExecServerClient;
use crate::client::ExecServerError;
use crate::protocol::HTTP_REQUEST_METHOD;
use crate::protocol::HttpRequestParams;
use crate::protocol::HttpRequestResponse;
/// Maximum queued body frames per streamed HTTP response.
const HTTP_BODY_DELTA_CHANNEL_CAPACITY: usize = 256;
impl ExecServerClient {
/// Performs an HTTP request and buffers the response body.
pub async fn http_request(
&self,
mut params: HttpRequestParams,
) -> Result<HttpRequestResponse, ExecServerError> {
params.stream_response = false;
self.call(HTTP_REQUEST_METHOD, &params).await
}
/// Performs an HTTP request and returns a body stream.
///
/// The method sets `stream_response` and replaces any caller-supplied
/// `request_id` with a connection-local id, so late deltas from abandoned
/// streams cannot be confused with later requests.
pub async fn http_request_stream(
&self,
mut params: HttpRequestParams,
) -> Result<(HttpRequestResponse, HttpResponseBodyStream), ExecServerError> {
params.stream_response = true;
let request_id = self.inner.next_http_body_stream_request_id();
params.request_id = request_id.clone();
let (tx, rx) = mpsc::channel(HTTP_BODY_DELTA_CHANNEL_CAPACITY);
self.inner
.insert_http_body_stream(request_id.clone(), tx)
.await?;
let mut registration =
HttpBodyStreamRegistration::new(Arc::clone(&self.inner), request_id.clone());
let response = match self.call(HTTP_REQUEST_METHOD, &params).await {
Ok(response) => response,
Err(error) => {
self.inner.remove_http_body_stream(&request_id).await;
registration.disarm();
return Err(error);
}
};
registration.disarm();
Ok((
response,
HttpResponseBodyStream::remote(Arc::clone(&self.inner), request_id, rx),
))
}
}
impl HttpClient for ExecServerClient {
/// Orchestrator-side adapter that forwards buffered HTTP requests to the
/// remote runtime over the shared JSON-RPC connection.
fn http_request(
&self,
params: HttpRequestParams,
) -> BoxFuture<'_, Result<HttpRequestResponse, ExecServerError>> {
async move { ExecServerClient::http_request(self, params).await }.boxed()
}
/// Orchestrator-side adapter that forwards streamed HTTP requests to the
/// remote runtime and exposes body deltas as a byte stream.
fn http_request_stream(
&self,
params: HttpRequestParams,
) -> BoxFuture<'_, Result<(HttpRequestResponse, HttpResponseBodyStream), ExecServerError>> {
async move { ExecServerClient::http_request_stream(self, params).await }.boxed()
}
}
+26
View File
@@ -1,5 +1,12 @@
use std::time::Duration;
use futures::future::BoxFuture;
use crate::ExecServerError;
use crate::HttpRequestParams;
use crate::HttpRequestResponse;
use crate::HttpResponseBodyStream;
/// Connection options for any exec-server client transport.
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct ExecServerClientConnectOptions {
@@ -17,3 +24,22 @@ pub struct RemoteExecServerConnectArgs {
pub initialize_timeout: Duration,
pub resume_session_id: Option<String>,
}
/// Sends HTTP requests through a runtime-selected transport.
///
/// This is the HTTP capability counterpart to [`crate::ExecBackend`]. Callers
/// use it when they need environment-owned network requests but should not
/// depend on the concrete connection type or how that connection is established.
pub trait HttpClient: Send + Sync {
/// Perform an HTTP request and buffer the response body.
fn http_request(
&self,
params: HttpRequestParams,
) -> BoxFuture<'_, Result<HttpRequestResponse, ExecServerError>>;
/// Perform an HTTP request and return a streamed body handle.
fn http_request_stream(
&self,
params: HttpRequestParams,
) -> BoxFuture<'_, Result<(HttpRequestResponse, HttpResponseBodyStream), ExecServerError>>;
}
+12 -1
View File
@@ -3,7 +3,9 @@ use std::sync::Arc;
use crate::ExecServerError;
use crate::ExecServerRuntimePaths;
use crate::HttpClient;
use crate::client::LazyRemoteExecServerClient;
use crate::client::http_client::ReqwestHttpClient;
use crate::file_system::ExecutorFileSystem;
use crate::local_file_system::LocalFileSystem;
use crate::local_process::LocalProcess;
@@ -136,6 +138,7 @@ pub struct Environment {
exec_server_url: Option<String>,
exec_backend: Arc<dyn ExecBackend>,
filesystem: Arc<dyn ExecutorFileSystem>,
http_client: Arc<dyn HttpClient>,
local_runtime_paths: Option<ExecServerRuntimePaths>,
}
@@ -146,6 +149,7 @@ impl Environment {
exec_server_url: None,
exec_backend: Arc::new(LocalProcess::default()),
filesystem: Arc::new(LocalFileSystem::unsandboxed()),
http_client: Arc::new(ReqwestHttpClient),
local_runtime_paths: None,
}
}
@@ -202,6 +206,7 @@ impl Environment {
filesystem: Arc::new(LocalFileSystem::with_runtime_paths(
local_runtime_paths.clone(),
)),
http_client: Arc::new(ReqwestHttpClient),
local_runtime_paths: Some(local_runtime_paths),
}
}
@@ -216,12 +221,14 @@ impl Environment {
) -> Self {
let client = LazyRemoteExecServerClient::new(exec_server_url.clone());
let exec_backend: Arc<dyn ExecBackend> = Arc::new(RemoteProcess::new(client.clone()));
let filesystem: Arc<dyn ExecutorFileSystem> = Arc::new(RemoteFileSystem::new(client));
let filesystem: Arc<dyn ExecutorFileSystem> =
Arc::new(RemoteFileSystem::new(client.clone()));
Self {
exec_server_url: Some(exec_server_url),
exec_backend,
filesystem,
http_client: Arc::new(client),
local_runtime_paths,
}
}
@@ -243,6 +250,10 @@ impl Environment {
Arc::clone(&self.exec_backend)
}
pub fn get_http_client(&self) -> Arc<dyn HttpClient> {
Arc::clone(&self.http_client)
}
pub fn get_filesystem(&self) -> Arc<dyn ExecutorFileSystem> {
Arc::clone(&self.filesystem)
}
+3
View File
@@ -20,7 +20,10 @@ mod server;
pub use client::ExecServerClient;
pub use client::ExecServerError;
pub use client::http_client::HttpResponseBodyStream;
pub use client::http_client::ReqwestHttpClient;
pub use client_api::ExecServerClientConnectOptions;
pub use client_api::HttpClient;
pub use client_api::RemoteExecServerConnectArgs;
pub use environment::CODEX_EXEC_SERVER_URL_ENV_VAR;
pub use environment::Environment;
+5 -5
View File
@@ -12,8 +12,8 @@ use tokio_util::sync::CancellationToken;
use tokio_util::task::TaskTracker;
use crate::ExecServerRuntimePaths;
use crate::client::http_client::ExecutorHttpRequestRunner;
use crate::client::http_client::ExecutorPendingHttpBodyStream;
use crate::client::http_client::PendingReqwestHttpBodyStream;
use crate::client::http_client::ReqwestHttpRequestRunner;
use crate::protocol::ExecParams;
use crate::protocol::ExecResponse;
use crate::protocol::FsCopyParams;
@@ -178,7 +178,7 @@ impl ExecServerHandler {
if stream_response {
self.reserve_http_body_stream(&http_request_id).await?;
}
let response = ExecutorHttpRequestRunner::new(params.timeout_ms)?
let response = ReqwestHttpRequestRunner::new(params.timeout_ms)?
.run(params)
.await;
if response.is_err() && stream_response {
@@ -306,7 +306,7 @@ impl ExecServerHandler {
async fn start_http_body_stream(
self: &Arc<Self>,
pending_stream: ExecutorPendingHttpBodyStream,
pending_stream: PendingReqwestHttpBodyStream,
) {
let request_id = pending_stream.request_id.clone();
if self.background_task_shutdown.is_cancelled() {
@@ -320,7 +320,7 @@ impl ExecServerHandler {
self.background_tasks.spawn(async move {
tokio::select! {
_ = shutdown.cancelled() => {}
_ = ExecutorHttpRequestRunner::stream_body(pending_stream, notifications) => {}
_ = ReqwestHttpRequestRunner::stream_body(pending_stream, notifications) => {}
}
handler.release_http_body_stream(&finished_request_id).await;
});
+3
View File
@@ -3,4 +3,7 @@ load("//:defs.bzl", "codex_rust_crate")
codex_rust_crate(
name = "rmcp-client",
crate_name = "codex_rmcp_client",
extra_binaries = [
"//codex-rs/cli:codex",
],
)
+1
View File
@@ -20,6 +20,7 @@ codex-keyring-store = { workspace = true }
codex-protocol = { workspace = true }
codex-utils-pty = { workspace = true }
codex-utils-home-dir = { workspace = true }
bytes = { workspace = true }
futures = { workspace = true, default-features = false, features = ["std"] }
keyring = { workspace = true, features = ["crypto-rust"] }
oauth2 = "5"
@@ -0,0 +1,391 @@
//! RMCP Streamable HTTP adapter built on top of the shared `HttpClient`
//! capability.
//!
//! This module runs in the orchestrator process. It turns high-level RMCP
//! operations like `post_message` and `get_stream` into calls on
//! `Arc<dyn HttpClient>`, which may be:
//! - a local HTTP client that issues requests from the orchestrator, or
//! - a remote HTTP client that forwards requests to the remote runtime
use std::io;
use std::sync::Arc;
use bytes::Bytes;
use codex_exec_server::ExecServerError;
use codex_exec_server::HttpClient;
use codex_exec_server::HttpHeader;
use codex_exec_server::HttpRequestParams;
use codex_exec_server::HttpResponseBodyStream;
use futures::StreamExt;
use futures::stream;
use futures::stream::BoxStream;
use reqwest::StatusCode;
use reqwest::header::ACCEPT;
use reqwest::header::AUTHORIZATION;
use reqwest::header::CONTENT_TYPE;
use reqwest::header::HeaderMap;
use reqwest::header::HeaderName;
use rmcp::model::ClientJsonRpcMessage;
use rmcp::model::ServerJsonRpcMessage;
use rmcp::transport::streamable_http_client::AuthRequiredError;
use rmcp::transport::streamable_http_client::StreamableHttpClient;
use rmcp::transport::streamable_http_client::StreamableHttpError;
use rmcp::transport::streamable_http_client::StreamableHttpPostResponse;
use sse_stream::Sse;
use sse_stream::SseStream;
const EVENT_STREAM_MIME_TYPE: &str = "text/event-stream";
const JSON_MIME_TYPE: &str = "application/json";
const HEADER_SESSION_ID: &str = "Mcp-Session-Id";
const NON_JSON_RESPONSE_BODY_PREVIEW_BYTES: usize = 8_192;
#[derive(Clone)]
pub(crate) struct StreamableHttpClientAdapter {
http_client: Arc<dyn HttpClient>,
default_headers: HeaderMap,
}
#[derive(Debug, thiserror::Error)]
pub(crate) enum StreamableHttpClientAdapterError {
#[error("streamable HTTP session expired with 404 Not Found")]
SessionExpired404,
#[error(transparent)]
HttpRequest(#[from] ExecServerError),
#[error("invalid HTTP header: {0}")]
Header(String),
}
impl StreamableHttpClientAdapter {
pub(crate) fn new(http_client: Arc<dyn HttpClient>, default_headers: HeaderMap) -> Self {
Self {
http_client,
default_headers,
}
}
}
impl StreamableHttpClient for StreamableHttpClientAdapter {
type Error = StreamableHttpClientAdapterError;
async fn post_message(
&self,
uri: Arc<str>,
message: ClientJsonRpcMessage,
session_id: Option<Arc<str>>,
auth_token: Option<String>,
) -> std::result::Result<StreamableHttpPostResponse, StreamableHttpError<Self::Error>> {
let mut headers = self.default_headers.clone();
insert_header(
&mut headers,
ACCEPT,
[EVENT_STREAM_MIME_TYPE, JSON_MIME_TYPE].join(", "),
StreamableHttpClientAdapterError::Header,
)?;
insert_header(
&mut headers,
CONTENT_TYPE,
JSON_MIME_TYPE.to_string(),
StreamableHttpClientAdapterError::Header,
)?;
if let Some(auth_token) = auth_token {
insert_header(
&mut headers,
AUTHORIZATION,
format!("Bearer {auth_token}"),
StreamableHttpClientAdapterError::Header,
)?;
}
if let Some(session_id_value) = session_id.as_ref() {
insert_header(
&mut headers,
HeaderName::from_static("mcp-session-id"),
session_id_value.to_string(),
StreamableHttpClientAdapterError::Header,
)?;
}
let body = serde_json::to_vec(&message).map_err(StreamableHttpError::Deserialize)?;
let (response, mut body_stream) = self
.http_client
.http_request_stream(HttpRequestParams {
method: "POST".to_string(),
url: uri.to_string(),
headers: protocol_headers(&headers),
body: Some(body.into()),
timeout_ms: None,
request_id: "buffered-request".to_string(),
stream_response: true,
})
.await
.map_err(StreamableHttpClientAdapterError::from)
.map_err(StreamableHttpError::Client)?;
if response.status == StatusCode::NOT_FOUND.as_u16() && session_id.is_some() {
return Err(StreamableHttpError::Client(
StreamableHttpClientAdapterError::SessionExpired404,
));
}
if response.status == StatusCode::UNAUTHORIZED.as_u16()
&& let Some(header) =
response_header(&response.headers, reqwest::header::WWW_AUTHENTICATE)
{
return Err(StreamableHttpError::AuthRequired(AuthRequiredError {
www_authenticate_header: header,
}));
}
if matches!(
StatusCode::from_u16(response.status).ok(),
Some(StatusCode::ACCEPTED | StatusCode::NO_CONTENT)
) {
return Ok(StreamableHttpPostResponse::Accepted);
}
let content_type = response_header(&response.headers, CONTENT_TYPE);
let session_id = response_header(&response.headers, HEADER_SESSION_ID);
match content_type.as_deref() {
Some(content_type) if content_type.starts_with(EVENT_STREAM_MIME_TYPE) => {
let event_stream = sse_stream_from_body(body_stream);
Ok(StreamableHttpPostResponse::Sse(event_stream, session_id))
}
Some(content_type) if content_type.starts_with(JSON_MIME_TYPE) => {
let body = collect_body(&mut body_stream).await?;
let message: ServerJsonRpcMessage =
serde_json::from_slice(&body).map_err(StreamableHttpError::Deserialize)?;
Ok(StreamableHttpPostResponse::Json(message, session_id))
}
_ => {
let body = collect_body(&mut body_stream).await?;
let content_type = content_type.unwrap_or_else(|| "missing-content-type".into());
Err(StreamableHttpError::UnexpectedContentType(Some(format!(
"{content_type}; body: {}",
body_preview(String::from_utf8_lossy(&body).to_string())
))))
}
}
}
async fn delete_session(
&self,
uri: Arc<str>,
session: Arc<str>,
auth_token: Option<String>,
) -> std::result::Result<(), StreamableHttpError<Self::Error>> {
let mut headers = self.default_headers.clone();
if let Some(auth_token) = auth_token {
insert_header(
&mut headers,
AUTHORIZATION,
format!("Bearer {auth_token}"),
StreamableHttpClientAdapterError::Header,
)?;
}
insert_header(
&mut headers,
HeaderName::from_static("mcp-session-id"),
session.to_string(),
StreamableHttpClientAdapterError::Header,
)?;
let response = self
.http_client
.http_request(HttpRequestParams {
method: "DELETE".to_string(),
url: uri.to_string(),
headers: protocol_headers(&headers),
body: None,
timeout_ms: None,
request_id: "buffered-request".to_string(),
stream_response: false,
})
.await
.map_err(StreamableHttpClientAdapterError::from)
.map_err(StreamableHttpError::Client)?;
if response.status == StatusCode::METHOD_NOT_ALLOWED.as_u16() {
return Ok(());
}
if !status_is_success(response.status) {
return Err(StreamableHttpError::UnexpectedServerResponse(
format!("DELETE returned HTTP {}", response.status).into(),
));
}
Ok(())
}
async fn get_stream(
&self,
uri: Arc<str>,
session_id: Arc<str>,
last_event_id: Option<String>,
auth_token: Option<String>,
) -> std::result::Result<
BoxStream<'static, std::result::Result<Sse, sse_stream::Error>>,
StreamableHttpError<Self::Error>,
> {
let mut headers = self.default_headers.clone();
insert_header(
&mut headers,
ACCEPT,
[EVENT_STREAM_MIME_TYPE, JSON_MIME_TYPE].join(", "),
StreamableHttpClientAdapterError::Header,
)?;
insert_header(
&mut headers,
HeaderName::from_static("mcp-session-id"),
session_id.to_string(),
StreamableHttpClientAdapterError::Header,
)?;
if let Some(last_event_id) = last_event_id {
insert_header(
&mut headers,
HeaderName::from_static("last-event-id"),
last_event_id,
StreamableHttpClientAdapterError::Header,
)?;
}
if let Some(auth_token) = auth_token {
insert_header(
&mut headers,
AUTHORIZATION,
format!("Bearer {auth_token}"),
StreamableHttpClientAdapterError::Header,
)?;
}
let (response, body_stream) = self
.http_client
.http_request_stream(HttpRequestParams {
method: "GET".to_string(),
url: uri.to_string(),
headers: protocol_headers(&headers),
body: None,
timeout_ms: None,
request_id: "buffered-request".to_string(),
stream_response: true,
})
.await
.map_err(StreamableHttpClientAdapterError::from)
.map_err(StreamableHttpError::Client)?;
if response.status == StatusCode::METHOD_NOT_ALLOWED.as_u16() {
return Err(StreamableHttpError::ServerDoesNotSupportSse);
}
if response.status == StatusCode::NOT_FOUND.as_u16() {
return Err(StreamableHttpError::Client(
StreamableHttpClientAdapterError::SessionExpired404,
));
}
if !status_is_success(response.status) {
return Err(StreamableHttpError::UnexpectedServerResponse(
format!("GET returned HTTP {}", response.status).into(),
));
}
match response_header(&response.headers, CONTENT_TYPE).as_deref() {
Some(content_type) if is_streamable_http_content_type(content_type) => {}
Some(content_type) => {
return Err(StreamableHttpError::UnexpectedContentType(Some(
content_type.to_string(),
)));
}
None => {
return Err(StreamableHttpError::UnexpectedContentType(None));
}
}
Ok(sse_stream_from_body(body_stream))
}
}
fn body_preview(body: impl Into<String>) -> String {
let mut body_preview = body.into();
let body_len = body_preview.len();
if body_len > NON_JSON_RESPONSE_BODY_PREVIEW_BYTES {
let mut boundary = NON_JSON_RESPONSE_BODY_PREVIEW_BYTES;
while !body_preview.is_char_boundary(boundary) {
boundary = boundary.saturating_sub(1);
}
body_preview.truncate(boundary);
body_preview.push_str(&format!(
"... (truncated {} bytes)",
body_len.saturating_sub(boundary)
));
}
body_preview
}
fn insert_header<Error>(
headers: &mut HeaderMap,
name: HeaderName,
value: String,
map_error: impl FnOnce(String) -> Error,
) -> std::result::Result<(), StreamableHttpError<Error>>
where
Error: std::error::Error + Send + Sync + 'static,
{
let value = reqwest::header::HeaderValue::from_str(&value)
.map_err(|error| StreamableHttpError::Client(map_error(error.to_string())))?;
headers.insert(name, value);
Ok(())
}
fn is_streamable_http_content_type(content_type: &str) -> bool {
content_type
.as_bytes()
.starts_with(EVENT_STREAM_MIME_TYPE.as_bytes())
|| content_type
.as_bytes()
.starts_with(JSON_MIME_TYPE.as_bytes())
}
fn protocol_headers(headers: &HeaderMap) -> Vec<HttpHeader> {
headers
.iter()
.filter_map(|(name, value)| {
Some(HttpHeader {
name: name.as_str().to_string(),
value: value.to_str().ok()?.to_string(),
})
})
.collect()
}
fn response_header(headers: &[HttpHeader], name: impl AsRef<str>) -> Option<String> {
let name = name.as_ref();
headers
.iter()
.find(|header| header.name.eq_ignore_ascii_case(name))
.map(|header| header.value.clone())
}
fn status_is_success(status: u16) -> bool {
StatusCode::from_u16(status).is_ok_and(|status| status.is_success())
}
async fn collect_body(
body_stream: &mut HttpResponseBodyStream,
) -> std::result::Result<Vec<u8>, StreamableHttpError<StreamableHttpClientAdapterError>> {
let mut body = Vec::new();
while let Some(chunk) = body_stream
.recv()
.await
.map_err(StreamableHttpClientAdapterError::from)
.map_err(StreamableHttpError::Client)?
{
body.extend_from_slice(&chunk);
}
Ok(body)
}
fn sse_stream_from_body(
body_stream: HttpResponseBodyStream,
) -> BoxStream<'static, std::result::Result<Sse, sse_stream::Error>> {
SseStream::from_byte_stream(stream::unfold(body_stream, |mut body_stream| async move {
match body_stream.recv().await {
Ok(Some(bytes)) => Some((Ok(Bytes::from(bytes)), body_stream)),
Ok(None) => None,
Err(error) => Some((Err(io::Error::other(error)), body_stream)),
}
}))
.boxed()
}
+1
View File
@@ -1,6 +1,7 @@
mod auth_status;
mod elicitation_client_service;
mod executor_process_transport;
mod http_client_adapter;
mod logging_client_handler;
mod oauth;
mod perform_oauth_login;
+30 -248
View File
@@ -1,4 +1,3 @@
use std::borrow::Cow;
use std::collections::HashMap;
use std::ffi::OsString;
use std::future::Future;
@@ -14,16 +13,12 @@ use anyhow::Result;
use anyhow::anyhow;
use codex_client::build_reqwest_client_with_custom_ca;
use codex_config::types::McpServerEnvVar;
use codex_exec_server::HttpClient;
use futures::FutureExt;
use futures::StreamExt;
use futures::future::BoxFuture;
use futures::stream::BoxStream;
use oauth2::TokenResponse;
use reqwest::header::ACCEPT;
use reqwest::header::AUTHORIZATION;
use reqwest::header::CONTENT_TYPE;
use reqwest::header::HeaderMap;
use reqwest::header::WWW_AUTHENTICATE;
use rmcp::model::CallToolRequestParams;
use rmcp::model::CallToolResult;
use rmcp::model::ClientNotification;
@@ -52,16 +47,11 @@ use rmcp::transport::StreamableHttpClientTransport;
use rmcp::transport::auth::AuthClient;
use rmcp::transport::auth::AuthError;
use rmcp::transport::auth::OAuthState;
use rmcp::transport::streamable_http_client::AuthRequiredError;
use rmcp::transport::streamable_http_client::StreamableHttpClient;
use rmcp::transport::streamable_http_client::StreamableHttpClientTransportConfig;
use rmcp::transport::streamable_http_client::StreamableHttpError;
use rmcp::transport::streamable_http_client::StreamableHttpPostResponse;
use serde::Deserialize;
use serde::Serialize;
use serde_json::Value;
use sse_stream::Sse;
use sse_stream::SseStream;
use tokio::sync::Mutex;
use tokio::sync::Semaphore;
use tokio::sync::watch;
@@ -69,6 +59,8 @@ use tokio::time;
use tracing::warn;
use crate::elicitation_client_service::ElicitationClientService;
use crate::http_client_adapter::StreamableHttpClientAdapter;
use crate::http_client_adapter::StreamableHttpClientAdapterError;
use crate::load_oauth_tokens;
use crate::oauth::OAuthPersistor;
use crate::oauth::StoredOAuthTokens;
@@ -79,239 +71,15 @@ use crate::utils::apply_default_headers;
use crate::utils::build_default_headers;
use codex_config::types::OAuthCredentialsStoreMode;
const EVENT_STREAM_MIME_TYPE: &str = "text/event-stream";
const JSON_MIME_TYPE: &str = "application/json";
const HEADER_LAST_EVENT_ID: &str = "Last-Event-Id";
const HEADER_SESSION_ID: &str = "Mcp-Session-Id";
const NON_JSON_RESPONSE_BODY_PREVIEW_BYTES: usize = 8_192;
#[derive(Clone)]
struct StreamableHttpResponseClient {
inner: reqwest::Client,
}
impl StreamableHttpResponseClient {
fn new(inner: reqwest::Client) -> Self {
Self { inner }
}
fn reqwest_error(
error: reqwest::Error,
) -> StreamableHttpError<StreamableHttpResponseClientError> {
StreamableHttpError::Client(StreamableHttpResponseClientError::from(error))
}
}
fn build_http_client(default_headers: &HeaderMap) -> Result<reqwest::Client> {
let builder = apply_default_headers(reqwest::Client::builder(), default_headers);
Ok(build_reqwest_client_with_custom_ca(builder)?)
}
#[derive(Debug, thiserror::Error)]
enum StreamableHttpResponseClientError {
#[error("streamable HTTP session expired with 404 Not Found")]
SessionExpired404,
#[error(transparent)]
Reqwest(#[from] reqwest::Error),
}
impl StreamableHttpClient for StreamableHttpResponseClient {
type Error = StreamableHttpResponseClientError;
async fn post_message(
&self,
uri: Arc<str>,
message: rmcp::model::ClientJsonRpcMessage,
session_id: Option<Arc<str>>,
auth_token: Option<String>,
) -> std::result::Result<StreamableHttpPostResponse, StreamableHttpError<Self::Error>> {
let mut request = self
.inner
.post(uri.as_ref())
.header(ACCEPT, [EVENT_STREAM_MIME_TYPE, JSON_MIME_TYPE].join(", "));
if let Some(auth_header) = auth_token {
request = request.bearer_auth(auth_header);
}
if let Some(session_id_value) = session_id.as_ref() {
request = request.header(HEADER_SESSION_ID, session_id_value.as_ref());
}
let response = request
.json(&message)
.send()
.await
.map_err(StreamableHttpResponseClient::reqwest_error)?;
if response.status() == reqwest::StatusCode::NOT_FOUND && session_id.is_some() {
return Err(StreamableHttpError::Client(
StreamableHttpResponseClientError::SessionExpired404,
));
}
if response.status() == reqwest::StatusCode::UNAUTHORIZED
&& let Some(header) = response.headers().get(WWW_AUTHENTICATE)
{
let header = header
.to_str()
.map_err(|_| {
StreamableHttpError::UnexpectedServerResponse(Cow::Borrowed(
"invalid www-authenticate header value",
))
})?
.to_string();
return Err(StreamableHttpError::AuthRequired(AuthRequiredError {
www_authenticate_header: header,
}));
}
let status = response.status();
if matches!(
status,
reqwest::StatusCode::ACCEPTED | reqwest::StatusCode::NO_CONTENT
) {
return Ok(StreamableHttpPostResponse::Accepted);
}
let content_type = response
.headers()
.get(CONTENT_TYPE)
.and_then(|value| value.to_str().ok())
.map(str::to_string);
let session_id = response
.headers()
.get(HEADER_SESSION_ID)
.and_then(|value| value.to_str().ok())
.map(str::to_string);
match content_type.as_deref() {
Some(ct) if ct.as_bytes().starts_with(EVENT_STREAM_MIME_TYPE.as_bytes()) => {
let event_stream = SseStream::from_byte_stream(response.bytes_stream()).boxed();
Ok(StreamableHttpPostResponse::Sse(event_stream, session_id))
}
Some(ct) if ct.as_bytes().starts_with(JSON_MIME_TYPE.as_bytes()) => {
let message = response
.json()
.await
.map_err(StreamableHttpResponseClient::reqwest_error)?;
Ok(StreamableHttpPostResponse::Json(message, session_id))
}
_ => {
let body = response
.text()
.await
.map_err(StreamableHttpResponseClient::reqwest_error)?;
let mut body_preview = body;
let body_len = body_preview.len();
if body_len > NON_JSON_RESPONSE_BODY_PREVIEW_BYTES {
let mut boundary = NON_JSON_RESPONSE_BODY_PREVIEW_BYTES;
while !body_preview.is_char_boundary(boundary) {
boundary = boundary.saturating_sub(1);
}
body_preview.truncate(boundary);
body_preview.push_str(&format!(
"... (truncated {} bytes)",
body_len.saturating_sub(boundary)
));
}
let content_type = content_type.unwrap_or_else(|| "missing-content-type".into());
Err(StreamableHttpError::UnexpectedContentType(Some(format!(
"{content_type}; body: {body_preview}"
))))
}
}
}
async fn delete_session(
&self,
uri: Arc<str>,
session: Arc<str>,
auth_token: Option<String>,
) -> std::result::Result<(), StreamableHttpError<Self::Error>> {
let mut request_builder = self.inner.delete(uri.as_ref());
if let Some(auth_header) = auth_token {
request_builder = request_builder.bearer_auth(auth_header);
}
let response = request_builder
.header(HEADER_SESSION_ID, session.as_ref())
.send()
.await
.map_err(StreamableHttpResponseClient::reqwest_error)?;
if response.status() == reqwest::StatusCode::METHOD_NOT_ALLOWED {
return Ok(());
}
response
.error_for_status()
.map_err(StreamableHttpResponseClient::reqwest_error)?;
Ok(())
}
async fn get_stream(
&self,
uri: Arc<str>,
session_id: Arc<str>,
last_event_id: Option<String>,
auth_token: Option<String>,
) -> std::result::Result<
BoxStream<'static, std::result::Result<Sse, sse_stream::Error>>,
StreamableHttpError<Self::Error>,
> {
let mut request_builder = self
.inner
.get(uri.as_ref())
.header(ACCEPT, [EVENT_STREAM_MIME_TYPE, JSON_MIME_TYPE].join(", "))
.header(HEADER_SESSION_ID, session_id.as_ref());
if let Some(last_event_id) = last_event_id {
request_builder = request_builder.header(HEADER_LAST_EVENT_ID, last_event_id);
}
if let Some(auth_header) = auth_token {
request_builder = request_builder.bearer_auth(auth_header);
}
let response = request_builder
.send()
.await
.map_err(StreamableHttpResponseClient::reqwest_error)?;
if response.status() == reqwest::StatusCode::METHOD_NOT_ALLOWED {
return Err(StreamableHttpError::ServerDoesNotSupportSse);
}
if response.status() == reqwest::StatusCode::NOT_FOUND {
return Err(StreamableHttpError::Client(
StreamableHttpResponseClientError::SessionExpired404,
));
}
let response = response
.error_for_status()
.map_err(StreamableHttpResponseClient::reqwest_error)?;
match response.headers().get(CONTENT_TYPE) {
Some(ct)
if ct.as_bytes().starts_with(EVENT_STREAM_MIME_TYPE.as_bytes())
|| ct.as_bytes().starts_with(JSON_MIME_TYPE.as_bytes()) => {}
Some(ct) => {
return Err(StreamableHttpError::UnexpectedContentType(Some(
String::from_utf8_lossy(ct.as_bytes()).to_string(),
)));
}
None => {
return Err(StreamableHttpError::UnexpectedContentType(None));
}
}
let event_stream = SseStream::from_byte_stream(response.bytes_stream()).boxed();
Ok(event_stream)
}
}
enum PendingTransport {
Stdio {
transport: StdioServerTransport,
},
StreamableHttp {
transport: StreamableHttpClientTransport<StreamableHttpResponseClient>,
transport: StreamableHttpClientTransport<StreamableHttpClientAdapter>,
},
StreamableHttpWithOAuth {
transport: StreamableHttpClientTransport<AuthClient<StreamableHttpResponseClient>>,
transport: StreamableHttpClientTransport<AuthClient<StreamableHttpClientAdapter>>,
oauth_persistor: OAuthPersistor,
},
}
@@ -339,6 +107,7 @@ enum TransportRecipe {
http_headers: Option<HashMap<String, String>>,
env_http_headers: Option<HashMap<String, String>>,
store_mode: OAuthCredentialsStoreMode,
http_client: Arc<dyn HttpClient>,
},
}
@@ -536,6 +305,7 @@ impl RmcpClient {
http_headers: Option<HashMap<String, String>>,
env_http_headers: Option<HashMap<String, String>>,
store_mode: OAuthCredentialsStoreMode,
http_client: Arc<dyn HttpClient>,
) -> Result<Self> {
let transport_recipe = TransportRecipe::StreamableHttp {
server_name: server_name.to_string(),
@@ -544,6 +314,7 @@ impl RmcpClient {
http_headers,
env_http_headers,
store_mode,
http_client,
};
let transport = Self::create_pending_transport(&transport_recipe).await?;
Ok(Self {
@@ -895,6 +666,7 @@ impl RmcpClient {
http_headers,
env_http_headers,
store_mode,
http_client,
} => {
let default_headers =
build_default_headers(http_headers.clone(), env_http_headers.clone())?;
@@ -919,6 +691,7 @@ impl RmcpClient {
initial_tokens.clone(),
*store_mode,
default_headers.clone(),
Arc::clone(http_client),
)
.await
{
@@ -945,9 +718,11 @@ impl RmcpClient {
let http_config =
StreamableHttpClientTransportConfig::with_uri(url.clone())
.auth_header(access_token);
let http_client = build_http_client(&default_headers)?;
let transport = StreamableHttpClientTransport::with_client(
StreamableHttpResponseClient::new(http_client),
StreamableHttpClientAdapter::new(
Arc::clone(http_client),
default_headers,
),
http_config,
);
Ok(PendingTransport::StreamableHttp { transport })
@@ -961,10 +736,8 @@ impl RmcpClient {
http_config = http_config.auth_header(bearer_token);
}
let http_client = build_http_client(&default_headers)?;
let transport = StreamableHttpClientTransport::with_client(
StreamableHttpResponseClient::new(http_client),
StreamableHttpClientAdapter::new(Arc::clone(http_client), default_headers),
http_config,
);
Ok(PendingTransport::StreamableHttp { transport })
@@ -1084,12 +857,12 @@ impl RmcpClient {
error
.error
.downcast_ref::<StreamableHttpError<StreamableHttpResponseClientError>>()
.downcast_ref::<StreamableHttpError<StreamableHttpClientAdapterError>>()
.is_some_and(|error| {
matches!(
error,
StreamableHttpError::Client(
StreamableHttpResponseClientError::SessionExpired404
StreamableHttpClientAdapterError::SessionExpired404
)
)
})
@@ -1156,12 +929,18 @@ async fn create_oauth_transport_and_runtime(
initial_tokens: StoredOAuthTokens,
credentials_store: OAuthCredentialsStoreMode,
default_headers: HeaderMap,
http_client: Arc<dyn HttpClient>,
) -> Result<(
StreamableHttpClientTransport<AuthClient<StreamableHttpResponseClient>>,
StreamableHttpClientTransport<AuthClient<StreamableHttpClientAdapter>>,
OAuthPersistor,
)> {
let http_client = build_http_client(&default_headers)?;
let mut oauth_state = OAuthState::new(url.to_string(), Some(http_client.clone())).await?;
let builder = apply_default_headers(reqwest::Client::builder(), &default_headers);
let oauth_metadata_client = build_reqwest_client_with_custom_ca(builder)?;
// TODO(aibrahim): teach OAuth bootstrap and refresh to use the same
// shared HTTP client abstraction instead of always creating the local
// reqwest metadata client here.
let mut oauth_state =
OAuthState::new(url.to_string(), Some(oauth_metadata_client.clone())).await?;
oauth_state
.set_credentials(
@@ -1178,7 +957,10 @@ async fn create_oauth_transport_and_runtime(
}
};
let auth_client = AuthClient::new(StreamableHttpResponseClient::new(http_client), manager);
let auth_client = AuthClient::new(
StreamableHttpClientAdapter::new(http_client, default_headers),
manager,
);
let auth_manager = auth_client.auth_manager.clone();
let transport = StreamableHttpClientTransport::with_client(
@@ -1,190 +1,12 @@
use std::net::TcpListener;
use std::path::PathBuf;
use std::time::Duration;
use std::time::Instant;
mod streamable_http_test_support;
use codex_config::types::OAuthCredentialsStoreMode;
use codex_rmcp_client::ElicitationAction;
use codex_rmcp_client::ElicitationResponse;
use codex_rmcp_client::RmcpClient;
use codex_utils_cargo_bin::CargoBinError;
use futures::FutureExt as _;
use pretty_assertions::assert_eq;
use rmcp::model::CallToolResult;
use rmcp::model::ClientCapabilities;
use rmcp::model::ElicitationCapability;
use rmcp::model::FormElicitationCapability;
use rmcp::model::Implementation;
use rmcp::model::InitializeRequestParams;
use rmcp::model::ProtocolVersion;
use serde_json::json;
use tokio::net::TcpStream;
use tokio::process::Child;
use tokio::process::Command;
use tokio::time::sleep;
const SESSION_POST_FAILURE_CONTROL_PATH: &str = "/test/control/session-post-failure";
fn streamable_http_server_bin() -> Result<PathBuf, CargoBinError> {
codex_utils_cargo_bin::cargo_bin("test_streamable_http_server")
}
fn init_params() -> InitializeRequestParams {
InitializeRequestParams {
meta: None,
capabilities: ClientCapabilities {
experimental: None,
extensions: None,
roots: None,
sampling: None,
elicitation: Some(ElicitationCapability {
form: Some(FormElicitationCapability {
schema_validation: None,
}),
url: None,
}),
tasks: None,
},
client_info: Implementation {
name: "codex-test".into(),
version: "0.0.0-test".into(),
title: Some("Codex rmcp recovery test".into()),
description: None,
icons: None,
website_url: None,
},
protocol_version: ProtocolVersion::V_2025_06_18,
}
}
fn expected_echo_result(message: &str) -> CallToolResult {
CallToolResult {
content: Vec::new(),
structured_content: Some(json!({
"echo": format!("ECHOING: {message}"),
"env": null,
})),
is_error: Some(false),
meta: None,
}
}
async fn create_client(base_url: &str) -> anyhow::Result<RmcpClient> {
let client = RmcpClient::new_streamable_http_client(
"test-streamable-http",
&format!("{base_url}/mcp"),
Some("test-bearer".to_string()),
/*http_headers*/ None,
/*env_http_headers*/ None,
OAuthCredentialsStoreMode::File,
)
.await?;
client
.initialize(
init_params(),
Some(Duration::from_secs(5)),
Box::new(|_, _| {
async {
Ok(ElicitationResponse {
action: ElicitationAction::Accept,
content: Some(json!({})),
meta: None,
})
}
.boxed()
}),
)
.await?;
Ok(client)
}
async fn call_echo_tool(client: &RmcpClient, message: &str) -> anyhow::Result<CallToolResult> {
client
.call_tool(
"echo".to_string(),
Some(json!({ "message": message })),
/*meta*/ None,
Some(Duration::from_secs(5)),
)
.await
}
async fn arm_session_post_failure(
base_url: &str,
status: u16,
remaining: usize,
) -> anyhow::Result<()> {
let response = reqwest::Client::new()
.post(format!("{base_url}{SESSION_POST_FAILURE_CONTROL_PATH}"))
.json(&json!({
"status": status,
"remaining": remaining,
}))
.send()
.await?;
assert_eq!(response.status(), reqwest::StatusCode::NO_CONTENT);
Ok(())
}
async fn spawn_streamable_http_server() -> anyhow::Result<(Child, String)> {
let listener = TcpListener::bind("127.0.0.1:0")?;
let port = listener.local_addr()?.port();
drop(listener);
let bind_addr = format!("127.0.0.1:{port}");
let base_url = format!("http://{bind_addr}");
let mut child = Command::new(streamable_http_server_bin()?)
.kill_on_drop(true)
.env("MCP_STREAMABLE_HTTP_BIND_ADDR", &bind_addr)
.spawn()?;
wait_for_streamable_http_server(&mut child, &bind_addr, Duration::from_secs(5)).await?;
Ok((child, base_url))
}
async fn wait_for_streamable_http_server(
server_child: &mut Child,
address: &str,
timeout: Duration,
) -> anyhow::Result<()> {
let deadline = Instant::now() + timeout;
loop {
if let Some(status) = server_child.try_wait()? {
return Err(anyhow::anyhow!(
"streamable HTTP server exited early with status {status}"
));
}
let remaining = deadline.saturating_duration_since(Instant::now());
if remaining.is_zero() {
return Err(anyhow::anyhow!(
"timed out waiting for streamable HTTP server at {address}: deadline reached"
));
}
match tokio::time::timeout(remaining, TcpStream::connect(address)).await {
Ok(Ok(_)) => return Ok(()),
Ok(Err(error)) => {
if Instant::now() >= deadline {
return Err(anyhow::anyhow!(
"timed out waiting for streamable HTTP server at {address}: {error}"
));
}
}
Err(_) => {
return Err(anyhow::anyhow!(
"timed out waiting for streamable HTTP server at {address}: connect call timed out"
));
}
}
sleep(Duration::from_millis(50)).await;
}
}
use streamable_http_test_support::arm_session_post_failure;
use streamable_http_test_support::call_echo_tool;
use streamable_http_test_support::create_client;
use streamable_http_test_support::expected_echo_result;
use streamable_http_test_support::spawn_streamable_http_server;
#[tokio::test(flavor = "multi_thread", worker_threads = 1)]
async fn streamable_http_404_session_expiry_recovers_and_retries_once() -> anyhow::Result<()> {
@@ -0,0 +1,37 @@
//! Integration coverage for the remote Streamable HTTP RMCP path.
//!
//! These tests exercise the orchestrator-side RMCP adapter against a real
//! `exec-server` process so HTTP requests go through the remote runtime path
//! instead of direct local `reqwest` calls.
mod streamable_http_test_support;
use pretty_assertions::assert_eq;
use streamable_http_test_support::call_echo_tool;
use streamable_http_test_support::create_remote_client;
use streamable_http_test_support::expected_echo_result;
use streamable_http_test_support::spawn_exec_server;
use streamable_http_test_support::spawn_streamable_http_server;
/// What this tests: the RMCP remote Streamable HTTP adapter can initialize
/// a server and call a tool while every MCP HTTP request goes through a real
/// exec-server process instead of a direct reqwest transport.
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn streamable_http_remote_client_round_trips_through_exec_server() -> anyhow::Result<()> {
// Phase 1: start the MCP Streamable HTTP test server and a local
// exec-server process that will own the HTTP network calls.
let (_server, base_url) = spawn_streamable_http_server().await?;
let exec_server = spawn_exec_server().await?;
// Phase 2: create and initialize the RMCP client using the executor-backed
// Streamable HTTP transport.
let client = create_remote_client(&base_url, exec_server.client.clone()).await?;
// Phase 3: prove the initialized client can complete a tool call and
// preserve the normal RMCP response shape.
let result = call_echo_tool(&client, "remote").await?;
assert_eq!(result, expected_echo_result("remote"));
Ok(())
}
@@ -0,0 +1,314 @@
//! Shared helpers for Streamable HTTP RMCP integration tests.
//!
//! This support module starts the test HTTP server, launches a real
//! `exec-server` when remote coverage is needed, and provides small helpers for
//! creating RMCP clients and asserting round-trip behavior.
// This support module is included by multiple integration-test crates. Each
// crate uses a different subset of the helpers, so dead-code warnings would
// otherwise depend on which test file compiled the module.
#![allow(dead_code)]
use std::net::TcpListener;
use std::path::PathBuf;
use std::process::Stdio;
use std::sync::Arc;
use std::time::Duration;
use std::time::Instant;
use anyhow::Context as _;
use codex_config::types::OAuthCredentialsStoreMode;
use codex_exec_server::Environment;
use codex_exec_server::ExecServerClient;
use codex_exec_server::RemoteExecServerConnectArgs;
use codex_rmcp_client::ElicitationAction;
use codex_rmcp_client::ElicitationResponse;
use codex_rmcp_client::RmcpClient;
use codex_utils_cargo_bin::CargoBinError;
use futures::FutureExt as _;
use pretty_assertions::assert_eq;
use rmcp::model::CallToolResult;
use rmcp::model::ClientCapabilities;
use rmcp::model::ElicitationCapability;
use rmcp::model::FormElicitationCapability;
use rmcp::model::Implementation;
use rmcp::model::InitializeRequestParams;
use rmcp::model::ProtocolVersion;
use serde_json::json;
use tempfile::TempDir;
use tokio::io::AsyncBufReadExt;
use tokio::io::BufReader;
use tokio::net::TcpStream;
use tokio::process::Child;
use tokio::process::Command;
use tokio::time::sleep;
const SESSION_POST_FAILURE_CONTROL_PATH: &str = "/test/control/session-post-failure";
fn streamable_http_server_bin() -> Result<PathBuf, CargoBinError> {
codex_utils_cargo_bin::cargo_bin("test_streamable_http_server")
}
fn init_params() -> InitializeRequestParams {
InitializeRequestParams {
meta: None,
capabilities: ClientCapabilities {
experimental: None,
extensions: None,
roots: None,
sampling: None,
elicitation: Some(ElicitationCapability {
form: Some(FormElicitationCapability {
schema_validation: None,
}),
url: None,
}),
tasks: None,
},
client_info: Implementation {
name: "codex-test".into(),
version: "0.0.0-test".into(),
title: Some("Codex rmcp recovery test".into()),
description: None,
icons: None,
website_url: None,
},
protocol_version: ProtocolVersion::V_2025_06_18,
}
}
pub(crate) fn expected_echo_result(message: &str) -> CallToolResult {
CallToolResult {
content: Vec::new(),
structured_content: Some(json!({
"echo": format!("ECHOING: {message}"),
"env": null,
})),
is_error: Some(false),
meta: None,
}
}
pub(crate) async fn create_client(base_url: &str) -> anyhow::Result<RmcpClient> {
let client = RmcpClient::new_streamable_http_client(
"test-streamable-http",
&format!("{base_url}/mcp"),
Some("test-bearer".to_string()),
/*http_headers*/ None,
/*env_http_headers*/ None,
OAuthCredentialsStoreMode::File,
Environment::default_for_tests().get_http_client(),
)
.await?;
client
.initialize(
init_params(),
Some(Duration::from_secs(5)),
Box::new(|_, _| {
async {
Ok(ElicitationResponse {
action: ElicitationAction::Accept,
content: Some(json!({})),
meta: None,
})
}
.boxed()
}),
)
.await?;
Ok(client)
}
/// Creates a Streamable HTTP RMCP client that sends traffic through the remote
/// runtime HTTP API.
pub(crate) async fn create_remote_client(
base_url: &str,
http_client: ExecServerClient,
) -> anyhow::Result<RmcpClient> {
let client = RmcpClient::new_streamable_http_client(
"test-streamable-http-remote",
&format!("{base_url}/mcp"),
Some("test-bearer".to_string()),
/*http_headers*/ None,
/*env_http_headers*/ None,
OAuthCredentialsStoreMode::File,
Arc::new(http_client),
)
.await?;
client
.initialize(
init_params(),
Some(Duration::from_secs(5)),
Box::new(|_, _| {
async {
Ok(ElicitationResponse {
action: ElicitationAction::Accept,
content: Some(json!({})),
meta: None,
})
}
.boxed()
}),
)
.await?;
Ok(client)
}
pub(crate) async fn call_echo_tool(
client: &RmcpClient,
message: &str,
) -> anyhow::Result<CallToolResult> {
client
.call_tool(
"echo".to_string(),
Some(json!({ "message": message })),
/*meta*/ None,
Some(Duration::from_secs(5)),
)
.await
}
pub(crate) async fn arm_session_post_failure(
base_url: &str,
status: u16,
remaining: usize,
) -> anyhow::Result<()> {
let response = reqwest::Client::new()
.post(format!("{base_url}{SESSION_POST_FAILURE_CONTROL_PATH}"))
.json(&json!({
"status": status,
"remaining": remaining,
}))
.send()
.await?;
assert_eq!(response.status(), reqwest::StatusCode::NO_CONTENT);
Ok(())
}
pub(crate) async fn spawn_streamable_http_server() -> anyhow::Result<(Child, String)> {
let listener = TcpListener::bind("127.0.0.1:0")?;
let port = listener.local_addr()?.port();
drop(listener);
let bind_addr = format!("127.0.0.1:{port}");
let base_url = format!("http://{bind_addr}");
let mut child = Command::new(streamable_http_server_bin()?)
.kill_on_drop(true)
.env("MCP_STREAMABLE_HTTP_BIND_ADDR", &bind_addr)
.spawn()?;
wait_for_streamable_http_server(&mut child, &bind_addr, Duration::from_secs(5)).await?;
Ok((child, base_url))
}
/// Owns the exec-server process used by the remote-client integration test.
pub(crate) struct ExecServerProcess {
_codex_home: TempDir,
child: Child,
pub(crate) client: ExecServerClient,
}
impl Drop for ExecServerProcess {
/// Stops the local exec-server process best-effort when the test exits.
fn drop(&mut self) {
let _ = self.child.start_kill();
}
}
/// Starts a local exec-server and connects an initialized `ExecServerClient`.
pub(crate) async fn spawn_exec_server() -> anyhow::Result<ExecServerProcess> {
let codex_home = TempDir::new()?;
let mut child = Command::new(codex_utils_cargo_bin::cargo_bin("codex")?)
.args(["exec-server", "--listen", "ws://127.0.0.1:0"])
.stdin(Stdio::null())
.stdout(Stdio::piped())
.stderr(Stdio::inherit())
.kill_on_drop(true)
.env("CODEX_HOME", codex_home.path())
.spawn()?;
let websocket_url = read_exec_server_listen_url(&mut child).await?;
let client = ExecServerClient::connect_websocket(RemoteExecServerConnectArgs::new(
websocket_url,
"rmcp-client-remote-http-test".to_string(),
))
.await?;
Ok(ExecServerProcess {
_codex_home: codex_home,
child,
client,
})
}
/// Reads the websocket URL printed by `codex exec-server --listen`.
async fn read_exec_server_listen_url(child: &mut Child) -> anyhow::Result<String> {
let stdout = child
.stdout
.take()
.context("failed to capture exec-server stdout")?;
let mut lines = BufReader::new(stdout).lines();
let deadline = Instant::now() + Duration::from_secs(10);
loop {
let remaining = deadline.saturating_duration_since(Instant::now());
if remaining.is_zero() {
anyhow::bail!("timed out waiting for exec-server listen URL");
}
let line = tokio::time::timeout(remaining, lines.next_line())
.await
.context("timed out waiting for exec-server stdout")??
.context("exec-server stdout closed before emitting listen URL")?;
let listen_url = line.trim();
if listen_url.starts_with("ws://") {
return Ok(listen_url.to_string());
}
}
}
async fn wait_for_streamable_http_server(
server_child: &mut Child,
address: &str,
timeout: Duration,
) -> anyhow::Result<()> {
let deadline = Instant::now() + timeout;
loop {
if let Some(status) = server_child.try_wait()? {
return Err(anyhow::anyhow!(
"streamable HTTP server exited early with status {status}"
));
}
let remaining = deadline.saturating_duration_since(Instant::now());
if remaining.is_zero() {
return Err(anyhow::anyhow!(
"timed out waiting for streamable HTTP server at {address}: deadline reached"
));
}
match tokio::time::timeout(remaining, TcpStream::connect(address)).await {
Ok(Ok(_)) => return Ok(()),
Ok(Err(error)) => {
if Instant::now() >= deadline {
return Err(anyhow::anyhow!(
"timed out waiting for streamable HTTP server at {address}: {error}"
));
}
}
Err(_) => {
return Err(anyhow::anyhow!(
"timed out waiting for streamable HTTP server at {address}: connect call timed out"
));
}
}
sleep(Duration::from_millis(50)).await;
}
}