mirror of
https://github.com/pchuan98/codex.git
synced 2026-07-01 00:31:56 +08:00
[1/4] Add Python goal routing foundation (#27110)
## Why Goal continuation turns are emitted by the existing runtime as separate physical turns. The Python SDK needs private thread-scoped routing before it can present those notifications as one logical operation, without changing ordinary turn routing or the app-server protocol. ## What - add private goal operation state and thread-scoped notification routing - add internal wrappers for the existing `thread/goal/clear` and `thread/goal/set` RPCs - include existing goal notifications in the SDK notification union - preserve ordinary turn-ID routing unchanged - add focused routing coverage This PR does not expose a public goal API. It is the first PR in the Python goal operations stack. ## Test plan - online CI, including the Python SDK suite - focused typed-notification routing coverage
This commit is contained in:
committed by
GitHub
Unverified
parent
8e69d29521
commit
5a0f913426
@@ -0,0 +1,156 @@
|
||||
import queue
|
||||
import threading
|
||||
import time
|
||||
from dataclasses import dataclass, field
|
||||
|
||||
from .generated.v2_all import (
|
||||
ThreadGoalClearedNotification,
|
||||
ThreadGoalStatus,
|
||||
ThreadGoalUpdatedNotification,
|
||||
Turn,
|
||||
TurnCompletedNotification,
|
||||
TurnStartedNotification,
|
||||
)
|
||||
from .models import Notification
|
||||
|
||||
|
||||
class _GoalStreamClosed(Exception):
|
||||
"""Wake a notification reader after its logical stream closes."""
|
||||
|
||||
|
||||
def _terminal_goal_status(status: ThreadGoalStatus | None) -> bool:
|
||||
return status in {
|
||||
ThreadGoalStatus.paused,
|
||||
ThreadGoalStatus.blocked,
|
||||
ThreadGoalStatus.usage_limited,
|
||||
ThreadGoalStatus.budget_limited,
|
||||
ThreadGoalStatus.complete,
|
||||
}
|
||||
|
||||
|
||||
@dataclass(slots=True)
|
||||
class _GoalOperationState:
|
||||
"""Private state for one goal operation exposed as a logical turn."""
|
||||
|
||||
thread_id: str
|
||||
logical_turn_id: str | None = None
|
||||
current_turn_id: str | None = None
|
||||
status: ThreadGoalStatus | None = None
|
||||
started_turn: Turn | None = None
|
||||
completed_turn: Turn | None = None
|
||||
interrupted: bool = False
|
||||
interrupt_requested: bool = False
|
||||
cleared: bool = False
|
||||
_condition: threading.Condition = field(default_factory=threading.Condition)
|
||||
_notifications: queue.Queue[Notification | BaseException] = field(default_factory=queue.Queue)
|
||||
_failure: BaseException | None = None
|
||||
_finished: bool = False
|
||||
|
||||
def observe(self, notification: Notification) -> None:
|
||||
payload = notification.payload
|
||||
with self._condition:
|
||||
if isinstance(payload, TurnStartedNotification):
|
||||
if self.logical_turn_id is None:
|
||||
self.logical_turn_id = payload.turn.id
|
||||
self.current_turn_id = payload.turn.id
|
||||
if self.started_turn is None:
|
||||
self.started_turn = payload.turn
|
||||
elif isinstance(payload, TurnCompletedNotification):
|
||||
self.completed_turn = payload.turn
|
||||
if self.current_turn_id == payload.turn.id:
|
||||
self.current_turn_id = None
|
||||
elif isinstance(payload, ThreadGoalUpdatedNotification):
|
||||
self.status = payload.goal.status
|
||||
if self.status == ThreadGoalStatus.active:
|
||||
self.cleared = False
|
||||
elif isinstance(payload, ThreadGoalClearedNotification):
|
||||
self.cleared = True
|
||||
if (
|
||||
self.current_turn_id is None
|
||||
and self.completed_turn is not None
|
||||
and (self.cleared or _terminal_goal_status(self.status))
|
||||
):
|
||||
self._finished = True
|
||||
self._condition.notify_all()
|
||||
self._notifications.put(notification)
|
||||
|
||||
def wait_for_start(self, timeout: float) -> str | None:
|
||||
"""Wait for the runtime-generated first turn without consuming its event."""
|
||||
deadline = time.monotonic() + timeout
|
||||
with self._condition:
|
||||
while self.started_turn is None or self.logical_turn_id is None:
|
||||
if self._failure is not None:
|
||||
raise self._failure
|
||||
remaining = deadline - time.monotonic()
|
||||
if remaining <= 0:
|
||||
return None
|
||||
self._condition.wait(remaining)
|
||||
return self.logical_turn_id
|
||||
|
||||
def fail(self, exc: BaseException) -> None:
|
||||
with self._condition:
|
||||
self._failure = exc
|
||||
self._condition.notify_all()
|
||||
self._notifications.put(exc)
|
||||
|
||||
def next_notification(self) -> Notification:
|
||||
item = self._notifications.get()
|
||||
if isinstance(item, BaseException):
|
||||
raise item
|
||||
return item
|
||||
|
||||
def finish(self) -> None:
|
||||
"""Mark the logical operation inactive and wake waiting controls."""
|
||||
with self._condition:
|
||||
self._finished = True
|
||||
self.current_turn_id = None
|
||||
self._condition.notify_all()
|
||||
|
||||
def is_finished(self) -> bool:
|
||||
with self._condition:
|
||||
return self._finished
|
||||
|
||||
def begin_interrupt(self) -> bool:
|
||||
with self._condition:
|
||||
if self._finished:
|
||||
return False
|
||||
self.interrupt_requested = True
|
||||
return True
|
||||
|
||||
def confirm_interrupt(self) -> None:
|
||||
with self._condition:
|
||||
self.interrupted = True
|
||||
self.interrupt_requested = False
|
||||
|
||||
def cancel_interrupt(self) -> None:
|
||||
with self._condition:
|
||||
self.interrupt_requested = False
|
||||
|
||||
def explicit_interrupt(self, status: ThreadGoalStatus | None) -> bool:
|
||||
with self._condition:
|
||||
return self.interrupted or (
|
||||
self.interrupt_requested and status == ThreadGoalStatus.paused
|
||||
)
|
||||
|
||||
def active_turn(self, *, after: str | None = None) -> str | None:
|
||||
"""Wait for the current turn, or return None once the goal has ended."""
|
||||
with self._condition:
|
||||
while True:
|
||||
if self._failure is not None:
|
||||
raise self._failure
|
||||
if self._finished:
|
||||
return None
|
||||
if self.current_turn_id is not None and self.current_turn_id != after:
|
||||
return self.current_turn_id
|
||||
if self.cleared or _terminal_goal_status(self.status):
|
||||
return None
|
||||
self._condition.wait()
|
||||
|
||||
def current_turn(self) -> str | None:
|
||||
"""Return the current physical turn without waiting for rollover."""
|
||||
with self._condition:
|
||||
return self.current_turn_id
|
||||
|
||||
def wake_notification_reader(self) -> None:
|
||||
"""Release a reader blocked after its stream has been closed."""
|
||||
self._notifications.put(_GoalStreamClosed())
|
||||
@@ -4,6 +4,7 @@ import queue
|
||||
import threading
|
||||
from collections import deque
|
||||
|
||||
from ._goal import _GoalOperationState
|
||||
from .errors import CodexError, map_jsonrpc_error
|
||||
from .generated.notification_registry import notification_turn_id
|
||||
from .generated.v2_all import AccountLoginCompletedNotification
|
||||
@@ -30,6 +31,7 @@ class MessageRouter:
|
||||
self._pending_login_notifications: dict[str, deque[Notification]] = {}
|
||||
self._turn_notifications: dict[str, queue.Queue[NotificationQueueItem]] = {}
|
||||
self._pending_turn_notifications: dict[str, deque[Notification]] = {}
|
||||
self._goal_operations: dict[str, _GoalOperationState] = {}
|
||||
self._global_notifications: queue.Queue[NotificationQueueItem] = queue.Queue()
|
||||
|
||||
def create_response_waiter(self, request_id: str) -> queue.Queue[ResponseQueueItem]:
|
||||
@@ -116,6 +118,21 @@ class MessageRouter:
|
||||
raise item
|
||||
return item
|
||||
|
||||
def register_goal(self, thread_id: str) -> _GoalOperationState:
|
||||
"""Register one thread-scoped logical goal operation before it starts."""
|
||||
state = _GoalOperationState(thread_id=thread_id)
|
||||
with self._lock:
|
||||
if thread_id in self._goal_operations:
|
||||
raise RuntimeError(f"thread {thread_id!r} already has an active goal operation")
|
||||
self._goal_operations[thread_id] = state
|
||||
return state
|
||||
|
||||
def unregister_goal(self, state: _GoalOperationState) -> None:
|
||||
"""Stop routing notifications to a completed logical goal operation."""
|
||||
with self._lock:
|
||||
if self._goal_operations.get(state.thread_id) is state:
|
||||
self._goal_operations.pop(state.thread_id)
|
||||
|
||||
def route_response(self, msg: dict[str, JsonValue]) -> None:
|
||||
"""Deliver a JSON-RPC response or error to its request waiter."""
|
||||
|
||||
@@ -157,6 +174,17 @@ class MessageRouter:
|
||||
return
|
||||
|
||||
turn_id = self._notification_turn_id(notification)
|
||||
thread_id = self._notification_thread_id(notification)
|
||||
if thread_id is not None:
|
||||
with self._lock:
|
||||
goal_state = self._goal_operations.get(thread_id)
|
||||
if goal_state is not None and (
|
||||
turn_id is not None or notification.method.startswith("thread/goal/")
|
||||
):
|
||||
goal_state.observe(notification)
|
||||
if goal_state.is_finished():
|
||||
self.unregister_goal(goal_state)
|
||||
return
|
||||
if turn_id is None:
|
||||
self._global_notifications.put(notification)
|
||||
return
|
||||
@@ -182,6 +210,8 @@ class MessageRouter:
|
||||
self._pending_login_notifications.clear()
|
||||
turn_queues = list(self._turn_notifications.values())
|
||||
self._pending_turn_notifications.clear()
|
||||
goal_operations = list(self._goal_operations.values())
|
||||
self._goal_operations.clear()
|
||||
# Put the same transport failure into every queue so no SDK call blocks
|
||||
# forever waiting for a response that cannot arrive.
|
||||
for waiter in response_waiters:
|
||||
@@ -190,8 +220,34 @@ class MessageRouter:
|
||||
login_queue.put(exc)
|
||||
for turn_queue in turn_queues:
|
||||
turn_queue.put(exc)
|
||||
for goal_operation in goal_operations:
|
||||
goal_operation.fail(exc)
|
||||
self._global_notifications.put(exc)
|
||||
|
||||
def _notification_turn_id(self, notification: Notification) -> str | None:
|
||||
"""Extract routing ids from generated metadata or raw unknown payloads."""
|
||||
payload = notification.payload
|
||||
if isinstance(payload, UnknownNotification):
|
||||
raw_turn_id = payload.params.get("turnId")
|
||||
if isinstance(raw_turn_id, str):
|
||||
return raw_turn_id
|
||||
raw_turn = payload.params.get("turn")
|
||||
if isinstance(raw_turn, dict):
|
||||
raw_nested_turn_id = raw_turn.get("id")
|
||||
if isinstance(raw_nested_turn_id, str):
|
||||
return raw_nested_turn_id
|
||||
return None
|
||||
return notification_turn_id(payload)
|
||||
|
||||
def _notification_thread_id(self, notification: Notification) -> str | None:
|
||||
"""Extract thread ids from typed payloads or raw unknown payloads."""
|
||||
payload = notification.payload
|
||||
if isinstance(payload, UnknownNotification):
|
||||
raw_thread_id = payload.params.get("threadId")
|
||||
return raw_thread_id if isinstance(raw_thread_id, str) else None
|
||||
thread_id = getattr(payload, "thread_id", None)
|
||||
return thread_id if isinstance(thread_id, str) else None
|
||||
|
||||
def _notification_login_id(self, notification: Notification) -> str | None:
|
||||
"""Extract the login attempt id from completion notifications."""
|
||||
if notification.method != "account/login/completed":
|
||||
@@ -205,18 +261,3 @@ class MessageRouter:
|
||||
if isinstance(raw_login_id, str):
|
||||
return raw_login_id
|
||||
return None
|
||||
|
||||
def _notification_turn_id(self, notification: Notification) -> str | None:
|
||||
"""Extract routing ids from known generated payloads or raw unknown payloads."""
|
||||
payload = notification.payload
|
||||
if isinstance(payload, UnknownNotification):
|
||||
raw_turn_id = payload.params.get("turnId")
|
||||
if isinstance(raw_turn_id, str):
|
||||
return raw_turn_id
|
||||
raw_turn = payload.params.get("turn")
|
||||
if isinstance(raw_turn, dict):
|
||||
raw_nested_turn_id = raw_turn.get("id")
|
||||
if isinstance(raw_nested_turn_id, str):
|
||||
return raw_nested_turn_id
|
||||
return None
|
||||
return notification_turn_id(payload)
|
||||
|
||||
@@ -6,6 +6,7 @@ from typing import AsyncIterator, Callable, ParamSpec, TypeVar
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from ._goal import _GoalOperationState
|
||||
from .client import CodexClient, CodexConfig
|
||||
from .generated.v2_all import (
|
||||
AccountLoginCompletedNotification,
|
||||
@@ -21,6 +22,9 @@ from .generated.v2_all import (
|
||||
ThreadCompactStartResponse,
|
||||
ThreadForkParams as V2ThreadForkParams,
|
||||
ThreadForkResponse,
|
||||
ThreadGoalClearResponse,
|
||||
ThreadGoalSetResponse,
|
||||
ThreadGoalStatus,
|
||||
ThreadListParams as V2ThreadListParams,
|
||||
ThreadListResponse,
|
||||
ThreadReadResponse,
|
||||
@@ -107,6 +111,14 @@ class AsyncCodexClient:
|
||||
"""Unregister a turn notification queue on the wrapped sync client."""
|
||||
self._sync.unregister_turn_notifications(turn_id)
|
||||
|
||||
def register_goal_operation(self, thread_id: str) -> _GoalOperationState:
|
||||
"""Register a logical goal route on the wrapped sync client."""
|
||||
return self._sync.register_goal_operation(thread_id)
|
||||
|
||||
def unregister_goal_operation(self, state: _GoalOperationState) -> None:
|
||||
"""Release one logical goal route."""
|
||||
self._sync.unregister_goal_operation(state)
|
||||
|
||||
async def request(
|
||||
self,
|
||||
method: str,
|
||||
@@ -192,6 +204,29 @@ class AsyncCodexClient:
|
||||
"""Start thread compaction using the wrapped sync client."""
|
||||
return await self._call_sync(self._sync.thread_compact, thread_id)
|
||||
|
||||
async def thread_goal_clear(self, thread_id: str) -> ThreadGoalClearResponse:
|
||||
"""Clear the persisted goal through the wrapped sync client."""
|
||||
return await self._call_sync(self._sync.thread_goal_clear, thread_id)
|
||||
|
||||
async def thread_goal_set(
|
||||
self,
|
||||
thread_id: str,
|
||||
*,
|
||||
objective: str | None = None,
|
||||
status: ThreadGoalStatus | None = None,
|
||||
) -> ThreadGoalSetResponse:
|
||||
"""Create or update a persisted goal through the wrapped sync client."""
|
||||
return await self._call_sync(
|
||||
self._sync.thread_goal_set,
|
||||
thread_id,
|
||||
objective=objective,
|
||||
status=status,
|
||||
)
|
||||
|
||||
async def pause_goal(self, thread_id: str) -> ThreadGoalSetResponse:
|
||||
"""Pause the active goal through the wrapped sync client."""
|
||||
return await self._call_sync(self._sync.pause_goal, thread_id)
|
||||
|
||||
async def turn_start(
|
||||
self,
|
||||
thread_id: str,
|
||||
@@ -256,6 +291,10 @@ class AsyncCodexClient:
|
||||
"""Wait for the next notification routed to one turn."""
|
||||
return await self._call_sync(self._sync.next_turn_notification, turn_id)
|
||||
|
||||
async def next_goal_notification(self, state: _GoalOperationState) -> Notification:
|
||||
"""Wait for the next notification in a logical goal turn."""
|
||||
return await self._call_sync(self._sync.next_goal_notification, state)
|
||||
|
||||
async def wait_for_login_completed(
|
||||
self,
|
||||
login_id: str,
|
||||
|
||||
@@ -10,6 +10,7 @@ from typing import Callable, Iterator, TypeVar
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from ._goal import _GoalOperationState
|
||||
from ._message_router import MessageRouter
|
||||
from ._version import __version__ as SDK_VERSION
|
||||
from .errors import CodexError, TransportClosedError
|
||||
@@ -30,6 +31,9 @@ from .generated.v2_all import (
|
||||
ThreadCompactStartResponse,
|
||||
ThreadForkParams as V2ThreadForkParams,
|
||||
ThreadForkResponse,
|
||||
ThreadGoalClearResponse,
|
||||
ThreadGoalSetResponse,
|
||||
ThreadGoalStatus,
|
||||
ThreadListParams as V2ThreadListParams,
|
||||
ThreadListResponse,
|
||||
ThreadReadResponse,
|
||||
@@ -352,6 +356,18 @@ class CodexClient:
|
||||
"""Return the next routed notification for the requested turn id."""
|
||||
return self._router.next_turn_notification(turn_id)
|
||||
|
||||
def register_goal_operation(self, thread_id: str) -> _GoalOperationState:
|
||||
"""Register a private thread-scoped route for a logical goal turn."""
|
||||
return self._router.register_goal(thread_id)
|
||||
|
||||
def unregister_goal_operation(self, state: _GoalOperationState) -> None:
|
||||
"""Release routing state for one logical goal turn."""
|
||||
self._router.unregister_goal(state)
|
||||
|
||||
def next_goal_notification(self, state: _GoalOperationState) -> Notification:
|
||||
"""Wait for the next notification in a logical goal turn."""
|
||||
return state.next_notification()
|
||||
|
||||
def account_login_start(
|
||||
self,
|
||||
params: V2LoginAccountParams | JsonObject,
|
||||
@@ -452,6 +468,37 @@ class CodexClient:
|
||||
response_model=ThreadCompactStartResponse,
|
||||
)
|
||||
|
||||
def thread_goal_clear(self, thread_id: str) -> ThreadGoalClearResponse:
|
||||
"""Clear the persisted goal for a thread before replacing it."""
|
||||
return self.request(
|
||||
"thread/goal/clear",
|
||||
{"threadId": thread_id},
|
||||
response_model=ThreadGoalClearResponse,
|
||||
)
|
||||
|
||||
def thread_goal_set(
|
||||
self,
|
||||
thread_id: str,
|
||||
*,
|
||||
objective: str | None = None,
|
||||
status: ThreadGoalStatus | None = None,
|
||||
) -> ThreadGoalSetResponse:
|
||||
"""Create or update the persisted goal for a thread."""
|
||||
payload: JsonObject = {"threadId": thread_id}
|
||||
if objective is not None:
|
||||
payload["objective"] = objective
|
||||
if status is not None:
|
||||
payload["status"] = status.value
|
||||
return self.request(
|
||||
"thread/goal/set",
|
||||
payload,
|
||||
response_model=ThreadGoalSetResponse,
|
||||
)
|
||||
|
||||
def pause_goal(self, thread_id: str) -> ThreadGoalSetResponse:
|
||||
"""Pause the active goal used by a logical goal turn."""
|
||||
return self.thread_goal_set(thread_id, status=ThreadGoalStatus.paused)
|
||||
|
||||
def turn_start(
|
||||
self,
|
||||
thread_id: str,
|
||||
|
||||
@@ -27,6 +27,8 @@ from .generated.v2_all import (
|
||||
ReasoningSummaryTextDeltaNotification,
|
||||
ReasoningTextDeltaNotification,
|
||||
TerminalInteractionNotification,
|
||||
ThreadGoalClearedNotification,
|
||||
ThreadGoalUpdatedNotification,
|
||||
ThreadNameUpdatedNotification,
|
||||
ThreadStartedNotification,
|
||||
ThreadTokenUsageUpdatedNotification,
|
||||
@@ -70,6 +72,8 @@ NotificationPayload: TypeAlias = (
|
||||
| ReasoningTextDeltaNotification
|
||||
| TerminalInteractionNotification
|
||||
| ThreadNameUpdatedNotification
|
||||
| ThreadGoalClearedNotification
|
||||
| ThreadGoalUpdatedNotification
|
||||
| ThreadStartedNotification
|
||||
| ThreadTokenUsageUpdatedNotification
|
||||
| TurnCompletedNotification
|
||||
|
||||
@@ -185,6 +185,32 @@ def test_turn_notification_router_demuxes_registered_turns() -> None:
|
||||
]
|
||||
|
||||
|
||||
def test_goal_notification_router_routes_by_thread_id() -> None:
|
||||
"""A goal operation should receive turn notifications across physical turn ids."""
|
||||
client = CodexClient()
|
||||
state = client.register_goal_operation("thread-1")
|
||||
|
||||
client._router.route_notification(
|
||||
client._coerce_notification(
|
||||
"item/agentMessage/delta",
|
||||
{
|
||||
"delta": "continued",
|
||||
"itemId": "item-1",
|
||||
"threadId": "thread-1",
|
||||
"turnId": "turn-2",
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
event = client.next_goal_notification(state)
|
||||
|
||||
assert isinstance(event.payload, AgentMessageDeltaNotification)
|
||||
assert (event.method, event.payload.delta) == (
|
||||
"item/agentMessage/delta",
|
||||
"continued",
|
||||
)
|
||||
|
||||
|
||||
def test_client_reader_routes_interleaved_turn_notifications_by_turn_id() -> None:
|
||||
"""Reader-loop routing should preserve order within each interleaved turn stream."""
|
||||
client = CodexClient()
|
||||
|
||||
Reference in New Issue
Block a user