From beb52188382d2ecc475e4960f727d31f849a78d2 Mon Sep 17 00:00:00 2001 From: Tao Chen Date: Tue, 26 Aug 2025 13:21:32 -0700 Subject: [PATCH] Python: Clean up and formatting (#487) * Clean up and formatting * Fix mypy * Bug fix --- .../agent_framework_workflow/_edge.py | 31 ++- .../agent_framework_workflow/_events.py | 7 +- .../agent_framework_workflow/_executor.py | 247 +++++++++--------- .../workflow/step_09a_sub_workflow.py | 30 +-- ...p_09b_sub_workflow_request_interception.py | 23 +- ...step_09c_sub_workflow_parallel_requests.py | 45 ++-- 6 files changed, 196 insertions(+), 187 deletions(-) diff --git a/python/packages/workflow/agent_framework_workflow/_edge.py b/python/packages/workflow/agent_framework_workflow/_edge.py index 2e30c414fe..76a0cb9a3a 100644 --- a/python/packages/workflow/agent_framework_workflow/_edge.py +++ b/python/packages/workflow/agent_framework_workflow/_edge.py @@ -77,6 +77,11 @@ class Edge(AFBaseModel): return self._condition(data) +def _default_edge_list() -> list[Edge]: + """Get the default list of edges for the group.""" + return [] + + class EdgeGroup(AFBaseModel): """Represents a group of edges that share some common properties and can be triggered together.""" @@ -84,7 +89,7 @@ class EdgeGroup(AFBaseModel): default_factory=lambda: f"EdgeGroup/{uuid.uuid4()}", description="Unique identifier for the edge group" ) type: str = Field(description="The type of edge group, corresponding to the class name") - edges: list[Edge] = Field(default_factory=list, description="List of edges in this group") + edges: list[Edge] = Field(default_factory=_default_edge_list, description="List of edges in this group") def __init__(self, **kwargs: Any) -> None: """Initialize the edge group.""" @@ -97,24 +102,12 @@ class EdgeGroup(AFBaseModel): @property def source_executor_ids(self) -> list[str]: """Get the source executor IDs of the edges in the group.""" - seen = set() - result = [] - for edge in self.edges: - if edge.source_id not in seen: - result.append(edge.source_id) - seen.add(edge.source_id) - return result + return list(dict.fromkeys(edge.source_id for edge in self.edges)) @property def target_executor_ids(self) -> list[str]: """Get the target executor IDs of the edges in the group.""" - seen = set() - result = [] - for edge in self.edges: - if edge.target_id not in seen: - result.append(edge.target_id) - seen.add(edge.target_id) - return result + return list(dict.fromkeys(edge.target_id for edge in self.edges)) class SingleEdgeGroup(EdgeGroup): @@ -275,6 +268,11 @@ class SwitchCaseEdgeGroupDefault(AFBaseModel): type: str = Field(default="Default", description="The type of the case") +def _default_case_list() -> list[SwitchCaseEdgeGroupCase | SwitchCaseEdgeGroupDefault]: + """Get the default list of cases for the group.""" + return [] + + class SwitchCaseEdgeGroup(FanOutEdgeGroup): """Represents a group of edges that assemble a conditional routing pattern. @@ -299,7 +297,8 @@ class SwitchCaseEdgeGroup(FanOutEdgeGroup): """ cases: list[SwitchCaseEdgeGroupCase | SwitchCaseEdgeGroupDefault] = Field( - default_factory=list, description="List of conditional cases for this switch-case group" + default_factory=_default_case_list, + description="List of conditional cases for this switch-case group", ) def __init__( diff --git a/python/packages/workflow/agent_framework_workflow/_events.py b/python/packages/workflow/agent_framework_workflow/_events.py index 857ace18d9..f4a67036ff 100644 --- a/python/packages/workflow/agent_framework_workflow/_events.py +++ b/python/packages/workflow/agent_framework_workflow/_events.py @@ -1,9 +1,12 @@ # Copyright (c) Microsoft. All rights reserved. -from typing import Any +from typing import TYPE_CHECKING, Any from agent_framework import AgentRunResponse, AgentRunResponseUpdate +if TYPE_CHECKING: + from ._executor import RequestInfoMessage + class WorkflowEvent: """Base class for workflow events.""" @@ -61,7 +64,7 @@ class RequestInfoEvent(WorkflowEvent): request_id: str, source_executor_id: str, request_type: type, - request_data: Any, + request_data: "RequestInfoMessage", ): """Initialize the request info event. diff --git a/python/packages/workflow/agent_framework_workflow/_executor.py b/python/packages/workflow/agent_framework_workflow/_executor.py index 239cf44b21..daaac755bf 100644 --- a/python/packages/workflow/agent_framework_workflow/_executor.py +++ b/python/packages/workflow/agent_framework_workflow/_executor.py @@ -188,8 +188,10 @@ class Executor(AFBaseModel): # Check if interceptor handled it or needs to forward if isinstance(response, RequestResponse): # Add automatic correlation info to the response - correlated_response = RequestResponse._with_correlation( - response, request.data, request.request_id + correlated_response = RequestResponse[RequestInfoMessage, Any].with_correlation( + response, # pyright: ignore[reportUnknownArgumentType] + request.data, + request.request_id, ) if correlated_response.is_handled: @@ -257,23 +259,8 @@ class Executor(AFBaseModel): ExecutorT = TypeVar("ExecutorT", bound="Executor") -@overload def handler( func: Callable[[ExecutorT, Any, WorkflowContext[Any]], Awaitable[Any]], -) -> Callable[[ExecutorT, Any, WorkflowContext[Any]], Awaitable[Any]]: ... - - -@overload -def handler( - func: None = None, -) -> Callable[ - [Callable[[ExecutorT, Any, WorkflowContext[Any]], Awaitable[Any]]], - Callable[[ExecutorT, Any, WorkflowContext[Any]], Awaitable[Any]], -]: ... - - -def handler( - func: Callable[[ExecutorT, Any, WorkflowContext[Any]], Awaitable[Any]] | None = None, ) -> ( Callable[[ExecutorT, Any, WorkflowContext[Any]], Awaitable[Any]] | Callable[ @@ -382,14 +369,24 @@ def handler( return wrapper - if func is None: - return decorator return decorator(func) # endregion: Handler Decorator + # region Request/Response Types +@dataclass +class RequestInfoMessage: + """Base class for all request messages in workflows. + + Any message that should be routed to the RequestInfoExecutor for external + handling must inherit from this class. This ensures type safety and makes + the request/response pattern explicit. + """ + + request_id: str = field(default_factory=lambda: str(uuid.uuid4())) + TRequest = TypeVar("TRequest", bound="RequestInfoMessage") TResponse = TypeVar("TResponse") @@ -411,7 +408,7 @@ class RequestResponse(Generic[TRequest, TResponse]): request_id: str | None = None # Added for tracking @classmethod - def handled(cls, data: TResponse) -> "RequestResponse[Any, TResponse]": + def handled(cls, data: TResponse) -> "RequestResponse[TRequest, TResponse]": """Create a response indicating the request was handled. Correlation info (original_request, request_id) will be added automatically @@ -420,13 +417,15 @@ class RequestResponse(Generic[TRequest, TResponse]): return cls(is_handled=True, data=data) @classmethod - def forward(cls, modified_request: Any = None) -> "RequestResponse[Any, Any]": + def forward(cls, modified_request: Any = None) -> "RequestResponse[TRequest, TResponse]": """Create a response indicating the request should be forwarded.""" return cls(is_handled=False, forward_request=modified_request) @staticmethod - def _with_correlation( - original_response: "RequestResponse[Any, TResponse]", original_request: TRequest, request_id: str + def with_correlation( + original_response: "RequestResponse[TRequest, TResponse]", + original_request: TRequest, + request_id: str, ) -> "RequestResponse[TRequest, TResponse]": """Internal method to add correlation info to a response. @@ -451,7 +450,7 @@ class SubWorkflowRequestInfo: request_id: str # Original request ID from sub-workflow sub_workflow_id: str # ID of the WorkflowExecutor that sent this - data: Any # The actual request data + data: RequestInfoMessage # The actual request data @dataclass @@ -595,110 +594,9 @@ def intercepts_request( # endregion: Intercepts Request Decorator -# region Agent Executor - - -@dataclass -class AgentExecutorRequest: - """A request to an agent executor. - - Attributes: - messages: A list of chat messages to be processed by the agent. - should_respond: A flag indicating whether the agent should respond to the messages. - If False, the messages will be saved to the executor's cache but not sent to the agent. - """ - - messages: list[ChatMessage] - should_respond: bool = True - - -@dataclass -class AgentExecutorResponse: - """A response from an agent executor. - - Attributes: - executor_id: The ID of the executor that generated the response. - response: The agent run response containing the messages generated by the agent. - """ - - executor_id: str - agent_run_response: AgentRunResponse - - -class AgentExecutor(Executor): - """built-in executor that wraps an agent for handling messages.""" - - def __init__( - self, - agent: AIAgent, - *, - agent_thread: AgentThread | None = None, - streaming: bool = False, - id: str | None = None, - ): - """Initialize the executor with a unique identifier. - - Args: - agent: The agent to be wrapped by this executor. - agent_thread: The thread to use for running the agent. If None, a new thread will be created. - streaming: Whether to enable streaming for the agent. If enabled, the executor will emit - AgentRunStreamingEvent updates instead of a single AgentRunEvent. - id: A unique identifier for the executor. If None, a new UUID will be generated. - """ - super().__init__(id or agent.id) - self._agent = agent - self._agent_thread = agent_thread or self._agent.get_new_thread() - self._streaming = streaming - self._cache: list[ChatMessage] = [] - - @handler - async def run(self, request: AgentExecutorRequest, ctx: WorkflowContext[AgentExecutorResponse]) -> None: - """Run the agent executor with the given request.""" - self._cache.extend(request.messages) - - if request.should_respond: - if self._streaming: - updates: list[AgentRunResponseUpdate] = [] - async for update in self._agent.run_streaming( - self._cache, - thread=self._agent_thread, - ): - updates.append(update) - await ctx.add_event(AgentRunStreamingEvent(self.id, update)) - response = AgentRunResponse.from_agent_run_response_updates(updates) - else: - response = await self._agent.run( - self._cache, - thread=self._agent_thread, - ) - await ctx.add_event(AgentRunEvent(self.id, response)) - - await ctx.send_message(AgentExecutorResponse(self.id, response)) - self._cache.clear() - - -# endregion: Agent Executor - - # region Request Info Executor -@dataclass -class RequestInfoMessage: - """Base class for all request messages in workflows. - - Any message that should be routed to the RequestInfoExecutor for external - handling must inherit from this class. This ensures type safety and makes - the request/response pattern explicit. - """ - - request_id: str = field(default_factory=lambda: str(uuid.uuid4())) - - -# Note: SubWorkflowRequestInfo, SubWorkflowResponse, and RequestResponse -# have been moved before intercepts_request decorator - - class RequestInfoExecutor(Executor): """Built-in executor that handles request/response patterns in workflows. @@ -798,14 +696,105 @@ class RequestInfoExecutor(Executor): else: # Regular response - send directly back to source # Create a correlated response that includes both the response data and original request - correlated_response = RequestResponse.handled(response_data) - correlated_response = RequestResponse._with_correlation(correlated_response, event.data, request_id) + if not isinstance(event.data, RequestInfoMessage): + raise TypeError(f"Expected RequestInfoMessage, got {type(event.data)}") + correlated_response = RequestResponse[RequestInfoMessage, Any].handled(response_data) + correlated_response = RequestResponse[RequestInfoMessage, Any].with_correlation( + correlated_response, + event.data, + request_id, + ) await ctx.send_message(correlated_response, target_id=event.source_executor_id) # endregion: Request Info Executor +# region Agent Executor + + +@dataclass +class AgentExecutorRequest: + """A request to an agent executor. + + Attributes: + messages: A list of chat messages to be processed by the agent. + should_respond: A flag indicating whether the agent should respond to the messages. + If False, the messages will be saved to the executor's cache but not sent to the agent. + """ + + messages: list[ChatMessage] + should_respond: bool = True + + +@dataclass +class AgentExecutorResponse: + """A response from an agent executor. + + Attributes: + executor_id: The ID of the executor that generated the response. + response: The agent run response containing the messages generated by the agent. + """ + + executor_id: str + agent_run_response: AgentRunResponse + + +class AgentExecutor(Executor): + """built-in executor that wraps an agent for handling messages.""" + + def __init__( + self, + agent: AIAgent, + *, + agent_thread: AgentThread | None = None, + streaming: bool = False, + id: str | None = None, + ): + """Initialize the executor with a unique identifier. + + Args: + agent: The agent to be wrapped by this executor. + agent_thread: The thread to use for running the agent. If None, a new thread will be created. + streaming: Whether to enable streaming for the agent. If enabled, the executor will emit + AgentRunStreamingEvent updates instead of a single AgentRunEvent. + id: A unique identifier for the executor. If None, a new UUID will be generated. + """ + super().__init__(id or agent.id) + self._agent = agent + self._agent_thread = agent_thread or self._agent.get_new_thread() + self._streaming = streaming + self._cache: list[ChatMessage] = [] + + @handler + async def run(self, request: AgentExecutorRequest, ctx: WorkflowContext[AgentExecutorResponse]) -> None: + """Run the agent executor with the given request.""" + self._cache.extend(request.messages) + + if request.should_respond: + if self._streaming: + updates: list[AgentRunResponseUpdate] = [] + async for update in self._agent.run_streaming( + self._cache, + thread=self._agent_thread, + ): + updates.append(update) + await ctx.add_event(AgentRunStreamingEvent(self.id, update)) + response = AgentRunResponse.from_agent_run_response_updates(updates) + else: + response = await self._agent.run( + self._cache, + thread=self._agent_thread, + ) + await ctx.add_event(AgentRunEvent(self.id, response)) + + await ctx.send_message(AgentExecutorResponse(self.id, response)) + self._cache.clear() + + +# endregion: Agent Executor + + # region Workflow Executor @@ -889,6 +878,8 @@ class WorkflowExecutor(Executor): self._pending_requests[event.request_id] = event.data # Wrap request with routing context and send to parent + if not isinstance(event.data, RequestInfoMessage): + raise TypeError(f"Expected RequestInfoMessage, got {type(event.data)}") wrapped_request = SubWorkflowRequestInfo( request_id=event.request_id, sub_workflow_id=self.id, @@ -957,6 +948,8 @@ class WorkflowExecutor(Executor): self._pending_requests[event.request_id] = event.data # Send the new request to parent + if not isinstance(event.data, RequestInfoMessage): + raise TypeError(f"Expected RequestInfoMessage, got {type(event.data)}") wrapped_request = SubWorkflowRequestInfo( request_id=event.request_id, sub_workflow_id=self.id, diff --git a/python/samples/getting_started/workflow/step_09a_sub_workflow.py b/python/samples/getting_started/workflow/step_09a_sub_workflow.py index 88c775d762..e0ca073557 100644 --- a/python/samples/getting_started/workflow/step_09a_sub_workflow.py +++ b/python/samples/getting_started/workflow/step_09a_sub_workflow.py @@ -79,9 +79,7 @@ class TextProcessor(Executor): super().__init__(id="text_processor") @handler - async def process_text( - self, request: TextProcessingRequest, ctx: WorkflowContext[TextProcessingResult] - ) -> None: + async def process_text(self, request: TextProcessingRequest, ctx: WorkflowContext[TextProcessingResult]) -> None: """Process a text string and return statistics.""" text_preview = f"'{request.text[:50]}{'...' if len(request.text) > 50 else ''}'" print(f"๐Ÿ” Sub-workflow processing text (Task {request.task_id}): {text_preview}") @@ -108,7 +106,7 @@ class TextProcessor(Executor): # Parent workflow class TextProcessingOrchestrator(Executor): """Orchestrates multiple text processing tasks using sub-workflows.""" - + results: list[TextProcessingResult] = [] expected_count: int = 0 @@ -116,9 +114,7 @@ class TextProcessingOrchestrator(Executor): super().__init__(id="text_orchestrator") @handler - async def start_processing( - self, texts: list[str], ctx: WorkflowContext[TextProcessingRequest] - ) -> None: + async def start_processing(self, texts: list[str], ctx: WorkflowContext[TextProcessingRequest]) -> None: """Start processing multiple text strings.""" print(f"๐Ÿ“„ Starting processing of {len(texts)} text strings") print("=" * 60) @@ -127,15 +123,13 @@ class TextProcessingOrchestrator(Executor): # Send each text to a sub-workflow for i, text in enumerate(texts): - task_id = f"task_{i+1}" + task_id = f"task_{i + 1}" request = TextProcessingRequest(text=text, task_id=task_id) print(f"๐Ÿ“ค Dispatching {task_id} to sub-workflow") await ctx.send_message(request, target_id="text_processor_workflow") @handler - async def collect_result( - self, result: TextProcessingResult, ctx: WorkflowContext[None] - ) -> None: + async def collect_result(self, result: TextProcessingResult, ctx: WorkflowContext[None]) -> None: """Collect results from sub-workflows.""" print(f"๐Ÿ“ฅ Collected result from {result.task_id}") self.results.append(result) @@ -168,11 +162,7 @@ async def main(): # Step 1: Create the text processing sub-workflow text_processor = TextProcessor() - processing_workflow = ( - WorkflowBuilder() - .set_start_executor(text_processor) - .build() - ) + processing_workflow = WorkflowBuilder().set_start_executor(text_processor).build() print("๐Ÿ”ง Setting up parent workflow...") @@ -205,7 +195,7 @@ async def main(): result = await main_workflow.run(test_texts) # Step 5: Display results - print(f"\n๐Ÿ“Š Processing Results:") + print("\n๐Ÿ“Š Processing Results:") print("=" * 60) # Sort results by task_id for consistent display @@ -213,12 +203,12 @@ async def main(): for result in sorted_results: preview = result.text[:30] + "..." if len(result.text) > 30 else result.text - preview = preview.replace('\n', ' ').strip() or '(empty)' + preview = preview.replace("\n", " ").strip() or "(empty)" print(f"โœ… {result.task_id}: '{preview}' -> {result.word_count} words, {result.char_count} chars") # Step 6: Display summary summary = orchestrator.get_summary() - print(f"\n๐Ÿ“ˆ Summary:") + print("\n๐Ÿ“ˆ Summary:") print("=" * 60) print(f"๐Ÿ“„ Total texts processed: {summary['total_texts']}") print(f"๐Ÿ“ Total words: {summary['total_words']}") @@ -226,7 +216,7 @@ async def main(): print(f"๐Ÿ“Š Average words per text: {summary['average_words_per_text']}") print(f"๐Ÿ“ Average characters per text: {summary['average_characters_per_text']}") - print(f"\n๐Ÿ Processing complete!") + print("\n๐Ÿ Processing complete!") if __name__ == "__main__": diff --git a/python/samples/getting_started/workflow/step_09b_sub_workflow_request_interception.py b/python/samples/getting_started/workflow/step_09b_sub_workflow_request_interception.py index 21f8b65ad7..a3933c6c9e 100644 --- a/python/samples/getting_started/workflow/step_09b_sub_workflow_request_interception.py +++ b/python/samples/getting_started/workflow/step_09b_sub_workflow_request_interception.py @@ -127,11 +127,19 @@ class EmailValidator(Executor): ) -> None: """Handle domain check response from RequestInfo with correlation.""" approved = bool(response.data) - domain = response.original_request.domain if (hasattr(response, 'original_request') and response.original_request) else "unknown" + domain = ( + response.original_request.domain + if (hasattr(response, "original_request") and response.original_request) + else "unknown" + ) print(f"๐Ÿ“ฌ Sub-workflow received domain response for '{domain}': {approved}") - + # Find the corresponding email using the request_id - request_id = response.original_request.request_id if (hasattr(response, 'original_request') and response.original_request) else None + request_id = ( + response.original_request.request_id + if (hasattr(response, "original_request") and response.original_request) + else None + ) if request_id and request_id in self._pending_emails: email = self._pending_emails.pop(request_id) # Remove from pending result = ValidationResult( @@ -146,6 +154,7 @@ class EmailValidator(Executor): # 3. Implement the parent workflow with request interception class SmartEmailOrchestrator(Executor): """Parent orchestrator that can intercept domain checks.""" + approved_domains: set[str] = set() def __init__(self, approved_domains: set[str] | None = None): @@ -177,7 +186,7 @@ class SmartEmailOrchestrator(Executor): print(f"โœ… Domain '{request.domain}' is pre-approved locally!") return RequestResponse[DomainCheckRequest, bool].handled(True) print(f"โ“ Domain '{request.domain}' unknown, forwarding to external service...") - return RequestResponse.forward() + return RequestResponse[DomainCheckRequest, bool].forward() @handler async def collect_result(self, result: ValidationResult, ctx: WorkflowContext[None]) -> None: @@ -196,7 +205,7 @@ async def run_example() -> None: """Run the sub-workflow example.""" print("๐Ÿš€ Setting up sub-workflow with request interception...") print() - + # 4. Build the sub-workflow email_validator = EmailValidator() # Match the target_id used in EmailValidator ("email_request_info") @@ -262,12 +271,12 @@ async def run_example() -> None: print("\n๐ŸŽฏ All requests were intercepted and handled locally!") # 10. Display final summary - print(f"\n๐Ÿ“Š Final Results Summary:") + print("\n๐Ÿ“Š Final Results Summary:") print("=" * 60) for result in orchestrator.results: status = "โœ… VALID" if result.is_valid else "โŒ INVALID" print(f"{status} {result.email}: {result.reason}") - + print(f"\n๐Ÿ Processed {len(orchestrator.results)} emails total") diff --git a/python/samples/getting_started/workflow/step_09c_sub_workflow_parallel_requests.py b/python/samples/getting_started/workflow/step_09c_sub_workflow_parallel_requests.py index c94cfd79b2..02892d7f0a 100644 --- a/python/samples/getting_started/workflow/step_09c_sub_workflow_parallel_requests.py +++ b/python/samples/getting_started/workflow/step_09c_sub_workflow_parallel_requests.py @@ -143,7 +143,10 @@ class ResourceRequester(Executor): # Send to parent workflow for interception - not to target_id await ctx.send_message(request) elif req_type == "policy": - print(f" ๐Ÿ›ก๏ธ Checking policy: {req_data.get('type', 'cpu')} x{req_data.get('amount', 1)} ({req_data.get('policy_type', 'quota')})") + print( + f" ๐Ÿ›ก๏ธ Checking policy: {req_data.get('type', 'cpu')} x{req_data.get('amount', 1)} " + f"({req_data.get('policy_type', 'quota')})" + ) request = PolicyCheckRequest( resource_type=req_data.get("type", "cpu"), amount=req_data.get("amount", 1), @@ -162,7 +165,8 @@ class ResourceRequester(Executor): if response.data: source_icon = "๐Ÿช" if response.data.source == "cache" else "๐ŸŒ" print( - f"๐Ÿ“ฆ {source_icon} Sub-workflow received: {response.data.allocated} {response.data.resource_type} from {response.data.source}" + f"๐Ÿ“ฆ {source_icon} Sub-workflow received: {response.data.allocated} {response.data.resource_type} " + f"from {response.data.source}" ) if self._collect_results(): # Emit completion event and send RequestFinished to the parent workflow. @@ -175,7 +179,10 @@ class ResourceRequester(Executor): """Handle policy check response.""" if response.data: status_icon = "โœ…" if response.data.approved else "โŒ" - print(f"๐Ÿ›ก๏ธ {status_icon} Sub-workflow received policy response: {response.data.approved} - {response.data.reason}") + print( + f"๐Ÿ›ก๏ธ {status_icon} Sub-workflow received policy response: " + f"{response.data.approved} - {response.data.reason}" + ) if self._collect_results(): # Emit completion event and send RequestFinished to the parent workflow. await ctx.add_event(WorkflowCompletedEvent(RequestFinished())) @@ -218,7 +225,7 @@ class ResourceCache(Executor): # Cache miss - forward to external print(f" โŒ Cache miss: need {request.amount}, have {available} {request.resource_type}") - return RequestResponse.forward() + return RequestResponse[ResourceRequest, ResourceResponse].forward() @handler async def collect_result( @@ -228,7 +235,8 @@ class ResourceCache(Executor): if response.data and response.data.source != "cache": # Don't double-count our own results self.results.append(response.data) print( - f"๐Ÿช ๐ŸŒ Cache received external response: {response.data.allocated} {response.data.resource_type} from {response.data.source}" + f"๐Ÿช ๐ŸŒ Cache received external response: {response.data.allocated} {response.data.resource_type} " + f"from {response.data.source}" ) @@ -265,11 +273,11 @@ class PolicyEngine(Executor): return RequestResponse[PolicyCheckRequest, PolicyResponse].handled(response) # Exceeds quota - forward to external for review print(f" โŒ Policy exceeds quota: {request.amount} > {quota_limit}, forwarding to external") - return RequestResponse.forward() + return RequestResponse[PolicyCheckRequest, PolicyResponse].forward() # Unknown policy type - forward to external print(f" โ“ Unknown policy type: {request.policy_type}, forwarding") - return RequestResponse.forward() + return RequestResponse[PolicyCheckRequest, PolicyResponse].forward() @handler async def collect_policy_result( @@ -292,19 +300,22 @@ class Coordinator(Executor): @handler async def handle_completion(self, completion: RequestFinished, ctx: WorkflowContext[None]) -> None: - """Handle sub-workflow completion. It comes from the sub-workflow emitted WorkflowCompletionEvent's data field.""" - print(f"๐ŸŽฏ Main workflow received completion.") + """Handle sub-workflow completion. + + It comes from the sub-workflow emitted WorkflowCompletionEvent's data field. + """ + print("๐ŸŽฏ Main workflow received completion.") async def main() -> None: """Demonstrate parallel request interception patterns.""" print("๐Ÿš€ Starting Sub-Workflow Parallel Request Interception Demo...") print("=" * 60) - - # 5. Create the sub-workflow + + # 5. Create the sub-workflow resource_requester = ResourceRequester() sub_request_info = RequestInfoExecutor(id="sub_request_info") - + sub_workflow = ( WorkflowBuilder() .set_start_executor(resource_requester) @@ -353,7 +364,10 @@ async def main() -> None: print(f"๐Ÿงช Testing with {len(test_requests)} mixed requests:") for i, req in enumerate(test_requests, 1): req_icon = "๐Ÿ“ฆ" if req["request_type"] == "resource" else "๐Ÿ›ก๏ธ" - print(f" {i}. {req_icon} {req['type']} x{req['amount']} ({req.get('priority', req.get('policy_type', 'default'))})") + print( + f" {i}. {req_icon} {req['type']} x{req['amount']} " + f"({req.get('priority', req.get('policy_type', 'default'))})" + ) print("=" * 70) # 8. Run the workflow @@ -398,11 +412,11 @@ async def main() -> None: status_icon = "โœ…" if result.approved else "โŒ" print(f" {status_icon} Approved: {result.approved} - {result.reason}") - print(f"\n๐Ÿ’พ Final Cache State:") + print("\n๐Ÿ’พ Final Cache State:") for resource, amount in cache.cache.items(): print(f" ๐Ÿ“ฆ {resource}: {amount} remaining") - print(f"\n๐Ÿ“ˆ Summary:") + print("\n๐Ÿ“ˆ Summary:") print(f" ๐ŸŽฏ Total requests: {len(test_requests)}") print(f" ๐Ÿช Resource requests handled: {len(cache.results)}") print(f" ๐Ÿ›ก๏ธ Policy requests handled: {len(policy.results)}") @@ -410,5 +424,6 @@ async def main() -> None: print("\n" + "=" * 70) + if __name__ == "__main__": asyncio.run(main())