From 1a8729d5a7162ab4de0cce5077ccd8a22f0208f7 Mon Sep 17 00:00:00 2001 From: Tao Chen Date: Tue, 3 Mar 2026 13:52:05 -0800 Subject: [PATCH] Python: Fix workflow tests pyright warnings (#4362) * Fix workflow tests pyright warnings * Update uv.lock * Fix pyright * Comments * Update root pyproject pyright setting * Update core pyproject pyright setting * Update core pyproject pyright setting --- python/packages/core/pyproject.toml | 1 + .../tests/workflow/test_agent_executor.py | 119 ++++++++--- .../test_agent_executor_tool_calls.py | 31 ++- .../core/tests/workflow/test_agent_utils.py | 45 ++--- .../core/tests/workflow/test_checkpoint.py | 23 ++- .../tests/workflow/test_checkpoint_encode.py | 11 +- .../packages/core/tests/workflow/test_edge.py | 19 +- .../core/tests/workflow/test_executor.py | 190 ++++++++++++------ .../tests/workflow/test_executor_future.py | 30 +-- .../tests/workflow/test_full_conversation.py | 97 +++++++-- .../tests/workflow/test_function_executor.py | 126 ++++++------ .../workflow/test_function_executor_future.py | 6 +- .../tests/workflow/test_request_info_mixin.py | 22 +- .../core/tests/workflow/test_runner.py | 53 ++--- .../core/tests/workflow/test_state.py | 42 ++-- .../core/tests/workflow/test_typing_utils.py | 16 +- .../core/tests/workflow/test_validation.py | 12 +- .../packages/core/tests/workflow/test_viz.py | 13 +- .../core/tests/workflow/test_workflow.py | 31 ++- .../tests/workflow/test_workflow_agent.py | 20 +- .../tests/workflow/test_workflow_builder.py | 44 +++- .../tests/workflow/test_workflow_context.py | 10 +- .../tests/workflow/test_workflow_kwargs.py | 82 +++++--- .../workflow/test_workflow_observability.py | 8 +- .../tests/workflow/test_workflow_states.py | 12 +- python/pyproject.toml | 1 - 26 files changed, 696 insertions(+), 368 deletions(-) diff --git a/python/packages/core/pyproject.toml b/python/packages/core/pyproject.toml index a3c3f53ed6..b16ec09ad8 100644 --- a/python/packages/core/pyproject.toml +++ b/python/packages/core/pyproject.toml @@ -105,6 +105,7 @@ extend = "../../pyproject.toml" [tool.pyright] extends = "../../pyproject.toml" +include = ["tests/workflow"] [tool.mypy] plugins = ['pydantic.mypy'] diff --git a/python/packages/core/tests/workflow/test_agent_executor.py b/python/packages/core/tests/workflow/test_agent_executor.py index 4a850db642..788e96e61e 100644 --- a/python/packages/core/tests/workflow/test_agent_executor.py +++ b/python/packages/core/tests/workflow/test_agent_executor.py @@ -2,19 +2,20 @@ import logging from collections.abc import AsyncIterable, Awaitable -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, Literal, overload import pytest - from agent_framework import ( AgentExecutor, AgentResponse, AgentResponseUpdate, + AgentRunInputs, AgentSession, BaseAgent, Content, Message, ResponseStream, + WorkflowEvent, WorkflowRunState, ) from agent_framework._workflows._agent_executor import AgentExecutorResponse @@ -32,26 +33,56 @@ class _CountingAgent(BaseAgent): super().__init__(**kwargs) self.call_count = 0 + @overload def run( self, - messages: str | Message | list[str] | list[Message] | None = None, + messages: AgentRunInputs | None = ..., + *, + stream: Literal[False] = ..., + session: AgentSession | None = ..., + **kwargs: Any, + ) -> Awaitable[AgentResponse[Any]]: ... + @overload + def run( + self, + messages: AgentRunInputs | None = ..., + *, + stream: Literal[True], + session: AgentSession | None = ..., + **kwargs: Any, + ) -> ResponseStream[AgentResponseUpdate, AgentResponse[Any]]: ... + + def run( + self, + messages: AgentRunInputs | None = None, *, stream: bool = False, session: AgentSession | None = None, **kwargs: Any, - ) -> Awaitable[AgentResponse] | ResponseStream[AgentResponseUpdate, AgentResponse]: + ) -> ( + Awaitable[AgentResponse[Any]] + | ResponseStream[AgentResponseUpdate, AgentResponse[Any]] + ): self.call_count += 1 if stream: async def _stream() -> AsyncIterable[AgentResponseUpdate]: yield AgentResponseUpdate( - contents=[Content.from_text(text=f"Response #{self.call_count}: {self.name}")] + contents=[ + Content.from_text( + text=f"Response #{self.call_count}: {self.name}" + ) + ] ) return ResponseStream(_stream(), finalizer=AgentResponse.from_updates) async def _run() -> AgentResponse: - return AgentResponse(messages=[Message("assistant", [f"Response #{self.call_count}: {self.name}"])]) + return AgentResponse( + messages=[ + Message("assistant", [f"Response #{self.call_count}: {self.name}"]) + ] + ) return _run() @@ -63,13 +94,36 @@ class _StreamingHookAgent(BaseAgent): super().__init__(**kwargs) self.result_hook_called = False + @overload def run( self, - messages: str | Message | list[str] | list[Message] | None = None, + messages: AgentRunInputs | None = ..., + *, + stream: Literal[False] = ..., + session: AgentSession | None = ..., + **kwargs: Any, + ) -> Awaitable[AgentResponse[Any]]: ... + @overload + def run( + self, + messages: AgentRunInputs | None = ..., + *, + stream: Literal[True], + session: AgentSession | None = ..., + **kwargs: Any, + ) -> ResponseStream[AgentResponseUpdate, AgentResponse[Any]]: ... + + def run( + self, + messages: AgentRunInputs | None = None, *, stream: bool = False, + session: AgentSession | None = None, **kwargs: Any, - ) -> Awaitable[AgentResponse] | ResponseStream[AgentResponseUpdate, AgentResponse]: + ) -> ( + Awaitable[AgentResponse[Any]] + | ResponseStream[AgentResponseUpdate, AgentResponse[Any]] + ): if stream: async def _stream() -> AsyncIterable[AgentResponseUpdate]: @@ -78,13 +132,15 @@ class _StreamingHookAgent(BaseAgent): role="assistant", ) - async def _mark_result_hook_called(response: AgentResponse) -> AgentResponse: + async def _mark_result_hook_called( + response: AgentResponse, + ) -> AgentResponse: self.result_hook_called = True return response - return ResponseStream(_stream(), finalizer=AgentResponse.from_updates).with_result_hook( - _mark_result_hook_called - ) + return ResponseStream( + _stream(), finalizer=AgentResponse.from_updates + ).with_result_hook(_mark_result_hook_called) async def _run() -> AgentResponse: return AgentResponse(messages=[Message("assistant", ["hook test"])]) @@ -92,7 +148,9 @@ class _StreamingHookAgent(BaseAgent): return _run() -async def test_agent_executor_streaming_finalizes_stream_and_runs_result_hooks() -> None: +async def test_agent_executor_streaming_finalizes_stream_and_runs_result_hooks() -> ( + None +): """AgentExecutor should call get_final_response() so stream result hooks execute.""" agent = _StreamingHookAgent(id="hook_agent", name="HookAgent") executor = AgentExecutor(agent, id="hook_exec") @@ -159,7 +217,9 @@ async def test_agent_executor_checkpoint_stores_and_restores_state() -> None: executor_state = executor_states[executor.id] # type: ignore[index] assert "cache" in executor_state, "Checkpoint should store executor cache state" - assert "agent_session" in executor_state, "Checkpoint should store executor session state" + assert "agent_session" in executor_state, ( + "Checkpoint should store executor session state" + ) # Verify session state structure session_state = executor_state["agent_session"] # type: ignore[index] @@ -180,11 +240,15 @@ async def test_agent_executor_checkpoint_stores_and_restores_state() -> None: assert restored_agent.call_count == 0 # Build new workflow with the restored executor - wf_resume = SequentialBuilder(participants=[restored_executor], checkpoint_storage=storage).build() + wf_resume = SequentialBuilder( + participants=[restored_executor], checkpoint_storage=storage + ).build() # Resume from checkpoint resumed_output: AgentExecutorResponse | None = None - async for ev in wf_resume.run(checkpoint_id=restore_checkpoint.checkpoint_id, stream=True): + async for ev in wf_resume.run( + checkpoint_id=restore_checkpoint.checkpoint_id, stream=True + ): if ev.type == "output": resumed_output = ev.data # type: ignore[assignment] if ev.type == "status" and ev.state in ( @@ -278,7 +342,7 @@ async def test_agent_executor_run_streaming_with_stream_kwarg_does_not_raise() - workflow = SequentialBuilder(participants=[executor]).build() # stream=True at workflow level triggers streaming mode (returns async iterable) - events = [] + events: list[WorkflowEvent] = [] async for event in workflow.run("hello", stream=True): events.append(event) assert len(events) > 0 @@ -288,10 +352,13 @@ async def test_agent_executor_run_streaming_with_stream_kwarg_does_not_raise() - @pytest.mark.parametrize("reserved_kwarg", ["session", "stream", "messages"]) async def test_prepare_agent_run_args_strips_reserved_kwargs(reserved_kwarg: str, caplog: "LogCaptureFixture") -> None: """_prepare_agent_run_args must remove reserved kwargs and log a warning.""" - raw = {reserved_kwarg: "should-be-stripped", "custom_key": "keep-me"} + raw: dict[str, Any] = { + reserved_kwarg: "should-be-stripped", + "custom_key": "keep-me", + } with caplog.at_level(logging.WARNING): - run_kwargs, options = AgentExecutor._prepare_agent_run_args(raw) + run_kwargs, options = AgentExecutor._prepare_agent_run_args(raw) # pyright: ignore[reportPrivateUsage] assert reserved_kwarg not in run_kwargs assert "custom_key" in run_kwargs @@ -302,8 +369,8 @@ async def test_prepare_agent_run_args_strips_reserved_kwargs(reserved_kwarg: str async def test_prepare_agent_run_args_preserves_non_reserved_kwargs() -> None: """Non-reserved workflow kwargs should pass through unchanged.""" - raw = {"custom_param": "value", "another": 42} - run_kwargs, options = AgentExecutor._prepare_agent_run_args(raw) + raw: dict[str, Any] = {"custom_param": "value", "another": 42} + run_kwargs, _options = AgentExecutor._prepare_agent_run_args(raw) # pyright: ignore[reportPrivateUsage] assert run_kwargs["custom_param"] == "value" assert run_kwargs["another"] == 42 @@ -312,10 +379,10 @@ async def test_prepare_agent_run_args_strips_all_reserved_kwargs_at_once( caplog: "LogCaptureFixture", ) -> None: """All reserved kwargs should be stripped when supplied together, each emitting a warning.""" - raw = {"session": "x", "stream": True, "messages": [], "custom": 1} + raw: dict[str, Any] = {"session": "x", "stream": True, "messages": [], "custom": 1} with caplog.at_level(logging.WARNING): - run_kwargs, options = AgentExecutor._prepare_agent_run_args(raw) + run_kwargs, options = AgentExecutor._prepare_agent_run_args(raw) # pyright: ignore[reportPrivateUsage] assert "session" not in run_kwargs assert "stream" not in run_kwargs @@ -324,7 +391,11 @@ async def test_prepare_agent_run_args_strips_all_reserved_kwargs_at_once( assert options is not None assert options["additional_function_arguments"]["custom"] == 1 - warned_keys = {r.message.split("'")[1] for r in caplog.records if "reserved" in r.message.lower()} + warned_keys = { + r.message.split("'")[1] + for r in caplog.records + if "reserved" in r.message.lower() + } assert warned_keys == {"session", "stream", "messages"} diff --git a/python/packages/core/tests/workflow/test_agent_executor_tool_calls.py b/python/packages/core/tests/workflow/test_agent_executor_tool_calls.py index cae5ea4e3b..07a37f9617 100644 --- a/python/packages/core/tests/workflow/test_agent_executor_tool_calls.py +++ b/python/packages/core/tests/workflow/test_agent_executor_tool_calls.py @@ -3,7 +3,7 @@ """Tests for AgentExecutor handling of tool calls and results in streaming mode.""" from collections.abc import AsyncIterable, Awaitable, Mapping, Sequence -from typing import Any +from typing import Any, Literal, overload from typing_extensions import Never @@ -13,6 +13,7 @@ from agent_framework import ( AgentExecutorResponse, AgentResponse, AgentResponseUpdate, + AgentRunInputs, AgentSession, BaseAgent, ChatResponse, @@ -37,18 +38,38 @@ class _ToolCallingAgent(BaseAgent): def __init__(self, **kwargs: Any) -> None: super().__init__(**kwargs) + @overload def run( self, - messages: str | Content | Message | Sequence[str | Content | Message] | None = None, + messages: AgentRunInputs | None = ..., + *, + stream: Literal[False] = ..., + session: AgentSession | None = ..., + **kwargs: Any, + ) -> Awaitable[AgentResponse[Any]]: ... + + @overload + def run( + self, + messages: AgentRunInputs | None = ..., + *, + stream: Literal[True], + session: AgentSession | None = ..., + **kwargs: Any, + ) -> ResponseStream[AgentResponseUpdate, AgentResponse[Any]]: ... + + def run( + self, + messages: AgentRunInputs | None = None, *, stream: bool = False, session: AgentSession | None = None, **kwargs: Any, - ) -> Awaitable[AgentResponse] | ResponseStream[AgentResponseUpdate, AgentResponse]: + ) -> Awaitable[AgentResponse[Any]] | ResponseStream[AgentResponseUpdate, AgentResponse[Any]]: if stream: return ResponseStream(self._run_stream_impl(), finalizer=AgentResponse.from_updates) - async def _run() -> AgentResponse: + async def _run() -> AgentResponse[Any]: return AgentResponse(messages=[Message("assistant", ["done"])]) return _run() @@ -111,6 +132,7 @@ async def test_agent_executor_emits_tool_calls_in_streaming_mode() -> None: # First event: text update assert events[0].data is not None assert events[0].data.contents[0].type == "text" + assert events[0].data.contents[0].text is not None assert "Let me search" in events[0].data.contents[0].text # Second event: function call @@ -129,6 +151,7 @@ async def test_agent_executor_emits_tool_calls_in_streaming_mode() -> None: # Fourth event: final text assert events[3].data is not None assert events[3].data.contents[0].type == "text" + assert events[3].data.contents[0].text is not None assert "sunny" in events[3].data.contents[0].text diff --git a/python/packages/core/tests/workflow/test_agent_utils.py b/python/packages/core/tests/workflow/test_agent_utils.py index d3889b4d3b..07d1e64c08 100644 --- a/python/packages/core/tests/workflow/test_agent_utils.py +++ b/python/packages/core/tests/workflow/test_agent_utils.py @@ -1,9 +1,9 @@ # Copyright (c) Microsoft. All rights reserved. -from collections.abc import AsyncIterable -from typing import Any +from collections.abc import Awaitable +from typing import Any, Literal, overload -from agent_framework import AgentResponse, AgentResponseUpdate, AgentSession, Message +from agent_framework import AgentResponse, AgentResponseUpdate, AgentRunInputs, AgentSession, ResponseStream from agent_framework._workflows._agent_utils import resolve_agent_id @@ -11,40 +11,23 @@ class MockAgent: """Mock agent for testing agent utilities.""" def __init__(self, agent_id: str, name: str | None = None) -> None: - self._id = agent_id - self._name = name + self.id: str = agent_id + self.name: str | None = name + self.description: str | None = None - @property - def id(self) -> str: - return self._id - - @property - def name(self) -> str | None: - return self._name - - @property - def display_name(self) -> str: - """Returns the display name of the agent.""" - ... - - @property - def description(self) -> str | None: - """Returns the description of the agent.""" - ... - - def run( - self, - messages: str | Message | list[str] | list[Message] | None = None, - *, - stream: bool = False, - session: AgentSession | None = None, - **kwargs: Any, - ) -> AgentResponse | AsyncIterable[AgentResponseUpdate]: ... + @overload + def run(self, messages: AgentRunInputs | None = ..., *, stream: Literal[False] = ..., session: AgentSession | None = ..., **kwargs: Any) -> Awaitable[AgentResponse[Any]]: ... + @overload + def run(self, messages: AgentRunInputs | None = ..., *, stream: Literal[True], session: AgentSession | None = ..., **kwargs: Any) -> ResponseStream[AgentResponseUpdate, AgentResponse[Any]]: ... + def run(self, messages: AgentRunInputs | None = None, *, stream: bool = False, session: AgentSession | None = None, **kwargs: Any) -> Awaitable[AgentResponse[Any]] | ResponseStream[AgentResponseUpdate, AgentResponse[Any]]: ... def create_session(self, **kwargs: Any) -> AgentSession: """Creates a new conversation session for the agent.""" ... + def get_session(self, *, service_session_id: str, **kwargs: Any) -> AgentSession: + return AgentSession() + def test_resolve_agent_id_with_name() -> None: """Test that resolve_agent_id returns name when agent has a name.""" diff --git a/python/packages/core/tests/workflow/test_checkpoint.py b/python/packages/core/tests/workflow/test_checkpoint.py index b05d625502..a32489acc0 100644 --- a/python/packages/core/tests/workflow/test_checkpoint.py +++ b/python/packages/core/tests/workflow/test_checkpoint.py @@ -5,6 +5,7 @@ import tempfile from dataclasses import dataclass from datetime import datetime, timezone from pathlib import Path +from typing import Any import pytest @@ -24,7 +25,7 @@ class _TestToolApprovalRequest: """Request data for tool approval in tests.""" tool_name: str - arguments: dict + arguments: dict[str, Any] timestamp: datetime @@ -41,7 +42,7 @@ class _TestApprovalRequest: """Approval request data for tests.""" action: str - params: tuple + params: tuple[Any, ...] @dataclass @@ -78,8 +79,8 @@ def test_workflow_checkpoint_custom_values(): workflow_name="test-workflow-456", graph_signature_hash="test-hash-456", timestamp=custom_timestamp, - messages={"executor1": [{"data": "test"}]}, - pending_request_info_events={"req123": {"data": "test"}}, + messages={"executor1": [{"data": "test"}]}, # type: ignore[arg-type] # raw dict for serialization test + pending_request_info_events={"req123": {"data": "test"}}, # type: ignore[arg-type] # raw dict for serialization test state={"key": "value"}, iteration_count=5, metadata={"test": True}, @@ -103,7 +104,7 @@ def test_workflow_checkpoint_to_dict(): checkpoint_id="test-id", workflow_name="test-workflow", graph_signature_hash="test-hash", - messages={"executor1": [{"data": "test"}]}, + messages={"executor1": [{"data": "test"}]}, # type: ignore[arg-type] # raw dict for serialization test state={"key": "value"}, iteration_count=5, ) @@ -161,8 +162,8 @@ async def test_memory_checkpoint_storage_save_and_load(): checkpoint = WorkflowCheckpoint( workflow_name="test-workflow", graph_signature_hash="test-hash", - messages={"executor1": [{"data": "hello"}]}, - pending_request_info_events={"req123": {"data": "test"}}, + messages={"executor1": [{"data": "hello"}]}, # type: ignore[arg-type] # raw dict for serialization test + pending_request_info_events={"req123": {"data": "test"}}, # type: ignore[arg-type] # raw dict for serialization test ) # Save checkpoint @@ -776,9 +777,9 @@ async def test_file_checkpoint_storage_save_and_load(): checkpoint = WorkflowCheckpoint( workflow_name="test-workflow", graph_signature_hash="test-hash", - messages={"executor1": [{"data": "hello", "source_id": "test", "target_id": None}]}, + messages={"executor1": [{"data": "hello", "source_id": "test", "target_id": None}]}, # type: ignore[arg-type] # raw dict for serialization test state={"key": "value"}, - pending_request_info_events={"req123": {"data": "test"}}, + pending_request_info_events={"req123": {"data": "test"}}, # type: ignore[arg-type] # raw dict for serialization test ) # Save checkpoint @@ -904,9 +905,9 @@ async def test_file_checkpoint_storage_json_serialization(): checkpoint = WorkflowCheckpoint( workflow_name="test-workflow", graph_signature_hash="test-hash", - messages={"executor1": [{"data": {"nested": {"value": 42}}, "source_id": "test", "target_id": None}]}, + messages={"executor1": [{"data": {"nested": {"value": 42}}, "source_id": "test", "target_id": None}]}, # type: ignore[arg-type] # raw dict for serialization test state={"list": [1, 2, 3], "dict": {"a": "b", "c": {"d": "e"}}, "bool": True, "null": None}, - pending_request_info_events={"req123": {"data": "test"}}, + pending_request_info_events={"req123": {"data": "test"}}, # type: ignore[arg-type] # raw dict for serialization test ) # Save and load diff --git a/python/packages/core/tests/workflow/test_checkpoint_encode.py b/python/packages/core/tests/workflow/test_checkpoint_encode.py index 68ec1ac4e3..02da2f1297 100644 --- a/python/packages/core/tests/workflow/test_checkpoint_encode.py +++ b/python/packages/core/tests/workflow/test_checkpoint_encode.py @@ -3,11 +3,11 @@ import json from dataclasses import dataclass from datetime import datetime, timezone -from typing import Any +from typing import Any, cast from agent_framework._workflows._checkpoint_encoding import ( - _PICKLE_MARKER, - _TYPE_MARKER, + _PICKLE_MARKER, # pyright: ignore[reportPrivateUsage] + _TYPE_MARKER, # pyright: ignore[reportPrivateUsage] encode_checkpoint_value, ) @@ -185,8 +185,9 @@ def test_encode_list_of_dataclasses() -> None: result = encode_checkpoint_value(data) assert isinstance(result, list) - assert len(result) == 2 - for item in result: + result_list = cast(list[Any], result) + assert len(result_list) == 2 + for item in result_list: assert _PICKLE_MARKER in item diff --git a/python/packages/core/tests/workflow/test_edge.py b/python/packages/core/tests/workflow/test_edge.py index f63cf9b45b..ecaa341726 100644 --- a/python/packages/core/tests/workflow/test_edge.py +++ b/python/packages/core/tests/workflow/test_edge.py @@ -4,6 +4,8 @@ from dataclasses import dataclass from typing import Any from unittest.mock import patch +from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter + import pytest from agent_framework import ( @@ -275,6 +277,7 @@ async def test_single_edge_group_send_message_with_condition_pass() -> None: success = await edge_runner.send_message(message, state, ctx) assert success is True assert target.call_count == 1 + assert target.last_message is not None assert target.last_message.data == "test" @@ -301,7 +304,7 @@ async def test_single_edge_group_send_message_with_condition_fail() -> None: assert target.call_count == 0 -async def test_single_edge_group_tracing_success(span_exporter) -> None: +async def test_single_edge_group_tracing_success(span_exporter: InMemorySpanExporter) -> None: """Test that single edge group processing creates proper success spans.""" source = MockExecutor(id="source_executor") target = MockExecutor(id="target_executor") @@ -352,7 +355,7 @@ async def test_single_edge_group_tracing_success(span_exporter) -> None: assert link.context.span_id == int("00f067aa0ba902b7", 16) -async def test_single_edge_group_tracing_condition_failure(span_exporter) -> None: +async def test_single_edge_group_tracing_condition_failure(span_exporter: InMemorySpanExporter) -> None: """Test that single edge group processing creates proper spans for condition failures.""" source = MockExecutor(id="source_executor") target = MockExecutor(id="target_executor") @@ -386,7 +389,7 @@ async def test_single_edge_group_tracing_condition_failure(span_exporter) -> Non assert span.attributes.get("edge_group.delivery_status") == EdgeGroupDeliveryStatus.DROPPED_CONDITION_FALSE.value -async def test_single_edge_group_tracing_type_mismatch(span_exporter) -> None: +async def test_single_edge_group_tracing_type_mismatch(span_exporter: InMemorySpanExporter) -> None: """Test that single edge group processing creates proper spans for type mismatches.""" source = MockExecutor(id="source_executor") target = MockExecutor(id="target_executor") @@ -421,7 +424,7 @@ async def test_single_edge_group_tracing_type_mismatch(span_exporter) -> None: assert span.attributes.get("edge_group.delivery_status") == EdgeGroupDeliveryStatus.DROPPED_TYPE_MISMATCH.value -async def test_single_edge_group_tracing_target_mismatch(span_exporter) -> None: +async def test_single_edge_group_tracing_target_mismatch(span_exporter: InMemorySpanExporter) -> None: """Test that single edge group processing creates proper spans for target mismatches.""" source = MockExecutor(id="source_executor") target = MockExecutor(id="target_executor") @@ -775,7 +778,7 @@ async def test_source_edge_group_with_selection_func_send_message_with_target_in assert success is False -async def test_fan_out_edge_group_tracing_success(span_exporter) -> None: +async def test_fan_out_edge_group_tracing_success(span_exporter: InMemorySpanExporter) -> None: """Test that fan-out edge group processing creates proper success spans.""" source = MockExecutor(id="source_executor") target1 = MockExecutor(id="target_executor_1") @@ -827,7 +830,7 @@ async def test_fan_out_edge_group_tracing_success(span_exporter) -> None: assert link.context.span_id == int("00f067aa0ba902b7", 16) -async def test_fan_out_edge_group_tracing_with_target(span_exporter) -> None: +async def test_fan_out_edge_group_tracing_with_target(span_exporter: InMemorySpanExporter) -> None: """Test that fan-out edge group processing creates proper spans for targeted messages.""" source = MockExecutor(id="source_executor") target1 = MockExecutor(id="target_executor_1") @@ -994,7 +997,7 @@ async def test_target_edge_group_send_message_with_invalid_data() -> None: assert success is False -async def test_fan_in_edge_group_tracing_buffered(span_exporter) -> None: +async def test_fan_in_edge_group_tracing_buffered(span_exporter: InMemorySpanExporter) -> None: """Test that fan-in edge group processing creates proper spans for buffered messages.""" source1 = MockExecutor(id="source_executor_1") source2 = MockExecutor(id="source_executor_2") @@ -1086,7 +1089,7 @@ async def test_fan_in_edge_group_tracing_buffered(span_exporter) -> None: assert link.context.span_id == int("00f067aa0ba902b8", 16) -async def test_fan_in_edge_group_tracing_type_mismatch(span_exporter) -> None: +async def test_fan_in_edge_group_tracing_type_mismatch(span_exporter: InMemorySpanExporter) -> None: """Test that fan-in edge group processing creates proper spans for type mismatches.""" source1 = MockExecutor(id="source_executor_1") source2 = MockExecutor(id="source_executor_2") diff --git a/python/packages/core/tests/workflow/test_executor.py b/python/packages/core/tests/workflow/test_executor.py index 06d027f19d..77827c0634 100644 --- a/python/packages/core/tests/workflow/test_executor.py +++ b/python/packages/core/tests/workflow/test_executor.py @@ -3,8 +3,6 @@ from dataclasses import dataclass import pytest -from typing_extensions import Never - from agent_framework import ( Executor, Message, @@ -16,6 +14,7 @@ from agent_framework import ( handler, response_handler, ) +from typing_extensions import Never # Module-level types for string forward reference tests @@ -59,7 +58,7 @@ def test_executor_handler_without_annotations(): class MockExecutorWithOneHandlerWithoutAnnotations(Executor): # type: ignore """A mock executor with one handler that does not implement any annotations.""" - @handler + @handler # pyright: ignore[reportUnknownArgumentType] async def handle(self, message, ctx) -> None: # type: ignore """A mock handler that does not implement any annotations.""" pass @@ -156,7 +155,11 @@ async def test_executor_invoked_event_contains_input_data(): workflow = WorkflowBuilder(start_executor=upper).add_edge(upper, collector).build() events = await workflow.run("hello world") - invoked_events = [e for e in events if isinstance(e, WorkflowEvent) and e.type == "executor_invoked"] + invoked_events = [ + e + for e in events + if isinstance(e, WorkflowEvent) and e.type == "executor_invoked" + ] assert len(invoked_events) == 2 @@ -190,10 +193,16 @@ async def test_executor_completed_event_contains_sent_messages(): sender = MultiSenderExecutor(id="sender") collector = CollectorExecutor(id="collector") - workflow = WorkflowBuilder(start_executor=sender).add_edge(sender, collector).build() + workflow = ( + WorkflowBuilder(start_executor=sender).add_edge(sender, collector).build() + ) events = await workflow.run("hello") - completed_events = [e for e in events if isinstance(e, WorkflowEvent) and e.type == "executor_completed"] + completed_events = [ + e + for e in events + if isinstance(e, WorkflowEvent) and e.type == "executor_completed" + ] # Sender should have completed with the sent messages sender_completed = next(e for e in completed_events if e.executor_id == "sender") @@ -201,7 +210,9 @@ async def test_executor_completed_event_contains_sent_messages(): assert sender_completed.data == ["hello-first", "hello-second"] # Collector should have completed with no sent messages (None) - collector_completed_events = [e for e in completed_events if e.executor_id == "collector"] + collector_completed_events = [ + e for e in completed_events if e.executor_id == "collector" + ] # Collector is called twice (once per message from sender) assert len(collector_completed_events) == 2 for collector_completed in collector_completed_events: @@ -220,7 +231,11 @@ async def test_executor_completed_event_includes_yielded_outputs(): workflow = WorkflowBuilder(start_executor=executor).build() events = await workflow.run("test") - completed_events = [e for e in events if isinstance(e, WorkflowEvent) and e.type == "executor_completed"] + completed_events = [ + e + for e in events + if isinstance(e, WorkflowEvent) and e.type == "executor_completed" + ] assert len(completed_events) == 1 assert completed_events[0].executor_id == "yielder" @@ -248,7 +263,9 @@ async def test_executor_events_with_complex_message_types(): class ProcessorExecutor(Executor): @handler - async def handle(self, request: Request, ctx: WorkflowContext[Response]) -> None: + async def handle( + self, request: Request, ctx: WorkflowContext[Response] + ) -> None: response = Response(results=[request.query.upper()] * request.limit) await ctx.send_message(response) @@ -260,13 +277,23 @@ async def test_executor_events_with_complex_message_types(): processor = ProcessorExecutor(id="processor") collector = CollectorExecutor(id="collector") - workflow = WorkflowBuilder(start_executor=processor).add_edge(processor, collector).build() + workflow = ( + WorkflowBuilder(start_executor=processor).add_edge(processor, collector).build() + ) input_request = Request(query="hello", limit=3) events = await workflow.run(input_request) - invoked_events = [e for e in events if isinstance(e, WorkflowEvent) and e.type == "executor_invoked"] - completed_events = [e for e in events if isinstance(e, WorkflowEvent) and e.type == "executor_completed"] + invoked_events = [ + e + for e in events + if isinstance(e, WorkflowEvent) and e.type == "executor_invoked" + ] + completed_events = [ + e + for e in events + if isinstance(e, WorkflowEvent) and e.type == "executor_completed" + ] # Check processor invoked event has the Request object processor_invoked = next(e for e in invoked_events if e.executor_id == "processor") @@ -275,7 +302,9 @@ async def test_executor_events_with_complex_message_types(): assert processor_invoked.data.limit == 3 # Check processor completed event has the Response object - processor_completed = next(e for e in completed_events if e.executor_id == "processor") + processor_completed = next( + e for e in completed_events if e.executor_id == "processor" + ) assert processor_completed.data is not None assert len(processor_completed.data) == 1 assert isinstance(processor_completed.data[0], Response) @@ -361,7 +390,9 @@ def test_executor_workflow_output_types_property(): # Test executor with union workflow output types class UnionWorkflowOutputExecutor(Executor): @handler - async def handle(self, text: str, ctx: WorkflowContext[int, str | bool]) -> None: + async def handle( + self, text: str, ctx: WorkflowContext[int, str | bool] + ) -> None: pass executor = UnionWorkflowOutputExecutor(id="union_workflow_output") @@ -372,11 +403,15 @@ def test_executor_workflow_output_types_property(): # Test executor with multiple handlers having different workflow output types class MultiHandlerWorkflowExecutor(Executor): @handler - async def handle_string(self, text: str, ctx: WorkflowContext[int, str]) -> None: + async def handle_string( + self, text: str, ctx: WorkflowContext[int, str] + ) -> None: pass @handler - async def handle_number(self, num: int, ctx: WorkflowContext[bool, float]) -> None: + async def handle_number( + self, num: int, ctx: WorkflowContext[bool, float] + ) -> None: pass executor = MultiHandlerWorkflowExecutor(id="multi_workflow") @@ -430,7 +465,9 @@ def test_executor_output_types_includes_response_handlers(): pass @response_handler - async def handle_response(self, original_request: str, response: bool, ctx: WorkflowContext[float]) -> None: + async def handle_response( + self, original_request: str, response: bool, ctx: WorkflowContext[float] + ) -> None: pass executor = RequestResponseExecutor(id="request_response") @@ -452,7 +489,10 @@ def test_executor_workflow_output_types_includes_response_handlers(): @response_handler async def handle_response( - self, original_request: str, response: bool, ctx: WorkflowContext[float, bool] + self, + original_request: str, + response: bool, + ctx: WorkflowContext[float, bool], ) -> None: pass @@ -509,7 +549,10 @@ def test_executor_response_handler_union_output_types(): @response_handler async def handle_response( - self, original_request: str, response: bool, ctx: WorkflowContext[int | str | float, bool | int] + self, + original_request: str, + response: bool, + ctx: WorkflowContext[int | str | float, bool | int], ) -> None: pass @@ -531,7 +574,9 @@ async def test_executor_invoked_event_data_not_mutated_by_handler(): """Test that executor_invoked event (type='executor_invoked').data captures original input, not mutated input.""" @executor(id="Mutator") - async def mutator(messages: list[Message], ctx: WorkflowContext[list[Message]]) -> None: + async def mutator( + messages: list[Message], ctx: WorkflowContext[list[Message]] + ) -> None: # The handler mutates the input list by appending new messages original_len = len(messages) messages.append(Message(role="assistant", text="Added by executor")) @@ -546,7 +591,11 @@ async def test_executor_invoked_event_data_not_mutated_by_handler(): events = await workflow.run(input_messages) # Find the invoked event for the Mutator executor - invoked_events = [e for e in events if isinstance(e, WorkflowEvent) and e.type == "executor_invoked"] + invoked_events = [ + e + for e in events + if isinstance(e, WorkflowEvent) and e.type == "executor_invoked" + ] assert len(invoked_events) == 1 mutator_invoked = invoked_events[0] @@ -577,8 +626,8 @@ class TestHandlerExplicitTypes: exec_instance = ExplicitInputExecutor(id="explicit_input") # Handler should be registered for str (explicit), not Any (introspected) - assert str in exec_instance._handlers - assert len(exec_instance._handlers) == 1 + assert str in exec_instance._handlers # pyright: ignore[reportPrivateUsage] + assert len(exec_instance._handlers) == 1 # pyright: ignore[reportPrivateUsage] # Can handle str messages assert exec_instance.can_handle(WorkflowMessage(data="hello", source_id="mock")) @@ -596,8 +645,8 @@ class TestHandlerExplicitTypes: exec_instance = ExplicitOutputExecutor(id="explicit_output") # Handler spec should have int as output type (explicit) - handler_func = exec_instance._handlers[str] - assert handler_func._handler_spec["output_types"] == [int] + handler_func = exec_instance._handlers[str] # pyright: ignore[reportPrivateUsage] + assert handler_func._handler_spec["output_types"] == [int] # pyright: ignore[reportFunctionMemberAccess] # Executor output_types property should reflect explicit type assert int in exec_instance.output_types @@ -615,16 +664,20 @@ class TestHandlerExplicitTypes: exec_instance = ExplicitBothExecutor(id="explicit_both") # Handler should be registered for dict (explicit input type) - assert dict in exec_instance._handlers - assert len(exec_instance._handlers) == 1 + assert dict in exec_instance._handlers # pyright: ignore[reportPrivateUsage] + assert len(exec_instance._handlers) == 1 # pyright: ignore[reportPrivateUsage] # Output type should be list (explicit) - handler_func = exec_instance._handlers[dict] - assert handler_func._handler_spec["output_types"] == [list] + handler_func = exec_instance._handlers[dict] # pyright: ignore[reportPrivateUsage] + assert handler_func._handler_spec["output_types"] == [list] # pyright: ignore[reportFunctionMemberAccess] # Verify can_handle - assert exec_instance.can_handle(WorkflowMessage(data={"key": "value"}, source_id="mock")) - assert not exec_instance.can_handle(WorkflowMessage(data="string", source_id="mock")) + assert exec_instance.can_handle( + WorkflowMessage(data={"key": "value"}, source_id="mock") + ) + assert not exec_instance.can_handle( + WorkflowMessage(data="string", source_id="mock") + ) def test_handler_with_explicit_union_input_type(self): """Test that explicit union input_type is handled correctly.""" @@ -639,13 +692,15 @@ class TestHandlerExplicitTypes: # Handler should be registered for the union type # The union type itself is stored as the key - assert len(exec_instance._handlers) == 1 + assert len(exec_instance._handlers) == 1 # pyright: ignore[reportPrivateUsage] # Can handle both str and int messages assert exec_instance.can_handle(WorkflowMessage(data="hello", source_id="mock")) assert exec_instance.can_handle(WorkflowMessage(data=42, source_id="mock")) # Cannot handle float - assert not exec_instance.can_handle(WorkflowMessage(data=3.14, source_id="mock")) + assert not exec_instance.can_handle( + WorkflowMessage(data=3.14, source_id="mock") + ) def test_handler_with_explicit_union_output_type(self): """Test that explicit union output is normalized to a list.""" @@ -674,8 +729,8 @@ class TestHandlerExplicitTypes: exec_instance = PrecedenceExecutor(id="precedence") # Should use explicit input type (bytes), not introspected (str) - assert bytes in exec_instance._handlers - assert str not in exec_instance._handlers + assert bytes in exec_instance._handlers # pyright: ignore[reportPrivateUsage] + assert str not in exec_instance._handlers # pyright: ignore[reportPrivateUsage] # Should use explicit output type (float), not introspected (int) assert float in exec_instance.output_types @@ -692,7 +747,7 @@ class TestHandlerExplicitTypes: exec_instance = IntrospectedExecutor(id="introspected") # Should use introspected types - assert str in exec_instance._handlers + assert str in exec_instance._handlers # pyright: ignore[reportPrivateUsage] assert int in exec_instance.output_types def test_handler_explicit_mode_requires_input(self): @@ -705,13 +760,13 @@ class TestHandlerExplicitTypes: pass exec_input = OnlyInputExecutor(id="only_input") - assert bytes in exec_input._handlers # Explicit + assert bytes in exec_input._handlers # pyright: ignore[reportPrivateUsage] # Explicit assert exec_input.output_types == [] # No output types (not introspected) # Only explicit output without input should raise error with pytest.raises(ValueError, match="must specify 'input' type"): - class OnlyOutputExecutor(Executor): + class OnlyOutputExecutor(Executor): # pyright: ignore[reportUnusedClass] @handler(output=float) async def handle(self, message: str, ctx: WorkflowContext[int]) -> None: pass @@ -719,9 +774,11 @@ class TestHandlerExplicitTypes: # Only explicit workflow_output without input should raise error with pytest.raises(ValueError, match="must specify 'input' type"): - class OnlyWorkflowOutputExecutor(Executor): + class OnlyWorkflowOutputExecutor(Executor): # pyright: ignore[reportUnusedClass] @handler(workflow_output=bool) - async def handle(self, message: str, ctx: WorkflowContext[int, str]) -> None: + async def handle( + self, message: str, ctx: WorkflowContext[int, str] + ) -> None: pass def test_handler_explicit_input_type_allows_no_message_annotation(self): @@ -734,8 +791,7 @@ class TestHandlerExplicitTypes: exec_instance = NoAnnotationExecutor(id="no_annotation") - # Should work with explicit input_type - assert str in exec_instance._handlers + assert str in exec_instance._handlers # pyright: ignore[reportPrivateUsage] assert exec_instance.can_handle(WorkflowMessage(data="hello", source_id="mock")) def test_handler_multiple_handlers_mixed_explicit_and_introspected(self): @@ -747,15 +803,17 @@ class TestHandlerExplicitTypes: pass @handler - async def handle_introspected(self, message: float, ctx: WorkflowContext[bool]) -> None: + async def handle_introspected( + self, message: float, ctx: WorkflowContext[bool] + ) -> None: pass exec_instance = MixedExecutor(id="mixed") # Should have both handlers - assert len(exec_instance._handlers) == 2 - assert str in exec_instance._handlers # Explicit - assert float in exec_instance._handlers # Introspected + assert len(exec_instance._handlers) == 2 # pyright: ignore[reportPrivateUsage] + assert str in exec_instance._handlers # pyright: ignore[reportPrivateUsage] # Explicit + assert float in exec_instance._handlers # pyright: ignore[reportPrivateUsage] # Introspected # Should have both output types assert int in exec_instance.output_types # Explicit @@ -772,8 +830,10 @@ class TestHandlerExplicitTypes: exec_instance = StringRefExecutor(id="string_ref") # Should resolve the string to the actual type - assert ForwardRefMessage in exec_instance._handlers - assert exec_instance.can_handle(WorkflowMessage(data=ForwardRefMessage("hello"), source_id="mock")) + assert ForwardRefMessage in exec_instance._handlers # pyright: ignore[reportPrivateUsage] + assert exec_instance.can_handle( + WorkflowMessage(data=ForwardRefMessage("hello"), source_id="mock") + ) def test_handler_with_string_forward_reference_union(self): """Test that string forward references work with union types.""" @@ -786,8 +846,12 @@ class TestHandlerExplicitTypes: exec_instance = StringUnionExecutor(id="string_union") # Should handle both types - assert exec_instance.can_handle(WorkflowMessage(data=ForwardRefTypeA("hello"), source_id="mock")) - assert exec_instance.can_handle(WorkflowMessage(data=ForwardRefTypeB(42), source_id="mock")) + assert exec_instance.can_handle( + WorkflowMessage(data=ForwardRefTypeA("hello"), source_id="mock") + ) + assert exec_instance.can_handle( + WorkflowMessage(data=ForwardRefTypeB(42), source_id="mock") + ) def test_handler_with_string_forward_reference_output_type(self): """Test that string forward references work for output_type.""" @@ -813,8 +877,8 @@ class TestHandlerExplicitTypes: exec_instance = ExplicitWorkflowOutputExecutor(id="explicit_workflow_output") # Handler spec should have bool as workflow_output_type (explicit) - handler_func = exec_instance._handlers[str] - assert handler_func._handler_spec["workflow_output_types"] == [bool] + handler_func = exec_instance._handlers[str] # pyright: ignore[reportPrivateUsage] + assert handler_func._handler_spec["workflow_output_types"] == [bool] # pyright: ignore[reportFunctionMemberAccess] # Executor workflow_output_types property should reflect explicit type assert bool in exec_instance.workflow_output_types @@ -826,13 +890,14 @@ class TestHandlerExplicitTypes: class PrecedenceExecutor(Executor): @handler(input=int, output=float, workflow_output=str) - async def handle(self, message: int, ctx: WorkflowContext[int, bool]) -> None: + async def handle( + self, message: int, ctx: WorkflowContext[int, bool] + ) -> None: pass exec_instance = PrecedenceExecutor(id="precedence") - # All types should come from explicit params - assert int in exec_instance._handlers + assert int in exec_instance._handlers # pyright: ignore[reportPrivateUsage] assert float in exec_instance.output_types assert str in exec_instance.workflow_output_types # Introspected types should NOT be present @@ -849,8 +914,7 @@ class TestHandlerExplicitTypes: exec_instance = AllExplicitExecutor(id="all_explicit") - # Check input type - assert str in exec_instance._handlers + assert str in exec_instance._handlers # pyright: ignore[reportPrivateUsage] assert exec_instance.can_handle(WorkflowMessage(data="hello", source_id="mock")) # Check output_type @@ -894,7 +958,9 @@ class TestHandlerExplicitTypes: async def handle(self, message, ctx: WorkflowContext) -> None: # type: ignore[no-untyped-def] pass - exec_instance = StringUnionWorkflowOutputExecutor(id="string_union_workflow_output") + exec_instance = StringUnionWorkflowOutputExecutor( + id="string_union_workflow_output" + ) # Should resolve both types from string union assert ForwardRefTypeA in exec_instance.workflow_output_types @@ -905,10 +971,14 @@ class TestHandlerExplicitTypes: class IntrospectedWorkflowOutputExecutor(Executor): @handler - async def handle(self, message: str, ctx: WorkflowContext[int, bool]) -> None: + async def handle( + self, message: str, ctx: WorkflowContext[int, bool] + ) -> None: pass - exec_instance = IntrospectedWorkflowOutputExecutor(id="introspected_workflow_output") + exec_instance = IntrospectedWorkflowOutputExecutor( + id="introspected_workflow_output" + ) # Should use introspected types from WorkflowContext[int, bool] assert int in exec_instance.output_types diff --git a/python/packages/core/tests/workflow/test_executor_future.py b/python/packages/core/tests/workflow/test_executor_future.py index c0916b9cf7..cb0c5c9f58 100644 --- a/python/packages/core/tests/workflow/test_executor_future.py +++ b/python/packages/core/tests/workflow/test_executor_future.py @@ -34,8 +34,8 @@ class TestExecutorFutureAnnotations: pass exec_instance = MyExecutor(id="test") - assert str in exec_instance._handlers - spec = exec_instance._handler_specs[0] + assert str in exec_instance._handlers # pyright: ignore[reportPrivateUsage] + spec = exec_instance._handler_specs[0] # pyright: ignore[reportPrivateUsage] assert spec["message_type"] is str assert spec["output_types"] == [MyTypeA] assert spec["workflow_output_types"] == [MyTypeB] @@ -49,8 +49,8 @@ class TestExecutorFutureAnnotations: pass exec_instance = MyExecutor(id="test") - assert int in exec_instance._handlers - spec = exec_instance._handler_specs[0] + assert int in exec_instance._handlers # pyright: ignore[reportPrivateUsage] + spec = exec_instance._handler_specs[0] # pyright: ignore[reportPrivateUsage] assert spec["message_type"] is int assert spec["output_types"] == [MyTypeA] @@ -63,7 +63,7 @@ class TestExecutorFutureAnnotations: pass exec_instance = MyExecutor(id="test") - spec = exec_instance._handler_specs[0] + spec = exec_instance._handler_specs[0] # pyright: ignore[reportPrivateUsage] assert spec["message_type"] == dict[str, Any] assert spec["output_types"] == [list[str]] @@ -76,8 +76,8 @@ class TestExecutorFutureAnnotations: pass exec_instance = MyExecutor(id="test") - assert str in exec_instance._handlers - spec = exec_instance._handler_specs[0] + assert str in exec_instance._handlers # pyright: ignore[reportPrivateUsage] + spec = exec_instance._handler_specs[0] # pyright: ignore[reportPrivateUsage] assert spec["output_types"] == [] assert spec["workflow_output_types"] == [] @@ -86,12 +86,12 @@ class TestExecutorFutureAnnotations: class MyExecutor(Executor): @handler(input=str, output=MyTypeA) - async def example(self, input, ctx) -> None: + async def example(self, input, ctx) -> None: # type: ignore[no-untyped-def] pass exec_instance = MyExecutor(id="test") - assert str in exec_instance._handlers - spec = exec_instance._handler_specs[0] + assert str in exec_instance._handlers # pyright: ignore[reportPrivateUsage] + spec = exec_instance._handler_specs[0] # pyright: ignore[reportPrivateUsage] assert spec["message_type"] is str assert spec["output_types"] == [MyTypeA] @@ -104,8 +104,8 @@ class TestExecutorFutureAnnotations: pass exec_instance = MyExecutor(id="test") - assert str in exec_instance._handlers - spec = exec_instance._handler_specs[0] + assert str in exec_instance._handlers # pyright: ignore[reportPrivateUsage] + spec = exec_instance._handler_specs[0] # pyright: ignore[reportPrivateUsage] assert spec["output_types"] == [MyTypeA, MyTypeB] assert spec["workflow_output_types"] == [MyTypeC] @@ -118,7 +118,7 @@ class TestExecutorFutureAnnotations: """ with pytest.raises(ValueError): - class Bad(Executor): - @handler - async def example(self, input: NonExistentType, ctx: WorkflowContext[MyTypeA, MyTypeB]) -> None: # noqa: F821 + class Bad(Executor): # pyright: ignore[reportUnusedClass] + @handler # pyright: ignore[reportUnknownArgumentType] + async def example(self, input: NonExistentType, ctx: WorkflowContext[MyTypeA, MyTypeB]) -> None: # noqa: F821 # type: ignore[name-defined] pass diff --git a/python/packages/core/tests/workflow/test_full_conversation.py b/python/packages/core/tests/workflow/test_full_conversation.py index 20d9abd8c0..b6b5260d83 100644 --- a/python/packages/core/tests/workflow/test_full_conversation.py +++ b/python/packages/core/tests/workflow/test_full_conversation.py @@ -1,7 +1,7 @@ # Copyright (c) Microsoft. All rights reserved. -from collections.abc import AsyncIterable, Awaitable, Sequence -from typing import Any +from collections.abc import AsyncIterable, Awaitable +from typing import Any, Literal, overload import pytest from pydantic import PrivateAttr @@ -13,6 +13,7 @@ from agent_framework import ( AgentExecutorResponse, AgentResponse, AgentResponseUpdate, + AgentRunInputs, AgentSession, BaseAgent, Content, @@ -34,14 +35,32 @@ class _SimpleAgent(BaseAgent): super().__init__(**kwargs) self._reply_text = reply_text + @overload def run( self, - messages: str | Content | Message | Sequence[str | Content | Message] | None = None, + messages: AgentRunInputs | None = ..., + *, + stream: Literal[False] = ..., + session: AgentSession | None = ..., + **kwargs: Any, + ) -> Awaitable[AgentResponse[Any]]: ... + @overload + def run( + self, + messages: AgentRunInputs | None = ..., + *, + stream: Literal[True], + session: AgentSession | None = ..., + **kwargs: Any, + ) -> ResponseStream[AgentResponseUpdate, AgentResponse[Any]]: ... + def run( + self, + messages: AgentRunInputs | None = None, *, stream: bool = False, session: AgentSession | None = None, **kwargs: Any, - ) -> Awaitable[AgentResponse] | ResponseStream[AgentResponseUpdate, AgentResponse]: + ) -> Awaitable[AgentResponse[Any]] | ResponseStream[AgentResponseUpdate, AgentResponse[Any]]: if stream: async def _stream() -> AsyncIterable[AgentResponseUpdate]: @@ -81,14 +100,32 @@ class _ToolHistoryAgent(BaseAgent): Message(role="assistant", contents=[Content.from_text(text=self._summary_text)]), ] + @overload def run( self, - messages: str | Content | Message | Sequence[str | Content | Message] | None = None, + messages: AgentRunInputs | None = ..., + *, + stream: Literal[False] = ..., + session: AgentSession | None = ..., + **kwargs: Any, + ) -> Awaitable[AgentResponse[Any]]: ... + @overload + def run( + self, + messages: AgentRunInputs | None = ..., + *, + stream: Literal[True], + session: AgentSession | None = ..., + **kwargs: Any, + ) -> ResponseStream[AgentResponseUpdate, AgentResponse[Any]]: ... + def run( + self, + messages: AgentRunInputs | None = None, *, stream: bool = False, session: AgentSession | None = None, **kwargs: Any, - ) -> Awaitable[AgentResponse] | ResponseStream[AgentResponseUpdate, AgentResponse]: + ) -> Awaitable[AgentResponse[Any]] | ResponseStream[AgentResponseUpdate, AgentResponse[Any]]: if stream: async def _stream() -> AsyncIterable[AgentResponseUpdate]: @@ -165,14 +202,32 @@ class _CaptureAgent(BaseAgent): super().__init__(**kwargs) self._reply_text = reply_text + @overload def run( self, - messages: str | Content | Message | Sequence[str | Content | Message] | None = None, + messages: AgentRunInputs | None = ..., + *, + stream: Literal[False] = ..., + session: AgentSession | None = ..., + **kwargs: Any, + ) -> Awaitable[AgentResponse[Any]]: ... + @overload + def run( + self, + messages: AgentRunInputs | None = ..., + *, + stream: Literal[True], + session: AgentSession | None = ..., + **kwargs: Any, + ) -> ResponseStream[AgentResponseUpdate, AgentResponse[Any]]: ... + def run( + self, + messages: AgentRunInputs | None = None, *, stream: bool = False, session: AgentSession | None = None, **kwargs: Any, - ) -> Awaitable[AgentResponse] | ResponseStream[AgentResponseUpdate, AgentResponse]: + ) -> Awaitable[AgentResponse[Any]] | ResponseStream[AgentResponseUpdate, AgentResponse[Any]]: # Normalize and record messages for verification norm: list[Message] = [] if messages: @@ -260,7 +315,7 @@ class _RoundTripCoordinator(Executor): async def handle_response( self, response: AgentExecutorResponse, - ctx: WorkflowContext[Never, dict[str, Any]], + ctx: WorkflowContext[AgentExecutorRequest, dict[str, Any]], ) -> None: self._seen += 1 if self._seen == 1: @@ -314,14 +369,32 @@ class _SessionIdCapturingAgent(BaseAgent): _captured_service_session_id: str | None = PrivateAttr(default="NOT_CAPTURED") + @overload def run( self, - messages: str | Content | Message | Sequence[str | Content | Message] | None = None, + messages: AgentRunInputs | None = ..., + *, + stream: Literal[False] = ..., + session: AgentSession | None = ..., + **kwargs: Any, + ) -> Awaitable[AgentResponse[Any]]: ... + @overload + def run( + self, + messages: AgentRunInputs | None = ..., + *, + stream: Literal[True], + session: AgentSession | None = ..., + **kwargs: Any, + ) -> ResponseStream[AgentResponseUpdate, AgentResponse[Any]]: ... + def run( + self, + messages: AgentRunInputs | None = None, *, stream: bool = False, session: AgentSession | None = None, **kwargs: Any, - ) -> Awaitable[AgentResponse] | ResponseStream[AgentResponseUpdate, AgentResponse]: + ) -> Awaitable[AgentResponse[Any]] | ResponseStream[AgentResponseUpdate, AgentResponse[Any]]: self._captured_service_session_id = session.service_session_id if session else None async def _run() -> AgentResponse: @@ -342,7 +415,7 @@ class _FullHistoryReplayCoordinator(Executor): async def handle( self, response: AgentExecutorResponse, - ctx: WorkflowContext[Never, Any], + ctx: WorkflowContext[AgentExecutorRequest, Any], ) -> None: full_conv = list(response.full_conversation or response.agent_response.messages) full_conv.append(Message(role="user", text="follow-up")) diff --git a/python/packages/core/tests/workflow/test_function_executor.py b/python/packages/core/tests/workflow/test_function_executor.py index c0b73156ff..8bb3f94d29 100644 --- a/python/packages/core/tests/workflow/test_function_executor.py +++ b/python/packages/core/tests/workflow/test_function_executor.py @@ -48,12 +48,12 @@ class TestFunctionExecutor: func_exec = FunctionExecutor(process_string) # Check that handler was registered - assert len(func_exec._handlers) == 1 - assert str in func_exec._handlers + assert len(func_exec._handlers) == 1 # pyright: ignore[reportPrivateUsage] + assert str in func_exec._handlers # pyright: ignore[reportPrivateUsage] # Check handler spec was created - assert len(func_exec._handler_specs) == 1 - spec = func_exec._handler_specs[0] + assert len(func_exec._handler_specs) == 1 # pyright: ignore[reportPrivateUsage] + spec = func_exec._handler_specs[0] # pyright: ignore[reportPrivateUsage] assert spec["name"] == "process_string" assert spec["message_type"] is str assert spec["output_types"] == [str] @@ -67,10 +67,10 @@ class TestFunctionExecutor: assert isinstance(process_int, FunctionExecutor) assert process_int.id == "test_executor" - assert int in process_int._handlers + assert int in process_int._handlers # pyright: ignore[reportPrivateUsage] # Check spec - spec = process_int._handler_specs[0] + spec = process_int._handler_specs[0] # pyright: ignore[reportPrivateUsage] assert spec["message_type"] is int assert spec["output_types"] == [int] @@ -78,7 +78,7 @@ class TestFunctionExecutor: """Test @executor decorator uses function name as default ID.""" @executor - async def my_function(data: dict, ctx: WorkflowContext[Any]) -> None: + async def my_function(data: dict[str, Any], ctx: WorkflowContext[Any]) -> None: await ctx.send_message(data) assert my_function.id == "my_function" @@ -92,7 +92,7 @@ class TestFunctionExecutor: assert isinstance(no_parens_function, FunctionExecutor) assert no_parens_function.id == "no_parens_function" - assert str in no_parens_function._handlers + assert str in no_parens_function._handlers # pyright: ignore[reportPrivateUsage] # Also test with single parameter function @executor @@ -101,7 +101,7 @@ class TestFunctionExecutor: assert isinstance(simple_no_parens, FunctionExecutor) assert simple_no_parens.id == "simple_no_parens" - assert int in simple_no_parens._handlers + assert int in simple_no_parens._handlers # pyright: ignore[reportPrivateUsage] def test_union_output_types(self): """Test that union output types are properly inferred for both messages and workflow outputs.""" @@ -113,7 +113,7 @@ class TestFunctionExecutor: else: await ctx.send_message(text.upper()) - spec = multi_output._handler_specs[0] + spec = multi_output._handler_specs[0] # pyright: ignore[reportPrivateUsage] assert set(spec["output_types"]) == {str, int} assert spec["workflow_output_types"] == [] # No workflow outputs defined @@ -127,7 +127,7 @@ class TestFunctionExecutor: else: await ctx.yield_output(data.upper()) - workflow_spec = multi_workflow_output._handler_specs[0] + workflow_spec = multi_workflow_output._handler_specs[0] # pyright: ignore[reportPrivateUsage] assert workflow_spec["output_types"] == [] # None means no message outputs assert set(workflow_spec["workflow_output_types"]) == {str, int, bool} @@ -139,7 +139,7 @@ class TestFunctionExecutor: # This executor doesn't send any messages pass - spec = no_output._handler_specs[0] + spec = no_output._handler_specs[0] # pyright: ignore[reportPrivateUsage] assert spec["output_types"] == [] assert spec["workflow_output_types"] == [] # No workflow outputs defined @@ -150,7 +150,7 @@ class TestFunctionExecutor: async def any_output(data: str, ctx: WorkflowContext[Any]) -> None: await ctx.send_message("result") - spec = any_output._handler_specs[0] + spec = any_output._handler_specs[0] # pyright: ignore[reportPrivateUsage] assert spec["output_types"] == [Any] assert spec["workflow_output_types"] == [] # No workflow outputs defined @@ -160,7 +160,7 @@ class TestFunctionExecutor: await ctx.send_message("message") await ctx.yield_output("workflow_output") - both_spec = any_both_output._handler_specs[0] + both_spec = any_both_output._handler_specs[0] # pyright: ignore[reportPrivateUsage] assert both_spec["output_types"] == [Any] assert both_spec["workflow_output_types"] == [Any] @@ -228,11 +228,11 @@ class TestFunctionExecutor: await ctx.yield_output(result) # Verify type inference for both executors - upper_spec = to_upper._handler_specs[0] + upper_spec = to_upper._handler_specs[0] # pyright: ignore[reportPrivateUsage] assert upper_spec["output_types"] == [str] assert upper_spec["workflow_output_types"] == [] # No workflow outputs - reverse_spec = reverse_text._handler_specs[0] + reverse_spec = reverse_text._handler_specs[0] # pyright: ignore[reportPrivateUsage] assert reverse_spec["output_types"] == [Any] # First parameter is Any assert reverse_spec["workflow_output_types"] == [str] # Second parameter is str @@ -270,7 +270,7 @@ class TestFunctionExecutor: await ctx.send_message(message) with pytest.raises(ValueError, match="Handler for type .* already registered"): - func_exec._register_instance_handler( + func_exec._register_instance_handler( # pyright: ignore[reportPrivateUsage] name="second", func=second_handler, message_type=str, @@ -287,7 +287,7 @@ class TestFunctionExecutor: result = {item: len(item) for item in items} await ctx.send_message(result) - spec = process_list._handler_specs[0] + spec = process_list._handler_specs[0] # pyright: ignore[reportPrivateUsage] assert spec["message_type"] == list[str] assert spec["output_types"] == [dict[str, int]] @@ -300,10 +300,10 @@ class TestFunctionExecutor: assert isinstance(process_simple, FunctionExecutor) assert process_simple.id == "simple_processor" - assert str in process_simple._handlers + assert str in process_simple._handlers # pyright: ignore[reportPrivateUsage] # Check spec - single parameter functions have no output types since they can't send messages - spec = process_simple._handler_specs[0] + spec = process_simple._handler_specs[0] # pyright: ignore[reportPrivateUsage] assert spec["message_type"] is str assert spec["output_types"] == [] assert spec["ctx_annotation"] is None @@ -316,7 +316,7 @@ class TestFunctionExecutor: return data * 2 func_exec = FunctionExecutor(valid_single) - assert int in func_exec._handlers + assert int in func_exec._handlers # pyright: ignore[reportPrivateUsage] # Single parameter with missing type annotation should still fail async def no_annotation(data): # type: ignore @@ -349,7 +349,7 @@ class TestFunctionExecutor: # For testing purposes, we can check that the handler is registered correctly assert double_value.can_handle(WorkflowMessage(data=5, source_id="mock")) - assert int in double_value._handlers + assert int in double_value._handlers # pyright: ignore[reportPrivateUsage] def test_sync_function_basic(self): """Test basic synchronous function support.""" @@ -360,10 +360,10 @@ class TestFunctionExecutor: assert isinstance(process_sync, FunctionExecutor) assert process_sync.id == "sync_processor" - assert str in process_sync._handlers + assert str in process_sync._handlers # pyright: ignore[reportPrivateUsage] # Check spec - sync single parameter functions have no output types - spec = process_sync._handler_specs[0] + spec = process_sync._handler_specs[0] # pyright: ignore[reportPrivateUsage] assert spec["message_type"] is str assert spec["output_types"] == [] assert spec["ctx_annotation"] is None @@ -378,10 +378,10 @@ class TestFunctionExecutor: assert isinstance(sync_with_ctx, FunctionExecutor) assert sync_with_ctx.id == "sync_with_ctx" - assert int in sync_with_ctx._handlers + assert int in sync_with_ctx._handlers # pyright: ignore[reportPrivateUsage] # Check spec - sync functions with context can infer output types - spec = sync_with_ctx._handler_specs[0] + spec = sync_with_ctx._handler_specs[0] # pyright: ignore[reportPrivateUsage] assert spec["message_type"] is int assert spec["output_types"] == [int] @@ -404,18 +404,18 @@ class TestFunctionExecutor: return data.upper() func_exec = FunctionExecutor(valid_sync) - assert str in func_exec._handlers + assert str in func_exec._handlers # pyright: ignore[reportPrivateUsage] # Valid sync function with two parameters def valid_sync_with_ctx(data: int, ctx: WorkflowContext[str]): return str(data) func_exec2 = FunctionExecutor(valid_sync_with_ctx) - assert int in func_exec2._handlers + assert int in func_exec2._handlers # pyright: ignore[reportPrivateUsage] # Sync function with missing type annotation should still fail - def no_annotation(data): # type: ignore - return data + def no_annotation(data): # type: ignore # pyright: ignore[reportUnknownVariableType] + return data # pyright: ignore[reportUnknownVariableType] with pytest.raises(ValueError, match="type annotation for the message"): FunctionExecutor(no_annotation) # type: ignore @@ -457,11 +457,11 @@ class TestFunctionExecutor: await ctx.yield_output(result) # Verify type inference for sync and async functions - sync_spec = to_upper_sync._handler_specs[0] + sync_spec = to_upper_sync._handler_specs[0] # pyright: ignore[reportPrivateUsage] assert sync_spec["output_types"] == [str] assert sync_spec["workflow_output_types"] == [] # No workflow outputs - async_spec = reverse_async._handler_specs[0] + async_spec = reverse_async._handler_specs[0] # pyright: ignore[reportPrivateUsage] assert async_spec["output_types"] == [Any] # First parameter is Any assert async_spec["workflow_output_types"] == [str] # Second parameter is str @@ -471,8 +471,8 @@ class TestFunctionExecutor: # For integration testing, we mainly verify that the handlers are properly registered # and the functions are wrapped correctly - assert str in to_upper_sync._handlers - assert str in reverse_async._handlers + assert str in to_upper_sync._handlers # pyright: ignore[reportPrivateUsage] + assert str in reverse_async._handlers # pyright: ignore[reportPrivateUsage] async def test_sync_function_thread_execution(self): """Test that sync functions run in thread pool and don't block the event loop.""" @@ -491,13 +491,13 @@ class TestFunctionExecutor: return data.upper() # Verify the function is wrapped and registered - assert str in blocking_function._handlers + assert str in blocking_function._handlers # pyright: ignore[reportPrivateUsage] # For a more complete test, we'd need to create a full workflow context, # but for now we can verify that the function was properly wrapped # and that sync functions store the correct metadata - assert not blocking_function._is_async - assert not blocking_function._has_context + assert not blocking_function._is_async # pyright: ignore[reportPrivateUsage] + assert not blocking_function._has_context # pyright: ignore[reportPrivateUsage] # The actual thread execution test would require a full workflow setup, # but the important thing is that asyncio.to_thread is used in the wrapper @@ -506,7 +506,7 @@ class TestFunctionExecutor: """Test that @executor decorator properly rejects @staticmethod with clear error.""" with pytest.raises(ValueError) as exc_info: - class Example: + class Example: # pyright: ignore[reportUnusedClass] @executor @staticmethod async def bad_handler(data: str) -> str: @@ -519,7 +519,7 @@ class TestFunctionExecutor: """Test that @executor decorator properly rejects @classmethod with clear error.""" with pytest.raises(ValueError) as exc_info: - class Example: + class Example: # pyright: ignore[reportUnusedClass] @executor @classmethod async def bad_handler(cls, data: str) -> str: @@ -570,8 +570,8 @@ class TestExecutorExplicitTypes: pass # Handler should be registered for str (explicit) - assert str in process._handlers - assert len(process._handlers) == 1 + assert str in process._handlers # pyright: ignore[reportPrivateUsage] + assert len(process._handlers) == 1 # pyright: ignore[reportPrivateUsage] # Can handle str messages assert process.can_handle(WorkflowMessage(data="hello", source_id="mock")) @@ -586,7 +586,7 @@ class TestExecutorExplicitTypes: pass # Handler spec should have int as output type (explicit), not str (introspected) - spec = process._handler_specs[0] + spec = process._handler_specs[0] # pyright: ignore[reportPrivateUsage] assert spec["output_types"] == [int] # Executor output_types property should reflect explicit type @@ -601,11 +601,11 @@ class TestExecutorExplicitTypes: pass # Handler should be registered for dict (explicit input type) - assert dict in process._handlers - assert len(process._handlers) == 1 + assert dict in process._handlers # pyright: ignore[reportPrivateUsage] + assert len(process._handlers) == 1 # pyright: ignore[reportPrivateUsage] # Output type should be list (explicit) - spec = process._handler_specs[0] + spec = process._handler_specs[0] # pyright: ignore[reportPrivateUsage] assert spec["output_types"] == [list] # Verify can_handle @@ -620,7 +620,7 @@ class TestExecutorExplicitTypes: pass # Handler should be registered for the union type - assert len(process._handlers) == 1 + assert len(process._handlers) == 1 # pyright: ignore[reportPrivateUsage] # Can handle both str and int messages assert process.can_handle(WorkflowMessage(data="hello", source_id="mock")) @@ -648,8 +648,8 @@ class TestExecutorExplicitTypes: pass # Should use explicit input type (bytes), not introspected (str) - assert bytes in process._handlers - assert str not in process._handlers + assert bytes in process._handlers # pyright: ignore[reportPrivateUsage] + assert str not in process._handlers # pyright: ignore[reportPrivateUsage] # Should use explicit output type (float), not introspected (int) assert float in process.output_types @@ -663,7 +663,7 @@ class TestExecutorExplicitTypes: pass # Should use introspected types - assert str in process._handlers + assert str in process._handlers # pyright: ignore[reportPrivateUsage] assert int in process.output_types def test_executor_partial_explicit_types(self): @@ -674,7 +674,7 @@ class TestExecutorExplicitTypes: async def process_input(message: str, ctx: WorkflowContext[int]) -> None: pass - assert bytes in process_input._handlers # Explicit + assert bytes in process_input._handlers # Explicit # pyright: ignore[reportPrivateUsage] assert int in process_input.output_types # Introspected # Only explicit output_type, introspect input_type @@ -682,7 +682,7 @@ class TestExecutorExplicitTypes: async def process_output(message: str, ctx: WorkflowContext[int]) -> None: pass - assert str in process_output._handlers # Introspected + assert str in process_output._handlers # Introspected # pyright: ignore[reportPrivateUsage] assert float in process_output.output_types # Explicit assert int not in process_output.output_types # Not introspected when explicit provided @@ -694,7 +694,7 @@ class TestExecutorExplicitTypes: pass # Should work with explicit input_type - assert str in process._handlers + assert str in process._handlers # pyright: ignore[reportPrivateUsage] assert process.can_handle(WorkflowMessage(data="hello", source_id="mock")) def test_executor_explicit_types_with_id(self): @@ -705,7 +705,7 @@ class TestExecutorExplicitTypes: pass assert process.id == "custom_id" - assert bytes in process._handlers + assert bytes in process._handlers # pyright: ignore[reportPrivateUsage] assert int in process.output_types def test_executor_explicit_types_with_single_param_function(self): @@ -713,10 +713,10 @@ class TestExecutorExplicitTypes: @executor(input=str) async def process(message): # type: ignore[no-untyped-def] - return message.upper() + return message.upper() # pyright: ignore[reportUnknownMemberType, reportUnknownVariableType] # Should work with explicit input_type - assert str in process._handlers + assert str in process._handlers # pyright: ignore[reportPrivateUsage] assert process.can_handle(WorkflowMessage(data="hello", source_id="mock")) assert not process.can_handle(WorkflowMessage(data=42, source_id="mock")) @@ -727,7 +727,7 @@ class TestExecutorExplicitTypes: def process(message, ctx: WorkflowContext) -> None: # type: ignore[no-untyped-def] pass - assert int in process._handlers + assert int in process._handlers # pyright: ignore[reportPrivateUsage] assert str in process.output_types def test_function_executor_constructor_with_explicit_types(self): @@ -736,10 +736,10 @@ class TestExecutorExplicitTypes: async def process(message, ctx: WorkflowContext) -> None: # type: ignore[no-untyped-def] pass - func_exec = FunctionExecutor(process, id="test", input=dict, output=list) + func_exec = FunctionExecutor(process, id="test", input=dict, output=list) # pyright: ignore[reportUnknownArgumentType] - assert dict in func_exec._handlers - spec = func_exec._handler_specs[0] + assert dict in func_exec._handlers # pyright: ignore[reportPrivateUsage] + spec = func_exec._handler_specs[0] # pyright: ignore[reportPrivateUsage] assert spec["message_type"] is dict assert spec["output_types"] == [list] @@ -766,7 +766,7 @@ class TestExecutorExplicitTypes: pass # Should resolve the string to the actual type - assert FuncExecForwardRefMessage in process._handlers + assert FuncExecForwardRefMessage in process._handlers # pyright: ignore[reportPrivateUsage] assert process.can_handle(WorkflowMessage(data=FuncExecForwardRefMessage("hello"), source_id="mock")) def test_executor_with_string_forward_reference_union(self): @@ -798,7 +798,7 @@ class TestExecutorExplicitTypes: pass # Handler spec should have bool as workflow_output_type (explicit) - spec = process._handler_specs[0] + spec = process._handler_specs[0] # pyright: ignore[reportPrivateUsage] assert spec["workflow_output_types"] == [bool] # Executor workflow_output_types property should reflect explicit type @@ -826,7 +826,7 @@ class TestExecutorExplicitTypes: pass # Check input type - assert str in process._handlers + assert str in process._handlers # pyright: ignore[reportPrivateUsage] assert process.can_handle(WorkflowMessage(data="hello", source_id="mock")) # Check output_type @@ -892,6 +892,6 @@ class TestExecutorExplicitTypes: workflow_output=bool, ) - assert str in exec_instance._handlers + assert str in exec_instance._handlers # pyright: ignore[reportPrivateUsage] assert int in exec_instance.output_types assert bool in exec_instance.workflow_output_types diff --git a/python/packages/core/tests/workflow/test_function_executor_future.py b/python/packages/core/tests/workflow/test_function_executor_future.py index a4a15aeba0..6d1ed32348 100644 --- a/python/packages/core/tests/workflow/test_function_executor_future.py +++ b/python/packages/core/tests/workflow/test_function_executor_future.py @@ -19,10 +19,10 @@ class TestFunctionExecutorFutureAnnotations: assert isinstance(process_future, FunctionExecutor) assert process_future.id == "future_test" - assert int in process_future._handlers + assert int in process_future._handlers # pyright: ignore[reportPrivateUsage] # Check spec - spec = process_future._handler_specs[0] + spec = process_future._handler_specs[0] # pyright: ignore[reportPrivateUsage] assert spec["message_type"] is int assert spec["output_types"] == [int] @@ -34,6 +34,6 @@ class TestFunctionExecutorFutureAnnotations: await ctx.send_message(["done"]) assert isinstance(process_complex, FunctionExecutor) - spec = process_complex._handler_specs[0] + spec = process_complex._handler_specs[0] # pyright: ignore[reportPrivateUsage] assert spec["message_type"] == dict[str, Any] assert spec["output_types"] == [list[str]] diff --git a/python/packages/core/tests/workflow/test_request_info_mixin.py b/python/packages/core/tests/workflow/test_request_info_mixin.py index 4c3d6560aa..cfde71b481 100644 --- a/python/packages/core/tests/workflow/test_request_info_mixin.py +++ b/python/packages/core/tests/workflow/test_request_info_mixin.py @@ -794,7 +794,7 @@ class TestResponseHandlerExplicitTypes: """Test response_handler with explicit request and response types.""" @response_handler(request=str, response=int) - async def test_handler(self, original_request, response, ctx) -> None: + async def test_handler(self: Any, original_request: Any, response: Any, ctx: WorkflowContext) -> None: pass spec = test_handler._response_handler_spec # type: ignore[reportAttributeAccessIssue] @@ -806,7 +806,7 @@ class TestResponseHandlerExplicitTypes: """Test response_handler with explicit output and workflow_output types.""" @response_handler(request=str, response=int, output=bool, workflow_output=float) - async def test_handler(self, original_request, response, ctx) -> None: + async def test_handler(self: Any, original_request: Any, response: Any, ctx: WorkflowContext) -> None: pass spec = test_handler._response_handler_spec # type: ignore[reportAttributeAccessIssue] @@ -818,8 +818,8 @@ class TestResponseHandlerExplicitTypes: def test_response_handler_with_union_types(self): """Test response_handler with union types.""" - @response_handler(request=str | int, response=bool | float) - async def test_handler(self, original_request, response, ctx) -> None: + @response_handler(request=str | int, response=bool | float) # pyright: ignore[reportArgumentType] + async def test_handler(self: Any, original_request: Any, response: Any, ctx: WorkflowContext) -> None: pass spec = test_handler._response_handler_spec # type: ignore[reportAttributeAccessIssue] @@ -830,7 +830,7 @@ class TestResponseHandlerExplicitTypes: """Test response_handler with string forward references.""" @response_handler(request="str", response="int") - async def test_handler(self, original_request, response, ctx) -> None: + async def test_handler(self: Any, original_request: Any, response: Any, ctx: WorkflowContext) -> None: pass spec = test_handler._response_handler_spec # type: ignore[reportAttributeAccessIssue] @@ -842,7 +842,7 @@ class TestResponseHandlerExplicitTypes: with pytest.raises(ValueError, match="must specify 'request' type"): @response_handler(response=int) - async def test_handler(self, original_request, response, ctx) -> None: + async def test_handler(self: Any, original_request: Any, response: Any, ctx: WorkflowContext) -> None: # pyright: ignore[reportUnusedFunction] pass def test_response_handler_explicit_missing_response_raises_error(self): @@ -850,7 +850,7 @@ class TestResponseHandlerExplicitTypes: with pytest.raises(ValueError, match="must specify 'response' type"): @response_handler(request=str) - async def test_handler(self, original_request, response, ctx) -> None: + async def test_handler(self: Any, original_request: Any, response: Any, ctx: WorkflowContext) -> None: # pyright: ignore[reportUnusedFunction] pass def test_response_handler_explicit_only_output_raises_error(self): @@ -858,7 +858,7 @@ class TestResponseHandlerExplicitTypes: with pytest.raises(ValueError, match="must specify 'request' type"): @response_handler(output=bool) - async def test_handler(self, original_request, response, ctx) -> None: + async def test_handler(self: Any, original_request: Any, response: Any, ctx: WorkflowContext) -> None: # pyright: ignore[reportUnusedFunction] pass def test_executor_with_explicit_response_handlers(self): @@ -873,7 +873,7 @@ class TestResponseHandlerExplicitTypes: pass @response_handler(request=str, response=int, output=bool) - async def handle_explicit(self, original_request, response, ctx) -> None: + async def handle_explicit(self, original_request: Any, response: Any, ctx: WorkflowContext) -> None: pass executor = TestExecutor() @@ -907,7 +907,7 @@ class TestResponseHandlerExplicitTypes: pass @response_handler(request=str, response=int) - async def handle_response(self, original_request, response, ctx) -> None: + async def handle_response(self, original_request: Any, response: Any, ctx: WorkflowContext) -> None: self.handled_request = original_request self.handled_response = response @@ -942,7 +942,7 @@ class TestResponseHandlerExplicitTypes: # Explicit type handler @response_handler(request=dict, response=bool) - async def handle_explicit(self, original_request, response, ctx) -> None: + async def handle_explicit(self, original_request: Any, response: Any, ctx: WorkflowContext) -> None: pass executor = TestExecutor() diff --git a/python/packages/core/tests/workflow/test_runner.py b/python/packages/core/tests/workflow/test_runner.py index eaf69f90b0..db6dccd9fa 100644 --- a/python/packages/core/tests/workflow/test_runner.py +++ b/python/packages/core/tests/workflow/test_runner.py @@ -2,6 +2,7 @@ import asyncio from dataclasses import dataclass +from typing import Any from unittest.mock import AsyncMock, MagicMock import pytest @@ -113,7 +114,7 @@ async def test_runner_run_until_convergence(): assert result is not None and result == 10 # iteration count shouldn't be reset after convergence - assert runner._iteration == 10 # type: ignore + assert runner._iteration == 10 # pyright: ignore[reportPrivateUsage] async def test_runner_run_until_convergence_not_completed(): @@ -173,7 +174,7 @@ async def test_runner_run_iteration_preserves_message_order_per_edge_runner() -> for index in range(5): await ctx.send_message(WorkflowMessage(data=MockMessage(data=index), source_id="source")) - await runner._run_iteration() + await runner._run_iteration() # pyright: ignore[reportPrivateUsage] assert edge_runner.received == [0, 1, 2, 3, 4] @@ -213,7 +214,7 @@ async def test_runner_run_iteration_delivers_different_edge_runners_concurrently await ctx.send_message(WorkflowMessage(data=MockMessage(data=1), source_id="source")) - iteration_task = asyncio.create_task(runner._run_iteration()) + iteration_task = asyncio.create_task(runner._run_iteration()) # pyright: ignore[reportPrivateUsage] await blocking_edge_runner.started.wait() await asyncio.wait_for(probe_edge_runner.probe_completed.wait(), timeout=2.0) @@ -280,7 +281,7 @@ async def test_fanout_edge_runner_delivers_to_multiple_targets_concurrently() -> # Queue a message from source (will be delivered to both targets via FanOut) await ctx.send_message(WorkflowMessage(data=MockMessage(data=1), source_id=source.id)) - iteration_task = asyncio.create_task(runner._run_iteration()) + iteration_task = asyncio.create_task(runner._run_iteration()) # pyright: ignore[reportPrivateUsage] # Wait for the blocking executor to start await blocking_target.started.wait() @@ -477,11 +478,11 @@ async def test_runner_reset_iteration_count(): ctx = InProcRunnerContext() runner = Runner([], {executor_a.id: executor_a}, state, ctx, "test_name", graph_signature_hash="test_hash") - runner._iteration = 10 + runner._iteration = 10 # pyright: ignore[reportPrivateUsage] runner.reset_iteration_count() - assert runner._iteration == 0 + assert runner._iteration == 0 # pyright: ignore[reportPrivateUsage] class CheckpointingContext(InProcRunnerContext): @@ -501,18 +502,19 @@ class CheckpointingContext(InProcRunnerContext): graph_signature_hash: str, state: State, previous_checkpoint_id: str | None, - iteration: int, + iteration_count: int, + metadata: dict[str, Any] | None = None, ) -> str: checkpoint = WorkflowCheckpoint( workflow_name=workflow_name, graph_signature_hash=graph_signature_hash, - state=state.export(), + state=state.export_state(), previous_checkpoint_id=previous_checkpoint_id, - iteration_count=iteration, + iteration_count=iteration_count, ) return await self._storage.save(checkpoint) - async def load_checkpoint(self, checkpoint_id: str) -> WorkflowCheckpoint | None: + async def load_checkpoint(self, checkpoint_id: str) -> WorkflowCheckpoint | None: # pyright: ignore[reportIncompatibleMethodOverride] try: return await self._storage.load(checkpoint_id) except WorkflowCheckpointException: @@ -537,7 +539,8 @@ class FailingCheckpointContext(InProcRunnerContext): graph_signature_hash: str, state: State, previous_checkpoint_id: str | None, - iteration: int, + iteration_count: int, + metadata: dict[str, Any] | None = None, ) -> str: raise RuntimeError("Simulated checkpoint failure") @@ -609,8 +612,8 @@ async def test_runner_restore_from_checkpoint_with_external_storage(): # Restore using external storage await runner.restore_from_checkpoint(checkpoint_id, checkpoint_storage=storage) - assert runner._resumed_from_checkpoint is True - assert runner._iteration == 5 + assert runner._resumed_from_checkpoint is True # pyright: ignore[reportPrivateUsage] + assert runner._iteration == 5 # pyright: ignore[reportPrivateUsage] assert state.get("test_key") == "test_value" @@ -684,7 +687,7 @@ async def test_runner_restore_executor_states_invalid_states_type(): runner = Runner([], {executor_a.id: executor_a}, state, ctx, "test_name", graph_signature_hash="test_hash") with pytest.raises(WorkflowCheckpointException, match="not a dictionary"): - await runner._restore_executor_states() + await runner._restore_executor_states() # pyright: ignore[reportPrivateUsage] async def test_runner_restore_executor_states_invalid_executor_id_type(): @@ -698,7 +701,7 @@ async def test_runner_restore_executor_states_invalid_executor_id_type(): runner = Runner([], {executor_a.id: executor_a}, state, ctx, "test_name", graph_signature_hash="test_hash") with pytest.raises(WorkflowCheckpointException, match="not a string"): - await runner._restore_executor_states() + await runner._restore_executor_states() # pyright: ignore[reportPrivateUsage] async def test_runner_restore_executor_states_invalid_state_type(): @@ -712,7 +715,7 @@ async def test_runner_restore_executor_states_invalid_state_type(): runner = Runner([], {executor_a.id: executor_a}, state, ctx, "test_name", graph_signature_hash="test_hash") with pytest.raises(WorkflowCheckpointException, match="not a dict"): - await runner._restore_executor_states() + await runner._restore_executor_states() # pyright: ignore[reportPrivateUsage] async def test_runner_restore_executor_states_invalid_state_keys(): @@ -726,7 +729,7 @@ async def test_runner_restore_executor_states_invalid_state_keys(): runner = Runner([], {executor_a.id: executor_a}, state, ctx, "test_name", graph_signature_hash="test_hash") with pytest.raises(WorkflowCheckpointException, match="not a dict"): - await runner._restore_executor_states() + await runner._restore_executor_states() # pyright: ignore[reportPrivateUsage] async def test_runner_restore_executor_states_missing_executor(): @@ -739,7 +742,7 @@ async def test_runner_restore_executor_states_missing_executor(): runner = Runner([], {}, state, ctx, "test_name", graph_signature_hash="test_hash") with pytest.raises(WorkflowCheckpointException, match="not found during state restoration"): - await runner._restore_executor_states() + await runner._restore_executor_states() # pyright: ignore[reportPrivateUsage] async def test_runner_set_executor_state_invalid_existing_states(): @@ -752,7 +755,7 @@ async def test_runner_set_executor_state_invalid_existing_states(): runner = Runner([], {executor_a.id: executor_a}, state, ctx, "test_name", graph_signature_hash="test_hash") with pytest.raises(WorkflowCheckpointException, match="not a dictionary"): - await runner._set_executor_state("executor_a", {"key": "value"}) + await runner._set_executor_state("executor_a", {"key": "value"}) # pyright: ignore[reportPrivateUsage] async def test_runner_with_pre_loop_events(): @@ -779,7 +782,7 @@ class EventEmittingExecutor(Executor): """An executor that emits events during execution.""" @handler - async def handle(self, message: MockMessage, ctx: WorkflowContext[MockMessage, int]) -> None: + async def handle(self, message: MockMessage, ctx: WorkflowContext[MockMessage, str]) -> None: # Emit event during processing await ctx.yield_output(f"processed-{message.data}") if message.data < 3: @@ -831,7 +834,7 @@ async def test_runner_restore_executor_states_no_states(): runner = Runner([], {executor_a.id: executor_a}, state, ctx, "test_name", graph_signature_hash="test_hash") # Should complete without error when no executor states exist - await runner._restore_executor_states() + await runner._restore_executor_states() # pyright: ignore[reportPrivateUsage] async def test_runner_checkpoint_with_resumed_flag(): @@ -853,7 +856,7 @@ async def test_runner_checkpoint_with_resumed_flag(): state = State() runner = Runner(edges, executors, state, ctx, "test_name", graph_signature_hash="test_hash") - runner._mark_resumed(5) + runner._mark_resumed(5) # pyright: ignore[reportPrivateUsage] # Add a message to trigger the checkpoint creation path await ctx.send_message(WorkflowMessage(data=MockMessage(data=8), source_id="START")) @@ -870,7 +873,7 @@ async def test_runner_checkpoint_with_resumed_flag(): pass # After completing, resumed flag should be reset - assert runner._resumed_from_checkpoint is False + assert runner._resumed_from_checkpoint is False # pyright: ignore[reportPrivateUsage] class ExecutorThatFailsWithEvents(Executor): @@ -883,7 +886,7 @@ class ExecutorThatFailsWithEvents(Executor): self._iteration_count = 0 @handler - async def handle(self, message: MockMessage, ctx: WorkflowContext[MockMessage, int]) -> None: + async def handle(self, message: MockMessage, ctx: WorkflowContext[MockMessage, str]) -> None: self._iteration_count += 1 # First emit an output event to the workflow context await ctx.yield_output(f"output-before-failure-{message.data}") @@ -951,7 +954,7 @@ class SlowEventEmittingExecutor(Executor): self.current_iteration = 0 @handler - async def handle(self, message: MockMessage, ctx: WorkflowContext[MockMessage, int]) -> None: + async def handle(self, message: MockMessage, ctx: WorkflowContext[MockMessage, str]) -> None: self.current_iteration += 1 # Emit output event await ctx.yield_output(f"iteration-{self.current_iteration}") diff --git a/python/packages/core/tests/workflow/test_state.py b/python/packages/core/tests/workflow/test_state.py index 486fc9fa25..7781eb4141 100644 --- a/python/packages/core/tests/workflow/test_state.py +++ b/python/packages/core/tests/workflow/test_state.py @@ -61,9 +61,9 @@ class TestSuperstepCaching: state.set("key", "value") # Value is in pending - assert "key" in state._pending + assert "key" in state._pending # pyright: ignore[reportPrivateUsage] # Value is NOT in committed - assert "key" not in state._committed + assert "key" not in state._committed # pyright: ignore[reportPrivateUsage] # But get() still returns it assert state.get("key") == "value" @@ -72,14 +72,14 @@ class TestSuperstepCaching: state.set("key", "value") # Before commit: in pending, not committed - assert "key" in state._pending - assert "key" not in state._committed + assert "key" in state._pending # pyright: ignore[reportPrivateUsage] + assert "key" not in state._committed # pyright: ignore[reportPrivateUsage] state.commit() # After commit: in committed, pending cleared - assert "key" not in state._pending - assert "key" in state._committed + assert "key" not in state._pending # pyright: ignore[reportPrivateUsage] + assert "key" in state._committed # pyright: ignore[reportPrivateUsage] assert state.get("key") == "value" def test_discard_clears_pending_without_committing(self) -> None: @@ -108,7 +108,7 @@ class TestSuperstepCaching: # get() returns pending value, not committed assert state.get("key") == "pending_value" # But committed still has old value - assert state._committed["key"] == "committed_value" + assert state._committed["key"] == "committed_value" # pyright: ignore[reportPrivateUsage] def test_multiple_sets_before_commit(self) -> None: state = State() @@ -130,13 +130,13 @@ class TestDeleteWithSuperstepCaching: state = State() state.set("key", "value") # Key only in pending, not committed - assert "key" in state._pending - assert "key" not in state._committed + assert "key" in state._pending # pyright: ignore[reportPrivateUsage] + assert "key" not in state._committed # pyright: ignore[reportPrivateUsage] state.delete("key") # Should be removed from pending - assert "key" not in state._pending + assert "key" not in state._pending # pyright: ignore[reportPrivateUsage] assert state.get("key") is None assert state.has("key") is False @@ -148,14 +148,14 @@ class TestDeleteWithSuperstepCaching: state.delete("key") # Key should be marked for deletion in pending (sentinel) - assert "key" in state._pending + assert "key" in state._pending # pyright: ignore[reportPrivateUsage] # get() should return default (not the sentinel!) assert state.get("key") is None assert state.get("key", "default") == "default" # has() should return False assert state.has("key") is False # But committed still has it until commit() - assert "key" in state._committed + assert "key" in state._committed # pyright: ignore[reportPrivateUsage] def test_delete_committed_key_removed_on_commit(self) -> None: state = State() @@ -166,8 +166,8 @@ class TestDeleteWithSuperstepCaching: state.commit() # Now it should be gone from committed too - assert "key" not in state._committed - assert "key" not in state._pending + assert "key" not in state._committed # pyright: ignore[reportPrivateUsage] + assert "key" not in state._pending # pyright: ignore[reportPrivateUsage] def test_delete_key_in_both_pending_and_committed(self) -> None: """Test delete when key exists in both pending (modified) and committed.""" @@ -177,8 +177,8 @@ class TestDeleteWithSuperstepCaching: # Modify the key (now in both pending and committed) state.set("key", "modified") - assert state._pending["key"] == "modified" - assert state._committed["key"] == "original" + assert state._pending["key"] == "modified" # pyright: ignore[reportPrivateUsage] + assert state._committed["key"] == "original" # pyright: ignore[reportPrivateUsage] # Delete should mark for deletion from committed state.delete("key") @@ -189,8 +189,8 @@ class TestDeleteWithSuperstepCaching: # After commit, key should be fully removed state.commit() - assert "key" not in state._committed - assert "key" not in state._pending + assert "key" not in state._committed # pyright: ignore[reportPrivateUsage] + assert "key" not in state._pending # pyright: ignore[reportPrivateUsage] def test_discard_after_delete_restores_committed_value(self) -> None: state = State() @@ -238,12 +238,12 @@ class TestFailureScenarios: state.set("key3", "value3") # Before commit - nothing in committed - assert len(state._committed) == 0 + assert len(state._committed) == 0 # pyright: ignore[reportPrivateUsage] state.commit() # After commit - all three values committed together - assert state._committed == {"key1": "value1", "key2": "value2", "key3": "value3"} + assert state._committed == {"key1": "value1", "key2": "value2", "key3": "value3"} # pyright: ignore[reportPrivateUsage] def test_repeated_supersteps_are_isolated(self) -> None: """Test that each superstep's changes are isolated until committed.""" @@ -300,4 +300,4 @@ class TestExportImport: # Pending is still there assert state.get("pending_key") == "pending_value" - assert "pending_key" in state._pending + assert "pending_key" in state._pending # pyright: ignore[reportPrivateUsage] diff --git a/python/packages/core/tests/workflow/test_typing_utils.py b/python/packages/core/tests/workflow/test_typing_utils.py index 4dc8d8c917..f94bd9d52e 100644 --- a/python/packages/core/tests/workflow/test_typing_utils.py +++ b/python/packages/core/tests/workflow/test_typing_utils.py @@ -36,32 +36,32 @@ def test_normalize_type_to_list_none() -> None: def test_normalize_type_to_list_union_pipe_syntax() -> None: """Test normalize_type_to_list with union types using | syntax.""" - result = normalize_type_to_list(str | int) + result = normalize_type_to_list(str | int) # pyright: ignore[reportArgumentType] assert set(result) == {str, int} - result = normalize_type_to_list(str | int | bool) + result = normalize_type_to_list(str | int | bool) # pyright: ignore[reportArgumentType] assert set(result) == {str, int, bool} def test_normalize_type_to_list_union_typing_syntax() -> None: """Test normalize_type_to_list with Union[] from typing module.""" - result = normalize_type_to_list(Union[str, int]) + result = normalize_type_to_list(Union[str, int]) # pyright: ignore[reportArgumentType] assert set(result) == {str, int} - result = normalize_type_to_list(Union[str, int, bool]) + result = normalize_type_to_list(Union[str, int, bool]) # pyright: ignore[reportArgumentType] assert set(result) == {str, int, bool} def test_normalize_type_to_list_optional() -> None: """Test normalize_type_to_list with Optional types (Union[T, None]).""" # Optional[str] is Union[str, None] - result = normalize_type_to_list(Optional[str]) + result = normalize_type_to_list(Optional[str]) # pyright: ignore[reportArgumentType] assert str in result assert type(None) in result assert len(result) == 2 # str | None is equivalent - result = normalize_type_to_list(str | None) + result = normalize_type_to_list(str | None) # pyright: ignore[reportArgumentType] assert str in result assert type(None) in result assert len(result) == 2 @@ -77,7 +77,7 @@ def test_normalize_type_to_list_custom_types() -> None: result = normalize_type_to_list(CustomMessage) assert result == [CustomMessage] - result = normalize_type_to_list(CustomMessage | str) + result = normalize_type_to_list(CustomMessage | str) # pyright: ignore[reportArgumentType] assert set(result) == {CustomMessage, str} @@ -96,7 +96,7 @@ def test_resolve_type_annotation_actual_types() -> None: """Test resolve_type_annotation passes through actual types unchanged.""" assert resolve_type_annotation(str) is str assert resolve_type_annotation(int) is int - assert resolve_type_annotation(str | int) == str | int + assert resolve_type_annotation(str | int) == str | int # pyright: ignore[reportArgumentType] def test_resolve_type_annotation_string_builtin() -> None: diff --git a/python/packages/core/tests/workflow/test_validation.py b/python/packages/core/tests/workflow/test_validation.py index ae694c8354..be3c8b45f7 100644 --- a/python/packages/core/tests/workflow/test_validation.py +++ b/python/packages/core/tests/workflow/test_validation.py @@ -484,8 +484,8 @@ def test_handler_ctx_missing_annotation_raises() -> None: # Validation now happens at handler registration time, not workflow build time with pytest.raises(ValueError) as exc: - class BadExecutor(Executor): - @handler + class BadExecutor(Executor): # pyright: ignore[reportUnusedClass] + @handler # pyright: ignore[reportUnknownArgumentType] async def handle(self, message: str, ctx) -> None: # type: ignore[no-untyped-def] pass @@ -496,8 +496,8 @@ def test_handler_ctx_invalid_t_out_entries_raises() -> None: # Validation now happens at handler registration time, not workflow build time with pytest.raises(ValueError) as exc: - class BadExecutor(Executor): - @handler + class BadExecutor(Executor): # pyright: ignore[reportUnusedClass] + @handler # pyright: ignore[reportUnknownArgumentType] async def handle(self, message: str, ctx: WorkflowContext[123]) -> None: # type: ignore[valid-type] pass @@ -555,7 +555,7 @@ def test_output_validation_with_valid_output_executors(): ) assert workflow is not None - assert workflow._output_executors == ["executor2"] + assert workflow._output_executors == ["executor2"] # pyright: ignore[reportPrivateUsage] def test_output_validation_with_multiple_valid_output_executors(): @@ -572,7 +572,7 @@ def test_output_validation_with_multiple_valid_output_executors(): ) assert workflow is not None - assert set(workflow._output_executors) == {"executor1", "executor3"} + assert set(workflow._output_executors) == {"executor1", "executor3"} # pyright: ignore[reportPrivateUsage] def test_output_validation_fails_for_nonexistent_executor(): diff --git a/python/packages/core/tests/workflow/test_viz.py b/python/packages/core/tests/workflow/test_viz.py index bf7bbffee1..5573dadd61 100644 --- a/python/packages/core/tests/workflow/test_viz.py +++ b/python/packages/core/tests/workflow/test_viz.py @@ -2,6 +2,9 @@ """Tests for the workflow visualization module.""" +from pathlib import Path +from typing import Any + import pytest from agent_framework import Executor, WorkflowBuilder, WorkflowContext, WorkflowExecutor, WorkflowViz, handler @@ -25,7 +28,7 @@ class ListStrTargetExecutor(Executor): @pytest.fixture -def basic_sub_workflow(): +def basic_sub_workflow() -> dict[str, Any]: """Fixture that creates a basic sub-workflow setup for testing.""" # Create a sub-workflow sub_exec1 = MockExecutor(id="sub_exec1") @@ -98,7 +101,7 @@ def test_workflow_viz_export_dot(): assert '"executor1" -> "executor2"' in content -def test_workflow_viz_export_dot_with_filename(tmp_path): +def test_workflow_viz_export_dot_with_filename(tmp_path: Path): """Test exporting workflow as DOT format with specified filename.""" executor1 = MockExecutor(id="executor1") executor2 = MockExecutor(id="executor2") @@ -203,7 +206,7 @@ def test_workflow_viz_graphviz_binary_not_found(): mock_source_class.return_value = mock_source # Import the ExecutableNotFound exception for the test - from graphviz.backend.execute import ExecutableNotFound + from graphviz.backend.execute import ExecutableNotFound # type: ignore[import-not-found] mock_source.render.side_effect = ExecutableNotFound("failed to execute PosixPath('dot')") @@ -329,7 +332,7 @@ def test_workflow_viz_mermaid_fan_in_edge_group(): assert "s2 --> t" not in mermaid -def test_workflow_viz_sub_workflow_digraph(basic_sub_workflow): +def test_workflow_viz_sub_workflow_digraph(basic_sub_workflow: dict[str, Any]): """Test that WorkflowViz can visualize sub-workflows in DOT format.""" main_workflow = basic_sub_workflow["main_workflow"] @@ -353,7 +356,7 @@ def test_workflow_viz_sub_workflow_digraph(basic_sub_workflow): assert '"workflow_executor_1/sub_exec1" -> "workflow_executor_1/sub_exec2"' in dot_content -def test_workflow_viz_sub_workflow_mermaid(basic_sub_workflow): +def test_workflow_viz_sub_workflow_mermaid(basic_sub_workflow: dict[str, Any]): """Test that WorkflowViz can visualize sub-workflows in Mermaid format.""" main_workflow = basic_sub_workflow["main_workflow"] diff --git a/python/packages/core/tests/workflow/test_workflow.py b/python/packages/core/tests/workflow/test_workflow.py index 8bbf11fa6a..f338ce94f6 100644 --- a/python/packages/core/tests/workflow/test_workflow.py +++ b/python/packages/core/tests/workflow/test_workflow.py @@ -4,7 +4,7 @@ import asyncio import tempfile from collections.abc import AsyncIterable, Awaitable, Sequence from dataclasses import dataclass, field -from typing import Any, cast +from typing import Any, Literal, cast, overload from uuid import uuid4 import pytest @@ -13,6 +13,7 @@ from agent_framework import ( AgentExecutor, AgentResponse, AgentResponseUpdate, + AgentRunInputs, AgentSession, BaseAgent, Content, @@ -474,7 +475,7 @@ class StateTrackingExecutor(Executor): ) -> None: """Handle the message and track it in workflow state.""" # Get existing messages from workflow state - existing_messages = ctx.get_state("processed_messages") or [] + existing_messages: list[str] = ctx.get_state("processed_messages") or [] # Record this message message_record = f"{message.run_id}:{message.data}" @@ -833,6 +834,26 @@ class _StreamingTestAgent(BaseAgent): super().__init__(**kwargs) self._reply_text = reply_text + @overload + def run( + self, + messages: AgentRunInputs | None = ..., + *, + stream: Literal[False] = ..., + session: AgentSession | None = ..., + **kwargs: Any, + ) -> Awaitable[AgentResponse[Any]]: ... + + @overload + def run( + self, + messages: AgentRunInputs | None = ..., + *, + stream: Literal[True], + session: AgentSession | None = ..., + **kwargs: Any, + ) -> ResponseStream[AgentResponseUpdate, AgentResponse[Any]]: ... + def run( self, messages: str | Content | Message | Sequence[str | Content | Message] | None = None, @@ -883,8 +904,10 @@ async def test_agent_streaming_vs_non_streaming() -> None: stream_events.append(event) # Filter for agent events - agent_response = [ - cast(AgentResponse, e.data) for e in stream_events if e.type == "output" and isinstance(e.data, AgentResponse) + agent_response: list[AgentResponse[Any]] = [ + cast(AgentResponse[Any], e.data) # pyright: ignore[reportUnknownMemberType] + for e in stream_events + if e.type == "output" and isinstance(e.data, AgentResponse) ] agent_response_updates = [ e.data for e in stream_events if e.type == "output" and isinstance(e.data, AgentResponseUpdate) diff --git a/python/packages/core/tests/workflow/test_workflow_agent.py b/python/packages/core/tests/workflow/test_workflow_agent.py index d20d60ba3b..b5a8bb9902 100644 --- a/python/packages/core/tests/workflow/test_workflow_agent.py +++ b/python/packages/core/tests/workflow/test_workflow_agent.py @@ -2,7 +2,7 @@ import uuid from collections.abc import Awaitable, Sequence -from typing import Any +from typing import Any, Literal, overload import pytest from typing_extensions import Never @@ -713,6 +713,14 @@ class TestWorkflowAgent: def create_session(self, **kwargs: Any) -> AgentSession: return AgentSession() + def get_session(self, *, service_session_id: str, **kwargs: Any) -> AgentSession: + return AgentSession() + + @overload + def run(self, messages: str | Content | Message | Sequence[str | Content | Message] | None = ..., *, stream: Literal[False] = ..., session: AgentSession | None = ..., **kwargs: Any) -> Awaitable[AgentResponse[Any]]: ... + @overload + def run(self, messages: str | Content | Message | Sequence[str | Content | Message] | None = ..., *, stream: Literal[True], session: AgentSession | None = ..., **kwargs: Any) -> ResponseStream[AgentResponseUpdate, AgentResponse[Any]]: ... + def run( self, messages: str | Content | Message | Sequence[str | Content | Message] | None = None, @@ -801,6 +809,14 @@ class TestWorkflowAgent: def create_session(self, **kwargs: Any) -> AgentSession: return AgentSession() + def get_session(self, *, service_session_id: str, **kwargs: Any) -> AgentSession: + return AgentSession() + + @overload + def run(self, messages: str | Content | Message | Sequence[str | Content | Message] | None = ..., *, stream: Literal[False] = ..., session: AgentSession | None = ..., **kwargs: Any) -> Awaitable[AgentResponse[Any]]: ... + @overload + def run(self, messages: str | Content | Message | Sequence[str | Content | Message] | None = ..., *, stream: Literal[True], session: AgentSession | None = ..., **kwargs: Any) -> ResponseStream[AgentResponseUpdate, AgentResponse[Any]]: ... + def run( self, messages: str | Content | Message | Sequence[str | Content | Message] | None = None, @@ -1207,7 +1223,7 @@ class TestWorkflowAgentMergeUpdates: ] # Compare using role.value for Role enum - actual_sequence_normalized = [(t, r.value if hasattr(r, "value") else r) for t, r in content_sequence] + actual_sequence_normalized = [(t, r.value if hasattr(r, "value") else r) for t, r in content_sequence] # type: ignore[union-attr] assert actual_sequence_normalized == expected_sequence, ( f"FunctionResultContent should come immediately after FunctionCallContent. " diff --git a/python/packages/core/tests/workflow/test_workflow_builder.py b/python/packages/core/tests/workflow/test_workflow_builder.py index 073a24e5a3..3a7b719530 100644 --- a/python/packages/core/tests/workflow/test_workflow_builder.py +++ b/python/packages/core/tests/workflow/test_workflow_builder.py @@ -1,7 +1,8 @@ # Copyright (c) Microsoft. All rights reserved. +from collections.abc import AsyncIterator, Awaitable from dataclasses import dataclass -from typing import Any +from typing import Any, Literal, overload import pytest @@ -9,10 +10,12 @@ from agent_framework import ( AgentExecutor, AgentResponse, AgentResponseUpdate, + AgentRunInputs, AgentSession, BaseAgent, Executor, Message, + ResponseStream, WorkflowBuilder, WorkflowContext, WorkflowValidationError, @@ -21,22 +24,49 @@ from agent_framework import ( class DummyAgent(BaseAgent): - def run(self, messages=None, *, stream: bool = False, session: AgentSession | None = None, **kwargs): # type: ignore[override] + @overload + def run( + self, + messages: AgentRunInputs | None = ..., + *, + stream: Literal[False] = ..., + session: AgentSession | None = ..., + **kwargs: Any, + ) -> Awaitable[AgentResponse[Any]]: ... + + @overload + def run( + self, + messages: AgentRunInputs | None = ..., + *, + stream: Literal[True], + session: AgentSession | None = ..., + **kwargs: Any, + ) -> ResponseStream[AgentResponseUpdate, AgentResponse[Any]]: ... + + def run( + self, + messages: AgentRunInputs | None = None, + *, + stream: bool = False, + session: AgentSession | None = None, + **kwargs: Any, + ) -> Awaitable[AgentResponse[Any]] | ResponseStream[AgentResponseUpdate, AgentResponse[Any]]: if stream: - return self._run_stream_impl() + return ResponseStream[AgentResponseUpdate, AgentResponse[Any]](self._run_stream_impl()) return self._run_impl(messages) - async def _run_impl(self, messages=None) -> AgentResponse: + async def _run_impl(self, messages: AgentRunInputs | None = None) -> AgentResponse: norm: list[Message] = [] if messages: - for m in messages: # type: ignore[iteration-over-optional] + for m in messages: # type: ignore[union-attr] if isinstance(m, Message): norm.append(m) elif isinstance(m, str): norm.append(Message(role="user", text=m)) return AgentResponse(messages=norm) - async def _run_stream_impl(self): # type: ignore[override] + async def _run_stream_impl(self) -> AsyncIterator[AgentResponseUpdate]: # Minimal async generator yield AgentResponseUpdate() @@ -202,7 +232,7 @@ def test_with_output_from_returns_builder(): builder = WorkflowBuilder(output_executors=[executor_a], start_executor=executor_a) # Verify builder was created with output_executors - assert builder._output_executors == [executor_a] + assert builder._output_executors == [executor_a] # pyright: ignore[reportPrivateUsage] def test_with_output_from_with_executor_instances(): diff --git a/python/packages/core/tests/workflow/test_workflow_context.py b/python/packages/core/tests/workflow/test_workflow_context.py index 53a7e44903..a13c0b5a55 100644 --- a/python/packages/core/tests/workflow/test_workflow_context.py +++ b/python/packages/core/tests/workflow/test_workflow_context.py @@ -84,7 +84,7 @@ async def test_executor_emits_normal_event() -> None: class _TestEvent(WorkflowEvent): def __init__(self, data: Any = None) -> None: - super().__init__("test_event", data=data) + super().__init__("test_event", data=data) # type: ignore[arg-type] async def test_workflow_context_type_annotations_no_parameter() -> None: @@ -244,8 +244,8 @@ async def test_workflow_context_missing_annotation_error() -> None: # Test class-based executor with missing ctx annotation with pytest.raises(ValueError, match="must have a WorkflowContext"): - class _BadExecutor(Executor): - @handler + class _BadExecutor(Executor): # pyright: ignore[reportUnusedClass] + @handler # pyright: ignore[reportUnknownArgumentType] async def bad_handler(self, text: str, ctx) -> None: # type: ignore[no-untyped-def] pass @@ -264,8 +264,8 @@ async def test_workflow_context_invalid_type_parameter_error() -> None: # Test class-based executor with invalid type parameter with pytest.raises(ValueError, match="invalid type entry"): - class _BadExecutor(Executor): - @handler + class _BadExecutor(Executor): # pyright: ignore[reportUnusedClass] + @handler # pyright: ignore[reportUnknownArgumentType] async def bad_handler(self, text: str, ctx: WorkflowContext[456]) -> None: # type: ignore[valid-type] pass diff --git a/python/packages/core/tests/workflow/test_workflow_kwargs.py b/python/packages/core/tests/workflow/test_workflow_kwargs.py index ce1465effc..0850c6b060 100644 --- a/python/packages/core/tests/workflow/test_workflow_kwargs.py +++ b/python/packages/core/tests/workflow/test_workflow_kwargs.py @@ -1,13 +1,14 @@ # Copyright (c) Microsoft. All rights reserved. -from collections.abc import AsyncIterable, Awaitable, Sequence -from typing import Annotated, Any +from collections.abc import AsyncIterable, Awaitable +from typing import Annotated, Any, Literal, overload import pytest from agent_framework import ( AgentResponse, AgentResponseUpdate, + AgentRunInputs, AgentSession, BaseAgent, Content, @@ -50,14 +51,19 @@ class _KwargsCapturingAgent(BaseAgent): super().__init__(name=name, description="Test agent for kwargs capture") self.captured_kwargs = [] + @overload + def run(self, messages: AgentRunInputs | None = ..., *, stream: Literal[False] = ..., session: AgentSession | None = ..., **kwargs: Any) -> Awaitable[AgentResponse[Any]]: ... + @overload + def run(self, messages: AgentRunInputs | None = ..., *, stream: Literal[True], session: AgentSession | None = ..., **kwargs: Any) -> ResponseStream[AgentResponseUpdate, AgentResponse[Any]]: ... + def run( self, - messages: str | Content | Message | Sequence[str | Content | Message] | None = None, + messages: AgentRunInputs | None = None, *, stream: bool = False, session: AgentSession | None = None, **kwargs: Any, - ) -> Awaitable[AgentResponse] | ResponseStream[AgentResponseUpdate, AgentResponse]: + ) -> Awaitable[AgentResponse[Any]] | ResponseStream[AgentResponseUpdate, AgentResponse[Any]]: self.captured_kwargs.append(dict(kwargs)) if stream: @@ -83,15 +89,20 @@ class _OptionsAwareAgent(BaseAgent): self.captured_options = [] self.captured_kwargs = [] + @overload + def run(self, messages: AgentRunInputs | None = ..., *, stream: Literal[False] = ..., session: AgentSession | None = ..., **kwargs: Any) -> Awaitable[AgentResponse[Any]]: ... + @overload + def run(self, messages: AgentRunInputs | None = ..., *, stream: Literal[True], session: AgentSession | None = ..., **kwargs: Any) -> ResponseStream[AgentResponseUpdate, AgentResponse[Any]]: ... + def run( self, - messages: str | Content | Message | Sequence[str | Content | Message] | None = None, + messages: AgentRunInputs | None = None, *, stream: bool = False, session: AgentSession | None = None, options: dict[str, Any] | None = None, **kwargs: Any, - ) -> Awaitable[AgentResponse] | ResponseStream[AgentResponseUpdate, AgentResponse]: + ) -> Awaitable[AgentResponse[Any]] | ResponseStream[AgentResponseUpdate, AgentResponse[Any]]: self.captured_options.append(dict(options) if options is not None else None) self.captured_kwargs.append(dict(kwargs)) if stream: @@ -189,15 +200,15 @@ async def test_sequential_run_options_does_not_conflict_with_agent_options() -> break assert len(agent.captured_options) >= 1 - captured_options = agent.captured_options[0] + captured_options: dict[str, Any] | None = agent.captured_options[0] assert captured_options is not None assert captured_options.get("store") is False - additional_args = captured_options.get("additional_function_arguments") + additional_args: Any = captured_options.get("additional_function_arguments") assert isinstance(additional_args, dict) - assert additional_args.get("source") == "workflow-options" - assert additional_args.get("custom_data") == custom_data - assert additional_args.get("user_token") == user_token + assert additional_args.get("source") == "workflow-options" # pyright: ignore[reportUnknownMemberType] + assert additional_args.get("custom_data") == custom_data # pyright: ignore[reportUnknownMemberType] + assert additional_args.get("user_token") == user_token # pyright: ignore[reportUnknownMemberType] # "options" should be passed once via the dedicated options parameter, # not duplicated in **kwargs. @@ -225,13 +236,13 @@ async def test_sequential_run_additional_function_arguments_flattened() -> None: break assert len(agent.captured_options) >= 1 - captured_options = agent.captured_options[0] + captured_options: dict[str, Any] | None = agent.captured_options[0] assert captured_options is not None - additional_args = captured_options.get("additional_function_arguments") + additional_args: Any = captured_options.get("additional_function_arguments") assert isinstance(additional_args, dict) - assert additional_args.get("custom_data") == custom_data - assert additional_args.get("user_token") == user_token + assert additional_args.get("custom_data") == custom_data # pyright: ignore[reportUnknownMemberType] + assert additional_args.get("user_token") == user_token # pyright: ignore[reportUnknownMemberType] assert "additional_function_arguments" not in additional_args assert len(agent.captured_kwargs) >= 1 @@ -255,14 +266,14 @@ async def test_sequential_run_additional_function_arguments_merges_with_options( break assert len(agent.captured_options) >= 1 - captured_options = agent.captured_options[0] + captured_options: dict[str, Any] | None = agent.captured_options[0] assert captured_options is not None - additional_args = captured_options.get("additional_function_arguments") + additional_args: Any = captured_options.get("additional_function_arguments") assert isinstance(additional_args, dict) - assert additional_args.get("source") == "workflow-options" - assert additional_args.get("custom_data") == {"session_id": "abc123"} - assert additional_args.get("user_token") == {"user_name": "alice"} + assert additional_args.get("source") == "workflow-options" # pyright: ignore[reportUnknownMemberType] + assert additional_args.get("custom_data") == {"session_id": "abc123"} # pyright: ignore[reportUnknownMemberType] + assert additional_args.get("user_token") == {"user_name": "alice"} # pyright: ignore[reportUnknownMemberType] assert "additional_function_arguments" not in additional_args @@ -463,14 +474,19 @@ async def test_kwargs_preserved_on_response_continuation() -> None: self.captured_kwargs = [] self._asked = False + @overload + def run(self, messages: AgentRunInputs | None = ..., *, stream: Literal[False] = ..., session: AgentSession | None = ..., **kwargs: Any) -> Awaitable[AgentResponse[Any]]: ... + @overload + def run(self, messages: AgentRunInputs | None = ..., *, stream: Literal[True], session: AgentSession | None = ..., **kwargs: Any) -> ResponseStream[AgentResponseUpdate, AgentResponse[Any]]: ... + def run( self, - messages: str | Content | Message | Sequence[str | Content | Message] | None = None, + messages: AgentRunInputs | None = None, *, stream: bool = False, session: AgentSession | None = None, **kwargs: Any, - ) -> Awaitable[AgentResponse] | ResponseStream[AgentResponseUpdate, AgentResponse]: + ) -> Awaitable[AgentResponse[Any]] | ResponseStream[AgentResponseUpdate, AgentResponse[Any]]: self.captured_kwargs.append(dict(kwargs)) if not self._asked: self._asked = True @@ -521,14 +537,19 @@ async def test_kwargs_overridden_on_response_continuation() -> None: self.captured_kwargs = [] self._asked = False + @overload + def run(self, messages: AgentRunInputs | None = ..., *, stream: Literal[False] = ..., session: AgentSession | None = ..., **kwargs: Any) -> Awaitable[AgentResponse[Any]]: ... + @overload + def run(self, messages: AgentRunInputs | None = ..., *, stream: Literal[True], session: AgentSession | None = ..., **kwargs: Any) -> ResponseStream[AgentResponseUpdate, AgentResponse[Any]]: ... + def run( self, - messages: str | Content | Message | Sequence[str | Content | Message] | None = None, + messages: AgentRunInputs | None = None, *, stream: bool = False, session: AgentSession | None = None, **kwargs: Any, - ) -> Awaitable[AgentResponse] | ResponseStream[AgentResponseUpdate, AgentResponse]: + ) -> Awaitable[AgentResponse[Any]] | ResponseStream[AgentResponseUpdate, AgentResponse[Any]]: self.captured_kwargs.append(dict(kwargs)) if not self._asked: self._asked = True @@ -583,14 +604,19 @@ async def test_kwargs_empty_value_passed_on_continuation() -> None: self.captured_kwargs = [] self._asked = False + @overload + def run(self, messages: AgentRunInputs | None = ..., *, stream: Literal[False] = ..., session: AgentSession | None = ..., **kwargs: Any) -> Awaitable[AgentResponse[Any]]: ... + @overload + def run(self, messages: AgentRunInputs | None = ..., *, stream: Literal[True], session: AgentSession | None = ..., **kwargs: Any) -> ResponseStream[AgentResponseUpdate, AgentResponse[Any]]: ... + def run( self, - messages: str | Content | Message | Sequence[str | Content | Message] | None = None, + messages: AgentRunInputs | None = None, *, stream: bool = False, session: AgentSession | None = None, **kwargs: Any, - ) -> Awaitable[AgentResponse] | ResponseStream[AgentResponseUpdate, AgentResponse]: + ) -> Awaitable[AgentResponse[Any]] | ResponseStream[AgentResponseUpdate, AgentResponse[Any]]: self.captured_kwargs.append(dict(kwargs)) if not self._asked: self._asked = True @@ -690,8 +716,8 @@ async def test_handoff_kwargs_flow_to_agents() -> None: workflow = ( HandoffBuilder(termination_condition=lambda conv: len(conv) >= 4) - .participants([agent1, agent2]) - .with_start_agent(agent1) + .participants([agent1, agent2]) # type: ignore[list-item] + .with_start_agent(agent1) # type: ignore[arg-type] .with_autonomous_mode() .build() ) diff --git a/python/packages/core/tests/workflow/test_workflow_observability.py b/python/packages/core/tests/workflow/test_workflow_observability.py index b2260abe63..b098fa2771 100644 --- a/python/packages/core/tests/workflow/test_workflow_observability.py +++ b/python/packages/core/tests/workflow/test_workflow_observability.py @@ -109,7 +109,7 @@ async def test_span_creation_and_attributes(span_exporter: InMemorySpanExporter) { "id": "test-workflow-123", "max_iterations": 100, - "model_dump_json": lambda self: '{"id": "test-workflow-123", "type": "mock"}', + "model_dump_json": lambda self: '{"id": "test-workflow-123", "type": "mock"}', # pyright: ignore[reportUnknownLambdaType] }, )(), ) @@ -122,7 +122,7 @@ async def test_span_creation_and_attributes(span_exporter: InMemorySpanExporter) }, ) as workflow_span: workflow_span.add_event(OtelAttr.WORKFLOW_STARTED) - sending_attributes = { + sending_attributes: dict[str, str | int] = { OtelAttr.MESSAGE_TYPE: "ResponseMessage", OtelAttr.MESSAGE_DESTINATION_EXECUTOR_ID: "target-789", } @@ -231,7 +231,7 @@ async def test_trace_context_handling(span_exporter: InMemorySpanExporter) -> No @pytest.mark.parametrize("enable_instrumentation", [False], indirect=True) async def test_trace_context_disabled_when_tracing_disabled( - enable_instrumentation, span_exporter: InMemorySpanExporter + enable_instrumentation: bool, span_exporter: InMemorySpanExporter ) -> None: """Test that no trace context is added when tracing is disabled.""" # Tracing should be disabled by default @@ -313,7 +313,7 @@ async def test_end_to_end_workflow_tracing(span_exporter: InMemorySpanExporter) span_exporter.clear() # Run workflow (this should create run spans) - events = [] + events: list[Any] = [] async for event in workflow.run("test input", stream=True): events.append(event) diff --git a/python/packages/core/tests/workflow/test_workflow_states.py b/python/packages/core/tests/workflow/test_workflow_states.py index 0ccf84b103..34c7e8c93f 100644 --- a/python/packages/core/tests/workflow/test_workflow_states.py +++ b/python/packages/core/tests/workflow/test_workflow_states.py @@ -1,5 +1,7 @@ # Copyright (c) Microsoft. All rights reserved. +from typing import Any + import pytest from typing_extensions import Never @@ -36,16 +38,16 @@ async def test_executor_failed_and_workflow_failed_events_streaming(): events.append(ev) # executor_failed event (type='executor_failed') should be emitted before workflow failed event - executor_failed_events = [e for e in events if isinstance(e, WorkflowEvent) and e.type == "executor_failed"] + executor_failed_events: list[WorkflowEvent[Any]] = [e for e in events if isinstance(e, WorkflowEvent) and e.type == "executor_failed"] assert executor_failed_events, "executor_failed event should be emitted when start executor fails" assert executor_failed_events[0].executor_id == "f" assert executor_failed_events[0].origin is WorkflowEventSource.FRAMEWORK # Workflow-level failure and FAILED status should be surfaced - failed_events = [e for e in events if isinstance(e, WorkflowEvent) and e.type == "failed"] + failed_events: list[WorkflowEvent[Any]] = [e for e in events if isinstance(e, WorkflowEvent) and e.type == "failed"] assert failed_events assert all(e.origin is WorkflowEventSource.FRAMEWORK for e in failed_events) - status = [e for e in events if isinstance(e, WorkflowEvent) and e.type == "status"] + status: list[WorkflowEvent[Any]] = [e for e in events if isinstance(e, WorkflowEvent) and e.type == "status"] assert status and status[-1].state == WorkflowRunState.FAILED assert all(e.origin is WorkflowEventSource.FRAMEWORK for e in status) @@ -94,13 +96,13 @@ async def test_executor_failed_event_from_second_executor_in_chain(): events.append(ev) # executor_failed event should be emitted for the failing executor - executor_failed_events = [e for e in events if isinstance(e, WorkflowEvent) and e.type == "executor_failed"] + executor_failed_events: list[WorkflowEvent[Any]] = [e for e in events if isinstance(e, WorkflowEvent) and e.type == "executor_failed"] assert executor_failed_events, "executor_failed event should be emitted when second executor fails" assert executor_failed_events[0].executor_id == "failing" assert executor_failed_events[0].origin is WorkflowEventSource.FRAMEWORK # Workflow-level failure should also be surfaced - failed_events = [e for e in events if isinstance(e, WorkflowEvent) and e.type == "failed"] + failed_events: list[WorkflowEvent[Any]] = [e for e in events if isinstance(e, WorkflowEvent) and e.type == "failed"] assert failed_events assert all(e.origin is WorkflowEventSource.FRAMEWORK for e in failed_events) diff --git a/python/pyproject.toml b/python/pyproject.toml index af80756bed..6bd15774a9 100644 --- a/python/pyproject.toml +++ b/python/pyproject.toml @@ -184,7 +184,6 @@ omit = [ [tool.pyright] include = ["agent_framework*"] -exclude = ["**/tests/**", "**/.venv/**", "packages/devui/frontend/**"] typeCheckingMode = "strict" reportUnnecessaryIsInstance = false reportMissingTypeStubs = false