mirror of
https://github.com/pchuan98/codex.git
synced 2026-07-01 00:31:56 +08:00
code-mode: define process host wire protocol (#29804)
## Why The process-owned code mode implementation needs an explicit, bounded wire contract before either side depends on it. Keeping framing and message semantics in `codex-code-mode-protocol` gives the client and sidecar one shared source of truth and makes compatibility failures detectable during connection setup. ## What changed - adds a versioned client/host handshake with required and optional capabilities - defines operation requests and responses for session lifecycle and cell control - defines reverse delegate request, response, cancellation, and cell-closure messages - adds a four-byte little-endian length-prefixed JSON codec with a hard frame cap - rejects malformed frames, unknown fields, invalid identifiers, and unsupported protocol states - locks the wire representation down with explicit JSON round-trip tests ## Testing - `just test -p codex-code-mode-protocol` ## Stack Part 1 of 6. Followed by [#29805](https://github.com/openai/codex/pull/29805).
This commit is contained in:
committed by
GitHub
Unverified
parent
f8937b7d86
commit
b3e1c33776
@@ -0,0 +1,101 @@
|
||||
use std::io;
|
||||
use std::mem::size_of;
|
||||
|
||||
use serde::Serialize;
|
||||
use serde::de::DeserializeOwned;
|
||||
use tokio::io::AsyncRead;
|
||||
use tokio::io::AsyncReadExt;
|
||||
use tokio::io::AsyncWrite;
|
||||
use tokio::io::AsyncWriteExt;
|
||||
|
||||
/// Maximum JSON payload size accepted for one IPC frame.
|
||||
pub const MAX_FRAME_BYTES: usize = 64 * 1024 * 1024;
|
||||
|
||||
/// Decodes JSON messages prefixed by a four-byte little-endian payload length.
|
||||
pub struct FramedReader<R> {
|
||||
reader: R,
|
||||
}
|
||||
|
||||
impl<R> FramedReader<R>
|
||||
where
|
||||
R: AsyncRead + Unpin,
|
||||
{
|
||||
pub fn new(reader: R) -> Self {
|
||||
Self { reader }
|
||||
}
|
||||
|
||||
/// Reads the next frame, returning `None` only for EOF at a frame boundary.
|
||||
pub async fn read<T>(&mut self) -> io::Result<Option<T>>
|
||||
where
|
||||
T: DeserializeOwned,
|
||||
{
|
||||
let mut length_bytes = [0_u8; size_of::<u32>()];
|
||||
if self.reader.read(&mut length_bytes[..1]).await? == 0 {
|
||||
return Ok(None);
|
||||
}
|
||||
self.reader.read_exact(&mut length_bytes[1..]).await?;
|
||||
|
||||
let length = u32::from_le_bytes(length_bytes) as usize;
|
||||
if length > MAX_FRAME_BYTES {
|
||||
return Err(io::Error::new(
|
||||
io::ErrorKind::InvalidData,
|
||||
format!("code-mode IPC frame length {length} exceeds {MAX_FRAME_BYTES} bytes"),
|
||||
));
|
||||
}
|
||||
|
||||
let mut payload = vec![0; length];
|
||||
self.reader.read_exact(&mut payload).await?;
|
||||
serde_json::from_slice(&payload).map(Some).map_err(|err| {
|
||||
io::Error::new(
|
||||
io::ErrorKind::InvalidData,
|
||||
format!("failed to decode code-mode IPC frame: {err}"),
|
||||
)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
/// Encodes JSON messages with a four-byte little-endian payload length.
|
||||
pub struct FramedWriter<W> {
|
||||
writer: W,
|
||||
}
|
||||
|
||||
impl<W> FramedWriter<W>
|
||||
where
|
||||
W: AsyncWrite + Unpin,
|
||||
{
|
||||
pub fn new(writer: W) -> Self {
|
||||
Self { writer }
|
||||
}
|
||||
|
||||
/// Writes and flushes one complete frame.
|
||||
pub async fn write<T>(&mut self, message: &T) -> io::Result<()>
|
||||
where
|
||||
T: Serialize,
|
||||
{
|
||||
let payload = serde_json::to_vec(message).map_err(|err| {
|
||||
io::Error::new(
|
||||
io::ErrorKind::InvalidData,
|
||||
format!("failed to encode code-mode IPC frame: {err}"),
|
||||
)
|
||||
})?;
|
||||
if payload.len() > MAX_FRAME_BYTES {
|
||||
return Err(io::Error::new(
|
||||
io::ErrorKind::InvalidData,
|
||||
format!(
|
||||
"code-mode IPC frame length {} exceeds {MAX_FRAME_BYTES} bytes",
|
||||
payload.len()
|
||||
),
|
||||
));
|
||||
}
|
||||
let length = u32::try_from(payload.len()).map_err(|_| {
|
||||
io::Error::new(
|
||||
io::ErrorKind::InvalidData,
|
||||
"code-mode IPC frame length exceeds u32",
|
||||
)
|
||||
})?;
|
||||
|
||||
self.writer.write_all(&length.to_le_bytes()).await?;
|
||||
self.writer.write_all(&payload).await?;
|
||||
self.writer.flush().await
|
||||
}
|
||||
}
|
||||
@@ -0,0 +1,104 @@
|
||||
use pretty_assertions::assert_eq;
|
||||
use serde_json::json;
|
||||
use tokio::io::AsyncReadExt;
|
||||
use tokio::io::AsyncWriteExt;
|
||||
|
||||
use super::FramedReader;
|
||||
use super::FramedWriter;
|
||||
use super::MAX_FRAME_BYTES;
|
||||
|
||||
#[tokio::test]
|
||||
async fn frame_wire_format_is_little_endian_length_prefixed_json() {
|
||||
let (writer, mut reader) = tokio::io::duplex(/*max_buf_size*/ 128);
|
||||
let write = tokio::spawn(async move {
|
||||
FramedWriter::new(writer)
|
||||
.write(&json!({"value": 1}))
|
||||
.await
|
||||
.expect("write frame");
|
||||
});
|
||||
|
||||
let mut bytes = Vec::new();
|
||||
reader.read_to_end(&mut bytes).await.expect("read bytes");
|
||||
write.await.expect("writer task");
|
||||
|
||||
let payload = br#"{"value":1}"#;
|
||||
let mut expected = (payload.len() as u32).to_le_bytes().to_vec();
|
||||
expected.extend_from_slice(payload);
|
||||
assert_eq!(bytes, expected);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn fragmented_frame_round_trips() {
|
||||
let value = json!({"type": "session/open", "sessionId": "session-1"});
|
||||
let payload = serde_json::to_vec(&value).expect("serialize");
|
||||
let mut bytes = (payload.len() as u32).to_le_bytes().to_vec();
|
||||
bytes.extend(payload);
|
||||
|
||||
let (mut writer, reader) = tokio::io::duplex(/*max_buf_size*/ 128);
|
||||
let write = tokio::spawn(async move {
|
||||
for byte in bytes {
|
||||
writer.write_all(&[byte]).await.expect("write byte");
|
||||
tokio::task::yield_now().await;
|
||||
}
|
||||
});
|
||||
|
||||
assert_eq!(
|
||||
FramedReader::new(reader)
|
||||
.read::<serde_json::Value>()
|
||||
.await
|
||||
.expect("read frame"),
|
||||
Some(value)
|
||||
);
|
||||
write.await.expect("writer task");
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn eof_is_clean_only_at_a_frame_boundary() {
|
||||
let (writer, reader) = tokio::io::duplex(/*max_buf_size*/ 16);
|
||||
drop(writer);
|
||||
assert_eq!(
|
||||
FramedReader::new(reader)
|
||||
.read::<serde_json::Value>()
|
||||
.await
|
||||
.expect("clean eof"),
|
||||
None
|
||||
);
|
||||
|
||||
let (mut writer, reader) = tokio::io::duplex(/*max_buf_size*/ 16);
|
||||
writer
|
||||
.write_all(&[1, 0])
|
||||
.await
|
||||
.expect("write partial header");
|
||||
drop(writer);
|
||||
let err = FramedReader::new(reader)
|
||||
.read::<serde_json::Value>()
|
||||
.await
|
||||
.expect_err("truncated header");
|
||||
assert_eq!(err.kind(), std::io::ErrorKind::UnexpectedEof);
|
||||
}
|
||||
|
||||
#[tokio::test]
|
||||
async fn oversized_and_malformed_frames_are_rejected() {
|
||||
let (mut writer, reader) = tokio::io::duplex(/*max_buf_size*/ 16);
|
||||
writer
|
||||
.write_all(&((MAX_FRAME_BYTES as u32) + 1).to_le_bytes())
|
||||
.await
|
||||
.expect("write oversized header");
|
||||
let err = FramedReader::new(reader)
|
||||
.read::<serde_json::Value>()
|
||||
.await
|
||||
.expect_err("oversized frame");
|
||||
assert_eq!(err.kind(), std::io::ErrorKind::InvalidData);
|
||||
|
||||
let (mut writer, reader) = tokio::io::duplex(/*max_buf_size*/ 16);
|
||||
writer
|
||||
.write_all(&(1_u32).to_le_bytes())
|
||||
.await
|
||||
.expect("write length");
|
||||
writer.write_all(b"{").await.expect("write malformed json");
|
||||
let err = FramedReader::new(reader)
|
||||
.read::<serde_json::Value>()
|
||||
.await
|
||||
.expect_err("malformed frame");
|
||||
assert_eq!(err.kind(), std::io::ErrorKind::InvalidData);
|
||||
}
|
||||
@@ -1,21 +1,57 @@
|
||||
use std::fmt::Debug;
|
||||
|
||||
use pretty_assertions::assert_eq;
|
||||
use serde::Serialize;
|
||||
use serde::de::DeserializeOwned;
|
||||
use serde_json::Value;
|
||||
use serde_json::json;
|
||||
|
||||
use super::Capability;
|
||||
use super::CapabilitySet;
|
||||
use super::ClientHello;
|
||||
use super::ClientToHost;
|
||||
use super::DelegateRequest;
|
||||
use super::DelegateRequestId;
|
||||
use super::DelegateResponse;
|
||||
use super::HandshakeRejectReason;
|
||||
use super::HostHello;
|
||||
use super::HostRequest;
|
||||
use super::HostResponse;
|
||||
use super::HostToClient;
|
||||
use super::ProtocolVersion;
|
||||
use super::RequestId;
|
||||
use super::SessionId;
|
||||
use super::SupportedProtocolVersions;
|
||||
use super::WireCellId;
|
||||
use super::WireContentItem;
|
||||
use super::WireExecuteRequest;
|
||||
use super::WireImageDetail;
|
||||
use super::WireNestedToolCall;
|
||||
use super::WireResult;
|
||||
use super::WireRuntimeResponse;
|
||||
use super::WireToolDefinition;
|
||||
use super::WireToolKind;
|
||||
use super::WireToolName;
|
||||
use super::WireWaitOutcome;
|
||||
use super::WireWaitRequest;
|
||||
use crate::ExecuteRequest;
|
||||
|
||||
fn session_id() -> SessionId {
|
||||
SessionId::new("session-1").expect("valid session ID")
|
||||
}
|
||||
|
||||
fn cell_id(value: &str) -> WireCellId {
|
||||
WireCellId::new(value)
|
||||
}
|
||||
|
||||
fn request_id(value: i64) -> RequestId {
|
||||
RequestId::new(value)
|
||||
}
|
||||
|
||||
fn delegate_request_id(value: i64) -> DelegateRequestId {
|
||||
DelegateRequestId::new(value)
|
||||
}
|
||||
|
||||
fn capability(value: &str) -> Capability {
|
||||
Capability::new(value).expect("valid capability")
|
||||
}
|
||||
@@ -25,116 +61,539 @@ fn supported_versions() -> SupportedProtocolVersions {
|
||||
.expect("nonempty unique protocol versions")
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn handshake_wire_contract_is_explicit_and_round_trips() {
|
||||
let client_hello = ClientToHost::ClientHello(
|
||||
ClientHello::new(
|
||||
supported_versions(),
|
||||
CapabilitySet::try_new([capability("required")]).expect("valid required set"),
|
||||
CapabilitySet::try_new([capability("optional")]).expect("valid optional set"),
|
||||
)
|
||||
.expect("disjoint capabilities"),
|
||||
);
|
||||
let client_hello_json = json!({
|
||||
"type": "connection/hello",
|
||||
"supportedVersions": [1],
|
||||
"requiredCapabilities": ["required"],
|
||||
"optionalCapabilities": ["optional"],
|
||||
});
|
||||
fn assert_wire_round_trip<T>(message: T, encoded: Value)
|
||||
where
|
||||
T: Debug + DeserializeOwned + PartialEq + Serialize,
|
||||
{
|
||||
assert_eq!(serde_json::to_value(&message).expect("serialize"), encoded);
|
||||
assert_eq!(
|
||||
serde_json::to_value(&client_hello).expect("serialize"),
|
||||
client_hello_json
|
||||
);
|
||||
assert_eq!(
|
||||
serde_json::from_value::<ClientToHost>(client_hello_json).expect("deserialize"),
|
||||
client_hello
|
||||
);
|
||||
|
||||
let host_hello = HostToClient::HostHello(HostHello::new(
|
||||
ProtocolVersion::V1,
|
||||
CapabilitySet::try_new([capability("required")]).expect("valid capabilities"),
|
||||
));
|
||||
let host_hello_json = json!({
|
||||
"type": "connection/ready",
|
||||
"selectedVersion": 1,
|
||||
"capabilities": ["required"],
|
||||
});
|
||||
assert_eq!(
|
||||
serde_json::to_value(&host_hello).expect("serialize"),
|
||||
host_hello_json
|
||||
);
|
||||
assert_eq!(
|
||||
serde_json::from_value::<HostToClient>(host_hello_json).expect("deserialize"),
|
||||
host_hello
|
||||
);
|
||||
|
||||
let rejected = HostToClient::HandshakeRejected {
|
||||
reason: HandshakeRejectReason::NoCompatibleVersion {
|
||||
supported_versions: supported_versions(),
|
||||
},
|
||||
};
|
||||
let rejected_json = json!({
|
||||
"type": "connection/rejected",
|
||||
"reason": {
|
||||
"type": "noCompatibleVersion",
|
||||
"supportedVersions": [1],
|
||||
},
|
||||
});
|
||||
assert_eq!(
|
||||
serde_json::to_value(&rejected).expect("serialize"),
|
||||
rejected_json
|
||||
);
|
||||
assert_eq!(
|
||||
serde_json::from_value::<HostToClient>(rejected_json).expect("deserialize"),
|
||||
rejected
|
||||
serde_json::from_value::<T>(encoded).expect("deserialize"),
|
||||
message
|
||||
);
|
||||
}
|
||||
|
||||
fn execute_request() -> WireExecuteRequest {
|
||||
WireExecuteRequest {
|
||||
tool_call_id: "call-1".to_string(),
|
||||
enabled_tools: vec![
|
||||
WireToolDefinition {
|
||||
name: "function_tool".to_string(),
|
||||
tool_name: WireToolName {
|
||||
name: "function_tool".to_string(),
|
||||
namespace: None,
|
||||
},
|
||||
description: "function tool".to_string(),
|
||||
kind: WireToolKind::Function,
|
||||
input_schema: Some(json!({ "type": "object" })),
|
||||
output_schema: None,
|
||||
},
|
||||
WireToolDefinition {
|
||||
name: "freeform_tool".to_string(),
|
||||
tool_name: WireToolName {
|
||||
name: "freeform_tool".to_string(),
|
||||
namespace: Some("mcp__sample__".to_string()),
|
||||
},
|
||||
description: "freeform tool".to_string(),
|
||||
kind: WireToolKind::Freeform,
|
||||
input_schema: None,
|
||||
output_schema: Some(json!({ "type": "string" })),
|
||||
},
|
||||
],
|
||||
source: "text('hello');".to_string(),
|
||||
yield_time_ms: Some(25),
|
||||
max_output_tokens: Some(100),
|
||||
}
|
||||
}
|
||||
|
||||
fn content_items() -> Vec<WireContentItem> {
|
||||
vec![
|
||||
WireContentItem::InputText {
|
||||
text: "hello".to_string(),
|
||||
},
|
||||
WireContentItem::InputImage {
|
||||
image_url: "data:image/png;base64,none".to_string(),
|
||||
detail: None,
|
||||
},
|
||||
WireContentItem::InputImage {
|
||||
image_url: "data:image/png;base64,auto".to_string(),
|
||||
detail: Some(WireImageDetail::Auto),
|
||||
},
|
||||
WireContentItem::InputImage {
|
||||
image_url: "data:image/png;base64,low".to_string(),
|
||||
detail: Some(WireImageDetail::Low),
|
||||
},
|
||||
WireContentItem::InputImage {
|
||||
image_url: "data:image/png;base64,high".to_string(),
|
||||
detail: Some(WireImageDetail::High),
|
||||
},
|
||||
WireContentItem::InputImage {
|
||||
image_url: "data:image/png;base64,original".to_string(),
|
||||
detail: Some(WireImageDetail::Original),
|
||||
},
|
||||
]
|
||||
}
|
||||
|
||||
fn content_items_json() -> Value {
|
||||
json!([
|
||||
{ "type": "input_text", "text": "hello" },
|
||||
{ "type": "input_image", "image_url": "data:image/png;base64,none" },
|
||||
{
|
||||
"type": "input_image",
|
||||
"image_url": "data:image/png;base64,auto",
|
||||
"detail": "auto",
|
||||
},
|
||||
{
|
||||
"type": "input_image",
|
||||
"image_url": "data:image/png;base64,low",
|
||||
"detail": "low",
|
||||
},
|
||||
{
|
||||
"type": "input_image",
|
||||
"image_url": "data:image/png;base64,high",
|
||||
"detail": "high",
|
||||
},
|
||||
{
|
||||
"type": "input_image",
|
||||
"image_url": "data:image/png;base64,original",
|
||||
"detail": "original",
|
||||
},
|
||||
])
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn session_lifecycle_wire_contract_is_explicit_and_round_trips() {
|
||||
let client_messages = [
|
||||
fn handshake_v1_variants_are_pinned() {
|
||||
assert_wire_round_trip(
|
||||
ClientToHost::ClientHello(
|
||||
ClientHello::new(
|
||||
supported_versions(),
|
||||
CapabilitySet::try_new([capability("required")]).expect("valid required set"),
|
||||
CapabilitySet::try_new([capability("optional")]).expect("valid optional set"),
|
||||
)
|
||||
.expect("disjoint capabilities"),
|
||||
),
|
||||
json!({
|
||||
"type": "connection/hello",
|
||||
"supportedVersions": [1],
|
||||
"requiredCapabilities": ["required"],
|
||||
"optionalCapabilities": ["optional"],
|
||||
}),
|
||||
);
|
||||
assert_wire_round_trip(
|
||||
HostToClient::HostHello(HostHello::new(
|
||||
ProtocolVersion::V1,
|
||||
CapabilitySet::try_new([capability("required")]).expect("valid capabilities"),
|
||||
)),
|
||||
json!({
|
||||
"type": "connection/ready",
|
||||
"selectedVersion": 1,
|
||||
"capabilities": ["required"],
|
||||
}),
|
||||
);
|
||||
for (reason, encoded) in [
|
||||
(
|
||||
ClientToHost::OpenSession {
|
||||
session_id: session_id(),
|
||||
HandshakeRejectReason::NoCompatibleVersion {
|
||||
supported_versions: supported_versions(),
|
||||
},
|
||||
json!({ "type": "session/open", "sessionId": "session-1" }),
|
||||
json!({
|
||||
"type": "connection/rejected",
|
||||
"reason": {
|
||||
"type": "noCompatibleVersion",
|
||||
"supportedVersions": [1],
|
||||
},
|
||||
}),
|
||||
),
|
||||
(
|
||||
ClientToHost::CloseSession {
|
||||
HandshakeRejectReason::MissingRequiredCapability {
|
||||
capability: capability("required"),
|
||||
},
|
||||
json!({
|
||||
"type": "connection/rejected",
|
||||
"reason": {
|
||||
"type": "missingRequiredCapability",
|
||||
"capability": "required",
|
||||
},
|
||||
}),
|
||||
),
|
||||
(
|
||||
HandshakeRejectReason::InvalidHello {
|
||||
message: "invalid hello".to_string(),
|
||||
},
|
||||
json!({
|
||||
"type": "connection/rejected",
|
||||
"reason": {
|
||||
"type": "invalidHello",
|
||||
"message": "invalid hello",
|
||||
},
|
||||
}),
|
||||
),
|
||||
] {
|
||||
assert_wire_round_trip(HostToClient::HandshakeRejected { reason }, encoded);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn client_to_host_v1_variants_are_pinned() {
|
||||
let execute_request = execute_request();
|
||||
for (id, request, encoded_request) in [
|
||||
(
|
||||
request_id(/*value*/ 1),
|
||||
HostRequest::OpenSession {
|
||||
session_id: session_id(),
|
||||
},
|
||||
json!({ "type": "session/close", "sessionId": "session-1" }),
|
||||
json!({ "method": "session/open", "sessionId": "session-1" }),
|
||||
),
|
||||
];
|
||||
for (message, encoded) in client_messages {
|
||||
assert_eq!(serde_json::to_value(&message).expect("serialize"), encoded);
|
||||
assert_eq!(
|
||||
serde_json::from_value::<ClientToHost>(encoded).expect("deserialize"),
|
||||
message
|
||||
(
|
||||
request_id(/*value*/ 2),
|
||||
HostRequest::Execute {
|
||||
session_id: session_id(),
|
||||
request: execute_request,
|
||||
},
|
||||
json!({
|
||||
"method": "session/execute",
|
||||
"sessionId": "session-1",
|
||||
"request": {
|
||||
"tool_call_id": "call-1",
|
||||
"enabled_tools": [
|
||||
{
|
||||
"name": "function_tool",
|
||||
"tool_name": { "name": "function_tool", "namespace": null },
|
||||
"description": "function tool",
|
||||
"kind": "function",
|
||||
"input_schema": { "type": "object" },
|
||||
"output_schema": null,
|
||||
},
|
||||
{
|
||||
"name": "freeform_tool",
|
||||
"tool_name": {
|
||||
"name": "freeform_tool",
|
||||
"namespace": "mcp__sample__",
|
||||
},
|
||||
"description": "freeform tool",
|
||||
"kind": "freeform",
|
||||
"input_schema": null,
|
||||
"output_schema": { "type": "string" },
|
||||
},
|
||||
],
|
||||
"source": "text('hello');",
|
||||
"yield_time_ms": 25,
|
||||
"max_output_tokens": 100,
|
||||
},
|
||||
}),
|
||||
),
|
||||
(
|
||||
request_id(/*value*/ 3),
|
||||
HostRequest::Wait {
|
||||
session_id: session_id(),
|
||||
request: WireWaitRequest {
|
||||
cell_id: cell_id("cell-1"),
|
||||
yield_time_ms: 50,
|
||||
},
|
||||
},
|
||||
json!({
|
||||
"method": "session/wait",
|
||||
"sessionId": "session-1",
|
||||
"request": { "cell_id": "cell-1", "yield_time_ms": 50 },
|
||||
}),
|
||||
),
|
||||
(
|
||||
request_id(/*value*/ 4),
|
||||
HostRequest::Terminate {
|
||||
session_id: session_id(),
|
||||
cell_id: cell_id("cell-1"),
|
||||
},
|
||||
json!({
|
||||
"method": "session/terminate",
|
||||
"sessionId": "session-1",
|
||||
"cellId": "cell-1",
|
||||
}),
|
||||
),
|
||||
(
|
||||
request_id(/*value*/ 5),
|
||||
HostRequest::ShutdownSession {
|
||||
session_id: session_id(),
|
||||
},
|
||||
json!({ "method": "session/shutdown", "sessionId": "session-1" }),
|
||||
),
|
||||
] {
|
||||
assert_wire_round_trip(
|
||||
ClientToHost::Request { id, request },
|
||||
json!({
|
||||
"type": "operation/request",
|
||||
"id": id,
|
||||
"request": encoded_request,
|
||||
}),
|
||||
);
|
||||
}
|
||||
|
||||
let host_messages = [
|
||||
for (id, result, encoded_result) in [
|
||||
(
|
||||
HostToClient::SessionReady {
|
||||
delegate_request_id(/*value*/ 6),
|
||||
WireResult::Ok {
|
||||
value: DelegateResponse::ToolResult {
|
||||
result: json!({ "answer": 42 }),
|
||||
},
|
||||
},
|
||||
json!({
|
||||
"status": "ok",
|
||||
"value": { "type": "tool/result", "result": { "answer": 42 } },
|
||||
}),
|
||||
),
|
||||
(
|
||||
delegate_request_id(/*value*/ 7),
|
||||
WireResult::Ok {
|
||||
value: DelegateResponse::NotificationDelivered,
|
||||
},
|
||||
json!({
|
||||
"status": "ok",
|
||||
"value": { "type": "notification/delivered" },
|
||||
}),
|
||||
),
|
||||
(
|
||||
delegate_request_id(/*value*/ 8),
|
||||
WireResult::Err {
|
||||
message: "delegate failed".to_string(),
|
||||
},
|
||||
json!({ "status": "error", "message": "delegate failed" }),
|
||||
),
|
||||
] {
|
||||
assert_wire_round_trip(
|
||||
ClientToHost::DelegateResponse { id, result },
|
||||
json!({
|
||||
"type": "delegate/response",
|
||||
"id": id,
|
||||
"result": encoded_result,
|
||||
}),
|
||||
);
|
||||
}
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn host_to_client_v1_variants_are_pinned() {
|
||||
for (id, response, encoded_response) in [
|
||||
(
|
||||
request_id(/*value*/ 1),
|
||||
HostResponse::SessionReady {
|
||||
session_id: session_id(),
|
||||
},
|
||||
json!({ "type": "session/ready", "sessionId": "session-1" }),
|
||||
),
|
||||
(
|
||||
HostToClient::SessionClosed {
|
||||
request_id(/*value*/ 2),
|
||||
HostResponse::ExecutionStarted {
|
||||
cell_id: cell_id("cell-1"),
|
||||
},
|
||||
json!({ "type": "execution/started", "cellId": "cell-1" }),
|
||||
),
|
||||
(
|
||||
request_id(/*value*/ 3),
|
||||
HostResponse::WaitCompleted {
|
||||
outcome: WireWaitOutcome::LiveCell(WireRuntimeResponse::Yielded {
|
||||
cell_id: cell_id("cell-1"),
|
||||
content_items: content_items(),
|
||||
}),
|
||||
},
|
||||
json!({
|
||||
"type": "wait/completed",
|
||||
"outcome": {
|
||||
"LiveCell": {
|
||||
"Yielded": {
|
||||
"cell_id": "cell-1",
|
||||
"content_items": content_items_json(),
|
||||
},
|
||||
},
|
||||
},
|
||||
}),
|
||||
),
|
||||
(
|
||||
request_id(/*value*/ 4),
|
||||
HostResponse::WaitCompleted {
|
||||
outcome: WireWaitOutcome::MissingCell(WireRuntimeResponse::Result {
|
||||
cell_id: cell_id("missing-cell"),
|
||||
content_items: Vec::new(),
|
||||
error_text: Some("cell not found".to_string()),
|
||||
}),
|
||||
},
|
||||
json!({
|
||||
"type": "wait/completed",
|
||||
"outcome": {
|
||||
"MissingCell": {
|
||||
"Result": {
|
||||
"cell_id": "missing-cell",
|
||||
"content_items": [],
|
||||
"error_text": "cell not found",
|
||||
},
|
||||
},
|
||||
},
|
||||
}),
|
||||
),
|
||||
(
|
||||
request_id(/*value*/ 5),
|
||||
HostResponse::SessionClosed {
|
||||
session_id: session_id(),
|
||||
},
|
||||
json!({ "type": "session/closed", "sessionId": "session-1" }),
|
||||
),
|
||||
];
|
||||
for (message, encoded) in host_messages {
|
||||
assert_eq!(serde_json::to_value(&message).expect("serialize"), encoded);
|
||||
assert_eq!(
|
||||
serde_json::from_value::<HostToClient>(encoded).expect("deserialize"),
|
||||
message
|
||||
] {
|
||||
assert_wire_round_trip(
|
||||
HostToClient::Response {
|
||||
id,
|
||||
result: WireResult::Ok { value: response },
|
||||
},
|
||||
json!({
|
||||
"type": "operation/response",
|
||||
"id": id,
|
||||
"result": { "status": "ok", "value": encoded_response },
|
||||
}),
|
||||
);
|
||||
}
|
||||
assert_wire_round_trip(
|
||||
HostToClient::Response {
|
||||
id: request_id(/*value*/ 6),
|
||||
result: WireResult::Err {
|
||||
message: "operation failed".to_string(),
|
||||
},
|
||||
},
|
||||
json!({
|
||||
"type": "operation/response",
|
||||
"id": 6,
|
||||
"result": { "status": "error", "message": "operation failed" },
|
||||
}),
|
||||
);
|
||||
|
||||
assert_wire_round_trip(
|
||||
HostToClient::InitialResponse {
|
||||
id: request_id(/*value*/ 7),
|
||||
result: WireResult::Ok {
|
||||
value: WireRuntimeResponse::Terminated {
|
||||
cell_id: cell_id("cell-1"),
|
||||
content_items: Vec::new(),
|
||||
},
|
||||
},
|
||||
},
|
||||
json!({
|
||||
"type": "execute/initialResponse",
|
||||
"id": 7,
|
||||
"result": {
|
||||
"status": "ok",
|
||||
"value": {
|
||||
"Terminated": { "cell_id": "cell-1", "content_items": [] },
|
||||
},
|
||||
},
|
||||
}),
|
||||
);
|
||||
assert_wire_round_trip(
|
||||
HostToClient::InitialResponse {
|
||||
id: request_id(/*value*/ 8),
|
||||
result: WireResult::Err {
|
||||
message: "execution failed".to_string(),
|
||||
},
|
||||
},
|
||||
json!({
|
||||
"type": "execute/initialResponse",
|
||||
"id": 8,
|
||||
"result": { "status": "error", "message": "execution failed" },
|
||||
}),
|
||||
);
|
||||
|
||||
assert_wire_round_trip(
|
||||
HostToClient::DelegateRequest {
|
||||
id: delegate_request_id(/*value*/ 9),
|
||||
session_id: session_id(),
|
||||
request: DelegateRequest::InvokeTool {
|
||||
invocation: WireNestedToolCall {
|
||||
cell_id: cell_id("cell-1"),
|
||||
runtime_tool_call_id: "runtime-call-1".to_string(),
|
||||
tool_name: WireToolName {
|
||||
name: "freeform_tool".to_string(),
|
||||
namespace: Some("mcp__sample__".to_string()),
|
||||
},
|
||||
tool_kind: WireToolKind::Freeform,
|
||||
input: Some(json!({ "value": 1 })),
|
||||
},
|
||||
},
|
||||
},
|
||||
json!({
|
||||
"type": "delegate/request",
|
||||
"id": 9,
|
||||
"sessionId": "session-1",
|
||||
"request": {
|
||||
"type": "tool/invoke",
|
||||
"invocation": {
|
||||
"cell_id": "cell-1",
|
||||
"runtime_tool_call_id": "runtime-call-1",
|
||||
"tool_name": {
|
||||
"name": "freeform_tool",
|
||||
"namespace": "mcp__sample__",
|
||||
},
|
||||
"tool_kind": "freeform",
|
||||
"input": { "value": 1 },
|
||||
},
|
||||
},
|
||||
}),
|
||||
);
|
||||
assert_wire_round_trip(
|
||||
HostToClient::DelegateRequest {
|
||||
id: delegate_request_id(/*value*/ 10),
|
||||
session_id: session_id(),
|
||||
request: DelegateRequest::Notify {
|
||||
call_id: "call-1".to_string(),
|
||||
cell_id: cell_id("cell-1"),
|
||||
text: "important".to_string(),
|
||||
},
|
||||
},
|
||||
json!({
|
||||
"type": "delegate/request",
|
||||
"id": 10,
|
||||
"sessionId": "session-1",
|
||||
"request": {
|
||||
"type": "notification/send",
|
||||
"callId": "call-1",
|
||||
"cellId": "cell-1",
|
||||
"text": "important",
|
||||
},
|
||||
}),
|
||||
);
|
||||
assert_wire_round_trip(
|
||||
HostToClient::CancelDelegateRequest {
|
||||
id: delegate_request_id(/*value*/ 11),
|
||||
},
|
||||
json!({ "type": "delegate/cancel", "id": 11 }),
|
||||
);
|
||||
assert_wire_round_trip(
|
||||
HostToClient::CellClosed {
|
||||
session_id: session_id(),
|
||||
cell_id: cell_id("cell-1"),
|
||||
},
|
||||
json!({
|
||||
"type": "cell/closed",
|
||||
"sessionId": "session-1",
|
||||
"cellId": "cell-1",
|
||||
}),
|
||||
);
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn execute_request_integer_bounds_are_enforced() {
|
||||
let wire_request = execute_request();
|
||||
let domain_request = ExecuteRequest::try_from(wire_request.clone())
|
||||
.expect("valid wire request converts to the domain");
|
||||
assert_eq!(
|
||||
WireExecuteRequest::try_from(domain_request.clone())
|
||||
.expect("valid domain request converts to the wire"),
|
||||
wire_request
|
||||
);
|
||||
|
||||
let too_large = ExecuteRequest {
|
||||
max_output_tokens: Some(usize::try_from(i32::MAX).expect("i32::MAX fits usize") + 1),
|
||||
..domain_request
|
||||
};
|
||||
assert!(WireExecuteRequest::try_from(too_large).is_err());
|
||||
|
||||
let negative = WireExecuteRequest {
|
||||
max_output_tokens: Some(-1),
|
||||
..wire_request
|
||||
};
|
||||
assert!(ExecuteRequest::try_from(negative).is_err());
|
||||
}
|
||||
|
||||
#[test]
|
||||
@@ -168,7 +627,11 @@ fn invalid_protocol_states_cannot_be_constructed_or_decoded() {
|
||||
);
|
||||
|
||||
for invalid in [
|
||||
json!({ "type": "session/open", "sessionId": "" }),
|
||||
json!({
|
||||
"type": "operation/request",
|
||||
"id": 1,
|
||||
"request": { "method": "session/open", "sessionId": "" },
|
||||
}),
|
||||
json!({
|
||||
"type": "connection/hello",
|
||||
"supportedVersions": [],
|
||||
@@ -187,19 +650,100 @@ fn invalid_protocol_states_cannot_be_constructed_or_decoded() {
|
||||
}
|
||||
|
||||
#[test]
|
||||
fn unknown_fields_are_rejected() {
|
||||
fn every_nested_v1_object_rejects_unknown_fields() {
|
||||
assert!(
|
||||
serde_json::from_value::<ClientToHost>(json!({
|
||||
"type": "session/open",
|
||||
"type": "operation/request",
|
||||
"id": 1,
|
||||
"request": { "method": "session/open", "sessionId": "session-1" },
|
||||
"unexpected": true,
|
||||
}))
|
||||
.is_err()
|
||||
);
|
||||
assert!(
|
||||
serde_json::from_value::<HostRequest>(json!({
|
||||
"method": "session/open",
|
||||
"sessionId": "session-1",
|
||||
"unexpected": true,
|
||||
}))
|
||||
.is_err()
|
||||
);
|
||||
assert!(
|
||||
serde_json::from_value::<WireExecuteRequest>(json!({
|
||||
"tool_call_id": "call-1",
|
||||
"enabled_tools": [],
|
||||
"source": "text('hello');",
|
||||
"yield_time_ms": null,
|
||||
"max_output_tokens": null,
|
||||
"unexpected": true,
|
||||
}))
|
||||
.is_err()
|
||||
);
|
||||
assert!(
|
||||
serde_json::from_value::<WireToolDefinition>(json!({
|
||||
"name": "tool",
|
||||
"tool_name": { "name": "tool", "namespace": null },
|
||||
"description": "tool",
|
||||
"kind": "function",
|
||||
"input_schema": null,
|
||||
"output_schema": null,
|
||||
"unexpected": true,
|
||||
}))
|
||||
.is_err()
|
||||
);
|
||||
assert!(
|
||||
serde_json::from_value::<WireToolName>(json!({
|
||||
"name": "tool",
|
||||
"namespace": null,
|
||||
"unexpected": true,
|
||||
}))
|
||||
.is_err()
|
||||
);
|
||||
assert!(
|
||||
serde_json::from_value::<WireWaitRequest>(json!({
|
||||
"cell_id": "cell-1",
|
||||
"yield_time_ms": 50,
|
||||
"unexpected": true,
|
||||
}))
|
||||
.is_err()
|
||||
);
|
||||
assert!(
|
||||
serde_json::from_value::<WireRuntimeResponse>(json!({
|
||||
"Yielded": {
|
||||
"cell_id": "cell-1",
|
||||
"content_items": [],
|
||||
"unexpected": true,
|
||||
},
|
||||
}))
|
||||
.is_err()
|
||||
);
|
||||
assert!(
|
||||
serde_json::from_value::<WireContentItem>(json!({
|
||||
"type": "input_text",
|
||||
"text": "hello",
|
||||
"unexpected": true,
|
||||
}))
|
||||
.is_err()
|
||||
);
|
||||
assert!(
|
||||
serde_json::from_value::<WireNestedToolCall>(json!({
|
||||
"cell_id": "cell-1",
|
||||
"runtime_tool_call_id": "runtime-call-1",
|
||||
"tool_name": { "name": "tool", "namespace": null },
|
||||
"tool_kind": "function",
|
||||
"input": null,
|
||||
"unexpected": true,
|
||||
}))
|
||||
.is_err()
|
||||
);
|
||||
assert!(
|
||||
serde_json::from_value::<HostToClient>(json!({
|
||||
"type": "session/ready",
|
||||
"sessionId": "session-1",
|
||||
"type": "operation/response",
|
||||
"id": 1,
|
||||
"result": {
|
||||
"status": "ok",
|
||||
"value": { "type": "session/ready", "sessionId": "session-1" },
|
||||
},
|
||||
"unexpected": true,
|
||||
}))
|
||||
.is_err()
|
||||
|
||||
@@ -2,13 +2,22 @@ use std::fmt;
|
||||
|
||||
use serde::Deserialize;
|
||||
use serde::Serialize;
|
||||
use serde_json::Value as JsonValue;
|
||||
|
||||
use super::Capability;
|
||||
use super::CapabilitySet;
|
||||
use super::DelegateRequestId;
|
||||
use super::HandshakeRejectReason;
|
||||
use super::ProtocolVersion;
|
||||
use super::RequestId;
|
||||
use super::SessionId;
|
||||
use super::SupportedProtocolVersions;
|
||||
use super::WireCellId;
|
||||
use super::WireExecuteRequest;
|
||||
use super::WireNestedToolCall;
|
||||
use super::WireRuntimeResponse;
|
||||
use super::WireWaitOutcome;
|
||||
use super::WireWaitRequest;
|
||||
|
||||
#[derive(Clone, Debug, PartialEq, Serialize)]
|
||||
#[serde(rename_all = "camelCase")]
|
||||
@@ -116,27 +125,133 @@ impl HostHello {
|
||||
}
|
||||
|
||||
/// Messages sent from a client to the code-mode host.
|
||||
#[derive(Clone, Debug, Deserialize, PartialEq, Serialize)]
|
||||
#[derive(Debug, Deserialize, PartialEq, Serialize)]
|
||||
#[serde(deny_unknown_fields, tag = "type", rename_all_fields = "camelCase")]
|
||||
pub enum ClientToHost {
|
||||
#[serde(rename = "connection/hello")]
|
||||
ClientHello(ClientHello),
|
||||
#[serde(rename = "session/open")]
|
||||
OpenSession { session_id: SessionId },
|
||||
#[serde(rename = "session/close")]
|
||||
CloseSession { session_id: SessionId },
|
||||
#[serde(rename = "operation/request")]
|
||||
Request { id: RequestId, request: HostRequest },
|
||||
#[serde(rename = "delegate/response")]
|
||||
DelegateResponse {
|
||||
id: DelegateRequestId,
|
||||
result: WireResult<DelegateResponse>,
|
||||
},
|
||||
}
|
||||
|
||||
/// Messages sent from the code-mode host to a client.
|
||||
#[derive(Clone, Debug, Deserialize, PartialEq, Serialize)]
|
||||
#[derive(Debug, Deserialize, PartialEq, Serialize)]
|
||||
#[serde(deny_unknown_fields, tag = "type", rename_all_fields = "camelCase")]
|
||||
pub enum HostToClient {
|
||||
#[serde(rename = "connection/ready")]
|
||||
HostHello(HostHello),
|
||||
#[serde(rename = "connection/rejected")]
|
||||
HandshakeRejected { reason: HandshakeRejectReason },
|
||||
#[serde(rename = "operation/response")]
|
||||
Response {
|
||||
id: RequestId,
|
||||
result: WireResult<HostResponse>,
|
||||
},
|
||||
#[serde(rename = "execute/initialResponse")]
|
||||
InitialResponse {
|
||||
id: RequestId,
|
||||
result: WireResult<WireRuntimeResponse>,
|
||||
},
|
||||
#[serde(rename = "delegate/request")]
|
||||
DelegateRequest {
|
||||
id: DelegateRequestId,
|
||||
session_id: SessionId,
|
||||
request: DelegateRequest,
|
||||
},
|
||||
#[serde(rename = "delegate/cancel")]
|
||||
CancelDelegateRequest { id: DelegateRequestId },
|
||||
#[serde(rename = "cell/closed")]
|
||||
CellClosed {
|
||||
session_id: SessionId,
|
||||
cell_id: WireCellId,
|
||||
},
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Deserialize, PartialEq, Serialize)]
|
||||
#[serde(deny_unknown_fields, tag = "method", rename_all_fields = "camelCase")]
|
||||
pub enum HostRequest {
|
||||
#[serde(rename = "session/open")]
|
||||
OpenSession { session_id: SessionId },
|
||||
#[serde(rename = "session/execute")]
|
||||
Execute {
|
||||
session_id: SessionId,
|
||||
request: WireExecuteRequest,
|
||||
},
|
||||
#[serde(rename = "session/wait")]
|
||||
Wait {
|
||||
session_id: SessionId,
|
||||
request: WireWaitRequest,
|
||||
},
|
||||
#[serde(rename = "session/terminate")]
|
||||
Terminate {
|
||||
session_id: SessionId,
|
||||
cell_id: WireCellId,
|
||||
},
|
||||
#[serde(rename = "session/shutdown")]
|
||||
ShutdownSession { session_id: SessionId },
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize, PartialEq, Serialize)]
|
||||
#[serde(deny_unknown_fields, tag = "type", rename_all_fields = "camelCase")]
|
||||
pub enum HostResponse {
|
||||
#[serde(rename = "session/ready")]
|
||||
SessionReady { session_id: SessionId },
|
||||
#[serde(rename = "execution/started")]
|
||||
ExecutionStarted { cell_id: WireCellId },
|
||||
#[serde(rename = "wait/completed")]
|
||||
WaitCompleted { outcome: WireWaitOutcome },
|
||||
#[serde(rename = "session/closed")]
|
||||
SessionClosed { session_id: SessionId },
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Deserialize, PartialEq, Serialize)]
|
||||
#[serde(deny_unknown_fields, tag = "type", rename_all_fields = "camelCase")]
|
||||
pub enum DelegateRequest {
|
||||
#[serde(rename = "tool/invoke")]
|
||||
InvokeTool { invocation: WireNestedToolCall },
|
||||
#[serde(rename = "notification/send")]
|
||||
Notify {
|
||||
call_id: String,
|
||||
cell_id: WireCellId,
|
||||
text: String,
|
||||
},
|
||||
}
|
||||
|
||||
#[derive(Clone, Debug, Deserialize, PartialEq, Serialize)]
|
||||
#[serde(deny_unknown_fields, tag = "type", rename_all_fields = "camelCase")]
|
||||
pub enum DelegateResponse {
|
||||
#[serde(rename = "tool/result")]
|
||||
ToolResult { result: JsonValue },
|
||||
#[serde(rename = "notification/delivered")]
|
||||
NotificationDelivered,
|
||||
}
|
||||
|
||||
#[derive(Debug, Deserialize, PartialEq, Serialize)]
|
||||
#[serde(deny_unknown_fields, tag = "status", rename_all_fields = "camelCase")]
|
||||
pub enum WireResult<T> {
|
||||
#[serde(rename = "ok")]
|
||||
Ok { value: T },
|
||||
#[serde(rename = "error")]
|
||||
Err { message: String },
|
||||
}
|
||||
|
||||
impl<T> WireResult<T> {
|
||||
pub fn from_result(result: Result<T, String>) -> Self {
|
||||
match result {
|
||||
Ok(value) => Self::Ok { value },
|
||||
Err(message) => Self::Err { message },
|
||||
}
|
||||
}
|
||||
|
||||
pub fn into_result(self) -> Result<T, String> {
|
||||
match self {
|
||||
Self::Ok { value } => Ok(value),
|
||||
Self::Err { message } => Err(message),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,29 +1,56 @@
|
||||
//! Transport-neutral messages for the callback-only code-mode host boundary.
|
||||
//! Messages and local IPC framing for the code-mode host boundary.
|
||||
//!
|
||||
//! Protocol version 1 relies on ordered framing and connection-scoped
|
||||
//! fail-stop behavior rather than message sequence numbers. It defines no
|
||||
//! optional capabilities yet; capability names provide an extension point for
|
||||
//! later versions without weakening the v1 decoder.
|
||||
//! Protocol version 1 multiplexes session operations and delegate callbacks by
|
||||
//! request ID over one ordered connection. It defines no optional capabilities
|
||||
//! yet; capability names provide an extension point for later versions without
|
||||
//! weakening the v1 decoder.
|
||||
|
||||
mod codec;
|
||||
mod error;
|
||||
mod message;
|
||||
mod payload;
|
||||
mod types;
|
||||
|
||||
pub use codec::FramedReader;
|
||||
pub use codec::FramedWriter;
|
||||
pub use codec::MAX_FRAME_BYTES;
|
||||
pub use error::HandshakeRejectReason;
|
||||
pub use message::ClientHello;
|
||||
pub use message::ClientHelloError;
|
||||
pub use message::ClientToHost;
|
||||
pub use message::DelegateRequest;
|
||||
pub use message::DelegateResponse;
|
||||
pub use message::HostHello;
|
||||
pub use message::HostRequest;
|
||||
pub use message::HostResponse;
|
||||
pub use message::HostToClient;
|
||||
pub use message::WireResult;
|
||||
pub use payload::WireCellId;
|
||||
pub use payload::WireContentItem;
|
||||
pub use payload::WireExecuteRequest;
|
||||
pub use payload::WireImageDetail;
|
||||
pub use payload::WireNestedToolCall;
|
||||
pub use payload::WireRuntimeResponse;
|
||||
pub use payload::WireToolDefinition;
|
||||
pub use payload::WireToolKind;
|
||||
pub use payload::WireToolName;
|
||||
pub use payload::WireWaitOutcome;
|
||||
pub use payload::WireWaitRequest;
|
||||
pub use types::Capability;
|
||||
pub use types::CapabilitySet;
|
||||
pub use types::DelegateRequestId;
|
||||
pub use types::DuplicateCapability;
|
||||
pub use types::InvalidIdentifier;
|
||||
pub use types::InvalidSupportedProtocolVersions;
|
||||
pub use types::ProtocolVersion;
|
||||
pub use types::RequestId;
|
||||
pub use types::SessionId;
|
||||
pub use types::SupportedProtocolVersions;
|
||||
|
||||
#[cfg(test)]
|
||||
#[path = "host_tests.rs"]
|
||||
mod tests;
|
||||
|
||||
#[cfg(test)]
|
||||
#[path = "codec_tests.rs"]
|
||||
mod codec_tests;
|
||||
|
||||
@@ -0,0 +1,412 @@
|
||||
use std::num::TryFromIntError;
|
||||
|
||||
use codex_protocol::ToolName;
|
||||
use serde::Deserialize;
|
||||
use serde::Serialize;
|
||||
use serde_json::Value as JsonValue;
|
||||
|
||||
use crate::CellId;
|
||||
use crate::CodeModeNestedToolCall;
|
||||
use crate::CodeModeToolKind;
|
||||
use crate::ExecuteRequest;
|
||||
use crate::FunctionCallOutputContentItem;
|
||||
use crate::ImageDetail;
|
||||
use crate::RuntimeResponse;
|
||||
use crate::ToolDefinition;
|
||||
use crate::WaitOutcome;
|
||||
use crate::WaitRequest;
|
||||
|
||||
/// A cell identifier with a wire representation owned by protocol V1.
|
||||
#[derive(Clone, Debug, Deserialize, Eq, Hash, PartialEq, Serialize)]
|
||||
#[serde(transparent)]
|
||||
pub struct WireCellId(String);
|
||||
|
||||
impl WireCellId {
|
||||
pub fn new(value: impl Into<String>) -> Self {
|
||||
Self(value.into())
|
||||
}
|
||||
|
||||
pub fn as_str(&self) -> &str {
|
||||
&self.0
|
||||
}
|
||||
}
|
||||
|
||||
impl From<CellId> for WireCellId {
|
||||
fn from(value: CellId) -> Self {
|
||||
Self(value.as_str().to_string())
|
||||
}
|
||||
}
|
||||
|
||||
impl From<&CellId> for WireCellId {
|
||||
fn from(value: &CellId) -> Self {
|
||||
Self(value.as_str().to_string())
|
||||
}
|
||||
}
|
||||
|
||||
impl From<WireCellId> for CellId {
|
||||
fn from(value: WireCellId) -> Self {
|
||||
Self::new(value.0)
|
||||
}
|
||||
}
|
||||
|
||||
/// The V1 wire representation of a tool's stable name.
|
||||
#[derive(Clone, Debug, Deserialize, Eq, PartialEq, Serialize)]
|
||||
#[serde(deny_unknown_fields)]
|
||||
pub struct WireToolName {
|
||||
pub name: String,
|
||||
pub namespace: Option<String>,
|
||||
}
|
||||
|
||||
impl From<ToolName> for WireToolName {
|
||||
fn from(value: ToolName) -> Self {
|
||||
Self {
|
||||
name: value.name,
|
||||
namespace: value.namespace,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<WireToolName> for ToolName {
|
||||
fn from(value: WireToolName) -> Self {
|
||||
Self::new(value.namespace, value.name)
|
||||
}
|
||||
}
|
||||
|
||||
/// The tool invocation shape supported by protocol V1.
|
||||
#[derive(Clone, Copy, Debug, Deserialize, Eq, PartialEq, Serialize)]
|
||||
#[serde(rename_all = "snake_case")]
|
||||
pub enum WireToolKind {
|
||||
Function,
|
||||
Freeform,
|
||||
}
|
||||
|
||||
impl From<CodeModeToolKind> for WireToolKind {
|
||||
fn from(value: CodeModeToolKind) -> Self {
|
||||
match value {
|
||||
CodeModeToolKind::Function => Self::Function,
|
||||
CodeModeToolKind::Freeform => Self::Freeform,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<WireToolKind> for CodeModeToolKind {
|
||||
fn from(value: WireToolKind) -> Self {
|
||||
match value {
|
||||
WireToolKind::Function => Self::Function,
|
||||
WireToolKind::Freeform => Self::Freeform,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// A V1 tool definition embedded in an execute request.
|
||||
#[derive(Clone, Debug, Deserialize, PartialEq, Serialize)]
|
||||
#[serde(deny_unknown_fields)]
|
||||
pub struct WireToolDefinition {
|
||||
pub name: String,
|
||||
pub tool_name: WireToolName,
|
||||
pub description: String,
|
||||
pub kind: WireToolKind,
|
||||
pub input_schema: Option<JsonValue>,
|
||||
pub output_schema: Option<JsonValue>,
|
||||
}
|
||||
|
||||
impl From<ToolDefinition> for WireToolDefinition {
|
||||
fn from(value: ToolDefinition) -> Self {
|
||||
Self {
|
||||
name: value.name,
|
||||
tool_name: value.tool_name.into(),
|
||||
description: value.description,
|
||||
kind: value.kind.into(),
|
||||
input_schema: value.input_schema,
|
||||
output_schema: value.output_schema,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<WireToolDefinition> for ToolDefinition {
|
||||
fn from(value: WireToolDefinition) -> Self {
|
||||
Self {
|
||||
name: value.name,
|
||||
tool_name: value.tool_name.into(),
|
||||
description: value.description,
|
||||
kind: value.kind.into(),
|
||||
input_schema: value.input_schema,
|
||||
output_schema: value.output_schema,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// The complete execute request shape supported by protocol V1.
|
||||
#[derive(Clone, Debug, Deserialize, PartialEq, Serialize)]
|
||||
#[serde(deny_unknown_fields)]
|
||||
pub struct WireExecuteRequest {
|
||||
pub tool_call_id: String,
|
||||
pub enabled_tools: Vec<WireToolDefinition>,
|
||||
pub source: String,
|
||||
pub yield_time_ms: Option<u64>,
|
||||
pub max_output_tokens: Option<i32>,
|
||||
}
|
||||
|
||||
impl TryFrom<ExecuteRequest> for WireExecuteRequest {
|
||||
type Error = TryFromIntError;
|
||||
|
||||
fn try_from(value: ExecuteRequest) -> Result<Self, Self::Error> {
|
||||
Ok(Self {
|
||||
tool_call_id: value.tool_call_id,
|
||||
enabled_tools: value.enabled_tools.into_iter().map(Into::into).collect(),
|
||||
source: value.source,
|
||||
yield_time_ms: value.yield_time_ms,
|
||||
max_output_tokens: value.max_output_tokens.map(i32::try_from).transpose()?,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
impl TryFrom<WireExecuteRequest> for ExecuteRequest {
|
||||
type Error = TryFromIntError;
|
||||
|
||||
fn try_from(value: WireExecuteRequest) -> Result<Self, Self::Error> {
|
||||
Ok(Self {
|
||||
tool_call_id: value.tool_call_id,
|
||||
enabled_tools: value.enabled_tools.into_iter().map(Into::into).collect(),
|
||||
source: value.source,
|
||||
yield_time_ms: value.yield_time_ms,
|
||||
max_output_tokens: value.max_output_tokens.map(usize::try_from).transpose()?,
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
/// The complete wait request shape supported by protocol V1.
|
||||
#[derive(Clone, Debug, Deserialize, Eq, PartialEq, Serialize)]
|
||||
#[serde(deny_unknown_fields)]
|
||||
pub struct WireWaitRequest {
|
||||
pub cell_id: WireCellId,
|
||||
pub yield_time_ms: u64,
|
||||
}
|
||||
|
||||
impl From<WaitRequest> for WireWaitRequest {
|
||||
fn from(value: WaitRequest) -> Self {
|
||||
Self {
|
||||
cell_id: value.cell_id.into(),
|
||||
yield_time_ms: value.yield_time_ms,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<WireWaitRequest> for WaitRequest {
|
||||
fn from(value: WireWaitRequest) -> Self {
|
||||
Self {
|
||||
cell_id: value.cell_id.into(),
|
||||
yield_time_ms: value.yield_time_ms,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Image detail values accepted in a V1 runtime response.
|
||||
#[derive(Clone, Copy, Debug, Deserialize, Eq, PartialEq, Serialize)]
|
||||
#[serde(rename_all = "lowercase")]
|
||||
pub enum WireImageDetail {
|
||||
Auto,
|
||||
Low,
|
||||
High,
|
||||
Original,
|
||||
}
|
||||
|
||||
impl From<ImageDetail> for WireImageDetail {
|
||||
fn from(value: ImageDetail) -> Self {
|
||||
match value {
|
||||
ImageDetail::Auto => Self::Auto,
|
||||
ImageDetail::Low => Self::Low,
|
||||
ImageDetail::High => Self::High,
|
||||
ImageDetail::Original => Self::Original,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<WireImageDetail> for ImageDetail {
|
||||
fn from(value: WireImageDetail) -> Self {
|
||||
match value {
|
||||
WireImageDetail::Auto => Self::Auto,
|
||||
WireImageDetail::Low => Self::Low,
|
||||
WireImageDetail::High => Self::High,
|
||||
WireImageDetail::Original => Self::Original,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// One output item emitted by a V1 runtime response.
|
||||
#[derive(Clone, Debug, Deserialize, PartialEq, Serialize)]
|
||||
#[serde(deny_unknown_fields, tag = "type", rename_all = "snake_case")]
|
||||
pub enum WireContentItem {
|
||||
InputText {
|
||||
text: String,
|
||||
},
|
||||
InputImage {
|
||||
image_url: String,
|
||||
#[serde(default, skip_serializing_if = "Option::is_none")]
|
||||
detail: Option<WireImageDetail>,
|
||||
},
|
||||
}
|
||||
|
||||
impl From<FunctionCallOutputContentItem> for WireContentItem {
|
||||
fn from(value: FunctionCallOutputContentItem) -> Self {
|
||||
match value {
|
||||
FunctionCallOutputContentItem::InputText { text } => Self::InputText { text },
|
||||
FunctionCallOutputContentItem::InputImage { image_url, detail } => Self::InputImage {
|
||||
image_url,
|
||||
detail: detail.map(Into::into),
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<WireContentItem> for FunctionCallOutputContentItem {
|
||||
fn from(value: WireContentItem) -> Self {
|
||||
match value {
|
||||
WireContentItem::InputText { text } => Self::InputText { text },
|
||||
WireContentItem::InputImage { image_url, detail } => Self::InputImage {
|
||||
image_url,
|
||||
detail: detail.map(Into::into),
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Runtime output returned over the V1 host connection.
|
||||
#[derive(Clone, Debug, Deserialize, PartialEq, Serialize)]
|
||||
#[serde(deny_unknown_fields)]
|
||||
pub enum WireRuntimeResponse {
|
||||
Yielded {
|
||||
cell_id: WireCellId,
|
||||
content_items: Vec<WireContentItem>,
|
||||
},
|
||||
Terminated {
|
||||
cell_id: WireCellId,
|
||||
content_items: Vec<WireContentItem>,
|
||||
},
|
||||
Result {
|
||||
cell_id: WireCellId,
|
||||
content_items: Vec<WireContentItem>,
|
||||
error_text: Option<String>,
|
||||
},
|
||||
}
|
||||
|
||||
impl From<RuntimeResponse> for WireRuntimeResponse {
|
||||
fn from(value: RuntimeResponse) -> Self {
|
||||
match value {
|
||||
RuntimeResponse::Yielded {
|
||||
cell_id,
|
||||
content_items,
|
||||
} => Self::Yielded {
|
||||
cell_id: cell_id.into(),
|
||||
content_items: content_items.into_iter().map(Into::into).collect(),
|
||||
},
|
||||
RuntimeResponse::Terminated {
|
||||
cell_id,
|
||||
content_items,
|
||||
} => Self::Terminated {
|
||||
cell_id: cell_id.into(),
|
||||
content_items: content_items.into_iter().map(Into::into).collect(),
|
||||
},
|
||||
RuntimeResponse::Result {
|
||||
cell_id,
|
||||
content_items,
|
||||
error_text,
|
||||
} => Self::Result {
|
||||
cell_id: cell_id.into(),
|
||||
content_items: content_items.into_iter().map(Into::into).collect(),
|
||||
error_text,
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<WireRuntimeResponse> for RuntimeResponse {
|
||||
fn from(value: WireRuntimeResponse) -> Self {
|
||||
match value {
|
||||
WireRuntimeResponse::Yielded {
|
||||
cell_id,
|
||||
content_items,
|
||||
} => Self::Yielded {
|
||||
cell_id: cell_id.into(),
|
||||
content_items: content_items.into_iter().map(Into::into).collect(),
|
||||
},
|
||||
WireRuntimeResponse::Terminated {
|
||||
cell_id,
|
||||
content_items,
|
||||
} => Self::Terminated {
|
||||
cell_id: cell_id.into(),
|
||||
content_items: content_items.into_iter().map(Into::into).collect(),
|
||||
},
|
||||
WireRuntimeResponse::Result {
|
||||
cell_id,
|
||||
content_items,
|
||||
error_text,
|
||||
} => Self::Result {
|
||||
cell_id: cell_id.into(),
|
||||
content_items: content_items.into_iter().map(Into::into).collect(),
|
||||
error_text,
|
||||
},
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// Whether a waited-for cell remained live in protocol V1.
|
||||
#[derive(Clone, Debug, Deserialize, PartialEq, Serialize)]
|
||||
#[serde(deny_unknown_fields)]
|
||||
pub enum WireWaitOutcome {
|
||||
LiveCell(WireRuntimeResponse),
|
||||
MissingCell(WireRuntimeResponse),
|
||||
}
|
||||
|
||||
impl From<WaitOutcome> for WireWaitOutcome {
|
||||
fn from(value: WaitOutcome) -> Self {
|
||||
match value {
|
||||
WaitOutcome::LiveCell(response) => Self::LiveCell(response.into()),
|
||||
WaitOutcome::MissingCell(response) => Self::MissingCell(response.into()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<WireWaitOutcome> for WaitOutcome {
|
||||
fn from(value: WireWaitOutcome) -> Self {
|
||||
match value {
|
||||
WireWaitOutcome::LiveCell(response) => Self::LiveCell(response.into()),
|
||||
WireWaitOutcome::MissingCell(response) => Self::MissingCell(response.into()),
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
/// A nested tool invocation sent over the V1 host connection.
|
||||
#[derive(Clone, Debug, Deserialize, PartialEq, Serialize)]
|
||||
#[serde(deny_unknown_fields)]
|
||||
pub struct WireNestedToolCall {
|
||||
pub cell_id: WireCellId,
|
||||
pub runtime_tool_call_id: String,
|
||||
pub tool_name: WireToolName,
|
||||
pub tool_kind: WireToolKind,
|
||||
pub input: Option<JsonValue>,
|
||||
}
|
||||
|
||||
impl From<CodeModeNestedToolCall> for WireNestedToolCall {
|
||||
fn from(value: CodeModeNestedToolCall) -> Self {
|
||||
Self {
|
||||
cell_id: value.cell_id.into(),
|
||||
runtime_tool_call_id: value.runtime_tool_call_id,
|
||||
tool_name: value.tool_name.into(),
|
||||
tool_kind: value.tool_kind.into(),
|
||||
input: value.input,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
impl From<WireNestedToolCall> for CodeModeNestedToolCall {
|
||||
fn from(value: WireNestedToolCall) -> Self {
|
||||
Self {
|
||||
cell_id: value.cell_id.into(),
|
||||
runtime_tool_call_id: value.runtime_tool_call_id,
|
||||
tool_name: value.tool_name.into(),
|
||||
tool_kind: value.tool_kind.into(),
|
||||
input: value.input,
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -8,6 +8,28 @@ use serde::Serialize;
|
||||
use serde::Serializer;
|
||||
use serde::de::Error as _;
|
||||
|
||||
/// Correlates one client operation request with the host's response.
|
||||
#[derive(Clone, Copy, Debug, Deserialize, Eq, Hash, Ord, PartialEq, PartialOrd, Serialize)]
|
||||
#[serde(transparent)]
|
||||
pub struct RequestId(i64);
|
||||
|
||||
impl RequestId {
|
||||
pub const fn new(value: i64) -> Self {
|
||||
Self(value)
|
||||
}
|
||||
}
|
||||
|
||||
/// Correlates one host delegate request with the client's response.
|
||||
#[derive(Clone, Copy, Debug, Deserialize, Eq, Hash, Ord, PartialEq, PartialOrd, Serialize)]
|
||||
#[serde(transparent)]
|
||||
pub struct DelegateRequestId(i64);
|
||||
|
||||
impl DelegateRequestId {
|
||||
pub const fn new(value: i64) -> Self {
|
||||
Self(value)
|
||||
}
|
||||
}
|
||||
|
||||
#[derive(Clone, Copy, Debug, Deserialize, Eq, Hash, Ord, PartialEq, PartialOrd, Serialize)]
|
||||
#[serde(transparent)]
|
||||
pub struct ProtocolVersion(NonZeroU32);
|
||||
|
||||
Reference in New Issue
Block a user