Python: Fix middleware terminate flag to exit function calling loop immediately (#2868)

* Fix middleware terminate flag to exit function calling loop immediately

* Eliminating duck typing

* Improve function exec result handling

* Fix race condition

* Fix mypy issues
This commit is contained in:
Evan Mattson
2025-12-16 18:52:52 +09:00
committed by GitHub
Unverified
parent 3139347526
commit 11d6dcfe80
4 changed files with 341 additions and 62 deletions
@@ -1405,13 +1405,17 @@ def use_chat_middleware(chat_client_class: type[TChatClient]) -> type[TChatClien
call_middleware = kwargs.pop("middleware", None)
instance_middleware = getattr(self, "middleware", None)
# Merge middleware from both sources, filtering for chat middleware only
all_middleware: list[ChatMiddleware | ChatMiddlewareCallable] = _merge_and_filter_chat_middleware(
instance_middleware, call_middleware
)
# Merge all middleware and separate by type
middleware = categorize_middleware(instance_middleware, call_middleware)
chat_middleware_list = middleware["chat"]
function_middleware_list = middleware["function"]
# If no middleware, use original method
if not all_middleware:
# Pass function middleware to function invocation system if present
if function_middleware_list:
kwargs["_function_middleware_pipeline"] = FunctionMiddlewarePipeline(function_middleware_list)
# If no chat middleware, use original method
if not chat_middleware_list:
async for update in original_get_streaming_response(self, messages, **kwargs):
yield update
return
@@ -1422,7 +1426,7 @@ def use_chat_middleware(chat_client_class: type[TChatClient]) -> type[TChatClien
# Extract chat_options or create default
chat_options = kwargs.pop("chat_options", ChatOptions())
pipeline = ChatMiddlewarePipeline(all_middleware) # type: ignore[arg-type]
pipeline = ChatMiddlewarePipeline(chat_middleware_list) # type: ignore[arg-type]
context = ChatContext(
chat_client=self,
messages=prepare_messages(messages),
@@ -1536,27 +1540,40 @@ def _merge_and_filter_chat_middleware(
return middleware["chat"] # type: ignore[return-value]
def extract_and_merge_function_middleware(chat_client: Any, **kwargs: Any) -> None:
def extract_and_merge_function_middleware(
chat_client: Any, kwargs: dict[str, Any]
) -> "FunctionMiddlewarePipeline | None":
"""Extract function middleware from chat client and merge with existing pipeline in kwargs.
Args:
chat_client: The chat client instance to extract middleware from.
kwargs: Dictionary containing middleware and pipeline information.
Keyword Args:
**kwargs: Dictionary containing middleware and pipeline information.
Returns:
A FunctionMiddlewarePipeline if function middleware is found, None otherwise.
"""
# Check if a pipeline was already created by use_chat_middleware
existing_pipeline: FunctionMiddlewarePipeline | None = kwargs.get("_function_middleware_pipeline")
# Get middleware sources
client_middleware = getattr(chat_client, "middleware", None) if hasattr(chat_client, "middleware") else None
run_level_middleware = kwargs.get("middleware")
existing_pipeline = kwargs.get("_function_middleware_pipeline")
# Extract existing pipeline middlewares if present
existing_middlewares = existing_pipeline._middlewares if existing_pipeline else None
# If we have an existing pipeline but no additional middleware sources, return it directly
if existing_pipeline and not client_middleware and not run_level_middleware:
return existing_pipeline
# If we have an existing pipeline with additional middleware, we need to merge
# Extract existing pipeline middlewares if present - cast to list[Middleware] for type compatibility
existing_middlewares: list[Middleware] | None = list(existing_pipeline._middlewares) if existing_pipeline else None
# Create combined pipeline from all sources using existing helper
combined_pipeline = create_function_middleware_pipeline(
client_middleware, run_level_middleware, existing_middlewares
)
if combined_pipeline:
kwargs["_function_middleware_pipeline"] = combined_pipeline
# If we have an existing pipeline but combined is None (no new middlewares), return existing
if existing_pipeline and combined_pipeline is None:
return existing_pipeline
return combined_pipeline
+116 -42
View File
@@ -1348,6 +1348,35 @@ class FunctionInvocationConfiguration(SerializationMixin):
self.include_detailed_errors = include_detailed_errors
class FunctionExecutionResult:
"""Internal wrapper pairing function output with loop control signals.
Function execution produces two distinct concerns: the semantic result (returned to
the LLM as FunctionResultContent) and control flow decisions (whether middleware
requested early termination). This wrapper keeps control signals out of user-facing
content types while allowing _try_execute_function_calls to communicate both.
Not exposed to users.
Attributes:
content: The FunctionResultContent or other content from the function execution.
terminate: If True, the function invocation loop should exit immediately without
another LLM call. Set when middleware sets context.terminate=True.
"""
__slots__ = ("content", "terminate")
def __init__(self, content: "Contents", terminate: bool = False) -> None:
"""Initialize FunctionExecutionResult.
Args:
content: The content from the function execution.
terminate: Whether to terminate the function calling loop.
"""
self.content = content
self.terminate = terminate
async def _auto_invoke_function(
function_call_content: "FunctionCallContent | FunctionApprovalResponseContent",
custom_args: dict[str, Any] | None = None,
@@ -1357,7 +1386,7 @@ async def _auto_invoke_function(
sequence_index: int | None = None,
request_index: int | None = None,
middleware_pipeline: Any = None, # Optional MiddlewarePipeline
) -> "Contents":
) -> "FunctionExecutionResult | Contents":
"""Invoke a function call requested by the agent, applying middleware that is defined.
Args:
@@ -1372,7 +1401,8 @@ async def _auto_invoke_function(
middleware_pipeline: Optional middleware pipeline to apply during execution.
Returns:
A FunctionResultContent containing the result or exception.
A FunctionExecutionResult wrapping the content and terminate signal,
or a Contents object for approval/hosted tool scenarios.
Raises:
KeyError: If the requested function is not found in the tool map.
@@ -1392,10 +1422,12 @@ async def _auto_invoke_function(
# Tool should exist because _try_execute_function_calls validates this
if tool is None:
exc = KeyError(f'Function "{function_call_content.name}" not found.')
return FunctionResultContent(
call_id=function_call_content.call_id,
result=f'Error: Requested function "{function_call_content.name}" not found.',
exception=exc,
return FunctionExecutionResult(
content=FunctionResultContent(
call_id=function_call_content.call_id,
result=f'Error: Requested function "{function_call_content.name}" not found.',
exception=exc,
)
)
else:
# Note: Unapproved tools (approved=False) are handled in _replace_approval_contents_with_results
@@ -1420,7 +1452,9 @@ async def _auto_invoke_function(
message = "Error: Argument parsing failed."
if config.include_detailed_errors:
message = f"{message} Exception: {exc}"
return FunctionResultContent(call_id=function_call_content.call_id, result=message, exception=exc)
return FunctionExecutionResult(
content=FunctionResultContent(call_id=function_call_content.call_id, result=message, exception=exc)
)
if not middleware_pipeline or (
not hasattr(middleware_pipeline, "has_middlewares") and not middleware_pipeline.has_middlewares
@@ -1432,15 +1466,19 @@ async def _auto_invoke_function(
tool_call_id=function_call_content.call_id,
**runtime_kwargs if getattr(tool, "_forward_runtime_kwargs", False) else {},
)
return FunctionResultContent(
call_id=function_call_content.call_id,
result=function_result,
return FunctionExecutionResult(
content=FunctionResultContent(
call_id=function_call_content.call_id,
result=function_result,
)
)
except Exception as exc:
message = "Error: Function failed."
if config.include_detailed_errors:
message = f"{message} Exception: {exc}"
return FunctionResultContent(call_id=function_call_content.call_id, result=message, exception=exc)
return FunctionExecutionResult(
content=FunctionResultContent(call_id=function_call_content.call_id, result=message, exception=exc)
)
# Execute through middleware pipeline if available
from ._middleware import FunctionInvocationContext
@@ -1464,15 +1502,20 @@ async def _auto_invoke_function(
context=middleware_context,
final_handler=final_function_handler,
)
return FunctionResultContent(
call_id=function_call_content.call_id,
result=function_result,
return FunctionExecutionResult(
content=FunctionResultContent(
call_id=function_call_content.call_id,
result=function_result,
),
terminate=middleware_context.terminate,
)
except Exception as exc:
message = "Error: Function failed."
if config.include_detailed_errors:
message = f"{message} Exception: {exc}"
return FunctionResultContent(call_id=function_call_content.call_id, result=message, exception=exc)
return FunctionExecutionResult(
content=FunctionResultContent(call_id=function_call_content.call_id, result=message, exception=exc)
)
def _get_tool_map(
@@ -1503,7 +1546,7 @@ async def _try_execute_function_calls(
| Sequence[ToolProtocol | Callable[..., Any] | MutableMapping[str, Any]]",
config: FunctionInvocationConfiguration,
middleware_pipeline: Any = None, # Optional MiddlewarePipeline to avoid circular imports
) -> Sequence["Contents"]:
) -> tuple[Sequence["Contents"], bool]:
"""Execute multiple function calls concurrently.
Args:
@@ -1515,9 +1558,11 @@ async def _try_execute_function_calls(
middleware_pipeline: Optional middleware pipeline to apply during execution.
Returns:
A list of Contents containing the results of each function call,
or the approval requests if any function requires approval,
or the original function calls if any are declaration only.
A tuple of:
- A list of Contents containing the results of each function call,
or the approval requests if any function requires approval,
or the original function calls if any are declaration only.
- A boolean indicating whether to terminate the function calling loop.
"""
from ._types import FunctionApprovalRequestContent, FunctionCallContent
@@ -1540,17 +1585,20 @@ async def _try_execute_function_calls(
raise KeyError(f'Error: Requested function "{fcc.name}" not found.')
if approval_needed:
# approval can only be needed for Function Call Contents, not Approval Responses.
return [
FunctionApprovalRequestContent(id=fcc.call_id, function_call=fcc)
for fcc in function_calls
if isinstance(fcc, FunctionCallContent)
]
return (
[
FunctionApprovalRequestContent(id=fcc.call_id, function_call=fcc)
for fcc in function_calls
if isinstance(fcc, FunctionCallContent)
],
False,
)
if declaration_only_flag:
# return the declaration only tools to the user, since we cannot execute them.
return [fcc for fcc in function_calls if isinstance(fcc, FunctionCallContent)]
return ([fcc for fcc in function_calls if isinstance(fcc, FunctionCallContent)], False)
# Run all function calls concurrently
return await asyncio.gather(*[
execution_results = await asyncio.gather(*[
_auto_invoke_function(
function_call_content=function_call, # type: ignore[arg-type]
custom_args=custom_args,
@@ -1563,6 +1611,20 @@ async def _try_execute_function_calls(
for seq_idx, function_call in enumerate(function_calls)
])
# Unpack FunctionExecutionResult wrappers and check for terminate signal
contents: list[Contents] = []
should_terminate = False
for result in execution_results:
if isinstance(result, FunctionExecutionResult):
contents.append(result.content)
if result.terminate:
should_terminate = True
else:
# Direct Contents (e.g., from hosted tools)
contents.append(result)
return (contents, should_terminate)
def _update_conversation_id(kwargs: dict[str, Any], conversation_id: str | None) -> None:
"""Update kwargs with conversation id.
@@ -1695,12 +1757,8 @@ def _handle_function_calls_response(
prepare_messages,
)
# Extract and merge function middleware from chat client with kwargs pipeline
extract_and_merge_function_middleware(self, **kwargs)
# Extract the middleware pipeline before calling the underlying function
# because the underlying function may not preserve it in kwargs
stored_middleware_pipeline = kwargs.get("_function_middleware_pipeline")
# Extract and merge function middleware from chat client with kwargs
stored_middleware_pipeline = extract_and_merge_function_middleware(self, kwargs)
# Get the config for function invocation (not part of ChatClientProtocol, hence getattr)
config: FunctionInvocationConfiguration | None = getattr(self, "function_invocation_configuration", None)
@@ -1726,7 +1784,7 @@ def _handle_function_calls_response(
approved_responses = [resp for resp in fcc_todo.values() if resp.approved]
approved_function_results: list[Contents] = []
if approved_responses:
approved_function_results = await _try_execute_function_calls(
results, _ = await _try_execute_function_calls(
custom_args=kwargs,
attempt_idx=attempt_idx,
function_calls=approved_responses,
@@ -1734,6 +1792,7 @@ def _handle_function_calls_response(
middleware_pipeline=stored_middleware_pipeline,
config=config,
)
approved_function_results = list(results)
if any(
fcr.exception is not None
for fcr in approved_function_results
@@ -1773,7 +1832,7 @@ def _handle_function_calls_response(
if function_calls and tools:
# Use the stored middleware pipeline instead of extracting from kwargs
# because kwargs may have been modified by the underlying function
function_call_results: list[Contents] = await _try_execute_function_calls(
function_call_results, should_terminate = await _try_execute_function_calls(
custom_args=kwargs,
attempt_idx=attempt_idx,
function_calls=function_calls,
@@ -1798,6 +1857,17 @@ def _handle_function_calls_response(
# the function calls are already in the response, so we just continue
return response
# Check if middleware signaled to terminate the loop (context.terminate=True)
# This allows middleware to short-circuit the tool loop without another LLM call
if should_terminate:
# Add tool results to response and return immediately without calling LLM again
result_message = ChatMessage(role="tool", contents=function_call_results)
response.messages.append(result_message)
if fcc_messages:
for msg in reversed(fcc_messages):
response.messages.insert(0, msg)
return response
if any(
fcr.exception is not None
for fcr in function_call_results
@@ -1890,12 +1960,8 @@ def _handle_function_calls_streaming_response(
prepare_messages,
)
# Extract and merge function middleware from chat client with kwargs pipeline
extract_and_merge_function_middleware(self, **kwargs)
# Extract the middleware pipeline before calling the underlying function
# because the underlying function may not preserve it in kwargs
stored_middleware_pipeline = kwargs.get("_function_middleware_pipeline")
# Extract and merge function middleware from chat client with kwargs
stored_middleware_pipeline = extract_and_merge_function_middleware(self, kwargs)
# Get the config for function invocation (not part of ChatClientProtocol, hence getattr)
config: FunctionInvocationConfiguration | None = getattr(self, "function_invocation_configuration", None)
@@ -1914,7 +1980,7 @@ def _handle_function_calls_streaming_response(
approved_responses = [resp for resp in fcc_todo.values() if resp.approved]
approved_function_results: list[Contents] = []
if approved_responses:
approved_function_results = await _try_execute_function_calls(
results, _ = await _try_execute_function_calls(
custom_args=kwargs,
attempt_idx=attempt_idx,
function_calls=approved_responses,
@@ -1922,6 +1988,7 @@ def _handle_function_calls_streaming_response(
middleware_pipeline=stored_middleware_pipeline,
config=config,
)
approved_function_results = list(results)
if any(
fcr.exception is not None
for fcr in approved_function_results
@@ -1976,7 +2043,7 @@ def _handle_function_calls_streaming_response(
if function_calls and tools:
# Use the stored middleware pipeline instead of extracting from kwargs
# because kwargs may have been modified by the underlying function
function_call_results: list[Contents] = await _try_execute_function_calls(
function_call_results, should_terminate = await _try_execute_function_calls(
custom_args=kwargs,
attempt_idx=attempt_idx,
function_calls=function_calls,
@@ -2005,6 +2072,13 @@ def _handle_function_calls_streaming_response(
# the function calls were already yielded.
return
# Check if middleware signaled to terminate the loop (context.terminate=True)
# This allows middleware to short-circuit the tool loop without another LLM call
if should_terminate:
# Yield tool results and return immediately without calling LLM again
yield ChatResponseUpdate(contents=function_call_results, role="tool")
return
if any(
fcr.exception is not None
for fcr in function_call_results
@@ -1,5 +1,7 @@
# Copyright (c) Microsoft. All rights reserved.
from collections.abc import Awaitable, Callable
import pytest
from agent_framework import (
@@ -16,6 +18,7 @@ from agent_framework import (
TextContent,
ai_function,
)
from agent_framework._middleware import FunctionInvocationContext, FunctionMiddleware
async def test_base_client_with_function_calling(chat_client_base: ChatClientProtocol):
@@ -2206,3 +2209,175 @@ async def test_streaming_error_recovery_resets_counter(chat_client_base: ChatCli
assert len(error_results) >= 1
assert len(success_results) >= 1
assert call_count == 2 # Both calls executed
class TerminateLoopMiddleware(FunctionMiddleware):
"""Middleware that sets terminate=True to exit the function calling loop."""
async def process(
self, context: FunctionInvocationContext, next_handler: Callable[[FunctionInvocationContext], Awaitable[None]]
) -> None:
# Set result to a simple value - the framework will wrap it in FunctionResultContent
context.result = "terminated by middleware"
context.terminate = True
async def test_terminate_loop_single_function_call(chat_client_base: ChatClientProtocol):
"""Test that terminate_loop=True exits the function calling loop after single function call."""
exec_counter = 0
@ai_function(name="test_function")
def ai_func(arg1: str) -> str:
nonlocal exec_counter
exec_counter += 1
return f"Processed {arg1}"
# Queue up two responses: function call, then final text
# If terminate_loop works, only the first response should be consumed
chat_client_base.run_responses = [
ChatResponse(
messages=ChatMessage(
role="assistant",
contents=[FunctionCallContent(call_id="1", name="test_function", arguments='{"arg1": "value1"}')],
)
),
ChatResponse(messages=ChatMessage(role="assistant", text="done")),
]
response = await chat_client_base.get_response(
"hello",
tool_choice="auto",
tools=[ai_func],
middleware=[TerminateLoopMiddleware()],
)
# Function should NOT have been executed - middleware intercepted it
assert exec_counter == 0
# There should be 2 messages: assistant with function call, tool result from middleware
# The loop should NOT have continued to call the LLM again
assert len(response.messages) == 2
assert response.messages[0].role == Role.ASSISTANT
assert isinstance(response.messages[0].contents[0], FunctionCallContent)
assert response.messages[1].role == Role.TOOL
assert isinstance(response.messages[1].contents[0], FunctionResultContent)
assert response.messages[1].contents[0].result == "terminated by middleware"
# Verify the second response is still in the queue (wasn't consumed)
assert len(chat_client_base.run_responses) == 1
class SelectiveTerminateMiddleware(FunctionMiddleware):
"""Only terminates for terminating_function."""
async def process(
self, context: FunctionInvocationContext, next_handler: Callable[[FunctionInvocationContext], Awaitable[None]]
) -> None:
if context.function.name == "terminating_function":
# Set result to a simple value - the framework will wrap it in FunctionResultContent
context.result = "terminated by middleware"
context.terminate = True
else:
await next_handler(context)
async def test_terminate_loop_multiple_function_calls_one_terminates(chat_client_base: ChatClientProtocol):
"""Test that any(terminate_loop=True) exits loop even with multiple function calls."""
normal_call_count = 0
terminating_call_count = 0
@ai_function(name="normal_function")
def normal_func(arg1: str) -> str:
nonlocal normal_call_count
normal_call_count += 1
return f"Normal {arg1}"
@ai_function(name="terminating_function")
def terminating_func(arg1: str) -> str:
nonlocal terminating_call_count
terminating_call_count += 1
return f"Terminating {arg1}"
# Queue up two responses: parallel function calls, then final text
chat_client_base.run_responses = [
ChatResponse(
messages=ChatMessage(
role="assistant",
contents=[
FunctionCallContent(call_id="1", name="normal_function", arguments='{"arg1": "value1"}'),
FunctionCallContent(call_id="2", name="terminating_function", arguments='{"arg1": "value2"}'),
],
)
),
ChatResponse(messages=ChatMessage(role="assistant", text="done")),
]
response = await chat_client_base.get_response(
"hello",
tool_choice="auto",
tools=[normal_func, terminating_func],
middleware=[SelectiveTerminateMiddleware()],
)
# normal_function should have executed (middleware calls next_handler)
# terminating_function should NOT have executed (middleware intercepts it)
assert normal_call_count == 1
assert terminating_call_count == 0
# There should be 2 messages: assistant with function calls, tool results
# The loop should NOT have continued to call the LLM again
assert len(response.messages) == 2
assert response.messages[0].role == Role.ASSISTANT
assert len(response.messages[0].contents) == 2
assert response.messages[1].role == Role.TOOL
# Both function results should be present
assert len(response.messages[1].contents) == 2
# Verify the second response is still in the queue (wasn't consumed)
assert len(chat_client_base.run_responses) == 1
async def test_terminate_loop_streaming_single_function_call(chat_client_base: ChatClientProtocol):
"""Test that terminate_loop=True exits the streaming function calling loop."""
exec_counter = 0
@ai_function(name="test_function")
def ai_func(arg1: str) -> str:
nonlocal exec_counter
exec_counter += 1
return f"Processed {arg1}"
# Queue up two streaming responses
chat_client_base.streaming_responses = [
[
ChatResponseUpdate(
contents=[FunctionCallContent(call_id="1", name="test_function", arguments='{"arg1": "value1"}')],
role="assistant",
),
],
[
ChatResponseUpdate(
contents=[TextContent(text="done")],
role="assistant",
)
],
]
updates = []
async for update in chat_client_base.get_streaming_response(
"hello",
tool_choice="auto",
tools=[ai_func],
middleware=[TerminateLoopMiddleware()],
):
updates.append(update)
# Function should NOT have been executed - middleware intercepted it
assert exec_counter == 0
# Should have function call update and function result update
# The loop should NOT have continued to call the LLM again
assert len(updates) == 2
# Verify the second streaming response is still in the queue (wasn't consumed)
assert len(chat_client_base.streaming_responses) == 1
@@ -193,7 +193,8 @@ class TestChatAgentFunctionBasedMiddleware:
# Create a message to start the conversation
messages = [ChatMessage(role=Role.USER, text="test message")]
# Set up chat client to return a function call
# Set up chat client to return a function call, then a final response
# If terminate works correctly, only the first response should be consumed
chat_client.responses = [
ChatResponse(
messages=[
@@ -204,7 +205,8 @@ class TestChatAgentFunctionBasedMiddleware:
],
)
]
)
),
ChatResponse(messages=[ChatMessage(role=Role.ASSISTANT, text="this should not be consumed")]),
]
# Create the test function with the expected signature
@@ -222,7 +224,11 @@ class TestChatAgentFunctionBasedMiddleware:
# Verify that function was not called and only middleware executed
assert execution_order == ["middleware_before", "middleware_after"]
assert "function_called" not in execution_order
assert execution_order == ["middleware_before", "middleware_after"]
# Verify the chat client was only called once (no extra LLM call after termination)
assert chat_client.call_count == 1
# Verify the second response is still in the queue (wasn't consumed)
assert len(chat_client.responses) == 1
async def test_function_middleware_with_post_termination(self, chat_client: "MockChatClient") -> None:
"""Test that function middleware can terminate execution after calling next()."""
@@ -242,7 +248,8 @@ class TestChatAgentFunctionBasedMiddleware:
# Create a message to start the conversation
messages = [ChatMessage(role=Role.USER, text="test message")]
# Set up chat client to return a function call
# Set up chat client to return a function call, then a final response
# If terminate works correctly, only the first response should be consumed
chat_client.responses = [
ChatResponse(
messages=[
@@ -253,7 +260,8 @@ class TestChatAgentFunctionBasedMiddleware:
],
)
]
)
),
ChatResponse(messages=[ChatMessage(role=Role.ASSISTANT, text="this should not be consumed")]),
]
# Create the test function with the expected signature
@@ -273,6 +281,11 @@ class TestChatAgentFunctionBasedMiddleware:
assert "function_called" in execution_order
assert execution_order == ["middleware_before", "function_called", "middleware_after"]
# Verify the chat client was only called once (no extra LLM call after termination)
assert chat_client.call_count == 1
# Verify the second response is still in the queue (wasn't consumed)
assert len(chat_client.responses) == 1
async def test_function_based_agent_middleware_with_chat_agent(self, chat_client: "MockChatClient") -> None:
"""Test function-based agent middleware with ChatAgent."""
execution_order: list[str] = []