Address comments

This commit is contained in:
Tao Chen
2026-04-16 13:54:15 -07:00
Unverified
parent 4249aef461
commit c99c63810a
@@ -11,6 +11,7 @@ from collections.abc import AsyncIterable, AsyncIterator, Generator, Mapping
from agent_framework import (
ChatOptions,
Content,
ContextProvider,
FileCheckpointStorage,
HistoryProvider,
Message,
@@ -106,22 +107,27 @@ class ResponsesHostServer(ResponsesAgentServerHost):
"There shouldn't be a history provider with `load_messages=True` already present. "
"History is managed by the hosting infrastructure."
)
provider = cast(ContextProvider, provider)
logger.warning(
"Context provider %s is present. If it maintains context in memory, "
"the context may be lost between requests. Use with caution.",
provider.source_id,
)
self._is_workflow_agent = False
self._checkpoint_storage = None
self._checkpoint_storage_path = None
if isinstance(agent, WorkflowAgent):
if agent.workflow._runner_context.has_checkpointing(): # pyright: ignore[reportPrivateUsage]
raise RuntimeError(
"There should not be a checkpoint storage already present in the workflow agent. "
"The hosting infrastructure will manage checkpoints instead."
)
checkpoint_storage_path = (
self._checkpoint_storage_path = (
self.CHECKPOINT_STORAGE_PATH
if self.config.is_hosted
else os.path.join(os.getcwd(), self.CHECKPOINT_STORAGE_PATH.lstrip("/"))
)
self._is_workflow_agent = True
self._checkpoint_storage = FileCheckpointStorage(checkpoint_storage_path)
self._agent = agent
self.response_handler(self._handler) # pyright: ignore[reportUnknownMemberType]
@@ -129,6 +135,11 @@ class ResponsesHostServer(ResponsesAgentServerHost):
# Append the user agent prefix for telemetry purposes
append_to_user_agent(self.USER_AGENT_PREFIX)
@staticmethod
def _is_streaming_request(request: CreateResponse) -> bool:
"""Check if the request is a streaming request."""
return request.stream is not None and request.stream is True
async def _handler(
self,
request: CreateResponse,
@@ -146,19 +157,22 @@ class ResponsesHostServer(ResponsesAgentServerHost):
history = await context.get_history()
messages = [*_to_messages(history), input_text]
chat_options = _to_chat_options(request)
chat_options, are_options_set = _to_chat_options(request)
is_streaming_request = self._is_streaming_request(request)
response_event_stream = ResponseEventStream(response_id=context.response_id, model=request.model)
yield response_event_stream.emit_created()
yield response_event_stream.emit_in_progress()
if request.stream is None or request.stream is False:
if not is_streaming_request:
# Run the agent in non-streaming mode
if isinstance(self._agent, RawAgent):
raw_agent = cast("RawAgent[Any]", self._agent) # pyright: ignore[reportUnknownMemberType]
response = await raw_agent.run(messages, stream=False, options=chat_options)
else:
if are_options_set:
logger.warning("Agent doesn't support runtime options. They will be ignored.")
response = await self._agent.run(messages, stream=False)
for message in response.messages:
@@ -174,6 +188,8 @@ class ResponsesHostServer(ResponsesAgentServerHost):
raw_agent = cast("RawAgent[Any]", self._agent) # pyright: ignore[reportUnknownMemberType]
response_stream = raw_agent.run(messages, stream=True, options=chat_options)
else:
if are_options_set:
logger.warning("Agent doesn't support runtime options. They will be ignored.")
response_stream = self._agent.run(messages, stream=True)
# Track the current active output item builder for streaming;
@@ -208,44 +224,71 @@ class ResponsesHostServer(ResponsesAgentServerHost):
by the hosting infrastructure or files will be preserved upon deactivation.
"""
input_text = await context.get_input_text()
stream = request.stream is not None and request.stream is True
is_streaming_request = self._is_streaming_request(request)
_, are_options_set = _to_chat_options(request)
if are_options_set:
logger.warning("Workflow agent doesn't support runtime options. They will be ignored.")
if request.previous_response_id is not None and context.conversation_id is not None:
raise RuntimeError("Previous response ID cannot be used in conjunction with conversation ID.")
context_id = request.previous_response_id or context.conversation_id
# The following should never happen due to the checks above.
# This is for type safety and defensive programming.
if self._checkpoint_storage is None:
raise RuntimeError("Checkpoint storage is not available for workflow agent.")
if self._checkpoint_storage_path is None:
raise RuntimeError("Checkpoint storage path is not configured for workflow agent.")
if not isinstance(self._agent, WorkflowAgent):
raise RuntimeError("Agent is not a workflow agent.")
# Restore from the latest checkpoint if available, otherwise start with an empty history
latest_checkpoint = await self._checkpoint_storage.get_latest(workflow_name=self._agent.workflow.name)
if latest_checkpoint is not None:
_ = await self._agent.run(
stream=stream,
checkpoint_id=latest_checkpoint.checkpoint_id,
checkpoint_storage=self._checkpoint_storage,
)
if context_id is not None:
checkpoint_storage = FileCheckpointStorage(os.path.join(self._checkpoint_storage_path, context_id))
latest_checkpoint = await checkpoint_storage.get_latest(workflow_name=self._agent.workflow.name)
if latest_checkpoint is not None:
if not is_streaming_request:
_ = await self._agent.run(
stream=False,
checkpoint_id=latest_checkpoint.checkpoint_id,
checkpoint_storage=checkpoint_storage,
)
else:
# Consume the streaming or the invocation will result in a no-op
async for _ in self._agent.run(
stream=True,
checkpoint_id=latest_checkpoint.checkpoint_id,
checkpoint_storage=checkpoint_storage,
):
pass
# Now run the agent with the latest input
response_event_stream = ResponseEventStream(response_id=context.response_id, model=request.model)
# Create a new checkpoint storage for this response based on the following rules:
# - If no previous response ID or conversation ID is provided, create a new checkpoint storage for this response
# - If a previous response ID is provided, create a new checkpoint storage for this response
# - If a conversation ID is provided, reuse the existing checkpoint storage for the conversation
context_id = context.conversation_id or context.response_id
checkpoint_storage = FileCheckpointStorage(os.path.join(self._checkpoint_storage_path, context_id))
yield response_event_stream.emit_created()
yield response_event_stream.emit_in_progress()
if not stream:
if not is_streaming_request:
# Run the agent in non-streaming mode
response = await self._agent.run(input_text, stream=False, checkpoint_storage=self._checkpoint_storage)
response = await self._agent.run(input_text, stream=False, checkpoint_storage=checkpoint_storage)
for message in response.messages:
for content in message.contents:
async for item in _to_outputs(response_event_stream, content):
yield item
await self._delete_not_latest_checkpoints(checkpoint_storage, self._agent.workflow.name)
yield response_event_stream.emit_completed()
return
# Run the agent in streaming mode
response_stream = self._agent.run(input_text, stream=True, checkpoint_storage=self._checkpoint_storage)
response_stream = self._agent.run(input_text, stream=True, checkpoint_storage=checkpoint_storage)
# Track the current active output item builder for streaming;
# lazily created on matching content, closed when a different type arrives.
@@ -264,7 +307,22 @@ class ResponsesHostServer(ResponsesAgentServerHost):
for event in tracker.close():
yield event
await self._delete_not_latest_checkpoints(checkpoint_storage, self._agent.workflow.name)
yield response_event_stream.emit_completed()
return
@staticmethod
async def _delete_not_latest_checkpoints(checkpoint_storage: FileCheckpointStorage, workflow_name: str):
"""Delete all checkpoints except the latest one.
We only need the last checkpoint for each invocation.
"""
latest_checkpoint = await checkpoint_storage.get_latest(workflow_name=workflow_name)
if latest_checkpoint is not None:
all_checkpoints = await checkpoint_storage.list_checkpoints(workflow_name=workflow_name)
for checkpoint in all_checkpoints:
if checkpoint.checkpoint_id != latest_checkpoint.checkpoint_id:
await checkpoint_storage.delete(checkpoint.checkpoint_id)
# region Active Builder State
@@ -418,7 +476,7 @@ class _OutputItemTracker:
# region Option Conversion
def _to_chat_options(request: CreateResponse) -> ChatOptions:
def _to_chat_options(request: CreateResponse) -> tuple[ChatOptions, bool]:
"""Converts a CreateResponse request to ChatOptions.
Args:
@@ -426,19 +484,26 @@ def _to_chat_options(request: CreateResponse) -> ChatOptions:
Returns:
ChatOptions: The converted ChatOptions.
bool: Whether any options were set.
"""
chat_options = ChatOptions()
are_options_set = False
if request.temperature is not None:
chat_options["temperature"] = request.temperature
are_options_set = True
if request.top_p is not None:
chat_options["top_p"] = request.top_p
are_options_set = True
if request.max_output_tokens is not None:
chat_options["max_tokens"] = request.max_output_tokens
are_options_set = True
if request.parallel_tool_calls is not None:
chat_options["allow_multiple_tool_calls"] = request.parallel_tool_calls
are_options_set = True
return chat_options
return chat_options, are_options_set
# endregion