Python: Fix workflow tests pyright warnings (#4362)

* Fix workflow tests pyright warnings

* Update uv.lock

* Fix pyright

* Comments

* Update root pyproject pyright setting

* Update core pyproject pyright setting

* Update core pyproject pyright setting
This commit is contained in:
Tao Chen
2026-03-03 13:52:05 -08:00
committed by GitHub
Unverified
parent 2a20750110
commit 1a8729d5a7
26 changed files with 696 additions and 368 deletions
+1
View File
@@ -105,6 +105,7 @@ extend = "../../pyproject.toml"
[tool.pyright]
extends = "../../pyproject.toml"
include = ["tests/workflow"]
[tool.mypy]
plugins = ['pydantic.mypy']
@@ -2,19 +2,20 @@
import logging
from collections.abc import AsyncIterable, Awaitable
from typing import TYPE_CHECKING, Any
from typing import TYPE_CHECKING, Any, Literal, overload
import pytest
from agent_framework import (
AgentExecutor,
AgentResponse,
AgentResponseUpdate,
AgentRunInputs,
AgentSession,
BaseAgent,
Content,
Message,
ResponseStream,
WorkflowEvent,
WorkflowRunState,
)
from agent_framework._workflows._agent_executor import AgentExecutorResponse
@@ -32,26 +33,56 @@ class _CountingAgent(BaseAgent):
super().__init__(**kwargs)
self.call_count = 0
@overload
def run(
self,
messages: str | Message | list[str] | list[Message] | None = None,
messages: AgentRunInputs | None = ...,
*,
stream: Literal[False] = ...,
session: AgentSession | None = ...,
**kwargs: Any,
) -> Awaitable[AgentResponse[Any]]: ...
@overload
def run(
self,
messages: AgentRunInputs | None = ...,
*,
stream: Literal[True],
session: AgentSession | None = ...,
**kwargs: Any,
) -> ResponseStream[AgentResponseUpdate, AgentResponse[Any]]: ...
def run(
self,
messages: AgentRunInputs | None = None,
*,
stream: bool = False,
session: AgentSession | None = None,
**kwargs: Any,
) -> Awaitable[AgentResponse] | ResponseStream[AgentResponseUpdate, AgentResponse]:
) -> (
Awaitable[AgentResponse[Any]]
| ResponseStream[AgentResponseUpdate, AgentResponse[Any]]
):
self.call_count += 1
if stream:
async def _stream() -> AsyncIterable[AgentResponseUpdate]:
yield AgentResponseUpdate(
contents=[Content.from_text(text=f"Response #{self.call_count}: {self.name}")]
contents=[
Content.from_text(
text=f"Response #{self.call_count}: {self.name}"
)
]
)
return ResponseStream(_stream(), finalizer=AgentResponse.from_updates)
async def _run() -> AgentResponse:
return AgentResponse(messages=[Message("assistant", [f"Response #{self.call_count}: {self.name}"])])
return AgentResponse(
messages=[
Message("assistant", [f"Response #{self.call_count}: {self.name}"])
]
)
return _run()
@@ -63,13 +94,36 @@ class _StreamingHookAgent(BaseAgent):
super().__init__(**kwargs)
self.result_hook_called = False
@overload
def run(
self,
messages: str | Message | list[str] | list[Message] | None = None,
messages: AgentRunInputs | None = ...,
*,
stream: Literal[False] = ...,
session: AgentSession | None = ...,
**kwargs: Any,
) -> Awaitable[AgentResponse[Any]]: ...
@overload
def run(
self,
messages: AgentRunInputs | None = ...,
*,
stream: Literal[True],
session: AgentSession | None = ...,
**kwargs: Any,
) -> ResponseStream[AgentResponseUpdate, AgentResponse[Any]]: ...
def run(
self,
messages: AgentRunInputs | None = None,
*,
stream: bool = False,
session: AgentSession | None = None,
**kwargs: Any,
) -> Awaitable[AgentResponse] | ResponseStream[AgentResponseUpdate, AgentResponse]:
) -> (
Awaitable[AgentResponse[Any]]
| ResponseStream[AgentResponseUpdate, AgentResponse[Any]]
):
if stream:
async def _stream() -> AsyncIterable[AgentResponseUpdate]:
@@ -78,13 +132,15 @@ class _StreamingHookAgent(BaseAgent):
role="assistant",
)
async def _mark_result_hook_called(response: AgentResponse) -> AgentResponse:
async def _mark_result_hook_called(
response: AgentResponse,
) -> AgentResponse:
self.result_hook_called = True
return response
return ResponseStream(_stream(), finalizer=AgentResponse.from_updates).with_result_hook(
_mark_result_hook_called
)
return ResponseStream(
_stream(), finalizer=AgentResponse.from_updates
).with_result_hook(_mark_result_hook_called)
async def _run() -> AgentResponse:
return AgentResponse(messages=[Message("assistant", ["hook test"])])
@@ -92,7 +148,9 @@ class _StreamingHookAgent(BaseAgent):
return _run()
async def test_agent_executor_streaming_finalizes_stream_and_runs_result_hooks() -> None:
async def test_agent_executor_streaming_finalizes_stream_and_runs_result_hooks() -> (
None
):
"""AgentExecutor should call get_final_response() so stream result hooks execute."""
agent = _StreamingHookAgent(id="hook_agent", name="HookAgent")
executor = AgentExecutor(agent, id="hook_exec")
@@ -159,7 +217,9 @@ async def test_agent_executor_checkpoint_stores_and_restores_state() -> None:
executor_state = executor_states[executor.id] # type: ignore[index]
assert "cache" in executor_state, "Checkpoint should store executor cache state"
assert "agent_session" in executor_state, "Checkpoint should store executor session state"
assert "agent_session" in executor_state, (
"Checkpoint should store executor session state"
)
# Verify session state structure
session_state = executor_state["agent_session"] # type: ignore[index]
@@ -180,11 +240,15 @@ async def test_agent_executor_checkpoint_stores_and_restores_state() -> None:
assert restored_agent.call_count == 0
# Build new workflow with the restored executor
wf_resume = SequentialBuilder(participants=[restored_executor], checkpoint_storage=storage).build()
wf_resume = SequentialBuilder(
participants=[restored_executor], checkpoint_storage=storage
).build()
# Resume from checkpoint
resumed_output: AgentExecutorResponse | None = None
async for ev in wf_resume.run(checkpoint_id=restore_checkpoint.checkpoint_id, stream=True):
async for ev in wf_resume.run(
checkpoint_id=restore_checkpoint.checkpoint_id, stream=True
):
if ev.type == "output":
resumed_output = ev.data # type: ignore[assignment]
if ev.type == "status" and ev.state in (
@@ -278,7 +342,7 @@ async def test_agent_executor_run_streaming_with_stream_kwarg_does_not_raise() -
workflow = SequentialBuilder(participants=[executor]).build()
# stream=True at workflow level triggers streaming mode (returns async iterable)
events = []
events: list[WorkflowEvent] = []
async for event in workflow.run("hello", stream=True):
events.append(event)
assert len(events) > 0
@@ -288,10 +352,13 @@ async def test_agent_executor_run_streaming_with_stream_kwarg_does_not_raise() -
@pytest.mark.parametrize("reserved_kwarg", ["session", "stream", "messages"])
async def test_prepare_agent_run_args_strips_reserved_kwargs(reserved_kwarg: str, caplog: "LogCaptureFixture") -> None:
"""_prepare_agent_run_args must remove reserved kwargs and log a warning."""
raw = {reserved_kwarg: "should-be-stripped", "custom_key": "keep-me"}
raw: dict[str, Any] = {
reserved_kwarg: "should-be-stripped",
"custom_key": "keep-me",
}
with caplog.at_level(logging.WARNING):
run_kwargs, options = AgentExecutor._prepare_agent_run_args(raw)
run_kwargs, options = AgentExecutor._prepare_agent_run_args(raw) # pyright: ignore[reportPrivateUsage]
assert reserved_kwarg not in run_kwargs
assert "custom_key" in run_kwargs
@@ -302,8 +369,8 @@ async def test_prepare_agent_run_args_strips_reserved_kwargs(reserved_kwarg: str
async def test_prepare_agent_run_args_preserves_non_reserved_kwargs() -> None:
"""Non-reserved workflow kwargs should pass through unchanged."""
raw = {"custom_param": "value", "another": 42}
run_kwargs, options = AgentExecutor._prepare_agent_run_args(raw)
raw: dict[str, Any] = {"custom_param": "value", "another": 42}
run_kwargs, _options = AgentExecutor._prepare_agent_run_args(raw) # pyright: ignore[reportPrivateUsage]
assert run_kwargs["custom_param"] == "value"
assert run_kwargs["another"] == 42
@@ -312,10 +379,10 @@ async def test_prepare_agent_run_args_strips_all_reserved_kwargs_at_once(
caplog: "LogCaptureFixture",
) -> None:
"""All reserved kwargs should be stripped when supplied together, each emitting a warning."""
raw = {"session": "x", "stream": True, "messages": [], "custom": 1}
raw: dict[str, Any] = {"session": "x", "stream": True, "messages": [], "custom": 1}
with caplog.at_level(logging.WARNING):
run_kwargs, options = AgentExecutor._prepare_agent_run_args(raw)
run_kwargs, options = AgentExecutor._prepare_agent_run_args(raw) # pyright: ignore[reportPrivateUsage]
assert "session" not in run_kwargs
assert "stream" not in run_kwargs
@@ -324,7 +391,11 @@ async def test_prepare_agent_run_args_strips_all_reserved_kwargs_at_once(
assert options is not None
assert options["additional_function_arguments"]["custom"] == 1
warned_keys = {r.message.split("'")[1] for r in caplog.records if "reserved" in r.message.lower()}
warned_keys = {
r.message.split("'")[1]
for r in caplog.records
if "reserved" in r.message.lower()
}
assert warned_keys == {"session", "stream", "messages"}
@@ -3,7 +3,7 @@
"""Tests for AgentExecutor handling of tool calls and results in streaming mode."""
from collections.abc import AsyncIterable, Awaitable, Mapping, Sequence
from typing import Any
from typing import Any, Literal, overload
from typing_extensions import Never
@@ -13,6 +13,7 @@ from agent_framework import (
AgentExecutorResponse,
AgentResponse,
AgentResponseUpdate,
AgentRunInputs,
AgentSession,
BaseAgent,
ChatResponse,
@@ -37,18 +38,38 @@ class _ToolCallingAgent(BaseAgent):
def __init__(self, **kwargs: Any) -> None:
super().__init__(**kwargs)
@overload
def run(
self,
messages: str | Content | Message | Sequence[str | Content | Message] | None = None,
messages: AgentRunInputs | None = ...,
*,
stream: Literal[False] = ...,
session: AgentSession | None = ...,
**kwargs: Any,
) -> Awaitable[AgentResponse[Any]]: ...
@overload
def run(
self,
messages: AgentRunInputs | None = ...,
*,
stream: Literal[True],
session: AgentSession | None = ...,
**kwargs: Any,
) -> ResponseStream[AgentResponseUpdate, AgentResponse[Any]]: ...
def run(
self,
messages: AgentRunInputs | None = None,
*,
stream: bool = False,
session: AgentSession | None = None,
**kwargs: Any,
) -> Awaitable[AgentResponse] | ResponseStream[AgentResponseUpdate, AgentResponse]:
) -> Awaitable[AgentResponse[Any]] | ResponseStream[AgentResponseUpdate, AgentResponse[Any]]:
if stream:
return ResponseStream(self._run_stream_impl(), finalizer=AgentResponse.from_updates)
async def _run() -> AgentResponse:
async def _run() -> AgentResponse[Any]:
return AgentResponse(messages=[Message("assistant", ["done"])])
return _run()
@@ -111,6 +132,7 @@ async def test_agent_executor_emits_tool_calls_in_streaming_mode() -> None:
# First event: text update
assert events[0].data is not None
assert events[0].data.contents[0].type == "text"
assert events[0].data.contents[0].text is not None
assert "Let me search" in events[0].data.contents[0].text
# Second event: function call
@@ -129,6 +151,7 @@ async def test_agent_executor_emits_tool_calls_in_streaming_mode() -> None:
# Fourth event: final text
assert events[3].data is not None
assert events[3].data.contents[0].type == "text"
assert events[3].data.contents[0].text is not None
assert "sunny" in events[3].data.contents[0].text
@@ -1,9 +1,9 @@
# Copyright (c) Microsoft. All rights reserved.
from collections.abc import AsyncIterable
from typing import Any
from collections.abc import Awaitable
from typing import Any, Literal, overload
from agent_framework import AgentResponse, AgentResponseUpdate, AgentSession, Message
from agent_framework import AgentResponse, AgentResponseUpdate, AgentRunInputs, AgentSession, ResponseStream
from agent_framework._workflows._agent_utils import resolve_agent_id
@@ -11,40 +11,23 @@ class MockAgent:
"""Mock agent for testing agent utilities."""
def __init__(self, agent_id: str, name: str | None = None) -> None:
self._id = agent_id
self._name = name
self.id: str = agent_id
self.name: str | None = name
self.description: str | None = None
@property
def id(self) -> str:
return self._id
@property
def name(self) -> str | None:
return self._name
@property
def display_name(self) -> str:
"""Returns the display name of the agent."""
...
@property
def description(self) -> str | None:
"""Returns the description of the agent."""
...
def run(
self,
messages: str | Message | list[str] | list[Message] | None = None,
*,
stream: bool = False,
session: AgentSession | None = None,
**kwargs: Any,
) -> AgentResponse | AsyncIterable[AgentResponseUpdate]: ...
@overload
def run(self, messages: AgentRunInputs | None = ..., *, stream: Literal[False] = ..., session: AgentSession | None = ..., **kwargs: Any) -> Awaitable[AgentResponse[Any]]: ...
@overload
def run(self, messages: AgentRunInputs | None = ..., *, stream: Literal[True], session: AgentSession | None = ..., **kwargs: Any) -> ResponseStream[AgentResponseUpdate, AgentResponse[Any]]: ...
def run(self, messages: AgentRunInputs | None = None, *, stream: bool = False, session: AgentSession | None = None, **kwargs: Any) -> Awaitable[AgentResponse[Any]] | ResponseStream[AgentResponseUpdate, AgentResponse[Any]]: ...
def create_session(self, **kwargs: Any) -> AgentSession:
"""Creates a new conversation session for the agent."""
...
def get_session(self, *, service_session_id: str, **kwargs: Any) -> AgentSession:
return AgentSession()
def test_resolve_agent_id_with_name() -> None:
"""Test that resolve_agent_id returns name when agent has a name."""
@@ -5,6 +5,7 @@ import tempfile
from dataclasses import dataclass
from datetime import datetime, timezone
from pathlib import Path
from typing import Any
import pytest
@@ -24,7 +25,7 @@ class _TestToolApprovalRequest:
"""Request data for tool approval in tests."""
tool_name: str
arguments: dict
arguments: dict[str, Any]
timestamp: datetime
@@ -41,7 +42,7 @@ class _TestApprovalRequest:
"""Approval request data for tests."""
action: str
params: tuple
params: tuple[Any, ...]
@dataclass
@@ -78,8 +79,8 @@ def test_workflow_checkpoint_custom_values():
workflow_name="test-workflow-456",
graph_signature_hash="test-hash-456",
timestamp=custom_timestamp,
messages={"executor1": [{"data": "test"}]},
pending_request_info_events={"req123": {"data": "test"}},
messages={"executor1": [{"data": "test"}]}, # type: ignore[arg-type] # raw dict for serialization test
pending_request_info_events={"req123": {"data": "test"}}, # type: ignore[arg-type] # raw dict for serialization test
state={"key": "value"},
iteration_count=5,
metadata={"test": True},
@@ -103,7 +104,7 @@ def test_workflow_checkpoint_to_dict():
checkpoint_id="test-id",
workflow_name="test-workflow",
graph_signature_hash="test-hash",
messages={"executor1": [{"data": "test"}]},
messages={"executor1": [{"data": "test"}]}, # type: ignore[arg-type] # raw dict for serialization test
state={"key": "value"},
iteration_count=5,
)
@@ -161,8 +162,8 @@ async def test_memory_checkpoint_storage_save_and_load():
checkpoint = WorkflowCheckpoint(
workflow_name="test-workflow",
graph_signature_hash="test-hash",
messages={"executor1": [{"data": "hello"}]},
pending_request_info_events={"req123": {"data": "test"}},
messages={"executor1": [{"data": "hello"}]}, # type: ignore[arg-type] # raw dict for serialization test
pending_request_info_events={"req123": {"data": "test"}}, # type: ignore[arg-type] # raw dict for serialization test
)
# Save checkpoint
@@ -776,9 +777,9 @@ async def test_file_checkpoint_storage_save_and_load():
checkpoint = WorkflowCheckpoint(
workflow_name="test-workflow",
graph_signature_hash="test-hash",
messages={"executor1": [{"data": "hello", "source_id": "test", "target_id": None}]},
messages={"executor1": [{"data": "hello", "source_id": "test", "target_id": None}]}, # type: ignore[arg-type] # raw dict for serialization test
state={"key": "value"},
pending_request_info_events={"req123": {"data": "test"}},
pending_request_info_events={"req123": {"data": "test"}}, # type: ignore[arg-type] # raw dict for serialization test
)
# Save checkpoint
@@ -904,9 +905,9 @@ async def test_file_checkpoint_storage_json_serialization():
checkpoint = WorkflowCheckpoint(
workflow_name="test-workflow",
graph_signature_hash="test-hash",
messages={"executor1": [{"data": {"nested": {"value": 42}}, "source_id": "test", "target_id": None}]},
messages={"executor1": [{"data": {"nested": {"value": 42}}, "source_id": "test", "target_id": None}]}, # type: ignore[arg-type] # raw dict for serialization test
state={"list": [1, 2, 3], "dict": {"a": "b", "c": {"d": "e"}}, "bool": True, "null": None},
pending_request_info_events={"req123": {"data": "test"}},
pending_request_info_events={"req123": {"data": "test"}}, # type: ignore[arg-type] # raw dict for serialization test
)
# Save and load
@@ -3,11 +3,11 @@
import json
from dataclasses import dataclass
from datetime import datetime, timezone
from typing import Any
from typing import Any, cast
from agent_framework._workflows._checkpoint_encoding import (
_PICKLE_MARKER,
_TYPE_MARKER,
_PICKLE_MARKER, # pyright: ignore[reportPrivateUsage]
_TYPE_MARKER, # pyright: ignore[reportPrivateUsage]
encode_checkpoint_value,
)
@@ -185,8 +185,9 @@ def test_encode_list_of_dataclasses() -> None:
result = encode_checkpoint_value(data)
assert isinstance(result, list)
assert len(result) == 2
for item in result:
result_list = cast(list[Any], result)
assert len(result_list) == 2
for item in result_list:
assert _PICKLE_MARKER in item
@@ -4,6 +4,8 @@ from dataclasses import dataclass
from typing import Any
from unittest.mock import patch
from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter
import pytest
from agent_framework import (
@@ -275,6 +277,7 @@ async def test_single_edge_group_send_message_with_condition_pass() -> None:
success = await edge_runner.send_message(message, state, ctx)
assert success is True
assert target.call_count == 1
assert target.last_message is not None
assert target.last_message.data == "test"
@@ -301,7 +304,7 @@ async def test_single_edge_group_send_message_with_condition_fail() -> None:
assert target.call_count == 0
async def test_single_edge_group_tracing_success(span_exporter) -> None:
async def test_single_edge_group_tracing_success(span_exporter: InMemorySpanExporter) -> None:
"""Test that single edge group processing creates proper success spans."""
source = MockExecutor(id="source_executor")
target = MockExecutor(id="target_executor")
@@ -352,7 +355,7 @@ async def test_single_edge_group_tracing_success(span_exporter) -> None:
assert link.context.span_id == int("00f067aa0ba902b7", 16)
async def test_single_edge_group_tracing_condition_failure(span_exporter) -> None:
async def test_single_edge_group_tracing_condition_failure(span_exporter: InMemorySpanExporter) -> None:
"""Test that single edge group processing creates proper spans for condition failures."""
source = MockExecutor(id="source_executor")
target = MockExecutor(id="target_executor")
@@ -386,7 +389,7 @@ async def test_single_edge_group_tracing_condition_failure(span_exporter) -> Non
assert span.attributes.get("edge_group.delivery_status") == EdgeGroupDeliveryStatus.DROPPED_CONDITION_FALSE.value
async def test_single_edge_group_tracing_type_mismatch(span_exporter) -> None:
async def test_single_edge_group_tracing_type_mismatch(span_exporter: InMemorySpanExporter) -> None:
"""Test that single edge group processing creates proper spans for type mismatches."""
source = MockExecutor(id="source_executor")
target = MockExecutor(id="target_executor")
@@ -421,7 +424,7 @@ async def test_single_edge_group_tracing_type_mismatch(span_exporter) -> None:
assert span.attributes.get("edge_group.delivery_status") == EdgeGroupDeliveryStatus.DROPPED_TYPE_MISMATCH.value
async def test_single_edge_group_tracing_target_mismatch(span_exporter) -> None:
async def test_single_edge_group_tracing_target_mismatch(span_exporter: InMemorySpanExporter) -> None:
"""Test that single edge group processing creates proper spans for target mismatches."""
source = MockExecutor(id="source_executor")
target = MockExecutor(id="target_executor")
@@ -775,7 +778,7 @@ async def test_source_edge_group_with_selection_func_send_message_with_target_in
assert success is False
async def test_fan_out_edge_group_tracing_success(span_exporter) -> None:
async def test_fan_out_edge_group_tracing_success(span_exporter: InMemorySpanExporter) -> None:
"""Test that fan-out edge group processing creates proper success spans."""
source = MockExecutor(id="source_executor")
target1 = MockExecutor(id="target_executor_1")
@@ -827,7 +830,7 @@ async def test_fan_out_edge_group_tracing_success(span_exporter) -> None:
assert link.context.span_id == int("00f067aa0ba902b7", 16)
async def test_fan_out_edge_group_tracing_with_target(span_exporter) -> None:
async def test_fan_out_edge_group_tracing_with_target(span_exporter: InMemorySpanExporter) -> None:
"""Test that fan-out edge group processing creates proper spans for targeted messages."""
source = MockExecutor(id="source_executor")
target1 = MockExecutor(id="target_executor_1")
@@ -994,7 +997,7 @@ async def test_target_edge_group_send_message_with_invalid_data() -> None:
assert success is False
async def test_fan_in_edge_group_tracing_buffered(span_exporter) -> None:
async def test_fan_in_edge_group_tracing_buffered(span_exporter: InMemorySpanExporter) -> None:
"""Test that fan-in edge group processing creates proper spans for buffered messages."""
source1 = MockExecutor(id="source_executor_1")
source2 = MockExecutor(id="source_executor_2")
@@ -1086,7 +1089,7 @@ async def test_fan_in_edge_group_tracing_buffered(span_exporter) -> None:
assert link.context.span_id == int("00f067aa0ba902b8", 16)
async def test_fan_in_edge_group_tracing_type_mismatch(span_exporter) -> None:
async def test_fan_in_edge_group_tracing_type_mismatch(span_exporter: InMemorySpanExporter) -> None:
"""Test that fan-in edge group processing creates proper spans for type mismatches."""
source1 = MockExecutor(id="source_executor_1")
source2 = MockExecutor(id="source_executor_2")
@@ -3,8 +3,6 @@
from dataclasses import dataclass
import pytest
from typing_extensions import Never
from agent_framework import (
Executor,
Message,
@@ -16,6 +14,7 @@ from agent_framework import (
handler,
response_handler,
)
from typing_extensions import Never
# Module-level types for string forward reference tests
@@ -59,7 +58,7 @@ def test_executor_handler_without_annotations():
class MockExecutorWithOneHandlerWithoutAnnotations(Executor): # type: ignore
"""A mock executor with one handler that does not implement any annotations."""
@handler
@handler # pyright: ignore[reportUnknownArgumentType]
async def handle(self, message, ctx) -> None: # type: ignore
"""A mock handler that does not implement any annotations."""
pass
@@ -156,7 +155,11 @@ async def test_executor_invoked_event_contains_input_data():
workflow = WorkflowBuilder(start_executor=upper).add_edge(upper, collector).build()
events = await workflow.run("hello world")
invoked_events = [e for e in events if isinstance(e, WorkflowEvent) and e.type == "executor_invoked"]
invoked_events = [
e
for e in events
if isinstance(e, WorkflowEvent) and e.type == "executor_invoked"
]
assert len(invoked_events) == 2
@@ -190,10 +193,16 @@ async def test_executor_completed_event_contains_sent_messages():
sender = MultiSenderExecutor(id="sender")
collector = CollectorExecutor(id="collector")
workflow = WorkflowBuilder(start_executor=sender).add_edge(sender, collector).build()
workflow = (
WorkflowBuilder(start_executor=sender).add_edge(sender, collector).build()
)
events = await workflow.run("hello")
completed_events = [e for e in events if isinstance(e, WorkflowEvent) and e.type == "executor_completed"]
completed_events = [
e
for e in events
if isinstance(e, WorkflowEvent) and e.type == "executor_completed"
]
# Sender should have completed with the sent messages
sender_completed = next(e for e in completed_events if e.executor_id == "sender")
@@ -201,7 +210,9 @@ async def test_executor_completed_event_contains_sent_messages():
assert sender_completed.data == ["hello-first", "hello-second"]
# Collector should have completed with no sent messages (None)
collector_completed_events = [e for e in completed_events if e.executor_id == "collector"]
collector_completed_events = [
e for e in completed_events if e.executor_id == "collector"
]
# Collector is called twice (once per message from sender)
assert len(collector_completed_events) == 2
for collector_completed in collector_completed_events:
@@ -220,7 +231,11 @@ async def test_executor_completed_event_includes_yielded_outputs():
workflow = WorkflowBuilder(start_executor=executor).build()
events = await workflow.run("test")
completed_events = [e for e in events if isinstance(e, WorkflowEvent) and e.type == "executor_completed"]
completed_events = [
e
for e in events
if isinstance(e, WorkflowEvent) and e.type == "executor_completed"
]
assert len(completed_events) == 1
assert completed_events[0].executor_id == "yielder"
@@ -248,7 +263,9 @@ async def test_executor_events_with_complex_message_types():
class ProcessorExecutor(Executor):
@handler
async def handle(self, request: Request, ctx: WorkflowContext[Response]) -> None:
async def handle(
self, request: Request, ctx: WorkflowContext[Response]
) -> None:
response = Response(results=[request.query.upper()] * request.limit)
await ctx.send_message(response)
@@ -260,13 +277,23 @@ async def test_executor_events_with_complex_message_types():
processor = ProcessorExecutor(id="processor")
collector = CollectorExecutor(id="collector")
workflow = WorkflowBuilder(start_executor=processor).add_edge(processor, collector).build()
workflow = (
WorkflowBuilder(start_executor=processor).add_edge(processor, collector).build()
)
input_request = Request(query="hello", limit=3)
events = await workflow.run(input_request)
invoked_events = [e for e in events if isinstance(e, WorkflowEvent) and e.type == "executor_invoked"]
completed_events = [e for e in events if isinstance(e, WorkflowEvent) and e.type == "executor_completed"]
invoked_events = [
e
for e in events
if isinstance(e, WorkflowEvent) and e.type == "executor_invoked"
]
completed_events = [
e
for e in events
if isinstance(e, WorkflowEvent) and e.type == "executor_completed"
]
# Check processor invoked event has the Request object
processor_invoked = next(e for e in invoked_events if e.executor_id == "processor")
@@ -275,7 +302,9 @@ async def test_executor_events_with_complex_message_types():
assert processor_invoked.data.limit == 3
# Check processor completed event has the Response object
processor_completed = next(e for e in completed_events if e.executor_id == "processor")
processor_completed = next(
e for e in completed_events if e.executor_id == "processor"
)
assert processor_completed.data is not None
assert len(processor_completed.data) == 1
assert isinstance(processor_completed.data[0], Response)
@@ -361,7 +390,9 @@ def test_executor_workflow_output_types_property():
# Test executor with union workflow output types
class UnionWorkflowOutputExecutor(Executor):
@handler
async def handle(self, text: str, ctx: WorkflowContext[int, str | bool]) -> None:
async def handle(
self, text: str, ctx: WorkflowContext[int, str | bool]
) -> None:
pass
executor = UnionWorkflowOutputExecutor(id="union_workflow_output")
@@ -372,11 +403,15 @@ def test_executor_workflow_output_types_property():
# Test executor with multiple handlers having different workflow output types
class MultiHandlerWorkflowExecutor(Executor):
@handler
async def handle_string(self, text: str, ctx: WorkflowContext[int, str]) -> None:
async def handle_string(
self, text: str, ctx: WorkflowContext[int, str]
) -> None:
pass
@handler
async def handle_number(self, num: int, ctx: WorkflowContext[bool, float]) -> None:
async def handle_number(
self, num: int, ctx: WorkflowContext[bool, float]
) -> None:
pass
executor = MultiHandlerWorkflowExecutor(id="multi_workflow")
@@ -430,7 +465,9 @@ def test_executor_output_types_includes_response_handlers():
pass
@response_handler
async def handle_response(self, original_request: str, response: bool, ctx: WorkflowContext[float]) -> None:
async def handle_response(
self, original_request: str, response: bool, ctx: WorkflowContext[float]
) -> None:
pass
executor = RequestResponseExecutor(id="request_response")
@@ -452,7 +489,10 @@ def test_executor_workflow_output_types_includes_response_handlers():
@response_handler
async def handle_response(
self, original_request: str, response: bool, ctx: WorkflowContext[float, bool]
self,
original_request: str,
response: bool,
ctx: WorkflowContext[float, bool],
) -> None:
pass
@@ -509,7 +549,10 @@ def test_executor_response_handler_union_output_types():
@response_handler
async def handle_response(
self, original_request: str, response: bool, ctx: WorkflowContext[int | str | float, bool | int]
self,
original_request: str,
response: bool,
ctx: WorkflowContext[int | str | float, bool | int],
) -> None:
pass
@@ -531,7 +574,9 @@ async def test_executor_invoked_event_data_not_mutated_by_handler():
"""Test that executor_invoked event (type='executor_invoked').data captures original input, not mutated input."""
@executor(id="Mutator")
async def mutator(messages: list[Message], ctx: WorkflowContext[list[Message]]) -> None:
async def mutator(
messages: list[Message], ctx: WorkflowContext[list[Message]]
) -> None:
# The handler mutates the input list by appending new messages
original_len = len(messages)
messages.append(Message(role="assistant", text="Added by executor"))
@@ -546,7 +591,11 @@ async def test_executor_invoked_event_data_not_mutated_by_handler():
events = await workflow.run(input_messages)
# Find the invoked event for the Mutator executor
invoked_events = [e for e in events if isinstance(e, WorkflowEvent) and e.type == "executor_invoked"]
invoked_events = [
e
for e in events
if isinstance(e, WorkflowEvent) and e.type == "executor_invoked"
]
assert len(invoked_events) == 1
mutator_invoked = invoked_events[0]
@@ -577,8 +626,8 @@ class TestHandlerExplicitTypes:
exec_instance = ExplicitInputExecutor(id="explicit_input")
# Handler should be registered for str (explicit), not Any (introspected)
assert str in exec_instance._handlers
assert len(exec_instance._handlers) == 1
assert str in exec_instance._handlers # pyright: ignore[reportPrivateUsage]
assert len(exec_instance._handlers) == 1 # pyright: ignore[reportPrivateUsage]
# Can handle str messages
assert exec_instance.can_handle(WorkflowMessage(data="hello", source_id="mock"))
@@ -596,8 +645,8 @@ class TestHandlerExplicitTypes:
exec_instance = ExplicitOutputExecutor(id="explicit_output")
# Handler spec should have int as output type (explicit)
handler_func = exec_instance._handlers[str]
assert handler_func._handler_spec["output_types"] == [int]
handler_func = exec_instance._handlers[str] # pyright: ignore[reportPrivateUsage]
assert handler_func._handler_spec["output_types"] == [int] # pyright: ignore[reportFunctionMemberAccess]
# Executor output_types property should reflect explicit type
assert int in exec_instance.output_types
@@ -615,16 +664,20 @@ class TestHandlerExplicitTypes:
exec_instance = ExplicitBothExecutor(id="explicit_both")
# Handler should be registered for dict (explicit input type)
assert dict in exec_instance._handlers
assert len(exec_instance._handlers) == 1
assert dict in exec_instance._handlers # pyright: ignore[reportPrivateUsage]
assert len(exec_instance._handlers) == 1 # pyright: ignore[reportPrivateUsage]
# Output type should be list (explicit)
handler_func = exec_instance._handlers[dict]
assert handler_func._handler_spec["output_types"] == [list]
handler_func = exec_instance._handlers[dict] # pyright: ignore[reportPrivateUsage]
assert handler_func._handler_spec["output_types"] == [list] # pyright: ignore[reportFunctionMemberAccess]
# Verify can_handle
assert exec_instance.can_handle(WorkflowMessage(data={"key": "value"}, source_id="mock"))
assert not exec_instance.can_handle(WorkflowMessage(data="string", source_id="mock"))
assert exec_instance.can_handle(
WorkflowMessage(data={"key": "value"}, source_id="mock")
)
assert not exec_instance.can_handle(
WorkflowMessage(data="string", source_id="mock")
)
def test_handler_with_explicit_union_input_type(self):
"""Test that explicit union input_type is handled correctly."""
@@ -639,13 +692,15 @@ class TestHandlerExplicitTypes:
# Handler should be registered for the union type
# The union type itself is stored as the key
assert len(exec_instance._handlers) == 1
assert len(exec_instance._handlers) == 1 # pyright: ignore[reportPrivateUsage]
# Can handle both str and int messages
assert exec_instance.can_handle(WorkflowMessage(data="hello", source_id="mock"))
assert exec_instance.can_handle(WorkflowMessage(data=42, source_id="mock"))
# Cannot handle float
assert not exec_instance.can_handle(WorkflowMessage(data=3.14, source_id="mock"))
assert not exec_instance.can_handle(
WorkflowMessage(data=3.14, source_id="mock")
)
def test_handler_with_explicit_union_output_type(self):
"""Test that explicit union output is normalized to a list."""
@@ -674,8 +729,8 @@ class TestHandlerExplicitTypes:
exec_instance = PrecedenceExecutor(id="precedence")
# Should use explicit input type (bytes), not introspected (str)
assert bytes in exec_instance._handlers
assert str not in exec_instance._handlers
assert bytes in exec_instance._handlers # pyright: ignore[reportPrivateUsage]
assert str not in exec_instance._handlers # pyright: ignore[reportPrivateUsage]
# Should use explicit output type (float), not introspected (int)
assert float in exec_instance.output_types
@@ -692,7 +747,7 @@ class TestHandlerExplicitTypes:
exec_instance = IntrospectedExecutor(id="introspected")
# Should use introspected types
assert str in exec_instance._handlers
assert str in exec_instance._handlers # pyright: ignore[reportPrivateUsage]
assert int in exec_instance.output_types
def test_handler_explicit_mode_requires_input(self):
@@ -705,13 +760,13 @@ class TestHandlerExplicitTypes:
pass
exec_input = OnlyInputExecutor(id="only_input")
assert bytes in exec_input._handlers # Explicit
assert bytes in exec_input._handlers # pyright: ignore[reportPrivateUsage] # Explicit
assert exec_input.output_types == [] # No output types (not introspected)
# Only explicit output without input should raise error
with pytest.raises(ValueError, match="must specify 'input' type"):
class OnlyOutputExecutor(Executor):
class OnlyOutputExecutor(Executor): # pyright: ignore[reportUnusedClass]
@handler(output=float)
async def handle(self, message: str, ctx: WorkflowContext[int]) -> None:
pass
@@ -719,9 +774,11 @@ class TestHandlerExplicitTypes:
# Only explicit workflow_output without input should raise error
with pytest.raises(ValueError, match="must specify 'input' type"):
class OnlyWorkflowOutputExecutor(Executor):
class OnlyWorkflowOutputExecutor(Executor): # pyright: ignore[reportUnusedClass]
@handler(workflow_output=bool)
async def handle(self, message: str, ctx: WorkflowContext[int, str]) -> None:
async def handle(
self, message: str, ctx: WorkflowContext[int, str]
) -> None:
pass
def test_handler_explicit_input_type_allows_no_message_annotation(self):
@@ -734,8 +791,7 @@ class TestHandlerExplicitTypes:
exec_instance = NoAnnotationExecutor(id="no_annotation")
# Should work with explicit input_type
assert str in exec_instance._handlers
assert str in exec_instance._handlers # pyright: ignore[reportPrivateUsage]
assert exec_instance.can_handle(WorkflowMessage(data="hello", source_id="mock"))
def test_handler_multiple_handlers_mixed_explicit_and_introspected(self):
@@ -747,15 +803,17 @@ class TestHandlerExplicitTypes:
pass
@handler
async def handle_introspected(self, message: float, ctx: WorkflowContext[bool]) -> None:
async def handle_introspected(
self, message: float, ctx: WorkflowContext[bool]
) -> None:
pass
exec_instance = MixedExecutor(id="mixed")
# Should have both handlers
assert len(exec_instance._handlers) == 2
assert str in exec_instance._handlers # Explicit
assert float in exec_instance._handlers # Introspected
assert len(exec_instance._handlers) == 2 # pyright: ignore[reportPrivateUsage]
assert str in exec_instance._handlers # pyright: ignore[reportPrivateUsage] # Explicit
assert float in exec_instance._handlers # pyright: ignore[reportPrivateUsage] # Introspected
# Should have both output types
assert int in exec_instance.output_types # Explicit
@@ -772,8 +830,10 @@ class TestHandlerExplicitTypes:
exec_instance = StringRefExecutor(id="string_ref")
# Should resolve the string to the actual type
assert ForwardRefMessage in exec_instance._handlers
assert exec_instance.can_handle(WorkflowMessage(data=ForwardRefMessage("hello"), source_id="mock"))
assert ForwardRefMessage in exec_instance._handlers # pyright: ignore[reportPrivateUsage]
assert exec_instance.can_handle(
WorkflowMessage(data=ForwardRefMessage("hello"), source_id="mock")
)
def test_handler_with_string_forward_reference_union(self):
"""Test that string forward references work with union types."""
@@ -786,8 +846,12 @@ class TestHandlerExplicitTypes:
exec_instance = StringUnionExecutor(id="string_union")
# Should handle both types
assert exec_instance.can_handle(WorkflowMessage(data=ForwardRefTypeA("hello"), source_id="mock"))
assert exec_instance.can_handle(WorkflowMessage(data=ForwardRefTypeB(42), source_id="mock"))
assert exec_instance.can_handle(
WorkflowMessage(data=ForwardRefTypeA("hello"), source_id="mock")
)
assert exec_instance.can_handle(
WorkflowMessage(data=ForwardRefTypeB(42), source_id="mock")
)
def test_handler_with_string_forward_reference_output_type(self):
"""Test that string forward references work for output_type."""
@@ -813,8 +877,8 @@ class TestHandlerExplicitTypes:
exec_instance = ExplicitWorkflowOutputExecutor(id="explicit_workflow_output")
# Handler spec should have bool as workflow_output_type (explicit)
handler_func = exec_instance._handlers[str]
assert handler_func._handler_spec["workflow_output_types"] == [bool]
handler_func = exec_instance._handlers[str] # pyright: ignore[reportPrivateUsage]
assert handler_func._handler_spec["workflow_output_types"] == [bool] # pyright: ignore[reportFunctionMemberAccess]
# Executor workflow_output_types property should reflect explicit type
assert bool in exec_instance.workflow_output_types
@@ -826,13 +890,14 @@ class TestHandlerExplicitTypes:
class PrecedenceExecutor(Executor):
@handler(input=int, output=float, workflow_output=str)
async def handle(self, message: int, ctx: WorkflowContext[int, bool]) -> None:
async def handle(
self, message: int, ctx: WorkflowContext[int, bool]
) -> None:
pass
exec_instance = PrecedenceExecutor(id="precedence")
# All types should come from explicit params
assert int in exec_instance._handlers
assert int in exec_instance._handlers # pyright: ignore[reportPrivateUsage]
assert float in exec_instance.output_types
assert str in exec_instance.workflow_output_types
# Introspected types should NOT be present
@@ -849,8 +914,7 @@ class TestHandlerExplicitTypes:
exec_instance = AllExplicitExecutor(id="all_explicit")
# Check input type
assert str in exec_instance._handlers
assert str in exec_instance._handlers # pyright: ignore[reportPrivateUsage]
assert exec_instance.can_handle(WorkflowMessage(data="hello", source_id="mock"))
# Check output_type
@@ -894,7 +958,9 @@ class TestHandlerExplicitTypes:
async def handle(self, message, ctx: WorkflowContext) -> None: # type: ignore[no-untyped-def]
pass
exec_instance = StringUnionWorkflowOutputExecutor(id="string_union_workflow_output")
exec_instance = StringUnionWorkflowOutputExecutor(
id="string_union_workflow_output"
)
# Should resolve both types from string union
assert ForwardRefTypeA in exec_instance.workflow_output_types
@@ -905,10 +971,14 @@ class TestHandlerExplicitTypes:
class IntrospectedWorkflowOutputExecutor(Executor):
@handler
async def handle(self, message: str, ctx: WorkflowContext[int, bool]) -> None:
async def handle(
self, message: str, ctx: WorkflowContext[int, bool]
) -> None:
pass
exec_instance = IntrospectedWorkflowOutputExecutor(id="introspected_workflow_output")
exec_instance = IntrospectedWorkflowOutputExecutor(
id="introspected_workflow_output"
)
# Should use introspected types from WorkflowContext[int, bool]
assert int in exec_instance.output_types
@@ -34,8 +34,8 @@ class TestExecutorFutureAnnotations:
pass
exec_instance = MyExecutor(id="test")
assert str in exec_instance._handlers
spec = exec_instance._handler_specs[0]
assert str in exec_instance._handlers # pyright: ignore[reportPrivateUsage]
spec = exec_instance._handler_specs[0] # pyright: ignore[reportPrivateUsage]
assert spec["message_type"] is str
assert spec["output_types"] == [MyTypeA]
assert spec["workflow_output_types"] == [MyTypeB]
@@ -49,8 +49,8 @@ class TestExecutorFutureAnnotations:
pass
exec_instance = MyExecutor(id="test")
assert int in exec_instance._handlers
spec = exec_instance._handler_specs[0]
assert int in exec_instance._handlers # pyright: ignore[reportPrivateUsage]
spec = exec_instance._handler_specs[0] # pyright: ignore[reportPrivateUsage]
assert spec["message_type"] is int
assert spec["output_types"] == [MyTypeA]
@@ -63,7 +63,7 @@ class TestExecutorFutureAnnotations:
pass
exec_instance = MyExecutor(id="test")
spec = exec_instance._handler_specs[0]
spec = exec_instance._handler_specs[0] # pyright: ignore[reportPrivateUsage]
assert spec["message_type"] == dict[str, Any]
assert spec["output_types"] == [list[str]]
@@ -76,8 +76,8 @@ class TestExecutorFutureAnnotations:
pass
exec_instance = MyExecutor(id="test")
assert str in exec_instance._handlers
spec = exec_instance._handler_specs[0]
assert str in exec_instance._handlers # pyright: ignore[reportPrivateUsage]
spec = exec_instance._handler_specs[0] # pyright: ignore[reportPrivateUsage]
assert spec["output_types"] == []
assert spec["workflow_output_types"] == []
@@ -86,12 +86,12 @@ class TestExecutorFutureAnnotations:
class MyExecutor(Executor):
@handler(input=str, output=MyTypeA)
async def example(self, input, ctx) -> None:
async def example(self, input, ctx) -> None: # type: ignore[no-untyped-def]
pass
exec_instance = MyExecutor(id="test")
assert str in exec_instance._handlers
spec = exec_instance._handler_specs[0]
assert str in exec_instance._handlers # pyright: ignore[reportPrivateUsage]
spec = exec_instance._handler_specs[0] # pyright: ignore[reportPrivateUsage]
assert spec["message_type"] is str
assert spec["output_types"] == [MyTypeA]
@@ -104,8 +104,8 @@ class TestExecutorFutureAnnotations:
pass
exec_instance = MyExecutor(id="test")
assert str in exec_instance._handlers
spec = exec_instance._handler_specs[0]
assert str in exec_instance._handlers # pyright: ignore[reportPrivateUsage]
spec = exec_instance._handler_specs[0] # pyright: ignore[reportPrivateUsage]
assert spec["output_types"] == [MyTypeA, MyTypeB]
assert spec["workflow_output_types"] == [MyTypeC]
@@ -118,7 +118,7 @@ class TestExecutorFutureAnnotations:
"""
with pytest.raises(ValueError):
class Bad(Executor):
@handler
async def example(self, input: NonExistentType, ctx: WorkflowContext[MyTypeA, MyTypeB]) -> None: # noqa: F821
class Bad(Executor): # pyright: ignore[reportUnusedClass]
@handler # pyright: ignore[reportUnknownArgumentType]
async def example(self, input: NonExistentType, ctx: WorkflowContext[MyTypeA, MyTypeB]) -> None: # noqa: F821 # type: ignore[name-defined]
pass
@@ -1,7 +1,7 @@
# Copyright (c) Microsoft. All rights reserved.
from collections.abc import AsyncIterable, Awaitable, Sequence
from typing import Any
from collections.abc import AsyncIterable, Awaitable
from typing import Any, Literal, overload
import pytest
from pydantic import PrivateAttr
@@ -13,6 +13,7 @@ from agent_framework import (
AgentExecutorResponse,
AgentResponse,
AgentResponseUpdate,
AgentRunInputs,
AgentSession,
BaseAgent,
Content,
@@ -34,14 +35,32 @@ class _SimpleAgent(BaseAgent):
super().__init__(**kwargs)
self._reply_text = reply_text
@overload
def run(
self,
messages: str | Content | Message | Sequence[str | Content | Message] | None = None,
messages: AgentRunInputs | None = ...,
*,
stream: Literal[False] = ...,
session: AgentSession | None = ...,
**kwargs: Any,
) -> Awaitable[AgentResponse[Any]]: ...
@overload
def run(
self,
messages: AgentRunInputs | None = ...,
*,
stream: Literal[True],
session: AgentSession | None = ...,
**kwargs: Any,
) -> ResponseStream[AgentResponseUpdate, AgentResponse[Any]]: ...
def run(
self,
messages: AgentRunInputs | None = None,
*,
stream: bool = False,
session: AgentSession | None = None,
**kwargs: Any,
) -> Awaitable[AgentResponse] | ResponseStream[AgentResponseUpdate, AgentResponse]:
) -> Awaitable[AgentResponse[Any]] | ResponseStream[AgentResponseUpdate, AgentResponse[Any]]:
if stream:
async def _stream() -> AsyncIterable[AgentResponseUpdate]:
@@ -81,14 +100,32 @@ class _ToolHistoryAgent(BaseAgent):
Message(role="assistant", contents=[Content.from_text(text=self._summary_text)]),
]
@overload
def run(
self,
messages: str | Content | Message | Sequence[str | Content | Message] | None = None,
messages: AgentRunInputs | None = ...,
*,
stream: Literal[False] = ...,
session: AgentSession | None = ...,
**kwargs: Any,
) -> Awaitable[AgentResponse[Any]]: ...
@overload
def run(
self,
messages: AgentRunInputs | None = ...,
*,
stream: Literal[True],
session: AgentSession | None = ...,
**kwargs: Any,
) -> ResponseStream[AgentResponseUpdate, AgentResponse[Any]]: ...
def run(
self,
messages: AgentRunInputs | None = None,
*,
stream: bool = False,
session: AgentSession | None = None,
**kwargs: Any,
) -> Awaitable[AgentResponse] | ResponseStream[AgentResponseUpdate, AgentResponse]:
) -> Awaitable[AgentResponse[Any]] | ResponseStream[AgentResponseUpdate, AgentResponse[Any]]:
if stream:
async def _stream() -> AsyncIterable[AgentResponseUpdate]:
@@ -165,14 +202,32 @@ class _CaptureAgent(BaseAgent):
super().__init__(**kwargs)
self._reply_text = reply_text
@overload
def run(
self,
messages: str | Content | Message | Sequence[str | Content | Message] | None = None,
messages: AgentRunInputs | None = ...,
*,
stream: Literal[False] = ...,
session: AgentSession | None = ...,
**kwargs: Any,
) -> Awaitable[AgentResponse[Any]]: ...
@overload
def run(
self,
messages: AgentRunInputs | None = ...,
*,
stream: Literal[True],
session: AgentSession | None = ...,
**kwargs: Any,
) -> ResponseStream[AgentResponseUpdate, AgentResponse[Any]]: ...
def run(
self,
messages: AgentRunInputs | None = None,
*,
stream: bool = False,
session: AgentSession | None = None,
**kwargs: Any,
) -> Awaitable[AgentResponse] | ResponseStream[AgentResponseUpdate, AgentResponse]:
) -> Awaitable[AgentResponse[Any]] | ResponseStream[AgentResponseUpdate, AgentResponse[Any]]:
# Normalize and record messages for verification
norm: list[Message] = []
if messages:
@@ -260,7 +315,7 @@ class _RoundTripCoordinator(Executor):
async def handle_response(
self,
response: AgentExecutorResponse,
ctx: WorkflowContext[Never, dict[str, Any]],
ctx: WorkflowContext[AgentExecutorRequest, dict[str, Any]],
) -> None:
self._seen += 1
if self._seen == 1:
@@ -314,14 +369,32 @@ class _SessionIdCapturingAgent(BaseAgent):
_captured_service_session_id: str | None = PrivateAttr(default="NOT_CAPTURED")
@overload
def run(
self,
messages: str | Content | Message | Sequence[str | Content | Message] | None = None,
messages: AgentRunInputs | None = ...,
*,
stream: Literal[False] = ...,
session: AgentSession | None = ...,
**kwargs: Any,
) -> Awaitable[AgentResponse[Any]]: ...
@overload
def run(
self,
messages: AgentRunInputs | None = ...,
*,
stream: Literal[True],
session: AgentSession | None = ...,
**kwargs: Any,
) -> ResponseStream[AgentResponseUpdate, AgentResponse[Any]]: ...
def run(
self,
messages: AgentRunInputs | None = None,
*,
stream: bool = False,
session: AgentSession | None = None,
**kwargs: Any,
) -> Awaitable[AgentResponse] | ResponseStream[AgentResponseUpdate, AgentResponse]:
) -> Awaitable[AgentResponse[Any]] | ResponseStream[AgentResponseUpdate, AgentResponse[Any]]:
self._captured_service_session_id = session.service_session_id if session else None
async def _run() -> AgentResponse:
@@ -342,7 +415,7 @@ class _FullHistoryReplayCoordinator(Executor):
async def handle(
self,
response: AgentExecutorResponse,
ctx: WorkflowContext[Never, Any],
ctx: WorkflowContext[AgentExecutorRequest, Any],
) -> None:
full_conv = list(response.full_conversation or response.agent_response.messages)
full_conv.append(Message(role="user", text="follow-up"))
@@ -48,12 +48,12 @@ class TestFunctionExecutor:
func_exec = FunctionExecutor(process_string)
# Check that handler was registered
assert len(func_exec._handlers) == 1
assert str in func_exec._handlers
assert len(func_exec._handlers) == 1 # pyright: ignore[reportPrivateUsage]
assert str in func_exec._handlers # pyright: ignore[reportPrivateUsage]
# Check handler spec was created
assert len(func_exec._handler_specs) == 1
spec = func_exec._handler_specs[0]
assert len(func_exec._handler_specs) == 1 # pyright: ignore[reportPrivateUsage]
spec = func_exec._handler_specs[0] # pyright: ignore[reportPrivateUsage]
assert spec["name"] == "process_string"
assert spec["message_type"] is str
assert spec["output_types"] == [str]
@@ -67,10 +67,10 @@ class TestFunctionExecutor:
assert isinstance(process_int, FunctionExecutor)
assert process_int.id == "test_executor"
assert int in process_int._handlers
assert int in process_int._handlers # pyright: ignore[reportPrivateUsage]
# Check spec
spec = process_int._handler_specs[0]
spec = process_int._handler_specs[0] # pyright: ignore[reportPrivateUsage]
assert spec["message_type"] is int
assert spec["output_types"] == [int]
@@ -78,7 +78,7 @@ class TestFunctionExecutor:
"""Test @executor decorator uses function name as default ID."""
@executor
async def my_function(data: dict, ctx: WorkflowContext[Any]) -> None:
async def my_function(data: dict[str, Any], ctx: WorkflowContext[Any]) -> None:
await ctx.send_message(data)
assert my_function.id == "my_function"
@@ -92,7 +92,7 @@ class TestFunctionExecutor:
assert isinstance(no_parens_function, FunctionExecutor)
assert no_parens_function.id == "no_parens_function"
assert str in no_parens_function._handlers
assert str in no_parens_function._handlers # pyright: ignore[reportPrivateUsage]
# Also test with single parameter function
@executor
@@ -101,7 +101,7 @@ class TestFunctionExecutor:
assert isinstance(simple_no_parens, FunctionExecutor)
assert simple_no_parens.id == "simple_no_parens"
assert int in simple_no_parens._handlers
assert int in simple_no_parens._handlers # pyright: ignore[reportPrivateUsage]
def test_union_output_types(self):
"""Test that union output types are properly inferred for both messages and workflow outputs."""
@@ -113,7 +113,7 @@ class TestFunctionExecutor:
else:
await ctx.send_message(text.upper())
spec = multi_output._handler_specs[0]
spec = multi_output._handler_specs[0] # pyright: ignore[reportPrivateUsage]
assert set(spec["output_types"]) == {str, int}
assert spec["workflow_output_types"] == [] # No workflow outputs defined
@@ -127,7 +127,7 @@ class TestFunctionExecutor:
else:
await ctx.yield_output(data.upper())
workflow_spec = multi_workflow_output._handler_specs[0]
workflow_spec = multi_workflow_output._handler_specs[0] # pyright: ignore[reportPrivateUsage]
assert workflow_spec["output_types"] == [] # None means no message outputs
assert set(workflow_spec["workflow_output_types"]) == {str, int, bool}
@@ -139,7 +139,7 @@ class TestFunctionExecutor:
# This executor doesn't send any messages
pass
spec = no_output._handler_specs[0]
spec = no_output._handler_specs[0] # pyright: ignore[reportPrivateUsage]
assert spec["output_types"] == []
assert spec["workflow_output_types"] == [] # No workflow outputs defined
@@ -150,7 +150,7 @@ class TestFunctionExecutor:
async def any_output(data: str, ctx: WorkflowContext[Any]) -> None:
await ctx.send_message("result")
spec = any_output._handler_specs[0]
spec = any_output._handler_specs[0] # pyright: ignore[reportPrivateUsage]
assert spec["output_types"] == [Any]
assert spec["workflow_output_types"] == [] # No workflow outputs defined
@@ -160,7 +160,7 @@ class TestFunctionExecutor:
await ctx.send_message("message")
await ctx.yield_output("workflow_output")
both_spec = any_both_output._handler_specs[0]
both_spec = any_both_output._handler_specs[0] # pyright: ignore[reportPrivateUsage]
assert both_spec["output_types"] == [Any]
assert both_spec["workflow_output_types"] == [Any]
@@ -228,11 +228,11 @@ class TestFunctionExecutor:
await ctx.yield_output(result)
# Verify type inference for both executors
upper_spec = to_upper._handler_specs[0]
upper_spec = to_upper._handler_specs[0] # pyright: ignore[reportPrivateUsage]
assert upper_spec["output_types"] == [str]
assert upper_spec["workflow_output_types"] == [] # No workflow outputs
reverse_spec = reverse_text._handler_specs[0]
reverse_spec = reverse_text._handler_specs[0] # pyright: ignore[reportPrivateUsage]
assert reverse_spec["output_types"] == [Any] # First parameter is Any
assert reverse_spec["workflow_output_types"] == [str] # Second parameter is str
@@ -270,7 +270,7 @@ class TestFunctionExecutor:
await ctx.send_message(message)
with pytest.raises(ValueError, match="Handler for type .* already registered"):
func_exec._register_instance_handler(
func_exec._register_instance_handler( # pyright: ignore[reportPrivateUsage]
name="second",
func=second_handler,
message_type=str,
@@ -287,7 +287,7 @@ class TestFunctionExecutor:
result = {item: len(item) for item in items}
await ctx.send_message(result)
spec = process_list._handler_specs[0]
spec = process_list._handler_specs[0] # pyright: ignore[reportPrivateUsage]
assert spec["message_type"] == list[str]
assert spec["output_types"] == [dict[str, int]]
@@ -300,10 +300,10 @@ class TestFunctionExecutor:
assert isinstance(process_simple, FunctionExecutor)
assert process_simple.id == "simple_processor"
assert str in process_simple._handlers
assert str in process_simple._handlers # pyright: ignore[reportPrivateUsage]
# Check spec - single parameter functions have no output types since they can't send messages
spec = process_simple._handler_specs[0]
spec = process_simple._handler_specs[0] # pyright: ignore[reportPrivateUsage]
assert spec["message_type"] is str
assert spec["output_types"] == []
assert spec["ctx_annotation"] is None
@@ -316,7 +316,7 @@ class TestFunctionExecutor:
return data * 2
func_exec = FunctionExecutor(valid_single)
assert int in func_exec._handlers
assert int in func_exec._handlers # pyright: ignore[reportPrivateUsage]
# Single parameter with missing type annotation should still fail
async def no_annotation(data): # type: ignore
@@ -349,7 +349,7 @@ class TestFunctionExecutor:
# For testing purposes, we can check that the handler is registered correctly
assert double_value.can_handle(WorkflowMessage(data=5, source_id="mock"))
assert int in double_value._handlers
assert int in double_value._handlers # pyright: ignore[reportPrivateUsage]
def test_sync_function_basic(self):
"""Test basic synchronous function support."""
@@ -360,10 +360,10 @@ class TestFunctionExecutor:
assert isinstance(process_sync, FunctionExecutor)
assert process_sync.id == "sync_processor"
assert str in process_sync._handlers
assert str in process_sync._handlers # pyright: ignore[reportPrivateUsage]
# Check spec - sync single parameter functions have no output types
spec = process_sync._handler_specs[0]
spec = process_sync._handler_specs[0] # pyright: ignore[reportPrivateUsage]
assert spec["message_type"] is str
assert spec["output_types"] == []
assert spec["ctx_annotation"] is None
@@ -378,10 +378,10 @@ class TestFunctionExecutor:
assert isinstance(sync_with_ctx, FunctionExecutor)
assert sync_with_ctx.id == "sync_with_ctx"
assert int in sync_with_ctx._handlers
assert int in sync_with_ctx._handlers # pyright: ignore[reportPrivateUsage]
# Check spec - sync functions with context can infer output types
spec = sync_with_ctx._handler_specs[0]
spec = sync_with_ctx._handler_specs[0] # pyright: ignore[reportPrivateUsage]
assert spec["message_type"] is int
assert spec["output_types"] == [int]
@@ -404,18 +404,18 @@ class TestFunctionExecutor:
return data.upper()
func_exec = FunctionExecutor(valid_sync)
assert str in func_exec._handlers
assert str in func_exec._handlers # pyright: ignore[reportPrivateUsage]
# Valid sync function with two parameters
def valid_sync_with_ctx(data: int, ctx: WorkflowContext[str]):
return str(data)
func_exec2 = FunctionExecutor(valid_sync_with_ctx)
assert int in func_exec2._handlers
assert int in func_exec2._handlers # pyright: ignore[reportPrivateUsage]
# Sync function with missing type annotation should still fail
def no_annotation(data): # type: ignore
return data
def no_annotation(data): # type: ignore # pyright: ignore[reportUnknownVariableType]
return data # pyright: ignore[reportUnknownVariableType]
with pytest.raises(ValueError, match="type annotation for the message"):
FunctionExecutor(no_annotation) # type: ignore
@@ -457,11 +457,11 @@ class TestFunctionExecutor:
await ctx.yield_output(result)
# Verify type inference for sync and async functions
sync_spec = to_upper_sync._handler_specs[0]
sync_spec = to_upper_sync._handler_specs[0] # pyright: ignore[reportPrivateUsage]
assert sync_spec["output_types"] == [str]
assert sync_spec["workflow_output_types"] == [] # No workflow outputs
async_spec = reverse_async._handler_specs[0]
async_spec = reverse_async._handler_specs[0] # pyright: ignore[reportPrivateUsage]
assert async_spec["output_types"] == [Any] # First parameter is Any
assert async_spec["workflow_output_types"] == [str] # Second parameter is str
@@ -471,8 +471,8 @@ class TestFunctionExecutor:
# For integration testing, we mainly verify that the handlers are properly registered
# and the functions are wrapped correctly
assert str in to_upper_sync._handlers
assert str in reverse_async._handlers
assert str in to_upper_sync._handlers # pyright: ignore[reportPrivateUsage]
assert str in reverse_async._handlers # pyright: ignore[reportPrivateUsage]
async def test_sync_function_thread_execution(self):
"""Test that sync functions run in thread pool and don't block the event loop."""
@@ -491,13 +491,13 @@ class TestFunctionExecutor:
return data.upper()
# Verify the function is wrapped and registered
assert str in blocking_function._handlers
assert str in blocking_function._handlers # pyright: ignore[reportPrivateUsage]
# For a more complete test, we'd need to create a full workflow context,
# but for now we can verify that the function was properly wrapped
# and that sync functions store the correct metadata
assert not blocking_function._is_async
assert not blocking_function._has_context
assert not blocking_function._is_async # pyright: ignore[reportPrivateUsage]
assert not blocking_function._has_context # pyright: ignore[reportPrivateUsage]
# The actual thread execution test would require a full workflow setup,
# but the important thing is that asyncio.to_thread is used in the wrapper
@@ -506,7 +506,7 @@ class TestFunctionExecutor:
"""Test that @executor decorator properly rejects @staticmethod with clear error."""
with pytest.raises(ValueError) as exc_info:
class Example:
class Example: # pyright: ignore[reportUnusedClass]
@executor
@staticmethod
async def bad_handler(data: str) -> str:
@@ -519,7 +519,7 @@ class TestFunctionExecutor:
"""Test that @executor decorator properly rejects @classmethod with clear error."""
with pytest.raises(ValueError) as exc_info:
class Example:
class Example: # pyright: ignore[reportUnusedClass]
@executor
@classmethod
async def bad_handler(cls, data: str) -> str:
@@ -570,8 +570,8 @@ class TestExecutorExplicitTypes:
pass
# Handler should be registered for str (explicit)
assert str in process._handlers
assert len(process._handlers) == 1
assert str in process._handlers # pyright: ignore[reportPrivateUsage]
assert len(process._handlers) == 1 # pyright: ignore[reportPrivateUsage]
# Can handle str messages
assert process.can_handle(WorkflowMessage(data="hello", source_id="mock"))
@@ -586,7 +586,7 @@ class TestExecutorExplicitTypes:
pass
# Handler spec should have int as output type (explicit), not str (introspected)
spec = process._handler_specs[0]
spec = process._handler_specs[0] # pyright: ignore[reportPrivateUsage]
assert spec["output_types"] == [int]
# Executor output_types property should reflect explicit type
@@ -601,11 +601,11 @@ class TestExecutorExplicitTypes:
pass
# Handler should be registered for dict (explicit input type)
assert dict in process._handlers
assert len(process._handlers) == 1
assert dict in process._handlers # pyright: ignore[reportPrivateUsage]
assert len(process._handlers) == 1 # pyright: ignore[reportPrivateUsage]
# Output type should be list (explicit)
spec = process._handler_specs[0]
spec = process._handler_specs[0] # pyright: ignore[reportPrivateUsage]
assert spec["output_types"] == [list]
# Verify can_handle
@@ -620,7 +620,7 @@ class TestExecutorExplicitTypes:
pass
# Handler should be registered for the union type
assert len(process._handlers) == 1
assert len(process._handlers) == 1 # pyright: ignore[reportPrivateUsage]
# Can handle both str and int messages
assert process.can_handle(WorkflowMessage(data="hello", source_id="mock"))
@@ -648,8 +648,8 @@ class TestExecutorExplicitTypes:
pass
# Should use explicit input type (bytes), not introspected (str)
assert bytes in process._handlers
assert str not in process._handlers
assert bytes in process._handlers # pyright: ignore[reportPrivateUsage]
assert str not in process._handlers # pyright: ignore[reportPrivateUsage]
# Should use explicit output type (float), not introspected (int)
assert float in process.output_types
@@ -663,7 +663,7 @@ class TestExecutorExplicitTypes:
pass
# Should use introspected types
assert str in process._handlers
assert str in process._handlers # pyright: ignore[reportPrivateUsage]
assert int in process.output_types
def test_executor_partial_explicit_types(self):
@@ -674,7 +674,7 @@ class TestExecutorExplicitTypes:
async def process_input(message: str, ctx: WorkflowContext[int]) -> None:
pass
assert bytes in process_input._handlers # Explicit
assert bytes in process_input._handlers # Explicit # pyright: ignore[reportPrivateUsage]
assert int in process_input.output_types # Introspected
# Only explicit output_type, introspect input_type
@@ -682,7 +682,7 @@ class TestExecutorExplicitTypes:
async def process_output(message: str, ctx: WorkflowContext[int]) -> None:
pass
assert str in process_output._handlers # Introspected
assert str in process_output._handlers # Introspected # pyright: ignore[reportPrivateUsage]
assert float in process_output.output_types # Explicit
assert int not in process_output.output_types # Not introspected when explicit provided
@@ -694,7 +694,7 @@ class TestExecutorExplicitTypes:
pass
# Should work with explicit input_type
assert str in process._handlers
assert str in process._handlers # pyright: ignore[reportPrivateUsage]
assert process.can_handle(WorkflowMessage(data="hello", source_id="mock"))
def test_executor_explicit_types_with_id(self):
@@ -705,7 +705,7 @@ class TestExecutorExplicitTypes:
pass
assert process.id == "custom_id"
assert bytes in process._handlers
assert bytes in process._handlers # pyright: ignore[reportPrivateUsage]
assert int in process.output_types
def test_executor_explicit_types_with_single_param_function(self):
@@ -713,10 +713,10 @@ class TestExecutorExplicitTypes:
@executor(input=str)
async def process(message): # type: ignore[no-untyped-def]
return message.upper()
return message.upper() # pyright: ignore[reportUnknownMemberType, reportUnknownVariableType]
# Should work with explicit input_type
assert str in process._handlers
assert str in process._handlers # pyright: ignore[reportPrivateUsage]
assert process.can_handle(WorkflowMessage(data="hello", source_id="mock"))
assert not process.can_handle(WorkflowMessage(data=42, source_id="mock"))
@@ -727,7 +727,7 @@ class TestExecutorExplicitTypes:
def process(message, ctx: WorkflowContext) -> None: # type: ignore[no-untyped-def]
pass
assert int in process._handlers
assert int in process._handlers # pyright: ignore[reportPrivateUsage]
assert str in process.output_types
def test_function_executor_constructor_with_explicit_types(self):
@@ -736,10 +736,10 @@ class TestExecutorExplicitTypes:
async def process(message, ctx: WorkflowContext) -> None: # type: ignore[no-untyped-def]
pass
func_exec = FunctionExecutor(process, id="test", input=dict, output=list)
func_exec = FunctionExecutor(process, id="test", input=dict, output=list) # pyright: ignore[reportUnknownArgumentType]
assert dict in func_exec._handlers
spec = func_exec._handler_specs[0]
assert dict in func_exec._handlers # pyright: ignore[reportPrivateUsage]
spec = func_exec._handler_specs[0] # pyright: ignore[reportPrivateUsage]
assert spec["message_type"] is dict
assert spec["output_types"] == [list]
@@ -766,7 +766,7 @@ class TestExecutorExplicitTypes:
pass
# Should resolve the string to the actual type
assert FuncExecForwardRefMessage in process._handlers
assert FuncExecForwardRefMessage in process._handlers # pyright: ignore[reportPrivateUsage]
assert process.can_handle(WorkflowMessage(data=FuncExecForwardRefMessage("hello"), source_id="mock"))
def test_executor_with_string_forward_reference_union(self):
@@ -798,7 +798,7 @@ class TestExecutorExplicitTypes:
pass
# Handler spec should have bool as workflow_output_type (explicit)
spec = process._handler_specs[0]
spec = process._handler_specs[0] # pyright: ignore[reportPrivateUsage]
assert spec["workflow_output_types"] == [bool]
# Executor workflow_output_types property should reflect explicit type
@@ -826,7 +826,7 @@ class TestExecutorExplicitTypes:
pass
# Check input type
assert str in process._handlers
assert str in process._handlers # pyright: ignore[reportPrivateUsage]
assert process.can_handle(WorkflowMessage(data="hello", source_id="mock"))
# Check output_type
@@ -892,6 +892,6 @@ class TestExecutorExplicitTypes:
workflow_output=bool,
)
assert str in exec_instance._handlers
assert str in exec_instance._handlers # pyright: ignore[reportPrivateUsage]
assert int in exec_instance.output_types
assert bool in exec_instance.workflow_output_types
@@ -19,10 +19,10 @@ class TestFunctionExecutorFutureAnnotations:
assert isinstance(process_future, FunctionExecutor)
assert process_future.id == "future_test"
assert int in process_future._handlers
assert int in process_future._handlers # pyright: ignore[reportPrivateUsage]
# Check spec
spec = process_future._handler_specs[0]
spec = process_future._handler_specs[0] # pyright: ignore[reportPrivateUsage]
assert spec["message_type"] is int
assert spec["output_types"] == [int]
@@ -34,6 +34,6 @@ class TestFunctionExecutorFutureAnnotations:
await ctx.send_message(["done"])
assert isinstance(process_complex, FunctionExecutor)
spec = process_complex._handler_specs[0]
spec = process_complex._handler_specs[0] # pyright: ignore[reportPrivateUsage]
assert spec["message_type"] == dict[str, Any]
assert spec["output_types"] == [list[str]]
@@ -794,7 +794,7 @@ class TestResponseHandlerExplicitTypes:
"""Test response_handler with explicit request and response types."""
@response_handler(request=str, response=int)
async def test_handler(self, original_request, response, ctx) -> None:
async def test_handler(self: Any, original_request: Any, response: Any, ctx: WorkflowContext) -> None:
pass
spec = test_handler._response_handler_spec # type: ignore[reportAttributeAccessIssue]
@@ -806,7 +806,7 @@ class TestResponseHandlerExplicitTypes:
"""Test response_handler with explicit output and workflow_output types."""
@response_handler(request=str, response=int, output=bool, workflow_output=float)
async def test_handler(self, original_request, response, ctx) -> None:
async def test_handler(self: Any, original_request: Any, response: Any, ctx: WorkflowContext) -> None:
pass
spec = test_handler._response_handler_spec # type: ignore[reportAttributeAccessIssue]
@@ -818,8 +818,8 @@ class TestResponseHandlerExplicitTypes:
def test_response_handler_with_union_types(self):
"""Test response_handler with union types."""
@response_handler(request=str | int, response=bool | float)
async def test_handler(self, original_request, response, ctx) -> None:
@response_handler(request=str | int, response=bool | float) # pyright: ignore[reportArgumentType]
async def test_handler(self: Any, original_request: Any, response: Any, ctx: WorkflowContext) -> None:
pass
spec = test_handler._response_handler_spec # type: ignore[reportAttributeAccessIssue]
@@ -830,7 +830,7 @@ class TestResponseHandlerExplicitTypes:
"""Test response_handler with string forward references."""
@response_handler(request="str", response="int")
async def test_handler(self, original_request, response, ctx) -> None:
async def test_handler(self: Any, original_request: Any, response: Any, ctx: WorkflowContext) -> None:
pass
spec = test_handler._response_handler_spec # type: ignore[reportAttributeAccessIssue]
@@ -842,7 +842,7 @@ class TestResponseHandlerExplicitTypes:
with pytest.raises(ValueError, match="must specify 'request' type"):
@response_handler(response=int)
async def test_handler(self, original_request, response, ctx) -> None:
async def test_handler(self: Any, original_request: Any, response: Any, ctx: WorkflowContext) -> None: # pyright: ignore[reportUnusedFunction]
pass
def test_response_handler_explicit_missing_response_raises_error(self):
@@ -850,7 +850,7 @@ class TestResponseHandlerExplicitTypes:
with pytest.raises(ValueError, match="must specify 'response' type"):
@response_handler(request=str)
async def test_handler(self, original_request, response, ctx) -> None:
async def test_handler(self: Any, original_request: Any, response: Any, ctx: WorkflowContext) -> None: # pyright: ignore[reportUnusedFunction]
pass
def test_response_handler_explicit_only_output_raises_error(self):
@@ -858,7 +858,7 @@ class TestResponseHandlerExplicitTypes:
with pytest.raises(ValueError, match="must specify 'request' type"):
@response_handler(output=bool)
async def test_handler(self, original_request, response, ctx) -> None:
async def test_handler(self: Any, original_request: Any, response: Any, ctx: WorkflowContext) -> None: # pyright: ignore[reportUnusedFunction]
pass
def test_executor_with_explicit_response_handlers(self):
@@ -873,7 +873,7 @@ class TestResponseHandlerExplicitTypes:
pass
@response_handler(request=str, response=int, output=bool)
async def handle_explicit(self, original_request, response, ctx) -> None:
async def handle_explicit(self, original_request: Any, response: Any, ctx: WorkflowContext) -> None:
pass
executor = TestExecutor()
@@ -907,7 +907,7 @@ class TestResponseHandlerExplicitTypes:
pass
@response_handler(request=str, response=int)
async def handle_response(self, original_request, response, ctx) -> None:
async def handle_response(self, original_request: Any, response: Any, ctx: WorkflowContext) -> None:
self.handled_request = original_request
self.handled_response = response
@@ -942,7 +942,7 @@ class TestResponseHandlerExplicitTypes:
# Explicit type handler
@response_handler(request=dict, response=bool)
async def handle_explicit(self, original_request, response, ctx) -> None:
async def handle_explicit(self, original_request: Any, response: Any, ctx: WorkflowContext) -> None:
pass
executor = TestExecutor()
@@ -2,6 +2,7 @@
import asyncio
from dataclasses import dataclass
from typing import Any
from unittest.mock import AsyncMock, MagicMock
import pytest
@@ -113,7 +114,7 @@ async def test_runner_run_until_convergence():
assert result is not None and result == 10
# iteration count shouldn't be reset after convergence
assert runner._iteration == 10 # type: ignore
assert runner._iteration == 10 # pyright: ignore[reportPrivateUsage]
async def test_runner_run_until_convergence_not_completed():
@@ -173,7 +174,7 @@ async def test_runner_run_iteration_preserves_message_order_per_edge_runner() ->
for index in range(5):
await ctx.send_message(WorkflowMessage(data=MockMessage(data=index), source_id="source"))
await runner._run_iteration()
await runner._run_iteration() # pyright: ignore[reportPrivateUsage]
assert edge_runner.received == [0, 1, 2, 3, 4]
@@ -213,7 +214,7 @@ async def test_runner_run_iteration_delivers_different_edge_runners_concurrently
await ctx.send_message(WorkflowMessage(data=MockMessage(data=1), source_id="source"))
iteration_task = asyncio.create_task(runner._run_iteration())
iteration_task = asyncio.create_task(runner._run_iteration()) # pyright: ignore[reportPrivateUsage]
await blocking_edge_runner.started.wait()
await asyncio.wait_for(probe_edge_runner.probe_completed.wait(), timeout=2.0)
@@ -280,7 +281,7 @@ async def test_fanout_edge_runner_delivers_to_multiple_targets_concurrently() ->
# Queue a message from source (will be delivered to both targets via FanOut)
await ctx.send_message(WorkflowMessage(data=MockMessage(data=1), source_id=source.id))
iteration_task = asyncio.create_task(runner._run_iteration())
iteration_task = asyncio.create_task(runner._run_iteration()) # pyright: ignore[reportPrivateUsage]
# Wait for the blocking executor to start
await blocking_target.started.wait()
@@ -477,11 +478,11 @@ async def test_runner_reset_iteration_count():
ctx = InProcRunnerContext()
runner = Runner([], {executor_a.id: executor_a}, state, ctx, "test_name", graph_signature_hash="test_hash")
runner._iteration = 10
runner._iteration = 10 # pyright: ignore[reportPrivateUsage]
runner.reset_iteration_count()
assert runner._iteration == 0
assert runner._iteration == 0 # pyright: ignore[reportPrivateUsage]
class CheckpointingContext(InProcRunnerContext):
@@ -501,18 +502,19 @@ class CheckpointingContext(InProcRunnerContext):
graph_signature_hash: str,
state: State,
previous_checkpoint_id: str | None,
iteration: int,
iteration_count: int,
metadata: dict[str, Any] | None = None,
) -> str:
checkpoint = WorkflowCheckpoint(
workflow_name=workflow_name,
graph_signature_hash=graph_signature_hash,
state=state.export(),
state=state.export_state(),
previous_checkpoint_id=previous_checkpoint_id,
iteration_count=iteration,
iteration_count=iteration_count,
)
return await self._storage.save(checkpoint)
async def load_checkpoint(self, checkpoint_id: str) -> WorkflowCheckpoint | None:
async def load_checkpoint(self, checkpoint_id: str) -> WorkflowCheckpoint | None: # pyright: ignore[reportIncompatibleMethodOverride]
try:
return await self._storage.load(checkpoint_id)
except WorkflowCheckpointException:
@@ -537,7 +539,8 @@ class FailingCheckpointContext(InProcRunnerContext):
graph_signature_hash: str,
state: State,
previous_checkpoint_id: str | None,
iteration: int,
iteration_count: int,
metadata: dict[str, Any] | None = None,
) -> str:
raise RuntimeError("Simulated checkpoint failure")
@@ -609,8 +612,8 @@ async def test_runner_restore_from_checkpoint_with_external_storage():
# Restore using external storage
await runner.restore_from_checkpoint(checkpoint_id, checkpoint_storage=storage)
assert runner._resumed_from_checkpoint is True
assert runner._iteration == 5
assert runner._resumed_from_checkpoint is True # pyright: ignore[reportPrivateUsage]
assert runner._iteration == 5 # pyright: ignore[reportPrivateUsage]
assert state.get("test_key") == "test_value"
@@ -684,7 +687,7 @@ async def test_runner_restore_executor_states_invalid_states_type():
runner = Runner([], {executor_a.id: executor_a}, state, ctx, "test_name", graph_signature_hash="test_hash")
with pytest.raises(WorkflowCheckpointException, match="not a dictionary"):
await runner._restore_executor_states()
await runner._restore_executor_states() # pyright: ignore[reportPrivateUsage]
async def test_runner_restore_executor_states_invalid_executor_id_type():
@@ -698,7 +701,7 @@ async def test_runner_restore_executor_states_invalid_executor_id_type():
runner = Runner([], {executor_a.id: executor_a}, state, ctx, "test_name", graph_signature_hash="test_hash")
with pytest.raises(WorkflowCheckpointException, match="not a string"):
await runner._restore_executor_states()
await runner._restore_executor_states() # pyright: ignore[reportPrivateUsage]
async def test_runner_restore_executor_states_invalid_state_type():
@@ -712,7 +715,7 @@ async def test_runner_restore_executor_states_invalid_state_type():
runner = Runner([], {executor_a.id: executor_a}, state, ctx, "test_name", graph_signature_hash="test_hash")
with pytest.raises(WorkflowCheckpointException, match="not a dict"):
await runner._restore_executor_states()
await runner._restore_executor_states() # pyright: ignore[reportPrivateUsage]
async def test_runner_restore_executor_states_invalid_state_keys():
@@ -726,7 +729,7 @@ async def test_runner_restore_executor_states_invalid_state_keys():
runner = Runner([], {executor_a.id: executor_a}, state, ctx, "test_name", graph_signature_hash="test_hash")
with pytest.raises(WorkflowCheckpointException, match="not a dict"):
await runner._restore_executor_states()
await runner._restore_executor_states() # pyright: ignore[reportPrivateUsage]
async def test_runner_restore_executor_states_missing_executor():
@@ -739,7 +742,7 @@ async def test_runner_restore_executor_states_missing_executor():
runner = Runner([], {}, state, ctx, "test_name", graph_signature_hash="test_hash")
with pytest.raises(WorkflowCheckpointException, match="not found during state restoration"):
await runner._restore_executor_states()
await runner._restore_executor_states() # pyright: ignore[reportPrivateUsage]
async def test_runner_set_executor_state_invalid_existing_states():
@@ -752,7 +755,7 @@ async def test_runner_set_executor_state_invalid_existing_states():
runner = Runner([], {executor_a.id: executor_a}, state, ctx, "test_name", graph_signature_hash="test_hash")
with pytest.raises(WorkflowCheckpointException, match="not a dictionary"):
await runner._set_executor_state("executor_a", {"key": "value"})
await runner._set_executor_state("executor_a", {"key": "value"}) # pyright: ignore[reportPrivateUsage]
async def test_runner_with_pre_loop_events():
@@ -779,7 +782,7 @@ class EventEmittingExecutor(Executor):
"""An executor that emits events during execution."""
@handler
async def handle(self, message: MockMessage, ctx: WorkflowContext[MockMessage, int]) -> None:
async def handle(self, message: MockMessage, ctx: WorkflowContext[MockMessage, str]) -> None:
# Emit event during processing
await ctx.yield_output(f"processed-{message.data}")
if message.data < 3:
@@ -831,7 +834,7 @@ async def test_runner_restore_executor_states_no_states():
runner = Runner([], {executor_a.id: executor_a}, state, ctx, "test_name", graph_signature_hash="test_hash")
# Should complete without error when no executor states exist
await runner._restore_executor_states()
await runner._restore_executor_states() # pyright: ignore[reportPrivateUsage]
async def test_runner_checkpoint_with_resumed_flag():
@@ -853,7 +856,7 @@ async def test_runner_checkpoint_with_resumed_flag():
state = State()
runner = Runner(edges, executors, state, ctx, "test_name", graph_signature_hash="test_hash")
runner._mark_resumed(5)
runner._mark_resumed(5) # pyright: ignore[reportPrivateUsage]
# Add a message to trigger the checkpoint creation path
await ctx.send_message(WorkflowMessage(data=MockMessage(data=8), source_id="START"))
@@ -870,7 +873,7 @@ async def test_runner_checkpoint_with_resumed_flag():
pass
# After completing, resumed flag should be reset
assert runner._resumed_from_checkpoint is False
assert runner._resumed_from_checkpoint is False # pyright: ignore[reportPrivateUsage]
class ExecutorThatFailsWithEvents(Executor):
@@ -883,7 +886,7 @@ class ExecutorThatFailsWithEvents(Executor):
self._iteration_count = 0
@handler
async def handle(self, message: MockMessage, ctx: WorkflowContext[MockMessage, int]) -> None:
async def handle(self, message: MockMessage, ctx: WorkflowContext[MockMessage, str]) -> None:
self._iteration_count += 1
# First emit an output event to the workflow context
await ctx.yield_output(f"output-before-failure-{message.data}")
@@ -951,7 +954,7 @@ class SlowEventEmittingExecutor(Executor):
self.current_iteration = 0
@handler
async def handle(self, message: MockMessage, ctx: WorkflowContext[MockMessage, int]) -> None:
async def handle(self, message: MockMessage, ctx: WorkflowContext[MockMessage, str]) -> None:
self.current_iteration += 1
# Emit output event
await ctx.yield_output(f"iteration-{self.current_iteration}")
@@ -61,9 +61,9 @@ class TestSuperstepCaching:
state.set("key", "value")
# Value is in pending
assert "key" in state._pending
assert "key" in state._pending # pyright: ignore[reportPrivateUsage]
# Value is NOT in committed
assert "key" not in state._committed
assert "key" not in state._committed # pyright: ignore[reportPrivateUsage]
# But get() still returns it
assert state.get("key") == "value"
@@ -72,14 +72,14 @@ class TestSuperstepCaching:
state.set("key", "value")
# Before commit: in pending, not committed
assert "key" in state._pending
assert "key" not in state._committed
assert "key" in state._pending # pyright: ignore[reportPrivateUsage]
assert "key" not in state._committed # pyright: ignore[reportPrivateUsage]
state.commit()
# After commit: in committed, pending cleared
assert "key" not in state._pending
assert "key" in state._committed
assert "key" not in state._pending # pyright: ignore[reportPrivateUsage]
assert "key" in state._committed # pyright: ignore[reportPrivateUsage]
assert state.get("key") == "value"
def test_discard_clears_pending_without_committing(self) -> None:
@@ -108,7 +108,7 @@ class TestSuperstepCaching:
# get() returns pending value, not committed
assert state.get("key") == "pending_value"
# But committed still has old value
assert state._committed["key"] == "committed_value"
assert state._committed["key"] == "committed_value" # pyright: ignore[reportPrivateUsage]
def test_multiple_sets_before_commit(self) -> None:
state = State()
@@ -130,13 +130,13 @@ class TestDeleteWithSuperstepCaching:
state = State()
state.set("key", "value")
# Key only in pending, not committed
assert "key" in state._pending
assert "key" not in state._committed
assert "key" in state._pending # pyright: ignore[reportPrivateUsage]
assert "key" not in state._committed # pyright: ignore[reportPrivateUsage]
state.delete("key")
# Should be removed from pending
assert "key" not in state._pending
assert "key" not in state._pending # pyright: ignore[reportPrivateUsage]
assert state.get("key") is None
assert state.has("key") is False
@@ -148,14 +148,14 @@ class TestDeleteWithSuperstepCaching:
state.delete("key")
# Key should be marked for deletion in pending (sentinel)
assert "key" in state._pending
assert "key" in state._pending # pyright: ignore[reportPrivateUsage]
# get() should return default (not the sentinel!)
assert state.get("key") is None
assert state.get("key", "default") == "default"
# has() should return False
assert state.has("key") is False
# But committed still has it until commit()
assert "key" in state._committed
assert "key" in state._committed # pyright: ignore[reportPrivateUsage]
def test_delete_committed_key_removed_on_commit(self) -> None:
state = State()
@@ -166,8 +166,8 @@ class TestDeleteWithSuperstepCaching:
state.commit()
# Now it should be gone from committed too
assert "key" not in state._committed
assert "key" not in state._pending
assert "key" not in state._committed # pyright: ignore[reportPrivateUsage]
assert "key" not in state._pending # pyright: ignore[reportPrivateUsage]
def test_delete_key_in_both_pending_and_committed(self) -> None:
"""Test delete when key exists in both pending (modified) and committed."""
@@ -177,8 +177,8 @@ class TestDeleteWithSuperstepCaching:
# Modify the key (now in both pending and committed)
state.set("key", "modified")
assert state._pending["key"] == "modified"
assert state._committed["key"] == "original"
assert state._pending["key"] == "modified" # pyright: ignore[reportPrivateUsage]
assert state._committed["key"] == "original" # pyright: ignore[reportPrivateUsage]
# Delete should mark for deletion from committed
state.delete("key")
@@ -189,8 +189,8 @@ class TestDeleteWithSuperstepCaching:
# After commit, key should be fully removed
state.commit()
assert "key" not in state._committed
assert "key" not in state._pending
assert "key" not in state._committed # pyright: ignore[reportPrivateUsage]
assert "key" not in state._pending # pyright: ignore[reportPrivateUsage]
def test_discard_after_delete_restores_committed_value(self) -> None:
state = State()
@@ -238,12 +238,12 @@ class TestFailureScenarios:
state.set("key3", "value3")
# Before commit - nothing in committed
assert len(state._committed) == 0
assert len(state._committed) == 0 # pyright: ignore[reportPrivateUsage]
state.commit()
# After commit - all three values committed together
assert state._committed == {"key1": "value1", "key2": "value2", "key3": "value3"}
assert state._committed == {"key1": "value1", "key2": "value2", "key3": "value3"} # pyright: ignore[reportPrivateUsage]
def test_repeated_supersteps_are_isolated(self) -> None:
"""Test that each superstep's changes are isolated until committed."""
@@ -300,4 +300,4 @@ class TestExportImport:
# Pending is still there
assert state.get("pending_key") == "pending_value"
assert "pending_key" in state._pending
assert "pending_key" in state._pending # pyright: ignore[reportPrivateUsage]
@@ -36,32 +36,32 @@ def test_normalize_type_to_list_none() -> None:
def test_normalize_type_to_list_union_pipe_syntax() -> None:
"""Test normalize_type_to_list with union types using | syntax."""
result = normalize_type_to_list(str | int)
result = normalize_type_to_list(str | int) # pyright: ignore[reportArgumentType]
assert set(result) == {str, int}
result = normalize_type_to_list(str | int | bool)
result = normalize_type_to_list(str | int | bool) # pyright: ignore[reportArgumentType]
assert set(result) == {str, int, bool}
def test_normalize_type_to_list_union_typing_syntax() -> None:
"""Test normalize_type_to_list with Union[] from typing module."""
result = normalize_type_to_list(Union[str, int])
result = normalize_type_to_list(Union[str, int]) # pyright: ignore[reportArgumentType]
assert set(result) == {str, int}
result = normalize_type_to_list(Union[str, int, bool])
result = normalize_type_to_list(Union[str, int, bool]) # pyright: ignore[reportArgumentType]
assert set(result) == {str, int, bool}
def test_normalize_type_to_list_optional() -> None:
"""Test normalize_type_to_list with Optional types (Union[T, None])."""
# Optional[str] is Union[str, None]
result = normalize_type_to_list(Optional[str])
result = normalize_type_to_list(Optional[str]) # pyright: ignore[reportArgumentType]
assert str in result
assert type(None) in result
assert len(result) == 2
# str | None is equivalent
result = normalize_type_to_list(str | None)
result = normalize_type_to_list(str | None) # pyright: ignore[reportArgumentType]
assert str in result
assert type(None) in result
assert len(result) == 2
@@ -77,7 +77,7 @@ def test_normalize_type_to_list_custom_types() -> None:
result = normalize_type_to_list(CustomMessage)
assert result == [CustomMessage]
result = normalize_type_to_list(CustomMessage | str)
result = normalize_type_to_list(CustomMessage | str) # pyright: ignore[reportArgumentType]
assert set(result) == {CustomMessage, str}
@@ -96,7 +96,7 @@ def test_resolve_type_annotation_actual_types() -> None:
"""Test resolve_type_annotation passes through actual types unchanged."""
assert resolve_type_annotation(str) is str
assert resolve_type_annotation(int) is int
assert resolve_type_annotation(str | int) == str | int
assert resolve_type_annotation(str | int) == str | int # pyright: ignore[reportArgumentType]
def test_resolve_type_annotation_string_builtin() -> None:
@@ -484,8 +484,8 @@ def test_handler_ctx_missing_annotation_raises() -> None:
# Validation now happens at handler registration time, not workflow build time
with pytest.raises(ValueError) as exc:
class BadExecutor(Executor):
@handler
class BadExecutor(Executor): # pyright: ignore[reportUnusedClass]
@handler # pyright: ignore[reportUnknownArgumentType]
async def handle(self, message: str, ctx) -> None: # type: ignore[no-untyped-def]
pass
@@ -496,8 +496,8 @@ def test_handler_ctx_invalid_t_out_entries_raises() -> None:
# Validation now happens at handler registration time, not workflow build time
with pytest.raises(ValueError) as exc:
class BadExecutor(Executor):
@handler
class BadExecutor(Executor): # pyright: ignore[reportUnusedClass]
@handler # pyright: ignore[reportUnknownArgumentType]
async def handle(self, message: str, ctx: WorkflowContext[123]) -> None: # type: ignore[valid-type]
pass
@@ -555,7 +555,7 @@ def test_output_validation_with_valid_output_executors():
)
assert workflow is not None
assert workflow._output_executors == ["executor2"]
assert workflow._output_executors == ["executor2"] # pyright: ignore[reportPrivateUsage]
def test_output_validation_with_multiple_valid_output_executors():
@@ -572,7 +572,7 @@ def test_output_validation_with_multiple_valid_output_executors():
)
assert workflow is not None
assert set(workflow._output_executors) == {"executor1", "executor3"}
assert set(workflow._output_executors) == {"executor1", "executor3"} # pyright: ignore[reportPrivateUsage]
def test_output_validation_fails_for_nonexistent_executor():
@@ -2,6 +2,9 @@
"""Tests for the workflow visualization module."""
from pathlib import Path
from typing import Any
import pytest
from agent_framework import Executor, WorkflowBuilder, WorkflowContext, WorkflowExecutor, WorkflowViz, handler
@@ -25,7 +28,7 @@ class ListStrTargetExecutor(Executor):
@pytest.fixture
def basic_sub_workflow():
def basic_sub_workflow() -> dict[str, Any]:
"""Fixture that creates a basic sub-workflow setup for testing."""
# Create a sub-workflow
sub_exec1 = MockExecutor(id="sub_exec1")
@@ -98,7 +101,7 @@ def test_workflow_viz_export_dot():
assert '"executor1" -> "executor2"' in content
def test_workflow_viz_export_dot_with_filename(tmp_path):
def test_workflow_viz_export_dot_with_filename(tmp_path: Path):
"""Test exporting workflow as DOT format with specified filename."""
executor1 = MockExecutor(id="executor1")
executor2 = MockExecutor(id="executor2")
@@ -203,7 +206,7 @@ def test_workflow_viz_graphviz_binary_not_found():
mock_source_class.return_value = mock_source
# Import the ExecutableNotFound exception for the test
from graphviz.backend.execute import ExecutableNotFound
from graphviz.backend.execute import ExecutableNotFound # type: ignore[import-not-found]
mock_source.render.side_effect = ExecutableNotFound("failed to execute PosixPath('dot')")
@@ -329,7 +332,7 @@ def test_workflow_viz_mermaid_fan_in_edge_group():
assert "s2 --> t" not in mermaid
def test_workflow_viz_sub_workflow_digraph(basic_sub_workflow):
def test_workflow_viz_sub_workflow_digraph(basic_sub_workflow: dict[str, Any]):
"""Test that WorkflowViz can visualize sub-workflows in DOT format."""
main_workflow = basic_sub_workflow["main_workflow"]
@@ -353,7 +356,7 @@ def test_workflow_viz_sub_workflow_digraph(basic_sub_workflow):
assert '"workflow_executor_1/sub_exec1" -> "workflow_executor_1/sub_exec2"' in dot_content
def test_workflow_viz_sub_workflow_mermaid(basic_sub_workflow):
def test_workflow_viz_sub_workflow_mermaid(basic_sub_workflow: dict[str, Any]):
"""Test that WorkflowViz can visualize sub-workflows in Mermaid format."""
main_workflow = basic_sub_workflow["main_workflow"]
@@ -4,7 +4,7 @@ import asyncio
import tempfile
from collections.abc import AsyncIterable, Awaitable, Sequence
from dataclasses import dataclass, field
from typing import Any, cast
from typing import Any, Literal, cast, overload
from uuid import uuid4
import pytest
@@ -13,6 +13,7 @@ from agent_framework import (
AgentExecutor,
AgentResponse,
AgentResponseUpdate,
AgentRunInputs,
AgentSession,
BaseAgent,
Content,
@@ -474,7 +475,7 @@ class StateTrackingExecutor(Executor):
) -> None:
"""Handle the message and track it in workflow state."""
# Get existing messages from workflow state
existing_messages = ctx.get_state("processed_messages") or []
existing_messages: list[str] = ctx.get_state("processed_messages") or []
# Record this message
message_record = f"{message.run_id}:{message.data}"
@@ -833,6 +834,26 @@ class _StreamingTestAgent(BaseAgent):
super().__init__(**kwargs)
self._reply_text = reply_text
@overload
def run(
self,
messages: AgentRunInputs | None = ...,
*,
stream: Literal[False] = ...,
session: AgentSession | None = ...,
**kwargs: Any,
) -> Awaitable[AgentResponse[Any]]: ...
@overload
def run(
self,
messages: AgentRunInputs | None = ...,
*,
stream: Literal[True],
session: AgentSession | None = ...,
**kwargs: Any,
) -> ResponseStream[AgentResponseUpdate, AgentResponse[Any]]: ...
def run(
self,
messages: str | Content | Message | Sequence[str | Content | Message] | None = None,
@@ -883,8 +904,10 @@ async def test_agent_streaming_vs_non_streaming() -> None:
stream_events.append(event)
# Filter for agent events
agent_response = [
cast(AgentResponse, e.data) for e in stream_events if e.type == "output" and isinstance(e.data, AgentResponse)
agent_response: list[AgentResponse[Any]] = [
cast(AgentResponse[Any], e.data) # pyright: ignore[reportUnknownMemberType]
for e in stream_events
if e.type == "output" and isinstance(e.data, AgentResponse)
]
agent_response_updates = [
e.data for e in stream_events if e.type == "output" and isinstance(e.data, AgentResponseUpdate)
@@ -2,7 +2,7 @@
import uuid
from collections.abc import Awaitable, Sequence
from typing import Any
from typing import Any, Literal, overload
import pytest
from typing_extensions import Never
@@ -713,6 +713,14 @@ class TestWorkflowAgent:
def create_session(self, **kwargs: Any) -> AgentSession:
return AgentSession()
def get_session(self, *, service_session_id: str, **kwargs: Any) -> AgentSession:
return AgentSession()
@overload
def run(self, messages: str | Content | Message | Sequence[str | Content | Message] | None = ..., *, stream: Literal[False] = ..., session: AgentSession | None = ..., **kwargs: Any) -> Awaitable[AgentResponse[Any]]: ...
@overload
def run(self, messages: str | Content | Message | Sequence[str | Content | Message] | None = ..., *, stream: Literal[True], session: AgentSession | None = ..., **kwargs: Any) -> ResponseStream[AgentResponseUpdate, AgentResponse[Any]]: ...
def run(
self,
messages: str | Content | Message | Sequence[str | Content | Message] | None = None,
@@ -801,6 +809,14 @@ class TestWorkflowAgent:
def create_session(self, **kwargs: Any) -> AgentSession:
return AgentSession()
def get_session(self, *, service_session_id: str, **kwargs: Any) -> AgentSession:
return AgentSession()
@overload
def run(self, messages: str | Content | Message | Sequence[str | Content | Message] | None = ..., *, stream: Literal[False] = ..., session: AgentSession | None = ..., **kwargs: Any) -> Awaitable[AgentResponse[Any]]: ...
@overload
def run(self, messages: str | Content | Message | Sequence[str | Content | Message] | None = ..., *, stream: Literal[True], session: AgentSession | None = ..., **kwargs: Any) -> ResponseStream[AgentResponseUpdate, AgentResponse[Any]]: ...
def run(
self,
messages: str | Content | Message | Sequence[str | Content | Message] | None = None,
@@ -1207,7 +1223,7 @@ class TestWorkflowAgentMergeUpdates:
]
# Compare using role.value for Role enum
actual_sequence_normalized = [(t, r.value if hasattr(r, "value") else r) for t, r in content_sequence]
actual_sequence_normalized = [(t, r.value if hasattr(r, "value") else r) for t, r in content_sequence] # type: ignore[union-attr]
assert actual_sequence_normalized == expected_sequence, (
f"FunctionResultContent should come immediately after FunctionCallContent. "
@@ -1,7 +1,8 @@
# Copyright (c) Microsoft. All rights reserved.
from collections.abc import AsyncIterator, Awaitable
from dataclasses import dataclass
from typing import Any
from typing import Any, Literal, overload
import pytest
@@ -9,10 +10,12 @@ from agent_framework import (
AgentExecutor,
AgentResponse,
AgentResponseUpdate,
AgentRunInputs,
AgentSession,
BaseAgent,
Executor,
Message,
ResponseStream,
WorkflowBuilder,
WorkflowContext,
WorkflowValidationError,
@@ -21,22 +24,49 @@ from agent_framework import (
class DummyAgent(BaseAgent):
def run(self, messages=None, *, stream: bool = False, session: AgentSession | None = None, **kwargs): # type: ignore[override]
@overload
def run(
self,
messages: AgentRunInputs | None = ...,
*,
stream: Literal[False] = ...,
session: AgentSession | None = ...,
**kwargs: Any,
) -> Awaitable[AgentResponse[Any]]: ...
@overload
def run(
self,
messages: AgentRunInputs | None = ...,
*,
stream: Literal[True],
session: AgentSession | None = ...,
**kwargs: Any,
) -> ResponseStream[AgentResponseUpdate, AgentResponse[Any]]: ...
def run(
self,
messages: AgentRunInputs | None = None,
*,
stream: bool = False,
session: AgentSession | None = None,
**kwargs: Any,
) -> Awaitable[AgentResponse[Any]] | ResponseStream[AgentResponseUpdate, AgentResponse[Any]]:
if stream:
return self._run_stream_impl()
return ResponseStream[AgentResponseUpdate, AgentResponse[Any]](self._run_stream_impl())
return self._run_impl(messages)
async def _run_impl(self, messages=None) -> AgentResponse:
async def _run_impl(self, messages: AgentRunInputs | None = None) -> AgentResponse:
norm: list[Message] = []
if messages:
for m in messages: # type: ignore[iteration-over-optional]
for m in messages: # type: ignore[union-attr]
if isinstance(m, Message):
norm.append(m)
elif isinstance(m, str):
norm.append(Message(role="user", text=m))
return AgentResponse(messages=norm)
async def _run_stream_impl(self): # type: ignore[override]
async def _run_stream_impl(self) -> AsyncIterator[AgentResponseUpdate]:
# Minimal async generator
yield AgentResponseUpdate()
@@ -202,7 +232,7 @@ def test_with_output_from_returns_builder():
builder = WorkflowBuilder(output_executors=[executor_a], start_executor=executor_a)
# Verify builder was created with output_executors
assert builder._output_executors == [executor_a]
assert builder._output_executors == [executor_a] # pyright: ignore[reportPrivateUsage]
def test_with_output_from_with_executor_instances():
@@ -84,7 +84,7 @@ async def test_executor_emits_normal_event() -> None:
class _TestEvent(WorkflowEvent):
def __init__(self, data: Any = None) -> None:
super().__init__("test_event", data=data)
super().__init__("test_event", data=data) # type: ignore[arg-type]
async def test_workflow_context_type_annotations_no_parameter() -> None:
@@ -244,8 +244,8 @@ async def test_workflow_context_missing_annotation_error() -> None:
# Test class-based executor with missing ctx annotation
with pytest.raises(ValueError, match="must have a WorkflowContext"):
class _BadExecutor(Executor):
@handler
class _BadExecutor(Executor): # pyright: ignore[reportUnusedClass]
@handler # pyright: ignore[reportUnknownArgumentType]
async def bad_handler(self, text: str, ctx) -> None: # type: ignore[no-untyped-def]
pass
@@ -264,8 +264,8 @@ async def test_workflow_context_invalid_type_parameter_error() -> None:
# Test class-based executor with invalid type parameter
with pytest.raises(ValueError, match="invalid type entry"):
class _BadExecutor(Executor):
@handler
class _BadExecutor(Executor): # pyright: ignore[reportUnusedClass]
@handler # pyright: ignore[reportUnknownArgumentType]
async def bad_handler(self, text: str, ctx: WorkflowContext[456]) -> None: # type: ignore[valid-type]
pass
@@ -1,13 +1,14 @@
# Copyright (c) Microsoft. All rights reserved.
from collections.abc import AsyncIterable, Awaitable, Sequence
from typing import Annotated, Any
from collections.abc import AsyncIterable, Awaitable
from typing import Annotated, Any, Literal, overload
import pytest
from agent_framework import (
AgentResponse,
AgentResponseUpdate,
AgentRunInputs,
AgentSession,
BaseAgent,
Content,
@@ -50,14 +51,19 @@ class _KwargsCapturingAgent(BaseAgent):
super().__init__(name=name, description="Test agent for kwargs capture")
self.captured_kwargs = []
@overload
def run(self, messages: AgentRunInputs | None = ..., *, stream: Literal[False] = ..., session: AgentSession | None = ..., **kwargs: Any) -> Awaitable[AgentResponse[Any]]: ...
@overload
def run(self, messages: AgentRunInputs | None = ..., *, stream: Literal[True], session: AgentSession | None = ..., **kwargs: Any) -> ResponseStream[AgentResponseUpdate, AgentResponse[Any]]: ...
def run(
self,
messages: str | Content | Message | Sequence[str | Content | Message] | None = None,
messages: AgentRunInputs | None = None,
*,
stream: bool = False,
session: AgentSession | None = None,
**kwargs: Any,
) -> Awaitable[AgentResponse] | ResponseStream[AgentResponseUpdate, AgentResponse]:
) -> Awaitable[AgentResponse[Any]] | ResponseStream[AgentResponseUpdate, AgentResponse[Any]]:
self.captured_kwargs.append(dict(kwargs))
if stream:
@@ -83,15 +89,20 @@ class _OptionsAwareAgent(BaseAgent):
self.captured_options = []
self.captured_kwargs = []
@overload
def run(self, messages: AgentRunInputs | None = ..., *, stream: Literal[False] = ..., session: AgentSession | None = ..., **kwargs: Any) -> Awaitable[AgentResponse[Any]]: ...
@overload
def run(self, messages: AgentRunInputs | None = ..., *, stream: Literal[True], session: AgentSession | None = ..., **kwargs: Any) -> ResponseStream[AgentResponseUpdate, AgentResponse[Any]]: ...
def run(
self,
messages: str | Content | Message | Sequence[str | Content | Message] | None = None,
messages: AgentRunInputs | None = None,
*,
stream: bool = False,
session: AgentSession | None = None,
options: dict[str, Any] | None = None,
**kwargs: Any,
) -> Awaitable[AgentResponse] | ResponseStream[AgentResponseUpdate, AgentResponse]:
) -> Awaitable[AgentResponse[Any]] | ResponseStream[AgentResponseUpdate, AgentResponse[Any]]:
self.captured_options.append(dict(options) if options is not None else None)
self.captured_kwargs.append(dict(kwargs))
if stream:
@@ -189,15 +200,15 @@ async def test_sequential_run_options_does_not_conflict_with_agent_options() ->
break
assert len(agent.captured_options) >= 1
captured_options = agent.captured_options[0]
captured_options: dict[str, Any] | None = agent.captured_options[0]
assert captured_options is not None
assert captured_options.get("store") is False
additional_args = captured_options.get("additional_function_arguments")
additional_args: Any = captured_options.get("additional_function_arguments")
assert isinstance(additional_args, dict)
assert additional_args.get("source") == "workflow-options"
assert additional_args.get("custom_data") == custom_data
assert additional_args.get("user_token") == user_token
assert additional_args.get("source") == "workflow-options" # pyright: ignore[reportUnknownMemberType]
assert additional_args.get("custom_data") == custom_data # pyright: ignore[reportUnknownMemberType]
assert additional_args.get("user_token") == user_token # pyright: ignore[reportUnknownMemberType]
# "options" should be passed once via the dedicated options parameter,
# not duplicated in **kwargs.
@@ -225,13 +236,13 @@ async def test_sequential_run_additional_function_arguments_flattened() -> None:
break
assert len(agent.captured_options) >= 1
captured_options = agent.captured_options[0]
captured_options: dict[str, Any] | None = agent.captured_options[0]
assert captured_options is not None
additional_args = captured_options.get("additional_function_arguments")
additional_args: Any = captured_options.get("additional_function_arguments")
assert isinstance(additional_args, dict)
assert additional_args.get("custom_data") == custom_data
assert additional_args.get("user_token") == user_token
assert additional_args.get("custom_data") == custom_data # pyright: ignore[reportUnknownMemberType]
assert additional_args.get("user_token") == user_token # pyright: ignore[reportUnknownMemberType]
assert "additional_function_arguments" not in additional_args
assert len(agent.captured_kwargs) >= 1
@@ -255,14 +266,14 @@ async def test_sequential_run_additional_function_arguments_merges_with_options(
break
assert len(agent.captured_options) >= 1
captured_options = agent.captured_options[0]
captured_options: dict[str, Any] | None = agent.captured_options[0]
assert captured_options is not None
additional_args = captured_options.get("additional_function_arguments")
additional_args: Any = captured_options.get("additional_function_arguments")
assert isinstance(additional_args, dict)
assert additional_args.get("source") == "workflow-options"
assert additional_args.get("custom_data") == {"session_id": "abc123"}
assert additional_args.get("user_token") == {"user_name": "alice"}
assert additional_args.get("source") == "workflow-options" # pyright: ignore[reportUnknownMemberType]
assert additional_args.get("custom_data") == {"session_id": "abc123"} # pyright: ignore[reportUnknownMemberType]
assert additional_args.get("user_token") == {"user_name": "alice"} # pyright: ignore[reportUnknownMemberType]
assert "additional_function_arguments" not in additional_args
@@ -463,14 +474,19 @@ async def test_kwargs_preserved_on_response_continuation() -> None:
self.captured_kwargs = []
self._asked = False
@overload
def run(self, messages: AgentRunInputs | None = ..., *, stream: Literal[False] = ..., session: AgentSession | None = ..., **kwargs: Any) -> Awaitable[AgentResponse[Any]]: ...
@overload
def run(self, messages: AgentRunInputs | None = ..., *, stream: Literal[True], session: AgentSession | None = ..., **kwargs: Any) -> ResponseStream[AgentResponseUpdate, AgentResponse[Any]]: ...
def run(
self,
messages: str | Content | Message | Sequence[str | Content | Message] | None = None,
messages: AgentRunInputs | None = None,
*,
stream: bool = False,
session: AgentSession | None = None,
**kwargs: Any,
) -> Awaitable[AgentResponse] | ResponseStream[AgentResponseUpdate, AgentResponse]:
) -> Awaitable[AgentResponse[Any]] | ResponseStream[AgentResponseUpdate, AgentResponse[Any]]:
self.captured_kwargs.append(dict(kwargs))
if not self._asked:
self._asked = True
@@ -521,14 +537,19 @@ async def test_kwargs_overridden_on_response_continuation() -> None:
self.captured_kwargs = []
self._asked = False
@overload
def run(self, messages: AgentRunInputs | None = ..., *, stream: Literal[False] = ..., session: AgentSession | None = ..., **kwargs: Any) -> Awaitable[AgentResponse[Any]]: ...
@overload
def run(self, messages: AgentRunInputs | None = ..., *, stream: Literal[True], session: AgentSession | None = ..., **kwargs: Any) -> ResponseStream[AgentResponseUpdate, AgentResponse[Any]]: ...
def run(
self,
messages: str | Content | Message | Sequence[str | Content | Message] | None = None,
messages: AgentRunInputs | None = None,
*,
stream: bool = False,
session: AgentSession | None = None,
**kwargs: Any,
) -> Awaitable[AgentResponse] | ResponseStream[AgentResponseUpdate, AgentResponse]:
) -> Awaitable[AgentResponse[Any]] | ResponseStream[AgentResponseUpdate, AgentResponse[Any]]:
self.captured_kwargs.append(dict(kwargs))
if not self._asked:
self._asked = True
@@ -583,14 +604,19 @@ async def test_kwargs_empty_value_passed_on_continuation() -> None:
self.captured_kwargs = []
self._asked = False
@overload
def run(self, messages: AgentRunInputs | None = ..., *, stream: Literal[False] = ..., session: AgentSession | None = ..., **kwargs: Any) -> Awaitable[AgentResponse[Any]]: ...
@overload
def run(self, messages: AgentRunInputs | None = ..., *, stream: Literal[True], session: AgentSession | None = ..., **kwargs: Any) -> ResponseStream[AgentResponseUpdate, AgentResponse[Any]]: ...
def run(
self,
messages: str | Content | Message | Sequence[str | Content | Message] | None = None,
messages: AgentRunInputs | None = None,
*,
stream: bool = False,
session: AgentSession | None = None,
**kwargs: Any,
) -> Awaitable[AgentResponse] | ResponseStream[AgentResponseUpdate, AgentResponse]:
) -> Awaitable[AgentResponse[Any]] | ResponseStream[AgentResponseUpdate, AgentResponse[Any]]:
self.captured_kwargs.append(dict(kwargs))
if not self._asked:
self._asked = True
@@ -690,8 +716,8 @@ async def test_handoff_kwargs_flow_to_agents() -> None:
workflow = (
HandoffBuilder(termination_condition=lambda conv: len(conv) >= 4)
.participants([agent1, agent2])
.with_start_agent(agent1)
.participants([agent1, agent2]) # type: ignore[list-item]
.with_start_agent(agent1) # type: ignore[arg-type]
.with_autonomous_mode()
.build()
)
@@ -109,7 +109,7 @@ async def test_span_creation_and_attributes(span_exporter: InMemorySpanExporter)
{
"id": "test-workflow-123",
"max_iterations": 100,
"model_dump_json": lambda self: '{"id": "test-workflow-123", "type": "mock"}',
"model_dump_json": lambda self: '{"id": "test-workflow-123", "type": "mock"}', # pyright: ignore[reportUnknownLambdaType]
},
)(),
)
@@ -122,7 +122,7 @@ async def test_span_creation_and_attributes(span_exporter: InMemorySpanExporter)
},
) as workflow_span:
workflow_span.add_event(OtelAttr.WORKFLOW_STARTED)
sending_attributes = {
sending_attributes: dict[str, str | int] = {
OtelAttr.MESSAGE_TYPE: "ResponseMessage",
OtelAttr.MESSAGE_DESTINATION_EXECUTOR_ID: "target-789",
}
@@ -231,7 +231,7 @@ async def test_trace_context_handling(span_exporter: InMemorySpanExporter) -> No
@pytest.mark.parametrize("enable_instrumentation", [False], indirect=True)
async def test_trace_context_disabled_when_tracing_disabled(
enable_instrumentation, span_exporter: InMemorySpanExporter
enable_instrumentation: bool, span_exporter: InMemorySpanExporter
) -> None:
"""Test that no trace context is added when tracing is disabled."""
# Tracing should be disabled by default
@@ -313,7 +313,7 @@ async def test_end_to_end_workflow_tracing(span_exporter: InMemorySpanExporter)
span_exporter.clear()
# Run workflow (this should create run spans)
events = []
events: list[Any] = []
async for event in workflow.run("test input", stream=True):
events.append(event)
@@ -1,5 +1,7 @@
# Copyright (c) Microsoft. All rights reserved.
from typing import Any
import pytest
from typing_extensions import Never
@@ -36,16 +38,16 @@ async def test_executor_failed_and_workflow_failed_events_streaming():
events.append(ev)
# executor_failed event (type='executor_failed') should be emitted before workflow failed event
executor_failed_events = [e for e in events if isinstance(e, WorkflowEvent) and e.type == "executor_failed"]
executor_failed_events: list[WorkflowEvent[Any]] = [e for e in events if isinstance(e, WorkflowEvent) and e.type == "executor_failed"]
assert executor_failed_events, "executor_failed event should be emitted when start executor fails"
assert executor_failed_events[0].executor_id == "f"
assert executor_failed_events[0].origin is WorkflowEventSource.FRAMEWORK
# Workflow-level failure and FAILED status should be surfaced
failed_events = [e for e in events if isinstance(e, WorkflowEvent) and e.type == "failed"]
failed_events: list[WorkflowEvent[Any]] = [e for e in events if isinstance(e, WorkflowEvent) and e.type == "failed"]
assert failed_events
assert all(e.origin is WorkflowEventSource.FRAMEWORK for e in failed_events)
status = [e for e in events if isinstance(e, WorkflowEvent) and e.type == "status"]
status: list[WorkflowEvent[Any]] = [e for e in events if isinstance(e, WorkflowEvent) and e.type == "status"]
assert status and status[-1].state == WorkflowRunState.FAILED
assert all(e.origin is WorkflowEventSource.FRAMEWORK for e in status)
@@ -94,13 +96,13 @@ async def test_executor_failed_event_from_second_executor_in_chain():
events.append(ev)
# executor_failed event should be emitted for the failing executor
executor_failed_events = [e for e in events if isinstance(e, WorkflowEvent) and e.type == "executor_failed"]
executor_failed_events: list[WorkflowEvent[Any]] = [e for e in events if isinstance(e, WorkflowEvent) and e.type == "executor_failed"]
assert executor_failed_events, "executor_failed event should be emitted when second executor fails"
assert executor_failed_events[0].executor_id == "failing"
assert executor_failed_events[0].origin is WorkflowEventSource.FRAMEWORK
# Workflow-level failure should also be surfaced
failed_events = [e for e in events if isinstance(e, WorkflowEvent) and e.type == "failed"]
failed_events: list[WorkflowEvent[Any]] = [e for e in events if isinstance(e, WorkflowEvent) and e.type == "failed"]
assert failed_events
assert all(e.origin is WorkflowEventSource.FRAMEWORK for e in failed_events)
-1
View File
@@ -184,7 +184,6 @@ omit = [
[tool.pyright]
include = ["agent_framework*"]
exclude = ["**/tests/**", "**/.venv/**", "packages/devui/frontend/**"]
typeCheckingMode = "strict"
reportUnnecessaryIsInstance = false
reportMissingTypeStubs = false