Python: added generic types to ChatOptions and ChatResponse/AgentResponse for Response Format (#3305)

* added generic types to ChatOptions and ChatResponse/AgentResponse for response format

* fix typevar import

* fix for older python versions

* fix missing import

* fixed imports

* fixed mypy

* mypy fix
This commit is contained in:
Eduard van Valkenburg
2026-01-28 22:23:02 +01:00
committed by GitHub
Unverified
parent 1f8463f9bb
commit 1226828ec2
42 changed files with 486 additions and 281 deletions
@@ -28,25 +28,21 @@ from ._http_service import AGUIHttpService
from ._message_adapters import agent_framework_messages_to_agui
from ._utils import convert_tools_to_agui_format
if TYPE_CHECKING:
from ._types import AGUIChatOptions
from typing import TypedDict
if sys.version_info >= (3, 13):
from typing import TypeVar
from typing import TypeVar # type: ignore # pragma: no cover
else:
from typing_extensions import TypeVar
from typing_extensions import TypeVar # type: ignore # pragma: no cover
if sys.version_info >= (3, 12):
from typing import override # type: ignore # pragma: no cover
else:
from typing_extensions import override # type: ignore[import] # pragma: no cover
if sys.version_info >= (3, 11):
from typing import Self # pragma: no cover
from typing import Self, TypedDict # pragma: no cover
else:
from typing_extensions import Self # pragma: no cover
from typing_extensions import Self, TypedDict # pragma: no cover
if TYPE_CHECKING:
from ._types import AGUIChatOptions
logger: logging.Logger = logging.getLogger(__name__)
@@ -85,7 +81,7 @@ def _apply_server_function_call_unwrap(chat_client: TBaseChatClient) -> TBaseCha
@wraps(original_get_response)
async def response_wrapper(self: Any, *args: Any, **kwargs: Any) -> ChatResponse:
response = await original_get_response(self, *args, **kwargs)
response: ChatResponse[Any] = await original_get_response(self, *args, **kwargs) # type: ignore[var-annotated]
if response.messages:
for message in response.messages:
_unwrap_server_function_call_contents(cast(MutableSequence[Content | dict[str, Any]], message.contents))
@@ -3,15 +3,23 @@
"""Type definitions for AG-UI integration."""
import sys
from typing import Any, TypedDict
from typing import Any, Generic
from agent_framework import ChatOptions
from pydantic import BaseModel, Field
if sys.version_info >= (3, 13):
from typing import TypeVar
from typing import TypeVar # type: ignore # pragma: no cover
else:
from typing_extensions import TypeVar
from typing_extensions import TypeVar # type: ignore # pragma: no cover
if sys.version_info >= (3, 11):
from typing import TypedDict # type: ignore # pragma: no cover
else:
from typing_extensions import TypedDict # type: ignore # pragma: no cover
TAGUIChatOptions = TypeVar("TAGUIChatOptions", bound=TypedDict, default="AGUIChatOptions", covariant=True) # type: ignore[valid-type]
TResponseModel = TypeVar("TResponseModel", bound=BaseModel | None, default=None)
class PredictStateConfig(TypedDict):
@@ -76,7 +84,7 @@ class AGUIRequest(BaseModel):
# region AG-UI Chat Options TypedDict
class AGUIChatOptions(ChatOptions, total=False):
class AGUIChatOptions(ChatOptions[TResponseModel], Generic[TResponseModel], total=False):
"""AG-UI protocol-specific chat options dict.
Extends base ChatOptions for the AG-UI (Agent-UI) protocol.
@@ -140,7 +148,5 @@ class AGUIChatOptions(ChatOptions, total=False):
AGUI_OPTION_TRANSLATIONS: dict[str, str] = {}
"""Maps ChatOptions keys to AG-UI parameter names (protocol uses standard names)."""
TAGUIChatOptions = TypeVar("TAGUIChatOptions", bound=TypedDict, default="AGUIChatOptions", covariant=True) # type: ignore[valid-type]
# endregion
@@ -12,6 +12,10 @@ if sys.version_info >= (3, 13):
from typing import TypeVar # type: ignore # pragma: no cover
else:
from typing_extensions import TypeVar # type: ignore # pragma: no cover
if sys.version_info >= (3, 11):
from typing import TypedDict # type: ignore # pragma: no cover
else:
from typing_extensions import TypedDict # type: ignore # pragma: no cover
if TYPE_CHECKING:
from agent_framework import ChatOptions
@@ -2,7 +2,7 @@
import sys
from collections.abc import AsyncIterable, MutableMapping, MutableSequence, Sequence
from typing import Any, ClassVar, Final, Generic, Literal, TypedDict
from typing import Any, ClassVar, Final, Generic, Literal
from agent_framework import (
AGENT_FRAMEWORK_USER_AGENT,
@@ -47,15 +47,18 @@ from anthropic.types.beta.beta_code_execution_tool_result_error import (
)
from pydantic import BaseModel, SecretStr, ValidationError
if sys.version_info >= (3, 13):
from typing import TypeVar
if sys.version_info >= (3, 11):
from typing import TypedDict # type: ignore # pragma: no cover
else:
from typing_extensions import TypeVar
from typing_extensions import TypedDict # type: ignore # pragma: no cover
if sys.version_info >= (3, 13):
from typing import TypeVar # type: ignore # pragma: no cover
else:
from typing_extensions import TypeVar # type: ignore # pragma: no cover
if sys.version_info >= (3, 12):
from typing import override # type: ignore # pragma: no cover
else:
from typing_extensions import override # type: ignore[import] # pragma: no cover
from typing_extensions import override # type: ignore # pragma: no cover
__all__ = [
"AnthropicChatOptions",
@@ -69,6 +72,8 @@ ANTHROPIC_DEFAULT_MAX_TOKENS: Final[int] = 1024
BETA_FLAGS: Final[list[str]] = ["mcp-client-2025-04-04", "code-execution-2025-08-25"]
STRUCTURED_OUTPUTS_BETA_FLAG: Final[str] = "structured-outputs-2025-11-13"
TResponseModel = TypeVar("TResponseModel", bound=BaseModel | None, default=None)
# region Anthropic Chat Options TypedDict
@@ -91,7 +96,7 @@ class ThinkingConfig(TypedDict, total=False):
budget_tokens: int
class AnthropicChatOptions(ChatOptions, total=False):
class AnthropicChatOptions(ChatOptions[TResponseModel], Generic[TResponseModel], total=False):
"""Anthropic-specific chat options.
Extends ChatOptions with options specific to Anthropic's Messages API.
@@ -2,7 +2,7 @@
import sys
from collections.abc import Callable, MutableMapping, Sequence
from typing import TYPE_CHECKING, Any, Generic, TypedDict, cast
from typing import TYPE_CHECKING, Any, Generic, cast
from agent_framework import (
AGENT_FRAMEWORK_USER_AGENT,
@@ -27,9 +27,13 @@ if TYPE_CHECKING:
from ._chat_client import AzureAIAgentOptions
if sys.version_info >= (3, 13):
from typing import Self, TypeVar # pragma: no cover
from typing import Self, TypeVar # type: ignore # pragma: no cover
else:
from typing_extensions import Self, TypeVar # pragma: no cover
from typing_extensions import Self, TypeVar # type: ignore # pragma: no cover
if sys.version_info >= (3, 11):
from typing import TypedDict # type: ignore # pragma: no cover
else:
from typing_extensions import TypedDict # type: ignore # pragma: no cover
# Type variable for options - allows typed ChatAgent[TOptions] returns
@@ -6,7 +6,7 @@ import os
import re
import sys
from collections.abc import AsyncIterable, Callable, Mapping, MutableMapping, MutableSequence, Sequence
from typing import Any, ClassVar, Generic, TypedDict
from typing import Any, ClassVar, Generic
from agent_framework import (
AGENT_FRAMEWORK_USER_AGENT,
@@ -96,9 +96,9 @@ if sys.version_info >= (3, 12):
else:
from typing_extensions import override # type: ignore[import] # pragma: no cover
if sys.version_info >= (3, 11):
from typing import Self # pragma: no cover
from typing import Self, TypedDict # type: ignore # pragma: no cover
else:
from typing_extensions import Self # pragma: no cover
from typing_extensions import Self, TypedDict # type: ignore # pragma: no cover
logger = get_logger("agent_framework.azure")
@@ -1265,7 +1265,7 @@ class AzureAIAgentClient(BaseChatClient[TAzureAIAgentOptions], Generic[TAzureAIA
| MutableMapping[str, Any]
| Sequence[ToolProtocol | Callable[..., Any] | MutableMapping[str, Any]]
| None = None,
default_options: TAzureAIAgentOptions | None = None,
default_options: TAzureAIAgentOptions | Mapping[str, Any] | None = None,
chat_message_store_factory: Callable[[], ChatMessageStoreProtocol] | None = None,
context_provider: ContextProvider | None = None,
middleware: Sequence[Middleware] | None = None,
@@ -2,7 +2,7 @@
import sys
from collections.abc import Callable, Mapping, MutableMapping, MutableSequence, Sequence
from typing import Any, ClassVar, Generic, TypedDict, TypeVar, cast
from typing import Any, ClassVar, Generic, TypeVar, cast
from agent_framework import (
AGENT_FRAMEWORK_USER_AGENT,
@@ -38,9 +38,9 @@ if sys.version_info >= (3, 12):
else:
from typing_extensions import override # type: ignore[import] # pragma: no cover
if sys.version_info >= (3, 11):
from typing import Self # pragma: no cover
from typing import Self, TypedDict # type: ignore # pragma: no cover
else:
from typing_extensions import Self # pragma: no cover
from typing_extensions import Self, TypedDict # type: ignore # pragma: no cover
logger = get_logger("agent_framework.azure")
@@ -551,7 +551,7 @@ class AzureAIClient(OpenAIBaseResponsesClient[TAzureAIClientOptions], Generic[TA
| MutableMapping[str, Any]
| Sequence[ToolProtocol | Callable[..., Any] | MutableMapping[str, Any]]
| None = None,
default_options: TAzureAIClientOptions | None = None,
default_options: TAzureAIClientOptions | Mapping[str, Any] | None = None,
chat_message_store_factory: Callable[[], ChatMessageStoreProtocol] | None = None,
context_provider: ContextProvider | None = None,
middleware: Sequence[Middleware] | None = None,
@@ -2,7 +2,7 @@
import sys
from collections.abc import Callable, MutableMapping, Sequence
from typing import Any, Generic, TypedDict
from typing import Any, Generic
from agent_framework import (
AGENT_FRAMEWORK_USER_AGENT,
@@ -33,9 +33,13 @@ from ._client import AzureAIClient, AzureAIProjectAgentOptions
from ._shared import AzureAISettings, create_text_format_config, from_azure_ai_tools, to_azure_ai_tools
if sys.version_info >= (3, 13):
from typing import Self, TypeVar # pragma: no cover
from typing import TypeVar # type: ignore # pragma: no cover
else:
from typing_extensions import Self, TypeVar # pragma: no cover
from typing_extensions import TypeVar # type: ignore # pragma: no cover
if sys.version_info >= (3, 11):
from typing import Self, TypedDict # type: ignore # pragma: no cover
else:
from typing_extensions import Self, TypedDict # type: ignore # pragma: no cover
logger = get_logger("agent_framework.azure")
@@ -5,7 +5,7 @@ import json
import sys
from collections import deque
from collections.abc import AsyncIterable, MutableMapping, MutableSequence, Sequence
from typing import Any, ClassVar, Generic, Literal, TypedDict
from typing import Any, ClassVar, Generic, Literal
from uuid import uuid4
from agent_framework import (
@@ -33,17 +33,20 @@ from agent_framework.observability import use_instrumentation
from boto3.session import Session as Boto3Session
from botocore.client import BaseClient
from botocore.config import Config as BotoConfig
from pydantic import SecretStr, ValidationError
from pydantic import BaseModel, SecretStr, ValidationError
if sys.version_info >= (3, 13):
from typing import TypeVar
from typing import TypeVar # type: ignore # pragma: no cover
else:
from typing_extensions import TypeVar
from typing_extensions import TypeVar # type: ignore # pragma: no cover
if sys.version_info >= (3, 12):
from typing import override # type: ignore # pragma: no cover
else:
from typing_extensions import override # type: ignore[import] # pragma: no cover
from typing_extensions import override # type: ignore # pragma: no cover
if sys.version_info >= (3, 11):
from typing import TypedDict # type: ignore # pragma: no cover
else:
from typing_extensions import TypedDict # type: ignore # pragma: no cover
logger = get_logger("agent_framework.bedrock")
@@ -55,6 +58,8 @@ __all__ = [
"BedrockSettings",
]
TResponseModel = TypeVar("TResponseModel", bound=BaseModel | None, default=None)
# region Bedrock Chat Options TypedDict
@@ -82,7 +87,7 @@ class BedrockGuardrailConfig(TypedDict, total=False):
"""How to process guardrails during streaming (sync blocks, async does not)."""
class BedrockChatOptions(ChatOptions, total=False):
class BedrockChatOptions(ChatOptions[TResponseModel], Generic[TResponseModel], total=False):
"""Amazon Bedrock Converse API-specific chat options dict.
Extends base ChatOptions with Bedrock-specific parameters.
@@ -30,9 +30,9 @@ from chatkit.types import (
)
if sys.version_info >= (3, 11):
from typing import assert_never
from typing import assert_never # type:ignore # pragma: no cover
else:
from typing_extensions import assert_never
from typing_extensions import assert_never # type:ignore # pragma: no cover
logger = logging.getLogger(__name__)
+70 -15
View File
@@ -3,7 +3,7 @@
import inspect
import re
import sys
from collections.abc import AsyncIterable, Awaitable, Callable, MutableMapping, Sequence
from collections.abc import AsyncIterable, Awaitable, Callable, Mapping, MutableMapping, Sequence
from contextlib import AbstractAsyncContextManager, AsyncExitStack
from copy import deepcopy
from itertools import chain
@@ -13,8 +13,8 @@ from typing import (
ClassVar,
Generic,
Protocol,
TypedDict,
cast,
overload,
runtime_checkable,
)
from uuid import uuid4
@@ -43,24 +43,25 @@ from ._types import (
from .exceptions import AgentExecutionException, AgentInitializationError
from .observability import use_agent_instrumentation
if TYPE_CHECKING:
from ._types import ChatOptions
if sys.version_info >= (3, 13):
from typing import TypeVar
from typing import TypeVar # type: ignore # pragma: no cover
else:
from typing_extensions import TypeVar
from typing_extensions import TypeVar # type: ignore # pragma: no cover
if sys.version_info >= (3, 12):
from typing import override # type: ignore # pragma: no cover
else:
from typing_extensions import override # type: ignore[import] # pragma: no cover
if sys.version_info >= (3, 11):
from typing import Self # pragma: no cover
from typing import Self, TypedDict # pragma: no cover
else:
from typing_extensions import Self # pragma: no cover
from typing_extensions import Self, TypedDict # pragma: no cover
if TYPE_CHECKING:
from ._types import ChatOptions
TResponseModel = TypeVar("TResponseModel", bound=BaseModel | None, default=None, covariant=True)
TResponseModelT = TypeVar("TResponseModelT", bound=BaseModel)
logger = get_logger("agent_framework")
@@ -622,6 +623,7 @@ class ChatAgent(BaseAgent, Generic[TOptions_co]): # type: ignore[misc]
provider-specific options including temperature, max_tokens, model_id,
tool_choice, and provider-specific options like reasoning_effort.
You can also create your own TypedDict for custom chat clients.
Note: response_format typing does not flow into run outputs when set via default_options.
These can be overridden at runtime via the ``options`` parameter of ``run()`` and ``run_stream()``.
tools: The tools to use for the request.
kwargs: Any additional keyword arguments. Will be stored as ``additional_properties``.
@@ -657,6 +659,14 @@ class ChatAgent(BaseAgent, Generic[TOptions_co]): # type: ignore[misc]
# Get tools from options or named parameter (named param takes precedence)
tools_ = tools if tools is not None else opts.pop("tools", None)
tools_ = cast(
ToolProtocol
| Callable[..., Any]
| MutableMapping[str, Any]
| list[ToolProtocol | Callable[..., Any] | MutableMapping[str, Any]]
| None,
tools_,
)
# Handle instructions - named parameter takes precedence over options
instructions_ = instructions if instructions is not None else opts.pop("instructions", None)
@@ -742,6 +752,7 @@ class ChatAgent(BaseAgent, Generic[TOptions_co]): # type: ignore[misc]
): # type: ignore[reportAttributeAccessIssue, attr-defined]
self.chat_client._update_agent_name_and_description(self.name, self.description) # type: ignore[reportAttributeAccessIssue, attr-defined]
@overload
async def run(
self,
messages: str | ChatMessage | Sequence[str | ChatMessage] | None = None,
@@ -752,9 +763,38 @@ class ChatAgent(BaseAgent, Generic[TOptions_co]): # type: ignore[misc]
| MutableMapping[str, Any]
| list[ToolProtocol | Callable[..., Any] | MutableMapping[str, Any]]
| None = None,
options: TOptions_co | None = None,
options: "ChatOptions[TResponseModelT]",
**kwargs: Any,
) -> AgentResponse:
) -> AgentResponse[TResponseModelT]: ...
@overload
async def run(
self,
messages: str | ChatMessage | Sequence[str | ChatMessage] | None = None,
*,
thread: AgentThread | None = None,
tools: ToolProtocol
| Callable[..., Any]
| MutableMapping[str, Any]
| list[ToolProtocol | Callable[..., Any] | MutableMapping[str, Any]]
| None = None,
options: TOptions_co | Mapping[str, Any] | "ChatOptions[Any]" | None = None,
**kwargs: Any,
) -> AgentResponse[Any]: ...
async def run(
self,
messages: str | ChatMessage | Sequence[str | ChatMessage] | None = None,
*,
thread: AgentThread | None = None,
tools: ToolProtocol
| Callable[..., Any]
| MutableMapping[str, Any]
| list[ToolProtocol | Callable[..., Any] | MutableMapping[str, Any]]
| None = None,
options: TOptions_co | Mapping[str, Any] | "ChatOptions[Any]" | None = None,
**kwargs: Any,
) -> AgentResponse[Any]:
"""Run the agent with the given messages and options.
Note:
@@ -784,6 +824,14 @@ class ChatAgent(BaseAgent, Generic[TOptions_co]): # type: ignore[misc]
# Get tools from options or named parameter (named param takes precedence)
tools_ = tools if tools is not None else opts.pop("tools", None)
tools_ = cast(
ToolProtocol
| Callable[..., Any]
| MutableMapping[str, Any]
| list[ToolProtocol | Callable[..., Any] | MutableMapping[str, Any]]
| None,
tools_,
)
input_messages = normalize_messages(messages)
thread, run_chat_options, thread_messages = await self._prepare_thread_and_messages(
@@ -860,12 +908,19 @@ class ChatAgent(BaseAgent, Generic[TOptions_co]): # type: ignore[misc]
response.messages,
**{k: v for k, v in kwargs.items() if k != "thread"},
)
response_format = co.get("response_format")
if not (
response_format is not None and isinstance(response_format, type) and issubclass(response_format, BaseModel)
):
response_format = None
return AgentResponse(
messages=response.messages,
response_id=response.response_id,
created_at=response.created_at,
usage_details=response.usage_details,
value=response.value,
response_format=response_format,
raw_representation=response,
additional_properties=response.additional_properties,
)
@@ -880,7 +935,7 @@ class ChatAgent(BaseAgent, Generic[TOptions_co]): # type: ignore[misc]
| MutableMapping[str, Any]
| list[ToolProtocol | Callable[..., Any] | MutableMapping[str, Any]]
| None = None,
options: TOptions_co | None = None,
options: TOptions_co | Mapping[str, Any] | None = None,
**kwargs: Any,
) -> AsyncIterable[AgentResponseUpdate]:
"""Stream the agent with the given messages and options.
@@ -6,6 +6,7 @@ from abc import ABC, abstractmethod
from collections.abc import (
AsyncIterable,
Callable,
Mapping,
MutableMapping,
MutableSequence,
Sequence,
@@ -17,9 +18,13 @@ from typing import (
Generic,
Protocol,
TypedDict,
cast,
overload,
runtime_checkable,
)
from pydantic import BaseModel
from ._logging import get_logger
from ._memory import ContextProvider
from ._middleware import (
@@ -45,9 +50,10 @@ from ._types import (
)
if sys.version_info >= (3, 13):
from typing import TypeVar
from typing import TypeVar # type: ignore # pragma: no cover
else:
from typing_extensions import TypeVar
from typing_extensions import TypeVar # type: ignore # pragma: no cover
if TYPE_CHECKING:
from ._agents import ChatAgent
@@ -120,6 +126,16 @@ class ChatClientProtocol(Protocol[TOptions_contra]): #
additional_properties: dict[str, Any]
@overload
async def get_response(
self,
messages: str | ChatMessage | Sequence[str | ChatMessage],
*,
options: "ChatOptions[TResponseModelT]",
**kwargs: Any,
) -> "ChatResponse[TResponseModelT]": ...
@overload
async def get_response(
self,
messages: str | ChatMessage | Sequence[str | ChatMessage],
@@ -175,6 +191,9 @@ TOptions_co = TypeVar(
covariant=True,
)
TResponseModel = TypeVar("TResponseModel", bound=BaseModel | None, default=None, covariant=True)
TResponseModelT = TypeVar("TResponseModelT", bound=BaseModel)
class BaseChatClient(SerializationMixin, ABC, Generic[TOptions_co]):
"""Base class for chat clients.
@@ -319,13 +338,31 @@ class BaseChatClient(SerializationMixin, ABC, Generic[TOptions_co]):
# region Public method
@overload
async def get_response(
self,
messages: str | ChatMessage | Sequence[str | ChatMessage],
*,
options: "ChatOptions[TResponseModelT]",
**kwargs: Any,
) -> ChatResponse[TResponseModelT]: ...
@overload
async def get_response(
self,
messages: str | ChatMessage | Sequence[str | ChatMessage],
*,
options: TOptions_co | None = None,
**kwargs: Any,
) -> ChatResponse:
) -> ChatResponse: ...
async def get_response(
self,
messages: str | ChatMessage | Sequence[str | ChatMessage],
*,
options: TOptions_co | "ChatOptions[Any]" | None = None,
**kwargs: Any,
) -> ChatResponse[Any]:
"""Get a response from a chat client.
Args:
@@ -389,7 +426,7 @@ class BaseChatClient(SerializationMixin, ABC, Generic[TOptions_co]):
| MutableMapping[str, Any]
| Sequence[ToolProtocol | Callable[..., Any] | MutableMapping[str, Any]]
| None = None,
default_options: TOptions_co | None = None,
default_options: TOptions_co | Mapping[str, Any] | None = None,
chat_message_store_factory: Callable[[], ChatMessageStoreProtocol] | None = None,
context_provider: ContextProvider | None = None,
middleware: Sequence[Middleware] | None = None,
@@ -410,6 +447,8 @@ class BaseChatClient(SerializationMixin, ABC, Generic[TOptions_co]):
default_options: A TypedDict containing chat options. When using a typed client like
``OpenAIChatClient``, this enables IDE autocomplete for provider-specific options
including temperature, max_tokens, model_id, tool_choice, and more.
Note: response_format typing does not flow into run outputs when set via default_options,
and dict literals are accepted without specialized option typing.
chat_message_store_factory: Factory function to create an instance of ChatMessageStoreProtocol.
If not provided, the default in-memory store will be used.
context_provider: Context providers to include during agent invocation.
@@ -446,7 +485,7 @@ class BaseChatClient(SerializationMixin, ABC, Generic[TOptions_co]):
description=description,
instructions=instructions,
tools=tools,
default_options=default_options,
default_options=cast(Any, default_options),
chat_message_store_factory=chat_message_store_factory,
context_provider=context_provider,
middleware=middleware,
@@ -1,11 +1,12 @@
# Copyright (c) Microsoft. All rights reserved.
import inspect
import sys
from abc import ABC, abstractmethod
from collections.abc import AsyncIterable, Awaitable, Callable, Mapping, MutableSequence, Sequence
from enum import Enum
from functools import update_wrapper
from typing import TYPE_CHECKING, Any, ClassVar, Generic, TypeAlias, TypedDict, TypeVar
from typing import TYPE_CHECKING, Any, ClassVar, Generic, TypeAlias, TypeVar
from ._serialization import SerializationMixin
from ._types import AgentResponse, AgentResponseUpdate, ChatMessage, normalize_messages, prepare_messages
@@ -20,6 +21,10 @@ if TYPE_CHECKING:
from ._tools import FunctionTool
from ._types import ChatResponse, ChatResponseUpdate
if sys.version_info >= (3, 11):
from typing import TypedDict # type: ignore # pragma: no cover
else:
from typing_extensions import TypedDict # type: ignore # pragma: no cover
__all__ = [
"AgentMiddleware",
+13 -13
View File
@@ -24,12 +24,11 @@ from typing import (
Generic,
Literal,
Protocol,
TypedDict,
TypeVar,
Union,
cast,
get_args,
get_origin,
overload,
runtime_checkable,
)
@@ -42,7 +41,7 @@ from .exceptions import ChatClientInitializationError, ToolException
from .observability import (
OPERATION_DURATION_BUCKET_BOUNDARIES,
OtelAttr,
capture_exception, # type: ignore
capture_exception,
get_function_span,
get_function_span_attributes,
get_meter,
@@ -57,20 +56,21 @@ if TYPE_CHECKING:
Content,
)
from typing import overload
# TypeVar with defaults support for Python < 3.13
if sys.version_info >= (3, 13):
from typing import TypeVar as TypeVar # type: ignore # pragma: no cover
else:
from typing_extensions import TypeVar as TypeVar # type: ignore[import] # pragma: no cover
if sys.version_info >= (3, 12):
from typing import override # type: ignore # pragma: no cover
else:
from typing_extensions import override # type: ignore[import] # pragma: no cover
# TypeVar with defaults support for Python < 3.13
if sys.version_info >= (3, 13):
from typing import TypeVar as TypeVarWithDefaults # type: ignore # pragma: no cover
if sys.version_info >= (3, 11):
from typing import TypedDict # type: ignore # pragma: no cover
else:
from typing_extensions import (
TypeVar as TypeVarWithDefaults, # type: ignore[import] # pragma: no cover
)
from typing_extensions import TypedDict # type: ignore # pragma: no cover
logger = get_logger()
@@ -97,8 +97,8 @@ DEFAULT_MAX_CONSECUTIVE_ERRORS_PER_REQUEST: Final[int] = 3
TChatClient = TypeVar("TChatClient", bound="ChatClientProtocol[Any]")
# region Helpers
ArgsT = TypeVarWithDefaults("ArgsT", bound=BaseModel, default=BaseModel)
ReturnT = TypeVarWithDefaults("ReturnT", default=Any)
ArgsT = TypeVar("ArgsT", bound=BaseModel, default=BaseModel)
ReturnT = TypeVar("ReturnT", default=Any)
def _parse_inputs(
+125 -21
View File
@@ -1,8 +1,8 @@
# Copyright (c) Microsoft. All rights reserved.
import base64
import json
import re
import sys
from collections.abc import (
AsyncIterable,
Callable,
@@ -12,7 +12,7 @@ from collections.abc import (
Sequence,
)
from copy import deepcopy
from typing import Any, ClassVar, Final, Literal, TypedDict, TypeVar, overload
from typing import TYPE_CHECKING, Any, ClassVar, Final, Generic, Literal, cast, overload
from pydantic import BaseModel, ValidationError
@@ -21,6 +21,15 @@ from ._serialization import SerializationMixin
from ._tools import ToolProtocol, tool
from .exceptions import AdditionItemMismatch, ContentError
if sys.version_info >= (3, 13):
from typing import TypeVar # pragma: no cover
else:
from typing_extensions import TypeVar # pragma: no cover
if sys.version_info >= (3, 11):
from typing import TypedDict # type: ignore # pragma: no cover
else:
from typing_extensions import TypedDict # type: ignore # pragma: no cover
__all__ = [
"AgentResponse",
"AgentResponseUpdate",
@@ -312,6 +321,8 @@ TEmbedding = TypeVar("TEmbedding")
TChatResponse = TypeVar("TChatResponse", bound="ChatResponse")
TToolMode = TypeVar("TToolMode", bound="ToolMode")
TAgentRunResponse = TypeVar("TAgentRunResponse", bound="AgentResponse")
TResponseModel = TypeVar("TResponseModel", bound=BaseModel | None, default=None, covariant=True)
TResponseModelT = TypeVar("TResponseModelT", bound=BaseModel)
CreatedAtT = str # Use a datetimeoffset type? Or a more specific type like datetime.datetime?
@@ -1911,7 +1922,7 @@ def _finalize_response(response: "ChatResponse | AgentResponse") -> None:
_coalesce_text_content(msg.contents, "text_reasoning")
class ChatResponse(SerializationMixin):
class ChatResponse(SerializationMixin, Generic[TResponseModel]):
"""Represents the response to a chat request.
Attributes:
@@ -1974,7 +1985,7 @@ class ChatResponse(SerializationMixin):
created_at: CreatedAtT | None = None,
finish_reason: FinishReason | None = None,
usage_details: UsageDetails | None = None,
value: Any | None = None,
value: TResponseModel | None = None,
response_format: type[BaseModel] | None = None,
additional_properties: dict[str, Any] | None = None,
raw_representation: Any | None = None,
@@ -2009,7 +2020,7 @@ class ChatResponse(SerializationMixin):
created_at: CreatedAtT | None = None,
finish_reason: FinishReason | None = None,
usage_details: UsageDetails | None = None,
value: Any | None = None,
value: TResponseModel | None = None,
response_format: type[BaseModel] | None = None,
additional_properties: dict[str, Any] | None = None,
raw_representation: Any | None = None,
@@ -2044,7 +2055,7 @@ class ChatResponse(SerializationMixin):
created_at: CreatedAtT | None = None,
finish_reason: FinishReason | dict[str, Any] | None = None,
usage_details: UsageDetails | dict[str, Any] | None = None,
value: Any | None = None,
value: TResponseModel | None = None,
response_format: type[BaseModel] | None = None,
additional_properties: dict[str, Any] | None = None,
raw_representation: Any | None = None,
@@ -2101,13 +2112,31 @@ class ChatResponse(SerializationMixin):
self.created_at = created_at
self.finish_reason = finish_reason
self.usage_details = usage_details
self._value: Any | None = value
self._value: TResponseModel | None = value
self._response_format: type[BaseModel] | None = response_format
self._value_parsed: bool = value is not None
self.additional_properties = additional_properties or {}
self.additional_properties.update(kwargs or {})
self.raw_representation: Any | list[Any] | None = raw_representation
@overload
@classmethod
def from_chat_response_updates(
cls: type["ChatResponse[Any]"],
updates: Sequence["ChatResponseUpdate"],
*,
output_format_type: type[TResponseModelT],
) -> "ChatResponse[TResponseModelT]": ...
@overload
@classmethod
def from_chat_response_updates(
cls: type["ChatResponse[Any]"],
updates: Sequence["ChatResponseUpdate"],
*,
output_format_type: None = None,
) -> "ChatResponse[Any]": ...
@classmethod
def from_chat_response_updates(
cls: type[TChatResponse],
@@ -2146,12 +2175,30 @@ class ChatResponse(SerializationMixin):
msg.try_parse_value(output_format_type)
return msg
@overload
@classmethod
async def from_chat_response_generator(
cls: type["ChatResponse[Any]"],
updates: AsyncIterable["ChatResponseUpdate"],
*,
output_format_type: type[TResponseModelT],
) -> "ChatResponse[TResponseModelT]": ...
@overload
@classmethod
async def from_chat_response_generator(
cls: type["ChatResponse[Any]"],
updates: AsyncIterable["ChatResponseUpdate"],
*,
output_format_type: None = None,
) -> "ChatResponse[Any]": ...
@classmethod
async def from_chat_response_generator(
cls: type[TChatResponse],
updates: AsyncIterable["ChatResponseUpdate"],
*,
output_format_type: type[BaseModel] | Mapping[str, Any] | None = None,
output_format_type: type[BaseModel] | None = None,
) -> TChatResponse:
"""Joins multiple updates into a single ChatResponse.
@@ -2187,7 +2234,7 @@ class ChatResponse(SerializationMixin):
return ("\n".join(message.text for message in self.messages if isinstance(message, ChatMessage))).strip()
@property
def value(self) -> Any | None:
def value(self) -> TResponseModel | None:
"""Get the parsed structured output value.
If a response_format was provided and parsing hasn't been attempted yet,
@@ -2203,14 +2250,20 @@ class ChatResponse(SerializationMixin):
and isinstance(self._response_format, type)
and issubclass(self._response_format, BaseModel)
):
self._value = self._response_format.model_validate_json(self.text)
self._value = cast(TResponseModel, self._response_format.model_validate_json(self.text))
self._value_parsed = True
return self._value
def __str__(self) -> str:
return self.text
def try_parse_value(self, output_format_type: type[_T] | None = None) -> _T | None:
@overload
def try_parse_value(self, output_format_type: type[TResponseModelT]) -> TResponseModelT | None: ...
@overload
def try_parse_value(self, output_format_type: None = None) -> TResponseModel | None: ...
def try_parse_value(self, output_format_type: type[BaseModel] | None = None) -> BaseModel | None:
"""Try to parse the text into a typed value.
This is the safe alternative to accessing the value property directly.
@@ -2238,7 +2291,7 @@ class ChatResponse(SerializationMixin):
try:
parsed_value = format_type.model_validate_json(self.text) # type: ignore[reportUnknownMemberType]
if use_cache:
self._value = parsed_value
self._value = cast(TResponseModel, parsed_value)
self._value_parsed = True
return parsed_value # type: ignore[return-value]
except ValidationError as ex:
@@ -2376,7 +2429,7 @@ class ChatResponseUpdate(SerializationMixin):
# region AgentResponse
class AgentResponse(SerializationMixin):
class AgentResponse(SerializationMixin, Generic[TResponseModel]):
"""Represents the response to an Agent run request.
Provides one or more response messages and metadata about the response.
@@ -2428,7 +2481,7 @@ class AgentResponse(SerializationMixin):
response_id: str | None = None,
created_at: CreatedAtT | None = None,
usage_details: UsageDetails | MutableMapping[str, Any] | None = None,
value: Any | None = None,
value: TResponseModel | None = None,
response_format: type[BaseModel] | None = None,
raw_representation: Any | None = None,
additional_properties: dict[str, Any] | None = None,
@@ -2469,7 +2522,7 @@ class AgentResponse(SerializationMixin):
self.response_id = response_id
self.created_at = created_at
self.usage_details = usage_details
self._value: Any | None = value
self._value: TResponseModel | None = value
self._response_format: type[BaseModel] | None = response_format
self._value_parsed: bool = value is not None
self.additional_properties = additional_properties or {}
@@ -2482,7 +2535,7 @@ class AgentResponse(SerializationMixin):
return "".join(msg.text for msg in self.messages) if self.messages else ""
@property
def value(self) -> Any | None:
def value(self) -> TResponseModel | None:
"""Get the parsed structured output value.
If a response_format was provided and parsing hasn't been attempted yet,
@@ -2498,7 +2551,7 @@ class AgentResponse(SerializationMixin):
and isinstance(self._response_format, type)
and issubclass(self._response_format, BaseModel)
):
self._value = self._response_format.model_validate_json(self.text)
self._value = cast(TResponseModel, self._response_format.model_validate_json(self.text))
self._value_parsed = True
return self._value
@@ -2512,6 +2565,24 @@ class AgentResponse(SerializationMixin):
if isinstance(content, Content) and content.user_input_request
]
@overload
@classmethod
def from_agent_run_response_updates(
cls: type["AgentResponse[Any]"],
updates: Sequence["AgentResponseUpdate"],
*,
output_format_type: type[TResponseModelT],
) -> "AgentResponse[TResponseModelT]": ...
@overload
@classmethod
def from_agent_run_response_updates(
cls: type["AgentResponse[Any]"],
updates: Sequence["AgentResponseUpdate"],
*,
output_format_type: None = None,
) -> "AgentResponse[Any]": ...
@classmethod
def from_agent_run_response_updates(
cls: type[TAgentRunResponse],
@@ -2535,6 +2606,24 @@ class AgentResponse(SerializationMixin):
msg.try_parse_value(output_format_type)
return msg
@overload
@classmethod
async def from_agent_response_generator(
cls: type["AgentResponse[Any]"],
updates: AsyncIterable["AgentResponseUpdate"],
*,
output_format_type: type[TResponseModelT],
) -> "AgentResponse[TResponseModelT]": ...
@overload
@classmethod
async def from_agent_response_generator(
cls: type["AgentResponse[Any]"],
updates: AsyncIterable["AgentResponseUpdate"],
*,
output_format_type: None = None,
) -> "AgentResponse[Any]": ...
@classmethod
async def from_agent_response_generator(
cls: type[TAgentRunResponse],
@@ -2561,7 +2650,13 @@ class AgentResponse(SerializationMixin):
def __str__(self) -> str:
return self.text
def try_parse_value(self, output_format_type: type[_T] | None = None) -> _T | None:
@overload
def try_parse_value(self, output_format_type: type[TResponseModelT]) -> TResponseModelT | None: ...
@overload
def try_parse_value(self, output_format_type: None = None) -> TResponseModel | None: ...
def try_parse_value(self, output_format_type: type[BaseModel] | None = None) -> BaseModel | None:
"""Try to parse the text into a typed value.
This is the safe alternative when you need to parse the response text into a typed value.
@@ -2589,7 +2684,7 @@ class AgentResponse(SerializationMixin):
try:
parsed_value = format_type.model_validate_json(self.text) # type: ignore[reportUnknownMemberType]
if use_cache:
self._value = parsed_value
self._value = cast(TResponseModel, parsed_value)
self._value_parsed = True
return parsed_value # type: ignore[return-value]
except ValidationError as ex:
@@ -2718,7 +2813,7 @@ class ToolMode(TypedDict, total=False):
# region TypedDict-based Chat Options
class ChatOptions(TypedDict, total=False):
class _ChatOptionsBase(TypedDict, total=False):
"""Common request settings for AI services as a TypedDict.
All fields are optional (total=False) to allow partial specification.
@@ -2771,7 +2866,7 @@ class ChatOptions(TypedDict, total=False):
allow_multiple_tool_calls: bool
# Response configuration
response_format: type[BaseModel] | dict[str, Any]
response_format: type[BaseModel] | Mapping[str, Any] | None
# Metadata
metadata: dict[str, Any]
@@ -2783,6 +2878,15 @@ class ChatOptions(TypedDict, total=False):
instructions: str
if TYPE_CHECKING:
class ChatOptions(_ChatOptionsBase, Generic[TResponseModel], total=False):
response_format: type[TResponseModel] | Mapping[str, Any] | None # type: ignore[misc]
else:
ChatOptions = _ChatOptionsBase
# region Chat Options Utility Functions
@@ -2,11 +2,12 @@
import json
import logging
import sys
import uuid
from collections.abc import AsyncIterable
from dataclasses import dataclass
from datetime import datetime, timezone
from typing import TYPE_CHECKING, Any, ClassVar, TypedDict, cast
from typing import TYPE_CHECKING, Any, ClassVar, cast
from agent_framework import (
AgentResponse,
@@ -32,6 +33,11 @@ from ._events import (
from ._message_utils import normalize_messages_input
from ._typing_utils import is_type_compatible
if sys.version_info >= (3, 11):
from typing import TypedDict # type: ignore # pragma: no cover
else:
from typing_extensions import TypedDict # type: ignore # pragma: no cover
if TYPE_CHECKING:
from ._workflow import Workflow
@@ -16,7 +16,7 @@ from ._const import WORKFLOW_RUN_KWARGS_KEY
from ._conversation_state import encode_chat_messages
from ._events import (
AgentRunEvent,
AgentRunUpdateEvent, # type: ignore[reportPrivateUsage]
AgentRunUpdateEvent,
)
from ._executor import Executor, handler
from ._message_utils import normalize_messages_input
@@ -24,9 +24,9 @@ from ._request_info_mixin import response_handler
from ._workflow_context import WorkflowContext
if sys.version_info >= (3, 12):
from typing import override
from typing import override # type: ignore # pragma: no cover
else:
from typing_extensions import override
from typing_extensions import override # type: ignore # pragma: no cover
logger = logging.getLogger(__name__)
@@ -22,9 +22,9 @@ from ._orchestration_request_info import AgentApprovalExecutor
from ._workflow_context import WorkflowContext
if sys.version_info >= (3, 12):
from typing import override
from typing import override # type: ignore # pragma: no cover
else:
from typing_extensions import override
from typing_extensions import override # type: ignore # pragma: no cover
logger = logging.getLogger(__name__)
@@ -52,9 +52,9 @@ from ._workflow_builder import WorkflowBuilder
from ._workflow_context import WorkflowContext
if sys.version_info >= (3, 12):
from typing import override
from typing import override # type: ignore # pragma: no cover
else:
from typing_extensions import override
from typing_extensions import override # type: ignore # pragma: no cover
logger = logging.getLogger(__name__)
@@ -55,9 +55,9 @@ from ._workflow_builder import WorkflowBuilder
from ._workflow_context import WorkflowContext
if sys.version_info >= (3, 12):
from typing import override
from typing import override # type: ignore # pragma: no cover
else:
from typing_extensions import override
from typing_extensions import override # type: ignore # pragma: no cover
logger = logging.getLogger(__name__)
@@ -39,15 +39,14 @@ from ._workflow import Workflow
from ._workflow_builder import WorkflowBuilder
from ._workflow_context import WorkflowContext
if sys.version_info >= (3, 11):
from typing import Self
else:
from typing_extensions import Self
if sys.version_info >= (3, 12):
from typing import override
from typing import override # type: ignore # pragma: no cover
else:
from typing_extensions import override
from typing_extensions import override # type: ignore # pragma: no cover
if sys.version_info >= (3, 11):
from typing import Self # type: ignore # pragma: no cover
else:
from typing_extensions import Self # type: ignore # pragma: no cover
logger = logging.getLogger(__name__)
@@ -2,11 +2,12 @@
import asyncio
import logging
import sys
import uuid
from copy import copy
from dataclasses import dataclass
from enum import Enum
from typing import Any, Protocol, TypedDict, TypeVar, runtime_checkable
from typing import Any, Protocol, TypeVar, runtime_checkable
from ._checkpoint import CheckpointStorage, WorkflowCheckpoint
from ._checkpoint_encoding import decode_checkpoint_value, encode_checkpoint_value
@@ -15,6 +16,11 @@ from ._events import RequestInfoEvent, WorkflowEvent
from ._shared_state import SharedState
from ._typing_utils import is_instance_of
if sys.version_info >= (3, 11):
from typing import TypedDict # type: ignore # pragma: no cover
else:
from typing_extensions import TypedDict # type: ignore # pragma: no cover
logger = logging.getLogger(__name__)
T = TypeVar("T")
@@ -8,9 +8,8 @@ from typing import Any
from typing_extensions import deprecated
from agent_framework import AgentThread
from .._agents import AgentProtocol
from .._threads import AgentThread
from ..observability import OtelAttr, capture_exception, create_workflow_span
from ._agent_executor import AgentExecutor
from ._checkpoint import CheckpointStorage
@@ -34,9 +33,9 @@ from ._validation import validate_workflow_graph
from ._workflow import Workflow
if sys.version_info >= (3, 11):
from typing import Self # pragma: no cover
from typing import Self # type: ignore # pragma: no cover
else:
from typing_extensions import Self # pragma: no cover
from typing_extensions import Self # type: ignore # pragma: no cover
logger = logging.getLogger(__name__)
@@ -26,9 +26,9 @@ from ._workflow import WorkflowRunResult
from ._workflow_context import WorkflowContext
if sys.version_info >= (3, 12):
from typing import override
from typing import override # type: ignore # pragma: no cover
else:
from typing_extensions import override
from typing_extensions import override # type: ignore # pragma: no cover
logger = logging.getLogger(__name__)
@@ -19,8 +19,10 @@ if sys.version_info >= (3, 13):
from typing import TypeVar # type: ignore # pragma: no cover
else:
from typing_extensions import TypeVar # type: ignore # pragma: no cover
from typing import TypedDict
if sys.version_info >= (3, 11):
from typing import TypedDict # type: ignore # pragma: no cover
else:
from typing_extensions import TypedDict # type: ignore # pragma: no cover
__all__ = ["AzureOpenAIAssistantsClient"]
@@ -4,13 +4,13 @@ import json
import logging
import sys
from collections.abc import Mapping
from typing import Any, Generic, TypedDict
from typing import Any, Generic
from azure.core.credentials import TokenCredential
from openai.lib.azure import AsyncAzureADTokenProvider, AsyncAzureOpenAI
from openai.types.chat.chat_completion import Choice
from openai.types.chat.chat_completion_chunk import Choice as ChunkChoice
from pydantic import ValidationError
from pydantic import BaseModel, ValidationError
from agent_framework import (
Annotation,
@@ -36,12 +36,18 @@ else:
if sys.version_info >= (3, 12):
from typing import override # type: ignore # pragma: no cover
else:
from typing_extensions import override # type: ignore[import] # pragma: no cover
from typing_extensions import override # type: ignore # pragma: no cover
if sys.version_info >= (3, 11):
from typing import TypedDict # type: ignore # pragma: no cover
else:
from typing_extensions import TypedDict # type: ignore # pragma: no cover
logger: logging.Logger = logging.getLogger(__name__)
__all__ = ["AzureOpenAIChatClient", "AzureOpenAIChatOptions", "AzureUserSecurityContext"]
TResponseModel = TypeVar("TResponseModel", bound=BaseModel | None, default=None)
# region Azure OpenAI Chat Options TypedDict
@@ -68,7 +74,7 @@ class AzureUserSecurityContext(TypedDict, total=False):
"""The original client's IP address."""
class AzureOpenAIChatOptions(OpenAIChatOptions, total=False):
class AzureOpenAIChatOptions(OpenAIChatOptions[TResponseModel], Generic[TResponseModel], total=False):
"""Azure OpenAI-specific chat options dict.
Extends OpenAIChatOptions with Azure-specific options including
@@ -2,34 +2,38 @@
import sys
from collections.abc import Mapping
from typing import TYPE_CHECKING, Any, Generic, TypedDict
from typing import TYPE_CHECKING, Any, Generic
from urllib.parse import urljoin
from azure.core.credentials import TokenCredential
from openai.lib.azure import AsyncAzureADTokenProvider, AsyncAzureOpenAI
from pydantic import ValidationError
from agent_framework import use_chat_middleware, use_function_invocation
from agent_framework.exceptions import ServiceInitializationError
from agent_framework.observability import use_instrumentation
from agent_framework.openai._responses_client import OpenAIBaseResponsesClient
from .._middleware import use_chat_middleware
from .._tools import use_function_invocation
from ..exceptions import ServiceInitializationError
from ..observability import use_instrumentation
from ..openai._responses_client import OpenAIBaseResponsesClient
from ._shared import (
AzureOpenAIConfigMixin,
AzureOpenAISettings,
)
if TYPE_CHECKING:
from agent_framework.openai._responses_client import OpenAIResponsesOptions
if sys.version_info >= (3, 12):
from typing import override # type: ignore # pragma: no cover
else:
from typing_extensions import override # type: ignore[import] # pragma: no cover
if sys.version_info >= (3, 13):
from typing import TypeVar # type: ignore # pragma: no cover
else:
from typing_extensions import TypeVar # type: ignore # pragma: no cover
if sys.version_info >= (3, 12):
from typing import override # type: ignore # pragma: no cover
else:
from typing_extensions import override # type: ignore # pragma: no cover
if sys.version_info >= (3, 11):
from typing import TypedDict # type: ignore # pragma: no cover
else:
from typing_extensions import TypedDict # type: ignore # pragma: no cover
if TYPE_CHECKING:
from ..openai._responses_client import OpenAIResponsesOptions
__all__ = ["AzureOpenAIResponsesClient"]
@@ -2,7 +2,7 @@
import sys
from collections.abc import Awaitable, Callable, MutableMapping, Sequence
from typing import TYPE_CHECKING, Any, Generic, TypedDict, cast
from typing import TYPE_CHECKING, Any, Generic, cast
from openai import AsyncOpenAI
from openai.types.beta.assistant import Assistant
@@ -21,10 +21,13 @@ if TYPE_CHECKING:
from ._assistants_client import OpenAIAssistantsOptions
if sys.version_info >= (3, 13):
from typing import Self, TypeVar # pragma: no cover
from typing import TypeVar # type:ignore # pragma: no cover
else:
from typing_extensions import Self, TypeVar # pragma: no cover
from typing_extensions import TypeVar # type:ignore # pragma: no cover
if sys.version_info >= (3, 11):
from typing import Self, TypedDict # type:ignore # pragma: no cover
else:
from typing_extensions import Self, TypedDict # type:ignore # pragma: no cover
__all__ = ["OpenAIAssistantProvider"]
@@ -9,12 +9,8 @@ from collections.abc import (
Mapping,
MutableMapping,
MutableSequence,
Sequence,
)
from typing import TYPE_CHECKING, Any, Generic, Literal, TypedDict, cast
if TYPE_CHECKING:
from .._agents import ChatAgent
from typing import Any, Generic, Literal, cast
from openai import AsyncOpenAI
from openai.types.beta.threads import (
@@ -29,17 +25,14 @@ from openai.types.beta.threads import (
from openai.types.beta.threads.run_create_params import AdditionalMessage
from openai.types.beta.threads.run_submit_tool_outputs_params import ToolOutput
from openai.types.beta.threads.runs import RunStep
from pydantic import ValidationError
from pydantic import BaseModel, ValidationError
from .._clients import BaseChatClient
from .._memory import ContextProvider
from .._middleware import Middleware, use_chat_middleware
from .._threads import ChatMessageStoreProtocol
from .._middleware import use_chat_middleware
from .._tools import (
FunctionTool,
HostedCodeInterpreterTool,
HostedFileSearchTool,
ToolProtocol,
use_function_invocation,
)
from .._types import (
@@ -57,19 +50,19 @@ from ..observability import use_instrumentation
from ._shared import OpenAIConfigMixin, OpenAISettings
if sys.version_info >= (3, 13):
from typing import TypeVar
from typing import TypeVar # type: ignore # pragma: no cover
else:
from typing_extensions import TypeVar
from typing_extensions import TypeVar # type: ignore # pragma: no cover
if sys.version_info >= (3, 12):
from typing import override # type: ignore # pragma: no cover
else:
from typing_extensions import override # type: ignore[import] # pragma: no cover
from typing_extensions import override # type: ignore # pragma: no cover
if sys.version_info >= (3, 11):
from typing import Self # pragma: no cover
from typing import Self, TypedDict # type: ignore # pragma: no cover
else:
from typing_extensions import Self # pragma: no cover
from typing_extensions import Self, TypedDict # type: ignore # pragma: no cover
__all__ = [
@@ -81,6 +74,8 @@ __all__ = [
# region OpenAI Assistants Options TypedDict
TResponseModel = TypeVar("TResponseModel", bound=BaseModel | None, default=None)
class VectorStoreToolResource(TypedDict, total=False):
"""Vector store configuration for file search tool resources."""
@@ -109,7 +104,7 @@ class AssistantToolResources(TypedDict, total=False):
"""Resources for file search tool, including vector store IDs."""
class OpenAIAssistantsOptions(ChatOptions, total=False):
class OpenAIAssistantsOptions(ChatOptions[TResponseModel], Generic[TResponseModel], total=False):
"""OpenAI Assistants API-specific options dict.
Extends base ChatOptions with Assistants API-specific parameters
@@ -765,59 +760,3 @@ class OpenAIAssistantsClient(
self.assistant_name = agent_name
if description and not self.assistant_description:
self.assistant_description = description
@override
def as_agent(
self,
*,
id: str | None = None,
name: str | None = None,
description: str | None = None,
instructions: str | None = None,
tools: ToolProtocol
| Callable[..., Any]
| MutableMapping[str, Any]
| Sequence[ToolProtocol | Callable[..., Any] | MutableMapping[str, Any]]
| None = None,
default_options: TOpenAIAssistantsOptions | None = None,
chat_message_store_factory: Callable[[], ChatMessageStoreProtocol] | None = None,
context_provider: ContextProvider | None = None,
middleware: Sequence[Middleware] | None = None,
**kwargs: Any,
) -> "ChatAgent[TOpenAIAssistantsOptions]":
"""Convert this chat client to a ChatAgent.
This method creates a ChatAgent instance with this client pre-configured.
It does NOT create an assistant on the OpenAI service - the actual assistant
will be created on the server during the first invocation (run).
For creating and managing persistent assistants on the server, use
:class:`~agent_framework.openai.OpenAIAssistantProvider` instead.
Keyword Args:
id: The unique identifier for the agent. Will be created automatically if not provided.
name: The name of the agent.
description: A brief description of the agent's purpose.
instructions: Optional instructions for the agent.
tools: The tools to use for the request.
default_options: A TypedDict containing chat options.
chat_message_store_factory: Factory function to create an instance of ChatMessageStoreProtocol.
context_provider: Context providers to include during agent invocation.
middleware: List of middleware to intercept agent and function invocations.
kwargs: Any additional keyword arguments.
Returns:
A ChatAgent instance configured with this chat client.
"""
return super().as_agent(
id=id,
name=name,
description=description,
instructions=instructions,
tools=tools,
default_options=default_options,
chat_message_store_factory=chat_message_store_factory,
context_provider=context_provider,
middleware=middleware,
**kwargs,
)
@@ -5,7 +5,7 @@ import sys
from collections.abc import AsyncIterable, Awaitable, Callable, Mapping, MutableMapping, MutableSequence, Sequence
from datetime import datetime, timezone
from itertools import chain
from typing import Any, Generic, Literal, TypedDict
from typing import Any, Generic, Literal
from openai import AsyncOpenAI, BadRequestError
from openai.lib._parsing._completions import type_to_response_format_param
@@ -14,7 +14,7 @@ from openai.types.chat.chat_completion import ChatCompletion, Choice
from openai.types.chat.chat_completion_chunk import ChatCompletionChunk
from openai.types.chat.chat_completion_chunk import Choice as ChunkChoice
from openai.types.chat.chat_completion_message_custom_tool_call import ChatCompletionMessageCustomToolCall
from pydantic import ValidationError
from pydantic import BaseModel, ValidationError
from .._clients import BaseChatClient
from .._logging import get_logger
@@ -41,19 +41,24 @@ from ._exceptions import OpenAIContentFilterException
from ._shared import OpenAIBase, OpenAIConfigMixin, OpenAISettings
if sys.version_info >= (3, 13):
from typing import TypeVar
from typing import TypeVar # type: ignore # pragma: no cover
else:
from typing_extensions import TypeVar
from typing_extensions import TypeVar # type: ignore # pragma: no cover
if sys.version_info >= (3, 12):
from typing import override # type: ignore # pragma: no cover
else:
from typing_extensions import override # type: ignore[import] # pragma: no cover
from typing_extensions import override # type: ignore # pragma: no cover
if sys.version_info >= (3, 11):
from typing import TypedDict # type: ignore # pragma: no cover
else:
from typing_extensions import TypedDict # type: ignore # pragma: no cover
__all__ = ["OpenAIChatClient", "OpenAIChatOptions"]
logger = get_logger("agent_framework.openai")
TResponseModel = TypeVar("TResponseModel", bound=BaseModel | None, default=None)
# region OpenAI Chat Options TypedDict
@@ -72,7 +77,7 @@ class Prediction(TypedDict, total=False):
content: str | list[PredictionTextContent]
class OpenAIChatOptions(ChatOptions, total=False):
class OpenAIChatOptions(ChatOptions[TResponseModel], Generic[TResponseModel], total=False):
"""OpenAI-specific chat options dict.
Extends ChatOptions with options specific to OpenAI's Chat Completions API.
@@ -12,7 +12,7 @@ from collections.abc import (
)
from datetime import datetime, timezone
from itertools import chain
from typing import Any, Generic, Literal, TypedDict, cast
from typing import Any, Generic, Literal, cast
from openai import AsyncOpenAI, BadRequestError
from openai.types.responses.file_search_tool_param import FileSearchToolParam
@@ -79,9 +79,14 @@ if sys.version_info >= (3, 12):
from typing import override # type: ignore # pragma: no cover
else:
from typing_extensions import override # type: ignore[import] # pragma: no cover
if sys.version_info >= (3, 11):
from typing import TypedDict # type: ignore # pragma: no cover
else:
from typing_extensions import TypedDict # type: ignore # pragma: no cover
logger = get_logger("agent_framework.openai")
__all__ = ["OpenAIResponsesClient", "OpenAIResponsesOptions"]
@@ -108,7 +113,10 @@ class StreamOptions(TypedDict, total=False):
"""Whether to include usage statistics in stream events."""
class OpenAIResponsesOptions(ChatOptions, total=False):
TResponseFormat = TypeVar("TResponseFormat", bound=BaseModel | None, default=None)
class OpenAIResponsesOptions(ChatOptions[TResponseFormat], Generic[TResponseFormat], total=False):
"""OpenAI Responses API-specific chat options.
Extends ChatOptions with options specific to OpenAI's Responses API.
@@ -1,5 +1,5 @@
# Copyright (c) Microsoft. All rights reserved.
from typing import Annotated, Any, Literal
from typing import Annotated, Any, Literal, get_args, get_origin
from unittest.mock import Mock
import pytest
@@ -1493,8 +1493,6 @@ async def test_tool_with_kwargs_injection():
def test_parse_annotation_with_literal_type():
"""Test that _parse_annotation returns Literal types unchanged (issue #2891)."""
from typing import get_args, get_origin
# Literal with string values
literal_annotation = Literal["Data", "Security", "Network"]
result = _parse_annotation(literal_annotation)
@@ -1505,7 +1503,6 @@ def test_parse_annotation_with_literal_type():
def test_parse_annotation_with_literal_int_type():
"""Test that _parse_annotation returns Literal int types unchanged."""
from typing import get_args, get_origin
literal_annotation = Literal[1, 2, 3]
result = _parse_annotation(literal_annotation)
@@ -1516,7 +1513,6 @@ def test_parse_annotation_with_literal_int_type():
def test_parse_annotation_with_literal_bool_type():
"""Test that _parse_annotation returns Literal bool types unchanged."""
from typing import get_args, get_origin
literal_annotation = Literal[True, False]
result = _parse_annotation(literal_annotation)
@@ -1535,7 +1531,6 @@ def test_parse_annotation_with_simple_types():
def test_parse_annotation_with_annotated_and_literal():
"""Test that Annotated[Literal[...], description] works correctly."""
from typing import get_args, get_origin
# When Literal is inside Annotated, it should still be preserved
annotated_literal = Annotated[Literal["A", "B", "C"], "The category"]
+2 -16
View File
@@ -3,10 +3,10 @@
import base64
from collections.abc import AsyncIterable
from datetime import datetime, timezone
from typing import Any
from typing import Any, Literal
import pytest
from pydantic import BaseModel
from pydantic import BaseModel, Field, ValidationError
from pytest import fixture, mark, raises
from agent_framework import (
@@ -665,9 +665,6 @@ def test_chat_response_with_format_init():
def test_chat_response_value_raises_on_invalid_schema():
"""Test that value property raises ValidationError with field constraint details."""
from typing import Literal
from pydantic import Field, ValidationError
class StrictSchema(BaseModel):
id: Literal[5]
@@ -689,9 +686,6 @@ def test_chat_response_value_raises_on_invalid_schema():
def test_chat_response_try_parse_value_returns_none_on_invalid():
"""Test that try_parse_value returns None on validation failure with Field constraints."""
from typing import Literal
from pydantic import Field
class StrictSchema(BaseModel):
id: Literal[5]
@@ -707,7 +701,6 @@ def test_chat_response_try_parse_value_returns_none_on_invalid():
def test_chat_response_try_parse_value_returns_value_on_success():
"""Test that try_parse_value returns parsed value when all constraints pass."""
from pydantic import Field
class MySchema(BaseModel):
name: str = Field(min_length=3)
@@ -724,9 +717,6 @@ def test_chat_response_try_parse_value_returns_value_on_success():
def test_agent_response_value_raises_on_invalid_schema():
"""Test that AgentResponse.value property raises ValidationError with field constraint details."""
from typing import Literal
from pydantic import Field, ValidationError
class StrictSchema(BaseModel):
id: Literal[5]
@@ -748,9 +738,6 @@ def test_agent_response_value_raises_on_invalid_schema():
def test_agent_response_try_parse_value_returns_none_on_invalid():
"""Test that AgentResponse.try_parse_value returns None on Field constraint failure."""
from typing import Literal
from pydantic import Field
class StrictSchema(BaseModel):
id: Literal[5]
@@ -766,7 +753,6 @@ def test_agent_response_try_parse_value_returns_none_on_invalid():
def test_agent_response_try_parse_value_returns_value_on_success():
"""Test that AgentResponse.try_parse_value returns parsed value when all constraints pass."""
from pydantic import Field
class MySchema(BaseModel):
name: str = Field(min_length=3)
@@ -1,6 +1,7 @@
# Copyright (c) Microsoft. All rights reserved.
import pytest
from typing_extensions import Never
from agent_framework import (
ChatMessage,
@@ -187,7 +188,6 @@ async def test_executor_completed_event_contains_sent_messages():
async def test_executor_completed_event_includes_yielded_outputs():
"""Test that ExecutorCompletedEvent.data includes yielded outputs."""
from typing_extensions import Never
from agent_framework import WorkflowOutputEvent
@@ -318,7 +318,6 @@ def test_executor_output_types_property():
def test_executor_workflow_output_types_property():
"""Test that the workflow_output_types property correctly identifies workflow output types."""
from typing_extensions import Never
# Test executor with no workflow output types
class NoWorkflowOutputExecutor(Executor):
@@ -42,9 +42,9 @@ from agent_framework import (
from agent_framework._workflows._checkpoint import InMemoryCheckpointStorage
if sys.version_info >= (3, 12):
from typing import override
from typing import override # type: ignore # pragma: no cover
else:
from typing_extensions import override
from typing_extensions import override # type: ignore # pragma: no cover
def test_magentic_context_reset_behavior():
@@ -1,8 +1,9 @@
# Copyright (c) Microsoft. All rights reserved.
import sys
from collections.abc import Callable, Mapping
from pathlib import Path
from typing import Any, Literal, TypedDict, cast
from typing import Any, Literal, cast
import yaml
from agent_framework import (
@@ -42,6 +43,11 @@ from ._models import (
agent_schema_dispatch,
)
if sys.version_info >= (3, 11):
from typing import TypedDict # type: ignore # pragma: no cover
else:
from typing_extensions import TypedDict # type: ignore # pragma: no cover
class ProviderTypeMapping(TypedDict, total=True):
package: str
@@ -24,10 +24,11 @@ See: dotnet/src/Microsoft.Agents.AI.Workflows.Declarative/PowerFx/
"""
import logging
import sys
from collections.abc import Mapping
from dataclasses import dataclass
from decimal import Decimal as _Decimal
from typing import Any, Literal, TypedDict, cast
from typing import Any, Literal, cast
from agent_framework._workflows import (
Executor,
@@ -36,6 +37,12 @@ from agent_framework._workflows import (
)
from powerfx import Engine
if sys.version_info >= (3, 11):
from typing import TypedDict # type: ignore # pragma: no cover
else:
from typing_extensions import TypedDict # type: ignore # pragma: no cover
logger = logging.getLogger(__name__)
@@ -7,7 +7,7 @@ import json
import logging
from dataclasses import fields, is_dataclass
from types import UnionType
from typing import Any, Union, get_args, get_origin
from typing import Any, Union, get_args, get_origin, get_type_hints
from agent_framework import ChatMessage
@@ -270,8 +270,6 @@ def generate_schema_from_serialization_mixin(cls: type[Any]) -> dict[str, Any]:
# Get type hints
try:
from typing import get_type_hints
type_hints = get_type_hints(cls)
except Exception:
type_hints = {}
@@ -348,8 +346,6 @@ def extract_response_type_from_executor(executor: Any, request_type: type) -> ty
The response type class, or None if not found
"""
try:
from typing import get_type_hints
# Introspect handler methods for @response_handler pattern
for attr_name in dir(executor):
if attr_name.startswith("_"):
@@ -1,7 +1,7 @@
# Copyright (c) Microsoft. All rights reserved.
import sys
from typing import Any, ClassVar, Generic, TypedDict
from typing import Any, ClassVar, Generic
from agent_framework import ChatOptions, use_chat_middleware, use_function_invocation
from agent_framework._pydantic import AFBaseSettings
@@ -11,12 +11,16 @@ from agent_framework.openai._chat_client import OpenAIBaseChatClient
from foundry_local import FoundryLocalManager
from foundry_local.models import DeviceType
from openai import AsyncOpenAI
from pydantic import BaseModel
if sys.version_info >= (3, 13):
from typing import TypeVar # type: ignore # pragma: no cover
else:
from typing_extensions import TypeVar # type: ignore # pragma: no cover
if sys.version_info >= (3, 11):
from typing import TypedDict # type: ignore # pragma: no cover
else:
from typing_extensions import TypedDict # type: ignore # pragma: no cover
__all__ = [
"FoundryLocalChatOptions",
@@ -24,11 +28,13 @@ __all__ = [
"FoundryLocalSettings",
]
TResponseModel = TypeVar("TResponseModel", bound=BaseModel | None, default=None)
# region Foundry Local Chat Options TypedDict
class FoundryLocalChatOptions(ChatOptions, total=False):
class FoundryLocalChatOptions(ChatOptions[TResponseModel], Generic[TResponseModel], total=False):
"""Azure Foundry Local (local model deployment) chat options dict.
Extends base ChatOptions for local model inference via Foundry Local.
@@ -3,7 +3,7 @@
import sys
from collections.abc import MutableSequence, Sequence
from contextlib import AbstractAsyncContextManager
from typing import Any, TypedDict
from typing import Any
from agent_framework import ChatMessage, Context, ContextProvider
from agent_framework.exceptions import ServiceInitializationError
@@ -15,9 +15,9 @@ else:
from typing_extensions import override # type: ignore[import] # pragma: no cover
if sys.version_info >= (3, 11):
from typing import NotRequired, Self # pragma: no cover
from typing import NotRequired, Self, TypedDict # pragma: no cover
else:
from typing_extensions import NotRequired, Self # pragma: no cover
from typing_extensions import NotRequired, Self, TypedDict # pragma: no cover
# Type aliases for Mem0 search response formats (v1.1 and v2; v1 is deprecated, but matches the type definition for v2)
@@ -11,7 +11,7 @@ from collections.abc import (
Sequence,
)
from itertools import chain
from typing import Any, ClassVar, Generic, TypedDict
from typing import Any, ClassVar, Generic
from agent_framework import (
BaseChatClient,
@@ -40,26 +40,32 @@ from ollama import AsyncClient
# Rename imported types to avoid naming conflicts with Agent Framework types
from ollama._types import ChatResponse as OllamaChatResponse
from ollama._types import Message as OllamaMessage
from pydantic import ValidationError
from pydantic import BaseModel, ValidationError
if sys.version_info >= (3, 13):
from typing import TypeVar
from typing import TypeVar # type: ignore # pragma: no cover
else:
from typing_extensions import TypeVar
from typing_extensions import TypeVar # type: ignore # pragma: no cover
if sys.version_info >= (3, 12):
from typing import override # type: ignore # pragma: no cover
else:
from typing_extensions import override # type: ignore[import] # pragma: no cover
from typing_extensions import override # type: ignore # pragma: no cover
if sys.version_info >= (3, 11):
from typing import TypedDict # type: ignore # pragma: no cover
else:
from typing_extensions import TypedDict # type: ignore # pragma: no cover
__all__ = ["OllamaChatClient", "OllamaChatOptions"]
TResponseModel = TypeVar("TResponseModel", bound=BaseModel | None, default=None)
# region Ollama Chat Options TypedDict
class OllamaChatOptions(ChatOptions, total=False):
class OllamaChatOptions(ChatOptions[TResponseModel], Generic[TResponseModel], total=False):
"""Ollama-specific chat options dict.
Extends base ChatOptions with Ollama-specific parameters.
@@ -37,9 +37,9 @@ async def non_streaming_example() -> None:
# Get structured response from the agent using response_format parameter
result = await agent.run(query, options={"response_format": OutputStruct})
# Access the structured output using try_parse_value for safe parsing
if structured_data := result.try_parse_value(OutputStruct):
print("Structured Output Agent (from result.try_parse_value):")
# Access the structured output using the parsed value
if structured_data := result.value:
print("Structured Output Agent:")
print(f"City: {structured_data.city}")
print(f"Description: {structured_data.description}")
else:
@@ -66,8 +66,8 @@ async def streaming_example() -> None:
output_format_type=OutputStruct,
)
# Access the structured output using try_parse_value for safe parsing
if structured_data := result.try_parse_value(OutputStruct):
# Access the structured output using the parsed value
if structured_data := result.value:
print("Structured Output (from streaming with AgentResponse.from_agent_response_generator):")
print(f"City: {structured_data.city}")
print(f"Description: {structured_data.description}")