mirror of
https://github.com/microsoft/agent-framework.git
synced 2026-06-16 21:04:09 +08:00
Python: Adding support for nested workflows (#460)
* Adding design documents and data flow descriptions for sub-workflows * Updating docs. * Sub-workflow implementation #1. Stuck because of singleton RequestInfoExecutor, going to make a change to remove that restrivtion. * Removed the singleton restriction on RequestInfoExecutor so enable sub-workflows. * Scenarios seem to be working. * Sample improved. * going to have intern add generic response wrappers. * Wrapped responses working. * Non-hardcoded routing is working. * Sample showing external approved and not approved. * Cleaning up. * Updating some samples and user guide. * Removing old design doc. * Cleaning up. * Adding python-package-setup.md back. * Update python/packages/workflow/agent_framework_workflow/_executor.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Update python/packages/workflow/agent_framework_workflow/_validation.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Removing prints. * Fixing lint and type issues. * Fixing lint and type issues. * Update python/packages/workflow/agent_framework_workflow/_executor.py Co-authored-by: Eric Zhu <ekzhu@users.noreply.github.com> * Adding type hints to intercepts decorator. * Removing unused files. * Fixing issue with sample 5 groupchat with hil. * Removing redundent samples. * Updates to ensure no conflicting request interceptors and to support a subflow with multiple requests in a single super step. * Fixing pypi errors. * clean up samples * update samples to make it more clear * warning for unhandled request info from sub workflow * add logger info --------- Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Co-authored-by: Eric Zhu <ekzhu@users.noreply.github.com>
This commit is contained in:
committed by
GitHub
Unverified
parent
b26d9c95fe
commit
b0b3fd151c
@@ -12,6 +12,7 @@ from agent_framework.workflow import (
|
||||
RequestInfoEvent,
|
||||
RequestInfoExecutor,
|
||||
RequestInfoMessage,
|
||||
RequestResponse,
|
||||
WorkflowBuilder,
|
||||
WorkflowCompletedEvent,
|
||||
WorkflowContext,
|
||||
@@ -94,16 +95,20 @@ class CriticGroupChatManager(Executor):
|
||||
|
||||
@handler
|
||||
async def handle_request_response(
|
||||
self, response: list[ChatMessage], ctx: WorkflowContext[AgentExecutorRequest]
|
||||
self,
|
||||
response: RequestResponse[RequestInfoMessage, list[ChatMessage]],
|
||||
ctx: WorkflowContext[AgentExecutorRequest],
|
||||
) -> None:
|
||||
"""Handler that processes the response from the RequestInfoExecutor."""
|
||||
messages: list[ChatMessage] = response.data or []
|
||||
|
||||
# Update the chat history with the response
|
||||
self._chat_history.extend(response)
|
||||
self._chat_history.extend(messages)
|
||||
|
||||
# Send the response to the other members
|
||||
await asyncio.gather(*[
|
||||
ctx.send_message(
|
||||
AgentExecutorRequest(messages=response, should_respond=False),
|
||||
AgentExecutorRequest(messages=messages, should_respond=False),
|
||||
target_id=member_id,
|
||||
)
|
||||
for member_id in self._members
|
||||
|
||||
@@ -0,0 +1,233 @@
|
||||
# Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
import asyncio
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
from agent_framework.workflow import (
|
||||
Executor,
|
||||
WorkflowBuilder,
|
||||
WorkflowCompletedEvent,
|
||||
WorkflowContext,
|
||||
WorkflowEvent,
|
||||
WorkflowExecutor,
|
||||
handler,
|
||||
)
|
||||
|
||||
"""
|
||||
The following sample demonstrates basic sub-workflows.
|
||||
|
||||
This sample shows how to:
|
||||
1. Create workflows that execute other workflows as sub-workflows
|
||||
2. Pass data between parent and sub-workflows
|
||||
3. Collect results from sub-workflows
|
||||
|
||||
The example simulates a simple text processing system where:
|
||||
- Sub-workflows process individual text strings (count words, characters)
|
||||
- Parent workflow orchestrates multiple sub-workflows and collects results
|
||||
|
||||
Key concepts demonstrated:
|
||||
- WorkflowExecutor: Wraps a workflow to make it behave as an executor
|
||||
- Sub-workflow isolation: Sub-workflows work independently
|
||||
- Result collection: Parent workflows can gather outputs from sub-workflows
|
||||
|
||||
Simple flow visualization:
|
||||
|
||||
Parent Orchestrator
|
||||
|
|
||||
| TextProcessingRequest(text, task_id)
|
||||
v
|
||||
[ Sub-workflow: WorkflowExecutor(TextProcessor) ]
|
||||
|
|
||||
| WorkflowCompletedEvent(TextProcessingResult)
|
||||
v
|
||||
Parent collects results and summarizes
|
||||
"""
|
||||
|
||||
|
||||
# Message types
|
||||
@dataclass
|
||||
class TextProcessingRequest:
|
||||
"""Request to process a text string."""
|
||||
|
||||
text: str
|
||||
task_id: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class TextProcessingResult:
|
||||
"""Result of text processing."""
|
||||
|
||||
task_id: str
|
||||
text: str
|
||||
word_count: int
|
||||
char_count: int
|
||||
|
||||
|
||||
class AllTasksCompleted(WorkflowEvent):
|
||||
"""Event triggered when all processing tasks are complete."""
|
||||
|
||||
def __init__(self, results: list[TextProcessingResult]):
|
||||
super().__init__(results)
|
||||
|
||||
|
||||
# Sub-workflow executor
|
||||
class TextProcessor(Executor):
|
||||
"""Processes text strings - counts words and characters."""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(id="text_processor")
|
||||
|
||||
@handler
|
||||
async def process_text(
|
||||
self, request: TextProcessingRequest, ctx: WorkflowContext[TextProcessingResult]
|
||||
) -> None:
|
||||
"""Process a text string and return statistics."""
|
||||
text_preview = f"'{request.text[:50]}{'...' if len(request.text) > 50 else ''}'"
|
||||
print(f"π Sub-workflow processing text (Task {request.task_id}): {text_preview}")
|
||||
|
||||
# Simple text processing
|
||||
word_count = len(request.text.split()) if request.text.strip() else 0
|
||||
char_count = len(request.text)
|
||||
|
||||
print(f"π Task {request.task_id}: {word_count} words, {char_count} characters")
|
||||
|
||||
# Create result
|
||||
result = TextProcessingResult(
|
||||
task_id=request.task_id,
|
||||
text=request.text,
|
||||
word_count=word_count,
|
||||
char_count=char_count,
|
||||
)
|
||||
|
||||
print(f"β
Sub-workflow completed task {request.task_id}")
|
||||
# Signal completion
|
||||
await ctx.add_event(WorkflowCompletedEvent(data=result))
|
||||
|
||||
|
||||
# Parent workflow
|
||||
class TextProcessingOrchestrator(Executor):
|
||||
"""Orchestrates multiple text processing tasks using sub-workflows."""
|
||||
|
||||
results: list[TextProcessingResult] = []
|
||||
expected_count: int = 0
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(id="text_orchestrator")
|
||||
|
||||
@handler
|
||||
async def start_processing(
|
||||
self, texts: list[str], ctx: WorkflowContext[TextProcessingRequest]
|
||||
) -> None:
|
||||
"""Start processing multiple text strings."""
|
||||
print(f"π Starting processing of {len(texts)} text strings")
|
||||
print("=" * 60)
|
||||
|
||||
self.expected_count = len(texts)
|
||||
|
||||
# Send each text to a sub-workflow
|
||||
for i, text in enumerate(texts):
|
||||
task_id = f"task_{i+1}"
|
||||
request = TextProcessingRequest(text=text, task_id=task_id)
|
||||
print(f"π€ Dispatching {task_id} to sub-workflow")
|
||||
await ctx.send_message(request, target_id="text_processor_workflow")
|
||||
|
||||
@handler
|
||||
async def collect_result(
|
||||
self, result: TextProcessingResult, ctx: WorkflowContext[None]
|
||||
) -> None:
|
||||
"""Collect results from sub-workflows."""
|
||||
print(f"π₯ Collected result from {result.task_id}")
|
||||
self.results.append(result)
|
||||
|
||||
# Check if all results are collected
|
||||
if len(self.results) == self.expected_count:
|
||||
print("\nπ All tasks completed!")
|
||||
await ctx.add_event(AllTasksCompleted(self.results))
|
||||
|
||||
def get_summary(self) -> dict[str, Any]:
|
||||
"""Get a summary of all processing results."""
|
||||
total_words = sum(result.word_count for result in self.results)
|
||||
total_chars = sum(result.char_count for result in self.results)
|
||||
avg_words = total_words / len(self.results) if self.results else 0
|
||||
avg_chars = total_chars / len(self.results) if self.results else 0
|
||||
|
||||
return {
|
||||
"total_texts": len(self.results),
|
||||
"total_words": total_words,
|
||||
"total_characters": total_chars,
|
||||
"average_words_per_text": round(avg_words, 2),
|
||||
"average_characters_per_text": round(avg_chars, 2),
|
||||
}
|
||||
|
||||
|
||||
async def main():
|
||||
"""Main function to run the basic sub-workflow example."""
|
||||
print("π Setting up sub-workflow...")
|
||||
|
||||
# Step 1: Create the text processing sub-workflow
|
||||
text_processor = TextProcessor()
|
||||
|
||||
processing_workflow = (
|
||||
WorkflowBuilder()
|
||||
.set_start_executor(text_processor)
|
||||
.build()
|
||||
)
|
||||
|
||||
print("π§ Setting up parent workflow...")
|
||||
|
||||
# Step 2: Create the parent workflow
|
||||
orchestrator = TextProcessingOrchestrator()
|
||||
workflow_executor = WorkflowExecutor(processing_workflow, id="text_processor_workflow")
|
||||
|
||||
main_workflow = (
|
||||
WorkflowBuilder()
|
||||
.set_start_executor(orchestrator)
|
||||
.add_edge(orchestrator, workflow_executor)
|
||||
.add_edge(workflow_executor, orchestrator)
|
||||
.build()
|
||||
)
|
||||
|
||||
# Step 3: Test data - various text strings
|
||||
test_texts = [
|
||||
"Hello world! This is a simple test.",
|
||||
"Python is a powerful programming language used for many applications.",
|
||||
"Short text.",
|
||||
"This is a longer text with multiple sentences. It contains more words and characters. We use it to test our text processing workflow.",
|
||||
"", # Empty string
|
||||
" Spaces around text ",
|
||||
]
|
||||
|
||||
print(f"\nπ§ͺ Testing with {len(test_texts)} text strings")
|
||||
print("=" * 60)
|
||||
|
||||
# Step 4: Run the workflow
|
||||
result = await main_workflow.run(test_texts)
|
||||
|
||||
# Step 5: Display results
|
||||
print(f"\nπ Processing Results:")
|
||||
print("=" * 60)
|
||||
|
||||
# Sort results by task_id for consistent display
|
||||
sorted_results = sorted(orchestrator.results, key=lambda r: r.task_id)
|
||||
|
||||
for result in sorted_results:
|
||||
preview = result.text[:30] + "..." if len(result.text) > 30 else result.text
|
||||
preview = preview.replace('\n', ' ').strip() or '(empty)'
|
||||
print(f"β
{result.task_id}: '{preview}' -> {result.word_count} words, {result.char_count} chars")
|
||||
|
||||
# Step 6: Display summary
|
||||
summary = orchestrator.get_summary()
|
||||
print(f"\nπ Summary:")
|
||||
print("=" * 60)
|
||||
print(f"π Total texts processed: {summary['total_texts']}")
|
||||
print(f"π Total words: {summary['total_words']}")
|
||||
print(f"π€ Total characters: {summary['total_characters']}")
|
||||
print(f"π Average words per text: {summary['average_words_per_text']}")
|
||||
print(f"π Average characters per text: {summary['average_characters_per_text']}")
|
||||
|
||||
print(f"\nπ Processing complete!")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
+275
@@ -0,0 +1,275 @@
|
||||
# Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
"""
|
||||
The following sample demonstrates sub-workflows with request interception and conditional forwarding.
|
||||
|
||||
This sample shows how to:
|
||||
1. Create workflows that execute other workflows as sub-workflows
|
||||
2. Intercept requests from sub-workflows in parent workflows using @intercepts_request
|
||||
3. Conditionally handle or forward requests using RequestResponse.handled() and RequestResponse.forward()
|
||||
4. Handle external requests that are forwarded by the parent workflow
|
||||
5. Proper request/response correlation for concurrent processing
|
||||
|
||||
The example simulates an email validation system where:
|
||||
- Sub-workflows validate multiple email addresses concurrently
|
||||
- Parent workflows can intercept domain check requests for optimization
|
||||
- Known domains (example.com, company.com) are approved locally
|
||||
- Unknown domains (unknown.org) are forwarded to external services
|
||||
- Request correlation ensures each email gets the correct domain check response
|
||||
- External domain check requests are processed and responses routed back correctly
|
||||
|
||||
Key concepts demonstrated:
|
||||
- WorkflowExecutor: Wraps a workflow to make it behave as an executor
|
||||
- @intercepts_request: Decorator for parent workflows to handle sub-workflow requests
|
||||
- RequestResponse: Enables conditional handling vs forwarding of requests
|
||||
- Request correlation: Using request_id to match responses with original requests
|
||||
- Concurrent processing: Multiple emails processed simultaneously without interference
|
||||
- External request routing: RequestInfoExecutor handles forwarded external requests
|
||||
- Sub-workflow isolation: Sub-workflows work normally without knowing they're nested
|
||||
|
||||
Simple flow visualization:
|
||||
|
||||
Parent Orchestrator (@intercepts_request)
|
||||
|
|
||||
| EmailValidationRequest(email) x3 (concurrent)
|
||||
v
|
||||
[ Sub-workflow: WorkflowExecutor(EmailValidator) ]
|
||||
|
|
||||
| DomainCheckRequest(domain) with request_id correlation
|
||||
v
|
||||
Interception? yes -> handled locally with RequestResponse.handled(True)
|
||||
no -> forwarded to RequestInfoExecutor -> external service
|
||||
|
|
||||
v
|
||||
Response routed back to sub-workflow using request_id
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
from agent_framework_workflow import (
|
||||
Executor,
|
||||
RequestInfoExecutor,
|
||||
RequestInfoMessage,
|
||||
RequestResponse,
|
||||
WorkflowBuilder,
|
||||
WorkflowCompletedEvent,
|
||||
WorkflowContext,
|
||||
WorkflowExecutor,
|
||||
handler,
|
||||
intercepts_request,
|
||||
)
|
||||
|
||||
|
||||
# 1. Define domain-specific message types
|
||||
@dataclass
|
||||
class EmailValidationRequest:
|
||||
"""Request to validate an email address."""
|
||||
|
||||
email: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class DomainCheckRequest(RequestInfoMessage):
|
||||
"""Request to check if a domain is approved."""
|
||||
|
||||
domain: str = ""
|
||||
|
||||
|
||||
@dataclass
|
||||
class ValidationResult:
|
||||
"""Result of email validation."""
|
||||
|
||||
email: str
|
||||
is_valid: bool
|
||||
reason: str
|
||||
|
||||
|
||||
# 2. Implement the sub-workflow executor (completely standard)
|
||||
class EmailValidator(Executor):
|
||||
"""Validates email addresses - doesn't know it's in a sub-workflow."""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize the EmailValidator executor."""
|
||||
super().__init__(id="email_validator")
|
||||
# Use a dict to track multiple pending emails by request_id
|
||||
self._pending_emails: dict[str, str] = {}
|
||||
|
||||
@handler
|
||||
async def validate_request(
|
||||
self, request: EmailValidationRequest, ctx: WorkflowContext[DomainCheckRequest | ValidationResult]
|
||||
) -> None:
|
||||
"""Validate an email address."""
|
||||
print(f"π Sub-workflow validating email: {request.email}")
|
||||
|
||||
# Extract domain
|
||||
domain = request.email.split("@")[1] if "@" in request.email else ""
|
||||
|
||||
if not domain:
|
||||
print(f"β Invalid email format: {request.email}")
|
||||
result = ValidationResult(email=request.email, is_valid=False, reason="Invalid email format")
|
||||
await ctx.add_event(WorkflowCompletedEvent(data=result))
|
||||
return
|
||||
|
||||
print(f"π Sub-workflow requesting domain check for: {domain}")
|
||||
# Request domain check
|
||||
domain_check = DomainCheckRequest(domain=domain)
|
||||
# Store the pending email with the request_id for correlation
|
||||
self._pending_emails[domain_check.request_id] = request.email
|
||||
await ctx.send_message(domain_check, target_id="email_request_info")
|
||||
|
||||
@handler
|
||||
async def handle_domain_response(
|
||||
self,
|
||||
response: RequestResponse[DomainCheckRequest, bool],
|
||||
ctx: WorkflowContext[ValidationResult],
|
||||
) -> None:
|
||||
"""Handle domain check response from RequestInfo with correlation."""
|
||||
approved = bool(response.data)
|
||||
domain = response.original_request.domain if (hasattr(response, 'original_request') and response.original_request) else "unknown"
|
||||
print(f"π¬ Sub-workflow received domain response for '{domain}': {approved}")
|
||||
|
||||
# Find the corresponding email using the request_id
|
||||
request_id = response.original_request.request_id if (hasattr(response, 'original_request') and response.original_request) else None
|
||||
if request_id and request_id in self._pending_emails:
|
||||
email = self._pending_emails.pop(request_id) # Remove from pending
|
||||
result = ValidationResult(
|
||||
email=email,
|
||||
is_valid=approved,
|
||||
reason="Domain approved" if approved else "Domain not approved",
|
||||
)
|
||||
print(f"β
Sub-workflow completing validation for: {email}")
|
||||
await ctx.add_event(WorkflowCompletedEvent(data=result))
|
||||
|
||||
|
||||
# 3. Implement the parent workflow with request interception
|
||||
class SmartEmailOrchestrator(Executor):
|
||||
"""Parent orchestrator that can intercept domain checks."""
|
||||
approved_domains: set[str] = set()
|
||||
|
||||
def __init__(self, approved_domains: set[str] | None = None):
|
||||
"""Initialize the SmartEmailOrchestrator with approved domains.
|
||||
|
||||
Args:
|
||||
approved_domains: Set of pre-approved domains, defaults to example.com, test.org, company.com
|
||||
"""
|
||||
super().__init__(id="email_orchestrator", approved_domains=approved_domains)
|
||||
self._results: list[ValidationResult] = []
|
||||
|
||||
@handler
|
||||
async def start_validation(self, emails: list[str], ctx: WorkflowContext[EmailValidationRequest]) -> None:
|
||||
"""Start validating a batch of emails."""
|
||||
print(f"π§ Starting validation of {len(emails)} email addresses")
|
||||
print("=" * 60)
|
||||
for email in emails:
|
||||
print(f"π€ Sending '{email}' to sub-workflow for validation")
|
||||
request = EmailValidationRequest(email=email)
|
||||
await ctx.send_message(request, target_id="email_validator_workflow")
|
||||
|
||||
@intercepts_request
|
||||
async def check_domain(
|
||||
self, request: DomainCheckRequest, ctx: WorkflowContext[Any]
|
||||
) -> RequestResponse[DomainCheckRequest, bool]:
|
||||
"""Intercept domain check requests from sub-workflows."""
|
||||
print(f"π Parent intercepting domain check for: {request.domain}")
|
||||
if request.domain in self.approved_domains:
|
||||
print(f"β
Domain '{request.domain}' is pre-approved locally!")
|
||||
return RequestResponse[DomainCheckRequest, bool].handled(True)
|
||||
print(f"β Domain '{request.domain}' unknown, forwarding to external service...")
|
||||
return RequestResponse.forward()
|
||||
|
||||
@handler
|
||||
async def collect_result(self, result: ValidationResult, ctx: WorkflowContext[None]) -> None:
|
||||
"""Collect validation results. It comes from the sub-workflow emitted WorkflowCompletionEvent's data field."""
|
||||
status_icon = "β
" if result.is_valid else "β"
|
||||
print(f"π₯ {status_icon} Validation result: {result.email} -> {result.reason}")
|
||||
self._results.append(result)
|
||||
|
||||
@property
|
||||
def results(self) -> list[ValidationResult]:
|
||||
"""Get the collected validation results."""
|
||||
return self._results
|
||||
|
||||
|
||||
async def run_example() -> None:
|
||||
"""Run the sub-workflow example."""
|
||||
print("π Setting up sub-workflow with request interception...")
|
||||
print()
|
||||
|
||||
# 4. Build the sub-workflow
|
||||
email_validator = EmailValidator()
|
||||
# Match the target_id used in EmailValidator ("email_request_info")
|
||||
request_info = RequestInfoExecutor(id="email_request_info")
|
||||
|
||||
validation_workflow = (
|
||||
WorkflowBuilder()
|
||||
.set_start_executor(email_validator)
|
||||
.add_edge(email_validator, request_info)
|
||||
.add_edge(request_info, email_validator)
|
||||
.build()
|
||||
)
|
||||
|
||||
# 5. Build the parent workflow with interception
|
||||
orchestrator = SmartEmailOrchestrator(approved_domains={"example.com", "company.com"})
|
||||
workflow_executor = WorkflowExecutor(validation_workflow, id="email_validator_workflow")
|
||||
# Add a RequestInfoExecutor to handle forwarded external requests
|
||||
main_request_info = RequestInfoExecutor(id="main_request_info")
|
||||
|
||||
main_workflow = (
|
||||
WorkflowBuilder()
|
||||
.set_start_executor(orchestrator)
|
||||
.add_edge(orchestrator, workflow_executor)
|
||||
.add_edge(workflow_executor, orchestrator)
|
||||
# Add edges for external request handling
|
||||
.add_edge(orchestrator, main_request_info)
|
||||
.add_edge(main_request_info, workflow_executor) # Route external responses to sub-workflow
|
||||
.build()
|
||||
)
|
||||
|
||||
# 6. Prepare test inputs: known domain, unknown domain
|
||||
test_emails = [
|
||||
"user@example.com", # Should be intercepted and approved
|
||||
"admin@company.com", # Should be intercepted and approved
|
||||
"guest@unknown.org", # Should be forwarded externally
|
||||
]
|
||||
|
||||
# 7. Run the workflow
|
||||
result = await main_workflow.run(test_emails)
|
||||
|
||||
# 8. Handle any external requests
|
||||
request_events = result.get_request_info_events()
|
||||
if request_events:
|
||||
print(f"\nπ Handling {len(request_events)} external request(s)...")
|
||||
for event in request_events:
|
||||
if event.data and hasattr(event.data, "domain"):
|
||||
print(f"π External domain check needed for: {event.data.domain}")
|
||||
|
||||
# Simulate external responses
|
||||
external_responses: dict[str, bool] = {}
|
||||
for event in request_events:
|
||||
# Simulate external domain checking
|
||||
if event.data and hasattr(event.data, "domain"):
|
||||
domain = event.data.domain
|
||||
# Let's say unknown.org is actually approved externally
|
||||
approved = domain == "unknown.org"
|
||||
print(f"π External service response for '{domain}': {'APPROVED' if approved else 'REJECTED'}")
|
||||
external_responses[event.request_id] = approved
|
||||
|
||||
# 9. Send external responses
|
||||
await main_workflow.send_responses(external_responses)
|
||||
else:
|
||||
print("\nπ― All requests were intercepted and handled locally!")
|
||||
|
||||
# 10. Display final summary
|
||||
print(f"\nπ Final Results Summary:")
|
||||
print("=" * 60)
|
||||
for result in orchestrator.results:
|
||||
status = "β
VALID" if result.is_valid else "β INVALID"
|
||||
print(f"{status} {result.email}: {result.reason}")
|
||||
|
||||
print(f"\nπ Processed {len(orchestrator.results)} emails total")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(run_example())
|
||||
@@ -0,0 +1,414 @@
|
||||
# Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
import asyncio
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
from agent_framework.workflow import (
|
||||
Executor,
|
||||
RequestInfoExecutor,
|
||||
WorkflowBuilder,
|
||||
WorkflowCompletedEvent,
|
||||
WorkflowContext,
|
||||
handler,
|
||||
)
|
||||
|
||||
# Import the new sub-workflow types directly from the implementation package
|
||||
try:
|
||||
from agent_framework_workflow import (
|
||||
RequestInfoMessage,
|
||||
RequestResponse,
|
||||
WorkflowExecutor,
|
||||
intercepts_request,
|
||||
)
|
||||
except ImportError:
|
||||
import os
|
||||
import sys
|
||||
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "..", "..", "packages", "workflow"))
|
||||
from agent_framework_workflow import (
|
||||
RequestInfoMessage,
|
||||
RequestResponse,
|
||||
WorkflowExecutor,
|
||||
intercepts_request,
|
||||
)
|
||||
|
||||
"""
|
||||
This sample demonstrates the PROPER pattern for request interception.
|
||||
|
||||
Key principles:
|
||||
1. Only ONE executor intercepts a given request type from a specific sub-workflow
|
||||
2. Different executors can intercept DIFFERENT request types from the same sub-workflow
|
||||
3. The same executor can intercept the same request type from DIFFERENT sub-workflows
|
||||
|
||||
This ensures:
|
||||
- Deterministic behavior
|
||||
- Clear responsibility boundaries
|
||||
- Easier debugging and maintenance
|
||||
|
||||
The example simulates a resource allocation system where:
|
||||
- Sub-workflow requests resources (CPU, memory, etc.)
|
||||
- A single Cache executor intercepts and handles resource requests
|
||||
- The Cache can either satisfy from cache or forward to external
|
||||
|
||||
Simple flow visualization:
|
||||
|
||||
Coordinator
|
||||
|
|
||||
| list[resource/policy requests]
|
||||
v
|
||||
[ Sub-workflow: WorkflowExecutor(ResourceRequester) ]
|
||||
| |
|
||||
| ResourceRequest | PolicyCheckRequest
|
||||
v v
|
||||
ResourceCache (@intercepts) PolicyEngine (@intercepts)
|
||||
| handled/forward | handled/forward
|
||||
v v
|
||||
RequestInfo (external) <----- forwarded when not handled
|
||||
| responses
|
||||
v
|
||||
Back to sub-workflow -> completion -> results collected
|
||||
"""
|
||||
|
||||
|
||||
# 1. Define domain-specific request/response types
|
||||
@dataclass
|
||||
class ResourceRequest(RequestInfoMessage):
|
||||
"""Request for computing resources."""
|
||||
|
||||
resource_type: str = "cpu" # cpu, memory, disk, etc.
|
||||
amount: int = 1
|
||||
priority: str = "normal" # low, normal, high
|
||||
|
||||
|
||||
@dataclass
|
||||
class PolicyCheckRequest(RequestInfoMessage):
|
||||
"""Request to check resource allocation policy."""
|
||||
|
||||
resource_type: str = ""
|
||||
amount: int = 0
|
||||
policy_type: str = "quota" # quota, compliance, security
|
||||
|
||||
|
||||
@dataclass
|
||||
class ResourceResponse:
|
||||
"""Response with allocated resources."""
|
||||
|
||||
resource_type: str
|
||||
allocated: int
|
||||
source: str # Which system provided the resources
|
||||
|
||||
|
||||
@dataclass
|
||||
class PolicyResponse:
|
||||
"""Response from policy check."""
|
||||
|
||||
approved: bool
|
||||
reason: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class RequestFinished:
|
||||
pass
|
||||
|
||||
|
||||
# 2. Implement the sub-workflow executor - makes resource and policy requests
|
||||
class ResourceRequester(Executor):
|
||||
"""Simple executor that requests resources and checks policies."""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(id="resource_requester")
|
||||
self._request_count = 0
|
||||
|
||||
@handler
|
||||
async def request_resources(
|
||||
self,
|
||||
requests: list[dict[str, Any]],
|
||||
ctx: WorkflowContext[ResourceRequest | PolicyCheckRequest],
|
||||
) -> None:
|
||||
"""Process a list of resource requests."""
|
||||
print(f"π Sub-workflow processing {len(requests)} requests")
|
||||
self._request_count += len(requests)
|
||||
|
||||
for req_data in requests:
|
||||
req_type = req_data.get("request_type", "resource")
|
||||
|
||||
if req_type == "resource":
|
||||
print(f" π¦ Requesting resource: {req_data.get('type', 'cpu')} x{req_data.get('amount', 1)}")
|
||||
request = ResourceRequest(
|
||||
resource_type=req_data.get("type", "cpu"),
|
||||
amount=req_data.get("amount", 1),
|
||||
priority=req_data.get("priority", "normal"),
|
||||
)
|
||||
# Send to parent workflow for interception - not to target_id
|
||||
await ctx.send_message(request)
|
||||
elif req_type == "policy":
|
||||
print(f" π‘οΈ Checking policy: {req_data.get('type', 'cpu')} x{req_data.get('amount', 1)} ({req_data.get('policy_type', 'quota')})")
|
||||
request = PolicyCheckRequest(
|
||||
resource_type=req_data.get("type", "cpu"),
|
||||
amount=req_data.get("amount", 1),
|
||||
policy_type=req_data.get("policy_type", "quota"),
|
||||
)
|
||||
# Send to parent workflow for interception - not to target_id
|
||||
await ctx.send_message(request)
|
||||
|
||||
@handler
|
||||
async def handle_resource_response(
|
||||
self,
|
||||
response: RequestResponse[ResourceRequest, ResourceResponse],
|
||||
ctx: WorkflowContext[None],
|
||||
) -> None:
|
||||
"""Handle resource allocation response."""
|
||||
if response.data:
|
||||
source_icon = "πͺ" if response.data.source == "cache" else "π"
|
||||
print(
|
||||
f"π¦ {source_icon} Sub-workflow received: {response.data.allocated} {response.data.resource_type} from {response.data.source}"
|
||||
)
|
||||
if self._collect_results():
|
||||
# Emit completion event and send RequestFinished to the parent workflow.
|
||||
await ctx.add_event(WorkflowCompletedEvent(RequestFinished()))
|
||||
|
||||
@handler
|
||||
async def handle_policy_response(
|
||||
self, response: RequestResponse[PolicyCheckRequest, PolicyResponse], ctx: WorkflowContext[None]
|
||||
) -> None:
|
||||
"""Handle policy check response."""
|
||||
if response.data:
|
||||
status_icon = "β
" if response.data.approved else "β"
|
||||
print(f"π‘οΈ {status_icon} Sub-workflow received policy response: {response.data.approved} - {response.data.reason}")
|
||||
if self._collect_results():
|
||||
# Emit completion event and send RequestFinished to the parent workflow.
|
||||
await ctx.add_event(WorkflowCompletedEvent(RequestFinished()))
|
||||
|
||||
def _collect_results(self) -> bool:
|
||||
"""Collect and summarize results."""
|
||||
self._request_count -= 1
|
||||
print(f"π Sub-workflow completed request ({self._request_count} remaining)")
|
||||
return self._request_count == 0
|
||||
|
||||
|
||||
# 3. Implement the Resource Cache - ONLY intercepts ResourceRequest
|
||||
class ResourceCache(Executor):
|
||||
"""Interceptor that handles RESOURCE requests from cache."""
|
||||
|
||||
# Use class attributes to avoid Pydantic assignment restrictions
|
||||
cache: dict[str, int] = {"cpu": 10, "memory": 50, "disk": 100}
|
||||
results: list[ResourceResponse] = []
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(id="resource_cache")
|
||||
# Instance initialization only; state kept in class attributes as above
|
||||
|
||||
@intercepts_request
|
||||
async def check_cache(
|
||||
self, request: ResourceRequest, ctx: WorkflowContext[None]
|
||||
) -> RequestResponse[ResourceRequest, ResourceResponse]:
|
||||
"""Intercept RESOURCE requests and check cache first."""
|
||||
print(f"πͺ CACHE interceptor checking: {request.amount} {request.resource_type}")
|
||||
|
||||
available = self.cache.get(request.resource_type, 0)
|
||||
|
||||
if available >= request.amount:
|
||||
# We can satisfy from cache
|
||||
self.cache[request.resource_type] -= request.amount
|
||||
response = ResourceResponse(resource_type=request.resource_type, allocated=request.amount, source="cache")
|
||||
print(f" β
Cache satisfied: {request.amount} {request.resource_type}")
|
||||
self.results.append(response)
|
||||
return RequestResponse[ResourceRequest, ResourceResponse].handled(response)
|
||||
|
||||
# Cache miss - forward to external
|
||||
print(f" β Cache miss: need {request.amount}, have {available} {request.resource_type}")
|
||||
return RequestResponse.forward()
|
||||
|
||||
@handler
|
||||
async def collect_result(
|
||||
self, response: RequestResponse[ResourceRequest, ResourceResponse], ctx: WorkflowContext[None]
|
||||
) -> None:
|
||||
"""Collect results from external requests that were forwarded."""
|
||||
if response.data and response.data.source != "cache": # Don't double-count our own results
|
||||
self.results.append(response.data)
|
||||
print(
|
||||
f"πͺ π Cache received external response: {response.data.allocated} {response.data.resource_type} from {response.data.source}"
|
||||
)
|
||||
|
||||
|
||||
# 4. Implement the Policy Engine - ONLY intercepts PolicyCheckRequest (different type!)
|
||||
class PolicyEngine(Executor):
|
||||
"""Interceptor that handles POLICY requests."""
|
||||
|
||||
# Use class attributes for simple sample state
|
||||
quota: dict[str, int] = {
|
||||
"cpu": 5, # Only allow up to 5 CPU units
|
||||
"memory": 20, # Only allow up to 20 memory units
|
||||
"disk": 1000, # Liberal disk policy
|
||||
}
|
||||
results: list[PolicyResponse] = []
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(id="policy_engine")
|
||||
# Instance initialization only; state kept in class attributes as above
|
||||
|
||||
@intercepts_request
|
||||
async def check_policy(
|
||||
self, request: PolicyCheckRequest, ctx: WorkflowContext[None]
|
||||
) -> RequestResponse[PolicyCheckRequest, PolicyResponse]:
|
||||
"""Intercept POLICY requests and apply rules."""
|
||||
print(f"π‘οΈ POLICY interceptor checking: {request.amount} {request.resource_type}, policy={request.policy_type}")
|
||||
|
||||
quota_limit = self.quota.get(request.resource_type, 0)
|
||||
|
||||
if request.policy_type == "quota":
|
||||
if request.amount <= quota_limit:
|
||||
response = PolicyResponse(approved=True, reason=f"Within quota ({quota_limit})")
|
||||
print(f" β
Policy approved: {request.amount} <= {quota_limit}")
|
||||
self.results.append(response)
|
||||
return RequestResponse[PolicyCheckRequest, PolicyResponse].handled(response)
|
||||
# Exceeds quota - forward to external for review
|
||||
print(f" β Policy exceeds quota: {request.amount} > {quota_limit}, forwarding to external")
|
||||
return RequestResponse.forward()
|
||||
|
||||
# Unknown policy type - forward to external
|
||||
print(f" β Unknown policy type: {request.policy_type}, forwarding")
|
||||
return RequestResponse.forward()
|
||||
|
||||
@handler
|
||||
async def collect_policy_result(
|
||||
self, response: RequestResponse[PolicyCheckRequest, PolicyResponse], ctx: WorkflowContext[None]
|
||||
) -> None:
|
||||
"""Collect policy results from external requests that were forwarded."""
|
||||
if response.data:
|
||||
self.results.append(response.data)
|
||||
print(f"π‘οΈ π Policy received external response: {response.data.approved} - {response.data.reason}")
|
||||
|
||||
|
||||
class Coordinator(Executor):
|
||||
def __init__(self):
|
||||
super().__init__(id="coordinator")
|
||||
|
||||
@handler
|
||||
async def start(self, requests: list[dict[str, Any]], ctx: WorkflowContext[object]) -> None:
|
||||
"""Start the resource allocation process."""
|
||||
await ctx.send_message(requests, target_id="resource_workflow")
|
||||
|
||||
@handler
|
||||
async def handle_completion(self, completion: RequestFinished, ctx: WorkflowContext[None]) -> None:
|
||||
"""Handle sub-workflow completion. It comes from the sub-workflow emitted WorkflowCompletionEvent's data field."""
|
||||
print(f"π― Main workflow received completion.")
|
||||
|
||||
|
||||
async def main() -> None:
|
||||
"""Demonstrate parallel request interception patterns."""
|
||||
print("π Starting Sub-Workflow Parallel Request Interception Demo...")
|
||||
print("=" * 60)
|
||||
|
||||
# 5. Create the sub-workflow
|
||||
resource_requester = ResourceRequester()
|
||||
sub_request_info = RequestInfoExecutor(id="sub_request_info")
|
||||
|
||||
sub_workflow = (
|
||||
WorkflowBuilder()
|
||||
.set_start_executor(resource_requester)
|
||||
.add_edge(resource_requester, sub_request_info)
|
||||
.add_edge(sub_request_info, resource_requester)
|
||||
.build()
|
||||
)
|
||||
|
||||
# 6. Create parent workflow with PROPER interceptor pattern
|
||||
cache = ResourceCache() # Intercepts ResourceRequest
|
||||
policy = PolicyEngine() # Intercepts PolicyCheckRequest (different type!)
|
||||
workflow_executor = WorkflowExecutor(sub_workflow, id="resource_workflow")
|
||||
main_request_info = RequestInfoExecutor(id="main_request_info")
|
||||
|
||||
# Create a simple coordinator that starts the process
|
||||
coordinator = Coordinator()
|
||||
|
||||
# PROPER PATTERN: Each executor intercepts DIFFERENT request types
|
||||
main_workflow = (
|
||||
WorkflowBuilder()
|
||||
.set_start_executor(coordinator)
|
||||
.add_edge(coordinator, workflow_executor) # Start sub-workflow
|
||||
.add_edge(workflow_executor, coordinator) # Sub-workflow completion back to coordinator
|
||||
.add_edge(workflow_executor, cache) # Cache intercepts ResourceRequest
|
||||
.add_edge(cache, workflow_executor) # Cache handles ResourceRequest
|
||||
.add_edge(workflow_executor, policy) # Policy handles PolicyCheckRequest
|
||||
.add_edge(policy, workflow_executor) # Policy intercepts PolicyCheckRequest
|
||||
.add_edge(cache, main_request_info) # Cache forwards to external
|
||||
.add_edge(policy, main_request_info) # Policy forwards to external
|
||||
.add_edge(main_request_info, workflow_executor) # External responses back
|
||||
.add_edge(workflow_executor, main_request_info) # Sub-workflow forwards to main
|
||||
.build()
|
||||
)
|
||||
|
||||
# 7. Test with various requests (mixed resource and policy)
|
||||
test_requests = [
|
||||
{"request_type": "resource", "type": "cpu", "amount": 2, "priority": "normal"}, # Cache hit
|
||||
{"request_type": "policy", "type": "cpu", "amount": 3, "policy_type": "quota"}, # Policy hit
|
||||
{"request_type": "resource", "type": "memory", "amount": 15, "priority": "normal"}, # Cache hit
|
||||
{"request_type": "policy", "type": "memory", "amount": 100, "policy_type": "quota"}, # Policy miss -> external
|
||||
{"request_type": "resource", "type": "gpu", "amount": 1, "priority": "high"}, # Cache miss -> external
|
||||
{"request_type": "policy", "type": "disk", "amount": 500, "policy_type": "quota"}, # Policy hit
|
||||
{"request_type": "policy", "type": "cpu", "amount": 1, "policy_type": "security"}, # Unknown policy -> external
|
||||
]
|
||||
|
||||
print(f"π§ͺ Testing with {len(test_requests)} mixed requests:")
|
||||
for i, req in enumerate(test_requests, 1):
|
||||
req_icon = "π¦" if req["request_type"] == "resource" else "π‘οΈ"
|
||||
print(f" {i}. {req_icon} {req['type']} x{req['amount']} ({req.get('priority', req.get('policy_type', 'default'))})")
|
||||
print("=" * 70)
|
||||
|
||||
# 8. Run the workflow
|
||||
print("π¬ Running workflow...")
|
||||
result = await main_workflow.run(test_requests)
|
||||
|
||||
# 9. Handle any external requests that couldn't be intercepted
|
||||
request_events = result.get_request_info_events()
|
||||
if request_events:
|
||||
print(f"\nπ Handling {len(request_events)} external request(s)...")
|
||||
|
||||
external_responses: dict[str, Any] = {}
|
||||
for event in request_events:
|
||||
if isinstance(event.data, ResourceRequest):
|
||||
# Handle ResourceRequest - create ResourceResponse
|
||||
resource_response = ResourceResponse(
|
||||
resource_type=event.data.resource_type, allocated=event.data.amount, source="external_provider"
|
||||
)
|
||||
external_responses[event.request_id] = resource_response
|
||||
print(f" π External provider: {resource_response.allocated} {resource_response.resource_type}")
|
||||
elif isinstance(event.data, PolicyCheckRequest):
|
||||
# Handle PolicyCheckRequest - create PolicyResponse
|
||||
policy_response = PolicyResponse(approved=True, reason="External policy service approved")
|
||||
external_responses[event.request_id] = policy_response
|
||||
print(f" π External policy: {'β
APPROVED' if policy_response.approved else 'β DENIED'}")
|
||||
|
||||
await main_workflow.send_responses(external_responses)
|
||||
else:
|
||||
print("\nπ― All requests were intercepted internally!")
|
||||
|
||||
# 10. Show results and analysis
|
||||
print("\n" + "=" * 70)
|
||||
print("π RESULTS ANALYSIS")
|
||||
print("=" * 70)
|
||||
|
||||
print(f"\nπͺ Cache Results ({len(cache.results)} handled):")
|
||||
for result in cache.results:
|
||||
print(f" β
{result.allocated} {result.resource_type} from {result.source}")
|
||||
|
||||
print(f"\nπ‘οΈ Policy Results ({len(policy.results)} handled):")
|
||||
for result in policy.results:
|
||||
status_icon = "β
" if result.approved else "β"
|
||||
print(f" {status_icon} Approved: {result.approved} - {result.reason}")
|
||||
|
||||
print(f"\nπΎ Final Cache State:")
|
||||
for resource, amount in cache.cache.items():
|
||||
print(f" π¦ {resource}: {amount} remaining")
|
||||
|
||||
print(f"\nπ Summary:")
|
||||
print(f" π― Total requests: {len(test_requests)}")
|
||||
print(f" πͺ Resource requests handled: {len(cache.results)}")
|
||||
print(f" π‘οΈ Policy requests handled: {len(policy.results)}")
|
||||
print(f" π External requests: {len(request_events) if request_events else 0}")
|
||||
|
||||
print("\n" + "=" * 70)
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
Reference in New Issue
Block a user