Python: Adding support for nested workflows (#460)

* Adding design documents and data flow descriptions for sub-workflows

* Updating docs.

* Sub-workflow implementation #1. Stuck because of singleton RequestInfoExecutor, going to make a change to remove that restrivtion.

* Removed the singleton restriction on RequestInfoExecutor so enable sub-workflows.

* Scenarios seem to be working.

* Sample improved.

* going to have intern add generic response wrappers.

* Wrapped responses working.

* Non-hardcoded routing is working.

* Sample showing external approved and not approved.

* Cleaning up.

* Updating some samples and user guide.

* Removing old design doc.

* Cleaning up.

* Adding python-package-setup.md back.

* Update python/packages/workflow/agent_framework_workflow/_executor.py

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

* Update python/packages/workflow/agent_framework_workflow/_validation.py

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>

* Removing prints.

* Fixing lint and type issues.

* Fixing lint and type issues.

* Update python/packages/workflow/agent_framework_workflow/_executor.py

Co-authored-by: Eric Zhu <ekzhu@users.noreply.github.com>

* Adding type hints to intercepts decorator.

* Removing unused files.

* Fixing issue with sample 5 groupchat with hil.

* Removing redundent samples.

* Updates to ensure no conflicting request interceptors and to support a subflow with multiple requests in a single super step.

* Fixing pypi errors.

* clean up samples

* update samples to make it more clear

* warning for unhandled request info from sub workflow

* add logger info

---------

Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com>
Co-authored-by: Eric Zhu <ekzhu@users.noreply.github.com>
This commit is contained in:
Ben Thomas
2025-08-22 20:21:37 -07:00
committed by GitHub
Unverified
parent b26d9c95fe
commit b0b3fd151c
16 changed files with 2335 additions and 39 deletions
+1 -1
View File
@@ -319,4 +319,4 @@ We should consider auto-instrumentation and provide an implementation of it to t
### Build and release
The build step will be done in GHA, adding the package to the release and then we call into Azure DevOps to use the ESRP pipeline to publish to pypi. This is how SK already works, we will just have to adapt it to the new package structure.
For now we will stick to semantic versioning, and all preview release will be tagged as such.
For now we will stick to semantic versioning, and all preview release will be tagged as such.
@@ -35,6 +35,11 @@ _IMPORTS = [
"WorkflowCheckpoint",
"Case",
"Default",
"RequestResponse",
"SubWorkflowRequestInfo",
"SubWorkflowResponse",
"WorkflowExecutor",
"intercepts_request",
]
@@ -18,17 +18,22 @@ from agent_framework_workflow import (
RequestInfoEvent,
RequestInfoExecutor,
RequestInfoMessage,
RequestResponse,
SubWorkflowRequestInfo,
SubWorkflowResponse,
Workflow,
WorkflowBuilder,
WorkflowCheckpoint,
WorkflowCompletedEvent,
WorkflowContext,
WorkflowEvent,
WorkflowExecutor,
WorkflowRunResult,
WorkflowStartedEvent,
WorkflowViz,
__version__,
handler,
intercepts_request,
)
__all__ = [
@@ -49,15 +54,20 @@ __all__ = [
"RequestInfoEvent",
"RequestInfoExecutor",
"RequestInfoMessage",
"RequestResponse",
"SubWorkflowRequestInfo",
"SubWorkflowResponse",
"Workflow",
"WorkflowBuilder",
"WorkflowCheckpoint",
"WorkflowCompletedEvent",
"WorkflowContext",
"WorkflowEvent",
"WorkflowExecutor",
"WorkflowRunResult",
"WorkflowStartedEvent",
"WorkflowViz",
"__version__",
"handler",
"intercepts_request",
]
@@ -30,7 +30,12 @@ from ._executor import (
Executor,
RequestInfoExecutor,
RequestInfoMessage,
RequestResponse,
SubWorkflowRequestInfo,
SubWorkflowResponse,
WorkflowExecutor,
handler,
intercepts_request,
)
from ._runner_context import (
InProcRunnerContext,
@@ -76,11 +81,12 @@ __all__ = [
"InProcRunnerContext",
"Message",
"RequestInfoEvent",
"RequestInfoEvent",
"RequestInfoExecutor",
"RequestInfoExecutor",
"RequestInfoMessage",
"RequestResponse",
"RunnerContext",
"SubWorkflowRequestInfo",
"SubWorkflowResponse",
"TypeCompatibilityError",
"ValidationTypeEnum",
"Workflow",
@@ -89,11 +95,13 @@ __all__ = [
"WorkflowCompletedEvent",
"WorkflowContext",
"WorkflowEvent",
"WorkflowExecutor",
"WorkflowRunResult",
"WorkflowStartedEvent",
"WorkflowValidationError",
"WorkflowViz",
"__version__",
"handler",
"intercepts_request",
"validate_workflow_graph",
]
@@ -5,9 +5,12 @@ import functools
import inspect
import uuid
from collections.abc import Awaitable, Callable
from dataclasses import dataclass
from dataclasses import dataclass, field
from types import UnionType
from typing import Any, ClassVar, TypeVar, Union, get_args, get_origin, overload
from typing import TYPE_CHECKING, Any, Generic, TypeVar, Union, get_args, get_origin, overload
if TYPE_CHECKING:
from ._workflow import Workflow
from agent_framework import AgentRunResponse, AgentRunResponseUpdate, AgentThread, AIAgent, ChatMessage
from agent_framework._pydantic import AFBaseModel
@@ -55,12 +58,14 @@ class Executor(AFBaseModel):
super().__init__(**kwargs)
self._handlers: dict[type, Callable[[Any, WorkflowContext[Any]], Any]] = {}
self._request_interceptors: dict[type | str, list[dict[str, Any]]] = {}
self._discover_handlers()
if not self._handlers:
if not self._handlers and not self._request_interceptors:
raise ValueError(
f"Executor {self.__class__.__name__} has no handlers defined. "
"Please define at least one handler using the @handler decorator."
"Please define at least one handler using the @handler decorator "
"or @intercepts_request decorator."
)
async def execute(self, message: Any, context: WorkflowContext[Any]) -> None:
@@ -73,6 +78,16 @@ class Executor(AFBaseModel):
Returns:
An awaitable that resolves to the result of the execution.
"""
# Handle case where Message wrapper is passed instead of raw data
# Lazy registration for SubWorkflowRequestInfo if we have interceptors
if self._request_interceptors and message.__class__.__name__ == "SubWorkflowRequestInfo":
# Directly handle SubWorkflowRequestInfo
await context.add_event(ExecutorInvokeEvent(self.id))
await self._handle_sub_workflow_request(message, context)
await context.add_event(ExecutorCompletedEvent(self.id))
return
handler: Callable[[Any, WorkflowContext[Any]], Any] | None = None
for message_type in self._handlers:
if is_instance_of(message, message_type):
@@ -81,30 +96,147 @@ class Executor(AFBaseModel):
if handler is None:
raise RuntimeError(f"Executor {self.__class__.__name__} cannot handle message of type {type(message)}.")
await context.add_event(ExecutorInvokeEvent(self.id))
await handler(message, context)
await context.add_event(ExecutorCompletedEvent(self.id))
def _discover_handlers(self) -> None:
"""Discover message handlers in the executor class."""
"""Discover message handlers and request interceptors in the executor class."""
# Use __class__.__dict__ to avoid accessing pydantic's dynamic attributes
for attr_name in dir(self.__class__):
try:
attr = getattr(self.__class__, attr_name)
if callable(attr) and hasattr(attr, "_handler_spec"):
handler_spec = attr._handler_spec # type: ignore
if self._handlers.get(handler_spec["message_type"]) is not None:
raise ValueError(
f"Duplicate handler for type {handler_spec['message_type']} in {self.__class__.__name__}"
)
# Get the bound method
bound_method = getattr(self, attr_name)
self._handlers[handler_spec["message_type"]] = bound_method
if callable(attr):
# Discover @handler methods
if hasattr(attr, "_handler_spec"):
handler_spec = attr._handler_spec # type: ignore
message_type = handler_spec["message_type"]
# Keep full generic types for handler registration to avoid conflicts
# Different RequestResponse[T, U] specializations are distinct handler types
if self._handlers.get(message_type) is not None:
raise ValueError(f"Duplicate handler for type {message_type} in {self.__class__.__name__}")
# Get the bound method
bound_method = getattr(self, attr_name)
self._handlers[message_type] = bound_method
# Discover @intercepts_request methods
if hasattr(attr, "_intercepts_request"):
# Get the bound method for interceptors
bound_method = getattr(self, attr_name)
interceptor_info = {
"method": bound_method,
"from_workflow": getattr(attr, "_from_workflow", None),
"condition": getattr(attr, "_intercept_condition", None),
}
request_type = attr._intercepts_request # type: ignore
if request_type not in self._request_interceptors:
self._request_interceptors[request_type] = []
self._request_interceptors[request_type].append(interceptor_info)
except AttributeError:
# Skip attributes that may not be accessible
continue
def _register_sub_workflow_handler(self) -> None:
"""Register automatic handler for SubWorkflowRequestInfo messages."""
# We need to use a string reference until the class is defined
# This will be resolved later when the class is actually used
pass # Will be registered lazily when needed
async def _handle_sub_workflow_request(
self,
request: "SubWorkflowRequestInfo",
ctx: WorkflowContext[Any],
) -> None:
"""Automatic routing to @intercepts_request methods.
This is only active for executors that have @intercepts_request methods.
"""
# Try to match against registered interceptors
for request_type, interceptor_list in self._request_interceptors.items():
matched = False
# Check type matching
if isinstance(request_type, type) and is_instance_of(request.data, request_type):
matched = True
elif (
isinstance(request_type, str)
and hasattr(request.data, "__class__")
and request.data.__class__.__name__ == request_type
):
# String matching - could check against type name or other attributes
matched = True
if matched:
# Check each interceptor in the list for this request type
for interceptor_info in interceptor_list:
# Check workflow scope if specified
from_workflow = interceptor_info["from_workflow"]
if from_workflow and request.sub_workflow_id != from_workflow:
continue # Skip this interceptor, wrong workflow
# Check additional condition
condition = interceptor_info["condition"]
if condition and not condition(request):
continue
# Call the interceptor method
method = interceptor_info["method"]
response = await method(request.data, ctx)
# Check if interceptor handled it or needs to forward
if isinstance(response, RequestResponse):
# Add automatic correlation info to the response
correlated_response = RequestResponse._with_correlation(
response, request.data, request.request_id
)
if correlated_response.is_handled:
# Send response back to sub-workflow
from ._runner_context import Message
response_message = Message(
source_id=self.id,
target_id=request.sub_workflow_id,
data=SubWorkflowResponse(
request_id=request.request_id,
data=correlated_response.data,
),
)
await ctx.send_message(response_message)
else:
# Forward WITH CONTEXT PRESERVED
# Update the data if interceptor provided a modified request
if correlated_response.forward_request:
request.data = correlated_response.forward_request
# Send the inner request to RequestInfoExecutor to create external request
from ._runner_context import Message
forward_message = Message(
source_id=self.id,
data=request,
)
await ctx.send_message(forward_message)
else:
# Legacy support: direct return means handled
await ctx.send_message(
SubWorkflowResponse(
request_id=request.request_id,
data=response,
),
target_id=request.sub_workflow_id,
)
return
# No interceptor found - forward inner request to RequestInfoExecutor
# This sends the original request to RequestInfoExecutor
from ._runner_context import Message
passthrough_message = Message(source_id=self.id, data=request.data)
await ctx.send_message(passthrough_message)
def can_handle(self, message: Any) -> bool:
"""Check if the executor can handle a given message type.
@@ -257,6 +389,212 @@ def handler(
# endregion: Handler Decorator
# region Request/Response Types
TRequest = TypeVar("TRequest", bound="RequestInfoMessage")
TResponse = TypeVar("TResponse")
@dataclass
class RequestResponse(Generic[TRequest, TResponse]):
"""Response from @intercepts_request methods with automatic correlation support.
This type allows intercepting executors to indicate whether they handled
a request or whether it should be forwarded to external sources. When handled,
the framework automatically adds correlation info to link responses to requests.
"""
is_handled: bool
data: TResponse | None = None
forward_request: TRequest | None = None
original_request: TRequest | None = None # Added for automatic correlation
request_id: str | None = None # Added for tracking
@classmethod
def handled(cls, data: TResponse) -> "RequestResponse[Any, TResponse]":
"""Create a response indicating the request was handled.
Correlation info (original_request, request_id) will be added automatically
by the framework when processing intercepted requests.
"""
return cls(is_handled=True, data=data)
@classmethod
def forward(cls, modified_request: Any = None) -> "RequestResponse[Any, Any]":
"""Create a response indicating the request should be forwarded."""
return cls(is_handled=False, forward_request=modified_request)
@staticmethod
def _with_correlation(
original_response: "RequestResponse[Any, TResponse]", original_request: TRequest, request_id: str
) -> "RequestResponse[TRequest, TResponse]":
"""Internal method to add correlation info to a response.
This is called automatically by the framework and should not be used directly.
"""
return RequestResponse(
is_handled=original_response.is_handled,
data=original_response.data,
forward_request=original_response.forward_request,
original_request=original_request,
request_id=request_id,
)
@dataclass
class SubWorkflowRequestInfo:
"""Routes requests from sub-workflows to parent workflows.
This message type wraps requests from sub-workflows to add routing context,
allowing parent workflows to intercept and potentially handle the request.
"""
request_id: str # Original request ID from sub-workflow
sub_workflow_id: str # ID of the WorkflowExecutor that sent this
data: Any # The actual request data
@dataclass
class SubWorkflowResponse:
"""Routes responses back to sub-workflows.
This message type is used to send responses back to sub-workflows that
made requests, ensuring the response reaches the correct sub-workflow.
"""
request_id: str # Matches the original request ID
data: Any # The actual response data
# endregion: Request/Response Types
# region Intercepts Request Decorator
# TypeVar for request type that must be a RequestInfoMessage subclass
RequestInfoMessageT = TypeVar("RequestInfoMessageT", bound="RequestInfoMessage")
# Type alias for interceptor functions
InterceptorFunc = Callable[
[Any, RequestInfoMessageT, WorkflowContext[Any]], Awaitable[RequestResponse[RequestInfoMessageT, Any]]
]
@overload
def intercepts_request(
func: Callable[..., Any],
) -> Callable[..., Any]: ...
@overload
def intercepts_request(
*,
from_workflow: str | None = None,
condition: Callable[[Any], bool] | None = None,
) -> Callable[[Callable[..., Any]], Callable[..., Any]]: ...
def intercepts_request(
func: Callable[..., Any] | None = None,
*,
from_workflow: str | None = None,
condition: Callable[[Any], bool] | None = None,
) -> Callable[..., Any]:
"""Decorator to mark methods that intercept sub-workflow requests.
The request type is automatically inferred from the method's second parameter type hint.
The type must be a subclass of RequestInfoMessage.
This decorator allows executors in parent workflows to intercept and handle requests from
sub-workflows before they are sent to external sources.
Args:
func: The function being decorated (automatically passed when used without parentheses).
from_workflow: Optional ID of specific sub-workflow to intercept from.
condition: Optional callable that must return True for interception.
Returns:
The decorated function with interception metadata.
Example:
@intercepts_request
async def check_domain(
self, request: DomainCheckRequest, ctx: WorkflowContext[Any]
) -> RequestResponse[DomainCheckRequest, bool]:
# Type automatically inferred as DomainCheckRequest from parameter annotation
if request.domain in self.approved_domains:
return RequestResponse.handled(True)
return RequestResponse.forward()
@intercepts_request(from_workflow="email_validator")
async def handle_specific(
self, request: EmailRequest, ctx: WorkflowContext[Any]
) -> RequestResponse[EmailRequest, str]:
# Only intercepts EmailRequest from the "email_validator" workflow
return RequestResponse.handled("handled by parent")
"""
def decorator(func: Callable[..., Any]) -> Callable[..., Any]:
# Extract request type from method signature
sig = inspect.signature(func)
params = list(sig.parameters.values())
if len(params) < 2:
raise ValueError(f"Interceptor method '{func.__name__}' must have at least 2 parameters (self, request)")
request_param = params[1] # Second parameter (after self)
request_type = request_param.annotation
if request_type is inspect.Parameter.empty:
raise ValueError(f"Interceptor method '{func.__name__}' request parameter must have a type annotation")
# Runtime validation that it's a RequestInfoMessage subclass
if isinstance(request_type, type):
# We need to check if RequestInfoMessage is defined yet, since this runs at import time
try:
# Try to get RequestInfoMessage from the current module's globals
request_info_message_class = None
func_module = inspect.getmodule(func)
if func_module and hasattr(func_module, "RequestInfoMessage"):
request_info_message_class = func_module.RequestInfoMessage
else:
# Look in the current module (where this decorator is defined)
import sys
current_module = sys.modules[__name__]
if hasattr(current_module, "RequestInfoMessage"):
request_info_message_class = current_module.RequestInfoMessage
if request_info_message_class and not issubclass(request_type, request_info_message_class):
raise TypeError(
f"Interceptor method '{func.__name__}' can only handle RequestInfoMessage subclasses, "
f"got {request_type}. Make sure your request type inherits from RequestInfoMessage."
)
except (AttributeError, NameError):
# RequestInfoMessage might not be defined yet at import time, skip validation
# This will be caught later when the interceptor is actually called
pass
@functools.wraps(func)
async def wrapper(self: Any, request: Any, ctx: WorkflowContext[Any]) -> Any:
return await func(self, request, ctx)
# Add metadata for discovery - store the inferred type
wrapper._intercepts_request = request_type # type: ignore
wrapper._from_workflow = from_workflow # type: ignore
wrapper._intercept_condition = condition # type: ignore
return wrapper
# If func is provided, we're being called without parentheses: @intercepts_request
if func is not None:
return decorator(func)
# Otherwise, we're being called with parentheses: @intercepts_request(from_workflow="...")
return decorator
# endregion: Intercepts Request Decorator
# region Agent Executor
@@ -354,7 +692,11 @@ class RequestInfoMessage:
the request/response pattern explicit.
"""
request_id: str = str(uuid.uuid4())
request_id: str = field(default_factory=lambda: str(uuid.uuid4()))
# Note: SubWorkflowRequestInfo, SubWorkflowResponse, and RequestResponse
# have been moved before intercepts_request decorator
class RequestInfoExecutor(Executor):
@@ -365,13 +707,19 @@ class RequestInfoExecutor(Executor):
a response is provided externally, it emits the response as a message.
"""
# Well-known ID for the request info executor
EXECUTOR_ID: ClassVar[str] = "request_info"
def __init__(self, id: str | None = None):
"""Initialize the RequestInfoExecutor with an optional custom ID.
def __init__(self) -> None:
"""Initialize the RequestInfoExecutor with its well-known ID."""
super().__init__(id=self.EXECUTOR_ID)
Args:
id: Optional custom ID for this RequestInfoExecutor. If not provided,
a unique ID will be generated.
"""
import uuid
executor_id = id or f"request_info_{uuid.uuid4().hex[:8]}"
super().__init__(id=executor_id)
self._request_events: dict[str, RequestInfoEvent] = {}
self._sub_workflow_contexts: dict[str, dict[str, str]] = {}
@handler
async def run(self, message: RequestInfoMessage, ctx: WorkflowContext[None]) -> None:
@@ -387,6 +735,36 @@ class RequestInfoExecutor(Executor):
self._request_events[message.request_id] = event
await ctx.add_event(event)
@handler
async def handle_sub_workflow_request(
self,
message: SubWorkflowRequestInfo,
ctx: WorkflowContext[None],
) -> None:
"""Handle forwarded sub-workflow request.
This method handles requests that were forwarded from parent workflows
because they couldn't be handled locally.
"""
# When called directly from runner, we need to use the sub_workflow_id as the source
source_executor_id = message.sub_workflow_id
# Store context for routing response back
self._sub_workflow_contexts[message.request_id] = {
"sub_workflow_id": message.sub_workflow_id,
"parent_executor_id": source_executor_id,
}
# Create event for external handling - preserve the SubWorkflowRequestInfo wrapper
event = RequestInfoEvent(
request_id=message.request_id, # Use original request ID
source_executor_id=source_executor_id,
request_type=type(message.data), # SubWorkflowRequestInfo type
request_data=message.data, # The full SubWorkflowRequestInfo
)
self._request_events[message.request_id] = event
await ctx.add_event(event)
async def handle_response(
self,
response_data: Any,
@@ -404,7 +782,190 @@ class RequestInfoExecutor(Executor):
raise ValueError(f"No request found with ID: {request_id}")
event = self._request_events.pop(request_id)
await ctx.send_message(response_data, target_id=event.source_executor_id)
# Check if this was a forwarded sub-workflow request
if request_id in self._sub_workflow_contexts:
context = self._sub_workflow_contexts.pop(request_id)
# Send back to sub-workflow that made the original request
await ctx.send_message(
SubWorkflowResponse(
request_id=request_id,
data=response_data,
),
target_id=context["sub_workflow_id"],
)
else:
# Regular response - send directly back to source
# Create a correlated response that includes both the response data and original request
correlated_response = RequestResponse.handled(response_data)
correlated_response = RequestResponse._with_correlation(correlated_response, event.data, request_id)
await ctx.send_message(correlated_response, target_id=event.source_executor_id)
# endregion: Request Info Executor
# region Workflow Executor
class WorkflowExecutor(Executor):
"""An executor that runs another workflow as its execution logic.
This executor wraps a workflow to make it behave as an executor, enabling
hierarchical workflow composition. Sub-workflows can send requests that
are intercepted by parent workflows.
"""
def __init__(self, workflow: "Workflow", id: str | None = None):
"""Initialize the WorkflowExecutor.
Args:
workflow: The workflow to execute as a sub-workflow.
id: Optional unique identifier for this executor.
"""
super().__init__(id)
self._workflow = workflow
# Track pending external responses by request_id
self._pending_responses: dict[str, Any] = {} # request_id -> response_data
# Track workflow state for proper resumption - support multiple concurrent requests
self._pending_requests: dict[str, Any] = {} # request_id -> original request data
self._active_executions: int = 0 # Count of active sub-workflow executions
# Response accumulation for multiple concurrent responses
self._collected_responses: dict[str, Any] = {} # Accumulate responses
self._expected_response_count: int = 0 # Track how many responses we're waiting for
@handler # No output_types - can send any completion data type
async def process_workflow(self, input_data: object, ctx: WorkflowContext[Any]) -> None:
"""Execute the sub-workflow with raw input data.
This handler starts a new sub-workflow execution. When the sub-workflow
needs external information, it pauses and sends a request to the parent.
Args:
input_data: The input data to send to the sub-workflow.
ctx: The workflow context from the parent.
"""
# Skip SubWorkflowResponse and SubWorkflowRequestInfo - they have specific handlers
if isinstance(input_data, (SubWorkflowResponse, SubWorkflowRequestInfo)):
return
from ._events import RequestInfoEvent, WorkflowCompletedEvent
# Track this execution
self._active_executions += 1
try:
# Run the sub-workflow and collect all events
events = [event async for event in self._workflow.run_streaming(input_data)]
# Count requests and initialize response tracking
request_count = 0
for event in events:
if isinstance(event, RequestInfoEvent):
request_count += 1
# Initialize response accumulation for this execution
# For sequential workflows (like step_08), expect only current requests
# For parallel workflows (like step_09), expect all requests at once
self._expected_response_count = request_count
self._collected_responses = {}
# If no requests in initial run, handle completion immediately
if request_count == 0:
self._expected_response_count = 0
# Process events to check for completion or requests
for event in events:
if isinstance(event, WorkflowCompletedEvent):
# Sub-workflow completed normally - send result to parent
await ctx.send_message(event.data)
self._active_executions -= 1
return # Exit after completion
if isinstance(event, RequestInfoEvent):
# Sub-workflow needs external information
# Track the pending request
self._pending_requests[event.request_id] = event.data
# Wrap request with routing context and send to parent
wrapped_request = SubWorkflowRequestInfo(
request_id=event.request_id,
sub_workflow_id=self.id,
data=event.data,
)
await ctx.send_message(wrapped_request)
# Continue processing remaining events (no return)
except Exception as e:
from ._events import ExecutorEvent
# Sub-workflow failed - create error event
error_event = ExecutorEvent(executor_id=self.id, data={"error": str(e), "type": "sub_workflow_error"})
await ctx.add_event(error_event)
self._active_executions -= 1
raise
@handler
async def handle_response(
self,
response: SubWorkflowResponse,
ctx: WorkflowContext[Any],
) -> None:
"""Handle response from parent for a forwarded request.
This handler accumulates responses and only resumes the sub-workflow
when all expected responses have been received.
Args:
response: The response to a previous request.
ctx: The workflow context.
"""
# Check if we have this pending request
pending_requests = getattr(self, "_pending_requests", {})
if response.request_id not in pending_requests:
return
# Remove the request from pending list
pending_requests.pop(response.request_id, None)
# Accumulate the response
self._collected_responses[response.request_id] = response.data
# Check if we have all expected responses for current batch
if len(self._collected_responses) >= self._expected_response_count:
from ._events import RequestInfoEvent, WorkflowCompletedEvent
# Send all collected responses to the sub-workflow
responses_to_send = dict(self._collected_responses)
self._collected_responses.clear() # Clear for next batch
result_events = [event async for event in self._workflow.send_responses_streaming(responses_to_send)]
# Process the result events
new_request_count = 0
for event in result_events:
if isinstance(event, WorkflowCompletedEvent):
# Sub-workflow completed - send result to parent
await ctx.send_message(event.data)
self._active_executions -= 1
return
if isinstance(event, RequestInfoEvent):
# Sub-workflow sent more requests - prepare for next batch
new_request_count += 1
self._pending_requests[event.request_id] = event.data
# Send the new request to parent
wrapped_request = SubWorkflowRequestInfo(
request_id=event.request_id,
sub_workflow_id=self.id,
data=event.data,
)
await ctx.send_message(wrapped_request)
# Update expected count for next batch of requests
self._expected_response_count = new_request_count
# endregion: Workflow Executor
@@ -4,7 +4,10 @@ import asyncio
import logging
from collections import defaultdict
from collections.abc import AsyncIterable, Sequence
from typing import Any
from typing import TYPE_CHECKING, Any
if TYPE_CHECKING:
from ._executor import RequestInfoExecutor
from ._edge import EdgeGroup
from ._edge_runner import EdgeRunner, create_edge_runner
@@ -12,6 +15,8 @@ from ._events import WorkflowEvent
from ._executor import Executor
from ._runner_context import Message, RunnerContext
from ._shared_state import SharedState
from ._typing_utils import is_instance_of
from ._workflow_context import WorkflowContext
logger = logging.getLogger(__name__)
@@ -132,12 +137,107 @@ class Runner:
async def _deliver_messages(source_executor_id: str, messages: list[Message]) -> None:
"""Outer loop to concurrently deliver messages from all sources to their targets."""
# Special handling for SubWorkflowRequestInfo messages
async def _deliver_sub_workflow_requests(messages: list[Message]) -> None:
from ._executor import SubWorkflowRequestInfo
# Handle SubWorkflowRequestInfo messages - only process those not already targeted
sub_workflow_messages = []
for msg in messages:
# Skip messages sent directly to RequestInfoExecutor - they are already forwarded
if self._is_message_to_request_info_executor(msg):
continue
if isinstance(msg.data, SubWorkflowRequestInfo):
sub_workflow_messages.append(msg)
for message in sub_workflow_messages:
sub_request = message.data
# Find executor that can intercept the wrapped type
interceptor_found = False
for executor in self._executors.values():
if hasattr(executor, "_request_interceptors") and executor.id != message.source_id:
# Check if any registered interceptor can handle this request type
for registered_type in executor._request_interceptors:
# Check type matching - handle both type and string cases
matched = False
if (
isinstance(registered_type, type)
and is_instance_of(sub_request.data, registered_type)
) or (
isinstance(registered_type, str)
and hasattr(sub_request.data, "__class__")
and sub_request.data.__class__.__name__ == registered_type
):
matched = True
if matched:
# Send directly to the intercepting executor
logger.info(
f"Sending sub-workflow request of type '{sub_request.data.__class__.__name__}' "
f"from sub-workflow '{sub_request.sub_workflow_id}' "
f"to executor '{executor.id}' for interception."
)
await executor.execute(sub_request, self._ctx) # type: ignore[arg-type]
interceptor_found = True
break
if interceptor_found:
break
if not interceptor_found:
# No interceptor found - send directly to RequestInfoExecutor if available.
# Find the RequestInfoExecutor instance
request_info_executor = self._find_request_info_executor()
if request_info_executor:
workflow_ctx: WorkflowContext[None] = WorkflowContext(
request_info_executor.id,
["Runner"],
self._shared_state,
self._ctx,
)
logger.info(
f"Sending sub-workflow request of type '{sub_request.data.__class__.__name__}' "
f"from sub-workflow '{sub_request.sub_workflow_id}' to RequestInfoExecutor "
f"'{request_info_executor.id}'"
)
await request_info_executor.execute(sub_request, workflow_ctx)
else:
logger.warning(
f"Sub-workflow request of type '{sub_request.data.__class__.__name__}' "
f"from sub-workflow '{sub_request.sub_workflow_id}' could not be handled: "
f"no RequestInfoExecutor found in the workflow. Add a RequestInfoExecutor "
f"to handle external requests or add an @intercepts_request handler."
)
async def _deliver_message_inner(edge_runner: EdgeRunner, message: Message) -> bool:
"""Inner loop to deliver a single message through an edge runner."""
return await edge_runner.send_message(message, self._shared_state, self._ctx)
# Handle SubWorkflowRequestInfo messages specially
await _deliver_sub_workflow_requests(messages)
# Filter out SubWorkflowRequestInfo messages from normal edge routing
# since they were handled specially
from ._executor import SubWorkflowRequestInfo
non_sub_workflow_messages = []
for msg in messages:
# Keep messages sent directly to RequestInfoExecutor (forwarded messages)
if self._is_message_to_request_info_executor(msg):
non_sub_workflow_messages.append(msg)
continue
# Skip SubWorkflowRequestInfo messages (handled by special routing)
if isinstance(msg.data, SubWorkflowRequestInfo):
continue
non_sub_workflow_messages.append(msg)
associated_edge_runners = self._edge_runner_map.get(source_executor_id, [])
for message in messages:
for message in non_sub_workflow_messages:
# Deliver a message through all edge runners associated with the source executor concurrently.
tasks = [_deliver_message_inner(edge_runner, message) for edge_runner in associated_edge_runners]
results = await asyncio.gather(*tasks)
@@ -282,3 +382,36 @@ class Runner:
parsed[source_executor_id].append(runner)
return parsed
def _find_request_info_executor(self) -> "RequestInfoExecutor | None":
"""Find the RequestInfoExecutor instance in this workflow.
Returns:
The RequestInfoExecutor instance if found, None otherwise.
"""
from ._executor import RequestInfoExecutor
for executor in self._executors.values():
if isinstance(executor, RequestInfoExecutor):
return executor
return None
def _is_message_to_request_info_executor(self, msg: "Message") -> bool:
"""Check if message targets any RequestInfoExecutor in this workflow.
Args:
msg: The message to check.
Returns:
True if the message targets a RequestInfoExecutor, False otherwise.
"""
from ._executor import RequestInfoExecutor
if not msg.target_id:
return False
# Check all executors to see if target_id matches a RequestInfoExecutor
for executor in self._executors.values():
if executor.id == msg.target_id and isinstance(executor, RequestInfoExecutor):
return True
return False
@@ -50,5 +50,29 @@ def is_instance_of(data: Any, target_type: type) -> bool:
for key, value in data.items() # type: ignore
)
# Case 6: target_type is RequestResponse[T, U] - validate generic parameters
if origin and hasattr(origin, "__name__") and origin.__name__ == "RequestResponse":
if not isinstance(data, origin):
return False
# Validate generic parameters for RequestResponse[TRequest, TResponse]
if len(args) >= 2:
request_type, response_type = args[0], args[1]
# Check if the original_request matches TRequest and data matches TResponse
if (
hasattr(data, "original_request")
and data.original_request is not None
and not is_instance_of(data.original_request, request_type)
):
return False
if hasattr(data, "data") and data.data is not None and not is_instance_of(data.data, response_type):
return False
return True
# Case 7: Other custom generic classes - check origin type only
# For generic classes, we check if data is an instance of the origin type
# We don't validate the generic parameters at runtime since that's handled by type system
if origin and hasattr(origin, "__name__"):
return isinstance(data, origin)
# Fallback: if we reach here, we assume data is an instance of the target_type
return isinstance(data, target_type)
@@ -22,6 +22,7 @@ class ValidationTypeEnum(Enum):
TYPE_COMPATIBILITY = "TYPE_COMPATIBILITY"
GRAPH_CONNECTIVITY = "GRAPH_CONNECTIVITY"
HANDLER_OUTPUT_ANNOTATION = "HANDLER_OUTPUT_ANNOTATION"
INTERCEPTOR_CONFLICT = "INTERCEPTOR_CONFLICT"
class WorkflowValidationError(Exception):
@@ -94,6 +95,13 @@ class HandlerOutputAnnotationError(WorkflowValidationError):
self.handler_name = handler_name
class InterceptorConflictError(WorkflowValidationError):
"""Exception raised when multiple executors intercept the same request type from the same sub-workflow."""
def __init__(self, message: str):
super().__init__(message, validation_type=ValidationTypeEnum.INTERCEPTOR_CONFLICT)
# endregion
@@ -129,6 +137,14 @@ class WorkflowGraphValidator:
self._edges = [edge for group in edge_groups for edge in group.edges]
self._edge_groups = edge_groups
# If only the start executor exists, add it to the executor map
# Handle the special case where the workflow consists of only a single executor and no edges.
# In this scenario, the executor map will be empty because there are no edge groups to reference executors.
# Adding the start executor to the map ensures that single-executor workflows (without any edges) are supported,
# allowing validation and execution to proceed for workflows that do not require inter-executor communication.
if not self._executors and start_executor and isinstance(start_executor, Executor):
self._executors[start_executor.id] = start_executor
# Validate that start_executor exists in the graph
# It should because we check for it in the WorkflowBuilder
# but we do it here for completeness.
@@ -145,6 +161,7 @@ class WorkflowGraphValidator:
self._validate_handler_ambiguity()
self._validate_dead_ends()
self._validate_cycles()
self._validate_interceptor_uniqueness()
def _validate_handler_output_annotations(self) -> None:
"""Validate that each handler's ctx parameter is annotated with WorkflowContext[T].
@@ -306,13 +323,21 @@ class WorkflowGraphValidator:
# If either executor has no type information, log warning and skip validation
# This allows for dynamic typing scenarios but warns about reduced validation coverage
if not source_output_types or not target_input_types:
if not source_output_types:
# Suppress warnings for built-in workflow components where dynamic typing is expected
try:
from ._executor import RequestInfoExecutor, WorkflowExecutor # local import to avoid cycles
builtin_types = (RequestInfoExecutor, WorkflowExecutor)
except Exception:
builtin_types = tuple() # type: ignore[assignment]
if not source_output_types and not isinstance(source_executor, builtin_types):
logger.warning(
f"Executor '{source_executor.id}' has no output type annotations. "
f"Type compatibility validation will be skipped for edges from this executor. "
f"Consider adding WorkflowContext[T] generics in handlers for better validation."
)
if not target_input_types:
if not target_input_types and not isinstance(target_executor, builtin_types):
logger.warning(
f"Executor '{target_executor.id}' has no input type annotations. "
f"Type compatibility validation will be skipped for edges to this executor. "
@@ -376,6 +401,13 @@ class WorkflowGraphValidator:
# Skip attributes that may not be accessible
continue
# Also include intercepted request types as potential outputs
# since @intercepts_request methods can forward requests
if hasattr(executor, "_request_interceptors"):
for request_type in executor._request_interceptors:
if isinstance(request_type, type):
output_types.append(request_type)
return output_types
def _get_executor_input_types(self, executor: Executor) -> list[type[Any]]:
@@ -571,6 +603,54 @@ class WorkflowGraphValidator:
"Ensure proper termination conditions exist to prevent infinite loops."
)
def _validate_interceptor_uniqueness(self) -> None:
"""Validate that only one executor intercepts a given request type from a specific sub-workflow.
This prevents non-deterministic behavior where multiple executors could intercept
the same request type from the same sub-workflow.
"""
from ._executor import WorkflowExecutor
# Find all WorkflowExecutor instances in the workflow
workflow_executors: dict[str, WorkflowExecutor] = {}
for executor_id, executor in self._executors.items():
if isinstance(executor, WorkflowExecutor):
workflow_executors[executor_id] = executor
# For each WorkflowExecutor, check which executors can intercept its requests
for workflow_id, _workflow_executor in workflow_executors.items():
# Map of request_type -> list of intercepting executor IDs
interceptors_by_type: dict[type | str, list[str]] = {}
# Find all executors that have edges from this WorkflowExecutor
# These are potential interceptors
for edge in self._edges:
if edge.source_id == workflow_id:
target_executor = self._executors.get(edge.target_id)
if target_executor and hasattr(target_executor, "_request_interceptors"):
# Check what request types this executor intercepts
for request_type, interceptor_list in target_executor._request_interceptors.items():
# Check if any interceptor is scoped to this workflow or unscoped
for interceptor_info in interceptor_list:
from_workflow = interceptor_info.get("from_workflow")
# If unscoped or specifically scoped to this workflow
if from_workflow is None or from_workflow == workflow_id:
if request_type not in interceptors_by_type:
interceptors_by_type[request_type] = []
interceptors_by_type[request_type].append(edge.target_id)
# Check for duplicates
for request_type, executor_ids in interceptors_by_type.items():
unique_executors = list(set(executor_ids)) # Remove duplicates from same executor
if len(unique_executors) > 1:
type_name = request_type.__name__ if isinstance(request_type, type) else str(request_type)
raise InterceptorConflictError(
f"Multiple executors intercept the same request type '{type_name}' "
f"from sub-workflow '{workflow_id}': {', '.join(unique_executors)}. "
f"Only one executor should intercept a given request type from a specific sub-workflow "
f"to ensure deterministic behavior."
)
# endregion
# region Type Compatibility Utilities
@@ -219,8 +219,8 @@ class Workflow(AFBaseModel):
raise RuntimeError(f"Failed to restore from checkpoint: {checkpoint_id}")
if responses:
request_info_executor = self._get_executor_by_id(RequestInfoExecutor.EXECUTOR_ID)
if isinstance(request_info_executor, RequestInfoExecutor):
request_info_executor = self._find_request_info_executor()
if request_info_executor:
for request_id, response_data in responses.items():
await request_info_executor.handle_response(
response_data,
@@ -246,9 +246,9 @@ class Workflow(AFBaseModel):
Yields:
WorkflowEvent: The events generated during the workflow execution after sending the responses.
"""
request_info_executor = self._get_executor_by_id(RequestInfoExecutor.EXECUTOR_ID)
if not isinstance(request_info_executor, RequestInfoExecutor):
raise ValueError(f"Executor with ID {RequestInfoExecutor.EXECUTOR_ID} is not a RequestInfoExecutor.")
request_info_executor = self._find_request_info_executor()
if not request_info_executor:
raise ValueError("No RequestInfoExecutor found in workflow.")
async def _handle_response(response: Any, request_id: str) -> None:
"""Handle the response from the RequestInfoExecutor."""
@@ -332,6 +332,19 @@ class Workflow(AFBaseModel):
raise ValueError(f"Executor with ID {executor_id} not found.")
return self.executors[executor_id]
def _find_request_info_executor(self) -> "RequestInfoExecutor | None":
"""Find the RequestInfoExecutor instance in this workflow.
Returns:
The RequestInfoExecutor instance if found, None otherwise.
"""
from ._executor import RequestInfoExecutor
for executor in self.executors.values():
if isinstance(executor, RequestInfoExecutor):
return executor
return None
async def _restore_from_external_checkpoint(
self, checkpoint_id: str, checkpoint_storage: CheckpointStorage
) -> bool:
@@ -0,0 +1,111 @@
# Copyright (c) Microsoft. All rights reserved.
import asyncio
from dataclasses import dataclass
import pytest
from agent_framework_workflow import (
Executor,
WorkflowBuilder,
WorkflowContext,
WorkflowExecutor,
handler,
)
@dataclass
class SimpleRequest:
"""Simple request for testing."""
text: str
@dataclass
class SimpleResponse:
"""Simple response for testing."""
result: str
class SimpleSubExecutor(Executor):
"""Simple executor for sub-workflow."""
def __init__(self):
super().__init__(id="simple_sub")
@handler
async def process(self, request: SimpleRequest, ctx: WorkflowContext[None]) -> None:
"""Process a simple request."""
from agent_framework_workflow import WorkflowCompletedEvent
# Just echo back with prefix and complete
response = SimpleResponse(result=f"processed: {request.text}")
await ctx.add_event(WorkflowCompletedEvent(data=response))
class SimpleParent(Executor):
"""Simple parent executor."""
result: SimpleResponse | None = None
def __init__(self):
super().__init__(id="simple_parent")
@handler
async def start(self, text: str, ctx: WorkflowContext[SimpleRequest]) -> None:
"""Start the process."""
request = SimpleRequest(text=text)
await ctx.send_message(request, target_id="sub_workflow")
@handler
async def collect(self, response: SimpleResponse, ctx: WorkflowContext[None]) -> None:
"""Collect the result."""
self.result = response
@pytest.mark.asyncio
async def test_simple_sub_workflow():
"""Test the simplest possible sub-workflow."""
# Create sub-workflow with dummy executor to satisfy validation
sub_executor = SimpleSubExecutor()
class DummyExecutor(Executor):
def __init__(self):
super().__init__(id="dummy")
@handler
async def process(self, message: object, ctx: WorkflowContext[None]) -> None:
pass # Do nothing
dummy = DummyExecutor()
sub_workflow = (
WorkflowBuilder()
.set_start_executor(sub_executor)
.add_edge(sub_executor, dummy) # Add edge to satisfy validation
.build()
)
# Create parent workflow
parent = SimpleParent()
workflow_executor = WorkflowExecutor(sub_workflow, id="sub_workflow")
main_workflow = (
WorkflowBuilder()
.set_start_executor(parent)
.add_edge(parent, workflow_executor)
.add_edge(workflow_executor, parent)
.build()
)
# Run the workflow
await main_workflow.run("hello world")
# Check result
assert parent.result is not None
assert parent.result.result == "processed: hello world"
if __name__ == "__main__":
# Run the simple test
asyncio.run(test_simple_sub_workflow())
@@ -0,0 +1,420 @@
# Copyright (c) Microsoft. All rights reserved.
import asyncio
from dataclasses import dataclass
from typing import Any
import pytest
from pydantic import Field
from agent_framework_workflow import (
Executor,
RequestInfoExecutor,
RequestInfoMessage,
RequestResponse,
WorkflowBuilder,
WorkflowCompletedEvent,
WorkflowContext,
WorkflowExecutor,
handler,
intercepts_request,
)
# Test message types
@dataclass
class EmailValidationRequest:
"""Request to validate an email address."""
email: str
@dataclass
class DomainCheckRequest(RequestInfoMessage):
"""Request to check if a domain is approved."""
domain: str = ""
email: str = "" # Include original email for correlation
@dataclass
class ValidationResult:
"""Result of email validation."""
email: str
is_valid: bool
reason: str
# Test executors
class EmailValidator(Executor):
"""Validates email addresses in a sub-workflow."""
def __init__(self):
super().__init__(id="email_validator")
@handler
async def validate_request(
self, request: EmailValidationRequest, ctx: WorkflowContext[RequestInfoMessage | ValidationResult]
) -> None:
"""Validate an email address."""
# Extract domain and check if it's approved
domain = request.email.split("@")[1] if "@" in request.email else ""
if not domain:
result = ValidationResult(email=request.email, is_valid=False, reason="Invalid email format")
await ctx.add_event(WorkflowCompletedEvent(data=result))
return
# Request domain check from external source
domain_check = DomainCheckRequest(domain=domain, email=request.email)
await ctx.send_message(domain_check)
@handler
async def handle_domain_response(
self, response: RequestResponse[DomainCheckRequest, bool], ctx: WorkflowContext[ValidationResult]
) -> None:
"""Handle domain check response with correlation."""
# Use the original email from the correlated response
result = ValidationResult(
email=response.original_request.email,
is_valid=response.data or False,
reason="Domain approved" if response.data else "Domain not approved",
)
await ctx.add_event(WorkflowCompletedEvent(data=result))
class ParentOrchestrator(Executor):
"""Parent workflow orchestrator with domain knowledge."""
approved_domains: set[str] = Field(default_factory=lambda: {"example.com", "test.org"})
results: list[ValidationResult] = Field(default_factory=list)
def __init__(self, approved_domains: set[str] | None = None, **kwargs: Any):
if approved_domains is not None:
kwargs["approved_domains"] = approved_domains
super().__init__(id="parent_orchestrator", **kwargs)
@handler
async def start(self, emails: list[str], ctx: WorkflowContext[EmailValidationRequest]) -> None:
"""Start processing emails."""
for email in emails:
request = EmailValidationRequest(email=email)
await ctx.send_message(request, target_id="email_workflow")
@intercepts_request
async def check_domain(
self, request: DomainCheckRequest, ctx: WorkflowContext[Any]
) -> RequestResponse[DomainCheckRequest, bool]:
"""Intercept domain check requests from sub-workflows."""
# Check if we know this domain
if request.domain in self.approved_domains:
return RequestResponse[DomainCheckRequest, bool].handled(True)
# We don't know this domain, forward to external
return RequestResponse[DomainCheckRequest, bool].forward()
@handler
async def collect_result(self, result: ValidationResult, ctx: WorkflowContext[None]) -> None:
"""Collect validation results."""
self.results.append(result)
@pytest.mark.asyncio
async def test_basic_sub_workflow() -> None:
"""Test basic sub-workflow execution without interception."""
# Create sub-workflow
email_validator = EmailValidator()
email_request_info = RequestInfoExecutor(id="email_request_info")
validation_workflow = (
WorkflowBuilder()
.set_start_executor(email_validator)
.add_edge(email_validator, email_request_info)
.add_edge(email_request_info, email_validator)
.build()
)
# Create parent workflow without interception
class SimpleParent(Executor):
result: ValidationResult | None = Field(default=None)
def __init__(self, **kwargs: Any):
super().__init__(id="simple_parent", **kwargs)
@handler
async def start(self, email: str, ctx: WorkflowContext[EmailValidationRequest]) -> None:
request = EmailValidationRequest(email=email)
await ctx.send_message(request, target_id="email_workflow")
@handler
async def collect(self, result: ValidationResult, ctx: WorkflowContext[None]) -> None:
self.result = result
parent = SimpleParent()
workflow_executor = WorkflowExecutor(validation_workflow, id="email_workflow")
main_request_info = RequestInfoExecutor(id="main_request_info")
main_workflow = (
WorkflowBuilder()
.set_start_executor(parent)
.add_edge(parent, workflow_executor)
.add_edge(workflow_executor, parent)
.add_edge(workflow_executor, main_request_info)
.add_edge(main_request_info, workflow_executor) # CRITICAL: For SubWorkflowResponse routing
.build()
)
# Run workflow with mocked external response
result = await main_workflow.run("test@example.com")
# Get request event and respond
request_events = result.get_request_info_events()
assert len(request_events) == 1
assert isinstance(request_events[0].data, DomainCheckRequest)
assert request_events[0].data.domain == "example.com"
# Send response through the main workflow
await main_workflow.send_responses({
request_events[0].request_id: True # Domain is approved
})
# Check result
assert parent.result is not None
assert parent.result.email == "test@example.com"
assert parent.result.is_valid is True
@pytest.mark.asyncio
async def test_sub_workflow_with_interception():
"""Test sub-workflow with parent interception of requests."""
# Create sub-workflow
email_validator = EmailValidator()
email_request_info = RequestInfoExecutor(id="email_request_info")
validation_workflow = (
WorkflowBuilder()
.set_start_executor(email_validator)
.add_edge(email_validator, email_request_info)
.add_edge(email_request_info, email_validator)
.build()
)
# Create parent workflow with interception
parent = ParentOrchestrator(approved_domains={"example.com", "internal.org"})
workflow_executor = WorkflowExecutor(validation_workflow, id="email_workflow")
parent_request_info = RequestInfoExecutor()
main_workflow = (
WorkflowBuilder()
.set_start_executor(parent)
.add_edge(parent, workflow_executor)
.add_edge(workflow_executor, parent)
.add_edge(parent, parent_request_info) # For forwarded requests
.add_edge(parent_request_info, workflow_executor) # For SubWorkflowResponse routing
.build()
)
# Test 1: Email with known domain (intercepted)
result = await main_workflow.run(["user@example.com"])
# Should complete without external requests
request_events = result.get_request_info_events()
assert len(request_events) == 0 # No external requests, handled internally
assert len(parent.results) == 1
assert parent.results[0].email == "user@example.com"
assert parent.results[0].is_valid is True
assert parent.results[0].reason == "Domain approved"
# Test 2: Email with unknown domain (forwarded)
parent.results.clear()
result = await main_workflow.run(["user@unknown.com"])
# Should have external request
request_events = result.get_request_info_events()
assert len(request_events) == 1
assert isinstance(request_events[0].data, DomainCheckRequest)
assert request_events[0].data.domain == "unknown.com"
# Send external response
await main_workflow.send_responses({
request_events[0].request_id: False # Domain not approved
})
assert len(parent.results) == 1
assert parent.results[0].email == "user@unknown.com"
assert parent.results[0].is_valid is False
assert parent.results[0].reason == "Domain not approved"
@pytest.mark.asyncio
async def test_conditional_forwarding() -> None:
"""Test conditional forwarding with RequestResponse.forward()."""
class ConditionalParent(Executor):
"""Parent that conditionally handles requests."""
cache: dict[str, bool] = Field(default_factory=lambda: {"cached.com": True})
result: ValidationResult | None = Field(default=None)
def __init__(self, **kwargs: Any):
super().__init__(id="conditional_parent", **kwargs)
@handler
async def start(self, email: str, ctx: WorkflowContext[EmailValidationRequest]) -> None:
request = EmailValidationRequest(email=email)
await ctx.send_message(request, target_id="email_workflow")
@intercepts_request
async def check_domain(
self, request: DomainCheckRequest, ctx: WorkflowContext[Any]
) -> RequestResponse[DomainCheckRequest, bool]:
"""Check cache first, then forward if not found."""
if request.domain in self.cache:
# Return cached result
return RequestResponse[DomainCheckRequest, bool].handled(self.cache[request.domain])
# Not in cache, forward to external
return RequestResponse[DomainCheckRequest, bool].forward()
@handler
async def collect(self, result: ValidationResult, ctx: WorkflowContext[None]) -> None:
self.result = result
# Setup workflows
email_validator = EmailValidator()
request_info = RequestInfoExecutor()
validation_workflow = (
WorkflowBuilder()
.set_start_executor(email_validator)
.add_edge(email_validator, request_info)
.add_edge(request_info, email_validator)
.build()
)
parent = ConditionalParent()
workflow_executor = WorkflowExecutor(validation_workflow, id="email_workflow")
parent_request_info = RequestInfoExecutor()
main_workflow = (
WorkflowBuilder()
.set_start_executor(parent)
.add_edge(parent, workflow_executor)
.add_edge(workflow_executor, parent)
.add_edge(parent, parent_request_info)
.add_edge(parent_request_info, workflow_executor) # For SubWorkflowResponse routing
.build()
)
# Test cached domain
result = await main_workflow.run("user@cached.com")
request_events = result.get_request_info_events()
assert len(request_events) == 0 # Handled from cache
assert parent.result is not None
assert parent.result.is_valid is True
# Test uncached domain
parent.result = None
result = await main_workflow.run("user@new.com")
request_events = result.get_request_info_events()
assert len(request_events) == 1 # Forwarded to external
await main_workflow.send_responses({request_events[0].request_id: True})
assert parent.result is not None
assert parent.result.is_valid is True
@pytest.mark.asyncio
async def test_workflow_scoped_interception() -> None:
"""Test interception scoped to specific sub-workflows."""
class MultiWorkflowParent(Executor):
"""Parent handling multiple sub-workflows."""
results: dict[str, ValidationResult] = Field(default_factory=dict)
def __init__(self, **kwargs: Any):
super().__init__(id="multi_parent", **kwargs)
@handler
async def start(self, data: dict[str, str], ctx: WorkflowContext[EmailValidationRequest]) -> None:
# Send to different sub-workflows
await ctx.send_message(EmailValidationRequest(email=data["email1"]), target_id="workflow_a")
await ctx.send_message(EmailValidationRequest(email=data["email2"]), target_id="workflow_b")
@intercepts_request(from_workflow="workflow_a")
async def check_domain_a(
self, request: DomainCheckRequest, ctx: WorkflowContext[Any]
) -> RequestResponse[DomainCheckRequest, bool]:
"""Strict rules for workflow A."""
if request.domain == "strict.com":
return RequestResponse[DomainCheckRequest, bool].handled(True)
return RequestResponse[DomainCheckRequest, bool].forward()
@intercepts_request(from_workflow="workflow_b")
async def check_domain_b(
self, request: DomainCheckRequest, ctx: WorkflowContext[Any]
) -> RequestResponse[DomainCheckRequest, bool]:
"""Lenient rules for workflow B."""
if request.domain.endswith(".com"):
return RequestResponse[DomainCheckRequest, bool].handled(True)
return RequestResponse[DomainCheckRequest, bool].forward()
@handler
async def collect(self, result: ValidationResult, ctx: WorkflowContext[None]) -> None:
self.results[result.email] = result
# Create two identical sub-workflows
def create_validation_workflow():
validator = EmailValidator()
request_info = RequestInfoExecutor()
return (
WorkflowBuilder()
.set_start_executor(validator)
.add_edge(validator, request_info)
.add_edge(request_info, validator)
.build()
)
workflow_a = create_validation_workflow()
workflow_b = create_validation_workflow()
parent = MultiWorkflowParent()
executor_a = WorkflowExecutor(workflow_a, id="workflow_a")
executor_b = WorkflowExecutor(workflow_b, id="workflow_b")
parent_request_info = RequestInfoExecutor()
main_workflow = (
WorkflowBuilder()
.set_start_executor(parent)
.add_edge(parent, executor_a)
.add_edge(parent, executor_b)
.add_edge(executor_a, parent)
.add_edge(executor_b, parent)
.add_edge(parent, parent_request_info)
.add_edge(parent_request_info, executor_a) # For SubWorkflowResponse routing
.add_edge(parent_request_info, executor_b) # For SubWorkflowResponse routing
.build()
)
# Run test
result = await main_workflow.run({"email1": "user@strict.com", "email2": "user@random.com"})
# Workflow A should handle strict.com
# Workflow B should handle any .com domain
request_events = result.get_request_info_events()
assert len(request_events) == 0 # Both handled internally
assert len(parent.results) == 2
assert parent.results["user@strict.com"].is_valid is True
assert parent.results["user@random.com"].is_valid is True
if __name__ == "__main__":
# Run tests
asyncio.run(test_basic_sub_workflow())
asyncio.run(test_sub_workflow_with_interception())
asyncio.run(test_conditional_forwarding())
asyncio.run(test_workflow_scoped_interception())
@@ -11,6 +11,7 @@ from agent_framework.workflow import (
RequestInfoEvent,
RequestInfoExecutor,
RequestInfoMessage,
RequestResponse,
WorkflowBuilder,
WorkflowCompletedEvent,
WorkflowContext,
@@ -68,10 +69,13 @@ class MockExecutorRequestApproval(Executor):
await ctx.send_message(RequestInfoMessage())
@handler
async def mock_handler_b(self, message: ApprovalMessage, ctx: WorkflowContext[NumberMessage]) -> None:
async def mock_handler_b(
self, message: RequestResponse[RequestInfoMessage, ApprovalMessage], ctx: WorkflowContext[NumberMessage]
) -> None:
"""A mock handler that processes the approval response."""
data = await ctx.get_shared_state(self.id)
if message.approved:
assert isinstance(message.data, ApprovalMessage)
if message.data.approved:
await ctx.add_event(WorkflowCompletedEvent(data=data))
else:
await ctx.send_message(NumberMessage(data=data))
@@ -12,6 +12,7 @@ from agent_framework.workflow import (
RequestInfoEvent,
RequestInfoExecutor,
RequestInfoMessage,
RequestResponse,
WorkflowBuilder,
WorkflowCompletedEvent,
WorkflowContext,
@@ -94,16 +95,20 @@ class CriticGroupChatManager(Executor):
@handler
async def handle_request_response(
self, response: list[ChatMessage], ctx: WorkflowContext[AgentExecutorRequest]
self,
response: RequestResponse[RequestInfoMessage, list[ChatMessage]],
ctx: WorkflowContext[AgentExecutorRequest],
) -> None:
"""Handler that processes the response from the RequestInfoExecutor."""
messages: list[ChatMessage] = response.data or []
# Update the chat history with the response
self._chat_history.extend(response)
self._chat_history.extend(messages)
# Send the response to the other members
await asyncio.gather(*[
ctx.send_message(
AgentExecutorRequest(messages=response, should_respond=False),
AgentExecutorRequest(messages=messages, should_respond=False),
target_id=member_id,
)
for member_id in self._members
@@ -0,0 +1,233 @@
# Copyright (c) Microsoft. All rights reserved.
import asyncio
from dataclasses import dataclass
from typing import Any
from agent_framework.workflow import (
Executor,
WorkflowBuilder,
WorkflowCompletedEvent,
WorkflowContext,
WorkflowEvent,
WorkflowExecutor,
handler,
)
"""
The following sample demonstrates basic sub-workflows.
This sample shows how to:
1. Create workflows that execute other workflows as sub-workflows
2. Pass data between parent and sub-workflows
3. Collect results from sub-workflows
The example simulates a simple text processing system where:
- Sub-workflows process individual text strings (count words, characters)
- Parent workflow orchestrates multiple sub-workflows and collects results
Key concepts demonstrated:
- WorkflowExecutor: Wraps a workflow to make it behave as an executor
- Sub-workflow isolation: Sub-workflows work independently
- Result collection: Parent workflows can gather outputs from sub-workflows
Simple flow visualization:
Parent Orchestrator
|
| TextProcessingRequest(text, task_id)
v
[ Sub-workflow: WorkflowExecutor(TextProcessor) ]
|
| WorkflowCompletedEvent(TextProcessingResult)
v
Parent collects results and summarizes
"""
# Message types
@dataclass
class TextProcessingRequest:
"""Request to process a text string."""
text: str
task_id: str
@dataclass
class TextProcessingResult:
"""Result of text processing."""
task_id: str
text: str
word_count: int
char_count: int
class AllTasksCompleted(WorkflowEvent):
"""Event triggered when all processing tasks are complete."""
def __init__(self, results: list[TextProcessingResult]):
super().__init__(results)
# Sub-workflow executor
class TextProcessor(Executor):
"""Processes text strings - counts words and characters."""
def __init__(self):
super().__init__(id="text_processor")
@handler
async def process_text(
self, request: TextProcessingRequest, ctx: WorkflowContext[TextProcessingResult]
) -> None:
"""Process a text string and return statistics."""
text_preview = f"'{request.text[:50]}{'...' if len(request.text) > 50 else ''}'"
print(f"πŸ” Sub-workflow processing text (Task {request.task_id}): {text_preview}")
# Simple text processing
word_count = len(request.text.split()) if request.text.strip() else 0
char_count = len(request.text)
print(f"πŸ“Š Task {request.task_id}: {word_count} words, {char_count} characters")
# Create result
result = TextProcessingResult(
task_id=request.task_id,
text=request.text,
word_count=word_count,
char_count=char_count,
)
print(f"βœ… Sub-workflow completed task {request.task_id}")
# Signal completion
await ctx.add_event(WorkflowCompletedEvent(data=result))
# Parent workflow
class TextProcessingOrchestrator(Executor):
"""Orchestrates multiple text processing tasks using sub-workflows."""
results: list[TextProcessingResult] = []
expected_count: int = 0
def __init__(self):
super().__init__(id="text_orchestrator")
@handler
async def start_processing(
self, texts: list[str], ctx: WorkflowContext[TextProcessingRequest]
) -> None:
"""Start processing multiple text strings."""
print(f"πŸ“„ Starting processing of {len(texts)} text strings")
print("=" * 60)
self.expected_count = len(texts)
# Send each text to a sub-workflow
for i, text in enumerate(texts):
task_id = f"task_{i+1}"
request = TextProcessingRequest(text=text, task_id=task_id)
print(f"πŸ“€ Dispatching {task_id} to sub-workflow")
await ctx.send_message(request, target_id="text_processor_workflow")
@handler
async def collect_result(
self, result: TextProcessingResult, ctx: WorkflowContext[None]
) -> None:
"""Collect results from sub-workflows."""
print(f"πŸ“₯ Collected result from {result.task_id}")
self.results.append(result)
# Check if all results are collected
if len(self.results) == self.expected_count:
print("\nπŸŽ‰ All tasks completed!")
await ctx.add_event(AllTasksCompleted(self.results))
def get_summary(self) -> dict[str, Any]:
"""Get a summary of all processing results."""
total_words = sum(result.word_count for result in self.results)
total_chars = sum(result.char_count for result in self.results)
avg_words = total_words / len(self.results) if self.results else 0
avg_chars = total_chars / len(self.results) if self.results else 0
return {
"total_texts": len(self.results),
"total_words": total_words,
"total_characters": total_chars,
"average_words_per_text": round(avg_words, 2),
"average_characters_per_text": round(avg_chars, 2),
}
async def main():
"""Main function to run the basic sub-workflow example."""
print("πŸš€ Setting up sub-workflow...")
# Step 1: Create the text processing sub-workflow
text_processor = TextProcessor()
processing_workflow = (
WorkflowBuilder()
.set_start_executor(text_processor)
.build()
)
print("πŸ”§ Setting up parent workflow...")
# Step 2: Create the parent workflow
orchestrator = TextProcessingOrchestrator()
workflow_executor = WorkflowExecutor(processing_workflow, id="text_processor_workflow")
main_workflow = (
WorkflowBuilder()
.set_start_executor(orchestrator)
.add_edge(orchestrator, workflow_executor)
.add_edge(workflow_executor, orchestrator)
.build()
)
# Step 3: Test data - various text strings
test_texts = [
"Hello world! This is a simple test.",
"Python is a powerful programming language used for many applications.",
"Short text.",
"This is a longer text with multiple sentences. It contains more words and characters. We use it to test our text processing workflow.",
"", # Empty string
" Spaces around text ",
]
print(f"\nπŸ§ͺ Testing with {len(test_texts)} text strings")
print("=" * 60)
# Step 4: Run the workflow
result = await main_workflow.run(test_texts)
# Step 5: Display results
print(f"\nπŸ“Š Processing Results:")
print("=" * 60)
# Sort results by task_id for consistent display
sorted_results = sorted(orchestrator.results, key=lambda r: r.task_id)
for result in sorted_results:
preview = result.text[:30] + "..." if len(result.text) > 30 else result.text
preview = preview.replace('\n', ' ').strip() or '(empty)'
print(f"βœ… {result.task_id}: '{preview}' -> {result.word_count} words, {result.char_count} chars")
# Step 6: Display summary
summary = orchestrator.get_summary()
print(f"\nπŸ“ˆ Summary:")
print("=" * 60)
print(f"πŸ“„ Total texts processed: {summary['total_texts']}")
print(f"πŸ“ Total words: {summary['total_words']}")
print(f"πŸ”€ Total characters: {summary['total_characters']}")
print(f"πŸ“Š Average words per text: {summary['average_words_per_text']}")
print(f"πŸ“ Average characters per text: {summary['average_characters_per_text']}")
print(f"\n🏁 Processing complete!")
if __name__ == "__main__":
asyncio.run(main())
@@ -0,0 +1,275 @@
# Copyright (c) Microsoft. All rights reserved.
"""
The following sample demonstrates sub-workflows with request interception and conditional forwarding.
This sample shows how to:
1. Create workflows that execute other workflows as sub-workflows
2. Intercept requests from sub-workflows in parent workflows using @intercepts_request
3. Conditionally handle or forward requests using RequestResponse.handled() and RequestResponse.forward()
4. Handle external requests that are forwarded by the parent workflow
5. Proper request/response correlation for concurrent processing
The example simulates an email validation system where:
- Sub-workflows validate multiple email addresses concurrently
- Parent workflows can intercept domain check requests for optimization
- Known domains (example.com, company.com) are approved locally
- Unknown domains (unknown.org) are forwarded to external services
- Request correlation ensures each email gets the correct domain check response
- External domain check requests are processed and responses routed back correctly
Key concepts demonstrated:
- WorkflowExecutor: Wraps a workflow to make it behave as an executor
- @intercepts_request: Decorator for parent workflows to handle sub-workflow requests
- RequestResponse: Enables conditional handling vs forwarding of requests
- Request correlation: Using request_id to match responses with original requests
- Concurrent processing: Multiple emails processed simultaneously without interference
- External request routing: RequestInfoExecutor handles forwarded external requests
- Sub-workflow isolation: Sub-workflows work normally without knowing they're nested
Simple flow visualization:
Parent Orchestrator (@intercepts_request)
|
| EmailValidationRequest(email) x3 (concurrent)
v
[ Sub-workflow: WorkflowExecutor(EmailValidator) ]
|
| DomainCheckRequest(domain) with request_id correlation
v
Interception? yes -> handled locally with RequestResponse.handled(True)
no -> forwarded to RequestInfoExecutor -> external service
|
v
Response routed back to sub-workflow using request_id
"""
import asyncio
from dataclasses import dataclass
from typing import Any
from agent_framework_workflow import (
Executor,
RequestInfoExecutor,
RequestInfoMessage,
RequestResponse,
WorkflowBuilder,
WorkflowCompletedEvent,
WorkflowContext,
WorkflowExecutor,
handler,
intercepts_request,
)
# 1. Define domain-specific message types
@dataclass
class EmailValidationRequest:
"""Request to validate an email address."""
email: str
@dataclass
class DomainCheckRequest(RequestInfoMessage):
"""Request to check if a domain is approved."""
domain: str = ""
@dataclass
class ValidationResult:
"""Result of email validation."""
email: str
is_valid: bool
reason: str
# 2. Implement the sub-workflow executor (completely standard)
class EmailValidator(Executor):
"""Validates email addresses - doesn't know it's in a sub-workflow."""
def __init__(self):
"""Initialize the EmailValidator executor."""
super().__init__(id="email_validator")
# Use a dict to track multiple pending emails by request_id
self._pending_emails: dict[str, str] = {}
@handler
async def validate_request(
self, request: EmailValidationRequest, ctx: WorkflowContext[DomainCheckRequest | ValidationResult]
) -> None:
"""Validate an email address."""
print(f"πŸ” Sub-workflow validating email: {request.email}")
# Extract domain
domain = request.email.split("@")[1] if "@" in request.email else ""
if not domain:
print(f"❌ Invalid email format: {request.email}")
result = ValidationResult(email=request.email, is_valid=False, reason="Invalid email format")
await ctx.add_event(WorkflowCompletedEvent(data=result))
return
print(f"🌐 Sub-workflow requesting domain check for: {domain}")
# Request domain check
domain_check = DomainCheckRequest(domain=domain)
# Store the pending email with the request_id for correlation
self._pending_emails[domain_check.request_id] = request.email
await ctx.send_message(domain_check, target_id="email_request_info")
@handler
async def handle_domain_response(
self,
response: RequestResponse[DomainCheckRequest, bool],
ctx: WorkflowContext[ValidationResult],
) -> None:
"""Handle domain check response from RequestInfo with correlation."""
approved = bool(response.data)
domain = response.original_request.domain if (hasattr(response, 'original_request') and response.original_request) else "unknown"
print(f"πŸ“¬ Sub-workflow received domain response for '{domain}': {approved}")
# Find the corresponding email using the request_id
request_id = response.original_request.request_id if (hasattr(response, 'original_request') and response.original_request) else None
if request_id and request_id in self._pending_emails:
email = self._pending_emails.pop(request_id) # Remove from pending
result = ValidationResult(
email=email,
is_valid=approved,
reason="Domain approved" if approved else "Domain not approved",
)
print(f"βœ… Sub-workflow completing validation for: {email}")
await ctx.add_event(WorkflowCompletedEvent(data=result))
# 3. Implement the parent workflow with request interception
class SmartEmailOrchestrator(Executor):
"""Parent orchestrator that can intercept domain checks."""
approved_domains: set[str] = set()
def __init__(self, approved_domains: set[str] | None = None):
"""Initialize the SmartEmailOrchestrator with approved domains.
Args:
approved_domains: Set of pre-approved domains, defaults to example.com, test.org, company.com
"""
super().__init__(id="email_orchestrator", approved_domains=approved_domains)
self._results: list[ValidationResult] = []
@handler
async def start_validation(self, emails: list[str], ctx: WorkflowContext[EmailValidationRequest]) -> None:
"""Start validating a batch of emails."""
print(f"πŸ“§ Starting validation of {len(emails)} email addresses")
print("=" * 60)
for email in emails:
print(f"πŸ“€ Sending '{email}' to sub-workflow for validation")
request = EmailValidationRequest(email=email)
await ctx.send_message(request, target_id="email_validator_workflow")
@intercepts_request
async def check_domain(
self, request: DomainCheckRequest, ctx: WorkflowContext[Any]
) -> RequestResponse[DomainCheckRequest, bool]:
"""Intercept domain check requests from sub-workflows."""
print(f"πŸ” Parent intercepting domain check for: {request.domain}")
if request.domain in self.approved_domains:
print(f"βœ… Domain '{request.domain}' is pre-approved locally!")
return RequestResponse[DomainCheckRequest, bool].handled(True)
print(f"❓ Domain '{request.domain}' unknown, forwarding to external service...")
return RequestResponse.forward()
@handler
async def collect_result(self, result: ValidationResult, ctx: WorkflowContext[None]) -> None:
"""Collect validation results. It comes from the sub-workflow emitted WorkflowCompletionEvent's data field."""
status_icon = "βœ…" if result.is_valid else "❌"
print(f"πŸ“₯ {status_icon} Validation result: {result.email} -> {result.reason}")
self._results.append(result)
@property
def results(self) -> list[ValidationResult]:
"""Get the collected validation results."""
return self._results
async def run_example() -> None:
"""Run the sub-workflow example."""
print("πŸš€ Setting up sub-workflow with request interception...")
print()
# 4. Build the sub-workflow
email_validator = EmailValidator()
# Match the target_id used in EmailValidator ("email_request_info")
request_info = RequestInfoExecutor(id="email_request_info")
validation_workflow = (
WorkflowBuilder()
.set_start_executor(email_validator)
.add_edge(email_validator, request_info)
.add_edge(request_info, email_validator)
.build()
)
# 5. Build the parent workflow with interception
orchestrator = SmartEmailOrchestrator(approved_domains={"example.com", "company.com"})
workflow_executor = WorkflowExecutor(validation_workflow, id="email_validator_workflow")
# Add a RequestInfoExecutor to handle forwarded external requests
main_request_info = RequestInfoExecutor(id="main_request_info")
main_workflow = (
WorkflowBuilder()
.set_start_executor(orchestrator)
.add_edge(orchestrator, workflow_executor)
.add_edge(workflow_executor, orchestrator)
# Add edges for external request handling
.add_edge(orchestrator, main_request_info)
.add_edge(main_request_info, workflow_executor) # Route external responses to sub-workflow
.build()
)
# 6. Prepare test inputs: known domain, unknown domain
test_emails = [
"user@example.com", # Should be intercepted and approved
"admin@company.com", # Should be intercepted and approved
"guest@unknown.org", # Should be forwarded externally
]
# 7. Run the workflow
result = await main_workflow.run(test_emails)
# 8. Handle any external requests
request_events = result.get_request_info_events()
if request_events:
print(f"\n🌐 Handling {len(request_events)} external request(s)...")
for event in request_events:
if event.data and hasattr(event.data, "domain"):
print(f"πŸ” External domain check needed for: {event.data.domain}")
# Simulate external responses
external_responses: dict[str, bool] = {}
for event in request_events:
# Simulate external domain checking
if event.data and hasattr(event.data, "domain"):
domain = event.data.domain
# Let's say unknown.org is actually approved externally
approved = domain == "unknown.org"
print(f"🌐 External service response for '{domain}': {'APPROVED' if approved else 'REJECTED'}")
external_responses[event.request_id] = approved
# 9. Send external responses
await main_workflow.send_responses(external_responses)
else:
print("\n🎯 All requests were intercepted and handled locally!")
# 10. Display final summary
print(f"\nπŸ“Š Final Results Summary:")
print("=" * 60)
for result in orchestrator.results:
status = "βœ… VALID" if result.is_valid else "❌ INVALID"
print(f"{status} {result.email}: {result.reason}")
print(f"\n🏁 Processed {len(orchestrator.results)} emails total")
if __name__ == "__main__":
asyncio.run(run_example())
@@ -0,0 +1,414 @@
# Copyright (c) Microsoft. All rights reserved.
import asyncio
from dataclasses import dataclass
from typing import Any
from agent_framework.workflow import (
Executor,
RequestInfoExecutor,
WorkflowBuilder,
WorkflowCompletedEvent,
WorkflowContext,
handler,
)
# Import the new sub-workflow types directly from the implementation package
try:
from agent_framework_workflow import (
RequestInfoMessage,
RequestResponse,
WorkflowExecutor,
intercepts_request,
)
except ImportError:
import os
import sys
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "..", "..", "packages", "workflow"))
from agent_framework_workflow import (
RequestInfoMessage,
RequestResponse,
WorkflowExecutor,
intercepts_request,
)
"""
This sample demonstrates the PROPER pattern for request interception.
Key principles:
1. Only ONE executor intercepts a given request type from a specific sub-workflow
2. Different executors can intercept DIFFERENT request types from the same sub-workflow
3. The same executor can intercept the same request type from DIFFERENT sub-workflows
This ensures:
- Deterministic behavior
- Clear responsibility boundaries
- Easier debugging and maintenance
The example simulates a resource allocation system where:
- Sub-workflow requests resources (CPU, memory, etc.)
- A single Cache executor intercepts and handles resource requests
- The Cache can either satisfy from cache or forward to external
Simple flow visualization:
Coordinator
|
| list[resource/policy requests]
v
[ Sub-workflow: WorkflowExecutor(ResourceRequester) ]
| |
| ResourceRequest | PolicyCheckRequest
v v
ResourceCache (@intercepts) PolicyEngine (@intercepts)
| handled/forward | handled/forward
v v
RequestInfo (external) <----- forwarded when not handled
| responses
v
Back to sub-workflow -> completion -> results collected
"""
# 1. Define domain-specific request/response types
@dataclass
class ResourceRequest(RequestInfoMessage):
"""Request for computing resources."""
resource_type: str = "cpu" # cpu, memory, disk, etc.
amount: int = 1
priority: str = "normal" # low, normal, high
@dataclass
class PolicyCheckRequest(RequestInfoMessage):
"""Request to check resource allocation policy."""
resource_type: str = ""
amount: int = 0
policy_type: str = "quota" # quota, compliance, security
@dataclass
class ResourceResponse:
"""Response with allocated resources."""
resource_type: str
allocated: int
source: str # Which system provided the resources
@dataclass
class PolicyResponse:
"""Response from policy check."""
approved: bool
reason: str
@dataclass
class RequestFinished:
pass
# 2. Implement the sub-workflow executor - makes resource and policy requests
class ResourceRequester(Executor):
"""Simple executor that requests resources and checks policies."""
def __init__(self):
super().__init__(id="resource_requester")
self._request_count = 0
@handler
async def request_resources(
self,
requests: list[dict[str, Any]],
ctx: WorkflowContext[ResourceRequest | PolicyCheckRequest],
) -> None:
"""Process a list of resource requests."""
print(f"🏭 Sub-workflow processing {len(requests)} requests")
self._request_count += len(requests)
for req_data in requests:
req_type = req_data.get("request_type", "resource")
if req_type == "resource":
print(f" πŸ“¦ Requesting resource: {req_data.get('type', 'cpu')} x{req_data.get('amount', 1)}")
request = ResourceRequest(
resource_type=req_data.get("type", "cpu"),
amount=req_data.get("amount", 1),
priority=req_data.get("priority", "normal"),
)
# Send to parent workflow for interception - not to target_id
await ctx.send_message(request)
elif req_type == "policy":
print(f" πŸ›‘οΈ Checking policy: {req_data.get('type', 'cpu')} x{req_data.get('amount', 1)} ({req_data.get('policy_type', 'quota')})")
request = PolicyCheckRequest(
resource_type=req_data.get("type", "cpu"),
amount=req_data.get("amount", 1),
policy_type=req_data.get("policy_type", "quota"),
)
# Send to parent workflow for interception - not to target_id
await ctx.send_message(request)
@handler
async def handle_resource_response(
self,
response: RequestResponse[ResourceRequest, ResourceResponse],
ctx: WorkflowContext[None],
) -> None:
"""Handle resource allocation response."""
if response.data:
source_icon = "πŸͺ" if response.data.source == "cache" else "🌐"
print(
f"πŸ“¦ {source_icon} Sub-workflow received: {response.data.allocated} {response.data.resource_type} from {response.data.source}"
)
if self._collect_results():
# Emit completion event and send RequestFinished to the parent workflow.
await ctx.add_event(WorkflowCompletedEvent(RequestFinished()))
@handler
async def handle_policy_response(
self, response: RequestResponse[PolicyCheckRequest, PolicyResponse], ctx: WorkflowContext[None]
) -> None:
"""Handle policy check response."""
if response.data:
status_icon = "βœ…" if response.data.approved else "❌"
print(f"πŸ›‘οΈ {status_icon} Sub-workflow received policy response: {response.data.approved} - {response.data.reason}")
if self._collect_results():
# Emit completion event and send RequestFinished to the parent workflow.
await ctx.add_event(WorkflowCompletedEvent(RequestFinished()))
def _collect_results(self) -> bool:
"""Collect and summarize results."""
self._request_count -= 1
print(f"πŸ“Š Sub-workflow completed request ({self._request_count} remaining)")
return self._request_count == 0
# 3. Implement the Resource Cache - ONLY intercepts ResourceRequest
class ResourceCache(Executor):
"""Interceptor that handles RESOURCE requests from cache."""
# Use class attributes to avoid Pydantic assignment restrictions
cache: dict[str, int] = {"cpu": 10, "memory": 50, "disk": 100}
results: list[ResourceResponse] = []
def __init__(self):
super().__init__(id="resource_cache")
# Instance initialization only; state kept in class attributes as above
@intercepts_request
async def check_cache(
self, request: ResourceRequest, ctx: WorkflowContext[None]
) -> RequestResponse[ResourceRequest, ResourceResponse]:
"""Intercept RESOURCE requests and check cache first."""
print(f"πŸͺ CACHE interceptor checking: {request.amount} {request.resource_type}")
available = self.cache.get(request.resource_type, 0)
if available >= request.amount:
# We can satisfy from cache
self.cache[request.resource_type] -= request.amount
response = ResourceResponse(resource_type=request.resource_type, allocated=request.amount, source="cache")
print(f" βœ… Cache satisfied: {request.amount} {request.resource_type}")
self.results.append(response)
return RequestResponse[ResourceRequest, ResourceResponse].handled(response)
# Cache miss - forward to external
print(f" ❌ Cache miss: need {request.amount}, have {available} {request.resource_type}")
return RequestResponse.forward()
@handler
async def collect_result(
self, response: RequestResponse[ResourceRequest, ResourceResponse], ctx: WorkflowContext[None]
) -> None:
"""Collect results from external requests that were forwarded."""
if response.data and response.data.source != "cache": # Don't double-count our own results
self.results.append(response.data)
print(
f"πŸͺ 🌐 Cache received external response: {response.data.allocated} {response.data.resource_type} from {response.data.source}"
)
# 4. Implement the Policy Engine - ONLY intercepts PolicyCheckRequest (different type!)
class PolicyEngine(Executor):
"""Interceptor that handles POLICY requests."""
# Use class attributes for simple sample state
quota: dict[str, int] = {
"cpu": 5, # Only allow up to 5 CPU units
"memory": 20, # Only allow up to 20 memory units
"disk": 1000, # Liberal disk policy
}
results: list[PolicyResponse] = []
def __init__(self):
super().__init__(id="policy_engine")
# Instance initialization only; state kept in class attributes as above
@intercepts_request
async def check_policy(
self, request: PolicyCheckRequest, ctx: WorkflowContext[None]
) -> RequestResponse[PolicyCheckRequest, PolicyResponse]:
"""Intercept POLICY requests and apply rules."""
print(f"πŸ›‘οΈ POLICY interceptor checking: {request.amount} {request.resource_type}, policy={request.policy_type}")
quota_limit = self.quota.get(request.resource_type, 0)
if request.policy_type == "quota":
if request.amount <= quota_limit:
response = PolicyResponse(approved=True, reason=f"Within quota ({quota_limit})")
print(f" βœ… Policy approved: {request.amount} <= {quota_limit}")
self.results.append(response)
return RequestResponse[PolicyCheckRequest, PolicyResponse].handled(response)
# Exceeds quota - forward to external for review
print(f" ❌ Policy exceeds quota: {request.amount} > {quota_limit}, forwarding to external")
return RequestResponse.forward()
# Unknown policy type - forward to external
print(f" ❓ Unknown policy type: {request.policy_type}, forwarding")
return RequestResponse.forward()
@handler
async def collect_policy_result(
self, response: RequestResponse[PolicyCheckRequest, PolicyResponse], ctx: WorkflowContext[None]
) -> None:
"""Collect policy results from external requests that were forwarded."""
if response.data:
self.results.append(response.data)
print(f"πŸ›‘οΈ 🌐 Policy received external response: {response.data.approved} - {response.data.reason}")
class Coordinator(Executor):
def __init__(self):
super().__init__(id="coordinator")
@handler
async def start(self, requests: list[dict[str, Any]], ctx: WorkflowContext[object]) -> None:
"""Start the resource allocation process."""
await ctx.send_message(requests, target_id="resource_workflow")
@handler
async def handle_completion(self, completion: RequestFinished, ctx: WorkflowContext[None]) -> None:
"""Handle sub-workflow completion. It comes from the sub-workflow emitted WorkflowCompletionEvent's data field."""
print(f"🎯 Main workflow received completion.")
async def main() -> None:
"""Demonstrate parallel request interception patterns."""
print("πŸš€ Starting Sub-Workflow Parallel Request Interception Demo...")
print("=" * 60)
# 5. Create the sub-workflow
resource_requester = ResourceRequester()
sub_request_info = RequestInfoExecutor(id="sub_request_info")
sub_workflow = (
WorkflowBuilder()
.set_start_executor(resource_requester)
.add_edge(resource_requester, sub_request_info)
.add_edge(sub_request_info, resource_requester)
.build()
)
# 6. Create parent workflow with PROPER interceptor pattern
cache = ResourceCache() # Intercepts ResourceRequest
policy = PolicyEngine() # Intercepts PolicyCheckRequest (different type!)
workflow_executor = WorkflowExecutor(sub_workflow, id="resource_workflow")
main_request_info = RequestInfoExecutor(id="main_request_info")
# Create a simple coordinator that starts the process
coordinator = Coordinator()
# PROPER PATTERN: Each executor intercepts DIFFERENT request types
main_workflow = (
WorkflowBuilder()
.set_start_executor(coordinator)
.add_edge(coordinator, workflow_executor) # Start sub-workflow
.add_edge(workflow_executor, coordinator) # Sub-workflow completion back to coordinator
.add_edge(workflow_executor, cache) # Cache intercepts ResourceRequest
.add_edge(cache, workflow_executor) # Cache handles ResourceRequest
.add_edge(workflow_executor, policy) # Policy handles PolicyCheckRequest
.add_edge(policy, workflow_executor) # Policy intercepts PolicyCheckRequest
.add_edge(cache, main_request_info) # Cache forwards to external
.add_edge(policy, main_request_info) # Policy forwards to external
.add_edge(main_request_info, workflow_executor) # External responses back
.add_edge(workflow_executor, main_request_info) # Sub-workflow forwards to main
.build()
)
# 7. Test with various requests (mixed resource and policy)
test_requests = [
{"request_type": "resource", "type": "cpu", "amount": 2, "priority": "normal"}, # Cache hit
{"request_type": "policy", "type": "cpu", "amount": 3, "policy_type": "quota"}, # Policy hit
{"request_type": "resource", "type": "memory", "amount": 15, "priority": "normal"}, # Cache hit
{"request_type": "policy", "type": "memory", "amount": 100, "policy_type": "quota"}, # Policy miss -> external
{"request_type": "resource", "type": "gpu", "amount": 1, "priority": "high"}, # Cache miss -> external
{"request_type": "policy", "type": "disk", "amount": 500, "policy_type": "quota"}, # Policy hit
{"request_type": "policy", "type": "cpu", "amount": 1, "policy_type": "security"}, # Unknown policy -> external
]
print(f"πŸ§ͺ Testing with {len(test_requests)} mixed requests:")
for i, req in enumerate(test_requests, 1):
req_icon = "πŸ“¦" if req["request_type"] == "resource" else "πŸ›‘οΈ"
print(f" {i}. {req_icon} {req['type']} x{req['amount']} ({req.get('priority', req.get('policy_type', 'default'))})")
print("=" * 70)
# 8. Run the workflow
print("🎬 Running workflow...")
result = await main_workflow.run(test_requests)
# 9. Handle any external requests that couldn't be intercepted
request_events = result.get_request_info_events()
if request_events:
print(f"\n🌐 Handling {len(request_events)} external request(s)...")
external_responses: dict[str, Any] = {}
for event in request_events:
if isinstance(event.data, ResourceRequest):
# Handle ResourceRequest - create ResourceResponse
resource_response = ResourceResponse(
resource_type=event.data.resource_type, allocated=event.data.amount, source="external_provider"
)
external_responses[event.request_id] = resource_response
print(f" 🏭 External provider: {resource_response.allocated} {resource_response.resource_type}")
elif isinstance(event.data, PolicyCheckRequest):
# Handle PolicyCheckRequest - create PolicyResponse
policy_response = PolicyResponse(approved=True, reason="External policy service approved")
external_responses[event.request_id] = policy_response
print(f" πŸ”’ External policy: {'βœ… APPROVED' if policy_response.approved else '❌ DENIED'}")
await main_workflow.send_responses(external_responses)
else:
print("\n🎯 All requests were intercepted internally!")
# 10. Show results and analysis
print("\n" + "=" * 70)
print("πŸ“Š RESULTS ANALYSIS")
print("=" * 70)
print(f"\nπŸͺ Cache Results ({len(cache.results)} handled):")
for result in cache.results:
print(f" βœ… {result.allocated} {result.resource_type} from {result.source}")
print(f"\nπŸ›‘οΈ Policy Results ({len(policy.results)} handled):")
for result in policy.results:
status_icon = "βœ…" if result.approved else "❌"
print(f" {status_icon} Approved: {result.approved} - {result.reason}")
print(f"\nπŸ’Ύ Final Cache State:")
for resource, amount in cache.cache.items():
print(f" πŸ“¦ {resource}: {amount} remaining")
print(f"\nπŸ“ˆ Summary:")
print(f" 🎯 Total requests: {len(test_requests)}")
print(f" πŸͺ Resource requests handled: {len(cache.results)}")
print(f" πŸ›‘οΈ Policy requests handled: {len(policy.results)}")
print(f" 🌐 External requests: {len(request_events) if request_events else 0}")
print("\n" + "=" * 70)
if __name__ == "__main__":
asyncio.run(main())