mirror of
https://github.com/microsoft/agent-framework.git
synced 2026-06-16 21:04:09 +08:00
Address comments
This commit is contained in:
@@ -11,6 +11,7 @@ from collections.abc import AsyncIterable, AsyncIterator, Generator, Mapping
|
||||
from agent_framework import (
|
||||
ChatOptions,
|
||||
Content,
|
||||
ContextProvider,
|
||||
FileCheckpointStorage,
|
||||
HistoryProvider,
|
||||
Message,
|
||||
@@ -106,22 +107,27 @@ class ResponsesHostServer(ResponsesAgentServerHost):
|
||||
"There shouldn't be a history provider with `load_messages=True` already present. "
|
||||
"History is managed by the hosting infrastructure."
|
||||
)
|
||||
provider = cast(ContextProvider, provider)
|
||||
logger.warning(
|
||||
"Context provider %s is present. If it maintains context in memory, "
|
||||
"the context may be lost between requests. Use with caution.",
|
||||
provider.source_id,
|
||||
)
|
||||
|
||||
self._is_workflow_agent = False
|
||||
self._checkpoint_storage = None
|
||||
self._checkpoint_storage_path = None
|
||||
if isinstance(agent, WorkflowAgent):
|
||||
if agent.workflow._runner_context.has_checkpointing(): # pyright: ignore[reportPrivateUsage]
|
||||
raise RuntimeError(
|
||||
"There should not be a checkpoint storage already present in the workflow agent. "
|
||||
"The hosting infrastructure will manage checkpoints instead."
|
||||
)
|
||||
checkpoint_storage_path = (
|
||||
self._checkpoint_storage_path = (
|
||||
self.CHECKPOINT_STORAGE_PATH
|
||||
if self.config.is_hosted
|
||||
else os.path.join(os.getcwd(), self.CHECKPOINT_STORAGE_PATH.lstrip("/"))
|
||||
)
|
||||
self._is_workflow_agent = True
|
||||
self._checkpoint_storage = FileCheckpointStorage(checkpoint_storage_path)
|
||||
|
||||
self._agent = agent
|
||||
self.response_handler(self._handler) # pyright: ignore[reportUnknownMemberType]
|
||||
@@ -129,6 +135,11 @@ class ResponsesHostServer(ResponsesAgentServerHost):
|
||||
# Append the user agent prefix for telemetry purposes
|
||||
append_to_user_agent(self.USER_AGENT_PREFIX)
|
||||
|
||||
@staticmethod
|
||||
def _is_streaming_request(request: CreateResponse) -> bool:
|
||||
"""Check if the request is a streaming request."""
|
||||
return request.stream is not None and request.stream is True
|
||||
|
||||
async def _handler(
|
||||
self,
|
||||
request: CreateResponse,
|
||||
@@ -146,19 +157,22 @@ class ResponsesHostServer(ResponsesAgentServerHost):
|
||||
history = await context.get_history()
|
||||
messages = [*_to_messages(history), input_text]
|
||||
|
||||
chat_options = _to_chat_options(request)
|
||||
chat_options, are_options_set = _to_chat_options(request)
|
||||
|
||||
is_streaming_request = self._is_streaming_request(request)
|
||||
response_event_stream = ResponseEventStream(response_id=context.response_id, model=request.model)
|
||||
|
||||
yield response_event_stream.emit_created()
|
||||
yield response_event_stream.emit_in_progress()
|
||||
|
||||
if request.stream is None or request.stream is False:
|
||||
if not is_streaming_request:
|
||||
# Run the agent in non-streaming mode
|
||||
if isinstance(self._agent, RawAgent):
|
||||
raw_agent = cast("RawAgent[Any]", self._agent) # pyright: ignore[reportUnknownMemberType]
|
||||
response = await raw_agent.run(messages, stream=False, options=chat_options)
|
||||
else:
|
||||
if are_options_set:
|
||||
logger.warning("Agent doesn't support runtime options. They will be ignored.")
|
||||
response = await self._agent.run(messages, stream=False)
|
||||
|
||||
for message in response.messages:
|
||||
@@ -174,6 +188,8 @@ class ResponsesHostServer(ResponsesAgentServerHost):
|
||||
raw_agent = cast("RawAgent[Any]", self._agent) # pyright: ignore[reportUnknownMemberType]
|
||||
response_stream = raw_agent.run(messages, stream=True, options=chat_options)
|
||||
else:
|
||||
if are_options_set:
|
||||
logger.warning("Agent doesn't support runtime options. They will be ignored.")
|
||||
response_stream = self._agent.run(messages, stream=True)
|
||||
|
||||
# Track the current active output item builder for streaming;
|
||||
@@ -208,44 +224,71 @@ class ResponsesHostServer(ResponsesAgentServerHost):
|
||||
by the hosting infrastructure or files will be preserved upon deactivation.
|
||||
"""
|
||||
input_text = await context.get_input_text()
|
||||
stream = request.stream is not None and request.stream is True
|
||||
is_streaming_request = self._is_streaming_request(request)
|
||||
|
||||
_, are_options_set = _to_chat_options(request)
|
||||
if are_options_set:
|
||||
logger.warning("Workflow agent doesn't support runtime options. They will be ignored.")
|
||||
|
||||
if request.previous_response_id is not None and context.conversation_id is not None:
|
||||
raise RuntimeError("Previous response ID cannot be used in conjunction with conversation ID.")
|
||||
context_id = request.previous_response_id or context.conversation_id
|
||||
|
||||
# The following should never happen due to the checks above.
|
||||
# This is for type safety and defensive programming.
|
||||
if self._checkpoint_storage is None:
|
||||
raise RuntimeError("Checkpoint storage is not available for workflow agent.")
|
||||
if self._checkpoint_storage_path is None:
|
||||
raise RuntimeError("Checkpoint storage path is not configured for workflow agent.")
|
||||
if not isinstance(self._agent, WorkflowAgent):
|
||||
raise RuntimeError("Agent is not a workflow agent.")
|
||||
|
||||
# Restore from the latest checkpoint if available, otherwise start with an empty history
|
||||
latest_checkpoint = await self._checkpoint_storage.get_latest(workflow_name=self._agent.workflow.name)
|
||||
if latest_checkpoint is not None:
|
||||
_ = await self._agent.run(
|
||||
stream=stream,
|
||||
checkpoint_id=latest_checkpoint.checkpoint_id,
|
||||
checkpoint_storage=self._checkpoint_storage,
|
||||
)
|
||||
if context_id is not None:
|
||||
checkpoint_storage = FileCheckpointStorage(os.path.join(self._checkpoint_storage_path, context_id))
|
||||
latest_checkpoint = await checkpoint_storage.get_latest(workflow_name=self._agent.workflow.name)
|
||||
if latest_checkpoint is not None:
|
||||
if not is_streaming_request:
|
||||
_ = await self._agent.run(
|
||||
stream=False,
|
||||
checkpoint_id=latest_checkpoint.checkpoint_id,
|
||||
checkpoint_storage=checkpoint_storage,
|
||||
)
|
||||
else:
|
||||
# Consume the streaming or the invocation will result in a no-op
|
||||
async for _ in self._agent.run(
|
||||
stream=True,
|
||||
checkpoint_id=latest_checkpoint.checkpoint_id,
|
||||
checkpoint_storage=checkpoint_storage,
|
||||
):
|
||||
pass
|
||||
|
||||
# Now run the agent with the latest input
|
||||
response_event_stream = ResponseEventStream(response_id=context.response_id, model=request.model)
|
||||
|
||||
# Create a new checkpoint storage for this response based on the following rules:
|
||||
# - If no previous response ID or conversation ID is provided, create a new checkpoint storage for this response
|
||||
# - If a previous response ID is provided, create a new checkpoint storage for this response
|
||||
# - If a conversation ID is provided, reuse the existing checkpoint storage for the conversation
|
||||
context_id = context.conversation_id or context.response_id
|
||||
checkpoint_storage = FileCheckpointStorage(os.path.join(self._checkpoint_storage_path, context_id))
|
||||
|
||||
yield response_event_stream.emit_created()
|
||||
yield response_event_stream.emit_in_progress()
|
||||
|
||||
if not stream:
|
||||
if not is_streaming_request:
|
||||
# Run the agent in non-streaming mode
|
||||
response = await self._agent.run(input_text, stream=False, checkpoint_storage=self._checkpoint_storage)
|
||||
response = await self._agent.run(input_text, stream=False, checkpoint_storage=checkpoint_storage)
|
||||
|
||||
for message in response.messages:
|
||||
for content in message.contents:
|
||||
async for item in _to_outputs(response_event_stream, content):
|
||||
yield item
|
||||
|
||||
await self._delete_not_latest_checkpoints(checkpoint_storage, self._agent.workflow.name)
|
||||
yield response_event_stream.emit_completed()
|
||||
return
|
||||
|
||||
# Run the agent in streaming mode
|
||||
response_stream = self._agent.run(input_text, stream=True, checkpoint_storage=self._checkpoint_storage)
|
||||
response_stream = self._agent.run(input_text, stream=True, checkpoint_storage=checkpoint_storage)
|
||||
|
||||
# Track the current active output item builder for streaming;
|
||||
# lazily created on matching content, closed when a different type arrives.
|
||||
@@ -264,7 +307,22 @@ class ResponsesHostServer(ResponsesAgentServerHost):
|
||||
for event in tracker.close():
|
||||
yield event
|
||||
|
||||
await self._delete_not_latest_checkpoints(checkpoint_storage, self._agent.workflow.name)
|
||||
yield response_event_stream.emit_completed()
|
||||
return
|
||||
|
||||
@staticmethod
|
||||
async def _delete_not_latest_checkpoints(checkpoint_storage: FileCheckpointStorage, workflow_name: str):
|
||||
"""Delete all checkpoints except the latest one.
|
||||
|
||||
We only need the last checkpoint for each invocation.
|
||||
"""
|
||||
latest_checkpoint = await checkpoint_storage.get_latest(workflow_name=workflow_name)
|
||||
if latest_checkpoint is not None:
|
||||
all_checkpoints = await checkpoint_storage.list_checkpoints(workflow_name=workflow_name)
|
||||
for checkpoint in all_checkpoints:
|
||||
if checkpoint.checkpoint_id != latest_checkpoint.checkpoint_id:
|
||||
await checkpoint_storage.delete(checkpoint.checkpoint_id)
|
||||
|
||||
|
||||
# region Active Builder State
|
||||
@@ -418,7 +476,7 @@ class _OutputItemTracker:
|
||||
# region Option Conversion
|
||||
|
||||
|
||||
def _to_chat_options(request: CreateResponse) -> ChatOptions:
|
||||
def _to_chat_options(request: CreateResponse) -> tuple[ChatOptions, bool]:
|
||||
"""Converts a CreateResponse request to ChatOptions.
|
||||
|
||||
Args:
|
||||
@@ -426,19 +484,26 @@ def _to_chat_options(request: CreateResponse) -> ChatOptions:
|
||||
|
||||
Returns:
|
||||
ChatOptions: The converted ChatOptions.
|
||||
bool: Whether any options were set.
|
||||
|
||||
"""
|
||||
chat_options = ChatOptions()
|
||||
are_options_set = False
|
||||
|
||||
if request.temperature is not None:
|
||||
chat_options["temperature"] = request.temperature
|
||||
are_options_set = True
|
||||
if request.top_p is not None:
|
||||
chat_options["top_p"] = request.top_p
|
||||
are_options_set = True
|
||||
if request.max_output_tokens is not None:
|
||||
chat_options["max_tokens"] = request.max_output_tokens
|
||||
are_options_set = True
|
||||
if request.parallel_tool_calls is not None:
|
||||
chat_options["allow_multiple_tool_calls"] = request.parallel_tool_calls
|
||||
are_options_set = True
|
||||
|
||||
return chat_options
|
||||
return chat_options, are_options_set
|
||||
|
||||
|
||||
# endregion
|
||||
|
||||
Reference in New Issue
Block a user