mirror of
https://github.com/microsoft/agent-framework.git
synced 2026-06-16 21:04:09 +08:00
Python: Added custom args and thread object to ai_function kwargs (#2769)
* Added an example of using kwargs in ai_function * Added thread object to ai_function kwargs * Updated docs * Small fix * Added thread parameter filtering
This commit is contained in:
committed by
GitHub
Unverified
parent
eb1117fff4
commit
d7434d59ce
@@ -878,6 +878,9 @@ class ChatAgent(BaseAgent):
|
||||
user=user,
|
||||
additional_properties=merged_additional_options, # type: ignore[arg-type]
|
||||
)
|
||||
|
||||
# Ensure thread is forwarded in kwargs for tool invocation
|
||||
kwargs["thread"] = thread
|
||||
# Filter chat_options from kwargs to prevent duplicate keyword argument
|
||||
filtered_kwargs = {k: v for k, v in kwargs.items() if k != "chat_options"}
|
||||
response = await self.chat_client.get_response(
|
||||
@@ -895,7 +898,12 @@ class ChatAgent(BaseAgent):
|
||||
|
||||
# Only notify the thread of new messages if the chatResponse was successful
|
||||
# to avoid inconsistent messages state in the thread.
|
||||
await self._notify_thread_of_new_messages(thread, input_messages, response.messages)
|
||||
await self._notify_thread_of_new_messages(
|
||||
thread,
|
||||
input_messages,
|
||||
response.messages,
|
||||
**{k: v for k, v in kwargs.items() if k != "thread"},
|
||||
)
|
||||
return AgentRunResponse(
|
||||
messages=response.messages,
|
||||
response_id=response.response_id,
|
||||
@@ -1017,6 +1025,8 @@ class ChatAgent(BaseAgent):
|
||||
additional_properties=merged_additional_options, # type: ignore[arg-type]
|
||||
)
|
||||
|
||||
# Ensure thread is forwarded in kwargs for tool invocation
|
||||
kwargs["thread"] = thread
|
||||
# Filter chat_options from kwargs to prevent duplicate keyword argument
|
||||
filtered_kwargs = {k: v for k, v in kwargs.items() if k != "chat_options"}
|
||||
response_updates: list[ChatResponseUpdate] = []
|
||||
@@ -1043,7 +1053,13 @@ class ChatAgent(BaseAgent):
|
||||
|
||||
response = ChatResponse.from_chat_response_updates(response_updates, output_format_type=co.response_format)
|
||||
await self._update_thread_with_type_and_conversation_id(thread, response.conversation_id)
|
||||
await self._notify_thread_of_new_messages(thread, input_messages, response.messages, **kwargs)
|
||||
|
||||
await self._notify_thread_of_new_messages(
|
||||
thread,
|
||||
input_messages,
|
||||
response.messages,
|
||||
**{k: v for k, v in kwargs.items() if k != "thread"},
|
||||
)
|
||||
|
||||
@override
|
||||
def get_new_thread(
|
||||
|
||||
@@ -627,6 +627,12 @@ class AIFunction(BaseTool, Generic[ArgsT, ReturnT]):
|
||||
self._invocation_duration_histogram = _default_histogram()
|
||||
self.type: Literal["ai_function"] = "ai_function"
|
||||
self._forward_runtime_kwargs: bool = False
|
||||
if self.func:
|
||||
sig = inspect.signature(self.func)
|
||||
for param in sig.parameters.values():
|
||||
if param.kind == inspect.Parameter.VAR_KEYWORD:
|
||||
self._forward_runtime_kwargs = True
|
||||
break
|
||||
|
||||
@property
|
||||
def declaration_only(self) -> bool:
|
||||
@@ -915,6 +921,7 @@ def _create_input_model_from_func(func: Callable[..., Any], name: str) -> type[B
|
||||
)
|
||||
for pname, param in sig.parameters.items()
|
||||
if pname not in {"self", "cls"}
|
||||
and param.kind not in {inspect.Parameter.VAR_POSITIONAL, inspect.Parameter.VAR_KEYWORD}
|
||||
}
|
||||
return create_model(f"{name}_input", **fields) # type: ignore[call-overload, no-any-return]
|
||||
|
||||
@@ -1744,7 +1751,9 @@ def _handle_function_calls_response(
|
||||
break
|
||||
_replace_approval_contents_with_results(prepped_messages, fcc_todo, approved_function_results)
|
||||
|
||||
response = await func(self, messages=prepped_messages, **kwargs)
|
||||
# Filter out internal framework kwargs before passing to clients.
|
||||
filtered_kwargs = {k: v for k, v in kwargs.items() if k != "thread"}
|
||||
response = await func(self, messages=prepped_messages, **filtered_kwargs)
|
||||
# if there are function calls, we will handle them first
|
||||
function_results = {
|
||||
it.call_id for it in response.messages[0].contents if isinstance(it, FunctionResultContent)
|
||||
@@ -1833,7 +1842,10 @@ def _handle_function_calls_response(
|
||||
|
||||
# Failsafe: give up on tools, ask model for plain answer
|
||||
kwargs["tool_choice"] = "none"
|
||||
response = await func(self, messages=prepped_messages, **kwargs)
|
||||
|
||||
# Filter out internal framework kwargs before passing to clients.
|
||||
filtered_kwargs = {k: v for k, v in kwargs.items() if k != "thread"}
|
||||
response = await func(self, messages=prepped_messages, **filtered_kwargs)
|
||||
if fcc_messages:
|
||||
for msg in reversed(fcc_messages):
|
||||
response.messages.insert(0, msg)
|
||||
@@ -1920,7 +1932,9 @@ def _handle_function_calls_streaming_response(
|
||||
_replace_approval_contents_with_results(prepped_messages, fcc_todo, approved_function_results)
|
||||
|
||||
all_updates: list["ChatResponseUpdate"] = []
|
||||
async for update in func(self, messages=prepped_messages, **kwargs):
|
||||
# Filter out internal framework kwargs before passing to clients.
|
||||
filtered_kwargs = {k: v for k, v in kwargs.items() if k != "thread"}
|
||||
async for update in func(self, messages=prepped_messages, **filtered_kwargs):
|
||||
all_updates.append(update)
|
||||
yield update
|
||||
|
||||
@@ -2031,7 +2045,9 @@ def _handle_function_calls_streaming_response(
|
||||
|
||||
# Failsafe: give up on tools, ask model for plain answer
|
||||
kwargs["tool_choice"] = "none"
|
||||
async for update in func(self, messages=prepped_messages, **kwargs):
|
||||
# Filter out internal framework kwargs before passing to clients.
|
||||
filtered_kwargs = {k: v for k, v in kwargs.items() if k != "thread"}
|
||||
async for update in func(self, messages=prepped_messages, **filtered_kwargs):
|
||||
yield update
|
||||
|
||||
return streaming_function_invocation_wrapper
|
||||
|
||||
@@ -21,9 +21,11 @@ from agent_framework import (
|
||||
ChatResponse,
|
||||
Context,
|
||||
ContextProvider,
|
||||
FunctionCallContent,
|
||||
HostedCodeInterpreterTool,
|
||||
Role,
|
||||
TextContent,
|
||||
ai_function,
|
||||
)
|
||||
from agent_framework._mcp import MCPTool
|
||||
from agent_framework.exceptions import AgentExecutionException
|
||||
@@ -595,3 +597,38 @@ async def test_chat_agent_with_local_mcp_tools(chat_client: ChatClientProtocol)
|
||||
# Test async context manager with MCP tools
|
||||
async with agent:
|
||||
pass
|
||||
|
||||
|
||||
async def test_agent_tool_receives_thread_in_kwargs(chat_client_base: Any) -> None:
|
||||
"""Verify tool execution receives 'thread' inside **kwargs when function is called by client."""
|
||||
|
||||
captured: dict[str, Any] = {}
|
||||
|
||||
@ai_function(name="echo_thread_info")
|
||||
def echo_thread_info(text: str, **kwargs: Any) -> str: # type: ignore[reportUnknownParameterType]
|
||||
thread = kwargs.get("thread")
|
||||
captured["has_thread"] = thread is not None
|
||||
captured["has_message_store"] = thread.message_store is not None if isinstance(thread, AgentThread) else False
|
||||
return f"echo: {text}"
|
||||
|
||||
# Make the base client emit a function call for our tool
|
||||
chat_client_base.run_responses = [
|
||||
ChatResponse(
|
||||
messages=ChatMessage(
|
||||
role="assistant",
|
||||
contents=[FunctionCallContent(call_id="1", name="echo_thread_info", arguments='{"text": "hello"}')],
|
||||
)
|
||||
),
|
||||
ChatResponse(messages=ChatMessage(role="assistant", text="done")),
|
||||
]
|
||||
|
||||
agent = ChatAgent(
|
||||
chat_client=chat_client_base, tools=[echo_thread_info], chat_message_store_factory=ChatMessageStore
|
||||
)
|
||||
thread = agent.get_new_thread()
|
||||
|
||||
result = await agent.run("hello", thread=thread)
|
||||
|
||||
assert result.text == "done"
|
||||
assert captured.get("has_thread") is True
|
||||
assert captured.get("has_message_store") is True
|
||||
|
||||
@@ -1334,3 +1334,37 @@ async def test_streaming_two_functions_mixed_approval():
|
||||
assert updates[2].role == Role.ASSISTANT
|
||||
assert len(updates[2].contents) == 2
|
||||
assert all(isinstance(c, FunctionApprovalRequestContent) for c in updates[2].contents)
|
||||
|
||||
|
||||
async def test_ai_function_with_kwargs_injection():
|
||||
"""Test that ai_function correctly handles kwargs injection and hides them from schema."""
|
||||
|
||||
@ai_function
|
||||
def tool_with_kwargs(x: int, **kwargs: Any) -> str:
|
||||
"""A tool that accepts kwargs."""
|
||||
user_id = kwargs.get("user_id", "unknown")
|
||||
return f"x={x}, user={user_id}"
|
||||
|
||||
# Verify schema does not include kwargs
|
||||
assert tool_with_kwargs.parameters() == {
|
||||
"properties": {"x": {"title": "X", "type": "integer"}},
|
||||
"required": ["x"],
|
||||
"title": "tool_with_kwargs_input",
|
||||
"type": "object",
|
||||
}
|
||||
|
||||
# Verify direct invocation works
|
||||
assert tool_with_kwargs(1, user_id="user1") == "x=1, user=user1"
|
||||
|
||||
# Verify invoke works with injected args
|
||||
result = await tool_with_kwargs.invoke(
|
||||
arguments=tool_with_kwargs.input_model(x=5),
|
||||
user_id="user2",
|
||||
)
|
||||
assert result == "x=5, user=user2"
|
||||
|
||||
# Verify invoke works without injected args (uses default)
|
||||
result_default = await tool_with_kwargs.invoke(
|
||||
arguments=tool_with_kwargs.input_model(x=10),
|
||||
)
|
||||
assert result_default == "x=10, user=unknown"
|
||||
|
||||
@@ -11,6 +11,8 @@ This folder contains examples demonstrating how to use AI functions (tools) with
|
||||
| [`ai_function_recover_from_failures.py`](ai_function_recover_from_failures.py) | Demonstrates graceful error handling when tools raise exceptions. Shows how agents receive error information and can recover from failures, deciding whether to retry or respond differently based on the exception. |
|
||||
| [`ai_function_with_approval.py`](ai_function_with_approval.py) | Shows how to implement user approval workflows for function calls without using threads. Demonstrates both streaming and non-streaming approval patterns where users can approve or reject function executions before they run. |
|
||||
| [`ai_function_with_approval_and_threads.py`](ai_function_with_approval_and_threads.py) | Demonstrates tool approval workflows using threads for automatic conversation history management. Shows how threads simplify approval workflows by automatically storing and retrieving conversation context. Includes both approval and rejection examples. |
|
||||
| [`ai_function_with_kwargs.py`](ai_function_with_kwargs.py) | Demonstrates how to inject custom arguments (context) into an AI function from the agent's run method. Useful for passing runtime information like access tokens or user IDs that the tool needs but the model shouldn't see. |
|
||||
| [`ai_function_with_thread_injection.py`](ai_function_with_thread_injection.py) | Shows how to access the current `thread` object inside an AI function via `**kwargs`. |
|
||||
| [`ai_function_with_max_exceptions.py`](ai_function_with_max_exceptions.py) | Shows how to limit the number of times a tool can fail with exceptions using `max_invocation_exceptions`. Useful for preventing expensive tools from being called repeatedly when they keep failing. |
|
||||
| [`ai_function_with_max_invocations.py`](ai_function_with_max_invocations.py) | Demonstrates limiting the total number of times a tool can be invoked using `max_invocations`. Useful for rate-limiting expensive operations or ensuring tools are only called a specific number of times per conversation. |
|
||||
| [`ai_functions_in_class.py`](ai_functions_in_class.py) | Shows how to use `ai_function` decorator with class methods to create stateful tools. Demonstrates how class state can control tool behavior dynamically, allowing you to adjust tool functionality at runtime by modifying class properties. |
|
||||
|
||||
@@ -0,0 +1,53 @@
|
||||
# Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
import asyncio
|
||||
from typing import Annotated, Any
|
||||
|
||||
from agent_framework import ai_function
|
||||
from agent_framework.openai import OpenAIResponsesClient
|
||||
from pydantic import Field
|
||||
|
||||
"""
|
||||
AI Function with kwargs Example
|
||||
|
||||
This example demonstrates how to inject custom keyword arguments (kwargs) into an AI function
|
||||
from the agent's run method, without exposing them to the AI model.
|
||||
|
||||
This is useful for passing runtime information like access tokens, user IDs, or
|
||||
request-specific context that the tool needs but the model shouldn't know about
|
||||
or provide.
|
||||
"""
|
||||
|
||||
|
||||
# Define the function tool with **kwargs to accept injected arguments
|
||||
@ai_function
|
||||
def get_weather(
|
||||
location: Annotated[str, Field(description="The location to get the weather for.")],
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
"""Get the weather for a given location."""
|
||||
# Extract the injected argument from kwargs
|
||||
user_id = kwargs.get("user_id", "unknown")
|
||||
|
||||
# Simulate using the user_id for logging or personalization
|
||||
print(f"Getting weather for user: {user_id}")
|
||||
|
||||
return f"The weather in {location} is cloudy with a high of 15°C."
|
||||
|
||||
|
||||
async def main() -> None:
|
||||
agent = OpenAIResponsesClient().create_agent(
|
||||
name="WeatherAgent",
|
||||
instructions="You are a helpful weather assistant.",
|
||||
tools=[get_weather],
|
||||
)
|
||||
|
||||
# Pass the injected argument when running the agent
|
||||
# The 'user_id' kwarg will be passed down to the tool execution via **kwargs
|
||||
response = await agent.run("What is the weather like in Amsterdam?", user_id="user_123")
|
||||
|
||||
print(f"Agent: {response.text}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
@@ -0,0 +1,52 @@
|
||||
# Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
import asyncio
|
||||
from typing import Annotated, Any
|
||||
|
||||
from agent_framework import AgentThread, ai_function
|
||||
from agent_framework.openai import OpenAIChatClient
|
||||
from pydantic import Field
|
||||
|
||||
"""
|
||||
AI Function with Thread Injection Example
|
||||
|
||||
This example demonstrates the behavior when passing 'thread' to agent.run()
|
||||
and accessing that thread in AI function.
|
||||
"""
|
||||
|
||||
|
||||
# Define the function tool with **kwargs
|
||||
@ai_function
|
||||
async def get_weather(
|
||||
location: Annotated[str, Field(description="The location to get the weather for.")],
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
"""Get the weather for a given location."""
|
||||
# Get thread object from kwargs
|
||||
thread = kwargs.get("thread")
|
||||
if thread and isinstance(thread, AgentThread):
|
||||
if thread.message_store:
|
||||
messages = await thread.message_store.list_messages()
|
||||
print(f"Thread contains {len(messages)} messages.")
|
||||
elif thread.service_thread_id:
|
||||
print(f"Thread ID: {thread.service_thread_id}.")
|
||||
|
||||
return f"The weather in {location} is cloudy."
|
||||
|
||||
|
||||
async def main() -> None:
|
||||
agent = OpenAIChatClient().create_agent(
|
||||
name="WeatherAgent", instructions="You are a helpful weather assistant.", tools=[get_weather]
|
||||
)
|
||||
|
||||
# Create a thread
|
||||
thread = agent.get_new_thread()
|
||||
|
||||
# Run the agent with the thread
|
||||
print(f"Agent: {await agent.run('What is the weather in London?', thread=thread)}")
|
||||
print(f"Agent: {await agent.run('What is the weather in Amsterdam?', thread=thread)}")
|
||||
print(f"Agent: {await agent.run('What cities did I ask about?', thread=thread)}")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
asyncio.run(main())
|
||||
Reference in New Issue
Block a user