Route Python SDK turn notifications by ID (#21778)

## Why

The Python SDK previously protected the stdio transport with a single
active turn-consumer guard. That avoided competing reads from stdout,
but it also meant one `Codex`/`AsyncCodex` client could not stream
multiple active turns at the same time. Notifications could also arrive
before the caller received a `TurnHandle` and registered for streaming,
so the SDK needed an explicit routing layer instead of letting
individual API calls read directly from the shared transport.

## What Changed

- Added a private `MessageRouter` that owns per-request response queues,
per-turn notification queues, pending turn-notification replay, and
global notification delivery behind a single stdout reader thread.
- Generated typed notification routing metadata so turn IDs come from
known payload shapes instead of router-side attribute guessing, with
explicit fallback handling for unknown notification payloads.
- Updated sync and async turn streaming so `TurnHandle.stream()`/`run()`
and `stream_text()` consume only notifications for their own turn ID,
while `AsyncAppServerClient` no longer serializes all transport calls
behind one async lock.
- Cleared pending turn-notification buffers when unregistered turns
complete so never-consumed turn handles do not leave stale queues
behind.
- Removed the internal stream-until helper now that turn completion
waiting can register directly with routed turn notifications.
- Updated Python SDK docs and focused tests for concurrent transport
calls, interleaved turn routing, buffered early notifications, unknown
notification routing, async delegation, and routed turn completion
behavior.

## Validation

- `uv run --extra dev ruff format scripts/update_sdk_artifacts.py
src/codex_app_server/_message_router.py src/codex_app_server/client.py
src/codex_app_server/generated/notification_registry.py
tests/test_client_rpc_methods.py
tests/test_public_api_runtime_behavior.py
tests/test_async_client_behavior.py`
- `uv run --extra dev ruff check scripts/update_sdk_artifacts.py
src/codex_app_server/_message_router.py src/codex_app_server/client.py
src/codex_app_server/generated/notification_registry.py
tests/test_client_rpc_methods.py
tests/test_public_api_runtime_behavior.py
tests/test_async_client_behavior.py`
- `uv run --extra dev pytest tests/test_client_rpc_methods.py
tests/test_public_api_runtime_behavior.py
tests/test_async_client_behavior.py`
- `git diff --check`

---------

Co-authored-by: Codex <noreply@openai.com>
This commit is contained in:
Ahmed Ibrahim
2026-05-09 07:16:23 +03:00
committed by GitHub
Unverified
parent 77d9223e9f
commit ebe75bb683
11 changed files with 916 additions and 197 deletions
+5 -5
View File
@@ -2,7 +2,7 @@
Public surface of `codex_app_server` for app-server v2.
This SDK surface is experimental. The current implementation intentionally allows only one active turn consumer (`Thread.run()`, `TurnHandle.stream()`, or `TurnHandle.run()`) per client instance at a time.
This SDK surface is experimental. Turn streams are routed by turn ID so one client can consume multiple active turns concurrently.
## Package Entry
@@ -137,8 +137,8 @@ Use `turn(...)` when you need low-level turn control (`stream()`, `steer()`,
Behavior notes:
- `stream()` and `run()` are exclusive per client instance in the current experimental build
- starting a second turn consumer on the same `Codex` instance raises `RuntimeError`
- `stream()` and `run()` consume only notifications for their own turn ID
- one `Codex` instance can stream multiple active turns concurrently
### AsyncTurnHandle
@@ -149,8 +149,8 @@ Behavior notes:
Behavior notes:
- `stream()` and `run()` are exclusive per client instance in the current experimental build
- starting a second turn consumer on the same `AsyncCodex` instance raises `RuntimeError`
- `stream()` and `run()` consume only notifications for their own turn ID
- one `AsyncCodex` instance can stream multiple active turns concurrently
## Inputs
+1 -1
View File
@@ -45,7 +45,7 @@ What happened:
- `thread.run("...")` started a turn, consumed events until completion, and returned the final assistant response plus collected items and usage.
- `result.final_response` is `None` when no final-answer or phase-less assistant message item completes for the turn.
- use `thread.turn(...)` when you need a `TurnHandle` for streaming, steering, interrupting, or turn IDs/status
- one client can have only one active turn consumer (`thread.run(...)`, `TurnHandle.stream()`, or `TurnHandle.run()`) at a time in the current experimental build
- one client can consume multiple active turns concurrently; turn streams are routed by turn ID
## 3) Continue the same thread (multi-turn)
+58 -1
View File
@@ -585,6 +585,43 @@ def _notification_specs() -> list[tuple[str, str]]:
return specs
def _notification_turn_id_specs(
specs: list[tuple[str, str]],
) -> tuple[list[str], list[str]]:
server_notifications = json.loads(
(schema_root_dir() / "ServerNotification.json").read_text()
)
definitions = server_notifications.get("definitions", {})
if not isinstance(definitions, dict):
return ([], [])
direct: list[str] = []
nested: list[str] = []
for _, class_name in specs:
definition = definitions.get(class_name)
if not isinstance(definition, dict):
continue
props = definition.get("properties", {})
if not isinstance(props, dict):
continue
if "turnId" in props:
direct.append(class_name)
continue
turn = props.get("turn")
if isinstance(turn, dict) and turn.get("$ref") == "#/definitions/Turn":
nested.append(class_name)
return (sorted(set(direct)), sorted(set(nested)))
def _type_tuple_source(class_names: list[str]) -> str:
if not class_names:
return "()"
if len(class_names) == 1:
return f"({class_names[0]},)"
return "(\n" + "".join(f" {class_name},\n" for class_name in class_names) + ")"
def generate_notification_registry() -> None:
out = (
sdk_root()
@@ -595,6 +632,7 @@ def generate_notification_registry() -> None:
)
specs = _notification_specs()
class_names = sorted({class_name for _, class_name in specs})
direct_turn_id_types, nested_turn_types = _notification_turn_id_specs(specs)
lines = [
"# Auto-generated by scripts/update_sdk_artifacts.py",
@@ -616,7 +654,26 @@ def generate_notification_registry() -> None:
)
for method, class_name in specs:
lines.append(f' "{method}": {class_name},')
lines.extend(["}", ""])
lines.extend(
[
"}",
"",
"DIRECT_TURN_ID_NOTIFICATION_TYPES: tuple[type[BaseModel], ...] = "
f"{_type_tuple_source(direct_turn_id_types)}",
"",
"NESTED_TURN_NOTIFICATION_TYPES: tuple[type[BaseModel], ...] = "
f"{_type_tuple_source(nested_turn_types)}",
"",
"",
"def notification_turn_id(payload: BaseModel) -> str | None:",
" if isinstance(payload, DIRECT_TURN_ID_NOTIFICATION_TYPES):",
" return payload.turn_id if isinstance(payload.turn_id, str) else None",
" if isinstance(payload, NESTED_TURN_NOTIFICATION_TYPES):",
" return payload.turn.id",
" return None",
"",
]
)
out.write_text("\n".join(lines))
@@ -0,0 +1,158 @@
from __future__ import annotations
import queue
import threading
from collections import deque
from .errors import AppServerError, map_jsonrpc_error
from .generated.notification_registry import notification_turn_id
from .models import JsonValue, Notification, UnknownNotification
ResponseQueueItem = JsonValue | BaseException
NotificationQueueItem = Notification | BaseException
class MessageRouter:
"""Route reader-thread messages to the SDK operation waiting for them.
The app-server stdio transport is a single ordered stream, so only the
reader thread should consume stdout. This router keeps the rest of the SDK
from competing for that stream by giving each in-flight JSON-RPC request
and active turn stream its own queue.
"""
def __init__(self) -> None:
self._lock = threading.Lock()
self._response_waiters: dict[str, queue.Queue[ResponseQueueItem]] = {}
self._turn_notifications: dict[str, queue.Queue[NotificationQueueItem]] = {}
self._pending_turn_notifications: dict[str, deque[Notification]] = {}
self._global_notifications: queue.Queue[NotificationQueueItem] = queue.Queue()
def create_response_waiter(self, request_id: str) -> queue.Queue[ResponseQueueItem]:
"""Register a one-shot queue for a JSON-RPC response id."""
waiter: queue.Queue[ResponseQueueItem] = queue.Queue(maxsize=1)
with self._lock:
self._response_waiters[request_id] = waiter
return waiter
def discard_response_waiter(self, request_id: str) -> None:
"""Remove a response waiter when the request could not be written."""
with self._lock:
self._response_waiters.pop(request_id, None)
def next_global_notification(self) -> Notification:
"""Block until the next notification that is not scoped to a turn."""
item = self._global_notifications.get()
if isinstance(item, BaseException):
raise item
return item
def register_turn(self, turn_id: str) -> None:
"""Register a queue for a turn stream and replay early events."""
turn_queue: queue.Queue[NotificationQueueItem] = queue.Queue()
with self._lock:
if turn_id in self._turn_notifications:
return
# A turn can emit events immediately after turn/start, before the
# caller receives the TurnHandle and starts streaming.
pending = self._pending_turn_notifications.pop(turn_id, deque())
self._turn_notifications[turn_id] = turn_queue
for notification in pending:
turn_queue.put(notification)
def unregister_turn(self, turn_id: str) -> None:
"""Stop routing future turn events to the stream queue."""
with self._lock:
self._turn_notifications.pop(turn_id, None)
def next_turn_notification(self, turn_id: str) -> Notification:
"""Block until the next notification for a registered turn."""
with self._lock:
turn_queue = self._turn_notifications.get(turn_id)
if turn_queue is None:
raise RuntimeError(f"turn {turn_id!r} is not registered for streaming")
item = turn_queue.get()
if isinstance(item, BaseException):
raise item
return item
def route_response(self, msg: dict[str, JsonValue]) -> None:
"""Deliver a JSON-RPC response or error to its request waiter."""
request_id = msg.get("id")
with self._lock:
waiter = self._response_waiters.pop(str(request_id), None)
if waiter is None:
return
if "error" in msg:
err = msg["error"]
if isinstance(err, dict):
waiter.put(
map_jsonrpc_error(
int(err.get("code", -32000)),
str(err.get("message", "unknown")),
err.get("data"),
)
)
else:
waiter.put(AppServerError("Malformed JSON-RPC error response"))
return
waiter.put(msg.get("result"))
def route_notification(self, notification: Notification) -> None:
"""Deliver a notification to a turn queue or the global queue."""
turn_id = self._notification_turn_id(notification)
if turn_id is None:
self._global_notifications.put(notification)
return
with self._lock:
turn_queue = self._turn_notifications.get(turn_id)
if turn_queue is None:
if notification.method == "turn/completed":
self._pending_turn_notifications.pop(turn_id, None)
return
self._pending_turn_notifications.setdefault(turn_id, deque()).append(
notification
)
return
turn_queue.put(notification)
def fail_all(self, exc: BaseException) -> None:
"""Wake every blocked waiter when the reader thread exits."""
with self._lock:
response_waiters = list(self._response_waiters.values())
self._response_waiters.clear()
turn_queues = list(self._turn_notifications.values())
self._pending_turn_notifications.clear()
# Put the same transport failure into every queue so no SDK call blocks
# forever waiting for a response that cannot arrive.
for waiter in response_waiters:
waiter.put(exc)
for turn_queue in turn_queues:
turn_queue.put(exc)
self._global_notifications.put(exc)
def _notification_turn_id(self, notification: Notification) -> str | None:
payload = notification.payload
if isinstance(payload, UnknownNotification):
raw_turn_id = payload.params.get("turnId")
if isinstance(raw_turn_id, str):
return raw_turn_id
raw_turn = payload.params.get("turn")
if isinstance(raw_turn, dict):
raw_nested_turn_id = raw_turn.get("id")
if isinstance(raw_nested_turn_id, str):
return raw_nested_turn_id
return None
return notification_turn_id(payload)
+16 -14
View File
@@ -38,14 +38,14 @@ from .generated.v2_all import (
)
from .models import InitializeResponse, JsonObject, Notification, ServerInfo
from ._inputs import (
ImageInput,
ImageInput as ImageInput,
Input,
InputItem,
LocalImageInput,
MentionInput,
InputItem as InputItem,
LocalImageInput as LocalImageInput,
MentionInput as MentionInput,
RunInput,
SkillInput,
TextInput,
SkillInput as SkillInput,
TextInput as TextInput,
_normalize_run_input,
_to_wire_input,
)
@@ -274,6 +274,7 @@ class Codex:
def thread_unarchive(self, thread_id: str) -> Thread:
unarchived = self._client.thread_unarchive(thread_id)
return Thread(self._client, unarchived.thread.id)
# END GENERATED: Codex.flat_methods
def models(self, *, include_hidden: bool = False) -> ModelListResponse:
@@ -476,6 +477,7 @@ class AsyncCodex:
await self._ensure_initialized()
unarchived = await self._client.thread_unarchive(thread_id)
return AsyncThread(self, unarchived.thread.id)
# END GENERATED: AsyncCodex.flat_methods
async def models(self, *, include_hidden: bool = False) -> ModelListResponse:
@@ -555,6 +557,7 @@ class Thread:
)
turn = self._client.turn_start(self.id, wire_input, params=params)
return TurnHandle(self._client, self.id, turn.turn.id)
# END GENERATED: Thread.flat_methods
def read(self, *, include_turns: bool = False) -> ThreadReadResponse:
@@ -644,6 +647,7 @@ class AsyncThread:
params=params,
)
return AsyncTurnHandle(self._codex, self.id, turn.turn.id)
# END GENERATED: AsyncThread.flat_methods
async def read(self, *, include_turns: bool = False) -> ThreadReadResponse:
@@ -674,11 +678,10 @@ class TurnHandle:
return self._client.turn_interrupt(self.thread_id, self.id)
def stream(self) -> Iterator[Notification]:
# TODO: replace this client-wide experimental guard with per-turn event demux.
self._client.acquire_turn_consumer(self.id)
self._client.register_turn_notifications(self.id)
try:
while True:
event = self._client.next_notification()
event = self._client.next_turn_notification(self.id)
yield event
if (
event.method == "turn/completed"
@@ -687,7 +690,7 @@ class TurnHandle:
):
break
finally:
self._client.release_turn_consumer(self.id)
self._client.unregister_turn_notifications(self.id)
def run(self) -> AppServerTurn:
completed: TurnCompletedNotification | None = None
@@ -728,11 +731,10 @@ class AsyncTurnHandle:
async def stream(self) -> AsyncIterator[Notification]:
await self._codex._ensure_initialized()
# TODO: replace this client-wide experimental guard with per-turn event demux.
self._codex._client.acquire_turn_consumer(self.id)
self._codex._client.register_turn_notifications(self.id)
try:
while True:
event = await self._codex._client.next_notification()
event = await self._codex._client.next_turn_notification(self.id)
yield event
if (
event.method == "turn/completed"
@@ -741,7 +743,7 @@ class AsyncTurnHandle:
):
break
finally:
self._codex._client.release_turn_consumer(self.id)
self._codex._client.unregister_turn_notifications(self.id)
async def run(self) -> AppServerTurn:
completed: TurnCompletedNotification | None = None
+33 -27
View File
@@ -2,7 +2,7 @@ from __future__ import annotations
import asyncio
from collections.abc import Iterator
from typing import AsyncIterator, Callable, Iterable, ParamSpec, TypeVar
from typing import AsyncIterator, Callable, ParamSpec, TypeVar
from pydantic import BaseModel
@@ -41,8 +41,6 @@ class AsyncAppServerClient:
def __init__(self, config: AppServerConfig | None = None) -> None:
self._sync = AppServerClient(config=config)
# Single stdio transport cannot be read safely from multiple threads.
self._transport_lock = asyncio.Lock()
async def __aenter__(self) -> "AsyncAppServerClient":
await self.start()
@@ -58,8 +56,7 @@ class AsyncAppServerClient:
*args: ParamsT.args,
**kwargs: ParamsT.kwargs,
) -> ReturnT:
async with self._transport_lock:
return await asyncio.to_thread(fn, *args, **kwargs)
return await asyncio.to_thread(fn, *args, **kwargs)
@staticmethod
def _next_from_iterator(
@@ -79,11 +76,11 @@ class AsyncAppServerClient:
async def initialize(self) -> InitializeResponse:
return await self._call_sync(self._sync.initialize)
def acquire_turn_consumer(self, turn_id: str) -> None:
self._sync.acquire_turn_consumer(turn_id)
def register_turn_notifications(self, turn_id: str) -> None:
self._sync.register_turn_notifications(turn_id)
def release_turn_consumer(self, turn_id: str) -> None:
self._sync.release_turn_consumer(turn_id)
def unregister_turn_notifications(self, turn_id: str) -> None:
self._sync.unregister_turn_notifications(turn_id)
async def request(
self,
@@ -99,7 +96,9 @@ class AsyncAppServerClient:
response_model=response_model,
)
async def thread_start(self, params: V2ThreadStartParams | JsonObject | None = None) -> ThreadStartResponse:
async def thread_start(
self, params: V2ThreadStartParams | JsonObject | None = None
) -> ThreadStartResponse:
return await self._call_sync(self._sync.thread_start, params)
async def thread_resume(
@@ -109,10 +108,14 @@ class AsyncAppServerClient:
) -> ThreadResumeResponse:
return await self._call_sync(self._sync.thread_resume, thread_id, params)
async def thread_list(self, params: V2ThreadListParams | JsonObject | None = None) -> ThreadListResponse:
async def thread_list(
self, params: V2ThreadListParams | JsonObject | None = None
) -> ThreadListResponse:
return await self._call_sync(self._sync.thread_list, params)
async def thread_read(self, thread_id: str, include_turns: bool = False) -> ThreadReadResponse:
async def thread_read(
self, thread_id: str, include_turns: bool = False
) -> ThreadReadResponse:
return await self._call_sync(self._sync.thread_read, thread_id, include_turns)
async def thread_fork(
@@ -140,9 +143,13 @@ class AsyncAppServerClient:
input_items: list[JsonObject] | JsonObject | str,
params: V2TurnStartParams | JsonObject | None = None,
) -> TurnStartResponse:
return await self._call_sync(self._sync.turn_start, thread_id, input_items, params)
return await self._call_sync(
self._sync.turn_start, thread_id, input_items, params
)
async def turn_interrupt(self, thread_id: str, turn_id: str) -> TurnInterruptResponse:
async def turn_interrupt(
self, thread_id: str, turn_id: str
) -> TurnInterruptResponse:
return await self._call_sync(self._sync.turn_interrupt, thread_id, turn_id)
async def turn_steer(
@@ -184,25 +191,24 @@ class AsyncAppServerClient:
async def next_notification(self) -> Notification:
return await self._call_sync(self._sync.next_notification)
async def next_turn_notification(self, turn_id: str) -> Notification:
return await self._call_sync(self._sync.next_turn_notification, turn_id)
async def wait_for_turn_completed(self, turn_id: str) -> TurnCompletedNotification:
return await self._call_sync(self._sync.wait_for_turn_completed, turn_id)
async def stream_until_methods(self, methods: Iterable[str] | str) -> list[Notification]:
return await self._call_sync(self._sync.stream_until_methods, methods)
async def stream_text(
self,
thread_id: str,
text: str,
params: V2TurnStartParams | JsonObject | None = None,
) -> AsyncIterator[AgentMessageDeltaNotification]:
async with self._transport_lock:
iterator = self._sync.stream_text(thread_id, text, params)
while True:
has_value, chunk = await asyncio.to_thread(
self._next_from_iterator,
iterator,
)
if not has_value:
break
yield chunk
iterator = self._sync.stream_text(thread_id, text, params)
while True:
has_value, chunk = await asyncio.to_thread(
self._next_from_iterator,
iterator,
)
if not has_value:
break
yield chunk
+127 -102
View File
@@ -8,11 +8,11 @@ import uuid
from collections import deque
from dataclasses import dataclass
from pathlib import Path
from typing import Callable, Iterable, Iterator, TypeVar
from typing import Callable, Iterator, TypeVar
from pydantic import BaseModel
from .errors import AppServerError, TransportClosedError, map_jsonrpc_error
from .errors import AppServerError, TransportClosedError
from .generated.notification_registry import NOTIFICATION_MODELS
from .generated.v2_all import (
AgentMessageDeltaNotification,
@@ -43,6 +43,7 @@ from .models import (
Notification,
UnknownNotification,
)
from ._message_router import MessageRouter
from .retry import retry_on_overload
from ._version import __version__ as SDK_VERSION
@@ -75,7 +76,9 @@ def _params_dict(
return dumped
if isinstance(params, dict):
return params
raise TypeError(f"Expected generated params model or dict, got {type(params).__name__}")
raise TypeError(
f"Expected generated params model or dict, got {type(params).__name__}"
)
def _installed_codex_path() -> Path:
@@ -146,11 +149,10 @@ class AppServerClient:
self._approval_handler = approval_handler or self._default_approval_handler
self._proc: subprocess.Popen[str] | None = None
self._lock = threading.Lock()
self._turn_consumer_lock = threading.Lock()
self._active_turn_consumer: str | None = None
self._pending_notifications: deque[Notification] = deque()
self._router = MessageRouter()
self._stderr_lines: deque[str] = deque(maxlen=400)
self._stderr_thread: threading.Thread | None = None
self._reader_thread: threading.Thread | None = None
def __enter__(self) -> "AppServerClient":
self.start()
@@ -189,13 +191,13 @@ class AppServerClient:
)
self._start_stderr_drain_thread()
self._start_reader_thread()
def close(self) -> None:
if self._proc is None:
return
proc = self._proc
self._proc = None
self._active_turn_consumer = None
if proc.stdin:
proc.stdin.close()
@@ -207,6 +209,8 @@ class AppServerClient:
if self._stderr_thread and self._stderr_thread.is_alive():
self._stderr_thread.join(timeout=0.5)
if self._reader_thread and self._reader_thread.is_alive():
self._reader_thread.join(timeout=0.5)
def initialize(self) -> InitializeResponse:
result = self.request(
@@ -240,70 +244,42 @@ class AppServerClient:
def _request_raw(self, method: str, params: JsonObject | None = None) -> JsonValue:
request_id = str(uuid.uuid4())
self._write_message({"id": request_id, "method": method, "params": params or {}})
waiter = self._router.create_response_waiter(request_id)
while True:
msg = self._read_message()
try:
self._write_message(
{"id": request_id, "method": method, "params": params or {}}
)
except BaseException:
self._router.discard_response_waiter(request_id)
raise
if "method" in msg and "id" in msg:
response = self._handle_server_request(msg)
self._write_message({"id": msg["id"], "result": response})
continue
if "method" in msg and "id" not in msg:
self._pending_notifications.append(
self._coerce_notification(msg["method"], msg.get("params"))
)
continue
if msg.get("id") != request_id:
continue
if "error" in msg:
err = msg["error"]
if isinstance(err, dict):
raise map_jsonrpc_error(
int(err.get("code", -32000)),
str(err.get("message", "unknown")),
err.get("data"),
)
raise AppServerError("Malformed JSON-RPC error response")
return msg.get("result")
item = waiter.get()
if isinstance(item, BaseException):
raise item
return item
def notify(self, method: str, params: JsonObject | None = None) -> None:
self._write_message({"method": method, "params": params or {}})
def next_notification(self) -> Notification:
if self._pending_notifications:
return self._pending_notifications.popleft()
return self._router.next_global_notification()
while True:
msg = self._read_message()
if "method" in msg and "id" in msg:
response = self._handle_server_request(msg)
self._write_message({"id": msg["id"], "result": response})
continue
if "method" in msg and "id" not in msg:
return self._coerce_notification(msg["method"], msg.get("params"))
def register_turn_notifications(self, turn_id: str) -> None:
self._router.register_turn(turn_id)
def acquire_turn_consumer(self, turn_id: str) -> None:
with self._turn_consumer_lock:
if self._active_turn_consumer is not None:
raise RuntimeError(
"Concurrent turn consumers are not yet supported in the experimental SDK. "
f"Client is already streaming turn {self._active_turn_consumer!r}; "
f"cannot start turn {turn_id!r} until the active consumer finishes."
)
self._active_turn_consumer = turn_id
def unregister_turn_notifications(self, turn_id: str) -> None:
self._router.unregister_turn(turn_id)
def release_turn_consumer(self, turn_id: str) -> None:
with self._turn_consumer_lock:
if self._active_turn_consumer == turn_id:
self._active_turn_consumer = None
def next_turn_notification(self, turn_id: str) -> Notification:
return self._router.next_turn_notification(turn_id)
def thread_start(self, params: V2ThreadStartParams | JsonObject | None = None) -> ThreadStartResponse:
return self.request("thread/start", _params_dict(params), response_model=ThreadStartResponse)
def thread_start(
self, params: V2ThreadStartParams | JsonObject | None = None
) -> ThreadStartResponse:
return self.request(
"thread/start", _params_dict(params), response_model=ThreadStartResponse
)
def thread_resume(
self,
@@ -311,12 +287,20 @@ class AppServerClient:
params: V2ThreadResumeParams | JsonObject | None = None,
) -> ThreadResumeResponse:
payload = {"threadId": thread_id, **_params_dict(params)}
return self.request("thread/resume", payload, response_model=ThreadResumeResponse)
return self.request(
"thread/resume", payload, response_model=ThreadResumeResponse
)
def thread_list(self, params: V2ThreadListParams | JsonObject | None = None) -> ThreadListResponse:
return self.request("thread/list", _params_dict(params), response_model=ThreadListResponse)
def thread_list(
self, params: V2ThreadListParams | JsonObject | None = None
) -> ThreadListResponse:
return self.request(
"thread/list", _params_dict(params), response_model=ThreadListResponse
)
def thread_read(self, thread_id: str, include_turns: bool = False) -> ThreadReadResponse:
def thread_read(
self, thread_id: str, include_turns: bool = False
) -> ThreadReadResponse:
return self.request(
"thread/read",
{"threadId": thread_id, "includeTurns": include_turns},
@@ -332,10 +316,18 @@ class AppServerClient:
return self.request("thread/fork", payload, response_model=ThreadForkResponse)
def thread_archive(self, thread_id: str) -> ThreadArchiveResponse:
return self.request("thread/archive", {"threadId": thread_id}, response_model=ThreadArchiveResponse)
return self.request(
"thread/archive",
{"threadId": thread_id},
response_model=ThreadArchiveResponse,
)
def thread_unarchive(self, thread_id: str) -> ThreadUnarchiveResponse:
return self.request("thread/unarchive", {"threadId": thread_id}, response_model=ThreadUnarchiveResponse)
return self.request(
"thread/unarchive",
{"threadId": thread_id},
response_model=ThreadUnarchiveResponse,
)
def thread_set_name(self, thread_id: str, name: str) -> ThreadSetNameResponse:
return self.request(
@@ -362,7 +354,9 @@ class AppServerClient:
"threadId": thread_id,
"input": self._normalize_input_items(input_items),
}
return self.request("turn/start", payload, response_model=TurnStartResponse)
started = self.request("turn/start", payload, response_model=TurnStartResponse)
self.register_turn_notifications(started.turn.id)
return started
def turn_interrupt(self, thread_id: str, turn_id: str) -> TurnInterruptResponse:
return self.request(
@@ -412,23 +406,18 @@ class AppServerClient:
)
def wait_for_turn_completed(self, turn_id: str) -> TurnCompletedNotification:
while True:
notification = self.next_notification()
if (
notification.method == "turn/completed"
and isinstance(notification.payload, TurnCompletedNotification)
and notification.payload.turn.id == turn_id
):
return notification.payload
def stream_until_methods(self, methods: Iterable[str] | str) -> list[Notification]:
target_methods = {methods} if isinstance(methods, str) else set(methods)
out: list[Notification] = []
while True:
notification = self.next_notification()
out.append(notification)
if notification.method in target_methods:
return out
self.register_turn_notifications(turn_id)
try:
while True:
notification = self.next_turn_notification(turn_id)
if (
notification.method == "turn/completed"
and isinstance(notification.payload, TurnCompletedNotification)
and notification.payload.turn.id == turn_id
):
return notification.payload
finally:
self.unregister_turn_notifications(turn_id)
def stream_text(
self,
@@ -438,33 +427,41 @@ class AppServerClient:
) -> Iterator[AgentMessageDeltaNotification]:
started = self.turn_start(thread_id, text, params=params)
turn_id = started.turn.id
while True:
notification = self.next_notification()
if (
notification.method == "item/agentMessage/delta"
and isinstance(notification.payload, AgentMessageDeltaNotification)
and notification.payload.turn_id == turn_id
):
yield notification.payload
continue
if (
notification.method == "turn/completed"
and isinstance(notification.payload, TurnCompletedNotification)
and notification.payload.turn.id == turn_id
):
break
self.register_turn_notifications(turn_id)
try:
while True:
notification = self.next_turn_notification(turn_id)
if (
notification.method == "item/agentMessage/delta"
and isinstance(notification.payload, AgentMessageDeltaNotification)
and notification.payload.turn_id == turn_id
):
yield notification.payload
continue
if (
notification.method == "turn/completed"
and isinstance(notification.payload, TurnCompletedNotification)
and notification.payload.turn.id == turn_id
):
break
finally:
self.unregister_turn_notifications(turn_id)
def _coerce_notification(self, method: str, params: object) -> Notification:
params_dict = params if isinstance(params, dict) else {}
model = NOTIFICATION_MODELS.get(method)
if model is None:
return Notification(method=method, payload=UnknownNotification(params=params_dict))
return Notification(
method=method, payload=UnknownNotification(params=params_dict)
)
try:
payload = model.model_validate(params_dict)
except Exception: # noqa: BLE001
return Notification(method=method, payload=UnknownNotification(params=params_dict))
return Notification(
method=method, payload=UnknownNotification(params=params_dict)
)
return Notification(method=method, payload=payload)
def _normalize_input_items(
@@ -477,7 +474,9 @@ class AppServerClient:
return [input_items]
return input_items
def _default_approval_handler(self, method: str, params: JsonObject | None) -> JsonObject:
def _default_approval_handler(
self, method: str, params: JsonObject | None
) -> JsonObject:
if method == "item/commandExecution/requestApproval":
return {"decision": "accept"}
if method == "item/fileChange/requestApproval":
@@ -498,6 +497,32 @@ class AppServerClient:
self._stderr_thread = threading.Thread(target=_drain, daemon=True)
self._stderr_thread.start()
def _start_reader_thread(self) -> None:
if self._proc is None or self._proc.stdout is None:
return
self._reader_thread = threading.Thread(target=self._reader_loop, daemon=True)
self._reader_thread.start()
def _reader_loop(self) -> None:
try:
while True:
msg = self._read_message()
if "method" in msg and "id" in msg:
response = self._handle_server_request(msg)
self._write_message({"id": msg["id"], "result": response})
continue
if "method" in msg and "id" not in msg:
method = msg["method"]
if isinstance(method, str):
self._router.route_notification(
self._coerce_notification(method, msg.get("params"))
)
continue
self._router.route_response(msg)
except BaseException as exc:
self._router.fail_all(exc)
def _stderr_tail(self, limit: int = 40) -> str:
return "\n".join(list(self._stderr_lines)[-limit:])
@@ -130,3 +130,43 @@ NOTIFICATION_MODELS: dict[str, type[BaseModel]] = {
"windows/worldWritableWarning": WindowsWorldWritableWarningNotification,
"windowsSandbox/setupCompleted": WindowsSandboxSetupCompletedNotification,
}
DIRECT_TURN_ID_NOTIFICATION_TYPES: tuple[type[BaseModel], ...] = (
AgentMessageDeltaNotification,
CommandExecutionOutputDeltaNotification,
ContextCompactedNotification,
ErrorNotification,
FileChangeOutputDeltaNotification,
FileChangePatchUpdatedNotification,
HookCompletedNotification,
HookStartedNotification,
ItemCompletedNotification,
ItemGuardianApprovalReviewCompletedNotification,
ItemGuardianApprovalReviewStartedNotification,
ItemStartedNotification,
McpToolCallProgressNotification,
ModelReroutedNotification,
ModelVerificationNotification,
PlanDeltaNotification,
ReasoningSummaryPartAddedNotification,
ReasoningSummaryTextDeltaNotification,
ReasoningTextDeltaNotification,
TerminalInteractionNotification,
ThreadGoalUpdatedNotification,
ThreadTokenUsageUpdatedNotification,
TurnDiffUpdatedNotification,
TurnPlanUpdatedNotification,
)
NESTED_TURN_NOTIFICATION_TYPES: tuple[type[BaseModel], ...] = (
TurnCompletedNotification,
TurnStartedNotification,
)
def notification_turn_id(payload: BaseModel) -> str | None:
if isinstance(payload, DIRECT_TURN_ID_NOTIFICATION_TYPES):
return payload.turn_id if isinstance(payload.turn_id, str) else None
if isinstance(payload, NESTED_TURN_NOTIFICATION_TYPES):
return payload.turn.id
return None
+150 -8
View File
@@ -2,11 +2,17 @@ from __future__ import annotations
import asyncio
import time
from types import SimpleNamespace
from codex_app_server.async_client import AsyncAppServerClient
from codex_app_server.generated.v2_all import (
AgentMessageDeltaNotification,
TurnCompletedNotification,
)
from codex_app_server.models import Notification, UnknownNotification
def test_async_client_serializes_transport_calls() -> None:
def test_async_client_allows_concurrent_transport_calls() -> None:
async def scenario() -> int:
client = AsyncAppServerClient()
active = 0
@@ -24,10 +30,10 @@ def test_async_client_serializes_transport_calls() -> None:
await asyncio.gather(client.model_list(), client.model_list())
return max_active
assert asyncio.run(scenario()) == 1
assert asyncio.run(scenario()) == 2
def test_async_stream_text_is_incremental_and_blocks_parallel_calls() -> None:
def test_async_stream_text_is_incremental_without_blocking_parallel_calls() -> None:
async def scenario() -> tuple[str, list[str], bool]:
client = AsyncAppServerClient()
@@ -46,19 +52,155 @@ def test_async_stream_text_is_incremental_and_blocks_parallel_calls() -> None:
stream = client.stream_text("thread-1", "hello")
first = await anext(stream)
blocked_before_stream_done = False
competing_call = asyncio.create_task(client.model_list())
await asyncio.sleep(0.01)
blocked_before_stream_done = not competing_call.done()
competing_call_done_before_stream_done = competing_call.done()
remaining: list[str] = []
async for item in stream:
remaining.append(item)
await competing_call
return first, remaining, blocked_before_stream_done
return first, remaining, competing_call_done_before_stream_done
first, remaining, blocked = asyncio.run(scenario())
first, remaining, was_unblocked = asyncio.run(scenario())
assert first == "first"
assert remaining == ["second", "third"]
assert blocked
assert was_unblocked
def test_async_client_turn_notification_methods_delegate_to_sync_client() -> None:
async def scenario() -> tuple[list[tuple[str, str]], Notification, str]:
client = AsyncAppServerClient()
event = Notification(
method="unknown/direct",
payload=UnknownNotification(params={"turnId": "turn-1"}),
)
completed = TurnCompletedNotification.model_validate(
{
"threadId": "thread-1",
"turn": {"id": "turn-1", "items": [], "status": "completed"},
}
)
calls: list[tuple[str, str]] = []
def fake_register(turn_id: str) -> None:
calls.append(("register", turn_id))
def fake_unregister(turn_id: str) -> None:
calls.append(("unregister", turn_id))
def fake_next(turn_id: str) -> Notification:
calls.append(("next", turn_id))
return event
def fake_wait(turn_id: str) -> TurnCompletedNotification:
calls.append(("wait", turn_id))
return completed
client._sync.register_turn_notifications = fake_register # type: ignore[method-assign]
client._sync.unregister_turn_notifications = fake_unregister # type: ignore[method-assign]
client._sync.next_turn_notification = fake_next # type: ignore[method-assign]
client._sync.wait_for_turn_completed = fake_wait # type: ignore[method-assign]
client.register_turn_notifications("turn-1")
next_event = await client.next_turn_notification("turn-1")
completed_event = await client.wait_for_turn_completed("turn-1")
client.unregister_turn_notifications("turn-1")
return calls, next_event, completed_event.turn.id
calls, next_event, completed_turn_id = asyncio.run(scenario())
assert (
calls,
next_event,
completed_turn_id,
) == (
[
("register", "turn-1"),
("next", "turn-1"),
("wait", "turn-1"),
("unregister", "turn-1"),
],
Notification(
method="unknown/direct",
payload=UnknownNotification(params={"turnId": "turn-1"}),
),
"turn-1",
)
def test_async_stream_text_uses_sync_turn_routing() -> None:
async def scenario() -> tuple[list[tuple[str, str]], list[str]]:
client = AsyncAppServerClient()
notifications = [
Notification(
method="item/agentMessage/delta",
payload=AgentMessageDeltaNotification.model_validate(
{
"delta": "first",
"itemId": "item-1",
"threadId": "thread-1",
"turnId": "turn-1",
}
),
),
Notification(
method="item/agentMessage/delta",
payload=AgentMessageDeltaNotification.model_validate(
{
"delta": "second",
"itemId": "item-2",
"threadId": "thread-1",
"turnId": "turn-1",
}
),
),
Notification(
method="turn/completed",
payload=TurnCompletedNotification.model_validate(
{
"threadId": "thread-1",
"turn": {"id": "turn-1", "items": [], "status": "completed"},
}
),
),
]
calls: list[tuple[str, str]] = []
def fake_turn_start(thread_id: str, text: str, *, params=None): # type: ignore[no-untyped-def]
calls.append(("turn_start", thread_id))
return SimpleNamespace(turn=SimpleNamespace(id="turn-1"))
def fake_register(turn_id: str) -> None:
calls.append(("register", turn_id))
def fake_next(turn_id: str) -> Notification:
calls.append(("next", turn_id))
return notifications.pop(0)
def fake_unregister(turn_id: str) -> None:
calls.append(("unregister", turn_id))
client._sync.turn_start = fake_turn_start # type: ignore[method-assign]
client._sync.register_turn_notifications = fake_register # type: ignore[method-assign]
client._sync.next_turn_notification = fake_next # type: ignore[method-assign]
client._sync.unregister_turn_notifications = fake_unregister # type: ignore[method-assign]
chunks = [chunk async for chunk in client.stream_text("thread-1", "hello")]
return calls, [chunk.delta for chunk in chunks]
calls, deltas = asyncio.run(scenario())
assert (calls, deltas) == (
[
("turn_start", "thread-1"),
("register", "turn-1"),
("next", "turn-1"),
("next", "turn-1"),
("next", "turn-1"),
("unregister", "turn-1"),
],
["first", "second"],
)
+222 -1
View File
@@ -4,13 +4,17 @@ from pathlib import Path
from typing import Any
from codex_app_server.client import AppServerClient, _params_dict
from codex_app_server.generated.notification_registry import notification_turn_id
from codex_app_server.generated.v2_all import (
AgentMessageDeltaNotification,
ApprovalsReviewer,
ThreadListParams,
ThreadResumeResponse,
ThreadTokenUsageUpdatedNotification,
TurnCompletedNotification,
WarningNotification,
)
from codex_app_server.models import UnknownNotification
from codex_app_server.models import Notification, UnknownNotification
ROOT = Path(__file__).resolve().parents[1]
@@ -128,3 +132,220 @@ def test_invalid_notification_payload_falls_back_to_unknown() -> None:
assert event.method == "thread/tokenUsage/updated"
assert isinstance(event.payload, UnknownNotification)
def test_generated_notification_turn_id_handles_known_payload_shapes() -> None:
direct = AgentMessageDeltaNotification.model_validate(
{
"delta": "hello",
"itemId": "item-1",
"threadId": "thread-1",
"turnId": "turn-1",
}
)
nested = TurnCompletedNotification.model_validate(
{
"threadId": "thread-1",
"turn": {"id": "turn-2", "items": [], "status": "completed"},
}
)
unscoped = WarningNotification(message="heads up")
assert [
notification_turn_id(direct),
notification_turn_id(nested),
notification_turn_id(unscoped),
] == ["turn-1", "turn-2", None]
def test_turn_notification_router_demuxes_registered_turns() -> None:
client = AppServerClient()
client.register_turn_notifications("turn-1")
client.register_turn_notifications("turn-2")
client._router.route_notification(
client._coerce_notification(
"item/agentMessage/delta",
{
"delta": "two",
"itemId": "item-2",
"threadId": "thread-1",
"turnId": "turn-2",
},
)
)
client._router.route_notification(
client._coerce_notification(
"item/agentMessage/delta",
{
"delta": "one",
"itemId": "item-1",
"threadId": "thread-1",
"turnId": "turn-1",
},
)
)
first = client.next_turn_notification("turn-1")
second = client.next_turn_notification("turn-2")
assert isinstance(first.payload, AgentMessageDeltaNotification)
assert isinstance(second.payload, AgentMessageDeltaNotification)
assert [
(first.method, first.payload.delta),
(second.method, second.payload.delta),
] == [
("item/agentMessage/delta", "one"),
("item/agentMessage/delta", "two"),
]
def test_client_reader_routes_interleaved_turn_notifications_by_turn_id() -> None:
client = AppServerClient()
client.register_turn_notifications("turn-1")
client.register_turn_notifications("turn-2")
messages: list[dict[str, object]] = [
{
"method": "item/agentMessage/delta",
"params": {
"delta": "one-a",
"itemId": "item-1",
"threadId": "thread-1",
"turnId": "turn-1",
},
},
{
"method": "item/agentMessage/delta",
"params": {
"delta": "two-a",
"itemId": "item-2",
"threadId": "thread-1",
"turnId": "turn-2",
},
},
{
"method": "item/agentMessage/delta",
"params": {
"delta": "one-b",
"itemId": "item-3",
"threadId": "thread-1",
"turnId": "turn-1",
},
},
{
"method": "item/agentMessage/delta",
"params": {
"delta": "two-b",
"itemId": "item-4",
"threadId": "thread-1",
"turnId": "turn-2",
},
},
]
def fake_read_message() -> dict[str, object]:
if messages:
return messages.pop(0)
raise EOFError
client._read_message = fake_read_message # type: ignore[method-assign]
client._reader_loop()
first_turn_events = [
client.next_turn_notification("turn-1"),
client.next_turn_notification("turn-1"),
]
second_turn_events = [
client.next_turn_notification("turn-2"),
client.next_turn_notification("turn-2"),
]
first_turn_deltas = [
event.payload.delta
for event in first_turn_events
if isinstance(event.payload, AgentMessageDeltaNotification)
]
second_turn_deltas = [
event.payload.delta
for event in second_turn_events
if isinstance(event.payload, AgentMessageDeltaNotification)
]
assert (first_turn_deltas, second_turn_deltas) == (
["one-a", "one-b"],
["two-a", "two-b"],
)
def test_turn_notification_router_buffers_events_before_registration() -> None:
client = AppServerClient()
client._router.route_notification(
client._coerce_notification(
"item/agentMessage/delta",
{
"delta": "early",
"itemId": "item-1",
"threadId": "thread-1",
"turnId": "turn-1",
},
)
)
client.register_turn_notifications("turn-1")
event = client.next_turn_notification("turn-1")
assert isinstance(event.payload, AgentMessageDeltaNotification)
assert (event.method, event.payload.delta) == (
"item/agentMessage/delta",
"early",
)
def test_turn_notification_router_clears_unregistered_turn_when_completed() -> None:
client = AppServerClient()
client._router.route_notification(
client._coerce_notification(
"item/agentMessage/delta",
{
"delta": "early",
"itemId": "item-1",
"threadId": "thread-1",
"turnId": "turn-1",
},
)
)
client._router.route_notification(
client._coerce_notification(
"turn/completed",
{
"threadId": "thread-1",
"turn": {"id": "turn-1", "items": [], "status": "completed"},
},
)
)
assert client._router._pending_turn_notifications == {}
def test_turn_notification_router_routes_unknown_turn_notifications() -> None:
client = AppServerClient()
client.register_turn_notifications("turn-1")
client.register_turn_notifications("turn-2")
client._router.route_notification(
Notification(
method="unknown/direct",
payload=UnknownNotification(params={"turnId": "turn-1"}),
)
)
client._router.route_notification(
Notification(
method="unknown/nested",
payload=UnknownNotification(params={"turn": {"id": "turn-2"}}),
)
)
first = client.next_turn_notification("turn-1")
second = client.next_turn_notification("turn-2")
assert [first.method, second.method] == ["unknown/direct", "unknown/nested"]
@@ -226,54 +226,74 @@ def test_async_codex_initializes_only_once_under_concurrency() -> None:
asyncio.run(scenario())
def test_turn_stream_rejects_second_active_consumer() -> None:
def test_turn_streams_can_consume_multiple_turns_on_one_client() -> None:
client = AppServerClient()
notifications: deque[Notification] = deque(
[
_delta_notification(turn_id="turn-1"),
_completed_notification(turn_id="turn-1"),
]
)
client.next_notification = notifications.popleft # type: ignore[method-assign]
notifications: dict[str, deque[Notification]] = {
"turn-1": deque(
[
_delta_notification(turn_id="turn-1", text="one"),
_completed_notification(turn_id="turn-1"),
]
),
"turn-2": deque(
[
_delta_notification(turn_id="turn-2", text="two"),
_completed_notification(turn_id="turn-2"),
]
),
}
client.next_turn_notification = lambda turn_id: notifications[turn_id].popleft() # type: ignore[method-assign]
first_stream = TurnHandle(client, "thread-1", "turn-1").stream()
assert next(first_stream).method == "item/agentMessage/delta"
second_stream = TurnHandle(client, "thread-1", "turn-2").stream()
with pytest.raises(RuntimeError, match="Concurrent turn consumers are not yet supported"):
next(second_stream)
assert next(second_stream).method == "item/agentMessage/delta"
assert next(first_stream).method == "turn/completed"
assert next(second_stream).method == "turn/completed"
first_stream.close()
second_stream.close()
def test_async_turn_stream_rejects_second_active_consumer() -> None:
def test_async_turn_streams_can_consume_multiple_turns_on_one_client() -> None:
async def scenario() -> None:
codex = AsyncCodex()
async def fake_ensure_initialized() -> None:
return None
notifications: deque[Notification] = deque(
[
_delta_notification(turn_id="turn-1"),
_completed_notification(turn_id="turn-1"),
]
)
notifications: dict[str, deque[Notification]] = {
"turn-1": deque(
[
_delta_notification(turn_id="turn-1", text="one"),
_completed_notification(turn_id="turn-1"),
]
),
"turn-2": deque(
[
_delta_notification(turn_id="turn-2", text="two"),
_completed_notification(turn_id="turn-2"),
]
),
}
async def fake_next_notification() -> Notification:
return notifications.popleft()
async def fake_next_notification(turn_id: str) -> Notification:
return notifications[turn_id].popleft()
codex._ensure_initialized = fake_ensure_initialized # type: ignore[method-assign]
codex._client.next_notification = fake_next_notification # type: ignore[method-assign]
codex._client.next_turn_notification = fake_next_notification # type: ignore[method-assign]
first_stream = AsyncTurnHandle(codex, "thread-1", "turn-1").stream()
assert (await anext(first_stream)).method == "item/agentMessage/delta"
second_stream = AsyncTurnHandle(codex, "thread-1", "turn-2").stream()
with pytest.raises(RuntimeError, match="Concurrent turn consumers are not yet supported"):
await anext(second_stream)
assert (await anext(second_stream)).method == "item/agentMessage/delta"
assert (await anext(first_stream)).method == "turn/completed"
assert (await anext(second_stream)).method == "turn/completed"
await first_stream.aclose()
await second_stream.aclose()
asyncio.run(scenario())
@@ -285,7 +305,7 @@ def test_turn_run_returns_completed_turn_payload() -> None:
_completed_notification(),
]
)
client.next_notification = notifications.popleft # type: ignore[method-assign]
client.next_turn_notification = lambda _turn_id: notifications.popleft() # type: ignore[method-assign]
result = TurnHandle(client, "thread-1", "turn-1").run()
@@ -305,7 +325,7 @@ def test_thread_run_accepts_string_input_and_returns_run_result() -> None:
_completed_notification(),
]
)
client.next_notification = notifications.popleft # type: ignore[method-assign]
client.next_turn_notification = lambda _turn_id: notifications.popleft() # type: ignore[method-assign]
seen: dict[str, object] = {}
def fake_turn_start(thread_id: str, wire_input: object, *, params=None): # noqa: ANN001,ANN202
@@ -338,7 +358,7 @@ def test_thread_run_uses_last_completed_assistant_message_as_final_response() ->
_completed_notification(),
]
)
client.next_notification = notifications.popleft # type: ignore[method-assign]
client.next_turn_notification = lambda _turn_id: notifications.popleft() # type: ignore[method-assign]
client.turn_start = lambda thread_id, wire_input, *, params=None: SimpleNamespace( # noqa: ARG005,E731
turn=SimpleNamespace(id="turn-1")
)
@@ -363,7 +383,7 @@ def test_thread_run_preserves_empty_last_assistant_message() -> None:
_completed_notification(),
]
)
client.next_notification = notifications.popleft # type: ignore[method-assign]
client.next_turn_notification = lambda _turn_id: notifications.popleft() # type: ignore[method-assign]
client.turn_start = lambda thread_id, wire_input, *, params=None: SimpleNamespace( # noqa: ARG005,E731
turn=SimpleNamespace(id="turn-1")
)
@@ -394,7 +414,7 @@ def test_thread_run_prefers_explicit_final_answer_over_later_commentary() -> Non
_completed_notification(),
]
)
client.next_notification = notifications.popleft # type: ignore[method-assign]
client.next_turn_notification = lambda _turn_id: notifications.popleft() # type: ignore[method-assign]
client.turn_start = lambda thread_id, wire_input, *, params=None: SimpleNamespace( # noqa: ARG005,E731
turn=SimpleNamespace(id="turn-1")
)
@@ -420,7 +440,7 @@ def test_thread_run_returns_none_when_only_commentary_messages_complete() -> Non
_completed_notification(),
]
)
client.next_notification = notifications.popleft # type: ignore[method-assign]
client.next_turn_notification = lambda _turn_id: notifications.popleft() # type: ignore[method-assign]
client.turn_start = lambda thread_id, wire_input, *, params=None: SimpleNamespace( # noqa: ARG005,E731
turn=SimpleNamespace(id="turn-1")
)
@@ -438,7 +458,7 @@ def test_thread_run_raises_on_failed_turn() -> None:
_completed_notification(status="failed", error_message="boom"),
]
)
client.next_notification = notifications.popleft # type: ignore[method-assign]
client.next_turn_notification = lambda _turn_id: notifications.popleft() # type: ignore[method-assign]
client.turn_start = lambda thread_id, wire_input, *, params=None: SimpleNamespace( # noqa: ARG005,E731
turn=SimpleNamespace(id="turn-1")
)
@@ -447,6 +467,48 @@ def test_thread_run_raises_on_failed_turn() -> None:
Thread(client, "thread-1").run("hello")
def test_stream_text_registers_and_consumes_turn_notifications() -> None:
client = AppServerClient()
notifications: deque[Notification] = deque(
[
_delta_notification(text="first"),
_delta_notification(text="second"),
_completed_notification(),
]
)
calls: list[tuple[str, str]] = []
client.turn_start = lambda thread_id, input_items, *, params=None: SimpleNamespace( # noqa: ARG005,E731
turn=SimpleNamespace(id="turn-1")
)
def fake_register(turn_id: str) -> None:
calls.append(("register", turn_id))
def fake_next(turn_id: str) -> Notification:
calls.append(("next", turn_id))
return notifications.popleft()
def fake_unregister(turn_id: str) -> None:
calls.append(("unregister", turn_id))
client.register_turn_notifications = fake_register # type: ignore[method-assign]
client.next_turn_notification = fake_next # type: ignore[method-assign]
client.unregister_turn_notifications = fake_unregister # type: ignore[method-assign]
chunks = list(client.stream_text("thread-1", "hello"))
assert ([chunk.delta for chunk in chunks], calls) == (
["first", "second"],
[
("register", "turn-1"),
("next", "turn-1"),
("next", "turn-1"),
("next", "turn-1"),
("unregister", "turn-1"),
],
)
def test_async_thread_run_accepts_string_input_and_returns_run_result() -> None:
async def scenario() -> None:
codex = AsyncCodex()
@@ -471,12 +533,12 @@ def test_async_thread_run_accepts_string_input_and_returns_run_result() -> None:
seen["params"] = params
return SimpleNamespace(turn=SimpleNamespace(id="turn-1"))
async def fake_next_notification() -> Notification:
async def fake_next_notification(_turn_id: str) -> Notification:
return notifications.popleft()
codex._ensure_initialized = fake_ensure_initialized # type: ignore[method-assign]
codex._client.turn_start = fake_turn_start # type: ignore[method-assign]
codex._client.next_notification = fake_next_notification # type: ignore[method-assign]
codex._client.next_turn_notification = fake_next_notification # type: ignore[method-assign]
result = await AsyncThread(codex, "thread-1").run("hello")
@@ -491,15 +553,21 @@ def test_async_thread_run_accepts_string_input_and_returns_run_result() -> None:
asyncio.run(scenario())
def test_async_thread_run_uses_last_completed_assistant_message_as_final_response() -> None:
def test_async_thread_run_uses_last_completed_assistant_message_as_final_response() -> (
None
):
async def scenario() -> None:
codex = AsyncCodex()
async def fake_ensure_initialized() -> None:
return None
first_item_notification = _item_completed_notification(text="First async message")
second_item_notification = _item_completed_notification(text="Second async message")
first_item_notification = _item_completed_notification(
text="First async message"
)
second_item_notification = _item_completed_notification(
text="Second async message"
)
notifications: deque[Notification] = deque(
[
first_item_notification,
@@ -511,12 +579,12 @@ def test_async_thread_run_uses_last_completed_assistant_message_as_final_respons
async def fake_turn_start(thread_id: str, wire_input: object, *, params=None): # noqa: ANN001,ANN202,ARG001
return SimpleNamespace(turn=SimpleNamespace(id="turn-1"))
async def fake_next_notification() -> Notification:
async def fake_next_notification(_turn_id: str) -> Notification:
return notifications.popleft()
codex._ensure_initialized = fake_ensure_initialized # type: ignore[method-assign]
codex._client.turn_start = fake_turn_start # type: ignore[method-assign]
codex._client.next_notification = fake_next_notification # type: ignore[method-assign]
codex._client.next_turn_notification = fake_next_notification # type: ignore[method-assign]
result = await AsyncThread(codex, "thread-1").run("hello")
@@ -550,12 +618,12 @@ def test_async_thread_run_returns_none_when_only_commentary_messages_complete()
async def fake_turn_start(thread_id: str, wire_input: object, *, params=None): # noqa: ANN001,ANN202,ARG001
return SimpleNamespace(turn=SimpleNamespace(id="turn-1"))
async def fake_next_notification() -> Notification:
async def fake_next_notification(_turn_id: str) -> Notification:
return notifications.popleft()
codex._ensure_initialized = fake_ensure_initialized # type: ignore[method-assign]
codex._client.turn_start = fake_turn_start # type: ignore[method-assign]
codex._client.next_notification = fake_next_notification # type: ignore[method-assign]
codex._client.next_turn_notification = fake_next_notification # type: ignore[method-assign]
result = await AsyncThread(codex, "thread-1").run("hello")