mirror of
https://github.com/microsoft/agent-framework.git
synced 2026-06-16 21:04:09 +08:00
Python: Fix workflow tests pyright warnings (#4362)
* Fix workflow tests pyright warnings * Update uv.lock * Fix pyright * Comments * Update root pyproject pyright setting * Update core pyproject pyright setting * Update core pyproject pyright setting
This commit is contained in:
committed by
GitHub
Unverified
parent
2a20750110
commit
1a8729d5a7
@@ -105,6 +105,7 @@ extend = "../../pyproject.toml"
|
||||
|
||||
[tool.pyright]
|
||||
extends = "../../pyproject.toml"
|
||||
include = ["tests/workflow"]
|
||||
|
||||
[tool.mypy]
|
||||
plugins = ['pydantic.mypy']
|
||||
|
||||
@@ -2,19 +2,20 @@
|
||||
|
||||
import logging
|
||||
from collections.abc import AsyncIterable, Awaitable
|
||||
from typing import TYPE_CHECKING, Any
|
||||
from typing import TYPE_CHECKING, Any, Literal, overload
|
||||
|
||||
import pytest
|
||||
|
||||
from agent_framework import (
|
||||
AgentExecutor,
|
||||
AgentResponse,
|
||||
AgentResponseUpdate,
|
||||
AgentRunInputs,
|
||||
AgentSession,
|
||||
BaseAgent,
|
||||
Content,
|
||||
Message,
|
||||
ResponseStream,
|
||||
WorkflowEvent,
|
||||
WorkflowRunState,
|
||||
)
|
||||
from agent_framework._workflows._agent_executor import AgentExecutorResponse
|
||||
@@ -32,26 +33,56 @@ class _CountingAgent(BaseAgent):
|
||||
super().__init__(**kwargs)
|
||||
self.call_count = 0
|
||||
|
||||
@overload
|
||||
def run(
|
||||
self,
|
||||
messages: str | Message | list[str] | list[Message] | None = None,
|
||||
messages: AgentRunInputs | None = ...,
|
||||
*,
|
||||
stream: Literal[False] = ...,
|
||||
session: AgentSession | None = ...,
|
||||
**kwargs: Any,
|
||||
) -> Awaitable[AgentResponse[Any]]: ...
|
||||
@overload
|
||||
def run(
|
||||
self,
|
||||
messages: AgentRunInputs | None = ...,
|
||||
*,
|
||||
stream: Literal[True],
|
||||
session: AgentSession | None = ...,
|
||||
**kwargs: Any,
|
||||
) -> ResponseStream[AgentResponseUpdate, AgentResponse[Any]]: ...
|
||||
|
||||
def run(
|
||||
self,
|
||||
messages: AgentRunInputs | None = None,
|
||||
*,
|
||||
stream: bool = False,
|
||||
session: AgentSession | None = None,
|
||||
**kwargs: Any,
|
||||
) -> Awaitable[AgentResponse] | ResponseStream[AgentResponseUpdate, AgentResponse]:
|
||||
) -> (
|
||||
Awaitable[AgentResponse[Any]]
|
||||
| ResponseStream[AgentResponseUpdate, AgentResponse[Any]]
|
||||
):
|
||||
self.call_count += 1
|
||||
if stream:
|
||||
|
||||
async def _stream() -> AsyncIterable[AgentResponseUpdate]:
|
||||
yield AgentResponseUpdate(
|
||||
contents=[Content.from_text(text=f"Response #{self.call_count}: {self.name}")]
|
||||
contents=[
|
||||
Content.from_text(
|
||||
text=f"Response #{self.call_count}: {self.name}"
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
return ResponseStream(_stream(), finalizer=AgentResponse.from_updates)
|
||||
|
||||
async def _run() -> AgentResponse:
|
||||
return AgentResponse(messages=[Message("assistant", [f"Response #{self.call_count}: {self.name}"])])
|
||||
return AgentResponse(
|
||||
messages=[
|
||||
Message("assistant", [f"Response #{self.call_count}: {self.name}"])
|
||||
]
|
||||
)
|
||||
|
||||
return _run()
|
||||
|
||||
@@ -63,13 +94,36 @@ class _StreamingHookAgent(BaseAgent):
|
||||
super().__init__(**kwargs)
|
||||
self.result_hook_called = False
|
||||
|
||||
@overload
|
||||
def run(
|
||||
self,
|
||||
messages: str | Message | list[str] | list[Message] | None = None,
|
||||
messages: AgentRunInputs | None = ...,
|
||||
*,
|
||||
stream: Literal[False] = ...,
|
||||
session: AgentSession | None = ...,
|
||||
**kwargs: Any,
|
||||
) -> Awaitable[AgentResponse[Any]]: ...
|
||||
@overload
|
||||
def run(
|
||||
self,
|
||||
messages: AgentRunInputs | None = ...,
|
||||
*,
|
||||
stream: Literal[True],
|
||||
session: AgentSession | None = ...,
|
||||
**kwargs: Any,
|
||||
) -> ResponseStream[AgentResponseUpdate, AgentResponse[Any]]: ...
|
||||
|
||||
def run(
|
||||
self,
|
||||
messages: AgentRunInputs | None = None,
|
||||
*,
|
||||
stream: bool = False,
|
||||
session: AgentSession | None = None,
|
||||
**kwargs: Any,
|
||||
) -> Awaitable[AgentResponse] | ResponseStream[AgentResponseUpdate, AgentResponse]:
|
||||
) -> (
|
||||
Awaitable[AgentResponse[Any]]
|
||||
| ResponseStream[AgentResponseUpdate, AgentResponse[Any]]
|
||||
):
|
||||
if stream:
|
||||
|
||||
async def _stream() -> AsyncIterable[AgentResponseUpdate]:
|
||||
@@ -78,13 +132,15 @@ class _StreamingHookAgent(BaseAgent):
|
||||
role="assistant",
|
||||
)
|
||||
|
||||
async def _mark_result_hook_called(response: AgentResponse) -> AgentResponse:
|
||||
async def _mark_result_hook_called(
|
||||
response: AgentResponse,
|
||||
) -> AgentResponse:
|
||||
self.result_hook_called = True
|
||||
return response
|
||||
|
||||
return ResponseStream(_stream(), finalizer=AgentResponse.from_updates).with_result_hook(
|
||||
_mark_result_hook_called
|
||||
)
|
||||
return ResponseStream(
|
||||
_stream(), finalizer=AgentResponse.from_updates
|
||||
).with_result_hook(_mark_result_hook_called)
|
||||
|
||||
async def _run() -> AgentResponse:
|
||||
return AgentResponse(messages=[Message("assistant", ["hook test"])])
|
||||
@@ -92,7 +148,9 @@ class _StreamingHookAgent(BaseAgent):
|
||||
return _run()
|
||||
|
||||
|
||||
async def test_agent_executor_streaming_finalizes_stream_and_runs_result_hooks() -> None:
|
||||
async def test_agent_executor_streaming_finalizes_stream_and_runs_result_hooks() -> (
|
||||
None
|
||||
):
|
||||
"""AgentExecutor should call get_final_response() so stream result hooks execute."""
|
||||
agent = _StreamingHookAgent(id="hook_agent", name="HookAgent")
|
||||
executor = AgentExecutor(agent, id="hook_exec")
|
||||
@@ -159,7 +217,9 @@ async def test_agent_executor_checkpoint_stores_and_restores_state() -> None:
|
||||
|
||||
executor_state = executor_states[executor.id] # type: ignore[index]
|
||||
assert "cache" in executor_state, "Checkpoint should store executor cache state"
|
||||
assert "agent_session" in executor_state, "Checkpoint should store executor session state"
|
||||
assert "agent_session" in executor_state, (
|
||||
"Checkpoint should store executor session state"
|
||||
)
|
||||
|
||||
# Verify session state structure
|
||||
session_state = executor_state["agent_session"] # type: ignore[index]
|
||||
@@ -180,11 +240,15 @@ async def test_agent_executor_checkpoint_stores_and_restores_state() -> None:
|
||||
assert restored_agent.call_count == 0
|
||||
|
||||
# Build new workflow with the restored executor
|
||||
wf_resume = SequentialBuilder(participants=[restored_executor], checkpoint_storage=storage).build()
|
||||
wf_resume = SequentialBuilder(
|
||||
participants=[restored_executor], checkpoint_storage=storage
|
||||
).build()
|
||||
|
||||
# Resume from checkpoint
|
||||
resumed_output: AgentExecutorResponse | None = None
|
||||
async for ev in wf_resume.run(checkpoint_id=restore_checkpoint.checkpoint_id, stream=True):
|
||||
async for ev in wf_resume.run(
|
||||
checkpoint_id=restore_checkpoint.checkpoint_id, stream=True
|
||||
):
|
||||
if ev.type == "output":
|
||||
resumed_output = ev.data # type: ignore[assignment]
|
||||
if ev.type == "status" and ev.state in (
|
||||
@@ -278,7 +342,7 @@ async def test_agent_executor_run_streaming_with_stream_kwarg_does_not_raise() -
|
||||
workflow = SequentialBuilder(participants=[executor]).build()
|
||||
|
||||
# stream=True at workflow level triggers streaming mode (returns async iterable)
|
||||
events = []
|
||||
events: list[WorkflowEvent] = []
|
||||
async for event in workflow.run("hello", stream=True):
|
||||
events.append(event)
|
||||
assert len(events) > 0
|
||||
@@ -288,10 +352,13 @@ async def test_agent_executor_run_streaming_with_stream_kwarg_does_not_raise() -
|
||||
@pytest.mark.parametrize("reserved_kwarg", ["session", "stream", "messages"])
|
||||
async def test_prepare_agent_run_args_strips_reserved_kwargs(reserved_kwarg: str, caplog: "LogCaptureFixture") -> None:
|
||||
"""_prepare_agent_run_args must remove reserved kwargs and log a warning."""
|
||||
raw = {reserved_kwarg: "should-be-stripped", "custom_key": "keep-me"}
|
||||
raw: dict[str, Any] = {
|
||||
reserved_kwarg: "should-be-stripped",
|
||||
"custom_key": "keep-me",
|
||||
}
|
||||
|
||||
with caplog.at_level(logging.WARNING):
|
||||
run_kwargs, options = AgentExecutor._prepare_agent_run_args(raw)
|
||||
run_kwargs, options = AgentExecutor._prepare_agent_run_args(raw) # pyright: ignore[reportPrivateUsage]
|
||||
|
||||
assert reserved_kwarg not in run_kwargs
|
||||
assert "custom_key" in run_kwargs
|
||||
@@ -302,8 +369,8 @@ async def test_prepare_agent_run_args_strips_reserved_kwargs(reserved_kwarg: str
|
||||
|
||||
async def test_prepare_agent_run_args_preserves_non_reserved_kwargs() -> None:
|
||||
"""Non-reserved workflow kwargs should pass through unchanged."""
|
||||
raw = {"custom_param": "value", "another": 42}
|
||||
run_kwargs, options = AgentExecutor._prepare_agent_run_args(raw)
|
||||
raw: dict[str, Any] = {"custom_param": "value", "another": 42}
|
||||
run_kwargs, _options = AgentExecutor._prepare_agent_run_args(raw) # pyright: ignore[reportPrivateUsage]
|
||||
assert run_kwargs["custom_param"] == "value"
|
||||
assert run_kwargs["another"] == 42
|
||||
|
||||
@@ -312,10 +379,10 @@ async def test_prepare_agent_run_args_strips_all_reserved_kwargs_at_once(
|
||||
caplog: "LogCaptureFixture",
|
||||
) -> None:
|
||||
"""All reserved kwargs should be stripped when supplied together, each emitting a warning."""
|
||||
raw = {"session": "x", "stream": True, "messages": [], "custom": 1}
|
||||
raw: dict[str, Any] = {"session": "x", "stream": True, "messages": [], "custom": 1}
|
||||
|
||||
with caplog.at_level(logging.WARNING):
|
||||
run_kwargs, options = AgentExecutor._prepare_agent_run_args(raw)
|
||||
run_kwargs, options = AgentExecutor._prepare_agent_run_args(raw) # pyright: ignore[reportPrivateUsage]
|
||||
|
||||
assert "session" not in run_kwargs
|
||||
assert "stream" not in run_kwargs
|
||||
@@ -324,7 +391,11 @@ async def test_prepare_agent_run_args_strips_all_reserved_kwargs_at_once(
|
||||
assert options is not None
|
||||
assert options["additional_function_arguments"]["custom"] == 1
|
||||
|
||||
warned_keys = {r.message.split("'")[1] for r in caplog.records if "reserved" in r.message.lower()}
|
||||
warned_keys = {
|
||||
r.message.split("'")[1]
|
||||
for r in caplog.records
|
||||
if "reserved" in r.message.lower()
|
||||
}
|
||||
assert warned_keys == {"session", "stream", "messages"}
|
||||
|
||||
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
"""Tests for AgentExecutor handling of tool calls and results in streaming mode."""
|
||||
|
||||
from collections.abc import AsyncIterable, Awaitable, Mapping, Sequence
|
||||
from typing import Any
|
||||
from typing import Any, Literal, overload
|
||||
|
||||
from typing_extensions import Never
|
||||
|
||||
@@ -13,6 +13,7 @@ from agent_framework import (
|
||||
AgentExecutorResponse,
|
||||
AgentResponse,
|
||||
AgentResponseUpdate,
|
||||
AgentRunInputs,
|
||||
AgentSession,
|
||||
BaseAgent,
|
||||
ChatResponse,
|
||||
@@ -37,18 +38,38 @@ class _ToolCallingAgent(BaseAgent):
|
||||
def __init__(self, **kwargs: Any) -> None:
|
||||
super().__init__(**kwargs)
|
||||
|
||||
@overload
|
||||
def run(
|
||||
self,
|
||||
messages: str | Content | Message | Sequence[str | Content | Message] | None = None,
|
||||
messages: AgentRunInputs | None = ...,
|
||||
*,
|
||||
stream: Literal[False] = ...,
|
||||
session: AgentSession | None = ...,
|
||||
**kwargs: Any,
|
||||
) -> Awaitable[AgentResponse[Any]]: ...
|
||||
|
||||
@overload
|
||||
def run(
|
||||
self,
|
||||
messages: AgentRunInputs | None = ...,
|
||||
*,
|
||||
stream: Literal[True],
|
||||
session: AgentSession | None = ...,
|
||||
**kwargs: Any,
|
||||
) -> ResponseStream[AgentResponseUpdate, AgentResponse[Any]]: ...
|
||||
|
||||
def run(
|
||||
self,
|
||||
messages: AgentRunInputs | None = None,
|
||||
*,
|
||||
stream: bool = False,
|
||||
session: AgentSession | None = None,
|
||||
**kwargs: Any,
|
||||
) -> Awaitable[AgentResponse] | ResponseStream[AgentResponseUpdate, AgentResponse]:
|
||||
) -> Awaitable[AgentResponse[Any]] | ResponseStream[AgentResponseUpdate, AgentResponse[Any]]:
|
||||
if stream:
|
||||
return ResponseStream(self._run_stream_impl(), finalizer=AgentResponse.from_updates)
|
||||
|
||||
async def _run() -> AgentResponse:
|
||||
async def _run() -> AgentResponse[Any]:
|
||||
return AgentResponse(messages=[Message("assistant", ["done"])])
|
||||
|
||||
return _run()
|
||||
@@ -111,6 +132,7 @@ async def test_agent_executor_emits_tool_calls_in_streaming_mode() -> None:
|
||||
# First event: text update
|
||||
assert events[0].data is not None
|
||||
assert events[0].data.contents[0].type == "text"
|
||||
assert events[0].data.contents[0].text is not None
|
||||
assert "Let me search" in events[0].data.contents[0].text
|
||||
|
||||
# Second event: function call
|
||||
@@ -129,6 +151,7 @@ async def test_agent_executor_emits_tool_calls_in_streaming_mode() -> None:
|
||||
# Fourth event: final text
|
||||
assert events[3].data is not None
|
||||
assert events[3].data.contents[0].type == "text"
|
||||
assert events[3].data.contents[0].text is not None
|
||||
assert "sunny" in events[3].data.contents[0].text
|
||||
|
||||
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
# Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
from collections.abc import AsyncIterable
|
||||
from typing import Any
|
||||
from collections.abc import Awaitable
|
||||
from typing import Any, Literal, overload
|
||||
|
||||
from agent_framework import AgentResponse, AgentResponseUpdate, AgentSession, Message
|
||||
from agent_framework import AgentResponse, AgentResponseUpdate, AgentRunInputs, AgentSession, ResponseStream
|
||||
from agent_framework._workflows._agent_utils import resolve_agent_id
|
||||
|
||||
|
||||
@@ -11,40 +11,23 @@ class MockAgent:
|
||||
"""Mock agent for testing agent utilities."""
|
||||
|
||||
def __init__(self, agent_id: str, name: str | None = None) -> None:
|
||||
self._id = agent_id
|
||||
self._name = name
|
||||
self.id: str = agent_id
|
||||
self.name: str | None = name
|
||||
self.description: str | None = None
|
||||
|
||||
@property
|
||||
def id(self) -> str:
|
||||
return self._id
|
||||
|
||||
@property
|
||||
def name(self) -> str | None:
|
||||
return self._name
|
||||
|
||||
@property
|
||||
def display_name(self) -> str:
|
||||
"""Returns the display name of the agent."""
|
||||
...
|
||||
|
||||
@property
|
||||
def description(self) -> str | None:
|
||||
"""Returns the description of the agent."""
|
||||
...
|
||||
|
||||
def run(
|
||||
self,
|
||||
messages: str | Message | list[str] | list[Message] | None = None,
|
||||
*,
|
||||
stream: bool = False,
|
||||
session: AgentSession | None = None,
|
||||
**kwargs: Any,
|
||||
) -> AgentResponse | AsyncIterable[AgentResponseUpdate]: ...
|
||||
@overload
|
||||
def run(self, messages: AgentRunInputs | None = ..., *, stream: Literal[False] = ..., session: AgentSession | None = ..., **kwargs: Any) -> Awaitable[AgentResponse[Any]]: ...
|
||||
@overload
|
||||
def run(self, messages: AgentRunInputs | None = ..., *, stream: Literal[True], session: AgentSession | None = ..., **kwargs: Any) -> ResponseStream[AgentResponseUpdate, AgentResponse[Any]]: ...
|
||||
def run(self, messages: AgentRunInputs | None = None, *, stream: bool = False, session: AgentSession | None = None, **kwargs: Any) -> Awaitable[AgentResponse[Any]] | ResponseStream[AgentResponseUpdate, AgentResponse[Any]]: ...
|
||||
|
||||
def create_session(self, **kwargs: Any) -> AgentSession:
|
||||
"""Creates a new conversation session for the agent."""
|
||||
...
|
||||
|
||||
def get_session(self, *, service_session_id: str, **kwargs: Any) -> AgentSession:
|
||||
return AgentSession()
|
||||
|
||||
|
||||
def test_resolve_agent_id_with_name() -> None:
|
||||
"""Test that resolve_agent_id returns name when agent has a name."""
|
||||
|
||||
@@ -5,6 +5,7 @@ import tempfile
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime, timezone
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
|
||||
@@ -24,7 +25,7 @@ class _TestToolApprovalRequest:
|
||||
"""Request data for tool approval in tests."""
|
||||
|
||||
tool_name: str
|
||||
arguments: dict
|
||||
arguments: dict[str, Any]
|
||||
timestamp: datetime
|
||||
|
||||
|
||||
@@ -41,7 +42,7 @@ class _TestApprovalRequest:
|
||||
"""Approval request data for tests."""
|
||||
|
||||
action: str
|
||||
params: tuple
|
||||
params: tuple[Any, ...]
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -78,8 +79,8 @@ def test_workflow_checkpoint_custom_values():
|
||||
workflow_name="test-workflow-456",
|
||||
graph_signature_hash="test-hash-456",
|
||||
timestamp=custom_timestamp,
|
||||
messages={"executor1": [{"data": "test"}]},
|
||||
pending_request_info_events={"req123": {"data": "test"}},
|
||||
messages={"executor1": [{"data": "test"}]}, # type: ignore[arg-type] # raw dict for serialization test
|
||||
pending_request_info_events={"req123": {"data": "test"}}, # type: ignore[arg-type] # raw dict for serialization test
|
||||
state={"key": "value"},
|
||||
iteration_count=5,
|
||||
metadata={"test": True},
|
||||
@@ -103,7 +104,7 @@ def test_workflow_checkpoint_to_dict():
|
||||
checkpoint_id="test-id",
|
||||
workflow_name="test-workflow",
|
||||
graph_signature_hash="test-hash",
|
||||
messages={"executor1": [{"data": "test"}]},
|
||||
messages={"executor1": [{"data": "test"}]}, # type: ignore[arg-type] # raw dict for serialization test
|
||||
state={"key": "value"},
|
||||
iteration_count=5,
|
||||
)
|
||||
@@ -161,8 +162,8 @@ async def test_memory_checkpoint_storage_save_and_load():
|
||||
checkpoint = WorkflowCheckpoint(
|
||||
workflow_name="test-workflow",
|
||||
graph_signature_hash="test-hash",
|
||||
messages={"executor1": [{"data": "hello"}]},
|
||||
pending_request_info_events={"req123": {"data": "test"}},
|
||||
messages={"executor1": [{"data": "hello"}]}, # type: ignore[arg-type] # raw dict for serialization test
|
||||
pending_request_info_events={"req123": {"data": "test"}}, # type: ignore[arg-type] # raw dict for serialization test
|
||||
)
|
||||
|
||||
# Save checkpoint
|
||||
@@ -776,9 +777,9 @@ async def test_file_checkpoint_storage_save_and_load():
|
||||
checkpoint = WorkflowCheckpoint(
|
||||
workflow_name="test-workflow",
|
||||
graph_signature_hash="test-hash",
|
||||
messages={"executor1": [{"data": "hello", "source_id": "test", "target_id": None}]},
|
||||
messages={"executor1": [{"data": "hello", "source_id": "test", "target_id": None}]}, # type: ignore[arg-type] # raw dict for serialization test
|
||||
state={"key": "value"},
|
||||
pending_request_info_events={"req123": {"data": "test"}},
|
||||
pending_request_info_events={"req123": {"data": "test"}}, # type: ignore[arg-type] # raw dict for serialization test
|
||||
)
|
||||
|
||||
# Save checkpoint
|
||||
@@ -904,9 +905,9 @@ async def test_file_checkpoint_storage_json_serialization():
|
||||
checkpoint = WorkflowCheckpoint(
|
||||
workflow_name="test-workflow",
|
||||
graph_signature_hash="test-hash",
|
||||
messages={"executor1": [{"data": {"nested": {"value": 42}}, "source_id": "test", "target_id": None}]},
|
||||
messages={"executor1": [{"data": {"nested": {"value": 42}}, "source_id": "test", "target_id": None}]}, # type: ignore[arg-type] # raw dict for serialization test
|
||||
state={"list": [1, 2, 3], "dict": {"a": "b", "c": {"d": "e"}}, "bool": True, "null": None},
|
||||
pending_request_info_events={"req123": {"data": "test"}},
|
||||
pending_request_info_events={"req123": {"data": "test"}}, # type: ignore[arg-type] # raw dict for serialization test
|
||||
)
|
||||
|
||||
# Save and load
|
||||
|
||||
@@ -3,11 +3,11 @@
|
||||
import json
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any
|
||||
from typing import Any, cast
|
||||
|
||||
from agent_framework._workflows._checkpoint_encoding import (
|
||||
_PICKLE_MARKER,
|
||||
_TYPE_MARKER,
|
||||
_PICKLE_MARKER, # pyright: ignore[reportPrivateUsage]
|
||||
_TYPE_MARKER, # pyright: ignore[reportPrivateUsage]
|
||||
encode_checkpoint_value,
|
||||
)
|
||||
|
||||
@@ -185,8 +185,9 @@ def test_encode_list_of_dataclasses() -> None:
|
||||
result = encode_checkpoint_value(data)
|
||||
|
||||
assert isinstance(result, list)
|
||||
assert len(result) == 2
|
||||
for item in result:
|
||||
result_list = cast(list[Any], result)
|
||||
assert len(result_list) == 2
|
||||
for item in result_list:
|
||||
assert _PICKLE_MARKER in item
|
||||
|
||||
|
||||
|
||||
@@ -4,6 +4,8 @@ from dataclasses import dataclass
|
||||
from typing import Any
|
||||
from unittest.mock import patch
|
||||
|
||||
from opentelemetry.sdk.trace.export.in_memory_span_exporter import InMemorySpanExporter
|
||||
|
||||
import pytest
|
||||
|
||||
from agent_framework import (
|
||||
@@ -275,6 +277,7 @@ async def test_single_edge_group_send_message_with_condition_pass() -> None:
|
||||
success = await edge_runner.send_message(message, state, ctx)
|
||||
assert success is True
|
||||
assert target.call_count == 1
|
||||
assert target.last_message is not None
|
||||
assert target.last_message.data == "test"
|
||||
|
||||
|
||||
@@ -301,7 +304,7 @@ async def test_single_edge_group_send_message_with_condition_fail() -> None:
|
||||
assert target.call_count == 0
|
||||
|
||||
|
||||
async def test_single_edge_group_tracing_success(span_exporter) -> None:
|
||||
async def test_single_edge_group_tracing_success(span_exporter: InMemorySpanExporter) -> None:
|
||||
"""Test that single edge group processing creates proper success spans."""
|
||||
source = MockExecutor(id="source_executor")
|
||||
target = MockExecutor(id="target_executor")
|
||||
@@ -352,7 +355,7 @@ async def test_single_edge_group_tracing_success(span_exporter) -> None:
|
||||
assert link.context.span_id == int("00f067aa0ba902b7", 16)
|
||||
|
||||
|
||||
async def test_single_edge_group_tracing_condition_failure(span_exporter) -> None:
|
||||
async def test_single_edge_group_tracing_condition_failure(span_exporter: InMemorySpanExporter) -> None:
|
||||
"""Test that single edge group processing creates proper spans for condition failures."""
|
||||
source = MockExecutor(id="source_executor")
|
||||
target = MockExecutor(id="target_executor")
|
||||
@@ -386,7 +389,7 @@ async def test_single_edge_group_tracing_condition_failure(span_exporter) -> Non
|
||||
assert span.attributes.get("edge_group.delivery_status") == EdgeGroupDeliveryStatus.DROPPED_CONDITION_FALSE.value
|
||||
|
||||
|
||||
async def test_single_edge_group_tracing_type_mismatch(span_exporter) -> None:
|
||||
async def test_single_edge_group_tracing_type_mismatch(span_exporter: InMemorySpanExporter) -> None:
|
||||
"""Test that single edge group processing creates proper spans for type mismatches."""
|
||||
source = MockExecutor(id="source_executor")
|
||||
target = MockExecutor(id="target_executor")
|
||||
@@ -421,7 +424,7 @@ async def test_single_edge_group_tracing_type_mismatch(span_exporter) -> None:
|
||||
assert span.attributes.get("edge_group.delivery_status") == EdgeGroupDeliveryStatus.DROPPED_TYPE_MISMATCH.value
|
||||
|
||||
|
||||
async def test_single_edge_group_tracing_target_mismatch(span_exporter) -> None:
|
||||
async def test_single_edge_group_tracing_target_mismatch(span_exporter: InMemorySpanExporter) -> None:
|
||||
"""Test that single edge group processing creates proper spans for target mismatches."""
|
||||
source = MockExecutor(id="source_executor")
|
||||
target = MockExecutor(id="target_executor")
|
||||
@@ -775,7 +778,7 @@ async def test_source_edge_group_with_selection_func_send_message_with_target_in
|
||||
assert success is False
|
||||
|
||||
|
||||
async def test_fan_out_edge_group_tracing_success(span_exporter) -> None:
|
||||
async def test_fan_out_edge_group_tracing_success(span_exporter: InMemorySpanExporter) -> None:
|
||||
"""Test that fan-out edge group processing creates proper success spans."""
|
||||
source = MockExecutor(id="source_executor")
|
||||
target1 = MockExecutor(id="target_executor_1")
|
||||
@@ -827,7 +830,7 @@ async def test_fan_out_edge_group_tracing_success(span_exporter) -> None:
|
||||
assert link.context.span_id == int("00f067aa0ba902b7", 16)
|
||||
|
||||
|
||||
async def test_fan_out_edge_group_tracing_with_target(span_exporter) -> None:
|
||||
async def test_fan_out_edge_group_tracing_with_target(span_exporter: InMemorySpanExporter) -> None:
|
||||
"""Test that fan-out edge group processing creates proper spans for targeted messages."""
|
||||
source = MockExecutor(id="source_executor")
|
||||
target1 = MockExecutor(id="target_executor_1")
|
||||
@@ -994,7 +997,7 @@ async def test_target_edge_group_send_message_with_invalid_data() -> None:
|
||||
assert success is False
|
||||
|
||||
|
||||
async def test_fan_in_edge_group_tracing_buffered(span_exporter) -> None:
|
||||
async def test_fan_in_edge_group_tracing_buffered(span_exporter: InMemorySpanExporter) -> None:
|
||||
"""Test that fan-in edge group processing creates proper spans for buffered messages."""
|
||||
source1 = MockExecutor(id="source_executor_1")
|
||||
source2 = MockExecutor(id="source_executor_2")
|
||||
@@ -1086,7 +1089,7 @@ async def test_fan_in_edge_group_tracing_buffered(span_exporter) -> None:
|
||||
assert link.context.span_id == int("00f067aa0ba902b8", 16)
|
||||
|
||||
|
||||
async def test_fan_in_edge_group_tracing_type_mismatch(span_exporter) -> None:
|
||||
async def test_fan_in_edge_group_tracing_type_mismatch(span_exporter: InMemorySpanExporter) -> None:
|
||||
"""Test that fan-in edge group processing creates proper spans for type mismatches."""
|
||||
source1 = MockExecutor(id="source_executor_1")
|
||||
source2 = MockExecutor(id="source_executor_2")
|
||||
|
||||
@@ -3,8 +3,6 @@
|
||||
from dataclasses import dataclass
|
||||
|
||||
import pytest
|
||||
from typing_extensions import Never
|
||||
|
||||
from agent_framework import (
|
||||
Executor,
|
||||
Message,
|
||||
@@ -16,6 +14,7 @@ from agent_framework import (
|
||||
handler,
|
||||
response_handler,
|
||||
)
|
||||
from typing_extensions import Never
|
||||
|
||||
|
||||
# Module-level types for string forward reference tests
|
||||
@@ -59,7 +58,7 @@ def test_executor_handler_without_annotations():
|
||||
class MockExecutorWithOneHandlerWithoutAnnotations(Executor): # type: ignore
|
||||
"""A mock executor with one handler that does not implement any annotations."""
|
||||
|
||||
@handler
|
||||
@handler # pyright: ignore[reportUnknownArgumentType]
|
||||
async def handle(self, message, ctx) -> None: # type: ignore
|
||||
"""A mock handler that does not implement any annotations."""
|
||||
pass
|
||||
@@ -156,7 +155,11 @@ async def test_executor_invoked_event_contains_input_data():
|
||||
workflow = WorkflowBuilder(start_executor=upper).add_edge(upper, collector).build()
|
||||
|
||||
events = await workflow.run("hello world")
|
||||
invoked_events = [e for e in events if isinstance(e, WorkflowEvent) and e.type == "executor_invoked"]
|
||||
invoked_events = [
|
||||
e
|
||||
for e in events
|
||||
if isinstance(e, WorkflowEvent) and e.type == "executor_invoked"
|
||||
]
|
||||
|
||||
assert len(invoked_events) == 2
|
||||
|
||||
@@ -190,10 +193,16 @@ async def test_executor_completed_event_contains_sent_messages():
|
||||
sender = MultiSenderExecutor(id="sender")
|
||||
collector = CollectorExecutor(id="collector")
|
||||
|
||||
workflow = WorkflowBuilder(start_executor=sender).add_edge(sender, collector).build()
|
||||
workflow = (
|
||||
WorkflowBuilder(start_executor=sender).add_edge(sender, collector).build()
|
||||
)
|
||||
|
||||
events = await workflow.run("hello")
|
||||
completed_events = [e for e in events if isinstance(e, WorkflowEvent) and e.type == "executor_completed"]
|
||||
completed_events = [
|
||||
e
|
||||
for e in events
|
||||
if isinstance(e, WorkflowEvent) and e.type == "executor_completed"
|
||||
]
|
||||
|
||||
# Sender should have completed with the sent messages
|
||||
sender_completed = next(e for e in completed_events if e.executor_id == "sender")
|
||||
@@ -201,7 +210,9 @@ async def test_executor_completed_event_contains_sent_messages():
|
||||
assert sender_completed.data == ["hello-first", "hello-second"]
|
||||
|
||||
# Collector should have completed with no sent messages (None)
|
||||
collector_completed_events = [e for e in completed_events if e.executor_id == "collector"]
|
||||
collector_completed_events = [
|
||||
e for e in completed_events if e.executor_id == "collector"
|
||||
]
|
||||
# Collector is called twice (once per message from sender)
|
||||
assert len(collector_completed_events) == 2
|
||||
for collector_completed in collector_completed_events:
|
||||
@@ -220,7 +231,11 @@ async def test_executor_completed_event_includes_yielded_outputs():
|
||||
workflow = WorkflowBuilder(start_executor=executor).build()
|
||||
|
||||
events = await workflow.run("test")
|
||||
completed_events = [e for e in events if isinstance(e, WorkflowEvent) and e.type == "executor_completed"]
|
||||
completed_events = [
|
||||
e
|
||||
for e in events
|
||||
if isinstance(e, WorkflowEvent) and e.type == "executor_completed"
|
||||
]
|
||||
|
||||
assert len(completed_events) == 1
|
||||
assert completed_events[0].executor_id == "yielder"
|
||||
@@ -248,7 +263,9 @@ async def test_executor_events_with_complex_message_types():
|
||||
|
||||
class ProcessorExecutor(Executor):
|
||||
@handler
|
||||
async def handle(self, request: Request, ctx: WorkflowContext[Response]) -> None:
|
||||
async def handle(
|
||||
self, request: Request, ctx: WorkflowContext[Response]
|
||||
) -> None:
|
||||
response = Response(results=[request.query.upper()] * request.limit)
|
||||
await ctx.send_message(response)
|
||||
|
||||
@@ -260,13 +277,23 @@ async def test_executor_events_with_complex_message_types():
|
||||
processor = ProcessorExecutor(id="processor")
|
||||
collector = CollectorExecutor(id="collector")
|
||||
|
||||
workflow = WorkflowBuilder(start_executor=processor).add_edge(processor, collector).build()
|
||||
workflow = (
|
||||
WorkflowBuilder(start_executor=processor).add_edge(processor, collector).build()
|
||||
)
|
||||
|
||||
input_request = Request(query="hello", limit=3)
|
||||
events = await workflow.run(input_request)
|
||||
|
||||
invoked_events = [e for e in events if isinstance(e, WorkflowEvent) and e.type == "executor_invoked"]
|
||||
completed_events = [e for e in events if isinstance(e, WorkflowEvent) and e.type == "executor_completed"]
|
||||
invoked_events = [
|
||||
e
|
||||
for e in events
|
||||
if isinstance(e, WorkflowEvent) and e.type == "executor_invoked"
|
||||
]
|
||||
completed_events = [
|
||||
e
|
||||
for e in events
|
||||
if isinstance(e, WorkflowEvent) and e.type == "executor_completed"
|
||||
]
|
||||
|
||||
# Check processor invoked event has the Request object
|
||||
processor_invoked = next(e for e in invoked_events if e.executor_id == "processor")
|
||||
@@ -275,7 +302,9 @@ async def test_executor_events_with_complex_message_types():
|
||||
assert processor_invoked.data.limit == 3
|
||||
|
||||
# Check processor completed event has the Response object
|
||||
processor_completed = next(e for e in completed_events if e.executor_id == "processor")
|
||||
processor_completed = next(
|
||||
e for e in completed_events if e.executor_id == "processor"
|
||||
)
|
||||
assert processor_completed.data is not None
|
||||
assert len(processor_completed.data) == 1
|
||||
assert isinstance(processor_completed.data[0], Response)
|
||||
@@ -361,7 +390,9 @@ def test_executor_workflow_output_types_property():
|
||||
# Test executor with union workflow output types
|
||||
class UnionWorkflowOutputExecutor(Executor):
|
||||
@handler
|
||||
async def handle(self, text: str, ctx: WorkflowContext[int, str | bool]) -> None:
|
||||
async def handle(
|
||||
self, text: str, ctx: WorkflowContext[int, str | bool]
|
||||
) -> None:
|
||||
pass
|
||||
|
||||
executor = UnionWorkflowOutputExecutor(id="union_workflow_output")
|
||||
@@ -372,11 +403,15 @@ def test_executor_workflow_output_types_property():
|
||||
# Test executor with multiple handlers having different workflow output types
|
||||
class MultiHandlerWorkflowExecutor(Executor):
|
||||
@handler
|
||||
async def handle_string(self, text: str, ctx: WorkflowContext[int, str]) -> None:
|
||||
async def handle_string(
|
||||
self, text: str, ctx: WorkflowContext[int, str]
|
||||
) -> None:
|
||||
pass
|
||||
|
||||
@handler
|
||||
async def handle_number(self, num: int, ctx: WorkflowContext[bool, float]) -> None:
|
||||
async def handle_number(
|
||||
self, num: int, ctx: WorkflowContext[bool, float]
|
||||
) -> None:
|
||||
pass
|
||||
|
||||
executor = MultiHandlerWorkflowExecutor(id="multi_workflow")
|
||||
@@ -430,7 +465,9 @@ def test_executor_output_types_includes_response_handlers():
|
||||
pass
|
||||
|
||||
@response_handler
|
||||
async def handle_response(self, original_request: str, response: bool, ctx: WorkflowContext[float]) -> None:
|
||||
async def handle_response(
|
||||
self, original_request: str, response: bool, ctx: WorkflowContext[float]
|
||||
) -> None:
|
||||
pass
|
||||
|
||||
executor = RequestResponseExecutor(id="request_response")
|
||||
@@ -452,7 +489,10 @@ def test_executor_workflow_output_types_includes_response_handlers():
|
||||
|
||||
@response_handler
|
||||
async def handle_response(
|
||||
self, original_request: str, response: bool, ctx: WorkflowContext[float, bool]
|
||||
self,
|
||||
original_request: str,
|
||||
response: bool,
|
||||
ctx: WorkflowContext[float, bool],
|
||||
) -> None:
|
||||
pass
|
||||
|
||||
@@ -509,7 +549,10 @@ def test_executor_response_handler_union_output_types():
|
||||
|
||||
@response_handler
|
||||
async def handle_response(
|
||||
self, original_request: str, response: bool, ctx: WorkflowContext[int | str | float, bool | int]
|
||||
self,
|
||||
original_request: str,
|
||||
response: bool,
|
||||
ctx: WorkflowContext[int | str | float, bool | int],
|
||||
) -> None:
|
||||
pass
|
||||
|
||||
@@ -531,7 +574,9 @@ async def test_executor_invoked_event_data_not_mutated_by_handler():
|
||||
"""Test that executor_invoked event (type='executor_invoked').data captures original input, not mutated input."""
|
||||
|
||||
@executor(id="Mutator")
|
||||
async def mutator(messages: list[Message], ctx: WorkflowContext[list[Message]]) -> None:
|
||||
async def mutator(
|
||||
messages: list[Message], ctx: WorkflowContext[list[Message]]
|
||||
) -> None:
|
||||
# The handler mutates the input list by appending new messages
|
||||
original_len = len(messages)
|
||||
messages.append(Message(role="assistant", text="Added by executor"))
|
||||
@@ -546,7 +591,11 @@ async def test_executor_invoked_event_data_not_mutated_by_handler():
|
||||
events = await workflow.run(input_messages)
|
||||
|
||||
# Find the invoked event for the Mutator executor
|
||||
invoked_events = [e for e in events if isinstance(e, WorkflowEvent) and e.type == "executor_invoked"]
|
||||
invoked_events = [
|
||||
e
|
||||
for e in events
|
||||
if isinstance(e, WorkflowEvent) and e.type == "executor_invoked"
|
||||
]
|
||||
assert len(invoked_events) == 1
|
||||
mutator_invoked = invoked_events[0]
|
||||
|
||||
@@ -577,8 +626,8 @@ class TestHandlerExplicitTypes:
|
||||
exec_instance = ExplicitInputExecutor(id="explicit_input")
|
||||
|
||||
# Handler should be registered for str (explicit), not Any (introspected)
|
||||
assert str in exec_instance._handlers
|
||||
assert len(exec_instance._handlers) == 1
|
||||
assert str in exec_instance._handlers # pyright: ignore[reportPrivateUsage]
|
||||
assert len(exec_instance._handlers) == 1 # pyright: ignore[reportPrivateUsage]
|
||||
|
||||
# Can handle str messages
|
||||
assert exec_instance.can_handle(WorkflowMessage(data="hello", source_id="mock"))
|
||||
@@ -596,8 +645,8 @@ class TestHandlerExplicitTypes:
|
||||
exec_instance = ExplicitOutputExecutor(id="explicit_output")
|
||||
|
||||
# Handler spec should have int as output type (explicit)
|
||||
handler_func = exec_instance._handlers[str]
|
||||
assert handler_func._handler_spec["output_types"] == [int]
|
||||
handler_func = exec_instance._handlers[str] # pyright: ignore[reportPrivateUsage]
|
||||
assert handler_func._handler_spec["output_types"] == [int] # pyright: ignore[reportFunctionMemberAccess]
|
||||
|
||||
# Executor output_types property should reflect explicit type
|
||||
assert int in exec_instance.output_types
|
||||
@@ -615,16 +664,20 @@ class TestHandlerExplicitTypes:
|
||||
exec_instance = ExplicitBothExecutor(id="explicit_both")
|
||||
|
||||
# Handler should be registered for dict (explicit input type)
|
||||
assert dict in exec_instance._handlers
|
||||
assert len(exec_instance._handlers) == 1
|
||||
assert dict in exec_instance._handlers # pyright: ignore[reportPrivateUsage]
|
||||
assert len(exec_instance._handlers) == 1 # pyright: ignore[reportPrivateUsage]
|
||||
|
||||
# Output type should be list (explicit)
|
||||
handler_func = exec_instance._handlers[dict]
|
||||
assert handler_func._handler_spec["output_types"] == [list]
|
||||
handler_func = exec_instance._handlers[dict] # pyright: ignore[reportPrivateUsage]
|
||||
assert handler_func._handler_spec["output_types"] == [list] # pyright: ignore[reportFunctionMemberAccess]
|
||||
|
||||
# Verify can_handle
|
||||
assert exec_instance.can_handle(WorkflowMessage(data={"key": "value"}, source_id="mock"))
|
||||
assert not exec_instance.can_handle(WorkflowMessage(data="string", source_id="mock"))
|
||||
assert exec_instance.can_handle(
|
||||
WorkflowMessage(data={"key": "value"}, source_id="mock")
|
||||
)
|
||||
assert not exec_instance.can_handle(
|
||||
WorkflowMessage(data="string", source_id="mock")
|
||||
)
|
||||
|
||||
def test_handler_with_explicit_union_input_type(self):
|
||||
"""Test that explicit union input_type is handled correctly."""
|
||||
@@ -639,13 +692,15 @@ class TestHandlerExplicitTypes:
|
||||
|
||||
# Handler should be registered for the union type
|
||||
# The union type itself is stored as the key
|
||||
assert len(exec_instance._handlers) == 1
|
||||
assert len(exec_instance._handlers) == 1 # pyright: ignore[reportPrivateUsage]
|
||||
|
||||
# Can handle both str and int messages
|
||||
assert exec_instance.can_handle(WorkflowMessage(data="hello", source_id="mock"))
|
||||
assert exec_instance.can_handle(WorkflowMessage(data=42, source_id="mock"))
|
||||
# Cannot handle float
|
||||
assert not exec_instance.can_handle(WorkflowMessage(data=3.14, source_id="mock"))
|
||||
assert not exec_instance.can_handle(
|
||||
WorkflowMessage(data=3.14, source_id="mock")
|
||||
)
|
||||
|
||||
def test_handler_with_explicit_union_output_type(self):
|
||||
"""Test that explicit union output is normalized to a list."""
|
||||
@@ -674,8 +729,8 @@ class TestHandlerExplicitTypes:
|
||||
exec_instance = PrecedenceExecutor(id="precedence")
|
||||
|
||||
# Should use explicit input type (bytes), not introspected (str)
|
||||
assert bytes in exec_instance._handlers
|
||||
assert str not in exec_instance._handlers
|
||||
assert bytes in exec_instance._handlers # pyright: ignore[reportPrivateUsage]
|
||||
assert str not in exec_instance._handlers # pyright: ignore[reportPrivateUsage]
|
||||
|
||||
# Should use explicit output type (float), not introspected (int)
|
||||
assert float in exec_instance.output_types
|
||||
@@ -692,7 +747,7 @@ class TestHandlerExplicitTypes:
|
||||
exec_instance = IntrospectedExecutor(id="introspected")
|
||||
|
||||
# Should use introspected types
|
||||
assert str in exec_instance._handlers
|
||||
assert str in exec_instance._handlers # pyright: ignore[reportPrivateUsage]
|
||||
assert int in exec_instance.output_types
|
||||
|
||||
def test_handler_explicit_mode_requires_input(self):
|
||||
@@ -705,13 +760,13 @@ class TestHandlerExplicitTypes:
|
||||
pass
|
||||
|
||||
exec_input = OnlyInputExecutor(id="only_input")
|
||||
assert bytes in exec_input._handlers # Explicit
|
||||
assert bytes in exec_input._handlers # pyright: ignore[reportPrivateUsage] # Explicit
|
||||
assert exec_input.output_types == [] # No output types (not introspected)
|
||||
|
||||
# Only explicit output without input should raise error
|
||||
with pytest.raises(ValueError, match="must specify 'input' type"):
|
||||
|
||||
class OnlyOutputExecutor(Executor):
|
||||
class OnlyOutputExecutor(Executor): # pyright: ignore[reportUnusedClass]
|
||||
@handler(output=float)
|
||||
async def handle(self, message: str, ctx: WorkflowContext[int]) -> None:
|
||||
pass
|
||||
@@ -719,9 +774,11 @@ class TestHandlerExplicitTypes:
|
||||
# Only explicit workflow_output without input should raise error
|
||||
with pytest.raises(ValueError, match="must specify 'input' type"):
|
||||
|
||||
class OnlyWorkflowOutputExecutor(Executor):
|
||||
class OnlyWorkflowOutputExecutor(Executor): # pyright: ignore[reportUnusedClass]
|
||||
@handler(workflow_output=bool)
|
||||
async def handle(self, message: str, ctx: WorkflowContext[int, str]) -> None:
|
||||
async def handle(
|
||||
self, message: str, ctx: WorkflowContext[int, str]
|
||||
) -> None:
|
||||
pass
|
||||
|
||||
def test_handler_explicit_input_type_allows_no_message_annotation(self):
|
||||
@@ -734,8 +791,7 @@ class TestHandlerExplicitTypes:
|
||||
|
||||
exec_instance = NoAnnotationExecutor(id="no_annotation")
|
||||
|
||||
# Should work with explicit input_type
|
||||
assert str in exec_instance._handlers
|
||||
assert str in exec_instance._handlers # pyright: ignore[reportPrivateUsage]
|
||||
assert exec_instance.can_handle(WorkflowMessage(data="hello", source_id="mock"))
|
||||
|
||||
def test_handler_multiple_handlers_mixed_explicit_and_introspected(self):
|
||||
@@ -747,15 +803,17 @@ class TestHandlerExplicitTypes:
|
||||
pass
|
||||
|
||||
@handler
|
||||
async def handle_introspected(self, message: float, ctx: WorkflowContext[bool]) -> None:
|
||||
async def handle_introspected(
|
||||
self, message: float, ctx: WorkflowContext[bool]
|
||||
) -> None:
|
||||
pass
|
||||
|
||||
exec_instance = MixedExecutor(id="mixed")
|
||||
|
||||
# Should have both handlers
|
||||
assert len(exec_instance._handlers) == 2
|
||||
assert str in exec_instance._handlers # Explicit
|
||||
assert float in exec_instance._handlers # Introspected
|
||||
assert len(exec_instance._handlers) == 2 # pyright: ignore[reportPrivateUsage]
|
||||
assert str in exec_instance._handlers # pyright: ignore[reportPrivateUsage] # Explicit
|
||||
assert float in exec_instance._handlers # pyright: ignore[reportPrivateUsage] # Introspected
|
||||
|
||||
# Should have both output types
|
||||
assert int in exec_instance.output_types # Explicit
|
||||
@@ -772,8 +830,10 @@ class TestHandlerExplicitTypes:
|
||||
exec_instance = StringRefExecutor(id="string_ref")
|
||||
|
||||
# Should resolve the string to the actual type
|
||||
assert ForwardRefMessage in exec_instance._handlers
|
||||
assert exec_instance.can_handle(WorkflowMessage(data=ForwardRefMessage("hello"), source_id="mock"))
|
||||
assert ForwardRefMessage in exec_instance._handlers # pyright: ignore[reportPrivateUsage]
|
||||
assert exec_instance.can_handle(
|
||||
WorkflowMessage(data=ForwardRefMessage("hello"), source_id="mock")
|
||||
)
|
||||
|
||||
def test_handler_with_string_forward_reference_union(self):
|
||||
"""Test that string forward references work with union types."""
|
||||
@@ -786,8 +846,12 @@ class TestHandlerExplicitTypes:
|
||||
exec_instance = StringUnionExecutor(id="string_union")
|
||||
|
||||
# Should handle both types
|
||||
assert exec_instance.can_handle(WorkflowMessage(data=ForwardRefTypeA("hello"), source_id="mock"))
|
||||
assert exec_instance.can_handle(WorkflowMessage(data=ForwardRefTypeB(42), source_id="mock"))
|
||||
assert exec_instance.can_handle(
|
||||
WorkflowMessage(data=ForwardRefTypeA("hello"), source_id="mock")
|
||||
)
|
||||
assert exec_instance.can_handle(
|
||||
WorkflowMessage(data=ForwardRefTypeB(42), source_id="mock")
|
||||
)
|
||||
|
||||
def test_handler_with_string_forward_reference_output_type(self):
|
||||
"""Test that string forward references work for output_type."""
|
||||
@@ -813,8 +877,8 @@ class TestHandlerExplicitTypes:
|
||||
exec_instance = ExplicitWorkflowOutputExecutor(id="explicit_workflow_output")
|
||||
|
||||
# Handler spec should have bool as workflow_output_type (explicit)
|
||||
handler_func = exec_instance._handlers[str]
|
||||
assert handler_func._handler_spec["workflow_output_types"] == [bool]
|
||||
handler_func = exec_instance._handlers[str] # pyright: ignore[reportPrivateUsage]
|
||||
assert handler_func._handler_spec["workflow_output_types"] == [bool] # pyright: ignore[reportFunctionMemberAccess]
|
||||
|
||||
# Executor workflow_output_types property should reflect explicit type
|
||||
assert bool in exec_instance.workflow_output_types
|
||||
@@ -826,13 +890,14 @@ class TestHandlerExplicitTypes:
|
||||
|
||||
class PrecedenceExecutor(Executor):
|
||||
@handler(input=int, output=float, workflow_output=str)
|
||||
async def handle(self, message: int, ctx: WorkflowContext[int, bool]) -> None:
|
||||
async def handle(
|
||||
self, message: int, ctx: WorkflowContext[int, bool]
|
||||
) -> None:
|
||||
pass
|
||||
|
||||
exec_instance = PrecedenceExecutor(id="precedence")
|
||||
|
||||
# All types should come from explicit params
|
||||
assert int in exec_instance._handlers
|
||||
assert int in exec_instance._handlers # pyright: ignore[reportPrivateUsage]
|
||||
assert float in exec_instance.output_types
|
||||
assert str in exec_instance.workflow_output_types
|
||||
# Introspected types should NOT be present
|
||||
@@ -849,8 +914,7 @@ class TestHandlerExplicitTypes:
|
||||
|
||||
exec_instance = AllExplicitExecutor(id="all_explicit")
|
||||
|
||||
# Check input type
|
||||
assert str in exec_instance._handlers
|
||||
assert str in exec_instance._handlers # pyright: ignore[reportPrivateUsage]
|
||||
assert exec_instance.can_handle(WorkflowMessage(data="hello", source_id="mock"))
|
||||
|
||||
# Check output_type
|
||||
@@ -894,7 +958,9 @@ class TestHandlerExplicitTypes:
|
||||
async def handle(self, message, ctx: WorkflowContext) -> None: # type: ignore[no-untyped-def]
|
||||
pass
|
||||
|
||||
exec_instance = StringUnionWorkflowOutputExecutor(id="string_union_workflow_output")
|
||||
exec_instance = StringUnionWorkflowOutputExecutor(
|
||||
id="string_union_workflow_output"
|
||||
)
|
||||
|
||||
# Should resolve both types from string union
|
||||
assert ForwardRefTypeA in exec_instance.workflow_output_types
|
||||
@@ -905,10 +971,14 @@ class TestHandlerExplicitTypes:
|
||||
|
||||
class IntrospectedWorkflowOutputExecutor(Executor):
|
||||
@handler
|
||||
async def handle(self, message: str, ctx: WorkflowContext[int, bool]) -> None:
|
||||
async def handle(
|
||||
self, message: str, ctx: WorkflowContext[int, bool]
|
||||
) -> None:
|
||||
pass
|
||||
|
||||
exec_instance = IntrospectedWorkflowOutputExecutor(id="introspected_workflow_output")
|
||||
exec_instance = IntrospectedWorkflowOutputExecutor(
|
||||
id="introspected_workflow_output"
|
||||
)
|
||||
|
||||
# Should use introspected types from WorkflowContext[int, bool]
|
||||
assert int in exec_instance.output_types
|
||||
|
||||
@@ -34,8 +34,8 @@ class TestExecutorFutureAnnotations:
|
||||
pass
|
||||
|
||||
exec_instance = MyExecutor(id="test")
|
||||
assert str in exec_instance._handlers
|
||||
spec = exec_instance._handler_specs[0]
|
||||
assert str in exec_instance._handlers # pyright: ignore[reportPrivateUsage]
|
||||
spec = exec_instance._handler_specs[0] # pyright: ignore[reportPrivateUsage]
|
||||
assert spec["message_type"] is str
|
||||
assert spec["output_types"] == [MyTypeA]
|
||||
assert spec["workflow_output_types"] == [MyTypeB]
|
||||
@@ -49,8 +49,8 @@ class TestExecutorFutureAnnotations:
|
||||
pass
|
||||
|
||||
exec_instance = MyExecutor(id="test")
|
||||
assert int in exec_instance._handlers
|
||||
spec = exec_instance._handler_specs[0]
|
||||
assert int in exec_instance._handlers # pyright: ignore[reportPrivateUsage]
|
||||
spec = exec_instance._handler_specs[0] # pyright: ignore[reportPrivateUsage]
|
||||
assert spec["message_type"] is int
|
||||
assert spec["output_types"] == [MyTypeA]
|
||||
|
||||
@@ -63,7 +63,7 @@ class TestExecutorFutureAnnotations:
|
||||
pass
|
||||
|
||||
exec_instance = MyExecutor(id="test")
|
||||
spec = exec_instance._handler_specs[0]
|
||||
spec = exec_instance._handler_specs[0] # pyright: ignore[reportPrivateUsage]
|
||||
assert spec["message_type"] == dict[str, Any]
|
||||
assert spec["output_types"] == [list[str]]
|
||||
|
||||
@@ -76,8 +76,8 @@ class TestExecutorFutureAnnotations:
|
||||
pass
|
||||
|
||||
exec_instance = MyExecutor(id="test")
|
||||
assert str in exec_instance._handlers
|
||||
spec = exec_instance._handler_specs[0]
|
||||
assert str in exec_instance._handlers # pyright: ignore[reportPrivateUsage]
|
||||
spec = exec_instance._handler_specs[0] # pyright: ignore[reportPrivateUsage]
|
||||
assert spec["output_types"] == []
|
||||
assert spec["workflow_output_types"] == []
|
||||
|
||||
@@ -86,12 +86,12 @@ class TestExecutorFutureAnnotations:
|
||||
|
||||
class MyExecutor(Executor):
|
||||
@handler(input=str, output=MyTypeA)
|
||||
async def example(self, input, ctx) -> None:
|
||||
async def example(self, input, ctx) -> None: # type: ignore[no-untyped-def]
|
||||
pass
|
||||
|
||||
exec_instance = MyExecutor(id="test")
|
||||
assert str in exec_instance._handlers
|
||||
spec = exec_instance._handler_specs[0]
|
||||
assert str in exec_instance._handlers # pyright: ignore[reportPrivateUsage]
|
||||
spec = exec_instance._handler_specs[0] # pyright: ignore[reportPrivateUsage]
|
||||
assert spec["message_type"] is str
|
||||
assert spec["output_types"] == [MyTypeA]
|
||||
|
||||
@@ -104,8 +104,8 @@ class TestExecutorFutureAnnotations:
|
||||
pass
|
||||
|
||||
exec_instance = MyExecutor(id="test")
|
||||
assert str in exec_instance._handlers
|
||||
spec = exec_instance._handler_specs[0]
|
||||
assert str in exec_instance._handlers # pyright: ignore[reportPrivateUsage]
|
||||
spec = exec_instance._handler_specs[0] # pyright: ignore[reportPrivateUsage]
|
||||
assert spec["output_types"] == [MyTypeA, MyTypeB]
|
||||
assert spec["workflow_output_types"] == [MyTypeC]
|
||||
|
||||
@@ -118,7 +118,7 @@ class TestExecutorFutureAnnotations:
|
||||
"""
|
||||
with pytest.raises(ValueError):
|
||||
|
||||
class Bad(Executor):
|
||||
@handler
|
||||
async def example(self, input: NonExistentType, ctx: WorkflowContext[MyTypeA, MyTypeB]) -> None: # noqa: F821
|
||||
class Bad(Executor): # pyright: ignore[reportUnusedClass]
|
||||
@handler # pyright: ignore[reportUnknownArgumentType]
|
||||
async def example(self, input: NonExistentType, ctx: WorkflowContext[MyTypeA, MyTypeB]) -> None: # noqa: F821 # type: ignore[name-defined]
|
||||
pass
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
# Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
from collections.abc import AsyncIterable, Awaitable, Sequence
|
||||
from typing import Any
|
||||
from collections.abc import AsyncIterable, Awaitable
|
||||
from typing import Any, Literal, overload
|
||||
|
||||
import pytest
|
||||
from pydantic import PrivateAttr
|
||||
@@ -13,6 +13,7 @@ from agent_framework import (
|
||||
AgentExecutorResponse,
|
||||
AgentResponse,
|
||||
AgentResponseUpdate,
|
||||
AgentRunInputs,
|
||||
AgentSession,
|
||||
BaseAgent,
|
||||
Content,
|
||||
@@ -34,14 +35,32 @@ class _SimpleAgent(BaseAgent):
|
||||
super().__init__(**kwargs)
|
||||
self._reply_text = reply_text
|
||||
|
||||
@overload
|
||||
def run(
|
||||
self,
|
||||
messages: str | Content | Message | Sequence[str | Content | Message] | None = None,
|
||||
messages: AgentRunInputs | None = ...,
|
||||
*,
|
||||
stream: Literal[False] = ...,
|
||||
session: AgentSession | None = ...,
|
||||
**kwargs: Any,
|
||||
) -> Awaitable[AgentResponse[Any]]: ...
|
||||
@overload
|
||||
def run(
|
||||
self,
|
||||
messages: AgentRunInputs | None = ...,
|
||||
*,
|
||||
stream: Literal[True],
|
||||
session: AgentSession | None = ...,
|
||||
**kwargs: Any,
|
||||
) -> ResponseStream[AgentResponseUpdate, AgentResponse[Any]]: ...
|
||||
def run(
|
||||
self,
|
||||
messages: AgentRunInputs | None = None,
|
||||
*,
|
||||
stream: bool = False,
|
||||
session: AgentSession | None = None,
|
||||
**kwargs: Any,
|
||||
) -> Awaitable[AgentResponse] | ResponseStream[AgentResponseUpdate, AgentResponse]:
|
||||
) -> Awaitable[AgentResponse[Any]] | ResponseStream[AgentResponseUpdate, AgentResponse[Any]]:
|
||||
if stream:
|
||||
|
||||
async def _stream() -> AsyncIterable[AgentResponseUpdate]:
|
||||
@@ -81,14 +100,32 @@ class _ToolHistoryAgent(BaseAgent):
|
||||
Message(role="assistant", contents=[Content.from_text(text=self._summary_text)]),
|
||||
]
|
||||
|
||||
@overload
|
||||
def run(
|
||||
self,
|
||||
messages: str | Content | Message | Sequence[str | Content | Message] | None = None,
|
||||
messages: AgentRunInputs | None = ...,
|
||||
*,
|
||||
stream: Literal[False] = ...,
|
||||
session: AgentSession | None = ...,
|
||||
**kwargs: Any,
|
||||
) -> Awaitable[AgentResponse[Any]]: ...
|
||||
@overload
|
||||
def run(
|
||||
self,
|
||||
messages: AgentRunInputs | None = ...,
|
||||
*,
|
||||
stream: Literal[True],
|
||||
session: AgentSession | None = ...,
|
||||
**kwargs: Any,
|
||||
) -> ResponseStream[AgentResponseUpdate, AgentResponse[Any]]: ...
|
||||
def run(
|
||||
self,
|
||||
messages: AgentRunInputs | None = None,
|
||||
*,
|
||||
stream: bool = False,
|
||||
session: AgentSession | None = None,
|
||||
**kwargs: Any,
|
||||
) -> Awaitable[AgentResponse] | ResponseStream[AgentResponseUpdate, AgentResponse]:
|
||||
) -> Awaitable[AgentResponse[Any]] | ResponseStream[AgentResponseUpdate, AgentResponse[Any]]:
|
||||
if stream:
|
||||
|
||||
async def _stream() -> AsyncIterable[AgentResponseUpdate]:
|
||||
@@ -165,14 +202,32 @@ class _CaptureAgent(BaseAgent):
|
||||
super().__init__(**kwargs)
|
||||
self._reply_text = reply_text
|
||||
|
||||
@overload
|
||||
def run(
|
||||
self,
|
||||
messages: str | Content | Message | Sequence[str | Content | Message] | None = None,
|
||||
messages: AgentRunInputs | None = ...,
|
||||
*,
|
||||
stream: Literal[False] = ...,
|
||||
session: AgentSession | None = ...,
|
||||
**kwargs: Any,
|
||||
) -> Awaitable[AgentResponse[Any]]: ...
|
||||
@overload
|
||||
def run(
|
||||
self,
|
||||
messages: AgentRunInputs | None = ...,
|
||||
*,
|
||||
stream: Literal[True],
|
||||
session: AgentSession | None = ...,
|
||||
**kwargs: Any,
|
||||
) -> ResponseStream[AgentResponseUpdate, AgentResponse[Any]]: ...
|
||||
def run(
|
||||
self,
|
||||
messages: AgentRunInputs | None = None,
|
||||
*,
|
||||
stream: bool = False,
|
||||
session: AgentSession | None = None,
|
||||
**kwargs: Any,
|
||||
) -> Awaitable[AgentResponse] | ResponseStream[AgentResponseUpdate, AgentResponse]:
|
||||
) -> Awaitable[AgentResponse[Any]] | ResponseStream[AgentResponseUpdate, AgentResponse[Any]]:
|
||||
# Normalize and record messages for verification
|
||||
norm: list[Message] = []
|
||||
if messages:
|
||||
@@ -260,7 +315,7 @@ class _RoundTripCoordinator(Executor):
|
||||
async def handle_response(
|
||||
self,
|
||||
response: AgentExecutorResponse,
|
||||
ctx: WorkflowContext[Never, dict[str, Any]],
|
||||
ctx: WorkflowContext[AgentExecutorRequest, dict[str, Any]],
|
||||
) -> None:
|
||||
self._seen += 1
|
||||
if self._seen == 1:
|
||||
@@ -314,14 +369,32 @@ class _SessionIdCapturingAgent(BaseAgent):
|
||||
|
||||
_captured_service_session_id: str | None = PrivateAttr(default="NOT_CAPTURED")
|
||||
|
||||
@overload
|
||||
def run(
|
||||
self,
|
||||
messages: str | Content | Message | Sequence[str | Content | Message] | None = None,
|
||||
messages: AgentRunInputs | None = ...,
|
||||
*,
|
||||
stream: Literal[False] = ...,
|
||||
session: AgentSession | None = ...,
|
||||
**kwargs: Any,
|
||||
) -> Awaitable[AgentResponse[Any]]: ...
|
||||
@overload
|
||||
def run(
|
||||
self,
|
||||
messages: AgentRunInputs | None = ...,
|
||||
*,
|
||||
stream: Literal[True],
|
||||
session: AgentSession | None = ...,
|
||||
**kwargs: Any,
|
||||
) -> ResponseStream[AgentResponseUpdate, AgentResponse[Any]]: ...
|
||||
def run(
|
||||
self,
|
||||
messages: AgentRunInputs | None = None,
|
||||
*,
|
||||
stream: bool = False,
|
||||
session: AgentSession | None = None,
|
||||
**kwargs: Any,
|
||||
) -> Awaitable[AgentResponse] | ResponseStream[AgentResponseUpdate, AgentResponse]:
|
||||
) -> Awaitable[AgentResponse[Any]] | ResponseStream[AgentResponseUpdate, AgentResponse[Any]]:
|
||||
self._captured_service_session_id = session.service_session_id if session else None
|
||||
|
||||
async def _run() -> AgentResponse:
|
||||
@@ -342,7 +415,7 @@ class _FullHistoryReplayCoordinator(Executor):
|
||||
async def handle(
|
||||
self,
|
||||
response: AgentExecutorResponse,
|
||||
ctx: WorkflowContext[Never, Any],
|
||||
ctx: WorkflowContext[AgentExecutorRequest, Any],
|
||||
) -> None:
|
||||
full_conv = list(response.full_conversation or response.agent_response.messages)
|
||||
full_conv.append(Message(role="user", text="follow-up"))
|
||||
|
||||
@@ -48,12 +48,12 @@ class TestFunctionExecutor:
|
||||
func_exec = FunctionExecutor(process_string)
|
||||
|
||||
# Check that handler was registered
|
||||
assert len(func_exec._handlers) == 1
|
||||
assert str in func_exec._handlers
|
||||
assert len(func_exec._handlers) == 1 # pyright: ignore[reportPrivateUsage]
|
||||
assert str in func_exec._handlers # pyright: ignore[reportPrivateUsage]
|
||||
|
||||
# Check handler spec was created
|
||||
assert len(func_exec._handler_specs) == 1
|
||||
spec = func_exec._handler_specs[0]
|
||||
assert len(func_exec._handler_specs) == 1 # pyright: ignore[reportPrivateUsage]
|
||||
spec = func_exec._handler_specs[0] # pyright: ignore[reportPrivateUsage]
|
||||
assert spec["name"] == "process_string"
|
||||
assert spec["message_type"] is str
|
||||
assert spec["output_types"] == [str]
|
||||
@@ -67,10 +67,10 @@ class TestFunctionExecutor:
|
||||
|
||||
assert isinstance(process_int, FunctionExecutor)
|
||||
assert process_int.id == "test_executor"
|
||||
assert int in process_int._handlers
|
||||
assert int in process_int._handlers # pyright: ignore[reportPrivateUsage]
|
||||
|
||||
# Check spec
|
||||
spec = process_int._handler_specs[0]
|
||||
spec = process_int._handler_specs[0] # pyright: ignore[reportPrivateUsage]
|
||||
assert spec["message_type"] is int
|
||||
assert spec["output_types"] == [int]
|
||||
|
||||
@@ -78,7 +78,7 @@ class TestFunctionExecutor:
|
||||
"""Test @executor decorator uses function name as default ID."""
|
||||
|
||||
@executor
|
||||
async def my_function(data: dict, ctx: WorkflowContext[Any]) -> None:
|
||||
async def my_function(data: dict[str, Any], ctx: WorkflowContext[Any]) -> None:
|
||||
await ctx.send_message(data)
|
||||
|
||||
assert my_function.id == "my_function"
|
||||
@@ -92,7 +92,7 @@ class TestFunctionExecutor:
|
||||
|
||||
assert isinstance(no_parens_function, FunctionExecutor)
|
||||
assert no_parens_function.id == "no_parens_function"
|
||||
assert str in no_parens_function._handlers
|
||||
assert str in no_parens_function._handlers # pyright: ignore[reportPrivateUsage]
|
||||
|
||||
# Also test with single parameter function
|
||||
@executor
|
||||
@@ -101,7 +101,7 @@ class TestFunctionExecutor:
|
||||
|
||||
assert isinstance(simple_no_parens, FunctionExecutor)
|
||||
assert simple_no_parens.id == "simple_no_parens"
|
||||
assert int in simple_no_parens._handlers
|
||||
assert int in simple_no_parens._handlers # pyright: ignore[reportPrivateUsage]
|
||||
|
||||
def test_union_output_types(self):
|
||||
"""Test that union output types are properly inferred for both messages and workflow outputs."""
|
||||
@@ -113,7 +113,7 @@ class TestFunctionExecutor:
|
||||
else:
|
||||
await ctx.send_message(text.upper())
|
||||
|
||||
spec = multi_output._handler_specs[0]
|
||||
spec = multi_output._handler_specs[0] # pyright: ignore[reportPrivateUsage]
|
||||
assert set(spec["output_types"]) == {str, int}
|
||||
assert spec["workflow_output_types"] == [] # No workflow outputs defined
|
||||
|
||||
@@ -127,7 +127,7 @@ class TestFunctionExecutor:
|
||||
else:
|
||||
await ctx.yield_output(data.upper())
|
||||
|
||||
workflow_spec = multi_workflow_output._handler_specs[0]
|
||||
workflow_spec = multi_workflow_output._handler_specs[0] # pyright: ignore[reportPrivateUsage]
|
||||
assert workflow_spec["output_types"] == [] # None means no message outputs
|
||||
assert set(workflow_spec["workflow_output_types"]) == {str, int, bool}
|
||||
|
||||
@@ -139,7 +139,7 @@ class TestFunctionExecutor:
|
||||
# This executor doesn't send any messages
|
||||
pass
|
||||
|
||||
spec = no_output._handler_specs[0]
|
||||
spec = no_output._handler_specs[0] # pyright: ignore[reportPrivateUsage]
|
||||
assert spec["output_types"] == []
|
||||
assert spec["workflow_output_types"] == [] # No workflow outputs defined
|
||||
|
||||
@@ -150,7 +150,7 @@ class TestFunctionExecutor:
|
||||
async def any_output(data: str, ctx: WorkflowContext[Any]) -> None:
|
||||
await ctx.send_message("result")
|
||||
|
||||
spec = any_output._handler_specs[0]
|
||||
spec = any_output._handler_specs[0] # pyright: ignore[reportPrivateUsage]
|
||||
assert spec["output_types"] == [Any]
|
||||
assert spec["workflow_output_types"] == [] # No workflow outputs defined
|
||||
|
||||
@@ -160,7 +160,7 @@ class TestFunctionExecutor:
|
||||
await ctx.send_message("message")
|
||||
await ctx.yield_output("workflow_output")
|
||||
|
||||
both_spec = any_both_output._handler_specs[0]
|
||||
both_spec = any_both_output._handler_specs[0] # pyright: ignore[reportPrivateUsage]
|
||||
assert both_spec["output_types"] == [Any]
|
||||
assert both_spec["workflow_output_types"] == [Any]
|
||||
|
||||
@@ -228,11 +228,11 @@ class TestFunctionExecutor:
|
||||
await ctx.yield_output(result)
|
||||
|
||||
# Verify type inference for both executors
|
||||
upper_spec = to_upper._handler_specs[0]
|
||||
upper_spec = to_upper._handler_specs[0] # pyright: ignore[reportPrivateUsage]
|
||||
assert upper_spec["output_types"] == [str]
|
||||
assert upper_spec["workflow_output_types"] == [] # No workflow outputs
|
||||
|
||||
reverse_spec = reverse_text._handler_specs[0]
|
||||
reverse_spec = reverse_text._handler_specs[0] # pyright: ignore[reportPrivateUsage]
|
||||
assert reverse_spec["output_types"] == [Any] # First parameter is Any
|
||||
assert reverse_spec["workflow_output_types"] == [str] # Second parameter is str
|
||||
|
||||
@@ -270,7 +270,7 @@ class TestFunctionExecutor:
|
||||
await ctx.send_message(message)
|
||||
|
||||
with pytest.raises(ValueError, match="Handler for type .* already registered"):
|
||||
func_exec._register_instance_handler(
|
||||
func_exec._register_instance_handler( # pyright: ignore[reportPrivateUsage]
|
||||
name="second",
|
||||
func=second_handler,
|
||||
message_type=str,
|
||||
@@ -287,7 +287,7 @@ class TestFunctionExecutor:
|
||||
result = {item: len(item) for item in items}
|
||||
await ctx.send_message(result)
|
||||
|
||||
spec = process_list._handler_specs[0]
|
||||
spec = process_list._handler_specs[0] # pyright: ignore[reportPrivateUsage]
|
||||
assert spec["message_type"] == list[str]
|
||||
assert spec["output_types"] == [dict[str, int]]
|
||||
|
||||
@@ -300,10 +300,10 @@ class TestFunctionExecutor:
|
||||
|
||||
assert isinstance(process_simple, FunctionExecutor)
|
||||
assert process_simple.id == "simple_processor"
|
||||
assert str in process_simple._handlers
|
||||
assert str in process_simple._handlers # pyright: ignore[reportPrivateUsage]
|
||||
|
||||
# Check spec - single parameter functions have no output types since they can't send messages
|
||||
spec = process_simple._handler_specs[0]
|
||||
spec = process_simple._handler_specs[0] # pyright: ignore[reportPrivateUsage]
|
||||
assert spec["message_type"] is str
|
||||
assert spec["output_types"] == []
|
||||
assert spec["ctx_annotation"] is None
|
||||
@@ -316,7 +316,7 @@ class TestFunctionExecutor:
|
||||
return data * 2
|
||||
|
||||
func_exec = FunctionExecutor(valid_single)
|
||||
assert int in func_exec._handlers
|
||||
assert int in func_exec._handlers # pyright: ignore[reportPrivateUsage]
|
||||
|
||||
# Single parameter with missing type annotation should still fail
|
||||
async def no_annotation(data): # type: ignore
|
||||
@@ -349,7 +349,7 @@ class TestFunctionExecutor:
|
||||
|
||||
# For testing purposes, we can check that the handler is registered correctly
|
||||
assert double_value.can_handle(WorkflowMessage(data=5, source_id="mock"))
|
||||
assert int in double_value._handlers
|
||||
assert int in double_value._handlers # pyright: ignore[reportPrivateUsage]
|
||||
|
||||
def test_sync_function_basic(self):
|
||||
"""Test basic synchronous function support."""
|
||||
@@ -360,10 +360,10 @@ class TestFunctionExecutor:
|
||||
|
||||
assert isinstance(process_sync, FunctionExecutor)
|
||||
assert process_sync.id == "sync_processor"
|
||||
assert str in process_sync._handlers
|
||||
assert str in process_sync._handlers # pyright: ignore[reportPrivateUsage]
|
||||
|
||||
# Check spec - sync single parameter functions have no output types
|
||||
spec = process_sync._handler_specs[0]
|
||||
spec = process_sync._handler_specs[0] # pyright: ignore[reportPrivateUsage]
|
||||
assert spec["message_type"] is str
|
||||
assert spec["output_types"] == []
|
||||
assert spec["ctx_annotation"] is None
|
||||
@@ -378,10 +378,10 @@ class TestFunctionExecutor:
|
||||
|
||||
assert isinstance(sync_with_ctx, FunctionExecutor)
|
||||
assert sync_with_ctx.id == "sync_with_ctx"
|
||||
assert int in sync_with_ctx._handlers
|
||||
assert int in sync_with_ctx._handlers # pyright: ignore[reportPrivateUsage]
|
||||
|
||||
# Check spec - sync functions with context can infer output types
|
||||
spec = sync_with_ctx._handler_specs[0]
|
||||
spec = sync_with_ctx._handler_specs[0] # pyright: ignore[reportPrivateUsage]
|
||||
assert spec["message_type"] is int
|
||||
assert spec["output_types"] == [int]
|
||||
|
||||
@@ -404,18 +404,18 @@ class TestFunctionExecutor:
|
||||
return data.upper()
|
||||
|
||||
func_exec = FunctionExecutor(valid_sync)
|
||||
assert str in func_exec._handlers
|
||||
assert str in func_exec._handlers # pyright: ignore[reportPrivateUsage]
|
||||
|
||||
# Valid sync function with two parameters
|
||||
def valid_sync_with_ctx(data: int, ctx: WorkflowContext[str]):
|
||||
return str(data)
|
||||
|
||||
func_exec2 = FunctionExecutor(valid_sync_with_ctx)
|
||||
assert int in func_exec2._handlers
|
||||
assert int in func_exec2._handlers # pyright: ignore[reportPrivateUsage]
|
||||
|
||||
# Sync function with missing type annotation should still fail
|
||||
def no_annotation(data): # type: ignore
|
||||
return data
|
||||
def no_annotation(data): # type: ignore # pyright: ignore[reportUnknownVariableType]
|
||||
return data # pyright: ignore[reportUnknownVariableType]
|
||||
|
||||
with pytest.raises(ValueError, match="type annotation for the message"):
|
||||
FunctionExecutor(no_annotation) # type: ignore
|
||||
@@ -457,11 +457,11 @@ class TestFunctionExecutor:
|
||||
await ctx.yield_output(result)
|
||||
|
||||
# Verify type inference for sync and async functions
|
||||
sync_spec = to_upper_sync._handler_specs[0]
|
||||
sync_spec = to_upper_sync._handler_specs[0] # pyright: ignore[reportPrivateUsage]
|
||||
assert sync_spec["output_types"] == [str]
|
||||
assert sync_spec["workflow_output_types"] == [] # No workflow outputs
|
||||
|
||||
async_spec = reverse_async._handler_specs[0]
|
||||
async_spec = reverse_async._handler_specs[0] # pyright: ignore[reportPrivateUsage]
|
||||
assert async_spec["output_types"] == [Any] # First parameter is Any
|
||||
assert async_spec["workflow_output_types"] == [str] # Second parameter is str
|
||||
|
||||
@@ -471,8 +471,8 @@ class TestFunctionExecutor:
|
||||
|
||||
# For integration testing, we mainly verify that the handlers are properly registered
|
||||
# and the functions are wrapped correctly
|
||||
assert str in to_upper_sync._handlers
|
||||
assert str in reverse_async._handlers
|
||||
assert str in to_upper_sync._handlers # pyright: ignore[reportPrivateUsage]
|
||||
assert str in reverse_async._handlers # pyright: ignore[reportPrivateUsage]
|
||||
|
||||
async def test_sync_function_thread_execution(self):
|
||||
"""Test that sync functions run in thread pool and don't block the event loop."""
|
||||
@@ -491,13 +491,13 @@ class TestFunctionExecutor:
|
||||
return data.upper()
|
||||
|
||||
# Verify the function is wrapped and registered
|
||||
assert str in blocking_function._handlers
|
||||
assert str in blocking_function._handlers # pyright: ignore[reportPrivateUsage]
|
||||
|
||||
# For a more complete test, we'd need to create a full workflow context,
|
||||
# but for now we can verify that the function was properly wrapped
|
||||
# and that sync functions store the correct metadata
|
||||
assert not blocking_function._is_async
|
||||
assert not blocking_function._has_context
|
||||
assert not blocking_function._is_async # pyright: ignore[reportPrivateUsage]
|
||||
assert not blocking_function._has_context # pyright: ignore[reportPrivateUsage]
|
||||
|
||||
# The actual thread execution test would require a full workflow setup,
|
||||
# but the important thing is that asyncio.to_thread is used in the wrapper
|
||||
@@ -506,7 +506,7 @@ class TestFunctionExecutor:
|
||||
"""Test that @executor decorator properly rejects @staticmethod with clear error."""
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
|
||||
class Example:
|
||||
class Example: # pyright: ignore[reportUnusedClass]
|
||||
@executor
|
||||
@staticmethod
|
||||
async def bad_handler(data: str) -> str:
|
||||
@@ -519,7 +519,7 @@ class TestFunctionExecutor:
|
||||
"""Test that @executor decorator properly rejects @classmethod with clear error."""
|
||||
with pytest.raises(ValueError) as exc_info:
|
||||
|
||||
class Example:
|
||||
class Example: # pyright: ignore[reportUnusedClass]
|
||||
@executor
|
||||
@classmethod
|
||||
async def bad_handler(cls, data: str) -> str:
|
||||
@@ -570,8 +570,8 @@ class TestExecutorExplicitTypes:
|
||||
pass
|
||||
|
||||
# Handler should be registered for str (explicit)
|
||||
assert str in process._handlers
|
||||
assert len(process._handlers) == 1
|
||||
assert str in process._handlers # pyright: ignore[reportPrivateUsage]
|
||||
assert len(process._handlers) == 1 # pyright: ignore[reportPrivateUsage]
|
||||
|
||||
# Can handle str messages
|
||||
assert process.can_handle(WorkflowMessage(data="hello", source_id="mock"))
|
||||
@@ -586,7 +586,7 @@ class TestExecutorExplicitTypes:
|
||||
pass
|
||||
|
||||
# Handler spec should have int as output type (explicit), not str (introspected)
|
||||
spec = process._handler_specs[0]
|
||||
spec = process._handler_specs[0] # pyright: ignore[reportPrivateUsage]
|
||||
assert spec["output_types"] == [int]
|
||||
|
||||
# Executor output_types property should reflect explicit type
|
||||
@@ -601,11 +601,11 @@ class TestExecutorExplicitTypes:
|
||||
pass
|
||||
|
||||
# Handler should be registered for dict (explicit input type)
|
||||
assert dict in process._handlers
|
||||
assert len(process._handlers) == 1
|
||||
assert dict in process._handlers # pyright: ignore[reportPrivateUsage]
|
||||
assert len(process._handlers) == 1 # pyright: ignore[reportPrivateUsage]
|
||||
|
||||
# Output type should be list (explicit)
|
||||
spec = process._handler_specs[0]
|
||||
spec = process._handler_specs[0] # pyright: ignore[reportPrivateUsage]
|
||||
assert spec["output_types"] == [list]
|
||||
|
||||
# Verify can_handle
|
||||
@@ -620,7 +620,7 @@ class TestExecutorExplicitTypes:
|
||||
pass
|
||||
|
||||
# Handler should be registered for the union type
|
||||
assert len(process._handlers) == 1
|
||||
assert len(process._handlers) == 1 # pyright: ignore[reportPrivateUsage]
|
||||
|
||||
# Can handle both str and int messages
|
||||
assert process.can_handle(WorkflowMessage(data="hello", source_id="mock"))
|
||||
@@ -648,8 +648,8 @@ class TestExecutorExplicitTypes:
|
||||
pass
|
||||
|
||||
# Should use explicit input type (bytes), not introspected (str)
|
||||
assert bytes in process._handlers
|
||||
assert str not in process._handlers
|
||||
assert bytes in process._handlers # pyright: ignore[reportPrivateUsage]
|
||||
assert str not in process._handlers # pyright: ignore[reportPrivateUsage]
|
||||
|
||||
# Should use explicit output type (float), not introspected (int)
|
||||
assert float in process.output_types
|
||||
@@ -663,7 +663,7 @@ class TestExecutorExplicitTypes:
|
||||
pass
|
||||
|
||||
# Should use introspected types
|
||||
assert str in process._handlers
|
||||
assert str in process._handlers # pyright: ignore[reportPrivateUsage]
|
||||
assert int in process.output_types
|
||||
|
||||
def test_executor_partial_explicit_types(self):
|
||||
@@ -674,7 +674,7 @@ class TestExecutorExplicitTypes:
|
||||
async def process_input(message: str, ctx: WorkflowContext[int]) -> None:
|
||||
pass
|
||||
|
||||
assert bytes in process_input._handlers # Explicit
|
||||
assert bytes in process_input._handlers # Explicit # pyright: ignore[reportPrivateUsage]
|
||||
assert int in process_input.output_types # Introspected
|
||||
|
||||
# Only explicit output_type, introspect input_type
|
||||
@@ -682,7 +682,7 @@ class TestExecutorExplicitTypes:
|
||||
async def process_output(message: str, ctx: WorkflowContext[int]) -> None:
|
||||
pass
|
||||
|
||||
assert str in process_output._handlers # Introspected
|
||||
assert str in process_output._handlers # Introspected # pyright: ignore[reportPrivateUsage]
|
||||
assert float in process_output.output_types # Explicit
|
||||
assert int not in process_output.output_types # Not introspected when explicit provided
|
||||
|
||||
@@ -694,7 +694,7 @@ class TestExecutorExplicitTypes:
|
||||
pass
|
||||
|
||||
# Should work with explicit input_type
|
||||
assert str in process._handlers
|
||||
assert str in process._handlers # pyright: ignore[reportPrivateUsage]
|
||||
assert process.can_handle(WorkflowMessage(data="hello", source_id="mock"))
|
||||
|
||||
def test_executor_explicit_types_with_id(self):
|
||||
@@ -705,7 +705,7 @@ class TestExecutorExplicitTypes:
|
||||
pass
|
||||
|
||||
assert process.id == "custom_id"
|
||||
assert bytes in process._handlers
|
||||
assert bytes in process._handlers # pyright: ignore[reportPrivateUsage]
|
||||
assert int in process.output_types
|
||||
|
||||
def test_executor_explicit_types_with_single_param_function(self):
|
||||
@@ -713,10 +713,10 @@ class TestExecutorExplicitTypes:
|
||||
|
||||
@executor(input=str)
|
||||
async def process(message): # type: ignore[no-untyped-def]
|
||||
return message.upper()
|
||||
return message.upper() # pyright: ignore[reportUnknownMemberType, reportUnknownVariableType]
|
||||
|
||||
# Should work with explicit input_type
|
||||
assert str in process._handlers
|
||||
assert str in process._handlers # pyright: ignore[reportPrivateUsage]
|
||||
assert process.can_handle(WorkflowMessage(data="hello", source_id="mock"))
|
||||
assert not process.can_handle(WorkflowMessage(data=42, source_id="mock"))
|
||||
|
||||
@@ -727,7 +727,7 @@ class TestExecutorExplicitTypes:
|
||||
def process(message, ctx: WorkflowContext) -> None: # type: ignore[no-untyped-def]
|
||||
pass
|
||||
|
||||
assert int in process._handlers
|
||||
assert int in process._handlers # pyright: ignore[reportPrivateUsage]
|
||||
assert str in process.output_types
|
||||
|
||||
def test_function_executor_constructor_with_explicit_types(self):
|
||||
@@ -736,10 +736,10 @@ class TestExecutorExplicitTypes:
|
||||
async def process(message, ctx: WorkflowContext) -> None: # type: ignore[no-untyped-def]
|
||||
pass
|
||||
|
||||
func_exec = FunctionExecutor(process, id="test", input=dict, output=list)
|
||||
func_exec = FunctionExecutor(process, id="test", input=dict, output=list) # pyright: ignore[reportUnknownArgumentType]
|
||||
|
||||
assert dict in func_exec._handlers
|
||||
spec = func_exec._handler_specs[0]
|
||||
assert dict in func_exec._handlers # pyright: ignore[reportPrivateUsage]
|
||||
spec = func_exec._handler_specs[0] # pyright: ignore[reportPrivateUsage]
|
||||
assert spec["message_type"] is dict
|
||||
assert spec["output_types"] == [list]
|
||||
|
||||
@@ -766,7 +766,7 @@ class TestExecutorExplicitTypes:
|
||||
pass
|
||||
|
||||
# Should resolve the string to the actual type
|
||||
assert FuncExecForwardRefMessage in process._handlers
|
||||
assert FuncExecForwardRefMessage in process._handlers # pyright: ignore[reportPrivateUsage]
|
||||
assert process.can_handle(WorkflowMessage(data=FuncExecForwardRefMessage("hello"), source_id="mock"))
|
||||
|
||||
def test_executor_with_string_forward_reference_union(self):
|
||||
@@ -798,7 +798,7 @@ class TestExecutorExplicitTypes:
|
||||
pass
|
||||
|
||||
# Handler spec should have bool as workflow_output_type (explicit)
|
||||
spec = process._handler_specs[0]
|
||||
spec = process._handler_specs[0] # pyright: ignore[reportPrivateUsage]
|
||||
assert spec["workflow_output_types"] == [bool]
|
||||
|
||||
# Executor workflow_output_types property should reflect explicit type
|
||||
@@ -826,7 +826,7 @@ class TestExecutorExplicitTypes:
|
||||
pass
|
||||
|
||||
# Check input type
|
||||
assert str in process._handlers
|
||||
assert str in process._handlers # pyright: ignore[reportPrivateUsage]
|
||||
assert process.can_handle(WorkflowMessage(data="hello", source_id="mock"))
|
||||
|
||||
# Check output_type
|
||||
@@ -892,6 +892,6 @@ class TestExecutorExplicitTypes:
|
||||
workflow_output=bool,
|
||||
)
|
||||
|
||||
assert str in exec_instance._handlers
|
||||
assert str in exec_instance._handlers # pyright: ignore[reportPrivateUsage]
|
||||
assert int in exec_instance.output_types
|
||||
assert bool in exec_instance.workflow_output_types
|
||||
|
||||
@@ -19,10 +19,10 @@ class TestFunctionExecutorFutureAnnotations:
|
||||
|
||||
assert isinstance(process_future, FunctionExecutor)
|
||||
assert process_future.id == "future_test"
|
||||
assert int in process_future._handlers
|
||||
assert int in process_future._handlers # pyright: ignore[reportPrivateUsage]
|
||||
|
||||
# Check spec
|
||||
spec = process_future._handler_specs[0]
|
||||
spec = process_future._handler_specs[0] # pyright: ignore[reportPrivateUsage]
|
||||
assert spec["message_type"] is int
|
||||
assert spec["output_types"] == [int]
|
||||
|
||||
@@ -34,6 +34,6 @@ class TestFunctionExecutorFutureAnnotations:
|
||||
await ctx.send_message(["done"])
|
||||
|
||||
assert isinstance(process_complex, FunctionExecutor)
|
||||
spec = process_complex._handler_specs[0]
|
||||
spec = process_complex._handler_specs[0] # pyright: ignore[reportPrivateUsage]
|
||||
assert spec["message_type"] == dict[str, Any]
|
||||
assert spec["output_types"] == [list[str]]
|
||||
|
||||
@@ -794,7 +794,7 @@ class TestResponseHandlerExplicitTypes:
|
||||
"""Test response_handler with explicit request and response types."""
|
||||
|
||||
@response_handler(request=str, response=int)
|
||||
async def test_handler(self, original_request, response, ctx) -> None:
|
||||
async def test_handler(self: Any, original_request: Any, response: Any, ctx: WorkflowContext) -> None:
|
||||
pass
|
||||
|
||||
spec = test_handler._response_handler_spec # type: ignore[reportAttributeAccessIssue]
|
||||
@@ -806,7 +806,7 @@ class TestResponseHandlerExplicitTypes:
|
||||
"""Test response_handler with explicit output and workflow_output types."""
|
||||
|
||||
@response_handler(request=str, response=int, output=bool, workflow_output=float)
|
||||
async def test_handler(self, original_request, response, ctx) -> None:
|
||||
async def test_handler(self: Any, original_request: Any, response: Any, ctx: WorkflowContext) -> None:
|
||||
pass
|
||||
|
||||
spec = test_handler._response_handler_spec # type: ignore[reportAttributeAccessIssue]
|
||||
@@ -818,8 +818,8 @@ class TestResponseHandlerExplicitTypes:
|
||||
def test_response_handler_with_union_types(self):
|
||||
"""Test response_handler with union types."""
|
||||
|
||||
@response_handler(request=str | int, response=bool | float)
|
||||
async def test_handler(self, original_request, response, ctx) -> None:
|
||||
@response_handler(request=str | int, response=bool | float) # pyright: ignore[reportArgumentType]
|
||||
async def test_handler(self: Any, original_request: Any, response: Any, ctx: WorkflowContext) -> None:
|
||||
pass
|
||||
|
||||
spec = test_handler._response_handler_spec # type: ignore[reportAttributeAccessIssue]
|
||||
@@ -830,7 +830,7 @@ class TestResponseHandlerExplicitTypes:
|
||||
"""Test response_handler with string forward references."""
|
||||
|
||||
@response_handler(request="str", response="int")
|
||||
async def test_handler(self, original_request, response, ctx) -> None:
|
||||
async def test_handler(self: Any, original_request: Any, response: Any, ctx: WorkflowContext) -> None:
|
||||
pass
|
||||
|
||||
spec = test_handler._response_handler_spec # type: ignore[reportAttributeAccessIssue]
|
||||
@@ -842,7 +842,7 @@ class TestResponseHandlerExplicitTypes:
|
||||
with pytest.raises(ValueError, match="must specify 'request' type"):
|
||||
|
||||
@response_handler(response=int)
|
||||
async def test_handler(self, original_request, response, ctx) -> None:
|
||||
async def test_handler(self: Any, original_request: Any, response: Any, ctx: WorkflowContext) -> None: # pyright: ignore[reportUnusedFunction]
|
||||
pass
|
||||
|
||||
def test_response_handler_explicit_missing_response_raises_error(self):
|
||||
@@ -850,7 +850,7 @@ class TestResponseHandlerExplicitTypes:
|
||||
with pytest.raises(ValueError, match="must specify 'response' type"):
|
||||
|
||||
@response_handler(request=str)
|
||||
async def test_handler(self, original_request, response, ctx) -> None:
|
||||
async def test_handler(self: Any, original_request: Any, response: Any, ctx: WorkflowContext) -> None: # pyright: ignore[reportUnusedFunction]
|
||||
pass
|
||||
|
||||
def test_response_handler_explicit_only_output_raises_error(self):
|
||||
@@ -858,7 +858,7 @@ class TestResponseHandlerExplicitTypes:
|
||||
with pytest.raises(ValueError, match="must specify 'request' type"):
|
||||
|
||||
@response_handler(output=bool)
|
||||
async def test_handler(self, original_request, response, ctx) -> None:
|
||||
async def test_handler(self: Any, original_request: Any, response: Any, ctx: WorkflowContext) -> None: # pyright: ignore[reportUnusedFunction]
|
||||
pass
|
||||
|
||||
def test_executor_with_explicit_response_handlers(self):
|
||||
@@ -873,7 +873,7 @@ class TestResponseHandlerExplicitTypes:
|
||||
pass
|
||||
|
||||
@response_handler(request=str, response=int, output=bool)
|
||||
async def handle_explicit(self, original_request, response, ctx) -> None:
|
||||
async def handle_explicit(self, original_request: Any, response: Any, ctx: WorkflowContext) -> None:
|
||||
pass
|
||||
|
||||
executor = TestExecutor()
|
||||
@@ -907,7 +907,7 @@ class TestResponseHandlerExplicitTypes:
|
||||
pass
|
||||
|
||||
@response_handler(request=str, response=int)
|
||||
async def handle_response(self, original_request, response, ctx) -> None:
|
||||
async def handle_response(self, original_request: Any, response: Any, ctx: WorkflowContext) -> None:
|
||||
self.handled_request = original_request
|
||||
self.handled_response = response
|
||||
|
||||
@@ -942,7 +942,7 @@ class TestResponseHandlerExplicitTypes:
|
||||
|
||||
# Explicit type handler
|
||||
@response_handler(request=dict, response=bool)
|
||||
async def handle_explicit(self, original_request, response, ctx) -> None:
|
||||
async def handle_explicit(self, original_request: Any, response: Any, ctx: WorkflowContext) -> None:
|
||||
pass
|
||||
|
||||
executor = TestExecutor()
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
|
||||
import asyncio
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
import pytest
|
||||
@@ -113,7 +114,7 @@ async def test_runner_run_until_convergence():
|
||||
assert result is not None and result == 10
|
||||
|
||||
# iteration count shouldn't be reset after convergence
|
||||
assert runner._iteration == 10 # type: ignore
|
||||
assert runner._iteration == 10 # pyright: ignore[reportPrivateUsage]
|
||||
|
||||
|
||||
async def test_runner_run_until_convergence_not_completed():
|
||||
@@ -173,7 +174,7 @@ async def test_runner_run_iteration_preserves_message_order_per_edge_runner() ->
|
||||
for index in range(5):
|
||||
await ctx.send_message(WorkflowMessage(data=MockMessage(data=index), source_id="source"))
|
||||
|
||||
await runner._run_iteration()
|
||||
await runner._run_iteration() # pyright: ignore[reportPrivateUsage]
|
||||
|
||||
assert edge_runner.received == [0, 1, 2, 3, 4]
|
||||
|
||||
@@ -213,7 +214,7 @@ async def test_runner_run_iteration_delivers_different_edge_runners_concurrently
|
||||
|
||||
await ctx.send_message(WorkflowMessage(data=MockMessage(data=1), source_id="source"))
|
||||
|
||||
iteration_task = asyncio.create_task(runner._run_iteration())
|
||||
iteration_task = asyncio.create_task(runner._run_iteration()) # pyright: ignore[reportPrivateUsage]
|
||||
|
||||
await blocking_edge_runner.started.wait()
|
||||
await asyncio.wait_for(probe_edge_runner.probe_completed.wait(), timeout=2.0)
|
||||
@@ -280,7 +281,7 @@ async def test_fanout_edge_runner_delivers_to_multiple_targets_concurrently() ->
|
||||
# Queue a message from source (will be delivered to both targets via FanOut)
|
||||
await ctx.send_message(WorkflowMessage(data=MockMessage(data=1), source_id=source.id))
|
||||
|
||||
iteration_task = asyncio.create_task(runner._run_iteration())
|
||||
iteration_task = asyncio.create_task(runner._run_iteration()) # pyright: ignore[reportPrivateUsage]
|
||||
|
||||
# Wait for the blocking executor to start
|
||||
await blocking_target.started.wait()
|
||||
@@ -477,11 +478,11 @@ async def test_runner_reset_iteration_count():
|
||||
ctx = InProcRunnerContext()
|
||||
|
||||
runner = Runner([], {executor_a.id: executor_a}, state, ctx, "test_name", graph_signature_hash="test_hash")
|
||||
runner._iteration = 10
|
||||
runner._iteration = 10 # pyright: ignore[reportPrivateUsage]
|
||||
|
||||
runner.reset_iteration_count()
|
||||
|
||||
assert runner._iteration == 0
|
||||
assert runner._iteration == 0 # pyright: ignore[reportPrivateUsage]
|
||||
|
||||
|
||||
class CheckpointingContext(InProcRunnerContext):
|
||||
@@ -501,18 +502,19 @@ class CheckpointingContext(InProcRunnerContext):
|
||||
graph_signature_hash: str,
|
||||
state: State,
|
||||
previous_checkpoint_id: str | None,
|
||||
iteration: int,
|
||||
iteration_count: int,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
) -> str:
|
||||
checkpoint = WorkflowCheckpoint(
|
||||
workflow_name=workflow_name,
|
||||
graph_signature_hash=graph_signature_hash,
|
||||
state=state.export(),
|
||||
state=state.export_state(),
|
||||
previous_checkpoint_id=previous_checkpoint_id,
|
||||
iteration_count=iteration,
|
||||
iteration_count=iteration_count,
|
||||
)
|
||||
return await self._storage.save(checkpoint)
|
||||
|
||||
async def load_checkpoint(self, checkpoint_id: str) -> WorkflowCheckpoint | None:
|
||||
async def load_checkpoint(self, checkpoint_id: str) -> WorkflowCheckpoint | None: # pyright: ignore[reportIncompatibleMethodOverride]
|
||||
try:
|
||||
return await self._storage.load(checkpoint_id)
|
||||
except WorkflowCheckpointException:
|
||||
@@ -537,7 +539,8 @@ class FailingCheckpointContext(InProcRunnerContext):
|
||||
graph_signature_hash: str,
|
||||
state: State,
|
||||
previous_checkpoint_id: str | None,
|
||||
iteration: int,
|
||||
iteration_count: int,
|
||||
metadata: dict[str, Any] | None = None,
|
||||
) -> str:
|
||||
raise RuntimeError("Simulated checkpoint failure")
|
||||
|
||||
@@ -609,8 +612,8 @@ async def test_runner_restore_from_checkpoint_with_external_storage():
|
||||
# Restore using external storage
|
||||
await runner.restore_from_checkpoint(checkpoint_id, checkpoint_storage=storage)
|
||||
|
||||
assert runner._resumed_from_checkpoint is True
|
||||
assert runner._iteration == 5
|
||||
assert runner._resumed_from_checkpoint is True # pyright: ignore[reportPrivateUsage]
|
||||
assert runner._iteration == 5 # pyright: ignore[reportPrivateUsage]
|
||||
assert state.get("test_key") == "test_value"
|
||||
|
||||
|
||||
@@ -684,7 +687,7 @@ async def test_runner_restore_executor_states_invalid_states_type():
|
||||
runner = Runner([], {executor_a.id: executor_a}, state, ctx, "test_name", graph_signature_hash="test_hash")
|
||||
|
||||
with pytest.raises(WorkflowCheckpointException, match="not a dictionary"):
|
||||
await runner._restore_executor_states()
|
||||
await runner._restore_executor_states() # pyright: ignore[reportPrivateUsage]
|
||||
|
||||
|
||||
async def test_runner_restore_executor_states_invalid_executor_id_type():
|
||||
@@ -698,7 +701,7 @@ async def test_runner_restore_executor_states_invalid_executor_id_type():
|
||||
runner = Runner([], {executor_a.id: executor_a}, state, ctx, "test_name", graph_signature_hash="test_hash")
|
||||
|
||||
with pytest.raises(WorkflowCheckpointException, match="not a string"):
|
||||
await runner._restore_executor_states()
|
||||
await runner._restore_executor_states() # pyright: ignore[reportPrivateUsage]
|
||||
|
||||
|
||||
async def test_runner_restore_executor_states_invalid_state_type():
|
||||
@@ -712,7 +715,7 @@ async def test_runner_restore_executor_states_invalid_state_type():
|
||||
runner = Runner([], {executor_a.id: executor_a}, state, ctx, "test_name", graph_signature_hash="test_hash")
|
||||
|
||||
with pytest.raises(WorkflowCheckpointException, match="not a dict"):
|
||||
await runner._restore_executor_states()
|
||||
await runner._restore_executor_states() # pyright: ignore[reportPrivateUsage]
|
||||
|
||||
|
||||
async def test_runner_restore_executor_states_invalid_state_keys():
|
||||
@@ -726,7 +729,7 @@ async def test_runner_restore_executor_states_invalid_state_keys():
|
||||
runner = Runner([], {executor_a.id: executor_a}, state, ctx, "test_name", graph_signature_hash="test_hash")
|
||||
|
||||
with pytest.raises(WorkflowCheckpointException, match="not a dict"):
|
||||
await runner._restore_executor_states()
|
||||
await runner._restore_executor_states() # pyright: ignore[reportPrivateUsage]
|
||||
|
||||
|
||||
async def test_runner_restore_executor_states_missing_executor():
|
||||
@@ -739,7 +742,7 @@ async def test_runner_restore_executor_states_missing_executor():
|
||||
runner = Runner([], {}, state, ctx, "test_name", graph_signature_hash="test_hash")
|
||||
|
||||
with pytest.raises(WorkflowCheckpointException, match="not found during state restoration"):
|
||||
await runner._restore_executor_states()
|
||||
await runner._restore_executor_states() # pyright: ignore[reportPrivateUsage]
|
||||
|
||||
|
||||
async def test_runner_set_executor_state_invalid_existing_states():
|
||||
@@ -752,7 +755,7 @@ async def test_runner_set_executor_state_invalid_existing_states():
|
||||
runner = Runner([], {executor_a.id: executor_a}, state, ctx, "test_name", graph_signature_hash="test_hash")
|
||||
|
||||
with pytest.raises(WorkflowCheckpointException, match="not a dictionary"):
|
||||
await runner._set_executor_state("executor_a", {"key": "value"})
|
||||
await runner._set_executor_state("executor_a", {"key": "value"}) # pyright: ignore[reportPrivateUsage]
|
||||
|
||||
|
||||
async def test_runner_with_pre_loop_events():
|
||||
@@ -779,7 +782,7 @@ class EventEmittingExecutor(Executor):
|
||||
"""An executor that emits events during execution."""
|
||||
|
||||
@handler
|
||||
async def handle(self, message: MockMessage, ctx: WorkflowContext[MockMessage, int]) -> None:
|
||||
async def handle(self, message: MockMessage, ctx: WorkflowContext[MockMessage, str]) -> None:
|
||||
# Emit event during processing
|
||||
await ctx.yield_output(f"processed-{message.data}")
|
||||
if message.data < 3:
|
||||
@@ -831,7 +834,7 @@ async def test_runner_restore_executor_states_no_states():
|
||||
runner = Runner([], {executor_a.id: executor_a}, state, ctx, "test_name", graph_signature_hash="test_hash")
|
||||
|
||||
# Should complete without error when no executor states exist
|
||||
await runner._restore_executor_states()
|
||||
await runner._restore_executor_states() # pyright: ignore[reportPrivateUsage]
|
||||
|
||||
|
||||
async def test_runner_checkpoint_with_resumed_flag():
|
||||
@@ -853,7 +856,7 @@ async def test_runner_checkpoint_with_resumed_flag():
|
||||
state = State()
|
||||
|
||||
runner = Runner(edges, executors, state, ctx, "test_name", graph_signature_hash="test_hash")
|
||||
runner._mark_resumed(5)
|
||||
runner._mark_resumed(5) # pyright: ignore[reportPrivateUsage]
|
||||
|
||||
# Add a message to trigger the checkpoint creation path
|
||||
await ctx.send_message(WorkflowMessage(data=MockMessage(data=8), source_id="START"))
|
||||
@@ -870,7 +873,7 @@ async def test_runner_checkpoint_with_resumed_flag():
|
||||
pass
|
||||
|
||||
# After completing, resumed flag should be reset
|
||||
assert runner._resumed_from_checkpoint is False
|
||||
assert runner._resumed_from_checkpoint is False # pyright: ignore[reportPrivateUsage]
|
||||
|
||||
|
||||
class ExecutorThatFailsWithEvents(Executor):
|
||||
@@ -883,7 +886,7 @@ class ExecutorThatFailsWithEvents(Executor):
|
||||
self._iteration_count = 0
|
||||
|
||||
@handler
|
||||
async def handle(self, message: MockMessage, ctx: WorkflowContext[MockMessage, int]) -> None:
|
||||
async def handle(self, message: MockMessage, ctx: WorkflowContext[MockMessage, str]) -> None:
|
||||
self._iteration_count += 1
|
||||
# First emit an output event to the workflow context
|
||||
await ctx.yield_output(f"output-before-failure-{message.data}")
|
||||
@@ -951,7 +954,7 @@ class SlowEventEmittingExecutor(Executor):
|
||||
self.current_iteration = 0
|
||||
|
||||
@handler
|
||||
async def handle(self, message: MockMessage, ctx: WorkflowContext[MockMessage, int]) -> None:
|
||||
async def handle(self, message: MockMessage, ctx: WorkflowContext[MockMessage, str]) -> None:
|
||||
self.current_iteration += 1
|
||||
# Emit output event
|
||||
await ctx.yield_output(f"iteration-{self.current_iteration}")
|
||||
|
||||
@@ -61,9 +61,9 @@ class TestSuperstepCaching:
|
||||
state.set("key", "value")
|
||||
|
||||
# Value is in pending
|
||||
assert "key" in state._pending
|
||||
assert "key" in state._pending # pyright: ignore[reportPrivateUsage]
|
||||
# Value is NOT in committed
|
||||
assert "key" not in state._committed
|
||||
assert "key" not in state._committed # pyright: ignore[reportPrivateUsage]
|
||||
# But get() still returns it
|
||||
assert state.get("key") == "value"
|
||||
|
||||
@@ -72,14 +72,14 @@ class TestSuperstepCaching:
|
||||
state.set("key", "value")
|
||||
|
||||
# Before commit: in pending, not committed
|
||||
assert "key" in state._pending
|
||||
assert "key" not in state._committed
|
||||
assert "key" in state._pending # pyright: ignore[reportPrivateUsage]
|
||||
assert "key" not in state._committed # pyright: ignore[reportPrivateUsage]
|
||||
|
||||
state.commit()
|
||||
|
||||
# After commit: in committed, pending cleared
|
||||
assert "key" not in state._pending
|
||||
assert "key" in state._committed
|
||||
assert "key" not in state._pending # pyright: ignore[reportPrivateUsage]
|
||||
assert "key" in state._committed # pyright: ignore[reportPrivateUsage]
|
||||
assert state.get("key") == "value"
|
||||
|
||||
def test_discard_clears_pending_without_committing(self) -> None:
|
||||
@@ -108,7 +108,7 @@ class TestSuperstepCaching:
|
||||
# get() returns pending value, not committed
|
||||
assert state.get("key") == "pending_value"
|
||||
# But committed still has old value
|
||||
assert state._committed["key"] == "committed_value"
|
||||
assert state._committed["key"] == "committed_value" # pyright: ignore[reportPrivateUsage]
|
||||
|
||||
def test_multiple_sets_before_commit(self) -> None:
|
||||
state = State()
|
||||
@@ -130,13 +130,13 @@ class TestDeleteWithSuperstepCaching:
|
||||
state = State()
|
||||
state.set("key", "value")
|
||||
# Key only in pending, not committed
|
||||
assert "key" in state._pending
|
||||
assert "key" not in state._committed
|
||||
assert "key" in state._pending # pyright: ignore[reportPrivateUsage]
|
||||
assert "key" not in state._committed # pyright: ignore[reportPrivateUsage]
|
||||
|
||||
state.delete("key")
|
||||
|
||||
# Should be removed from pending
|
||||
assert "key" not in state._pending
|
||||
assert "key" not in state._pending # pyright: ignore[reportPrivateUsage]
|
||||
assert state.get("key") is None
|
||||
assert state.has("key") is False
|
||||
|
||||
@@ -148,14 +148,14 @@ class TestDeleteWithSuperstepCaching:
|
||||
state.delete("key")
|
||||
|
||||
# Key should be marked for deletion in pending (sentinel)
|
||||
assert "key" in state._pending
|
||||
assert "key" in state._pending # pyright: ignore[reportPrivateUsage]
|
||||
# get() should return default (not the sentinel!)
|
||||
assert state.get("key") is None
|
||||
assert state.get("key", "default") == "default"
|
||||
# has() should return False
|
||||
assert state.has("key") is False
|
||||
# But committed still has it until commit()
|
||||
assert "key" in state._committed
|
||||
assert "key" in state._committed # pyright: ignore[reportPrivateUsage]
|
||||
|
||||
def test_delete_committed_key_removed_on_commit(self) -> None:
|
||||
state = State()
|
||||
@@ -166,8 +166,8 @@ class TestDeleteWithSuperstepCaching:
|
||||
state.commit()
|
||||
|
||||
# Now it should be gone from committed too
|
||||
assert "key" not in state._committed
|
||||
assert "key" not in state._pending
|
||||
assert "key" not in state._committed # pyright: ignore[reportPrivateUsage]
|
||||
assert "key" not in state._pending # pyright: ignore[reportPrivateUsage]
|
||||
|
||||
def test_delete_key_in_both_pending_and_committed(self) -> None:
|
||||
"""Test delete when key exists in both pending (modified) and committed."""
|
||||
@@ -177,8 +177,8 @@ class TestDeleteWithSuperstepCaching:
|
||||
|
||||
# Modify the key (now in both pending and committed)
|
||||
state.set("key", "modified")
|
||||
assert state._pending["key"] == "modified"
|
||||
assert state._committed["key"] == "original"
|
||||
assert state._pending["key"] == "modified" # pyright: ignore[reportPrivateUsage]
|
||||
assert state._committed["key"] == "original" # pyright: ignore[reportPrivateUsage]
|
||||
|
||||
# Delete should mark for deletion from committed
|
||||
state.delete("key")
|
||||
@@ -189,8 +189,8 @@ class TestDeleteWithSuperstepCaching:
|
||||
|
||||
# After commit, key should be fully removed
|
||||
state.commit()
|
||||
assert "key" not in state._committed
|
||||
assert "key" not in state._pending
|
||||
assert "key" not in state._committed # pyright: ignore[reportPrivateUsage]
|
||||
assert "key" not in state._pending # pyright: ignore[reportPrivateUsage]
|
||||
|
||||
def test_discard_after_delete_restores_committed_value(self) -> None:
|
||||
state = State()
|
||||
@@ -238,12 +238,12 @@ class TestFailureScenarios:
|
||||
state.set("key3", "value3")
|
||||
|
||||
# Before commit - nothing in committed
|
||||
assert len(state._committed) == 0
|
||||
assert len(state._committed) == 0 # pyright: ignore[reportPrivateUsage]
|
||||
|
||||
state.commit()
|
||||
|
||||
# After commit - all three values committed together
|
||||
assert state._committed == {"key1": "value1", "key2": "value2", "key3": "value3"}
|
||||
assert state._committed == {"key1": "value1", "key2": "value2", "key3": "value3"} # pyright: ignore[reportPrivateUsage]
|
||||
|
||||
def test_repeated_supersteps_are_isolated(self) -> None:
|
||||
"""Test that each superstep's changes are isolated until committed."""
|
||||
@@ -300,4 +300,4 @@ class TestExportImport:
|
||||
|
||||
# Pending is still there
|
||||
assert state.get("pending_key") == "pending_value"
|
||||
assert "pending_key" in state._pending
|
||||
assert "pending_key" in state._pending # pyright: ignore[reportPrivateUsage]
|
||||
|
||||
@@ -36,32 +36,32 @@ def test_normalize_type_to_list_none() -> None:
|
||||
|
||||
def test_normalize_type_to_list_union_pipe_syntax() -> None:
|
||||
"""Test normalize_type_to_list with union types using | syntax."""
|
||||
result = normalize_type_to_list(str | int)
|
||||
result = normalize_type_to_list(str | int) # pyright: ignore[reportArgumentType]
|
||||
assert set(result) == {str, int}
|
||||
|
||||
result = normalize_type_to_list(str | int | bool)
|
||||
result = normalize_type_to_list(str | int | bool) # pyright: ignore[reportArgumentType]
|
||||
assert set(result) == {str, int, bool}
|
||||
|
||||
|
||||
def test_normalize_type_to_list_union_typing_syntax() -> None:
|
||||
"""Test normalize_type_to_list with Union[] from typing module."""
|
||||
result = normalize_type_to_list(Union[str, int])
|
||||
result = normalize_type_to_list(Union[str, int]) # pyright: ignore[reportArgumentType]
|
||||
assert set(result) == {str, int}
|
||||
|
||||
result = normalize_type_to_list(Union[str, int, bool])
|
||||
result = normalize_type_to_list(Union[str, int, bool]) # pyright: ignore[reportArgumentType]
|
||||
assert set(result) == {str, int, bool}
|
||||
|
||||
|
||||
def test_normalize_type_to_list_optional() -> None:
|
||||
"""Test normalize_type_to_list with Optional types (Union[T, None])."""
|
||||
# Optional[str] is Union[str, None]
|
||||
result = normalize_type_to_list(Optional[str])
|
||||
result = normalize_type_to_list(Optional[str]) # pyright: ignore[reportArgumentType]
|
||||
assert str in result
|
||||
assert type(None) in result
|
||||
assert len(result) == 2
|
||||
|
||||
# str | None is equivalent
|
||||
result = normalize_type_to_list(str | None)
|
||||
result = normalize_type_to_list(str | None) # pyright: ignore[reportArgumentType]
|
||||
assert str in result
|
||||
assert type(None) in result
|
||||
assert len(result) == 2
|
||||
@@ -77,7 +77,7 @@ def test_normalize_type_to_list_custom_types() -> None:
|
||||
result = normalize_type_to_list(CustomMessage)
|
||||
assert result == [CustomMessage]
|
||||
|
||||
result = normalize_type_to_list(CustomMessage | str)
|
||||
result = normalize_type_to_list(CustomMessage | str) # pyright: ignore[reportArgumentType]
|
||||
assert set(result) == {CustomMessage, str}
|
||||
|
||||
|
||||
@@ -96,7 +96,7 @@ def test_resolve_type_annotation_actual_types() -> None:
|
||||
"""Test resolve_type_annotation passes through actual types unchanged."""
|
||||
assert resolve_type_annotation(str) is str
|
||||
assert resolve_type_annotation(int) is int
|
||||
assert resolve_type_annotation(str | int) == str | int
|
||||
assert resolve_type_annotation(str | int) == str | int # pyright: ignore[reportArgumentType]
|
||||
|
||||
|
||||
def test_resolve_type_annotation_string_builtin() -> None:
|
||||
|
||||
@@ -484,8 +484,8 @@ def test_handler_ctx_missing_annotation_raises() -> None:
|
||||
# Validation now happens at handler registration time, not workflow build time
|
||||
with pytest.raises(ValueError) as exc:
|
||||
|
||||
class BadExecutor(Executor):
|
||||
@handler
|
||||
class BadExecutor(Executor): # pyright: ignore[reportUnusedClass]
|
||||
@handler # pyright: ignore[reportUnknownArgumentType]
|
||||
async def handle(self, message: str, ctx) -> None: # type: ignore[no-untyped-def]
|
||||
pass
|
||||
|
||||
@@ -496,8 +496,8 @@ def test_handler_ctx_invalid_t_out_entries_raises() -> None:
|
||||
# Validation now happens at handler registration time, not workflow build time
|
||||
with pytest.raises(ValueError) as exc:
|
||||
|
||||
class BadExecutor(Executor):
|
||||
@handler
|
||||
class BadExecutor(Executor): # pyright: ignore[reportUnusedClass]
|
||||
@handler # pyright: ignore[reportUnknownArgumentType]
|
||||
async def handle(self, message: str, ctx: WorkflowContext[123]) -> None: # type: ignore[valid-type]
|
||||
pass
|
||||
|
||||
@@ -555,7 +555,7 @@ def test_output_validation_with_valid_output_executors():
|
||||
)
|
||||
|
||||
assert workflow is not None
|
||||
assert workflow._output_executors == ["executor2"]
|
||||
assert workflow._output_executors == ["executor2"] # pyright: ignore[reportPrivateUsage]
|
||||
|
||||
|
||||
def test_output_validation_with_multiple_valid_output_executors():
|
||||
@@ -572,7 +572,7 @@ def test_output_validation_with_multiple_valid_output_executors():
|
||||
)
|
||||
|
||||
assert workflow is not None
|
||||
assert set(workflow._output_executors) == {"executor1", "executor3"}
|
||||
assert set(workflow._output_executors) == {"executor1", "executor3"} # pyright: ignore[reportPrivateUsage]
|
||||
|
||||
|
||||
def test_output_validation_fails_for_nonexistent_executor():
|
||||
|
||||
@@ -2,6 +2,9 @@
|
||||
|
||||
"""Tests for the workflow visualization module."""
|
||||
|
||||
from pathlib import Path
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
|
||||
from agent_framework import Executor, WorkflowBuilder, WorkflowContext, WorkflowExecutor, WorkflowViz, handler
|
||||
@@ -25,7 +28,7 @@ class ListStrTargetExecutor(Executor):
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def basic_sub_workflow():
|
||||
def basic_sub_workflow() -> dict[str, Any]:
|
||||
"""Fixture that creates a basic sub-workflow setup for testing."""
|
||||
# Create a sub-workflow
|
||||
sub_exec1 = MockExecutor(id="sub_exec1")
|
||||
@@ -98,7 +101,7 @@ def test_workflow_viz_export_dot():
|
||||
assert '"executor1" -> "executor2"' in content
|
||||
|
||||
|
||||
def test_workflow_viz_export_dot_with_filename(tmp_path):
|
||||
def test_workflow_viz_export_dot_with_filename(tmp_path: Path):
|
||||
"""Test exporting workflow as DOT format with specified filename."""
|
||||
executor1 = MockExecutor(id="executor1")
|
||||
executor2 = MockExecutor(id="executor2")
|
||||
@@ -203,7 +206,7 @@ def test_workflow_viz_graphviz_binary_not_found():
|
||||
mock_source_class.return_value = mock_source
|
||||
|
||||
# Import the ExecutableNotFound exception for the test
|
||||
from graphviz.backend.execute import ExecutableNotFound
|
||||
from graphviz.backend.execute import ExecutableNotFound # type: ignore[import-not-found]
|
||||
|
||||
mock_source.render.side_effect = ExecutableNotFound("failed to execute PosixPath('dot')")
|
||||
|
||||
@@ -329,7 +332,7 @@ def test_workflow_viz_mermaid_fan_in_edge_group():
|
||||
assert "s2 --> t" not in mermaid
|
||||
|
||||
|
||||
def test_workflow_viz_sub_workflow_digraph(basic_sub_workflow):
|
||||
def test_workflow_viz_sub_workflow_digraph(basic_sub_workflow: dict[str, Any]):
|
||||
"""Test that WorkflowViz can visualize sub-workflows in DOT format."""
|
||||
main_workflow = basic_sub_workflow["main_workflow"]
|
||||
|
||||
@@ -353,7 +356,7 @@ def test_workflow_viz_sub_workflow_digraph(basic_sub_workflow):
|
||||
assert '"workflow_executor_1/sub_exec1" -> "workflow_executor_1/sub_exec2"' in dot_content
|
||||
|
||||
|
||||
def test_workflow_viz_sub_workflow_mermaid(basic_sub_workflow):
|
||||
def test_workflow_viz_sub_workflow_mermaid(basic_sub_workflow: dict[str, Any]):
|
||||
"""Test that WorkflowViz can visualize sub-workflows in Mermaid format."""
|
||||
main_workflow = basic_sub_workflow["main_workflow"]
|
||||
|
||||
|
||||
@@ -4,7 +4,7 @@ import asyncio
|
||||
import tempfile
|
||||
from collections.abc import AsyncIterable, Awaitable, Sequence
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Any, cast
|
||||
from typing import Any, Literal, cast, overload
|
||||
from uuid import uuid4
|
||||
|
||||
import pytest
|
||||
@@ -13,6 +13,7 @@ from agent_framework import (
|
||||
AgentExecutor,
|
||||
AgentResponse,
|
||||
AgentResponseUpdate,
|
||||
AgentRunInputs,
|
||||
AgentSession,
|
||||
BaseAgent,
|
||||
Content,
|
||||
@@ -474,7 +475,7 @@ class StateTrackingExecutor(Executor):
|
||||
) -> None:
|
||||
"""Handle the message and track it in workflow state."""
|
||||
# Get existing messages from workflow state
|
||||
existing_messages = ctx.get_state("processed_messages") or []
|
||||
existing_messages: list[str] = ctx.get_state("processed_messages") or []
|
||||
|
||||
# Record this message
|
||||
message_record = f"{message.run_id}:{message.data}"
|
||||
@@ -833,6 +834,26 @@ class _StreamingTestAgent(BaseAgent):
|
||||
super().__init__(**kwargs)
|
||||
self._reply_text = reply_text
|
||||
|
||||
@overload
|
||||
def run(
|
||||
self,
|
||||
messages: AgentRunInputs | None = ...,
|
||||
*,
|
||||
stream: Literal[False] = ...,
|
||||
session: AgentSession | None = ...,
|
||||
**kwargs: Any,
|
||||
) -> Awaitable[AgentResponse[Any]]: ...
|
||||
|
||||
@overload
|
||||
def run(
|
||||
self,
|
||||
messages: AgentRunInputs | None = ...,
|
||||
*,
|
||||
stream: Literal[True],
|
||||
session: AgentSession | None = ...,
|
||||
**kwargs: Any,
|
||||
) -> ResponseStream[AgentResponseUpdate, AgentResponse[Any]]: ...
|
||||
|
||||
def run(
|
||||
self,
|
||||
messages: str | Content | Message | Sequence[str | Content | Message] | None = None,
|
||||
@@ -883,8 +904,10 @@ async def test_agent_streaming_vs_non_streaming() -> None:
|
||||
stream_events.append(event)
|
||||
|
||||
# Filter for agent events
|
||||
agent_response = [
|
||||
cast(AgentResponse, e.data) for e in stream_events if e.type == "output" and isinstance(e.data, AgentResponse)
|
||||
agent_response: list[AgentResponse[Any]] = [
|
||||
cast(AgentResponse[Any], e.data) # pyright: ignore[reportUnknownMemberType]
|
||||
for e in stream_events
|
||||
if e.type == "output" and isinstance(e.data, AgentResponse)
|
||||
]
|
||||
agent_response_updates = [
|
||||
e.data for e in stream_events if e.type == "output" and isinstance(e.data, AgentResponseUpdate)
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
|
||||
import uuid
|
||||
from collections.abc import Awaitable, Sequence
|
||||
from typing import Any
|
||||
from typing import Any, Literal, overload
|
||||
|
||||
import pytest
|
||||
from typing_extensions import Never
|
||||
@@ -713,6 +713,14 @@ class TestWorkflowAgent:
|
||||
def create_session(self, **kwargs: Any) -> AgentSession:
|
||||
return AgentSession()
|
||||
|
||||
def get_session(self, *, service_session_id: str, **kwargs: Any) -> AgentSession:
|
||||
return AgentSession()
|
||||
|
||||
@overload
|
||||
def run(self, messages: str | Content | Message | Sequence[str | Content | Message] | None = ..., *, stream: Literal[False] = ..., session: AgentSession | None = ..., **kwargs: Any) -> Awaitable[AgentResponse[Any]]: ...
|
||||
@overload
|
||||
def run(self, messages: str | Content | Message | Sequence[str | Content | Message] | None = ..., *, stream: Literal[True], session: AgentSession | None = ..., **kwargs: Any) -> ResponseStream[AgentResponseUpdate, AgentResponse[Any]]: ...
|
||||
|
||||
def run(
|
||||
self,
|
||||
messages: str | Content | Message | Sequence[str | Content | Message] | None = None,
|
||||
@@ -801,6 +809,14 @@ class TestWorkflowAgent:
|
||||
def create_session(self, **kwargs: Any) -> AgentSession:
|
||||
return AgentSession()
|
||||
|
||||
def get_session(self, *, service_session_id: str, **kwargs: Any) -> AgentSession:
|
||||
return AgentSession()
|
||||
|
||||
@overload
|
||||
def run(self, messages: str | Content | Message | Sequence[str | Content | Message] | None = ..., *, stream: Literal[False] = ..., session: AgentSession | None = ..., **kwargs: Any) -> Awaitable[AgentResponse[Any]]: ...
|
||||
@overload
|
||||
def run(self, messages: str | Content | Message | Sequence[str | Content | Message] | None = ..., *, stream: Literal[True], session: AgentSession | None = ..., **kwargs: Any) -> ResponseStream[AgentResponseUpdate, AgentResponse[Any]]: ...
|
||||
|
||||
def run(
|
||||
self,
|
||||
messages: str | Content | Message | Sequence[str | Content | Message] | None = None,
|
||||
@@ -1207,7 +1223,7 @@ class TestWorkflowAgentMergeUpdates:
|
||||
]
|
||||
|
||||
# Compare using role.value for Role enum
|
||||
actual_sequence_normalized = [(t, r.value if hasattr(r, "value") else r) for t, r in content_sequence]
|
||||
actual_sequence_normalized = [(t, r.value if hasattr(r, "value") else r) for t, r in content_sequence] # type: ignore[union-attr]
|
||||
|
||||
assert actual_sequence_normalized == expected_sequence, (
|
||||
f"FunctionResultContent should come immediately after FunctionCallContent. "
|
||||
|
||||
@@ -1,7 +1,8 @@
|
||||
# Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
from collections.abc import AsyncIterator, Awaitable
|
||||
from dataclasses import dataclass
|
||||
from typing import Any
|
||||
from typing import Any, Literal, overload
|
||||
|
||||
import pytest
|
||||
|
||||
@@ -9,10 +10,12 @@ from agent_framework import (
|
||||
AgentExecutor,
|
||||
AgentResponse,
|
||||
AgentResponseUpdate,
|
||||
AgentRunInputs,
|
||||
AgentSession,
|
||||
BaseAgent,
|
||||
Executor,
|
||||
Message,
|
||||
ResponseStream,
|
||||
WorkflowBuilder,
|
||||
WorkflowContext,
|
||||
WorkflowValidationError,
|
||||
@@ -21,22 +24,49 @@ from agent_framework import (
|
||||
|
||||
|
||||
class DummyAgent(BaseAgent):
|
||||
def run(self, messages=None, *, stream: bool = False, session: AgentSession | None = None, **kwargs): # type: ignore[override]
|
||||
@overload
|
||||
def run(
|
||||
self,
|
||||
messages: AgentRunInputs | None = ...,
|
||||
*,
|
||||
stream: Literal[False] = ...,
|
||||
session: AgentSession | None = ...,
|
||||
**kwargs: Any,
|
||||
) -> Awaitable[AgentResponse[Any]]: ...
|
||||
|
||||
@overload
|
||||
def run(
|
||||
self,
|
||||
messages: AgentRunInputs | None = ...,
|
||||
*,
|
||||
stream: Literal[True],
|
||||
session: AgentSession | None = ...,
|
||||
**kwargs: Any,
|
||||
) -> ResponseStream[AgentResponseUpdate, AgentResponse[Any]]: ...
|
||||
|
||||
def run(
|
||||
self,
|
||||
messages: AgentRunInputs | None = None,
|
||||
*,
|
||||
stream: bool = False,
|
||||
session: AgentSession | None = None,
|
||||
**kwargs: Any,
|
||||
) -> Awaitable[AgentResponse[Any]] | ResponseStream[AgentResponseUpdate, AgentResponse[Any]]:
|
||||
if stream:
|
||||
return self._run_stream_impl()
|
||||
return ResponseStream[AgentResponseUpdate, AgentResponse[Any]](self._run_stream_impl())
|
||||
return self._run_impl(messages)
|
||||
|
||||
async def _run_impl(self, messages=None) -> AgentResponse:
|
||||
async def _run_impl(self, messages: AgentRunInputs | None = None) -> AgentResponse:
|
||||
norm: list[Message] = []
|
||||
if messages:
|
||||
for m in messages: # type: ignore[iteration-over-optional]
|
||||
for m in messages: # type: ignore[union-attr]
|
||||
if isinstance(m, Message):
|
||||
norm.append(m)
|
||||
elif isinstance(m, str):
|
||||
norm.append(Message(role="user", text=m))
|
||||
return AgentResponse(messages=norm)
|
||||
|
||||
async def _run_stream_impl(self): # type: ignore[override]
|
||||
async def _run_stream_impl(self) -> AsyncIterator[AgentResponseUpdate]:
|
||||
# Minimal async generator
|
||||
yield AgentResponseUpdate()
|
||||
|
||||
@@ -202,7 +232,7 @@ def test_with_output_from_returns_builder():
|
||||
builder = WorkflowBuilder(output_executors=[executor_a], start_executor=executor_a)
|
||||
|
||||
# Verify builder was created with output_executors
|
||||
assert builder._output_executors == [executor_a]
|
||||
assert builder._output_executors == [executor_a] # pyright: ignore[reportPrivateUsage]
|
||||
|
||||
|
||||
def test_with_output_from_with_executor_instances():
|
||||
|
||||
@@ -84,7 +84,7 @@ async def test_executor_emits_normal_event() -> None:
|
||||
|
||||
class _TestEvent(WorkflowEvent):
|
||||
def __init__(self, data: Any = None) -> None:
|
||||
super().__init__("test_event", data=data)
|
||||
super().__init__("test_event", data=data) # type: ignore[arg-type]
|
||||
|
||||
|
||||
async def test_workflow_context_type_annotations_no_parameter() -> None:
|
||||
@@ -244,8 +244,8 @@ async def test_workflow_context_missing_annotation_error() -> None:
|
||||
# Test class-based executor with missing ctx annotation
|
||||
with pytest.raises(ValueError, match="must have a WorkflowContext"):
|
||||
|
||||
class _BadExecutor(Executor):
|
||||
@handler
|
||||
class _BadExecutor(Executor): # pyright: ignore[reportUnusedClass]
|
||||
@handler # pyright: ignore[reportUnknownArgumentType]
|
||||
async def bad_handler(self, text: str, ctx) -> None: # type: ignore[no-untyped-def]
|
||||
pass
|
||||
|
||||
@@ -264,8 +264,8 @@ async def test_workflow_context_invalid_type_parameter_error() -> None:
|
||||
# Test class-based executor with invalid type parameter
|
||||
with pytest.raises(ValueError, match="invalid type entry"):
|
||||
|
||||
class _BadExecutor(Executor):
|
||||
@handler
|
||||
class _BadExecutor(Executor): # pyright: ignore[reportUnusedClass]
|
||||
@handler # pyright: ignore[reportUnknownArgumentType]
|
||||
async def bad_handler(self, text: str, ctx: WorkflowContext[456]) -> None: # type: ignore[valid-type]
|
||||
pass
|
||||
|
||||
|
||||
@@ -1,13 +1,14 @@
|
||||
# Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
from collections.abc import AsyncIterable, Awaitable, Sequence
|
||||
from typing import Annotated, Any
|
||||
from collections.abc import AsyncIterable, Awaitable
|
||||
from typing import Annotated, Any, Literal, overload
|
||||
|
||||
import pytest
|
||||
|
||||
from agent_framework import (
|
||||
AgentResponse,
|
||||
AgentResponseUpdate,
|
||||
AgentRunInputs,
|
||||
AgentSession,
|
||||
BaseAgent,
|
||||
Content,
|
||||
@@ -50,14 +51,19 @@ class _KwargsCapturingAgent(BaseAgent):
|
||||
super().__init__(name=name, description="Test agent for kwargs capture")
|
||||
self.captured_kwargs = []
|
||||
|
||||
@overload
|
||||
def run(self, messages: AgentRunInputs | None = ..., *, stream: Literal[False] = ..., session: AgentSession | None = ..., **kwargs: Any) -> Awaitable[AgentResponse[Any]]: ...
|
||||
@overload
|
||||
def run(self, messages: AgentRunInputs | None = ..., *, stream: Literal[True], session: AgentSession | None = ..., **kwargs: Any) -> ResponseStream[AgentResponseUpdate, AgentResponse[Any]]: ...
|
||||
|
||||
def run(
|
||||
self,
|
||||
messages: str | Content | Message | Sequence[str | Content | Message] | None = None,
|
||||
messages: AgentRunInputs | None = None,
|
||||
*,
|
||||
stream: bool = False,
|
||||
session: AgentSession | None = None,
|
||||
**kwargs: Any,
|
||||
) -> Awaitable[AgentResponse] | ResponseStream[AgentResponseUpdate, AgentResponse]:
|
||||
) -> Awaitable[AgentResponse[Any]] | ResponseStream[AgentResponseUpdate, AgentResponse[Any]]:
|
||||
self.captured_kwargs.append(dict(kwargs))
|
||||
if stream:
|
||||
|
||||
@@ -83,15 +89,20 @@ class _OptionsAwareAgent(BaseAgent):
|
||||
self.captured_options = []
|
||||
self.captured_kwargs = []
|
||||
|
||||
@overload
|
||||
def run(self, messages: AgentRunInputs | None = ..., *, stream: Literal[False] = ..., session: AgentSession | None = ..., **kwargs: Any) -> Awaitable[AgentResponse[Any]]: ...
|
||||
@overload
|
||||
def run(self, messages: AgentRunInputs | None = ..., *, stream: Literal[True], session: AgentSession | None = ..., **kwargs: Any) -> ResponseStream[AgentResponseUpdate, AgentResponse[Any]]: ...
|
||||
|
||||
def run(
|
||||
self,
|
||||
messages: str | Content | Message | Sequence[str | Content | Message] | None = None,
|
||||
messages: AgentRunInputs | None = None,
|
||||
*,
|
||||
stream: bool = False,
|
||||
session: AgentSession | None = None,
|
||||
options: dict[str, Any] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> Awaitable[AgentResponse] | ResponseStream[AgentResponseUpdate, AgentResponse]:
|
||||
) -> Awaitable[AgentResponse[Any]] | ResponseStream[AgentResponseUpdate, AgentResponse[Any]]:
|
||||
self.captured_options.append(dict(options) if options is not None else None)
|
||||
self.captured_kwargs.append(dict(kwargs))
|
||||
if stream:
|
||||
@@ -189,15 +200,15 @@ async def test_sequential_run_options_does_not_conflict_with_agent_options() ->
|
||||
break
|
||||
|
||||
assert len(agent.captured_options) >= 1
|
||||
captured_options = agent.captured_options[0]
|
||||
captured_options: dict[str, Any] | None = agent.captured_options[0]
|
||||
assert captured_options is not None
|
||||
assert captured_options.get("store") is False
|
||||
|
||||
additional_args = captured_options.get("additional_function_arguments")
|
||||
additional_args: Any = captured_options.get("additional_function_arguments")
|
||||
assert isinstance(additional_args, dict)
|
||||
assert additional_args.get("source") == "workflow-options"
|
||||
assert additional_args.get("custom_data") == custom_data
|
||||
assert additional_args.get("user_token") == user_token
|
||||
assert additional_args.get("source") == "workflow-options" # pyright: ignore[reportUnknownMemberType]
|
||||
assert additional_args.get("custom_data") == custom_data # pyright: ignore[reportUnknownMemberType]
|
||||
assert additional_args.get("user_token") == user_token # pyright: ignore[reportUnknownMemberType]
|
||||
|
||||
# "options" should be passed once via the dedicated options parameter,
|
||||
# not duplicated in **kwargs.
|
||||
@@ -225,13 +236,13 @@ async def test_sequential_run_additional_function_arguments_flattened() -> None:
|
||||
break
|
||||
|
||||
assert len(agent.captured_options) >= 1
|
||||
captured_options = agent.captured_options[0]
|
||||
captured_options: dict[str, Any] | None = agent.captured_options[0]
|
||||
assert captured_options is not None
|
||||
|
||||
additional_args = captured_options.get("additional_function_arguments")
|
||||
additional_args: Any = captured_options.get("additional_function_arguments")
|
||||
assert isinstance(additional_args, dict)
|
||||
assert additional_args.get("custom_data") == custom_data
|
||||
assert additional_args.get("user_token") == user_token
|
||||
assert additional_args.get("custom_data") == custom_data # pyright: ignore[reportUnknownMemberType]
|
||||
assert additional_args.get("user_token") == user_token # pyright: ignore[reportUnknownMemberType]
|
||||
assert "additional_function_arguments" not in additional_args
|
||||
|
||||
assert len(agent.captured_kwargs) >= 1
|
||||
@@ -255,14 +266,14 @@ async def test_sequential_run_additional_function_arguments_merges_with_options(
|
||||
break
|
||||
|
||||
assert len(agent.captured_options) >= 1
|
||||
captured_options = agent.captured_options[0]
|
||||
captured_options: dict[str, Any] | None = agent.captured_options[0]
|
||||
assert captured_options is not None
|
||||
|
||||
additional_args = captured_options.get("additional_function_arguments")
|
||||
additional_args: Any = captured_options.get("additional_function_arguments")
|
||||
assert isinstance(additional_args, dict)
|
||||
assert additional_args.get("source") == "workflow-options"
|
||||
assert additional_args.get("custom_data") == {"session_id": "abc123"}
|
||||
assert additional_args.get("user_token") == {"user_name": "alice"}
|
||||
assert additional_args.get("source") == "workflow-options" # pyright: ignore[reportUnknownMemberType]
|
||||
assert additional_args.get("custom_data") == {"session_id": "abc123"} # pyright: ignore[reportUnknownMemberType]
|
||||
assert additional_args.get("user_token") == {"user_name": "alice"} # pyright: ignore[reportUnknownMemberType]
|
||||
assert "additional_function_arguments" not in additional_args
|
||||
|
||||
|
||||
@@ -463,14 +474,19 @@ async def test_kwargs_preserved_on_response_continuation() -> None:
|
||||
self.captured_kwargs = []
|
||||
self._asked = False
|
||||
|
||||
@overload
|
||||
def run(self, messages: AgentRunInputs | None = ..., *, stream: Literal[False] = ..., session: AgentSession | None = ..., **kwargs: Any) -> Awaitable[AgentResponse[Any]]: ...
|
||||
@overload
|
||||
def run(self, messages: AgentRunInputs | None = ..., *, stream: Literal[True], session: AgentSession | None = ..., **kwargs: Any) -> ResponseStream[AgentResponseUpdate, AgentResponse[Any]]: ...
|
||||
|
||||
def run(
|
||||
self,
|
||||
messages: str | Content | Message | Sequence[str | Content | Message] | None = None,
|
||||
messages: AgentRunInputs | None = None,
|
||||
*,
|
||||
stream: bool = False,
|
||||
session: AgentSession | None = None,
|
||||
**kwargs: Any,
|
||||
) -> Awaitable[AgentResponse] | ResponseStream[AgentResponseUpdate, AgentResponse]:
|
||||
) -> Awaitable[AgentResponse[Any]] | ResponseStream[AgentResponseUpdate, AgentResponse[Any]]:
|
||||
self.captured_kwargs.append(dict(kwargs))
|
||||
if not self._asked:
|
||||
self._asked = True
|
||||
@@ -521,14 +537,19 @@ async def test_kwargs_overridden_on_response_continuation() -> None:
|
||||
self.captured_kwargs = []
|
||||
self._asked = False
|
||||
|
||||
@overload
|
||||
def run(self, messages: AgentRunInputs | None = ..., *, stream: Literal[False] = ..., session: AgentSession | None = ..., **kwargs: Any) -> Awaitable[AgentResponse[Any]]: ...
|
||||
@overload
|
||||
def run(self, messages: AgentRunInputs | None = ..., *, stream: Literal[True], session: AgentSession | None = ..., **kwargs: Any) -> ResponseStream[AgentResponseUpdate, AgentResponse[Any]]: ...
|
||||
|
||||
def run(
|
||||
self,
|
||||
messages: str | Content | Message | Sequence[str | Content | Message] | None = None,
|
||||
messages: AgentRunInputs | None = None,
|
||||
*,
|
||||
stream: bool = False,
|
||||
session: AgentSession | None = None,
|
||||
**kwargs: Any,
|
||||
) -> Awaitable[AgentResponse] | ResponseStream[AgentResponseUpdate, AgentResponse]:
|
||||
) -> Awaitable[AgentResponse[Any]] | ResponseStream[AgentResponseUpdate, AgentResponse[Any]]:
|
||||
self.captured_kwargs.append(dict(kwargs))
|
||||
if not self._asked:
|
||||
self._asked = True
|
||||
@@ -583,14 +604,19 @@ async def test_kwargs_empty_value_passed_on_continuation() -> None:
|
||||
self.captured_kwargs = []
|
||||
self._asked = False
|
||||
|
||||
@overload
|
||||
def run(self, messages: AgentRunInputs | None = ..., *, stream: Literal[False] = ..., session: AgentSession | None = ..., **kwargs: Any) -> Awaitable[AgentResponse[Any]]: ...
|
||||
@overload
|
||||
def run(self, messages: AgentRunInputs | None = ..., *, stream: Literal[True], session: AgentSession | None = ..., **kwargs: Any) -> ResponseStream[AgentResponseUpdate, AgentResponse[Any]]: ...
|
||||
|
||||
def run(
|
||||
self,
|
||||
messages: str | Content | Message | Sequence[str | Content | Message] | None = None,
|
||||
messages: AgentRunInputs | None = None,
|
||||
*,
|
||||
stream: bool = False,
|
||||
session: AgentSession | None = None,
|
||||
**kwargs: Any,
|
||||
) -> Awaitable[AgentResponse] | ResponseStream[AgentResponseUpdate, AgentResponse]:
|
||||
) -> Awaitable[AgentResponse[Any]] | ResponseStream[AgentResponseUpdate, AgentResponse[Any]]:
|
||||
self.captured_kwargs.append(dict(kwargs))
|
||||
if not self._asked:
|
||||
self._asked = True
|
||||
@@ -690,8 +716,8 @@ async def test_handoff_kwargs_flow_to_agents() -> None:
|
||||
|
||||
workflow = (
|
||||
HandoffBuilder(termination_condition=lambda conv: len(conv) >= 4)
|
||||
.participants([agent1, agent2])
|
||||
.with_start_agent(agent1)
|
||||
.participants([agent1, agent2]) # type: ignore[list-item]
|
||||
.with_start_agent(agent1) # type: ignore[arg-type]
|
||||
.with_autonomous_mode()
|
||||
.build()
|
||||
)
|
||||
|
||||
@@ -109,7 +109,7 @@ async def test_span_creation_and_attributes(span_exporter: InMemorySpanExporter)
|
||||
{
|
||||
"id": "test-workflow-123",
|
||||
"max_iterations": 100,
|
||||
"model_dump_json": lambda self: '{"id": "test-workflow-123", "type": "mock"}',
|
||||
"model_dump_json": lambda self: '{"id": "test-workflow-123", "type": "mock"}', # pyright: ignore[reportUnknownLambdaType]
|
||||
},
|
||||
)(),
|
||||
)
|
||||
@@ -122,7 +122,7 @@ async def test_span_creation_and_attributes(span_exporter: InMemorySpanExporter)
|
||||
},
|
||||
) as workflow_span:
|
||||
workflow_span.add_event(OtelAttr.WORKFLOW_STARTED)
|
||||
sending_attributes = {
|
||||
sending_attributes: dict[str, str | int] = {
|
||||
OtelAttr.MESSAGE_TYPE: "ResponseMessage",
|
||||
OtelAttr.MESSAGE_DESTINATION_EXECUTOR_ID: "target-789",
|
||||
}
|
||||
@@ -231,7 +231,7 @@ async def test_trace_context_handling(span_exporter: InMemorySpanExporter) -> No
|
||||
|
||||
@pytest.mark.parametrize("enable_instrumentation", [False], indirect=True)
|
||||
async def test_trace_context_disabled_when_tracing_disabled(
|
||||
enable_instrumentation, span_exporter: InMemorySpanExporter
|
||||
enable_instrumentation: bool, span_exporter: InMemorySpanExporter
|
||||
) -> None:
|
||||
"""Test that no trace context is added when tracing is disabled."""
|
||||
# Tracing should be disabled by default
|
||||
@@ -313,7 +313,7 @@ async def test_end_to_end_workflow_tracing(span_exporter: InMemorySpanExporter)
|
||||
span_exporter.clear()
|
||||
|
||||
# Run workflow (this should create run spans)
|
||||
events = []
|
||||
events: list[Any] = []
|
||||
async for event in workflow.run("test input", stream=True):
|
||||
events.append(event)
|
||||
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
# Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
from typing_extensions import Never
|
||||
|
||||
@@ -36,16 +38,16 @@ async def test_executor_failed_and_workflow_failed_events_streaming():
|
||||
events.append(ev)
|
||||
|
||||
# executor_failed event (type='executor_failed') should be emitted before workflow failed event
|
||||
executor_failed_events = [e for e in events if isinstance(e, WorkflowEvent) and e.type == "executor_failed"]
|
||||
executor_failed_events: list[WorkflowEvent[Any]] = [e for e in events if isinstance(e, WorkflowEvent) and e.type == "executor_failed"]
|
||||
assert executor_failed_events, "executor_failed event should be emitted when start executor fails"
|
||||
assert executor_failed_events[0].executor_id == "f"
|
||||
assert executor_failed_events[0].origin is WorkflowEventSource.FRAMEWORK
|
||||
|
||||
# Workflow-level failure and FAILED status should be surfaced
|
||||
failed_events = [e for e in events if isinstance(e, WorkflowEvent) and e.type == "failed"]
|
||||
failed_events: list[WorkflowEvent[Any]] = [e for e in events if isinstance(e, WorkflowEvent) and e.type == "failed"]
|
||||
assert failed_events
|
||||
assert all(e.origin is WorkflowEventSource.FRAMEWORK for e in failed_events)
|
||||
status = [e for e in events if isinstance(e, WorkflowEvent) and e.type == "status"]
|
||||
status: list[WorkflowEvent[Any]] = [e for e in events if isinstance(e, WorkflowEvent) and e.type == "status"]
|
||||
assert status and status[-1].state == WorkflowRunState.FAILED
|
||||
assert all(e.origin is WorkflowEventSource.FRAMEWORK for e in status)
|
||||
|
||||
@@ -94,13 +96,13 @@ async def test_executor_failed_event_from_second_executor_in_chain():
|
||||
events.append(ev)
|
||||
|
||||
# executor_failed event should be emitted for the failing executor
|
||||
executor_failed_events = [e for e in events if isinstance(e, WorkflowEvent) and e.type == "executor_failed"]
|
||||
executor_failed_events: list[WorkflowEvent[Any]] = [e for e in events if isinstance(e, WorkflowEvent) and e.type == "executor_failed"]
|
||||
assert executor_failed_events, "executor_failed event should be emitted when second executor fails"
|
||||
assert executor_failed_events[0].executor_id == "failing"
|
||||
assert executor_failed_events[0].origin is WorkflowEventSource.FRAMEWORK
|
||||
|
||||
# Workflow-level failure should also be surfaced
|
||||
failed_events = [e for e in events if isinstance(e, WorkflowEvent) and e.type == "failed"]
|
||||
failed_events: list[WorkflowEvent[Any]] = [e for e in events if isinstance(e, WorkflowEvent) and e.type == "failed"]
|
||||
assert failed_events
|
||||
assert all(e.origin is WorkflowEventSource.FRAMEWORK for e in failed_events)
|
||||
|
||||
|
||||
@@ -184,7 +184,6 @@ omit = [
|
||||
|
||||
[tool.pyright]
|
||||
include = ["agent_framework*"]
|
||||
exclude = ["**/tests/**", "**/.venv/**", "packages/devui/frontend/**"]
|
||||
typeCheckingMode = "strict"
|
||||
reportUnnecessaryIsInstance = false
|
||||
reportMissingTypeStubs = false
|
||||
|
||||
Reference in New Issue
Block a user