mirror of
https://github.com/microsoft/agent-framework.git
synced 2026-06-16 21:04:09 +08:00
[BREAKING] Python: clean up kwargs across agents, chat clients, tools, and sessions (#4581)
* Python: clean up kwargs across agents, chat clients, tools, and sessions (#3642) Audit and refactor public **kwargs usage across core agents, chat clients, tools, sessions, and provider packages per the migration strategy codified in CODING_STANDARD.md. Key changes: - Add explicit runtime buckets: function_invocation_kwargs and client_kwargs on RawAgent.run() and chat client get_response() layers. - Refactor FunctionTool to prefer explicit ctx: FunctionInvocationContext injection; legacy **kwargs tools still work via _forward_runtime_kwargs. - Refactor Agent.as_tool() to use direct JSON schema, always-streaming wrapper, approval_mode parameter, and UserInputRequiredException propagation (integrates PR #4568 behavior). - Remove implicit session bleeding into FunctionInvocationContext; tools that need a session must receive it via function_invocation_kwargs. - Lower chat-client layers after FunctionInvocationLayer accept only compatibility **kwargs (client_kwargs flattened, function_invocation_kwargs ignored). - Add layered docstring composition from Raw... implementations via _docstrings.py helper. - Clean up provider constructors to use explicit additional_properties. - Deprecation warnings on legacy direct kwargs paths. - Update samples, tests, and typing across all 23 packages. Resolves #3642 Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> * clarified docstring * feedback fixes * Add unit tests for _docstrings.py build/apply helpers Tests cover: no docstring source, no extra kwargs, appending to existing Keyword Args section, inserting after Args, inserting in plain docstrings, multiline descriptions, ordering, and apply_layered_docstring. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> * Add test for propagate_session TypeError on non-AgentSession values Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> * Add tests for multi-content and empty UserInputRequiredException propagation Cover the branching logic in _try_execute_function_calls for: - Multiple user_input_request items in a single exception (extra_user_input_contents path) - Empty contents list (fallback function_result path) Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> * Add tests for DurableAIAgent.get_session forwarding service_session_id Verifies get_session correctly forwards service_session_id and session_id to the executor's get_new_session, replacing the removed kwargs test. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> * Simplify ag-ui test stub to read session from client_kwargs only Remove dual-mode detection (client_kwargs vs raw kwargs fallback) from the test mock. Session is now read exclusively from client_kwargs, matching the settled public calling convention. Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com> * updated create and get sessions in durable * fixed docstrings * fix test * updated session handling * updated from main * updated tests --------- Co-authored-by: Copilot <223556219+Copilot@users.noreply.github.com>
This commit is contained in:
committed by
GitHub
Unverified
parent
b7990908fe
commit
a4b9539b62
@@ -127,7 +127,12 @@ def create_agent(name: str, tool_mode: Literal['auto', 'required', 'none'] | Cha
|
||||
Avoid `**kwargs` unless absolutely necessary. It should only be used as an escape route, not for well-known flows of data:
|
||||
|
||||
- **Prefer named parameters**: If there are known extra arguments being passed, use explicit named parameters instead of kwargs
|
||||
- **Prefer purpose-specific buckets over generic kwargs**: If a flexible payload is still needed, use an explicit named parameter such as `additional_properties`, `function_invocation_kwargs`, or `client_kwargs` rather than a blanket `**kwargs`
|
||||
- **Subclassing support**: kwargs is acceptable in methods that are part of classes designed for subclassing, allowing subclass-defined kwargs to pass through without issues. In this case, clearly document that kwargs exists for subclass extensibility and not for passing arbitrary data
|
||||
- **Make known flows explicit first**: For abstract hooks, move known data flows into explicit parameters before leaving `**kwargs` behind for subclass extensibility (for example, prefer `state=` explicitly instead of passing it through kwargs)
|
||||
- **Prefer explicit metadata containers**: For constructors that expose metadata, prefer an explicit `additional_properties` parameter.
|
||||
- **Keep SDK passthroughs narrow and documented**: A kwargs escape hatch may be acceptable for provider helper APIs that pass through to a large or unstable external SDK surface, but it should be documented as SDK passthrough and revisited regularly
|
||||
- **Do not keep passthrough kwargs on wrappers that do not use them**: Convenience wrappers and session helpers should not accept generic kwargs merely to forward or ignore them
|
||||
- **Remove when possible**: In other cases, removing kwargs is likely better than keeping it
|
||||
- **Separate kwargs by purpose**: When combining kwargs for multiple purposes, use specific parameters like `client_kwargs: dict[str, Any]` instead of mixing everything in `**kwargs`
|
||||
- **Always document**: If kwargs must be used, always document how it's used, either by referencing external documentation or explaining its purpose
|
||||
|
||||
@@ -6,7 +6,7 @@ import base64
|
||||
import json
|
||||
import re
|
||||
import uuid
|
||||
from collections.abc import AsyncIterable, Awaitable, Sequence
|
||||
from collections.abc import AsyncIterable, Awaitable, Mapping, Sequence
|
||||
from typing import Any, Final, Literal, TypeAlias, overload
|
||||
|
||||
import httpx
|
||||
@@ -226,6 +226,8 @@ class A2AAgent(AgentTelemetryLayer, BaseAgent):
|
||||
*,
|
||||
stream: Literal[False] = ...,
|
||||
session: AgentSession | None = None,
|
||||
function_invocation_kwargs: Mapping[str, Any] | None = None,
|
||||
client_kwargs: Mapping[str, Any] | None = None,
|
||||
continuation_token: A2AContinuationToken | None = None,
|
||||
background: bool = False,
|
||||
**kwargs: Any,
|
||||
@@ -238,17 +240,21 @@ class A2AAgent(AgentTelemetryLayer, BaseAgent):
|
||||
*,
|
||||
stream: Literal[True],
|
||||
session: AgentSession | None = None,
|
||||
function_invocation_kwargs: Mapping[str, Any] | None = None,
|
||||
client_kwargs: Mapping[str, Any] | None = None,
|
||||
continuation_token: A2AContinuationToken | None = None,
|
||||
background: bool = False,
|
||||
**kwargs: Any,
|
||||
) -> ResponseStream[AgentResponseUpdate, AgentResponse[Any]]: ...
|
||||
|
||||
def run(
|
||||
def run( # pyright: ignore[reportIncompatibleMethodOverride]
|
||||
self,
|
||||
messages: AgentRunInputs | None = None,
|
||||
*,
|
||||
stream: bool = False,
|
||||
session: AgentSession | None = None,
|
||||
function_invocation_kwargs: Mapping[str, Any] | None = None,
|
||||
client_kwargs: Mapping[str, Any] | None = None,
|
||||
continuation_token: A2AContinuationToken | None = None,
|
||||
background: bool = False,
|
||||
**kwargs: Any,
|
||||
@@ -261,17 +267,23 @@ class A2AAgent(AgentTelemetryLayer, BaseAgent):
|
||||
Keyword Args:
|
||||
stream: Whether to stream the response. Defaults to False.
|
||||
session: The conversation session associated with the message(s).
|
||||
function_invocation_kwargs: Present for compatibility with the shared agent interface.
|
||||
A2AAgent does not use these values directly.
|
||||
client_kwargs: Present for compatibility with the shared agent interface.
|
||||
A2AAgent does not use these values directly.
|
||||
kwargs: Additional compatibility keyword arguments.
|
||||
A2AAgent does not use these values directly.
|
||||
continuation_token: Optional token to resume a long-running task
|
||||
instead of starting a new one.
|
||||
background: When True, in-progress task updates surface continuation
|
||||
tokens so the caller can poll or resubscribe later. When False
|
||||
(default), the agent internally waits for the task to complete.
|
||||
kwargs: Additional keyword arguments.
|
||||
|
||||
Returns:
|
||||
When stream=False: An Awaitable[AgentResponse].
|
||||
When stream=True: A ResponseStream of AgentResponseUpdate items.
|
||||
"""
|
||||
del function_invocation_kwargs, client_kwargs, kwargs
|
||||
if continuation_token is not None:
|
||||
a2a_stream: AsyncIterable[A2AStreamItem] = self.client.resubscribe(
|
||||
TaskIdParams(id=continuation_token["task_id"])
|
||||
|
||||
@@ -220,7 +220,6 @@ class AGUIChatClient(
|
||||
additional_properties: dict[str, Any] | None = None,
|
||||
middleware: Sequence[ChatAndFunctionMiddlewareTypes] | None = None,
|
||||
function_invocation_configuration: FunctionInvocationConfiguration | None = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Initialize the AG-UI chat client.
|
||||
|
||||
@@ -231,13 +230,11 @@ class AGUIChatClient(
|
||||
additional_properties: Additional properties to store
|
||||
middleware: Optional middleware to apply to the client.
|
||||
function_invocation_configuration: Optional function invocation configuration override.
|
||||
**kwargs: Additional arguments passed to BaseChatClient
|
||||
"""
|
||||
super().__init__(
|
||||
additional_properties=additional_properties,
|
||||
middleware=middleware,
|
||||
function_invocation_configuration=function_invocation_configuration,
|
||||
**kwargs,
|
||||
)
|
||||
self._http_service = AGUIHttpService(
|
||||
endpoint=endpoint,
|
||||
|
||||
@@ -98,7 +98,11 @@ class StreamingChatClientStub(
|
||||
options: OptionsCoT | ChatOptions[Any] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> Awaitable[ChatResponse[Any]] | ResponseStream[ChatResponseUpdate, ChatResponse[Any]]:
|
||||
self.last_session = kwargs.get("session")
|
||||
client_kwargs = kwargs.get("client_kwargs")
|
||||
if isinstance(client_kwargs, Mapping):
|
||||
self.last_session = cast(AgentSession | None, client_kwargs.get("session"))
|
||||
else:
|
||||
self.last_session = None
|
||||
self.last_service_session_id = self.last_session.service_session_id if self.last_session else None
|
||||
return cast(
|
||||
Awaitable[ChatResponse[Any]] | ResponseStream[ChatResponseUpdate, ChatResponse[Any]],
|
||||
|
||||
@@ -702,14 +702,9 @@ async def test_agent_with_use_service_session_is_true(streaming_chat_client_stub
|
||||
"""Test that when use_service_session is True, the AgentSession used to run the agent is set to the service session ID."""
|
||||
from agent_framework.ag_ui import AgentFrameworkAgent
|
||||
|
||||
request_service_session_id: str | None = None
|
||||
|
||||
async def stream_fn(
|
||||
messages: MutableSequence[Message], chat_options: ChatOptions, **kwargs: Any
|
||||
) -> AsyncIterator[ChatResponseUpdate]:
|
||||
nonlocal request_service_session_id
|
||||
session = kwargs.get("session")
|
||||
request_service_session_id = session.service_session_id if session else None
|
||||
yield ChatResponseUpdate(
|
||||
contents=[Content.from_text(text="Response")], response_id="resp_67890", conversation_id="conv_12345"
|
||||
)
|
||||
@@ -719,11 +714,22 @@ async def test_agent_with_use_service_session_is_true(streaming_chat_client_stub
|
||||
|
||||
input_data = {"messages": [{"role": "user", "content": "Hi"}], "thread_id": "conv_123456"}
|
||||
|
||||
# Spy on agent.run to capture the session kwarg at call time (before streaming mutates it)
|
||||
captured_service_session_id: str | None = None
|
||||
original_run = agent.run
|
||||
|
||||
def capturing_run(*args: Any, **kwargs: Any) -> Any:
|
||||
nonlocal captured_service_session_id
|
||||
session = kwargs.get("session")
|
||||
captured_service_session_id = session.service_session_id if session else None
|
||||
return original_run(*args, **kwargs)
|
||||
|
||||
agent.run = capturing_run # type: ignore[assignment, method-assign]
|
||||
|
||||
events: list[Any] = []
|
||||
async for event in wrapper.run(input_data):
|
||||
events.append(event)
|
||||
request_service_session_id = agent.client.last_service_session_id
|
||||
assert request_service_session_id == "conv_123456" # type: ignore[attr-defined] (service_session_id should be set)
|
||||
assert captured_service_session_id == "conv_123456"
|
||||
|
||||
|
||||
async def test_function_approval_mode_executes_tool(streaming_chat_client_stub):
|
||||
|
||||
@@ -228,11 +228,11 @@ class AnthropicClient(
|
||||
model_id: str | None = None,
|
||||
anthropic_client: AsyncAnthropic | None = None,
|
||||
additional_beta_flags: list[str] | None = None,
|
||||
additional_properties: dict[str, Any] | None = None,
|
||||
middleware: Sequence[ChatAndFunctionMiddlewareTypes] | None = None,
|
||||
function_invocation_configuration: FunctionInvocationConfiguration | None = None,
|
||||
env_file_path: str | None = None,
|
||||
env_file_encoding: str | None = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Initialize an Anthropic Agent client.
|
||||
|
||||
@@ -244,11 +244,11 @@ class AnthropicClient(
|
||||
For instance if you need to set a different base_url for testing or private deployments.
|
||||
additional_beta_flags: Additional beta flags to enable on the client.
|
||||
Default flags are: "mcp-client-2025-04-04", "code-execution-2025-08-25".
|
||||
additional_properties: Additional properties stored on the client instance.
|
||||
middleware: Optional middleware to apply to the client.
|
||||
function_invocation_configuration: Optional function invocation configuration override.
|
||||
env_file_path: Path to environment file for loading settings.
|
||||
env_file_encoding: Encoding of the environment file.
|
||||
kwargs: Additional keyword arguments passed to the parent class.
|
||||
|
||||
Examples:
|
||||
.. code-block:: python
|
||||
@@ -319,9 +319,9 @@ class AnthropicClient(
|
||||
|
||||
# Initialize parent
|
||||
super().__init__(
|
||||
additional_properties=additional_properties,
|
||||
middleware=middleware,
|
||||
function_invocation_configuration=function_invocation_configuration,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# Initialize instance variables
|
||||
|
||||
@@ -17,10 +17,15 @@ from agent_framework_azure_ai_search._context_provider import AzureAISearchConte
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def clear_azure_search_environment(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
for key in tuple(os.environ):
|
||||
if key.startswith("AZURE_SEARCH_"):
|
||||
monkeypatch.delenv(key, raising=False)
|
||||
def clear_azure_search_env(monkeypatch: pytest.MonkeyPatch) -> None:
|
||||
"""Keep tests isolated from ambient Azure Search environment variables."""
|
||||
for key in (
|
||||
"AZURE_SEARCH_ENDPOINT",
|
||||
"AZURE_SEARCH_INDEX_NAME",
|
||||
"AZURE_SEARCH_KNOWLEDGE_BASE_NAME",
|
||||
"AZURE_SEARCH_API_KEY",
|
||||
):
|
||||
monkeypatch.delenv(key, raising=False)
|
||||
|
||||
|
||||
class MockSearchResults:
|
||||
|
||||
@@ -444,11 +444,11 @@ class AzureAIAgentClient(
|
||||
model_deployment_name: str | None = None,
|
||||
credential: AzureCredentialTypes | None = None,
|
||||
should_cleanup_agent: bool = True,
|
||||
additional_properties: dict[str, Any] | None = None,
|
||||
middleware: Sequence[ChatAndFunctionMiddlewareTypes] | None = None,
|
||||
function_invocation_configuration: FunctionInvocationConfiguration | None = None,
|
||||
env_file_path: str | None = None,
|
||||
env_file_encoding: str | None = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Initialize an Azure AI Agent client.
|
||||
|
||||
@@ -471,11 +471,11 @@ class AzureAIAgentClient(
|
||||
should_cleanup_agent: Whether to cleanup (delete) agents created by this client when
|
||||
the client is closed or context is exited. Defaults to True. Only affects agents
|
||||
created by this client instance; existing agents passed via agent_id are never deleted.
|
||||
additional_properties: Additional properties stored on the client instance.
|
||||
middleware: Optional sequence of middlewares to include.
|
||||
function_invocation_configuration: Optional function invocation configuration.
|
||||
env_file_path: Path to environment file for loading settings.
|
||||
env_file_encoding: Encoding of the environment file.
|
||||
kwargs: Additional keyword arguments passed to the parent class.
|
||||
|
||||
Examples:
|
||||
.. code-block:: python
|
||||
@@ -548,9 +548,9 @@ class AzureAIAgentClient(
|
||||
|
||||
# Initialize parent
|
||||
super().__init__(
|
||||
additional_properties=additional_properties,
|
||||
middleware=middleware,
|
||||
function_invocation_configuration=function_invocation_configuration,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
# Initialize instance variables
|
||||
|
||||
@@ -119,9 +119,9 @@ class RawAzureAIClient(RawOpenAIResponsesClient[AzureAIClientOptionsT], Generic[
|
||||
credential: AzureCredentialTypes | None = None,
|
||||
use_latest_version: bool | None = None,
|
||||
allow_preview: bool | None = None,
|
||||
additional_properties: dict[str, Any] | None = None,
|
||||
env_file_path: str | None = None,
|
||||
env_file_encoding: str | None = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Initialize a bare Azure AI client.
|
||||
|
||||
@@ -145,9 +145,9 @@ class RawAzureAIClient(RawOpenAIResponsesClient[AzureAIClientOptionsT], Generic[
|
||||
use_latest_version: Boolean flag that indicates whether to use latest agent version
|
||||
if it exists in the service.
|
||||
allow_preview: Enables preview opt-in on internally-created ``AIProjectClient``.
|
||||
additional_properties: Additional properties stored on the client instance.
|
||||
env_file_path: Path to environment file for loading settings.
|
||||
env_file_encoding: Encoding of the environment file.
|
||||
kwargs: Additional keyword arguments passed to the parent class.
|
||||
|
||||
Examples:
|
||||
.. code-block:: python
|
||||
@@ -217,7 +217,7 @@ class RawAzureAIClient(RawOpenAIResponsesClient[AzureAIClientOptionsT], Generic[
|
||||
|
||||
# Initialize parent
|
||||
super().__init__(
|
||||
**kwargs,
|
||||
additional_properties=additional_properties,
|
||||
)
|
||||
|
||||
# Initialize instance variables
|
||||
@@ -1243,11 +1243,11 @@ class AzureAIClient(
|
||||
credential: AzureCredentialTypes | None = None,
|
||||
use_latest_version: bool | None = None,
|
||||
allow_preview: bool | None = None,
|
||||
additional_properties: dict[str, Any] | None = None,
|
||||
middleware: Sequence[ChatAndFunctionMiddlewareTypes] | None = None,
|
||||
function_invocation_configuration: FunctionInvocationConfiguration | None = None,
|
||||
env_file_path: str | None = None,
|
||||
env_file_encoding: str | None = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Initialize an Azure AI client with full layer support.
|
||||
|
||||
@@ -1268,11 +1268,11 @@ class AzureAIClient(
|
||||
use_latest_version: Boolean flag that indicates whether to use latest agent version
|
||||
if it exists in the service.
|
||||
allow_preview: Enables preview opt-in on internally-created ``AIProjectClient``
|
||||
additional_properties: Additional properties stored on the client instance.
|
||||
middleware: Optional sequence of chat middlewares to include.
|
||||
function_invocation_configuration: Optional function invocation configuration.
|
||||
env_file_path: Path to environment file for loading settings.
|
||||
env_file_encoding: Encoding of the environment file.
|
||||
kwargs: Additional keyword arguments passed to the parent class.
|
||||
|
||||
Examples:
|
||||
.. code-block:: python
|
||||
@@ -1319,9 +1319,9 @@ class AzureAIClient(
|
||||
credential=credential,
|
||||
use_latest_version=use_latest_version,
|
||||
allow_preview=allow_preview,
|
||||
additional_properties=additional_properties,
|
||||
middleware=middleware,
|
||||
function_invocation_configuration=function_invocation_configuration,
|
||||
env_file_path=env_file_path,
|
||||
env_file_encoding=env_file_encoding,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@@ -124,9 +124,9 @@ class RawAzureAIInferenceEmbeddingClient(
|
||||
text_client: EmbeddingsClient | None = None,
|
||||
image_client: ImageEmbeddingsClient | None = None,
|
||||
credential: AzureKeyCredential | None = None,
|
||||
additional_properties: dict[str, Any] | None = None,
|
||||
env_file_path: str | None = None,
|
||||
env_file_encoding: str | None = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Initialize a raw Azure AI Inference embedding client."""
|
||||
settings = load_settings(
|
||||
@@ -160,7 +160,7 @@ class RawAzureAIInferenceEmbeddingClient(
|
||||
credential=credential, # type: ignore[arg-type]
|
||||
)
|
||||
self._endpoint = resolved_endpoint
|
||||
super().__init__(**kwargs)
|
||||
super().__init__(additional_properties=additional_properties)
|
||||
|
||||
async def close(self) -> None:
|
||||
"""Close the underlying SDK clients and release resources."""
|
||||
@@ -376,9 +376,9 @@ class AzureAIInferenceEmbeddingClient(
|
||||
image_client: ImageEmbeddingsClient | None = None,
|
||||
credential: AzureKeyCredential | None = None,
|
||||
otel_provider_name: str | None = None,
|
||||
additional_properties: dict[str, Any] | None = None,
|
||||
env_file_path: str | None = None,
|
||||
env_file_encoding: str | None = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Initialize an Azure AI Inference embedding client."""
|
||||
super().__init__(
|
||||
@@ -389,8 +389,8 @@ class AzureAIInferenceEmbeddingClient(
|
||||
text_client=text_client,
|
||||
image_client=image_client,
|
||||
credential=credential,
|
||||
additional_properties=additional_properties,
|
||||
otel_provider_name=otel_provider_name,
|
||||
env_file_path=env_file_path,
|
||||
env_file_encoding=env_file_encoding,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@@ -124,7 +124,13 @@ class CosmosHistoryProvider(BaseHistoryProvider):
|
||||
|
||||
self._database_client = self._cosmos_client.get_database_client(self.database_name)
|
||||
|
||||
async def get_messages(self, session_id: str | None, **kwargs: Any) -> list[Message]:
|
||||
async def get_messages(
|
||||
self,
|
||||
session_id: str | None,
|
||||
*,
|
||||
state: dict[str, Any] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> list[Message]:
|
||||
"""Retrieve stored messages for this session from Azure Cosmos DB."""
|
||||
await self._ensure_container_proxy()
|
||||
session_key = self._session_partition_key(session_id)
|
||||
@@ -157,7 +163,14 @@ class CosmosHistoryProvider(BaseHistoryProvider):
|
||||
|
||||
return messages
|
||||
|
||||
async def save_messages(self, session_id: str | None, messages: Sequence[Message], **kwargs: Any) -> None:
|
||||
async def save_messages(
|
||||
self,
|
||||
session_id: str | None,
|
||||
messages: Sequence[Message],
|
||||
*,
|
||||
state: dict[str, Any] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Persist messages for this session to Azure Cosmos DB."""
|
||||
if not messages:
|
||||
return
|
||||
|
||||
@@ -236,11 +236,11 @@ class BedrockChatClient(
|
||||
session_token: str | None = None,
|
||||
client: BaseClient | None = None,
|
||||
boto3_session: Boto3Session | None = None,
|
||||
additional_properties: dict[str, Any] | None = None,
|
||||
middleware: Sequence[ChatAndFunctionMiddlewareTypes] | None = None,
|
||||
function_invocation_configuration: FunctionInvocationConfiguration | None = None,
|
||||
env_file_path: str | None = None,
|
||||
env_file_encoding: str | None = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Create a Bedrock chat client and load AWS credentials.
|
||||
|
||||
@@ -252,11 +252,11 @@ class BedrockChatClient(
|
||||
session_token: Optional AWS session token for temporary credentials.
|
||||
client: Preconfigured Bedrock runtime client; when omitted a boto3 session is created.
|
||||
boto3_session: Custom boto3 session used to build the runtime client if provided.
|
||||
additional_properties: Additional properties stored on the client instance.
|
||||
middleware: Optional sequence of middlewares to include.
|
||||
function_invocation_configuration: Optional function invocation configuration
|
||||
env_file_path: Optional .env file path used by ``BedrockSettings`` to load defaults.
|
||||
env_file_encoding: Encoding for the optional .env file.
|
||||
kwargs: Additional arguments forwarded to ``BaseChatClient``.
|
||||
|
||||
Examples:
|
||||
.. code-block:: python
|
||||
@@ -303,9 +303,9 @@ class BedrockChatClient(
|
||||
)
|
||||
|
||||
super().__init__(
|
||||
additional_properties=additional_properties,
|
||||
middleware=middleware,
|
||||
function_invocation_configuration=function_invocation_configuration,
|
||||
**kwargs,
|
||||
)
|
||||
self.model_id = chat_model_id
|
||||
self.region = region
|
||||
|
||||
@@ -104,9 +104,9 @@ class RawBedrockEmbeddingClient(
|
||||
session_token: str | None = None,
|
||||
client: BaseClient | None = None,
|
||||
boto3_session: Boto3Session | None = None,
|
||||
additional_properties: dict[str, Any] | None = None,
|
||||
env_file_path: str | None = None,
|
||||
env_file_encoding: str | None = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Initialize a raw Bedrock embedding client."""
|
||||
settings = load_settings(
|
||||
@@ -145,7 +145,7 @@ class RawBedrockEmbeddingClient(
|
||||
|
||||
self.model_id: str = settings["embedding_model_id"] # type: ignore[assignment] # pyright: ignore[reportTypedDictNotRequiredAccess]
|
||||
self.region = resolved_region
|
||||
super().__init__(**kwargs)
|
||||
super().__init__(additional_properties=additional_properties)
|
||||
|
||||
def service_url(self) -> str:
|
||||
"""Get the URL of the service."""
|
||||
@@ -274,9 +274,9 @@ class BedrockEmbeddingClient(
|
||||
client: BaseClient | None = None,
|
||||
boto3_session: Boto3Session | None = None,
|
||||
otel_provider_name: str | None = None,
|
||||
additional_properties: dict[str, Any] | None = None,
|
||||
env_file_path: str | None = None,
|
||||
env_file_encoding: str | None = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Initialize a Bedrock embedding client."""
|
||||
super().__init__(
|
||||
@@ -287,8 +287,8 @@ class BedrockEmbeddingClient(
|
||||
session_token=session_token,
|
||||
client=client,
|
||||
boto3_session=boto3_session,
|
||||
additional_properties=additional_properties,
|
||||
otel_provider_name=otel_provider_name,
|
||||
env_file_path=env_file_path,
|
||||
env_file_encoding=env_file_encoding,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@@ -590,6 +590,7 @@ class RawClaudeAgent(BaseAgent, Generic[OptionsT]):
|
||||
*,
|
||||
stream: Literal[False] = ...,
|
||||
session: AgentSession | None = None,
|
||||
options: OptionsT | None = None,
|
||||
**kwargs: Any,
|
||||
) -> Awaitable[AgentResponse[Any]]: ...
|
||||
|
||||
@@ -600,6 +601,7 @@ class RawClaudeAgent(BaseAgent, Generic[OptionsT]):
|
||||
*,
|
||||
stream: Literal[True],
|
||||
session: AgentSession | None = None,
|
||||
options: OptionsT | None = None,
|
||||
**kwargs: Any,
|
||||
) -> ResponseStream[AgentResponseUpdate, AgentResponse[Any]]: ...
|
||||
|
||||
@@ -609,7 +611,8 @@ class RawClaudeAgent(BaseAgent, Generic[OptionsT]):
|
||||
*,
|
||||
stream: bool = False,
|
||||
session: AgentSession | None = None,
|
||||
**kwargs: Any,
|
||||
options: OptionsT | None = None,
|
||||
**kwargs: Any, # type: ignore
|
||||
) -> Awaitable[AgentResponse[Any]] | ResponseStream[AgentResponseUpdate, AgentResponse[Any]]:
|
||||
"""Run the agent with the given messages.
|
||||
|
||||
@@ -621,16 +624,16 @@ class RawClaudeAgent(BaseAgent, Generic[OptionsT]):
|
||||
returns an awaitable AgentResponse.
|
||||
session: The conversation session. If session has service_session_id set,
|
||||
the agent will resume that session.
|
||||
kwargs: Additional keyword arguments including 'options' for runtime options
|
||||
(model, permission_mode can be changed per-request).
|
||||
options: Runtime options. Model and permission_mode can be changed per request.
|
||||
kwargs: Additional keyword arguments for compatibility with the shared agent
|
||||
interface (e.g. compaction_strategy, tokenizer). Not used by ClaudeAgent.
|
||||
|
||||
Returns:
|
||||
When stream=True: An ResponseStream for streaming updates.
|
||||
When stream=False: An Awaitable[AgentResponse] with the complete response.
|
||||
"""
|
||||
options = kwargs.pop("options", None)
|
||||
response = ResponseStream(
|
||||
self._get_stream(messages, session=session, options=options, **kwargs),
|
||||
self._get_stream(messages, session=session, options=options),
|
||||
finalizer=self._finalize_response,
|
||||
)
|
||||
|
||||
@@ -643,8 +646,7 @@ class RawClaudeAgent(BaseAgent, Generic[OptionsT]):
|
||||
messages: AgentRunInputs | None = None,
|
||||
*,
|
||||
session: AgentSession | None = None,
|
||||
options: OptionsT | MutableMapping[str, Any] | None = None,
|
||||
**kwargs: Any,
|
||||
options: OptionsT | None = None,
|
||||
) -> AsyncIterable[AgentResponseUpdate]:
|
||||
"""Internal streaming implementation."""
|
||||
session = session or self.create_session()
|
||||
|
||||
@@ -196,7 +196,6 @@ class CopilotStudioAgent(BaseAgent):
|
||||
*,
|
||||
stream: Literal[False] = False,
|
||||
session: AgentSession | None = None,
|
||||
**kwargs: Any,
|
||||
) -> Awaitable[AgentResponse]: ...
|
||||
|
||||
@overload
|
||||
@@ -206,7 +205,6 @@ class CopilotStudioAgent(BaseAgent):
|
||||
*,
|
||||
stream: Literal[True],
|
||||
session: AgentSession | None = None,
|
||||
**kwargs: Any,
|
||||
) -> ResponseStream[AgentResponseUpdate, AgentResponse]: ...
|
||||
|
||||
def run(
|
||||
@@ -215,7 +213,6 @@ class CopilotStudioAgent(BaseAgent):
|
||||
*,
|
||||
stream: bool = False,
|
||||
session: AgentSession | None = None,
|
||||
**kwargs: Any,
|
||||
) -> Awaitable[AgentResponse] | ResponseStream[AgentResponseUpdate, AgentResponse]:
|
||||
"""Get a response from the agent.
|
||||
|
||||
@@ -229,22 +226,20 @@ class CopilotStudioAgent(BaseAgent):
|
||||
Keyword Args:
|
||||
stream: Whether to stream the response. Defaults to False.
|
||||
session: The conversation session associated with the message(s).
|
||||
kwargs: Additional keyword arguments.
|
||||
|
||||
Returns:
|
||||
When stream=False: An Awaitable[AgentResponse].
|
||||
When stream=True: A ResponseStream of AgentResponseUpdate items.
|
||||
"""
|
||||
if stream:
|
||||
return self._run_stream_impl(messages=messages, session=session, **kwargs)
|
||||
return self._run_impl(messages=messages, session=session, **kwargs)
|
||||
return self._run_stream_impl(messages=messages, session=session)
|
||||
return self._run_impl(messages=messages, session=session)
|
||||
|
||||
async def _run_impl(
|
||||
self,
|
||||
messages: AgentRunInputs | None = None,
|
||||
*,
|
||||
session: AgentSession | None = None,
|
||||
**kwargs: Any,
|
||||
) -> AgentResponse:
|
||||
"""Non-streaming implementation of run."""
|
||||
if not session:
|
||||
@@ -269,7 +264,6 @@ class CopilotStudioAgent(BaseAgent):
|
||||
messages: AgentRunInputs | None = None,
|
||||
*,
|
||||
session: AgentSession | None = None,
|
||||
**kwargs: Any,
|
||||
) -> ResponseStream[AgentResponseUpdate, AgentResponse]:
|
||||
"""Streaming implementation of run."""
|
||||
|
||||
|
||||
@@ -215,6 +215,7 @@ from ._workflows._workflow_executor import (
|
||||
)
|
||||
from .exceptions import (
|
||||
MiddlewareException,
|
||||
UserInputRequiredException,
|
||||
WorkflowCheckpointException,
|
||||
WorkflowConvergenceException,
|
||||
WorkflowException,
|
||||
@@ -349,6 +350,7 @@ __all__ = [
|
||||
"TypeCompatibilityError",
|
||||
"UpdateT",
|
||||
"UsageDetails",
|
||||
"UserInputRequiredException",
|
||||
"ValidationTypeEnum",
|
||||
"Workflow",
|
||||
"WorkflowAgent",
|
||||
|
||||
@@ -2,10 +2,10 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import inspect
|
||||
import logging
|
||||
import re
|
||||
import sys
|
||||
import warnings
|
||||
from collections.abc import Awaitable, Callable, Mapping, MutableMapping, Sequence
|
||||
from contextlib import AbstractAsyncContextManager, AsyncExitStack
|
||||
from copy import deepcopy
|
||||
@@ -27,12 +27,13 @@ from uuid import uuid4
|
||||
from mcp import types
|
||||
from mcp.server.lowlevel import Server
|
||||
from mcp.shared.exceptions import McpError
|
||||
from pydantic import BaseModel, Field, create_model
|
||||
from pydantic import BaseModel
|
||||
|
||||
from . import _tools as _tool_utils # pyright: ignore[reportPrivateUsage]
|
||||
from ._clients import BaseChatClient, SupportsChatGetResponse
|
||||
from ._docstrings import apply_layered_docstring
|
||||
from ._mcp import LOG_LEVEL_MAPPING, MCPTool
|
||||
from ._middleware import AgentMiddlewareLayer, MiddlewareTypes
|
||||
from ._middleware import AgentMiddlewareLayer, FunctionInvocationContext, MiddlewareTypes
|
||||
from ._serialization import SerializationMixin
|
||||
from ._sessions import (
|
||||
AgentSession,
|
||||
@@ -53,7 +54,7 @@ from ._types import (
|
||||
map_chat_to_agent_update,
|
||||
normalize_messages,
|
||||
)
|
||||
from .exceptions import AgentInvalidResponseException
|
||||
from .exceptions import AgentInvalidResponseException, UserInputRequiredException
|
||||
from .observability import AgentTelemetryLayer
|
||||
|
||||
if sys.version_info >= (3, 13):
|
||||
@@ -169,8 +170,8 @@ class _RunContext(TypedDict):
|
||||
chat_options: MutableMapping[str, Any]
|
||||
compaction_strategy: CompactionStrategy | None
|
||||
tokenizer: TokenizerProtocol | None
|
||||
filtered_kwargs: Mapping[str, Any]
|
||||
finalize_kwargs: Mapping[str, Any]
|
||||
client_kwargs: Mapping[str, Any]
|
||||
function_invocation_kwargs: Mapping[str, Any]
|
||||
|
||||
|
||||
# region Agent Protocol
|
||||
@@ -218,15 +219,15 @@ class SupportsAgentRun(Protocol):
|
||||
|
||||
return AgentResponse(messages=[], response_id="custom-response")
|
||||
|
||||
def create_session(self, **kwargs):
|
||||
def create_session(self, *, session_id: str | None = None):
|
||||
from agent_framework import AgentSession
|
||||
|
||||
return AgentSession(**kwargs)
|
||||
return AgentSession(session_id=session_id)
|
||||
|
||||
def get_session(self, *, service_session_id, **kwargs):
|
||||
def get_session(self, service_session_id: str, *, session_id: str | None = None):
|
||||
from agent_framework import AgentSession
|
||||
|
||||
return AgentSession(service_session_id=service_session_id, **kwargs)
|
||||
return AgentSession(service_session_id=service_session_id, session_id=session_id)
|
||||
|
||||
|
||||
# Verify the instance satisfies the protocol
|
||||
@@ -245,6 +246,8 @@ class SupportsAgentRun(Protocol):
|
||||
*,
|
||||
stream: Literal[False] = ...,
|
||||
session: AgentSession | None = None,
|
||||
function_invocation_kwargs: Mapping[str, Any] | None = None,
|
||||
client_kwargs: Mapping[str, Any] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> Awaitable[AgentResponse[Any]]:
|
||||
"""Get a response from the agent (non-streaming)."""
|
||||
@@ -257,6 +260,8 @@ class SupportsAgentRun(Protocol):
|
||||
*,
|
||||
stream: Literal[True],
|
||||
session: AgentSession | None = None,
|
||||
function_invocation_kwargs: Mapping[str, Any] | None = None,
|
||||
client_kwargs: Mapping[str, Any] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> ResponseStream[AgentResponseUpdate, AgentResponse[Any]]:
|
||||
"""Get a streaming response from the agent."""
|
||||
@@ -268,6 +273,8 @@ class SupportsAgentRun(Protocol):
|
||||
*,
|
||||
stream: bool = False,
|
||||
session: AgentSession | None = None,
|
||||
function_invocation_kwargs: Mapping[str, Any] | None = None,
|
||||
client_kwargs: Mapping[str, Any] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> Awaitable[AgentResponse[Any]] | ResponseStream[AgentResponseUpdate, AgentResponse[Any]]:
|
||||
"""Get a response from the agent.
|
||||
@@ -282,6 +289,8 @@ class SupportsAgentRun(Protocol):
|
||||
Keyword Args:
|
||||
stream: Whether to stream the response. Defaults to False.
|
||||
session: The conversation session associated with the message(s).
|
||||
function_invocation_kwargs: Keyword arguments forwarded to tool invocation.
|
||||
client_kwargs: Additional client-specific keyword arguments.
|
||||
kwargs: Additional keyword arguments.
|
||||
|
||||
Returns:
|
||||
@@ -291,11 +300,11 @@ class SupportsAgentRun(Protocol):
|
||||
"""
|
||||
...
|
||||
|
||||
def create_session(self, **kwargs: Any) -> AgentSession:
|
||||
def create_session(self, *, session_id: str | None = None) -> AgentSession:
|
||||
"""Creates a new conversation session."""
|
||||
...
|
||||
|
||||
def get_session(self, *, service_session_id: str, **kwargs: Any) -> AgentSession:
|
||||
def get_session(self, service_session_id: str, *, session_id: str | None = None) -> AgentSession:
|
||||
"""Gets or creates a session for a service-managed session ID."""
|
||||
...
|
||||
|
||||
@@ -378,6 +387,13 @@ class BaseAgent(SerializationMixin):
|
||||
additional_properties: Additional properties set on the agent.
|
||||
kwargs: Additional keyword arguments (merged into additional_properties).
|
||||
"""
|
||||
if kwargs:
|
||||
warnings.warn(
|
||||
"Passing additional properties as direct keyword arguments to BaseAgent is deprecated; "
|
||||
"pass them via additional_properties instead.",
|
||||
DeprecationWarning,
|
||||
stacklevel=3,
|
||||
)
|
||||
if id is None:
|
||||
id = str(uuid4())
|
||||
self.id = id
|
||||
@@ -392,27 +408,40 @@ class BaseAgent(SerializationMixin):
|
||||
self.additional_properties: dict[str, Any] = cast(dict[str, Any], additional_properties or {})
|
||||
self.additional_properties.update(kwargs)
|
||||
|
||||
def create_session(self, *, session_id: str | None = None, **kwargs: Any) -> AgentSession:
|
||||
def create_session(self, *, session_id: str | None = None) -> AgentSession:
|
||||
"""Create a new lightweight session.
|
||||
|
||||
This will be used by an agent to hold the persisted session.
|
||||
This depends on the service used, in some cases, or with store=True
|
||||
this will add the ``service_session_id`` based on the response,
|
||||
which is then fed back to the API on the next call.
|
||||
|
||||
In other cases, if there is a HistoryProvider setup in the agent,
|
||||
that is used and it can store state in the session.
|
||||
|
||||
If there is no HistoryProvider and store=False or the default of a service is False.
|
||||
Then a ``InMemoryHistoryProvider`` instance is added to the agent and used with the session automatically.
|
||||
The ``InMemoryHistoryProvider`` stores the messages as `state` in the session by default.
|
||||
|
||||
Keyword Args:
|
||||
session_id: Optional session ID (generated if not provided).
|
||||
kwargs: Additional keyword arguments.
|
||||
|
||||
Returns:
|
||||
A new AgentSession instance.
|
||||
"""
|
||||
return AgentSession(session_id=session_id)
|
||||
|
||||
def get_session(self, *, service_session_id: str, session_id: str | None = None, **kwargs: Any) -> AgentSession:
|
||||
"""Get or create a session for a service-managed session ID.
|
||||
def get_session(self, service_session_id: str, *, session_id: str | None = None) -> AgentSession:
|
||||
"""Get a session for a service-managed session ID.
|
||||
|
||||
Only use this to create a session continuing that session id from a service.
|
||||
Otherwise use ``create_session``.
|
||||
|
||||
Args:
|
||||
service_session_id: The service-managed session ID.
|
||||
|
||||
Keyword Args:
|
||||
session_id: Optional local session ID (generated if not provided).
|
||||
kwargs: Additional keyword arguments.
|
||||
|
||||
Returns:
|
||||
A new AgentSession instance with service_session_id set.
|
||||
@@ -452,9 +481,8 @@ class BaseAgent(SerializationMixin):
|
||||
description: str | None = None,
|
||||
arg_name: str = "task",
|
||||
arg_description: str | None = None,
|
||||
stream_callback: Callable[[AgentResponseUpdate], None]
|
||||
| Callable[[AgentResponseUpdate], Awaitable[None]]
|
||||
| None = None,
|
||||
approval_mode: Literal["always_require", "never_require"] = "never_require",
|
||||
stream_callback: Callable[[AgentResponseUpdate], Awaitable[None] | None] | None = None,
|
||||
propagate_session: bool = False,
|
||||
) -> FunctionTool:
|
||||
"""Create a FunctionTool that wraps this agent.
|
||||
@@ -465,21 +493,15 @@ class BaseAgent(SerializationMixin):
|
||||
arg_name: The name of the function argument (default: "task").
|
||||
arg_description: The description for the function argument.
|
||||
If None, defaults to "Task for {tool_name}".
|
||||
approval_mode: Whether this delegated tool requires approval before execution.
|
||||
stream_callback: Optional callback for streaming responses. If provided, uses run(..., stream=True).
|
||||
propagate_session: If True, the parent agent's ``AgentSession`` is
|
||||
forwarded to this sub-agent's ``run()`` call, so both agents
|
||||
operate within the same logical session (sharing the same
|
||||
``session_id`` and provider-managed state, such as any stored
|
||||
conversation history or metadata). Defaults to False, meaning
|
||||
the sub-agent runs with a new, independent session.
|
||||
propagate_session: If True, the parent agent's session is forwarded
|
||||
to this sub-agent's ``run()`` call so both agents share the
|
||||
same session. Defaults to False.
|
||||
|
||||
Returns:
|
||||
A FunctionTool that can be used as a tool by other agents.
|
||||
|
||||
Raises:
|
||||
TypeError: If the agent does not implement SupportsAgentRun.
|
||||
ValueError: If the agent tool name cannot be determined.
|
||||
|
||||
Examples:
|
||||
.. code-block:: python
|
||||
|
||||
@@ -507,59 +529,46 @@ class BaseAgent(SerializationMixin):
|
||||
tool_description = description or self.description or ""
|
||||
argument_description = arg_description or f"Task for {tool_name}"
|
||||
|
||||
# Create dynamic input model with the specified argument name
|
||||
field_info = Field(..., description=argument_description)
|
||||
model_name = f"{name or _sanitize_agent_name(self.name) or 'agent'}_task"
|
||||
input_model = create_model(model_name, **{arg_name: (str, field_info)}) # type: ignore[call-overload]
|
||||
input_schema = {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
arg_name: {
|
||||
"type": "string",
|
||||
"description": argument_description,
|
||||
}
|
||||
},
|
||||
"required": [arg_name],
|
||||
"additionalProperties": False,
|
||||
}
|
||||
|
||||
# Check if callback is async once, outside the wrapper
|
||||
is_async_callback = stream_callback is not None and inspect.iscoroutinefunction(stream_callback)
|
||||
async def _agent_wrapper(ctx: FunctionInvocationContext, **kwargs: Any) -> str:
|
||||
"""Wrapper function that calls the agent.
|
||||
|
||||
async def agent_wrapper(**kwargs: Any) -> str:
|
||||
"""Wrapper function that calls the agent."""
|
||||
# Extract the input from kwargs using the specified arg_name
|
||||
input_text = kwargs.get(arg_name, "")
|
||||
Args:
|
||||
ctx: the function invocation context used
|
||||
**kwargs: only used to dynamically load the argument that is defined for this tool.
|
||||
"""
|
||||
stream = self.run(
|
||||
str(kwargs.get(arg_name, "")),
|
||||
stream=True,
|
||||
session=ctx.session if propagate_session else None,
|
||||
function_invocation_kwargs=dict(ctx.kwargs),
|
||||
)
|
||||
if stream_callback is not None:
|
||||
stream.with_transform_hook(stream_callback)
|
||||
final_response = await stream.get_final_response()
|
||||
if final_response.user_input_requests:
|
||||
raise UserInputRequiredException(contents=final_response.user_input_requests)
|
||||
# TODO(Copilot): update once #4331 merges
|
||||
return final_response.text
|
||||
|
||||
# Extract parent session when propagate_session is enabled
|
||||
parent_session = kwargs.get("session") if propagate_session else None
|
||||
|
||||
# Forward runtime context kwargs, excluding framework-internal keys.
|
||||
forwarded_kwargs = {
|
||||
k: v for k, v in kwargs.items() if k not in (arg_name, "conversation_id", "options", "session")
|
||||
}
|
||||
|
||||
if stream_callback is None:
|
||||
# Use non-streaming mode
|
||||
return (
|
||||
await self.run(
|
||||
input_text,
|
||||
stream=False,
|
||||
session=parent_session,
|
||||
**forwarded_kwargs,
|
||||
)
|
||||
).text
|
||||
|
||||
# Use streaming mode - accumulate updates and create final response
|
||||
response_updates: list[AgentResponseUpdate] = []
|
||||
async for update in self.run(input_text, stream=True, session=parent_session, **forwarded_kwargs):
|
||||
response_updates.append(update)
|
||||
if is_async_callback:
|
||||
await stream_callback(update) # type: ignore[misc]
|
||||
else:
|
||||
stream_callback(update)
|
||||
|
||||
# Create final text from accumulated updates
|
||||
return AgentResponse.from_updates(response_updates).text
|
||||
|
||||
agent_tool: FunctionTool = FunctionTool(
|
||||
return FunctionTool(
|
||||
name=tool_name,
|
||||
description=tool_description,
|
||||
func=agent_wrapper,
|
||||
input_model=input_model, # type: ignore
|
||||
approval_mode="never_require",
|
||||
func=_agent_wrapper,
|
||||
input_model=input_schema,
|
||||
approval_mode=approval_mode,
|
||||
)
|
||||
agent_tool._forward_runtime_kwargs = True # type: ignore
|
||||
return agent_tool
|
||||
|
||||
|
||||
# region Agent
|
||||
@@ -801,6 +810,8 @@ class RawAgent(BaseAgent, Generic[OptionsCoT]): # type: ignore[misc]
|
||||
options: ChatOptions[ResponseModelBoundT],
|
||||
compaction_strategy: CompactionStrategy | None = None,
|
||||
tokenizer: TokenizerProtocol | None = None,
|
||||
function_invocation_kwargs: Mapping[str, Any] | None = None,
|
||||
client_kwargs: Mapping[str, Any] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> Awaitable[AgentResponse[ResponseModelBoundT]]: ...
|
||||
|
||||
@@ -815,6 +826,8 @@ class RawAgent(BaseAgent, Generic[OptionsCoT]): # type: ignore[misc]
|
||||
options: OptionsCoT | ChatOptions[None] | None = None,
|
||||
compaction_strategy: CompactionStrategy | None = None,
|
||||
tokenizer: TokenizerProtocol | None = None,
|
||||
function_invocation_kwargs: Mapping[str, Any] | None = None,
|
||||
client_kwargs: Mapping[str, Any] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> Awaitable[AgentResponse[Any]]: ...
|
||||
|
||||
@@ -829,6 +842,8 @@ class RawAgent(BaseAgent, Generic[OptionsCoT]): # type: ignore[misc]
|
||||
options: OptionsCoT | ChatOptions[Any] | None = None,
|
||||
compaction_strategy: CompactionStrategy | None = None,
|
||||
tokenizer: TokenizerProtocol | None = None,
|
||||
function_invocation_kwargs: Mapping[str, Any] | None = None,
|
||||
client_kwargs: Mapping[str, Any] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> ResponseStream[AgentResponseUpdate, AgentResponse[Any]]: ...
|
||||
|
||||
@@ -842,6 +857,8 @@ class RawAgent(BaseAgent, Generic[OptionsCoT]): # type: ignore[misc]
|
||||
options: OptionsCoT | ChatOptions[Any] | None = None,
|
||||
compaction_strategy: CompactionStrategy | None = None,
|
||||
tokenizer: TokenizerProtocol | None = None,
|
||||
function_invocation_kwargs: Mapping[str, Any] | None = None,
|
||||
client_kwargs: Mapping[str, Any] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> Awaitable[AgentResponse[Any]] | ResponseStream[AgentResponseUpdate, AgentResponse[Any]]:
|
||||
"""Run the agent with the given messages and options.
|
||||
@@ -871,14 +888,23 @@ class RawAgent(BaseAgent, Generic[OptionsCoT]): # type: ignore[misc]
|
||||
tokenizer: Optional per-run tokenizer override passed to
|
||||
``client.get_response()``. When omitted, the agent-level override
|
||||
is used, falling back to the client default.
|
||||
kwargs: Additional keyword arguments for the agent. These are only
|
||||
passed to functions that are called.
|
||||
function_invocation_kwargs: Keyword arguments forwarded to tool invocation.
|
||||
client_kwargs: Additional client-specific keyword arguments for the chat client.
|
||||
kwargs: Deprecated additional keyword arguments for the agent.
|
||||
They are forwarded to both tool invocation and the chat client for compatibility.
|
||||
|
||||
Returns:
|
||||
When stream=False: An Awaitable[AgentResponse] containing the agent's response.
|
||||
When stream=True: A ResponseStream of AgentResponseUpdate items with
|
||||
``get_final_response()`` for the final AgentResponse.
|
||||
"""
|
||||
if kwargs:
|
||||
warnings.warn(
|
||||
"Passing runtime keyword arguments directly to run() is deprecated; pass tool values via "
|
||||
"function_invocation_kwargs and client-specific values via client_kwargs instead.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
if not stream:
|
||||
|
||||
async def _run_non_streaming() -> AgentResponse[Any]:
|
||||
@@ -889,7 +915,9 @@ class RawAgent(BaseAgent, Generic[OptionsCoT]): # type: ignore[misc]
|
||||
options=options,
|
||||
compaction_strategy=compaction_strategy,
|
||||
tokenizer=tokenizer,
|
||||
kwargs=kwargs,
|
||||
legacy_kwargs=kwargs,
|
||||
function_invocation_kwargs=function_invocation_kwargs,
|
||||
client_kwargs=client_kwargs,
|
||||
)
|
||||
response = cast(
|
||||
ChatResponse[Any],
|
||||
@@ -899,7 +927,8 @@ class RawAgent(BaseAgent, Generic[OptionsCoT]): # type: ignore[misc]
|
||||
options=ctx["chat_options"], # type: ignore[reportArgumentType]
|
||||
compaction_strategy=ctx["compaction_strategy"],
|
||||
tokenizer=ctx["tokenizer"],
|
||||
**ctx["filtered_kwargs"],
|
||||
function_invocation_kwargs=ctx["function_invocation_kwargs"],
|
||||
client_kwargs=ctx["client_kwargs"],
|
||||
),
|
||||
)
|
||||
|
||||
@@ -974,7 +1003,9 @@ class RawAgent(BaseAgent, Generic[OptionsCoT]): # type: ignore[misc]
|
||||
options=options,
|
||||
compaction_strategy=compaction_strategy,
|
||||
tokenizer=tokenizer,
|
||||
kwargs=kwargs,
|
||||
legacy_kwargs=kwargs,
|
||||
function_invocation_kwargs=function_invocation_kwargs,
|
||||
client_kwargs=client_kwargs,
|
||||
)
|
||||
ctx: _RunContext = ctx_holder["ctx"] # type: ignore[assignment] # Safe: we just assigned it
|
||||
return self.client.get_response( # type: ignore[call-overload, no-any-return]
|
||||
@@ -983,7 +1014,8 @@ class RawAgent(BaseAgent, Generic[OptionsCoT]): # type: ignore[misc]
|
||||
options=ctx["chat_options"], # type: ignore[reportArgumentType]
|
||||
compaction_strategy=ctx["compaction_strategy"],
|
||||
tokenizer=ctx["tokenizer"],
|
||||
**ctx["filtered_kwargs"],
|
||||
function_invocation_kwargs=ctx["function_invocation_kwargs"],
|
||||
client_kwargs=ctx["client_kwargs"],
|
||||
)
|
||||
|
||||
def _propagate_conversation_id(
|
||||
@@ -1071,9 +1103,12 @@ class RawAgent(BaseAgent, Generic[OptionsCoT]): # type: ignore[misc]
|
||||
options: Mapping[str, Any] | None,
|
||||
compaction_strategy: CompactionStrategy | None,
|
||||
tokenizer: TokenizerProtocol | None,
|
||||
kwargs: dict[str, Any],
|
||||
legacy_kwargs: Mapping[str, Any],
|
||||
function_invocation_kwargs: Mapping[str, Any] | None,
|
||||
client_kwargs: Mapping[str, Any] | None,
|
||||
) -> _RunContext:
|
||||
opts = dict(options) if options else {}
|
||||
existing_additional_args: dict[str, Any] = opts.pop("additional_function_arguments", None) or {}
|
||||
|
||||
# Get tools from options or named parameter (named param takes precedence)
|
||||
tools_ = tools if tools is not None else opts.pop("tools", None)
|
||||
@@ -1104,6 +1139,12 @@ class RawAgent(BaseAgent, Generic[OptionsCoT]): # type: ignore[misc]
|
||||
input_messages=input_messages,
|
||||
options=opts,
|
||||
)
|
||||
default_additional_args = chat_options.pop("additional_function_arguments", None)
|
||||
if isinstance(default_additional_args, Mapping):
|
||||
existing_additional_args = {
|
||||
**dict(cast(Mapping[str, Any], default_additional_args)),
|
||||
**existing_additional_args,
|
||||
}
|
||||
|
||||
agent_name = self._get_agent_name()
|
||||
base_tools = normalize_tools(chat_options.pop("tools", None))
|
||||
@@ -1135,13 +1176,13 @@ class RawAgent(BaseAgent, Generic[OptionsCoT]): # type: ignore[misc]
|
||||
duplicate_error_message=mcp_duplicate_message,
|
||||
)
|
||||
|
||||
# Merge runtime kwargs into additional_function_arguments so they're available
|
||||
# in function middleware context and tool invocation.
|
||||
existing_additional_args: dict[str, Any] = opts.pop("additional_function_arguments", None) or {}
|
||||
additional_function_arguments = {**kwargs, **existing_additional_args}
|
||||
# Include session so as_tool() wrappers with propagate_session=True can access it.
|
||||
if active_session is not None:
|
||||
additional_function_arguments["session"] = active_session
|
||||
# TODO(Copilot): Delete once direct ``run(**kwargs)`` compatibility is removed.
|
||||
# Legacy compatibility still fans out direct run kwargs into tool runtime kwargs.
|
||||
effective_function_invocation_kwargs = {
|
||||
**dict(legacy_kwargs),
|
||||
**(dict(function_invocation_kwargs) if function_invocation_kwargs is not None else {}),
|
||||
}
|
||||
additional_function_arguments = {**effective_function_invocation_kwargs, **existing_additional_args}
|
||||
|
||||
# Build options dict from run() options merged with provided options
|
||||
run_opts: dict[str, Any] = {
|
||||
@@ -1150,7 +1191,6 @@ class RawAgent(BaseAgent, Generic[OptionsCoT]): # type: ignore[misc]
|
||||
if active_session
|
||||
else opts.pop("conversation_id", None),
|
||||
"allow_multiple_tool_calls": opts.pop("allow_multiple_tool_calls", None),
|
||||
"additional_function_arguments": additional_function_arguments or None,
|
||||
"frequency_penalty": opts.pop("frequency_penalty", None),
|
||||
"logit_bias": opts.pop("logit_bias", None),
|
||||
"max_tokens": opts.pop("max_tokens", None),
|
||||
@@ -1174,11 +1214,14 @@ class RawAgent(BaseAgent, Generic[OptionsCoT]): # type: ignore[misc]
|
||||
# Build session_messages from session context: context messages + input messages
|
||||
session_messages: list[Message] = session_context.get_messages(include_input=True)
|
||||
|
||||
# Ensure session is forwarded in kwargs for tool invocation
|
||||
finalize_kwargs = dict(kwargs)
|
||||
finalize_kwargs["session"] = active_session
|
||||
# Filter chat_options from kwargs to prevent duplicate keyword argument
|
||||
filtered_kwargs = {k: v for k, v in finalize_kwargs.items() if k != "chat_options"}
|
||||
# TODO(Copilot): Delete once direct ``run(**kwargs)`` compatibility is removed.
|
||||
# Legacy compatibility still fans out direct run kwargs into client kwargs.
|
||||
effective_client_kwargs = {
|
||||
**dict(legacy_kwargs),
|
||||
**(dict(client_kwargs) if client_kwargs is not None else {}),
|
||||
}
|
||||
if active_session is not None:
|
||||
effective_client_kwargs["session"] = active_session
|
||||
|
||||
return {
|
||||
"session": active_session,
|
||||
@@ -1189,8 +1232,8 @@ class RawAgent(BaseAgent, Generic[OptionsCoT]): # type: ignore[misc]
|
||||
"chat_options": co,
|
||||
"compaction_strategy": compaction_strategy or self.compaction_strategy,
|
||||
"tokenizer": tokenizer or self.tokenizer,
|
||||
"filtered_kwargs": filtered_kwargs,
|
||||
"finalize_kwargs": finalize_kwargs,
|
||||
"client_kwargs": effective_client_kwargs,
|
||||
"function_invocation_kwargs": additional_function_arguments,
|
||||
}
|
||||
|
||||
async def _finalize_response(
|
||||
@@ -1440,6 +1483,58 @@ class Agent(
|
||||
For a minimal implementation without these features, use :class:`RawAgent`.
|
||||
"""
|
||||
|
||||
@overload
|
||||
def run(
|
||||
self,
|
||||
messages: AgentRunInputs | None = None,
|
||||
*,
|
||||
stream: Literal[False] = ...,
|
||||
session: AgentSession | None = None,
|
||||
function_invocation_kwargs: Mapping[str, Any] | None = None,
|
||||
client_kwargs: Mapping[str, Any] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> Awaitable[AgentResponse[Any]]: ...
|
||||
|
||||
@overload
|
||||
def run(
|
||||
self,
|
||||
messages: AgentRunInputs | None = None,
|
||||
*,
|
||||
stream: Literal[True],
|
||||
session: AgentSession | None = None,
|
||||
function_invocation_kwargs: Mapping[str, Any] | None = None,
|
||||
client_kwargs: Mapping[str, Any] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> ResponseStream[AgentResponseUpdate, AgentResponse[Any]]: ...
|
||||
|
||||
def run(
|
||||
self,
|
||||
messages: AgentRunInputs | None = None,
|
||||
*,
|
||||
stream: bool = False,
|
||||
session: AgentSession | None = None,
|
||||
middleware: Sequence[MiddlewareTypes] | None = None,
|
||||
options: OptionsCoT | ChatOptions[Any] | None = None,
|
||||
function_invocation_kwargs: Mapping[str, Any] | None = None,
|
||||
client_kwargs: Mapping[str, Any] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> Awaitable[AgentResponse[Any]] | ResponseStream[AgentResponseUpdate, AgentResponse[Any]]:
|
||||
"""Run the agent."""
|
||||
super_run = cast(
|
||||
"Callable[..., Awaitable[AgentResponse[Any]] | ResponseStream[AgentResponseUpdate, AgentResponse[Any]]]",
|
||||
super().run, # type: ignore[misc]
|
||||
)
|
||||
return super_run( # type: ignore[no-any-return]
|
||||
messages=messages,
|
||||
stream=stream,
|
||||
session=session,
|
||||
middleware=middleware,
|
||||
options=options,
|
||||
function_invocation_kwargs=function_invocation_kwargs,
|
||||
client_kwargs=client_kwargs,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
client: SupportsChatGetResponse[OptionsCoT],
|
||||
@@ -1471,3 +1566,34 @@ class Agent(
|
||||
tokenizer=tokenizer,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
def _apply_agent_docstrings() -> None:
|
||||
"""Align public agent docstrings with the raw implementation."""
|
||||
apply_layered_docstring(
|
||||
AgentMiddlewareLayer.run,
|
||||
RawAgent.run,
|
||||
extra_keyword_args={
|
||||
"middleware": """
|
||||
Optional per-run agent, chat, and function middleware.
|
||||
Agent middleware wraps the run itself, while chat and function middleware are forwarded to the
|
||||
underlying chat-client stack for this call.
|
||||
""",
|
||||
},
|
||||
)
|
||||
apply_layered_docstring(AgentTelemetryLayer.run, AgentMiddlewareLayer.run)
|
||||
apply_layered_docstring(
|
||||
Agent.run,
|
||||
RawAgent.run,
|
||||
extra_keyword_args={
|
||||
"middleware": """
|
||||
Optional per-run agent, chat, and function middleware.
|
||||
Agent middleware wraps the run itself, while chat and function middleware are forwarded to the
|
||||
underlying chat-client stack for this call.
|
||||
""",
|
||||
},
|
||||
)
|
||||
apply_layered_docstring(Agent.__init__, RawAgent.__init__)
|
||||
|
||||
|
||||
_apply_agent_docstrings()
|
||||
|
||||
@@ -4,6 +4,7 @@ from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import sys
|
||||
import warnings
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import (
|
||||
AsyncIterable,
|
||||
@@ -27,6 +28,7 @@ from typing import (
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from ._docstrings import apply_layered_docstring
|
||||
from ._serialization import SerializationMixin
|
||||
from ._tools import (
|
||||
FunctionInvocationConfiguration,
|
||||
@@ -105,7 +107,7 @@ class SupportsChatGetResponse(Protocol[OptionsContraT]):
|
||||
class CustomChatClient:
|
||||
additional_properties: dict = {}
|
||||
|
||||
def get_response(self, messages, *, stream=False, **kwargs):
|
||||
def get_response(self, messages, *, stream=False, client_kwargs=None, **kwargs):
|
||||
if stream:
|
||||
from agent_framework import ChatResponseUpdate, ResponseStream
|
||||
|
||||
@@ -149,6 +151,8 @@ class SupportsChatGetResponse(Protocol[OptionsContraT]):
|
||||
options: OptionsContraT | ChatOptions[None] | None = None,
|
||||
compaction_strategy: CompactionStrategy | None = None,
|
||||
tokenizer: TokenizerProtocol | None = None,
|
||||
function_invocation_kwargs: Mapping[str, Any] | None = None,
|
||||
client_kwargs: Mapping[str, Any] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> Awaitable[ChatResponse[Any]]: ...
|
||||
|
||||
@@ -161,6 +165,8 @@ class SupportsChatGetResponse(Protocol[OptionsContraT]):
|
||||
options: OptionsContraT | ChatOptions[Any] | None = None,
|
||||
compaction_strategy: CompactionStrategy | None = None,
|
||||
tokenizer: TokenizerProtocol | None = None,
|
||||
function_invocation_kwargs: Mapping[str, Any] | None = None,
|
||||
client_kwargs: Mapping[str, Any] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> ResponseStream[ChatResponseUpdate, ChatResponse[Any]]: ...
|
||||
|
||||
@@ -172,6 +178,8 @@ class SupportsChatGetResponse(Protocol[OptionsContraT]):
|
||||
options: OptionsContraT | ChatOptions[Any] | None = None,
|
||||
compaction_strategy: CompactionStrategy | None = None,
|
||||
tokenizer: TokenizerProtocol | None = None,
|
||||
function_invocation_kwargs: Mapping[str, Any] | None = None,
|
||||
client_kwargs: Mapping[str, Any] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> Awaitable[ChatResponse[Any]] | ResponseStream[ChatResponseUpdate, ChatResponse[Any]]:
|
||||
"""Send input and return the response.
|
||||
@@ -182,7 +190,9 @@ class SupportsChatGetResponse(Protocol[OptionsContraT]):
|
||||
options: Chat options as a TypedDict.
|
||||
compaction_strategy: Optional per-call compaction override.
|
||||
tokenizer: Optional per-call tokenizer override.
|
||||
**kwargs: Additional chat options.
|
||||
function_invocation_kwargs: Keyword arguments forwarded only to tool invocation layers.
|
||||
client_kwargs: Additional client-specific keyword arguments.
|
||||
**kwargs: Deprecated additional client-specific keyword arguments.
|
||||
|
||||
Returns:
|
||||
When stream=False: An awaitable ChatResponse from the client.
|
||||
@@ -283,23 +293,31 @@ class BaseChatClient(SerializationMixin, ABC, Generic[OptionsCoT]):
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
additional_properties: dict[str, Any] | None = None,
|
||||
compaction_strategy: CompactionStrategy | None = None,
|
||||
tokenizer: TokenizerProtocol | None = None,
|
||||
additional_properties: dict[str, Any] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Initialize a BaseChatClient instance.
|
||||
|
||||
Keyword Args:
|
||||
additional_properties: Additional properties for the client.
|
||||
compaction_strategy: Optional compaction strategy to apply before model calls.
|
||||
tokenizer: Optional tokenizer used by token-aware compaction strategies.
|
||||
kwargs: Additional keyword arguments (merged into additional_properties).
|
||||
additional_properties: Additional properties for the client.
|
||||
kwargs: Additional keyword arguments (merged into additional_properties for now).
|
||||
"""
|
||||
self.additional_properties = additional_properties or {}
|
||||
self.compaction_strategy = compaction_strategy
|
||||
self.tokenizer = tokenizer
|
||||
super().__init__(**kwargs)
|
||||
if kwargs:
|
||||
warnings.warn(
|
||||
"Passing additional properties as direct keyword arguments to BaseChatClient is deprecated; "
|
||||
"pass them via additional_properties instead.",
|
||||
DeprecationWarning,
|
||||
stacklevel=3,
|
||||
)
|
||||
self.additional_properties.update(kwargs)
|
||||
super().__init__()
|
||||
|
||||
def to_dict(self, *, exclude: set[str] | None = None, exclude_none: bool = True) -> dict[str, Any]:
|
||||
"""Convert the instance to a dictionary.
|
||||
@@ -486,7 +504,13 @@ class BaseChatClient(SerializationMixin, ABC, Generic[OptionsCoT]):
|
||||
When omitted, the client-level default is used.
|
||||
tokenizer: Optional per-call tokenizer override. When omitted, the
|
||||
client-level default is used.
|
||||
**kwargs: Other keyword arguments, can be used to pass function specific parameters.
|
||||
**kwargs: Additional compatibility keyword arguments. Lower chat-client layers do not
|
||||
consume ``function_invocation_kwargs`` directly; if present, it is ignored here
|
||||
because function invocation has already been handled by upper layers. If a
|
||||
``client_kwargs`` mapping is present, it is flattened into standard keyword
|
||||
arguments before forwarding to ``_inner_get_response()`` so client implementations
|
||||
can leverage those values, while implementations that ignore
|
||||
extra kwargs remain compatible.
|
||||
|
||||
Returns:
|
||||
When streaming a response stream of ChatResponseUpdates, otherwise an Awaitable ChatResponse.
|
||||
@@ -495,12 +519,21 @@ class BaseChatClient(SerializationMixin, ABC, Generic[OptionsCoT]):
|
||||
compaction_strategy=compaction_strategy,
|
||||
tokenizer=tokenizer,
|
||||
)
|
||||
compatibility_client_kwargs = kwargs.pop("client_kwargs", None)
|
||||
kwargs.pop("function_invocation_kwargs", None)
|
||||
merged_client_kwargs = (
|
||||
dict(cast(Mapping[str, Any], compatibility_client_kwargs))
|
||||
if isinstance(compatibility_client_kwargs, Mapping)
|
||||
else {}
|
||||
)
|
||||
merged_client_kwargs.update(kwargs)
|
||||
|
||||
if not compaction_overrides:
|
||||
return self._inner_get_response(
|
||||
messages=messages,
|
||||
stream=stream,
|
||||
options=options or {},
|
||||
**kwargs,
|
||||
options=options or {}, # type: ignore[arg-type]
|
||||
**merged_client_kwargs,
|
||||
)
|
||||
|
||||
if stream:
|
||||
@@ -514,7 +547,7 @@ class BaseChatClient(SerializationMixin, ABC, Generic[OptionsCoT]):
|
||||
messages=prepared_messages,
|
||||
stream=True,
|
||||
options=options or {},
|
||||
**kwargs,
|
||||
**merged_client_kwargs,
|
||||
)
|
||||
if isinstance(stream_response, ResponseStream):
|
||||
return stream_response # type: ignore[reportUnknownVariableType]
|
||||
@@ -534,7 +567,7 @@ class BaseChatClient(SerializationMixin, ABC, Generic[OptionsCoT]):
|
||||
messages=prepared_messages,
|
||||
stream=False,
|
||||
options=options or {},
|
||||
**kwargs,
|
||||
**merged_client_kwargs,
|
||||
)
|
||||
|
||||
return _get_response()
|
||||
@@ -564,7 +597,7 @@ class BaseChatClient(SerializationMixin, ABC, Generic[OptionsCoT]):
|
||||
function_invocation_configuration: FunctionInvocationConfiguration | None = None,
|
||||
compaction_strategy: CompactionStrategy | None = None,
|
||||
tokenizer: TokenizerProtocol | None = None,
|
||||
**kwargs: Any,
|
||||
additional_properties: Mapping[str, Any] | None = None,
|
||||
) -> Agent[OptionsCoT]:
|
||||
"""Create a Agent with this client.
|
||||
|
||||
@@ -590,7 +623,7 @@ class BaseChatClient(SerializationMixin, ABC, Generic[OptionsCoT]):
|
||||
client-level compaction defaults remain in effect for each call.
|
||||
tokenizer: Optional agent-level tokenizer override. When omitted,
|
||||
client-level tokenizer defaults remain in effect for each call.
|
||||
kwargs: Any additional keyword arguments. Will be stored as ``additional_properties``.
|
||||
additional_properties: Additional properties stored on the created agent.
|
||||
|
||||
Returns:
|
||||
A Agent instance configured with this chat client.
|
||||
@@ -615,21 +648,24 @@ class BaseChatClient(SerializationMixin, ABC, Generic[OptionsCoT]):
|
||||
"""
|
||||
from ._agents import Agent
|
||||
|
||||
return Agent(
|
||||
client=self,
|
||||
id=id,
|
||||
name=name,
|
||||
description=description,
|
||||
instructions=instructions,
|
||||
tools=tools,
|
||||
default_options=cast(Any, default_options),
|
||||
context_providers=context_providers,
|
||||
middleware=middleware,
|
||||
function_invocation_configuration=function_invocation_configuration,
|
||||
compaction_strategy=compaction_strategy,
|
||||
tokenizer=tokenizer,
|
||||
**kwargs,
|
||||
)
|
||||
agent_kwargs: dict[str, Any] = {
|
||||
"client": self,
|
||||
"id": id,
|
||||
"name": name,
|
||||
"description": description,
|
||||
"instructions": instructions,
|
||||
"tools": tools,
|
||||
"default_options": cast(Any, default_options),
|
||||
"context_providers": context_providers,
|
||||
"middleware": middleware,
|
||||
"compaction_strategy": compaction_strategy,
|
||||
"tokenizer": tokenizer,
|
||||
"additional_properties": dict(additional_properties) if additional_properties is not None else None,
|
||||
}
|
||||
if function_invocation_configuration is not None:
|
||||
agent_kwargs["function_invocation_configuration"] = function_invocation_configuration
|
||||
|
||||
return Agent(**agent_kwargs)
|
||||
|
||||
|
||||
# endregion
|
||||
@@ -892,16 +928,14 @@ class BaseEmbeddingClient(SerializationMixin, ABC, Generic[EmbeddingInputT, Embe
|
||||
self,
|
||||
*,
|
||||
additional_properties: dict[str, Any] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Initialize a BaseEmbeddingClient instance.
|
||||
|
||||
Args:
|
||||
additional_properties: Additional properties to pass to the client.
|
||||
**kwargs: Additional keyword arguments passed to parent classes (for MRO).
|
||||
"""
|
||||
self.additional_properties = additional_properties or {}
|
||||
super().__init__(**kwargs)
|
||||
super().__init__()
|
||||
|
||||
@abstractmethod
|
||||
async def get_embeddings(
|
||||
@@ -923,3 +957,36 @@ class BaseEmbeddingClient(SerializationMixin, ABC, Generic[EmbeddingInputT, Embe
|
||||
|
||||
|
||||
# endregion
|
||||
|
||||
|
||||
def _apply_get_response_docstrings() -> None:
|
||||
"""Align layered chat-client docstrings with the lowest public implementation."""
|
||||
from ._middleware import ChatMiddlewareLayer
|
||||
from ._tools import FunctionInvocationLayer
|
||||
from .observability import ChatTelemetryLayer
|
||||
|
||||
apply_layered_docstring(ChatTelemetryLayer.get_response, BaseChatClient.get_response)
|
||||
apply_layered_docstring(
|
||||
FunctionInvocationLayer.get_response,
|
||||
ChatTelemetryLayer.get_response,
|
||||
extra_keyword_args={
|
||||
"function_middleware": """
|
||||
Optional per-call function middleware.
|
||||
When omitted, middleware configured on the client or forwarded from higher layers is used.
|
||||
""",
|
||||
},
|
||||
)
|
||||
apply_layered_docstring(
|
||||
ChatMiddlewareLayer.get_response,
|
||||
FunctionInvocationLayer.get_response,
|
||||
extra_keyword_args={
|
||||
"middleware": """
|
||||
Optional per-call chat and function middleware.
|
||||
This compatibility keyword argument is merged with any ``client_kwargs["middleware"]`` value
|
||||
before the request is executed.
|
||||
""",
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
_apply_get_response_docstrings()
|
||||
|
||||
@@ -0,0 +1,85 @@
|
||||
# Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import inspect
|
||||
from collections.abc import Callable, Mapping
|
||||
from typing import Any
|
||||
|
||||
_GOOGLE_SECTION_HEADERS = (
|
||||
"Args:",
|
||||
"Keyword Args:",
|
||||
"Returns:",
|
||||
"Raises:",
|
||||
"Examples:",
|
||||
"Note:",
|
||||
"Notes:",
|
||||
"Warning:",
|
||||
"Warnings:",
|
||||
)
|
||||
|
||||
|
||||
def _find_section_index(lines: list[str], header: str) -> int | None:
|
||||
for index, line in enumerate(lines):
|
||||
if line == header:
|
||||
return index
|
||||
return None
|
||||
|
||||
|
||||
def _find_next_section_index(lines: list[str], start: int) -> int:
|
||||
for index in range(start, len(lines)):
|
||||
if lines[index] in _GOOGLE_SECTION_HEADERS:
|
||||
return index
|
||||
return len(lines)
|
||||
|
||||
|
||||
def _format_keyword_arg_lines(extra_keyword_args: Mapping[str, str]) -> list[str]:
|
||||
formatted_lines: list[str] = []
|
||||
for name, description in extra_keyword_args.items():
|
||||
description_lines = inspect.cleandoc(description).splitlines()
|
||||
if not description_lines:
|
||||
formatted_lines.append(f" {name}:")
|
||||
continue
|
||||
formatted_lines.append(f" {name}: {description_lines[0]}")
|
||||
formatted_lines.extend(f" {line}" for line in description_lines[1:])
|
||||
return formatted_lines
|
||||
|
||||
|
||||
def build_layered_docstring(
|
||||
source: Callable[..., Any],
|
||||
*,
|
||||
extra_keyword_args: Mapping[str, str] | None = None,
|
||||
) -> str | None:
|
||||
"""Build a Google-style docstring from a lower-layer implementation."""
|
||||
docstring = inspect.getdoc(source)
|
||||
if not docstring:
|
||||
return None
|
||||
if not extra_keyword_args:
|
||||
return docstring
|
||||
|
||||
lines = docstring.splitlines()
|
||||
formatted_keyword_arg_lines = _format_keyword_arg_lines(extra_keyword_args)
|
||||
keyword_args_index = _find_section_index(lines, "Keyword Args:")
|
||||
|
||||
if keyword_args_index is None:
|
||||
args_index = _find_section_index(lines, "Args:")
|
||||
if args_index is not None:
|
||||
insert_index = _find_next_section_index(lines, args_index + 1)
|
||||
else:
|
||||
insert_index = _find_next_section_index(lines, 0)
|
||||
lines[insert_index:insert_index] = ["", "Keyword Args:", *formatted_keyword_arg_lines]
|
||||
return "\n".join(lines).rstrip()
|
||||
|
||||
insert_index = _find_next_section_index(lines, keyword_args_index + 1)
|
||||
lines[insert_index:insert_index] = formatted_keyword_arg_lines
|
||||
return "\n".join(lines).rstrip()
|
||||
|
||||
|
||||
def apply_layered_docstring(
|
||||
target: Callable[..., Any],
|
||||
source: Callable[..., Any],
|
||||
*,
|
||||
extra_keyword_args: Mapping[str, str] | None = None,
|
||||
) -> None:
|
||||
"""Copy a lower-layer docstring onto a wrapper and extend it when needed."""
|
||||
target.__doc__ = build_layered_docstring(source, extra_keyword_args=extra_keyword_args)
|
||||
@@ -109,7 +109,9 @@ class AgentContext:
|
||||
to see the actual execution result or can be set to override the execution result.
|
||||
For non-streaming: should be AgentResponse.
|
||||
For streaming: should be ResponseStream[AgentResponseUpdate, AgentResponse].
|
||||
kwargs: Additional keyword arguments passed to the agent run method.
|
||||
kwargs: Legacy runtime keyword arguments visible to agent middleware.
|
||||
client_kwargs: Client-specific keyword arguments for downstream chat clients.
|
||||
function_invocation_kwargs: Keyword arguments forwarded to tool invocation.
|
||||
|
||||
Examples:
|
||||
.. code-block:: python
|
||||
@@ -147,6 +149,8 @@ class AgentContext:
|
||||
metadata: Mapping[str, Any] | None = None,
|
||||
result: AgentResponse | ResponseStream[AgentResponseUpdate, AgentResponse] | None = None,
|
||||
kwargs: Mapping[str, Any] | None = None,
|
||||
client_kwargs: Mapping[str, Any] | None = None,
|
||||
function_invocation_kwargs: Mapping[str, Any] | None = None,
|
||||
stream_transform_hooks: Sequence[
|
||||
Callable[[AgentResponseUpdate], AgentResponseUpdate | Awaitable[AgentResponseUpdate]]
|
||||
]
|
||||
@@ -167,7 +171,9 @@ class AgentContext:
|
||||
tokenizer: Optional per-run tokenizer override.
|
||||
metadata: Metadata dictionary for sharing data between agent middleware.
|
||||
result: Agent execution result.
|
||||
kwargs: Additional keyword arguments passed to the agent run method.
|
||||
kwargs: Legacy runtime keyword arguments visible to agent middleware.
|
||||
client_kwargs: Client-specific keyword arguments for downstream chat clients.
|
||||
function_invocation_kwargs: Keyword arguments forwarded to tool invocation.
|
||||
stream_transform_hooks: Hooks to transform streamed updates.
|
||||
stream_result_hooks: Hooks to process the final result after streaming.
|
||||
stream_cleanup_hooks: Hooks to run after streaming completes.
|
||||
@@ -182,6 +188,10 @@ class AgentContext:
|
||||
self.metadata: dict[str, Any] = dict(metadata) if metadata is not None else {}
|
||||
self.result = result
|
||||
self.kwargs: dict[str, Any] = dict(kwargs) if kwargs is not None else {}
|
||||
self.client_kwargs: dict[str, Any] = dict(client_kwargs) if client_kwargs is not None else {}
|
||||
self.function_invocation_kwargs: dict[str, Any] = (
|
||||
dict(function_invocation_kwargs) if function_invocation_kwargs is not None else {}
|
||||
)
|
||||
self.stream_transform_hooks = list(stream_transform_hooks or [])
|
||||
self.stream_result_hooks = list(stream_result_hooks or [])
|
||||
self.stream_cleanup_hooks = list(stream_cleanup_hooks or [])
|
||||
@@ -196,11 +206,11 @@ class FunctionInvocationContext:
|
||||
Attributes:
|
||||
function: The function being invoked.
|
||||
arguments: The validated arguments for the function.
|
||||
session: The agent session for this invocation, if any.
|
||||
metadata: Metadata dictionary for sharing data between function middleware.
|
||||
result: Function execution result. Can be observed after calling ``call_next()``
|
||||
to see the actual execution result or can be set to override the execution result.
|
||||
|
||||
kwargs: Additional keyword arguments passed to the chat method that invoked this function.
|
||||
kwargs: Additional runtime keyword arguments forwarded to the function invocation.
|
||||
|
||||
Examples:
|
||||
.. code-block:: python
|
||||
@@ -225,6 +235,7 @@ class FunctionInvocationContext:
|
||||
self,
|
||||
function: FunctionTool,
|
||||
arguments: BaseModel | Mapping[str, Any],
|
||||
session: AgentSession | None = None,
|
||||
metadata: Mapping[str, Any] | None = None,
|
||||
result: Any = None,
|
||||
kwargs: Mapping[str, Any] | None = None,
|
||||
@@ -234,12 +245,14 @@ class FunctionInvocationContext:
|
||||
Args:
|
||||
function: The function being invoked.
|
||||
arguments: The validated arguments for the function.
|
||||
session: The agent session for this invocation, if any.
|
||||
metadata: Metadata dictionary for sharing data between function middleware.
|
||||
result: Function execution result.
|
||||
kwargs: Additional keyword arguments passed to the chat method that invoked this function.
|
||||
kwargs: Additional runtime keyword arguments forwarded to the function invocation.
|
||||
"""
|
||||
self.function = function
|
||||
self.arguments = arguments
|
||||
self.session = session
|
||||
self.metadata: dict[str, Any] = dict(metadata) if metadata is not None else {}
|
||||
self.result = result
|
||||
self.kwargs: dict[str, Any] = dict(kwargs) if kwargs is not None else {}
|
||||
@@ -262,6 +275,7 @@ class ChatContext:
|
||||
For non-streaming: should be ChatResponse.
|
||||
For streaming: should be ResponseStream[ChatResponseUpdate, ChatResponse].
|
||||
kwargs: Additional keyword arguments passed to the chat client.
|
||||
function_invocation_kwargs: Keyword arguments forwarded only to tool invocation layers.
|
||||
stream_transform_hooks: Hooks applied to transform each streamed update.
|
||||
stream_result_hooks: Hooks applied to the finalized response (after finalizer).
|
||||
stream_cleanup_hooks: Hooks executed after stream consumption (before finalizer).
|
||||
@@ -298,6 +312,7 @@ class ChatContext:
|
||||
metadata: Mapping[str, Any] | None = None,
|
||||
result: ChatResponse | ResponseStream[ChatResponseUpdate, ChatResponse] | None = None,
|
||||
kwargs: Mapping[str, Any] | None = None,
|
||||
function_invocation_kwargs: Mapping[str, Any] | None = None,
|
||||
stream_transform_hooks: Sequence[
|
||||
Callable[[ChatResponseUpdate], ChatResponseUpdate | Awaitable[ChatResponseUpdate]]
|
||||
]
|
||||
@@ -315,6 +330,7 @@ class ChatContext:
|
||||
metadata: Metadata dictionary for sharing data between chat middleware.
|
||||
result: Chat execution result.
|
||||
kwargs: Additional keyword arguments passed to the chat client.
|
||||
function_invocation_kwargs: Keyword arguments forwarded only to tool invocation layers.
|
||||
stream_transform_hooks: Transform hooks to apply to each streamed update.
|
||||
stream_result_hooks: Result hooks to apply to the finalized streaming response.
|
||||
stream_cleanup_hooks: Cleanup hooks to run after streaming completes.
|
||||
@@ -326,6 +342,9 @@ class ChatContext:
|
||||
self.metadata: dict[str, Any] = dict(metadata) if metadata is not None else {}
|
||||
self.result = result
|
||||
self.kwargs: dict[str, Any] = dict(kwargs) if kwargs is not None else {}
|
||||
self.function_invocation_kwargs: dict[str, Any] = (
|
||||
dict(function_invocation_kwargs) if function_invocation_kwargs is not None else {}
|
||||
)
|
||||
self.stream_transform_hooks = list(stream_transform_hooks or [])
|
||||
self.stream_result_hooks = list(stream_result_hooks or [])
|
||||
self.stream_cleanup_hooks = list(stream_cleanup_hooks or [])
|
||||
@@ -980,6 +999,7 @@ class ChatMiddlewareLayer(Generic[OptionsCoT]):
|
||||
options: ChatOptions[ResponseModelBoundT],
|
||||
compaction_strategy: CompactionStrategy | None = None,
|
||||
tokenizer: TokenizerProtocol | None = None,
|
||||
function_invocation_kwargs: Mapping[str, Any] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> Awaitable[ChatResponse[ResponseModelBoundT]]: ...
|
||||
|
||||
@@ -992,6 +1012,8 @@ class ChatMiddlewareLayer(Generic[OptionsCoT]):
|
||||
options: OptionsCoT | ChatOptions[None] | None = None,
|
||||
compaction_strategy: CompactionStrategy | None = None,
|
||||
tokenizer: TokenizerProtocol | None = None,
|
||||
function_invocation_kwargs: Mapping[str, Any] | None = None,
|
||||
client_kwargs: Mapping[str, Any] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> Awaitable[ChatResponse[Any]]: ...
|
||||
|
||||
@@ -1004,6 +1026,8 @@ class ChatMiddlewareLayer(Generic[OptionsCoT]):
|
||||
options: OptionsCoT | ChatOptions[Any] | None = None,
|
||||
compaction_strategy: CompactionStrategy | None = None,
|
||||
tokenizer: TokenizerProtocol | None = None,
|
||||
function_invocation_kwargs: Mapping[str, Any] | None = None,
|
||||
client_kwargs: Mapping[str, Any] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> ResponseStream[ChatResponseUpdate, ChatResponse[Any]]: ...
|
||||
|
||||
@@ -1015,6 +1039,8 @@ class ChatMiddlewareLayer(Generic[OptionsCoT]):
|
||||
options: OptionsCoT | ChatOptions[Any] | None = None,
|
||||
compaction_strategy: CompactionStrategy | None = None,
|
||||
tokenizer: TokenizerProtocol | None = None,
|
||||
function_invocation_kwargs: Mapping[str, Any] | None = None,
|
||||
client_kwargs: Mapping[str, Any] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> Awaitable[ChatResponse[Any]] | ResponseStream[ChatResponseUpdate, ChatResponse[Any]]:
|
||||
"""Execute the chat pipeline if middleware is configured."""
|
||||
@@ -1025,9 +1051,10 @@ class ChatMiddlewareLayer(Generic[OptionsCoT]):
|
||||
if tokenizer is not None:
|
||||
kwargs["tokenizer"] = tokenizer
|
||||
|
||||
call_middleware = kwargs.pop("middleware", [])
|
||||
effective_client_kwargs = dict(client_kwargs) if client_kwargs is not None else {}
|
||||
call_middleware = kwargs.pop("middleware", effective_client_kwargs.pop("middleware", []))
|
||||
middleware = categorize_middleware(call_middleware)
|
||||
kwargs["function_middleware"] = middleware["function"]
|
||||
effective_client_kwargs["function_middleware"] = middleware["function"]
|
||||
|
||||
pipeline = ChatMiddlewarePipeline(
|
||||
*self.chat_middleware,
|
||||
@@ -1038,6 +1065,8 @@ class ChatMiddlewareLayer(Generic[OptionsCoT]):
|
||||
messages=messages,
|
||||
stream=stream,
|
||||
options=options,
|
||||
function_invocation_kwargs=function_invocation_kwargs,
|
||||
client_kwargs=effective_client_kwargs,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@@ -1046,7 +1075,8 @@ class ChatMiddlewareLayer(Generic[OptionsCoT]):
|
||||
messages=list(messages),
|
||||
options=options,
|
||||
stream=stream,
|
||||
kwargs=kwargs,
|
||||
kwargs={**effective_client_kwargs, **kwargs},
|
||||
function_invocation_kwargs=function_invocation_kwargs,
|
||||
)
|
||||
|
||||
async def _execute() -> ChatResponse | ResponseStream[ChatResponseUpdate, ChatResponse] | None:
|
||||
@@ -1079,11 +1109,17 @@ class ChatMiddlewareLayer(Generic[OptionsCoT]):
|
||||
self, context: ChatContext
|
||||
) -> Awaitable[ChatResponse] | ResponseStream[ChatResponseUpdate, ChatResponse]:
|
||||
"""Internal middleware handler to adapt to pipeline."""
|
||||
handler_kwargs = dict(context.kwargs)
|
||||
compaction_strategy = handler_kwargs.pop("compaction_strategy", None)
|
||||
tokenizer = handler_kwargs.pop("tokenizer", None)
|
||||
return super().get_response( # type: ignore[misc, no-any-return]
|
||||
messages=context.messages,
|
||||
stream=context.stream,
|
||||
options=context.options or {},
|
||||
**context.kwargs,
|
||||
compaction_strategy=compaction_strategy,
|
||||
tokenizer=tokenizer,
|
||||
function_invocation_kwargs=context.function_invocation_kwargs,
|
||||
client_kwargs=handler_kwargs,
|
||||
)
|
||||
|
||||
|
||||
@@ -1115,6 +1151,8 @@ class AgentMiddlewareLayer:
|
||||
options: ChatOptions[ResponseModelBoundT],
|
||||
compaction_strategy: CompactionStrategy | None = None,
|
||||
tokenizer: TokenizerProtocol | None = None,
|
||||
function_invocation_kwargs: Mapping[str, Any] | None = None,
|
||||
client_kwargs: Mapping[str, Any] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> Awaitable[AgentResponse[ResponseModelBoundT]]: ...
|
||||
|
||||
@@ -1129,6 +1167,8 @@ class AgentMiddlewareLayer:
|
||||
options: ChatOptions[None] | None = None,
|
||||
compaction_strategy: CompactionStrategy | None = None,
|
||||
tokenizer: TokenizerProtocol | None = None,
|
||||
function_invocation_kwargs: Mapping[str, Any] | None = None,
|
||||
client_kwargs: Mapping[str, Any] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> Awaitable[AgentResponse[Any]]: ...
|
||||
|
||||
@@ -1143,6 +1183,8 @@ class AgentMiddlewareLayer:
|
||||
options: ChatOptions[Any] | None = None,
|
||||
compaction_strategy: CompactionStrategy | None = None,
|
||||
tokenizer: TokenizerProtocol | None = None,
|
||||
function_invocation_kwargs: Mapping[str, Any] | None = None,
|
||||
client_kwargs: Mapping[str, Any] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> ResponseStream[AgentResponseUpdate, AgentResponse[Any]]: ...
|
||||
|
||||
@@ -1156,6 +1198,8 @@ class AgentMiddlewareLayer:
|
||||
options: ChatOptions[Any] | None = None,
|
||||
compaction_strategy: CompactionStrategy | None = None,
|
||||
tokenizer: TokenizerProtocol | None = None,
|
||||
function_invocation_kwargs: Mapping[str, Any] | None = None,
|
||||
client_kwargs: Mapping[str, Any] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> Awaitable[AgentResponse[Any]] | ResponseStream[AgentResponseUpdate, AgentResponse[Any]]:
|
||||
"""MiddlewareTypes-enabled unified run method."""
|
||||
@@ -1175,9 +1219,12 @@ class AgentMiddlewareLayer:
|
||||
+ run_middleware_list["function"]
|
||||
+ run_middleware_list["chat"]
|
||||
)
|
||||
combined_kwargs = dict(kwargs)
|
||||
combined_kwargs["middleware"] = combined_function_chat_middleware if combined_function_chat_middleware else None
|
||||
|
||||
effective_client_kwargs = dict(client_kwargs) if client_kwargs is not None else {}
|
||||
if combined_function_chat_middleware:
|
||||
effective_client_kwargs["middleware"] = combined_function_chat_middleware
|
||||
effective_function_invocation_kwargs = (
|
||||
dict(function_invocation_kwargs) if function_invocation_kwargs is not None else {}
|
||||
)
|
||||
# Execute with middleware if available
|
||||
if not pipeline.has_middlewares:
|
||||
return super().run( # type: ignore[misc, no-any-return]
|
||||
@@ -1187,7 +1234,9 @@ class AgentMiddlewareLayer:
|
||||
options=options,
|
||||
compaction_strategy=compaction_strategy,
|
||||
tokenizer=tokenizer,
|
||||
**combined_kwargs,
|
||||
function_invocation_kwargs=effective_function_invocation_kwargs,
|
||||
client_kwargs=effective_client_kwargs,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
context = AgentContext(
|
||||
@@ -1198,7 +1247,9 @@ class AgentMiddlewareLayer:
|
||||
stream=stream,
|
||||
compaction_strategy=compaction_strategy,
|
||||
tokenizer=tokenizer,
|
||||
kwargs=combined_kwargs,
|
||||
kwargs=kwargs,
|
||||
client_kwargs=effective_client_kwargs,
|
||||
function_invocation_kwargs=effective_function_invocation_kwargs,
|
||||
)
|
||||
|
||||
async def _execute() -> AgentResponse | ResponseStream[AgentResponseUpdate, AgentResponse] | None:
|
||||
@@ -1230,6 +1281,13 @@ class AgentMiddlewareLayer:
|
||||
def _middleware_handler(
|
||||
self, context: AgentContext
|
||||
) -> Awaitable[AgentResponse] | ResponseStream[AgentResponseUpdate, AgentResponse]:
|
||||
# TODO(Copilot): Delete once direct ``run(**kwargs)`` compatibility is removed.
|
||||
client_kwargs = {**context.client_kwargs, **context.kwargs}
|
||||
# TODO(Copilot): Delete once direct ``run(**kwargs)`` compatibility is removed.
|
||||
function_invocation_kwargs = {
|
||||
**context.function_invocation_kwargs,
|
||||
**{k: v for k, v in context.kwargs.items() if k != "middleware"},
|
||||
}
|
||||
return super().run( # type: ignore[misc, no-any-return]
|
||||
context.messages,
|
||||
stream=context.stream,
|
||||
@@ -1237,7 +1295,8 @@ class AgentMiddlewareLayer:
|
||||
options=context.options,
|
||||
compaction_strategy=context.compaction_strategy,
|
||||
tokenizer=context.tokenizer,
|
||||
**context.kwargs,
|
||||
function_invocation_kwargs=function_invocation_kwargs,
|
||||
client_kwargs=client_kwargs,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -392,12 +392,16 @@ class BaseHistoryProvider(BaseContextProvider):
|
||||
self.store_outputs = store_outputs
|
||||
|
||||
@abstractmethod
|
||||
async def get_messages(self, session_id: str | None, **kwargs: Any) -> list[Message]:
|
||||
async def get_messages(
|
||||
self, session_id: str | None, *, state: dict[str, Any] | None = None, **kwargs: Any
|
||||
) -> list[Message]:
|
||||
"""Retrieve stored messages for this session.
|
||||
|
||||
Args:
|
||||
session_id: The session ID to retrieve messages for.
|
||||
**kwargs: Additional arguments (e.g., ``state`` for in-memory providers).
|
||||
state: Optional session state for providers that persist in session state.
|
||||
Not used by all providers.
|
||||
**kwargs: Additional subclass-specific extensibility arguments.
|
||||
|
||||
Returns:
|
||||
List of stored messages.
|
||||
@@ -405,13 +409,22 @@ class BaseHistoryProvider(BaseContextProvider):
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def save_messages(self, session_id: str | None, messages: Sequence[Message], **kwargs: Any) -> None:
|
||||
async def save_messages(
|
||||
self,
|
||||
session_id: str | None,
|
||||
messages: Sequence[Message],
|
||||
*,
|
||||
state: dict[str, Any] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Persist messages for this session.
|
||||
|
||||
Args:
|
||||
session_id: The session ID to store messages for.
|
||||
messages: The messages to persist.
|
||||
**kwargs: Additional arguments (e.g., ``state`` for in-memory providers).
|
||||
state: Optional session state for providers that persist in session state.
|
||||
Not used by all providers.
|
||||
**kwargs: Additional subclass-specific extensibility arguments.
|
||||
"""
|
||||
...
|
||||
|
||||
|
||||
@@ -7,6 +7,8 @@ import inspect
|
||||
import json
|
||||
import logging
|
||||
import sys
|
||||
import typing
|
||||
import warnings
|
||||
from collections.abc import (
|
||||
AsyncIterable,
|
||||
Awaitable,
|
||||
@@ -37,7 +39,7 @@ from opentelemetry.metrics import Histogram, NoOpHistogram
|
||||
from pydantic import BaseModel, Field, ValidationError, create_model
|
||||
|
||||
from ._serialization import SerializationMixin
|
||||
from .exceptions import ToolException
|
||||
from .exceptions import ToolException, UserInputRequiredException
|
||||
from .observability import (
|
||||
OPERATION_DURATION_BUCKET_BOUNDARIES,
|
||||
OtelAttr,
|
||||
@@ -61,7 +63,8 @@ if TYPE_CHECKING:
|
||||
from ._clients import SupportsChatGetResponse
|
||||
from ._compaction import CompactionStrategy, TokenizerProtocol
|
||||
from ._mcp import MCPTool
|
||||
from ._middleware import FunctionMiddlewarePipeline, FunctionMiddlewareTypes
|
||||
from ._middleware import FunctionInvocationContext, FunctionMiddlewarePipeline, FunctionMiddlewareTypes
|
||||
from ._sessions import AgentSession
|
||||
from ._types import (
|
||||
ChatOptions,
|
||||
ChatResponse,
|
||||
@@ -187,6 +190,16 @@ def _default_histogram() -> Histogram:
|
||||
)
|
||||
|
||||
|
||||
def _annotation_includes_function_invocation_context(annotation: Any) -> bool:
|
||||
"""Check whether an annotation resolves to FunctionInvocationContext."""
|
||||
from ._middleware import FunctionInvocationContext
|
||||
|
||||
candidates = get_args(annotation) or (annotation,)
|
||||
return any(
|
||||
candidate is FunctionInvocationContext or candidate == "FunctionInvocationContext" for candidate in candidates
|
||||
)
|
||||
|
||||
|
||||
ClassT = TypeVar("ClassT", bound="SerializationMixin")
|
||||
|
||||
|
||||
@@ -323,6 +336,12 @@ class FunctionTool(SerializationMixin):
|
||||
# FunctionTool-specific attributes
|
||||
self.func = func
|
||||
self._instance = None # Store the instance for bound methods
|
||||
self._context_parameter_name: str | None = None
|
||||
self._input_model_explicitly_provided = input_model is not None
|
||||
# TODO(Copilot): Delete once legacy ``**kwargs`` runtime injection is removed.
|
||||
self._forward_runtime_kwargs: bool = False
|
||||
if self.func:
|
||||
self._discover_injected_parameters()
|
||||
|
||||
# Initialize schema cache (will be lazily populated)
|
||||
self._input_schema_cached: dict[str, Any] | None = None
|
||||
@@ -349,13 +368,37 @@ class FunctionTool(SerializationMixin):
|
||||
self._invocation_duration_histogram = _default_histogram()
|
||||
self.type: Literal["function_tool"] = "function_tool"
|
||||
self.result_parser = result_parser
|
||||
self._forward_runtime_kwargs: bool = False
|
||||
if self.func:
|
||||
sig = inspect.signature(self.func)
|
||||
for param in sig.parameters.values():
|
||||
if param.kind == inspect.Parameter.VAR_KEYWORD:
|
||||
self._forward_runtime_kwargs = True
|
||||
break
|
||||
|
||||
def _discover_injected_parameters(self) -> None:
|
||||
"""Inspect the wrapped function for runtime injection parameters."""
|
||||
func = self.func.func if isinstance(self.func, FunctionTool) else self.func
|
||||
if func is None:
|
||||
return
|
||||
|
||||
signature = inspect.signature(func)
|
||||
try:
|
||||
type_hints = typing.get_type_hints(func)
|
||||
except Exception:
|
||||
type_hints = {name: param.annotation for name, param in signature.parameters.items()}
|
||||
|
||||
for name, param in signature.parameters.items():
|
||||
if name in {"self", "cls"}:
|
||||
continue
|
||||
if param.kind == inspect.Parameter.VAR_KEYWORD:
|
||||
self._forward_runtime_kwargs = True
|
||||
continue
|
||||
|
||||
annotation = type_hints.get(name, param.annotation)
|
||||
if self._is_context_parameter(name, annotation):
|
||||
if self._context_parameter_name is not None:
|
||||
raise ValueError(f"Function '{self.name}' defines multiple FunctionInvocationContext parameters.")
|
||||
self._context_parameter_name = name
|
||||
|
||||
def _is_context_parameter(self, name: str, annotation: Any) -> bool:
|
||||
"""Check whether a callable parameter should receive FunctionInvocationContext injection."""
|
||||
if _annotation_includes_function_invocation_context(annotation):
|
||||
return True
|
||||
return self._input_model_explicitly_provided and name == "ctx" and annotation is inspect.Parameter.empty
|
||||
|
||||
def __str__(self) -> str:
|
||||
"""Return a string representation of the tool."""
|
||||
@@ -424,6 +467,7 @@ class FunctionTool(SerializationMixin):
|
||||
)
|
||||
for pname, param in sig.parameters.items()
|
||||
if pname not in {"self", "cls"}
|
||||
and pname != self._context_parameter_name
|
||||
and param.kind not in {inspect.Parameter.VAR_POSITIONAL, inspect.Parameter.VAR_KEYWORD}
|
||||
}
|
||||
return create_model(f"{self.name}_input", **fields)
|
||||
@@ -461,6 +505,7 @@ class FunctionTool(SerializationMixin):
|
||||
self,
|
||||
*,
|
||||
arguments: BaseModel | Mapping[str, Any] | None = None,
|
||||
context: FunctionInvocationContext | None = None,
|
||||
**kwargs: Any,
|
||||
) -> list[Content]:
|
||||
"""Run the AI function with the provided arguments as a Pydantic model.
|
||||
@@ -472,7 +517,8 @@ class FunctionTool(SerializationMixin):
|
||||
|
||||
Keyword Args:
|
||||
arguments: A mapping or model instance containing the arguments for the function.
|
||||
kwargs: Keyword arguments to pass to the function, will not be used if ``arguments`` is provided.
|
||||
context: Explicit function invocation context carrying runtime kwargs.
|
||||
kwargs: Deprecated keyword arguments to pass to the function. Use ``context`` instead.
|
||||
|
||||
Returns:
|
||||
A list of Content items representing the tool output.
|
||||
@@ -483,14 +529,37 @@ class FunctionTool(SerializationMixin):
|
||||
if self.declaration_only:
|
||||
raise ToolException(f"Function '{self.name}' is declaration only and cannot be invoked.")
|
||||
global OBSERVABILITY_SETTINGS
|
||||
from ._middleware import FunctionInvocationContext
|
||||
from ._types import Content
|
||||
from .observability import OBSERVABILITY_SETTINGS
|
||||
|
||||
parser = self.result_parser or FunctionTool.parse_result
|
||||
|
||||
original_kwargs = dict(kwargs)
|
||||
tool_call_id = original_kwargs.pop("tool_call_id", None)
|
||||
if arguments is not None:
|
||||
parameter_names = set(self.parameters().get("properties", {}).keys())
|
||||
direct_argument_kwargs = (
|
||||
{key: value for key, value in kwargs.items() if key in parameter_names} if arguments is None else {}
|
||||
)
|
||||
runtime_kwargs = dict(context.kwargs) if context is not None else {}
|
||||
deprecated_runtime_kwargs = {
|
||||
key: value for key, value in kwargs.items() if key not in direct_argument_kwargs and key != "tool_call_id"
|
||||
}
|
||||
if deprecated_runtime_kwargs:
|
||||
warnings.warn(
|
||||
"Passing runtime keyword arguments directly to FunctionTool.invoke() is deprecated; "
|
||||
"pass them via FunctionInvocationContext instead.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
runtime_kwargs.update(deprecated_runtime_kwargs)
|
||||
tool_call_id = kwargs.get("tool_call_id", runtime_kwargs.pop("tool_call_id", None))
|
||||
if arguments is None and direct_argument_kwargs:
|
||||
arguments = direct_argument_kwargs
|
||||
if arguments is None and context is not None:
|
||||
arguments = context.arguments
|
||||
|
||||
if arguments is None:
|
||||
validated_arguments: dict[str, Any] = {}
|
||||
else:
|
||||
try:
|
||||
if isinstance(arguments, Mapping):
|
||||
parsed_arguments = dict(arguments)
|
||||
@@ -512,19 +581,45 @@ class FunctionTool(SerializationMixin):
|
||||
)
|
||||
except ValidationError as exc:
|
||||
raise TypeError(f"Invalid arguments for '{self.name}': {exc}") from exc
|
||||
kwargs = _validate_arguments_against_schema(
|
||||
|
||||
validated_arguments = _validate_arguments_against_schema(
|
||||
arguments=parsed_arguments,
|
||||
schema=self.parameters(),
|
||||
tool_name=self.name,
|
||||
)
|
||||
if getattr(self, "_forward_runtime_kwargs", False) and original_kwargs:
|
||||
kwargs.update(original_kwargs)
|
||||
else:
|
||||
kwargs = original_kwargs
|
||||
|
||||
effective_context = context
|
||||
if effective_context is None and self._context_parameter_name is not None:
|
||||
effective_context = FunctionInvocationContext(
|
||||
function=self,
|
||||
arguments=validated_arguments,
|
||||
kwargs=runtime_kwargs,
|
||||
)
|
||||
if effective_context is not None:
|
||||
effective_context.function = self
|
||||
effective_context.arguments = validated_arguments
|
||||
effective_context.kwargs = dict(runtime_kwargs)
|
||||
|
||||
call_kwargs = dict(validated_arguments)
|
||||
observable_kwargs = dict(validated_arguments)
|
||||
|
||||
# Legacy runtime kwargs injection path retained for backwards compatibility with tools
|
||||
# that still declare ``**kwargs``. New tools should consume runtime data via ``ctx``.
|
||||
legacy_runtime_kwargs = dict(runtime_kwargs)
|
||||
if self._forward_runtime_kwargs and legacy_runtime_kwargs:
|
||||
for key, value in legacy_runtime_kwargs.items():
|
||||
if key not in call_kwargs:
|
||||
call_kwargs[key] = value
|
||||
if key not in observable_kwargs:
|
||||
observable_kwargs[key] = value
|
||||
|
||||
if self._context_parameter_name is not None and effective_context is not None:
|
||||
call_kwargs[self._context_parameter_name] = effective_context
|
||||
|
||||
if not OBSERVABILITY_SETTINGS.ENABLED: # type: ignore[name-defined]
|
||||
logger.info(f"Function name: {self.name}")
|
||||
logger.debug(f"Function arguments: {kwargs}")
|
||||
res = self.__call__(**kwargs)
|
||||
logger.debug(f"Function arguments: {observable_kwargs}")
|
||||
res = self.__call__(**call_kwargs)
|
||||
result = await res if inspect.isawaitable(res) else res
|
||||
try:
|
||||
parsed = parser(result)
|
||||
@@ -545,7 +640,7 @@ class FunctionTool(SerializationMixin):
|
||||
# Filter out framework kwargs that are not JSON serializable.
|
||||
serializable_kwargs = {
|
||||
k: v
|
||||
for k, v in kwargs.items()
|
||||
for k, v in observable_kwargs.items()
|
||||
if k
|
||||
not in {
|
||||
"chat_options",
|
||||
@@ -571,7 +666,7 @@ class FunctionTool(SerializationMixin):
|
||||
start_time_stamp = perf_counter()
|
||||
end_time_stamp: float | None = None
|
||||
try:
|
||||
res = self.__call__(**kwargs)
|
||||
res = self.__call__(**call_kwargs)
|
||||
result = await res if inspect.isawaitable(res) else res
|
||||
end_time_stamp = perf_counter()
|
||||
except Exception as exception:
|
||||
@@ -1218,9 +1313,10 @@ async def _auto_invoke_function(
|
||||
*,
|
||||
config: FunctionInvocationConfiguration,
|
||||
tool_map: dict[str, FunctionTool],
|
||||
invocation_session: AgentSession | None = None,
|
||||
sequence_index: int | None = None,
|
||||
request_index: int | None = None,
|
||||
middleware_pipeline: FunctionMiddlewarePipeline | None = None, # Optional MiddlewarePipeline
|
||||
middleware_pipeline: FunctionMiddlewarePipeline | None = None,
|
||||
) -> Content:
|
||||
"""Invoke a function call requested by the agent, applying middleware that is defined.
|
||||
|
||||
@@ -1231,6 +1327,7 @@ async def _auto_invoke_function(
|
||||
Keyword Args:
|
||||
config: The function invocation configuration.
|
||||
tool_map: A mapping of tool names to FunctionTool instances.
|
||||
invocation_session: The agent session for this invocation, if any.
|
||||
sequence_index: The index of the function call in the sequence.
|
||||
request_index: The index of the request iteration.
|
||||
middleware_pipeline: Optional middleware pipeline to apply during execution.
|
||||
@@ -1282,6 +1379,8 @@ async def _auto_invoke_function(
|
||||
for key, value in (custom_args or {}).items()
|
||||
if key not in {"_function_middleware_pipeline", "middleware", "conversation_id"}
|
||||
}
|
||||
if invocation_session is not None:
|
||||
runtime_kwargs["session"] = invocation_session
|
||||
try:
|
||||
if not cast(bool, getattr(tool, "_schema_supplied", False)) and tool.input_model is not None:
|
||||
args = tool.input_model.model_validate(parsed_args).model_dump(exclude_none=True)
|
||||
@@ -1303,19 +1402,31 @@ async def _auto_invoke_function(
|
||||
additional_properties=function_call_content.additional_properties,
|
||||
)
|
||||
|
||||
from ._middleware import FunctionInvocationContext
|
||||
|
||||
if middleware_pipeline is None or not middleware_pipeline.has_middlewares:
|
||||
# No middleware - execute directly
|
||||
try:
|
||||
direct_context = None
|
||||
if getattr(tool, "_forward_runtime_kwargs", False) or getattr(tool, "_context_parameter_name", None):
|
||||
direct_context = FunctionInvocationContext(
|
||||
function=tool,
|
||||
arguments=args,
|
||||
session=invocation_session,
|
||||
kwargs=runtime_kwargs.copy(),
|
||||
)
|
||||
function_result = await tool.invoke(
|
||||
arguments=args,
|
||||
context=direct_context,
|
||||
tool_call_id=function_call_content.call_id,
|
||||
**runtime_kwargs if getattr(tool, "_forward_runtime_kwargs", False) else {},
|
||||
)
|
||||
return Content.from_function_result(
|
||||
call_id=function_call_content.call_id, # type: ignore[arg-type]
|
||||
result=function_result,
|
||||
additional_properties=function_call_content.additional_properties,
|
||||
)
|
||||
except UserInputRequiredException:
|
||||
raise
|
||||
except Exception as exc:
|
||||
message = "Error: Function failed."
|
||||
if config.get("include_detailed_errors", False):
|
||||
@@ -1327,19 +1438,18 @@ async def _auto_invoke_function(
|
||||
additional_properties=function_call_content.additional_properties,
|
||||
)
|
||||
# Execute through middleware pipeline if available
|
||||
from ._middleware import FunctionInvocationContext
|
||||
|
||||
middleware_context = FunctionInvocationContext(
|
||||
function=tool,
|
||||
arguments=args,
|
||||
session=invocation_session,
|
||||
kwargs=runtime_kwargs.copy(),
|
||||
)
|
||||
|
||||
async def final_function_handler(context_obj: Any) -> Any:
|
||||
return await tool.invoke(
|
||||
arguments=context_obj.arguments,
|
||||
context=context_obj,
|
||||
tool_call_id=function_call_content.call_id,
|
||||
**context_obj.kwargs if getattr(tool, "_forward_runtime_kwargs", False) else {},
|
||||
)
|
||||
|
||||
from ._middleware import MiddlewareTermination
|
||||
@@ -1362,6 +1472,8 @@ async def _auto_invoke_function(
|
||||
additional_properties=function_call_content.additional_properties,
|
||||
)
|
||||
raise
|
||||
except UserInputRequiredException:
|
||||
raise
|
||||
except Exception as exc:
|
||||
message = "Error: Function failed."
|
||||
if config.get("include_detailed_errors", False):
|
||||
@@ -1390,7 +1502,8 @@ async def _try_execute_function_calls(
|
||||
function_calls: Sequence[Content],
|
||||
tools: ToolTypes | Callable[..., Any] | Sequence[ToolTypes | Callable[..., Any]],
|
||||
config: FunctionInvocationConfiguration,
|
||||
middleware_pipeline: Any = None, # Optional MiddlewarePipeline to avoid circular imports
|
||||
invocation_session: AgentSession | None = None,
|
||||
middleware_pipeline: Any = None,
|
||||
) -> tuple[Sequence[Content], bool]:
|
||||
"""Execute multiple function calls concurrently.
|
||||
|
||||
@@ -1400,6 +1513,7 @@ async def _try_execute_function_calls(
|
||||
function_calls: A sequence of FunctionCallContent to execute.
|
||||
tools: The tools available for execution.
|
||||
config: Configuration for function invocation.
|
||||
invocation_session: The agent session for this invocation, if any.
|
||||
middleware_pipeline: Optional middleware pipeline to apply during execution.
|
||||
|
||||
Returns:
|
||||
@@ -1469,6 +1583,8 @@ async def _try_execute_function_calls(
|
||||
# Run all function calls concurrently, handling MiddlewareTermination
|
||||
from ._middleware import MiddlewareTermination
|
||||
|
||||
extra_user_input_contents: list[Content] = []
|
||||
|
||||
async def invoke_with_termination_handling(
|
||||
function_call: Content,
|
||||
seq_idx: int,
|
||||
@@ -1479,6 +1595,7 @@ async def _try_execute_function_calls(
|
||||
function_call_content=function_call, # type: ignore[arg-type]
|
||||
custom_args=custom_args,
|
||||
tool_map=tool_map,
|
||||
invocation_session=invocation_session,
|
||||
sequence_index=seq_idx,
|
||||
request_index=attempt_idx,
|
||||
middleware_pipeline=middleware_pipeline,
|
||||
@@ -1495,6 +1612,26 @@ async def _try_execute_function_calls(
|
||||
result=exc.result,
|
||||
)
|
||||
return (result_content, True)
|
||||
except UserInputRequiredException as exc:
|
||||
if exc.contents:
|
||||
propagated: list[Content] = []
|
||||
for item in exc.contents:
|
||||
if isinstance(item, Content):
|
||||
item.call_id = function_call.call_id # type: ignore[attr-defined]
|
||||
if not item.id: # type: ignore[attr-defined]
|
||||
item.id = function_call.call_id # type: ignore[attr-defined]
|
||||
propagated.append(item)
|
||||
if propagated:
|
||||
extra_user_input_contents.extend(propagated[1:])
|
||||
return (propagated[0], False)
|
||||
return (
|
||||
Content.from_function_result(
|
||||
call_id=function_call.call_id, # type: ignore[arg-type]
|
||||
result="Tool requires user input but no request details were provided.",
|
||||
exception="UserInputRequiredException",
|
||||
),
|
||||
False,
|
||||
)
|
||||
|
||||
execution_results = await asyncio.gather(*[
|
||||
invoke_with_termination_handling(function_call, seq_idx) for seq_idx, function_call in enumerate(function_calls)
|
||||
@@ -1502,6 +1639,7 @@ async def _try_execute_function_calls(
|
||||
|
||||
# Unpack results - each is (Content, terminate_flag)
|
||||
contents: list[Content] = [result[0] for result in execution_results]
|
||||
contents.extend(extra_user_input_contents)
|
||||
# If any function requested termination, terminate the loop
|
||||
should_terminate = any(result[1] for result in execution_results)
|
||||
return (contents, should_terminate)
|
||||
@@ -1514,6 +1652,7 @@ async def _execute_function_calls(
|
||||
function_calls: list[Content],
|
||||
tool_options: dict[str, Any] | None,
|
||||
config: FunctionInvocationConfiguration,
|
||||
invocation_session: AgentSession | None = None,
|
||||
middleware_pipeline: Any = None,
|
||||
) -> tuple[list[Content], bool, bool]:
|
||||
tools = _extract_tools(tool_options)
|
||||
@@ -1524,6 +1663,7 @@ async def _execute_function_calls(
|
||||
attempt_idx=attempt_idx,
|
||||
function_calls=function_calls,
|
||||
tools=tools, # type: ignore
|
||||
invocation_session=invocation_session,
|
||||
middleware_pipeline=middleware_pipeline,
|
||||
config=config,
|
||||
)
|
||||
@@ -1733,7 +1873,10 @@ def _handle_function_call_results(
|
||||
) -> FunctionRequestResult:
|
||||
from ._types import Message
|
||||
|
||||
if any(fccr.type in {"function_approval_request", "function_call"} for fccr in function_call_results):
|
||||
if any(
|
||||
fccr.type in {"function_approval_request", "function_call"} or fccr.user_input_request
|
||||
for fccr in function_call_results
|
||||
):
|
||||
# Only add items that aren't already in the message (e.g. function_approval_request wrappers).
|
||||
# Declaration-only function_call items are already present from the LLM response.
|
||||
new_items = [fccr for fccr in function_call_results if fccr.type != "function_call"]
|
||||
@@ -1901,6 +2044,8 @@ class FunctionInvocationLayer(Generic[OptionsCoT]):
|
||||
options: ChatOptions[ResponseModelBoundT],
|
||||
compaction_strategy: CompactionStrategy | None = None,
|
||||
tokenizer: TokenizerProtocol | None = None,
|
||||
function_invocation_kwargs: Mapping[str, Any] | None = None,
|
||||
client_kwargs: Mapping[str, Any] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> Awaitable[ChatResponse[ResponseModelBoundT]]: ...
|
||||
|
||||
@@ -1913,6 +2058,8 @@ class FunctionInvocationLayer(Generic[OptionsCoT]):
|
||||
options: OptionsCoT | ChatOptions[None] | None = None,
|
||||
compaction_strategy: CompactionStrategy | None = None,
|
||||
tokenizer: TokenizerProtocol | None = None,
|
||||
function_invocation_kwargs: Mapping[str, Any] | None = None,
|
||||
client_kwargs: Mapping[str, Any] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> Awaitable[ChatResponse[Any]]: ...
|
||||
|
||||
@@ -1925,6 +2072,8 @@ class FunctionInvocationLayer(Generic[OptionsCoT]):
|
||||
options: OptionsCoT | ChatOptions[Any] | None = None,
|
||||
compaction_strategy: CompactionStrategy | None = None,
|
||||
tokenizer: TokenizerProtocol | None = None,
|
||||
function_invocation_kwargs: Mapping[str, Any] | None = None,
|
||||
client_kwargs: Mapping[str, Any] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> ResponseStream[ChatResponseUpdate, ChatResponse[Any]]: ...
|
||||
|
||||
@@ -1937,6 +2086,8 @@ class FunctionInvocationLayer(Generic[OptionsCoT]):
|
||||
function_middleware: Sequence[FunctionMiddlewareTypes] | None = None,
|
||||
compaction_strategy: CompactionStrategy | None = None,
|
||||
tokenizer: TokenizerProtocol | None = None,
|
||||
function_invocation_kwargs: Mapping[str, Any] | None = None,
|
||||
client_kwargs: Mapping[str, Any] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> Awaitable[ChatResponse[Any]] | ResponseStream[ChatResponseUpdate, ChatResponse[Any]]:
|
||||
from ._middleware import FunctionMiddlewarePipeline
|
||||
@@ -1947,28 +2098,45 @@ class FunctionInvocationLayer(Generic[OptionsCoT]):
|
||||
)
|
||||
|
||||
super_get_response = super().get_response # type: ignore[misc]
|
||||
if kwargs:
|
||||
warnings.warn(
|
||||
"Passing client-specific keyword arguments directly to get_response() is deprecated; "
|
||||
"pass them via client_kwargs instead.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
effective_client_kwargs = dict(client_kwargs) if client_kwargs is not None else {}
|
||||
effective_function_middleware = function_middleware
|
||||
if effective_function_middleware is None:
|
||||
middleware_from_client_kwargs = effective_client_kwargs.pop("function_middleware", None)
|
||||
if middleware_from_client_kwargs is not None:
|
||||
effective_function_middleware = cast(Sequence[Any], middleware_from_client_kwargs)
|
||||
|
||||
# ChatMiddleware adds this kwarg
|
||||
function_middleware_pipeline = FunctionMiddlewarePipeline(
|
||||
*(self.function_middleware), *(function_middleware or [])
|
||||
*(self.function_middleware), *(effective_function_middleware or [])
|
||||
)
|
||||
max_errors = self.function_invocation_configuration.get(
|
||||
"max_consecutive_errors_per_request", DEFAULT_MAX_CONSECUTIVE_ERRORS_PER_REQUEST
|
||||
)
|
||||
additional_function_arguments: dict[str, Any] = {}
|
||||
additional_function_arguments = (
|
||||
dict(function_invocation_kwargs) if function_invocation_kwargs is not None else {}
|
||||
)
|
||||
if options and (additional_opts := options.get("additional_function_arguments")): # type: ignore[attr-defined]
|
||||
additional_function_arguments = additional_opts # type: ignore
|
||||
additional_function_arguments.update(cast(Mapping[str, Any], additional_opts))
|
||||
from ._sessions import AgentSession as _AgentSession
|
||||
|
||||
raw_session = effective_client_kwargs.get("session")
|
||||
invocation_session = raw_session if isinstance(raw_session, _AgentSession) else None
|
||||
execute_function_calls = partial(
|
||||
_execute_function_calls,
|
||||
custom_args=additional_function_arguments,
|
||||
config=self.function_invocation_configuration,
|
||||
invocation_session=invocation_session,
|
||||
middleware_pipeline=function_middleware_pipeline,
|
||||
)
|
||||
filtered_kwargs = {k: v for k, v in kwargs.items() if k != "session"}
|
||||
if compaction_strategy is not None:
|
||||
filtered_kwargs["compaction_strategy"] = compaction_strategy
|
||||
if tokenizer is not None:
|
||||
filtered_kwargs["tokenizer"] = tokenizer
|
||||
filtered_kwargs = {k: v for k, v in {**effective_client_kwargs, **kwargs}.items() if k != "session"}
|
||||
|
||||
# Make options mutable so we can update conversation_id during function invocation loop
|
||||
mutable_options: dict[str, Any] = dict(options) if options else {}
|
||||
@@ -2018,7 +2186,9 @@ class FunctionInvocationLayer(Generic[OptionsCoT]):
|
||||
messages=prepped_messages,
|
||||
stream=False,
|
||||
options=mutable_options,
|
||||
**filtered_kwargs,
|
||||
compaction_strategy=compaction_strategy,
|
||||
tokenizer=tokenizer,
|
||||
client_kwargs=filtered_kwargs,
|
||||
),
|
||||
)
|
||||
|
||||
@@ -2087,7 +2257,9 @@ class FunctionInvocationLayer(Generic[OptionsCoT]):
|
||||
messages=prepped_messages,
|
||||
stream=False,
|
||||
options=mutable_options,
|
||||
**filtered_kwargs,
|
||||
compaction_strategy=compaction_strategy,
|
||||
tokenizer=tokenizer,
|
||||
client_kwargs=filtered_kwargs,
|
||||
),
|
||||
)
|
||||
if fcc_messages:
|
||||
@@ -2137,7 +2309,9 @@ class FunctionInvocationLayer(Generic[OptionsCoT]):
|
||||
messages=prepped_messages,
|
||||
stream=True,
|
||||
options=mutable_options,
|
||||
**filtered_kwargs,
|
||||
compaction_strategy=compaction_strategy,
|
||||
tokenizer=tokenizer,
|
||||
client_kwargs=filtered_kwargs,
|
||||
),
|
||||
)
|
||||
await inner_stream
|
||||
@@ -2229,7 +2403,9 @@ class FunctionInvocationLayer(Generic[OptionsCoT]):
|
||||
messages=prepped_messages,
|
||||
stream=True,
|
||||
options=mutable_options,
|
||||
**filtered_kwargs,
|
||||
compaction_strategy=compaction_strategy,
|
||||
tokenizer=tokenizer,
|
||||
client_kwargs=filtered_kwargs,
|
||||
),
|
||||
)
|
||||
await final_inner_stream
|
||||
|
||||
@@ -2698,7 +2698,7 @@ class ResponseStream(AsyncIterable[UpdateT], Generic[UpdateT, FinalT]):
|
||||
stream: AsyncIterable[UpdateT] | Awaitable[AsyncIterable[UpdateT]],
|
||||
*,
|
||||
finalizer: Callable[[Sequence[UpdateT]], FinalT | Awaitable[FinalT]] | None = None,
|
||||
transform_hooks: list[Callable[[UpdateT], UpdateT | Awaitable[UpdateT] | None]] | None = None,
|
||||
transform_hooks: list[Callable[[UpdateT], UpdateT | Awaitable[UpdateT | None] | None]] | None = None,
|
||||
cleanup_hooks: list[Callable[[], Awaitable[None] | None]] | None = None,
|
||||
result_hooks: list[Callable[[FinalT], FinalT | Awaitable[FinalT | None] | None]] | None = None,
|
||||
) -> None:
|
||||
@@ -2722,7 +2722,7 @@ class ResponseStream(AsyncIterable[UpdateT], Generic[UpdateT, FinalT]):
|
||||
self._consumed: bool = False
|
||||
self._finalized: bool = False
|
||||
self._final_result: FinalT | None = None
|
||||
self._transform_hooks: list[Callable[[UpdateT], UpdateT | Awaitable[UpdateT] | None]] = (
|
||||
self._transform_hooks: list[Callable[[UpdateT], UpdateT | Awaitable[UpdateT | None] | None]] = (
|
||||
transform_hooks if transform_hooks is not None else []
|
||||
)
|
||||
self._result_hooks: list[Callable[[FinalT], FinalT | Awaitable[FinalT | None] | None]] = (
|
||||
@@ -2995,7 +2995,7 @@ class ResponseStream(AsyncIterable[UpdateT], Generic[UpdateT, FinalT]):
|
||||
|
||||
def with_transform_hook(
|
||||
self,
|
||||
hook: Callable[[UpdateT], UpdateT | Awaitable[UpdateT] | None],
|
||||
hook: Callable[[UpdateT], UpdateT | Awaitable[UpdateT | None] | None],
|
||||
) -> ResponseStream[UpdateT, FinalT]:
|
||||
"""Register a transform hook executed for each update during iteration."""
|
||||
self._transform_hooks.append(hook)
|
||||
|
||||
@@ -172,12 +172,12 @@ class AzureOpenAIChatClient( # type: ignore[misc]
|
||||
credential: AzureCredentialTypes | AzureTokenProvider | None = None,
|
||||
default_headers: Mapping[str, str] | None = None,
|
||||
async_client: AsyncAzureOpenAI | None = None,
|
||||
additional_properties: dict[str, Any] | None = None,
|
||||
env_file_path: str | None = None,
|
||||
env_file_encoding: str | None = None,
|
||||
instruction_role: str | None = None,
|
||||
middleware: Sequence[MiddlewareTypes] | None = None,
|
||||
function_invocation_configuration: FunctionInvocationConfiguration | None = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Initialize an Azure OpenAI Chat completion client.
|
||||
|
||||
@@ -205,13 +205,13 @@ class AzureOpenAIChatClient( # type: ignore[misc]
|
||||
default_headers: The default headers mapping of string keys to
|
||||
string values for HTTP requests.
|
||||
async_client: An existing client to use.
|
||||
additional_properties: Additional properties stored on the client instance.
|
||||
env_file_path: Use the environment settings file as a fallback to using env vars.
|
||||
env_file_encoding: The encoding of the environment settings file, defaults to 'utf-8'.
|
||||
instruction_role: The role to use for 'instruction' messages, for example, summarization
|
||||
prompts could use `developer` or `system`.
|
||||
middleware: Optional sequence of middleware to apply to requests.
|
||||
function_invocation_configuration: Optional configuration for function invocation behavior.
|
||||
kwargs: Other keyword parameters.
|
||||
|
||||
Examples:
|
||||
.. code-block:: python
|
||||
@@ -283,10 +283,10 @@ class AzureOpenAIChatClient( # type: ignore[misc]
|
||||
credential=credential,
|
||||
default_headers=default_headers,
|
||||
client=async_client,
|
||||
additional_properties=additional_properties,
|
||||
instruction_role=instruction_role,
|
||||
middleware=middleware,
|
||||
function_invocation_configuration=function_invocation_configuration,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@override
|
||||
|
||||
@@ -180,6 +180,34 @@ class ToolExecutionException(ToolException):
|
||||
pass
|
||||
|
||||
|
||||
class UserInputRequiredException(ToolException):
|
||||
"""Raised when a tool wrapping a sub-agent requires user input to proceed.
|
||||
|
||||
This exception carries the ``user_input_request`` Content items emitted by
|
||||
the sub-agent (e.g., ``oauth_consent_request``, ``function_approval_request``)
|
||||
so the tool invocation layer can propagate them to the parent agent's response
|
||||
instead of swallowing them as a generic tool error.
|
||||
|
||||
Args:
|
||||
contents: The user-input-request Content items from the sub-agent response.
|
||||
message: Human-readable description of why user input is needed.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
contents: list[Any],
|
||||
message: str = "Tool requires user input to proceed.",
|
||||
) -> None:
|
||||
"""Create a UserInputRequiredException.
|
||||
|
||||
Args:
|
||||
contents: The user-input-request Content items from the sub-agent response.
|
||||
message: Human-readable description of why user input is needed.
|
||||
"""
|
||||
super().__init__(message, log_level=None)
|
||||
self.contents = contents
|
||||
|
||||
|
||||
# endregion
|
||||
|
||||
# region Middleware Exceptions
|
||||
|
||||
@@ -1162,11 +1162,35 @@ class ChatTelemetryLayer(Generic[OptionsCoT]):
|
||||
tokenizer: TokenizerProtocol | None = None,
|
||||
**kwargs: Any,
|
||||
) -> Awaitable[ChatResponse[Any]] | ResponseStream[ChatResponseUpdate, ChatResponse[Any]]:
|
||||
"""Trace chat responses with OpenTelemetry spans and metrics."""
|
||||
"""Trace chat responses with OpenTelemetry spans and metrics.
|
||||
|
||||
Args:
|
||||
messages: The message or messages to send to the model.
|
||||
stream: Whether to stream the response. Defaults to False.
|
||||
options: Chat options as a TypedDict.
|
||||
compaction_strategy: Optional compaction strategy to apply before model calls.
|
||||
tokenizer: Optional tokenizer used by token-aware compaction strategies.
|
||||
|
||||
Keyword Args:
|
||||
kwargs: Compatibility keyword arguments from higher client layers. This layer does
|
||||
not consume ``function_invocation_kwargs`` directly; if present, it is ignored
|
||||
because function invocation has already been processed above. If a ``client_kwargs``
|
||||
mapping is present, it is flattened into ordinary keyword arguments for tracing and
|
||||
forwarding so clients that use those values continue to work while clients that
|
||||
ignore extra kwargs remain compatible.
|
||||
"""
|
||||
from ._types import ChatResponse, ChatResponseUpdate, ResponseStream # type: ignore[reportUnusedImport]
|
||||
|
||||
global OBSERVABILITY_SETTINGS
|
||||
super_get_response = super().get_response # type: ignore[misc]
|
||||
compatibility_client_kwargs = kwargs.pop("client_kwargs", None)
|
||||
kwargs.pop("function_invocation_kwargs", None)
|
||||
merged_client_kwargs = (
|
||||
dict(cast(Mapping[str, Any], compatibility_client_kwargs))
|
||||
if isinstance(compatibility_client_kwargs, Mapping)
|
||||
else {}
|
||||
)
|
||||
merged_client_kwargs.update(kwargs)
|
||||
|
||||
if not OBSERVABILITY_SETTINGS.ENABLED:
|
||||
return super_get_response( # type: ignore[no-any-return]
|
||||
@@ -1175,12 +1199,14 @@ class ChatTelemetryLayer(Generic[OptionsCoT]):
|
||||
options=options,
|
||||
compaction_strategy=compaction_strategy,
|
||||
tokenizer=tokenizer,
|
||||
**kwargs,
|
||||
**merged_client_kwargs,
|
||||
)
|
||||
|
||||
opts: dict[str, Any] = options or {} # type: ignore[assignment]
|
||||
provider_name = str(getattr(self, "otel_provider_name", "unknown"))
|
||||
model_id = kwargs.get("model_id") or opts.get("model_id") or getattr(self, "model_id", None) or "unknown"
|
||||
model_id = (
|
||||
merged_client_kwargs.get("model_id") or opts.get("model_id") or getattr(self, "model_id", None) or "unknown"
|
||||
)
|
||||
service_url_func = getattr(self, "service_url", None)
|
||||
service_url = str(service_url_func() if callable(service_url_func) else "unknown")
|
||||
attributes = _get_span_attributes(
|
||||
@@ -1188,7 +1214,7 @@ class ChatTelemetryLayer(Generic[OptionsCoT]):
|
||||
provider_name=provider_name,
|
||||
model=model_id,
|
||||
service_url=service_url,
|
||||
**kwargs,
|
||||
**merged_client_kwargs,
|
||||
)
|
||||
|
||||
if stream:
|
||||
@@ -1200,7 +1226,7 @@ class ChatTelemetryLayer(Generic[OptionsCoT]):
|
||||
options=opts,
|
||||
compaction_strategy=compaction_strategy,
|
||||
tokenizer=tokenizer,
|
||||
**kwargs,
|
||||
**merged_client_kwargs,
|
||||
),
|
||||
)
|
||||
|
||||
@@ -1291,7 +1317,7 @@ class ChatTelemetryLayer(Generic[OptionsCoT]):
|
||||
options=opts,
|
||||
compaction_strategy=compaction_strategy,
|
||||
tokenizer=tokenizer,
|
||||
**kwargs,
|
||||
**merged_client_kwargs,
|
||||
),
|
||||
)
|
||||
except Exception as exception:
|
||||
@@ -1420,6 +1446,8 @@ class AgentTelemetryLayer:
|
||||
session: AgentSession | None = None,
|
||||
compaction_strategy: CompactionStrategy | None = None,
|
||||
tokenizer: TokenizerProtocol | None = None,
|
||||
function_invocation_kwargs: Mapping[str, Any] | None = None,
|
||||
client_kwargs: Mapping[str, Any] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> Awaitable[AgentResponse[Any]]: ...
|
||||
|
||||
@@ -1432,6 +1460,8 @@ class AgentTelemetryLayer:
|
||||
session: AgentSession | None = None,
|
||||
compaction_strategy: CompactionStrategy | None = None,
|
||||
tokenizer: TokenizerProtocol | None = None,
|
||||
function_invocation_kwargs: Mapping[str, Any] | None = None,
|
||||
client_kwargs: Mapping[str, Any] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> ResponseStream[AgentResponseUpdate, AgentResponse[Any]]: ...
|
||||
|
||||
@@ -1443,6 +1473,8 @@ class AgentTelemetryLayer:
|
||||
session: AgentSession | None = None,
|
||||
compaction_strategy: CompactionStrategy | None = None,
|
||||
tokenizer: TokenizerProtocol | None = None,
|
||||
function_invocation_kwargs: Mapping[str, Any] | None = None,
|
||||
client_kwargs: Mapping[str, Any] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> Awaitable[AgentResponse[Any]] | ResponseStream[AgentResponseUpdate, AgentResponse[Any]]:
|
||||
"""Trace agent runs with OpenTelemetry spans and metrics."""
|
||||
@@ -1463,11 +1495,15 @@ class AgentTelemetryLayer:
|
||||
session=session,
|
||||
compaction_strategy=compaction_strategy,
|
||||
tokenizer=tokenizer,
|
||||
function_invocation_kwargs=function_invocation_kwargs,
|
||||
client_kwargs=client_kwargs,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
default_options = getattr(self, "default_options", {})
|
||||
options = kwargs.get("options")
|
||||
merged_client_kwargs = dict(client_kwargs) if client_kwargs is not None else {}
|
||||
merged_client_kwargs.update(kwargs)
|
||||
merged_options: dict[str, Any] = merge_chat_options(default_options, options or {})
|
||||
attributes = _get_span_attributes(
|
||||
operation_name=OtelAttr.AGENT_INVOKE_OPERATION,
|
||||
@@ -1477,7 +1513,7 @@ class AgentTelemetryLayer:
|
||||
agent_description=getattr(self, "description", None),
|
||||
thread_id=session.service_session_id if session else None,
|
||||
all_options=merged_options,
|
||||
**kwargs,
|
||||
**merged_client_kwargs,
|
||||
)
|
||||
|
||||
if stream:
|
||||
@@ -1487,6 +1523,8 @@ class AgentTelemetryLayer:
|
||||
session=session,
|
||||
compaction_strategy=compaction_strategy,
|
||||
tokenizer=tokenizer,
|
||||
function_invocation_kwargs=function_invocation_kwargs,
|
||||
client_kwargs=client_kwargs,
|
||||
**kwargs,
|
||||
)
|
||||
if isinstance(run_result, ResponseStream):
|
||||
@@ -1578,6 +1616,8 @@ class AgentTelemetryLayer:
|
||||
session=session,
|
||||
compaction_strategy=compaction_strategy,
|
||||
tokenizer=tokenizer,
|
||||
function_invocation_kwargs=function_invocation_kwargs,
|
||||
client_kwargs=client_kwargs,
|
||||
**kwargs,
|
||||
)
|
||||
except Exception as exception:
|
||||
|
||||
@@ -15,7 +15,7 @@ from collections.abc import (
|
||||
)
|
||||
from datetime import datetime, timezone
|
||||
from itertools import chain
|
||||
from typing import Any, Generic, Literal, cast
|
||||
from typing import Any, Generic, Literal, cast, overload
|
||||
|
||||
from openai import AsyncOpenAI, BadRequestError
|
||||
from openai.lib._parsing._completions import type_to_response_format_param
|
||||
@@ -30,7 +30,8 @@ from openai.types.chat.completion_create_params import WebSearchOptions
|
||||
from pydantic import BaseModel
|
||||
|
||||
from .._clients import BaseChatClient
|
||||
from .._middleware import ChatAndFunctionMiddlewareTypes, ChatMiddlewareLayer
|
||||
from .._docstrings import apply_layered_docstring
|
||||
from .._middleware import ChatAndFunctionMiddlewareTypes, ChatMiddlewareLayer, FunctionMiddlewareTypes
|
||||
from .._settings import load_settings
|
||||
from .._tools import (
|
||||
FunctionInvocationConfiguration,
|
||||
@@ -72,6 +73,7 @@ else:
|
||||
|
||||
logger = logging.getLogger("agent_framework.openai")
|
||||
|
||||
ResponseModelBoundT = TypeVar("ResponseModelBoundT", bound=BaseModel)
|
||||
ResponseModelT = TypeVar("ResponseModelT", bound=BaseModel | None, default=None)
|
||||
|
||||
|
||||
@@ -213,6 +215,57 @@ class RawOpenAIChatClient( # type: ignore[misc]
|
||||
|
||||
# endregion
|
||||
|
||||
@overload
|
||||
def get_response(
|
||||
self,
|
||||
messages: Sequence[Message],
|
||||
*,
|
||||
stream: Literal[False] = ...,
|
||||
options: ChatOptions[ResponseModelBoundT],
|
||||
**kwargs: Any,
|
||||
) -> Awaitable[ChatResponse[ResponseModelBoundT]]: ...
|
||||
|
||||
@overload
|
||||
def get_response(
|
||||
self,
|
||||
messages: Sequence[Message],
|
||||
*,
|
||||
stream: Literal[False] = ...,
|
||||
options: OpenAIChatOptionsT | ChatOptions[None] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> Awaitable[ChatResponse[Any]]: ...
|
||||
|
||||
@overload
|
||||
def get_response(
|
||||
self,
|
||||
messages: Sequence[Message],
|
||||
*,
|
||||
stream: Literal[True],
|
||||
options: OpenAIChatOptionsT | ChatOptions[Any] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> ResponseStream[ChatResponseUpdate, ChatResponse[Any]]: ...
|
||||
|
||||
@override
|
||||
def get_response(
|
||||
self,
|
||||
messages: Sequence[Message],
|
||||
*,
|
||||
stream: bool = False,
|
||||
options: OpenAIChatOptionsT | ChatOptions[Any] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> Awaitable[ChatResponse[Any]] | ResponseStream[ChatResponseUpdate, ChatResponse[Any]]:
|
||||
"""Get a response from the raw OpenAI chat client."""
|
||||
super_get_response = cast(
|
||||
"Callable[..., Awaitable[ChatResponse[Any]] | ResponseStream[ChatResponseUpdate, ChatResponse[Any]]]",
|
||||
super().get_response, # type: ignore[misc]
|
||||
)
|
||||
return super_get_response( # type: ignore[no-any-return]
|
||||
messages=messages,
|
||||
stream=stream,
|
||||
options=options,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@override
|
||||
def _inner_get_response(
|
||||
self,
|
||||
@@ -727,6 +780,77 @@ class OpenAIChatClient( # type: ignore[misc]
|
||||
):
|
||||
"""OpenAI Chat completion class with middleware, telemetry, and function invocation support."""
|
||||
|
||||
@overload
|
||||
def get_response(
|
||||
self,
|
||||
messages: Sequence[Message],
|
||||
*,
|
||||
stream: Literal[False] = ...,
|
||||
options: ChatOptions[ResponseModelBoundT],
|
||||
function_middleware: Sequence[FunctionMiddlewareTypes] | None = None,
|
||||
function_invocation_kwargs: Mapping[str, Any] | None = None,
|
||||
client_kwargs: Mapping[str, Any] | None = None,
|
||||
middleware: Sequence[ChatAndFunctionMiddlewareTypes] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> Awaitable[ChatResponse[ResponseModelBoundT]]: ...
|
||||
|
||||
@overload
|
||||
def get_response(
|
||||
self,
|
||||
messages: Sequence[Message],
|
||||
*,
|
||||
stream: Literal[False] = ...,
|
||||
options: OpenAIChatOptionsT | ChatOptions[None] | None = None,
|
||||
function_middleware: Sequence[FunctionMiddlewareTypes] | None = None,
|
||||
function_invocation_kwargs: Mapping[str, Any] | None = None,
|
||||
client_kwargs: Mapping[str, Any] | None = None,
|
||||
middleware: Sequence[ChatAndFunctionMiddlewareTypes] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> Awaitable[ChatResponse[Any]]: ...
|
||||
|
||||
@overload
|
||||
def get_response(
|
||||
self,
|
||||
messages: Sequence[Message],
|
||||
*,
|
||||
stream: Literal[True],
|
||||
options: OpenAIChatOptionsT | ChatOptions[Any] | None = None,
|
||||
function_middleware: Sequence[FunctionMiddlewareTypes] | None = None,
|
||||
function_invocation_kwargs: Mapping[str, Any] | None = None,
|
||||
client_kwargs: Mapping[str, Any] | None = None,
|
||||
middleware: Sequence[ChatAndFunctionMiddlewareTypes] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> ResponseStream[ChatResponseUpdate, ChatResponse[Any]]: ...
|
||||
|
||||
@override
|
||||
def get_response(
|
||||
self,
|
||||
messages: Sequence[Message],
|
||||
*,
|
||||
stream: bool = False,
|
||||
options: OpenAIChatOptionsT | ChatOptions[Any] | None = None,
|
||||
function_middleware: Sequence[FunctionMiddlewareTypes] | None = None,
|
||||
function_invocation_kwargs: Mapping[str, Any] | None = None,
|
||||
client_kwargs: Mapping[str, Any] | None = None,
|
||||
middleware: Sequence[ChatAndFunctionMiddlewareTypes] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> Awaitable[ChatResponse[Any]] | ResponseStream[ChatResponseUpdate, ChatResponse[Any]]:
|
||||
"""Get a response from the OpenAI chat client with all standard layers enabled."""
|
||||
super_get_response = cast(
|
||||
"Callable[..., Awaitable[ChatResponse[Any]] | ResponseStream[ChatResponseUpdate, ChatResponse[Any]]]",
|
||||
super().get_response, # type: ignore[misc]
|
||||
)
|
||||
return super_get_response( # type: ignore[no-any-return]
|
||||
messages=messages,
|
||||
stream=stream,
|
||||
options=options,
|
||||
function_middleware=function_middleware,
|
||||
function_invocation_kwargs=function_invocation_kwargs,
|
||||
client_kwargs=client_kwargs,
|
||||
middleware=middleware,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
*,
|
||||
@@ -830,3 +954,25 @@ class OpenAIChatClient( # type: ignore[misc]
|
||||
middleware=middleware,
|
||||
function_invocation_configuration=function_invocation_configuration,
|
||||
)
|
||||
|
||||
|
||||
def _apply_openai_chat_client_docstrings() -> None:
|
||||
"""Align OpenAI chat-client docstrings with the raw implementation."""
|
||||
apply_layered_docstring(RawOpenAIChatClient.get_response, BaseChatClient.get_response)
|
||||
apply_layered_docstring(
|
||||
OpenAIChatClient.get_response,
|
||||
RawOpenAIChatClient.get_response,
|
||||
extra_keyword_args={
|
||||
"function_middleware": """
|
||||
Optional per-call function middleware.
|
||||
When omitted, middleware configured on the client or forwarded from higher layers is used.
|
||||
""",
|
||||
"middleware": """
|
||||
Optional per-call chat and function middleware.
|
||||
This is merged with any middleware configured on the client for the current request.
|
||||
""",
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
_apply_openai_chat_client_docstrings()
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
# Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
import contextlib
|
||||
import inspect
|
||||
from collections.abc import AsyncIterable, MutableSequence
|
||||
from typing import Any
|
||||
from unittest.mock import AsyncMock, MagicMock
|
||||
@@ -31,6 +32,7 @@ from agent_framework import (
|
||||
)
|
||||
from agent_framework._agents import _get_tool_name, _merge_options, _sanitize_agent_name
|
||||
from agent_framework._mcp import MCPTool, _build_prefixed_mcp_name, _normalize_mcp_name
|
||||
from agent_framework._middleware import FunctionInvocationContext
|
||||
|
||||
|
||||
class _FixedTokenizer:
|
||||
@@ -101,6 +103,30 @@ def test_chat_client_agent_type(client: SupportsChatGetResponse) -> None:
|
||||
assert isinstance(chat_client_agent, SupportsAgentRun)
|
||||
|
||||
|
||||
def test_agent_init_docstring_surfaces_raw_agent_constructor_docs() -> None:
|
||||
docstring = inspect.getdoc(Agent.__init__)
|
||||
|
||||
assert docstring is not None
|
||||
assert "client: The chat client to use for the agent." in docstring
|
||||
assert "middleware: List of middleware to intercept agent and function invocations." in docstring
|
||||
|
||||
|
||||
def test_agent_run_docstring_surfaces_raw_agent_runtime_docs() -> None:
|
||||
docstring = inspect.getdoc(Agent.run)
|
||||
|
||||
assert docstring is not None
|
||||
assert "Run the agent with the given messages and options." in docstring
|
||||
assert "function_invocation_kwargs: Keyword arguments forwarded to tool invocation." in docstring
|
||||
assert "middleware: Optional per-run agent, chat, and function middleware." in docstring
|
||||
|
||||
|
||||
def test_agent_run_is_defined_on_agent_class() -> None:
|
||||
signature = inspect.signature(Agent.run)
|
||||
|
||||
assert Agent.run.__qualname__ == "Agent.run"
|
||||
assert "middleware" in signature.parameters
|
||||
|
||||
|
||||
async def test_chat_client_agent_init(client: SupportsChatGetResponse) -> None:
|
||||
agent_id = str(uuid4())
|
||||
agent = Agent(client=client, id=agent_id, description="Test")
|
||||
@@ -121,6 +147,13 @@ async def test_chat_client_agent_init_with_name(
|
||||
assert agent.description == "Test"
|
||||
|
||||
|
||||
def test_agent_init_warns_for_direct_additional_properties(client: SupportsChatGetResponse) -> None:
|
||||
with pytest.warns(DeprecationWarning, match="additional_properties"):
|
||||
agent = Agent(client=client, legacy_key="legacy-value")
|
||||
|
||||
assert agent.additional_properties["legacy_key"] == "legacy-value"
|
||||
|
||||
|
||||
async def test_chat_client_agent_run(client: SupportsChatGetResponse) -> None:
|
||||
agent = Agent(client=client)
|
||||
|
||||
@@ -253,33 +286,38 @@ async def test_prepare_session_does_not_mutate_agent_chat_options(
|
||||
assert len(agent.default_options["tools"]) == 1
|
||||
|
||||
|
||||
async def test_prepare_run_context_keeps_compaction_overrides_out_of_kwargs(
|
||||
async def test_prepare_run_context_handles_function_kwargs(
|
||||
chat_client_base: SupportsChatGetResponse,
|
||||
) -> None:
|
||||
strategy = SlidingWindowStrategy(keep_last_groups=2)
|
||||
tokenizer = _FixedTokenizer(13)
|
||||
agent = Agent(client=chat_client_base)
|
||||
session = agent.create_session()
|
||||
|
||||
ctx = await agent._prepare_run_context( # type: ignore[reportPrivateUsage]
|
||||
messages=[Message(role="user", text="Hello")],
|
||||
session=None,
|
||||
messages="Hello",
|
||||
session=session,
|
||||
tools=None,
|
||||
options=None,
|
||||
compaction_strategy=strategy,
|
||||
tokenizer=tokenizer,
|
||||
kwargs={"custom_flag": True},
|
||||
options={
|
||||
"temperature": 0.4,
|
||||
"additional_function_arguments": {"from_options": "options-value"},
|
||||
},
|
||||
compaction_strategy=None,
|
||||
tokenizer=None,
|
||||
legacy_kwargs={"legacy_key": "legacy-value"},
|
||||
function_invocation_kwargs={"runtime_key": "runtime-value"},
|
||||
client_kwargs={"client_key": "client-value"},
|
||||
)
|
||||
|
||||
assert ctx["compaction_strategy"] is strategy
|
||||
assert ctx["tokenizer"] is tokenizer
|
||||
assert ctx["filtered_kwargs"].get("custom_flag") is True
|
||||
assert "compaction_strategy" not in ctx["filtered_kwargs"]
|
||||
assert "tokenizer" not in ctx["filtered_kwargs"]
|
||||
assert ctx["chat_options"]["temperature"] == 0.4
|
||||
assert "additional_function_arguments" not in ctx["chat_options"]
|
||||
assert ctx["function_invocation_kwargs"]["from_options"] == "options-value"
|
||||
assert ctx["function_invocation_kwargs"]["legacy_key"] == "legacy-value"
|
||||
assert ctx["function_invocation_kwargs"]["runtime_key"] == "runtime-value"
|
||||
assert "session" not in ctx["function_invocation_kwargs"]
|
||||
assert ctx["client_kwargs"]["client_key"] == "client-value"
|
||||
assert ctx["client_kwargs"]["session"] is session
|
||||
|
||||
|
||||
async def test_chat_client_agent_run_with_session(
|
||||
chat_client_base: SupportsChatGetResponse,
|
||||
) -> None:
|
||||
async def test_chat_client_agent_run_with_session(chat_client_base: SupportsChatGetResponse) -> None:
|
||||
mock_response = ChatResponse(
|
||||
messages=[Message(role="assistant", contents=[Content.from_text("test response")])],
|
||||
conversation_id="123",
|
||||
@@ -720,8 +758,9 @@ async def test_chat_agent_as_tool_basic(client: SupportsChatGetResponse) -> None
|
||||
|
||||
assert tool.name == "TestAgent"
|
||||
assert tool.description == "Test agent for as_tool"
|
||||
assert tool.approval_mode == "never_require"
|
||||
assert hasattr(tool, "func")
|
||||
assert hasattr(tool, "input_model")
|
||||
assert tool.input_model is None
|
||||
|
||||
|
||||
async def test_chat_agent_as_tool_custom_parameters(
|
||||
@@ -735,13 +774,15 @@ async def test_chat_agent_as_tool_custom_parameters(
|
||||
description="Custom description",
|
||||
arg_name="query",
|
||||
arg_description="Custom input description",
|
||||
approval_mode="always_require",
|
||||
)
|
||||
|
||||
assert tool.name == "CustomTool"
|
||||
assert tool.description == "Custom description"
|
||||
assert tool.approval_mode == "always_require"
|
||||
|
||||
# Check that the input model has the custom field name
|
||||
schema = tool.input_model.model_json_schema()
|
||||
schema = tool.parameters()
|
||||
assert "query" in schema["properties"]
|
||||
assert schema["properties"]["query"]["description"] == "Custom input description"
|
||||
|
||||
@@ -760,7 +801,7 @@ async def test_chat_agent_as_tool_defaults(client: SupportsChatGetResponse) -> N
|
||||
assert tool.description == "" # Should default to empty string
|
||||
|
||||
# Check default input field
|
||||
schema = tool.input_model.model_json_schema()
|
||||
schema = tool.parameters()
|
||||
assert "task" in schema["properties"]
|
||||
assert "Task for TestAgent" in schema["properties"]["task"]["description"]
|
||||
|
||||
@@ -783,12 +824,12 @@ async def test_chat_agent_as_tool_function_execution(
|
||||
tool = agent.as_tool()
|
||||
|
||||
# Test function execution
|
||||
result = await tool.invoke(arguments=tool.input_model(task="Hello"))
|
||||
result = await tool.invoke(arguments={"task": "Hello"})
|
||||
|
||||
# Should return the agent's response text as a list of Content items
|
||||
assert isinstance(result, list)
|
||||
assert len(result) == 1
|
||||
assert result[0].text == "test response" # From mock chat client
|
||||
assert result[0].text == "test streaming response another update" # From mock streaming client
|
||||
|
||||
|
||||
async def test_chat_agent_as_tool_with_stream_callback(
|
||||
@@ -806,7 +847,7 @@ async def test_chat_agent_as_tool_with_stream_callback(
|
||||
tool = agent.as_tool(stream_callback=stream_callback)
|
||||
|
||||
# Execute the tool
|
||||
result = await tool.invoke(arguments=tool.input_model(task="Hello"))
|
||||
result = await tool.invoke(arguments={"task": "Hello"})
|
||||
|
||||
# Should have collected streaming updates
|
||||
assert len(collected_updates) > 0
|
||||
@@ -826,9 +867,9 @@ async def test_chat_agent_as_tool_with_custom_arg_name(
|
||||
tool = agent.as_tool(arg_name="prompt", arg_description="Custom prompt input")
|
||||
|
||||
# Test that the custom argument name works
|
||||
result = await tool.invoke(arguments=tool.input_model(prompt="Test prompt"))
|
||||
result = await tool.invoke(arguments={"prompt": "Test prompt"})
|
||||
assert isinstance(result, list)
|
||||
assert result[0].text == "test response"
|
||||
assert result[0].text == "test streaming response another update"
|
||||
|
||||
|
||||
async def test_chat_agent_as_tool_with_async_stream_callback(
|
||||
@@ -846,7 +887,7 @@ async def test_chat_agent_as_tool_with_async_stream_callback(
|
||||
tool = agent.as_tool(stream_callback=async_stream_callback)
|
||||
|
||||
# Execute the tool
|
||||
result = await tool.invoke(arguments=tool.input_model(task="Hello"))
|
||||
result = await tool.invoke(arguments={"task": "Hello"})
|
||||
|
||||
# Should have collected streaming updates
|
||||
assert len(collected_updates) > 0
|
||||
@@ -877,17 +918,14 @@ async def test_chat_agent_as_tool_name_sanitization(
|
||||
assert tool.name == expected_tool_name, f"Expected {expected_tool_name}, got {tool.name} for input {agent_name}"
|
||||
|
||||
|
||||
async def test_chat_agent_as_tool_propagate_session_true(
|
||||
client: SupportsChatGetResponse,
|
||||
) -> None:
|
||||
"""Test that propagate_session=True forwards the parent's session to the sub-agent."""
|
||||
async def test_chat_agent_as_tool_propagate_session_true(client: SupportsChatGetResponse) -> None:
|
||||
"""Test that propagate_session=True forwards the session to the sub-agent."""
|
||||
agent = Agent(client=client, name="SubAgent", description="Sub agent")
|
||||
tool = agent.as_tool(propagate_session=True)
|
||||
|
||||
parent_session = AgentSession(session_id="parent-session-123")
|
||||
parent_session.state["shared_key"] = "shared_value"
|
||||
|
||||
# Spy on the agent's run method to capture the session argument
|
||||
original_run = agent.run
|
||||
captured_session = None
|
||||
|
||||
@@ -898,16 +936,20 @@ async def test_chat_agent_as_tool_propagate_session_true(
|
||||
|
||||
agent.run = capturing_run # type: ignore[assignment, method-assign]
|
||||
|
||||
await tool.invoke(arguments=tool.input_model(task="Hello"), session=parent_session)
|
||||
await tool.invoke(
|
||||
context=FunctionInvocationContext(
|
||||
function=tool,
|
||||
arguments={"task": "Hello"},
|
||||
session=parent_session,
|
||||
)
|
||||
)
|
||||
|
||||
assert captured_session is parent_session
|
||||
assert captured_session.session_id == "parent-session-123"
|
||||
assert captured_session.state["shared_key"] == "shared_value"
|
||||
|
||||
|
||||
async def test_chat_agent_as_tool_propagate_session_false_by_default(
|
||||
client: SupportsChatGetResponse,
|
||||
) -> None:
|
||||
async def test_chat_agent_as_tool_propagate_session_false_by_default(client: SupportsChatGetResponse) -> None:
|
||||
"""Test that propagate_session defaults to False and does not forward the session."""
|
||||
agent = Agent(client=client, name="SubAgent", description="Sub agent")
|
||||
tool = agent.as_tool() # default: propagate_session=False
|
||||
@@ -924,22 +966,25 @@ async def test_chat_agent_as_tool_propagate_session_false_by_default(
|
||||
|
||||
agent.run = capturing_run # type: ignore[assignment, method-assign]
|
||||
|
||||
await tool.invoke(arguments=tool.input_model(task="Hello"), session=parent_session)
|
||||
await tool.invoke(
|
||||
context=FunctionInvocationContext(
|
||||
function=tool,
|
||||
arguments={"task": "Hello"},
|
||||
session=parent_session,
|
||||
)
|
||||
)
|
||||
|
||||
assert captured_session is None
|
||||
|
||||
|
||||
async def test_chat_agent_as_tool_propagate_session_shares_state(
|
||||
client: SupportsChatGetResponse,
|
||||
) -> None:
|
||||
"""Test that shared session allows the sub-agent to read and write parent's state."""
|
||||
async def test_chat_agent_as_tool_propagate_session_shares_state(client: SupportsChatGetResponse) -> None:
|
||||
"""Test that a propagated session allows the sub-agent to read and write parent state."""
|
||||
agent = Agent(client=client, name="SubAgent", description="Sub agent")
|
||||
tool = agent.as_tool(propagate_session=True)
|
||||
|
||||
parent_session = AgentSession(session_id="shared-session")
|
||||
parent_session.state["counter"] = 0
|
||||
|
||||
# The sub-agent receives the same session object, so mutations are shared
|
||||
original_run = agent.run
|
||||
captured_session = None
|
||||
|
||||
@@ -952,9 +997,14 @@ async def test_chat_agent_as_tool_propagate_session_shares_state(
|
||||
|
||||
agent.run = capturing_run # type: ignore[assignment, method-assign]
|
||||
|
||||
await tool.invoke(arguments=tool.input_model(task="Hello"), session=parent_session)
|
||||
await tool.invoke(
|
||||
context=FunctionInvocationContext(
|
||||
function=tool,
|
||||
arguments={"task": "Hello"},
|
||||
session=parent_session,
|
||||
)
|
||||
)
|
||||
|
||||
# The parent's state should reflect the sub-agent's mutation
|
||||
assert parent_session.state["counter"] == 1
|
||||
|
||||
|
||||
@@ -1131,7 +1181,7 @@ async def test_agent_run_accepts_prefixed_mcp_tools(chat_client_base: Any) -> No
|
||||
|
||||
|
||||
async def test_agent_tool_receives_session_in_kwargs(chat_client_base: Any) -> None:
|
||||
"""Verify tool execution receives 'session' inside **kwargs when function is called by client."""
|
||||
"""Verify legacy **kwargs tools receive the session when agent.run() is called with one."""
|
||||
|
||||
captured: dict[str, Any] = {}
|
||||
|
||||
@@ -1142,7 +1192,6 @@ async def test_agent_tool_receives_session_in_kwargs(chat_client_base: Any) -> N
|
||||
captured["has_state"] = session.state is not None if isinstance(session, AgentSession) else False
|
||||
return f"echo: {text}"
|
||||
|
||||
# Make the base client emit a function call for our tool
|
||||
chat_client_base.run_responses = [
|
||||
ChatResponse(
|
||||
messages=Message(
|
||||
@@ -1162,17 +1211,52 @@ async def test_agent_tool_receives_session_in_kwargs(chat_client_base: Any) -> N
|
||||
agent = Agent(client=chat_client_base, tools=[echo_session_info])
|
||||
session = agent.create_session()
|
||||
|
||||
result = await agent.run(
|
||||
"hello",
|
||||
session=session,
|
||||
options={"additional_function_arguments": {"session": session}},
|
||||
)
|
||||
result = await agent.run("hello", session=session)
|
||||
|
||||
assert result.text == "done"
|
||||
assert captured.get("has_session") is True
|
||||
assert captured.get("has_state") is True
|
||||
|
||||
|
||||
async def test_agent_tool_receives_explicit_session_via_function_invocation_context_kwargs(
|
||||
chat_client_base: Any,
|
||||
) -> None:
|
||||
"""Verify ctx-based tools receive the session via FunctionInvocationContext.session."""
|
||||
|
||||
captured: dict[str, Any] = {}
|
||||
|
||||
@tool(name="capture_session_context", approval_mode="never_require")
|
||||
def capture_session_context(text: str, ctx: FunctionInvocationContext) -> str:
|
||||
captured["session"] = ctx.session
|
||||
captured["has_state"] = ctx.session.state is not None if isinstance(ctx.session, AgentSession) else False
|
||||
return f"echo: {text}"
|
||||
|
||||
chat_client_base.run_responses = [
|
||||
ChatResponse(
|
||||
messages=Message(
|
||||
role="assistant",
|
||||
contents=[
|
||||
Content.from_function_call(
|
||||
call_id="1",
|
||||
name="capture_session_context",
|
||||
arguments='{"text": "hello"}',
|
||||
)
|
||||
],
|
||||
)
|
||||
),
|
||||
ChatResponse(messages=Message(role="assistant", text="done")),
|
||||
]
|
||||
|
||||
agent = Agent(client=chat_client_base, tools=[capture_session_context])
|
||||
session = agent.create_session()
|
||||
|
||||
result = await agent.run("hello", session=session)
|
||||
|
||||
assert result.text == "done"
|
||||
assert captured["session"] is session
|
||||
assert captured["has_state"] is True
|
||||
|
||||
|
||||
async def test_chat_agent_tool_choice_run_level_overrides_agent_level(chat_client_base: Any, tool_tool: Any) -> None:
|
||||
"""Verify that tool_choice passed to run() overrides agent-level tool_choice."""
|
||||
|
||||
@@ -1859,4 +1943,26 @@ async def test_stores_by_default_with_store_false_in_default_options_injects_inm
|
||||
assert any(isinstance(p, InMemoryHistoryProvider) for p in agent.context_providers)
|
||||
|
||||
|
||||
# endregion
|
||||
# region as_tool user_input_request propagation
|
||||
|
||||
|
||||
async def test_as_tool_raises_on_user_input_request(client: SupportsChatGetResponse) -> None:
|
||||
"""Test that as_tool raises when the wrapped sub-agent requests user input."""
|
||||
from agent_framework.exceptions import UserInputRequiredException
|
||||
|
||||
consent_content = Content.from_oauth_consent_request(
|
||||
consent_link="https://login.microsoftonline.com/consent",
|
||||
)
|
||||
client.streaming_responses = [ # type: ignore[attr-defined]
|
||||
[ChatResponseUpdate(contents=[consent_content], role="assistant")],
|
||||
]
|
||||
|
||||
agent = Agent(client=client, name="OAuthAgent", description="Agent requiring consent")
|
||||
agent_tool = agent.as_tool()
|
||||
|
||||
with raises(UserInputRequiredException) as exc_info:
|
||||
await agent_tool.invoke(arguments={"task": "Do something"})
|
||||
|
||||
assert len(exc_info.value.contents) == 1
|
||||
assert exc_info.value.contents[0].type == "oauth_consent_request"
|
||||
assert exc_info.value.contents[0].consent_link == "https://login.microsoftonline.com/consent"
|
||||
|
||||
@@ -6,7 +6,7 @@ from collections.abc import Awaitable, Callable
|
||||
from typing import Any
|
||||
|
||||
from agent_framework import Agent, ChatResponse, Content, Message, agent_middleware
|
||||
from agent_framework._middleware import AgentContext
|
||||
from agent_framework._middleware import AgentContext, FunctionInvocationContext
|
||||
|
||||
from .conftest import MockChatClient
|
||||
|
||||
@@ -14,14 +14,28 @@ from .conftest import MockChatClient
|
||||
class TestAsToolKwargsPropagation:
|
||||
"""Test cases for kwargs propagation through as_tool() delegation."""
|
||||
|
||||
@staticmethod
|
||||
def _build_context(
|
||||
tool: Any,
|
||||
*,
|
||||
task: str,
|
||||
runtime_kwargs: dict[str, Any] | None = None,
|
||||
) -> FunctionInvocationContext:
|
||||
return FunctionInvocationContext(
|
||||
function=tool,
|
||||
arguments={"task": task},
|
||||
kwargs=runtime_kwargs,
|
||||
)
|
||||
|
||||
async def test_as_tool_forwards_runtime_kwargs(self, client: MockChatClient) -> None:
|
||||
"""Test that runtime kwargs are forwarded through as_tool() to sub-agent."""
|
||||
"""Test that runtime kwargs are forwarded through as_tool() to sub-agent tools."""
|
||||
captured_kwargs: dict[str, Any] = {}
|
||||
captured_function_invocation_kwargs: dict[str, Any] = {}
|
||||
|
||||
@agent_middleware
|
||||
async def capture_middleware(context: AgentContext, call_next: Callable[[], Awaitable[None]]) -> None:
|
||||
# Capture kwargs passed to the sub-agent
|
||||
captured_kwargs.update(context.kwargs)
|
||||
captured_function_invocation_kwargs.update(context.function_invocation_kwargs)
|
||||
await call_next()
|
||||
|
||||
# Setup mock response
|
||||
@@ -39,29 +53,31 @@ class TestAsToolKwargsPropagation:
|
||||
# Create tool from sub-agent
|
||||
tool = sub_agent.as_tool(name="delegate", arg_name="task")
|
||||
|
||||
# Directly invoke the tool with kwargs (simulating what happens during agent execution)
|
||||
# Directly invoke the tool with explicit runtime context (simulating agent execution).
|
||||
_ = await tool.invoke(
|
||||
arguments=tool.input_model(task="Test delegation"),
|
||||
api_token="secret-xyz-123",
|
||||
user_id="user-456",
|
||||
session_id="session-789",
|
||||
context=self._build_context(
|
||||
tool,
|
||||
task="Test delegation",
|
||||
runtime_kwargs={
|
||||
"api_token": "secret-xyz-123",
|
||||
"user_id": "user-456",
|
||||
"session_id": "session-789",
|
||||
},
|
||||
),
|
||||
)
|
||||
|
||||
# Verify kwargs were forwarded to sub-agent
|
||||
assert "api_token" in captured_kwargs, f"Expected 'api_token' in {captured_kwargs}"
|
||||
assert captured_kwargs["api_token"] == "secret-xyz-123"
|
||||
assert "user_id" in captured_kwargs
|
||||
assert captured_kwargs["user_id"] == "user-456"
|
||||
assert "session_id" in captured_kwargs
|
||||
assert captured_kwargs["session_id"] == "session-789"
|
||||
assert captured_kwargs == {}
|
||||
assert captured_function_invocation_kwargs["api_token"] == "secret-xyz-123"
|
||||
assert captured_function_invocation_kwargs["user_id"] == "user-456"
|
||||
assert captured_function_invocation_kwargs["session_id"] == "session-789"
|
||||
|
||||
async def test_as_tool_excludes_arg_name_from_forwarded_kwargs(self, client: MockChatClient) -> None:
|
||||
"""Test that the arg_name parameter is not forwarded as a kwarg."""
|
||||
captured_kwargs: dict[str, Any] = {}
|
||||
async def test_as_tool_forwards_context_kwargs_verbatim(self, client: MockChatClient) -> None:
|
||||
"""Test that runtime kwargs are forwarded exactly from FunctionInvocationContext.kwargs."""
|
||||
captured_function_invocation_kwargs: dict[str, Any] = {}
|
||||
|
||||
@agent_middleware
|
||||
async def capture_middleware(context: AgentContext, call_next: Callable[[], Awaitable[None]]) -> None:
|
||||
captured_kwargs.update(context.kwargs)
|
||||
captured_function_invocation_kwargs.update(context.function_invocation_kwargs)
|
||||
await call_next()
|
||||
|
||||
# Setup mock response
|
||||
@@ -79,25 +95,26 @@ class TestAsToolKwargsPropagation:
|
||||
|
||||
# Invoke tool with both the arg_name field and additional kwargs
|
||||
await tool.invoke(
|
||||
arguments=tool.input_model(custom_task="Test task"),
|
||||
api_token="token-123",
|
||||
custom_task="should_be_excluded", # This should be filtered out
|
||||
context=FunctionInvocationContext(
|
||||
function=tool,
|
||||
arguments={"custom_task": "Test task"},
|
||||
kwargs={
|
||||
"api_token": "token-123",
|
||||
"custom_task": "should_be_excluded",
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
# The arg_name ("custom_task") should NOT be in the forwarded kwargs
|
||||
assert "custom_task" not in captured_kwargs
|
||||
# But other kwargs should be present
|
||||
assert "api_token" in captured_kwargs
|
||||
assert captured_kwargs["api_token"] == "token-123"
|
||||
assert captured_function_invocation_kwargs["custom_task"] == "should_be_excluded"
|
||||
assert captured_function_invocation_kwargs["api_token"] == "token-123"
|
||||
|
||||
async def test_as_tool_nested_delegation_propagates_kwargs(self, client: MockChatClient) -> None:
|
||||
"""Test that kwargs propagate through multiple levels of delegation (A → B → C)."""
|
||||
captured_kwargs_list: list[dict[str, Any]] = []
|
||||
"""Test that runtime kwargs propagate through multiple levels of delegation (A -> B -> C)."""
|
||||
captured_function_invocation_kwargs_list: list[dict[str, Any]] = []
|
||||
|
||||
@agent_middleware
|
||||
async def capture_middleware(context: AgentContext, call_next: Callable[[], Awaitable[None]]) -> None:
|
||||
# Capture kwargs at each level
|
||||
captured_kwargs_list.append(dict(context.kwargs))
|
||||
captured_function_invocation_kwargs_list.append(dict(context.function_invocation_kwargs))
|
||||
await call_next()
|
||||
|
||||
# Setup mock responses to trigger nested tool invocation: B calls tool C, then completes.
|
||||
@@ -140,24 +157,29 @@ class TestAsToolKwargsPropagation:
|
||||
|
||||
# Invoke tool B with kwargs - should propagate to both B and C
|
||||
await tool_b.invoke(
|
||||
arguments=tool_b.input_model(task="Test cascade"),
|
||||
trace_id="trace-abc-123",
|
||||
tenant_id="tenant-xyz",
|
||||
options={"additional_function_arguments": {"trace_id": "trace-abc-123", "tenant_id": "tenant-xyz"}},
|
||||
context=self._build_context(
|
||||
tool_b,
|
||||
task="Test cascade",
|
||||
runtime_kwargs={
|
||||
"trace_id": "trace-abc-123",
|
||||
"tenant_id": "tenant-xyz",
|
||||
},
|
||||
),
|
||||
)
|
||||
|
||||
# Verify kwargs were forwarded to the first agent invocation.
|
||||
assert len(captured_kwargs_list) >= 1
|
||||
assert captured_kwargs_list[0].get("trace_id") == "trace-abc-123"
|
||||
assert captured_kwargs_list[0].get("tenant_id") == "tenant-xyz"
|
||||
assert len(captured_function_invocation_kwargs_list) >= 1
|
||||
assert captured_function_invocation_kwargs_list[0].get("trace_id") == "trace-abc-123"
|
||||
assert captured_function_invocation_kwargs_list[0].get("tenant_id") == "tenant-xyz"
|
||||
|
||||
async def test_as_tool_streaming_mode_forwards_kwargs(self, client: MockChatClient) -> None:
|
||||
"""Test that kwargs are forwarded in streaming mode."""
|
||||
"""Test that runtime kwargs are forwarded in streaming mode."""
|
||||
captured_kwargs: dict[str, Any] = {}
|
||||
captured_function_invocation_kwargs: dict[str, Any] = {}
|
||||
|
||||
@agent_middleware
|
||||
async def capture_middleware(context: AgentContext, call_next: Callable[[], Awaitable[None]]) -> None:
|
||||
captured_kwargs.update(context.kwargs)
|
||||
captured_function_invocation_kwargs.update(context.function_invocation_kwargs)
|
||||
await call_next()
|
||||
|
||||
# Setup mock streaming responses
|
||||
@@ -182,13 +204,15 @@ class TestAsToolKwargsPropagation:
|
||||
|
||||
# Invoke tool with kwargs while streaming callback is active
|
||||
await tool.invoke(
|
||||
arguments=tool.input_model(task="Test streaming"),
|
||||
api_key="streaming-key-999",
|
||||
context=self._build_context(
|
||||
tool,
|
||||
task="Test streaming",
|
||||
runtime_kwargs={"api_key": "streaming-key-999"},
|
||||
),
|
||||
)
|
||||
|
||||
# Verify kwargs were forwarded even in streaming mode
|
||||
assert "api_key" in captured_kwargs
|
||||
assert captured_kwargs["api_key"] == "streaming-key-999"
|
||||
assert captured_kwargs == {}
|
||||
assert captured_function_invocation_kwargs["api_key"] == "streaming-key-999"
|
||||
assert len(captured_updates) == 1
|
||||
|
||||
async def test_as_tool_empty_kwargs_still_works(self, client: MockChatClient) -> None:
|
||||
@@ -206,18 +230,20 @@ class TestAsToolKwargsPropagation:
|
||||
tool = sub_agent.as_tool()
|
||||
|
||||
# Invoke without any extra kwargs - should work without errors
|
||||
result = await tool.invoke(arguments=tool.input_model(task="Simple task"))
|
||||
result = await tool.invoke(arguments={"task": "Simple task"})
|
||||
|
||||
# Verify tool executed successfully
|
||||
assert result is not None
|
||||
|
||||
async def test_as_tool_kwargs_with_chat_options(self, client: MockChatClient) -> None:
|
||||
"""Test that kwargs including chat_options are properly forwarded."""
|
||||
"""Test that runtime kwargs are forwarded only via function_invocation_kwargs."""
|
||||
captured_kwargs: dict[str, Any] = {}
|
||||
captured_function_invocation_kwargs: dict[str, Any] = {}
|
||||
|
||||
@agent_middleware
|
||||
async def capture_middleware(context: AgentContext, call_next: Callable[[], Awaitable[None]]) -> None:
|
||||
captured_kwargs.update(context.kwargs)
|
||||
captured_function_invocation_kwargs.update(context.function_invocation_kwargs)
|
||||
await call_next()
|
||||
|
||||
# Setup mock response
|
||||
@@ -235,24 +261,26 @@ class TestAsToolKwargsPropagation:
|
||||
|
||||
# Invoke with various kwargs
|
||||
await tool.invoke(
|
||||
arguments=tool.input_model(task="Test with options"),
|
||||
temperature=0.8,
|
||||
max_tokens=500,
|
||||
custom_param="custom_value",
|
||||
context=self._build_context(
|
||||
tool,
|
||||
task="Test with options",
|
||||
runtime_kwargs={
|
||||
"temperature": 0.8,
|
||||
"max_tokens": 500,
|
||||
"custom_param": "custom_value",
|
||||
},
|
||||
),
|
||||
)
|
||||
|
||||
# Verify all kwargs were forwarded
|
||||
assert "temperature" in captured_kwargs
|
||||
assert captured_kwargs["temperature"] == 0.8
|
||||
assert "max_tokens" in captured_kwargs
|
||||
assert captured_kwargs["max_tokens"] == 500
|
||||
assert "custom_param" in captured_kwargs
|
||||
assert captured_kwargs["custom_param"] == "custom_value"
|
||||
assert captured_kwargs == {}
|
||||
assert captured_function_invocation_kwargs["temperature"] == 0.8
|
||||
assert captured_function_invocation_kwargs["max_tokens"] == 500
|
||||
assert captured_function_invocation_kwargs["custom_param"] == "custom_value"
|
||||
|
||||
async def test_as_tool_kwargs_isolated_per_invocation(self, client: MockChatClient) -> None:
|
||||
"""Test that kwargs are isolated per invocation and don't leak between calls."""
|
||||
first_call_kwargs: dict[str, Any] = {}
|
||||
second_call_kwargs: dict[str, Any] = {}
|
||||
"""Test that runtime kwargs are isolated per invocation and don't leak between calls."""
|
||||
first_call_function_invocation_kwargs: dict[str, Any] = {}
|
||||
second_call_function_invocation_kwargs: dict[str, Any] = {}
|
||||
call_count = 0
|
||||
|
||||
@agent_middleware
|
||||
@@ -260,9 +288,9 @@ class TestAsToolKwargsPropagation:
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
if call_count == 1:
|
||||
first_call_kwargs.update(context.kwargs)
|
||||
first_call_function_invocation_kwargs.update(context.function_invocation_kwargs)
|
||||
elif call_count == 2:
|
||||
second_call_kwargs.update(context.kwargs)
|
||||
second_call_function_invocation_kwargs.update(context.function_invocation_kwargs)
|
||||
await call_next()
|
||||
|
||||
# Setup mock responses for both calls
|
||||
@@ -281,33 +309,35 @@ class TestAsToolKwargsPropagation:
|
||||
|
||||
# First call with specific kwargs
|
||||
await tool.invoke(
|
||||
arguments=tool.input_model(task="First task"),
|
||||
session_id="session-1",
|
||||
api_token="token-1",
|
||||
context=self._build_context(
|
||||
tool,
|
||||
task="First task",
|
||||
runtime_kwargs={"session_id": "session-1", "api_token": "token-1"},
|
||||
),
|
||||
)
|
||||
|
||||
# Second call with different kwargs
|
||||
await tool.invoke(
|
||||
arguments=tool.input_model(task="Second task"),
|
||||
session_id="session-2",
|
||||
api_token="token-2",
|
||||
context=self._build_context(
|
||||
tool,
|
||||
task="Second task",
|
||||
runtime_kwargs={"session_id": "session-2", "api_token": "token-2"},
|
||||
),
|
||||
)
|
||||
|
||||
# Verify first call had its own kwargs
|
||||
assert first_call_kwargs.get("session_id") == "session-1"
|
||||
assert first_call_kwargs.get("api_token") == "token-1"
|
||||
assert first_call_function_invocation_kwargs.get("session_id") == "session-1"
|
||||
assert first_call_function_invocation_kwargs.get("api_token") == "token-1"
|
||||
|
||||
# Verify second call had its own kwargs (not leaked from first)
|
||||
assert second_call_kwargs.get("session_id") == "session-2"
|
||||
assert second_call_kwargs.get("api_token") == "token-2"
|
||||
assert second_call_function_invocation_kwargs.get("session_id") == "session-2"
|
||||
assert second_call_function_invocation_kwargs.get("api_token") == "token-2"
|
||||
|
||||
async def test_as_tool_excludes_conversation_id_from_forwarded_kwargs(self, client: MockChatClient) -> None:
|
||||
"""Test that conversation_id is not forwarded to sub-agent."""
|
||||
captured_kwargs: dict[str, Any] = {}
|
||||
async def test_as_tool_forwards_conversation_id_from_context_kwargs(self, client: MockChatClient) -> None:
|
||||
"""Test that conversation_id is forwarded when explicitly present in runtime context kwargs."""
|
||||
captured_function_invocation_kwargs: dict[str, Any] = {}
|
||||
|
||||
@agent_middleware
|
||||
async def capture_middleware(context: AgentContext, call_next: Callable[[], Awaitable[None]]) -> None:
|
||||
captured_kwargs.update(context.kwargs)
|
||||
captured_function_invocation_kwargs.update(context.function_invocation_kwargs)
|
||||
await call_next()
|
||||
|
||||
# Setup mock response
|
||||
@@ -325,17 +355,17 @@ class TestAsToolKwargsPropagation:
|
||||
|
||||
# Invoke tool with conversation_id in kwargs (simulating parent's conversation state)
|
||||
await tool.invoke(
|
||||
arguments=tool.input_model(task="Test delegation"),
|
||||
conversation_id="conv-parent-456",
|
||||
api_token="secret-xyz-123",
|
||||
user_id="user-456",
|
||||
context=self._build_context(
|
||||
tool,
|
||||
task="Test delegation",
|
||||
runtime_kwargs={
|
||||
"conversation_id": "conv-parent-456",
|
||||
"api_token": "secret-xyz-123",
|
||||
"user_id": "user-456",
|
||||
},
|
||||
),
|
||||
)
|
||||
|
||||
# Verify conversation_id was NOT forwarded to sub-agent
|
||||
assert "conversation_id" not in captured_kwargs, (
|
||||
f"conversation_id should not be forwarded, but got: {captured_kwargs}"
|
||||
)
|
||||
|
||||
# Verify other kwargs were still forwarded
|
||||
assert captured_kwargs.get("api_token") == "secret-xyz-123"
|
||||
assert captured_kwargs.get("user_id") == "user-456"
|
||||
assert captured_function_invocation_kwargs.get("conversation_id") == "conv-parent-456"
|
||||
assert captured_function_invocation_kwargs.get("api_token") == "secret-xyz-123"
|
||||
assert captured_function_invocation_kwargs.get("user_id") == "user-456"
|
||||
|
||||
@@ -1,9 +1,12 @@
|
||||
# Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
|
||||
import inspect
|
||||
from typing import Any
|
||||
from unittest.mock import patch
|
||||
|
||||
import pytest
|
||||
|
||||
from agent_framework import (
|
||||
GROUP_ANNOTATION_KEY,
|
||||
GROUP_TOKEN_COUNT_KEY,
|
||||
@@ -50,6 +53,60 @@ def test_base_client(chat_client_base: SupportsChatGetResponse):
|
||||
assert isinstance(chat_client_base, SupportsChatGetResponse)
|
||||
|
||||
|
||||
def test_base_client_warns_for_direct_additional_properties(chat_client_base: SupportsChatGetResponse) -> None:
|
||||
with pytest.warns(DeprecationWarning, match="additional_properties"):
|
||||
client = type(chat_client_base)(legacy_key="legacy-value")
|
||||
|
||||
assert client.additional_properties["legacy_key"] == "legacy-value"
|
||||
|
||||
|
||||
def test_base_client_as_agent_uses_explicit_additional_properties(chat_client_base: SupportsChatGetResponse) -> None:
|
||||
agent = chat_client_base.as_agent(additional_properties={"team": "core"})
|
||||
|
||||
assert agent.additional_properties == {"team": "core"}
|
||||
|
||||
|
||||
def test_openai_chat_client_get_response_docstring_surfaces_layered_runtime_docs() -> None:
|
||||
from agent_framework.openai import OpenAIChatClient
|
||||
|
||||
docstring = inspect.getdoc(OpenAIChatClient.get_response)
|
||||
|
||||
assert docstring is not None
|
||||
assert "Get a response from a chat client." in docstring
|
||||
assert "function_invocation_kwargs" in docstring
|
||||
assert "function_middleware: Optional per-call function middleware." in docstring
|
||||
assert "middleware: Optional per-call chat and function middleware." in docstring
|
||||
|
||||
|
||||
def test_openai_chat_client_get_response_is_defined_on_openai_class() -> None:
|
||||
from agent_framework.openai import OpenAIChatClient
|
||||
|
||||
signature = inspect.signature(OpenAIChatClient.get_response)
|
||||
|
||||
assert OpenAIChatClient.get_response.__qualname__ == "OpenAIChatClient.get_response"
|
||||
assert "function_middleware" in signature.parameters
|
||||
assert "middleware" in signature.parameters
|
||||
|
||||
|
||||
async def test_base_client_get_response_uses_explicit_client_kwargs(chat_client_base: SupportsChatGetResponse) -> None:
|
||||
async def fake_inner_get_response(**kwargs):
|
||||
assert kwargs["trace_id"] == "trace-123"
|
||||
assert "function_invocation_kwargs" not in kwargs
|
||||
return ChatResponse(messages=[Message(role="assistant", text="ok")])
|
||||
|
||||
with patch.object(
|
||||
chat_client_base,
|
||||
"_inner_get_response",
|
||||
side_effect=fake_inner_get_response,
|
||||
) as mock_inner_get_response:
|
||||
await chat_client_base.get_response(
|
||||
[Message(role="user", text="hello")],
|
||||
function_invocation_kwargs={"tool_request_id": "tool-123"},
|
||||
client_kwargs={"trace_id": "trace-123"},
|
||||
)
|
||||
mock_inner_get_response.assert_called_once()
|
||||
|
||||
|
||||
async def test_base_client_get_response(chat_client_base: SupportsChatGetResponse):
|
||||
response = await chat_client_base.get_response([Message(role="user", text="Hello")])
|
||||
assert response.messages[0].role == "assistant"
|
||||
|
||||
@@ -0,0 +1,175 @@
|
||||
# Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
from agent_framework._docstrings import apply_layered_docstring, build_layered_docstring
|
||||
|
||||
# -- Helpers: stub functions with various docstring shapes --
|
||||
|
||||
|
||||
def _source_with_full_docstring(x: int) -> int:
|
||||
"""Do something useful.
|
||||
|
||||
Args:
|
||||
x: The input value.
|
||||
|
||||
Keyword Args:
|
||||
timeout: Max seconds to wait.
|
||||
|
||||
Returns:
|
||||
The computed result.
|
||||
"""
|
||||
return x
|
||||
|
||||
|
||||
def _source_with_args_only(x: int) -> int:
|
||||
"""Do something useful.
|
||||
|
||||
Args:
|
||||
x: The input value.
|
||||
|
||||
Returns:
|
||||
The computed result.
|
||||
"""
|
||||
return x
|
||||
|
||||
|
||||
def _source_no_sections() -> None:
|
||||
"""A plain summary with no Google-style sections."""
|
||||
|
||||
|
||||
def _source_no_docstring() -> None:
|
||||
pass
|
||||
|
||||
|
||||
def _target_stub() -> None:
|
||||
pass
|
||||
|
||||
|
||||
# -- build_layered_docstring tests --
|
||||
|
||||
|
||||
def test_build_returns_none_when_source_has_no_docstring() -> None:
|
||||
result = build_layered_docstring(_source_no_docstring)
|
||||
assert result is None
|
||||
|
||||
|
||||
def test_build_returns_original_when_no_extra_kwargs() -> None:
|
||||
result = build_layered_docstring(_source_with_full_docstring)
|
||||
assert result is not None
|
||||
assert "Do something useful." in result
|
||||
assert "Keyword Args:" in result
|
||||
|
||||
|
||||
def test_build_returns_original_when_extra_kwargs_empty() -> None:
|
||||
result = build_layered_docstring(_source_with_full_docstring, extra_keyword_args={})
|
||||
assert result is not None
|
||||
assert result == build_layered_docstring(_source_with_full_docstring)
|
||||
|
||||
|
||||
def test_build_appends_to_existing_keyword_args_section() -> None:
|
||||
result = build_layered_docstring(
|
||||
_source_with_full_docstring,
|
||||
extra_keyword_args={"retries": "Number of retries."},
|
||||
)
|
||||
assert result is not None
|
||||
assert "timeout: Max seconds to wait." in result
|
||||
assert "retries: Number of retries." in result
|
||||
# Both should be under Keyword Args
|
||||
lines = result.splitlines()
|
||||
kw_index = next(i for i, line in enumerate(lines) if line == "Keyword Args:")
|
||||
ret_index = next(i for i, line in enumerate(lines) if line == "Returns:")
|
||||
retries_index = next(i for i, line in enumerate(lines) if "retries:" in line)
|
||||
assert kw_index < retries_index < ret_index
|
||||
|
||||
|
||||
def test_build_inserts_keyword_args_after_args_section() -> None:
|
||||
result = build_layered_docstring(
|
||||
_source_with_args_only,
|
||||
extra_keyword_args={"verbose": "Enable verbose output."},
|
||||
)
|
||||
assert result is not None
|
||||
assert "Keyword Args:" in result
|
||||
assert "verbose: Enable verbose output." in result
|
||||
lines = result.splitlines()
|
||||
args_index = next(i for i, line in enumerate(lines) if line == "Args:")
|
||||
kw_index = next(i for i, line in enumerate(lines) if line == "Keyword Args:")
|
||||
ret_index = next(i for i, line in enumerate(lines) if line == "Returns:")
|
||||
assert args_index < kw_index < ret_index
|
||||
|
||||
|
||||
def test_build_inserts_keyword_args_in_docstring_with_no_sections() -> None:
|
||||
result = build_layered_docstring(
|
||||
_source_no_sections,
|
||||
extra_keyword_args={"debug": "Enable debug mode."},
|
||||
)
|
||||
assert result is not None
|
||||
assert "A plain summary" in result
|
||||
assert "Keyword Args:" in result
|
||||
assert "debug: Enable debug mode." in result
|
||||
|
||||
|
||||
def test_build_handles_multiline_descriptions() -> None:
|
||||
result = build_layered_docstring(
|
||||
_source_with_args_only,
|
||||
extra_keyword_args={
|
||||
"config": "The configuration object.\nMust be a valid mapping.\nDefaults to empty.",
|
||||
},
|
||||
)
|
||||
assert result is not None
|
||||
lines = result.splitlines()
|
||||
config_line = next(line for line in lines if "config:" in line)
|
||||
assert "The configuration object." in config_line
|
||||
# Continuation lines should be indented
|
||||
config_idx = lines.index(config_line)
|
||||
assert "Must be a valid mapping." in lines[config_idx + 1]
|
||||
assert "Defaults to empty." in lines[config_idx + 2]
|
||||
|
||||
|
||||
def test_build_preserves_multiple_extra_kwargs_order() -> None:
|
||||
result = build_layered_docstring(
|
||||
_source_with_args_only,
|
||||
extra_keyword_args={
|
||||
"alpha": "First.",
|
||||
"beta": "Second.",
|
||||
"gamma": "Third.",
|
||||
},
|
||||
)
|
||||
assert result is not None
|
||||
lines = result.splitlines()
|
||||
alpha_idx = next(i for i, line in enumerate(lines) if "alpha:" in line)
|
||||
beta_idx = next(i for i, line in enumerate(lines) if "beta:" in line)
|
||||
gamma_idx = next(i for i, line in enumerate(lines) if "gamma:" in line)
|
||||
assert alpha_idx < beta_idx < gamma_idx
|
||||
|
||||
|
||||
# -- apply_layered_docstring tests --
|
||||
|
||||
|
||||
def test_apply_sets_docstring_on_target() -> None:
|
||||
def target() -> None:
|
||||
pass
|
||||
|
||||
apply_layered_docstring(target, _source_with_full_docstring)
|
||||
assert target.__doc__ is not None
|
||||
assert "Do something useful." in target.__doc__
|
||||
|
||||
|
||||
def test_apply_with_extra_kwargs() -> None:
|
||||
def target() -> None:
|
||||
pass
|
||||
|
||||
apply_layered_docstring(
|
||||
target,
|
||||
_source_with_args_only,
|
||||
extra_keyword_args={"flag": "A boolean flag."},
|
||||
)
|
||||
assert target.__doc__ is not None
|
||||
assert "flag: A boolean flag." in target.__doc__
|
||||
assert "Keyword Args:" in target.__doc__
|
||||
|
||||
|
||||
def test_apply_sets_none_when_source_has_no_docstring() -> None:
|
||||
def target() -> None:
|
||||
"""Original."""
|
||||
|
||||
apply_layered_docstring(target, _source_no_docstring)
|
||||
assert target.__doc__ is None
|
||||
@@ -4,6 +4,8 @@ from __future__ import annotations
|
||||
|
||||
from collections.abc import Sequence
|
||||
|
||||
import pytest
|
||||
|
||||
from agent_framework import (
|
||||
BaseEmbeddingClient,
|
||||
Embedding,
|
||||
@@ -63,6 +65,11 @@ def test_base_additional_properties_custom() -> None:
|
||||
assert client.additional_properties == {"key": "value"}
|
||||
|
||||
|
||||
def test_base_embedding_client_rejects_unknown_kwargs() -> None:
|
||||
with pytest.raises(TypeError):
|
||||
MockEmbeddingClient(legacy_key="value") # type: ignore[call-arg]
|
||||
|
||||
|
||||
# --- SupportsGetEmbeddings protocol tests ---
|
||||
|
||||
|
||||
|
||||
@@ -3651,3 +3651,131 @@ class TestUpdateConversationId:
|
||||
|
||||
|
||||
# endregion
|
||||
async def test_user_input_request_propagates_through_as_tool(chat_client_base: SupportsChatGetResponse):
|
||||
"""Test that user_input_request content from a sub-agent wrapped as a tool propagates to the parent response."""
|
||||
from agent_framework.exceptions import UserInputRequiredException
|
||||
|
||||
@tool(name="delegate_agent", approval_mode="never_require")
|
||||
def delegate_tool(task: str) -> str:
|
||||
del task
|
||||
raise UserInputRequiredException(
|
||||
contents=[
|
||||
Content.from_oauth_consent_request(
|
||||
consent_link="https://login.microsoftonline.com/consent",
|
||||
)
|
||||
]
|
||||
)
|
||||
|
||||
chat_client_base.run_responses = [
|
||||
ChatResponse(
|
||||
messages=Message(
|
||||
role="assistant",
|
||||
contents=[
|
||||
Content.from_function_call(call_id="1", name="delegate_agent", arguments='{"task": "do it"}'),
|
||||
],
|
||||
)
|
||||
)
|
||||
]
|
||||
|
||||
response = await chat_client_base.get_response(
|
||||
[Message(role="user", text="delegate this")],
|
||||
options={"tool_choice": "auto", "tools": [delegate_tool]},
|
||||
)
|
||||
|
||||
user_requests = [
|
||||
content
|
||||
for msg in response.messages
|
||||
for content in msg.contents
|
||||
if isinstance(content, Content) and content.user_input_request
|
||||
]
|
||||
assert len(user_requests) == 1
|
||||
assert user_requests[0].type == "oauth_consent_request"
|
||||
assert user_requests[0].consent_link == "https://login.microsoftonline.com/consent"
|
||||
assert user_requests[0].user_input_request is True
|
||||
|
||||
|
||||
async def test_user_input_request_multiple_contents_propagate(chat_client_base: SupportsChatGetResponse):
|
||||
"""Test that multiple user_input_request items in a single exception all propagate to the parent response."""
|
||||
from agent_framework.exceptions import UserInputRequiredException
|
||||
|
||||
@tool(name="multi_request_tool", approval_mode="never_require")
|
||||
def multi_request(task: str) -> str:
|
||||
del task
|
||||
raise UserInputRequiredException(
|
||||
contents=[
|
||||
Content.from_oauth_consent_request(
|
||||
consent_link="https://example.com/consent1",
|
||||
),
|
||||
Content.from_oauth_consent_request(
|
||||
consent_link="https://example.com/consent2",
|
||||
),
|
||||
Content.from_oauth_consent_request(
|
||||
consent_link="https://example.com/consent3",
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
chat_client_base.run_responses = [
|
||||
ChatResponse(
|
||||
messages=Message(
|
||||
role="assistant",
|
||||
contents=[
|
||||
Content.from_function_call(call_id="1", name="multi_request_tool", arguments='{"task": "do it"}'),
|
||||
],
|
||||
)
|
||||
)
|
||||
]
|
||||
|
||||
response = await chat_client_base.get_response(
|
||||
[Message(role="user", text="do something")],
|
||||
options={"tool_choice": "auto", "tools": [multi_request]},
|
||||
)
|
||||
|
||||
user_requests = [
|
||||
content
|
||||
for msg in response.messages
|
||||
for content in msg.contents
|
||||
if isinstance(content, Content) and content.user_input_request
|
||||
]
|
||||
assert len(user_requests) == 3
|
||||
consent_links = {r.consent_link for r in user_requests}
|
||||
assert consent_links == {
|
||||
"https://example.com/consent1",
|
||||
"https://example.com/consent2",
|
||||
"https://example.com/consent3",
|
||||
}
|
||||
|
||||
|
||||
async def test_user_input_request_empty_contents_returns_fallback(chat_client_base: SupportsChatGetResponse):
|
||||
"""Test that UserInputRequiredException with empty contents produces a fallback function_result."""
|
||||
from agent_framework.exceptions import UserInputRequiredException
|
||||
|
||||
@tool(name="empty_request_tool", approval_mode="never_require")
|
||||
def empty_request(task: str) -> str:
|
||||
del task
|
||||
raise UserInputRequiredException(contents=[])
|
||||
|
||||
chat_client_base.run_responses = [
|
||||
ChatResponse(
|
||||
messages=Message(
|
||||
role="assistant",
|
||||
contents=[
|
||||
Content.from_function_call(call_id="1", name="empty_request_tool", arguments='{"task": "do it"}'),
|
||||
],
|
||||
)
|
||||
),
|
||||
ChatResponse(messages=Message(role="assistant", text="handled")),
|
||||
]
|
||||
|
||||
response = await chat_client_base.get_response(
|
||||
[Message(role="user", text="do something")],
|
||||
options={"tool_choice": "auto", "tools": [empty_request]},
|
||||
)
|
||||
|
||||
# With empty contents, the handler returns a function_result with an error message
|
||||
# and the loop continues to the next chat response.
|
||||
function_results = [
|
||||
content for msg in response.messages for content in msg.contents if content.type == "function_result"
|
||||
]
|
||||
assert len(function_results) >= 1
|
||||
assert any("user input" in (fr.result or "").lower() for fr in function_results)
|
||||
|
||||
@@ -6,11 +6,13 @@ from collections.abc import AsyncIterable, Awaitable, MutableSequence, Sequence
|
||||
from typing import Any
|
||||
|
||||
from agent_framework import (
|
||||
Agent,
|
||||
BaseChatClient,
|
||||
ChatMiddlewareLayer,
|
||||
ChatResponse,
|
||||
ChatResponseUpdate,
|
||||
Content,
|
||||
FunctionInvocationContext,
|
||||
FunctionInvocationLayer,
|
||||
Message,
|
||||
ResponseStream,
|
||||
@@ -97,6 +99,7 @@ class TestKwargsPropagationToFunctionTool:
|
||||
|
||||
async def test_kwargs_propagate_to_tool_with_kwargs(self) -> None:
|
||||
"""Test that kwargs passed to get_response() are available in @tool **kwargs."""
|
||||
# TODO(Copilot): Remove this legacy coverage once runtime ``**kwargs`` tool injection is removed.
|
||||
captured_kwargs: dict[str, Any] = {}
|
||||
|
||||
@tool(approval_mode="never_require")
|
||||
@@ -149,6 +152,7 @@ class TestKwargsPropagationToFunctionTool:
|
||||
|
||||
async def test_kwargs_not_forwarded_to_tool_without_kwargs(self) -> None:
|
||||
"""Test that kwargs are NOT forwarded to @tool that doesn't accept **kwargs."""
|
||||
# TODO(Copilot): Remove this legacy coverage once runtime ``**kwargs`` tool injection is removed.
|
||||
|
||||
@tool(approval_mode="never_require")
|
||||
def simple_tool(x: int) -> str:
|
||||
@@ -185,6 +189,7 @@ class TestKwargsPropagationToFunctionTool:
|
||||
|
||||
async def test_kwargs_isolated_between_function_calls(self) -> None:
|
||||
"""Test that kwargs are consistent across multiple function call invocations."""
|
||||
# TODO(Copilot): Remove this legacy coverage once runtime ``**kwargs`` tool injection is removed.
|
||||
invocation_kwargs: list[dict[str, Any]] = []
|
||||
|
||||
@tool(approval_mode="never_require")
|
||||
@@ -235,6 +240,7 @@ class TestKwargsPropagationToFunctionTool:
|
||||
|
||||
async def test_streaming_response_kwargs_propagation(self) -> None:
|
||||
"""Test that kwargs propagate to @tool in streaming mode."""
|
||||
# TODO(Copilot): Remove this legacy coverage once runtime ``**kwargs`` tool injection is removed.
|
||||
captured_kwargs: dict[str, Any] = {}
|
||||
|
||||
@tool(approval_mode="never_require")
|
||||
@@ -287,3 +293,59 @@ class TestKwargsPropagationToFunctionTool:
|
||||
assert "streaming_session" in captured_kwargs, f"Expected 'streaming_session' in {captured_kwargs}"
|
||||
assert captured_kwargs["streaming_session"] == "session-xyz"
|
||||
assert captured_kwargs["correlation_id"] == "corr-123"
|
||||
|
||||
async def test_agent_run_injects_function_invocation_context(self) -> None:
|
||||
"""Test that Agent.run injects FunctionInvocationContext for ctx-based tools."""
|
||||
captured_context_kwargs: dict[str, Any] = {}
|
||||
captured_client_kwargs: dict[str, Any] = {}
|
||||
captured_options: dict[str, Any] = {}
|
||||
|
||||
@tool(approval_mode="never_require")
|
||||
def capture_context_tool(x: int, ctx: FunctionInvocationContext) -> str:
|
||||
captured_context_kwargs.update(ctx.kwargs)
|
||||
return f"result: x={x}"
|
||||
|
||||
class CapturingFunctionInvokingMockClient(FunctionInvokingMockClient):
|
||||
async def _get_non_streaming_response(
|
||||
self,
|
||||
*,
|
||||
messages: MutableSequence[Message],
|
||||
options: dict[str, Any],
|
||||
**kwargs: Any,
|
||||
) -> ChatResponse:
|
||||
captured_options.update(options)
|
||||
captured_client_kwargs.update(kwargs)
|
||||
return await super()._get_non_streaming_response(messages=messages, options=options, **kwargs)
|
||||
|
||||
client = CapturingFunctionInvokingMockClient()
|
||||
client.run_responses = [
|
||||
ChatResponse(
|
||||
messages=[
|
||||
Message(
|
||||
role="assistant",
|
||||
contents=[
|
||||
Content.from_function_call(
|
||||
call_id="call_1",
|
||||
name="capture_context_tool",
|
||||
arguments='{"x": 42}',
|
||||
)
|
||||
],
|
||||
)
|
||||
]
|
||||
),
|
||||
ChatResponse(messages=[Message(role="assistant", text="Done!")]),
|
||||
]
|
||||
|
||||
agent = Agent(client=client, tools=[capture_context_tool])
|
||||
result = await agent.run(
|
||||
[Message(role="user", text="Test")],
|
||||
function_invocation_kwargs={"tool_request_id": "tool-123"},
|
||||
client_kwargs={"client_request_id": "client-456"},
|
||||
)
|
||||
|
||||
assert captured_context_kwargs["tool_request_id"] == "tool-123"
|
||||
assert "client_request_id" not in captured_context_kwargs
|
||||
assert captured_client_kwargs["client_request_id"] == "client-456"
|
||||
assert "tool_request_id" not in captured_client_kwargs
|
||||
assert "additional_function_arguments" not in captured_options
|
||||
assert result.messages[-1].text == "Done!"
|
||||
|
||||
@@ -192,10 +192,10 @@ class ConcreteHistoryProvider(BaseHistoryProvider):
|
||||
self.stored: list[Message] = []
|
||||
self._stored_messages = stored_messages or []
|
||||
|
||||
async def get_messages(self, session_id: str | None, **kwargs) -> list[Message]:
|
||||
async def get_messages(self, session_id: str | None, *, state=None, **kwargs) -> list[Message]:
|
||||
return list(self._stored_messages)
|
||||
|
||||
async def save_messages(self, session_id: str | None, messages: Sequence[Message], **kwargs) -> None:
|
||||
async def save_messages(self, session_id: str | None, messages: Sequence[Message], *, state=None, **kwargs) -> None:
|
||||
self.stored.extend(messages)
|
||||
|
||||
|
||||
|
||||
@@ -12,6 +12,7 @@ from agent_framework import (
|
||||
FunctionTool,
|
||||
tool,
|
||||
)
|
||||
from agent_framework._middleware import FunctionInvocationContext
|
||||
from agent_framework._tools import (
|
||||
_parse_annotation,
|
||||
_parse_inputs,
|
||||
@@ -952,6 +953,128 @@ async def test_ai_function_with_kwargs_injection():
|
||||
assert result_default[0].text == "x=10, user=unknown"
|
||||
|
||||
|
||||
async def test_ai_function_with_explicit_invocation_context():
|
||||
"""Test that invoke() can receive runtime kwargs via FunctionInvocationContext."""
|
||||
|
||||
@tool
|
||||
def tool_with_context(x: int, ctx: FunctionInvocationContext) -> str:
|
||||
"""A tool that accepts runtime context injection."""
|
||||
user_id = ctx.kwargs.get("user_id", "unknown")
|
||||
return f"x={x}, user={user_id}"
|
||||
|
||||
assert tool_with_context.parameters() == {
|
||||
"properties": {"x": {"title": "X", "type": "integer"}},
|
||||
"required": ["x"],
|
||||
"title": "tool_with_context_input",
|
||||
"type": "object",
|
||||
}
|
||||
|
||||
context = FunctionInvocationContext(
|
||||
function=tool_with_context,
|
||||
arguments=tool_with_context.input_model(x=7),
|
||||
kwargs={"user_id": "ctx-user"},
|
||||
)
|
||||
|
||||
result = await tool_with_context.invoke(context=context)
|
||||
|
||||
assert result[0].text == "x=7, user=ctx-user"
|
||||
|
||||
|
||||
async def test_ai_function_with_typed_context_parameter_using_custom_name():
|
||||
"""Test that typed context injection works for names other than ctx."""
|
||||
|
||||
@tool
|
||||
def tool_with_runtime_context(x: int, runtime: FunctionInvocationContext) -> str:
|
||||
"""A tool that uses a custom context parameter name."""
|
||||
user_id = runtime.kwargs.get("user_id", "unknown")
|
||||
return f"x={x}, user={user_id}"
|
||||
|
||||
assert tool_with_runtime_context.parameters() == {
|
||||
"properties": {"x": {"title": "X", "type": "integer"}},
|
||||
"required": ["x"],
|
||||
"title": "tool_with_runtime_context_input",
|
||||
"type": "object",
|
||||
}
|
||||
|
||||
context = FunctionInvocationContext(
|
||||
function=tool_with_runtime_context,
|
||||
arguments=tool_with_runtime_context.input_model(x=8),
|
||||
kwargs={"user_id": "runtime-user"},
|
||||
)
|
||||
|
||||
result = await tool_with_runtime_context.invoke(context=context)
|
||||
|
||||
assert result[0].text == "x=8, user=runtime-user"
|
||||
|
||||
|
||||
async def test_ai_function_with_explicit_schema_and_untyped_ctx():
|
||||
"""Test that explicit schemas allow an untyped ctx parameter."""
|
||||
|
||||
class ToolInput(BaseModel):
|
||||
x: int
|
||||
|
||||
@tool(schema=ToolInput)
|
||||
def tool_with_schema(x, ctx) -> str:
|
||||
"""A tool with explicit schema and implicit ctx injection."""
|
||||
return f"x={x}, user={ctx.kwargs.get('user_id', 'unknown')}"
|
||||
|
||||
context = FunctionInvocationContext(
|
||||
function=tool_with_schema,
|
||||
arguments=ToolInput(x=9),
|
||||
kwargs={"user_id": "schema-user"},
|
||||
)
|
||||
|
||||
result = await tool_with_schema.invoke(context=context)
|
||||
|
||||
assert result[0].text == "x=9, user=schema-user"
|
||||
|
||||
|
||||
async def test_ai_function_with_explicit_schema_and_typed_ctx():
|
||||
"""Test that explicit schemas also work with typed context injection."""
|
||||
|
||||
class ToolInput(BaseModel):
|
||||
x: int
|
||||
|
||||
@tool(schema=ToolInput)
|
||||
def tool_with_schema(x: int, runtime: FunctionInvocationContext) -> str:
|
||||
"""A tool with explicit schema and typed context injection."""
|
||||
return f"x={x}, user={runtime.kwargs.get('user_id', 'unknown')}"
|
||||
|
||||
context = FunctionInvocationContext(
|
||||
function=tool_with_schema,
|
||||
arguments=ToolInput(x=11),
|
||||
kwargs={"user_id": "typed-schema-user"},
|
||||
)
|
||||
|
||||
result = await tool_with_schema.invoke(context=context)
|
||||
|
||||
assert tool_with_schema.parameters() == ToolInput.model_json_schema()
|
||||
assert result[0].text == "x=11, user=typed-schema-user"
|
||||
|
||||
|
||||
def test_ai_function_with_multiple_typed_context_parameters_fails():
|
||||
"""Test that tools reject multiple typed FunctionInvocationContext parameters."""
|
||||
|
||||
with pytest.raises(ValueError, match="multiple FunctionInvocationContext parameters"):
|
||||
|
||||
@tool
|
||||
def invalid_tool(ctx_one: FunctionInvocationContext, ctx_two: FunctionInvocationContext) -> str:
|
||||
return f"{ctx_one.kwargs}-{ctx_two.kwargs}"
|
||||
|
||||
|
||||
def test_ai_function_with_ctx_and_typed_context_parameter_fails():
|
||||
"""Test that explicit-schema tools reject both implicit ctx and typed context parameters."""
|
||||
|
||||
class ToolInput(BaseModel):
|
||||
x: int
|
||||
|
||||
with pytest.raises(ValueError, match="multiple FunctionInvocationContext parameters"):
|
||||
|
||||
@tool(schema=ToolInput)
|
||||
def invalid_tool(x, ctx, runtime: FunctionInvocationContext) -> str:
|
||||
return f"{x}-{ctx.kwargs}-{runtime.kwargs}"
|
||||
|
||||
|
||||
# region _parse_annotation tests
|
||||
|
||||
|
||||
|
||||
@@ -124,10 +124,20 @@ class DurableAgentExecutor(ABC, Generic[TaskT]):
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def get_new_session(self, agent_name: str, **kwargs: Any) -> DurableAgentSession:
|
||||
def get_new_session(
|
||||
self,
|
||||
agent_name: str,
|
||||
*,
|
||||
session_id: str | None = None,
|
||||
service_session_id: str | None = None,
|
||||
) -> DurableAgentSession:
|
||||
"""Create a new DurableAgentSession with random session ID."""
|
||||
session_id = self._create_session_id(agent_name)
|
||||
return DurableAgentSession.from_session_id(session_id, **kwargs)
|
||||
durable_session_id = self._create_session_id(agent_name)
|
||||
return DurableAgentSession(
|
||||
durable_session_id=durable_session_id,
|
||||
session_id=session_id,
|
||||
service_session_id=service_session_id,
|
||||
)
|
||||
|
||||
def _create_session_id(
|
||||
self,
|
||||
|
||||
@@ -284,46 +284,48 @@ class DurableAgentSession(AgentSession):
|
||||
durable_session_id: AgentSessionId | None = None,
|
||||
session_id: str | None = None,
|
||||
service_session_id: str | None = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
super().__init__(session_id=session_id, service_session_id=service_session_id, **kwargs)
|
||||
self._session_id_value: AgentSessionId | None = durable_session_id
|
||||
super().__init__(session_id=session_id, service_session_id=service_session_id)
|
||||
self.durable_session_id: AgentSessionId | None = durable_session_id
|
||||
|
||||
@property
|
||||
def durable_session_id(self) -> AgentSessionId | None:
|
||||
return self._session_id_value
|
||||
|
||||
@durable_session_id.setter
|
||||
def durable_session_id(self, value: AgentSessionId | None) -> None:
|
||||
self._session_id_value = value
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
state = super().to_dict()
|
||||
if self.durable_session_id is not None:
|
||||
state[self._SERIALIZED_SESSION_ID_KEY] = str(self.durable_session_id)
|
||||
return state
|
||||
|
||||
@classmethod
|
||||
def from_session_id(
|
||||
cls,
|
||||
session_id: AgentSessionId,
|
||||
**kwargs: Any,
|
||||
durable_session_id: AgentSessionId,
|
||||
*,
|
||||
session_id: str | None = None,
|
||||
service_session_id: str | None = None,
|
||||
) -> DurableAgentSession:
|
||||
return cls(durable_session_id=session_id, **kwargs)
|
||||
|
||||
def to_dict(self) -> dict[str, Any]:
|
||||
state = super().to_dict()
|
||||
if self._session_id_value is not None:
|
||||
state[self._SERIALIZED_SESSION_ID_KEY] = str(self._session_id_value)
|
||||
return state
|
||||
"""Create a DurableAgentSession from an AgentSessionId."""
|
||||
return cls(
|
||||
durable_session_id=durable_session_id,
|
||||
session_id=session_id,
|
||||
service_session_id=service_session_id,
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def from_dict(cls, data: dict[str, Any]) -> DurableAgentSession:
|
||||
state_payload = dict(data)
|
||||
session_id_value = state_payload.pop(cls._SERIALIZED_SESSION_ID_KEY, None)
|
||||
session = super().from_dict(state_payload)
|
||||
"""Create a DurableAgentSession from a state dict."""
|
||||
data = dict(data) # defensive copy — avoid mutating caller's dict
|
||||
session_id_value = data.pop(cls._SERIALIZED_SESSION_ID_KEY, None)
|
||||
session = super().from_dict(data)
|
||||
durable_session_id: AgentSessionId | None = None
|
||||
# We need to create a DurableAgentSession from the base AgentSession
|
||||
if session_id_value is not None:
|
||||
if not isinstance(session_id_value, str):
|
||||
raise ValueError("durable_session_id must be a string when present in serialized state")
|
||||
durable_session_id = AgentSessionId.parse(session_id_value)
|
||||
|
||||
durable_session = cls(
|
||||
durable_session_id=durable_session_id,
|
||||
session_id=session.session_id,
|
||||
service_session_id=session.service_session_id,
|
||||
)
|
||||
durable_session.state.update(session.state)
|
||||
if session_id_value is not None:
|
||||
if not isinstance(session_id_value, str):
|
||||
raise ValueError("durable_session_id must be a string when present in serialized state")
|
||||
durable_session._session_id_value = AgentSessionId.parse(session_id_value)
|
||||
return durable_session
|
||||
|
||||
@@ -133,16 +133,13 @@ class DurableAIAgent(SupportsAgentRun, Generic[TaskT]):
|
||||
session=session,
|
||||
)
|
||||
|
||||
def create_session(self, **kwargs: Any) -> DurableAgentSession:
|
||||
def create_session(self, *, session_id: str | None = None) -> DurableAgentSession:
|
||||
"""Create a new agent session via the provider."""
|
||||
return self._executor.get_new_session(self.name, **kwargs)
|
||||
return self._executor.get_new_session(self.name)
|
||||
|
||||
def get_session(self, **kwargs: Any) -> AgentSession:
|
||||
"""Retrieve an existing session via the provider.
|
||||
|
||||
For durable agents, sessions do not use `service_session_id` so this is not used.
|
||||
"""
|
||||
return self._executor.get_new_session(self.name, **kwargs)
|
||||
def get_session(self, service_session_id: str, *, session_id: str | None = None) -> AgentSession:
|
||||
"""Retrieve an existing session via the provider."""
|
||||
return self._executor.get_new_session(self.name, service_session_id=service_session_id, session_id=session_id)
|
||||
|
||||
def _normalize_messages(self, messages: AgentRunInputs | None) -> str:
|
||||
"""Convert supported message inputs to a single string.
|
||||
|
||||
@@ -2,6 +2,8 @@
|
||||
|
||||
"""Unit tests for AgentSessionId and DurableAgentSession."""
|
||||
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
from agent_framework import AgentSession
|
||||
|
||||
@@ -153,7 +155,7 @@ class TestDurableAgentSession:
|
||||
def test_from_session_id(self) -> None:
|
||||
"""Test creating DurableAgentSession from session ID."""
|
||||
session_id = AgentSessionId(name="TestAgent", key="test-key")
|
||||
session = DurableAgentSession.from_session_id(session_id)
|
||||
session = DurableAgentSession(durable_session_id=session_id)
|
||||
|
||||
assert isinstance(session, DurableAgentSession)
|
||||
assert session.durable_session_id is not None
|
||||
@@ -161,10 +163,10 @@ class TestDurableAgentSession:
|
||||
assert session.durable_session_id.name == "TestAgent"
|
||||
assert session.durable_session_id.key == "test-key"
|
||||
|
||||
def test_from_session_id_with_service_session_id(self) -> None:
|
||||
"""Test creating DurableAgentSession with service session ID."""
|
||||
def test_init_with_service_session_id(self) -> None:
|
||||
"""Test creating DurableAgentSession with explicit service session ID."""
|
||||
session_id = AgentSessionId(name="TestAgent", key="test-key")
|
||||
session = DurableAgentSession.from_session_id(session_id, service_session_id="service-123")
|
||||
session = DurableAgentSession(durable_session_id=session_id, service_session_id="service-123")
|
||||
|
||||
assert session.durable_session_id is not None
|
||||
assert session.durable_session_id == session_id
|
||||
@@ -192,7 +194,7 @@ class TestDurableAgentSession:
|
||||
|
||||
def test_from_dict_with_durable_session_id(self) -> None:
|
||||
"""Test deserialization restores durable session ID."""
|
||||
serialized = {
|
||||
serialized: dict[str, Any] = {
|
||||
"type": "session",
|
||||
"session_id": "session-123",
|
||||
"service_session_id": "service-123",
|
||||
@@ -210,7 +212,7 @@ class TestDurableAgentSession:
|
||||
|
||||
def test_from_dict_without_durable_session_id(self) -> None:
|
||||
"""Test deserialization without durable session ID."""
|
||||
serialized = {
|
||||
serialized: dict[str, Any] = {
|
||||
"type": "session",
|
||||
"session_id": "session-456",
|
||||
"service_session_id": "service-456",
|
||||
|
||||
@@ -88,15 +88,6 @@ class TestDurableAIAgentClientIntegration:
|
||||
|
||||
assert isinstance(session, DurableAgentSession)
|
||||
|
||||
def test_client_agent_session_with_parameters(self, agent_client: DurableAIAgentClient) -> None:
|
||||
"""Verify agent can create sessions with custom parameters."""
|
||||
agent = agent_client.get_agent("assistant")
|
||||
|
||||
session = agent.create_session(service_session_id="client-session-123")
|
||||
|
||||
assert isinstance(session, DurableAgentSession)
|
||||
assert session.service_session_id == "client-session-123"
|
||||
|
||||
|
||||
class TestDurableAIAgentClientPollingConfiguration:
|
||||
"""Test polling configuration parameters for DurableAIAgentClient."""
|
||||
|
||||
@@ -82,17 +82,6 @@ class TestDurableAIAgentOrchestrationContextIntegration:
|
||||
|
||||
assert isinstance(session, DurableAgentSession)
|
||||
|
||||
def test_orchestration_agent_session_with_parameters(
|
||||
self, agent_context: DurableAIAgentOrchestrationContext
|
||||
) -> None:
|
||||
"""Verify agent can create sessions with custom parameters."""
|
||||
agent = agent_context.get_agent("assistant")
|
||||
|
||||
session = agent.create_session(service_session_id="orch-session-456")
|
||||
|
||||
assert isinstance(session, DurableAgentSession)
|
||||
assert session.service_session_id == "orch-session-456"
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main([__file__, "-v", "--tb=short"])
|
||||
|
||||
@@ -184,16 +184,31 @@ class TestDurableAIAgentSessionManagement:
|
||||
mock_executor.get_new_session.assert_called_once_with("test_agent")
|
||||
assert session == mock_session
|
||||
|
||||
def test_create_session_forwards_kwargs(self, test_agent: DurableAIAgent[Any], mock_executor: Mock) -> None:
|
||||
"""Verify create_session forwards kwargs to executor."""
|
||||
mock_session = DurableAgentSession(service_session_id="session-123")
|
||||
def test_get_session_forwards_service_session_id(
|
||||
self, test_agent: DurableAIAgent[Any], mock_executor: Mock
|
||||
) -> None:
|
||||
"""Verify get_session forwards service_session_id and session_id to executor."""
|
||||
mock_session = DurableAgentSession(service_session_id="svc-123")
|
||||
mock_executor.get_new_session.return_value = mock_session
|
||||
|
||||
test_agent.create_session(service_session_id="session-123")
|
||||
session = test_agent.get_session("svc-123", session_id="local-456")
|
||||
|
||||
mock_executor.get_new_session.assert_called_once()
|
||||
_, kwargs = mock_executor.get_new_session.call_args
|
||||
assert kwargs["service_session_id"] == "session-123"
|
||||
mock_executor.get_new_session.assert_called_once_with(
|
||||
"test_agent", service_session_id="svc-123", session_id="local-456"
|
||||
)
|
||||
assert session.service_session_id == "svc-123"
|
||||
|
||||
def test_get_session_without_session_id(self, test_agent: DurableAIAgent[Any], mock_executor: Mock) -> None:
|
||||
"""Verify get_session works with only service_session_id (session_id defaults to None)."""
|
||||
mock_session = DurableAgentSession(service_session_id="svc-789")
|
||||
mock_executor.get_new_session.return_value = mock_session
|
||||
|
||||
session = test_agent.get_session("svc-789")
|
||||
|
||||
mock_executor.get_new_session.assert_called_once_with(
|
||||
"test_agent", service_session_id="svc-789", session_id=None
|
||||
)
|
||||
assert session.service_session_id == "svc-789"
|
||||
|
||||
|
||||
class TestDurableAgentProviderInterface:
|
||||
|
||||
+3
-4
@@ -146,11 +146,11 @@ class FoundryLocalClient(
|
||||
timeout: float | None = None,
|
||||
prepare_model: bool = True,
|
||||
device: DeviceType | None = None,
|
||||
additional_properties: dict[str, Any] | None = None,
|
||||
middleware: Sequence[ChatAndFunctionMiddlewareTypes] | None = None,
|
||||
function_invocation_configuration: FunctionInvocationConfiguration | None = None,
|
||||
env_file_path: str | None = None,
|
||||
env_file_encoding: str = "utf-8",
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Initialize a FoundryLocalClient.
|
||||
|
||||
@@ -169,12 +169,11 @@ class FoundryLocalClient(
|
||||
The device is used to select the appropriate model variant.
|
||||
If not provided, the default device for your system will be used.
|
||||
The values are in the foundry_local.models.DeviceType enum.
|
||||
additional_properties: Additional properties stored on the client instance.
|
||||
middleware: Optional sequence of ChatAndFunctionMiddlewareTypes to apply to requests.
|
||||
function_invocation_configuration: Optional configuration for function invocation support.
|
||||
env_file_path: If provided, the .env settings are read from this file path location.
|
||||
env_file_encoding: The encoding of the .env file, defaults to 'utf-8'.
|
||||
kwargs: Additional keyword arguments, are passed to the RawOpenAIChatClient.
|
||||
This can include middleware and additional properties.
|
||||
|
||||
Examples:
|
||||
|
||||
@@ -271,8 +270,8 @@ class FoundryLocalClient(
|
||||
super().__init__(
|
||||
model_id=model_info.id,
|
||||
client=AsyncOpenAI(base_url=manager.endpoint, api_key=manager.api_key),
|
||||
additional_properties=additional_properties,
|
||||
middleware=middleware,
|
||||
function_invocation_configuration=function_invocation_configuration,
|
||||
**kwargs,
|
||||
)
|
||||
self.manager = manager
|
||||
|
||||
@@ -303,7 +303,6 @@ class GitHubCopilotAgent(BaseAgent, Generic[OptionsT]):
|
||||
stream: Literal[False] = False,
|
||||
session: AgentSession | None = None,
|
||||
options: OptionsT | None = None,
|
||||
**kwargs: Any,
|
||||
) -> Awaitable[AgentResponse]: ...
|
||||
|
||||
@overload
|
||||
@@ -314,7 +313,6 @@ class GitHubCopilotAgent(BaseAgent, Generic[OptionsT]):
|
||||
stream: Literal[True],
|
||||
session: AgentSession | None = None,
|
||||
options: OptionsT | None = None,
|
||||
**kwargs: Any,
|
||||
) -> ResponseStream[AgentResponseUpdate, AgentResponse]: ...
|
||||
|
||||
def run(
|
||||
@@ -324,7 +322,6 @@ class GitHubCopilotAgent(BaseAgent, Generic[OptionsT]):
|
||||
stream: bool = False,
|
||||
session: AgentSession | None = None,
|
||||
options: OptionsT | None = None,
|
||||
**kwargs: Any,
|
||||
) -> Awaitable[AgentResponse] | ResponseStream[AgentResponseUpdate, AgentResponse]:
|
||||
"""Get a response from the agent.
|
||||
|
||||
@@ -339,7 +336,6 @@ class GitHubCopilotAgent(BaseAgent, Generic[OptionsT]):
|
||||
stream: Whether to stream the response. Defaults to False.
|
||||
session: The conversation session associated with the message(s).
|
||||
options: Runtime options (model, timeout, etc.).
|
||||
kwargs: Additional keyword arguments.
|
||||
|
||||
Returns:
|
||||
When stream=False: An Awaitable[AgentResponse].
|
||||
@@ -354,10 +350,10 @@ class GitHubCopilotAgent(BaseAgent, Generic[OptionsT]):
|
||||
return AgentResponse.from_updates(updates)
|
||||
|
||||
return ResponseStream(
|
||||
self._stream_updates(messages=messages, session=session, options=options, **kwargs),
|
||||
self._stream_updates(messages=messages, session=session, options=options),
|
||||
finalizer=_finalize,
|
||||
)
|
||||
return self._run_impl(messages=messages, session=session, options=options, **kwargs)
|
||||
return self._run_impl(messages=messages, session=session, options=options)
|
||||
|
||||
async def _run_impl(
|
||||
self,
|
||||
@@ -365,7 +361,6 @@ class GitHubCopilotAgent(BaseAgent, Generic[OptionsT]):
|
||||
*,
|
||||
session: AgentSession | None = None,
|
||||
options: OptionsT | None = None,
|
||||
**kwargs: Any,
|
||||
) -> AgentResponse:
|
||||
"""Non-streaming implementation of run."""
|
||||
if not self._started:
|
||||
@@ -414,7 +409,6 @@ class GitHubCopilotAgent(BaseAgent, Generic[OptionsT]):
|
||||
*,
|
||||
session: AgentSession | None = None,
|
||||
options: OptionsT | None = None,
|
||||
**kwargs: Any,
|
||||
) -> AsyncIterable[AgentResponseUpdate]:
|
||||
"""Internal method to stream updates from GitHub Copilot.
|
||||
|
||||
@@ -424,7 +418,6 @@ class GitHubCopilotAgent(BaseAgent, Generic[OptionsT]):
|
||||
Keyword Args:
|
||||
session: The conversation session associated with the message(s).
|
||||
options: Runtime options (model, timeout, etc.).
|
||||
kwargs: Additional keyword arguments.
|
||||
|
||||
Yields:
|
||||
AgentResponseUpdate items.
|
||||
|
||||
@@ -300,11 +300,11 @@ class OllamaChatClient(
|
||||
host: str | None = None,
|
||||
client: AsyncClient | None = None,
|
||||
model_id: str | None = None,
|
||||
additional_properties: dict[str, Any] | None = None,
|
||||
middleware: Sequence[ChatAndFunctionMiddlewareTypes] | None = None,
|
||||
function_invocation_configuration: FunctionInvocationConfiguration | None = None,
|
||||
env_file_path: str | None = None,
|
||||
env_file_encoding: str | None = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Initialize an Ollama Chat client.
|
||||
|
||||
@@ -313,11 +313,11 @@ class OllamaChatClient(
|
||||
Can be set via the OLLAMA_HOST env variable.
|
||||
client: An optional Ollama Client instance. If not provided, a new instance will be created.
|
||||
model_id: The Ollama chat model ID to use. Can be set via the OLLAMA_MODEL_ID env variable.
|
||||
additional_properties: Additional properties stored on the client instance.
|
||||
middleware: Optional middleware to apply to the client.
|
||||
function_invocation_configuration: Optional function invocation configuration override.
|
||||
env_file_path: An optional path to a dotenv (.env) file to load environment variables from.
|
||||
env_file_encoding: The encoding to use when reading the dotenv (.env) file. Defaults to 'utf-8'.
|
||||
**kwargs: Additional keyword arguments passed to BaseChatClient.
|
||||
"""
|
||||
ollama_settings = load_settings(
|
||||
OllamaSettings,
|
||||
@@ -336,9 +336,9 @@ class OllamaChatClient(
|
||||
self.host = str(self.client._client.base_url) # type: ignore[reportUnknownMemberType,reportPrivateUsage,reportUnknownArgumentType]
|
||||
|
||||
super().__init__(
|
||||
additional_properties=additional_properties,
|
||||
middleware=middleware,
|
||||
function_invocation_configuration=function_invocation_configuration,
|
||||
**kwargs,
|
||||
)
|
||||
self.middleware = list(self.chat_middleware)
|
||||
|
||||
|
||||
@@ -92,9 +92,9 @@ class RawOllamaEmbeddingClient(
|
||||
model_id: str | None = None,
|
||||
host: str | None = None,
|
||||
client: AsyncClient | None = None,
|
||||
additional_properties: dict[str, Any] | None = None,
|
||||
env_file_path: str | None = None,
|
||||
env_file_encoding: str | None = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Initialize a raw Ollama embedding client."""
|
||||
ollama_settings = load_settings(
|
||||
@@ -110,7 +110,7 @@ class RawOllamaEmbeddingClient(
|
||||
self.model_id = ollama_settings["embedding_model_id"] # type: ignore[assignment,reportTypedDictNotRequiredAccess]
|
||||
self.client = client or AsyncClient(host=ollama_settings.get("host"))
|
||||
self.host = str(self.client._client.base_url) # type: ignore[reportUnknownMemberType,reportPrivateUsage,reportUnknownArgumentType]
|
||||
super().__init__(**kwargs)
|
||||
super().__init__(additional_properties=additional_properties)
|
||||
|
||||
def service_url(self) -> str:
|
||||
"""Get the URL of the service."""
|
||||
@@ -214,17 +214,17 @@ class OllamaEmbeddingClient(
|
||||
host: str | None = None,
|
||||
client: AsyncClient | None = None,
|
||||
otel_provider_name: str | None = None,
|
||||
additional_properties: dict[str, Any] | None = None,
|
||||
env_file_path: str | None = None,
|
||||
env_file_encoding: str | None = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Initialize an Ollama embedding client."""
|
||||
super().__init__(
|
||||
model_id=model_id,
|
||||
host=host,
|
||||
client=client,
|
||||
additional_properties=additional_properties,
|
||||
otel_provider_name=otel_provider_name,
|
||||
env_file_path=env_file_path,
|
||||
env_file_encoding=env_file_encoding,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@@ -107,11 +107,18 @@ class RedisHistoryProvider(BaseHistoryProvider):
|
||||
"""Get the Redis key for a given session's messages."""
|
||||
return f"{self.key_prefix}:{session_id or 'default'}"
|
||||
|
||||
async def get_messages(self, session_id: str | None, **kwargs: Any) -> list[Message]:
|
||||
async def get_messages(
|
||||
self,
|
||||
session_id: str | None,
|
||||
*,
|
||||
state: dict[str, Any] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> list[Message]:
|
||||
"""Retrieve stored messages for this session from Redis.
|
||||
|
||||
Args:
|
||||
session_id: The session ID to retrieve messages for.
|
||||
state: Optional session state. Unused for Redis-backed history.
|
||||
**kwargs: Additional arguments (unused).
|
||||
|
||||
Returns:
|
||||
@@ -125,12 +132,20 @@ class RedisHistoryProvider(BaseHistoryProvider):
|
||||
messages.append(Message.from_dict(self._deserialize_json(serialized))) # type: ignore[union-attr]
|
||||
return messages
|
||||
|
||||
async def save_messages(self, session_id: str | None, messages: Sequence[Message], **kwargs: Any) -> None:
|
||||
async def save_messages(
|
||||
self,
|
||||
session_id: str | None,
|
||||
messages: Sequence[Message],
|
||||
*,
|
||||
state: dict[str, Any] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> None:
|
||||
"""Persist messages for this session to Redis.
|
||||
|
||||
Args:
|
||||
session_id: The session ID to store messages for.
|
||||
messages: The messages to persist.
|
||||
state: Optional session state. Unused for Redis-backed history.
|
||||
**kwargs: Additional arguments (unused).
|
||||
"""
|
||||
if not messages:
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
import asyncio
|
||||
from collections.abc import Awaitable, Callable
|
||||
|
||||
from agent_framework import AgentContext, AgentSession
|
||||
from agent_framework import AgentContext, AgentSession, FunctionInvocationContext, tool
|
||||
from agent_framework.openai import OpenAIResponsesClient
|
||||
from dotenv import load_dotenv
|
||||
|
||||
@@ -18,9 +18,6 @@ sub-agent invoked as a tool using ``propagate_session=True``.
|
||||
When session propagation is enabled, both agents share the same session object,
|
||||
including session_id and the mutable state dict. This allows correlated
|
||||
conversation tracking and shared state across the agent hierarchy.
|
||||
|
||||
The middleware functions below are purely for observability — they are NOT
|
||||
required for session propagation to work.
|
||||
"""
|
||||
|
||||
|
||||
@@ -28,65 +25,83 @@ async def log_session(
|
||||
context: AgentContext,
|
||||
call_next: Callable[[], Awaitable[None]],
|
||||
) -> None:
|
||||
"""Agent middleware that logs the session received by each agent.
|
||||
|
||||
NOT required for session propagation — only used to observe the flow.
|
||||
If propagation is working, both agents will show the same session_id.
|
||||
"""
|
||||
"""Agent middleware that logs the session received by each agent."""
|
||||
session: AgentSession | None = context.session
|
||||
if not session:
|
||||
print("No session found.")
|
||||
await call_next()
|
||||
return
|
||||
agent_name = context.agent.name or "unknown"
|
||||
session_id = session.session_id if session else None
|
||||
state = dict(session.state) if session else {}
|
||||
print(f" [{agent_name}] session_id={session_id}, state={state}")
|
||||
print(
|
||||
f" [{agent_name}] session_id={session.session_id}, "
|
||||
f"service_session_id={session.service_session_id} state={session.state}"
|
||||
)
|
||||
await call_next()
|
||||
|
||||
|
||||
@tool(description="Use this tool to store the findings so that other agents can reason over them.")
|
||||
def store_findings(findings: str, ctx: FunctionInvocationContext) -> None:
|
||||
if ctx.session is None:
|
||||
return
|
||||
current_findings = ctx.session.state.get("findings")
|
||||
if current_findings is None:
|
||||
ctx.session.state["findings"] = findings
|
||||
else:
|
||||
ctx.session.state["findings"] = f"{current_findings}\n{findings}"
|
||||
|
||||
|
||||
@tool(description="Use this tool to gather the current findings from other agents.")
|
||||
def recall_findings(ctx: FunctionInvocationContext) -> str:
|
||||
if ctx.session is None:
|
||||
return "No session available"
|
||||
current_findings = ctx.session.state.get("findings")
|
||||
if current_findings is None:
|
||||
return "Nothing yet"
|
||||
return current_findings
|
||||
|
||||
|
||||
async def main() -> None:
|
||||
print("=== Agent-as-Tool: Session Propagation ===\n")
|
||||
|
||||
client = OpenAIResponsesClient()
|
||||
|
||||
# --- Sub-agent: a research specialist ---
|
||||
# The sub-agent has the same log_session middleware to prove it receives the session.
|
||||
research_agent = client.as_agent(
|
||||
name="ResearchAgent",
|
||||
instructions="You are a research assistant. Provide concise answers.",
|
||||
instructions="You are a research assistant. Provide concise answers and store your findings.",
|
||||
middleware=[log_session],
|
||||
tools=[store_findings, recall_findings],
|
||||
)
|
||||
|
||||
# propagate_session=True: the coordinator's session will be forwarded
|
||||
research_tool = research_agent.as_tool(
|
||||
name="research",
|
||||
description="Research a topic and return findings",
|
||||
description="Research a topic and store your findings.",
|
||||
arg_name="query",
|
||||
arg_description="The research query",
|
||||
propagate_session=True,
|
||||
)
|
||||
|
||||
# --- Coordinator agent ---
|
||||
coordinator = client.as_agent(
|
||||
name="CoordinatorAgent",
|
||||
instructions="You coordinate research. Use the 'research' tool to look up information.",
|
||||
tools=[research_tool],
|
||||
instructions=(
|
||||
"You coordinate research. Use the 'research' tool to start research "
|
||||
"and then use the recall findings tool to gather up everything."
|
||||
),
|
||||
tools=[research_tool, store_findings, recall_findings],
|
||||
middleware=[log_session],
|
||||
)
|
||||
|
||||
# Create a shared session and put some state in it
|
||||
session = coordinator.create_session()
|
||||
session.state["request_source"] = "demo"
|
||||
session.state["findings"] = None
|
||||
print(f"Session ID: {session.session_id}")
|
||||
print(f"Session state before run: {session.state}\n")
|
||||
|
||||
query = "What are the latest developments in quantum computing?"
|
||||
query = "What are the latest developments in quantum computing and in AI?"
|
||||
print(f"User: {query}\n")
|
||||
|
||||
result = await coordinator.run(query, session=session)
|
||||
|
||||
print(f"\nCoordinator: {result}\n")
|
||||
print(f"Session state after run: {session.state}")
|
||||
print(
|
||||
"\nIf both agents show the same session_id above, session propagation is working."
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
# Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
import asyncio
|
||||
from typing import Annotated, Any
|
||||
from typing import Annotated
|
||||
|
||||
from agent_framework import tool
|
||||
from agent_framework import FunctionInvocationContext, tool
|
||||
from agent_framework.openai import OpenAIResponsesClient
|
||||
from dotenv import load_dotenv
|
||||
from pydantic import Field
|
||||
@@ -14,27 +14,27 @@ load_dotenv()
|
||||
"""
|
||||
AI Function with kwargs Example
|
||||
|
||||
This example demonstrates how to inject custom keyword arguments (kwargs) into an AI function
|
||||
from the agent's run method, without exposing them to the AI model.
|
||||
This example demonstrates how to inject runtime context into an AI function
|
||||
from the agent's run method, without exposing it to the AI model.
|
||||
|
||||
This is useful for passing runtime information like access tokens, user IDs, or
|
||||
request-specific context that the tool needs but the model shouldn't know about
|
||||
or provide.
|
||||
or provide. The injected context parameter can be typed as
|
||||
``FunctionInvocationContext`` as shown here, or left untyped as ``ctx`` when you
|
||||
prefer a lighter-weight sample setup.
|
||||
"""
|
||||
|
||||
|
||||
# Define the function tool with **kwargs to accept injected arguments
|
||||
# NOTE: approval_mode="never_require" is for sample brevity. Use "always_require" in production;
|
||||
# see samples/02-agents/tools/function_tool_with_approval.py
|
||||
# and samples/02-agents/tools/function_tool_with_approval_and_sessions.py.
|
||||
# Define the function tool with explicit invocation context.
|
||||
# The context parameter can also be declared as an untyped ``ctx`` parameter.
|
||||
@tool(approval_mode="never_require")
|
||||
def get_weather(
|
||||
location: Annotated[str, Field(description="The location to get the weather for.")],
|
||||
**kwargs: Any,
|
||||
ctx: FunctionInvocationContext,
|
||||
) -> str:
|
||||
"""Get the weather for a given location."""
|
||||
# Extract the injected argument from kwargs
|
||||
user_id = kwargs.get("user_id", "unknown")
|
||||
# Extract the injected argument from the explicit context
|
||||
user_id = ctx.kwargs.get("user_id", "unknown")
|
||||
|
||||
# Simulate using the user_id for logging or personalization
|
||||
print(f"Getting weather for user: {user_id}")
|
||||
@@ -49,9 +49,11 @@ async def main() -> None:
|
||||
tools=[get_weather],
|
||||
)
|
||||
|
||||
# Pass the injected argument when running the agent
|
||||
# The 'user_id' kwarg will be passed down to the tool execution via **kwargs
|
||||
response = await agent.run("What is the weather like in Amsterdam?", user_id="user_123")
|
||||
# Pass the runtime context explicitly when running the agent.
|
||||
response = await agent.run(
|
||||
"What is the weather like in Amsterdam?",
|
||||
function_invocation_kwargs={"user_id": "user_123"},
|
||||
)
|
||||
|
||||
print(f"Agent: {response.text}")
|
||||
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
# Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
import asyncio
|
||||
from typing import Annotated, Any
|
||||
from typing import Annotated
|
||||
|
||||
from agent_framework import AgentSession, tool
|
||||
from agent_framework import AgentSession, FunctionInvocationContext, tool
|
||||
from agent_framework.openai import OpenAIResponsesClient
|
||||
from dotenv import load_dotenv
|
||||
from pydantic import Field
|
||||
@@ -14,23 +14,21 @@ load_dotenv()
|
||||
"""
|
||||
AI Function with Session Injection Example
|
||||
|
||||
This example demonstrates the behavior when passing 'session' to agent.run()
|
||||
and accessing that session in AI function.
|
||||
This example demonstrates accessing the agent session inside a tool function
|
||||
via ``FunctionInvocationContext.session``. The session is automatically
|
||||
available when the agent is invoked with a session.
|
||||
"""
|
||||
|
||||
|
||||
# Define the function tool with **kwargs
|
||||
# NOTE: approval_mode="never_require" is for sample brevity. Use "always_require" in production;
|
||||
# see samples/02-agents/tools/function_tool_with_approval.py
|
||||
# and samples/02-agents/tools/function_tool_with_approval_and_sessions.py.
|
||||
# Define the function tool with explicit invocation context.
|
||||
# The context parameter can also be declared as an untyped parameter with the name: ``ctx``.
|
||||
@tool(approval_mode="never_require")
|
||||
async def get_weather(
|
||||
location: Annotated[str, Field(description="The location to get the weather for.")],
|
||||
**kwargs: Any,
|
||||
ctx: FunctionInvocationContext,
|
||||
) -> str:
|
||||
"""Get the weather for a given location."""
|
||||
# Get session object from kwargs
|
||||
session = kwargs.get("session")
|
||||
session = ctx.session
|
||||
if session and isinstance(session, AgentSession) and session.service_session_id:
|
||||
print(f"Session ID: {session.service_session_id}.")
|
||||
|
||||
@@ -42,17 +40,19 @@ async def main() -> None:
|
||||
name="WeatherAgent",
|
||||
instructions="You are a helpful weather assistant.",
|
||||
tools=[get_weather],
|
||||
options={"store": True},
|
||||
default_options={"store": True},
|
||||
)
|
||||
|
||||
# Create a session
|
||||
session = agent.create_session()
|
||||
|
||||
# Run the agent with the session
|
||||
# Pass session via additional_function_arguments so tools can access it via **kwargs
|
||||
opts = {"additional_function_arguments": {"session": session}}
|
||||
print(f"Agent: {await agent.run('What is the weather in London?', session=session, options=opts)}")
|
||||
print(f"Agent: {await agent.run('What is the weather in Amsterdam?', session=session, options=opts)}")
|
||||
# Run the agent with the session; tools receive it via ctx.session.
|
||||
print(
|
||||
f"Agent: {await agent.run('What is the weather in London?', session=session)}"
|
||||
)
|
||||
print(
|
||||
f"Agent: {await agent.run('What is the weather in Amsterdam?', session=session)}"
|
||||
)
|
||||
print(f"Agent: {await agent.run('What cities did I ask about?', session=session)}")
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user