mirror of
https://github.com/microsoft/agent-framework.git
synced 2026-06-16 21:04:09 +08:00
Python: Adding support for nested workflows (#460)
* Adding design documents and data flow descriptions for sub-workflows * Updating docs. * Sub-workflow implementation #1. Stuck because of singleton RequestInfoExecutor, going to make a change to remove that restrivtion. * Removed the singleton restriction on RequestInfoExecutor so enable sub-workflows. * Scenarios seem to be working. * Sample improved. * going to have intern add generic response wrappers. * Wrapped responses working. * Non-hardcoded routing is working. * Sample showing external approved and not approved. * Cleaning up. * Updating some samples and user guide. * Removing old design doc. * Cleaning up. * Adding python-package-setup.md back. * Update python/packages/workflow/agent_framework_workflow/_executor.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Update python/packages/workflow/agent_framework_workflow/_validation.py Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> * Removing prints. * Fixing lint and type issues. * Fixing lint and type issues. * Update python/packages/workflow/agent_framework_workflow/_executor.py Co-authored-by: Eric Zhu <ekzhu@users.noreply.github.com> * Adding type hints to intercepts decorator. * Removing unused files. * Fixing issue with sample 5 groupchat with hil. * Removing redundent samples. * Updates to ensure no conflicting request interceptors and to support a subflow with multiple requests in a single super step. * Fixing pypi errors. * clean up samples * update samples to make it more clear * warning for unhandled request info from sub workflow * add logger info --------- Co-authored-by: Copilot <175728472+Copilot@users.noreply.github.com> Co-authored-by: Eric Zhu <ekzhu@users.noreply.github.com>
This commit is contained in:
committed by
GitHub
Unverified
parent
b26d9c95fe
commit
b0b3fd151c
@@ -319,4 +319,4 @@ We should consider auto-instrumentation and provide an implementation of it to t
|
||||
### Build and release
|
||||
The build step will be done in GHA, adding the package to the release and then we call into Azure DevOps to use the ESRP pipeline to publish to pypi. This is how SK already works, we will just have to adapt it to the new package structure.
|
||||
|
||||
For now we will stick to semantic versioning, and all preview release will be tagged as such.
|
||||
For now we will stick to semantic versioning, and all preview release will be tagged as such.
|
||||
@@ -35,6 +35,11 @@ _IMPORTS = [
|
||||
"WorkflowCheckpoint",
|
||||
"Case",
|
||||
"Default",
|
||||
"RequestResponse",
|
||||
"SubWorkflowRequestInfo",
|
||||
"SubWorkflowResponse",
|
||||
"WorkflowExecutor",
|
||||
"intercepts_request",
|
||||
]
|
||||
|
||||
|
||||
|
||||
@@ -18,17 +18,22 @@ from agent_framework_workflow import (
|
||||
RequestInfoEvent,
|
||||
RequestInfoExecutor,
|
||||
RequestInfoMessage,
|
||||
RequestResponse,
|
||||
SubWorkflowRequestInfo,
|
||||
SubWorkflowResponse,
|
||||
Workflow,
|
||||
WorkflowBuilder,
|
||||
WorkflowCheckpoint,
|
||||
WorkflowCompletedEvent,
|
||||
WorkflowContext,
|
||||
WorkflowEvent,
|
||||
WorkflowExecutor,
|
||||
WorkflowRunResult,
|
||||
WorkflowStartedEvent,
|
||||
WorkflowViz,
|
||||
__version__,
|
||||
handler,
|
||||
intercepts_request,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
@@ -49,15 +54,20 @@ __all__ = [
|
||||
"RequestInfoEvent",
|
||||
"RequestInfoExecutor",
|
||||
"RequestInfoMessage",
|
||||
"RequestResponse",
|
||||
"SubWorkflowRequestInfo",
|
||||
"SubWorkflowResponse",
|
||||
"Workflow",
|
||||
"WorkflowBuilder",
|
||||
"WorkflowCheckpoint",
|
||||
"WorkflowCompletedEvent",
|
||||
"WorkflowContext",
|
||||
"WorkflowEvent",
|
||||
"WorkflowExecutor",
|
||||
"WorkflowRunResult",
|
||||
"WorkflowStartedEvent",
|
||||
"WorkflowViz",
|
||||
"__version__",
|
||||
"handler",
|
||||
"intercepts_request",
|
||||
]
|
||||
|
||||
@@ -30,7 +30,12 @@ from ._executor import (
|
||||
Executor,
|
||||
RequestInfoExecutor,
|
||||
RequestInfoMessage,
|
||||
RequestResponse,
|
||||
SubWorkflowRequestInfo,
|
||||
SubWorkflowResponse,
|
||||
WorkflowExecutor,
|
||||
handler,
|
||||
intercepts_request,
|
||||
)
|
||||
from ._runner_context import (
|
||||
InProcRunnerContext,
|
||||
@@ -76,11 +81,12 @@ __all__ = [
|
||||
"InProcRunnerContext",
|
||||
"Message",
|
||||
"RequestInfoEvent",
|
||||
"RequestInfoEvent",
|
||||
"RequestInfoExecutor",
|
||||
"RequestInfoExecutor",
|
||||
"RequestInfoMessage",
|
||||
"RequestResponse",
|
||||
"RunnerContext",
|
||||
"SubWorkflowRequestInfo",
|
||||
"SubWorkflowResponse",
|
||||
"TypeCompatibilityError",
|
||||
"ValidationTypeEnum",
|
||||
"Workflow",
|
||||
@@ -89,11 +95,13 @@ __all__ = [
|
||||
"WorkflowCompletedEvent",
|
||||
"WorkflowContext",
|
||||
"WorkflowEvent",
|
||||
"WorkflowExecutor",
|
||||
"WorkflowRunResult",
|
||||
"WorkflowStartedEvent",
|
||||
"WorkflowValidationError",
|
||||
"WorkflowViz",
|
||||
"__version__",
|
||||
"handler",
|
||||
"intercepts_request",
|
||||
"validate_workflow_graph",
|
||||
]
|
||||
|
||||
@@ -5,9 +5,12 @@ import functools
|
||||
import inspect
|
||||
import uuid
|
||||
from collections.abc import Awaitable, Callable
|
||||
from dataclasses import dataclass
|
||||
from dataclasses import dataclass, field
|
||||
from types import UnionType
|
||||
from typing import Any, ClassVar, TypeVar, Union, get_args, get_origin, overload
|
||||
from typing import TYPE_CHECKING, Any, Generic, TypeVar, Union, get_args, get_origin, overload
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ._workflow import Workflow
|
||||
|
||||
from agent_framework import AgentRunResponse, AgentRunResponseUpdate, AgentThread, AIAgent, ChatMessage
|
||||
from agent_framework._pydantic import AFBaseModel
|
||||
@@ -55,12 +58,14 @@ class Executor(AFBaseModel):
|
||||
super().__init__(**kwargs)
|
||||
|
||||
self._handlers: dict[type, Callable[[Any, WorkflowContext[Any]], Any]] = {}
|
||||
self._request_interceptors: dict[type | str, list[dict[str, Any]]] = {}
|
||||
self._discover_handlers()
|
||||
|
||||
if not self._handlers:
|
||||
if not self._handlers and not self._request_interceptors:
|
||||
raise ValueError(
|
||||
f"Executor {self.__class__.__name__} has no handlers defined. "
|
||||
"Please define at least one handler using the @handler decorator."
|
||||
"Please define at least one handler using the @handler decorator "
|
||||
"or @intercepts_request decorator."
|
||||
)
|
||||
|
||||
async def execute(self, message: Any, context: WorkflowContext[Any]) -> None:
|
||||
@@ -73,6 +78,16 @@ class Executor(AFBaseModel):
|
||||
Returns:
|
||||
An awaitable that resolves to the result of the execution.
|
||||
"""
|
||||
# Handle case where Message wrapper is passed instead of raw data
|
||||
|
||||
# Lazy registration for SubWorkflowRequestInfo if we have interceptors
|
||||
if self._request_interceptors and message.__class__.__name__ == "SubWorkflowRequestInfo":
|
||||
# Directly handle SubWorkflowRequestInfo
|
||||
await context.add_event(ExecutorInvokeEvent(self.id))
|
||||
await self._handle_sub_workflow_request(message, context)
|
||||
await context.add_event(ExecutorCompletedEvent(self.id))
|
||||
return
|
||||
|
||||
handler: Callable[[Any, WorkflowContext[Any]], Any] | None = None
|
||||
for message_type in self._handlers:
|
||||
if is_instance_of(message, message_type):
|
||||
@@ -81,30 +96,147 @@ class Executor(AFBaseModel):
|
||||
|
||||
if handler is None:
|
||||
raise RuntimeError(f"Executor {self.__class__.__name__} cannot handle message of type {type(message)}.")
|
||||
|
||||
await context.add_event(ExecutorInvokeEvent(self.id))
|
||||
await handler(message, context)
|
||||
await context.add_event(ExecutorCompletedEvent(self.id))
|
||||
|
||||
def _discover_handlers(self) -> None:
|
||||
"""Discover message handlers in the executor class."""
|
||||
"""Discover message handlers and request interceptors in the executor class."""
|
||||
# Use __class__.__dict__ to avoid accessing pydantic's dynamic attributes
|
||||
for attr_name in dir(self.__class__):
|
||||
try:
|
||||
attr = getattr(self.__class__, attr_name)
|
||||
if callable(attr) and hasattr(attr, "_handler_spec"):
|
||||
handler_spec = attr._handler_spec # type: ignore
|
||||
if self._handlers.get(handler_spec["message_type"]) is not None:
|
||||
raise ValueError(
|
||||
f"Duplicate handler for type {handler_spec['message_type']} in {self.__class__.__name__}"
|
||||
)
|
||||
# Get the bound method
|
||||
bound_method = getattr(self, attr_name)
|
||||
self._handlers[handler_spec["message_type"]] = bound_method
|
||||
if callable(attr):
|
||||
# Discover @handler methods
|
||||
if hasattr(attr, "_handler_spec"):
|
||||
handler_spec = attr._handler_spec # type: ignore
|
||||
message_type = handler_spec["message_type"]
|
||||
|
||||
# Keep full generic types for handler registration to avoid conflicts
|
||||
# Different RequestResponse[T, U] specializations are distinct handler types
|
||||
|
||||
if self._handlers.get(message_type) is not None:
|
||||
raise ValueError(f"Duplicate handler for type {message_type} in {self.__class__.__name__}")
|
||||
# Get the bound method
|
||||
bound_method = getattr(self, attr_name)
|
||||
self._handlers[message_type] = bound_method
|
||||
|
||||
# Discover @intercepts_request methods
|
||||
if hasattr(attr, "_intercepts_request"):
|
||||
# Get the bound method for interceptors
|
||||
bound_method = getattr(self, attr_name)
|
||||
interceptor_info = {
|
||||
"method": bound_method,
|
||||
"from_workflow": getattr(attr, "_from_workflow", None),
|
||||
"condition": getattr(attr, "_intercept_condition", None),
|
||||
}
|
||||
request_type = attr._intercepts_request # type: ignore
|
||||
if request_type not in self._request_interceptors:
|
||||
self._request_interceptors[request_type] = []
|
||||
self._request_interceptors[request_type].append(interceptor_info)
|
||||
except AttributeError:
|
||||
# Skip attributes that may not be accessible
|
||||
continue
|
||||
|
||||
def _register_sub_workflow_handler(self) -> None:
|
||||
"""Register automatic handler for SubWorkflowRequestInfo messages."""
|
||||
# We need to use a string reference until the class is defined
|
||||
# This will be resolved later when the class is actually used
|
||||
pass # Will be registered lazily when needed
|
||||
|
||||
async def _handle_sub_workflow_request(
|
||||
self,
|
||||
request: "SubWorkflowRequestInfo",
|
||||
ctx: WorkflowContext[Any],
|
||||
) -> None:
|
||||
"""Automatic routing to @intercepts_request methods.
|
||||
|
||||
This is only active for executors that have @intercepts_request methods.
|
||||
"""
|
||||
# Try to match against registered interceptors
|
||||
for request_type, interceptor_list in self._request_interceptors.items():
|
||||
matched = False
|
||||
|
||||
# Check type matching
|
||||
if isinstance(request_type, type) and is_instance_of(request.data, request_type):
|
||||
matched = True
|
||||
elif (
|
||||
isinstance(request_type, str)
|
||||
and hasattr(request.data, "__class__")
|
||||
and request.data.__class__.__name__ == request_type
|
||||
):
|
||||
# String matching - could check against type name or other attributes
|
||||
matched = True
|
||||
|
||||
if matched:
|
||||
# Check each interceptor in the list for this request type
|
||||
for interceptor_info in interceptor_list:
|
||||
# Check workflow scope if specified
|
||||
from_workflow = interceptor_info["from_workflow"]
|
||||
if from_workflow and request.sub_workflow_id != from_workflow:
|
||||
continue # Skip this interceptor, wrong workflow
|
||||
|
||||
# Check additional condition
|
||||
condition = interceptor_info["condition"]
|
||||
if condition and not condition(request):
|
||||
continue
|
||||
|
||||
# Call the interceptor method
|
||||
method = interceptor_info["method"]
|
||||
response = await method(request.data, ctx)
|
||||
|
||||
# Check if interceptor handled it or needs to forward
|
||||
if isinstance(response, RequestResponse):
|
||||
# Add automatic correlation info to the response
|
||||
correlated_response = RequestResponse._with_correlation(
|
||||
response, request.data, request.request_id
|
||||
)
|
||||
|
||||
if correlated_response.is_handled:
|
||||
# Send response back to sub-workflow
|
||||
from ._runner_context import Message
|
||||
|
||||
response_message = Message(
|
||||
source_id=self.id,
|
||||
target_id=request.sub_workflow_id,
|
||||
data=SubWorkflowResponse(
|
||||
request_id=request.request_id,
|
||||
data=correlated_response.data,
|
||||
),
|
||||
)
|
||||
await ctx.send_message(response_message)
|
||||
else:
|
||||
# Forward WITH CONTEXT PRESERVED
|
||||
# Update the data if interceptor provided a modified request
|
||||
if correlated_response.forward_request:
|
||||
request.data = correlated_response.forward_request
|
||||
|
||||
# Send the inner request to RequestInfoExecutor to create external request
|
||||
from ._runner_context import Message
|
||||
|
||||
forward_message = Message(
|
||||
source_id=self.id,
|
||||
data=request,
|
||||
)
|
||||
await ctx.send_message(forward_message)
|
||||
else:
|
||||
# Legacy support: direct return means handled
|
||||
await ctx.send_message(
|
||||
SubWorkflowResponse(
|
||||
request_id=request.request_id,
|
||||
data=response,
|
||||
),
|
||||
target_id=request.sub_workflow_id,
|
||||
)
|
||||
return
|
||||
|
||||
# No interceptor found - forward inner request to RequestInfoExecutor
|
||||
# This sends the original request to RequestInfoExecutor
|
||||
from ._runner_context import Message
|
||||
|
||||
passthrough_message = Message(source_id=self.id, data=request.data)
|
||||
await ctx.send_message(passthrough_message)
|
||||
|
||||
def can_handle(self, message: Any) -> bool:
|
||||
"""Check if the executor can handle a given message type.
|
||||
|
||||
@@ -257,6 +389,212 @@ def handler(
|
||||
|
||||
# endregion: Handler Decorator
|
||||
|
||||
# region Request/Response Types
|
||||
|
||||
TRequest = TypeVar("TRequest", bound="RequestInfoMessage")
|
||||
TResponse = TypeVar("TResponse")
|
||||
|
||||
|
||||
@dataclass
|
||||
class RequestResponse(Generic[TRequest, TResponse]):
|
||||
"""Response from @intercepts_request methods with automatic correlation support.
|
||||
|
||||
This type allows intercepting executors to indicate whether they handled
|
||||
a request or whether it should be forwarded to external sources. When handled,
|
||||
the framework automatically adds correlation info to link responses to requests.
|
||||
"""
|
||||
|
||||
is_handled: bool
|
||||
data: TResponse | None = None
|
||||
forward_request: TRequest | None = None
|
||||
original_request: TRequest | None = None # Added for automatic correlation
|
||||
request_id: str | None = None # Added for tracking
|
||||
|
||||
@classmethod
|
||||
def handled(cls, data: TResponse) -> "RequestResponse[Any, TResponse]":
|
||||
"""Create a response indicating the request was handled.
|
||||
|
||||
Correlation info (original_request, request_id) will be added automatically
|
||||
by the framework when processing intercepted requests.
|
||||
"""
|
||||
return cls(is_handled=True, data=data)
|
||||
|
||||
@classmethod
|
||||
def forward(cls, modified_request: Any = None) -> "RequestResponse[Any, Any]":
|
||||
"""Create a response indicating the request should be forwarded."""
|
||||
return cls(is_handled=False, forward_request=modified_request)
|
||||
|
||||
@staticmethod
|
||||
def _with_correlation(
|
||||
original_response: "RequestResponse[Any, TResponse]", original_request: TRequest, request_id: str
|
||||
) -> "RequestResponse[TRequest, TResponse]":
|
||||
"""Internal method to add correlation info to a response.
|
||||
|
||||
This is called automatically by the framework and should not be used directly.
|
||||
"""
|
||||
return RequestResponse(
|
||||
is_handled=original_response.is_handled,
|
||||
data=original_response.data,
|
||||
forward_request=original_response.forward_request,
|
||||
original_request=original_request,
|
||||
request_id=request_id,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class SubWorkflowRequestInfo:
|
||||
"""Routes requests from sub-workflows to parent workflows.
|
||||
|
||||
This message type wraps requests from sub-workflows to add routing context,
|
||||
allowing parent workflows to intercept and potentially handle the request.
|
||||
"""
|
||||
|
||||
request_id: str # Original request ID from sub-workflow
|
||||
sub_workflow_id: str # ID of the WorkflowExecutor that sent this
|
||||
data: Any # The actual request data
|
||||
|
||||
|
||||
@dataclass
|
||||
class SubWorkflowResponse:
|
||||
"""Routes responses back to sub-workflows.
|
||||
|
||||
This message type is used to send responses back to sub-workflows that
|
||||
made requests, ensuring the response reaches the correct sub-workflow.
|
||||
"""
|
||||
|
||||
request_id: str # Matches the original request ID
|
||||
data: Any # The actual response data
|
||||
|
||||
|
||||
# endregion: Request/Response Types
|
||||
|
||||
# region Intercepts Request Decorator
|
||||
|
||||
# TypeVar for request type that must be a RequestInfoMessage subclass
|
||||
RequestInfoMessageT = TypeVar("RequestInfoMessageT", bound="RequestInfoMessage")
|
||||
|
||||
# Type alias for interceptor functions
|
||||
InterceptorFunc = Callable[
|
||||
[Any, RequestInfoMessageT, WorkflowContext[Any]], Awaitable[RequestResponse[RequestInfoMessageT, Any]]
|
||||
]
|
||||
|
||||
|
||||
@overload
|
||||
def intercepts_request(
|
||||
func: Callable[..., Any],
|
||||
) -> Callable[..., Any]: ...
|
||||
|
||||
|
||||
@overload
|
||||
def intercepts_request(
|
||||
*,
|
||||
from_workflow: str | None = None,
|
||||
condition: Callable[[Any], bool] | None = None,
|
||||
) -> Callable[[Callable[..., Any]], Callable[..., Any]]: ...
|
||||
|
||||
|
||||
def intercepts_request(
|
||||
func: Callable[..., Any] | None = None,
|
||||
*,
|
||||
from_workflow: str | None = None,
|
||||
condition: Callable[[Any], bool] | None = None,
|
||||
) -> Callable[..., Any]:
|
||||
"""Decorator to mark methods that intercept sub-workflow requests.
|
||||
|
||||
The request type is automatically inferred from the method's second parameter type hint.
|
||||
The type must be a subclass of RequestInfoMessage.
|
||||
|
||||
This decorator allows executors in parent workflows to intercept and handle requests from
|
||||
sub-workflows before they are sent to external sources.
|
||||
|
||||
Args:
|
||||
func: The function being decorated (automatically passed when used without parentheses).
|
||||
from_workflow: Optional ID of specific sub-workflow to intercept from.
|
||||
condition: Optional callable that must return True for interception.
|
||||
|
||||
Returns:
|
||||
The decorated function with interception metadata.
|
||||
|
||||
Example:
|
||||
@intercepts_request
|
||||
async def check_domain(
|
||||
self, request: DomainCheckRequest, ctx: WorkflowContext[Any]
|
||||
) -> RequestResponse[DomainCheckRequest, bool]:
|
||||
# Type automatically inferred as DomainCheckRequest from parameter annotation
|
||||
if request.domain in self.approved_domains:
|
||||
return RequestResponse.handled(True)
|
||||
return RequestResponse.forward()
|
||||
|
||||
@intercepts_request(from_workflow="email_validator")
|
||||
async def handle_specific(
|
||||
self, request: EmailRequest, ctx: WorkflowContext[Any]
|
||||
) -> RequestResponse[EmailRequest, str]:
|
||||
# Only intercepts EmailRequest from the "email_validator" workflow
|
||||
return RequestResponse.handled("handled by parent")
|
||||
"""
|
||||
|
||||
def decorator(func: Callable[..., Any]) -> Callable[..., Any]:
|
||||
# Extract request type from method signature
|
||||
sig = inspect.signature(func)
|
||||
params = list(sig.parameters.values())
|
||||
|
||||
if len(params) < 2:
|
||||
raise ValueError(f"Interceptor method '{func.__name__}' must have at least 2 parameters (self, request)")
|
||||
|
||||
request_param = params[1] # Second parameter (after self)
|
||||
request_type = request_param.annotation
|
||||
|
||||
if request_type is inspect.Parameter.empty:
|
||||
raise ValueError(f"Interceptor method '{func.__name__}' request parameter must have a type annotation")
|
||||
|
||||
# Runtime validation that it's a RequestInfoMessage subclass
|
||||
if isinstance(request_type, type):
|
||||
# We need to check if RequestInfoMessage is defined yet, since this runs at import time
|
||||
try:
|
||||
# Try to get RequestInfoMessage from the current module's globals
|
||||
request_info_message_class = None
|
||||
func_module = inspect.getmodule(func)
|
||||
if func_module and hasattr(func_module, "RequestInfoMessage"):
|
||||
request_info_message_class = func_module.RequestInfoMessage
|
||||
else:
|
||||
# Look in the current module (where this decorator is defined)
|
||||
import sys
|
||||
|
||||
current_module = sys.modules[__name__]
|
||||
if hasattr(current_module, "RequestInfoMessage"):
|
||||
request_info_message_class = current_module.RequestInfoMessage
|
||||
|
||||
if request_info_message_class and not issubclass(request_type, request_info_message_class):
|
||||
raise TypeError(
|
||||
f"Interceptor method '{func.__name__}' can only handle RequestInfoMessage subclasses, "
|
||||
f"got {request_type}. Make sure your request type inherits from RequestInfoMessage."
|
||||
)
|
||||
except (AttributeError, NameError):
|
||||
# RequestInfoMessage might not be defined yet at import time, skip validation
|
||||
# This will be caught later when the interceptor is actually called
|
||||
pass
|
||||
|
||||
@functools.wraps(func)
|
||||
async def wrapper(self: Any, request: Any, ctx: WorkflowContext[Any]) -> Any:
|
||||
return await func(self, request, ctx)
|
||||
|
||||
# Add metadata for discovery - store the inferred type
|
||||
wrapper._intercepts_request = request_type # type: ignore
|
||||
wrapper._from_workflow = from_workflow # type: ignore
|
||||
wrapper._intercept_condition = condition # type: ignore
|
||||
|
||||
return wrapper
|
||||
|
||||
# If func is provided, we're being called without parentheses: @intercepts_request
|
||||
if func is not None:
|
||||
return decorator(func)
|
||||
|
||||
# Otherwise, we're being called with parentheses: @intercepts_request(from_workflow="...")
|
||||
return decorator
|
||||
|
||||
|
||||
# endregion: Intercepts Request Decorator
|
||||
|
||||
# region Agent Executor
|
||||
|
||||
|
||||
@@ -354,7 +692,11 @@ class RequestInfoMessage:
|
||||
the request/response pattern explicit.
|
||||
"""
|
||||
|
||||
request_id: str = str(uuid.uuid4())
|
||||
request_id: str = field(default_factory=lambda: str(uuid.uuid4()))
|
||||
|
||||
|
||||
# Note: SubWorkflowRequestInfo, SubWorkflowResponse, and RequestResponse
|
||||
# have been moved before intercepts_request decorator
|
||||
|
||||
|
||||
class RequestInfoExecutor(Executor):
|
||||
@@ -365,13 +707,19 @@ class RequestInfoExecutor(Executor):
|
||||
a response is provided externally, it emits the response as a message.
|
||||
"""
|
||||
|
||||
# Well-known ID for the request info executor
|
||||
EXECUTOR_ID: ClassVar[str] = "request_info"
|
||||
def __init__(self, id: str | None = None):
|
||||
"""Initialize the RequestInfoExecutor with an optional custom ID.
|
||||
|
||||
def __init__(self) -> None:
|
||||
"""Initialize the RequestInfoExecutor with its well-known ID."""
|
||||
super().__init__(id=self.EXECUTOR_ID)
|
||||
Args:
|
||||
id: Optional custom ID for this RequestInfoExecutor. If not provided,
|
||||
a unique ID will be generated.
|
||||
"""
|
||||
import uuid
|
||||
|
||||
executor_id = id or f"request_info_{uuid.uuid4().hex[:8]}"
|
||||
super().__init__(id=executor_id)
|
||||
self._request_events: dict[str, RequestInfoEvent] = {}
|
||||
self._sub_workflow_contexts: dict[str, dict[str, str]] = {}
|
||||
|
||||
@handler
|
||||
async def run(self, message: RequestInfoMessage, ctx: WorkflowContext[None]) -> None:
|
||||
@@ -387,6 +735,36 @@ class RequestInfoExecutor(Executor):
|
||||
self._request_events[message.request_id] = event
|
||||
await ctx.add_event(event)
|
||||
|
||||
@handler
|
||||
async def handle_sub_workflow_request(
|
||||
self,
|
||||
message: SubWorkflowRequestInfo,
|
||||
ctx: WorkflowContext[None],
|
||||
) -> None:
|
||||
"""Handle forwarded sub-workflow request.
|
||||
|
||||
This method handles requests that were forwarded from parent workflows
|
||||
because they couldn't be handled locally.
|
||||
"""
|
||||
# When called directly from runner, we need to use the sub_workflow_id as the source
|
||||
source_executor_id = message.sub_workflow_id
|
||||
|
||||
# Store context for routing response back
|
||||
self._sub_workflow_contexts[message.request_id] = {
|
||||
"sub_workflow_id": message.sub_workflow_id,
|
||||
"parent_executor_id": source_executor_id,
|
||||
}
|
||||
|
||||
# Create event for external handling - preserve the SubWorkflowRequestInfo wrapper
|
||||
event = RequestInfoEvent(
|
||||
request_id=message.request_id, # Use original request ID
|
||||
source_executor_id=source_executor_id,
|
||||
request_type=type(message.data), # SubWorkflowRequestInfo type
|
||||
request_data=message.data, # The full SubWorkflowRequestInfo
|
||||
)
|
||||
self._request_events[message.request_id] = event
|
||||
await ctx.add_event(event)
|
||||
|
||||
async def handle_response(
|
||||
self,
|
||||
response_data: Any,
|
||||
@@ -404,7 +782,190 @@ class RequestInfoExecutor(Executor):
|
||||
raise ValueError(f"No request found with ID: {request_id}")
|
||||
|
||||
event = self._request_events.pop(request_id)
|
||||
await ctx.send_message(response_data, target_id=event.source_executor_id)
|
||||
|
||||
# Check if this was a forwarded sub-workflow request
|
||||
if request_id in self._sub_workflow_contexts:
|
||||
context = self._sub_workflow_contexts.pop(request_id)
|
||||
|
||||
# Send back to sub-workflow that made the original request
|
||||
await ctx.send_message(
|
||||
SubWorkflowResponse(
|
||||
request_id=request_id,
|
||||
data=response_data,
|
||||
),
|
||||
target_id=context["sub_workflow_id"],
|
||||
)
|
||||
else:
|
||||
# Regular response - send directly back to source
|
||||
# Create a correlated response that includes both the response data and original request
|
||||
correlated_response = RequestResponse.handled(response_data)
|
||||
correlated_response = RequestResponse._with_correlation(correlated_response, event.data, request_id)
|
||||
|
||||
await ctx.send_message(correlated_response, target_id=event.source_executor_id)
|
||||
|
||||
|
||||
# endregion: Request Info Executor
|
||||
|
||||
# region Workflow Executor
|
||||
|
||||
|
||||
class WorkflowExecutor(Executor):
|
||||
"""An executor that runs another workflow as its execution logic.
|
||||
|
||||
This executor wraps a workflow to make it behave as an executor, enabling
|
||||
hierarchical workflow composition. Sub-workflows can send requests that
|
||||
are intercepted by parent workflows.
|
||||
"""
|
||||
|
||||
def __init__(self, workflow: "Workflow", id: str | None = None):
|
||||
"""Initialize the WorkflowExecutor.
|
||||
|
||||
Args:
|
||||
workflow: The workflow to execute as a sub-workflow.
|
||||
id: Optional unique identifier for this executor.
|
||||
"""
|
||||
super().__init__(id)
|
||||
self._workflow = workflow
|
||||
# Track pending external responses by request_id
|
||||
self._pending_responses: dict[str, Any] = {} # request_id -> response_data
|
||||
# Track workflow state for proper resumption - support multiple concurrent requests
|
||||
self._pending_requests: dict[str, Any] = {} # request_id -> original request data
|
||||
self._active_executions: int = 0 # Count of active sub-workflow executions
|
||||
# Response accumulation for multiple concurrent responses
|
||||
self._collected_responses: dict[str, Any] = {} # Accumulate responses
|
||||
self._expected_response_count: int = 0 # Track how many responses we're waiting for
|
||||
|
||||
@handler # No output_types - can send any completion data type
|
||||
async def process_workflow(self, input_data: object, ctx: WorkflowContext[Any]) -> None:
|
||||
"""Execute the sub-workflow with raw input data.
|
||||
|
||||
This handler starts a new sub-workflow execution. When the sub-workflow
|
||||
needs external information, it pauses and sends a request to the parent.
|
||||
|
||||
Args:
|
||||
input_data: The input data to send to the sub-workflow.
|
||||
ctx: The workflow context from the parent.
|
||||
"""
|
||||
# Skip SubWorkflowResponse and SubWorkflowRequestInfo - they have specific handlers
|
||||
if isinstance(input_data, (SubWorkflowResponse, SubWorkflowRequestInfo)):
|
||||
return
|
||||
|
||||
from ._events import RequestInfoEvent, WorkflowCompletedEvent
|
||||
|
||||
# Track this execution
|
||||
self._active_executions += 1
|
||||
|
||||
try:
|
||||
# Run the sub-workflow and collect all events
|
||||
events = [event async for event in self._workflow.run_streaming(input_data)]
|
||||
|
||||
# Count requests and initialize response tracking
|
||||
request_count = 0
|
||||
for event in events:
|
||||
if isinstance(event, RequestInfoEvent):
|
||||
request_count += 1
|
||||
|
||||
# Initialize response accumulation for this execution
|
||||
# For sequential workflows (like step_08), expect only current requests
|
||||
# For parallel workflows (like step_09), expect all requests at once
|
||||
self._expected_response_count = request_count
|
||||
self._collected_responses = {}
|
||||
|
||||
# If no requests in initial run, handle completion immediately
|
||||
if request_count == 0:
|
||||
self._expected_response_count = 0
|
||||
|
||||
# Process events to check for completion or requests
|
||||
for event in events:
|
||||
if isinstance(event, WorkflowCompletedEvent):
|
||||
# Sub-workflow completed normally - send result to parent
|
||||
await ctx.send_message(event.data)
|
||||
self._active_executions -= 1
|
||||
return # Exit after completion
|
||||
|
||||
if isinstance(event, RequestInfoEvent):
|
||||
# Sub-workflow needs external information
|
||||
# Track the pending request
|
||||
self._pending_requests[event.request_id] = event.data
|
||||
|
||||
# Wrap request with routing context and send to parent
|
||||
wrapped_request = SubWorkflowRequestInfo(
|
||||
request_id=event.request_id,
|
||||
sub_workflow_id=self.id,
|
||||
data=event.data,
|
||||
)
|
||||
|
||||
await ctx.send_message(wrapped_request)
|
||||
# Continue processing remaining events (no return)
|
||||
|
||||
except Exception as e:
|
||||
from ._events import ExecutorEvent
|
||||
|
||||
# Sub-workflow failed - create error event
|
||||
error_event = ExecutorEvent(executor_id=self.id, data={"error": str(e), "type": "sub_workflow_error"})
|
||||
await ctx.add_event(error_event)
|
||||
self._active_executions -= 1
|
||||
raise
|
||||
|
||||
@handler
|
||||
async def handle_response(
|
||||
self,
|
||||
response: SubWorkflowResponse,
|
||||
ctx: WorkflowContext[Any],
|
||||
) -> None:
|
||||
"""Handle response from parent for a forwarded request.
|
||||
|
||||
This handler accumulates responses and only resumes the sub-workflow
|
||||
when all expected responses have been received.
|
||||
|
||||
Args:
|
||||
response: The response to a previous request.
|
||||
ctx: The workflow context.
|
||||
"""
|
||||
# Check if we have this pending request
|
||||
pending_requests = getattr(self, "_pending_requests", {})
|
||||
if response.request_id not in pending_requests:
|
||||
return
|
||||
|
||||
# Remove the request from pending list
|
||||
pending_requests.pop(response.request_id, None)
|
||||
|
||||
# Accumulate the response
|
||||
self._collected_responses[response.request_id] = response.data
|
||||
|
||||
# Check if we have all expected responses for current batch
|
||||
if len(self._collected_responses) >= self._expected_response_count:
|
||||
from ._events import RequestInfoEvent, WorkflowCompletedEvent
|
||||
|
||||
# Send all collected responses to the sub-workflow
|
||||
responses_to_send = dict(self._collected_responses)
|
||||
self._collected_responses.clear() # Clear for next batch
|
||||
|
||||
result_events = [event async for event in self._workflow.send_responses_streaming(responses_to_send)]
|
||||
|
||||
# Process the result events
|
||||
new_request_count = 0
|
||||
for event in result_events:
|
||||
if isinstance(event, WorkflowCompletedEvent):
|
||||
# Sub-workflow completed - send result to parent
|
||||
await ctx.send_message(event.data)
|
||||
self._active_executions -= 1
|
||||
return
|
||||
if isinstance(event, RequestInfoEvent):
|
||||
# Sub-workflow sent more requests - prepare for next batch
|
||||
new_request_count += 1
|
||||
self._pending_requests[event.request_id] = event.data
|
||||
|
||||
# Send the new request to parent
|
||||
wrapped_request = SubWorkflowRequestInfo(
|
||||
request_id=event.request_id,
|
||||
sub_workflow_id=self.id,
|
||||
data=event.data,
|
||||
)
|
||||
await ctx.send_message(wrapped_request)
|
||||
|
||||
# Update expected count for next batch of requests
|
||||
self._expected_response_count = new_request_count
|
||||
|
||||
|
||||
# endregion: Workflow Executor
|
||||
|
||||
@@ -4,7 +4,10 @@ import asyncio
|
||||
import logging
|
||||
from collections import defaultdict
|
||||
from collections.abc import AsyncIterable, Sequence
|
||||
from typing import Any
|
||||
from typing import TYPE_CHECKING, Any
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ._executor import RequestInfoExecutor
|
||||
|
||||
from ._edge import EdgeGroup
|
||||
from ._edge_runner import EdgeRunner, create_edge_runner
|
||||
@@ -12,6 +15,8 @@ from ._events import WorkflowEvent
|
||||
from ._executor import Executor
|
||||
from ._runner_context import Message, RunnerContext
|
||||
from ._shared_state import SharedState
|
||||
from ._typing_utils import is_instance_of
|
||||
from ._workflow_context import WorkflowContext
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -132,12 +137,107 @@ class Runner:
|
||||
async def _deliver_messages(source_executor_id: str, messages: list[Message]) -> None:
|
||||
"""Outer loop to concurrently deliver messages from all sources to their targets."""
|
||||
|
||||
# Special handling for SubWorkflowRequestInfo messages
|
||||
async def _deliver_sub_workflow_requests(messages: list[Message]) -> None:
|
||||
from ._executor import SubWorkflowRequestInfo
|
||||
|
||||
# Handle SubWorkflowRequestInfo messages - only process those not already targeted
|
||||
sub_workflow_messages = []
|
||||
for msg in messages:
|
||||
# Skip messages sent directly to RequestInfoExecutor - they are already forwarded
|
||||
if self._is_message_to_request_info_executor(msg):
|
||||
continue
|
||||
|
||||
if isinstance(msg.data, SubWorkflowRequestInfo):
|
||||
sub_workflow_messages.append(msg)
|
||||
|
||||
for message in sub_workflow_messages:
|
||||
sub_request = message.data
|
||||
|
||||
# Find executor that can intercept the wrapped type
|
||||
interceptor_found = False
|
||||
for executor in self._executors.values():
|
||||
if hasattr(executor, "_request_interceptors") and executor.id != message.source_id:
|
||||
# Check if any registered interceptor can handle this request type
|
||||
for registered_type in executor._request_interceptors:
|
||||
# Check type matching - handle both type and string cases
|
||||
matched = False
|
||||
if (
|
||||
isinstance(registered_type, type)
|
||||
and is_instance_of(sub_request.data, registered_type)
|
||||
) or (
|
||||
isinstance(registered_type, str)
|
||||
and hasattr(sub_request.data, "__class__")
|
||||
and sub_request.data.__class__.__name__ == registered_type
|
||||
):
|
||||
matched = True
|
||||
|
||||
if matched:
|
||||
# Send directly to the intercepting executor
|
||||
logger.info(
|
||||
f"Sending sub-workflow request of type '{sub_request.data.__class__.__name__}' "
|
||||
f"from sub-workflow '{sub_request.sub_workflow_id}' "
|
||||
f"to executor '{executor.id}' for interception."
|
||||
)
|
||||
await executor.execute(sub_request, self._ctx) # type: ignore[arg-type]
|
||||
interceptor_found = True
|
||||
break
|
||||
if interceptor_found:
|
||||
break
|
||||
|
||||
if not interceptor_found:
|
||||
# No interceptor found - send directly to RequestInfoExecutor if available.
|
||||
|
||||
# Find the RequestInfoExecutor instance
|
||||
request_info_executor = self._find_request_info_executor()
|
||||
|
||||
if request_info_executor:
|
||||
workflow_ctx: WorkflowContext[None] = WorkflowContext(
|
||||
request_info_executor.id,
|
||||
["Runner"],
|
||||
self._shared_state,
|
||||
self._ctx,
|
||||
)
|
||||
logger.info(
|
||||
f"Sending sub-workflow request of type '{sub_request.data.__class__.__name__}' "
|
||||
f"from sub-workflow '{sub_request.sub_workflow_id}' to RequestInfoExecutor "
|
||||
f"'{request_info_executor.id}'"
|
||||
)
|
||||
await request_info_executor.execute(sub_request, workflow_ctx)
|
||||
else:
|
||||
logger.warning(
|
||||
f"Sub-workflow request of type '{sub_request.data.__class__.__name__}' "
|
||||
f"from sub-workflow '{sub_request.sub_workflow_id}' could not be handled: "
|
||||
f"no RequestInfoExecutor found in the workflow. Add a RequestInfoExecutor "
|
||||
f"to handle external requests or add an @intercepts_request handler."
|
||||
)
|
||||
|
||||
async def _deliver_message_inner(edge_runner: EdgeRunner, message: Message) -> bool:
|
||||
"""Inner loop to deliver a single message through an edge runner."""
|
||||
return await edge_runner.send_message(message, self._shared_state, self._ctx)
|
||||
|
||||
# Handle SubWorkflowRequestInfo messages specially
|
||||
await _deliver_sub_workflow_requests(messages)
|
||||
|
||||
# Filter out SubWorkflowRequestInfo messages from normal edge routing
|
||||
# since they were handled specially
|
||||
from ._executor import SubWorkflowRequestInfo
|
||||
|
||||
non_sub_workflow_messages = []
|
||||
for msg in messages:
|
||||
# Keep messages sent directly to RequestInfoExecutor (forwarded messages)
|
||||
if self._is_message_to_request_info_executor(msg):
|
||||
non_sub_workflow_messages.append(msg)
|
||||
continue
|
||||
|
||||
# Skip SubWorkflowRequestInfo messages (handled by special routing)
|
||||
if isinstance(msg.data, SubWorkflowRequestInfo):
|
||||
continue
|
||||
|
||||
non_sub_workflow_messages.append(msg)
|
||||
|
||||
associated_edge_runners = self._edge_runner_map.get(source_executor_id, [])
|
||||
for message in messages:
|
||||
for message in non_sub_workflow_messages:
|
||||
# Deliver a message through all edge runners associated with the source executor concurrently.
|
||||
tasks = [_deliver_message_inner(edge_runner, message) for edge_runner in associated_edge_runners]
|
||||
results = await asyncio.gather(*tasks)
|
||||
@@ -282,3 +382,36 @@ class Runner:
|
||||
parsed[source_executor_id].append(runner)
|
||||
|
||||
return parsed
|
||||
|
||||
def _find_request_info_executor(self) -> "RequestInfoExecutor | None":
|
||||
"""Find the RequestInfoExecutor instance in this workflow.
|
||||
|
||||
Returns:
|
||||
The RequestInfoExecutor instance if found, None otherwise.
|
||||
"""
|
||||
from ._executor import RequestInfoExecutor
|
||||
|
||||
for executor in self._executors.values():
|
||||
if isinstance(executor, RequestInfoExecutor):
|
||||
return executor
|
||||
return None
|
||||
|
||||
def _is_message_to_request_info_executor(self, msg: "Message") -> bool:
|
||||
"""Check if message targets any RequestInfoExecutor in this workflow.
|
||||
|
||||
Args:
|
||||
msg: The message to check.
|
||||
|
||||
Returns:
|
||||
True if the message targets a RequestInfoExecutor, False otherwise.
|
||||
"""
|
||||
from ._executor import RequestInfoExecutor
|
||||
|
||||
if not msg.target_id:
|
||||
return False
|
||||
|
||||
# Check all executors to see if target_id matches a RequestInfoExecutor
|
||||
for executor in self._executors.values():
|
||||
if executor.id == msg.target_id and isinstance(executor, RequestInfoExecutor):
|
||||
return True
|
||||
return False
|
||||
|
||||
@@ -50,5 +50,29 @@ def is_instance_of(data: Any, target_type: type) -> bool:
|
||||
for key, value in data.items() # type: ignore
|
||||
)
|
||||
|
||||
# Case 6: target_type is RequestResponse[T, U] - validate generic parameters
|
||||
if origin and hasattr(origin, "__name__") and origin.__name__ == "RequestResponse":
|
||||
if not isinstance(data, origin):
|
||||
return False
|
||||
# Validate generic parameters for RequestResponse[TRequest, TResponse]
|
||||
if len(args) >= 2:
|
||||
request_type, response_type = args[0], args[1]
|
||||
# Check if the original_request matches TRequest and data matches TResponse
|
||||
if (
|
||||
hasattr(data, "original_request")
|
||||
and data.original_request is not None
|
||||
and not is_instance_of(data.original_request, request_type)
|
||||
):
|
||||
return False
|
||||
if hasattr(data, "data") and data.data is not None and not is_instance_of(data.data, response_type):
|
||||
return False
|
||||
return True
|
||||
|
||||
# Case 7: Other custom generic classes - check origin type only
|
||||
# For generic classes, we check if data is an instance of the origin type
|
||||
# We don't validate the generic parameters at runtime since that's handled by type system
|
||||
if origin and hasattr(origin, "__name__"):
|
||||
return isinstance(data, origin)
|
||||
|
||||
# Fallback: if we reach here, we assume data is an instance of the target_type
|
||||
return isinstance(data, target_type)
|
||||
|
||||
@@ -22,6 +22,7 @@ class ValidationTypeEnum(Enum):
|
||||
TYPE_COMPATIBILITY = "TYPE_COMPATIBILITY"
|
||||
GRAPH_CONNECTIVITY = "GRAPH_CONNECTIVITY"
|
||||
HANDLER_OUTPUT_ANNOTATION = "HANDLER_OUTPUT_ANNOTATION"
|
||||
INTERCEPTOR_CONFLICT = "INTERCEPTOR_CONFLICT"
|
||||
|
||||
|
||||
class WorkflowValidationError(Exception):
|
||||
@@ -94,6 +95,13 @@ class HandlerOutputAnnotationError(WorkflowValidationError):
|
||||
self.handler_name = handler_name
|
||||
|
||||
|
||||
class InterceptorConflictError(WorkflowValidationError):
|
||||
"""Exception raised when multiple executors intercept the same request type from the same sub-workflow."""
|
||||
|
||||
def __init__(self, message: str):
|
||||
super().__init__(message, validation_type=ValidationTypeEnum.INTERCEPTOR_CONFLICT)
|
||||
|
||||
|
||||
# endregion
|
||||
|
||||
|
||||
@@ -129,6 +137,14 @@ class WorkflowGraphValidator:
|
||||
self._edges = [edge for group in edge_groups for edge in group.edges]
|
||||
self._edge_groups = edge_groups
|
||||
|
||||
# If only the start executor exists, add it to the executor map
|
||||
# Handle the special case where the workflow consists of only a single executor and no edges.
|
||||
# In this scenario, the executor map will be empty because there are no edge groups to reference executors.
|
||||
# Adding the start executor to the map ensures that single-executor workflows (without any edges) are supported,
|
||||
# allowing validation and execution to proceed for workflows that do not require inter-executor communication.
|
||||
if not self._executors and start_executor and isinstance(start_executor, Executor):
|
||||
self._executors[start_executor.id] = start_executor
|
||||
|
||||
# Validate that start_executor exists in the graph
|
||||
# It should because we check for it in the WorkflowBuilder
|
||||
# but we do it here for completeness.
|
||||
@@ -145,6 +161,7 @@ class WorkflowGraphValidator:
|
||||
self._validate_handler_ambiguity()
|
||||
self._validate_dead_ends()
|
||||
self._validate_cycles()
|
||||
self._validate_interceptor_uniqueness()
|
||||
|
||||
def _validate_handler_output_annotations(self) -> None:
|
||||
"""Validate that each handler's ctx parameter is annotated with WorkflowContext[T].
|
||||
@@ -306,13 +323,21 @@ class WorkflowGraphValidator:
|
||||
# If either executor has no type information, log warning and skip validation
|
||||
# This allows for dynamic typing scenarios but warns about reduced validation coverage
|
||||
if not source_output_types or not target_input_types:
|
||||
if not source_output_types:
|
||||
# Suppress warnings for built-in workflow components where dynamic typing is expected
|
||||
try:
|
||||
from ._executor import RequestInfoExecutor, WorkflowExecutor # local import to avoid cycles
|
||||
|
||||
builtin_types = (RequestInfoExecutor, WorkflowExecutor)
|
||||
except Exception:
|
||||
builtin_types = tuple() # type: ignore[assignment]
|
||||
|
||||
if not source_output_types and not isinstance(source_executor, builtin_types):
|
||||
logger.warning(
|
||||
f"Executor '{source_executor.id}' has no output type annotations. "
|
||||
f"Type compatibility validation will be skipped for edges from this executor. "
|
||||
f"Consider adding WorkflowContext[T] generics in handlers for better validation."
|
||||
)
|
||||
if not target_input_types:
|
||||
if not target_input_types and not isinstance(target_executor, builtin_types):
|
||||
logger.warning(
|
||||
f"Executor '{target_executor.id}' has no input type annotations. "
|
||||
f"Type compatibility validation will be skipped for edges to this executor. "
|
||||
@@ -376,6 +401,13 @@ class WorkflowGraphValidator:
|
||||
# Skip attributes that may not be accessible
|
||||
continue
|
||||
|
||||
# Also include intercepted request types as potential outputs
|
||||
# since @intercepts_request methods can forward requests
|
||||
if hasattr(executor, "_request_interceptors"):
|
||||
for request_type in executor._request_interceptors:
|
||||
if isinstance(request_type, type):
|
||||
output_types.append(request_type)
|
||||
|
||||
return output_types
|
||||
|
||||
def _get_executor_input_types(self, executor: Executor) -> list[type[Any]]:
|
||||
@@ -571,6 +603,54 @@ class WorkflowGraphValidator:
|
||||
"Ensure proper termination conditions exist to prevent infinite loops."
|
||||
)
|
||||
|
||||
def _validate_interceptor_uniqueness(self) -> None:
|
||||
"""Validate that only one executor intercepts a given request type from a specific sub-workflow.
|
||||
|
||||
This prevents non-deterministic behavior where multiple executors could intercept
|
||||
the same request type from the same sub-workflow.
|
||||
"""
|
||||
from ._executor import WorkflowExecutor
|
||||
|
||||
# Find all WorkflowExecutor instances in the workflow
|
||||
workflow_executors: dict[str, WorkflowExecutor] = {}
|
||||
for executor_id, executor in self._executors.items():
|
||||
if isinstance(executor, WorkflowExecutor):
|
||||
workflow_executors[executor_id] = executor
|
||||
|
||||
# For each WorkflowExecutor, check which executors can intercept its requests
|
||||
for workflow_id, _workflow_executor in workflow_executors.items():
|
||||
# Map of request_type -> list of intercepting executor IDs
|
||||
interceptors_by_type: dict[type | str, list[str]] = {}
|
||||
|
||||
# Find all executors that have edges from this WorkflowExecutor
|
||||
# These are potential interceptors
|
||||
for edge in self._edges:
|
||||
if edge.source_id == workflow_id:
|
||||
target_executor = self._executors.get(edge.target_id)
|
||||
if target_executor and hasattr(target_executor, "_request_interceptors"):
|
||||
# Check what request types this executor intercepts
|
||||
for request_type, interceptor_list in target_executor._request_interceptors.items():
|
||||
# Check if any interceptor is scoped to this workflow or unscoped
|
||||
for interceptor_info in interceptor_list:
|
||||
from_workflow = interceptor_info.get("from_workflow")
|
||||
# If unscoped or specifically scoped to this workflow
|
||||
if from_workflow is None or from_workflow == workflow_id:
|
||||
if request_type not in interceptors_by_type:
|
||||
interceptors_by_type[request_type] = []
|
||||
interceptors_by_type[request_type].append(edge.target_id)
|
||||
|
||||
# Check for duplicates
|
||||
for request_type, executor_ids in interceptors_by_type.items():
|
||||
unique_executors = list(set(executor_ids)) # Remove duplicates from same executor
|
||||
if len(unique_executors) > 1:
|
||||
type_name = request_type.__name__ if isinstance(request_type, type) else str(request_type)
|
||||
raise InterceptorConflictError(
|
||||
f"Multiple executors intercept the same request type '{type_name}' "
|
||||
f"from sub-workflow '{workflow_id}': {', '.join(unique_executors)}. "
|
||||
f"Only one executor should intercept a given request type from a specific sub-workflow "
|
||||
f"to ensure deterministic behavior."
|
||||
)
|
||||
|
||||
# endregion
|
||||
|
||||
# region Type Compatibility Utilities
|
||||
|
||||
@@ -219,8 +219,8 @@ class Workflow(AFBaseModel):
|
||||
raise RuntimeError(f"Failed to restore from checkpoint: {checkpoint_id}")
|
||||
|
||||
if responses:
|
||||
request_info_executor = self._get_executor_by_id(RequestInfoExecutor.EXECUTOR_ID)
|
||||
if isinstance(request_info_executor, RequestInfoExecutor):
|
||||
request_info_executor = self._find_request_info_executor()
|
||||
if request_info_executor:
|
||||
for request_id, response_data in responses.items():
|
||||
await request_info_executor.handle_response(
|
||||
response_data,
|
||||
@@ -246,9 +246,9 @@ class Workflow(AFBaseModel):
|
||||
Yields:
|
||||
WorkflowEvent: The events generated during the workflow execution after sending the responses.
|
||||
"""
|
||||
request_info_executor = self._get_executor_by_id(RequestInfoExecutor.EXECUTOR_ID)
|
||||
if not isinstance(request_info_executor, RequestInfoExecutor):
|
||||
raise ValueError(f"Executor with ID {RequestInfoExecutor.EXECUTOR_ID} is not a RequestInfoExecutor.")
|
||||
request_info_executor = self._find_request_info_executor()
|
||||
if not request_info_executor:
|
||||
raise ValueError("No RequestInfoExecutor found in workflow.")
|
||||
|
||||
async def _handle_response(response: Any, request_id: str) -> None:
|
||||
"""Handle the response from the RequestInfoExecutor."""
|
||||
@@ -332,6 +332,19 @@ class Workflow(AFBaseModel):
|
||||
raise ValueError(f"Executor with ID {executor_id} not found.")
|
||||
return self.executors[executor_id]
|
||||
|
||||
def _find_request_info_executor(self) -> "RequestInfoExecutor | None":
|
||||
"""Find the RequestInfoExecutor instance in this workflow.
|
||||
|
||||
Returns:
|
||||
The RequestInfoExecutor instance if found, None otherwise.
|
||||
"""
|
||||
from ._executor import RequestInfoExecutor
|
||||
|
||||
for executor in self.executors.values():
|
||||
if isinstance(executor, RequestInfoExecutor):
|
||||
return executor
|
||||
return None
|
||||
|
||||
async def _restore_from_external_checkpoint(
|
||||
self, checkpoint_id: str, checkpoint_storage: CheckpointStorage
|
||||
) -> bool:
|
||||
|
||||
@@ -0,0 +1,111 @@
|
||||
# Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
import asyncio
|
||||
from dataclasses import dataclass
|
||||
|
||||
import pytest
|
||||
|
||||
from agent_framework_workflow import (
|
||||
Executor,
|
||||
WorkflowBuilder,
|
||||
WorkflowContext,
|
||||
WorkflowExecutor,
|
||||
handler,
|
||||
)
|
||||
|
||||
|
||||
@dataclass
|
||||
class SimpleRequest:
|
||||
"""Simple request for testing."""
|
||||
|
||||
text: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class SimpleResponse:
|
||||
"""Simple response for testing."""
|
||||
|
||||
result: str
|
||||
|
||||
|
||||
class SimpleSubExecutor(Executor):
|
||||
"""Simple executor for sub-workflow."""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(id="simple_sub")
|
||||
|
||||
@handler
|
||||
async def process(self, request: SimpleRequest, ctx: WorkflowContext[None]) -> None:
|
||||
"""Process a simple request."""
|
||||
from agent_framework_workflow import WorkflowCompletedEvent
|
||||
|
||||
# Just echo back with prefix and complete
|
||||
response = SimpleResponse(result=f"processed: {request.text}")
|
||||
await ctx.add_event(WorkflowCompletedEvent(data=response))
|
||||
|
||||
|
||||
class SimpleParent(Executor):
|
||||
"""Simple parent executor."""
|
||||
|
||||
result: SimpleResponse | None = None
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(id="simple_parent")
|
||||
|
||||
@handler
|
||||
async def start(self, text: str, ctx: WorkflowContext[SimpleRequest]) -> None:
|
||||
"""Start the process."""
|
||||
request = SimpleRequest(text=text)
|
||||
await ctx.send_message(request, target_id="sub_workflow")
|
||||
|
||||
@handler
|
||||
async def collect(self, response: SimpleResponse, ctx: WorkflowContext[None]) -> None:
|
||||
"""Collect the result."""
|
||||
self.result = response
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_simple_sub_workflow():
|
||||
"""Test the simplest possible sub-workflow."""
|
||||
# Create sub-workflow with dummy executor to satisfy validation
|
||||
sub_executor = SimpleSubExecutor()
|
||||
|
||||
class DummyExecutor(Executor):
|
||||
def __init__(self):
|
||||
super().__init__(id="dummy")
|
||||
|
||||
@handler
|
||||
async def process(self, message: object, ctx: WorkflowContext[None]) -> None:
|
||||
pass # Do nothing
|
||||
|
||||
dummy = DummyExecutor()
|
||||
sub_workflow = (
|
||||
WorkflowBuilder()
|
||||
.set_start_executor(sub_executor)
|
||||
.add_edge(sub_executor, dummy) # Add edge to satisfy validation
|
||||
.build()
|
||||
)
|
||||
|
||||
# Create parent workflow
|
||||
parent = SimpleParent()
|
||||
workflow_executor = WorkflowExecutor(sub_workflow, id="sub_workflow")
|
||||
|
||||
main_workflow = (
|
||||
WorkflowBuilder()
|
||||
.set_start_executor(parent)
|
||||
.add_edge(parent, workflow_executor)
|
||||
.add_edge(workflow_executor, parent)
|
||||
.build()
|
||||
)
|
||||
|
||||
# Run the workflow
|
||||
await main_workflow.run("hello world")
|
||||
|
||||
# Check result
|
||||
assert parent.result is not None
|
||||
assert parent.result.result == "processed: hello world"
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Run the simple test
|
||||
asyncio.run(test_simple_sub_workflow())
|
||||
@@ -0,0 +1,420 @@
|
||||
# Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
import asyncio
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
from pydantic import Field
|
||||
|
||||
from agent_framework_workflow import (
|
||||
Executor,
|
||||
RequestInfoExecutor,
|
||||
RequestInfoMessage,
|
||||
RequestResponse,
|
||||
WorkflowBuilder,
|
||||
WorkflowCompletedEvent,
|
||||
WorkflowContext,
|
||||
WorkflowExecutor,
|
||||
handler,
|
||||
intercepts_request,
|
||||
)
|
||||
|
||||
|
||||
# Test message types
|
||||
@dataclass
|
||||
class EmailValidationRequest:
|
||||
"""Request to validate an email address."""
|
||||
|
||||
email: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class DomainCheckRequest(RequestInfoMessage):
|
||||
"""Request to check if a domain is approved."""
|
||||
|
||||
domain: str = ""
|
||||
email: str = "" # Include original email for correlation
|
||||
|
||||
|
||||
@dataclass
|
||||
class ValidationResult:
|
||||
"""Result of email validation."""
|
||||
|
||||
email: str
|
||||
is_valid: bool
|
||||
reason: str
|
||||
|
||||
|
||||
# Test executors
|
||||
class EmailValidator(Executor):
|
||||
"""Validates email addresses in a sub-workflow."""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(id="email_validator")
|
||||
|
||||
@handler
|
||||
async def validate_request(
|
||||
self, request: EmailValidationRequest, ctx: WorkflowContext[RequestInfoMessage | ValidationResult]
|
||||
) -> None:
|
||||
"""Validate an email address."""
|
||||
# Extract domain and check if it's approved
|
||||
domain = request.email.split("@")[1] if "@" in request.email else ""
|
||||
|
||||
if not domain:
|
||||
result = ValidationResult(email=request.email, is_valid=False, reason="Invalid email format")
|
||||
await ctx.add_event(WorkflowCompletedEvent(data=result))
|
||||
return
|
||||
|
||||
# Request domain check from external source
|
||||
domain_check = DomainCheckRequest(domain=domain, email=request.email)
|
||||
await ctx.send_message(domain_check)
|
||||
|
||||
@handler
|
||||
async def handle_domain_response(
|
||||
self, response: RequestResponse[DomainCheckRequest, bool], ctx: WorkflowContext[ValidationResult]
|
||||
) -> None:
|
||||
"""Handle domain check response with correlation."""
|
||||
# Use the original email from the correlated response
|
||||
result = ValidationResult(
|
||||
email=response.original_request.email,
|
||||
is_valid=response.data or False,
|
||||
reason="Domain approved" if response.data else "Domain not approved",
|
||||
)
|
||||
await ctx.add_event(WorkflowCompletedEvent(data=result))
|
||||
|
||||
|
||||
class ParentOrchestrator(Executor):
|
||||
"""Parent workflow orchestrator with domain knowledge."""
|
||||
|
||||
approved_domains: set[str] = Field(default_factory=lambda: {"example.com", "test.org"})
|
||||
results: list[ValidationResult] = Field(default_factory=list)
|
||||
|
||||
def __init__(self, approved_domains: set[str] | None = None, **kwargs: Any):
|
||||
if approved_domains is not None:
|
||||
kwargs["approved_domains"] = approved_domains
|
||||
super().__init__(id="parent_orchestrator", **kwargs)
|
||||
|
||||
@handler
|
||||
async def start(self, emails: list[str], ctx: WorkflowContext[EmailValidationRequest]) -> None:
|
||||
"""Start processing emails."""
|
||||
for email in emails:
|
||||
request = EmailValidationRequest(email=email)
|
||||
await ctx.send_message(request, target_id="email_workflow")
|
||||
|
||||
@intercepts_request
|
||||
async def check_domain(
|
||||
self, request: DomainCheckRequest, ctx: WorkflowContext[Any]
|
||||
) -> RequestResponse[DomainCheckRequest, bool]:
|
||||
"""Intercept domain check requests from sub-workflows."""
|
||||
# Check if we know this domain
|
||||
if request.domain in self.approved_domains:
|
||||
return RequestResponse[DomainCheckRequest, bool].handled(True)
|
||||
|
||||
# We don't know this domain, forward to external
|
||||
return RequestResponse[DomainCheckRequest, bool].forward()
|
||||
|
||||
@handler
|
||||
async def collect_result(self, result: ValidationResult, ctx: WorkflowContext[None]) -> None:
|
||||
"""Collect validation results."""
|
||||
self.results.append(result)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_basic_sub_workflow() -> None:
|
||||
"""Test basic sub-workflow execution without interception."""
|
||||
# Create sub-workflow
|
||||
email_validator = EmailValidator()
|
||||
email_request_info = RequestInfoExecutor(id="email_request_info")
|
||||
|
||||
validation_workflow = (
|
||||
WorkflowBuilder()
|
||||
.set_start_executor(email_validator)
|
||||
.add_edge(email_validator, email_request_info)
|
||||
.add_edge(email_request_info, email_validator)
|
||||
.build()
|
||||
)
|
||||
|
||||
# Create parent workflow without interception
|
||||
class SimpleParent(Executor):
|
||||
result: ValidationResult | None = Field(default=None)
|
||||
|
||||
def __init__(self, **kwargs: Any):
|
||||
super().__init__(id="simple_parent", **kwargs)
|
||||
|
||||
@handler
|
||||
async def start(self, email: str, ctx: WorkflowContext[EmailValidationRequest]) -> None:
|
||||
request = EmailValidationRequest(email=email)
|
||||
await ctx.send_message(request, target_id="email_workflow")
|
||||
|
||||
@handler
|
||||
async def collect(self, result: ValidationResult, ctx: WorkflowContext[None]) -> None:
|
||||
self.result = result
|
||||
|
||||
parent = SimpleParent()
|
||||
workflow_executor = WorkflowExecutor(validation_workflow, id="email_workflow")
|
||||
main_request_info = RequestInfoExecutor(id="main_request_info")
|
||||
|
||||
main_workflow = (
|
||||
WorkflowBuilder()
|
||||
.set_start_executor(parent)
|
||||
.add_edge(parent, workflow_executor)
|
||||
.add_edge(workflow_executor, parent)
|
||||
.add_edge(workflow_executor, main_request_info)
|
||||
.add_edge(main_request_info, workflow_executor) # CRITICAL: For SubWorkflowResponse routing
|
||||
.build()
|
||||
)
|
||||
|
||||
# Run workflow with mocked external response
|
||||
result = await main_workflow.run("test@example.com")
|
||||
|
||||
# Get request event and respond
|
||||
request_events = result.get_request_info_events()
|
||||
assert len(request_events) == 1
|
||||
assert isinstance(request_events[0].data, DomainCheckRequest)
|
||||
assert request_events[0].data.domain == "example.com"
|
||||
|
||||
# Send response through the main workflow
|
||||
await main_workflow.send_responses({
|
||||
request_events[0].request_id: True # Domain is approved
|
||||
})
|
||||
|
||||
# Check result
|
||||
assert parent.result is not None
|
||||
assert parent.result.email == "test@example.com"
|
||||
assert parent.result.is_valid is True
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_sub_workflow_with_interception():
|
||||
"""Test sub-workflow with parent interception of requests."""
|
||||
# Create sub-workflow
|
||||
email_validator = EmailValidator()
|
||||
email_request_info = RequestInfoExecutor(id="email_request_info")
|
||||
|
||||
validation_workflow = (
|
||||
WorkflowBuilder()
|
||||
.set_start_executor(email_validator)
|
||||
.add_edge(email_validator, email_request_info)
|
||||
.add_edge(email_request_info, email_validator)
|
||||
.build()
|
||||
)
|
||||
|
||||
# Create parent workflow with interception
|
||||
parent = ParentOrchestrator(approved_domains={"example.com", "internal.org"})
|
||||
workflow_executor = WorkflowExecutor(validation_workflow, id="email_workflow")
|
||||
parent_request_info = RequestInfoExecutor()
|
||||
|
||||
main_workflow = (
|
||||
WorkflowBuilder()
|
||||
.set_start_executor(parent)
|
||||
.add_edge(parent, workflow_executor)
|
||||
.add_edge(workflow_executor, parent)
|
||||
.add_edge(parent, parent_request_info) # For forwarded requests
|
||||
.add_edge(parent_request_info, workflow_executor) # For SubWorkflowResponse routing
|
||||
.build()
|
||||
)
|
||||
|
||||
# Test 1: Email with known domain (intercepted)
|
||||
result = await main_workflow.run(["user@example.com"])
|
||||
|
||||
# Should complete without external requests
|
||||
request_events = result.get_request_info_events()
|
||||
assert len(request_events) == 0 # No external requests, handled internally
|
||||
|
||||
assert len(parent.results) == 1
|
||||
assert parent.results[0].email == "user@example.com"
|
||||
assert parent.results[0].is_valid is True
|
||||
assert parent.results[0].reason == "Domain approved"
|
||||
|
||||
# Test 2: Email with unknown domain (forwarded)
|
||||
parent.results.clear()
|
||||
result = await main_workflow.run(["user@unknown.com"])
|
||||
|
||||
# Should have external request
|
||||
request_events = result.get_request_info_events()
|
||||
assert len(request_events) == 1
|
||||
assert isinstance(request_events[0].data, DomainCheckRequest)
|
||||
assert request_events[0].data.domain == "unknown.com"
|
||||
|
||||
# Send external response
|
||||
await main_workflow.send_responses({
|
||||
request_events[0].request_id: False # Domain not approved
|
||||
})
|
||||
|
||||
assert len(parent.results) == 1
|
||||
assert parent.results[0].email == "user@unknown.com"
|
||||
assert parent.results[0].is_valid is False
|
||||
assert parent.results[0].reason == "Domain not approved"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_conditional_forwarding() -> None:
|
||||
"""Test conditional forwarding with RequestResponse.forward()."""
|
||||
|
||||
class ConditionalParent(Executor):
|
||||
"""Parent that conditionally handles requests."""
|
||||
|
||||
cache: dict[str, bool] = Field(default_factory=lambda: {"cached.com": True})
|
||||
result: ValidationResult | None = Field(default=None)
|
||||
|
||||
def __init__(self, **kwargs: Any):
|
||||
super().__init__(id="conditional_parent", **kwargs)
|
||||
|
||||
@handler
|
||||
async def start(self, email: str, ctx: WorkflowContext[EmailValidationRequest]) -> None:
|
||||
request = EmailValidationRequest(email=email)
|
||||
await ctx.send_message(request, target_id="email_workflow")
|
||||
|
||||
@intercepts_request
|
||||
async def check_domain(
|
||||
self, request: DomainCheckRequest, ctx: WorkflowContext[Any]
|
||||
) -> RequestResponse[DomainCheckRequest, bool]:
|
||||
"""Check cache first, then forward if not found."""
|
||||
if request.domain in self.cache:
|
||||
# Return cached result
|
||||
return RequestResponse[DomainCheckRequest, bool].handled(self.cache[request.domain])
|
||||
|
||||
# Not in cache, forward to external
|
||||
return RequestResponse[DomainCheckRequest, bool].forward()
|
||||
|
||||
@handler
|
||||
async def collect(self, result: ValidationResult, ctx: WorkflowContext[None]) -> None:
|
||||
self.result = result
|
||||
|
||||
# Setup workflows
|
||||
email_validator = EmailValidator()
|
||||
request_info = RequestInfoExecutor()
|
||||
|
||||
validation_workflow = (
|
||||
WorkflowBuilder()
|
||||
.set_start_executor(email_validator)
|
||||
.add_edge(email_validator, request_info)
|
||||
.add_edge(request_info, email_validator)
|
||||
.build()
|
||||
)
|
||||
|
||||
parent = ConditionalParent()
|
||||
workflow_executor = WorkflowExecutor(validation_workflow, id="email_workflow")
|
||||
parent_request_info = RequestInfoExecutor()
|
||||
|
||||
main_workflow = (
|
||||
WorkflowBuilder()
|
||||
.set_start_executor(parent)
|
||||
.add_edge(parent, workflow_executor)
|
||||
.add_edge(workflow_executor, parent)
|
||||
.add_edge(parent, parent_request_info)
|
||||
.add_edge(parent_request_info, workflow_executor) # For SubWorkflowResponse routing
|
||||
.build()
|
||||
)
|
||||
|
||||
# Test cached domain
|
||||
result = await main_workflow.run("user@cached.com")
|
||||
request_events = result.get_request_info_events()
|
||||
assert len(request_events) == 0 # Handled from cache
|
||||
assert parent.result is not None
|
||||
assert parent.result.is_valid is True
|
||||
|
||||
# Test uncached domain
|
||||
parent.result = None
|
||||
result = await main_workflow.run("user@new.com")
|
||||
request_events = result.get_request_info_events()
|
||||
assert len(request_events) == 1 # Forwarded to external
|
||||
|
||||
await main_workflow.send_responses({request_events[0].request_id: True})
|
||||
assert parent.result is not None
|
||||
assert parent.result.is_valid is True
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_workflow_scoped_interception() -> None:
|
||||
"""Test interception scoped to specific sub-workflows."""
|
||||
|
||||
class MultiWorkflowParent(Executor):
|
||||
"""Parent handling multiple sub-workflows."""
|
||||
|
||||
results: dict[str, ValidationResult] = Field(default_factory=dict)
|
||||
|
||||
def __init__(self, **kwargs: Any):
|
||||
super().__init__(id="multi_parent", **kwargs)
|
||||
|
||||
@handler
|
||||
async def start(self, data: dict[str, str], ctx: WorkflowContext[EmailValidationRequest]) -> None:
|
||||
# Send to different sub-workflows
|
||||
await ctx.send_message(EmailValidationRequest(email=data["email1"]), target_id="workflow_a")
|
||||
await ctx.send_message(EmailValidationRequest(email=data["email2"]), target_id="workflow_b")
|
||||
|
||||
@intercepts_request(from_workflow="workflow_a")
|
||||
async def check_domain_a(
|
||||
self, request: DomainCheckRequest, ctx: WorkflowContext[Any]
|
||||
) -> RequestResponse[DomainCheckRequest, bool]:
|
||||
"""Strict rules for workflow A."""
|
||||
if request.domain == "strict.com":
|
||||
return RequestResponse[DomainCheckRequest, bool].handled(True)
|
||||
return RequestResponse[DomainCheckRequest, bool].forward()
|
||||
|
||||
@intercepts_request(from_workflow="workflow_b")
|
||||
async def check_domain_b(
|
||||
self, request: DomainCheckRequest, ctx: WorkflowContext[Any]
|
||||
) -> RequestResponse[DomainCheckRequest, bool]:
|
||||
"""Lenient rules for workflow B."""
|
||||
if request.domain.endswith(".com"):
|
||||
return RequestResponse[DomainCheckRequest, bool].handled(True)
|
||||
return RequestResponse[DomainCheckRequest, bool].forward()
|
||||
|
||||
@handler
|
||||
async def collect(self, result: ValidationResult, ctx: WorkflowContext[None]) -> None:
|
||||
self.results[result.email] = result
|
||||
|
||||
# Create two identical sub-workflows
|
||||
def create_validation_workflow():
|
||||
validator = EmailValidator()
|
||||
request_info = RequestInfoExecutor()
|
||||
return (
|
||||
WorkflowBuilder()
|
||||
.set_start_executor(validator)
|
||||
.add_edge(validator, request_info)
|
||||
.add_edge(request_info, validator)
|
||||
.build()
|
||||
)
|
||||
|
||||
workflow_a = create_validation_workflow()
|
||||
workflow_b = create_validation_workflow()
|
||||
|
||||
parent = MultiWorkflowParent()
|
||||
executor_a = WorkflowExecutor(workflow_a, id="workflow_a")
|
||||
executor_b = WorkflowExecutor(workflow_b, id="workflow_b")
|
||||
parent_request_info = RequestInfoExecutor()
|
||||
|
||||
main_workflow = (
|
||||
WorkflowBuilder()
|
||||
.set_start_executor(parent)
|
||||
.add_edge(parent, executor_a)
|
||||
.add_edge(parent, executor_b)
|
||||
.add_edge(executor_a, parent)
|
||||
.add_edge(executor_b, parent)
|
||||
.add_edge(parent, parent_request_info)
|
||||
.add_edge(parent_request_info, executor_a) # For SubWorkflowResponse routing
|
||||
.add_edge(parent_request_info, executor_b) # For SubWorkflowResponse routing
|
||||
.build()
|
||||
)
|
||||
|
||||
# Run test
|
||||
result = await main_workflow.run({"email1": "user@strict.com", "email2": "user@random.com"})
|
||||
|
||||
# Workflow A should handle strict.com
|
||||
# Workflow B should handle any .com domain
|
||||
request_events = result.get_request_info_events()
|
||||
assert len(request_events) == 0 # Both handled internally
|
||||
|
||||
assert len(parent.results) == 2
|
||||
assert parent.results["user@strict.com"].is_valid is True
|
||||
assert parent.results["user@random.com"].is_valid is True
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# Run tests
|
||||
asyncio.run(test_basic_sub_workflow())
|
||||
asyncio.run(test_sub_workflow_with_interception())
|
||||
asyncio.run(test_conditional_forwarding())
|
||||
asyncio.run(test_workflow_scoped_interception())
|
||||
@@ -11,6 +11,7 @@ from agent_framework.workflow import (
|
||||
RequestInfoEvent,
|
||||
RequestInfoExecutor,
|
||||
RequestInfoMessage,
|
||||
RequestResponse,
|
||||
WorkflowBuilder,
|
||||
WorkflowCompletedEvent,
|
||||
WorkflowContext,
|
||||
@@ -68,10 +69,13 @@ class MockExecutorRequestApproval(Executor):
|
||||
await ctx.send_message(RequestInfoMessage())
|
||||
|
||||
@handler
|
||||
async def mock_handler_b(self, message: ApprovalMessage, ctx: WorkflowContext[NumberMessage]) -> None:
|
||||
async def mock_handler_b(
|
||||
self, message: RequestResponse[RequestInfoMessage, ApprovalMessage], ctx: WorkflowContext[NumberMessage]
|
||||
) -> None:
|
||||
"""A mock handler that processes the approval response."""
|
||||
data = await ctx.get_shared_state(self.id)
|
||||
if message.approved:
|
||||
assert isinstance(message.data, ApprovalMessage)
|
||||
if message.data.approved:
|
||||
await ctx.add_event(WorkflowCompletedEvent(data=data))
|
||||
else:
|
||||
await ctx.send_message(NumberMessage(data=data))
|
||||
|
||||
@@ -12,6 +12,7 @@ from agent_framework.workflow import (
|
||||
RequestInfoEvent,
|
||||
RequestInfoExecutor,
|
||||
RequestInfoMessage,
|
||||
RequestResponse,
|
||||
WorkflowBuilder,
|
||||
WorkflowCompletedEvent,
|
||||
WorkflowContext,
|
||||
@@ -94,16 +95,20 @@ class CriticGroupChatManager(Executor):
|
||||
|
||||
@handler
|
||||
async def handle_request_response(
|
||||
self, response: list[ChatMessage], ctx: WorkflowContext[AgentExecutorRequest]
|
||||
self,
|
||||
response: RequestResponse[RequestInfoMessage, list[ChatMessage]],
|
||||
ctx: WorkflowContext[AgentExecutorRequest],
|
||||
) -> None:
|
||||
"""Handler that processes the response from the RequestInfoExecutor."""
|
||||
messages: list[ChatMessage] = response.data or []
|
||||
|
||||
# Update the chat history with the response
|
||||
self._chat_history.extend(response)
|
||||
self._chat_history.extend(messages)
|
||||
|
||||
# Send the response to the other members
|
||||
await asyncio.gather(*[
|
||||
ctx.send_message(
|
||||
AgentExecutorRequest(messages=response, should_respond=False),
|
||||
AgentExecutorRequest(messages=messages, should_respond=False),
|
||||
target_id=member_id,
|
||||
)
|
||||
for member_id in self._members
|
||||
|
||||
@@ -0,0 +1,233 @@
|
||||
# Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
import asyncio
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
from agent_framework.workflow import (
|
||||
Executor,
|
||||
WorkflowBuilder,
|
||||
WorkflowCompletedEvent,
|
||||
WorkflowContext,
|
||||
WorkflowEvent,
|
||||
WorkflowExecutor,
|
||||
handler,
|
||||
)
|
||||
|
||||
"""
|
||||
The following sample demonstrates basic sub-workflows.
|
||||
|
||||
This sample shows how to:
|
||||
1. Create workflows that execute other workflows as sub-workflows
|
||||
2. Pass data between parent and sub-workflows
|
||||
3. Collect results from sub-workflows
|
||||
|
||||
The example simulates a simple text processing system where:
|
||||
- Sub-workflows process individual text strings (count words, characters)
|
||||
- Parent workflow orchestrates multiple sub-workflows and collects results
|
||||
|
||||
Key concepts demonstrated:
|
||||
- WorkflowExecutor: Wraps a workflow to make it behave as an executor
|
||||
- Sub-workflow isolation: Sub-workflows work independently
|
||||
- Result collection: Parent workflows can gather outputs from sub-workflows
|
||||
|
||||
Simple flow visualization:
|
||||
|
||||
Parent Orchestrator
|
||||
|
|
||||
| TextProcessingRequest(text, task_id)
|
||||
v
|
||||
[ Sub-workflow: WorkflowExecutor(TextProcessor) ]
|
||||
|
|
||||
| WorkflowCompletedEvent(TextProcessingResult)
|
||||
v
|
||||
Parent collects results and summarizes
|
||||
"""
|
||||
|
||||
|
||||
# Message types
|
||||
@dataclass
|
||||
class TextProcessingRequest:
|
||||
"""Request to process a text string."""
|
||||
|
||||
text: str
|
||||
task_id: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class TextProcessingResult:
|
||||
"""Result of text processing."""
|
||||
|
||||
task_id: str
|
||||
text: str
|
||||
word_count: int
|
||||
char_count: int
|
||||
|
||||
|
||||
class AllTasksCompleted(WorkflowEvent):
|
||||
"""Event triggered when all processing tasks are complete."""
|
||||
|
||||
def __init__(self, results: list[TextProcessingResult]):
|
||||
super().__init__(results)
|
||||
|
||||
|
||||
# Sub-workflow executor
|
||||
class TextProcessor(Executor):
|
||||
"""Processes text strings - counts words and characters."""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(id="text_processor")
|
||||
|
||||
@handler
|
||||
async def process_text(
|
||||
self, request: TextProcessingRequest, ctx: WorkflowContext[TextProcessingResult]
|
||||
) -> None:
|
||||
"""Process a text string and return statistics."""
|
||||
text_preview = f"'{request.text[:50]}{'...' if len(request.text) > 50 else ''}'"
|
||||
print(f"π Sub-workflow processing text (Task {request.task_id}): {text_preview}")
|
||||
|
||||
# Simple text processing
|
||||
word_count = len(request.text.split()) if request.text.strip() else 0
|
||||
char_count = len(request.text)
|
||||
|
||||
print(f"π Task {request.task_id}: {word_count} words, {char_count} characters")
|
||||
|
||||
# Create result
|
||||
result = TextProcessingResult(
|
||||
task_id=request.task_id,
|
||||
text=request.text,
|
||||
word_count=word_count,
|
||||
char_count=char_count,
|
||||
)
|
||||
|
||||
print(f"β
Sub-workflow completed task {request.task_id}")
|
||||
# Signal completion
|
||||
await ctx.add_event(WorkflowCompletedEvent(data=result))
|
||||
|
||||
|
||||
# Parent workflow
|
||||
class TextProcessingOrchestrator(Executor):
|
||||
"""Orchestrates multiple text processing tasks using sub-workflows."""
|
||||
|
||||
results: list[TextProcessingResult] = []
|
||||
expected_count: int = 0
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(id="text_orchestrator")
|
||||
|
||||
@handler
|
||||
async def start_processing(
|
||||
self, texts: list[str], ctx: WorkflowContext[TextProcessingRequest]
|
||||
) -> None:
|
||||
"""Start processing multiple text strings."""
|
||||
print(f"π Starting processing of {len(texts)} text strings")
|
||||
print("=" * 60)
|
||||
|
||||
self.expected_count = len(texts)
|
||||
|
||||
# Send each text to a sub-workflow
|
||||
for i, text in enumerate(texts):
|
||||
task_id = f"task_{i+1}"
|
||||
request = TextProcessingRequest(text=text, task_id=task_id)
|
||||
print(f"π€ Dispatching {task_id} to sub-workflow")
|
||||
await ctx.send_message(request, target_id="text_processor_workflow")
|
||||
|
||||
@handler
|
||||
async def collect_result(
|
||||
self, result: TextProcessingResult, ctx: WorkflowContext[None]
|
||||
) -> None:
|
||||
"""Collect results from sub-workflows."""
|
||||
print(f"π₯ Collected result from {result.task_id}")
|
||||
self.results.append(result)
|
||||
|
||||
# Check if all results are collected
|
||||
if len(self.results) == self.expected_count:
|
||||
print("\nπ All tasks completed!")
|
||||
await ctx.add_event(AllTasksCompleted(self.results))
|
||||
|
||||
def get_summary(self) -> dict[str, Any]:
|
||||
"""Get a summary of all processing results."""
|
||||
total_words = sum(result.word_count for result in self.results)
|
||||
total_chars = sum(result.char_count for result in self.results)
|
||||
avg_words = total_words / len(self.results) if self.results else 0
|
||||
avg_chars = total_chars / len(self.results) if self.results else 0
|
||||
|
||||
return {
|
||||
"total_texts": len(self.results),
|
||||
"total_words": total_words,
|
||||
"total_characters": total_chars,
|
||||
"average_words_per_text": round(avg_words, 2),
|
||||
"average_characters_per_text": round(avg_chars, 2),
|
||||
}
|
||||
|
||||
|
||||
async def main():
|
||||
"""Main function to run the basic sub-workflow example."""
|
||||
print("π Setting up sub-workflow...")
|
||||
|
||||
# Step 1: Create the text processing sub-workflow
|
||||
text_processor = TextProcessor()
|
||||
|
||||
processing_workflow = (
|
||||
WorkflowBuilder()
|
||||
.set_start_executor(text_processor)
|
||||
.build()
|
||||
)
|
||||
|
||||
print("π§ Setting up parent workflow...")
|
||||
|
||||
# Step 2: Create the parent workflow
|
||||
orchestrator = TextProcessingOrchestrator()
|
||||
workflow_executor = WorkflowExecutor(processing_workflow, id="text_processor_workflow")
|
||||
|
||||
main_workflow = (
|
||||
WorkflowBuilder()
|
||||
.set_start_executor(orchestrator)
|
||||
.add_edge(orchestrator, workflow_executor)
|
||||
.add_edge(workflow_executor, orchestrator)
|
||||
.build()
|
||||
)
|
||||
|
||||
# Step 3: Test data - various text strings
|
||||
test_texts = [
|
||||
"Hello world! This is a simple test.",
|
||||
"Python is a powerful programming language used for many applications.",
|
||||
"Short text.",
|
||||
"This is a longer text with multiple sentences. It contains more words and characters. We use it to test our text processing workflow.",
|
||||
"", # Empty string
|
||||
" Spaces around text ",
|
||||
]
|
||||
|
||||
print(f"\nπ§ͺ Testing with {len(test_texts)} text strings")
|
||||
print("=" * 60)
|
||||
|
||||
# Step 4: Run the workflow
|
||||
result = await main_workflow.run(test_texts)
|
||||
|
||||
# Step 5: Display results
|
||||
print(f"\nπ Processing Results:")
|
||||
print("=" * 60)
|
||||
|
||||
# Sort results by task_id for consistent display
|
||||
sorted_results = sorted(orchestrator.results, key=lambda r: r.task_id)
|
||||
|
||||
for result in sorted_results:
|
||||
preview = result.text[:30] + "..." if len(result.text) > 30 else result.text
|
||||
preview = preview.replace('\n', ' ').strip() or '(empty)'
|
||||
print(f"β
{result.task_id}: '{preview}' -> {result.word_count} words, {result.char_count} chars")
|
||||
|
||||
# Step 6: Display summary
|
||||
summary = orchestrator.get_summary()
|
||||
print(f"\nπ Summary:")
|
||||
print("=" * 60)
|
||||
print(f"π Total texts processed: {summary['total_texts']}")
|
||||
print(f"π Total words: {summary['total_words']}")
|
||||
print(f"π€ Total characters: {summary['total_characters']}")
|
||||
print(f"π Average words per text: {summary['average_words_per_text']}")
|
||||
print(f"π Average characters per text: {summary['average_characters_per_text']}")
|
||||
|
||||
print(f"\nπ Processing complete!")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
+275
@@ -0,0 +1,275 @@
|
||||
# Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
"""
|
||||
The following sample demonstrates sub-workflows with request interception and conditional forwarding.
|
||||
|
||||
This sample shows how to:
|
||||
1. Create workflows that execute other workflows as sub-workflows
|
||||
2. Intercept requests from sub-workflows in parent workflows using @intercepts_request
|
||||
3. Conditionally handle or forward requests using RequestResponse.handled() and RequestResponse.forward()
|
||||
4. Handle external requests that are forwarded by the parent workflow
|
||||
5. Proper request/response correlation for concurrent processing
|
||||
|
||||
The example simulates an email validation system where:
|
||||
- Sub-workflows validate multiple email addresses concurrently
|
||||
- Parent workflows can intercept domain check requests for optimization
|
||||
- Known domains (example.com, company.com) are approved locally
|
||||
- Unknown domains (unknown.org) are forwarded to external services
|
||||
- Request correlation ensures each email gets the correct domain check response
|
||||
- External domain check requests are processed and responses routed back correctly
|
||||
|
||||
Key concepts demonstrated:
|
||||
- WorkflowExecutor: Wraps a workflow to make it behave as an executor
|
||||
- @intercepts_request: Decorator for parent workflows to handle sub-workflow requests
|
||||
- RequestResponse: Enables conditional handling vs forwarding of requests
|
||||
- Request correlation: Using request_id to match responses with original requests
|
||||
- Concurrent processing: Multiple emails processed simultaneously without interference
|
||||
- External request routing: RequestInfoExecutor handles forwarded external requests
|
||||
- Sub-workflow isolation: Sub-workflows work normally without knowing they're nested
|
||||
|
||||
Simple flow visualization:
|
||||
|
||||
Parent Orchestrator (@intercepts_request)
|
||||
|
|
||||
| EmailValidationRequest(email) x3 (concurrent)
|
||||
v
|
||||
[ Sub-workflow: WorkflowExecutor(EmailValidator) ]
|
||||
|
|
||||
| DomainCheckRequest(domain) with request_id correlation
|
||||
v
|
||||
Interception? yes -> handled locally with RequestResponse.handled(True)
|
||||
no -> forwarded to RequestInfoExecutor -> external service
|
||||
|
|
||||
v
|
||||
Response routed back to sub-workflow using request_id
|
||||
"""
|
||||
|
||||
import asyncio
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
from agent_framework_workflow import (
|
||||
Executor,
|
||||
RequestInfoExecutor,
|
||||
RequestInfoMessage,
|
||||
RequestResponse,
|
||||
WorkflowBuilder,
|
||||
WorkflowCompletedEvent,
|
||||
WorkflowContext,
|
||||
WorkflowExecutor,
|
||||
handler,
|
||||
intercepts_request,
|
||||
)
|
||||
|
||||
|
||||
# 1. Define domain-specific message types
|
||||
@dataclass
|
||||
class EmailValidationRequest:
|
||||
"""Request to validate an email address."""
|
||||
|
||||
email: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class DomainCheckRequest(RequestInfoMessage):
|
||||
"""Request to check if a domain is approved."""
|
||||
|
||||
domain: str = ""
|
||||
|
||||
|
||||
@dataclass
|
||||
class ValidationResult:
|
||||
"""Result of email validation."""
|
||||
|
||||
email: str
|
||||
is_valid: bool
|
||||
reason: str
|
||||
|
||||
|
||||
# 2. Implement the sub-workflow executor (completely standard)
|
||||
class EmailValidator(Executor):
|
||||
"""Validates email addresses - doesn't know it's in a sub-workflow."""
|
||||
|
||||
def __init__(self):
|
||||
"""Initialize the EmailValidator executor."""
|
||||
super().__init__(id="email_validator")
|
||||
# Use a dict to track multiple pending emails by request_id
|
||||
self._pending_emails: dict[str, str] = {}
|
||||
|
||||
@handler
|
||||
async def validate_request(
|
||||
self, request: EmailValidationRequest, ctx: WorkflowContext[DomainCheckRequest | ValidationResult]
|
||||
) -> None:
|
||||
"""Validate an email address."""
|
||||
print(f"π Sub-workflow validating email: {request.email}")
|
||||
|
||||
# Extract domain
|
||||
domain = request.email.split("@")[1] if "@" in request.email else ""
|
||||
|
||||
if not domain:
|
||||
print(f"β Invalid email format: {request.email}")
|
||||
result = ValidationResult(email=request.email, is_valid=False, reason="Invalid email format")
|
||||
await ctx.add_event(WorkflowCompletedEvent(data=result))
|
||||
return
|
||||
|
||||
print(f"π Sub-workflow requesting domain check for: {domain}")
|
||||
# Request domain check
|
||||
domain_check = DomainCheckRequest(domain=domain)
|
||||
# Store the pending email with the request_id for correlation
|
||||
self._pending_emails[domain_check.request_id] = request.email
|
||||
await ctx.send_message(domain_check, target_id="email_request_info")
|
||||
|
||||
@handler
|
||||
async def handle_domain_response(
|
||||
self,
|
||||
response: RequestResponse[DomainCheckRequest, bool],
|
||||
ctx: WorkflowContext[ValidationResult],
|
||||
) -> None:
|
||||
"""Handle domain check response from RequestInfo with correlation."""
|
||||
approved = bool(response.data)
|
||||
domain = response.original_request.domain if (hasattr(response, 'original_request') and response.original_request) else "unknown"
|
||||
print(f"π¬ Sub-workflow received domain response for '{domain}': {approved}")
|
||||
|
||||
# Find the corresponding email using the request_id
|
||||
request_id = response.original_request.request_id if (hasattr(response, 'original_request') and response.original_request) else None
|
||||
if request_id and request_id in self._pending_emails:
|
||||
email = self._pending_emails.pop(request_id) # Remove from pending
|
||||
result = ValidationResult(
|
||||
email=email,
|
||||
is_valid=approved,
|
||||
reason="Domain approved" if approved else "Domain not approved",
|
||||
)
|
||||
print(f"β
Sub-workflow completing validation for: {email}")
|
||||
await ctx.add_event(WorkflowCompletedEvent(data=result))
|
||||
|
||||
|
||||
# 3. Implement the parent workflow with request interception
|
||||
class SmartEmailOrchestrator(Executor):
|
||||
"""Parent orchestrator that can intercept domain checks."""
|
||||
approved_domains: set[str] = set()
|
||||
|
||||
def __init__(self, approved_domains: set[str] | None = None):
|
||||
"""Initialize the SmartEmailOrchestrator with approved domains.
|
||||
|
||||
Args:
|
||||
approved_domains: Set of pre-approved domains, defaults to example.com, test.org, company.com
|
||||
"""
|
||||
super().__init__(id="email_orchestrator", approved_domains=approved_domains)
|
||||
self._results: list[ValidationResult] = []
|
||||
|
||||
@handler
|
||||
async def start_validation(self, emails: list[str], ctx: WorkflowContext[EmailValidationRequest]) -> None:
|
||||
"""Start validating a batch of emails."""
|
||||
print(f"π§ Starting validation of {len(emails)} email addresses")
|
||||
print("=" * 60)
|
||||
for email in emails:
|
||||
print(f"π€ Sending '{email}' to sub-workflow for validation")
|
||||
request = EmailValidationRequest(email=email)
|
||||
await ctx.send_message(request, target_id="email_validator_workflow")
|
||||
|
||||
@intercepts_request
|
||||
async def check_domain(
|
||||
self, request: DomainCheckRequest, ctx: WorkflowContext[Any]
|
||||
) -> RequestResponse[DomainCheckRequest, bool]:
|
||||
"""Intercept domain check requests from sub-workflows."""
|
||||
print(f"π Parent intercepting domain check for: {request.domain}")
|
||||
if request.domain in self.approved_domains:
|
||||
print(f"β
Domain '{request.domain}' is pre-approved locally!")
|
||||
return RequestResponse[DomainCheckRequest, bool].handled(True)
|
||||
print(f"β Domain '{request.domain}' unknown, forwarding to external service...")
|
||||
return RequestResponse.forward()
|
||||
|
||||
@handler
|
||||
async def collect_result(self, result: ValidationResult, ctx: WorkflowContext[None]) -> None:
|
||||
"""Collect validation results. It comes from the sub-workflow emitted WorkflowCompletionEvent's data field."""
|
||||
status_icon = "β
" if result.is_valid else "β"
|
||||
print(f"π₯ {status_icon} Validation result: {result.email} -> {result.reason}")
|
||||
self._results.append(result)
|
||||
|
||||
@property
|
||||
def results(self) -> list[ValidationResult]:
|
||||
"""Get the collected validation results."""
|
||||
return self._results
|
||||
|
||||
|
||||
async def run_example() -> None:
|
||||
"""Run the sub-workflow example."""
|
||||
print("π Setting up sub-workflow with request interception...")
|
||||
print()
|
||||
|
||||
# 4. Build the sub-workflow
|
||||
email_validator = EmailValidator()
|
||||
# Match the target_id used in EmailValidator ("email_request_info")
|
||||
request_info = RequestInfoExecutor(id="email_request_info")
|
||||
|
||||
validation_workflow = (
|
||||
WorkflowBuilder()
|
||||
.set_start_executor(email_validator)
|
||||
.add_edge(email_validator, request_info)
|
||||
.add_edge(request_info, email_validator)
|
||||
.build()
|
||||
)
|
||||
|
||||
# 5. Build the parent workflow with interception
|
||||
orchestrator = SmartEmailOrchestrator(approved_domains={"example.com", "company.com"})
|
||||
workflow_executor = WorkflowExecutor(validation_workflow, id="email_validator_workflow")
|
||||
# Add a RequestInfoExecutor to handle forwarded external requests
|
||||
main_request_info = RequestInfoExecutor(id="main_request_info")
|
||||
|
||||
main_workflow = (
|
||||
WorkflowBuilder()
|
||||
.set_start_executor(orchestrator)
|
||||
.add_edge(orchestrator, workflow_executor)
|
||||
.add_edge(workflow_executor, orchestrator)
|
||||
# Add edges for external request handling
|
||||
.add_edge(orchestrator, main_request_info)
|
||||
.add_edge(main_request_info, workflow_executor) # Route external responses to sub-workflow
|
||||
.build()
|
||||
)
|
||||
|
||||
# 6. Prepare test inputs: known domain, unknown domain
|
||||
test_emails = [
|
||||
"user@example.com", # Should be intercepted and approved
|
||||
"admin@company.com", # Should be intercepted and approved
|
||||
"guest@unknown.org", # Should be forwarded externally
|
||||
]
|
||||
|
||||
# 7. Run the workflow
|
||||
result = await main_workflow.run(test_emails)
|
||||
|
||||
# 8. Handle any external requests
|
||||
request_events = result.get_request_info_events()
|
||||
if request_events:
|
||||
print(f"\nπ Handling {len(request_events)} external request(s)...")
|
||||
for event in request_events:
|
||||
if event.data and hasattr(event.data, "domain"):
|
||||
print(f"π External domain check needed for: {event.data.domain}")
|
||||
|
||||
# Simulate external responses
|
||||
external_responses: dict[str, bool] = {}
|
||||
for event in request_events:
|
||||
# Simulate external domain checking
|
||||
if event.data and hasattr(event.data, "domain"):
|
||||
domain = event.data.domain
|
||||
# Let's say unknown.org is actually approved externally
|
||||
approved = domain == "unknown.org"
|
||||
print(f"π External service response for '{domain}': {'APPROVED' if approved else 'REJECTED'}")
|
||||
external_responses[event.request_id] = approved
|
||||
|
||||
# 9. Send external responses
|
||||
await main_workflow.send_responses(external_responses)
|
||||
else:
|
||||
print("\nπ― All requests were intercepted and handled locally!")
|
||||
|
||||
# 10. Display final summary
|
||||
print(f"\nπ Final Results Summary:")
|
||||
print("=" * 60)
|
||||
for result in orchestrator.results:
|
||||
status = "β
VALID" if result.is_valid else "β INVALID"
|
||||
print(f"{status} {result.email}: {result.reason}")
|
||||
|
||||
print(f"\nπ Processed {len(orchestrator.results)} emails total")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(run_example())
|
||||
@@ -0,0 +1,414 @@
|
||||
# Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
import asyncio
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
|
||||
from agent_framework.workflow import (
|
||||
Executor,
|
||||
RequestInfoExecutor,
|
||||
WorkflowBuilder,
|
||||
WorkflowCompletedEvent,
|
||||
WorkflowContext,
|
||||
handler,
|
||||
)
|
||||
|
||||
# Import the new sub-workflow types directly from the implementation package
|
||||
try:
|
||||
from agent_framework_workflow import (
|
||||
RequestInfoMessage,
|
||||
RequestResponse,
|
||||
WorkflowExecutor,
|
||||
intercepts_request,
|
||||
)
|
||||
except ImportError:
|
||||
import os
|
||||
import sys
|
||||
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "..", "..", "packages", "workflow"))
|
||||
from agent_framework_workflow import (
|
||||
RequestInfoMessage,
|
||||
RequestResponse,
|
||||
WorkflowExecutor,
|
||||
intercepts_request,
|
||||
)
|
||||
|
||||
"""
|
||||
This sample demonstrates the PROPER pattern for request interception.
|
||||
|
||||
Key principles:
|
||||
1. Only ONE executor intercepts a given request type from a specific sub-workflow
|
||||
2. Different executors can intercept DIFFERENT request types from the same sub-workflow
|
||||
3. The same executor can intercept the same request type from DIFFERENT sub-workflows
|
||||
|
||||
This ensures:
|
||||
- Deterministic behavior
|
||||
- Clear responsibility boundaries
|
||||
- Easier debugging and maintenance
|
||||
|
||||
The example simulates a resource allocation system where:
|
||||
- Sub-workflow requests resources (CPU, memory, etc.)
|
||||
- A single Cache executor intercepts and handles resource requests
|
||||
- The Cache can either satisfy from cache or forward to external
|
||||
|
||||
Simple flow visualization:
|
||||
|
||||
Coordinator
|
||||
|
|
||||
| list[resource/policy requests]
|
||||
v
|
||||
[ Sub-workflow: WorkflowExecutor(ResourceRequester) ]
|
||||
| |
|
||||
| ResourceRequest | PolicyCheckRequest
|
||||
v v
|
||||
ResourceCache (@intercepts) PolicyEngine (@intercepts)
|
||||
| handled/forward | handled/forward
|
||||
v v
|
||||
RequestInfo (external) <----- forwarded when not handled
|
||||
| responses
|
||||
v
|
||||
Back to sub-workflow -> completion -> results collected
|
||||
"""
|
||||
|
||||
|
||||
# 1. Define domain-specific request/response types
|
||||
@dataclass
|
||||
class ResourceRequest(RequestInfoMessage):
|
||||
"""Request for computing resources."""
|
||||
|
||||
resource_type: str = "cpu" # cpu, memory, disk, etc.
|
||||
amount: int = 1
|
||||
priority: str = "normal" # low, normal, high
|
||||
|
||||
|
||||
@dataclass
|
||||
class PolicyCheckRequest(RequestInfoMessage):
|
||||
"""Request to check resource allocation policy."""
|
||||
|
||||
resource_type: str = ""
|
||||
amount: int = 0
|
||||
policy_type: str = "quota" # quota, compliance, security
|
||||
|
||||
|
||||
@dataclass
|
||||
class ResourceResponse:
|
||||
"""Response with allocated resources."""
|
||||
|
||||
resource_type: str
|
||||
allocated: int
|
||||
source: str # Which system provided the resources
|
||||
|
||||
|
||||
@dataclass
|
||||
class PolicyResponse:
|
||||
"""Response from policy check."""
|
||||
|
||||
approved: bool
|
||||
reason: str
|
||||
|
||||
|
||||
@dataclass
|
||||
class RequestFinished:
|
||||
pass
|
||||
|
||||
|
||||
# 2. Implement the sub-workflow executor - makes resource and policy requests
|
||||
class ResourceRequester(Executor):
|
||||
"""Simple executor that requests resources and checks policies."""
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(id="resource_requester")
|
||||
self._request_count = 0
|
||||
|
||||
@handler
|
||||
async def request_resources(
|
||||
self,
|
||||
requests: list[dict[str, Any]],
|
||||
ctx: WorkflowContext[ResourceRequest | PolicyCheckRequest],
|
||||
) -> None:
|
||||
"""Process a list of resource requests."""
|
||||
print(f"π Sub-workflow processing {len(requests)} requests")
|
||||
self._request_count += len(requests)
|
||||
|
||||
for req_data in requests:
|
||||
req_type = req_data.get("request_type", "resource")
|
||||
|
||||
if req_type == "resource":
|
||||
print(f" π¦ Requesting resource: {req_data.get('type', 'cpu')} x{req_data.get('amount', 1)}")
|
||||
request = ResourceRequest(
|
||||
resource_type=req_data.get("type", "cpu"),
|
||||
amount=req_data.get("amount", 1),
|
||||
priority=req_data.get("priority", "normal"),
|
||||
)
|
||||
# Send to parent workflow for interception - not to target_id
|
||||
await ctx.send_message(request)
|
||||
elif req_type == "policy":
|
||||
print(f" π‘οΈ Checking policy: {req_data.get('type', 'cpu')} x{req_data.get('amount', 1)} ({req_data.get('policy_type', 'quota')})")
|
||||
request = PolicyCheckRequest(
|
||||
resource_type=req_data.get("type", "cpu"),
|
||||
amount=req_data.get("amount", 1),
|
||||
policy_type=req_data.get("policy_type", "quota"),
|
||||
)
|
||||
# Send to parent workflow for interception - not to target_id
|
||||
await ctx.send_message(request)
|
||||
|
||||
@handler
|
||||
async def handle_resource_response(
|
||||
self,
|
||||
response: RequestResponse[ResourceRequest, ResourceResponse],
|
||||
ctx: WorkflowContext[None],
|
||||
) -> None:
|
||||
"""Handle resource allocation response."""
|
||||
if response.data:
|
||||
source_icon = "πͺ" if response.data.source == "cache" else "π"
|
||||
print(
|
||||
f"π¦ {source_icon} Sub-workflow received: {response.data.allocated} {response.data.resource_type} from {response.data.source}"
|
||||
)
|
||||
if self._collect_results():
|
||||
# Emit completion event and send RequestFinished to the parent workflow.
|
||||
await ctx.add_event(WorkflowCompletedEvent(RequestFinished()))
|
||||
|
||||
@handler
|
||||
async def handle_policy_response(
|
||||
self, response: RequestResponse[PolicyCheckRequest, PolicyResponse], ctx: WorkflowContext[None]
|
||||
) -> None:
|
||||
"""Handle policy check response."""
|
||||
if response.data:
|
||||
status_icon = "β
" if response.data.approved else "β"
|
||||
print(f"π‘οΈ {status_icon} Sub-workflow received policy response: {response.data.approved} - {response.data.reason}")
|
||||
if self._collect_results():
|
||||
# Emit completion event and send RequestFinished to the parent workflow.
|
||||
await ctx.add_event(WorkflowCompletedEvent(RequestFinished()))
|
||||
|
||||
def _collect_results(self) -> bool:
|
||||
"""Collect and summarize results."""
|
||||
self._request_count -= 1
|
||||
print(f"π Sub-workflow completed request ({self._request_count} remaining)")
|
||||
return self._request_count == 0
|
||||
|
||||
|
||||
# 3. Implement the Resource Cache - ONLY intercepts ResourceRequest
|
||||
class ResourceCache(Executor):
|
||||
"""Interceptor that handles RESOURCE requests from cache."""
|
||||
|
||||
# Use class attributes to avoid Pydantic assignment restrictions
|
||||
cache: dict[str, int] = {"cpu": 10, "memory": 50, "disk": 100}
|
||||
results: list[ResourceResponse] = []
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(id="resource_cache")
|
||||
# Instance initialization only; state kept in class attributes as above
|
||||
|
||||
@intercepts_request
|
||||
async def check_cache(
|
||||
self, request: ResourceRequest, ctx: WorkflowContext[None]
|
||||
) -> RequestResponse[ResourceRequest, ResourceResponse]:
|
||||
"""Intercept RESOURCE requests and check cache first."""
|
||||
print(f"πͺ CACHE interceptor checking: {request.amount} {request.resource_type}")
|
||||
|
||||
available = self.cache.get(request.resource_type, 0)
|
||||
|
||||
if available >= request.amount:
|
||||
# We can satisfy from cache
|
||||
self.cache[request.resource_type] -= request.amount
|
||||
response = ResourceResponse(resource_type=request.resource_type, allocated=request.amount, source="cache")
|
||||
print(f" β
Cache satisfied: {request.amount} {request.resource_type}")
|
||||
self.results.append(response)
|
||||
return RequestResponse[ResourceRequest, ResourceResponse].handled(response)
|
||||
|
||||
# Cache miss - forward to external
|
||||
print(f" β Cache miss: need {request.amount}, have {available} {request.resource_type}")
|
||||
return RequestResponse.forward()
|
||||
|
||||
@handler
|
||||
async def collect_result(
|
||||
self, response: RequestResponse[ResourceRequest, ResourceResponse], ctx: WorkflowContext[None]
|
||||
) -> None:
|
||||
"""Collect results from external requests that were forwarded."""
|
||||
if response.data and response.data.source != "cache": # Don't double-count our own results
|
||||
self.results.append(response.data)
|
||||
print(
|
||||
f"πͺ π Cache received external response: {response.data.allocated} {response.data.resource_type} from {response.data.source}"
|
||||
)
|
||||
|
||||
|
||||
# 4. Implement the Policy Engine - ONLY intercepts PolicyCheckRequest (different type!)
|
||||
class PolicyEngine(Executor):
|
||||
"""Interceptor that handles POLICY requests."""
|
||||
|
||||
# Use class attributes for simple sample state
|
||||
quota: dict[str, int] = {
|
||||
"cpu": 5, # Only allow up to 5 CPU units
|
||||
"memory": 20, # Only allow up to 20 memory units
|
||||
"disk": 1000, # Liberal disk policy
|
||||
}
|
||||
results: list[PolicyResponse] = []
|
||||
|
||||
def __init__(self):
|
||||
super().__init__(id="policy_engine")
|
||||
# Instance initialization only; state kept in class attributes as above
|
||||
|
||||
@intercepts_request
|
||||
async def check_policy(
|
||||
self, request: PolicyCheckRequest, ctx: WorkflowContext[None]
|
||||
) -> RequestResponse[PolicyCheckRequest, PolicyResponse]:
|
||||
"""Intercept POLICY requests and apply rules."""
|
||||
print(f"π‘οΈ POLICY interceptor checking: {request.amount} {request.resource_type}, policy={request.policy_type}")
|
||||
|
||||
quota_limit = self.quota.get(request.resource_type, 0)
|
||||
|
||||
if request.policy_type == "quota":
|
||||
if request.amount <= quota_limit:
|
||||
response = PolicyResponse(approved=True, reason=f"Within quota ({quota_limit})")
|
||||
print(f" β
Policy approved: {request.amount} <= {quota_limit}")
|
||||
self.results.append(response)
|
||||
return RequestResponse[PolicyCheckRequest, PolicyResponse].handled(response)
|
||||
# Exceeds quota - forward to external for review
|
||||
print(f" β Policy exceeds quota: {request.amount} > {quota_limit}, forwarding to external")
|
||||
return RequestResponse.forward()
|
||||
|
||||
# Unknown policy type - forward to external
|
||||
print(f" β Unknown policy type: {request.policy_type}, forwarding")
|
||||
return RequestResponse.forward()
|
||||
|
||||
@handler
|
||||
async def collect_policy_result(
|
||||
self, response: RequestResponse[PolicyCheckRequest, PolicyResponse], ctx: WorkflowContext[None]
|
||||
) -> None:
|
||||
"""Collect policy results from external requests that were forwarded."""
|
||||
if response.data:
|
||||
self.results.append(response.data)
|
||||
print(f"π‘οΈ π Policy received external response: {response.data.approved} - {response.data.reason}")
|
||||
|
||||
|
||||
class Coordinator(Executor):
|
||||
def __init__(self):
|
||||
super().__init__(id="coordinator")
|
||||
|
||||
@handler
|
||||
async def start(self, requests: list[dict[str, Any]], ctx: WorkflowContext[object]) -> None:
|
||||
"""Start the resource allocation process."""
|
||||
await ctx.send_message(requests, target_id="resource_workflow")
|
||||
|
||||
@handler
|
||||
async def handle_completion(self, completion: RequestFinished, ctx: WorkflowContext[None]) -> None:
|
||||
"""Handle sub-workflow completion. It comes from the sub-workflow emitted WorkflowCompletionEvent's data field."""
|
||||
print(f"π― Main workflow received completion.")
|
||||
|
||||
|
||||
async def main() -> None:
|
||||
"""Demonstrate parallel request interception patterns."""
|
||||
print("π Starting Sub-Workflow Parallel Request Interception Demo...")
|
||||
print("=" * 60)
|
||||
|
||||
# 5. Create the sub-workflow
|
||||
resource_requester = ResourceRequester()
|
||||
sub_request_info = RequestInfoExecutor(id="sub_request_info")
|
||||
|
||||
sub_workflow = (
|
||||
WorkflowBuilder()
|
||||
.set_start_executor(resource_requester)
|
||||
.add_edge(resource_requester, sub_request_info)
|
||||
.add_edge(sub_request_info, resource_requester)
|
||||
.build()
|
||||
)
|
||||
|
||||
# 6. Create parent workflow with PROPER interceptor pattern
|
||||
cache = ResourceCache() # Intercepts ResourceRequest
|
||||
policy = PolicyEngine() # Intercepts PolicyCheckRequest (different type!)
|
||||
workflow_executor = WorkflowExecutor(sub_workflow, id="resource_workflow")
|
||||
main_request_info = RequestInfoExecutor(id="main_request_info")
|
||||
|
||||
# Create a simple coordinator that starts the process
|
||||
coordinator = Coordinator()
|
||||
|
||||
# PROPER PATTERN: Each executor intercepts DIFFERENT request types
|
||||
main_workflow = (
|
||||
WorkflowBuilder()
|
||||
.set_start_executor(coordinator)
|
||||
.add_edge(coordinator, workflow_executor) # Start sub-workflow
|
||||
.add_edge(workflow_executor, coordinator) # Sub-workflow completion back to coordinator
|
||||
.add_edge(workflow_executor, cache) # Cache intercepts ResourceRequest
|
||||
.add_edge(cache, workflow_executor) # Cache handles ResourceRequest
|
||||
.add_edge(workflow_executor, policy) # Policy handles PolicyCheckRequest
|
||||
.add_edge(policy, workflow_executor) # Policy intercepts PolicyCheckRequest
|
||||
.add_edge(cache, main_request_info) # Cache forwards to external
|
||||
.add_edge(policy, main_request_info) # Policy forwards to external
|
||||
.add_edge(main_request_info, workflow_executor) # External responses back
|
||||
.add_edge(workflow_executor, main_request_info) # Sub-workflow forwards to main
|
||||
.build()
|
||||
)
|
||||
|
||||
# 7. Test with various requests (mixed resource and policy)
|
||||
test_requests = [
|
||||
{"request_type": "resource", "type": "cpu", "amount": 2, "priority": "normal"}, # Cache hit
|
||||
{"request_type": "policy", "type": "cpu", "amount": 3, "policy_type": "quota"}, # Policy hit
|
||||
{"request_type": "resource", "type": "memory", "amount": 15, "priority": "normal"}, # Cache hit
|
||||
{"request_type": "policy", "type": "memory", "amount": 100, "policy_type": "quota"}, # Policy miss -> external
|
||||
{"request_type": "resource", "type": "gpu", "amount": 1, "priority": "high"}, # Cache miss -> external
|
||||
{"request_type": "policy", "type": "disk", "amount": 500, "policy_type": "quota"}, # Policy hit
|
||||
{"request_type": "policy", "type": "cpu", "amount": 1, "policy_type": "security"}, # Unknown policy -> external
|
||||
]
|
||||
|
||||
print(f"π§ͺ Testing with {len(test_requests)} mixed requests:")
|
||||
for i, req in enumerate(test_requests, 1):
|
||||
req_icon = "π¦" if req["request_type"] == "resource" else "π‘οΈ"
|
||||
print(f" {i}. {req_icon} {req['type']} x{req['amount']} ({req.get('priority', req.get('policy_type', 'default'))})")
|
||||
print("=" * 70)
|
||||
|
||||
# 8. Run the workflow
|
||||
print("π¬ Running workflow...")
|
||||
result = await main_workflow.run(test_requests)
|
||||
|
||||
# 9. Handle any external requests that couldn't be intercepted
|
||||
request_events = result.get_request_info_events()
|
||||
if request_events:
|
||||
print(f"\nπ Handling {len(request_events)} external request(s)...")
|
||||
|
||||
external_responses: dict[str, Any] = {}
|
||||
for event in request_events:
|
||||
if isinstance(event.data, ResourceRequest):
|
||||
# Handle ResourceRequest - create ResourceResponse
|
||||
resource_response = ResourceResponse(
|
||||
resource_type=event.data.resource_type, allocated=event.data.amount, source="external_provider"
|
||||
)
|
||||
external_responses[event.request_id] = resource_response
|
||||
print(f" π External provider: {resource_response.allocated} {resource_response.resource_type}")
|
||||
elif isinstance(event.data, PolicyCheckRequest):
|
||||
# Handle PolicyCheckRequest - create PolicyResponse
|
||||
policy_response = PolicyResponse(approved=True, reason="External policy service approved")
|
||||
external_responses[event.request_id] = policy_response
|
||||
print(f" π External policy: {'β
APPROVED' if policy_response.approved else 'β DENIED'}")
|
||||
|
||||
await main_workflow.send_responses(external_responses)
|
||||
else:
|
||||
print("\nπ― All requests were intercepted internally!")
|
||||
|
||||
# 10. Show results and analysis
|
||||
print("\n" + "=" * 70)
|
||||
print("π RESULTS ANALYSIS")
|
||||
print("=" * 70)
|
||||
|
||||
print(f"\nπͺ Cache Results ({len(cache.results)} handled):")
|
||||
for result in cache.results:
|
||||
print(f" β
{result.allocated} {result.resource_type} from {result.source}")
|
||||
|
||||
print(f"\nπ‘οΈ Policy Results ({len(policy.results)} handled):")
|
||||
for result in policy.results:
|
||||
status_icon = "β
" if result.approved else "β"
|
||||
print(f" {status_icon} Approved: {result.approved} - {result.reason}")
|
||||
|
||||
print(f"\nπΎ Final Cache State:")
|
||||
for resource, amount in cache.cache.items():
|
||||
print(f" π¦ {resource}: {amount} remaining")
|
||||
|
||||
print(f"\nπ Summary:")
|
||||
print(f" π― Total requests: {len(test_requests)}")
|
||||
print(f" πͺ Resource requests handled: {len(cache.results)}")
|
||||
print(f" π‘οΈ Policy requests handled: {len(policy.results)}")
|
||||
print(f" π External requests: {len(request_events) if request_events else 0}")
|
||||
|
||||
print("\n" + "=" * 70)
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
Reference in New Issue
Block a user