Python: Add AgentLoopMiddleware for re-running agents in a loop (#6174)

* Python: Add AgentLoopMiddleware for re-running agents in a loop

Add `AgentLoopMiddleware`, an `AgentMiddleware` that re-runs the wrapped
agent in a loop. A single configurable class covers three common patterns,
each with a convenience classmethod factory:

- Ralph loop (`.ralph(...)`): no exit criteria, with feedback tracking
  (`record_feedback`/`progress`), progress injection (`inject_progress`),
  optional fresh context per iteration (`fresh_context`), and an early-stop
  completion signal (`is_complete`).
- Predicate (`.with_predicate(...)`): loop while a `should_continue` callable
  returns True (e.g. paired with `todos_remaining`/`background_tasks_running`).
- Judge (`.with_judge(...)`): a second chat client decides whether the original
  request was answered, using a `JudgeVerdict` structured-output response.

The loop also auto-resolves pending function-approval / user-input requests via
an `on_approval_request` callable (bounded by `max_approval_rounds`), and the
next iteration's input is controlled by `next_message`. Supports both streaming
and non-streaming runs.

Exports `AgentLoopMiddleware`, `JudgeVerdict`, `todos_remaining`, and
`background_tasks_running`. Adds tests, a sample, and docs.

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>

* Python: Refine AgentLoopMiddleware API and sample

- with_judge: add criteria list with {{criteria}} templating into judge
  instructions plus an agent-side instruction; add fresh_context, additional
  judge feedback relay; default judge max_iterations.
- should_continue is now required and positional; supports (bool, str|None)
  feedback tuples surfaced to next_message/record_feedback via feedback kwarg.
- Judge forwards full multi-modal request and response messages.
- Default max_iterations=10 (explicit None = unbounded); removed is_complete and
  Ralph terminology; ShouldContinueResult is a real TypeAlias.
- Sample: stream all loops, print iteration counts via injected user-block
  boundaries (robust to function calling), <role>: content formatting, per-method
  expected output, and a looping todo sample.

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>

* Python: Fix CI checks for AgentLoopMiddleware

- Resolve pyright errors in _loop.py: drop the always-true final_result None
  check (the while loop always assigns it) and cast finish_reason to the
  AgentResponse constructor's expected type.
- Apply pyupgrade --py310-plus: import TypeAlias from typing.

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>

* Python: Resolve mypy/pyright disagreement on finish_reason

pyright infers AgentResponse.finish_reason as including str and rejects the
direct assignment, while mypy considers a cast redundant. Drop the cast and
suppress only pyright with a targeted reportArgumentType ignore, satisfying
both type checkers.

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>

* Python: Add todo+judge AgentLoopMiddleware sample

Add a second AgentLoopMiddleware sample that composes two criteria in one
should_continue predicate: a TodoProvider check (evaluated first) and a
report-style judge chat client (evaluated once todos are complete) that grades
the assembled report against shared requirements. Register it in the middleware
samples README.

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>

* Python: Compose todo+judge loops as two middleware

Rework the todo+judge sample to compose two AgentLoopMiddleware on the agent
itself (middleware=[judge_loop, todo_loop]) instead of a single hand-written
predicate. The inner todos_remaining loop drafts the report todo-by-todo and the
outer with_judge loop re-runs it until an editor chat client judges the report
publication-ready, reusing the built-in helpers.

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>

* Reset session for fresh_context loops via snapshot/restore

AgentLoopMiddleware.fresh_context previously only reset context.messages,
so with an attached session each iteration still reloaded the local
transcript or re-threaded the service-side conversation id and the model
saw the accumulated history. Snapshot the session once before the loop
(via to_dict) and restore it (from_dict + field copy) between iterations,
so every pass starts from the pre-loop baseline. The final iteration's
pass is persisted (no restore after the terminating iteration), so a
subsequent agent.run continues from there.

Removed the obsolete warning, updated docstrings and core AGENTS.md, and
added tests: a snapshot/restore round-trip, a session-reset
streaming x fresh_context x inject_progress x store matrix across multiple
runs and loop iterations, and response_format parsing across the loop.

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>

* Updated samples and docstrings

---------

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
This commit is contained in:
Eduard van Valkenburg
2026-06-12 16:35:54 +02:00
committed by GitHub
Unverified
parent 3f77c555cf
commit 1acd242550
9 changed files with 2519 additions and 0 deletions
+5
View File
@@ -116,6 +116,11 @@ agent_framework/
available, approval requests for known non-approval-required tools are treated as already approved, hidden, stored
in session state keyed to the visible approval request ids from that batch, and reinjected only when that visible
approval flow resumes.
### Agent Loop (`_harness/_loop.py`)
- **`AgentLoopMiddleware`** - `AgentMiddleware` that re-runs an agent in a loop by calling `call_next()` repeatedly (the pipeline re-reads `context.messages` each time). One configurable class covers two patterns: a required user `should_continue` predicate (sync or async, the first positional/keyword arg), and a chat-client judge built via the `.with_judge(...)` factory (a second chat client decides whether the original request was answered; loops while it is *not*, using a `JudgeVerdict` structured-output response — internally just an async `should_continue` predicate). The constructor covers the predicate pattern directly; only the judge has a convenience classmethod factory (`.with_judge(judge_client, ...)`) that forwards to `__init__`. Supports both streaming and non-streaming runs. By default a non-streaming run returns an aggregated `AgentResponse` containing every iteration's messages plus the injected `next_message` "nudge" messages (as `user` messages); set `return_final_only=True` to return only the last iteration's response. Streaming runs always yield each iteration's updates and emit the injected nudge messages as `user` updates between iterations (the `return_final_only` flag has no effect on streaming, and the final response reflects the last iteration; `MiddlewareTermination` is handled cleanly). `should_continue` is required; other constructor args are optional: `max_iterations` (safety cap; defaults to `DEFAULT_MAX_ITERATIONS`=10, explicit `None`→unbounded, positive int caps; `.with_judge` uses `DEFAULT_JUDGE_MAX_ITERATIONS`=5 as its default), `next_message` (defaults to a short "continue" nudge), `return_final_only`, and `additional_instructions` (an extra `system` message injected ahead of the input before the agent runs — becomes part of the original messages so it survives `fresh_context` resets and persists via a session). The judge is configured only through `.with_judge` (`judge_client`/`instructions`/`criteria`), not the constructor, and its `reasoning` is fed back to the agent as the next iteration's input; the judge forwards the original request messages and the agent's latest response messages verbatim so multi-modal content is preserved. `criteria` (a `list[str]`) is both injected as the agent's `additional_instructions` and rendered into the judge instructions wherever the `{{criteria}}` placeholder (`CRITERIA_PLACEHOLDER`) appears (`DEFAULT_JUDGE_INSTRUCTIONS` ends with it; custom `instructions` may include it, and it is stripped when no criteria are given). The `should_continue`/`next_message` callables are invoked with keyword args (`iteration`, `last_result`, `messages`, `original_messages`, `session`, `agent`, `progress`, `feedback`) and may be sync or async; declare only what you need plus `**kwargs`. `should_continue` may return a plain `bool` or a `(bool, str | None)` tuple whose second item is feedback surfaced to `next_message`/`record_feedback` via the `feedback` kwarg (the judge uses this to relay its `reasoning`). Stop precedence per iteration is `max_iterations``should_continue`, evaluated before `record_feedback` so the feedback is available to it.
- **Feedback tracking** - `record_feedback` captures a per-iteration progress entry (called with the loop kwargs; if it returns a truthy string the entry is appended, otherwise the agent's response text is used as the fallback entry). The accumulated log is exposed to every callback via the `progress` keyword (a per-iteration copy of prior entries) and, when `inject_progress=True` (default), injected into the next iteration's input as a `user` message (the full log without a session, only the latest entry with a session to avoid duplicating history). `fresh_context=True` restarts each iteration from the original task plus the progress log; when a session is attached it is snapshotted (`to_dict()`) before the loop and restored (`from_dict` + field copy) between iterations so the local transcript and any service-side conversation id reset too (in-loop working-state is discarded, pre-loop state preserved, continuity carried only by the progress log).
- **`todos_remaining(provider)`** / **`background_tasks_running(provider)`** - Helper factories returning `should_continue` predicates that loop while a `TodoProvider` has open items, or while a `BackgroundAgentsProvider`'s persisted state shows running tasks.
### Workflows (`_workflows/`)
@@ -102,6 +102,12 @@ from ._harness._file_access import (
FileSystemAgentFileStore,
InMemoryAgentFileStore,
)
from ._harness._loop import (
AgentLoopMiddleware,
JudgeVerdict,
background_tasks_running,
todos_remaining,
)
from ._harness._memory import (
DEFAULT_MEMORY_SOURCE_ID,
MemoryContextProvider,
@@ -363,6 +369,7 @@ __all__ = [
"AgentExecutorResponse",
"AgentFileStore",
"AgentFrameworkException",
"AgentLoopMiddleware",
"AgentMiddleware",
"AgentMiddlewareLayer",
"AgentMiddlewareTypes",
@@ -454,6 +461,7 @@ __all__ = [
"InlineSkill",
"InlineSkillResource",
"InlineSkillScript",
"JudgeVerdict",
"LocalEvaluator",
"MCPSkill",
"MCPSkillResource",
@@ -558,6 +566,7 @@ __all__ = [
"agent_middleware",
"annotate_message_groups",
"apply_compaction",
"background_tasks_running",
"chat_middleware",
"create_always_approve_tool_response",
"create_always_approve_tool_with_arguments_response",
@@ -588,6 +597,7 @@ __all__ = [
"response_handler",
"set_agent_mode",
"step",
"todos_remaining",
"tool",
"tool_call_args_match",
"tool_called_check",
@@ -0,0 +1,796 @@
# Copyright (c) Microsoft. All rights reserved.
"""AgentLoopMiddleware: re-run an agent in a loop until a criterion is met.
This module provides :class:`AgentLoopMiddleware`, an :class:`~agent_framework.AgentMiddleware`
that repeatedly re-invokes the wrapped agent while a ``should_continue`` predicate says to keep
going. It serves two common patterns through a single configurable class:
1. A user-supplied ``should_continue`` predicate - for example, keep looping while a response does
not yet contain a completion marker, while a :class:`~agent_framework.TodoProvider` still has
open items, or while a :class:`~agent_framework.BackgroundAgentsProvider` still has running
tasks (see the :func:`todos_remaining` and :func:`background_tasks_running` helpers). The loop
can track a **feedback log** across iterations (``record_feedback``): each pass contributes an
entry that is exposed to every callback via the ``progress`` keyword and (by default) injected
into the next iteration's input. Set ``fresh_context=True`` to restart each pass from the
original task plus the progress log (with a session attached, the session is also snapshotted
before the loop and restored between iterations so no accumulated history leaks back in).
``max_iterations`` bounds the loop as a safety cap.
2. A chat-client judge (via :meth:`AgentLoopMiddleware.with_judge`) - a second chat client decides
whether the user's original request has been answered (via a :class:`JudgeVerdict` structured
output); the loop continues while the answer is "no". This is a convenience wrapper that builds an
async ``should_continue`` predicate, so it is a special case of (1).
In every case, the input for the next iteration is controlled by the ``next_message`` callable.
"""
from __future__ import annotations
import inspect
from collections.abc import Awaitable, Callable, Sequence
from typing import TYPE_CHECKING, Any, TypeAlias
from pydantic import BaseModel, Field
from typing_extensions import Self
from .._feature_stage import ExperimentalFeature, experimental
from .._middleware import AgentContext, AgentMiddleware, MiddlewareTermination
from .._types import (
AgentResponse,
AgentResponseUpdate,
AgentRunInputs,
Message,
ResponseStream,
UsageDetails,
add_usage_details,
normalize_messages,
)
if TYPE_CHECKING:
from .._clients import SupportsChatGetResponse
__all__ = [
"AgentLoopMiddleware",
"JudgeVerdict",
"background_tasks_running",
"todos_remaining",
]
DEFAULT_NEXT_MESSAGE = "Continue working on the task. If it is complete, say so."
# Placeholder substituted with the rendered ``criteria`` block in judge instructions (see
# :meth:`AgentLoopMiddleware.with_judge`). User-supplied instructions may include it to control
# where the criteria are inserted; if absent, the criteria are not added to the judge instructions.
CRITERIA_PLACEHOLDER = "{{criteria}}"
# Verdict markers the judge is asked to emit for clients that do not honor structured output. They
# are deliberately non-overlapping: neither marker is a substring of the other, nor of the JSON
# field name ``answered``, so the text fallback in :func:`_build_judge_condition` cannot misclassify
# a negative verdict (e.g. ``{"answered": false}``) as a positive one.
JUDGE_VERDICT_DONE = "VERDICT: DONE"
JUDGE_VERDICT_MORE = "VERDICT: MORE"
DEFAULT_JUDGE_INSTRUCTIONS = (
"You are an evaluator. You are given a user's original request and an agent's latest response. "
"Decide whether the agent has fully addressed the original request. "
"Set 'answered' to true if the request has been fully addressed, or false if more work is still "
"required, and use 'reasoning' to briefly justify your decision. "
f"If you cannot return structured output, end your reply with a line reading exactly "
f"'{JUDGE_VERDICT_DONE}' when the request has been fully addressed or '{JUDGE_VERDICT_MORE}' "
f"when more work is still required."
"{{criteria}}"
)
def _render_criteria_block(criteria: Sequence[str] | None) -> str:
"""Render a list of criteria into a bullet block for the judge instructions (``""`` if none)."""
if not criteria:
return ""
bullets = "\n".join(f"- {item}" for item in criteria)
return f"\n\nThe response must satisfy all of the following criteria:\n{bullets}"
def _criteria_agent_instruction(criteria: Sequence[str]) -> str:
"""Render the criteria into an extra instruction injected for the agent before each run."""
bullets = "\n".join(f"- {item}" for item in criteria)
return f"Your response must satisfy all of the following criteria:\n{bullets}"
class JudgeVerdict(BaseModel):
"""Structured verdict returned by the judge chat client."""
answered: bool = Field(
description=(
"True if the agent has fully addressed the original request and it adheres to the other "
"judging standards, otherwise False."
),
)
reasoning: str = Field(
default="",
description="Brief justification for the verdict.",
)
# Default iteration cap applied when ``max_iterations`` is not provided. Loops are bounded by
# default to guard against runaway re-invocation; pass ``max_iterations=None`` explicitly to opt
# into an unbounded loop.
DEFAULT_MAX_ITERATIONS = 10
# Default iteration cap for judge-driven loops. LLM-judged loops are costly and probabilistic, so
# they are bounded by a smaller default. Pass ``max_iterations=None`` explicitly to opt into an
# unbounded judge loop.
DEFAULT_JUDGE_MAX_ITERATIONS = 5
# A callable invoked between iterations. It always receives the loop keyword arguments
# (``iteration``, ``last_result``, ``messages``, ``original_messages``, ``session``, ``agent``,
# ``progress``, ``feedback``). Callers declare only the keywords they need plus ``**kwargs`` to
# ignore the rest. ``should_continue`` may return a plain ``bool`` (continue/stop) or a
# ``(bool, str | None)`` tuple whose second item is feedback surfaced to the ``next_message`` and
# ``record_feedback`` callables via the ``feedback`` keyword argument.
ShouldContinueResult: TypeAlias = "bool | tuple[bool, str | None]"
ShouldContinueCallable = Callable[..., "ShouldContinueResult | Awaitable[ShouldContinueResult]"]
NextMessageCallable = Callable[..., "AgentRunInputs | Awaitable[AgentRunInputs | None] | None"]
# A callable invoked once per work iteration to capture a progress-log entry from that iteration. It
# receives the loop keyword arguments and returns a string entry (appended to the log) or ``None``
# (record nothing for that iteration).
FeedbackCallable = Callable[..., "str | Awaitable[str | None] | None"]
async def _maybe_await(value: Any) -> Any:
"""Await ``value`` if it is awaitable, otherwise return it as-is."""
if inspect.isawaitable(value):
return await value
return value
def _build_judge_condition(
judge_client: SupportsChatGetResponse,
instructions: str,
) -> tuple[ShouldContinueCallable, NextMessageCallable]:
"""Build the ``should_continue`` predicate and ``next_message`` callable for a judge loop.
The judge is called directly (no agent tools, session, or middleware) with fresh messages, so
the loop's evaluation cannot recurse back through the agent pipeline. The original input messages
are forwarded verbatim (rather than collapsed to text) so multi-modal requests are preserved. The
judge is asked for a :class:`JudgeVerdict` structured output; if the client does not honor
structured output the verdict falls back to the explicit, non-overlapping ``VERDICT: DONE`` /
``VERDICT: MORE`` markers (``MORE`` wins, keeping the loop running, when the marker is ambiguous
or absent).
The predicate returns a ``(continue, reasoning)`` tuple; the loop surfaces that ``reasoning`` to
the next-message callable as the ``feedback`` keyword argument, which feeds it back to the agent
so it knows *why* its previous answer was judged incomplete.
"""
async def _judge(
*, last_result: AgentResponse, original_messages: list[Message], **kwargs: Any
) -> tuple[bool, str | None]:
judge_messages = [
Message(role="system", contents=[instructions]),
Message(
role="user",
contents=["Evaluate the agent's work. The user's original request follows:"],
),
*original_messages,
Message(role="user", contents=["The agent's latest response was:"]),
*last_result.messages,
Message(role="user", contents=["Has the original request been fully addressed?"]),
]
response = await judge_client.get_response(judge_messages, options={"response_format": JudgeVerdict})
verdict = response.value
if isinstance(verdict, JudgeVerdict):
answered = verdict.answered
reasoning = verdict.reasoning
else:
# Fallback for clients that do not honor structured output: look for the explicit,
# non-overlapping verdict markers. ``FAIL`` (more work needed) takes precedence so an
# ambiguous or marker-less reply keeps looping rather than stopping on an incomplete
# answer.
text = response.text.upper()
# ``MORE`` (more work needed) takes precedence so an ambiguous reply keeps looping.
answered = False if JUDGE_VERDICT_MORE in text else JUDGE_VERDICT_DONE in text
reasoning = response.text.strip()
# Continue looping while the request is not yet answered, surfacing the reasoning as feedback.
return (not answered), (reasoning or None)
def _next_message(*, feedback: str | None = None, **kwargs: Any) -> AgentRunInputs:
# Feed the judge's reasoning back to the agent so the next iteration addresses the gap.
if feedback:
return (
"An evaluator reviewed your previous response and judged that it does not yet fully "
f"address the original request.\n\nEvaluator feedback: {feedback}\n\n"
"Revise and continue so the original request is fully addressed."
)
return DEFAULT_NEXT_MESSAGE
return _judge, _next_message
@experimental(feature_id=ExperimentalFeature.HARNESS)
class AgentLoopMiddleware(AgentMiddleware):
"""Re-run an agent in a loop until a criterion is met (or never).
This middleware repeatedly invokes the wrapped agent. After each run it decides whether to run
again based on ``should_continue`` and ``max_iterations``, and uses ``next_message`` to build
the input for the next iteration. Use :meth:`with_judge` to drive the loop with a chat-client
judge instead of a hand-written predicate.
By default a non-streaming run returns an aggregated :class:`~agent_framework.AgentResponse`
containing every iteration's messages plus the injected ``next_message`` "nudge" messages (set
``return_final_only=True`` to return only the last iteration's response). Streaming runs always
yield each iteration's updates and emit the injected nudge messages as ``user`` updates between
iterations.
The ``should_continue`` and ``next_message`` callables are invoked with keyword arguments, so a
caller only needs to declare the ones it uses plus ``**kwargs``. The keywords are:
- ``iteration`` (int): the number of completed runs so far (1-based after the first run).
- ``last_result`` (AgentResponse): the result of the iteration that just completed.
- ``messages`` (list[Message]): the messages used for the iteration that just completed.
- ``original_messages`` (list[Message]): the input used for the first iteration.
- ``session`` (AgentSession | None): the active session, used by the provider helpers.
- ``agent``: the agent being looped.
- ``progress`` (list[str]): the feedback log accumulated so far (see ``record_feedback``).
- ``feedback`` (str | None): the feedback string returned by ``should_continue`` for this
iteration (``None`` when it returned a plain bool). ``should_continue`` may return either a
``bool`` or a ``(bool, str | None)`` tuple; the string is surfaced here so ``next_message``
and ``record_feedback`` can reference it.
Examples:
.. code-block:: python
from agent_framework import Agent, AgentResponse
from agent_framework._harness._loop import AgentLoopMiddleware
async def should_continue(*, iteration: int, last_result: AgentResponse, **kwargs) -> bool:
return iteration < 3 and "DONE" not in last_result.text
agent = Agent(client=client, middleware=[AgentLoopMiddleware(should_continue)])
Note:
``max_iterations`` acts as a safety cap and defaults to ``DEFAULT_MAX_ITERATIONS`` (10). Pass
an explicit ``None`` to make the loop unbounded, in which case it relies entirely on
``should_continue`` to stop, so make sure the predicate can eventually return ``False``.
"""
def __init__(
self,
should_continue: ShouldContinueCallable,
*,
max_iterations: int | None = DEFAULT_MAX_ITERATIONS,
next_message: NextMessageCallable | None = None,
record_feedback: FeedbackCallable | None = None,
inject_progress: bool = True,
fresh_context: bool = False,
return_final_only: bool = False,
additional_instructions: str | None = None,
) -> None:
"""Initialize the agent loop middleware.
Args:
should_continue: Predicate that decides whether to run the agent again. May be sync or
async and is called with the loop keyword arguments (``iteration``, ``last_result``,
``messages``, ``original_messages``, ``session``, ``agent``, ``progress``, and
``feedback`` -- see the class docstring for what each one carries; declare only the
ones you need plus ``**kwargs``). Return ``True``/``False`` to
continue/stop, or a ``(bool, str | None)`` tuple to also provide feedback; the
feedback string is surfaced to the ``next_message`` and ``record_feedback`` callables
via the ``feedback`` keyword argument. To loop on a chat-client judge instead, build
the middleware via :meth:`with_judge`.
Keyword Args:
max_iterations: Maximum number of agent runs, used as a safety cap. Defaults to
``DEFAULT_MAX_ITERATIONS`` (10); pass an explicit ``None`` for an unbounded loop, or
a positive integer to set a custom cap. (The :meth:`with_judge` factory uses
``DEFAULT_JUDGE_MAX_ITERATIONS`` (5) as its default instead.)
next_message: Callable that produces the input for the next iteration, called with the
loop keyword arguments. Defaults to a short "continue" nudge. Returning ``None``
reuses the previous iteration's messages verbatim (in which case the progress log is
*not* injected; see ``inject_progress``).
record_feedback: Optional callable invoked once per work iteration to capture a feedback
entry. Called as ``record_feedback(**loop_kwargs)`` and returns a
string entry appended to the progress log, or ``None`` to record nothing for that
iteration. When not provided, the iteration's response text (``last_result.text``) is
recorded instead. The accumulated log is exposed to every callback via the
``progress`` loop keyword argument. For production loops prefer a ``record_feedback``
that returns a terse summary rather than relying on the full response text.
inject_progress: When ``True`` (default), the accumulated progress log is injected into
the next iteration's input as a single ``user`` message ("Progress so far: ..."). To
avoid duplication, only the most recent entry is injected when a session is attached
(the session already retains earlier turns); the full log is injected when there is
no session or ``fresh_context`` is set. When ``False`` the log is only exposed via the
``progress`` loop keyword argument and never injected automatically.
fresh_context: When ``True``, each iteration starts from a clean context: ``context``
messages are reset to the original input messages (plus the injected progress log)
instead of accumulating the prior conversation. When a session is attached, the
session is snapshotted once before the loop and restored to that pre-loop baseline
before each subsequent iteration, so the local transcript and any service-side
conversation id are reset too and the agent does not re-read the accumulated history.
In-loop working-state mutations are discarded; pre-loop state is preserved; continuity
is carried only by the progress log.
return_final_only: Controls what a non-streaming run returns. When ``False`` (default),
the returned :class:`~agent_framework.AgentResponse` aggregates every iteration: each
iteration's response messages plus the injected ``next_message`` "nudge" messages
(as ``user`` messages), so the caller sees the full back-and-forth. When ``True``,
only the final iteration's :class:`~agent_framework.AgentResponse` is returned. This
flag has no effect on streaming runs (the stream cannot know in advance which
iteration is last); streaming always yields each iteration's updates and injects the
``next_message`` messages as ``user`` updates between iterations.
additional_instructions: Optional extra instruction injected as a ``system`` message
ahead of the input messages before the agent runs. It becomes part of the original
messages, so it is preserved across ``fresh_context`` resets and (with a session)
persists server-side across iterations. Used by :meth:`with_judge` to tell the agent
about the criteria its response must satisfy, but available to any loop.
Raises:
ValueError: If ``max_iterations`` is not ``None`` and is less than 1.
"""
if max_iterations is not None and max_iterations < 1:
raise ValueError("max_iterations must be None or a positive integer (>= 1).")
self.max_iterations: int | None = max_iterations
self.should_continue: ShouldContinueCallable = should_continue
self.next_message = next_message
self.record_feedback = record_feedback
self.inject_progress = inject_progress
self.fresh_context = fresh_context
self.return_final_only = return_final_only
self.additional_instructions = additional_instructions
@classmethod
def with_judge(
cls,
judge_client: SupportsChatGetResponse,
*,
criteria: Sequence[str] | None = None,
instructions: str | None = None,
max_iterations: int | None = DEFAULT_JUDGE_MAX_ITERATIONS,
next_message: NextMessageCallable | None = None,
fresh_context: bool = False,
) -> Self:
"""Create a loop that continues until a judge chat client decides the request was answered.
Convenience factory for the judge pattern: ``judge_client`` is queried with a
:class:`JudgeVerdict` structured-output response after each iteration and the loop continues
while the request is *not* answered. The judge's ``reasoning`` is fed back to the agent as
the next iteration's input (unless a custom ``next_message`` is provided), so the agent knows
why its previous answer was judged incomplete. See :meth:`__init__` for the full meaning of
each argument.
Args:
judge_client: Chat client used to judge whether the original request was answered.
Keyword Args:
criteria: Optional list of criteria the response must satisfy. When provided, they are
(1) injected as an extra ``system`` instruction for the agent before it runs (via
``additional_instructions``) and (2) rendered into the judge instructions wherever
the ``{{criteria}}`` placeholder appears (``CRITERIA_PLACEHOLDER``).
instructions: Optional system instructions for the judge. Defaults to
``DEFAULT_JUDGE_INSTRUCTIONS``. May contain the ``{{criteria}}`` placeholder, which
is replaced with the rendered ``criteria`` (or removed when no criteria are given).
max_iterations: Maximum number of agent runs. Defaults to
``DEFAULT_JUDGE_MAX_ITERATIONS`` (5); pass ``None`` for unbounded, or a positive
integer to set a custom cap.
next_message: Callable that produces the next iteration's input. Defaults to one that
relays the judge's ``reasoning`` back to the agent.
fresh_context: When ``True``, each iteration restarts from the original input messages
(plus the injected progress log and judge feedback) instead of accumulating the prior
conversation; an attached session is snapshotted before the loop and restored to that
baseline between iterations. See :meth:`__init__` for the full semantics. Defaults to
``False``.
"""
judge_instructions = (instructions or DEFAULT_JUDGE_INSTRUCTIONS).replace(
CRITERIA_PLACEHOLDER, _render_criteria_block(criteria)
)
should_continue, judge_next_message = _build_judge_condition(judge_client, judge_instructions)
return cls(
should_continue=should_continue,
max_iterations=max_iterations,
next_message=next_message or judge_next_message,
fresh_context=fresh_context,
additional_instructions=_criteria_agent_instruction(criteria) if criteria else None,
)
async def process(
self,
context: AgentContext,
call_next: Callable[[], Awaitable[None]],
) -> None:
"""Run the wrapped agent in a loop."""
if self.additional_instructions is not None:
# Inject the extra instruction as a system message ahead of the input so it is present
# on every iteration and preserved across fresh_context resets (which restart from
# ``original_messages``).
context.messages = [
Message(role="system", contents=[self.additional_instructions]),
*context.messages,
]
original_messages = list(context.messages)
# For a truly fresh context per iteration the session must also be reset, otherwise the
# next run reloads the local transcript or re-threads the service-side conversation and the
# model still sees the accumulated history. Snapshot the session once here (the pre-loop
# baseline) and restore it before each subsequent iteration so every pass starts clean.
snapshot = context.session.to_dict() if self.fresh_context and context.session is not None else None
if context.stream:
self._process_streaming(context, call_next, original_messages, snapshot)
else:
await self._process_non_streaming(context, call_next, original_messages, snapshot)
@staticmethod
def _restore_session(session: Any, snapshot: dict[str, Any]) -> None:
"""Restore a session in place to a previously captured ``to_dict()`` snapshot.
Re-hydrates the snapshot via :meth:`AgentSession.from_dict` and copies the mutable fields
(``service_session_id`` and ``state``) back onto the live ``session`` instance, so any
reference held by the agent/context observes the reset. ``session_id`` is preserved (the
snapshot carries the same id). A fresh ``from_dict`` is built on every call so repeated
restores from one snapshot do not alias the same state dict.
"""
from .._sessions import AgentSession
restored = AgentSession.from_dict(snapshot)
session.service_session_id = restored.service_session_id
session.state = restored.state
async def _process_non_streaming(
self,
context: AgentContext,
call_next: Callable[[], Awaitable[None]],
original_messages: list[Message],
snapshot: dict[str, Any] | None,
) -> None:
iteration = 0
work_iterations = 0
progress: list[str] = []
# Aggregated transcript across iterations: each iteration's response messages plus the
# injected "nudge" messages, used to build the combined response when return_final_only=False.
aggregated: list[Message] = []
aggregated_usage: UsageDetails | None = None
final_result: AgentResponse | None = None
while True:
await call_next()
iteration += 1
result = context.result
if not isinstance(result, AgentResponse):
raise TypeError(
"AgentLoopMiddleware expected an AgentResponse from a non-streaming run, "
f"got {type(result).__name__}."
)
final_result = result
aggregated.extend(result.messages)
if result.usage_details is not None:
aggregated_usage = add_usage_details(aggregated_usage, result.usage_details)
messages_used = context.messages
loop_kwargs = self._build_loop_kwargs(
context=context,
iteration=iteration,
last_result=result,
messages_used=messages_used,
original_messages=original_messages,
progress=progress,
)
work_iterations += 1
# Decide whether to stop and capture any feedback from should_continue first, so the
# feedback is available to both the progress and next-message callables this iteration.
stop, feedback = await self._evaluate_stop(loop_kwargs, work_iterations)
loop_kwargs = self._build_loop_kwargs(
context=context,
iteration=iteration,
last_result=result,
messages_used=messages_used,
original_messages=original_messages,
progress=progress,
feedback=feedback,
)
# Capture this iteration's progress entry, then refresh loop_kwargs so the next-message
# resolution sees the latest entry.
if await self._record_progress(result, loop_kwargs, progress):
loop_kwargs = self._build_loop_kwargs(
context=context,
iteration=iteration,
last_result=result,
messages_used=messages_used,
original_messages=original_messages,
progress=progress,
feedback=feedback,
)
if stop:
break
if snapshot is not None and context.session is not None:
# Reset the session to the pre-loop baseline so the next run starts fresh; only the
# progress log (injected by _resolve_next_message) carries continuity forward.
self._restore_session(context.session, snapshot)
next_messages = await self._resolve_next_message(loop_kwargs, messages_used, original_messages)
context.messages = next_messages
aggregated.extend(next_messages)
if not self.return_final_only:
context.result = self._aggregate_response(final_result, aggregated, aggregated_usage)
def _process_streaming(
self,
context: AgentContext,
call_next: Callable[[], Awaitable[None]],
original_messages: list[Message],
snapshot: dict[str, Any] | None,
) -> None:
# Holds the last iteration's final response so the outer stream's finalizer can return it
# rather than an aggregate of every iteration.
holder: dict[str, AgentResponse | None] = {"final": None}
async def _generator() -> Any:
iteration = 0
work_iterations = 0
progress: list[str] = []
while True:
try:
await call_next()
inner = context.result
if not isinstance(inner, ResponseStream):
raise TypeError(
"AgentLoopMiddleware expected a ResponseStream from a streaming run, "
f"got {type(inner).__name__}."
)
async for update in inner:
yield update
holder["final"] = await inner.get_final_response()
except MiddlewareTermination:
# The pipeline's MiddlewareTermination suppression is no longer active once
# process() has returned (the stream is consumed lazily), so a termination
# raised by a downstream middleware or during stream consumption surfaces here.
# Stop cleanly and keep whatever final response we have from a prior iteration.
return
iteration += 1
messages_used = context.messages
final = holder["final"]
loop_kwargs = self._build_loop_kwargs(
context=context,
iteration=iteration,
last_result=final,
messages_used=messages_used,
original_messages=original_messages,
progress=progress,
)
work_iterations += 1
# Decide whether to stop and capture any feedback from should_continue first, so the
# feedback is available to both the progress and next-message callables this iteration.
stop, feedback = await self._evaluate_stop(loop_kwargs, work_iterations)
loop_kwargs = self._build_loop_kwargs(
context=context,
iteration=iteration,
last_result=final,
messages_used=messages_used,
original_messages=original_messages,
progress=progress,
feedback=feedback,
)
if await self._record_progress(final, loop_kwargs, progress):
loop_kwargs = self._build_loop_kwargs(
context=context,
iteration=iteration,
last_result=final,
messages_used=messages_used,
original_messages=original_messages,
progress=progress,
feedback=feedback,
)
if stop:
return
if snapshot is not None and context.session is not None:
# Reset the session to the pre-loop baseline before the next run. The final
# response was already awaited above, so the service-side conversation id has
# been propagated and is safe to discard here.
self._restore_session(context.session, snapshot)
next_messages = await self._resolve_next_message(loop_kwargs, messages_used, original_messages)
context.messages = next_messages
# Surface the injected "nudge" messages in the stream so consumers see the user
# turns that drive each subsequent iteration (the equivalent of the aggregated
# transcript that non-streaming runs return).
for message in next_messages:
yield self._message_to_update(message)
def _finalize(updates: Sequence[AgentResponseUpdate]) -> AgentResponse:
if holder["final"] is not None:
return holder["final"]
return AgentResponse.from_updates(updates)
context.result = ResponseStream(_generator(), finalizer=_finalize)
def _build_loop_kwargs(
self,
*,
context: AgentContext,
iteration: int,
last_result: AgentResponse | None,
messages_used: list[Message],
original_messages: list[Message],
progress: list[str],
feedback: str | None = None,
) -> dict[str, Any]:
return {
"iteration": iteration,
"last_result": last_result,
"messages": messages_used,
"original_messages": original_messages,
"session": context.session,
"agent": context.agent,
# A copy so user callbacks cannot mutate the loop's internal progress log.
"progress": list(progress),
# Feedback returned by ``should_continue`` for this iteration (``None`` if it returned a
# plain bool, or the stop was decided by ``max_iterations``).
"feedback": feedback,
}
async def _record_progress(
self,
last_result: AgentResponse | None,
loop_kwargs: dict[str, Any],
progress: list[str],
) -> bool:
"""Capture this iteration's feedback into ``progress``. Returns ``True`` if an entry was added."""
if self.record_feedback is not None:
entry = await _maybe_await(self.record_feedback(**loop_kwargs))
else:
entry = last_result.text.strip() if last_result is not None else None
if entry:
progress.append(entry)
return True
return False
async def _evaluate_stop(self, loop_kwargs: dict[str, Any], work_iterations: int) -> tuple[bool, str | None]:
"""Decide whether the loop should stop, returning ``(stop, feedback)``.
``max_iterations`` is a safety cap that short-circuits before ``should_continue`` is
evaluated (so an expensive predicate/judge is not called once the cap has fired). Any
feedback returned by ``should_continue`` is propagated so the progress and next-message
callables can reference it.
"""
if self.max_iterations is not None and work_iterations >= self.max_iterations:
return True, None
keep_going, feedback = await self._should_continue(loop_kwargs)
return (not keep_going), feedback
async def _should_continue(self, loop_kwargs: dict[str, Any]) -> tuple[bool, str | None]:
"""Evaluate the predicate, normalizing its result to ``(continue, feedback)``."""
result = await _maybe_await(self.should_continue(**loop_kwargs))
return (bool(result[0]), result[1]) if isinstance(result, tuple) else (bool(result), None) # type: ignore
@staticmethod
def _message_to_update(message: Message) -> AgentResponseUpdate:
"""Wrap an injected loop message as a streaming update so consumers see it inline."""
return AgentResponseUpdate(
contents=message.contents,
role=message.role,
author_name=message.author_name,
message_id=message.message_id,
)
@staticmethod
def _aggregate_response(
final: AgentResponse,
messages: list[Message],
usage: UsageDetails | None,
) -> AgentResponse:
"""Build a combined response carrying every iteration's messages and summed usage.
Metadata (``response_id``, structured ``value``, etc.) is taken from the final iteration; the
structured value is passed through pre-parsed so it is not re-derived from the aggregated text.
"""
return AgentResponse(
messages=messages,
response_id=final.response_id,
agent_id=final.agent_id,
created_at=final.created_at,
finish_reason=final.finish_reason, # pyright: ignore[reportArgumentType]
usage_details=usage,
value=final.value,
additional_properties=dict(final.additional_properties) if final.additional_properties else None,
raw_representation=final.raw_representation,
)
@staticmethod
def _render_progress(entries: list[str]) -> Message:
"""Format progress-log entries into a single ``user`` message."""
body = "\n".join(f"- {entry}" for entry in entries)
return Message(role="user", contents=[f"Progress so far:\n{body}"])
async def _resolve_next_message(
self,
loop_kwargs: dict[str, Any],
messages_used: list[Message],
original_messages: list[Message],
) -> list[Message]:
# Compute the base next input. A ``next_message`` callable returning None requests a verbatim
# reuse of the previous messages (no progress injection); in fresh-context mode that escape
# hatch does not apply, so fall back to the default nudge instead.
if self.next_message is None:
next_msgs = normalize_messages(DEFAULT_NEXT_MESSAGE)
else:
next_input = await _maybe_await(self.next_message(**loop_kwargs))
if next_input is None:
if not self.fresh_context:
return list(messages_used)
next_msgs = normalize_messages(DEFAULT_NEXT_MESSAGE)
else:
next_msgs = normalize_messages(next_input)
progress: list[str] = loop_kwargs.get("progress") or []
session = loop_kwargs.get("session")
progress_msg: Message | None = None
if self.inject_progress and progress:
# With a session the earlier entries are already retained in the conversation, so only
# the latest entry is injected to avoid duplication. Otherwise inject the full log.
entries = progress if (session is None or self.fresh_context) else progress[-1:]
progress_msg = self._render_progress(entries)
if self.fresh_context:
result = list(original_messages)
if progress_msg is not None:
result.append(progress_msg)
result.extend(next_msgs)
return result
if progress_msg is not None:
return [progress_msg, *next_msgs]
return list(next_msgs)
def todos_remaining(provider: Any) -> ShouldContinueCallable:
"""Build a ``should_continue`` predicate that loops while a ``TodoProvider`` has open items.
Args:
provider: A :class:`~agent_framework.TodoProvider` attached to the same session as the loop.
Returns:
A predicate suitable for :class:`AgentLoopMiddleware`'s ``should_continue`` argument.
"""
async def _should_continue(*, session: Any = None, **kwargs: Any) -> bool:
if session is None:
return False
items = await provider.store.load_items(session, source_id=provider.source_id)
return any(not item.is_complete for item in items)
return _should_continue
def background_tasks_running(provider: Any) -> ShouldContinueCallable:
"""Build a ``should_continue`` predicate that loops while a ``BackgroundAgentsProvider`` is busy.
The predicate inspects the provider's persisted task state and continues while any task is still
marked as running. Pair it with ``max_iterations`` so the loop is guaranteed to stop even if a
task's persisted status is never refreshed.
Args:
provider: A :class:`~agent_framework.BackgroundAgentsProvider` attached to the same session
as the loop.
Returns:
A predicate suitable for :class:`AgentLoopMiddleware`'s ``should_continue`` argument.
"""
from ._background_agents import BackgroundTaskInfo, BackgroundTaskStatus
def _should_continue(*, session: Any = None, **kwargs: Any) -> bool:
if session is None:
return False
state = session.state.get(provider.source_id)
if not state:
return False
return any(
BackgroundTaskInfo.from_dict(task).status == BackgroundTaskStatus.RUNNING for task in state.get("tasks", [])
)
return _should_continue
File diff suppressed because it is too large Load Diff