mirror of
https://github.com/microsoft/agent-framework.git
synced 2026-06-16 21:04:09 +08:00
Python: added generic types to ChatOptions and ChatResponse/AgentResponse for Response Format (#3305)
* added generic types to ChatOptions and ChatResponse/AgentResponse for response format * fix typevar import * fix for older python versions * fix missing import * fixed imports * fixed mypy * mypy fix
This commit is contained in:
committed by
GitHub
Unverified
parent
1f8463f9bb
commit
1226828ec2
@@ -28,25 +28,21 @@ from ._http_service import AGUIHttpService
|
||||
from ._message_adapters import agent_framework_messages_to_agui
|
||||
from ._utils import convert_tools_to_agui_format
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ._types import AGUIChatOptions
|
||||
|
||||
from typing import TypedDict
|
||||
|
||||
if sys.version_info >= (3, 13):
|
||||
from typing import TypeVar
|
||||
from typing import TypeVar # type: ignore # pragma: no cover
|
||||
else:
|
||||
from typing_extensions import TypeVar
|
||||
|
||||
from typing_extensions import TypeVar # type: ignore # pragma: no cover
|
||||
if sys.version_info >= (3, 12):
|
||||
from typing import override # type: ignore # pragma: no cover
|
||||
else:
|
||||
from typing_extensions import override # type: ignore[import] # pragma: no cover
|
||||
|
||||
if sys.version_info >= (3, 11):
|
||||
from typing import Self # pragma: no cover
|
||||
from typing import Self, TypedDict # pragma: no cover
|
||||
else:
|
||||
from typing_extensions import Self # pragma: no cover
|
||||
from typing_extensions import Self, TypedDict # pragma: no cover
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ._types import AGUIChatOptions
|
||||
|
||||
logger: logging.Logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -85,7 +81,7 @@ def _apply_server_function_call_unwrap(chat_client: TBaseChatClient) -> TBaseCha
|
||||
|
||||
@wraps(original_get_response)
|
||||
async def response_wrapper(self: Any, *args: Any, **kwargs: Any) -> ChatResponse:
|
||||
response = await original_get_response(self, *args, **kwargs)
|
||||
response: ChatResponse[Any] = await original_get_response(self, *args, **kwargs) # type: ignore[var-annotated]
|
||||
if response.messages:
|
||||
for message in response.messages:
|
||||
_unwrap_server_function_call_contents(cast(MutableSequence[Content | dict[str, Any]], message.contents))
|
||||
|
||||
@@ -3,15 +3,23 @@
|
||||
"""Type definitions for AG-UI integration."""
|
||||
|
||||
import sys
|
||||
from typing import Any, TypedDict
|
||||
from typing import Any, Generic
|
||||
|
||||
from agent_framework import ChatOptions
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
if sys.version_info >= (3, 13):
|
||||
from typing import TypeVar
|
||||
from typing import TypeVar # type: ignore # pragma: no cover
|
||||
else:
|
||||
from typing_extensions import TypeVar
|
||||
from typing_extensions import TypeVar # type: ignore # pragma: no cover
|
||||
if sys.version_info >= (3, 11):
|
||||
from typing import TypedDict # type: ignore # pragma: no cover
|
||||
else:
|
||||
from typing_extensions import TypedDict # type: ignore # pragma: no cover
|
||||
|
||||
|
||||
TAGUIChatOptions = TypeVar("TAGUIChatOptions", bound=TypedDict, default="AGUIChatOptions", covariant=True) # type: ignore[valid-type]
|
||||
TResponseModel = TypeVar("TResponseModel", bound=BaseModel | None, default=None)
|
||||
|
||||
|
||||
class PredictStateConfig(TypedDict):
|
||||
@@ -76,7 +84,7 @@ class AGUIRequest(BaseModel):
|
||||
# region AG-UI Chat Options TypedDict
|
||||
|
||||
|
||||
class AGUIChatOptions(ChatOptions, total=False):
|
||||
class AGUIChatOptions(ChatOptions[TResponseModel], Generic[TResponseModel], total=False):
|
||||
"""AG-UI protocol-specific chat options dict.
|
||||
|
||||
Extends base ChatOptions for the AG-UI (Agent-UI) protocol.
|
||||
@@ -140,7 +148,5 @@ class AGUIChatOptions(ChatOptions, total=False):
|
||||
AGUI_OPTION_TRANSLATIONS: dict[str, str] = {}
|
||||
"""Maps ChatOptions keys to AG-UI parameter names (protocol uses standard names)."""
|
||||
|
||||
TAGUIChatOptions = TypeVar("TAGUIChatOptions", bound=TypedDict, default="AGUIChatOptions", covariant=True) # type: ignore[valid-type]
|
||||
|
||||
|
||||
# endregion
|
||||
|
||||
@@ -12,6 +12,10 @@ if sys.version_info >= (3, 13):
|
||||
from typing import TypeVar # type: ignore # pragma: no cover
|
||||
else:
|
||||
from typing_extensions import TypeVar # type: ignore # pragma: no cover
|
||||
if sys.version_info >= (3, 11):
|
||||
from typing import TypedDict # type: ignore # pragma: no cover
|
||||
else:
|
||||
from typing_extensions import TypedDict # type: ignore # pragma: no cover
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from agent_framework import ChatOptions
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
|
||||
import sys
|
||||
from collections.abc import AsyncIterable, MutableMapping, MutableSequence, Sequence
|
||||
from typing import Any, ClassVar, Final, Generic, Literal, TypedDict
|
||||
from typing import Any, ClassVar, Final, Generic, Literal
|
||||
|
||||
from agent_framework import (
|
||||
AGENT_FRAMEWORK_USER_AGENT,
|
||||
@@ -47,15 +47,18 @@ from anthropic.types.beta.beta_code_execution_tool_result_error import (
|
||||
)
|
||||
from pydantic import BaseModel, SecretStr, ValidationError
|
||||
|
||||
if sys.version_info >= (3, 13):
|
||||
from typing import TypeVar
|
||||
if sys.version_info >= (3, 11):
|
||||
from typing import TypedDict # type: ignore # pragma: no cover
|
||||
else:
|
||||
from typing_extensions import TypeVar
|
||||
|
||||
from typing_extensions import TypedDict # type: ignore # pragma: no cover
|
||||
if sys.version_info >= (3, 13):
|
||||
from typing import TypeVar # type: ignore # pragma: no cover
|
||||
else:
|
||||
from typing_extensions import TypeVar # type: ignore # pragma: no cover
|
||||
if sys.version_info >= (3, 12):
|
||||
from typing import override # type: ignore # pragma: no cover
|
||||
else:
|
||||
from typing_extensions import override # type: ignore[import] # pragma: no cover
|
||||
from typing_extensions import override # type: ignore # pragma: no cover
|
||||
|
||||
__all__ = [
|
||||
"AnthropicChatOptions",
|
||||
@@ -69,6 +72,8 @@ ANTHROPIC_DEFAULT_MAX_TOKENS: Final[int] = 1024
|
||||
BETA_FLAGS: Final[list[str]] = ["mcp-client-2025-04-04", "code-execution-2025-08-25"]
|
||||
STRUCTURED_OUTPUTS_BETA_FLAG: Final[str] = "structured-outputs-2025-11-13"
|
||||
|
||||
TResponseModel = TypeVar("TResponseModel", bound=BaseModel | None, default=None)
|
||||
|
||||
|
||||
# region Anthropic Chat Options TypedDict
|
||||
|
||||
@@ -91,7 +96,7 @@ class ThinkingConfig(TypedDict, total=False):
|
||||
budget_tokens: int
|
||||
|
||||
|
||||
class AnthropicChatOptions(ChatOptions, total=False):
|
||||
class AnthropicChatOptions(ChatOptions[TResponseModel], Generic[TResponseModel], total=False):
|
||||
"""Anthropic-specific chat options.
|
||||
|
||||
Extends ChatOptions with options specific to Anthropic's Messages API.
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
|
||||
import sys
|
||||
from collections.abc import Callable, MutableMapping, Sequence
|
||||
from typing import TYPE_CHECKING, Any, Generic, TypedDict, cast
|
||||
from typing import TYPE_CHECKING, Any, Generic, cast
|
||||
|
||||
from agent_framework import (
|
||||
AGENT_FRAMEWORK_USER_AGENT,
|
||||
@@ -27,9 +27,13 @@ if TYPE_CHECKING:
|
||||
from ._chat_client import AzureAIAgentOptions
|
||||
|
||||
if sys.version_info >= (3, 13):
|
||||
from typing import Self, TypeVar # pragma: no cover
|
||||
from typing import Self, TypeVar # type: ignore # pragma: no cover
|
||||
else:
|
||||
from typing_extensions import Self, TypeVar # pragma: no cover
|
||||
from typing_extensions import Self, TypeVar # type: ignore # pragma: no cover
|
||||
if sys.version_info >= (3, 11):
|
||||
from typing import TypedDict # type: ignore # pragma: no cover
|
||||
else:
|
||||
from typing_extensions import TypedDict # type: ignore # pragma: no cover
|
||||
|
||||
|
||||
# Type variable for options - allows typed ChatAgent[TOptions] returns
|
||||
|
||||
@@ -6,7 +6,7 @@ import os
|
||||
import re
|
||||
import sys
|
||||
from collections.abc import AsyncIterable, Callable, Mapping, MutableMapping, MutableSequence, Sequence
|
||||
from typing import Any, ClassVar, Generic, TypedDict
|
||||
from typing import Any, ClassVar, Generic
|
||||
|
||||
from agent_framework import (
|
||||
AGENT_FRAMEWORK_USER_AGENT,
|
||||
@@ -96,9 +96,9 @@ if sys.version_info >= (3, 12):
|
||||
else:
|
||||
from typing_extensions import override # type: ignore[import] # pragma: no cover
|
||||
if sys.version_info >= (3, 11):
|
||||
from typing import Self # pragma: no cover
|
||||
from typing import Self, TypedDict # type: ignore # pragma: no cover
|
||||
else:
|
||||
from typing_extensions import Self # pragma: no cover
|
||||
from typing_extensions import Self, TypedDict # type: ignore # pragma: no cover
|
||||
|
||||
|
||||
logger = get_logger("agent_framework.azure")
|
||||
@@ -1265,7 +1265,7 @@ class AzureAIAgentClient(BaseChatClient[TAzureAIAgentOptions], Generic[TAzureAIA
|
||||
| MutableMapping[str, Any]
|
||||
| Sequence[ToolProtocol | Callable[..., Any] | MutableMapping[str, Any]]
|
||||
| None = None,
|
||||
default_options: TAzureAIAgentOptions | None = None,
|
||||
default_options: TAzureAIAgentOptions | Mapping[str, Any] | None = None,
|
||||
chat_message_store_factory: Callable[[], ChatMessageStoreProtocol] | None = None,
|
||||
context_provider: ContextProvider | None = None,
|
||||
middleware: Sequence[Middleware] | None = None,
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
|
||||
import sys
|
||||
from collections.abc import Callable, Mapping, MutableMapping, MutableSequence, Sequence
|
||||
from typing import Any, ClassVar, Generic, TypedDict, TypeVar, cast
|
||||
from typing import Any, ClassVar, Generic, TypeVar, cast
|
||||
|
||||
from agent_framework import (
|
||||
AGENT_FRAMEWORK_USER_AGENT,
|
||||
@@ -38,9 +38,9 @@ if sys.version_info >= (3, 12):
|
||||
else:
|
||||
from typing_extensions import override # type: ignore[import] # pragma: no cover
|
||||
if sys.version_info >= (3, 11):
|
||||
from typing import Self # pragma: no cover
|
||||
from typing import Self, TypedDict # type: ignore # pragma: no cover
|
||||
else:
|
||||
from typing_extensions import Self # pragma: no cover
|
||||
from typing_extensions import Self, TypedDict # type: ignore # pragma: no cover
|
||||
|
||||
|
||||
logger = get_logger("agent_framework.azure")
|
||||
@@ -551,7 +551,7 @@ class AzureAIClient(OpenAIBaseResponsesClient[TAzureAIClientOptions], Generic[TA
|
||||
| MutableMapping[str, Any]
|
||||
| Sequence[ToolProtocol | Callable[..., Any] | MutableMapping[str, Any]]
|
||||
| None = None,
|
||||
default_options: TAzureAIClientOptions | None = None,
|
||||
default_options: TAzureAIClientOptions | Mapping[str, Any] | None = None,
|
||||
chat_message_store_factory: Callable[[], ChatMessageStoreProtocol] | None = None,
|
||||
context_provider: ContextProvider | None = None,
|
||||
middleware: Sequence[Middleware] | None = None,
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
|
||||
import sys
|
||||
from collections.abc import Callable, MutableMapping, Sequence
|
||||
from typing import Any, Generic, TypedDict
|
||||
from typing import Any, Generic
|
||||
|
||||
from agent_framework import (
|
||||
AGENT_FRAMEWORK_USER_AGENT,
|
||||
@@ -33,9 +33,13 @@ from ._client import AzureAIClient, AzureAIProjectAgentOptions
|
||||
from ._shared import AzureAISettings, create_text_format_config, from_azure_ai_tools, to_azure_ai_tools
|
||||
|
||||
if sys.version_info >= (3, 13):
|
||||
from typing import Self, TypeVar # pragma: no cover
|
||||
from typing import TypeVar # type: ignore # pragma: no cover
|
||||
else:
|
||||
from typing_extensions import Self, TypeVar # pragma: no cover
|
||||
from typing_extensions import TypeVar # type: ignore # pragma: no cover
|
||||
if sys.version_info >= (3, 11):
|
||||
from typing import Self, TypedDict # type: ignore # pragma: no cover
|
||||
else:
|
||||
from typing_extensions import Self, TypedDict # type: ignore # pragma: no cover
|
||||
|
||||
|
||||
logger = get_logger("agent_framework.azure")
|
||||
|
||||
@@ -5,7 +5,7 @@ import json
|
||||
import sys
|
||||
from collections import deque
|
||||
from collections.abc import AsyncIterable, MutableMapping, MutableSequence, Sequence
|
||||
from typing import Any, ClassVar, Generic, Literal, TypedDict
|
||||
from typing import Any, ClassVar, Generic, Literal
|
||||
from uuid import uuid4
|
||||
|
||||
from agent_framework import (
|
||||
@@ -33,17 +33,20 @@ from agent_framework.observability import use_instrumentation
|
||||
from boto3.session import Session as Boto3Session
|
||||
from botocore.client import BaseClient
|
||||
from botocore.config import Config as BotoConfig
|
||||
from pydantic import SecretStr, ValidationError
|
||||
from pydantic import BaseModel, SecretStr, ValidationError
|
||||
|
||||
if sys.version_info >= (3, 13):
|
||||
from typing import TypeVar
|
||||
from typing import TypeVar # type: ignore # pragma: no cover
|
||||
else:
|
||||
from typing_extensions import TypeVar
|
||||
|
||||
from typing_extensions import TypeVar # type: ignore # pragma: no cover
|
||||
if sys.version_info >= (3, 12):
|
||||
from typing import override # type: ignore # pragma: no cover
|
||||
else:
|
||||
from typing_extensions import override # type: ignore[import] # pragma: no cover
|
||||
from typing_extensions import override # type: ignore # pragma: no cover
|
||||
if sys.version_info >= (3, 11):
|
||||
from typing import TypedDict # type: ignore # pragma: no cover
|
||||
else:
|
||||
from typing_extensions import TypedDict # type: ignore # pragma: no cover
|
||||
|
||||
logger = get_logger("agent_framework.bedrock")
|
||||
|
||||
@@ -55,6 +58,8 @@ __all__ = [
|
||||
"BedrockSettings",
|
||||
]
|
||||
|
||||
TResponseModel = TypeVar("TResponseModel", bound=BaseModel | None, default=None)
|
||||
|
||||
|
||||
# region Bedrock Chat Options TypedDict
|
||||
|
||||
@@ -82,7 +87,7 @@ class BedrockGuardrailConfig(TypedDict, total=False):
|
||||
"""How to process guardrails during streaming (sync blocks, async does not)."""
|
||||
|
||||
|
||||
class BedrockChatOptions(ChatOptions, total=False):
|
||||
class BedrockChatOptions(ChatOptions[TResponseModel], Generic[TResponseModel], total=False):
|
||||
"""Amazon Bedrock Converse API-specific chat options dict.
|
||||
|
||||
Extends base ChatOptions with Bedrock-specific parameters.
|
||||
|
||||
@@ -30,9 +30,9 @@ from chatkit.types import (
|
||||
)
|
||||
|
||||
if sys.version_info >= (3, 11):
|
||||
from typing import assert_never
|
||||
from typing import assert_never # type:ignore # pragma: no cover
|
||||
else:
|
||||
from typing_extensions import assert_never
|
||||
from typing_extensions import assert_never # type:ignore # pragma: no cover
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
import inspect
|
||||
import re
|
||||
import sys
|
||||
from collections.abc import AsyncIterable, Awaitable, Callable, MutableMapping, Sequence
|
||||
from collections.abc import AsyncIterable, Awaitable, Callable, Mapping, MutableMapping, Sequence
|
||||
from contextlib import AbstractAsyncContextManager, AsyncExitStack
|
||||
from copy import deepcopy
|
||||
from itertools import chain
|
||||
@@ -13,8 +13,8 @@ from typing import (
|
||||
ClassVar,
|
||||
Generic,
|
||||
Protocol,
|
||||
TypedDict,
|
||||
cast,
|
||||
overload,
|
||||
runtime_checkable,
|
||||
)
|
||||
from uuid import uuid4
|
||||
@@ -43,24 +43,25 @@ from ._types import (
|
||||
from .exceptions import AgentExecutionException, AgentInitializationError
|
||||
from .observability import use_agent_instrumentation
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ._types import ChatOptions
|
||||
|
||||
|
||||
if sys.version_info >= (3, 13):
|
||||
from typing import TypeVar
|
||||
from typing import TypeVar # type: ignore # pragma: no cover
|
||||
else:
|
||||
from typing_extensions import TypeVar
|
||||
|
||||
from typing_extensions import TypeVar # type: ignore # pragma: no cover
|
||||
if sys.version_info >= (3, 12):
|
||||
from typing import override # type: ignore # pragma: no cover
|
||||
else:
|
||||
from typing_extensions import override # type: ignore[import] # pragma: no cover
|
||||
|
||||
if sys.version_info >= (3, 11):
|
||||
from typing import Self # pragma: no cover
|
||||
from typing import Self, TypedDict # pragma: no cover
|
||||
else:
|
||||
from typing_extensions import Self # pragma: no cover
|
||||
from typing_extensions import Self, TypedDict # pragma: no cover
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ._types import ChatOptions
|
||||
|
||||
|
||||
TResponseModel = TypeVar("TResponseModel", bound=BaseModel | None, default=None, covariant=True)
|
||||
TResponseModelT = TypeVar("TResponseModelT", bound=BaseModel)
|
||||
|
||||
|
||||
logger = get_logger("agent_framework")
|
||||
@@ -622,6 +623,7 @@ class ChatAgent(BaseAgent, Generic[TOptions_co]): # type: ignore[misc]
|
||||
provider-specific options including temperature, max_tokens, model_id,
|
||||
tool_choice, and provider-specific options like reasoning_effort.
|
||||
You can also create your own TypedDict for custom chat clients.
|
||||
Note: response_format typing does not flow into run outputs when set via default_options.
|
||||
These can be overridden at runtime via the ``options`` parameter of ``run()`` and ``run_stream()``.
|
||||
tools: The tools to use for the request.
|
||||
kwargs: Any additional keyword arguments. Will be stored as ``additional_properties``.
|
||||
@@ -657,6 +659,14 @@ class ChatAgent(BaseAgent, Generic[TOptions_co]): # type: ignore[misc]
|
||||
|
||||
# Get tools from options or named parameter (named param takes precedence)
|
||||
tools_ = tools if tools is not None else opts.pop("tools", None)
|
||||
tools_ = cast(
|
||||
ToolProtocol
|
||||
| Callable[..., Any]
|
||||
| MutableMapping[str, Any]
|
||||
| list[ToolProtocol | Callable[..., Any] | MutableMapping[str, Any]]
|
||||
| None,
|
||||
tools_,
|
||||
)
|
||||
|
||||
# Handle instructions - named parameter takes precedence over options
|
||||
instructions_ = instructions if instructions is not None else opts.pop("instructions", None)
|
||||
@@ -742,6 +752,7 @@ class ChatAgent(BaseAgent, Generic[TOptions_co]): # type: ignore[misc]
|
||||
): # type: ignore[reportAttributeAccessIssue, attr-defined]
|
||||
self.chat_client._update_agent_name_and_description(self.name, self.description) # type: ignore[reportAttributeAccessIssue, attr-defined]
|
||||
|
||||
@overload
|
||||
async def run(
|
||||
self,
|
||||
messages: str | ChatMessage | Sequence[str | ChatMessage] | None = None,
|
||||
@@ -752,9 +763,38 @@ class ChatAgent(BaseAgent, Generic[TOptions_co]): # type: ignore[misc]
|
||||
| MutableMapping[str, Any]
|
||||
| list[ToolProtocol | Callable[..., Any] | MutableMapping[str, Any]]
|
||||
| None = None,
|
||||
options: TOptions_co | None = None,
|
||||
options: "ChatOptions[TResponseModelT]",
|
||||
**kwargs: Any,
|
||||
) -> AgentResponse:
|
||||
) -> AgentResponse[TResponseModelT]: ...
|
||||
|
||||
@overload
|
||||
async def run(
|
||||
self,
|
||||
messages: str | ChatMessage | Sequence[str | ChatMessage] | None = None,
|
||||
*,
|
||||
thread: AgentThread | None = None,
|
||||
tools: ToolProtocol
|
||||
| Callable[..., Any]
|
||||
| MutableMapping[str, Any]
|
||||
| list[ToolProtocol | Callable[..., Any] | MutableMapping[str, Any]]
|
||||
| None = None,
|
||||
options: TOptions_co | Mapping[str, Any] | "ChatOptions[Any]" | None = None,
|
||||
**kwargs: Any,
|
||||
) -> AgentResponse[Any]: ...
|
||||
|
||||
async def run(
|
||||
self,
|
||||
messages: str | ChatMessage | Sequence[str | ChatMessage] | None = None,
|
||||
*,
|
||||
thread: AgentThread | None = None,
|
||||
tools: ToolProtocol
|
||||
| Callable[..., Any]
|
||||
| MutableMapping[str, Any]
|
||||
| list[ToolProtocol | Callable[..., Any] | MutableMapping[str, Any]]
|
||||
| None = None,
|
||||
options: TOptions_co | Mapping[str, Any] | "ChatOptions[Any]" | None = None,
|
||||
**kwargs: Any,
|
||||
) -> AgentResponse[Any]:
|
||||
"""Run the agent with the given messages and options.
|
||||
|
||||
Note:
|
||||
@@ -784,6 +824,14 @@ class ChatAgent(BaseAgent, Generic[TOptions_co]): # type: ignore[misc]
|
||||
|
||||
# Get tools from options or named parameter (named param takes precedence)
|
||||
tools_ = tools if tools is not None else opts.pop("tools", None)
|
||||
tools_ = cast(
|
||||
ToolProtocol
|
||||
| Callable[..., Any]
|
||||
| MutableMapping[str, Any]
|
||||
| list[ToolProtocol | Callable[..., Any] | MutableMapping[str, Any]]
|
||||
| None,
|
||||
tools_,
|
||||
)
|
||||
|
||||
input_messages = normalize_messages(messages)
|
||||
thread, run_chat_options, thread_messages = await self._prepare_thread_and_messages(
|
||||
@@ -860,12 +908,19 @@ class ChatAgent(BaseAgent, Generic[TOptions_co]): # type: ignore[misc]
|
||||
response.messages,
|
||||
**{k: v for k, v in kwargs.items() if k != "thread"},
|
||||
)
|
||||
response_format = co.get("response_format")
|
||||
if not (
|
||||
response_format is not None and isinstance(response_format, type) and issubclass(response_format, BaseModel)
|
||||
):
|
||||
response_format = None
|
||||
|
||||
return AgentResponse(
|
||||
messages=response.messages,
|
||||
response_id=response.response_id,
|
||||
created_at=response.created_at,
|
||||
usage_details=response.usage_details,
|
||||
value=response.value,
|
||||
response_format=response_format,
|
||||
raw_representation=response,
|
||||
additional_properties=response.additional_properties,
|
||||
)
|
||||
@@ -880,7 +935,7 @@ class ChatAgent(BaseAgent, Generic[TOptions_co]): # type: ignore[misc]
|
||||
| MutableMapping[str, Any]
|
||||
| list[ToolProtocol | Callable[..., Any] | MutableMapping[str, Any]]
|
||||
| None = None,
|
||||
options: TOptions_co | None = None,
|
||||
options: TOptions_co | Mapping[str, Any] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> AsyncIterable[AgentResponseUpdate]:
|
||||
"""Stream the agent with the given messages and options.
|
||||
|
||||
@@ -6,6 +6,7 @@ from abc import ABC, abstractmethod
|
||||
from collections.abc import (
|
||||
AsyncIterable,
|
||||
Callable,
|
||||
Mapping,
|
||||
MutableMapping,
|
||||
MutableSequence,
|
||||
Sequence,
|
||||
@@ -17,9 +18,13 @@ from typing import (
|
||||
Generic,
|
||||
Protocol,
|
||||
TypedDict,
|
||||
cast,
|
||||
overload,
|
||||
runtime_checkable,
|
||||
)
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from ._logging import get_logger
|
||||
from ._memory import ContextProvider
|
||||
from ._middleware import (
|
||||
@@ -45,9 +50,10 @@ from ._types import (
|
||||
)
|
||||
|
||||
if sys.version_info >= (3, 13):
|
||||
from typing import TypeVar
|
||||
from typing import TypeVar # type: ignore # pragma: no cover
|
||||
else:
|
||||
from typing_extensions import TypeVar
|
||||
from typing_extensions import TypeVar # type: ignore # pragma: no cover
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ._agents import ChatAgent
|
||||
@@ -120,6 +126,16 @@ class ChatClientProtocol(Protocol[TOptions_contra]): #
|
||||
|
||||
additional_properties: dict[str, Any]
|
||||
|
||||
@overload
|
||||
async def get_response(
|
||||
self,
|
||||
messages: str | ChatMessage | Sequence[str | ChatMessage],
|
||||
*,
|
||||
options: "ChatOptions[TResponseModelT]",
|
||||
**kwargs: Any,
|
||||
) -> "ChatResponse[TResponseModelT]": ...
|
||||
|
||||
@overload
|
||||
async def get_response(
|
||||
self,
|
||||
messages: str | ChatMessage | Sequence[str | ChatMessage],
|
||||
@@ -175,6 +191,9 @@ TOptions_co = TypeVar(
|
||||
covariant=True,
|
||||
)
|
||||
|
||||
TResponseModel = TypeVar("TResponseModel", bound=BaseModel | None, default=None, covariant=True)
|
||||
TResponseModelT = TypeVar("TResponseModelT", bound=BaseModel)
|
||||
|
||||
|
||||
class BaseChatClient(SerializationMixin, ABC, Generic[TOptions_co]):
|
||||
"""Base class for chat clients.
|
||||
@@ -319,13 +338,31 @@ class BaseChatClient(SerializationMixin, ABC, Generic[TOptions_co]):
|
||||
|
||||
# region Public method
|
||||
|
||||
@overload
|
||||
async def get_response(
|
||||
self,
|
||||
messages: str | ChatMessage | Sequence[str | ChatMessage],
|
||||
*,
|
||||
options: "ChatOptions[TResponseModelT]",
|
||||
**kwargs: Any,
|
||||
) -> ChatResponse[TResponseModelT]: ...
|
||||
|
||||
@overload
|
||||
async def get_response(
|
||||
self,
|
||||
messages: str | ChatMessage | Sequence[str | ChatMessage],
|
||||
*,
|
||||
options: TOptions_co | None = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResponse:
|
||||
) -> ChatResponse: ...
|
||||
|
||||
async def get_response(
|
||||
self,
|
||||
messages: str | ChatMessage | Sequence[str | ChatMessage],
|
||||
*,
|
||||
options: TOptions_co | "ChatOptions[Any]" | None = None,
|
||||
**kwargs: Any,
|
||||
) -> ChatResponse[Any]:
|
||||
"""Get a response from a chat client.
|
||||
|
||||
Args:
|
||||
@@ -389,7 +426,7 @@ class BaseChatClient(SerializationMixin, ABC, Generic[TOptions_co]):
|
||||
| MutableMapping[str, Any]
|
||||
| Sequence[ToolProtocol | Callable[..., Any] | MutableMapping[str, Any]]
|
||||
| None = None,
|
||||
default_options: TOptions_co | None = None,
|
||||
default_options: TOptions_co | Mapping[str, Any] | None = None,
|
||||
chat_message_store_factory: Callable[[], ChatMessageStoreProtocol] | None = None,
|
||||
context_provider: ContextProvider | None = None,
|
||||
middleware: Sequence[Middleware] | None = None,
|
||||
@@ -410,6 +447,8 @@ class BaseChatClient(SerializationMixin, ABC, Generic[TOptions_co]):
|
||||
default_options: A TypedDict containing chat options. When using a typed client like
|
||||
``OpenAIChatClient``, this enables IDE autocomplete for provider-specific options
|
||||
including temperature, max_tokens, model_id, tool_choice, and more.
|
||||
Note: response_format typing does not flow into run outputs when set via default_options,
|
||||
and dict literals are accepted without specialized option typing.
|
||||
chat_message_store_factory: Factory function to create an instance of ChatMessageStoreProtocol.
|
||||
If not provided, the default in-memory store will be used.
|
||||
context_provider: Context providers to include during agent invocation.
|
||||
@@ -446,7 +485,7 @@ class BaseChatClient(SerializationMixin, ABC, Generic[TOptions_co]):
|
||||
description=description,
|
||||
instructions=instructions,
|
||||
tools=tools,
|
||||
default_options=default_options,
|
||||
default_options=cast(Any, default_options),
|
||||
chat_message_store_factory=chat_message_store_factory,
|
||||
context_provider=context_provider,
|
||||
middleware=middleware,
|
||||
|
||||
@@ -1,11 +1,12 @@
|
||||
# Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
import inspect
|
||||
import sys
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import AsyncIterable, Awaitable, Callable, Mapping, MutableSequence, Sequence
|
||||
from enum import Enum
|
||||
from functools import update_wrapper
|
||||
from typing import TYPE_CHECKING, Any, ClassVar, Generic, TypeAlias, TypedDict, TypeVar
|
||||
from typing import TYPE_CHECKING, Any, ClassVar, Generic, TypeAlias, TypeVar
|
||||
|
||||
from ._serialization import SerializationMixin
|
||||
from ._types import AgentResponse, AgentResponseUpdate, ChatMessage, normalize_messages, prepare_messages
|
||||
@@ -20,6 +21,10 @@ if TYPE_CHECKING:
|
||||
from ._tools import FunctionTool
|
||||
from ._types import ChatResponse, ChatResponseUpdate
|
||||
|
||||
if sys.version_info >= (3, 11):
|
||||
from typing import TypedDict # type: ignore # pragma: no cover
|
||||
else:
|
||||
from typing_extensions import TypedDict # type: ignore # pragma: no cover
|
||||
|
||||
__all__ = [
|
||||
"AgentMiddleware",
|
||||
|
||||
@@ -24,12 +24,11 @@ from typing import (
|
||||
Generic,
|
||||
Literal,
|
||||
Protocol,
|
||||
TypedDict,
|
||||
TypeVar,
|
||||
Union,
|
||||
cast,
|
||||
get_args,
|
||||
get_origin,
|
||||
overload,
|
||||
runtime_checkable,
|
||||
)
|
||||
|
||||
@@ -42,7 +41,7 @@ from .exceptions import ChatClientInitializationError, ToolException
|
||||
from .observability import (
|
||||
OPERATION_DURATION_BUCKET_BOUNDARIES,
|
||||
OtelAttr,
|
||||
capture_exception, # type: ignore
|
||||
capture_exception,
|
||||
get_function_span,
|
||||
get_function_span_attributes,
|
||||
get_meter,
|
||||
@@ -57,20 +56,21 @@ if TYPE_CHECKING:
|
||||
Content,
|
||||
)
|
||||
|
||||
from typing import overload
|
||||
|
||||
# TypeVar with defaults support for Python < 3.13
|
||||
if sys.version_info >= (3, 13):
|
||||
from typing import TypeVar as TypeVar # type: ignore # pragma: no cover
|
||||
else:
|
||||
from typing_extensions import TypeVar as TypeVar # type: ignore[import] # pragma: no cover
|
||||
if sys.version_info >= (3, 12):
|
||||
from typing import override # type: ignore # pragma: no cover
|
||||
else:
|
||||
from typing_extensions import override # type: ignore[import] # pragma: no cover
|
||||
|
||||
# TypeVar with defaults support for Python < 3.13
|
||||
if sys.version_info >= (3, 13):
|
||||
from typing import TypeVar as TypeVarWithDefaults # type: ignore # pragma: no cover
|
||||
if sys.version_info >= (3, 11):
|
||||
from typing import TypedDict # type: ignore # pragma: no cover
|
||||
else:
|
||||
from typing_extensions import (
|
||||
TypeVar as TypeVarWithDefaults, # type: ignore[import] # pragma: no cover
|
||||
)
|
||||
from typing_extensions import TypedDict # type: ignore # pragma: no cover
|
||||
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
@@ -97,8 +97,8 @@ DEFAULT_MAX_CONSECUTIVE_ERRORS_PER_REQUEST: Final[int] = 3
|
||||
TChatClient = TypeVar("TChatClient", bound="ChatClientProtocol[Any]")
|
||||
# region Helpers
|
||||
|
||||
ArgsT = TypeVarWithDefaults("ArgsT", bound=BaseModel, default=BaseModel)
|
||||
ReturnT = TypeVarWithDefaults("ReturnT", default=Any)
|
||||
ArgsT = TypeVar("ArgsT", bound=BaseModel, default=BaseModel)
|
||||
ReturnT = TypeVar("ReturnT", default=Any)
|
||||
|
||||
|
||||
def _parse_inputs(
|
||||
|
||||
@@ -1,8 +1,8 @@
|
||||
# Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
import base64
|
||||
import json
|
||||
import re
|
||||
import sys
|
||||
from collections.abc import (
|
||||
AsyncIterable,
|
||||
Callable,
|
||||
@@ -12,7 +12,7 @@ from collections.abc import (
|
||||
Sequence,
|
||||
)
|
||||
from copy import deepcopy
|
||||
from typing import Any, ClassVar, Final, Literal, TypedDict, TypeVar, overload
|
||||
from typing import TYPE_CHECKING, Any, ClassVar, Final, Generic, Literal, cast, overload
|
||||
|
||||
from pydantic import BaseModel, ValidationError
|
||||
|
||||
@@ -21,6 +21,15 @@ from ._serialization import SerializationMixin
|
||||
from ._tools import ToolProtocol, tool
|
||||
from .exceptions import AdditionItemMismatch, ContentError
|
||||
|
||||
if sys.version_info >= (3, 13):
|
||||
from typing import TypeVar # pragma: no cover
|
||||
else:
|
||||
from typing_extensions import TypeVar # pragma: no cover
|
||||
if sys.version_info >= (3, 11):
|
||||
from typing import TypedDict # type: ignore # pragma: no cover
|
||||
else:
|
||||
from typing_extensions import TypedDict # type: ignore # pragma: no cover
|
||||
|
||||
__all__ = [
|
||||
"AgentResponse",
|
||||
"AgentResponseUpdate",
|
||||
@@ -312,6 +321,8 @@ TEmbedding = TypeVar("TEmbedding")
|
||||
TChatResponse = TypeVar("TChatResponse", bound="ChatResponse")
|
||||
TToolMode = TypeVar("TToolMode", bound="ToolMode")
|
||||
TAgentRunResponse = TypeVar("TAgentRunResponse", bound="AgentResponse")
|
||||
TResponseModel = TypeVar("TResponseModel", bound=BaseModel | None, default=None, covariant=True)
|
||||
TResponseModelT = TypeVar("TResponseModelT", bound=BaseModel)
|
||||
|
||||
CreatedAtT = str # Use a datetimeoffset type? Or a more specific type like datetime.datetime?
|
||||
|
||||
@@ -1911,7 +1922,7 @@ def _finalize_response(response: "ChatResponse | AgentResponse") -> None:
|
||||
_coalesce_text_content(msg.contents, "text_reasoning")
|
||||
|
||||
|
||||
class ChatResponse(SerializationMixin):
|
||||
class ChatResponse(SerializationMixin, Generic[TResponseModel]):
|
||||
"""Represents the response to a chat request.
|
||||
|
||||
Attributes:
|
||||
@@ -1974,7 +1985,7 @@ class ChatResponse(SerializationMixin):
|
||||
created_at: CreatedAtT | None = None,
|
||||
finish_reason: FinishReason | None = None,
|
||||
usage_details: UsageDetails | None = None,
|
||||
value: Any | None = None,
|
||||
value: TResponseModel | None = None,
|
||||
response_format: type[BaseModel] | None = None,
|
||||
additional_properties: dict[str, Any] | None = None,
|
||||
raw_representation: Any | None = None,
|
||||
@@ -2009,7 +2020,7 @@ class ChatResponse(SerializationMixin):
|
||||
created_at: CreatedAtT | None = None,
|
||||
finish_reason: FinishReason | None = None,
|
||||
usage_details: UsageDetails | None = None,
|
||||
value: Any | None = None,
|
||||
value: TResponseModel | None = None,
|
||||
response_format: type[BaseModel] | None = None,
|
||||
additional_properties: dict[str, Any] | None = None,
|
||||
raw_representation: Any | None = None,
|
||||
@@ -2044,7 +2055,7 @@ class ChatResponse(SerializationMixin):
|
||||
created_at: CreatedAtT | None = None,
|
||||
finish_reason: FinishReason | dict[str, Any] | None = None,
|
||||
usage_details: UsageDetails | dict[str, Any] | None = None,
|
||||
value: Any | None = None,
|
||||
value: TResponseModel | None = None,
|
||||
response_format: type[BaseModel] | None = None,
|
||||
additional_properties: dict[str, Any] | None = None,
|
||||
raw_representation: Any | None = None,
|
||||
@@ -2101,13 +2112,31 @@ class ChatResponse(SerializationMixin):
|
||||
self.created_at = created_at
|
||||
self.finish_reason = finish_reason
|
||||
self.usage_details = usage_details
|
||||
self._value: Any | None = value
|
||||
self._value: TResponseModel | None = value
|
||||
self._response_format: type[BaseModel] | None = response_format
|
||||
self._value_parsed: bool = value is not None
|
||||
self.additional_properties = additional_properties or {}
|
||||
self.additional_properties.update(kwargs or {})
|
||||
self.raw_representation: Any | list[Any] | None = raw_representation
|
||||
|
||||
@overload
|
||||
@classmethod
|
||||
def from_chat_response_updates(
|
||||
cls: type["ChatResponse[Any]"],
|
||||
updates: Sequence["ChatResponseUpdate"],
|
||||
*,
|
||||
output_format_type: type[TResponseModelT],
|
||||
) -> "ChatResponse[TResponseModelT]": ...
|
||||
|
||||
@overload
|
||||
@classmethod
|
||||
def from_chat_response_updates(
|
||||
cls: type["ChatResponse[Any]"],
|
||||
updates: Sequence["ChatResponseUpdate"],
|
||||
*,
|
||||
output_format_type: None = None,
|
||||
) -> "ChatResponse[Any]": ...
|
||||
|
||||
@classmethod
|
||||
def from_chat_response_updates(
|
||||
cls: type[TChatResponse],
|
||||
@@ -2146,12 +2175,30 @@ class ChatResponse(SerializationMixin):
|
||||
msg.try_parse_value(output_format_type)
|
||||
return msg
|
||||
|
||||
@overload
|
||||
@classmethod
|
||||
async def from_chat_response_generator(
|
||||
cls: type["ChatResponse[Any]"],
|
||||
updates: AsyncIterable["ChatResponseUpdate"],
|
||||
*,
|
||||
output_format_type: type[TResponseModelT],
|
||||
) -> "ChatResponse[TResponseModelT]": ...
|
||||
|
||||
@overload
|
||||
@classmethod
|
||||
async def from_chat_response_generator(
|
||||
cls: type["ChatResponse[Any]"],
|
||||
updates: AsyncIterable["ChatResponseUpdate"],
|
||||
*,
|
||||
output_format_type: None = None,
|
||||
) -> "ChatResponse[Any]": ...
|
||||
|
||||
@classmethod
|
||||
async def from_chat_response_generator(
|
||||
cls: type[TChatResponse],
|
||||
updates: AsyncIterable["ChatResponseUpdate"],
|
||||
*,
|
||||
output_format_type: type[BaseModel] | Mapping[str, Any] | None = None,
|
||||
output_format_type: type[BaseModel] | None = None,
|
||||
) -> TChatResponse:
|
||||
"""Joins multiple updates into a single ChatResponse.
|
||||
|
||||
@@ -2187,7 +2234,7 @@ class ChatResponse(SerializationMixin):
|
||||
return ("\n".join(message.text for message in self.messages if isinstance(message, ChatMessage))).strip()
|
||||
|
||||
@property
|
||||
def value(self) -> Any | None:
|
||||
def value(self) -> TResponseModel | None:
|
||||
"""Get the parsed structured output value.
|
||||
|
||||
If a response_format was provided and parsing hasn't been attempted yet,
|
||||
@@ -2203,14 +2250,20 @@ class ChatResponse(SerializationMixin):
|
||||
and isinstance(self._response_format, type)
|
||||
and issubclass(self._response_format, BaseModel)
|
||||
):
|
||||
self._value = self._response_format.model_validate_json(self.text)
|
||||
self._value = cast(TResponseModel, self._response_format.model_validate_json(self.text))
|
||||
self._value_parsed = True
|
||||
return self._value
|
||||
|
||||
def __str__(self) -> str:
|
||||
return self.text
|
||||
|
||||
def try_parse_value(self, output_format_type: type[_T] | None = None) -> _T | None:
|
||||
@overload
|
||||
def try_parse_value(self, output_format_type: type[TResponseModelT]) -> TResponseModelT | None: ...
|
||||
|
||||
@overload
|
||||
def try_parse_value(self, output_format_type: None = None) -> TResponseModel | None: ...
|
||||
|
||||
def try_parse_value(self, output_format_type: type[BaseModel] | None = None) -> BaseModel | None:
|
||||
"""Try to parse the text into a typed value.
|
||||
|
||||
This is the safe alternative to accessing the value property directly.
|
||||
@@ -2238,7 +2291,7 @@ class ChatResponse(SerializationMixin):
|
||||
try:
|
||||
parsed_value = format_type.model_validate_json(self.text) # type: ignore[reportUnknownMemberType]
|
||||
if use_cache:
|
||||
self._value = parsed_value
|
||||
self._value = cast(TResponseModel, parsed_value)
|
||||
self._value_parsed = True
|
||||
return parsed_value # type: ignore[return-value]
|
||||
except ValidationError as ex:
|
||||
@@ -2376,7 +2429,7 @@ class ChatResponseUpdate(SerializationMixin):
|
||||
# region AgentResponse
|
||||
|
||||
|
||||
class AgentResponse(SerializationMixin):
|
||||
class AgentResponse(SerializationMixin, Generic[TResponseModel]):
|
||||
"""Represents the response to an Agent run request.
|
||||
|
||||
Provides one or more response messages and metadata about the response.
|
||||
@@ -2428,7 +2481,7 @@ class AgentResponse(SerializationMixin):
|
||||
response_id: str | None = None,
|
||||
created_at: CreatedAtT | None = None,
|
||||
usage_details: UsageDetails | MutableMapping[str, Any] | None = None,
|
||||
value: Any | None = None,
|
||||
value: TResponseModel | None = None,
|
||||
response_format: type[BaseModel] | None = None,
|
||||
raw_representation: Any | None = None,
|
||||
additional_properties: dict[str, Any] | None = None,
|
||||
@@ -2469,7 +2522,7 @@ class AgentResponse(SerializationMixin):
|
||||
self.response_id = response_id
|
||||
self.created_at = created_at
|
||||
self.usage_details = usage_details
|
||||
self._value: Any | None = value
|
||||
self._value: TResponseModel | None = value
|
||||
self._response_format: type[BaseModel] | None = response_format
|
||||
self._value_parsed: bool = value is not None
|
||||
self.additional_properties = additional_properties or {}
|
||||
@@ -2482,7 +2535,7 @@ class AgentResponse(SerializationMixin):
|
||||
return "".join(msg.text for msg in self.messages) if self.messages else ""
|
||||
|
||||
@property
|
||||
def value(self) -> Any | None:
|
||||
def value(self) -> TResponseModel | None:
|
||||
"""Get the parsed structured output value.
|
||||
|
||||
If a response_format was provided and parsing hasn't been attempted yet,
|
||||
@@ -2498,7 +2551,7 @@ class AgentResponse(SerializationMixin):
|
||||
and isinstance(self._response_format, type)
|
||||
and issubclass(self._response_format, BaseModel)
|
||||
):
|
||||
self._value = self._response_format.model_validate_json(self.text)
|
||||
self._value = cast(TResponseModel, self._response_format.model_validate_json(self.text))
|
||||
self._value_parsed = True
|
||||
return self._value
|
||||
|
||||
@@ -2512,6 +2565,24 @@ class AgentResponse(SerializationMixin):
|
||||
if isinstance(content, Content) and content.user_input_request
|
||||
]
|
||||
|
||||
@overload
|
||||
@classmethod
|
||||
def from_agent_run_response_updates(
|
||||
cls: type["AgentResponse[Any]"],
|
||||
updates: Sequence["AgentResponseUpdate"],
|
||||
*,
|
||||
output_format_type: type[TResponseModelT],
|
||||
) -> "AgentResponse[TResponseModelT]": ...
|
||||
|
||||
@overload
|
||||
@classmethod
|
||||
def from_agent_run_response_updates(
|
||||
cls: type["AgentResponse[Any]"],
|
||||
updates: Sequence["AgentResponseUpdate"],
|
||||
*,
|
||||
output_format_type: None = None,
|
||||
) -> "AgentResponse[Any]": ...
|
||||
|
||||
@classmethod
|
||||
def from_agent_run_response_updates(
|
||||
cls: type[TAgentRunResponse],
|
||||
@@ -2535,6 +2606,24 @@ class AgentResponse(SerializationMixin):
|
||||
msg.try_parse_value(output_format_type)
|
||||
return msg
|
||||
|
||||
@overload
|
||||
@classmethod
|
||||
async def from_agent_response_generator(
|
||||
cls: type["AgentResponse[Any]"],
|
||||
updates: AsyncIterable["AgentResponseUpdate"],
|
||||
*,
|
||||
output_format_type: type[TResponseModelT],
|
||||
) -> "AgentResponse[TResponseModelT]": ...
|
||||
|
||||
@overload
|
||||
@classmethod
|
||||
async def from_agent_response_generator(
|
||||
cls: type["AgentResponse[Any]"],
|
||||
updates: AsyncIterable["AgentResponseUpdate"],
|
||||
*,
|
||||
output_format_type: None = None,
|
||||
) -> "AgentResponse[Any]": ...
|
||||
|
||||
@classmethod
|
||||
async def from_agent_response_generator(
|
||||
cls: type[TAgentRunResponse],
|
||||
@@ -2561,7 +2650,13 @@ class AgentResponse(SerializationMixin):
|
||||
def __str__(self) -> str:
|
||||
return self.text
|
||||
|
||||
def try_parse_value(self, output_format_type: type[_T] | None = None) -> _T | None:
|
||||
@overload
|
||||
def try_parse_value(self, output_format_type: type[TResponseModelT]) -> TResponseModelT | None: ...
|
||||
|
||||
@overload
|
||||
def try_parse_value(self, output_format_type: None = None) -> TResponseModel | None: ...
|
||||
|
||||
def try_parse_value(self, output_format_type: type[BaseModel] | None = None) -> BaseModel | None:
|
||||
"""Try to parse the text into a typed value.
|
||||
|
||||
This is the safe alternative when you need to parse the response text into a typed value.
|
||||
@@ -2589,7 +2684,7 @@ class AgentResponse(SerializationMixin):
|
||||
try:
|
||||
parsed_value = format_type.model_validate_json(self.text) # type: ignore[reportUnknownMemberType]
|
||||
if use_cache:
|
||||
self._value = parsed_value
|
||||
self._value = cast(TResponseModel, parsed_value)
|
||||
self._value_parsed = True
|
||||
return parsed_value # type: ignore[return-value]
|
||||
except ValidationError as ex:
|
||||
@@ -2718,7 +2813,7 @@ class ToolMode(TypedDict, total=False):
|
||||
# region TypedDict-based Chat Options
|
||||
|
||||
|
||||
class ChatOptions(TypedDict, total=False):
|
||||
class _ChatOptionsBase(TypedDict, total=False):
|
||||
"""Common request settings for AI services as a TypedDict.
|
||||
|
||||
All fields are optional (total=False) to allow partial specification.
|
||||
@@ -2771,7 +2866,7 @@ class ChatOptions(TypedDict, total=False):
|
||||
allow_multiple_tool_calls: bool
|
||||
|
||||
# Response configuration
|
||||
response_format: type[BaseModel] | dict[str, Any]
|
||||
response_format: type[BaseModel] | Mapping[str, Any] | None
|
||||
|
||||
# Metadata
|
||||
metadata: dict[str, Any]
|
||||
@@ -2783,6 +2878,15 @@ class ChatOptions(TypedDict, total=False):
|
||||
instructions: str
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
|
||||
class ChatOptions(_ChatOptionsBase, Generic[TResponseModel], total=False):
|
||||
response_format: type[TResponseModel] | Mapping[str, Any] | None # type: ignore[misc]
|
||||
|
||||
else:
|
||||
ChatOptions = _ChatOptionsBase
|
||||
|
||||
|
||||
# region Chat Options Utility Functions
|
||||
|
||||
|
||||
|
||||
@@ -2,11 +2,12 @@
|
||||
|
||||
import json
|
||||
import logging
|
||||
import sys
|
||||
import uuid
|
||||
from collections.abc import AsyncIterable
|
||||
from dataclasses import dataclass
|
||||
from datetime import datetime, timezone
|
||||
from typing import TYPE_CHECKING, Any, ClassVar, TypedDict, cast
|
||||
from typing import TYPE_CHECKING, Any, ClassVar, cast
|
||||
|
||||
from agent_framework import (
|
||||
AgentResponse,
|
||||
@@ -32,6 +33,11 @@ from ._events import (
|
||||
from ._message_utils import normalize_messages_input
|
||||
from ._typing_utils import is_type_compatible
|
||||
|
||||
if sys.version_info >= (3, 11):
|
||||
from typing import TypedDict # type: ignore # pragma: no cover
|
||||
else:
|
||||
from typing_extensions import TypedDict # type: ignore # pragma: no cover
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ._workflow import Workflow
|
||||
|
||||
|
||||
@@ -16,7 +16,7 @@ from ._const import WORKFLOW_RUN_KWARGS_KEY
|
||||
from ._conversation_state import encode_chat_messages
|
||||
from ._events import (
|
||||
AgentRunEvent,
|
||||
AgentRunUpdateEvent, # type: ignore[reportPrivateUsage]
|
||||
AgentRunUpdateEvent,
|
||||
)
|
||||
from ._executor import Executor, handler
|
||||
from ._message_utils import normalize_messages_input
|
||||
@@ -24,9 +24,9 @@ from ._request_info_mixin import response_handler
|
||||
from ._workflow_context import WorkflowContext
|
||||
|
||||
if sys.version_info >= (3, 12):
|
||||
from typing import override
|
||||
from typing import override # type: ignore # pragma: no cover
|
||||
else:
|
||||
from typing_extensions import override
|
||||
from typing_extensions import override # type: ignore # pragma: no cover
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -22,9 +22,9 @@ from ._orchestration_request_info import AgentApprovalExecutor
|
||||
from ._workflow_context import WorkflowContext
|
||||
|
||||
if sys.version_info >= (3, 12):
|
||||
from typing import override
|
||||
from typing import override # type: ignore # pragma: no cover
|
||||
else:
|
||||
from typing_extensions import override
|
||||
from typing_extensions import override # type: ignore # pragma: no cover
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -52,9 +52,9 @@ from ._workflow_builder import WorkflowBuilder
|
||||
from ._workflow_context import WorkflowContext
|
||||
|
||||
if sys.version_info >= (3, 12):
|
||||
from typing import override
|
||||
from typing import override # type: ignore # pragma: no cover
|
||||
else:
|
||||
from typing_extensions import override
|
||||
from typing_extensions import override # type: ignore # pragma: no cover
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
@@ -55,9 +55,9 @@ from ._workflow_builder import WorkflowBuilder
|
||||
from ._workflow_context import WorkflowContext
|
||||
|
||||
if sys.version_info >= (3, 12):
|
||||
from typing import override
|
||||
from typing import override # type: ignore # pragma: no cover
|
||||
else:
|
||||
from typing_extensions import override
|
||||
from typing_extensions import override # type: ignore # pragma: no cover
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -39,15 +39,14 @@ from ._workflow import Workflow
|
||||
from ._workflow_builder import WorkflowBuilder
|
||||
from ._workflow_context import WorkflowContext
|
||||
|
||||
if sys.version_info >= (3, 11):
|
||||
from typing import Self
|
||||
else:
|
||||
from typing_extensions import Self
|
||||
|
||||
if sys.version_info >= (3, 12):
|
||||
from typing import override
|
||||
from typing import override # type: ignore # pragma: no cover
|
||||
else:
|
||||
from typing_extensions import override
|
||||
from typing_extensions import override # type: ignore # pragma: no cover
|
||||
if sys.version_info >= (3, 11):
|
||||
from typing import Self # type: ignore # pragma: no cover
|
||||
else:
|
||||
from typing_extensions import Self # type: ignore # pragma: no cover
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -2,11 +2,12 @@
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import sys
|
||||
import uuid
|
||||
from copy import copy
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from typing import Any, Protocol, TypedDict, TypeVar, runtime_checkable
|
||||
from typing import Any, Protocol, TypeVar, runtime_checkable
|
||||
|
||||
from ._checkpoint import CheckpointStorage, WorkflowCheckpoint
|
||||
from ._checkpoint_encoding import decode_checkpoint_value, encode_checkpoint_value
|
||||
@@ -15,6 +16,11 @@ from ._events import RequestInfoEvent, WorkflowEvent
|
||||
from ._shared_state import SharedState
|
||||
from ._typing_utils import is_instance_of
|
||||
|
||||
if sys.version_info >= (3, 11):
|
||||
from typing import TypedDict # type: ignore # pragma: no cover
|
||||
else:
|
||||
from typing_extensions import TypedDict # type: ignore # pragma: no cover
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
T = TypeVar("T")
|
||||
|
||||
@@ -8,9 +8,8 @@ from typing import Any
|
||||
|
||||
from typing_extensions import deprecated
|
||||
|
||||
from agent_framework import AgentThread
|
||||
|
||||
from .._agents import AgentProtocol
|
||||
from .._threads import AgentThread
|
||||
from ..observability import OtelAttr, capture_exception, create_workflow_span
|
||||
from ._agent_executor import AgentExecutor
|
||||
from ._checkpoint import CheckpointStorage
|
||||
@@ -34,9 +33,9 @@ from ._validation import validate_workflow_graph
|
||||
from ._workflow import Workflow
|
||||
|
||||
if sys.version_info >= (3, 11):
|
||||
from typing import Self # pragma: no cover
|
||||
from typing import Self # type: ignore # pragma: no cover
|
||||
else:
|
||||
from typing_extensions import Self # pragma: no cover
|
||||
from typing_extensions import Self # type: ignore # pragma: no cover
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -26,9 +26,9 @@ from ._workflow import WorkflowRunResult
|
||||
from ._workflow_context import WorkflowContext
|
||||
|
||||
if sys.version_info >= (3, 12):
|
||||
from typing import override
|
||||
from typing import override # type: ignore # pragma: no cover
|
||||
else:
|
||||
from typing_extensions import override
|
||||
from typing_extensions import override # type: ignore # pragma: no cover
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
@@ -19,8 +19,10 @@ if sys.version_info >= (3, 13):
|
||||
from typing import TypeVar # type: ignore # pragma: no cover
|
||||
else:
|
||||
from typing_extensions import TypeVar # type: ignore # pragma: no cover
|
||||
|
||||
from typing import TypedDict
|
||||
if sys.version_info >= (3, 11):
|
||||
from typing import TypedDict # type: ignore # pragma: no cover
|
||||
else:
|
||||
from typing_extensions import TypedDict # type: ignore # pragma: no cover
|
||||
|
||||
__all__ = ["AzureOpenAIAssistantsClient"]
|
||||
|
||||
|
||||
@@ -4,13 +4,13 @@ import json
|
||||
import logging
|
||||
import sys
|
||||
from collections.abc import Mapping
|
||||
from typing import Any, Generic, TypedDict
|
||||
from typing import Any, Generic
|
||||
|
||||
from azure.core.credentials import TokenCredential
|
||||
from openai.lib.azure import AsyncAzureADTokenProvider, AsyncAzureOpenAI
|
||||
from openai.types.chat.chat_completion import Choice
|
||||
from openai.types.chat.chat_completion_chunk import Choice as ChunkChoice
|
||||
from pydantic import ValidationError
|
||||
from pydantic import BaseModel, ValidationError
|
||||
|
||||
from agent_framework import (
|
||||
Annotation,
|
||||
@@ -36,12 +36,18 @@ else:
|
||||
if sys.version_info >= (3, 12):
|
||||
from typing import override # type: ignore # pragma: no cover
|
||||
else:
|
||||
from typing_extensions import override # type: ignore[import] # pragma: no cover
|
||||
from typing_extensions import override # type: ignore # pragma: no cover
|
||||
if sys.version_info >= (3, 11):
|
||||
from typing import TypedDict # type: ignore # pragma: no cover
|
||||
else:
|
||||
from typing_extensions import TypedDict # type: ignore # pragma: no cover
|
||||
|
||||
logger: logging.Logger = logging.getLogger(__name__)
|
||||
|
||||
__all__ = ["AzureOpenAIChatClient", "AzureOpenAIChatOptions", "AzureUserSecurityContext"]
|
||||
|
||||
TResponseModel = TypeVar("TResponseModel", bound=BaseModel | None, default=None)
|
||||
|
||||
|
||||
# region Azure OpenAI Chat Options TypedDict
|
||||
|
||||
@@ -68,7 +74,7 @@ class AzureUserSecurityContext(TypedDict, total=False):
|
||||
"""The original client's IP address."""
|
||||
|
||||
|
||||
class AzureOpenAIChatOptions(OpenAIChatOptions, total=False):
|
||||
class AzureOpenAIChatOptions(OpenAIChatOptions[TResponseModel], Generic[TResponseModel], total=False):
|
||||
"""Azure OpenAI-specific chat options dict.
|
||||
|
||||
Extends OpenAIChatOptions with Azure-specific options including
|
||||
|
||||
@@ -2,34 +2,38 @@
|
||||
|
||||
import sys
|
||||
from collections.abc import Mapping
|
||||
from typing import TYPE_CHECKING, Any, Generic, TypedDict
|
||||
from typing import TYPE_CHECKING, Any, Generic
|
||||
from urllib.parse import urljoin
|
||||
|
||||
from azure.core.credentials import TokenCredential
|
||||
from openai.lib.azure import AsyncAzureADTokenProvider, AsyncAzureOpenAI
|
||||
from pydantic import ValidationError
|
||||
|
||||
from agent_framework import use_chat_middleware, use_function_invocation
|
||||
from agent_framework.exceptions import ServiceInitializationError
|
||||
from agent_framework.observability import use_instrumentation
|
||||
from agent_framework.openai._responses_client import OpenAIBaseResponsesClient
|
||||
|
||||
from .._middleware import use_chat_middleware
|
||||
from .._tools import use_function_invocation
|
||||
from ..exceptions import ServiceInitializationError
|
||||
from ..observability import use_instrumentation
|
||||
from ..openai._responses_client import OpenAIBaseResponsesClient
|
||||
from ._shared import (
|
||||
AzureOpenAIConfigMixin,
|
||||
AzureOpenAISettings,
|
||||
)
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from agent_framework.openai._responses_client import OpenAIResponsesOptions
|
||||
|
||||
if sys.version_info >= (3, 12):
|
||||
from typing import override # type: ignore # pragma: no cover
|
||||
else:
|
||||
from typing_extensions import override # type: ignore[import] # pragma: no cover
|
||||
if sys.version_info >= (3, 13):
|
||||
from typing import TypeVar # type: ignore # pragma: no cover
|
||||
else:
|
||||
from typing_extensions import TypeVar # type: ignore # pragma: no cover
|
||||
if sys.version_info >= (3, 12):
|
||||
from typing import override # type: ignore # pragma: no cover
|
||||
else:
|
||||
from typing_extensions import override # type: ignore # pragma: no cover
|
||||
if sys.version_info >= (3, 11):
|
||||
from typing import TypedDict # type: ignore # pragma: no cover
|
||||
else:
|
||||
from typing_extensions import TypedDict # type: ignore # pragma: no cover
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ..openai._responses_client import OpenAIResponsesOptions
|
||||
|
||||
__all__ = ["AzureOpenAIResponsesClient"]
|
||||
|
||||
|
||||
@@ -2,7 +2,7 @@
|
||||
|
||||
import sys
|
||||
from collections.abc import Awaitable, Callable, MutableMapping, Sequence
|
||||
from typing import TYPE_CHECKING, Any, Generic, TypedDict, cast
|
||||
from typing import TYPE_CHECKING, Any, Generic, cast
|
||||
|
||||
from openai import AsyncOpenAI
|
||||
from openai.types.beta.assistant import Assistant
|
||||
@@ -21,10 +21,13 @@ if TYPE_CHECKING:
|
||||
from ._assistants_client import OpenAIAssistantsOptions
|
||||
|
||||
if sys.version_info >= (3, 13):
|
||||
from typing import Self, TypeVar # pragma: no cover
|
||||
from typing import TypeVar # type:ignore # pragma: no cover
|
||||
else:
|
||||
from typing_extensions import Self, TypeVar # pragma: no cover
|
||||
|
||||
from typing_extensions import TypeVar # type:ignore # pragma: no cover
|
||||
if sys.version_info >= (3, 11):
|
||||
from typing import Self, TypedDict # type:ignore # pragma: no cover
|
||||
else:
|
||||
from typing_extensions import Self, TypedDict # type:ignore # pragma: no cover
|
||||
|
||||
__all__ = ["OpenAIAssistantProvider"]
|
||||
|
||||
|
||||
@@ -9,12 +9,8 @@ from collections.abc import (
|
||||
Mapping,
|
||||
MutableMapping,
|
||||
MutableSequence,
|
||||
Sequence,
|
||||
)
|
||||
from typing import TYPE_CHECKING, Any, Generic, Literal, TypedDict, cast
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .._agents import ChatAgent
|
||||
from typing import Any, Generic, Literal, cast
|
||||
|
||||
from openai import AsyncOpenAI
|
||||
from openai.types.beta.threads import (
|
||||
@@ -29,17 +25,14 @@ from openai.types.beta.threads import (
|
||||
from openai.types.beta.threads.run_create_params import AdditionalMessage
|
||||
from openai.types.beta.threads.run_submit_tool_outputs_params import ToolOutput
|
||||
from openai.types.beta.threads.runs import RunStep
|
||||
from pydantic import ValidationError
|
||||
from pydantic import BaseModel, ValidationError
|
||||
|
||||
from .._clients import BaseChatClient
|
||||
from .._memory import ContextProvider
|
||||
from .._middleware import Middleware, use_chat_middleware
|
||||
from .._threads import ChatMessageStoreProtocol
|
||||
from .._middleware import use_chat_middleware
|
||||
from .._tools import (
|
||||
FunctionTool,
|
||||
HostedCodeInterpreterTool,
|
||||
HostedFileSearchTool,
|
||||
ToolProtocol,
|
||||
use_function_invocation,
|
||||
)
|
||||
from .._types import (
|
||||
@@ -57,19 +50,19 @@ from ..observability import use_instrumentation
|
||||
from ._shared import OpenAIConfigMixin, OpenAISettings
|
||||
|
||||
if sys.version_info >= (3, 13):
|
||||
from typing import TypeVar
|
||||
from typing import TypeVar # type: ignore # pragma: no cover
|
||||
else:
|
||||
from typing_extensions import TypeVar
|
||||
from typing_extensions import TypeVar # type: ignore # pragma: no cover
|
||||
|
||||
if sys.version_info >= (3, 12):
|
||||
from typing import override # type: ignore # pragma: no cover
|
||||
else:
|
||||
from typing_extensions import override # type: ignore[import] # pragma: no cover
|
||||
from typing_extensions import override # type: ignore # pragma: no cover
|
||||
|
||||
if sys.version_info >= (3, 11):
|
||||
from typing import Self # pragma: no cover
|
||||
from typing import Self, TypedDict # type: ignore # pragma: no cover
|
||||
else:
|
||||
from typing_extensions import Self # pragma: no cover
|
||||
from typing_extensions import Self, TypedDict # type: ignore # pragma: no cover
|
||||
|
||||
|
||||
__all__ = [
|
||||
@@ -81,6 +74,8 @@ __all__ = [
|
||||
|
||||
# region OpenAI Assistants Options TypedDict
|
||||
|
||||
TResponseModel = TypeVar("TResponseModel", bound=BaseModel | None, default=None)
|
||||
|
||||
|
||||
class VectorStoreToolResource(TypedDict, total=False):
|
||||
"""Vector store configuration for file search tool resources."""
|
||||
@@ -109,7 +104,7 @@ class AssistantToolResources(TypedDict, total=False):
|
||||
"""Resources for file search tool, including vector store IDs."""
|
||||
|
||||
|
||||
class OpenAIAssistantsOptions(ChatOptions, total=False):
|
||||
class OpenAIAssistantsOptions(ChatOptions[TResponseModel], Generic[TResponseModel], total=False):
|
||||
"""OpenAI Assistants API-specific options dict.
|
||||
|
||||
Extends base ChatOptions with Assistants API-specific parameters
|
||||
@@ -765,59 +760,3 @@ class OpenAIAssistantsClient(
|
||||
self.assistant_name = agent_name
|
||||
if description and not self.assistant_description:
|
||||
self.assistant_description = description
|
||||
|
||||
@override
|
||||
def as_agent(
|
||||
self,
|
||||
*,
|
||||
id: str | None = None,
|
||||
name: str | None = None,
|
||||
description: str | None = None,
|
||||
instructions: str | None = None,
|
||||
tools: ToolProtocol
|
||||
| Callable[..., Any]
|
||||
| MutableMapping[str, Any]
|
||||
| Sequence[ToolProtocol | Callable[..., Any] | MutableMapping[str, Any]]
|
||||
| None = None,
|
||||
default_options: TOpenAIAssistantsOptions | None = None,
|
||||
chat_message_store_factory: Callable[[], ChatMessageStoreProtocol] | None = None,
|
||||
context_provider: ContextProvider | None = None,
|
||||
middleware: Sequence[Middleware] | None = None,
|
||||
**kwargs: Any,
|
||||
) -> "ChatAgent[TOpenAIAssistantsOptions]":
|
||||
"""Convert this chat client to a ChatAgent.
|
||||
|
||||
This method creates a ChatAgent instance with this client pre-configured.
|
||||
It does NOT create an assistant on the OpenAI service - the actual assistant
|
||||
will be created on the server during the first invocation (run).
|
||||
|
||||
For creating and managing persistent assistants on the server, use
|
||||
:class:`~agent_framework.openai.OpenAIAssistantProvider` instead.
|
||||
|
||||
Keyword Args:
|
||||
id: The unique identifier for the agent. Will be created automatically if not provided.
|
||||
name: The name of the agent.
|
||||
description: A brief description of the agent's purpose.
|
||||
instructions: Optional instructions for the agent.
|
||||
tools: The tools to use for the request.
|
||||
default_options: A TypedDict containing chat options.
|
||||
chat_message_store_factory: Factory function to create an instance of ChatMessageStoreProtocol.
|
||||
context_provider: Context providers to include during agent invocation.
|
||||
middleware: List of middleware to intercept agent and function invocations.
|
||||
kwargs: Any additional keyword arguments.
|
||||
|
||||
Returns:
|
||||
A ChatAgent instance configured with this chat client.
|
||||
"""
|
||||
return super().as_agent(
|
||||
id=id,
|
||||
name=name,
|
||||
description=description,
|
||||
instructions=instructions,
|
||||
tools=tools,
|
||||
default_options=default_options,
|
||||
chat_message_store_factory=chat_message_store_factory,
|
||||
context_provider=context_provider,
|
||||
middleware=middleware,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
@@ -5,7 +5,7 @@ import sys
|
||||
from collections.abc import AsyncIterable, Awaitable, Callable, Mapping, MutableMapping, MutableSequence, Sequence
|
||||
from datetime import datetime, timezone
|
||||
from itertools import chain
|
||||
from typing import Any, Generic, Literal, TypedDict
|
||||
from typing import Any, Generic, Literal
|
||||
|
||||
from openai import AsyncOpenAI, BadRequestError
|
||||
from openai.lib._parsing._completions import type_to_response_format_param
|
||||
@@ -14,7 +14,7 @@ from openai.types.chat.chat_completion import ChatCompletion, Choice
|
||||
from openai.types.chat.chat_completion_chunk import ChatCompletionChunk
|
||||
from openai.types.chat.chat_completion_chunk import Choice as ChunkChoice
|
||||
from openai.types.chat.chat_completion_message_custom_tool_call import ChatCompletionMessageCustomToolCall
|
||||
from pydantic import ValidationError
|
||||
from pydantic import BaseModel, ValidationError
|
||||
|
||||
from .._clients import BaseChatClient
|
||||
from .._logging import get_logger
|
||||
@@ -41,19 +41,24 @@ from ._exceptions import OpenAIContentFilterException
|
||||
from ._shared import OpenAIBase, OpenAIConfigMixin, OpenAISettings
|
||||
|
||||
if sys.version_info >= (3, 13):
|
||||
from typing import TypeVar
|
||||
from typing import TypeVar # type: ignore # pragma: no cover
|
||||
else:
|
||||
from typing_extensions import TypeVar
|
||||
|
||||
from typing_extensions import TypeVar # type: ignore # pragma: no cover
|
||||
if sys.version_info >= (3, 12):
|
||||
from typing import override # type: ignore # pragma: no cover
|
||||
else:
|
||||
from typing_extensions import override # type: ignore[import] # pragma: no cover
|
||||
from typing_extensions import override # type: ignore # pragma: no cover
|
||||
if sys.version_info >= (3, 11):
|
||||
from typing import TypedDict # type: ignore # pragma: no cover
|
||||
else:
|
||||
from typing_extensions import TypedDict # type: ignore # pragma: no cover
|
||||
|
||||
__all__ = ["OpenAIChatClient", "OpenAIChatOptions"]
|
||||
|
||||
logger = get_logger("agent_framework.openai")
|
||||
|
||||
TResponseModel = TypeVar("TResponseModel", bound=BaseModel | None, default=None)
|
||||
|
||||
|
||||
# region OpenAI Chat Options TypedDict
|
||||
|
||||
@@ -72,7 +77,7 @@ class Prediction(TypedDict, total=False):
|
||||
content: str | list[PredictionTextContent]
|
||||
|
||||
|
||||
class OpenAIChatOptions(ChatOptions, total=False):
|
||||
class OpenAIChatOptions(ChatOptions[TResponseModel], Generic[TResponseModel], total=False):
|
||||
"""OpenAI-specific chat options dict.
|
||||
|
||||
Extends ChatOptions with options specific to OpenAI's Chat Completions API.
|
||||
|
||||
@@ -12,7 +12,7 @@ from collections.abc import (
|
||||
)
|
||||
from datetime import datetime, timezone
|
||||
from itertools import chain
|
||||
from typing import Any, Generic, Literal, TypedDict, cast
|
||||
from typing import Any, Generic, Literal, cast
|
||||
|
||||
from openai import AsyncOpenAI, BadRequestError
|
||||
from openai.types.responses.file_search_tool_param import FileSearchToolParam
|
||||
@@ -79,9 +79,14 @@ if sys.version_info >= (3, 12):
|
||||
from typing import override # type: ignore # pragma: no cover
|
||||
else:
|
||||
from typing_extensions import override # type: ignore[import] # pragma: no cover
|
||||
if sys.version_info >= (3, 11):
|
||||
from typing import TypedDict # type: ignore # pragma: no cover
|
||||
else:
|
||||
from typing_extensions import TypedDict # type: ignore # pragma: no cover
|
||||
|
||||
logger = get_logger("agent_framework.openai")
|
||||
|
||||
|
||||
__all__ = ["OpenAIResponsesClient", "OpenAIResponsesOptions"]
|
||||
|
||||
|
||||
@@ -108,7 +113,10 @@ class StreamOptions(TypedDict, total=False):
|
||||
"""Whether to include usage statistics in stream events."""
|
||||
|
||||
|
||||
class OpenAIResponsesOptions(ChatOptions, total=False):
|
||||
TResponseFormat = TypeVar("TResponseFormat", bound=BaseModel | None, default=None)
|
||||
|
||||
|
||||
class OpenAIResponsesOptions(ChatOptions[TResponseFormat], Generic[TResponseFormat], total=False):
|
||||
"""OpenAI Responses API-specific chat options.
|
||||
|
||||
Extends ChatOptions with options specific to OpenAI's Responses API.
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
# Copyright (c) Microsoft. All rights reserved.
|
||||
from typing import Annotated, Any, Literal
|
||||
from typing import Annotated, Any, Literal, get_args, get_origin
|
||||
from unittest.mock import Mock
|
||||
|
||||
import pytest
|
||||
@@ -1493,8 +1493,6 @@ async def test_tool_with_kwargs_injection():
|
||||
|
||||
def test_parse_annotation_with_literal_type():
|
||||
"""Test that _parse_annotation returns Literal types unchanged (issue #2891)."""
|
||||
from typing import get_args, get_origin
|
||||
|
||||
# Literal with string values
|
||||
literal_annotation = Literal["Data", "Security", "Network"]
|
||||
result = _parse_annotation(literal_annotation)
|
||||
@@ -1505,7 +1503,6 @@ def test_parse_annotation_with_literal_type():
|
||||
|
||||
def test_parse_annotation_with_literal_int_type():
|
||||
"""Test that _parse_annotation returns Literal int types unchanged."""
|
||||
from typing import get_args, get_origin
|
||||
|
||||
literal_annotation = Literal[1, 2, 3]
|
||||
result = _parse_annotation(literal_annotation)
|
||||
@@ -1516,7 +1513,6 @@ def test_parse_annotation_with_literal_int_type():
|
||||
|
||||
def test_parse_annotation_with_literal_bool_type():
|
||||
"""Test that _parse_annotation returns Literal bool types unchanged."""
|
||||
from typing import get_args, get_origin
|
||||
|
||||
literal_annotation = Literal[True, False]
|
||||
result = _parse_annotation(literal_annotation)
|
||||
@@ -1535,7 +1531,6 @@ def test_parse_annotation_with_simple_types():
|
||||
|
||||
def test_parse_annotation_with_annotated_and_literal():
|
||||
"""Test that Annotated[Literal[...], description] works correctly."""
|
||||
from typing import get_args, get_origin
|
||||
|
||||
# When Literal is inside Annotated, it should still be preserved
|
||||
annotated_literal = Annotated[Literal["A", "B", "C"], "The category"]
|
||||
|
||||
@@ -3,10 +3,10 @@
|
||||
import base64
|
||||
from collections.abc import AsyncIterable
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any
|
||||
from typing import Any, Literal
|
||||
|
||||
import pytest
|
||||
from pydantic import BaseModel
|
||||
from pydantic import BaseModel, Field, ValidationError
|
||||
from pytest import fixture, mark, raises
|
||||
|
||||
from agent_framework import (
|
||||
@@ -665,9 +665,6 @@ def test_chat_response_with_format_init():
|
||||
|
||||
def test_chat_response_value_raises_on_invalid_schema():
|
||||
"""Test that value property raises ValidationError with field constraint details."""
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import Field, ValidationError
|
||||
|
||||
class StrictSchema(BaseModel):
|
||||
id: Literal[5]
|
||||
@@ -689,9 +686,6 @@ def test_chat_response_value_raises_on_invalid_schema():
|
||||
|
||||
def test_chat_response_try_parse_value_returns_none_on_invalid():
|
||||
"""Test that try_parse_value returns None on validation failure with Field constraints."""
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
class StrictSchema(BaseModel):
|
||||
id: Literal[5]
|
||||
@@ -707,7 +701,6 @@ def test_chat_response_try_parse_value_returns_none_on_invalid():
|
||||
|
||||
def test_chat_response_try_parse_value_returns_value_on_success():
|
||||
"""Test that try_parse_value returns parsed value when all constraints pass."""
|
||||
from pydantic import Field
|
||||
|
||||
class MySchema(BaseModel):
|
||||
name: str = Field(min_length=3)
|
||||
@@ -724,9 +717,6 @@ def test_chat_response_try_parse_value_returns_value_on_success():
|
||||
|
||||
def test_agent_response_value_raises_on_invalid_schema():
|
||||
"""Test that AgentResponse.value property raises ValidationError with field constraint details."""
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import Field, ValidationError
|
||||
|
||||
class StrictSchema(BaseModel):
|
||||
id: Literal[5]
|
||||
@@ -748,9 +738,6 @@ def test_agent_response_value_raises_on_invalid_schema():
|
||||
|
||||
def test_agent_response_try_parse_value_returns_none_on_invalid():
|
||||
"""Test that AgentResponse.try_parse_value returns None on Field constraint failure."""
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
class StrictSchema(BaseModel):
|
||||
id: Literal[5]
|
||||
@@ -766,7 +753,6 @@ def test_agent_response_try_parse_value_returns_none_on_invalid():
|
||||
|
||||
def test_agent_response_try_parse_value_returns_value_on_success():
|
||||
"""Test that AgentResponse.try_parse_value returns parsed value when all constraints pass."""
|
||||
from pydantic import Field
|
||||
|
||||
class MySchema(BaseModel):
|
||||
name: str = Field(min_length=3)
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
# Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
import pytest
|
||||
from typing_extensions import Never
|
||||
|
||||
from agent_framework import (
|
||||
ChatMessage,
|
||||
@@ -187,7 +188,6 @@ async def test_executor_completed_event_contains_sent_messages():
|
||||
|
||||
async def test_executor_completed_event_includes_yielded_outputs():
|
||||
"""Test that ExecutorCompletedEvent.data includes yielded outputs."""
|
||||
from typing_extensions import Never
|
||||
|
||||
from agent_framework import WorkflowOutputEvent
|
||||
|
||||
@@ -318,7 +318,6 @@ def test_executor_output_types_property():
|
||||
|
||||
def test_executor_workflow_output_types_property():
|
||||
"""Test that the workflow_output_types property correctly identifies workflow output types."""
|
||||
from typing_extensions import Never
|
||||
|
||||
# Test executor with no workflow output types
|
||||
class NoWorkflowOutputExecutor(Executor):
|
||||
|
||||
@@ -42,9 +42,9 @@ from agent_framework import (
|
||||
from agent_framework._workflows._checkpoint import InMemoryCheckpointStorage
|
||||
|
||||
if sys.version_info >= (3, 12):
|
||||
from typing import override
|
||||
from typing import override # type: ignore # pragma: no cover
|
||||
else:
|
||||
from typing_extensions import override
|
||||
from typing_extensions import override # type: ignore # pragma: no cover
|
||||
|
||||
|
||||
def test_magentic_context_reset_behavior():
|
||||
|
||||
@@ -1,8 +1,9 @@
|
||||
# Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
import sys
|
||||
from collections.abc import Callable, Mapping
|
||||
from pathlib import Path
|
||||
from typing import Any, Literal, TypedDict, cast
|
||||
from typing import Any, Literal, cast
|
||||
|
||||
import yaml
|
||||
from agent_framework import (
|
||||
@@ -42,6 +43,11 @@ from ._models import (
|
||||
agent_schema_dispatch,
|
||||
)
|
||||
|
||||
if sys.version_info >= (3, 11):
|
||||
from typing import TypedDict # type: ignore # pragma: no cover
|
||||
else:
|
||||
from typing_extensions import TypedDict # type: ignore # pragma: no cover
|
||||
|
||||
|
||||
class ProviderTypeMapping(TypedDict, total=True):
|
||||
package: str
|
||||
|
||||
+8
-1
@@ -24,10 +24,11 @@ See: dotnet/src/Microsoft.Agents.AI.Workflows.Declarative/PowerFx/
|
||||
"""
|
||||
|
||||
import logging
|
||||
import sys
|
||||
from collections.abc import Mapping
|
||||
from dataclasses import dataclass
|
||||
from decimal import Decimal as _Decimal
|
||||
from typing import Any, Literal, TypedDict, cast
|
||||
from typing import Any, Literal, cast
|
||||
|
||||
from agent_framework._workflows import (
|
||||
Executor,
|
||||
@@ -36,6 +37,12 @@ from agent_framework._workflows import (
|
||||
)
|
||||
from powerfx import Engine
|
||||
|
||||
if sys.version_info >= (3, 11):
|
||||
from typing import TypedDict # type: ignore # pragma: no cover
|
||||
else:
|
||||
from typing_extensions import TypedDict # type: ignore # pragma: no cover
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
|
||||
@@ -7,7 +7,7 @@ import json
|
||||
import logging
|
||||
from dataclasses import fields, is_dataclass
|
||||
from types import UnionType
|
||||
from typing import Any, Union, get_args, get_origin
|
||||
from typing import Any, Union, get_args, get_origin, get_type_hints
|
||||
|
||||
from agent_framework import ChatMessage
|
||||
|
||||
@@ -270,8 +270,6 @@ def generate_schema_from_serialization_mixin(cls: type[Any]) -> dict[str, Any]:
|
||||
|
||||
# Get type hints
|
||||
try:
|
||||
from typing import get_type_hints
|
||||
|
||||
type_hints = get_type_hints(cls)
|
||||
except Exception:
|
||||
type_hints = {}
|
||||
@@ -348,8 +346,6 @@ def extract_response_type_from_executor(executor: Any, request_type: type) -> ty
|
||||
The response type class, or None if not found
|
||||
"""
|
||||
try:
|
||||
from typing import get_type_hints
|
||||
|
||||
# Introspect handler methods for @response_handler pattern
|
||||
for attr_name in dir(executor):
|
||||
if attr_name.startswith("_"):
|
||||
|
||||
+9
-3
@@ -1,7 +1,7 @@
|
||||
# Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
import sys
|
||||
from typing import Any, ClassVar, Generic, TypedDict
|
||||
from typing import Any, ClassVar, Generic
|
||||
|
||||
from agent_framework import ChatOptions, use_chat_middleware, use_function_invocation
|
||||
from agent_framework._pydantic import AFBaseSettings
|
||||
@@ -11,12 +11,16 @@ from agent_framework.openai._chat_client import OpenAIBaseChatClient
|
||||
from foundry_local import FoundryLocalManager
|
||||
from foundry_local.models import DeviceType
|
||||
from openai import AsyncOpenAI
|
||||
from pydantic import BaseModel
|
||||
|
||||
if sys.version_info >= (3, 13):
|
||||
from typing import TypeVar # type: ignore # pragma: no cover
|
||||
else:
|
||||
from typing_extensions import TypeVar # type: ignore # pragma: no cover
|
||||
|
||||
if sys.version_info >= (3, 11):
|
||||
from typing import TypedDict # type: ignore # pragma: no cover
|
||||
else:
|
||||
from typing_extensions import TypedDict # type: ignore # pragma: no cover
|
||||
|
||||
__all__ = [
|
||||
"FoundryLocalChatOptions",
|
||||
@@ -24,11 +28,13 @@ __all__ = [
|
||||
"FoundryLocalSettings",
|
||||
]
|
||||
|
||||
TResponseModel = TypeVar("TResponseModel", bound=BaseModel | None, default=None)
|
||||
|
||||
|
||||
# region Foundry Local Chat Options TypedDict
|
||||
|
||||
|
||||
class FoundryLocalChatOptions(ChatOptions, total=False):
|
||||
class FoundryLocalChatOptions(ChatOptions[TResponseModel], Generic[TResponseModel], total=False):
|
||||
"""Azure Foundry Local (local model deployment) chat options dict.
|
||||
|
||||
Extends base ChatOptions for local model inference via Foundry Local.
|
||||
|
||||
@@ -3,7 +3,7 @@
|
||||
import sys
|
||||
from collections.abc import MutableSequence, Sequence
|
||||
from contextlib import AbstractAsyncContextManager
|
||||
from typing import Any, TypedDict
|
||||
from typing import Any
|
||||
|
||||
from agent_framework import ChatMessage, Context, ContextProvider
|
||||
from agent_framework.exceptions import ServiceInitializationError
|
||||
@@ -15,9 +15,9 @@ else:
|
||||
from typing_extensions import override # type: ignore[import] # pragma: no cover
|
||||
|
||||
if sys.version_info >= (3, 11):
|
||||
from typing import NotRequired, Self # pragma: no cover
|
||||
from typing import NotRequired, Self, TypedDict # pragma: no cover
|
||||
else:
|
||||
from typing_extensions import NotRequired, Self # pragma: no cover
|
||||
from typing_extensions import NotRequired, Self, TypedDict # pragma: no cover
|
||||
|
||||
|
||||
# Type aliases for Mem0 search response formats (v1.1 and v2; v1 is deprecated, but matches the type definition for v2)
|
||||
|
||||
@@ -11,7 +11,7 @@ from collections.abc import (
|
||||
Sequence,
|
||||
)
|
||||
from itertools import chain
|
||||
from typing import Any, ClassVar, Generic, TypedDict
|
||||
from typing import Any, ClassVar, Generic
|
||||
|
||||
from agent_framework import (
|
||||
BaseChatClient,
|
||||
@@ -40,26 +40,32 @@ from ollama import AsyncClient
|
||||
# Rename imported types to avoid naming conflicts with Agent Framework types
|
||||
from ollama._types import ChatResponse as OllamaChatResponse
|
||||
from ollama._types import Message as OllamaMessage
|
||||
from pydantic import ValidationError
|
||||
from pydantic import BaseModel, ValidationError
|
||||
|
||||
if sys.version_info >= (3, 13):
|
||||
from typing import TypeVar
|
||||
from typing import TypeVar # type: ignore # pragma: no cover
|
||||
else:
|
||||
from typing_extensions import TypeVar
|
||||
from typing_extensions import TypeVar # type: ignore # pragma: no cover
|
||||
|
||||
if sys.version_info >= (3, 12):
|
||||
from typing import override # type: ignore # pragma: no cover
|
||||
else:
|
||||
from typing_extensions import override # type: ignore[import] # pragma: no cover
|
||||
from typing_extensions import override # type: ignore # pragma: no cover
|
||||
|
||||
if sys.version_info >= (3, 11):
|
||||
from typing import TypedDict # type: ignore # pragma: no cover
|
||||
else:
|
||||
from typing_extensions import TypedDict # type: ignore # pragma: no cover
|
||||
|
||||
__all__ = ["OllamaChatClient", "OllamaChatOptions"]
|
||||
|
||||
TResponseModel = TypeVar("TResponseModel", bound=BaseModel | None, default=None)
|
||||
|
||||
|
||||
# region Ollama Chat Options TypedDict
|
||||
|
||||
|
||||
class OllamaChatOptions(ChatOptions, total=False):
|
||||
class OllamaChatOptions(ChatOptions[TResponseModel], Generic[TResponseModel], total=False):
|
||||
"""Ollama-specific chat options dict.
|
||||
|
||||
Extends base ChatOptions with Ollama-specific parameters.
|
||||
|
||||
+5
-5
@@ -37,9 +37,9 @@ async def non_streaming_example() -> None:
|
||||
# Get structured response from the agent using response_format parameter
|
||||
result = await agent.run(query, options={"response_format": OutputStruct})
|
||||
|
||||
# Access the structured output using try_parse_value for safe parsing
|
||||
if structured_data := result.try_parse_value(OutputStruct):
|
||||
print("Structured Output Agent (from result.try_parse_value):")
|
||||
# Access the structured output using the parsed value
|
||||
if structured_data := result.value:
|
||||
print("Structured Output Agent:")
|
||||
print(f"City: {structured_data.city}")
|
||||
print(f"Description: {structured_data.description}")
|
||||
else:
|
||||
@@ -66,8 +66,8 @@ async def streaming_example() -> None:
|
||||
output_format_type=OutputStruct,
|
||||
)
|
||||
|
||||
# Access the structured output using try_parse_value for safe parsing
|
||||
if structured_data := result.try_parse_value(OutputStruct):
|
||||
# Access the structured output using the parsed value
|
||||
if structured_data := result.value:
|
||||
print("Structured Output (from streaming with AgentResponse.from_agent_response_generator):")
|
||||
print(f"City: {structured_data.city}")
|
||||
print(f"Description: {structured_data.description}")
|
||||
|
||||
Reference in New Issue
Block a user