[BREAKING] Python: clean up kwargs across agents, chat clients, tools, and sessions (#4581)

* Python: clean up kwargs across agents, chat clients, tools, and sessions (#3642)

Audit and refactor public **kwargs usage across core agents, chat clients,
tools, sessions, and provider packages per the migration strategy codified
in CODING_STANDARD.md.

Key changes:
- Add explicit runtime buckets: function_invocation_kwargs and client_kwargs
  on RawAgent.run() and chat client get_response() layers.
- Refactor FunctionTool to prefer explicit ctx: FunctionInvocationContext
  injection; legacy **kwargs tools still work via _forward_runtime_kwargs.
- Refactor Agent.as_tool() to use direct JSON schema, always-streaming
  wrapper, approval_mode parameter, and UserInputRequiredException
  propagation (integrates PR #4568 behavior).
- Remove implicit session bleeding into FunctionInvocationContext; tools
  that need a session must receive it via function_invocation_kwargs.
- Lower chat-client layers after FunctionInvocationLayer accept only
  compatibility **kwargs (client_kwargs flattened, function_invocation_kwargs
  ignored).
- Add layered docstring composition from Raw... implementations via
  _docstrings.py helper.
- Clean up provider constructors to use explicit additional_properties.
- Deprecation warnings on legacy direct kwargs paths.
- Update samples, tests, and typing across all 23 packages.

Resolves #3642

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>

* clarified docstring

* feedback fixes

* Add unit tests for _docstrings.py build/apply helpers

Tests cover: no docstring source, no extra kwargs, appending to existing
Keyword Args section, inserting after Args, inserting in plain docstrings,
multiline descriptions, ordering, and apply_layered_docstring.

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>

* Add test for propagate_session TypeError on non-AgentSession values

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>

* Add tests for multi-content and empty UserInputRequiredException propagation

Cover the branching logic in _try_execute_function_calls for:
- Multiple user_input_request items in a single exception (extra_user_input_contents path)
- Empty contents list (fallback function_result path)

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>

* Add tests for DurableAIAgent.get_session forwarding service_session_id

Verifies get_session correctly forwards service_session_id and session_id
to the executor's get_new_session, replacing the removed kwargs test.

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>

* Simplify ag-ui test stub to read session from client_kwargs only

Remove dual-mode detection (client_kwargs vs raw kwargs fallback) from
the test mock. Session is now read exclusively from client_kwargs,
matching the settled public calling convention.

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>

* updated create and get sessions in durable

* fixed docstrings

* fix test

* updated session handling

* updated from main

* updated tests

---------

Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
This commit is contained in:
Eduard van Valkenburg
2026-03-13 09:58:32 +01:00
committed by GitHub
Unverified
parent b7990908fe
commit a4b9539b62
52 changed files with 2060 additions and 562 deletions
+5
View File
@@ -127,7 +127,12 @@ def create_agent(name: str, tool_mode: Literal['auto', 'required', 'none'] | Cha
Avoid `**kwargs` unless absolutely necessary. It should only be used as an escape route, not for well-known flows of data:
- **Prefer named parameters**: If there are known extra arguments being passed, use explicit named parameters instead of kwargs
- **Prefer purpose-specific buckets over generic kwargs**: If a flexible payload is still needed, use an explicit named parameter such as `additional_properties`, `function_invocation_kwargs`, or `client_kwargs` rather than a blanket `**kwargs`
- **Subclassing support**: kwargs is acceptable in methods that are part of classes designed for subclassing, allowing subclass-defined kwargs to pass through without issues. In this case, clearly document that kwargs exists for subclass extensibility and not for passing arbitrary data
- **Make known flows explicit first**: For abstract hooks, move known data flows into explicit parameters before leaving `**kwargs` behind for subclass extensibility (for example, prefer `state=` explicitly instead of passing it through kwargs)
- **Prefer explicit metadata containers**: For constructors that expose metadata, prefer an explicit `additional_properties` parameter.
- **Keep SDK passthroughs narrow and documented**: A kwargs escape hatch may be acceptable for provider helper APIs that pass through to a large or unstable external SDK surface, but it should be documented as SDK passthrough and revisited regularly
- **Do not keep passthrough kwargs on wrappers that do not use them**: Convenience wrappers and session helpers should not accept generic kwargs merely to forward or ignore them
- **Remove when possible**: In other cases, removing kwargs is likely better than keeping it
- **Separate kwargs by purpose**: When combining kwargs for multiple purposes, use specific parameters like `client_kwargs: dict[str, Any]` instead of mixing everything in `**kwargs`
- **Always document**: If kwargs must be used, always document how it's used, either by referencing external documentation or explaining its purpose
@@ -6,7 +6,7 @@ import base64
import json
import re
import uuid
from collections.abc import AsyncIterable, Awaitable, Sequence
from collections.abc import AsyncIterable, Awaitable, Mapping, Sequence
from typing import Any, Final, Literal, TypeAlias, overload
import httpx
@@ -226,6 +226,8 @@ class A2AAgent(AgentTelemetryLayer, BaseAgent):
*,
stream: Literal[False] = ...,
session: AgentSession | None = None,
function_invocation_kwargs: Mapping[str, Any] | None = None,
client_kwargs: Mapping[str, Any] | None = None,
continuation_token: A2AContinuationToken | None = None,
background: bool = False,
**kwargs: Any,
@@ -238,17 +240,21 @@ class A2AAgent(AgentTelemetryLayer, BaseAgent):
*,
stream: Literal[True],
session: AgentSession | None = None,
function_invocation_kwargs: Mapping[str, Any] | None = None,
client_kwargs: Mapping[str, Any] | None = None,
continuation_token: A2AContinuationToken | None = None,
background: bool = False,
**kwargs: Any,
) -> ResponseStream[AgentResponseUpdate, AgentResponse[Any]]: ...
def run(
def run( # pyright: ignore[reportIncompatibleMethodOverride]
self,
messages: AgentRunInputs | None = None,
*,
stream: bool = False,
session: AgentSession | None = None,
function_invocation_kwargs: Mapping[str, Any] | None = None,
client_kwargs: Mapping[str, Any] | None = None,
continuation_token: A2AContinuationToken | None = None,
background: bool = False,
**kwargs: Any,
@@ -261,17 +267,23 @@ class A2AAgent(AgentTelemetryLayer, BaseAgent):
Keyword Args:
stream: Whether to stream the response. Defaults to False.
session: The conversation session associated with the message(s).
function_invocation_kwargs: Present for compatibility with the shared agent interface.
A2AAgent does not use these values directly.
client_kwargs: Present for compatibility with the shared agent interface.
A2AAgent does not use these values directly.
kwargs: Additional compatibility keyword arguments.
A2AAgent does not use these values directly.
continuation_token: Optional token to resume a long-running task
instead of starting a new one.
background: When True, in-progress task updates surface continuation
tokens so the caller can poll or resubscribe later. When False
(default), the agent internally waits for the task to complete.
kwargs: Additional keyword arguments.
Returns:
When stream=False: An Awaitable[AgentResponse].
When stream=True: A ResponseStream of AgentResponseUpdate items.
"""
del function_invocation_kwargs, client_kwargs, kwargs
if continuation_token is not None:
a2a_stream: AsyncIterable[A2AStreamItem] = self.client.resubscribe(
TaskIdParams(id=continuation_token["task_id"])
@@ -220,7 +220,6 @@ class AGUIChatClient(
additional_properties: dict[str, Any] | None = None,
middleware: Sequence[ChatAndFunctionMiddlewareTypes] | None = None,
function_invocation_configuration: FunctionInvocationConfiguration | None = None,
**kwargs: Any,
) -> None:
"""Initialize the AG-UI chat client.
@@ -231,13 +230,11 @@ class AGUIChatClient(
additional_properties: Additional properties to store
middleware: Optional middleware to apply to the client.
function_invocation_configuration: Optional function invocation configuration override.
**kwargs: Additional arguments passed to BaseChatClient
"""
super().__init__(
additional_properties=additional_properties,
middleware=middleware,
function_invocation_configuration=function_invocation_configuration,
**kwargs,
)
self._http_service = AGUIHttpService(
endpoint=endpoint,
@@ -98,7 +98,11 @@ class StreamingChatClientStub(
options: OptionsCoT | ChatOptions[Any] | None = None,
**kwargs: Any,
) -> Awaitable[ChatResponse[Any]] | ResponseStream[ChatResponseUpdate, ChatResponse[Any]]:
self.last_session = kwargs.get("session")
client_kwargs = kwargs.get("client_kwargs")
if isinstance(client_kwargs, Mapping):
self.last_session = cast(AgentSession | None, client_kwargs.get("session"))
else:
self.last_session = None
self.last_service_session_id = self.last_session.service_session_id if self.last_session else None
return cast(
Awaitable[ChatResponse[Any]] | ResponseStream[ChatResponseUpdate, ChatResponse[Any]],
@@ -702,14 +702,9 @@ async def test_agent_with_use_service_session_is_true(streaming_chat_client_stub
"""Test that when use_service_session is True, the AgentSession used to run the agent is set to the service session ID."""
from agent_framework.ag_ui import AgentFrameworkAgent
request_service_session_id: str | None = None
async def stream_fn(
messages: MutableSequence[Message], chat_options: ChatOptions, **kwargs: Any
) -> AsyncIterator[ChatResponseUpdate]:
nonlocal request_service_session_id
session = kwargs.get("session")
request_service_session_id = session.service_session_id if session else None
yield ChatResponseUpdate(
contents=[Content.from_text(text="Response")], response_id="resp_67890", conversation_id="conv_12345"
)
@@ -719,11 +714,22 @@ async def test_agent_with_use_service_session_is_true(streaming_chat_client_stub
input_data = {"messages": [{"role": "user", "content": "Hi"}], "thread_id": "conv_123456"}
# Spy on agent.run to capture the session kwarg at call time (before streaming mutates it)
captured_service_session_id: str | None = None
original_run = agent.run
def capturing_run(*args: Any, **kwargs: Any) -> Any:
nonlocal captured_service_session_id
session = kwargs.get("session")
captured_service_session_id = session.service_session_id if session else None
return original_run(*args, **kwargs)
agent.run = capturing_run # type: ignore[assignment, method-assign]
events: list[Any] = []
async for event in wrapper.run(input_data):
events.append(event)
request_service_session_id = agent.client.last_service_session_id
assert request_service_session_id == "conv_123456" # type: ignore[attr-defined] (service_session_id should be set)
assert captured_service_session_id == "conv_123456"
async def test_function_approval_mode_executes_tool(streaming_chat_client_stub):
@@ -228,11 +228,11 @@ class AnthropicClient(
model_id: str | None = None,
anthropic_client: AsyncAnthropic | None = None,
additional_beta_flags: list[str] | None = None,
additional_properties: dict[str, Any] | None = None,
middleware: Sequence[ChatAndFunctionMiddlewareTypes] | None = None,
function_invocation_configuration: FunctionInvocationConfiguration | None = None,
env_file_path: str | None = None,
env_file_encoding: str | None = None,
**kwargs: Any,
) -> None:
"""Initialize an Anthropic Agent client.
@@ -244,11 +244,11 @@ class AnthropicClient(
For instance if you need to set a different base_url for testing or private deployments.
additional_beta_flags: Additional beta flags to enable on the client.
Default flags are: "mcp-client-2025-04-04", "code-execution-2025-08-25".
additional_properties: Additional properties stored on the client instance.
middleware: Optional middleware to apply to the client.
function_invocation_configuration: Optional function invocation configuration override.
env_file_path: Path to environment file for loading settings.
env_file_encoding: Encoding of the environment file.
kwargs: Additional keyword arguments passed to the parent class.
Examples:
.. code-block:: python
@@ -319,9 +319,9 @@ class AnthropicClient(
# Initialize parent
super().__init__(
additional_properties=additional_properties,
middleware=middleware,
function_invocation_configuration=function_invocation_configuration,
**kwargs,
)
# Initialize instance variables
@@ -17,10 +17,15 @@ from agent_framework_azure_ai_search._context_provider import AzureAISearchConte
@pytest.fixture(autouse=True)
def clear_azure_search_environment(monkeypatch: pytest.MonkeyPatch) -> None:
for key in tuple(os.environ):
if key.startswith("AZURE_SEARCH_"):
monkeypatch.delenv(key, raising=False)
def clear_azure_search_env(monkeypatch: pytest.MonkeyPatch) -> None:
"""Keep tests isolated from ambient Azure Search environment variables."""
for key in (
"AZURE_SEARCH_ENDPOINT",
"AZURE_SEARCH_INDEX_NAME",
"AZURE_SEARCH_KNOWLEDGE_BASE_NAME",
"AZURE_SEARCH_API_KEY",
):
monkeypatch.delenv(key, raising=False)
class MockSearchResults:
@@ -444,11 +444,11 @@ class AzureAIAgentClient(
model_deployment_name: str | None = None,
credential: AzureCredentialTypes | None = None,
should_cleanup_agent: bool = True,
additional_properties: dict[str, Any] | None = None,
middleware: Sequence[ChatAndFunctionMiddlewareTypes] | None = None,
function_invocation_configuration: FunctionInvocationConfiguration | None = None,
env_file_path: str | None = None,
env_file_encoding: str | None = None,
**kwargs: Any,
) -> None:
"""Initialize an Azure AI Agent client.
@@ -471,11 +471,11 @@ class AzureAIAgentClient(
should_cleanup_agent: Whether to cleanup (delete) agents created by this client when
the client is closed or context is exited. Defaults to True. Only affects agents
created by this client instance; existing agents passed via agent_id are never deleted.
additional_properties: Additional properties stored on the client instance.
middleware: Optional sequence of middlewares to include.
function_invocation_configuration: Optional function invocation configuration.
env_file_path: Path to environment file for loading settings.
env_file_encoding: Encoding of the environment file.
kwargs: Additional keyword arguments passed to the parent class.
Examples:
.. code-block:: python
@@ -548,9 +548,9 @@ class AzureAIAgentClient(
# Initialize parent
super().__init__(
additional_properties=additional_properties,
middleware=middleware,
function_invocation_configuration=function_invocation_configuration,
**kwargs,
)
# Initialize instance variables
@@ -119,9 +119,9 @@ class RawAzureAIClient(RawOpenAIResponsesClient[AzureAIClientOptionsT], Generic[
credential: AzureCredentialTypes | None = None,
use_latest_version: bool | None = None,
allow_preview: bool | None = None,
additional_properties: dict[str, Any] | None = None,
env_file_path: str | None = None,
env_file_encoding: str | None = None,
**kwargs: Any,
) -> None:
"""Initialize a bare Azure AI client.
@@ -145,9 +145,9 @@ class RawAzureAIClient(RawOpenAIResponsesClient[AzureAIClientOptionsT], Generic[
use_latest_version: Boolean flag that indicates whether to use latest agent version
if it exists in the service.
allow_preview: Enables preview opt-in on internally-created ``AIProjectClient``.
additional_properties: Additional properties stored on the client instance.
env_file_path: Path to environment file for loading settings.
env_file_encoding: Encoding of the environment file.
kwargs: Additional keyword arguments passed to the parent class.
Examples:
.. code-block:: python
@@ -217,7 +217,7 @@ class RawAzureAIClient(RawOpenAIResponsesClient[AzureAIClientOptionsT], Generic[
# Initialize parent
super().__init__(
**kwargs,
additional_properties=additional_properties,
)
# Initialize instance variables
@@ -1243,11 +1243,11 @@ class AzureAIClient(
credential: AzureCredentialTypes | None = None,
use_latest_version: bool | None = None,
allow_preview: bool | None = None,
additional_properties: dict[str, Any] | None = None,
middleware: Sequence[ChatAndFunctionMiddlewareTypes] | None = None,
function_invocation_configuration: FunctionInvocationConfiguration | None = None,
env_file_path: str | None = None,
env_file_encoding: str | None = None,
**kwargs: Any,
) -> None:
"""Initialize an Azure AI client with full layer support.
@@ -1268,11 +1268,11 @@ class AzureAIClient(
use_latest_version: Boolean flag that indicates whether to use latest agent version
if it exists in the service.
allow_preview: Enables preview opt-in on internally-created ``AIProjectClient``
additional_properties: Additional properties stored on the client instance.
middleware: Optional sequence of chat middlewares to include.
function_invocation_configuration: Optional function invocation configuration.
env_file_path: Path to environment file for loading settings.
env_file_encoding: Encoding of the environment file.
kwargs: Additional keyword arguments passed to the parent class.
Examples:
.. code-block:: python
@@ -1319,9 +1319,9 @@ class AzureAIClient(
credential=credential,
use_latest_version=use_latest_version,
allow_preview=allow_preview,
additional_properties=additional_properties,
middleware=middleware,
function_invocation_configuration=function_invocation_configuration,
env_file_path=env_file_path,
env_file_encoding=env_file_encoding,
**kwargs,
)
@@ -124,9 +124,9 @@ class RawAzureAIInferenceEmbeddingClient(
text_client: EmbeddingsClient | None = None,
image_client: ImageEmbeddingsClient | None = None,
credential: AzureKeyCredential | None = None,
additional_properties: dict[str, Any] | None = None,
env_file_path: str | None = None,
env_file_encoding: str | None = None,
**kwargs: Any,
) -> None:
"""Initialize a raw Azure AI Inference embedding client."""
settings = load_settings(
@@ -160,7 +160,7 @@ class RawAzureAIInferenceEmbeddingClient(
credential=credential, # type: ignore[arg-type]
)
self._endpoint = resolved_endpoint
super().__init__(**kwargs)
super().__init__(additional_properties=additional_properties)
async def close(self) -> None:
"""Close the underlying SDK clients and release resources."""
@@ -376,9 +376,9 @@ class AzureAIInferenceEmbeddingClient(
image_client: ImageEmbeddingsClient | None = None,
credential: AzureKeyCredential | None = None,
otel_provider_name: str | None = None,
additional_properties: dict[str, Any] | None = None,
env_file_path: str | None = None,
env_file_encoding: str | None = None,
**kwargs: Any,
) -> None:
"""Initialize an Azure AI Inference embedding client."""
super().__init__(
@@ -389,8 +389,8 @@ class AzureAIInferenceEmbeddingClient(
text_client=text_client,
image_client=image_client,
credential=credential,
additional_properties=additional_properties,
otel_provider_name=otel_provider_name,
env_file_path=env_file_path,
env_file_encoding=env_file_encoding,
**kwargs,
)
@@ -124,7 +124,13 @@ class CosmosHistoryProvider(BaseHistoryProvider):
self._database_client = self._cosmos_client.get_database_client(self.database_name)
async def get_messages(self, session_id: str | None, **kwargs: Any) -> list[Message]:
async def get_messages(
self,
session_id: str | None,
*,
state: dict[str, Any] | None = None,
**kwargs: Any,
) -> list[Message]:
"""Retrieve stored messages for this session from Azure Cosmos DB."""
await self._ensure_container_proxy()
session_key = self._session_partition_key(session_id)
@@ -157,7 +163,14 @@ class CosmosHistoryProvider(BaseHistoryProvider):
return messages
async def save_messages(self, session_id: str | None, messages: Sequence[Message], **kwargs: Any) -> None:
async def save_messages(
self,
session_id: str | None,
messages: Sequence[Message],
*,
state: dict[str, Any] | None = None,
**kwargs: Any,
) -> None:
"""Persist messages for this session to Azure Cosmos DB."""
if not messages:
return
@@ -236,11 +236,11 @@ class BedrockChatClient(
session_token: str | None = None,
client: BaseClient | None = None,
boto3_session: Boto3Session | None = None,
additional_properties: dict[str, Any] | None = None,
middleware: Sequence[ChatAndFunctionMiddlewareTypes] | None = None,
function_invocation_configuration: FunctionInvocationConfiguration | None = None,
env_file_path: str | None = None,
env_file_encoding: str | None = None,
**kwargs: Any,
) -> None:
"""Create a Bedrock chat client and load AWS credentials.
@@ -252,11 +252,11 @@ class BedrockChatClient(
session_token: Optional AWS session token for temporary credentials.
client: Preconfigured Bedrock runtime client; when omitted a boto3 session is created.
boto3_session: Custom boto3 session used to build the runtime client if provided.
additional_properties: Additional properties stored on the client instance.
middleware: Optional sequence of middlewares to include.
function_invocation_configuration: Optional function invocation configuration
env_file_path: Optional .env file path used by ``BedrockSettings`` to load defaults.
env_file_encoding: Encoding for the optional .env file.
kwargs: Additional arguments forwarded to ``BaseChatClient``.
Examples:
.. code-block:: python
@@ -303,9 +303,9 @@ class BedrockChatClient(
)
super().__init__(
additional_properties=additional_properties,
middleware=middleware,
function_invocation_configuration=function_invocation_configuration,
**kwargs,
)
self.model_id = chat_model_id
self.region = region
@@ -104,9 +104,9 @@ class RawBedrockEmbeddingClient(
session_token: str | None = None,
client: BaseClient | None = None,
boto3_session: Boto3Session | None = None,
additional_properties: dict[str, Any] | None = None,
env_file_path: str | None = None,
env_file_encoding: str | None = None,
**kwargs: Any,
) -> None:
"""Initialize a raw Bedrock embedding client."""
settings = load_settings(
@@ -145,7 +145,7 @@ class RawBedrockEmbeddingClient(
self.model_id: str = settings["embedding_model_id"] # type: ignore[assignment] # pyright: ignore[reportTypedDictNotRequiredAccess]
self.region = resolved_region
super().__init__(**kwargs)
super().__init__(additional_properties=additional_properties)
def service_url(self) -> str:
"""Get the URL of the service."""
@@ -274,9 +274,9 @@ class BedrockEmbeddingClient(
client: BaseClient | None = None,
boto3_session: Boto3Session | None = None,
otel_provider_name: str | None = None,
additional_properties: dict[str, Any] | None = None,
env_file_path: str | None = None,
env_file_encoding: str | None = None,
**kwargs: Any,
) -> None:
"""Initialize a Bedrock embedding client."""
super().__init__(
@@ -287,8 +287,8 @@ class BedrockEmbeddingClient(
session_token=session_token,
client=client,
boto3_session=boto3_session,
additional_properties=additional_properties,
otel_provider_name=otel_provider_name,
env_file_path=env_file_path,
env_file_encoding=env_file_encoding,
**kwargs,
)
@@ -590,6 +590,7 @@ class RawClaudeAgent(BaseAgent, Generic[OptionsT]):
*,
stream: Literal[False] = ...,
session: AgentSession | None = None,
options: OptionsT | None = None,
**kwargs: Any,
) -> Awaitable[AgentResponse[Any]]: ...
@@ -600,6 +601,7 @@ class RawClaudeAgent(BaseAgent, Generic[OptionsT]):
*,
stream: Literal[True],
session: AgentSession | None = None,
options: OptionsT | None = None,
**kwargs: Any,
) -> ResponseStream[AgentResponseUpdate, AgentResponse[Any]]: ...
@@ -609,7 +611,8 @@ class RawClaudeAgent(BaseAgent, Generic[OptionsT]):
*,
stream: bool = False,
session: AgentSession | None = None,
**kwargs: Any,
options: OptionsT | None = None,
**kwargs: Any, # type: ignore
) -> Awaitable[AgentResponse[Any]] | ResponseStream[AgentResponseUpdate, AgentResponse[Any]]:
"""Run the agent with the given messages.
@@ -621,16 +624,16 @@ class RawClaudeAgent(BaseAgent, Generic[OptionsT]):
returns an awaitable AgentResponse.
session: The conversation session. If session has service_session_id set,
the agent will resume that session.
kwargs: Additional keyword arguments including 'options' for runtime options
(model, permission_mode can be changed per-request).
options: Runtime options. Model and permission_mode can be changed per request.
kwargs: Additional keyword arguments for compatibility with the shared agent
interface (e.g. compaction_strategy, tokenizer). Not used by ClaudeAgent.
Returns:
When stream=True: An ResponseStream for streaming updates.
When stream=False: An Awaitable[AgentResponse] with the complete response.
"""
options = kwargs.pop("options", None)
response = ResponseStream(
self._get_stream(messages, session=session, options=options, **kwargs),
self._get_stream(messages, session=session, options=options),
finalizer=self._finalize_response,
)
@@ -643,8 +646,7 @@ class RawClaudeAgent(BaseAgent, Generic[OptionsT]):
messages: AgentRunInputs | None = None,
*,
session: AgentSession | None = None,
options: OptionsT | MutableMapping[str, Any] | None = None,
**kwargs: Any,
options: OptionsT | None = None,
) -> AsyncIterable[AgentResponseUpdate]:
"""Internal streaming implementation."""
session = session or self.create_session()
@@ -196,7 +196,6 @@ class CopilotStudioAgent(BaseAgent):
*,
stream: Literal[False] = False,
session: AgentSession | None = None,
**kwargs: Any,
) -> Awaitable[AgentResponse]: ...
@overload
@@ -206,7 +205,6 @@ class CopilotStudioAgent(BaseAgent):
*,
stream: Literal[True],
session: AgentSession | None = None,
**kwargs: Any,
) -> ResponseStream[AgentResponseUpdate, AgentResponse]: ...
def run(
@@ -215,7 +213,6 @@ class CopilotStudioAgent(BaseAgent):
*,
stream: bool = False,
session: AgentSession | None = None,
**kwargs: Any,
) -> Awaitable[AgentResponse] | ResponseStream[AgentResponseUpdate, AgentResponse]:
"""Get a response from the agent.
@@ -229,22 +226,20 @@ class CopilotStudioAgent(BaseAgent):
Keyword Args:
stream: Whether to stream the response. Defaults to False.
session: The conversation session associated with the message(s).
kwargs: Additional keyword arguments.
Returns:
When stream=False: An Awaitable[AgentResponse].
When stream=True: A ResponseStream of AgentResponseUpdate items.
"""
if stream:
return self._run_stream_impl(messages=messages, session=session, **kwargs)
return self._run_impl(messages=messages, session=session, **kwargs)
return self._run_stream_impl(messages=messages, session=session)
return self._run_impl(messages=messages, session=session)
async def _run_impl(
self,
messages: AgentRunInputs | None = None,
*,
session: AgentSession | None = None,
**kwargs: Any,
) -> AgentResponse:
"""Non-streaming implementation of run."""
if not session:
@@ -269,7 +264,6 @@ class CopilotStudioAgent(BaseAgent):
messages: AgentRunInputs | None = None,
*,
session: AgentSession | None = None,
**kwargs: Any,
) -> ResponseStream[AgentResponseUpdate, AgentResponse]:
"""Streaming implementation of run."""
@@ -215,6 +215,7 @@ from ._workflows._workflow_executor import (
)
from .exceptions import (
MiddlewareException,
UserInputRequiredException,
WorkflowCheckpointException,
WorkflowConvergenceException,
WorkflowException,
@@ -349,6 +350,7 @@ __all__ = [
"TypeCompatibilityError",
"UpdateT",
"UsageDetails",
"UserInputRequiredException",
"ValidationTypeEnum",
"Workflow",
"WorkflowAgent",
+225 -99
View File
@@ -2,10 +2,10 @@
from __future__ import annotations
import inspect
import logging
import re
import sys
import warnings
from collections.abc import Awaitable, Callable, Mapping, MutableMapping, Sequence
from contextlib import AbstractAsyncContextManager, AsyncExitStack
from copy import deepcopy
@@ -27,12 +27,13 @@ from uuid import uuid4
from mcp import types
from mcp.server.lowlevel import Server
from mcp.shared.exceptions import McpError
from pydantic import BaseModel, Field, create_model
from pydantic import BaseModel
from . import _tools as _tool_utils # pyright: ignore[reportPrivateUsage]
from ._clients import BaseChatClient, SupportsChatGetResponse
from ._docstrings import apply_layered_docstring
from ._mcp import LOG_LEVEL_MAPPING, MCPTool
from ._middleware import AgentMiddlewareLayer, MiddlewareTypes
from ._middleware import AgentMiddlewareLayer, FunctionInvocationContext, MiddlewareTypes
from ._serialization import SerializationMixin
from ._sessions import (
AgentSession,
@@ -53,7 +54,7 @@ from ._types import (
map_chat_to_agent_update,
normalize_messages,
)
from .exceptions import AgentInvalidResponseException
from .exceptions import AgentInvalidResponseException, UserInputRequiredException
from .observability import AgentTelemetryLayer
if sys.version_info >= (3, 13):
@@ -169,8 +170,8 @@ class _RunContext(TypedDict):
chat_options: MutableMapping[str, Any]
compaction_strategy: CompactionStrategy | None
tokenizer: TokenizerProtocol | None
filtered_kwargs: Mapping[str, Any]
finalize_kwargs: Mapping[str, Any]
client_kwargs: Mapping[str, Any]
function_invocation_kwargs: Mapping[str, Any]
# region Agent Protocol
@@ -218,15 +219,15 @@ class SupportsAgentRun(Protocol):
return AgentResponse(messages=[], response_id="custom-response")
def create_session(self, **kwargs):
def create_session(self, *, session_id: str | None = None):
from agent_framework import AgentSession
return AgentSession(**kwargs)
return AgentSession(session_id=session_id)
def get_session(self, *, service_session_id, **kwargs):
def get_session(self, service_session_id: str, *, session_id: str | None = None):
from agent_framework import AgentSession
return AgentSession(service_session_id=service_session_id, **kwargs)
return AgentSession(service_session_id=service_session_id, session_id=session_id)
# Verify the instance satisfies the protocol
@@ -245,6 +246,8 @@ class SupportsAgentRun(Protocol):
*,
stream: Literal[False] = ...,
session: AgentSession | None = None,
function_invocation_kwargs: Mapping[str, Any] | None = None,
client_kwargs: Mapping[str, Any] | None = None,
**kwargs: Any,
) -> Awaitable[AgentResponse[Any]]:
"""Get a response from the agent (non-streaming)."""
@@ -257,6 +260,8 @@ class SupportsAgentRun(Protocol):
*,
stream: Literal[True],
session: AgentSession | None = None,
function_invocation_kwargs: Mapping[str, Any] | None = None,
client_kwargs: Mapping[str, Any] | None = None,
**kwargs: Any,
) -> ResponseStream[AgentResponseUpdate, AgentResponse[Any]]:
"""Get a streaming response from the agent."""
@@ -268,6 +273,8 @@ class SupportsAgentRun(Protocol):
*,
stream: bool = False,
session: AgentSession | None = None,
function_invocation_kwargs: Mapping[str, Any] | None = None,
client_kwargs: Mapping[str, Any] | None = None,
**kwargs: Any,
) -> Awaitable[AgentResponse[Any]] | ResponseStream[AgentResponseUpdate, AgentResponse[Any]]:
"""Get a response from the agent.
@@ -282,6 +289,8 @@ class SupportsAgentRun(Protocol):
Keyword Args:
stream: Whether to stream the response. Defaults to False.
session: The conversation session associated with the message(s).
function_invocation_kwargs: Keyword arguments forwarded to tool invocation.
client_kwargs: Additional client-specific keyword arguments.
kwargs: Additional keyword arguments.
Returns:
@@ -291,11 +300,11 @@ class SupportsAgentRun(Protocol):
"""
...
def create_session(self, **kwargs: Any) -> AgentSession:
def create_session(self, *, session_id: str | None = None) -> AgentSession:
"""Creates a new conversation session."""
...
def get_session(self, *, service_session_id: str, **kwargs: Any) -> AgentSession:
def get_session(self, service_session_id: str, *, session_id: str | None = None) -> AgentSession:
"""Gets or creates a session for a service-managed session ID."""
...
@@ -378,6 +387,13 @@ class BaseAgent(SerializationMixin):
additional_properties: Additional properties set on the agent.
kwargs: Additional keyword arguments (merged into additional_properties).
"""
if kwargs:
warnings.warn(
"Passing additional properties as direct keyword arguments to BaseAgent is deprecated; "
"pass them via additional_properties instead.",
DeprecationWarning,
stacklevel=3,
)
if id is None:
id = str(uuid4())
self.id = id
@@ -392,27 +408,40 @@ class BaseAgent(SerializationMixin):
self.additional_properties: dict[str, Any] = cast(dict[str, Any], additional_properties or {})
self.additional_properties.update(kwargs)
def create_session(self, *, session_id: str | None = None, **kwargs: Any) -> AgentSession:
def create_session(self, *, session_id: str | None = None) -> AgentSession:
"""Create a new lightweight session.
This will be used by an agent to hold the persisted session.
This depends on the service used, in some cases, or with store=True
this will add the ``service_session_id`` based on the response,
which is then fed back to the API on the next call.
In other cases, if there is a HistoryProvider setup in the agent,
that is used and it can store state in the session.
If there is no HistoryProvider and store=False or the default of a service is False.
Then a ``InMemoryHistoryProvider`` instance is added to the agent and used with the session automatically.
The ``InMemoryHistoryProvider`` stores the messages as `state` in the session by default.
Keyword Args:
session_id: Optional session ID (generated if not provided).
kwargs: Additional keyword arguments.
Returns:
A new AgentSession instance.
"""
return AgentSession(session_id=session_id)
def get_session(self, *, service_session_id: str, session_id: str | None = None, **kwargs: Any) -> AgentSession:
"""Get or create a session for a service-managed session ID.
def get_session(self, service_session_id: str, *, session_id: str | None = None) -> AgentSession:
"""Get a session for a service-managed session ID.
Only use this to create a session continuing that session id from a service.
Otherwise use ``create_session``.
Args:
service_session_id: The service-managed session ID.
Keyword Args:
session_id: Optional local session ID (generated if not provided).
kwargs: Additional keyword arguments.
Returns:
A new AgentSession instance with service_session_id set.
@@ -452,9 +481,8 @@ class BaseAgent(SerializationMixin):
description: str | None = None,
arg_name: str = "task",
arg_description: str | None = None,
stream_callback: Callable[[AgentResponseUpdate], None]
| Callable[[AgentResponseUpdate], Awaitable[None]]
| None = None,
approval_mode: Literal["always_require", "never_require"] = "never_require",
stream_callback: Callable[[AgentResponseUpdate], Awaitable[None] | None] | None = None,
propagate_session: bool = False,
) -> FunctionTool:
"""Create a FunctionTool that wraps this agent.
@@ -465,21 +493,15 @@ class BaseAgent(SerializationMixin):
arg_name: The name of the function argument (default: "task").
arg_description: The description for the function argument.
If None, defaults to "Task for {tool_name}".
approval_mode: Whether this delegated tool requires approval before execution.
stream_callback: Optional callback for streaming responses. If provided, uses run(..., stream=True).
propagate_session: If True, the parent agent's ``AgentSession`` is
forwarded to this sub-agent's ``run()`` call, so both agents
operate within the same logical session (sharing the same
``session_id`` and provider-managed state, such as any stored
conversation history or metadata). Defaults to False, meaning
the sub-agent runs with a new, independent session.
propagate_session: If True, the parent agent's session is forwarded
to this sub-agent's ``run()`` call so both agents share the
same session. Defaults to False.
Returns:
A FunctionTool that can be used as a tool by other agents.
Raises:
TypeError: If the agent does not implement SupportsAgentRun.
ValueError: If the agent tool name cannot be determined.
Examples:
.. code-block:: python
@@ -507,59 +529,46 @@ class BaseAgent(SerializationMixin):
tool_description = description or self.description or ""
argument_description = arg_description or f"Task for {tool_name}"
# Create dynamic input model with the specified argument name
field_info = Field(..., description=argument_description)
model_name = f"{name or _sanitize_agent_name(self.name) or 'agent'}_task"
input_model = create_model(model_name, **{arg_name: (str, field_info)}) # type: ignore[call-overload]
input_schema = {
"type": "object",
"properties": {
arg_name: {
"type": "string",
"description": argument_description,
}
},
"required": [arg_name],
"additionalProperties": False,
}
# Check if callback is async once, outside the wrapper
is_async_callback = stream_callback is not None and inspect.iscoroutinefunction(stream_callback)
async def _agent_wrapper(ctx: FunctionInvocationContext, **kwargs: Any) -> str:
"""Wrapper function that calls the agent.
async def agent_wrapper(**kwargs: Any) -> str:
"""Wrapper function that calls the agent."""
# Extract the input from kwargs using the specified arg_name
input_text = kwargs.get(arg_name, "")
Args:
ctx: the function invocation context used
**kwargs: only used to dynamically load the argument that is defined for this tool.
"""
stream = self.run(
str(kwargs.get(arg_name, "")),
stream=True,
session=ctx.session if propagate_session else None,
function_invocation_kwargs=dict(ctx.kwargs),
)
if stream_callback is not None:
stream.with_transform_hook(stream_callback)
final_response = await stream.get_final_response()
if final_response.user_input_requests:
raise UserInputRequiredException(contents=final_response.user_input_requests)
# TODO(Copilot): update once #4331 merges
return final_response.text
# Extract parent session when propagate_session is enabled
parent_session = kwargs.get("session") if propagate_session else None
# Forward runtime context kwargs, excluding framework-internal keys.
forwarded_kwargs = {
k: v for k, v in kwargs.items() if k not in (arg_name, "conversation_id", "options", "session")
}
if stream_callback is None:
# Use non-streaming mode
return (
await self.run(
input_text,
stream=False,
session=parent_session,
**forwarded_kwargs,
)
).text
# Use streaming mode - accumulate updates and create final response
response_updates: list[AgentResponseUpdate] = []
async for update in self.run(input_text, stream=True, session=parent_session, **forwarded_kwargs):
response_updates.append(update)
if is_async_callback:
await stream_callback(update) # type: ignore[misc]
else:
stream_callback(update)
# Create final text from accumulated updates
return AgentResponse.from_updates(response_updates).text
agent_tool: FunctionTool = FunctionTool(
return FunctionTool(
name=tool_name,
description=tool_description,
func=agent_wrapper,
input_model=input_model, # type: ignore
approval_mode="never_require",
func=_agent_wrapper,
input_model=input_schema,
approval_mode=approval_mode,
)
agent_tool._forward_runtime_kwargs = True # type: ignore
return agent_tool
# region Agent
@@ -801,6 +810,8 @@ class RawAgent(BaseAgent, Generic[OptionsCoT]): # type: ignore[misc]
options: ChatOptions[ResponseModelBoundT],
compaction_strategy: CompactionStrategy | None = None,
tokenizer: TokenizerProtocol | None = None,
function_invocation_kwargs: Mapping[str, Any] | None = None,
client_kwargs: Mapping[str, Any] | None = None,
**kwargs: Any,
) -> Awaitable[AgentResponse[ResponseModelBoundT]]: ...
@@ -815,6 +826,8 @@ class RawAgent(BaseAgent, Generic[OptionsCoT]): # type: ignore[misc]
options: OptionsCoT | ChatOptions[None] | None = None,
compaction_strategy: CompactionStrategy | None = None,
tokenizer: TokenizerProtocol | None = None,
function_invocation_kwargs: Mapping[str, Any] | None = None,
client_kwargs: Mapping[str, Any] | None = None,
**kwargs: Any,
) -> Awaitable[AgentResponse[Any]]: ...
@@ -829,6 +842,8 @@ class RawAgent(BaseAgent, Generic[OptionsCoT]): # type: ignore[misc]
options: OptionsCoT | ChatOptions[Any] | None = None,
compaction_strategy: CompactionStrategy | None = None,
tokenizer: TokenizerProtocol | None = None,
function_invocation_kwargs: Mapping[str, Any] | None = None,
client_kwargs: Mapping[str, Any] | None = None,
**kwargs: Any,
) -> ResponseStream[AgentResponseUpdate, AgentResponse[Any]]: ...
@@ -842,6 +857,8 @@ class RawAgent(BaseAgent, Generic[OptionsCoT]): # type: ignore[misc]
options: OptionsCoT | ChatOptions[Any] | None = None,
compaction_strategy: CompactionStrategy | None = None,
tokenizer: TokenizerProtocol | None = None,
function_invocation_kwargs: Mapping[str, Any] | None = None,
client_kwargs: Mapping[str, Any] | None = None,
**kwargs: Any,
) -> Awaitable[AgentResponse[Any]] | ResponseStream[AgentResponseUpdate, AgentResponse[Any]]:
"""Run the agent with the given messages and options.
@@ -871,14 +888,23 @@ class RawAgent(BaseAgent, Generic[OptionsCoT]): # type: ignore[misc]
tokenizer: Optional per-run tokenizer override passed to
``client.get_response()``. When omitted, the agent-level override
is used, falling back to the client default.
kwargs: Additional keyword arguments for the agent. These are only
passed to functions that are called.
function_invocation_kwargs: Keyword arguments forwarded to tool invocation.
client_kwargs: Additional client-specific keyword arguments for the chat client.
kwargs: Deprecated additional keyword arguments for the agent.
They are forwarded to both tool invocation and the chat client for compatibility.
Returns:
When stream=False: An Awaitable[AgentResponse] containing the agent's response.
When stream=True: A ResponseStream of AgentResponseUpdate items with
``get_final_response()`` for the final AgentResponse.
"""
if kwargs:
warnings.warn(
"Passing runtime keyword arguments directly to run() is deprecated; pass tool values via "
"function_invocation_kwargs and client-specific values via client_kwargs instead.",
DeprecationWarning,
stacklevel=2,
)
if not stream:
async def _run_non_streaming() -> AgentResponse[Any]:
@@ -889,7 +915,9 @@ class RawAgent(BaseAgent, Generic[OptionsCoT]): # type: ignore[misc]
options=options,
compaction_strategy=compaction_strategy,
tokenizer=tokenizer,
kwargs=kwargs,
legacy_kwargs=kwargs,
function_invocation_kwargs=function_invocation_kwargs,
client_kwargs=client_kwargs,
)
response = cast(
ChatResponse[Any],
@@ -899,7 +927,8 @@ class RawAgent(BaseAgent, Generic[OptionsCoT]): # type: ignore[misc]
options=ctx["chat_options"], # type: ignore[reportArgumentType]
compaction_strategy=ctx["compaction_strategy"],
tokenizer=ctx["tokenizer"],
**ctx["filtered_kwargs"],
function_invocation_kwargs=ctx["function_invocation_kwargs"],
client_kwargs=ctx["client_kwargs"],
),
)
@@ -974,7 +1003,9 @@ class RawAgent(BaseAgent, Generic[OptionsCoT]): # type: ignore[misc]
options=options,
compaction_strategy=compaction_strategy,
tokenizer=tokenizer,
kwargs=kwargs,
legacy_kwargs=kwargs,
function_invocation_kwargs=function_invocation_kwargs,
client_kwargs=client_kwargs,
)
ctx: _RunContext = ctx_holder["ctx"] # type: ignore[assignment] # Safe: we just assigned it
return self.client.get_response( # type: ignore[call-overload, no-any-return]
@@ -983,7 +1014,8 @@ class RawAgent(BaseAgent, Generic[OptionsCoT]): # type: ignore[misc]
options=ctx["chat_options"], # type: ignore[reportArgumentType]
compaction_strategy=ctx["compaction_strategy"],
tokenizer=ctx["tokenizer"],
**ctx["filtered_kwargs"],
function_invocation_kwargs=ctx["function_invocation_kwargs"],
client_kwargs=ctx["client_kwargs"],
)
def _propagate_conversation_id(
@@ -1071,9 +1103,12 @@ class RawAgent(BaseAgent, Generic[OptionsCoT]): # type: ignore[misc]
options: Mapping[str, Any] | None,
compaction_strategy: CompactionStrategy | None,
tokenizer: TokenizerProtocol | None,
kwargs: dict[str, Any],
legacy_kwargs: Mapping[str, Any],
function_invocation_kwargs: Mapping[str, Any] | None,
client_kwargs: Mapping[str, Any] | None,
) -> _RunContext:
opts = dict(options) if options else {}
existing_additional_args: dict[str, Any] = opts.pop("additional_function_arguments", None) or {}
# Get tools from options or named parameter (named param takes precedence)
tools_ = tools if tools is not None else opts.pop("tools", None)
@@ -1104,6 +1139,12 @@ class RawAgent(BaseAgent, Generic[OptionsCoT]): # type: ignore[misc]
input_messages=input_messages,
options=opts,
)
default_additional_args = chat_options.pop("additional_function_arguments", None)
if isinstance(default_additional_args, Mapping):
existing_additional_args = {
**dict(cast(Mapping[str, Any], default_additional_args)),
**existing_additional_args,
}
agent_name = self._get_agent_name()
base_tools = normalize_tools(chat_options.pop("tools", None))
@@ -1135,13 +1176,13 @@ class RawAgent(BaseAgent, Generic[OptionsCoT]): # type: ignore[misc]
duplicate_error_message=mcp_duplicate_message,
)
# Merge runtime kwargs into additional_function_arguments so they're available
# in function middleware context and tool invocation.
existing_additional_args: dict[str, Any] = opts.pop("additional_function_arguments", None) or {}
additional_function_arguments = {**kwargs, **existing_additional_args}
# Include session so as_tool() wrappers with propagate_session=True can access it.
if active_session is not None:
additional_function_arguments["session"] = active_session
# TODO(Copilot): Delete once direct ``run(**kwargs)`` compatibility is removed.
# Legacy compatibility still fans out direct run kwargs into tool runtime kwargs.
effective_function_invocation_kwargs = {
**dict(legacy_kwargs),
**(dict(function_invocation_kwargs) if function_invocation_kwargs is not None else {}),
}
additional_function_arguments = {**effective_function_invocation_kwargs, **existing_additional_args}
# Build options dict from run() options merged with provided options
run_opts: dict[str, Any] = {
@@ -1150,7 +1191,6 @@ class RawAgent(BaseAgent, Generic[OptionsCoT]): # type: ignore[misc]
if active_session
else opts.pop("conversation_id", None),
"allow_multiple_tool_calls": opts.pop("allow_multiple_tool_calls", None),
"additional_function_arguments": additional_function_arguments or None,
"frequency_penalty": opts.pop("frequency_penalty", None),
"logit_bias": opts.pop("logit_bias", None),
"max_tokens": opts.pop("max_tokens", None),
@@ -1174,11 +1214,14 @@ class RawAgent(BaseAgent, Generic[OptionsCoT]): # type: ignore[misc]
# Build session_messages from session context: context messages + input messages
session_messages: list[Message] = session_context.get_messages(include_input=True)
# Ensure session is forwarded in kwargs for tool invocation
finalize_kwargs = dict(kwargs)
finalize_kwargs["session"] = active_session
# Filter chat_options from kwargs to prevent duplicate keyword argument
filtered_kwargs = {k: v for k, v in finalize_kwargs.items() if k != "chat_options"}
# TODO(Copilot): Delete once direct ``run(**kwargs)`` compatibility is removed.
# Legacy compatibility still fans out direct run kwargs into client kwargs.
effective_client_kwargs = {
**dict(legacy_kwargs),
**(dict(client_kwargs) if client_kwargs is not None else {}),
}
if active_session is not None:
effective_client_kwargs["session"] = active_session
return {
"session": active_session,
@@ -1189,8 +1232,8 @@ class RawAgent(BaseAgent, Generic[OptionsCoT]): # type: ignore[misc]
"chat_options": co,
"compaction_strategy": compaction_strategy or self.compaction_strategy,
"tokenizer": tokenizer or self.tokenizer,
"filtered_kwargs": filtered_kwargs,
"finalize_kwargs": finalize_kwargs,
"client_kwargs": effective_client_kwargs,
"function_invocation_kwargs": additional_function_arguments,
}
async def _finalize_response(
@@ -1440,6 +1483,58 @@ class Agent(
For a minimal implementation without these features, use :class:`RawAgent`.
"""
@overload
def run(
self,
messages: AgentRunInputs | None = None,
*,
stream: Literal[False] = ...,
session: AgentSession | None = None,
function_invocation_kwargs: Mapping[str, Any] | None = None,
client_kwargs: Mapping[str, Any] | None = None,
**kwargs: Any,
) -> Awaitable[AgentResponse[Any]]: ...
@overload
def run(
self,
messages: AgentRunInputs | None = None,
*,
stream: Literal[True],
session: AgentSession | None = None,
function_invocation_kwargs: Mapping[str, Any] | None = None,
client_kwargs: Mapping[str, Any] | None = None,
**kwargs: Any,
) -> ResponseStream[AgentResponseUpdate, AgentResponse[Any]]: ...
def run(
self,
messages: AgentRunInputs | None = None,
*,
stream: bool = False,
session: AgentSession | None = None,
middleware: Sequence[MiddlewareTypes] | None = None,
options: OptionsCoT | ChatOptions[Any] | None = None,
function_invocation_kwargs: Mapping[str, Any] | None = None,
client_kwargs: Mapping[str, Any] | None = None,
**kwargs: Any,
) -> Awaitable[AgentResponse[Any]] | ResponseStream[AgentResponseUpdate, AgentResponse[Any]]:
"""Run the agent."""
super_run = cast(
"Callable[..., Awaitable[AgentResponse[Any]] | ResponseStream[AgentResponseUpdate, AgentResponse[Any]]]",
super().run, # type: ignore[misc]
)
return super_run( # type: ignore[no-any-return]
messages=messages,
stream=stream,
session=session,
middleware=middleware,
options=options,
function_invocation_kwargs=function_invocation_kwargs,
client_kwargs=client_kwargs,
**kwargs,
)
def __init__(
self,
client: SupportsChatGetResponse[OptionsCoT],
@@ -1471,3 +1566,34 @@ class Agent(
tokenizer=tokenizer,
**kwargs,
)
def _apply_agent_docstrings() -> None:
"""Align public agent docstrings with the raw implementation."""
apply_layered_docstring(
AgentMiddlewareLayer.run,
RawAgent.run,
extra_keyword_args={
"middleware": """
Optional per-run agent, chat, and function middleware.
Agent middleware wraps the run itself, while chat and function middleware are forwarded to the
underlying chat-client stack for this call.
""",
},
)
apply_layered_docstring(AgentTelemetryLayer.run, AgentMiddlewareLayer.run)
apply_layered_docstring(
Agent.run,
RawAgent.run,
extra_keyword_args={
"middleware": """
Optional per-run agent, chat, and function middleware.
Agent middleware wraps the run itself, while chat and function middleware are forwarded to the
underlying chat-client stack for this call.
""",
},
)
apply_layered_docstring(Agent.__init__, RawAgent.__init__)
_apply_agent_docstrings()
@@ -4,6 +4,7 @@ from __future__ import annotations
import logging
import sys
import warnings
from abc import ABC, abstractmethod
from collections.abc import (
AsyncIterable,
@@ -27,6 +28,7 @@ from typing import (
from pydantic import BaseModel
from ._docstrings import apply_layered_docstring
from ._serialization import SerializationMixin
from ._tools import (
FunctionInvocationConfiguration,
@@ -105,7 +107,7 @@ class SupportsChatGetResponse(Protocol[OptionsContraT]):
class CustomChatClient:
additional_properties: dict = {}
def get_response(self, messages, *, stream=False, **kwargs):
def get_response(self, messages, *, stream=False, client_kwargs=None, **kwargs):
if stream:
from agent_framework import ChatResponseUpdate, ResponseStream
@@ -149,6 +151,8 @@ class SupportsChatGetResponse(Protocol[OptionsContraT]):
options: OptionsContraT | ChatOptions[None] | None = None,
compaction_strategy: CompactionStrategy | None = None,
tokenizer: TokenizerProtocol | None = None,
function_invocation_kwargs: Mapping[str, Any] | None = None,
client_kwargs: Mapping[str, Any] | None = None,
**kwargs: Any,
) -> Awaitable[ChatResponse[Any]]: ...
@@ -161,6 +165,8 @@ class SupportsChatGetResponse(Protocol[OptionsContraT]):
options: OptionsContraT | ChatOptions[Any] | None = None,
compaction_strategy: CompactionStrategy | None = None,
tokenizer: TokenizerProtocol | None = None,
function_invocation_kwargs: Mapping[str, Any] | None = None,
client_kwargs: Mapping[str, Any] | None = None,
**kwargs: Any,
) -> ResponseStream[ChatResponseUpdate, ChatResponse[Any]]: ...
@@ -172,6 +178,8 @@ class SupportsChatGetResponse(Protocol[OptionsContraT]):
options: OptionsContraT | ChatOptions[Any] | None = None,
compaction_strategy: CompactionStrategy | None = None,
tokenizer: TokenizerProtocol | None = None,
function_invocation_kwargs: Mapping[str, Any] | None = None,
client_kwargs: Mapping[str, Any] | None = None,
**kwargs: Any,
) -> Awaitable[ChatResponse[Any]] | ResponseStream[ChatResponseUpdate, ChatResponse[Any]]:
"""Send input and return the response.
@@ -182,7 +190,9 @@ class SupportsChatGetResponse(Protocol[OptionsContraT]):
options: Chat options as a TypedDict.
compaction_strategy: Optional per-call compaction override.
tokenizer: Optional per-call tokenizer override.
**kwargs: Additional chat options.
function_invocation_kwargs: Keyword arguments forwarded only to tool invocation layers.
client_kwargs: Additional client-specific keyword arguments.
**kwargs: Deprecated additional client-specific keyword arguments.
Returns:
When stream=False: An awaitable ChatResponse from the client.
@@ -283,23 +293,31 @@ class BaseChatClient(SerializationMixin, ABC, Generic[OptionsCoT]):
def __init__(
self,
*,
additional_properties: dict[str, Any] | None = None,
compaction_strategy: CompactionStrategy | None = None,
tokenizer: TokenizerProtocol | None = None,
additional_properties: dict[str, Any] | None = None,
**kwargs: Any,
) -> None:
"""Initialize a BaseChatClient instance.
Keyword Args:
additional_properties: Additional properties for the client.
compaction_strategy: Optional compaction strategy to apply before model calls.
tokenizer: Optional tokenizer used by token-aware compaction strategies.
kwargs: Additional keyword arguments (merged into additional_properties).
additional_properties: Additional properties for the client.
kwargs: Additional keyword arguments (merged into additional_properties for now).
"""
self.additional_properties = additional_properties or {}
self.compaction_strategy = compaction_strategy
self.tokenizer = tokenizer
super().__init__(**kwargs)
if kwargs:
warnings.warn(
"Passing additional properties as direct keyword arguments to BaseChatClient is deprecated; "
"pass them via additional_properties instead.",
DeprecationWarning,
stacklevel=3,
)
self.additional_properties.update(kwargs)
super().__init__()
def to_dict(self, *, exclude: set[str] | None = None, exclude_none: bool = True) -> dict[str, Any]:
"""Convert the instance to a dictionary.
@@ -486,7 +504,13 @@ class BaseChatClient(SerializationMixin, ABC, Generic[OptionsCoT]):
When omitted, the client-level default is used.
tokenizer: Optional per-call tokenizer override. When omitted, the
client-level default is used.
**kwargs: Other keyword arguments, can be used to pass function specific parameters.
**kwargs: Additional compatibility keyword arguments. Lower chat-client layers do not
consume ``function_invocation_kwargs`` directly; if present, it is ignored here
because function invocation has already been handled by upper layers. If a
``client_kwargs`` mapping is present, it is flattened into standard keyword
arguments before forwarding to ``_inner_get_response()`` so client implementations
can leverage those values, while implementations that ignore
extra kwargs remain compatible.
Returns:
When streaming a response stream of ChatResponseUpdates, otherwise an Awaitable ChatResponse.
@@ -495,12 +519,21 @@ class BaseChatClient(SerializationMixin, ABC, Generic[OptionsCoT]):
compaction_strategy=compaction_strategy,
tokenizer=tokenizer,
)
compatibility_client_kwargs = kwargs.pop("client_kwargs", None)
kwargs.pop("function_invocation_kwargs", None)
merged_client_kwargs = (
dict(cast(Mapping[str, Any], compatibility_client_kwargs))
if isinstance(compatibility_client_kwargs, Mapping)
else {}
)
merged_client_kwargs.update(kwargs)
if not compaction_overrides:
return self._inner_get_response(
messages=messages,
stream=stream,
options=options or {},
**kwargs,
options=options or {}, # type: ignore[arg-type]
**merged_client_kwargs,
)
if stream:
@@ -514,7 +547,7 @@ class BaseChatClient(SerializationMixin, ABC, Generic[OptionsCoT]):
messages=prepared_messages,
stream=True,
options=options or {},
**kwargs,
**merged_client_kwargs,
)
if isinstance(stream_response, ResponseStream):
return stream_response # type: ignore[reportUnknownVariableType]
@@ -534,7 +567,7 @@ class BaseChatClient(SerializationMixin, ABC, Generic[OptionsCoT]):
messages=prepared_messages,
stream=False,
options=options or {},
**kwargs,
**merged_client_kwargs,
)
return _get_response()
@@ -564,7 +597,7 @@ class BaseChatClient(SerializationMixin, ABC, Generic[OptionsCoT]):
function_invocation_configuration: FunctionInvocationConfiguration | None = None,
compaction_strategy: CompactionStrategy | None = None,
tokenizer: TokenizerProtocol | None = None,
**kwargs: Any,
additional_properties: Mapping[str, Any] | None = None,
) -> Agent[OptionsCoT]:
"""Create a Agent with this client.
@@ -590,7 +623,7 @@ class BaseChatClient(SerializationMixin, ABC, Generic[OptionsCoT]):
client-level compaction defaults remain in effect for each call.
tokenizer: Optional agent-level tokenizer override. When omitted,
client-level tokenizer defaults remain in effect for each call.
kwargs: Any additional keyword arguments. Will be stored as ``additional_properties``.
additional_properties: Additional properties stored on the created agent.
Returns:
A Agent instance configured with this chat client.
@@ -615,21 +648,24 @@ class BaseChatClient(SerializationMixin, ABC, Generic[OptionsCoT]):
"""
from ._agents import Agent
return Agent(
client=self,
id=id,
name=name,
description=description,
instructions=instructions,
tools=tools,
default_options=cast(Any, default_options),
context_providers=context_providers,
middleware=middleware,
function_invocation_configuration=function_invocation_configuration,
compaction_strategy=compaction_strategy,
tokenizer=tokenizer,
**kwargs,
)
agent_kwargs: dict[str, Any] = {
"client": self,
"id": id,
"name": name,
"description": description,
"instructions": instructions,
"tools": tools,
"default_options": cast(Any, default_options),
"context_providers": context_providers,
"middleware": middleware,
"compaction_strategy": compaction_strategy,
"tokenizer": tokenizer,
"additional_properties": dict(additional_properties) if additional_properties is not None else None,
}
if function_invocation_configuration is not None:
agent_kwargs["function_invocation_configuration"] = function_invocation_configuration
return Agent(**agent_kwargs)
# endregion
@@ -892,16 +928,14 @@ class BaseEmbeddingClient(SerializationMixin, ABC, Generic[EmbeddingInputT, Embe
self,
*,
additional_properties: dict[str, Any] | None = None,
**kwargs: Any,
) -> None:
"""Initialize a BaseEmbeddingClient instance.
Args:
additional_properties: Additional properties to pass to the client.
**kwargs: Additional keyword arguments passed to parent classes (for MRO).
"""
self.additional_properties = additional_properties or {}
super().__init__(**kwargs)
super().__init__()
@abstractmethod
async def get_embeddings(
@@ -923,3 +957,36 @@ class BaseEmbeddingClient(SerializationMixin, ABC, Generic[EmbeddingInputT, Embe
# endregion
def _apply_get_response_docstrings() -> None:
"""Align layered chat-client docstrings with the lowest public implementation."""
from ._middleware import ChatMiddlewareLayer
from ._tools import FunctionInvocationLayer
from .observability import ChatTelemetryLayer
apply_layered_docstring(ChatTelemetryLayer.get_response, BaseChatClient.get_response)
apply_layered_docstring(
FunctionInvocationLayer.get_response,
ChatTelemetryLayer.get_response,
extra_keyword_args={
"function_middleware": """
Optional per-call function middleware.
When omitted, middleware configured on the client or forwarded from higher layers is used.
""",
},
)
apply_layered_docstring(
ChatMiddlewareLayer.get_response,
FunctionInvocationLayer.get_response,
extra_keyword_args={
"middleware": """
Optional per-call chat and function middleware.
This compatibility keyword argument is merged with any ``client_kwargs["middleware"]`` value
before the request is executed.
""",
},
)
_apply_get_response_docstrings()
@@ -0,0 +1,85 @@
# Copyright (c) Microsoft. All rights reserved.
from __future__ import annotations
import inspect
from collections.abc import Callable, Mapping
from typing import Any
_GOOGLE_SECTION_HEADERS = (
"Args:",
"Keyword Args:",
"Returns:",
"Raises:",
"Examples:",
"Note:",
"Notes:",
"Warning:",
"Warnings:",
)
def _find_section_index(lines: list[str], header: str) -> int | None:
for index, line in enumerate(lines):
if line == header:
return index
return None
def _find_next_section_index(lines: list[str], start: int) -> int:
for index in range(start, len(lines)):
if lines[index] in _GOOGLE_SECTION_HEADERS:
return index
return len(lines)
def _format_keyword_arg_lines(extra_keyword_args: Mapping[str, str]) -> list[str]:
formatted_lines: list[str] = []
for name, description in extra_keyword_args.items():
description_lines = inspect.cleandoc(description).splitlines()
if not description_lines:
formatted_lines.append(f" {name}:")
continue
formatted_lines.append(f" {name}: {description_lines[0]}")
formatted_lines.extend(f" {line}" for line in description_lines[1:])
return formatted_lines
def build_layered_docstring(
source: Callable[..., Any],
*,
extra_keyword_args: Mapping[str, str] | None = None,
) -> str | None:
"""Build a Google-style docstring from a lower-layer implementation."""
docstring = inspect.getdoc(source)
if not docstring:
return None
if not extra_keyword_args:
return docstring
lines = docstring.splitlines()
formatted_keyword_arg_lines = _format_keyword_arg_lines(extra_keyword_args)
keyword_args_index = _find_section_index(lines, "Keyword Args:")
if keyword_args_index is None:
args_index = _find_section_index(lines, "Args:")
if args_index is not None:
insert_index = _find_next_section_index(lines, args_index + 1)
else:
insert_index = _find_next_section_index(lines, 0)
lines[insert_index:insert_index] = ["", "Keyword Args:", *formatted_keyword_arg_lines]
return "\n".join(lines).rstrip()
insert_index = _find_next_section_index(lines, keyword_args_index + 1)
lines[insert_index:insert_index] = formatted_keyword_arg_lines
return "\n".join(lines).rstrip()
def apply_layered_docstring(
target: Callable[..., Any],
source: Callable[..., Any],
*,
extra_keyword_args: Mapping[str, str] | None = None,
) -> None:
"""Copy a lower-layer docstring onto a wrapper and extend it when needed."""
target.__doc__ = build_layered_docstring(source, extra_keyword_args=extra_keyword_args)
@@ -109,7 +109,9 @@ class AgentContext:
to see the actual execution result or can be set to override the execution result.
For non-streaming: should be AgentResponse.
For streaming: should be ResponseStream[AgentResponseUpdate, AgentResponse].
kwargs: Additional keyword arguments passed to the agent run method.
kwargs: Legacy runtime keyword arguments visible to agent middleware.
client_kwargs: Client-specific keyword arguments for downstream chat clients.
function_invocation_kwargs: Keyword arguments forwarded to tool invocation.
Examples:
.. code-block:: python
@@ -147,6 +149,8 @@ class AgentContext:
metadata: Mapping[str, Any] | None = None,
result: AgentResponse | ResponseStream[AgentResponseUpdate, AgentResponse] | None = None,
kwargs: Mapping[str, Any] | None = None,
client_kwargs: Mapping[str, Any] | None = None,
function_invocation_kwargs: Mapping[str, Any] | None = None,
stream_transform_hooks: Sequence[
Callable[[AgentResponseUpdate], AgentResponseUpdate | Awaitable[AgentResponseUpdate]]
]
@@ -167,7 +171,9 @@ class AgentContext:
tokenizer: Optional per-run tokenizer override.
metadata: Metadata dictionary for sharing data between agent middleware.
result: Agent execution result.
kwargs: Additional keyword arguments passed to the agent run method.
kwargs: Legacy runtime keyword arguments visible to agent middleware.
client_kwargs: Client-specific keyword arguments for downstream chat clients.
function_invocation_kwargs: Keyword arguments forwarded to tool invocation.
stream_transform_hooks: Hooks to transform streamed updates.
stream_result_hooks: Hooks to process the final result after streaming.
stream_cleanup_hooks: Hooks to run after streaming completes.
@@ -182,6 +188,10 @@ class AgentContext:
self.metadata: dict[str, Any] = dict(metadata) if metadata is not None else {}
self.result = result
self.kwargs: dict[str, Any] = dict(kwargs) if kwargs is not None else {}
self.client_kwargs: dict[str, Any] = dict(client_kwargs) if client_kwargs is not None else {}
self.function_invocation_kwargs: dict[str, Any] = (
dict(function_invocation_kwargs) if function_invocation_kwargs is not None else {}
)
self.stream_transform_hooks = list(stream_transform_hooks or [])
self.stream_result_hooks = list(stream_result_hooks or [])
self.stream_cleanup_hooks = list(stream_cleanup_hooks or [])
@@ -196,11 +206,11 @@ class FunctionInvocationContext:
Attributes:
function: The function being invoked.
arguments: The validated arguments for the function.
session: The agent session for this invocation, if any.
metadata: Metadata dictionary for sharing data between function middleware.
result: Function execution result. Can be observed after calling ``call_next()``
to see the actual execution result or can be set to override the execution result.
kwargs: Additional keyword arguments passed to the chat method that invoked this function.
kwargs: Additional runtime keyword arguments forwarded to the function invocation.
Examples:
.. code-block:: python
@@ -225,6 +235,7 @@ class FunctionInvocationContext:
self,
function: FunctionTool,
arguments: BaseModel | Mapping[str, Any],
session: AgentSession | None = None,
metadata: Mapping[str, Any] | None = None,
result: Any = None,
kwargs: Mapping[str, Any] | None = None,
@@ -234,12 +245,14 @@ class FunctionInvocationContext:
Args:
function: The function being invoked.
arguments: The validated arguments for the function.
session: The agent session for this invocation, if any.
metadata: Metadata dictionary for sharing data between function middleware.
result: Function execution result.
kwargs: Additional keyword arguments passed to the chat method that invoked this function.
kwargs: Additional runtime keyword arguments forwarded to the function invocation.
"""
self.function = function
self.arguments = arguments
self.session = session
self.metadata: dict[str, Any] = dict(metadata) if metadata is not None else {}
self.result = result
self.kwargs: dict[str, Any] = dict(kwargs) if kwargs is not None else {}
@@ -262,6 +275,7 @@ class ChatContext:
For non-streaming: should be ChatResponse.
For streaming: should be ResponseStream[ChatResponseUpdate, ChatResponse].
kwargs: Additional keyword arguments passed to the chat client.
function_invocation_kwargs: Keyword arguments forwarded only to tool invocation layers.
stream_transform_hooks: Hooks applied to transform each streamed update.
stream_result_hooks: Hooks applied to the finalized response (after finalizer).
stream_cleanup_hooks: Hooks executed after stream consumption (before finalizer).
@@ -298,6 +312,7 @@ class ChatContext:
metadata: Mapping[str, Any] | None = None,
result: ChatResponse | ResponseStream[ChatResponseUpdate, ChatResponse] | None = None,
kwargs: Mapping[str, Any] | None = None,
function_invocation_kwargs: Mapping[str, Any] | None = None,
stream_transform_hooks: Sequence[
Callable[[ChatResponseUpdate], ChatResponseUpdate | Awaitable[ChatResponseUpdate]]
]
@@ -315,6 +330,7 @@ class ChatContext:
metadata: Metadata dictionary for sharing data between chat middleware.
result: Chat execution result.
kwargs: Additional keyword arguments passed to the chat client.
function_invocation_kwargs: Keyword arguments forwarded only to tool invocation layers.
stream_transform_hooks: Transform hooks to apply to each streamed update.
stream_result_hooks: Result hooks to apply to the finalized streaming response.
stream_cleanup_hooks: Cleanup hooks to run after streaming completes.
@@ -326,6 +342,9 @@ class ChatContext:
self.metadata: dict[str, Any] = dict(metadata) if metadata is not None else {}
self.result = result
self.kwargs: dict[str, Any] = dict(kwargs) if kwargs is not None else {}
self.function_invocation_kwargs: dict[str, Any] = (
dict(function_invocation_kwargs) if function_invocation_kwargs is not None else {}
)
self.stream_transform_hooks = list(stream_transform_hooks or [])
self.stream_result_hooks = list(stream_result_hooks or [])
self.stream_cleanup_hooks = list(stream_cleanup_hooks or [])
@@ -980,6 +999,7 @@ class ChatMiddlewareLayer(Generic[OptionsCoT]):
options: ChatOptions[ResponseModelBoundT],
compaction_strategy: CompactionStrategy | None = None,
tokenizer: TokenizerProtocol | None = None,
function_invocation_kwargs: Mapping[str, Any] | None = None,
**kwargs: Any,
) -> Awaitable[ChatResponse[ResponseModelBoundT]]: ...
@@ -992,6 +1012,8 @@ class ChatMiddlewareLayer(Generic[OptionsCoT]):
options: OptionsCoT | ChatOptions[None] | None = None,
compaction_strategy: CompactionStrategy | None = None,
tokenizer: TokenizerProtocol | None = None,
function_invocation_kwargs: Mapping[str, Any] | None = None,
client_kwargs: Mapping[str, Any] | None = None,
**kwargs: Any,
) -> Awaitable[ChatResponse[Any]]: ...
@@ -1004,6 +1026,8 @@ class ChatMiddlewareLayer(Generic[OptionsCoT]):
options: OptionsCoT | ChatOptions[Any] | None = None,
compaction_strategy: CompactionStrategy | None = None,
tokenizer: TokenizerProtocol | None = None,
function_invocation_kwargs: Mapping[str, Any] | None = None,
client_kwargs: Mapping[str, Any] | None = None,
**kwargs: Any,
) -> ResponseStream[ChatResponseUpdate, ChatResponse[Any]]: ...
@@ -1015,6 +1039,8 @@ class ChatMiddlewareLayer(Generic[OptionsCoT]):
options: OptionsCoT | ChatOptions[Any] | None = None,
compaction_strategy: CompactionStrategy | None = None,
tokenizer: TokenizerProtocol | None = None,
function_invocation_kwargs: Mapping[str, Any] | None = None,
client_kwargs: Mapping[str, Any] | None = None,
**kwargs: Any,
) -> Awaitable[ChatResponse[Any]] | ResponseStream[ChatResponseUpdate, ChatResponse[Any]]:
"""Execute the chat pipeline if middleware is configured."""
@@ -1025,9 +1051,10 @@ class ChatMiddlewareLayer(Generic[OptionsCoT]):
if tokenizer is not None:
kwargs["tokenizer"] = tokenizer
call_middleware = kwargs.pop("middleware", [])
effective_client_kwargs = dict(client_kwargs) if client_kwargs is not None else {}
call_middleware = kwargs.pop("middleware", effective_client_kwargs.pop("middleware", []))
middleware = categorize_middleware(call_middleware)
kwargs["function_middleware"] = middleware["function"]
effective_client_kwargs["function_middleware"] = middleware["function"]
pipeline = ChatMiddlewarePipeline(
*self.chat_middleware,
@@ -1038,6 +1065,8 @@ class ChatMiddlewareLayer(Generic[OptionsCoT]):
messages=messages,
stream=stream,
options=options,
function_invocation_kwargs=function_invocation_kwargs,
client_kwargs=effective_client_kwargs,
**kwargs,
)
@@ -1046,7 +1075,8 @@ class ChatMiddlewareLayer(Generic[OptionsCoT]):
messages=list(messages),
options=options,
stream=stream,
kwargs=kwargs,
kwargs={**effective_client_kwargs, **kwargs},
function_invocation_kwargs=function_invocation_kwargs,
)
async def _execute() -> ChatResponse | ResponseStream[ChatResponseUpdate, ChatResponse] | None:
@@ -1079,11 +1109,17 @@ class ChatMiddlewareLayer(Generic[OptionsCoT]):
self, context: ChatContext
) -> Awaitable[ChatResponse] | ResponseStream[ChatResponseUpdate, ChatResponse]:
"""Internal middleware handler to adapt to pipeline."""
handler_kwargs = dict(context.kwargs)
compaction_strategy = handler_kwargs.pop("compaction_strategy", None)
tokenizer = handler_kwargs.pop("tokenizer", None)
return super().get_response( # type: ignore[misc, no-any-return]
messages=context.messages,
stream=context.stream,
options=context.options or {},
**context.kwargs,
compaction_strategy=compaction_strategy,
tokenizer=tokenizer,
function_invocation_kwargs=context.function_invocation_kwargs,
client_kwargs=handler_kwargs,
)
@@ -1115,6 +1151,8 @@ class AgentMiddlewareLayer:
options: ChatOptions[ResponseModelBoundT],
compaction_strategy: CompactionStrategy | None = None,
tokenizer: TokenizerProtocol | None = None,
function_invocation_kwargs: Mapping[str, Any] | None = None,
client_kwargs: Mapping[str, Any] | None = None,
**kwargs: Any,
) -> Awaitable[AgentResponse[ResponseModelBoundT]]: ...
@@ -1129,6 +1167,8 @@ class AgentMiddlewareLayer:
options: ChatOptions[None] | None = None,
compaction_strategy: CompactionStrategy | None = None,
tokenizer: TokenizerProtocol | None = None,
function_invocation_kwargs: Mapping[str, Any] | None = None,
client_kwargs: Mapping[str, Any] | None = None,
**kwargs: Any,
) -> Awaitable[AgentResponse[Any]]: ...
@@ -1143,6 +1183,8 @@ class AgentMiddlewareLayer:
options: ChatOptions[Any] | None = None,
compaction_strategy: CompactionStrategy | None = None,
tokenizer: TokenizerProtocol | None = None,
function_invocation_kwargs: Mapping[str, Any] | None = None,
client_kwargs: Mapping[str, Any] | None = None,
**kwargs: Any,
) -> ResponseStream[AgentResponseUpdate, AgentResponse[Any]]: ...
@@ -1156,6 +1198,8 @@ class AgentMiddlewareLayer:
options: ChatOptions[Any] | None = None,
compaction_strategy: CompactionStrategy | None = None,
tokenizer: TokenizerProtocol | None = None,
function_invocation_kwargs: Mapping[str, Any] | None = None,
client_kwargs: Mapping[str, Any] | None = None,
**kwargs: Any,
) -> Awaitable[AgentResponse[Any]] | ResponseStream[AgentResponseUpdate, AgentResponse[Any]]:
"""MiddlewareTypes-enabled unified run method."""
@@ -1175,9 +1219,12 @@ class AgentMiddlewareLayer:
+ run_middleware_list["function"]
+ run_middleware_list["chat"]
)
combined_kwargs = dict(kwargs)
combined_kwargs["middleware"] = combined_function_chat_middleware if combined_function_chat_middleware else None
effective_client_kwargs = dict(client_kwargs) if client_kwargs is not None else {}
if combined_function_chat_middleware:
effective_client_kwargs["middleware"] = combined_function_chat_middleware
effective_function_invocation_kwargs = (
dict(function_invocation_kwargs) if function_invocation_kwargs is not None else {}
)
# Execute with middleware if available
if not pipeline.has_middlewares:
return super().run( # type: ignore[misc, no-any-return]
@@ -1187,7 +1234,9 @@ class AgentMiddlewareLayer:
options=options,
compaction_strategy=compaction_strategy,
tokenizer=tokenizer,
**combined_kwargs,
function_invocation_kwargs=effective_function_invocation_kwargs,
client_kwargs=effective_client_kwargs,
**kwargs,
)
context = AgentContext(
@@ -1198,7 +1247,9 @@ class AgentMiddlewareLayer:
stream=stream,
compaction_strategy=compaction_strategy,
tokenizer=tokenizer,
kwargs=combined_kwargs,
kwargs=kwargs,
client_kwargs=effective_client_kwargs,
function_invocation_kwargs=effective_function_invocation_kwargs,
)
async def _execute() -> AgentResponse | ResponseStream[AgentResponseUpdate, AgentResponse] | None:
@@ -1230,6 +1281,13 @@ class AgentMiddlewareLayer:
def _middleware_handler(
self, context: AgentContext
) -> Awaitable[AgentResponse] | ResponseStream[AgentResponseUpdate, AgentResponse]:
# TODO(Copilot): Delete once direct ``run(**kwargs)`` compatibility is removed.
client_kwargs = {**context.client_kwargs, **context.kwargs}
# TODO(Copilot): Delete once direct ``run(**kwargs)`` compatibility is removed.
function_invocation_kwargs = {
**context.function_invocation_kwargs,
**{k: v for k, v in context.kwargs.items() if k != "middleware"},
}
return super().run( # type: ignore[misc, no-any-return]
context.messages,
stream=context.stream,
@@ -1237,7 +1295,8 @@ class AgentMiddlewareLayer:
options=context.options,
compaction_strategy=context.compaction_strategy,
tokenizer=context.tokenizer,
**context.kwargs,
function_invocation_kwargs=function_invocation_kwargs,
client_kwargs=client_kwargs,
)
@@ -392,12 +392,16 @@ class BaseHistoryProvider(BaseContextProvider):
self.store_outputs = store_outputs
@abstractmethod
async def get_messages(self, session_id: str | None, **kwargs: Any) -> list[Message]:
async def get_messages(
self, session_id: str | None, *, state: dict[str, Any] | None = None, **kwargs: Any
) -> list[Message]:
"""Retrieve stored messages for this session.
Args:
session_id: The session ID to retrieve messages for.
**kwargs: Additional arguments (e.g., ``state`` for in-memory providers).
state: Optional session state for providers that persist in session state.
Not used by all providers.
**kwargs: Additional subclass-specific extensibility arguments.
Returns:
List of stored messages.
@@ -405,13 +409,22 @@ class BaseHistoryProvider(BaseContextProvider):
...
@abstractmethod
async def save_messages(self, session_id: str | None, messages: Sequence[Message], **kwargs: Any) -> None:
async def save_messages(
self,
session_id: str | None,
messages: Sequence[Message],
*,
state: dict[str, Any] | None = None,
**kwargs: Any,
) -> None:
"""Persist messages for this session.
Args:
session_id: The session ID to store messages for.
messages: The messages to persist.
**kwargs: Additional arguments (e.g., ``state`` for in-memory providers).
state: Optional session state for providers that persist in session state.
Not used by all providers.
**kwargs: Additional subclass-specific extensibility arguments.
"""
...
+217 -41
View File
@@ -7,6 +7,8 @@ import inspect
import json
import logging
import sys
import typing
import warnings
from collections.abc import (
AsyncIterable,
Awaitable,
@@ -37,7 +39,7 @@ from opentelemetry.metrics import Histogram, NoOpHistogram
from pydantic import BaseModel, Field, ValidationError, create_model
from ._serialization import SerializationMixin
from .exceptions import ToolException
from .exceptions import ToolException, UserInputRequiredException
from .observability import (
OPERATION_DURATION_BUCKET_BOUNDARIES,
OtelAttr,
@@ -61,7 +63,8 @@ if TYPE_CHECKING:
from ._clients import SupportsChatGetResponse
from ._compaction import CompactionStrategy, TokenizerProtocol
from ._mcp import MCPTool
from ._middleware import FunctionMiddlewarePipeline, FunctionMiddlewareTypes
from ._middleware import FunctionInvocationContext, FunctionMiddlewarePipeline, FunctionMiddlewareTypes
from ._sessions import AgentSession
from ._types import (
ChatOptions,
ChatResponse,
@@ -187,6 +190,16 @@ def _default_histogram() -> Histogram:
)
def _annotation_includes_function_invocation_context(annotation: Any) -> bool:
"""Check whether an annotation resolves to FunctionInvocationContext."""
from ._middleware import FunctionInvocationContext
candidates = get_args(annotation) or (annotation,)
return any(
candidate is FunctionInvocationContext or candidate == "FunctionInvocationContext" for candidate in candidates
)
ClassT = TypeVar("ClassT", bound="SerializationMixin")
@@ -323,6 +336,12 @@ class FunctionTool(SerializationMixin):
# FunctionTool-specific attributes
self.func = func
self._instance = None # Store the instance for bound methods
self._context_parameter_name: str | None = None
self._input_model_explicitly_provided = input_model is not None
# TODO(Copilot): Delete once legacy ``**kwargs`` runtime injection is removed.
self._forward_runtime_kwargs: bool = False
if self.func:
self._discover_injected_parameters()
# Initialize schema cache (will be lazily populated)
self._input_schema_cached: dict[str, Any] | None = None
@@ -349,13 +368,37 @@ class FunctionTool(SerializationMixin):
self._invocation_duration_histogram = _default_histogram()
self.type: Literal["function_tool"] = "function_tool"
self.result_parser = result_parser
self._forward_runtime_kwargs: bool = False
if self.func:
sig = inspect.signature(self.func)
for param in sig.parameters.values():
if param.kind == inspect.Parameter.VAR_KEYWORD:
self._forward_runtime_kwargs = True
break
def _discover_injected_parameters(self) -> None:
"""Inspect the wrapped function for runtime injection parameters."""
func = self.func.func if isinstance(self.func, FunctionTool) else self.func
if func is None:
return
signature = inspect.signature(func)
try:
type_hints = typing.get_type_hints(func)
except Exception:
type_hints = {name: param.annotation for name, param in signature.parameters.items()}
for name, param in signature.parameters.items():
if name in {"self", "cls"}:
continue
if param.kind == inspect.Parameter.VAR_KEYWORD:
self._forward_runtime_kwargs = True
continue
annotation = type_hints.get(name, param.annotation)
if self._is_context_parameter(name, annotation):
if self._context_parameter_name is not None:
raise ValueError(f"Function '{self.name}' defines multiple FunctionInvocationContext parameters.")
self._context_parameter_name = name
def _is_context_parameter(self, name: str, annotation: Any) -> bool:
"""Check whether a callable parameter should receive FunctionInvocationContext injection."""
if _annotation_includes_function_invocation_context(annotation):
return True
return self._input_model_explicitly_provided and name == "ctx" and annotation is inspect.Parameter.empty
def __str__(self) -> str:
"""Return a string representation of the tool."""
@@ -424,6 +467,7 @@ class FunctionTool(SerializationMixin):
)
for pname, param in sig.parameters.items()
if pname not in {"self", "cls"}
and pname != self._context_parameter_name
and param.kind not in {inspect.Parameter.VAR_POSITIONAL, inspect.Parameter.VAR_KEYWORD}
}
return create_model(f"{self.name}_input", **fields)
@@ -461,6 +505,7 @@ class FunctionTool(SerializationMixin):
self,
*,
arguments: BaseModel | Mapping[str, Any] | None = None,
context: FunctionInvocationContext | None = None,
**kwargs: Any,
) -> list[Content]:
"""Run the AI function with the provided arguments as a Pydantic model.
@@ -472,7 +517,8 @@ class FunctionTool(SerializationMixin):
Keyword Args:
arguments: A mapping or model instance containing the arguments for the function.
kwargs: Keyword arguments to pass to the function, will not be used if ``arguments`` is provided.
context: Explicit function invocation context carrying runtime kwargs.
kwargs: Deprecated keyword arguments to pass to the function. Use ``context`` instead.
Returns:
A list of Content items representing the tool output.
@@ -483,14 +529,37 @@ class FunctionTool(SerializationMixin):
if self.declaration_only:
raise ToolException(f"Function '{self.name}' is declaration only and cannot be invoked.")
global OBSERVABILITY_SETTINGS
from ._middleware import FunctionInvocationContext
from ._types import Content
from .observability import OBSERVABILITY_SETTINGS
parser = self.result_parser or FunctionTool.parse_result
original_kwargs = dict(kwargs)
tool_call_id = original_kwargs.pop("tool_call_id", None)
if arguments is not None:
parameter_names = set(self.parameters().get("properties", {}).keys())
direct_argument_kwargs = (
{key: value for key, value in kwargs.items() if key in parameter_names} if arguments is None else {}
)
runtime_kwargs = dict(context.kwargs) if context is not None else {}
deprecated_runtime_kwargs = {
key: value for key, value in kwargs.items() if key not in direct_argument_kwargs and key != "tool_call_id"
}
if deprecated_runtime_kwargs:
warnings.warn(
"Passing runtime keyword arguments directly to FunctionTool.invoke() is deprecated; "
"pass them via FunctionInvocationContext instead.",
DeprecationWarning,
stacklevel=2,
)
runtime_kwargs.update(deprecated_runtime_kwargs)
tool_call_id = kwargs.get("tool_call_id", runtime_kwargs.pop("tool_call_id", None))
if arguments is None and direct_argument_kwargs:
arguments = direct_argument_kwargs
if arguments is None and context is not None:
arguments = context.arguments
if arguments is None:
validated_arguments: dict[str, Any] = {}
else:
try:
if isinstance(arguments, Mapping):
parsed_arguments = dict(arguments)
@@ -512,19 +581,45 @@ class FunctionTool(SerializationMixin):
)
except ValidationError as exc:
raise TypeError(f"Invalid arguments for '{self.name}': {exc}") from exc
kwargs = _validate_arguments_against_schema(
validated_arguments = _validate_arguments_against_schema(
arguments=parsed_arguments,
schema=self.parameters(),
tool_name=self.name,
)
if getattr(self, "_forward_runtime_kwargs", False) and original_kwargs:
kwargs.update(original_kwargs)
else:
kwargs = original_kwargs
effective_context = context
if effective_context is None and self._context_parameter_name is not None:
effective_context = FunctionInvocationContext(
function=self,
arguments=validated_arguments,
kwargs=runtime_kwargs,
)
if effective_context is not None:
effective_context.function = self
effective_context.arguments = validated_arguments
effective_context.kwargs = dict(runtime_kwargs)
call_kwargs = dict(validated_arguments)
observable_kwargs = dict(validated_arguments)
# Legacy runtime kwargs injection path retained for backwards compatibility with tools
# that still declare ``**kwargs``. New tools should consume runtime data via ``ctx``.
legacy_runtime_kwargs = dict(runtime_kwargs)
if self._forward_runtime_kwargs and legacy_runtime_kwargs:
for key, value in legacy_runtime_kwargs.items():
if key not in call_kwargs:
call_kwargs[key] = value
if key not in observable_kwargs:
observable_kwargs[key] = value
if self._context_parameter_name is not None and effective_context is not None:
call_kwargs[self._context_parameter_name] = effective_context
if not OBSERVABILITY_SETTINGS.ENABLED: # type: ignore[name-defined]
logger.info(f"Function name: {self.name}")
logger.debug(f"Function arguments: {kwargs}")
res = self.__call__(**kwargs)
logger.debug(f"Function arguments: {observable_kwargs}")
res = self.__call__(**call_kwargs)
result = await res if inspect.isawaitable(res) else res
try:
parsed = parser(result)
@@ -545,7 +640,7 @@ class FunctionTool(SerializationMixin):
# Filter out framework kwargs that are not JSON serializable.
serializable_kwargs = {
k: v
for k, v in kwargs.items()
for k, v in observable_kwargs.items()
if k
not in {
"chat_options",
@@ -571,7 +666,7 @@ class FunctionTool(SerializationMixin):
start_time_stamp = perf_counter()
end_time_stamp: float | None = None
try:
res = self.__call__(**kwargs)
res = self.__call__(**call_kwargs)
result = await res if inspect.isawaitable(res) else res
end_time_stamp = perf_counter()
except Exception as exception:
@@ -1218,9 +1313,10 @@ async def _auto_invoke_function(
*,
config: FunctionInvocationConfiguration,
tool_map: dict[str, FunctionTool],
invocation_session: AgentSession | None = None,
sequence_index: int | None = None,
request_index: int | None = None,
middleware_pipeline: FunctionMiddlewarePipeline | None = None, # Optional MiddlewarePipeline
middleware_pipeline: FunctionMiddlewarePipeline | None = None,
) -> Content:
"""Invoke a function call requested by the agent, applying middleware that is defined.
@@ -1231,6 +1327,7 @@ async def _auto_invoke_function(
Keyword Args:
config: The function invocation configuration.
tool_map: A mapping of tool names to FunctionTool instances.
invocation_session: The agent session for this invocation, if any.
sequence_index: The index of the function call in the sequence.
request_index: The index of the request iteration.
middleware_pipeline: Optional middleware pipeline to apply during execution.
@@ -1282,6 +1379,8 @@ async def _auto_invoke_function(
for key, value in (custom_args or {}).items()
if key not in {"_function_middleware_pipeline", "middleware", "conversation_id"}
}
if invocation_session is not None:
runtime_kwargs["session"] = invocation_session
try:
if not cast(bool, getattr(tool, "_schema_supplied", False)) and tool.input_model is not None:
args = tool.input_model.model_validate(parsed_args).model_dump(exclude_none=True)
@@ -1303,19 +1402,31 @@ async def _auto_invoke_function(
additional_properties=function_call_content.additional_properties,
)
from ._middleware import FunctionInvocationContext
if middleware_pipeline is None or not middleware_pipeline.has_middlewares:
# No middleware - execute directly
try:
direct_context = None
if getattr(tool, "_forward_runtime_kwargs", False) or getattr(tool, "_context_parameter_name", None):
direct_context = FunctionInvocationContext(
function=tool,
arguments=args,
session=invocation_session,
kwargs=runtime_kwargs.copy(),
)
function_result = await tool.invoke(
arguments=args,
context=direct_context,
tool_call_id=function_call_content.call_id,
**runtime_kwargs if getattr(tool, "_forward_runtime_kwargs", False) else {},
)
return Content.from_function_result(
call_id=function_call_content.call_id, # type: ignore[arg-type]
result=function_result,
additional_properties=function_call_content.additional_properties,
)
except UserInputRequiredException:
raise
except Exception as exc:
message = "Error: Function failed."
if config.get("include_detailed_errors", False):
@@ -1327,19 +1438,18 @@ async def _auto_invoke_function(
additional_properties=function_call_content.additional_properties,
)
# Execute through middleware pipeline if available
from ._middleware import FunctionInvocationContext
middleware_context = FunctionInvocationContext(
function=tool,
arguments=args,
session=invocation_session,
kwargs=runtime_kwargs.copy(),
)
async def final_function_handler(context_obj: Any) -> Any:
return await tool.invoke(
arguments=context_obj.arguments,
context=context_obj,
tool_call_id=function_call_content.call_id,
**context_obj.kwargs if getattr(tool, "_forward_runtime_kwargs", False) else {},
)
from ._middleware import MiddlewareTermination
@@ -1362,6 +1472,8 @@ async def _auto_invoke_function(
additional_properties=function_call_content.additional_properties,
)
raise
except UserInputRequiredException:
raise
except Exception as exc:
message = "Error: Function failed."
if config.get("include_detailed_errors", False):
@@ -1390,7 +1502,8 @@ async def _try_execute_function_calls(
function_calls: Sequence[Content],
tools: ToolTypes | Callable[..., Any] | Sequence[ToolTypes | Callable[..., Any]],
config: FunctionInvocationConfiguration,
middleware_pipeline: Any = None, # Optional MiddlewarePipeline to avoid circular imports
invocation_session: AgentSession | None = None,
middleware_pipeline: Any = None,
) -> tuple[Sequence[Content], bool]:
"""Execute multiple function calls concurrently.
@@ -1400,6 +1513,7 @@ async def _try_execute_function_calls(
function_calls: A sequence of FunctionCallContent to execute.
tools: The tools available for execution.
config: Configuration for function invocation.
invocation_session: The agent session for this invocation, if any.
middleware_pipeline: Optional middleware pipeline to apply during execution.
Returns:
@@ -1469,6 +1583,8 @@ async def _try_execute_function_calls(
# Run all function calls concurrently, handling MiddlewareTermination
from ._middleware import MiddlewareTermination
extra_user_input_contents: list[Content] = []
async def invoke_with_termination_handling(
function_call: Content,
seq_idx: int,
@@ -1479,6 +1595,7 @@ async def _try_execute_function_calls(
function_call_content=function_call, # type: ignore[arg-type]
custom_args=custom_args,
tool_map=tool_map,
invocation_session=invocation_session,
sequence_index=seq_idx,
request_index=attempt_idx,
middleware_pipeline=middleware_pipeline,
@@ -1495,6 +1612,26 @@ async def _try_execute_function_calls(
result=exc.result,
)
return (result_content, True)
except UserInputRequiredException as exc:
if exc.contents:
propagated: list[Content] = []
for item in exc.contents:
if isinstance(item, Content):
item.call_id = function_call.call_id # type: ignore[attr-defined]
if not item.id: # type: ignore[attr-defined]
item.id = function_call.call_id # type: ignore[attr-defined]
propagated.append(item)
if propagated:
extra_user_input_contents.extend(propagated[1:])
return (propagated[0], False)
return (
Content.from_function_result(
call_id=function_call.call_id, # type: ignore[arg-type]
result="Tool requires user input but no request details were provided.",
exception="UserInputRequiredException",
),
False,
)
execution_results = await asyncio.gather(*[
invoke_with_termination_handling(function_call, seq_idx) for seq_idx, function_call in enumerate(function_calls)
@@ -1502,6 +1639,7 @@ async def _try_execute_function_calls(
# Unpack results - each is (Content, terminate_flag)
contents: list[Content] = [result[0] for result in execution_results]
contents.extend(extra_user_input_contents)
# If any function requested termination, terminate the loop
should_terminate = any(result[1] for result in execution_results)
return (contents, should_terminate)
@@ -1514,6 +1652,7 @@ async def _execute_function_calls(
function_calls: list[Content],
tool_options: dict[str, Any] | None,
config: FunctionInvocationConfiguration,
invocation_session: AgentSession | None = None,
middleware_pipeline: Any = None,
) -> tuple[list[Content], bool, bool]:
tools = _extract_tools(tool_options)
@@ -1524,6 +1663,7 @@ async def _execute_function_calls(
attempt_idx=attempt_idx,
function_calls=function_calls,
tools=tools, # type: ignore
invocation_session=invocation_session,
middleware_pipeline=middleware_pipeline,
config=config,
)
@@ -1733,7 +1873,10 @@ def _handle_function_call_results(
) -> FunctionRequestResult:
from ._types import Message
if any(fccr.type in {"function_approval_request", "function_call"} for fccr in function_call_results):
if any(
fccr.type in {"function_approval_request", "function_call"} or fccr.user_input_request
for fccr in function_call_results
):
# Only add items that aren't already in the message (e.g. function_approval_request wrappers).
# Declaration-only function_call items are already present from the LLM response.
new_items = [fccr for fccr in function_call_results if fccr.type != "function_call"]
@@ -1901,6 +2044,8 @@ class FunctionInvocationLayer(Generic[OptionsCoT]):
options: ChatOptions[ResponseModelBoundT],
compaction_strategy: CompactionStrategy | None = None,
tokenizer: TokenizerProtocol | None = None,
function_invocation_kwargs: Mapping[str, Any] | None = None,
client_kwargs: Mapping[str, Any] | None = None,
**kwargs: Any,
) -> Awaitable[ChatResponse[ResponseModelBoundT]]: ...
@@ -1913,6 +2058,8 @@ class FunctionInvocationLayer(Generic[OptionsCoT]):
options: OptionsCoT | ChatOptions[None] | None = None,
compaction_strategy: CompactionStrategy | None = None,
tokenizer: TokenizerProtocol | None = None,
function_invocation_kwargs: Mapping[str, Any] | None = None,
client_kwargs: Mapping[str, Any] | None = None,
**kwargs: Any,
) -> Awaitable[ChatResponse[Any]]: ...
@@ -1925,6 +2072,8 @@ class FunctionInvocationLayer(Generic[OptionsCoT]):
options: OptionsCoT | ChatOptions[Any] | None = None,
compaction_strategy: CompactionStrategy | None = None,
tokenizer: TokenizerProtocol | None = None,
function_invocation_kwargs: Mapping[str, Any] | None = None,
client_kwargs: Mapping[str, Any] | None = None,
**kwargs: Any,
) -> ResponseStream[ChatResponseUpdate, ChatResponse[Any]]: ...
@@ -1937,6 +2086,8 @@ class FunctionInvocationLayer(Generic[OptionsCoT]):
function_middleware: Sequence[FunctionMiddlewareTypes] | None = None,
compaction_strategy: CompactionStrategy | None = None,
tokenizer: TokenizerProtocol | None = None,
function_invocation_kwargs: Mapping[str, Any] | None = None,
client_kwargs: Mapping[str, Any] | None = None,
**kwargs: Any,
) -> Awaitable[ChatResponse[Any]] | ResponseStream[ChatResponseUpdate, ChatResponse[Any]]:
from ._middleware import FunctionMiddlewarePipeline
@@ -1947,28 +2098,45 @@ class FunctionInvocationLayer(Generic[OptionsCoT]):
)
super_get_response = super().get_response # type: ignore[misc]
if kwargs:
warnings.warn(
"Passing client-specific keyword arguments directly to get_response() is deprecated; "
"pass them via client_kwargs instead.",
DeprecationWarning,
stacklevel=2,
)
effective_client_kwargs = dict(client_kwargs) if client_kwargs is not None else {}
effective_function_middleware = function_middleware
if effective_function_middleware is None:
middleware_from_client_kwargs = effective_client_kwargs.pop("function_middleware", None)
if middleware_from_client_kwargs is not None:
effective_function_middleware = cast(Sequence[Any], middleware_from_client_kwargs)
# ChatMiddleware adds this kwarg
function_middleware_pipeline = FunctionMiddlewarePipeline(
*(self.function_middleware), *(function_middleware or [])
*(self.function_middleware), *(effective_function_middleware or [])
)
max_errors = self.function_invocation_configuration.get(
"max_consecutive_errors_per_request", DEFAULT_MAX_CONSECUTIVE_ERRORS_PER_REQUEST
)
additional_function_arguments: dict[str, Any] = {}
additional_function_arguments = (
dict(function_invocation_kwargs) if function_invocation_kwargs is not None else {}
)
if options and (additional_opts := options.get("additional_function_arguments")): # type: ignore[attr-defined]
additional_function_arguments = additional_opts # type: ignore
additional_function_arguments.update(cast(Mapping[str, Any], additional_opts))
from ._sessions import AgentSession as _AgentSession
raw_session = effective_client_kwargs.get("session")
invocation_session = raw_session if isinstance(raw_session, _AgentSession) else None
execute_function_calls = partial(
_execute_function_calls,
custom_args=additional_function_arguments,
config=self.function_invocation_configuration,
invocation_session=invocation_session,
middleware_pipeline=function_middleware_pipeline,
)
filtered_kwargs = {k: v for k, v in kwargs.items() if k != "session"}
if compaction_strategy is not None:
filtered_kwargs["compaction_strategy"] = compaction_strategy
if tokenizer is not None:
filtered_kwargs["tokenizer"] = tokenizer
filtered_kwargs = {k: v for k, v in {**effective_client_kwargs, **kwargs}.items() if k != "session"}
# Make options mutable so we can update conversation_id during function invocation loop
mutable_options: dict[str, Any] = dict(options) if options else {}
@@ -2018,7 +2186,9 @@ class FunctionInvocationLayer(Generic[OptionsCoT]):
messages=prepped_messages,
stream=False,
options=mutable_options,
**filtered_kwargs,
compaction_strategy=compaction_strategy,
tokenizer=tokenizer,
client_kwargs=filtered_kwargs,
),
)
@@ -2087,7 +2257,9 @@ class FunctionInvocationLayer(Generic[OptionsCoT]):
messages=prepped_messages,
stream=False,
options=mutable_options,
**filtered_kwargs,
compaction_strategy=compaction_strategy,
tokenizer=tokenizer,
client_kwargs=filtered_kwargs,
),
)
if fcc_messages:
@@ -2137,7 +2309,9 @@ class FunctionInvocationLayer(Generic[OptionsCoT]):
messages=prepped_messages,
stream=True,
options=mutable_options,
**filtered_kwargs,
compaction_strategy=compaction_strategy,
tokenizer=tokenizer,
client_kwargs=filtered_kwargs,
),
)
await inner_stream
@@ -2229,7 +2403,9 @@ class FunctionInvocationLayer(Generic[OptionsCoT]):
messages=prepped_messages,
stream=True,
options=mutable_options,
**filtered_kwargs,
compaction_strategy=compaction_strategy,
tokenizer=tokenizer,
client_kwargs=filtered_kwargs,
),
)
await final_inner_stream
@@ -2698,7 +2698,7 @@ class ResponseStream(AsyncIterable[UpdateT], Generic[UpdateT, FinalT]):
stream: AsyncIterable[UpdateT] | Awaitable[AsyncIterable[UpdateT]],
*,
finalizer: Callable[[Sequence[UpdateT]], FinalT | Awaitable[FinalT]] | None = None,
transform_hooks: list[Callable[[UpdateT], UpdateT | Awaitable[UpdateT] | None]] | None = None,
transform_hooks: list[Callable[[UpdateT], UpdateT | Awaitable[UpdateT | None] | None]] | None = None,
cleanup_hooks: list[Callable[[], Awaitable[None] | None]] | None = None,
result_hooks: list[Callable[[FinalT], FinalT | Awaitable[FinalT | None] | None]] | None = None,
) -> None:
@@ -2722,7 +2722,7 @@ class ResponseStream(AsyncIterable[UpdateT], Generic[UpdateT, FinalT]):
self._consumed: bool = False
self._finalized: bool = False
self._final_result: FinalT | None = None
self._transform_hooks: list[Callable[[UpdateT], UpdateT | Awaitable[UpdateT] | None]] = (
self._transform_hooks: list[Callable[[UpdateT], UpdateT | Awaitable[UpdateT | None] | None]] = (
transform_hooks if transform_hooks is not None else []
)
self._result_hooks: list[Callable[[FinalT], FinalT | Awaitable[FinalT | None] | None]] = (
@@ -2995,7 +2995,7 @@ class ResponseStream(AsyncIterable[UpdateT], Generic[UpdateT, FinalT]):
def with_transform_hook(
self,
hook: Callable[[UpdateT], UpdateT | Awaitable[UpdateT] | None],
hook: Callable[[UpdateT], UpdateT | Awaitable[UpdateT | None] | None],
) -> ResponseStream[UpdateT, FinalT]:
"""Register a transform hook executed for each update during iteration."""
self._transform_hooks.append(hook)
@@ -172,12 +172,12 @@ class AzureOpenAIChatClient( # type: ignore[misc]
credential: AzureCredentialTypes | AzureTokenProvider | None = None,
default_headers: Mapping[str, str] | None = None,
async_client: AsyncAzureOpenAI | None = None,
additional_properties: dict[str, Any] | None = None,
env_file_path: str | None = None,
env_file_encoding: str | None = None,
instruction_role: str | None = None,
middleware: Sequence[MiddlewareTypes] | None = None,
function_invocation_configuration: FunctionInvocationConfiguration | None = None,
**kwargs: Any,
) -> None:
"""Initialize an Azure OpenAI Chat completion client.
@@ -205,13 +205,13 @@ class AzureOpenAIChatClient( # type: ignore[misc]
default_headers: The default headers mapping of string keys to
string values for HTTP requests.
async_client: An existing client to use.
additional_properties: Additional properties stored on the client instance.
env_file_path: Use the environment settings file as a fallback to using env vars.
env_file_encoding: The encoding of the environment settings file, defaults to 'utf-8'.
instruction_role: The role to use for 'instruction' messages, for example, summarization
prompts could use `developer` or `system`.
middleware: Optional sequence of middleware to apply to requests.
function_invocation_configuration: Optional configuration for function invocation behavior.
kwargs: Other keyword parameters.
Examples:
.. code-block:: python
@@ -283,10 +283,10 @@ class AzureOpenAIChatClient( # type: ignore[misc]
credential=credential,
default_headers=default_headers,
client=async_client,
additional_properties=additional_properties,
instruction_role=instruction_role,
middleware=middleware,
function_invocation_configuration=function_invocation_configuration,
**kwargs,
)
@override
@@ -180,6 +180,34 @@ class ToolExecutionException(ToolException):
pass
class UserInputRequiredException(ToolException):
"""Raised when a tool wrapping a sub-agent requires user input to proceed.
This exception carries the ``user_input_request`` Content items emitted by
the sub-agent (e.g., ``oauth_consent_request``, ``function_approval_request``)
so the tool invocation layer can propagate them to the parent agent's response
instead of swallowing them as a generic tool error.
Args:
contents: The user-input-request Content items from the sub-agent response.
message: Human-readable description of why user input is needed.
"""
def __init__(
self,
contents: list[Any],
message: str = "Tool requires user input to proceed.",
) -> None:
"""Create a UserInputRequiredException.
Args:
contents: The user-input-request Content items from the sub-agent response.
message: Human-readable description of why user input is needed.
"""
super().__init__(message, log_level=None)
self.contents = contents
# endregion
# region Middleware Exceptions
@@ -1162,11 +1162,35 @@ class ChatTelemetryLayer(Generic[OptionsCoT]):
tokenizer: TokenizerProtocol | None = None,
**kwargs: Any,
) -> Awaitable[ChatResponse[Any]] | ResponseStream[ChatResponseUpdate, ChatResponse[Any]]:
"""Trace chat responses with OpenTelemetry spans and metrics."""
"""Trace chat responses with OpenTelemetry spans and metrics.
Args:
messages: The message or messages to send to the model.
stream: Whether to stream the response. Defaults to False.
options: Chat options as a TypedDict.
compaction_strategy: Optional compaction strategy to apply before model calls.
tokenizer: Optional tokenizer used by token-aware compaction strategies.
Keyword Args:
kwargs: Compatibility keyword arguments from higher client layers. This layer does
not consume ``function_invocation_kwargs`` directly; if present, it is ignored
because function invocation has already been processed above. If a ``client_kwargs``
mapping is present, it is flattened into ordinary keyword arguments for tracing and
forwarding so clients that use those values continue to work while clients that
ignore extra kwargs remain compatible.
"""
from ._types import ChatResponse, ChatResponseUpdate, ResponseStream # type: ignore[reportUnusedImport]
global OBSERVABILITY_SETTINGS
super_get_response = super().get_response # type: ignore[misc]
compatibility_client_kwargs = kwargs.pop("client_kwargs", None)
kwargs.pop("function_invocation_kwargs", None)
merged_client_kwargs = (
dict(cast(Mapping[str, Any], compatibility_client_kwargs))
if isinstance(compatibility_client_kwargs, Mapping)
else {}
)
merged_client_kwargs.update(kwargs)
if not OBSERVABILITY_SETTINGS.ENABLED:
return super_get_response( # type: ignore[no-any-return]
@@ -1175,12 +1199,14 @@ class ChatTelemetryLayer(Generic[OptionsCoT]):
options=options,
compaction_strategy=compaction_strategy,
tokenizer=tokenizer,
**kwargs,
**merged_client_kwargs,
)
opts: dict[str, Any] = options or {} # type: ignore[assignment]
provider_name = str(getattr(self, "otel_provider_name", "unknown"))
model_id = kwargs.get("model_id") or opts.get("model_id") or getattr(self, "model_id", None) or "unknown"
model_id = (
merged_client_kwargs.get("model_id") or opts.get("model_id") or getattr(self, "model_id", None) or "unknown"
)
service_url_func = getattr(self, "service_url", None)
service_url = str(service_url_func() if callable(service_url_func) else "unknown")
attributes = _get_span_attributes(
@@ -1188,7 +1214,7 @@ class ChatTelemetryLayer(Generic[OptionsCoT]):
provider_name=provider_name,
model=model_id,
service_url=service_url,
**kwargs,
**merged_client_kwargs,
)
if stream:
@@ -1200,7 +1226,7 @@ class ChatTelemetryLayer(Generic[OptionsCoT]):
options=opts,
compaction_strategy=compaction_strategy,
tokenizer=tokenizer,
**kwargs,
**merged_client_kwargs,
),
)
@@ -1291,7 +1317,7 @@ class ChatTelemetryLayer(Generic[OptionsCoT]):
options=opts,
compaction_strategy=compaction_strategy,
tokenizer=tokenizer,
**kwargs,
**merged_client_kwargs,
),
)
except Exception as exception:
@@ -1420,6 +1446,8 @@ class AgentTelemetryLayer:
session: AgentSession | None = None,
compaction_strategy: CompactionStrategy | None = None,
tokenizer: TokenizerProtocol | None = None,
function_invocation_kwargs: Mapping[str, Any] | None = None,
client_kwargs: Mapping[str, Any] | None = None,
**kwargs: Any,
) -> Awaitable[AgentResponse[Any]]: ...
@@ -1432,6 +1460,8 @@ class AgentTelemetryLayer:
session: AgentSession | None = None,
compaction_strategy: CompactionStrategy | None = None,
tokenizer: TokenizerProtocol | None = None,
function_invocation_kwargs: Mapping[str, Any] | None = None,
client_kwargs: Mapping[str, Any] | None = None,
**kwargs: Any,
) -> ResponseStream[AgentResponseUpdate, AgentResponse[Any]]: ...
@@ -1443,6 +1473,8 @@ class AgentTelemetryLayer:
session: AgentSession | None = None,
compaction_strategy: CompactionStrategy | None = None,
tokenizer: TokenizerProtocol | None = None,
function_invocation_kwargs: Mapping[str, Any] | None = None,
client_kwargs: Mapping[str, Any] | None = None,
**kwargs: Any,
) -> Awaitable[AgentResponse[Any]] | ResponseStream[AgentResponseUpdate, AgentResponse[Any]]:
"""Trace agent runs with OpenTelemetry spans and metrics."""
@@ -1463,11 +1495,15 @@ class AgentTelemetryLayer:
session=session,
compaction_strategy=compaction_strategy,
tokenizer=tokenizer,
function_invocation_kwargs=function_invocation_kwargs,
client_kwargs=client_kwargs,
**kwargs,
)
default_options = getattr(self, "default_options", {})
options = kwargs.get("options")
merged_client_kwargs = dict(client_kwargs) if client_kwargs is not None else {}
merged_client_kwargs.update(kwargs)
merged_options: dict[str, Any] = merge_chat_options(default_options, options or {})
attributes = _get_span_attributes(
operation_name=OtelAttr.AGENT_INVOKE_OPERATION,
@@ -1477,7 +1513,7 @@ class AgentTelemetryLayer:
agent_description=getattr(self, "description", None),
thread_id=session.service_session_id if session else None,
all_options=merged_options,
**kwargs,
**merged_client_kwargs,
)
if stream:
@@ -1487,6 +1523,8 @@ class AgentTelemetryLayer:
session=session,
compaction_strategy=compaction_strategy,
tokenizer=tokenizer,
function_invocation_kwargs=function_invocation_kwargs,
client_kwargs=client_kwargs,
**kwargs,
)
if isinstance(run_result, ResponseStream):
@@ -1578,6 +1616,8 @@ class AgentTelemetryLayer:
session=session,
compaction_strategy=compaction_strategy,
tokenizer=tokenizer,
function_invocation_kwargs=function_invocation_kwargs,
client_kwargs=client_kwargs,
**kwargs,
)
except Exception as exception:
@@ -15,7 +15,7 @@ from collections.abc import (
)
from datetime import datetime, timezone
from itertools import chain
from typing import Any, Generic, Literal, cast
from typing import Any, Generic, Literal, cast, overload
from openai import AsyncOpenAI, BadRequestError
from openai.lib._parsing._completions import type_to_response_format_param
@@ -30,7 +30,8 @@ from openai.types.chat.completion_create_params import WebSearchOptions
from pydantic import BaseModel
from .._clients import BaseChatClient
from .._middleware import ChatAndFunctionMiddlewareTypes, ChatMiddlewareLayer
from .._docstrings import apply_layered_docstring
from .._middleware import ChatAndFunctionMiddlewareTypes, ChatMiddlewareLayer, FunctionMiddlewareTypes
from .._settings import load_settings
from .._tools import (
FunctionInvocationConfiguration,
@@ -72,6 +73,7 @@ else:
logger = logging.getLogger("agent_framework.openai")
ResponseModelBoundT = TypeVar("ResponseModelBoundT", bound=BaseModel)
ResponseModelT = TypeVar("ResponseModelT", bound=BaseModel | None, default=None)
@@ -213,6 +215,57 @@ class RawOpenAIChatClient( # type: ignore[misc]
# endregion
@overload
def get_response(
self,
messages: Sequence[Message],
*,
stream: Literal[False] = ...,
options: ChatOptions[ResponseModelBoundT],
**kwargs: Any,
) -> Awaitable[ChatResponse[ResponseModelBoundT]]: ...
@overload
def get_response(
self,
messages: Sequence[Message],
*,
stream: Literal[False] = ...,
options: OpenAIChatOptionsT | ChatOptions[None] | None = None,
**kwargs: Any,
) -> Awaitable[ChatResponse[Any]]: ...
@overload
def get_response(
self,
messages: Sequence[Message],
*,
stream: Literal[True],
options: OpenAIChatOptionsT | ChatOptions[Any] | None = None,
**kwargs: Any,
) -> ResponseStream[ChatResponseUpdate, ChatResponse[Any]]: ...
@override
def get_response(
self,
messages: Sequence[Message],
*,
stream: bool = False,
options: OpenAIChatOptionsT | ChatOptions[Any] | None = None,
**kwargs: Any,
) -> Awaitable[ChatResponse[Any]] | ResponseStream[ChatResponseUpdate, ChatResponse[Any]]:
"""Get a response from the raw OpenAI chat client."""
super_get_response = cast(
"Callable[..., Awaitable[ChatResponse[Any]] | ResponseStream[ChatResponseUpdate, ChatResponse[Any]]]",
super().get_response, # type: ignore[misc]
)
return super_get_response( # type: ignore[no-any-return]
messages=messages,
stream=stream,
options=options,
**kwargs,
)
@override
def _inner_get_response(
self,
@@ -727,6 +780,77 @@ class OpenAIChatClient( # type: ignore[misc]
):
"""OpenAI Chat completion class with middleware, telemetry, and function invocation support."""
@overload
def get_response(
self,
messages: Sequence[Message],
*,
stream: Literal[False] = ...,
options: ChatOptions[ResponseModelBoundT],
function_middleware: Sequence[FunctionMiddlewareTypes] | None = None,
function_invocation_kwargs: Mapping[str, Any] | None = None,
client_kwargs: Mapping[str, Any] | None = None,
middleware: Sequence[ChatAndFunctionMiddlewareTypes] | None = None,
**kwargs: Any,
) -> Awaitable[ChatResponse[ResponseModelBoundT]]: ...
@overload
def get_response(
self,
messages: Sequence[Message],
*,
stream: Literal[False] = ...,
options: OpenAIChatOptionsT | ChatOptions[None] | None = None,
function_middleware: Sequence[FunctionMiddlewareTypes] | None = None,
function_invocation_kwargs: Mapping[str, Any] | None = None,
client_kwargs: Mapping[str, Any] | None = None,
middleware: Sequence[ChatAndFunctionMiddlewareTypes] | None = None,
**kwargs: Any,
) -> Awaitable[ChatResponse[Any]]: ...
@overload
def get_response(
self,
messages: Sequence[Message],
*,
stream: Literal[True],
options: OpenAIChatOptionsT | ChatOptions[Any] | None = None,
function_middleware: Sequence[FunctionMiddlewareTypes] | None = None,
function_invocation_kwargs: Mapping[str, Any] | None = None,
client_kwargs: Mapping[str, Any] | None = None,
middleware: Sequence[ChatAndFunctionMiddlewareTypes] | None = None,
**kwargs: Any,
) -> ResponseStream[ChatResponseUpdate, ChatResponse[Any]]: ...
@override
def get_response(
self,
messages: Sequence[Message],
*,
stream: bool = False,
options: OpenAIChatOptionsT | ChatOptions[Any] | None = None,
function_middleware: Sequence[FunctionMiddlewareTypes] | None = None,
function_invocation_kwargs: Mapping[str, Any] | None = None,
client_kwargs: Mapping[str, Any] | None = None,
middleware: Sequence[ChatAndFunctionMiddlewareTypes] | None = None,
**kwargs: Any,
) -> Awaitable[ChatResponse[Any]] | ResponseStream[ChatResponseUpdate, ChatResponse[Any]]:
"""Get a response from the OpenAI chat client with all standard layers enabled."""
super_get_response = cast(
"Callable[..., Awaitable[ChatResponse[Any]] | ResponseStream[ChatResponseUpdate, ChatResponse[Any]]]",
super().get_response, # type: ignore[misc]
)
return super_get_response( # type: ignore[no-any-return]
messages=messages,
stream=stream,
options=options,
function_middleware=function_middleware,
function_invocation_kwargs=function_invocation_kwargs,
client_kwargs=client_kwargs,
middleware=middleware,
**kwargs,
)
def __init__(
self,
*,
@@ -830,3 +954,25 @@ class OpenAIChatClient( # type: ignore[misc]
middleware=middleware,
function_invocation_configuration=function_invocation_configuration,
)
def _apply_openai_chat_client_docstrings() -> None:
"""Align OpenAI chat-client docstrings with the raw implementation."""
apply_layered_docstring(RawOpenAIChatClient.get_response, BaseChatClient.get_response)
apply_layered_docstring(
OpenAIChatClient.get_response,
RawOpenAIChatClient.get_response,
extra_keyword_args={
"function_middleware": """
Optional per-call function middleware.
When omitted, middleware configured on the client or forwarded from higher layers is used.
""",
"middleware": """
Optional per-call chat and function middleware.
This is merged with any middleware configured on the client for the current request.
""",
},
)
_apply_openai_chat_client_docstrings()
+157 -51
View File
@@ -1,6 +1,7 @@
# Copyright (c) Microsoft. All rights reserved.
import contextlib
import inspect
from collections.abc import AsyncIterable, MutableSequence
from typing import Any
from unittest.mock import AsyncMock, MagicMock
@@ -31,6 +32,7 @@ from agent_framework import (
)
from agent_framework._agents import _get_tool_name, _merge_options, _sanitize_agent_name
from agent_framework._mcp import MCPTool, _build_prefixed_mcp_name, _normalize_mcp_name
from agent_framework._middleware import FunctionInvocationContext
class _FixedTokenizer:
@@ -101,6 +103,30 @@ def test_chat_client_agent_type(client: SupportsChatGetResponse) -> None:
assert isinstance(chat_client_agent, SupportsAgentRun)
def test_agent_init_docstring_surfaces_raw_agent_constructor_docs() -> None:
docstring = inspect.getdoc(Agent.__init__)
assert docstring is not None
assert "client: The chat client to use for the agent." in docstring
assert "middleware: List of middleware to intercept agent and function invocations." in docstring
def test_agent_run_docstring_surfaces_raw_agent_runtime_docs() -> None:
docstring = inspect.getdoc(Agent.run)
assert docstring is not None
assert "Run the agent with the given messages and options." in docstring
assert "function_invocation_kwargs: Keyword arguments forwarded to tool invocation." in docstring
assert "middleware: Optional per-run agent, chat, and function middleware." in docstring
def test_agent_run_is_defined_on_agent_class() -> None:
signature = inspect.signature(Agent.run)
assert Agent.run.__qualname__ == "Agent.run"
assert "middleware" in signature.parameters
async def test_chat_client_agent_init(client: SupportsChatGetResponse) -> None:
agent_id = str(uuid4())
agent = Agent(client=client, id=agent_id, description="Test")
@@ -121,6 +147,13 @@ async def test_chat_client_agent_init_with_name(
assert agent.description == "Test"
def test_agent_init_warns_for_direct_additional_properties(client: SupportsChatGetResponse) -> None:
with pytest.warns(DeprecationWarning, match="additional_properties"):
agent = Agent(client=client, legacy_key="legacy-value")
assert agent.additional_properties["legacy_key"] == "legacy-value"
async def test_chat_client_agent_run(client: SupportsChatGetResponse) -> None:
agent = Agent(client=client)
@@ -253,33 +286,38 @@ async def test_prepare_session_does_not_mutate_agent_chat_options(
assert len(agent.default_options["tools"]) == 1
async def test_prepare_run_context_keeps_compaction_overrides_out_of_kwargs(
async def test_prepare_run_context_handles_function_kwargs(
chat_client_base: SupportsChatGetResponse,
) -> None:
strategy = SlidingWindowStrategy(keep_last_groups=2)
tokenizer = _FixedTokenizer(13)
agent = Agent(client=chat_client_base)
session = agent.create_session()
ctx = await agent._prepare_run_context( # type: ignore[reportPrivateUsage]
messages=[Message(role="user", text="Hello")],
session=None,
messages="Hello",
session=session,
tools=None,
options=None,
compaction_strategy=strategy,
tokenizer=tokenizer,
kwargs={"custom_flag": True},
options={
"temperature": 0.4,
"additional_function_arguments": {"from_options": "options-value"},
},
compaction_strategy=None,
tokenizer=None,
legacy_kwargs={"legacy_key": "legacy-value"},
function_invocation_kwargs={"runtime_key": "runtime-value"},
client_kwargs={"client_key": "client-value"},
)
assert ctx["compaction_strategy"] is strategy
assert ctx["tokenizer"] is tokenizer
assert ctx["filtered_kwargs"].get("custom_flag") is True
assert "compaction_strategy" not in ctx["filtered_kwargs"]
assert "tokenizer" not in ctx["filtered_kwargs"]
assert ctx["chat_options"]["temperature"] == 0.4
assert "additional_function_arguments" not in ctx["chat_options"]
assert ctx["function_invocation_kwargs"]["from_options"] == "options-value"
assert ctx["function_invocation_kwargs"]["legacy_key"] == "legacy-value"
assert ctx["function_invocation_kwargs"]["runtime_key"] == "runtime-value"
assert "session" not in ctx["function_invocation_kwargs"]
assert ctx["client_kwargs"]["client_key"] == "client-value"
assert ctx["client_kwargs"]["session"] is session
async def test_chat_client_agent_run_with_session(
chat_client_base: SupportsChatGetResponse,
) -> None:
async def test_chat_client_agent_run_with_session(chat_client_base: SupportsChatGetResponse) -> None:
mock_response = ChatResponse(
messages=[Message(role="assistant", contents=[Content.from_text("test response")])],
conversation_id="123",
@@ -720,8 +758,9 @@ async def test_chat_agent_as_tool_basic(client: SupportsChatGetResponse) -> None
assert tool.name == "TestAgent"
assert tool.description == "Test agent for as_tool"
assert tool.approval_mode == "never_require"
assert hasattr(tool, "func")
assert hasattr(tool, "input_model")
assert tool.input_model is None
async def test_chat_agent_as_tool_custom_parameters(
@@ -735,13 +774,15 @@ async def test_chat_agent_as_tool_custom_parameters(
description="Custom description",
arg_name="query",
arg_description="Custom input description",
approval_mode="always_require",
)
assert tool.name == "CustomTool"
assert tool.description == "Custom description"
assert tool.approval_mode == "always_require"
# Check that the input model has the custom field name
schema = tool.input_model.model_json_schema()
schema = tool.parameters()
assert "query" in schema["properties"]
assert schema["properties"]["query"]["description"] == "Custom input description"
@@ -760,7 +801,7 @@ async def test_chat_agent_as_tool_defaults(client: SupportsChatGetResponse) -> N
assert tool.description == "" # Should default to empty string
# Check default input field
schema = tool.input_model.model_json_schema()
schema = tool.parameters()
assert "task" in schema["properties"]
assert "Task for TestAgent" in schema["properties"]["task"]["description"]
@@ -783,12 +824,12 @@ async def test_chat_agent_as_tool_function_execution(
tool = agent.as_tool()
# Test function execution
result = await tool.invoke(arguments=tool.input_model(task="Hello"))
result = await tool.invoke(arguments={"task": "Hello"})
# Should return the agent's response text as a list of Content items
assert isinstance(result, list)
assert len(result) == 1
assert result[0].text == "test response" # From mock chat client
assert result[0].text == "test streaming response another update" # From mock streaming client
async def test_chat_agent_as_tool_with_stream_callback(
@@ -806,7 +847,7 @@ async def test_chat_agent_as_tool_with_stream_callback(
tool = agent.as_tool(stream_callback=stream_callback)
# Execute the tool
result = await tool.invoke(arguments=tool.input_model(task="Hello"))
result = await tool.invoke(arguments={"task": "Hello"})
# Should have collected streaming updates
assert len(collected_updates) > 0
@@ -826,9 +867,9 @@ async def test_chat_agent_as_tool_with_custom_arg_name(
tool = agent.as_tool(arg_name="prompt", arg_description="Custom prompt input")
# Test that the custom argument name works
result = await tool.invoke(arguments=tool.input_model(prompt="Test prompt"))
result = await tool.invoke(arguments={"prompt": "Test prompt"})
assert isinstance(result, list)
assert result[0].text == "test response"
assert result[0].text == "test streaming response another update"
async def test_chat_agent_as_tool_with_async_stream_callback(
@@ -846,7 +887,7 @@ async def test_chat_agent_as_tool_with_async_stream_callback(
tool = agent.as_tool(stream_callback=async_stream_callback)
# Execute the tool
result = await tool.invoke(arguments=tool.input_model(task="Hello"))
result = await tool.invoke(arguments={"task": "Hello"})
# Should have collected streaming updates
assert len(collected_updates) > 0
@@ -877,17 +918,14 @@ async def test_chat_agent_as_tool_name_sanitization(
assert tool.name == expected_tool_name, f"Expected {expected_tool_name}, got {tool.name} for input {agent_name}"
async def test_chat_agent_as_tool_propagate_session_true(
client: SupportsChatGetResponse,
) -> None:
"""Test that propagate_session=True forwards the parent's session to the sub-agent."""
async def test_chat_agent_as_tool_propagate_session_true(client: SupportsChatGetResponse) -> None:
"""Test that propagate_session=True forwards the session to the sub-agent."""
agent = Agent(client=client, name="SubAgent", description="Sub agent")
tool = agent.as_tool(propagate_session=True)
parent_session = AgentSession(session_id="parent-session-123")
parent_session.state["shared_key"] = "shared_value"
# Spy on the agent's run method to capture the session argument
original_run = agent.run
captured_session = None
@@ -898,16 +936,20 @@ async def test_chat_agent_as_tool_propagate_session_true(
agent.run = capturing_run # type: ignore[assignment, method-assign]
await tool.invoke(arguments=tool.input_model(task="Hello"), session=parent_session)
await tool.invoke(
context=FunctionInvocationContext(
function=tool,
arguments={"task": "Hello"},
session=parent_session,
)
)
assert captured_session is parent_session
assert captured_session.session_id == "parent-session-123"
assert captured_session.state["shared_key"] == "shared_value"
async def test_chat_agent_as_tool_propagate_session_false_by_default(
client: SupportsChatGetResponse,
) -> None:
async def test_chat_agent_as_tool_propagate_session_false_by_default(client: SupportsChatGetResponse) -> None:
"""Test that propagate_session defaults to False and does not forward the session."""
agent = Agent(client=client, name="SubAgent", description="Sub agent")
tool = agent.as_tool() # default: propagate_session=False
@@ -924,22 +966,25 @@ async def test_chat_agent_as_tool_propagate_session_false_by_default(
agent.run = capturing_run # type: ignore[assignment, method-assign]
await tool.invoke(arguments=tool.input_model(task="Hello"), session=parent_session)
await tool.invoke(
context=FunctionInvocationContext(
function=tool,
arguments={"task": "Hello"},
session=parent_session,
)
)
assert captured_session is None
async def test_chat_agent_as_tool_propagate_session_shares_state(
client: SupportsChatGetResponse,
) -> None:
"""Test that shared session allows the sub-agent to read and write parent's state."""
async def test_chat_agent_as_tool_propagate_session_shares_state(client: SupportsChatGetResponse) -> None:
"""Test that a propagated session allows the sub-agent to read and write parent state."""
agent = Agent(client=client, name="SubAgent", description="Sub agent")
tool = agent.as_tool(propagate_session=True)
parent_session = AgentSession(session_id="shared-session")
parent_session.state["counter"] = 0
# The sub-agent receives the same session object, so mutations are shared
original_run = agent.run
captured_session = None
@@ -952,9 +997,14 @@ async def test_chat_agent_as_tool_propagate_session_shares_state(
agent.run = capturing_run # type: ignore[assignment, method-assign]
await tool.invoke(arguments=tool.input_model(task="Hello"), session=parent_session)
await tool.invoke(
context=FunctionInvocationContext(
function=tool,
arguments={"task": "Hello"},
session=parent_session,
)
)
# The parent's state should reflect the sub-agent's mutation
assert parent_session.state["counter"] == 1
@@ -1131,7 +1181,7 @@ async def test_agent_run_accepts_prefixed_mcp_tools(chat_client_base: Any) -> No
async def test_agent_tool_receives_session_in_kwargs(chat_client_base: Any) -> None:
"""Verify tool execution receives 'session' inside **kwargs when function is called by client."""
"""Verify legacy **kwargs tools receive the session when agent.run() is called with one."""
captured: dict[str, Any] = {}
@@ -1142,7 +1192,6 @@ async def test_agent_tool_receives_session_in_kwargs(chat_client_base: Any) -> N
captured["has_state"] = session.state is not None if isinstance(session, AgentSession) else False
return f"echo: {text}"
# Make the base client emit a function call for our tool
chat_client_base.run_responses = [
ChatResponse(
messages=Message(
@@ -1162,17 +1211,52 @@ async def test_agent_tool_receives_session_in_kwargs(chat_client_base: Any) -> N
agent = Agent(client=chat_client_base, tools=[echo_session_info])
session = agent.create_session()
result = await agent.run(
"hello",
session=session,
options={"additional_function_arguments": {"session": session}},
)
result = await agent.run("hello", session=session)
assert result.text == "done"
assert captured.get("has_session") is True
assert captured.get("has_state") is True
async def test_agent_tool_receives_explicit_session_via_function_invocation_context_kwargs(
chat_client_base: Any,
) -> None:
"""Verify ctx-based tools receive the session via FunctionInvocationContext.session."""
captured: dict[str, Any] = {}
@tool(name="capture_session_context", approval_mode="never_require")
def capture_session_context(text: str, ctx: FunctionInvocationContext) -> str:
captured["session"] = ctx.session
captured["has_state"] = ctx.session.state is not None if isinstance(ctx.session, AgentSession) else False
return f"echo: {text}"
chat_client_base.run_responses = [
ChatResponse(
messages=Message(
role="assistant",
contents=[
Content.from_function_call(
call_id="1",
name="capture_session_context",
arguments='{"text": "hello"}',
)
],
)
),
ChatResponse(messages=Message(role="assistant", text="done")),
]
agent = Agent(client=chat_client_base, tools=[capture_session_context])
session = agent.create_session()
result = await agent.run("hello", session=session)
assert result.text == "done"
assert captured["session"] is session
assert captured["has_state"] is True
async def test_chat_agent_tool_choice_run_level_overrides_agent_level(chat_client_base: Any, tool_tool: Any) -> None:
"""Verify that tool_choice passed to run() overrides agent-level tool_choice."""
@@ -1859,4 +1943,26 @@ async def test_stores_by_default_with_store_false_in_default_options_injects_inm
assert any(isinstance(p, InMemoryHistoryProvider) for p in agent.context_providers)
# endregion
# region as_tool user_input_request propagation
async def test_as_tool_raises_on_user_input_request(client: SupportsChatGetResponse) -> None:
"""Test that as_tool raises when the wrapped sub-agent requests user input."""
from agent_framework.exceptions import UserInputRequiredException
consent_content = Content.from_oauth_consent_request(
consent_link="https://login.microsoftonline.com/consent",
)
client.streaming_responses = [ # type: ignore[attr-defined]
[ChatResponseUpdate(contents=[consent_content], role="assistant")],
]
agent = Agent(client=client, name="OAuthAgent", description="Agent requiring consent")
agent_tool = agent.as_tool()
with raises(UserInputRequiredException) as exc_info:
await agent_tool.invoke(arguments={"task": "Do something"})
assert len(exc_info.value.contents) == 1
assert exc_info.value.contents[0].type == "oauth_consent_request"
assert exc_info.value.contents[0].consent_link == "https://login.microsoftonline.com/consent"
@@ -6,7 +6,7 @@ from collections.abc import Awaitable, Callable
from typing import Any
from agent_framework import Agent, ChatResponse, Content, Message, agent_middleware
from agent_framework._middleware import AgentContext
from agent_framework._middleware import AgentContext, FunctionInvocationContext
from .conftest import MockChatClient
@@ -14,14 +14,28 @@ from .conftest import MockChatClient
class TestAsToolKwargsPropagation:
"""Test cases for kwargs propagation through as_tool() delegation."""
@staticmethod
def _build_context(
tool: Any,
*,
task: str,
runtime_kwargs: dict[str, Any] | None = None,
) -> FunctionInvocationContext:
return FunctionInvocationContext(
function=tool,
arguments={"task": task},
kwargs=runtime_kwargs,
)
async def test_as_tool_forwards_runtime_kwargs(self, client: MockChatClient) -> None:
"""Test that runtime kwargs are forwarded through as_tool() to sub-agent."""
"""Test that runtime kwargs are forwarded through as_tool() to sub-agent tools."""
captured_kwargs: dict[str, Any] = {}
captured_function_invocation_kwargs: dict[str, Any] = {}
@agent_middleware
async def capture_middleware(context: AgentContext, call_next: Callable[[], Awaitable[None]]) -> None:
# Capture kwargs passed to the sub-agent
captured_kwargs.update(context.kwargs)
captured_function_invocation_kwargs.update(context.function_invocation_kwargs)
await call_next()
# Setup mock response
@@ -39,29 +53,31 @@ class TestAsToolKwargsPropagation:
# Create tool from sub-agent
tool = sub_agent.as_tool(name="delegate", arg_name="task")
# Directly invoke the tool with kwargs (simulating what happens during agent execution)
# Directly invoke the tool with explicit runtime context (simulating agent execution).
_ = await tool.invoke(
arguments=tool.input_model(task="Test delegation"),
api_token="secret-xyz-123",
user_id="user-456",
session_id="session-789",
context=self._build_context(
tool,
task="Test delegation",
runtime_kwargs={
"api_token": "secret-xyz-123",
"user_id": "user-456",
"session_id": "session-789",
},
),
)
# Verify kwargs were forwarded to sub-agent
assert "api_token" in captured_kwargs, f"Expected 'api_token' in {captured_kwargs}"
assert captured_kwargs["api_token"] == "secret-xyz-123"
assert "user_id" in captured_kwargs
assert captured_kwargs["user_id"] == "user-456"
assert "session_id" in captured_kwargs
assert captured_kwargs["session_id"] == "session-789"
assert captured_kwargs == {}
assert captured_function_invocation_kwargs["api_token"] == "secret-xyz-123"
assert captured_function_invocation_kwargs["user_id"] == "user-456"
assert captured_function_invocation_kwargs["session_id"] == "session-789"
async def test_as_tool_excludes_arg_name_from_forwarded_kwargs(self, client: MockChatClient) -> None:
"""Test that the arg_name parameter is not forwarded as a kwarg."""
captured_kwargs: dict[str, Any] = {}
async def test_as_tool_forwards_context_kwargs_verbatim(self, client: MockChatClient) -> None:
"""Test that runtime kwargs are forwarded exactly from FunctionInvocationContext.kwargs."""
captured_function_invocation_kwargs: dict[str, Any] = {}
@agent_middleware
async def capture_middleware(context: AgentContext, call_next: Callable[[], Awaitable[None]]) -> None:
captured_kwargs.update(context.kwargs)
captured_function_invocation_kwargs.update(context.function_invocation_kwargs)
await call_next()
# Setup mock response
@@ -79,25 +95,26 @@ class TestAsToolKwargsPropagation:
# Invoke tool with both the arg_name field and additional kwargs
await tool.invoke(
arguments=tool.input_model(custom_task="Test task"),
api_token="token-123",
custom_task="should_be_excluded", # This should be filtered out
context=FunctionInvocationContext(
function=tool,
arguments={"custom_task": "Test task"},
kwargs={
"api_token": "token-123",
"custom_task": "should_be_excluded",
},
)
)
# The arg_name ("custom_task") should NOT be in the forwarded kwargs
assert "custom_task" not in captured_kwargs
# But other kwargs should be present
assert "api_token" in captured_kwargs
assert captured_kwargs["api_token"] == "token-123"
assert captured_function_invocation_kwargs["custom_task"] == "should_be_excluded"
assert captured_function_invocation_kwargs["api_token"] == "token-123"
async def test_as_tool_nested_delegation_propagates_kwargs(self, client: MockChatClient) -> None:
"""Test that kwargs propagate through multiple levels of delegation (A B C)."""
captured_kwargs_list: list[dict[str, Any]] = []
"""Test that runtime kwargs propagate through multiple levels of delegation (A -> B -> C)."""
captured_function_invocation_kwargs_list: list[dict[str, Any]] = []
@agent_middleware
async def capture_middleware(context: AgentContext, call_next: Callable[[], Awaitable[None]]) -> None:
# Capture kwargs at each level
captured_kwargs_list.append(dict(context.kwargs))
captured_function_invocation_kwargs_list.append(dict(context.function_invocation_kwargs))
await call_next()
# Setup mock responses to trigger nested tool invocation: B calls tool C, then completes.
@@ -140,24 +157,29 @@ class TestAsToolKwargsPropagation:
# Invoke tool B with kwargs - should propagate to both B and C
await tool_b.invoke(
arguments=tool_b.input_model(task="Test cascade"),
trace_id="trace-abc-123",
tenant_id="tenant-xyz",
options={"additional_function_arguments": {"trace_id": "trace-abc-123", "tenant_id": "tenant-xyz"}},
context=self._build_context(
tool_b,
task="Test cascade",
runtime_kwargs={
"trace_id": "trace-abc-123",
"tenant_id": "tenant-xyz",
},
),
)
# Verify kwargs were forwarded to the first agent invocation.
assert len(captured_kwargs_list) >= 1
assert captured_kwargs_list[0].get("trace_id") == "trace-abc-123"
assert captured_kwargs_list[0].get("tenant_id") == "tenant-xyz"
assert len(captured_function_invocation_kwargs_list) >= 1
assert captured_function_invocation_kwargs_list[0].get("trace_id") == "trace-abc-123"
assert captured_function_invocation_kwargs_list[0].get("tenant_id") == "tenant-xyz"
async def test_as_tool_streaming_mode_forwards_kwargs(self, client: MockChatClient) -> None:
"""Test that kwargs are forwarded in streaming mode."""
"""Test that runtime kwargs are forwarded in streaming mode."""
captured_kwargs: dict[str, Any] = {}
captured_function_invocation_kwargs: dict[str, Any] = {}
@agent_middleware
async def capture_middleware(context: AgentContext, call_next: Callable[[], Awaitable[None]]) -> None:
captured_kwargs.update(context.kwargs)
captured_function_invocation_kwargs.update(context.function_invocation_kwargs)
await call_next()
# Setup mock streaming responses
@@ -182,13 +204,15 @@ class TestAsToolKwargsPropagation:
# Invoke tool with kwargs while streaming callback is active
await tool.invoke(
arguments=tool.input_model(task="Test streaming"),
api_key="streaming-key-999",
context=self._build_context(
tool,
task="Test streaming",
runtime_kwargs={"api_key": "streaming-key-999"},
),
)
# Verify kwargs were forwarded even in streaming mode
assert "api_key" in captured_kwargs
assert captured_kwargs["api_key"] == "streaming-key-999"
assert captured_kwargs == {}
assert captured_function_invocation_kwargs["api_key"] == "streaming-key-999"
assert len(captured_updates) == 1
async def test_as_tool_empty_kwargs_still_works(self, client: MockChatClient) -> None:
@@ -206,18 +230,20 @@ class TestAsToolKwargsPropagation:
tool = sub_agent.as_tool()
# Invoke without any extra kwargs - should work without errors
result = await tool.invoke(arguments=tool.input_model(task="Simple task"))
result = await tool.invoke(arguments={"task": "Simple task"})
# Verify tool executed successfully
assert result is not None
async def test_as_tool_kwargs_with_chat_options(self, client: MockChatClient) -> None:
"""Test that kwargs including chat_options are properly forwarded."""
"""Test that runtime kwargs are forwarded only via function_invocation_kwargs."""
captured_kwargs: dict[str, Any] = {}
captured_function_invocation_kwargs: dict[str, Any] = {}
@agent_middleware
async def capture_middleware(context: AgentContext, call_next: Callable[[], Awaitable[None]]) -> None:
captured_kwargs.update(context.kwargs)
captured_function_invocation_kwargs.update(context.function_invocation_kwargs)
await call_next()
# Setup mock response
@@ -235,24 +261,26 @@ class TestAsToolKwargsPropagation:
# Invoke with various kwargs
await tool.invoke(
arguments=tool.input_model(task="Test with options"),
temperature=0.8,
max_tokens=500,
custom_param="custom_value",
context=self._build_context(
tool,
task="Test with options",
runtime_kwargs={
"temperature": 0.8,
"max_tokens": 500,
"custom_param": "custom_value",
},
),
)
# Verify all kwargs were forwarded
assert "temperature" in captured_kwargs
assert captured_kwargs["temperature"] == 0.8
assert "max_tokens" in captured_kwargs
assert captured_kwargs["max_tokens"] == 500
assert "custom_param" in captured_kwargs
assert captured_kwargs["custom_param"] == "custom_value"
assert captured_kwargs == {}
assert captured_function_invocation_kwargs["temperature"] == 0.8
assert captured_function_invocation_kwargs["max_tokens"] == 500
assert captured_function_invocation_kwargs["custom_param"] == "custom_value"
async def test_as_tool_kwargs_isolated_per_invocation(self, client: MockChatClient) -> None:
"""Test that kwargs are isolated per invocation and don't leak between calls."""
first_call_kwargs: dict[str, Any] = {}
second_call_kwargs: dict[str, Any] = {}
"""Test that runtime kwargs are isolated per invocation and don't leak between calls."""
first_call_function_invocation_kwargs: dict[str, Any] = {}
second_call_function_invocation_kwargs: dict[str, Any] = {}
call_count = 0
@agent_middleware
@@ -260,9 +288,9 @@ class TestAsToolKwargsPropagation:
nonlocal call_count
call_count += 1
if call_count == 1:
first_call_kwargs.update(context.kwargs)
first_call_function_invocation_kwargs.update(context.function_invocation_kwargs)
elif call_count == 2:
second_call_kwargs.update(context.kwargs)
second_call_function_invocation_kwargs.update(context.function_invocation_kwargs)
await call_next()
# Setup mock responses for both calls
@@ -281,33 +309,35 @@ class TestAsToolKwargsPropagation:
# First call with specific kwargs
await tool.invoke(
arguments=tool.input_model(task="First task"),
session_id="session-1",
api_token="token-1",
context=self._build_context(
tool,
task="First task",
runtime_kwargs={"session_id": "session-1", "api_token": "token-1"},
),
)
# Second call with different kwargs
await tool.invoke(
arguments=tool.input_model(task="Second task"),
session_id="session-2",
api_token="token-2",
context=self._build_context(
tool,
task="Second task",
runtime_kwargs={"session_id": "session-2", "api_token": "token-2"},
),
)
# Verify first call had its own kwargs
assert first_call_kwargs.get("session_id") == "session-1"
assert first_call_kwargs.get("api_token") == "token-1"
assert first_call_function_invocation_kwargs.get("session_id") == "session-1"
assert first_call_function_invocation_kwargs.get("api_token") == "token-1"
# Verify second call had its own kwargs (not leaked from first)
assert second_call_kwargs.get("session_id") == "session-2"
assert second_call_kwargs.get("api_token") == "token-2"
assert second_call_function_invocation_kwargs.get("session_id") == "session-2"
assert second_call_function_invocation_kwargs.get("api_token") == "token-2"
async def test_as_tool_excludes_conversation_id_from_forwarded_kwargs(self, client: MockChatClient) -> None:
"""Test that conversation_id is not forwarded to sub-agent."""
captured_kwargs: dict[str, Any] = {}
async def test_as_tool_forwards_conversation_id_from_context_kwargs(self, client: MockChatClient) -> None:
"""Test that conversation_id is forwarded when explicitly present in runtime context kwargs."""
captured_function_invocation_kwargs: dict[str, Any] = {}
@agent_middleware
async def capture_middleware(context: AgentContext, call_next: Callable[[], Awaitable[None]]) -> None:
captured_kwargs.update(context.kwargs)
captured_function_invocation_kwargs.update(context.function_invocation_kwargs)
await call_next()
# Setup mock response
@@ -325,17 +355,17 @@ class TestAsToolKwargsPropagation:
# Invoke tool with conversation_id in kwargs (simulating parent's conversation state)
await tool.invoke(
arguments=tool.input_model(task="Test delegation"),
conversation_id="conv-parent-456",
api_token="secret-xyz-123",
user_id="user-456",
context=self._build_context(
tool,
task="Test delegation",
runtime_kwargs={
"conversation_id": "conv-parent-456",
"api_token": "secret-xyz-123",
"user_id": "user-456",
},
),
)
# Verify conversation_id was NOT forwarded to sub-agent
assert "conversation_id" not in captured_kwargs, (
f"conversation_id should not be forwarded, but got: {captured_kwargs}"
)
# Verify other kwargs were still forwarded
assert captured_kwargs.get("api_token") == "secret-xyz-123"
assert captured_kwargs.get("user_id") == "user-456"
assert captured_function_invocation_kwargs.get("conversation_id") == "conv-parent-456"
assert captured_function_invocation_kwargs.get("api_token") == "secret-xyz-123"
assert captured_function_invocation_kwargs.get("user_id") == "user-456"
@@ -1,9 +1,12 @@
# Copyright (c) Microsoft. All rights reserved.
import inspect
from typing import Any
from unittest.mock import patch
import pytest
from agent_framework import (
GROUP_ANNOTATION_KEY,
GROUP_TOKEN_COUNT_KEY,
@@ -50,6 +53,60 @@ def test_base_client(chat_client_base: SupportsChatGetResponse):
assert isinstance(chat_client_base, SupportsChatGetResponse)
def test_base_client_warns_for_direct_additional_properties(chat_client_base: SupportsChatGetResponse) -> None:
with pytest.warns(DeprecationWarning, match="additional_properties"):
client = type(chat_client_base)(legacy_key="legacy-value")
assert client.additional_properties["legacy_key"] == "legacy-value"
def test_base_client_as_agent_uses_explicit_additional_properties(chat_client_base: SupportsChatGetResponse) -> None:
agent = chat_client_base.as_agent(additional_properties={"team": "core"})
assert agent.additional_properties == {"team": "core"}
def test_openai_chat_client_get_response_docstring_surfaces_layered_runtime_docs() -> None:
from agent_framework.openai import OpenAIChatClient
docstring = inspect.getdoc(OpenAIChatClient.get_response)
assert docstring is not None
assert "Get a response from a chat client." in docstring
assert "function_invocation_kwargs" in docstring
assert "function_middleware: Optional per-call function middleware." in docstring
assert "middleware: Optional per-call chat and function middleware." in docstring
def test_openai_chat_client_get_response_is_defined_on_openai_class() -> None:
from agent_framework.openai import OpenAIChatClient
signature = inspect.signature(OpenAIChatClient.get_response)
assert OpenAIChatClient.get_response.__qualname__ == "OpenAIChatClient.get_response"
assert "function_middleware" in signature.parameters
assert "middleware" in signature.parameters
async def test_base_client_get_response_uses_explicit_client_kwargs(chat_client_base: SupportsChatGetResponse) -> None:
async def fake_inner_get_response(**kwargs):
assert kwargs["trace_id"] == "trace-123"
assert "function_invocation_kwargs" not in kwargs
return ChatResponse(messages=[Message(role="assistant", text="ok")])
with patch.object(
chat_client_base,
"_inner_get_response",
side_effect=fake_inner_get_response,
) as mock_inner_get_response:
await chat_client_base.get_response(
[Message(role="user", text="hello")],
function_invocation_kwargs={"tool_request_id": "tool-123"},
client_kwargs={"trace_id": "trace-123"},
)
mock_inner_get_response.assert_called_once()
async def test_base_client_get_response(chat_client_base: SupportsChatGetResponse):
response = await chat_client_base.get_response([Message(role="user", text="Hello")])
assert response.messages[0].role == "assistant"
@@ -0,0 +1,175 @@
# Copyright (c) Microsoft. All rights reserved.
from agent_framework._docstrings import apply_layered_docstring, build_layered_docstring
# -- Helpers: stub functions with various docstring shapes --
def _source_with_full_docstring(x: int) -> int:
"""Do something useful.
Args:
x: The input value.
Keyword Args:
timeout: Max seconds to wait.
Returns:
The computed result.
"""
return x
def _source_with_args_only(x: int) -> int:
"""Do something useful.
Args:
x: The input value.
Returns:
The computed result.
"""
return x
def _source_no_sections() -> None:
"""A plain summary with no Google-style sections."""
def _source_no_docstring() -> None:
pass
def _target_stub() -> None:
pass
# -- build_layered_docstring tests --
def test_build_returns_none_when_source_has_no_docstring() -> None:
result = build_layered_docstring(_source_no_docstring)
assert result is None
def test_build_returns_original_when_no_extra_kwargs() -> None:
result = build_layered_docstring(_source_with_full_docstring)
assert result is not None
assert "Do something useful." in result
assert "Keyword Args:" in result
def test_build_returns_original_when_extra_kwargs_empty() -> None:
result = build_layered_docstring(_source_with_full_docstring, extra_keyword_args={})
assert result is not None
assert result == build_layered_docstring(_source_with_full_docstring)
def test_build_appends_to_existing_keyword_args_section() -> None:
result = build_layered_docstring(
_source_with_full_docstring,
extra_keyword_args={"retries": "Number of retries."},
)
assert result is not None
assert "timeout: Max seconds to wait." in result
assert "retries: Number of retries." in result
# Both should be under Keyword Args
lines = result.splitlines()
kw_index = next(i for i, line in enumerate(lines) if line == "Keyword Args:")
ret_index = next(i for i, line in enumerate(lines) if line == "Returns:")
retries_index = next(i for i, line in enumerate(lines) if "retries:" in line)
assert kw_index < retries_index < ret_index
def test_build_inserts_keyword_args_after_args_section() -> None:
result = build_layered_docstring(
_source_with_args_only,
extra_keyword_args={"verbose": "Enable verbose output."},
)
assert result is not None
assert "Keyword Args:" in result
assert "verbose: Enable verbose output." in result
lines = result.splitlines()
args_index = next(i for i, line in enumerate(lines) if line == "Args:")
kw_index = next(i for i, line in enumerate(lines) if line == "Keyword Args:")
ret_index = next(i for i, line in enumerate(lines) if line == "Returns:")
assert args_index < kw_index < ret_index
def test_build_inserts_keyword_args_in_docstring_with_no_sections() -> None:
result = build_layered_docstring(
_source_no_sections,
extra_keyword_args={"debug": "Enable debug mode."},
)
assert result is not None
assert "A plain summary" in result
assert "Keyword Args:" in result
assert "debug: Enable debug mode." in result
def test_build_handles_multiline_descriptions() -> None:
result = build_layered_docstring(
_source_with_args_only,
extra_keyword_args={
"config": "The configuration object.\nMust be a valid mapping.\nDefaults to empty.",
},
)
assert result is not None
lines = result.splitlines()
config_line = next(line for line in lines if "config:" in line)
assert "The configuration object." in config_line
# Continuation lines should be indented
config_idx = lines.index(config_line)
assert "Must be a valid mapping." in lines[config_idx + 1]
assert "Defaults to empty." in lines[config_idx + 2]
def test_build_preserves_multiple_extra_kwargs_order() -> None:
result = build_layered_docstring(
_source_with_args_only,
extra_keyword_args={
"alpha": "First.",
"beta": "Second.",
"gamma": "Third.",
},
)
assert result is not None
lines = result.splitlines()
alpha_idx = next(i for i, line in enumerate(lines) if "alpha:" in line)
beta_idx = next(i for i, line in enumerate(lines) if "beta:" in line)
gamma_idx = next(i for i, line in enumerate(lines) if "gamma:" in line)
assert alpha_idx < beta_idx < gamma_idx
# -- apply_layered_docstring tests --
def test_apply_sets_docstring_on_target() -> None:
def target() -> None:
pass
apply_layered_docstring(target, _source_with_full_docstring)
assert target.__doc__ is not None
assert "Do something useful." in target.__doc__
def test_apply_with_extra_kwargs() -> None:
def target() -> None:
pass
apply_layered_docstring(
target,
_source_with_args_only,
extra_keyword_args={"flag": "A boolean flag."},
)
assert target.__doc__ is not None
assert "flag: A boolean flag." in target.__doc__
assert "Keyword Args:" in target.__doc__
def test_apply_sets_none_when_source_has_no_docstring() -> None:
def target() -> None:
"""Original."""
apply_layered_docstring(target, _source_no_docstring)
assert target.__doc__ is None
@@ -4,6 +4,8 @@ from __future__ import annotations
from collections.abc import Sequence
import pytest
from agent_framework import (
BaseEmbeddingClient,
Embedding,
@@ -63,6 +65,11 @@ def test_base_additional_properties_custom() -> None:
assert client.additional_properties == {"key": "value"}
def test_base_embedding_client_rejects_unknown_kwargs() -> None:
with pytest.raises(TypeError):
MockEmbeddingClient(legacy_key="value") # type: ignore[call-arg]
# --- SupportsGetEmbeddings protocol tests ---
@@ -3651,3 +3651,131 @@ class TestUpdateConversationId:
# endregion
async def test_user_input_request_propagates_through_as_tool(chat_client_base: SupportsChatGetResponse):
"""Test that user_input_request content from a sub-agent wrapped as a tool propagates to the parent response."""
from agent_framework.exceptions import UserInputRequiredException
@tool(name="delegate_agent", approval_mode="never_require")
def delegate_tool(task: str) -> str:
del task
raise UserInputRequiredException(
contents=[
Content.from_oauth_consent_request(
consent_link="https://login.microsoftonline.com/consent",
)
]
)
chat_client_base.run_responses = [
ChatResponse(
messages=Message(
role="assistant",
contents=[
Content.from_function_call(call_id="1", name="delegate_agent", arguments='{"task": "do it"}'),
],
)
)
]
response = await chat_client_base.get_response(
[Message(role="user", text="delegate this")],
options={"tool_choice": "auto", "tools": [delegate_tool]},
)
user_requests = [
content
for msg in response.messages
for content in msg.contents
if isinstance(content, Content) and content.user_input_request
]
assert len(user_requests) == 1
assert user_requests[0].type == "oauth_consent_request"
assert user_requests[0].consent_link == "https://login.microsoftonline.com/consent"
assert user_requests[0].user_input_request is True
async def test_user_input_request_multiple_contents_propagate(chat_client_base: SupportsChatGetResponse):
"""Test that multiple user_input_request items in a single exception all propagate to the parent response."""
from agent_framework.exceptions import UserInputRequiredException
@tool(name="multi_request_tool", approval_mode="never_require")
def multi_request(task: str) -> str:
del task
raise UserInputRequiredException(
contents=[
Content.from_oauth_consent_request(
consent_link="https://example.com/consent1",
),
Content.from_oauth_consent_request(
consent_link="https://example.com/consent2",
),
Content.from_oauth_consent_request(
consent_link="https://example.com/consent3",
),
]
)
chat_client_base.run_responses = [
ChatResponse(
messages=Message(
role="assistant",
contents=[
Content.from_function_call(call_id="1", name="multi_request_tool", arguments='{"task": "do it"}'),
],
)
)
]
response = await chat_client_base.get_response(
[Message(role="user", text="do something")],
options={"tool_choice": "auto", "tools": [multi_request]},
)
user_requests = [
content
for msg in response.messages
for content in msg.contents
if isinstance(content, Content) and content.user_input_request
]
assert len(user_requests) == 3
consent_links = {r.consent_link for r in user_requests}
assert consent_links == {
"https://example.com/consent1",
"https://example.com/consent2",
"https://example.com/consent3",
}
async def test_user_input_request_empty_contents_returns_fallback(chat_client_base: SupportsChatGetResponse):
"""Test that UserInputRequiredException with empty contents produces a fallback function_result."""
from agent_framework.exceptions import UserInputRequiredException
@tool(name="empty_request_tool", approval_mode="never_require")
def empty_request(task: str) -> str:
del task
raise UserInputRequiredException(contents=[])
chat_client_base.run_responses = [
ChatResponse(
messages=Message(
role="assistant",
contents=[
Content.from_function_call(call_id="1", name="empty_request_tool", arguments='{"task": "do it"}'),
],
)
),
ChatResponse(messages=Message(role="assistant", text="handled")),
]
response = await chat_client_base.get_response(
[Message(role="user", text="do something")],
options={"tool_choice": "auto", "tools": [empty_request]},
)
# With empty contents, the handler returns a function_result with an error message
# and the loop continues to the next chat response.
function_results = [
content for msg in response.messages for content in msg.contents if content.type == "function_result"
]
assert len(function_results) >= 1
assert any("user input" in (fr.result or "").lower() for fr in function_results)
@@ -6,11 +6,13 @@ from collections.abc import AsyncIterable, Awaitable, MutableSequence, Sequence
from typing import Any
from agent_framework import (
Agent,
BaseChatClient,
ChatMiddlewareLayer,
ChatResponse,
ChatResponseUpdate,
Content,
FunctionInvocationContext,
FunctionInvocationLayer,
Message,
ResponseStream,
@@ -97,6 +99,7 @@ class TestKwargsPropagationToFunctionTool:
async def test_kwargs_propagate_to_tool_with_kwargs(self) -> None:
"""Test that kwargs passed to get_response() are available in @tool **kwargs."""
# TODO(Copilot): Remove this legacy coverage once runtime ``**kwargs`` tool injection is removed.
captured_kwargs: dict[str, Any] = {}
@tool(approval_mode="never_require")
@@ -149,6 +152,7 @@ class TestKwargsPropagationToFunctionTool:
async def test_kwargs_not_forwarded_to_tool_without_kwargs(self) -> None:
"""Test that kwargs are NOT forwarded to @tool that doesn't accept **kwargs."""
# TODO(Copilot): Remove this legacy coverage once runtime ``**kwargs`` tool injection is removed.
@tool(approval_mode="never_require")
def simple_tool(x: int) -> str:
@@ -185,6 +189,7 @@ class TestKwargsPropagationToFunctionTool:
async def test_kwargs_isolated_between_function_calls(self) -> None:
"""Test that kwargs are consistent across multiple function call invocations."""
# TODO(Copilot): Remove this legacy coverage once runtime ``**kwargs`` tool injection is removed.
invocation_kwargs: list[dict[str, Any]] = []
@tool(approval_mode="never_require")
@@ -235,6 +240,7 @@ class TestKwargsPropagationToFunctionTool:
async def test_streaming_response_kwargs_propagation(self) -> None:
"""Test that kwargs propagate to @tool in streaming mode."""
# TODO(Copilot): Remove this legacy coverage once runtime ``**kwargs`` tool injection is removed.
captured_kwargs: dict[str, Any] = {}
@tool(approval_mode="never_require")
@@ -287,3 +293,59 @@ class TestKwargsPropagationToFunctionTool:
assert "streaming_session" in captured_kwargs, f"Expected 'streaming_session' in {captured_kwargs}"
assert captured_kwargs["streaming_session"] == "session-xyz"
assert captured_kwargs["correlation_id"] == "corr-123"
async def test_agent_run_injects_function_invocation_context(self) -> None:
"""Test that Agent.run injects FunctionInvocationContext for ctx-based tools."""
captured_context_kwargs: dict[str, Any] = {}
captured_client_kwargs: dict[str, Any] = {}
captured_options: dict[str, Any] = {}
@tool(approval_mode="never_require")
def capture_context_tool(x: int, ctx: FunctionInvocationContext) -> str:
captured_context_kwargs.update(ctx.kwargs)
return f"result: x={x}"
class CapturingFunctionInvokingMockClient(FunctionInvokingMockClient):
async def _get_non_streaming_response(
self,
*,
messages: MutableSequence[Message],
options: dict[str, Any],
**kwargs: Any,
) -> ChatResponse:
captured_options.update(options)
captured_client_kwargs.update(kwargs)
return await super()._get_non_streaming_response(messages=messages, options=options, **kwargs)
client = CapturingFunctionInvokingMockClient()
client.run_responses = [
ChatResponse(
messages=[
Message(
role="assistant",
contents=[
Content.from_function_call(
call_id="call_1",
name="capture_context_tool",
arguments='{"x": 42}',
)
],
)
]
),
ChatResponse(messages=[Message(role="assistant", text="Done!")]),
]
agent = Agent(client=client, tools=[capture_context_tool])
result = await agent.run(
[Message(role="user", text="Test")],
function_invocation_kwargs={"tool_request_id": "tool-123"},
client_kwargs={"client_request_id": "client-456"},
)
assert captured_context_kwargs["tool_request_id"] == "tool-123"
assert "client_request_id" not in captured_context_kwargs
assert captured_client_kwargs["client_request_id"] == "client-456"
assert "tool_request_id" not in captured_client_kwargs
assert "additional_function_arguments" not in captured_options
assert result.messages[-1].text == "Done!"
@@ -192,10 +192,10 @@ class ConcreteHistoryProvider(BaseHistoryProvider):
self.stored: list[Message] = []
self._stored_messages = stored_messages or []
async def get_messages(self, session_id: str | None, **kwargs) -> list[Message]:
async def get_messages(self, session_id: str | None, *, state=None, **kwargs) -> list[Message]:
return list(self._stored_messages)
async def save_messages(self, session_id: str | None, messages: Sequence[Message], **kwargs) -> None:
async def save_messages(self, session_id: str | None, messages: Sequence[Message], *, state=None, **kwargs) -> None:
self.stored.extend(messages)
@@ -12,6 +12,7 @@ from agent_framework import (
FunctionTool,
tool,
)
from agent_framework._middleware import FunctionInvocationContext
from agent_framework._tools import (
_parse_annotation,
_parse_inputs,
@@ -952,6 +953,128 @@ async def test_ai_function_with_kwargs_injection():
assert result_default[0].text == "x=10, user=unknown"
async def test_ai_function_with_explicit_invocation_context():
"""Test that invoke() can receive runtime kwargs via FunctionInvocationContext."""
@tool
def tool_with_context(x: int, ctx: FunctionInvocationContext) -> str:
"""A tool that accepts runtime context injection."""
user_id = ctx.kwargs.get("user_id", "unknown")
return f"x={x}, user={user_id}"
assert tool_with_context.parameters() == {
"properties": {"x": {"title": "X", "type": "integer"}},
"required": ["x"],
"title": "tool_with_context_input",
"type": "object",
}
context = FunctionInvocationContext(
function=tool_with_context,
arguments=tool_with_context.input_model(x=7),
kwargs={"user_id": "ctx-user"},
)
result = await tool_with_context.invoke(context=context)
assert result[0].text == "x=7, user=ctx-user"
async def test_ai_function_with_typed_context_parameter_using_custom_name():
"""Test that typed context injection works for names other than ctx."""
@tool
def tool_with_runtime_context(x: int, runtime: FunctionInvocationContext) -> str:
"""A tool that uses a custom context parameter name."""
user_id = runtime.kwargs.get("user_id", "unknown")
return f"x={x}, user={user_id}"
assert tool_with_runtime_context.parameters() == {
"properties": {"x": {"title": "X", "type": "integer"}},
"required": ["x"],
"title": "tool_with_runtime_context_input",
"type": "object",
}
context = FunctionInvocationContext(
function=tool_with_runtime_context,
arguments=tool_with_runtime_context.input_model(x=8),
kwargs={"user_id": "runtime-user"},
)
result = await tool_with_runtime_context.invoke(context=context)
assert result[0].text == "x=8, user=runtime-user"
async def test_ai_function_with_explicit_schema_and_untyped_ctx():
"""Test that explicit schemas allow an untyped ctx parameter."""
class ToolInput(BaseModel):
x: int
@tool(schema=ToolInput)
def tool_with_schema(x, ctx) -> str:
"""A tool with explicit schema and implicit ctx injection."""
return f"x={x}, user={ctx.kwargs.get('user_id', 'unknown')}"
context = FunctionInvocationContext(
function=tool_with_schema,
arguments=ToolInput(x=9),
kwargs={"user_id": "schema-user"},
)
result = await tool_with_schema.invoke(context=context)
assert result[0].text == "x=9, user=schema-user"
async def test_ai_function_with_explicit_schema_and_typed_ctx():
"""Test that explicit schemas also work with typed context injection."""
class ToolInput(BaseModel):
x: int
@tool(schema=ToolInput)
def tool_with_schema(x: int, runtime: FunctionInvocationContext) -> str:
"""A tool with explicit schema and typed context injection."""
return f"x={x}, user={runtime.kwargs.get('user_id', 'unknown')}"
context = FunctionInvocationContext(
function=tool_with_schema,
arguments=ToolInput(x=11),
kwargs={"user_id": "typed-schema-user"},
)
result = await tool_with_schema.invoke(context=context)
assert tool_with_schema.parameters() == ToolInput.model_json_schema()
assert result[0].text == "x=11, user=typed-schema-user"
def test_ai_function_with_multiple_typed_context_parameters_fails():
"""Test that tools reject multiple typed FunctionInvocationContext parameters."""
with pytest.raises(ValueError, match="multiple FunctionInvocationContext parameters"):
@tool
def invalid_tool(ctx_one: FunctionInvocationContext, ctx_two: FunctionInvocationContext) -> str:
return f"{ctx_one.kwargs}-{ctx_two.kwargs}"
def test_ai_function_with_ctx_and_typed_context_parameter_fails():
"""Test that explicit-schema tools reject both implicit ctx and typed context parameters."""
class ToolInput(BaseModel):
x: int
with pytest.raises(ValueError, match="multiple FunctionInvocationContext parameters"):
@tool(schema=ToolInput)
def invalid_tool(x, ctx, runtime: FunctionInvocationContext) -> str:
return f"{x}-{ctx.kwargs}-{runtime.kwargs}"
# region _parse_annotation tests
@@ -124,10 +124,20 @@ class DurableAgentExecutor(ABC, Generic[TaskT]):
"""
raise NotImplementedError
def get_new_session(self, agent_name: str, **kwargs: Any) -> DurableAgentSession:
def get_new_session(
self,
agent_name: str,
*,
session_id: str | None = None,
service_session_id: str | None = None,
) -> DurableAgentSession:
"""Create a new DurableAgentSession with random session ID."""
session_id = self._create_session_id(agent_name)
return DurableAgentSession.from_session_id(session_id, **kwargs)
durable_session_id = self._create_session_id(agent_name)
return DurableAgentSession(
durable_session_id=durable_session_id,
session_id=session_id,
service_session_id=service_session_id,
)
def _create_session_id(
self,
@@ -284,46 +284,48 @@ class DurableAgentSession(AgentSession):
durable_session_id: AgentSessionId | None = None,
session_id: str | None = None,
service_session_id: str | None = None,
**kwargs: Any,
) -> None:
super().__init__(session_id=session_id, service_session_id=service_session_id, **kwargs)
self._session_id_value: AgentSessionId | None = durable_session_id
super().__init__(session_id=session_id, service_session_id=service_session_id)
self.durable_session_id: AgentSessionId | None = durable_session_id
@property
def durable_session_id(self) -> AgentSessionId | None:
return self._session_id_value
@durable_session_id.setter
def durable_session_id(self, value: AgentSessionId | None) -> None:
self._session_id_value = value
def to_dict(self) -> dict[str, Any]:
state = super().to_dict()
if self.durable_session_id is not None:
state[self._SERIALIZED_SESSION_ID_KEY] = str(self.durable_session_id)
return state
@classmethod
def from_session_id(
cls,
session_id: AgentSessionId,
**kwargs: Any,
durable_session_id: AgentSessionId,
*,
session_id: str | None = None,
service_session_id: str | None = None,
) -> DurableAgentSession:
return cls(durable_session_id=session_id, **kwargs)
def to_dict(self) -> dict[str, Any]:
state = super().to_dict()
if self._session_id_value is not None:
state[self._SERIALIZED_SESSION_ID_KEY] = str(self._session_id_value)
return state
"""Create a DurableAgentSession from an AgentSessionId."""
return cls(
durable_session_id=durable_session_id,
session_id=session_id,
service_session_id=service_session_id,
)
@classmethod
def from_dict(cls, data: dict[str, Any]) -> DurableAgentSession:
state_payload = dict(data)
session_id_value = state_payload.pop(cls._SERIALIZED_SESSION_ID_KEY, None)
session = super().from_dict(state_payload)
"""Create a DurableAgentSession from a state dict."""
data = dict(data) # defensive copy — avoid mutating caller's dict
session_id_value = data.pop(cls._SERIALIZED_SESSION_ID_KEY, None)
session = super().from_dict(data)
durable_session_id: AgentSessionId | None = None
# We need to create a DurableAgentSession from the base AgentSession
if session_id_value is not None:
if not isinstance(session_id_value, str):
raise ValueError("durable_session_id must be a string when present in serialized state")
durable_session_id = AgentSessionId.parse(session_id_value)
durable_session = cls(
durable_session_id=durable_session_id,
session_id=session.session_id,
service_session_id=session.service_session_id,
)
durable_session.state.update(session.state)
if session_id_value is not None:
if not isinstance(session_id_value, str):
raise ValueError("durable_session_id must be a string when present in serialized state")
durable_session._session_id_value = AgentSessionId.parse(session_id_value)
return durable_session
@@ -133,16 +133,13 @@ class DurableAIAgent(SupportsAgentRun, Generic[TaskT]):
session=session,
)
def create_session(self, **kwargs: Any) -> DurableAgentSession:
def create_session(self, *, session_id: str | None = None) -> DurableAgentSession:
"""Create a new agent session via the provider."""
return self._executor.get_new_session(self.name, **kwargs)
return self._executor.get_new_session(self.name)
def get_session(self, **kwargs: Any) -> AgentSession:
"""Retrieve an existing session via the provider.
For durable agents, sessions do not use `service_session_id` so this is not used.
"""
return self._executor.get_new_session(self.name, **kwargs)
def get_session(self, service_session_id: str, *, session_id: str | None = None) -> AgentSession:
"""Retrieve an existing session via the provider."""
return self._executor.get_new_session(self.name, service_session_id=service_session_id, session_id=session_id)
def _normalize_messages(self, messages: AgentRunInputs | None) -> str:
"""Convert supported message inputs to a single string.
@@ -2,6 +2,8 @@
"""Unit tests for AgentSessionId and DurableAgentSession."""
from typing import Any
import pytest
from agent_framework import AgentSession
@@ -153,7 +155,7 @@ class TestDurableAgentSession:
def test_from_session_id(self) -> None:
"""Test creating DurableAgentSession from session ID."""
session_id = AgentSessionId(name="TestAgent", key="test-key")
session = DurableAgentSession.from_session_id(session_id)
session = DurableAgentSession(durable_session_id=session_id)
assert isinstance(session, DurableAgentSession)
assert session.durable_session_id is not None
@@ -161,10 +163,10 @@ class TestDurableAgentSession:
assert session.durable_session_id.name == "TestAgent"
assert session.durable_session_id.key == "test-key"
def test_from_session_id_with_service_session_id(self) -> None:
"""Test creating DurableAgentSession with service session ID."""
def test_init_with_service_session_id(self) -> None:
"""Test creating DurableAgentSession with explicit service session ID."""
session_id = AgentSessionId(name="TestAgent", key="test-key")
session = DurableAgentSession.from_session_id(session_id, service_session_id="service-123")
session = DurableAgentSession(durable_session_id=session_id, service_session_id="service-123")
assert session.durable_session_id is not None
assert session.durable_session_id == session_id
@@ -192,7 +194,7 @@ class TestDurableAgentSession:
def test_from_dict_with_durable_session_id(self) -> None:
"""Test deserialization restores durable session ID."""
serialized = {
serialized: dict[str, Any] = {
"type": "session",
"session_id": "session-123",
"service_session_id": "service-123",
@@ -210,7 +212,7 @@ class TestDurableAgentSession:
def test_from_dict_without_durable_session_id(self) -> None:
"""Test deserialization without durable session ID."""
serialized = {
serialized: dict[str, Any] = {
"type": "session",
"session_id": "session-456",
"service_session_id": "service-456",
@@ -88,15 +88,6 @@ class TestDurableAIAgentClientIntegration:
assert isinstance(session, DurableAgentSession)
def test_client_agent_session_with_parameters(self, agent_client: DurableAIAgentClient) -> None:
"""Verify agent can create sessions with custom parameters."""
agent = agent_client.get_agent("assistant")
session = agent.create_session(service_session_id="client-session-123")
assert isinstance(session, DurableAgentSession)
assert session.service_session_id == "client-session-123"
class TestDurableAIAgentClientPollingConfiguration:
"""Test polling configuration parameters for DurableAIAgentClient."""
@@ -82,17 +82,6 @@ class TestDurableAIAgentOrchestrationContextIntegration:
assert isinstance(session, DurableAgentSession)
def test_orchestration_agent_session_with_parameters(
self, agent_context: DurableAIAgentOrchestrationContext
) -> None:
"""Verify agent can create sessions with custom parameters."""
agent = agent_context.get_agent("assistant")
session = agent.create_session(service_session_id="orch-session-456")
assert isinstance(session, DurableAgentSession)
assert session.service_session_id == "orch-session-456"
if __name__ == "__main__":
pytest.main([__file__, "-v", "--tb=short"])
+22 -7
View File
@@ -184,16 +184,31 @@ class TestDurableAIAgentSessionManagement:
mock_executor.get_new_session.assert_called_once_with("test_agent")
assert session == mock_session
def test_create_session_forwards_kwargs(self, test_agent: DurableAIAgent[Any], mock_executor: Mock) -> None:
"""Verify create_session forwards kwargs to executor."""
mock_session = DurableAgentSession(service_session_id="session-123")
def test_get_session_forwards_service_session_id(
self, test_agent: DurableAIAgent[Any], mock_executor: Mock
) -> None:
"""Verify get_session forwards service_session_id and session_id to executor."""
mock_session = DurableAgentSession(service_session_id="svc-123")
mock_executor.get_new_session.return_value = mock_session
test_agent.create_session(service_session_id="session-123")
session = test_agent.get_session("svc-123", session_id="local-456")
mock_executor.get_new_session.assert_called_once()
_, kwargs = mock_executor.get_new_session.call_args
assert kwargs["service_session_id"] == "session-123"
mock_executor.get_new_session.assert_called_once_with(
"test_agent", service_session_id="svc-123", session_id="local-456"
)
assert session.service_session_id == "svc-123"
def test_get_session_without_session_id(self, test_agent: DurableAIAgent[Any], mock_executor: Mock) -> None:
"""Verify get_session works with only service_session_id (session_id defaults to None)."""
mock_session = DurableAgentSession(service_session_id="svc-789")
mock_executor.get_new_session.return_value = mock_session
session = test_agent.get_session("svc-789")
mock_executor.get_new_session.assert_called_once_with(
"test_agent", service_session_id="svc-789", session_id=None
)
assert session.service_session_id == "svc-789"
class TestDurableAgentProviderInterface:
@@ -146,11 +146,11 @@ class FoundryLocalClient(
timeout: float | None = None,
prepare_model: bool = True,
device: DeviceType | None = None,
additional_properties: dict[str, Any] | None = None,
middleware: Sequence[ChatAndFunctionMiddlewareTypes] | None = None,
function_invocation_configuration: FunctionInvocationConfiguration | None = None,
env_file_path: str | None = None,
env_file_encoding: str = "utf-8",
**kwargs: Any,
) -> None:
"""Initialize a FoundryLocalClient.
@@ -169,12 +169,11 @@ class FoundryLocalClient(
The device is used to select the appropriate model variant.
If not provided, the default device for your system will be used.
The values are in the foundry_local.models.DeviceType enum.
additional_properties: Additional properties stored on the client instance.
middleware: Optional sequence of ChatAndFunctionMiddlewareTypes to apply to requests.
function_invocation_configuration: Optional configuration for function invocation support.
env_file_path: If provided, the .env settings are read from this file path location.
env_file_encoding: The encoding of the .env file, defaults to 'utf-8'.
kwargs: Additional keyword arguments, are passed to the RawOpenAIChatClient.
This can include middleware and additional properties.
Examples:
@@ -271,8 +270,8 @@ class FoundryLocalClient(
super().__init__(
model_id=model_info.id,
client=AsyncOpenAI(base_url=manager.endpoint, api_key=manager.api_key),
additional_properties=additional_properties,
middleware=middleware,
function_invocation_configuration=function_invocation_configuration,
**kwargs,
)
self.manager = manager
@@ -303,7 +303,6 @@ class GitHubCopilotAgent(BaseAgent, Generic[OptionsT]):
stream: Literal[False] = False,
session: AgentSession | None = None,
options: OptionsT | None = None,
**kwargs: Any,
) -> Awaitable[AgentResponse]: ...
@overload
@@ -314,7 +313,6 @@ class GitHubCopilotAgent(BaseAgent, Generic[OptionsT]):
stream: Literal[True],
session: AgentSession | None = None,
options: OptionsT | None = None,
**kwargs: Any,
) -> ResponseStream[AgentResponseUpdate, AgentResponse]: ...
def run(
@@ -324,7 +322,6 @@ class GitHubCopilotAgent(BaseAgent, Generic[OptionsT]):
stream: bool = False,
session: AgentSession | None = None,
options: OptionsT | None = None,
**kwargs: Any,
) -> Awaitable[AgentResponse] | ResponseStream[AgentResponseUpdate, AgentResponse]:
"""Get a response from the agent.
@@ -339,7 +336,6 @@ class GitHubCopilotAgent(BaseAgent, Generic[OptionsT]):
stream: Whether to stream the response. Defaults to False.
session: The conversation session associated with the message(s).
options: Runtime options (model, timeout, etc.).
kwargs: Additional keyword arguments.
Returns:
When stream=False: An Awaitable[AgentResponse].
@@ -354,10 +350,10 @@ class GitHubCopilotAgent(BaseAgent, Generic[OptionsT]):
return AgentResponse.from_updates(updates)
return ResponseStream(
self._stream_updates(messages=messages, session=session, options=options, **kwargs),
self._stream_updates(messages=messages, session=session, options=options),
finalizer=_finalize,
)
return self._run_impl(messages=messages, session=session, options=options, **kwargs)
return self._run_impl(messages=messages, session=session, options=options)
async def _run_impl(
self,
@@ -365,7 +361,6 @@ class GitHubCopilotAgent(BaseAgent, Generic[OptionsT]):
*,
session: AgentSession | None = None,
options: OptionsT | None = None,
**kwargs: Any,
) -> AgentResponse:
"""Non-streaming implementation of run."""
if not self._started:
@@ -414,7 +409,6 @@ class GitHubCopilotAgent(BaseAgent, Generic[OptionsT]):
*,
session: AgentSession | None = None,
options: OptionsT | None = None,
**kwargs: Any,
) -> AsyncIterable[AgentResponseUpdate]:
"""Internal method to stream updates from GitHub Copilot.
@@ -424,7 +418,6 @@ class GitHubCopilotAgent(BaseAgent, Generic[OptionsT]):
Keyword Args:
session: The conversation session associated with the message(s).
options: Runtime options (model, timeout, etc.).
kwargs: Additional keyword arguments.
Yields:
AgentResponseUpdate items.
@@ -300,11 +300,11 @@ class OllamaChatClient(
host: str | None = None,
client: AsyncClient | None = None,
model_id: str | None = None,
additional_properties: dict[str, Any] | None = None,
middleware: Sequence[ChatAndFunctionMiddlewareTypes] | None = None,
function_invocation_configuration: FunctionInvocationConfiguration | None = None,
env_file_path: str | None = None,
env_file_encoding: str | None = None,
**kwargs: Any,
) -> None:
"""Initialize an Ollama Chat client.
@@ -313,11 +313,11 @@ class OllamaChatClient(
Can be set via the OLLAMA_HOST env variable.
client: An optional Ollama Client instance. If not provided, a new instance will be created.
model_id: The Ollama chat model ID to use. Can be set via the OLLAMA_MODEL_ID env variable.
additional_properties: Additional properties stored on the client instance.
middleware: Optional middleware to apply to the client.
function_invocation_configuration: Optional function invocation configuration override.
env_file_path: An optional path to a dotenv (.env) file to load environment variables from.
env_file_encoding: The encoding to use when reading the dotenv (.env) file. Defaults to 'utf-8'.
**kwargs: Additional keyword arguments passed to BaseChatClient.
"""
ollama_settings = load_settings(
OllamaSettings,
@@ -336,9 +336,9 @@ class OllamaChatClient(
self.host = str(self.client._client.base_url) # type: ignore[reportUnknownMemberType,reportPrivateUsage,reportUnknownArgumentType]
super().__init__(
additional_properties=additional_properties,
middleware=middleware,
function_invocation_configuration=function_invocation_configuration,
**kwargs,
)
self.middleware = list(self.chat_middleware)
@@ -92,9 +92,9 @@ class RawOllamaEmbeddingClient(
model_id: str | None = None,
host: str | None = None,
client: AsyncClient | None = None,
additional_properties: dict[str, Any] | None = None,
env_file_path: str | None = None,
env_file_encoding: str | None = None,
**kwargs: Any,
) -> None:
"""Initialize a raw Ollama embedding client."""
ollama_settings = load_settings(
@@ -110,7 +110,7 @@ class RawOllamaEmbeddingClient(
self.model_id = ollama_settings["embedding_model_id"] # type: ignore[assignment,reportTypedDictNotRequiredAccess]
self.client = client or AsyncClient(host=ollama_settings.get("host"))
self.host = str(self.client._client.base_url) # type: ignore[reportUnknownMemberType,reportPrivateUsage,reportUnknownArgumentType]
super().__init__(**kwargs)
super().__init__(additional_properties=additional_properties)
def service_url(self) -> str:
"""Get the URL of the service."""
@@ -214,17 +214,17 @@ class OllamaEmbeddingClient(
host: str | None = None,
client: AsyncClient | None = None,
otel_provider_name: str | None = None,
additional_properties: dict[str, Any] | None = None,
env_file_path: str | None = None,
env_file_encoding: str | None = None,
**kwargs: Any,
) -> None:
"""Initialize an Ollama embedding client."""
super().__init__(
model_id=model_id,
host=host,
client=client,
additional_properties=additional_properties,
otel_provider_name=otel_provider_name,
env_file_path=env_file_path,
env_file_encoding=env_file_encoding,
**kwargs,
)
@@ -107,11 +107,18 @@ class RedisHistoryProvider(BaseHistoryProvider):
"""Get the Redis key for a given session's messages."""
return f"{self.key_prefix}:{session_id or 'default'}"
async def get_messages(self, session_id: str | None, **kwargs: Any) -> list[Message]:
async def get_messages(
self,
session_id: str | None,
*,
state: dict[str, Any] | None = None,
**kwargs: Any,
) -> list[Message]:
"""Retrieve stored messages for this session from Redis.
Args:
session_id: The session ID to retrieve messages for.
state: Optional session state. Unused for Redis-backed history.
**kwargs: Additional arguments (unused).
Returns:
@@ -125,12 +132,20 @@ class RedisHistoryProvider(BaseHistoryProvider):
messages.append(Message.from_dict(self._deserialize_json(serialized))) # type: ignore[union-attr]
return messages
async def save_messages(self, session_id: str | None, messages: Sequence[Message], **kwargs: Any) -> None:
async def save_messages(
self,
session_id: str | None,
messages: Sequence[Message],
*,
state: dict[str, Any] | None = None,
**kwargs: Any,
) -> None:
"""Persist messages for this session to Redis.
Args:
session_id: The session ID to store messages for.
messages: The messages to persist.
state: Optional session state. Unused for Redis-backed history.
**kwargs: Additional arguments (unused).
"""
if not messages:
@@ -3,7 +3,7 @@
import asyncio
from collections.abc import Awaitable, Callable
from agent_framework import AgentContext, AgentSession
from agent_framework import AgentContext, AgentSession, FunctionInvocationContext, tool
from agent_framework.openai import OpenAIResponsesClient
from dotenv import load_dotenv
@@ -18,9 +18,6 @@ sub-agent invoked as a tool using ``propagate_session=True``.
When session propagation is enabled, both agents share the same session object,
including session_id and the mutable state dict. This allows correlated
conversation tracking and shared state across the agent hierarchy.
The middleware functions below are purely for observability they are NOT
required for session propagation to work.
"""
@@ -28,65 +25,83 @@ async def log_session(
context: AgentContext,
call_next: Callable[[], Awaitable[None]],
) -> None:
"""Agent middleware that logs the session received by each agent.
NOT required for session propagation only used to observe the flow.
If propagation is working, both agents will show the same session_id.
"""
"""Agent middleware that logs the session received by each agent."""
session: AgentSession | None = context.session
if not session:
print("No session found.")
await call_next()
return
agent_name = context.agent.name or "unknown"
session_id = session.session_id if session else None
state = dict(session.state) if session else {}
print(f" [{agent_name}] session_id={session_id}, state={state}")
print(
f" [{agent_name}] session_id={session.session_id}, "
f"service_session_id={session.service_session_id} state={session.state}"
)
await call_next()
@tool(description="Use this tool to store the findings so that other agents can reason over them.")
def store_findings(findings: str, ctx: FunctionInvocationContext) -> None:
if ctx.session is None:
return
current_findings = ctx.session.state.get("findings")
if current_findings is None:
ctx.session.state["findings"] = findings
else:
ctx.session.state["findings"] = f"{current_findings}\n{findings}"
@tool(description="Use this tool to gather the current findings from other agents.")
def recall_findings(ctx: FunctionInvocationContext) -> str:
if ctx.session is None:
return "No session available"
current_findings = ctx.session.state.get("findings")
if current_findings is None:
return "Nothing yet"
return current_findings
async def main() -> None:
print("=== Agent-as-Tool: Session Propagation ===\n")
client = OpenAIResponsesClient()
# --- Sub-agent: a research specialist ---
# The sub-agent has the same log_session middleware to prove it receives the session.
research_agent = client.as_agent(
name="ResearchAgent",
instructions="You are a research assistant. Provide concise answers.",
instructions="You are a research assistant. Provide concise answers and store your findings.",
middleware=[log_session],
tools=[store_findings, recall_findings],
)
# propagate_session=True: the coordinator's session will be forwarded
research_tool = research_agent.as_tool(
name="research",
description="Research a topic and return findings",
description="Research a topic and store your findings.",
arg_name="query",
arg_description="The research query",
propagate_session=True,
)
# --- Coordinator agent ---
coordinator = client.as_agent(
name="CoordinatorAgent",
instructions="You coordinate research. Use the 'research' tool to look up information.",
tools=[research_tool],
instructions=(
"You coordinate research. Use the 'research' tool to start research "
"and then use the recall findings tool to gather up everything."
),
tools=[research_tool, store_findings, recall_findings],
middleware=[log_session],
)
# Create a shared session and put some state in it
session = coordinator.create_session()
session.state["request_source"] = "demo"
session.state["findings"] = None
print(f"Session ID: {session.session_id}")
print(f"Session state before run: {session.state}\n")
query = "What are the latest developments in quantum computing?"
query = "What are the latest developments in quantum computing and in AI?"
print(f"User: {query}\n")
result = await coordinator.run(query, session=session)
print(f"\nCoordinator: {result}\n")
print(f"Session state after run: {session.state}")
print(
"\nIf both agents show the same session_id above, session propagation is working."
)
if __name__ == "__main__":
@@ -1,9 +1,9 @@
# Copyright (c) Microsoft. All rights reserved.
import asyncio
from typing import Annotated, Any
from typing import Annotated
from agent_framework import tool
from agent_framework import FunctionInvocationContext, tool
from agent_framework.openai import OpenAIResponsesClient
from dotenv import load_dotenv
from pydantic import Field
@@ -14,27 +14,27 @@ load_dotenv()
"""
AI Function with kwargs Example
This example demonstrates how to inject custom keyword arguments (kwargs) into an AI function
from the agent's run method, without exposing them to the AI model.
This example demonstrates how to inject runtime context into an AI function
from the agent's run method, without exposing it to the AI model.
This is useful for passing runtime information like access tokens, user IDs, or
request-specific context that the tool needs but the model shouldn't know about
or provide.
or provide. The injected context parameter can be typed as
``FunctionInvocationContext`` as shown here, or left untyped as ``ctx`` when you
prefer a lighter-weight sample setup.
"""
# Define the function tool with **kwargs to accept injected arguments
# NOTE: approval_mode="never_require" is for sample brevity. Use "always_require" in production;
# see samples/02-agents/tools/function_tool_with_approval.py
# and samples/02-agents/tools/function_tool_with_approval_and_sessions.py.
# Define the function tool with explicit invocation context.
# The context parameter can also be declared as an untyped ``ctx`` parameter.
@tool(approval_mode="never_require")
def get_weather(
location: Annotated[str, Field(description="The location to get the weather for.")],
**kwargs: Any,
ctx: FunctionInvocationContext,
) -> str:
"""Get the weather for a given location."""
# Extract the injected argument from kwargs
user_id = kwargs.get("user_id", "unknown")
# Extract the injected argument from the explicit context
user_id = ctx.kwargs.get("user_id", "unknown")
# Simulate using the user_id for logging or personalization
print(f"Getting weather for user: {user_id}")
@@ -49,9 +49,11 @@ async def main() -> None:
tools=[get_weather],
)
# Pass the injected argument when running the agent
# The 'user_id' kwarg will be passed down to the tool execution via **kwargs
response = await agent.run("What is the weather like in Amsterdam?", user_id="user_123")
# Pass the runtime context explicitly when running the agent.
response = await agent.run(
"What is the weather like in Amsterdam?",
function_invocation_kwargs={"user_id": "user_123"},
)
print(f"Agent: {response.text}")
@@ -1,9 +1,9 @@
# Copyright (c) Microsoft. All rights reserved.
import asyncio
from typing import Annotated, Any
from typing import Annotated
from agent_framework import AgentSession, tool
from agent_framework import AgentSession, FunctionInvocationContext, tool
from agent_framework.openai import OpenAIResponsesClient
from dotenv import load_dotenv
from pydantic import Field
@@ -14,23 +14,21 @@ load_dotenv()
"""
AI Function with Session Injection Example
This example demonstrates the behavior when passing 'session' to agent.run()
and accessing that session in AI function.
This example demonstrates accessing the agent session inside a tool function
via ``FunctionInvocationContext.session``. The session is automatically
available when the agent is invoked with a session.
"""
# Define the function tool with **kwargs
# NOTE: approval_mode="never_require" is for sample brevity. Use "always_require" in production;
# see samples/02-agents/tools/function_tool_with_approval.py
# and samples/02-agents/tools/function_tool_with_approval_and_sessions.py.
# Define the function tool with explicit invocation context.
# The context parameter can also be declared as an untyped parameter with the name: ``ctx``.
@tool(approval_mode="never_require")
async def get_weather(
location: Annotated[str, Field(description="The location to get the weather for.")],
**kwargs: Any,
ctx: FunctionInvocationContext,
) -> str:
"""Get the weather for a given location."""
# Get session object from kwargs
session = kwargs.get("session")
session = ctx.session
if session and isinstance(session, AgentSession) and session.service_session_id:
print(f"Session ID: {session.service_session_id}.")
@@ -42,17 +40,19 @@ async def main() -> None:
name="WeatherAgent",
instructions="You are a helpful weather assistant.",
tools=[get_weather],
options={"store": True},
default_options={"store": True},
)
# Create a session
session = agent.create_session()
# Run the agent with the session
# Pass session via additional_function_arguments so tools can access it via **kwargs
opts = {"additional_function_arguments": {"session": session}}
print(f"Agent: {await agent.run('What is the weather in London?', session=session, options=opts)}")
print(f"Agent: {await agent.run('What is the weather in Amsterdam?', session=session, options=opts)}")
# Run the agent with the session; tools receive it via ctx.session.
print(
f"Agent: {await agent.run('What is the weather in London?', session=session)}"
)
print(
f"Agent: {await agent.run('What is the weather in Amsterdam?', session=session)}"
)
print(f"Agent: {await agent.run('What cities did I ask about?', session=session)}")