Python: [BREAKING] Fix #3613 chat/agent message typing alignment (#3920)

* Fix #3613 message typing across chat and agents

* Address #3613 review feedback and sample input style

* refactor: use shared AgentRunMessages aliases (#3613)

* refactor: rename agent run input aliases for #3613

* samples: inline image content in run calls

* core: export AgentRunInputs from package init

* core: use explicit init re-exports without __all__

* updated logging and inits

* Fix core mypy export and samples XML note

* Remove AgentRunInputsOrNone and dedupe loggers

* Remove prepare_messages helper

* fix integration tests
This commit is contained in:
Eduard van Valkenburg
2026-02-16 16:27:25 +01:00
committed by GitHub
Unverified
parent 503eb10fdd
commit dc9439a75a
87 changed files with 422 additions and 578 deletions
@@ -19,7 +19,6 @@ from ._clients import (
SupportsMCPTool,
SupportsWebSearchTool,
)
from ._logging import get_logger, setup_logging
from ._mcp import MCPStdioTool, MCPStreamableHTTPTool, MCPWebsocketTool
from ._middleware import (
AgentContext,
@@ -34,7 +33,6 @@ from ._middleware import (
FunctionInvocationContext,
FunctionMiddleware,
FunctionMiddlewareTypes,
MiddlewareException,
MiddlewareTermination,
MiddlewareType,
MiddlewareTypes,
@@ -67,6 +65,7 @@ from ._tools import (
from ._types import (
AgentResponse,
AgentResponseUpdate,
AgentRunInputs,
Annotation,
ChatOptions,
ChatResponse,
@@ -152,6 +151,7 @@ from ._workflows import (
response_handler,
validate_workflow_graph,
)
from .exceptions import MiddlewareException
__all__ = [
"AGENT_FRAMEWORK_USER_AGENT",
@@ -169,6 +169,7 @@ __all__ = [
"AgentMiddlewareTypes",
"AgentResponse",
"AgentResponseUpdate",
"AgentRunInputs",
"AgentSession",
"Annotation",
"BaseAgent",
@@ -264,6 +265,7 @@ __all__ = [
"WorkflowRunnerException",
"WorkflowValidationError",
"WorkflowViz",
"__version__",
"add_usage_details",
"agent_middleware",
"chat_middleware",
@@ -271,7 +273,6 @@ __all__ = [
"detect_media_type_from_base64",
"executor",
"function_middleware",
"get_logger",
"handler",
"map_chat_to_agent_update",
"merge_chat_options",
@@ -283,7 +284,6 @@ __all__ = [
"register_state_type",
"resolve_agent_id",
"response_handler",
"setup_logging",
"tool",
"validate_chat_options",
"validate_tool_mode",
+11 -13
View File
@@ -3,6 +3,7 @@
from __future__ import annotations
import inspect
import logging
import re
import sys
from collections.abc import Awaitable, Callable, Mapping, MutableMapping, Sequence
@@ -29,7 +30,6 @@ from mcp.shared.exceptions import McpError
from pydantic import BaseModel, Field, create_model
from ._clients import BaseChatClient, SupportsChatGetResponse
from ._logging import get_logger
from ._mcp import LOG_LEVEL_MAPPING, MCPTool
from ._middleware import AgentMiddlewareLayer, MiddlewareTypes
from ._serialization import SerializationMixin
@@ -41,6 +41,7 @@ from ._tools import (
from ._types import (
AgentResponse,
AgentResponseUpdate,
AgentRunInputs,
ChatResponse,
ChatResponseUpdate,
Message,
@@ -67,7 +68,7 @@ else:
if TYPE_CHECKING:
from ._types import ChatOptions
logger = get_logger("agent_framework")
logger = logging.getLogger("agent_framework")
ResponseModelBoundT = TypeVar("ResponseModelBoundT", bound=BaseModel)
OptionsCoT = TypeVar(
@@ -159,9 +160,6 @@ class _RunContext(TypedDict):
finalize_kwargs: dict[str, Any]
__all__ = ["Agent", "BaseAgent", "RawAgent", "SupportsAgentRun"]
# region Agent Protocol
@@ -230,7 +228,7 @@ class SupportsAgentRun(Protocol):
@overload
def run(
self,
messages: str | Message | Sequence[str | Message] | None = None,
messages: AgentRunInputs | None = None,
*,
stream: Literal[False] = ...,
session: AgentSession | None = None,
@@ -242,7 +240,7 @@ class SupportsAgentRun(Protocol):
@overload
def run(
self,
messages: str | Message | Sequence[str | Message] | None = None,
messages: AgentRunInputs | None = None,
*,
stream: Literal[True],
session: AgentSession | None = None,
@@ -253,7 +251,7 @@ class SupportsAgentRun(Protocol):
def run(
self,
messages: str | Message | Sequence[str | Message] | None = None,
messages: AgentRunInputs | None = None,
*,
stream: bool = False,
session: AgentSession | None = None,
@@ -763,7 +761,7 @@ class RawAgent(BaseAgent, Generic[OptionsCoT]): # type: ignore[misc]
@overload
def run(
self,
messages: str | Message | Sequence[str | Message] | None = None,
messages: AgentRunInputs | None = None,
*,
stream: Literal[False] = ...,
session: AgentSession | None = None,
@@ -780,7 +778,7 @@ class RawAgent(BaseAgent, Generic[OptionsCoT]): # type: ignore[misc]
@overload
def run(
self,
messages: str | Message | Sequence[str | Message] | None = None,
messages: AgentRunInputs | None = None,
*,
stream: Literal[False] = ...,
session: AgentSession | None = None,
@@ -797,7 +795,7 @@ class RawAgent(BaseAgent, Generic[OptionsCoT]): # type: ignore[misc]
@overload
def run(
self,
messages: str | Message | Sequence[str | Message] | None = None,
messages: AgentRunInputs | None = None,
*,
stream: Literal[True],
session: AgentSession | None = None,
@@ -813,7 +811,7 @@ class RawAgent(BaseAgent, Generic[OptionsCoT]): # type: ignore[misc]
def run(
self,
messages: str | Message | Sequence[str | Message] | None = None,
messages: AgentRunInputs | None = None,
*,
stream: bool = False,
session: AgentSession | None = None,
@@ -1000,7 +998,7 @@ class RawAgent(BaseAgent, Generic[OptionsCoT]): # type: ignore[misc]
async def _prepare_run_context(
self,
*,
messages: str | Message | Sequence[str | Message] | None,
messages: AgentRunInputs | None,
session: AgentSession | None,
tools: FunctionTool
| Callable[..., Any]
@@ -2,6 +2,7 @@
from __future__ import annotations
import logging
import sys
from abc import ABC, abstractmethod
from collections.abc import (
@@ -27,7 +28,6 @@ from typing import (
from pydantic import BaseModel
from ._logging import get_logger
from ._serialization import SerializationMixin
from ._tools import (
FunctionInvocationConfiguration,
@@ -38,7 +38,6 @@ from ._types import (
ChatResponseUpdate,
Message,
ResponseStream,
prepare_messages,
validate_chat_options,
)
@@ -61,17 +60,7 @@ InputT = TypeVar("InputT", contravariant=True)
EmbeddingT = TypeVar("EmbeddingT")
BaseChatClientT = TypeVar("BaseChatClientT", bound="BaseChatClient")
logger = get_logger()
__all__ = [
"BaseChatClient",
"SupportsChatGetResponse",
"SupportsCodeInterpreterTool",
"SupportsFileSearchTool",
"SupportsImageGenerationTool",
"SupportsMCPTool",
"SupportsWebSearchTool",
]
logger = logging.getLogger("agent_framework")
# region SupportsChatGetResponse Protocol
@@ -139,7 +128,7 @@ class SupportsChatGetResponse(Protocol[OptionsContraT]):
@overload
def get_response(
self,
messages: str | Message | Sequence[str | Message],
messages: Sequence[Message],
*,
stream: Literal[False] = ...,
options: ChatOptions[ResponseModelBoundT],
@@ -149,7 +138,7 @@ class SupportsChatGetResponse(Protocol[OptionsContraT]):
@overload
def get_response(
self,
messages: str | Message | Sequence[str | Message],
messages: Sequence[Message],
*,
stream: Literal[False] = ...,
options: OptionsContraT | ChatOptions[None] | None = None,
@@ -159,7 +148,7 @@ class SupportsChatGetResponse(Protocol[OptionsContraT]):
@overload
def get_response(
self,
messages: str | Message | Sequence[str | Message],
messages: Sequence[Message],
*,
stream: Literal[True],
options: OptionsContraT | ChatOptions[Any] | None = None,
@@ -168,7 +157,7 @@ class SupportsChatGetResponse(Protocol[OptionsContraT]):
def get_response(
self,
messages: str | Message | Sequence[str | Message],
messages: Sequence[Message],
*,
stream: bool = False,
options: OptionsContraT | ChatOptions[Any] | None = None,
@@ -254,9 +243,9 @@ class BaseChatClient(SerializationMixin, ABC, Generic[OptionsCoT]):
client = CustomChatClient()
# Use the client to get responses
response = await client.get_response("Hello, how are you?")
response = await client.get_response([Message(role="user", text="Hello, how are you?")])
# Or stream responses
async for update in client.get_response("Hello!", stream=True):
async for update in client.get_response([Message(role="user", text="Hello!")], stream=True):
print(update)
"""
@@ -376,7 +365,7 @@ class BaseChatClient(SerializationMixin, ABC, Generic[OptionsCoT]):
@overload
def get_response(
self,
messages: str | Message | Sequence[str | Message],
messages: Sequence[Message],
*,
stream: Literal[False] = ...,
options: ChatOptions[ResponseModelBoundT],
@@ -386,7 +375,7 @@ class BaseChatClient(SerializationMixin, ABC, Generic[OptionsCoT]):
@overload
def get_response(
self,
messages: str | Message | Sequence[str | Message],
messages: Sequence[Message],
*,
stream: Literal[False] = ...,
options: OptionsCoT | ChatOptions[None] | None = None,
@@ -396,7 +385,7 @@ class BaseChatClient(SerializationMixin, ABC, Generic[OptionsCoT]):
@overload
def get_response(
self,
messages: str | Message | Sequence[str | Message],
messages: Sequence[Message],
*,
stream: Literal[True],
options: OptionsCoT | ChatOptions[Any] | None = None,
@@ -405,7 +394,7 @@ class BaseChatClient(SerializationMixin, ABC, Generic[OptionsCoT]):
def get_response(
self,
messages: str | Message | Sequence[str | Message],
messages: Sequence[Message],
*,
stream: bool = False,
options: OptionsCoT | ChatOptions[Any] | None = None,
@@ -422,9 +411,8 @@ class BaseChatClient(SerializationMixin, ABC, Generic[OptionsCoT]):
Returns:
When streaming a response stream of ChatResponseUpdates, otherwise an Awaitable ChatResponse.
"""
prepared_messages = prepare_messages(messages)
return self._inner_get_response(
messages=prepared_messages,
messages=messages,
stream=stream,
options=options or {}, # type: ignore[arg-type]
**kwargs,
@@ -1,29 +0,0 @@
# Copyright (c) Microsoft. All rights reserved.
import logging
from .exceptions import AgentFrameworkException
__all__ = ["get_logger", "setup_logging"]
def setup_logging() -> None:
"""Setup the logging configuration for the agent framework."""
logging.basicConfig(
format="[%(asctime)s - %(pathname)s:%(lineno)d - %(levelname)s] %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
)
def get_logger(name: str = "agent_framework") -> logging.Logger:
"""Get a logger with the specified name, defaulting to 'agent_framework'.
Args:
name (str): The name of the logger. Defaults to 'agent_framework'.
Returns:
logging.Logger: The configured logger instance.
"""
if not name.startswith("agent_framework"):
raise AgentFrameworkException("Logger name must start with 'agent_framework'.")
return logging.getLogger(name)
@@ -72,12 +72,6 @@ LOG_LEVEL_MAPPING: dict[types.LoggingLevel, int] = {
"emergency": logging.CRITICAL,
}
__all__ = [
"MCPStdioTool",
"MCPStreamableHTTPTool",
"MCPWebsocketTool",
]
def _parse_prompt_result_from_mcp(
mcp_type: types.GetPromptResult,
@@ -14,11 +14,12 @@ from ._clients import SupportsChatGetResponse
from ._types import (
AgentResponse,
AgentResponseUpdate,
AgentRunInputs,
ChatResponse,
ChatResponseUpdate,
Message,
ResponseStream,
prepare_messages,
normalize_messages,
)
from .exceptions import MiddlewareException
@@ -42,27 +43,6 @@ if TYPE_CHECKING:
ResponseModelBoundT = TypeVar("ResponseModelBoundT", bound=BaseModel)
__all__ = [
"AgentContext",
"AgentMiddleware",
"AgentMiddlewareLayer",
"AgentMiddlewareTypes",
"ChatAndFunctionMiddlewareTypes",
"ChatContext",
"ChatMiddleware",
"ChatMiddlewareLayer",
"ChatMiddlewareTypes",
"FunctionInvocationContext",
"FunctionMiddleware",
"FunctionMiddlewareTypes",
"MiddlewareException",
"MiddlewareTermination",
"MiddlewareType",
"MiddlewareTypes",
"agent_middleware",
"chat_middleware",
"function_middleware",
]
AgentT = TypeVar("AgentT", bound="SupportsAgentRun")
ContextT = TypeVar("ContextT")
@@ -978,7 +958,7 @@ class ChatMiddlewareLayer(Generic[OptionsCoT]):
@overload
def get_response(
self,
messages: str | Message | Sequence[str | Message],
messages: Sequence[Message],
*,
stream: Literal[False] = ...,
options: ChatOptions[ResponseModelBoundT],
@@ -988,7 +968,7 @@ class ChatMiddlewareLayer(Generic[OptionsCoT]):
@overload
def get_response(
self,
messages: str | Message | Sequence[str | Message],
messages: Sequence[Message],
*,
stream: Literal[False] = ...,
options: OptionsCoT | ChatOptions[None] | None = None,
@@ -998,7 +978,7 @@ class ChatMiddlewareLayer(Generic[OptionsCoT]):
@overload
def get_response(
self,
messages: str | Message | Sequence[str | Message],
messages: Sequence[Message],
*,
stream: Literal[True],
options: OptionsCoT | ChatOptions[Any] | None = None,
@@ -1007,7 +987,7 @@ class ChatMiddlewareLayer(Generic[OptionsCoT]):
def get_response(
self,
messages: str | Message | Sequence[str | Message],
messages: Sequence[Message],
*,
stream: bool = False,
options: OptionsCoT | ChatOptions[Any] | None = None,
@@ -1034,7 +1014,7 @@ class ChatMiddlewareLayer(Generic[OptionsCoT]):
context = ChatContext(
client=self, # type: ignore[arg-type]
messages=prepare_messages(messages),
messages=list(messages),
options=options,
stream=stream,
kwargs=kwargs,
@@ -1095,7 +1075,7 @@ class AgentMiddlewareLayer:
@overload
def run(
self,
messages: str | Message | Sequence[str | Message] | None = None,
messages: AgentRunInputs | None = None,
*,
stream: Literal[False] = ...,
session: AgentSession | None = None,
@@ -1107,7 +1087,7 @@ class AgentMiddlewareLayer:
@overload
def run(
self,
messages: str | Message | Sequence[str | Message] | None = None,
messages: AgentRunInputs | None = None,
*,
stream: Literal[False] = ...,
session: AgentSession | None = None,
@@ -1119,7 +1099,7 @@ class AgentMiddlewareLayer:
@overload
def run(
self,
messages: str | Message | Sequence[str | Message] | None = None,
messages: AgentRunInputs | None = None,
*,
stream: Literal[True],
session: AgentSession | None = None,
@@ -1130,7 +1110,7 @@ class AgentMiddlewareLayer:
def run(
self,
messages: str | Message | Sequence[str | Message] | None = None,
messages: AgentRunInputs | None = None,
*,
stream: bool = False,
session: AgentSession | None = None,
@@ -1161,7 +1141,7 @@ class AgentMiddlewareLayer:
context = AgentContext(
agent=self, # type: ignore[arg-type]
messages=prepare_messages(messages), # type: ignore[arg-type]
messages=normalize_messages(messages),
session=session,
options=options,
stream=stream,
@@ -3,13 +3,12 @@
from __future__ import annotations
import json
import logging
import re
from collections.abc import Mapping, MutableMapping
from typing import Any, ClassVar, Protocol, TypeVar, runtime_checkable
from ._logging import get_logger
logger = get_logger()
logger = logging.getLogger("agent_framework")
ClassT = TypeVar("ClassT", bound="SerializationMixin")
ProtocolT = TypeVar("ProtocolT", bound="SerializationProtocol")
@@ -24,16 +24,6 @@ if TYPE_CHECKING:
from ._agents import SupportsAgentRun
__all__ = [
"AgentSession",
"BaseContextProvider",
"BaseHistoryProvider",
"InMemoryHistoryProvider",
"SessionContext",
"register_state_type",
]
# Registry of known types for state deserialization
_STATE_TYPE_REGISTRY: dict[str, type] = {}
@@ -45,7 +45,6 @@ if sys.version_info >= (3, 13):
else:
from typing_extensions import TypeVar # type: ignore # pragma: no cover
__all__ = ["SecretString", "load_settings"]
SettingsT = TypeVar("SettingsT", default=dict[str, Any])
@@ -2,21 +2,14 @@
from __future__ import annotations
import logging
import os
from typing import Any, Final
from . import __version__ as version_info
from ._logging import get_logger
logger = get_logger()
logger = logging.getLogger("agent_framework")
__all__ = [
"AGENT_FRAMEWORK_USER_AGENT",
"APP_INFO",
"USER_AGENT_KEY",
"USER_AGENT_TELEMETRY_DISABLED_ENV_VAR",
"prepend_agent_framework_to_user_agent",
]
# Note that if this environment variable does not exist, user agent telemetry is enabled.
USER_AGENT_TELEMETRY_DISABLED_ENV_VAR = "AGENT_FRAMEWORK_USER_AGENT_DISABLED"
+8 -19
View File
@@ -5,6 +5,7 @@ from __future__ import annotations
import asyncio
import inspect
import json
import logging
import sys
from collections.abc import (
AsyncIterable,
@@ -34,7 +35,6 @@ from typing import (
from opentelemetry.metrics import Histogram, NoOpHistogram
from pydantic import BaseModel, Field, ValidationError, create_model
from ._logging import get_logger
from ._serialization import SerializationMixin
from .exceptions import ToolException
from .observability import (
@@ -71,18 +71,8 @@ if TYPE_CHECKING:
ResponseModelBoundT = TypeVar("ResponseModelBoundT", bound=BaseModel)
logger = get_logger()
logger = logging.getLogger("agent_framework")
__all__ = [
"FunctionInvocationConfiguration",
"FunctionInvocationLayer",
"FunctionTool",
"normalize_function_invocation_configuration",
"tool",
]
logger = get_logger()
DEFAULT_MAX_ITERATIONS: Final[int] = 40
DEFAULT_MAX_CONSECUTIVE_ERRORS_PER_REQUEST: Final[int] = 3
ChatClientT = TypeVar("ChatClientT", bound="SupportsChatGetResponse[Any]")
@@ -1941,7 +1931,7 @@ class FunctionInvocationLayer(Generic[OptionsCoT]):
@overload
def get_response(
self,
messages: str | Message | Sequence[str | Message],
messages: Sequence[Message],
*,
stream: Literal[False] = ...,
options: ChatOptions[ResponseModelBoundT],
@@ -1951,7 +1941,7 @@ class FunctionInvocationLayer(Generic[OptionsCoT]):
@overload
def get_response(
self,
messages: str | Message | Sequence[str | Message],
messages: Sequence[Message],
*,
stream: Literal[False] = ...,
options: OptionsCoT | ChatOptions[None] | None = None,
@@ -1961,7 +1951,7 @@ class FunctionInvocationLayer(Generic[OptionsCoT]):
@overload
def get_response(
self,
messages: str | Message | Sequence[str | Message],
messages: Sequence[Message],
*,
stream: Literal[True],
options: OptionsCoT | ChatOptions[Any] | None = None,
@@ -1970,7 +1960,7 @@ class FunctionInvocationLayer(Generic[OptionsCoT]):
def get_response(
self,
messages: str | Message | Sequence[str | Message],
messages: Sequence[Message],
*,
stream: bool = False,
options: OptionsCoT | ChatOptions[Any] | None = None,
@@ -1982,7 +1972,6 @@ class FunctionInvocationLayer(Generic[OptionsCoT]):
ChatResponse,
ChatResponseUpdate,
ResponseStream,
prepare_messages,
)
super_get_response = super().get_response # type: ignore[misc]
@@ -2014,7 +2003,7 @@ class FunctionInvocationLayer(Generic[OptionsCoT]):
nonlocal mutable_options
nonlocal filtered_kwargs
errors_in_a_row: int = 0
prepped_messages = prepare_messages(messages)
prepped_messages = list(messages)
fcc_messages: list[Message] = []
response: ChatResponse | None = None
@@ -2108,7 +2097,7 @@ class FunctionInvocationLayer(Generic[OptionsCoT]):
nonlocal mutable_options
nonlocal stream_result_hooks
errors_in_a_row: int = 0
prepped_messages = prepare_messages(messages)
prepped_messages = list(messages)
fcc_messages: list[Message] = []
response: ChatResponse | None = None
+4 -74
View File
@@ -4,6 +4,7 @@ from __future__ import annotations
import base64
import json
import logging
import re
import sys
from asyncio import iscoroutine
@@ -13,7 +14,6 @@ from typing import TYPE_CHECKING, Any, ClassVar, Final, Generic, Literal, NewTyp
from pydantic import BaseModel
from ._logging import get_logger
from ._serialization import SerializationMixin
from ._tools import FunctionTool, tool
from .exceptions import AdditionItemMismatch, ContentError
@@ -27,41 +27,7 @@ if sys.version_info >= (3, 11):
else:
from typing_extensions import TypedDict # type: ignore # pragma: no cover
__all__ = [
"AgentResponse",
"AgentResponseUpdate",
"Annotation",
"ChatOptions",
"ChatResponse",
"ChatResponseUpdate",
"Content",
"ContinuationToken",
"FinalT",
"FinishReason",
"FinishReasonLiteral",
"Message",
"OuterFinalT",
"OuterUpdateT",
"ResponseStream",
"Role",
"RoleLiteral",
"TextSpanRegion",
"ToolMode",
"UpdateT",
"UsageDetails",
"add_usage_details",
"detect_media_type_from_base64",
"map_chat_to_agent_update",
"merge_chat_options",
"normalize_messages",
"normalize_tools",
"prepend_instructions_to_messages",
"validate_chat_options",
"validate_tool_mode",
"validate_tools",
]
logger = get_logger("agent_framework")
logger = logging.getLogger("agent_framework")
# region Content Parsing Utilities
@@ -1536,47 +1502,11 @@ class Message(SerializationMixin):
return " ".join(content.text for content in self.contents if content.type == "text") # type: ignore[misc]
def prepare_messages(
messages: str | Content | Message | Sequence[str | Content | Message],
system_instructions: str | Sequence[str] | None = None,
) -> list[Message]:
"""Convert various message input formats into a list of Message objects.
Args:
messages: The input messages in various supported formats. Can be:
- A string (converted to a user message)
- A Content object (wrapped in a user Message)
- A Message object
- A sequence containing any mix of the above
system_instructions: The system instructions. They will be inserted to the start of the messages list.
Returns:
A list of Message objects.
"""
if system_instructions is not None:
if isinstance(system_instructions, str):
system_instructions = [system_instructions]
system_instruction_messages = [Message("system", [instr]) for instr in system_instructions]
else:
system_instruction_messages = []
if isinstance(messages, str):
return [*system_instruction_messages, Message("user", [messages])]
if isinstance(messages, Content):
return [*system_instruction_messages, Message("user", [messages])]
if isinstance(messages, Message):
return [*system_instruction_messages, messages]
return_messages: list[Message] = system_instruction_messages
for msg in messages:
if isinstance(msg, (str, Content)):
msg = Message("user", [msg])
return_messages.append(msg)
return return_messages
AgentRunInputs = str | Content | Message | Sequence[str | Content | Message]
def normalize_messages(
messages: str | Content | Message | Sequence[str | Content | Message] | None = None,
messages: AgentRunInputs | None = None,
) -> list[Message]:
"""Normalize message inputs to a list of Message objects.
@@ -22,6 +22,7 @@ from .._sessions import (
from .._types import (
AgentResponse,
AgentResponseUpdate,
AgentRunInputs,
Content,
Message,
ResponseStream,
@@ -145,7 +146,7 @@ class WorkflowAgent(BaseAgent):
@overload
def run(
self,
messages: str | Message | Sequence[str | Message] | None = None,
messages: AgentRunInputs | None = None,
*,
stream: Literal[True],
session: AgentSession | None = None,
@@ -157,7 +158,7 @@ class WorkflowAgent(BaseAgent):
@overload
async def run(
self,
messages: str | Message | Sequence[str | Message] | None = None,
messages: AgentRunInputs | None = None,
*,
stream: Literal[False] = ...,
session: AgentSession | None = None,
@@ -168,7 +169,7 @@ class WorkflowAgent(BaseAgent):
def run(
self,
messages: str | Message | Sequence[str | Message] | None = None,
messages: AgentRunInputs | None = None,
*,
stream: bool = False,
session: AgentSession | None = None,
@@ -214,7 +215,7 @@ class WorkflowAgent(BaseAgent):
async def _run_impl(
self,
messages: str | Message | Sequence[str | Message],
messages: AgentRunInputs,
response_id: str,
session: AgentSession | None,
checkpoint_id: str | None = None,
@@ -270,7 +271,7 @@ class WorkflowAgent(BaseAgent):
async def _run_stream_impl(
self,
messages: str | Message | Sequence[str | Message],
messages: AgentRunInputs,
response_id: str,
session: AgentSession | None,
checkpoint_id: str | None = None,
@@ -3,11 +3,10 @@
from __future__ import annotations
import base64
import logging
import pickle # nosec # noqa: S403
from typing import Any
from agent_framework import get_logger
"""Checkpoint encoding using JSON structure with pickle+base64 for arbitrary data.
This hybrid approach provides:
@@ -20,7 +19,7 @@ from trusted sources. Loading a malicious checkpoint file can execute arbitrary
"""
logger = get_logger(__name__)
logger = logging.getLogger("agent_framework")
# Marker to identify pickled values in serialized JSON
_PICKLE_MARKER = "__pickled__"
@@ -2,18 +2,17 @@
"""Shared helpers for normalizing workflow message inputs."""
from collections.abc import Sequence
from agent_framework import Message
from agent_framework import Content, Message
from agent_framework._types import AgentRunInputs
def normalize_messages_input(
messages: str | Message | Sequence[str | Message] | None = None,
messages: AgentRunInputs | None = None,
) -> list[Message]:
"""Normalize heterogeneous message inputs to a list of Message objects.
Args:
messages: String, Message, or sequence of either. None yields empty list.
messages: String, Content, Message, or sequence of those values. None yields empty list.
Returns:
List of Message instances suitable for workflow consumption.
@@ -24,6 +23,9 @@ def normalize_messages_input(
if isinstance(messages, str):
return [Message(role="user", text=messages)]
if isinstance(messages, Content):
return [Message(role="user", contents=[messages])]
if isinstance(messages, Message):
return [messages]
@@ -31,13 +33,12 @@ def normalize_messages_input(
for item in messages:
if isinstance(item, str):
normalized.append(Message(role="user", text=item))
elif isinstance(item, Content):
normalized.append(Message(role="user", contents=[item]))
elif isinstance(item, Message):
normalized.append(item)
else:
raise TypeError(
f"Messages sequence must contain only str or Message instances; found {type(item).__name__}."
f"Messages sequence must contain only str, Content, or Message instances; found {type(item).__name__}."
)
return normalized
__all__ = ["normalize_messages_input"]
@@ -27,8 +27,6 @@ if sys.version_info >= (3, 11):
else:
from typing_extensions import TypedDict # type: ignore # pragma: no cover
__all__ = ["AzureOpenAIAssistantsClient"]
# region Azure OpenAI Assistants Options TypedDict
@@ -53,7 +53,6 @@ if TYPE_CHECKING:
logger: logging.Logger = logging.getLogger(__name__)
__all__ = ["AzureOpenAIChatClient", "AzureOpenAIChatOptions", "AzureUserSecurityContext"]
ResponseModelT = TypeVar("ResponseModelT", bound=BaseModel | None, default=None)
@@ -42,8 +42,6 @@ if TYPE_CHECKING:
from .._middleware import MiddlewareTypes
from ..openai._responses_client import OpenAIResponsesOptions
__all__ = ["AzureOpenAIResponsesClient"]
AzureOpenAIResponsesOptionsT = TypeVar(
"AzureOpenAIResponsesOptionsT",
@@ -20,7 +20,6 @@ from opentelemetry.semconv.attributes import service_attributes
from opentelemetry.semconv_ai import Meters, SpanAttributes
from . import __version__ as version_info
from ._logging import get_logger
from ._settings import load_settings
if sys.version_info >= (3, 13):
@@ -44,6 +43,7 @@ if TYPE_CHECKING: # pragma: no cover
from ._types import (
AgentResponse,
AgentResponseUpdate,
AgentRunInputs,
ChatOptions,
ChatResponse,
ChatResponseUpdate,
@@ -73,7 +73,7 @@ AgentT = TypeVar("AgentT", bound="SupportsAgentRun")
ChatClientT = TypeVar("ChatClientT", bound="SupportsChatGetResponse[Any]")
logger = get_logger()
logger = logging.getLogger("agent_framework")
OTEL_METRICS: Final[str] = "__otel_metrics__"
@@ -747,7 +747,6 @@ class ObservabilitySettings:
for log_exporter in log_exporters:
logger_provider.add_log_record_processor(BatchLogRecordProcessor(log_exporter))
# Attach a handler with the provider to the root logger
logger = logging.getLogger()
handler = LoggingHandler(logger_provider=logger_provider)
logger.addHandler(handler)
set_logger_provider(logger_provider)
@@ -1084,7 +1083,7 @@ class ChatTelemetryLayer(Generic[OptionsCoT]):
@overload
def get_response(
self,
messages: str | Message | Sequence[str | Message],
messages: Sequence[Message],
*,
stream: Literal[False] = ...,
options: ChatOptions[ResponseModelBoundT],
@@ -1094,7 +1093,7 @@ class ChatTelemetryLayer(Generic[OptionsCoT]):
@overload
def get_response(
self,
messages: str | Message | Sequence[str | Message],
messages: Sequence[Message],
*,
stream: Literal[False] = ...,
options: OptionsCoT | ChatOptions[None] | None = None,
@@ -1104,7 +1103,7 @@ class ChatTelemetryLayer(Generic[OptionsCoT]):
@overload
def get_response(
self,
messages: str | Message | Sequence[str | Message],
messages: Sequence[Message],
*,
stream: Literal[True],
options: OptionsCoT | ChatOptions[Any] | None = None,
@@ -1113,7 +1112,7 @@ class ChatTelemetryLayer(Generic[OptionsCoT]):
def get_response(
self,
messages: str | Message | Sequence[str | Message],
messages: Sequence[Message],
*,
stream: bool = False,
options: OptionsCoT | ChatOptions[Any] | None = None,
@@ -1277,7 +1276,7 @@ class AgentTelemetryLayer:
@overload
def run(
self,
messages: str | Message | Sequence[str | Message] | None = None,
messages: AgentRunInputs | None = None,
*,
stream: Literal[False] = ...,
session: AgentSession | None = None,
@@ -1287,7 +1286,7 @@ class AgentTelemetryLayer:
@overload
def run(
self,
messages: str | Message | Sequence[str | Message] | None = None,
messages: AgentRunInputs | None = None,
*,
stream: Literal[True],
session: AgentSession | None = None,
@@ -1296,7 +1295,7 @@ class AgentTelemetryLayer:
def run(
self,
messages: str | Message | Sequence[str | Message] | None = None,
messages: AgentRunInputs | None = None,
*,
stream: bool = False,
session: AgentSession | None = None,
@@ -1614,15 +1613,15 @@ def capture_exception(span: trace.Span, exception: Exception, timestamp: int | N
def _capture_messages(
span: trace.Span,
provider_name: str,
messages: str | Message | Sequence[str | Message],
messages: AgentRunInputs,
system_instructions: str | list[str] | None = None,
output: bool = False,
finish_reason: FinishReason | None = None,
) -> None:
"""Log messages with extra information."""
from ._types import prepare_messages
from ._types import normalize_messages, prepend_instructions_to_messages
prepped = prepare_messages(messages, system_instructions=system_instructions)
prepped = prepend_instructions_to_messages(normalize_messages(messages), system_instructions)
otel_messages: list[dict[str, Any]] = []
for index, message in enumerate(prepped):
# Reuse the otel message representation for logging instead of calling to_dict()
@@ -33,7 +33,6 @@ if sys.version_info >= (3, 11):
else:
from typing_extensions import Self, TypedDict # type:ignore # pragma: no cover
__all__ = ["OpenAIAssistantProvider"]
# Type variable for options - allows typed OpenAIAssistantProvider[OptionsCoT] returns
# Default matches OpenAIAssistantsClient's default options type
@@ -68,12 +68,6 @@ else:
if TYPE_CHECKING:
from .._middleware import MiddlewareTypes
__all__ = [
"AssistantToolResources",
"OpenAIAssistantsClient",
"OpenAIAssistantsOptions",
]
# region OpenAI Assistants Options TypedDict
@@ -3,6 +3,7 @@
from __future__ import annotations
import json
import logging
import sys
from collections.abc import AsyncIterable, Awaitable, Callable, Mapping, MutableMapping, Sequence
from datetime import datetime, timezone
@@ -20,7 +21,6 @@ from openai.types.chat.completion_create_params import WebSearchOptions
from pydantic import BaseModel
from .._clients import BaseChatClient
from .._logging import get_logger
from .._middleware import ChatAndFunctionMiddlewareTypes, ChatMiddlewareLayer
from .._settings import load_settings
from .._tools import (
@@ -60,9 +60,7 @@ if sys.version_info >= (3, 11):
else:
from typing_extensions import TypedDict # type: ignore # pragma: no cover
__all__ = ["OpenAIChatClient", "OpenAIChatOptions"]
logger = get_logger("agent_framework.openai")
logger = logging.getLogger("agent_framework.openai")
ResponseModelT = TypeVar("ResponseModelT", bound=BaseModel | None, default=None)
@@ -10,8 +10,6 @@ from openai import BadRequestError
from ..exceptions import ServiceContentFilterException
__all__ = ["ContentFilterResultSeverity", "OpenAIContentFilterException"]
class ContentFilterResultSeverity(Enum):
"""The severity of the content filter result."""
@@ -2,6 +2,7 @@
from __future__ import annotations
import logging
import sys
from collections.abc import (
AsyncIterable,
@@ -36,7 +37,6 @@ from openai.types.responses.web_search_tool_param import WebSearchToolParam
from pydantic import BaseModel
from .._clients import BaseChatClient
from .._logging import get_logger
from .._middleware import ChatMiddlewareLayer
from .._settings import load_settings
from .._tools import (
@@ -90,10 +90,7 @@ if TYPE_CHECKING:
FunctionMiddlewareCallable,
)
logger = get_logger("agent_framework.openai")
__all__ = ["OpenAIContinuationToken", "OpenAIResponsesClient", "OpenAIResponsesOptions", "RawOpenAIResponsesClient"]
logger = logging.getLogger("agent_framework.openai")
class OpenAIContinuationToken(ContinuationToken):
@@ -22,14 +22,13 @@ from openai.types.responses.response import Response
from openai.types.responses.response_stream_event import ResponseStreamEvent
from packaging.version import parse
from .._logging import get_logger
from .._serialization import SerializationMixin
from .._settings import SecretString
from .._telemetry import APP_INFO, USER_AGENT_KEY, prepend_agent_framework_to_user_agent
from .._tools import FunctionTool
from ..exceptions import ServiceInitializationError
logger: logging.Logger = get_logger("agent_framework.openai")
logger: logging.Logger = logging.getLogger("agent_framework.openai")
RESPONSE_TYPE = Union[
@@ -53,9 +52,6 @@ else:
from typing_extensions import TypedDict # type: ignore # pragma: no cover
__all__ = ["OpenAISettings"]
def _check_openai_version_for_callable_api_key() -> None:
"""Check if OpenAI version supports callable API keys.
@@ -334,8 +334,10 @@ async def test_integration_options(
messages = [Message(role="user", text="What is the weather in Seattle?")]
elif option_name == "response_format":
# Use prompt that works well with structured output
messages = [Message(role="user", text="The weather in Seattle is sunny")]
messages.append(Message(role="user", text="What is the weather in Seattle?"))
messages = [
Message(role="user", text="The weather in Seattle is sunny"),
Message(role="user", text="What is the weather in Seattle?"),
]
else:
# Generic prompt for simple options
messages = [Message(role="user", text="Say 'Hello World' briefly.")]
@@ -396,7 +398,12 @@ async def test_integration_web_search() -> None:
for streaming in [False, True]:
content = {
"messages": "Who are the main characters of Kpop Demon Hunters? Do a web search to find the answer.",
"messages": [
Message(
role="user",
text="Who are the main characters of Kpop Demon Hunters? Do a web search to find the answer.",
)
],
"options": {
"tool_choice": "auto",
"tools": [AzureOpenAIResponsesClient.get_web_search_tool()],
@@ -416,7 +423,7 @@ async def test_integration_web_search() -> None:
# Test that the client will use the web search tool with location
content = {
"messages": "What is the current weather? Do not ask for my current location.",
"messages": [Message(role="user", text="What is the current weather? Do not ask for my current location.")],
"options": {
"tool_choice": "auto",
"tools": [
@@ -498,7 +505,7 @@ async def test_integration_client_agent_hosted_mcp_tool() -> None:
"""Integration test for MCP tool with Azure Response Agent using Microsoft Learn MCP."""
client = AzureOpenAIResponsesClient(credential=AzureCliCredential())
response = await client.get_response(
"How to create an Azure storage account using az cli?",
messages=[Message(role="user", text="How to create an Azure storage account using az cli?")],
options={
# this needs to be high enough to handle the full MCP tool response.
"max_tokens": 5000,
@@ -523,7 +530,7 @@ async def test_integration_client_agent_hosted_code_interpreter_tool():
client = AzureOpenAIResponsesClient(credential=AzureCliCredential())
response = await client.get_response(
"Calculate the sum of numbers from 1 to 10 using Python code.",
messages=[Message(role="user", text="Calculate the sum of numbers from 1 to 10 using Python code.")],
options={
"tools": [AzureOpenAIResponsesClient.get_code_interpreter_tool()],
},
@@ -43,6 +43,12 @@ async def test_agent_run(agent: SupportsAgentRun) -> None:
assert response.messages[0].text == "Response"
async def test_agent_run_with_content(agent: SupportsAgentRun) -> None:
response = await agent.run(Content.from_text("test"))
assert response.messages[0].role == "assistant"
assert response.messages[0].text == "Response"
async def test_agent_run_streaming(agent: SupportsAgentRun) -> None:
async def collect_updates(updates: AsyncIterable[AgentResponseUpdate]) -> list[AgentResponseUpdate]:
return [u async for u in updates]
@@ -21,13 +21,13 @@ def test_chat_client_type(client: SupportsChatGetResponse):
async def test_chat_client_get_response(client: SupportsChatGetResponse):
response = await client.get_response(Message(role="user", text="Hello"))
response = await client.get_response([Message(role="user", text="Hello")])
assert response.text == "test response"
assert response.messages[0].role == "assistant"
async def test_chat_client_get_response_streaming(client: SupportsChatGetResponse):
async for update in client.get_response(Message(role="user", text="Hello"), stream=True):
async for update in client.get_response([Message(role="user", text="Hello")], stream=True):
assert update.text == "test streaming response " or update.text == "another update"
assert update.role == "assistant"
@@ -38,13 +38,13 @@ def test_base_client(chat_client_base: SupportsChatGetResponse):
async def test_base_client_get_response(chat_client_base: SupportsChatGetResponse):
response = await chat_client_base.get_response(Message(role="user", text="Hello"))
response = await chat_client_base.get_response([Message(role="user", text="Hello")])
assert response.messages[0].role == "assistant"
assert response.messages[0].text == "test response - Hello"
async def test_base_client_get_response_streaming(chat_client_base: SupportsChatGetResponse):
async for update in chat_client_base.get_response(Message(role="user", text="Hello"), stream=True):
async for update in chat_client_base.get_response([Message(role="user", text="Hello")], stream=True):
assert update.text == "update - Hello" or update.text == "another update"
@@ -59,7 +59,9 @@ async def test_chat_client_instructions_handling(chat_client_base: SupportsChatG
"_inner_get_response",
side_effect=fake_inner_get_response,
) as mock_inner_get_response:
await chat_client_base.get_response("hello", options={"instructions": instructions})
await chat_client_base.get_response(
[Message(role="user", text="hello")], options={"instructions": instructions}
)
mock_inner_get_response.assert_called_once()
_, kwargs = mock_inner_get_response.call_args
messages = kwargs.get("messages", [])
@@ -38,7 +38,9 @@ async def test_base_client_with_function_calling(chat_client_base: SupportsChatG
),
ChatResponse(messages=Message(role="assistant", text="done")),
]
response = await chat_client_base.get_response("hello", options={"tool_choice": "auto", "tools": [ai_func]})
response = await chat_client_base.get_response(
[Message(role="user", text="hello")], options={"tool_choice": "auto", "tools": [ai_func]}
)
assert exec_counter == 1
assert len(response.messages) == 3
assert response.messages[0].role == "assistant"
@@ -83,7 +85,9 @@ async def test_base_client_with_function_calling_resets(chat_client_base: Suppor
),
ChatResponse(messages=Message(role="assistant", text="done")),
]
response = await chat_client_base.get_response("hello", options={"tool_choice": "auto", "tools": [ai_func]})
response = await chat_client_base.get_response(
[Message(role="user", text="hello")], options={"tool_choice": "auto", "tools": [ai_func]}
)
assert exec_counter == 2
assert len(response.messages) == 5
assert response.messages[0].role == "assistant"
@@ -388,11 +392,13 @@ async def test_function_invocation_scenarios(
options["conversation_id"] = conversation_id
if not streaming:
response = await chat_client_base.get_response("hello", options=options)
response = await chat_client_base.get_response([Message(role="user", text="hello")], options=options)
messages = response.messages
else:
updates = []
async for update in chat_client_base.get_response("hello", options=options, stream=True):
async for update in chat_client_base.get_response(
[Message(role="user", text="hello")], options=options, stream=True
):
updates.append(update)
messages = updates
@@ -776,7 +782,9 @@ async def test_max_iterations_limit(chat_client_base: SupportsChatGetResponse):
# Set max_iterations to 1 in additional_properties
chat_client_base.function_invocation_configuration["max_iterations"] = 1
response = await chat_client_base.get_response("hello", options={"tool_choice": "auto", "tools": [ai_func]})
response = await chat_client_base.get_response(
[Message(role="user", text="hello")], options={"tool_choice": "auto", "tools": [ai_func]}
)
# With max_iterations=1, we should:
# 1. Execute first function call (exec_counter=1)
@@ -803,7 +811,9 @@ async def test_function_invocation_config_enabled_false(chat_client_base: Suppor
# Disable function invocation
chat_client_base.function_invocation_configuration["enabled"] = False
response = await chat_client_base.get_response("hello", options={"tool_choice": "auto", "tools": [ai_func]})
response = await chat_client_base.get_response(
[Message(role="user", text="hello")], options={"tool_choice": "auto", "tools": [ai_func]}
)
# Function should not be executed - when enabled=False, the loop doesn't run
assert exec_counter == 0
@@ -859,7 +869,9 @@ async def test_function_invocation_config_max_consecutive_errors(chat_client_bas
# Set max_consecutive_errors to 2
chat_client_base.function_invocation_configuration["max_consecutive_errors_per_request"] = 2
response = await chat_client_base.get_response("hello", options={"tool_choice": "auto", "tools": [error_func]})
response = await chat_client_base.get_response(
[Message(role="user", text="hello")], options={"tool_choice": "auto", "tools": [error_func]}
)
# Should stop after 2 consecutive errors and force a non-tool response
error_results = [
@@ -904,7 +916,9 @@ async def test_function_invocation_config_terminate_on_unknown_calls_false(chat_
# Set terminate_on_unknown_calls to False (default)
chat_client_base.function_invocation_configuration["terminate_on_unknown_calls"] = False
response = await chat_client_base.get_response("hello", options={"tool_choice": "auto", "tools": [known_func]})
response = await chat_client_base.get_response(
[Message(role="user", text="hello")], options={"tool_choice": "auto", "tools": [known_func]}
)
# Should have a result message indicating the tool wasn't found
assert len(response.messages) == 3
@@ -940,7 +954,9 @@ async def test_function_invocation_config_terminate_on_unknown_calls_true(chat_c
# Should raise an exception when encountering an unknown function
with pytest.raises(KeyError, match='Error: Requested function "unknown_function" not found'):
await chat_client_base.get_response("hello", options={"tool_choice": "auto", "tools": [known_func]})
await chat_client_base.get_response(
[Message(role="user", text="hello")], options={"tool_choice": "auto", "tools": [known_func]}
)
assert exec_counter == 0
@@ -978,7 +994,9 @@ async def test_function_invocation_config_additional_tools(chat_client_base: Sup
chat_client_base.function_invocation_configuration["additional_tools"] = [hidden_func]
# Only pass visible_func in the tools parameter
response = await chat_client_base.get_response("hello", options={"tool_choice": "auto", "tools": [visible_func]})
response = await chat_client_base.get_response(
[Message(role="user", text="hello")], options={"tool_choice": "auto", "tools": [visible_func]}
)
# Additional tools are treated as declaration_only, so not executed
# The function call should be in the messages but not executed
@@ -1016,7 +1034,9 @@ async def test_function_invocation_config_include_detailed_errors_false(chat_cli
# Set include_detailed_errors to False (default)
chat_client_base.function_invocation_configuration["include_detailed_errors"] = False
response = await chat_client_base.get_response("hello", options={"tool_choice": "auto", "tools": [error_func]})
response = await chat_client_base.get_response(
[Message(role="user", text="hello")], options={"tool_choice": "auto", "tools": [error_func]}
)
# Should have a generic error message
error_result = next(
@@ -1050,7 +1070,9 @@ async def test_function_invocation_config_include_detailed_errors_true(chat_clie
# Set include_detailed_errors to True
chat_client_base.function_invocation_configuration["include_detailed_errors"] = True
response = await chat_client_base.get_response("hello", options={"tool_choice": "auto", "tools": [error_func]})
response = await chat_client_base.get_response(
[Message(role="user", text="hello")], options={"tool_choice": "auto", "tools": [error_func]}
)
# Should have detailed error message
error_result = next(
@@ -1120,7 +1142,9 @@ async def test_argument_validation_error_with_detailed_errors(chat_client_base:
# Set include_detailed_errors to True
chat_client_base.function_invocation_configuration["include_detailed_errors"] = True
response = await chat_client_base.get_response("hello", options={"tool_choice": "auto", "tools": [typed_func]})
response = await chat_client_base.get_response(
[Message(role="user", text="hello")], options={"tool_choice": "auto", "tools": [typed_func]}
)
# Should have detailed validation error
error_result = next(
@@ -1154,7 +1178,9 @@ async def test_argument_validation_error_without_detailed_errors(chat_client_bas
# Set include_detailed_errors to False (default)
chat_client_base.function_invocation_configuration["include_detailed_errors"] = False
response = await chat_client_base.get_response("hello", options={"tool_choice": "auto", "tools": [typed_func]})
response = await chat_client_base.get_response(
[Message(role="user", text="hello")], options={"tool_choice": "auto", "tools": [typed_func]}
)
# Should have generic validation error
error_result = next(
@@ -1219,7 +1245,9 @@ async def test_unapproved_tool_execution_raises_exception(chat_client_base: Supp
]
# Get approval request
response1 = await chat_client_base.get_response("hello", options={"tool_choice": "auto", "tools": [test_func]})
response1 = await chat_client_base.get_response(
[Message(role="user", text="hello")], options={"tool_choice": "auto", "tools": [test_func]}
)
approval_req = [c for c in response1.messages[0].contents if c.type == "function_approval_request"][0]
@@ -1277,7 +1305,9 @@ async def test_approved_function_call_with_error_without_detailed_errors(chat_cl
chat_client_base.function_invocation_configuration["include_detailed_errors"] = False
# Get approval request
response1 = await chat_client_base.get_response("hello", options={"tool_choice": "auto", "tools": [error_func]})
response1 = await chat_client_base.get_response(
[Message(role="user", text="hello")], options={"tool_choice": "auto", "tools": [error_func]}
)
approval_req = [c for c in response1.messages[0].contents if c.type == "function_approval_request"][0]
@@ -1340,7 +1370,9 @@ async def test_approved_function_call_with_error_with_detailed_errors(chat_clien
chat_client_base.function_invocation_configuration["include_detailed_errors"] = True
# Get approval request
response1 = await chat_client_base.get_response("hello", options={"tool_choice": "auto", "tools": [error_func]})
response1 = await chat_client_base.get_response(
[Message(role="user", text="hello")], options={"tool_choice": "auto", "tools": [error_func]}
)
approval_req = [c for c in response1.messages[0].contents if c.type == "function_approval_request"][0]
@@ -1403,7 +1435,9 @@ async def test_approved_function_call_with_validation_error(chat_client_base: Su
chat_client_base.function_invocation_configuration["include_detailed_errors"] = True
# Get approval request
response1 = await chat_client_base.get_response("hello", options={"tool_choice": "auto", "tools": [typed_func]})
response1 = await chat_client_base.get_response(
[Message(role="user", text="hello")], options={"tool_choice": "auto", "tools": [typed_func]}
)
approval_req = [c for c in response1.messages[0].contents if c.type == "function_approval_request"][0]
@@ -1459,7 +1493,9 @@ async def test_approved_function_call_successful_execution(chat_client_base: Sup
]
# Get approval request
response1 = await chat_client_base.get_response("hello", options={"tool_choice": "auto", "tools": [success_func]})
response1 = await chat_client_base.get_response(
[Message(role="user", text="hello")], options={"tool_choice": "auto", "tools": [success_func]}
)
approval_req = [c for c in response1.messages[0].contents if c.type == "function_approval_request"][0]
@@ -1575,7 +1611,9 @@ async def test_multiple_function_calls_parallel_execution(chat_client_base: Supp
ChatResponse(messages=Message(role="assistant", text="done")),
]
response = await chat_client_base.get_response("hello", options={"tool_choice": "auto", "tools": [func1, func2]})
response = await chat_client_base.get_response(
[Message(role="user", text="hello")], options={"tool_choice": "auto", "tools": [func1, func2]}
)
# Both functions should have been executed
assert "func1_start" in exec_order
@@ -1612,7 +1650,9 @@ async def test_callable_function_converted_to_tool(chat_client_base: SupportsCha
]
# Pass plain function (will be auto-converted)
response = await chat_client_base.get_response("hello", options={"tool_choice": "auto", "tools": [plain_function]})
response = await chat_client_base.get_response(
[Message(role="user", text="hello")], options={"tool_choice": "auto", "tools": [plain_function]}
)
# Function should be executed
assert exec_counter == 1
@@ -1644,7 +1684,9 @@ async def test_conversation_id_handling(chat_client_base: SupportsChatGetRespons
),
]
response = await chat_client_base.get_response("hello", options={"tool_choice": "auto", "tools": [test_func]})
response = await chat_client_base.get_response(
[Message(role="user", text="hello")], options={"tool_choice": "auto", "tools": [test_func]}
)
# Should have executed the function
results = [content for msg in response.messages for content in msg.contents if content.type == "function_result"]
@@ -1671,7 +1713,9 @@ async def test_function_result_appended_to_existing_assistant_message(chat_clien
ChatResponse(messages=Message(role="assistant", text="done")),
]
response = await chat_client_base.get_response("hello", options={"tool_choice": "auto", "tools": [test_func]})
response = await chat_client_base.get_response(
[Message(role="user", text="hello")], options={"tool_choice": "auto", "tools": [test_func]}
)
# Should have messages with both function call and function result
assert len(response.messages) >= 2
@@ -1716,7 +1760,9 @@ async def test_error_recovery_resets_counter(chat_client_base: SupportsChatGetRe
ChatResponse(messages=Message(role="assistant", text="done")),
]
response = await chat_client_base.get_response("hello", options={"tool_choice": "auto", "tools": [sometimes_fails]})
response = await chat_client_base.get_response(
[Message(role="user", text="hello")], options={"tool_choice": "auto", "tools": [sometimes_fails]}
)
# Should have both an error and a success
error_results = [
@@ -1990,7 +2036,9 @@ async def test_streaming_function_invocation_config_terminate_on_unknown_calls_t
# Should raise an exception when encountering an unknown function
with pytest.raises(KeyError, match='Error: Requested function "unknown_function" not found'):
async for _ in chat_client_base.get_response("hello", options={"tool_choice": "auto", "tools": [known_func]}):
async for _ in chat_client_base.get_response(
[Message(role="user", text="hello")], options={"tool_choice": "auto", "tools": [known_func]}
):
pass
assert exec_counter == 0
@@ -1,39 +0,0 @@
# Copyright (c) Microsoft. All rights reserved.
import pytest
from agent_framework import get_logger
from agent_framework.exceptions import AgentFrameworkException
def test_get_logger():
"""Test that the logger is created with the correct name."""
logger = get_logger()
assert logger.name == "agent_framework"
def test_get_logger_custom_name():
"""Test that the logger can be created with a custom name."""
custom_name = "agent_framework.custom"
logger = get_logger(custom_name)
assert logger.name == custom_name
def test_get_logger_invalid_name():
"""Test that an exception is raised for an invalid logger name."""
with pytest.raises(AgentFrameworkException):
get_logger("invalid_name")
def test_log(caplog):
"""Test that the logger can log messages and adheres to the expected format."""
logger = get_logger()
with caplog.at_level("DEBUG"):
logger.debug("This is a debug message")
assert len(caplog.records) == 1
record = caplog.records[0]
assert record.levelname == "DEBUG"
assert record.message == "This is a debug message"
assert record.name == "agent_framework"
assert record.pathname.endswith("test_logging.py")
@@ -1083,7 +1083,12 @@ async def test_integration_web_search() -> None:
# Use static method for web search tool
web_search_tool = OpenAIChatClient.get_web_search_tool()
content = {
"messages": "Who are the main characters of Kpop Demon Hunters? Do a web search to find the answer.",
"messages": [
Message(
role="user",
text="Who are the main characters of Kpop Demon Hunters? Do a web search to find the answer.",
)
],
"options": {
"tool_choice": "auto",
"tools": [web_search_tool],
@@ -1110,7 +1115,7 @@ async def test_integration_web_search() -> None:
}
)
content = {
"messages": "What is the current weather? Do not ask for my current location.",
"messages": [Message(role="user", text="What is the current weather? Do not ask for my current location.")],
"options": {
"tool_choice": "auto",
"tools": [web_search_tool_with_location],
@@ -2416,7 +2416,12 @@ async def test_integration_web_search() -> None:
# Use static method for web search tool
web_search_tool = OpenAIResponsesClient.get_web_search_tool()
content = {
"messages": "Who are the main characters of Kpop Demon Hunters? Do a web search to find the answer.",
"messages": [
Message(
role="user",
text="Who are the main characters of Kpop Demon Hunters? Do a web search to find the answer.",
)
],
"options": {
"tool_choice": "auto",
"tools": [web_search_tool],
@@ -2438,7 +2443,7 @@ async def test_integration_web_search() -> None:
user_location={"country": "US", "city": "Seattle"},
)
content = {
"messages": "What is the current weather? Do not ask for my current location.",
"messages": [Message(role="user", text="What is the current weather? Do not ask for my current location.")],
"options": {
"tool_choice": "auto",
"tools": [web_search_tool_with_location],
@@ -39,7 +39,7 @@ class _ToolCallingAgent(BaseAgent):
def run(
self,
messages: str | Message | Sequence[str | Message] | None = None,
messages: str | Content | Message | Sequence[str | Content | Message] | None = None,
*,
stream: bool = False,
session: AgentSession | None = None,
@@ -35,7 +35,7 @@ class _SimpleAgent(BaseAgent):
def run(
self,
messages: str | Message | Sequence[str | Message] | None = None,
messages: str | Content | Message | Sequence[str | Content | Message] | None = None,
*,
stream: bool = False,
session: AgentSession | None = None,
@@ -105,7 +105,7 @@ class _CaptureAgent(BaseAgent):
def run(
self,
messages: str | Message | Sequence[str | Message] | None = None,
messages: str | Content | Message | Sequence[str | Content | Message] | None = None,
*,
stream: bool = False,
session: AgentSession | None = None,
@@ -835,7 +835,7 @@ class _StreamingTestAgent(BaseAgent):
def run(
self,
messages: str | Message | Sequence[str | Message] | None = None,
messages: str | Content | Message | Sequence[str | Content | Message] | None = None,
*,
stream: bool = False,
session: AgentSession | None = None,
@@ -52,7 +52,7 @@ class _KwargsCapturingAgent(BaseAgent):
def run(
self,
messages: str | Message | Sequence[str | Message] | None = None,
messages: str | Content | Message | Sequence[str | Content | Message] | None = None,
*,
stream: bool = False,
session: AgentSession | None = None,
@@ -85,7 +85,7 @@ class _OptionsAwareAgent(BaseAgent):
def run(
self,
messages: str | Message | Sequence[str | Message] | None = None,
messages: str | Content | Message | Sequence[str | Content | Message] | None = None,
*,
stream: bool = False,
session: AgentSession | None = None,