diff --git a/sdk/python/src/openai_codex/_goal.py b/sdk/python/src/openai_codex/_goal.py index dc1fe358c..37a63fc28 100644 --- a/sdk/python/src/openai_codex/_goal.py +++ b/sdk/python/src/openai_codex/_goal.py @@ -1,8 +1,12 @@ +import asyncio import queue import threading import time +from collections import deque from dataclasses import dataclass, field +from typing import AsyncIterator, Awaitable, Callable, Iterator +from .generated.notification_registry import notification_turn_id from .generated.v2_all import ( ThreadGoalClearedNotification, ThreadGoalStatus, @@ -10,8 +14,9 @@ from .generated.v2_all import ( Turn, TurnCompletedNotification, TurnStartedNotification, + TurnStatus, ) -from .models import Notification +from .models import Notification, UnknownNotification class _GoalStreamClosed(Exception): @@ -45,10 +50,16 @@ class _GoalOperationState: _notifications: queue.Queue[Notification | BaseException] = field(default_factory=queue.Queue) _failure: BaseException | None = None _finished: bool = False + _turn_routing_active: bool = False - def observe(self, notification: Notification) -> None: + def observe(self, notification: Notification) -> bool: payload = notification.payload with self._condition: + if not self._turn_routing_active and not isinstance( + payload, + ThreadGoalClearedNotification | ThreadGoalUpdatedNotification, + ): + return False if isinstance(payload, TurnStartedNotification): if self.logical_turn_id is None: self.logical_turn_id = payload.turn.id @@ -73,6 +84,12 @@ class _GoalOperationState: self._finished = True self._condition.notify_all() self._notifications.put(notification) + return True + + def activate_turn_routing(self) -> None: + """Accept physical turns after the previous stored goal is cleared.""" + with self._condition: + self._turn_routing_active = True def wait_for_start(self, timeout: float) -> str | None: """Wait for the runtime-generated first turn without consuming its event.""" @@ -121,16 +138,18 @@ class _GoalOperationState: with self._condition: self.interrupted = True self.interrupt_requested = False + self._condition.notify_all() def cancel_interrupt(self) -> None: with self._condition: self.interrupt_requested = False + self._condition.notify_all() - def explicit_interrupt(self, status: ThreadGoalStatus | None) -> bool: + def explicit_interrupt(self) -> bool: with self._condition: - return self.interrupted or ( - self.interrupt_requested and status == ThreadGoalStatus.paused - ) + while self.interrupt_requested: + self._condition.wait() + return self.interrupted def active_turn(self, *, after: str | None = None) -> str | None: """Wait for the current turn, or return None once the goal has ended.""" @@ -151,6 +170,279 @@ class _GoalOperationState: with self._condition: return self.current_turn_id + def resolve_active_turn(self, expected: str, active: str) -> None: + """Adopt a server-reported active id when routed state is still stale.""" + with self._condition: + if self.current_turn_id in {None, expected}: + self.current_turn_id = active + self._condition.notify_all() + + def turn_for_interrupt(self) -> str | None: + """Return an active or stale turn id that can resolve rollover races.""" + with self._condition: + if self.current_turn_id is not None: + return self.current_turn_id + if self.completed_turn is not None: + return self.completed_turn.id + if self.started_turn is not None: + return self.started_turn.id + return None + def wake_notification_reader(self) -> None: """Release a reader blocked after its stream has been closed.""" self._notifications.put(_GoalStreamClosed()) + + +def _logical_notification(notification: Notification, logical_turn_id: str) -> Notification: + """Return a copy whose turn metadata uses the logical operation id.""" + payload = notification.payload + if isinstance(payload, UnknownNotification): + params = dict(payload.params) + if isinstance(params.get("turnId"), str): + params["turnId"] = logical_turn_id + turn = params.get("turn") + if isinstance(turn, dict) and isinstance(turn.get("id"), str): + params["turn"] = {**turn, "id": logical_turn_id} + return Notification(notification.method, UnknownNotification(params)) + + turn_id = notification_turn_id(payload) + if turn_id is None: + return notification + if hasattr(payload, "turn_id"): + return Notification( + notification.method, + payload.model_copy(update={"turn_id": logical_turn_id}), + ) + if hasattr(payload, "turn"): + logical_turn = payload.turn.model_copy(update={"id": logical_turn_id}) + return Notification( + notification.method, + payload.model_copy(update={"turn": logical_turn}), + ) + return notification + + +def _logical_completion( + completed: TurnCompletedNotification, + *, + logical_turn_id: str, + started: Turn | None, + interrupted: bool, +) -> TurnCompletedNotification: + """Coalesce the final physical completion into one logical completion.""" + final_turn = completed.turn + started_at = started.started_at if started is not None else final_turn.started_at + duration_ms = final_turn.duration_ms + if started_at is not None and final_turn.completed_at is not None: + duration_ms = max(0, final_turn.completed_at - started_at) * 1000 + updates: dict[str, object] = { + "id": logical_turn_id, + "started_at": started_at, + "duration_ms": duration_ms, + } + if interrupted: + updates["status"] = TurnStatus.interrupted + return completed.model_copy(update={"turn": final_turn.model_copy(update=updates)}) + + +@dataclass(slots=True) +class _GoalStreamCursor: + """Consume physical goal events as one ordered logical turn stream.""" + + state: _GoalOperationState + started: Turn | None = None + last_completed: TurnCompletedNotification | None = None + failed_completion: TurnCompletedNotification | None = None + status: ThreadGoalStatus | None = None + active: bool = False + cleared: bool = False + + def process(self, notification: Notification) -> tuple[list[Notification], bool]: + logical_turn_id = self.state.logical_turn_id + if logical_turn_id is None: + raise RuntimeError("goal operation has not been bound to a logical turn id") + + payload = notification.payload + if isinstance(payload, TurnStartedNotification): + self.active = True + if self.started is not None: + return [], False + self.started = payload.turn + return [_logical_notification(notification, logical_turn_id)], False + + if isinstance(payload, TurnCompletedNotification): + self.active = False + self.last_completed = payload + if payload.turn.status == TurnStatus.interrupted: + return [ + self._completion( + notification.method, + self.failed_completion or payload, + ) + ], True + if payload.turn.status == TurnStatus.failed: + self.failed_completion = payload + if self.cleared or _terminal_goal_status(self.status): + self.state.finish() + return [self._completion(notification.method, payload)], True + return [], False + if self.status is None and not self.cleared: + raise RuntimeError( + "the connected Codex runtime did not activate goal mode for this turn" + ) + if self.cleared or _terminal_goal_status(self.status): + self.state.finish() + return [ + self._completion( + notification.method, + self.failed_completion or payload, + ) + ], True + return [], False + + events = [_logical_notification(notification, logical_turn_id)] + if isinstance(payload, ThreadGoalUpdatedNotification): + self.status = payload.goal.status + if self.status == ThreadGoalStatus.active: + self.cleared = False + events = [] + elif isinstance(payload, ThreadGoalClearedNotification): + self.cleared = True + events = [] + + if ( + not self.active + and self.last_completed is not None + and (self.cleared or _terminal_goal_status(self.status)) + ): + self.state.finish() + events.append( + self._completion( + "turn/completed", + self.failed_completion or self.last_completed, + ) + ) + return events, True + return events, False + + def _completion( + self, + method: str, + payload: TurnCompletedNotification, + ) -> Notification: + logical_turn_id = self.state.logical_turn_id + if logical_turn_id is None: + raise RuntimeError("goal operation has not been bound to a logical turn id") + return Notification( + method, + _logical_completion( + payload, + logical_turn_id=logical_turn_id, + started=self.started, + interrupted=self.state.explicit_interrupt(), + ), + ) + + +@dataclass(slots=True) +class _GoalNotificationStream(Iterator[Notification]): + """Closeable synchronous view of one logical goal operation.""" + + state: _GoalOperationState + next_notification: Callable[[], Notification] + unregister: Callable[[], None] + cancel_goal: Callable[[], None] + _cursor: _GoalStreamCursor = field(init=False) + _pending: deque[Notification] = field(default_factory=deque) + _closed: bool = False + + def __post_init__(self) -> None: + self._cursor = _GoalStreamCursor(self.state) + + def __iter__(self) -> "_GoalNotificationStream": + return self + + def __next__(self) -> Notification: + if self._closed: + raise StopIteration + try: + while not self._pending: + notification = self.next_notification() + events, completed = self._cursor.process(notification) + self._pending.extend(events) + if completed: + self._finish() + return self._pending.popleft() + except _GoalStreamClosed: + self.close() + raise StopIteration from None + except KeyboardInterrupt: + self.cancel_goal() + self.close() + raise + except BaseException: + self.close() + raise + + def _finish(self) -> None: + if self._closed: + return + self.state.finish() + self.state.wake_notification_reader() + self.unregister() + self._closed = True + + def close(self) -> None: + self._finish() + + +@dataclass(slots=True) +class _AsyncGoalNotificationStream(AsyncIterator[Notification]): + """Closeable asynchronous view of one logical goal operation.""" + + state: _GoalOperationState + next_notification: Callable[[], Awaitable[Notification]] + unregister: Callable[[], None] + cancel_goal: Callable[[], Awaitable[None]] + _cursor: _GoalStreamCursor = field(init=False) + _pending: deque[Notification] = field(default_factory=deque) + _closed: bool = False + + def __post_init__(self) -> None: + self._cursor = _GoalStreamCursor(self.state) + + def __aiter__(self) -> "_AsyncGoalNotificationStream": + return self + + async def __anext__(self) -> Notification: + if self._closed: + raise StopAsyncIteration + try: + while not self._pending: + notification = await self.next_notification() + events, completed = self._cursor.process(notification) + self._pending.extend(events) + if completed: + self._finish() + return self._pending.popleft() + except _GoalStreamClosed: + await self.aclose() + raise StopAsyncIteration from None + except asyncio.CancelledError: + await self.cancel_goal() + await self.aclose() + raise + except BaseException: + await self.aclose() + raise + + def _finish(self) -> None: + if self._closed: + return + self.state.finish() + self.state.wake_notification_reader() + self.unregister() + self._closed = True + + async def aclose(self) -> None: + self._finish() diff --git a/sdk/python/src/openai_codex/_message_router.py b/sdk/python/src/openai_codex/_message_router.py index e71544a0e..c979c8c8d 100644 --- a/sdk/python/src/openai_codex/_message_router.py +++ b/sdk/python/src/openai_codex/_message_router.py @@ -121,10 +121,20 @@ class MessageRouter: def register_goal(self, thread_id: str) -> _GoalOperationState: """Register one thread-scoped logical goal operation before it starts.""" state = _GoalOperationState(thread_id=thread_id) + state.activate_turn_routing() + return self._register_goal(state) + + def reserve_goal(self, thread_id: str) -> _GoalOperationState: + """Reserve a thread route without accepting physical turns yet.""" + return self._register_goal(_GoalOperationState(thread_id=thread_id)) + + def _register_goal(self, state: _GoalOperationState) -> _GoalOperationState: with self._lock: - if thread_id in self._goal_operations: - raise RuntimeError(f"thread {thread_id!r} already has an active goal operation") - self._goal_operations[thread_id] = state + if state.thread_id in self._goal_operations: + raise RuntimeError( + f"thread {state.thread_id!r} already has an active goal operation" + ) + self._goal_operations[state.thread_id] = state return state def unregister_goal(self, state: _GoalOperationState) -> None: @@ -133,6 +143,11 @@ class MessageRouter: if self._goal_operations.get(state.thread_id) is state: self._goal_operations.pop(state.thread_id) + def has_goal(self, thread_id: str) -> bool: + """Return whether a logical goal operation owns this thread route.""" + with self._lock: + return thread_id in self._goal_operations + def route_response(self, msg: dict[str, JsonValue]) -> None: """Deliver a JSON-RPC response or error to its request waiter.""" @@ -181,10 +196,10 @@ class MessageRouter: if goal_state is not None and ( turn_id is not None or notification.method.startswith("thread/goal/") ): - goal_state.observe(notification) - if goal_state.is_finished(): - self.unregister_goal(goal_state) - return + if goal_state.observe(notification): + if goal_state.is_finished(): + self.unregister_goal(goal_state) + return if turn_id is None: self._global_notifications.put(notification) return diff --git a/sdk/python/src/openai_codex/async_client.py b/sdk/python/src/openai_codex/async_client.py index ff12f1688..07e7d0d40 100644 --- a/sdk/python/src/openai_codex/async_client.py +++ b/sdk/python/src/openai_codex/async_client.py @@ -1,7 +1,9 @@ from __future__ import annotations import asyncio +import threading from collections.abc import Iterator +from concurrent.futures import Future from typing import AsyncIterator, Callable, ParamSpec, TypeVar from pydantic import BaseModel @@ -227,6 +229,58 @@ class AsyncCodexClient: """Pause the active goal through the wrapped sync client.""" return await self._call_sync(self._sync.pause_goal, thread_id) + async def cancel_goal_operation(self, state: _GoalOperationState) -> None: + """Stop continuation work after a logical goal operation is cancelled.""" + await self._call_sync(self._sync.cancel_goal_operation, state) + + async def start_goal_operation( + self, + thread_id: str, + objective: str, + ) -> tuple[_GoalOperationState, str]: + """Start a logical goal through the wrapped sync client.""" + operation: Future[tuple[_GoalOperationState, str]] = Future() + + def start_operation() -> None: + try: + operation.set_result(self._sync.start_goal_operation(thread_id, objective)) + except BaseException as exc: + operation.set_exception(exc) + + worker = threading.Thread( + target=start_operation, + name="codex-goal-start", + daemon=True, + ) + worker.start() + try: + return await asyncio.shield(asyncio.wrap_future(operation)) + except asyncio.CancelledError: + + def cleanup_cancelled_start( + completed: Future[tuple[_GoalOperationState, str]], + ) -> None: + try: + state, _ = completed.result() + except BaseException: + return + + def stop_cancelled_goal() -> None: + try: + self._sync.cancel_goal_operation(state) + finally: + state.finish() + self._sync.unregister_goal_operation(state) + + threading.Thread( + target=stop_cancelled_goal, + name="codex-goal-start-cleanup", + daemon=True, + ).start() + + operation.add_done_callback(cleanup_cancelled_start) + raise + async def turn_start( self, thread_id: str, diff --git a/sdk/python/src/openai_codex/client.py b/sdk/python/src/openai_codex/client.py index 951f4f95f..ab5390ae5 100644 --- a/sdk/python/src/openai_codex/client.py +++ b/sdk/python/src/openai_codex/client.py @@ -1,10 +1,13 @@ import json import os +import re import subprocess import threading import uuid +from _thread import LockType from collections import deque -from dataclasses import dataclass +from contextlib import contextmanager +from dataclasses import dataclass, field from pathlib import Path from typing import Callable, Iterator, TypeVar @@ -13,7 +16,7 @@ from pydantic import BaseModel from ._goal import _GoalOperationState from ._message_router import MessageRouter from ._version import __version__ as SDK_VERSION -from .errors import CodexError, TransportClosedError +from .errors import CodexError, InvalidRequestError, TransportClosedError from .generated.notification_registry import NOTIFICATION_MODELS from .generated.v2_all import ( AccountLoginCompletedNotification, @@ -23,6 +26,7 @@ from .generated.v2_all import ( ChatgptLoginAccountResponse, GetAccountParams as V2GetAccountParams, GetAccountResponse, + IdleThreadStatus, LoginAccountParams as V2LoginAccountParams, LoginAccountResponse, LogoutAccountResponse, @@ -61,6 +65,18 @@ from .retry import retry_on_overload ModelT = TypeVar("ModelT", bound=BaseModel) ApprovalHandler = Callable[[str, JsonObject | None], JsonObject] RUNTIME_PKG_NAME = "openai-codex-cli-bin" +_GOAL_START_TIMEOUT_S = 30.0 + + +@dataclass(slots=True) +class _ThreadStartLock: + lock: LockType = field(default_factory=threading.Lock) + users: int = 0 + + +def _active_turn_id_from_error(exc: InvalidRequestError) -> str | None: + match = re.search(r" but found `?([^`]+)`?$", exc.message) + return match.group(1) if match is not None else None def _params_dict( @@ -205,6 +221,8 @@ class CodexClient: self._approval_handler = approval_handler or self._default_approval_handler self._proc: subprocess.Popen[str] | None = None self._lock = threading.Lock() + self._thread_start_locks_guard = threading.Lock() + self._thread_start_locks: dict[str, _ThreadStartLock] = {} self._router = MessageRouter() self._stderr_lines: deque[str] = deque(maxlen=400) self._stderr_thread: threading.Thread | None = None @@ -360,6 +378,10 @@ class CodexClient: """Register a private thread-scoped route for a logical goal turn.""" return self._router.register_goal(thread_id) + def reserve_goal_operation(self, thread_id: str) -> _GoalOperationState: + """Reserve a private thread route before replacing its stored goal.""" + return self._router.reserve_goal(thread_id) + def unregister_goal_operation(self, state: _GoalOperationState) -> None: """Release routing state for one logical goal turn.""" self._router.unregister_goal(state) @@ -499,6 +521,84 @@ class CodexClient: """Pause the active goal used by a logical goal turn.""" return self.thread_goal_set(thread_id, status=ThreadGoalStatus.paused) + def cancel_goal_operation(self, state: _GoalOperationState) -> None: + """Best-effort cleanup after a logical goal operation is cancelled.""" + try: + self.pause_goal(state.thread_id) + except Exception: + pass + self._interrupt_goal_operation(state) + + def _interrupt_goal_operation(self, state: _GoalOperationState) -> None: + turn_id = state.turn_for_interrupt() + if turn_id is None: + return + try: + self.turn_interrupt(state.thread_id, turn_id) + except InvalidRequestError as exc: + if not exc.message.startswith("expected active turn id"): + return + next_turn_id = _active_turn_id_from_error(exc) or state.current_turn() + if next_turn_id is None or next_turn_id == turn_id: + return + try: + self.turn_interrupt(state.thread_id, next_turn_id) + except Exception: + pass + except Exception: + pass + + def start_goal_operation( + self, + thread_id: str, + objective: str, + ) -> tuple[_GoalOperationState, str]: + """Start a logical goal and wait for its runtime-generated first turn.""" + with self._thread_start_lock(thread_id): + return self._start_goal_operation(thread_id, objective) + + def _start_goal_operation( + self, + thread_id: str, + objective: str, + ) -> tuple[_GoalOperationState, str]: + thread = self.thread_read(thread_id).thread + if not isinstance(thread.status.root, IdleThreadStatus): + raise InvalidRequestError( + -32600, + f"thread must be idle before starting a goal: {thread_id}", + ) + if thread.ephemeral or thread.path is None: + raise InvalidRequestError( + -32600, + f"thread must be persisted before starting a goal: {thread_id}", + ) + + state = self.reserve_goal_operation(thread_id) + activated = False + try: + self.thread_goal_clear(thread_id) + state.activate_turn_routing() + self.thread_goal_set( + thread_id, + objective=objective, + status=ThreadGoalStatus.active, + ) + activated = True + turn_id = state.wait_for_start(_GOAL_START_TIMEOUT_S) + if turn_id is None: + raise CodexError( + "timed out waiting for goal turn to start after " + f"{int(_GOAL_START_TIMEOUT_S)} seconds" + ) + return state, turn_id + except BaseException as exc: + if activated or not isinstance(exc, InvalidRequestError): + self.cancel_goal_operation(state) + state.finish() + self.unregister_goal_operation(state) + raise + def turn_start( self, thread_id: str, @@ -506,14 +606,37 @@ class CodexClient: params: V2TurnStartParams | JsonObject | None = None, ) -> TurnStartResponse: """Start a turn and register its notification queue as early as possible.""" - payload = { - **_params_dict(params), - "threadId": thread_id, - "input": self._normalize_input_items(input_items), - } - started = self.request("turn/start", payload, response_model=TurnStartResponse) - self.register_turn_notifications(started.turn.id) - return started + with self._thread_start_lock(thread_id): + if self._router.has_goal(thread_id): + raise InvalidRequestError( + -32600, + f"thread has an active goal operation: {thread_id}", + ) + payload = { + **_params_dict(params), + "threadId": thread_id, + "input": self._normalize_input_items(input_items), + } + started = self.request("turn/start", payload, response_model=TurnStartResponse) + self.register_turn_notifications(started.turn.id) + return started + + @contextmanager + def _thread_start_lock(self, thread_id: str) -> Iterator[None]: + with self._thread_start_locks_guard: + entry = self._thread_start_locks.get(thread_id) + if entry is None: + entry = _ThreadStartLock() + self._thread_start_locks[thread_id] = entry + entry.users += 1 + try: + with entry.lock: + yield + finally: + with self._thread_start_locks_guard: + entry.users -= 1 + if entry.users == 0: + self._thread_start_locks.pop(thread_id, None) def turn_interrupt(self, thread_id: str, turn_id: str) -> TurnInterruptResponse: return self.request( diff --git a/sdk/python/tests/test_app_server_goal_operations.py b/sdk/python/tests/test_app_server_goal_operations.py new file mode 100644 index 000000000..c280c7522 --- /dev/null +++ b/sdk/python/tests/test_app_server_goal_operations.py @@ -0,0 +1,90 @@ +from app_server_harness import ( + AppServerHarness, + ev_assistant_message, + ev_completed, + ev_function_call, + ev_response_created, + sse, +) +from app_server_helpers import agent_message_texts + +from openai_codex import Codex +from openai_codex._goal import _GoalNotificationStream +from openai_codex._run import _collect_turn_result +from openai_codex.generated.notification_registry import notification_turn_id +from openai_codex.generated.v2_all import TurnStatus + + +def test_private_goal_operation_coalesces_runtime_continuations(tmp_path) -> None: + """The private engine should expose automatic continuations as one turn.""" + with AppServerHarness(tmp_path) as harness: + harness.responses.enqueue_assistant_message( + "Initial pass complete.", + response_id="goal-initial", + ) + harness.responses.enqueue_sse( + sse( + [ + ev_response_created("goal-complete-tool"), + ev_function_call( + "call-goal-complete", + "update_goal", + '{"status":"complete"}', + ), + ev_completed("goal-complete-tool"), + ] + ) + ) + harness.responses.enqueue_sse( + sse( + [ + ev_response_created("goal-final"), + ev_assistant_message("msg-goal-final", "Goal complete."), + ev_completed("goal-final"), + ] + ) + ) + + with Codex(config=harness.app_server_config()) as codex: + thread = codex.thread_start() + state, turn_id = codex._client.start_goal_operation( # noqa: SLF001 + thread.id, + "Improve benchmark coverage", + ) + stream = _GoalNotificationStream( + state, + state.next_notification, + lambda: codex._client.unregister_goal_operation(state), # noqa: SLF001 + lambda: codex._client.cancel_goal_operation(state), # noqa: SLF001 + ) + events = list(stream) + result = _collect_turn_result(iter(events), turn_id=turn_id) + routes = codex._client._router._goal_operations.copy() # noqa: SLF001 + requests = harness.responses.wait_for_requests(3) + + lifecycle = [event.method for event in events if event.method.startswith("turn/")] + routed_ids = [ + routed_id + for event in events + if (routed_id := notification_turn_id(event.payload)) is not None + ] + assert { + "lifecycle": lifecycle, + "routed_ids": routed_ids, + "result": (result.id, result.status, result.final_response), + "messages": agent_message_texts(events), + "request_count": len(requests), + "objective_reached_model": ( + "\nImprove benchmark coverage\n" + in "\n".join(requests[0].message_input_texts("user")) + ), + "routes_after_completion": routes, + } == { + "lifecycle": ["turn/started", "turn/completed"], + "routed_ids": [turn_id] * len(routed_ids), + "result": (turn_id, TurnStatus.completed, "Goal complete."), + "messages": ["Initial pass complete.", "Goal complete."], + "request_count": 3, + "objective_reached_model": True, + "routes_after_completion": {}, + }