Add reset to workflow

This commit is contained in:
Tao Chen
2026-06-08 13:42:42 -07:00
Unverified
parent c5e6a7797f
commit 65522bdbee
19 changed files with 4986 additions and 4136 deletions
+6 -17
View File
@@ -1643,9 +1643,7 @@ class MCPTool:
return parser(fallback_result)
if task_id is None:
raise ToolExecutionException(
f"MCP server did not return a task_id or fallback result for '{tool_name}'."
)
raise ToolExecutionException(f"MCP server did not return a task_id or fallback result for '{tool_name}'.")
# Track to completion: poll until terminal, then fetch payload. Never re-issue
# tools/call past this point; reconnect-and-retry only against the same task_id.
@@ -1765,9 +1763,7 @@ class MCPTool:
transient_codes: frozenset[int] = frozenset({int(httpx.codes.REQUEST_TIMEOUT)})
while True:
request = types.ClientRequest(
types.GetTaskRequest(params=types.GetTaskRequestParams(taskId=task_id))
)
request = types.ClientRequest(types.GetTaskRequest(params=types.GetTaskRequestParams(taskId=task_id)))
try:
# GetTaskResult.ttl is required-but-Optional in the SDK; coerce below.
lenient = await self._send_with_one_reconnect(
@@ -1775,9 +1771,7 @@ class MCPTool:
)
except McpError as ex:
if ex.error.code in transient_codes:
logger.debug(
"Transient %s on tasks/get for '%s'; will retry.", ex.error.code, task_id
)
logger.debug("Transient %s on tasks/get for '%s'; will retry.", ex.error.code, task_id)
await asyncio.sleep(_MCP_TASK_MIN_POLL_INTERVAL.total_seconds())
continue
# Hard server error mid-poll: task may still be running.
@@ -1906,9 +1900,7 @@ class MCPTool:
if not self._is_connection_lost(ex):
raise
if attempt < _MCP_RECONNECT_ATTEMPTS - 1:
logger.info(
"MCP connection lost during %s; reconnecting (task_id=%s).", operation, task_id
)
logger.info("MCP connection lost during %s; reconnecting (task_id=%s).", operation, task_id)
try:
await self.connect(reset=True)
except Exception as reconn_ex:
@@ -1967,9 +1959,7 @@ class MCPTool:
"""
from mcp import types
request = types.ClientRequest(
types.CancelTaskRequest(params=types.CancelTaskRequestParams(taskId=task_id))
)
request = types.ClientRequest(types.CancelTaskRequest(params=types.CancelTaskRequestParams(taskId=task_id)))
try:
await asyncio.wait_for(
self.session.send_request(request, types.CancelTaskResult), # type: ignore[union-attr]
@@ -1979,8 +1969,7 @@ class MCPTool:
raise
except asyncio.TimeoutError:
logger.warning(
"Best-effort tasks/cancel for '%s' timed out after %.1fs; "
"remote task may still be running.",
"Best-effort tasks/cancel for '%s' timed out after %.1fs; remote task may still be running.",
task_id,
_MCP_TASK_CANCEL_TIMEOUT.total_seconds(),
)
@@ -3516,9 +3516,7 @@ class MCPSkill(Skill):
result = await self._client.read_resource(_mcp_any_url(self._skill_md_uri))
text = _mcp_join_text(result)
if not text:
raise ValueError(
f"The MCP server returned no text content for SKILL.md resource '{self._skill_md_uri}'."
)
raise ValueError(f"The MCP server returned no text content for SKILL.md resource '{self._skill_md_uri}'.")
self._content = text
return text
@@ -3572,11 +3570,7 @@ class MCPSkill(Skill):
or ``None`` if the name is unsafe.
"""
normalized = name.replace("\\", "/")
if (
normalized.startswith("/")
or "://" in normalized
or any(seg == ".." for seg in normalized.split("/"))
):
if normalized.startswith("/") or "://" in normalized or any(seg == ".." for seg in normalized.split("/")):
logger.debug("Rejecting resource name with unsafe path components: %r", name)
return None
return normalized
@@ -166,6 +166,10 @@ class AgentExecutor(Executor):
raise ValueError("Agent must have a non-empty name or id or an explicit id must be provided.")
super().__init__(exec_id)
self._agent = agent
# Track whether the caller supplied a session so reset() can preserve their
# session reference (which may be wired to external/service-side storage)
# and only replace sessions we created ourselves.
self._session_supplied_by_caller = session is not None
self._session = session or self._agent.create_session()
self._pending_agent_requests: dict[str, Content] = {}
@@ -365,10 +369,37 @@ class AgentExecutor(Executor):
pending_responses_payload = state.get("pending_responses_to_agent")
self._pending_responses_to_agent = pending_responses_payload or []
def reset(self) -> None:
"""Reset the internal cache of the executor."""
logger.debug("AgentExecutor %s: Resetting cache", self.id)
@override
async def reset(self) -> None:
"""Reset the executor to its initial state for a new workflow run.
Clears the message cache, full conversation snapshot, and any pending
user-input request/response bookkeeping.
Session handling:
* If the session was created by this executor (no ``session`` argument
was passed to ``__init__``), it is replaced with a fresh one via
``agent.create_session()`` so prior conversation history does not
leak into the next run.
* If the session was supplied by the caller, it is left untouched.
The caller owns the session lifecycle (it may be backed by
service-side or external storage) and is responsible for clearing
or rotating it if a clean slate is desired.
"""
logger.debug("AgentExecutor %s: Resetting state", self.id)
self._cache.clear()
self._full_conversation.clear()
self._pending_agent_requests.clear()
self._pending_responses_to_agent.clear()
if not self._session_supplied_by_caller:
self._session = self._agent.create_session()
else:
logger.warning(
"AgentExecutor %s: Session was supplied by the caller and will not be reset. "
"Prior conversation history retained in the session may leak into the next run. "
"Reset or rotate the session externally if a clean slate is required.",
self.id,
)
async def _run_agent_and_emit(
self,
@@ -516,6 +516,14 @@ class Executor(RequestInfoMixin, DictConvertible):
"""
...
async def reset(self) -> None:
"""Reset the executor to its initial state.
Override this method in subclasses to implement custom logic that should
run when the workflow is reset.
"""
...
# endregion: Executor
@@ -88,9 +88,14 @@ class Runner:
@property
def context(self) -> RunnerContext:
"""Get the workflow context."""
"""Get the runner context for message, event, and checkpoint handling."""
return self._ctx
@property
def state(self) -> State:
"""Get the shared state for the workflow."""
return self._state
def reserve(self) -> None:
"""Synchronously reserve the runner for an upcoming run.
@@ -117,9 +122,38 @@ class Runner:
self._lifecycle = _RunnerLifecycle.IDLE
def reset_iteration_count(self) -> None:
"""Reset the iteration count to zero."""
"""Reset the iteration count to zero.
This is useful when the workflow resumes from a new set of messages.
Note:
When a workflow is resumed from a response (for a request_info_event)
or a checkpoint, the iteration count is normally NOT reset.
"""
self._iteration = 0
async def reset_for_new_run(self) -> None:
"""Reset the runner for a new run.
This is useful when reusing the same workflow instance for a different run
that is independent from prior runs.
Raises:
WorkflowRunnerException: If the runner is reserved or running. Reset is only
allowed when the runner is idle to avoid clobbering in-flight run state.
"""
if self._lifecycle is not _RunnerLifecycle.IDLE:
raise WorkflowRunnerException(
"Cannot reset the runner while a run is in progress. "
"Wait for the current run to complete before calling reset_for_new_run()."
)
self.reset_iteration_count()
self._ctx.reset_for_new_run()
self._state.clear()
self._resumed_from_checkpoint = False
for executor in self._executors.values():
await executor.reset()
async def run_until_convergence(self) -> AsyncGenerator[WorkflowEvent, None]:
"""Run the workflow until no more messages are sent."""
# Mandatory reservation: callers must reserve() the runner first. This makes
@@ -294,7 +328,7 @@ class Runner:
checkpoint_id: CheckpointID,
checkpoint_storage: CheckpointStorage | None = None,
) -> None:
"""Restore workflow state from a checkpoint.
"""Restore the runner from a checkpoint.
Args:
checkpoint_id: The ID of the checkpoint to restore from
@@ -347,11 +347,10 @@ class Workflow(DictConvertible):
# Store non-serializable runtime objects as private attributes
self._runner_context = runner_context
self._runner_context.set_yield_output_classifier(self._output_designation.classify)
self._state = State()
self._runner: Runner = Runner(
self.edge_groups,
self.executors,
self._state,
State(),
runner_context,
self.name,
self.graph_signature_hash,
@@ -552,14 +551,14 @@ class Workflow(DictConvertible):
combined_kwargs["client_kwargs"] = self._resolve_invocation_kwargs(
client_kwargs, "client_kwargs"
)
self._state.set(WORKFLOW_RUN_KWARGS_KEY, combined_kwargs)
self._runner.state.set(WORKFLOW_RUN_KWARGS_KEY, combined_kwargs)
elif not is_continuation:
self._state.set(WORKFLOW_RUN_KWARGS_KEY, {})
self._state.commit() # Commit immediately so kwargs are available
self._runner.state.set(WORKFLOW_RUN_KWARGS_KEY, {})
self._runner.state.commit() # Commit immediately so kwargs are available
# Set streaming mode (always set explicitly per run since
# reset_for_new_run() no longer runs to clear it).
self._runner_context.set_streaming(streaming)
self._runner.context.set_streaming(streaming)
# Execute initial setup if provided
if initial_executor_fn:
@@ -653,7 +652,7 @@ class Workflow(DictConvertible):
await executor.execute(
message,
[self.__class__.__name__],
self._state,
self._runner.state,
self._runner.context,
trace_contexts=None,
source_span_ids=None,
@@ -934,7 +933,7 @@ class Workflow(DictConvertible):
async def _send_responses_internal(self, responses: Mapping[str, Any]) -> None:
"""Internal method to validate and send responses to the executors."""
pending_requests = await self._runner_context.get_pending_request_info_events()
pending_requests = await self._runner.context.get_pending_request_info_events()
if not pending_requests:
raise RuntimeError("No pending requests found in workflow context.")
@@ -954,7 +953,7 @@ class Workflow(DictConvertible):
coerced_responses[request_id] = response
await asyncio.gather(*[
self._runner_context.send_request_info_response(request_id, response)
self._runner.context.send_request_info_response(request_id, response)
for request_id, response in coerced_responses.items()
])
@@ -1150,3 +1149,17 @@ class Workflow(DictConvertible):
context_providers=context_providers,
**kwargs,
)
async def reset_for_new_run(self) -> None:
"""Reset the workflow for a new run that is independent from prior runs.
Note:
This will reset EVERYTHING - executor states, workflow state, and runner
context (including pending requests/messages).
Raises:
WorkflowRunnerException: If a run is currently in progress. Reset is only
allowed when the workflow is idle to avoid clobbering in-flight run state.
"""
await self._runner.reset_for_new_run()
@@ -534,6 +534,19 @@ class WorkflowExecutor(Executor):
for event in request_info_events
])
@override
async def reset(self) -> None:
"""Reset the WorkflowExecutor to its initial state for a new workflow run.
Clears in-flight execution contexts and the request-to-execution mapping,
then recursively resets the wrapped sub-workflow so its executors, runner
context, and shared state are also returned to a clean state.
"""
logger.debug("WorkflowExecutor %s: Resetting state", self.id)
self._execution_contexts.clear()
self._request_to_execution.clear()
await self.workflow.reset_for_new_run()
async def _process_workflow_result(
self,
result: WorkflowRunResult,
@@ -281,9 +281,7 @@ async def test_mcp_prompts_get_creates_client_span(span_exporter: InMemorySpanEx
async def test_mcp_prompts_get_mcp_error_sets_error_type(span_exporter: InMemorySpanExporter):
"""When session.get_prompt() raises McpError, the span should have error.type and ERROR status."""
tool = _make_connected_mcp_tool()
tool.session.get_prompt = AsyncMock(
side_effect=McpError(ErrorData(code=-32602, message="prompt not found"))
)
tool.session.get_prompt = AsyncMock(side_effect=McpError(ErrorData(code=-32602, message="prompt not found")))
span_exporter.clear()
with pytest.raises(ToolExecutionException):
@@ -35,26 +35,22 @@ description: Convert between common units.
Body content here.
"""
SAMPLE_SKILL_INDEX = json.dumps(
{
"$schema": "https://schemas.agentskills.io/discovery/0.2.0/schema.json",
"skills": [
{
"name": "unit-converter",
"type": "skill-md",
"description": "Convert between common units.",
"url": "skill://unit-converter/SKILL.md",
}
],
}
)
SAMPLE_SKILL_INDEX = json.dumps({
"$schema": "https://schemas.agentskills.io/discovery/0.2.0/schema.json",
"skills": [
{
"name": "unit-converter",
"type": "skill-md",
"description": "Convert between common units.",
"url": "skill://unit-converter/SKILL.md",
}
],
})
def _make_text_result(text: str, uri: str = "skill://test") -> ReadResourceResult:
"""Create a ReadResourceResult with a single TextResourceContents."""
return ReadResourceResult(
contents=[TextResourceContents(uri=AnyUrl(uri), text=text, mimeType="text/markdown")]
)
return ReadResourceResult(contents=[TextResourceContents(uri=AnyUrl(uri), text=text, mimeType="text/markdown")])
def _make_blob_result(
@@ -230,12 +226,10 @@ class TestMCPSkill:
@pytest.mark.asyncio
async def test_get_resource_text(self) -> None:
client = _make_client(
**{
"skill://unit-converter/SKILL.md": _make_text_result(SAMPLE_SKILL_MD),
"skill://unit-converter/references/checklist.md": _make_text_result("- check thing 1\n- check thing 2"),
}
)
client = _make_client(**{
"skill://unit-converter/SKILL.md": _make_text_result(SAMPLE_SKILL_MD),
"skill://unit-converter/references/checklist.md": _make_text_result("- check thing 1\n- check thing 2"),
})
from agent_framework import SkillFrontmatter
fm = SkillFrontmatter(name="unit-converter", description="Convert between common units.")
@@ -249,12 +243,10 @@ class TestMCPSkill:
@pytest.mark.asyncio
async def test_get_resource_binary(self) -> None:
data = bytes([0x01, 0x02, 0x03, 0x04])
client = _make_client(
**{
"skill://unit-converter/SKILL.md": _make_text_result(SAMPLE_SKILL_MD),
"skill://unit-converter/assets/icon.bin": _make_blob_result(data),
}
)
client = _make_client(**{
"skill://unit-converter/SKILL.md": _make_text_result(SAMPLE_SKILL_MD),
"skill://unit-converter/assets/icon.bin": _make_blob_result(data),
})
from agent_framework import SkillFrontmatter
fm = SkillFrontmatter(name="unit-converter", description="Convert between common units.")
@@ -345,12 +337,10 @@ class TestMCPSkillsSource:
@pytest.mark.asyncio
async def test_index_based_discovery_returns_skill(self) -> None:
client = _make_client(
**{
"skill://index.json": _make_text_result(SAMPLE_SKILL_INDEX, uri="skill://index.json"),
"skill://unit-converter/SKILL.md": _make_text_result(SAMPLE_SKILL_MD),
}
)
client = _make_client(**{
"skill://index.json": _make_text_result(SAMPLE_SKILL_INDEX, uri="skill://index.json"),
"skill://unit-converter/SKILL.md": _make_text_result(SAMPLE_SKILL_MD),
})
source = MCPSkillsSource(client=client)
skills = await source.get_skills()
@@ -373,9 +363,7 @@ class TestMCPSkillsSource:
async def test_does_not_read_skill_md_during_discovery(self) -> None:
# Index points to a skill, but SKILL.md is not registered on the server.
# Discovery should succeed because it only reads the index.
client = _make_client(
**{"skill://index.json": _make_text_result(SAMPLE_SKILL_INDEX, uri="skill://index.json")}
)
client = _make_client(**{"skill://index.json": _make_text_result(SAMPLE_SKILL_INDEX, uri="skill://index.json")})
source = MCPSkillsSource(client=client)
skills = await source.get_skills()
@@ -384,19 +372,17 @@ class TestMCPSkillsSource:
@pytest.mark.asyncio
async def test_invalid_name_is_skipped(self) -> None:
index_json = json.dumps(
{
"$schema": "https://schemas.agentskills.io/discovery/0.2.0/schema.json",
"skills": [
{
"name": "UnitConverter", # Invalid: uppercase
"type": "skill-md",
"description": "Convert between common units.",
"url": "skill://UnitConverter/SKILL.md",
}
],
}
)
index_json = json.dumps({
"$schema": "https://schemas.agentskills.io/discovery/0.2.0/schema.json",
"skills": [
{
"name": "UnitConverter", # Invalid: uppercase
"type": "skill-md",
"description": "Convert between common units.",
"url": "skill://UnitConverter/SKILL.md",
}
],
})
client = _make_client(**{"skill://index.json": _make_text_result(index_json, uri="skill://index.json")})
source = MCPSkillsSource(client=client)
skills = await source.get_skills()
@@ -404,18 +390,16 @@ class TestMCPSkillsSource:
@pytest.mark.asyncio
async def test_missing_required_fields_is_skipped(self) -> None:
index_json = json.dumps(
{
"$schema": "https://schemas.agentskills.io/discovery/0.2.0/schema.json",
"skills": [
{
"name": "unit-converter",
"type": "skill-md",
# Missing description and url
}
],
}
)
index_json = json.dumps({
"$schema": "https://schemas.agentskills.io/discovery/0.2.0/schema.json",
"skills": [
{
"name": "unit-converter",
"type": "skill-md",
# Missing description and url
}
],
})
client = _make_client(**{"skill://index.json": _make_text_result(index_json, uri="skill://index.json")})
source = MCPSkillsSource(client=client)
skills = await source.get_skills()
@@ -423,19 +407,17 @@ class TestMCPSkillsSource:
@pytest.mark.asyncio
async def test_unsupported_type_is_skipped(self) -> None:
index_json = json.dumps(
{
"$schema": "https://schemas.agentskills.io/discovery/0.2.0/schema.json",
"skills": [
{
"name": "some-skill",
"type": "archive",
"description": "Packaged skill.",
"url": "skill://some-skill.tar.gz",
}
],
}
)
index_json = json.dumps({
"$schema": "https://schemas.agentskills.io/discovery/0.2.0/schema.json",
"skills": [
{
"name": "some-skill",
"type": "archive",
"description": "Packaged skill.",
"url": "skill://some-skill.tar.gz",
}
],
})
client = _make_client(**{"skill://index.json": _make_text_result(index_json, uri="skill://index.json")})
source = MCPSkillsSource(client=client)
skills = await source.get_skills()
@@ -443,18 +425,16 @@ class TestMCPSkillsSource:
@pytest.mark.asyncio
async def test_template_type_is_skipped(self) -> None:
index_json = json.dumps(
{
"$schema": "https://schemas.agentskills.io/discovery/0.2.0/schema.json",
"skills": [
{
"type": "mcp-resource-template",
"description": "Per-product documentation skill",
"url": "skill://docs/{product}/SKILL.md",
}
],
}
)
index_json = json.dumps({
"$schema": "https://schemas.agentskills.io/discovery/0.2.0/schema.json",
"skills": [
{
"type": "mcp-resource-template",
"description": "Per-product documentation skill",
"url": "skill://docs/{product}/SKILL.md",
}
],
})
client = _make_client(**{"skill://index.json": _make_text_result(index_json, uri="skill://index.json")})
source = MCPSkillsSource(client=client)
skills = await source.get_skills()
@@ -462,31 +442,25 @@ class TestMCPSkillsSource:
@pytest.mark.asyncio
async def test_empty_index_returns_empty(self) -> None:
client = _make_client(
**{"skill://index.json": _make_text_result('{"skills": []}', uri="skill://index.json")}
)
client = _make_client(**{"skill://index.json": _make_text_result('{"skills": []}', uri="skill://index.json")})
source = MCPSkillsSource(client=client)
skills = await source.get_skills()
assert skills == []
@pytest.mark.asyncio
async def test_malformed_index_json_returns_empty(self) -> None:
client = _make_client(
**{"skill://index.json": _make_text_result("not valid json", uri="skill://index.json")}
)
client = _make_client(**{"skill://index.json": _make_text_result("not valid json", uri="skill://index.json")})
source = MCPSkillsSource(client=client)
skills = await source.get_skills()
assert skills == []
@pytest.mark.asyncio
async def test_sibling_text_resource(self) -> None:
client = _make_client(
**{
"skill://index.json": _make_text_result(SAMPLE_SKILL_INDEX, uri="skill://index.json"),
"skill://unit-converter/SKILL.md": _make_text_result(SAMPLE_SKILL_MD),
"skill://unit-converter/references/checklist.md": _make_text_result("- check thing 1\n- check thing 2"),
}
)
client = _make_client(**{
"skill://index.json": _make_text_result(SAMPLE_SKILL_INDEX, uri="skill://index.json"),
"skill://unit-converter/SKILL.md": _make_text_result(SAMPLE_SKILL_MD),
"skill://unit-converter/references/checklist.md": _make_text_result("- check thing 1\n- check thing 2"),
})
source = MCPSkillsSource(client=client)
skill = (await source.get_skills())[0]
resource = await skill.get_resource("references/checklist.md")
@@ -497,13 +471,11 @@ class TestMCPSkillsSource:
@pytest.mark.asyncio
async def test_sibling_binary_resource(self) -> None:
data = bytes([0x01, 0x02, 0x03, 0x04])
client = _make_client(
**{
"skill://index.json": _make_text_result(SAMPLE_SKILL_INDEX, uri="skill://index.json"),
"skill://unit-converter/SKILL.md": _make_text_result(SAMPLE_SKILL_MD),
"skill://unit-converter/assets/icon.bin": _make_blob_result(data),
}
)
client = _make_client(**{
"skill://index.json": _make_text_result(SAMPLE_SKILL_INDEX, uri="skill://index.json"),
"skill://unit-converter/SKILL.md": _make_text_result(SAMPLE_SKILL_MD),
"skill://unit-converter/assets/icon.bin": _make_blob_result(data),
})
source = MCPSkillsSource(client=client)
skill = (await source.get_skills())[0]
resource = await skill.get_resource("assets/icon.bin")
@@ -649,9 +621,7 @@ class TestMCPSkillsSourceErrorCodeBranching:
from agent_framework import SkillFrontmatter
client = AsyncMock()
client.read_resource = AsyncMock(
side_effect=McpError(error=ErrorData(code=0, message="Handler error"))
)
client.read_resource = AsyncMock(side_effect=McpError(error=ErrorData(code=0, message="Handler error")))
fm = SkillFrontmatter(name="test-skill", description="Test.")
skill = MCPSkill(frontmatter=fm, skill_md_uri="skill://test/SKILL.md", client=client)
with pytest.raises(McpError):
@@ -1,5 +1,6 @@
# Copyright (c) Microsoft. All rights reserved.
import logging
from collections.abc import AsyncIterable, Awaitable
from typing import Any, Literal, overload
@@ -19,7 +20,7 @@ from agent_framework import (
WorkflowEvent,
WorkflowRunState,
)
from agent_framework._workflows._agent_executor import AgentExecutorResponse
from agent_framework._workflows._agent_executor import AgentExecutorRequest, AgentExecutorResponse
from agent_framework._workflows._checkpoint import InMemoryCheckpointStorage
from agent_framework._workflows._const import GLOBAL_KWARGS_KEY
@@ -306,6 +307,108 @@ async def test_agent_executor_save_and_restore_state_directly() -> None:
assert restored_session.session_id == session.session_id
# region: Tests for AgentExecutor.reset()
async def test_agent_executor_reset_clears_per_run_state() -> None:
"""reset() clears cache, conversation snapshot, and pending request/response buffers."""
agent = _CountingAgent(id="reset_agent", name="ResetAgent")
executor = AgentExecutor(agent, id="reset_exec")
# Populate every per-run buffer.
executor._cache = [Message(role="user", contents=["cached"])] # type: ignore[reportPrivateUsage]
executor._full_conversation = [ # type: ignore[reportPrivateUsage]
Message(role="user", contents=["prior turn"]),
Message(role="assistant", contents=["prior response"]),
]
pending_request = Content.from_text(text="approve?")
executor._pending_agent_requests = {"req-1": pending_request} # type: ignore[reportPrivateUsage]
executor._pending_responses_to_agent = [Content.from_text(text="approved")] # type: ignore[reportPrivateUsage]
await executor.reset()
assert executor._cache == [] # type: ignore[reportPrivateUsage]
assert executor._full_conversation == [] # type: ignore[reportPrivateUsage]
assert executor._pending_agent_requests == {} # type: ignore[reportPrivateUsage]
assert executor._pending_responses_to_agent == [] # type: ignore[reportPrivateUsage]
async def test_agent_executor_reset_creates_fresh_session_when_auto_created() -> None:
"""reset() replaces the agent session when the executor created it itself."""
agent = _CountingAgent(id="reset_session_agent", name="ResetSessionAgent")
# No session passed in — executor creates one via agent.create_session().
executor = AgentExecutor(agent, id="reset_session_exec")
auto_created = executor._session # type: ignore[reportPrivateUsage]
auto_created.state["history"] = {"messages": [Message(role="user", contents=["old"])]}
await executor.reset()
new_session = executor._session # type: ignore[reportPrivateUsage]
assert new_session is not auto_created
assert new_session.session_id != auto_created.session_id
assert "history" not in new_session.state
async def test_agent_executor_reset_preserves_caller_supplied_session(caplog: pytest.LogCaptureFixture) -> None:
"""reset() leaves a session passed in via __init__ untouched and warns the caller."""
agent = _CountingAgent(id="reset_session_agent", name="ResetSessionAgent")
caller_session = AgentSession()
history_payload = {"messages": [Message(role="user", contents=["old"])]}
caller_session.state["history"] = history_payload
executor = AgentExecutor(agent, id="reset_session_exec", session=caller_session)
assert executor._session is caller_session # type: ignore[reportPrivateUsage]
with caplog.at_level(logging.WARNING, logger="agent_framework._workflows._agent_executor"):
await executor.reset()
# Same instance, state untouched — the caller is responsible for managing the session.
assert executor._session is caller_session # type: ignore[reportPrivateUsage]
assert caller_session.state["history"] is history_payload
assert any("Session was supplied by the caller" in record.message for record in caplog.records)
async def test_agent_executor_reset_allows_subsequent_run() -> None:
"""After reset(), the executor can be reused for a fresh workflow run without leaking state."""
agent = _CountingAgent(id="reset_reuse_agent", name="ResetReuseAgent")
executor = AgentExecutor(agent, id="reset_reuse_exec")
workflow = WorkflowBuilder(start_executor=executor, output_from=[executor]).build()
first_outputs: list[WorkflowEvent] = []
async for event in workflow.run(
AgentExecutorRequest(messages=[Message(role="user", contents=["hello"])]),
stream=True,
):
if event.type == "output":
first_outputs.append(event)
assert first_outputs, "first run should have produced at least one output event"
# After a normal run the cache is drained but the conversation snapshot remains.
assert executor._cache == [] # type: ignore[reportPrivateUsage]
assert executor._full_conversation != [] # type: ignore[reportPrivateUsage]
first_session_id = executor._session.session_id # type: ignore[reportPrivateUsage]
await workflow.reset_for_new_run()
assert executor._full_conversation == [] # type: ignore[reportPrivateUsage]
# Session was auto-created, so reset() rotates it to a fresh one.
assert executor._session.session_id != first_session_id # type: ignore[reportPrivateUsage]
second_outputs: list[WorkflowEvent] = []
async for event in workflow.run(
AgentExecutorRequest(messages=[Message(role="user", contents=["second"])]),
stream=True,
):
if event.type == "output":
second_outputs.append(event)
assert second_outputs, "second run after reset should have produced at least one output event"
assert agent.call_count == 2
# endregion: Tests for AgentExecutor.reset()
async def test_prepare_agent_run_args_extracts_invocation_kwargs() -> None:
"""_prepare_agent_run_args extracts function_invocation_kwargs and client_kwargs."""
agent = _CountingAgent(id="test_agent", name="TestAgent")
@@ -1004,3 +1004,53 @@ def test_handler_typevar_error_takes_priority_over_context_error():
@handler
async def process(self, message: _T, ctx) -> None: # type: ignore[no-untyped-def]
pass
# region: Tests for Executor.reset()
async def test_executor_default_reset_is_noop():
"""The base Executor.reset() is a no-op and must complete without raising.
Subclasses that don't carry reset-relevant state should be able to rely on the
default implementation.
"""
class StatelessExecutor(Executor):
@handler
async def handle(self, message: str, ctx: WorkflowContext[str]) -> None:
await ctx.send_message(message)
executor_instance = StatelessExecutor(id="stateless")
# Must complete without raising and return None.
assert await executor_instance.reset() is None
async def test_executor_subclass_reset_is_invoked():
"""A subclass that overrides reset() can clear its own internal state."""
class CounterExecutor(Executor):
def __init__(self, id: str) -> None:
super().__init__(id=id)
self.counter = 0
self.reset_calls = 0
@handler
async def handle(self, message: int, ctx: WorkflowContext[int]) -> None:
self.counter += message
async def reset(self) -> None:
self.counter = 0
self.reset_calls += 1
executor_instance = CounterExecutor(id="counter")
executor_instance.counter = 42
await executor_instance.reset()
assert executor_instance.counter == 0
assert executor_instance.reset_calls == 1
# endregion: Tests for Executor.reset()
@@ -1112,3 +1112,255 @@ async def test_runner_drains_straggler_events_at_iteration_end():
output_events = [e for e in events if e.type == "output"]
# We should have output events from both executors
assert len(output_events) >= 2
# region: Tests for InProcRunnerContext.reset_for_new_run()
async def test_runner_context_reset_clears_in_flight_messages():
"""reset_for_new_run drops queued executor-to-executor messages."""
ctx = InProcRunnerContext()
await ctx.send_message(WorkflowMessage(data=MockMessage(data=1), source_id="src"))
assert await ctx.has_messages() is True
ctx.reset_for_new_run()
assert await ctx.has_messages() is False
assert await ctx.drain_messages() == {}
async def test_runner_context_reset_drains_pending_events():
"""reset_for_new_run discards any events buffered for streaming."""
ctx = InProcRunnerContext()
await ctx.add_event(WorkflowEvent.superstep_started(iteration=1))
assert await ctx.has_events() is True
ctx.reset_for_new_run()
assert await ctx.has_events() is False
assert await ctx.drain_events() == []
async def test_runner_context_reset_resets_streaming_flag():
"""reset_for_new_run resets streaming back to its non-streaming default."""
ctx = InProcRunnerContext()
ctx.set_streaming(True)
assert ctx.is_streaming() is True
ctx.reset_for_new_run()
assert ctx.is_streaming() is False
# endregion: Tests for InProcRunnerContext.reset_for_new_run()
# region: Tests for Runner.reset_for_new_run()
async def test_runner_reset_for_new_run_resets_iteration_count():
"""reset_for_new_run resets the iteration counter back to zero."""
runner = _make_runner()
runner._iteration = 7 # pyright: ignore[reportPrivateUsage]
await runner.reset_for_new_run()
assert runner._iteration == 0 # pyright: ignore[reportPrivateUsage]
async def test_runner_reset_for_new_run_clears_shared_state():
"""reset_for_new_run wipes both committed and pending entries from shared state."""
state = State()
state.set("committed_key", "committed_value")
state.commit()
state.set("pending_key", "pending_value") # uncommitted
runner = Runner(
[],
{},
state,
InProcRunnerContext(),
"test_name",
graph_signature_hash="test_hash",
)
await runner.reset_for_new_run()
assert state.get("committed_key") is None
assert state.get("pending_key") is None
assert state.has("committed_key") is False
assert state.has("pending_key") is False
async def test_runner_reset_for_new_run_clears_resumed_from_checkpoint_flag():
"""reset_for_new_run clears the flag set by restore_from_checkpoint."""
runner = _make_runner()
runner._mark_resumed(iteration=5) # pyright: ignore[reportPrivateUsage]
assert runner._resumed_from_checkpoint is True # pyright: ignore[reportPrivateUsage]
await runner.reset_for_new_run()
assert runner._resumed_from_checkpoint is False # pyright: ignore[reportPrivateUsage]
# And the iteration count restored from the checkpoint must be wiped, too.
assert runner._iteration == 0 # pyright: ignore[reportPrivateUsage]
async def test_runner_reset_for_new_run_invokes_executor_reset_for_each_executor():
"""reset_for_new_run calls reset() on every registered executor exactly once."""
class TrackingExecutor(MockExecutor):
def __init__(self, id: str) -> None:
super().__init__(id=id)
self.reset_calls = 0
async def reset(self) -> None:
self.reset_calls += 1
executor_a = TrackingExecutor(id="executor_a")
executor_b = TrackingExecutor(id="executor_b")
runner = Runner(
[],
{executor_a.id: executor_a, executor_b.id: executor_b},
State(),
InProcRunnerContext(),
"test_name",
graph_signature_hash="test_hash",
)
await runner.reset_for_new_run()
assert executor_a.reset_calls == 1
assert executor_b.reset_calls == 1
async def test_runner_reset_for_new_run_resets_runner_context():
"""reset_for_new_run forwards the reset to the underlying runner context."""
ctx = InProcRunnerContext()
await ctx.send_message(WorkflowMessage(data=MockMessage(data=0), source_id="src"))
await ctx.add_event(WorkflowEvent.superstep_started(iteration=1))
ctx.set_streaming(True)
runner = Runner([], {}, State(), ctx, "test_name", graph_signature_hash="test_hash")
await runner.reset_for_new_run()
assert await ctx.has_messages() is False
assert await ctx.has_events() is False
assert ctx.is_streaming() is False
async def test_runner_can_run_again_after_reset_for_new_run():
"""After reset_for_new_run the runner can be reserved and converge a fresh workload."""
executor_a = MockExecutor(id="executor_a")
executor_b = MockExecutor(id="executor_b")
edges = [
SingleEdgeGroup(executor_a.id, executor_b.id),
SingleEdgeGroup(executor_b.id, executor_a.id),
]
executors: dict[str, Executor] = {
executor_a.id: executor_a,
executor_b.id: executor_b,
}
state = State()
ctx = InProcRunnerContext()
runner = Runner(edges, executors, state, ctx, "test_name", graph_signature_hash="test_hash")
# First run: drives MockExecutor's loop until it yields the terminal value.
await executor_a.execute(MockMessage(data=0), ["START"], state, ctx)
runner.reserve()
async for _ in runner.run_until_convergence():
pass
assert runner._iteration == 10 # pyright: ignore[reportPrivateUsage]
await runner.reset_for_new_run()
# Second run: must succeed cleanly using the same runner instance.
await executor_a.execute(MockMessage(data=0), ["START"], state, ctx)
runner.reserve()
second_run_outputs: list[int] = []
async for event in runner.run_until_convergence():
if event.type == "output":
second_run_outputs.append(event.data)
assert second_run_outputs == [10]
assert runner._iteration == 10 # pyright: ignore[reportPrivateUsage]
async def test_runner_reset_for_new_run_rejected_when_reserved():
"""reset_for_new_run refuses to run when the runner is reserved but not yet running."""
runner = _make_runner()
runner.reserve()
with pytest.raises(WorkflowRunnerException, match="Cannot reset the runner while a run is in progress"):
await runner.reset_for_new_run()
async def test_runner_reset_for_new_run_rejected_while_running():
"""reset_for_new_run refuses to run while a run is mid-execution."""
runner = _make_runner()
started = asyncio.Event()
release = asyncio.Event()
async def _slow_run() -> None:
runner.reserve()
async for _ in runner.run_until_convergence():
if not started.is_set():
started.set()
await release.wait()
task = asyncio.create_task(_slow_run())
await started.wait() # first run is now executing inside run_until_convergence
try:
with pytest.raises(WorkflowRunnerException, match="Cannot reset the runner while a run is in progress"):
await runner.reset_for_new_run()
finally:
release.set()
await task
# Once the run drained, reset must succeed again.
await runner.reset_for_new_run()
async def test_runner_reset_for_new_run_does_not_mutate_when_rejected():
"""When reset is rejected, the runner's iteration counter and state are untouched."""
class TrackingExecutor(MockExecutor):
def __init__(self, id: str) -> None:
super().__init__(id=id)
self.reset_calls = 0
async def reset(self) -> None:
self.reset_calls += 1
executor = TrackingExecutor(id="executor")
state = State()
state.set("preserved", 42)
state.commit()
runner = Runner(
[],
{executor.id: executor},
state,
InProcRunnerContext(),
"test_name",
graph_signature_hash="test_hash",
)
runner._iteration = 7 # pyright: ignore[reportPrivateUsage]
runner.reserve()
with pytest.raises(WorkflowRunnerException):
await runner.reset_for_new_run()
# Nothing was mutated by the failed reset.
assert runner._iteration == 7 # pyright: ignore[reportPrivateUsage]
assert state.get("preserved") == 42
assert executor.reset_calls == 0
# endregion: Tests for Runner.reset_for_new_run()
@@ -689,3 +689,89 @@ async def test_sub_workflow_intermediate_outputs_propagate_to_parent() -> None:
# The parent's own terminal output is unaffected.
assert any(e.executor_id == "parent_sink" and e.data == "final: hello" for e in output_events)
# region: Tests for WorkflowExecutor.reset()
async def test_workflow_executor_reset_clears_execution_state() -> None:
"""reset() clears the WorkflowExecutor's per-run execution contexts and request mappings."""
validation_workflow = create_email_validation_workflow()
parent = Coordinator()
workflow_executor = WorkflowExecutor(validation_workflow, "email_validation_workflow")
main_workflow = (
WorkflowBuilder(start_executor=parent)
.add_edge(parent, workflow_executor)
.add_edge(workflow_executor, parent)
.build()
)
# First run pauses with a pending request from the sub-workflow.
result = await main_workflow.run("test@example.com")
assert len(result.get_request_info_events()) == 1
assert len(workflow_executor._execution_contexts) == 1 # type: ignore[reportPrivateUsage]
assert len(workflow_executor._request_to_execution) == 1 # type: ignore[reportPrivateUsage]
await main_workflow.reset_for_new_run()
assert workflow_executor._execution_contexts == {} # type: ignore[reportPrivateUsage]
assert workflow_executor._request_to_execution == {} # type: ignore[reportPrivateUsage]
async def test_workflow_executor_reset_resets_wrapped_workflow() -> None:
"""reset() recursively resets the wrapped workflow (runner iteration counter cleared)."""
validation_workflow = create_email_validation_workflow()
parent = Coordinator()
workflow_executor = WorkflowExecutor(validation_workflow, "email_validation_workflow")
main_workflow = (
WorkflowBuilder(start_executor=parent)
.add_edge(parent, workflow_executor)
.add_edge(workflow_executor, parent)
.build()
)
await main_workflow.run("test@example.com")
# The sub-workflow's runner advanced past iteration 0 during execution.
assert validation_workflow._runner._iteration > 0 # type: ignore[reportPrivateUsage]
await main_workflow.reset_for_new_run()
# The wrapped workflow's runner was reset along with the parent.
assert validation_workflow._runner._iteration == 0 # type: ignore[reportPrivateUsage]
async def test_workflow_executor_reset_allows_subsequent_run() -> None:
"""After reset(), the parent + WorkflowExecutor can be reused for a fresh run with no leakage."""
validation_workflow = create_email_validation_workflow()
parent = Coordinator()
workflow_executor = WorkflowExecutor(validation_workflow, "email_validation_workflow")
main_workflow = (
WorkflowBuilder(start_executor=parent)
.add_edge(parent, workflow_executor)
.add_edge(workflow_executor, parent)
.build()
)
first_result = await main_workflow.run("first@example.com")
assert len(first_result.get_request_info_events()) == 1
await main_workflow.reset_for_new_run()
# State on the WorkflowExecutor and parent's pending-request bookkeeping is clean.
assert workflow_executor._execution_contexts == {} # type: ignore[reportPrivateUsage]
assert workflow_executor._request_to_execution == {} # type: ignore[reportPrivateUsage]
second_result = await main_workflow.run("second@example.com")
second_requests = second_result.get_request_info_events()
assert len(second_requests) == 1
assert isinstance(second_requests[0].data, DomainCheckRequest)
# Confirm the new run produced a request from the second email, not the cached first one.
assert second_requests[0].data.email == "second@example.com"
# And the WorkflowExecutor is now tracking exactly one fresh execution.
assert len(workflow_executor._execution_contexts) == 1 # type: ignore[reportPrivateUsage]
# endregion: Tests for WorkflowExecutor.reset()
@@ -28,6 +28,7 @@ from agent_framework import (
WorkflowEvent,
WorkflowException,
WorkflowMessage,
WorkflowRunnerException,
WorkflowRunState,
handler,
response_handler,
@@ -1266,3 +1267,144 @@ async def test_output_executors_filtering_with_run_responses_streaming() -> None
# endregion
# region: Tests for Workflow.reset_for_new_run()
async def test_workflow_reset_for_new_run_allows_subsequent_run() -> None:
"""After reset_for_new_run() the same workflow instance can be run again from scratch."""
executor_a = IncrementExecutor(id="executor_a")
executor_b = IncrementExecutor(id="executor_b")
workflow = (
WorkflowBuilder(start_executor=executor_a, output_from=[executor_a, executor_b])
.add_edge(executor_a, executor_b)
.add_edge(executor_b, executor_a)
.build()
)
first = await workflow.run(NumberMessage(data=0))
assert first.get_outputs() == [10]
await workflow.reset_for_new_run()
second = await workflow.run(NumberMessage(data=0))
assert second.get_outputs() == [10]
async def test_workflow_reset_for_new_run_clears_workflow_state() -> None:
"""reset_for_new_run() clears values that executors persisted in shared workflow state."""
class StateWritingExecutor(Executor):
@handler
async def handle(self, message: NumberMessage, ctx: WorkflowContext[Any, int]) -> None:
previous = ctx.get_state("seen") or 0
ctx.set_state("seen", previous + 1)
await ctx.yield_output(previous + 1)
state_writer = StateWritingExecutor(id="state_writer")
workflow = WorkflowBuilder(start_executor=state_writer, output_from=[state_writer]).build()
first = await workflow.run(NumberMessage(data=1))
assert first.get_outputs() == [1]
# State was persisted by the executor.
assert workflow._runner.state.get("seen") == 1 # pyright: ignore[reportPrivateUsage]
await workflow.reset_for_new_run()
# The runner's shared state has been wiped.
assert workflow._runner.state.get("seen") is None # pyright: ignore[reportPrivateUsage]
second = await workflow.run(NumberMessage(data=1))
# Counter started fresh from 0 again; output is 1, not 2.
assert second.get_outputs() == [1]
async def test_workflow_reset_for_new_run_invokes_executor_reset_hook() -> None:
"""reset_for_new_run() calls Executor.reset() on every executor in the workflow."""
class ResettableExecutor(Executor):
def __init__(self, id: str) -> None:
super().__init__(id=id)
self.reset_calls = 0
self.handled = 0
@handler
async def handle(self, message: NumberMessage, ctx: WorkflowContext[Any, int]) -> None:
self.handled += 1
await ctx.yield_output(self.handled)
async def reset(self) -> None:
self.reset_calls += 1
self.handled = 0
executor = ResettableExecutor(id="resettable")
workflow = WorkflowBuilder(start_executor=executor, output_from=[executor]).build()
await workflow.run(NumberMessage(data=1))
assert executor.handled == 1
assert executor.reset_calls == 0
await workflow.reset_for_new_run()
assert executor.reset_calls == 1
# The executor's own counter was wiped by its overridden reset().
assert executor.handled == 0
async def test_workflow_reset_for_new_run_resets_runner_iteration_counter() -> None:
"""reset_for_new_run() drops the iteration counter accumulated during a prior run."""
executor_a = IncrementExecutor(id="executor_a")
executor_b = IncrementExecutor(id="executor_b")
workflow = (
WorkflowBuilder(start_executor=executor_a, output_from=[executor_a, executor_b])
.add_edge(executor_a, executor_b)
.add_edge(executor_b, executor_a)
.build()
)
await workflow.run(NumberMessage(data=0))
assert workflow._runner._iteration > 0 # pyright: ignore[reportPrivateUsage]
await workflow.reset_for_new_run()
assert workflow._runner._iteration == 0 # pyright: ignore[reportPrivateUsage]
async def test_workflow_reset_for_new_run_rejected_during_streaming_run() -> None:
"""reset_for_new_run() raises WorkflowRunnerException while a streaming run is in progress."""
executor_a = IncrementExecutor(id="executor_a")
executor_b = IncrementExecutor(id="executor_b")
workflow = (
WorkflowBuilder(start_executor=executor_a, output_from=[executor_a, executor_b])
.add_edge(executor_a, executor_b)
.add_edge(executor_b, executor_a)
.build()
)
async def consume_stream_slowly() -> list[WorkflowEvent]:
events: list[WorkflowEvent] = []
async for event in workflow.run(NumberMessage(data=0), stream=True):
events.append(event)
await asyncio.sleep(0.01)
return events
task = asyncio.create_task(consume_stream_slowly())
# Let the streaming run start.
await asyncio.sleep(0.02)
try:
with pytest.raises(WorkflowRunnerException, match="Cannot reset the runner while a run is in progress"):
await workflow.reset_for_new_run()
finally:
await task
# After the run completes, reset succeeds again.
await workflow.reset_for_new_run()
assert workflow._runner._iteration == 0 # pyright: ignore[reportPrivateUsage]
# endregion
@@ -598,3 +598,25 @@ class BaseGroupChatOrchestrator(Executor, ABC):
metadata: Pattern-specific state dict
"""
pass
@override
async def reset(self) -> None:
"""Reset the orchestrator to its initial state for a new workflow run.
Clears the shared conversation history and round counter, then delegates
to ``_reset_pattern_state()`` so subclasses can clean up any
pattern-specific per-run state (caches, sessions, ledgers, etc.).
"""
logger.debug("%s %s: Resetting state", self.__class__.__name__, self.id)
self._full_conversation.clear()
self._round_index = 0
self._reset_pattern_state()
def _reset_pattern_state(self) -> None:
"""Reset pattern-specific state.
Override this method in subclasses to clear pattern-specific per-run state
when ``reset()`` is invoked. Called after the base class clears the shared
conversation and round counter.
"""
pass
@@ -327,6 +327,7 @@ class AgentBasedGroupChatOrchestrator(BaseGroupChatOrchestrator):
)
self._agent = agent
self._retry_attempts = retry_attempts
self._session_supplied_by_caller = session is not None
self._session = session or agent.create_session()
# Cache for messages since last agent invocation
# This is different from the full conversation history maintained by the base orchestrator
@@ -337,6 +338,25 @@ class AgentBasedGroupChatOrchestrator(BaseGroupChatOrchestrator):
self._cache.extend(messages)
return super()._append_messages(messages)
@override
def _reset_pattern_state(self) -> None:
"""Reset pattern-specific state for a new workflow run.
Clears the per-run message cache and rotates the orchestrator agent's
session unless the caller supplied a session explicitly (in which case
the caller is responsible for the session's lifecycle).
"""
self._cache.clear()
if self._session_supplied_by_caller:
logger.warning(
"%s %s: Session was supplied by the caller and will not be reset. "
"If you want a fresh session for the next run, reset or replace it before invoking the workflow.",
self.__class__.__name__,
self.id,
)
else:
self._session = self._agent.create_session()
@override
async def _handle_messages(
self,
@@ -1263,6 +1263,14 @@ class MagenticOrchestrator(BaseGroupChatOrchestrator):
# a target will broadcast to all.
await ctx.send_message(MagenticResetSignal())
@override
def _reset_pattern_state(self) -> None:
"""Reset Magentic-specific per-run state for a new workflow run."""
self._magentic_context = None
self._task_ledger = None
self._progress_ledger = None
self._terminated = False
@override
async def on_checkpoint_save(self) -> dict[str, Any]:
"""Capture current orchestrator state for checkpointing."""
@@ -1097,3 +1097,139 @@ def test_group_chat_orchestrator_factory_invalid_return_type():
# endregion
# region Reset
async def test_base_orchestrator_reset_clears_conversation_and_round_index() -> None:
"""reset() clears the conversation history and the round counter."""
from agent_framework.orchestrations import GroupChatOrchestrator
from agent_framework_orchestrations._base_group_chat_orchestrator import ParticipantRegistry
selector = make_sequence_selector()
orchestrator = GroupChatOrchestrator(
id="orch",
participant_registry=ParticipantRegistry([]),
selection_func=selector,
max_rounds=2,
)
orchestrator._full_conversation = [Message(role="user", contents=["hi"], author_name="user")]
orchestrator._round_index = 4
await orchestrator.reset()
assert orchestrator._full_conversation == []
assert orchestrator._round_index == 0
async def test_base_orchestrator_reset_invokes_pattern_state_hook() -> None:
"""reset() calls _reset_pattern_state() so subclasses can clean up their own state."""
from agent_framework.orchestrations import GroupChatOrchestrator
from agent_framework_orchestrations._base_group_chat_orchestrator import ParticipantRegistry
selector = make_sequence_selector()
class TrackingOrchestrator(GroupChatOrchestrator):
reset_calls: int = 0
def _reset_pattern_state(self) -> None:
type(self).reset_calls += 1
orchestrator = TrackingOrchestrator(
id="orch",
participant_registry=ParticipantRegistry([]),
selection_func=selector,
max_rounds=2,
)
await orchestrator.reset()
await orchestrator.reset()
assert TrackingOrchestrator.reset_calls == 2
async def test_agent_based_orchestrator_reset_clears_cache_and_rotates_session() -> None:
"""When the session was not supplied by the caller, reset() rotates the session and clears the cache."""
from agent_framework.orchestrations import AgentBasedGroupChatOrchestrator
from agent_framework_orchestrations._base_group_chat_orchestrator import ParticipantRegistry
agent = cast(Agent, StubManagerAgent())
orchestrator = AgentBasedGroupChatOrchestrator(
agent=agent,
participant_registry=ParticipantRegistry([]),
max_rounds=2,
)
original_session = orchestrator._session
orchestrator._cache = [Message(role="assistant", contents=["x"], author_name="agent")]
orchestrator._full_conversation = [Message(role="user", contents=["x"], author_name="user")]
orchestrator._round_index = 3
await orchestrator.reset()
assert orchestrator._cache == []
assert orchestrator._full_conversation == []
assert orchestrator._round_index == 0
assert orchestrator._session is not original_session
async def test_agent_based_orchestrator_reset_warns_when_session_supplied(caplog: pytest.LogCaptureFixture) -> None:
"""When the caller supplied a session, reset() preserves it and logs a warning."""
import logging
from agent_framework.orchestrations import AgentBasedGroupChatOrchestrator
from agent_framework_orchestrations._base_group_chat_orchestrator import ParticipantRegistry
agent = cast(Agent, StubManagerAgent())
supplied_session = agent.create_session()
orchestrator = AgentBasedGroupChatOrchestrator(
agent=agent,
participant_registry=ParticipantRegistry([]),
session=supplied_session,
max_rounds=2,
)
orchestrator._cache = [Message(role="assistant", contents=["x"], author_name="agent")]
with caplog.at_level(logging.WARNING, logger="agent_framework_orchestrations._group_chat"):
await orchestrator.reset()
assert orchestrator._cache == []
# The caller-owned session must be preserved.
assert orchestrator._session is supplied_session
warnings = [
r for r in caplog.records if r.levelno == logging.WARNING and "Session was supplied by the caller" in r.message
]
assert warnings, f"expected a warning about caller-supplied session, got: {[r.message for r in caplog.records]}"
async def test_workflow_reset_resets_group_chat_orchestrator() -> None:
"""End-to-end: workflow.reset_for_new_run() resets the orchestrator's conversation state."""
selector = make_sequence_selector()
alpha = StubAgent("alpha", "ack from alpha")
beta = StubAgent("beta", "ack from beta")
workflow = GroupChatBuilder(
participants=[alpha, beta],
max_rounds=2,
selection_func=selector,
orchestrator_name="manager",
).build()
async for _ in workflow.run("first task", stream=True):
pass
orchestrator = cast(BaseGroupChatOrchestrator, workflow.executors[GroupChatBuilder.DEFAULT_ORCHESTRATOR_ID])
assert orchestrator._full_conversation, "orchestrator should have accumulated conversation after first run"
assert orchestrator._round_index > 0
await workflow.reset_for_new_run()
assert orchestrator._full_conversation == []
assert orchestrator._round_index == 0
# endregion
+3963 -3982
View File
File diff suppressed because it is too large Load Diff