Add special handling for workflows

This commit is contained in:
Tao Chen
2026-04-15 20:47:40 -07:00
Unverified
parent 55e0705923
commit 4249aef461
6 changed files with 155 additions and 38 deletions
@@ -5,9 +5,19 @@ from __future__ import annotations
import asyncio
import json
import logging
import os
from collections.abc import AsyncIterable, AsyncIterator, Generator, Mapping
from agent_framework import ChatOptions, Content, HistoryProvider, Message, RawAgent, SupportsAgentRun
from agent_framework import (
ChatOptions,
Content,
FileCheckpointStorage,
HistoryProvider,
Message,
RawAgent,
SupportsAgentRun,
WorkflowAgent,
)
from agent_framework._telemetry import append_to_user_agent
from azure.ai.agentserver.responses import (
ResponseContext,
@@ -60,6 +70,8 @@ class ResponsesHostServer(ResponsesAgentServerHost):
"""A responses server host for an agent."""
USER_AGENT_PREFIX = "foundry-hosting"
# TODO(@taochen): Allow a different checkpoint storage that stores checkpoints externally
CHECKPOINT_STORAGE_PATH = "/.checkpoints"
def __init__(
self,
@@ -80,8 +92,11 @@ class ResponsesHostServer(ResponsesAgentServerHost):
**kwargs: Additional keyword arguments.
Note:
The agent must not have a history provider with `load_messages=True`,
because history is managed by the hosting infrastructure.
1. The agent must not have a history provider with `load_messages=True`,
because history is managed by the hosting infrastructure.
2. The agent must not have any context providers that maintain context
in memory, because the hosting environment may get deactivated between
requests, and any in-memory context would be lost.
"""
super().__init__(prefix=prefix, options=options, store=store, **kwargs)
@@ -91,8 +106,24 @@ class ResponsesHostServer(ResponsesAgentServerHost):
"There shouldn't be a history provider with `load_messages=True` already present. "
"History is managed by the hosting infrastructure."
)
self._agent = agent
self._is_workflow_agent = False
self._checkpoint_storage = None
if isinstance(agent, WorkflowAgent):
if agent.workflow._runner_context.has_checkpointing(): # pyright: ignore[reportPrivateUsage]
raise RuntimeError(
"There should not be a checkpoint storage already present in the workflow agent. "
"The hosting infrastructure will manage checkpoints instead."
)
checkpoint_storage_path = (
self.CHECKPOINT_STORAGE_PATH
if self.config.is_hosted
else os.path.join(os.getcwd(), self.CHECKPOINT_STORAGE_PATH.lstrip("/"))
)
self._is_workflow_agent = True
self._checkpoint_storage = FileCheckpointStorage(checkpoint_storage_path)
self._agent = agent
self.response_handler(self._handler) # pyright: ignore[reportUnknownMemberType]
# Append the user agent prefix for telemetry purposes
@@ -105,16 +136,22 @@ class ResponsesHostServer(ResponsesAgentServerHost):
cancellation_signal: asyncio.Event,
) -> AsyncIterable[ResponseStreamEvent | dict[str, Any]]:
"""Handle the creation of a response."""
if self._is_workflow_agent:
# Workflow agents are handled differently because they require checkpoint restoration
async for event in self._handle_workflow_agent(request, context, cancellation_signal):
yield event
return
input_text = await context.get_input_text()
history = await context.get_history()
messages = [*_to_messages(history), input_text]
chat_options = _to_chat_options(request)
stream = ResponseEventStream(response_id=context.response_id, model=request.model)
response_event_stream = ResponseEventStream(response_id=context.response_id, model=request.model)
yield stream.emit_created()
yield stream.emit_in_progress()
yield response_event_stream.emit_created()
yield response_event_stream.emit_in_progress()
if request.stream is None or request.stream is False:
# Run the agent in non-streaming mode
@@ -126,13 +163,13 @@ class ResponsesHostServer(ResponsesAgentServerHost):
for message in response.messages:
for content in message.contents:
async for item in _to_outputs(stream, content):
async for item in _to_outputs(response_event_stream, content):
yield item
yield stream.emit_completed()
yield response_event_stream.emit_completed()
return
# Start the streaming response
# Run the agent in streaming mode
if isinstance(self._agent, RawAgent):
raw_agent = cast("RawAgent[Any]", self._agent) # pyright: ignore[reportUnknownMemberType]
response_stream = raw_agent.run(messages, stream=True, options=chat_options)
@@ -141,14 +178,14 @@ class ResponsesHostServer(ResponsesAgentServerHost):
# Track the current active output item builder for streaming;
# lazily created on matching content, closed when a different type arrives.
tracker = _OutputItemTracker(stream)
tracker = _OutputItemTracker(response_event_stream)
async for update in response_stream:
for content in update.contents:
for event in tracker.handle(content):
yield event
if tracker.needs_async:
async for item in _to_outputs(stream, content):
async for item in _to_outputs(response_event_stream, content):
yield item
tracker.needs_async = False
@@ -156,7 +193,78 @@ class ResponsesHostServer(ResponsesAgentServerHost):
for event in tracker.close():
yield event
yield stream.emit_completed()
yield response_event_stream.emit_completed()
async def _handle_workflow_agent(
self,
request: CreateResponse,
context: ResponseContext,
cancellation_signal: asyncio.Event,
) -> AsyncIterable[ResponseStreamEvent | dict[str, Any]]:
"""Handle the creation of a response for a workflow agent.
Why this is required:
The sandbox may be deactivated after some period of inactivity, and only data managed
by the hosting infrastructure or files will be preserved upon deactivation.
"""
input_text = await context.get_input_text()
stream = request.stream is not None and request.stream is True
# The following should never happen due to the checks above.
# This is for type safety and defensive programming.
if self._checkpoint_storage is None:
raise RuntimeError("Checkpoint storage is not available for workflow agent.")
if not isinstance(self._agent, WorkflowAgent):
raise RuntimeError("Agent is not a workflow agent.")
# Restore from the latest checkpoint if available, otherwise start with an empty history
latest_checkpoint = await self._checkpoint_storage.get_latest(workflow_name=self._agent.workflow.name)
if latest_checkpoint is not None:
_ = await self._agent.run(
stream=stream,
checkpoint_id=latest_checkpoint.checkpoint_id,
checkpoint_storage=self._checkpoint_storage,
)
# Now run the agent with the latest input
response_event_stream = ResponseEventStream(response_id=context.response_id, model=request.model)
yield response_event_stream.emit_created()
yield response_event_stream.emit_in_progress()
if not stream:
# Run the agent in non-streaming mode
response = await self._agent.run(input_text, stream=False, checkpoint_storage=self._checkpoint_storage)
for message in response.messages:
for content in message.contents:
async for item in _to_outputs(response_event_stream, content):
yield item
yield response_event_stream.emit_completed()
return
# Run the agent in streaming mode
response_stream = self._agent.run(input_text, stream=True, checkpoint_storage=self._checkpoint_storage)
# Track the current active output item builder for streaming;
# lazily created on matching content, closed when a different type arrives.
tracker = _OutputItemTracker(response_event_stream)
async for update in response_stream:
for content in update.contents:
for event in tracker.handle(content):
yield event
if tracker.needs_async:
async for item in _to_outputs(response_event_stream, content):
yield item
tracker.needs_async = False
# Close any remaining active builder
for event in tracker.close():
yield event
yield response_event_stream.emit_completed()
# region Active Builder State
@@ -5,7 +5,6 @@ import os
from agent_framework import Agent
from agent_framework.foundry import FoundryChatClient
from agent_framework_foundry_hosting import ResponsesHostServer
from azure.ai.agentserver.responses import InMemoryResponseProvider
from azure.identity import AzureCliCredential
from dotenv import load_dotenv
@@ -29,7 +28,7 @@ def main():
default_options={"store": False},
)
server = ResponsesHostServer(agent, store=InMemoryResponseProvider())
server = ResponsesHostServer(agent)
server.run()
@@ -7,7 +7,6 @@ from random import randint
from agent_framework import Agent, tool
from agent_framework.foundry import FoundryChatClient
from agent_framework_foundry_hosting import ResponsesHostServer
from azure.ai.agentserver.responses import InMemoryResponseProvider
from azure.identity import AzureCliCredential
from dotenv import load_dotenv
from pydantic import Field
@@ -67,7 +66,7 @@ def main():
default_options={"store": False},
)
server = ResponsesHostServer(agent, store=InMemoryResponseProvider())
server = ResponsesHostServer(agent)
server.run()
@@ -6,7 +6,6 @@ import httpx
from agent_framework import Agent, MCPStreamableHTTPTool
from agent_framework.foundry import FoundryChatClient
from agent_framework_foundry_hosting import ResponsesHostServer
from azure.ai.agentserver.responses import InMemoryResponseProvider
from azure.identity import AzureCliCredential
from dotenv import load_dotenv
@@ -69,7 +68,7 @@ def main():
default_options={"store": False},
)
server = ResponsesHostServer(agent, store=InMemoryResponseProvider())
server = ResponsesHostServer(agent)
server.run()
@@ -13,5 +13,5 @@ curl -X POST http://localhost:8088/responses -H "Content-Type: application/json"
Invoke with `azd`:
```bash
azd ai agent invoke --local "List all the repositories I own on GitHub."
azd ai agent invoke --local "Create a slogan for a new electric SUV that is affordable and fun to drive."
```
@@ -2,11 +2,10 @@
import os
from agent_framework import Agent
from agent_framework import Agent, AgentExecutor, WorkflowBuilder
from agent_framework.foundry import FoundryChatClient
from agent_framework.orchestrations import GroupChatBuilder, GroupChatState
from agent_framework.orchestrations import GroupChatState
from agent_framework_foundry_hosting import ResponsesHostServer
from azure.ai.agentserver.responses import InMemoryResponseProvider
from azure.identity import AzureCliCredential
from dotenv import load_dotenv
@@ -30,35 +29,48 @@ def main():
writer_agent = Agent(
client=client,
instructions=(
"You are an excellent content writer. You create new content and edit contents based on the feedback."
),
instructions=("You are an excellent slogan writer. You create new slogans based on the given topic."),
name="writer",
)
reviewer_agent = Agent(
legal_agent = Agent(
client=client,
instructions=(
"You are an excellent content reviewer."
"Provide actionable feedback to the writer about the provided content."
"Provide the feedback in the most concise manner possible."
"You are an excellent legal reviewer. "
"Make necessary corrections to the slogan so that it is legally compliant."
),
name="reviewer",
name="legal_reviewer",
)
format_agent = Agent(
client=client,
instructions=(
"You are an excellent content formatter. "
"You take the slogan and format it in a cool retro style when printing to a terminal."
),
name="formatter",
)
# Set the context mode to `last_agent` so that each agent only sees the output of the
# previous agent instead of the full conversation history
writer_executor = AgentExecutor(writer_agent, context_mode="last_agent")
legal_executor = AgentExecutor(legal_agent, context_mode="last_agent")
format_executor = AgentExecutor(format_agent, context_mode="last_agent")
workflow_agent = (
GroupChatBuilder(
participants=[writer_agent, reviewer_agent],
# Set a hard termination condition to stop after 4 messages:
# User message + writer message + reviewer message + writer message
termination_condition=lambda conversation: len(conversation) >= 4,
selection_func=round_robin_selector,
WorkflowBuilder(
start_executor=writer_executor,
# Limiting the output to only the final formatted result.
# If this is not set, all intermediate results will be included in the output.
output_executors=[format_executor],
)
.add_edge(writer_executor, legal_executor)
.add_edge(legal_executor, format_executor)
.build()
.as_agent()
)
server = ResponsesHostServer(workflow_agent, store=InMemoryResponseProvider())
server = ResponsesHostServer(workflow_agent)
server.run()