mirror of
https://github.com/pchuan98/codex.git
synced 2026-07-01 00:31:56 +08:00
[3/4] Add executor-backed RMCP HTTP client (#18583)
### Why The RMCP layer needs a Streamable HTTP client that can talk either directly over `reqwest` or through the executor HTTP runner without duplicating MCP session logic higher in the stack. This PR adds that client-side transport boundary so remote Streamable HTTP MCP can reuse the same RMCP flow as the local path. ### What - Add a shared `rmcp-client/src/streamable_http/` module with: - `transport_client.rs` for the local-or-remote transport enum - `local_client.rs` for the direct `reqwest` implementation - `remote_client.rs` for the executor-backed implementation - `common.rs` for the small shared Streamable HTTP helpers - Teach `RmcpClient` to build Streamable HTTP transports in either local or remote mode while keeping the existing OAuth ownership in RMCP. - Translate remote POST, GET, and DELETE session operations into executor `http/request` calls. - Preserve RMCP session expiry handling and reconnect behavior for the remote transport. - Add remote transport coverage in `rmcp-client/tests/streamable_http_remote.rs` and keep the shared test support in `rmcp-client/tests/streamable_http_test_support.rs`. ### Verification - `cargo check -p codex-rmcp-client` - online CI ### Stack 1. #18581 protocol 2. #18582 runner 3. #18583 RMCP client 4. #18584 manager wiring and local/remote coverage --------- Co-authored-by: Codex <noreply@openai.com>
This commit is contained in:
committed by
GitHub
Unverified
parent
83ec1eb5d6
commit
0e78ce80ee
Generated
+3
@@ -2579,7 +2579,9 @@ dependencies = [
|
||||
"arc-swap",
|
||||
"async-trait",
|
||||
"base64 0.22.1",
|
||||
"bytes",
|
||||
"codex-app-server-protocol",
|
||||
"codex-client",
|
||||
"codex-config",
|
||||
"codex-protocol",
|
||||
"codex-sandboxing",
|
||||
@@ -3141,6 +3143,7 @@ version = "0.0.0"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"axum",
|
||||
"bytes",
|
||||
"codex-client",
|
||||
"codex-config",
|
||||
"codex-exec-server",
|
||||
|
||||
@@ -1589,23 +1589,11 @@ async fn make_rmcp_client(
|
||||
env_http_headers,
|
||||
bearer_token_env_var,
|
||||
} => {
|
||||
if remote_environment {
|
||||
if !runtime_environment.environment().is_remote() {
|
||||
return Err(StartupOutcomeError::from(anyhow!(
|
||||
"remote MCP server `{server_name}` requires a remote executor environment"
|
||||
)));
|
||||
}
|
||||
if remote_environment && !runtime_environment.environment().is_remote() {
|
||||
return Err(StartupOutcomeError::from(anyhow!(
|
||||
// Remote HTTP needs the future low-level executor
|
||||
// `network/request` API so reqwest runs on the executor side.
|
||||
// Do not fall back to local HTTP here; the config explicitly
|
||||
// asked for remote placement.
|
||||
"remote streamable HTTP MCP server `{server_name}` is not implemented yet"
|
||||
"remote MCP server `{server_name}` requires a remote environment"
|
||||
)));
|
||||
}
|
||||
|
||||
// Local streamable HTTP remains the existing reqwest path from
|
||||
// the orchestrator process.
|
||||
let resolved_bearer_token =
|
||||
match resolve_bearer_token(server_name, bearer_token_env_var.as_deref()) {
|
||||
Ok(token) => token,
|
||||
@@ -1618,6 +1606,7 @@ async fn make_rmcp_client(
|
||||
http_headers,
|
||||
env_http_headers,
|
||||
store_mode,
|
||||
runtime_environment.environment().get_http_client(),
|
||||
)
|
||||
.await
|
||||
.map_err(StartupOutcomeError::from)
|
||||
|
||||
@@ -14,7 +14,9 @@ workspace = true
|
||||
arc-swap = { workspace = true }
|
||||
async-trait = { workspace = true }
|
||||
base64 = { workspace = true }
|
||||
bytes = { workspace = true }
|
||||
codex-app-server-protocol = { workspace = true }
|
||||
codex-client = { workspace = true }
|
||||
codex-config = { workspace = true }
|
||||
codex-protocol = { workspace = true }
|
||||
codex-sandboxing = { workspace = true }
|
||||
|
||||
@@ -8,6 +8,8 @@ use std::time::Duration;
|
||||
|
||||
use arc_swap::ArcSwap;
|
||||
use codex_app_server_protocol::JSONRPCNotification;
|
||||
use futures::FutureExt;
|
||||
use futures::future::BoxFuture;
|
||||
use serde_json::Value;
|
||||
use tokio::sync::Mutex;
|
||||
use tokio::sync::OnceCell;
|
||||
@@ -20,6 +22,7 @@ use tracing::debug;
|
||||
|
||||
use crate::ProcessId;
|
||||
use crate::client_api::ExecServerClientConnectOptions;
|
||||
use crate::client_api::HttpClient;
|
||||
use crate::client_api::RemoteExecServerConnectArgs;
|
||||
use crate::connection::JsonRpcConnection;
|
||||
use crate::process::ExecProcessEvent;
|
||||
@@ -206,6 +209,25 @@ impl LazyRemoteExecServerClient {
|
||||
}
|
||||
}
|
||||
|
||||
impl HttpClient for LazyRemoteExecServerClient {
|
||||
fn http_request(
|
||||
&self,
|
||||
params: crate::HttpRequestParams,
|
||||
) -> BoxFuture<'_, Result<crate::HttpRequestResponse, ExecServerError>> {
|
||||
async move { self.get().await?.http_request(params).await }.boxed()
|
||||
}
|
||||
|
||||
fn http_request_stream(
|
||||
&self,
|
||||
params: crate::HttpRequestParams,
|
||||
) -> BoxFuture<
|
||||
'_,
|
||||
Result<(crate::HttpRequestResponse, crate::HttpResponseBodyStream), ExecServerError>,
|
||||
> {
|
||||
async move { self.get().await?.http_request_stream(params).await }.boxed()
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
pub enum ExecServerError {
|
||||
#[error("failed to spawn exec-server: {0}")]
|
||||
@@ -226,6 +248,8 @@ pub enum ExecServerError {
|
||||
Disconnected(String),
|
||||
#[error("failed to serialize or deserialize exec-server JSON: {0}")]
|
||||
Json(#[from] serde_json::Error),
|
||||
#[error("HTTP request failed: {0}")]
|
||||
HttpRequest(String),
|
||||
#[error("exec-server protocol error: {0}")]
|
||||
Protocol(String),
|
||||
#[error("exec-server rejected request ({code}): {message}")]
|
||||
|
||||
@@ -1,522 +1,26 @@
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
use std::sync::atomic::Ordering;
|
||||
use std::time::Duration;
|
||||
//! HTTP client capability implementations shared by local and remote environments.
|
||||
//!
|
||||
//! This module is the facade for the environment-owned [`crate::HttpClient`]
|
||||
//! capability:
|
||||
//! - [`ReqwestHttpClient`] executes requests directly with `reqwest`
|
||||
//! - [`ExecServerClient`] forwards requests over the JSON-RPC transport
|
||||
//! - [`HttpResponseBodyStream`] presents buffered local bodies and streamed
|
||||
//! remote `http/request/bodyDelta` notifications through one byte-stream API
|
||||
//!
|
||||
//! Runtime split:
|
||||
//! - orchestrator process: holds an `Arc<dyn HttpClient>` and chooses local or
|
||||
//! remote execution
|
||||
//! - remote runtime: serves the `http/request` RPC and runs the concrete local
|
||||
//! HTTP request there when the orchestrator uses [`ExecServerClient`]
|
||||
|
||||
use codex_app_server_protocol::JSONRPCErrorError;
|
||||
use futures::StreamExt;
|
||||
use reqwest::Method;
|
||||
use reqwest::Url;
|
||||
use reqwest::header::HeaderMap;
|
||||
use reqwest::header::HeaderName;
|
||||
use reqwest::header::HeaderValue;
|
||||
use serde_json::Value;
|
||||
use serde_json::from_value;
|
||||
use tokio::runtime::Handle;
|
||||
use tokio::sync::mpsc;
|
||||
use tokio::sync::mpsc::error::TrySendError;
|
||||
use tracing::debug;
|
||||
#[path = "reqwest_http_client.rs"]
|
||||
mod reqwest_http_client;
|
||||
#[path = "http_response_body_stream.rs"]
|
||||
pub(crate) mod response_body_stream;
|
||||
#[path = "rpc_http_client.rs"]
|
||||
mod rpc_http_client;
|
||||
|
||||
use super::ExecServerClient;
|
||||
use super::ExecServerError;
|
||||
use super::Inner;
|
||||
use crate::protocol::HTTP_REQUEST_BODY_DELTA_METHOD;
|
||||
use crate::protocol::HTTP_REQUEST_METHOD;
|
||||
use crate::protocol::HttpHeader;
|
||||
use crate::protocol::HttpRequestBodyDeltaNotification;
|
||||
use crate::protocol::HttpRequestParams;
|
||||
use crate::protocol::HttpRequestResponse;
|
||||
use crate::rpc::RpcNotificationSender;
|
||||
use crate::rpc::internal_error;
|
||||
use crate::rpc::invalid_params;
|
||||
|
||||
/// Maximum queued body frames per streamed executor HTTP response.
|
||||
const HTTP_BODY_DELTA_CHANNEL_CAPACITY: usize = 256;
|
||||
|
||||
pub(crate) struct ExecutorPendingHttpBodyStream {
|
||||
pub(crate) request_id: String,
|
||||
response: reqwest::Response,
|
||||
}
|
||||
|
||||
pub(crate) struct ExecutorHttpRequestRunner {
|
||||
client: reqwest::Client,
|
||||
}
|
||||
|
||||
/// Request-scoped stream of body chunks for an executor HTTP response.
|
||||
///
|
||||
/// The initial `http/request` call returns status and headers. This stream then
|
||||
/// receives the ordered `http/request/bodyDelta` notifications for that request
|
||||
/// id until EOF or a terminal error.
|
||||
pub struct HttpResponseBodyStream {
|
||||
inner: Arc<Inner>,
|
||||
request_id: String,
|
||||
next_seq: u64,
|
||||
rx: mpsc::Receiver<HttpRequestBodyDeltaNotification>,
|
||||
// Terminal frames can carry a final chunk; return that once, then EOF.
|
||||
pending_eof: bool,
|
||||
closed: bool,
|
||||
}
|
||||
|
||||
impl ExecServerClient {
|
||||
/// Performs an executor-side HTTP request and buffers the response body.
|
||||
pub async fn http_request(
|
||||
&self,
|
||||
mut params: HttpRequestParams,
|
||||
) -> Result<HttpRequestResponse, ExecServerError> {
|
||||
params.stream_response = false;
|
||||
self.call(HTTP_REQUEST_METHOD, ¶ms).await
|
||||
}
|
||||
|
||||
/// Performs an executor-side HTTP request and returns a body stream.
|
||||
///
|
||||
/// The method sets `stream_response` and replaces any caller-supplied
|
||||
/// `request_id` with a connection-local id, so late deltas from abandoned
|
||||
/// streams cannot be confused with later requests.
|
||||
pub async fn http_request_stream(
|
||||
&self,
|
||||
mut params: HttpRequestParams,
|
||||
) -> Result<(HttpRequestResponse, HttpResponseBodyStream), ExecServerError> {
|
||||
params.stream_response = true;
|
||||
let request_id = self.inner.next_http_body_stream_request_id();
|
||||
params.request_id = request_id.clone();
|
||||
let (tx, rx) = mpsc::channel(HTTP_BODY_DELTA_CHANNEL_CAPACITY);
|
||||
self.inner
|
||||
.insert_http_body_stream(request_id.clone(), tx)
|
||||
.await?;
|
||||
let mut registration = HttpBodyStreamRegistration {
|
||||
inner: Arc::clone(&self.inner),
|
||||
request_id: request_id.clone(),
|
||||
active: true,
|
||||
};
|
||||
let response = match self.call(HTTP_REQUEST_METHOD, ¶ms).await {
|
||||
Ok(response) => response,
|
||||
Err(error) => {
|
||||
self.inner.remove_http_body_stream(&request_id).await;
|
||||
registration.active = false;
|
||||
return Err(error);
|
||||
}
|
||||
};
|
||||
registration.active = false;
|
||||
Ok((
|
||||
response,
|
||||
HttpResponseBodyStream {
|
||||
inner: Arc::clone(&self.inner),
|
||||
request_id,
|
||||
next_seq: 1,
|
||||
rx,
|
||||
pending_eof: false,
|
||||
closed: false,
|
||||
},
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
impl HttpResponseBodyStream {
|
||||
/// Receives the next response-body chunk.
|
||||
///
|
||||
/// Returns `Ok(None)` at EOF and converts sequence gaps or executor-side
|
||||
/// stream errors into protocol errors.
|
||||
pub async fn recv(&mut self) -> Result<Option<Vec<u8>>, ExecServerError> {
|
||||
if self.pending_eof {
|
||||
self.pending_eof = false;
|
||||
self.finish().await;
|
||||
return Ok(None);
|
||||
}
|
||||
|
||||
let Some(delta) = self.rx.recv().await else {
|
||||
self.finish().await;
|
||||
if let Some(error) = self
|
||||
.inner
|
||||
.take_http_body_stream_failure(&self.request_id)
|
||||
.await
|
||||
{
|
||||
return Err(ExecServerError::Protocol(format!(
|
||||
"http response stream `{}` failed: {error}",
|
||||
self.request_id
|
||||
)));
|
||||
}
|
||||
return Ok(None);
|
||||
};
|
||||
if delta.seq != self.next_seq {
|
||||
self.finish().await;
|
||||
return Err(ExecServerError::Protocol(format!(
|
||||
"http response stream `{}` received seq {}, expected {}",
|
||||
self.request_id, delta.seq, self.next_seq
|
||||
)));
|
||||
}
|
||||
self.next_seq += 1;
|
||||
let chunk = delta.delta.into_inner();
|
||||
|
||||
if let Some(error) = delta.error {
|
||||
self.finish().await;
|
||||
return Err(ExecServerError::Protocol(format!(
|
||||
"http response stream `{}` failed: {error}",
|
||||
self.request_id
|
||||
)));
|
||||
}
|
||||
if delta.done {
|
||||
self.finish().await;
|
||||
if chunk.is_empty() {
|
||||
return Ok(None);
|
||||
}
|
||||
self.pending_eof = true;
|
||||
}
|
||||
Ok(Some(chunk))
|
||||
}
|
||||
|
||||
/// Removes this stream from the connection routing table once it reaches EOF.
|
||||
async fn finish(&mut self) {
|
||||
if self.closed {
|
||||
return;
|
||||
}
|
||||
self.closed = true;
|
||||
self.inner.remove_http_body_stream(&self.request_id).await;
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for HttpResponseBodyStream {
|
||||
/// Schedules stream-route removal if the consumer drops before EOF.
|
||||
fn drop(&mut self) {
|
||||
if self.closed {
|
||||
return;
|
||||
}
|
||||
self.closed = true;
|
||||
spawn_remove_http_body_stream(Arc::clone(&self.inner), self.request_id.clone());
|
||||
}
|
||||
}
|
||||
|
||||
impl ExecutorHttpRequestRunner {
|
||||
pub(crate) fn new(timeout_ms: Option<u64>) -> Result<Self, JSONRPCErrorError> {
|
||||
let client = match timeout_ms {
|
||||
None => reqwest::Client::builder(),
|
||||
Some(timeout_ms) => {
|
||||
reqwest::Client::builder().timeout(Duration::from_millis(timeout_ms))
|
||||
}
|
||||
}
|
||||
.build()
|
||||
.map_err(|err| internal_error(format!("failed to build http/request client: {err}")))?;
|
||||
Ok(Self { client })
|
||||
}
|
||||
|
||||
pub(crate) async fn run(
|
||||
&self,
|
||||
params: HttpRequestParams,
|
||||
) -> Result<(HttpRequestResponse, Option<ExecutorPendingHttpBodyStream>), JSONRPCErrorError>
|
||||
{
|
||||
let method = Method::from_bytes(params.method.as_bytes())
|
||||
.map_err(|err| invalid_params(format!("http/request method is invalid: {err}")))?;
|
||||
let url = Url::parse(¶ms.url)
|
||||
.map_err(|err| invalid_params(format!("http/request url is invalid: {err}")))?;
|
||||
match url.scheme() {
|
||||
"http" | "https" => {}
|
||||
scheme => {
|
||||
return Err(invalid_params(format!(
|
||||
"http/request only supports http and https URLs, got {scheme}"
|
||||
)));
|
||||
}
|
||||
}
|
||||
|
||||
let headers = Self::build_headers(params.headers)?;
|
||||
let mut request = self.client.request(method, url).headers(headers);
|
||||
if let Some(body) = params.body {
|
||||
request = request.body(body.into_inner());
|
||||
}
|
||||
|
||||
let response = request
|
||||
.send()
|
||||
.await
|
||||
.map_err(|err| internal_error(format!("http/request failed: {err}")))?;
|
||||
let status = response.status().as_u16();
|
||||
let headers = Self::response_headers(response.headers());
|
||||
|
||||
if params.stream_response {
|
||||
return Ok((
|
||||
HttpRequestResponse {
|
||||
status,
|
||||
headers,
|
||||
body: Vec::new().into(),
|
||||
},
|
||||
Some(ExecutorPendingHttpBodyStream {
|
||||
request_id: params.request_id,
|
||||
response,
|
||||
}),
|
||||
));
|
||||
}
|
||||
|
||||
let body = response.bytes().await.map_err(|err| {
|
||||
internal_error(format!("failed to read http/request response body: {err}"))
|
||||
})?;
|
||||
|
||||
Ok((
|
||||
HttpRequestResponse {
|
||||
status,
|
||||
headers,
|
||||
body: body.to_vec().into(),
|
||||
},
|
||||
None,
|
||||
))
|
||||
}
|
||||
|
||||
fn build_headers(headers: Vec<HttpHeader>) -> Result<HeaderMap, JSONRPCErrorError> {
|
||||
let mut header_map = HeaderMap::new();
|
||||
for header in headers {
|
||||
let name = HeaderName::from_bytes(header.name.as_bytes()).map_err(|err| {
|
||||
invalid_params(format!("http/request header name is invalid: {err}"))
|
||||
})?;
|
||||
let value = HeaderValue::from_str(&header.value).map_err(|err| {
|
||||
invalid_params(format!(
|
||||
"http/request header value is invalid for {}: {err}",
|
||||
header.name
|
||||
))
|
||||
})?;
|
||||
header_map.append(name, value);
|
||||
}
|
||||
Ok(header_map)
|
||||
}
|
||||
|
||||
fn response_headers(headers: &HeaderMap) -> Vec<HttpHeader> {
|
||||
headers
|
||||
.iter()
|
||||
.filter_map(|(name, value)| {
|
||||
Some(HttpHeader {
|
||||
name: name.as_str().to_string(),
|
||||
value: value.to_str().ok()?.to_string(),
|
||||
})
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
pub(crate) async fn stream_body(
|
||||
pending_stream: ExecutorPendingHttpBodyStream,
|
||||
notifications: RpcNotificationSender,
|
||||
) {
|
||||
let ExecutorPendingHttpBodyStream {
|
||||
request_id,
|
||||
response,
|
||||
} = pending_stream;
|
||||
let mut seq = 1;
|
||||
let mut body = response.bytes_stream();
|
||||
while let Some(chunk) = body.next().await {
|
||||
match chunk {
|
||||
Ok(bytes) => {
|
||||
if !send_executor_body_delta(
|
||||
¬ifications,
|
||||
HttpRequestBodyDeltaNotification {
|
||||
request_id: request_id.clone(),
|
||||
seq,
|
||||
delta: bytes.to_vec().into(),
|
||||
done: false,
|
||||
error: None,
|
||||
},
|
||||
)
|
||||
.await
|
||||
{
|
||||
return;
|
||||
}
|
||||
seq += 1;
|
||||
}
|
||||
Err(err) => {
|
||||
let _ = send_executor_body_delta(
|
||||
¬ifications,
|
||||
HttpRequestBodyDeltaNotification {
|
||||
request_id,
|
||||
seq,
|
||||
delta: Vec::new().into(),
|
||||
done: true,
|
||||
error: Some(err.to_string()),
|
||||
},
|
||||
)
|
||||
.await;
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let _ = send_executor_body_delta(
|
||||
¬ifications,
|
||||
HttpRequestBodyDeltaNotification {
|
||||
request_id,
|
||||
seq,
|
||||
delta: Vec::new().into(),
|
||||
done: true,
|
||||
error: None,
|
||||
},
|
||||
)
|
||||
.await;
|
||||
}
|
||||
}
|
||||
|
||||
impl Inner {
|
||||
/// Routes one streamed HTTP body notification into its request-local receiver.
|
||||
pub(super) async fn handle_http_body_delta_notification(
|
||||
&self,
|
||||
params: Option<Value>,
|
||||
) -> Result<(), ExecServerError> {
|
||||
let params: HttpRequestBodyDeltaNotification = from_value(params.unwrap_or(Value::Null))?;
|
||||
// Unknown request ids are ignored intentionally: a stream may have already
|
||||
// reached EOF and released its route.
|
||||
if let Some(tx) = self
|
||||
.http_body_streams
|
||||
.load()
|
||||
.get(¶ms.request_id)
|
||||
.cloned()
|
||||
{
|
||||
let request_id = params.request_id.clone();
|
||||
let terminal_delta = params.done || params.error.is_some();
|
||||
match tx.try_send(params) {
|
||||
Ok(()) => {
|
||||
if terminal_delta {
|
||||
self.remove_http_body_stream(&request_id).await;
|
||||
}
|
||||
}
|
||||
Err(TrySendError::Closed(_)) => {
|
||||
self.remove_http_body_stream(&request_id).await;
|
||||
debug!("http response stream receiver dropped before body delta delivery");
|
||||
}
|
||||
Err(TrySendError::Full(_)) => {
|
||||
self.record_http_body_stream_failure(
|
||||
&request_id,
|
||||
"body delta channel filled before delivery".to_string(),
|
||||
)
|
||||
.await;
|
||||
self.remove_http_body_stream(&request_id).await;
|
||||
debug!(
|
||||
"closing http response stream `{request_id}` after body delta backpressure"
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Fails active streamed HTTP bodies so callers do not wait forever after a
|
||||
/// transport disconnect or notification handling failure.
|
||||
pub(super) async fn fail_all_http_body_streams(&self, message: String) {
|
||||
let _streams_write_guard = self.http_body_streams_write_lock.lock().await;
|
||||
let streams = self.http_body_streams.load();
|
||||
let streams = streams.as_ref().clone();
|
||||
self.http_body_streams.store(Arc::new(HashMap::new()));
|
||||
for (request_id, tx) in streams {
|
||||
if tx
|
||||
.try_send(HttpRequestBodyDeltaNotification {
|
||||
request_id: request_id.clone(),
|
||||
seq: 1,
|
||||
delta: Vec::new().into(),
|
||||
done: true,
|
||||
error: Some(message.clone()),
|
||||
})
|
||||
.is_err()
|
||||
{
|
||||
let mut next_failures = self.http_body_stream_failures.load().as_ref().clone();
|
||||
next_failures.insert(request_id, message.clone());
|
||||
self.http_body_stream_failures
|
||||
.store(Arc::new(next_failures));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Allocates a connection-local streamed HTTP response id.
|
||||
fn next_http_body_stream_request_id(&self) -> String {
|
||||
let id = self
|
||||
.http_body_stream_next_id
|
||||
.fetch_add(1, Ordering::Relaxed);
|
||||
format!("http-{id}")
|
||||
}
|
||||
|
||||
/// Registers a request id before issuing an executor streaming HTTP call.
|
||||
async fn insert_http_body_stream(
|
||||
&self,
|
||||
request_id: String,
|
||||
tx: mpsc::Sender<HttpRequestBodyDeltaNotification>,
|
||||
) -> Result<(), ExecServerError> {
|
||||
let _streams_write_guard = self.http_body_streams_write_lock.lock().await;
|
||||
let streams = self.http_body_streams.load();
|
||||
if streams.contains_key(&request_id) {
|
||||
return Err(ExecServerError::Protocol(format!(
|
||||
"http response stream already registered for request {request_id}"
|
||||
)));
|
||||
}
|
||||
let mut next_streams = streams.as_ref().clone();
|
||||
next_streams.insert(request_id.clone(), tx);
|
||||
self.http_body_streams.store(Arc::new(next_streams));
|
||||
let failures = self.http_body_stream_failures.load();
|
||||
if failures.contains_key(&request_id) {
|
||||
let mut next_failures = failures.as_ref().clone();
|
||||
next_failures.remove(&request_id);
|
||||
self.http_body_stream_failures
|
||||
.store(Arc::new(next_failures));
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Removes a request id after EOF, terminal error, or request failure.
|
||||
async fn remove_http_body_stream(
|
||||
&self,
|
||||
request_id: &str,
|
||||
) -> Option<mpsc::Sender<HttpRequestBodyDeltaNotification>> {
|
||||
let _streams_write_guard = self.http_body_streams_write_lock.lock().await;
|
||||
let streams = self.http_body_streams.load();
|
||||
let stream = streams.get(request_id).cloned();
|
||||
stream.as_ref()?;
|
||||
let mut next_streams = streams.as_ref().clone();
|
||||
next_streams.remove(request_id);
|
||||
self.http_body_streams.store(Arc::new(next_streams));
|
||||
stream
|
||||
}
|
||||
|
||||
async fn record_http_body_stream_failure(&self, request_id: &str, message: String) {
|
||||
let _streams_write_guard = self.http_body_streams_write_lock.lock().await;
|
||||
let failures = self.http_body_stream_failures.load();
|
||||
let mut next_failures = failures.as_ref().clone();
|
||||
next_failures.insert(request_id.to_string(), message);
|
||||
self.http_body_stream_failures
|
||||
.store(Arc::new(next_failures));
|
||||
}
|
||||
|
||||
async fn take_http_body_stream_failure(&self, request_id: &str) -> Option<String> {
|
||||
let _streams_write_guard = self.http_body_streams_write_lock.lock().await;
|
||||
let failures = self.http_body_stream_failures.load();
|
||||
let error = failures.get(request_id).cloned();
|
||||
error.as_ref()?;
|
||||
let mut next_failures = failures.as_ref().clone();
|
||||
next_failures.remove(request_id);
|
||||
self.http_body_stream_failures
|
||||
.store(Arc::new(next_failures));
|
||||
error
|
||||
}
|
||||
}
|
||||
|
||||
/// Active route registration owned while `http_request_stream` awaits headers.
|
||||
struct HttpBodyStreamRegistration {
|
||||
inner: Arc<Inner>,
|
||||
request_id: String,
|
||||
active: bool,
|
||||
}
|
||||
|
||||
impl Drop for HttpBodyStreamRegistration {
|
||||
/// Removes the route if the stream request future is cancelled before headers return.
|
||||
fn drop(&mut self) {
|
||||
if self.active {
|
||||
spawn_remove_http_body_stream(Arc::clone(&self.inner), self.request_id.clone());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Schedules HTTP body route removal from synchronous drop paths.
|
||||
fn spawn_remove_http_body_stream(inner: Arc<Inner>, request_id: String) {
|
||||
if let Ok(handle) = Handle::try_current() {
|
||||
handle.spawn(async move {
|
||||
inner.remove_http_body_stream(&request_id).await;
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
async fn send_executor_body_delta(
|
||||
notifications: &RpcNotificationSender,
|
||||
delta: HttpRequestBodyDeltaNotification,
|
||||
) -> bool {
|
||||
notifications
|
||||
.notify(HTTP_REQUEST_BODY_DELTA_METHOD, &delta)
|
||||
.await
|
||||
.is_ok()
|
||||
}
|
||||
pub(crate) use reqwest_http_client::PendingReqwestHttpBodyStream;
|
||||
pub use reqwest_http_client::ReqwestHttpClient;
|
||||
pub(crate) use reqwest_http_client::ReqwestHttpRequestRunner;
|
||||
pub use response_body_stream::HttpResponseBodyStream;
|
||||
|
||||
@@ -0,0 +1,355 @@
|
||||
//! Shared HTTP response-body stream plumbing for local and remote execution.
|
||||
//!
|
||||
//! This module owns the byte-stream type exposed by the `HttpClient`
|
||||
//! capability plus the remote-side routing table used to turn
|
||||
//! `http/request/bodyDelta` notifications back into per-request streams.
|
||||
|
||||
use std::collections::HashMap;
|
||||
use std::pin::Pin;
|
||||
use std::sync::Arc;
|
||||
use std::sync::atomic::Ordering;
|
||||
|
||||
use bytes::Bytes;
|
||||
use futures::StreamExt;
|
||||
use reqwest::Response;
|
||||
use serde_json::Value;
|
||||
use serde_json::from_value;
|
||||
use tokio::runtime::Handle;
|
||||
use tokio::sync::mpsc;
|
||||
use tokio::sync::mpsc::error::TrySendError;
|
||||
use tracing::debug;
|
||||
|
||||
use crate::client::ExecServerError;
|
||||
use crate::client::Inner;
|
||||
use crate::protocol::HTTP_REQUEST_BODY_DELTA_METHOD;
|
||||
use crate::protocol::HttpRequestBodyDeltaNotification;
|
||||
use crate::rpc::RpcNotificationSender;
|
||||
|
||||
pub(super) struct HttpBodyStreamRegistration {
|
||||
inner: Arc<Inner>,
|
||||
request_id: String,
|
||||
active: bool,
|
||||
}
|
||||
|
||||
enum HttpResponseBodyStreamInner {
|
||||
Local {
|
||||
body: Pin<Box<dyn futures::Stream<Item = Result<Bytes, reqwest::Error>> + Send>>,
|
||||
},
|
||||
Remote {
|
||||
inner: Arc<Inner>,
|
||||
request_id: String,
|
||||
next_seq: u64,
|
||||
rx: mpsc::Receiver<HttpRequestBodyDeltaNotification>,
|
||||
pending_eof: bool,
|
||||
closed: bool,
|
||||
},
|
||||
}
|
||||
|
||||
/// Request-scoped stream of body chunks for an HTTP response.
|
||||
///
|
||||
/// The initial `http/request` call returns status and headers. This stream then
|
||||
/// receives the ordered `http/request/bodyDelta` notifications for that request
|
||||
/// id until EOF or a terminal error.
|
||||
pub struct HttpResponseBodyStream {
|
||||
inner: HttpResponseBodyStreamInner,
|
||||
}
|
||||
|
||||
impl HttpResponseBodyStream {
|
||||
pub(super) fn local(response: Response) -> Self {
|
||||
Self {
|
||||
inner: HttpResponseBodyStreamInner::Local {
|
||||
body: Box::pin(response.bytes_stream()),
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
pub(super) fn remote(
|
||||
inner: Arc<Inner>,
|
||||
request_id: String,
|
||||
rx: mpsc::Receiver<HttpRequestBodyDeltaNotification>,
|
||||
) -> Self {
|
||||
Self {
|
||||
inner: HttpResponseBodyStreamInner::Remote {
|
||||
inner,
|
||||
request_id,
|
||||
next_seq: 1,
|
||||
rx,
|
||||
pending_eof: false,
|
||||
closed: false,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
/// Receives the next response-body chunk.
|
||||
///
|
||||
/// Returns `Ok(None)` at EOF and converts sequence gaps or stream-side
|
||||
/// stream errors into protocol errors.
|
||||
pub async fn recv(&mut self) -> Result<Option<Vec<u8>>, ExecServerError> {
|
||||
match &mut self.inner {
|
||||
HttpResponseBodyStreamInner::Local { body } => match body.next().await {
|
||||
Some(chunk) => match chunk {
|
||||
Ok(bytes) => Ok(Some(bytes.to_vec())),
|
||||
Err(error) => Err(ExecServerError::HttpRequest(error.to_string())),
|
||||
},
|
||||
None => Ok(None),
|
||||
},
|
||||
HttpResponseBodyStreamInner::Remote {
|
||||
inner,
|
||||
request_id,
|
||||
next_seq,
|
||||
rx,
|
||||
pending_eof,
|
||||
closed,
|
||||
} => {
|
||||
if *pending_eof {
|
||||
*pending_eof = false;
|
||||
finish_remote_stream(inner, request_id, closed).await;
|
||||
return Ok(None);
|
||||
}
|
||||
|
||||
let Some(delta) = rx.recv().await else {
|
||||
finish_remote_stream(inner, request_id, closed).await;
|
||||
if let Some(error) = inner.take_http_body_stream_failure(request_id).await {
|
||||
return Err(ExecServerError::Protocol(format!(
|
||||
"http response stream `{request_id}` failed: {error}",
|
||||
)));
|
||||
}
|
||||
return Ok(None);
|
||||
};
|
||||
if delta.seq != *next_seq {
|
||||
finish_remote_stream(inner, request_id, closed).await;
|
||||
return Err(ExecServerError::Protocol(format!(
|
||||
"http response stream `{request_id}` received seq {}, expected {}",
|
||||
delta.seq, *next_seq
|
||||
)));
|
||||
}
|
||||
*next_seq += 1;
|
||||
let chunk = delta.delta.into_inner();
|
||||
|
||||
if let Some(error) = delta.error {
|
||||
finish_remote_stream(inner, request_id, closed).await;
|
||||
return Err(ExecServerError::Protocol(format!(
|
||||
"http response stream `{request_id}` failed: {error}",
|
||||
)));
|
||||
}
|
||||
if delta.done {
|
||||
finish_remote_stream(inner, request_id, closed).await;
|
||||
if chunk.is_empty() {
|
||||
return Ok(None);
|
||||
}
|
||||
*pending_eof = true;
|
||||
}
|
||||
Ok(Some(chunk))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for HttpResponseBodyStream {
|
||||
/// Schedules stream-route removal if the consumer drops before EOF.
|
||||
fn drop(&mut self) {
|
||||
if let HttpResponseBodyStreamInner::Remote {
|
||||
inner,
|
||||
request_id,
|
||||
closed,
|
||||
..
|
||||
} = &mut self.inner
|
||||
{
|
||||
if *closed {
|
||||
return;
|
||||
}
|
||||
*closed = true;
|
||||
spawn_remove_http_body_stream(Arc::clone(inner), request_id.clone());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl HttpBodyStreamRegistration {
|
||||
pub(super) fn new(inner: Arc<Inner>, request_id: String) -> Self {
|
||||
Self {
|
||||
inner,
|
||||
request_id,
|
||||
active: true,
|
||||
}
|
||||
}
|
||||
|
||||
pub(super) fn disarm(&mut self) {
|
||||
self.active = false;
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for HttpBodyStreamRegistration {
|
||||
/// Removes the route if the stream request future is cancelled before headers return.
|
||||
fn drop(&mut self) {
|
||||
if self.active {
|
||||
spawn_remove_http_body_stream(Arc::clone(&self.inner), self.request_id.clone());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn finish_remote_stream(inner: &Arc<Inner>, request_id: &str, closed: &mut bool) {
|
||||
if *closed {
|
||||
return;
|
||||
}
|
||||
*closed = true;
|
||||
inner.remove_http_body_stream(request_id).await;
|
||||
}
|
||||
|
||||
/// Schedules HTTP body route removal from synchronous drop paths.
|
||||
fn spawn_remove_http_body_stream(inner: Arc<Inner>, request_id: String) {
|
||||
if let Ok(handle) = Handle::try_current() {
|
||||
handle.spawn(async move {
|
||||
inner.remove_http_body_stream(&request_id).await;
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
pub(super) async fn send_body_delta(
|
||||
notifications: &RpcNotificationSender,
|
||||
delta: HttpRequestBodyDeltaNotification,
|
||||
) -> bool {
|
||||
notifications
|
||||
.notify(HTTP_REQUEST_BODY_DELTA_METHOD, &delta)
|
||||
.await
|
||||
.is_ok()
|
||||
}
|
||||
|
||||
impl Inner {
|
||||
/// Routes one streamed HTTP body notification into its request-local receiver.
|
||||
pub(crate) async fn handle_http_body_delta_notification(
|
||||
&self,
|
||||
params: Option<Value>,
|
||||
) -> Result<(), ExecServerError> {
|
||||
let params: HttpRequestBodyDeltaNotification = from_value(params.unwrap_or(Value::Null))?;
|
||||
// Unknown request ids are ignored intentionally: a stream may have already
|
||||
// reached EOF and released its route.
|
||||
if let Some(tx) = self
|
||||
.http_body_streams
|
||||
.load()
|
||||
.get(¶ms.request_id)
|
||||
.cloned()
|
||||
{
|
||||
let request_id = params.request_id.clone();
|
||||
let terminal_delta = params.done || params.error.is_some();
|
||||
match tx.try_send(params) {
|
||||
Ok(()) => {
|
||||
if terminal_delta {
|
||||
self.remove_http_body_stream(&request_id).await;
|
||||
}
|
||||
}
|
||||
Err(TrySendError::Closed(_)) => {
|
||||
self.remove_http_body_stream(&request_id).await;
|
||||
debug!("http response stream receiver dropped before body delta delivery");
|
||||
}
|
||||
Err(TrySendError::Full(_)) => {
|
||||
self.record_http_body_stream_failure(
|
||||
&request_id,
|
||||
"body delta channel filled before delivery".to_string(),
|
||||
)
|
||||
.await;
|
||||
self.remove_http_body_stream(&request_id).await;
|
||||
debug!(
|
||||
"closing http response stream `{request_id}` after body delta backpressure"
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Fails active streamed HTTP bodies so callers do not wait forever after a
|
||||
/// transport disconnect or notification handling failure.
|
||||
pub(crate) async fn fail_all_http_body_streams(&self, message: String) {
|
||||
let _streams_write_guard = self.http_body_streams_write_lock.lock().await;
|
||||
let streams = self.http_body_streams.load();
|
||||
let streams = streams.as_ref().clone();
|
||||
self.http_body_streams.store(Arc::new(HashMap::new()));
|
||||
for (request_id, tx) in streams {
|
||||
if tx
|
||||
.try_send(HttpRequestBodyDeltaNotification {
|
||||
request_id: request_id.clone(),
|
||||
seq: 1,
|
||||
delta: Vec::new().into(),
|
||||
done: true,
|
||||
error: Some(message.clone()),
|
||||
})
|
||||
.is_err()
|
||||
{
|
||||
let mut next_failures = self.http_body_stream_failures.load().as_ref().clone();
|
||||
next_failures.insert(request_id, message.clone());
|
||||
self.http_body_stream_failures
|
||||
.store(Arc::new(next_failures));
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Allocates a connection-local streamed HTTP response id.
|
||||
pub(super) fn next_http_body_stream_request_id(&self) -> String {
|
||||
let id = self
|
||||
.http_body_stream_next_id
|
||||
.fetch_add(1, Ordering::Relaxed);
|
||||
format!("http-{id}")
|
||||
}
|
||||
|
||||
/// Registers a request id before issuing a streaming HTTP call.
|
||||
pub(super) async fn insert_http_body_stream(
|
||||
&self,
|
||||
request_id: String,
|
||||
tx: mpsc::Sender<HttpRequestBodyDeltaNotification>,
|
||||
) -> Result<(), ExecServerError> {
|
||||
let _streams_write_guard = self.http_body_streams_write_lock.lock().await;
|
||||
let streams = self.http_body_streams.load();
|
||||
if streams.contains_key(&request_id) {
|
||||
return Err(ExecServerError::Protocol(format!(
|
||||
"http response stream already registered for request {request_id}"
|
||||
)));
|
||||
}
|
||||
let mut next_streams = streams.as_ref().clone();
|
||||
next_streams.insert(request_id.clone(), tx);
|
||||
self.http_body_streams.store(Arc::new(next_streams));
|
||||
let failures = self.http_body_stream_failures.load();
|
||||
if failures.contains_key(&request_id) {
|
||||
let mut next_failures = failures.as_ref().clone();
|
||||
next_failures.remove(&request_id);
|
||||
self.http_body_stream_failures
|
||||
.store(Arc::new(next_failures));
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
/// Removes a request id after EOF, terminal error, or request failure.
|
||||
pub(super) async fn remove_http_body_stream(
|
||||
&self,
|
||||
request_id: &str,
|
||||
) -> Option<mpsc::Sender<HttpRequestBodyDeltaNotification>> {
|
||||
let _streams_write_guard = self.http_body_streams_write_lock.lock().await;
|
||||
let streams = self.http_body_streams.load();
|
||||
let stream = streams.get(request_id).cloned();
|
||||
stream.as_ref()?;
|
||||
let mut next_streams = streams.as_ref().clone();
|
||||
next_streams.remove(request_id);
|
||||
self.http_body_streams.store(Arc::new(next_streams));
|
||||
stream
|
||||
}
|
||||
|
||||
async fn record_http_body_stream_failure(&self, request_id: &str, message: String) {
|
||||
let _streams_write_guard = self.http_body_streams_write_lock.lock().await;
|
||||
let failures = self.http_body_stream_failures.load();
|
||||
let mut next_failures = failures.as_ref().clone();
|
||||
next_failures.insert(request_id.to_string(), message);
|
||||
self.http_body_stream_failures
|
||||
.store(Arc::new(next_failures));
|
||||
}
|
||||
|
||||
async fn take_http_body_stream_failure(&self, request_id: &str) -> Option<String> {
|
||||
let _streams_write_guard = self.http_body_streams_write_lock.lock().await;
|
||||
let failures = self.http_body_stream_failures.load();
|
||||
let error = failures.get(request_id).cloned();
|
||||
error.as_ref()?;
|
||||
let mut next_failures = failures.as_ref().clone();
|
||||
next_failures.remove(request_id);
|
||||
self.http_body_stream_failures
|
||||
.store(Arc::new(next_failures));
|
||||
error
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,267 @@
|
||||
//! `reqwest`-backed `HttpClient` implementation.
|
||||
//!
|
||||
//! This code runs wherever the real network request should originate:
|
||||
//! - in a local environment, that means the orchestrator process
|
||||
//! - in a remote environment, that means the remote runtime after the
|
||||
//! orchestrator has forwarded `http/request` over JSON-RPC
|
||||
|
||||
use std::time::Duration;
|
||||
|
||||
use codex_app_server_protocol::JSONRPCErrorError;
|
||||
use codex_client::build_reqwest_client_with_custom_ca;
|
||||
use futures::FutureExt;
|
||||
use futures::StreamExt;
|
||||
use futures::future::BoxFuture;
|
||||
use reqwest::Method;
|
||||
use reqwest::Url;
|
||||
use reqwest::header::HeaderMap;
|
||||
use reqwest::header::HeaderName;
|
||||
use reqwest::header::HeaderValue;
|
||||
|
||||
use super::HttpResponseBodyStream;
|
||||
use super::response_body_stream::send_body_delta;
|
||||
use crate::HttpClient;
|
||||
use crate::client::ExecServerError;
|
||||
use crate::protocol::HttpHeader;
|
||||
use crate::protocol::HttpRequestBodyDeltaNotification;
|
||||
use crate::protocol::HttpRequestParams;
|
||||
use crate::protocol::HttpRequestResponse;
|
||||
use crate::rpc::RpcNotificationSender;
|
||||
use crate::rpc::internal_error;
|
||||
use crate::rpc::invalid_params;
|
||||
|
||||
/// `HttpClient` implementation that performs the actual HTTP request with
|
||||
/// `reqwest`.
|
||||
#[derive(Clone, Default)]
|
||||
pub struct ReqwestHttpClient;
|
||||
|
||||
/// Streaming response state held between the initial HTTP response and
|
||||
/// downstream body-delta forwarding.
|
||||
pub(crate) struct PendingReqwestHttpBodyStream {
|
||||
pub(crate) request_id: String,
|
||||
pub(crate) response: reqwest::Response,
|
||||
}
|
||||
|
||||
/// Validates `http/request` parameters and runs the actual `reqwest` call used
|
||||
/// by the exec-server route and the local [`HttpClient`] backend.
|
||||
pub(crate) struct ReqwestHttpRequestRunner {
|
||||
client: reqwest::Client,
|
||||
}
|
||||
|
||||
impl ReqwestHttpClient {
|
||||
fn build_client(timeout_ms: Option<u64>) -> Result<reqwest::Client, ExecServerError> {
|
||||
let builder = match timeout_ms {
|
||||
None => reqwest::Client::builder(),
|
||||
Some(timeout_ms) => {
|
||||
reqwest::Client::builder().timeout(Duration::from_millis(timeout_ms))
|
||||
}
|
||||
};
|
||||
build_reqwest_client_with_custom_ca(builder)
|
||||
.map_err(|error| ExecServerError::HttpRequest(error.to_string()))
|
||||
}
|
||||
}
|
||||
|
||||
impl HttpClient for ReqwestHttpClient {
|
||||
fn http_request(
|
||||
&self,
|
||||
params: HttpRequestParams,
|
||||
) -> BoxFuture<'_, Result<HttpRequestResponse, ExecServerError>> {
|
||||
async move {
|
||||
let runner = ReqwestHttpRequestRunner::new(params.timeout_ms)
|
||||
.map_err(|error| ExecServerError::HttpRequest(error.message))?;
|
||||
let (response, _) = runner
|
||||
.run(HttpRequestParams {
|
||||
stream_response: false,
|
||||
..params
|
||||
})
|
||||
.await
|
||||
.map_err(|error| ExecServerError::HttpRequest(error.message))?;
|
||||
Ok(response)
|
||||
}
|
||||
.boxed()
|
||||
}
|
||||
|
||||
fn http_request_stream(
|
||||
&self,
|
||||
params: HttpRequestParams,
|
||||
) -> BoxFuture<'_, Result<(HttpRequestResponse, HttpResponseBodyStream), ExecServerError>> {
|
||||
async move {
|
||||
let runner = ReqwestHttpRequestRunner::new(params.timeout_ms)
|
||||
.map_err(|error| ExecServerError::HttpRequest(error.message))?;
|
||||
let (response, pending_stream) = runner
|
||||
.run(HttpRequestParams {
|
||||
stream_response: true,
|
||||
..params
|
||||
})
|
||||
.await
|
||||
.map_err(|error| ExecServerError::HttpRequest(error.message))?;
|
||||
let pending_stream = pending_stream.ok_or_else(|| {
|
||||
ExecServerError::Protocol(
|
||||
"http request stream did not return a response body stream".to_string(),
|
||||
)
|
||||
})?;
|
||||
Ok((
|
||||
response,
|
||||
HttpResponseBodyStream::local(pending_stream.response),
|
||||
))
|
||||
}
|
||||
.boxed()
|
||||
}
|
||||
}
|
||||
|
||||
impl ReqwestHttpRequestRunner {
|
||||
pub(crate) fn new(timeout_ms: Option<u64>) -> Result<Self, JSONRPCErrorError> {
|
||||
let client = ReqwestHttpClient::build_client(timeout_ms)
|
||||
.map_err(|error| internal_error(error.to_string()))?;
|
||||
Ok(Self { client })
|
||||
}
|
||||
|
||||
pub(crate) async fn run(
|
||||
&self,
|
||||
params: HttpRequestParams,
|
||||
) -> Result<(HttpRequestResponse, Option<PendingReqwestHttpBodyStream>), JSONRPCErrorError>
|
||||
{
|
||||
let method = Method::from_bytes(params.method.as_bytes())
|
||||
.map_err(|error| invalid_params(format!("http/request method is invalid: {error}")))?;
|
||||
let url = Url::parse(¶ms.url)
|
||||
.map_err(|error| invalid_params(format!("http/request url is invalid: {error}")))?;
|
||||
match url.scheme() {
|
||||
"http" | "https" => {}
|
||||
scheme => {
|
||||
return Err(invalid_params(format!(
|
||||
"http/request only supports http and https URLs, got {scheme}"
|
||||
)));
|
||||
}
|
||||
}
|
||||
|
||||
let headers = Self::build_headers(params.headers)?;
|
||||
let mut request = self.client.request(method, url).headers(headers);
|
||||
if let Some(body) = params.body {
|
||||
request = request.body(body.into_inner());
|
||||
}
|
||||
|
||||
let response = request
|
||||
.send()
|
||||
.await
|
||||
.map_err(|error| internal_error(format!("http/request failed: {error}")))?;
|
||||
let status = response.status().as_u16();
|
||||
let headers = Self::response_headers(response.headers());
|
||||
|
||||
if params.stream_response {
|
||||
return Ok((
|
||||
HttpRequestResponse {
|
||||
status,
|
||||
headers,
|
||||
body: Vec::new().into(),
|
||||
},
|
||||
Some(PendingReqwestHttpBodyStream {
|
||||
request_id: params.request_id,
|
||||
response,
|
||||
}),
|
||||
));
|
||||
}
|
||||
|
||||
let body = response.bytes().await.map_err(|error| {
|
||||
internal_error(format!(
|
||||
"failed to read http/request response body: {error}"
|
||||
))
|
||||
})?;
|
||||
|
||||
Ok((
|
||||
HttpRequestResponse {
|
||||
status,
|
||||
headers,
|
||||
body: body.to_vec().into(),
|
||||
},
|
||||
None,
|
||||
))
|
||||
}
|
||||
|
||||
pub(crate) async fn stream_body(
|
||||
pending_stream: PendingReqwestHttpBodyStream,
|
||||
notifications: RpcNotificationSender,
|
||||
) {
|
||||
let PendingReqwestHttpBodyStream {
|
||||
request_id,
|
||||
response,
|
||||
} = pending_stream;
|
||||
let mut seq = 1;
|
||||
let mut body = response.bytes_stream();
|
||||
while let Some(chunk) = body.next().await {
|
||||
match chunk {
|
||||
Ok(bytes) => {
|
||||
if !send_body_delta(
|
||||
¬ifications,
|
||||
HttpRequestBodyDeltaNotification {
|
||||
request_id: request_id.clone(),
|
||||
seq,
|
||||
delta: bytes.to_vec().into(),
|
||||
done: false,
|
||||
error: None,
|
||||
},
|
||||
)
|
||||
.await
|
||||
{
|
||||
return;
|
||||
}
|
||||
seq += 1;
|
||||
}
|
||||
Err(error) => {
|
||||
let _ = send_body_delta(
|
||||
¬ifications,
|
||||
HttpRequestBodyDeltaNotification {
|
||||
request_id,
|
||||
seq,
|
||||
delta: Vec::new().into(),
|
||||
done: true,
|
||||
error: Some(error.to_string()),
|
||||
},
|
||||
)
|
||||
.await;
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
let _ = send_body_delta(
|
||||
¬ifications,
|
||||
HttpRequestBodyDeltaNotification {
|
||||
request_id,
|
||||
seq,
|
||||
delta: Vec::new().into(),
|
||||
done: true,
|
||||
error: None,
|
||||
},
|
||||
)
|
||||
.await;
|
||||
}
|
||||
|
||||
fn build_headers(headers: Vec<HttpHeader>) -> Result<HeaderMap, JSONRPCErrorError> {
|
||||
let mut header_map = HeaderMap::new();
|
||||
for header in headers {
|
||||
let name = HeaderName::from_bytes(header.name.as_bytes()).map_err(|error| {
|
||||
invalid_params(format!("http/request header name is invalid: {error}"))
|
||||
})?;
|
||||
let value = HeaderValue::from_str(&header.value).map_err(|error| {
|
||||
invalid_params(format!(
|
||||
"http/request header value is invalid for {}: {error}",
|
||||
header.name
|
||||
))
|
||||
})?;
|
||||
header_map.append(name, value);
|
||||
}
|
||||
Ok(header_map)
|
||||
}
|
||||
|
||||
fn response_headers(headers: &HeaderMap) -> Vec<HttpHeader> {
|
||||
headers
|
||||
.iter()
|
||||
.filter_map(|(name, value)| {
|
||||
Some(HttpHeader {
|
||||
name: name.as_str().to_string(),
|
||||
value: value.to_str().ok()?.to_string(),
|
||||
})
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,88 @@
|
||||
//! JSON-RPC-backed `HttpClient` implementation.
|
||||
//!
|
||||
//! This code runs in the orchestrator process. It does not issue network
|
||||
//! requests directly; instead it forwards `http/request` to the remote runtime
|
||||
//! and then reconstructs streamed bodies from `http/request/bodyDelta`
|
||||
//! notifications on the shared connection.
|
||||
|
||||
use std::sync::Arc;
|
||||
|
||||
use futures::FutureExt;
|
||||
use futures::future::BoxFuture;
|
||||
use tokio::sync::mpsc;
|
||||
|
||||
use super::HttpResponseBodyStream;
|
||||
use super::response_body_stream::HttpBodyStreamRegistration;
|
||||
use crate::HttpClient;
|
||||
use crate::client::ExecServerClient;
|
||||
use crate::client::ExecServerError;
|
||||
use crate::protocol::HTTP_REQUEST_METHOD;
|
||||
use crate::protocol::HttpRequestParams;
|
||||
use crate::protocol::HttpRequestResponse;
|
||||
|
||||
/// Maximum queued body frames per streamed HTTP response.
|
||||
const HTTP_BODY_DELTA_CHANNEL_CAPACITY: usize = 256;
|
||||
|
||||
impl ExecServerClient {
|
||||
/// Performs an HTTP request and buffers the response body.
|
||||
pub async fn http_request(
|
||||
&self,
|
||||
mut params: HttpRequestParams,
|
||||
) -> Result<HttpRequestResponse, ExecServerError> {
|
||||
params.stream_response = false;
|
||||
self.call(HTTP_REQUEST_METHOD, ¶ms).await
|
||||
}
|
||||
|
||||
/// Performs an HTTP request and returns a body stream.
|
||||
///
|
||||
/// The method sets `stream_response` and replaces any caller-supplied
|
||||
/// `request_id` with a connection-local id, so late deltas from abandoned
|
||||
/// streams cannot be confused with later requests.
|
||||
pub async fn http_request_stream(
|
||||
&self,
|
||||
mut params: HttpRequestParams,
|
||||
) -> Result<(HttpRequestResponse, HttpResponseBodyStream), ExecServerError> {
|
||||
params.stream_response = true;
|
||||
let request_id = self.inner.next_http_body_stream_request_id();
|
||||
params.request_id = request_id.clone();
|
||||
let (tx, rx) = mpsc::channel(HTTP_BODY_DELTA_CHANNEL_CAPACITY);
|
||||
self.inner
|
||||
.insert_http_body_stream(request_id.clone(), tx)
|
||||
.await?;
|
||||
let mut registration =
|
||||
HttpBodyStreamRegistration::new(Arc::clone(&self.inner), request_id.clone());
|
||||
let response = match self.call(HTTP_REQUEST_METHOD, ¶ms).await {
|
||||
Ok(response) => response,
|
||||
Err(error) => {
|
||||
self.inner.remove_http_body_stream(&request_id).await;
|
||||
registration.disarm();
|
||||
return Err(error);
|
||||
}
|
||||
};
|
||||
registration.disarm();
|
||||
Ok((
|
||||
response,
|
||||
HttpResponseBodyStream::remote(Arc::clone(&self.inner), request_id, rx),
|
||||
))
|
||||
}
|
||||
}
|
||||
|
||||
impl HttpClient for ExecServerClient {
|
||||
/// Orchestrator-side adapter that forwards buffered HTTP requests to the
|
||||
/// remote runtime over the shared JSON-RPC connection.
|
||||
fn http_request(
|
||||
&self,
|
||||
params: HttpRequestParams,
|
||||
) -> BoxFuture<'_, Result<HttpRequestResponse, ExecServerError>> {
|
||||
async move { ExecServerClient::http_request(self, params).await }.boxed()
|
||||
}
|
||||
|
||||
/// Orchestrator-side adapter that forwards streamed HTTP requests to the
|
||||
/// remote runtime and exposes body deltas as a byte stream.
|
||||
fn http_request_stream(
|
||||
&self,
|
||||
params: HttpRequestParams,
|
||||
) -> BoxFuture<'_, Result<(HttpRequestResponse, HttpResponseBodyStream), ExecServerError>> {
|
||||
async move { ExecServerClient::http_request_stream(self, params).await }.boxed()
|
||||
}
|
||||
}
|
||||
@@ -1,5 +1,12 @@
|
||||
use std::time::Duration;
|
||||
|
||||
use futures::future::BoxFuture;
|
||||
|
||||
use crate::ExecServerError;
|
||||
use crate::HttpRequestParams;
|
||||
use crate::HttpRequestResponse;
|
||||
use crate::HttpResponseBodyStream;
|
||||
|
||||
/// Connection options for any exec-server client transport.
|
||||
#[derive(Debug, Clone, PartialEq, Eq)]
|
||||
pub struct ExecServerClientConnectOptions {
|
||||
@@ -17,3 +24,22 @@ pub struct RemoteExecServerConnectArgs {
|
||||
pub initialize_timeout: Duration,
|
||||
pub resume_session_id: Option<String>,
|
||||
}
|
||||
|
||||
/// Sends HTTP requests through a runtime-selected transport.
|
||||
///
|
||||
/// This is the HTTP capability counterpart to [`crate::ExecBackend`]. Callers
|
||||
/// use it when they need environment-owned network requests but should not
|
||||
/// depend on the concrete connection type or how that connection is established.
|
||||
pub trait HttpClient: Send + Sync {
|
||||
/// Perform an HTTP request and buffer the response body.
|
||||
fn http_request(
|
||||
&self,
|
||||
params: HttpRequestParams,
|
||||
) -> BoxFuture<'_, Result<HttpRequestResponse, ExecServerError>>;
|
||||
|
||||
/// Perform an HTTP request and return a streamed body handle.
|
||||
fn http_request_stream(
|
||||
&self,
|
||||
params: HttpRequestParams,
|
||||
) -> BoxFuture<'_, Result<(HttpRequestResponse, HttpResponseBodyStream), ExecServerError>>;
|
||||
}
|
||||
|
||||
@@ -3,7 +3,9 @@ use std::sync::Arc;
|
||||
|
||||
use crate::ExecServerError;
|
||||
use crate::ExecServerRuntimePaths;
|
||||
use crate::HttpClient;
|
||||
use crate::client::LazyRemoteExecServerClient;
|
||||
use crate::client::http_client::ReqwestHttpClient;
|
||||
use crate::file_system::ExecutorFileSystem;
|
||||
use crate::local_file_system::LocalFileSystem;
|
||||
use crate::local_process::LocalProcess;
|
||||
@@ -136,6 +138,7 @@ pub struct Environment {
|
||||
exec_server_url: Option<String>,
|
||||
exec_backend: Arc<dyn ExecBackend>,
|
||||
filesystem: Arc<dyn ExecutorFileSystem>,
|
||||
http_client: Arc<dyn HttpClient>,
|
||||
local_runtime_paths: Option<ExecServerRuntimePaths>,
|
||||
}
|
||||
|
||||
@@ -146,6 +149,7 @@ impl Environment {
|
||||
exec_server_url: None,
|
||||
exec_backend: Arc::new(LocalProcess::default()),
|
||||
filesystem: Arc::new(LocalFileSystem::unsandboxed()),
|
||||
http_client: Arc::new(ReqwestHttpClient),
|
||||
local_runtime_paths: None,
|
||||
}
|
||||
}
|
||||
@@ -202,6 +206,7 @@ impl Environment {
|
||||
filesystem: Arc::new(LocalFileSystem::with_runtime_paths(
|
||||
local_runtime_paths.clone(),
|
||||
)),
|
||||
http_client: Arc::new(ReqwestHttpClient),
|
||||
local_runtime_paths: Some(local_runtime_paths),
|
||||
}
|
||||
}
|
||||
@@ -216,12 +221,14 @@ impl Environment {
|
||||
) -> Self {
|
||||
let client = LazyRemoteExecServerClient::new(exec_server_url.clone());
|
||||
let exec_backend: Arc<dyn ExecBackend> = Arc::new(RemoteProcess::new(client.clone()));
|
||||
let filesystem: Arc<dyn ExecutorFileSystem> = Arc::new(RemoteFileSystem::new(client));
|
||||
let filesystem: Arc<dyn ExecutorFileSystem> =
|
||||
Arc::new(RemoteFileSystem::new(client.clone()));
|
||||
|
||||
Self {
|
||||
exec_server_url: Some(exec_server_url),
|
||||
exec_backend,
|
||||
filesystem,
|
||||
http_client: Arc::new(client),
|
||||
local_runtime_paths,
|
||||
}
|
||||
}
|
||||
@@ -243,6 +250,10 @@ impl Environment {
|
||||
Arc::clone(&self.exec_backend)
|
||||
}
|
||||
|
||||
pub fn get_http_client(&self) -> Arc<dyn HttpClient> {
|
||||
Arc::clone(&self.http_client)
|
||||
}
|
||||
|
||||
pub fn get_filesystem(&self) -> Arc<dyn ExecutorFileSystem> {
|
||||
Arc::clone(&self.filesystem)
|
||||
}
|
||||
|
||||
@@ -20,7 +20,10 @@ mod server;
|
||||
|
||||
pub use client::ExecServerClient;
|
||||
pub use client::ExecServerError;
|
||||
pub use client::http_client::HttpResponseBodyStream;
|
||||
pub use client::http_client::ReqwestHttpClient;
|
||||
pub use client_api::ExecServerClientConnectOptions;
|
||||
pub use client_api::HttpClient;
|
||||
pub use client_api::RemoteExecServerConnectArgs;
|
||||
pub use environment::CODEX_EXEC_SERVER_URL_ENV_VAR;
|
||||
pub use environment::Environment;
|
||||
|
||||
@@ -12,8 +12,8 @@ use tokio_util::sync::CancellationToken;
|
||||
use tokio_util::task::TaskTracker;
|
||||
|
||||
use crate::ExecServerRuntimePaths;
|
||||
use crate::client::http_client::ExecutorHttpRequestRunner;
|
||||
use crate::client::http_client::ExecutorPendingHttpBodyStream;
|
||||
use crate::client::http_client::PendingReqwestHttpBodyStream;
|
||||
use crate::client::http_client::ReqwestHttpRequestRunner;
|
||||
use crate::protocol::ExecParams;
|
||||
use crate::protocol::ExecResponse;
|
||||
use crate::protocol::FsCopyParams;
|
||||
@@ -178,7 +178,7 @@ impl ExecServerHandler {
|
||||
if stream_response {
|
||||
self.reserve_http_body_stream(&http_request_id).await?;
|
||||
}
|
||||
let response = ExecutorHttpRequestRunner::new(params.timeout_ms)?
|
||||
let response = ReqwestHttpRequestRunner::new(params.timeout_ms)?
|
||||
.run(params)
|
||||
.await;
|
||||
if response.is_err() && stream_response {
|
||||
@@ -306,7 +306,7 @@ impl ExecServerHandler {
|
||||
|
||||
async fn start_http_body_stream(
|
||||
self: &Arc<Self>,
|
||||
pending_stream: ExecutorPendingHttpBodyStream,
|
||||
pending_stream: PendingReqwestHttpBodyStream,
|
||||
) {
|
||||
let request_id = pending_stream.request_id.clone();
|
||||
if self.background_task_shutdown.is_cancelled() {
|
||||
@@ -320,7 +320,7 @@ impl ExecServerHandler {
|
||||
self.background_tasks.spawn(async move {
|
||||
tokio::select! {
|
||||
_ = shutdown.cancelled() => {}
|
||||
_ = ExecutorHttpRequestRunner::stream_body(pending_stream, notifications) => {}
|
||||
_ = ReqwestHttpRequestRunner::stream_body(pending_stream, notifications) => {}
|
||||
}
|
||||
handler.release_http_body_stream(&finished_request_id).await;
|
||||
});
|
||||
|
||||
@@ -3,4 +3,7 @@ load("//:defs.bzl", "codex_rust_crate")
|
||||
codex_rust_crate(
|
||||
name = "rmcp-client",
|
||||
crate_name = "codex_rmcp_client",
|
||||
extra_binaries = [
|
||||
"//codex-rs/cli:codex",
|
||||
],
|
||||
)
|
||||
|
||||
@@ -20,6 +20,7 @@ codex-keyring-store = { workspace = true }
|
||||
codex-protocol = { workspace = true }
|
||||
codex-utils-pty = { workspace = true }
|
||||
codex-utils-home-dir = { workspace = true }
|
||||
bytes = { workspace = true }
|
||||
futures = { workspace = true, default-features = false, features = ["std"] }
|
||||
keyring = { workspace = true, features = ["crypto-rust"] }
|
||||
oauth2 = "5"
|
||||
|
||||
@@ -0,0 +1,391 @@
|
||||
//! RMCP Streamable HTTP adapter built on top of the shared `HttpClient`
|
||||
//! capability.
|
||||
//!
|
||||
//! This module runs in the orchestrator process. It turns high-level RMCP
|
||||
//! operations like `post_message` and `get_stream` into calls on
|
||||
//! `Arc<dyn HttpClient>`, which may be:
|
||||
//! - a local HTTP client that issues requests from the orchestrator, or
|
||||
//! - a remote HTTP client that forwards requests to the remote runtime
|
||||
|
||||
use std::io;
|
||||
use std::sync::Arc;
|
||||
|
||||
use bytes::Bytes;
|
||||
use codex_exec_server::ExecServerError;
|
||||
use codex_exec_server::HttpClient;
|
||||
use codex_exec_server::HttpHeader;
|
||||
use codex_exec_server::HttpRequestParams;
|
||||
use codex_exec_server::HttpResponseBodyStream;
|
||||
use futures::StreamExt;
|
||||
use futures::stream;
|
||||
use futures::stream::BoxStream;
|
||||
use reqwest::StatusCode;
|
||||
use reqwest::header::ACCEPT;
|
||||
use reqwest::header::AUTHORIZATION;
|
||||
use reqwest::header::CONTENT_TYPE;
|
||||
use reqwest::header::HeaderMap;
|
||||
use reqwest::header::HeaderName;
|
||||
use rmcp::model::ClientJsonRpcMessage;
|
||||
use rmcp::model::ServerJsonRpcMessage;
|
||||
use rmcp::transport::streamable_http_client::AuthRequiredError;
|
||||
use rmcp::transport::streamable_http_client::StreamableHttpClient;
|
||||
use rmcp::transport::streamable_http_client::StreamableHttpError;
|
||||
use rmcp::transport::streamable_http_client::StreamableHttpPostResponse;
|
||||
use sse_stream::Sse;
|
||||
use sse_stream::SseStream;
|
||||
|
||||
const EVENT_STREAM_MIME_TYPE: &str = "text/event-stream";
|
||||
const JSON_MIME_TYPE: &str = "application/json";
|
||||
const HEADER_SESSION_ID: &str = "Mcp-Session-Id";
|
||||
const NON_JSON_RESPONSE_BODY_PREVIEW_BYTES: usize = 8_192;
|
||||
|
||||
#[derive(Clone)]
|
||||
pub(crate) struct StreamableHttpClientAdapter {
|
||||
http_client: Arc<dyn HttpClient>,
|
||||
default_headers: HeaderMap,
|
||||
}
|
||||
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
pub(crate) enum StreamableHttpClientAdapterError {
|
||||
#[error("streamable HTTP session expired with 404 Not Found")]
|
||||
SessionExpired404,
|
||||
#[error(transparent)]
|
||||
HttpRequest(#[from] ExecServerError),
|
||||
#[error("invalid HTTP header: {0}")]
|
||||
Header(String),
|
||||
}
|
||||
|
||||
impl StreamableHttpClientAdapter {
|
||||
pub(crate) fn new(http_client: Arc<dyn HttpClient>, default_headers: HeaderMap) -> Self {
|
||||
Self {
|
||||
http_client,
|
||||
default_headers,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl StreamableHttpClient for StreamableHttpClientAdapter {
|
||||
type Error = StreamableHttpClientAdapterError;
|
||||
|
||||
async fn post_message(
|
||||
&self,
|
||||
uri: Arc<str>,
|
||||
message: ClientJsonRpcMessage,
|
||||
session_id: Option<Arc<str>>,
|
||||
auth_token: Option<String>,
|
||||
) -> std::result::Result<StreamableHttpPostResponse, StreamableHttpError<Self::Error>> {
|
||||
let mut headers = self.default_headers.clone();
|
||||
insert_header(
|
||||
&mut headers,
|
||||
ACCEPT,
|
||||
[EVENT_STREAM_MIME_TYPE, JSON_MIME_TYPE].join(", "),
|
||||
StreamableHttpClientAdapterError::Header,
|
||||
)?;
|
||||
insert_header(
|
||||
&mut headers,
|
||||
CONTENT_TYPE,
|
||||
JSON_MIME_TYPE.to_string(),
|
||||
StreamableHttpClientAdapterError::Header,
|
||||
)?;
|
||||
if let Some(auth_token) = auth_token {
|
||||
insert_header(
|
||||
&mut headers,
|
||||
AUTHORIZATION,
|
||||
format!("Bearer {auth_token}"),
|
||||
StreamableHttpClientAdapterError::Header,
|
||||
)?;
|
||||
}
|
||||
if let Some(session_id_value) = session_id.as_ref() {
|
||||
insert_header(
|
||||
&mut headers,
|
||||
HeaderName::from_static("mcp-session-id"),
|
||||
session_id_value.to_string(),
|
||||
StreamableHttpClientAdapterError::Header,
|
||||
)?;
|
||||
}
|
||||
|
||||
let body = serde_json::to_vec(&message).map_err(StreamableHttpError::Deserialize)?;
|
||||
let (response, mut body_stream) = self
|
||||
.http_client
|
||||
.http_request_stream(HttpRequestParams {
|
||||
method: "POST".to_string(),
|
||||
url: uri.to_string(),
|
||||
headers: protocol_headers(&headers),
|
||||
body: Some(body.into()),
|
||||
timeout_ms: None,
|
||||
request_id: "buffered-request".to_string(),
|
||||
stream_response: true,
|
||||
})
|
||||
.await
|
||||
.map_err(StreamableHttpClientAdapterError::from)
|
||||
.map_err(StreamableHttpError::Client)?;
|
||||
|
||||
if response.status == StatusCode::NOT_FOUND.as_u16() && session_id.is_some() {
|
||||
return Err(StreamableHttpError::Client(
|
||||
StreamableHttpClientAdapterError::SessionExpired404,
|
||||
));
|
||||
}
|
||||
if response.status == StatusCode::UNAUTHORIZED.as_u16()
|
||||
&& let Some(header) =
|
||||
response_header(&response.headers, reqwest::header::WWW_AUTHENTICATE)
|
||||
{
|
||||
return Err(StreamableHttpError::AuthRequired(AuthRequiredError {
|
||||
www_authenticate_header: header,
|
||||
}));
|
||||
}
|
||||
if matches!(
|
||||
StatusCode::from_u16(response.status).ok(),
|
||||
Some(StatusCode::ACCEPTED | StatusCode::NO_CONTENT)
|
||||
) {
|
||||
return Ok(StreamableHttpPostResponse::Accepted);
|
||||
}
|
||||
|
||||
let content_type = response_header(&response.headers, CONTENT_TYPE);
|
||||
let session_id = response_header(&response.headers, HEADER_SESSION_ID);
|
||||
match content_type.as_deref() {
|
||||
Some(content_type) if content_type.starts_with(EVENT_STREAM_MIME_TYPE) => {
|
||||
let event_stream = sse_stream_from_body(body_stream);
|
||||
Ok(StreamableHttpPostResponse::Sse(event_stream, session_id))
|
||||
}
|
||||
Some(content_type) if content_type.starts_with(JSON_MIME_TYPE) => {
|
||||
let body = collect_body(&mut body_stream).await?;
|
||||
let message: ServerJsonRpcMessage =
|
||||
serde_json::from_slice(&body).map_err(StreamableHttpError::Deserialize)?;
|
||||
Ok(StreamableHttpPostResponse::Json(message, session_id))
|
||||
}
|
||||
_ => {
|
||||
let body = collect_body(&mut body_stream).await?;
|
||||
let content_type = content_type.unwrap_or_else(|| "missing-content-type".into());
|
||||
Err(StreamableHttpError::UnexpectedContentType(Some(format!(
|
||||
"{content_type}; body: {}",
|
||||
body_preview(String::from_utf8_lossy(&body).to_string())
|
||||
))))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn delete_session(
|
||||
&self,
|
||||
uri: Arc<str>,
|
||||
session: Arc<str>,
|
||||
auth_token: Option<String>,
|
||||
) -> std::result::Result<(), StreamableHttpError<Self::Error>> {
|
||||
let mut headers = self.default_headers.clone();
|
||||
if let Some(auth_token) = auth_token {
|
||||
insert_header(
|
||||
&mut headers,
|
||||
AUTHORIZATION,
|
||||
format!("Bearer {auth_token}"),
|
||||
StreamableHttpClientAdapterError::Header,
|
||||
)?;
|
||||
}
|
||||
insert_header(
|
||||
&mut headers,
|
||||
HeaderName::from_static("mcp-session-id"),
|
||||
session.to_string(),
|
||||
StreamableHttpClientAdapterError::Header,
|
||||
)?;
|
||||
|
||||
let response = self
|
||||
.http_client
|
||||
.http_request(HttpRequestParams {
|
||||
method: "DELETE".to_string(),
|
||||
url: uri.to_string(),
|
||||
headers: protocol_headers(&headers),
|
||||
body: None,
|
||||
timeout_ms: None,
|
||||
request_id: "buffered-request".to_string(),
|
||||
stream_response: false,
|
||||
})
|
||||
.await
|
||||
.map_err(StreamableHttpClientAdapterError::from)
|
||||
.map_err(StreamableHttpError::Client)?;
|
||||
|
||||
if response.status == StatusCode::METHOD_NOT_ALLOWED.as_u16() {
|
||||
return Ok(());
|
||||
}
|
||||
if !status_is_success(response.status) {
|
||||
return Err(StreamableHttpError::UnexpectedServerResponse(
|
||||
format!("DELETE returned HTTP {}", response.status).into(),
|
||||
));
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn get_stream(
|
||||
&self,
|
||||
uri: Arc<str>,
|
||||
session_id: Arc<str>,
|
||||
last_event_id: Option<String>,
|
||||
auth_token: Option<String>,
|
||||
) -> std::result::Result<
|
||||
BoxStream<'static, std::result::Result<Sse, sse_stream::Error>>,
|
||||
StreamableHttpError<Self::Error>,
|
||||
> {
|
||||
let mut headers = self.default_headers.clone();
|
||||
insert_header(
|
||||
&mut headers,
|
||||
ACCEPT,
|
||||
[EVENT_STREAM_MIME_TYPE, JSON_MIME_TYPE].join(", "),
|
||||
StreamableHttpClientAdapterError::Header,
|
||||
)?;
|
||||
insert_header(
|
||||
&mut headers,
|
||||
HeaderName::from_static("mcp-session-id"),
|
||||
session_id.to_string(),
|
||||
StreamableHttpClientAdapterError::Header,
|
||||
)?;
|
||||
if let Some(last_event_id) = last_event_id {
|
||||
insert_header(
|
||||
&mut headers,
|
||||
HeaderName::from_static("last-event-id"),
|
||||
last_event_id,
|
||||
StreamableHttpClientAdapterError::Header,
|
||||
)?;
|
||||
}
|
||||
if let Some(auth_token) = auth_token {
|
||||
insert_header(
|
||||
&mut headers,
|
||||
AUTHORIZATION,
|
||||
format!("Bearer {auth_token}"),
|
||||
StreamableHttpClientAdapterError::Header,
|
||||
)?;
|
||||
}
|
||||
|
||||
let (response, body_stream) = self
|
||||
.http_client
|
||||
.http_request_stream(HttpRequestParams {
|
||||
method: "GET".to_string(),
|
||||
url: uri.to_string(),
|
||||
headers: protocol_headers(&headers),
|
||||
body: None,
|
||||
timeout_ms: None,
|
||||
request_id: "buffered-request".to_string(),
|
||||
stream_response: true,
|
||||
})
|
||||
.await
|
||||
.map_err(StreamableHttpClientAdapterError::from)
|
||||
.map_err(StreamableHttpError::Client)?;
|
||||
|
||||
if response.status == StatusCode::METHOD_NOT_ALLOWED.as_u16() {
|
||||
return Err(StreamableHttpError::ServerDoesNotSupportSse);
|
||||
}
|
||||
if response.status == StatusCode::NOT_FOUND.as_u16() {
|
||||
return Err(StreamableHttpError::Client(
|
||||
StreamableHttpClientAdapterError::SessionExpired404,
|
||||
));
|
||||
}
|
||||
if !status_is_success(response.status) {
|
||||
return Err(StreamableHttpError::UnexpectedServerResponse(
|
||||
format!("GET returned HTTP {}", response.status).into(),
|
||||
));
|
||||
}
|
||||
|
||||
match response_header(&response.headers, CONTENT_TYPE).as_deref() {
|
||||
Some(content_type) if is_streamable_http_content_type(content_type) => {}
|
||||
Some(content_type) => {
|
||||
return Err(StreamableHttpError::UnexpectedContentType(Some(
|
||||
content_type.to_string(),
|
||||
)));
|
||||
}
|
||||
None => {
|
||||
return Err(StreamableHttpError::UnexpectedContentType(None));
|
||||
}
|
||||
}
|
||||
|
||||
Ok(sse_stream_from_body(body_stream))
|
||||
}
|
||||
}
|
||||
|
||||
fn body_preview(body: impl Into<String>) -> String {
|
||||
let mut body_preview = body.into();
|
||||
let body_len = body_preview.len();
|
||||
if body_len > NON_JSON_RESPONSE_BODY_PREVIEW_BYTES {
|
||||
let mut boundary = NON_JSON_RESPONSE_BODY_PREVIEW_BYTES;
|
||||
while !body_preview.is_char_boundary(boundary) {
|
||||
boundary = boundary.saturating_sub(1);
|
||||
}
|
||||
body_preview.truncate(boundary);
|
||||
body_preview.push_str(&format!(
|
||||
"... (truncated {} bytes)",
|
||||
body_len.saturating_sub(boundary)
|
||||
));
|
||||
}
|
||||
body_preview
|
||||
}
|
||||
|
||||
fn insert_header<Error>(
|
||||
headers: &mut HeaderMap,
|
||||
name: HeaderName,
|
||||
value: String,
|
||||
map_error: impl FnOnce(String) -> Error,
|
||||
) -> std::result::Result<(), StreamableHttpError<Error>>
|
||||
where
|
||||
Error: std::error::Error + Send + Sync + 'static,
|
||||
{
|
||||
let value = reqwest::header::HeaderValue::from_str(&value)
|
||||
.map_err(|error| StreamableHttpError::Client(map_error(error.to_string())))?;
|
||||
headers.insert(name, value);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn is_streamable_http_content_type(content_type: &str) -> bool {
|
||||
content_type
|
||||
.as_bytes()
|
||||
.starts_with(EVENT_STREAM_MIME_TYPE.as_bytes())
|
||||
|| content_type
|
||||
.as_bytes()
|
||||
.starts_with(JSON_MIME_TYPE.as_bytes())
|
||||
}
|
||||
|
||||
fn protocol_headers(headers: &HeaderMap) -> Vec<HttpHeader> {
|
||||
headers
|
||||
.iter()
|
||||
.filter_map(|(name, value)| {
|
||||
Some(HttpHeader {
|
||||
name: name.as_str().to_string(),
|
||||
value: value.to_str().ok()?.to_string(),
|
||||
})
|
||||
})
|
||||
.collect()
|
||||
}
|
||||
|
||||
fn response_header(headers: &[HttpHeader], name: impl AsRef<str>) -> Option<String> {
|
||||
let name = name.as_ref();
|
||||
headers
|
||||
.iter()
|
||||
.find(|header| header.name.eq_ignore_ascii_case(name))
|
||||
.map(|header| header.value.clone())
|
||||
}
|
||||
|
||||
fn status_is_success(status: u16) -> bool {
|
||||
StatusCode::from_u16(status).is_ok_and(|status| status.is_success())
|
||||
}
|
||||
|
||||
async fn collect_body(
|
||||
body_stream: &mut HttpResponseBodyStream,
|
||||
) -> std::result::Result<Vec<u8>, StreamableHttpError<StreamableHttpClientAdapterError>> {
|
||||
let mut body = Vec::new();
|
||||
while let Some(chunk) = body_stream
|
||||
.recv()
|
||||
.await
|
||||
.map_err(StreamableHttpClientAdapterError::from)
|
||||
.map_err(StreamableHttpError::Client)?
|
||||
{
|
||||
body.extend_from_slice(&chunk);
|
||||
}
|
||||
Ok(body)
|
||||
}
|
||||
|
||||
fn sse_stream_from_body(
|
||||
body_stream: HttpResponseBodyStream,
|
||||
) -> BoxStream<'static, std::result::Result<Sse, sse_stream::Error>> {
|
||||
SseStream::from_byte_stream(stream::unfold(body_stream, |mut body_stream| async move {
|
||||
match body_stream.recv().await {
|
||||
Ok(Some(bytes)) => Some((Ok(Bytes::from(bytes)), body_stream)),
|
||||
Ok(None) => None,
|
||||
Err(error) => Some((Err(io::Error::other(error)), body_stream)),
|
||||
}
|
||||
}))
|
||||
.boxed()
|
||||
}
|
||||
@@ -1,6 +1,7 @@
|
||||
mod auth_status;
|
||||
mod elicitation_client_service;
|
||||
mod executor_process_transport;
|
||||
mod http_client_adapter;
|
||||
mod logging_client_handler;
|
||||
mod oauth;
|
||||
mod perform_oauth_login;
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
use std::borrow::Cow;
|
||||
use std::collections::HashMap;
|
||||
use std::ffi::OsString;
|
||||
use std::future::Future;
|
||||
@@ -14,16 +13,12 @@ use anyhow::Result;
|
||||
use anyhow::anyhow;
|
||||
use codex_client::build_reqwest_client_with_custom_ca;
|
||||
use codex_config::types::McpServerEnvVar;
|
||||
use codex_exec_server::HttpClient;
|
||||
use futures::FutureExt;
|
||||
use futures::StreamExt;
|
||||
use futures::future::BoxFuture;
|
||||
use futures::stream::BoxStream;
|
||||
use oauth2::TokenResponse;
|
||||
use reqwest::header::ACCEPT;
|
||||
use reqwest::header::AUTHORIZATION;
|
||||
use reqwest::header::CONTENT_TYPE;
|
||||
use reqwest::header::HeaderMap;
|
||||
use reqwest::header::WWW_AUTHENTICATE;
|
||||
use rmcp::model::CallToolRequestParams;
|
||||
use rmcp::model::CallToolResult;
|
||||
use rmcp::model::ClientNotification;
|
||||
@@ -52,16 +47,11 @@ use rmcp::transport::StreamableHttpClientTransport;
|
||||
use rmcp::transport::auth::AuthClient;
|
||||
use rmcp::transport::auth::AuthError;
|
||||
use rmcp::transport::auth::OAuthState;
|
||||
use rmcp::transport::streamable_http_client::AuthRequiredError;
|
||||
use rmcp::transport::streamable_http_client::StreamableHttpClient;
|
||||
use rmcp::transport::streamable_http_client::StreamableHttpClientTransportConfig;
|
||||
use rmcp::transport::streamable_http_client::StreamableHttpError;
|
||||
use rmcp::transport::streamable_http_client::StreamableHttpPostResponse;
|
||||
use serde::Deserialize;
|
||||
use serde::Serialize;
|
||||
use serde_json::Value;
|
||||
use sse_stream::Sse;
|
||||
use sse_stream::SseStream;
|
||||
use tokio::sync::Mutex;
|
||||
use tokio::sync::Semaphore;
|
||||
use tokio::sync::watch;
|
||||
@@ -69,6 +59,8 @@ use tokio::time;
|
||||
use tracing::warn;
|
||||
|
||||
use crate::elicitation_client_service::ElicitationClientService;
|
||||
use crate::http_client_adapter::StreamableHttpClientAdapter;
|
||||
use crate::http_client_adapter::StreamableHttpClientAdapterError;
|
||||
use crate::load_oauth_tokens;
|
||||
use crate::oauth::OAuthPersistor;
|
||||
use crate::oauth::StoredOAuthTokens;
|
||||
@@ -79,239 +71,15 @@ use crate::utils::apply_default_headers;
|
||||
use crate::utils::build_default_headers;
|
||||
use codex_config::types::OAuthCredentialsStoreMode;
|
||||
|
||||
const EVENT_STREAM_MIME_TYPE: &str = "text/event-stream";
|
||||
const JSON_MIME_TYPE: &str = "application/json";
|
||||
const HEADER_LAST_EVENT_ID: &str = "Last-Event-Id";
|
||||
const HEADER_SESSION_ID: &str = "Mcp-Session-Id";
|
||||
const NON_JSON_RESPONSE_BODY_PREVIEW_BYTES: usize = 8_192;
|
||||
|
||||
#[derive(Clone)]
|
||||
struct StreamableHttpResponseClient {
|
||||
inner: reqwest::Client,
|
||||
}
|
||||
|
||||
impl StreamableHttpResponseClient {
|
||||
fn new(inner: reqwest::Client) -> Self {
|
||||
Self { inner }
|
||||
}
|
||||
|
||||
fn reqwest_error(
|
||||
error: reqwest::Error,
|
||||
) -> StreamableHttpError<StreamableHttpResponseClientError> {
|
||||
StreamableHttpError::Client(StreamableHttpResponseClientError::from(error))
|
||||
}
|
||||
}
|
||||
|
||||
fn build_http_client(default_headers: &HeaderMap) -> Result<reqwest::Client> {
|
||||
let builder = apply_default_headers(reqwest::Client::builder(), default_headers);
|
||||
Ok(build_reqwest_client_with_custom_ca(builder)?)
|
||||
}
|
||||
|
||||
#[derive(Debug, thiserror::Error)]
|
||||
enum StreamableHttpResponseClientError {
|
||||
#[error("streamable HTTP session expired with 404 Not Found")]
|
||||
SessionExpired404,
|
||||
#[error(transparent)]
|
||||
Reqwest(#[from] reqwest::Error),
|
||||
}
|
||||
|
||||
impl StreamableHttpClient for StreamableHttpResponseClient {
|
||||
type Error = StreamableHttpResponseClientError;
|
||||
|
||||
async fn post_message(
|
||||
&self,
|
||||
uri: Arc<str>,
|
||||
message: rmcp::model::ClientJsonRpcMessage,
|
||||
session_id: Option<Arc<str>>,
|
||||
auth_token: Option<String>,
|
||||
) -> std::result::Result<StreamableHttpPostResponse, StreamableHttpError<Self::Error>> {
|
||||
let mut request = self
|
||||
.inner
|
||||
.post(uri.as_ref())
|
||||
.header(ACCEPT, [EVENT_STREAM_MIME_TYPE, JSON_MIME_TYPE].join(", "));
|
||||
if let Some(auth_header) = auth_token {
|
||||
request = request.bearer_auth(auth_header);
|
||||
}
|
||||
if let Some(session_id_value) = session_id.as_ref() {
|
||||
request = request.header(HEADER_SESSION_ID, session_id_value.as_ref());
|
||||
}
|
||||
|
||||
let response = request
|
||||
.json(&message)
|
||||
.send()
|
||||
.await
|
||||
.map_err(StreamableHttpResponseClient::reqwest_error)?;
|
||||
if response.status() == reqwest::StatusCode::NOT_FOUND && session_id.is_some() {
|
||||
return Err(StreamableHttpError::Client(
|
||||
StreamableHttpResponseClientError::SessionExpired404,
|
||||
));
|
||||
}
|
||||
if response.status() == reqwest::StatusCode::UNAUTHORIZED
|
||||
&& let Some(header) = response.headers().get(WWW_AUTHENTICATE)
|
||||
{
|
||||
let header = header
|
||||
.to_str()
|
||||
.map_err(|_| {
|
||||
StreamableHttpError::UnexpectedServerResponse(Cow::Borrowed(
|
||||
"invalid www-authenticate header value",
|
||||
))
|
||||
})?
|
||||
.to_string();
|
||||
return Err(StreamableHttpError::AuthRequired(AuthRequiredError {
|
||||
www_authenticate_header: header,
|
||||
}));
|
||||
}
|
||||
|
||||
let status = response.status();
|
||||
if matches!(
|
||||
status,
|
||||
reqwest::StatusCode::ACCEPTED | reqwest::StatusCode::NO_CONTENT
|
||||
) {
|
||||
return Ok(StreamableHttpPostResponse::Accepted);
|
||||
}
|
||||
|
||||
let content_type = response
|
||||
.headers()
|
||||
.get(CONTENT_TYPE)
|
||||
.and_then(|value| value.to_str().ok())
|
||||
.map(str::to_string);
|
||||
let session_id = response
|
||||
.headers()
|
||||
.get(HEADER_SESSION_ID)
|
||||
.and_then(|value| value.to_str().ok())
|
||||
.map(str::to_string);
|
||||
|
||||
match content_type.as_deref() {
|
||||
Some(ct) if ct.as_bytes().starts_with(EVENT_STREAM_MIME_TYPE.as_bytes()) => {
|
||||
let event_stream = SseStream::from_byte_stream(response.bytes_stream()).boxed();
|
||||
Ok(StreamableHttpPostResponse::Sse(event_stream, session_id))
|
||||
}
|
||||
Some(ct) if ct.as_bytes().starts_with(JSON_MIME_TYPE.as_bytes()) => {
|
||||
let message = response
|
||||
.json()
|
||||
.await
|
||||
.map_err(StreamableHttpResponseClient::reqwest_error)?;
|
||||
Ok(StreamableHttpPostResponse::Json(message, session_id))
|
||||
}
|
||||
_ => {
|
||||
let body = response
|
||||
.text()
|
||||
.await
|
||||
.map_err(StreamableHttpResponseClient::reqwest_error)?;
|
||||
let mut body_preview = body;
|
||||
let body_len = body_preview.len();
|
||||
if body_len > NON_JSON_RESPONSE_BODY_PREVIEW_BYTES {
|
||||
let mut boundary = NON_JSON_RESPONSE_BODY_PREVIEW_BYTES;
|
||||
while !body_preview.is_char_boundary(boundary) {
|
||||
boundary = boundary.saturating_sub(1);
|
||||
}
|
||||
body_preview.truncate(boundary);
|
||||
body_preview.push_str(&format!(
|
||||
"... (truncated {} bytes)",
|
||||
body_len.saturating_sub(boundary)
|
||||
));
|
||||
}
|
||||
|
||||
let content_type = content_type.unwrap_or_else(|| "missing-content-type".into());
|
||||
Err(StreamableHttpError::UnexpectedContentType(Some(format!(
|
||||
"{content_type}; body: {body_preview}"
|
||||
))))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn delete_session(
|
||||
&self,
|
||||
uri: Arc<str>,
|
||||
session: Arc<str>,
|
||||
auth_token: Option<String>,
|
||||
) -> std::result::Result<(), StreamableHttpError<Self::Error>> {
|
||||
let mut request_builder = self.inner.delete(uri.as_ref());
|
||||
if let Some(auth_header) = auth_token {
|
||||
request_builder = request_builder.bearer_auth(auth_header);
|
||||
}
|
||||
let response = request_builder
|
||||
.header(HEADER_SESSION_ID, session.as_ref())
|
||||
.send()
|
||||
.await
|
||||
.map_err(StreamableHttpResponseClient::reqwest_error)?;
|
||||
|
||||
if response.status() == reqwest::StatusCode::METHOD_NOT_ALLOWED {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
response
|
||||
.error_for_status()
|
||||
.map_err(StreamableHttpResponseClient::reqwest_error)?;
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn get_stream(
|
||||
&self,
|
||||
uri: Arc<str>,
|
||||
session_id: Arc<str>,
|
||||
last_event_id: Option<String>,
|
||||
auth_token: Option<String>,
|
||||
) -> std::result::Result<
|
||||
BoxStream<'static, std::result::Result<Sse, sse_stream::Error>>,
|
||||
StreamableHttpError<Self::Error>,
|
||||
> {
|
||||
let mut request_builder = self
|
||||
.inner
|
||||
.get(uri.as_ref())
|
||||
.header(ACCEPT, [EVENT_STREAM_MIME_TYPE, JSON_MIME_TYPE].join(", "))
|
||||
.header(HEADER_SESSION_ID, session_id.as_ref());
|
||||
if let Some(last_event_id) = last_event_id {
|
||||
request_builder = request_builder.header(HEADER_LAST_EVENT_ID, last_event_id);
|
||||
}
|
||||
if let Some(auth_header) = auth_token {
|
||||
request_builder = request_builder.bearer_auth(auth_header);
|
||||
}
|
||||
|
||||
let response = request_builder
|
||||
.send()
|
||||
.await
|
||||
.map_err(StreamableHttpResponseClient::reqwest_error)?;
|
||||
if response.status() == reqwest::StatusCode::METHOD_NOT_ALLOWED {
|
||||
return Err(StreamableHttpError::ServerDoesNotSupportSse);
|
||||
}
|
||||
if response.status() == reqwest::StatusCode::NOT_FOUND {
|
||||
return Err(StreamableHttpError::Client(
|
||||
StreamableHttpResponseClientError::SessionExpired404,
|
||||
));
|
||||
}
|
||||
|
||||
let response = response
|
||||
.error_for_status()
|
||||
.map_err(StreamableHttpResponseClient::reqwest_error)?;
|
||||
match response.headers().get(CONTENT_TYPE) {
|
||||
Some(ct)
|
||||
if ct.as_bytes().starts_with(EVENT_STREAM_MIME_TYPE.as_bytes())
|
||||
|| ct.as_bytes().starts_with(JSON_MIME_TYPE.as_bytes()) => {}
|
||||
Some(ct) => {
|
||||
return Err(StreamableHttpError::UnexpectedContentType(Some(
|
||||
String::from_utf8_lossy(ct.as_bytes()).to_string(),
|
||||
)));
|
||||
}
|
||||
None => {
|
||||
return Err(StreamableHttpError::UnexpectedContentType(None));
|
||||
}
|
||||
}
|
||||
|
||||
let event_stream = SseStream::from_byte_stream(response.bytes_stream()).boxed();
|
||||
Ok(event_stream)
|
||||
}
|
||||
}
|
||||
|
||||
enum PendingTransport {
|
||||
Stdio {
|
||||
transport: StdioServerTransport,
|
||||
},
|
||||
StreamableHttp {
|
||||
transport: StreamableHttpClientTransport<StreamableHttpResponseClient>,
|
||||
transport: StreamableHttpClientTransport<StreamableHttpClientAdapter>,
|
||||
},
|
||||
StreamableHttpWithOAuth {
|
||||
transport: StreamableHttpClientTransport<AuthClient<StreamableHttpResponseClient>>,
|
||||
transport: StreamableHttpClientTransport<AuthClient<StreamableHttpClientAdapter>>,
|
||||
oauth_persistor: OAuthPersistor,
|
||||
},
|
||||
}
|
||||
@@ -339,6 +107,7 @@ enum TransportRecipe {
|
||||
http_headers: Option<HashMap<String, String>>,
|
||||
env_http_headers: Option<HashMap<String, String>>,
|
||||
store_mode: OAuthCredentialsStoreMode,
|
||||
http_client: Arc<dyn HttpClient>,
|
||||
},
|
||||
}
|
||||
|
||||
@@ -536,6 +305,7 @@ impl RmcpClient {
|
||||
http_headers: Option<HashMap<String, String>>,
|
||||
env_http_headers: Option<HashMap<String, String>>,
|
||||
store_mode: OAuthCredentialsStoreMode,
|
||||
http_client: Arc<dyn HttpClient>,
|
||||
) -> Result<Self> {
|
||||
let transport_recipe = TransportRecipe::StreamableHttp {
|
||||
server_name: server_name.to_string(),
|
||||
@@ -544,6 +314,7 @@ impl RmcpClient {
|
||||
http_headers,
|
||||
env_http_headers,
|
||||
store_mode,
|
||||
http_client,
|
||||
};
|
||||
let transport = Self::create_pending_transport(&transport_recipe).await?;
|
||||
Ok(Self {
|
||||
@@ -895,6 +666,7 @@ impl RmcpClient {
|
||||
http_headers,
|
||||
env_http_headers,
|
||||
store_mode,
|
||||
http_client,
|
||||
} => {
|
||||
let default_headers =
|
||||
build_default_headers(http_headers.clone(), env_http_headers.clone())?;
|
||||
@@ -919,6 +691,7 @@ impl RmcpClient {
|
||||
initial_tokens.clone(),
|
||||
*store_mode,
|
||||
default_headers.clone(),
|
||||
Arc::clone(http_client),
|
||||
)
|
||||
.await
|
||||
{
|
||||
@@ -945,9 +718,11 @@ impl RmcpClient {
|
||||
let http_config =
|
||||
StreamableHttpClientTransportConfig::with_uri(url.clone())
|
||||
.auth_header(access_token);
|
||||
let http_client = build_http_client(&default_headers)?;
|
||||
let transport = StreamableHttpClientTransport::with_client(
|
||||
StreamableHttpResponseClient::new(http_client),
|
||||
StreamableHttpClientAdapter::new(
|
||||
Arc::clone(http_client),
|
||||
default_headers,
|
||||
),
|
||||
http_config,
|
||||
);
|
||||
Ok(PendingTransport::StreamableHttp { transport })
|
||||
@@ -961,10 +736,8 @@ impl RmcpClient {
|
||||
http_config = http_config.auth_header(bearer_token);
|
||||
}
|
||||
|
||||
let http_client = build_http_client(&default_headers)?;
|
||||
|
||||
let transport = StreamableHttpClientTransport::with_client(
|
||||
StreamableHttpResponseClient::new(http_client),
|
||||
StreamableHttpClientAdapter::new(Arc::clone(http_client), default_headers),
|
||||
http_config,
|
||||
);
|
||||
Ok(PendingTransport::StreamableHttp { transport })
|
||||
@@ -1084,12 +857,12 @@ impl RmcpClient {
|
||||
|
||||
error
|
||||
.error
|
||||
.downcast_ref::<StreamableHttpError<StreamableHttpResponseClientError>>()
|
||||
.downcast_ref::<StreamableHttpError<StreamableHttpClientAdapterError>>()
|
||||
.is_some_and(|error| {
|
||||
matches!(
|
||||
error,
|
||||
StreamableHttpError::Client(
|
||||
StreamableHttpResponseClientError::SessionExpired404
|
||||
StreamableHttpClientAdapterError::SessionExpired404
|
||||
)
|
||||
)
|
||||
})
|
||||
@@ -1156,12 +929,18 @@ async fn create_oauth_transport_and_runtime(
|
||||
initial_tokens: StoredOAuthTokens,
|
||||
credentials_store: OAuthCredentialsStoreMode,
|
||||
default_headers: HeaderMap,
|
||||
http_client: Arc<dyn HttpClient>,
|
||||
) -> Result<(
|
||||
StreamableHttpClientTransport<AuthClient<StreamableHttpResponseClient>>,
|
||||
StreamableHttpClientTransport<AuthClient<StreamableHttpClientAdapter>>,
|
||||
OAuthPersistor,
|
||||
)> {
|
||||
let http_client = build_http_client(&default_headers)?;
|
||||
let mut oauth_state = OAuthState::new(url.to_string(), Some(http_client.clone())).await?;
|
||||
let builder = apply_default_headers(reqwest::Client::builder(), &default_headers);
|
||||
let oauth_metadata_client = build_reqwest_client_with_custom_ca(builder)?;
|
||||
// TODO(aibrahim): teach OAuth bootstrap and refresh to use the same
|
||||
// shared HTTP client abstraction instead of always creating the local
|
||||
// reqwest metadata client here.
|
||||
let mut oauth_state =
|
||||
OAuthState::new(url.to_string(), Some(oauth_metadata_client.clone())).await?;
|
||||
|
||||
oauth_state
|
||||
.set_credentials(
|
||||
@@ -1178,7 +957,10 @@ async fn create_oauth_transport_and_runtime(
|
||||
}
|
||||
};
|
||||
|
||||
let auth_client = AuthClient::new(StreamableHttpResponseClient::new(http_client), manager);
|
||||
let auth_client = AuthClient::new(
|
||||
StreamableHttpClientAdapter::new(http_client, default_headers),
|
||||
manager,
|
||||
);
|
||||
let auth_manager = auth_client.auth_manager.clone();
|
||||
|
||||
let transport = StreamableHttpClientTransport::with_client(
|
||||
|
||||
@@ -1,190 +1,12 @@
|
||||
use std::net::TcpListener;
|
||||
use std::path::PathBuf;
|
||||
use std::time::Duration;
|
||||
use std::time::Instant;
|
||||
mod streamable_http_test_support;
|
||||
|
||||
use codex_config::types::OAuthCredentialsStoreMode;
|
||||
use codex_rmcp_client::ElicitationAction;
|
||||
use codex_rmcp_client::ElicitationResponse;
|
||||
use codex_rmcp_client::RmcpClient;
|
||||
use codex_utils_cargo_bin::CargoBinError;
|
||||
use futures::FutureExt as _;
|
||||
use pretty_assertions::assert_eq;
|
||||
use rmcp::model::CallToolResult;
|
||||
use rmcp::model::ClientCapabilities;
|
||||
use rmcp::model::ElicitationCapability;
|
||||
use rmcp::model::FormElicitationCapability;
|
||||
use rmcp::model::Implementation;
|
||||
use rmcp::model::InitializeRequestParams;
|
||||
use rmcp::model::ProtocolVersion;
|
||||
use serde_json::json;
|
||||
use tokio::net::TcpStream;
|
||||
use tokio::process::Child;
|
||||
use tokio::process::Command;
|
||||
use tokio::time::sleep;
|
||||
|
||||
const SESSION_POST_FAILURE_CONTROL_PATH: &str = "/test/control/session-post-failure";
|
||||
|
||||
fn streamable_http_server_bin() -> Result<PathBuf, CargoBinError> {
|
||||
codex_utils_cargo_bin::cargo_bin("test_streamable_http_server")
|
||||
}
|
||||
|
||||
fn init_params() -> InitializeRequestParams {
|
||||
InitializeRequestParams {
|
||||
meta: None,
|
||||
capabilities: ClientCapabilities {
|
||||
experimental: None,
|
||||
extensions: None,
|
||||
roots: None,
|
||||
sampling: None,
|
||||
elicitation: Some(ElicitationCapability {
|
||||
form: Some(FormElicitationCapability {
|
||||
schema_validation: None,
|
||||
}),
|
||||
url: None,
|
||||
}),
|
||||
tasks: None,
|
||||
},
|
||||
client_info: Implementation {
|
||||
name: "codex-test".into(),
|
||||
version: "0.0.0-test".into(),
|
||||
title: Some("Codex rmcp recovery test".into()),
|
||||
description: None,
|
||||
icons: None,
|
||||
website_url: None,
|
||||
},
|
||||
protocol_version: ProtocolVersion::V_2025_06_18,
|
||||
}
|
||||
}
|
||||
|
||||
fn expected_echo_result(message: &str) -> CallToolResult {
|
||||
CallToolResult {
|
||||
content: Vec::new(),
|
||||
structured_content: Some(json!({
|
||||
"echo": format!("ECHOING: {message}"),
|
||||
"env": null,
|
||||
})),
|
||||
is_error: Some(false),
|
||||
meta: None,
|
||||
}
|
||||
}
|
||||
|
||||
async fn create_client(base_url: &str) -> anyhow::Result<RmcpClient> {
|
||||
let client = RmcpClient::new_streamable_http_client(
|
||||
"test-streamable-http",
|
||||
&format!("{base_url}/mcp"),
|
||||
Some("test-bearer".to_string()),
|
||||
/*http_headers*/ None,
|
||||
/*env_http_headers*/ None,
|
||||
OAuthCredentialsStoreMode::File,
|
||||
)
|
||||
.await?;
|
||||
|
||||
client
|
||||
.initialize(
|
||||
init_params(),
|
||||
Some(Duration::from_secs(5)),
|
||||
Box::new(|_, _| {
|
||||
async {
|
||||
Ok(ElicitationResponse {
|
||||
action: ElicitationAction::Accept,
|
||||
content: Some(json!({})),
|
||||
meta: None,
|
||||
})
|
||||
}
|
||||
.boxed()
|
||||
}),
|
||||
)
|
||||
.await?;
|
||||
|
||||
Ok(client)
|
||||
}
|
||||
|
||||
async fn call_echo_tool(client: &RmcpClient, message: &str) -> anyhow::Result<CallToolResult> {
|
||||
client
|
||||
.call_tool(
|
||||
"echo".to_string(),
|
||||
Some(json!({ "message": message })),
|
||||
/*meta*/ None,
|
||||
Some(Duration::from_secs(5)),
|
||||
)
|
||||
.await
|
||||
}
|
||||
|
||||
async fn arm_session_post_failure(
|
||||
base_url: &str,
|
||||
status: u16,
|
||||
remaining: usize,
|
||||
) -> anyhow::Result<()> {
|
||||
let response = reqwest::Client::new()
|
||||
.post(format!("{base_url}{SESSION_POST_FAILURE_CONTROL_PATH}"))
|
||||
.json(&json!({
|
||||
"status": status,
|
||||
"remaining": remaining,
|
||||
}))
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
assert_eq!(response.status(), reqwest::StatusCode::NO_CONTENT);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn spawn_streamable_http_server() -> anyhow::Result<(Child, String)> {
|
||||
let listener = TcpListener::bind("127.0.0.1:0")?;
|
||||
let port = listener.local_addr()?.port();
|
||||
drop(listener);
|
||||
|
||||
let bind_addr = format!("127.0.0.1:{port}");
|
||||
let base_url = format!("http://{bind_addr}");
|
||||
let mut child = Command::new(streamable_http_server_bin()?)
|
||||
.kill_on_drop(true)
|
||||
.env("MCP_STREAMABLE_HTTP_BIND_ADDR", &bind_addr)
|
||||
.spawn()?;
|
||||
|
||||
wait_for_streamable_http_server(&mut child, &bind_addr, Duration::from_secs(5)).await?;
|
||||
Ok((child, base_url))
|
||||
}
|
||||
|
||||
async fn wait_for_streamable_http_server(
|
||||
server_child: &mut Child,
|
||||
address: &str,
|
||||
timeout: Duration,
|
||||
) -> anyhow::Result<()> {
|
||||
let deadline = Instant::now() + timeout;
|
||||
|
||||
loop {
|
||||
if let Some(status) = server_child.try_wait()? {
|
||||
return Err(anyhow::anyhow!(
|
||||
"streamable HTTP server exited early with status {status}"
|
||||
));
|
||||
}
|
||||
|
||||
let remaining = deadline.saturating_duration_since(Instant::now());
|
||||
if remaining.is_zero() {
|
||||
return Err(anyhow::anyhow!(
|
||||
"timed out waiting for streamable HTTP server at {address}: deadline reached"
|
||||
));
|
||||
}
|
||||
|
||||
match tokio::time::timeout(remaining, TcpStream::connect(address)).await {
|
||||
Ok(Ok(_)) => return Ok(()),
|
||||
Ok(Err(error)) => {
|
||||
if Instant::now() >= deadline {
|
||||
return Err(anyhow::anyhow!(
|
||||
"timed out waiting for streamable HTTP server at {address}: {error}"
|
||||
));
|
||||
}
|
||||
}
|
||||
Err(_) => {
|
||||
return Err(anyhow::anyhow!(
|
||||
"timed out waiting for streamable HTTP server at {address}: connect call timed out"
|
||||
));
|
||||
}
|
||||
}
|
||||
|
||||
sleep(Duration::from_millis(50)).await;
|
||||
}
|
||||
}
|
||||
use streamable_http_test_support::arm_session_post_failure;
|
||||
use streamable_http_test_support::call_echo_tool;
|
||||
use streamable_http_test_support::create_client;
|
||||
use streamable_http_test_support::expected_echo_result;
|
||||
use streamable_http_test_support::spawn_streamable_http_server;
|
||||
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 1)]
|
||||
async fn streamable_http_404_session_expiry_recovers_and_retries_once() -> anyhow::Result<()> {
|
||||
|
||||
@@ -0,0 +1,37 @@
|
||||
//! Integration coverage for the remote Streamable HTTP RMCP path.
|
||||
//!
|
||||
//! These tests exercise the orchestrator-side RMCP adapter against a real
|
||||
//! `exec-server` process so HTTP requests go through the remote runtime path
|
||||
//! instead of direct local `reqwest` calls.
|
||||
|
||||
mod streamable_http_test_support;
|
||||
|
||||
use pretty_assertions::assert_eq;
|
||||
|
||||
use streamable_http_test_support::call_echo_tool;
|
||||
use streamable_http_test_support::create_remote_client;
|
||||
use streamable_http_test_support::expected_echo_result;
|
||||
use streamable_http_test_support::spawn_exec_server;
|
||||
use streamable_http_test_support::spawn_streamable_http_server;
|
||||
|
||||
/// What this tests: the RMCP remote Streamable HTTP adapter can initialize
|
||||
/// a server and call a tool while every MCP HTTP request goes through a real
|
||||
/// exec-server process instead of a direct reqwest transport.
|
||||
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
|
||||
async fn streamable_http_remote_client_round_trips_through_exec_server() -> anyhow::Result<()> {
|
||||
// Phase 1: start the MCP Streamable HTTP test server and a local
|
||||
// exec-server process that will own the HTTP network calls.
|
||||
let (_server, base_url) = spawn_streamable_http_server().await?;
|
||||
let exec_server = spawn_exec_server().await?;
|
||||
|
||||
// Phase 2: create and initialize the RMCP client using the executor-backed
|
||||
// Streamable HTTP transport.
|
||||
let client = create_remote_client(&base_url, exec_server.client.clone()).await?;
|
||||
|
||||
// Phase 3: prove the initialized client can complete a tool call and
|
||||
// preserve the normal RMCP response shape.
|
||||
let result = call_echo_tool(&client, "remote").await?;
|
||||
assert_eq!(result, expected_echo_result("remote"));
|
||||
|
||||
Ok(())
|
||||
}
|
||||
@@ -0,0 +1,314 @@
|
||||
//! Shared helpers for Streamable HTTP RMCP integration tests.
|
||||
//!
|
||||
//! This support module starts the test HTTP server, launches a real
|
||||
//! `exec-server` when remote coverage is needed, and provides small helpers for
|
||||
//! creating RMCP clients and asserting round-trip behavior.
|
||||
|
||||
// This support module is included by multiple integration-test crates. Each
|
||||
// crate uses a different subset of the helpers, so dead-code warnings would
|
||||
// otherwise depend on which test file compiled the module.
|
||||
#![allow(dead_code)]
|
||||
|
||||
use std::net::TcpListener;
|
||||
use std::path::PathBuf;
|
||||
use std::process::Stdio;
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
use std::time::Instant;
|
||||
|
||||
use anyhow::Context as _;
|
||||
use codex_config::types::OAuthCredentialsStoreMode;
|
||||
use codex_exec_server::Environment;
|
||||
use codex_exec_server::ExecServerClient;
|
||||
use codex_exec_server::RemoteExecServerConnectArgs;
|
||||
use codex_rmcp_client::ElicitationAction;
|
||||
use codex_rmcp_client::ElicitationResponse;
|
||||
use codex_rmcp_client::RmcpClient;
|
||||
use codex_utils_cargo_bin::CargoBinError;
|
||||
use futures::FutureExt as _;
|
||||
use pretty_assertions::assert_eq;
|
||||
use rmcp::model::CallToolResult;
|
||||
use rmcp::model::ClientCapabilities;
|
||||
use rmcp::model::ElicitationCapability;
|
||||
use rmcp::model::FormElicitationCapability;
|
||||
use rmcp::model::Implementation;
|
||||
use rmcp::model::InitializeRequestParams;
|
||||
use rmcp::model::ProtocolVersion;
|
||||
use serde_json::json;
|
||||
use tempfile::TempDir;
|
||||
use tokio::io::AsyncBufReadExt;
|
||||
use tokio::io::BufReader;
|
||||
use tokio::net::TcpStream;
|
||||
use tokio::process::Child;
|
||||
use tokio::process::Command;
|
||||
use tokio::time::sleep;
|
||||
|
||||
const SESSION_POST_FAILURE_CONTROL_PATH: &str = "/test/control/session-post-failure";
|
||||
|
||||
fn streamable_http_server_bin() -> Result<PathBuf, CargoBinError> {
|
||||
codex_utils_cargo_bin::cargo_bin("test_streamable_http_server")
|
||||
}
|
||||
|
||||
fn init_params() -> InitializeRequestParams {
|
||||
InitializeRequestParams {
|
||||
meta: None,
|
||||
capabilities: ClientCapabilities {
|
||||
experimental: None,
|
||||
extensions: None,
|
||||
roots: None,
|
||||
sampling: None,
|
||||
elicitation: Some(ElicitationCapability {
|
||||
form: Some(FormElicitationCapability {
|
||||
schema_validation: None,
|
||||
}),
|
||||
url: None,
|
||||
}),
|
||||
tasks: None,
|
||||
},
|
||||
client_info: Implementation {
|
||||
name: "codex-test".into(),
|
||||
version: "0.0.0-test".into(),
|
||||
title: Some("Codex rmcp recovery test".into()),
|
||||
description: None,
|
||||
icons: None,
|
||||
website_url: None,
|
||||
},
|
||||
protocol_version: ProtocolVersion::V_2025_06_18,
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) fn expected_echo_result(message: &str) -> CallToolResult {
|
||||
CallToolResult {
|
||||
content: Vec::new(),
|
||||
structured_content: Some(json!({
|
||||
"echo": format!("ECHOING: {message}"),
|
||||
"env": null,
|
||||
})),
|
||||
is_error: Some(false),
|
||||
meta: None,
|
||||
}
|
||||
}
|
||||
|
||||
pub(crate) async fn create_client(base_url: &str) -> anyhow::Result<RmcpClient> {
|
||||
let client = RmcpClient::new_streamable_http_client(
|
||||
"test-streamable-http",
|
||||
&format!("{base_url}/mcp"),
|
||||
Some("test-bearer".to_string()),
|
||||
/*http_headers*/ None,
|
||||
/*env_http_headers*/ None,
|
||||
OAuthCredentialsStoreMode::File,
|
||||
Environment::default_for_tests().get_http_client(),
|
||||
)
|
||||
.await?;
|
||||
|
||||
client
|
||||
.initialize(
|
||||
init_params(),
|
||||
Some(Duration::from_secs(5)),
|
||||
Box::new(|_, _| {
|
||||
async {
|
||||
Ok(ElicitationResponse {
|
||||
action: ElicitationAction::Accept,
|
||||
content: Some(json!({})),
|
||||
meta: None,
|
||||
})
|
||||
}
|
||||
.boxed()
|
||||
}),
|
||||
)
|
||||
.await?;
|
||||
|
||||
Ok(client)
|
||||
}
|
||||
|
||||
/// Creates a Streamable HTTP RMCP client that sends traffic through the remote
|
||||
/// runtime HTTP API.
|
||||
pub(crate) async fn create_remote_client(
|
||||
base_url: &str,
|
||||
http_client: ExecServerClient,
|
||||
) -> anyhow::Result<RmcpClient> {
|
||||
let client = RmcpClient::new_streamable_http_client(
|
||||
"test-streamable-http-remote",
|
||||
&format!("{base_url}/mcp"),
|
||||
Some("test-bearer".to_string()),
|
||||
/*http_headers*/ None,
|
||||
/*env_http_headers*/ None,
|
||||
OAuthCredentialsStoreMode::File,
|
||||
Arc::new(http_client),
|
||||
)
|
||||
.await?;
|
||||
|
||||
client
|
||||
.initialize(
|
||||
init_params(),
|
||||
Some(Duration::from_secs(5)),
|
||||
Box::new(|_, _| {
|
||||
async {
|
||||
Ok(ElicitationResponse {
|
||||
action: ElicitationAction::Accept,
|
||||
content: Some(json!({})),
|
||||
meta: None,
|
||||
})
|
||||
}
|
||||
.boxed()
|
||||
}),
|
||||
)
|
||||
.await?;
|
||||
|
||||
Ok(client)
|
||||
}
|
||||
|
||||
pub(crate) async fn call_echo_tool(
|
||||
client: &RmcpClient,
|
||||
message: &str,
|
||||
) -> anyhow::Result<CallToolResult> {
|
||||
client
|
||||
.call_tool(
|
||||
"echo".to_string(),
|
||||
Some(json!({ "message": message })),
|
||||
/*meta*/ None,
|
||||
Some(Duration::from_secs(5)),
|
||||
)
|
||||
.await
|
||||
}
|
||||
|
||||
pub(crate) async fn arm_session_post_failure(
|
||||
base_url: &str,
|
||||
status: u16,
|
||||
remaining: usize,
|
||||
) -> anyhow::Result<()> {
|
||||
let response = reqwest::Client::new()
|
||||
.post(format!("{base_url}{SESSION_POST_FAILURE_CONTROL_PATH}"))
|
||||
.json(&json!({
|
||||
"status": status,
|
||||
"remaining": remaining,
|
||||
}))
|
||||
.send()
|
||||
.await?;
|
||||
|
||||
assert_eq!(response.status(), reqwest::StatusCode::NO_CONTENT);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
pub(crate) async fn spawn_streamable_http_server() -> anyhow::Result<(Child, String)> {
|
||||
let listener = TcpListener::bind("127.0.0.1:0")?;
|
||||
let port = listener.local_addr()?.port();
|
||||
drop(listener);
|
||||
|
||||
let bind_addr = format!("127.0.0.1:{port}");
|
||||
let base_url = format!("http://{bind_addr}");
|
||||
let mut child = Command::new(streamable_http_server_bin()?)
|
||||
.kill_on_drop(true)
|
||||
.env("MCP_STREAMABLE_HTTP_BIND_ADDR", &bind_addr)
|
||||
.spawn()?;
|
||||
|
||||
wait_for_streamable_http_server(&mut child, &bind_addr, Duration::from_secs(5)).await?;
|
||||
Ok((child, base_url))
|
||||
}
|
||||
|
||||
/// Owns the exec-server process used by the remote-client integration test.
|
||||
pub(crate) struct ExecServerProcess {
|
||||
_codex_home: TempDir,
|
||||
child: Child,
|
||||
pub(crate) client: ExecServerClient,
|
||||
}
|
||||
|
||||
impl Drop for ExecServerProcess {
|
||||
/// Stops the local exec-server process best-effort when the test exits.
|
||||
fn drop(&mut self) {
|
||||
let _ = self.child.start_kill();
|
||||
}
|
||||
}
|
||||
|
||||
/// Starts a local exec-server and connects an initialized `ExecServerClient`.
|
||||
pub(crate) async fn spawn_exec_server() -> anyhow::Result<ExecServerProcess> {
|
||||
let codex_home = TempDir::new()?;
|
||||
let mut child = Command::new(codex_utils_cargo_bin::cargo_bin("codex")?)
|
||||
.args(["exec-server", "--listen", "ws://127.0.0.1:0"])
|
||||
.stdin(Stdio::null())
|
||||
.stdout(Stdio::piped())
|
||||
.stderr(Stdio::inherit())
|
||||
.kill_on_drop(true)
|
||||
.env("CODEX_HOME", codex_home.path())
|
||||
.spawn()?;
|
||||
|
||||
let websocket_url = read_exec_server_listen_url(&mut child).await?;
|
||||
let client = ExecServerClient::connect_websocket(RemoteExecServerConnectArgs::new(
|
||||
websocket_url,
|
||||
"rmcp-client-remote-http-test".to_string(),
|
||||
))
|
||||
.await?;
|
||||
|
||||
Ok(ExecServerProcess {
|
||||
_codex_home: codex_home,
|
||||
child,
|
||||
client,
|
||||
})
|
||||
}
|
||||
|
||||
/// Reads the websocket URL printed by `codex exec-server --listen`.
|
||||
async fn read_exec_server_listen_url(child: &mut Child) -> anyhow::Result<String> {
|
||||
let stdout = child
|
||||
.stdout
|
||||
.take()
|
||||
.context("failed to capture exec-server stdout")?;
|
||||
let mut lines = BufReader::new(stdout).lines();
|
||||
let deadline = Instant::now() + Duration::from_secs(10);
|
||||
|
||||
loop {
|
||||
let remaining = deadline.saturating_duration_since(Instant::now());
|
||||
if remaining.is_zero() {
|
||||
anyhow::bail!("timed out waiting for exec-server listen URL");
|
||||
}
|
||||
|
||||
let line = tokio::time::timeout(remaining, lines.next_line())
|
||||
.await
|
||||
.context("timed out waiting for exec-server stdout")??
|
||||
.context("exec-server stdout closed before emitting listen URL")?;
|
||||
let listen_url = line.trim();
|
||||
if listen_url.starts_with("ws://") {
|
||||
return Ok(listen_url.to_string());
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn wait_for_streamable_http_server(
|
||||
server_child: &mut Child,
|
||||
address: &str,
|
||||
timeout: Duration,
|
||||
) -> anyhow::Result<()> {
|
||||
let deadline = Instant::now() + timeout;
|
||||
|
||||
loop {
|
||||
if let Some(status) = server_child.try_wait()? {
|
||||
return Err(anyhow::anyhow!(
|
||||
"streamable HTTP server exited early with status {status}"
|
||||
));
|
||||
}
|
||||
|
||||
let remaining = deadline.saturating_duration_since(Instant::now());
|
||||
if remaining.is_zero() {
|
||||
return Err(anyhow::anyhow!(
|
||||
"timed out waiting for streamable HTTP server at {address}: deadline reached"
|
||||
));
|
||||
}
|
||||
|
||||
match tokio::time::timeout(remaining, TcpStream::connect(address)).await {
|
||||
Ok(Ok(_)) => return Ok(()),
|
||||
Ok(Err(error)) => {
|
||||
if Instant::now() >= deadline {
|
||||
return Err(anyhow::anyhow!(
|
||||
"timed out waiting for streamable HTTP server at {address}: {error}"
|
||||
));
|
||||
}
|
||||
}
|
||||
Err(_) => {
|
||||
return Err(anyhow::anyhow!(
|
||||
"timed out waiting for streamable HTTP server at {address}: connect call timed out"
|
||||
));
|
||||
}
|
||||
}
|
||||
|
||||
sleep(Duration::from_millis(50)).await;
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user