Files
agent-framework/python/packages/workflow/tests/test_edge.py
T
Tao Chen ed86baa6cb Python: Workflow Edge Groups (#393)
* Introducing edge groups

* Add conditional and partitioning edge groups; next add samples and tests

* Add unit tests

* Add samples

* Address comments 1

* Address comments 2

* Update conditional edge group to take in cases and default

* Minor updates to sample

* Collapsing Paritioning Edge group and Conditional Edge group to source edge group

* Improve sample clarity

* Name consolidation

---------

Co-authored-by: Eric Zhu <ekzhu@users.noreply.github.com>
2025-08-15 18:11:35 +00:00

832 lines
29 KiB
Python

# Copyright (c) Microsoft. All rights reserved.
from dataclasses import dataclass
from typing import Any
from unittest.mock import patch
import pytest
from agent_framework.workflow import Executor, WorkflowContext, handler
from agent_framework_workflow._edge import (
Case,
Default,
Edge,
FanInEdgeGroup,
FanOutEdgeGroup,
SingleEdgeGroup,
SwitchCaseEdgeGroup,
)
@dataclass
class MockMessage:
"""A mock message for testing purposes."""
data: Any
@dataclass
class MockMessageSecondary:
"""A secondary mock message for testing purposes."""
data: Any
class MockExecutor(Executor):
"""A mock executor for testing purposes."""
@handler
async def mock_handler(self, message: MockMessage, ctx: WorkflowContext) -> None:
"""A mock handler that does nothing."""
pass
class MockExecutorSecondary(Executor):
"""A secondary mock executor for testing purposes."""
@handler
async def mock_handler_secondary(self, message: MockMessageSecondary, ctx: WorkflowContext) -> None:
"""A secondary mock handler that does nothing."""
pass
class MockAggregator(Executor):
"""A mock aggregator for testing purposes."""
@handler
async def mock_aggregator_handler(self, message: list[MockMessage], ctx: WorkflowContext) -> None:
"""A mock aggregator handler that does nothing."""
pass
# region Edge
def test_create_edge():
"""Test creating an edge with a source and target executor."""
source = MockExecutor(id="source_executor")
target = MockExecutor(id="target_executor")
edge = Edge(source=source, target=target)
assert edge.source_id == "source_executor"
assert edge.target_id == "target_executor"
assert edge.id == f"{edge.source_id}{Edge.ID_SEPARATOR}{edge.target_id}"
def test_edge_can_handle():
"""Test creating an edge with a source and target executor."""
source = MockExecutor(id="source_executor")
target = MockExecutor(id="target_executor")
edge = Edge(source=source, target=target)
assert edge.can_handle(MockMessage(data="test"))
# endregion Edge
# region SingleEdgeGroup
def test_single_edge_group():
"""Test creating a single edge group."""
source = MockExecutor(id="source_executor")
target = MockExecutor(id="target_executor")
edge_group = SingleEdgeGroup(source=source, target=target)
assert edge_group.source_executors == [source]
assert edge_group.target_executors == [target]
assert edge_group.edges[0].source_id == "source_executor"
assert edge_group.edges[0].target_id == "target_executor"
def test_single_edge_group_with_condition():
"""Test creating a single edge group with a condition."""
source = MockExecutor(id="source_executor")
target = MockExecutor(id="target_executor")
edge_group = SingleEdgeGroup(source=source, target=target, condition=lambda x: x.data == "test")
assert edge_group.source_executors == [source]
assert edge_group.target_executors == [target]
assert edge_group.edges[0].source_id == "source_executor"
assert edge_group.edges[0].target_id == "target_executor"
assert edge_group.edges[0]._condition is not None # type: ignore
async def test_single_edge_group_send_message():
"""Test sending a message through a single edge group."""
source = MockExecutor(id="source_executor")
target = MockExecutor(id="target_executor")
edge_group = SingleEdgeGroup(source=source, target=target)
from agent_framework_workflow._runner_context import InProcRunnerContext, Message
from agent_framework_workflow._shared_state import SharedState
shared_state = SharedState()
ctx = InProcRunnerContext()
data = MockMessage(data="test")
message = Message(data=data, source_id=source.id)
success = await edge_group.send_message(message, shared_state, ctx)
assert success is True
async def test_single_edge_group_send_message_with_target():
"""Test sending a message through a single edge group."""
source = MockExecutor(id="source_executor")
target = MockExecutor(id="target_executor")
edge_group = SingleEdgeGroup(source=source, target=target)
from agent_framework_workflow._runner_context import InProcRunnerContext, Message
from agent_framework_workflow._shared_state import SharedState
shared_state = SharedState()
ctx = InProcRunnerContext()
data = MockMessage(data="test")
message = Message(data=data, source_id=source.id, target_id=target.id)
success = await edge_group.send_message(message, shared_state, ctx)
assert success is True
async def test_single_edge_group_send_message_with_invalid_target():
"""Test sending a message through a single edge group."""
source = MockExecutor(id="source_executor")
target = MockExecutor(id="target_executor")
edge_group = SingleEdgeGroup(source=source, target=target)
from agent_framework_workflow._runner_context import InProcRunnerContext, Message
from agent_framework_workflow._shared_state import SharedState
shared_state = SharedState()
ctx = InProcRunnerContext()
data = MockMessage(data="test")
message = Message(data=data, source_id=source.id, target_id="invalid_target")
success = await edge_group.send_message(message, shared_state, ctx)
assert success is False
async def test_single_edge_group_send_message_with_invalid_data():
"""Test sending a message through a single edge group."""
source = MockExecutor(id="source_executor")
target = MockExecutor(id="target_executor")
edge_group = SingleEdgeGroup(source=source, target=target)
from agent_framework_workflow._runner_context import InProcRunnerContext, Message
from agent_framework_workflow._shared_state import SharedState
shared_state = SharedState()
ctx = InProcRunnerContext()
data = "invalid_data"
message = Message(data=data, source_id=source.id)
success = await edge_group.send_message(message, shared_state, ctx)
assert success is False
# endregion SingleEdgeGroup
# region FanOutEdgeGroup
def test_source_edge_group():
"""Test creating a fan-out group."""
source = MockExecutor(id="source_executor")
target1 = MockExecutor(id="target_executor_1")
target2 = MockExecutor(id="target_executor_2")
edge_group = FanOutEdgeGroup(source=source, targets=[target1, target2])
assert edge_group.source_executors == [source]
assert edge_group.target_executors == [target1, target2]
assert len(edge_group.edges) == 2
assert edge_group.edges[0].source_id == "source_executor"
assert edge_group.edges[0].target_id == "target_executor_1"
assert edge_group.edges[1].source_id == "source_executor"
assert edge_group.edges[1].target_id == "target_executor_2"
def test_source_edge_group_invalid_number_of_targets():
"""Test creating a fan-out group with an invalid number of targets."""
source = MockExecutor(id="source_executor")
target = MockExecutor(id="target_executor")
with pytest.raises(ValueError, match="FanOutEdgeGroup must contain at least two targets"):
FanOutEdgeGroup(source=source, targets=[target])
async def test_source_edge_group_send_message():
"""Test sending a message through a fan-out group."""
source = MockExecutor(id="source_executor")
target1 = MockExecutor(id="target_executor_1")
target2 = MockExecutor(id="target_executor_2")
edge_group = FanOutEdgeGroup(source=source, targets=[target1, target2])
from agent_framework_workflow._runner_context import InProcRunnerContext, Message
from agent_framework_workflow._shared_state import SharedState
shared_state = SharedState()
ctx = InProcRunnerContext()
data = MockMessage(data="test")
message = Message(data=data, source_id=source.id)
with patch("agent_framework_workflow._edge.Edge.send_message") as mock_send:
success = await edge_group.send_message(message, shared_state, ctx)
assert success is True
assert mock_send.call_count == 2
async def test_source_edge_group_send_message_with_target():
"""Test sending a message through a fan-out group with a target."""
source = MockExecutor(id="source_executor")
target1 = MockExecutor(id="target_executor_1")
target2 = MockExecutor(id="target_executor_2")
edge_group = FanOutEdgeGroup(source=source, targets=[target1, target2])
from agent_framework_workflow._runner_context import InProcRunnerContext, Message
from agent_framework_workflow._shared_state import SharedState
shared_state = SharedState()
ctx = InProcRunnerContext()
data = MockMessage(data="test")
message = Message(data=data, source_id=source.id, target_id=target1.id)
with patch("agent_framework_workflow._edge.Edge.send_message") as mock_send:
success = await edge_group.send_message(message, shared_state, ctx)
assert success is True
assert mock_send.call_count == 1
assert mock_send.call_args[0][0].target_id == target1.id
async def test_source_edge_group_send_message_with_invalid_target():
"""Test sending a message through a fan-out group with an invalid target."""
source = MockExecutor(id="source_executor")
target1 = MockExecutor(id="target_executor_1")
target2 = MockExecutor(id="target_executor_2")
edge_group = FanOutEdgeGroup(source=source, targets=[target1, target2])
from agent_framework_workflow._runner_context import InProcRunnerContext, Message
from agent_framework_workflow._shared_state import SharedState
shared_state = SharedState()
ctx = InProcRunnerContext()
data = MockMessage(data="test")
message = Message(data=data, source_id=source.id, target_id="invalid_target")
success = await edge_group.send_message(message, shared_state, ctx)
assert success is False
async def test_source_edge_group_send_message_with_invalid_data():
"""Test sending a message through a fan-out group with invalid data."""
source = MockExecutor(id="source_executor")
target1 = MockExecutor(id="target_executor_1")
target2 = MockExecutor(id="target_executor_2")
edge_group = FanOutEdgeGroup(source=source, targets=[target1, target2])
from agent_framework_workflow._runner_context import InProcRunnerContext, Message
from agent_framework_workflow._shared_state import SharedState
shared_state = SharedState()
ctx = InProcRunnerContext()
data = "invalid_data"
message = Message(data=data, source_id=source.id)
success = await edge_group.send_message(message, shared_state, ctx)
assert success is False
async def test_source_edge_group_send_message_only_one_successful_send():
"""Test sending a message through a fan-out group where only one edge can handle the message."""
source = MockExecutor(id="source_executor")
target1 = MockExecutor(id="target_executor_1")
target2 = MockExecutorSecondary(id="target_executor_2")
edge_group = FanOutEdgeGroup(source=source, targets=[target1, target2])
from agent_framework_workflow._runner_context import InProcRunnerContext, Message
from agent_framework_workflow._shared_state import SharedState
shared_state = SharedState()
ctx = InProcRunnerContext()
data = MockMessage(data="test")
message = Message(data=data, source_id=source.id)
with patch("agent_framework_workflow._edge.Edge.send_message") as mock_send:
success = await edge_group.send_message(message, shared_state, ctx)
assert success is True
assert mock_send.call_count == 1
def test_source_edge_group_with_selection_func():
"""Test creating a partitioning edge group."""
source = MockExecutor(id="source_executor")
target1 = MockExecutor(id="target_executor_1")
target2 = MockExecutor(id="target_executor_2")
edge_group = FanOutEdgeGroup(
source=source,
targets=[target1, target2],
selection_func=lambda data, target_ids: [target1.id],
)
assert edge_group.source_executors == [source]
assert edge_group.target_executors == [target1, target2]
assert len(edge_group.edges) == 2
assert edge_group.edges[0].source_id == "source_executor"
assert edge_group.edges[0].target_id == "target_executor_1"
assert edge_group.edges[1].source_id == "source_executor"
assert edge_group.edges[1].target_id == "target_executor_2"
async def test_source_edge_group_with_selection_func_send_message():
"""Test sending a message through a fan-out group with a selection function."""
source = MockExecutor(id="source_executor")
target1 = MockExecutor(id="target_executor_1")
target2 = MockExecutor(id="target_executor_2")
edge_group = FanOutEdgeGroup(
source=source,
targets=[target1, target2],
selection_func=lambda data, target_ids: [target1.id, target2.id],
)
from agent_framework_workflow._runner_context import InProcRunnerContext, Message
from agent_framework_workflow._shared_state import SharedState
shared_state = SharedState()
ctx = InProcRunnerContext()
data = MockMessage(data="test")
message = Message(data=data, source_id=source.id)
with patch("agent_framework_workflow._edge.Edge.send_message") as mock_send:
success = await edge_group.send_message(message, shared_state, ctx)
assert success is True
assert mock_send.call_count == 2
async def test_source_edge_group_with_selection_func_send_message_with_invalid_selection_result():
"""Test sending a message through a fan-out group with a selection func with an invalid selection result."""
source = MockExecutor(id="source_executor")
target1 = MockExecutor(id="target_executor_1")
target2 = MockExecutor(id="target_executor_2")
edge_group = FanOutEdgeGroup(
source=source,
targets=[target1, target2],
selection_func=lambda data, target_ids: [target1.id, "invalid_target"],
)
from agent_framework_workflow._runner_context import InProcRunnerContext, Message
from agent_framework_workflow._shared_state import SharedState
shared_state = SharedState()
ctx = InProcRunnerContext()
data = MockMessage(data="test")
message = Message(data=data, source_id=source.id)
with pytest.raises(RuntimeError):
await edge_group.send_message(message, shared_state, ctx)
async def test_source_edge_group_with_selection_func_send_message_with_target():
"""Test sending a message through a fan-out group with a selection func with a target."""
source = MockExecutor(id="source_executor")
target1 = MockExecutor(id="target_executor_1")
target2 = MockExecutor(id="target_executor_2")
edge_group = FanOutEdgeGroup(
source=source,
targets=[target1, target2],
selection_func=lambda data, target_ids: [target1.id, target2.id],
)
from agent_framework_workflow._runner_context import InProcRunnerContext, Message
from agent_framework_workflow._shared_state import SharedState
shared_state = SharedState()
ctx = InProcRunnerContext()
data = MockMessage(data="test")
message = Message(data=data, source_id=source.id, target_id=target1.id)
with patch("agent_framework_workflow._edge.Edge.send_message") as mock_send:
success = await edge_group.send_message(message, shared_state, ctx)
assert success is True
assert mock_send.call_count == 1
assert mock_send.call_args[0][0].target_id == target1.id
async def test_source_edge_group_with_selection_func_send_message_with_target_not_in_selection():
"""Test sending a message through a fan-out group with a selection func with a target not in the selection."""
source = MockExecutor(id="source_executor")
target1 = MockExecutor(id="target_executor_1")
target2 = MockExecutor(id="target_executor_2")
edge_group = FanOutEdgeGroup(
source=source,
targets=[target1, target2],
selection_func=lambda data, target_ids: [target1.id], # Only target1 will receive the message
)
from agent_framework_workflow._runner_context import InProcRunnerContext, Message
from agent_framework_workflow._shared_state import SharedState
shared_state = SharedState()
ctx = InProcRunnerContext()
data = MockMessage(data="test")
message = Message(data=data, source_id=source.id, target_id=target2.id)
success = await edge_group.send_message(message, shared_state, ctx)
assert success is False
async def test_source_edge_group_with_selection_func_send_message_with_invalid_data():
"""Test sending a message through a fan-out group with a selection func with invalid data."""
source = MockExecutor(id="source_executor")
target1 = MockExecutor(id="target_executor_1")
target2 = MockExecutor(id="target_executor_2")
edge_group = FanOutEdgeGroup(
source=source, targets=[target1, target2], selection_func=lambda data, target_ids: [target1.id, target2.id]
)
from agent_framework_workflow._runner_context import InProcRunnerContext, Message
from agent_framework_workflow._shared_state import SharedState
shared_state = SharedState()
ctx = InProcRunnerContext()
data = "invalid_data"
message = Message(data=data, source_id=source.id)
success = await edge_group.send_message(message, shared_state, ctx)
assert success is False
async def test_source_edge_group_with_selection_func_send_message_with_target_invalid_data():
"""Test sending a message through a fan-out group with a selection func with a target and invalid data."""
source = MockExecutor(id="source_executor")
target1 = MockExecutor(id="target_executor_1")
target2 = MockExecutor(id="target_executor_2")
edge_group = FanOutEdgeGroup(
source=source, targets=[target1, target2], selection_func=lambda data, target_ids: [target1.id, target2.id]
)
from agent_framework_workflow._runner_context import InProcRunnerContext, Message
from agent_framework_workflow._shared_state import SharedState
shared_state = SharedState()
ctx = InProcRunnerContext()
data = "invalid_data"
message = Message(data=data, source_id=source.id, target_id=target1.id)
success = await edge_group.send_message(message, shared_state, ctx)
assert success is False
# endregion FanOutEdgeGroup
# region FanInEdgeGroup
def test_target_edge_group():
"""Test creating a fan-in edge group."""
source1 = MockExecutor(id="source_executor_1")
source2 = MockExecutor(id="source_executor_2")
target = MockAggregator(id="target_executor")
edge_group = FanInEdgeGroup(sources=[source1, source2], target=target)
assert edge_group.source_executors == [source1, source2]
assert edge_group.target_executors == [target]
assert len(edge_group.edges) == 2
assert edge_group.edges[0].source_id == "source_executor_1"
assert edge_group.edges[0].target_id == "target_executor"
assert edge_group.edges[1].source_id == "source_executor_2"
assert edge_group.edges[1].target_id == "target_executor"
def test_target_edge_group_invalid_number_of_sources():
"""Test creating a fan-in edge group with an invalid number of sources."""
source = MockExecutor(id="source_executor")
target = MockAggregator(id="target_executor")
with pytest.raises(ValueError, match="FanInEdgeGroup must contain at least two sources"):
FanInEdgeGroup(sources=[source], target=target)
async def test_target_edge_group_send_message_buffer():
"""Test sending a message through a fan-in edge group with buffering."""
source1 = MockExecutor(id="source_executor_1")
source2 = MockExecutor(id="source_executor_2")
target = MockAggregator(id="target_executor")
edge_group = FanInEdgeGroup(sources=[source1, source2], target=target)
from agent_framework_workflow._runner_context import InProcRunnerContext, Message
from agent_framework_workflow._shared_state import SharedState
shared_state = SharedState()
ctx = InProcRunnerContext()
data = MockMessage(data="test")
with patch("agent_framework_workflow._edge.Edge.send_message") as mock_send:
success = await edge_group.send_message(
Message(data=data, source_id=source1.id),
shared_state,
ctx,
)
assert success is True
assert mock_send.call_count == 0 # The message should be buffered and wait for the second source
assert len(edge_group._buffer[source1.id]) == 1 # type: ignore
success = await edge_group.send_message(
Message(data=data, source_id=source2.id),
shared_state,
ctx,
)
assert success is True
assert mock_send.call_count == 1 # The message should be sent now that both sources have sent their messages
# Buffer should be cleared after sending
assert not edge_group._buffer # type: ignore
async def test_target_edge_group_send_message_with_invalid_target():
"""Test sending a message through a fan-in edge group with an invalid target."""
source1 = MockExecutor(id="source_executor_1")
source2 = MockExecutor(id="source_executor_2")
target = MockAggregator(id="target_executor")
edge_group = FanInEdgeGroup(sources=[source1, source2], target=target)
from agent_framework_workflow._runner_context import InProcRunnerContext, Message
from agent_framework_workflow._shared_state import SharedState
shared_state = SharedState()
ctx = InProcRunnerContext()
data = MockMessage(data="test")
message = Message(data=data, source_id=source1.id, target_id="invalid_target")
success = await edge_group.send_message(message, shared_state, ctx)
assert success is False
async def test_target_edge_group_send_message_with_invalid_data():
"""Test sending a message through a fan-in edge group with invalid data."""
source1 = MockExecutor(id="source_executor_1")
source2 = MockExecutor(id="source_executor_2")
target = MockAggregator(id="target_executor")
edge_group = FanInEdgeGroup(sources=[source1, source2], target=target)
from agent_framework_workflow._runner_context import InProcRunnerContext, Message
from agent_framework_workflow._shared_state import SharedState
shared_state = SharedState()
ctx = InProcRunnerContext()
data = "invalid_data"
message = Message(data=data, source_id=source1.id)
success = await edge_group.send_message(message, shared_state, ctx)
assert success is False
# endregion FanInEdgeGroup
# region SwitchCaseEdgeGroup
def test_switch_case_edge_group():
"""Test creating a switch case edge group."""
source = MockExecutor(id="source_executor")
target1 = MockExecutor(id="target_executor_1")
target2 = MockExecutor(id="target_executor_2")
edge_group = SwitchCaseEdgeGroup(
source=source,
cases=[
Case(condition=lambda x: x.data < 0, target=target1),
Default(target=target2),
],
)
assert edge_group.source_executors == [source]
assert edge_group.target_executors == [target1, target2]
assert len(edge_group.edges) == 2
assert edge_group.edges[0].source_id == "source_executor"
assert edge_group.edges[0].target_id == "target_executor_1"
assert edge_group.edges[1].source_id == "source_executor"
assert edge_group.edges[1].target_id == "target_executor_2"
assert edge_group._selection_func is not None # type: ignore
assert edge_group._selection_func(MockMessage(data=-1), [target1.id, target2.id]) == [target1.id] # type: ignore
assert edge_group._selection_func(MockMessage(data=1), [target1.id, target2.id]) == [target2.id] # type: ignore
def test_switch_case_edge_group_invalid_number_of_cases():
"""Test creating a switch case edge group with an invalid number of cases."""
source = MockExecutor(id="source_executor")
target = MockExecutor(id="target_executor")
with pytest.raises(
ValueError, match=r"SwitchCaseEdgeGroup must contain at least two cases \(including the default case\)."
):
SwitchCaseEdgeGroup(
source=source,
cases=[
Case(condition=lambda x: x.data < 0, target=target),
],
)
with pytest.raises(ValueError, match="SwitchCaseEdgeGroup must contain exactly one default case."):
SwitchCaseEdgeGroup(
source=source,
cases=[
Case(condition=lambda x: x.data < 0, target=target),
Case(condition=lambda x: x.data >= 0, target=target),
],
)
def test_switch_case_edge_group_invalid_number_of_default_cases():
"""Test creating a switch case edge group with an invalid number of conditions."""
source = MockExecutor(id="source_executor")
target1 = MockExecutor(id="target_executor_1")
target2 = MockExecutor(id="target_executor_2")
with pytest.raises(ValueError, match="SwitchCaseEdgeGroup must contain exactly one default case."):
SwitchCaseEdgeGroup(
source=source,
cases=[
Case(condition=lambda x: x.data < 0, target=target1),
Default(target=target2),
Default(target=target2),
],
)
async def test_switch_case_edge_group_send_message():
"""Test sending a message through a switch case edge group."""
source = MockExecutor(id="source_executor")
target1 = MockExecutor(id="target_executor_1")
target2 = MockExecutor(id="target_executor_2")
edge_group = SwitchCaseEdgeGroup(
source=source,
cases=[
Case(condition=lambda x: x.data < 0, target=target1),
Default(target=target2),
],
)
from agent_framework_workflow._runner_context import InProcRunnerContext, Message
from agent_framework_workflow._shared_state import SharedState
shared_state = SharedState()
ctx = InProcRunnerContext()
data = MockMessage(data=-1)
message = Message(data=data, source_id=source.id)
with patch("agent_framework_workflow._edge.Edge.send_message") as mock_send:
success = await edge_group.send_message(message, shared_state, ctx)
assert success is True
assert mock_send.call_count == 1
# Default condition should
data = MockMessage(data=1)
message = Message(data=data, source_id=source.id)
with patch("agent_framework_workflow._edge.Edge.send_message") as mock_send:
success = await edge_group.send_message(message, shared_state, ctx)
assert success is True
assert mock_send.call_count == 1
async def test_switch_case_edge_group_send_message_with_invalid_target():
"""Test sending a message through a switch case edge group with an invalid target."""
source = MockExecutor(id="source_executor")
target1 = MockExecutor(id="target_executor_1")
target2 = MockExecutor(id="target_executor_2")
edge_group = SwitchCaseEdgeGroup(
source=source,
cases=[
Case(condition=lambda x: x.data < 0, target=target1),
Default(target=target2),
],
)
from agent_framework_workflow._runner_context import InProcRunnerContext, Message
from agent_framework_workflow._shared_state import SharedState
shared_state = SharedState()
ctx = InProcRunnerContext()
data = MockMessage(data=-1)
message = Message(data=data, source_id=source.id, target_id="invalid_target")
success = await edge_group.send_message(message, shared_state, ctx)
assert success is False
async def test_switch_case_edge_group_send_message_with_valid_target():
"""Test sending a message through a switch case edge group with a target."""
source = MockExecutor(id="source_executor")
target1 = MockExecutor(id="target_executor_1")
target2 = MockExecutor(id="target_executor_2")
edge_group = SwitchCaseEdgeGroup(
source=source,
cases=[
Case(condition=lambda x: x.data < 0, target=target1),
Default(target=target2),
],
)
from agent_framework_workflow._runner_context import InProcRunnerContext, Message
from agent_framework_workflow._shared_state import SharedState
shared_state = SharedState()
ctx = InProcRunnerContext()
data = MockMessage(data=1) # Condition will fail
message = Message(data=data, source_id=source.id, target_id=target1.id)
success = await edge_group.send_message(message, shared_state, ctx)
assert success is False
data = MockMessage(data=-1) # Condition will pass
message = Message(data=data, source_id=source.id, target_id=target1.id)
success = await edge_group.send_message(message, shared_state, ctx)
assert success is True
async def test_switch_case_edge_group_send_message_with_invalid_data():
"""Test sending a message through a switch case edge group with invalid data."""
source = MockExecutor(id="source_executor")
target1 = MockExecutor(id="target_executor_1")
target2 = MockExecutor(id="target_executor_2")
edge_group = SwitchCaseEdgeGroup(
source=source,
cases=[
Case(condition=lambda x: x.data < 0, target=target1),
Default(target=target2),
],
)
from agent_framework_workflow._runner_context import InProcRunnerContext, Message
from agent_framework_workflow._shared_state import SharedState
shared_state = SharedState()
ctx = InProcRunnerContext()
data = "invalid_data"
message = Message(data=data, source_id=source.id)
success = await edge_group.send_message(message, shared_state, ctx)
assert success is False
# endregion SwitchCaseEdgeGroup