mirror of
https://github.com/microsoft/agent-framework.git
synced 2026-06-16 21:04:09 +08:00
* Fix thread corruption when max_iterations exhausted (#1366) When the function invocation loop exhausts max_iterations while the model keeps requesting tools, the failsafe code path (calling the model with tool_choice='none' and prepending fcc_messages) was unreachable because 'if response is not None: return response' short-circuited before it. The fix removes the premature return so the failsafe always runs after loop exhaustion, making a final model call with tool_choice='none' to produce a clean text answer and prepending accumulated fcc_messages from prior iterations. This matches the existing pattern used by the error threshold and max_function_calls paths. Also unskips test_max_iterations_limit and test_streaming_max_iterations_limit which were previously skipped with 'needs investigation in unified API'. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> * Add fix report for issue #1366 Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> * Fix ruff formatting in _tools.py and test_issue_1366_thread_corruption.py Apply ruff format to fix multi-line string concatenation and function call formatting issues flagged by the linter. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> * Add quality review for issue #1366 fix Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> * Remove temporary investigation docs. * Address PR review: explicit enabled check in log condition, clarify mock behavior in test - Add explicit function_invocation_configuration['enabled'] check to the 'Maximum iterations reached' log condition in both non-streaming and streaming paths, making intent clearer when function invocation is disabled. - Add comment in test_thread_safe_after_max_iterations_with_agent explaining that the failsafe response (tool_choice='none') is provided automatically by the mock client, not from run_responses. * Blend fix and tests into project without issue-specific callouts - Remove issue #1366 references from _tools.py comments - Move regression tests from standalone test_issue_1366_thread_corruption.py into test_function_invocation_logic.py alongside existing max_iterations tests - Clean up test docstrings to describe behavior generically - Delete the standalone issue-specific test file --------- Co-authored-by: alliscode <bentho@microsoft.com> Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
This commit is contained in:
committed by
GitHub
Unverified
parent
40d2fac29c
commit
23fe2c16b3
@@ -2192,9 +2192,15 @@ class FunctionInvocationLayer(Generic[OptionsCoT]):
|
||||
prepped_messages.extend(response.messages)
|
||||
continue
|
||||
|
||||
if response is not None:
|
||||
return response
|
||||
|
||||
# Loop exhausted all iterations (or function invocation disabled).
|
||||
# Make a final model call with tool_choice="none" so the model
|
||||
# produces a plain text answer instead of leaving orphaned
|
||||
# function_call items without matching results.
|
||||
if response is not None and self.function_invocation_configuration["enabled"]:
|
||||
logger.info(
|
||||
"Maximum iterations reached (%d). Requesting final response without tools.",
|
||||
self.function_invocation_configuration["max_iterations"],
|
||||
)
|
||||
mutable_options["tool_choice"] = "none"
|
||||
response = await super_get_response(
|
||||
messages=prepped_messages,
|
||||
@@ -2325,9 +2331,15 @@ class FunctionInvocationLayer(Generic[OptionsCoT]):
|
||||
prepped_messages.extend(response.messages)
|
||||
continue
|
||||
|
||||
if response is not None:
|
||||
return
|
||||
|
||||
# Loop exhausted all iterations (or function invocation disabled).
|
||||
# Make a final model call with tool_choice="none" so the model
|
||||
# produces a plain text answer instead of leaving orphaned
|
||||
# function_call items without matching results.
|
||||
if response is not None and self.function_invocation_configuration["enabled"]:
|
||||
logger.info(
|
||||
"Maximum iterations reached (%d). Requesting final response without tools.",
|
||||
self.function_invocation_configuration["max_iterations"],
|
||||
)
|
||||
mutable_options["tool_choice"] = "none"
|
||||
inner_stream = await _ensure_response_stream(
|
||||
super_get_response(
|
||||
|
||||
@@ -831,8 +831,6 @@ async def test_rejection_result_uses_function_call_id(chat_client_base: Supports
|
||||
assert "rejected" in rejection_result.result.lower()
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="Failsafe behavior with max_iterations needs investigation in unified API")
|
||||
@pytest.mark.skip(reason="Failsafe behavior with max_iterations needs investigation in unified API")
|
||||
async def test_max_iterations_limit(chat_client_base: SupportsChatGetResponse):
|
||||
"""Test that MAX_ITERATIONS in additional_properties limits function call loops."""
|
||||
exec_counter = 0
|
||||
@@ -880,6 +878,256 @@ async def test_max_iterations_limit(chat_client_base: SupportsChatGetResponse):
|
||||
assert response.messages[-1].text == "I broke out of the function invocation loop..." # Failsafe response
|
||||
|
||||
|
||||
async def test_max_iterations_no_orphaned_function_calls(chat_client_base: SupportsChatGetResponse):
|
||||
"""When max_iterations is reached, verify the returned response has no orphaned
|
||||
FunctionCallContent (i.e., every function_call has a matching function_result).
|
||||
"""
|
||||
exec_counter = 0
|
||||
|
||||
@tool(name="test_function", approval_mode="never_require")
|
||||
def ai_func(arg1: str) -> str:
|
||||
nonlocal exec_counter
|
||||
exec_counter += 1
|
||||
return f"Processed {arg1}"
|
||||
|
||||
# Model keeps requesting tool calls on every iteration
|
||||
chat_client_base.run_responses = [
|
||||
ChatResponse(
|
||||
messages=Message(
|
||||
role="assistant",
|
||||
contents=[
|
||||
Content.from_function_call(call_id="call_1", name="test_function", arguments='{"arg1": "v1"}')
|
||||
],
|
||||
)
|
||||
),
|
||||
ChatResponse(
|
||||
messages=Message(
|
||||
role="assistant",
|
||||
contents=[
|
||||
Content.from_function_call(call_id="call_2", name="test_function", arguments='{"arg1": "v2"}')
|
||||
],
|
||||
)
|
||||
),
|
||||
ChatResponse(
|
||||
messages=Message(
|
||||
role="assistant",
|
||||
contents=[
|
||||
Content.from_function_call(call_id="call_3", name="test_function", arguments='{"arg1": "v3"}')
|
||||
],
|
||||
)
|
||||
),
|
||||
]
|
||||
|
||||
chat_client_base.function_invocation_configuration["max_iterations"] = 2
|
||||
|
||||
response = await chat_client_base.get_response(
|
||||
[Message(role="user", text="hello")],
|
||||
options={"tool_choice": "auto", "tools": [ai_func]},
|
||||
)
|
||||
|
||||
# Collect all function_call and function_result call_ids from response
|
||||
all_call_ids = set()
|
||||
all_result_ids = set()
|
||||
for msg in response.messages:
|
||||
for content in msg.contents:
|
||||
if content.type == "function_call":
|
||||
all_call_ids.add(content.call_id)
|
||||
elif content.type == "function_result":
|
||||
all_result_ids.add(content.call_id)
|
||||
|
||||
orphaned_calls = all_call_ids - all_result_ids
|
||||
assert not orphaned_calls, (
|
||||
f"Response contains orphaned FunctionCallContent without matching "
|
||||
f"FunctionResultContent: {orphaned_calls}."
|
||||
)
|
||||
|
||||
|
||||
async def test_max_iterations_makes_final_toolchoice_none_call(chat_client_base: SupportsChatGetResponse):
|
||||
"""When max_iterations is reached, verify a final model call is made with
|
||||
tool_choice='none' to produce a clean text response.
|
||||
"""
|
||||
exec_counter = 0
|
||||
|
||||
@tool(name="test_function", approval_mode="never_require")
|
||||
def ai_func(arg1: str) -> str:
|
||||
nonlocal exec_counter
|
||||
exec_counter += 1
|
||||
return f"Processed {arg1}"
|
||||
|
||||
chat_client_base.run_responses = [
|
||||
ChatResponse(
|
||||
messages=Message(
|
||||
role="assistant",
|
||||
contents=[
|
||||
Content.from_function_call(call_id="call_1", name="test_function", arguments='{"arg1": "v1"}')
|
||||
],
|
||||
)
|
||||
),
|
||||
ChatResponse(
|
||||
messages=Message(
|
||||
role="assistant",
|
||||
contents=[
|
||||
Content.from_function_call(call_id="call_2", name="test_function", arguments='{"arg1": "v2"}')
|
||||
],
|
||||
)
|
||||
),
|
||||
# This response should be reached via failsafe (tool_choice="none")
|
||||
ChatResponse(messages=Message(role="assistant", text="Final answer after giving up on tools.")),
|
||||
]
|
||||
|
||||
chat_client_base.function_invocation_configuration["max_iterations"] = 1
|
||||
|
||||
response = await chat_client_base.get_response(
|
||||
[Message(role="user", text="hello")],
|
||||
options={"tool_choice": "auto", "tools": [ai_func]},
|
||||
)
|
||||
|
||||
assert exec_counter == 1, f"Expected 1 function execution, got {exec_counter}"
|
||||
|
||||
# The response should end with a plain text message (from the failsafe call)
|
||||
last_msg = response.messages[-1]
|
||||
has_function_calls = any(c.type == "function_call" for c in last_msg.contents)
|
||||
|
||||
assert not has_function_calls, (
|
||||
f"Last message in response still contains function_call items. "
|
||||
f"Expected a clean text response after max_iterations failsafe. "
|
||||
f"Got message with role={last_msg.role}, contents={[c.type for c in last_msg.contents]}"
|
||||
)
|
||||
|
||||
# The mock client returns "I broke out of the function invocation loop..."
|
||||
# when tool_choice="none"
|
||||
assert last_msg.text == "I broke out of the function invocation loop...", (
|
||||
f"Expected failsafe text response, got: {last_msg.text!r}"
|
||||
)
|
||||
|
||||
|
||||
async def test_max_iterations_preserves_all_fcc_messages(chat_client_base: SupportsChatGetResponse):
|
||||
"""When max_iterations is reached and a final response is produced, all
|
||||
intermediate function call/result messages should be included.
|
||||
"""
|
||||
exec_counter = 0
|
||||
|
||||
@tool(name="test_function", approval_mode="never_require")
|
||||
def ai_func(arg1: str) -> str:
|
||||
nonlocal exec_counter
|
||||
exec_counter += 1
|
||||
return f"Result {exec_counter}"
|
||||
|
||||
# Two iterations of function calls, then failsafe
|
||||
chat_client_base.run_responses = [
|
||||
ChatResponse(
|
||||
messages=Message(
|
||||
role="assistant",
|
||||
contents=[
|
||||
Content.from_function_call(call_id="call_1", name="test_function", arguments='{"arg1": "v1"}')
|
||||
],
|
||||
)
|
||||
),
|
||||
ChatResponse(
|
||||
messages=Message(
|
||||
role="assistant",
|
||||
contents=[
|
||||
Content.from_function_call(call_id="call_2", name="test_function", arguments='{"arg1": "v2"}')
|
||||
],
|
||||
)
|
||||
),
|
||||
ChatResponse(messages=Message(role="assistant", text="Done")),
|
||||
]
|
||||
|
||||
chat_client_base.function_invocation_configuration["max_iterations"] = 2
|
||||
|
||||
response = await chat_client_base.get_response(
|
||||
[Message(role="user", text="hello")],
|
||||
options={"tool_choice": "auto", "tools": [ai_func]},
|
||||
)
|
||||
|
||||
assert exec_counter == 2, f"Expected 2 function executions, got {exec_counter}"
|
||||
|
||||
# All function calls from both iterations should be present in the response
|
||||
all_call_ids = set()
|
||||
all_result_ids = set()
|
||||
for msg in response.messages:
|
||||
for content in msg.contents:
|
||||
if content.type == "function_call":
|
||||
all_call_ids.add(content.call_id)
|
||||
elif content.type == "function_result":
|
||||
all_result_ids.add(content.call_id)
|
||||
|
||||
assert "call_1" in all_call_ids, "First iteration's function call missing from response"
|
||||
assert "call_2" in all_call_ids, "Second iteration's function call missing from response"
|
||||
|
||||
assert all_call_ids == all_result_ids, (
|
||||
f"Mismatched function calls and results. Calls: {all_call_ids}, Results: {all_result_ids}"
|
||||
)
|
||||
|
||||
|
||||
async def test_max_iterations_thread_integrity_with_agent(chat_client_base: SupportsChatGetResponse):
|
||||
"""Verify that agent.run() does not produce orphaned function calls after
|
||||
max_iterations, which would corrupt the thread and cause API errors on the
|
||||
next call.
|
||||
"""
|
||||
|
||||
@tool(name="browser_snapshot", approval_mode="never_require")
|
||||
def browser_snapshot(url: str) -> str:
|
||||
return f"Screenshot of {url}"
|
||||
|
||||
# Model keeps requesting tool calls on every iteration.
|
||||
# The failsafe call (with tool_choice="none") after the loop is handled
|
||||
# automatically by the mock client, which returns a hardcoded text response
|
||||
# when tool_choice="none" (see conftest.py ChatClientBase.get_response).
|
||||
chat_client_base.run_responses = [
|
||||
ChatResponse(
|
||||
messages=Message(
|
||||
role="assistant",
|
||||
contents=[
|
||||
Content.from_function_call(
|
||||
call_id="call_abc", name="browser_snapshot", arguments='{"url": "https://example.com"}'
|
||||
)
|
||||
],
|
||||
)
|
||||
),
|
||||
ChatResponse(
|
||||
messages=Message(
|
||||
role="assistant",
|
||||
contents=[
|
||||
Content.from_function_call(
|
||||
call_id="call_xyz", name="browser_snapshot", arguments='{"url": "https://test.com"}'
|
||||
)
|
||||
],
|
||||
)
|
||||
),
|
||||
]
|
||||
|
||||
chat_client_base.function_invocation_configuration["max_iterations"] = 2
|
||||
|
||||
agent = Agent(
|
||||
client=chat_client_base,
|
||||
name="test-agent",
|
||||
tools=[browser_snapshot],
|
||||
)
|
||||
|
||||
response = await agent.run(
|
||||
"Take screenshots",
|
||||
options={"tool_choice": "auto"},
|
||||
)
|
||||
|
||||
# Check for orphaned function calls in the response messages
|
||||
all_call_ids = set()
|
||||
all_result_ids = set()
|
||||
for msg in response.messages:
|
||||
for content in msg.contents:
|
||||
if content.type == "function_call":
|
||||
all_call_ids.add(content.call_id)
|
||||
elif content.type == "function_result":
|
||||
all_result_ids.add(content.call_id)
|
||||
|
||||
orphaned_calls = all_call_ids - all_result_ids
|
||||
assert not orphaned_calls, (
|
||||
f"Response contains orphaned function calls {orphaned_calls}. "
|
||||
f"This would cause API errors on the next call."
|
||||
)
|
||||
|
||||
|
||||
@pytest.mark.parametrize("max_iterations", [10])
|
||||
async def test_max_function_calls_limits_parallel_invocations(chat_client_base: SupportsChatGetResponse):
|
||||
"""Test that max_function_calls caps total function invocations across iterations with parallel calls."""
|
||||
@@ -2248,7 +2496,6 @@ async def test_streaming_approval_request_generated(chat_client_base: SupportsCh
|
||||
assert exec_counter == 0 # Function not executed yet due to approval requirement
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="Failsafe behavior with max_iterations needs investigation in unified API")
|
||||
async def test_streaming_max_iterations_limit(chat_client_base: SupportsChatGetResponse):
|
||||
"""Test that MAX_ITERATIONS in streaming mode limits function call loops."""
|
||||
exec_counter = 0
|
||||
|
||||
Reference in New Issue
Block a user