From 0e78ce80eebbb08856e4a212ecd045ab42d72948 Mon Sep 17 00:00:00 2001 From: Ahmed Ibrahim Date: Wed, 22 Apr 2026 17:38:04 -0700 Subject: [PATCH] [3/4] Add executor-backed RMCP HTTP client (#18583) ### Why The RMCP layer needs a Streamable HTTP client that can talk either directly over `reqwest` or through the executor HTTP runner without duplicating MCP session logic higher in the stack. This PR adds that client-side transport boundary so remote Streamable HTTP MCP can reuse the same RMCP flow as the local path. ### What - Add a shared `rmcp-client/src/streamable_http/` module with: - `transport_client.rs` for the local-or-remote transport enum - `local_client.rs` for the direct `reqwest` implementation - `remote_client.rs` for the executor-backed implementation - `common.rs` for the small shared Streamable HTTP helpers - Teach `RmcpClient` to build Streamable HTTP transports in either local or remote mode while keeping the existing OAuth ownership in RMCP. - Translate remote POST, GET, and DELETE session operations into executor `http/request` calls. - Preserve RMCP session expiry handling and reconnect behavior for the remote transport. - Add remote transport coverage in `rmcp-client/tests/streamable_http_remote.rs` and keep the shared test support in `rmcp-client/tests/streamable_http_test_support.rs`. ### Verification - `cargo check -p codex-rmcp-client` - online CI ### Stack 1. #18581 protocol 2. #18582 runner 3. #18583 RMCP client 4. #18584 manager wiring and local/remote coverage --------- Co-authored-by: Codex --- codex-rs/Cargo.lock | 3 + .../codex-mcp/src/mcp_connection_manager.rs | 17 +- codex-rs/exec-server/Cargo.toml | 2 + codex-rs/exec-server/src/client.rs | 24 + .../exec-server/src/client/http_client.rs | 544 +----------------- .../src/client/http_response_body_stream.rs | 355 ++++++++++++ .../src/client/reqwest_http_client.rs | 267 +++++++++ .../exec-server/src/client/rpc_http_client.rs | 88 +++ codex-rs/exec-server/src/client_api.rs | 26 + codex-rs/exec-server/src/environment.rs | 13 +- codex-rs/exec-server/src/lib.rs | 3 + codex-rs/exec-server/src/server/handler.rs | 10 +- codex-rs/rmcp-client/BUILD.bazel | 3 + codex-rs/rmcp-client/Cargo.toml | 1 + .../rmcp-client/src/http_client_adapter.rs | 391 +++++++++++++ codex-rs/rmcp-client/src/lib.rs | 1 + codex-rs/rmcp-client/src/rmcp_client.rs | 278 +-------- .../tests/streamable_http_recovery.rs | 190 +----- .../tests/streamable_http_remote.rs | 37 ++ .../tests/streamable_http_test_support.rs | 314 ++++++++++ 20 files changed, 1595 insertions(+), 972 deletions(-) create mode 100644 codex-rs/exec-server/src/client/http_response_body_stream.rs create mode 100644 codex-rs/exec-server/src/client/reqwest_http_client.rs create mode 100644 codex-rs/exec-server/src/client/rpc_http_client.rs create mode 100644 codex-rs/rmcp-client/src/http_client_adapter.rs create mode 100644 codex-rs/rmcp-client/tests/streamable_http_remote.rs create mode 100644 codex-rs/rmcp-client/tests/streamable_http_test_support.rs diff --git a/codex-rs/Cargo.lock b/codex-rs/Cargo.lock index 6bc53a49f..74473e174 100644 --- a/codex-rs/Cargo.lock +++ b/codex-rs/Cargo.lock @@ -2579,7 +2579,9 @@ dependencies = [ "arc-swap", "async-trait", "base64 0.22.1", + "bytes", "codex-app-server-protocol", + "codex-client", "codex-config", "codex-protocol", "codex-sandboxing", @@ -3141,6 +3143,7 @@ version = "0.0.0" dependencies = [ "anyhow", "axum", + "bytes", "codex-client", "codex-config", "codex-exec-server", diff --git a/codex-rs/codex-mcp/src/mcp_connection_manager.rs b/codex-rs/codex-mcp/src/mcp_connection_manager.rs index 10c48e040..fe8bc523e 100644 --- a/codex-rs/codex-mcp/src/mcp_connection_manager.rs +++ b/codex-rs/codex-mcp/src/mcp_connection_manager.rs @@ -1589,23 +1589,11 @@ async fn make_rmcp_client( env_http_headers, bearer_token_env_var, } => { - if remote_environment { - if !runtime_environment.environment().is_remote() { - return Err(StartupOutcomeError::from(anyhow!( - "remote MCP server `{server_name}` requires a remote executor environment" - ))); - } + if remote_environment && !runtime_environment.environment().is_remote() { return Err(StartupOutcomeError::from(anyhow!( - // Remote HTTP needs the future low-level executor - // `network/request` API so reqwest runs on the executor side. - // Do not fall back to local HTTP here; the config explicitly - // asked for remote placement. - "remote streamable HTTP MCP server `{server_name}` is not implemented yet" + "remote MCP server `{server_name}` requires a remote environment" ))); } - - // Local streamable HTTP remains the existing reqwest path from - // the orchestrator process. let resolved_bearer_token = match resolve_bearer_token(server_name, bearer_token_env_var.as_deref()) { Ok(token) => token, @@ -1618,6 +1606,7 @@ async fn make_rmcp_client( http_headers, env_http_headers, store_mode, + runtime_environment.environment().get_http_client(), ) .await .map_err(StartupOutcomeError::from) diff --git a/codex-rs/exec-server/Cargo.toml b/codex-rs/exec-server/Cargo.toml index 968806327..a1a25e6e9 100644 --- a/codex-rs/exec-server/Cargo.toml +++ b/codex-rs/exec-server/Cargo.toml @@ -14,7 +14,9 @@ workspace = true arc-swap = { workspace = true } async-trait = { workspace = true } base64 = { workspace = true } +bytes = { workspace = true } codex-app-server-protocol = { workspace = true } +codex-client = { workspace = true } codex-config = { workspace = true } codex-protocol = { workspace = true } codex-sandboxing = { workspace = true } diff --git a/codex-rs/exec-server/src/client.rs b/codex-rs/exec-server/src/client.rs index 375571b77..f26069ac7 100644 --- a/codex-rs/exec-server/src/client.rs +++ b/codex-rs/exec-server/src/client.rs @@ -8,6 +8,8 @@ use std::time::Duration; use arc_swap::ArcSwap; use codex_app_server_protocol::JSONRPCNotification; +use futures::FutureExt; +use futures::future::BoxFuture; use serde_json::Value; use tokio::sync::Mutex; use tokio::sync::OnceCell; @@ -20,6 +22,7 @@ use tracing::debug; use crate::ProcessId; use crate::client_api::ExecServerClientConnectOptions; +use crate::client_api::HttpClient; use crate::client_api::RemoteExecServerConnectArgs; use crate::connection::JsonRpcConnection; use crate::process::ExecProcessEvent; @@ -206,6 +209,25 @@ impl LazyRemoteExecServerClient { } } +impl HttpClient for LazyRemoteExecServerClient { + fn http_request( + &self, + params: crate::HttpRequestParams, + ) -> BoxFuture<'_, Result> { + async move { self.get().await?.http_request(params).await }.boxed() + } + + fn http_request_stream( + &self, + params: crate::HttpRequestParams, + ) -> BoxFuture< + '_, + Result<(crate::HttpRequestResponse, crate::HttpResponseBodyStream), ExecServerError>, + > { + async move { self.get().await?.http_request_stream(params).await }.boxed() + } +} + #[derive(Debug, thiserror::Error)] pub enum ExecServerError { #[error("failed to spawn exec-server: {0}")] @@ -226,6 +248,8 @@ pub enum ExecServerError { Disconnected(String), #[error("failed to serialize or deserialize exec-server JSON: {0}")] Json(#[from] serde_json::Error), + #[error("HTTP request failed: {0}")] + HttpRequest(String), #[error("exec-server protocol error: {0}")] Protocol(String), #[error("exec-server rejected request ({code}): {message}")] diff --git a/codex-rs/exec-server/src/client/http_client.rs b/codex-rs/exec-server/src/client/http_client.rs index 1e91fa448..cfbb3a60b 100644 --- a/codex-rs/exec-server/src/client/http_client.rs +++ b/codex-rs/exec-server/src/client/http_client.rs @@ -1,522 +1,26 @@ -use std::collections::HashMap; -use std::sync::Arc; -use std::sync::atomic::Ordering; -use std::time::Duration; +//! HTTP client capability implementations shared by local and remote environments. +//! +//! This module is the facade for the environment-owned [`crate::HttpClient`] +//! capability: +//! - [`ReqwestHttpClient`] executes requests directly with `reqwest` +//! - [`ExecServerClient`] forwards requests over the JSON-RPC transport +//! - [`HttpResponseBodyStream`] presents buffered local bodies and streamed +//! remote `http/request/bodyDelta` notifications through one byte-stream API +//! +//! Runtime split: +//! - orchestrator process: holds an `Arc` and chooses local or +//! remote execution +//! - remote runtime: serves the `http/request` RPC and runs the concrete local +//! HTTP request there when the orchestrator uses [`ExecServerClient`] -use codex_app_server_protocol::JSONRPCErrorError; -use futures::StreamExt; -use reqwest::Method; -use reqwest::Url; -use reqwest::header::HeaderMap; -use reqwest::header::HeaderName; -use reqwest::header::HeaderValue; -use serde_json::Value; -use serde_json::from_value; -use tokio::runtime::Handle; -use tokio::sync::mpsc; -use tokio::sync::mpsc::error::TrySendError; -use tracing::debug; +#[path = "reqwest_http_client.rs"] +mod reqwest_http_client; +#[path = "http_response_body_stream.rs"] +pub(crate) mod response_body_stream; +#[path = "rpc_http_client.rs"] +mod rpc_http_client; -use super::ExecServerClient; -use super::ExecServerError; -use super::Inner; -use crate::protocol::HTTP_REQUEST_BODY_DELTA_METHOD; -use crate::protocol::HTTP_REQUEST_METHOD; -use crate::protocol::HttpHeader; -use crate::protocol::HttpRequestBodyDeltaNotification; -use crate::protocol::HttpRequestParams; -use crate::protocol::HttpRequestResponse; -use crate::rpc::RpcNotificationSender; -use crate::rpc::internal_error; -use crate::rpc::invalid_params; - -/// Maximum queued body frames per streamed executor HTTP response. -const HTTP_BODY_DELTA_CHANNEL_CAPACITY: usize = 256; - -pub(crate) struct ExecutorPendingHttpBodyStream { - pub(crate) request_id: String, - response: reqwest::Response, -} - -pub(crate) struct ExecutorHttpRequestRunner { - client: reqwest::Client, -} - -/// Request-scoped stream of body chunks for an executor HTTP response. -/// -/// The initial `http/request` call returns status and headers. This stream then -/// receives the ordered `http/request/bodyDelta` notifications for that request -/// id until EOF or a terminal error. -pub struct HttpResponseBodyStream { - inner: Arc, - request_id: String, - next_seq: u64, - rx: mpsc::Receiver, - // Terminal frames can carry a final chunk; return that once, then EOF. - pending_eof: bool, - closed: bool, -} - -impl ExecServerClient { - /// Performs an executor-side HTTP request and buffers the response body. - pub async fn http_request( - &self, - mut params: HttpRequestParams, - ) -> Result { - params.stream_response = false; - self.call(HTTP_REQUEST_METHOD, ¶ms).await - } - - /// Performs an executor-side HTTP request and returns a body stream. - /// - /// The method sets `stream_response` and replaces any caller-supplied - /// `request_id` with a connection-local id, so late deltas from abandoned - /// streams cannot be confused with later requests. - pub async fn http_request_stream( - &self, - mut params: HttpRequestParams, - ) -> Result<(HttpRequestResponse, HttpResponseBodyStream), ExecServerError> { - params.stream_response = true; - let request_id = self.inner.next_http_body_stream_request_id(); - params.request_id = request_id.clone(); - let (tx, rx) = mpsc::channel(HTTP_BODY_DELTA_CHANNEL_CAPACITY); - self.inner - .insert_http_body_stream(request_id.clone(), tx) - .await?; - let mut registration = HttpBodyStreamRegistration { - inner: Arc::clone(&self.inner), - request_id: request_id.clone(), - active: true, - }; - let response = match self.call(HTTP_REQUEST_METHOD, ¶ms).await { - Ok(response) => response, - Err(error) => { - self.inner.remove_http_body_stream(&request_id).await; - registration.active = false; - return Err(error); - } - }; - registration.active = false; - Ok(( - response, - HttpResponseBodyStream { - inner: Arc::clone(&self.inner), - request_id, - next_seq: 1, - rx, - pending_eof: false, - closed: false, - }, - )) - } -} - -impl HttpResponseBodyStream { - /// Receives the next response-body chunk. - /// - /// Returns `Ok(None)` at EOF and converts sequence gaps or executor-side - /// stream errors into protocol errors. - pub async fn recv(&mut self) -> Result>, ExecServerError> { - if self.pending_eof { - self.pending_eof = false; - self.finish().await; - return Ok(None); - } - - let Some(delta) = self.rx.recv().await else { - self.finish().await; - if let Some(error) = self - .inner - .take_http_body_stream_failure(&self.request_id) - .await - { - return Err(ExecServerError::Protocol(format!( - "http response stream `{}` failed: {error}", - self.request_id - ))); - } - return Ok(None); - }; - if delta.seq != self.next_seq { - self.finish().await; - return Err(ExecServerError::Protocol(format!( - "http response stream `{}` received seq {}, expected {}", - self.request_id, delta.seq, self.next_seq - ))); - } - self.next_seq += 1; - let chunk = delta.delta.into_inner(); - - if let Some(error) = delta.error { - self.finish().await; - return Err(ExecServerError::Protocol(format!( - "http response stream `{}` failed: {error}", - self.request_id - ))); - } - if delta.done { - self.finish().await; - if chunk.is_empty() { - return Ok(None); - } - self.pending_eof = true; - } - Ok(Some(chunk)) - } - - /// Removes this stream from the connection routing table once it reaches EOF. - async fn finish(&mut self) { - if self.closed { - return; - } - self.closed = true; - self.inner.remove_http_body_stream(&self.request_id).await; - } -} - -impl Drop for HttpResponseBodyStream { - /// Schedules stream-route removal if the consumer drops before EOF. - fn drop(&mut self) { - if self.closed { - return; - } - self.closed = true; - spawn_remove_http_body_stream(Arc::clone(&self.inner), self.request_id.clone()); - } -} - -impl ExecutorHttpRequestRunner { - pub(crate) fn new(timeout_ms: Option) -> Result { - let client = match timeout_ms { - None => reqwest::Client::builder(), - Some(timeout_ms) => { - reqwest::Client::builder().timeout(Duration::from_millis(timeout_ms)) - } - } - .build() - .map_err(|err| internal_error(format!("failed to build http/request client: {err}")))?; - Ok(Self { client }) - } - - pub(crate) async fn run( - &self, - params: HttpRequestParams, - ) -> Result<(HttpRequestResponse, Option), JSONRPCErrorError> - { - let method = Method::from_bytes(params.method.as_bytes()) - .map_err(|err| invalid_params(format!("http/request method is invalid: {err}")))?; - let url = Url::parse(¶ms.url) - .map_err(|err| invalid_params(format!("http/request url is invalid: {err}")))?; - match url.scheme() { - "http" | "https" => {} - scheme => { - return Err(invalid_params(format!( - "http/request only supports http and https URLs, got {scheme}" - ))); - } - } - - let headers = Self::build_headers(params.headers)?; - let mut request = self.client.request(method, url).headers(headers); - if let Some(body) = params.body { - request = request.body(body.into_inner()); - } - - let response = request - .send() - .await - .map_err(|err| internal_error(format!("http/request failed: {err}")))?; - let status = response.status().as_u16(); - let headers = Self::response_headers(response.headers()); - - if params.stream_response { - return Ok(( - HttpRequestResponse { - status, - headers, - body: Vec::new().into(), - }, - Some(ExecutorPendingHttpBodyStream { - request_id: params.request_id, - response, - }), - )); - } - - let body = response.bytes().await.map_err(|err| { - internal_error(format!("failed to read http/request response body: {err}")) - })?; - - Ok(( - HttpRequestResponse { - status, - headers, - body: body.to_vec().into(), - }, - None, - )) - } - - fn build_headers(headers: Vec) -> Result { - let mut header_map = HeaderMap::new(); - for header in headers { - let name = HeaderName::from_bytes(header.name.as_bytes()).map_err(|err| { - invalid_params(format!("http/request header name is invalid: {err}")) - })?; - let value = HeaderValue::from_str(&header.value).map_err(|err| { - invalid_params(format!( - "http/request header value is invalid for {}: {err}", - header.name - )) - })?; - header_map.append(name, value); - } - Ok(header_map) - } - - fn response_headers(headers: &HeaderMap) -> Vec { - headers - .iter() - .filter_map(|(name, value)| { - Some(HttpHeader { - name: name.as_str().to_string(), - value: value.to_str().ok()?.to_string(), - }) - }) - .collect() - } - - pub(crate) async fn stream_body( - pending_stream: ExecutorPendingHttpBodyStream, - notifications: RpcNotificationSender, - ) { - let ExecutorPendingHttpBodyStream { - request_id, - response, - } = pending_stream; - let mut seq = 1; - let mut body = response.bytes_stream(); - while let Some(chunk) = body.next().await { - match chunk { - Ok(bytes) => { - if !send_executor_body_delta( - ¬ifications, - HttpRequestBodyDeltaNotification { - request_id: request_id.clone(), - seq, - delta: bytes.to_vec().into(), - done: false, - error: None, - }, - ) - .await - { - return; - } - seq += 1; - } - Err(err) => { - let _ = send_executor_body_delta( - ¬ifications, - HttpRequestBodyDeltaNotification { - request_id, - seq, - delta: Vec::new().into(), - done: true, - error: Some(err.to_string()), - }, - ) - .await; - return; - } - } - } - - let _ = send_executor_body_delta( - ¬ifications, - HttpRequestBodyDeltaNotification { - request_id, - seq, - delta: Vec::new().into(), - done: true, - error: None, - }, - ) - .await; - } -} - -impl Inner { - /// Routes one streamed HTTP body notification into its request-local receiver. - pub(super) async fn handle_http_body_delta_notification( - &self, - params: Option, - ) -> Result<(), ExecServerError> { - let params: HttpRequestBodyDeltaNotification = from_value(params.unwrap_or(Value::Null))?; - // Unknown request ids are ignored intentionally: a stream may have already - // reached EOF and released its route. - if let Some(tx) = self - .http_body_streams - .load() - .get(¶ms.request_id) - .cloned() - { - let request_id = params.request_id.clone(); - let terminal_delta = params.done || params.error.is_some(); - match tx.try_send(params) { - Ok(()) => { - if terminal_delta { - self.remove_http_body_stream(&request_id).await; - } - } - Err(TrySendError::Closed(_)) => { - self.remove_http_body_stream(&request_id).await; - debug!("http response stream receiver dropped before body delta delivery"); - } - Err(TrySendError::Full(_)) => { - self.record_http_body_stream_failure( - &request_id, - "body delta channel filled before delivery".to_string(), - ) - .await; - self.remove_http_body_stream(&request_id).await; - debug!( - "closing http response stream `{request_id}` after body delta backpressure" - ); - } - } - } - Ok(()) - } - - /// Fails active streamed HTTP bodies so callers do not wait forever after a - /// transport disconnect or notification handling failure. - pub(super) async fn fail_all_http_body_streams(&self, message: String) { - let _streams_write_guard = self.http_body_streams_write_lock.lock().await; - let streams = self.http_body_streams.load(); - let streams = streams.as_ref().clone(); - self.http_body_streams.store(Arc::new(HashMap::new())); - for (request_id, tx) in streams { - if tx - .try_send(HttpRequestBodyDeltaNotification { - request_id: request_id.clone(), - seq: 1, - delta: Vec::new().into(), - done: true, - error: Some(message.clone()), - }) - .is_err() - { - let mut next_failures = self.http_body_stream_failures.load().as_ref().clone(); - next_failures.insert(request_id, message.clone()); - self.http_body_stream_failures - .store(Arc::new(next_failures)); - } - } - } - - /// Allocates a connection-local streamed HTTP response id. - fn next_http_body_stream_request_id(&self) -> String { - let id = self - .http_body_stream_next_id - .fetch_add(1, Ordering::Relaxed); - format!("http-{id}") - } - - /// Registers a request id before issuing an executor streaming HTTP call. - async fn insert_http_body_stream( - &self, - request_id: String, - tx: mpsc::Sender, - ) -> Result<(), ExecServerError> { - let _streams_write_guard = self.http_body_streams_write_lock.lock().await; - let streams = self.http_body_streams.load(); - if streams.contains_key(&request_id) { - return Err(ExecServerError::Protocol(format!( - "http response stream already registered for request {request_id}" - ))); - } - let mut next_streams = streams.as_ref().clone(); - next_streams.insert(request_id.clone(), tx); - self.http_body_streams.store(Arc::new(next_streams)); - let failures = self.http_body_stream_failures.load(); - if failures.contains_key(&request_id) { - let mut next_failures = failures.as_ref().clone(); - next_failures.remove(&request_id); - self.http_body_stream_failures - .store(Arc::new(next_failures)); - } - Ok(()) - } - - /// Removes a request id after EOF, terminal error, or request failure. - async fn remove_http_body_stream( - &self, - request_id: &str, - ) -> Option> { - let _streams_write_guard = self.http_body_streams_write_lock.lock().await; - let streams = self.http_body_streams.load(); - let stream = streams.get(request_id).cloned(); - stream.as_ref()?; - let mut next_streams = streams.as_ref().clone(); - next_streams.remove(request_id); - self.http_body_streams.store(Arc::new(next_streams)); - stream - } - - async fn record_http_body_stream_failure(&self, request_id: &str, message: String) { - let _streams_write_guard = self.http_body_streams_write_lock.lock().await; - let failures = self.http_body_stream_failures.load(); - let mut next_failures = failures.as_ref().clone(); - next_failures.insert(request_id.to_string(), message); - self.http_body_stream_failures - .store(Arc::new(next_failures)); - } - - async fn take_http_body_stream_failure(&self, request_id: &str) -> Option { - let _streams_write_guard = self.http_body_streams_write_lock.lock().await; - let failures = self.http_body_stream_failures.load(); - let error = failures.get(request_id).cloned(); - error.as_ref()?; - let mut next_failures = failures.as_ref().clone(); - next_failures.remove(request_id); - self.http_body_stream_failures - .store(Arc::new(next_failures)); - error - } -} - -/// Active route registration owned while `http_request_stream` awaits headers. -struct HttpBodyStreamRegistration { - inner: Arc, - request_id: String, - active: bool, -} - -impl Drop for HttpBodyStreamRegistration { - /// Removes the route if the stream request future is cancelled before headers return. - fn drop(&mut self) { - if self.active { - spawn_remove_http_body_stream(Arc::clone(&self.inner), self.request_id.clone()); - } - } -} - -/// Schedules HTTP body route removal from synchronous drop paths. -fn spawn_remove_http_body_stream(inner: Arc, request_id: String) { - if let Ok(handle) = Handle::try_current() { - handle.spawn(async move { - inner.remove_http_body_stream(&request_id).await; - }); - } -} - -async fn send_executor_body_delta( - notifications: &RpcNotificationSender, - delta: HttpRequestBodyDeltaNotification, -) -> bool { - notifications - .notify(HTTP_REQUEST_BODY_DELTA_METHOD, &delta) - .await - .is_ok() -} +pub(crate) use reqwest_http_client::PendingReqwestHttpBodyStream; +pub use reqwest_http_client::ReqwestHttpClient; +pub(crate) use reqwest_http_client::ReqwestHttpRequestRunner; +pub use response_body_stream::HttpResponseBodyStream; diff --git a/codex-rs/exec-server/src/client/http_response_body_stream.rs b/codex-rs/exec-server/src/client/http_response_body_stream.rs new file mode 100644 index 000000000..3435af7d5 --- /dev/null +++ b/codex-rs/exec-server/src/client/http_response_body_stream.rs @@ -0,0 +1,355 @@ +//! Shared HTTP response-body stream plumbing for local and remote execution. +//! +//! This module owns the byte-stream type exposed by the `HttpClient` +//! capability plus the remote-side routing table used to turn +//! `http/request/bodyDelta` notifications back into per-request streams. + +use std::collections::HashMap; +use std::pin::Pin; +use std::sync::Arc; +use std::sync::atomic::Ordering; + +use bytes::Bytes; +use futures::StreamExt; +use reqwest::Response; +use serde_json::Value; +use serde_json::from_value; +use tokio::runtime::Handle; +use tokio::sync::mpsc; +use tokio::sync::mpsc::error::TrySendError; +use tracing::debug; + +use crate::client::ExecServerError; +use crate::client::Inner; +use crate::protocol::HTTP_REQUEST_BODY_DELTA_METHOD; +use crate::protocol::HttpRequestBodyDeltaNotification; +use crate::rpc::RpcNotificationSender; + +pub(super) struct HttpBodyStreamRegistration { + inner: Arc, + request_id: String, + active: bool, +} + +enum HttpResponseBodyStreamInner { + Local { + body: Pin> + Send>>, + }, + Remote { + inner: Arc, + request_id: String, + next_seq: u64, + rx: mpsc::Receiver, + pending_eof: bool, + closed: bool, + }, +} + +/// Request-scoped stream of body chunks for an HTTP response. +/// +/// The initial `http/request` call returns status and headers. This stream then +/// receives the ordered `http/request/bodyDelta` notifications for that request +/// id until EOF or a terminal error. +pub struct HttpResponseBodyStream { + inner: HttpResponseBodyStreamInner, +} + +impl HttpResponseBodyStream { + pub(super) fn local(response: Response) -> Self { + Self { + inner: HttpResponseBodyStreamInner::Local { + body: Box::pin(response.bytes_stream()), + }, + } + } + + pub(super) fn remote( + inner: Arc, + request_id: String, + rx: mpsc::Receiver, + ) -> Self { + Self { + inner: HttpResponseBodyStreamInner::Remote { + inner, + request_id, + next_seq: 1, + rx, + pending_eof: false, + closed: false, + }, + } + } + + /// Receives the next response-body chunk. + /// + /// Returns `Ok(None)` at EOF and converts sequence gaps or stream-side + /// stream errors into protocol errors. + pub async fn recv(&mut self) -> Result>, ExecServerError> { + match &mut self.inner { + HttpResponseBodyStreamInner::Local { body } => match body.next().await { + Some(chunk) => match chunk { + Ok(bytes) => Ok(Some(bytes.to_vec())), + Err(error) => Err(ExecServerError::HttpRequest(error.to_string())), + }, + None => Ok(None), + }, + HttpResponseBodyStreamInner::Remote { + inner, + request_id, + next_seq, + rx, + pending_eof, + closed, + } => { + if *pending_eof { + *pending_eof = false; + finish_remote_stream(inner, request_id, closed).await; + return Ok(None); + } + + let Some(delta) = rx.recv().await else { + finish_remote_stream(inner, request_id, closed).await; + if let Some(error) = inner.take_http_body_stream_failure(request_id).await { + return Err(ExecServerError::Protocol(format!( + "http response stream `{request_id}` failed: {error}", + ))); + } + return Ok(None); + }; + if delta.seq != *next_seq { + finish_remote_stream(inner, request_id, closed).await; + return Err(ExecServerError::Protocol(format!( + "http response stream `{request_id}` received seq {}, expected {}", + delta.seq, *next_seq + ))); + } + *next_seq += 1; + let chunk = delta.delta.into_inner(); + + if let Some(error) = delta.error { + finish_remote_stream(inner, request_id, closed).await; + return Err(ExecServerError::Protocol(format!( + "http response stream `{request_id}` failed: {error}", + ))); + } + if delta.done { + finish_remote_stream(inner, request_id, closed).await; + if chunk.is_empty() { + return Ok(None); + } + *pending_eof = true; + } + Ok(Some(chunk)) + } + } + } +} + +impl Drop for HttpResponseBodyStream { + /// Schedules stream-route removal if the consumer drops before EOF. + fn drop(&mut self) { + if let HttpResponseBodyStreamInner::Remote { + inner, + request_id, + closed, + .. + } = &mut self.inner + { + if *closed { + return; + } + *closed = true; + spawn_remove_http_body_stream(Arc::clone(inner), request_id.clone()); + } + } +} + +impl HttpBodyStreamRegistration { + pub(super) fn new(inner: Arc, request_id: String) -> Self { + Self { + inner, + request_id, + active: true, + } + } + + pub(super) fn disarm(&mut self) { + self.active = false; + } +} + +impl Drop for HttpBodyStreamRegistration { + /// Removes the route if the stream request future is cancelled before headers return. + fn drop(&mut self) { + if self.active { + spawn_remove_http_body_stream(Arc::clone(&self.inner), self.request_id.clone()); + } + } +} + +async fn finish_remote_stream(inner: &Arc, request_id: &str, closed: &mut bool) { + if *closed { + return; + } + *closed = true; + inner.remove_http_body_stream(request_id).await; +} + +/// Schedules HTTP body route removal from synchronous drop paths. +fn spawn_remove_http_body_stream(inner: Arc, request_id: String) { + if let Ok(handle) = Handle::try_current() { + handle.spawn(async move { + inner.remove_http_body_stream(&request_id).await; + }); + } +} + +pub(super) async fn send_body_delta( + notifications: &RpcNotificationSender, + delta: HttpRequestBodyDeltaNotification, +) -> bool { + notifications + .notify(HTTP_REQUEST_BODY_DELTA_METHOD, &delta) + .await + .is_ok() +} + +impl Inner { + /// Routes one streamed HTTP body notification into its request-local receiver. + pub(crate) async fn handle_http_body_delta_notification( + &self, + params: Option, + ) -> Result<(), ExecServerError> { + let params: HttpRequestBodyDeltaNotification = from_value(params.unwrap_or(Value::Null))?; + // Unknown request ids are ignored intentionally: a stream may have already + // reached EOF and released its route. + if let Some(tx) = self + .http_body_streams + .load() + .get(¶ms.request_id) + .cloned() + { + let request_id = params.request_id.clone(); + let terminal_delta = params.done || params.error.is_some(); + match tx.try_send(params) { + Ok(()) => { + if terminal_delta { + self.remove_http_body_stream(&request_id).await; + } + } + Err(TrySendError::Closed(_)) => { + self.remove_http_body_stream(&request_id).await; + debug!("http response stream receiver dropped before body delta delivery"); + } + Err(TrySendError::Full(_)) => { + self.record_http_body_stream_failure( + &request_id, + "body delta channel filled before delivery".to_string(), + ) + .await; + self.remove_http_body_stream(&request_id).await; + debug!( + "closing http response stream `{request_id}` after body delta backpressure" + ); + } + } + } + Ok(()) + } + + /// Fails active streamed HTTP bodies so callers do not wait forever after a + /// transport disconnect or notification handling failure. + pub(crate) async fn fail_all_http_body_streams(&self, message: String) { + let _streams_write_guard = self.http_body_streams_write_lock.lock().await; + let streams = self.http_body_streams.load(); + let streams = streams.as_ref().clone(); + self.http_body_streams.store(Arc::new(HashMap::new())); + for (request_id, tx) in streams { + if tx + .try_send(HttpRequestBodyDeltaNotification { + request_id: request_id.clone(), + seq: 1, + delta: Vec::new().into(), + done: true, + error: Some(message.clone()), + }) + .is_err() + { + let mut next_failures = self.http_body_stream_failures.load().as_ref().clone(); + next_failures.insert(request_id, message.clone()); + self.http_body_stream_failures + .store(Arc::new(next_failures)); + } + } + } + + /// Allocates a connection-local streamed HTTP response id. + pub(super) fn next_http_body_stream_request_id(&self) -> String { + let id = self + .http_body_stream_next_id + .fetch_add(1, Ordering::Relaxed); + format!("http-{id}") + } + + /// Registers a request id before issuing a streaming HTTP call. + pub(super) async fn insert_http_body_stream( + &self, + request_id: String, + tx: mpsc::Sender, + ) -> Result<(), ExecServerError> { + let _streams_write_guard = self.http_body_streams_write_lock.lock().await; + let streams = self.http_body_streams.load(); + if streams.contains_key(&request_id) { + return Err(ExecServerError::Protocol(format!( + "http response stream already registered for request {request_id}" + ))); + } + let mut next_streams = streams.as_ref().clone(); + next_streams.insert(request_id.clone(), tx); + self.http_body_streams.store(Arc::new(next_streams)); + let failures = self.http_body_stream_failures.load(); + if failures.contains_key(&request_id) { + let mut next_failures = failures.as_ref().clone(); + next_failures.remove(&request_id); + self.http_body_stream_failures + .store(Arc::new(next_failures)); + } + Ok(()) + } + + /// Removes a request id after EOF, terminal error, or request failure. + pub(super) async fn remove_http_body_stream( + &self, + request_id: &str, + ) -> Option> { + let _streams_write_guard = self.http_body_streams_write_lock.lock().await; + let streams = self.http_body_streams.load(); + let stream = streams.get(request_id).cloned(); + stream.as_ref()?; + let mut next_streams = streams.as_ref().clone(); + next_streams.remove(request_id); + self.http_body_streams.store(Arc::new(next_streams)); + stream + } + + async fn record_http_body_stream_failure(&self, request_id: &str, message: String) { + let _streams_write_guard = self.http_body_streams_write_lock.lock().await; + let failures = self.http_body_stream_failures.load(); + let mut next_failures = failures.as_ref().clone(); + next_failures.insert(request_id.to_string(), message); + self.http_body_stream_failures + .store(Arc::new(next_failures)); + } + + async fn take_http_body_stream_failure(&self, request_id: &str) -> Option { + let _streams_write_guard = self.http_body_streams_write_lock.lock().await; + let failures = self.http_body_stream_failures.load(); + let error = failures.get(request_id).cloned(); + error.as_ref()?; + let mut next_failures = failures.as_ref().clone(); + next_failures.remove(request_id); + self.http_body_stream_failures + .store(Arc::new(next_failures)); + error + } +} diff --git a/codex-rs/exec-server/src/client/reqwest_http_client.rs b/codex-rs/exec-server/src/client/reqwest_http_client.rs new file mode 100644 index 000000000..44305e466 --- /dev/null +++ b/codex-rs/exec-server/src/client/reqwest_http_client.rs @@ -0,0 +1,267 @@ +//! `reqwest`-backed `HttpClient` implementation. +//! +//! This code runs wherever the real network request should originate: +//! - in a local environment, that means the orchestrator process +//! - in a remote environment, that means the remote runtime after the +//! orchestrator has forwarded `http/request` over JSON-RPC + +use std::time::Duration; + +use codex_app_server_protocol::JSONRPCErrorError; +use codex_client::build_reqwest_client_with_custom_ca; +use futures::FutureExt; +use futures::StreamExt; +use futures::future::BoxFuture; +use reqwest::Method; +use reqwest::Url; +use reqwest::header::HeaderMap; +use reqwest::header::HeaderName; +use reqwest::header::HeaderValue; + +use super::HttpResponseBodyStream; +use super::response_body_stream::send_body_delta; +use crate::HttpClient; +use crate::client::ExecServerError; +use crate::protocol::HttpHeader; +use crate::protocol::HttpRequestBodyDeltaNotification; +use crate::protocol::HttpRequestParams; +use crate::protocol::HttpRequestResponse; +use crate::rpc::RpcNotificationSender; +use crate::rpc::internal_error; +use crate::rpc::invalid_params; + +/// `HttpClient` implementation that performs the actual HTTP request with +/// `reqwest`. +#[derive(Clone, Default)] +pub struct ReqwestHttpClient; + +/// Streaming response state held between the initial HTTP response and +/// downstream body-delta forwarding. +pub(crate) struct PendingReqwestHttpBodyStream { + pub(crate) request_id: String, + pub(crate) response: reqwest::Response, +} + +/// Validates `http/request` parameters and runs the actual `reqwest` call used +/// by the exec-server route and the local [`HttpClient`] backend. +pub(crate) struct ReqwestHttpRequestRunner { + client: reqwest::Client, +} + +impl ReqwestHttpClient { + fn build_client(timeout_ms: Option) -> Result { + let builder = match timeout_ms { + None => reqwest::Client::builder(), + Some(timeout_ms) => { + reqwest::Client::builder().timeout(Duration::from_millis(timeout_ms)) + } + }; + build_reqwest_client_with_custom_ca(builder) + .map_err(|error| ExecServerError::HttpRequest(error.to_string())) + } +} + +impl HttpClient for ReqwestHttpClient { + fn http_request( + &self, + params: HttpRequestParams, + ) -> BoxFuture<'_, Result> { + async move { + let runner = ReqwestHttpRequestRunner::new(params.timeout_ms) + .map_err(|error| ExecServerError::HttpRequest(error.message))?; + let (response, _) = runner + .run(HttpRequestParams { + stream_response: false, + ..params + }) + .await + .map_err(|error| ExecServerError::HttpRequest(error.message))?; + Ok(response) + } + .boxed() + } + + fn http_request_stream( + &self, + params: HttpRequestParams, + ) -> BoxFuture<'_, Result<(HttpRequestResponse, HttpResponseBodyStream), ExecServerError>> { + async move { + let runner = ReqwestHttpRequestRunner::new(params.timeout_ms) + .map_err(|error| ExecServerError::HttpRequest(error.message))?; + let (response, pending_stream) = runner + .run(HttpRequestParams { + stream_response: true, + ..params + }) + .await + .map_err(|error| ExecServerError::HttpRequest(error.message))?; + let pending_stream = pending_stream.ok_or_else(|| { + ExecServerError::Protocol( + "http request stream did not return a response body stream".to_string(), + ) + })?; + Ok(( + response, + HttpResponseBodyStream::local(pending_stream.response), + )) + } + .boxed() + } +} + +impl ReqwestHttpRequestRunner { + pub(crate) fn new(timeout_ms: Option) -> Result { + let client = ReqwestHttpClient::build_client(timeout_ms) + .map_err(|error| internal_error(error.to_string()))?; + Ok(Self { client }) + } + + pub(crate) async fn run( + &self, + params: HttpRequestParams, + ) -> Result<(HttpRequestResponse, Option), JSONRPCErrorError> + { + let method = Method::from_bytes(params.method.as_bytes()) + .map_err(|error| invalid_params(format!("http/request method is invalid: {error}")))?; + let url = Url::parse(¶ms.url) + .map_err(|error| invalid_params(format!("http/request url is invalid: {error}")))?; + match url.scheme() { + "http" | "https" => {} + scheme => { + return Err(invalid_params(format!( + "http/request only supports http and https URLs, got {scheme}" + ))); + } + } + + let headers = Self::build_headers(params.headers)?; + let mut request = self.client.request(method, url).headers(headers); + if let Some(body) = params.body { + request = request.body(body.into_inner()); + } + + let response = request + .send() + .await + .map_err(|error| internal_error(format!("http/request failed: {error}")))?; + let status = response.status().as_u16(); + let headers = Self::response_headers(response.headers()); + + if params.stream_response { + return Ok(( + HttpRequestResponse { + status, + headers, + body: Vec::new().into(), + }, + Some(PendingReqwestHttpBodyStream { + request_id: params.request_id, + response, + }), + )); + } + + let body = response.bytes().await.map_err(|error| { + internal_error(format!( + "failed to read http/request response body: {error}" + )) + })?; + + Ok(( + HttpRequestResponse { + status, + headers, + body: body.to_vec().into(), + }, + None, + )) + } + + pub(crate) async fn stream_body( + pending_stream: PendingReqwestHttpBodyStream, + notifications: RpcNotificationSender, + ) { + let PendingReqwestHttpBodyStream { + request_id, + response, + } = pending_stream; + let mut seq = 1; + let mut body = response.bytes_stream(); + while let Some(chunk) = body.next().await { + match chunk { + Ok(bytes) => { + if !send_body_delta( + ¬ifications, + HttpRequestBodyDeltaNotification { + request_id: request_id.clone(), + seq, + delta: bytes.to_vec().into(), + done: false, + error: None, + }, + ) + .await + { + return; + } + seq += 1; + } + Err(error) => { + let _ = send_body_delta( + ¬ifications, + HttpRequestBodyDeltaNotification { + request_id, + seq, + delta: Vec::new().into(), + done: true, + error: Some(error.to_string()), + }, + ) + .await; + return; + } + } + } + + let _ = send_body_delta( + ¬ifications, + HttpRequestBodyDeltaNotification { + request_id, + seq, + delta: Vec::new().into(), + done: true, + error: None, + }, + ) + .await; + } + + fn build_headers(headers: Vec) -> Result { + let mut header_map = HeaderMap::new(); + for header in headers { + let name = HeaderName::from_bytes(header.name.as_bytes()).map_err(|error| { + invalid_params(format!("http/request header name is invalid: {error}")) + })?; + let value = HeaderValue::from_str(&header.value).map_err(|error| { + invalid_params(format!( + "http/request header value is invalid for {}: {error}", + header.name + )) + })?; + header_map.append(name, value); + } + Ok(header_map) + } + + fn response_headers(headers: &HeaderMap) -> Vec { + headers + .iter() + .filter_map(|(name, value)| { + Some(HttpHeader { + name: name.as_str().to_string(), + value: value.to_str().ok()?.to_string(), + }) + }) + .collect() + } +} diff --git a/codex-rs/exec-server/src/client/rpc_http_client.rs b/codex-rs/exec-server/src/client/rpc_http_client.rs new file mode 100644 index 000000000..d2ce842ca --- /dev/null +++ b/codex-rs/exec-server/src/client/rpc_http_client.rs @@ -0,0 +1,88 @@ +//! JSON-RPC-backed `HttpClient` implementation. +//! +//! This code runs in the orchestrator process. It does not issue network +//! requests directly; instead it forwards `http/request` to the remote runtime +//! and then reconstructs streamed bodies from `http/request/bodyDelta` +//! notifications on the shared connection. + +use std::sync::Arc; + +use futures::FutureExt; +use futures::future::BoxFuture; +use tokio::sync::mpsc; + +use super::HttpResponseBodyStream; +use super::response_body_stream::HttpBodyStreamRegistration; +use crate::HttpClient; +use crate::client::ExecServerClient; +use crate::client::ExecServerError; +use crate::protocol::HTTP_REQUEST_METHOD; +use crate::protocol::HttpRequestParams; +use crate::protocol::HttpRequestResponse; + +/// Maximum queued body frames per streamed HTTP response. +const HTTP_BODY_DELTA_CHANNEL_CAPACITY: usize = 256; + +impl ExecServerClient { + /// Performs an HTTP request and buffers the response body. + pub async fn http_request( + &self, + mut params: HttpRequestParams, + ) -> Result { + params.stream_response = false; + self.call(HTTP_REQUEST_METHOD, ¶ms).await + } + + /// Performs an HTTP request and returns a body stream. + /// + /// The method sets `stream_response` and replaces any caller-supplied + /// `request_id` with a connection-local id, so late deltas from abandoned + /// streams cannot be confused with later requests. + pub async fn http_request_stream( + &self, + mut params: HttpRequestParams, + ) -> Result<(HttpRequestResponse, HttpResponseBodyStream), ExecServerError> { + params.stream_response = true; + let request_id = self.inner.next_http_body_stream_request_id(); + params.request_id = request_id.clone(); + let (tx, rx) = mpsc::channel(HTTP_BODY_DELTA_CHANNEL_CAPACITY); + self.inner + .insert_http_body_stream(request_id.clone(), tx) + .await?; + let mut registration = + HttpBodyStreamRegistration::new(Arc::clone(&self.inner), request_id.clone()); + let response = match self.call(HTTP_REQUEST_METHOD, ¶ms).await { + Ok(response) => response, + Err(error) => { + self.inner.remove_http_body_stream(&request_id).await; + registration.disarm(); + return Err(error); + } + }; + registration.disarm(); + Ok(( + response, + HttpResponseBodyStream::remote(Arc::clone(&self.inner), request_id, rx), + )) + } +} + +impl HttpClient for ExecServerClient { + /// Orchestrator-side adapter that forwards buffered HTTP requests to the + /// remote runtime over the shared JSON-RPC connection. + fn http_request( + &self, + params: HttpRequestParams, + ) -> BoxFuture<'_, Result> { + async move { ExecServerClient::http_request(self, params).await }.boxed() + } + + /// Orchestrator-side adapter that forwards streamed HTTP requests to the + /// remote runtime and exposes body deltas as a byte stream. + fn http_request_stream( + &self, + params: HttpRequestParams, + ) -> BoxFuture<'_, Result<(HttpRequestResponse, HttpResponseBodyStream), ExecServerError>> { + async move { ExecServerClient::http_request_stream(self, params).await }.boxed() + } +} diff --git a/codex-rs/exec-server/src/client_api.rs b/codex-rs/exec-server/src/client_api.rs index ac4371e2e..b1761b69f 100644 --- a/codex-rs/exec-server/src/client_api.rs +++ b/codex-rs/exec-server/src/client_api.rs @@ -1,5 +1,12 @@ use std::time::Duration; +use futures::future::BoxFuture; + +use crate::ExecServerError; +use crate::HttpRequestParams; +use crate::HttpRequestResponse; +use crate::HttpResponseBodyStream; + /// Connection options for any exec-server client transport. #[derive(Debug, Clone, PartialEq, Eq)] pub struct ExecServerClientConnectOptions { @@ -17,3 +24,22 @@ pub struct RemoteExecServerConnectArgs { pub initialize_timeout: Duration, pub resume_session_id: Option, } + +/// Sends HTTP requests through a runtime-selected transport. +/// +/// This is the HTTP capability counterpart to [`crate::ExecBackend`]. Callers +/// use it when they need environment-owned network requests but should not +/// depend on the concrete connection type or how that connection is established. +pub trait HttpClient: Send + Sync { + /// Perform an HTTP request and buffer the response body. + fn http_request( + &self, + params: HttpRequestParams, + ) -> BoxFuture<'_, Result>; + + /// Perform an HTTP request and return a streamed body handle. + fn http_request_stream( + &self, + params: HttpRequestParams, + ) -> BoxFuture<'_, Result<(HttpRequestResponse, HttpResponseBodyStream), ExecServerError>>; +} diff --git a/codex-rs/exec-server/src/environment.rs b/codex-rs/exec-server/src/environment.rs index 02fd493b1..9e4c69c41 100644 --- a/codex-rs/exec-server/src/environment.rs +++ b/codex-rs/exec-server/src/environment.rs @@ -3,7 +3,9 @@ use std::sync::Arc; use crate::ExecServerError; use crate::ExecServerRuntimePaths; +use crate::HttpClient; use crate::client::LazyRemoteExecServerClient; +use crate::client::http_client::ReqwestHttpClient; use crate::file_system::ExecutorFileSystem; use crate::local_file_system::LocalFileSystem; use crate::local_process::LocalProcess; @@ -136,6 +138,7 @@ pub struct Environment { exec_server_url: Option, exec_backend: Arc, filesystem: Arc, + http_client: Arc, local_runtime_paths: Option, } @@ -146,6 +149,7 @@ impl Environment { exec_server_url: None, exec_backend: Arc::new(LocalProcess::default()), filesystem: Arc::new(LocalFileSystem::unsandboxed()), + http_client: Arc::new(ReqwestHttpClient), local_runtime_paths: None, } } @@ -202,6 +206,7 @@ impl Environment { filesystem: Arc::new(LocalFileSystem::with_runtime_paths( local_runtime_paths.clone(), )), + http_client: Arc::new(ReqwestHttpClient), local_runtime_paths: Some(local_runtime_paths), } } @@ -216,12 +221,14 @@ impl Environment { ) -> Self { let client = LazyRemoteExecServerClient::new(exec_server_url.clone()); let exec_backend: Arc = Arc::new(RemoteProcess::new(client.clone())); - let filesystem: Arc = Arc::new(RemoteFileSystem::new(client)); + let filesystem: Arc = + Arc::new(RemoteFileSystem::new(client.clone())); Self { exec_server_url: Some(exec_server_url), exec_backend, filesystem, + http_client: Arc::new(client), local_runtime_paths, } } @@ -243,6 +250,10 @@ impl Environment { Arc::clone(&self.exec_backend) } + pub fn get_http_client(&self) -> Arc { + Arc::clone(&self.http_client) + } + pub fn get_filesystem(&self) -> Arc { Arc::clone(&self.filesystem) } diff --git a/codex-rs/exec-server/src/lib.rs b/codex-rs/exec-server/src/lib.rs index fc6a86f50..3da7aa73f 100644 --- a/codex-rs/exec-server/src/lib.rs +++ b/codex-rs/exec-server/src/lib.rs @@ -20,7 +20,10 @@ mod server; pub use client::ExecServerClient; pub use client::ExecServerError; +pub use client::http_client::HttpResponseBodyStream; +pub use client::http_client::ReqwestHttpClient; pub use client_api::ExecServerClientConnectOptions; +pub use client_api::HttpClient; pub use client_api::RemoteExecServerConnectArgs; pub use environment::CODEX_EXEC_SERVER_URL_ENV_VAR; pub use environment::Environment; diff --git a/codex-rs/exec-server/src/server/handler.rs b/codex-rs/exec-server/src/server/handler.rs index 15035335e..d0645724c 100644 --- a/codex-rs/exec-server/src/server/handler.rs +++ b/codex-rs/exec-server/src/server/handler.rs @@ -12,8 +12,8 @@ use tokio_util::sync::CancellationToken; use tokio_util::task::TaskTracker; use crate::ExecServerRuntimePaths; -use crate::client::http_client::ExecutorHttpRequestRunner; -use crate::client::http_client::ExecutorPendingHttpBodyStream; +use crate::client::http_client::PendingReqwestHttpBodyStream; +use crate::client::http_client::ReqwestHttpRequestRunner; use crate::protocol::ExecParams; use crate::protocol::ExecResponse; use crate::protocol::FsCopyParams; @@ -178,7 +178,7 @@ impl ExecServerHandler { if stream_response { self.reserve_http_body_stream(&http_request_id).await?; } - let response = ExecutorHttpRequestRunner::new(params.timeout_ms)? + let response = ReqwestHttpRequestRunner::new(params.timeout_ms)? .run(params) .await; if response.is_err() && stream_response { @@ -306,7 +306,7 @@ impl ExecServerHandler { async fn start_http_body_stream( self: &Arc, - pending_stream: ExecutorPendingHttpBodyStream, + pending_stream: PendingReqwestHttpBodyStream, ) { let request_id = pending_stream.request_id.clone(); if self.background_task_shutdown.is_cancelled() { @@ -320,7 +320,7 @@ impl ExecServerHandler { self.background_tasks.spawn(async move { tokio::select! { _ = shutdown.cancelled() => {} - _ = ExecutorHttpRequestRunner::stream_body(pending_stream, notifications) => {} + _ = ReqwestHttpRequestRunner::stream_body(pending_stream, notifications) => {} } handler.release_http_body_stream(&finished_request_id).await; }); diff --git a/codex-rs/rmcp-client/BUILD.bazel b/codex-rs/rmcp-client/BUILD.bazel index ad5b62603..89a1963ec 100644 --- a/codex-rs/rmcp-client/BUILD.bazel +++ b/codex-rs/rmcp-client/BUILD.bazel @@ -3,4 +3,7 @@ load("//:defs.bzl", "codex_rust_crate") codex_rust_crate( name = "rmcp-client", crate_name = "codex_rmcp_client", + extra_binaries = [ + "//codex-rs/cli:codex", + ], ) diff --git a/codex-rs/rmcp-client/Cargo.toml b/codex-rs/rmcp-client/Cargo.toml index b81c224a4..40e461314 100644 --- a/codex-rs/rmcp-client/Cargo.toml +++ b/codex-rs/rmcp-client/Cargo.toml @@ -20,6 +20,7 @@ codex-keyring-store = { workspace = true } codex-protocol = { workspace = true } codex-utils-pty = { workspace = true } codex-utils-home-dir = { workspace = true } +bytes = { workspace = true } futures = { workspace = true, default-features = false, features = ["std"] } keyring = { workspace = true, features = ["crypto-rust"] } oauth2 = "5" diff --git a/codex-rs/rmcp-client/src/http_client_adapter.rs b/codex-rs/rmcp-client/src/http_client_adapter.rs new file mode 100644 index 000000000..0656b8ce3 --- /dev/null +++ b/codex-rs/rmcp-client/src/http_client_adapter.rs @@ -0,0 +1,391 @@ +//! RMCP Streamable HTTP adapter built on top of the shared `HttpClient` +//! capability. +//! +//! This module runs in the orchestrator process. It turns high-level RMCP +//! operations like `post_message` and `get_stream` into calls on +//! `Arc`, which may be: +//! - a local HTTP client that issues requests from the orchestrator, or +//! - a remote HTTP client that forwards requests to the remote runtime + +use std::io; +use std::sync::Arc; + +use bytes::Bytes; +use codex_exec_server::ExecServerError; +use codex_exec_server::HttpClient; +use codex_exec_server::HttpHeader; +use codex_exec_server::HttpRequestParams; +use codex_exec_server::HttpResponseBodyStream; +use futures::StreamExt; +use futures::stream; +use futures::stream::BoxStream; +use reqwest::StatusCode; +use reqwest::header::ACCEPT; +use reqwest::header::AUTHORIZATION; +use reqwest::header::CONTENT_TYPE; +use reqwest::header::HeaderMap; +use reqwest::header::HeaderName; +use rmcp::model::ClientJsonRpcMessage; +use rmcp::model::ServerJsonRpcMessage; +use rmcp::transport::streamable_http_client::AuthRequiredError; +use rmcp::transport::streamable_http_client::StreamableHttpClient; +use rmcp::transport::streamable_http_client::StreamableHttpError; +use rmcp::transport::streamable_http_client::StreamableHttpPostResponse; +use sse_stream::Sse; +use sse_stream::SseStream; + +const EVENT_STREAM_MIME_TYPE: &str = "text/event-stream"; +const JSON_MIME_TYPE: &str = "application/json"; +const HEADER_SESSION_ID: &str = "Mcp-Session-Id"; +const NON_JSON_RESPONSE_BODY_PREVIEW_BYTES: usize = 8_192; + +#[derive(Clone)] +pub(crate) struct StreamableHttpClientAdapter { + http_client: Arc, + default_headers: HeaderMap, +} + +#[derive(Debug, thiserror::Error)] +pub(crate) enum StreamableHttpClientAdapterError { + #[error("streamable HTTP session expired with 404 Not Found")] + SessionExpired404, + #[error(transparent)] + HttpRequest(#[from] ExecServerError), + #[error("invalid HTTP header: {0}")] + Header(String), +} + +impl StreamableHttpClientAdapter { + pub(crate) fn new(http_client: Arc, default_headers: HeaderMap) -> Self { + Self { + http_client, + default_headers, + } + } +} + +impl StreamableHttpClient for StreamableHttpClientAdapter { + type Error = StreamableHttpClientAdapterError; + + async fn post_message( + &self, + uri: Arc, + message: ClientJsonRpcMessage, + session_id: Option>, + auth_token: Option, + ) -> std::result::Result> { + let mut headers = self.default_headers.clone(); + insert_header( + &mut headers, + ACCEPT, + [EVENT_STREAM_MIME_TYPE, JSON_MIME_TYPE].join(", "), + StreamableHttpClientAdapterError::Header, + )?; + insert_header( + &mut headers, + CONTENT_TYPE, + JSON_MIME_TYPE.to_string(), + StreamableHttpClientAdapterError::Header, + )?; + if let Some(auth_token) = auth_token { + insert_header( + &mut headers, + AUTHORIZATION, + format!("Bearer {auth_token}"), + StreamableHttpClientAdapterError::Header, + )?; + } + if let Some(session_id_value) = session_id.as_ref() { + insert_header( + &mut headers, + HeaderName::from_static("mcp-session-id"), + session_id_value.to_string(), + StreamableHttpClientAdapterError::Header, + )?; + } + + let body = serde_json::to_vec(&message).map_err(StreamableHttpError::Deserialize)?; + let (response, mut body_stream) = self + .http_client + .http_request_stream(HttpRequestParams { + method: "POST".to_string(), + url: uri.to_string(), + headers: protocol_headers(&headers), + body: Some(body.into()), + timeout_ms: None, + request_id: "buffered-request".to_string(), + stream_response: true, + }) + .await + .map_err(StreamableHttpClientAdapterError::from) + .map_err(StreamableHttpError::Client)?; + + if response.status == StatusCode::NOT_FOUND.as_u16() && session_id.is_some() { + return Err(StreamableHttpError::Client( + StreamableHttpClientAdapterError::SessionExpired404, + )); + } + if response.status == StatusCode::UNAUTHORIZED.as_u16() + && let Some(header) = + response_header(&response.headers, reqwest::header::WWW_AUTHENTICATE) + { + return Err(StreamableHttpError::AuthRequired(AuthRequiredError { + www_authenticate_header: header, + })); + } + if matches!( + StatusCode::from_u16(response.status).ok(), + Some(StatusCode::ACCEPTED | StatusCode::NO_CONTENT) + ) { + return Ok(StreamableHttpPostResponse::Accepted); + } + + let content_type = response_header(&response.headers, CONTENT_TYPE); + let session_id = response_header(&response.headers, HEADER_SESSION_ID); + match content_type.as_deref() { + Some(content_type) if content_type.starts_with(EVENT_STREAM_MIME_TYPE) => { + let event_stream = sse_stream_from_body(body_stream); + Ok(StreamableHttpPostResponse::Sse(event_stream, session_id)) + } + Some(content_type) if content_type.starts_with(JSON_MIME_TYPE) => { + let body = collect_body(&mut body_stream).await?; + let message: ServerJsonRpcMessage = + serde_json::from_slice(&body).map_err(StreamableHttpError::Deserialize)?; + Ok(StreamableHttpPostResponse::Json(message, session_id)) + } + _ => { + let body = collect_body(&mut body_stream).await?; + let content_type = content_type.unwrap_or_else(|| "missing-content-type".into()); + Err(StreamableHttpError::UnexpectedContentType(Some(format!( + "{content_type}; body: {}", + body_preview(String::from_utf8_lossy(&body).to_string()) + )))) + } + } + } + + async fn delete_session( + &self, + uri: Arc, + session: Arc, + auth_token: Option, + ) -> std::result::Result<(), StreamableHttpError> { + let mut headers = self.default_headers.clone(); + if let Some(auth_token) = auth_token { + insert_header( + &mut headers, + AUTHORIZATION, + format!("Bearer {auth_token}"), + StreamableHttpClientAdapterError::Header, + )?; + } + insert_header( + &mut headers, + HeaderName::from_static("mcp-session-id"), + session.to_string(), + StreamableHttpClientAdapterError::Header, + )?; + + let response = self + .http_client + .http_request(HttpRequestParams { + method: "DELETE".to_string(), + url: uri.to_string(), + headers: protocol_headers(&headers), + body: None, + timeout_ms: None, + request_id: "buffered-request".to_string(), + stream_response: false, + }) + .await + .map_err(StreamableHttpClientAdapterError::from) + .map_err(StreamableHttpError::Client)?; + + if response.status == StatusCode::METHOD_NOT_ALLOWED.as_u16() { + return Ok(()); + } + if !status_is_success(response.status) { + return Err(StreamableHttpError::UnexpectedServerResponse( + format!("DELETE returned HTTP {}", response.status).into(), + )); + } + Ok(()) + } + + async fn get_stream( + &self, + uri: Arc, + session_id: Arc, + last_event_id: Option, + auth_token: Option, + ) -> std::result::Result< + BoxStream<'static, std::result::Result>, + StreamableHttpError, + > { + let mut headers = self.default_headers.clone(); + insert_header( + &mut headers, + ACCEPT, + [EVENT_STREAM_MIME_TYPE, JSON_MIME_TYPE].join(", "), + StreamableHttpClientAdapterError::Header, + )?; + insert_header( + &mut headers, + HeaderName::from_static("mcp-session-id"), + session_id.to_string(), + StreamableHttpClientAdapterError::Header, + )?; + if let Some(last_event_id) = last_event_id { + insert_header( + &mut headers, + HeaderName::from_static("last-event-id"), + last_event_id, + StreamableHttpClientAdapterError::Header, + )?; + } + if let Some(auth_token) = auth_token { + insert_header( + &mut headers, + AUTHORIZATION, + format!("Bearer {auth_token}"), + StreamableHttpClientAdapterError::Header, + )?; + } + + let (response, body_stream) = self + .http_client + .http_request_stream(HttpRequestParams { + method: "GET".to_string(), + url: uri.to_string(), + headers: protocol_headers(&headers), + body: None, + timeout_ms: None, + request_id: "buffered-request".to_string(), + stream_response: true, + }) + .await + .map_err(StreamableHttpClientAdapterError::from) + .map_err(StreamableHttpError::Client)?; + + if response.status == StatusCode::METHOD_NOT_ALLOWED.as_u16() { + return Err(StreamableHttpError::ServerDoesNotSupportSse); + } + if response.status == StatusCode::NOT_FOUND.as_u16() { + return Err(StreamableHttpError::Client( + StreamableHttpClientAdapterError::SessionExpired404, + )); + } + if !status_is_success(response.status) { + return Err(StreamableHttpError::UnexpectedServerResponse( + format!("GET returned HTTP {}", response.status).into(), + )); + } + + match response_header(&response.headers, CONTENT_TYPE).as_deref() { + Some(content_type) if is_streamable_http_content_type(content_type) => {} + Some(content_type) => { + return Err(StreamableHttpError::UnexpectedContentType(Some( + content_type.to_string(), + ))); + } + None => { + return Err(StreamableHttpError::UnexpectedContentType(None)); + } + } + + Ok(sse_stream_from_body(body_stream)) + } +} + +fn body_preview(body: impl Into) -> String { + let mut body_preview = body.into(); + let body_len = body_preview.len(); + if body_len > NON_JSON_RESPONSE_BODY_PREVIEW_BYTES { + let mut boundary = NON_JSON_RESPONSE_BODY_PREVIEW_BYTES; + while !body_preview.is_char_boundary(boundary) { + boundary = boundary.saturating_sub(1); + } + body_preview.truncate(boundary); + body_preview.push_str(&format!( + "... (truncated {} bytes)", + body_len.saturating_sub(boundary) + )); + } + body_preview +} + +fn insert_header( + headers: &mut HeaderMap, + name: HeaderName, + value: String, + map_error: impl FnOnce(String) -> Error, +) -> std::result::Result<(), StreamableHttpError> +where + Error: std::error::Error + Send + Sync + 'static, +{ + let value = reqwest::header::HeaderValue::from_str(&value) + .map_err(|error| StreamableHttpError::Client(map_error(error.to_string())))?; + headers.insert(name, value); + Ok(()) +} + +fn is_streamable_http_content_type(content_type: &str) -> bool { + content_type + .as_bytes() + .starts_with(EVENT_STREAM_MIME_TYPE.as_bytes()) + || content_type + .as_bytes() + .starts_with(JSON_MIME_TYPE.as_bytes()) +} + +fn protocol_headers(headers: &HeaderMap) -> Vec { + headers + .iter() + .filter_map(|(name, value)| { + Some(HttpHeader { + name: name.as_str().to_string(), + value: value.to_str().ok()?.to_string(), + }) + }) + .collect() +} + +fn response_header(headers: &[HttpHeader], name: impl AsRef) -> Option { + let name = name.as_ref(); + headers + .iter() + .find(|header| header.name.eq_ignore_ascii_case(name)) + .map(|header| header.value.clone()) +} + +fn status_is_success(status: u16) -> bool { + StatusCode::from_u16(status).is_ok_and(|status| status.is_success()) +} + +async fn collect_body( + body_stream: &mut HttpResponseBodyStream, +) -> std::result::Result, StreamableHttpError> { + let mut body = Vec::new(); + while let Some(chunk) = body_stream + .recv() + .await + .map_err(StreamableHttpClientAdapterError::from) + .map_err(StreamableHttpError::Client)? + { + body.extend_from_slice(&chunk); + } + Ok(body) +} + +fn sse_stream_from_body( + body_stream: HttpResponseBodyStream, +) -> BoxStream<'static, std::result::Result> { + SseStream::from_byte_stream(stream::unfold(body_stream, |mut body_stream| async move { + match body_stream.recv().await { + Ok(Some(bytes)) => Some((Ok(Bytes::from(bytes)), body_stream)), + Ok(None) => None, + Err(error) => Some((Err(io::Error::other(error)), body_stream)), + } + })) + .boxed() +} diff --git a/codex-rs/rmcp-client/src/lib.rs b/codex-rs/rmcp-client/src/lib.rs index f3870b931..57e9f0e80 100644 --- a/codex-rs/rmcp-client/src/lib.rs +++ b/codex-rs/rmcp-client/src/lib.rs @@ -1,6 +1,7 @@ mod auth_status; mod elicitation_client_service; mod executor_process_transport; +mod http_client_adapter; mod logging_client_handler; mod oauth; mod perform_oauth_login; diff --git a/codex-rs/rmcp-client/src/rmcp_client.rs b/codex-rs/rmcp-client/src/rmcp_client.rs index 50891ec8c..0608e00d7 100644 --- a/codex-rs/rmcp-client/src/rmcp_client.rs +++ b/codex-rs/rmcp-client/src/rmcp_client.rs @@ -1,4 +1,3 @@ -use std::borrow::Cow; use std::collections::HashMap; use std::ffi::OsString; use std::future::Future; @@ -14,16 +13,12 @@ use anyhow::Result; use anyhow::anyhow; use codex_client::build_reqwest_client_with_custom_ca; use codex_config::types::McpServerEnvVar; +use codex_exec_server::HttpClient; use futures::FutureExt; -use futures::StreamExt; use futures::future::BoxFuture; -use futures::stream::BoxStream; use oauth2::TokenResponse; -use reqwest::header::ACCEPT; use reqwest::header::AUTHORIZATION; -use reqwest::header::CONTENT_TYPE; use reqwest::header::HeaderMap; -use reqwest::header::WWW_AUTHENTICATE; use rmcp::model::CallToolRequestParams; use rmcp::model::CallToolResult; use rmcp::model::ClientNotification; @@ -52,16 +47,11 @@ use rmcp::transport::StreamableHttpClientTransport; use rmcp::transport::auth::AuthClient; use rmcp::transport::auth::AuthError; use rmcp::transport::auth::OAuthState; -use rmcp::transport::streamable_http_client::AuthRequiredError; -use rmcp::transport::streamable_http_client::StreamableHttpClient; use rmcp::transport::streamable_http_client::StreamableHttpClientTransportConfig; use rmcp::transport::streamable_http_client::StreamableHttpError; -use rmcp::transport::streamable_http_client::StreamableHttpPostResponse; use serde::Deserialize; use serde::Serialize; use serde_json::Value; -use sse_stream::Sse; -use sse_stream::SseStream; use tokio::sync::Mutex; use tokio::sync::Semaphore; use tokio::sync::watch; @@ -69,6 +59,8 @@ use tokio::time; use tracing::warn; use crate::elicitation_client_service::ElicitationClientService; +use crate::http_client_adapter::StreamableHttpClientAdapter; +use crate::http_client_adapter::StreamableHttpClientAdapterError; use crate::load_oauth_tokens; use crate::oauth::OAuthPersistor; use crate::oauth::StoredOAuthTokens; @@ -79,239 +71,15 @@ use crate::utils::apply_default_headers; use crate::utils::build_default_headers; use codex_config::types::OAuthCredentialsStoreMode; -const EVENT_STREAM_MIME_TYPE: &str = "text/event-stream"; -const JSON_MIME_TYPE: &str = "application/json"; -const HEADER_LAST_EVENT_ID: &str = "Last-Event-Id"; -const HEADER_SESSION_ID: &str = "Mcp-Session-Id"; -const NON_JSON_RESPONSE_BODY_PREVIEW_BYTES: usize = 8_192; - -#[derive(Clone)] -struct StreamableHttpResponseClient { - inner: reqwest::Client, -} - -impl StreamableHttpResponseClient { - fn new(inner: reqwest::Client) -> Self { - Self { inner } - } - - fn reqwest_error( - error: reqwest::Error, - ) -> StreamableHttpError { - StreamableHttpError::Client(StreamableHttpResponseClientError::from(error)) - } -} - -fn build_http_client(default_headers: &HeaderMap) -> Result { - let builder = apply_default_headers(reqwest::Client::builder(), default_headers); - Ok(build_reqwest_client_with_custom_ca(builder)?) -} - -#[derive(Debug, thiserror::Error)] -enum StreamableHttpResponseClientError { - #[error("streamable HTTP session expired with 404 Not Found")] - SessionExpired404, - #[error(transparent)] - Reqwest(#[from] reqwest::Error), -} - -impl StreamableHttpClient for StreamableHttpResponseClient { - type Error = StreamableHttpResponseClientError; - - async fn post_message( - &self, - uri: Arc, - message: rmcp::model::ClientJsonRpcMessage, - session_id: Option>, - auth_token: Option, - ) -> std::result::Result> { - let mut request = self - .inner - .post(uri.as_ref()) - .header(ACCEPT, [EVENT_STREAM_MIME_TYPE, JSON_MIME_TYPE].join(", ")); - if let Some(auth_header) = auth_token { - request = request.bearer_auth(auth_header); - } - if let Some(session_id_value) = session_id.as_ref() { - request = request.header(HEADER_SESSION_ID, session_id_value.as_ref()); - } - - let response = request - .json(&message) - .send() - .await - .map_err(StreamableHttpResponseClient::reqwest_error)?; - if response.status() == reqwest::StatusCode::NOT_FOUND && session_id.is_some() { - return Err(StreamableHttpError::Client( - StreamableHttpResponseClientError::SessionExpired404, - )); - } - if response.status() == reqwest::StatusCode::UNAUTHORIZED - && let Some(header) = response.headers().get(WWW_AUTHENTICATE) - { - let header = header - .to_str() - .map_err(|_| { - StreamableHttpError::UnexpectedServerResponse(Cow::Borrowed( - "invalid www-authenticate header value", - )) - })? - .to_string(); - return Err(StreamableHttpError::AuthRequired(AuthRequiredError { - www_authenticate_header: header, - })); - } - - let status = response.status(); - if matches!( - status, - reqwest::StatusCode::ACCEPTED | reqwest::StatusCode::NO_CONTENT - ) { - return Ok(StreamableHttpPostResponse::Accepted); - } - - let content_type = response - .headers() - .get(CONTENT_TYPE) - .and_then(|value| value.to_str().ok()) - .map(str::to_string); - let session_id = response - .headers() - .get(HEADER_SESSION_ID) - .and_then(|value| value.to_str().ok()) - .map(str::to_string); - - match content_type.as_deref() { - Some(ct) if ct.as_bytes().starts_with(EVENT_STREAM_MIME_TYPE.as_bytes()) => { - let event_stream = SseStream::from_byte_stream(response.bytes_stream()).boxed(); - Ok(StreamableHttpPostResponse::Sse(event_stream, session_id)) - } - Some(ct) if ct.as_bytes().starts_with(JSON_MIME_TYPE.as_bytes()) => { - let message = response - .json() - .await - .map_err(StreamableHttpResponseClient::reqwest_error)?; - Ok(StreamableHttpPostResponse::Json(message, session_id)) - } - _ => { - let body = response - .text() - .await - .map_err(StreamableHttpResponseClient::reqwest_error)?; - let mut body_preview = body; - let body_len = body_preview.len(); - if body_len > NON_JSON_RESPONSE_BODY_PREVIEW_BYTES { - let mut boundary = NON_JSON_RESPONSE_BODY_PREVIEW_BYTES; - while !body_preview.is_char_boundary(boundary) { - boundary = boundary.saturating_sub(1); - } - body_preview.truncate(boundary); - body_preview.push_str(&format!( - "... (truncated {} bytes)", - body_len.saturating_sub(boundary) - )); - } - - let content_type = content_type.unwrap_or_else(|| "missing-content-type".into()); - Err(StreamableHttpError::UnexpectedContentType(Some(format!( - "{content_type}; body: {body_preview}" - )))) - } - } - } - - async fn delete_session( - &self, - uri: Arc, - session: Arc, - auth_token: Option, - ) -> std::result::Result<(), StreamableHttpError> { - let mut request_builder = self.inner.delete(uri.as_ref()); - if let Some(auth_header) = auth_token { - request_builder = request_builder.bearer_auth(auth_header); - } - let response = request_builder - .header(HEADER_SESSION_ID, session.as_ref()) - .send() - .await - .map_err(StreamableHttpResponseClient::reqwest_error)?; - - if response.status() == reqwest::StatusCode::METHOD_NOT_ALLOWED { - return Ok(()); - } - - response - .error_for_status() - .map_err(StreamableHttpResponseClient::reqwest_error)?; - Ok(()) - } - - async fn get_stream( - &self, - uri: Arc, - session_id: Arc, - last_event_id: Option, - auth_token: Option, - ) -> std::result::Result< - BoxStream<'static, std::result::Result>, - StreamableHttpError, - > { - let mut request_builder = self - .inner - .get(uri.as_ref()) - .header(ACCEPT, [EVENT_STREAM_MIME_TYPE, JSON_MIME_TYPE].join(", ")) - .header(HEADER_SESSION_ID, session_id.as_ref()); - if let Some(last_event_id) = last_event_id { - request_builder = request_builder.header(HEADER_LAST_EVENT_ID, last_event_id); - } - if let Some(auth_header) = auth_token { - request_builder = request_builder.bearer_auth(auth_header); - } - - let response = request_builder - .send() - .await - .map_err(StreamableHttpResponseClient::reqwest_error)?; - if response.status() == reqwest::StatusCode::METHOD_NOT_ALLOWED { - return Err(StreamableHttpError::ServerDoesNotSupportSse); - } - if response.status() == reqwest::StatusCode::NOT_FOUND { - return Err(StreamableHttpError::Client( - StreamableHttpResponseClientError::SessionExpired404, - )); - } - - let response = response - .error_for_status() - .map_err(StreamableHttpResponseClient::reqwest_error)?; - match response.headers().get(CONTENT_TYPE) { - Some(ct) - if ct.as_bytes().starts_with(EVENT_STREAM_MIME_TYPE.as_bytes()) - || ct.as_bytes().starts_with(JSON_MIME_TYPE.as_bytes()) => {} - Some(ct) => { - return Err(StreamableHttpError::UnexpectedContentType(Some( - String::from_utf8_lossy(ct.as_bytes()).to_string(), - ))); - } - None => { - return Err(StreamableHttpError::UnexpectedContentType(None)); - } - } - - let event_stream = SseStream::from_byte_stream(response.bytes_stream()).boxed(); - Ok(event_stream) - } -} - enum PendingTransport { Stdio { transport: StdioServerTransport, }, StreamableHttp { - transport: StreamableHttpClientTransport, + transport: StreamableHttpClientTransport, }, StreamableHttpWithOAuth { - transport: StreamableHttpClientTransport>, + transport: StreamableHttpClientTransport>, oauth_persistor: OAuthPersistor, }, } @@ -339,6 +107,7 @@ enum TransportRecipe { http_headers: Option>, env_http_headers: Option>, store_mode: OAuthCredentialsStoreMode, + http_client: Arc, }, } @@ -536,6 +305,7 @@ impl RmcpClient { http_headers: Option>, env_http_headers: Option>, store_mode: OAuthCredentialsStoreMode, + http_client: Arc, ) -> Result { let transport_recipe = TransportRecipe::StreamableHttp { server_name: server_name.to_string(), @@ -544,6 +314,7 @@ impl RmcpClient { http_headers, env_http_headers, store_mode, + http_client, }; let transport = Self::create_pending_transport(&transport_recipe).await?; Ok(Self { @@ -895,6 +666,7 @@ impl RmcpClient { http_headers, env_http_headers, store_mode, + http_client, } => { let default_headers = build_default_headers(http_headers.clone(), env_http_headers.clone())?; @@ -919,6 +691,7 @@ impl RmcpClient { initial_tokens.clone(), *store_mode, default_headers.clone(), + Arc::clone(http_client), ) .await { @@ -945,9 +718,11 @@ impl RmcpClient { let http_config = StreamableHttpClientTransportConfig::with_uri(url.clone()) .auth_header(access_token); - let http_client = build_http_client(&default_headers)?; let transport = StreamableHttpClientTransport::with_client( - StreamableHttpResponseClient::new(http_client), + StreamableHttpClientAdapter::new( + Arc::clone(http_client), + default_headers, + ), http_config, ); Ok(PendingTransport::StreamableHttp { transport }) @@ -961,10 +736,8 @@ impl RmcpClient { http_config = http_config.auth_header(bearer_token); } - let http_client = build_http_client(&default_headers)?; - let transport = StreamableHttpClientTransport::with_client( - StreamableHttpResponseClient::new(http_client), + StreamableHttpClientAdapter::new(Arc::clone(http_client), default_headers), http_config, ); Ok(PendingTransport::StreamableHttp { transport }) @@ -1084,12 +857,12 @@ impl RmcpClient { error .error - .downcast_ref::>() + .downcast_ref::>() .is_some_and(|error| { matches!( error, StreamableHttpError::Client( - StreamableHttpResponseClientError::SessionExpired404 + StreamableHttpClientAdapterError::SessionExpired404 ) ) }) @@ -1156,12 +929,18 @@ async fn create_oauth_transport_and_runtime( initial_tokens: StoredOAuthTokens, credentials_store: OAuthCredentialsStoreMode, default_headers: HeaderMap, + http_client: Arc, ) -> Result<( - StreamableHttpClientTransport>, + StreamableHttpClientTransport>, OAuthPersistor, )> { - let http_client = build_http_client(&default_headers)?; - let mut oauth_state = OAuthState::new(url.to_string(), Some(http_client.clone())).await?; + let builder = apply_default_headers(reqwest::Client::builder(), &default_headers); + let oauth_metadata_client = build_reqwest_client_with_custom_ca(builder)?; + // TODO(aibrahim): teach OAuth bootstrap and refresh to use the same + // shared HTTP client abstraction instead of always creating the local + // reqwest metadata client here. + let mut oauth_state = + OAuthState::new(url.to_string(), Some(oauth_metadata_client.clone())).await?; oauth_state .set_credentials( @@ -1178,7 +957,10 @@ async fn create_oauth_transport_and_runtime( } }; - let auth_client = AuthClient::new(StreamableHttpResponseClient::new(http_client), manager); + let auth_client = AuthClient::new( + StreamableHttpClientAdapter::new(http_client, default_headers), + manager, + ); let auth_manager = auth_client.auth_manager.clone(); let transport = StreamableHttpClientTransport::with_client( diff --git a/codex-rs/rmcp-client/tests/streamable_http_recovery.rs b/codex-rs/rmcp-client/tests/streamable_http_recovery.rs index c0525aafa..4be21f6cf 100644 --- a/codex-rs/rmcp-client/tests/streamable_http_recovery.rs +++ b/codex-rs/rmcp-client/tests/streamable_http_recovery.rs @@ -1,190 +1,12 @@ -use std::net::TcpListener; -use std::path::PathBuf; -use std::time::Duration; -use std::time::Instant; +mod streamable_http_test_support; -use codex_config::types::OAuthCredentialsStoreMode; -use codex_rmcp_client::ElicitationAction; -use codex_rmcp_client::ElicitationResponse; -use codex_rmcp_client::RmcpClient; -use codex_utils_cargo_bin::CargoBinError; -use futures::FutureExt as _; use pretty_assertions::assert_eq; -use rmcp::model::CallToolResult; -use rmcp::model::ClientCapabilities; -use rmcp::model::ElicitationCapability; -use rmcp::model::FormElicitationCapability; -use rmcp::model::Implementation; -use rmcp::model::InitializeRequestParams; -use rmcp::model::ProtocolVersion; -use serde_json::json; -use tokio::net::TcpStream; -use tokio::process::Child; -use tokio::process::Command; -use tokio::time::sleep; -const SESSION_POST_FAILURE_CONTROL_PATH: &str = "/test/control/session-post-failure"; - -fn streamable_http_server_bin() -> Result { - codex_utils_cargo_bin::cargo_bin("test_streamable_http_server") -} - -fn init_params() -> InitializeRequestParams { - InitializeRequestParams { - meta: None, - capabilities: ClientCapabilities { - experimental: None, - extensions: None, - roots: None, - sampling: None, - elicitation: Some(ElicitationCapability { - form: Some(FormElicitationCapability { - schema_validation: None, - }), - url: None, - }), - tasks: None, - }, - client_info: Implementation { - name: "codex-test".into(), - version: "0.0.0-test".into(), - title: Some("Codex rmcp recovery test".into()), - description: None, - icons: None, - website_url: None, - }, - protocol_version: ProtocolVersion::V_2025_06_18, - } -} - -fn expected_echo_result(message: &str) -> CallToolResult { - CallToolResult { - content: Vec::new(), - structured_content: Some(json!({ - "echo": format!("ECHOING: {message}"), - "env": null, - })), - is_error: Some(false), - meta: None, - } -} - -async fn create_client(base_url: &str) -> anyhow::Result { - let client = RmcpClient::new_streamable_http_client( - "test-streamable-http", - &format!("{base_url}/mcp"), - Some("test-bearer".to_string()), - /*http_headers*/ None, - /*env_http_headers*/ None, - OAuthCredentialsStoreMode::File, - ) - .await?; - - client - .initialize( - init_params(), - Some(Duration::from_secs(5)), - Box::new(|_, _| { - async { - Ok(ElicitationResponse { - action: ElicitationAction::Accept, - content: Some(json!({})), - meta: None, - }) - } - .boxed() - }), - ) - .await?; - - Ok(client) -} - -async fn call_echo_tool(client: &RmcpClient, message: &str) -> anyhow::Result { - client - .call_tool( - "echo".to_string(), - Some(json!({ "message": message })), - /*meta*/ None, - Some(Duration::from_secs(5)), - ) - .await -} - -async fn arm_session_post_failure( - base_url: &str, - status: u16, - remaining: usize, -) -> anyhow::Result<()> { - let response = reqwest::Client::new() - .post(format!("{base_url}{SESSION_POST_FAILURE_CONTROL_PATH}")) - .json(&json!({ - "status": status, - "remaining": remaining, - })) - .send() - .await?; - - assert_eq!(response.status(), reqwest::StatusCode::NO_CONTENT); - Ok(()) -} - -async fn spawn_streamable_http_server() -> anyhow::Result<(Child, String)> { - let listener = TcpListener::bind("127.0.0.1:0")?; - let port = listener.local_addr()?.port(); - drop(listener); - - let bind_addr = format!("127.0.0.1:{port}"); - let base_url = format!("http://{bind_addr}"); - let mut child = Command::new(streamable_http_server_bin()?) - .kill_on_drop(true) - .env("MCP_STREAMABLE_HTTP_BIND_ADDR", &bind_addr) - .spawn()?; - - wait_for_streamable_http_server(&mut child, &bind_addr, Duration::from_secs(5)).await?; - Ok((child, base_url)) -} - -async fn wait_for_streamable_http_server( - server_child: &mut Child, - address: &str, - timeout: Duration, -) -> anyhow::Result<()> { - let deadline = Instant::now() + timeout; - - loop { - if let Some(status) = server_child.try_wait()? { - return Err(anyhow::anyhow!( - "streamable HTTP server exited early with status {status}" - )); - } - - let remaining = deadline.saturating_duration_since(Instant::now()); - if remaining.is_zero() { - return Err(anyhow::anyhow!( - "timed out waiting for streamable HTTP server at {address}: deadline reached" - )); - } - - match tokio::time::timeout(remaining, TcpStream::connect(address)).await { - Ok(Ok(_)) => return Ok(()), - Ok(Err(error)) => { - if Instant::now() >= deadline { - return Err(anyhow::anyhow!( - "timed out waiting for streamable HTTP server at {address}: {error}" - )); - } - } - Err(_) => { - return Err(anyhow::anyhow!( - "timed out waiting for streamable HTTP server at {address}: connect call timed out" - )); - } - } - - sleep(Duration::from_millis(50)).await; - } -} +use streamable_http_test_support::arm_session_post_failure; +use streamable_http_test_support::call_echo_tool; +use streamable_http_test_support::create_client; +use streamable_http_test_support::expected_echo_result; +use streamable_http_test_support::spawn_streamable_http_server; #[tokio::test(flavor = "multi_thread", worker_threads = 1)] async fn streamable_http_404_session_expiry_recovers_and_retries_once() -> anyhow::Result<()> { diff --git a/codex-rs/rmcp-client/tests/streamable_http_remote.rs b/codex-rs/rmcp-client/tests/streamable_http_remote.rs new file mode 100644 index 000000000..0d4690a82 --- /dev/null +++ b/codex-rs/rmcp-client/tests/streamable_http_remote.rs @@ -0,0 +1,37 @@ +//! Integration coverage for the remote Streamable HTTP RMCP path. +//! +//! These tests exercise the orchestrator-side RMCP adapter against a real +//! `exec-server` process so HTTP requests go through the remote runtime path +//! instead of direct local `reqwest` calls. + +mod streamable_http_test_support; + +use pretty_assertions::assert_eq; + +use streamable_http_test_support::call_echo_tool; +use streamable_http_test_support::create_remote_client; +use streamable_http_test_support::expected_echo_result; +use streamable_http_test_support::spawn_exec_server; +use streamable_http_test_support::spawn_streamable_http_server; + +/// What this tests: the RMCP remote Streamable HTTP adapter can initialize +/// a server and call a tool while every MCP HTTP request goes through a real +/// exec-server process instead of a direct reqwest transport. +#[tokio::test(flavor = "multi_thread", worker_threads = 2)] +async fn streamable_http_remote_client_round_trips_through_exec_server() -> anyhow::Result<()> { + // Phase 1: start the MCP Streamable HTTP test server and a local + // exec-server process that will own the HTTP network calls. + let (_server, base_url) = spawn_streamable_http_server().await?; + let exec_server = spawn_exec_server().await?; + + // Phase 2: create and initialize the RMCP client using the executor-backed + // Streamable HTTP transport. + let client = create_remote_client(&base_url, exec_server.client.clone()).await?; + + // Phase 3: prove the initialized client can complete a tool call and + // preserve the normal RMCP response shape. + let result = call_echo_tool(&client, "remote").await?; + assert_eq!(result, expected_echo_result("remote")); + + Ok(()) +} diff --git a/codex-rs/rmcp-client/tests/streamable_http_test_support.rs b/codex-rs/rmcp-client/tests/streamable_http_test_support.rs new file mode 100644 index 000000000..ec7f7dc6f --- /dev/null +++ b/codex-rs/rmcp-client/tests/streamable_http_test_support.rs @@ -0,0 +1,314 @@ +//! Shared helpers for Streamable HTTP RMCP integration tests. +//! +//! This support module starts the test HTTP server, launches a real +//! `exec-server` when remote coverage is needed, and provides small helpers for +//! creating RMCP clients and asserting round-trip behavior. + +// This support module is included by multiple integration-test crates. Each +// crate uses a different subset of the helpers, so dead-code warnings would +// otherwise depend on which test file compiled the module. +#![allow(dead_code)] + +use std::net::TcpListener; +use std::path::PathBuf; +use std::process::Stdio; +use std::sync::Arc; +use std::time::Duration; +use std::time::Instant; + +use anyhow::Context as _; +use codex_config::types::OAuthCredentialsStoreMode; +use codex_exec_server::Environment; +use codex_exec_server::ExecServerClient; +use codex_exec_server::RemoteExecServerConnectArgs; +use codex_rmcp_client::ElicitationAction; +use codex_rmcp_client::ElicitationResponse; +use codex_rmcp_client::RmcpClient; +use codex_utils_cargo_bin::CargoBinError; +use futures::FutureExt as _; +use pretty_assertions::assert_eq; +use rmcp::model::CallToolResult; +use rmcp::model::ClientCapabilities; +use rmcp::model::ElicitationCapability; +use rmcp::model::FormElicitationCapability; +use rmcp::model::Implementation; +use rmcp::model::InitializeRequestParams; +use rmcp::model::ProtocolVersion; +use serde_json::json; +use tempfile::TempDir; +use tokio::io::AsyncBufReadExt; +use tokio::io::BufReader; +use tokio::net::TcpStream; +use tokio::process::Child; +use tokio::process::Command; +use tokio::time::sleep; + +const SESSION_POST_FAILURE_CONTROL_PATH: &str = "/test/control/session-post-failure"; + +fn streamable_http_server_bin() -> Result { + codex_utils_cargo_bin::cargo_bin("test_streamable_http_server") +} + +fn init_params() -> InitializeRequestParams { + InitializeRequestParams { + meta: None, + capabilities: ClientCapabilities { + experimental: None, + extensions: None, + roots: None, + sampling: None, + elicitation: Some(ElicitationCapability { + form: Some(FormElicitationCapability { + schema_validation: None, + }), + url: None, + }), + tasks: None, + }, + client_info: Implementation { + name: "codex-test".into(), + version: "0.0.0-test".into(), + title: Some("Codex rmcp recovery test".into()), + description: None, + icons: None, + website_url: None, + }, + protocol_version: ProtocolVersion::V_2025_06_18, + } +} + +pub(crate) fn expected_echo_result(message: &str) -> CallToolResult { + CallToolResult { + content: Vec::new(), + structured_content: Some(json!({ + "echo": format!("ECHOING: {message}"), + "env": null, + })), + is_error: Some(false), + meta: None, + } +} + +pub(crate) async fn create_client(base_url: &str) -> anyhow::Result { + let client = RmcpClient::new_streamable_http_client( + "test-streamable-http", + &format!("{base_url}/mcp"), + Some("test-bearer".to_string()), + /*http_headers*/ None, + /*env_http_headers*/ None, + OAuthCredentialsStoreMode::File, + Environment::default_for_tests().get_http_client(), + ) + .await?; + + client + .initialize( + init_params(), + Some(Duration::from_secs(5)), + Box::new(|_, _| { + async { + Ok(ElicitationResponse { + action: ElicitationAction::Accept, + content: Some(json!({})), + meta: None, + }) + } + .boxed() + }), + ) + .await?; + + Ok(client) +} + +/// Creates a Streamable HTTP RMCP client that sends traffic through the remote +/// runtime HTTP API. +pub(crate) async fn create_remote_client( + base_url: &str, + http_client: ExecServerClient, +) -> anyhow::Result { + let client = RmcpClient::new_streamable_http_client( + "test-streamable-http-remote", + &format!("{base_url}/mcp"), + Some("test-bearer".to_string()), + /*http_headers*/ None, + /*env_http_headers*/ None, + OAuthCredentialsStoreMode::File, + Arc::new(http_client), + ) + .await?; + + client + .initialize( + init_params(), + Some(Duration::from_secs(5)), + Box::new(|_, _| { + async { + Ok(ElicitationResponse { + action: ElicitationAction::Accept, + content: Some(json!({})), + meta: None, + }) + } + .boxed() + }), + ) + .await?; + + Ok(client) +} + +pub(crate) async fn call_echo_tool( + client: &RmcpClient, + message: &str, +) -> anyhow::Result { + client + .call_tool( + "echo".to_string(), + Some(json!({ "message": message })), + /*meta*/ None, + Some(Duration::from_secs(5)), + ) + .await +} + +pub(crate) async fn arm_session_post_failure( + base_url: &str, + status: u16, + remaining: usize, +) -> anyhow::Result<()> { + let response = reqwest::Client::new() + .post(format!("{base_url}{SESSION_POST_FAILURE_CONTROL_PATH}")) + .json(&json!({ + "status": status, + "remaining": remaining, + })) + .send() + .await?; + + assert_eq!(response.status(), reqwest::StatusCode::NO_CONTENT); + Ok(()) +} + +pub(crate) async fn spawn_streamable_http_server() -> anyhow::Result<(Child, String)> { + let listener = TcpListener::bind("127.0.0.1:0")?; + let port = listener.local_addr()?.port(); + drop(listener); + + let bind_addr = format!("127.0.0.1:{port}"); + let base_url = format!("http://{bind_addr}"); + let mut child = Command::new(streamable_http_server_bin()?) + .kill_on_drop(true) + .env("MCP_STREAMABLE_HTTP_BIND_ADDR", &bind_addr) + .spawn()?; + + wait_for_streamable_http_server(&mut child, &bind_addr, Duration::from_secs(5)).await?; + Ok((child, base_url)) +} + +/// Owns the exec-server process used by the remote-client integration test. +pub(crate) struct ExecServerProcess { + _codex_home: TempDir, + child: Child, + pub(crate) client: ExecServerClient, +} + +impl Drop for ExecServerProcess { + /// Stops the local exec-server process best-effort when the test exits. + fn drop(&mut self) { + let _ = self.child.start_kill(); + } +} + +/// Starts a local exec-server and connects an initialized `ExecServerClient`. +pub(crate) async fn spawn_exec_server() -> anyhow::Result { + let codex_home = TempDir::new()?; + let mut child = Command::new(codex_utils_cargo_bin::cargo_bin("codex")?) + .args(["exec-server", "--listen", "ws://127.0.0.1:0"]) + .stdin(Stdio::null()) + .stdout(Stdio::piped()) + .stderr(Stdio::inherit()) + .kill_on_drop(true) + .env("CODEX_HOME", codex_home.path()) + .spawn()?; + + let websocket_url = read_exec_server_listen_url(&mut child).await?; + let client = ExecServerClient::connect_websocket(RemoteExecServerConnectArgs::new( + websocket_url, + "rmcp-client-remote-http-test".to_string(), + )) + .await?; + + Ok(ExecServerProcess { + _codex_home: codex_home, + child, + client, + }) +} + +/// Reads the websocket URL printed by `codex exec-server --listen`. +async fn read_exec_server_listen_url(child: &mut Child) -> anyhow::Result { + let stdout = child + .stdout + .take() + .context("failed to capture exec-server stdout")?; + let mut lines = BufReader::new(stdout).lines(); + let deadline = Instant::now() + Duration::from_secs(10); + + loop { + let remaining = deadline.saturating_duration_since(Instant::now()); + if remaining.is_zero() { + anyhow::bail!("timed out waiting for exec-server listen URL"); + } + + let line = tokio::time::timeout(remaining, lines.next_line()) + .await + .context("timed out waiting for exec-server stdout")?? + .context("exec-server stdout closed before emitting listen URL")?; + let listen_url = line.trim(); + if listen_url.starts_with("ws://") { + return Ok(listen_url.to_string()); + } + } +} + +async fn wait_for_streamable_http_server( + server_child: &mut Child, + address: &str, + timeout: Duration, +) -> anyhow::Result<()> { + let deadline = Instant::now() + timeout; + + loop { + if let Some(status) = server_child.try_wait()? { + return Err(anyhow::anyhow!( + "streamable HTTP server exited early with status {status}" + )); + } + + let remaining = deadline.saturating_duration_since(Instant::now()); + if remaining.is_zero() { + return Err(anyhow::anyhow!( + "timed out waiting for streamable HTTP server at {address}: deadline reached" + )); + } + + match tokio::time::timeout(remaining, TcpStream::connect(address)).await { + Ok(Ok(_)) => return Ok(()), + Ok(Err(error)) => { + if Instant::now() >= deadline { + return Err(anyhow::anyhow!( + "timed out waiting for streamable HTTP server at {address}: {error}" + )); + } + } + Err(_) => { + return Err(anyhow::anyhow!( + "timed out waiting for streamable HTTP server at {address}: connect call timed out" + )); + } + } + + sleep(Duration::from_millis(50)).await; + } +}