mirror of
https://github.com/microsoft/agent-framework.git
synced 2026-06-16 21:04:09 +08:00
ed86baa6cb
* Introducing edge groups * Add conditional and partitioning edge groups; next add samples and tests * Add unit tests * Add samples * Address comments 1 * Address comments 2 * Update conditional edge group to take in cases and default * Minor updates to sample * Collapsing Paritioning Edge group and Conditional Edge group to source edge group * Improve sample clarity * Name consolidation --------- Co-authored-by: Eric Zhu <ekzhu@users.noreply.github.com>
832 lines
29 KiB
Python
832 lines
29 KiB
Python
# Copyright (c) Microsoft. All rights reserved.
|
|
|
|
from dataclasses import dataclass
|
|
from typing import Any
|
|
from unittest.mock import patch
|
|
|
|
import pytest
|
|
from agent_framework.workflow import Executor, WorkflowContext, handler
|
|
|
|
from agent_framework_workflow._edge import (
|
|
Case,
|
|
Default,
|
|
Edge,
|
|
FanInEdgeGroup,
|
|
FanOutEdgeGroup,
|
|
SingleEdgeGroup,
|
|
SwitchCaseEdgeGroup,
|
|
)
|
|
|
|
|
|
@dataclass
|
|
class MockMessage:
|
|
"""A mock message for testing purposes."""
|
|
|
|
data: Any
|
|
|
|
|
|
@dataclass
|
|
class MockMessageSecondary:
|
|
"""A secondary mock message for testing purposes."""
|
|
|
|
data: Any
|
|
|
|
|
|
class MockExecutor(Executor):
|
|
"""A mock executor for testing purposes."""
|
|
|
|
@handler
|
|
async def mock_handler(self, message: MockMessage, ctx: WorkflowContext) -> None:
|
|
"""A mock handler that does nothing."""
|
|
pass
|
|
|
|
|
|
class MockExecutorSecondary(Executor):
|
|
"""A secondary mock executor for testing purposes."""
|
|
|
|
@handler
|
|
async def mock_handler_secondary(self, message: MockMessageSecondary, ctx: WorkflowContext) -> None:
|
|
"""A secondary mock handler that does nothing."""
|
|
pass
|
|
|
|
|
|
class MockAggregator(Executor):
|
|
"""A mock aggregator for testing purposes."""
|
|
|
|
@handler
|
|
async def mock_aggregator_handler(self, message: list[MockMessage], ctx: WorkflowContext) -> None:
|
|
"""A mock aggregator handler that does nothing."""
|
|
pass
|
|
|
|
|
|
# region Edge
|
|
|
|
|
|
def test_create_edge():
|
|
"""Test creating an edge with a source and target executor."""
|
|
source = MockExecutor(id="source_executor")
|
|
target = MockExecutor(id="target_executor")
|
|
|
|
edge = Edge(source=source, target=target)
|
|
|
|
assert edge.source_id == "source_executor"
|
|
assert edge.target_id == "target_executor"
|
|
assert edge.id == f"{edge.source_id}{Edge.ID_SEPARATOR}{edge.target_id}"
|
|
|
|
|
|
def test_edge_can_handle():
|
|
"""Test creating an edge with a source and target executor."""
|
|
source = MockExecutor(id="source_executor")
|
|
target = MockExecutor(id="target_executor")
|
|
|
|
edge = Edge(source=source, target=target)
|
|
|
|
assert edge.can_handle(MockMessage(data="test"))
|
|
|
|
|
|
# endregion Edge
|
|
|
|
# region SingleEdgeGroup
|
|
|
|
|
|
def test_single_edge_group():
|
|
"""Test creating a single edge group."""
|
|
source = MockExecutor(id="source_executor")
|
|
target = MockExecutor(id="target_executor")
|
|
|
|
edge_group = SingleEdgeGroup(source=source, target=target)
|
|
|
|
assert edge_group.source_executors == [source]
|
|
assert edge_group.target_executors == [target]
|
|
assert edge_group.edges[0].source_id == "source_executor"
|
|
assert edge_group.edges[0].target_id == "target_executor"
|
|
|
|
|
|
def test_single_edge_group_with_condition():
|
|
"""Test creating a single edge group with a condition."""
|
|
source = MockExecutor(id="source_executor")
|
|
target = MockExecutor(id="target_executor")
|
|
|
|
edge_group = SingleEdgeGroup(source=source, target=target, condition=lambda x: x.data == "test")
|
|
|
|
assert edge_group.source_executors == [source]
|
|
assert edge_group.target_executors == [target]
|
|
assert edge_group.edges[0].source_id == "source_executor"
|
|
assert edge_group.edges[0].target_id == "target_executor"
|
|
assert edge_group.edges[0]._condition is not None # type: ignore
|
|
|
|
|
|
async def test_single_edge_group_send_message():
|
|
"""Test sending a message through a single edge group."""
|
|
source = MockExecutor(id="source_executor")
|
|
target = MockExecutor(id="target_executor")
|
|
|
|
edge_group = SingleEdgeGroup(source=source, target=target)
|
|
|
|
from agent_framework_workflow._runner_context import InProcRunnerContext, Message
|
|
from agent_framework_workflow._shared_state import SharedState
|
|
|
|
shared_state = SharedState()
|
|
ctx = InProcRunnerContext()
|
|
|
|
data = MockMessage(data="test")
|
|
message = Message(data=data, source_id=source.id)
|
|
|
|
success = await edge_group.send_message(message, shared_state, ctx)
|
|
assert success is True
|
|
|
|
|
|
async def test_single_edge_group_send_message_with_target():
|
|
"""Test sending a message through a single edge group."""
|
|
source = MockExecutor(id="source_executor")
|
|
target = MockExecutor(id="target_executor")
|
|
|
|
edge_group = SingleEdgeGroup(source=source, target=target)
|
|
|
|
from agent_framework_workflow._runner_context import InProcRunnerContext, Message
|
|
from agent_framework_workflow._shared_state import SharedState
|
|
|
|
shared_state = SharedState()
|
|
ctx = InProcRunnerContext()
|
|
|
|
data = MockMessage(data="test")
|
|
message = Message(data=data, source_id=source.id, target_id=target.id)
|
|
|
|
success = await edge_group.send_message(message, shared_state, ctx)
|
|
assert success is True
|
|
|
|
|
|
async def test_single_edge_group_send_message_with_invalid_target():
|
|
"""Test sending a message through a single edge group."""
|
|
source = MockExecutor(id="source_executor")
|
|
target = MockExecutor(id="target_executor")
|
|
|
|
edge_group = SingleEdgeGroup(source=source, target=target)
|
|
|
|
from agent_framework_workflow._runner_context import InProcRunnerContext, Message
|
|
from agent_framework_workflow._shared_state import SharedState
|
|
|
|
shared_state = SharedState()
|
|
ctx = InProcRunnerContext()
|
|
|
|
data = MockMessage(data="test")
|
|
message = Message(data=data, source_id=source.id, target_id="invalid_target")
|
|
|
|
success = await edge_group.send_message(message, shared_state, ctx)
|
|
assert success is False
|
|
|
|
|
|
async def test_single_edge_group_send_message_with_invalid_data():
|
|
"""Test sending a message through a single edge group."""
|
|
source = MockExecutor(id="source_executor")
|
|
target = MockExecutor(id="target_executor")
|
|
|
|
edge_group = SingleEdgeGroup(source=source, target=target)
|
|
|
|
from agent_framework_workflow._runner_context import InProcRunnerContext, Message
|
|
from agent_framework_workflow._shared_state import SharedState
|
|
|
|
shared_state = SharedState()
|
|
ctx = InProcRunnerContext()
|
|
|
|
data = "invalid_data"
|
|
message = Message(data=data, source_id=source.id)
|
|
|
|
success = await edge_group.send_message(message, shared_state, ctx)
|
|
assert success is False
|
|
|
|
|
|
# endregion SingleEdgeGroup
|
|
|
|
|
|
# region FanOutEdgeGroup
|
|
|
|
|
|
def test_source_edge_group():
|
|
"""Test creating a fan-out group."""
|
|
source = MockExecutor(id="source_executor")
|
|
target1 = MockExecutor(id="target_executor_1")
|
|
target2 = MockExecutor(id="target_executor_2")
|
|
|
|
edge_group = FanOutEdgeGroup(source=source, targets=[target1, target2])
|
|
|
|
assert edge_group.source_executors == [source]
|
|
assert edge_group.target_executors == [target1, target2]
|
|
assert len(edge_group.edges) == 2
|
|
assert edge_group.edges[0].source_id == "source_executor"
|
|
assert edge_group.edges[0].target_id == "target_executor_1"
|
|
assert edge_group.edges[1].source_id == "source_executor"
|
|
assert edge_group.edges[1].target_id == "target_executor_2"
|
|
|
|
|
|
def test_source_edge_group_invalid_number_of_targets():
|
|
"""Test creating a fan-out group with an invalid number of targets."""
|
|
source = MockExecutor(id="source_executor")
|
|
target = MockExecutor(id="target_executor")
|
|
|
|
with pytest.raises(ValueError, match="FanOutEdgeGroup must contain at least two targets"):
|
|
FanOutEdgeGroup(source=source, targets=[target])
|
|
|
|
|
|
async def test_source_edge_group_send_message():
|
|
"""Test sending a message through a fan-out group."""
|
|
source = MockExecutor(id="source_executor")
|
|
target1 = MockExecutor(id="target_executor_1")
|
|
target2 = MockExecutor(id="target_executor_2")
|
|
|
|
edge_group = FanOutEdgeGroup(source=source, targets=[target1, target2])
|
|
|
|
from agent_framework_workflow._runner_context import InProcRunnerContext, Message
|
|
from agent_framework_workflow._shared_state import SharedState
|
|
|
|
shared_state = SharedState()
|
|
ctx = InProcRunnerContext()
|
|
|
|
data = MockMessage(data="test")
|
|
message = Message(data=data, source_id=source.id)
|
|
|
|
with patch("agent_framework_workflow._edge.Edge.send_message") as mock_send:
|
|
success = await edge_group.send_message(message, shared_state, ctx)
|
|
|
|
assert success is True
|
|
assert mock_send.call_count == 2
|
|
|
|
|
|
async def test_source_edge_group_send_message_with_target():
|
|
"""Test sending a message through a fan-out group with a target."""
|
|
source = MockExecutor(id="source_executor")
|
|
target1 = MockExecutor(id="target_executor_1")
|
|
target2 = MockExecutor(id="target_executor_2")
|
|
|
|
edge_group = FanOutEdgeGroup(source=source, targets=[target1, target2])
|
|
|
|
from agent_framework_workflow._runner_context import InProcRunnerContext, Message
|
|
from agent_framework_workflow._shared_state import SharedState
|
|
|
|
shared_state = SharedState()
|
|
ctx = InProcRunnerContext()
|
|
|
|
data = MockMessage(data="test")
|
|
message = Message(data=data, source_id=source.id, target_id=target1.id)
|
|
|
|
with patch("agent_framework_workflow._edge.Edge.send_message") as mock_send:
|
|
success = await edge_group.send_message(message, shared_state, ctx)
|
|
|
|
assert success is True
|
|
assert mock_send.call_count == 1
|
|
assert mock_send.call_args[0][0].target_id == target1.id
|
|
|
|
|
|
async def test_source_edge_group_send_message_with_invalid_target():
|
|
"""Test sending a message through a fan-out group with an invalid target."""
|
|
source = MockExecutor(id="source_executor")
|
|
target1 = MockExecutor(id="target_executor_1")
|
|
target2 = MockExecutor(id="target_executor_2")
|
|
|
|
edge_group = FanOutEdgeGroup(source=source, targets=[target1, target2])
|
|
|
|
from agent_framework_workflow._runner_context import InProcRunnerContext, Message
|
|
from agent_framework_workflow._shared_state import SharedState
|
|
|
|
shared_state = SharedState()
|
|
ctx = InProcRunnerContext()
|
|
|
|
data = MockMessage(data="test")
|
|
message = Message(data=data, source_id=source.id, target_id="invalid_target")
|
|
|
|
success = await edge_group.send_message(message, shared_state, ctx)
|
|
assert success is False
|
|
|
|
|
|
async def test_source_edge_group_send_message_with_invalid_data():
|
|
"""Test sending a message through a fan-out group with invalid data."""
|
|
source = MockExecutor(id="source_executor")
|
|
target1 = MockExecutor(id="target_executor_1")
|
|
target2 = MockExecutor(id="target_executor_2")
|
|
|
|
edge_group = FanOutEdgeGroup(source=source, targets=[target1, target2])
|
|
|
|
from agent_framework_workflow._runner_context import InProcRunnerContext, Message
|
|
from agent_framework_workflow._shared_state import SharedState
|
|
|
|
shared_state = SharedState()
|
|
ctx = InProcRunnerContext()
|
|
|
|
data = "invalid_data"
|
|
message = Message(data=data, source_id=source.id)
|
|
|
|
success = await edge_group.send_message(message, shared_state, ctx)
|
|
assert success is False
|
|
|
|
|
|
async def test_source_edge_group_send_message_only_one_successful_send():
|
|
"""Test sending a message through a fan-out group where only one edge can handle the message."""
|
|
source = MockExecutor(id="source_executor")
|
|
target1 = MockExecutor(id="target_executor_1")
|
|
target2 = MockExecutorSecondary(id="target_executor_2")
|
|
|
|
edge_group = FanOutEdgeGroup(source=source, targets=[target1, target2])
|
|
|
|
from agent_framework_workflow._runner_context import InProcRunnerContext, Message
|
|
from agent_framework_workflow._shared_state import SharedState
|
|
|
|
shared_state = SharedState()
|
|
ctx = InProcRunnerContext()
|
|
|
|
data = MockMessage(data="test")
|
|
message = Message(data=data, source_id=source.id)
|
|
|
|
with patch("agent_framework_workflow._edge.Edge.send_message") as mock_send:
|
|
success = await edge_group.send_message(message, shared_state, ctx)
|
|
|
|
assert success is True
|
|
assert mock_send.call_count == 1
|
|
|
|
|
|
def test_source_edge_group_with_selection_func():
|
|
"""Test creating a partitioning edge group."""
|
|
source = MockExecutor(id="source_executor")
|
|
target1 = MockExecutor(id="target_executor_1")
|
|
target2 = MockExecutor(id="target_executor_2")
|
|
|
|
edge_group = FanOutEdgeGroup(
|
|
source=source,
|
|
targets=[target1, target2],
|
|
selection_func=lambda data, target_ids: [target1.id],
|
|
)
|
|
|
|
assert edge_group.source_executors == [source]
|
|
assert edge_group.target_executors == [target1, target2]
|
|
assert len(edge_group.edges) == 2
|
|
assert edge_group.edges[0].source_id == "source_executor"
|
|
assert edge_group.edges[0].target_id == "target_executor_1"
|
|
assert edge_group.edges[1].source_id == "source_executor"
|
|
assert edge_group.edges[1].target_id == "target_executor_2"
|
|
|
|
|
|
async def test_source_edge_group_with_selection_func_send_message():
|
|
"""Test sending a message through a fan-out group with a selection function."""
|
|
source = MockExecutor(id="source_executor")
|
|
target1 = MockExecutor(id="target_executor_1")
|
|
target2 = MockExecutor(id="target_executor_2")
|
|
|
|
edge_group = FanOutEdgeGroup(
|
|
source=source,
|
|
targets=[target1, target2],
|
|
selection_func=lambda data, target_ids: [target1.id, target2.id],
|
|
)
|
|
|
|
from agent_framework_workflow._runner_context import InProcRunnerContext, Message
|
|
from agent_framework_workflow._shared_state import SharedState
|
|
|
|
shared_state = SharedState()
|
|
ctx = InProcRunnerContext()
|
|
|
|
data = MockMessage(data="test")
|
|
message = Message(data=data, source_id=source.id)
|
|
|
|
with patch("agent_framework_workflow._edge.Edge.send_message") as mock_send:
|
|
success = await edge_group.send_message(message, shared_state, ctx)
|
|
|
|
assert success is True
|
|
assert mock_send.call_count == 2
|
|
|
|
|
|
async def test_source_edge_group_with_selection_func_send_message_with_invalid_selection_result():
|
|
"""Test sending a message through a fan-out group with a selection func with an invalid selection result."""
|
|
source = MockExecutor(id="source_executor")
|
|
target1 = MockExecutor(id="target_executor_1")
|
|
target2 = MockExecutor(id="target_executor_2")
|
|
|
|
edge_group = FanOutEdgeGroup(
|
|
source=source,
|
|
targets=[target1, target2],
|
|
selection_func=lambda data, target_ids: [target1.id, "invalid_target"],
|
|
)
|
|
|
|
from agent_framework_workflow._runner_context import InProcRunnerContext, Message
|
|
from agent_framework_workflow._shared_state import SharedState
|
|
|
|
shared_state = SharedState()
|
|
ctx = InProcRunnerContext()
|
|
|
|
data = MockMessage(data="test")
|
|
message = Message(data=data, source_id=source.id)
|
|
|
|
with pytest.raises(RuntimeError):
|
|
await edge_group.send_message(message, shared_state, ctx)
|
|
|
|
|
|
async def test_source_edge_group_with_selection_func_send_message_with_target():
|
|
"""Test sending a message through a fan-out group with a selection func with a target."""
|
|
source = MockExecutor(id="source_executor")
|
|
target1 = MockExecutor(id="target_executor_1")
|
|
target2 = MockExecutor(id="target_executor_2")
|
|
|
|
edge_group = FanOutEdgeGroup(
|
|
source=source,
|
|
targets=[target1, target2],
|
|
selection_func=lambda data, target_ids: [target1.id, target2.id],
|
|
)
|
|
|
|
from agent_framework_workflow._runner_context import InProcRunnerContext, Message
|
|
from agent_framework_workflow._shared_state import SharedState
|
|
|
|
shared_state = SharedState()
|
|
ctx = InProcRunnerContext()
|
|
|
|
data = MockMessage(data="test")
|
|
message = Message(data=data, source_id=source.id, target_id=target1.id)
|
|
|
|
with patch("agent_framework_workflow._edge.Edge.send_message") as mock_send:
|
|
success = await edge_group.send_message(message, shared_state, ctx)
|
|
|
|
assert success is True
|
|
assert mock_send.call_count == 1
|
|
assert mock_send.call_args[0][0].target_id == target1.id
|
|
|
|
|
|
async def test_source_edge_group_with_selection_func_send_message_with_target_not_in_selection():
|
|
"""Test sending a message through a fan-out group with a selection func with a target not in the selection."""
|
|
source = MockExecutor(id="source_executor")
|
|
target1 = MockExecutor(id="target_executor_1")
|
|
target2 = MockExecutor(id="target_executor_2")
|
|
|
|
edge_group = FanOutEdgeGroup(
|
|
source=source,
|
|
targets=[target1, target2],
|
|
selection_func=lambda data, target_ids: [target1.id], # Only target1 will receive the message
|
|
)
|
|
|
|
from agent_framework_workflow._runner_context import InProcRunnerContext, Message
|
|
from agent_framework_workflow._shared_state import SharedState
|
|
|
|
shared_state = SharedState()
|
|
ctx = InProcRunnerContext()
|
|
|
|
data = MockMessage(data="test")
|
|
message = Message(data=data, source_id=source.id, target_id=target2.id)
|
|
|
|
success = await edge_group.send_message(message, shared_state, ctx)
|
|
assert success is False
|
|
|
|
|
|
async def test_source_edge_group_with_selection_func_send_message_with_invalid_data():
|
|
"""Test sending a message through a fan-out group with a selection func with invalid data."""
|
|
source = MockExecutor(id="source_executor")
|
|
target1 = MockExecutor(id="target_executor_1")
|
|
target2 = MockExecutor(id="target_executor_2")
|
|
|
|
edge_group = FanOutEdgeGroup(
|
|
source=source, targets=[target1, target2], selection_func=lambda data, target_ids: [target1.id, target2.id]
|
|
)
|
|
|
|
from agent_framework_workflow._runner_context import InProcRunnerContext, Message
|
|
from agent_framework_workflow._shared_state import SharedState
|
|
|
|
shared_state = SharedState()
|
|
ctx = InProcRunnerContext()
|
|
|
|
data = "invalid_data"
|
|
message = Message(data=data, source_id=source.id)
|
|
|
|
success = await edge_group.send_message(message, shared_state, ctx)
|
|
assert success is False
|
|
|
|
|
|
async def test_source_edge_group_with_selection_func_send_message_with_target_invalid_data():
|
|
"""Test sending a message through a fan-out group with a selection func with a target and invalid data."""
|
|
source = MockExecutor(id="source_executor")
|
|
target1 = MockExecutor(id="target_executor_1")
|
|
target2 = MockExecutor(id="target_executor_2")
|
|
|
|
edge_group = FanOutEdgeGroup(
|
|
source=source, targets=[target1, target2], selection_func=lambda data, target_ids: [target1.id, target2.id]
|
|
)
|
|
|
|
from agent_framework_workflow._runner_context import InProcRunnerContext, Message
|
|
from agent_framework_workflow._shared_state import SharedState
|
|
|
|
shared_state = SharedState()
|
|
ctx = InProcRunnerContext()
|
|
|
|
data = "invalid_data"
|
|
message = Message(data=data, source_id=source.id, target_id=target1.id)
|
|
|
|
success = await edge_group.send_message(message, shared_state, ctx)
|
|
assert success is False
|
|
|
|
|
|
# endregion FanOutEdgeGroup
|
|
|
|
# region FanInEdgeGroup
|
|
|
|
|
|
def test_target_edge_group():
|
|
"""Test creating a fan-in edge group."""
|
|
source1 = MockExecutor(id="source_executor_1")
|
|
source2 = MockExecutor(id="source_executor_2")
|
|
target = MockAggregator(id="target_executor")
|
|
|
|
edge_group = FanInEdgeGroup(sources=[source1, source2], target=target)
|
|
|
|
assert edge_group.source_executors == [source1, source2]
|
|
assert edge_group.target_executors == [target]
|
|
assert len(edge_group.edges) == 2
|
|
assert edge_group.edges[0].source_id == "source_executor_1"
|
|
assert edge_group.edges[0].target_id == "target_executor"
|
|
assert edge_group.edges[1].source_id == "source_executor_2"
|
|
assert edge_group.edges[1].target_id == "target_executor"
|
|
|
|
|
|
def test_target_edge_group_invalid_number_of_sources():
|
|
"""Test creating a fan-in edge group with an invalid number of sources."""
|
|
source = MockExecutor(id="source_executor")
|
|
target = MockAggregator(id="target_executor")
|
|
|
|
with pytest.raises(ValueError, match="FanInEdgeGroup must contain at least two sources"):
|
|
FanInEdgeGroup(sources=[source], target=target)
|
|
|
|
|
|
async def test_target_edge_group_send_message_buffer():
|
|
"""Test sending a message through a fan-in edge group with buffering."""
|
|
source1 = MockExecutor(id="source_executor_1")
|
|
source2 = MockExecutor(id="source_executor_2")
|
|
target = MockAggregator(id="target_executor")
|
|
|
|
edge_group = FanInEdgeGroup(sources=[source1, source2], target=target)
|
|
|
|
from agent_framework_workflow._runner_context import InProcRunnerContext, Message
|
|
from agent_framework_workflow._shared_state import SharedState
|
|
|
|
shared_state = SharedState()
|
|
ctx = InProcRunnerContext()
|
|
|
|
data = MockMessage(data="test")
|
|
|
|
with patch("agent_framework_workflow._edge.Edge.send_message") as mock_send:
|
|
success = await edge_group.send_message(
|
|
Message(data=data, source_id=source1.id),
|
|
shared_state,
|
|
ctx,
|
|
)
|
|
|
|
assert success is True
|
|
assert mock_send.call_count == 0 # The message should be buffered and wait for the second source
|
|
assert len(edge_group._buffer[source1.id]) == 1 # type: ignore
|
|
|
|
success = await edge_group.send_message(
|
|
Message(data=data, source_id=source2.id),
|
|
shared_state,
|
|
ctx,
|
|
)
|
|
assert success is True
|
|
assert mock_send.call_count == 1 # The message should be sent now that both sources have sent their messages
|
|
|
|
# Buffer should be cleared after sending
|
|
assert not edge_group._buffer # type: ignore
|
|
|
|
|
|
async def test_target_edge_group_send_message_with_invalid_target():
|
|
"""Test sending a message through a fan-in edge group with an invalid target."""
|
|
source1 = MockExecutor(id="source_executor_1")
|
|
source2 = MockExecutor(id="source_executor_2")
|
|
target = MockAggregator(id="target_executor")
|
|
|
|
edge_group = FanInEdgeGroup(sources=[source1, source2], target=target)
|
|
|
|
from agent_framework_workflow._runner_context import InProcRunnerContext, Message
|
|
from agent_framework_workflow._shared_state import SharedState
|
|
|
|
shared_state = SharedState()
|
|
ctx = InProcRunnerContext()
|
|
|
|
data = MockMessage(data="test")
|
|
message = Message(data=data, source_id=source1.id, target_id="invalid_target")
|
|
|
|
success = await edge_group.send_message(message, shared_state, ctx)
|
|
assert success is False
|
|
|
|
|
|
async def test_target_edge_group_send_message_with_invalid_data():
|
|
"""Test sending a message through a fan-in edge group with invalid data."""
|
|
source1 = MockExecutor(id="source_executor_1")
|
|
source2 = MockExecutor(id="source_executor_2")
|
|
target = MockAggregator(id="target_executor")
|
|
|
|
edge_group = FanInEdgeGroup(sources=[source1, source2], target=target)
|
|
|
|
from agent_framework_workflow._runner_context import InProcRunnerContext, Message
|
|
from agent_framework_workflow._shared_state import SharedState
|
|
|
|
shared_state = SharedState()
|
|
ctx = InProcRunnerContext()
|
|
|
|
data = "invalid_data"
|
|
message = Message(data=data, source_id=source1.id)
|
|
|
|
success = await edge_group.send_message(message, shared_state, ctx)
|
|
assert success is False
|
|
|
|
|
|
# endregion FanInEdgeGroup
|
|
|
|
# region SwitchCaseEdgeGroup
|
|
|
|
|
|
def test_switch_case_edge_group():
|
|
"""Test creating a switch case edge group."""
|
|
source = MockExecutor(id="source_executor")
|
|
target1 = MockExecutor(id="target_executor_1")
|
|
target2 = MockExecutor(id="target_executor_2")
|
|
|
|
edge_group = SwitchCaseEdgeGroup(
|
|
source=source,
|
|
cases=[
|
|
Case(condition=lambda x: x.data < 0, target=target1),
|
|
Default(target=target2),
|
|
],
|
|
)
|
|
|
|
assert edge_group.source_executors == [source]
|
|
assert edge_group.target_executors == [target1, target2]
|
|
assert len(edge_group.edges) == 2
|
|
assert edge_group.edges[0].source_id == "source_executor"
|
|
assert edge_group.edges[0].target_id == "target_executor_1"
|
|
assert edge_group.edges[1].source_id == "source_executor"
|
|
assert edge_group.edges[1].target_id == "target_executor_2"
|
|
|
|
assert edge_group._selection_func is not None # type: ignore
|
|
assert edge_group._selection_func(MockMessage(data=-1), [target1.id, target2.id]) == [target1.id] # type: ignore
|
|
assert edge_group._selection_func(MockMessage(data=1), [target1.id, target2.id]) == [target2.id] # type: ignore
|
|
|
|
|
|
def test_switch_case_edge_group_invalid_number_of_cases():
|
|
"""Test creating a switch case edge group with an invalid number of cases."""
|
|
source = MockExecutor(id="source_executor")
|
|
target = MockExecutor(id="target_executor")
|
|
|
|
with pytest.raises(
|
|
ValueError, match=r"SwitchCaseEdgeGroup must contain at least two cases \(including the default case\)."
|
|
):
|
|
SwitchCaseEdgeGroup(
|
|
source=source,
|
|
cases=[
|
|
Case(condition=lambda x: x.data < 0, target=target),
|
|
],
|
|
)
|
|
|
|
with pytest.raises(ValueError, match="SwitchCaseEdgeGroup must contain exactly one default case."):
|
|
SwitchCaseEdgeGroup(
|
|
source=source,
|
|
cases=[
|
|
Case(condition=lambda x: x.data < 0, target=target),
|
|
Case(condition=lambda x: x.data >= 0, target=target),
|
|
],
|
|
)
|
|
|
|
|
|
def test_switch_case_edge_group_invalid_number_of_default_cases():
|
|
"""Test creating a switch case edge group with an invalid number of conditions."""
|
|
source = MockExecutor(id="source_executor")
|
|
target1 = MockExecutor(id="target_executor_1")
|
|
target2 = MockExecutor(id="target_executor_2")
|
|
|
|
with pytest.raises(ValueError, match="SwitchCaseEdgeGroup must contain exactly one default case."):
|
|
SwitchCaseEdgeGroup(
|
|
source=source,
|
|
cases=[
|
|
Case(condition=lambda x: x.data < 0, target=target1),
|
|
Default(target=target2),
|
|
Default(target=target2),
|
|
],
|
|
)
|
|
|
|
|
|
async def test_switch_case_edge_group_send_message():
|
|
"""Test sending a message through a switch case edge group."""
|
|
source = MockExecutor(id="source_executor")
|
|
target1 = MockExecutor(id="target_executor_1")
|
|
target2 = MockExecutor(id="target_executor_2")
|
|
|
|
edge_group = SwitchCaseEdgeGroup(
|
|
source=source,
|
|
cases=[
|
|
Case(condition=lambda x: x.data < 0, target=target1),
|
|
Default(target=target2),
|
|
],
|
|
)
|
|
|
|
from agent_framework_workflow._runner_context import InProcRunnerContext, Message
|
|
from agent_framework_workflow._shared_state import SharedState
|
|
|
|
shared_state = SharedState()
|
|
ctx = InProcRunnerContext()
|
|
|
|
data = MockMessage(data=-1)
|
|
message = Message(data=data, source_id=source.id)
|
|
|
|
with patch("agent_framework_workflow._edge.Edge.send_message") as mock_send:
|
|
success = await edge_group.send_message(message, shared_state, ctx)
|
|
|
|
assert success is True
|
|
assert mock_send.call_count == 1
|
|
|
|
# Default condition should
|
|
data = MockMessage(data=1)
|
|
message = Message(data=data, source_id=source.id)
|
|
with patch("agent_framework_workflow._edge.Edge.send_message") as mock_send:
|
|
success = await edge_group.send_message(message, shared_state, ctx)
|
|
|
|
assert success is True
|
|
assert mock_send.call_count == 1
|
|
|
|
|
|
async def test_switch_case_edge_group_send_message_with_invalid_target():
|
|
"""Test sending a message through a switch case edge group with an invalid target."""
|
|
source = MockExecutor(id="source_executor")
|
|
target1 = MockExecutor(id="target_executor_1")
|
|
target2 = MockExecutor(id="target_executor_2")
|
|
|
|
edge_group = SwitchCaseEdgeGroup(
|
|
source=source,
|
|
cases=[
|
|
Case(condition=lambda x: x.data < 0, target=target1),
|
|
Default(target=target2),
|
|
],
|
|
)
|
|
|
|
from agent_framework_workflow._runner_context import InProcRunnerContext, Message
|
|
from agent_framework_workflow._shared_state import SharedState
|
|
|
|
shared_state = SharedState()
|
|
ctx = InProcRunnerContext()
|
|
|
|
data = MockMessage(data=-1)
|
|
message = Message(data=data, source_id=source.id, target_id="invalid_target")
|
|
|
|
success = await edge_group.send_message(message, shared_state, ctx)
|
|
assert success is False
|
|
|
|
|
|
async def test_switch_case_edge_group_send_message_with_valid_target():
|
|
"""Test sending a message through a switch case edge group with a target."""
|
|
source = MockExecutor(id="source_executor")
|
|
target1 = MockExecutor(id="target_executor_1")
|
|
target2 = MockExecutor(id="target_executor_2")
|
|
|
|
edge_group = SwitchCaseEdgeGroup(
|
|
source=source,
|
|
cases=[
|
|
Case(condition=lambda x: x.data < 0, target=target1),
|
|
Default(target=target2),
|
|
],
|
|
)
|
|
|
|
from agent_framework_workflow._runner_context import InProcRunnerContext, Message
|
|
from agent_framework_workflow._shared_state import SharedState
|
|
|
|
shared_state = SharedState()
|
|
ctx = InProcRunnerContext()
|
|
|
|
data = MockMessage(data=1) # Condition will fail
|
|
message = Message(data=data, source_id=source.id, target_id=target1.id)
|
|
|
|
success = await edge_group.send_message(message, shared_state, ctx)
|
|
assert success is False
|
|
|
|
data = MockMessage(data=-1) # Condition will pass
|
|
message = Message(data=data, source_id=source.id, target_id=target1.id)
|
|
success = await edge_group.send_message(message, shared_state, ctx)
|
|
assert success is True
|
|
|
|
|
|
async def test_switch_case_edge_group_send_message_with_invalid_data():
|
|
"""Test sending a message through a switch case edge group with invalid data."""
|
|
source = MockExecutor(id="source_executor")
|
|
target1 = MockExecutor(id="target_executor_1")
|
|
target2 = MockExecutor(id="target_executor_2")
|
|
|
|
edge_group = SwitchCaseEdgeGroup(
|
|
source=source,
|
|
cases=[
|
|
Case(condition=lambda x: x.data < 0, target=target1),
|
|
Default(target=target2),
|
|
],
|
|
)
|
|
|
|
from agent_framework_workflow._runner_context import InProcRunnerContext, Message
|
|
from agent_framework_workflow._shared_state import SharedState
|
|
|
|
shared_state = SharedState()
|
|
ctx = InProcRunnerContext()
|
|
|
|
data = "invalid_data"
|
|
message = Message(data=data, source_id=source.id)
|
|
|
|
success = await edge_group.send_message(message, shared_state, ctx)
|
|
assert success is False
|
|
|
|
|
|
# endregion SwitchCaseEdgeGroup
|