mirror of
https://github.com/microsoft/agent-framework.git
synced 2026-06-16 21:04:09 +08:00
Add reset to workflow
This commit is contained in:
@@ -1643,9 +1643,7 @@ class MCPTool:
|
||||
return parser(fallback_result)
|
||||
|
||||
if task_id is None:
|
||||
raise ToolExecutionException(
|
||||
f"MCP server did not return a task_id or fallback result for '{tool_name}'."
|
||||
)
|
||||
raise ToolExecutionException(f"MCP server did not return a task_id or fallback result for '{tool_name}'.")
|
||||
|
||||
# Track to completion: poll until terminal, then fetch payload. Never re-issue
|
||||
# tools/call past this point; reconnect-and-retry only against the same task_id.
|
||||
@@ -1765,9 +1763,7 @@ class MCPTool:
|
||||
transient_codes: frozenset[int] = frozenset({int(httpx.codes.REQUEST_TIMEOUT)})
|
||||
|
||||
while True:
|
||||
request = types.ClientRequest(
|
||||
types.GetTaskRequest(params=types.GetTaskRequestParams(taskId=task_id))
|
||||
)
|
||||
request = types.ClientRequest(types.GetTaskRequest(params=types.GetTaskRequestParams(taskId=task_id)))
|
||||
try:
|
||||
# GetTaskResult.ttl is required-but-Optional in the SDK; coerce below.
|
||||
lenient = await self._send_with_one_reconnect(
|
||||
@@ -1775,9 +1771,7 @@ class MCPTool:
|
||||
)
|
||||
except McpError as ex:
|
||||
if ex.error.code in transient_codes:
|
||||
logger.debug(
|
||||
"Transient %s on tasks/get for '%s'; will retry.", ex.error.code, task_id
|
||||
)
|
||||
logger.debug("Transient %s on tasks/get for '%s'; will retry.", ex.error.code, task_id)
|
||||
await asyncio.sleep(_MCP_TASK_MIN_POLL_INTERVAL.total_seconds())
|
||||
continue
|
||||
# Hard server error mid-poll: task may still be running.
|
||||
@@ -1906,9 +1900,7 @@ class MCPTool:
|
||||
if not self._is_connection_lost(ex):
|
||||
raise
|
||||
if attempt < _MCP_RECONNECT_ATTEMPTS - 1:
|
||||
logger.info(
|
||||
"MCP connection lost during %s; reconnecting (task_id=%s).", operation, task_id
|
||||
)
|
||||
logger.info("MCP connection lost during %s; reconnecting (task_id=%s).", operation, task_id)
|
||||
try:
|
||||
await self.connect(reset=True)
|
||||
except Exception as reconn_ex:
|
||||
@@ -1967,9 +1959,7 @@ class MCPTool:
|
||||
"""
|
||||
from mcp import types
|
||||
|
||||
request = types.ClientRequest(
|
||||
types.CancelTaskRequest(params=types.CancelTaskRequestParams(taskId=task_id))
|
||||
)
|
||||
request = types.ClientRequest(types.CancelTaskRequest(params=types.CancelTaskRequestParams(taskId=task_id)))
|
||||
try:
|
||||
await asyncio.wait_for(
|
||||
self.session.send_request(request, types.CancelTaskResult), # type: ignore[union-attr]
|
||||
@@ -1979,8 +1969,7 @@ class MCPTool:
|
||||
raise
|
||||
except asyncio.TimeoutError:
|
||||
logger.warning(
|
||||
"Best-effort tasks/cancel for '%s' timed out after %.1fs; "
|
||||
"remote task may still be running.",
|
||||
"Best-effort tasks/cancel for '%s' timed out after %.1fs; remote task may still be running.",
|
||||
task_id,
|
||||
_MCP_TASK_CANCEL_TIMEOUT.total_seconds(),
|
||||
)
|
||||
|
||||
@@ -3516,9 +3516,7 @@ class MCPSkill(Skill):
|
||||
result = await self._client.read_resource(_mcp_any_url(self._skill_md_uri))
|
||||
text = _mcp_join_text(result)
|
||||
if not text:
|
||||
raise ValueError(
|
||||
f"The MCP server returned no text content for SKILL.md resource '{self._skill_md_uri}'."
|
||||
)
|
||||
raise ValueError(f"The MCP server returned no text content for SKILL.md resource '{self._skill_md_uri}'.")
|
||||
self._content = text
|
||||
return text
|
||||
|
||||
@@ -3572,11 +3570,7 @@ class MCPSkill(Skill):
|
||||
or ``None`` if the name is unsafe.
|
||||
"""
|
||||
normalized = name.replace("\\", "/")
|
||||
if (
|
||||
normalized.startswith("/")
|
||||
or "://" in normalized
|
||||
or any(seg == ".." for seg in normalized.split("/"))
|
||||
):
|
||||
if normalized.startswith("/") or "://" in normalized or any(seg == ".." for seg in normalized.split("/")):
|
||||
logger.debug("Rejecting resource name with unsafe path components: %r", name)
|
||||
return None
|
||||
return normalized
|
||||
|
||||
@@ -166,6 +166,10 @@ class AgentExecutor(Executor):
|
||||
raise ValueError("Agent must have a non-empty name or id or an explicit id must be provided.")
|
||||
super().__init__(exec_id)
|
||||
self._agent = agent
|
||||
# Track whether the caller supplied a session so reset() can preserve their
|
||||
# session reference (which may be wired to external/service-side storage)
|
||||
# and only replace sessions we created ourselves.
|
||||
self._session_supplied_by_caller = session is not None
|
||||
self._session = session or self._agent.create_session()
|
||||
|
||||
self._pending_agent_requests: dict[str, Content] = {}
|
||||
@@ -365,10 +369,37 @@ class AgentExecutor(Executor):
|
||||
pending_responses_payload = state.get("pending_responses_to_agent")
|
||||
self._pending_responses_to_agent = pending_responses_payload or []
|
||||
|
||||
def reset(self) -> None:
|
||||
"""Reset the internal cache of the executor."""
|
||||
logger.debug("AgentExecutor %s: Resetting cache", self.id)
|
||||
@override
|
||||
async def reset(self) -> None:
|
||||
"""Reset the executor to its initial state for a new workflow run.
|
||||
|
||||
Clears the message cache, full conversation snapshot, and any pending
|
||||
user-input request/response bookkeeping.
|
||||
|
||||
Session handling:
|
||||
* If the session was created by this executor (no ``session`` argument
|
||||
was passed to ``__init__``), it is replaced with a fresh one via
|
||||
``agent.create_session()`` so prior conversation history does not
|
||||
leak into the next run.
|
||||
* If the session was supplied by the caller, it is left untouched.
|
||||
The caller owns the session lifecycle (it may be backed by
|
||||
service-side or external storage) and is responsible for clearing
|
||||
or rotating it if a clean slate is desired.
|
||||
"""
|
||||
logger.debug("AgentExecutor %s: Resetting state", self.id)
|
||||
self._cache.clear()
|
||||
self._full_conversation.clear()
|
||||
self._pending_agent_requests.clear()
|
||||
self._pending_responses_to_agent.clear()
|
||||
if not self._session_supplied_by_caller:
|
||||
self._session = self._agent.create_session()
|
||||
else:
|
||||
logger.warning(
|
||||
"AgentExecutor %s: Session was supplied by the caller and will not be reset. "
|
||||
"Prior conversation history retained in the session may leak into the next run. "
|
||||
"Reset or rotate the session externally if a clean slate is required.",
|
||||
self.id,
|
||||
)
|
||||
|
||||
async def _run_agent_and_emit(
|
||||
self,
|
||||
|
||||
@@ -516,6 +516,14 @@ class Executor(RequestInfoMixin, DictConvertible):
|
||||
"""
|
||||
...
|
||||
|
||||
async def reset(self) -> None:
|
||||
"""Reset the executor to its initial state.
|
||||
|
||||
Override this method in subclasses to implement custom logic that should
|
||||
run when the workflow is reset.
|
||||
"""
|
||||
...
|
||||
|
||||
|
||||
# endregion: Executor
|
||||
|
||||
|
||||
@@ -88,9 +88,14 @@ class Runner:
|
||||
|
||||
@property
|
||||
def context(self) -> RunnerContext:
|
||||
"""Get the workflow context."""
|
||||
"""Get the runner context for message, event, and checkpoint handling."""
|
||||
return self._ctx
|
||||
|
||||
@property
|
||||
def state(self) -> State:
|
||||
"""Get the shared state for the workflow."""
|
||||
return self._state
|
||||
|
||||
def reserve(self) -> None:
|
||||
"""Synchronously reserve the runner for an upcoming run.
|
||||
|
||||
@@ -117,9 +122,38 @@ class Runner:
|
||||
self._lifecycle = _RunnerLifecycle.IDLE
|
||||
|
||||
def reset_iteration_count(self) -> None:
|
||||
"""Reset the iteration count to zero."""
|
||||
"""Reset the iteration count to zero.
|
||||
|
||||
This is useful when the workflow resumes from a new set of messages.
|
||||
|
||||
Note:
|
||||
When a workflow is resumed from a response (for a request_info_event)
|
||||
or a checkpoint, the iteration count is normally NOT reset.
|
||||
"""
|
||||
self._iteration = 0
|
||||
|
||||
async def reset_for_new_run(self) -> None:
|
||||
"""Reset the runner for a new run.
|
||||
|
||||
This is useful when reusing the same workflow instance for a different run
|
||||
that is independent from prior runs.
|
||||
|
||||
Raises:
|
||||
WorkflowRunnerException: If the runner is reserved or running. Reset is only
|
||||
allowed when the runner is idle to avoid clobbering in-flight run state.
|
||||
"""
|
||||
if self._lifecycle is not _RunnerLifecycle.IDLE:
|
||||
raise WorkflowRunnerException(
|
||||
"Cannot reset the runner while a run is in progress. "
|
||||
"Wait for the current run to complete before calling reset_for_new_run()."
|
||||
)
|
||||
self.reset_iteration_count()
|
||||
self._ctx.reset_for_new_run()
|
||||
self._state.clear()
|
||||
self._resumed_from_checkpoint = False
|
||||
for executor in self._executors.values():
|
||||
await executor.reset()
|
||||
|
||||
async def run_until_convergence(self) -> AsyncGenerator[WorkflowEvent, None]:
|
||||
"""Run the workflow until no more messages are sent."""
|
||||
# Mandatory reservation: callers must reserve() the runner first. This makes
|
||||
@@ -294,7 +328,7 @@ class Runner:
|
||||
checkpoint_id: CheckpointID,
|
||||
checkpoint_storage: CheckpointStorage | None = None,
|
||||
) -> None:
|
||||
"""Restore workflow state from a checkpoint.
|
||||
"""Restore the runner from a checkpoint.
|
||||
|
||||
Args:
|
||||
checkpoint_id: The ID of the checkpoint to restore from
|
||||
|
||||
@@ -347,11 +347,10 @@ class Workflow(DictConvertible):
|
||||
# Store non-serializable runtime objects as private attributes
|
||||
self._runner_context = runner_context
|
||||
self._runner_context.set_yield_output_classifier(self._output_designation.classify)
|
||||
self._state = State()
|
||||
self._runner: Runner = Runner(
|
||||
self.edge_groups,
|
||||
self.executors,
|
||||
self._state,
|
||||
State(),
|
||||
runner_context,
|
||||
self.name,
|
||||
self.graph_signature_hash,
|
||||
@@ -552,14 +551,14 @@ class Workflow(DictConvertible):
|
||||
combined_kwargs["client_kwargs"] = self._resolve_invocation_kwargs(
|
||||
client_kwargs, "client_kwargs"
|
||||
)
|
||||
self._state.set(WORKFLOW_RUN_KWARGS_KEY, combined_kwargs)
|
||||
self._runner.state.set(WORKFLOW_RUN_KWARGS_KEY, combined_kwargs)
|
||||
elif not is_continuation:
|
||||
self._state.set(WORKFLOW_RUN_KWARGS_KEY, {})
|
||||
self._state.commit() # Commit immediately so kwargs are available
|
||||
self._runner.state.set(WORKFLOW_RUN_KWARGS_KEY, {})
|
||||
self._runner.state.commit() # Commit immediately so kwargs are available
|
||||
|
||||
# Set streaming mode (always set explicitly per run since
|
||||
# reset_for_new_run() no longer runs to clear it).
|
||||
self._runner_context.set_streaming(streaming)
|
||||
self._runner.context.set_streaming(streaming)
|
||||
|
||||
# Execute initial setup if provided
|
||||
if initial_executor_fn:
|
||||
@@ -653,7 +652,7 @@ class Workflow(DictConvertible):
|
||||
await executor.execute(
|
||||
message,
|
||||
[self.__class__.__name__],
|
||||
self._state,
|
||||
self._runner.state,
|
||||
self._runner.context,
|
||||
trace_contexts=None,
|
||||
source_span_ids=None,
|
||||
@@ -934,7 +933,7 @@ class Workflow(DictConvertible):
|
||||
|
||||
async def _send_responses_internal(self, responses: Mapping[str, Any]) -> None:
|
||||
"""Internal method to validate and send responses to the executors."""
|
||||
pending_requests = await self._runner_context.get_pending_request_info_events()
|
||||
pending_requests = await self._runner.context.get_pending_request_info_events()
|
||||
if not pending_requests:
|
||||
raise RuntimeError("No pending requests found in workflow context.")
|
||||
|
||||
@@ -954,7 +953,7 @@ class Workflow(DictConvertible):
|
||||
coerced_responses[request_id] = response
|
||||
|
||||
await asyncio.gather(*[
|
||||
self._runner_context.send_request_info_response(request_id, response)
|
||||
self._runner.context.send_request_info_response(request_id, response)
|
||||
for request_id, response in coerced_responses.items()
|
||||
])
|
||||
|
||||
@@ -1150,3 +1149,17 @@ class Workflow(DictConvertible):
|
||||
context_providers=context_providers,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
async def reset_for_new_run(self) -> None:
|
||||
"""Reset the workflow for a new run that is independent from prior runs.
|
||||
|
||||
Note:
|
||||
This will reset EVERYTHING - executor states, workflow state, and runner
|
||||
context (including pending requests/messages).
|
||||
|
||||
Raises:
|
||||
WorkflowRunnerException: If a run is currently in progress. Reset is only
|
||||
allowed when the workflow is idle to avoid clobbering in-flight run state.
|
||||
|
||||
"""
|
||||
await self._runner.reset_for_new_run()
|
||||
|
||||
@@ -534,6 +534,19 @@ class WorkflowExecutor(Executor):
|
||||
for event in request_info_events
|
||||
])
|
||||
|
||||
@override
|
||||
async def reset(self) -> None:
|
||||
"""Reset the WorkflowExecutor to its initial state for a new workflow run.
|
||||
|
||||
Clears in-flight execution contexts and the request-to-execution mapping,
|
||||
then recursively resets the wrapped sub-workflow so its executors, runner
|
||||
context, and shared state are also returned to a clean state.
|
||||
"""
|
||||
logger.debug("WorkflowExecutor %s: Resetting state", self.id)
|
||||
self._execution_contexts.clear()
|
||||
self._request_to_execution.clear()
|
||||
await self.workflow.reset_for_new_run()
|
||||
|
||||
async def _process_workflow_result(
|
||||
self,
|
||||
result: WorkflowRunResult,
|
||||
|
||||
@@ -281,9 +281,7 @@ async def test_mcp_prompts_get_creates_client_span(span_exporter: InMemorySpanEx
|
||||
async def test_mcp_prompts_get_mcp_error_sets_error_type(span_exporter: InMemorySpanExporter):
|
||||
"""When session.get_prompt() raises McpError, the span should have error.type and ERROR status."""
|
||||
tool = _make_connected_mcp_tool()
|
||||
tool.session.get_prompt = AsyncMock(
|
||||
side_effect=McpError(ErrorData(code=-32602, message="prompt not found"))
|
||||
)
|
||||
tool.session.get_prompt = AsyncMock(side_effect=McpError(ErrorData(code=-32602, message="prompt not found")))
|
||||
|
||||
span_exporter.clear()
|
||||
with pytest.raises(ToolExecutionException):
|
||||
|
||||
@@ -35,26 +35,22 @@ description: Convert between common units.
|
||||
Body content here.
|
||||
"""
|
||||
|
||||
SAMPLE_SKILL_INDEX = json.dumps(
|
||||
{
|
||||
"$schema": "https://schemas.agentskills.io/discovery/0.2.0/schema.json",
|
||||
"skills": [
|
||||
{
|
||||
"name": "unit-converter",
|
||||
"type": "skill-md",
|
||||
"description": "Convert between common units.",
|
||||
"url": "skill://unit-converter/SKILL.md",
|
||||
}
|
||||
],
|
||||
}
|
||||
)
|
||||
SAMPLE_SKILL_INDEX = json.dumps({
|
||||
"$schema": "https://schemas.agentskills.io/discovery/0.2.0/schema.json",
|
||||
"skills": [
|
||||
{
|
||||
"name": "unit-converter",
|
||||
"type": "skill-md",
|
||||
"description": "Convert between common units.",
|
||||
"url": "skill://unit-converter/SKILL.md",
|
||||
}
|
||||
],
|
||||
})
|
||||
|
||||
|
||||
def _make_text_result(text: str, uri: str = "skill://test") -> ReadResourceResult:
|
||||
"""Create a ReadResourceResult with a single TextResourceContents."""
|
||||
return ReadResourceResult(
|
||||
contents=[TextResourceContents(uri=AnyUrl(uri), text=text, mimeType="text/markdown")]
|
||||
)
|
||||
return ReadResourceResult(contents=[TextResourceContents(uri=AnyUrl(uri), text=text, mimeType="text/markdown")])
|
||||
|
||||
|
||||
def _make_blob_result(
|
||||
@@ -230,12 +226,10 @@ class TestMCPSkill:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_resource_text(self) -> None:
|
||||
client = _make_client(
|
||||
**{
|
||||
"skill://unit-converter/SKILL.md": _make_text_result(SAMPLE_SKILL_MD),
|
||||
"skill://unit-converter/references/checklist.md": _make_text_result("- check thing 1\n- check thing 2"),
|
||||
}
|
||||
)
|
||||
client = _make_client(**{
|
||||
"skill://unit-converter/SKILL.md": _make_text_result(SAMPLE_SKILL_MD),
|
||||
"skill://unit-converter/references/checklist.md": _make_text_result("- check thing 1\n- check thing 2"),
|
||||
})
|
||||
from agent_framework import SkillFrontmatter
|
||||
|
||||
fm = SkillFrontmatter(name="unit-converter", description="Convert between common units.")
|
||||
@@ -249,12 +243,10 @@ class TestMCPSkill:
|
||||
@pytest.mark.asyncio
|
||||
async def test_get_resource_binary(self) -> None:
|
||||
data = bytes([0x01, 0x02, 0x03, 0x04])
|
||||
client = _make_client(
|
||||
**{
|
||||
"skill://unit-converter/SKILL.md": _make_text_result(SAMPLE_SKILL_MD),
|
||||
"skill://unit-converter/assets/icon.bin": _make_blob_result(data),
|
||||
}
|
||||
)
|
||||
client = _make_client(**{
|
||||
"skill://unit-converter/SKILL.md": _make_text_result(SAMPLE_SKILL_MD),
|
||||
"skill://unit-converter/assets/icon.bin": _make_blob_result(data),
|
||||
})
|
||||
from agent_framework import SkillFrontmatter
|
||||
|
||||
fm = SkillFrontmatter(name="unit-converter", description="Convert between common units.")
|
||||
@@ -345,12 +337,10 @@ class TestMCPSkillsSource:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_index_based_discovery_returns_skill(self) -> None:
|
||||
client = _make_client(
|
||||
**{
|
||||
"skill://index.json": _make_text_result(SAMPLE_SKILL_INDEX, uri="skill://index.json"),
|
||||
"skill://unit-converter/SKILL.md": _make_text_result(SAMPLE_SKILL_MD),
|
||||
}
|
||||
)
|
||||
client = _make_client(**{
|
||||
"skill://index.json": _make_text_result(SAMPLE_SKILL_INDEX, uri="skill://index.json"),
|
||||
"skill://unit-converter/SKILL.md": _make_text_result(SAMPLE_SKILL_MD),
|
||||
})
|
||||
source = MCPSkillsSource(client=client)
|
||||
skills = await source.get_skills()
|
||||
|
||||
@@ -373,9 +363,7 @@ class TestMCPSkillsSource:
|
||||
async def test_does_not_read_skill_md_during_discovery(self) -> None:
|
||||
# Index points to a skill, but SKILL.md is not registered on the server.
|
||||
# Discovery should succeed because it only reads the index.
|
||||
client = _make_client(
|
||||
**{"skill://index.json": _make_text_result(SAMPLE_SKILL_INDEX, uri="skill://index.json")}
|
||||
)
|
||||
client = _make_client(**{"skill://index.json": _make_text_result(SAMPLE_SKILL_INDEX, uri="skill://index.json")})
|
||||
source = MCPSkillsSource(client=client)
|
||||
skills = await source.get_skills()
|
||||
|
||||
@@ -384,19 +372,17 @@ class TestMCPSkillsSource:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_invalid_name_is_skipped(self) -> None:
|
||||
index_json = json.dumps(
|
||||
{
|
||||
"$schema": "https://schemas.agentskills.io/discovery/0.2.0/schema.json",
|
||||
"skills": [
|
||||
{
|
||||
"name": "UnitConverter", # Invalid: uppercase
|
||||
"type": "skill-md",
|
||||
"description": "Convert between common units.",
|
||||
"url": "skill://UnitConverter/SKILL.md",
|
||||
}
|
||||
],
|
||||
}
|
||||
)
|
||||
index_json = json.dumps({
|
||||
"$schema": "https://schemas.agentskills.io/discovery/0.2.0/schema.json",
|
||||
"skills": [
|
||||
{
|
||||
"name": "UnitConverter", # Invalid: uppercase
|
||||
"type": "skill-md",
|
||||
"description": "Convert between common units.",
|
||||
"url": "skill://UnitConverter/SKILL.md",
|
||||
}
|
||||
],
|
||||
})
|
||||
client = _make_client(**{"skill://index.json": _make_text_result(index_json, uri="skill://index.json")})
|
||||
source = MCPSkillsSource(client=client)
|
||||
skills = await source.get_skills()
|
||||
@@ -404,18 +390,16 @@ class TestMCPSkillsSource:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_missing_required_fields_is_skipped(self) -> None:
|
||||
index_json = json.dumps(
|
||||
{
|
||||
"$schema": "https://schemas.agentskills.io/discovery/0.2.0/schema.json",
|
||||
"skills": [
|
||||
{
|
||||
"name": "unit-converter",
|
||||
"type": "skill-md",
|
||||
# Missing description and url
|
||||
}
|
||||
],
|
||||
}
|
||||
)
|
||||
index_json = json.dumps({
|
||||
"$schema": "https://schemas.agentskills.io/discovery/0.2.0/schema.json",
|
||||
"skills": [
|
||||
{
|
||||
"name": "unit-converter",
|
||||
"type": "skill-md",
|
||||
# Missing description and url
|
||||
}
|
||||
],
|
||||
})
|
||||
client = _make_client(**{"skill://index.json": _make_text_result(index_json, uri="skill://index.json")})
|
||||
source = MCPSkillsSource(client=client)
|
||||
skills = await source.get_skills()
|
||||
@@ -423,19 +407,17 @@ class TestMCPSkillsSource:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_unsupported_type_is_skipped(self) -> None:
|
||||
index_json = json.dumps(
|
||||
{
|
||||
"$schema": "https://schemas.agentskills.io/discovery/0.2.0/schema.json",
|
||||
"skills": [
|
||||
{
|
||||
"name": "some-skill",
|
||||
"type": "archive",
|
||||
"description": "Packaged skill.",
|
||||
"url": "skill://some-skill.tar.gz",
|
||||
}
|
||||
],
|
||||
}
|
||||
)
|
||||
index_json = json.dumps({
|
||||
"$schema": "https://schemas.agentskills.io/discovery/0.2.0/schema.json",
|
||||
"skills": [
|
||||
{
|
||||
"name": "some-skill",
|
||||
"type": "archive",
|
||||
"description": "Packaged skill.",
|
||||
"url": "skill://some-skill.tar.gz",
|
||||
}
|
||||
],
|
||||
})
|
||||
client = _make_client(**{"skill://index.json": _make_text_result(index_json, uri="skill://index.json")})
|
||||
source = MCPSkillsSource(client=client)
|
||||
skills = await source.get_skills()
|
||||
@@ -443,18 +425,16 @@ class TestMCPSkillsSource:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_template_type_is_skipped(self) -> None:
|
||||
index_json = json.dumps(
|
||||
{
|
||||
"$schema": "https://schemas.agentskills.io/discovery/0.2.0/schema.json",
|
||||
"skills": [
|
||||
{
|
||||
"type": "mcp-resource-template",
|
||||
"description": "Per-product documentation skill",
|
||||
"url": "skill://docs/{product}/SKILL.md",
|
||||
}
|
||||
],
|
||||
}
|
||||
)
|
||||
index_json = json.dumps({
|
||||
"$schema": "https://schemas.agentskills.io/discovery/0.2.0/schema.json",
|
||||
"skills": [
|
||||
{
|
||||
"type": "mcp-resource-template",
|
||||
"description": "Per-product documentation skill",
|
||||
"url": "skill://docs/{product}/SKILL.md",
|
||||
}
|
||||
],
|
||||
})
|
||||
client = _make_client(**{"skill://index.json": _make_text_result(index_json, uri="skill://index.json")})
|
||||
source = MCPSkillsSource(client=client)
|
||||
skills = await source.get_skills()
|
||||
@@ -462,31 +442,25 @@ class TestMCPSkillsSource:
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_empty_index_returns_empty(self) -> None:
|
||||
client = _make_client(
|
||||
**{"skill://index.json": _make_text_result('{"skills": []}', uri="skill://index.json")}
|
||||
)
|
||||
client = _make_client(**{"skill://index.json": _make_text_result('{"skills": []}', uri="skill://index.json")})
|
||||
source = MCPSkillsSource(client=client)
|
||||
skills = await source.get_skills()
|
||||
assert skills == []
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_malformed_index_json_returns_empty(self) -> None:
|
||||
client = _make_client(
|
||||
**{"skill://index.json": _make_text_result("not valid json", uri="skill://index.json")}
|
||||
)
|
||||
client = _make_client(**{"skill://index.json": _make_text_result("not valid json", uri="skill://index.json")})
|
||||
source = MCPSkillsSource(client=client)
|
||||
skills = await source.get_skills()
|
||||
assert skills == []
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_sibling_text_resource(self) -> None:
|
||||
client = _make_client(
|
||||
**{
|
||||
"skill://index.json": _make_text_result(SAMPLE_SKILL_INDEX, uri="skill://index.json"),
|
||||
"skill://unit-converter/SKILL.md": _make_text_result(SAMPLE_SKILL_MD),
|
||||
"skill://unit-converter/references/checklist.md": _make_text_result("- check thing 1\n- check thing 2"),
|
||||
}
|
||||
)
|
||||
client = _make_client(**{
|
||||
"skill://index.json": _make_text_result(SAMPLE_SKILL_INDEX, uri="skill://index.json"),
|
||||
"skill://unit-converter/SKILL.md": _make_text_result(SAMPLE_SKILL_MD),
|
||||
"skill://unit-converter/references/checklist.md": _make_text_result("- check thing 1\n- check thing 2"),
|
||||
})
|
||||
source = MCPSkillsSource(client=client)
|
||||
skill = (await source.get_skills())[0]
|
||||
resource = await skill.get_resource("references/checklist.md")
|
||||
@@ -497,13 +471,11 @@ class TestMCPSkillsSource:
|
||||
@pytest.mark.asyncio
|
||||
async def test_sibling_binary_resource(self) -> None:
|
||||
data = bytes([0x01, 0x02, 0x03, 0x04])
|
||||
client = _make_client(
|
||||
**{
|
||||
"skill://index.json": _make_text_result(SAMPLE_SKILL_INDEX, uri="skill://index.json"),
|
||||
"skill://unit-converter/SKILL.md": _make_text_result(SAMPLE_SKILL_MD),
|
||||
"skill://unit-converter/assets/icon.bin": _make_blob_result(data),
|
||||
}
|
||||
)
|
||||
client = _make_client(**{
|
||||
"skill://index.json": _make_text_result(SAMPLE_SKILL_INDEX, uri="skill://index.json"),
|
||||
"skill://unit-converter/SKILL.md": _make_text_result(SAMPLE_SKILL_MD),
|
||||
"skill://unit-converter/assets/icon.bin": _make_blob_result(data),
|
||||
})
|
||||
source = MCPSkillsSource(client=client)
|
||||
skill = (await source.get_skills())[0]
|
||||
resource = await skill.get_resource("assets/icon.bin")
|
||||
@@ -649,9 +621,7 @@ class TestMCPSkillsSourceErrorCodeBranching:
|
||||
from agent_framework import SkillFrontmatter
|
||||
|
||||
client = AsyncMock()
|
||||
client.read_resource = AsyncMock(
|
||||
side_effect=McpError(error=ErrorData(code=0, message="Handler error"))
|
||||
)
|
||||
client.read_resource = AsyncMock(side_effect=McpError(error=ErrorData(code=0, message="Handler error")))
|
||||
fm = SkillFrontmatter(name="test-skill", description="Test.")
|
||||
skill = MCPSkill(frontmatter=fm, skill_md_uri="skill://test/SKILL.md", client=client)
|
||||
with pytest.raises(McpError):
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
# Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
import logging
|
||||
from collections.abc import AsyncIterable, Awaitable
|
||||
from typing import Any, Literal, overload
|
||||
|
||||
@@ -19,7 +20,7 @@ from agent_framework import (
|
||||
WorkflowEvent,
|
||||
WorkflowRunState,
|
||||
)
|
||||
from agent_framework._workflows._agent_executor import AgentExecutorResponse
|
||||
from agent_framework._workflows._agent_executor import AgentExecutorRequest, AgentExecutorResponse
|
||||
from agent_framework._workflows._checkpoint import InMemoryCheckpointStorage
|
||||
from agent_framework._workflows._const import GLOBAL_KWARGS_KEY
|
||||
|
||||
@@ -306,6 +307,108 @@ async def test_agent_executor_save_and_restore_state_directly() -> None:
|
||||
assert restored_session.session_id == session.session_id
|
||||
|
||||
|
||||
# region: Tests for AgentExecutor.reset()
|
||||
|
||||
|
||||
async def test_agent_executor_reset_clears_per_run_state() -> None:
|
||||
"""reset() clears cache, conversation snapshot, and pending request/response buffers."""
|
||||
agent = _CountingAgent(id="reset_agent", name="ResetAgent")
|
||||
executor = AgentExecutor(agent, id="reset_exec")
|
||||
|
||||
# Populate every per-run buffer.
|
||||
executor._cache = [Message(role="user", contents=["cached"])] # type: ignore[reportPrivateUsage]
|
||||
executor._full_conversation = [ # type: ignore[reportPrivateUsage]
|
||||
Message(role="user", contents=["prior turn"]),
|
||||
Message(role="assistant", contents=["prior response"]),
|
||||
]
|
||||
pending_request = Content.from_text(text="approve?")
|
||||
executor._pending_agent_requests = {"req-1": pending_request} # type: ignore[reportPrivateUsage]
|
||||
executor._pending_responses_to_agent = [Content.from_text(text="approved")] # type: ignore[reportPrivateUsage]
|
||||
|
||||
await executor.reset()
|
||||
|
||||
assert executor._cache == [] # type: ignore[reportPrivateUsage]
|
||||
assert executor._full_conversation == [] # type: ignore[reportPrivateUsage]
|
||||
assert executor._pending_agent_requests == {} # type: ignore[reportPrivateUsage]
|
||||
assert executor._pending_responses_to_agent == [] # type: ignore[reportPrivateUsage]
|
||||
|
||||
|
||||
async def test_agent_executor_reset_creates_fresh_session_when_auto_created() -> None:
|
||||
"""reset() replaces the agent session when the executor created it itself."""
|
||||
agent = _CountingAgent(id="reset_session_agent", name="ResetSessionAgent")
|
||||
# No session passed in — executor creates one via agent.create_session().
|
||||
executor = AgentExecutor(agent, id="reset_session_exec")
|
||||
auto_created = executor._session # type: ignore[reportPrivateUsage]
|
||||
auto_created.state["history"] = {"messages": [Message(role="user", contents=["old"])]}
|
||||
|
||||
await executor.reset()
|
||||
|
||||
new_session = executor._session # type: ignore[reportPrivateUsage]
|
||||
assert new_session is not auto_created
|
||||
assert new_session.session_id != auto_created.session_id
|
||||
assert "history" not in new_session.state
|
||||
|
||||
|
||||
async def test_agent_executor_reset_preserves_caller_supplied_session(caplog: pytest.LogCaptureFixture) -> None:
|
||||
"""reset() leaves a session passed in via __init__ untouched and warns the caller."""
|
||||
agent = _CountingAgent(id="reset_session_agent", name="ResetSessionAgent")
|
||||
caller_session = AgentSession()
|
||||
history_payload = {"messages": [Message(role="user", contents=["old"])]}
|
||||
caller_session.state["history"] = history_payload
|
||||
executor = AgentExecutor(agent, id="reset_session_exec", session=caller_session)
|
||||
|
||||
assert executor._session is caller_session # type: ignore[reportPrivateUsage]
|
||||
|
||||
with caplog.at_level(logging.WARNING, logger="agent_framework._workflows._agent_executor"):
|
||||
await executor.reset()
|
||||
|
||||
# Same instance, state untouched — the caller is responsible for managing the session.
|
||||
assert executor._session is caller_session # type: ignore[reportPrivateUsage]
|
||||
assert caller_session.state["history"] is history_payload
|
||||
assert any("Session was supplied by the caller" in record.message for record in caplog.records)
|
||||
|
||||
|
||||
async def test_agent_executor_reset_allows_subsequent_run() -> None:
|
||||
"""After reset(), the executor can be reused for a fresh workflow run without leaking state."""
|
||||
agent = _CountingAgent(id="reset_reuse_agent", name="ResetReuseAgent")
|
||||
executor = AgentExecutor(agent, id="reset_reuse_exec")
|
||||
workflow = WorkflowBuilder(start_executor=executor, output_from=[executor]).build()
|
||||
|
||||
first_outputs: list[WorkflowEvent] = []
|
||||
async for event in workflow.run(
|
||||
AgentExecutorRequest(messages=[Message(role="user", contents=["hello"])]),
|
||||
stream=True,
|
||||
):
|
||||
if event.type == "output":
|
||||
first_outputs.append(event)
|
||||
|
||||
assert first_outputs, "first run should have produced at least one output event"
|
||||
# After a normal run the cache is drained but the conversation snapshot remains.
|
||||
assert executor._cache == [] # type: ignore[reportPrivateUsage]
|
||||
assert executor._full_conversation != [] # type: ignore[reportPrivateUsage]
|
||||
first_session_id = executor._session.session_id # type: ignore[reportPrivateUsage]
|
||||
|
||||
await workflow.reset_for_new_run()
|
||||
|
||||
assert executor._full_conversation == [] # type: ignore[reportPrivateUsage]
|
||||
# Session was auto-created, so reset() rotates it to a fresh one.
|
||||
assert executor._session.session_id != first_session_id # type: ignore[reportPrivateUsage]
|
||||
|
||||
second_outputs: list[WorkflowEvent] = []
|
||||
async for event in workflow.run(
|
||||
AgentExecutorRequest(messages=[Message(role="user", contents=["second"])]),
|
||||
stream=True,
|
||||
):
|
||||
if event.type == "output":
|
||||
second_outputs.append(event)
|
||||
|
||||
assert second_outputs, "second run after reset should have produced at least one output event"
|
||||
assert agent.call_count == 2
|
||||
|
||||
|
||||
# endregion: Tests for AgentExecutor.reset()
|
||||
|
||||
|
||||
async def test_prepare_agent_run_args_extracts_invocation_kwargs() -> None:
|
||||
"""_prepare_agent_run_args extracts function_invocation_kwargs and client_kwargs."""
|
||||
agent = _CountingAgent(id="test_agent", name="TestAgent")
|
||||
|
||||
@@ -1004,3 +1004,53 @@ def test_handler_typevar_error_takes_priority_over_context_error():
|
||||
@handler
|
||||
async def process(self, message: _T, ctx) -> None: # type: ignore[no-untyped-def]
|
||||
pass
|
||||
|
||||
|
||||
# region: Tests for Executor.reset()
|
||||
|
||||
|
||||
async def test_executor_default_reset_is_noop():
|
||||
"""The base Executor.reset() is a no-op and must complete without raising.
|
||||
|
||||
Subclasses that don't carry reset-relevant state should be able to rely on the
|
||||
default implementation.
|
||||
"""
|
||||
|
||||
class StatelessExecutor(Executor):
|
||||
@handler
|
||||
async def handle(self, message: str, ctx: WorkflowContext[str]) -> None:
|
||||
await ctx.send_message(message)
|
||||
|
||||
executor_instance = StatelessExecutor(id="stateless")
|
||||
|
||||
# Must complete without raising and return None.
|
||||
assert await executor_instance.reset() is None
|
||||
|
||||
|
||||
async def test_executor_subclass_reset_is_invoked():
|
||||
"""A subclass that overrides reset() can clear its own internal state."""
|
||||
|
||||
class CounterExecutor(Executor):
|
||||
def __init__(self, id: str) -> None:
|
||||
super().__init__(id=id)
|
||||
self.counter = 0
|
||||
self.reset_calls = 0
|
||||
|
||||
@handler
|
||||
async def handle(self, message: int, ctx: WorkflowContext[int]) -> None:
|
||||
self.counter += message
|
||||
|
||||
async def reset(self) -> None:
|
||||
self.counter = 0
|
||||
self.reset_calls += 1
|
||||
|
||||
executor_instance = CounterExecutor(id="counter")
|
||||
executor_instance.counter = 42
|
||||
|
||||
await executor_instance.reset()
|
||||
|
||||
assert executor_instance.counter == 0
|
||||
assert executor_instance.reset_calls == 1
|
||||
|
||||
|
||||
# endregion: Tests for Executor.reset()
|
||||
|
||||
@@ -1112,3 +1112,255 @@ async def test_runner_drains_straggler_events_at_iteration_end():
|
||||
output_events = [e for e in events if e.type == "output"]
|
||||
# We should have output events from both executors
|
||||
assert len(output_events) >= 2
|
||||
|
||||
|
||||
# region: Tests for InProcRunnerContext.reset_for_new_run()
|
||||
|
||||
|
||||
async def test_runner_context_reset_clears_in_flight_messages():
|
||||
"""reset_for_new_run drops queued executor-to-executor messages."""
|
||||
ctx = InProcRunnerContext()
|
||||
await ctx.send_message(WorkflowMessage(data=MockMessage(data=1), source_id="src"))
|
||||
|
||||
assert await ctx.has_messages() is True
|
||||
|
||||
ctx.reset_for_new_run()
|
||||
|
||||
assert await ctx.has_messages() is False
|
||||
assert await ctx.drain_messages() == {}
|
||||
|
||||
|
||||
async def test_runner_context_reset_drains_pending_events():
|
||||
"""reset_for_new_run discards any events buffered for streaming."""
|
||||
ctx = InProcRunnerContext()
|
||||
await ctx.add_event(WorkflowEvent.superstep_started(iteration=1))
|
||||
assert await ctx.has_events() is True
|
||||
|
||||
ctx.reset_for_new_run()
|
||||
|
||||
assert await ctx.has_events() is False
|
||||
assert await ctx.drain_events() == []
|
||||
|
||||
|
||||
async def test_runner_context_reset_resets_streaming_flag():
|
||||
"""reset_for_new_run resets streaming back to its non-streaming default."""
|
||||
ctx = InProcRunnerContext()
|
||||
ctx.set_streaming(True)
|
||||
assert ctx.is_streaming() is True
|
||||
|
||||
ctx.reset_for_new_run()
|
||||
|
||||
assert ctx.is_streaming() is False
|
||||
|
||||
|
||||
# endregion: Tests for InProcRunnerContext.reset_for_new_run()
|
||||
|
||||
|
||||
# region: Tests for Runner.reset_for_new_run()
|
||||
|
||||
|
||||
async def test_runner_reset_for_new_run_resets_iteration_count():
|
||||
"""reset_for_new_run resets the iteration counter back to zero."""
|
||||
runner = _make_runner()
|
||||
runner._iteration = 7 # pyright: ignore[reportPrivateUsage]
|
||||
|
||||
await runner.reset_for_new_run()
|
||||
|
||||
assert runner._iteration == 0 # pyright: ignore[reportPrivateUsage]
|
||||
|
||||
|
||||
async def test_runner_reset_for_new_run_clears_shared_state():
|
||||
"""reset_for_new_run wipes both committed and pending entries from shared state."""
|
||||
state = State()
|
||||
state.set("committed_key", "committed_value")
|
||||
state.commit()
|
||||
state.set("pending_key", "pending_value") # uncommitted
|
||||
|
||||
runner = Runner(
|
||||
[],
|
||||
{},
|
||||
state,
|
||||
InProcRunnerContext(),
|
||||
"test_name",
|
||||
graph_signature_hash="test_hash",
|
||||
)
|
||||
|
||||
await runner.reset_for_new_run()
|
||||
|
||||
assert state.get("committed_key") is None
|
||||
assert state.get("pending_key") is None
|
||||
assert state.has("committed_key") is False
|
||||
assert state.has("pending_key") is False
|
||||
|
||||
|
||||
async def test_runner_reset_for_new_run_clears_resumed_from_checkpoint_flag():
|
||||
"""reset_for_new_run clears the flag set by restore_from_checkpoint."""
|
||||
runner = _make_runner()
|
||||
runner._mark_resumed(iteration=5) # pyright: ignore[reportPrivateUsage]
|
||||
assert runner._resumed_from_checkpoint is True # pyright: ignore[reportPrivateUsage]
|
||||
|
||||
await runner.reset_for_new_run()
|
||||
|
||||
assert runner._resumed_from_checkpoint is False # pyright: ignore[reportPrivateUsage]
|
||||
# And the iteration count restored from the checkpoint must be wiped, too.
|
||||
assert runner._iteration == 0 # pyright: ignore[reportPrivateUsage]
|
||||
|
||||
|
||||
async def test_runner_reset_for_new_run_invokes_executor_reset_for_each_executor():
|
||||
"""reset_for_new_run calls reset() on every registered executor exactly once."""
|
||||
|
||||
class TrackingExecutor(MockExecutor):
|
||||
def __init__(self, id: str) -> None:
|
||||
super().__init__(id=id)
|
||||
self.reset_calls = 0
|
||||
|
||||
async def reset(self) -> None:
|
||||
self.reset_calls += 1
|
||||
|
||||
executor_a = TrackingExecutor(id="executor_a")
|
||||
executor_b = TrackingExecutor(id="executor_b")
|
||||
|
||||
runner = Runner(
|
||||
[],
|
||||
{executor_a.id: executor_a, executor_b.id: executor_b},
|
||||
State(),
|
||||
InProcRunnerContext(),
|
||||
"test_name",
|
||||
graph_signature_hash="test_hash",
|
||||
)
|
||||
|
||||
await runner.reset_for_new_run()
|
||||
|
||||
assert executor_a.reset_calls == 1
|
||||
assert executor_b.reset_calls == 1
|
||||
|
||||
|
||||
async def test_runner_reset_for_new_run_resets_runner_context():
|
||||
"""reset_for_new_run forwards the reset to the underlying runner context."""
|
||||
ctx = InProcRunnerContext()
|
||||
await ctx.send_message(WorkflowMessage(data=MockMessage(data=0), source_id="src"))
|
||||
await ctx.add_event(WorkflowEvent.superstep_started(iteration=1))
|
||||
ctx.set_streaming(True)
|
||||
|
||||
runner = Runner([], {}, State(), ctx, "test_name", graph_signature_hash="test_hash")
|
||||
|
||||
await runner.reset_for_new_run()
|
||||
|
||||
assert await ctx.has_messages() is False
|
||||
assert await ctx.has_events() is False
|
||||
assert ctx.is_streaming() is False
|
||||
|
||||
|
||||
async def test_runner_can_run_again_after_reset_for_new_run():
|
||||
"""After reset_for_new_run the runner can be reserved and converge a fresh workload."""
|
||||
executor_a = MockExecutor(id="executor_a")
|
||||
executor_b = MockExecutor(id="executor_b")
|
||||
|
||||
edges = [
|
||||
SingleEdgeGroup(executor_a.id, executor_b.id),
|
||||
SingleEdgeGroup(executor_b.id, executor_a.id),
|
||||
]
|
||||
executors: dict[str, Executor] = {
|
||||
executor_a.id: executor_a,
|
||||
executor_b.id: executor_b,
|
||||
}
|
||||
state = State()
|
||||
ctx = InProcRunnerContext()
|
||||
|
||||
runner = Runner(edges, executors, state, ctx, "test_name", graph_signature_hash="test_hash")
|
||||
|
||||
# First run: drives MockExecutor's loop until it yields the terminal value.
|
||||
await executor_a.execute(MockMessage(data=0), ["START"], state, ctx)
|
||||
runner.reserve()
|
||||
async for _ in runner.run_until_convergence():
|
||||
pass
|
||||
assert runner._iteration == 10 # pyright: ignore[reportPrivateUsage]
|
||||
|
||||
await runner.reset_for_new_run()
|
||||
|
||||
# Second run: must succeed cleanly using the same runner instance.
|
||||
await executor_a.execute(MockMessage(data=0), ["START"], state, ctx)
|
||||
runner.reserve()
|
||||
second_run_outputs: list[int] = []
|
||||
async for event in runner.run_until_convergence():
|
||||
if event.type == "output":
|
||||
second_run_outputs.append(event.data)
|
||||
|
||||
assert second_run_outputs == [10]
|
||||
assert runner._iteration == 10 # pyright: ignore[reportPrivateUsage]
|
||||
|
||||
|
||||
async def test_runner_reset_for_new_run_rejected_when_reserved():
|
||||
"""reset_for_new_run refuses to run when the runner is reserved but not yet running."""
|
||||
runner = _make_runner()
|
||||
runner.reserve()
|
||||
|
||||
with pytest.raises(WorkflowRunnerException, match="Cannot reset the runner while a run is in progress"):
|
||||
await runner.reset_for_new_run()
|
||||
|
||||
|
||||
async def test_runner_reset_for_new_run_rejected_while_running():
|
||||
"""reset_for_new_run refuses to run while a run is mid-execution."""
|
||||
runner = _make_runner()
|
||||
|
||||
started = asyncio.Event()
|
||||
release = asyncio.Event()
|
||||
|
||||
async def _slow_run() -> None:
|
||||
runner.reserve()
|
||||
async for _ in runner.run_until_convergence():
|
||||
if not started.is_set():
|
||||
started.set()
|
||||
await release.wait()
|
||||
|
||||
task = asyncio.create_task(_slow_run())
|
||||
await started.wait() # first run is now executing inside run_until_convergence
|
||||
|
||||
try:
|
||||
with pytest.raises(WorkflowRunnerException, match="Cannot reset the runner while a run is in progress"):
|
||||
await runner.reset_for_new_run()
|
||||
finally:
|
||||
release.set()
|
||||
await task
|
||||
|
||||
# Once the run drained, reset must succeed again.
|
||||
await runner.reset_for_new_run()
|
||||
|
||||
|
||||
async def test_runner_reset_for_new_run_does_not_mutate_when_rejected():
|
||||
"""When reset is rejected, the runner's iteration counter and state are untouched."""
|
||||
|
||||
class TrackingExecutor(MockExecutor):
|
||||
def __init__(self, id: str) -> None:
|
||||
super().__init__(id=id)
|
||||
self.reset_calls = 0
|
||||
|
||||
async def reset(self) -> None:
|
||||
self.reset_calls += 1
|
||||
|
||||
executor = TrackingExecutor(id="executor")
|
||||
state = State()
|
||||
state.set("preserved", 42)
|
||||
state.commit()
|
||||
|
||||
runner = Runner(
|
||||
[],
|
||||
{executor.id: executor},
|
||||
state,
|
||||
InProcRunnerContext(),
|
||||
"test_name",
|
||||
graph_signature_hash="test_hash",
|
||||
)
|
||||
runner._iteration = 7 # pyright: ignore[reportPrivateUsage]
|
||||
runner.reserve()
|
||||
|
||||
with pytest.raises(WorkflowRunnerException):
|
||||
await runner.reset_for_new_run()
|
||||
|
||||
# Nothing was mutated by the failed reset.
|
||||
assert runner._iteration == 7 # pyright: ignore[reportPrivateUsage]
|
||||
assert state.get("preserved") == 42
|
||||
assert executor.reset_calls == 0
|
||||
|
||||
|
||||
# endregion: Tests for Runner.reset_for_new_run()
|
||||
|
||||
@@ -689,3 +689,89 @@ async def test_sub_workflow_intermediate_outputs_propagate_to_parent() -> None:
|
||||
|
||||
# The parent's own terminal output is unaffected.
|
||||
assert any(e.executor_id == "parent_sink" and e.data == "final: hello" for e in output_events)
|
||||
|
||||
|
||||
# region: Tests for WorkflowExecutor.reset()
|
||||
|
||||
|
||||
async def test_workflow_executor_reset_clears_execution_state() -> None:
|
||||
"""reset() clears the WorkflowExecutor's per-run execution contexts and request mappings."""
|
||||
validation_workflow = create_email_validation_workflow()
|
||||
parent = Coordinator()
|
||||
workflow_executor = WorkflowExecutor(validation_workflow, "email_validation_workflow")
|
||||
|
||||
main_workflow = (
|
||||
WorkflowBuilder(start_executor=parent)
|
||||
.add_edge(parent, workflow_executor)
|
||||
.add_edge(workflow_executor, parent)
|
||||
.build()
|
||||
)
|
||||
|
||||
# First run pauses with a pending request from the sub-workflow.
|
||||
result = await main_workflow.run("test@example.com")
|
||||
assert len(result.get_request_info_events()) == 1
|
||||
assert len(workflow_executor._execution_contexts) == 1 # type: ignore[reportPrivateUsage]
|
||||
assert len(workflow_executor._request_to_execution) == 1 # type: ignore[reportPrivateUsage]
|
||||
|
||||
await main_workflow.reset_for_new_run()
|
||||
|
||||
assert workflow_executor._execution_contexts == {} # type: ignore[reportPrivateUsage]
|
||||
assert workflow_executor._request_to_execution == {} # type: ignore[reportPrivateUsage]
|
||||
|
||||
|
||||
async def test_workflow_executor_reset_resets_wrapped_workflow() -> None:
|
||||
"""reset() recursively resets the wrapped workflow (runner iteration counter cleared)."""
|
||||
validation_workflow = create_email_validation_workflow()
|
||||
parent = Coordinator()
|
||||
workflow_executor = WorkflowExecutor(validation_workflow, "email_validation_workflow")
|
||||
|
||||
main_workflow = (
|
||||
WorkflowBuilder(start_executor=parent)
|
||||
.add_edge(parent, workflow_executor)
|
||||
.add_edge(workflow_executor, parent)
|
||||
.build()
|
||||
)
|
||||
|
||||
await main_workflow.run("test@example.com")
|
||||
# The sub-workflow's runner advanced past iteration 0 during execution.
|
||||
assert validation_workflow._runner._iteration > 0 # type: ignore[reportPrivateUsage]
|
||||
|
||||
await main_workflow.reset_for_new_run()
|
||||
|
||||
# The wrapped workflow's runner was reset along with the parent.
|
||||
assert validation_workflow._runner._iteration == 0 # type: ignore[reportPrivateUsage]
|
||||
|
||||
|
||||
async def test_workflow_executor_reset_allows_subsequent_run() -> None:
|
||||
"""After reset(), the parent + WorkflowExecutor can be reused for a fresh run with no leakage."""
|
||||
validation_workflow = create_email_validation_workflow()
|
||||
parent = Coordinator()
|
||||
workflow_executor = WorkflowExecutor(validation_workflow, "email_validation_workflow")
|
||||
|
||||
main_workflow = (
|
||||
WorkflowBuilder(start_executor=parent)
|
||||
.add_edge(parent, workflow_executor)
|
||||
.add_edge(workflow_executor, parent)
|
||||
.build()
|
||||
)
|
||||
|
||||
first_result = await main_workflow.run("first@example.com")
|
||||
assert len(first_result.get_request_info_events()) == 1
|
||||
|
||||
await main_workflow.reset_for_new_run()
|
||||
|
||||
# State on the WorkflowExecutor and parent's pending-request bookkeeping is clean.
|
||||
assert workflow_executor._execution_contexts == {} # type: ignore[reportPrivateUsage]
|
||||
assert workflow_executor._request_to_execution == {} # type: ignore[reportPrivateUsage]
|
||||
|
||||
second_result = await main_workflow.run("second@example.com")
|
||||
second_requests = second_result.get_request_info_events()
|
||||
assert len(second_requests) == 1
|
||||
assert isinstance(second_requests[0].data, DomainCheckRequest)
|
||||
# Confirm the new run produced a request from the second email, not the cached first one.
|
||||
assert second_requests[0].data.email == "second@example.com"
|
||||
# And the WorkflowExecutor is now tracking exactly one fresh execution.
|
||||
assert len(workflow_executor._execution_contexts) == 1 # type: ignore[reportPrivateUsage]
|
||||
|
||||
|
||||
# endregion: Tests for WorkflowExecutor.reset()
|
||||
|
||||
@@ -28,6 +28,7 @@ from agent_framework import (
|
||||
WorkflowEvent,
|
||||
WorkflowException,
|
||||
WorkflowMessage,
|
||||
WorkflowRunnerException,
|
||||
WorkflowRunState,
|
||||
handler,
|
||||
response_handler,
|
||||
@@ -1266,3 +1267,144 @@ async def test_output_executors_filtering_with_run_responses_streaming() -> None
|
||||
|
||||
|
||||
# endregion
|
||||
|
||||
|
||||
# region: Tests for Workflow.reset_for_new_run()
|
||||
|
||||
|
||||
async def test_workflow_reset_for_new_run_allows_subsequent_run() -> None:
|
||||
"""After reset_for_new_run() the same workflow instance can be run again from scratch."""
|
||||
executor_a = IncrementExecutor(id="executor_a")
|
||||
executor_b = IncrementExecutor(id="executor_b")
|
||||
|
||||
workflow = (
|
||||
WorkflowBuilder(start_executor=executor_a, output_from=[executor_a, executor_b])
|
||||
.add_edge(executor_a, executor_b)
|
||||
.add_edge(executor_b, executor_a)
|
||||
.build()
|
||||
)
|
||||
|
||||
first = await workflow.run(NumberMessage(data=0))
|
||||
assert first.get_outputs() == [10]
|
||||
|
||||
await workflow.reset_for_new_run()
|
||||
|
||||
second = await workflow.run(NumberMessage(data=0))
|
||||
assert second.get_outputs() == [10]
|
||||
|
||||
|
||||
async def test_workflow_reset_for_new_run_clears_workflow_state() -> None:
|
||||
"""reset_for_new_run() clears values that executors persisted in shared workflow state."""
|
||||
|
||||
class StateWritingExecutor(Executor):
|
||||
@handler
|
||||
async def handle(self, message: NumberMessage, ctx: WorkflowContext[Any, int]) -> None:
|
||||
previous = ctx.get_state("seen") or 0
|
||||
ctx.set_state("seen", previous + 1)
|
||||
await ctx.yield_output(previous + 1)
|
||||
|
||||
state_writer = StateWritingExecutor(id="state_writer")
|
||||
workflow = WorkflowBuilder(start_executor=state_writer, output_from=[state_writer]).build()
|
||||
|
||||
first = await workflow.run(NumberMessage(data=1))
|
||||
assert first.get_outputs() == [1]
|
||||
# State was persisted by the executor.
|
||||
assert workflow._runner.state.get("seen") == 1 # pyright: ignore[reportPrivateUsage]
|
||||
|
||||
await workflow.reset_for_new_run()
|
||||
|
||||
# The runner's shared state has been wiped.
|
||||
assert workflow._runner.state.get("seen") is None # pyright: ignore[reportPrivateUsage]
|
||||
|
||||
second = await workflow.run(NumberMessage(data=1))
|
||||
# Counter started fresh from 0 again; output is 1, not 2.
|
||||
assert second.get_outputs() == [1]
|
||||
|
||||
|
||||
async def test_workflow_reset_for_new_run_invokes_executor_reset_hook() -> None:
|
||||
"""reset_for_new_run() calls Executor.reset() on every executor in the workflow."""
|
||||
|
||||
class ResettableExecutor(Executor):
|
||||
def __init__(self, id: str) -> None:
|
||||
super().__init__(id=id)
|
||||
self.reset_calls = 0
|
||||
self.handled = 0
|
||||
|
||||
@handler
|
||||
async def handle(self, message: NumberMessage, ctx: WorkflowContext[Any, int]) -> None:
|
||||
self.handled += 1
|
||||
await ctx.yield_output(self.handled)
|
||||
|
||||
async def reset(self) -> None:
|
||||
self.reset_calls += 1
|
||||
self.handled = 0
|
||||
|
||||
executor = ResettableExecutor(id="resettable")
|
||||
workflow = WorkflowBuilder(start_executor=executor, output_from=[executor]).build()
|
||||
|
||||
await workflow.run(NumberMessage(data=1))
|
||||
assert executor.handled == 1
|
||||
assert executor.reset_calls == 0
|
||||
|
||||
await workflow.reset_for_new_run()
|
||||
|
||||
assert executor.reset_calls == 1
|
||||
# The executor's own counter was wiped by its overridden reset().
|
||||
assert executor.handled == 0
|
||||
|
||||
|
||||
async def test_workflow_reset_for_new_run_resets_runner_iteration_counter() -> None:
|
||||
"""reset_for_new_run() drops the iteration counter accumulated during a prior run."""
|
||||
executor_a = IncrementExecutor(id="executor_a")
|
||||
executor_b = IncrementExecutor(id="executor_b")
|
||||
|
||||
workflow = (
|
||||
WorkflowBuilder(start_executor=executor_a, output_from=[executor_a, executor_b])
|
||||
.add_edge(executor_a, executor_b)
|
||||
.add_edge(executor_b, executor_a)
|
||||
.build()
|
||||
)
|
||||
|
||||
await workflow.run(NumberMessage(data=0))
|
||||
assert workflow._runner._iteration > 0 # pyright: ignore[reportPrivateUsage]
|
||||
|
||||
await workflow.reset_for_new_run()
|
||||
|
||||
assert workflow._runner._iteration == 0 # pyright: ignore[reportPrivateUsage]
|
||||
|
||||
|
||||
async def test_workflow_reset_for_new_run_rejected_during_streaming_run() -> None:
|
||||
"""reset_for_new_run() raises WorkflowRunnerException while a streaming run is in progress."""
|
||||
executor_a = IncrementExecutor(id="executor_a")
|
||||
executor_b = IncrementExecutor(id="executor_b")
|
||||
|
||||
workflow = (
|
||||
WorkflowBuilder(start_executor=executor_a, output_from=[executor_a, executor_b])
|
||||
.add_edge(executor_a, executor_b)
|
||||
.add_edge(executor_b, executor_a)
|
||||
.build()
|
||||
)
|
||||
|
||||
async def consume_stream_slowly() -> list[WorkflowEvent]:
|
||||
events: list[WorkflowEvent] = []
|
||||
async for event in workflow.run(NumberMessage(data=0), stream=True):
|
||||
events.append(event)
|
||||
await asyncio.sleep(0.01)
|
||||
return events
|
||||
|
||||
task = asyncio.create_task(consume_stream_slowly())
|
||||
# Let the streaming run start.
|
||||
await asyncio.sleep(0.02)
|
||||
|
||||
try:
|
||||
with pytest.raises(WorkflowRunnerException, match="Cannot reset the runner while a run is in progress"):
|
||||
await workflow.reset_for_new_run()
|
||||
finally:
|
||||
await task
|
||||
|
||||
# After the run completes, reset succeeds again.
|
||||
await workflow.reset_for_new_run()
|
||||
assert workflow._runner._iteration == 0 # pyright: ignore[reportPrivateUsage]
|
||||
|
||||
|
||||
# endregion
|
||||
|
||||
Reference in New Issue
Block a user