# Copyright (c) Microsoft. All rights reserved. import tempfile from dataclasses import dataclass from typing import Any import pytest from agent_framework.workflow import ( Executor, FileCheckpointStorage, RequestInfoEvent, RequestInfoExecutor, RequestInfoMessage, RequestResponse, WorkflowBuilder, WorkflowCompletedEvent, WorkflowContext, WorkflowEvent, handler, ) from agent_framework_workflow import Message @dataclass class NumberMessage: """A mock message for testing purposes.""" data: int class IncrementExecutor(Executor): """An executor that increments message data by a specified amount for testing purposes.""" limit: int = 10 increment: int = 1 @handler async def mock_handler(self, message: NumberMessage, ctx: WorkflowContext[NumberMessage]) -> None: if message.data < self.limit: await ctx.send_message(NumberMessage(data=message.data + self.increment)) else: await ctx.add_event(WorkflowCompletedEvent(data=message.data)) class AggregatorExecutor(Executor): """A mock executor that aggregates results from multiple executors.""" @handler async def mock_handler(self, messages: list[NumberMessage], ctx: WorkflowContext[Any]) -> None: # This mock simply returns the data incremented by 1 await ctx.add_event(WorkflowCompletedEvent(data=sum(msg.data for msg in messages))) @dataclass class ApprovalMessage: """A mock message for approval requests.""" approved: bool class MockExecutorRequestApproval(Executor): """A mock executor that simulates a request for approval.""" @handler async def mock_handler_a(self, message: NumberMessage, ctx: WorkflowContext[RequestInfoMessage]) -> None: """A mock handler that requests approval.""" await ctx.set_shared_state(self.id, message.data) await ctx.send_message(RequestInfoMessage()) @handler async def mock_handler_b( self, message: RequestResponse[RequestInfoMessage, ApprovalMessage], ctx: WorkflowContext[NumberMessage] ) -> None: """A mock handler that processes the approval response.""" data = await ctx.get_shared_state(self.id) assert isinstance(message.data, ApprovalMessage) if message.data.approved: await ctx.add_event(WorkflowCompletedEvent(data=data)) else: await ctx.send_message(NumberMessage(data=data)) async def test_workflow_run_streaming(): """Test the workflow run stream.""" executor_a = IncrementExecutor(id="executor_a") executor_b = IncrementExecutor(id="executor_b") workflow = ( WorkflowBuilder() .set_start_executor(executor_a) .add_edge(executor_a, executor_b) .add_edge(executor_b, executor_a) .build() ) result: int | None = None async for event in workflow.run_streaming(NumberMessage(data=0)): assert isinstance(event, WorkflowEvent) if isinstance(event, WorkflowCompletedEvent): result = event.data assert result is not None and result == 10 async def test_workflow_run_stream_not_completed(): """Test the workflow run stream.""" executor_a = IncrementExecutor(id="executor_a") executor_b = IncrementExecutor(id="executor_b") workflow = ( WorkflowBuilder() .set_start_executor(executor_a) .add_edge(executor_a, executor_b) .add_edge(executor_b, executor_a) .set_max_iterations(5) .build() ) with pytest.raises(RuntimeError): async for _ in workflow.run_streaming(NumberMessage(data=0)): pass async def test_workflow_run(): """Test the workflow run.""" executor_a = IncrementExecutor(id="executor_a") executor_b = IncrementExecutor(id="executor_b") workflow = ( WorkflowBuilder() .set_start_executor(executor_a) .add_edge(executor_a, executor_b) .add_edge(executor_b, executor_a) .build() ) events = await workflow.run(NumberMessage(data=0)) completed_event = events.get_completed_event() assert isinstance(completed_event, WorkflowCompletedEvent) assert completed_event.data == 10 async def test_workflow_run_not_completed(): """Test the workflow run.""" executor_a = IncrementExecutor(id="executor_a") executor_b = IncrementExecutor(id="executor_b") workflow = ( WorkflowBuilder() .set_start_executor(executor_a) .add_edge(executor_a, executor_b) .add_edge(executor_b, executor_a) .set_max_iterations(5) .build() ) with pytest.raises(RuntimeError): await workflow.run(NumberMessage(data=0)) async def test_workflow_send_responses_streaming(): """Test the workflow run with approval.""" executor_a = IncrementExecutor(id="executor_a") executor_b = MockExecutorRequestApproval(id="executor_b") request_info_executor = RequestInfoExecutor() workflow = ( WorkflowBuilder() .set_start_executor(executor_a) .add_edge(executor_a, executor_b) .add_edge(executor_b, executor_a) .add_edge(executor_b, request_info_executor) .add_edge(request_info_executor, executor_b) .build() ) request_info_event: RequestInfoEvent | None = None async for event in workflow.run_streaming(NumberMessage(data=0)): if isinstance(event, RequestInfoEvent): request_info_event = event assert request_info_event is not None result: int | None = None async for event in workflow.send_responses_streaming({ request_info_event.request_id: ApprovalMessage(approved=True) }): if isinstance(event, WorkflowCompletedEvent): result = event.data assert result is not None and result == 1 # The data should be incremented by 1 from the initial message async def test_workflow_send_responses(): """Test the workflow run with approval.""" executor_a = IncrementExecutor(id="executor_a") executor_b = MockExecutorRequestApproval(id="executor_b") request_info_executor = RequestInfoExecutor() workflow = ( WorkflowBuilder() .set_start_executor(executor_a) .add_edge(executor_a, executor_b) .add_edge(executor_b, executor_a) .add_edge(executor_b, request_info_executor) .add_edge(request_info_executor, executor_b) .build() ) events = await workflow.run(NumberMessage(data=0)) request_info_events = events.get_request_info_events() assert len(request_info_events) == 1 result = await workflow.send_responses({request_info_events[0].request_id: ApprovalMessage(approved=True)}) completed_event = result.get_completed_event() assert isinstance(completed_event, WorkflowCompletedEvent) assert completed_event.data == 1 # The data should be incremented by 1 from the initial message async def test_fan_out(): """Test a fan-out workflow.""" executor_a = IncrementExecutor(id="executor_a") executor_b = IncrementExecutor(id="executor_b", limit=1) executor_c = IncrementExecutor(id="executor_c", limit=2) # This executor will not complete the workflow workflow = ( WorkflowBuilder().set_start_executor(executor_a).add_fan_out_edges(executor_a, [executor_b, executor_c]).build() ) events = await workflow.run(NumberMessage(data=0)) # Each executor will emit two events: ExecutorInvokeEvent and ExecutorCompletedEvent # executor_b will also emit a WorkflowCompletedEvent assert len(events) == 7 completed_event = events.get_completed_event() assert completed_event is not None and completed_event.data == 1 async def test_fan_out_multiple_completed_events(): """Test a fan-out workflow with multiple completed events.""" executor_a = IncrementExecutor(id="executor_a") executor_b = IncrementExecutor(id="executor_b", limit=1) executor_c = IncrementExecutor(id="executor_c", limit=1) workflow = ( WorkflowBuilder().set_start_executor(executor_a).add_fan_out_edges(executor_a, [executor_b, executor_c]).build() ) events = await workflow.run(NumberMessage(data=0)) # Each executor will emit two events: ExecutorInvokeEvent and ExecutorCompletedEvent # executor_a and executor_b will also emit a WorkflowCompletedEvent assert len(events) == 8 with pytest.raises(ValueError): events.get_completed_event() async def test_fan_in(): """Test a fan-in workflow.""" executor_a = IncrementExecutor(id="executor_a") executor_b = IncrementExecutor(id="executor_b") executor_c = IncrementExecutor(id="executor_c") aggregator = AggregatorExecutor(id="aggregator") workflow = ( WorkflowBuilder() .set_start_executor(executor_a) .add_fan_out_edges(executor_a, [executor_b, executor_c]) .add_fan_in_edges([executor_b, executor_c], aggregator) .build() ) events = await workflow.run(NumberMessage(data=0)) # Each executor will emit two events: ExecutorInvokeEvent and ExecutorCompletedEvent # aggregator will also emit a WorkflowCompletedEvent assert len(events) == 9 completed_event = events.get_completed_event() assert completed_event is not None and completed_event.data == 4 @pytest.fixture def simple_executor() -> Executor: class SimpleExecutor(Executor): @handler async def handle_message(self, message: Message, context: WorkflowContext[None]) -> None: pass return SimpleExecutor(id="test_executor") async def test_workflow_with_checkpointing_enabled(simple_executor: Executor): """Test that a workflow can be built with checkpointing enabled.""" with tempfile.TemporaryDirectory() as temp_dir: storage = FileCheckpointStorage(temp_dir) # Build workflow with checkpointing - should not raise any errors workflow = ( WorkflowBuilder() .add_edge(simple_executor, simple_executor) # Self-loop to satisfy graph requirements .set_start_executor(simple_executor) .with_checkpointing(storage) .build() ) # Verify workflow was created and can run test_message = Message(data="test message", source_id="test", target_id=None) result = await workflow.run(test_message) assert result is not None async def test_workflow_checkpointing_not_enabled_for_external_restore(simple_executor: Executor): """Test that external checkpoint restoration fails when workflow doesn't support checkpointing.""" # Build workflow WITHOUT checkpointing workflow = ( WorkflowBuilder() .add_edge(simple_executor, simple_executor) # Self-loop to satisfy graph requirements .set_start_executor(simple_executor) .build() ) # Attempt to restore from checkpoint without providing external storage should fail try: [event async for event in workflow.run_streaming_from_checkpoint("fake-checkpoint-id")] raise AssertionError("Expected ValueError to be raised") except ValueError as e: assert "Cannot restore from checkpoint" in str(e) assert "either provide checkpoint_storage parameter" in str(e) async def test_workflow_run_stream_from_checkpoint_no_checkpointing_enabled(simple_executor: Executor): # Build workflow WITHOUT checkpointing workflow = ( WorkflowBuilder() .add_edge(simple_executor, simple_executor) # Self-loop to satisfy graph requirements .set_start_executor(simple_executor) .build() ) # Attempt to run from checkpoint should fail try: async for _ in workflow.run_streaming_from_checkpoint("fake_checkpoint_id"): pass raise AssertionError("Expected ValueError to be raised") except ValueError as e: assert "Cannot restore from checkpoint" in str(e) assert "either provide checkpoint_storage parameter" in str(e) async def test_workflow_run_stream_from_checkpoint_invalid_checkpoint(simple_executor: Executor): """Test that attempting to restore from a non-existent checkpoint fails appropriately.""" with tempfile.TemporaryDirectory() as temp_dir: storage = FileCheckpointStorage(temp_dir) # Build workflow with checkpointing workflow = ( WorkflowBuilder() .add_edge(simple_executor, simple_executor) # Self-loop to satisfy graph requirements .set_start_executor(simple_executor) .with_checkpointing(storage) .build() ) # Attempt to run from non-existent checkpoint should fail try: async for _ in workflow.run_streaming_from_checkpoint("nonexistent_checkpoint_id"): pass raise AssertionError("Expected RuntimeError to be raised") except RuntimeError as e: assert "Failed to restore from checkpoint" in str(e) async def test_workflow_run_stream_from_checkpoint_with_external_storage(simple_executor: Executor): """Test that external checkpoint storage can be provided for restoration.""" with tempfile.TemporaryDirectory() as temp_dir: storage = FileCheckpointStorage(temp_dir) # Create a test checkpoint manually in storage from agent_framework_workflow._checkpoint import WorkflowCheckpoint test_checkpoint = WorkflowCheckpoint( workflow_id="test-workflow", messages={}, shared_state={}, executor_states={}, iteration_count=0, max_iterations=100, ) checkpoint_id = await storage.save_checkpoint(test_checkpoint) # Create a workflow WITHOUT checkpointing workflow_without_checkpointing = ( WorkflowBuilder().add_edge(simple_executor, simple_executor).set_start_executor(simple_executor).build() ) # Resume from checkpoint using external storage parameter try: events: list[WorkflowEvent] = [] async for event in workflow_without_checkpointing.run_streaming_from_checkpoint( checkpoint_id, checkpoint_storage=storage ): events.append(event) if len(events) >= 2: # Limit to avoid infinite loops break except Exception: # Expected since we have minimal setup, but method should accept the parameters pass async def test_workflow_run_from_checkpoint_non_streaming(simple_executor: Executor): """Test the non-streaming run_from_checkpoint method.""" with tempfile.TemporaryDirectory() as temp_dir: storage = FileCheckpointStorage(temp_dir) # Create a test checkpoint manually in storage from agent_framework_workflow._checkpoint import WorkflowCheckpoint test_checkpoint = WorkflowCheckpoint( workflow_id="test-workflow", messages={}, shared_state={}, executor_states={}, iteration_count=0, max_iterations=100, ) checkpoint_id = await storage.save_checkpoint(test_checkpoint) # Build workflow with checkpointing workflow = ( WorkflowBuilder() .add_edge(simple_executor, simple_executor) .set_start_executor(simple_executor) .with_checkpointing(storage) .build() ) # Test non-streaming run_from_checkpoint method result = await workflow.run_from_checkpoint(checkpoint_id) assert isinstance(result, list) # Should return WorkflowRunResult which extends list assert hasattr(result, "get_completed_event") # Should have WorkflowRunResult methods async def test_workflow_run_stream_from_checkpoint_with_responses(simple_executor: Executor): """Test that run_streaming_from_checkpoint accepts responses parameter.""" with tempfile.TemporaryDirectory() as temp_dir: storage = FileCheckpointStorage(temp_dir) # Create a test checkpoint manually in storage from agent_framework_workflow._checkpoint import WorkflowCheckpoint test_checkpoint = WorkflowCheckpoint( workflow_id="test-workflow", messages={}, shared_state={}, executor_states={}, iteration_count=0, max_iterations=100, ) checkpoint_id = await storage.save_checkpoint(test_checkpoint) # Build workflow with checkpointing workflow = ( WorkflowBuilder() .add_edge(simple_executor, simple_executor) .set_start_executor(simple_executor) .with_checkpointing(storage) .build() ) # Test that run_stream_from_checkpoint accepts responses parameter responses = {"request_123": {"data": "test_response"}} try: events: list[WorkflowEvent] = [] async for event in workflow.run_streaming_from_checkpoint(checkpoint_id, responses=responses): events.append(event) if len(events) >= 2: # Limit to avoid infinite loops break except Exception: # Expected since we have minimal setup, but method should accept the parameters pass @dataclass class StateTrackingMessage: """A message that tracks state for testing context reset behavior.""" data: str run_id: str class StateTrackingExecutor(Executor): """An executor that tracks state in shared state to test context reset behavior.""" @handler async def handle_message(self, message: StateTrackingMessage, ctx: WorkflowContext[Any]) -> None: """Handle the message and track it in shared state.""" # Get existing messages from shared state try: existing_messages = await ctx.get_shared_state("processed_messages") except KeyError: existing_messages = [] # Record this message message_record = f"{message.run_id}:{message.data}" existing_messages.append(message_record) # type: ignore # Update shared state await ctx.set_shared_state("processed_messages", existing_messages) # Complete workflow with current shared state await ctx.add_event(WorkflowCompletedEvent(data=existing_messages.copy())) # type: ignore async def test_workflow_multiple_runs_no_state_collision(): """Test that running the same workflow instance multiple times doesn't have state collision.""" with tempfile.TemporaryDirectory() as temp_dir: storage = FileCheckpointStorage(temp_dir) # Create executor that tracks state in shared state state_executor = StateTrackingExecutor(id="state_executor") # Build workflow with checkpointing workflow = ( WorkflowBuilder() .add_edge(state_executor, state_executor) # Self-loop to satisfy graph requirements .set_start_executor(state_executor) .with_checkpointing(storage) .build() ) # Run 1: Should only see messages from run 1 result1 = await workflow.run(StateTrackingMessage(data="message1", run_id="run1")) completed1 = result1.get_completed_event() assert completed1 is not None assert completed1.data == ["run1:message1"] # Run 2: Should only see messages from run 2, not run 1 result2 = await workflow.run(StateTrackingMessage(data="message2", run_id="run2")) completed2 = result2.get_completed_event() assert completed2 is not None assert completed2.data == ["run2:message2"] # Should NOT contain run1 data # Run 3: Should only see messages from run 3 result3 = await workflow.run(StateTrackingMessage(data="message3", run_id="run3")) completed3 = result3.get_completed_event() assert completed3 is not None assert completed3.data == ["run3:message3"] # Should NOT contain run1 or run2 data # Verify that each run only processed its own message # This confirms that the checkpointable context properly resets between runs assert completed1.data != completed2.data assert completed2.data != completed3.data assert completed1.data != completed3.data async def test_comprehensive_edge_groups_workflow(): """Test a workflow that uses SwitchCaseEdgeGroup, FanOutEdgeGroup, and FanInEdgeGroup.""" from agent_framework_workflow._edge import Case, Default # Create 6 executors for different roles with different increment values router = IncrementExecutor(id="router", limit=1000, increment=1) # Increment by 1 processor_a = IncrementExecutor(id="proc_a", limit=1000, increment=1) # Increment by 1 processor_b = IncrementExecutor(id="proc_b", limit=1000, increment=2) # Increment by 2 (different from proc_a) fanout_hub = IncrementExecutor(id="fanout_hub", limit=1000, increment=1) # Increment by 1 parallel_1 = IncrementExecutor(id="parallel_1", limit=1000, increment=3) # Increment by 3 parallel_2 = IncrementExecutor( id="parallel_2", limit=1000, increment=5 ) # Increment by 5 (different from parallel_1) aggregator = AggregatorExecutor(id="aggregator") # Combines results from parallel processors # Build workflow with different edge group types: # 1. SwitchCase: router -> (proc_a if data < 5, else proc_b) # 2. Direct edge: proc_a -> fanout_hub, proc_b -> fanout_hub # 3. FanOut: fanout_hub -> [parallel_1, parallel_2] # 4. FanIn: [parallel_1, parallel_2] -> aggregator workflow = ( WorkflowBuilder() .set_start_executor(router) # Switch-case routing based on message data .add_switch_case_edge_group( router, [ Case(condition=lambda msg: msg.data < 5, target=processor_a), Default(target=processor_b), ], ) # Both processors send to fanout hub .add_edge(processor_a, fanout_hub) .add_edge(processor_b, fanout_hub) # Fan out to parallel processors .add_fan_out_edges(fanout_hub, [parallel_1, parallel_2]) # Fan in to aggregator .add_fan_in_edges([parallel_1, parallel_2], aggregator) .build() ) # Test with small number (should go through processor_a) # router(2->3) -> switch routes to proc_a -> proc_a(3->4) -> fanout_hub(4->5) # -> [parallel_1(5->8), parallel_2(5->10)] -> aggregator(8+10=18) events_small = await workflow.run(NumberMessage(data=2)) completed_small = events_small.get_completed_event() assert completed_small is not None assert completed_small.data == 18 # Exact expected result: 8+10 from parallel processors # Test with large number (should go through processor_b) # router(8->9) -> switch routes to proc_b -> proc_b(9->11) -> fanout_hub(11->12) # -> [parallel_1(12->15), parallel_2(12->17)] -> aggregator(15+17=32) events_large = await workflow.run(NumberMessage(data=8)) completed_large = events_large.get_completed_event() assert completed_large is not None assert completed_large.data == 32 # Exact expected result: 15+17 from parallel processors # The key verification is that we successfully executed a workflow using all three edge group types # and that both switch-case paths work (small vs large numbers) # Verify we had multiple events indicating complex execution path assert len(events_small) >= 6 # Should have multiple executors involved assert len(events_large) >= 6 # Verify different paths were taken by checking exact results assert completed_small.data == 18, f"Small number path should result in 18, got {completed_small.data}" assert completed_large.data == 32, f"Large number path should result in 32, got {completed_large.data}" assert completed_small.data != completed_large.data, "Different paths should produce different results" # Both tests should complete successfully, proving all edge group types work # Additional verification: check that the workflow contains the expected edge group types edge_groups = workflow.edge_groups has_switch_case = any(edge_group.__class__.__name__ == "SwitchCaseEdgeGroup" for edge_group in edge_groups) has_fan_out = any(edge_group.__class__.__name__ == "FanOutEdgeGroup" for edge_group in edge_groups) has_fan_in = any(edge_group.__class__.__name__ == "FanInEdgeGroup" for edge_group in edge_groups) assert has_switch_case, "Workflow should contain SwitchCaseEdgeGroup" assert has_fan_out, "Workflow should contain FanOutEdgeGroup" assert has_fan_in, "Workflow should contain FanInEdgeGroup" async def test_workflow_with_simple_cycle_and_exit_condition(): """Test a simpler workflow with a cycle that has a clear exit condition.""" # Create a simple cycle: A -> B -> A, with A having an exit condition executor_a = IncrementExecutor(id="exec_a", limit=6, increment=2) # Exit when data >= 6 executor_b = IncrementExecutor(id="exec_b", limit=1000, increment=1) # Never exit, just increment # Simple cycle: A -> B -> A, A exits when limit reached workflow = ( WorkflowBuilder() .set_start_executor(executor_a) .add_edge(executor_a, executor_b) # A -> B .add_edge(executor_b, executor_a) # B -> A (creates cycle) .build() ) # Test the cycle # Expected: exec_a(2->4) -> exec_b(4->5) -> exec_a(5->7, completes because 7 >= 6) events = await workflow.run(NumberMessage(data=2)) completed_event = events.get_completed_event() assert completed_event is not None assert ( completed_event.data is not None and completed_event.data >= 6 ) # Should complete when executor_a reaches its limit # Verify cycling occurred (should have events from both executors) # Check for ExecutorInvokeEvent and ExecutorCompletedEvent types that have executor_id from agent_framework_workflow._events import ExecutorCompletedEvent, ExecutorInvokeEvent executor_events = [e for e in events if isinstance(e, (ExecutorInvokeEvent, ExecutorCompletedEvent))] executor_ids = {e.executor_id for e in executor_events} assert "exec_a" in executor_ids, "Should have events from executor A" assert "exec_b" in executor_ids, "Should have events from executor B" # Should have multiple events due to cycling assert len(events) >= 4, f"Expected at least 4 events due to cycling, got {len(events)}"