mirror of
https://github.com/microsoft/agent-framework.git
synced 2026-06-16 21:04:09 +08:00
[BREAKING] Python: Refactor SharedState to State with sync methods and superstep caching (#3667)
* Refactor SharedState to State with sync methods and superstep caching * Fixes * Address PR feedback * Remove dead links * Fix lab test import
This commit is contained in:
committed by
GitHub
Unverified
parent
4e25917644
commit
10afb86213
@@ -189,9 +189,9 @@ class DataIngestion(Executor):
|
||||
timestamp=asyncio.get_event_loop().time(),
|
||||
)
|
||||
|
||||
# Store both batch data and original request in shared state
|
||||
await ctx.set_shared_state(f"batch_{batch.batch_id}", batch)
|
||||
await ctx.set_shared_state(f"request_{batch.batch_id}", request)
|
||||
# Store both batch data and original request in workflow state
|
||||
ctx.set_state(f"batch_{batch.batch_id}", batch)
|
||||
ctx.set_state(f"request_{batch.batch_id}", request)
|
||||
|
||||
await ctx.send_message(batch)
|
||||
|
||||
@@ -204,7 +204,7 @@ class SchemaValidator(Executor):
|
||||
async def validate_schema(self, batch: DataBatch, ctx: WorkflowContext[ValidationReport]) -> None:
|
||||
"""Perform schema validation with processing delay."""
|
||||
# Check if schema validation is enabled
|
||||
request = await ctx.get_shared_state(f"request_{batch.batch_id}")
|
||||
request = ctx.get_state(f"request_{batch.batch_id}")
|
||||
if not request or not request.enable_schema_validation:
|
||||
return
|
||||
|
||||
@@ -240,7 +240,7 @@ class DataQualityValidator(Executor):
|
||||
async def validate_quality(self, batch: DataBatch, ctx: WorkflowContext[ValidationReport]) -> None:
|
||||
"""Perform data quality validation."""
|
||||
# Check if quality validation is enabled
|
||||
request = await ctx.get_shared_state(f"request_{batch.batch_id}")
|
||||
request = ctx.get_state(f"request_{batch.batch_id}")
|
||||
if not request or not request.enable_quality_validation:
|
||||
return
|
||||
|
||||
@@ -282,7 +282,7 @@ class SecurityValidator(Executor):
|
||||
async def validate_security(self, batch: DataBatch, ctx: WorkflowContext[ValidationReport]) -> None:
|
||||
"""Perform security validation."""
|
||||
# Check if security validation is enabled
|
||||
request = await ctx.get_shared_state(f"request_{batch.batch_id}")
|
||||
request = ctx.get_state(f"request_{batch.batch_id}")
|
||||
if not request or not request.enable_security_validation:
|
||||
return
|
||||
|
||||
@@ -323,7 +323,7 @@ class ValidationAggregator(Executor):
|
||||
return
|
||||
|
||||
batch_id = reports[0].batch_id
|
||||
request = await ctx.get_shared_state(f"request_{batch_id}")
|
||||
request = ctx.get_state(f"request_{batch_id}")
|
||||
|
||||
await asyncio.sleep(1) # Aggregation processing time
|
||||
|
||||
@@ -353,8 +353,8 @@ class ValidationAggregator(Executor):
|
||||
)
|
||||
return
|
||||
|
||||
# Retrieve original batch from shared state
|
||||
batch_data = await ctx.get_shared_state(f"batch_{batch_id}")
|
||||
# Retrieve original batch from workflow state
|
||||
batch_data = ctx.get_state(f"batch_{batch_id}")
|
||||
if batch_data:
|
||||
await ctx.send_message(batch_data)
|
||||
else:
|
||||
@@ -375,7 +375,7 @@ class DataNormalizer(Executor):
|
||||
@handler
|
||||
async def normalize_data(self, batch: DataBatch, ctx: WorkflowContext[TransformationResult]) -> None:
|
||||
"""Perform data normalization."""
|
||||
request = await ctx.get_shared_state(f"request_{batch.batch_id}")
|
||||
request = ctx.get_state(f"request_{batch.batch_id}")
|
||||
|
||||
# Check if normalization is enabled
|
||||
if not request or "normalize" not in request.transformations:
|
||||
@@ -420,7 +420,7 @@ class DataEnrichment(Executor):
|
||||
@handler
|
||||
async def enrich_data(self, batch: DataBatch, ctx: WorkflowContext[TransformationResult]) -> None:
|
||||
"""Perform data enrichment."""
|
||||
request = await ctx.get_shared_state(f"request_{batch.batch_id}")
|
||||
request = ctx.get_state(f"request_{batch.batch_id}")
|
||||
|
||||
# Check if enrichment is enabled
|
||||
if not request or "enrich" not in request.transformations:
|
||||
@@ -464,7 +464,7 @@ class DataAggregator(Executor):
|
||||
@handler
|
||||
async def aggregate_data(self, batch: DataBatch, ctx: WorkflowContext[TransformationResult]) -> None:
|
||||
"""Perform data aggregation."""
|
||||
request = await ctx.get_shared_state(f"request_{batch.batch_id}")
|
||||
request = ctx.get_state(f"request_{batch.batch_id}")
|
||||
|
||||
# Check if aggregation is enabled
|
||||
if not request or "aggregate" not in request.transformations:
|
||||
@@ -625,12 +625,12 @@ class FinalProcessor(Executor):
|
||||
|
||||
# Workflow Builder Helper
|
||||
class WorkflowSetupHelper:
|
||||
"""Helper class to set up the complex workflow with shared state management."""
|
||||
"""Helper class to set up the complex workflow with state management."""
|
||||
|
||||
@staticmethod
|
||||
async def store_batch_data(batch: DataBatch, ctx: WorkflowContext) -> None:
|
||||
"""Store batch data in shared state for later retrieval."""
|
||||
await ctx.set_shared_state(f"batch_{batch.batch_id}", batch)
|
||||
"""Store batch data in workflow state for later retrieval."""
|
||||
ctx.set_state(f"batch_{batch.batch_id}", batch)
|
||||
|
||||
|
||||
# Create the workflow instance
|
||||
|
||||
@@ -37,8 +37,6 @@ Once comfortable with these, explore the rest of the samples below.
|
||||
| Azure Chat Agents (Streaming) | [agents/azure_chat_agents_streaming.py](./agents/azure_chat_agents_streaming.py) | Add Azure Chat agents as edges and handle streaming events |
|
||||
| Azure AI Agents (Streaming) | [agents/azure_ai_agents_streaming.py](./agents/azure_ai_agents_streaming.py) | Add Azure AI agents as edges and handle streaming events |
|
||||
| Azure AI Agents (Shared Thread) | [agents/azure_ai_agents_with_shared_thread.py](./agents/azure_ai_agents_with_shared_thread.py) | Share a common message thread between multiple Azure AI agents in a workflow |
|
||||
| Azure Chat Agents (Function Bridge) | [agents/azure_chat_agents_function_bridge.py](./agents/azure_chat_agents_function_bridge.py) | Chain two agents with a function executor that injects external context |
|
||||
| Azure Chat Agents (Tools + HITL) | [agents/azure_chat_agents_tool_calls_with_feedback.py](./agents/azure_chat_agents_tool_calls_with_feedback.py) | Tool-enabled writer/editor pipeline with human feedback gating |
|
||||
| Custom Agent Executors | [agents/custom_agent_executors.py](./agents/custom_agent_executors.py) | Create executors to handle agent run methods |
|
||||
| Sequential Workflow as Agent | [agents/sequential_workflow_as_agent.py](./agents/sequential_workflow_as_agent.py) | Build a sequential workflow orchestrating agents, then expose it as a reusable agent |
|
||||
| Concurrent Workflow as Agent | [agents/concurrent_workflow_as_agent.py](./agents/concurrent_workflow_as_agent.py) | Build a concurrent fan-out/fan-in workflow, then expose it as a reusable agent |
|
||||
@@ -146,16 +144,10 @@ to configure which agents can route to which others with a fluent, type-safe API
|
||||
|
||||
### state-management
|
||||
|
||||
| Sample | File | Concepts |
|
||||
| -------------------------------- | ------------------------------------------------------------------------------------------------ | -------------------------------------------------------------------------- |
|
||||
| Shared States | [state-management/shared_states_with_agents.py](./state-management/shared_states_with_agents.py) | Store in shared state once and later reuse across agents |
|
||||
| Workflow Kwargs (Custom Context) | [state-management/workflow_kwargs.py](./state-management/workflow_kwargs.py) | Pass custom context (data, user tokens) via kwargs to `@ai_function` tools |
|
||||
|
||||
=======
|
||||
| Sample | File | Concepts |
|
||||
|---|---|---|
|
||||
| Shared States | [state-management/shared_states_with_agents.py](./state-management/shared_states_with_agents.py) | Store in shared state once and later reuse across agents |
|
||||
| Workflow Kwargs (Custom Context) | [state-management/workflow_kwargs.py](./state-management/workflow_kwargs.py) | Pass custom context (data, user tokens) via kwargs to `@tool` tools |
|
||||
| Sample | File | Concepts |
|
||||
| -------------------------------- | ------------------------------------------------------------------------------------------------ | ----------------------------------------------------------------- |
|
||||
| State with Agents | [state-management/state_with_agents.py](./state-management/state_with_agents.py) | Store in state once and later reuse across agents |
|
||||
| Workflow Kwargs (Custom Context) | [state-management/workflow_kwargs.py](./state-management/workflow_kwargs.py) | Pass custom context (data, user tokens) via kwargs to `@tool` tools |
|
||||
|
||||
### visualization
|
||||
|
||||
|
||||
@@ -18,7 +18,7 @@ Key Concepts:
|
||||
- Build a workflow using SequentialBuilder (or any builder pattern)
|
||||
- Expose the workflow as a reusable agent via workflow.as_agent()
|
||||
- Pass custom context as kwargs when invoking workflow_agent.run() or run_stream()
|
||||
- kwargs are stored in SharedState and propagated to all agent invocations
|
||||
- kwargs are stored in State and propagated to all agent invocations
|
||||
- @tool functions receive kwargs via **kwargs parameter
|
||||
|
||||
When to use workflow.as_agent():
|
||||
|
||||
+2
-2
@@ -81,9 +81,9 @@ class BriefPreparer(Executor):
|
||||
normalized = " ".join(brief.split()).strip()
|
||||
if not normalized.endswith("."):
|
||||
normalized += "."
|
||||
# Persist the cleaned brief in shared state so downstream executors and
|
||||
# Persist the cleaned brief in workflow state so downstream executors and
|
||||
# future checkpoints can recover the original intent.
|
||||
await ctx.set_shared_state("brief", normalized)
|
||||
ctx.set_state("brief", normalized)
|
||||
prompt = (
|
||||
"You are drafting product release notes. Summarise the brief below in two sentences. "
|
||||
"Keep it positive and end with a call to action.\n\n"
|
||||
|
||||
+11
-11
@@ -36,7 +36,7 @@ Purpose:
|
||||
Demonstrate how to use a multi-selection edge group to fan out from one executor to multiple possible targets.
|
||||
Show how to:
|
||||
- Implement a selection function that chooses one or more downstream branches based on analysis.
|
||||
- Share state across branches so different executors can read the same email content.
|
||||
- Share workflow state across branches so different executors can read the same email content.
|
||||
- Validate agent outputs with Pydantic models for robust structured data exchange.
|
||||
- Merge results from multiple branches (e.g., a summary) back into a typed state.
|
||||
- Apply conditional persistence logic (short vs long emails).
|
||||
@@ -44,7 +44,7 @@ Show how to:
|
||||
Prerequisites:
|
||||
- Familiarity with WorkflowBuilder, executors, edges, and events.
|
||||
- Understanding of multi-selection edge groups and how their selection function maps to target ids.
|
||||
- Experience with shared state in workflows for persisting and reusing objects.
|
||||
- Experience with workflow state for persisting and reusing objects.
|
||||
"""
|
||||
|
||||
|
||||
@@ -87,8 +87,8 @@ class DatabaseEvent(WorkflowEvent): ...
|
||||
@executor(id="store_email")
|
||||
async def store_email(email_text: str, ctx: WorkflowContext[AgentExecutorRequest]) -> None:
|
||||
new_email = Email(email_id=str(uuid4()), email_content=email_text)
|
||||
await ctx.set_shared_state(f"{EMAIL_STATE_PREFIX}{new_email.email_id}", new_email)
|
||||
await ctx.set_shared_state(CURRENT_EMAIL_ID_KEY, new_email.email_id)
|
||||
ctx.set_state(f"{EMAIL_STATE_PREFIX}{new_email.email_id}", new_email)
|
||||
ctx.set_state(CURRENT_EMAIL_ID_KEY, new_email.email_id)
|
||||
|
||||
await ctx.send_message(
|
||||
AgentExecutorRequest(messages=[ChatMessage("user", text=new_email.email_content)], should_respond=True)
|
||||
@@ -98,8 +98,8 @@ async def store_email(email_text: str, ctx: WorkflowContext[AgentExecutorRequest
|
||||
@executor(id="to_analysis_result")
|
||||
async def to_analysis_result(response: AgentExecutorResponse, ctx: WorkflowContext[AnalysisResult]) -> None:
|
||||
parsed = AnalysisResultAgent.model_validate_json(response.agent_response.text)
|
||||
email_id: str = await ctx.get_shared_state(CURRENT_EMAIL_ID_KEY)
|
||||
email: Email = await ctx.get_shared_state(f"{EMAIL_STATE_PREFIX}{email_id}")
|
||||
email_id: str = ctx.get_state(CURRENT_EMAIL_ID_KEY)
|
||||
email: Email = ctx.get_state(f"{EMAIL_STATE_PREFIX}{email_id}")
|
||||
await ctx.send_message(
|
||||
AnalysisResult(
|
||||
spam_decision=parsed.spam_decision,
|
||||
@@ -116,7 +116,7 @@ async def submit_to_email_assistant(analysis: AnalysisResult, ctx: WorkflowConte
|
||||
if analysis.spam_decision != "NotSpam":
|
||||
raise RuntimeError("This executor should only handle NotSpam messages.")
|
||||
|
||||
email: Email = await ctx.get_shared_state(f"{EMAIL_STATE_PREFIX}{analysis.email_id}")
|
||||
email: Email = ctx.get_state(f"{EMAIL_STATE_PREFIX}{analysis.email_id}")
|
||||
await ctx.send_message(
|
||||
AgentExecutorRequest(messages=[ChatMessage("user", text=email.email_content)], should_respond=True)
|
||||
)
|
||||
@@ -131,7 +131,7 @@ async def finalize_and_send(response: AgentExecutorResponse, ctx: WorkflowContex
|
||||
@executor(id="summarize_email")
|
||||
async def summarize_email(analysis: AnalysisResult, ctx: WorkflowContext[AgentExecutorRequest]) -> None:
|
||||
# Only called for long NotSpam emails by selection_func
|
||||
email: Email = await ctx.get_shared_state(f"{EMAIL_STATE_PREFIX}{analysis.email_id}")
|
||||
email: Email = ctx.get_state(f"{EMAIL_STATE_PREFIX}{analysis.email_id}")
|
||||
await ctx.send_message(
|
||||
AgentExecutorRequest(messages=[ChatMessage("user", text=email.email_content)], should_respond=True)
|
||||
)
|
||||
@@ -140,8 +140,8 @@ async def summarize_email(analysis: AnalysisResult, ctx: WorkflowContext[AgentEx
|
||||
@executor(id="merge_summary")
|
||||
async def merge_summary(response: AgentExecutorResponse, ctx: WorkflowContext[AnalysisResult]) -> None:
|
||||
summary = EmailSummaryModel.model_validate_json(response.agent_response.text)
|
||||
email_id: str = await ctx.get_shared_state(CURRENT_EMAIL_ID_KEY)
|
||||
email: Email = await ctx.get_shared_state(f"{EMAIL_STATE_PREFIX}{email_id}")
|
||||
email_id: str = ctx.get_state(CURRENT_EMAIL_ID_KEY)
|
||||
email: Email = ctx.get_state(f"{EMAIL_STATE_PREFIX}{email_id}")
|
||||
# Build an AnalysisResult mirroring to_analysis_result but with summary
|
||||
await ctx.send_message(
|
||||
AnalysisResult(
|
||||
@@ -165,7 +165,7 @@ async def handle_spam(analysis: AnalysisResult, ctx: WorkflowContext[Never, str]
|
||||
@executor(id="handle_uncertain")
|
||||
async def handle_uncertain(analysis: AnalysisResult, ctx: WorkflowContext[Never, str]) -> None:
|
||||
if analysis.spam_decision == "Uncertain":
|
||||
email: Email | None = await ctx.get_shared_state(f"{EMAIL_STATE_PREFIX}{analysis.email_id}")
|
||||
email: Email | None = ctx.get_state(f"{EMAIL_STATE_PREFIX}{analysis.email_id}")
|
||||
await ctx.yield_output(
|
||||
f"Email marked as uncertain: {analysis.reason}. Email content: {getattr(email, 'email_content', '')}"
|
||||
)
|
||||
|
||||
@@ -25,13 +25,13 @@ from typing_extensions import Never
|
||||
"""
|
||||
Sample: Switch-Case Edge Group with an explicit Uncertain branch.
|
||||
|
||||
The workflow stores a single email in shared state, asks a spam detection agent for a three way decision,
|
||||
The workflow stores a single email in workflow state, asks a spam detection agent for a three way decision,
|
||||
then routes with a switch-case group: NotSpam to the drafting assistant, Spam to a spam handler, and
|
||||
Default to an Uncertain handler.
|
||||
|
||||
Purpose:
|
||||
Demonstrate deterministic one of N routing with switch-case edges. Show how to:
|
||||
- Persist input once in shared state, then pass around a small typed pointer that carries the email id.
|
||||
- Persist input once in workflow state, then pass around a small typed pointer that carries the email id.
|
||||
- Validate agent JSON with Pydantic models for robust parsing.
|
||||
- Keep executor responsibilities narrow. Transform model output to a typed DetectionResult, then route based
|
||||
on that type.
|
||||
@@ -74,7 +74,7 @@ class DetectionResult:
|
||||
|
||||
@dataclass
|
||||
class Email:
|
||||
# In memory record of the email content stored in shared state.
|
||||
# In memory record of the email content stored in workflow state.
|
||||
email_id: str
|
||||
email_content: str
|
||||
|
||||
@@ -93,8 +93,8 @@ def get_case(expected_decision: str):
|
||||
async def store_email(email_text: str, ctx: WorkflowContext[AgentExecutorRequest]) -> None:
|
||||
# Persist the raw email once. Store under a unique key and set the current pointer for convenience.
|
||||
new_email = Email(email_id=str(uuid4()), email_content=email_text)
|
||||
await ctx.set_shared_state(f"{EMAIL_STATE_PREFIX}{new_email.email_id}", new_email)
|
||||
await ctx.set_shared_state(CURRENT_EMAIL_ID_KEY, new_email.email_id)
|
||||
ctx.set_state(f"{EMAIL_STATE_PREFIX}{new_email.email_id}", new_email)
|
||||
ctx.set_state(CURRENT_EMAIL_ID_KEY, new_email.email_id)
|
||||
|
||||
# Kick off the detector by forwarding the email as a user message to the spam_detection_agent.
|
||||
await ctx.send_message(
|
||||
@@ -106,7 +106,7 @@ async def store_email(email_text: str, ctx: WorkflowContext[AgentExecutorRequest
|
||||
async def to_detection_result(response: AgentExecutorResponse, ctx: WorkflowContext[DetectionResult]) -> None:
|
||||
# Parse the detector JSON into a typed model. Attach the current email id for downstream lookups.
|
||||
parsed = DetectionResultAgent.model_validate_json(response.agent_response.text)
|
||||
email_id: str = await ctx.get_shared_state(CURRENT_EMAIL_ID_KEY)
|
||||
email_id: str = ctx.get_state(CURRENT_EMAIL_ID_KEY)
|
||||
await ctx.send_message(DetectionResult(spam_decision=parsed.spam_decision, reason=parsed.reason, email_id=email_id))
|
||||
|
||||
|
||||
@@ -116,8 +116,8 @@ async def submit_to_email_assistant(detection: DetectionResult, ctx: WorkflowCon
|
||||
if detection.spam_decision != "NotSpam":
|
||||
raise RuntimeError("This executor should only handle NotSpam messages.")
|
||||
|
||||
# Load the original content from shared state using the id carried in DetectionResult.
|
||||
email: Email = await ctx.get_shared_state(f"{EMAIL_STATE_PREFIX}{detection.email_id}")
|
||||
# Load the original content from workflow state using the id carried in DetectionResult.
|
||||
email: Email = ctx.get_state(f"{EMAIL_STATE_PREFIX}{detection.email_id}")
|
||||
await ctx.send_message(
|
||||
AgentExecutorRequest(messages=[ChatMessage("user", text=email.email_content)], should_respond=True)
|
||||
)
|
||||
@@ -143,7 +143,7 @@ async def handle_spam(detection: DetectionResult, ctx: WorkflowContext[Never, st
|
||||
async def handle_uncertain(detection: DetectionResult, ctx: WorkflowContext[Never, str]) -> None:
|
||||
# Uncertain path terminal. Surface the original content to aid human review.
|
||||
if detection.spam_decision == "Uncertain":
|
||||
email: Email | None = await ctx.get_shared_state(f"{EMAIL_STATE_PREFIX}{detection.email_id}")
|
||||
email: Email | None = ctx.get_state(f"{EMAIL_STATE_PREFIX}{detection.email_id}")
|
||||
await ctx.yield_output(
|
||||
f"Email marked as uncertain: {detection.reason}. Email content: {getattr(email, 'email_content', '')}"
|
||||
)
|
||||
|
||||
+12
-12
@@ -10,7 +10,7 @@ import aiofiles
|
||||
from agent_framework import (
|
||||
Executor, # Base class for custom workflow steps
|
||||
WorkflowBuilder, # Fluent builder for executors and edges
|
||||
WorkflowContext, # Per run context with shared state and messaging
|
||||
WorkflowContext, # Per run context with workflow state and messaging
|
||||
WorkflowOutputEvent, # Event emitted when workflow yields output
|
||||
WorkflowViz, # Utility to visualize a workflow graph
|
||||
handler, # Decorator to expose an Executor method as a step
|
||||
@@ -26,7 +26,7 @@ It also demonstrates WorkflowViz for graph visualization.
|
||||
|
||||
Purpose:
|
||||
Show how to:
|
||||
- Partition input once and coordinate parallel mappers with shared state.
|
||||
- Partition input once and coordinate parallel mappers with workflow state.
|
||||
- Implement map, shuffle, and reduce executors that pass file paths instead of large payloads.
|
||||
- Use fan out and fan in edges to express parallelism and joins.
|
||||
- Persist intermediate results to disk to bound memory usage for large inputs.
|
||||
@@ -49,8 +49,8 @@ TEMP_DIR = os.path.join(DIR, "tmp")
|
||||
# Ensure the temporary directory exists
|
||||
os.makedirs(TEMP_DIR, exist_ok=True)
|
||||
|
||||
# Define a key for the shared state to store the data to be processed
|
||||
SHARED_STATE_DATA_KEY = "data_to_be_processed"
|
||||
# Define a key for the workflow state to store the data to be processed
|
||||
STATE_DATA_KEY = "data_to_be_processed"
|
||||
|
||||
|
||||
class SplitCompleted:
|
||||
@@ -69,17 +69,17 @@ class Split(Executor):
|
||||
|
||||
@handler
|
||||
async def split(self, data: str, ctx: WorkflowContext[SplitCompleted]) -> None:
|
||||
"""Tokenize input and assign contiguous index ranges to each mapper via shared state.
|
||||
"""Tokenize input and assign contiguous index ranges to each mapper via workflow state.
|
||||
|
||||
Args:
|
||||
data: The raw text to process.
|
||||
ctx: Workflow context to persist shared state and send messages.
|
||||
ctx: Workflow context to persist state and send messages.
|
||||
"""
|
||||
# Process data into a list of words and remove empty lines or words.
|
||||
word_list = self._preprocess(data)
|
||||
|
||||
# Store tokenized words once so all mappers can read by index.
|
||||
await ctx.set_shared_state(SHARED_STATE_DATA_KEY, word_list)
|
||||
ctx.set_state(STATE_DATA_KEY, word_list)
|
||||
|
||||
# Divide indices into contiguous slices for each mapper.
|
||||
map_executor_count = len(self._map_executor_ids)
|
||||
@@ -90,8 +90,8 @@ class Split(Executor):
|
||||
start_index = i * chunk_size
|
||||
end_index = start_index + chunk_size if i < map_executor_count - 1 else len(word_list)
|
||||
|
||||
# The mapper reads its slice from shared state keyed by its own executor id.
|
||||
await ctx.set_shared_state(self._map_executor_ids[i], (start_index, end_index))
|
||||
# The mapper reads its slice from workflow state keyed by its own executor id.
|
||||
ctx.set_state(self._map_executor_ids[i], (start_index, end_index))
|
||||
await ctx.send_message(SplitCompleted(), self._map_executor_ids[i])
|
||||
|
||||
tasks = [asyncio.create_task(_process_chunk(i)) for i in range(map_executor_count)]
|
||||
@@ -119,11 +119,11 @@ class Map(Executor):
|
||||
|
||||
Args:
|
||||
_: SplitCompleted marker indicating maps can begin.
|
||||
ctx: Workflow context for shared state access and messaging.
|
||||
ctx: Workflow context for workflow state access and messaging.
|
||||
"""
|
||||
# Retrieve tokens and our assigned slice.
|
||||
data_to_be_processed: list[str] = await ctx.get_shared_state(SHARED_STATE_DATA_KEY)
|
||||
chunk_start, chunk_end = await ctx.get_shared_state(self.id)
|
||||
data_to_be_processed: list[str] = ctx.get_state(STATE_DATA_KEY)
|
||||
chunk_start, chunk_end = ctx.get_state(self.id)
|
||||
|
||||
results = [(item, 1) for item in data_to_be_processed[chunk_start:chunk_end]]
|
||||
|
||||
|
||||
+12
-12
@@ -21,14 +21,14 @@ from pydantic import BaseModel
|
||||
from typing_extensions import Never
|
||||
|
||||
"""
|
||||
Sample: Shared state with agents and conditional routing.
|
||||
Sample: Workflow state with agents and conditional routing.
|
||||
|
||||
Store an email once by id, classify it with a detector agent, then either draft a reply with an assistant
|
||||
agent or finish with a spam notice. Stream events as the workflow runs.
|
||||
|
||||
Purpose:
|
||||
Show how to:
|
||||
- Use shared state to decouple large payloads from messages and pass around lightweight references.
|
||||
- Use workflow state to decouple large payloads from messages and pass around lightweight references.
|
||||
- Enforce structured agent outputs with Pydantic models via response_format for robust parsing.
|
||||
- Route using conditional edges based on a typed intermediate DetectionResult.
|
||||
- Compose agent backed executors with function style executors and yield the final output when the workflow completes.
|
||||
@@ -58,7 +58,7 @@ class EmailResponse(BaseModel):
|
||||
|
||||
@dataclass
|
||||
class DetectionResult:
|
||||
"""Internal detection result enriched with the shared state email_id for later lookups."""
|
||||
"""Internal detection result enriched with the state email_id for later lookups."""
|
||||
|
||||
is_spam: bool
|
||||
reason: str
|
||||
@@ -67,7 +67,7 @@ class DetectionResult:
|
||||
|
||||
@dataclass
|
||||
class Email:
|
||||
"""In memory record stored in shared state to avoid re-sending large bodies on edges."""
|
||||
"""In memory record stored in state to avoid re-sending large bodies on edges."""
|
||||
|
||||
email_id: str
|
||||
email_content: str
|
||||
@@ -91,7 +91,7 @@ def get_condition(expected_result: bool):
|
||||
|
||||
@executor(id="store_email")
|
||||
async def store_email(email_text: str, ctx: WorkflowContext[AgentExecutorRequest]) -> None:
|
||||
"""Persist the raw email content in shared state and trigger spam detection.
|
||||
"""Persist the raw email content in state and trigger spam detection.
|
||||
|
||||
Responsibilities:
|
||||
- Generate a unique email_id (UUID) for downstream retrieval.
|
||||
@@ -99,8 +99,8 @@ async def store_email(email_text: str, ctx: WorkflowContext[AgentExecutorRequest
|
||||
- Emit an AgentExecutorRequest asking the detector to respond.
|
||||
"""
|
||||
new_email = Email(email_id=str(uuid4()), email_content=email_text)
|
||||
await ctx.set_shared_state(f"{EMAIL_STATE_PREFIX}{new_email.email_id}", new_email)
|
||||
await ctx.set_shared_state(CURRENT_EMAIL_ID_KEY, new_email.email_id)
|
||||
ctx.set_state(f"{EMAIL_STATE_PREFIX}{new_email.email_id}", new_email)
|
||||
ctx.set_state(CURRENT_EMAIL_ID_KEY, new_email.email_id)
|
||||
|
||||
await ctx.send_message(
|
||||
AgentExecutorRequest(messages=[ChatMessage("user", text=new_email.email_content)], should_respond=True)
|
||||
@@ -113,11 +113,11 @@ async def to_detection_result(response: AgentExecutorResponse, ctx: WorkflowCont
|
||||
|
||||
Steps:
|
||||
1) Validate the agent's JSON output into DetectionResultAgent.
|
||||
2) Retrieve the current email_id from shared state.
|
||||
2) Retrieve the current email_id from workflow state.
|
||||
3) Send a typed DetectionResult for conditional routing.
|
||||
"""
|
||||
parsed = DetectionResultAgent.model_validate_json(response.agent_response.text)
|
||||
email_id: str = await ctx.get_shared_state(CURRENT_EMAIL_ID_KEY)
|
||||
email_id: str = ctx.get_state(CURRENT_EMAIL_ID_KEY)
|
||||
await ctx.send_message(DetectionResult(is_spam=parsed.is_spam, reason=parsed.reason, email_id=email_id))
|
||||
|
||||
|
||||
@@ -131,8 +131,8 @@ async def submit_to_email_assistant(detection: DetectionResult, ctx: WorkflowCon
|
||||
if detection.is_spam:
|
||||
raise RuntimeError("This executor should only handle non-spam messages.")
|
||||
|
||||
# Load the original content by id from shared state and forward it to the assistant.
|
||||
email: Email = await ctx.get_shared_state(f"{EMAIL_STATE_PREFIX}{detection.email_id}")
|
||||
# Load the original content by id from workflow state and forward it to the assistant.
|
||||
email: Email = ctx.get_state(f"{EMAIL_STATE_PREFIX}{detection.email_id}")
|
||||
await ctx.send_message(
|
||||
AgentExecutorRequest(messages=[ChatMessage("user", text=email.email_content)], should_respond=True)
|
||||
)
|
||||
@@ -181,7 +181,7 @@ def create_email_assistant_agent() -> ChatAgent:
|
||||
|
||||
|
||||
async def main() -> None:
|
||||
"""Build and run the shared state with agents and conditional routing workflow."""
|
||||
"""Build and run the workflow state with agents and conditional routing workflow."""
|
||||
|
||||
# Build the workflow graph with conditional edges.
|
||||
# Flow:
|
||||
@@ -16,7 +16,7 @@ through any workflow pattern to @tool functions using the **kwargs pattern.
|
||||
|
||||
Key Concepts:
|
||||
- Pass custom context as kwargs when invoking workflow.run_stream() or workflow.run()
|
||||
- kwargs are stored in SharedState and passed to all agent invocations
|
||||
- kwargs are stored in State and passed to all agent invocations
|
||||
- @tool functions receive kwargs via **kwargs parameter
|
||||
- Works with Sequential, Concurrent, GroupChat, Handoff, and Magentic patterns
|
||||
|
||||
|
||||
Reference in New Issue
Block a user