mirror of
https://github.com/microsoft/agent-framework.git
synced 2026-06-16 21:04:09 +08:00
Python: Fix middleware terminate flag to exit function calling loop immediately (#2868)
* Fix middleware terminate flag to exit function calling loop immediately * Eliminating duck typing * Improve function exec result handling * Fix race condition * Fix mypy issues
This commit is contained in:
committed by
GitHub
Unverified
parent
3139347526
commit
11d6dcfe80
@@ -1405,13 +1405,17 @@ def use_chat_middleware(chat_client_class: type[TChatClient]) -> type[TChatClien
|
||||
call_middleware = kwargs.pop("middleware", None)
|
||||
instance_middleware = getattr(self, "middleware", None)
|
||||
|
||||
# Merge middleware from both sources, filtering for chat middleware only
|
||||
all_middleware: list[ChatMiddleware | ChatMiddlewareCallable] = _merge_and_filter_chat_middleware(
|
||||
instance_middleware, call_middleware
|
||||
)
|
||||
# Merge all middleware and separate by type
|
||||
middleware = categorize_middleware(instance_middleware, call_middleware)
|
||||
chat_middleware_list = middleware["chat"]
|
||||
function_middleware_list = middleware["function"]
|
||||
|
||||
# If no middleware, use original method
|
||||
if not all_middleware:
|
||||
# Pass function middleware to function invocation system if present
|
||||
if function_middleware_list:
|
||||
kwargs["_function_middleware_pipeline"] = FunctionMiddlewarePipeline(function_middleware_list)
|
||||
|
||||
# If no chat middleware, use original method
|
||||
if not chat_middleware_list:
|
||||
async for update in original_get_streaming_response(self, messages, **kwargs):
|
||||
yield update
|
||||
return
|
||||
@@ -1422,7 +1426,7 @@ def use_chat_middleware(chat_client_class: type[TChatClient]) -> type[TChatClien
|
||||
# Extract chat_options or create default
|
||||
chat_options = kwargs.pop("chat_options", ChatOptions())
|
||||
|
||||
pipeline = ChatMiddlewarePipeline(all_middleware) # type: ignore[arg-type]
|
||||
pipeline = ChatMiddlewarePipeline(chat_middleware_list) # type: ignore[arg-type]
|
||||
context = ChatContext(
|
||||
chat_client=self,
|
||||
messages=prepare_messages(messages),
|
||||
@@ -1536,27 +1540,40 @@ def _merge_and_filter_chat_middleware(
|
||||
return middleware["chat"] # type: ignore[return-value]
|
||||
|
||||
|
||||
def extract_and_merge_function_middleware(chat_client: Any, **kwargs: Any) -> None:
|
||||
def extract_and_merge_function_middleware(
|
||||
chat_client: Any, kwargs: dict[str, Any]
|
||||
) -> "FunctionMiddlewarePipeline | None":
|
||||
"""Extract function middleware from chat client and merge with existing pipeline in kwargs.
|
||||
|
||||
Args:
|
||||
chat_client: The chat client instance to extract middleware from.
|
||||
kwargs: Dictionary containing middleware and pipeline information.
|
||||
|
||||
Keyword Args:
|
||||
**kwargs: Dictionary containing middleware and pipeline information.
|
||||
Returns:
|
||||
A FunctionMiddlewarePipeline if function middleware is found, None otherwise.
|
||||
"""
|
||||
# Check if a pipeline was already created by use_chat_middleware
|
||||
existing_pipeline: FunctionMiddlewarePipeline | None = kwargs.get("_function_middleware_pipeline")
|
||||
|
||||
# Get middleware sources
|
||||
client_middleware = getattr(chat_client, "middleware", None) if hasattr(chat_client, "middleware") else None
|
||||
run_level_middleware = kwargs.get("middleware")
|
||||
existing_pipeline = kwargs.get("_function_middleware_pipeline")
|
||||
|
||||
# Extract existing pipeline middlewares if present
|
||||
existing_middlewares = existing_pipeline._middlewares if existing_pipeline else None
|
||||
# If we have an existing pipeline but no additional middleware sources, return it directly
|
||||
if existing_pipeline and not client_middleware and not run_level_middleware:
|
||||
return existing_pipeline
|
||||
|
||||
# If we have an existing pipeline with additional middleware, we need to merge
|
||||
# Extract existing pipeline middlewares if present - cast to list[Middleware] for type compatibility
|
||||
existing_middlewares: list[Middleware] | None = list(existing_pipeline._middlewares) if existing_pipeline else None
|
||||
|
||||
# Create combined pipeline from all sources using existing helper
|
||||
combined_pipeline = create_function_middleware_pipeline(
|
||||
client_middleware, run_level_middleware, existing_middlewares
|
||||
)
|
||||
|
||||
if combined_pipeline:
|
||||
kwargs["_function_middleware_pipeline"] = combined_pipeline
|
||||
# If we have an existing pipeline but combined is None (no new middlewares), return existing
|
||||
if existing_pipeline and combined_pipeline is None:
|
||||
return existing_pipeline
|
||||
|
||||
return combined_pipeline
|
||||
|
||||
@@ -1348,6 +1348,35 @@ class FunctionInvocationConfiguration(SerializationMixin):
|
||||
self.include_detailed_errors = include_detailed_errors
|
||||
|
||||
|
||||
class FunctionExecutionResult:
|
||||
"""Internal wrapper pairing function output with loop control signals.
|
||||
|
||||
Function execution produces two distinct concerns: the semantic result (returned to
|
||||
the LLM as FunctionResultContent) and control flow decisions (whether middleware
|
||||
requested early termination). This wrapper keeps control signals out of user-facing
|
||||
content types while allowing _try_execute_function_calls to communicate both.
|
||||
|
||||
Not exposed to users.
|
||||
|
||||
Attributes:
|
||||
content: The FunctionResultContent or other content from the function execution.
|
||||
terminate: If True, the function invocation loop should exit immediately without
|
||||
another LLM call. Set when middleware sets context.terminate=True.
|
||||
"""
|
||||
|
||||
__slots__ = ("content", "terminate")
|
||||
|
||||
def __init__(self, content: "Contents", terminate: bool = False) -> None:
|
||||
"""Initialize FunctionExecutionResult.
|
||||
|
||||
Args:
|
||||
content: The content from the function execution.
|
||||
terminate: Whether to terminate the function calling loop.
|
||||
"""
|
||||
self.content = content
|
||||
self.terminate = terminate
|
||||
|
||||
|
||||
async def _auto_invoke_function(
|
||||
function_call_content: "FunctionCallContent | FunctionApprovalResponseContent",
|
||||
custom_args: dict[str, Any] | None = None,
|
||||
@@ -1357,7 +1386,7 @@ async def _auto_invoke_function(
|
||||
sequence_index: int | None = None,
|
||||
request_index: int | None = None,
|
||||
middleware_pipeline: Any = None, # Optional MiddlewarePipeline
|
||||
) -> "Contents":
|
||||
) -> "FunctionExecutionResult | Contents":
|
||||
"""Invoke a function call requested by the agent, applying middleware that is defined.
|
||||
|
||||
Args:
|
||||
@@ -1372,7 +1401,8 @@ async def _auto_invoke_function(
|
||||
middleware_pipeline: Optional middleware pipeline to apply during execution.
|
||||
|
||||
Returns:
|
||||
A FunctionResultContent containing the result or exception.
|
||||
A FunctionExecutionResult wrapping the content and terminate signal,
|
||||
or a Contents object for approval/hosted tool scenarios.
|
||||
|
||||
Raises:
|
||||
KeyError: If the requested function is not found in the tool map.
|
||||
@@ -1392,10 +1422,12 @@ async def _auto_invoke_function(
|
||||
# Tool should exist because _try_execute_function_calls validates this
|
||||
if tool is None:
|
||||
exc = KeyError(f'Function "{function_call_content.name}" not found.')
|
||||
return FunctionResultContent(
|
||||
call_id=function_call_content.call_id,
|
||||
result=f'Error: Requested function "{function_call_content.name}" not found.',
|
||||
exception=exc,
|
||||
return FunctionExecutionResult(
|
||||
content=FunctionResultContent(
|
||||
call_id=function_call_content.call_id,
|
||||
result=f'Error: Requested function "{function_call_content.name}" not found.',
|
||||
exception=exc,
|
||||
)
|
||||
)
|
||||
else:
|
||||
# Note: Unapproved tools (approved=False) are handled in _replace_approval_contents_with_results
|
||||
@@ -1420,7 +1452,9 @@ async def _auto_invoke_function(
|
||||
message = "Error: Argument parsing failed."
|
||||
if config.include_detailed_errors:
|
||||
message = f"{message} Exception: {exc}"
|
||||
return FunctionResultContent(call_id=function_call_content.call_id, result=message, exception=exc)
|
||||
return FunctionExecutionResult(
|
||||
content=FunctionResultContent(call_id=function_call_content.call_id, result=message, exception=exc)
|
||||
)
|
||||
|
||||
if not middleware_pipeline or (
|
||||
not hasattr(middleware_pipeline, "has_middlewares") and not middleware_pipeline.has_middlewares
|
||||
@@ -1432,15 +1466,19 @@ async def _auto_invoke_function(
|
||||
tool_call_id=function_call_content.call_id,
|
||||
**runtime_kwargs if getattr(tool, "_forward_runtime_kwargs", False) else {},
|
||||
)
|
||||
return FunctionResultContent(
|
||||
call_id=function_call_content.call_id,
|
||||
result=function_result,
|
||||
return FunctionExecutionResult(
|
||||
content=FunctionResultContent(
|
||||
call_id=function_call_content.call_id,
|
||||
result=function_result,
|
||||
)
|
||||
)
|
||||
except Exception as exc:
|
||||
message = "Error: Function failed."
|
||||
if config.include_detailed_errors:
|
||||
message = f"{message} Exception: {exc}"
|
||||
return FunctionResultContent(call_id=function_call_content.call_id, result=message, exception=exc)
|
||||
return FunctionExecutionResult(
|
||||
content=FunctionResultContent(call_id=function_call_content.call_id, result=message, exception=exc)
|
||||
)
|
||||
# Execute through middleware pipeline if available
|
||||
from ._middleware import FunctionInvocationContext
|
||||
|
||||
@@ -1464,15 +1502,20 @@ async def _auto_invoke_function(
|
||||
context=middleware_context,
|
||||
final_handler=final_function_handler,
|
||||
)
|
||||
return FunctionResultContent(
|
||||
call_id=function_call_content.call_id,
|
||||
result=function_result,
|
||||
return FunctionExecutionResult(
|
||||
content=FunctionResultContent(
|
||||
call_id=function_call_content.call_id,
|
||||
result=function_result,
|
||||
),
|
||||
terminate=middleware_context.terminate,
|
||||
)
|
||||
except Exception as exc:
|
||||
message = "Error: Function failed."
|
||||
if config.include_detailed_errors:
|
||||
message = f"{message} Exception: {exc}"
|
||||
return FunctionResultContent(call_id=function_call_content.call_id, result=message, exception=exc)
|
||||
return FunctionExecutionResult(
|
||||
content=FunctionResultContent(call_id=function_call_content.call_id, result=message, exception=exc)
|
||||
)
|
||||
|
||||
|
||||
def _get_tool_map(
|
||||
@@ -1503,7 +1546,7 @@ async def _try_execute_function_calls(
|
||||
| Sequence[ToolProtocol | Callable[..., Any] | MutableMapping[str, Any]]",
|
||||
config: FunctionInvocationConfiguration,
|
||||
middleware_pipeline: Any = None, # Optional MiddlewarePipeline to avoid circular imports
|
||||
) -> Sequence["Contents"]:
|
||||
) -> tuple[Sequence["Contents"], bool]:
|
||||
"""Execute multiple function calls concurrently.
|
||||
|
||||
Args:
|
||||
@@ -1515,9 +1558,11 @@ async def _try_execute_function_calls(
|
||||
middleware_pipeline: Optional middleware pipeline to apply during execution.
|
||||
|
||||
Returns:
|
||||
A list of Contents containing the results of each function call,
|
||||
or the approval requests if any function requires approval,
|
||||
or the original function calls if any are declaration only.
|
||||
A tuple of:
|
||||
- A list of Contents containing the results of each function call,
|
||||
or the approval requests if any function requires approval,
|
||||
or the original function calls if any are declaration only.
|
||||
- A boolean indicating whether to terminate the function calling loop.
|
||||
"""
|
||||
from ._types import FunctionApprovalRequestContent, FunctionCallContent
|
||||
|
||||
@@ -1540,17 +1585,20 @@ async def _try_execute_function_calls(
|
||||
raise KeyError(f'Error: Requested function "{fcc.name}" not found.')
|
||||
if approval_needed:
|
||||
# approval can only be needed for Function Call Contents, not Approval Responses.
|
||||
return [
|
||||
FunctionApprovalRequestContent(id=fcc.call_id, function_call=fcc)
|
||||
for fcc in function_calls
|
||||
if isinstance(fcc, FunctionCallContent)
|
||||
]
|
||||
return (
|
||||
[
|
||||
FunctionApprovalRequestContent(id=fcc.call_id, function_call=fcc)
|
||||
for fcc in function_calls
|
||||
if isinstance(fcc, FunctionCallContent)
|
||||
],
|
||||
False,
|
||||
)
|
||||
if declaration_only_flag:
|
||||
# return the declaration only tools to the user, since we cannot execute them.
|
||||
return [fcc for fcc in function_calls if isinstance(fcc, FunctionCallContent)]
|
||||
return ([fcc for fcc in function_calls if isinstance(fcc, FunctionCallContent)], False)
|
||||
|
||||
# Run all function calls concurrently
|
||||
return await asyncio.gather(*[
|
||||
execution_results = await asyncio.gather(*[
|
||||
_auto_invoke_function(
|
||||
function_call_content=function_call, # type: ignore[arg-type]
|
||||
custom_args=custom_args,
|
||||
@@ -1563,6 +1611,20 @@ async def _try_execute_function_calls(
|
||||
for seq_idx, function_call in enumerate(function_calls)
|
||||
])
|
||||
|
||||
# Unpack FunctionExecutionResult wrappers and check for terminate signal
|
||||
contents: list[Contents] = []
|
||||
should_terminate = False
|
||||
for result in execution_results:
|
||||
if isinstance(result, FunctionExecutionResult):
|
||||
contents.append(result.content)
|
||||
if result.terminate:
|
||||
should_terminate = True
|
||||
else:
|
||||
# Direct Contents (e.g., from hosted tools)
|
||||
contents.append(result)
|
||||
|
||||
return (contents, should_terminate)
|
||||
|
||||
|
||||
def _update_conversation_id(kwargs: dict[str, Any], conversation_id: str | None) -> None:
|
||||
"""Update kwargs with conversation id.
|
||||
@@ -1695,12 +1757,8 @@ def _handle_function_calls_response(
|
||||
prepare_messages,
|
||||
)
|
||||
|
||||
# Extract and merge function middleware from chat client with kwargs pipeline
|
||||
extract_and_merge_function_middleware(self, **kwargs)
|
||||
|
||||
# Extract the middleware pipeline before calling the underlying function
|
||||
# because the underlying function may not preserve it in kwargs
|
||||
stored_middleware_pipeline = kwargs.get("_function_middleware_pipeline")
|
||||
# Extract and merge function middleware from chat client with kwargs
|
||||
stored_middleware_pipeline = extract_and_merge_function_middleware(self, kwargs)
|
||||
|
||||
# Get the config for function invocation (not part of ChatClientProtocol, hence getattr)
|
||||
config: FunctionInvocationConfiguration | None = getattr(self, "function_invocation_configuration", None)
|
||||
@@ -1726,7 +1784,7 @@ def _handle_function_calls_response(
|
||||
approved_responses = [resp for resp in fcc_todo.values() if resp.approved]
|
||||
approved_function_results: list[Contents] = []
|
||||
if approved_responses:
|
||||
approved_function_results = await _try_execute_function_calls(
|
||||
results, _ = await _try_execute_function_calls(
|
||||
custom_args=kwargs,
|
||||
attempt_idx=attempt_idx,
|
||||
function_calls=approved_responses,
|
||||
@@ -1734,6 +1792,7 @@ def _handle_function_calls_response(
|
||||
middleware_pipeline=stored_middleware_pipeline,
|
||||
config=config,
|
||||
)
|
||||
approved_function_results = list(results)
|
||||
if any(
|
||||
fcr.exception is not None
|
||||
for fcr in approved_function_results
|
||||
@@ -1773,7 +1832,7 @@ def _handle_function_calls_response(
|
||||
if function_calls and tools:
|
||||
# Use the stored middleware pipeline instead of extracting from kwargs
|
||||
# because kwargs may have been modified by the underlying function
|
||||
function_call_results: list[Contents] = await _try_execute_function_calls(
|
||||
function_call_results, should_terminate = await _try_execute_function_calls(
|
||||
custom_args=kwargs,
|
||||
attempt_idx=attempt_idx,
|
||||
function_calls=function_calls,
|
||||
@@ -1798,6 +1857,17 @@ def _handle_function_calls_response(
|
||||
# the function calls are already in the response, so we just continue
|
||||
return response
|
||||
|
||||
# Check if middleware signaled to terminate the loop (context.terminate=True)
|
||||
# This allows middleware to short-circuit the tool loop without another LLM call
|
||||
if should_terminate:
|
||||
# Add tool results to response and return immediately without calling LLM again
|
||||
result_message = ChatMessage(role="tool", contents=function_call_results)
|
||||
response.messages.append(result_message)
|
||||
if fcc_messages:
|
||||
for msg in reversed(fcc_messages):
|
||||
response.messages.insert(0, msg)
|
||||
return response
|
||||
|
||||
if any(
|
||||
fcr.exception is not None
|
||||
for fcr in function_call_results
|
||||
@@ -1890,12 +1960,8 @@ def _handle_function_calls_streaming_response(
|
||||
prepare_messages,
|
||||
)
|
||||
|
||||
# Extract and merge function middleware from chat client with kwargs pipeline
|
||||
extract_and_merge_function_middleware(self, **kwargs)
|
||||
|
||||
# Extract the middleware pipeline before calling the underlying function
|
||||
# because the underlying function may not preserve it in kwargs
|
||||
stored_middleware_pipeline = kwargs.get("_function_middleware_pipeline")
|
||||
# Extract and merge function middleware from chat client with kwargs
|
||||
stored_middleware_pipeline = extract_and_merge_function_middleware(self, kwargs)
|
||||
|
||||
# Get the config for function invocation (not part of ChatClientProtocol, hence getattr)
|
||||
config: FunctionInvocationConfiguration | None = getattr(self, "function_invocation_configuration", None)
|
||||
@@ -1914,7 +1980,7 @@ def _handle_function_calls_streaming_response(
|
||||
approved_responses = [resp for resp in fcc_todo.values() if resp.approved]
|
||||
approved_function_results: list[Contents] = []
|
||||
if approved_responses:
|
||||
approved_function_results = await _try_execute_function_calls(
|
||||
results, _ = await _try_execute_function_calls(
|
||||
custom_args=kwargs,
|
||||
attempt_idx=attempt_idx,
|
||||
function_calls=approved_responses,
|
||||
@@ -1922,6 +1988,7 @@ def _handle_function_calls_streaming_response(
|
||||
middleware_pipeline=stored_middleware_pipeline,
|
||||
config=config,
|
||||
)
|
||||
approved_function_results = list(results)
|
||||
if any(
|
||||
fcr.exception is not None
|
||||
for fcr in approved_function_results
|
||||
@@ -1976,7 +2043,7 @@ def _handle_function_calls_streaming_response(
|
||||
if function_calls and tools:
|
||||
# Use the stored middleware pipeline instead of extracting from kwargs
|
||||
# because kwargs may have been modified by the underlying function
|
||||
function_call_results: list[Contents] = await _try_execute_function_calls(
|
||||
function_call_results, should_terminate = await _try_execute_function_calls(
|
||||
custom_args=kwargs,
|
||||
attempt_idx=attempt_idx,
|
||||
function_calls=function_calls,
|
||||
@@ -2005,6 +2072,13 @@ def _handle_function_calls_streaming_response(
|
||||
# the function calls were already yielded.
|
||||
return
|
||||
|
||||
# Check if middleware signaled to terminate the loop (context.terminate=True)
|
||||
# This allows middleware to short-circuit the tool loop without another LLM call
|
||||
if should_terminate:
|
||||
# Yield tool results and return immediately without calling LLM again
|
||||
yield ChatResponseUpdate(contents=function_call_results, role="tool")
|
||||
return
|
||||
|
||||
if any(
|
||||
fcr.exception is not None
|
||||
for fcr in function_call_results
|
||||
|
||||
@@ -1,5 +1,7 @@
|
||||
# Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
from collections.abc import Awaitable, Callable
|
||||
|
||||
import pytest
|
||||
|
||||
from agent_framework import (
|
||||
@@ -16,6 +18,7 @@ from agent_framework import (
|
||||
TextContent,
|
||||
ai_function,
|
||||
)
|
||||
from agent_framework._middleware import FunctionInvocationContext, FunctionMiddleware
|
||||
|
||||
|
||||
async def test_base_client_with_function_calling(chat_client_base: ChatClientProtocol):
|
||||
@@ -2206,3 +2209,175 @@ async def test_streaming_error_recovery_resets_counter(chat_client_base: ChatCli
|
||||
assert len(error_results) >= 1
|
||||
assert len(success_results) >= 1
|
||||
assert call_count == 2 # Both calls executed
|
||||
|
||||
|
||||
class TerminateLoopMiddleware(FunctionMiddleware):
|
||||
"""Middleware that sets terminate=True to exit the function calling loop."""
|
||||
|
||||
async def process(
|
||||
self, context: FunctionInvocationContext, next_handler: Callable[[FunctionInvocationContext], Awaitable[None]]
|
||||
) -> None:
|
||||
# Set result to a simple value - the framework will wrap it in FunctionResultContent
|
||||
context.result = "terminated by middleware"
|
||||
context.terminate = True
|
||||
|
||||
|
||||
async def test_terminate_loop_single_function_call(chat_client_base: ChatClientProtocol):
|
||||
"""Test that terminate_loop=True exits the function calling loop after single function call."""
|
||||
exec_counter = 0
|
||||
|
||||
@ai_function(name="test_function")
|
||||
def ai_func(arg1: str) -> str:
|
||||
nonlocal exec_counter
|
||||
exec_counter += 1
|
||||
return f"Processed {arg1}"
|
||||
|
||||
# Queue up two responses: function call, then final text
|
||||
# If terminate_loop works, only the first response should be consumed
|
||||
chat_client_base.run_responses = [
|
||||
ChatResponse(
|
||||
messages=ChatMessage(
|
||||
role="assistant",
|
||||
contents=[FunctionCallContent(call_id="1", name="test_function", arguments='{"arg1": "value1"}')],
|
||||
)
|
||||
),
|
||||
ChatResponse(messages=ChatMessage(role="assistant", text="done")),
|
||||
]
|
||||
|
||||
response = await chat_client_base.get_response(
|
||||
"hello",
|
||||
tool_choice="auto",
|
||||
tools=[ai_func],
|
||||
middleware=[TerminateLoopMiddleware()],
|
||||
)
|
||||
|
||||
# Function should NOT have been executed - middleware intercepted it
|
||||
assert exec_counter == 0
|
||||
|
||||
# There should be 2 messages: assistant with function call, tool result from middleware
|
||||
# The loop should NOT have continued to call the LLM again
|
||||
assert len(response.messages) == 2
|
||||
assert response.messages[0].role == Role.ASSISTANT
|
||||
assert isinstance(response.messages[0].contents[0], FunctionCallContent)
|
||||
assert response.messages[1].role == Role.TOOL
|
||||
assert isinstance(response.messages[1].contents[0], FunctionResultContent)
|
||||
assert response.messages[1].contents[0].result == "terminated by middleware"
|
||||
|
||||
# Verify the second response is still in the queue (wasn't consumed)
|
||||
assert len(chat_client_base.run_responses) == 1
|
||||
|
||||
|
||||
class SelectiveTerminateMiddleware(FunctionMiddleware):
|
||||
"""Only terminates for terminating_function."""
|
||||
|
||||
async def process(
|
||||
self, context: FunctionInvocationContext, next_handler: Callable[[FunctionInvocationContext], Awaitable[None]]
|
||||
) -> None:
|
||||
if context.function.name == "terminating_function":
|
||||
# Set result to a simple value - the framework will wrap it in FunctionResultContent
|
||||
context.result = "terminated by middleware"
|
||||
context.terminate = True
|
||||
else:
|
||||
await next_handler(context)
|
||||
|
||||
|
||||
async def test_terminate_loop_multiple_function_calls_one_terminates(chat_client_base: ChatClientProtocol):
|
||||
"""Test that any(terminate_loop=True) exits loop even with multiple function calls."""
|
||||
normal_call_count = 0
|
||||
terminating_call_count = 0
|
||||
|
||||
@ai_function(name="normal_function")
|
||||
def normal_func(arg1: str) -> str:
|
||||
nonlocal normal_call_count
|
||||
normal_call_count += 1
|
||||
return f"Normal {arg1}"
|
||||
|
||||
@ai_function(name="terminating_function")
|
||||
def terminating_func(arg1: str) -> str:
|
||||
nonlocal terminating_call_count
|
||||
terminating_call_count += 1
|
||||
return f"Terminating {arg1}"
|
||||
|
||||
# Queue up two responses: parallel function calls, then final text
|
||||
chat_client_base.run_responses = [
|
||||
ChatResponse(
|
||||
messages=ChatMessage(
|
||||
role="assistant",
|
||||
contents=[
|
||||
FunctionCallContent(call_id="1", name="normal_function", arguments='{"arg1": "value1"}'),
|
||||
FunctionCallContent(call_id="2", name="terminating_function", arguments='{"arg1": "value2"}'),
|
||||
],
|
||||
)
|
||||
),
|
||||
ChatResponse(messages=ChatMessage(role="assistant", text="done")),
|
||||
]
|
||||
|
||||
response = await chat_client_base.get_response(
|
||||
"hello",
|
||||
tool_choice="auto",
|
||||
tools=[normal_func, terminating_func],
|
||||
middleware=[SelectiveTerminateMiddleware()],
|
||||
)
|
||||
|
||||
# normal_function should have executed (middleware calls next_handler)
|
||||
# terminating_function should NOT have executed (middleware intercepts it)
|
||||
assert normal_call_count == 1
|
||||
assert terminating_call_count == 0
|
||||
|
||||
# There should be 2 messages: assistant with function calls, tool results
|
||||
# The loop should NOT have continued to call the LLM again
|
||||
assert len(response.messages) == 2
|
||||
assert response.messages[0].role == Role.ASSISTANT
|
||||
assert len(response.messages[0].contents) == 2
|
||||
assert response.messages[1].role == Role.TOOL
|
||||
# Both function results should be present
|
||||
assert len(response.messages[1].contents) == 2
|
||||
|
||||
# Verify the second response is still in the queue (wasn't consumed)
|
||||
assert len(chat_client_base.run_responses) == 1
|
||||
|
||||
|
||||
async def test_terminate_loop_streaming_single_function_call(chat_client_base: ChatClientProtocol):
|
||||
"""Test that terminate_loop=True exits the streaming function calling loop."""
|
||||
exec_counter = 0
|
||||
|
||||
@ai_function(name="test_function")
|
||||
def ai_func(arg1: str) -> str:
|
||||
nonlocal exec_counter
|
||||
exec_counter += 1
|
||||
return f"Processed {arg1}"
|
||||
|
||||
# Queue up two streaming responses
|
||||
chat_client_base.streaming_responses = [
|
||||
[
|
||||
ChatResponseUpdate(
|
||||
contents=[FunctionCallContent(call_id="1", name="test_function", arguments='{"arg1": "value1"}')],
|
||||
role="assistant",
|
||||
),
|
||||
],
|
||||
[
|
||||
ChatResponseUpdate(
|
||||
contents=[TextContent(text="done")],
|
||||
role="assistant",
|
||||
)
|
||||
],
|
||||
]
|
||||
|
||||
updates = []
|
||||
async for update in chat_client_base.get_streaming_response(
|
||||
"hello",
|
||||
tool_choice="auto",
|
||||
tools=[ai_func],
|
||||
middleware=[TerminateLoopMiddleware()],
|
||||
):
|
||||
updates.append(update)
|
||||
|
||||
# Function should NOT have been executed - middleware intercepted it
|
||||
assert exec_counter == 0
|
||||
|
||||
# Should have function call update and function result update
|
||||
# The loop should NOT have continued to call the LLM again
|
||||
assert len(updates) == 2
|
||||
|
||||
# Verify the second streaming response is still in the queue (wasn't consumed)
|
||||
assert len(chat_client_base.streaming_responses) == 1
|
||||
|
||||
@@ -193,7 +193,8 @@ class TestChatAgentFunctionBasedMiddleware:
|
||||
# Create a message to start the conversation
|
||||
messages = [ChatMessage(role=Role.USER, text="test message")]
|
||||
|
||||
# Set up chat client to return a function call
|
||||
# Set up chat client to return a function call, then a final response
|
||||
# If terminate works correctly, only the first response should be consumed
|
||||
chat_client.responses = [
|
||||
ChatResponse(
|
||||
messages=[
|
||||
@@ -204,7 +205,8 @@ class TestChatAgentFunctionBasedMiddleware:
|
||||
],
|
||||
)
|
||||
]
|
||||
)
|
||||
),
|
||||
ChatResponse(messages=[ChatMessage(role=Role.ASSISTANT, text="this should not be consumed")]),
|
||||
]
|
||||
|
||||
# Create the test function with the expected signature
|
||||
@@ -222,7 +224,11 @@ class TestChatAgentFunctionBasedMiddleware:
|
||||
# Verify that function was not called and only middleware executed
|
||||
assert execution_order == ["middleware_before", "middleware_after"]
|
||||
assert "function_called" not in execution_order
|
||||
assert execution_order == ["middleware_before", "middleware_after"]
|
||||
|
||||
# Verify the chat client was only called once (no extra LLM call after termination)
|
||||
assert chat_client.call_count == 1
|
||||
# Verify the second response is still in the queue (wasn't consumed)
|
||||
assert len(chat_client.responses) == 1
|
||||
|
||||
async def test_function_middleware_with_post_termination(self, chat_client: "MockChatClient") -> None:
|
||||
"""Test that function middleware can terminate execution after calling next()."""
|
||||
@@ -242,7 +248,8 @@ class TestChatAgentFunctionBasedMiddleware:
|
||||
# Create a message to start the conversation
|
||||
messages = [ChatMessage(role=Role.USER, text="test message")]
|
||||
|
||||
# Set up chat client to return a function call
|
||||
# Set up chat client to return a function call, then a final response
|
||||
# If terminate works correctly, only the first response should be consumed
|
||||
chat_client.responses = [
|
||||
ChatResponse(
|
||||
messages=[
|
||||
@@ -253,7 +260,8 @@ class TestChatAgentFunctionBasedMiddleware:
|
||||
],
|
||||
)
|
||||
]
|
||||
)
|
||||
),
|
||||
ChatResponse(messages=[ChatMessage(role=Role.ASSISTANT, text="this should not be consumed")]),
|
||||
]
|
||||
|
||||
# Create the test function with the expected signature
|
||||
@@ -273,6 +281,11 @@ class TestChatAgentFunctionBasedMiddleware:
|
||||
assert "function_called" in execution_order
|
||||
assert execution_order == ["middleware_before", "function_called", "middleware_after"]
|
||||
|
||||
# Verify the chat client was only called once (no extra LLM call after termination)
|
||||
assert chat_client.call_count == 1
|
||||
# Verify the second response is still in the queue (wasn't consumed)
|
||||
assert len(chat_client.responses) == 1
|
||||
|
||||
async def test_function_based_agent_middleware_with_chat_agent(self, chat_client: "MockChatClient") -> None:
|
||||
"""Test function-based agent middleware with ChatAgent."""
|
||||
execution_order: list[str] = []
|
||||
|
||||
Reference in New Issue
Block a user