Python: [BREAKING] Remove deprecated kwargs compatibility paths (#4858)

* [BREAKING] Remove deprecated kwargs compatibility paths

Remove the deprecated kwargs compatibility shims across core agents, clients, tools, middleware, and telemetry.

Keep workflow kwargs behavior intact in this branch and follow up separately in #4850.

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>

* Fix PR CI fallout for kwargs removal

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>

* Address PR review feedback

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>

* updates

* Fix Azure AI CI fallout

Remove the stale _get_current_conversation_id override from the Azure AI client after the OpenAI base helper was deleted.

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>

* fixed new classes

* Fix Assistants deprecated import gating

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>

* Fix integration replay regressions

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>

* Switch multi-agent hosting samples to Azure chat completions

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>

* Simplify Azure multi-agent sample config

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>

---------

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
This commit is contained in:
Eduard van Valkenburg
2026-03-27 22:00:12 +01:00
committed by GitHub
Unverified
parent ca6cdd142e
commit b1b528e4a8
52 changed files with 1136 additions and 971 deletions
@@ -210,6 +210,7 @@ jobs:
OPENAI_API_KEY: ${{ secrets.OPENAI__APIKEY }}
AZURE_OPENAI_ENDPOINT: ${{ vars.AZUREOPENAI__ENDPOINT }}
AZURE_OPENAI_DEPLOYMENT_NAME: ${{ vars.AZUREOPENAI__RESPONSESDEPLOYMENTNAME }}
AZURE_OPENAI_CHAT_DEPLOYMENT_NAME: ${{ vars.AZUREOPENAI__CHATDEPLOYMENTNAME }}
FOUNDRY_PROJECT_ENDPOINT: ${{ vars.FOUNDRY_PROJECT_ENDPOINT }}
FOUNDRY_MODEL: ${{ vars.FOUNDRY_MODEL }}
FUNCTIONS_WORKER_RUNTIME: "python"
+1
View File
@@ -341,6 +341,7 @@ jobs:
OPENAI_API_KEY: ${{ secrets.OPENAI__APIKEY }}
AZURE_OPENAI_ENDPOINT: ${{ vars.AZUREOPENAI__ENDPOINT }}
AZURE_OPENAI_DEPLOYMENT_NAME: ${{ vars.AZUREOPENAI__RESPONSESDEPLOYMENTNAME }}
AZURE_OPENAI_CHAT_DEPLOYMENT_NAME: ${{ vars.AZUREOPENAI__CHATDEPLOYMENTNAME }}
FOUNDRY_PROJECT_ENDPOINT: ${{ vars.FOUNDRY_PROJECT_ENDPOINT }}
FOUNDRY_MODEL: ${{ vars.FOUNDRY_MODEL }}
FUNCTIONS_WORKER_RUNTIME: "python"
+8 -8
View File
@@ -366,7 +366,7 @@ def test_get_uri_data_invalid_uri() -> None:
def test_parse_contents_from_a2a_conversion(a2a_agent: A2AAgent) -> None:
"""Test A2A parts to contents conversion."""
agent = A2AAgent(name="Test Agent", client=MockA2AClient(), _http_client=None)
agent = A2AAgent(name="Test Agent", client=MockA2AClient(), http_client=None)
# Create A2A parts
parts = [Part(root=TextPart(text="First part")), Part(root=TextPart(text="Second part"))]
@@ -485,7 +485,7 @@ async def test_context_manager_no_cleanup_when_no_http_client() -> None:
mock_a2a_client = MagicMock()
agent = A2AAgent(client=mock_a2a_client, _http_client=None)
agent = A2AAgent(client=mock_a2a_client, http_client=None)
# This should not raise any errors
async with agent:
@@ -495,7 +495,7 @@ async def test_context_manager_no_cleanup_when_no_http_client() -> None:
def test_prepare_message_for_a2a_with_multiple_contents() -> None:
"""Test conversion of Message with multiple contents."""
agent = A2AAgent(client=MagicMock(), _http_client=None)
agent = A2AAgent(client=MagicMock(), http_client=None)
# Create message with multiple content types
message = Message(
@@ -523,7 +523,7 @@ def test_prepare_message_for_a2a_with_multiple_contents() -> None:
def test_prepare_message_for_a2a_forwards_context_id() -> None:
"""Test conversion of Message preserves context_id without duplicating it in metadata."""
agent = A2AAgent(client=MagicMock(), _http_client=None)
agent = A2AAgent(client=MagicMock(), http_client=None)
message = Message(
role="user",
@@ -540,7 +540,7 @@ def test_prepare_message_for_a2a_forwards_context_id() -> None:
def test_parse_contents_from_a2a_with_data_part() -> None:
"""Test conversion of A2A DataPart."""
agent = A2AAgent(client=MagicMock(), _http_client=None)
agent = A2AAgent(client=MagicMock(), http_client=None)
# Create DataPart
data_part = Part(root=DataPart(data={"key": "value", "number": 42}, metadata={"source": "test"}))
@@ -556,7 +556,7 @@ def test_parse_contents_from_a2a_with_data_part() -> None:
def test_parse_contents_from_a2a_unknown_part_kind() -> None:
"""Test error handling for unknown A2A part kind."""
agent = A2AAgent(client=MagicMock(), _http_client=None)
agent = A2AAgent(client=MagicMock(), http_client=None)
# Create a mock part with unknown kind
mock_part = MagicMock()
@@ -569,7 +569,7 @@ def test_parse_contents_from_a2a_unknown_part_kind() -> None:
def test_prepare_message_for_a2a_with_hosted_file() -> None:
"""Test conversion of Message with HostedFileContent to A2A message."""
agent = A2AAgent(client=MagicMock(), _http_client=None)
agent = A2AAgent(client=MagicMock(), http_client=None)
# Create message with hosted file content
message = Message(
@@ -595,7 +595,7 @@ def test_prepare_message_for_a2a_with_hosted_file() -> None:
def test_parse_contents_from_a2a_with_hosted_file_uri() -> None:
"""Test conversion of A2A FilePart with hosted file URI back to UriContent."""
agent = A2AAgent(client=MagicMock(), _http_client=None)
agent = A2AAgent(client=MagicMock(), http_client=None)
# Create FilePart with hosted file URI (simulating what A2A would send back)
file_part = Part(
@@ -445,6 +445,8 @@ class AzureAIAgentsProvider(Generic[OptionsCoT]):
# Merge tools: convert agent's hosted tools + user-provided function tools
merged_tools = self._merge_tools(agent.tools, provided_tools)
merged_default_options: dict[str, Any] = dict(default_options) if default_options is not None else {}
merged_default_options.setdefault("model_id", agent.model)
return Agent( # type: ignore[return-value]
client=client,
@@ -452,9 +454,8 @@ class AzureAIAgentsProvider(Generic[OptionsCoT]):
name=agent.name,
description=agent.description,
instructions=agent.instructions,
model_id=agent.model,
tools=merged_tools,
default_options=default_options, # type: ignore[arg-type]
default_options=cast(Any, merged_default_options),
middleware=middleware,
context_providers=context_providers,
)
@@ -603,11 +603,6 @@ class RawAzureAIClient(RawOpenAIChatClient[AzureAIClientOptionsT], Generic[Azure
return transformed
@override
def _get_current_conversation_id(self, options: Mapping[str, Any], **kwargs: Any) -> str | None:
"""Get the current conversation ID from chat options or kwargs."""
return options.get("conversation_id") or kwargs.get("conversation_id") or self.conversation_id
@override
def _parse_response_from_openai(
self,
@@ -24,7 +24,10 @@ from agent_framework._telemetry import AGENT_FRAMEWORK_USER_AGENT, APP_INFO, pre
from agent_framework._tools import FunctionInvocationConfiguration, FunctionInvocationLayer
from agent_framework._types import Annotation, Content
from agent_framework.observability import ChatTelemetryLayer, EmbeddingTelemetryLayer
from agent_framework_openai._assistants_client import OpenAIAssistantsClient, OpenAIAssistantsOptions
from agent_framework_openai._assistants_client import (
OpenAIAssistantsClient, # type: ignore[reportDeprecated]
OpenAIAssistantsOptions,
)
from agent_framework_openai._chat_client import OpenAIChatOptions, RawOpenAIChatClient
from agent_framework_openai._chat_completion_client import OpenAIChatCompletionOptions, RawOpenAIChatCompletionClient
from agent_framework_openai._embedding_client import OpenAIEmbeddingOptions, RawOpenAIEmbeddingClient
@@ -673,7 +676,8 @@ AzureOpenAIAssistantsOptions = OpenAIAssistantsOptions
"Use OpenAIAssistantsClient (also deprecated) or migrate to OpenAIChatClient."
)
class AzureOpenAIAssistantsClient(
OpenAIAssistantsClient[AzureOpenAIAssistantsOptionsT], Generic[AzureOpenAIAssistantsOptionsT]
OpenAIAssistantsClient[AzureOpenAIAssistantsOptionsT], # type: ignore[reportDeprecated]
Generic[AzureOpenAIAssistantsOptionsT],
):
"""Deprecated Azure OpenAI Assistants client. Use OpenAIAssistantsClient or migrate to OpenAIChatClient."""
@@ -5,7 +5,7 @@ from __future__ import annotations
import logging
import sys
from collections.abc import Callable, Mapping, MutableMapping, Sequence
from typing import Any, Generic
from typing import Any, Generic, cast
from agent_framework import (
AGENT_FRAMEWORK_USER_AGENT,
@@ -398,6 +398,8 @@ class AzureAIProjectAgentProvider(Generic[OptionsCoT]):
# from_azure_ai_tools converts hosted tools (MCP, code interpreter, file search, web search)
# but function tools need the actual implementations from provided_tools
merged_tools = self._merge_tools(details.definition.tools, provided_tools)
merged_default_options: dict[str, Any] = dict(default_options) if default_options is not None else {}
merged_default_options.setdefault("model_id", details.definition.model)
return Agent( # type: ignore[return-value]
client=client,
@@ -405,9 +407,8 @@ class AzureAIProjectAgentProvider(Generic[OptionsCoT]):
name=details.name,
description=details.description,
instructions=details.definition.instructions,
model_id=details.definition.model,
tools=merged_tools,
default_options=default_options, # type: ignore[arg-type]
default_options=cast(Any, merged_default_options),
middleware=middleware,
context_providers=context_providers,
)
@@ -477,7 +477,9 @@ async def test_integration_client_agent_existing_session():
) as first_agent:
# Start a conversation and capture the session
session = first_agent.create_session()
first_response = await first_agent.run("My hobby is photography. Remember this.", session=session, store=True)
first_response = await first_agent.run(
"My hobby is photography. Remember this.", session=session, options={"store": True}
)
assert isinstance(first_response, AgentResponse)
assert first_response.text is not None
@@ -492,7 +494,9 @@ async def test_integration_client_agent_existing_session():
instructions="You are a helpful assistant with good memory.",
) as second_agent:
# Reuse the preserved session
second_response = await second_agent.run("What is my hobby?", session=preserved_session)
second_response = await second_agent.run(
"What is my hobby?", session=preserved_session, options={"store": True}
)
assert isinstance(second_response, AgentResponse)
assert second_response.text is not None
@@ -7,7 +7,7 @@ import logging
import sys
from collections.abc import AsyncIterable, Awaitable, Callable, MutableMapping, Sequence
from pathlib import Path
from typing import TYPE_CHECKING, Any, ClassVar, Generic, Literal, overload
from typing import TYPE_CHECKING, Any, ClassVar, Generic, Literal, cast, overload
from agent_framework import (
AgentMiddlewareTypes,
@@ -584,7 +584,7 @@ class RawClaudeAgent(BaseAgent, Generic[OptionsT]):
return AgentResponse.from_updates(updates, value=structured_output)
@overload
def run(
def run( # type: ignore[override]
self,
messages: AgentRunInputs | None = None,
*,
@@ -595,7 +595,7 @@ class RawClaudeAgent(BaseAgent, Generic[OptionsT]):
) -> Awaitable[AgentResponse[Any]]: ...
@overload
def run(
def run( # type: ignore[override]
self,
messages: AgentRunInputs | None = None,
*,
@@ -747,3 +747,71 @@ class ClaudeAgent(AgentTelemetryLayer, RawClaudeAgent[OptionsT], Generic[Options
response = await agent.run("Hello!")
print(response.text)
"""
@overload # type: ignore[override]
def run(
self,
messages: AgentRunInputs | None = None,
*,
stream: Literal[False] = ...,
session: AgentSession | None = None,
middleware: Sequence[AgentMiddlewareTypes] | None = None,
options: OptionsT | None = None,
tools: ToolTypes | Callable[..., Any] | Sequence[ToolTypes | Callable[..., Any]] | None = None,
compaction_strategy: Any = None,
tokenizer: Any = None,
function_invocation_kwargs: dict[str, Any] | None = None,
client_kwargs: dict[str, Any] | None = None,
**kwargs: Any,
) -> Awaitable[AgentResponse[Any]]: ...
@overload # type: ignore[override]
def run(
self,
messages: AgentRunInputs | None = None,
*,
stream: Literal[True],
session: AgentSession | None = None,
middleware: Sequence[AgentMiddlewareTypes] | None = None,
options: OptionsT | None = None,
tools: ToolTypes | Callable[..., Any] | Sequence[ToolTypes | Callable[..., Any]] | None = None,
compaction_strategy: Any = None,
tokenizer: Any = None,
function_invocation_kwargs: dict[str, Any] | None = None,
client_kwargs: dict[str, Any] | None = None,
**kwargs: Any,
) -> ResponseStream[AgentResponseUpdate, AgentResponse[Any]]: ...
def run( # pyright: ignore[reportIncompatibleMethodOverride] # type: ignore[override]
self,
messages: AgentRunInputs | None = None,
*,
stream: bool = False,
session: AgentSession | None = None,
middleware: Sequence[AgentMiddlewareTypes] | None = None,
options: OptionsT | None = None,
tools: ToolTypes | Callable[..., Any] | Sequence[ToolTypes | Callable[..., Any]] | None = None,
compaction_strategy: Any = None,
tokenizer: Any = None,
function_invocation_kwargs: dict[str, Any] | None = None,
client_kwargs: dict[str, Any] | None = None,
**kwargs: Any,
) -> Awaitable[AgentResponse[Any]] | ResponseStream[AgentResponseUpdate, AgentResponse[Any]]:
"""Run the Claude agent with telemetry enabled."""
super_run = cast(
"Callable[..., Awaitable[AgentResponse[Any]] | ResponseStream[AgentResponseUpdate, AgentResponse[Any]]]",
super().run,
)
return super_run(
messages=messages,
stream=stream,
session=session,
middleware=middleware,
options=options,
tools=tools,
compaction_strategy=compaction_strategy,
tokenizer=tokenizer,
function_invocation_kwargs=function_invocation_kwargs,
client_kwargs=client_kwargs,
**kwargs,
)
+52 -55
View File
@@ -5,7 +5,6 @@ from __future__ import annotations
import logging
import re
import sys
import warnings
from collections.abc import Awaitable, Callable, Mapping, MutableMapping, Sequence
from contextlib import AbstractAsyncContextManager, AsyncExitStack
from copy import deepcopy
@@ -248,7 +247,6 @@ class SupportsAgentRun(Protocol):
session: AgentSession | None = None,
function_invocation_kwargs: Mapping[str, Any] | None = None,
client_kwargs: Mapping[str, Any] | None = None,
**kwargs: Any,
) -> Awaitable[AgentResponse[Any]]:
"""Get a response from the agent (non-streaming)."""
...
@@ -262,7 +260,6 @@ class SupportsAgentRun(Protocol):
session: AgentSession | None = None,
function_invocation_kwargs: Mapping[str, Any] | None = None,
client_kwargs: Mapping[str, Any] | None = None,
**kwargs: Any,
) -> ResponseStream[AgentResponseUpdate, AgentResponse[Any]]:
"""Get a streaming response from the agent."""
...
@@ -275,7 +272,6 @@ class SupportsAgentRun(Protocol):
session: AgentSession | None = None,
function_invocation_kwargs: Mapping[str, Any] | None = None,
client_kwargs: Mapping[str, Any] | None = None,
**kwargs: Any,
) -> Awaitable[AgentResponse[Any]] | ResponseStream[AgentResponseUpdate, AgentResponse[Any]]:
"""Get a response from the agent.
@@ -291,7 +287,6 @@ class SupportsAgentRun(Protocol):
session: The conversation session associated with the message(s).
function_invocation_kwargs: Keyword arguments forwarded to tool invocation.
client_kwargs: Additional client-specific keyword arguments.
kwargs: Additional keyword arguments.
Returns:
When stream=False: An AgentResponse with the final result.
@@ -334,7 +329,15 @@ class BaseAgent(SerializationMixin):
# Create a concrete subclass that implements the protocol
class SimpleAgent(BaseAgent):
async def run(self, messages=None, *, stream=False, session=None, **kwargs):
async def run(
self,
messages=None,
*,
stream=False,
session=None,
function_invocation_kwargs=None,
client_kwargs=None,
):
if stream:
async def _stream():
@@ -373,7 +376,6 @@ class BaseAgent(SerializationMixin):
context_providers: Sequence[BaseContextProvider] | None = None,
middleware: Sequence[MiddlewareTypes] | None = None,
additional_properties: MutableMapping[str, Any] | None = None,
**kwargs: Any,
) -> None:
"""Initialize a BaseAgent instance.
@@ -385,15 +387,7 @@ class BaseAgent(SerializationMixin):
context_providers: Context providers to include during agent invocation.
middleware: List of middleware.
additional_properties: Additional properties set on the agent.
kwargs: Additional keyword arguments (merged into additional_properties).
"""
if kwargs:
warnings.warn(
"Passing additional properties as direct keyword arguments to BaseAgent is deprecated; "
"pass them via additional_properties instead.",
DeprecationWarning,
stacklevel=3,
)
if id is None:
id = str(uuid4())
self.id = id
@@ -403,10 +397,7 @@ class BaseAgent(SerializationMixin):
self.middleware: list[MiddlewareTypes] | None = (
cast(list[MiddlewareTypes], middleware) if middleware is not None else None
)
# Merge kwargs into additional_properties
self.additional_properties: dict[str, Any] = cast(dict[str, Any], additional_properties or {})
self.additional_properties.update(kwargs)
def create_session(self, *, session_id: str | None = None) -> AgentSession:
"""Create a new lightweight session.
@@ -666,9 +657,10 @@ class RawAgent(BaseAgent, Generic[OptionsCoT]): # type: ignore[misc]
tools: ToolTypes | Callable[..., Any] | Sequence[ToolTypes | Callable[..., Any]] | None = None,
default_options: OptionsCoT | None = None,
context_providers: Sequence[BaseContextProvider] | None = None,
middleware: Sequence[MiddlewareTypes] | None = None,
compaction_strategy: CompactionStrategy | None = None,
tokenizer: TokenizerProtocol | None = None,
**kwargs: Any,
additional_properties: MutableMapping[str, Any] | None = None,
) -> None:
"""Initialize a Agent instance.
@@ -695,7 +687,7 @@ class RawAgent(BaseAgent, Generic[OptionsCoT]): # type: ignore[misc]
If both this and a compaction_strategy on the underlying client are set, this one is used.
tokenizer: Optional agent-level tokenizer.
If both this and a tokenizer on the underlying client are set, this one is used.
kwargs: Any additional keyword arguments. Will be stored as ``additional_properties``.
additional_properties: Additional properties stored on the agent.
"""
opts = dict(default_options) if default_options else {}
@@ -709,7 +701,8 @@ class RawAgent(BaseAgent, Generic[OptionsCoT]): # type: ignore[misc]
name=name,
description=description,
context_providers=context_providers,
**kwargs,
middleware=middleware,
additional_properties=additional_properties,
)
self.client = client
self.compaction_strategy = compaction_strategy
@@ -812,7 +805,6 @@ class RawAgent(BaseAgent, Generic[OptionsCoT]): # type: ignore[misc]
tokenizer: TokenizerProtocol | None = None,
function_invocation_kwargs: Mapping[str, Any] | None = None,
client_kwargs: Mapping[str, Any] | None = None,
**kwargs: Any,
) -> Awaitable[AgentResponse[ResponseModelBoundT]]: ...
@overload
@@ -828,7 +820,6 @@ class RawAgent(BaseAgent, Generic[OptionsCoT]): # type: ignore[misc]
tokenizer: TokenizerProtocol | None = None,
function_invocation_kwargs: Mapping[str, Any] | None = None,
client_kwargs: Mapping[str, Any] | None = None,
**kwargs: Any,
) -> Awaitable[AgentResponse[Any]]: ...
@overload
@@ -844,7 +835,6 @@ class RawAgent(BaseAgent, Generic[OptionsCoT]): # type: ignore[misc]
tokenizer: TokenizerProtocol | None = None,
function_invocation_kwargs: Mapping[str, Any] | None = None,
client_kwargs: Mapping[str, Any] | None = None,
**kwargs: Any,
) -> ResponseStream[AgentResponseUpdate, AgentResponse[Any]]: ...
def run(
@@ -859,7 +849,6 @@ class RawAgent(BaseAgent, Generic[OptionsCoT]): # type: ignore[misc]
tokenizer: TokenizerProtocol | None = None,
function_invocation_kwargs: Mapping[str, Any] | None = None,
client_kwargs: Mapping[str, Any] | None = None,
**kwargs: Any,
) -> Awaitable[AgentResponse[Any]] | ResponseStream[AgentResponseUpdate, AgentResponse[Any]]:
"""Run the agent with the given messages and options.
@@ -890,21 +879,12 @@ class RawAgent(BaseAgent, Generic[OptionsCoT]): # type: ignore[misc]
is used, falling back to the client default.
function_invocation_kwargs: Keyword arguments forwarded to tool invocation.
client_kwargs: Additional client-specific keyword arguments for the chat client.
kwargs: Deprecated additional keyword arguments for the agent.
They are forwarded to both tool invocation and the chat client for compatibility.
Returns:
When stream=False: An Awaitable[AgentResponse] containing the agent's response.
When stream=True: A ResponseStream of AgentResponseUpdate items with
``get_final_response()`` for the final AgentResponse.
"""
if kwargs:
warnings.warn(
"Passing runtime keyword arguments directly to run() is deprecated; pass tool values via "
"function_invocation_kwargs and client-specific values via client_kwargs instead.",
DeprecationWarning,
stacklevel=2,
)
if not stream:
async def _run_non_streaming() -> AgentResponse[Any]:
@@ -915,7 +895,6 @@ class RawAgent(BaseAgent, Generic[OptionsCoT]): # type: ignore[misc]
options=options,
compaction_strategy=compaction_strategy,
tokenizer=tokenizer,
legacy_kwargs=kwargs,
function_invocation_kwargs=function_invocation_kwargs,
client_kwargs=client_kwargs,
)
@@ -1003,7 +982,6 @@ class RawAgent(BaseAgent, Generic[OptionsCoT]): # type: ignore[misc]
options=options,
compaction_strategy=compaction_strategy,
tokenizer=tokenizer,
legacy_kwargs=kwargs,
function_invocation_kwargs=function_invocation_kwargs,
client_kwargs=client_kwargs,
)
@@ -1103,7 +1081,6 @@ class RawAgent(BaseAgent, Generic[OptionsCoT]): # type: ignore[misc]
options: Mapping[str, Any] | None,
compaction_strategy: CompactionStrategy | None,
tokenizer: TokenizerProtocol | None,
legacy_kwargs: Mapping[str, Any],
function_invocation_kwargs: Mapping[str, Any] | None,
client_kwargs: Mapping[str, Any] | None,
) -> _RunContext:
@@ -1176,12 +1153,9 @@ class RawAgent(BaseAgent, Generic[OptionsCoT]): # type: ignore[misc]
duplicate_error_message=mcp_duplicate_message,
)
# TODO(Copilot): Delete once direct ``run(**kwargs)`` compatibility is removed.
# Legacy compatibility still fans out direct run kwargs into tool runtime kwargs.
effective_function_invocation_kwargs = {
**dict(legacy_kwargs),
**(dict(function_invocation_kwargs) if function_invocation_kwargs is not None else {}),
}
effective_function_invocation_kwargs = (
dict(function_invocation_kwargs) if function_invocation_kwargs is not None else {}
)
additional_function_arguments = {**effective_function_invocation_kwargs, **existing_additional_args}
# Build options dict from run() options merged with provided options
@@ -1214,12 +1188,7 @@ class RawAgent(BaseAgent, Generic[OptionsCoT]): # type: ignore[misc]
# Build session_messages from session context: context messages + input messages
session_messages: list[Message] = session_context.get_messages(include_input=True)
# TODO(Copilot): Delete once direct ``run(**kwargs)`` compatibility is removed.
# Legacy compatibility still fans out direct run kwargs into client kwargs.
effective_client_kwargs = {
**dict(legacy_kwargs),
**(dict(client_kwargs) if client_kwargs is not None else {}),
}
effective_client_kwargs = dict(client_kwargs) if client_kwargs is not None else {}
if active_session is not None:
effective_client_kwargs["session"] = active_session
@@ -1499,9 +1468,29 @@ class Agent(
*,
stream: Literal[False] = ...,
session: AgentSession | None = None,
middleware: Sequence[MiddlewareTypes] | None = None,
tools: ToolTypes | Callable[..., Any] | Sequence[ToolTypes | Callable[..., Any]] | None = None,
options: ChatOptions[ResponseModelBoundT],
compaction_strategy: CompactionStrategy | None = None,
tokenizer: TokenizerProtocol | None = None,
function_invocation_kwargs: Mapping[str, Any] | None = None,
client_kwargs: Mapping[str, Any] | None = None,
) -> Awaitable[AgentResponse[ResponseModelBoundT]]: ...
@overload
def run(
self,
messages: AgentRunInputs | None = None,
*,
stream: Literal[False] = ...,
session: AgentSession | None = None,
middleware: Sequence[MiddlewareTypes] | None = None,
tools: ToolTypes | Callable[..., Any] | Sequence[ToolTypes | Callable[..., Any]] | None = None,
options: OptionsCoT | ChatOptions[None] | None = None,
compaction_strategy: CompactionStrategy | None = None,
tokenizer: TokenizerProtocol | None = None,
function_invocation_kwargs: Mapping[str, Any] | None = None,
client_kwargs: Mapping[str, Any] | None = None,
**kwargs: Any,
) -> Awaitable[AgentResponse[Any]]: ...
@overload
@@ -1511,9 +1500,13 @@ class Agent(
*,
stream: Literal[True],
session: AgentSession | None = None,
middleware: Sequence[MiddlewareTypes] | None = None,
tools: ToolTypes | Callable[..., Any] | Sequence[ToolTypes | Callable[..., Any]] | None = None,
options: OptionsCoT | ChatOptions[Any] | None = None,
compaction_strategy: CompactionStrategy | None = None,
tokenizer: TokenizerProtocol | None = None,
function_invocation_kwargs: Mapping[str, Any] | None = None,
client_kwargs: Mapping[str, Any] | None = None,
**kwargs: Any,
) -> ResponseStream[AgentResponseUpdate, AgentResponse[Any]]: ...
def run(
@@ -1523,10 +1516,12 @@ class Agent(
stream: bool = False,
session: AgentSession | None = None,
middleware: Sequence[MiddlewareTypes] | None = None,
tools: ToolTypes | Callable[..., Any] | Sequence[ToolTypes | Callable[..., Any]] | None = None,
options: OptionsCoT | ChatOptions[Any] | None = None,
compaction_strategy: CompactionStrategy | None = None,
tokenizer: TokenizerProtocol | None = None,
function_invocation_kwargs: Mapping[str, Any] | None = None,
client_kwargs: Mapping[str, Any] | None = None,
**kwargs: Any,
) -> Awaitable[AgentResponse[Any]] | ResponseStream[AgentResponseUpdate, AgentResponse[Any]]:
"""Run the agent."""
super_run = cast(
@@ -1538,10 +1533,12 @@ class Agent(
stream=stream,
session=session,
middleware=middleware,
tools=tools,
options=options,
compaction_strategy=compaction_strategy,
tokenizer=tokenizer,
function_invocation_kwargs=function_invocation_kwargs,
client_kwargs=client_kwargs,
**kwargs,
)
def __init__(
@@ -1558,7 +1555,7 @@ class Agent(
middleware: Sequence[MiddlewareTypes] | None = None,
compaction_strategy: CompactionStrategy | None = None,
tokenizer: TokenizerProtocol | None = None,
**kwargs: Any,
additional_properties: MutableMapping[str, Any] | None = None,
) -> None:
"""Initialize a Agent instance."""
super().__init__(
@@ -1573,7 +1570,7 @@ class Agent(
middleware=middleware,
compaction_strategy=compaction_strategy,
tokenizer=tokenizer,
**kwargs,
additional_properties=additional_properties,
)
@@ -4,7 +4,6 @@ from __future__ import annotations
import logging
import sys
import warnings
from abc import ABC, abstractmethod
from collections.abc import (
AsyncIterable,
@@ -139,7 +138,8 @@ class SupportsChatGetResponse(Protocol[OptionsContraT]):
options: ChatOptions[ResponseModelBoundT],
compaction_strategy: CompactionStrategy | None = None,
tokenizer: TokenizerProtocol | None = None,
**kwargs: Any,
function_invocation_kwargs: Mapping[str, Any] | None = None,
client_kwargs: Mapping[str, Any] | None = None,
) -> Awaitable[ChatResponse[ResponseModelBoundT]]: ...
@overload
@@ -153,7 +153,6 @@ class SupportsChatGetResponse(Protocol[OptionsContraT]):
tokenizer: TokenizerProtocol | None = None,
function_invocation_kwargs: Mapping[str, Any] | None = None,
client_kwargs: Mapping[str, Any] | None = None,
**kwargs: Any,
) -> Awaitable[ChatResponse[Any]]: ...
@overload
@@ -167,7 +166,6 @@ class SupportsChatGetResponse(Protocol[OptionsContraT]):
tokenizer: TokenizerProtocol | None = None,
function_invocation_kwargs: Mapping[str, Any] | None = None,
client_kwargs: Mapping[str, Any] | None = None,
**kwargs: Any,
) -> ResponseStream[ChatResponseUpdate, ChatResponse[Any]]: ...
def get_response(
@@ -180,7 +178,6 @@ class SupportsChatGetResponse(Protocol[OptionsContraT]):
tokenizer: TokenizerProtocol | None = None,
function_invocation_kwargs: Mapping[str, Any] | None = None,
client_kwargs: Mapping[str, Any] | None = None,
**kwargs: Any,
) -> Awaitable[ChatResponse[Any]] | ResponseStream[ChatResponseUpdate, ChatResponse[Any]]:
"""Send input and return the response.
@@ -192,7 +189,6 @@ class SupportsChatGetResponse(Protocol[OptionsContraT]):
tokenizer: Optional per-call tokenizer override.
function_invocation_kwargs: Keyword arguments forwarded only to tool invocation layers.
client_kwargs: Additional client-specific keyword arguments.
**kwargs: Deprecated additional client-specific keyword arguments.
Returns:
When stream=False: An awaitable ChatResponse from the client.
@@ -296,7 +292,6 @@ class BaseChatClient(SerializationMixin, ABC, Generic[OptionsCoT]):
compaction_strategy: CompactionStrategy | None = None,
tokenizer: TokenizerProtocol | None = None,
additional_properties: dict[str, Any] | None = None,
**kwargs: Any,
) -> None:
"""Initialize a BaseChatClient instance.
@@ -304,19 +299,10 @@ class BaseChatClient(SerializationMixin, ABC, Generic[OptionsCoT]):
compaction_strategy: Optional compaction strategy to apply before model calls.
tokenizer: Optional tokenizer used by token-aware compaction strategies.
additional_properties: Additional properties for the client.
kwargs: Additional keyword arguments (merged into additional_properties for now).
"""
self.additional_properties = additional_properties or {}
self.compaction_strategy = compaction_strategy
self.tokenizer = tokenizer
if kwargs:
warnings.warn(
"Passing additional properties as direct keyword arguments to BaseChatClient is deprecated; "
"pass them via additional_properties instead.",
DeprecationWarning,
stacklevel=3,
)
self.additional_properties.update(kwargs)
super().__init__()
def to_dict(self, *, exclude: set[str] | None = None, exclude_none: bool = True) -> dict[str, Any]:
@@ -457,7 +443,8 @@ class BaseChatClient(SerializationMixin, ABC, Generic[OptionsCoT]):
options: ChatOptions[ResponseModelBoundT],
compaction_strategy: CompactionStrategy | None = None,
tokenizer: TokenizerProtocol | None = None,
**kwargs: Any,
function_invocation_kwargs: Mapping[str, Any] | None = None,
client_kwargs: Mapping[str, Any] | None = None,
) -> Awaitable[ChatResponse[ResponseModelBoundT]]: ...
@overload
@@ -469,7 +456,8 @@ class BaseChatClient(SerializationMixin, ABC, Generic[OptionsCoT]):
options: OptionsCoT | ChatOptions[None] | None = None,
compaction_strategy: CompactionStrategy | None = None,
tokenizer: TokenizerProtocol | None = None,
**kwargs: Any,
function_invocation_kwargs: Mapping[str, Any] | None = None,
client_kwargs: Mapping[str, Any] | None = None,
) -> Awaitable[ChatResponse[Any]]: ...
@overload
@@ -481,7 +469,8 @@ class BaseChatClient(SerializationMixin, ABC, Generic[OptionsCoT]):
options: OptionsCoT | ChatOptions[Any] | None = None,
compaction_strategy: CompactionStrategy | None = None,
tokenizer: TokenizerProtocol | None = None,
**kwargs: Any,
function_invocation_kwargs: Mapping[str, Any] | None = None,
client_kwargs: Mapping[str, Any] | None = None,
) -> ResponseStream[ChatResponseUpdate, ChatResponse[Any]]: ...
def get_response(
@@ -492,7 +481,8 @@ class BaseChatClient(SerializationMixin, ABC, Generic[OptionsCoT]):
options: OptionsCoT | ChatOptions[Any] | None = None,
compaction_strategy: CompactionStrategy | None = None,
tokenizer: TokenizerProtocol | None = None,
**kwargs: Any,
function_invocation_kwargs: Mapping[str, Any] | None = None,
client_kwargs: Mapping[str, Any] | None = None,
) -> Awaitable[ChatResponse[Any]] | ResponseStream[ChatResponseUpdate, ChatResponse[Any]]:
"""Get a response from a chat client.
@@ -504,13 +494,9 @@ class BaseChatClient(SerializationMixin, ABC, Generic[OptionsCoT]):
When omitted, the client-level default is used.
tokenizer: Optional per-call tokenizer override. When omitted, the
client-level default is used.
**kwargs: Additional compatibility keyword arguments. Lower chat-client layers do not
consume ``function_invocation_kwargs`` directly; if present, it is ignored here
because function invocation has already been handled by upper layers. If a
``client_kwargs`` mapping is present, it is flattened into standard keyword
arguments before forwarding to ``_inner_get_response()`` so client implementations
can leverage those values, while implementations that ignore
extra kwargs remain compatible.
function_invocation_kwargs: Keyword arguments forwarded only to tool invocation layers.
client_kwargs: Additional client-specific keyword arguments forwarded to
``_inner_get_response()``.
Returns:
When streaming a response stream of ChatResponseUpdates, otherwise an Awaitable ChatResponse.
@@ -519,14 +505,7 @@ class BaseChatClient(SerializationMixin, ABC, Generic[OptionsCoT]):
compaction_strategy=compaction_strategy,
tokenizer=tokenizer,
)
compatibility_client_kwargs = kwargs.pop("client_kwargs", None)
kwargs.pop("function_invocation_kwargs", None)
merged_client_kwargs = (
dict(cast(Mapping[str, Any], compatibility_client_kwargs))
if isinstance(compatibility_client_kwargs, Mapping)
else {}
)
merged_client_kwargs.update(kwargs)
merged_client_kwargs = dict(client_kwargs) if client_kwargs is not None else {}
if not compaction_overrides:
return self._inner_get_response(
+2 -1
View File
@@ -768,7 +768,8 @@ class MCPTool:
options["stop"] = params.stopSequences
try:
response = await self.client.get_response(
chat_client: Any = self.client
response: Any = await chat_client.get_response(
messages,
options=options or None,
)
@@ -39,7 +39,7 @@ if TYPE_CHECKING:
from ._clients import SupportsChatGetResponse
from ._compaction import CompactionStrategy, TokenizerProtocol
from ._sessions import AgentSession
from ._tools import FunctionTool
from ._tools import FunctionTool, ToolTypes
from ._types import ChatOptions, ChatResponse, ChatResponseUpdate
ResponseModelBoundT = TypeVar("ResponseModelBoundT", bound=BaseModel)
@@ -100,6 +100,7 @@ class AgentContext:
agent: The agent being invoked.
messages: The messages being sent to the agent.
session: The agent session for this invocation, if any.
tools: Run-level tool overrides for this invocation, if any.
options: The options for the agent invocation as a dict.
stream: Whether this is a streaming invocation.
compaction_strategy: Optional per-run compaction override.
@@ -142,6 +143,7 @@ class AgentContext:
agent: SupportsAgentRun,
messages: list[Message],
session: AgentSession | None = None,
tools: ToolTypes | Callable[..., Any] | Sequence[ToolTypes | Callable[..., Any]] | None = None,
options: Mapping[str, Any] | None = None,
stream: bool = False,
compaction_strategy: CompactionStrategy | None = None,
@@ -165,6 +167,7 @@ class AgentContext:
agent: The agent being invoked.
messages: The messages being sent to the agent.
session: The agent session for this invocation, if any.
tools: Run-level tool overrides for this invocation, if any.
options: The options for the agent invocation as a dict.
stream: Whether this is a streaming invocation.
compaction_strategy: Optional per-run compaction override.
@@ -181,6 +184,7 @@ class AgentContext:
self.agent = agent
self.messages = messages
self.session = session
self.tools = tools
self.options = options
self.stream = stream
self.compaction_strategy = compaction_strategy
@@ -1025,7 +1029,7 @@ class ChatMiddlewareLayer(Generic[OptionsCoT]):
compaction_strategy: CompactionStrategy | None = None,
tokenizer: TokenizerProtocol | None = None,
function_invocation_kwargs: Mapping[str, Any] | None = None,
**kwargs: Any,
client_kwargs: Mapping[str, Any] | None = None,
) -> Awaitable[ChatResponse[ResponseModelBoundT]]: ...
@overload
@@ -1039,7 +1043,6 @@ class ChatMiddlewareLayer(Generic[OptionsCoT]):
tokenizer: TokenizerProtocol | None = None,
function_invocation_kwargs: Mapping[str, Any] | None = None,
client_kwargs: Mapping[str, Any] | None = None,
**kwargs: Any,
) -> Awaitable[ChatResponse[Any]]: ...
@overload
@@ -1053,7 +1056,6 @@ class ChatMiddlewareLayer(Generic[OptionsCoT]):
tokenizer: TokenizerProtocol | None = None,
function_invocation_kwargs: Mapping[str, Any] | None = None,
client_kwargs: Mapping[str, Any] | None = None,
**kwargs: Any,
) -> ResponseStream[ChatResponseUpdate, ChatResponse[Any]]: ...
def get_response(
@@ -1066,27 +1068,26 @@ class ChatMiddlewareLayer(Generic[OptionsCoT]):
tokenizer: TokenizerProtocol | None = None,
function_invocation_kwargs: Mapping[str, Any] | None = None,
client_kwargs: Mapping[str, Any] | None = None,
**kwargs: Any,
) -> Awaitable[ChatResponse[Any]] | ResponseStream[ChatResponseUpdate, ChatResponse[Any]]:
"""Execute the chat pipeline if middleware is configured."""
super_get_response = super().get_response # type: ignore[misc]
if compaction_strategy is not None:
kwargs["compaction_strategy"] = compaction_strategy
if tokenizer is not None:
kwargs["tokenizer"] = tokenizer
effective_client_kwargs = dict(client_kwargs) if client_kwargs is not None else {}
call_middleware = effective_client_kwargs.pop("middleware", [])
context_kwargs = dict(effective_client_kwargs)
if compaction_strategy is not None:
context_kwargs["compaction_strategy"] = compaction_strategy
if tokenizer is not None:
context_kwargs["tokenizer"] = tokenizer
pipeline = self._get_chat_middleware_pipeline(call_middleware) # type: ignore[reportUnknownArgumentType]
if not pipeline.has_middlewares:
return super_get_response( # type: ignore[no-any-return]
messages=messages,
stream=stream,
options=options,
compaction_strategy=compaction_strategy,
tokenizer=tokenizer,
function_invocation_kwargs=function_invocation_kwargs,
client_kwargs=effective_client_kwargs,
**kwargs,
)
context = ChatContext(
@@ -1094,7 +1095,7 @@ class ChatMiddlewareLayer(Generic[OptionsCoT]):
messages=list(messages),
options=options,
stream=stream,
kwargs={**effective_client_kwargs, **kwargs},
kwargs=context_kwargs,
function_invocation_kwargs=function_invocation_kwargs,
)
@@ -1180,12 +1181,12 @@ class AgentMiddlewareLayer:
stream: Literal[False] = ...,
session: AgentSession | None = None,
middleware: Sequence[MiddlewareTypes] | None = None,
tools: ToolTypes | Callable[..., Any] | Sequence[ToolTypes | Callable[..., Any]] | None = None,
options: ChatOptions[ResponseModelBoundT],
compaction_strategy: CompactionStrategy | None = None,
tokenizer: TokenizerProtocol | None = None,
function_invocation_kwargs: Mapping[str, Any] | None = None,
client_kwargs: Mapping[str, Any] | None = None,
**kwargs: Any,
) -> Awaitable[AgentResponse[ResponseModelBoundT]]: ...
@overload
@@ -1196,12 +1197,12 @@ class AgentMiddlewareLayer:
stream: Literal[False] = ...,
session: AgentSession | None = None,
middleware: Sequence[MiddlewareTypes] | None = None,
tools: ToolTypes | Callable[..., Any] | Sequence[ToolTypes | Callable[..., Any]] | None = None,
options: ChatOptions[None] | None = None,
compaction_strategy: CompactionStrategy | None = None,
tokenizer: TokenizerProtocol | None = None,
function_invocation_kwargs: Mapping[str, Any] | None = None,
client_kwargs: Mapping[str, Any] | None = None,
**kwargs: Any,
) -> Awaitable[AgentResponse[Any]]: ...
@overload
@@ -1212,12 +1213,12 @@ class AgentMiddlewareLayer:
stream: Literal[True],
session: AgentSession | None = None,
middleware: Sequence[MiddlewareTypes] | None = None,
tools: ToolTypes | Callable[..., Any] | Sequence[ToolTypes | Callable[..., Any]] | None = None,
options: ChatOptions[Any] | None = None,
compaction_strategy: CompactionStrategy | None = None,
tokenizer: TokenizerProtocol | None = None,
function_invocation_kwargs: Mapping[str, Any] | None = None,
client_kwargs: Mapping[str, Any] | None = None,
**kwargs: Any,
) -> ResponseStream[AgentResponseUpdate, AgentResponse[Any]]: ...
def run(
@@ -1227,12 +1228,12 @@ class AgentMiddlewareLayer:
stream: bool = False,
session: AgentSession | None = None,
middleware: Sequence[MiddlewareTypes] | None = None,
tools: ToolTypes | Callable[..., Any] | Sequence[ToolTypes | Callable[..., Any]] | None = None,
options: ChatOptions[Any] | None = None,
compaction_strategy: CompactionStrategy | None = None,
tokenizer: TokenizerProtocol | None = None,
function_invocation_kwargs: Mapping[str, Any] | None = None,
client_kwargs: Mapping[str, Any] | None = None,
**kwargs: Any,
) -> Awaitable[AgentResponse[Any]] | ResponseStream[AgentResponseUpdate, AgentResponse[Any]]:
"""MiddlewareTypes-enabled unified run method."""
# Re-categorize self.middleware at runtime to support dynamic changes
@@ -1263,23 +1264,23 @@ class AgentMiddlewareLayer:
messages,
stream=stream,
session=session,
tools=tools,
options=options,
compaction_strategy=compaction_strategy,
tokenizer=tokenizer,
function_invocation_kwargs=effective_function_invocation_kwargs,
client_kwargs=effective_client_kwargs,
**kwargs,
)
context = AgentContext(
agent=self, # type: ignore[arg-type]
messages=normalize_messages(messages),
session=session,
tools=tools,
options=options,
stream=stream,
compaction_strategy=compaction_strategy,
tokenizer=tokenizer,
kwargs=kwargs,
client_kwargs=effective_client_kwargs,
function_invocation_kwargs=effective_function_invocation_kwargs,
)
@@ -1313,22 +1314,16 @@ class AgentMiddlewareLayer:
def _middleware_handler(
self, context: AgentContext
) -> Awaitable[AgentResponse] | ResponseStream[AgentResponseUpdate, AgentResponse]:
# TODO(Copilot): Delete once direct ``run(**kwargs)`` compatibility is removed.
client_kwargs = {**context.client_kwargs, **context.kwargs}
# TODO(Copilot): Delete once direct ``run(**kwargs)`` compatibility is removed.
function_invocation_kwargs = {
**context.function_invocation_kwargs,
**{k: v for k, v in context.kwargs.items() if k != "middleware"},
}
return super().run( # type: ignore[misc, no-any-return]
context.messages,
stream=context.stream,
session=context.session,
tools=context.tools,
options=context.options,
compaction_strategy=context.compaction_strategy,
tokenizer=context.tokenizer,
function_invocation_kwargs=function_invocation_kwargs,
client_kwargs=client_kwargs,
function_invocation_kwargs=context.function_invocation_kwargs,
client_kwargs=context.client_kwargs,
)
+25 -52
View File
@@ -8,7 +8,6 @@ import json
import logging
import sys
import typing
import warnings
from collections.abc import (
AsyncIterable,
Awaitable,
@@ -344,8 +343,6 @@ class FunctionTool(SerializationMixin):
self._instance = None # Store the instance for bound methods
self._context_parameter_name: str | None = None
self._input_model_explicitly_provided = input_model is not None
# TODO(Copilot): Delete once legacy ``**kwargs`` runtime injection is removed.
self._forward_runtime_kwargs: bool = False
if self.func:
self._discover_injected_parameters()
@@ -390,10 +387,6 @@ class FunctionTool(SerializationMixin):
for name, param in signature.parameters.items():
if name in {"self", "cls"}:
continue
if param.kind == inspect.Parameter.VAR_KEYWORD:
self._forward_runtime_kwargs = True
continue
annotation = type_hints.get(name, param.annotation)
if self._is_context_parameter(name, annotation):
if self._context_parameter_name is not None:
@@ -518,6 +511,7 @@ class FunctionTool(SerializationMixin):
*,
arguments: BaseModel | Mapping[str, Any] | None = None,
context: FunctionInvocationContext | None = None,
tool_call_id: str | None = None,
**kwargs: Any,
) -> list[Content]:
"""Run the AI function with the provided arguments as a Pydantic model.
@@ -530,7 +524,10 @@ class FunctionTool(SerializationMixin):
Keyword Args:
arguments: A mapping or model instance containing the arguments for the function.
context: Explicit function invocation context carrying runtime kwargs.
kwargs: Deprecated keyword arguments to pass to the function. Use ``context`` instead.
tool_call_id: Optional tool call identifier used for telemetry and tracing.
kwargs: Direct function argument values. When provided, every keyword
must match a declared tool parameter. Runtime data must be passed
via ``context``.
Returns:
A list of Content items representing the tool output.
@@ -552,18 +549,13 @@ class FunctionTool(SerializationMixin):
{key: value for key, value in kwargs.items() if key in parameter_names} if arguments is None else {}
)
runtime_kwargs = dict(context.kwargs) if context is not None else {}
deprecated_runtime_kwargs = {
key: value for key, value in kwargs.items() if key not in direct_argument_kwargs and key != "tool_call_id"
}
if deprecated_runtime_kwargs:
warnings.warn(
"Passing runtime keyword arguments directly to FunctionTool.invoke() is deprecated; "
"pass them via FunctionInvocationContext instead.",
DeprecationWarning,
stacklevel=2,
unexpected_kwargs = {key: value for key, value in kwargs.items() if key not in direct_argument_kwargs}
if unexpected_kwargs:
unexpected_names = ", ".join(sorted(unexpected_kwargs))
raise TypeError(
f"Unexpected keyword argument(s) for tool '{self.name}': {unexpected_names}. "
"Pass runtime data via FunctionInvocationContext instead."
)
runtime_kwargs.update(deprecated_runtime_kwargs)
tool_call_id = kwargs.get("tool_call_id", runtime_kwargs.pop("tool_call_id", None))
if arguments is None and direct_argument_kwargs:
arguments = direct_argument_kwargs
if arguments is None and context is not None:
@@ -614,17 +606,6 @@ class FunctionTool(SerializationMixin):
call_kwargs = dict(validated_arguments)
observable_kwargs = dict(validated_arguments)
# Legacy runtime kwargs injection path retained for backwards compatibility with tools
# that still declare ``**kwargs``. New tools should consume runtime data via ``ctx``.
legacy_runtime_kwargs = dict(runtime_kwargs)
if self._forward_runtime_kwargs and legacy_runtime_kwargs:
for key, value in legacy_runtime_kwargs.items():
if key not in call_kwargs:
call_kwargs[key] = value
if key not in observable_kwargs:
observable_kwargs[key] = value
if self._context_parameter_name is not None and effective_context is not None:
call_kwargs[self._context_parameter_name] = effective_context
@@ -1420,7 +1401,7 @@ async def _auto_invoke_function(
# No middleware - execute directly
try:
direct_context = None
if getattr(tool, "_forward_runtime_kwargs", False) or getattr(tool, "_context_parameter_name", None):
if getattr(tool, "_context_parameter_name", None):
direct_context = FunctionInvocationContext(
function=tool,
arguments=args,
@@ -2078,7 +2059,6 @@ class FunctionInvocationLayer(Generic[OptionsCoT]):
tokenizer: TokenizerProtocol | None = None,
function_invocation_kwargs: Mapping[str, Any] | None = None,
client_kwargs: Mapping[str, Any] | None = None,
**kwargs: Any,
) -> Awaitable[ChatResponse[ResponseModelBoundT]]: ...
@overload
@@ -2093,7 +2073,6 @@ class FunctionInvocationLayer(Generic[OptionsCoT]):
tokenizer: TokenizerProtocol | None = None,
function_invocation_kwargs: Mapping[str, Any] | None = None,
client_kwargs: Mapping[str, Any] | None = None,
**kwargs: Any,
) -> Awaitable[ChatResponse[Any]]: ...
@overload
@@ -2108,7 +2087,6 @@ class FunctionInvocationLayer(Generic[OptionsCoT]):
tokenizer: TokenizerProtocol | None = None,
function_invocation_kwargs: Mapping[str, Any] | None = None,
client_kwargs: Mapping[str, Any] | None = None,
**kwargs: Any,
) -> ResponseStream[ChatResponseUpdate, ChatResponse[Any]]: ...
def get_response(
@@ -2122,7 +2100,6 @@ class FunctionInvocationLayer(Generic[OptionsCoT]):
tokenizer: TokenizerProtocol | None = None,
function_invocation_kwargs: Mapping[str, Any] | None = None,
client_kwargs: Mapping[str, Any] | None = None,
**kwargs: Any,
) -> Awaitable[ChatResponse[Any]] | ResponseStream[ChatResponseUpdate, ChatResponse[Any]]:
from ._middleware import categorize_middleware
from ._types import (
@@ -2133,14 +2110,6 @@ class FunctionInvocationLayer(Generic[OptionsCoT]):
)
super_get_response = super().get_response # type: ignore[misc]
if kwargs:
warnings.warn(
"Passing client-specific keyword arguments directly to get_response() is deprecated; "
"pass them via client_kwargs instead.",
DeprecationWarning,
stacklevel=2,
)
effective_client_kwargs = dict(client_kwargs) if client_kwargs is not None else {}
if middleware is not None:
existing = effective_client_kwargs.get("middleware", [])
@@ -2176,19 +2145,23 @@ class FunctionInvocationLayer(Generic[OptionsCoT]):
invocation_session=invocation_session,
middleware_pipeline=function_middleware_pipeline,
)
filtered_kwargs = {k: v for k, v in {**effective_client_kwargs, **kwargs}.items() if k != "session"}
filtered_kwargs = {k: v for k, v in effective_client_kwargs.items() if k != "session"}
# Make options mutable so we can update conversation_id during function invocation loop
mutable_options: dict[str, Any] = dict(options) if options else {}
# Remove additional_function_arguments from options passed to underlying chat client
# It's for tool invocation only and not recognized by chat service APIs
mutable_options.pop("additional_function_arguments", None)
# Support tools passed via kwargs in direct client.get_response(...) calls.
if "tools" in filtered_kwargs:
if mutable_options.get("tools") is None:
mutable_options["tools"] = filtered_kwargs["tools"]
filtered_kwargs.pop("tools", None)
if not self.function_invocation_configuration.get("enabled", True):
return super_get_response( # type: ignore[no-any-return]
messages=messages,
stream=stream,
options=mutable_options,
compaction_strategy=compaction_strategy,
tokenizer=tokenizer,
function_invocation_kwargs=function_invocation_kwargs,
client_kwargs=filtered_kwargs,
)
if not stream:
async def _get_response() -> ChatResponse[Any]:
@@ -2235,7 +2208,7 @@ class FunctionInvocationLayer(Generic[OptionsCoT]):
aggregated_usage = add_usage_details(aggregated_usage, response.usage_details)
if response.conversation_id is not None:
_update_conversation_id(kwargs, response.conversation_id, mutable_options)
_update_conversation_id(filtered_kwargs, response.conversation_id, mutable_options)
prepped_messages = []
result = await _process_function_requests(
@@ -2379,7 +2352,7 @@ class FunctionInvocationLayer(Generic[OptionsCoT]):
return
if response.conversation_id is not None:
_update_conversation_id(kwargs, response.conversation_id, mutable_options)
_update_conversation_id(filtered_kwargs, response.conversation_id, mutable_options)
prepped_messages = []
result = await _process_function_requests(
@@ -12,7 +12,7 @@ from agent_framework import Content
from .._agents import SupportsAgentRun
from .._sessions import AgentSession
from .._types import AgentResponse, AgentResponseUpdate, Message
from .._types import AgentResponse, AgentResponseUpdate, Message, ResponseStream
from ._agent_utils import resolve_agent_id
from ._const import WORKFLOW_RUN_KWARGS_KEY
from ._executor import Executor, handler
@@ -352,7 +352,8 @@ class AgentExecutor(Executor):
"""
run_kwargs, options = self._prepare_agent_run_args(ctx.get_state(WORKFLOW_RUN_KWARGS_KEY, {}))
response = await self._agent.run(
run_agent = cast(Callable[..., Awaitable[AgentResponse[Any]]], self._agent.run)
response = await run_agent(
self._cache,
stream=False,
session=self._session,
@@ -383,7 +384,8 @@ class AgentExecutor(Executor):
updates: list[AgentResponseUpdate] = []
streamed_user_input_requests: list[Content] = []
stream = self._agent.run(
run_agent_stream = cast(Callable[..., ResponseStream[AgentResponseUpdate, AgentResponse[Any]]], self._agent.run)
stream = run_agent_stream(
self._cache,
stream=True,
session=self._session,
@@ -49,8 +49,9 @@ if TYPE_CHECKING: # pragma: no cover
from ._agents import SupportsAgentRun
from ._clients import SupportsChatGetResponse
from ._compaction import CompactionStrategy, TokenizerProtocol
from ._middleware import MiddlewareTypes
from ._sessions import AgentSession
from ._tools import FunctionTool
from ._tools import FunctionTool, ToolTypes
from ._types import (
AgentResponse,
AgentResponseUpdate,
@@ -1191,7 +1192,8 @@ class ChatTelemetryLayer(Generic[OptionsCoT]):
options: ChatOptions[ResponseModelBoundT],
compaction_strategy: CompactionStrategy | None = None,
tokenizer: TokenizerProtocol | None = None,
**kwargs: Any,
function_invocation_kwargs: Mapping[str, Any] | None = None,
client_kwargs: Mapping[str, Any] | None = None,
) -> Awaitable[ChatResponse[ResponseModelBoundT]]: ...
@overload
@@ -1203,7 +1205,8 @@ class ChatTelemetryLayer(Generic[OptionsCoT]):
options: OptionsCoT | ChatOptions[None] | None = None,
compaction_strategy: CompactionStrategy | None = None,
tokenizer: TokenizerProtocol | None = None,
**kwargs: Any,
function_invocation_kwargs: Mapping[str, Any] | None = None,
client_kwargs: Mapping[str, Any] | None = None,
) -> Awaitable[ChatResponse[Any]]: ...
@overload
@@ -1215,7 +1218,8 @@ class ChatTelemetryLayer(Generic[OptionsCoT]):
options: OptionsCoT | ChatOptions[Any] | None = None,
compaction_strategy: CompactionStrategy | None = None,
tokenizer: TokenizerProtocol | None = None,
**kwargs: Any,
function_invocation_kwargs: Mapping[str, Any] | None = None,
client_kwargs: Mapping[str, Any] | None = None,
) -> ResponseStream[ChatResponseUpdate, ChatResponse[Any]]: ...
def get_response(
@@ -1226,7 +1230,8 @@ class ChatTelemetryLayer(Generic[OptionsCoT]):
options: OptionsCoT | ChatOptions[Any] | None = None,
compaction_strategy: CompactionStrategy | None = None,
tokenizer: TokenizerProtocol | None = None,
**kwargs: Any,
function_invocation_kwargs: Mapping[str, Any] | None = None,
client_kwargs: Mapping[str, Any] | None = None,
) -> Awaitable[ChatResponse[Any]] | ResponseStream[ChatResponseUpdate, ChatResponse[Any]]:
"""Trace chat responses with OpenTelemetry spans and metrics.
@@ -1238,25 +1243,14 @@ class ChatTelemetryLayer(Generic[OptionsCoT]):
tokenizer: Optional tokenizer used by token-aware compaction strategies.
Keyword Args:
kwargs: Compatibility keyword arguments from higher client layers. This layer does
not consume ``function_invocation_kwargs`` directly; if present, it is ignored
because function invocation has already been processed above. If a ``client_kwargs``
mapping is present, it is flattened into ordinary keyword arguments for tracing and
forwarding so clients that use those values continue to work while clients that
ignore extra kwargs remain compatible.
function_invocation_kwargs: Keyword arguments forwarded only to tool invocation layers.
client_kwargs: Additional client-specific keyword arguments for downstream chat clients.
"""
from ._types import ChatResponse, ChatResponseUpdate, ResponseStream # type: ignore[reportUnusedImport]
global OBSERVABILITY_SETTINGS
super_get_response = super().get_response # type: ignore[misc]
compatibility_client_kwargs = kwargs.pop("client_kwargs", None)
kwargs.pop("function_invocation_kwargs", None)
merged_client_kwargs = (
dict(cast(Mapping[str, Any], compatibility_client_kwargs))
if isinstance(compatibility_client_kwargs, Mapping)
else {}
)
merged_client_kwargs.update(kwargs)
merged_client_kwargs = dict(client_kwargs) if client_kwargs is not None else {}
if not OBSERVABILITY_SETTINGS.ENABLED:
return super_get_response( # type: ignore[no-any-return]
@@ -1265,7 +1259,8 @@ class ChatTelemetryLayer(Generic[OptionsCoT]):
options=options,
compaction_strategy=compaction_strategy,
tokenizer=tokenizer,
**merged_client_kwargs,
function_invocation_kwargs=function_invocation_kwargs,
client_kwargs=merged_client_kwargs,
)
opts: dict[str, Any] = options or {} # type: ignore[assignment]
@@ -1292,7 +1287,8 @@ class ChatTelemetryLayer(Generic[OptionsCoT]):
options=opts,
compaction_strategy=compaction_strategy,
tokenizer=tokenizer,
**merged_client_kwargs,
function_invocation_kwargs=function_invocation_kwargs,
client_kwargs=merged_client_kwargs,
),
)
@@ -1384,7 +1380,8 @@ class ChatTelemetryLayer(Generic[OptionsCoT]):
options=opts,
compaction_strategy=compaction_strategy,
tokenizer=tokenizer,
**merged_client_kwargs,
function_invocation_kwargs=function_invocation_kwargs,
client_kwargs=merged_client_kwargs,
),
)
except Exception as exception:
@@ -1512,11 +1509,29 @@ class AgentTelemetryLayer:
*,
stream: Literal[False] = ...,
session: AgentSession | None = None,
middleware: Sequence[MiddlewareTypes] | None = None,
tools: ToolTypes | Callable[..., Any] | Sequence[ToolTypes | Callable[..., Any]] | None = None,
options: ChatOptions[ResponseModelBoundT],
compaction_strategy: CompactionStrategy | None = None,
tokenizer: TokenizerProtocol | None = None,
function_invocation_kwargs: Mapping[str, Any] | None = None,
client_kwargs: Mapping[str, Any] | None = None,
) -> Awaitable[AgentResponse[ResponseModelBoundT]]: ...
@overload
def run(
self,
messages: AgentRunInputs | None = None,
*,
stream: Literal[False] = ...,
session: AgentSession | None = None,
middleware: Sequence[MiddlewareTypes] | None = None,
tools: ToolTypes | Callable[..., Any] | Sequence[ToolTypes | Callable[..., Any]] | None = None,
options: ChatOptions[None] | None = None,
compaction_strategy: CompactionStrategy | None = None,
tokenizer: TokenizerProtocol | None = None,
function_invocation_kwargs: Mapping[str, Any] | None = None,
client_kwargs: Mapping[str, Any] | None = None,
**kwargs: Any,
) -> Awaitable[AgentResponse[Any]]: ...
@overload
@@ -1526,11 +1541,13 @@ class AgentTelemetryLayer:
*,
stream: Literal[True],
session: AgentSession | None = None,
middleware: Sequence[MiddlewareTypes] | None = None,
tools: ToolTypes | Callable[..., Any] | Sequence[ToolTypes | Callable[..., Any]] | None = None,
options: ChatOptions[Any] | None = None,
compaction_strategy: CompactionStrategy | None = None,
tokenizer: TokenizerProtocol | None = None,
function_invocation_kwargs: Mapping[str, Any] | None = None,
client_kwargs: Mapping[str, Any] | None = None,
**kwargs: Any,
) -> ResponseStream[AgentResponseUpdate, AgentResponse[Any]]: ...
def run(
@@ -1539,11 +1556,13 @@ class AgentTelemetryLayer:
*,
stream: bool = False,
session: AgentSession | None = None,
middleware: Sequence[MiddlewareTypes] | None = None,
tools: ToolTypes | Callable[..., Any] | Sequence[ToolTypes | Callable[..., Any]] | None = None,
options: ChatOptions[Any] | None = None,
compaction_strategy: CompactionStrategy | None = None,
tokenizer: TokenizerProtocol | None = None,
function_invocation_kwargs: Mapping[str, Any] | None = None,
client_kwargs: Mapping[str, Any] | None = None,
**kwargs: Any,
) -> Awaitable[AgentResponse[Any]] | ResponseStream[AgentResponseUpdate, AgentResponse[Any]]:
"""Trace agent runs with OpenTelemetry spans and metrics."""
global OBSERVABILITY_SETTINGS
@@ -1554,23 +1573,27 @@ class AgentTelemetryLayer:
super().run, # type: ignore[misc]
)
provider_name = str(self.otel_provider_name)
super_run_kwargs: dict[str, Any] = {
"messages": messages,
"stream": stream,
"session": session,
"tools": tools,
"options": options,
"compaction_strategy": compaction_strategy,
"tokenizer": tokenizer,
"function_invocation_kwargs": function_invocation_kwargs,
"client_kwargs": client_kwargs,
}
if middleware is not None:
super_run_kwargs["middleware"] = middleware
if not OBSERVABILITY_SETTINGS.ENABLED:
return super_run( # type: ignore[no-any-return]
messages=messages,
stream=stream,
session=session,
compaction_strategy=compaction_strategy,
tokenizer=tokenizer,
function_invocation_kwargs=function_invocation_kwargs,
client_kwargs=client_kwargs,
**kwargs,
)
return super_run(**super_run_kwargs) # type: ignore[no-any-return]
default_options = getattr(self, "default_options", {})
options = kwargs.get("options")
default_options = dict(getattr(self, "default_options", {}))
merged_client_kwargs = dict(client_kwargs) if client_kwargs is not None else {}
merged_client_kwargs.update(kwargs)
merged_options: dict[str, Any] = merge_chat_options(default_options, options or {})
merged_options: dict[str, Any] = merge_chat_options(
default_options, dict(options) if options is not None else {}
)
attributes = _get_span_attributes(
operation_name=OtelAttr.AGENT_INVOKE_OPERATION,
provider_name=provider_name,
@@ -1590,16 +1613,7 @@ class AgentTelemetryLayer:
if stream:
try:
run_result: object = super_run(
messages=messages,
stream=True,
session=session,
compaction_strategy=compaction_strategy,
tokenizer=tokenizer,
function_invocation_kwargs=function_invocation_kwargs,
client_kwargs=client_kwargs,
**kwargs,
)
run_result: object = super_run(**super_run_kwargs)
if isinstance(run_result, ResponseStream):
result_stream: ResponseStream[AgentResponseUpdate, AgentResponse[Any]] = run_result # pyright: ignore[reportUnknownVariableType]
elif isinstance(run_result, Awaitable):
@@ -1693,16 +1707,7 @@ class AgentTelemetryLayer:
)
start_time_stamp = perf_counter()
try:
response: AgentResponse[Any] = await super_run(
messages=messages,
stream=False,
session=session,
compaction_strategy=compaction_strategy,
tokenizer=tokenizer,
function_invocation_kwargs=function_invocation_kwargs,
client_kwargs=client_kwargs,
**kwargs,
)
response: AgentResponse[Any] = await super_run(**super_run_kwargs)
except Exception as exception:
capture_exception(span=span, exception=exception, timestamp=time_ns())
raise
+8 -12
View File
@@ -148,11 +148,9 @@ async def test_chat_client_agent_init_with_name(
assert agent.description == "Test"
def test_agent_init_warns_for_direct_additional_properties(client: SupportsChatGetResponse) -> None:
with pytest.warns(DeprecationWarning, match="additional_properties"):
agent = Agent(client=client, legacy_key="legacy-value")
assert agent.additional_properties["legacy_key"] == "legacy-value"
def test_agent_init_rejects_direct_additional_properties(client: SupportsChatGetResponse) -> None:
with pytest.raises(TypeError):
Agent(client=client, legacy_key="legacy-value")
async def test_chat_client_agent_run(client: SupportsChatGetResponse) -> None:
@@ -303,7 +301,6 @@ async def test_prepare_run_context_handles_function_kwargs(
},
compaction_strategy=None,
tokenizer=None,
legacy_kwargs={"legacy_key": "legacy-value"},
function_invocation_kwargs={"runtime_key": "runtime-value"},
client_kwargs={"client_key": "client-value"},
)
@@ -311,7 +308,6 @@ async def test_prepare_run_context_handles_function_kwargs(
assert ctx["chat_options"]["temperature"] == 0.4
assert "additional_function_arguments" not in ctx["chat_options"]
assert ctx["function_invocation_kwargs"]["from_options"] == "options-value"
assert ctx["function_invocation_kwargs"]["legacy_key"] == "legacy-value"
assert ctx["function_invocation_kwargs"]["runtime_key"] == "runtime-value"
assert "session" not in ctx["function_invocation_kwargs"]
assert ctx["client_kwargs"]["client_key"] == "client-value"
@@ -1181,8 +1177,8 @@ async def test_agent_run_accepts_prefixed_mcp_tools(chat_client_base: Any) -> No
assert tool_names == ["search", "docs_search"]
async def test_agent_tool_receives_session_in_kwargs(chat_client_base: Any) -> None:
"""Verify legacy **kwargs tools receive the session when agent.run() is called with one."""
async def test_agent_tool_without_context_does_not_receive_session(chat_client_base: Any) -> None:
"""Verify tools without FunctionInvocationContext no longer receive injected session kwargs."""
captured: dict[str, Any] = {}
@@ -1215,8 +1211,8 @@ async def test_agent_tool_receives_session_in_kwargs(chat_client_base: Any) -> N
result = await agent.run("hello", session=session)
assert result.text == "done"
assert captured.get("has_session") is True
assert captured.get("has_state") is True
assert captured.get("has_session") is False
assert captured.get("has_state") is False
async def test_agent_tool_receives_explicit_session_via_function_invocation_context_kwargs(
@@ -1278,7 +1274,7 @@ async def test_chat_agent_tool_choice_run_level_overrides_agent_level(chat_clien
agent = Agent(
client=chat_client_base,
tools=[tool_tool],
options={"tool_choice": "auto"},
default_options={"tool_choice": "auto"},
)
# Run with run-level tool_choice="required"
@@ -1,7 +1,6 @@
# Copyright (c) Microsoft. All rights reserved.
import inspect
from typing import Any
from unittest.mock import patch
@@ -15,11 +14,6 @@ from agent_framework import (
Message,
SlidingWindowStrategy,
SupportsChatGetResponse,
SupportsCodeInterpreterTool,
SupportsFileSearchTool,
SupportsImageGenerationTool,
SupportsMCPTool,
SupportsWebSearchTool,
TruncationStrategy,
)
@@ -53,11 +47,9 @@ def test_base_client(chat_client_base: SupportsChatGetResponse):
assert isinstance(chat_client_base, SupportsChatGetResponse)
def test_base_client_warns_for_direct_additional_properties(chat_client_base: SupportsChatGetResponse) -> None:
with pytest.warns(DeprecationWarning, match="additional_properties"):
client = type(chat_client_base)(legacy_key="legacy-value")
assert client.additional_properties["legacy_key"] == "legacy-value"
def test_base_client_rejects_direct_additional_properties(chat_client_base: SupportsChatGetResponse) -> None:
with pytest.raises(TypeError):
type(chat_client_base)(legacy_key="legacy-value")
def test_base_client_as_agent_uses_explicit_additional_properties(chat_client_base: SupportsChatGetResponse) -> None:
@@ -66,27 +58,6 @@ def test_base_client_as_agent_uses_explicit_additional_properties(chat_client_ba
assert agent.additional_properties == {"team": "core"}
def test_openai_chat_completion_client_get_response_docstring_surfaces_layered_runtime_docs() -> None:
from agent_framework.openai import OpenAIChatCompletionClient
docstring = inspect.getdoc(OpenAIChatCompletionClient.get_response)
assert docstring is not None
assert "Get a response from a chat client." in docstring
assert "function_invocation_kwargs" in docstring
assert "middleware: Optional per-call chat and function middleware." in docstring
assert "function_middleware: Optional per-call function middleware." not in docstring
def test_openai_chat_completion_client_get_response_is_defined_on_openai_class() -> None:
from agent_framework.openai import OpenAIChatCompletionClient
signature = inspect.signature(OpenAIChatCompletionClient.get_response)
assert OpenAIChatCompletionClient.get_response.__qualname__ == "OpenAIChatCompletionClient.get_response"
assert "middleware" in signature.parameters
async def test_base_client_get_response_uses_explicit_client_kwargs(chat_client_base: SupportsChatGetResponse) -> None:
async def fake_inner_get_response(**kwargs):
assert kwargs["trace_id"] == "trace-123"
@@ -333,66 +304,3 @@ async def test_chat_client_instructions_handling(chat_client_base: SupportsChatG
assert appended_messages[0].text == "You are a helpful assistant."
assert appended_messages[1].role == "user"
assert appended_messages[1].text == "hello"
# region Tool Support Protocol Tests
def test_openai_responses_client_supports_all_tool_protocols():
"""Test that OpenAIResponsesClient supports all hosted tool protocols."""
from agent_framework.openai import OpenAIResponsesClient
assert isinstance(OpenAIResponsesClient, SupportsCodeInterpreterTool)
assert isinstance(OpenAIResponsesClient, SupportsWebSearchTool)
assert isinstance(OpenAIResponsesClient, SupportsImageGenerationTool)
assert isinstance(OpenAIResponsesClient, SupportsMCPTool)
assert isinstance(OpenAIResponsesClient, SupportsFileSearchTool)
def test_openai_chat_completion_client_supports_web_search_only():
"""Test that OpenAIChatClient only supports web search tool."""
from agent_framework.openai import OpenAIChatCompletionClient
assert not isinstance(OpenAIChatCompletionClient, SupportsCodeInterpreterTool)
assert isinstance(OpenAIChatCompletionClient, SupportsWebSearchTool)
assert not isinstance(OpenAIChatCompletionClient, SupportsImageGenerationTool)
assert not isinstance(OpenAIChatCompletionClient, SupportsMCPTool)
assert not isinstance(OpenAIChatCompletionClient, SupportsFileSearchTool)
def test_openai_assistants_client_supports_code_interpreter_and_file_search():
"""Test that OpenAIAssistantsClient supports code interpreter and file search."""
from agent_framework.openai import OpenAIAssistantsClient
assert isinstance(OpenAIAssistantsClient, SupportsCodeInterpreterTool)
assert not isinstance(OpenAIAssistantsClient, SupportsWebSearchTool)
assert not isinstance(OpenAIAssistantsClient, SupportsImageGenerationTool)
assert not isinstance(OpenAIAssistantsClient, SupportsMCPTool)
assert isinstance(OpenAIAssistantsClient, SupportsFileSearchTool)
def test_protocol_isinstance_with_client_instance():
"""Test that protocol isinstance works with client instances."""
from agent_framework.openai import OpenAIResponsesClient
# Create mock client instance (won't connect to API)
client = OpenAIResponsesClient.__new__(OpenAIResponsesClient)
assert isinstance(client, SupportsCodeInterpreterTool)
assert isinstance(client, SupportsWebSearchTool)
def test_protocol_tool_methods_return_dict():
"""Test that static tool methods return dict[str, Any]."""
from agent_framework.openai import OpenAIResponsesClient
code_tool = OpenAIResponsesClient.get_code_interpreter_tool()
assert isinstance(code_tool, dict)
assert code_tool.get("type") == "code_interpreter"
web_tool = OpenAIResponsesClient.get_web_search_tool()
assert isinstance(web_tool, dict)
assert web_tool.get("type") == "web_search"
# endregion
@@ -13,6 +13,7 @@ from agent_framework import (
Content,
Message,
SupportsChatGetResponse,
chat_middleware,
tool,
)
from agent_framework._compaction import (
@@ -74,7 +75,7 @@ async def test_base_client_with_function_calling(chat_client_base: SupportsChatG
assert response.messages[2].text == "done"
async def test_base_client_with_function_calling_tools_in_kwargs(chat_client_base: SupportsChatGetResponse):
async def test_base_client_with_function_calling_string_input(chat_client_base: SupportsChatGetResponse):
exec_counter = 0
@tool(name="test_function", approval_mode="never_require")
@@ -95,7 +96,7 @@ async def test_base_client_with_function_calling_tools_in_kwargs(chat_client_bas
ChatResponse(messages=Message(role="assistant", text="done")),
]
response = await chat_client_base.get_response("hello", tools=[ai_func])
response = await chat_client_base.get_response("hello", options={"tool_choice": "auto", "tools": [ai_func]})
assert exec_counter == 1
assert len(response.messages) == 3
@@ -1429,6 +1430,36 @@ async def test_function_invocation_config_enabled_false(chat_client_base: Suppor
assert len(response.messages) > 0
async def test_function_invocation_config_enabled_false_preserves_invocation_kwargs(
chat_client_base: SupportsChatGetResponse,
):
"""Test disabled function invocation still forwards invocation kwargs downstream."""
captured_kwargs: dict[str, Any] = {}
@tool(name="test_function")
def ai_func(arg1: str) -> str:
return f"Processed {arg1}"
@chat_middleware
async def capture_middleware(context, call_next):
captured_kwargs.update(context.function_invocation_kwargs or {})
await call_next()
chat_client_base.chat_middleware = [capture_middleware]
chat_client_base.run_responses = [
ChatResponse(messages=Message(role="assistant", text="response without function calling")),
]
chat_client_base.function_invocation_configuration["enabled"] = False
await chat_client_base.get_response(
[Message(role="user", text="hello")],
options={"tool_choice": "auto", "tools": [ai_func]},
function_invocation_kwargs={"tool_request_id": "tool-123"},
)
assert captured_kwargs == {"tool_request_id": "tool-123"}
@pytest.mark.skip(reason="Error handling and failsafe behavior needs investigation in unified API")
async def test_function_invocation_config_max_consecutive_errors(chat_client_base: SupportsChatGetResponse):
"""Test that max_consecutive_errors_per_request limits error retries."""
@@ -1523,7 +1554,7 @@ async def test_function_invocation_stop_clears_conversation_id_non_stream(chat_c
response = await chat_client_base.get_response(
[Message(role="user", text="hello")],
options={"tool_choice": "auto", "tools": [error_func]},
session=session_stub,
client_kwargs={"session": session_stub},
)
assert response.conversation_id is None
@@ -1881,8 +1912,7 @@ async def test_hosted_tool_approval_response(chat_client_base: SupportsChatGetRe
# Send the approval response
response = await chat_client_base.get_response(
[Message(role="user", contents=[approval_response])],
tool_choice="auto",
tools=[local_func],
options={"tool_choice": "auto", "tools": [local_func]},
)
# The hosted tool approval should be returned as-is (not executed)
@@ -1930,8 +1960,7 @@ async def test_hosted_mcp_approval_response_passthrough(chat_client_base: Suppor
response = await chat_client_base.get_response(
messages,
tool_choice="auto",
tools=[local_func],
options={"tool_choice": "auto", "tools": [local_func]},
)
# The response should succeed without errors
@@ -2024,8 +2053,7 @@ async def test_mixed_local_and_hosted_approval_flow(chat_client_base: SupportsCh
response = await chat_client_base.get_response(
messages,
tool_choice="auto",
tools=[local_func],
options={"tool_choice": "auto", "tools": [local_func]},
)
assert response is not None
@@ -2799,7 +2827,7 @@ async def test_streaming_function_invocation_stop_clears_conversation_id(chat_cl
"hello",
options={"tool_choice": "auto", "tools": [error_func]},
stream=True,
session=session_stub,
client_kwargs={"session": session_stub},
)
async for _ in stream:
pass
@@ -1,351 +0,0 @@
# Copyright (c) Microsoft. All rights reserved.
"""Tests for kwargs propagation from get_response() to @tool functions."""
from collections.abc import AsyncIterable, Awaitable, MutableSequence, Sequence
from typing import Any
from agent_framework import (
Agent,
BaseChatClient,
ChatMiddlewareLayer,
ChatResponse,
ChatResponseUpdate,
Content,
FunctionInvocationContext,
FunctionInvocationLayer,
Message,
ResponseStream,
tool,
)
from agent_framework.observability import ChatTelemetryLayer
class _MockBaseChatClient(BaseChatClient[Any]):
"""Mock chat client for testing function invocation."""
def __init__(self) -> None:
super().__init__()
self.run_responses: list[ChatResponse] = []
self.streaming_responses: list[list[ChatResponseUpdate]] = []
self.call_count: int = 0
def _inner_get_response(
self,
*,
messages: MutableSequence[Message],
stream: bool,
options: dict[str, Any],
**kwargs: Any,
) -> Awaitable[ChatResponse] | ResponseStream[ChatResponseUpdate, ChatResponse]:
if stream:
return self._get_streaming_response(messages=messages, options=options, **kwargs)
async def _get() -> ChatResponse:
return await self._get_non_streaming_response(messages=messages, options=options, **kwargs)
return _get()
async def _get_non_streaming_response(
self,
*,
messages: MutableSequence[Message],
options: dict[str, Any],
**kwargs: Any,
) -> ChatResponse:
self.call_count += 1
if self.run_responses:
return self.run_responses.pop(0)
return ChatResponse(messages=Message(role="assistant", text="default response"))
def _get_streaming_response(
self,
*,
messages: MutableSequence[Message],
options: dict[str, Any],
**kwargs: Any,
) -> ResponseStream[ChatResponseUpdate, ChatResponse]:
async def _stream() -> AsyncIterable[ChatResponseUpdate]:
self.call_count += 1
if self.streaming_responses:
for update in self.streaming_responses.pop(0):
yield update
else:
yield ChatResponseUpdate(
contents=[Content.from_text("default streaming response")], role="assistant", finish_reason="stop"
)
def _finalize(updates: Sequence[ChatResponseUpdate]) -> ChatResponse:
response_format = options.get("response_format")
output_format_type = response_format if isinstance(response_format, type) else None
return ChatResponse.from_updates(updates, output_format_type=output_format_type)
return ResponseStream(_stream(), finalizer=_finalize)
class FunctionInvokingMockClient(
FunctionInvocationLayer[Any],
ChatMiddlewareLayer[Any],
ChatTelemetryLayer[Any],
_MockBaseChatClient,
):
"""Mock client with function invocation support."""
pass
class TestKwargsPropagationToFunctionTool:
"""Test cases for kwargs flowing from get_response() to @tool functions."""
async def test_kwargs_propagate_to_tool_with_kwargs(self) -> None:
"""Test that kwargs passed to get_response() are available in @tool **kwargs."""
# TODO(Copilot): Remove this legacy coverage once runtime ``**kwargs`` tool injection is removed.
captured_kwargs: dict[str, Any] = {}
@tool(approval_mode="never_require")
def capture_kwargs_tool(x: int, **kwargs: Any) -> str:
"""A tool that captures kwargs for testing."""
captured_kwargs.update(kwargs)
return f"result: x={x}"
client = FunctionInvokingMockClient()
client.run_responses = [
# First response: function call
ChatResponse(
messages=[
Message(
role="assistant",
contents=[
Content.from_function_call(
call_id="call_1", name="capture_kwargs_tool", arguments='{"x": 42}'
)
],
)
]
),
# Second response: final answer
ChatResponse(messages=[Message(role="assistant", text="Done!")]),
]
result = await client.get_response(
messages=[Message(role="user", text="Test")],
stream=False,
options={
"tools": [capture_kwargs_tool],
"additional_function_arguments": {
"user_id": "user-123",
"session_token": "secret-token",
"custom_data": {"key": "value"},
},
},
)
# Verify the tool was called and received the kwargs
assert "user_id" in captured_kwargs, f"Expected 'user_id' in captured kwargs: {captured_kwargs}"
assert captured_kwargs["user_id"] == "user-123"
assert "session_token" in captured_kwargs
assert captured_kwargs["session_token"] == "secret-token"
assert "custom_data" in captured_kwargs
assert captured_kwargs["custom_data"] == {"key": "value"}
# Verify result
assert result.messages[-1].text == "Done!"
async def test_kwargs_not_forwarded_to_tool_without_kwargs(self) -> None:
"""Test that kwargs are NOT forwarded to @tool that doesn't accept **kwargs."""
# TODO(Copilot): Remove this legacy coverage once runtime ``**kwargs`` tool injection is removed.
@tool(approval_mode="never_require")
def simple_tool(x: int) -> str:
"""A simple tool without **kwargs."""
return f"result: x={x}"
client = FunctionInvokingMockClient()
client.run_responses = [
ChatResponse(
messages=[
Message(
role="assistant",
contents=[
Content.from_function_call(call_id="call_1", name="simple_tool", arguments='{"x": 99}')
],
)
]
),
ChatResponse(messages=[Message(role="assistant", text="Completed!")]),
]
# Call with additional_function_arguments - the tool should work but not receive them
result = await client.get_response(
messages=[Message(role="user", text="Test")],
stream=False,
options={
"tools": [simple_tool],
"additional_function_arguments": {"user_id": "user-123"},
},
)
# Verify the tool was called successfully (no error from extra kwargs)
assert result.messages[-1].text == "Completed!"
async def test_kwargs_isolated_between_function_calls(self) -> None:
"""Test that kwargs are consistent across multiple function call invocations."""
# TODO(Copilot): Remove this legacy coverage once runtime ``**kwargs`` tool injection is removed.
invocation_kwargs: list[dict[str, Any]] = []
@tool(approval_mode="never_require")
def tracking_tool(name: str, **kwargs: Any) -> str:
"""A tool that tracks kwargs from each invocation."""
invocation_kwargs.append(dict(kwargs))
return f"called with {name}"
client = FunctionInvokingMockClient()
client.run_responses = [
# Two function calls in one response
ChatResponse(
messages=[
Message(
role="assistant",
contents=[
Content.from_function_call(
call_id="call_1", name="tracking_tool", arguments='{"name": "first"}'
),
Content.from_function_call(
call_id="call_2", name="tracking_tool", arguments='{"name": "second"}'
),
],
)
]
),
ChatResponse(messages=[Message(role="assistant", text="All done!")]),
]
result = await client.get_response(
messages=[Message(role="user", text="Test")],
stream=False,
options={
"tools": [tracking_tool],
"additional_function_arguments": {
"request_id": "req-001",
"trace_context": {"trace_id": "abc"},
},
},
)
# Both invocations should have received the same kwargs
assert len(invocation_kwargs) == 2
for kwargs in invocation_kwargs:
assert kwargs.get("request_id") == "req-001"
assert kwargs.get("trace_context") == {"trace_id": "abc"}
assert result.messages[-1].text == "All done!"
async def test_streaming_response_kwargs_propagation(self) -> None:
"""Test that kwargs propagate to @tool in streaming mode."""
# TODO(Copilot): Remove this legacy coverage once runtime ``**kwargs`` tool injection is removed.
captured_kwargs: dict[str, Any] = {}
@tool(approval_mode="never_require")
def streaming_capture_tool(value: str, **kwargs: Any) -> str:
"""A tool that captures kwargs during streaming."""
captured_kwargs.update(kwargs)
return f"processed: {value}"
client = FunctionInvokingMockClient()
client.streaming_responses = [
# First stream: function call
[
ChatResponseUpdate(
role="assistant",
contents=[
Content.from_function_call(
call_id="stream_call_1",
name="streaming_capture_tool",
arguments='{"value": "streaming-test"}',
)
],
finish_reason="stop",
)
],
# Second stream: final response
[
ChatResponseUpdate(
contents=[Content.from_text("Stream complete!")], role="assistant", finish_reason="stop"
)
],
]
# Collect streaming updates
updates: list[ChatResponseUpdate] = []
stream = client.get_response(
messages=[Message(role="user", text="Test")],
stream=True,
options={
"tools": [streaming_capture_tool],
"additional_function_arguments": {
"streaming_session": "session-xyz",
"correlation_id": "corr-123",
},
},
)
async for update in stream:
updates.append(update)
# Verify kwargs were captured by the tool
assert "streaming_session" in captured_kwargs, f"Expected 'streaming_session' in {captured_kwargs}"
assert captured_kwargs["streaming_session"] == "session-xyz"
assert captured_kwargs["correlation_id"] == "corr-123"
async def test_agent_run_injects_function_invocation_context(self) -> None:
"""Test that Agent.run injects FunctionInvocationContext for ctx-based tools."""
captured_context_kwargs: dict[str, Any] = {}
captured_client_kwargs: dict[str, Any] = {}
captured_options: dict[str, Any] = {}
@tool(approval_mode="never_require")
def capture_context_tool(x: int, ctx: FunctionInvocationContext) -> str:
captured_context_kwargs.update(ctx.kwargs)
return f"result: x={x}"
class CapturingFunctionInvokingMockClient(FunctionInvokingMockClient):
async def _get_non_streaming_response(
self,
*,
messages: MutableSequence[Message],
options: dict[str, Any],
**kwargs: Any,
) -> ChatResponse:
captured_options.update(options)
captured_client_kwargs.update(kwargs)
return await super()._get_non_streaming_response(messages=messages, options=options, **kwargs)
client = CapturingFunctionInvokingMockClient()
client.run_responses = [
ChatResponse(
messages=[
Message(
role="assistant",
contents=[
Content.from_function_call(
call_id="call_1",
name="capture_context_tool",
arguments='{"x": 42}',
)
],
)
]
),
ChatResponse(messages=[Message(role="assistant", text="Done!")]),
]
agent = Agent(client=client, tools=[capture_context_tool])
result = await agent.run(
[Message(role="user", text="Test")],
function_invocation_kwargs={"tool_request_id": "tool-123"},
client_kwargs={"client_request_id": "client-456"},
)
assert captured_context_kwargs["tool_request_id"] == "tool-123"
assert "client_request_id" not in captured_context_kwargs
assert captured_client_kwargs["client_request_id"] == "client-456"
assert "tool_request_id" not in captured_client_kwargs
assert "additional_function_arguments" not in captured_options
assert result.messages[-1].text == "Done!"
+16 -8
View File
@@ -1751,6 +1751,9 @@ async def test_mcp_tool_sampling_callback_no_valid_content():
assert isinstance(result, types.ErrorData)
assert result.code == types.INTERNAL_ERROR
assert "Failed to get right content types from the response." in result.message
mock_chat_client.get_response.assert_awaited_once()
_, kwargs = mock_chat_client.get_response.await_args
assert kwargs["options"] == {"max_tokens": None}
async def test_mcp_tool_sampling_callback_no_response_and_successful_message_creation():
@@ -3704,14 +3707,19 @@ async def test_mcp_tool_filters_framework_kwargs():
# Invoke the tool with framework kwargs that should be filtered out
await func.invoke(
param="test_value",
response_format=MockResponseFormat, # Should be filtered
chat_options={"some": "option"}, # Should be filtered
tools=[Mock()], # Should be filtered
tool_choice="auto", # Should be filtered
session=Mock(), # Should be filtered
conversation_id="conv-123", # Should be filtered
options={"metadata": "value"}, # Should be filtered
context=FunctionInvocationContext(
function=func,
arguments={"param": "test_value"},
kwargs={
"response_format": MockResponseFormat, # Should be filtered
"chat_options": {"some": "option"}, # Should be filtered
"tools": [Mock()], # Should be filtered
"tool_choice": "auto", # Should be filtered
"session": Mock(), # Should be filtered
"conversation_id": "conv-123", # Should be filtered
"options": {"metadata": "value"}, # Should be filtered
},
),
)
# Verify call_tool was called with only the valid argument
@@ -789,9 +789,10 @@ class TestChatAgentFunctionMiddlewareWithTools:
assert modified_kwargs["new_param"] == "added_by_middleware"
assert modified_kwargs["custom_param"] == "test_value"
async def test_run_kwargs_available_in_function_middleware(self, chat_client_base: "MockBaseChatClient") -> None:
"""Test that kwargs passed directly to agent.run() appear in FunctionInvocationContext.kwargs,
including complex nested values like dicts."""
async def test_function_invocation_kwargs_available_in_function_middleware(
self, chat_client_base: "MockBaseChatClient"
) -> None:
"""Test that function_invocation_kwargs appear in FunctionInvocationContext.kwargs."""
captured_kwargs: dict[str, Any] = {}
@function_middleware
@@ -822,18 +823,20 @@ class TestChatAgentFunctionMiddlewareWithTools:
session_metadata = {"tenant": "acme-corp", "region": "us-west"}
await agent.run(
[Message(role="user", text="Get weather")],
user_id="user-456",
session_metadata=session_metadata,
function_invocation_kwargs={
"user_id": "user-456",
"session_metadata": session_metadata,
},
)
assert "user_id" in captured_kwargs, f"Expected 'user_id' in kwargs: {captured_kwargs}"
assert captured_kwargs["user_id"] == "user-456"
assert captured_kwargs["session_metadata"] == {"tenant": "acme-corp", "region": "us-west"}
async def test_run_kwargs_merged_with_additional_function_arguments(
async def test_function_invocation_kwargs_merged_with_additional_function_arguments(
self, chat_client_base: "MockBaseChatClient"
) -> None:
"""Test that explicit additional_function_arguments in options take precedence over run kwargs."""
"""Test that explicit additional_function_arguments in options take precedence."""
captured_kwargs: dict[str, Any] = {}
@function_middleware
@@ -863,9 +866,10 @@ class TestChatAgentFunctionMiddlewareWithTools:
await agent.run(
[Message(role="user", text="Get weather")],
# This kwarg should be overridden by additional_function_arguments
user_id="from-kwargs",
tenant_id="from-kwargs",
function_invocation_kwargs={
"user_id": "from-kwargs",
"tenant_id": "from-kwargs",
},
options={
"additional_function_arguments": {
"user_id": "from-options",
@@ -876,15 +880,15 @@ class TestChatAgentFunctionMiddlewareWithTools:
# additional_function_arguments takes precedence for overlapping keys
assert captured_kwargs["user_id"] == "from-options"
# Non-overlapping kwargs from run() still come through
# Non-overlapping function_invocation_kwargs still come through
assert captured_kwargs["tenant_id"] == "from-kwargs"
# Keys only in additional_function_arguments are present
assert captured_kwargs["extra_key"] == "only-in-options"
async def test_run_kwargs_consistent_across_multiple_tool_calls(
async def test_function_invocation_kwargs_consistent_across_multiple_tool_calls(
self, chat_client_base: "MockBaseChatClient"
) -> None:
"""Test that kwargs are consistent across multiple tool invocations in a single run."""
"""Test that function_invocation_kwargs are consistent across tool invocations."""
invocation_kwargs: list[dict[str, Any]] = []
@function_middleware
@@ -917,8 +921,10 @@ class TestChatAgentFunctionMiddlewareWithTools:
await agent.run(
[Message(role="user", text="Get weather for both cities")],
user_id="user-456",
request_id="req-001",
function_invocation_kwargs={
"user_id": "user-456",
"request_id": "req-001",
},
)
assert len(invocation_kwargs) == 2
@@ -2060,23 +2066,21 @@ class TestChatAgentChatMiddleware:
"agent_middleware_after",
]
async def test_agent_middleware_can_access_and_override_custom_kwargs(self) -> None:
"""Test that agent middleware can access and override custom parameters like temperature."""
captured_kwargs: dict[str, Any] = {}
modified_kwargs: dict[str, Any] = {}
async def test_agent_middleware_can_access_and_override_options(self) -> None:
"""Test that agent middleware can access and override runtime options."""
captured_options: dict[str, Any] = {}
modified_options: dict[str, Any] = {}
@agent_middleware
async def kwargs_middleware(context: AgentContext, call_next: Callable[[], Awaitable[None]]) -> None:
# Capture the original kwargs
captured_kwargs.update(context.kwargs)
assert isinstance(context.options, dict)
captured_options.update(context.options)
# Modify some kwargs
context.kwargs["temperature"] = 0.9
context.kwargs["max_tokens"] = 500
context.kwargs["new_param"] = "added_by_middleware"
context.options["temperature"] = 0.9
context.options["max_tokens"] = 500
context.options["new_param"] = "added_by_middleware"
# Store modified kwargs for verification
modified_kwargs.update(context.kwargs)
modified_options.update(context.options)
await call_next()
@@ -2084,24 +2088,25 @@ class TestChatAgentChatMiddleware:
client = MockBaseChatClient()
agent = Agent(client=client, middleware=[kwargs_middleware])
# Execute the agent with custom parameters
# Execute the agent with runtime options
messages = [Message(role="user", text="test message")]
response = await agent.run(messages, temperature=0.7, max_tokens=100, custom_param="test_value")
response = await agent.run(
messages,
options={"temperature": 0.7, "max_tokens": 100, "custom_param": "test_value"},
)
# Verify response
assert response is not None
assert len(response.messages) > 0
# Verify middleware captured the original kwargs
assert captured_kwargs["temperature"] == 0.7
assert captured_kwargs["max_tokens"] == 100
assert captured_kwargs["custom_param"] == "test_value"
assert captured_options["temperature"] == 0.7
assert captured_options["max_tokens"] == 100
assert captured_options["custom_param"] == "test_value"
# Verify middleware could modify the kwargs
assert modified_kwargs["temperature"] == 0.9
assert modified_kwargs["max_tokens"] == 500
assert modified_kwargs["new_param"] == "added_by_middleware"
assert modified_kwargs["custom_param"] == "test_value" # Should still be there
assert modified_options["temperature"] == 0.9
assert modified_options["max_tokens"] == 500
assert modified_options["new_param"] == "added_by_middleware"
assert modified_options["custom_param"] == "test_value"
# class TestMiddlewareWithProtocolOnlyAgent:
@@ -2,6 +2,7 @@
from collections.abc import Awaitable, Callable
from typing import Any
from unittest.mock import patch
from agent_framework import (
Agent,
@@ -296,50 +297,77 @@ class TestChatMiddleware:
assert response3 is not None
assert execution_count["count"] == 2 # Should be 2 now
async def test_chat_client_middleware_can_access_and_override_custom_kwargs(
async def test_run_level_middleware_is_not_forwarded_to_inner_client(
self, chat_client_base: "MockBaseChatClient"
) -> None:
"""Test that chat client middleware can access and override custom parameters like temperature."""
captured_kwargs: dict[str, Any] = {}
modified_kwargs: dict[str, Any] = {}
"""Test that run-level middleware stays in the middleware pipeline only."""
observed_context_kwargs: dict[str, Any] = {}
@chat_middleware
async def inspecting_middleware(context: ChatContext, call_next: Callable[[], Awaitable[None]]) -> None:
observed_context_kwargs.update(context.kwargs)
await call_next()
async def fake_inner_get_response(**kwargs: Any) -> ChatResponse:
assert "middleware" not in kwargs
return ChatResponse(messages=[Message(role="assistant", text="ok")])
with patch.object(
chat_client_base,
"_inner_get_response",
side_effect=fake_inner_get_response,
) as mock_inner_get_response:
response = await chat_client_base.get_response(
[Message(role="user", text="hello")],
client_kwargs={"middleware": [inspecting_middleware], "trace_id": "trace-123"},
)
assert response.messages[0].text == "ok"
assert observed_context_kwargs == {"trace_id": "trace-123"}
mock_inner_get_response.assert_called_once()
async def test_chat_client_middleware_can_access_and_override_options(
self, chat_client_base: "MockBaseChatClient"
) -> None:
"""Test that chat client middleware can access and override runtime options."""
captured_options: dict[str, Any] = {}
modified_options: dict[str, Any] = {}
@chat_middleware
async def kwargs_middleware(context: ChatContext, call_next: Callable[[], Awaitable[None]]) -> None:
# Capture the original kwargs
captured_kwargs.update(context.kwargs)
assert isinstance(context.options, dict)
captured_options.update(context.options)
# Modify some kwargs
context.kwargs["temperature"] = 0.9
context.kwargs["max_tokens"] = 500
context.kwargs["new_param"] = "added_by_middleware"
context.options["temperature"] = 0.9
context.options["max_tokens"] = 500
context.options["new_param"] = "added_by_middleware"
# Store modified kwargs for verification
modified_kwargs.update(context.kwargs)
modified_options.update(context.options)
await call_next()
# Add middleware to chat client
chat_client_base.chat_middleware = [kwargs_middleware]
# Execute chat client with custom parameters
# Execute chat client with runtime options
messages = [Message(role="user", text="test message")]
response = await chat_client_base.get_response(
messages, temperature=0.7, max_tokens=100, custom_param="test_value"
messages,
options={"temperature": 0.7, "max_tokens": 100, "custom_param": "test_value"},
)
# Verify response
assert response is not None
assert len(response.messages) > 0
assert captured_kwargs["temperature"] == 0.7
assert captured_kwargs["max_tokens"] == 100
assert captured_kwargs["custom_param"] == "test_value"
assert captured_options["temperature"] == 0.7
assert captured_options["max_tokens"] == 100
assert captured_options["custom_param"] == "test_value"
# Verify middleware could modify the kwargs
assert modified_kwargs["temperature"] == 0.9
assert modified_kwargs["max_tokens"] == 500
assert modified_kwargs["new_param"] == "added_by_middleware"
assert modified_kwargs["custom_param"] == "test_value" # Should still be there
assert modified_options["temperature"] == 0.9
assert modified_options["max_tokens"] == 500
assert modified_options["new_param"] == "added_by_middleware"
assert modified_options["custom_param"] == "test_value"
def test_chat_middleware_pipeline_cache_reuses_matching_middleware(
self,
@@ -207,7 +207,7 @@ async def test_chat_client_observability(mock_chat_client, span_exporter: InMemo
messages = [Message(role="user", text="Test message")]
span_exporter.clear()
response = await client.get_response(messages=messages, model_id="Test")
response = await client.get_response(messages=messages, options={"model_id": "Test"})
assert response is not None
spans = span_exporter.get_finished_spans()
assert len(spans) == 1
@@ -232,7 +232,7 @@ async def test_chat_client_streaming_observability(
span_exporter.clear()
# Collect all yielded updates
updates = []
stream = client.get_response(stream=True, messages=messages, model_id="Test")
stream = client.get_response(stream=True, messages=messages, options={"model_id": "Test"})
async for update in stream:
updates.append(update)
await stream.get_final_response()
@@ -1540,7 +1540,7 @@ async def test_chat_client_observability_exception(mock_chat_client, span_export
span_exporter.clear()
with pytest.raises(ValueError, match="Test error"):
await client.get_response(messages=messages, model_id="Test")
await client.get_response(messages=messages, options={"model_id": "Test"})
spans = span_exporter.get_finished_spans()
assert len(spans) == 1
@@ -1570,7 +1570,7 @@ async def test_chat_client_streaming_observability_exception(mock_chat_client, s
span_exporter.clear()
with pytest.raises(ValueError, match="Streaming error"):
async for _ in client.get_response(messages=messages, stream=True, model_id="Test"):
async for _ in client.get_response(messages=messages, stream=True, options={"model_id": "Test"}):
pass
spans = span_exporter.get_finished_spans()
@@ -2075,7 +2075,7 @@ async def test_capture_messages_with_finish_reason(mock_chat_client, span_export
messages = [Message(role="user", text="Test")]
span_exporter.clear()
response = await client.get_response(messages=messages, model_id="Test")
response = await client.get_response(messages=messages, options={"model_id": "Test"})
assert response is not None
assert response.finish_reason == "stop"
@@ -2165,7 +2165,7 @@ async def test_chat_client_when_disabled(mock_chat_client, span_exporter: InMemo
messages = [Message(role="user", text="Test")]
span_exporter.clear()
response = await client.get_response(messages=messages, model_id="Test")
response = await client.get_response(messages=messages, options={"model_id": "Test"})
assert response is not None
spans = span_exporter.get_finished_spans()
@@ -2181,7 +2181,7 @@ async def test_chat_client_streaming_when_disabled(mock_chat_client, span_export
span_exporter.clear()
updates = []
async for update in client.get_response(messages=messages, stream=True, model_id="Test"):
async for update in client.get_response(messages=messages, stream=True, options={"model_id": "Test"}):
updates.append(update)
assert len(updates) == 2 # Still works functionally
@@ -2661,7 +2661,7 @@ async def test_capture_messages_preserves_non_ascii_characters(mock_chat_client,
messages = [Message(role="user", text=japanese_text)]
span_exporter.clear()
response = await client.get_response(messages=messages, model_id="Test")
response = await client.get_response(messages=messages, options={"model_id": "Test"})
assert response is not None
spans = span_exporter.get_finished_spans()
+15 -20
View File
@@ -594,8 +594,8 @@ async def test_tool_invoke_telemetry_sensitive_disabled(span_exporter: InMemoryS
assert attributes[OtelAttr.TOOL_CALL_ID] == "test_call_id"
async def test_tool_invoke_ignores_additional_kwargs() -> None:
"""Ensure tools drop unknown kwargs when invoked with validated arguments."""
async def test_tool_invoke_rejects_unexpected_runtime_kwargs() -> None:
"""Ensure invoke() requires runtime data to flow through FunctionInvocationContext."""
@tool
async def simple_tool(message: str) -> str:
@@ -604,15 +604,12 @@ async def test_tool_invoke_ignores_additional_kwargs() -> None:
args = simple_tool.input_model(message="hello world")
# These kwargs simulate runtime context passed through function invocation.
result = await simple_tool.invoke(
arguments=args,
api_token="secret-token",
options={"model_id": "dummy"},
)
assert isinstance(result, list)
assert result[0].text == "HELLO WORLD"
with pytest.raises(TypeError, match="Unexpected keyword argument"):
await simple_tool.invoke(
arguments=args,
api_token="secret-token",
options={"model_id": "dummy"},
)
async def test_tool_invoke_telemetry_with_pydantic_args(span_exporter: InMemorySpanExporter):
@@ -917,8 +914,8 @@ def test_parse_inputs_unsupported_type():
# endregion
async def test_ai_function_with_kwargs_injection():
"""Test that ai_function correctly handles kwargs injection and hides them from schema."""
async def test_ai_function_with_kwargs_rejects_runtime_invoke_kwargs():
"""Test that runtime kwargs must be passed through FunctionInvocationContext."""
@tool
def tool_with_kwargs(x: int, **kwargs: Any) -> str:
@@ -937,13 +934,11 @@ async def test_ai_function_with_kwargs_injection():
# Verify direct invocation works
assert tool_with_kwargs(1, user_id="user1") == "x=1, user=user1"
# Verify invoke works with injected args
result = await tool_with_kwargs.invoke(
arguments=tool_with_kwargs.input_model(x=5),
user_id="user2",
)
assert isinstance(result, list)
assert result[0].text == "x=5, user=user2"
with pytest.raises(TypeError, match="Unexpected keyword argument"):
await tool_with_kwargs.invoke(
arguments=tool_with_kwargs.input_model(x=5),
user_id="user2",
)
# Verify invoke works without injected args (uses default)
result_default = await tool_with_kwargs.invoke(
@@ -446,7 +446,7 @@ async def executor_with_real_agent() -> tuple[AgentFrameworkExecutor, str, MockB
name="Test Chat Agent",
description="A real Agent for testing execution flow",
client=mock_client,
system_message="You are a helpful test assistant.",
instructions="You are a helpful test assistant.",
)
# Register the real agent
@@ -478,14 +478,14 @@ async def sequential_workflow() -> tuple[AgentFrameworkExecutor, str, MockBaseCh
name="Writer",
description="Content writer agent",
client=mock_client,
system_message="You are a content writer. Create clear, engaging content.",
instructions="You are a content writer. Create clear, engaging content.",
)
reviewer = Agent(
id="reviewer",
name="Reviewer",
description="Content reviewer agent",
client=mock_client,
system_message="You are a reviewer. Provide constructive feedback.",
instructions="You are a reviewer. Provide constructive feedback.",
)
workflow = SequentialBuilder(participants=[writer, reviewer]).build()
@@ -523,21 +523,21 @@ async def concurrent_workflow() -> tuple[AgentFrameworkExecutor, str, MockBaseCh
name="Researcher",
description="Research agent",
client=mock_client,
system_message="You are a researcher. Find key data and insights.",
instructions="You are a researcher. Find key data and insights.",
)
analyst = Agent(
id="analyst",
name="Analyst",
description="Analysis agent",
client=mock_client,
system_message="You are an analyst. Identify trends and patterns.",
instructions="You are an analyst. Identify trends and patterns.",
)
summarizer = Agent(
id="summarizer",
name="Summarizer",
description="Summary agent",
client=mock_client,
system_message="You are a summarizer. Provide concise summaries.",
instructions="You are a summarizer. Provide concise summaries.",
)
workflow = ConcurrentBuilder(participants=[researcher, analyst, summarizer]).build()
@@ -309,7 +309,7 @@ async def test_full_pipeline_workflow_events_are_json_serializable():
name="Serialization Test Agent",
description="Agent for testing serialization",
client=mock_client,
system_message="You are a test assistant.",
instructions="You are a test assistant.",
)
agent_executor = AgentExecutor(id="agent_node", agent=agent)
@@ -23,6 +23,7 @@ from ._callbacks import AgentCallbackContext, AgentResponseCallbackProtocol
from ._durable_agent_state import (
DurableAgentState,
DurableAgentStateEntry,
DurableAgentStateMessage,
DurableAgentStateRequest,
DurableAgentStateResponse,
)
@@ -151,10 +152,11 @@ class AgentEntity:
try:
chat_messages: list[Message] = [
m.to_chat_message()
replayable_message
for entry in self.state.data.conversation_history
if not self._is_error_response(entry)
for m in entry.messages
if (replayable_message := self._to_replayable_message(m)) is not None
]
run_kwargs: dict[str, Any] = {"messages": chat_messages, "options": options}
@@ -190,6 +192,21 @@ class AgentEntity:
return error_response
@staticmethod
def _to_replayable_message(message: DurableAgentStateMessage) -> Message | None:
"""Convert persisted history into a message safe to replay into chat clients."""
chat_message = message.to_chat_message()
replayable_contents = [content for content in chat_message.contents if content.type != "reasoning"]
if not replayable_contents:
return None
return Message(
role=chat_message.role,
contents=replayable_contents,
author_name=chat_message.author_name,
additional_properties=chat_message.additional_properties,
)
async def _invoke_agent(
self,
run_kwargs: dict[str, Any],
@@ -21,7 +21,9 @@ from agent_framework_durabletask import (
DurableAgentStateData,
DurableAgentStateMessage,
DurableAgentStateRequest,
DurableAgentStateResponse,
DurableAgentStateTextContent,
DurableAgentStateTextReasoningContent,
RunRequest,
)
from agent_framework_durabletask._entities import DurableTaskEntityStateProvider
@@ -391,6 +393,54 @@ class TestAgentEntityRunAgent:
assert len(history) == 6
assert entity.state.message_count == 6
async def test_run_filters_reasoning_content_from_replayed_history(self) -> None:
"""Replayed durable history should not include reasoning-only content items."""
captured_messages: list[Message] = []
async def mock_run(*args, stream=False, **kwargs):
if stream:
raise TypeError("streaming not supported")
captured_messages.extend(kwargs["messages"])
return _agent_response("Response")
mock_agent = Mock()
mock_agent.run = mock_run
entity = _make_entity(mock_agent)
entity.state.data = DurableAgentStateData(
conversation_history=[
DurableAgentStateRequest(
correlation_id="corr-entity-prev-request",
created_at=datetime.now(),
messages=[
DurableAgentStateMessage(
role="user",
contents=[DurableAgentStateTextContent(text="Hi")],
)
],
),
DurableAgentStateResponse(
correlation_id="corr-entity-prev-response",
created_at=datetime.now(),
messages=[
DurableAgentStateMessage(
role="assistant",
contents=[
DurableAgentStateTextReasoningContent(text="Let me think."),
DurableAgentStateTextContent(text="Hello there."),
],
)
],
),
]
)
await entity.run({"message": "What next?", "correlationId": "corr-entity-replay"})
assert captured_messages
assert all(content.type != "reasoning" for message in captured_messages for content in message.contents)
assert [message.text for message in captured_messages] == ["Hi", "Hello there.", "What next?"]
class TestAgentEntityReset:
"""Test suite for the reset operation."""
@@ -27,6 +27,7 @@ from agent_framework import (
RawAgent,
load_settings,
)
from agent_framework._compaction import CompactionStrategy, TokenizerProtocol
from agent_framework.observability import AgentTelemetryLayer, ChatTelemetryLayer
from agent_framework_openai._chat_client import OpenAIChatOptions, RawOpenAIChatClient
from azure.ai.projects.aio import AIProjectClient
@@ -125,9 +126,13 @@ class RawFoundryAgentChatClient( # type: ignore[misc]
credential: AzureCredentialTypes | None = None,
project_client: AIProjectClient | None = None,
allow_preview: bool | None = None,
default_headers: Mapping[str, str] | None = None,
env_file_path: str | None = None,
env_file_encoding: str | None = None,
**kwargs: Any,
instruction_role: str | None = None,
compaction_strategy: CompactionStrategy | None = None,
tokenizer: TokenizerProtocol | None = None,
additional_properties: dict[str, Any] | None = None,
) -> None:
"""Initialize a raw Foundry Agent client.
@@ -141,9 +146,13 @@ class RawFoundryAgentChatClient( # type: ignore[misc]
credential: Azure credential for authentication.
project_client: An existing AIProjectClient to use.
allow_preview: Enables preview opt-in on internally-created AIProjectClient.
default_headers: Additional HTTP headers for requests made through the OpenAI client.
env_file_path: Path to .env file for settings.
env_file_encoding: Encoding for .env file.
kwargs: Additional keyword arguments.
instruction_role: The role to use for 'instruction' messages.
compaction_strategy: Optional per-client compaction override.
tokenizer: Optional tokenizer for compaction strategies.
additional_properties: Additional properties stored on the client instance.
"""
settings = load_settings(
FoundryAgentSettings,
@@ -189,7 +198,14 @@ class RawFoundryAgentChatClient( # type: ignore[misc]
# Get OpenAI client from project
async_client = self.project_client.get_openai_client()
super().__init__(async_client=async_client, **kwargs)
super().__init__(
async_client=async_client,
default_headers=default_headers,
instruction_role=instruction_role,
compaction_strategy=compaction_strategy,
tokenizer=tokenizer,
additional_properties=additional_properties,
)
def _get_agent_reference(self) -> dict[str, str]:
"""Build the agent reference dict for the Responses API."""
@@ -210,7 +226,10 @@ class RawFoundryAgentChatClient( # type: ignore[misc]
default_options: FoundryAgentOptionsT | Mapping[str, Any] | None = None,
context_providers: Sequence[BaseContextProvider] | None = None,
middleware: Sequence[MiddlewareTypes] | None = None,
**kwargs: Any,
function_invocation_configuration: FunctionInvocationConfiguration | None = None,
compaction_strategy: CompactionStrategy | None = None,
tokenizer: TokenizerProtocol | None = None,
additional_properties: Mapping[str, Any] | None = None,
) -> Agent[FoundryAgentOptionsT]:
"""Create a FoundryAgent that reuses this client's Foundry configuration."""
function_tools = cast(
@@ -233,7 +252,10 @@ class RawFoundryAgentChatClient( # type: ignore[misc]
description=description,
instructions=instructions,
default_options=default_options,
**kwargs,
function_invocation_configuration=function_invocation_configuration,
compaction_strategy=compaction_strategy,
tokenizer=tokenizer,
additional_properties=additional_properties,
),
)
@@ -365,11 +387,15 @@ class _FoundryAgentChatClient( # type: ignore[misc]
credential: AzureCredentialTypes | None = None,
project_client: AIProjectClient | None = None,
allow_preview: bool | None = None,
default_headers: Mapping[str, str] | None = None,
env_file_path: str | None = None,
env_file_encoding: str | None = None,
instruction_role: str | None = None,
compaction_strategy: CompactionStrategy | None = None,
tokenizer: TokenizerProtocol | None = None,
additional_properties: dict[str, Any] | None = None,
middleware: (Sequence[ChatAndFunctionMiddlewareTypes] | None) = None,
function_invocation_configuration: FunctionInvocationConfiguration | None = None,
**kwargs: Any,
) -> None:
"""Initialize a Foundry Agent client with full middleware support.
@@ -380,11 +406,15 @@ class _FoundryAgentChatClient( # type: ignore[misc]
credential: Azure credential for authentication.
project_client: An existing AIProjectClient to use.
allow_preview: Enables preview opt-in on internally-created AIProjectClient.
default_headers: Additional HTTP headers for requests made through the OpenAI client.
env_file_path: Path to .env file for settings.
env_file_encoding: Encoding for .env file.
instruction_role: The role to use for 'instruction' messages.
compaction_strategy: Optional per-client compaction override.
tokenizer: Optional tokenizer for compaction strategies.
additional_properties: Additional properties stored on the client instance.
middleware: Optional sequence of middleware.
function_invocation_configuration: Optional function invocation configuration.
kwargs: Additional keyword arguments.
"""
super().__init__(
project_endpoint=project_endpoint,
@@ -393,11 +423,15 @@ class _FoundryAgentChatClient( # type: ignore[misc]
credential=credential,
project_client=project_client,
allow_preview=allow_preview,
default_headers=default_headers,
env_file_path=env_file_path,
env_file_encoding=env_file_encoding,
instruction_role=instruction_role,
compaction_strategy=compaction_strategy,
tokenizer=tokenizer,
additional_properties=additional_properties,
middleware=middleware,
function_invocation_configuration=function_invocation_configuration,
**kwargs,
)
@@ -435,10 +469,19 @@ class RawFoundryAgent( # type: ignore[misc]
allow_preview: bool | None = None,
tools: FunctionTool | Callable[..., Any] | Sequence[FunctionTool | Callable[..., Any]] | None = None,
context_providers: Sequence[BaseContextProvider] | None = None,
middleware: Sequence[MiddlewareTypes] | None = None,
client_type: type[RawFoundryAgentChatClient] | None = None,
env_file_path: str | None = None,
env_file_encoding: str | None = None,
**kwargs: Any,
id: str | None = None,
name: str | None = None,
description: str | None = None,
instructions: str | None = None,
default_options: FoundryAgentOptionsT | Mapping[str, Any] | None = None,
function_invocation_configuration: FunctionInvocationConfiguration | None = None,
compaction_strategy: CompactionStrategy | None = None,
tokenizer: TokenizerProtocol | None = None,
additional_properties: Mapping[str, Any] | None = None,
) -> None:
"""Initialize a Foundry Agent.
@@ -454,11 +497,20 @@ class RawFoundryAgent( # type: ignore[misc]
allow_preview: Enables preview opt-in on internally-created AIProjectClient.
tools: Function tools to provide to the agent. Only ``FunctionTool`` objects are accepted.
context_providers: Optional context providers for injecting dynamic context.
middleware: Optional agent-level middleware.
client_type: Custom client class to use (must be a subclass of ``RawFoundryAgentChatClient``).
Defaults to ``_FoundryAgentChatClient`` (full client middleware).
env_file_path: Path to .env file for settings.
env_file_encoding: Encoding for .env file.
kwargs: Additional keyword arguments passed to the Agent base class.
id: Optional local agent identifier.
name: Optional display name for the local agent wrapper.
description: Optional local description for the local agent wrapper.
instructions: Optional instructions for the local agent wrapper.
default_options: Default chat options for the local agent wrapper.
function_invocation_configuration: Optional function invocation configuration override.
compaction_strategy: Optional agent-level in-run compaction override.
tokenizer: Optional agent-level tokenizer override.
additional_properties: Additional properties stored on the local agent wrapper.
"""
# Create the client
actual_client_type = client_type or _FoundryAgentChatClient
@@ -467,22 +519,38 @@ class RawFoundryAgent( # type: ignore[misc]
f"client_type must be a subclass of RawFoundryAgentChatClient, got {actual_client_type.__name__}"
)
client = actual_client_type(
project_endpoint=project_endpoint,
agent_name=agent_name,
agent_version=agent_version,
credential=credential,
project_client=project_client,
allow_preview=allow_preview,
env_file_path=env_file_path,
env_file_encoding=env_file_encoding,
)
client_kwargs: dict[str, Any] = {
"project_endpoint": project_endpoint,
"agent_name": agent_name,
"agent_version": agent_version,
"credential": credential,
"project_client": project_client,
"allow_preview": allow_preview,
"env_file_path": env_file_path,
"env_file_encoding": env_file_encoding,
}
if function_invocation_configuration is not None:
if not issubclass(actual_client_type, FunctionInvocationLayer):
raise TypeError(
"function_invocation_configuration requires a FunctionInvocationLayer-based client_type."
)
client_kwargs["function_invocation_configuration"] = function_invocation_configuration
client = actual_client_type(**client_kwargs)
super().__init__(
client=client, # type: ignore[arg-type]
instructions=instructions,
id=id,
name=name,
description=description,
tools=tools, # type: ignore[arg-type]
default_options=cast(FoundryAgentOptionsT | None, default_options),
context_providers=context_providers,
**kwargs,
middleware=middleware,
compaction_strategy=compaction_strategy,
tokenizer=tokenizer,
additional_properties=dict(additional_properties) if additional_properties is not None else None,
)
async def configure_azure_monitor(
@@ -598,7 +666,15 @@ class FoundryAgent( # type: ignore[misc]
client_type: type[RawFoundryAgentChatClient] | None = None,
env_file_path: str | None = None,
env_file_encoding: str | None = None,
**kwargs: Any,
id: str | None = None,
name: str | None = None,
description: str | None = None,
instructions: str | None = None,
default_options: FoundryAgentOptionsT | Mapping[str, Any] | None = None,
function_invocation_configuration: FunctionInvocationConfiguration | None = None,
compaction_strategy: CompactionStrategy | None = None,
tokenizer: TokenizerProtocol | None = None,
additional_properties: Mapping[str, Any] | None = None,
) -> None:
"""Initialize a Foundry Agent with full middleware and telemetry.
@@ -615,7 +691,15 @@ class FoundryAgent( # type: ignore[misc]
client_type: Custom client class (must subclass ``RawFoundryAgentChatClient``).
env_file_path: Path to .env file for settings.
env_file_encoding: Encoding for .env file.
kwargs: Additional keyword arguments.
id: Optional local agent identifier.
name: Optional display name for the local agent wrapper.
description: Optional local description for the local agent wrapper.
instructions: Optional instructions for the local agent wrapper.
default_options: Default chat options for the local agent wrapper.
function_invocation_configuration: Optional function invocation configuration override.
compaction_strategy: Optional agent-level in-run compaction override.
tokenizer: Optional agent-level tokenizer override.
additional_properties: Additional properties stored on the local agent wrapper.
"""
super().__init__(
project_endpoint=project_endpoint,
@@ -630,5 +714,13 @@ class FoundryAgent( # type: ignore[misc]
client_type=client_type,
env_file_path=env_file_path,
env_file_encoding=env_file_encoding,
**kwargs,
id=id,
name=name,
description=description,
instructions=instructions,
default_options=default_options,
function_invocation_configuration=function_invocation_configuration,
compaction_strategy=compaction_strategy,
tokenizer=tokenizer,
additional_properties=additional_properties,
)
@@ -4,7 +4,7 @@ from __future__ import annotations
import logging
import sys
from collections.abc import Awaitable, Callable, Sequence
from collections.abc import Awaitable, Callable, Mapping, Sequence
from typing import TYPE_CHECKING, Any, ClassVar, Generic, Literal
from agent_framework import (
@@ -15,6 +15,7 @@ from agent_framework import (
FunctionInvocationLayer,
load_settings,
)
from agent_framework._compaction import CompactionStrategy, TokenizerProtocol
from agent_framework.observability import ChatTelemetryLayer
from agent_framework_openai._chat_client import OpenAIChatOptions, RawOpenAIChatClient
from azure.ai.projects.aio import AIProjectClient
@@ -132,10 +133,13 @@ class RawFoundryChatClient( # type: ignore[misc]
model: str | None = None,
credential: AzureCredentialTypes | AzureTokenProvider | None = None,
allow_preview: bool | None = None,
default_headers: Mapping[str, str] | None = None,
env_file_path: str | None = None,
env_file_encoding: str | None = None,
instruction_role: str | None = None,
**kwargs: Any,
compaction_strategy: CompactionStrategy | None = None,
tokenizer: TokenizerProtocol | None = None,
additional_properties: dict[str, Any] | None = None,
) -> None:
"""Initialize a raw Microsoft Foundry chat client.
@@ -149,10 +153,13 @@ class RawFoundryChatClient( # type: ignore[misc]
credential: Azure credential or token provider for authentication.
Required when using ``project_endpoint`` without a ``project_client``.
allow_preview: Enables preview opt-in on internally-created AIProjectClient.
default_headers: Additional HTTP headers for requests made through the OpenAI client.
env_file_path: Path to .env file for settings.
env_file_encoding: Encoding for .env file.
instruction_role: The role to use for 'instruction' messages.
kwargs: Additional keyword arguments.
compaction_strategy: Optional per-client compaction override.
tokenizer: Optional tokenizer for compaction strategies.
additional_properties: Additional properties stored on the client instance.
"""
foundry_settings = load_settings(
FoundrySettings,
@@ -195,8 +202,11 @@ class RawFoundryChatClient( # type: ignore[misc]
super().__init__(
model=resolved_model,
async_client=project_client.get_openai_client(),
default_headers=default_headers,
instruction_role=instruction_role,
**kwargs,
compaction_strategy=compaction_strategy,
tokenizer=tokenizer,
additional_properties=additional_properties,
)
self.project_client = project_client
@@ -516,12 +526,15 @@ class FoundryChatClient( # type: ignore[misc]
model: str | None = None,
credential: AzureCredentialTypes | AzureTokenProvider | None = None,
allow_preview: bool | None = None,
default_headers: Mapping[str, str] | None = None,
env_file_path: str | None = None,
env_file_encoding: str | None = None,
instruction_role: str | None = None,
compaction_strategy: CompactionStrategy | None = None,
tokenizer: TokenizerProtocol | None = None,
additional_properties: dict[str, Any] | None = None,
middleware: (Sequence[ChatAndFunctionMiddlewareTypes] | None) = None,
function_invocation_configuration: FunctionInvocationConfiguration | None = None,
**kwargs: Any,
) -> None:
"""Initialize a Foundry chat client.
@@ -533,12 +546,15 @@ class FoundryChatClient( # type: ignore[misc]
Can also be set via environment variable ``FOUNDRY_MODEL``.
credential: Azure credential or token provider for authentication.
allow_preview: Enables preview opt-in on internally-created AIProjectClient.
default_headers: Additional HTTP headers for requests made through the OpenAI client.
env_file_path: Path to .env file for settings.
env_file_encoding: Encoding for .env file.
instruction_role: The role to use for 'instruction' messages.
compaction_strategy: Optional per-client compaction override.
tokenizer: Optional tokenizer for compaction strategies.
additional_properties: Additional properties stored on the client instance.
middleware: Optional sequence of middleware.
function_invocation_configuration: Optional function invocation configuration.
kwargs: Additional keyword arguments.
"""
super().__init__(
project_endpoint=project_endpoint,
@@ -546,10 +562,13 @@ class FoundryChatClient( # type: ignore[misc]
model=model,
credential=credential,
allow_preview=allow_preview,
default_headers=default_headers,
env_file_path=env_file_path,
env_file_encoding=env_file_encoding,
instruction_role=instruction_role,
compaction_strategy=compaction_strategy,
tokenizer=tokenizer,
additional_properties=additional_properties,
middleware=middleware,
function_invocation_configuration=function_invocation_configuration,
**kwargs,
)
@@ -2,6 +2,7 @@
from __future__ import annotations
import inspect
import os
import sys
from typing import Any
@@ -68,6 +69,17 @@ def test_raw_foundry_agent_chat_client_init_with_agent_name() -> None:
assert client.agent_version == "1.0"
def test_raw_foundry_agent_chat_client_init_uses_explicit_parameters() -> None:
signature = inspect.signature(RawFoundryAgentChatClient.__init__)
assert "default_headers" in signature.parameters
assert "instruction_role" in signature.parameters
assert "compaction_strategy" in signature.parameters
assert "tokenizer" in signature.parameters
assert "additional_properties" in signature.parameters
assert all(parameter.kind != inspect.Parameter.VAR_KEYWORD for parameter in signature.parameters.values())
def test_raw_foundry_agent_chat_client_get_agent_reference_with_version() -> None:
"""Test agent reference includes version when provided."""
@@ -129,6 +141,15 @@ def test_raw_foundry_agent_chat_client_as_agent_preserves_client_type() -> None:
assert named_agent.client.agent_name == "test-agent"
def test_raw_foundry_agent_chat_client_as_agent_uses_explicit_parameters() -> None:
signature = inspect.signature(RawFoundryAgentChatClient.as_agent)
assert "compaction_strategy" in signature.parameters
assert "tokenizer" in signature.parameters
assert "additional_properties" in signature.parameters
assert all(parameter.kind != inspect.Parameter.VAR_KEYWORD for parameter in signature.parameters.values())
async def test_raw_foundry_agent_chat_client_prepare_options_validates_tools() -> None:
"""Test that _prepare_options rejects non-FunctionTool objects."""
@@ -210,6 +231,17 @@ def test_foundry_agent_chat_client_init() -> None:
assert client.agent_name == "test-agent"
def test_foundry_agent_chat_client_init_uses_explicit_parameters() -> None:
signature = inspect.signature(_FoundryAgentChatClient.__init__)
assert "default_headers" in signature.parameters
assert "instruction_role" in signature.parameters
assert "compaction_strategy" in signature.parameters
assert "tokenizer" in signature.parameters
assert "additional_properties" in signature.parameters
assert all(parameter.kind != inspect.Parameter.VAR_KEYWORD for parameter in signature.parameters.values())
def test_raw_foundry_agent_init_creates_client() -> None:
"""Test that RawFoundryAgent creates a client internally."""
@@ -241,6 +273,28 @@ def test_raw_foundry_agent_init_with_custom_client_type() -> None:
assert isinstance(agent.client, RawFoundryAgentChatClient)
def test_raw_foundry_agent_init_uses_explicit_parameters() -> None:
signature = inspect.signature(RawFoundryAgent.__init__)
assert "instructions" in signature.parameters
assert "default_options" in signature.parameters
assert "compaction_strategy" in signature.parameters
assert "tokenizer" in signature.parameters
assert "additional_properties" in signature.parameters
assert all(parameter.kind != inspect.Parameter.VAR_KEYWORD for parameter in signature.parameters.values())
def test_foundry_agent_init_uses_explicit_parameters() -> None:
signature = inspect.signature(FoundryAgent.__init__)
assert "instructions" in signature.parameters
assert "default_options" in signature.parameters
assert "compaction_strategy" in signature.parameters
assert "tokenizer" in signature.parameters
assert "additional_properties" in signature.parameters
assert all(parameter.kind != inspect.Parameter.VAR_KEYWORD for parameter in signature.parameters.values())
def test_raw_foundry_agent_init_rejects_invalid_client_type() -> None:
"""Test that invalid client_type raises TypeError."""
@@ -2,6 +2,7 @@
from __future__ import annotations
import inspect
import json
import os
import sys
@@ -140,6 +141,26 @@ def test_init() -> None:
assert client.project_client is mock_project_client
def test_raw_foundry_chat_client_init_uses_explicit_parameters() -> None:
signature = inspect.signature(RawFoundryChatClient.__init__)
assert "default_headers" in signature.parameters
assert "compaction_strategy" in signature.parameters
assert "tokenizer" in signature.parameters
assert "additional_properties" in signature.parameters
assert all(parameter.kind != inspect.Parameter.VAR_KEYWORD for parameter in signature.parameters.values())
def test_foundry_chat_client_init_uses_explicit_parameters() -> None:
signature = inspect.signature(FoundryChatClient.__init__)
assert "default_headers" in signature.parameters
assert "compaction_strategy" in signature.parameters
assert "tokenizer" in signature.parameters
assert "additional_properties" in signature.parameters
assert all(parameter.kind != inspect.Parameter.VAR_KEYWORD for parameter in signature.parameters.values())
def test_init_with_default_header() -> None:
default_headers = {"X-Unit-Test": "test-guid"}
mock_openai_client = _make_mock_openai_client()
@@ -3,15 +3,21 @@
from __future__ import annotations
import sys
from collections.abc import Sequence
from typing import Any, Generic
from collections.abc import Awaitable, Callable, Mapping, Sequence
from typing import Any, Generic, Literal, cast, overload
from agent_framework import (
ChatAndFunctionMiddlewareTypes,
ChatMiddlewareLayer,
ChatOptions,
ChatResponse,
ChatResponseUpdate,
CompactionStrategy,
FunctionInvocationConfiguration,
FunctionInvocationLayer,
Message,
ResponseStream,
TokenizerProtocol,
)
from agent_framework._settings import load_settings
from agent_framework.observability import ChatTelemetryLayer
@@ -122,8 +128,8 @@ class FoundryLocalSettings(TypedDict, total=False):
'FOUNDRY_LOCAL_'.
Keys:
model_id: The name of the model deployment to use.
(Env var FOUNDRY_LOCAL_MODEL_ID)
model: The name of the model deployment to use.
(Env var FOUNDRY_LOCAL_MODEL)
"""
model: str | None
@@ -138,6 +144,78 @@ class FoundryLocalClient(
):
"""Foundry Local Chat completion class with middleware, telemetry, and function invocation support."""
@overload
def get_response(
self,
messages: Sequence[Message],
*,
stream: Literal[False] = ...,
options: ChatOptions[ResponseModelT],
compaction_strategy: CompactionStrategy | None = None,
tokenizer: TokenizerProtocol | None = None,
function_invocation_kwargs: Mapping[str, Any] | None = None,
client_kwargs: Mapping[str, Any] | None = None,
middleware: Sequence[ChatAndFunctionMiddlewareTypes] | None = None,
) -> Awaitable[ChatResponse[ResponseModelT]]: ...
@overload
def get_response(
self,
messages: Sequence[Message],
*,
stream: Literal[False] = ...,
options: FoundryLocalChatOptionsT | ChatOptions[None] | None = None,
compaction_strategy: CompactionStrategy | None = None,
tokenizer: TokenizerProtocol | None = None,
function_invocation_kwargs: Mapping[str, Any] | None = None,
client_kwargs: Mapping[str, Any] | None = None,
middleware: Sequence[ChatAndFunctionMiddlewareTypes] | None = None,
) -> Awaitable[ChatResponse[Any]]: ...
@overload
def get_response(
self,
messages: Sequence[Message],
*,
stream: Literal[True],
options: FoundryLocalChatOptionsT | ChatOptions[Any] | None = None,
compaction_strategy: CompactionStrategy | None = None,
tokenizer: TokenizerProtocol | None = None,
function_invocation_kwargs: Mapping[str, Any] | None = None,
client_kwargs: Mapping[str, Any] | None = None,
middleware: Sequence[ChatAndFunctionMiddlewareTypes] | None = None,
) -> ResponseStream[ChatResponseUpdate, ChatResponse[Any]]: ...
def get_response(
self,
messages: Sequence[Message],
*,
stream: bool = False,
options: FoundryLocalChatOptionsT | ChatOptions[Any] | None = None,
compaction_strategy: CompactionStrategy | None = None,
tokenizer: TokenizerProtocol | None = None,
function_invocation_kwargs: Mapping[str, Any] | None = None,
client_kwargs: Mapping[str, Any] | None = None,
middleware: Sequence[ChatAndFunctionMiddlewareTypes] | None = None,
) -> Awaitable[ChatResponse[Any]] | ResponseStream[ChatResponseUpdate, ChatResponse[Any]]:
"""Get a response from the Foundry Local chat client with all standard layers enabled."""
super_get_response = cast(
"Callable[..., Awaitable[ChatResponse[Any]] | ResponseStream[ChatResponseUpdate, ChatResponse[Any]]]",
super().get_response,
)
effective_client_kwargs = dict(client_kwargs) if client_kwargs is not None else {}
if middleware is not None:
effective_client_kwargs["middleware"] = middleware
return super_get_response(
messages=messages,
stream=stream,
options=options,
compaction_strategy=compaction_strategy,
tokenizer=tokenizer,
function_invocation_kwargs=function_invocation_kwargs,
client_kwargs=effective_client_kwargs,
)
def __init__(
self,
model: str | None = None,
@@ -182,7 +260,7 @@ class FoundryLocalClient(
# Create a FoundryLocalClient with a specific model ID:
from agent_framework.foundry import FoundryLocalClient
client = FoundryLocalClient(model_id="phi-4-mini")
client = FoundryLocalClient(model="phi-4-mini")
agent = client.as_agent(
name="LocalAgent",
@@ -192,7 +270,7 @@ class FoundryLocalClient(
response = await agent.run("What's the weather like in Seattle?")
# Or you can set the model id in the environment:
os.environ["FOUNDRY_LOCAL_MODEL_ID"] = "phi-4-mini"
os.environ["FOUNDRY_LOCAL_MODEL"] = "phi-4-mini"
client = FoundryLocalClient()
# A FoundryLocalManager is created and if set, the service is started.
@@ -205,12 +283,12 @@ class FoundryLocalClient(
from foundry_local.models import DeviceType
client = FoundryLocalClient(
model_id="phi-4-mini",
model="phi-4-mini",
device=DeviceType.GPU,
)
# and choosing if the model should be prepared on initialization:
client = FoundryLocalClient(
model_id="phi-4-mini",
model="phi-4-mini",
prepare_model=False,
)
# Beware, in this case the first request to generate a completion
@@ -230,7 +308,7 @@ class FoundryLocalClient(
class MyOptions(FoundryLocalChatOptions, total=False):
my_custom_option: str
client: FoundryLocalClient[MyOptions] = FoundryLocalClient(model_id="phi-4-mini")
client: FoundryLocalClient[MyOptions] = FoundryLocalClient(model="phi-4-mini")
response = await client.get_response("Hello", options={"my_custom_option": "value"})
Raises:
@@ -1,5 +1,6 @@
# Copyright (c) Microsoft. All rights reserved.
import inspect
from unittest.mock import MagicMock, patch
import pytest
@@ -66,6 +67,15 @@ def test_foundry_local_client_init(mock_foundry_local_manager: MagicMock) -> Non
assert isinstance(client, SupportsChatGetResponse)
def test_foundry_local_client_get_response_uses_explicit_runtime_buckets() -> None:
"""Foundry Local should expose explicit runtime buckets instead of raw kwargs."""
signature = inspect.signature(FoundryLocalClient.get_response)
assert "client_kwargs" in signature.parameters
assert "function_invocation_kwargs" in signature.parameters
assert all(parameter.kind != inspect.Parameter.VAR_KEYWORD for parameter in signature.parameters.values())
def test_foundry_local_client_init_with_bootstrap_false(mock_foundry_local_manager: MagicMock) -> None:
"""Test FoundryLocalClient initialization with bootstrap=False."""
with patch(
@@ -211,7 +211,7 @@ class TaskRunner:
client=assistant_chat_client,
instructions=assistant_system_prompt,
tools=tools,
temperature=self.assistant_sampling_temperature,
default_options={"temperature": self.assistant_sampling_temperature},
context_providers=[
SlidingWindowHistoryProvider(
system_message=assistant_system_prompt,
@@ -246,7 +246,7 @@ class TaskRunner:
return Agent(
client=user_simuator_chat_client,
instructions=user_sim_system_prompt,
temperature=0.0,
default_options={"temperature": 0.0},
# No sliding window for user simulator to maintain full conversation context
# TODO(yuge): Consider adding user tools in future for more realistic scenarios
)
@@ -17,7 +17,7 @@ else:
from ._assistant_provider import OpenAIAssistantProvider
from ._assistants_client import (
AssistantToolResources,
OpenAIAssistantsClient,
OpenAIAssistantsClient, # type: ignore[reportDeprecated]
OpenAIAssistantsOptions,
)
from ._chat_client import (
@@ -15,7 +15,7 @@ from openai import AsyncOpenAI
from openai.types.beta.assistant import Assistant
from pydantic import BaseModel
from ._assistants_client import OpenAIAssistantsClient
from ._assistants_client import OpenAIAssistantsClient # type: ignore[reportDeprecated]
from ._shared import OpenAISettings, from_assistant_tools, to_assistant_tools
if TYPE_CHECKING:
@@ -538,7 +538,7 @@ class OpenAIAssistantProvider(Generic[OptionsCoT]):
A configured Agent instance.
"""
# Create the chat client with the assistant
client = OpenAIAssistantsClient(
client = OpenAIAssistantsClient( # type: ignore[reportDeprecated]
model=assistant.model,
assistant_id=assistant.id,
assistant_name=assistant.name,
@@ -70,6 +70,11 @@ if sys.version_info >= (3, 12):
else:
from typing_extensions import override # type: ignore # pragma: no cover
if sys.version_info >= (3, 13):
from warnings import deprecated # type: ignore # pragma: no cover
else:
from typing_extensions import deprecated # type: ignore # pragma: no cover
if sys.version_info >= (3, 11):
from typing import Self, TypedDict # type: ignore # pragma: no cover
else:
@@ -208,6 +213,7 @@ OpenAIAssistantsOptionsT = TypeVar(
# endregion
@deprecated("OpenAIAssistantsClient is deprecated. Use OpenAIChatClient instead.")
class OpenAIAssistantsClient( # type: ignore[misc]
OpenAIConfigMixin,
FunctionInvocationLayer[OpenAIAssistantsOptionsT],
@@ -216,7 +222,11 @@ class OpenAIAssistantsClient( # type: ignore[misc]
BaseChatClient[OpenAIAssistantsOptionsT],
Generic[OpenAIAssistantsOptionsT],
):
"""OpenAI Assistants client with middleware, telemetry, and function invocation support."""
"""OpenAI Assistants client with middleware, telemetry, and function invocation support.
.. deprecated::
OpenAIAssistantsClient is deprecated. Use :class:`OpenAIChatClient` instead.
"""
# region Hosted Tool Factory Methods
@@ -29,6 +29,7 @@ from typing import (
)
from agent_framework._clients import BaseChatClient
from agent_framework._compaction import CompactionStrategy, TokenizerProtocol
from agent_framework._middleware import ChatAndFunctionMiddlewareTypes, ChatMiddlewareLayer
from agent_framework._settings import SecretString
from agent_framework._telemetry import USER_AGENT_KEY
@@ -278,6 +279,9 @@ class RawOpenAIChatClient( # type: ignore[misc]
default_headers: Mapping[str, str] | None = None,
async_client: AsyncOpenAI | None = None,
instruction_role: str | None = None,
compaction_strategy: CompactionStrategy | None = None,
tokenizer: TokenizerProtocol | None = None,
additional_properties: dict[str, Any] | None = None,
env_file_path: str | None = None,
env_file_encoding: str | None = None,
) -> None:
@@ -295,6 +299,9 @@ class RawOpenAIChatClient( # type: ignore[misc]
default_headers: Additional HTTP headers.
async_client: Pre-configured OpenAI client.
instruction_role: Role for instruction messages (for example ``"system"``).
compaction_strategy: Optional per-client compaction override.
tokenizer: Optional tokenizer for compaction strategies.
additional_properties: Additional properties stored on the client instance.
env_file_path: Optional ``.env`` file that is checked before the process environment
for ``OPENAI_*`` values.
env_file_encoding: Encoding for the ``.env`` file.
@@ -314,6 +321,9 @@ class RawOpenAIChatClient( # type: ignore[misc]
default_headers: Mapping[str, str] | None = None,
async_client: AsyncAzureOpenAI | AsyncOpenAI | None = None,
instruction_role: str | None = None,
compaction_strategy: CompactionStrategy | None = None,
tokenizer: TokenizerProtocol | None = None,
additional_properties: dict[str, Any] | None = None,
env_file_path: str | None = None,
env_file_encoding: str | None = None,
) -> None:
@@ -338,6 +348,9 @@ class RawOpenAIChatClient( # type: ignore[misc]
async_client: Pre-configured client. Passing ``AsyncAzureOpenAI`` keeps the client on
Azure; passing ``AsyncOpenAI`` keeps the client on OpenAI and bypasses env lookup.
instruction_role: Role for instruction messages (for example ``"system"``).
compaction_strategy: Optional per-client compaction override.
tokenizer: Optional tokenizer for compaction strategies.
additional_properties: Additional properties stored on the client instance.
env_file_path: Optional ``.env`` file that is checked before process environment
variables for ``AZURE_OPENAI_*`` values.
env_file_encoding: Encoding for the ``.env`` file.
@@ -358,9 +371,11 @@ class RawOpenAIChatClient( # type: ignore[misc]
default_headers: Mapping[str, str] | None = None,
async_client: AsyncOpenAI | None = None,
instruction_role: str | None = None,
compaction_strategy: CompactionStrategy | None = None,
tokenizer: TokenizerProtocol | None = None,
additional_properties: dict[str, Any] | None = None,
env_file_path: str | None = None,
env_file_encoding: str | None = None,
**kwargs: Any,
) -> None:
"""Initialize a raw OpenAI Chat client.
@@ -391,11 +406,13 @@ class RawOpenAIChatClient( # type: ignore[misc]
async_client: Pre-configured client. Passing ``AsyncAzureOpenAI`` keeps the client on
Azure; passing ``AsyncOpenAI`` keeps the client on OpenAI and bypasses env lookup.
instruction_role: Role for instruction messages (for example ``"system"``).
compaction_strategy: Optional per-client compaction override.
tokenizer: Optional tokenizer for compaction strategies.
additional_properties: Additional properties stored on the client instance.
env_file_path: Optional ``.env`` file that is checked before process environment
variables. The same file is used for both ``OPENAI_*`` and ``AZURE_OPENAI_*``
lookups.
env_file_encoding: Encoding for the ``.env`` file.
kwargs: Additional keyword arguments forwarded to ``BaseChatClient``.
Notes:
Environment resolution and routing precedence are:
@@ -452,7 +469,11 @@ class RawOpenAIChatClient( # type: ignore[misc]
if use_azure_client:
self.OTEL_PROVIDER_NAME = "azure.ai.openai" # type: ignore[misc]
super().__init__(**kwargs)
super().__init__(
compaction_strategy=compaction_strategy,
tokenizer=tokenizer,
additional_properties=additional_properties,
)
# region Inner Methods
@@ -460,7 +481,6 @@ class RawOpenAIChatClient( # type: ignore[misc]
self,
messages: Sequence[Message],
options: Mapping[str, Any],
**kwargs: Any,
) -> tuple[AsyncOpenAI, dict[str, Any], dict[str, Any]]:
"""Validate options and prepare the request.
@@ -469,7 +489,7 @@ class RawOpenAIChatClient( # type: ignore[misc]
"""
client = self.client
validated_options = await self._validate_options(options)
run_options = await self._prepare_options(messages, validated_options, **kwargs)
run_options = await self._prepare_options(messages, validated_options)
return client, run_options, validated_options
def _handle_request_error(self, ex: Exception) -> NoReturn:
@@ -526,7 +546,7 @@ class RawOpenAIChatClient( # type: ignore[misc]
client,
run_options,
validated_options,
) = await self._prepare_request(messages, options, **kwargs)
) = await self._prepare_request(messages, options)
try:
if "text_format" in run_options:
async with client.responses.stream(**run_options) as response:
@@ -560,7 +580,7 @@ class RawOpenAIChatClient( # type: ignore[misc]
except Exception as ex:
self._handle_request_error(ex)
return self._parse_response_from_openai(response, options=validated_options)
client, run_options, validated_options = await self._prepare_request(messages, options, **kwargs)
client, run_options, validated_options = await self._prepare_request(messages, options)
try:
if "text_format" in run_options:
response = await client.responses.parse(stream=False, **run_options)
@@ -1121,7 +1141,6 @@ class RawOpenAIChatClient( # type: ignore[misc]
self,
messages: Sequence[Message],
options: Mapping[str, Any],
**kwargs: Any,
) -> dict[str, Any]:
"""Take options dict and create the specific options for Responses API."""
# Exclude keys that are not supported or handled separately
@@ -1143,7 +1162,7 @@ class RawOpenAIChatClient( # type: ignore[misc]
# messages
# Handle instructions by prepending to messages as system message
# Only prepend instructions for the first turn (when no conversation/response ID exists)
conversation_id = self._get_current_conversation_id(options, **kwargs)
conversation_id = options.get("conversation_id")
if (instructions := options.get("instructions")) and not conversation_id:
# First turn: prepend instructions as system message
messages = prepend_instructions_to_messages(list(messages), instructions, role="system")
@@ -1151,7 +1170,7 @@ class RawOpenAIChatClient( # type: ignore[misc]
request_input = self._prepare_messages_for_openai(messages)
if not request_input:
raise ChatClientInvalidRequestException("Messages are required for chat completions")
conversation_id = self._get_current_conversation_id(options, **kwargs)
conversation_id = options.get("conversation_id")
run_options["input"] = request_input
# model id
@@ -1169,7 +1188,7 @@ class RawOpenAIChatClient( # type: ignore[misc]
run_options[new_key] = run_options.pop(old_key)
# Handle different conversation ID formats
if conversation_id := self._get_current_conversation_id(options, **kwargs):
if conversation_id := options.get("conversation_id"):
if conversation_id.startswith("resp_"):
# For response IDs, set previous_response_id and remove conversation property
run_options["previous_response_id"] = conversation_id
@@ -1223,14 +1242,6 @@ class RawOpenAIChatClient( # type: ignore[misc]
raise ValueError("model must be a non-empty string")
options["model"] = self.model
def _get_current_conversation_id(self, options: Mapping[str, Any], **kwargs: Any) -> str | None:
"""Get the current conversation ID, preferring kwargs over options.
This ensures runtime-updated conversation IDs (for example, from tool execution
loops) take precedence over the initial configuration provided in options.
"""
return kwargs.get("conversation_id") or options.get("conversation_id")
def _prepare_messages_for_openai(self, chat_messages: Sequence[Message]) -> list[dict[str, Any]]:
"""Prepare the chat messages for a request.
@@ -2490,10 +2501,13 @@ class OpenAIChatClient( # type: ignore[misc]
default_headers: Mapping[str, str] | None = None,
async_client: AsyncOpenAI | None = None,
instruction_role: str | None = None,
env_file_path: str | None = None,
env_file_encoding: str | None = None,
compaction_strategy: CompactionStrategy | None = None,
tokenizer: TokenizerProtocol | None = None,
middleware: Sequence[ChatAndFunctionMiddlewareTypes] | None = None,
function_invocation_configuration: FunctionInvocationConfiguration | None = None,
additional_properties: dict[str, Any] | None = None,
env_file_path: str | None = None,
env_file_encoding: str | None = None,
) -> None:
"""Initialize an OpenAI Responses client.
@@ -2509,11 +2523,14 @@ class OpenAIChatClient( # type: ignore[misc]
default_headers: Additional HTTP headers.
async_client: Pre-configured OpenAI client.
instruction_role: Role for instruction messages (for example ``"system"``).
compaction_strategy: Optional per-client compaction override.
tokenizer: Optional tokenizer for compaction strategies.
middleware: Optional middleware to apply to the client.
function_invocation_configuration: Optional function invocation configuration override.
additional_properties: Optional additional properties to include on all requests.
env_file_path: Optional ``.env`` file that is checked before the process environment
for ``OPENAI_*`` values.
env_file_encoding: Encoding for the ``.env`` file.
middleware: Optional middleware to apply to the client.
function_invocation_configuration: Optional function invocation configuration override.
"""
...
@@ -2530,10 +2547,13 @@ class OpenAIChatClient( # type: ignore[misc]
default_headers: Mapping[str, str] | None = None,
async_client: AsyncAzureOpenAI | AsyncOpenAI | None = None,
instruction_role: str | None = None,
env_file_path: str | None = None,
env_file_encoding: str | None = None,
compaction_strategy: CompactionStrategy | None = None,
tokenizer: TokenizerProtocol | None = None,
middleware: Sequence[ChatAndFunctionMiddlewareTypes] | None = None,
function_invocation_configuration: FunctionInvocationConfiguration | None = None,
additional_properties: dict[str, Any] | None = None,
env_file_path: str | None = None,
env_file_encoding: str | None = None,
) -> None:
"""Initialize an OpenAI Responses client.
@@ -2556,11 +2576,14 @@ class OpenAIChatClient( # type: ignore[misc]
async_client: Pre-configured client. Passing ``AsyncAzureOpenAI`` keeps the client on
Azure; passing ``AsyncOpenAI`` keeps the client on OpenAI and bypasses env lookup.
instruction_role: Role for instruction messages (for example ``"system"``).
compaction_strategy: Optional per-client compaction override.
tokenizer: Optional tokenizer for compaction strategies.
middleware: Optional middleware to apply to the client.
function_invocation_configuration: Optional function invocation configuration override.
additional_properties: Optional additional properties to include on all requests.
env_file_path: Optional ``.env`` file that is checked before process environment
variables for ``AZURE_OPENAI_*`` values.
env_file_encoding: Encoding for the ``.env`` file.
middleware: Optional middleware to apply to the client.
function_invocation_configuration: Optional function invocation configuration override.
"""
...
@@ -2577,11 +2600,13 @@ class OpenAIChatClient( # type: ignore[misc]
default_headers: Mapping[str, str] | None = None,
async_client: AsyncOpenAI | None = None,
instruction_role: str | None = None,
env_file_path: str | None = None,
env_file_encoding: str | None = None,
compaction_strategy: CompactionStrategy | None = None,
tokenizer: TokenizerProtocol | None = None,
middleware: Sequence[ChatAndFunctionMiddlewareTypes] | None = None,
function_invocation_configuration: FunctionInvocationConfiguration | None = None,
**kwargs: Any,
additional_properties: dict[str, Any] | None = None,
env_file_path: str | None = None,
env_file_encoding: str | None = None,
) -> None:
"""Initialize an OpenAI Responses client.
@@ -2611,13 +2636,15 @@ class OpenAIChatClient( # type: ignore[misc]
async_client: Pre-configured client. Passing ``AsyncAzureOpenAI`` keeps the client on
Azure; passing ``AsyncOpenAI`` keeps the client on OpenAI and bypasses env lookup.
instruction_role: Role to use for instruction messages (for example ``"system"``).
compaction_strategy: Optional per-client compaction override.
tokenizer: Optional tokenizer for compaction strategies.
middleware: Optional middleware to apply to the client.
function_invocation_configuration: Optional function invocation configuration override.
additional_properties: Additional properties stored on the client instance.
env_file_path: Optional ``.env`` file that is checked before process environment
variables. The same file is used for both ``OPENAI_*`` and ``AZURE_OPENAI_*``
lookups.
env_file_encoding: Encoding for the ``.env`` file.
middleware: Optional middleware to apply to the client.
function_invocation_configuration: Optional function invocation configuration override.
kwargs: Other keyword parameters.
Notes:
Environment resolution and routing precedence are:
@@ -2675,7 +2702,9 @@ class OpenAIChatClient( # type: ignore[misc]
env_file_encoding=env_file_encoding,
middleware=middleware,
function_invocation_configuration=function_invocation_configuration,
**kwargs,
compaction_strategy=compaction_strategy,
tokenizer=tokenizer,
additional_properties=additional_properties,
)
@@ -18,6 +18,7 @@ from itertools import chain
from typing import TYPE_CHECKING, Any, ClassVar, Generic, Literal, cast, overload
from agent_framework._clients import BaseChatClient
from agent_framework._compaction import CompactionStrategy, TokenizerProtocol
from agent_framework._docstrings import apply_layered_docstring
from agent_framework._middleware import ChatAndFunctionMiddlewareTypes, ChatMiddlewareLayer
from agent_framework._settings import SecretString
@@ -193,6 +194,9 @@ class RawOpenAIChatCompletionClient( # type: ignore[misc]
default_headers: Mapping[str, str] | None = None,
async_client: AsyncOpenAI | None = None,
instruction_role: str | None = None,
compaction_strategy: CompactionStrategy | None = None,
tokenizer: TokenizerProtocol | None = None,
additional_properties: dict[str, Any] | None = None,
env_file_path: str | None = None,
env_file_encoding: str | None = None,
) -> None:
@@ -210,6 +214,9 @@ class RawOpenAIChatCompletionClient( # type: ignore[misc]
default_headers: Additional HTTP headers.
async_client: Pre-configured OpenAI client.
instruction_role: Role for instruction messages (for example ``"system"``).
compaction_strategy: Optional per-client compaction override.
tokenizer: Optional tokenizer for compaction strategies.
additional_properties: Additional properties stored on the client instance.
env_file_path: Optional ``.env`` file that is checked before the process environment
for ``OPENAI_*`` values.
env_file_encoding: Encoding for the ``.env`` file.
@@ -229,6 +236,9 @@ class RawOpenAIChatCompletionClient( # type: ignore[misc]
default_headers: Mapping[str, str] | None = None,
async_client: AsyncAzureOpenAI | AsyncOpenAI | None = None,
instruction_role: str | None = None,
compaction_strategy: CompactionStrategy | None = None,
tokenizer: TokenizerProtocol | None = None,
additional_properties: dict[str, Any] | None = None,
env_file_path: str | None = None,
env_file_encoding: str | None = None,
) -> None:
@@ -253,6 +263,9 @@ class RawOpenAIChatCompletionClient( # type: ignore[misc]
async_client: Pre-configured client. Passing ``AsyncAzureOpenAI`` keeps the client on
Azure; passing ``AsyncOpenAI`` keeps the client on OpenAI and bypasses env lookup.
instruction_role: Role for instruction messages (for example ``"system"``).
compaction_strategy: Optional per-client compaction override.
tokenizer: Optional tokenizer for compaction strategies.
additional_properties: Additional properties stored on the client instance.
env_file_path: Optional ``.env`` file that is checked before process environment
variables for ``AZURE_OPENAI_*`` values.
env_file_encoding: Encoding for the ``.env`` file.
@@ -273,9 +286,11 @@ class RawOpenAIChatCompletionClient( # type: ignore[misc]
default_headers: Mapping[str, str] | None = None,
async_client: AsyncOpenAI | None = None,
instruction_role: str | None = None,
compaction_strategy: CompactionStrategy | None = None,
tokenizer: TokenizerProtocol | None = None,
additional_properties: dict[str, Any] | None = None,
env_file_path: str | None = None,
env_file_encoding: str | None = None,
**kwargs: Any,
) -> None:
"""Initialize a raw OpenAI Chat completion client.
@@ -306,11 +321,13 @@ class RawOpenAIChatCompletionClient( # type: ignore[misc]
async_client: Pre-configured client. Passing ``AsyncAzureOpenAI`` keeps the client on
Azure; passing ``AsyncOpenAI`` keeps the client on OpenAI and bypasses env lookup.
instruction_role: Role for instruction messages (for example ``"system"``).
compaction_strategy: Optional per-client compaction override.
tokenizer: Optional tokenizer for compaction strategies.
additional_properties: Additional properties stored on the client instance.
env_file_path: Optional ``.env`` file that is checked before process environment
variables. The same file is used for both ``OPENAI_*`` and ``AZURE_OPENAI_*``
lookups.
env_file_encoding: Encoding for the ``.env`` file.
kwargs: Additional keyword arguments forwarded to ``BaseChatClient``.
Notes:
Environment resolution and routing precedence are:
@@ -366,7 +383,11 @@ class RawOpenAIChatCompletionClient( # type: ignore[misc]
if use_azure_client:
self.OTEL_PROVIDER_NAME = "azure.ai.openai" # type: ignore[misc]
super().__init__(**kwargs)
super().__init__(
compaction_strategy=compaction_strategy,
tokenizer=tokenizer,
additional_properties=additional_properties,
)
# region Hosted Tool Factory Methods
@@ -427,7 +448,10 @@ class RawOpenAIChatCompletionClient( # type: ignore[misc]
*,
stream: Literal[False] = ...,
options: ChatOptions[ResponseModelBoundT],
**kwargs: Any,
compaction_strategy: CompactionStrategy | None = None,
tokenizer: TokenizerProtocol | None = None,
function_invocation_kwargs: Mapping[str, Any] | None = None,
client_kwargs: Mapping[str, Any] | None = None,
) -> Awaitable[ChatResponse[ResponseModelBoundT]]: ...
@overload
@@ -437,7 +461,10 @@ class RawOpenAIChatCompletionClient( # type: ignore[misc]
*,
stream: Literal[False] = ...,
options: OpenAIChatCompletionOptionsT | ChatOptions[None] | None = None,
**kwargs: Any,
compaction_strategy: CompactionStrategy | None = None,
tokenizer: TokenizerProtocol | None = None,
function_invocation_kwargs: Mapping[str, Any] | None = None,
client_kwargs: Mapping[str, Any] | None = None,
) -> Awaitable[ChatResponse[Any]]: ...
@overload
@@ -447,7 +474,10 @@ class RawOpenAIChatCompletionClient( # type: ignore[misc]
*,
stream: Literal[True],
options: OpenAIChatCompletionOptionsT | ChatOptions[Any] | None = None,
**kwargs: Any,
compaction_strategy: CompactionStrategy | None = None,
tokenizer: TokenizerProtocol | None = None,
function_invocation_kwargs: Mapping[str, Any] | None = None,
client_kwargs: Mapping[str, Any] | None = None,
) -> ResponseStream[ChatResponseUpdate, ChatResponse[Any]]: ...
@override
@@ -457,7 +487,10 @@ class RawOpenAIChatCompletionClient( # type: ignore[misc]
*,
stream: bool = False,
options: OpenAIChatCompletionOptionsT | ChatOptions[Any] | None = None,
**kwargs: Any,
compaction_strategy: CompactionStrategy | None = None,
tokenizer: TokenizerProtocol | None = None,
function_invocation_kwargs: Mapping[str, Any] | None = None,
client_kwargs: Mapping[str, Any] | None = None,
) -> Awaitable[ChatResponse[Any]] | ResponseStream[ChatResponseUpdate, ChatResponse[Any]]:
"""Get a response from the raw OpenAI chat client."""
super_get_response = cast(
@@ -468,7 +501,10 @@ class RawOpenAIChatCompletionClient( # type: ignore[misc]
messages=messages,
stream=stream,
options=options,
**kwargs,
compaction_strategy=compaction_strategy,
tokenizer=tokenizer,
function_invocation_kwargs=function_invocation_kwargs,
client_kwargs=client_kwargs,
)
@override
@@ -1205,10 +1241,11 @@ class OpenAIChatCompletionClient( # type: ignore[misc]
*,
stream: Literal[False] = ...,
options: ChatOptions[ResponseModelBoundT],
compaction_strategy: CompactionStrategy | None = None,
tokenizer: TokenizerProtocol | None = None,
function_invocation_kwargs: Mapping[str, Any] | None = None,
client_kwargs: Mapping[str, Any] | None = None,
middleware: Sequence[ChatAndFunctionMiddlewareTypes] | None = None,
**kwargs: Any,
) -> Awaitable[ChatResponse[ResponseModelBoundT]]: ...
@overload
@@ -1218,10 +1255,11 @@ class OpenAIChatCompletionClient( # type: ignore[misc]
*,
stream: Literal[False] = ...,
options: OpenAIChatCompletionOptionsT | ChatOptions[None] | None = None,
compaction_strategy: CompactionStrategy | None = None,
tokenizer: TokenizerProtocol | None = None,
function_invocation_kwargs: Mapping[str, Any] | None = None,
client_kwargs: Mapping[str, Any] | None = None,
middleware: Sequence[ChatAndFunctionMiddlewareTypes] | None = None,
**kwargs: Any,
) -> Awaitable[ChatResponse[Any]]: ...
@overload
@@ -1231,10 +1269,11 @@ class OpenAIChatCompletionClient( # type: ignore[misc]
*,
stream: Literal[True],
options: OpenAIChatCompletionOptionsT | ChatOptions[Any] | None = None,
compaction_strategy: CompactionStrategy | None = None,
tokenizer: TokenizerProtocol | None = None,
function_invocation_kwargs: Mapping[str, Any] | None = None,
client_kwargs: Mapping[str, Any] | None = None,
middleware: Sequence[ChatAndFunctionMiddlewareTypes] | None = None,
**kwargs: Any,
) -> ResponseStream[ChatResponseUpdate, ChatResponse[Any]]: ...
@override
@@ -1244,10 +1283,11 @@ class OpenAIChatCompletionClient( # type: ignore[misc]
*,
stream: bool = False,
options: OpenAIChatCompletionOptionsT | ChatOptions[Any] | None = None,
compaction_strategy: CompactionStrategy | None = None,
tokenizer: TokenizerProtocol | None = None,
function_invocation_kwargs: Mapping[str, Any] | None = None,
client_kwargs: Mapping[str, Any] | None = None,
middleware: Sequence[ChatAndFunctionMiddlewareTypes] | None = None,
**kwargs: Any,
) -> Awaitable[ChatResponse[Any]] | ResponseStream[ChatResponseUpdate, ChatResponse[Any]]:
"""Get a response from the OpenAI chat client with all standard layers enabled."""
super_get_response = cast(
@@ -1261,9 +1301,10 @@ class OpenAIChatCompletionClient( # type: ignore[misc]
messages=messages,
stream=stream,
options=options,
compaction_strategy=compaction_strategy,
tokenizer=tokenizer,
function_invocation_kwargs=function_invocation_kwargs,
client_kwargs=effective_client_kwargs,
**kwargs,
)
@@ -79,6 +79,7 @@ class RawOpenAIEmbeddingClient(
base_url: str | None = None,
default_headers: Mapping[str, str] | None = None,
async_client: AsyncOpenAI | None = None,
additional_properties: dict[str, Any] | None = None,
env_file_path: str | None = None,
env_file_encoding: str | None = None,
) -> None:
@@ -95,6 +96,7 @@ class RawOpenAIEmbeddingClient(
``OPENAI_BASE_URL``.
default_headers: Additional HTTP headers.
async_client: Pre-configured OpenAI client.
additional_properties: Additional properties stored on the client instance.
env_file_path: Optional ``.env`` file that is checked before the process environment
for ``OPENAI_*`` values.
env_file_encoding: Encoding for the ``.env`` file.
@@ -113,6 +115,7 @@ class RawOpenAIEmbeddingClient(
base_url: str | None = None,
default_headers: Mapping[str, str] | None = None,
async_client: AsyncAzureOpenAI | AsyncOpenAI | None = None,
additional_properties: dict[str, Any] | None = None,
env_file_path: str | None = None,
env_file_encoding: str | None = None,
) -> None:
@@ -136,6 +139,7 @@ class RawOpenAIEmbeddingClient(
default_headers: Additional HTTP headers.
async_client: Pre-configured client. Passing ``AsyncAzureOpenAI`` keeps the client on
Azure; passing ``AsyncOpenAI`` keeps the client on OpenAI.
additional_properties: Additional properties stored on the client instance.
env_file_path: Optional ``.env`` file that is checked before process environment
variables for ``AZURE_OPENAI_*`` values.
env_file_encoding: Encoding for the ``.env`` file.
@@ -155,9 +159,9 @@ class RawOpenAIEmbeddingClient(
api_version: str | None = None,
default_headers: Mapping[str, str] | None = None,
async_client: AsyncAzureOpenAI | AsyncOpenAI | None = None,
additional_properties: dict[str, Any] | None = None,
env_file_path: str | None = None,
env_file_encoding: str | None = None,
**kwargs: Any,
) -> None:
"""Initialize a raw OpenAI embedding client.
@@ -187,11 +191,11 @@ class RawOpenAIEmbeddingClient(
default_headers: Additional HTTP headers.
async_client: Pre-configured client. Passing ``AsyncAzureOpenAI`` keeps the client on
Azure; passing ``AsyncOpenAI`` keeps the client on OpenAI.
additional_properties: Additional properties stored on the client instance.
env_file_path: Optional ``.env`` file that is checked before process environment
variables. The same file is used for both ``OPENAI_*`` and ``AZURE_OPENAI_*``
lookups.
env_file_encoding: Encoding for the ``.env`` file.
kwargs: Additional keyword arguments forwarded to ``BaseEmbeddingClient``.
Notes:
Environment resolution precedence is:
@@ -247,7 +251,7 @@ class RawOpenAIEmbeddingClient(
if use_azure_client:
self.OTEL_PROVIDER_NAME = "azure.ai.openai" # type: ignore[misc]
super().__init__(**kwargs)
super().__init__(additional_properties=additional_properties)
def service_url(self) -> str:
"""Get the URL of the service."""
@@ -1,5 +1,6 @@
# Copyright (c) Microsoft. All rights reserved.
import inspect
import json
import logging
from typing import Annotated, Any
@@ -11,6 +12,11 @@ from agent_framework import (
Content,
Message,
SupportsChatGetResponse,
SupportsCodeInterpreterTool,
SupportsFileSearchTool,
SupportsImageGenerationTool,
SupportsMCPTool,
SupportsWebSearchTool,
tool,
)
from openai.types.beta.threads import (
@@ -30,6 +36,8 @@ from pydantic import Field
from agent_framework_openai import OpenAIAssistantsClient
pytestmark = pytest.mark.filterwarnings("ignore:OpenAIAssistantsClient is deprecated\\..*:DeprecationWarning")
def create_test_openai_assistants_client(
mock_async_openai: MagicMock,
@@ -104,6 +112,25 @@ def mock_async_openai() -> MagicMock:
return mock_client
def test_openai_assistants_client_is_deprecated(mock_async_openai: MagicMock) -> None:
with pytest.warns(DeprecationWarning, match="OpenAIAssistantsClient is deprecated. Use OpenAIChatClient instead."):
OpenAIAssistantsClient(model="gpt-4", api_key="test-api-key", async_client=mock_async_openai)
def test_openai_assistants_client_init_keeps_var_keyword() -> None:
signature = inspect.signature(OpenAIAssistantsClient.__init__)
assert any(parameter.kind == inspect.Parameter.VAR_KEYWORD for parameter in signature.parameters.values())
def test_openai_assistants_client_supports_code_interpreter_and_file_search() -> None:
assert isinstance(OpenAIAssistantsClient, SupportsCodeInterpreterTool)
assert not isinstance(OpenAIAssistantsClient, SupportsWebSearchTool)
assert not isinstance(OpenAIAssistantsClient, SupportsImageGenerationTool)
assert not isinstance(OpenAIAssistantsClient, SupportsMCPTool)
assert isinstance(OpenAIAssistantsClient, SupportsFileSearchTool)
def test_init_with_client(mock_async_openai: MagicMock) -> None:
"""Test OpenAIAssistantsClient initialization with existing client."""
client = create_test_openai_assistants_client(
@@ -1,6 +1,7 @@
# Copyright (c) Microsoft. All rights reserved.
import base64
import inspect
import json
import os
from datetime import datetime, timezone
@@ -18,6 +19,11 @@ from agent_framework import (
FunctionTool,
Message,
SupportsChatGetResponse,
SupportsCodeInterpreterTool,
SupportsFileSearchTool,
SupportsImageGenerationTool,
SupportsMCPTool,
SupportsWebSearchTool,
tool,
)
from agent_framework._sessions import (
@@ -48,7 +54,7 @@ from openai.types.responses.response_text_delta_event import ResponseTextDeltaEv
from pydantic import BaseModel
from pytest import param
from agent_framework_openai import OpenAIChatClient
from agent_framework_openai import OpenAIChatClient, OpenAIResponsesClient
from agent_framework_openai._chat_client import OPENAI_LOCAL_SHELL_CALL_ITEM_ID_KEY
from agent_framework_openai._exceptions import OpenAIContentFilterException
@@ -110,6 +116,40 @@ def test_init(openai_unit_test_env: dict[str, str]) -> None:
assert isinstance(openai_responses_client, SupportsChatGetResponse)
def test_init_uses_explicit_parameters() -> None:
signature = inspect.signature(OpenAIChatClient.__init__)
assert "additional_properties" in signature.parameters
assert "compaction_strategy" in signature.parameters
assert "tokenizer" in signature.parameters
assert all(parameter.kind != inspect.Parameter.VAR_KEYWORD for parameter in signature.parameters.values())
def test_deprecated_responses_client_supports_all_tool_protocols() -> None:
assert isinstance(OpenAIResponsesClient, SupportsCodeInterpreterTool)
assert isinstance(OpenAIResponsesClient, SupportsWebSearchTool)
assert isinstance(OpenAIResponsesClient, SupportsImageGenerationTool)
assert isinstance(OpenAIResponsesClient, SupportsMCPTool)
assert isinstance(OpenAIResponsesClient, SupportsFileSearchTool)
def test_protocol_isinstance_with_responses_client_instance() -> None:
client = object.__new__(OpenAIResponsesClient)
assert isinstance(client, SupportsCodeInterpreterTool)
assert isinstance(client, SupportsWebSearchTool)
def test_deprecated_responses_client_tool_methods_return_dict() -> None:
code_tool = OpenAIResponsesClient.get_code_interpreter_tool()
assert isinstance(code_tool, dict)
assert code_tool.get("type") == "code_interpreter"
web_tool = OpenAIResponsesClient.get_web_search_tool()
assert isinstance(web_tool, dict)
assert web_tool.get("type") == "web_search"
def test_init_prefers_openai_responses_model(monkeypatch, openai_unit_test_env: dict[str, str]) -> None:
monkeypatch.setenv("OPENAI_RESPONSES_MODEL", "test_responses_model_id")
@@ -3033,20 +3073,6 @@ async def test_prepare_options_store_parameter_handling() -> None:
assert "previous_response_id" not in options
async def test_conversation_id_precedence_kwargs_over_options() -> None:
"""When both kwargs and options contain conversation_id, kwargs wins."""
client = OpenAIChatClient(model="test-model", api_key="test-key")
messages = [Message(role="user", text="Hello")]
# options has a stale response id, kwargs carries the freshest one
opts = {"conversation_id": "resp_old_123"}
run_opts = await client._prepare_options(messages, opts, conversation_id="resp_new_456") # type: ignore
# Verify kwargs takes precedence and maps to previous_response_id for resp_* IDs
assert run_opts.get("previous_response_id") == "resp_new_456"
assert "conversation" not in run_opts
def _create_mock_responses_text_response(*, response_id: str) -> MagicMock:
mock_response = MagicMock()
mock_response.id = response_id
@@ -465,7 +465,7 @@ async def test_integration_client_agent_existing_session() -> None:
first_response = await first_agent.run(
"My hobby is photography. Remember this.",
session=session,
store=True,
options={"store": True},
)
assert isinstance(first_response, AgentResponse)
@@ -476,7 +476,9 @@ async def test_integration_client_agent_existing_session() -> None:
client=OpenAIChatClient(credential=credential),
instructions="You are a helpful assistant with good memory.",
) as second_agent:
second_response = await second_agent.run("What is my hobby?", session=preserved_session)
second_response = await second_agent.run(
"What is my hobby?", session=preserved_session, options={"store": True}
)
assert isinstance(second_response, AgentResponse)
assert second_response.text is not None
@@ -1,5 +1,6 @@
# Copyright (c) Microsoft. All rights reserved.
import inspect
import json
import os
from typing import Any
@@ -11,6 +12,11 @@ from agent_framework import (
Content,
Message,
SupportsChatGetResponse,
SupportsCodeInterpreterTool,
SupportsFileSearchTool,
SupportsImageGenerationTool,
SupportsMCPTool,
SupportsWebSearchTool,
tool,
)
from agent_framework.exceptions import ChatClientException, SettingNotFoundError
@@ -20,7 +26,7 @@ from openai.types.chat.chat_completion_message import ChatCompletionMessage
from pydantic import BaseModel
from pytest import param
from agent_framework_openai import OpenAIChatCompletionClient
from agent_framework_openai import OpenAIChatCompletionClient, RawOpenAIChatCompletionClient
from agent_framework_openai._exceptions import OpenAIContentFilterException
skip_if_openai_integration_tests_disabled = pytest.mark.skipif(
@@ -37,6 +43,41 @@ def test_init(openai_unit_test_env: dict[str, str]) -> None:
assert isinstance(open_ai_chat_completion, SupportsChatGetResponse)
def test_get_response_docstring_surfaces_layered_runtime_docs() -> None:
docstring = inspect.getdoc(OpenAIChatCompletionClient.get_response)
assert docstring is not None
assert "Get a response from a chat client." in docstring
assert "function_invocation_kwargs" in docstring
assert "middleware: Optional per-call chat and function middleware." in docstring
assert "function_middleware: Optional per-call function middleware." not in docstring
def test_get_response_is_defined_on_openai_class() -> None:
signature = inspect.signature(OpenAIChatCompletionClient.get_response)
assert OpenAIChatCompletionClient.get_response.__qualname__ == "OpenAIChatCompletionClient.get_response"
assert "middleware" in signature.parameters
assert all(parameter.kind != inspect.Parameter.VAR_KEYWORD for parameter in signature.parameters.values())
def test_init_uses_explicit_parameters() -> None:
signature = inspect.signature(RawOpenAIChatCompletionClient.__init__)
assert "additional_properties" in signature.parameters
assert "compaction_strategy" in signature.parameters
assert "tokenizer" in signature.parameters
assert all(parameter.kind != inspect.Parameter.VAR_KEYWORD for parameter in signature.parameters.values())
def test_supports_web_search_only() -> None:
assert not isinstance(OpenAIChatCompletionClient, SupportsCodeInterpreterTool)
assert isinstance(OpenAIChatCompletionClient, SupportsWebSearchTool)
assert not isinstance(OpenAIChatCompletionClient, SupportsImageGenerationTool)
assert not isinstance(OpenAIChatCompletionClient, SupportsMCPTool)
assert not isinstance(OpenAIChatCompletionClient, SupportsFileSearchTool)
def test_init_prefers_openai_chat_model(monkeypatch, openai_unit_test_env: dict[str, str]) -> None:
monkeypatch.setenv("OPENAI_CHAT_MODEL", "test_chat_model_id")
@@ -138,7 +138,7 @@ async def test_cmc_structured_output_no_fcc(
openai_chat_completion = OpenAIChatCompletionClient()
await openai_chat_completion.get_response(
messages=chat_history,
response_format=Test,
options={"response_format": Test},
)
mock_create.assert_awaited_once()
@@ -322,7 +322,7 @@ async def test_get_streaming_structured_output_no_fcc(
async for msg in openai_chat_completion.get_response(
stream=True,
messages=chat_history,
response_format=Test,
options={"response_format": Test},
):
assert isinstance(msg, ChatResponseUpdate)
mock_create.assert_awaited_once()
@@ -2,6 +2,7 @@
from __future__ import annotations
import inspect
import os
from unittest.mock import AsyncMock, MagicMock
@@ -15,6 +16,7 @@ from agent_framework_openai import (
OpenAIEmbeddingClient,
OpenAIEmbeddingOptions,
)
from agent_framework_openai._embedding_client import RawOpenAIEmbeddingClient
def _make_openai_response(
@@ -44,6 +46,13 @@ def test_openai_construction_with_explicit_params() -> None:
assert client.model == "text-embedding-3-small"
def test_raw_openai_embedding_client_init_uses_explicit_parameters() -> None:
signature = inspect.signature(RawOpenAIEmbeddingClient.__init__)
assert "additional_properties" in signature.parameters
assert all(parameter.kind != inspect.Parameter.VAR_KEYWORD for parameter in signature.parameters.values())
def test_openai_construction_from_env(openai_unit_test_env: dict[str, str]) -> None:
client = OpenAIEmbeddingClient()
assert client.model == openai_unit_test_env["OPENAI_EMBEDDING_MODEL"]
@@ -1,21 +1,20 @@
# Copyright (c) Microsoft. All rights reserved.
"""Host multiple Foundry-powered agents inside a single Azure Functions app.
"""Host multiple Azure OpenAI-powered agents inside a single Azure Functions app.
Components used in this sample:
- FoundryChatClient to create agents bound to a shared Foundry deployment.
- OpenAIChatCompletionClient configured for Azure OpenAI.
- AgentFunctionApp to register multiple agents and expose dedicated HTTP endpoints.
- Custom tool functions to demonstrate tool invocation from different agents.
Prerequisites: set `FOUNDRY_PROJECT_ENDPOINT`, `FOUNDRY_MODEL`, and sign in with Azure CLI before starting the Functions host."""
Prerequisites: set `AZURE_OPENAI_ENDPOINT`, `AZURE_OPENAI_DEPLOYMENT_NAME`, and sign in with Azure CLI before starting the Functions host."""
import logging
import os
from typing import Any
from agent_framework import Agent, tool
from agent_framework.azure import AgentFunctionApp
from agent_framework.foundry import FoundryChatClient
from agent_framework.openai import OpenAIChatCompletionClient
from azure.identity.aio import AzureCliCredential
from dotenv import load_dotenv
@@ -60,9 +59,7 @@ def calculate_tip(bill_amount: float, tip_percentage: float = 15.0) -> dict[str,
# 1. Create multiple agents, each with its own instruction set and tools.
client = FoundryChatClient(
project_endpoint=os.environ["FOUNDRY_PROJECT_ENDPOINT"],
model=os.environ["FOUNDRY_MODEL"],
client = OpenAIChatCompletionClient(
credential=AzureCliCredential(),
)
@@ -8,7 +8,7 @@ each with their own specialized capabilities and tools.
Prerequisites:
- The worker must be running with both agents registered
- Set FOUNDRY_PROJECT_ENDPOINT and FOUNDRY_MODEL
- Set AZURE_OPENAI_ENDPOINT and AZURE_OPENAI_DEPLOYMENT_NAME when running the worker
- Sign in with Azure CLI for AzureCliCredential authentication
- Durable Task Scheduler must be running
"""
@@ -5,7 +5,7 @@ This sample demonstrates running both the worker and client in a single process
for multiple agents with different tools. The worker registers two agents
(WeatherAgent and MathAgent), each with their own specialized capabilities.
Prerequisites:
- Set FOUNDRY_PROJECT_ENDPOINT and FOUNDRY_MODEL
- Set AZURE_OPENAI_ENDPOINT and AZURE_OPENAI_DEPLOYMENT_NAME
- Sign in with Azure CLI for AzureCliCredential authentication
- Durable Task Scheduler must be running (e.g., using Docker)
To run this sample:
@@ -1,13 +1,13 @@
# Copyright (c) Microsoft. All rights reserved.
"""Worker process for hosting multiple agents with different tools using Durable Task.
"""Worker process for hosting multiple Azure OpenAI agents with different tools using Durable Task.
This worker registers two agents - a weather assistant and a math assistant - each
with their own specialized tools. This demonstrates how to host multiple agents
with different capabilities in a single worker process.
Prerequisites:
- Set FOUNDRY_PROJECT_ENDPOINT and FOUNDRY_MODEL
- Set AZURE_OPENAI_ENDPOINT and AZURE_OPENAI_DEPLOYMENT_NAME
- Sign in with Azure CLI for AzureCliCredential authentication
- Start a Durable Task Scheduler (e.g., using Docker)
"""
@@ -19,7 +19,7 @@ from typing import Any
from agent_framework import Agent, tool
from agent_framework.azure import DurableAIAgentWorker
from agent_framework.foundry import FoundryChatClient
from agent_framework.openai import OpenAIChatCompletionClient
from azure.identity import AzureCliCredential
from azure.identity.aio import AzureCliCredential as AsyncAzureCliCredential
from dotenv import load_dotenv
@@ -73,13 +73,10 @@ def create_weather_agent():
Returns:
Agent: The configured Weather agent with weather tool
"""
_client = FoundryChatClient(
project_endpoint=os.environ["FOUNDRY_PROJECT_ENDPOINT"],
model=os.environ["FOUNDRY_MODEL"],
credential=AsyncAzureCliCredential(),
)
return Agent(
client=_client,
client=OpenAIChatCompletionClient(
credential=AsyncAzureCliCredential(),
),
name=WEATHER_AGENT_NAME,
instructions="You are a helpful weather assistant. Provide current weather information.",
tools=[get_weather],
@@ -92,13 +89,10 @@ def create_math_agent():
Returns:
Agent: The configured Math agent with calculation tools
"""
_client = FoundryChatClient(
project_endpoint=os.environ["FOUNDRY_PROJECT_ENDPOINT"],
model=os.environ["FOUNDRY_MODEL"],
credential=AsyncAzureCliCredential(),
)
return Agent(
client=_client,
client=OpenAIChatCompletionClient(
credential=AsyncAzureCliCredential(),
),
name=MATH_AGENT_NAME,
instructions="You are a helpful math assistant. Help users with calculations like tip calculations.",
tools=[calculate_tip],