Python: Add special handling for workflows (#5298)

* Add special handling for workflows

* Address comments
This commit is contained in:
Tao Chen
2026-04-16 17:55:45 -07:00
committed by GitHub
Unverified
parent 55e0705923
commit 0fcd71dbeb
6 changed files with 224 additions and 42 deletions
@@ -5,9 +5,20 @@ from __future__ import annotations
import asyncio
import json
import logging
import os
from collections.abc import AsyncIterable, AsyncIterator, Generator, Mapping
from agent_framework import ChatOptions, Content, HistoryProvider, Message, RawAgent, SupportsAgentRun
from agent_framework import (
ChatOptions,
Content,
ContextProvider,
FileCheckpointStorage,
HistoryProvider,
Message,
RawAgent,
SupportsAgentRun,
WorkflowAgent,
)
from agent_framework._telemetry import append_to_user_agent
from azure.ai.agentserver.responses import (
ResponseContext,
@@ -60,6 +71,8 @@ class ResponsesHostServer(ResponsesAgentServerHost):
"""A responses server host for an agent."""
USER_AGENT_PREFIX = "foundry-hosting"
# TODO(@taochen): Allow a different checkpoint storage that stores checkpoints externally
CHECKPOINT_STORAGE_PATH = "/.checkpoints"
def __init__(
self,
@@ -80,8 +93,11 @@ class ResponsesHostServer(ResponsesAgentServerHost):
**kwargs: Additional keyword arguments.
Note:
The agent must not have a history provider with `load_messages=True`,
because history is managed by the hosting infrastructure.
1. The agent must not have a history provider with `load_messages=True`,
because history is managed by the hosting infrastructure.
2. The agent must not have any context providers that maintain context
in memory, because the hosting environment may get deactivated between
requests, and any in-memory context would be lost.
"""
super().__init__(prefix=prefix, options=options, store=store, **kwargs)
@@ -91,13 +107,39 @@ class ResponsesHostServer(ResponsesAgentServerHost):
"There shouldn't be a history provider with `load_messages=True` already present. "
"History is managed by the hosting infrastructure."
)
self._agent = agent
provider = cast(ContextProvider, provider)
logger.warning(
"Context provider %s is present. If it maintains context in memory, "
"the context may be lost between requests. Use with caution.",
provider.source_id,
)
self._is_workflow_agent = False
self._checkpoint_storage_path = None
if isinstance(agent, WorkflowAgent):
if agent.workflow._runner_context.has_checkpointing(): # pyright: ignore[reportPrivateUsage]
raise RuntimeError(
"There should not be a checkpoint storage already present in the workflow agent. "
"The hosting infrastructure will manage checkpoints instead."
)
self._checkpoint_storage_path = (
self.CHECKPOINT_STORAGE_PATH
if self.config.is_hosted
else os.path.join(os.getcwd(), self.CHECKPOINT_STORAGE_PATH.lstrip("/"))
)
self._is_workflow_agent = True
self._agent = agent
self.response_handler(self._handler) # pyright: ignore[reportUnknownMemberType]
# Append the user agent prefix for telemetry purposes
append_to_user_agent(self.USER_AGENT_PREFIX)
@staticmethod
def _is_streaming_request(request: CreateResponse) -> bool:
"""Check if the request is a streaming request."""
return request.stream is not None and request.stream is True
async def _handler(
self,
request: CreateResponse,
@@ -105,50 +147,61 @@ class ResponsesHostServer(ResponsesAgentServerHost):
cancellation_signal: asyncio.Event,
) -> AsyncIterable[ResponseStreamEvent | dict[str, Any]]:
"""Handle the creation of a response."""
if self._is_workflow_agent:
# Workflow agents are handled differently because they require checkpoint restoration
async for event in self._handle_workflow_agent(request, context, cancellation_signal):
yield event
return
input_text = await context.get_input_text()
history = await context.get_history()
messages = [*_to_messages(history), input_text]
chat_options = _to_chat_options(request)
chat_options, are_options_set = _to_chat_options(request)
stream = ResponseEventStream(response_id=context.response_id, model=request.model)
is_streaming_request = self._is_streaming_request(request)
response_event_stream = ResponseEventStream(response_id=context.response_id, model=request.model)
yield stream.emit_created()
yield stream.emit_in_progress()
yield response_event_stream.emit_created()
yield response_event_stream.emit_in_progress()
if request.stream is None or request.stream is False:
if not is_streaming_request:
# Run the agent in non-streaming mode
if isinstance(self._agent, RawAgent):
raw_agent = cast("RawAgent[Any]", self._agent) # pyright: ignore[reportUnknownMemberType]
response = await raw_agent.run(messages, stream=False, options=chat_options)
else:
if are_options_set:
logger.warning("Agent doesn't support runtime options. They will be ignored.")
response = await self._agent.run(messages, stream=False)
for message in response.messages:
for content in message.contents:
async for item in _to_outputs(stream, content):
async for item in _to_outputs(response_event_stream, content):
yield item
yield stream.emit_completed()
yield response_event_stream.emit_completed()
return
# Start the streaming response
# Run the agent in streaming mode
if isinstance(self._agent, RawAgent):
raw_agent = cast("RawAgent[Any]", self._agent) # pyright: ignore[reportUnknownMemberType]
response_stream = raw_agent.run(messages, stream=True, options=chat_options)
else:
if are_options_set:
logger.warning("Agent doesn't support runtime options. They will be ignored.")
response_stream = self._agent.run(messages, stream=True)
# Track the current active output item builder for streaming;
# lazily created on matching content, closed when a different type arrives.
tracker = _OutputItemTracker(stream)
tracker = _OutputItemTracker(response_event_stream)
async for update in response_stream:
for content in update.contents:
for event in tracker.handle(content):
yield event
if tracker.needs_async:
async for item in _to_outputs(stream, content):
async for item in _to_outputs(response_event_stream, content):
yield item
tracker.needs_async = False
@@ -156,7 +209,120 @@ class ResponsesHostServer(ResponsesAgentServerHost):
for event in tracker.close():
yield event
yield stream.emit_completed()
yield response_event_stream.emit_completed()
async def _handle_workflow_agent(
self,
request: CreateResponse,
context: ResponseContext,
cancellation_signal: asyncio.Event,
) -> AsyncIterable[ResponseStreamEvent | dict[str, Any]]:
"""Handle the creation of a response for a workflow agent.
Why this is required:
The sandbox may be deactivated after some period of inactivity, and only data managed
by the hosting infrastructure or files will be preserved upon deactivation.
"""
input_text = await context.get_input_text()
is_streaming_request = self._is_streaming_request(request)
_, are_options_set = _to_chat_options(request)
if are_options_set:
logger.warning("Workflow agent doesn't support runtime options. They will be ignored.")
if request.previous_response_id is not None and context.conversation_id is not None:
raise RuntimeError("Previous response ID cannot be used in conjunction with conversation ID.")
context_id = request.previous_response_id or context.conversation_id
# The following should never happen due to the checks above.
# This is for type safety and defensive programming.
if self._checkpoint_storage_path is None:
raise RuntimeError("Checkpoint storage path is not configured for workflow agent.")
if not isinstance(self._agent, WorkflowAgent):
raise RuntimeError("Agent is not a workflow agent.")
# Restore from the latest checkpoint if available, otherwise start with an empty history
if context_id is not None:
checkpoint_storage = FileCheckpointStorage(os.path.join(self._checkpoint_storage_path, context_id))
latest_checkpoint = await checkpoint_storage.get_latest(workflow_name=self._agent.workflow.name)
if latest_checkpoint is not None:
if not is_streaming_request:
_ = await self._agent.run(
stream=False,
checkpoint_id=latest_checkpoint.checkpoint_id,
checkpoint_storage=checkpoint_storage,
)
else:
# Consume the streaming or the invocation will result in a no-op
async for _ in self._agent.run(
stream=True,
checkpoint_id=latest_checkpoint.checkpoint_id,
checkpoint_storage=checkpoint_storage,
):
pass
# Now run the agent with the latest input
response_event_stream = ResponseEventStream(response_id=context.response_id, model=request.model)
# Create a new checkpoint storage for this response based on the following rules:
# - If no previous response ID or conversation ID is provided, create a new checkpoint storage for this response
# - If a previous response ID is provided, create a new checkpoint storage for this response
# - If a conversation ID is provided, reuse the existing checkpoint storage for the conversation
context_id = context.conversation_id or context.response_id
checkpoint_storage = FileCheckpointStorage(os.path.join(self._checkpoint_storage_path, context_id))
yield response_event_stream.emit_created()
yield response_event_stream.emit_in_progress()
if not is_streaming_request:
# Run the agent in non-streaming mode
response = await self._agent.run(input_text, stream=False, checkpoint_storage=checkpoint_storage)
for message in response.messages:
for content in message.contents:
async for item in _to_outputs(response_event_stream, content):
yield item
await self._delete_not_latest_checkpoints(checkpoint_storage, self._agent.workflow.name)
yield response_event_stream.emit_completed()
return
# Run the agent in streaming mode
response_stream = self._agent.run(input_text, stream=True, checkpoint_storage=checkpoint_storage)
# Track the current active output item builder for streaming;
# lazily created on matching content, closed when a different type arrives.
tracker = _OutputItemTracker(response_event_stream)
async for update in response_stream:
for content in update.contents:
for event in tracker.handle(content):
yield event
if tracker.needs_async:
async for item in _to_outputs(response_event_stream, content):
yield item
tracker.needs_async = False
# Close any remaining active builder
for event in tracker.close():
yield event
await self._delete_not_latest_checkpoints(checkpoint_storage, self._agent.workflow.name)
yield response_event_stream.emit_completed()
return
@staticmethod
async def _delete_not_latest_checkpoints(checkpoint_storage: FileCheckpointStorage, workflow_name: str):
"""Delete all checkpoints except the latest one.
We only need the last checkpoint for each invocation.
"""
latest_checkpoint = await checkpoint_storage.get_latest(workflow_name=workflow_name)
if latest_checkpoint is not None:
all_checkpoints = await checkpoint_storage.list_checkpoints(workflow_name=workflow_name)
for checkpoint in all_checkpoints:
if checkpoint.checkpoint_id != latest_checkpoint.checkpoint_id:
await checkpoint_storage.delete(checkpoint.checkpoint_id)
# region Active Builder State
@@ -310,7 +476,7 @@ class _OutputItemTracker:
# region Option Conversion
def _to_chat_options(request: CreateResponse) -> ChatOptions:
def _to_chat_options(request: CreateResponse) -> tuple[ChatOptions, bool]:
"""Converts a CreateResponse request to ChatOptions.
Args:
@@ -318,19 +484,26 @@ def _to_chat_options(request: CreateResponse) -> ChatOptions:
Returns:
ChatOptions: The converted ChatOptions.
bool: Whether any options were set.
"""
chat_options = ChatOptions()
are_options_set = False
if request.temperature is not None:
chat_options["temperature"] = request.temperature
are_options_set = True
if request.top_p is not None:
chat_options["top_p"] = request.top_p
are_options_set = True
if request.max_output_tokens is not None:
chat_options["max_tokens"] = request.max_output_tokens
are_options_set = True
if request.parallel_tool_calls is not None:
chat_options["allow_multiple_tool_calls"] = request.parallel_tool_calls
are_options_set = True
return chat_options
return chat_options, are_options_set
# endregion
@@ -5,7 +5,6 @@ import os
from agent_framework import Agent
from agent_framework.foundry import FoundryChatClient
from agent_framework_foundry_hosting import ResponsesHostServer
from azure.ai.agentserver.responses import InMemoryResponseProvider
from azure.identity import AzureCliCredential
from dotenv import load_dotenv
@@ -29,7 +28,7 @@ def main():
default_options={"store": False},
)
server = ResponsesHostServer(agent, store=InMemoryResponseProvider())
server = ResponsesHostServer(agent)
server.run()
@@ -7,7 +7,6 @@ from random import randint
from agent_framework import Agent, tool
from agent_framework.foundry import FoundryChatClient
from agent_framework_foundry_hosting import ResponsesHostServer
from azure.ai.agentserver.responses import InMemoryResponseProvider
from azure.identity import AzureCliCredential
from dotenv import load_dotenv
from pydantic import Field
@@ -67,7 +66,7 @@ def main():
default_options={"store": False},
)
server = ResponsesHostServer(agent, store=InMemoryResponseProvider())
server = ResponsesHostServer(agent)
server.run()
@@ -6,7 +6,6 @@ import httpx
from agent_framework import Agent, MCPStreamableHTTPTool
from agent_framework.foundry import FoundryChatClient
from agent_framework_foundry_hosting import ResponsesHostServer
from azure.ai.agentserver.responses import InMemoryResponseProvider
from azure.identity import AzureCliCredential
from dotenv import load_dotenv
@@ -69,7 +68,7 @@ def main():
default_options={"store": False},
)
server = ResponsesHostServer(agent, store=InMemoryResponseProvider())
server = ResponsesHostServer(agent)
server.run()
@@ -13,5 +13,5 @@ curl -X POST http://localhost:8088/responses -H "Content-Type: application/json"
Invoke with `azd`:
```bash
azd ai agent invoke --local "List all the repositories I own on GitHub."
azd ai agent invoke --local "Create a slogan for a new electric SUV that is affordable and fun to drive."
```
@@ -2,11 +2,10 @@
import os
from agent_framework import Agent
from agent_framework import Agent, AgentExecutor, WorkflowBuilder
from agent_framework.foundry import FoundryChatClient
from agent_framework.orchestrations import GroupChatBuilder, GroupChatState
from agent_framework.orchestrations import GroupChatState
from agent_framework_foundry_hosting import ResponsesHostServer
from azure.ai.agentserver.responses import InMemoryResponseProvider
from azure.identity import AzureCliCredential
from dotenv import load_dotenv
@@ -30,35 +29,48 @@ def main():
writer_agent = Agent(
client=client,
instructions=(
"You are an excellent content writer. You create new content and edit contents based on the feedback."
),
instructions=("You are an excellent slogan writer. You create new slogans based on the given topic."),
name="writer",
)
reviewer_agent = Agent(
legal_agent = Agent(
client=client,
instructions=(
"You are an excellent content reviewer."
"Provide actionable feedback to the writer about the provided content."
"Provide the feedback in the most concise manner possible."
"You are an excellent legal reviewer. "
"Make necessary corrections to the slogan so that it is legally compliant."
),
name="reviewer",
name="legal_reviewer",
)
format_agent = Agent(
client=client,
instructions=(
"You are an excellent content formatter. "
"You take the slogan and format it in a cool retro style when printing to a terminal."
),
name="formatter",
)
# Set the context mode to `last_agent` so that each agent only sees the output of the
# previous agent instead of the full conversation history
writer_executor = AgentExecutor(writer_agent, context_mode="last_agent")
legal_executor = AgentExecutor(legal_agent, context_mode="last_agent")
format_executor = AgentExecutor(format_agent, context_mode="last_agent")
workflow_agent = (
GroupChatBuilder(
participants=[writer_agent, reviewer_agent],
# Set a hard termination condition to stop after 4 messages:
# User message + writer message + reviewer message + writer message
termination_condition=lambda conversation: len(conversation) >= 4,
selection_func=round_robin_selector,
WorkflowBuilder(
start_executor=writer_executor,
# Limiting the output to only the final formatted result.
# If this is not set, all intermediate results will be included in the output.
output_executors=[format_executor],
)
.add_edge(writer_executor, legal_executor)
.add_edge(legal_executor, format_executor)
.build()
.as_agent()
)
server = ResponsesHostServer(workflow_agent, store=InMemoryResponseProvider())
server = ResponsesHostServer(workflow_agent)
server.run()