mirror of
https://github.com/microsoft/agent-framework.git
synced 2026-06-16 21:04:09 +08:00
Python: [BREAKING] Remove deprecated kwargs compatibility paths (#4858)
* [BREAKING] Remove deprecated kwargs compatibility paths Remove the deprecated kwargs compatibility shims across core agents, clients, tools, middleware, and telemetry. Keep workflow kwargs behavior intact in this branch and follow up separately in #4850. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> * Fix PR CI fallout for kwargs removal Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> * Address PR review feedback Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> * updates * Fix Azure AI CI fallout Remove the stale _get_current_conversation_id override from the Azure AI client after the OpenAI base helper was deleted. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> * fixed new classes * Fix Assistants deprecated import gating Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> * Fix integration replay regressions Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> * Switch multi-agent hosting samples to Azure chat completions Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> * Simplify Azure multi-agent sample config Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> --------- Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
This commit is contained in:
committed by
GitHub
Unverified
parent
ca6cdd142e
commit
b1b528e4a8
@@ -210,6 +210,7 @@ jobs:
|
||||
OPENAI_API_KEY: ${{ secrets.OPENAI__APIKEY }}
|
||||
AZURE_OPENAI_ENDPOINT: ${{ vars.AZUREOPENAI__ENDPOINT }}
|
||||
AZURE_OPENAI_DEPLOYMENT_NAME: ${{ vars.AZUREOPENAI__RESPONSESDEPLOYMENTNAME }}
|
||||
AZURE_OPENAI_CHAT_DEPLOYMENT_NAME: ${{ vars.AZUREOPENAI__CHATDEPLOYMENTNAME }}
|
||||
FOUNDRY_PROJECT_ENDPOINT: ${{ vars.FOUNDRY_PROJECT_ENDPOINT }}
|
||||
FOUNDRY_MODEL: ${{ vars.FOUNDRY_MODEL }}
|
||||
FUNCTIONS_WORKER_RUNTIME: "python"
|
||||
|
||||
@@ -341,6 +341,7 @@ jobs:
|
||||
OPENAI_API_KEY: ${{ secrets.OPENAI__APIKEY }}
|
||||
AZURE_OPENAI_ENDPOINT: ${{ vars.AZUREOPENAI__ENDPOINT }}
|
||||
AZURE_OPENAI_DEPLOYMENT_NAME: ${{ vars.AZUREOPENAI__RESPONSESDEPLOYMENTNAME }}
|
||||
AZURE_OPENAI_CHAT_DEPLOYMENT_NAME: ${{ vars.AZUREOPENAI__CHATDEPLOYMENTNAME }}
|
||||
FOUNDRY_PROJECT_ENDPOINT: ${{ vars.FOUNDRY_PROJECT_ENDPOINT }}
|
||||
FOUNDRY_MODEL: ${{ vars.FOUNDRY_MODEL }}
|
||||
FUNCTIONS_WORKER_RUNTIME: "python"
|
||||
|
||||
@@ -366,7 +366,7 @@ def test_get_uri_data_invalid_uri() -> None:
|
||||
def test_parse_contents_from_a2a_conversion(a2a_agent: A2AAgent) -> None:
|
||||
"""Test A2A parts to contents conversion."""
|
||||
|
||||
agent = A2AAgent(name="Test Agent", client=MockA2AClient(), _http_client=None)
|
||||
agent = A2AAgent(name="Test Agent", client=MockA2AClient(), http_client=None)
|
||||
|
||||
# Create A2A parts
|
||||
parts = [Part(root=TextPart(text="First part")), Part(root=TextPart(text="Second part"))]
|
||||
@@ -485,7 +485,7 @@ async def test_context_manager_no_cleanup_when_no_http_client() -> None:
|
||||
|
||||
mock_a2a_client = MagicMock()
|
||||
|
||||
agent = A2AAgent(client=mock_a2a_client, _http_client=None)
|
||||
agent = A2AAgent(client=mock_a2a_client, http_client=None)
|
||||
|
||||
# This should not raise any errors
|
||||
async with agent:
|
||||
@@ -495,7 +495,7 @@ async def test_context_manager_no_cleanup_when_no_http_client() -> None:
|
||||
def test_prepare_message_for_a2a_with_multiple_contents() -> None:
|
||||
"""Test conversion of Message with multiple contents."""
|
||||
|
||||
agent = A2AAgent(client=MagicMock(), _http_client=None)
|
||||
agent = A2AAgent(client=MagicMock(), http_client=None)
|
||||
|
||||
# Create message with multiple content types
|
||||
message = Message(
|
||||
@@ -523,7 +523,7 @@ def test_prepare_message_for_a2a_with_multiple_contents() -> None:
|
||||
def test_prepare_message_for_a2a_forwards_context_id() -> None:
|
||||
"""Test conversion of Message preserves context_id without duplicating it in metadata."""
|
||||
|
||||
agent = A2AAgent(client=MagicMock(), _http_client=None)
|
||||
agent = A2AAgent(client=MagicMock(), http_client=None)
|
||||
|
||||
message = Message(
|
||||
role="user",
|
||||
@@ -540,7 +540,7 @@ def test_prepare_message_for_a2a_forwards_context_id() -> None:
|
||||
def test_parse_contents_from_a2a_with_data_part() -> None:
|
||||
"""Test conversion of A2A DataPart."""
|
||||
|
||||
agent = A2AAgent(client=MagicMock(), _http_client=None)
|
||||
agent = A2AAgent(client=MagicMock(), http_client=None)
|
||||
|
||||
# Create DataPart
|
||||
data_part = Part(root=DataPart(data={"key": "value", "number": 42}, metadata={"source": "test"}))
|
||||
@@ -556,7 +556,7 @@ def test_parse_contents_from_a2a_with_data_part() -> None:
|
||||
|
||||
def test_parse_contents_from_a2a_unknown_part_kind() -> None:
|
||||
"""Test error handling for unknown A2A part kind."""
|
||||
agent = A2AAgent(client=MagicMock(), _http_client=None)
|
||||
agent = A2AAgent(client=MagicMock(), http_client=None)
|
||||
|
||||
# Create a mock part with unknown kind
|
||||
mock_part = MagicMock()
|
||||
@@ -569,7 +569,7 @@ def test_parse_contents_from_a2a_unknown_part_kind() -> None:
|
||||
def test_prepare_message_for_a2a_with_hosted_file() -> None:
|
||||
"""Test conversion of Message with HostedFileContent to A2A message."""
|
||||
|
||||
agent = A2AAgent(client=MagicMock(), _http_client=None)
|
||||
agent = A2AAgent(client=MagicMock(), http_client=None)
|
||||
|
||||
# Create message with hosted file content
|
||||
message = Message(
|
||||
@@ -595,7 +595,7 @@ def test_prepare_message_for_a2a_with_hosted_file() -> None:
|
||||
def test_parse_contents_from_a2a_with_hosted_file_uri() -> None:
|
||||
"""Test conversion of A2A FilePart with hosted file URI back to UriContent."""
|
||||
|
||||
agent = A2AAgent(client=MagicMock(), _http_client=None)
|
||||
agent = A2AAgent(client=MagicMock(), http_client=None)
|
||||
|
||||
# Create FilePart with hosted file URI (simulating what A2A would send back)
|
||||
file_part = Part(
|
||||
|
||||
@@ -445,6 +445,8 @@ class AzureAIAgentsProvider(Generic[OptionsCoT]):
|
||||
|
||||
# Merge tools: convert agent's hosted tools + user-provided function tools
|
||||
merged_tools = self._merge_tools(agent.tools, provided_tools)
|
||||
merged_default_options: dict[str, Any] = dict(default_options) if default_options is not None else {}
|
||||
merged_default_options.setdefault("model_id", agent.model)
|
||||
|
||||
return Agent( # type: ignore[return-value]
|
||||
client=client,
|
||||
@@ -452,9 +454,8 @@ class AzureAIAgentsProvider(Generic[OptionsCoT]):
|
||||
name=agent.name,
|
||||
description=agent.description,
|
||||
instructions=agent.instructions,
|
||||
model_id=agent.model,
|
||||
tools=merged_tools,
|
||||
default_options=default_options, # type: ignore[arg-type]
|
||||
default_options=cast(Any, merged_default_options),
|
||||
middleware=middleware,
|
||||
context_providers=context_providers,
|
||||
)
|
||||
|
||||
@@ -603,11 +603,6 @@ class RawAzureAIClient(RawOpenAIChatClient[AzureAIClientOptionsT], Generic[Azure
|
||||
|
||||
return transformed
|
||||
|
||||
@override
|
||||
def _get_current_conversation_id(self, options: Mapping[str, Any], **kwargs: Any) -> str | None:
|
||||
"""Get the current conversation ID from chat options or kwargs."""
|
||||
return options.get("conversation_id") or kwargs.get("conversation_id") or self.conversation_id
|
||||
|
||||
@override
|
||||
def _parse_response_from_openai(
|
||||
self,
|
||||
|
||||
@@ -24,7 +24,10 @@ from agent_framework._telemetry import AGENT_FRAMEWORK_USER_AGENT, APP_INFO, pre
|
||||
from agent_framework._tools import FunctionInvocationConfiguration, FunctionInvocationLayer
|
||||
from agent_framework._types import Annotation, Content
|
||||
from agent_framework.observability import ChatTelemetryLayer, EmbeddingTelemetryLayer
|
||||
from agent_framework_openai._assistants_client import OpenAIAssistantsClient, OpenAIAssistantsOptions
|
||||
from agent_framework_openai._assistants_client import (
|
||||
OpenAIAssistantsClient, # type: ignore[reportDeprecated]
|
||||
OpenAIAssistantsOptions,
|
||||
)
|
||||
from agent_framework_openai._chat_client import OpenAIChatOptions, RawOpenAIChatClient
|
||||
from agent_framework_openai._chat_completion_client import OpenAIChatCompletionOptions, RawOpenAIChatCompletionClient
|
||||
from agent_framework_openai._embedding_client import OpenAIEmbeddingOptions, RawOpenAIEmbeddingClient
|
||||
@@ -673,7 +676,8 @@ AzureOpenAIAssistantsOptions = OpenAIAssistantsOptions
|
||||
"Use OpenAIAssistantsClient (also deprecated) or migrate to OpenAIChatClient."
|
||||
)
|
||||
class AzureOpenAIAssistantsClient(
|
||||
OpenAIAssistantsClient[AzureOpenAIAssistantsOptionsT], Generic[AzureOpenAIAssistantsOptionsT]
|
||||
OpenAIAssistantsClient[AzureOpenAIAssistantsOptionsT], # type: ignore[reportDeprecated]
|
||||
Generic[AzureOpenAIAssistantsOptionsT],
|
||||
):
|
||||
"""Deprecated Azure OpenAI Assistants client. Use OpenAIAssistantsClient or migrate to OpenAIChatClient."""
|
||||
|
||||
|
||||
@@ -5,7 +5,7 @@ from __future__ import annotations
|
||||
import logging
|
||||
import sys
|
||||
from collections.abc import Callable, Mapping, MutableMapping, Sequence
|
||||
from typing import Any, Generic
|
||||
from typing import Any, Generic, cast
|
||||
|
||||
from agent_framework import (
|
||||
AGENT_FRAMEWORK_USER_AGENT,
|
||||
@@ -398,6 +398,8 @@ class AzureAIProjectAgentProvider(Generic[OptionsCoT]):
|
||||
# from_azure_ai_tools converts hosted tools (MCP, code interpreter, file search, web search)
|
||||
# but function tools need the actual implementations from provided_tools
|
||||
merged_tools = self._merge_tools(details.definition.tools, provided_tools)
|
||||
merged_default_options: dict[str, Any] = dict(default_options) if default_options is not None else {}
|
||||
merged_default_options.setdefault("model_id", details.definition.model)
|
||||
|
||||
return Agent( # type: ignore[return-value]
|
||||
client=client,
|
||||
@@ -405,9 +407,8 @@ class AzureAIProjectAgentProvider(Generic[OptionsCoT]):
|
||||
name=details.name,
|
||||
description=details.description,
|
||||
instructions=details.definition.instructions,
|
||||
model_id=details.definition.model,
|
||||
tools=merged_tools,
|
||||
default_options=default_options, # type: ignore[arg-type]
|
||||
default_options=cast(Any, merged_default_options),
|
||||
middleware=middleware,
|
||||
context_providers=context_providers,
|
||||
)
|
||||
|
||||
@@ -477,7 +477,9 @@ async def test_integration_client_agent_existing_session():
|
||||
) as first_agent:
|
||||
# Start a conversation and capture the session
|
||||
session = first_agent.create_session()
|
||||
first_response = await first_agent.run("My hobby is photography. Remember this.", session=session, store=True)
|
||||
first_response = await first_agent.run(
|
||||
"My hobby is photography. Remember this.", session=session, options={"store": True}
|
||||
)
|
||||
|
||||
assert isinstance(first_response, AgentResponse)
|
||||
assert first_response.text is not None
|
||||
@@ -492,7 +494,9 @@ async def test_integration_client_agent_existing_session():
|
||||
instructions="You are a helpful assistant with good memory.",
|
||||
) as second_agent:
|
||||
# Reuse the preserved session
|
||||
second_response = await second_agent.run("What is my hobby?", session=preserved_session)
|
||||
second_response = await second_agent.run(
|
||||
"What is my hobby?", session=preserved_session, options={"store": True}
|
||||
)
|
||||
|
||||
assert isinstance(second_response, AgentResponse)
|
||||
assert second_response.text is not None
|
||||
|
||||
@@ -7,7 +7,7 @@ import logging
|
||||
import sys
|
||||
from collections.abc import AsyncIterable, Awaitable, Callable, MutableMapping, Sequence
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any, ClassVar, Generic, Literal, overload
|
||||
from typing import TYPE_CHECKING, Any, ClassVar, Generic, Literal, cast, overload
|
||||
|
||||
from agent_framework import (
|
||||
AgentMiddlewareTypes,
|
||||
@@ -584,7 +584,7 @@ class RawClaudeAgent(BaseAgent, Generic[OptionsT]):
|
||||
return AgentResponse.from_updates(updates, value=structured_output)
|
||||
|
||||
@overload
|
||||
def run(
|
||||
def run( # type: ignore[override]
|
||||
self,
|
||||
messages: AgentRunInputs | None = None,
|
||||
*,
|
||||
@@ -595,7 +595,7 @@ class RawClaudeAgent(BaseAgent, Generic[OptionsT]):
|
||||
) -> Awaitable[AgentResponse[Any]]: ...
|
||||
|
||||
@overload
|
||||
def run(
|
||||
def run( # type: ignore[override]
|
||||
self,
|
||||
messages: AgentRunInputs | None = None,
|
||||
*,
|
||||
@@ -747,3 +747,71 @@ class ClaudeAgent(AgentTelemetryLayer, RawClaudeAgent[OptionsT], Generic[Options
|
||||
response = await agent.run("Hello!")
|
||||
print(response.text)
|
||||
"""
|
||||
|
||||
@overload # type: ignore[override]
|
||||
def run(
|
||||
self,
|
||||
messages: AgentRunInputs | None = None,
|
||||
*,
|
||||
stream: Literal[False] = ...,
|
||||
session: AgentSession | None = None,
|
||||
middleware: Sequence[AgentMiddlewareTypes] | None = None,
|
||||
options: OptionsT | None = None,
|
||||
tools: ToolTypes | Callable[..., Any] | Sequence[ToolTypes | Callable[..., Any]] | None = None,
|
||||
compaction_strategy: Any = None,
|
||||
tokenizer: Any = None,
|
||||
function_invocation_kwargs: dict[str, Any] | None = None,
|
||||
client_kwargs: dict[str, Any] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> Awaitable[AgentResponse[Any]]: ...
|
||||
|
||||
@overload # type: ignore[override]
|
||||
def run(
|
||||
self,
|
||||
messages: AgentRunInputs | None = None,
|
||||
*,
|
||||
stream: Literal[True],
|
||||
session: AgentSession | None = None,
|
||||
middleware: Sequence[AgentMiddlewareTypes] | None = None,
|
||||
options: OptionsT | None = None,
|
||||
tools: ToolTypes | Callable[..., Any] | Sequence[ToolTypes | Callable[..., Any]] | None = None,
|
||||
compaction_strategy: Any = None,
|
||||
tokenizer: Any = None,
|
||||
function_invocation_kwargs: dict[str, Any] | None = None,
|
||||
client_kwargs: dict[str, Any] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> ResponseStream[AgentResponseUpdate, AgentResponse[Any]]: ...
|
||||
|
||||
def run( # pyright: ignore[reportIncompatibleMethodOverride] # type: ignore[override]
|
||||
self,
|
||||
messages: AgentRunInputs | None = None,
|
||||
*,
|
||||
stream: bool = False,
|
||||
session: AgentSession | None = None,
|
||||
middleware: Sequence[AgentMiddlewareTypes] | None = None,
|
||||
options: OptionsT | None = None,
|
||||
tools: ToolTypes | Callable[..., Any] | Sequence[ToolTypes | Callable[..., Any]] | None = None,
|
||||
compaction_strategy: Any = None,
|
||||
tokenizer: Any = None,
|
||||
function_invocation_kwargs: dict[str, Any] | None = None,
|
||||
client_kwargs: dict[str, Any] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> Awaitable[AgentResponse[Any]] | ResponseStream[AgentResponseUpdate, AgentResponse[Any]]:
|
||||
"""Run the Claude agent with telemetry enabled."""
|
||||
super_run = cast(
|
||||
"Callable[..., Awaitable[AgentResponse[Any]] | ResponseStream[AgentResponseUpdate, AgentResponse[Any]]]",
|
||||
super().run,
|
||||
)
|
||||
return super_run(
|
||||
messages=messages,
|
||||
stream=stream,
|
||||
session=session,
|
||||
middleware=middleware,
|
||||
options=options,
|
||||
tools=tools,
|
||||
compaction_strategy=compaction_strategy,
|
||||
tokenizer=tokenizer,
|
||||
function_invocation_kwargs=function_invocation_kwargs,
|
||||
client_kwargs=client_kwargs,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@@ -5,7 +5,6 @@ from __future__ import annotations
|
||||
import logging
|
||||
import re
|
||||
import sys
|
||||
import warnings
|
||||
from collections.abc import Awaitable, Callable, Mapping, MutableMapping, Sequence
|
||||
from contextlib import AbstractAsyncContextManager, AsyncExitStack
|
||||
from copy import deepcopy
|
||||
@@ -248,7 +247,6 @@ class SupportsAgentRun(Protocol):
|
||||
session: AgentSession | None = None,
|
||||
function_invocation_kwargs: Mapping[str, Any] | None = None,
|
||||
client_kwargs: Mapping[str, Any] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> Awaitable[AgentResponse[Any]]:
|
||||
"""Get a response from the agent (non-streaming)."""
|
||||
...
|
||||
@@ -262,7 +260,6 @@ class SupportsAgentRun(Protocol):
|
||||
session: AgentSession | None = None,
|
||||
function_invocation_kwargs: Mapping[str, Any] | None = None,
|
||||
client_kwargs: Mapping[str, Any] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> ResponseStream[AgentResponseUpdate, AgentResponse[Any]]:
|
||||
"""Get a streaming response from the agent."""
|
||||
...
|
||||
@@ -275,7 +272,6 @@ class SupportsAgentRun(Protocol):
|
||||
session: AgentSession | None = None,
|
||||
function_invocation_kwargs: Mapping[str, Any] | None = None,
|
||||
client_kwargs: Mapping[str, Any] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> Awaitable[AgentResponse[Any]] | ResponseStream[AgentResponseUpdate, AgentResponse[Any]]:
|
||||
"""Get a response from the agent.
|
||||
|
||||
@@ -291,7 +287,6 @@ class SupportsAgentRun(Protocol):
|
||||
session: The conversation session associated with the message(s).
|
||||
function_invocation_kwargs: Keyword arguments forwarded to tool invocation.
|
||||
client_kwargs: Additional client-specific keyword arguments.
|
||||
kwargs: Additional keyword arguments.
|
||||
|
||||
Returns:
|
||||
When stream=False: An AgentResponse with the final result.
|
||||
@@ -334,7 +329,15 @@ class BaseAgent(SerializationMixin):
|
||||
|
||||
# Create a concrete subclass that implements the protocol
|
||||
class SimpleAgent(BaseAgent):
|
||||
async def run(self, messages=None, *, stream=False, session=None, **kwargs):
|
||||
async def run(
|
||||
self,
|
||||
messages=None,
|
||||
*,
|
||||
stream=False,
|
||||
session=None,
|
||||
function_invocation_kwargs=None,
|
||||
client_kwargs=None,
|
||||
):
|
||||
if stream:
|
||||
|
||||
async def _stream():
|
||||
@@ -373,7 +376,6 @@ class BaseAgent(SerializationMixin):
|
||||
context_providers: Sequence[BaseContextProvider] | None = None,
|
||||
middleware: Sequence[MiddlewareTypes] | None = None,
|
||||
additional_properties: MutableMapping[str, Any] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Initialize a BaseAgent instance.
|
||||
|
||||
@@ -385,15 +387,7 @@ class BaseAgent(SerializationMixin):
|
||||
context_providers: Context providers to include during agent invocation.
|
||||
middleware: List of middleware.
|
||||
additional_properties: Additional properties set on the agent.
|
||||
kwargs: Additional keyword arguments (merged into additional_properties).
|
||||
"""
|
||||
if kwargs:
|
||||
warnings.warn(
|
||||
"Passing additional properties as direct keyword arguments to BaseAgent is deprecated; "
|
||||
"pass them via additional_properties instead.",
|
||||
DeprecationWarning,
|
||||
stacklevel=3,
|
||||
)
|
||||
if id is None:
|
||||
id = str(uuid4())
|
||||
self.id = id
|
||||
@@ -403,10 +397,7 @@ class BaseAgent(SerializationMixin):
|
||||
self.middleware: list[MiddlewareTypes] | None = (
|
||||
cast(list[MiddlewareTypes], middleware) if middleware is not None else None
|
||||
)
|
||||
|
||||
# Merge kwargs into additional_properties
|
||||
self.additional_properties: dict[str, Any] = cast(dict[str, Any], additional_properties or {})
|
||||
self.additional_properties.update(kwargs)
|
||||
|
||||
def create_session(self, *, session_id: str | None = None) -> AgentSession:
|
||||
"""Create a new lightweight session.
|
||||
@@ -666,9 +657,10 @@ class RawAgent(BaseAgent, Generic[OptionsCoT]): # type: ignore[misc]
|
||||
tools: ToolTypes | Callable[..., Any] | Sequence[ToolTypes | Callable[..., Any]] | None = None,
|
||||
default_options: OptionsCoT | None = None,
|
||||
context_providers: Sequence[BaseContextProvider] | None = None,
|
||||
middleware: Sequence[MiddlewareTypes] | None = None,
|
||||
compaction_strategy: CompactionStrategy | None = None,
|
||||
tokenizer: TokenizerProtocol | None = None,
|
||||
**kwargs: Any,
|
||||
additional_properties: MutableMapping[str, Any] | None = None,
|
||||
) -> None:
|
||||
"""Initialize a Agent instance.
|
||||
|
||||
@@ -695,7 +687,7 @@ class RawAgent(BaseAgent, Generic[OptionsCoT]): # type: ignore[misc]
|
||||
If both this and a compaction_strategy on the underlying client are set, this one is used.
|
||||
tokenizer: Optional agent-level tokenizer.
|
||||
If both this and a tokenizer on the underlying client are set, this one is used.
|
||||
kwargs: Any additional keyword arguments. Will be stored as ``additional_properties``.
|
||||
additional_properties: Additional properties stored on the agent.
|
||||
"""
|
||||
opts = dict(default_options) if default_options else {}
|
||||
|
||||
@@ -709,7 +701,8 @@ class RawAgent(BaseAgent, Generic[OptionsCoT]): # type: ignore[misc]
|
||||
name=name,
|
||||
description=description,
|
||||
context_providers=context_providers,
|
||||
**kwargs,
|
||||
middleware=middleware,
|
||||
additional_properties=additional_properties,
|
||||
)
|
||||
self.client = client
|
||||
self.compaction_strategy = compaction_strategy
|
||||
@@ -812,7 +805,6 @@ class RawAgent(BaseAgent, Generic[OptionsCoT]): # type: ignore[misc]
|
||||
tokenizer: TokenizerProtocol | None = None,
|
||||
function_invocation_kwargs: Mapping[str, Any] | None = None,
|
||||
client_kwargs: Mapping[str, Any] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> Awaitable[AgentResponse[ResponseModelBoundT]]: ...
|
||||
|
||||
@overload
|
||||
@@ -828,7 +820,6 @@ class RawAgent(BaseAgent, Generic[OptionsCoT]): # type: ignore[misc]
|
||||
tokenizer: TokenizerProtocol | None = None,
|
||||
function_invocation_kwargs: Mapping[str, Any] | None = None,
|
||||
client_kwargs: Mapping[str, Any] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> Awaitable[AgentResponse[Any]]: ...
|
||||
|
||||
@overload
|
||||
@@ -844,7 +835,6 @@ class RawAgent(BaseAgent, Generic[OptionsCoT]): # type: ignore[misc]
|
||||
tokenizer: TokenizerProtocol | None = None,
|
||||
function_invocation_kwargs: Mapping[str, Any] | None = None,
|
||||
client_kwargs: Mapping[str, Any] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> ResponseStream[AgentResponseUpdate, AgentResponse[Any]]: ...
|
||||
|
||||
def run(
|
||||
@@ -859,7 +849,6 @@ class RawAgent(BaseAgent, Generic[OptionsCoT]): # type: ignore[misc]
|
||||
tokenizer: TokenizerProtocol | None = None,
|
||||
function_invocation_kwargs: Mapping[str, Any] | None = None,
|
||||
client_kwargs: Mapping[str, Any] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> Awaitable[AgentResponse[Any]] | ResponseStream[AgentResponseUpdate, AgentResponse[Any]]:
|
||||
"""Run the agent with the given messages and options.
|
||||
|
||||
@@ -890,21 +879,12 @@ class RawAgent(BaseAgent, Generic[OptionsCoT]): # type: ignore[misc]
|
||||
is used, falling back to the client default.
|
||||
function_invocation_kwargs: Keyword arguments forwarded to tool invocation.
|
||||
client_kwargs: Additional client-specific keyword arguments for the chat client.
|
||||
kwargs: Deprecated additional keyword arguments for the agent.
|
||||
They are forwarded to both tool invocation and the chat client for compatibility.
|
||||
|
||||
Returns:
|
||||
When stream=False: An Awaitable[AgentResponse] containing the agent's response.
|
||||
When stream=True: A ResponseStream of AgentResponseUpdate items with
|
||||
``get_final_response()`` for the final AgentResponse.
|
||||
"""
|
||||
if kwargs:
|
||||
warnings.warn(
|
||||
"Passing runtime keyword arguments directly to run() is deprecated; pass tool values via "
|
||||
"function_invocation_kwargs and client-specific values via client_kwargs instead.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
if not stream:
|
||||
|
||||
async def _run_non_streaming() -> AgentResponse[Any]:
|
||||
@@ -915,7 +895,6 @@ class RawAgent(BaseAgent, Generic[OptionsCoT]): # type: ignore[misc]
|
||||
options=options,
|
||||
compaction_strategy=compaction_strategy,
|
||||
tokenizer=tokenizer,
|
||||
legacy_kwargs=kwargs,
|
||||
function_invocation_kwargs=function_invocation_kwargs,
|
||||
client_kwargs=client_kwargs,
|
||||
)
|
||||
@@ -1003,7 +982,6 @@ class RawAgent(BaseAgent, Generic[OptionsCoT]): # type: ignore[misc]
|
||||
options=options,
|
||||
compaction_strategy=compaction_strategy,
|
||||
tokenizer=tokenizer,
|
||||
legacy_kwargs=kwargs,
|
||||
function_invocation_kwargs=function_invocation_kwargs,
|
||||
client_kwargs=client_kwargs,
|
||||
)
|
||||
@@ -1103,7 +1081,6 @@ class RawAgent(BaseAgent, Generic[OptionsCoT]): # type: ignore[misc]
|
||||
options: Mapping[str, Any] | None,
|
||||
compaction_strategy: CompactionStrategy | None,
|
||||
tokenizer: TokenizerProtocol | None,
|
||||
legacy_kwargs: Mapping[str, Any],
|
||||
function_invocation_kwargs: Mapping[str, Any] | None,
|
||||
client_kwargs: Mapping[str, Any] | None,
|
||||
) -> _RunContext:
|
||||
@@ -1176,12 +1153,9 @@ class RawAgent(BaseAgent, Generic[OptionsCoT]): # type: ignore[misc]
|
||||
duplicate_error_message=mcp_duplicate_message,
|
||||
)
|
||||
|
||||
# TODO(Copilot): Delete once direct ``run(**kwargs)`` compatibility is removed.
|
||||
# Legacy compatibility still fans out direct run kwargs into tool runtime kwargs.
|
||||
effective_function_invocation_kwargs = {
|
||||
**dict(legacy_kwargs),
|
||||
**(dict(function_invocation_kwargs) if function_invocation_kwargs is not None else {}),
|
||||
}
|
||||
effective_function_invocation_kwargs = (
|
||||
dict(function_invocation_kwargs) if function_invocation_kwargs is not None else {}
|
||||
)
|
||||
additional_function_arguments = {**effective_function_invocation_kwargs, **existing_additional_args}
|
||||
|
||||
# Build options dict from run() options merged with provided options
|
||||
@@ -1214,12 +1188,7 @@ class RawAgent(BaseAgent, Generic[OptionsCoT]): # type: ignore[misc]
|
||||
# Build session_messages from session context: context messages + input messages
|
||||
session_messages: list[Message] = session_context.get_messages(include_input=True)
|
||||
|
||||
# TODO(Copilot): Delete once direct ``run(**kwargs)`` compatibility is removed.
|
||||
# Legacy compatibility still fans out direct run kwargs into client kwargs.
|
||||
effective_client_kwargs = {
|
||||
**dict(legacy_kwargs),
|
||||
**(dict(client_kwargs) if client_kwargs is not None else {}),
|
||||
}
|
||||
effective_client_kwargs = dict(client_kwargs) if client_kwargs is not None else {}
|
||||
if active_session is not None:
|
||||
effective_client_kwargs["session"] = active_session
|
||||
|
||||
@@ -1499,9 +1468,29 @@ class Agent(
|
||||
*,
|
||||
stream: Literal[False] = ...,
|
||||
session: AgentSession | None = None,
|
||||
middleware: Sequence[MiddlewareTypes] | None = None,
|
||||
tools: ToolTypes | Callable[..., Any] | Sequence[ToolTypes | Callable[..., Any]] | None = None,
|
||||
options: ChatOptions[ResponseModelBoundT],
|
||||
compaction_strategy: CompactionStrategy | None = None,
|
||||
tokenizer: TokenizerProtocol | None = None,
|
||||
function_invocation_kwargs: Mapping[str, Any] | None = None,
|
||||
client_kwargs: Mapping[str, Any] | None = None,
|
||||
) -> Awaitable[AgentResponse[ResponseModelBoundT]]: ...
|
||||
|
||||
@overload
|
||||
def run(
|
||||
self,
|
||||
messages: AgentRunInputs | None = None,
|
||||
*,
|
||||
stream: Literal[False] = ...,
|
||||
session: AgentSession | None = None,
|
||||
middleware: Sequence[MiddlewareTypes] | None = None,
|
||||
tools: ToolTypes | Callable[..., Any] | Sequence[ToolTypes | Callable[..., Any]] | None = None,
|
||||
options: OptionsCoT | ChatOptions[None] | None = None,
|
||||
compaction_strategy: CompactionStrategy | None = None,
|
||||
tokenizer: TokenizerProtocol | None = None,
|
||||
function_invocation_kwargs: Mapping[str, Any] | None = None,
|
||||
client_kwargs: Mapping[str, Any] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> Awaitable[AgentResponse[Any]]: ...
|
||||
|
||||
@overload
|
||||
@@ -1511,9 +1500,13 @@ class Agent(
|
||||
*,
|
||||
stream: Literal[True],
|
||||
session: AgentSession | None = None,
|
||||
middleware: Sequence[MiddlewareTypes] | None = None,
|
||||
tools: ToolTypes | Callable[..., Any] | Sequence[ToolTypes | Callable[..., Any]] | None = None,
|
||||
options: OptionsCoT | ChatOptions[Any] | None = None,
|
||||
compaction_strategy: CompactionStrategy | None = None,
|
||||
tokenizer: TokenizerProtocol | None = None,
|
||||
function_invocation_kwargs: Mapping[str, Any] | None = None,
|
||||
client_kwargs: Mapping[str, Any] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> ResponseStream[AgentResponseUpdate, AgentResponse[Any]]: ...
|
||||
|
||||
def run(
|
||||
@@ -1523,10 +1516,12 @@ class Agent(
|
||||
stream: bool = False,
|
||||
session: AgentSession | None = None,
|
||||
middleware: Sequence[MiddlewareTypes] | None = None,
|
||||
tools: ToolTypes | Callable[..., Any] | Sequence[ToolTypes | Callable[..., Any]] | None = None,
|
||||
options: OptionsCoT | ChatOptions[Any] | None = None,
|
||||
compaction_strategy: CompactionStrategy | None = None,
|
||||
tokenizer: TokenizerProtocol | None = None,
|
||||
function_invocation_kwargs: Mapping[str, Any] | None = None,
|
||||
client_kwargs: Mapping[str, Any] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> Awaitable[AgentResponse[Any]] | ResponseStream[AgentResponseUpdate, AgentResponse[Any]]:
|
||||
"""Run the agent."""
|
||||
super_run = cast(
|
||||
@@ -1538,10 +1533,12 @@ class Agent(
|
||||
stream=stream,
|
||||
session=session,
|
||||
middleware=middleware,
|
||||
tools=tools,
|
||||
options=options,
|
||||
compaction_strategy=compaction_strategy,
|
||||
tokenizer=tokenizer,
|
||||
function_invocation_kwargs=function_invocation_kwargs,
|
||||
client_kwargs=client_kwargs,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def __init__(
|
||||
@@ -1558,7 +1555,7 @@ class Agent(
|
||||
middleware: Sequence[MiddlewareTypes] | None = None,
|
||||
compaction_strategy: CompactionStrategy | None = None,
|
||||
tokenizer: TokenizerProtocol | None = None,
|
||||
**kwargs: Any,
|
||||
additional_properties: MutableMapping[str, Any] | None = None,
|
||||
) -> None:
|
||||
"""Initialize a Agent instance."""
|
||||
super().__init__(
|
||||
@@ -1573,7 +1570,7 @@ class Agent(
|
||||
middleware=middleware,
|
||||
compaction_strategy=compaction_strategy,
|
||||
tokenizer=tokenizer,
|
||||
**kwargs,
|
||||
additional_properties=additional_properties,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -4,7 +4,6 @@ from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import sys
|
||||
import warnings
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import (
|
||||
AsyncIterable,
|
||||
@@ -139,7 +138,8 @@ class SupportsChatGetResponse(Protocol[OptionsContraT]):
|
||||
options: ChatOptions[ResponseModelBoundT],
|
||||
compaction_strategy: CompactionStrategy | None = None,
|
||||
tokenizer: TokenizerProtocol | None = None,
|
||||
**kwargs: Any,
|
||||
function_invocation_kwargs: Mapping[str, Any] | None = None,
|
||||
client_kwargs: Mapping[str, Any] | None = None,
|
||||
) -> Awaitable[ChatResponse[ResponseModelBoundT]]: ...
|
||||
|
||||
@overload
|
||||
@@ -153,7 +153,6 @@ class SupportsChatGetResponse(Protocol[OptionsContraT]):
|
||||
tokenizer: TokenizerProtocol | None = None,
|
||||
function_invocation_kwargs: Mapping[str, Any] | None = None,
|
||||
client_kwargs: Mapping[str, Any] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> Awaitable[ChatResponse[Any]]: ...
|
||||
|
||||
@overload
|
||||
@@ -167,7 +166,6 @@ class SupportsChatGetResponse(Protocol[OptionsContraT]):
|
||||
tokenizer: TokenizerProtocol | None = None,
|
||||
function_invocation_kwargs: Mapping[str, Any] | None = None,
|
||||
client_kwargs: Mapping[str, Any] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> ResponseStream[ChatResponseUpdate, ChatResponse[Any]]: ...
|
||||
|
||||
def get_response(
|
||||
@@ -180,7 +178,6 @@ class SupportsChatGetResponse(Protocol[OptionsContraT]):
|
||||
tokenizer: TokenizerProtocol | None = None,
|
||||
function_invocation_kwargs: Mapping[str, Any] | None = None,
|
||||
client_kwargs: Mapping[str, Any] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> Awaitable[ChatResponse[Any]] | ResponseStream[ChatResponseUpdate, ChatResponse[Any]]:
|
||||
"""Send input and return the response.
|
||||
|
||||
@@ -192,7 +189,6 @@ class SupportsChatGetResponse(Protocol[OptionsContraT]):
|
||||
tokenizer: Optional per-call tokenizer override.
|
||||
function_invocation_kwargs: Keyword arguments forwarded only to tool invocation layers.
|
||||
client_kwargs: Additional client-specific keyword arguments.
|
||||
**kwargs: Deprecated additional client-specific keyword arguments.
|
||||
|
||||
Returns:
|
||||
When stream=False: An awaitable ChatResponse from the client.
|
||||
@@ -296,7 +292,6 @@ class BaseChatClient(SerializationMixin, ABC, Generic[OptionsCoT]):
|
||||
compaction_strategy: CompactionStrategy | None = None,
|
||||
tokenizer: TokenizerProtocol | None = None,
|
||||
additional_properties: dict[str, Any] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Initialize a BaseChatClient instance.
|
||||
|
||||
@@ -304,19 +299,10 @@ class BaseChatClient(SerializationMixin, ABC, Generic[OptionsCoT]):
|
||||
compaction_strategy: Optional compaction strategy to apply before model calls.
|
||||
tokenizer: Optional tokenizer used by token-aware compaction strategies.
|
||||
additional_properties: Additional properties for the client.
|
||||
kwargs: Additional keyword arguments (merged into additional_properties for now).
|
||||
"""
|
||||
self.additional_properties = additional_properties or {}
|
||||
self.compaction_strategy = compaction_strategy
|
||||
self.tokenizer = tokenizer
|
||||
if kwargs:
|
||||
warnings.warn(
|
||||
"Passing additional properties as direct keyword arguments to BaseChatClient is deprecated; "
|
||||
"pass them via additional_properties instead.",
|
||||
DeprecationWarning,
|
||||
stacklevel=3,
|
||||
)
|
||||
self.additional_properties.update(kwargs)
|
||||
super().__init__()
|
||||
|
||||
def to_dict(self, *, exclude: set[str] | None = None, exclude_none: bool = True) -> dict[str, Any]:
|
||||
@@ -457,7 +443,8 @@ class BaseChatClient(SerializationMixin, ABC, Generic[OptionsCoT]):
|
||||
options: ChatOptions[ResponseModelBoundT],
|
||||
compaction_strategy: CompactionStrategy | None = None,
|
||||
tokenizer: TokenizerProtocol | None = None,
|
||||
**kwargs: Any,
|
||||
function_invocation_kwargs: Mapping[str, Any] | None = None,
|
||||
client_kwargs: Mapping[str, Any] | None = None,
|
||||
) -> Awaitable[ChatResponse[ResponseModelBoundT]]: ...
|
||||
|
||||
@overload
|
||||
@@ -469,7 +456,8 @@ class BaseChatClient(SerializationMixin, ABC, Generic[OptionsCoT]):
|
||||
options: OptionsCoT | ChatOptions[None] | None = None,
|
||||
compaction_strategy: CompactionStrategy | None = None,
|
||||
tokenizer: TokenizerProtocol | None = None,
|
||||
**kwargs: Any,
|
||||
function_invocation_kwargs: Mapping[str, Any] | None = None,
|
||||
client_kwargs: Mapping[str, Any] | None = None,
|
||||
) -> Awaitable[ChatResponse[Any]]: ...
|
||||
|
||||
@overload
|
||||
@@ -481,7 +469,8 @@ class BaseChatClient(SerializationMixin, ABC, Generic[OptionsCoT]):
|
||||
options: OptionsCoT | ChatOptions[Any] | None = None,
|
||||
compaction_strategy: CompactionStrategy | None = None,
|
||||
tokenizer: TokenizerProtocol | None = None,
|
||||
**kwargs: Any,
|
||||
function_invocation_kwargs: Mapping[str, Any] | None = None,
|
||||
client_kwargs: Mapping[str, Any] | None = None,
|
||||
) -> ResponseStream[ChatResponseUpdate, ChatResponse[Any]]: ...
|
||||
|
||||
def get_response(
|
||||
@@ -492,7 +481,8 @@ class BaseChatClient(SerializationMixin, ABC, Generic[OptionsCoT]):
|
||||
options: OptionsCoT | ChatOptions[Any] | None = None,
|
||||
compaction_strategy: CompactionStrategy | None = None,
|
||||
tokenizer: TokenizerProtocol | None = None,
|
||||
**kwargs: Any,
|
||||
function_invocation_kwargs: Mapping[str, Any] | None = None,
|
||||
client_kwargs: Mapping[str, Any] | None = None,
|
||||
) -> Awaitable[ChatResponse[Any]] | ResponseStream[ChatResponseUpdate, ChatResponse[Any]]:
|
||||
"""Get a response from a chat client.
|
||||
|
||||
@@ -504,13 +494,9 @@ class BaseChatClient(SerializationMixin, ABC, Generic[OptionsCoT]):
|
||||
When omitted, the client-level default is used.
|
||||
tokenizer: Optional per-call tokenizer override. When omitted, the
|
||||
client-level default is used.
|
||||
**kwargs: Additional compatibility keyword arguments. Lower chat-client layers do not
|
||||
consume ``function_invocation_kwargs`` directly; if present, it is ignored here
|
||||
because function invocation has already been handled by upper layers. If a
|
||||
``client_kwargs`` mapping is present, it is flattened into standard keyword
|
||||
arguments before forwarding to ``_inner_get_response()`` so client implementations
|
||||
can leverage those values, while implementations that ignore
|
||||
extra kwargs remain compatible.
|
||||
function_invocation_kwargs: Keyword arguments forwarded only to tool invocation layers.
|
||||
client_kwargs: Additional client-specific keyword arguments forwarded to
|
||||
``_inner_get_response()``.
|
||||
|
||||
Returns:
|
||||
When streaming a response stream of ChatResponseUpdates, otherwise an Awaitable ChatResponse.
|
||||
@@ -519,14 +505,7 @@ class BaseChatClient(SerializationMixin, ABC, Generic[OptionsCoT]):
|
||||
compaction_strategy=compaction_strategy,
|
||||
tokenizer=tokenizer,
|
||||
)
|
||||
compatibility_client_kwargs = kwargs.pop("client_kwargs", None)
|
||||
kwargs.pop("function_invocation_kwargs", None)
|
||||
merged_client_kwargs = (
|
||||
dict(cast(Mapping[str, Any], compatibility_client_kwargs))
|
||||
if isinstance(compatibility_client_kwargs, Mapping)
|
||||
else {}
|
||||
)
|
||||
merged_client_kwargs.update(kwargs)
|
||||
merged_client_kwargs = dict(client_kwargs) if client_kwargs is not None else {}
|
||||
|
||||
if not compaction_overrides:
|
||||
return self._inner_get_response(
|
||||
|
||||
@@ -768,7 +768,8 @@ class MCPTool:
|
||||
options["stop"] = params.stopSequences
|
||||
|
||||
try:
|
||||
response = await self.client.get_response(
|
||||
chat_client: Any = self.client
|
||||
response: Any = await chat_client.get_response(
|
||||
messages,
|
||||
options=options or None,
|
||||
)
|
||||
|
||||
@@ -39,7 +39,7 @@ if TYPE_CHECKING:
|
||||
from ._clients import SupportsChatGetResponse
|
||||
from ._compaction import CompactionStrategy, TokenizerProtocol
|
||||
from ._sessions import AgentSession
|
||||
from ._tools import FunctionTool
|
||||
from ._tools import FunctionTool, ToolTypes
|
||||
from ._types import ChatOptions, ChatResponse, ChatResponseUpdate
|
||||
|
||||
ResponseModelBoundT = TypeVar("ResponseModelBoundT", bound=BaseModel)
|
||||
@@ -100,6 +100,7 @@ class AgentContext:
|
||||
agent: The agent being invoked.
|
||||
messages: The messages being sent to the agent.
|
||||
session: The agent session for this invocation, if any.
|
||||
tools: Run-level tool overrides for this invocation, if any.
|
||||
options: The options for the agent invocation as a dict.
|
||||
stream: Whether this is a streaming invocation.
|
||||
compaction_strategy: Optional per-run compaction override.
|
||||
@@ -142,6 +143,7 @@ class AgentContext:
|
||||
agent: SupportsAgentRun,
|
||||
messages: list[Message],
|
||||
session: AgentSession | None = None,
|
||||
tools: ToolTypes | Callable[..., Any] | Sequence[ToolTypes | Callable[..., Any]] | None = None,
|
||||
options: Mapping[str, Any] | None = None,
|
||||
stream: bool = False,
|
||||
compaction_strategy: CompactionStrategy | None = None,
|
||||
@@ -165,6 +167,7 @@ class AgentContext:
|
||||
agent: The agent being invoked.
|
||||
messages: The messages being sent to the agent.
|
||||
session: The agent session for this invocation, if any.
|
||||
tools: Run-level tool overrides for this invocation, if any.
|
||||
options: The options for the agent invocation as a dict.
|
||||
stream: Whether this is a streaming invocation.
|
||||
compaction_strategy: Optional per-run compaction override.
|
||||
@@ -181,6 +184,7 @@ class AgentContext:
|
||||
self.agent = agent
|
||||
self.messages = messages
|
||||
self.session = session
|
||||
self.tools = tools
|
||||
self.options = options
|
||||
self.stream = stream
|
||||
self.compaction_strategy = compaction_strategy
|
||||
@@ -1025,7 +1029,7 @@ class ChatMiddlewareLayer(Generic[OptionsCoT]):
|
||||
compaction_strategy: CompactionStrategy | None = None,
|
||||
tokenizer: TokenizerProtocol | None = None,
|
||||
function_invocation_kwargs: Mapping[str, Any] | None = None,
|
||||
**kwargs: Any,
|
||||
client_kwargs: Mapping[str, Any] | None = None,
|
||||
) -> Awaitable[ChatResponse[ResponseModelBoundT]]: ...
|
||||
|
||||
@overload
|
||||
@@ -1039,7 +1043,6 @@ class ChatMiddlewareLayer(Generic[OptionsCoT]):
|
||||
tokenizer: TokenizerProtocol | None = None,
|
||||
function_invocation_kwargs: Mapping[str, Any] | None = None,
|
||||
client_kwargs: Mapping[str, Any] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> Awaitable[ChatResponse[Any]]: ...
|
||||
|
||||
@overload
|
||||
@@ -1053,7 +1056,6 @@ class ChatMiddlewareLayer(Generic[OptionsCoT]):
|
||||
tokenizer: TokenizerProtocol | None = None,
|
||||
function_invocation_kwargs: Mapping[str, Any] | None = None,
|
||||
client_kwargs: Mapping[str, Any] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> ResponseStream[ChatResponseUpdate, ChatResponse[Any]]: ...
|
||||
|
||||
def get_response(
|
||||
@@ -1066,27 +1068,26 @@ class ChatMiddlewareLayer(Generic[OptionsCoT]):
|
||||
tokenizer: TokenizerProtocol | None = None,
|
||||
function_invocation_kwargs: Mapping[str, Any] | None = None,
|
||||
client_kwargs: Mapping[str, Any] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> Awaitable[ChatResponse[Any]] | ResponseStream[ChatResponseUpdate, ChatResponse[Any]]:
|
||||
"""Execute the chat pipeline if middleware is configured."""
|
||||
super_get_response = super().get_response # type: ignore[misc]
|
||||
|
||||
if compaction_strategy is not None:
|
||||
kwargs["compaction_strategy"] = compaction_strategy
|
||||
if tokenizer is not None:
|
||||
kwargs["tokenizer"] = tokenizer
|
||||
|
||||
effective_client_kwargs = dict(client_kwargs) if client_kwargs is not None else {}
|
||||
call_middleware = effective_client_kwargs.pop("middleware", [])
|
||||
context_kwargs = dict(effective_client_kwargs)
|
||||
if compaction_strategy is not None:
|
||||
context_kwargs["compaction_strategy"] = compaction_strategy
|
||||
if tokenizer is not None:
|
||||
context_kwargs["tokenizer"] = tokenizer
|
||||
pipeline = self._get_chat_middleware_pipeline(call_middleware) # type: ignore[reportUnknownArgumentType]
|
||||
if not pipeline.has_middlewares:
|
||||
return super_get_response( # type: ignore[no-any-return]
|
||||
messages=messages,
|
||||
stream=stream,
|
||||
options=options,
|
||||
compaction_strategy=compaction_strategy,
|
||||
tokenizer=tokenizer,
|
||||
function_invocation_kwargs=function_invocation_kwargs,
|
||||
client_kwargs=effective_client_kwargs,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
context = ChatContext(
|
||||
@@ -1094,7 +1095,7 @@ class ChatMiddlewareLayer(Generic[OptionsCoT]):
|
||||
messages=list(messages),
|
||||
options=options,
|
||||
stream=stream,
|
||||
kwargs={**effective_client_kwargs, **kwargs},
|
||||
kwargs=context_kwargs,
|
||||
function_invocation_kwargs=function_invocation_kwargs,
|
||||
)
|
||||
|
||||
@@ -1180,12 +1181,12 @@ class AgentMiddlewareLayer:
|
||||
stream: Literal[False] = ...,
|
||||
session: AgentSession | None = None,
|
||||
middleware: Sequence[MiddlewareTypes] | None = None,
|
||||
tools: ToolTypes | Callable[..., Any] | Sequence[ToolTypes | Callable[..., Any]] | None = None,
|
||||
options: ChatOptions[ResponseModelBoundT],
|
||||
compaction_strategy: CompactionStrategy | None = None,
|
||||
tokenizer: TokenizerProtocol | None = None,
|
||||
function_invocation_kwargs: Mapping[str, Any] | None = None,
|
||||
client_kwargs: Mapping[str, Any] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> Awaitable[AgentResponse[ResponseModelBoundT]]: ...
|
||||
|
||||
@overload
|
||||
@@ -1196,12 +1197,12 @@ class AgentMiddlewareLayer:
|
||||
stream: Literal[False] = ...,
|
||||
session: AgentSession | None = None,
|
||||
middleware: Sequence[MiddlewareTypes] | None = None,
|
||||
tools: ToolTypes | Callable[..., Any] | Sequence[ToolTypes | Callable[..., Any]] | None = None,
|
||||
options: ChatOptions[None] | None = None,
|
||||
compaction_strategy: CompactionStrategy | None = None,
|
||||
tokenizer: TokenizerProtocol | None = None,
|
||||
function_invocation_kwargs: Mapping[str, Any] | None = None,
|
||||
client_kwargs: Mapping[str, Any] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> Awaitable[AgentResponse[Any]]: ...
|
||||
|
||||
@overload
|
||||
@@ -1212,12 +1213,12 @@ class AgentMiddlewareLayer:
|
||||
stream: Literal[True],
|
||||
session: AgentSession | None = None,
|
||||
middleware: Sequence[MiddlewareTypes] | None = None,
|
||||
tools: ToolTypes | Callable[..., Any] | Sequence[ToolTypes | Callable[..., Any]] | None = None,
|
||||
options: ChatOptions[Any] | None = None,
|
||||
compaction_strategy: CompactionStrategy | None = None,
|
||||
tokenizer: TokenizerProtocol | None = None,
|
||||
function_invocation_kwargs: Mapping[str, Any] | None = None,
|
||||
client_kwargs: Mapping[str, Any] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> ResponseStream[AgentResponseUpdate, AgentResponse[Any]]: ...
|
||||
|
||||
def run(
|
||||
@@ -1227,12 +1228,12 @@ class AgentMiddlewareLayer:
|
||||
stream: bool = False,
|
||||
session: AgentSession | None = None,
|
||||
middleware: Sequence[MiddlewareTypes] | None = None,
|
||||
tools: ToolTypes | Callable[..., Any] | Sequence[ToolTypes | Callable[..., Any]] | None = None,
|
||||
options: ChatOptions[Any] | None = None,
|
||||
compaction_strategy: CompactionStrategy | None = None,
|
||||
tokenizer: TokenizerProtocol | None = None,
|
||||
function_invocation_kwargs: Mapping[str, Any] | None = None,
|
||||
client_kwargs: Mapping[str, Any] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> Awaitable[AgentResponse[Any]] | ResponseStream[AgentResponseUpdate, AgentResponse[Any]]:
|
||||
"""MiddlewareTypes-enabled unified run method."""
|
||||
# Re-categorize self.middleware at runtime to support dynamic changes
|
||||
@@ -1263,23 +1264,23 @@ class AgentMiddlewareLayer:
|
||||
messages,
|
||||
stream=stream,
|
||||
session=session,
|
||||
tools=tools,
|
||||
options=options,
|
||||
compaction_strategy=compaction_strategy,
|
||||
tokenizer=tokenizer,
|
||||
function_invocation_kwargs=effective_function_invocation_kwargs,
|
||||
client_kwargs=effective_client_kwargs,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
context = AgentContext(
|
||||
agent=self, # type: ignore[arg-type]
|
||||
messages=normalize_messages(messages),
|
||||
session=session,
|
||||
tools=tools,
|
||||
options=options,
|
||||
stream=stream,
|
||||
compaction_strategy=compaction_strategy,
|
||||
tokenizer=tokenizer,
|
||||
kwargs=kwargs,
|
||||
client_kwargs=effective_client_kwargs,
|
||||
function_invocation_kwargs=effective_function_invocation_kwargs,
|
||||
)
|
||||
@@ -1313,22 +1314,16 @@ class AgentMiddlewareLayer:
|
||||
def _middleware_handler(
|
||||
self, context: AgentContext
|
||||
) -> Awaitable[AgentResponse] | ResponseStream[AgentResponseUpdate, AgentResponse]:
|
||||
# TODO(Copilot): Delete once direct ``run(**kwargs)`` compatibility is removed.
|
||||
client_kwargs = {**context.client_kwargs, **context.kwargs}
|
||||
# TODO(Copilot): Delete once direct ``run(**kwargs)`` compatibility is removed.
|
||||
function_invocation_kwargs = {
|
||||
**context.function_invocation_kwargs,
|
||||
**{k: v for k, v in context.kwargs.items() if k != "middleware"},
|
||||
}
|
||||
return super().run( # type: ignore[misc, no-any-return]
|
||||
context.messages,
|
||||
stream=context.stream,
|
||||
session=context.session,
|
||||
tools=context.tools,
|
||||
options=context.options,
|
||||
compaction_strategy=context.compaction_strategy,
|
||||
tokenizer=context.tokenizer,
|
||||
function_invocation_kwargs=function_invocation_kwargs,
|
||||
client_kwargs=client_kwargs,
|
||||
function_invocation_kwargs=context.function_invocation_kwargs,
|
||||
client_kwargs=context.client_kwargs,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -8,7 +8,6 @@ import json
|
||||
import logging
|
||||
import sys
|
||||
import typing
|
||||
import warnings
|
||||
from collections.abc import (
|
||||
AsyncIterable,
|
||||
Awaitable,
|
||||
@@ -344,8 +343,6 @@ class FunctionTool(SerializationMixin):
|
||||
self._instance = None # Store the instance for bound methods
|
||||
self._context_parameter_name: str | None = None
|
||||
self._input_model_explicitly_provided = input_model is not None
|
||||
# TODO(Copilot): Delete once legacy ``**kwargs`` runtime injection is removed.
|
||||
self._forward_runtime_kwargs: bool = False
|
||||
if self.func:
|
||||
self._discover_injected_parameters()
|
||||
|
||||
@@ -390,10 +387,6 @@ class FunctionTool(SerializationMixin):
|
||||
for name, param in signature.parameters.items():
|
||||
if name in {"self", "cls"}:
|
||||
continue
|
||||
if param.kind == inspect.Parameter.VAR_KEYWORD:
|
||||
self._forward_runtime_kwargs = True
|
||||
continue
|
||||
|
||||
annotation = type_hints.get(name, param.annotation)
|
||||
if self._is_context_parameter(name, annotation):
|
||||
if self._context_parameter_name is not None:
|
||||
@@ -518,6 +511,7 @@ class FunctionTool(SerializationMixin):
|
||||
*,
|
||||
arguments: BaseModel | Mapping[str, Any] | None = None,
|
||||
context: FunctionInvocationContext | None = None,
|
||||
tool_call_id: str | None = None,
|
||||
**kwargs: Any,
|
||||
) -> list[Content]:
|
||||
"""Run the AI function with the provided arguments as a Pydantic model.
|
||||
@@ -530,7 +524,10 @@ class FunctionTool(SerializationMixin):
|
||||
Keyword Args:
|
||||
arguments: A mapping or model instance containing the arguments for the function.
|
||||
context: Explicit function invocation context carrying runtime kwargs.
|
||||
kwargs: Deprecated keyword arguments to pass to the function. Use ``context`` instead.
|
||||
tool_call_id: Optional tool call identifier used for telemetry and tracing.
|
||||
kwargs: Direct function argument values. When provided, every keyword
|
||||
must match a declared tool parameter. Runtime data must be passed
|
||||
via ``context``.
|
||||
|
||||
Returns:
|
||||
A list of Content items representing the tool output.
|
||||
@@ -552,18 +549,13 @@ class FunctionTool(SerializationMixin):
|
||||
{key: value for key, value in kwargs.items() if key in parameter_names} if arguments is None else {}
|
||||
)
|
||||
runtime_kwargs = dict(context.kwargs) if context is not None else {}
|
||||
deprecated_runtime_kwargs = {
|
||||
key: value for key, value in kwargs.items() if key not in direct_argument_kwargs and key != "tool_call_id"
|
||||
}
|
||||
if deprecated_runtime_kwargs:
|
||||
warnings.warn(
|
||||
"Passing runtime keyword arguments directly to FunctionTool.invoke() is deprecated; "
|
||||
"pass them via FunctionInvocationContext instead.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
unexpected_kwargs = {key: value for key, value in kwargs.items() if key not in direct_argument_kwargs}
|
||||
if unexpected_kwargs:
|
||||
unexpected_names = ", ".join(sorted(unexpected_kwargs))
|
||||
raise TypeError(
|
||||
f"Unexpected keyword argument(s) for tool '{self.name}': {unexpected_names}. "
|
||||
"Pass runtime data via FunctionInvocationContext instead."
|
||||
)
|
||||
runtime_kwargs.update(deprecated_runtime_kwargs)
|
||||
tool_call_id = kwargs.get("tool_call_id", runtime_kwargs.pop("tool_call_id", None))
|
||||
if arguments is None and direct_argument_kwargs:
|
||||
arguments = direct_argument_kwargs
|
||||
if arguments is None and context is not None:
|
||||
@@ -614,17 +606,6 @@ class FunctionTool(SerializationMixin):
|
||||
|
||||
call_kwargs = dict(validated_arguments)
|
||||
observable_kwargs = dict(validated_arguments)
|
||||
|
||||
# Legacy runtime kwargs injection path retained for backwards compatibility with tools
|
||||
# that still declare ``**kwargs``. New tools should consume runtime data via ``ctx``.
|
||||
legacy_runtime_kwargs = dict(runtime_kwargs)
|
||||
if self._forward_runtime_kwargs and legacy_runtime_kwargs:
|
||||
for key, value in legacy_runtime_kwargs.items():
|
||||
if key not in call_kwargs:
|
||||
call_kwargs[key] = value
|
||||
if key not in observable_kwargs:
|
||||
observable_kwargs[key] = value
|
||||
|
||||
if self._context_parameter_name is not None and effective_context is not None:
|
||||
call_kwargs[self._context_parameter_name] = effective_context
|
||||
|
||||
@@ -1420,7 +1401,7 @@ async def _auto_invoke_function(
|
||||
# No middleware - execute directly
|
||||
try:
|
||||
direct_context = None
|
||||
if getattr(tool, "_forward_runtime_kwargs", False) or getattr(tool, "_context_parameter_name", None):
|
||||
if getattr(tool, "_context_parameter_name", None):
|
||||
direct_context = FunctionInvocationContext(
|
||||
function=tool,
|
||||
arguments=args,
|
||||
@@ -2078,7 +2059,6 @@ class FunctionInvocationLayer(Generic[OptionsCoT]):
|
||||
tokenizer: TokenizerProtocol | None = None,
|
||||
function_invocation_kwargs: Mapping[str, Any] | None = None,
|
||||
client_kwargs: Mapping[str, Any] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> Awaitable[ChatResponse[ResponseModelBoundT]]: ...
|
||||
|
||||
@overload
|
||||
@@ -2093,7 +2073,6 @@ class FunctionInvocationLayer(Generic[OptionsCoT]):
|
||||
tokenizer: TokenizerProtocol | None = None,
|
||||
function_invocation_kwargs: Mapping[str, Any] | None = None,
|
||||
client_kwargs: Mapping[str, Any] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> Awaitable[ChatResponse[Any]]: ...
|
||||
|
||||
@overload
|
||||
@@ -2108,7 +2087,6 @@ class FunctionInvocationLayer(Generic[OptionsCoT]):
|
||||
tokenizer: TokenizerProtocol | None = None,
|
||||
function_invocation_kwargs: Mapping[str, Any] | None = None,
|
||||
client_kwargs: Mapping[str, Any] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> ResponseStream[ChatResponseUpdate, ChatResponse[Any]]: ...
|
||||
|
||||
def get_response(
|
||||
@@ -2122,7 +2100,6 @@ class FunctionInvocationLayer(Generic[OptionsCoT]):
|
||||
tokenizer: TokenizerProtocol | None = None,
|
||||
function_invocation_kwargs: Mapping[str, Any] | None = None,
|
||||
client_kwargs: Mapping[str, Any] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> Awaitable[ChatResponse[Any]] | ResponseStream[ChatResponseUpdate, ChatResponse[Any]]:
|
||||
from ._middleware import categorize_middleware
|
||||
from ._types import (
|
||||
@@ -2133,14 +2110,6 @@ class FunctionInvocationLayer(Generic[OptionsCoT]):
|
||||
)
|
||||
|
||||
super_get_response = super().get_response # type: ignore[misc]
|
||||
if kwargs:
|
||||
warnings.warn(
|
||||
"Passing client-specific keyword arguments directly to get_response() is deprecated; "
|
||||
"pass them via client_kwargs instead.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
effective_client_kwargs = dict(client_kwargs) if client_kwargs is not None else {}
|
||||
if middleware is not None:
|
||||
existing = effective_client_kwargs.get("middleware", [])
|
||||
@@ -2176,19 +2145,23 @@ class FunctionInvocationLayer(Generic[OptionsCoT]):
|
||||
invocation_session=invocation_session,
|
||||
middleware_pipeline=function_middleware_pipeline,
|
||||
)
|
||||
filtered_kwargs = {k: v for k, v in {**effective_client_kwargs, **kwargs}.items() if k != "session"}
|
||||
filtered_kwargs = {k: v for k, v in effective_client_kwargs.items() if k != "session"}
|
||||
|
||||
# Make options mutable so we can update conversation_id during function invocation loop
|
||||
mutable_options: dict[str, Any] = dict(options) if options else {}
|
||||
# Remove additional_function_arguments from options passed to underlying chat client
|
||||
# It's for tool invocation only and not recognized by chat service APIs
|
||||
mutable_options.pop("additional_function_arguments", None)
|
||||
# Support tools passed via kwargs in direct client.get_response(...) calls.
|
||||
if "tools" in filtered_kwargs:
|
||||
if mutable_options.get("tools") is None:
|
||||
mutable_options["tools"] = filtered_kwargs["tools"]
|
||||
filtered_kwargs.pop("tools", None)
|
||||
|
||||
if not self.function_invocation_configuration.get("enabled", True):
|
||||
return super_get_response( # type: ignore[no-any-return]
|
||||
messages=messages,
|
||||
stream=stream,
|
||||
options=mutable_options,
|
||||
compaction_strategy=compaction_strategy,
|
||||
tokenizer=tokenizer,
|
||||
function_invocation_kwargs=function_invocation_kwargs,
|
||||
client_kwargs=filtered_kwargs,
|
||||
)
|
||||
if not stream:
|
||||
|
||||
async def _get_response() -> ChatResponse[Any]:
|
||||
@@ -2235,7 +2208,7 @@ class FunctionInvocationLayer(Generic[OptionsCoT]):
|
||||
aggregated_usage = add_usage_details(aggregated_usage, response.usage_details)
|
||||
|
||||
if response.conversation_id is not None:
|
||||
_update_conversation_id(kwargs, response.conversation_id, mutable_options)
|
||||
_update_conversation_id(filtered_kwargs, response.conversation_id, mutable_options)
|
||||
prepped_messages = []
|
||||
|
||||
result = await _process_function_requests(
|
||||
@@ -2379,7 +2352,7 @@ class FunctionInvocationLayer(Generic[OptionsCoT]):
|
||||
return
|
||||
|
||||
if response.conversation_id is not None:
|
||||
_update_conversation_id(kwargs, response.conversation_id, mutable_options)
|
||||
_update_conversation_id(filtered_kwargs, response.conversation_id, mutable_options)
|
||||
prepped_messages = []
|
||||
|
||||
result = await _process_function_requests(
|
||||
|
||||
@@ -12,7 +12,7 @@ from agent_framework import Content
|
||||
|
||||
from .._agents import SupportsAgentRun
|
||||
from .._sessions import AgentSession
|
||||
from .._types import AgentResponse, AgentResponseUpdate, Message
|
||||
from .._types import AgentResponse, AgentResponseUpdate, Message, ResponseStream
|
||||
from ._agent_utils import resolve_agent_id
|
||||
from ._const import WORKFLOW_RUN_KWARGS_KEY
|
||||
from ._executor import Executor, handler
|
||||
@@ -352,7 +352,8 @@ class AgentExecutor(Executor):
|
||||
"""
|
||||
run_kwargs, options = self._prepare_agent_run_args(ctx.get_state(WORKFLOW_RUN_KWARGS_KEY, {}))
|
||||
|
||||
response = await self._agent.run(
|
||||
run_agent = cast(Callable[..., Awaitable[AgentResponse[Any]]], self._agent.run)
|
||||
response = await run_agent(
|
||||
self._cache,
|
||||
stream=False,
|
||||
session=self._session,
|
||||
@@ -383,7 +384,8 @@ class AgentExecutor(Executor):
|
||||
|
||||
updates: list[AgentResponseUpdate] = []
|
||||
streamed_user_input_requests: list[Content] = []
|
||||
stream = self._agent.run(
|
||||
run_agent_stream = cast(Callable[..., ResponseStream[AgentResponseUpdate, AgentResponse[Any]]], self._agent.run)
|
||||
stream = run_agent_stream(
|
||||
self._cache,
|
||||
stream=True,
|
||||
session=self._session,
|
||||
|
||||
@@ -49,8 +49,9 @@ if TYPE_CHECKING: # pragma: no cover
|
||||
from ._agents import SupportsAgentRun
|
||||
from ._clients import SupportsChatGetResponse
|
||||
from ._compaction import CompactionStrategy, TokenizerProtocol
|
||||
from ._middleware import MiddlewareTypes
|
||||
from ._sessions import AgentSession
|
||||
from ._tools import FunctionTool
|
||||
from ._tools import FunctionTool, ToolTypes
|
||||
from ._types import (
|
||||
AgentResponse,
|
||||
AgentResponseUpdate,
|
||||
@@ -1191,7 +1192,8 @@ class ChatTelemetryLayer(Generic[OptionsCoT]):
|
||||
options: ChatOptions[ResponseModelBoundT],
|
||||
compaction_strategy: CompactionStrategy | None = None,
|
||||
tokenizer: TokenizerProtocol | None = None,
|
||||
**kwargs: Any,
|
||||
function_invocation_kwargs: Mapping[str, Any] | None = None,
|
||||
client_kwargs: Mapping[str, Any] | None = None,
|
||||
) -> Awaitable[ChatResponse[ResponseModelBoundT]]: ...
|
||||
|
||||
@overload
|
||||
@@ -1203,7 +1205,8 @@ class ChatTelemetryLayer(Generic[OptionsCoT]):
|
||||
options: OptionsCoT | ChatOptions[None] | None = None,
|
||||
compaction_strategy: CompactionStrategy | None = None,
|
||||
tokenizer: TokenizerProtocol | None = None,
|
||||
**kwargs: Any,
|
||||
function_invocation_kwargs: Mapping[str, Any] | None = None,
|
||||
client_kwargs: Mapping[str, Any] | None = None,
|
||||
) -> Awaitable[ChatResponse[Any]]: ...
|
||||
|
||||
@overload
|
||||
@@ -1215,7 +1218,8 @@ class ChatTelemetryLayer(Generic[OptionsCoT]):
|
||||
options: OptionsCoT | ChatOptions[Any] | None = None,
|
||||
compaction_strategy: CompactionStrategy | None = None,
|
||||
tokenizer: TokenizerProtocol | None = None,
|
||||
**kwargs: Any,
|
||||
function_invocation_kwargs: Mapping[str, Any] | None = None,
|
||||
client_kwargs: Mapping[str, Any] | None = None,
|
||||
) -> ResponseStream[ChatResponseUpdate, ChatResponse[Any]]: ...
|
||||
|
||||
def get_response(
|
||||
@@ -1226,7 +1230,8 @@ class ChatTelemetryLayer(Generic[OptionsCoT]):
|
||||
options: OptionsCoT | ChatOptions[Any] | None = None,
|
||||
compaction_strategy: CompactionStrategy | None = None,
|
||||
tokenizer: TokenizerProtocol | None = None,
|
||||
**kwargs: Any,
|
||||
function_invocation_kwargs: Mapping[str, Any] | None = None,
|
||||
client_kwargs: Mapping[str, Any] | None = None,
|
||||
) -> Awaitable[ChatResponse[Any]] | ResponseStream[ChatResponseUpdate, ChatResponse[Any]]:
|
||||
"""Trace chat responses with OpenTelemetry spans and metrics.
|
||||
|
||||
@@ -1238,25 +1243,14 @@ class ChatTelemetryLayer(Generic[OptionsCoT]):
|
||||
tokenizer: Optional tokenizer used by token-aware compaction strategies.
|
||||
|
||||
Keyword Args:
|
||||
kwargs: Compatibility keyword arguments from higher client layers. This layer does
|
||||
not consume ``function_invocation_kwargs`` directly; if present, it is ignored
|
||||
because function invocation has already been processed above. If a ``client_kwargs``
|
||||
mapping is present, it is flattened into ordinary keyword arguments for tracing and
|
||||
forwarding so clients that use those values continue to work while clients that
|
||||
ignore extra kwargs remain compatible.
|
||||
function_invocation_kwargs: Keyword arguments forwarded only to tool invocation layers.
|
||||
client_kwargs: Additional client-specific keyword arguments for downstream chat clients.
|
||||
"""
|
||||
from ._types import ChatResponse, ChatResponseUpdate, ResponseStream # type: ignore[reportUnusedImport]
|
||||
|
||||
global OBSERVABILITY_SETTINGS
|
||||
super_get_response = super().get_response # type: ignore[misc]
|
||||
compatibility_client_kwargs = kwargs.pop("client_kwargs", None)
|
||||
kwargs.pop("function_invocation_kwargs", None)
|
||||
merged_client_kwargs = (
|
||||
dict(cast(Mapping[str, Any], compatibility_client_kwargs))
|
||||
if isinstance(compatibility_client_kwargs, Mapping)
|
||||
else {}
|
||||
)
|
||||
merged_client_kwargs.update(kwargs)
|
||||
merged_client_kwargs = dict(client_kwargs) if client_kwargs is not None else {}
|
||||
|
||||
if not OBSERVABILITY_SETTINGS.ENABLED:
|
||||
return super_get_response( # type: ignore[no-any-return]
|
||||
@@ -1265,7 +1259,8 @@ class ChatTelemetryLayer(Generic[OptionsCoT]):
|
||||
options=options,
|
||||
compaction_strategy=compaction_strategy,
|
||||
tokenizer=tokenizer,
|
||||
**merged_client_kwargs,
|
||||
function_invocation_kwargs=function_invocation_kwargs,
|
||||
client_kwargs=merged_client_kwargs,
|
||||
)
|
||||
|
||||
opts: dict[str, Any] = options or {} # type: ignore[assignment]
|
||||
@@ -1292,7 +1287,8 @@ class ChatTelemetryLayer(Generic[OptionsCoT]):
|
||||
options=opts,
|
||||
compaction_strategy=compaction_strategy,
|
||||
tokenizer=tokenizer,
|
||||
**merged_client_kwargs,
|
||||
function_invocation_kwargs=function_invocation_kwargs,
|
||||
client_kwargs=merged_client_kwargs,
|
||||
),
|
||||
)
|
||||
|
||||
@@ -1384,7 +1380,8 @@ class ChatTelemetryLayer(Generic[OptionsCoT]):
|
||||
options=opts,
|
||||
compaction_strategy=compaction_strategy,
|
||||
tokenizer=tokenizer,
|
||||
**merged_client_kwargs,
|
||||
function_invocation_kwargs=function_invocation_kwargs,
|
||||
client_kwargs=merged_client_kwargs,
|
||||
),
|
||||
)
|
||||
except Exception as exception:
|
||||
@@ -1512,11 +1509,29 @@ class AgentTelemetryLayer:
|
||||
*,
|
||||
stream: Literal[False] = ...,
|
||||
session: AgentSession | None = None,
|
||||
middleware: Sequence[MiddlewareTypes] | None = None,
|
||||
tools: ToolTypes | Callable[..., Any] | Sequence[ToolTypes | Callable[..., Any]] | None = None,
|
||||
options: ChatOptions[ResponseModelBoundT],
|
||||
compaction_strategy: CompactionStrategy | None = None,
|
||||
tokenizer: TokenizerProtocol | None = None,
|
||||
function_invocation_kwargs: Mapping[str, Any] | None = None,
|
||||
client_kwargs: Mapping[str, Any] | None = None,
|
||||
) -> Awaitable[AgentResponse[ResponseModelBoundT]]: ...
|
||||
|
||||
@overload
|
||||
def run(
|
||||
self,
|
||||
messages: AgentRunInputs | None = None,
|
||||
*,
|
||||
stream: Literal[False] = ...,
|
||||
session: AgentSession | None = None,
|
||||
middleware: Sequence[MiddlewareTypes] | None = None,
|
||||
tools: ToolTypes | Callable[..., Any] | Sequence[ToolTypes | Callable[..., Any]] | None = None,
|
||||
options: ChatOptions[None] | None = None,
|
||||
compaction_strategy: CompactionStrategy | None = None,
|
||||
tokenizer: TokenizerProtocol | None = None,
|
||||
function_invocation_kwargs: Mapping[str, Any] | None = None,
|
||||
client_kwargs: Mapping[str, Any] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> Awaitable[AgentResponse[Any]]: ...
|
||||
|
||||
@overload
|
||||
@@ -1526,11 +1541,13 @@ class AgentTelemetryLayer:
|
||||
*,
|
||||
stream: Literal[True],
|
||||
session: AgentSession | None = None,
|
||||
middleware: Sequence[MiddlewareTypes] | None = None,
|
||||
tools: ToolTypes | Callable[..., Any] | Sequence[ToolTypes | Callable[..., Any]] | None = None,
|
||||
options: ChatOptions[Any] | None = None,
|
||||
compaction_strategy: CompactionStrategy | None = None,
|
||||
tokenizer: TokenizerProtocol | None = None,
|
||||
function_invocation_kwargs: Mapping[str, Any] | None = None,
|
||||
client_kwargs: Mapping[str, Any] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> ResponseStream[AgentResponseUpdate, AgentResponse[Any]]: ...
|
||||
|
||||
def run(
|
||||
@@ -1539,11 +1556,13 @@ class AgentTelemetryLayer:
|
||||
*,
|
||||
stream: bool = False,
|
||||
session: AgentSession | None = None,
|
||||
middleware: Sequence[MiddlewareTypes] | None = None,
|
||||
tools: ToolTypes | Callable[..., Any] | Sequence[ToolTypes | Callable[..., Any]] | None = None,
|
||||
options: ChatOptions[Any] | None = None,
|
||||
compaction_strategy: CompactionStrategy | None = None,
|
||||
tokenizer: TokenizerProtocol | None = None,
|
||||
function_invocation_kwargs: Mapping[str, Any] | None = None,
|
||||
client_kwargs: Mapping[str, Any] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> Awaitable[AgentResponse[Any]] | ResponseStream[AgentResponseUpdate, AgentResponse[Any]]:
|
||||
"""Trace agent runs with OpenTelemetry spans and metrics."""
|
||||
global OBSERVABILITY_SETTINGS
|
||||
@@ -1554,23 +1573,27 @@ class AgentTelemetryLayer:
|
||||
super().run, # type: ignore[misc]
|
||||
)
|
||||
provider_name = str(self.otel_provider_name)
|
||||
super_run_kwargs: dict[str, Any] = {
|
||||
"messages": messages,
|
||||
"stream": stream,
|
||||
"session": session,
|
||||
"tools": tools,
|
||||
"options": options,
|
||||
"compaction_strategy": compaction_strategy,
|
||||
"tokenizer": tokenizer,
|
||||
"function_invocation_kwargs": function_invocation_kwargs,
|
||||
"client_kwargs": client_kwargs,
|
||||
}
|
||||
if middleware is not None:
|
||||
super_run_kwargs["middleware"] = middleware
|
||||
if not OBSERVABILITY_SETTINGS.ENABLED:
|
||||
return super_run( # type: ignore[no-any-return]
|
||||
messages=messages,
|
||||
stream=stream,
|
||||
session=session,
|
||||
compaction_strategy=compaction_strategy,
|
||||
tokenizer=tokenizer,
|
||||
function_invocation_kwargs=function_invocation_kwargs,
|
||||
client_kwargs=client_kwargs,
|
||||
**kwargs,
|
||||
)
|
||||
return super_run(**super_run_kwargs) # type: ignore[no-any-return]
|
||||
|
||||
default_options = getattr(self, "default_options", {})
|
||||
options = kwargs.get("options")
|
||||
default_options = dict(getattr(self, "default_options", {}))
|
||||
merged_client_kwargs = dict(client_kwargs) if client_kwargs is not None else {}
|
||||
merged_client_kwargs.update(kwargs)
|
||||
merged_options: dict[str, Any] = merge_chat_options(default_options, options or {})
|
||||
merged_options: dict[str, Any] = merge_chat_options(
|
||||
default_options, dict(options) if options is not None else {}
|
||||
)
|
||||
attributes = _get_span_attributes(
|
||||
operation_name=OtelAttr.AGENT_INVOKE_OPERATION,
|
||||
provider_name=provider_name,
|
||||
@@ -1590,16 +1613,7 @@ class AgentTelemetryLayer:
|
||||
|
||||
if stream:
|
||||
try:
|
||||
run_result: object = super_run(
|
||||
messages=messages,
|
||||
stream=True,
|
||||
session=session,
|
||||
compaction_strategy=compaction_strategy,
|
||||
tokenizer=tokenizer,
|
||||
function_invocation_kwargs=function_invocation_kwargs,
|
||||
client_kwargs=client_kwargs,
|
||||
**kwargs,
|
||||
)
|
||||
run_result: object = super_run(**super_run_kwargs)
|
||||
if isinstance(run_result, ResponseStream):
|
||||
result_stream: ResponseStream[AgentResponseUpdate, AgentResponse[Any]] = run_result # pyright: ignore[reportUnknownVariableType]
|
||||
elif isinstance(run_result, Awaitable):
|
||||
@@ -1693,16 +1707,7 @@ class AgentTelemetryLayer:
|
||||
)
|
||||
start_time_stamp = perf_counter()
|
||||
try:
|
||||
response: AgentResponse[Any] = await super_run(
|
||||
messages=messages,
|
||||
stream=False,
|
||||
session=session,
|
||||
compaction_strategy=compaction_strategy,
|
||||
tokenizer=tokenizer,
|
||||
function_invocation_kwargs=function_invocation_kwargs,
|
||||
client_kwargs=client_kwargs,
|
||||
**kwargs,
|
||||
)
|
||||
response: AgentResponse[Any] = await super_run(**super_run_kwargs)
|
||||
except Exception as exception:
|
||||
capture_exception(span=span, exception=exception, timestamp=time_ns())
|
||||
raise
|
||||
|
||||
@@ -148,11 +148,9 @@ async def test_chat_client_agent_init_with_name(
|
||||
assert agent.description == "Test"
|
||||
|
||||
|
||||
def test_agent_init_warns_for_direct_additional_properties(client: SupportsChatGetResponse) -> None:
|
||||
with pytest.warns(DeprecationWarning, match="additional_properties"):
|
||||
agent = Agent(client=client, legacy_key="legacy-value")
|
||||
|
||||
assert agent.additional_properties["legacy_key"] == "legacy-value"
|
||||
def test_agent_init_rejects_direct_additional_properties(client: SupportsChatGetResponse) -> None:
|
||||
with pytest.raises(TypeError):
|
||||
Agent(client=client, legacy_key="legacy-value")
|
||||
|
||||
|
||||
async def test_chat_client_agent_run(client: SupportsChatGetResponse) -> None:
|
||||
@@ -303,7 +301,6 @@ async def test_prepare_run_context_handles_function_kwargs(
|
||||
},
|
||||
compaction_strategy=None,
|
||||
tokenizer=None,
|
||||
legacy_kwargs={"legacy_key": "legacy-value"},
|
||||
function_invocation_kwargs={"runtime_key": "runtime-value"},
|
||||
client_kwargs={"client_key": "client-value"},
|
||||
)
|
||||
@@ -311,7 +308,6 @@ async def test_prepare_run_context_handles_function_kwargs(
|
||||
assert ctx["chat_options"]["temperature"] == 0.4
|
||||
assert "additional_function_arguments" not in ctx["chat_options"]
|
||||
assert ctx["function_invocation_kwargs"]["from_options"] == "options-value"
|
||||
assert ctx["function_invocation_kwargs"]["legacy_key"] == "legacy-value"
|
||||
assert ctx["function_invocation_kwargs"]["runtime_key"] == "runtime-value"
|
||||
assert "session" not in ctx["function_invocation_kwargs"]
|
||||
assert ctx["client_kwargs"]["client_key"] == "client-value"
|
||||
@@ -1181,8 +1177,8 @@ async def test_agent_run_accepts_prefixed_mcp_tools(chat_client_base: Any) -> No
|
||||
assert tool_names == ["search", "docs_search"]
|
||||
|
||||
|
||||
async def test_agent_tool_receives_session_in_kwargs(chat_client_base: Any) -> None:
|
||||
"""Verify legacy **kwargs tools receive the session when agent.run() is called with one."""
|
||||
async def test_agent_tool_without_context_does_not_receive_session(chat_client_base: Any) -> None:
|
||||
"""Verify tools without FunctionInvocationContext no longer receive injected session kwargs."""
|
||||
|
||||
captured: dict[str, Any] = {}
|
||||
|
||||
@@ -1215,8 +1211,8 @@ async def test_agent_tool_receives_session_in_kwargs(chat_client_base: Any) -> N
|
||||
result = await agent.run("hello", session=session)
|
||||
|
||||
assert result.text == "done"
|
||||
assert captured.get("has_session") is True
|
||||
assert captured.get("has_state") is True
|
||||
assert captured.get("has_session") is False
|
||||
assert captured.get("has_state") is False
|
||||
|
||||
|
||||
async def test_agent_tool_receives_explicit_session_via_function_invocation_context_kwargs(
|
||||
@@ -1278,7 +1274,7 @@ async def test_chat_agent_tool_choice_run_level_overrides_agent_level(chat_clien
|
||||
agent = Agent(
|
||||
client=chat_client_base,
|
||||
tools=[tool_tool],
|
||||
options={"tool_choice": "auto"},
|
||||
default_options={"tool_choice": "auto"},
|
||||
)
|
||||
|
||||
# Run with run-level tool_choice="required"
|
||||
|
||||
@@ -1,7 +1,6 @@
|
||||
# Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
|
||||
import inspect
|
||||
from typing import Any
|
||||
from unittest.mock import patch
|
||||
|
||||
@@ -15,11 +14,6 @@ from agent_framework import (
|
||||
Message,
|
||||
SlidingWindowStrategy,
|
||||
SupportsChatGetResponse,
|
||||
SupportsCodeInterpreterTool,
|
||||
SupportsFileSearchTool,
|
||||
SupportsImageGenerationTool,
|
||||
SupportsMCPTool,
|
||||
SupportsWebSearchTool,
|
||||
TruncationStrategy,
|
||||
)
|
||||
|
||||
@@ -53,11 +47,9 @@ def test_base_client(chat_client_base: SupportsChatGetResponse):
|
||||
assert isinstance(chat_client_base, SupportsChatGetResponse)
|
||||
|
||||
|
||||
def test_base_client_warns_for_direct_additional_properties(chat_client_base: SupportsChatGetResponse) -> None:
|
||||
with pytest.warns(DeprecationWarning, match="additional_properties"):
|
||||
client = type(chat_client_base)(legacy_key="legacy-value")
|
||||
|
||||
assert client.additional_properties["legacy_key"] == "legacy-value"
|
||||
def test_base_client_rejects_direct_additional_properties(chat_client_base: SupportsChatGetResponse) -> None:
|
||||
with pytest.raises(TypeError):
|
||||
type(chat_client_base)(legacy_key="legacy-value")
|
||||
|
||||
|
||||
def test_base_client_as_agent_uses_explicit_additional_properties(chat_client_base: SupportsChatGetResponse) -> None:
|
||||
@@ -66,27 +58,6 @@ def test_base_client_as_agent_uses_explicit_additional_properties(chat_client_ba
|
||||
assert agent.additional_properties == {"team": "core"}
|
||||
|
||||
|
||||
def test_openai_chat_completion_client_get_response_docstring_surfaces_layered_runtime_docs() -> None:
|
||||
from agent_framework.openai import OpenAIChatCompletionClient
|
||||
|
||||
docstring = inspect.getdoc(OpenAIChatCompletionClient.get_response)
|
||||
|
||||
assert docstring is not None
|
||||
assert "Get a response from a chat client." in docstring
|
||||
assert "function_invocation_kwargs" in docstring
|
||||
assert "middleware: Optional per-call chat and function middleware." in docstring
|
||||
assert "function_middleware: Optional per-call function middleware." not in docstring
|
||||
|
||||
|
||||
def test_openai_chat_completion_client_get_response_is_defined_on_openai_class() -> None:
|
||||
from agent_framework.openai import OpenAIChatCompletionClient
|
||||
|
||||
signature = inspect.signature(OpenAIChatCompletionClient.get_response)
|
||||
|
||||
assert OpenAIChatCompletionClient.get_response.__qualname__ == "OpenAIChatCompletionClient.get_response"
|
||||
assert "middleware" in signature.parameters
|
||||
|
||||
|
||||
async def test_base_client_get_response_uses_explicit_client_kwargs(chat_client_base: SupportsChatGetResponse) -> None:
|
||||
async def fake_inner_get_response(**kwargs):
|
||||
assert kwargs["trace_id"] == "trace-123"
|
||||
@@ -333,66 +304,3 @@ async def test_chat_client_instructions_handling(chat_client_base: SupportsChatG
|
||||
assert appended_messages[0].text == "You are a helpful assistant."
|
||||
assert appended_messages[1].role == "user"
|
||||
assert appended_messages[1].text == "hello"
|
||||
|
||||
|
||||
# region Tool Support Protocol Tests
|
||||
|
||||
|
||||
def test_openai_responses_client_supports_all_tool_protocols():
|
||||
"""Test that OpenAIResponsesClient supports all hosted tool protocols."""
|
||||
from agent_framework.openai import OpenAIResponsesClient
|
||||
|
||||
assert isinstance(OpenAIResponsesClient, SupportsCodeInterpreterTool)
|
||||
assert isinstance(OpenAIResponsesClient, SupportsWebSearchTool)
|
||||
assert isinstance(OpenAIResponsesClient, SupportsImageGenerationTool)
|
||||
assert isinstance(OpenAIResponsesClient, SupportsMCPTool)
|
||||
assert isinstance(OpenAIResponsesClient, SupportsFileSearchTool)
|
||||
|
||||
|
||||
def test_openai_chat_completion_client_supports_web_search_only():
|
||||
"""Test that OpenAIChatClient only supports web search tool."""
|
||||
from agent_framework.openai import OpenAIChatCompletionClient
|
||||
|
||||
assert not isinstance(OpenAIChatCompletionClient, SupportsCodeInterpreterTool)
|
||||
assert isinstance(OpenAIChatCompletionClient, SupportsWebSearchTool)
|
||||
assert not isinstance(OpenAIChatCompletionClient, SupportsImageGenerationTool)
|
||||
assert not isinstance(OpenAIChatCompletionClient, SupportsMCPTool)
|
||||
assert not isinstance(OpenAIChatCompletionClient, SupportsFileSearchTool)
|
||||
|
||||
|
||||
def test_openai_assistants_client_supports_code_interpreter_and_file_search():
|
||||
"""Test that OpenAIAssistantsClient supports code interpreter and file search."""
|
||||
from agent_framework.openai import OpenAIAssistantsClient
|
||||
|
||||
assert isinstance(OpenAIAssistantsClient, SupportsCodeInterpreterTool)
|
||||
assert not isinstance(OpenAIAssistantsClient, SupportsWebSearchTool)
|
||||
assert not isinstance(OpenAIAssistantsClient, SupportsImageGenerationTool)
|
||||
assert not isinstance(OpenAIAssistantsClient, SupportsMCPTool)
|
||||
assert isinstance(OpenAIAssistantsClient, SupportsFileSearchTool)
|
||||
|
||||
|
||||
def test_protocol_isinstance_with_client_instance():
|
||||
"""Test that protocol isinstance works with client instances."""
|
||||
from agent_framework.openai import OpenAIResponsesClient
|
||||
|
||||
# Create mock client instance (won't connect to API)
|
||||
client = OpenAIResponsesClient.__new__(OpenAIResponsesClient)
|
||||
|
||||
assert isinstance(client, SupportsCodeInterpreterTool)
|
||||
assert isinstance(client, SupportsWebSearchTool)
|
||||
|
||||
|
||||
def test_protocol_tool_methods_return_dict():
|
||||
"""Test that static tool methods return dict[str, Any]."""
|
||||
from agent_framework.openai import OpenAIResponsesClient
|
||||
|
||||
code_tool = OpenAIResponsesClient.get_code_interpreter_tool()
|
||||
assert isinstance(code_tool, dict)
|
||||
assert code_tool.get("type") == "code_interpreter"
|
||||
|
||||
web_tool = OpenAIResponsesClient.get_web_search_tool()
|
||||
assert isinstance(web_tool, dict)
|
||||
assert web_tool.get("type") == "web_search"
|
||||
|
||||
|
||||
# endregion
|
||||
|
||||
@@ -13,6 +13,7 @@ from agent_framework import (
|
||||
Content,
|
||||
Message,
|
||||
SupportsChatGetResponse,
|
||||
chat_middleware,
|
||||
tool,
|
||||
)
|
||||
from agent_framework._compaction import (
|
||||
@@ -74,7 +75,7 @@ async def test_base_client_with_function_calling(chat_client_base: SupportsChatG
|
||||
assert response.messages[2].text == "done"
|
||||
|
||||
|
||||
async def test_base_client_with_function_calling_tools_in_kwargs(chat_client_base: SupportsChatGetResponse):
|
||||
async def test_base_client_with_function_calling_string_input(chat_client_base: SupportsChatGetResponse):
|
||||
exec_counter = 0
|
||||
|
||||
@tool(name="test_function", approval_mode="never_require")
|
||||
@@ -95,7 +96,7 @@ async def test_base_client_with_function_calling_tools_in_kwargs(chat_client_bas
|
||||
ChatResponse(messages=Message(role="assistant", text="done")),
|
||||
]
|
||||
|
||||
response = await chat_client_base.get_response("hello", tools=[ai_func])
|
||||
response = await chat_client_base.get_response("hello", options={"tool_choice": "auto", "tools": [ai_func]})
|
||||
|
||||
assert exec_counter == 1
|
||||
assert len(response.messages) == 3
|
||||
@@ -1429,6 +1430,36 @@ async def test_function_invocation_config_enabled_false(chat_client_base: Suppor
|
||||
assert len(response.messages) > 0
|
||||
|
||||
|
||||
async def test_function_invocation_config_enabled_false_preserves_invocation_kwargs(
|
||||
chat_client_base: SupportsChatGetResponse,
|
||||
):
|
||||
"""Test disabled function invocation still forwards invocation kwargs downstream."""
|
||||
captured_kwargs: dict[str, Any] = {}
|
||||
|
||||
@tool(name="test_function")
|
||||
def ai_func(arg1: str) -> str:
|
||||
return f"Processed {arg1}"
|
||||
|
||||
@chat_middleware
|
||||
async def capture_middleware(context, call_next):
|
||||
captured_kwargs.update(context.function_invocation_kwargs or {})
|
||||
await call_next()
|
||||
|
||||
chat_client_base.chat_middleware = [capture_middleware]
|
||||
chat_client_base.run_responses = [
|
||||
ChatResponse(messages=Message(role="assistant", text="response without function calling")),
|
||||
]
|
||||
chat_client_base.function_invocation_configuration["enabled"] = False
|
||||
|
||||
await chat_client_base.get_response(
|
||||
[Message(role="user", text="hello")],
|
||||
options={"tool_choice": "auto", "tools": [ai_func]},
|
||||
function_invocation_kwargs={"tool_request_id": "tool-123"},
|
||||
)
|
||||
|
||||
assert captured_kwargs == {"tool_request_id": "tool-123"}
|
||||
|
||||
|
||||
@pytest.mark.skip(reason="Error handling and failsafe behavior needs investigation in unified API")
|
||||
async def test_function_invocation_config_max_consecutive_errors(chat_client_base: SupportsChatGetResponse):
|
||||
"""Test that max_consecutive_errors_per_request limits error retries."""
|
||||
@@ -1523,7 +1554,7 @@ async def test_function_invocation_stop_clears_conversation_id_non_stream(chat_c
|
||||
response = await chat_client_base.get_response(
|
||||
[Message(role="user", text="hello")],
|
||||
options={"tool_choice": "auto", "tools": [error_func]},
|
||||
session=session_stub,
|
||||
client_kwargs={"session": session_stub},
|
||||
)
|
||||
|
||||
assert response.conversation_id is None
|
||||
@@ -1881,8 +1912,7 @@ async def test_hosted_tool_approval_response(chat_client_base: SupportsChatGetRe
|
||||
# Send the approval response
|
||||
response = await chat_client_base.get_response(
|
||||
[Message(role="user", contents=[approval_response])],
|
||||
tool_choice="auto",
|
||||
tools=[local_func],
|
||||
options={"tool_choice": "auto", "tools": [local_func]},
|
||||
)
|
||||
|
||||
# The hosted tool approval should be returned as-is (not executed)
|
||||
@@ -1930,8 +1960,7 @@ async def test_hosted_mcp_approval_response_passthrough(chat_client_base: Suppor
|
||||
|
||||
response = await chat_client_base.get_response(
|
||||
messages,
|
||||
tool_choice="auto",
|
||||
tools=[local_func],
|
||||
options={"tool_choice": "auto", "tools": [local_func]},
|
||||
)
|
||||
|
||||
# The response should succeed without errors
|
||||
@@ -2024,8 +2053,7 @@ async def test_mixed_local_and_hosted_approval_flow(chat_client_base: SupportsCh
|
||||
|
||||
response = await chat_client_base.get_response(
|
||||
messages,
|
||||
tool_choice="auto",
|
||||
tools=[local_func],
|
||||
options={"tool_choice": "auto", "tools": [local_func]},
|
||||
)
|
||||
|
||||
assert response is not None
|
||||
@@ -2799,7 +2827,7 @@ async def test_streaming_function_invocation_stop_clears_conversation_id(chat_cl
|
||||
"hello",
|
||||
options={"tool_choice": "auto", "tools": [error_func]},
|
||||
stream=True,
|
||||
session=session_stub,
|
||||
client_kwargs={"session": session_stub},
|
||||
)
|
||||
async for _ in stream:
|
||||
pass
|
||||
|
||||
@@ -1,351 +0,0 @@
|
||||
# Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
"""Tests for kwargs propagation from get_response() to @tool functions."""
|
||||
|
||||
from collections.abc import AsyncIterable, Awaitable, MutableSequence, Sequence
|
||||
from typing import Any
|
||||
|
||||
from agent_framework import (
|
||||
Agent,
|
||||
BaseChatClient,
|
||||
ChatMiddlewareLayer,
|
||||
ChatResponse,
|
||||
ChatResponseUpdate,
|
||||
Content,
|
||||
FunctionInvocationContext,
|
||||
FunctionInvocationLayer,
|
||||
Message,
|
||||
ResponseStream,
|
||||
tool,
|
||||
)
|
||||
from agent_framework.observability import ChatTelemetryLayer
|
||||
|
||||
|
||||
class _MockBaseChatClient(BaseChatClient[Any]):
|
||||
"""Mock chat client for testing function invocation."""
|
||||
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.run_responses: list[ChatResponse] = []
|
||||
self.streaming_responses: list[list[ChatResponseUpdate]] = []
|
||||
self.call_count: int = 0
|
||||
|
||||
def _inner_get_response(
|
||||
self,
|
||||
*,
|
||||
messages: MutableSequence[Message],
|
||||
stream: bool,
|
||||
options: dict[str, Any],
|
||||
**kwargs: Any,
|
||||
) -> Awaitable[ChatResponse] | ResponseStream[ChatResponseUpdate, ChatResponse]:
|
||||
if stream:
|
||||
return self._get_streaming_response(messages=messages, options=options, **kwargs)
|
||||
|
||||
async def _get() -> ChatResponse:
|
||||
return await self._get_non_streaming_response(messages=messages, options=options, **kwargs)
|
||||
|
||||
return _get()
|
||||
|
||||
async def _get_non_streaming_response(
|
||||
self,
|
||||
*,
|
||||
messages: MutableSequence[Message],
|
||||
options: dict[str, Any],
|
||||
**kwargs: Any,
|
||||
) -> ChatResponse:
|
||||
self.call_count += 1
|
||||
if self.run_responses:
|
||||
return self.run_responses.pop(0)
|
||||
return ChatResponse(messages=Message(role="assistant", text="default response"))
|
||||
|
||||
def _get_streaming_response(
|
||||
self,
|
||||
*,
|
||||
messages: MutableSequence[Message],
|
||||
options: dict[str, Any],
|
||||
**kwargs: Any,
|
||||
) -> ResponseStream[ChatResponseUpdate, ChatResponse]:
|
||||
async def _stream() -> AsyncIterable[ChatResponseUpdate]:
|
||||
self.call_count += 1
|
||||
if self.streaming_responses:
|
||||
for update in self.streaming_responses.pop(0):
|
||||
yield update
|
||||
else:
|
||||
yield ChatResponseUpdate(
|
||||
contents=[Content.from_text("default streaming response")], role="assistant", finish_reason="stop"
|
||||
)
|
||||
|
||||
def _finalize(updates: Sequence[ChatResponseUpdate]) -> ChatResponse:
|
||||
response_format = options.get("response_format")
|
||||
output_format_type = response_format if isinstance(response_format, type) else None
|
||||
return ChatResponse.from_updates(updates, output_format_type=output_format_type)
|
||||
|
||||
return ResponseStream(_stream(), finalizer=_finalize)
|
||||
|
||||
|
||||
class FunctionInvokingMockClient(
|
||||
FunctionInvocationLayer[Any],
|
||||
ChatMiddlewareLayer[Any],
|
||||
ChatTelemetryLayer[Any],
|
||||
_MockBaseChatClient,
|
||||
):
|
||||
"""Mock client with function invocation support."""
|
||||
|
||||
pass
|
||||
|
||||
|
||||
class TestKwargsPropagationToFunctionTool:
|
||||
"""Test cases for kwargs flowing from get_response() to @tool functions."""
|
||||
|
||||
async def test_kwargs_propagate_to_tool_with_kwargs(self) -> None:
|
||||
"""Test that kwargs passed to get_response() are available in @tool **kwargs."""
|
||||
# TODO(Copilot): Remove this legacy coverage once runtime ``**kwargs`` tool injection is removed.
|
||||
captured_kwargs: dict[str, Any] = {}
|
||||
|
||||
@tool(approval_mode="never_require")
|
||||
def capture_kwargs_tool(x: int, **kwargs: Any) -> str:
|
||||
"""A tool that captures kwargs for testing."""
|
||||
captured_kwargs.update(kwargs)
|
||||
return f"result: x={x}"
|
||||
|
||||
client = FunctionInvokingMockClient()
|
||||
client.run_responses = [
|
||||
# First response: function call
|
||||
ChatResponse(
|
||||
messages=[
|
||||
Message(
|
||||
role="assistant",
|
||||
contents=[
|
||||
Content.from_function_call(
|
||||
call_id="call_1", name="capture_kwargs_tool", arguments='{"x": 42}'
|
||||
)
|
||||
],
|
||||
)
|
||||
]
|
||||
),
|
||||
# Second response: final answer
|
||||
ChatResponse(messages=[Message(role="assistant", text="Done!")]),
|
||||
]
|
||||
|
||||
result = await client.get_response(
|
||||
messages=[Message(role="user", text="Test")],
|
||||
stream=False,
|
||||
options={
|
||||
"tools": [capture_kwargs_tool],
|
||||
"additional_function_arguments": {
|
||||
"user_id": "user-123",
|
||||
"session_token": "secret-token",
|
||||
"custom_data": {"key": "value"},
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
# Verify the tool was called and received the kwargs
|
||||
assert "user_id" in captured_kwargs, f"Expected 'user_id' in captured kwargs: {captured_kwargs}"
|
||||
assert captured_kwargs["user_id"] == "user-123"
|
||||
assert "session_token" in captured_kwargs
|
||||
assert captured_kwargs["session_token"] == "secret-token"
|
||||
assert "custom_data" in captured_kwargs
|
||||
assert captured_kwargs["custom_data"] == {"key": "value"}
|
||||
# Verify result
|
||||
assert result.messages[-1].text == "Done!"
|
||||
|
||||
async def test_kwargs_not_forwarded_to_tool_without_kwargs(self) -> None:
|
||||
"""Test that kwargs are NOT forwarded to @tool that doesn't accept **kwargs."""
|
||||
# TODO(Copilot): Remove this legacy coverage once runtime ``**kwargs`` tool injection is removed.
|
||||
|
||||
@tool(approval_mode="never_require")
|
||||
def simple_tool(x: int) -> str:
|
||||
"""A simple tool without **kwargs."""
|
||||
return f"result: x={x}"
|
||||
|
||||
client = FunctionInvokingMockClient()
|
||||
client.run_responses = [
|
||||
ChatResponse(
|
||||
messages=[
|
||||
Message(
|
||||
role="assistant",
|
||||
contents=[
|
||||
Content.from_function_call(call_id="call_1", name="simple_tool", arguments='{"x": 99}')
|
||||
],
|
||||
)
|
||||
]
|
||||
),
|
||||
ChatResponse(messages=[Message(role="assistant", text="Completed!")]),
|
||||
]
|
||||
|
||||
# Call with additional_function_arguments - the tool should work but not receive them
|
||||
result = await client.get_response(
|
||||
messages=[Message(role="user", text="Test")],
|
||||
stream=False,
|
||||
options={
|
||||
"tools": [simple_tool],
|
||||
"additional_function_arguments": {"user_id": "user-123"},
|
||||
},
|
||||
)
|
||||
|
||||
# Verify the tool was called successfully (no error from extra kwargs)
|
||||
assert result.messages[-1].text == "Completed!"
|
||||
|
||||
async def test_kwargs_isolated_between_function_calls(self) -> None:
|
||||
"""Test that kwargs are consistent across multiple function call invocations."""
|
||||
# TODO(Copilot): Remove this legacy coverage once runtime ``**kwargs`` tool injection is removed.
|
||||
invocation_kwargs: list[dict[str, Any]] = []
|
||||
|
||||
@tool(approval_mode="never_require")
|
||||
def tracking_tool(name: str, **kwargs: Any) -> str:
|
||||
"""A tool that tracks kwargs from each invocation."""
|
||||
invocation_kwargs.append(dict(kwargs))
|
||||
return f"called with {name}"
|
||||
|
||||
client = FunctionInvokingMockClient()
|
||||
client.run_responses = [
|
||||
# Two function calls in one response
|
||||
ChatResponse(
|
||||
messages=[
|
||||
Message(
|
||||
role="assistant",
|
||||
contents=[
|
||||
Content.from_function_call(
|
||||
call_id="call_1", name="tracking_tool", arguments='{"name": "first"}'
|
||||
),
|
||||
Content.from_function_call(
|
||||
call_id="call_2", name="tracking_tool", arguments='{"name": "second"}'
|
||||
),
|
||||
],
|
||||
)
|
||||
]
|
||||
),
|
||||
ChatResponse(messages=[Message(role="assistant", text="All done!")]),
|
||||
]
|
||||
|
||||
result = await client.get_response(
|
||||
messages=[Message(role="user", text="Test")],
|
||||
stream=False,
|
||||
options={
|
||||
"tools": [tracking_tool],
|
||||
"additional_function_arguments": {
|
||||
"request_id": "req-001",
|
||||
"trace_context": {"trace_id": "abc"},
|
||||
},
|
||||
},
|
||||
)
|
||||
|
||||
# Both invocations should have received the same kwargs
|
||||
assert len(invocation_kwargs) == 2
|
||||
for kwargs in invocation_kwargs:
|
||||
assert kwargs.get("request_id") == "req-001"
|
||||
assert kwargs.get("trace_context") == {"trace_id": "abc"}
|
||||
assert result.messages[-1].text == "All done!"
|
||||
|
||||
async def test_streaming_response_kwargs_propagation(self) -> None:
|
||||
"""Test that kwargs propagate to @tool in streaming mode."""
|
||||
# TODO(Copilot): Remove this legacy coverage once runtime ``**kwargs`` tool injection is removed.
|
||||
captured_kwargs: dict[str, Any] = {}
|
||||
|
||||
@tool(approval_mode="never_require")
|
||||
def streaming_capture_tool(value: str, **kwargs: Any) -> str:
|
||||
"""A tool that captures kwargs during streaming."""
|
||||
captured_kwargs.update(kwargs)
|
||||
return f"processed: {value}"
|
||||
|
||||
client = FunctionInvokingMockClient()
|
||||
client.streaming_responses = [
|
||||
# First stream: function call
|
||||
[
|
||||
ChatResponseUpdate(
|
||||
role="assistant",
|
||||
contents=[
|
||||
Content.from_function_call(
|
||||
call_id="stream_call_1",
|
||||
name="streaming_capture_tool",
|
||||
arguments='{"value": "streaming-test"}',
|
||||
)
|
||||
],
|
||||
finish_reason="stop",
|
||||
)
|
||||
],
|
||||
# Second stream: final response
|
||||
[
|
||||
ChatResponseUpdate(
|
||||
contents=[Content.from_text("Stream complete!")], role="assistant", finish_reason="stop"
|
||||
)
|
||||
],
|
||||
]
|
||||
|
||||
# Collect streaming updates
|
||||
updates: list[ChatResponseUpdate] = []
|
||||
stream = client.get_response(
|
||||
messages=[Message(role="user", text="Test")],
|
||||
stream=True,
|
||||
options={
|
||||
"tools": [streaming_capture_tool],
|
||||
"additional_function_arguments": {
|
||||
"streaming_session": "session-xyz",
|
||||
"correlation_id": "corr-123",
|
||||
},
|
||||
},
|
||||
)
|
||||
async for update in stream:
|
||||
updates.append(update)
|
||||
|
||||
# Verify kwargs were captured by the tool
|
||||
assert "streaming_session" in captured_kwargs, f"Expected 'streaming_session' in {captured_kwargs}"
|
||||
assert captured_kwargs["streaming_session"] == "session-xyz"
|
||||
assert captured_kwargs["correlation_id"] == "corr-123"
|
||||
|
||||
async def test_agent_run_injects_function_invocation_context(self) -> None:
|
||||
"""Test that Agent.run injects FunctionInvocationContext for ctx-based tools."""
|
||||
captured_context_kwargs: dict[str, Any] = {}
|
||||
captured_client_kwargs: dict[str, Any] = {}
|
||||
captured_options: dict[str, Any] = {}
|
||||
|
||||
@tool(approval_mode="never_require")
|
||||
def capture_context_tool(x: int, ctx: FunctionInvocationContext) -> str:
|
||||
captured_context_kwargs.update(ctx.kwargs)
|
||||
return f"result: x={x}"
|
||||
|
||||
class CapturingFunctionInvokingMockClient(FunctionInvokingMockClient):
|
||||
async def _get_non_streaming_response(
|
||||
self,
|
||||
*,
|
||||
messages: MutableSequence[Message],
|
||||
options: dict[str, Any],
|
||||
**kwargs: Any,
|
||||
) -> ChatResponse:
|
||||
captured_options.update(options)
|
||||
captured_client_kwargs.update(kwargs)
|
||||
return await super()._get_non_streaming_response(messages=messages, options=options, **kwargs)
|
||||
|
||||
client = CapturingFunctionInvokingMockClient()
|
||||
client.run_responses = [
|
||||
ChatResponse(
|
||||
messages=[
|
||||
Message(
|
||||
role="assistant",
|
||||
contents=[
|
||||
Content.from_function_call(
|
||||
call_id="call_1",
|
||||
name="capture_context_tool",
|
||||
arguments='{"x": 42}',
|
||||
)
|
||||
],
|
||||
)
|
||||
]
|
||||
),
|
||||
ChatResponse(messages=[Message(role="assistant", text="Done!")]),
|
||||
]
|
||||
|
||||
agent = Agent(client=client, tools=[capture_context_tool])
|
||||
result = await agent.run(
|
||||
[Message(role="user", text="Test")],
|
||||
function_invocation_kwargs={"tool_request_id": "tool-123"},
|
||||
client_kwargs={"client_request_id": "client-456"},
|
||||
)
|
||||
|
||||
assert captured_context_kwargs["tool_request_id"] == "tool-123"
|
||||
assert "client_request_id" not in captured_context_kwargs
|
||||
assert captured_client_kwargs["client_request_id"] == "client-456"
|
||||
assert "tool_request_id" not in captured_client_kwargs
|
||||
assert "additional_function_arguments" not in captured_options
|
||||
assert result.messages[-1].text == "Done!"
|
||||
@@ -1751,6 +1751,9 @@ async def test_mcp_tool_sampling_callback_no_valid_content():
|
||||
assert isinstance(result, types.ErrorData)
|
||||
assert result.code == types.INTERNAL_ERROR
|
||||
assert "Failed to get right content types from the response." in result.message
|
||||
mock_chat_client.get_response.assert_awaited_once()
|
||||
_, kwargs = mock_chat_client.get_response.await_args
|
||||
assert kwargs["options"] == {"max_tokens": None}
|
||||
|
||||
|
||||
async def test_mcp_tool_sampling_callback_no_response_and_successful_message_creation():
|
||||
@@ -3704,14 +3707,19 @@ async def test_mcp_tool_filters_framework_kwargs():
|
||||
|
||||
# Invoke the tool with framework kwargs that should be filtered out
|
||||
await func.invoke(
|
||||
param="test_value",
|
||||
response_format=MockResponseFormat, # Should be filtered
|
||||
chat_options={"some": "option"}, # Should be filtered
|
||||
tools=[Mock()], # Should be filtered
|
||||
tool_choice="auto", # Should be filtered
|
||||
session=Mock(), # Should be filtered
|
||||
conversation_id="conv-123", # Should be filtered
|
||||
options={"metadata": "value"}, # Should be filtered
|
||||
context=FunctionInvocationContext(
|
||||
function=func,
|
||||
arguments={"param": "test_value"},
|
||||
kwargs={
|
||||
"response_format": MockResponseFormat, # Should be filtered
|
||||
"chat_options": {"some": "option"}, # Should be filtered
|
||||
"tools": [Mock()], # Should be filtered
|
||||
"tool_choice": "auto", # Should be filtered
|
||||
"session": Mock(), # Should be filtered
|
||||
"conversation_id": "conv-123", # Should be filtered
|
||||
"options": {"metadata": "value"}, # Should be filtered
|
||||
},
|
||||
),
|
||||
)
|
||||
|
||||
# Verify call_tool was called with only the valid argument
|
||||
|
||||
@@ -789,9 +789,10 @@ class TestChatAgentFunctionMiddlewareWithTools:
|
||||
assert modified_kwargs["new_param"] == "added_by_middleware"
|
||||
assert modified_kwargs["custom_param"] == "test_value"
|
||||
|
||||
async def test_run_kwargs_available_in_function_middleware(self, chat_client_base: "MockBaseChatClient") -> None:
|
||||
"""Test that kwargs passed directly to agent.run() appear in FunctionInvocationContext.kwargs,
|
||||
including complex nested values like dicts."""
|
||||
async def test_function_invocation_kwargs_available_in_function_middleware(
|
||||
self, chat_client_base: "MockBaseChatClient"
|
||||
) -> None:
|
||||
"""Test that function_invocation_kwargs appear in FunctionInvocationContext.kwargs."""
|
||||
captured_kwargs: dict[str, Any] = {}
|
||||
|
||||
@function_middleware
|
||||
@@ -822,18 +823,20 @@ class TestChatAgentFunctionMiddlewareWithTools:
|
||||
session_metadata = {"tenant": "acme-corp", "region": "us-west"}
|
||||
await agent.run(
|
||||
[Message(role="user", text="Get weather")],
|
||||
user_id="user-456",
|
||||
session_metadata=session_metadata,
|
||||
function_invocation_kwargs={
|
||||
"user_id": "user-456",
|
||||
"session_metadata": session_metadata,
|
||||
},
|
||||
)
|
||||
|
||||
assert "user_id" in captured_kwargs, f"Expected 'user_id' in kwargs: {captured_kwargs}"
|
||||
assert captured_kwargs["user_id"] == "user-456"
|
||||
assert captured_kwargs["session_metadata"] == {"tenant": "acme-corp", "region": "us-west"}
|
||||
|
||||
async def test_run_kwargs_merged_with_additional_function_arguments(
|
||||
async def test_function_invocation_kwargs_merged_with_additional_function_arguments(
|
||||
self, chat_client_base: "MockBaseChatClient"
|
||||
) -> None:
|
||||
"""Test that explicit additional_function_arguments in options take precedence over run kwargs."""
|
||||
"""Test that explicit additional_function_arguments in options take precedence."""
|
||||
captured_kwargs: dict[str, Any] = {}
|
||||
|
||||
@function_middleware
|
||||
@@ -863,9 +866,10 @@ class TestChatAgentFunctionMiddlewareWithTools:
|
||||
|
||||
await agent.run(
|
||||
[Message(role="user", text="Get weather")],
|
||||
# This kwarg should be overridden by additional_function_arguments
|
||||
user_id="from-kwargs",
|
||||
tenant_id="from-kwargs",
|
||||
function_invocation_kwargs={
|
||||
"user_id": "from-kwargs",
|
||||
"tenant_id": "from-kwargs",
|
||||
},
|
||||
options={
|
||||
"additional_function_arguments": {
|
||||
"user_id": "from-options",
|
||||
@@ -876,15 +880,15 @@ class TestChatAgentFunctionMiddlewareWithTools:
|
||||
|
||||
# additional_function_arguments takes precedence for overlapping keys
|
||||
assert captured_kwargs["user_id"] == "from-options"
|
||||
# Non-overlapping kwargs from run() still come through
|
||||
# Non-overlapping function_invocation_kwargs still come through
|
||||
assert captured_kwargs["tenant_id"] == "from-kwargs"
|
||||
# Keys only in additional_function_arguments are present
|
||||
assert captured_kwargs["extra_key"] == "only-in-options"
|
||||
|
||||
async def test_run_kwargs_consistent_across_multiple_tool_calls(
|
||||
async def test_function_invocation_kwargs_consistent_across_multiple_tool_calls(
|
||||
self, chat_client_base: "MockBaseChatClient"
|
||||
) -> None:
|
||||
"""Test that kwargs are consistent across multiple tool invocations in a single run."""
|
||||
"""Test that function_invocation_kwargs are consistent across tool invocations."""
|
||||
invocation_kwargs: list[dict[str, Any]] = []
|
||||
|
||||
@function_middleware
|
||||
@@ -917,8 +921,10 @@ class TestChatAgentFunctionMiddlewareWithTools:
|
||||
|
||||
await agent.run(
|
||||
[Message(role="user", text="Get weather for both cities")],
|
||||
user_id="user-456",
|
||||
request_id="req-001",
|
||||
function_invocation_kwargs={
|
||||
"user_id": "user-456",
|
||||
"request_id": "req-001",
|
||||
},
|
||||
)
|
||||
|
||||
assert len(invocation_kwargs) == 2
|
||||
@@ -2060,23 +2066,21 @@ class TestChatAgentChatMiddleware:
|
||||
"agent_middleware_after",
|
||||
]
|
||||
|
||||
async def test_agent_middleware_can_access_and_override_custom_kwargs(self) -> None:
|
||||
"""Test that agent middleware can access and override custom parameters like temperature."""
|
||||
captured_kwargs: dict[str, Any] = {}
|
||||
modified_kwargs: dict[str, Any] = {}
|
||||
async def test_agent_middleware_can_access_and_override_options(self) -> None:
|
||||
"""Test that agent middleware can access and override runtime options."""
|
||||
captured_options: dict[str, Any] = {}
|
||||
modified_options: dict[str, Any] = {}
|
||||
|
||||
@agent_middleware
|
||||
async def kwargs_middleware(context: AgentContext, call_next: Callable[[], Awaitable[None]]) -> None:
|
||||
# Capture the original kwargs
|
||||
captured_kwargs.update(context.kwargs)
|
||||
assert isinstance(context.options, dict)
|
||||
captured_options.update(context.options)
|
||||
|
||||
# Modify some kwargs
|
||||
context.kwargs["temperature"] = 0.9
|
||||
context.kwargs["max_tokens"] = 500
|
||||
context.kwargs["new_param"] = "added_by_middleware"
|
||||
context.options["temperature"] = 0.9
|
||||
context.options["max_tokens"] = 500
|
||||
context.options["new_param"] = "added_by_middleware"
|
||||
|
||||
# Store modified kwargs for verification
|
||||
modified_kwargs.update(context.kwargs)
|
||||
modified_options.update(context.options)
|
||||
|
||||
await call_next()
|
||||
|
||||
@@ -2084,24 +2088,25 @@ class TestChatAgentChatMiddleware:
|
||||
client = MockBaseChatClient()
|
||||
agent = Agent(client=client, middleware=[kwargs_middleware])
|
||||
|
||||
# Execute the agent with custom parameters
|
||||
# Execute the agent with runtime options
|
||||
messages = [Message(role="user", text="test message")]
|
||||
response = await agent.run(messages, temperature=0.7, max_tokens=100, custom_param="test_value")
|
||||
response = await agent.run(
|
||||
messages,
|
||||
options={"temperature": 0.7, "max_tokens": 100, "custom_param": "test_value"},
|
||||
)
|
||||
|
||||
# Verify response
|
||||
assert response is not None
|
||||
assert len(response.messages) > 0
|
||||
|
||||
# Verify middleware captured the original kwargs
|
||||
assert captured_kwargs["temperature"] == 0.7
|
||||
assert captured_kwargs["max_tokens"] == 100
|
||||
assert captured_kwargs["custom_param"] == "test_value"
|
||||
assert captured_options["temperature"] == 0.7
|
||||
assert captured_options["max_tokens"] == 100
|
||||
assert captured_options["custom_param"] == "test_value"
|
||||
|
||||
# Verify middleware could modify the kwargs
|
||||
assert modified_kwargs["temperature"] == 0.9
|
||||
assert modified_kwargs["max_tokens"] == 500
|
||||
assert modified_kwargs["new_param"] == "added_by_middleware"
|
||||
assert modified_kwargs["custom_param"] == "test_value" # Should still be there
|
||||
assert modified_options["temperature"] == 0.9
|
||||
assert modified_options["max_tokens"] == 500
|
||||
assert modified_options["new_param"] == "added_by_middleware"
|
||||
assert modified_options["custom_param"] == "test_value"
|
||||
|
||||
|
||||
# class TestMiddlewareWithProtocolOnlyAgent:
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
|
||||
from collections.abc import Awaitable, Callable
|
||||
from typing import Any
|
||||
from unittest.mock import patch
|
||||
|
||||
from agent_framework import (
|
||||
Agent,
|
||||
@@ -296,50 +297,77 @@ class TestChatMiddleware:
|
||||
assert response3 is not None
|
||||
assert execution_count["count"] == 2 # Should be 2 now
|
||||
|
||||
async def test_chat_client_middleware_can_access_and_override_custom_kwargs(
|
||||
async def test_run_level_middleware_is_not_forwarded_to_inner_client(
|
||||
self, chat_client_base: "MockBaseChatClient"
|
||||
) -> None:
|
||||
"""Test that chat client middleware can access and override custom parameters like temperature."""
|
||||
captured_kwargs: dict[str, Any] = {}
|
||||
modified_kwargs: dict[str, Any] = {}
|
||||
"""Test that run-level middleware stays in the middleware pipeline only."""
|
||||
observed_context_kwargs: dict[str, Any] = {}
|
||||
|
||||
@chat_middleware
|
||||
async def inspecting_middleware(context: ChatContext, call_next: Callable[[], Awaitable[None]]) -> None:
|
||||
observed_context_kwargs.update(context.kwargs)
|
||||
await call_next()
|
||||
|
||||
async def fake_inner_get_response(**kwargs: Any) -> ChatResponse:
|
||||
assert "middleware" not in kwargs
|
||||
return ChatResponse(messages=[Message(role="assistant", text="ok")])
|
||||
|
||||
with patch.object(
|
||||
chat_client_base,
|
||||
"_inner_get_response",
|
||||
side_effect=fake_inner_get_response,
|
||||
) as mock_inner_get_response:
|
||||
response = await chat_client_base.get_response(
|
||||
[Message(role="user", text="hello")],
|
||||
client_kwargs={"middleware": [inspecting_middleware], "trace_id": "trace-123"},
|
||||
)
|
||||
|
||||
assert response.messages[0].text == "ok"
|
||||
assert observed_context_kwargs == {"trace_id": "trace-123"}
|
||||
mock_inner_get_response.assert_called_once()
|
||||
|
||||
async def test_chat_client_middleware_can_access_and_override_options(
|
||||
self, chat_client_base: "MockBaseChatClient"
|
||||
) -> None:
|
||||
"""Test that chat client middleware can access and override runtime options."""
|
||||
captured_options: dict[str, Any] = {}
|
||||
modified_options: dict[str, Any] = {}
|
||||
|
||||
@chat_middleware
|
||||
async def kwargs_middleware(context: ChatContext, call_next: Callable[[], Awaitable[None]]) -> None:
|
||||
# Capture the original kwargs
|
||||
captured_kwargs.update(context.kwargs)
|
||||
assert isinstance(context.options, dict)
|
||||
captured_options.update(context.options)
|
||||
|
||||
# Modify some kwargs
|
||||
context.kwargs["temperature"] = 0.9
|
||||
context.kwargs["max_tokens"] = 500
|
||||
context.kwargs["new_param"] = "added_by_middleware"
|
||||
context.options["temperature"] = 0.9
|
||||
context.options["max_tokens"] = 500
|
||||
context.options["new_param"] = "added_by_middleware"
|
||||
|
||||
# Store modified kwargs for verification
|
||||
modified_kwargs.update(context.kwargs)
|
||||
modified_options.update(context.options)
|
||||
|
||||
await call_next()
|
||||
|
||||
# Add middleware to chat client
|
||||
chat_client_base.chat_middleware = [kwargs_middleware]
|
||||
|
||||
# Execute chat client with custom parameters
|
||||
# Execute chat client with runtime options
|
||||
messages = [Message(role="user", text="test message")]
|
||||
response = await chat_client_base.get_response(
|
||||
messages, temperature=0.7, max_tokens=100, custom_param="test_value"
|
||||
messages,
|
||||
options={"temperature": 0.7, "max_tokens": 100, "custom_param": "test_value"},
|
||||
)
|
||||
|
||||
# Verify response
|
||||
assert response is not None
|
||||
assert len(response.messages) > 0
|
||||
|
||||
assert captured_kwargs["temperature"] == 0.7
|
||||
assert captured_kwargs["max_tokens"] == 100
|
||||
assert captured_kwargs["custom_param"] == "test_value"
|
||||
assert captured_options["temperature"] == 0.7
|
||||
assert captured_options["max_tokens"] == 100
|
||||
assert captured_options["custom_param"] == "test_value"
|
||||
|
||||
# Verify middleware could modify the kwargs
|
||||
assert modified_kwargs["temperature"] == 0.9
|
||||
assert modified_kwargs["max_tokens"] == 500
|
||||
assert modified_kwargs["new_param"] == "added_by_middleware"
|
||||
assert modified_kwargs["custom_param"] == "test_value" # Should still be there
|
||||
assert modified_options["temperature"] == 0.9
|
||||
assert modified_options["max_tokens"] == 500
|
||||
assert modified_options["new_param"] == "added_by_middleware"
|
||||
assert modified_options["custom_param"] == "test_value"
|
||||
|
||||
def test_chat_middleware_pipeline_cache_reuses_matching_middleware(
|
||||
self,
|
||||
|
||||
@@ -207,7 +207,7 @@ async def test_chat_client_observability(mock_chat_client, span_exporter: InMemo
|
||||
|
||||
messages = [Message(role="user", text="Test message")]
|
||||
span_exporter.clear()
|
||||
response = await client.get_response(messages=messages, model_id="Test")
|
||||
response = await client.get_response(messages=messages, options={"model_id": "Test"})
|
||||
assert response is not None
|
||||
spans = span_exporter.get_finished_spans()
|
||||
assert len(spans) == 1
|
||||
@@ -232,7 +232,7 @@ async def test_chat_client_streaming_observability(
|
||||
span_exporter.clear()
|
||||
# Collect all yielded updates
|
||||
updates = []
|
||||
stream = client.get_response(stream=True, messages=messages, model_id="Test")
|
||||
stream = client.get_response(stream=True, messages=messages, options={"model_id": "Test"})
|
||||
async for update in stream:
|
||||
updates.append(update)
|
||||
await stream.get_final_response()
|
||||
@@ -1540,7 +1540,7 @@ async def test_chat_client_observability_exception(mock_chat_client, span_export
|
||||
|
||||
span_exporter.clear()
|
||||
with pytest.raises(ValueError, match="Test error"):
|
||||
await client.get_response(messages=messages, model_id="Test")
|
||||
await client.get_response(messages=messages, options={"model_id": "Test"})
|
||||
|
||||
spans = span_exporter.get_finished_spans()
|
||||
assert len(spans) == 1
|
||||
@@ -1570,7 +1570,7 @@ async def test_chat_client_streaming_observability_exception(mock_chat_client, s
|
||||
|
||||
span_exporter.clear()
|
||||
with pytest.raises(ValueError, match="Streaming error"):
|
||||
async for _ in client.get_response(messages=messages, stream=True, model_id="Test"):
|
||||
async for _ in client.get_response(messages=messages, stream=True, options={"model_id": "Test"}):
|
||||
pass
|
||||
|
||||
spans = span_exporter.get_finished_spans()
|
||||
@@ -2075,7 +2075,7 @@ async def test_capture_messages_with_finish_reason(mock_chat_client, span_export
|
||||
messages = [Message(role="user", text="Test")]
|
||||
|
||||
span_exporter.clear()
|
||||
response = await client.get_response(messages=messages, model_id="Test")
|
||||
response = await client.get_response(messages=messages, options={"model_id": "Test"})
|
||||
|
||||
assert response is not None
|
||||
assert response.finish_reason == "stop"
|
||||
@@ -2165,7 +2165,7 @@ async def test_chat_client_when_disabled(mock_chat_client, span_exporter: InMemo
|
||||
messages = [Message(role="user", text="Test")]
|
||||
|
||||
span_exporter.clear()
|
||||
response = await client.get_response(messages=messages, model_id="Test")
|
||||
response = await client.get_response(messages=messages, options={"model_id": "Test"})
|
||||
|
||||
assert response is not None
|
||||
spans = span_exporter.get_finished_spans()
|
||||
@@ -2181,7 +2181,7 @@ async def test_chat_client_streaming_when_disabled(mock_chat_client, span_export
|
||||
|
||||
span_exporter.clear()
|
||||
updates = []
|
||||
async for update in client.get_response(messages=messages, stream=True, model_id="Test"):
|
||||
async for update in client.get_response(messages=messages, stream=True, options={"model_id": "Test"}):
|
||||
updates.append(update)
|
||||
|
||||
assert len(updates) == 2 # Still works functionally
|
||||
@@ -2661,7 +2661,7 @@ async def test_capture_messages_preserves_non_ascii_characters(mock_chat_client,
|
||||
messages = [Message(role="user", text=japanese_text)]
|
||||
|
||||
span_exporter.clear()
|
||||
response = await client.get_response(messages=messages, model_id="Test")
|
||||
response = await client.get_response(messages=messages, options={"model_id": "Test"})
|
||||
|
||||
assert response is not None
|
||||
spans = span_exporter.get_finished_spans()
|
||||
|
||||
@@ -594,8 +594,8 @@ async def test_tool_invoke_telemetry_sensitive_disabled(span_exporter: InMemoryS
|
||||
assert attributes[OtelAttr.TOOL_CALL_ID] == "test_call_id"
|
||||
|
||||
|
||||
async def test_tool_invoke_ignores_additional_kwargs() -> None:
|
||||
"""Ensure tools drop unknown kwargs when invoked with validated arguments."""
|
||||
async def test_tool_invoke_rejects_unexpected_runtime_kwargs() -> None:
|
||||
"""Ensure invoke() requires runtime data to flow through FunctionInvocationContext."""
|
||||
|
||||
@tool
|
||||
async def simple_tool(message: str) -> str:
|
||||
@@ -604,15 +604,12 @@ async def test_tool_invoke_ignores_additional_kwargs() -> None:
|
||||
|
||||
args = simple_tool.input_model(message="hello world")
|
||||
|
||||
# These kwargs simulate runtime context passed through function invocation.
|
||||
result = await simple_tool.invoke(
|
||||
arguments=args,
|
||||
api_token="secret-token",
|
||||
options={"model_id": "dummy"},
|
||||
)
|
||||
|
||||
assert isinstance(result, list)
|
||||
assert result[0].text == "HELLO WORLD"
|
||||
with pytest.raises(TypeError, match="Unexpected keyword argument"):
|
||||
await simple_tool.invoke(
|
||||
arguments=args,
|
||||
api_token="secret-token",
|
||||
options={"model_id": "dummy"},
|
||||
)
|
||||
|
||||
|
||||
async def test_tool_invoke_telemetry_with_pydantic_args(span_exporter: InMemorySpanExporter):
|
||||
@@ -917,8 +914,8 @@ def test_parse_inputs_unsupported_type():
|
||||
# endregion
|
||||
|
||||
|
||||
async def test_ai_function_with_kwargs_injection():
|
||||
"""Test that ai_function correctly handles kwargs injection and hides them from schema."""
|
||||
async def test_ai_function_with_kwargs_rejects_runtime_invoke_kwargs():
|
||||
"""Test that runtime kwargs must be passed through FunctionInvocationContext."""
|
||||
|
||||
@tool
|
||||
def tool_with_kwargs(x: int, **kwargs: Any) -> str:
|
||||
@@ -937,13 +934,11 @@ async def test_ai_function_with_kwargs_injection():
|
||||
# Verify direct invocation works
|
||||
assert tool_with_kwargs(1, user_id="user1") == "x=1, user=user1"
|
||||
|
||||
# Verify invoke works with injected args
|
||||
result = await tool_with_kwargs.invoke(
|
||||
arguments=tool_with_kwargs.input_model(x=5),
|
||||
user_id="user2",
|
||||
)
|
||||
assert isinstance(result, list)
|
||||
assert result[0].text == "x=5, user=user2"
|
||||
with pytest.raises(TypeError, match="Unexpected keyword argument"):
|
||||
await tool_with_kwargs.invoke(
|
||||
arguments=tool_with_kwargs.input_model(x=5),
|
||||
user_id="user2",
|
||||
)
|
||||
|
||||
# Verify invoke works without injected args (uses default)
|
||||
result_default = await tool_with_kwargs.invoke(
|
||||
|
||||
@@ -446,7 +446,7 @@ async def executor_with_real_agent() -> tuple[AgentFrameworkExecutor, str, MockB
|
||||
name="Test Chat Agent",
|
||||
description="A real Agent for testing execution flow",
|
||||
client=mock_client,
|
||||
system_message="You are a helpful test assistant.",
|
||||
instructions="You are a helpful test assistant.",
|
||||
)
|
||||
|
||||
# Register the real agent
|
||||
@@ -478,14 +478,14 @@ async def sequential_workflow() -> tuple[AgentFrameworkExecutor, str, MockBaseCh
|
||||
name="Writer",
|
||||
description="Content writer agent",
|
||||
client=mock_client,
|
||||
system_message="You are a content writer. Create clear, engaging content.",
|
||||
instructions="You are a content writer. Create clear, engaging content.",
|
||||
)
|
||||
reviewer = Agent(
|
||||
id="reviewer",
|
||||
name="Reviewer",
|
||||
description="Content reviewer agent",
|
||||
client=mock_client,
|
||||
system_message="You are a reviewer. Provide constructive feedback.",
|
||||
instructions="You are a reviewer. Provide constructive feedback.",
|
||||
)
|
||||
|
||||
workflow = SequentialBuilder(participants=[writer, reviewer]).build()
|
||||
@@ -523,21 +523,21 @@ async def concurrent_workflow() -> tuple[AgentFrameworkExecutor, str, MockBaseCh
|
||||
name="Researcher",
|
||||
description="Research agent",
|
||||
client=mock_client,
|
||||
system_message="You are a researcher. Find key data and insights.",
|
||||
instructions="You are a researcher. Find key data and insights.",
|
||||
)
|
||||
analyst = Agent(
|
||||
id="analyst",
|
||||
name="Analyst",
|
||||
description="Analysis agent",
|
||||
client=mock_client,
|
||||
system_message="You are an analyst. Identify trends and patterns.",
|
||||
instructions="You are an analyst. Identify trends and patterns.",
|
||||
)
|
||||
summarizer = Agent(
|
||||
id="summarizer",
|
||||
name="Summarizer",
|
||||
description="Summary agent",
|
||||
client=mock_client,
|
||||
system_message="You are a summarizer. Provide concise summaries.",
|
||||
instructions="You are a summarizer. Provide concise summaries.",
|
||||
)
|
||||
|
||||
workflow = ConcurrentBuilder(participants=[researcher, analyst, summarizer]).build()
|
||||
|
||||
@@ -309,7 +309,7 @@ async def test_full_pipeline_workflow_events_are_json_serializable():
|
||||
name="Serialization Test Agent",
|
||||
description="Agent for testing serialization",
|
||||
client=mock_client,
|
||||
system_message="You are a test assistant.",
|
||||
instructions="You are a test assistant.",
|
||||
)
|
||||
|
||||
agent_executor = AgentExecutor(id="agent_node", agent=agent)
|
||||
|
||||
@@ -23,6 +23,7 @@ from ._callbacks import AgentCallbackContext, AgentResponseCallbackProtocol
|
||||
from ._durable_agent_state import (
|
||||
DurableAgentState,
|
||||
DurableAgentStateEntry,
|
||||
DurableAgentStateMessage,
|
||||
DurableAgentStateRequest,
|
||||
DurableAgentStateResponse,
|
||||
)
|
||||
@@ -151,10 +152,11 @@ class AgentEntity:
|
||||
|
||||
try:
|
||||
chat_messages: list[Message] = [
|
||||
m.to_chat_message()
|
||||
replayable_message
|
||||
for entry in self.state.data.conversation_history
|
||||
if not self._is_error_response(entry)
|
||||
for m in entry.messages
|
||||
if (replayable_message := self._to_replayable_message(m)) is not None
|
||||
]
|
||||
|
||||
run_kwargs: dict[str, Any] = {"messages": chat_messages, "options": options}
|
||||
@@ -190,6 +192,21 @@ class AgentEntity:
|
||||
|
||||
return error_response
|
||||
|
||||
@staticmethod
|
||||
def _to_replayable_message(message: DurableAgentStateMessage) -> Message | None:
|
||||
"""Convert persisted history into a message safe to replay into chat clients."""
|
||||
chat_message = message.to_chat_message()
|
||||
replayable_contents = [content for content in chat_message.contents if content.type != "reasoning"]
|
||||
if not replayable_contents:
|
||||
return None
|
||||
|
||||
return Message(
|
||||
role=chat_message.role,
|
||||
contents=replayable_contents,
|
||||
author_name=chat_message.author_name,
|
||||
additional_properties=chat_message.additional_properties,
|
||||
)
|
||||
|
||||
async def _invoke_agent(
|
||||
self,
|
||||
run_kwargs: dict[str, Any],
|
||||
|
||||
@@ -21,7 +21,9 @@ from agent_framework_durabletask import (
|
||||
DurableAgentStateData,
|
||||
DurableAgentStateMessage,
|
||||
DurableAgentStateRequest,
|
||||
DurableAgentStateResponse,
|
||||
DurableAgentStateTextContent,
|
||||
DurableAgentStateTextReasoningContent,
|
||||
RunRequest,
|
||||
)
|
||||
from agent_framework_durabletask._entities import DurableTaskEntityStateProvider
|
||||
@@ -391,6 +393,54 @@ class TestAgentEntityRunAgent:
|
||||
assert len(history) == 6
|
||||
assert entity.state.message_count == 6
|
||||
|
||||
async def test_run_filters_reasoning_content_from_replayed_history(self) -> None:
|
||||
"""Replayed durable history should not include reasoning-only content items."""
|
||||
captured_messages: list[Message] = []
|
||||
|
||||
async def mock_run(*args, stream=False, **kwargs):
|
||||
if stream:
|
||||
raise TypeError("streaming not supported")
|
||||
captured_messages.extend(kwargs["messages"])
|
||||
return _agent_response("Response")
|
||||
|
||||
mock_agent = Mock()
|
||||
mock_agent.run = mock_run
|
||||
|
||||
entity = _make_entity(mock_agent)
|
||||
entity.state.data = DurableAgentStateData(
|
||||
conversation_history=[
|
||||
DurableAgentStateRequest(
|
||||
correlation_id="corr-entity-prev-request",
|
||||
created_at=datetime.now(),
|
||||
messages=[
|
||||
DurableAgentStateMessage(
|
||||
role="user",
|
||||
contents=[DurableAgentStateTextContent(text="Hi")],
|
||||
)
|
||||
],
|
||||
),
|
||||
DurableAgentStateResponse(
|
||||
correlation_id="corr-entity-prev-response",
|
||||
created_at=datetime.now(),
|
||||
messages=[
|
||||
DurableAgentStateMessage(
|
||||
role="assistant",
|
||||
contents=[
|
||||
DurableAgentStateTextReasoningContent(text="Let me think."),
|
||||
DurableAgentStateTextContent(text="Hello there."),
|
||||
],
|
||||
)
|
||||
],
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
await entity.run({"message": "What next?", "correlationId": "corr-entity-replay"})
|
||||
|
||||
assert captured_messages
|
||||
assert all(content.type != "reasoning" for message in captured_messages for content in message.contents)
|
||||
assert [message.text for message in captured_messages] == ["Hi", "Hello there.", "What next?"]
|
||||
|
||||
|
||||
class TestAgentEntityReset:
|
||||
"""Test suite for the reset operation."""
|
||||
|
||||
@@ -27,6 +27,7 @@ from agent_framework import (
|
||||
RawAgent,
|
||||
load_settings,
|
||||
)
|
||||
from agent_framework._compaction import CompactionStrategy, TokenizerProtocol
|
||||
from agent_framework.observability import AgentTelemetryLayer, ChatTelemetryLayer
|
||||
from agent_framework_openai._chat_client import OpenAIChatOptions, RawOpenAIChatClient
|
||||
from azure.ai.projects.aio import AIProjectClient
|
||||
@@ -125,9 +126,13 @@ class RawFoundryAgentChatClient( # type: ignore[misc]
|
||||
credential: AzureCredentialTypes | None = None,
|
||||
project_client: AIProjectClient | None = None,
|
||||
allow_preview: bool | None = None,
|
||||
default_headers: Mapping[str, str] | None = None,
|
||||
env_file_path: str | None = None,
|
||||
env_file_encoding: str | None = None,
|
||||
**kwargs: Any,
|
||||
instruction_role: str | None = None,
|
||||
compaction_strategy: CompactionStrategy | None = None,
|
||||
tokenizer: TokenizerProtocol | None = None,
|
||||
additional_properties: dict[str, Any] | None = None,
|
||||
) -> None:
|
||||
"""Initialize a raw Foundry Agent client.
|
||||
|
||||
@@ -141,9 +146,13 @@ class RawFoundryAgentChatClient( # type: ignore[misc]
|
||||
credential: Azure credential for authentication.
|
||||
project_client: An existing AIProjectClient to use.
|
||||
allow_preview: Enables preview opt-in on internally-created AIProjectClient.
|
||||
default_headers: Additional HTTP headers for requests made through the OpenAI client.
|
||||
env_file_path: Path to .env file for settings.
|
||||
env_file_encoding: Encoding for .env file.
|
||||
kwargs: Additional keyword arguments.
|
||||
instruction_role: The role to use for 'instruction' messages.
|
||||
compaction_strategy: Optional per-client compaction override.
|
||||
tokenizer: Optional tokenizer for compaction strategies.
|
||||
additional_properties: Additional properties stored on the client instance.
|
||||
"""
|
||||
settings = load_settings(
|
||||
FoundryAgentSettings,
|
||||
@@ -189,7 +198,14 @@ class RawFoundryAgentChatClient( # type: ignore[misc]
|
||||
# Get OpenAI client from project
|
||||
async_client = self.project_client.get_openai_client()
|
||||
|
||||
super().__init__(async_client=async_client, **kwargs)
|
||||
super().__init__(
|
||||
async_client=async_client,
|
||||
default_headers=default_headers,
|
||||
instruction_role=instruction_role,
|
||||
compaction_strategy=compaction_strategy,
|
||||
tokenizer=tokenizer,
|
||||
additional_properties=additional_properties,
|
||||
)
|
||||
|
||||
def _get_agent_reference(self) -> dict[str, str]:
|
||||
"""Build the agent reference dict for the Responses API."""
|
||||
@@ -210,7 +226,10 @@ class RawFoundryAgentChatClient( # type: ignore[misc]
|
||||
default_options: FoundryAgentOptionsT | Mapping[str, Any] | None = None,
|
||||
context_providers: Sequence[BaseContextProvider] | None = None,
|
||||
middleware: Sequence[MiddlewareTypes] | None = None,
|
||||
**kwargs: Any,
|
||||
function_invocation_configuration: FunctionInvocationConfiguration | None = None,
|
||||
compaction_strategy: CompactionStrategy | None = None,
|
||||
tokenizer: TokenizerProtocol | None = None,
|
||||
additional_properties: Mapping[str, Any] | None = None,
|
||||
) -> Agent[FoundryAgentOptionsT]:
|
||||
"""Create a FoundryAgent that reuses this client's Foundry configuration."""
|
||||
function_tools = cast(
|
||||
@@ -233,7 +252,10 @@ class RawFoundryAgentChatClient( # type: ignore[misc]
|
||||
description=description,
|
||||
instructions=instructions,
|
||||
default_options=default_options,
|
||||
**kwargs,
|
||||
function_invocation_configuration=function_invocation_configuration,
|
||||
compaction_strategy=compaction_strategy,
|
||||
tokenizer=tokenizer,
|
||||
additional_properties=additional_properties,
|
||||
),
|
||||
)
|
||||
|
||||
@@ -365,11 +387,15 @@ class _FoundryAgentChatClient( # type: ignore[misc]
|
||||
credential: AzureCredentialTypes | None = None,
|
||||
project_client: AIProjectClient | None = None,
|
||||
allow_preview: bool | None = None,
|
||||
default_headers: Mapping[str, str] | None = None,
|
||||
env_file_path: str | None = None,
|
||||
env_file_encoding: str | None = None,
|
||||
instruction_role: str | None = None,
|
||||
compaction_strategy: CompactionStrategy | None = None,
|
||||
tokenizer: TokenizerProtocol | None = None,
|
||||
additional_properties: dict[str, Any] | None = None,
|
||||
middleware: (Sequence[ChatAndFunctionMiddlewareTypes] | None) = None,
|
||||
function_invocation_configuration: FunctionInvocationConfiguration | None = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Initialize a Foundry Agent client with full middleware support.
|
||||
|
||||
@@ -380,11 +406,15 @@ class _FoundryAgentChatClient( # type: ignore[misc]
|
||||
credential: Azure credential for authentication.
|
||||
project_client: An existing AIProjectClient to use.
|
||||
allow_preview: Enables preview opt-in on internally-created AIProjectClient.
|
||||
default_headers: Additional HTTP headers for requests made through the OpenAI client.
|
||||
env_file_path: Path to .env file for settings.
|
||||
env_file_encoding: Encoding for .env file.
|
||||
instruction_role: The role to use for 'instruction' messages.
|
||||
compaction_strategy: Optional per-client compaction override.
|
||||
tokenizer: Optional tokenizer for compaction strategies.
|
||||
additional_properties: Additional properties stored on the client instance.
|
||||
middleware: Optional sequence of middleware.
|
||||
function_invocation_configuration: Optional function invocation configuration.
|
||||
kwargs: Additional keyword arguments.
|
||||
"""
|
||||
super().__init__(
|
||||
project_endpoint=project_endpoint,
|
||||
@@ -393,11 +423,15 @@ class _FoundryAgentChatClient( # type: ignore[misc]
|
||||
credential=credential,
|
||||
project_client=project_client,
|
||||
allow_preview=allow_preview,
|
||||
default_headers=default_headers,
|
||||
env_file_path=env_file_path,
|
||||
env_file_encoding=env_file_encoding,
|
||||
instruction_role=instruction_role,
|
||||
compaction_strategy=compaction_strategy,
|
||||
tokenizer=tokenizer,
|
||||
additional_properties=additional_properties,
|
||||
middleware=middleware,
|
||||
function_invocation_configuration=function_invocation_configuration,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
@@ -435,10 +469,19 @@ class RawFoundryAgent( # type: ignore[misc]
|
||||
allow_preview: bool | None = None,
|
||||
tools: FunctionTool | Callable[..., Any] | Sequence[FunctionTool | Callable[..., Any]] | None = None,
|
||||
context_providers: Sequence[BaseContextProvider] | None = None,
|
||||
middleware: Sequence[MiddlewareTypes] | None = None,
|
||||
client_type: type[RawFoundryAgentChatClient] | None = None,
|
||||
env_file_path: str | None = None,
|
||||
env_file_encoding: str | None = None,
|
||||
**kwargs: Any,
|
||||
id: str | None = None,
|
||||
name: str | None = None,
|
||||
description: str | None = None,
|
||||
instructions: str | None = None,
|
||||
default_options: FoundryAgentOptionsT | Mapping[str, Any] | None = None,
|
||||
function_invocation_configuration: FunctionInvocationConfiguration | None = None,
|
||||
compaction_strategy: CompactionStrategy | None = None,
|
||||
tokenizer: TokenizerProtocol | None = None,
|
||||
additional_properties: Mapping[str, Any] | None = None,
|
||||
) -> None:
|
||||
"""Initialize a Foundry Agent.
|
||||
|
||||
@@ -454,11 +497,20 @@ class RawFoundryAgent( # type: ignore[misc]
|
||||
allow_preview: Enables preview opt-in on internally-created AIProjectClient.
|
||||
tools: Function tools to provide to the agent. Only ``FunctionTool`` objects are accepted.
|
||||
context_providers: Optional context providers for injecting dynamic context.
|
||||
middleware: Optional agent-level middleware.
|
||||
client_type: Custom client class to use (must be a subclass of ``RawFoundryAgentChatClient``).
|
||||
Defaults to ``_FoundryAgentChatClient`` (full client middleware).
|
||||
env_file_path: Path to .env file for settings.
|
||||
env_file_encoding: Encoding for .env file.
|
||||
kwargs: Additional keyword arguments passed to the Agent base class.
|
||||
id: Optional local agent identifier.
|
||||
name: Optional display name for the local agent wrapper.
|
||||
description: Optional local description for the local agent wrapper.
|
||||
instructions: Optional instructions for the local agent wrapper.
|
||||
default_options: Default chat options for the local agent wrapper.
|
||||
function_invocation_configuration: Optional function invocation configuration override.
|
||||
compaction_strategy: Optional agent-level in-run compaction override.
|
||||
tokenizer: Optional agent-level tokenizer override.
|
||||
additional_properties: Additional properties stored on the local agent wrapper.
|
||||
"""
|
||||
# Create the client
|
||||
actual_client_type = client_type or _FoundryAgentChatClient
|
||||
@@ -467,22 +519,38 @@ class RawFoundryAgent( # type: ignore[misc]
|
||||
f"client_type must be a subclass of RawFoundryAgentChatClient, got {actual_client_type.__name__}"
|
||||
)
|
||||
|
||||
client = actual_client_type(
|
||||
project_endpoint=project_endpoint,
|
||||
agent_name=agent_name,
|
||||
agent_version=agent_version,
|
||||
credential=credential,
|
||||
project_client=project_client,
|
||||
allow_preview=allow_preview,
|
||||
env_file_path=env_file_path,
|
||||
env_file_encoding=env_file_encoding,
|
||||
)
|
||||
client_kwargs: dict[str, Any] = {
|
||||
"project_endpoint": project_endpoint,
|
||||
"agent_name": agent_name,
|
||||
"agent_version": agent_version,
|
||||
"credential": credential,
|
||||
"project_client": project_client,
|
||||
"allow_preview": allow_preview,
|
||||
"env_file_path": env_file_path,
|
||||
"env_file_encoding": env_file_encoding,
|
||||
}
|
||||
if function_invocation_configuration is not None:
|
||||
if not issubclass(actual_client_type, FunctionInvocationLayer):
|
||||
raise TypeError(
|
||||
"function_invocation_configuration requires a FunctionInvocationLayer-based client_type."
|
||||
)
|
||||
client_kwargs["function_invocation_configuration"] = function_invocation_configuration
|
||||
|
||||
client = actual_client_type(**client_kwargs)
|
||||
|
||||
super().__init__(
|
||||
client=client, # type: ignore[arg-type]
|
||||
instructions=instructions,
|
||||
id=id,
|
||||
name=name,
|
||||
description=description,
|
||||
tools=tools, # type: ignore[arg-type]
|
||||
default_options=cast(FoundryAgentOptionsT | None, default_options),
|
||||
context_providers=context_providers,
|
||||
**kwargs,
|
||||
middleware=middleware,
|
||||
compaction_strategy=compaction_strategy,
|
||||
tokenizer=tokenizer,
|
||||
additional_properties=dict(additional_properties) if additional_properties is not None else None,
|
||||
)
|
||||
|
||||
async def configure_azure_monitor(
|
||||
@@ -598,7 +666,15 @@ class FoundryAgent( # type: ignore[misc]
|
||||
client_type: type[RawFoundryAgentChatClient] | None = None,
|
||||
env_file_path: str | None = None,
|
||||
env_file_encoding: str | None = None,
|
||||
**kwargs: Any,
|
||||
id: str | None = None,
|
||||
name: str | None = None,
|
||||
description: str | None = None,
|
||||
instructions: str | None = None,
|
||||
default_options: FoundryAgentOptionsT | Mapping[str, Any] | None = None,
|
||||
function_invocation_configuration: FunctionInvocationConfiguration | None = None,
|
||||
compaction_strategy: CompactionStrategy | None = None,
|
||||
tokenizer: TokenizerProtocol | None = None,
|
||||
additional_properties: Mapping[str, Any] | None = None,
|
||||
) -> None:
|
||||
"""Initialize a Foundry Agent with full middleware and telemetry.
|
||||
|
||||
@@ -615,7 +691,15 @@ class FoundryAgent( # type: ignore[misc]
|
||||
client_type: Custom client class (must subclass ``RawFoundryAgentChatClient``).
|
||||
env_file_path: Path to .env file for settings.
|
||||
env_file_encoding: Encoding for .env file.
|
||||
kwargs: Additional keyword arguments.
|
||||
id: Optional local agent identifier.
|
||||
name: Optional display name for the local agent wrapper.
|
||||
description: Optional local description for the local agent wrapper.
|
||||
instructions: Optional instructions for the local agent wrapper.
|
||||
default_options: Default chat options for the local agent wrapper.
|
||||
function_invocation_configuration: Optional function invocation configuration override.
|
||||
compaction_strategy: Optional agent-level in-run compaction override.
|
||||
tokenizer: Optional agent-level tokenizer override.
|
||||
additional_properties: Additional properties stored on the local agent wrapper.
|
||||
"""
|
||||
super().__init__(
|
||||
project_endpoint=project_endpoint,
|
||||
@@ -630,5 +714,13 @@ class FoundryAgent( # type: ignore[misc]
|
||||
client_type=client_type,
|
||||
env_file_path=env_file_path,
|
||||
env_file_encoding=env_file_encoding,
|
||||
**kwargs,
|
||||
id=id,
|
||||
name=name,
|
||||
description=description,
|
||||
instructions=instructions,
|
||||
default_options=default_options,
|
||||
function_invocation_configuration=function_invocation_configuration,
|
||||
compaction_strategy=compaction_strategy,
|
||||
tokenizer=tokenizer,
|
||||
additional_properties=additional_properties,
|
||||
)
|
||||
|
||||
@@ -4,7 +4,7 @@ from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import sys
|
||||
from collections.abc import Awaitable, Callable, Sequence
|
||||
from collections.abc import Awaitable, Callable, Mapping, Sequence
|
||||
from typing import TYPE_CHECKING, Any, ClassVar, Generic, Literal
|
||||
|
||||
from agent_framework import (
|
||||
@@ -15,6 +15,7 @@ from agent_framework import (
|
||||
FunctionInvocationLayer,
|
||||
load_settings,
|
||||
)
|
||||
from agent_framework._compaction import CompactionStrategy, TokenizerProtocol
|
||||
from agent_framework.observability import ChatTelemetryLayer
|
||||
from agent_framework_openai._chat_client import OpenAIChatOptions, RawOpenAIChatClient
|
||||
from azure.ai.projects.aio import AIProjectClient
|
||||
@@ -132,10 +133,13 @@ class RawFoundryChatClient( # type: ignore[misc]
|
||||
model: str | None = None,
|
||||
credential: AzureCredentialTypes | AzureTokenProvider | None = None,
|
||||
allow_preview: bool | None = None,
|
||||
default_headers: Mapping[str, str] | None = None,
|
||||
env_file_path: str | None = None,
|
||||
env_file_encoding: str | None = None,
|
||||
instruction_role: str | None = None,
|
||||
**kwargs: Any,
|
||||
compaction_strategy: CompactionStrategy | None = None,
|
||||
tokenizer: TokenizerProtocol | None = None,
|
||||
additional_properties: dict[str, Any] | None = None,
|
||||
) -> None:
|
||||
"""Initialize a raw Microsoft Foundry chat client.
|
||||
|
||||
@@ -149,10 +153,13 @@ class RawFoundryChatClient( # type: ignore[misc]
|
||||
credential: Azure credential or token provider for authentication.
|
||||
Required when using ``project_endpoint`` without a ``project_client``.
|
||||
allow_preview: Enables preview opt-in on internally-created AIProjectClient.
|
||||
default_headers: Additional HTTP headers for requests made through the OpenAI client.
|
||||
env_file_path: Path to .env file for settings.
|
||||
env_file_encoding: Encoding for .env file.
|
||||
instruction_role: The role to use for 'instruction' messages.
|
||||
kwargs: Additional keyword arguments.
|
||||
compaction_strategy: Optional per-client compaction override.
|
||||
tokenizer: Optional tokenizer for compaction strategies.
|
||||
additional_properties: Additional properties stored on the client instance.
|
||||
"""
|
||||
foundry_settings = load_settings(
|
||||
FoundrySettings,
|
||||
@@ -195,8 +202,11 @@ class RawFoundryChatClient( # type: ignore[misc]
|
||||
super().__init__(
|
||||
model=resolved_model,
|
||||
async_client=project_client.get_openai_client(),
|
||||
default_headers=default_headers,
|
||||
instruction_role=instruction_role,
|
||||
**kwargs,
|
||||
compaction_strategy=compaction_strategy,
|
||||
tokenizer=tokenizer,
|
||||
additional_properties=additional_properties,
|
||||
)
|
||||
self.project_client = project_client
|
||||
|
||||
@@ -516,12 +526,15 @@ class FoundryChatClient( # type: ignore[misc]
|
||||
model: str | None = None,
|
||||
credential: AzureCredentialTypes | AzureTokenProvider | None = None,
|
||||
allow_preview: bool | None = None,
|
||||
default_headers: Mapping[str, str] | None = None,
|
||||
env_file_path: str | None = None,
|
||||
env_file_encoding: str | None = None,
|
||||
instruction_role: str | None = None,
|
||||
compaction_strategy: CompactionStrategy | None = None,
|
||||
tokenizer: TokenizerProtocol | None = None,
|
||||
additional_properties: dict[str, Any] | None = None,
|
||||
middleware: (Sequence[ChatAndFunctionMiddlewareTypes] | None) = None,
|
||||
function_invocation_configuration: FunctionInvocationConfiguration | None = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Initialize a Foundry chat client.
|
||||
|
||||
@@ -533,12 +546,15 @@ class FoundryChatClient( # type: ignore[misc]
|
||||
Can also be set via environment variable ``FOUNDRY_MODEL``.
|
||||
credential: Azure credential or token provider for authentication.
|
||||
allow_preview: Enables preview opt-in on internally-created AIProjectClient.
|
||||
default_headers: Additional HTTP headers for requests made through the OpenAI client.
|
||||
env_file_path: Path to .env file for settings.
|
||||
env_file_encoding: Encoding for .env file.
|
||||
instruction_role: The role to use for 'instruction' messages.
|
||||
compaction_strategy: Optional per-client compaction override.
|
||||
tokenizer: Optional tokenizer for compaction strategies.
|
||||
additional_properties: Additional properties stored on the client instance.
|
||||
middleware: Optional sequence of middleware.
|
||||
function_invocation_configuration: Optional function invocation configuration.
|
||||
kwargs: Additional keyword arguments.
|
||||
"""
|
||||
super().__init__(
|
||||
project_endpoint=project_endpoint,
|
||||
@@ -546,10 +562,13 @@ class FoundryChatClient( # type: ignore[misc]
|
||||
model=model,
|
||||
credential=credential,
|
||||
allow_preview=allow_preview,
|
||||
default_headers=default_headers,
|
||||
env_file_path=env_file_path,
|
||||
env_file_encoding=env_file_encoding,
|
||||
instruction_role=instruction_role,
|
||||
compaction_strategy=compaction_strategy,
|
||||
tokenizer=tokenizer,
|
||||
additional_properties=additional_properties,
|
||||
middleware=middleware,
|
||||
function_invocation_configuration=function_invocation_configuration,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import inspect
|
||||
import os
|
||||
import sys
|
||||
from typing import Any
|
||||
@@ -68,6 +69,17 @@ def test_raw_foundry_agent_chat_client_init_with_agent_name() -> None:
|
||||
assert client.agent_version == "1.0"
|
||||
|
||||
|
||||
def test_raw_foundry_agent_chat_client_init_uses_explicit_parameters() -> None:
|
||||
signature = inspect.signature(RawFoundryAgentChatClient.__init__)
|
||||
|
||||
assert "default_headers" in signature.parameters
|
||||
assert "instruction_role" in signature.parameters
|
||||
assert "compaction_strategy" in signature.parameters
|
||||
assert "tokenizer" in signature.parameters
|
||||
assert "additional_properties" in signature.parameters
|
||||
assert all(parameter.kind != inspect.Parameter.VAR_KEYWORD for parameter in signature.parameters.values())
|
||||
|
||||
|
||||
def test_raw_foundry_agent_chat_client_get_agent_reference_with_version() -> None:
|
||||
"""Test agent reference includes version when provided."""
|
||||
|
||||
@@ -129,6 +141,15 @@ def test_raw_foundry_agent_chat_client_as_agent_preserves_client_type() -> None:
|
||||
assert named_agent.client.agent_name == "test-agent"
|
||||
|
||||
|
||||
def test_raw_foundry_agent_chat_client_as_agent_uses_explicit_parameters() -> None:
|
||||
signature = inspect.signature(RawFoundryAgentChatClient.as_agent)
|
||||
|
||||
assert "compaction_strategy" in signature.parameters
|
||||
assert "tokenizer" in signature.parameters
|
||||
assert "additional_properties" in signature.parameters
|
||||
assert all(parameter.kind != inspect.Parameter.VAR_KEYWORD for parameter in signature.parameters.values())
|
||||
|
||||
|
||||
async def test_raw_foundry_agent_chat_client_prepare_options_validates_tools() -> None:
|
||||
"""Test that _prepare_options rejects non-FunctionTool objects."""
|
||||
|
||||
@@ -210,6 +231,17 @@ def test_foundry_agent_chat_client_init() -> None:
|
||||
assert client.agent_name == "test-agent"
|
||||
|
||||
|
||||
def test_foundry_agent_chat_client_init_uses_explicit_parameters() -> None:
|
||||
signature = inspect.signature(_FoundryAgentChatClient.__init__)
|
||||
|
||||
assert "default_headers" in signature.parameters
|
||||
assert "instruction_role" in signature.parameters
|
||||
assert "compaction_strategy" in signature.parameters
|
||||
assert "tokenizer" in signature.parameters
|
||||
assert "additional_properties" in signature.parameters
|
||||
assert all(parameter.kind != inspect.Parameter.VAR_KEYWORD for parameter in signature.parameters.values())
|
||||
|
||||
|
||||
def test_raw_foundry_agent_init_creates_client() -> None:
|
||||
"""Test that RawFoundryAgent creates a client internally."""
|
||||
|
||||
@@ -241,6 +273,28 @@ def test_raw_foundry_agent_init_with_custom_client_type() -> None:
|
||||
assert isinstance(agent.client, RawFoundryAgentChatClient)
|
||||
|
||||
|
||||
def test_raw_foundry_agent_init_uses_explicit_parameters() -> None:
|
||||
signature = inspect.signature(RawFoundryAgent.__init__)
|
||||
|
||||
assert "instructions" in signature.parameters
|
||||
assert "default_options" in signature.parameters
|
||||
assert "compaction_strategy" in signature.parameters
|
||||
assert "tokenizer" in signature.parameters
|
||||
assert "additional_properties" in signature.parameters
|
||||
assert all(parameter.kind != inspect.Parameter.VAR_KEYWORD for parameter in signature.parameters.values())
|
||||
|
||||
|
||||
def test_foundry_agent_init_uses_explicit_parameters() -> None:
|
||||
signature = inspect.signature(FoundryAgent.__init__)
|
||||
|
||||
assert "instructions" in signature.parameters
|
||||
assert "default_options" in signature.parameters
|
||||
assert "compaction_strategy" in signature.parameters
|
||||
assert "tokenizer" in signature.parameters
|
||||
assert "additional_properties" in signature.parameters
|
||||
assert all(parameter.kind != inspect.Parameter.VAR_KEYWORD for parameter in signature.parameters.values())
|
||||
|
||||
|
||||
def test_raw_foundry_agent_init_rejects_invalid_client_type() -> None:
|
||||
"""Test that invalid client_type raises TypeError."""
|
||||
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import inspect
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
@@ -140,6 +141,26 @@ def test_init() -> None:
|
||||
assert client.project_client is mock_project_client
|
||||
|
||||
|
||||
def test_raw_foundry_chat_client_init_uses_explicit_parameters() -> None:
|
||||
signature = inspect.signature(RawFoundryChatClient.__init__)
|
||||
|
||||
assert "default_headers" in signature.parameters
|
||||
assert "compaction_strategy" in signature.parameters
|
||||
assert "tokenizer" in signature.parameters
|
||||
assert "additional_properties" in signature.parameters
|
||||
assert all(parameter.kind != inspect.Parameter.VAR_KEYWORD for parameter in signature.parameters.values())
|
||||
|
||||
|
||||
def test_foundry_chat_client_init_uses_explicit_parameters() -> None:
|
||||
signature = inspect.signature(FoundryChatClient.__init__)
|
||||
|
||||
assert "default_headers" in signature.parameters
|
||||
assert "compaction_strategy" in signature.parameters
|
||||
assert "tokenizer" in signature.parameters
|
||||
assert "additional_properties" in signature.parameters
|
||||
assert all(parameter.kind != inspect.Parameter.VAR_KEYWORD for parameter in signature.parameters.values())
|
||||
|
||||
|
||||
def test_init_with_default_header() -> None:
|
||||
default_headers = {"X-Unit-Test": "test-guid"}
|
||||
mock_openai_client = _make_mock_openai_client()
|
||||
|
||||
+87
-9
@@ -3,15 +3,21 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import sys
|
||||
from collections.abc import Sequence
|
||||
from typing import Any, Generic
|
||||
from collections.abc import Awaitable, Callable, Mapping, Sequence
|
||||
from typing import Any, Generic, Literal, cast, overload
|
||||
|
||||
from agent_framework import (
|
||||
ChatAndFunctionMiddlewareTypes,
|
||||
ChatMiddlewareLayer,
|
||||
ChatOptions,
|
||||
ChatResponse,
|
||||
ChatResponseUpdate,
|
||||
CompactionStrategy,
|
||||
FunctionInvocationConfiguration,
|
||||
FunctionInvocationLayer,
|
||||
Message,
|
||||
ResponseStream,
|
||||
TokenizerProtocol,
|
||||
)
|
||||
from agent_framework._settings import load_settings
|
||||
from agent_framework.observability import ChatTelemetryLayer
|
||||
@@ -122,8 +128,8 @@ class FoundryLocalSettings(TypedDict, total=False):
|
||||
'FOUNDRY_LOCAL_'.
|
||||
|
||||
Keys:
|
||||
model_id: The name of the model deployment to use.
|
||||
(Env var FOUNDRY_LOCAL_MODEL_ID)
|
||||
model: The name of the model deployment to use.
|
||||
(Env var FOUNDRY_LOCAL_MODEL)
|
||||
"""
|
||||
|
||||
model: str | None
|
||||
@@ -138,6 +144,78 @@ class FoundryLocalClient(
|
||||
):
|
||||
"""Foundry Local Chat completion class with middleware, telemetry, and function invocation support."""
|
||||
|
||||
@overload
|
||||
def get_response(
|
||||
self,
|
||||
messages: Sequence[Message],
|
||||
*,
|
||||
stream: Literal[False] = ...,
|
||||
options: ChatOptions[ResponseModelT],
|
||||
compaction_strategy: CompactionStrategy | None = None,
|
||||
tokenizer: TokenizerProtocol | None = None,
|
||||
function_invocation_kwargs: Mapping[str, Any] | None = None,
|
||||
client_kwargs: Mapping[str, Any] | None = None,
|
||||
middleware: Sequence[ChatAndFunctionMiddlewareTypes] | None = None,
|
||||
) -> Awaitable[ChatResponse[ResponseModelT]]: ...
|
||||
|
||||
@overload
|
||||
def get_response(
|
||||
self,
|
||||
messages: Sequence[Message],
|
||||
*,
|
||||
stream: Literal[False] = ...,
|
||||
options: FoundryLocalChatOptionsT | ChatOptions[None] | None = None,
|
||||
compaction_strategy: CompactionStrategy | None = None,
|
||||
tokenizer: TokenizerProtocol | None = None,
|
||||
function_invocation_kwargs: Mapping[str, Any] | None = None,
|
||||
client_kwargs: Mapping[str, Any] | None = None,
|
||||
middleware: Sequence[ChatAndFunctionMiddlewareTypes] | None = None,
|
||||
) -> Awaitable[ChatResponse[Any]]: ...
|
||||
|
||||
@overload
|
||||
def get_response(
|
||||
self,
|
||||
messages: Sequence[Message],
|
||||
*,
|
||||
stream: Literal[True],
|
||||
options: FoundryLocalChatOptionsT | ChatOptions[Any] | None = None,
|
||||
compaction_strategy: CompactionStrategy | None = None,
|
||||
tokenizer: TokenizerProtocol | None = None,
|
||||
function_invocation_kwargs: Mapping[str, Any] | None = None,
|
||||
client_kwargs: Mapping[str, Any] | None = None,
|
||||
middleware: Sequence[ChatAndFunctionMiddlewareTypes] | None = None,
|
||||
) -> ResponseStream[ChatResponseUpdate, ChatResponse[Any]]: ...
|
||||
|
||||
def get_response(
|
||||
self,
|
||||
messages: Sequence[Message],
|
||||
*,
|
||||
stream: bool = False,
|
||||
options: FoundryLocalChatOptionsT | ChatOptions[Any] | None = None,
|
||||
compaction_strategy: CompactionStrategy | None = None,
|
||||
tokenizer: TokenizerProtocol | None = None,
|
||||
function_invocation_kwargs: Mapping[str, Any] | None = None,
|
||||
client_kwargs: Mapping[str, Any] | None = None,
|
||||
middleware: Sequence[ChatAndFunctionMiddlewareTypes] | None = None,
|
||||
) -> Awaitable[ChatResponse[Any]] | ResponseStream[ChatResponseUpdate, ChatResponse[Any]]:
|
||||
"""Get a response from the Foundry Local chat client with all standard layers enabled."""
|
||||
super_get_response = cast(
|
||||
"Callable[..., Awaitable[ChatResponse[Any]] | ResponseStream[ChatResponseUpdate, ChatResponse[Any]]]",
|
||||
super().get_response,
|
||||
)
|
||||
effective_client_kwargs = dict(client_kwargs) if client_kwargs is not None else {}
|
||||
if middleware is not None:
|
||||
effective_client_kwargs["middleware"] = middleware
|
||||
return super_get_response(
|
||||
messages=messages,
|
||||
stream=stream,
|
||||
options=options,
|
||||
compaction_strategy=compaction_strategy,
|
||||
tokenizer=tokenizer,
|
||||
function_invocation_kwargs=function_invocation_kwargs,
|
||||
client_kwargs=effective_client_kwargs,
|
||||
)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: str | None = None,
|
||||
@@ -182,7 +260,7 @@ class FoundryLocalClient(
|
||||
# Create a FoundryLocalClient with a specific model ID:
|
||||
from agent_framework.foundry import FoundryLocalClient
|
||||
|
||||
client = FoundryLocalClient(model_id="phi-4-mini")
|
||||
client = FoundryLocalClient(model="phi-4-mini")
|
||||
|
||||
agent = client.as_agent(
|
||||
name="LocalAgent",
|
||||
@@ -192,7 +270,7 @@ class FoundryLocalClient(
|
||||
response = await agent.run("What's the weather like in Seattle?")
|
||||
|
||||
# Or you can set the model id in the environment:
|
||||
os.environ["FOUNDRY_LOCAL_MODEL_ID"] = "phi-4-mini"
|
||||
os.environ["FOUNDRY_LOCAL_MODEL"] = "phi-4-mini"
|
||||
client = FoundryLocalClient()
|
||||
|
||||
# A FoundryLocalManager is created and if set, the service is started.
|
||||
@@ -205,12 +283,12 @@ class FoundryLocalClient(
|
||||
from foundry_local.models import DeviceType
|
||||
|
||||
client = FoundryLocalClient(
|
||||
model_id="phi-4-mini",
|
||||
model="phi-4-mini",
|
||||
device=DeviceType.GPU,
|
||||
)
|
||||
# and choosing if the model should be prepared on initialization:
|
||||
client = FoundryLocalClient(
|
||||
model_id="phi-4-mini",
|
||||
model="phi-4-mini",
|
||||
prepare_model=False,
|
||||
)
|
||||
# Beware, in this case the first request to generate a completion
|
||||
@@ -230,7 +308,7 @@ class FoundryLocalClient(
|
||||
class MyOptions(FoundryLocalChatOptions, total=False):
|
||||
my_custom_option: str
|
||||
|
||||
client: FoundryLocalClient[MyOptions] = FoundryLocalClient(model_id="phi-4-mini")
|
||||
client: FoundryLocalClient[MyOptions] = FoundryLocalClient(model="phi-4-mini")
|
||||
response = await client.get_response("Hello", options={"my_custom_option": "value"})
|
||||
|
||||
Raises:
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
# Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
import inspect
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
import pytest
|
||||
@@ -66,6 +67,15 @@ def test_foundry_local_client_init(mock_foundry_local_manager: MagicMock) -> Non
|
||||
assert isinstance(client, SupportsChatGetResponse)
|
||||
|
||||
|
||||
def test_foundry_local_client_get_response_uses_explicit_runtime_buckets() -> None:
|
||||
"""Foundry Local should expose explicit runtime buckets instead of raw kwargs."""
|
||||
signature = inspect.signature(FoundryLocalClient.get_response)
|
||||
|
||||
assert "client_kwargs" in signature.parameters
|
||||
assert "function_invocation_kwargs" in signature.parameters
|
||||
assert all(parameter.kind != inspect.Parameter.VAR_KEYWORD for parameter in signature.parameters.values())
|
||||
|
||||
|
||||
def test_foundry_local_client_init_with_bootstrap_false(mock_foundry_local_manager: MagicMock) -> None:
|
||||
"""Test FoundryLocalClient initialization with bootstrap=False."""
|
||||
with patch(
|
||||
|
||||
@@ -211,7 +211,7 @@ class TaskRunner:
|
||||
client=assistant_chat_client,
|
||||
instructions=assistant_system_prompt,
|
||||
tools=tools,
|
||||
temperature=self.assistant_sampling_temperature,
|
||||
default_options={"temperature": self.assistant_sampling_temperature},
|
||||
context_providers=[
|
||||
SlidingWindowHistoryProvider(
|
||||
system_message=assistant_system_prompt,
|
||||
@@ -246,7 +246,7 @@ class TaskRunner:
|
||||
return Agent(
|
||||
client=user_simuator_chat_client,
|
||||
instructions=user_sim_system_prompt,
|
||||
temperature=0.0,
|
||||
default_options={"temperature": 0.0},
|
||||
# No sliding window for user simulator to maintain full conversation context
|
||||
# TODO(yuge): Consider adding user tools in future for more realistic scenarios
|
||||
)
|
||||
|
||||
@@ -17,7 +17,7 @@ else:
|
||||
from ._assistant_provider import OpenAIAssistantProvider
|
||||
from ._assistants_client import (
|
||||
AssistantToolResources,
|
||||
OpenAIAssistantsClient,
|
||||
OpenAIAssistantsClient, # type: ignore[reportDeprecated]
|
||||
OpenAIAssistantsOptions,
|
||||
)
|
||||
from ._chat_client import (
|
||||
|
||||
@@ -15,7 +15,7 @@ from openai import AsyncOpenAI
|
||||
from openai.types.beta.assistant import Assistant
|
||||
from pydantic import BaseModel
|
||||
|
||||
from ._assistants_client import OpenAIAssistantsClient
|
||||
from ._assistants_client import OpenAIAssistantsClient # type: ignore[reportDeprecated]
|
||||
from ._shared import OpenAISettings, from_assistant_tools, to_assistant_tools
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@@ -538,7 +538,7 @@ class OpenAIAssistantProvider(Generic[OptionsCoT]):
|
||||
A configured Agent instance.
|
||||
"""
|
||||
# Create the chat client with the assistant
|
||||
client = OpenAIAssistantsClient(
|
||||
client = OpenAIAssistantsClient( # type: ignore[reportDeprecated]
|
||||
model=assistant.model,
|
||||
assistant_id=assistant.id,
|
||||
assistant_name=assistant.name,
|
||||
|
||||
@@ -70,6 +70,11 @@ if sys.version_info >= (3, 12):
|
||||
else:
|
||||
from typing_extensions import override # type: ignore # pragma: no cover
|
||||
|
||||
if sys.version_info >= (3, 13):
|
||||
from warnings import deprecated # type: ignore # pragma: no cover
|
||||
else:
|
||||
from typing_extensions import deprecated # type: ignore # pragma: no cover
|
||||
|
||||
if sys.version_info >= (3, 11):
|
||||
from typing import Self, TypedDict # type: ignore # pragma: no cover
|
||||
else:
|
||||
@@ -208,6 +213,7 @@ OpenAIAssistantsOptionsT = TypeVar(
|
||||
# endregion
|
||||
|
||||
|
||||
@deprecated("OpenAIAssistantsClient is deprecated. Use OpenAIChatClient instead.")
|
||||
class OpenAIAssistantsClient( # type: ignore[misc]
|
||||
OpenAIConfigMixin,
|
||||
FunctionInvocationLayer[OpenAIAssistantsOptionsT],
|
||||
@@ -216,7 +222,11 @@ class OpenAIAssistantsClient( # type: ignore[misc]
|
||||
BaseChatClient[OpenAIAssistantsOptionsT],
|
||||
Generic[OpenAIAssistantsOptionsT],
|
||||
):
|
||||
"""OpenAI Assistants client with middleware, telemetry, and function invocation support."""
|
||||
"""OpenAI Assistants client with middleware, telemetry, and function invocation support.
|
||||
|
||||
.. deprecated::
|
||||
OpenAIAssistantsClient is deprecated. Use :class:`OpenAIChatClient` instead.
|
||||
"""
|
||||
|
||||
# region Hosted Tool Factory Methods
|
||||
|
||||
|
||||
@@ -29,6 +29,7 @@ from typing import (
|
||||
)
|
||||
|
||||
from agent_framework._clients import BaseChatClient
|
||||
from agent_framework._compaction import CompactionStrategy, TokenizerProtocol
|
||||
from agent_framework._middleware import ChatAndFunctionMiddlewareTypes, ChatMiddlewareLayer
|
||||
from agent_framework._settings import SecretString
|
||||
from agent_framework._telemetry import USER_AGENT_KEY
|
||||
@@ -278,6 +279,9 @@ class RawOpenAIChatClient( # type: ignore[misc]
|
||||
default_headers: Mapping[str, str] | None = None,
|
||||
async_client: AsyncOpenAI | None = None,
|
||||
instruction_role: str | None = None,
|
||||
compaction_strategy: CompactionStrategy | None = None,
|
||||
tokenizer: TokenizerProtocol | None = None,
|
||||
additional_properties: dict[str, Any] | None = None,
|
||||
env_file_path: str | None = None,
|
||||
env_file_encoding: str | None = None,
|
||||
) -> None:
|
||||
@@ -295,6 +299,9 @@ class RawOpenAIChatClient( # type: ignore[misc]
|
||||
default_headers: Additional HTTP headers.
|
||||
async_client: Pre-configured OpenAI client.
|
||||
instruction_role: Role for instruction messages (for example ``"system"``).
|
||||
compaction_strategy: Optional per-client compaction override.
|
||||
tokenizer: Optional tokenizer for compaction strategies.
|
||||
additional_properties: Additional properties stored on the client instance.
|
||||
env_file_path: Optional ``.env`` file that is checked before the process environment
|
||||
for ``OPENAI_*`` values.
|
||||
env_file_encoding: Encoding for the ``.env`` file.
|
||||
@@ -314,6 +321,9 @@ class RawOpenAIChatClient( # type: ignore[misc]
|
||||
default_headers: Mapping[str, str] | None = None,
|
||||
async_client: AsyncAzureOpenAI | AsyncOpenAI | None = None,
|
||||
instruction_role: str | None = None,
|
||||
compaction_strategy: CompactionStrategy | None = None,
|
||||
tokenizer: TokenizerProtocol | None = None,
|
||||
additional_properties: dict[str, Any] | None = None,
|
||||
env_file_path: str | None = None,
|
||||
env_file_encoding: str | None = None,
|
||||
) -> None:
|
||||
@@ -338,6 +348,9 @@ class RawOpenAIChatClient( # type: ignore[misc]
|
||||
async_client: Pre-configured client. Passing ``AsyncAzureOpenAI`` keeps the client on
|
||||
Azure; passing ``AsyncOpenAI`` keeps the client on OpenAI and bypasses env lookup.
|
||||
instruction_role: Role for instruction messages (for example ``"system"``).
|
||||
compaction_strategy: Optional per-client compaction override.
|
||||
tokenizer: Optional tokenizer for compaction strategies.
|
||||
additional_properties: Additional properties stored on the client instance.
|
||||
env_file_path: Optional ``.env`` file that is checked before process environment
|
||||
variables for ``AZURE_OPENAI_*`` values.
|
||||
env_file_encoding: Encoding for the ``.env`` file.
|
||||
@@ -358,9 +371,11 @@ class RawOpenAIChatClient( # type: ignore[misc]
|
||||
default_headers: Mapping[str, str] | None = None,
|
||||
async_client: AsyncOpenAI | None = None,
|
||||
instruction_role: str | None = None,
|
||||
compaction_strategy: CompactionStrategy | None = None,
|
||||
tokenizer: TokenizerProtocol | None = None,
|
||||
additional_properties: dict[str, Any] | None = None,
|
||||
env_file_path: str | None = None,
|
||||
env_file_encoding: str | None = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Initialize a raw OpenAI Chat client.
|
||||
|
||||
@@ -391,11 +406,13 @@ class RawOpenAIChatClient( # type: ignore[misc]
|
||||
async_client: Pre-configured client. Passing ``AsyncAzureOpenAI`` keeps the client on
|
||||
Azure; passing ``AsyncOpenAI`` keeps the client on OpenAI and bypasses env lookup.
|
||||
instruction_role: Role for instruction messages (for example ``"system"``).
|
||||
compaction_strategy: Optional per-client compaction override.
|
||||
tokenizer: Optional tokenizer for compaction strategies.
|
||||
additional_properties: Additional properties stored on the client instance.
|
||||
env_file_path: Optional ``.env`` file that is checked before process environment
|
||||
variables. The same file is used for both ``OPENAI_*`` and ``AZURE_OPENAI_*``
|
||||
lookups.
|
||||
env_file_encoding: Encoding for the ``.env`` file.
|
||||
kwargs: Additional keyword arguments forwarded to ``BaseChatClient``.
|
||||
|
||||
Notes:
|
||||
Environment resolution and routing precedence are:
|
||||
@@ -452,7 +469,11 @@ class RawOpenAIChatClient( # type: ignore[misc]
|
||||
if use_azure_client:
|
||||
self.OTEL_PROVIDER_NAME = "azure.ai.openai" # type: ignore[misc]
|
||||
|
||||
super().__init__(**kwargs)
|
||||
super().__init__(
|
||||
compaction_strategy=compaction_strategy,
|
||||
tokenizer=tokenizer,
|
||||
additional_properties=additional_properties,
|
||||
)
|
||||
|
||||
# region Inner Methods
|
||||
|
||||
@@ -460,7 +481,6 @@ class RawOpenAIChatClient( # type: ignore[misc]
|
||||
self,
|
||||
messages: Sequence[Message],
|
||||
options: Mapping[str, Any],
|
||||
**kwargs: Any,
|
||||
) -> tuple[AsyncOpenAI, dict[str, Any], dict[str, Any]]:
|
||||
"""Validate options and prepare the request.
|
||||
|
||||
@@ -469,7 +489,7 @@ class RawOpenAIChatClient( # type: ignore[misc]
|
||||
"""
|
||||
client = self.client
|
||||
validated_options = await self._validate_options(options)
|
||||
run_options = await self._prepare_options(messages, validated_options, **kwargs)
|
||||
run_options = await self._prepare_options(messages, validated_options)
|
||||
return client, run_options, validated_options
|
||||
|
||||
def _handle_request_error(self, ex: Exception) -> NoReturn:
|
||||
@@ -526,7 +546,7 @@ class RawOpenAIChatClient( # type: ignore[misc]
|
||||
client,
|
||||
run_options,
|
||||
validated_options,
|
||||
) = await self._prepare_request(messages, options, **kwargs)
|
||||
) = await self._prepare_request(messages, options)
|
||||
try:
|
||||
if "text_format" in run_options:
|
||||
async with client.responses.stream(**run_options) as response:
|
||||
@@ -560,7 +580,7 @@ class RawOpenAIChatClient( # type: ignore[misc]
|
||||
except Exception as ex:
|
||||
self._handle_request_error(ex)
|
||||
return self._parse_response_from_openai(response, options=validated_options)
|
||||
client, run_options, validated_options = await self._prepare_request(messages, options, **kwargs)
|
||||
client, run_options, validated_options = await self._prepare_request(messages, options)
|
||||
try:
|
||||
if "text_format" in run_options:
|
||||
response = await client.responses.parse(stream=False, **run_options)
|
||||
@@ -1121,7 +1141,6 @@ class RawOpenAIChatClient( # type: ignore[misc]
|
||||
self,
|
||||
messages: Sequence[Message],
|
||||
options: Mapping[str, Any],
|
||||
**kwargs: Any,
|
||||
) -> dict[str, Any]:
|
||||
"""Take options dict and create the specific options for Responses API."""
|
||||
# Exclude keys that are not supported or handled separately
|
||||
@@ -1143,7 +1162,7 @@ class RawOpenAIChatClient( # type: ignore[misc]
|
||||
# messages
|
||||
# Handle instructions by prepending to messages as system message
|
||||
# Only prepend instructions for the first turn (when no conversation/response ID exists)
|
||||
conversation_id = self._get_current_conversation_id(options, **kwargs)
|
||||
conversation_id = options.get("conversation_id")
|
||||
if (instructions := options.get("instructions")) and not conversation_id:
|
||||
# First turn: prepend instructions as system message
|
||||
messages = prepend_instructions_to_messages(list(messages), instructions, role="system")
|
||||
@@ -1151,7 +1170,7 @@ class RawOpenAIChatClient( # type: ignore[misc]
|
||||
request_input = self._prepare_messages_for_openai(messages)
|
||||
if not request_input:
|
||||
raise ChatClientInvalidRequestException("Messages are required for chat completions")
|
||||
conversation_id = self._get_current_conversation_id(options, **kwargs)
|
||||
conversation_id = options.get("conversation_id")
|
||||
run_options["input"] = request_input
|
||||
|
||||
# model id
|
||||
@@ -1169,7 +1188,7 @@ class RawOpenAIChatClient( # type: ignore[misc]
|
||||
run_options[new_key] = run_options.pop(old_key)
|
||||
|
||||
# Handle different conversation ID formats
|
||||
if conversation_id := self._get_current_conversation_id(options, **kwargs):
|
||||
if conversation_id := options.get("conversation_id"):
|
||||
if conversation_id.startswith("resp_"):
|
||||
# For response IDs, set previous_response_id and remove conversation property
|
||||
run_options["previous_response_id"] = conversation_id
|
||||
@@ -1223,14 +1242,6 @@ class RawOpenAIChatClient( # type: ignore[misc]
|
||||
raise ValueError("model must be a non-empty string")
|
||||
options["model"] = self.model
|
||||
|
||||
def _get_current_conversation_id(self, options: Mapping[str, Any], **kwargs: Any) -> str | None:
|
||||
"""Get the current conversation ID, preferring kwargs over options.
|
||||
|
||||
This ensures runtime-updated conversation IDs (for example, from tool execution
|
||||
loops) take precedence over the initial configuration provided in options.
|
||||
"""
|
||||
return kwargs.get("conversation_id") or options.get("conversation_id")
|
||||
|
||||
def _prepare_messages_for_openai(self, chat_messages: Sequence[Message]) -> list[dict[str, Any]]:
|
||||
"""Prepare the chat messages for a request.
|
||||
|
||||
@@ -2490,10 +2501,13 @@ class OpenAIChatClient( # type: ignore[misc]
|
||||
default_headers: Mapping[str, str] | None = None,
|
||||
async_client: AsyncOpenAI | None = None,
|
||||
instruction_role: str | None = None,
|
||||
env_file_path: str | None = None,
|
||||
env_file_encoding: str | None = None,
|
||||
compaction_strategy: CompactionStrategy | None = None,
|
||||
tokenizer: TokenizerProtocol | None = None,
|
||||
middleware: Sequence[ChatAndFunctionMiddlewareTypes] | None = None,
|
||||
function_invocation_configuration: FunctionInvocationConfiguration | None = None,
|
||||
additional_properties: dict[str, Any] | None = None,
|
||||
env_file_path: str | None = None,
|
||||
env_file_encoding: str | None = None,
|
||||
) -> None:
|
||||
"""Initialize an OpenAI Responses client.
|
||||
|
||||
@@ -2509,11 +2523,14 @@ class OpenAIChatClient( # type: ignore[misc]
|
||||
default_headers: Additional HTTP headers.
|
||||
async_client: Pre-configured OpenAI client.
|
||||
instruction_role: Role for instruction messages (for example ``"system"``).
|
||||
compaction_strategy: Optional per-client compaction override.
|
||||
tokenizer: Optional tokenizer for compaction strategies.
|
||||
middleware: Optional middleware to apply to the client.
|
||||
function_invocation_configuration: Optional function invocation configuration override.
|
||||
additional_properties: Optional additional properties to include on all requests.
|
||||
env_file_path: Optional ``.env`` file that is checked before the process environment
|
||||
for ``OPENAI_*`` values.
|
||||
env_file_encoding: Encoding for the ``.env`` file.
|
||||
middleware: Optional middleware to apply to the client.
|
||||
function_invocation_configuration: Optional function invocation configuration override.
|
||||
"""
|
||||
...
|
||||
|
||||
@@ -2530,10 +2547,13 @@ class OpenAIChatClient( # type: ignore[misc]
|
||||
default_headers: Mapping[str, str] | None = None,
|
||||
async_client: AsyncAzureOpenAI | AsyncOpenAI | None = None,
|
||||
instruction_role: str | None = None,
|
||||
env_file_path: str | None = None,
|
||||
env_file_encoding: str | None = None,
|
||||
compaction_strategy: CompactionStrategy | None = None,
|
||||
tokenizer: TokenizerProtocol | None = None,
|
||||
middleware: Sequence[ChatAndFunctionMiddlewareTypes] | None = None,
|
||||
function_invocation_configuration: FunctionInvocationConfiguration | None = None,
|
||||
additional_properties: dict[str, Any] | None = None,
|
||||
env_file_path: str | None = None,
|
||||
env_file_encoding: str | None = None,
|
||||
) -> None:
|
||||
"""Initialize an OpenAI Responses client.
|
||||
|
||||
@@ -2556,11 +2576,14 @@ class OpenAIChatClient( # type: ignore[misc]
|
||||
async_client: Pre-configured client. Passing ``AsyncAzureOpenAI`` keeps the client on
|
||||
Azure; passing ``AsyncOpenAI`` keeps the client on OpenAI and bypasses env lookup.
|
||||
instruction_role: Role for instruction messages (for example ``"system"``).
|
||||
compaction_strategy: Optional per-client compaction override.
|
||||
tokenizer: Optional tokenizer for compaction strategies.
|
||||
middleware: Optional middleware to apply to the client.
|
||||
function_invocation_configuration: Optional function invocation configuration override.
|
||||
additional_properties: Optional additional properties to include on all requests.
|
||||
env_file_path: Optional ``.env`` file that is checked before process environment
|
||||
variables for ``AZURE_OPENAI_*`` values.
|
||||
env_file_encoding: Encoding for the ``.env`` file.
|
||||
middleware: Optional middleware to apply to the client.
|
||||
function_invocation_configuration: Optional function invocation configuration override.
|
||||
"""
|
||||
...
|
||||
|
||||
@@ -2577,11 +2600,13 @@ class OpenAIChatClient( # type: ignore[misc]
|
||||
default_headers: Mapping[str, str] | None = None,
|
||||
async_client: AsyncOpenAI | None = None,
|
||||
instruction_role: str | None = None,
|
||||
env_file_path: str | None = None,
|
||||
env_file_encoding: str | None = None,
|
||||
compaction_strategy: CompactionStrategy | None = None,
|
||||
tokenizer: TokenizerProtocol | None = None,
|
||||
middleware: Sequence[ChatAndFunctionMiddlewareTypes] | None = None,
|
||||
function_invocation_configuration: FunctionInvocationConfiguration | None = None,
|
||||
**kwargs: Any,
|
||||
additional_properties: dict[str, Any] | None = None,
|
||||
env_file_path: str | None = None,
|
||||
env_file_encoding: str | None = None,
|
||||
) -> None:
|
||||
"""Initialize an OpenAI Responses client.
|
||||
|
||||
@@ -2611,13 +2636,15 @@ class OpenAIChatClient( # type: ignore[misc]
|
||||
async_client: Pre-configured client. Passing ``AsyncAzureOpenAI`` keeps the client on
|
||||
Azure; passing ``AsyncOpenAI`` keeps the client on OpenAI and bypasses env lookup.
|
||||
instruction_role: Role to use for instruction messages (for example ``"system"``).
|
||||
compaction_strategy: Optional per-client compaction override.
|
||||
tokenizer: Optional tokenizer for compaction strategies.
|
||||
middleware: Optional middleware to apply to the client.
|
||||
function_invocation_configuration: Optional function invocation configuration override.
|
||||
additional_properties: Additional properties stored on the client instance.
|
||||
env_file_path: Optional ``.env`` file that is checked before process environment
|
||||
variables. The same file is used for both ``OPENAI_*`` and ``AZURE_OPENAI_*``
|
||||
lookups.
|
||||
env_file_encoding: Encoding for the ``.env`` file.
|
||||
middleware: Optional middleware to apply to the client.
|
||||
function_invocation_configuration: Optional function invocation configuration override.
|
||||
kwargs: Other keyword parameters.
|
||||
|
||||
Notes:
|
||||
Environment resolution and routing precedence are:
|
||||
@@ -2675,7 +2702,9 @@ class OpenAIChatClient( # type: ignore[misc]
|
||||
env_file_encoding=env_file_encoding,
|
||||
middleware=middleware,
|
||||
function_invocation_configuration=function_invocation_configuration,
|
||||
**kwargs,
|
||||
compaction_strategy=compaction_strategy,
|
||||
tokenizer=tokenizer,
|
||||
additional_properties=additional_properties,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -18,6 +18,7 @@ from itertools import chain
|
||||
from typing import TYPE_CHECKING, Any, ClassVar, Generic, Literal, cast, overload
|
||||
|
||||
from agent_framework._clients import BaseChatClient
|
||||
from agent_framework._compaction import CompactionStrategy, TokenizerProtocol
|
||||
from agent_framework._docstrings import apply_layered_docstring
|
||||
from agent_framework._middleware import ChatAndFunctionMiddlewareTypes, ChatMiddlewareLayer
|
||||
from agent_framework._settings import SecretString
|
||||
@@ -193,6 +194,9 @@ class RawOpenAIChatCompletionClient( # type: ignore[misc]
|
||||
default_headers: Mapping[str, str] | None = None,
|
||||
async_client: AsyncOpenAI | None = None,
|
||||
instruction_role: str | None = None,
|
||||
compaction_strategy: CompactionStrategy | None = None,
|
||||
tokenizer: TokenizerProtocol | None = None,
|
||||
additional_properties: dict[str, Any] | None = None,
|
||||
env_file_path: str | None = None,
|
||||
env_file_encoding: str | None = None,
|
||||
) -> None:
|
||||
@@ -210,6 +214,9 @@ class RawOpenAIChatCompletionClient( # type: ignore[misc]
|
||||
default_headers: Additional HTTP headers.
|
||||
async_client: Pre-configured OpenAI client.
|
||||
instruction_role: Role for instruction messages (for example ``"system"``).
|
||||
compaction_strategy: Optional per-client compaction override.
|
||||
tokenizer: Optional tokenizer for compaction strategies.
|
||||
additional_properties: Additional properties stored on the client instance.
|
||||
env_file_path: Optional ``.env`` file that is checked before the process environment
|
||||
for ``OPENAI_*`` values.
|
||||
env_file_encoding: Encoding for the ``.env`` file.
|
||||
@@ -229,6 +236,9 @@ class RawOpenAIChatCompletionClient( # type: ignore[misc]
|
||||
default_headers: Mapping[str, str] | None = None,
|
||||
async_client: AsyncAzureOpenAI | AsyncOpenAI | None = None,
|
||||
instruction_role: str | None = None,
|
||||
compaction_strategy: CompactionStrategy | None = None,
|
||||
tokenizer: TokenizerProtocol | None = None,
|
||||
additional_properties: dict[str, Any] | None = None,
|
||||
env_file_path: str | None = None,
|
||||
env_file_encoding: str | None = None,
|
||||
) -> None:
|
||||
@@ -253,6 +263,9 @@ class RawOpenAIChatCompletionClient( # type: ignore[misc]
|
||||
async_client: Pre-configured client. Passing ``AsyncAzureOpenAI`` keeps the client on
|
||||
Azure; passing ``AsyncOpenAI`` keeps the client on OpenAI and bypasses env lookup.
|
||||
instruction_role: Role for instruction messages (for example ``"system"``).
|
||||
compaction_strategy: Optional per-client compaction override.
|
||||
tokenizer: Optional tokenizer for compaction strategies.
|
||||
additional_properties: Additional properties stored on the client instance.
|
||||
env_file_path: Optional ``.env`` file that is checked before process environment
|
||||
variables for ``AZURE_OPENAI_*`` values.
|
||||
env_file_encoding: Encoding for the ``.env`` file.
|
||||
@@ -273,9 +286,11 @@ class RawOpenAIChatCompletionClient( # type: ignore[misc]
|
||||
default_headers: Mapping[str, str] | None = None,
|
||||
async_client: AsyncOpenAI | None = None,
|
||||
instruction_role: str | None = None,
|
||||
compaction_strategy: CompactionStrategy | None = None,
|
||||
tokenizer: TokenizerProtocol | None = None,
|
||||
additional_properties: dict[str, Any] | None = None,
|
||||
env_file_path: str | None = None,
|
||||
env_file_encoding: str | None = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Initialize a raw OpenAI Chat completion client.
|
||||
|
||||
@@ -306,11 +321,13 @@ class RawOpenAIChatCompletionClient( # type: ignore[misc]
|
||||
async_client: Pre-configured client. Passing ``AsyncAzureOpenAI`` keeps the client on
|
||||
Azure; passing ``AsyncOpenAI`` keeps the client on OpenAI and bypasses env lookup.
|
||||
instruction_role: Role for instruction messages (for example ``"system"``).
|
||||
compaction_strategy: Optional per-client compaction override.
|
||||
tokenizer: Optional tokenizer for compaction strategies.
|
||||
additional_properties: Additional properties stored on the client instance.
|
||||
env_file_path: Optional ``.env`` file that is checked before process environment
|
||||
variables. The same file is used for both ``OPENAI_*`` and ``AZURE_OPENAI_*``
|
||||
lookups.
|
||||
env_file_encoding: Encoding for the ``.env`` file.
|
||||
kwargs: Additional keyword arguments forwarded to ``BaseChatClient``.
|
||||
|
||||
Notes:
|
||||
Environment resolution and routing precedence are:
|
||||
@@ -366,7 +383,11 @@ class RawOpenAIChatCompletionClient( # type: ignore[misc]
|
||||
if use_azure_client:
|
||||
self.OTEL_PROVIDER_NAME = "azure.ai.openai" # type: ignore[misc]
|
||||
|
||||
super().__init__(**kwargs)
|
||||
super().__init__(
|
||||
compaction_strategy=compaction_strategy,
|
||||
tokenizer=tokenizer,
|
||||
additional_properties=additional_properties,
|
||||
)
|
||||
|
||||
# region Hosted Tool Factory Methods
|
||||
|
||||
@@ -427,7 +448,10 @@ class RawOpenAIChatCompletionClient( # type: ignore[misc]
|
||||
*,
|
||||
stream: Literal[False] = ...,
|
||||
options: ChatOptions[ResponseModelBoundT],
|
||||
**kwargs: Any,
|
||||
compaction_strategy: CompactionStrategy | None = None,
|
||||
tokenizer: TokenizerProtocol | None = None,
|
||||
function_invocation_kwargs: Mapping[str, Any] | None = None,
|
||||
client_kwargs: Mapping[str, Any] | None = None,
|
||||
) -> Awaitable[ChatResponse[ResponseModelBoundT]]: ...
|
||||
|
||||
@overload
|
||||
@@ -437,7 +461,10 @@ class RawOpenAIChatCompletionClient( # type: ignore[misc]
|
||||
*,
|
||||
stream: Literal[False] = ...,
|
||||
options: OpenAIChatCompletionOptionsT | ChatOptions[None] | None = None,
|
||||
**kwargs: Any,
|
||||
compaction_strategy: CompactionStrategy | None = None,
|
||||
tokenizer: TokenizerProtocol | None = None,
|
||||
function_invocation_kwargs: Mapping[str, Any] | None = None,
|
||||
client_kwargs: Mapping[str, Any] | None = None,
|
||||
) -> Awaitable[ChatResponse[Any]]: ...
|
||||
|
||||
@overload
|
||||
@@ -447,7 +474,10 @@ class RawOpenAIChatCompletionClient( # type: ignore[misc]
|
||||
*,
|
||||
stream: Literal[True],
|
||||
options: OpenAIChatCompletionOptionsT | ChatOptions[Any] | None = None,
|
||||
**kwargs: Any,
|
||||
compaction_strategy: CompactionStrategy | None = None,
|
||||
tokenizer: TokenizerProtocol | None = None,
|
||||
function_invocation_kwargs: Mapping[str, Any] | None = None,
|
||||
client_kwargs: Mapping[str, Any] | None = None,
|
||||
) -> ResponseStream[ChatResponseUpdate, ChatResponse[Any]]: ...
|
||||
|
||||
@override
|
||||
@@ -457,7 +487,10 @@ class RawOpenAIChatCompletionClient( # type: ignore[misc]
|
||||
*,
|
||||
stream: bool = False,
|
||||
options: OpenAIChatCompletionOptionsT | ChatOptions[Any] | None = None,
|
||||
**kwargs: Any,
|
||||
compaction_strategy: CompactionStrategy | None = None,
|
||||
tokenizer: TokenizerProtocol | None = None,
|
||||
function_invocation_kwargs: Mapping[str, Any] | None = None,
|
||||
client_kwargs: Mapping[str, Any] | None = None,
|
||||
) -> Awaitable[ChatResponse[Any]] | ResponseStream[ChatResponseUpdate, ChatResponse[Any]]:
|
||||
"""Get a response from the raw OpenAI chat client."""
|
||||
super_get_response = cast(
|
||||
@@ -468,7 +501,10 @@ class RawOpenAIChatCompletionClient( # type: ignore[misc]
|
||||
messages=messages,
|
||||
stream=stream,
|
||||
options=options,
|
||||
**kwargs,
|
||||
compaction_strategy=compaction_strategy,
|
||||
tokenizer=tokenizer,
|
||||
function_invocation_kwargs=function_invocation_kwargs,
|
||||
client_kwargs=client_kwargs,
|
||||
)
|
||||
|
||||
@override
|
||||
@@ -1205,10 +1241,11 @@ class OpenAIChatCompletionClient( # type: ignore[misc]
|
||||
*,
|
||||
stream: Literal[False] = ...,
|
||||
options: ChatOptions[ResponseModelBoundT],
|
||||
compaction_strategy: CompactionStrategy | None = None,
|
||||
tokenizer: TokenizerProtocol | None = None,
|
||||
function_invocation_kwargs: Mapping[str, Any] | None = None,
|
||||
client_kwargs: Mapping[str, Any] | None = None,
|
||||
middleware: Sequence[ChatAndFunctionMiddlewareTypes] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> Awaitable[ChatResponse[ResponseModelBoundT]]: ...
|
||||
|
||||
@overload
|
||||
@@ -1218,10 +1255,11 @@ class OpenAIChatCompletionClient( # type: ignore[misc]
|
||||
*,
|
||||
stream: Literal[False] = ...,
|
||||
options: OpenAIChatCompletionOptionsT | ChatOptions[None] | None = None,
|
||||
compaction_strategy: CompactionStrategy | None = None,
|
||||
tokenizer: TokenizerProtocol | None = None,
|
||||
function_invocation_kwargs: Mapping[str, Any] | None = None,
|
||||
client_kwargs: Mapping[str, Any] | None = None,
|
||||
middleware: Sequence[ChatAndFunctionMiddlewareTypes] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> Awaitable[ChatResponse[Any]]: ...
|
||||
|
||||
@overload
|
||||
@@ -1231,10 +1269,11 @@ class OpenAIChatCompletionClient( # type: ignore[misc]
|
||||
*,
|
||||
stream: Literal[True],
|
||||
options: OpenAIChatCompletionOptionsT | ChatOptions[Any] | None = None,
|
||||
compaction_strategy: CompactionStrategy | None = None,
|
||||
tokenizer: TokenizerProtocol | None = None,
|
||||
function_invocation_kwargs: Mapping[str, Any] | None = None,
|
||||
client_kwargs: Mapping[str, Any] | None = None,
|
||||
middleware: Sequence[ChatAndFunctionMiddlewareTypes] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> ResponseStream[ChatResponseUpdate, ChatResponse[Any]]: ...
|
||||
|
||||
@override
|
||||
@@ -1244,10 +1283,11 @@ class OpenAIChatCompletionClient( # type: ignore[misc]
|
||||
*,
|
||||
stream: bool = False,
|
||||
options: OpenAIChatCompletionOptionsT | ChatOptions[Any] | None = None,
|
||||
compaction_strategy: CompactionStrategy | None = None,
|
||||
tokenizer: TokenizerProtocol | None = None,
|
||||
function_invocation_kwargs: Mapping[str, Any] | None = None,
|
||||
client_kwargs: Mapping[str, Any] | None = None,
|
||||
middleware: Sequence[ChatAndFunctionMiddlewareTypes] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> Awaitable[ChatResponse[Any]] | ResponseStream[ChatResponseUpdate, ChatResponse[Any]]:
|
||||
"""Get a response from the OpenAI chat client with all standard layers enabled."""
|
||||
super_get_response = cast(
|
||||
@@ -1261,9 +1301,10 @@ class OpenAIChatCompletionClient( # type: ignore[misc]
|
||||
messages=messages,
|
||||
stream=stream,
|
||||
options=options,
|
||||
compaction_strategy=compaction_strategy,
|
||||
tokenizer=tokenizer,
|
||||
function_invocation_kwargs=function_invocation_kwargs,
|
||||
client_kwargs=effective_client_kwargs,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -79,6 +79,7 @@ class RawOpenAIEmbeddingClient(
|
||||
base_url: str | None = None,
|
||||
default_headers: Mapping[str, str] | None = None,
|
||||
async_client: AsyncOpenAI | None = None,
|
||||
additional_properties: dict[str, Any] | None = None,
|
||||
env_file_path: str | None = None,
|
||||
env_file_encoding: str | None = None,
|
||||
) -> None:
|
||||
@@ -95,6 +96,7 @@ class RawOpenAIEmbeddingClient(
|
||||
``OPENAI_BASE_URL``.
|
||||
default_headers: Additional HTTP headers.
|
||||
async_client: Pre-configured OpenAI client.
|
||||
additional_properties: Additional properties stored on the client instance.
|
||||
env_file_path: Optional ``.env`` file that is checked before the process environment
|
||||
for ``OPENAI_*`` values.
|
||||
env_file_encoding: Encoding for the ``.env`` file.
|
||||
@@ -113,6 +115,7 @@ class RawOpenAIEmbeddingClient(
|
||||
base_url: str | None = None,
|
||||
default_headers: Mapping[str, str] | None = None,
|
||||
async_client: AsyncAzureOpenAI | AsyncOpenAI | None = None,
|
||||
additional_properties: dict[str, Any] | None = None,
|
||||
env_file_path: str | None = None,
|
||||
env_file_encoding: str | None = None,
|
||||
) -> None:
|
||||
@@ -136,6 +139,7 @@ class RawOpenAIEmbeddingClient(
|
||||
default_headers: Additional HTTP headers.
|
||||
async_client: Pre-configured client. Passing ``AsyncAzureOpenAI`` keeps the client on
|
||||
Azure; passing ``AsyncOpenAI`` keeps the client on OpenAI.
|
||||
additional_properties: Additional properties stored on the client instance.
|
||||
env_file_path: Optional ``.env`` file that is checked before process environment
|
||||
variables for ``AZURE_OPENAI_*`` values.
|
||||
env_file_encoding: Encoding for the ``.env`` file.
|
||||
@@ -155,9 +159,9 @@ class RawOpenAIEmbeddingClient(
|
||||
api_version: str | None = None,
|
||||
default_headers: Mapping[str, str] | None = None,
|
||||
async_client: AsyncAzureOpenAI | AsyncOpenAI | None = None,
|
||||
additional_properties: dict[str, Any] | None = None,
|
||||
env_file_path: str | None = None,
|
||||
env_file_encoding: str | None = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Initialize a raw OpenAI embedding client.
|
||||
|
||||
@@ -187,11 +191,11 @@ class RawOpenAIEmbeddingClient(
|
||||
default_headers: Additional HTTP headers.
|
||||
async_client: Pre-configured client. Passing ``AsyncAzureOpenAI`` keeps the client on
|
||||
Azure; passing ``AsyncOpenAI`` keeps the client on OpenAI.
|
||||
additional_properties: Additional properties stored on the client instance.
|
||||
env_file_path: Optional ``.env`` file that is checked before process environment
|
||||
variables. The same file is used for both ``OPENAI_*`` and ``AZURE_OPENAI_*``
|
||||
lookups.
|
||||
env_file_encoding: Encoding for the ``.env`` file.
|
||||
kwargs: Additional keyword arguments forwarded to ``BaseEmbeddingClient``.
|
||||
|
||||
Notes:
|
||||
Environment resolution precedence is:
|
||||
@@ -247,7 +251,7 @@ class RawOpenAIEmbeddingClient(
|
||||
if use_azure_client:
|
||||
self.OTEL_PROVIDER_NAME = "azure.ai.openai" # type: ignore[misc]
|
||||
|
||||
super().__init__(**kwargs)
|
||||
super().__init__(additional_properties=additional_properties)
|
||||
|
||||
def service_url(self) -> str:
|
||||
"""Get the URL of the service."""
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
# Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
import inspect
|
||||
import json
|
||||
import logging
|
||||
from typing import Annotated, Any
|
||||
@@ -11,6 +12,11 @@ from agent_framework import (
|
||||
Content,
|
||||
Message,
|
||||
SupportsChatGetResponse,
|
||||
SupportsCodeInterpreterTool,
|
||||
SupportsFileSearchTool,
|
||||
SupportsImageGenerationTool,
|
||||
SupportsMCPTool,
|
||||
SupportsWebSearchTool,
|
||||
tool,
|
||||
)
|
||||
from openai.types.beta.threads import (
|
||||
@@ -30,6 +36,8 @@ from pydantic import Field
|
||||
|
||||
from agent_framework_openai import OpenAIAssistantsClient
|
||||
|
||||
pytestmark = pytest.mark.filterwarnings("ignore:OpenAIAssistantsClient is deprecated\\..*:DeprecationWarning")
|
||||
|
||||
|
||||
def create_test_openai_assistants_client(
|
||||
mock_async_openai: MagicMock,
|
||||
@@ -104,6 +112,25 @@ def mock_async_openai() -> MagicMock:
|
||||
return mock_client
|
||||
|
||||
|
||||
def test_openai_assistants_client_is_deprecated(mock_async_openai: MagicMock) -> None:
|
||||
with pytest.warns(DeprecationWarning, match="OpenAIAssistantsClient is deprecated. Use OpenAIChatClient instead."):
|
||||
OpenAIAssistantsClient(model="gpt-4", api_key="test-api-key", async_client=mock_async_openai)
|
||||
|
||||
|
||||
def test_openai_assistants_client_init_keeps_var_keyword() -> None:
|
||||
signature = inspect.signature(OpenAIAssistantsClient.__init__)
|
||||
|
||||
assert any(parameter.kind == inspect.Parameter.VAR_KEYWORD for parameter in signature.parameters.values())
|
||||
|
||||
|
||||
def test_openai_assistants_client_supports_code_interpreter_and_file_search() -> None:
|
||||
assert isinstance(OpenAIAssistantsClient, SupportsCodeInterpreterTool)
|
||||
assert not isinstance(OpenAIAssistantsClient, SupportsWebSearchTool)
|
||||
assert not isinstance(OpenAIAssistantsClient, SupportsImageGenerationTool)
|
||||
assert not isinstance(OpenAIAssistantsClient, SupportsMCPTool)
|
||||
assert isinstance(OpenAIAssistantsClient, SupportsFileSearchTool)
|
||||
|
||||
|
||||
def test_init_with_client(mock_async_openai: MagicMock) -> None:
|
||||
"""Test OpenAIAssistantsClient initialization with existing client."""
|
||||
client = create_test_openai_assistants_client(
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
# Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
import base64
|
||||
import inspect
|
||||
import json
|
||||
import os
|
||||
from datetime import datetime, timezone
|
||||
@@ -18,6 +19,11 @@ from agent_framework import (
|
||||
FunctionTool,
|
||||
Message,
|
||||
SupportsChatGetResponse,
|
||||
SupportsCodeInterpreterTool,
|
||||
SupportsFileSearchTool,
|
||||
SupportsImageGenerationTool,
|
||||
SupportsMCPTool,
|
||||
SupportsWebSearchTool,
|
||||
tool,
|
||||
)
|
||||
from agent_framework._sessions import (
|
||||
@@ -48,7 +54,7 @@ from openai.types.responses.response_text_delta_event import ResponseTextDeltaEv
|
||||
from pydantic import BaseModel
|
||||
from pytest import param
|
||||
|
||||
from agent_framework_openai import OpenAIChatClient
|
||||
from agent_framework_openai import OpenAIChatClient, OpenAIResponsesClient
|
||||
from agent_framework_openai._chat_client import OPENAI_LOCAL_SHELL_CALL_ITEM_ID_KEY
|
||||
from agent_framework_openai._exceptions import OpenAIContentFilterException
|
||||
|
||||
@@ -110,6 +116,40 @@ def test_init(openai_unit_test_env: dict[str, str]) -> None:
|
||||
assert isinstance(openai_responses_client, SupportsChatGetResponse)
|
||||
|
||||
|
||||
def test_init_uses_explicit_parameters() -> None:
|
||||
signature = inspect.signature(OpenAIChatClient.__init__)
|
||||
|
||||
assert "additional_properties" in signature.parameters
|
||||
assert "compaction_strategy" in signature.parameters
|
||||
assert "tokenizer" in signature.parameters
|
||||
assert all(parameter.kind != inspect.Parameter.VAR_KEYWORD for parameter in signature.parameters.values())
|
||||
|
||||
|
||||
def test_deprecated_responses_client_supports_all_tool_protocols() -> None:
|
||||
assert isinstance(OpenAIResponsesClient, SupportsCodeInterpreterTool)
|
||||
assert isinstance(OpenAIResponsesClient, SupportsWebSearchTool)
|
||||
assert isinstance(OpenAIResponsesClient, SupportsImageGenerationTool)
|
||||
assert isinstance(OpenAIResponsesClient, SupportsMCPTool)
|
||||
assert isinstance(OpenAIResponsesClient, SupportsFileSearchTool)
|
||||
|
||||
|
||||
def test_protocol_isinstance_with_responses_client_instance() -> None:
|
||||
client = object.__new__(OpenAIResponsesClient)
|
||||
|
||||
assert isinstance(client, SupportsCodeInterpreterTool)
|
||||
assert isinstance(client, SupportsWebSearchTool)
|
||||
|
||||
|
||||
def test_deprecated_responses_client_tool_methods_return_dict() -> None:
|
||||
code_tool = OpenAIResponsesClient.get_code_interpreter_tool()
|
||||
assert isinstance(code_tool, dict)
|
||||
assert code_tool.get("type") == "code_interpreter"
|
||||
|
||||
web_tool = OpenAIResponsesClient.get_web_search_tool()
|
||||
assert isinstance(web_tool, dict)
|
||||
assert web_tool.get("type") == "web_search"
|
||||
|
||||
|
||||
def test_init_prefers_openai_responses_model(monkeypatch, openai_unit_test_env: dict[str, str]) -> None:
|
||||
monkeypatch.setenv("OPENAI_RESPONSES_MODEL", "test_responses_model_id")
|
||||
|
||||
@@ -3033,20 +3073,6 @@ async def test_prepare_options_store_parameter_handling() -> None:
|
||||
assert "previous_response_id" not in options
|
||||
|
||||
|
||||
async def test_conversation_id_precedence_kwargs_over_options() -> None:
|
||||
"""When both kwargs and options contain conversation_id, kwargs wins."""
|
||||
client = OpenAIChatClient(model="test-model", api_key="test-key")
|
||||
messages = [Message(role="user", text="Hello")]
|
||||
|
||||
# options has a stale response id, kwargs carries the freshest one
|
||||
opts = {"conversation_id": "resp_old_123"}
|
||||
run_opts = await client._prepare_options(messages, opts, conversation_id="resp_new_456") # type: ignore
|
||||
|
||||
# Verify kwargs takes precedence and maps to previous_response_id for resp_* IDs
|
||||
assert run_opts.get("previous_response_id") == "resp_new_456"
|
||||
assert "conversation" not in run_opts
|
||||
|
||||
|
||||
def _create_mock_responses_text_response(*, response_id: str) -> MagicMock:
|
||||
mock_response = MagicMock()
|
||||
mock_response.id = response_id
|
||||
|
||||
@@ -465,7 +465,7 @@ async def test_integration_client_agent_existing_session() -> None:
|
||||
first_response = await first_agent.run(
|
||||
"My hobby is photography. Remember this.",
|
||||
session=session,
|
||||
store=True,
|
||||
options={"store": True},
|
||||
)
|
||||
|
||||
assert isinstance(first_response, AgentResponse)
|
||||
@@ -476,7 +476,9 @@ async def test_integration_client_agent_existing_session() -> None:
|
||||
client=OpenAIChatClient(credential=credential),
|
||||
instructions="You are a helpful assistant with good memory.",
|
||||
) as second_agent:
|
||||
second_response = await second_agent.run("What is my hobby?", session=preserved_session)
|
||||
second_response = await second_agent.run(
|
||||
"What is my hobby?", session=preserved_session, options={"store": True}
|
||||
)
|
||||
|
||||
assert isinstance(second_response, AgentResponse)
|
||||
assert second_response.text is not None
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
# Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
import inspect
|
||||
import json
|
||||
import os
|
||||
from typing import Any
|
||||
@@ -11,6 +12,11 @@ from agent_framework import (
|
||||
Content,
|
||||
Message,
|
||||
SupportsChatGetResponse,
|
||||
SupportsCodeInterpreterTool,
|
||||
SupportsFileSearchTool,
|
||||
SupportsImageGenerationTool,
|
||||
SupportsMCPTool,
|
||||
SupportsWebSearchTool,
|
||||
tool,
|
||||
)
|
||||
from agent_framework.exceptions import ChatClientException, SettingNotFoundError
|
||||
@@ -20,7 +26,7 @@ from openai.types.chat.chat_completion_message import ChatCompletionMessage
|
||||
from pydantic import BaseModel
|
||||
from pytest import param
|
||||
|
||||
from agent_framework_openai import OpenAIChatCompletionClient
|
||||
from agent_framework_openai import OpenAIChatCompletionClient, RawOpenAIChatCompletionClient
|
||||
from agent_framework_openai._exceptions import OpenAIContentFilterException
|
||||
|
||||
skip_if_openai_integration_tests_disabled = pytest.mark.skipif(
|
||||
@@ -37,6 +43,41 @@ def test_init(openai_unit_test_env: dict[str, str]) -> None:
|
||||
assert isinstance(open_ai_chat_completion, SupportsChatGetResponse)
|
||||
|
||||
|
||||
def test_get_response_docstring_surfaces_layered_runtime_docs() -> None:
|
||||
docstring = inspect.getdoc(OpenAIChatCompletionClient.get_response)
|
||||
|
||||
assert docstring is not None
|
||||
assert "Get a response from a chat client." in docstring
|
||||
assert "function_invocation_kwargs" in docstring
|
||||
assert "middleware: Optional per-call chat and function middleware." in docstring
|
||||
assert "function_middleware: Optional per-call function middleware." not in docstring
|
||||
|
||||
|
||||
def test_get_response_is_defined_on_openai_class() -> None:
|
||||
signature = inspect.signature(OpenAIChatCompletionClient.get_response)
|
||||
|
||||
assert OpenAIChatCompletionClient.get_response.__qualname__ == "OpenAIChatCompletionClient.get_response"
|
||||
assert "middleware" in signature.parameters
|
||||
assert all(parameter.kind != inspect.Parameter.VAR_KEYWORD for parameter in signature.parameters.values())
|
||||
|
||||
|
||||
def test_init_uses_explicit_parameters() -> None:
|
||||
signature = inspect.signature(RawOpenAIChatCompletionClient.__init__)
|
||||
|
||||
assert "additional_properties" in signature.parameters
|
||||
assert "compaction_strategy" in signature.parameters
|
||||
assert "tokenizer" in signature.parameters
|
||||
assert all(parameter.kind != inspect.Parameter.VAR_KEYWORD for parameter in signature.parameters.values())
|
||||
|
||||
|
||||
def test_supports_web_search_only() -> None:
|
||||
assert not isinstance(OpenAIChatCompletionClient, SupportsCodeInterpreterTool)
|
||||
assert isinstance(OpenAIChatCompletionClient, SupportsWebSearchTool)
|
||||
assert not isinstance(OpenAIChatCompletionClient, SupportsImageGenerationTool)
|
||||
assert not isinstance(OpenAIChatCompletionClient, SupportsMCPTool)
|
||||
assert not isinstance(OpenAIChatCompletionClient, SupportsFileSearchTool)
|
||||
|
||||
|
||||
def test_init_prefers_openai_chat_model(monkeypatch, openai_unit_test_env: dict[str, str]) -> None:
|
||||
monkeypatch.setenv("OPENAI_CHAT_MODEL", "test_chat_model_id")
|
||||
|
||||
|
||||
@@ -138,7 +138,7 @@ async def test_cmc_structured_output_no_fcc(
|
||||
openai_chat_completion = OpenAIChatCompletionClient()
|
||||
await openai_chat_completion.get_response(
|
||||
messages=chat_history,
|
||||
response_format=Test,
|
||||
options={"response_format": Test},
|
||||
)
|
||||
mock_create.assert_awaited_once()
|
||||
|
||||
@@ -322,7 +322,7 @@ async def test_get_streaming_structured_output_no_fcc(
|
||||
async for msg in openai_chat_completion.get_response(
|
||||
stream=True,
|
||||
messages=chat_history,
|
||||
response_format=Test,
|
||||
options={"response_format": Test},
|
||||
):
|
||||
assert isinstance(msg, ChatResponseUpdate)
|
||||
mock_create.assert_awaited_once()
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import inspect
|
||||
import os
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
|
||||
@@ -15,6 +16,7 @@ from agent_framework_openai import (
|
||||
OpenAIEmbeddingClient,
|
||||
OpenAIEmbeddingOptions,
|
||||
)
|
||||
from agent_framework_openai._embedding_client import RawOpenAIEmbeddingClient
|
||||
|
||||
|
||||
def _make_openai_response(
|
||||
@@ -44,6 +46,13 @@ def test_openai_construction_with_explicit_params() -> None:
|
||||
assert client.model == "text-embedding-3-small"
|
||||
|
||||
|
||||
def test_raw_openai_embedding_client_init_uses_explicit_parameters() -> None:
|
||||
signature = inspect.signature(RawOpenAIEmbeddingClient.__init__)
|
||||
|
||||
assert "additional_properties" in signature.parameters
|
||||
assert all(parameter.kind != inspect.Parameter.VAR_KEYWORD for parameter in signature.parameters.values())
|
||||
|
||||
|
||||
def test_openai_construction_from_env(openai_unit_test_env: dict[str, str]) -> None:
|
||||
client = OpenAIEmbeddingClient()
|
||||
assert client.model == openai_unit_test_env["OPENAI_EMBEDDING_MODEL"]
|
||||
|
||||
@@ -1,21 +1,20 @@
|
||||
# Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
"""Host multiple Foundry-powered agents inside a single Azure Functions app.
|
||||
"""Host multiple Azure OpenAI-powered agents inside a single Azure Functions app.
|
||||
|
||||
Components used in this sample:
|
||||
- FoundryChatClient to create agents bound to a shared Foundry deployment.
|
||||
- OpenAIChatCompletionClient configured for Azure OpenAI.
|
||||
- AgentFunctionApp to register multiple agents and expose dedicated HTTP endpoints.
|
||||
- Custom tool functions to demonstrate tool invocation from different agents.
|
||||
|
||||
Prerequisites: set `FOUNDRY_PROJECT_ENDPOINT`, `FOUNDRY_MODEL`, and sign in with Azure CLI before starting the Functions host."""
|
||||
Prerequisites: set `AZURE_OPENAI_ENDPOINT`, `AZURE_OPENAI_DEPLOYMENT_NAME`, and sign in with Azure CLI before starting the Functions host."""
|
||||
|
||||
import logging
|
||||
import os
|
||||
from typing import Any
|
||||
|
||||
from agent_framework import Agent, tool
|
||||
from agent_framework.azure import AgentFunctionApp
|
||||
from agent_framework.foundry import FoundryChatClient
|
||||
from agent_framework.openai import OpenAIChatCompletionClient
|
||||
from azure.identity.aio import AzureCliCredential
|
||||
from dotenv import load_dotenv
|
||||
|
||||
@@ -60,9 +59,7 @@ def calculate_tip(bill_amount: float, tip_percentage: float = 15.0) -> dict[str,
|
||||
|
||||
|
||||
# 1. Create multiple agents, each with its own instruction set and tools.
|
||||
client = FoundryChatClient(
|
||||
project_endpoint=os.environ["FOUNDRY_PROJECT_ENDPOINT"],
|
||||
model=os.environ["FOUNDRY_MODEL"],
|
||||
client = OpenAIChatCompletionClient(
|
||||
credential=AzureCliCredential(),
|
||||
)
|
||||
|
||||
|
||||
@@ -8,7 +8,7 @@ each with their own specialized capabilities and tools.
|
||||
|
||||
Prerequisites:
|
||||
- The worker must be running with both agents registered
|
||||
- Set FOUNDRY_PROJECT_ENDPOINT and FOUNDRY_MODEL
|
||||
- Set AZURE_OPENAI_ENDPOINT and AZURE_OPENAI_DEPLOYMENT_NAME when running the worker
|
||||
- Sign in with Azure CLI for AzureCliCredential authentication
|
||||
- Durable Task Scheduler must be running
|
||||
"""
|
||||
|
||||
@@ -5,7 +5,7 @@ This sample demonstrates running both the worker and client in a single process
|
||||
for multiple agents with different tools. The worker registers two agents
|
||||
(WeatherAgent and MathAgent), each with their own specialized capabilities.
|
||||
Prerequisites:
|
||||
- Set FOUNDRY_PROJECT_ENDPOINT and FOUNDRY_MODEL
|
||||
- Set AZURE_OPENAI_ENDPOINT and AZURE_OPENAI_DEPLOYMENT_NAME
|
||||
- Sign in with Azure CLI for AzureCliCredential authentication
|
||||
- Durable Task Scheduler must be running (e.g., using Docker)
|
||||
To run this sample:
|
||||
|
||||
@@ -1,13 +1,13 @@
|
||||
# Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
"""Worker process for hosting multiple agents with different tools using Durable Task.
|
||||
"""Worker process for hosting multiple Azure OpenAI agents with different tools using Durable Task.
|
||||
|
||||
This worker registers two agents - a weather assistant and a math assistant - each
|
||||
with their own specialized tools. This demonstrates how to host multiple agents
|
||||
with different capabilities in a single worker process.
|
||||
|
||||
Prerequisites:
|
||||
- Set FOUNDRY_PROJECT_ENDPOINT and FOUNDRY_MODEL
|
||||
- Set AZURE_OPENAI_ENDPOINT and AZURE_OPENAI_DEPLOYMENT_NAME
|
||||
- Sign in with Azure CLI for AzureCliCredential authentication
|
||||
- Start a Durable Task Scheduler (e.g., using Docker)
|
||||
"""
|
||||
@@ -19,7 +19,7 @@ from typing import Any
|
||||
|
||||
from agent_framework import Agent, tool
|
||||
from agent_framework.azure import DurableAIAgentWorker
|
||||
from agent_framework.foundry import FoundryChatClient
|
||||
from agent_framework.openai import OpenAIChatCompletionClient
|
||||
from azure.identity import AzureCliCredential
|
||||
from azure.identity.aio import AzureCliCredential as AsyncAzureCliCredential
|
||||
from dotenv import load_dotenv
|
||||
@@ -73,13 +73,10 @@ def create_weather_agent():
|
||||
Returns:
|
||||
Agent: The configured Weather agent with weather tool
|
||||
"""
|
||||
_client = FoundryChatClient(
|
||||
project_endpoint=os.environ["FOUNDRY_PROJECT_ENDPOINT"],
|
||||
model=os.environ["FOUNDRY_MODEL"],
|
||||
credential=AsyncAzureCliCredential(),
|
||||
)
|
||||
return Agent(
|
||||
client=_client,
|
||||
client=OpenAIChatCompletionClient(
|
||||
credential=AsyncAzureCliCredential(),
|
||||
),
|
||||
name=WEATHER_AGENT_NAME,
|
||||
instructions="You are a helpful weather assistant. Provide current weather information.",
|
||||
tools=[get_weather],
|
||||
@@ -92,13 +89,10 @@ def create_math_agent():
|
||||
Returns:
|
||||
Agent: The configured Math agent with calculation tools
|
||||
"""
|
||||
_client = FoundryChatClient(
|
||||
project_endpoint=os.environ["FOUNDRY_PROJECT_ENDPOINT"],
|
||||
model=os.environ["FOUNDRY_MODEL"],
|
||||
credential=AsyncAzureCliCredential(),
|
||||
)
|
||||
return Agent(
|
||||
client=_client,
|
||||
client=OpenAIChatCompletionClient(
|
||||
credential=AsyncAzureCliCredential(),
|
||||
),
|
||||
name=MATH_AGENT_NAME,
|
||||
instructions="You are a helpful math assistant. Help users with calculations like tip calculations.",
|
||||
tools=[calculate_tip],
|
||||
|
||||
Reference in New Issue
Block a user