mirror of
https://github.com/microsoft/agent-framework.git
synced 2026-06-16 21:04:09 +08:00
* Fix #3613 message typing across chat and agents * Address #3613 review feedback and sample input style * refactor: use shared AgentRunMessages aliases (#3613) * refactor: rename agent run input aliases for #3613 * samples: inline image content in run calls * core: export AgentRunInputs from package init * core: use explicit init re-exports without __all__ * updated logging and inits * Fix core mypy export and samples XML note * Remove AgentRunInputsOrNone and dedupe loggers * Remove prepare_messages helper * fix integration tests
This commit is contained in:
committed by
GitHub
Unverified
parent
503eb10fdd
commit
dc9439a75a
@@ -19,7 +19,6 @@ from ._clients import (
|
||||
SupportsMCPTool,
|
||||
SupportsWebSearchTool,
|
||||
)
|
||||
from ._logging import get_logger, setup_logging
|
||||
from ._mcp import MCPStdioTool, MCPStreamableHTTPTool, MCPWebsocketTool
|
||||
from ._middleware import (
|
||||
AgentContext,
|
||||
@@ -34,7 +33,6 @@ from ._middleware import (
|
||||
FunctionInvocationContext,
|
||||
FunctionMiddleware,
|
||||
FunctionMiddlewareTypes,
|
||||
MiddlewareException,
|
||||
MiddlewareTermination,
|
||||
MiddlewareType,
|
||||
MiddlewareTypes,
|
||||
@@ -67,6 +65,7 @@ from ._tools import (
|
||||
from ._types import (
|
||||
AgentResponse,
|
||||
AgentResponseUpdate,
|
||||
AgentRunInputs,
|
||||
Annotation,
|
||||
ChatOptions,
|
||||
ChatResponse,
|
||||
@@ -152,6 +151,7 @@ from ._workflows import (
|
||||
response_handler,
|
||||
validate_workflow_graph,
|
||||
)
|
||||
from .exceptions import MiddlewareException
|
||||
|
||||
__all__ = [
|
||||
"AGENT_FRAMEWORK_USER_AGENT",
|
||||
@@ -169,6 +169,7 @@ __all__ = [
|
||||
"AgentMiddlewareTypes",
|
||||
"AgentResponse",
|
||||
"AgentResponseUpdate",
|
||||
"AgentRunInputs",
|
||||
"AgentSession",
|
||||
"Annotation",
|
||||
"BaseAgent",
|
||||
@@ -264,6 +265,7 @@ __all__ = [
|
||||
"WorkflowRunnerException",
|
||||
"WorkflowValidationError",
|
||||
"WorkflowViz",
|
||||
"__version__",
|
||||
"add_usage_details",
|
||||
"agent_middleware",
|
||||
"chat_middleware",
|
||||
@@ -271,7 +273,6 @@ __all__ = [
|
||||
"detect_media_type_from_base64",
|
||||
"executor",
|
||||
"function_middleware",
|
||||
"get_logger",
|
||||
"handler",
|
||||
"map_chat_to_agent_update",
|
||||
"merge_chat_options",
|
||||
@@ -283,7 +284,6 @@ __all__ = [
|
||||
"register_state_type",
|
||||
"resolve_agent_id",
|
||||
"response_handler",
|
||||
"setup_logging",
|
||||
"tool",
|
||||
"validate_chat_options",
|
||||
"validate_tool_mode",
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import inspect
|
||||
import logging
|
||||
import re
|
||||
import sys
|
||||
from collections.abc import Awaitable, Callable, Mapping, MutableMapping, Sequence
|
||||
@@ -29,7 +30,6 @@ from mcp.shared.exceptions import McpError
|
||||
from pydantic import BaseModel, Field, create_model
|
||||
|
||||
from ._clients import BaseChatClient, SupportsChatGetResponse
|
||||
from ._logging import get_logger
|
||||
from ._mcp import LOG_LEVEL_MAPPING, MCPTool
|
||||
from ._middleware import AgentMiddlewareLayer, MiddlewareTypes
|
||||
from ._serialization import SerializationMixin
|
||||
@@ -41,6 +41,7 @@ from ._tools import (
|
||||
from ._types import (
|
||||
AgentResponse,
|
||||
AgentResponseUpdate,
|
||||
AgentRunInputs,
|
||||
ChatResponse,
|
||||
ChatResponseUpdate,
|
||||
Message,
|
||||
@@ -67,7 +68,7 @@ else:
|
||||
if TYPE_CHECKING:
|
||||
from ._types import ChatOptions
|
||||
|
||||
logger = get_logger("agent_framework")
|
||||
logger = logging.getLogger("agent_framework")
|
||||
|
||||
ResponseModelBoundT = TypeVar("ResponseModelBoundT", bound=BaseModel)
|
||||
OptionsCoT = TypeVar(
|
||||
@@ -159,9 +160,6 @@ class _RunContext(TypedDict):
|
||||
finalize_kwargs: dict[str, Any]
|
||||
|
||||
|
||||
__all__ = ["Agent", "BaseAgent", "RawAgent", "SupportsAgentRun"]
|
||||
|
||||
|
||||
# region Agent Protocol
|
||||
|
||||
|
||||
@@ -230,7 +228,7 @@ class SupportsAgentRun(Protocol):
|
||||
@overload
|
||||
def run(
|
||||
self,
|
||||
messages: str | Message | Sequence[str | Message] | None = None,
|
||||
messages: AgentRunInputs | None = None,
|
||||
*,
|
||||
stream: Literal[False] = ...,
|
||||
session: AgentSession | None = None,
|
||||
@@ -242,7 +240,7 @@ class SupportsAgentRun(Protocol):
|
||||
@overload
|
||||
def run(
|
||||
self,
|
||||
messages: str | Message | Sequence[str | Message] | None = None,
|
||||
messages: AgentRunInputs | None = None,
|
||||
*,
|
||||
stream: Literal[True],
|
||||
session: AgentSession | None = None,
|
||||
@@ -253,7 +251,7 @@ class SupportsAgentRun(Protocol):
|
||||
|
||||
def run(
|
||||
self,
|
||||
messages: str | Message | Sequence[str | Message] | None = None,
|
||||
messages: AgentRunInputs | None = None,
|
||||
*,
|
||||
stream: bool = False,
|
||||
session: AgentSession | None = None,
|
||||
@@ -763,7 +761,7 @@ class RawAgent(BaseAgent, Generic[OptionsCoT]): # type: ignore[misc]
|
||||
@overload
|
||||
def run(
|
||||
self,
|
||||
messages: str | Message | Sequence[str | Message] | None = None,
|
||||
messages: AgentRunInputs | None = None,
|
||||
*,
|
||||
stream: Literal[False] = ...,
|
||||
session: AgentSession | None = None,
|
||||
@@ -780,7 +778,7 @@ class RawAgent(BaseAgent, Generic[OptionsCoT]): # type: ignore[misc]
|
||||
@overload
|
||||
def run(
|
||||
self,
|
||||
messages: str | Message | Sequence[str | Message] | None = None,
|
||||
messages: AgentRunInputs | None = None,
|
||||
*,
|
||||
stream: Literal[False] = ...,
|
||||
session: AgentSession | None = None,
|
||||
@@ -797,7 +795,7 @@ class RawAgent(BaseAgent, Generic[OptionsCoT]): # type: ignore[misc]
|
||||
@overload
|
||||
def run(
|
||||
self,
|
||||
messages: str | Message | Sequence[str | Message] | None = None,
|
||||
messages: AgentRunInputs | None = None,
|
||||
*,
|
||||
stream: Literal[True],
|
||||
session: AgentSession | None = None,
|
||||
@@ -813,7 +811,7 @@ class RawAgent(BaseAgent, Generic[OptionsCoT]): # type: ignore[misc]
|
||||
|
||||
def run(
|
||||
self,
|
||||
messages: str | Message | Sequence[str | Message] | None = None,
|
||||
messages: AgentRunInputs | None = None,
|
||||
*,
|
||||
stream: bool = False,
|
||||
session: AgentSession | None = None,
|
||||
@@ -1000,7 +998,7 @@ class RawAgent(BaseAgent, Generic[OptionsCoT]): # type: ignore[misc]
|
||||
async def _prepare_run_context(
|
||||
self,
|
||||
*,
|
||||
messages: str | Message | Sequence[str | Message] | None,
|
||||
messages: AgentRunInputs | None,
|
||||
session: AgentSession | None,
|
||||
tools: FunctionTool
|
||||
| Callable[..., Any]
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import sys
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import (
|
||||
@@ -27,7 +28,6 @@ from typing import (
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from ._logging import get_logger
|
||||
from ._serialization import SerializationMixin
|
||||
from ._tools import (
|
||||
FunctionInvocationConfiguration,
|
||||
@@ -38,7 +38,6 @@ from ._types import (
|
||||
ChatResponseUpdate,
|
||||
Message,
|
||||
ResponseStream,
|
||||
prepare_messages,
|
||||
validate_chat_options,
|
||||
)
|
||||
|
||||
@@ -61,17 +60,7 @@ InputT = TypeVar("InputT", contravariant=True)
|
||||
EmbeddingT = TypeVar("EmbeddingT")
|
||||
BaseChatClientT = TypeVar("BaseChatClientT", bound="BaseChatClient")
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
__all__ = [
|
||||
"BaseChatClient",
|
||||
"SupportsChatGetResponse",
|
||||
"SupportsCodeInterpreterTool",
|
||||
"SupportsFileSearchTool",
|
||||
"SupportsImageGenerationTool",
|
||||
"SupportsMCPTool",
|
||||
"SupportsWebSearchTool",
|
||||
]
|
||||
logger = logging.getLogger("agent_framework")
|
||||
|
||||
|
||||
# region SupportsChatGetResponse Protocol
|
||||
@@ -139,7 +128,7 @@ class SupportsChatGetResponse(Protocol[OptionsContraT]):
|
||||
@overload
|
||||
def get_response(
|
||||
self,
|
||||
messages: str | Message | Sequence[str | Message],
|
||||
messages: Sequence[Message],
|
||||
*,
|
||||
stream: Literal[False] = ...,
|
||||
options: ChatOptions[ResponseModelBoundT],
|
||||
@@ -149,7 +138,7 @@ class SupportsChatGetResponse(Protocol[OptionsContraT]):
|
||||
@overload
|
||||
def get_response(
|
||||
self,
|
||||
messages: str | Message | Sequence[str | Message],
|
||||
messages: Sequence[Message],
|
||||
*,
|
||||
stream: Literal[False] = ...,
|
||||
options: OptionsContraT | ChatOptions[None] | None = None,
|
||||
@@ -159,7 +148,7 @@ class SupportsChatGetResponse(Protocol[OptionsContraT]):
|
||||
@overload
|
||||
def get_response(
|
||||
self,
|
||||
messages: str | Message | Sequence[str | Message],
|
||||
messages: Sequence[Message],
|
||||
*,
|
||||
stream: Literal[True],
|
||||
options: OptionsContraT | ChatOptions[Any] | None = None,
|
||||
@@ -168,7 +157,7 @@ class SupportsChatGetResponse(Protocol[OptionsContraT]):
|
||||
|
||||
def get_response(
|
||||
self,
|
||||
messages: str | Message | Sequence[str | Message],
|
||||
messages: Sequence[Message],
|
||||
*,
|
||||
stream: bool = False,
|
||||
options: OptionsContraT | ChatOptions[Any] | None = None,
|
||||
@@ -254,9 +243,9 @@ class BaseChatClient(SerializationMixin, ABC, Generic[OptionsCoT]):
|
||||
client = CustomChatClient()
|
||||
|
||||
# Use the client to get responses
|
||||
response = await client.get_response("Hello, how are you?")
|
||||
response = await client.get_response([Message(role="user", text="Hello, how are you?")])
|
||||
# Or stream responses
|
||||
async for update in client.get_response("Hello!", stream=True):
|
||||
async for update in client.get_response([Message(role="user", text="Hello!")], stream=True):
|
||||
print(update)
|
||||
"""
|
||||
|
||||
@@ -376,7 +365,7 @@ class BaseChatClient(SerializationMixin, ABC, Generic[OptionsCoT]):
|
||||
@overload
|
||||
def get_response(
|
||||
self,
|
||||
messages: str | Message | Sequence[str | Message],
|
||||
messages: Sequence[Message],
|
||||
*,
|
||||
stream: Literal[False] = ...,
|
||||
options: ChatOptions[ResponseModelBoundT],
|
||||
@@ -386,7 +375,7 @@ class BaseChatClient(SerializationMixin, ABC, Generic[OptionsCoT]):
|
||||
@overload
|
||||
def get_response(
|
||||
self,
|
||||
messages: str | Message | Sequence[str | Message],
|
||||
messages: Sequence[Message],
|
||||
*,
|
||||
stream: Literal[False] = ...,
|
||||
options: OptionsCoT | ChatOptions[None] | None = None,
|
||||
@@ -396,7 +385,7 @@ class BaseChatClient(SerializationMixin, ABC, Generic[OptionsCoT]):
|
||||
@overload
|
||||
def get_response(
|
||||
self,
|
||||
messages: str | Message | Sequence[str | Message],
|
||||
messages: Sequence[Message],
|
||||
*,
|
||||
stream: Literal[True],
|
||||
options: OptionsCoT | ChatOptions[Any] | None = None,
|
||||
@@ -405,7 +394,7 @@ class BaseChatClient(SerializationMixin, ABC, Generic[OptionsCoT]):
|
||||
|
||||
def get_response(
|
||||
self,
|
||||
messages: str | Message | Sequence[str | Message],
|
||||
messages: Sequence[Message],
|
||||
*,
|
||||
stream: bool = False,
|
||||
options: OptionsCoT | ChatOptions[Any] | None = None,
|
||||
@@ -422,9 +411,8 @@ class BaseChatClient(SerializationMixin, ABC, Generic[OptionsCoT]):
|
||||
Returns:
|
||||
When streaming a response stream of ChatResponseUpdates, otherwise an Awaitable ChatResponse.
|
||||
"""
|
||||
prepared_messages = prepare_messages(messages)
|
||||
return self._inner_get_response(
|
||||
messages=prepared_messages,
|
||||
messages=messages,
|
||||
stream=stream,
|
||||
options=options or {}, # type: ignore[arg-type]
|
||||
**kwargs,
|
||||
|
||||
@@ -1,29 +0,0 @@
|
||||
# Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
import logging
|
||||
|
||||
from .exceptions import AgentFrameworkException
|
||||
|
||||
__all__ = ["get_logger", "setup_logging"]
|
||||
|
||||
|
||||
def setup_logging() -> None:
|
||||
"""Setup the logging configuration for the agent framework."""
|
||||
logging.basicConfig(
|
||||
format="[%(asctime)s - %(pathname)s:%(lineno)d - %(levelname)s] %(message)s",
|
||||
datefmt="%Y-%m-%d %H:%M:%S",
|
||||
)
|
||||
|
||||
|
||||
def get_logger(name: str = "agent_framework") -> logging.Logger:
|
||||
"""Get a logger with the specified name, defaulting to 'agent_framework'.
|
||||
|
||||
Args:
|
||||
name (str): The name of the logger. Defaults to 'agent_framework'.
|
||||
|
||||
Returns:
|
||||
logging.Logger: The configured logger instance.
|
||||
"""
|
||||
if not name.startswith("agent_framework"):
|
||||
raise AgentFrameworkException("Logger name must start with 'agent_framework'.")
|
||||
return logging.getLogger(name)
|
||||
@@ -72,12 +72,6 @@ LOG_LEVEL_MAPPING: dict[types.LoggingLevel, int] = {
|
||||
"emergency": logging.CRITICAL,
|
||||
}
|
||||
|
||||
__all__ = [
|
||||
"MCPStdioTool",
|
||||
"MCPStreamableHTTPTool",
|
||||
"MCPWebsocketTool",
|
||||
]
|
||||
|
||||
|
||||
def _parse_prompt_result_from_mcp(
|
||||
mcp_type: types.GetPromptResult,
|
||||
|
||||
@@ -14,11 +14,12 @@ from ._clients import SupportsChatGetResponse
|
||||
from ._types import (
|
||||
AgentResponse,
|
||||
AgentResponseUpdate,
|
||||
AgentRunInputs,
|
||||
ChatResponse,
|
||||
ChatResponseUpdate,
|
||||
Message,
|
||||
ResponseStream,
|
||||
prepare_messages,
|
||||
normalize_messages,
|
||||
)
|
||||
from .exceptions import MiddlewareException
|
||||
|
||||
@@ -42,27 +43,6 @@ if TYPE_CHECKING:
|
||||
|
||||
ResponseModelBoundT = TypeVar("ResponseModelBoundT", bound=BaseModel)
|
||||
|
||||
__all__ = [
|
||||
"AgentContext",
|
||||
"AgentMiddleware",
|
||||
"AgentMiddlewareLayer",
|
||||
"AgentMiddlewareTypes",
|
||||
"ChatAndFunctionMiddlewareTypes",
|
||||
"ChatContext",
|
||||
"ChatMiddleware",
|
||||
"ChatMiddlewareLayer",
|
||||
"ChatMiddlewareTypes",
|
||||
"FunctionInvocationContext",
|
||||
"FunctionMiddleware",
|
||||
"FunctionMiddlewareTypes",
|
||||
"MiddlewareException",
|
||||
"MiddlewareTermination",
|
||||
"MiddlewareType",
|
||||
"MiddlewareTypes",
|
||||
"agent_middleware",
|
||||
"chat_middleware",
|
||||
"function_middleware",
|
||||
]
|
||||
|
||||
AgentT = TypeVar("AgentT", bound="SupportsAgentRun")
|
||||
ContextT = TypeVar("ContextT")
|
||||
@@ -978,7 +958,7 @@ class ChatMiddlewareLayer(Generic[OptionsCoT]):
|
||||
@overload
|
||||
def get_response(
|
||||
self,
|
||||
messages: str | Message | Sequence[str | Message],
|
||||
messages: Sequence[Message],
|
||||
*,
|
||||
stream: Literal[False] = ...,
|
||||
options: ChatOptions[ResponseModelBoundT],
|
||||
@@ -988,7 +968,7 @@ class ChatMiddlewareLayer(Generic[OptionsCoT]):
|
||||
@overload
|
||||
def get_response(
|
||||
self,
|
||||
messages: str | Message | Sequence[str | Message],
|
||||
messages: Sequence[Message],
|
||||
*,
|
||||
stream: Literal[False] = ...,
|
||||
options: OptionsCoT | ChatOptions[None] | None = None,
|
||||
@@ -998,7 +978,7 @@ class ChatMiddlewareLayer(Generic[OptionsCoT]):
|
||||
@overload
|
||||
def get_response(
|
||||
self,
|
||||
messages: str | Message | Sequence[str | Message],
|
||||
messages: Sequence[Message],
|
||||
*,
|
||||
stream: Literal[True],
|
||||
options: OptionsCoT | ChatOptions[Any] | None = None,
|
||||
@@ -1007,7 +987,7 @@ class ChatMiddlewareLayer(Generic[OptionsCoT]):
|
||||
|
||||
def get_response(
|
||||
self,
|
||||
messages: str | Message | Sequence[str | Message],
|
||||
messages: Sequence[Message],
|
||||
*,
|
||||
stream: bool = False,
|
||||
options: OptionsCoT | ChatOptions[Any] | None = None,
|
||||
@@ -1034,7 +1014,7 @@ class ChatMiddlewareLayer(Generic[OptionsCoT]):
|
||||
|
||||
context = ChatContext(
|
||||
client=self, # type: ignore[arg-type]
|
||||
messages=prepare_messages(messages),
|
||||
messages=list(messages),
|
||||
options=options,
|
||||
stream=stream,
|
||||
kwargs=kwargs,
|
||||
@@ -1095,7 +1075,7 @@ class AgentMiddlewareLayer:
|
||||
@overload
|
||||
def run(
|
||||
self,
|
||||
messages: str | Message | Sequence[str | Message] | None = None,
|
||||
messages: AgentRunInputs | None = None,
|
||||
*,
|
||||
stream: Literal[False] = ...,
|
||||
session: AgentSession | None = None,
|
||||
@@ -1107,7 +1087,7 @@ class AgentMiddlewareLayer:
|
||||
@overload
|
||||
def run(
|
||||
self,
|
||||
messages: str | Message | Sequence[str | Message] | None = None,
|
||||
messages: AgentRunInputs | None = None,
|
||||
*,
|
||||
stream: Literal[False] = ...,
|
||||
session: AgentSession | None = None,
|
||||
@@ -1119,7 +1099,7 @@ class AgentMiddlewareLayer:
|
||||
@overload
|
||||
def run(
|
||||
self,
|
||||
messages: str | Message | Sequence[str | Message] | None = None,
|
||||
messages: AgentRunInputs | None = None,
|
||||
*,
|
||||
stream: Literal[True],
|
||||
session: AgentSession | None = None,
|
||||
@@ -1130,7 +1110,7 @@ class AgentMiddlewareLayer:
|
||||
|
||||
def run(
|
||||
self,
|
||||
messages: str | Message | Sequence[str | Message] | None = None,
|
||||
messages: AgentRunInputs | None = None,
|
||||
*,
|
||||
stream: bool = False,
|
||||
session: AgentSession | None = None,
|
||||
@@ -1161,7 +1141,7 @@ class AgentMiddlewareLayer:
|
||||
|
||||
context = AgentContext(
|
||||
agent=self, # type: ignore[arg-type]
|
||||
messages=prepare_messages(messages), # type: ignore[arg-type]
|
||||
messages=normalize_messages(messages),
|
||||
session=session,
|
||||
options=options,
|
||||
stream=stream,
|
||||
|
||||
@@ -3,13 +3,12 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
from collections.abc import Mapping, MutableMapping
|
||||
from typing import Any, ClassVar, Protocol, TypeVar, runtime_checkable
|
||||
|
||||
from ._logging import get_logger
|
||||
|
||||
logger = get_logger()
|
||||
logger = logging.getLogger("agent_framework")
|
||||
|
||||
ClassT = TypeVar("ClassT", bound="SerializationMixin")
|
||||
ProtocolT = TypeVar("ProtocolT", bound="SerializationProtocol")
|
||||
|
||||
@@ -24,16 +24,6 @@ if TYPE_CHECKING:
|
||||
from ._agents import SupportsAgentRun
|
||||
|
||||
|
||||
__all__ = [
|
||||
"AgentSession",
|
||||
"BaseContextProvider",
|
||||
"BaseHistoryProvider",
|
||||
"InMemoryHistoryProvider",
|
||||
"SessionContext",
|
||||
"register_state_type",
|
||||
]
|
||||
|
||||
|
||||
# Registry of known types for state deserialization
|
||||
_STATE_TYPE_REGISTRY: dict[str, type] = {}
|
||||
|
||||
|
||||
@@ -45,7 +45,6 @@ if sys.version_info >= (3, 13):
|
||||
else:
|
||||
from typing_extensions import TypeVar # type: ignore # pragma: no cover
|
||||
|
||||
__all__ = ["SecretString", "load_settings"]
|
||||
|
||||
SettingsT = TypeVar("SettingsT", default=dict[str, Any])
|
||||
|
||||
|
||||
@@ -2,21 +2,14 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
from typing import Any, Final
|
||||
|
||||
from . import __version__ as version_info
|
||||
from ._logging import get_logger
|
||||
|
||||
logger = get_logger()
|
||||
logger = logging.getLogger("agent_framework")
|
||||
|
||||
__all__ = [
|
||||
"AGENT_FRAMEWORK_USER_AGENT",
|
||||
"APP_INFO",
|
||||
"USER_AGENT_KEY",
|
||||
"USER_AGENT_TELEMETRY_DISABLED_ENV_VAR",
|
||||
"prepend_agent_framework_to_user_agent",
|
||||
]
|
||||
|
||||
# Note that if this environment variable does not exist, user agent telemetry is enabled.
|
||||
USER_AGENT_TELEMETRY_DISABLED_ENV_VAR = "AGENT_FRAMEWORK_USER_AGENT_DISABLED"
|
||||
|
||||
@@ -5,6 +5,7 @@ from __future__ import annotations
|
||||
import asyncio
|
||||
import inspect
|
||||
import json
|
||||
import logging
|
||||
import sys
|
||||
from collections.abc import (
|
||||
AsyncIterable,
|
||||
@@ -34,7 +35,6 @@ from typing import (
|
||||
from opentelemetry.metrics import Histogram, NoOpHistogram
|
||||
from pydantic import BaseModel, Field, ValidationError, create_model
|
||||
|
||||
from ._logging import get_logger
|
||||
from ._serialization import SerializationMixin
|
||||
from .exceptions import ToolException
|
||||
from .observability import (
|
||||
@@ -71,18 +71,8 @@ if TYPE_CHECKING:
|
||||
ResponseModelBoundT = TypeVar("ResponseModelBoundT", bound=BaseModel)
|
||||
|
||||
|
||||
logger = get_logger()
|
||||
logger = logging.getLogger("agent_framework")
|
||||
|
||||
__all__ = [
|
||||
"FunctionInvocationConfiguration",
|
||||
"FunctionInvocationLayer",
|
||||
"FunctionTool",
|
||||
"normalize_function_invocation_configuration",
|
||||
"tool",
|
||||
]
|
||||
|
||||
|
||||
logger = get_logger()
|
||||
DEFAULT_MAX_ITERATIONS: Final[int] = 40
|
||||
DEFAULT_MAX_CONSECUTIVE_ERRORS_PER_REQUEST: Final[int] = 3
|
||||
ChatClientT = TypeVar("ChatClientT", bound="SupportsChatGetResponse[Any]")
|
||||
@@ -1941,7 +1931,7 @@ class FunctionInvocationLayer(Generic[OptionsCoT]):
|
||||
@overload
|
||||
def get_response(
|
||||
self,
|
||||
messages: str | Message | Sequence[str | Message],
|
||||
messages: Sequence[Message],
|
||||
*,
|
||||
stream: Literal[False] = ...,
|
||||
options: ChatOptions[ResponseModelBoundT],
|
||||
@@ -1951,7 +1941,7 @@ class FunctionInvocationLayer(Generic[OptionsCoT]):
|
||||
@overload
|
||||
def get_response(
|
||||
self,
|
||||
messages: str | Message | Sequence[str | Message],
|
||||
messages: Sequence[Message],
|
||||
*,
|
||||
stream: Literal[False] = ...,
|
||||
options: OptionsCoT | ChatOptions[None] | None = None,
|
||||
@@ -1961,7 +1951,7 @@ class FunctionInvocationLayer(Generic[OptionsCoT]):
|
||||
@overload
|
||||
def get_response(
|
||||
self,
|
||||
messages: str | Message | Sequence[str | Message],
|
||||
messages: Sequence[Message],
|
||||
*,
|
||||
stream: Literal[True],
|
||||
options: OptionsCoT | ChatOptions[Any] | None = None,
|
||||
@@ -1970,7 +1960,7 @@ class FunctionInvocationLayer(Generic[OptionsCoT]):
|
||||
|
||||
def get_response(
|
||||
self,
|
||||
messages: str | Message | Sequence[str | Message],
|
||||
messages: Sequence[Message],
|
||||
*,
|
||||
stream: bool = False,
|
||||
options: OptionsCoT | ChatOptions[Any] | None = None,
|
||||
@@ -1982,7 +1972,6 @@ class FunctionInvocationLayer(Generic[OptionsCoT]):
|
||||
ChatResponse,
|
||||
ChatResponseUpdate,
|
||||
ResponseStream,
|
||||
prepare_messages,
|
||||
)
|
||||
|
||||
super_get_response = super().get_response # type: ignore[misc]
|
||||
@@ -2014,7 +2003,7 @@ class FunctionInvocationLayer(Generic[OptionsCoT]):
|
||||
nonlocal mutable_options
|
||||
nonlocal filtered_kwargs
|
||||
errors_in_a_row: int = 0
|
||||
prepped_messages = prepare_messages(messages)
|
||||
prepped_messages = list(messages)
|
||||
fcc_messages: list[Message] = []
|
||||
response: ChatResponse | None = None
|
||||
|
||||
@@ -2108,7 +2097,7 @@ class FunctionInvocationLayer(Generic[OptionsCoT]):
|
||||
nonlocal mutable_options
|
||||
nonlocal stream_result_hooks
|
||||
errors_in_a_row: int = 0
|
||||
prepped_messages = prepare_messages(messages)
|
||||
prepped_messages = list(messages)
|
||||
fcc_messages: list[Message] = []
|
||||
response: ChatResponse | None = None
|
||||
|
||||
|
||||
@@ -4,6 +4,7 @@ from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
import sys
|
||||
from asyncio import iscoroutine
|
||||
@@ -13,7 +14,6 @@ from typing import TYPE_CHECKING, Any, ClassVar, Final, Generic, Literal, NewTyp
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from ._logging import get_logger
|
||||
from ._serialization import SerializationMixin
|
||||
from ._tools import FunctionTool, tool
|
||||
from .exceptions import AdditionItemMismatch, ContentError
|
||||
@@ -27,41 +27,7 @@ if sys.version_info >= (3, 11):
|
||||
else:
|
||||
from typing_extensions import TypedDict # type: ignore # pragma: no cover
|
||||
|
||||
__all__ = [
|
||||
"AgentResponse",
|
||||
"AgentResponseUpdate",
|
||||
"Annotation",
|
||||
"ChatOptions",
|
||||
"ChatResponse",
|
||||
"ChatResponseUpdate",
|
||||
"Content",
|
||||
"ContinuationToken",
|
||||
"FinalT",
|
||||
"FinishReason",
|
||||
"FinishReasonLiteral",
|
||||
"Message",
|
||||
"OuterFinalT",
|
||||
"OuterUpdateT",
|
||||
"ResponseStream",
|
||||
"Role",
|
||||
"RoleLiteral",
|
||||
"TextSpanRegion",
|
||||
"ToolMode",
|
||||
"UpdateT",
|
||||
"UsageDetails",
|
||||
"add_usage_details",
|
||||
"detect_media_type_from_base64",
|
||||
"map_chat_to_agent_update",
|
||||
"merge_chat_options",
|
||||
"normalize_messages",
|
||||
"normalize_tools",
|
||||
"prepend_instructions_to_messages",
|
||||
"validate_chat_options",
|
||||
"validate_tool_mode",
|
||||
"validate_tools",
|
||||
]
|
||||
|
||||
logger = get_logger("agent_framework")
|
||||
logger = logging.getLogger("agent_framework")
|
||||
|
||||
|
||||
# region Content Parsing Utilities
|
||||
@@ -1536,47 +1502,11 @@ class Message(SerializationMixin):
|
||||
return " ".join(content.text for content in self.contents if content.type == "text") # type: ignore[misc]
|
||||
|
||||
|
||||
def prepare_messages(
|
||||
messages: str | Content | Message | Sequence[str | Content | Message],
|
||||
system_instructions: str | Sequence[str] | None = None,
|
||||
) -> list[Message]:
|
||||
"""Convert various message input formats into a list of Message objects.
|
||||
|
||||
Args:
|
||||
messages: The input messages in various supported formats. Can be:
|
||||
- A string (converted to a user message)
|
||||
- A Content object (wrapped in a user Message)
|
||||
- A Message object
|
||||
- A sequence containing any mix of the above
|
||||
system_instructions: The system instructions. They will be inserted to the start of the messages list.
|
||||
|
||||
Returns:
|
||||
A list of Message objects.
|
||||
"""
|
||||
if system_instructions is not None:
|
||||
if isinstance(system_instructions, str):
|
||||
system_instructions = [system_instructions]
|
||||
system_instruction_messages = [Message("system", [instr]) for instr in system_instructions]
|
||||
else:
|
||||
system_instruction_messages = []
|
||||
|
||||
if isinstance(messages, str):
|
||||
return [*system_instruction_messages, Message("user", [messages])]
|
||||
if isinstance(messages, Content):
|
||||
return [*system_instruction_messages, Message("user", [messages])]
|
||||
if isinstance(messages, Message):
|
||||
return [*system_instruction_messages, messages]
|
||||
|
||||
return_messages: list[Message] = system_instruction_messages
|
||||
for msg in messages:
|
||||
if isinstance(msg, (str, Content)):
|
||||
msg = Message("user", [msg])
|
||||
return_messages.append(msg)
|
||||
return return_messages
|
||||
AgentRunInputs = str | Content | Message | Sequence[str | Content | Message]
|
||||
|
||||
|
||||
def normalize_messages(
|
||||
messages: str | Content | Message | Sequence[str | Content | Message] | None = None,
|
||||
messages: AgentRunInputs | None = None,
|
||||
) -> list[Message]:
|
||||
"""Normalize message inputs to a list of Message objects.
|
||||
|
||||
|
||||
@@ -22,6 +22,7 @@ from .._sessions import (
|
||||
from .._types import (
|
||||
AgentResponse,
|
||||
AgentResponseUpdate,
|
||||
AgentRunInputs,
|
||||
Content,
|
||||
Message,
|
||||
ResponseStream,
|
||||
@@ -145,7 +146,7 @@ class WorkflowAgent(BaseAgent):
|
||||
@overload
|
||||
def run(
|
||||
self,
|
||||
messages: str | Message | Sequence[str | Message] | None = None,
|
||||
messages: AgentRunInputs | None = None,
|
||||
*,
|
||||
stream: Literal[True],
|
||||
session: AgentSession | None = None,
|
||||
@@ -157,7 +158,7 @@ class WorkflowAgent(BaseAgent):
|
||||
@overload
|
||||
async def run(
|
||||
self,
|
||||
messages: str | Message | Sequence[str | Message] | None = None,
|
||||
messages: AgentRunInputs | None = None,
|
||||
*,
|
||||
stream: Literal[False] = ...,
|
||||
session: AgentSession | None = None,
|
||||
@@ -168,7 +169,7 @@ class WorkflowAgent(BaseAgent):
|
||||
|
||||
def run(
|
||||
self,
|
||||
messages: str | Message | Sequence[str | Message] | None = None,
|
||||
messages: AgentRunInputs | None = None,
|
||||
*,
|
||||
stream: bool = False,
|
||||
session: AgentSession | None = None,
|
||||
@@ -214,7 +215,7 @@ class WorkflowAgent(BaseAgent):
|
||||
|
||||
async def _run_impl(
|
||||
self,
|
||||
messages: str | Message | Sequence[str | Message],
|
||||
messages: AgentRunInputs,
|
||||
response_id: str,
|
||||
session: AgentSession | None,
|
||||
checkpoint_id: str | None = None,
|
||||
@@ -270,7 +271,7 @@ class WorkflowAgent(BaseAgent):
|
||||
|
||||
async def _run_stream_impl(
|
||||
self,
|
||||
messages: str | Message | Sequence[str | Message],
|
||||
messages: AgentRunInputs,
|
||||
response_id: str,
|
||||
session: AgentSession | None,
|
||||
checkpoint_id: str | None = None,
|
||||
|
||||
@@ -3,11 +3,10 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import logging
|
||||
import pickle # nosec # noqa: S403
|
||||
from typing import Any
|
||||
|
||||
from agent_framework import get_logger
|
||||
|
||||
"""Checkpoint encoding using JSON structure with pickle+base64 for arbitrary data.
|
||||
|
||||
This hybrid approach provides:
|
||||
@@ -20,7 +19,7 @@ from trusted sources. Loading a malicious checkpoint file can execute arbitrary
|
||||
"""
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
logger = logging.getLogger("agent_framework")
|
||||
|
||||
# Marker to identify pickled values in serialized JSON
|
||||
_PICKLE_MARKER = "__pickled__"
|
||||
|
||||
@@ -2,18 +2,17 @@
|
||||
|
||||
"""Shared helpers for normalizing workflow message inputs."""
|
||||
|
||||
from collections.abc import Sequence
|
||||
|
||||
from agent_framework import Message
|
||||
from agent_framework import Content, Message
|
||||
from agent_framework._types import AgentRunInputs
|
||||
|
||||
|
||||
def normalize_messages_input(
|
||||
messages: str | Message | Sequence[str | Message] | None = None,
|
||||
messages: AgentRunInputs | None = None,
|
||||
) -> list[Message]:
|
||||
"""Normalize heterogeneous message inputs to a list of Message objects.
|
||||
|
||||
Args:
|
||||
messages: String, Message, or sequence of either. None yields empty list.
|
||||
messages: String, Content, Message, or sequence of those values. None yields empty list.
|
||||
|
||||
Returns:
|
||||
List of Message instances suitable for workflow consumption.
|
||||
@@ -24,6 +23,9 @@ def normalize_messages_input(
|
||||
if isinstance(messages, str):
|
||||
return [Message(role="user", text=messages)]
|
||||
|
||||
if isinstance(messages, Content):
|
||||
return [Message(role="user", contents=[messages])]
|
||||
|
||||
if isinstance(messages, Message):
|
||||
return [messages]
|
||||
|
||||
@@ -31,13 +33,12 @@ def normalize_messages_input(
|
||||
for item in messages:
|
||||
if isinstance(item, str):
|
||||
normalized.append(Message(role="user", text=item))
|
||||
elif isinstance(item, Content):
|
||||
normalized.append(Message(role="user", contents=[item]))
|
||||
elif isinstance(item, Message):
|
||||
normalized.append(item)
|
||||
else:
|
||||
raise TypeError(
|
||||
f"Messages sequence must contain only str or Message instances; found {type(item).__name__}."
|
||||
f"Messages sequence must contain only str, Content, or Message instances; found {type(item).__name__}."
|
||||
)
|
||||
return normalized
|
||||
|
||||
|
||||
__all__ = ["normalize_messages_input"]
|
||||
|
||||
@@ -27,8 +27,6 @@ if sys.version_info >= (3, 11):
|
||||
else:
|
||||
from typing_extensions import TypedDict # type: ignore # pragma: no cover
|
||||
|
||||
__all__ = ["AzureOpenAIAssistantsClient"]
|
||||
|
||||
|
||||
# region Azure OpenAI Assistants Options TypedDict
|
||||
|
||||
|
||||
@@ -53,7 +53,6 @@ if TYPE_CHECKING:
|
||||
|
||||
logger: logging.Logger = logging.getLogger(__name__)
|
||||
|
||||
__all__ = ["AzureOpenAIChatClient", "AzureOpenAIChatOptions", "AzureUserSecurityContext"]
|
||||
|
||||
ResponseModelT = TypeVar("ResponseModelT", bound=BaseModel | None, default=None)
|
||||
|
||||
|
||||
@@ -42,8 +42,6 @@ if TYPE_CHECKING:
|
||||
from .._middleware import MiddlewareTypes
|
||||
from ..openai._responses_client import OpenAIResponsesOptions
|
||||
|
||||
__all__ = ["AzureOpenAIResponsesClient"]
|
||||
|
||||
|
||||
AzureOpenAIResponsesOptionsT = TypeVar(
|
||||
"AzureOpenAIResponsesOptionsT",
|
||||
|
||||
@@ -20,7 +20,6 @@ from opentelemetry.semconv.attributes import service_attributes
|
||||
from opentelemetry.semconv_ai import Meters, SpanAttributes
|
||||
|
||||
from . import __version__ as version_info
|
||||
from ._logging import get_logger
|
||||
from ._settings import load_settings
|
||||
|
||||
if sys.version_info >= (3, 13):
|
||||
@@ -44,6 +43,7 @@ if TYPE_CHECKING: # pragma: no cover
|
||||
from ._types import (
|
||||
AgentResponse,
|
||||
AgentResponseUpdate,
|
||||
AgentRunInputs,
|
||||
ChatOptions,
|
||||
ChatResponse,
|
||||
ChatResponseUpdate,
|
||||
@@ -73,7 +73,7 @@ AgentT = TypeVar("AgentT", bound="SupportsAgentRun")
|
||||
ChatClientT = TypeVar("ChatClientT", bound="SupportsChatGetResponse[Any]")
|
||||
|
||||
|
||||
logger = get_logger()
|
||||
logger = logging.getLogger("agent_framework")
|
||||
|
||||
|
||||
OTEL_METRICS: Final[str] = "__otel_metrics__"
|
||||
@@ -747,7 +747,6 @@ class ObservabilitySettings:
|
||||
for log_exporter in log_exporters:
|
||||
logger_provider.add_log_record_processor(BatchLogRecordProcessor(log_exporter))
|
||||
# Attach a handler with the provider to the root logger
|
||||
logger = logging.getLogger()
|
||||
handler = LoggingHandler(logger_provider=logger_provider)
|
||||
logger.addHandler(handler)
|
||||
set_logger_provider(logger_provider)
|
||||
@@ -1084,7 +1083,7 @@ class ChatTelemetryLayer(Generic[OptionsCoT]):
|
||||
@overload
|
||||
def get_response(
|
||||
self,
|
||||
messages: str | Message | Sequence[str | Message],
|
||||
messages: Sequence[Message],
|
||||
*,
|
||||
stream: Literal[False] = ...,
|
||||
options: ChatOptions[ResponseModelBoundT],
|
||||
@@ -1094,7 +1093,7 @@ class ChatTelemetryLayer(Generic[OptionsCoT]):
|
||||
@overload
|
||||
def get_response(
|
||||
self,
|
||||
messages: str | Message | Sequence[str | Message],
|
||||
messages: Sequence[Message],
|
||||
*,
|
||||
stream: Literal[False] = ...,
|
||||
options: OptionsCoT | ChatOptions[None] | None = None,
|
||||
@@ -1104,7 +1103,7 @@ class ChatTelemetryLayer(Generic[OptionsCoT]):
|
||||
@overload
|
||||
def get_response(
|
||||
self,
|
||||
messages: str | Message | Sequence[str | Message],
|
||||
messages: Sequence[Message],
|
||||
*,
|
||||
stream: Literal[True],
|
||||
options: OptionsCoT | ChatOptions[Any] | None = None,
|
||||
@@ -1113,7 +1112,7 @@ class ChatTelemetryLayer(Generic[OptionsCoT]):
|
||||
|
||||
def get_response(
|
||||
self,
|
||||
messages: str | Message | Sequence[str | Message],
|
||||
messages: Sequence[Message],
|
||||
*,
|
||||
stream: bool = False,
|
||||
options: OptionsCoT | ChatOptions[Any] | None = None,
|
||||
@@ -1277,7 +1276,7 @@ class AgentTelemetryLayer:
|
||||
@overload
|
||||
def run(
|
||||
self,
|
||||
messages: str | Message | Sequence[str | Message] | None = None,
|
||||
messages: AgentRunInputs | None = None,
|
||||
*,
|
||||
stream: Literal[False] = ...,
|
||||
session: AgentSession | None = None,
|
||||
@@ -1287,7 +1286,7 @@ class AgentTelemetryLayer:
|
||||
@overload
|
||||
def run(
|
||||
self,
|
||||
messages: str | Message | Sequence[str | Message] | None = None,
|
||||
messages: AgentRunInputs | None = None,
|
||||
*,
|
||||
stream: Literal[True],
|
||||
session: AgentSession | None = None,
|
||||
@@ -1296,7 +1295,7 @@ class AgentTelemetryLayer:
|
||||
|
||||
def run(
|
||||
self,
|
||||
messages: str | Message | Sequence[str | Message] | None = None,
|
||||
messages: AgentRunInputs | None = None,
|
||||
*,
|
||||
stream: bool = False,
|
||||
session: AgentSession | None = None,
|
||||
@@ -1614,15 +1613,15 @@ def capture_exception(span: trace.Span, exception: Exception, timestamp: int | N
|
||||
def _capture_messages(
|
||||
span: trace.Span,
|
||||
provider_name: str,
|
||||
messages: str | Message | Sequence[str | Message],
|
||||
messages: AgentRunInputs,
|
||||
system_instructions: str | list[str] | None = None,
|
||||
output: bool = False,
|
||||
finish_reason: FinishReason | None = None,
|
||||
) -> None:
|
||||
"""Log messages with extra information."""
|
||||
from ._types import prepare_messages
|
||||
from ._types import normalize_messages, prepend_instructions_to_messages
|
||||
|
||||
prepped = prepare_messages(messages, system_instructions=system_instructions)
|
||||
prepped = prepend_instructions_to_messages(normalize_messages(messages), system_instructions)
|
||||
otel_messages: list[dict[str, Any]] = []
|
||||
for index, message in enumerate(prepped):
|
||||
# Reuse the otel message representation for logging instead of calling to_dict()
|
||||
|
||||
@@ -33,7 +33,6 @@ if sys.version_info >= (3, 11):
|
||||
else:
|
||||
from typing_extensions import Self, TypedDict # type:ignore # pragma: no cover
|
||||
|
||||
__all__ = ["OpenAIAssistantProvider"]
|
||||
|
||||
# Type variable for options - allows typed OpenAIAssistantProvider[OptionsCoT] returns
|
||||
# Default matches OpenAIAssistantsClient's default options type
|
||||
|
||||
@@ -68,12 +68,6 @@ else:
|
||||
if TYPE_CHECKING:
|
||||
from .._middleware import MiddlewareTypes
|
||||
|
||||
__all__ = [
|
||||
"AssistantToolResources",
|
||||
"OpenAIAssistantsClient",
|
||||
"OpenAIAssistantsOptions",
|
||||
]
|
||||
|
||||
|
||||
# region OpenAI Assistants Options TypedDict
|
||||
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import sys
|
||||
from collections.abc import AsyncIterable, Awaitable, Callable, Mapping, MutableMapping, Sequence
|
||||
from datetime import datetime, timezone
|
||||
@@ -20,7 +21,6 @@ from openai.types.chat.completion_create_params import WebSearchOptions
|
||||
from pydantic import BaseModel
|
||||
|
||||
from .._clients import BaseChatClient
|
||||
from .._logging import get_logger
|
||||
from .._middleware import ChatAndFunctionMiddlewareTypes, ChatMiddlewareLayer
|
||||
from .._settings import load_settings
|
||||
from .._tools import (
|
||||
@@ -60,9 +60,7 @@ if sys.version_info >= (3, 11):
|
||||
else:
|
||||
from typing_extensions import TypedDict # type: ignore # pragma: no cover
|
||||
|
||||
__all__ = ["OpenAIChatClient", "OpenAIChatOptions"]
|
||||
|
||||
logger = get_logger("agent_framework.openai")
|
||||
logger = logging.getLogger("agent_framework.openai")
|
||||
|
||||
ResponseModelT = TypeVar("ResponseModelT", bound=BaseModel | None, default=None)
|
||||
|
||||
|
||||
@@ -10,8 +10,6 @@ from openai import BadRequestError
|
||||
|
||||
from ..exceptions import ServiceContentFilterException
|
||||
|
||||
__all__ = ["ContentFilterResultSeverity", "OpenAIContentFilterException"]
|
||||
|
||||
|
||||
class ContentFilterResultSeverity(Enum):
|
||||
"""The severity of the content filter result."""
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import sys
|
||||
from collections.abc import (
|
||||
AsyncIterable,
|
||||
@@ -36,7 +37,6 @@ from openai.types.responses.web_search_tool_param import WebSearchToolParam
|
||||
from pydantic import BaseModel
|
||||
|
||||
from .._clients import BaseChatClient
|
||||
from .._logging import get_logger
|
||||
from .._middleware import ChatMiddlewareLayer
|
||||
from .._settings import load_settings
|
||||
from .._tools import (
|
||||
@@ -90,10 +90,7 @@ if TYPE_CHECKING:
|
||||
FunctionMiddlewareCallable,
|
||||
)
|
||||
|
||||
logger = get_logger("agent_framework.openai")
|
||||
|
||||
|
||||
__all__ = ["OpenAIContinuationToken", "OpenAIResponsesClient", "OpenAIResponsesOptions", "RawOpenAIResponsesClient"]
|
||||
logger = logging.getLogger("agent_framework.openai")
|
||||
|
||||
|
||||
class OpenAIContinuationToken(ContinuationToken):
|
||||
|
||||
@@ -22,14 +22,13 @@ from openai.types.responses.response import Response
|
||||
from openai.types.responses.response_stream_event import ResponseStreamEvent
|
||||
from packaging.version import parse
|
||||
|
||||
from .._logging import get_logger
|
||||
from .._serialization import SerializationMixin
|
||||
from .._settings import SecretString
|
||||
from .._telemetry import APP_INFO, USER_AGENT_KEY, prepend_agent_framework_to_user_agent
|
||||
from .._tools import FunctionTool
|
||||
from ..exceptions import ServiceInitializationError
|
||||
|
||||
logger: logging.Logger = get_logger("agent_framework.openai")
|
||||
logger: logging.Logger = logging.getLogger("agent_framework.openai")
|
||||
|
||||
|
||||
RESPONSE_TYPE = Union[
|
||||
@@ -53,9 +52,6 @@ else:
|
||||
from typing_extensions import TypedDict # type: ignore # pragma: no cover
|
||||
|
||||
|
||||
__all__ = ["OpenAISettings"]
|
||||
|
||||
|
||||
def _check_openai_version_for_callable_api_key() -> None:
|
||||
"""Check if OpenAI version supports callable API keys.
|
||||
|
||||
|
||||
@@ -334,8 +334,10 @@ async def test_integration_options(
|
||||
messages = [Message(role="user", text="What is the weather in Seattle?")]
|
||||
elif option_name == "response_format":
|
||||
# Use prompt that works well with structured output
|
||||
messages = [Message(role="user", text="The weather in Seattle is sunny")]
|
||||
messages.append(Message(role="user", text="What is the weather in Seattle?"))
|
||||
messages = [
|
||||
Message(role="user", text="The weather in Seattle is sunny"),
|
||||
Message(role="user", text="What is the weather in Seattle?"),
|
||||
]
|
||||
else:
|
||||
# Generic prompt for simple options
|
||||
messages = [Message(role="user", text="Say 'Hello World' briefly.")]
|
||||
@@ -396,7 +398,12 @@ async def test_integration_web_search() -> None:
|
||||
|
||||
for streaming in [False, True]:
|
||||
content = {
|
||||
"messages": "Who are the main characters of Kpop Demon Hunters? Do a web search to find the answer.",
|
||||
"messages": [
|
||||
Message(
|
||||
role="user",
|
||||
text="Who are the main characters of Kpop Demon Hunters? Do a web search to find the answer.",
|
||||
)
|
||||
],
|
||||
"options": {
|
||||
"tool_choice": "auto",
|
||||
"tools": [AzureOpenAIResponsesClient.get_web_search_tool()],
|
||||
@@ -416,7 +423,7 @@ async def test_integration_web_search() -> None:
|
||||
|
||||
# Test that the client will use the web search tool with location
|
||||
content = {
|
||||
"messages": "What is the current weather? Do not ask for my current location.",
|
||||
"messages": [Message(role="user", text="What is the current weather? Do not ask for my current location.")],
|
||||
"options": {
|
||||
"tool_choice": "auto",
|
||||
"tools": [
|
||||
@@ -498,7 +505,7 @@ async def test_integration_client_agent_hosted_mcp_tool() -> None:
|
||||
"""Integration test for MCP tool with Azure Response Agent using Microsoft Learn MCP."""
|
||||
client = AzureOpenAIResponsesClient(credential=AzureCliCredential())
|
||||
response = await client.get_response(
|
||||
"How to create an Azure storage account using az cli?",
|
||||
messages=[Message(role="user", text="How to create an Azure storage account using az cli?")],
|
||||
options={
|
||||
# this needs to be high enough to handle the full MCP tool response.
|
||||
"max_tokens": 5000,
|
||||
@@ -523,7 +530,7 @@ async def test_integration_client_agent_hosted_code_interpreter_tool():
|
||||
client = AzureOpenAIResponsesClient(credential=AzureCliCredential())
|
||||
|
||||
response = await client.get_response(
|
||||
"Calculate the sum of numbers from 1 to 10 using Python code.",
|
||||
messages=[Message(role="user", text="Calculate the sum of numbers from 1 to 10 using Python code.")],
|
||||
options={
|
||||
"tools": [AzureOpenAIResponsesClient.get_code_interpreter_tool()],
|
||||
},
|
||||
|
||||
@@ -43,6 +43,12 @@ async def test_agent_run(agent: SupportsAgentRun) -> None:
|
||||
assert response.messages[0].text == "Response"
|
||||
|
||||
|
||||
async def test_agent_run_with_content(agent: SupportsAgentRun) -> None:
|
||||
response = await agent.run(Content.from_text("test"))
|
||||
assert response.messages[0].role == "assistant"
|
||||
assert response.messages[0].text == "Response"
|
||||
|
||||
|
||||
async def test_agent_run_streaming(agent: SupportsAgentRun) -> None:
|
||||
async def collect_updates(updates: AsyncIterable[AgentResponseUpdate]) -> list[AgentResponseUpdate]:
|
||||
return [u async for u in updates]
|
||||
|
||||
@@ -21,13 +21,13 @@ def test_chat_client_type(client: SupportsChatGetResponse):
|
||||
|
||||
|
||||
async def test_chat_client_get_response(client: SupportsChatGetResponse):
|
||||
response = await client.get_response(Message(role="user", text="Hello"))
|
||||
response = await client.get_response([Message(role="user", text="Hello")])
|
||||
assert response.text == "test response"
|
||||
assert response.messages[0].role == "assistant"
|
||||
|
||||
|
||||
async def test_chat_client_get_response_streaming(client: SupportsChatGetResponse):
|
||||
async for update in client.get_response(Message(role="user", text="Hello"), stream=True):
|
||||
async for update in client.get_response([Message(role="user", text="Hello")], stream=True):
|
||||
assert update.text == "test streaming response " or update.text == "another update"
|
||||
assert update.role == "assistant"
|
||||
|
||||
@@ -38,13 +38,13 @@ def test_base_client(chat_client_base: SupportsChatGetResponse):
|
||||
|
||||
|
||||
async def test_base_client_get_response(chat_client_base: SupportsChatGetResponse):
|
||||
response = await chat_client_base.get_response(Message(role="user", text="Hello"))
|
||||
response = await chat_client_base.get_response([Message(role="user", text="Hello")])
|
||||
assert response.messages[0].role == "assistant"
|
||||
assert response.messages[0].text == "test response - Hello"
|
||||
|
||||
|
||||
async def test_base_client_get_response_streaming(chat_client_base: SupportsChatGetResponse):
|
||||
async for update in chat_client_base.get_response(Message(role="user", text="Hello"), stream=True):
|
||||
async for update in chat_client_base.get_response([Message(role="user", text="Hello")], stream=True):
|
||||
assert update.text == "update - Hello" or update.text == "another update"
|
||||
|
||||
|
||||
@@ -59,7 +59,9 @@ async def test_chat_client_instructions_handling(chat_client_base: SupportsChatG
|
||||
"_inner_get_response",
|
||||
side_effect=fake_inner_get_response,
|
||||
) as mock_inner_get_response:
|
||||
await chat_client_base.get_response("hello", options={"instructions": instructions})
|
||||
await chat_client_base.get_response(
|
||||
[Message(role="user", text="hello")], options={"instructions": instructions}
|
||||
)
|
||||
mock_inner_get_response.assert_called_once()
|
||||
_, kwargs = mock_inner_get_response.call_args
|
||||
messages = kwargs.get("messages", [])
|
||||
|
||||
@@ -38,7 +38,9 @@ async def test_base_client_with_function_calling(chat_client_base: SupportsChatG
|
||||
),
|
||||
ChatResponse(messages=Message(role="assistant", text="done")),
|
||||
]
|
||||
response = await chat_client_base.get_response("hello", options={"tool_choice": "auto", "tools": [ai_func]})
|
||||
response = await chat_client_base.get_response(
|
||||
[Message(role="user", text="hello")], options={"tool_choice": "auto", "tools": [ai_func]}
|
||||
)
|
||||
assert exec_counter == 1
|
||||
assert len(response.messages) == 3
|
||||
assert response.messages[0].role == "assistant"
|
||||
@@ -83,7 +85,9 @@ async def test_base_client_with_function_calling_resets(chat_client_base: Suppor
|
||||
),
|
||||
ChatResponse(messages=Message(role="assistant", text="done")),
|
||||
]
|
||||
response = await chat_client_base.get_response("hello", options={"tool_choice": "auto", "tools": [ai_func]})
|
||||
response = await chat_client_base.get_response(
|
||||
[Message(role="user", text="hello")], options={"tool_choice": "auto", "tools": [ai_func]}
|
||||
)
|
||||
assert exec_counter == 2
|
||||
assert len(response.messages) == 5
|
||||
assert response.messages[0].role == "assistant"
|
||||
@@ -388,11 +392,13 @@ async def test_function_invocation_scenarios(
|
||||
options["conversation_id"] = conversation_id
|
||||
|
||||
if not streaming:
|
||||
response = await chat_client_base.get_response("hello", options=options)
|
||||
response = await chat_client_base.get_response([Message(role="user", text="hello")], options=options)
|
||||
messages = response.messages
|
||||
else:
|
||||
updates = []
|
||||
async for update in chat_client_base.get_response("hello", options=options, stream=True):
|
||||
async for update in chat_client_base.get_response(
|
||||
[Message(role="user", text="hello")], options=options, stream=True
|
||||
):
|
||||
updates.append(update)
|
||||
messages = updates
|
||||
|
||||
@@ -776,7 +782,9 @@ async def test_max_iterations_limit(chat_client_base: SupportsChatGetResponse):
|
||||
# Set max_iterations to 1 in additional_properties
|
||||
chat_client_base.function_invocation_configuration["max_iterations"] = 1
|
||||
|
||||
response = await chat_client_base.get_response("hello", options={"tool_choice": "auto", "tools": [ai_func]})
|
||||
response = await chat_client_base.get_response(
|
||||
[Message(role="user", text="hello")], options={"tool_choice": "auto", "tools": [ai_func]}
|
||||
)
|
||||
|
||||
# With max_iterations=1, we should:
|
||||
# 1. Execute first function call (exec_counter=1)
|
||||
@@ -803,7 +811,9 @@ async def test_function_invocation_config_enabled_false(chat_client_base: Suppor
|
||||
# Disable function invocation
|
||||
chat_client_base.function_invocation_configuration["enabled"] = False
|
||||
|
||||
response = await chat_client_base.get_response("hello", options={"tool_choice": "auto", "tools": [ai_func]})
|
||||
response = await chat_client_base.get_response(
|
||||
[Message(role="user", text="hello")], options={"tool_choice": "auto", "tools": [ai_func]}
|
||||
)
|
||||
|
||||
# Function should not be executed - when enabled=False, the loop doesn't run
|
||||
assert exec_counter == 0
|
||||
@@ -859,7 +869,9 @@ async def test_function_invocation_config_max_consecutive_errors(chat_client_bas
|
||||
# Set max_consecutive_errors to 2
|
||||
chat_client_base.function_invocation_configuration["max_consecutive_errors_per_request"] = 2
|
||||
|
||||
response = await chat_client_base.get_response("hello", options={"tool_choice": "auto", "tools": [error_func]})
|
||||
response = await chat_client_base.get_response(
|
||||
[Message(role="user", text="hello")], options={"tool_choice": "auto", "tools": [error_func]}
|
||||
)
|
||||
|
||||
# Should stop after 2 consecutive errors and force a non-tool response
|
||||
error_results = [
|
||||
@@ -904,7 +916,9 @@ async def test_function_invocation_config_terminate_on_unknown_calls_false(chat_
|
||||
# Set terminate_on_unknown_calls to False (default)
|
||||
chat_client_base.function_invocation_configuration["terminate_on_unknown_calls"] = False
|
||||
|
||||
response = await chat_client_base.get_response("hello", options={"tool_choice": "auto", "tools": [known_func]})
|
||||
response = await chat_client_base.get_response(
|
||||
[Message(role="user", text="hello")], options={"tool_choice": "auto", "tools": [known_func]}
|
||||
)
|
||||
|
||||
# Should have a result message indicating the tool wasn't found
|
||||
assert len(response.messages) == 3
|
||||
@@ -940,7 +954,9 @@ async def test_function_invocation_config_terminate_on_unknown_calls_true(chat_c
|
||||
|
||||
# Should raise an exception when encountering an unknown function
|
||||
with pytest.raises(KeyError, match='Error: Requested function "unknown_function" not found'):
|
||||
await chat_client_base.get_response("hello", options={"tool_choice": "auto", "tools": [known_func]})
|
||||
await chat_client_base.get_response(
|
||||
[Message(role="user", text="hello")], options={"tool_choice": "auto", "tools": [known_func]}
|
||||
)
|
||||
|
||||
assert exec_counter == 0
|
||||
|
||||
@@ -978,7 +994,9 @@ async def test_function_invocation_config_additional_tools(chat_client_base: Sup
|
||||
chat_client_base.function_invocation_configuration["additional_tools"] = [hidden_func]
|
||||
|
||||
# Only pass visible_func in the tools parameter
|
||||
response = await chat_client_base.get_response("hello", options={"tool_choice": "auto", "tools": [visible_func]})
|
||||
response = await chat_client_base.get_response(
|
||||
[Message(role="user", text="hello")], options={"tool_choice": "auto", "tools": [visible_func]}
|
||||
)
|
||||
|
||||
# Additional tools are treated as declaration_only, so not executed
|
||||
# The function call should be in the messages but not executed
|
||||
@@ -1016,7 +1034,9 @@ async def test_function_invocation_config_include_detailed_errors_false(chat_cli
|
||||
# Set include_detailed_errors to False (default)
|
||||
chat_client_base.function_invocation_configuration["include_detailed_errors"] = False
|
||||
|
||||
response = await chat_client_base.get_response("hello", options={"tool_choice": "auto", "tools": [error_func]})
|
||||
response = await chat_client_base.get_response(
|
||||
[Message(role="user", text="hello")], options={"tool_choice": "auto", "tools": [error_func]}
|
||||
)
|
||||
|
||||
# Should have a generic error message
|
||||
error_result = next(
|
||||
@@ -1050,7 +1070,9 @@ async def test_function_invocation_config_include_detailed_errors_true(chat_clie
|
||||
# Set include_detailed_errors to True
|
||||
chat_client_base.function_invocation_configuration["include_detailed_errors"] = True
|
||||
|
||||
response = await chat_client_base.get_response("hello", options={"tool_choice": "auto", "tools": [error_func]})
|
||||
response = await chat_client_base.get_response(
|
||||
[Message(role="user", text="hello")], options={"tool_choice": "auto", "tools": [error_func]}
|
||||
)
|
||||
|
||||
# Should have detailed error message
|
||||
error_result = next(
|
||||
@@ -1120,7 +1142,9 @@ async def test_argument_validation_error_with_detailed_errors(chat_client_base:
|
||||
# Set include_detailed_errors to True
|
||||
chat_client_base.function_invocation_configuration["include_detailed_errors"] = True
|
||||
|
||||
response = await chat_client_base.get_response("hello", options={"tool_choice": "auto", "tools": [typed_func]})
|
||||
response = await chat_client_base.get_response(
|
||||
[Message(role="user", text="hello")], options={"tool_choice": "auto", "tools": [typed_func]}
|
||||
)
|
||||
|
||||
# Should have detailed validation error
|
||||
error_result = next(
|
||||
@@ -1154,7 +1178,9 @@ async def test_argument_validation_error_without_detailed_errors(chat_client_bas
|
||||
# Set include_detailed_errors to False (default)
|
||||
chat_client_base.function_invocation_configuration["include_detailed_errors"] = False
|
||||
|
||||
response = await chat_client_base.get_response("hello", options={"tool_choice": "auto", "tools": [typed_func]})
|
||||
response = await chat_client_base.get_response(
|
||||
[Message(role="user", text="hello")], options={"tool_choice": "auto", "tools": [typed_func]}
|
||||
)
|
||||
|
||||
# Should have generic validation error
|
||||
error_result = next(
|
||||
@@ -1219,7 +1245,9 @@ async def test_unapproved_tool_execution_raises_exception(chat_client_base: Supp
|
||||
]
|
||||
|
||||
# Get approval request
|
||||
response1 = await chat_client_base.get_response("hello", options={"tool_choice": "auto", "tools": [test_func]})
|
||||
response1 = await chat_client_base.get_response(
|
||||
[Message(role="user", text="hello")], options={"tool_choice": "auto", "tools": [test_func]}
|
||||
)
|
||||
|
||||
approval_req = [c for c in response1.messages[0].contents if c.type == "function_approval_request"][0]
|
||||
|
||||
@@ -1277,7 +1305,9 @@ async def test_approved_function_call_with_error_without_detailed_errors(chat_cl
|
||||
chat_client_base.function_invocation_configuration["include_detailed_errors"] = False
|
||||
|
||||
# Get approval request
|
||||
response1 = await chat_client_base.get_response("hello", options={"tool_choice": "auto", "tools": [error_func]})
|
||||
response1 = await chat_client_base.get_response(
|
||||
[Message(role="user", text="hello")], options={"tool_choice": "auto", "tools": [error_func]}
|
||||
)
|
||||
|
||||
approval_req = [c for c in response1.messages[0].contents if c.type == "function_approval_request"][0]
|
||||
|
||||
@@ -1340,7 +1370,9 @@ async def test_approved_function_call_with_error_with_detailed_errors(chat_clien
|
||||
chat_client_base.function_invocation_configuration["include_detailed_errors"] = True
|
||||
|
||||
# Get approval request
|
||||
response1 = await chat_client_base.get_response("hello", options={"tool_choice": "auto", "tools": [error_func]})
|
||||
response1 = await chat_client_base.get_response(
|
||||
[Message(role="user", text="hello")], options={"tool_choice": "auto", "tools": [error_func]}
|
||||
)
|
||||
|
||||
approval_req = [c for c in response1.messages[0].contents if c.type == "function_approval_request"][0]
|
||||
|
||||
@@ -1403,7 +1435,9 @@ async def test_approved_function_call_with_validation_error(chat_client_base: Su
|
||||
chat_client_base.function_invocation_configuration["include_detailed_errors"] = True
|
||||
|
||||
# Get approval request
|
||||
response1 = await chat_client_base.get_response("hello", options={"tool_choice": "auto", "tools": [typed_func]})
|
||||
response1 = await chat_client_base.get_response(
|
||||
[Message(role="user", text="hello")], options={"tool_choice": "auto", "tools": [typed_func]}
|
||||
)
|
||||
|
||||
approval_req = [c for c in response1.messages[0].contents if c.type == "function_approval_request"][0]
|
||||
|
||||
@@ -1459,7 +1493,9 @@ async def test_approved_function_call_successful_execution(chat_client_base: Sup
|
||||
]
|
||||
|
||||
# Get approval request
|
||||
response1 = await chat_client_base.get_response("hello", options={"tool_choice": "auto", "tools": [success_func]})
|
||||
response1 = await chat_client_base.get_response(
|
||||
[Message(role="user", text="hello")], options={"tool_choice": "auto", "tools": [success_func]}
|
||||
)
|
||||
|
||||
approval_req = [c for c in response1.messages[0].contents if c.type == "function_approval_request"][0]
|
||||
|
||||
@@ -1575,7 +1611,9 @@ async def test_multiple_function_calls_parallel_execution(chat_client_base: Supp
|
||||
ChatResponse(messages=Message(role="assistant", text="done")),
|
||||
]
|
||||
|
||||
response = await chat_client_base.get_response("hello", options={"tool_choice": "auto", "tools": [func1, func2]})
|
||||
response = await chat_client_base.get_response(
|
||||
[Message(role="user", text="hello")], options={"tool_choice": "auto", "tools": [func1, func2]}
|
||||
)
|
||||
|
||||
# Both functions should have been executed
|
||||
assert "func1_start" in exec_order
|
||||
@@ -1612,7 +1650,9 @@ async def test_callable_function_converted_to_tool(chat_client_base: SupportsCha
|
||||
]
|
||||
|
||||
# Pass plain function (will be auto-converted)
|
||||
response = await chat_client_base.get_response("hello", options={"tool_choice": "auto", "tools": [plain_function]})
|
||||
response = await chat_client_base.get_response(
|
||||
[Message(role="user", text="hello")], options={"tool_choice": "auto", "tools": [plain_function]}
|
||||
)
|
||||
|
||||
# Function should be executed
|
||||
assert exec_counter == 1
|
||||
@@ -1644,7 +1684,9 @@ async def test_conversation_id_handling(chat_client_base: SupportsChatGetRespons
|
||||
),
|
||||
]
|
||||
|
||||
response = await chat_client_base.get_response("hello", options={"tool_choice": "auto", "tools": [test_func]})
|
||||
response = await chat_client_base.get_response(
|
||||
[Message(role="user", text="hello")], options={"tool_choice": "auto", "tools": [test_func]}
|
||||
)
|
||||
|
||||
# Should have executed the function
|
||||
results = [content for msg in response.messages for content in msg.contents if content.type == "function_result"]
|
||||
@@ -1671,7 +1713,9 @@ async def test_function_result_appended_to_existing_assistant_message(chat_clien
|
||||
ChatResponse(messages=Message(role="assistant", text="done")),
|
||||
]
|
||||
|
||||
response = await chat_client_base.get_response("hello", options={"tool_choice": "auto", "tools": [test_func]})
|
||||
response = await chat_client_base.get_response(
|
||||
[Message(role="user", text="hello")], options={"tool_choice": "auto", "tools": [test_func]}
|
||||
)
|
||||
|
||||
# Should have messages with both function call and function result
|
||||
assert len(response.messages) >= 2
|
||||
@@ -1716,7 +1760,9 @@ async def test_error_recovery_resets_counter(chat_client_base: SupportsChatGetRe
|
||||
ChatResponse(messages=Message(role="assistant", text="done")),
|
||||
]
|
||||
|
||||
response = await chat_client_base.get_response("hello", options={"tool_choice": "auto", "tools": [sometimes_fails]})
|
||||
response = await chat_client_base.get_response(
|
||||
[Message(role="user", text="hello")], options={"tool_choice": "auto", "tools": [sometimes_fails]}
|
||||
)
|
||||
|
||||
# Should have both an error and a success
|
||||
error_results = [
|
||||
@@ -1990,7 +2036,9 @@ async def test_streaming_function_invocation_config_terminate_on_unknown_calls_t
|
||||
|
||||
# Should raise an exception when encountering an unknown function
|
||||
with pytest.raises(KeyError, match='Error: Requested function "unknown_function" not found'):
|
||||
async for _ in chat_client_base.get_response("hello", options={"tool_choice": "auto", "tools": [known_func]}):
|
||||
async for _ in chat_client_base.get_response(
|
||||
[Message(role="user", text="hello")], options={"tool_choice": "auto", "tools": [known_func]}
|
||||
):
|
||||
pass
|
||||
|
||||
assert exec_counter == 0
|
||||
|
||||
@@ -1,39 +0,0 @@
|
||||
# Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
|
||||
import pytest
|
||||
|
||||
from agent_framework import get_logger
|
||||
from agent_framework.exceptions import AgentFrameworkException
|
||||
|
||||
|
||||
def test_get_logger():
|
||||
"""Test that the logger is created with the correct name."""
|
||||
logger = get_logger()
|
||||
assert logger.name == "agent_framework"
|
||||
|
||||
|
||||
def test_get_logger_custom_name():
|
||||
"""Test that the logger can be created with a custom name."""
|
||||
custom_name = "agent_framework.custom"
|
||||
logger = get_logger(custom_name)
|
||||
assert logger.name == custom_name
|
||||
|
||||
|
||||
def test_get_logger_invalid_name():
|
||||
"""Test that an exception is raised for an invalid logger name."""
|
||||
with pytest.raises(AgentFrameworkException):
|
||||
get_logger("invalid_name")
|
||||
|
||||
|
||||
def test_log(caplog):
|
||||
"""Test that the logger can log messages and adheres to the expected format."""
|
||||
logger = get_logger()
|
||||
with caplog.at_level("DEBUG"):
|
||||
logger.debug("This is a debug message")
|
||||
assert len(caplog.records) == 1
|
||||
record = caplog.records[0]
|
||||
assert record.levelname == "DEBUG"
|
||||
assert record.message == "This is a debug message"
|
||||
assert record.name == "agent_framework"
|
||||
assert record.pathname.endswith("test_logging.py")
|
||||
@@ -1083,7 +1083,12 @@ async def test_integration_web_search() -> None:
|
||||
# Use static method for web search tool
|
||||
web_search_tool = OpenAIChatClient.get_web_search_tool()
|
||||
content = {
|
||||
"messages": "Who are the main characters of Kpop Demon Hunters? Do a web search to find the answer.",
|
||||
"messages": [
|
||||
Message(
|
||||
role="user",
|
||||
text="Who are the main characters of Kpop Demon Hunters? Do a web search to find the answer.",
|
||||
)
|
||||
],
|
||||
"options": {
|
||||
"tool_choice": "auto",
|
||||
"tools": [web_search_tool],
|
||||
@@ -1110,7 +1115,7 @@ async def test_integration_web_search() -> None:
|
||||
}
|
||||
)
|
||||
content = {
|
||||
"messages": "What is the current weather? Do not ask for my current location.",
|
||||
"messages": [Message(role="user", text="What is the current weather? Do not ask for my current location.")],
|
||||
"options": {
|
||||
"tool_choice": "auto",
|
||||
"tools": [web_search_tool_with_location],
|
||||
|
||||
@@ -2416,7 +2416,12 @@ async def test_integration_web_search() -> None:
|
||||
# Use static method for web search tool
|
||||
web_search_tool = OpenAIResponsesClient.get_web_search_tool()
|
||||
content = {
|
||||
"messages": "Who are the main characters of Kpop Demon Hunters? Do a web search to find the answer.",
|
||||
"messages": [
|
||||
Message(
|
||||
role="user",
|
||||
text="Who are the main characters of Kpop Demon Hunters? Do a web search to find the answer.",
|
||||
)
|
||||
],
|
||||
"options": {
|
||||
"tool_choice": "auto",
|
||||
"tools": [web_search_tool],
|
||||
@@ -2438,7 +2443,7 @@ async def test_integration_web_search() -> None:
|
||||
user_location={"country": "US", "city": "Seattle"},
|
||||
)
|
||||
content = {
|
||||
"messages": "What is the current weather? Do not ask for my current location.",
|
||||
"messages": [Message(role="user", text="What is the current weather? Do not ask for my current location.")],
|
||||
"options": {
|
||||
"tool_choice": "auto",
|
||||
"tools": [web_search_tool_with_location],
|
||||
|
||||
@@ -39,7 +39,7 @@ class _ToolCallingAgent(BaseAgent):
|
||||
|
||||
def run(
|
||||
self,
|
||||
messages: str | Message | Sequence[str | Message] | None = None,
|
||||
messages: str | Content | Message | Sequence[str | Content | Message] | None = None,
|
||||
*,
|
||||
stream: bool = False,
|
||||
session: AgentSession | None = None,
|
||||
|
||||
@@ -35,7 +35,7 @@ class _SimpleAgent(BaseAgent):
|
||||
|
||||
def run(
|
||||
self,
|
||||
messages: str | Message | Sequence[str | Message] | None = None,
|
||||
messages: str | Content | Message | Sequence[str | Content | Message] | None = None,
|
||||
*,
|
||||
stream: bool = False,
|
||||
session: AgentSession | None = None,
|
||||
@@ -105,7 +105,7 @@ class _CaptureAgent(BaseAgent):
|
||||
|
||||
def run(
|
||||
self,
|
||||
messages: str | Message | Sequence[str | Message] | None = None,
|
||||
messages: str | Content | Message | Sequence[str | Content | Message] | None = None,
|
||||
*,
|
||||
stream: bool = False,
|
||||
session: AgentSession | None = None,
|
||||
|
||||
@@ -835,7 +835,7 @@ class _StreamingTestAgent(BaseAgent):
|
||||
|
||||
def run(
|
||||
self,
|
||||
messages: str | Message | Sequence[str | Message] | None = None,
|
||||
messages: str | Content | Message | Sequence[str | Content | Message] | None = None,
|
||||
*,
|
||||
stream: bool = False,
|
||||
session: AgentSession | None = None,
|
||||
|
||||
@@ -52,7 +52,7 @@ class _KwargsCapturingAgent(BaseAgent):
|
||||
|
||||
def run(
|
||||
self,
|
||||
messages: str | Message | Sequence[str | Message] | None = None,
|
||||
messages: str | Content | Message | Sequence[str | Content | Message] | None = None,
|
||||
*,
|
||||
stream: bool = False,
|
||||
session: AgentSession | None = None,
|
||||
@@ -85,7 +85,7 @@ class _OptionsAwareAgent(BaseAgent):
|
||||
|
||||
def run(
|
||||
self,
|
||||
messages: str | Message | Sequence[str | Message] | None = None,
|
||||
messages: str | Content | Message | Sequence[str | Content | Message] | None = None,
|
||||
*,
|
||||
stream: bool = False,
|
||||
session: AgentSession | None = None,
|
||||
|
||||
Reference in New Issue
Block a user