mirror of
https://github.com/microsoft/agent-framework.git
synced 2026-06-16 21:04:09 +08:00
Python: Add special handling for workflows (#5298)
* Add special handling for workflows * Address comments
This commit is contained in:
committed by
GitHub
Unverified
parent
55e0705923
commit
0fcd71dbeb
@@ -5,9 +5,20 @@ from __future__ import annotations
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
from collections.abc import AsyncIterable, AsyncIterator, Generator, Mapping
|
||||
|
||||
from agent_framework import ChatOptions, Content, HistoryProvider, Message, RawAgent, SupportsAgentRun
|
||||
from agent_framework import (
|
||||
ChatOptions,
|
||||
Content,
|
||||
ContextProvider,
|
||||
FileCheckpointStorage,
|
||||
HistoryProvider,
|
||||
Message,
|
||||
RawAgent,
|
||||
SupportsAgentRun,
|
||||
WorkflowAgent,
|
||||
)
|
||||
from agent_framework._telemetry import append_to_user_agent
|
||||
from azure.ai.agentserver.responses import (
|
||||
ResponseContext,
|
||||
@@ -60,6 +71,8 @@ class ResponsesHostServer(ResponsesAgentServerHost):
|
||||
"""A responses server host for an agent."""
|
||||
|
||||
USER_AGENT_PREFIX = "foundry-hosting"
|
||||
# TODO(@taochen): Allow a different checkpoint storage that stores checkpoints externally
|
||||
CHECKPOINT_STORAGE_PATH = "/.checkpoints"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
@@ -80,8 +93,11 @@ class ResponsesHostServer(ResponsesAgentServerHost):
|
||||
**kwargs: Additional keyword arguments.
|
||||
|
||||
Note:
|
||||
The agent must not have a history provider with `load_messages=True`,
|
||||
because history is managed by the hosting infrastructure.
|
||||
1. The agent must not have a history provider with `load_messages=True`,
|
||||
because history is managed by the hosting infrastructure.
|
||||
2. The agent must not have any context providers that maintain context
|
||||
in memory, because the hosting environment may get deactivated between
|
||||
requests, and any in-memory context would be lost.
|
||||
"""
|
||||
super().__init__(prefix=prefix, options=options, store=store, **kwargs)
|
||||
|
||||
@@ -91,13 +107,39 @@ class ResponsesHostServer(ResponsesAgentServerHost):
|
||||
"There shouldn't be a history provider with `load_messages=True` already present. "
|
||||
"History is managed by the hosting infrastructure."
|
||||
)
|
||||
self._agent = agent
|
||||
provider = cast(ContextProvider, provider)
|
||||
logger.warning(
|
||||
"Context provider %s is present. If it maintains context in memory, "
|
||||
"the context may be lost between requests. Use with caution.",
|
||||
provider.source_id,
|
||||
)
|
||||
|
||||
self._is_workflow_agent = False
|
||||
self._checkpoint_storage_path = None
|
||||
if isinstance(agent, WorkflowAgent):
|
||||
if agent.workflow._runner_context.has_checkpointing(): # pyright: ignore[reportPrivateUsage]
|
||||
raise RuntimeError(
|
||||
"There should not be a checkpoint storage already present in the workflow agent. "
|
||||
"The hosting infrastructure will manage checkpoints instead."
|
||||
)
|
||||
self._checkpoint_storage_path = (
|
||||
self.CHECKPOINT_STORAGE_PATH
|
||||
if self.config.is_hosted
|
||||
else os.path.join(os.getcwd(), self.CHECKPOINT_STORAGE_PATH.lstrip("/"))
|
||||
)
|
||||
self._is_workflow_agent = True
|
||||
|
||||
self._agent = agent
|
||||
self.response_handler(self._handler) # pyright: ignore[reportUnknownMemberType]
|
||||
|
||||
# Append the user agent prefix for telemetry purposes
|
||||
append_to_user_agent(self.USER_AGENT_PREFIX)
|
||||
|
||||
@staticmethod
|
||||
def _is_streaming_request(request: CreateResponse) -> bool:
|
||||
"""Check if the request is a streaming request."""
|
||||
return request.stream is not None and request.stream is True
|
||||
|
||||
async def _handler(
|
||||
self,
|
||||
request: CreateResponse,
|
||||
@@ -105,50 +147,61 @@ class ResponsesHostServer(ResponsesAgentServerHost):
|
||||
cancellation_signal: asyncio.Event,
|
||||
) -> AsyncIterable[ResponseStreamEvent | dict[str, Any]]:
|
||||
"""Handle the creation of a response."""
|
||||
if self._is_workflow_agent:
|
||||
# Workflow agents are handled differently because they require checkpoint restoration
|
||||
async for event in self._handle_workflow_agent(request, context, cancellation_signal):
|
||||
yield event
|
||||
return
|
||||
|
||||
input_text = await context.get_input_text()
|
||||
history = await context.get_history()
|
||||
messages = [*_to_messages(history), input_text]
|
||||
|
||||
chat_options = _to_chat_options(request)
|
||||
chat_options, are_options_set = _to_chat_options(request)
|
||||
|
||||
stream = ResponseEventStream(response_id=context.response_id, model=request.model)
|
||||
is_streaming_request = self._is_streaming_request(request)
|
||||
response_event_stream = ResponseEventStream(response_id=context.response_id, model=request.model)
|
||||
|
||||
yield stream.emit_created()
|
||||
yield stream.emit_in_progress()
|
||||
yield response_event_stream.emit_created()
|
||||
yield response_event_stream.emit_in_progress()
|
||||
|
||||
if request.stream is None or request.stream is False:
|
||||
if not is_streaming_request:
|
||||
# Run the agent in non-streaming mode
|
||||
if isinstance(self._agent, RawAgent):
|
||||
raw_agent = cast("RawAgent[Any]", self._agent) # pyright: ignore[reportUnknownMemberType]
|
||||
response = await raw_agent.run(messages, stream=False, options=chat_options)
|
||||
else:
|
||||
if are_options_set:
|
||||
logger.warning("Agent doesn't support runtime options. They will be ignored.")
|
||||
response = await self._agent.run(messages, stream=False)
|
||||
|
||||
for message in response.messages:
|
||||
for content in message.contents:
|
||||
async for item in _to_outputs(stream, content):
|
||||
async for item in _to_outputs(response_event_stream, content):
|
||||
yield item
|
||||
|
||||
yield stream.emit_completed()
|
||||
yield response_event_stream.emit_completed()
|
||||
return
|
||||
|
||||
# Start the streaming response
|
||||
# Run the agent in streaming mode
|
||||
if isinstance(self._agent, RawAgent):
|
||||
raw_agent = cast("RawAgent[Any]", self._agent) # pyright: ignore[reportUnknownMemberType]
|
||||
response_stream = raw_agent.run(messages, stream=True, options=chat_options)
|
||||
else:
|
||||
if are_options_set:
|
||||
logger.warning("Agent doesn't support runtime options. They will be ignored.")
|
||||
response_stream = self._agent.run(messages, stream=True)
|
||||
|
||||
# Track the current active output item builder for streaming;
|
||||
# lazily created on matching content, closed when a different type arrives.
|
||||
tracker = _OutputItemTracker(stream)
|
||||
tracker = _OutputItemTracker(response_event_stream)
|
||||
|
||||
async for update in response_stream:
|
||||
for content in update.contents:
|
||||
for event in tracker.handle(content):
|
||||
yield event
|
||||
if tracker.needs_async:
|
||||
async for item in _to_outputs(stream, content):
|
||||
async for item in _to_outputs(response_event_stream, content):
|
||||
yield item
|
||||
tracker.needs_async = False
|
||||
|
||||
@@ -156,7 +209,120 @@ class ResponsesHostServer(ResponsesAgentServerHost):
|
||||
for event in tracker.close():
|
||||
yield event
|
||||
|
||||
yield stream.emit_completed()
|
||||
yield response_event_stream.emit_completed()
|
||||
|
||||
async def _handle_workflow_agent(
|
||||
self,
|
||||
request: CreateResponse,
|
||||
context: ResponseContext,
|
||||
cancellation_signal: asyncio.Event,
|
||||
) -> AsyncIterable[ResponseStreamEvent | dict[str, Any]]:
|
||||
"""Handle the creation of a response for a workflow agent.
|
||||
|
||||
Why this is required:
|
||||
The sandbox may be deactivated after some period of inactivity, and only data managed
|
||||
by the hosting infrastructure or files will be preserved upon deactivation.
|
||||
"""
|
||||
input_text = await context.get_input_text()
|
||||
is_streaming_request = self._is_streaming_request(request)
|
||||
|
||||
_, are_options_set = _to_chat_options(request)
|
||||
if are_options_set:
|
||||
logger.warning("Workflow agent doesn't support runtime options. They will be ignored.")
|
||||
|
||||
if request.previous_response_id is not None and context.conversation_id is not None:
|
||||
raise RuntimeError("Previous response ID cannot be used in conjunction with conversation ID.")
|
||||
context_id = request.previous_response_id or context.conversation_id
|
||||
|
||||
# The following should never happen due to the checks above.
|
||||
# This is for type safety and defensive programming.
|
||||
if self._checkpoint_storage_path is None:
|
||||
raise RuntimeError("Checkpoint storage path is not configured for workflow agent.")
|
||||
if not isinstance(self._agent, WorkflowAgent):
|
||||
raise RuntimeError("Agent is not a workflow agent.")
|
||||
|
||||
# Restore from the latest checkpoint if available, otherwise start with an empty history
|
||||
if context_id is not None:
|
||||
checkpoint_storage = FileCheckpointStorage(os.path.join(self._checkpoint_storage_path, context_id))
|
||||
latest_checkpoint = await checkpoint_storage.get_latest(workflow_name=self._agent.workflow.name)
|
||||
if latest_checkpoint is not None:
|
||||
if not is_streaming_request:
|
||||
_ = await self._agent.run(
|
||||
stream=False,
|
||||
checkpoint_id=latest_checkpoint.checkpoint_id,
|
||||
checkpoint_storage=checkpoint_storage,
|
||||
)
|
||||
else:
|
||||
# Consume the streaming or the invocation will result in a no-op
|
||||
async for _ in self._agent.run(
|
||||
stream=True,
|
||||
checkpoint_id=latest_checkpoint.checkpoint_id,
|
||||
checkpoint_storage=checkpoint_storage,
|
||||
):
|
||||
pass
|
||||
|
||||
# Now run the agent with the latest input
|
||||
response_event_stream = ResponseEventStream(response_id=context.response_id, model=request.model)
|
||||
|
||||
# Create a new checkpoint storage for this response based on the following rules:
|
||||
# - If no previous response ID or conversation ID is provided, create a new checkpoint storage for this response
|
||||
# - If a previous response ID is provided, create a new checkpoint storage for this response
|
||||
# - If a conversation ID is provided, reuse the existing checkpoint storage for the conversation
|
||||
context_id = context.conversation_id or context.response_id
|
||||
checkpoint_storage = FileCheckpointStorage(os.path.join(self._checkpoint_storage_path, context_id))
|
||||
|
||||
yield response_event_stream.emit_created()
|
||||
yield response_event_stream.emit_in_progress()
|
||||
|
||||
if not is_streaming_request:
|
||||
# Run the agent in non-streaming mode
|
||||
response = await self._agent.run(input_text, stream=False, checkpoint_storage=checkpoint_storage)
|
||||
|
||||
for message in response.messages:
|
||||
for content in message.contents:
|
||||
async for item in _to_outputs(response_event_stream, content):
|
||||
yield item
|
||||
|
||||
await self._delete_not_latest_checkpoints(checkpoint_storage, self._agent.workflow.name)
|
||||
yield response_event_stream.emit_completed()
|
||||
return
|
||||
|
||||
# Run the agent in streaming mode
|
||||
response_stream = self._agent.run(input_text, stream=True, checkpoint_storage=checkpoint_storage)
|
||||
|
||||
# Track the current active output item builder for streaming;
|
||||
# lazily created on matching content, closed when a different type arrives.
|
||||
tracker = _OutputItemTracker(response_event_stream)
|
||||
|
||||
async for update in response_stream:
|
||||
for content in update.contents:
|
||||
for event in tracker.handle(content):
|
||||
yield event
|
||||
if tracker.needs_async:
|
||||
async for item in _to_outputs(response_event_stream, content):
|
||||
yield item
|
||||
tracker.needs_async = False
|
||||
|
||||
# Close any remaining active builder
|
||||
for event in tracker.close():
|
||||
yield event
|
||||
|
||||
await self._delete_not_latest_checkpoints(checkpoint_storage, self._agent.workflow.name)
|
||||
yield response_event_stream.emit_completed()
|
||||
return
|
||||
|
||||
@staticmethod
|
||||
async def _delete_not_latest_checkpoints(checkpoint_storage: FileCheckpointStorage, workflow_name: str):
|
||||
"""Delete all checkpoints except the latest one.
|
||||
|
||||
We only need the last checkpoint for each invocation.
|
||||
"""
|
||||
latest_checkpoint = await checkpoint_storage.get_latest(workflow_name=workflow_name)
|
||||
if latest_checkpoint is not None:
|
||||
all_checkpoints = await checkpoint_storage.list_checkpoints(workflow_name=workflow_name)
|
||||
for checkpoint in all_checkpoints:
|
||||
if checkpoint.checkpoint_id != latest_checkpoint.checkpoint_id:
|
||||
await checkpoint_storage.delete(checkpoint.checkpoint_id)
|
||||
|
||||
|
||||
# region Active Builder State
|
||||
@@ -310,7 +476,7 @@ class _OutputItemTracker:
|
||||
# region Option Conversion
|
||||
|
||||
|
||||
def _to_chat_options(request: CreateResponse) -> ChatOptions:
|
||||
def _to_chat_options(request: CreateResponse) -> tuple[ChatOptions, bool]:
|
||||
"""Converts a CreateResponse request to ChatOptions.
|
||||
|
||||
Args:
|
||||
@@ -318,19 +484,26 @@ def _to_chat_options(request: CreateResponse) -> ChatOptions:
|
||||
|
||||
Returns:
|
||||
ChatOptions: The converted ChatOptions.
|
||||
bool: Whether any options were set.
|
||||
|
||||
"""
|
||||
chat_options = ChatOptions()
|
||||
are_options_set = False
|
||||
|
||||
if request.temperature is not None:
|
||||
chat_options["temperature"] = request.temperature
|
||||
are_options_set = True
|
||||
if request.top_p is not None:
|
||||
chat_options["top_p"] = request.top_p
|
||||
are_options_set = True
|
||||
if request.max_output_tokens is not None:
|
||||
chat_options["max_tokens"] = request.max_output_tokens
|
||||
are_options_set = True
|
||||
if request.parallel_tool_calls is not None:
|
||||
chat_options["allow_multiple_tool_calls"] = request.parallel_tool_calls
|
||||
are_options_set = True
|
||||
|
||||
return chat_options
|
||||
return chat_options, are_options_set
|
||||
|
||||
|
||||
# endregion
|
||||
|
||||
@@ -5,7 +5,6 @@ import os
|
||||
from agent_framework import Agent
|
||||
from agent_framework.foundry import FoundryChatClient
|
||||
from agent_framework_foundry_hosting import ResponsesHostServer
|
||||
from azure.ai.agentserver.responses import InMemoryResponseProvider
|
||||
from azure.identity import AzureCliCredential
|
||||
from dotenv import load_dotenv
|
||||
|
||||
@@ -29,7 +28,7 @@ def main():
|
||||
default_options={"store": False},
|
||||
)
|
||||
|
||||
server = ResponsesHostServer(agent, store=InMemoryResponseProvider())
|
||||
server = ResponsesHostServer(agent)
|
||||
server.run()
|
||||
|
||||
|
||||
|
||||
@@ -7,7 +7,6 @@ from random import randint
|
||||
from agent_framework import Agent, tool
|
||||
from agent_framework.foundry import FoundryChatClient
|
||||
from agent_framework_foundry_hosting import ResponsesHostServer
|
||||
from azure.ai.agentserver.responses import InMemoryResponseProvider
|
||||
from azure.identity import AzureCliCredential
|
||||
from dotenv import load_dotenv
|
||||
from pydantic import Field
|
||||
@@ -67,7 +66,7 @@ def main():
|
||||
default_options={"store": False},
|
||||
)
|
||||
|
||||
server = ResponsesHostServer(agent, store=InMemoryResponseProvider())
|
||||
server = ResponsesHostServer(agent)
|
||||
server.run()
|
||||
|
||||
|
||||
|
||||
@@ -6,7 +6,6 @@ import httpx
|
||||
from agent_framework import Agent, MCPStreamableHTTPTool
|
||||
from agent_framework.foundry import FoundryChatClient
|
||||
from agent_framework_foundry_hosting import ResponsesHostServer
|
||||
from azure.ai.agentserver.responses import InMemoryResponseProvider
|
||||
from azure.identity import AzureCliCredential
|
||||
from dotenv import load_dotenv
|
||||
|
||||
@@ -69,7 +68,7 @@ def main():
|
||||
default_options={"store": False},
|
||||
)
|
||||
|
||||
server = ResponsesHostServer(agent, store=InMemoryResponseProvider())
|
||||
server = ResponsesHostServer(agent)
|
||||
server.run()
|
||||
|
||||
|
||||
|
||||
@@ -13,5 +13,5 @@ curl -X POST http://localhost:8088/responses -H "Content-Type: application/json"
|
||||
Invoke with `azd`:
|
||||
|
||||
```bash
|
||||
azd ai agent invoke --local "List all the repositories I own on GitHub."
|
||||
azd ai agent invoke --local "Create a slogan for a new electric SUV that is affordable and fun to drive."
|
||||
```
|
||||
|
||||
@@ -2,11 +2,10 @@
|
||||
|
||||
import os
|
||||
|
||||
from agent_framework import Agent
|
||||
from agent_framework import Agent, AgentExecutor, WorkflowBuilder
|
||||
from agent_framework.foundry import FoundryChatClient
|
||||
from agent_framework.orchestrations import GroupChatBuilder, GroupChatState
|
||||
from agent_framework.orchestrations import GroupChatState
|
||||
from agent_framework_foundry_hosting import ResponsesHostServer
|
||||
from azure.ai.agentserver.responses import InMemoryResponseProvider
|
||||
from azure.identity import AzureCliCredential
|
||||
from dotenv import load_dotenv
|
||||
|
||||
@@ -30,35 +29,48 @@ def main():
|
||||
|
||||
writer_agent = Agent(
|
||||
client=client,
|
||||
instructions=(
|
||||
"You are an excellent content writer. You create new content and edit contents based on the feedback."
|
||||
),
|
||||
instructions=("You are an excellent slogan writer. You create new slogans based on the given topic."),
|
||||
name="writer",
|
||||
)
|
||||
|
||||
reviewer_agent = Agent(
|
||||
legal_agent = Agent(
|
||||
client=client,
|
||||
instructions=(
|
||||
"You are an excellent content reviewer."
|
||||
"Provide actionable feedback to the writer about the provided content."
|
||||
"Provide the feedback in the most concise manner possible."
|
||||
"You are an excellent legal reviewer. "
|
||||
"Make necessary corrections to the slogan so that it is legally compliant."
|
||||
),
|
||||
name="reviewer",
|
||||
name="legal_reviewer",
|
||||
)
|
||||
|
||||
format_agent = Agent(
|
||||
client=client,
|
||||
instructions=(
|
||||
"You are an excellent content formatter. "
|
||||
"You take the slogan and format it in a cool retro style when printing to a terminal."
|
||||
),
|
||||
name="formatter",
|
||||
)
|
||||
|
||||
# Set the context mode to `last_agent` so that each agent only sees the output of the
|
||||
# previous agent instead of the full conversation history
|
||||
writer_executor = AgentExecutor(writer_agent, context_mode="last_agent")
|
||||
legal_executor = AgentExecutor(legal_agent, context_mode="last_agent")
|
||||
format_executor = AgentExecutor(format_agent, context_mode="last_agent")
|
||||
|
||||
workflow_agent = (
|
||||
GroupChatBuilder(
|
||||
participants=[writer_agent, reviewer_agent],
|
||||
# Set a hard termination condition to stop after 4 messages:
|
||||
# User message + writer message + reviewer message + writer message
|
||||
termination_condition=lambda conversation: len(conversation) >= 4,
|
||||
selection_func=round_robin_selector,
|
||||
WorkflowBuilder(
|
||||
start_executor=writer_executor,
|
||||
# Limiting the output to only the final formatted result.
|
||||
# If this is not set, all intermediate results will be included in the output.
|
||||
output_executors=[format_executor],
|
||||
)
|
||||
.add_edge(writer_executor, legal_executor)
|
||||
.add_edge(legal_executor, format_executor)
|
||||
.build()
|
||||
.as_agent()
|
||||
)
|
||||
|
||||
server = ResponsesHostServer(workflow_agent, store=InMemoryResponseProvider())
|
||||
server = ResponsesHostServer(workflow_agent)
|
||||
server.run()
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user