# Copyright (c) Microsoft. All rights reserved. from dataclasses import dataclass from typing import Any from unittest.mock import patch import pytest from agent_framework.workflow import Executor, WorkflowContext, handler from agent_framework_workflow._edge import ( Edge, FanInEdgeGroup, FanOutEdgeGroup, SingleEdgeGroup, SwitchCaseEdgeGroup, SwitchCaseEdgeGroupCase, SwitchCaseEdgeGroupDefault, ) from agent_framework_workflow._edge_runner import create_edge_runner from agent_framework_workflow._runner_context import InProcRunnerContext, Message from agent_framework_workflow._shared_state import SharedState @dataclass class MockMessage: """A mock message for testing purposes.""" data: Any @dataclass class MockMessageSecondary: """A secondary mock message for testing purposes.""" data: Any class MockExecutor(Executor): """A mock executor for testing purposes.""" call_count: int = 0 last_message: Any = None @handler async def mock_handler(self, message: MockMessage, ctx: WorkflowContext) -> None: """A mock handler that does nothing.""" self.call_count += 1 self.last_message = message class MockExecutorSecondary(Executor): """A secondary mock executor for testing purposes.""" call_count: int = 0 last_message: Any = None @handler async def mock_handler_secondary(self, message: MockMessageSecondary, ctx: WorkflowContext) -> None: """A secondary mock handler that does nothing.""" self.call_count += 1 self.last_message = message class MockAggregator(Executor): """A mock aggregator for testing purposes.""" call_count: int = 0 last_message: Any = None @handler async def mock_aggregator_handler(self, message: list[MockMessage], ctx: WorkflowContext) -> None: """A mock aggregator handler that does nothing.""" self.call_count += 1 self.last_message = message # region Edge def test_create_edge(): """Test creating an edge with a source and target executor.""" source = MockExecutor(id="source_executor") target = MockExecutor(id="target_executor") edge = Edge(source_id=source.id, target_id=target.id) assert edge.source_id == "source_executor" assert edge.target_id == "target_executor" assert edge.id == f"{edge.source_id}{Edge.ID_SEPARATOR}{edge.target_id}" def test_edge_can_handle(): """Test creating an edge with a source and target executor.""" source = MockExecutor(id="source_executor") target = MockExecutor(id="target_executor") edge = Edge(source_id=source.id, target_id=target.id) assert edge.should_route(MockMessage(data="test")) # endregion Edge # region SingleEdgeGroup def test_single_edge_group(): """Test creating a single edge group.""" source = MockExecutor(id="source_executor") target = MockExecutor(id="target_executor") edge_group = SingleEdgeGroup(source_id=source.id, target_id=target.id) assert edge_group.source_executor_ids == [source.id] assert edge_group.target_executor_ids == [target.id] assert edge_group.edges[0].source_id == "source_executor" assert edge_group.edges[0].target_id == "target_executor" def test_single_edge_group_with_condition(): """Test creating a single edge group with a condition.""" source = MockExecutor(id="source_executor") target = MockExecutor(id="target_executor") edge_group = SingleEdgeGroup(source_id=source.id, target_id=target.id, condition=lambda x: x.data == "test") assert edge_group.source_executor_ids == [source.id] assert edge_group.target_executor_ids == [target.id] assert edge_group.edges[0].source_id == "source_executor" assert edge_group.edges[0].target_id == "target_executor" assert edge_group.edges[0]._condition is not None # type: ignore async def test_single_edge_group_send_message() -> None: """Test sending a message through a single edge runner.""" source = MockExecutor(id="source_executor") target = MockExecutor(id="target_executor") executors: dict[str, Executor] = {source.id: source, target.id: target} edge_group = SingleEdgeGroup(source_id=source.id, target_id=target.id) edge_runner = create_edge_runner(edge_group, executors) shared_state = SharedState() ctx = InProcRunnerContext() data = MockMessage(data="test") message = Message(data=data, source_id=source.id) success = await edge_runner.send_message(message, shared_state, ctx) assert success is True async def test_single_edge_group_send_message_with_target() -> None: """Test sending a message through a single edge runner.""" source = MockExecutor(id="source_executor") target = MockExecutor(id="target_executor") executors: dict[str, Executor] = {source.id: source, target.id: target} edge_group = SingleEdgeGroup(source_id=source.id, target_id=target.id) edge_runner = create_edge_runner(edge_group, executors) shared_state = SharedState() ctx = InProcRunnerContext() data = MockMessage(data="test") message = Message(data=data, source_id=source.id, target_id=target.id) success = await edge_runner.send_message(message, shared_state, ctx) assert success is True async def test_single_edge_group_send_message_with_invalid_target() -> None: """Test sending a message through a single edge runner.""" source = MockExecutor(id="source_executor") target = MockExecutor(id="target_executor") executors: dict[str, Executor] = {source.id: source, target.id: target} edge_group = SingleEdgeGroup(source_id=source.id, target_id=target.id) edge_runner = create_edge_runner(edge_group, executors) shared_state = SharedState() ctx = InProcRunnerContext() data = MockMessage(data="test") message = Message(data=data, source_id=source.id, target_id="invalid_target") success = await edge_runner.send_message(message, shared_state, ctx) assert success is False async def test_single_edge_group_send_message_with_invalid_data() -> None: """Test sending a message through a single edge runner with invalid data.""" source = MockExecutor(id="source_executor") target = MockExecutor(id="target_executor") executors: dict[str, Executor] = {source.id: source, target.id: target} edge_group = SingleEdgeGroup(source_id=source.id, target_id=target.id) edge_runner = create_edge_runner(edge_group, executors) shared_state = SharedState() ctx = InProcRunnerContext() data = "invalid_data" message = Message(data=data, source_id=source.id) success = await edge_runner.send_message(message, shared_state, ctx) assert success is False # endregion SingleEdgeGroup # region FanOutEdgeGroup def test_source_edge_group(): """Test creating a fan-out group.""" source = MockExecutor(id="source_executor") target1 = MockExecutor(id="target_executor_1") target2 = MockExecutor(id="target_executor_2") edge_group = FanOutEdgeGroup(source_id=source.id, target_ids=[target1.id, target2.id]) assert edge_group.source_executor_ids == [source.id] assert edge_group.target_executor_ids == [target1.id, target2.id] assert len(edge_group.edges) == 2 assert edge_group.edges[0].source_id == "source_executor" assert edge_group.edges[0].target_id == "target_executor_1" assert edge_group.edges[1].source_id == "source_executor" assert edge_group.edges[1].target_id == "target_executor_2" def test_source_edge_group_invalid_number_of_targets() -> None: """Test creating a fan-out group with an invalid number of targets.""" source = MockExecutor(id="source_executor") target = MockExecutor(id="target_executor") with pytest.raises(ValueError, match="FanOutEdgeGroup must contain at least two targets"): FanOutEdgeGroup(source_id=source.id, target_ids=[target.id]) async def test_source_edge_group_send_message() -> None: """Test sending a message through a fan-out edge runner.""" source = MockExecutor(id="source_executor") target1 = MockExecutor(id="target_executor_1") target2 = MockExecutor(id="target_executor_2") executors: dict[str, Executor] = {source.id: source, target1.id: target1, target2.id: target2} edge_group = FanOutEdgeGroup(source_id=source.id, target_ids=[target1.id, target2.id]) edge_runner = create_edge_runner(edge_group, executors) shared_state = SharedState() ctx = InProcRunnerContext() data = MockMessage(data="test") message = Message(data=data, source_id=source.id) success = await edge_runner.send_message(message, shared_state, ctx) assert success is True assert target1.call_count == 1 assert target2.call_count == 1 async def test_source_edge_group_send_message_with_target() -> None: """Test sending a message through a fan-out group with a target.""" source = MockExecutor(id="source_executor") target1 = MockExecutor(id="target_executor_1") target2 = MockExecutor(id="target_executor_2") edge_group = FanOutEdgeGroup(source_id=source.id, target_ids=[target1.id, target2.id]) executors: dict[str, Executor] = {source.id: source, target1.id: target1, target2.id: target2} edge_runner = create_edge_runner(edge_group, executors) shared_state = SharedState() ctx = InProcRunnerContext() data = MockMessage(data="test") message = Message(data=data, source_id=source.id, target_id=target1.id) success = await edge_runner.send_message(message, shared_state, ctx) assert success is True assert target1.call_count == 1 assert target2.call_count == 0 # target2 should not be called since message targets target1 async def test_source_edge_group_send_message_with_invalid_target() -> None: """Test sending a message through a fan-out group with an invalid target.""" source = MockExecutor(id="source_executor") target1 = MockExecutor(id="target_executor_1") target2 = MockExecutor(id="target_executor_2") edge_group = FanOutEdgeGroup(source_id=source.id, target_ids=[target1.id, target2.id]) executors: dict[str, Executor] = {source.id: source, target1.id: target1, target2.id: target2} edge_runner = create_edge_runner(edge_group, executors) shared_state = SharedState() ctx = InProcRunnerContext() data = MockMessage(data="test") message = Message(data=data, source_id=source.id, target_id="invalid_target") success = await edge_runner.send_message(message, shared_state, ctx) assert success is False async def test_source_edge_group_send_message_with_invalid_data() -> None: """Test sending a message through a fan-out group with invalid data.""" source = MockExecutor(id="source_executor") target1 = MockExecutor(id="target_executor_1") target2 = MockExecutor(id="target_executor_2") edge_group = FanOutEdgeGroup(source_id=source.id, target_ids=[target1.id, target2.id]) executors: dict[str, Executor] = {source.id: source, target1.id: target1, target2.id: target2} edge_runner = create_edge_runner(edge_group, executors) shared_state = SharedState() ctx = InProcRunnerContext() data = "invalid_data" message = Message(data=data, source_id=source.id) success = await edge_runner.send_message(message, shared_state, ctx) assert success is False async def test_source_edge_group_send_message_only_one_successful_send() -> None: """Test sending a message through a fan-out group where only one edge can handle the message.""" source = MockExecutor(id="source_executor") target1 = MockExecutor(id="target_executor_1") target2 = MockExecutorSecondary(id="target_executor_2") edge_group = FanOutEdgeGroup(source_id=source.id, target_ids=[target1.id, target2.id]) executors: dict[str, Executor] = {source.id: source, target1.id: target1, target2.id: target2} edge_runner = create_edge_runner(edge_group, executors) shared_state = SharedState() ctx = InProcRunnerContext() data = MockMessage(data="test") message = Message(data=data, source_id=source.id) success = await edge_runner.send_message(message, shared_state, ctx) assert success is True assert target1.call_count == 1 # target1 can handle MockMessage assert target2.call_count == 0 # target2 (MockExecutorSecondary) cannot handle MockMessage def test_source_edge_group_with_selection_func(): """Test creating a partitioning edge group.""" source = MockExecutor(id="source_executor") target1 = MockExecutor(id="target_executor_1") target2 = MockExecutor(id="target_executor_2") edge_group = FanOutEdgeGroup( source_id=source.id, target_ids=[target1.id, target2.id], selection_func=lambda data, target_ids: [target1.id], ) assert edge_group.source_executor_ids == [source.id] assert edge_group.target_executor_ids == [target1.id, target2.id] assert len(edge_group.edges) == 2 assert edge_group.edges[0].source_id == "source_executor" assert edge_group.edges[0].target_id == "target_executor_1" assert edge_group.edges[1].source_id == "source_executor" assert edge_group.edges[1].target_id == "target_executor_2" async def test_source_edge_group_with_selection_func_send_message() -> None: """Test sending a message through a fan-out group with a selection function.""" source = MockExecutor(id="source_executor") target1 = MockExecutor(id="target_executor_1") target2 = MockExecutor(id="target_executor_2") edge_group = FanOutEdgeGroup( source_id=source.id, target_ids=[target1.id, target2.id], selection_func=lambda data, target_ids: [target1.id, target2.id], ) executors: dict[str, Executor] = {source.id: source, target1.id: target1, target2.id: target2} edge_runner = create_edge_runner(edge_group, executors) shared_state = SharedState() ctx = InProcRunnerContext() data = MockMessage(data="test") message = Message(data=data, source_id=source.id) with patch("agent_framework_workflow._edge_runner.EdgeRunner._execute_on_target") as mock_send: success = await edge_runner.send_message(message, shared_state, ctx) assert success is True assert mock_send.call_count == 2 async def test_source_edge_group_with_selection_func_send_message_with_invalid_selection_result() -> None: """Test sending a message through a fan-out group with a selection func with an invalid selection result.""" source = MockExecutor(id="source_executor") target1 = MockExecutor(id="target_executor_1") target2 = MockExecutor(id="target_executor_2") edge_group = FanOutEdgeGroup( source_id=source.id, target_ids=[target1.id, target2.id], selection_func=lambda data, target_ids: [target1.id, "invalid_target"], ) executors: dict[str, Executor] = {source.id: source, target1.id: target1, target2.id: target2} edge_runner = create_edge_runner(edge_group, executors) shared_state = SharedState() ctx = InProcRunnerContext() data = MockMessage(data="test") message = Message(data=data, source_id=source.id) with pytest.raises(RuntimeError): await edge_runner.send_message(message, shared_state, ctx) async def test_source_edge_group_with_selection_func_send_message_with_target() -> None: """Test sending a message through a fan-out group with a selection func with a target.""" source = MockExecutor(id="source_executor") target1 = MockExecutor(id="target_executor_1") target2 = MockExecutor(id="target_executor_2") edge_group = FanOutEdgeGroup( source_id=source.id, target_ids=[target1.id, target2.id], selection_func=lambda data, target_ids: [target1.id, target2.id], ) executors: dict[str, Executor] = {source.id: source, target1.id: target1, target2.id: target2} edge_runner = create_edge_runner(edge_group, executors) shared_state = SharedState() ctx = InProcRunnerContext() data = MockMessage(data="test") message = Message(data=data, source_id=source.id, target_id=target1.id) with patch("agent_framework_workflow._edge_runner.EdgeRunner._execute_on_target") as mock_send: success = await edge_runner.send_message(message, shared_state, ctx) assert success is True assert mock_send.call_count == 1 assert mock_send.call_args[0][0] == target1.id async def test_source_edge_group_with_selection_func_send_message_with_target_not_in_selection() -> None: """Test sending a message through a fan-out group with a selection func with a target not in the selection.""" source = MockExecutor(id="source_executor") target1 = MockExecutor(id="target_executor_1") target2 = MockExecutor(id="target_executor_2") edge_group = FanOutEdgeGroup( source_id=source.id, target_ids=[target1.id, target2.id], selection_func=lambda data, target_ids: [target1.id], # Only target1 will receive the message ) executors: dict[str, Executor] = {source.id: source, target1.id: target1, target2.id: target2} edge_runner = create_edge_runner(edge_group, executors) shared_state = SharedState() ctx = InProcRunnerContext() data = MockMessage(data="test") message = Message(data=data, source_id=source.id, target_id=target2.id) success = await edge_runner.send_message(message, shared_state, ctx) assert success is False async def test_source_edge_group_with_selection_func_send_message_with_invalid_data() -> None: """Test sending a message through a fan-out group with a selection func with invalid data.""" source = MockExecutor(id="source_executor") target1 = MockExecutor(id="target_executor_1") target2 = MockExecutor(id="target_executor_2") edge_group = FanOutEdgeGroup( source_id=source.id, target_ids=[target1.id, target2.id], selection_func=lambda data, target_ids: [target1.id, target2.id], ) executors: dict[str, Executor] = {source.id: source, target1.id: target1, target2.id: target2} edge_runner = create_edge_runner(edge_group, executors) shared_state = SharedState() ctx = InProcRunnerContext() data = "invalid_data" message = Message(data=data, source_id=source.id) success = await edge_runner.send_message(message, shared_state, ctx) assert success is False async def test_source_edge_group_with_selection_func_send_message_with_target_invalid_data() -> None: """Test sending a message through a fan-out group with a selection func with a target and invalid data.""" source = MockExecutor(id="source_executor") target1 = MockExecutor(id="target_executor_1") target2 = MockExecutor(id="target_executor_2") edge_group = FanOutEdgeGroup( source_id=source.id, target_ids=[target1.id, target2.id], selection_func=lambda data, target_ids: [target1.id, target2.id], ) executors: dict[str, Executor] = {source.id: source, target1.id: target1, target2.id: target2} edge_runner = create_edge_runner(edge_group, executors) shared_state = SharedState() ctx = InProcRunnerContext() data = "invalid_data" message = Message(data=data, source_id=source.id, target_id=target1.id) success = await edge_runner.send_message(message, shared_state, ctx) assert success is False # endregion FanOutEdgeGroup # region FanInEdgeGroup def test_target_edge_group(): """Test creating a fan-in edge group.""" source1 = MockExecutor(id="source_executor_1") source2 = MockExecutor(id="source_executor_2") target = MockAggregator(id="target_executor") edge_group = FanInEdgeGroup(source_ids=[source1.id, source2.id], target_id=target.id) assert edge_group.source_executor_ids == [source1.id, source2.id] assert edge_group.target_executor_ids == [target.id] assert len(edge_group.edges) == 2 assert edge_group.edges[0].source_id == "source_executor_1" assert edge_group.edges[0].target_id == "target_executor" assert edge_group.edges[1].source_id == "source_executor_2" assert edge_group.edges[1].target_id == "target_executor" def test_target_edge_group_invalid_number_of_sources(): """Test creating a fan-in edge group with an invalid number of sources.""" source = MockExecutor(id="source_executor") target = MockAggregator(id="target_executor") with pytest.raises(ValueError, match="FanInEdgeGroup must contain at least two sources"): FanInEdgeGroup(source_ids=[source.id], target_id=target.id) async def test_target_edge_group_send_message_buffer() -> None: """Test sending a message through a fan-in edge group with buffering.""" source1 = MockExecutor(id="source_executor_1") source2 = MockExecutor(id="source_executor_2") target = MockAggregator(id="target_executor") edge_group = FanInEdgeGroup(source_ids=[source1.id, source2.id], target_id=target.id) executors: dict[str, Executor] = {source1.id: source1, source2.id: source2, target.id: target} edge_runner = create_edge_runner(edge_group, executors) shared_state = SharedState() ctx = InProcRunnerContext() data = MockMessage(data="test") with patch("agent_framework_workflow._edge_runner.EdgeRunner._execute_on_target") as mock_send: success = await edge_runner.send_message( Message(data=data, source_id=source1.id), shared_state, ctx, ) assert success is True assert mock_send.call_count == 0 # The message should be buffered and wait for the second source assert len(edge_runner._buffer[source1.id]) == 1 # type: ignore success = await edge_runner.send_message( Message(data=data, source_id=source2.id), shared_state, ctx, ) assert success is True assert mock_send.call_count == 1 # The message should be sent now that both sources have sent their messages # Buffer should be cleared after sending assert not edge_runner._buffer # type: ignore async def test_target_edge_group_send_message_with_invalid_target() -> None: """Test sending a message through a fan-in edge group with an invalid target.""" source1 = MockExecutor(id="source_executor_1") source2 = MockExecutor(id="source_executor_2") target = MockAggregator(id="target_executor") edge_group = FanInEdgeGroup(source_ids=[source1.id, source2.id], target_id=target.id) executors: dict[str, Executor] = {source1.id: source1, source2.id: source2, target.id: target} edge_runner = create_edge_runner(edge_group, executors) shared_state = SharedState() ctx = InProcRunnerContext() data = MockMessage(data="test") message = Message(data=data, source_id=source1.id, target_id="invalid_target") success = await edge_runner.send_message(message, shared_state, ctx) assert success is False async def test_target_edge_group_send_message_with_invalid_data() -> None: """Test sending a message through a fan-in edge group with invalid data.""" source1 = MockExecutor(id="source_executor_1") source2 = MockExecutor(id="source_executor_2") target = MockAggregator(id="target_executor") edge_group = FanInEdgeGroup(source_ids=[source1.id, source2.id], target_id=target.id) executors: dict[str, Executor] = {source1.id: source1, source2.id: source2, target.id: target} edge_runner = create_edge_runner(edge_group, executors) shared_state = SharedState() ctx = InProcRunnerContext() data = "invalid_data" message = Message(data=data, source_id=source1.id) success = await edge_runner.send_message(message, shared_state, ctx) assert success is False # endregion FanInEdgeGroup # region SwitchCaseEdgeGroup def test_switch_case_edge_group() -> None: """Test creating a switch case edge group.""" source = MockExecutor(id="source_executor") target1 = MockExecutor(id="target_executor_1") target2 = MockExecutor(id="target_executor_2") edge_group = SwitchCaseEdgeGroup( source_id=source.id, cases=[ SwitchCaseEdgeGroupCase(condition=lambda x: x.data < 0, target_id=target1.id), SwitchCaseEdgeGroupDefault(target_id=target2.id), ], ) assert edge_group.source_executor_ids == [source.id] assert edge_group.target_executor_ids == [target1.id, target2.id] assert len(edge_group.edges) == 2 assert edge_group.edges[0].source_id == "source_executor" assert edge_group.edges[0].target_id == "target_executor_1" assert edge_group.edges[1].source_id == "source_executor" assert edge_group.edges[1].target_id == "target_executor_2" assert edge_group._selection_func is not None # type: ignore assert edge_group._selection_func(MockMessage(data=-1), [target1.id, target2.id]) == [target1.id] # type: ignore assert edge_group._selection_func(MockMessage(data=1), [target1.id, target2.id]) == [target2.id] # type: ignore def test_switch_case_edge_group_invalid_number_of_cases(): """Test creating a switch case edge group with an invalid number of cases.""" source = MockExecutor(id="source_executor") target = MockExecutor(id="target_executor") with pytest.raises( ValueError, match=r"SwitchCaseEdgeGroup must contain at least two cases \(including the default case\)." ): SwitchCaseEdgeGroup( source_id=source.id, cases=[ SwitchCaseEdgeGroupCase(condition=lambda x: x.data < 0, target_id=target.id), ], ) with pytest.raises(ValueError, match="SwitchCaseEdgeGroup must contain exactly one default case."): SwitchCaseEdgeGroup( source_id=source.id, cases=[ SwitchCaseEdgeGroupCase(condition=lambda x: x.data < 0, target_id=target.id), SwitchCaseEdgeGroupCase(condition=lambda x: x.data >= 0, target_id=target.id), ], ) def test_switch_case_edge_group_invalid_number_of_default_cases(): """Test creating a switch case edge group with an invalid number of conditions.""" source = MockExecutor(id="source_executor") target1 = MockExecutor(id="target_executor_1") target2 = MockExecutor(id="target_executor_2") with pytest.raises(ValueError, match="SwitchCaseEdgeGroup must contain exactly one default case."): SwitchCaseEdgeGroup( source_id=source.id, cases=[ SwitchCaseEdgeGroupCase(condition=lambda x: x.data < 0, target_id=target1.id), SwitchCaseEdgeGroupDefault(target_id=target2.id), SwitchCaseEdgeGroupDefault(target_id=target2.id), ], ) async def test_switch_case_edge_group_send_message() -> None: """Test sending a message through a switch case edge group.""" source = MockExecutor(id="source_executor") target1 = MockExecutor(id="target_executor_1") target2 = MockExecutor(id="target_executor_2") edge_group = SwitchCaseEdgeGroup( source_id=source.id, cases=[ SwitchCaseEdgeGroupCase(condition=lambda x: x.data < 0, target_id=target1.id), SwitchCaseEdgeGroupDefault(target_id=target2.id), ], ) executors: dict[str, Executor] = {source.id: source, target1.id: target1, target2.id: target2} edge_runner = create_edge_runner(edge_group, executors) shared_state = SharedState() ctx = InProcRunnerContext() data = MockMessage(data=-1) message = Message(data=data, source_id=source.id) with patch("agent_framework_workflow._edge_runner.EdgeRunner._execute_on_target") as mock_send: success = await edge_runner.send_message(message, shared_state, ctx) assert success is True assert mock_send.call_count == 1 # Default condition should data = MockMessage(data=1) message = Message(data=data, source_id=source.id) with patch("agent_framework_workflow._edge_runner.EdgeRunner._execute_on_target") as mock_send: success = await edge_runner.send_message(message, shared_state, ctx) assert success is True assert mock_send.call_count == 1 async def test_switch_case_edge_group_send_message_with_invalid_target() -> None: """Test sending a message through a switch case edge group with an invalid target.""" source = MockExecutor(id="source_executor") target1 = MockExecutor(id="target_executor_1") target2 = MockExecutor(id="target_executor_2") edge_group = SwitchCaseEdgeGroup( source_id=source.id, cases=[ SwitchCaseEdgeGroupCase(condition=lambda x: x.data < 0, target_id=target1.id), SwitchCaseEdgeGroupDefault(target_id=target2.id), ], ) executors: dict[str, Executor] = {source.id: source, target1.id: target1, target2.id: target2} edge_runner = create_edge_runner(edge_group, executors) shared_state = SharedState() ctx = InProcRunnerContext() data = MockMessage(data=-1) message = Message(data=data, source_id=source.id, target_id="invalid_target") success = await edge_runner.send_message(message, shared_state, ctx) assert success is False async def test_switch_case_edge_group_send_message_with_valid_target() -> None: """Test sending a message through a switch case edge group with a target.""" source = MockExecutor(id="source_executor") target1 = MockExecutor(id="target_executor_1") target2 = MockExecutor(id="target_executor_2") edge_group = SwitchCaseEdgeGroup( source_id=source.id, cases=[ SwitchCaseEdgeGroupCase(condition=lambda x: x.data < 0, target_id=target1.id), SwitchCaseEdgeGroupDefault(target_id=target2.id), ], ) executors: dict[str, Executor] = {source.id: source, target1.id: target1, target2.id: target2} edge_runner = create_edge_runner(edge_group, executors) shared_state = SharedState() ctx = InProcRunnerContext() data = MockMessage(data=1) # Condition will fail message = Message(data=data, source_id=source.id, target_id=target1.id) success = await edge_runner.send_message(message, shared_state, ctx) assert success is False data = MockMessage(data=-1) # Condition will pass message = Message(data=data, source_id=source.id, target_id=target1.id) success = await edge_runner.send_message(message, shared_state, ctx) assert success is True async def test_switch_case_edge_group_send_message_with_invalid_data() -> None: """Test sending a message through a switch case edge group with invalid data.""" source = MockExecutor(id="source_executor") target1 = MockExecutor(id="target_executor_1") target2 = MockExecutor(id="target_executor_2") edge_group = SwitchCaseEdgeGroup( source_id=source.id, cases=[ SwitchCaseEdgeGroupCase(condition=lambda x: x.data < 0, target_id=target1.id), SwitchCaseEdgeGroupDefault(target_id=target2.id), ], ) executors: dict[str, Executor] = {source.id: source, target1.id: target1, target2.id: target2} edge_runner = create_edge_runner(edge_group, executors) shared_state = SharedState() ctx = InProcRunnerContext() data = "invalid_data" message = Message(data=data, source_id=source.id) success = await edge_runner.send_message(message, shared_state, ctx) assert success is False # endregion SwitchCaseEdgeGroup