# Copyright (c) Microsoft. All rights reserved. from dataclasses import dataclass from typing import Any from unittest.mock import patch import pytest from agent_framework.workflow import Executor, WorkflowContext, handler from opentelemetry.sdk.trace import TracerProvider from opentelemetry.sdk.trace.export import SimpleSpanProcessor from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter from agent_framework_workflow._edge import ( Edge, FanInEdgeGroup, FanOutEdgeGroup, SingleEdgeGroup, SwitchCaseEdgeGroup, SwitchCaseEdgeGroupCase, SwitchCaseEdgeGroupDefault, ) from agent_framework_workflow._edge_runner import create_edge_runner from agent_framework_workflow._runner_context import InProcRunnerContext, Message from agent_framework_workflow._shared_state import SharedState from agent_framework_workflow._telemetry import EdgeGroupDeliveryStatus, workflow_tracer @pytest.fixture def tracing_enabled(): """Enable tracing for tests.""" import os original_value = os.environ.get("AGENT_FRAMEWORK_WORKFLOW_ENABLE_OTEL_DIAGNOSTICS") os.environ["AGENT_FRAMEWORK_WORKFLOW_ENABLE_OTEL_DIAGNOSTICS"] = "true" # Force reload the settings to pick up the environment variable from agent_framework_workflow._telemetry import WorkflowDiagnosticSettings workflow_tracer.settings = WorkflowDiagnosticSettings() yield # Restore original value if original_value is None: os.environ.pop("AGENT_FRAMEWORK_WORKFLOW_ENABLE_OTEL_DIAGNOSTICS", None) else: os.environ["AGENT_FRAMEWORK_WORKFLOW_ENABLE_OTEL_DIAGNOSTICS"] = original_value # Reload settings again workflow_tracer.settings = WorkflowDiagnosticSettings() @pytest.fixture def span_exporter(tracing_enabled): """Set up OpenTelemetry test infrastructure.""" # Use the built-in InMemorySpanExporter for better compatibility exporter = InMemorySpanExporter() tracer_provider = TracerProvider() tracer_provider.add_span_processor(SimpleSpanProcessor(exporter)) # Store original tracer original_tracer = workflow_tracer.tracer # Set up our test tracer workflow_tracer.tracer = tracer_provider.get_tracer("agent_framework") yield exporter # Clean up exporter.clear() workflow_tracer.tracer = original_tracer @dataclass class MockMessage: """A mock message for testing purposes.""" data: Any @dataclass class MockMessageSecondary: """A secondary mock message for testing purposes.""" data: Any class MockExecutor(Executor): """A mock executor for testing purposes.""" call_count: int = 0 last_message: Any = None @handler async def mock_handler(self, message: MockMessage, ctx: WorkflowContext) -> None: """A mock handler that does nothing.""" self.call_count += 1 self.last_message = message class MockExecutorSecondary(Executor): """A secondary mock executor for testing purposes.""" call_count: int = 0 last_message: Any = None @handler async def mock_handler_secondary(self, message: MockMessageSecondary, ctx: WorkflowContext) -> None: """A secondary mock handler that does nothing.""" self.call_count += 1 self.last_message = message class MockAggregator(Executor): """A mock aggregator for testing purposes.""" call_count: int = 0 last_message: Any = None @handler async def mock_aggregator_handler(self, message: list[MockMessage], ctx: WorkflowContext) -> None: """A mock aggregator handler that does nothing.""" self.call_count += 1 self.last_message = message # region Edge def test_create_edge(): """Test creating an edge with a source and target executor.""" source = MockExecutor(id="source_executor") target = MockExecutor(id="target_executor") edge = Edge(source_id=source.id, target_id=target.id) assert edge.source_id == "source_executor" assert edge.target_id == "target_executor" assert edge.id == f"{edge.source_id}{Edge.ID_SEPARATOR}{edge.target_id}" def test_edge_can_handle(): """Test creating an edge with a source and target executor.""" source = MockExecutor(id="source_executor") target = MockExecutor(id="target_executor") edge = Edge(source_id=source.id, target_id=target.id) assert edge.should_route(MockMessage(data="test")) # endregion Edge # region SingleEdgeGroup def test_single_edge_group(): """Test creating a single edge group.""" source = MockExecutor(id="source_executor") target = MockExecutor(id="target_executor") edge_group = SingleEdgeGroup(source_id=source.id, target_id=target.id) assert edge_group.source_executor_ids == [source.id] assert edge_group.target_executor_ids == [target.id] assert edge_group.edges[0].source_id == "source_executor" assert edge_group.edges[0].target_id == "target_executor" def test_single_edge_group_with_condition(): """Test creating a single edge group with a condition.""" source = MockExecutor(id="source_executor") target = MockExecutor(id="target_executor") edge_group = SingleEdgeGroup(source_id=source.id, target_id=target.id, condition=lambda x: x.data == "test") assert edge_group.source_executor_ids == [source.id] assert edge_group.target_executor_ids == [target.id] assert edge_group.edges[0].source_id == "source_executor" assert edge_group.edges[0].target_id == "target_executor" assert edge_group.edges[0]._condition is not None # type: ignore async def test_single_edge_group_send_message() -> None: """Test sending a message through a single edge runner.""" source = MockExecutor(id="source_executor") target = MockExecutor(id="target_executor") executors: dict[str, Executor] = {source.id: source, target.id: target} edge_group = SingleEdgeGroup(source_id=source.id, target_id=target.id) edge_runner = create_edge_runner(edge_group, executors) shared_state = SharedState() ctx = InProcRunnerContext() data = MockMessage(data="test") message = Message(data=data, source_id=source.id) success = await edge_runner.send_message(message, shared_state, ctx) assert success is True async def test_single_edge_group_send_message_with_target() -> None: """Test sending a message through a single edge runner.""" source = MockExecutor(id="source_executor") target = MockExecutor(id="target_executor") executors: dict[str, Executor] = {source.id: source, target.id: target} edge_group = SingleEdgeGroup(source_id=source.id, target_id=target.id) edge_runner = create_edge_runner(edge_group, executors) shared_state = SharedState() ctx = InProcRunnerContext() data = MockMessage(data="test") message = Message(data=data, source_id=source.id, target_id=target.id) success = await edge_runner.send_message(message, shared_state, ctx) assert success is True async def test_single_edge_group_send_message_with_invalid_target() -> None: """Test sending a message through a single edge runner.""" source = MockExecutor(id="source_executor") target = MockExecutor(id="target_executor") executors: dict[str, Executor] = {source.id: source, target.id: target} edge_group = SingleEdgeGroup(source_id=source.id, target_id=target.id) edge_runner = create_edge_runner(edge_group, executors) shared_state = SharedState() ctx = InProcRunnerContext() data = MockMessage(data="test") message = Message(data=data, source_id=source.id, target_id="invalid_target") success = await edge_runner.send_message(message, shared_state, ctx) assert success is False async def test_single_edge_group_send_message_with_invalid_data() -> None: """Test sending a message through a single edge runner with invalid data.""" source = MockExecutor(id="source_executor") target = MockExecutor(id="target_executor") executors: dict[str, Executor] = {source.id: source, target.id: target} edge_group = SingleEdgeGroup(source_id=source.id, target_id=target.id) edge_runner = create_edge_runner(edge_group, executors) shared_state = SharedState() ctx = InProcRunnerContext() data = "invalid_data" message = Message(data=data, source_id=source.id) success = await edge_runner.send_message(message, shared_state, ctx) assert success is False async def test_single_edge_group_send_message_with_condition_pass() -> None: """Test sending a message through a single edge runner with a condition that passes.""" source = MockExecutor(id="source_executor") target = MockExecutor(id="target_executor") executors: dict[str, Executor] = {source.id: source, target.id: target} # Create edge group with condition that passes when data == "test" edge_group = SingleEdgeGroup(source_id=source.id, target_id=target.id, condition=lambda x: x.data == "test") edge_runner = create_edge_runner(edge_group, executors) shared_state = SharedState() ctx = InProcRunnerContext() data = MockMessage(data="test") message = Message(data=data, source_id=source.id) success = await edge_runner.send_message(message, shared_state, ctx) assert success is True assert target.call_count == 1 assert target.last_message.data == "test" async def test_single_edge_group_send_message_with_condition_fail() -> None: """Test sending a message through a single edge runner with a condition that fails.""" source = MockExecutor(id="source_executor") target = MockExecutor(id="target_executor") executors: dict[str, Executor] = {source.id: source, target.id: target} # Create edge group with condition that passes when data == "test" edge_group = SingleEdgeGroup(source_id=source.id, target_id=target.id, condition=lambda x: x.data == "test") edge_runner = create_edge_runner(edge_group, executors) shared_state = SharedState() ctx = InProcRunnerContext() data = MockMessage(data="different") message = Message(data=data, source_id=source.id) success = await edge_runner.send_message(message, shared_state, ctx) # Should return True because message was processed, but condition failed assert success is True # Target should not be called because condition failed assert target.call_count == 0 async def test_single_edge_group_tracing_success(span_exporter) -> None: """Test that single edge group processing creates proper success spans.""" source = MockExecutor(id="source_executor") target = MockExecutor(id="target_executor") executors: dict[str, Executor] = {source.id: source, target.id: target} edge_group = SingleEdgeGroup(source_id=source.id, target_id=target.id) edge_runner = create_edge_runner(edge_group, executors) shared_state = SharedState() ctx = InProcRunnerContext() # Create trace context and span IDs to simulate a message with tracing information trace_contexts = [{"traceparent": "00-4bf92f3577b34da6a3ce929d0e0e4736-00f067aa0ba902b7-01"}] source_span_ids = ["00f067aa0ba902b7"] data = MockMessage(data="test") message = Message(data=data, source_id=source.id, trace_contexts=trace_contexts, source_span_ids=source_span_ids) # Clear any build spans span_exporter.clear() success = await edge_runner.send_message(message, shared_state, ctx) assert success is True spans = span_exporter.get_finished_spans() edge_group_spans = [s for s in spans if s.name == "edge_group.process"] assert len(edge_group_spans) == 1 span = edge_group_spans[0] assert span.attributes is not None assert span.attributes.get("edge_group.type") == "SingleEdgeGroup" assert span.attributes.get("edge_group.delivered") is True assert span.attributes.get("edge_group.delivery_status") == EdgeGroupDeliveryStatus.DELIVERED.value assert span.attributes.get("edge_group.id") is not None assert span.attributes.get("message.source_id") == source.id # Verify span links are created assert span.links is not None assert len(span.links) == 1 link = span.links[0] # Verify the link points to the correct trace and span assert link.context.trace_id == int("4bf92f3577b34da6a3ce929d0e0e4736", 16) assert link.context.span_id == int("00f067aa0ba902b7", 16) async def test_single_edge_group_tracing_condition_failure(span_exporter) -> None: """Test that single edge group processing creates proper spans for condition failures.""" source = MockExecutor(id="source_executor") target = MockExecutor(id="target_executor") executors: dict[str, Executor] = {source.id: source, target.id: target} edge_group = SingleEdgeGroup(source_id=source.id, target_id=target.id, condition=lambda x: x.data == "pass") edge_runner = create_edge_runner(edge_group, executors) shared_state = SharedState() ctx = InProcRunnerContext() data = MockMessage(data="fail") message = Message(data=data, source_id=source.id) # Clear any build spans span_exporter.clear() success = await edge_runner.send_message(message, shared_state, ctx) assert success is True # Returns True but condition failed spans = span_exporter.get_finished_spans() edge_group_spans = [s for s in spans if s.name == "edge_group.process"] assert len(edge_group_spans) == 1 span = edge_group_spans[0] assert span.attributes is not None assert span.attributes.get("edge_group.type") == "SingleEdgeGroup" assert span.attributes.get("edge_group.delivered") is False assert span.attributes.get("edge_group.delivery_status") == EdgeGroupDeliveryStatus.DROPPED_CONDITION_FALSE.value async def test_single_edge_group_tracing_type_mismatch(span_exporter) -> None: """Test that single edge group processing creates proper spans for type mismatches.""" source = MockExecutor(id="source_executor") target = MockExecutor(id="target_executor") executors: dict[str, Executor] = {source.id: source, target.id: target} edge_group = SingleEdgeGroup(source_id=source.id, target_id=target.id) edge_runner = create_edge_runner(edge_group, executors) shared_state = SharedState() ctx = InProcRunnerContext() # Send incompatible data type data = "invalid_data" message = Message(data=data, source_id=source.id) # Clear any build spans span_exporter.clear() success = await edge_runner.send_message(message, shared_state, ctx) assert success is False spans = span_exporter.get_finished_spans() edge_group_spans = [s for s in spans if s.name == "edge_group.process"] assert len(edge_group_spans) == 1 span = edge_group_spans[0] assert span.attributes is not None assert span.attributes.get("edge_group.type") == "SingleEdgeGroup" assert span.attributes.get("edge_group.delivered") is False assert span.attributes.get("edge_group.delivery_status") == EdgeGroupDeliveryStatus.DROPPED_TYPE_MISMATCH.value async def test_single_edge_group_tracing_target_mismatch(span_exporter) -> None: """Test that single edge group processing creates proper spans for target mismatches.""" source = MockExecutor(id="source_executor") target = MockExecutor(id="target_executor") executors: dict[str, Executor] = {source.id: source, target.id: target} edge_group = SingleEdgeGroup(source_id=source.id, target_id=target.id) edge_runner = create_edge_runner(edge_group, executors) shared_state = SharedState() ctx = InProcRunnerContext() data = MockMessage(data="test") message = Message(data=data, source_id=source.id, target_id="wrong_target") # Clear any build spans span_exporter.clear() success = await edge_runner.send_message(message, shared_state, ctx) assert success is False spans = span_exporter.get_finished_spans() edge_group_spans = [s for s in spans if s.name == "edge_group.process"] assert len(edge_group_spans) == 1 span = edge_group_spans[0] assert span.attributes is not None assert span.attributes.get("edge_group.type") == "SingleEdgeGroup" assert span.attributes.get("edge_group.delivered") is False assert span.attributes.get("edge_group.delivery_status") == EdgeGroupDeliveryStatus.DROPPED_TARGET_MISMATCH.value assert span.attributes.get("message.target_id") == "wrong_target" # endregion SingleEdgeGroup # region FanOutEdgeGroup def test_source_edge_group(): """Test creating a fan-out group.""" source = MockExecutor(id="source_executor") target1 = MockExecutor(id="target_executor_1") target2 = MockExecutor(id="target_executor_2") edge_group = FanOutEdgeGroup(source_id=source.id, target_ids=[target1.id, target2.id]) assert edge_group.source_executor_ids == [source.id] assert edge_group.target_executor_ids == [target1.id, target2.id] assert len(edge_group.edges) == 2 assert edge_group.edges[0].source_id == "source_executor" assert edge_group.edges[0].target_id == "target_executor_1" assert edge_group.edges[1].source_id == "source_executor" assert edge_group.edges[1].target_id == "target_executor_2" def test_source_edge_group_invalid_number_of_targets() -> None: """Test creating a fan-out group with an invalid number of targets.""" source = MockExecutor(id="source_executor") target = MockExecutor(id="target_executor") with pytest.raises(ValueError, match="FanOutEdgeGroup must contain at least two targets"): FanOutEdgeGroup(source_id=source.id, target_ids=[target.id]) async def test_source_edge_group_send_message() -> None: """Test sending a message through a fan-out edge runner.""" source = MockExecutor(id="source_executor") target1 = MockExecutor(id="target_executor_1") target2 = MockExecutor(id="target_executor_2") executors: dict[str, Executor] = {source.id: source, target1.id: target1, target2.id: target2} edge_group = FanOutEdgeGroup(source_id=source.id, target_ids=[target1.id, target2.id]) edge_runner = create_edge_runner(edge_group, executors) shared_state = SharedState() ctx = InProcRunnerContext() data = MockMessage(data="test") message = Message(data=data, source_id=source.id) success = await edge_runner.send_message(message, shared_state, ctx) assert success is True assert target1.call_count == 1 assert target2.call_count == 1 async def test_source_edge_group_send_message_with_target() -> None: """Test sending a message through a fan-out group with a target.""" source = MockExecutor(id="source_executor") target1 = MockExecutor(id="target_executor_1") target2 = MockExecutor(id="target_executor_2") edge_group = FanOutEdgeGroup(source_id=source.id, target_ids=[target1.id, target2.id]) executors: dict[str, Executor] = {source.id: source, target1.id: target1, target2.id: target2} edge_runner = create_edge_runner(edge_group, executors) shared_state = SharedState() ctx = InProcRunnerContext() data = MockMessage(data="test") message = Message(data=data, source_id=source.id, target_id=target1.id) success = await edge_runner.send_message(message, shared_state, ctx) assert success is True assert target1.call_count == 1 assert target2.call_count == 0 # target2 should not be called since message targets target1 async def test_source_edge_group_send_message_with_invalid_target() -> None: """Test sending a message through a fan-out group with an invalid target.""" source = MockExecutor(id="source_executor") target1 = MockExecutor(id="target_executor_1") target2 = MockExecutor(id="target_executor_2") edge_group = FanOutEdgeGroup(source_id=source.id, target_ids=[target1.id, target2.id]) executors: dict[str, Executor] = {source.id: source, target1.id: target1, target2.id: target2} edge_runner = create_edge_runner(edge_group, executors) shared_state = SharedState() ctx = InProcRunnerContext() data = MockMessage(data="test") message = Message(data=data, source_id=source.id, target_id="invalid_target") success = await edge_runner.send_message(message, shared_state, ctx) assert success is False async def test_source_edge_group_send_message_with_invalid_data() -> None: """Test sending a message through a fan-out group with invalid data.""" source = MockExecutor(id="source_executor") target1 = MockExecutor(id="target_executor_1") target2 = MockExecutor(id="target_executor_2") edge_group = FanOutEdgeGroup(source_id=source.id, target_ids=[target1.id, target2.id]) executors: dict[str, Executor] = {source.id: source, target1.id: target1, target2.id: target2} edge_runner = create_edge_runner(edge_group, executors) shared_state = SharedState() ctx = InProcRunnerContext() data = "invalid_data" message = Message(data=data, source_id=source.id) success = await edge_runner.send_message(message, shared_state, ctx) assert success is False async def test_source_edge_group_send_message_only_one_successful_send() -> None: """Test sending a message through a fan-out group where only one edge can handle the message.""" source = MockExecutor(id="source_executor") target1 = MockExecutor(id="target_executor_1") target2 = MockExecutorSecondary(id="target_executor_2") edge_group = FanOutEdgeGroup(source_id=source.id, target_ids=[target1.id, target2.id]) executors: dict[str, Executor] = {source.id: source, target1.id: target1, target2.id: target2} edge_runner = create_edge_runner(edge_group, executors) shared_state = SharedState() ctx = InProcRunnerContext() data = MockMessage(data="test") message = Message(data=data, source_id=source.id) success = await edge_runner.send_message(message, shared_state, ctx) assert success is True assert target1.call_count == 1 # target1 can handle MockMessage assert target2.call_count == 0 # target2 (MockExecutorSecondary) cannot handle MockMessage def test_source_edge_group_with_selection_func(): """Test creating a partitioning edge group.""" source = MockExecutor(id="source_executor") target1 = MockExecutor(id="target_executor_1") target2 = MockExecutor(id="target_executor_2") edge_group = FanOutEdgeGroup( source_id=source.id, target_ids=[target1.id, target2.id], selection_func=lambda data, target_ids: [target1.id], ) assert edge_group.source_executor_ids == [source.id] assert edge_group.target_executor_ids == [target1.id, target2.id] assert len(edge_group.edges) == 2 assert edge_group.edges[0].source_id == "source_executor" assert edge_group.edges[0].target_id == "target_executor_1" assert edge_group.edges[1].source_id == "source_executor" assert edge_group.edges[1].target_id == "target_executor_2" async def test_source_edge_group_with_selection_func_send_message() -> None: """Test sending a message through a fan-out group with a selection function.""" source = MockExecutor(id="source_executor") target1 = MockExecutor(id="target_executor_1") target2 = MockExecutor(id="target_executor_2") edge_group = FanOutEdgeGroup( source_id=source.id, target_ids=[target1.id, target2.id], selection_func=lambda data, target_ids: [target1.id, target2.id], ) executors: dict[str, Executor] = {source.id: source, target1.id: target1, target2.id: target2} edge_runner = create_edge_runner(edge_group, executors) shared_state = SharedState() ctx = InProcRunnerContext() data = MockMessage(data="test") message = Message(data=data, source_id=source.id) with patch("agent_framework_workflow._edge_runner.EdgeRunner._execute_on_target") as mock_send: success = await edge_runner.send_message(message, shared_state, ctx) assert success is True assert mock_send.call_count == 2 async def test_source_edge_group_with_selection_func_send_message_with_invalid_selection_result() -> None: """Test sending a message through a fan-out group with a selection func with an invalid selection result.""" source = MockExecutor(id="source_executor") target1 = MockExecutor(id="target_executor_1") target2 = MockExecutor(id="target_executor_2") edge_group = FanOutEdgeGroup( source_id=source.id, target_ids=[target1.id, target2.id], selection_func=lambda data, target_ids: [target1.id, "invalid_target"], ) executors: dict[str, Executor] = {source.id: source, target1.id: target1, target2.id: target2} edge_runner = create_edge_runner(edge_group, executors) shared_state = SharedState() ctx = InProcRunnerContext() data = MockMessage(data="test") message = Message(data=data, source_id=source.id) with pytest.raises(RuntimeError): await edge_runner.send_message(message, shared_state, ctx) async def test_source_edge_group_with_selection_func_send_message_with_target() -> None: """Test sending a message through a fan-out group with a selection func with a target.""" source = MockExecutor(id="source_executor") target1 = MockExecutor(id="target_executor_1") target2 = MockExecutor(id="target_executor_2") edge_group = FanOutEdgeGroup( source_id=source.id, target_ids=[target1.id, target2.id], selection_func=lambda data, target_ids: [target1.id, target2.id], ) executors: dict[str, Executor] = {source.id: source, target1.id: target1, target2.id: target2} edge_runner = create_edge_runner(edge_group, executors) shared_state = SharedState() ctx = InProcRunnerContext() data = MockMessage(data="test") message = Message(data=data, source_id=source.id, target_id=target1.id) with patch("agent_framework_workflow._edge_runner.EdgeRunner._execute_on_target") as mock_send: success = await edge_runner.send_message(message, shared_state, ctx) assert success is True assert mock_send.call_count == 1 assert mock_send.call_args[0][0] == target1.id async def test_source_edge_group_with_selection_func_send_message_with_target_not_in_selection() -> None: """Test sending a message through a fan-out group with a selection func with a target not in the selection.""" source = MockExecutor(id="source_executor") target1 = MockExecutor(id="target_executor_1") target2 = MockExecutor(id="target_executor_2") edge_group = FanOutEdgeGroup( source_id=source.id, target_ids=[target1.id, target2.id], selection_func=lambda data, target_ids: [target1.id], # Only target1 will receive the message ) executors: dict[str, Executor] = {source.id: source, target1.id: target1, target2.id: target2} edge_runner = create_edge_runner(edge_group, executors) shared_state = SharedState() ctx = InProcRunnerContext() data = MockMessage(data="test") message = Message(data=data, source_id=source.id, target_id=target2.id) success = await edge_runner.send_message(message, shared_state, ctx) assert success is False async def test_source_edge_group_with_selection_func_send_message_with_invalid_data() -> None: """Test sending a message through a fan-out group with a selection func with invalid data.""" source = MockExecutor(id="source_executor") target1 = MockExecutor(id="target_executor_1") target2 = MockExecutor(id="target_executor_2") edge_group = FanOutEdgeGroup( source_id=source.id, target_ids=[target1.id, target2.id], selection_func=lambda data, target_ids: [target1.id, target2.id], ) executors: dict[str, Executor] = {source.id: source, target1.id: target1, target2.id: target2} edge_runner = create_edge_runner(edge_group, executors) shared_state = SharedState() ctx = InProcRunnerContext() data = "invalid_data" message = Message(data=data, source_id=source.id) success = await edge_runner.send_message(message, shared_state, ctx) assert success is False async def test_source_edge_group_with_selection_func_send_message_with_target_invalid_data() -> None: """Test sending a message through a fan-out group with a selection func with a target and invalid data.""" source = MockExecutor(id="source_executor") target1 = MockExecutor(id="target_executor_1") target2 = MockExecutor(id="target_executor_2") edge_group = FanOutEdgeGroup( source_id=source.id, target_ids=[target1.id, target2.id], selection_func=lambda data, target_ids: [target1.id, target2.id], ) executors: dict[str, Executor] = {source.id: source, target1.id: target1, target2.id: target2} edge_runner = create_edge_runner(edge_group, executors) shared_state = SharedState() ctx = InProcRunnerContext() data = "invalid_data" message = Message(data=data, source_id=source.id, target_id=target1.id) success = await edge_runner.send_message(message, shared_state, ctx) assert success is False async def test_fan_out_edge_group_tracing_success(span_exporter) -> None: """Test that fan-out edge group processing creates proper success spans.""" source = MockExecutor(id="source_executor") target1 = MockExecutor(id="target_executor_1") target2 = MockExecutor(id="target_executor_2") executors: dict[str, Executor] = {source.id: source, target1.id: target1, target2.id: target2} edge_group = FanOutEdgeGroup(source_id=source.id, target_ids=[target1.id, target2.id]) edge_runner = create_edge_runner(edge_group, executors) shared_state = SharedState() ctx = InProcRunnerContext() # Create trace context and span IDs to simulate a message with tracing information trace_contexts = [{"traceparent": "00-4bf92f3577b34da6a3ce929d0e0e4736-00f067aa0ba902b7-01"}] source_span_ids = ["00f067aa0ba902b7"] data = MockMessage(data="test") message = Message(data=data, source_id=source.id, trace_contexts=trace_contexts, source_span_ids=source_span_ids) # Clear any build spans span_exporter.clear() success = await edge_runner.send_message(message, shared_state, ctx) assert success is True spans = span_exporter.get_finished_spans() edge_group_spans = [s for s in spans if s.name == "edge_group.process"] assert len(edge_group_spans) == 1 span = edge_group_spans[0] assert span.attributes is not None assert span.attributes.get("edge_group.type") == "FanOutEdgeGroup" assert span.attributes.get("edge_group.delivered") is True assert span.attributes.get("edge_group.delivery_status") == EdgeGroupDeliveryStatus.DELIVERED.value assert span.attributes.get("edge_group.id") is not None assert span.attributes.get("message.source_id") == source.id # Verify span links are created assert span.links is not None assert len(span.links) == 1 link = span.links[0] # Verify the link points to the correct trace and span assert link.context.trace_id == int("4bf92f3577b34da6a3ce929d0e0e4736", 16) assert link.context.span_id == int("00f067aa0ba902b7", 16) async def test_fan_out_edge_group_tracing_with_target(span_exporter) -> None: """Test that fan-out edge group processing creates proper spans for targeted messages.""" source = MockExecutor(id="source_executor") target1 = MockExecutor(id="target_executor_1") target2 = MockExecutor(id="target_executor_2") executors: dict[str, Executor] = {source.id: source, target1.id: target1, target2.id: target2} edge_group = FanOutEdgeGroup(source_id=source.id, target_ids=[target1.id, target2.id]) edge_runner = create_edge_runner(edge_group, executors) shared_state = SharedState() ctx = InProcRunnerContext() # Create trace context and span IDs to simulate a message with tracing information trace_contexts = [{"traceparent": "00-4bf92f3577b34da6a3ce929d0e0e4736-00f067aa0ba902b7-01"}] source_span_ids = ["00f067aa0ba902b7"] data = MockMessage(data="test") message = Message( data=data, source_id=source.id, target_id=target1.id, trace_contexts=trace_contexts, source_span_ids=source_span_ids, ) # Clear any build spans span_exporter.clear() success = await edge_runner.send_message(message, shared_state, ctx) assert success is True spans = span_exporter.get_finished_spans() edge_group_spans = [s for s in spans if s.name == "edge_group.process"] assert len(edge_group_spans) == 1 span = edge_group_spans[0] assert span.attributes is not None assert span.attributes.get("edge_group.type") == "FanOutEdgeGroup" assert span.attributes.get("edge_group.delivered") is True assert span.attributes.get("edge_group.delivery_status") == EdgeGroupDeliveryStatus.DELIVERED.value assert span.attributes.get("message.target_id") == target1.id # Verify span links are created assert span.links is not None assert len(span.links) == 1 link = span.links[0] # Verify the link points to the correct trace and span assert link.context.trace_id == int("4bf92f3577b34da6a3ce929d0e0e4736", 16) assert link.context.span_id == int("00f067aa0ba902b7", 16) # endregion FanOutEdgeGroup # region FanInEdgeGroup def test_target_edge_group(): """Test creating a fan-in edge group.""" source1 = MockExecutor(id="source_executor_1") source2 = MockExecutor(id="source_executor_2") target = MockAggregator(id="target_executor") edge_group = FanInEdgeGroup(source_ids=[source1.id, source2.id], target_id=target.id) assert edge_group.source_executor_ids == [source1.id, source2.id] assert edge_group.target_executor_ids == [target.id] assert len(edge_group.edges) == 2 assert edge_group.edges[0].source_id == "source_executor_1" assert edge_group.edges[0].target_id == "target_executor" assert edge_group.edges[1].source_id == "source_executor_2" assert edge_group.edges[1].target_id == "target_executor" def test_target_edge_group_invalid_number_of_sources(): """Test creating a fan-in edge group with an invalid number of sources.""" source = MockExecutor(id="source_executor") target = MockAggregator(id="target_executor") with pytest.raises(ValueError, match="FanInEdgeGroup must contain at least two sources"): FanInEdgeGroup(source_ids=[source.id], target_id=target.id) async def test_target_edge_group_send_message_buffer() -> None: """Test sending a message through a fan-in edge group with buffering.""" source1 = MockExecutor(id="source_executor_1") source2 = MockExecutor(id="source_executor_2") target = MockAggregator(id="target_executor") edge_group = FanInEdgeGroup(source_ids=[source1.id, source2.id], target_id=target.id) executors: dict[str, Executor] = {source1.id: source1, source2.id: source2, target.id: target} edge_runner = create_edge_runner(edge_group, executors) shared_state = SharedState() ctx = InProcRunnerContext() data = MockMessage(data="test") with patch("agent_framework_workflow._edge_runner.EdgeRunner._execute_on_target") as mock_send: success = await edge_runner.send_message( Message(data=data, source_id=source1.id), shared_state, ctx, ) assert success is True assert mock_send.call_count == 0 # The message should be buffered and wait for the second source assert len(edge_runner._buffer[source1.id]) == 1 # type: ignore success = await edge_runner.send_message( Message(data=data, source_id=source2.id), shared_state, ctx, ) assert success is True assert mock_send.call_count == 1 # The message should be sent now that both sources have sent their messages # Buffer should be cleared after sending assert not edge_runner._buffer # type: ignore async def test_target_edge_group_send_message_with_invalid_target() -> None: """Test sending a message through a fan-in edge group with an invalid target.""" source1 = MockExecutor(id="source_executor_1") source2 = MockExecutor(id="source_executor_2") target = MockAggregator(id="target_executor") edge_group = FanInEdgeGroup(source_ids=[source1.id, source2.id], target_id=target.id) executors: dict[str, Executor] = {source1.id: source1, source2.id: source2, target.id: target} edge_runner = create_edge_runner(edge_group, executors) shared_state = SharedState() ctx = InProcRunnerContext() data = MockMessage(data="test") message = Message(data=data, source_id=source1.id, target_id="invalid_target") success = await edge_runner.send_message(message, shared_state, ctx) assert success is False async def test_target_edge_group_send_message_with_invalid_data() -> None: """Test sending a message through a fan-in edge group with invalid data.""" source1 = MockExecutor(id="source_executor_1") source2 = MockExecutor(id="source_executor_2") target = MockAggregator(id="target_executor") edge_group = FanInEdgeGroup(source_ids=[source1.id, source2.id], target_id=target.id) executors: dict[str, Executor] = {source1.id: source1, source2.id: source2, target.id: target} edge_runner = create_edge_runner(edge_group, executors) shared_state = SharedState() ctx = InProcRunnerContext() data = "invalid_data" message = Message(data=data, source_id=source1.id) success = await edge_runner.send_message(message, shared_state, ctx) assert success is False async def test_fan_in_edge_group_tracing_buffered(span_exporter) -> None: """Test that fan-in edge group processing creates proper spans for buffered messages.""" source1 = MockExecutor(id="source_executor_1") source2 = MockExecutor(id="source_executor_2") target = MockAggregator(id="target_executor") executors: dict[str, Executor] = {source1.id: source1, source2.id: source2, target.id: target} edge_group = FanInEdgeGroup(source_ids=[source1.id, source2.id], target_id=target.id) edge_runner = create_edge_runner(edge_group, executors) shared_state = SharedState() ctx = InProcRunnerContext() data = MockMessage(data="test") # Create trace context and span IDs to simulate a message with tracing information trace_contexts1 = [{"traceparent": "00-4bf92f3577b34da6a3ce929d0e0e4736-00f067aa0ba902b7-01"}] source_span_ids1 = ["00f067aa0ba902b7"] trace_contexts2 = [{"traceparent": "00-4bf92f3577b34da6a3ce929d0e0e4736-00f067aa0ba902b8-01"}] source_span_ids2 = ["00f067aa0ba902b8"] # Clear any build spans span_exporter.clear() # Send first message (should be buffered) success = await edge_runner.send_message( Message(data=data, source_id=source1.id, trace_contexts=trace_contexts1, source_span_ids=source_span_ids1), shared_state, ctx, ) assert success is True spans = span_exporter.get_finished_spans() edge_group_spans = [s for s in spans if s.name == "edge_group.process"] assert len(edge_group_spans) == 1 span = edge_group_spans[0] assert span.attributes is not None assert span.attributes.get("edge_group.type") == "FanInEdgeGroup" assert span.attributes.get("edge_group.delivered") is True assert span.attributes.get("edge_group.delivery_status") == EdgeGroupDeliveryStatus.BUFFERED.value assert span.attributes.get("message.source_id") == source1.id # Verify span links are created for first message assert span.links is not None assert len(span.links) == 1 link = span.links[0] # Verify the link points to the correct trace and span assert link.context.trace_id == int("4bf92f3577b34da6a3ce929d0e0e4736", 16) assert link.context.span_id == int("00f067aa0ba902b7", 16) # Clear spans and send second message (should trigger delivery) span_exporter.clear() success = await edge_runner.send_message( Message(data=data, source_id=source2.id, trace_contexts=trace_contexts2, source_span_ids=source_span_ids2), shared_state, ctx, ) assert success is True spans = span_exporter.get_finished_spans() edge_group_spans = [s for s in spans if s.name == "edge_group.process"] assert len(edge_group_spans) == 1 span = edge_group_spans[0] assert span.attributes is not None assert span.attributes.get("edge_group.type") == "FanInEdgeGroup" assert span.attributes.get("edge_group.delivered") is True assert span.attributes.get("edge_group.delivery_status") == EdgeGroupDeliveryStatus.DELIVERED.value assert span.attributes.get("message.source_id") == source2.id # Verify span links are created for second message assert span.links is not None assert len(span.links) == 1 link = span.links[0] # Verify the link points to the correct trace and span for the second message assert link.context.trace_id == int("4bf92f3577b34da6a3ce929d0e0e4736", 16) assert link.context.span_id == int("00f067aa0ba902b8", 16) async def test_fan_in_edge_group_tracing_type_mismatch(span_exporter) -> None: """Test that fan-in edge group processing creates proper spans for type mismatches.""" source1 = MockExecutor(id="source_executor_1") source2 = MockExecutor(id="source_executor_2") target = MockAggregator(id="target_executor") executors: dict[str, Executor] = {source1.id: source1, source2.id: source2, target.id: target} edge_group = FanInEdgeGroup(source_ids=[source1.id, source2.id], target_id=target.id) edge_runner = create_edge_runner(edge_group, executors) shared_state = SharedState() ctx = InProcRunnerContext() # Send incompatible data type data = "invalid_data" message = Message(data=data, source_id=source1.id) # Clear any build spans span_exporter.clear() success = await edge_runner.send_message(message, shared_state, ctx) assert success is False spans = span_exporter.get_finished_spans() edge_group_spans = [s for s in spans if s.name == "edge_group.process"] assert len(edge_group_spans) == 1 span = edge_group_spans[0] assert span.attributes is not None assert span.attributes.get("edge_group.type") == "FanInEdgeGroup" assert span.attributes.get("edge_group.delivered") is False assert span.attributes.get("edge_group.delivery_status") == EdgeGroupDeliveryStatus.DROPPED_TYPE_MISMATCH.value # endregion FanInEdgeGroup # region SwitchCaseEdgeGroup def test_switch_case_edge_group() -> None: """Test creating a switch case edge group.""" source = MockExecutor(id="source_executor") target1 = MockExecutor(id="target_executor_1") target2 = MockExecutor(id="target_executor_2") edge_group = SwitchCaseEdgeGroup( source_id=source.id, cases=[ SwitchCaseEdgeGroupCase(condition=lambda x: x.data < 0, target_id=target1.id), SwitchCaseEdgeGroupDefault(target_id=target2.id), ], ) assert edge_group.source_executor_ids == [source.id] assert edge_group.target_executor_ids == [target1.id, target2.id] assert len(edge_group.edges) == 2 assert edge_group.edges[0].source_id == "source_executor" assert edge_group.edges[0].target_id == "target_executor_1" assert edge_group.edges[1].source_id == "source_executor" assert edge_group.edges[1].target_id == "target_executor_2" assert edge_group._selection_func is not None # type: ignore assert edge_group._selection_func(MockMessage(data=-1), [target1.id, target2.id]) == [target1.id] # type: ignore assert edge_group._selection_func(MockMessage(data=1), [target1.id, target2.id]) == [target2.id] # type: ignore def test_switch_case_edge_group_invalid_number_of_cases(): """Test creating a switch case edge group with an invalid number of cases.""" source = MockExecutor(id="source_executor") target = MockExecutor(id="target_executor") with pytest.raises( ValueError, match=r"SwitchCaseEdgeGroup must contain at least two cases \(including the default case\)." ): SwitchCaseEdgeGroup( source_id=source.id, cases=[ SwitchCaseEdgeGroupCase(condition=lambda x: x.data < 0, target_id=target.id), ], ) with pytest.raises(ValueError, match="SwitchCaseEdgeGroup must contain exactly one default case."): SwitchCaseEdgeGroup( source_id=source.id, cases=[ SwitchCaseEdgeGroupCase(condition=lambda x: x.data < 0, target_id=target.id), SwitchCaseEdgeGroupCase(condition=lambda x: x.data >= 0, target_id=target.id), ], ) def test_switch_case_edge_group_invalid_number_of_default_cases(): """Test creating a switch case edge group with an invalid number of conditions.""" source = MockExecutor(id="source_executor") target1 = MockExecutor(id="target_executor_1") target2 = MockExecutor(id="target_executor_2") with pytest.raises(ValueError, match="SwitchCaseEdgeGroup must contain exactly one default case."): SwitchCaseEdgeGroup( source_id=source.id, cases=[ SwitchCaseEdgeGroupCase(condition=lambda x: x.data < 0, target_id=target1.id), SwitchCaseEdgeGroupDefault(target_id=target2.id), SwitchCaseEdgeGroupDefault(target_id=target2.id), ], ) async def test_switch_case_edge_group_send_message() -> None: """Test sending a message through a switch case edge group.""" source = MockExecutor(id="source_executor") target1 = MockExecutor(id="target_executor_1") target2 = MockExecutor(id="target_executor_2") edge_group = SwitchCaseEdgeGroup( source_id=source.id, cases=[ SwitchCaseEdgeGroupCase(condition=lambda x: x.data < 0, target_id=target1.id), SwitchCaseEdgeGroupDefault(target_id=target2.id), ], ) executors: dict[str, Executor] = {source.id: source, target1.id: target1, target2.id: target2} edge_runner = create_edge_runner(edge_group, executors) shared_state = SharedState() ctx = InProcRunnerContext() data = MockMessage(data=-1) message = Message(data=data, source_id=source.id) with patch("agent_framework_workflow._edge_runner.EdgeRunner._execute_on_target") as mock_send: success = await edge_runner.send_message(message, shared_state, ctx) assert success is True assert mock_send.call_count == 1 # Default condition should data = MockMessage(data=1) message = Message(data=data, source_id=source.id) with patch("agent_framework_workflow._edge_runner.EdgeRunner._execute_on_target") as mock_send: success = await edge_runner.send_message(message, shared_state, ctx) assert success is True assert mock_send.call_count == 1 async def test_switch_case_edge_group_send_message_with_invalid_target() -> None: """Test sending a message through a switch case edge group with an invalid target.""" source = MockExecutor(id="source_executor") target1 = MockExecutor(id="target_executor_1") target2 = MockExecutor(id="target_executor_2") edge_group = SwitchCaseEdgeGroup( source_id=source.id, cases=[ SwitchCaseEdgeGroupCase(condition=lambda x: x.data < 0, target_id=target1.id), SwitchCaseEdgeGroupDefault(target_id=target2.id), ], ) executors: dict[str, Executor] = {source.id: source, target1.id: target1, target2.id: target2} edge_runner = create_edge_runner(edge_group, executors) shared_state = SharedState() ctx = InProcRunnerContext() data = MockMessage(data=-1) message = Message(data=data, source_id=source.id, target_id="invalid_target") success = await edge_runner.send_message(message, shared_state, ctx) assert success is False async def test_switch_case_edge_group_send_message_with_valid_target() -> None: """Test sending a message through a switch case edge group with a target.""" source = MockExecutor(id="source_executor") target1 = MockExecutor(id="target_executor_1") target2 = MockExecutor(id="target_executor_2") edge_group = SwitchCaseEdgeGroup( source_id=source.id, cases=[ SwitchCaseEdgeGroupCase(condition=lambda x: x.data < 0, target_id=target1.id), SwitchCaseEdgeGroupDefault(target_id=target2.id), ], ) executors: dict[str, Executor] = {source.id: source, target1.id: target1, target2.id: target2} edge_runner = create_edge_runner(edge_group, executors) shared_state = SharedState() ctx = InProcRunnerContext() data = MockMessage(data=1) # Condition will fail message = Message(data=data, source_id=source.id, target_id=target1.id) success = await edge_runner.send_message(message, shared_state, ctx) assert success is False data = MockMessage(data=-1) # Condition will pass message = Message(data=data, source_id=source.id, target_id=target1.id) success = await edge_runner.send_message(message, shared_state, ctx) assert success is True async def test_switch_case_edge_group_send_message_with_invalid_data() -> None: """Test sending a message through a switch case edge group with invalid data.""" source = MockExecutor(id="source_executor") target1 = MockExecutor(id="target_executor_1") target2 = MockExecutor(id="target_executor_2") edge_group = SwitchCaseEdgeGroup( source_id=source.id, cases=[ SwitchCaseEdgeGroupCase(condition=lambda x: x.data < 0, target_id=target1.id), SwitchCaseEdgeGroupDefault(target_id=target2.id), ], ) executors: dict[str, Executor] = {source.id: source, target1.id: target1, target2.id: target2} edge_runner = create_edge_runner(edge_group, executors) shared_state = SharedState() ctx = InProcRunnerContext() data = "invalid_data" message = Message(data=data, source_id=source.id) success = await edge_runner.send_message(message, shared_state, ctx) assert success is False # endregion SwitchCaseEdgeGroup