Python: Added custom args and thread object to ai_function kwargs (#2769)

* Added an example of using kwargs in ai_function

* Added thread object to ai_function kwargs

* Updated docs

* Small fix

* Added thread parameter filtering
This commit is contained in:
Dmytro Struk
2025-12-11 17:53:04 -08:00
committed by GitHub
Unverified
parent eb1117fff4
commit d7434d59ce
7 changed files with 216 additions and 6 deletions
@@ -878,6 +878,9 @@ class ChatAgent(BaseAgent):
user=user,
additional_properties=merged_additional_options, # type: ignore[arg-type]
)
# Ensure thread is forwarded in kwargs for tool invocation
kwargs["thread"] = thread
# Filter chat_options from kwargs to prevent duplicate keyword argument
filtered_kwargs = {k: v for k, v in kwargs.items() if k != "chat_options"}
response = await self.chat_client.get_response(
@@ -895,7 +898,12 @@ class ChatAgent(BaseAgent):
# Only notify the thread of new messages if the chatResponse was successful
# to avoid inconsistent messages state in the thread.
await self._notify_thread_of_new_messages(thread, input_messages, response.messages)
await self._notify_thread_of_new_messages(
thread,
input_messages,
response.messages,
**{k: v for k, v in kwargs.items() if k != "thread"},
)
return AgentRunResponse(
messages=response.messages,
response_id=response.response_id,
@@ -1017,6 +1025,8 @@ class ChatAgent(BaseAgent):
additional_properties=merged_additional_options, # type: ignore[arg-type]
)
# Ensure thread is forwarded in kwargs for tool invocation
kwargs["thread"] = thread
# Filter chat_options from kwargs to prevent duplicate keyword argument
filtered_kwargs = {k: v for k, v in kwargs.items() if k != "chat_options"}
response_updates: list[ChatResponseUpdate] = []
@@ -1043,7 +1053,13 @@ class ChatAgent(BaseAgent):
response = ChatResponse.from_chat_response_updates(response_updates, output_format_type=co.response_format)
await self._update_thread_with_type_and_conversation_id(thread, response.conversation_id)
await self._notify_thread_of_new_messages(thread, input_messages, response.messages, **kwargs)
await self._notify_thread_of_new_messages(
thread,
input_messages,
response.messages,
**{k: v for k, v in kwargs.items() if k != "thread"},
)
@override
def get_new_thread(
+20 -4
View File
@@ -627,6 +627,12 @@ class AIFunction(BaseTool, Generic[ArgsT, ReturnT]):
self._invocation_duration_histogram = _default_histogram()
self.type: Literal["ai_function"] = "ai_function"
self._forward_runtime_kwargs: bool = False
if self.func:
sig = inspect.signature(self.func)
for param in sig.parameters.values():
if param.kind == inspect.Parameter.VAR_KEYWORD:
self._forward_runtime_kwargs = True
break
@property
def declaration_only(self) -> bool:
@@ -915,6 +921,7 @@ def _create_input_model_from_func(func: Callable[..., Any], name: str) -> type[B
)
for pname, param in sig.parameters.items()
if pname not in {"self", "cls"}
and param.kind not in {inspect.Parameter.VAR_POSITIONAL, inspect.Parameter.VAR_KEYWORD}
}
return create_model(f"{name}_input", **fields) # type: ignore[call-overload, no-any-return]
@@ -1744,7 +1751,9 @@ def _handle_function_calls_response(
break
_replace_approval_contents_with_results(prepped_messages, fcc_todo, approved_function_results)
response = await func(self, messages=prepped_messages, **kwargs)
# Filter out internal framework kwargs before passing to clients.
filtered_kwargs = {k: v for k, v in kwargs.items() if k != "thread"}
response = await func(self, messages=prepped_messages, **filtered_kwargs)
# if there are function calls, we will handle them first
function_results = {
it.call_id for it in response.messages[0].contents if isinstance(it, FunctionResultContent)
@@ -1833,7 +1842,10 @@ def _handle_function_calls_response(
# Failsafe: give up on tools, ask model for plain answer
kwargs["tool_choice"] = "none"
response = await func(self, messages=prepped_messages, **kwargs)
# Filter out internal framework kwargs before passing to clients.
filtered_kwargs = {k: v for k, v in kwargs.items() if k != "thread"}
response = await func(self, messages=prepped_messages, **filtered_kwargs)
if fcc_messages:
for msg in reversed(fcc_messages):
response.messages.insert(0, msg)
@@ -1920,7 +1932,9 @@ def _handle_function_calls_streaming_response(
_replace_approval_contents_with_results(prepped_messages, fcc_todo, approved_function_results)
all_updates: list["ChatResponseUpdate"] = []
async for update in func(self, messages=prepped_messages, **kwargs):
# Filter out internal framework kwargs before passing to clients.
filtered_kwargs = {k: v for k, v in kwargs.items() if k != "thread"}
async for update in func(self, messages=prepped_messages, **filtered_kwargs):
all_updates.append(update)
yield update
@@ -2031,7 +2045,9 @@ def _handle_function_calls_streaming_response(
# Failsafe: give up on tools, ask model for plain answer
kwargs["tool_choice"] = "none"
async for update in func(self, messages=prepped_messages, **kwargs):
# Filter out internal framework kwargs before passing to clients.
filtered_kwargs = {k: v for k, v in kwargs.items() if k != "thread"}
async for update in func(self, messages=prepped_messages, **filtered_kwargs):
yield update
return streaming_function_invocation_wrapper
@@ -21,9 +21,11 @@ from agent_framework import (
ChatResponse,
Context,
ContextProvider,
FunctionCallContent,
HostedCodeInterpreterTool,
Role,
TextContent,
ai_function,
)
from agent_framework._mcp import MCPTool
from agent_framework.exceptions import AgentExecutionException
@@ -595,3 +597,38 @@ async def test_chat_agent_with_local_mcp_tools(chat_client: ChatClientProtocol)
# Test async context manager with MCP tools
async with agent:
pass
async def test_agent_tool_receives_thread_in_kwargs(chat_client_base: Any) -> None:
"""Verify tool execution receives 'thread' inside **kwargs when function is called by client."""
captured: dict[str, Any] = {}
@ai_function(name="echo_thread_info")
def echo_thread_info(text: str, **kwargs: Any) -> str: # type: ignore[reportUnknownParameterType]
thread = kwargs.get("thread")
captured["has_thread"] = thread is not None
captured["has_message_store"] = thread.message_store is not None if isinstance(thread, AgentThread) else False
return f"echo: {text}"
# Make the base client emit a function call for our tool
chat_client_base.run_responses = [
ChatResponse(
messages=ChatMessage(
role="assistant",
contents=[FunctionCallContent(call_id="1", name="echo_thread_info", arguments='{"text": "hello"}')],
)
),
ChatResponse(messages=ChatMessage(role="assistant", text="done")),
]
agent = ChatAgent(
chat_client=chat_client_base, tools=[echo_thread_info], chat_message_store_factory=ChatMessageStore
)
thread = agent.get_new_thread()
result = await agent.run("hello", thread=thread)
assert result.text == "done"
assert captured.get("has_thread") is True
assert captured.get("has_message_store") is True
@@ -1334,3 +1334,37 @@ async def test_streaming_two_functions_mixed_approval():
assert updates[2].role == Role.ASSISTANT
assert len(updates[2].contents) == 2
assert all(isinstance(c, FunctionApprovalRequestContent) for c in updates[2].contents)
async def test_ai_function_with_kwargs_injection():
"""Test that ai_function correctly handles kwargs injection and hides them from schema."""
@ai_function
def tool_with_kwargs(x: int, **kwargs: Any) -> str:
"""A tool that accepts kwargs."""
user_id = kwargs.get("user_id", "unknown")
return f"x={x}, user={user_id}"
# Verify schema does not include kwargs
assert tool_with_kwargs.parameters() == {
"properties": {"x": {"title": "X", "type": "integer"}},
"required": ["x"],
"title": "tool_with_kwargs_input",
"type": "object",
}
# Verify direct invocation works
assert tool_with_kwargs(1, user_id="user1") == "x=1, user=user1"
# Verify invoke works with injected args
result = await tool_with_kwargs.invoke(
arguments=tool_with_kwargs.input_model(x=5),
user_id="user2",
)
assert result == "x=5, user=user2"
# Verify invoke works without injected args (uses default)
result_default = await tool_with_kwargs.invoke(
arguments=tool_with_kwargs.input_model(x=10),
)
assert result_default == "x=10, user=unknown"
@@ -11,6 +11,8 @@ This folder contains examples demonstrating how to use AI functions (tools) with
| [`ai_function_recover_from_failures.py`](ai_function_recover_from_failures.py) | Demonstrates graceful error handling when tools raise exceptions. Shows how agents receive error information and can recover from failures, deciding whether to retry or respond differently based on the exception. |
| [`ai_function_with_approval.py`](ai_function_with_approval.py) | Shows how to implement user approval workflows for function calls without using threads. Demonstrates both streaming and non-streaming approval patterns where users can approve or reject function executions before they run. |
| [`ai_function_with_approval_and_threads.py`](ai_function_with_approval_and_threads.py) | Demonstrates tool approval workflows using threads for automatic conversation history management. Shows how threads simplify approval workflows by automatically storing and retrieving conversation context. Includes both approval and rejection examples. |
| [`ai_function_with_kwargs.py`](ai_function_with_kwargs.py) | Demonstrates how to inject custom arguments (context) into an AI function from the agent's run method. Useful for passing runtime information like access tokens or user IDs that the tool needs but the model shouldn't see. |
| [`ai_function_with_thread_injection.py`](ai_function_with_thread_injection.py) | Shows how to access the current `thread` object inside an AI function via `**kwargs`. |
| [`ai_function_with_max_exceptions.py`](ai_function_with_max_exceptions.py) | Shows how to limit the number of times a tool can fail with exceptions using `max_invocation_exceptions`. Useful for preventing expensive tools from being called repeatedly when they keep failing. |
| [`ai_function_with_max_invocations.py`](ai_function_with_max_invocations.py) | Demonstrates limiting the total number of times a tool can be invoked using `max_invocations`. Useful for rate-limiting expensive operations or ensuring tools are only called a specific number of times per conversation. |
| [`ai_functions_in_class.py`](ai_functions_in_class.py) | Shows how to use `ai_function` decorator with class methods to create stateful tools. Demonstrates how class state can control tool behavior dynamically, allowing you to adjust tool functionality at runtime by modifying class properties. |
@@ -0,0 +1,53 @@
# Copyright (c) Microsoft. All rights reserved.
import asyncio
from typing import Annotated, Any
from agent_framework import ai_function
from agent_framework.openai import OpenAIResponsesClient
from pydantic import Field
"""
AI Function with kwargs Example
This example demonstrates how to inject custom keyword arguments (kwargs) into an AI function
from the agent's run method, without exposing them to the AI model.
This is useful for passing runtime information like access tokens, user IDs, or
request-specific context that the tool needs but the model shouldn't know about
or provide.
"""
# Define the function tool with **kwargs to accept injected arguments
@ai_function
def get_weather(
location: Annotated[str, Field(description="The location to get the weather for.")],
**kwargs: Any,
) -> str:
"""Get the weather for a given location."""
# Extract the injected argument from kwargs
user_id = kwargs.get("user_id", "unknown")
# Simulate using the user_id for logging or personalization
print(f"Getting weather for user: {user_id}")
return f"The weather in {location} is cloudy with a high of 15°C."
async def main() -> None:
agent = OpenAIResponsesClient().create_agent(
name="WeatherAgent",
instructions="You are a helpful weather assistant.",
tools=[get_weather],
)
# Pass the injected argument when running the agent
# The 'user_id' kwarg will be passed down to the tool execution via **kwargs
response = await agent.run("What is the weather like in Amsterdam?", user_id="user_123")
print(f"Agent: {response.text}")
if __name__ == "__main__":
asyncio.run(main())
@@ -0,0 +1,52 @@
# Copyright (c) Microsoft. All rights reserved.
import asyncio
from typing import Annotated, Any
from agent_framework import AgentThread, ai_function
from agent_framework.openai import OpenAIChatClient
from pydantic import Field
"""
AI Function with Thread Injection Example
This example demonstrates the behavior when passing 'thread' to agent.run()
and accessing that thread in AI function.
"""
# Define the function tool with **kwargs
@ai_function
async def get_weather(
location: Annotated[str, Field(description="The location to get the weather for.")],
**kwargs: Any,
) -> str:
"""Get the weather for a given location."""
# Get thread object from kwargs
thread = kwargs.get("thread")
if thread and isinstance(thread, AgentThread):
if thread.message_store:
messages = await thread.message_store.list_messages()
print(f"Thread contains {len(messages)} messages.")
elif thread.service_thread_id:
print(f"Thread ID: {thread.service_thread_id}.")
return f"The weather in {location} is cloudy."
async def main() -> None:
agent = OpenAIChatClient().create_agent(
name="WeatherAgent", instructions="You are a helpful weather assistant.", tools=[get_weather]
)
# Create a thread
thread = agent.get_new_thread()
# Run the agent with the thread
print(f"Agent: {await agent.run('What is the weather in London?', thread=thread)}")
print(f"Agent: {await agent.run('What is the weather in Amsterdam?', thread=thread)}")
print(f"Agent: {await agent.run('What cities did I ask about?', thread=thread)}")
if __name__ == "__main__":
asyncio.run(main())