mirror of
https://github.com/microsoft/agent-framework.git
synced 2026-06-16 21:04:09 +08:00
Python: Fix Failure with Function Approval Messages in Chat Clients (#1322)
* Python: set role=tool when processing approval responses * Python: set role=tool when processing approval responses * Fix approval mode with OpenAIChatClient and threads: add approval requests to assistant message, fix deduplication/rejection call_id, filter approval content, add tests and example * update test tools after change * Rename _collect_approval_todos to _collect_approval_responses and filter empty call_ids
This commit is contained in:
committed by
GitHub
Unverified
parent
d6fdb91480
commit
e032fe3993
@@ -1142,16 +1142,17 @@ def _extract_tools(kwargs: dict[str, Any]) -> Any:
|
||||
return tools
|
||||
|
||||
|
||||
def _collect_approval_todos(
|
||||
def _collect_approval_responses(
|
||||
messages: "list[ChatMessage]",
|
||||
) -> dict[str, "FunctionApprovalResponseContent"]:
|
||||
"""Collect approved function calls from messages."""
|
||||
"""Collect approval responses (both approved and rejected) from messages."""
|
||||
from ._types import ChatMessage, FunctionApprovalResponseContent
|
||||
|
||||
fcc_todo: dict[str, FunctionApprovalResponseContent] = {}
|
||||
for msg in messages:
|
||||
for content in msg.contents if isinstance(msg, ChatMessage) else []:
|
||||
if isinstance(content, FunctionApprovalResponseContent) and content.approved:
|
||||
# Collect BOTH approved and rejected responses
|
||||
if isinstance(content, FunctionApprovalResponseContent):
|
||||
fcc_todo[content.id] = content
|
||||
return fcc_todo
|
||||
|
||||
@@ -1162,26 +1163,52 @@ def _replace_approval_contents_with_results(
|
||||
approved_function_results: "list[Contents]",
|
||||
) -> None:
|
||||
"""Replace approval request/response contents with function call/result contents in-place."""
|
||||
from ._types import FunctionApprovalRequestContent, FunctionApprovalResponseContent, FunctionResultContent
|
||||
from ._types import (
|
||||
FunctionApprovalRequestContent,
|
||||
FunctionApprovalResponseContent,
|
||||
FunctionCallContent,
|
||||
FunctionResultContent,
|
||||
Role,
|
||||
)
|
||||
|
||||
result_idx = 0
|
||||
for msg in messages:
|
||||
# First pass - collect existing function call IDs to avoid duplicates
|
||||
existing_call_ids = {
|
||||
content.call_id for content in msg.contents if isinstance(content, FunctionCallContent) and content.call_id
|
||||
}
|
||||
|
||||
# Track approval requests that should be removed (duplicates)
|
||||
contents_to_remove = []
|
||||
|
||||
for content_idx, content in enumerate(msg.contents):
|
||||
if isinstance(content, FunctionApprovalRequestContent):
|
||||
# put back the function call content
|
||||
msg.contents[content_idx] = content.function_call
|
||||
if isinstance(content, FunctionApprovalResponseContent):
|
||||
# Don't add the function call if it already exists (would create duplicate)
|
||||
if content.function_call.call_id in existing_call_ids:
|
||||
# Just mark for removal - the function call already exists
|
||||
contents_to_remove.append(content_idx)
|
||||
else:
|
||||
# Put back the function call content only if it doesn't exist
|
||||
msg.contents[content_idx] = content.function_call
|
||||
elif isinstance(content, FunctionApprovalResponseContent):
|
||||
if content.approved and content.id in fcc_todo:
|
||||
# Replace with the corresponding result
|
||||
if result_idx < len(approved_function_results):
|
||||
msg.contents[content_idx] = approved_function_results[result_idx]
|
||||
result_idx += 1
|
||||
msg.role = Role.TOOL
|
||||
else:
|
||||
# Create a "not approved" result for rejected calls
|
||||
# Use function_call.call_id (the function's ID), not content.id (approval's ID)
|
||||
msg.contents[content_idx] = FunctionResultContent(
|
||||
call_id=content.id,
|
||||
call_id=content.function_call.call_id,
|
||||
result="Error: Tool call invocation was rejected by user.",
|
||||
)
|
||||
msg.role = Role.TOOL
|
||||
|
||||
# Remove approval requests that were duplicates (in reverse order to preserve indices)
|
||||
for idx in reversed(contents_to_remove):
|
||||
msg.contents.pop(idx)
|
||||
|
||||
|
||||
def _handle_function_calls_response(
|
||||
@@ -1234,16 +1261,20 @@ def _handle_function_calls_response(
|
||||
response: "ChatResponse | None" = None
|
||||
fcc_messages: "list[ChatMessage]" = []
|
||||
for attempt_idx in range(instance_max_iterations):
|
||||
fcc_todo = _collect_approval_todos(prepped_messages)
|
||||
fcc_todo = _collect_approval_responses(prepped_messages)
|
||||
if fcc_todo:
|
||||
tools = _extract_tools(kwargs)
|
||||
approved_function_results: list[Contents] = await _execute_function_calls(
|
||||
custom_args=kwargs,
|
||||
attempt_idx=attempt_idx,
|
||||
function_calls=list(fcc_todo.values()),
|
||||
tools=tools, # type: ignore
|
||||
middleware_pipeline=stored_middleware_pipeline,
|
||||
)
|
||||
# Only execute APPROVED function calls, not rejected ones
|
||||
approved_responses = [resp for resp in fcc_todo.values() if resp.approved]
|
||||
approved_function_results: list[Contents] = []
|
||||
if approved_responses:
|
||||
approved_function_results = await _execute_function_calls(
|
||||
custom_args=kwargs,
|
||||
attempt_idx=attempt_idx,
|
||||
function_calls=approved_responses,
|
||||
tools=tools, # type: ignore
|
||||
middleware_pipeline=stored_middleware_pipeline,
|
||||
)
|
||||
_replace_approval_contents_with_results(prepped_messages, fcc_todo, approved_function_results)
|
||||
|
||||
response = await func(self, messages=prepped_messages, **kwargs)
|
||||
@@ -1273,6 +1304,21 @@ def _handle_function_calls_response(
|
||||
tools=tools, # type: ignore
|
||||
middleware_pipeline=stored_middleware_pipeline,
|
||||
)
|
||||
|
||||
# Check if we have approval requests in the results
|
||||
if any(isinstance(fccr, FunctionApprovalRequestContent) for fccr in function_call_results):
|
||||
# Add approval requests to the existing assistant message (with tool_calls)
|
||||
# instead of creating a separate tool message
|
||||
from ._types import Role
|
||||
|
||||
if response.messages and response.messages[0].role == Role.ASSISTANT:
|
||||
response.messages[0].contents.extend(function_call_results)
|
||||
else:
|
||||
# Fallback: create new assistant message (shouldn't normally happen)
|
||||
result_message = ChatMessage(role="assistant", contents=function_call_results)
|
||||
response.messages.append(result_message)
|
||||
return response
|
||||
|
||||
# add a single ChatMessage to the response with the results
|
||||
result_message = ChatMessage(role="tool", contents=function_call_results)
|
||||
response.messages.append(result_message)
|
||||
@@ -1283,9 +1329,6 @@ def _handle_function_calls_response(
|
||||
# this runs in every but the first run
|
||||
# we need to keep track of all function call messages
|
||||
fcc_messages.extend(response.messages)
|
||||
# and add them as additional context to the messages
|
||||
if any(isinstance(fccr, FunctionApprovalRequestContent) for fccr in function_call_results):
|
||||
return response
|
||||
if getattr(kwargs.get("chat_options"), "store", False):
|
||||
prepped_messages.clear()
|
||||
prepped_messages.append(result_message)
|
||||
@@ -1365,16 +1408,20 @@ def _handle_function_calls_streaming_response(
|
||||
prepped_messages = prepare_messages(messages)
|
||||
fcc_messages: "list[ChatMessage]" = []
|
||||
for attempt_idx in range(instance_max_iterations):
|
||||
fcc_todo = _collect_approval_todos(prepped_messages)
|
||||
fcc_todo = _collect_approval_responses(prepped_messages)
|
||||
if fcc_todo:
|
||||
tools = _extract_tools(kwargs)
|
||||
approved_function_results: list[Contents] = await _execute_function_calls(
|
||||
custom_args=kwargs,
|
||||
attempt_idx=attempt_idx,
|
||||
function_calls=list(fcc_todo.values()),
|
||||
tools=tools, # type: ignore
|
||||
middleware_pipeline=stored_middleware_pipeline,
|
||||
)
|
||||
# Only execute APPROVED function calls, not rejected ones
|
||||
approved_responses = [resp for resp in fcc_todo.values() if resp.approved]
|
||||
approved_function_results: list[Contents] = []
|
||||
if approved_responses:
|
||||
approved_function_results = await _execute_function_calls(
|
||||
custom_args=kwargs,
|
||||
attempt_idx=attempt_idx,
|
||||
function_calls=approved_responses,
|
||||
tools=tools, # type: ignore
|
||||
middleware_pipeline=stored_middleware_pipeline,
|
||||
)
|
||||
_replace_approval_contents_with_results(prepped_messages, fcc_todo, approved_function_results)
|
||||
|
||||
all_updates: list["ChatResponseUpdate"] = []
|
||||
@@ -1427,6 +1474,24 @@ def _handle_function_calls_streaming_response(
|
||||
tools=tools, # type: ignore
|
||||
middleware_pipeline=stored_middleware_pipeline,
|
||||
)
|
||||
|
||||
# Check if we have approval requests in the results
|
||||
if any(isinstance(fccr, FunctionApprovalRequestContent) for fccr in function_call_results):
|
||||
# Add approval requests to the existing assistant message (with tool_calls)
|
||||
# instead of creating a separate tool message
|
||||
from ._types import Role
|
||||
|
||||
if response.messages and response.messages[0].role == Role.ASSISTANT:
|
||||
response.messages[0].contents.extend(function_call_results)
|
||||
# Yield the approval requests as part of the assistant message
|
||||
yield ChatResponseUpdate(contents=function_call_results, role="assistant")
|
||||
else:
|
||||
# Fallback: create new assistant message (shouldn't normally happen)
|
||||
result_message = ChatMessage(role="assistant", contents=function_call_results)
|
||||
yield ChatResponseUpdate(contents=function_call_results, role="assistant")
|
||||
response.messages.append(result_message)
|
||||
return
|
||||
|
||||
# add a single ChatMessage to the response with the results
|
||||
result_message = ChatMessage(role="tool", contents=function_call_results)
|
||||
yield ChatResponseUpdate(contents=function_call_results, role="tool")
|
||||
@@ -1438,9 +1503,6 @@ def _handle_function_calls_streaming_response(
|
||||
# this runs in every but the first run
|
||||
# we need to keep track of all function call messages
|
||||
fcc_messages.extend(response.messages)
|
||||
# and add them as additional context to the messages
|
||||
if any(isinstance(fccr, FunctionApprovalRequestContent) for fccr in function_call_results):
|
||||
return
|
||||
if getattr(kwargs.get("chat_options"), "store", False):
|
||||
prepped_messages.clear()
|
||||
prepped_messages.append(result_message)
|
||||
|
||||
@@ -28,6 +28,8 @@ from .._types import (
|
||||
Contents,
|
||||
DataContent,
|
||||
FinishReason,
|
||||
FunctionApprovalRequestContent,
|
||||
FunctionApprovalResponseContent,
|
||||
FunctionCallContent,
|
||||
FunctionResultContent,
|
||||
Role,
|
||||
@@ -356,6 +358,10 @@ class OpenAIBaseChatClient(OpenAIBase, BaseChatClient):
|
||||
"""Parse a chat message into the openai format."""
|
||||
all_messages: list[dict[str, Any]] = []
|
||||
for content in message.contents:
|
||||
# Skip approval content - it's internal framework state, not for the LLM
|
||||
if isinstance(content, (FunctionApprovalRequestContent, FunctionApprovalResponseContent)):
|
||||
continue
|
||||
|
||||
args: dict[str, Any] = {
|
||||
"role": message.role.value if isinstance(message.role, Role) else message.role,
|
||||
}
|
||||
|
||||
@@ -252,15 +252,17 @@ async def test_function_invocation_scenarios(
|
||||
# Verify based on scenario (for no thread and local thread cases)
|
||||
if num_functions == 1:
|
||||
if approval_required:
|
||||
# Single function with approval: call + approval request
|
||||
# Single function with approval: assistant message contains both call + approval request
|
||||
if not streaming:
|
||||
assert len(messages) == 2
|
||||
assert len(messages) == 1
|
||||
# Assistant message should have FunctionCallContent + FunctionApprovalRequestContent
|
||||
assert len(messages[0].contents) == 2
|
||||
assert isinstance(messages[0].contents[0], FunctionCallContent)
|
||||
assert isinstance(messages[1].contents[0], FunctionApprovalRequestContent)
|
||||
assert messages[1].contents[0].function_call.name == "approval_func"
|
||||
assert isinstance(messages[0].contents[1], FunctionApprovalRequestContent)
|
||||
assert messages[0].contents[1].function_call.name == "approval_func"
|
||||
assert exec_counter == 0 # Function not executed yet
|
||||
else:
|
||||
# Streaming: 2 function call chunks + 1 approval request
|
||||
# Streaming: 2 function call chunks + 1 approval request update (same assistant message)
|
||||
assert len(messages) == 3
|
||||
assert isinstance(messages[0].contents[0], FunctionCallContent)
|
||||
assert isinstance(messages[1].contents[0], FunctionCallContent)
|
||||
@@ -288,15 +290,16 @@ async def test_function_invocation_scenarios(
|
||||
else: # num_functions == 2
|
||||
# Two functions with mixed approval
|
||||
if not streaming:
|
||||
# Mixed: first message has both calls, second has approval requests for both
|
||||
# Mixed: assistant message has both calls + approval requests (4 items total)
|
||||
# (because when one requires approval, all are batched for approval)
|
||||
assert len(messages) == 2
|
||||
assert len(messages[0].contents) == 2 # Both function calls
|
||||
assert len(messages) == 1
|
||||
# Should have: 2 FunctionCallContent + 2 FunctionApprovalRequestContent
|
||||
assert len(messages[0].contents) == 4
|
||||
assert isinstance(messages[0].contents[0], FunctionCallContent)
|
||||
assert isinstance(messages[0].contents[1], FunctionCallContent)
|
||||
# Both should result in approval requests
|
||||
assert len(messages[1].contents) == 2
|
||||
assert all(isinstance(c, FunctionApprovalRequestContent) for c in messages[1].contents)
|
||||
approval_requests = [c for c in messages[0].contents if isinstance(c, FunctionApprovalRequestContent)]
|
||||
assert len(approval_requests) == 2
|
||||
assert exec_counter == 0 # Neither function executed yet
|
||||
else:
|
||||
# Streaming: 2 function call updates + 1 approval request with 2 contents
|
||||
@@ -344,13 +347,16 @@ async def test_rejected_approval(chat_client_base: ChatClientProtocol):
|
||||
|
||||
# Get the response with approval requests
|
||||
response = await chat_client_base.get_response("hello", tool_choice="auto", tools=[func_approved, func_rejected])
|
||||
assert len(response.messages) == 2
|
||||
assert len(response.messages[1].contents) == 2
|
||||
assert all(isinstance(c, FunctionApprovalRequestContent) for c in response.messages[1].contents)
|
||||
# Approval requests are now added to the assistant message, not a separate message
|
||||
assert len(response.messages) == 1
|
||||
# Assistant message should have: 2 FunctionCallContent + 2 FunctionApprovalRequestContent
|
||||
assert len(response.messages[0].contents) == 4
|
||||
approval_requests = [c for c in response.messages[0].contents if isinstance(c, FunctionApprovalRequestContent)]
|
||||
assert len(approval_requests) == 2
|
||||
|
||||
# Approve one and reject the other
|
||||
approval_req_1 = response.messages[1].contents[0]
|
||||
approval_req_2 = response.messages[1].contents[1]
|
||||
approval_req_1 = approval_requests[0]
|
||||
approval_req_2 = approval_requests[1]
|
||||
|
||||
approved_response = FunctionApprovalResponseContent(
|
||||
id=approval_req_1.id,
|
||||
@@ -391,6 +397,184 @@ async def test_rejected_approval(chat_client_base: ChatClientProtocol):
|
||||
assert rejected_result.result == "Error: Tool call invocation was rejected by user."
|
||||
assert exec_counter_rejected == 0
|
||||
|
||||
# Verify that messages with FunctionResultContent have role="tool"
|
||||
# This ensures the message format is correct for OpenAI's API
|
||||
for msg in all_messages:
|
||||
for content in msg.contents:
|
||||
if isinstance(content, FunctionResultContent):
|
||||
assert msg.role == Role.TOOL, (
|
||||
f"Message with FunctionResultContent must have role='tool', got '{msg.role}'"
|
||||
)
|
||||
|
||||
|
||||
async def test_approval_requests_in_assistant_message(chat_client_base: ChatClientProtocol):
|
||||
"""Approval requests should be added to the assistant message that contains the function call."""
|
||||
exec_counter = 0
|
||||
|
||||
@ai_function(name="test_func", approval_mode="always_require")
|
||||
def func_with_approval(arg1: str) -> str:
|
||||
nonlocal exec_counter
|
||||
exec_counter += 1
|
||||
return f"Result {arg1}"
|
||||
|
||||
chat_client_base.run_responses = [
|
||||
ChatResponse(
|
||||
messages=ChatMessage(
|
||||
role="assistant",
|
||||
contents=[
|
||||
FunctionCallContent(call_id="1", name="test_func", arguments='{"arg1": "value1"}'),
|
||||
],
|
||||
)
|
||||
),
|
||||
]
|
||||
|
||||
response = await chat_client_base.get_response("hello", tool_choice="auto", tools=[func_with_approval])
|
||||
|
||||
# Should have one assistant message containing both the call and approval request
|
||||
assert len(response.messages) == 1
|
||||
assert response.messages[0].role == Role.ASSISTANT
|
||||
assert len(response.messages[0].contents) == 2
|
||||
assert isinstance(response.messages[0].contents[0], FunctionCallContent)
|
||||
assert isinstance(response.messages[0].contents[1], FunctionApprovalRequestContent)
|
||||
assert exec_counter == 0
|
||||
|
||||
|
||||
async def test_persisted_approval_messages_replay_correctly(chat_client_base: ChatClientProtocol):
|
||||
"""Approval flow should work when messages are persisted and sent back (thread scenario)."""
|
||||
from agent_framework import FunctionApprovalResponseContent
|
||||
|
||||
exec_counter = 0
|
||||
|
||||
@ai_function(name="test_func", approval_mode="always_require")
|
||||
def func_with_approval(arg1: str) -> str:
|
||||
nonlocal exec_counter
|
||||
exec_counter += 1
|
||||
return f"Result {arg1}"
|
||||
|
||||
chat_client_base.run_responses = [
|
||||
ChatResponse(
|
||||
messages=ChatMessage(
|
||||
role="assistant",
|
||||
contents=[
|
||||
FunctionCallContent(call_id="1", name="test_func", arguments='{"arg1": "value1"}'),
|
||||
],
|
||||
)
|
||||
),
|
||||
ChatResponse(messages=ChatMessage(role="assistant", text="done")),
|
||||
]
|
||||
|
||||
# Get approval request
|
||||
response1 = await chat_client_base.get_response("hello", tool_choice="auto", tools=[func_with_approval])
|
||||
|
||||
# Store messages (like a thread would)
|
||||
persisted_messages = [
|
||||
ChatMessage(role="user", contents=[TextContent(text="hello")]),
|
||||
*response1.messages,
|
||||
]
|
||||
|
||||
# Send approval
|
||||
approval_req = [c for c in response1.messages[0].contents if isinstance(c, FunctionApprovalRequestContent)][0]
|
||||
approval_response = FunctionApprovalResponseContent(
|
||||
id=approval_req.id,
|
||||
function_call=approval_req.function_call,
|
||||
approved=True,
|
||||
)
|
||||
persisted_messages.append(ChatMessage(role="user", contents=[approval_response]))
|
||||
|
||||
# Continue with all persisted messages
|
||||
response2 = await chat_client_base.get_response(persisted_messages, tool_choice="auto", tools=[func_with_approval])
|
||||
|
||||
# Should execute successfully
|
||||
assert response2 is not None
|
||||
assert exec_counter == 1
|
||||
assert response2.messages[-1].text == "done"
|
||||
|
||||
|
||||
async def test_no_duplicate_function_calls_after_approval_processing(chat_client_base: ChatClientProtocol):
|
||||
"""Processing approval should not create duplicate function calls in messages."""
|
||||
from agent_framework import FunctionApprovalResponseContent
|
||||
|
||||
@ai_function(name="test_func", approval_mode="always_require")
|
||||
def func_with_approval(arg1: str) -> str:
|
||||
return f"Result {arg1}"
|
||||
|
||||
chat_client_base.run_responses = [
|
||||
ChatResponse(
|
||||
messages=ChatMessage(
|
||||
role="assistant",
|
||||
contents=[
|
||||
FunctionCallContent(call_id="1", name="test_func", arguments='{"arg1": "value1"}'),
|
||||
],
|
||||
)
|
||||
),
|
||||
ChatResponse(messages=ChatMessage(role="assistant", text="done")),
|
||||
]
|
||||
|
||||
response1 = await chat_client_base.get_response("hello", tool_choice="auto", tools=[func_with_approval])
|
||||
|
||||
approval_req = [c for c in response1.messages[0].contents if isinstance(c, FunctionApprovalRequestContent)][0]
|
||||
approval_response = FunctionApprovalResponseContent(
|
||||
id=approval_req.id,
|
||||
function_call=approval_req.function_call,
|
||||
approved=True,
|
||||
)
|
||||
|
||||
all_messages = response1.messages + [ChatMessage(role="user", contents=[approval_response])]
|
||||
await chat_client_base.get_response(all_messages, tool_choice="auto", tools=[func_with_approval])
|
||||
|
||||
# Count function calls with the same call_id
|
||||
function_call_count = sum(
|
||||
1
|
||||
for msg in all_messages
|
||||
for content in msg.contents
|
||||
if isinstance(content, FunctionCallContent) and content.call_id == "1"
|
||||
)
|
||||
|
||||
assert function_call_count == 1
|
||||
|
||||
|
||||
async def test_rejection_result_uses_function_call_id(chat_client_base: ChatClientProtocol):
|
||||
"""Rejection error result should use the function call's call_id, not the approval's id."""
|
||||
from agent_framework import FunctionApprovalResponseContent
|
||||
|
||||
@ai_function(name="test_func", approval_mode="always_require")
|
||||
def func_with_approval(arg1: str) -> str:
|
||||
return f"Result {arg1}"
|
||||
|
||||
chat_client_base.run_responses = [
|
||||
ChatResponse(
|
||||
messages=ChatMessage(
|
||||
role="assistant",
|
||||
contents=[
|
||||
FunctionCallContent(call_id="call_123", name="test_func", arguments='{"arg1": "value1"}'),
|
||||
],
|
||||
)
|
||||
),
|
||||
ChatResponse(messages=ChatMessage(role="assistant", text="done")),
|
||||
]
|
||||
|
||||
response1 = await chat_client_base.get_response("hello", tool_choice="auto", tools=[func_with_approval])
|
||||
|
||||
approval_req = [c for c in response1.messages[0].contents if isinstance(c, FunctionApprovalRequestContent)][0]
|
||||
rejection_response = FunctionApprovalResponseContent(
|
||||
id=approval_req.id,
|
||||
function_call=approval_req.function_call,
|
||||
approved=False,
|
||||
)
|
||||
|
||||
all_messages = response1.messages + [ChatMessage(role="user", contents=[rejection_response])]
|
||||
await chat_client_base.get_response(all_messages, tool_choice="auto", tools=[func_with_approval])
|
||||
|
||||
# Find the rejection result
|
||||
rejection_result = next(
|
||||
(content for msg in all_messages for content in msg.contents if isinstance(content, FunctionResultContent)),
|
||||
None,
|
||||
)
|
||||
|
||||
assert rejection_result is not None
|
||||
assert rejection_result.call_id == "call_123"
|
||||
assert "rejected" in rejection_result.result.lower()
|
||||
|
||||
|
||||
async def test_max_iterations_limit(chat_client_base: ChatClientProtocol):
|
||||
"""Test that MAX_ITERATIONS in additional_properties limits function call loops."""
|
||||
|
||||
@@ -747,13 +747,14 @@ async def test_non_streaming_single_function_requires_approval():
|
||||
# Execute
|
||||
result = await wrapped(mock_client, messages=[], tools=[requires_approval_tool])
|
||||
|
||||
# Verify: should return 2 messages - function call and approval request
|
||||
# Verify: should return 1 message with function call and approval request
|
||||
from agent_framework import FunctionApprovalRequestContent
|
||||
|
||||
assert len(result.messages) == 2
|
||||
assert len(result.messages) == 1
|
||||
assert len(result.messages[0].contents) == 2
|
||||
assert isinstance(result.messages[0].contents[0], FunctionCallContent)
|
||||
assert isinstance(result.messages[1].contents[0], FunctionApprovalRequestContent)
|
||||
assert result.messages[1].contents[0].function_call.name == "requires_approval_tool"
|
||||
assert isinstance(result.messages[0].contents[1], FunctionApprovalRequestContent)
|
||||
assert result.messages[0].contents[1].function_call.name == "requires_approval_tool"
|
||||
|
||||
|
||||
async def test_non_streaming_two_functions_both_no_approval():
|
||||
@@ -838,16 +839,17 @@ async def test_non_streaming_two_functions_both_require_approval():
|
||||
# Execute
|
||||
result = await wrapped(mock_client, messages=[], tools=[requires_approval_tool])
|
||||
|
||||
# Verify: should return 2 messages - function calls and approval requests
|
||||
# Verify: should return 1 message with function calls and approval requests
|
||||
from agent_framework import FunctionApprovalRequestContent
|
||||
|
||||
assert len(result.messages) == 2
|
||||
assert len(result.messages[0].contents) == 2 # Both function calls
|
||||
assert all(isinstance(c, FunctionCallContent) for c in result.messages[0].contents)
|
||||
assert len(result.messages[1].contents) == 2 # Both approval requests
|
||||
assert all(isinstance(c, FunctionApprovalRequestContent) for c in result.messages[1].contents)
|
||||
assert result.messages[1].contents[0].function_call.name == "requires_approval_tool"
|
||||
assert result.messages[1].contents[1].function_call.name == "requires_approval_tool"
|
||||
assert len(result.messages) == 1
|
||||
assert len(result.messages[0].contents) == 4 # 2 function calls + 2 approval requests
|
||||
function_calls = [c for c in result.messages[0].contents if isinstance(c, FunctionCallContent)]
|
||||
approval_requests = [c for c in result.messages[0].contents if isinstance(c, FunctionApprovalRequestContent)]
|
||||
assert len(function_calls) == 2
|
||||
assert len(approval_requests) == 2
|
||||
assert approval_requests[0].function_call.name == "requires_approval_tool"
|
||||
assert approval_requests[1].function_call.name == "requires_approval_tool"
|
||||
|
||||
|
||||
async def test_non_streaming_two_functions_mixed_approval():
|
||||
@@ -886,10 +888,10 @@ async def test_non_streaming_two_functions_mixed_approval():
|
||||
# Verify: should return approval requests for both (when one needs approval, all are sent for approval)
|
||||
from agent_framework import FunctionApprovalRequestContent
|
||||
|
||||
assert len(result.messages) == 2
|
||||
assert len(result.messages[0].contents) == 2 # Both function calls
|
||||
assert len(result.messages[1].contents) == 2 # Both approval requests
|
||||
assert all(isinstance(c, FunctionApprovalRequestContent) for c in result.messages[1].contents)
|
||||
assert len(result.messages) == 1
|
||||
assert len(result.messages[0].contents) == 4 # 2 function calls + 2 approval requests
|
||||
approval_requests = [c for c in result.messages[0].contents if isinstance(c, FunctionApprovalRequestContent)]
|
||||
assert len(approval_requests) == 2
|
||||
|
||||
|
||||
async def test_streaming_single_function_no_approval():
|
||||
@@ -974,7 +976,7 @@ async def test_streaming_single_function_requires_approval():
|
||||
|
||||
assert len(updates) == 2
|
||||
assert isinstance(updates[0].contents[0], FunctionCallContent)
|
||||
assert updates[1].role == Role.TOOL
|
||||
assert updates[1].role == Role.ASSISTANT
|
||||
assert isinstance(updates[1].contents[0], FunctionApprovalRequestContent)
|
||||
|
||||
|
||||
@@ -1069,8 +1071,8 @@ async def test_streaming_two_functions_both_require_approval():
|
||||
assert len(updates) == 3
|
||||
assert isinstance(updates[0].contents[0], FunctionCallContent)
|
||||
assert isinstance(updates[1].contents[0], FunctionCallContent)
|
||||
# Tool update with both approval requests
|
||||
assert updates[2].role == Role.TOOL
|
||||
# Assistant update with both approval requests
|
||||
assert updates[2].role == Role.ASSISTANT
|
||||
assert len(updates[2].contents) == 2
|
||||
assert all(isinstance(c, FunctionApprovalRequestContent) for c in updates[2].contents)
|
||||
|
||||
@@ -1116,7 +1118,7 @@ async def test_streaming_two_functions_mixed_approval():
|
||||
assert len(updates) == 3
|
||||
assert isinstance(updates[0].contents[0], FunctionCallContent)
|
||||
assert isinstance(updates[1].contents[0], FunctionCallContent)
|
||||
# Tool update with both approval requests
|
||||
assert updates[2].role == Role.TOOL
|
||||
# Assistant update with both approval requests
|
||||
assert updates[2].role == Role.ASSISTANT
|
||||
assert len(updates[2].contents) == 2
|
||||
assert all(isinstance(c, FunctionApprovalRequestContent) for c in updates[2].contents)
|
||||
|
||||
@@ -0,0 +1,102 @@
|
||||
# Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
import asyncio
|
||||
from typing import Annotated
|
||||
|
||||
from agent_framework import ChatAgent, ChatMessage, ai_function
|
||||
from agent_framework.azure import AzureOpenAIChatClient
|
||||
|
||||
"""
|
||||
Tool Approvals with Threads
|
||||
|
||||
This sample demonstrates using tool approvals with threads.
|
||||
With threads, you don't need to manually pass previous messages -
|
||||
the thread stores and retrieves them automatically.
|
||||
"""
|
||||
|
||||
|
||||
@ai_function(approval_mode="always_require")
|
||||
def add_to_calendar(
|
||||
event_name: Annotated[str, "Name of the event"], date: Annotated[str, "Date of the event"]
|
||||
) -> str:
|
||||
"""Add an event to the calendar (requires approval)."""
|
||||
print(f">>> EXECUTING: add_to_calendar(event_name='{event_name}', date='{date}')")
|
||||
return f"Added '{event_name}' to calendar on {date}"
|
||||
|
||||
|
||||
async def approval_example() -> None:
|
||||
"""Example showing approval with threads."""
|
||||
print("=== Tool Approval with Thread ===\n")
|
||||
|
||||
agent = ChatAgent(
|
||||
chat_client=AzureOpenAIChatClient(),
|
||||
name="CalendarAgent",
|
||||
instructions="You are a helpful calendar assistant.",
|
||||
tools=[add_to_calendar],
|
||||
)
|
||||
|
||||
thread = agent.get_new_thread()
|
||||
|
||||
# Step 1: Agent requests to call the tool
|
||||
query = "Add a dentist appointment on March 15th"
|
||||
print(f"User: {query}")
|
||||
result = await agent.run(query, thread=thread)
|
||||
|
||||
# Check for approval requests
|
||||
if result.user_input_requests:
|
||||
for request in result.user_input_requests:
|
||||
print(f"\nApproval needed:")
|
||||
print(f" Function: {request.function_call.name}")
|
||||
print(f" Arguments: {request.function_call.arguments}")
|
||||
|
||||
# User approves (in real app, this would be user input)
|
||||
approved = True # Change to False to see rejection
|
||||
print(f" Decision: {'Approved' if approved else 'Rejected'}")
|
||||
|
||||
# Step 2: Send approval response
|
||||
approval_response = request.create_response(approved=approved)
|
||||
result = await agent.run(ChatMessage(role="user", contents=[approval_response]), thread=thread)
|
||||
|
||||
print(f"Agent: {result}\n")
|
||||
|
||||
|
||||
async def rejection_example() -> None:
|
||||
"""Example showing rejection with threads."""
|
||||
print("=== Tool Rejection with Thread ===\n")
|
||||
|
||||
agent = ChatAgent(
|
||||
chat_client=AzureOpenAIChatClient(),
|
||||
name="CalendarAgent",
|
||||
instructions="You are a helpful calendar assistant.",
|
||||
tools=[add_to_calendar],
|
||||
)
|
||||
|
||||
thread = agent.get_new_thread()
|
||||
|
||||
query = "Add a team meeting on December 20th"
|
||||
print(f"User: {query}")
|
||||
result = await agent.run(query, thread=thread)
|
||||
|
||||
if result.user_input_requests:
|
||||
for request in result.user_input_requests:
|
||||
print(f"\nApproval needed:")
|
||||
print(f" Function: {request.function_call.name}")
|
||||
print(f" Arguments: {request.function_call.arguments}")
|
||||
|
||||
# User rejects
|
||||
print(f" Decision: Rejected")
|
||||
|
||||
# Send rejection response
|
||||
rejection_response = request.create_response(approved=False)
|
||||
result = await agent.run(ChatMessage(role="user", contents=[rejection_response]), thread=thread)
|
||||
|
||||
print(f"Agent: {result}\n")
|
||||
|
||||
|
||||
async def main() -> None:
|
||||
await approval_example()
|
||||
await rejection_example()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
Reference in New Issue
Block a user