mirror of
https://github.com/pchuan98/codex.git
synced 2026-07-01 00:31:56 +08:00
[codex] implement standalone code-mode process host (#30111)
## Summary - implement the standalone `codex-code-mode-host` stdio service - route sessions, cells, delegate requests, responses, and cancellation through a bounded host peer - supervise request, writer, cell-forwarding, actor, and V8 failure boundaries - bound request/session tombstones and fail-stop the connection on invalid protocol state - add host-only duplex protocol tests and local Cargo/Bazel run recipes ## Why This stage makes the host process independently runnable and reviewable before exposing any remote client in Codex. Transport or runtime failure closes the connection and relies on process replacement rather than transactional recovery. ## Stack This is **3 of 4** in the process-owned code-mode session stack. - Depends on #30110 - The final client PR targets this branch ## Validation - `just test -p codex-code-mode-host` — 7 host-only tests passed - `just fix -p codex-code-mode-host` - `just bazel-lock-update` - `just bazel-lock-check` - `just fmt`
This commit is contained in:
committed by
GitHub
Unverified
parent
8ebf71ec25
commit
da78d5fdc5
Generated
+8
@@ -2515,6 +2515,14 @@ dependencies = [
|
||||
[[package]]
|
||||
name = "codex-code-mode-host"
|
||||
version = "0.0.0"
|
||||
dependencies = [
|
||||
"anyhow",
|
||||
"codex-code-mode",
|
||||
"codex-code-mode-protocol",
|
||||
"pretty_assertions",
|
||||
"tokio",
|
||||
"tokio-util",
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "codex-code-mode-protocol"
|
||||
|
||||
@@ -8,5 +8,20 @@ license.workspace = true
|
||||
name = "codex-code-mode-host"
|
||||
path = "src/main.rs"
|
||||
|
||||
[lib]
|
||||
doctest = false
|
||||
name = "codex_code_mode_host"
|
||||
path = "src/lib.rs"
|
||||
|
||||
[lints]
|
||||
workspace = true
|
||||
|
||||
[dependencies]
|
||||
anyhow = { workspace = true }
|
||||
codex-code-mode = { workspace = true }
|
||||
codex-code-mode-protocol = { workspace = true }
|
||||
tokio = { workspace = true, features = ["io-std", "io-util", "macros", "rt", "sync", "time"] }
|
||||
tokio-util = { workspace = true, features = ["rt"] }
|
||||
|
||||
[dev-dependencies]
|
||||
pretty_assertions = { workspace = true }
|
||||
|
||||
@@ -0,0 +1,85 @@
|
||||
use std::sync::Arc;
|
||||
|
||||
use codex_code_mode_protocol::CellId;
|
||||
use codex_code_mode_protocol::CodeModeNestedToolCall;
|
||||
use codex_code_mode_protocol::CodeModeSessionDelegate;
|
||||
use codex_code_mode_protocol::NotificationFuture;
|
||||
use codex_code_mode_protocol::ToolInvocationFuture;
|
||||
use codex_code_mode_protocol::host::DelegateRequest;
|
||||
use codex_code_mode_protocol::host::DelegateResponse;
|
||||
use codex_code_mode_protocol::host::SessionId;
|
||||
use tokio_util::sync::CancellationToken;
|
||||
|
||||
use crate::peer::HostPeer;
|
||||
|
||||
pub(super) struct RemoteDelegate {
|
||||
session_id: SessionId,
|
||||
peer: Arc<HostPeer>,
|
||||
}
|
||||
|
||||
impl RemoteDelegate {
|
||||
pub(super) fn new(session_id: SessionId, peer: Arc<HostPeer>) -> Self {
|
||||
Self { session_id, peer }
|
||||
}
|
||||
}
|
||||
|
||||
impl CodeModeSessionDelegate for RemoteDelegate {
|
||||
fn invoke_tool<'a>(
|
||||
&'a self,
|
||||
invocation: CodeModeNestedToolCall,
|
||||
cancellation_token: CancellationToken,
|
||||
) -> ToolInvocationFuture<'a> {
|
||||
Box::pin(async move {
|
||||
match self
|
||||
.peer
|
||||
.call(
|
||||
self.session_id.clone(),
|
||||
DelegateRequest::InvokeTool {
|
||||
invocation: invocation.into(),
|
||||
},
|
||||
cancellation_token,
|
||||
)
|
||||
.await?
|
||||
{
|
||||
DelegateResponse::ToolResult { result } => Ok(result),
|
||||
DelegateResponse::NotificationDelivered => {
|
||||
Err("code-mode client returned an invalid tool result".to_string())
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
fn notify<'a>(
|
||||
&'a self,
|
||||
call_id: String,
|
||||
cell_id: CellId,
|
||||
text: String,
|
||||
cancellation_token: CancellationToken,
|
||||
) -> NotificationFuture<'a> {
|
||||
Box::pin(async move {
|
||||
match self
|
||||
.peer
|
||||
.call(
|
||||
self.session_id.clone(),
|
||||
DelegateRequest::Notify {
|
||||
call_id,
|
||||
cell_id: cell_id.into(),
|
||||
text,
|
||||
},
|
||||
cancellation_token,
|
||||
)
|
||||
.await?
|
||||
{
|
||||
DelegateResponse::NotificationDelivered => Ok(()),
|
||||
DelegateResponse::ToolResult { .. } => {
|
||||
Err("code-mode client returned an invalid notification result".to_string())
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
fn cell_closed(&self, cell_id: &CellId) {
|
||||
self.peer
|
||||
.close_cell(self.session_id.clone(), cell_id.clone());
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,486 @@
|
||||
use codex_code_mode_protocol::host::Capability;
|
||||
use codex_code_mode_protocol::host::CapabilitySet;
|
||||
use codex_code_mode_protocol::host::ClientHello;
|
||||
use codex_code_mode_protocol::host::ClientToHost;
|
||||
use codex_code_mode_protocol::host::EncodedFrame;
|
||||
use codex_code_mode_protocol::host::FramedReader;
|
||||
use codex_code_mode_protocol::host::FramedWriter;
|
||||
use codex_code_mode_protocol::host::HandshakeRejectReason;
|
||||
use codex_code_mode_protocol::host::HostHello;
|
||||
use codex_code_mode_protocol::host::HostRequest;
|
||||
use codex_code_mode_protocol::host::HostResponse;
|
||||
use codex_code_mode_protocol::host::HostToClient;
|
||||
use codex_code_mode_protocol::host::ProtocolVersion;
|
||||
use codex_code_mode_protocol::host::RequestId;
|
||||
use codex_code_mode_protocol::host::SessionId;
|
||||
use codex_code_mode_protocol::host::SupportedProtocolVersions;
|
||||
use codex_code_mode_protocol::host::WireExecuteRequest;
|
||||
use codex_code_mode_protocol::host::WireResult;
|
||||
use pretty_assertions::assert_eq;
|
||||
use tokio::sync::Semaphore;
|
||||
use tokio::sync::mpsc;
|
||||
use tokio_util::sync::CancellationToken;
|
||||
use tokio_util::task::TaskTracker;
|
||||
|
||||
use super::HostState;
|
||||
use super::MAX_ACTIVE_CELLS;
|
||||
use super::MAX_IN_FLIGHT_REQUESTS;
|
||||
use super::MAX_RECENT_REQUEST_IDS;
|
||||
use super::RequestKind;
|
||||
use super::RequestRegistry;
|
||||
use super::SeenSessionIds;
|
||||
use super::peer::HostPeer;
|
||||
use super::run;
|
||||
|
||||
fn client_hello(
|
||||
versions: impl IntoIterator<Item = ProtocolVersion>,
|
||||
required_capabilities: CapabilitySet,
|
||||
) -> ClientToHost {
|
||||
ClientToHost::ClientHello(
|
||||
ClientHello::new(
|
||||
SupportedProtocolVersions::try_new(versions).expect("supported versions"),
|
||||
required_capabilities,
|
||||
CapabilitySet::empty(),
|
||||
)
|
||||
.expect("client hello"),
|
||||
)
|
||||
}
|
||||
|
||||
fn session_id(value: &str) -> SessionId {
|
||||
SessionId::new(value).expect("session ID")
|
||||
}
|
||||
|
||||
fn request_id(value: i64) -> RequestId {
|
||||
RequestId::new(value)
|
||||
}
|
||||
|
||||
async fn decode_frame(frame: EncodedFrame) -> HostToClient {
|
||||
let (reader, writer) = tokio::io::duplex(/*max_buf_size*/ 4096);
|
||||
let writer = tokio::spawn(async move {
|
||||
FramedWriter::new(writer)
|
||||
.write_frame(&frame)
|
||||
.await
|
||||
.expect("write encoded frame");
|
||||
});
|
||||
let message = FramedReader::new(reader)
|
||||
.read()
|
||||
.await
|
||||
.expect("read encoded frame")
|
||||
.expect("encoded frame message");
|
||||
writer.await.expect("frame writer task");
|
||||
message
|
||||
}
|
||||
|
||||
fn execute_request(source: &str) -> WireExecuteRequest {
|
||||
WireExecuteRequest {
|
||||
tool_call_id: "call-1".to_string(),
|
||||
enabled_tools: Vec::new(),
|
||||
source: source.to_string(),
|
||||
yield_time_ms: Some(60_000),
|
||||
max_output_tokens: Some(1_000),
|
||||
}
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn handshake_and_multiple_session_lifecycles_are_ordered() {
|
||||
let (host_stream, client_stream) = tokio::io::duplex(/*max_buf_size*/ 4096);
|
||||
let (host_reader, host_writer) = tokio::io::split(host_stream);
|
||||
let (client_reader, client_writer) = tokio::io::split(client_stream);
|
||||
let host = tokio::spawn(run(host_reader, host_writer));
|
||||
let mut reader = FramedReader::new(client_reader);
|
||||
let mut writer = FramedWriter::new(client_writer);
|
||||
|
||||
writer
|
||||
.write(&client_hello([ProtocolVersion::V1], CapabilitySet::empty()))
|
||||
.await
|
||||
.expect("write hello");
|
||||
assert_eq!(
|
||||
reader.read::<HostToClient>().await.expect("read hello"),
|
||||
Some(HostToClient::HostHello(HostHello::new(
|
||||
ProtocolVersion::V1,
|
||||
CapabilitySet::empty(),
|
||||
)))
|
||||
);
|
||||
|
||||
for (request_id, id) in [
|
||||
(request_id(/*value*/ 1), "session-1"),
|
||||
(request_id(/*value*/ 2), "session-2"),
|
||||
] {
|
||||
writer
|
||||
.write(&ClientToHost::Request {
|
||||
id: request_id,
|
||||
request: HostRequest::OpenSession {
|
||||
session_id: session_id(id),
|
||||
},
|
||||
})
|
||||
.await
|
||||
.expect("open session");
|
||||
assert_eq!(
|
||||
reader.read::<HostToClient>().await.expect("session ready"),
|
||||
Some(HostToClient::Response {
|
||||
id: request_id,
|
||||
result: WireResult::Ok {
|
||||
value: HostResponse::SessionReady {
|
||||
session_id: session_id(id),
|
||||
},
|
||||
},
|
||||
})
|
||||
);
|
||||
}
|
||||
|
||||
for (request_id, id) in [
|
||||
(request_id(/*value*/ 3), "session-1"),
|
||||
(request_id(/*value*/ 4), "session-2"),
|
||||
] {
|
||||
writer
|
||||
.write(&ClientToHost::Request {
|
||||
id: request_id,
|
||||
request: HostRequest::ShutdownSession {
|
||||
session_id: session_id(id),
|
||||
},
|
||||
})
|
||||
.await
|
||||
.expect("shutdown session");
|
||||
assert_eq!(
|
||||
reader.read::<HostToClient>().await.expect("session closed"),
|
||||
Some(HostToClient::Response {
|
||||
id: request_id,
|
||||
result: WireResult::Ok {
|
||||
value: HostResponse::SessionClosed {
|
||||
session_id: session_id(id),
|
||||
},
|
||||
},
|
||||
})
|
||||
);
|
||||
}
|
||||
|
||||
drop(writer);
|
||||
drop(reader);
|
||||
host.await.expect("host task").expect("host connection");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn incompatible_or_invalid_handshake_is_rejected() {
|
||||
let (host_stream, client_stream) = tokio::io::duplex(/*max_buf_size*/ 1024);
|
||||
let (host_reader, host_writer) = tokio::io::split(host_stream);
|
||||
let (client_reader, client_writer) = tokio::io::split(client_stream);
|
||||
let host = tokio::spawn(run(host_reader, host_writer));
|
||||
let mut reader = FramedReader::new(client_reader);
|
||||
let mut writer = FramedWriter::new(client_writer);
|
||||
let version_two = ProtocolVersion::new(/*value*/ 2).expect("protocol version");
|
||||
|
||||
writer
|
||||
.write(&client_hello([version_two], CapabilitySet::empty()))
|
||||
.await
|
||||
.expect("write hello");
|
||||
assert_eq!(
|
||||
reader.read::<HostToClient>().await.expect("rejection"),
|
||||
Some(HostToClient::HandshakeRejected {
|
||||
reason: HandshakeRejectReason::NoCompatibleVersion {
|
||||
supported_versions: SupportedProtocolVersions::try_new([ProtocolVersion::V1])
|
||||
.expect("host versions"),
|
||||
},
|
||||
})
|
||||
);
|
||||
host.await.expect("host task").expect("host connection");
|
||||
|
||||
let (host_stream, client_stream) = tokio::io::duplex(/*max_buf_size*/ 1024);
|
||||
let (host_reader, host_writer) = tokio::io::split(host_stream);
|
||||
let (client_reader, client_writer) = tokio::io::split(client_stream);
|
||||
let host = tokio::spawn(run(host_reader, host_writer));
|
||||
let mut reader = FramedReader::new(client_reader);
|
||||
let mut writer = FramedWriter::new(client_writer);
|
||||
writer
|
||||
.write(&ClientToHost::Request {
|
||||
id: request_id(/*value*/ 1),
|
||||
request: HostRequest::OpenSession {
|
||||
session_id: session_id("session-1"),
|
||||
},
|
||||
})
|
||||
.await
|
||||
.expect("write invalid first message");
|
||||
assert_eq!(
|
||||
reader.read::<HostToClient>().await.expect("rejection"),
|
||||
Some(HostToClient::HandshakeRejected {
|
||||
reason: HandshakeRejectReason::InvalidHello {
|
||||
message: "first message must be connection/hello".to_string(),
|
||||
},
|
||||
})
|
||||
);
|
||||
host.await.expect("host task").expect("host connection");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn unsupported_required_capability_is_rejected() {
|
||||
let (host_stream, client_stream) = tokio::io::duplex(/*max_buf_size*/ 1024);
|
||||
let (host_reader, host_writer) = tokio::io::split(host_stream);
|
||||
let (client_reader, client_writer) = tokio::io::split(client_stream);
|
||||
let host = tokio::spawn(run(host_reader, host_writer));
|
||||
let mut reader = FramedReader::new(client_reader);
|
||||
let mut writer = FramedWriter::new(client_writer);
|
||||
let capability = Capability::new("required").expect("capability");
|
||||
|
||||
writer
|
||||
.write(&client_hello(
|
||||
[ProtocolVersion::V1],
|
||||
CapabilitySet::try_new([capability.clone()]).expect("capabilities"),
|
||||
))
|
||||
.await
|
||||
.expect("write hello");
|
||||
assert_eq!(
|
||||
reader.read::<HostToClient>().await.expect("rejection"),
|
||||
Some(HostToClient::HandshakeRejected {
|
||||
reason: HandshakeRejectReason::MissingRequiredCapability { capability },
|
||||
})
|
||||
);
|
||||
host.await.expect("host task").expect("host connection");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn session_id_cannot_be_reused_after_shutdown() {
|
||||
let (host_stream, client_stream) = tokio::io::duplex(/*max_buf_size*/ 2048);
|
||||
let (host_reader, host_writer) = tokio::io::split(host_stream);
|
||||
let (client_reader, client_writer) = tokio::io::split(client_stream);
|
||||
let host = tokio::spawn(run(host_reader, host_writer));
|
||||
let mut reader = FramedReader::new(client_reader);
|
||||
let mut writer = FramedWriter::new(client_writer);
|
||||
writer
|
||||
.write(&client_hello([ProtocolVersion::V1], CapabilitySet::empty()))
|
||||
.await
|
||||
.expect("write hello");
|
||||
reader
|
||||
.read::<HostToClient>()
|
||||
.await
|
||||
.expect("read hello")
|
||||
.expect("host hello");
|
||||
|
||||
let id = session_id("session-1");
|
||||
for (request_id, request) in [
|
||||
(
|
||||
request_id(/*value*/ 1),
|
||||
HostRequest::OpenSession {
|
||||
session_id: id.clone(),
|
||||
},
|
||||
),
|
||||
(
|
||||
request_id(/*value*/ 2),
|
||||
HostRequest::ShutdownSession {
|
||||
session_id: id.clone(),
|
||||
},
|
||||
),
|
||||
] {
|
||||
writer
|
||||
.write(&ClientToHost::Request {
|
||||
id: request_id,
|
||||
request,
|
||||
})
|
||||
.await
|
||||
.expect("session request");
|
||||
reader
|
||||
.read::<HostToClient>()
|
||||
.await
|
||||
.expect("session response")
|
||||
.expect("session response message");
|
||||
}
|
||||
writer
|
||||
.write(&ClientToHost::Request {
|
||||
id: request_id(/*value*/ 3),
|
||||
request: HostRequest::OpenSession { session_id: id },
|
||||
})
|
||||
.await
|
||||
.expect("reuse session ID");
|
||||
assert_eq!(
|
||||
reader.read::<HostToClient>().await.expect("reuse response"),
|
||||
Some(HostToClient::Response {
|
||||
id: request_id(/*value*/ 3),
|
||||
result: WireResult::Err {
|
||||
message: "code-mode session ID `session-1` was reused".to_string(),
|
||||
},
|
||||
})
|
||||
);
|
||||
drop(writer);
|
||||
drop(reader);
|
||||
host.await.expect("host task").expect("host connection");
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn request_cancellation_tombstones_are_bounded() {
|
||||
let mut requests = RequestRegistry::default();
|
||||
let duplicate = request_id(/*value*/ -1);
|
||||
requests
|
||||
.start(duplicate, RequestKind::OpenSession)
|
||||
.expect("start duplicate probe");
|
||||
assert!(requests.start(duplicate, RequestKind::OpenSession).is_err());
|
||||
requests.finish(duplicate);
|
||||
for value in 1..=MAX_RECENT_REQUEST_IDS as i64 + 100 {
|
||||
let id = request_id(value);
|
||||
requests
|
||||
.start(id, RequestKind::Wait)
|
||||
.expect("start request");
|
||||
requests.cancel(id);
|
||||
requests.finish(id);
|
||||
}
|
||||
for value in 10_000..20_000 {
|
||||
requests.cancel(request_id(value));
|
||||
}
|
||||
|
||||
assert!(requests.active.is_empty());
|
||||
assert_eq!(requests.recent.len(), MAX_RECENT_REQUEST_IDS);
|
||||
assert_eq!(requests.recent_order.len(), MAX_RECENT_REQUEST_IDS);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn request_task_panic_disconnects_host() {
|
||||
let (outgoing_tx, _outgoing_rx) = mpsc::channel(/*max_capacity*/ 1);
|
||||
let peer = Arc::new(HostPeer::new(outgoing_tx));
|
||||
let state = HostState {
|
||||
sessions: Mutex::new(HashMap::new()),
|
||||
seen_session_ids: Mutex::new(SeenSessionIds::default()),
|
||||
requests: Mutex::new(RequestRegistry::default()),
|
||||
request_tasks: TaskTracker::new(),
|
||||
request_permits: Arc::new(Semaphore::new(MAX_IN_FLIGHT_REQUESTS)),
|
||||
active_cell_permits: Arc::new(Semaphore::new(MAX_ACTIVE_CELLS)),
|
||||
closing: AtomicBool::new(false),
|
||||
peer: Arc::clone(&peer),
|
||||
};
|
||||
let task = state.request_tasks.spawn(async {
|
||||
panic!("request panic probe");
|
||||
});
|
||||
state.supervise_request_task(task);
|
||||
|
||||
tokio::time::timeout(Duration::from_secs(1), peer.disconnected())
|
||||
.await
|
||||
.expect("request panic should disconnect host");
|
||||
assert!(
|
||||
peer.failure()
|
||||
.expect("request failure")
|
||||
.contains("request task failed")
|
||||
);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn execute_request_id_remains_active_until_initial_response() {
|
||||
let (outgoing_tx, mut outgoing_rx) = mpsc::channel(/*max_capacity*/ 4);
|
||||
let peer = Arc::new(HostPeer::new(outgoing_tx));
|
||||
let state = Arc::new(HostState {
|
||||
sessions: Mutex::new(HashMap::new()),
|
||||
seen_session_ids: Mutex::new(SeenSessionIds::default()),
|
||||
requests: Mutex::new(RequestRegistry::default()),
|
||||
request_tasks: TaskTracker::new(),
|
||||
request_permits: Arc::new(Semaphore::new(MAX_IN_FLIGHT_REQUESTS)),
|
||||
active_cell_permits: Arc::new(Semaphore::new(MAX_ACTIVE_CELLS)),
|
||||
closing: AtomicBool::new(false),
|
||||
peer,
|
||||
});
|
||||
let session_id = session_id("session-1");
|
||||
state
|
||||
.open_session(session_id.clone())
|
||||
.expect("open session");
|
||||
let request_id = request_id(/*value*/ 1);
|
||||
|
||||
state
|
||||
.spawn_request(
|
||||
request_id,
|
||||
HostRequest::Execute {
|
||||
session_id: session_id.clone(),
|
||||
request: execute_request("await new Promise(() => {});"),
|
||||
},
|
||||
)
|
||||
.expect("spawn execute request");
|
||||
let started = decode_frame(outgoing_rx.recv().await.expect("execution started frame")).await;
|
||||
let HostToClient::Response {
|
||||
id,
|
||||
result:
|
||||
WireResult::Ok {
|
||||
value: HostResponse::ExecutionStarted { cell_id },
|
||||
},
|
||||
} = started
|
||||
else {
|
||||
panic!("expected execution started response");
|
||||
};
|
||||
assert_eq!(id, request_id);
|
||||
assert!(
|
||||
state
|
||||
.requests
|
||||
.lock()
|
||||
.unwrap_or_else(PoisonError::into_inner)
|
||||
.active
|
||||
.contains_key(&request_id)
|
||||
);
|
||||
|
||||
state
|
||||
.session(&session_id)
|
||||
.expect("session")
|
||||
.terminate(cell_id.into())
|
||||
.await
|
||||
.expect("terminate cell");
|
||||
state.disconnect().await;
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn active_cell_limit_rejects_execute_without_disconnecting() {
|
||||
let (outgoing_tx, mut outgoing_rx) = mpsc::channel(/*max_capacity*/ 1);
|
||||
let peer = Arc::new(HostPeer::new(outgoing_tx));
|
||||
let state = HostState {
|
||||
sessions: Mutex::new(HashMap::new()),
|
||||
seen_session_ids: Mutex::new(SeenSessionIds::default()),
|
||||
requests: Mutex::new(RequestRegistry::default()),
|
||||
request_tasks: TaskTracker::new(),
|
||||
request_permits: Arc::new(Semaphore::new(MAX_IN_FLIGHT_REQUESTS)),
|
||||
active_cell_permits: Arc::new(Semaphore::new(/*permits*/ 0)),
|
||||
closing: AtomicBool::new(false),
|
||||
peer: Arc::clone(&peer),
|
||||
};
|
||||
let session_id = session_id("session-1");
|
||||
state
|
||||
.open_session(session_id.clone())
|
||||
.expect("open session");
|
||||
let request_id = request_id(/*value*/ 1);
|
||||
|
||||
state
|
||||
.handle_request(
|
||||
request_id,
|
||||
HostRequest::Execute {
|
||||
session_id,
|
||||
request: execute_request("text(\"hello\");"),
|
||||
},
|
||||
CancellationToken::new(),
|
||||
)
|
||||
.await;
|
||||
|
||||
assert_eq!(
|
||||
decode_frame(outgoing_rx.recv().await.expect("execute response frame")).await,
|
||||
HostToClient::Response {
|
||||
id: request_id,
|
||||
result: WireResult::Err {
|
||||
message: "code-mode host has too many active cells".to_string(),
|
||||
},
|
||||
}
|
||||
);
|
||||
assert!(!peer.is_disconnected());
|
||||
state.disconnect().await;
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn cell_forwarding_panic_disconnects_host() {
|
||||
let (outgoing_tx, _outgoing_rx) = mpsc::channel(/*max_capacity*/ 1);
|
||||
let peer = Arc::new(HostPeer::new(outgoing_tx));
|
||||
peer.spawn_critical("cell forwarding", async {
|
||||
panic!("cell forwarding panic probe");
|
||||
});
|
||||
|
||||
tokio::time::timeout(Duration::from_secs(1), peer.disconnected())
|
||||
.await
|
||||
.expect("cell panic should disconnect host");
|
||||
assert!(
|
||||
peer.failure()
|
||||
.expect("cell failure")
|
||||
.contains("cell forwarding task failed")
|
||||
);
|
||||
}
|
||||
use std::collections::HashMap;
|
||||
use std::sync::Arc;
|
||||
use std::sync::Mutex;
|
||||
use std::sync::PoisonError;
|
||||
use std::sync::atomic::AtomicBool;
|
||||
use std::time::Duration;
|
||||
@@ -0,0 +1,605 @@
|
||||
use std::collections::HashMap;
|
||||
use std::collections::HashSet;
|
||||
use std::collections::VecDeque;
|
||||
use std::sync::Arc;
|
||||
use std::sync::Mutex;
|
||||
use std::sync::PoisonError;
|
||||
use std::sync::atomic::AtomicBool;
|
||||
use std::sync::atomic::Ordering;
|
||||
use std::time::Duration;
|
||||
|
||||
use anyhow::Context;
|
||||
use anyhow::Result;
|
||||
use codex_code_mode::InProcessCodeModeSession;
|
||||
use codex_code_mode_protocol::host::CapabilitySet;
|
||||
use codex_code_mode_protocol::host::ClientToHost;
|
||||
use codex_code_mode_protocol::host::EncodedFrame;
|
||||
use codex_code_mode_protocol::host::FramedReader;
|
||||
use codex_code_mode_protocol::host::FramedWriter;
|
||||
use codex_code_mode_protocol::host::HandshakeRejectReason;
|
||||
use codex_code_mode_protocol::host::HostHello;
|
||||
use codex_code_mode_protocol::host::HostRequest;
|
||||
use codex_code_mode_protocol::host::HostResponse;
|
||||
use codex_code_mode_protocol::host::HostToClient;
|
||||
use codex_code_mode_protocol::host::ProtocolVersion;
|
||||
use codex_code_mode_protocol::host::RequestId;
|
||||
use codex_code_mode_protocol::host::SessionId;
|
||||
use codex_code_mode_protocol::host::SupportedProtocolVersions;
|
||||
use tokio::io::AsyncRead;
|
||||
use tokio::io::AsyncWrite;
|
||||
use tokio::sync::Semaphore;
|
||||
use tokio::sync::mpsc;
|
||||
use tokio_util::sync::CancellationToken;
|
||||
use tokio_util::task::TaskTracker;
|
||||
|
||||
use self::delegate::RemoteDelegate;
|
||||
use self::peer::HostPeer;
|
||||
|
||||
mod delegate;
|
||||
mod peer;
|
||||
|
||||
const MAX_IN_FLIGHT_REQUESTS: usize = 256;
|
||||
const MAX_ACTIVE_CELLS: usize = 128;
|
||||
const MAX_RECENT_REQUEST_IDS: usize = 4096;
|
||||
const MAX_RECENT_SESSION_IDS: usize = 4096;
|
||||
const SHUTDOWN_TIMEOUT: Duration = Duration::from_secs(5);
|
||||
|
||||
/// Runs one code-mode host connection over the process standard streams.
|
||||
pub async fn run_stdio() -> Result<()> {
|
||||
run(tokio::io::stdin(), tokio::io::stdout()).await
|
||||
}
|
||||
|
||||
/// Runs one code-mode host connection over an ordered input/output pair.
|
||||
async fn run<R, W>(reader: R, writer: W) -> Result<()>
|
||||
where
|
||||
R: AsyncRead + Send + Unpin + 'static,
|
||||
W: AsyncWrite + Send + Unpin + 'static,
|
||||
{
|
||||
let mut reader = FramedReader::new(reader);
|
||||
let mut writer = FramedWriter::new(writer);
|
||||
if !negotiate(&mut reader, &mut writer).await? {
|
||||
return Ok(());
|
||||
}
|
||||
|
||||
let (outgoing_tx, mut outgoing_rx) = mpsc::channel::<EncodedFrame>(/*max_capacity*/ 128);
|
||||
let peer = Arc::new(HostPeer::new(outgoing_tx));
|
||||
let state = Arc::new(HostState {
|
||||
sessions: Mutex::new(HashMap::new()),
|
||||
seen_session_ids: Mutex::new(SeenSessionIds::default()),
|
||||
requests: Mutex::new(RequestRegistry::default()),
|
||||
request_tasks: TaskTracker::new(),
|
||||
request_permits: Arc::new(Semaphore::new(MAX_IN_FLIGHT_REQUESTS)),
|
||||
active_cell_permits: Arc::new(Semaphore::new(MAX_ACTIVE_CELLS)),
|
||||
closing: AtomicBool::new(false),
|
||||
peer: Arc::clone(&peer),
|
||||
});
|
||||
let writer_disconnected = peer.disconnection_token();
|
||||
let writer_task = tokio::spawn(async move {
|
||||
loop {
|
||||
tokio::select! {
|
||||
_ = writer_disconnected.cancelled() => return Ok::<(), anyhow::Error>(()),
|
||||
frame = outgoing_rx.recv() => {
|
||||
let Some(frame) = frame else {
|
||||
return Ok(());
|
||||
};
|
||||
if let Err(err) = writer.write_frame(&frame).await {
|
||||
return Err(
|
||||
anyhow::Error::new(err)
|
||||
.context("failed to write code-mode host message")
|
||||
);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
});
|
||||
let writer_peer = Arc::clone(&peer);
|
||||
let writer_supervisor = tokio::spawn(async move {
|
||||
match writer_task.await {
|
||||
Ok(Ok(())) if !writer_peer.is_disconnected() => {
|
||||
writer_peer.fail("code-mode writer task exited unexpectedly".to_string());
|
||||
}
|
||||
Ok(Ok(())) => {}
|
||||
Ok(Err(err)) => {
|
||||
writer_peer.fail(format!("code-mode writer task failed: {err:#}"));
|
||||
}
|
||||
Err(err) => {
|
||||
writer_peer.fail(format!("code-mode writer task failed: {err}"));
|
||||
}
|
||||
}
|
||||
});
|
||||
|
||||
let input_result = async {
|
||||
loop {
|
||||
let message = tokio::select! {
|
||||
_ = peer.disconnected() => break,
|
||||
message = reader.read::<ClientToHost>() => message
|
||||
.context("failed to read code-mode client message")?,
|
||||
};
|
||||
let Some(message) = message else {
|
||||
break;
|
||||
};
|
||||
match message {
|
||||
ClientToHost::ClientHello(_) => {
|
||||
anyhow::bail!("received a second code-mode client hello");
|
||||
}
|
||||
ClientToHost::Request { id, request } => {
|
||||
state.spawn_request(id, request)?;
|
||||
}
|
||||
ClientToHost::CancelRequest { id } => {
|
||||
state.cancel_request(id);
|
||||
}
|
||||
ClientToHost::DelegateResponse { id, result } => {
|
||||
peer.complete(id, result.into_result()).await;
|
||||
}
|
||||
}
|
||||
}
|
||||
Ok::<(), anyhow::Error>(())
|
||||
}
|
||||
.await;
|
||||
|
||||
peer.disconnect();
|
||||
if tokio::time::timeout(SHUTDOWN_TIMEOUT, state.disconnect())
|
||||
.await
|
||||
.is_err()
|
||||
{
|
||||
peer.fail("timed out shutting down code-mode host state".to_string());
|
||||
}
|
||||
drop(state);
|
||||
tokio::time::timeout(SHUTDOWN_TIMEOUT, writer_supervisor)
|
||||
.await
|
||||
.context("timed out supervising code-mode writer task")?
|
||||
.context("code-mode writer supervisor task failed")?;
|
||||
let failure = peer.failure();
|
||||
drop(peer);
|
||||
input_result?;
|
||||
if let Some(failure) = failure {
|
||||
anyhow::bail!(failure);
|
||||
}
|
||||
Ok(())
|
||||
}
|
||||
|
||||
async fn negotiate<R, W>(reader: &mut FramedReader<R>, writer: &mut FramedWriter<W>) -> Result<bool>
|
||||
where
|
||||
R: AsyncRead + Unpin,
|
||||
W: AsyncWrite + Unpin,
|
||||
{
|
||||
let Some(first_message) = reader
|
||||
.read::<ClientToHost>()
|
||||
.await
|
||||
.context("failed to read code-mode client hello")?
|
||||
else {
|
||||
return Ok(false);
|
||||
};
|
||||
let ClientToHost::ClientHello(client_hello) = first_message else {
|
||||
writer
|
||||
.write(&HostToClient::HandshakeRejected {
|
||||
reason: HandshakeRejectReason::InvalidHello {
|
||||
message: "first message must be connection/hello".to_string(),
|
||||
},
|
||||
})
|
||||
.await
|
||||
.context("failed to reject invalid code-mode client hello")?;
|
||||
return Ok(false);
|
||||
};
|
||||
|
||||
let supported_versions = SupportedProtocolVersions::try_new([ProtocolVersion::V1])?;
|
||||
if !client_hello
|
||||
.supported_versions()
|
||||
.contains(ProtocolVersion::V1)
|
||||
{
|
||||
writer
|
||||
.write(&HostToClient::HandshakeRejected {
|
||||
reason: HandshakeRejectReason::NoCompatibleVersion { supported_versions },
|
||||
})
|
||||
.await
|
||||
.context("failed to reject incompatible code-mode client")?;
|
||||
return Ok(false);
|
||||
}
|
||||
|
||||
let host_capabilities = CapabilitySet::empty();
|
||||
if let Some(capability) = client_hello
|
||||
.required_capabilities()
|
||||
.iter()
|
||||
.find(|capability| !host_capabilities.contains(capability))
|
||||
{
|
||||
writer
|
||||
.write(&HostToClient::HandshakeRejected {
|
||||
reason: HandshakeRejectReason::MissingRequiredCapability {
|
||||
capability: capability.clone(),
|
||||
},
|
||||
})
|
||||
.await
|
||||
.context("failed to reject unsupported code-mode capability")?;
|
||||
return Ok(false);
|
||||
}
|
||||
|
||||
writer
|
||||
.write(&HostToClient::HostHello(HostHello::new(
|
||||
ProtocolVersion::V1,
|
||||
host_capabilities,
|
||||
)))
|
||||
.await
|
||||
.context("failed to write code-mode host hello")?;
|
||||
Ok(true)
|
||||
}
|
||||
|
||||
struct HostState {
|
||||
sessions: Mutex<HashMap<SessionId, Arc<InProcessCodeModeSession>>>,
|
||||
seen_session_ids: Mutex<SeenSessionIds>,
|
||||
requests: Mutex<RequestRegistry>,
|
||||
request_tasks: TaskTracker,
|
||||
request_permits: Arc<Semaphore>,
|
||||
active_cell_permits: Arc<Semaphore>,
|
||||
closing: AtomicBool,
|
||||
peer: Arc<HostPeer>,
|
||||
}
|
||||
|
||||
impl HostState {
|
||||
fn spawn_request(
|
||||
self: &Arc<Self>,
|
||||
request_id: RequestId,
|
||||
request: HostRequest,
|
||||
) -> Result<(), anyhow::Error> {
|
||||
let cancellation = self
|
||||
.requests
|
||||
.lock()
|
||||
.unwrap_or_else(PoisonError::into_inner)
|
||||
.start(request_id, RequestKind::from(&request))?;
|
||||
let Ok(permit) = Arc::clone(&self.request_permits).try_acquire_owned() else {
|
||||
self.respond(
|
||||
request_id,
|
||||
Err("code-mode host has too many in-flight requests".to_string()),
|
||||
);
|
||||
self.finish_request(request_id);
|
||||
return Ok(());
|
||||
};
|
||||
let state = Arc::clone(self);
|
||||
let request_task = self.request_tasks.spawn(async move {
|
||||
let _permit = permit;
|
||||
state
|
||||
.handle_request(request_id, request, cancellation)
|
||||
.await;
|
||||
state.finish_request(request_id);
|
||||
});
|
||||
self.supervise_request_task(request_task);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn supervise_request_task(&self, task: tokio::task::JoinHandle<()>) {
|
||||
let peer = Arc::clone(&self.peer);
|
||||
tokio::spawn(async move {
|
||||
if let Err(err) = task.await {
|
||||
peer.fail(format!("code-mode request task failed: {err}"));
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
async fn handle_request(
|
||||
&self,
|
||||
request_id: RequestId,
|
||||
request: HostRequest,
|
||||
cancellation: CancellationToken,
|
||||
) {
|
||||
if self.closing.load(Ordering::Acquire) {
|
||||
self.respond(
|
||||
request_id,
|
||||
Err("code-mode host is shutting down".to_string()),
|
||||
);
|
||||
return;
|
||||
}
|
||||
match request {
|
||||
HostRequest::OpenSession { session_id } => {
|
||||
let result = self
|
||||
.open_session(session_id.clone())
|
||||
.map(|()| HostResponse::SessionReady { session_id });
|
||||
self.respond(request_id, result);
|
||||
}
|
||||
HostRequest::Execute {
|
||||
session_id,
|
||||
request,
|
||||
} => {
|
||||
if cancellation.is_cancelled() {
|
||||
self.respond(request_id, Err("code-mode request cancelled".to_string()));
|
||||
return;
|
||||
}
|
||||
let request = match request.try_into() {
|
||||
Ok(request) => request,
|
||||
Err(err) => {
|
||||
self.respond(
|
||||
request_id,
|
||||
Err(format!("invalid code-mode execute request: {err}")),
|
||||
);
|
||||
return;
|
||||
}
|
||||
};
|
||||
let session = match self.session(&session_id) {
|
||||
Ok(session) => session,
|
||||
Err(err) => {
|
||||
self.respond(request_id, Err(err));
|
||||
return;
|
||||
}
|
||||
};
|
||||
let Ok(active_cell_permit) =
|
||||
Arc::clone(&self.active_cell_permits).try_acquire_owned()
|
||||
else {
|
||||
self.respond(
|
||||
request_id,
|
||||
Err("code-mode host has too many active cells".to_string()),
|
||||
);
|
||||
return;
|
||||
};
|
||||
let result = session.execute(request).await;
|
||||
match result {
|
||||
Ok(started) => {
|
||||
let cell_id = started.cell_id.clone();
|
||||
self.respond(
|
||||
request_id,
|
||||
Ok(HostResponse::ExecutionStarted {
|
||||
cell_id: cell_id.into(),
|
||||
}),
|
||||
);
|
||||
let initial_response_sent = self.peer.start_cell(
|
||||
session_id,
|
||||
request_id,
|
||||
started,
|
||||
active_cell_permit,
|
||||
);
|
||||
let _ = initial_response_sent.await;
|
||||
}
|
||||
Err(err) => self.respond(request_id, Err(err)),
|
||||
}
|
||||
}
|
||||
HostRequest::Wait {
|
||||
session_id,
|
||||
request,
|
||||
} => {
|
||||
let result = match self.session(&session_id) {
|
||||
Ok(session) => {
|
||||
tokio::select! {
|
||||
biased;
|
||||
_ = cancellation.cancelled() => {
|
||||
Err("code-mode request cancelled".to_string())
|
||||
}
|
||||
result = session.wait(request.into()) => result.map(|outcome| {
|
||||
HostResponse::WaitCompleted {
|
||||
outcome: outcome.into(),
|
||||
}
|
||||
}),
|
||||
}
|
||||
}
|
||||
Err(err) => Err(err),
|
||||
};
|
||||
self.respond(request_id, result);
|
||||
}
|
||||
HostRequest::Terminate {
|
||||
session_id,
|
||||
cell_id,
|
||||
} => {
|
||||
let result = match self.session(&session_id) {
|
||||
Ok(session) => session.terminate(cell_id.into()).await.map(|outcome| {
|
||||
HostResponse::WaitCompleted {
|
||||
outcome: outcome.into(),
|
||||
}
|
||||
}),
|
||||
Err(err) => Err(err),
|
||||
};
|
||||
self.respond(request_id, result);
|
||||
}
|
||||
HostRequest::ShutdownSession { session_id } => {
|
||||
let session = self
|
||||
.sessions
|
||||
.lock()
|
||||
.unwrap_or_else(PoisonError::into_inner)
|
||||
.remove(&session_id);
|
||||
let result = match session {
|
||||
Some(session) => match session.shutdown().await {
|
||||
Ok(()) => {
|
||||
self.peer.wait_for_session_cells(&session_id).await;
|
||||
Ok(HostResponse::SessionClosed { session_id })
|
||||
}
|
||||
Err(err) => Err(err),
|
||||
},
|
||||
None => Err(format!("unknown code-mode session {session_id}")),
|
||||
};
|
||||
self.respond(request_id, result);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn open_session(&self, session_id: SessionId) -> Result<(), String> {
|
||||
let mut sessions = self.sessions.lock().unwrap_or_else(PoisonError::into_inner);
|
||||
if sessions.contains_key(&session_id) {
|
||||
return Err(format!(
|
||||
"code-mode session ID `{session_id}` is already open"
|
||||
));
|
||||
}
|
||||
if self.closing.load(Ordering::Acquire) {
|
||||
return Err("code-mode host is shutting down".to_string());
|
||||
}
|
||||
if !self
|
||||
.seen_session_ids
|
||||
.lock()
|
||||
.unwrap_or_else(PoisonError::into_inner)
|
||||
.remember(session_id.clone())
|
||||
{
|
||||
return Err(format!("code-mode session ID `{session_id}` was reused"));
|
||||
}
|
||||
let delegate = Arc::new(RemoteDelegate::new(
|
||||
session_id.clone(),
|
||||
Arc::clone(&self.peer),
|
||||
));
|
||||
let peer = Arc::downgrade(&self.peer);
|
||||
let task_failure_handler = Arc::new(move |reason| {
|
||||
if let Some(peer) = peer.upgrade() {
|
||||
peer.fail(reason);
|
||||
}
|
||||
});
|
||||
sessions.insert(
|
||||
session_id,
|
||||
Arc::new(
|
||||
InProcessCodeModeSession::with_delegate_and_task_failure_handler(
|
||||
delegate,
|
||||
task_failure_handler,
|
||||
),
|
||||
),
|
||||
);
|
||||
Ok(())
|
||||
}
|
||||
|
||||
fn session(&self, session_id: &SessionId) -> Result<Arc<InProcessCodeModeSession>, String> {
|
||||
self.sessions
|
||||
.lock()
|
||||
.unwrap_or_else(PoisonError::into_inner)
|
||||
.get(session_id)
|
||||
.cloned()
|
||||
.ok_or_else(|| format!("unknown code-mode session {session_id}"))
|
||||
}
|
||||
|
||||
fn respond(&self, id: RequestId, result: Result<HostResponse, String>) {
|
||||
self.peer.respond(id, result);
|
||||
}
|
||||
|
||||
fn cancel_request(&self, request_id: RequestId) {
|
||||
self.requests
|
||||
.lock()
|
||||
.unwrap_or_else(PoisonError::into_inner)
|
||||
.cancel(request_id);
|
||||
}
|
||||
|
||||
fn finish_request(&self, request_id: RequestId) {
|
||||
self.requests
|
||||
.lock()
|
||||
.unwrap_or_else(PoisonError::into_inner)
|
||||
.finish(request_id);
|
||||
}
|
||||
|
||||
async fn disconnect(&self) {
|
||||
self.closing.store(true, Ordering::Release);
|
||||
self.requests
|
||||
.lock()
|
||||
.unwrap_or_else(PoisonError::into_inner)
|
||||
.cancel_all();
|
||||
self.request_tasks.close();
|
||||
self.request_tasks.wait().await;
|
||||
let sessions = self
|
||||
.sessions
|
||||
.lock()
|
||||
.unwrap_or_else(PoisonError::into_inner)
|
||||
.drain()
|
||||
.map(|(_, session)| session)
|
||||
.collect::<Vec<_>>();
|
||||
for session in sessions {
|
||||
let _ = session.shutdown().await;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy)]
|
||||
enum RequestKind {
|
||||
OpenSession,
|
||||
Execute,
|
||||
Wait,
|
||||
Terminate,
|
||||
ShutdownSession,
|
||||
}
|
||||
|
||||
impl RequestKind {
|
||||
fn from(request: &HostRequest) -> Self {
|
||||
match request {
|
||||
HostRequest::OpenSession { .. } => Self::OpenSession,
|
||||
HostRequest::Execute { .. } => Self::Execute,
|
||||
HostRequest::Wait { .. } => Self::Wait,
|
||||
HostRequest::Terminate { .. } => Self::Terminate,
|
||||
HostRequest::ShutdownSession { .. } => Self::ShutdownSession,
|
||||
}
|
||||
}
|
||||
|
||||
fn is_cancellable(self) -> bool {
|
||||
matches!(self, Self::Execute | Self::Wait)
|
||||
}
|
||||
}
|
||||
|
||||
struct ActiveRequest {
|
||||
kind: RequestKind,
|
||||
cancellation: CancellationToken,
|
||||
}
|
||||
|
||||
#[derive(Default)]
|
||||
struct RequestRegistry {
|
||||
active: HashMap<RequestId, ActiveRequest>,
|
||||
recent: HashSet<RequestId>,
|
||||
recent_order: VecDeque<RequestId>,
|
||||
}
|
||||
|
||||
impl RequestRegistry {
|
||||
fn start(
|
||||
&mut self,
|
||||
request_id: RequestId,
|
||||
kind: RequestKind,
|
||||
) -> Result<CancellationToken, anyhow::Error> {
|
||||
if self.active.contains_key(&request_id) || self.recent.contains(&request_id) {
|
||||
anyhow::bail!("duplicate code-mode request ID {request_id:?}");
|
||||
}
|
||||
let cancellation = CancellationToken::new();
|
||||
self.active.insert(
|
||||
request_id,
|
||||
ActiveRequest {
|
||||
kind,
|
||||
cancellation: cancellation.clone(),
|
||||
},
|
||||
);
|
||||
Ok(cancellation)
|
||||
}
|
||||
|
||||
fn cancel(&self, request_id: RequestId) {
|
||||
if let Some(request) = self.active.get(&request_id)
|
||||
&& request.kind.is_cancellable()
|
||||
{
|
||||
request.cancellation.cancel();
|
||||
}
|
||||
}
|
||||
|
||||
fn finish(&mut self, request_id: RequestId) {
|
||||
if self.active.remove(&request_id).is_none() {
|
||||
return;
|
||||
}
|
||||
self.recent.insert(request_id);
|
||||
self.recent_order.push_back(request_id);
|
||||
while self.recent_order.len() > MAX_RECENT_REQUEST_IDS {
|
||||
if let Some(expired) = self.recent_order.pop_front() {
|
||||
self.recent.remove(&expired);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
fn cancel_all(&self) {
|
||||
for request in self.active.values() {
|
||||
request.cancellation.cancel();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Default)]
|
||||
struct SeenSessionIds {
|
||||
ids: HashSet<SessionId>,
|
||||
order: VecDeque<SessionId>,
|
||||
}
|
||||
|
||||
impl SeenSessionIds {
|
||||
fn remember(&mut self, session_id: SessionId) -> bool {
|
||||
if !self.ids.insert(session_id.clone()) {
|
||||
return false;
|
||||
}
|
||||
self.order.push_back(session_id);
|
||||
while self.order.len() > MAX_RECENT_SESSION_IDS {
|
||||
if let Some(expired) = self.order.pop_front() {
|
||||
self.ids.remove(&expired);
|
||||
}
|
||||
}
|
||||
true
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
#[path = "host_tests.rs"]
|
||||
mod tests;
|
||||
@@ -1 +1,4 @@
|
||||
fn main() {}
|
||||
#[tokio::main(flavor = "current_thread")]
|
||||
async fn main() -> anyhow::Result<()> {
|
||||
codex_code_mode_host::run_stdio().await
|
||||
}
|
||||
|
||||
@@ -0,0 +1,541 @@
|
||||
use std::collections::HashMap;
|
||||
use std::collections::VecDeque;
|
||||
use std::sync::Arc;
|
||||
use std::sync::Mutex as StdMutex;
|
||||
use std::sync::PoisonError;
|
||||
use std::sync::atomic::AtomicI64;
|
||||
use std::sync::atomic::Ordering;
|
||||
|
||||
use codex_code_mode_protocol::CellId;
|
||||
use codex_code_mode_protocol::StartedCell;
|
||||
use codex_code_mode_protocol::host::DelegateRequest;
|
||||
use codex_code_mode_protocol::host::DelegateRequestId;
|
||||
use codex_code_mode_protocol::host::DelegateResponse;
|
||||
use codex_code_mode_protocol::host::EncodedFrame;
|
||||
use codex_code_mode_protocol::host::HostToClient;
|
||||
use codex_code_mode_protocol::host::RequestId;
|
||||
use codex_code_mode_protocol::host::SessionId;
|
||||
use codex_code_mode_protocol::host::WireResult;
|
||||
use tokio::sync::Mutex;
|
||||
use tokio::sync::Notify;
|
||||
use tokio::sync::OwnedSemaphorePermit;
|
||||
use tokio::sync::Semaphore;
|
||||
use tokio::sync::mpsc;
|
||||
use tokio::sync::oneshot;
|
||||
use tokio_util::sync::CancellationToken;
|
||||
|
||||
const CELL_MESSAGE_CAPACITY: usize = 128;
|
||||
const MAX_PENDING_DELEGATE_CALLS: usize = 256;
|
||||
|
||||
pub(super) struct HostPeer {
|
||||
outgoing_tx: mpsc::Sender<EncodedFrame>,
|
||||
pending: Mutex<HashMap<DelegateRequestId, PendingDelegate>>,
|
||||
delegate_permits: Arc<Semaphore>,
|
||||
cell_routes: StdMutex<HashMap<(SessionId, CellId), CellRoute>>,
|
||||
cell_routes_changed: Notify,
|
||||
next_request_id: AtomicI64,
|
||||
disconnected: CancellationToken,
|
||||
failure: StdMutex<Option<String>>,
|
||||
}
|
||||
|
||||
struct PendingDelegate {
|
||||
response_tx: oneshot::Sender<Result<DelegateResponse, String>>,
|
||||
dispatched: bool,
|
||||
_permit: OwnedSemaphorePermit,
|
||||
}
|
||||
|
||||
enum CellRoute {
|
||||
Pending(VecDeque<CellMessage>),
|
||||
Active(mpsc::Sender<CellMessage>),
|
||||
}
|
||||
|
||||
enum CellMessage {
|
||||
Delegate {
|
||||
id: DelegateRequestId,
|
||||
request: DelegateRequest,
|
||||
dispatched_tx: oneshot::Sender<Result<(), String>>,
|
||||
},
|
||||
Closed,
|
||||
}
|
||||
|
||||
impl HostPeer {
|
||||
pub(super) fn new(outgoing_tx: mpsc::Sender<EncodedFrame>) -> Self {
|
||||
Self {
|
||||
outgoing_tx,
|
||||
pending: Mutex::new(HashMap::new()),
|
||||
delegate_permits: Arc::new(Semaphore::new(MAX_PENDING_DELEGATE_CALLS)),
|
||||
cell_routes: StdMutex::new(HashMap::new()),
|
||||
cell_routes_changed: Notify::new(),
|
||||
next_request_id: AtomicI64::new(1),
|
||||
disconnected: CancellationToken::new(),
|
||||
failure: StdMutex::new(None),
|
||||
}
|
||||
}
|
||||
|
||||
pub(super) fn send(&self, message: HostToClient) -> Result<(), PeerSendError> {
|
||||
let frame = EncodedFrame::encode(&message)
|
||||
.map_err(|err| PeerSendError::Payload(err.to_string()))?;
|
||||
self.send_frame(frame)
|
||||
}
|
||||
|
||||
pub(super) fn respond(
|
||||
&self,
|
||||
id: RequestId,
|
||||
result: Result<codex_code_mode_protocol::host::HostResponse, String>,
|
||||
) {
|
||||
let message = HostToClient::Response {
|
||||
id,
|
||||
result: WireResult::from_result(result),
|
||||
};
|
||||
if let Err(PeerSendError::Payload(err)) = self.send(message) {
|
||||
let _ = self.send(HostToClient::Response {
|
||||
id,
|
||||
result: WireResult::Err {
|
||||
message: format!("code-mode host response exceeds the IPC frame limit: {err}"),
|
||||
},
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
fn initial_response(
|
||||
&self,
|
||||
id: RequestId,
|
||||
result: Result<codex_code_mode_protocol::host::WireRuntimeResponse, String>,
|
||||
) {
|
||||
let message = HostToClient::InitialResponse {
|
||||
id,
|
||||
result: WireResult::from_result(result),
|
||||
};
|
||||
if let Err(PeerSendError::Payload(err)) = self.send(message) {
|
||||
let _ = self.send(HostToClient::InitialResponse {
|
||||
id,
|
||||
result: WireResult::Err {
|
||||
message: format!(
|
||||
"code-mode initial response exceeds the IPC frame limit: {err}"
|
||||
),
|
||||
},
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
pub(super) async fn call(
|
||||
self: &Arc<Self>,
|
||||
session_id: SessionId,
|
||||
request: DelegateRequest,
|
||||
cancellation_token: CancellationToken,
|
||||
) -> Result<DelegateResponse, String> {
|
||||
if self.disconnected.is_cancelled() {
|
||||
return Err("code-mode client connection closed".to_string());
|
||||
}
|
||||
let Ok(permit) = Arc::clone(&self.delegate_permits).try_acquire_owned() else {
|
||||
return Err("code-mode host has too many pending delegate calls".to_string());
|
||||
};
|
||||
let id = DelegateRequestId::new(self.next_request_id.fetch_add(1, Ordering::Relaxed));
|
||||
let (response_tx, response_rx) = oneshot::channel();
|
||||
self.pending.lock().await.insert(
|
||||
id,
|
||||
PendingDelegate {
|
||||
response_tx,
|
||||
dispatched: false,
|
||||
_permit: permit,
|
||||
},
|
||||
);
|
||||
let mut pending = PendingDelegateRequest::new(Arc::clone(self), id);
|
||||
let cell_id = match &request {
|
||||
DelegateRequest::InvokeTool { invocation } => invocation.cell_id.clone().into(),
|
||||
DelegateRequest::Notify { cell_id, .. } => cell_id.clone().into(),
|
||||
};
|
||||
let (dispatched_tx, dispatched_rx) = oneshot::channel();
|
||||
if let Err(err) = self.route_cell_message(
|
||||
(session_id, cell_id),
|
||||
CellMessage::Delegate {
|
||||
id,
|
||||
request,
|
||||
dispatched_tx,
|
||||
},
|
||||
) {
|
||||
self.pending.lock().await.remove(&id);
|
||||
pending.disarm();
|
||||
return Err(err);
|
||||
}
|
||||
|
||||
let dispatched = tokio::select! {
|
||||
dispatched = dispatched_rx => dispatched.map_err(|_| {
|
||||
"code-mode cell route closed before dispatching delegate request".to_string()
|
||||
})?,
|
||||
_ = self.disconnected.cancelled() => {
|
||||
self.pending.lock().await.remove(&id);
|
||||
pending.disarm();
|
||||
return Err("code-mode client connection closed".to_string());
|
||||
}
|
||||
};
|
||||
if let Err(err) = dispatched {
|
||||
self.pending.lock().await.remove(&id);
|
||||
pending.disarm();
|
||||
return Err(err);
|
||||
}
|
||||
|
||||
tokio::select! {
|
||||
response = response_rx => {
|
||||
pending.disarm();
|
||||
response.map_err(|_| {
|
||||
"code-mode client closed before returning delegate output".to_string()
|
||||
})?
|
||||
}
|
||||
_ = cancellation_token.cancelled() => {
|
||||
if self.remove_pending(id).await.is_some() {
|
||||
let _ = self.send(HostToClient::CancelDelegateRequest { id });
|
||||
}
|
||||
pending.disarm();
|
||||
Err("code mode delegate request cancelled".to_string())
|
||||
}
|
||||
_ = self.disconnected.cancelled() => {
|
||||
self.pending.lock().await.remove(&id);
|
||||
pending.disarm();
|
||||
Err("code-mode client connection closed".to_string())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub(super) async fn complete(
|
||||
&self,
|
||||
id: DelegateRequestId,
|
||||
response: Result<DelegateResponse, String>,
|
||||
) {
|
||||
if let Some(pending) = self.remove_pending(id).await {
|
||||
let _ = pending.response_tx.send(response);
|
||||
}
|
||||
}
|
||||
|
||||
pub(super) fn start_cell(
|
||||
self: &Arc<Self>,
|
||||
session_id: SessionId,
|
||||
request_id: RequestId,
|
||||
started: StartedCell,
|
||||
active_cell_permit: OwnedSemaphorePermit,
|
||||
) -> oneshot::Receiver<()> {
|
||||
let (initial_response_sent_tx, initial_response_sent_rx) = oneshot::channel();
|
||||
let key = (session_id, started.cell_id.clone());
|
||||
let (messages_tx, messages_rx) = mpsc::channel(CELL_MESSAGE_CAPACITY);
|
||||
let previous = self
|
||||
.cell_routes
|
||||
.lock()
|
||||
.unwrap_or_else(PoisonError::into_inner)
|
||||
.insert(key.clone(), CellRoute::Active(messages_tx.clone()));
|
||||
match previous {
|
||||
Some(CellRoute::Pending(messages)) => {
|
||||
for message in messages {
|
||||
if messages_tx.try_send(message).is_err() {
|
||||
self.disconnect();
|
||||
return initial_response_sent_rx;
|
||||
}
|
||||
}
|
||||
}
|
||||
Some(CellRoute::Active(_)) => {
|
||||
self.disconnect();
|
||||
return initial_response_sent_rx;
|
||||
}
|
||||
None => {}
|
||||
}
|
||||
let peer = Arc::clone(self);
|
||||
self.spawn_critical("cell forwarding", async move {
|
||||
drive_cell(
|
||||
peer,
|
||||
key,
|
||||
request_id,
|
||||
started,
|
||||
messages_rx,
|
||||
initial_response_sent_tx,
|
||||
active_cell_permit,
|
||||
)
|
||||
.await;
|
||||
});
|
||||
initial_response_sent_rx
|
||||
}
|
||||
|
||||
pub(super) fn close_cell(&self, session_id: SessionId, cell_id: CellId) {
|
||||
let _ = self.route_cell_message((session_id, cell_id), CellMessage::Closed);
|
||||
}
|
||||
|
||||
pub(super) fn disconnect(&self) {
|
||||
self.disconnected.cancel();
|
||||
}
|
||||
|
||||
pub(super) fn fail(&self, reason: String) {
|
||||
let mut failure = self.failure.lock().unwrap_or_else(PoisonError::into_inner);
|
||||
if failure.is_none() {
|
||||
*failure = Some(reason);
|
||||
}
|
||||
drop(failure);
|
||||
self.disconnect();
|
||||
}
|
||||
|
||||
pub(super) fn failure(&self) -> Option<String> {
|
||||
self.failure
|
||||
.lock()
|
||||
.unwrap_or_else(PoisonError::into_inner)
|
||||
.clone()
|
||||
}
|
||||
|
||||
pub(super) fn is_disconnected(&self) -> bool {
|
||||
self.disconnected.is_cancelled()
|
||||
}
|
||||
|
||||
pub(super) async fn disconnected(&self) {
|
||||
self.disconnected.cancelled().await;
|
||||
}
|
||||
|
||||
pub(super) fn disconnection_token(&self) -> CancellationToken {
|
||||
self.disconnected.clone()
|
||||
}
|
||||
|
||||
pub(super) async fn wait_for_session_cells(&self, session_id: &SessionId) {
|
||||
loop {
|
||||
let changed = self.cell_routes_changed.notified();
|
||||
if !self
|
||||
.cell_routes
|
||||
.lock()
|
||||
.unwrap_or_else(PoisonError::into_inner)
|
||||
.keys()
|
||||
.any(|(route_session_id, _)| route_session_id == session_id)
|
||||
{
|
||||
return;
|
||||
}
|
||||
tokio::select! {
|
||||
_ = changed => {}
|
||||
_ = self.disconnected.cancelled() => return,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn send_delegate_if_pending(
|
||||
&self,
|
||||
id: DelegateRequestId,
|
||||
session_id: SessionId,
|
||||
request: DelegateRequest,
|
||||
dispatched_tx: oneshot::Sender<Result<(), String>>,
|
||||
) {
|
||||
let result = {
|
||||
let mut pending = self.pending.lock().await;
|
||||
let Some(pending) = pending.get_mut(&id) else {
|
||||
let _ = dispatched_tx.send(Err(
|
||||
"code-mode delegate request was cancelled before dispatch".to_string(),
|
||||
));
|
||||
return;
|
||||
};
|
||||
match self.send(HostToClient::DelegateRequest {
|
||||
id,
|
||||
session_id,
|
||||
request,
|
||||
}) {
|
||||
Ok(()) => {
|
||||
pending.dispatched = true;
|
||||
Ok(())
|
||||
}
|
||||
Err(err) => Err(err.to_string()),
|
||||
}
|
||||
};
|
||||
let _ = dispatched_tx.send(result);
|
||||
}
|
||||
|
||||
fn route_cell_message(
|
||||
&self,
|
||||
key: (SessionId, CellId),
|
||||
message: CellMessage,
|
||||
) -> Result<(), String> {
|
||||
use std::collections::hash_map::Entry;
|
||||
|
||||
let result = match self
|
||||
.cell_routes
|
||||
.lock()
|
||||
.unwrap_or_else(PoisonError::into_inner)
|
||||
.entry(key)
|
||||
{
|
||||
Entry::Occupied(mut entry) => match entry.get_mut() {
|
||||
CellRoute::Pending(messages) if messages.len() < CELL_MESSAGE_CAPACITY => {
|
||||
messages.push_back(message);
|
||||
Ok(())
|
||||
}
|
||||
CellRoute::Pending(_) => Err("code-mode cell message queue is full".to_string()),
|
||||
CellRoute::Active(sender) => sender
|
||||
.try_send(message)
|
||||
.map_err(|_| "code-mode cell message queue is unavailable".to_string()),
|
||||
},
|
||||
Entry::Vacant(entry) => {
|
||||
entry.insert(CellRoute::Pending(VecDeque::from([message])));
|
||||
Ok(())
|
||||
}
|
||||
};
|
||||
if result.is_err() {
|
||||
self.disconnect();
|
||||
}
|
||||
result
|
||||
}
|
||||
|
||||
async fn remove_pending(&self, id: DelegateRequestId) -> Option<PendingDelegate> {
|
||||
self.pending.lock().await.remove(&id)
|
||||
}
|
||||
|
||||
pub(super) fn spawn_critical<F>(self: &Arc<Self>, task_name: &'static str, future: F)
|
||||
where
|
||||
F: std::future::Future<Output = ()> + Send + 'static,
|
||||
{
|
||||
let task = tokio::spawn(future);
|
||||
let peer = Arc::clone(self);
|
||||
tokio::spawn(async move {
|
||||
if let Err(err) = task.await {
|
||||
peer.fail(format!("code-mode {task_name} task failed: {err}"));
|
||||
}
|
||||
});
|
||||
}
|
||||
|
||||
fn send_frame(&self, frame: EncodedFrame) -> Result<(), PeerSendError> {
|
||||
match self.outgoing_tx.try_send(frame) {
|
||||
Ok(()) => Ok(()),
|
||||
Err(mpsc::error::TrySendError::Full(_)) => {
|
||||
self.disconnect();
|
||||
Err(PeerSendError::Unavailable(
|
||||
"code-mode host outgoing queue is full".to_string(),
|
||||
))
|
||||
}
|
||||
Err(mpsc::error::TrySendError::Closed(_)) => {
|
||||
self.disconnect();
|
||||
Err(PeerSendError::Unavailable(
|
||||
"code-mode client connection closed".to_string(),
|
||||
))
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
async fn drive_cell(
|
||||
peer: Arc<HostPeer>,
|
||||
key: (SessionId, CellId),
|
||||
request_id: RequestId,
|
||||
started: StartedCell,
|
||||
mut messages_rx: mpsc::Receiver<CellMessage>,
|
||||
initial_response_sent_tx: oneshot::Sender<()>,
|
||||
_active_cell_permit: OwnedSemaphorePermit,
|
||||
) {
|
||||
let mut initial_response_sent_tx = Some(initial_response_sent_tx);
|
||||
let initial_response = started.initial_response();
|
||||
tokio::pin!(initial_response);
|
||||
let closed = loop {
|
||||
tokio::select! {
|
||||
biased;
|
||||
result = &mut initial_response => {
|
||||
peer.initial_response(request_id, result.map(Into::into));
|
||||
if let Some(initial_response_sent_tx) = initial_response_sent_tx.take() {
|
||||
let _ = initial_response_sent_tx.send(());
|
||||
}
|
||||
break false;
|
||||
}
|
||||
message = messages_rx.recv() => match message {
|
||||
Some(CellMessage::Delegate {
|
||||
id,
|
||||
request,
|
||||
dispatched_tx,
|
||||
}) => {
|
||||
peer.send_delegate_if_pending(id, key.0.clone(), request, dispatched_tx).await;
|
||||
}
|
||||
Some(CellMessage::Closed) | None => break true,
|
||||
},
|
||||
_ = peer.disconnected.cancelled() => {
|
||||
peer.remove_cell_route(&key);
|
||||
return;
|
||||
}
|
||||
}
|
||||
};
|
||||
|
||||
if closed {
|
||||
peer.initial_response(request_id, initial_response.await.map(Into::into));
|
||||
if let Some(initial_response_sent_tx) = initial_response_sent_tx.take() {
|
||||
let _ = initial_response_sent_tx.send(());
|
||||
}
|
||||
} else {
|
||||
loop {
|
||||
tokio::select! {
|
||||
message = messages_rx.recv() => match message {
|
||||
Some(CellMessage::Delegate {
|
||||
id,
|
||||
request,
|
||||
dispatched_tx,
|
||||
}) => {
|
||||
peer.send_delegate_if_pending(id, key.0.clone(), request, dispatched_tx).await;
|
||||
}
|
||||
Some(CellMessage::Closed) | None => break,
|
||||
},
|
||||
_ = peer.disconnected.cancelled() => {
|
||||
peer.remove_cell_route(&key);
|
||||
return;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
let _ = peer.send(HostToClient::CellClosed {
|
||||
session_id: key.0.clone(),
|
||||
cell_id: (&key.1).into(),
|
||||
});
|
||||
peer.remove_cell_route(&key);
|
||||
}
|
||||
|
||||
impl HostPeer {
|
||||
fn remove_cell_route(&self, key: &(SessionId, CellId)) {
|
||||
let removed = self
|
||||
.cell_routes
|
||||
.lock()
|
||||
.unwrap_or_else(PoisonError::into_inner)
|
||||
.remove(key);
|
||||
if removed.is_some() {
|
||||
self.cell_routes_changed.notify_waiters();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
pub(super) enum PeerSendError {
|
||||
Payload(String),
|
||||
Unavailable(String),
|
||||
}
|
||||
|
||||
impl std::fmt::Display for PeerSendError {
|
||||
fn fmt(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
|
||||
match self {
|
||||
Self::Payload(message) | Self::Unavailable(message) => formatter.write_str(message),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
struct PendingDelegateRequest {
|
||||
peer: Arc<HostPeer>,
|
||||
id: Option<DelegateRequestId>,
|
||||
}
|
||||
|
||||
impl PendingDelegateRequest {
|
||||
fn new(peer: Arc<HostPeer>, id: DelegateRequestId) -> Self {
|
||||
Self { peer, id: Some(id) }
|
||||
}
|
||||
|
||||
fn disarm(&mut self) {
|
||||
self.id = None;
|
||||
}
|
||||
}
|
||||
|
||||
impl Drop for PendingDelegateRequest {
|
||||
fn drop(&mut self) {
|
||||
let Some(id) = self.id.take() else {
|
||||
return;
|
||||
};
|
||||
let peer = Arc::clone(&self.peer);
|
||||
tokio::spawn(async move {
|
||||
if let Some(pending) = peer.remove_pending(id).await
|
||||
&& pending.dispatched
|
||||
{
|
||||
let _ = peer.send(HostToClient::CancelDelegateRequest { id });
|
||||
}
|
||||
});
|
||||
}
|
||||
}
|
||||
|
||||
#[cfg(test)]
|
||||
#[path = "peer_tests.rs"]
|
||||
mod tests;
|
||||
@@ -0,0 +1,95 @@
|
||||
use std::sync::Arc;
|
||||
use std::time::Duration;
|
||||
|
||||
use codex_code_mode_protocol::CellId;
|
||||
use codex_code_mode_protocol::RuntimeResponse;
|
||||
use codex_code_mode_protocol::StartedCell;
|
||||
use codex_code_mode_protocol::host::DelegateRequest;
|
||||
use codex_code_mode_protocol::host::RequestId;
|
||||
use codex_code_mode_protocol::host::SessionId;
|
||||
use pretty_assertions::assert_eq;
|
||||
use tokio::sync::Semaphore;
|
||||
use tokio::sync::mpsc;
|
||||
use tokio::sync::oneshot;
|
||||
use tokio::sync::oneshot::error::TryRecvError;
|
||||
use tokio_util::sync::CancellationToken;
|
||||
|
||||
use super::HostPeer;
|
||||
use super::MAX_PENDING_DELEGATE_CALLS;
|
||||
|
||||
fn session_id(value: &str) -> SessionId {
|
||||
SessionId::new(value).expect("session ID")
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn start_cell_reports_when_initial_response_is_enqueued() {
|
||||
let (outgoing_tx, mut outgoing_rx) = mpsc::channel(/*max_capacity*/ 4);
|
||||
let peer = Arc::new(HostPeer::new(outgoing_tx));
|
||||
let cell_id = CellId::new("cell-1".to_string());
|
||||
let (response_tx, response_rx) = oneshot::channel();
|
||||
let started = StartedCell::new(cell_id.clone(), response_rx);
|
||||
let active_cell_permits = Arc::new(Semaphore::new(/*permits*/ 1));
|
||||
let active_cell_permit = Arc::clone(&active_cell_permits)
|
||||
.try_acquire_owned()
|
||||
.expect("active cell permit");
|
||||
|
||||
let mut initial_response_sent = peer.start_cell(
|
||||
session_id("session-1"),
|
||||
RequestId::new(/*value*/ 1),
|
||||
started,
|
||||
active_cell_permit,
|
||||
);
|
||||
assert_eq!(initial_response_sent.try_recv(), Err(TryRecvError::Empty));
|
||||
|
||||
response_tx
|
||||
.send(RuntimeResponse::Result {
|
||||
cell_id: cell_id.clone(),
|
||||
content_items: Vec::new(),
|
||||
error_text: None,
|
||||
})
|
||||
.expect("initial response receiver");
|
||||
initial_response_sent
|
||||
.await
|
||||
.expect("initial response completion");
|
||||
outgoing_rx.recv().await.expect("initial response frame");
|
||||
assert_eq!(active_cell_permits.available_permits(), 0);
|
||||
|
||||
peer.close_cell(session_id("session-1"), cell_id);
|
||||
let permit = tokio::time::timeout(
|
||||
Duration::from_secs(1),
|
||||
Arc::clone(&active_cell_permits).acquire_owned(),
|
||||
)
|
||||
.await
|
||||
.expect("cell permit should be released")
|
||||
.expect("cell permit semaphore should remain open");
|
||||
drop(permit);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn pending_delegate_limit_rejects_call_without_disconnecting() {
|
||||
let (outgoing_tx, _outgoing_rx) = mpsc::channel(/*max_capacity*/ 1);
|
||||
let peer = Arc::new(HostPeer::new(outgoing_tx));
|
||||
let permits = Arc::clone(&peer.delegate_permits)
|
||||
.acquire_many_owned(MAX_PENDING_DELEGATE_CALLS as u32)
|
||||
.await
|
||||
.expect("delegate permits");
|
||||
|
||||
let result = peer
|
||||
.call(
|
||||
session_id("session-1"),
|
||||
DelegateRequest::Notify {
|
||||
call_id: "call-1".to_string(),
|
||||
cell_id: CellId::new("cell-1".to_string()).into(),
|
||||
text: "hello".to_string(),
|
||||
},
|
||||
CancellationToken::new(),
|
||||
)
|
||||
.await;
|
||||
|
||||
assert_eq!(
|
||||
result,
|
||||
Err("code-mode host has too many pending delegate calls".to_string())
|
||||
);
|
||||
assert!(!peer.is_disconnected());
|
||||
drop(permits);
|
||||
}
|
||||
@@ -31,6 +31,10 @@ tui-with-exec-server *args:
|
||||
file-search *args:
|
||||
cargo run --bin codex-file-search -- {args}
|
||||
|
||||
# Run the standalone code-mode host from source.
|
||||
code-mode-host *args:
|
||||
cargo run --bin codex-code-mode-host -- {args}
|
||||
|
||||
# Build the CLI and run the app-server test client
|
||||
app-server-test-client *args:
|
||||
cargo build -p codex-cli
|
||||
@@ -107,6 +111,16 @@ bazel-codex *args:
|
||||
bazel-codex *args:
|
||||
bazel run //codex-rs/cli:codex --run_under='cd /d "{{ invocation_directory_native() }}" &&' -- @($args | Select-Object -Skip 1)
|
||||
|
||||
# Build and run the standalone code-mode host from source using Bazel.
|
||||
[no-cd]
|
||||
[unix]
|
||||
bazel-code-mode-host *args:
|
||||
bazel run //codex-rs/code-mode-host:codex-code-mode-host --run_under="cd $PWD &&" -- "$@"
|
||||
|
||||
[windows]
|
||||
bazel-code-mode-host *args:
|
||||
bazel run //codex-rs/code-mode-host:codex-code-mode-host --run_under='cd /d "{{ invocation_directory_native() }}" &&' -- @($args | Select-Object -Skip 1)
|
||||
|
||||
[no-cd]
|
||||
bazel-lock-update:
|
||||
bazel mod deps --lockfile_mode=update
|
||||
|
||||
Reference in New Issue
Block a user