Python: [BREAKING] Fix #3613 chat/agent message typing alignment (#3920)

* Fix #3613 message typing across chat and agents

* Address #3613 review feedback and sample input style

* refactor: use shared AgentRunMessages aliases (#3613)

* refactor: rename agent run input aliases for #3613

* samples: inline image content in run calls

* core: export AgentRunInputs from package init

* core: use explicit init re-exports without __all__

* updated logging and inits

* Fix core mypy export and samples XML note

* Remove AgentRunInputsOrNone and dedupe loggers

* Remove prepare_messages helper

* fix integration tests
This commit is contained in:
Eduard van Valkenburg
2026-02-16 16:27:25 +01:00
committed by GitHub
Unverified
parent 503eb10fdd
commit dc9439a75a
87 changed files with 422 additions and 578 deletions
+6 -1
View File
@@ -81,7 +81,12 @@ from agent_framework.azure import AzureOpenAIChatClient
## Public API and Exports
Define `__all__` in each module. Avoid `from module import *` in `__init__.py` files:
In `__init__.py` files that define package-level public APIs, use direct re-export imports plus an explicit
`__all__`. Avoid identity aliases like `from ._agents import ChatAgent as ChatAgent`, and avoid
`from module import *`.
Do not define `__all__` in internal non-`__init__.py` modules. Exception: modules intentionally exposed as a
public import surface (for example, `agent_framework.observability`) should define `__all__`.
```python
__all__ = ["ChatAgent", "Message", "ChatResponse"]
+19 -28
View File
@@ -130,27 +130,6 @@ user_msg = UserMessage(content="Hello, world!")
asst_msg = AssistantMessage(content="Hello, world!")
```
### Logging
Use the centralized logging system:
```python
from agent_framework import get_logger
# For main package
logger = get_logger()
# For subpackages
logger = get_logger('agent_framework.azure')
```
**Do not use** direct logging module imports:
```python
# ❌ Avoid this
import logging
logger = logging.getLogger(__name__)
```
### Import Structure
The package follows a flat import structure:
@@ -189,8 +168,6 @@ python/
│ │ ├── _clients.py # Chat client protocols and base classes
│ │ ├── _tools.py # Tool definitions
│ │ ├── _types.py # Type definitions
│ │ ├── _logging.py # Logging utilities
│ │ │
│ │ │ # Provider folders - lazy load from connector packages
│ │ ├── openai/ # OpenAI clients (built into core)
│ │ ├── azure/ # Lazy loads from azure-ai, azure-ai-search, azurefunctions
@@ -405,12 +382,15 @@ If in doubt, use the link above to read much more considerations of what to do a
**All wildcard imports (`from ... import *`) are prohibited** in production code, including both `.py` and `.pyi` files. Always use explicit import lists to maintain clarity and avoid namespace pollution.
Define `__all__` in each module to explicitly declare the public API, then import specific symbols by name:
Do not use ``__all__`` in internal modules. Define it in the ``__init__`` file of the level you want to expose.
If a non-``__init__`` module is intentionally part of the public API surface (for example, ``observability.py``),
it should define ``__all__`` as well.
Also avoid identity alias imports in ``__init__`` files. Use ``from ._module import Symbol`` instead of
``from ._module import Symbol as Symbol``.
```python
# ✅ Preferred - explicit __all__ and named imports
__all__ = ["Agent", "Message", "ChatResponse"]
from ._agents import Agent
from ._types import Message, ChatResponse
@@ -422,9 +402,20 @@ from ._types import (
ResponseStream,
)
__all__ = [
"Agent",
"AgentResponse",
"ChatResponse",
"Message",
"ResponseStream",
]
# ❌ Prohibited pattern: wildcard/star imports (do not use)
# from ._agents import <all public symbols>
# from ._types import <all public symbols>
# from ._agents import *
# from ._types import *
# ❌ Prohibited pattern: identity alias imports (do not use)
# from ._agents import Agent as Agent
```
**Rationale:**
@@ -40,6 +40,7 @@ from agent_framework import (
normalize_messages,
prepend_agent_framework_to_user_agent,
)
from agent_framework._types import AgentRunInputs
from agent_framework.observability import AgentTelemetryLayer
__all__ = ["A2AAgent", "A2AContinuationToken"]
@@ -208,7 +209,7 @@ class A2AAgent(AgentTelemetryLayer, BaseAgent):
@overload
def run(
self,
messages: str | Message | Sequence[str | Message] | None = None,
messages: AgentRunInputs | None = None,
*,
stream: Literal[False] = ...,
session: AgentSession | None = None,
@@ -220,7 +221,7 @@ class A2AAgent(AgentTelemetryLayer, BaseAgent):
@overload
def run(
self,
messages: str | Message | Sequence[str | Message] | None = None,
messages: AgentRunInputs | None = None,
*,
stream: Literal[True],
session: AgentSession | None = None,
@@ -231,7 +232,7 @@ class A2AAgent(AgentTelemetryLayer, BaseAgent):
def run(
self,
messages: str | Message | Sequence[str | Message] | None = None,
messages: AgentRunInputs | None = None,
*,
stream: bool = False,
session: AgentSession | None = None,
@@ -49,7 +49,7 @@ if TYPE_CHECKING:
from ._types import AGUIChatOptions
logger: logging.Logger = logging.getLogger(__name__)
logger: logging.Logger = logging.getLogger("agent_framework.ag_ui")
def _unwrap_server_function_call_contents(contents: MutableSequence[Content | dict[str, Any]]) -> None:
@@ -277,9 +277,6 @@ class AGUIChatClient(
registered: set[str] = getattr(self, "_registered_server_tools", set())
registered.add(tool_name)
self._registered_server_tools = registered # type: ignore[attr-defined]
from agent_framework._logging import get_logger
logger = get_logger()
logger.debug(f"[AGUIChatClient] Registered server placeholder: {tool_name}")
def _extract_state_from_messages(self, messages: Sequence[Message]) -> tuple[list[Message], dict[str, Any] | None]:
@@ -310,9 +307,6 @@ class AGUIChatClient(
messages_without_state = list(messages[:-1]) if len(messages) > 1 else []
return messages_without_state, state
except (json.JSONDecodeError, ValueError, KeyError) as e:
from agent_framework._logging import get_logger
logger = get_logger()
logger.warning(f"Failed to extract state from message: {e}")
return list(messages), None
@@ -408,9 +408,6 @@ def agui_messages_to_agent_framework(messages: list[dict[str, Any]]) -> list[Mes
approval_payload_text = result_content if isinstance(result_content, str) else json.dumps(parsed)
# Log the full approval payload to debug modified arguments
import logging
logger = logging.getLogger(__name__)
logger.info(f"Approval payload received: {parsed}")
approval_call_id = tool_call_id
@@ -11,7 +11,7 @@ import asyncio
import os
from typing import cast
from agent_framework import ChatResponse, ChatResponseUpdate, ResponseStream
from agent_framework import ChatResponse, ChatResponseUpdate, Message, ResponseStream
from agent_framework.ag_ui import AGUIChatClient
@@ -44,7 +44,7 @@ async def main():
metadata = {"thread_id": thread_id} if thread_id else None
stream = client.get_response(
message,
[Message(role="user", text=message)],
stream=True,
options={"metadata": metadata} if metadata else None,
)
@@ -15,7 +15,7 @@ import asyncio
import os
from typing import cast
from agent_framework import ChatResponse, ChatResponseUpdate, ResponseStream, tool
from agent_framework import ChatResponse, ChatResponseUpdate, Message, ResponseStream, tool
from agent_framework.ag_ui import AGUIChatClient
@@ -73,7 +73,7 @@ async def streaming_example(client: AGUIChatClient, thread_id: str | None = None
print("Assistant: ", end="", flush=True)
stream = client.get_response(
"Tell me a short joke",
[Message(role="user", text="Tell me a short joke")],
stream=True,
options={"metadata": metadata} if metadata else None,
)
@@ -100,7 +100,7 @@ async def non_streaming_example(client: AGUIChatClient, thread_id: str | None =
print("\nUser: What is 2 + 2?\n")
response = await client.get_response("What is 2 + 2?", metadata=metadata)
response = await client.get_response([Message(role="user", text="What is 2 + 2?")], metadata=metadata)
print(f"Assistant: {response.text}")
@@ -139,7 +139,7 @@ async def tool_example(client: AGUIChatClient, thread_id: str | None = None):
print("(Server must be configured with matching tools to execute them)\n")
response = await client.get_response(
"What's the weather in Seattle?", tools=[get_weather, calculate], metadata=metadata
[Message(role="user", text="What's the weather in Seattle?")], tools=[get_weather, calculate], metadata=metadata
)
print(f"Assistant: {response.text}")
@@ -174,14 +174,16 @@ async def conversation_example(client: AGUIChatClient):
# First turn
print("User: My name is Alice\n")
response1 = await client.get_response("My name is Alice")
response1 = await client.get_response([Message(role="user", text="My name is Alice")])
print(f"Assistant: {response1.text}")
thread_id = response1.additional_properties.get("thread_id")
print(f"\n[Thread: {thread_id}]")
# Second turn - using same thread
print("\nUser: What's my name?\n")
response2 = await client.get_response("What's my name?", options={"metadata": {"thread_id": thread_id}})
response2 = await client.get_response(
[Message(role="user", text="What's my name?")], options={"metadata": {"thread_id": thread_id}}
)
print(f"Assistant: {response2.text}")
# Check if context was maintained
@@ -191,7 +193,9 @@ async def conversation_example(client: AGUIChatClient):
# Third turn
print("\nUser: Can you also tell me what 10 * 5 is?\n")
response3 = await client.get_response(
"Can you also tell me what 10 * 5 is?", options={"metadata": {"thread_id": thread_id}}, tools=[calculate]
[Message(role="user", text="Can you also tell me what 10 * 5 is?")],
options={"metadata": {"thread_id": thread_id}},
tools=[calculate],
)
print(f"Assistant: {response3.text}")
@@ -55,7 +55,7 @@ class StreamingChatClientStub(
@overload
def get_response(
self,
messages: str | Message | Sequence[str | Message],
messages: Sequence[Message],
*,
stream: Literal[False] = ...,
options: ChatOptions[Any],
@@ -65,7 +65,7 @@ class StreamingChatClientStub(
@overload
def get_response(
self,
messages: str | Message | Sequence[str | Message],
messages: Sequence[Message],
*,
stream: Literal[False] = ...,
options: OptionsCoT | ChatOptions[None] | None = ...,
@@ -75,7 +75,7 @@ class StreamingChatClientStub(
@overload
def get_response(
self,
messages: str | Message | Sequence[str | Message],
messages: Sequence[Message],
*,
stream: Literal[True],
options: OptionsCoT | ChatOptions[Any] | None = ...,
@@ -84,7 +84,7 @@ class StreamingChatClientStub(
def get_response(
self,
messages: str | Message | Sequence[str | Message],
messages: Sequence[Message],
*,
stream: bool = False,
options: OptionsCoT | ChatOptions[Any] | None = None,
@@ -175,7 +175,7 @@ class StubAgent(SupportsAgentRun):
@overload
def run(
self,
messages: str | Message | Sequence[str | Message] | None = None,
messages: str | Content | Message | Sequence[str | Content | Message] | None = None,
*,
stream: Literal[False] = ...,
session: AgentSession | None = None,
@@ -185,7 +185,7 @@ class StubAgent(SupportsAgentRun):
@overload
def run(
self,
messages: str | Message | Sequence[str | Message] | None = None,
messages: str | Content | Message | Sequence[str | Content | Message] | None = None,
*,
stream: Literal[True],
session: AgentSession | None = None,
@@ -194,7 +194,7 @@ class StubAgent(SupportsAgentRun):
def run(
self,
messages: str | Message | Sequence[str | Message] | None = None,
messages: str | Content | Message | Sequence[str | Content | Message] | None = None,
*,
stream: bool = False,
session: AgentSession | None = None,
@@ -2,6 +2,7 @@
from __future__ import annotations
import logging
import sys
from collections.abc import AsyncIterable, Awaitable, Mapping, MutableMapping, Sequence
from typing import Any, ClassVar, Final, Generic, Literal, TypedDict
@@ -24,7 +25,6 @@ from agent_framework import (
ResponseStream,
TextSpanRegion,
UsageDetails,
get_logger,
)
from agent_framework._settings import SecretString, load_settings
from agent_framework._types import _get_data_bytes_as_str # type: ignore
@@ -68,7 +68,7 @@ __all__ = [
"ThinkingConfig",
]
logger = get_logger("agent_framework.anthropic")
logger = logging.getLogger("agent_framework.anthropic")
ANTHROPIC_DEFAULT_MAX_TOKENS: Final[int] = 1024
BETA_FLAGS: Final[list[str]] = ["mcp-client-2025-04-04", "code-execution-2025-08-25"]
@@ -8,12 +8,12 @@ This module provides ``AzureAISearchContextProvider``, built on the new
from __future__ import annotations
import logging
import sys
from collections.abc import Awaitable, Callable
from typing import TYPE_CHECKING, Any, ClassVar, Literal, TypedDict
from agent_framework import AGENT_FRAMEWORK_USER_AGENT, Message
from agent_framework._logging import get_logger
from agent_framework._sessions import AgentSession, BaseContextProvider, SessionContext
from agent_framework._settings import SecretString, load_settings
from agent_framework.exceptions import ServiceInitializationError
@@ -103,7 +103,7 @@ try:
except ImportError:
_agentic_retrieval_available = False
logger = get_logger(__name__)
logger = logging.getLogger("agent_framework.azure_ai_search")
_DEFAULT_AGENTIC_MESSAGE_HISTORY_COUNT = 10
@@ -4,6 +4,7 @@ from __future__ import annotations
import ast
import json
import logging
import os
import re
import sys
@@ -31,7 +32,6 @@ from agent_framework import (
Role,
TextSpanRegion,
UsageDetails,
get_logger,
)
from agent_framework._settings import load_settings
from agent_framework.exceptions import ServiceInitializationError, ServiceInvalidRequestError, ServiceResponseException
@@ -102,7 +102,7 @@ else:
from typing_extensions import Self, TypedDict # type: ignore # pragma: no cover
logger = get_logger("agent_framework.azure")
logger = logging.getLogger("agent_framework.azure")
__all__ = ["AzureAIAgentClient", "AzureAIAgentOptions"]
@@ -3,6 +3,7 @@
from __future__ import annotations
import json
import logging
import sys
from collections.abc import Callable, Mapping, MutableMapping, Sequence
from contextlib import suppress
@@ -19,7 +20,6 @@ from agent_framework import (
FunctionTool,
Message,
MiddlewareTypes,
get_logger,
)
from agent_framework._settings import load_settings
from agent_framework.exceptions import ServiceInitializationError
@@ -59,7 +59,7 @@ else:
from typing_extensions import Self, TypedDict # type: ignore # pragma: no cover
logger = get_logger("agent_framework.azure")
logger = logging.getLogger("agent_framework.azure")
class AzureAIProjectAgentOptions(OpenAIResponsesOptions, total=False):
@@ -2,6 +2,7 @@
from __future__ import annotations
import logging
import sys
from collections.abc import Callable, MutableMapping, Sequence
from typing import Any, Generic
@@ -12,7 +13,6 @@ from agent_framework import (
BaseContextProvider,
FunctionTool,
MiddlewareTypes,
get_logger,
normalize_tools,
)
from agent_framework._mcp import MCPTool
@@ -43,7 +43,7 @@ else:
from typing_extensions import Self, TypedDict # type: ignore # pragma: no cover
logger = get_logger("agent_framework.azure")
logger = logging.getLogger("agent_framework.azure")
# Type variable for options - allows typed Agent[OptionsT] returns
@@ -2,13 +2,13 @@
from __future__ import annotations
import logging
import sys
from collections.abc import Mapping, MutableMapping, Sequence
from typing import Any, cast
from agent_framework import (
FunctionTool,
get_logger,
)
from agent_framework.exceptions import ServiceInvalidRequestError
from azure.ai.agents.models import (
@@ -37,7 +37,7 @@ if sys.version_info >= (3, 11):
else:
from typing_extensions import TypedDict # type: ignore # pragma: no cover
logger = get_logger("agent_framework.azure")
logger = logging.getLogger("agent_framework.azure")
class AzureAISettings(TypedDict, total=False):
@@ -134,7 +134,7 @@ def create_test_azure_ai_client(
client._created_agent_tool_names = set() # type: ignore
client._created_agent_structured_output_signature = None # type: ignore
client.additional_properties = {}
client.middleware = None
client.chat_middleware = []
# Mock the OpenAI client attribute
mock_openai_client = MagicMock()
@@ -1546,7 +1546,12 @@ async def test_integration_web_search() -> None:
async with temporary_chat_client(agent_name="af-int-test-web-search") as client:
for streaming in [False, True]:
content = {
"messages": "Who are the main characters of Kpop Demon Hunters? Do a web search to find the answer.",
"messages": [
Message(
role="user",
text="Who are the main characters of Kpop Demon Hunters? Do a web search to find the answer.",
)
],
"options": {
"tool_choice": "auto",
"tools": [client.get_web_search_tool()],
@@ -1565,7 +1570,9 @@ async def test_integration_web_search() -> None:
# Test that the client will use the web search tool with location
content = {
"messages": "What is the current weather? Do not ask for my current location.",
"messages": [
Message(role="user", text="What is the current weather? Do not ask for my current location.")
],
"options": {
"tool_choice": "auto",
"tools": [client.get_web_search_tool(user_location={"country": "US", "city": "Seattle"})],
@@ -1584,7 +1591,7 @@ async def test_integration_agent_hosted_mcp_tool() -> None:
"""Integration test for MCP tool with Azure Response Agent using Microsoft Learn MCP."""
async with temporary_chat_client(agent_name="af-int-test-mcp") as client:
response = await client.get_response(
"How to create an Azure storage account using az cli?",
messages=[Message(role="user", text="How to create an Azure storage account using az cli?")],
options={
# this needs to be high enough to handle the full MCP tool response.
"max_tokens": 5000,
@@ -1608,7 +1615,7 @@ async def test_integration_agent_hosted_code_interpreter_tool():
"""Test Azure Responses Client agent with code interpreter tool through AzureAIClient."""
async with temporary_chat_client(agent_name="af-int-test-code-interpreter") as client:
response = await client.get_response(
"Calculate the sum of numbers from 1 to 10 using Python code.",
messages=[Message(role="user", text="Calculate the sum of numbers from 1 to 10 using Python code.")],
options={
"tools": [client.get_code_interpreter_tool()],
},
@@ -9,6 +9,7 @@ with Azure Durable Entities, enabling stateful and durable AI agent execution.
from __future__ import annotations
import json
import logging
import re
import uuid
from collections.abc import Callable, Mapping
@@ -18,7 +19,7 @@ from typing import TYPE_CHECKING, Any, TypeVar, cast
import azure.durable_functions as df
import azure.functions as func
from agent_framework import SupportsAgentRun, get_logger
from agent_framework import SupportsAgentRun
from agent_framework_durabletask import (
DEFAULT_MAX_POLL_RETRIES,
DEFAULT_POLL_INTERVAL_SECONDS,
@@ -42,7 +43,7 @@ from ._entities import create_agent_entity
from ._errors import IncomingRequestError
from ._orchestration import AgentOrchestrationContextType, AgentTask, AzureFunctionsAgentExecutor
logger = get_logger("agent_framework.azurefunctions")
logger = logging.getLogger("agent_framework.azurefunctions")
EntityHandler = Callable[[df.DurableEntityContext], None]
HandlerT = TypeVar("HandlerT", bound=Callable[..., Any])
@@ -10,18 +10,19 @@ allows for long-running agent conversations.
from __future__ import annotations
import asyncio
import logging
from collections.abc import Callable
from typing import Any, cast
import azure.durable_functions as df
from agent_framework import SupportsAgentRun, get_logger
from agent_framework import SupportsAgentRun
from agent_framework_durabletask import (
AgentEntity,
AgentEntityStateProviderMixin,
AgentResponseCallbackProtocol,
)
logger = get_logger("agent_framework.azurefunctions.entities")
logger = logging.getLogger("agent_framework.azurefunctions")
class AzureFunctionEntityStateProvider(AgentEntityStateProviderMixin):
@@ -5,11 +5,12 @@
This module provides support for using agents inside Durable Function orchestrations.
"""
import logging
from collections.abc import Callable
from typing import TYPE_CHECKING, Any, TypeAlias
import azure.durable_functions as df
from agent_framework import AgentSession, get_logger
from agent_framework import AgentSession
from agent_framework_durabletask import (
DurableAgentExecutor,
RunRequest,
@@ -21,7 +22,7 @@ from azure.durable_functions.models.actions.NoOpAction import NoOpAction
from azure.durable_functions.models.Task import CompoundTask, TaskState
from pydantic import BaseModel
logger = get_logger("agent_framework.azurefunctions.orchestration")
logger = logging.getLogger("agent_framework.azurefunctions")
CompoundActionConstructor: TypeAlias = Callable[[list[Any]], Any] | None
@@ -4,6 +4,7 @@ from __future__ import annotations
import asyncio
import json
import logging
import sys
from collections import deque
from collections.abc import AsyncIterable, Awaitable, Mapping, MutableMapping, Sequence
@@ -26,7 +27,6 @@ from agent_framework import (
Message,
ResponseStream,
UsageDetails,
get_logger,
validate_tool_mode,
)
from agent_framework._settings import SecretString, load_settings
@@ -50,7 +50,7 @@ if sys.version_info >= (3, 11):
else:
from typing_extensions import TypedDict # type: ignore # pragma: no cover
logger = get_logger("agent_framework.bedrock")
logger = logging.getLogger("agent_framework.bedrock")
__all__ = [
@@ -3,6 +3,7 @@
from __future__ import annotations
import contextlib
import logging
import sys
from collections.abc import AsyncIterable, Awaitable, Callable, MutableMapping, Sequence
from pathlib import Path
@@ -19,11 +20,10 @@ from agent_framework import (
FunctionTool,
Message,
ResponseStream,
get_logger,
normalize_messages,
)
from agent_framework._settings import load_settings
from agent_framework._types import normalize_tools
from agent_framework._types import AgentRunInputs, normalize_tools
from agent_framework.exceptions import ServiceException
from claude_agent_sdk import (
AssistantMessage,
@@ -61,7 +61,7 @@ if TYPE_CHECKING:
__all__ = ["ClaudeAgent", "ClaudeAgentOptions"]
logger = get_logger("agent_framework.claude")
logger = logging.getLogger("agent_framework.claude")
# Name of the in-process MCP server that hosts Agent Framework tools.
# FunctionTool instances are converted to SDK MCP tools and served
@@ -557,7 +557,7 @@ class ClaudeAgent(BaseAgent, Generic[OptionsT]):
@overload
def run(
self,
messages: str | Message | Sequence[str | Message] | None = None,
messages: AgentRunInputs | None = None,
*,
stream: Literal[True],
session: AgentSession | None = None,
@@ -568,7 +568,7 @@ class ClaudeAgent(BaseAgent, Generic[OptionsT]):
@overload
async def run(
self,
messages: str | Message | Sequence[str | Message] | None = None,
messages: AgentRunInputs | None = None,
*,
stream: Literal[False] = ...,
session: AgentSession | None = None,
@@ -578,7 +578,7 @@ class ClaudeAgent(BaseAgent, Generic[OptionsT]):
def run(
self,
messages: str | Message | Sequence[str | Message] | None = None,
messages: AgentRunInputs | None = None,
*,
stream: bool = False,
session: AgentSession | None = None,
@@ -612,7 +612,7 @@ class ClaudeAgent(BaseAgent, Generic[OptionsT]):
async def _get_stream(
self,
messages: str | Message | Sequence[str | Message] | None = None,
messages: AgentRunInputs | None = None,
*,
session: AgentSession | None = None,
options: OptionsT | MutableMapping[str, Any] | None = None,
@@ -18,6 +18,7 @@ from agent_framework import (
normalize_messages,
)
from agent_framework._settings import load_settings
from agent_framework._types import AgentRunInputs
from agent_framework.exceptions import ServiceException, ServiceInitializationError
from microsoft_agents.copilotstudio.client import AgentType, ConnectionSettings, CopilotClient, PowerPlatformCloud
@@ -187,7 +188,7 @@ class CopilotStudioAgent(BaseAgent):
@overload
def run(
self,
messages: str | Message | list[str] | list[Message] | None = None,
messages: AgentRunInputs | None = None,
*,
stream: Literal[False] = False,
session: AgentSession | None = None,
@@ -197,7 +198,7 @@ class CopilotStudioAgent(BaseAgent):
@overload
def run(
self,
messages: str | Message | list[str] | list[Message] | None = None,
messages: AgentRunInputs | None = None,
*,
stream: Literal[True],
session: AgentSession | None = None,
@@ -206,7 +207,7 @@ class CopilotStudioAgent(BaseAgent):
def run(
self,
messages: str | Message | list[str] | list[Message] | None = None,
messages: AgentRunInputs | None = None,
*,
stream: bool = False,
session: AgentSession | None = None,
@@ -236,7 +237,7 @@ class CopilotStudioAgent(BaseAgent):
async def _run_impl(
self,
messages: str | Message | list[str] | list[Message] | None = None,
messages: AgentRunInputs | None = None,
*,
session: AgentSession | None = None,
**kwargs: Any,
@@ -261,7 +262,7 @@ class CopilotStudioAgent(BaseAgent):
def _run_stream_impl(
self,
messages: str | Message | list[str] | list[Message] | None = None,
messages: AgentRunInputs | None = None,
*,
session: AgentSession | None = None,
**kwargs: Any,
@@ -19,7 +19,6 @@ from ._clients import (
SupportsMCPTool,
SupportsWebSearchTool,
)
from ._logging import get_logger, setup_logging
from ._mcp import MCPStdioTool, MCPStreamableHTTPTool, MCPWebsocketTool
from ._middleware import (
AgentContext,
@@ -34,7 +33,6 @@ from ._middleware import (
FunctionInvocationContext,
FunctionMiddleware,
FunctionMiddlewareTypes,
MiddlewareException,
MiddlewareTermination,
MiddlewareType,
MiddlewareTypes,
@@ -67,6 +65,7 @@ from ._tools import (
from ._types import (
AgentResponse,
AgentResponseUpdate,
AgentRunInputs,
Annotation,
ChatOptions,
ChatResponse,
@@ -152,6 +151,7 @@ from ._workflows import (
response_handler,
validate_workflow_graph,
)
from .exceptions import MiddlewareException
__all__ = [
"AGENT_FRAMEWORK_USER_AGENT",
@@ -169,6 +169,7 @@ __all__ = [
"AgentMiddlewareTypes",
"AgentResponse",
"AgentResponseUpdate",
"AgentRunInputs",
"AgentSession",
"Annotation",
"BaseAgent",
@@ -264,6 +265,7 @@ __all__ = [
"WorkflowRunnerException",
"WorkflowValidationError",
"WorkflowViz",
"__version__",
"add_usage_details",
"agent_middleware",
"chat_middleware",
@@ -271,7 +273,6 @@ __all__ = [
"detect_media_type_from_base64",
"executor",
"function_middleware",
"get_logger",
"handler",
"map_chat_to_agent_update",
"merge_chat_options",
@@ -283,7 +284,6 @@ __all__ = [
"register_state_type",
"resolve_agent_id",
"response_handler",
"setup_logging",
"tool",
"validate_chat_options",
"validate_tool_mode",
+11 -13
View File
@@ -3,6 +3,7 @@
from __future__ import annotations
import inspect
import logging
import re
import sys
from collections.abc import Awaitable, Callable, Mapping, MutableMapping, Sequence
@@ -29,7 +30,6 @@ from mcp.shared.exceptions import McpError
from pydantic import BaseModel, Field, create_model
from ._clients import BaseChatClient, SupportsChatGetResponse
from ._logging import get_logger
from ._mcp import LOG_LEVEL_MAPPING, MCPTool
from ._middleware import AgentMiddlewareLayer, MiddlewareTypes
from ._serialization import SerializationMixin
@@ -41,6 +41,7 @@ from ._tools import (
from ._types import (
AgentResponse,
AgentResponseUpdate,
AgentRunInputs,
ChatResponse,
ChatResponseUpdate,
Message,
@@ -67,7 +68,7 @@ else:
if TYPE_CHECKING:
from ._types import ChatOptions
logger = get_logger("agent_framework")
logger = logging.getLogger("agent_framework")
ResponseModelBoundT = TypeVar("ResponseModelBoundT", bound=BaseModel)
OptionsCoT = TypeVar(
@@ -159,9 +160,6 @@ class _RunContext(TypedDict):
finalize_kwargs: dict[str, Any]
__all__ = ["Agent", "BaseAgent", "RawAgent", "SupportsAgentRun"]
# region Agent Protocol
@@ -230,7 +228,7 @@ class SupportsAgentRun(Protocol):
@overload
def run(
self,
messages: str | Message | Sequence[str | Message] | None = None,
messages: AgentRunInputs | None = None,
*,
stream: Literal[False] = ...,
session: AgentSession | None = None,
@@ -242,7 +240,7 @@ class SupportsAgentRun(Protocol):
@overload
def run(
self,
messages: str | Message | Sequence[str | Message] | None = None,
messages: AgentRunInputs | None = None,
*,
stream: Literal[True],
session: AgentSession | None = None,
@@ -253,7 +251,7 @@ class SupportsAgentRun(Protocol):
def run(
self,
messages: str | Message | Sequence[str | Message] | None = None,
messages: AgentRunInputs | None = None,
*,
stream: bool = False,
session: AgentSession | None = None,
@@ -763,7 +761,7 @@ class RawAgent(BaseAgent, Generic[OptionsCoT]): # type: ignore[misc]
@overload
def run(
self,
messages: str | Message | Sequence[str | Message] | None = None,
messages: AgentRunInputs | None = None,
*,
stream: Literal[False] = ...,
session: AgentSession | None = None,
@@ -780,7 +778,7 @@ class RawAgent(BaseAgent, Generic[OptionsCoT]): # type: ignore[misc]
@overload
def run(
self,
messages: str | Message | Sequence[str | Message] | None = None,
messages: AgentRunInputs | None = None,
*,
stream: Literal[False] = ...,
session: AgentSession | None = None,
@@ -797,7 +795,7 @@ class RawAgent(BaseAgent, Generic[OptionsCoT]): # type: ignore[misc]
@overload
def run(
self,
messages: str | Message | Sequence[str | Message] | None = None,
messages: AgentRunInputs | None = None,
*,
stream: Literal[True],
session: AgentSession | None = None,
@@ -813,7 +811,7 @@ class RawAgent(BaseAgent, Generic[OptionsCoT]): # type: ignore[misc]
def run(
self,
messages: str | Message | Sequence[str | Message] | None = None,
messages: AgentRunInputs | None = None,
*,
stream: bool = False,
session: AgentSession | None = None,
@@ -1000,7 +998,7 @@ class RawAgent(BaseAgent, Generic[OptionsCoT]): # type: ignore[misc]
async def _prepare_run_context(
self,
*,
messages: str | Message | Sequence[str | Message] | None,
messages: AgentRunInputs | None,
session: AgentSession | None,
tools: FunctionTool
| Callable[..., Any]
@@ -2,6 +2,7 @@
from __future__ import annotations
import logging
import sys
from abc import ABC, abstractmethod
from collections.abc import (
@@ -27,7 +28,6 @@ from typing import (
from pydantic import BaseModel
from ._logging import get_logger
from ._serialization import SerializationMixin
from ._tools import (
FunctionInvocationConfiguration,
@@ -38,7 +38,6 @@ from ._types import (
ChatResponseUpdate,
Message,
ResponseStream,
prepare_messages,
validate_chat_options,
)
@@ -61,17 +60,7 @@ InputT = TypeVar("InputT", contravariant=True)
EmbeddingT = TypeVar("EmbeddingT")
BaseChatClientT = TypeVar("BaseChatClientT", bound="BaseChatClient")
logger = get_logger()
__all__ = [
"BaseChatClient",
"SupportsChatGetResponse",
"SupportsCodeInterpreterTool",
"SupportsFileSearchTool",
"SupportsImageGenerationTool",
"SupportsMCPTool",
"SupportsWebSearchTool",
]
logger = logging.getLogger("agent_framework")
# region SupportsChatGetResponse Protocol
@@ -139,7 +128,7 @@ class SupportsChatGetResponse(Protocol[OptionsContraT]):
@overload
def get_response(
self,
messages: str | Message | Sequence[str | Message],
messages: Sequence[Message],
*,
stream: Literal[False] = ...,
options: ChatOptions[ResponseModelBoundT],
@@ -149,7 +138,7 @@ class SupportsChatGetResponse(Protocol[OptionsContraT]):
@overload
def get_response(
self,
messages: str | Message | Sequence[str | Message],
messages: Sequence[Message],
*,
stream: Literal[False] = ...,
options: OptionsContraT | ChatOptions[None] | None = None,
@@ -159,7 +148,7 @@ class SupportsChatGetResponse(Protocol[OptionsContraT]):
@overload
def get_response(
self,
messages: str | Message | Sequence[str | Message],
messages: Sequence[Message],
*,
stream: Literal[True],
options: OptionsContraT | ChatOptions[Any] | None = None,
@@ -168,7 +157,7 @@ class SupportsChatGetResponse(Protocol[OptionsContraT]):
def get_response(
self,
messages: str | Message | Sequence[str | Message],
messages: Sequence[Message],
*,
stream: bool = False,
options: OptionsContraT | ChatOptions[Any] | None = None,
@@ -254,9 +243,9 @@ class BaseChatClient(SerializationMixin, ABC, Generic[OptionsCoT]):
client = CustomChatClient()
# Use the client to get responses
response = await client.get_response("Hello, how are you?")
response = await client.get_response([Message(role="user", text="Hello, how are you?")])
# Or stream responses
async for update in client.get_response("Hello!", stream=True):
async for update in client.get_response([Message(role="user", text="Hello!")], stream=True):
print(update)
"""
@@ -376,7 +365,7 @@ class BaseChatClient(SerializationMixin, ABC, Generic[OptionsCoT]):
@overload
def get_response(
self,
messages: str | Message | Sequence[str | Message],
messages: Sequence[Message],
*,
stream: Literal[False] = ...,
options: ChatOptions[ResponseModelBoundT],
@@ -386,7 +375,7 @@ class BaseChatClient(SerializationMixin, ABC, Generic[OptionsCoT]):
@overload
def get_response(
self,
messages: str | Message | Sequence[str | Message],
messages: Sequence[Message],
*,
stream: Literal[False] = ...,
options: OptionsCoT | ChatOptions[None] | None = None,
@@ -396,7 +385,7 @@ class BaseChatClient(SerializationMixin, ABC, Generic[OptionsCoT]):
@overload
def get_response(
self,
messages: str | Message | Sequence[str | Message],
messages: Sequence[Message],
*,
stream: Literal[True],
options: OptionsCoT | ChatOptions[Any] | None = None,
@@ -405,7 +394,7 @@ class BaseChatClient(SerializationMixin, ABC, Generic[OptionsCoT]):
def get_response(
self,
messages: str | Message | Sequence[str | Message],
messages: Sequence[Message],
*,
stream: bool = False,
options: OptionsCoT | ChatOptions[Any] | None = None,
@@ -422,9 +411,8 @@ class BaseChatClient(SerializationMixin, ABC, Generic[OptionsCoT]):
Returns:
When streaming a response stream of ChatResponseUpdates, otherwise an Awaitable ChatResponse.
"""
prepared_messages = prepare_messages(messages)
return self._inner_get_response(
messages=prepared_messages,
messages=messages,
stream=stream,
options=options or {}, # type: ignore[arg-type]
**kwargs,
@@ -1,29 +0,0 @@
# Copyright (c) Microsoft. All rights reserved.
import logging
from .exceptions import AgentFrameworkException
__all__ = ["get_logger", "setup_logging"]
def setup_logging() -> None:
"""Setup the logging configuration for the agent framework."""
logging.basicConfig(
format="[%(asctime)s - %(pathname)s:%(lineno)d - %(levelname)s] %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
)
def get_logger(name: str = "agent_framework") -> logging.Logger:
"""Get a logger with the specified name, defaulting to 'agent_framework'.
Args:
name (str): The name of the logger. Defaults to 'agent_framework'.
Returns:
logging.Logger: The configured logger instance.
"""
if not name.startswith("agent_framework"):
raise AgentFrameworkException("Logger name must start with 'agent_framework'.")
return logging.getLogger(name)
@@ -72,12 +72,6 @@ LOG_LEVEL_MAPPING: dict[types.LoggingLevel, int] = {
"emergency": logging.CRITICAL,
}
__all__ = [
"MCPStdioTool",
"MCPStreamableHTTPTool",
"MCPWebsocketTool",
]
def _parse_prompt_result_from_mcp(
mcp_type: types.GetPromptResult,
@@ -14,11 +14,12 @@ from ._clients import SupportsChatGetResponse
from ._types import (
AgentResponse,
AgentResponseUpdate,
AgentRunInputs,
ChatResponse,
ChatResponseUpdate,
Message,
ResponseStream,
prepare_messages,
normalize_messages,
)
from .exceptions import MiddlewareException
@@ -42,27 +43,6 @@ if TYPE_CHECKING:
ResponseModelBoundT = TypeVar("ResponseModelBoundT", bound=BaseModel)
__all__ = [
"AgentContext",
"AgentMiddleware",
"AgentMiddlewareLayer",
"AgentMiddlewareTypes",
"ChatAndFunctionMiddlewareTypes",
"ChatContext",
"ChatMiddleware",
"ChatMiddlewareLayer",
"ChatMiddlewareTypes",
"FunctionInvocationContext",
"FunctionMiddleware",
"FunctionMiddlewareTypes",
"MiddlewareException",
"MiddlewareTermination",
"MiddlewareType",
"MiddlewareTypes",
"agent_middleware",
"chat_middleware",
"function_middleware",
]
AgentT = TypeVar("AgentT", bound="SupportsAgentRun")
ContextT = TypeVar("ContextT")
@@ -978,7 +958,7 @@ class ChatMiddlewareLayer(Generic[OptionsCoT]):
@overload
def get_response(
self,
messages: str | Message | Sequence[str | Message],
messages: Sequence[Message],
*,
stream: Literal[False] = ...,
options: ChatOptions[ResponseModelBoundT],
@@ -988,7 +968,7 @@ class ChatMiddlewareLayer(Generic[OptionsCoT]):
@overload
def get_response(
self,
messages: str | Message | Sequence[str | Message],
messages: Sequence[Message],
*,
stream: Literal[False] = ...,
options: OptionsCoT | ChatOptions[None] | None = None,
@@ -998,7 +978,7 @@ class ChatMiddlewareLayer(Generic[OptionsCoT]):
@overload
def get_response(
self,
messages: str | Message | Sequence[str | Message],
messages: Sequence[Message],
*,
stream: Literal[True],
options: OptionsCoT | ChatOptions[Any] | None = None,
@@ -1007,7 +987,7 @@ class ChatMiddlewareLayer(Generic[OptionsCoT]):
def get_response(
self,
messages: str | Message | Sequence[str | Message],
messages: Sequence[Message],
*,
stream: bool = False,
options: OptionsCoT | ChatOptions[Any] | None = None,
@@ -1034,7 +1014,7 @@ class ChatMiddlewareLayer(Generic[OptionsCoT]):
context = ChatContext(
client=self, # type: ignore[arg-type]
messages=prepare_messages(messages),
messages=list(messages),
options=options,
stream=stream,
kwargs=kwargs,
@@ -1095,7 +1075,7 @@ class AgentMiddlewareLayer:
@overload
def run(
self,
messages: str | Message | Sequence[str | Message] | None = None,
messages: AgentRunInputs | None = None,
*,
stream: Literal[False] = ...,
session: AgentSession | None = None,
@@ -1107,7 +1087,7 @@ class AgentMiddlewareLayer:
@overload
def run(
self,
messages: str | Message | Sequence[str | Message] | None = None,
messages: AgentRunInputs | None = None,
*,
stream: Literal[False] = ...,
session: AgentSession | None = None,
@@ -1119,7 +1099,7 @@ class AgentMiddlewareLayer:
@overload
def run(
self,
messages: str | Message | Sequence[str | Message] | None = None,
messages: AgentRunInputs | None = None,
*,
stream: Literal[True],
session: AgentSession | None = None,
@@ -1130,7 +1110,7 @@ class AgentMiddlewareLayer:
def run(
self,
messages: str | Message | Sequence[str | Message] | None = None,
messages: AgentRunInputs | None = None,
*,
stream: bool = False,
session: AgentSession | None = None,
@@ -1161,7 +1141,7 @@ class AgentMiddlewareLayer:
context = AgentContext(
agent=self, # type: ignore[arg-type]
messages=prepare_messages(messages), # type: ignore[arg-type]
messages=normalize_messages(messages),
session=session,
options=options,
stream=stream,
@@ -3,13 +3,12 @@
from __future__ import annotations
import json
import logging
import re
from collections.abc import Mapping, MutableMapping
from typing import Any, ClassVar, Protocol, TypeVar, runtime_checkable
from ._logging import get_logger
logger = get_logger()
logger = logging.getLogger("agent_framework")
ClassT = TypeVar("ClassT", bound="SerializationMixin")
ProtocolT = TypeVar("ProtocolT", bound="SerializationProtocol")
@@ -24,16 +24,6 @@ if TYPE_CHECKING:
from ._agents import SupportsAgentRun
__all__ = [
"AgentSession",
"BaseContextProvider",
"BaseHistoryProvider",
"InMemoryHistoryProvider",
"SessionContext",
"register_state_type",
]
# Registry of known types for state deserialization
_STATE_TYPE_REGISTRY: dict[str, type] = {}
@@ -45,7 +45,6 @@ if sys.version_info >= (3, 13):
else:
from typing_extensions import TypeVar # type: ignore # pragma: no cover
__all__ = ["SecretString", "load_settings"]
SettingsT = TypeVar("SettingsT", default=dict[str, Any])
@@ -2,21 +2,14 @@
from __future__ import annotations
import logging
import os
from typing import Any, Final
from . import __version__ as version_info
from ._logging import get_logger
logger = get_logger()
logger = logging.getLogger("agent_framework")
__all__ = [
"AGENT_FRAMEWORK_USER_AGENT",
"APP_INFO",
"USER_AGENT_KEY",
"USER_AGENT_TELEMETRY_DISABLED_ENV_VAR",
"prepend_agent_framework_to_user_agent",
]
# Note that if this environment variable does not exist, user agent telemetry is enabled.
USER_AGENT_TELEMETRY_DISABLED_ENV_VAR = "AGENT_FRAMEWORK_USER_AGENT_DISABLED"
+8 -19
View File
@@ -5,6 +5,7 @@ from __future__ import annotations
import asyncio
import inspect
import json
import logging
import sys
from collections.abc import (
AsyncIterable,
@@ -34,7 +35,6 @@ from typing import (
from opentelemetry.metrics import Histogram, NoOpHistogram
from pydantic import BaseModel, Field, ValidationError, create_model
from ._logging import get_logger
from ._serialization import SerializationMixin
from .exceptions import ToolException
from .observability import (
@@ -71,18 +71,8 @@ if TYPE_CHECKING:
ResponseModelBoundT = TypeVar("ResponseModelBoundT", bound=BaseModel)
logger = get_logger()
logger = logging.getLogger("agent_framework")
__all__ = [
"FunctionInvocationConfiguration",
"FunctionInvocationLayer",
"FunctionTool",
"normalize_function_invocation_configuration",
"tool",
]
logger = get_logger()
DEFAULT_MAX_ITERATIONS: Final[int] = 40
DEFAULT_MAX_CONSECUTIVE_ERRORS_PER_REQUEST: Final[int] = 3
ChatClientT = TypeVar("ChatClientT", bound="SupportsChatGetResponse[Any]")
@@ -1941,7 +1931,7 @@ class FunctionInvocationLayer(Generic[OptionsCoT]):
@overload
def get_response(
self,
messages: str | Message | Sequence[str | Message],
messages: Sequence[Message],
*,
stream: Literal[False] = ...,
options: ChatOptions[ResponseModelBoundT],
@@ -1951,7 +1941,7 @@ class FunctionInvocationLayer(Generic[OptionsCoT]):
@overload
def get_response(
self,
messages: str | Message | Sequence[str | Message],
messages: Sequence[Message],
*,
stream: Literal[False] = ...,
options: OptionsCoT | ChatOptions[None] | None = None,
@@ -1961,7 +1951,7 @@ class FunctionInvocationLayer(Generic[OptionsCoT]):
@overload
def get_response(
self,
messages: str | Message | Sequence[str | Message],
messages: Sequence[Message],
*,
stream: Literal[True],
options: OptionsCoT | ChatOptions[Any] | None = None,
@@ -1970,7 +1960,7 @@ class FunctionInvocationLayer(Generic[OptionsCoT]):
def get_response(
self,
messages: str | Message | Sequence[str | Message],
messages: Sequence[Message],
*,
stream: bool = False,
options: OptionsCoT | ChatOptions[Any] | None = None,
@@ -1982,7 +1972,6 @@ class FunctionInvocationLayer(Generic[OptionsCoT]):
ChatResponse,
ChatResponseUpdate,
ResponseStream,
prepare_messages,
)
super_get_response = super().get_response # type: ignore[misc]
@@ -2014,7 +2003,7 @@ class FunctionInvocationLayer(Generic[OptionsCoT]):
nonlocal mutable_options
nonlocal filtered_kwargs
errors_in_a_row: int = 0
prepped_messages = prepare_messages(messages)
prepped_messages = list(messages)
fcc_messages: list[Message] = []
response: ChatResponse | None = None
@@ -2108,7 +2097,7 @@ class FunctionInvocationLayer(Generic[OptionsCoT]):
nonlocal mutable_options
nonlocal stream_result_hooks
errors_in_a_row: int = 0
prepped_messages = prepare_messages(messages)
prepped_messages = list(messages)
fcc_messages: list[Message] = []
response: ChatResponse | None = None
+4 -74
View File
@@ -4,6 +4,7 @@ from __future__ import annotations
import base64
import json
import logging
import re
import sys
from asyncio import iscoroutine
@@ -13,7 +14,6 @@ from typing import TYPE_CHECKING, Any, ClassVar, Final, Generic, Literal, NewTyp
from pydantic import BaseModel
from ._logging import get_logger
from ._serialization import SerializationMixin
from ._tools import FunctionTool, tool
from .exceptions import AdditionItemMismatch, ContentError
@@ -27,41 +27,7 @@ if sys.version_info >= (3, 11):
else:
from typing_extensions import TypedDict # type: ignore # pragma: no cover
__all__ = [
"AgentResponse",
"AgentResponseUpdate",
"Annotation",
"ChatOptions",
"ChatResponse",
"ChatResponseUpdate",
"Content",
"ContinuationToken",
"FinalT",
"FinishReason",
"FinishReasonLiteral",
"Message",
"OuterFinalT",
"OuterUpdateT",
"ResponseStream",
"Role",
"RoleLiteral",
"TextSpanRegion",
"ToolMode",
"UpdateT",
"UsageDetails",
"add_usage_details",
"detect_media_type_from_base64",
"map_chat_to_agent_update",
"merge_chat_options",
"normalize_messages",
"normalize_tools",
"prepend_instructions_to_messages",
"validate_chat_options",
"validate_tool_mode",
"validate_tools",
]
logger = get_logger("agent_framework")
logger = logging.getLogger("agent_framework")
# region Content Parsing Utilities
@@ -1536,47 +1502,11 @@ class Message(SerializationMixin):
return " ".join(content.text for content in self.contents if content.type == "text") # type: ignore[misc]
def prepare_messages(
messages: str | Content | Message | Sequence[str | Content | Message],
system_instructions: str | Sequence[str] | None = None,
) -> list[Message]:
"""Convert various message input formats into a list of Message objects.
Args:
messages: The input messages in various supported formats. Can be:
- A string (converted to a user message)
- A Content object (wrapped in a user Message)
- A Message object
- A sequence containing any mix of the above
system_instructions: The system instructions. They will be inserted to the start of the messages list.
Returns:
A list of Message objects.
"""
if system_instructions is not None:
if isinstance(system_instructions, str):
system_instructions = [system_instructions]
system_instruction_messages = [Message("system", [instr]) for instr in system_instructions]
else:
system_instruction_messages = []
if isinstance(messages, str):
return [*system_instruction_messages, Message("user", [messages])]
if isinstance(messages, Content):
return [*system_instruction_messages, Message("user", [messages])]
if isinstance(messages, Message):
return [*system_instruction_messages, messages]
return_messages: list[Message] = system_instruction_messages
for msg in messages:
if isinstance(msg, (str, Content)):
msg = Message("user", [msg])
return_messages.append(msg)
return return_messages
AgentRunInputs = str | Content | Message | Sequence[str | Content | Message]
def normalize_messages(
messages: str | Content | Message | Sequence[str | Content | Message] | None = None,
messages: AgentRunInputs | None = None,
) -> list[Message]:
"""Normalize message inputs to a list of Message objects.
@@ -22,6 +22,7 @@ from .._sessions import (
from .._types import (
AgentResponse,
AgentResponseUpdate,
AgentRunInputs,
Content,
Message,
ResponseStream,
@@ -145,7 +146,7 @@ class WorkflowAgent(BaseAgent):
@overload
def run(
self,
messages: str | Message | Sequence[str | Message] | None = None,
messages: AgentRunInputs | None = None,
*,
stream: Literal[True],
session: AgentSession | None = None,
@@ -157,7 +158,7 @@ class WorkflowAgent(BaseAgent):
@overload
async def run(
self,
messages: str | Message | Sequence[str | Message] | None = None,
messages: AgentRunInputs | None = None,
*,
stream: Literal[False] = ...,
session: AgentSession | None = None,
@@ -168,7 +169,7 @@ class WorkflowAgent(BaseAgent):
def run(
self,
messages: str | Message | Sequence[str | Message] | None = None,
messages: AgentRunInputs | None = None,
*,
stream: bool = False,
session: AgentSession | None = None,
@@ -214,7 +215,7 @@ class WorkflowAgent(BaseAgent):
async def _run_impl(
self,
messages: str | Message | Sequence[str | Message],
messages: AgentRunInputs,
response_id: str,
session: AgentSession | None,
checkpoint_id: str | None = None,
@@ -270,7 +271,7 @@ class WorkflowAgent(BaseAgent):
async def _run_stream_impl(
self,
messages: str | Message | Sequence[str | Message],
messages: AgentRunInputs,
response_id: str,
session: AgentSession | None,
checkpoint_id: str | None = None,
@@ -3,11 +3,10 @@
from __future__ import annotations
import base64
import logging
import pickle # nosec # noqa: S403
from typing import Any
from agent_framework import get_logger
"""Checkpoint encoding using JSON structure with pickle+base64 for arbitrary data.
This hybrid approach provides:
@@ -20,7 +19,7 @@ from trusted sources. Loading a malicious checkpoint file can execute arbitrary
"""
logger = get_logger(__name__)
logger = logging.getLogger("agent_framework")
# Marker to identify pickled values in serialized JSON
_PICKLE_MARKER = "__pickled__"
@@ -2,18 +2,17 @@
"""Shared helpers for normalizing workflow message inputs."""
from collections.abc import Sequence
from agent_framework import Message
from agent_framework import Content, Message
from agent_framework._types import AgentRunInputs
def normalize_messages_input(
messages: str | Message | Sequence[str | Message] | None = None,
messages: AgentRunInputs | None = None,
) -> list[Message]:
"""Normalize heterogeneous message inputs to a list of Message objects.
Args:
messages: String, Message, or sequence of either. None yields empty list.
messages: String, Content, Message, or sequence of those values. None yields empty list.
Returns:
List of Message instances suitable for workflow consumption.
@@ -24,6 +23,9 @@ def normalize_messages_input(
if isinstance(messages, str):
return [Message(role="user", text=messages)]
if isinstance(messages, Content):
return [Message(role="user", contents=[messages])]
if isinstance(messages, Message):
return [messages]
@@ -31,13 +33,12 @@ def normalize_messages_input(
for item in messages:
if isinstance(item, str):
normalized.append(Message(role="user", text=item))
elif isinstance(item, Content):
normalized.append(Message(role="user", contents=[item]))
elif isinstance(item, Message):
normalized.append(item)
else:
raise TypeError(
f"Messages sequence must contain only str or Message instances; found {type(item).__name__}."
f"Messages sequence must contain only str, Content, or Message instances; found {type(item).__name__}."
)
return normalized
__all__ = ["normalize_messages_input"]
@@ -27,8 +27,6 @@ if sys.version_info >= (3, 11):
else:
from typing_extensions import TypedDict # type: ignore # pragma: no cover
__all__ = ["AzureOpenAIAssistantsClient"]
# region Azure OpenAI Assistants Options TypedDict
@@ -53,7 +53,6 @@ if TYPE_CHECKING:
logger: logging.Logger = logging.getLogger(__name__)
__all__ = ["AzureOpenAIChatClient", "AzureOpenAIChatOptions", "AzureUserSecurityContext"]
ResponseModelT = TypeVar("ResponseModelT", bound=BaseModel | None, default=None)
@@ -42,8 +42,6 @@ if TYPE_CHECKING:
from .._middleware import MiddlewareTypes
from ..openai._responses_client import OpenAIResponsesOptions
__all__ = ["AzureOpenAIResponsesClient"]
AzureOpenAIResponsesOptionsT = TypeVar(
"AzureOpenAIResponsesOptionsT",
@@ -20,7 +20,6 @@ from opentelemetry.semconv.attributes import service_attributes
from opentelemetry.semconv_ai import Meters, SpanAttributes
from . import __version__ as version_info
from ._logging import get_logger
from ._settings import load_settings
if sys.version_info >= (3, 13):
@@ -44,6 +43,7 @@ if TYPE_CHECKING: # pragma: no cover
from ._types import (
AgentResponse,
AgentResponseUpdate,
AgentRunInputs,
ChatOptions,
ChatResponse,
ChatResponseUpdate,
@@ -73,7 +73,7 @@ AgentT = TypeVar("AgentT", bound="SupportsAgentRun")
ChatClientT = TypeVar("ChatClientT", bound="SupportsChatGetResponse[Any]")
logger = get_logger()
logger = logging.getLogger("agent_framework")
OTEL_METRICS: Final[str] = "__otel_metrics__"
@@ -747,7 +747,6 @@ class ObservabilitySettings:
for log_exporter in log_exporters:
logger_provider.add_log_record_processor(BatchLogRecordProcessor(log_exporter))
# Attach a handler with the provider to the root logger
logger = logging.getLogger()
handler = LoggingHandler(logger_provider=logger_provider)
logger.addHandler(handler)
set_logger_provider(logger_provider)
@@ -1084,7 +1083,7 @@ class ChatTelemetryLayer(Generic[OptionsCoT]):
@overload
def get_response(
self,
messages: str | Message | Sequence[str | Message],
messages: Sequence[Message],
*,
stream: Literal[False] = ...,
options: ChatOptions[ResponseModelBoundT],
@@ -1094,7 +1093,7 @@ class ChatTelemetryLayer(Generic[OptionsCoT]):
@overload
def get_response(
self,
messages: str | Message | Sequence[str | Message],
messages: Sequence[Message],
*,
stream: Literal[False] = ...,
options: OptionsCoT | ChatOptions[None] | None = None,
@@ -1104,7 +1103,7 @@ class ChatTelemetryLayer(Generic[OptionsCoT]):
@overload
def get_response(
self,
messages: str | Message | Sequence[str | Message],
messages: Sequence[Message],
*,
stream: Literal[True],
options: OptionsCoT | ChatOptions[Any] | None = None,
@@ -1113,7 +1112,7 @@ class ChatTelemetryLayer(Generic[OptionsCoT]):
def get_response(
self,
messages: str | Message | Sequence[str | Message],
messages: Sequence[Message],
*,
stream: bool = False,
options: OptionsCoT | ChatOptions[Any] | None = None,
@@ -1277,7 +1276,7 @@ class AgentTelemetryLayer:
@overload
def run(
self,
messages: str | Message | Sequence[str | Message] | None = None,
messages: AgentRunInputs | None = None,
*,
stream: Literal[False] = ...,
session: AgentSession | None = None,
@@ -1287,7 +1286,7 @@ class AgentTelemetryLayer:
@overload
def run(
self,
messages: str | Message | Sequence[str | Message] | None = None,
messages: AgentRunInputs | None = None,
*,
stream: Literal[True],
session: AgentSession | None = None,
@@ -1296,7 +1295,7 @@ class AgentTelemetryLayer:
def run(
self,
messages: str | Message | Sequence[str | Message] | None = None,
messages: AgentRunInputs | None = None,
*,
stream: bool = False,
session: AgentSession | None = None,
@@ -1614,15 +1613,15 @@ def capture_exception(span: trace.Span, exception: Exception, timestamp: int | N
def _capture_messages(
span: trace.Span,
provider_name: str,
messages: str | Message | Sequence[str | Message],
messages: AgentRunInputs,
system_instructions: str | list[str] | None = None,
output: bool = False,
finish_reason: FinishReason | None = None,
) -> None:
"""Log messages with extra information."""
from ._types import prepare_messages
from ._types import normalize_messages, prepend_instructions_to_messages
prepped = prepare_messages(messages, system_instructions=system_instructions)
prepped = prepend_instructions_to_messages(normalize_messages(messages), system_instructions)
otel_messages: list[dict[str, Any]] = []
for index, message in enumerate(prepped):
# Reuse the otel message representation for logging instead of calling to_dict()
@@ -33,7 +33,6 @@ if sys.version_info >= (3, 11):
else:
from typing_extensions import Self, TypedDict # type:ignore # pragma: no cover
__all__ = ["OpenAIAssistantProvider"]
# Type variable for options - allows typed OpenAIAssistantProvider[OptionsCoT] returns
# Default matches OpenAIAssistantsClient's default options type
@@ -68,12 +68,6 @@ else:
if TYPE_CHECKING:
from .._middleware import MiddlewareTypes
__all__ = [
"AssistantToolResources",
"OpenAIAssistantsClient",
"OpenAIAssistantsOptions",
]
# region OpenAI Assistants Options TypedDict
@@ -3,6 +3,7 @@
from __future__ import annotations
import json
import logging
import sys
from collections.abc import AsyncIterable, Awaitable, Callable, Mapping, MutableMapping, Sequence
from datetime import datetime, timezone
@@ -20,7 +21,6 @@ from openai.types.chat.completion_create_params import WebSearchOptions
from pydantic import BaseModel
from .._clients import BaseChatClient
from .._logging import get_logger
from .._middleware import ChatAndFunctionMiddlewareTypes, ChatMiddlewareLayer
from .._settings import load_settings
from .._tools import (
@@ -60,9 +60,7 @@ if sys.version_info >= (3, 11):
else:
from typing_extensions import TypedDict # type: ignore # pragma: no cover
__all__ = ["OpenAIChatClient", "OpenAIChatOptions"]
logger = get_logger("agent_framework.openai")
logger = logging.getLogger("agent_framework.openai")
ResponseModelT = TypeVar("ResponseModelT", bound=BaseModel | None, default=None)
@@ -10,8 +10,6 @@ from openai import BadRequestError
from ..exceptions import ServiceContentFilterException
__all__ = ["ContentFilterResultSeverity", "OpenAIContentFilterException"]
class ContentFilterResultSeverity(Enum):
"""The severity of the content filter result."""
@@ -2,6 +2,7 @@
from __future__ import annotations
import logging
import sys
from collections.abc import (
AsyncIterable,
@@ -36,7 +37,6 @@ from openai.types.responses.web_search_tool_param import WebSearchToolParam
from pydantic import BaseModel
from .._clients import BaseChatClient
from .._logging import get_logger
from .._middleware import ChatMiddlewareLayer
from .._settings import load_settings
from .._tools import (
@@ -90,10 +90,7 @@ if TYPE_CHECKING:
FunctionMiddlewareCallable,
)
logger = get_logger("agent_framework.openai")
__all__ = ["OpenAIContinuationToken", "OpenAIResponsesClient", "OpenAIResponsesOptions", "RawOpenAIResponsesClient"]
logger = logging.getLogger("agent_framework.openai")
class OpenAIContinuationToken(ContinuationToken):
@@ -22,14 +22,13 @@ from openai.types.responses.response import Response
from openai.types.responses.response_stream_event import ResponseStreamEvent
from packaging.version import parse
from .._logging import get_logger
from .._serialization import SerializationMixin
from .._settings import SecretString
from .._telemetry import APP_INFO, USER_AGENT_KEY, prepend_agent_framework_to_user_agent
from .._tools import FunctionTool
from ..exceptions import ServiceInitializationError
logger: logging.Logger = get_logger("agent_framework.openai")
logger: logging.Logger = logging.getLogger("agent_framework.openai")
RESPONSE_TYPE = Union[
@@ -53,9 +52,6 @@ else:
from typing_extensions import TypedDict # type: ignore # pragma: no cover
__all__ = ["OpenAISettings"]
def _check_openai_version_for_callable_api_key() -> None:
"""Check if OpenAI version supports callable API keys.
@@ -334,8 +334,10 @@ async def test_integration_options(
messages = [Message(role="user", text="What is the weather in Seattle?")]
elif option_name == "response_format":
# Use prompt that works well with structured output
messages = [Message(role="user", text="The weather in Seattle is sunny")]
messages.append(Message(role="user", text="What is the weather in Seattle?"))
messages = [
Message(role="user", text="The weather in Seattle is sunny"),
Message(role="user", text="What is the weather in Seattle?"),
]
else:
# Generic prompt for simple options
messages = [Message(role="user", text="Say 'Hello World' briefly.")]
@@ -396,7 +398,12 @@ async def test_integration_web_search() -> None:
for streaming in [False, True]:
content = {
"messages": "Who are the main characters of Kpop Demon Hunters? Do a web search to find the answer.",
"messages": [
Message(
role="user",
text="Who are the main characters of Kpop Demon Hunters? Do a web search to find the answer.",
)
],
"options": {
"tool_choice": "auto",
"tools": [AzureOpenAIResponsesClient.get_web_search_tool()],
@@ -416,7 +423,7 @@ async def test_integration_web_search() -> None:
# Test that the client will use the web search tool with location
content = {
"messages": "What is the current weather? Do not ask for my current location.",
"messages": [Message(role="user", text="What is the current weather? Do not ask for my current location.")],
"options": {
"tool_choice": "auto",
"tools": [
@@ -498,7 +505,7 @@ async def test_integration_client_agent_hosted_mcp_tool() -> None:
"""Integration test for MCP tool with Azure Response Agent using Microsoft Learn MCP."""
client = AzureOpenAIResponsesClient(credential=AzureCliCredential())
response = await client.get_response(
"How to create an Azure storage account using az cli?",
messages=[Message(role="user", text="How to create an Azure storage account using az cli?")],
options={
# this needs to be high enough to handle the full MCP tool response.
"max_tokens": 5000,
@@ -523,7 +530,7 @@ async def test_integration_client_agent_hosted_code_interpreter_tool():
client = AzureOpenAIResponsesClient(credential=AzureCliCredential())
response = await client.get_response(
"Calculate the sum of numbers from 1 to 10 using Python code.",
messages=[Message(role="user", text="Calculate the sum of numbers from 1 to 10 using Python code.")],
options={
"tools": [AzureOpenAIResponsesClient.get_code_interpreter_tool()],
},
@@ -43,6 +43,12 @@ async def test_agent_run(agent: SupportsAgentRun) -> None:
assert response.messages[0].text == "Response"
async def test_agent_run_with_content(agent: SupportsAgentRun) -> None:
response = await agent.run(Content.from_text("test"))
assert response.messages[0].role == "assistant"
assert response.messages[0].text == "Response"
async def test_agent_run_streaming(agent: SupportsAgentRun) -> None:
async def collect_updates(updates: AsyncIterable[AgentResponseUpdate]) -> list[AgentResponseUpdate]:
return [u async for u in updates]
@@ -21,13 +21,13 @@ def test_chat_client_type(client: SupportsChatGetResponse):
async def test_chat_client_get_response(client: SupportsChatGetResponse):
response = await client.get_response(Message(role="user", text="Hello"))
response = await client.get_response([Message(role="user", text="Hello")])
assert response.text == "test response"
assert response.messages[0].role == "assistant"
async def test_chat_client_get_response_streaming(client: SupportsChatGetResponse):
async for update in client.get_response(Message(role="user", text="Hello"), stream=True):
async for update in client.get_response([Message(role="user", text="Hello")], stream=True):
assert update.text == "test streaming response " or update.text == "another update"
assert update.role == "assistant"
@@ -38,13 +38,13 @@ def test_base_client(chat_client_base: SupportsChatGetResponse):
async def test_base_client_get_response(chat_client_base: SupportsChatGetResponse):
response = await chat_client_base.get_response(Message(role="user", text="Hello"))
response = await chat_client_base.get_response([Message(role="user", text="Hello")])
assert response.messages[0].role == "assistant"
assert response.messages[0].text == "test response - Hello"
async def test_base_client_get_response_streaming(chat_client_base: SupportsChatGetResponse):
async for update in chat_client_base.get_response(Message(role="user", text="Hello"), stream=True):
async for update in chat_client_base.get_response([Message(role="user", text="Hello")], stream=True):
assert update.text == "update - Hello" or update.text == "another update"
@@ -59,7 +59,9 @@ async def test_chat_client_instructions_handling(chat_client_base: SupportsChatG
"_inner_get_response",
side_effect=fake_inner_get_response,
) as mock_inner_get_response:
await chat_client_base.get_response("hello", options={"instructions": instructions})
await chat_client_base.get_response(
[Message(role="user", text="hello")], options={"instructions": instructions}
)
mock_inner_get_response.assert_called_once()
_, kwargs = mock_inner_get_response.call_args
messages = kwargs.get("messages", [])
@@ -38,7 +38,9 @@ async def test_base_client_with_function_calling(chat_client_base: SupportsChatG
),
ChatResponse(messages=Message(role="assistant", text="done")),
]
response = await chat_client_base.get_response("hello", options={"tool_choice": "auto", "tools": [ai_func]})
response = await chat_client_base.get_response(
[Message(role="user", text="hello")], options={"tool_choice": "auto", "tools": [ai_func]}
)
assert exec_counter == 1
assert len(response.messages) == 3
assert response.messages[0].role == "assistant"
@@ -83,7 +85,9 @@ async def test_base_client_with_function_calling_resets(chat_client_base: Suppor
),
ChatResponse(messages=Message(role="assistant", text="done")),
]
response = await chat_client_base.get_response("hello", options={"tool_choice": "auto", "tools": [ai_func]})
response = await chat_client_base.get_response(
[Message(role="user", text="hello")], options={"tool_choice": "auto", "tools": [ai_func]}
)
assert exec_counter == 2
assert len(response.messages) == 5
assert response.messages[0].role == "assistant"
@@ -388,11 +392,13 @@ async def test_function_invocation_scenarios(
options["conversation_id"] = conversation_id
if not streaming:
response = await chat_client_base.get_response("hello", options=options)
response = await chat_client_base.get_response([Message(role="user", text="hello")], options=options)
messages = response.messages
else:
updates = []
async for update in chat_client_base.get_response("hello", options=options, stream=True):
async for update in chat_client_base.get_response(
[Message(role="user", text="hello")], options=options, stream=True
):
updates.append(update)
messages = updates
@@ -776,7 +782,9 @@ async def test_max_iterations_limit(chat_client_base: SupportsChatGetResponse):
# Set max_iterations to 1 in additional_properties
chat_client_base.function_invocation_configuration["max_iterations"] = 1
response = await chat_client_base.get_response("hello", options={"tool_choice": "auto", "tools": [ai_func]})
response = await chat_client_base.get_response(
[Message(role="user", text="hello")], options={"tool_choice": "auto", "tools": [ai_func]}
)
# With max_iterations=1, we should:
# 1. Execute first function call (exec_counter=1)
@@ -803,7 +811,9 @@ async def test_function_invocation_config_enabled_false(chat_client_base: Suppor
# Disable function invocation
chat_client_base.function_invocation_configuration["enabled"] = False
response = await chat_client_base.get_response("hello", options={"tool_choice": "auto", "tools": [ai_func]})
response = await chat_client_base.get_response(
[Message(role="user", text="hello")], options={"tool_choice": "auto", "tools": [ai_func]}
)
# Function should not be executed - when enabled=False, the loop doesn't run
assert exec_counter == 0
@@ -859,7 +869,9 @@ async def test_function_invocation_config_max_consecutive_errors(chat_client_bas
# Set max_consecutive_errors to 2
chat_client_base.function_invocation_configuration["max_consecutive_errors_per_request"] = 2
response = await chat_client_base.get_response("hello", options={"tool_choice": "auto", "tools": [error_func]})
response = await chat_client_base.get_response(
[Message(role="user", text="hello")], options={"tool_choice": "auto", "tools": [error_func]}
)
# Should stop after 2 consecutive errors and force a non-tool response
error_results = [
@@ -904,7 +916,9 @@ async def test_function_invocation_config_terminate_on_unknown_calls_false(chat_
# Set terminate_on_unknown_calls to False (default)
chat_client_base.function_invocation_configuration["terminate_on_unknown_calls"] = False
response = await chat_client_base.get_response("hello", options={"tool_choice": "auto", "tools": [known_func]})
response = await chat_client_base.get_response(
[Message(role="user", text="hello")], options={"tool_choice": "auto", "tools": [known_func]}
)
# Should have a result message indicating the tool wasn't found
assert len(response.messages) == 3
@@ -940,7 +954,9 @@ async def test_function_invocation_config_terminate_on_unknown_calls_true(chat_c
# Should raise an exception when encountering an unknown function
with pytest.raises(KeyError, match='Error: Requested function "unknown_function" not found'):
await chat_client_base.get_response("hello", options={"tool_choice": "auto", "tools": [known_func]})
await chat_client_base.get_response(
[Message(role="user", text="hello")], options={"tool_choice": "auto", "tools": [known_func]}
)
assert exec_counter == 0
@@ -978,7 +994,9 @@ async def test_function_invocation_config_additional_tools(chat_client_base: Sup
chat_client_base.function_invocation_configuration["additional_tools"] = [hidden_func]
# Only pass visible_func in the tools parameter
response = await chat_client_base.get_response("hello", options={"tool_choice": "auto", "tools": [visible_func]})
response = await chat_client_base.get_response(
[Message(role="user", text="hello")], options={"tool_choice": "auto", "tools": [visible_func]}
)
# Additional tools are treated as declaration_only, so not executed
# The function call should be in the messages but not executed
@@ -1016,7 +1034,9 @@ async def test_function_invocation_config_include_detailed_errors_false(chat_cli
# Set include_detailed_errors to False (default)
chat_client_base.function_invocation_configuration["include_detailed_errors"] = False
response = await chat_client_base.get_response("hello", options={"tool_choice": "auto", "tools": [error_func]})
response = await chat_client_base.get_response(
[Message(role="user", text="hello")], options={"tool_choice": "auto", "tools": [error_func]}
)
# Should have a generic error message
error_result = next(
@@ -1050,7 +1070,9 @@ async def test_function_invocation_config_include_detailed_errors_true(chat_clie
# Set include_detailed_errors to True
chat_client_base.function_invocation_configuration["include_detailed_errors"] = True
response = await chat_client_base.get_response("hello", options={"tool_choice": "auto", "tools": [error_func]})
response = await chat_client_base.get_response(
[Message(role="user", text="hello")], options={"tool_choice": "auto", "tools": [error_func]}
)
# Should have detailed error message
error_result = next(
@@ -1120,7 +1142,9 @@ async def test_argument_validation_error_with_detailed_errors(chat_client_base:
# Set include_detailed_errors to True
chat_client_base.function_invocation_configuration["include_detailed_errors"] = True
response = await chat_client_base.get_response("hello", options={"tool_choice": "auto", "tools": [typed_func]})
response = await chat_client_base.get_response(
[Message(role="user", text="hello")], options={"tool_choice": "auto", "tools": [typed_func]}
)
# Should have detailed validation error
error_result = next(
@@ -1154,7 +1178,9 @@ async def test_argument_validation_error_without_detailed_errors(chat_client_bas
# Set include_detailed_errors to False (default)
chat_client_base.function_invocation_configuration["include_detailed_errors"] = False
response = await chat_client_base.get_response("hello", options={"tool_choice": "auto", "tools": [typed_func]})
response = await chat_client_base.get_response(
[Message(role="user", text="hello")], options={"tool_choice": "auto", "tools": [typed_func]}
)
# Should have generic validation error
error_result = next(
@@ -1219,7 +1245,9 @@ async def test_unapproved_tool_execution_raises_exception(chat_client_base: Supp
]
# Get approval request
response1 = await chat_client_base.get_response("hello", options={"tool_choice": "auto", "tools": [test_func]})
response1 = await chat_client_base.get_response(
[Message(role="user", text="hello")], options={"tool_choice": "auto", "tools": [test_func]}
)
approval_req = [c for c in response1.messages[0].contents if c.type == "function_approval_request"][0]
@@ -1277,7 +1305,9 @@ async def test_approved_function_call_with_error_without_detailed_errors(chat_cl
chat_client_base.function_invocation_configuration["include_detailed_errors"] = False
# Get approval request
response1 = await chat_client_base.get_response("hello", options={"tool_choice": "auto", "tools": [error_func]})
response1 = await chat_client_base.get_response(
[Message(role="user", text="hello")], options={"tool_choice": "auto", "tools": [error_func]}
)
approval_req = [c for c in response1.messages[0].contents if c.type == "function_approval_request"][0]
@@ -1340,7 +1370,9 @@ async def test_approved_function_call_with_error_with_detailed_errors(chat_clien
chat_client_base.function_invocation_configuration["include_detailed_errors"] = True
# Get approval request
response1 = await chat_client_base.get_response("hello", options={"tool_choice": "auto", "tools": [error_func]})
response1 = await chat_client_base.get_response(
[Message(role="user", text="hello")], options={"tool_choice": "auto", "tools": [error_func]}
)
approval_req = [c for c in response1.messages[0].contents if c.type == "function_approval_request"][0]
@@ -1403,7 +1435,9 @@ async def test_approved_function_call_with_validation_error(chat_client_base: Su
chat_client_base.function_invocation_configuration["include_detailed_errors"] = True
# Get approval request
response1 = await chat_client_base.get_response("hello", options={"tool_choice": "auto", "tools": [typed_func]})
response1 = await chat_client_base.get_response(
[Message(role="user", text="hello")], options={"tool_choice": "auto", "tools": [typed_func]}
)
approval_req = [c for c in response1.messages[0].contents if c.type == "function_approval_request"][0]
@@ -1459,7 +1493,9 @@ async def test_approved_function_call_successful_execution(chat_client_base: Sup
]
# Get approval request
response1 = await chat_client_base.get_response("hello", options={"tool_choice": "auto", "tools": [success_func]})
response1 = await chat_client_base.get_response(
[Message(role="user", text="hello")], options={"tool_choice": "auto", "tools": [success_func]}
)
approval_req = [c for c in response1.messages[0].contents if c.type == "function_approval_request"][0]
@@ -1575,7 +1611,9 @@ async def test_multiple_function_calls_parallel_execution(chat_client_base: Supp
ChatResponse(messages=Message(role="assistant", text="done")),
]
response = await chat_client_base.get_response("hello", options={"tool_choice": "auto", "tools": [func1, func2]})
response = await chat_client_base.get_response(
[Message(role="user", text="hello")], options={"tool_choice": "auto", "tools": [func1, func2]}
)
# Both functions should have been executed
assert "func1_start" in exec_order
@@ -1612,7 +1650,9 @@ async def test_callable_function_converted_to_tool(chat_client_base: SupportsCha
]
# Pass plain function (will be auto-converted)
response = await chat_client_base.get_response("hello", options={"tool_choice": "auto", "tools": [plain_function]})
response = await chat_client_base.get_response(
[Message(role="user", text="hello")], options={"tool_choice": "auto", "tools": [plain_function]}
)
# Function should be executed
assert exec_counter == 1
@@ -1644,7 +1684,9 @@ async def test_conversation_id_handling(chat_client_base: SupportsChatGetRespons
),
]
response = await chat_client_base.get_response("hello", options={"tool_choice": "auto", "tools": [test_func]})
response = await chat_client_base.get_response(
[Message(role="user", text="hello")], options={"tool_choice": "auto", "tools": [test_func]}
)
# Should have executed the function
results = [content for msg in response.messages for content in msg.contents if content.type == "function_result"]
@@ -1671,7 +1713,9 @@ async def test_function_result_appended_to_existing_assistant_message(chat_clien
ChatResponse(messages=Message(role="assistant", text="done")),
]
response = await chat_client_base.get_response("hello", options={"tool_choice": "auto", "tools": [test_func]})
response = await chat_client_base.get_response(
[Message(role="user", text="hello")], options={"tool_choice": "auto", "tools": [test_func]}
)
# Should have messages with both function call and function result
assert len(response.messages) >= 2
@@ -1716,7 +1760,9 @@ async def test_error_recovery_resets_counter(chat_client_base: SupportsChatGetRe
ChatResponse(messages=Message(role="assistant", text="done")),
]
response = await chat_client_base.get_response("hello", options={"tool_choice": "auto", "tools": [sometimes_fails]})
response = await chat_client_base.get_response(
[Message(role="user", text="hello")], options={"tool_choice": "auto", "tools": [sometimes_fails]}
)
# Should have both an error and a success
error_results = [
@@ -1990,7 +2036,9 @@ async def test_streaming_function_invocation_config_terminate_on_unknown_calls_t
# Should raise an exception when encountering an unknown function
with pytest.raises(KeyError, match='Error: Requested function "unknown_function" not found'):
async for _ in chat_client_base.get_response("hello", options={"tool_choice": "auto", "tools": [known_func]}):
async for _ in chat_client_base.get_response(
[Message(role="user", text="hello")], options={"tool_choice": "auto", "tools": [known_func]}
):
pass
assert exec_counter == 0
@@ -1,39 +0,0 @@
# Copyright (c) Microsoft. All rights reserved.
import pytest
from agent_framework import get_logger
from agent_framework.exceptions import AgentFrameworkException
def test_get_logger():
"""Test that the logger is created with the correct name."""
logger = get_logger()
assert logger.name == "agent_framework"
def test_get_logger_custom_name():
"""Test that the logger can be created with a custom name."""
custom_name = "agent_framework.custom"
logger = get_logger(custom_name)
assert logger.name == custom_name
def test_get_logger_invalid_name():
"""Test that an exception is raised for an invalid logger name."""
with pytest.raises(AgentFrameworkException):
get_logger("invalid_name")
def test_log(caplog):
"""Test that the logger can log messages and adheres to the expected format."""
logger = get_logger()
with caplog.at_level("DEBUG"):
logger.debug("This is a debug message")
assert len(caplog.records) == 1
record = caplog.records[0]
assert record.levelname == "DEBUG"
assert record.message == "This is a debug message"
assert record.name == "agent_framework"
assert record.pathname.endswith("test_logging.py")
@@ -1083,7 +1083,12 @@ async def test_integration_web_search() -> None:
# Use static method for web search tool
web_search_tool = OpenAIChatClient.get_web_search_tool()
content = {
"messages": "Who are the main characters of Kpop Demon Hunters? Do a web search to find the answer.",
"messages": [
Message(
role="user",
text="Who are the main characters of Kpop Demon Hunters? Do a web search to find the answer.",
)
],
"options": {
"tool_choice": "auto",
"tools": [web_search_tool],
@@ -1110,7 +1115,7 @@ async def test_integration_web_search() -> None:
}
)
content = {
"messages": "What is the current weather? Do not ask for my current location.",
"messages": [Message(role="user", text="What is the current weather? Do not ask for my current location.")],
"options": {
"tool_choice": "auto",
"tools": [web_search_tool_with_location],
@@ -2416,7 +2416,12 @@ async def test_integration_web_search() -> None:
# Use static method for web search tool
web_search_tool = OpenAIResponsesClient.get_web_search_tool()
content = {
"messages": "Who are the main characters of Kpop Demon Hunters? Do a web search to find the answer.",
"messages": [
Message(
role="user",
text="Who are the main characters of Kpop Demon Hunters? Do a web search to find the answer.",
)
],
"options": {
"tool_choice": "auto",
"tools": [web_search_tool],
@@ -2438,7 +2443,7 @@ async def test_integration_web_search() -> None:
user_location={"country": "US", "city": "Seattle"},
)
content = {
"messages": "What is the current weather? Do not ask for my current location.",
"messages": [Message(role="user", text="What is the current weather? Do not ask for my current location.")],
"options": {
"tool_choice": "auto",
"tools": [web_search_tool_with_location],
@@ -39,7 +39,7 @@ class _ToolCallingAgent(BaseAgent):
def run(
self,
messages: str | Message | Sequence[str | Message] | None = None,
messages: str | Content | Message | Sequence[str | Content | Message] | None = None,
*,
stream: bool = False,
session: AgentSession | None = None,
@@ -35,7 +35,7 @@ class _SimpleAgent(BaseAgent):
def run(
self,
messages: str | Message | Sequence[str | Message] | None = None,
messages: str | Content | Message | Sequence[str | Content | Message] | None = None,
*,
stream: bool = False,
session: AgentSession | None = None,
@@ -105,7 +105,7 @@ class _CaptureAgent(BaseAgent):
def run(
self,
messages: str | Message | Sequence[str | Message] | None = None,
messages: str | Content | Message | Sequence[str | Content | Message] | None = None,
*,
stream: bool = False,
session: AgentSession | None = None,
@@ -835,7 +835,7 @@ class _StreamingTestAgent(BaseAgent):
def run(
self,
messages: str | Message | Sequence[str | Message] | None = None,
messages: str | Content | Message | Sequence[str | Content | Message] | None = None,
*,
stream: bool = False,
session: AgentSession | None = None,
@@ -52,7 +52,7 @@ class _KwargsCapturingAgent(BaseAgent):
def run(
self,
messages: str | Message | Sequence[str | Message] | None = None,
messages: str | Content | Message | Sequence[str | Content | Message] | None = None,
*,
stream: bool = False,
session: AgentSession | None = None,
@@ -85,7 +85,7 @@ class _OptionsAwareAgent(BaseAgent):
def run(
self,
messages: str | Message | Sequence[str | Message] | None = None,
messages: str | Content | Message | Sequence[str | Content | Message] | None = None,
*,
stream: bool = False,
session: AgentSession | None = None,
@@ -1,12 +1,12 @@
# Copyright (c) Microsoft. All rights reserved.
from __future__ import annotations
import logging
import os
from collections.abc import MutableMapping
from contextvars import ContextVar
from typing import Any, Literal, TypeVar, Union
from agent_framework import get_logger
from agent_framework._serialization import SerializationMixin
try:
@@ -20,7 +20,7 @@ except (ImportError, RuntimeError):
from typing import overload
logger = get_logger("agent_framework.declarative")
logger = logging.getLogger("agent_framework.declarative")
# Context variable for safe_mode setting.
# When True (default), environment variables are NOT accessible in PowerFx expressions.
@@ -10,10 +10,10 @@ This module implements handlers for:
from __future__ import annotations
import json
import logging
from collections.abc import AsyncGenerator
from typing import Any, cast
from agent_framework import get_logger
from agent_framework._types import AgentResponse, Message
from ._handlers import (
@@ -25,7 +25,7 @@ from ._handlers import (
)
from ._human_input import ExternalLoopEvent, QuestionRequest
logger = get_logger("agent_framework.declarative.workflows.actions")
logger = logging.getLogger("agent_framework.declarative")
def _extract_json_from_response(text: str) -> Any:
@@ -18,11 +18,10 @@ actually yielding any events.
from __future__ import annotations
import logging
from collections.abc import AsyncGenerator
from typing import TYPE_CHECKING, Any, cast
from agent_framework import get_logger
from ._handlers import (
ActionContext,
AttachmentOutputEvent,
@@ -35,7 +34,7 @@ from ._handlers import (
if TYPE_CHECKING:
from ._state import WorkflowState
logger = get_logger("agent_framework.declarative.workflows.actions")
logger = logging.getLogger("agent_framework.declarative")
@action_handler("SetValue")
@@ -11,10 +11,9 @@ This module implements handlers for:
- ContinueLoop: Skip to the next iteration
"""
import logging
from collections.abc import AsyncGenerator
from agent_framework import get_logger
from ._handlers import (
ActionContext,
LoopControlSignal,
@@ -22,7 +21,7 @@ from ._handlers import (
action_handler,
)
logger = get_logger("agent_framework.declarative.workflows.actions")
logger = logging.getLogger("agent_framework.declarative")
@action_handler("Foreach")
@@ -9,18 +9,17 @@ This module implements handlers for:
from __future__ import annotations
import logging
from collections.abc import AsyncGenerator
from dataclasses import dataclass
from agent_framework import get_logger
from ._handlers import (
ActionContext,
WorkflowEvent,
action_handler,
)
logger = get_logger("agent_framework.declarative.workflows.actions")
logger = logging.getLogger("agent_framework.declarative")
class WorkflowActionError(Exception):
@@ -12,6 +12,7 @@ enabling checkpointing, visualization, and pause/resume capabilities.
from __future__ import annotations
import logging
from collections.abc import Mapping
from pathlib import Path
from typing import Any, cast
@@ -22,13 +23,12 @@ from agent_framework import (
CheckpointStorage,
SupportsAgentRun,
Workflow,
get_logger,
)
from .._loader import AgentFactory
from ._declarative_builder import DeclarativeWorkflowBuilder
logger = get_logger("agent_framework.declarative.workflows")
logger = logging.getLogger("agent_framework.declarative")
class DeclarativeWorkflowError(Exception):
@@ -9,16 +9,15 @@ has a corresponding handler registered via the @action_handler decorator.
from __future__ import annotations
import logging
from collections.abc import AsyncGenerator, Callable
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Any, Protocol, runtime_checkable
from agent_framework import get_logger
if TYPE_CHECKING:
from ._state import WorkflowState
logger = get_logger("agent_framework.declarative.workflows")
logger = logging.getLogger("agent_framework.declarative")
@dataclass
@@ -10,12 +10,11 @@ This module implements handlers for human input patterns:
from __future__ import annotations
import logging
from collections.abc import AsyncGenerator
from dataclasses import dataclass
from typing import TYPE_CHECKING, Any, cast
from agent_framework import get_logger
from ._handlers import (
ActionContext,
WorkflowEvent,
@@ -25,7 +24,7 @@ from ._handlers import (
if TYPE_CHECKING:
from ._state import WorkflowState
logger = get_logger("agent_framework.declarative.workflows.human_input")
logger = logging.getLogger("agent_framework.declarative")
@dataclass
@@ -11,11 +11,10 @@ This module provides state management for declarative workflows, handling:
from __future__ import annotations
import logging
from collections.abc import Mapping
from typing import Any, cast
from agent_framework import get_logger
try:
from powerfx import Engine
@@ -25,7 +24,7 @@ except (ImportError, RuntimeError):
# RuntimeError: .NET runtime not available or misconfigured
_powerfx_engine = None
logger = get_logger("agent_framework.declarative.workflows")
logger = logging.getLogger("agent_framework.declarative")
class WorkflowState:
@@ -8,14 +8,16 @@ with durable agents via gRPC.
from __future__ import annotations
from agent_framework import AgentResponse, get_logger
import logging
from agent_framework import AgentResponse
from durabletask.client import TaskHubGrpcClient
from ._constants import DEFAULT_MAX_POLL_RETRIES, DEFAULT_POLL_INTERVAL_SECONDS
from ._executors import ClientAgentExecutor
from ._shim import DurableAgentProvider, DurableAIAgent
logger = get_logger("agent_framework.durabletask.client")
logger = logging.getLogger("agent_framework.durabletask")
class DurableAIAgentClient(DurableAgentProvider[AgentResponse]):
@@ -30,6 +30,7 @@ All classes support bidirectional conversion between:
from __future__ import annotations
import json
import logging
from collections.abc import MutableMapping
from datetime import datetime, timezone
from enum import Enum
@@ -40,14 +41,13 @@ from agent_framework import (
Content,
Message,
UsageDetails,
get_logger,
)
from dateutil import parser as date_parser
from ._constants import ContentTypes, DurableStateFields
from ._models import RunRequest, serialize_response_format
logger = get_logger("agent_framework.durabletask.durable_agent_state")
logger = logging.getLogger("agent_framework.durabletask")
class DurableAgentStateEntryJsonType(str, Enum):
@@ -5,6 +5,7 @@
from __future__ import annotations
import inspect
import logging
from datetime import datetime, timezone
from typing import Any, cast
@@ -15,7 +16,6 @@ from agent_framework import (
Message,
ResponseStream,
SupportsAgentRun,
get_logger,
)
from durabletask.entities import DurableEntity
@@ -28,7 +28,7 @@ from ._durable_agent_state import (
)
from ._models import RunRequest
logger = get_logger("agent_framework.durabletask.entities")
logger = logging.getLogger("agent_framework.durabletask")
class AgentEntityStateProviderMixin:
@@ -10,13 +10,14 @@ and are injected into the shim.
from __future__ import annotations
import logging
import time
import uuid
from abc import ABC, abstractmethod
from datetime import datetime, timezone
from typing import Any, Generic, TypeVar
from agent_framework import AgentResponse, AgentSession, Content, Message, get_logger
from agent_framework import AgentResponse, AgentSession, Content, Message
from durabletask.client import TaskHubGrpcClient
from durabletask.entities import EntityInstanceId
from durabletask.task import CompletableTask, CompositeTask, OrchestrationContext, Task
@@ -27,7 +28,7 @@ from ._durable_agent_state import DurableAgentState
from ._models import AgentSessionId, DurableAgentSession, RunRequest
from ._response_utils import ensure_response_format, load_agent_response
logger = get_logger("agent_framework.durabletask.executors")
logger = logging.getLogger("agent_framework.durabletask")
# TypeVar for the task type returned by executors
TaskT = TypeVar("TaskT")
@@ -8,13 +8,14 @@ orchestration functions to interact with durable agents.
from __future__ import annotations
from agent_framework import get_logger
import logging
from durabletask.task import OrchestrationContext
from ._executors import DurableAgentTask, OrchestrationAgentExecutor
from ._shim import DurableAgentProvider, DurableAIAgent
logger = get_logger("agent_framework.durabletask.orchestration_context")
logger = logging.getLogger("agent_framework.durabletask")
class DurableAIAgentOrchestrationContext(DurableAgentProvider[DurableAgentTask]):
@@ -2,12 +2,13 @@
"""Shared utilities for handling AgentResponse parsing and validation."""
import logging
from typing import Any
from agent_framework import AgentResponse, get_logger
from agent_framework import AgentResponse
from pydantic import BaseModel
logger = get_logger("agent_framework.durabletask.response_utils")
logger = logging.getLogger("agent_framework.durabletask")
def load_agent_response(agent_response: AgentResponse | dict[str, Any] | None) -> AgentResponse:
@@ -12,7 +12,8 @@ from __future__ import annotations
from abc import ABC, abstractmethod
from typing import Any, Generic, Literal, TypeVar
from agent_framework import AgentSession, Message, SupportsAgentRun
from agent_framework import AgentSession, SupportsAgentRun, normalize_messages
from agent_framework._types import AgentRunInputs
from ._executors import DurableAgentExecutor
from ._models import DurableAgentSession
@@ -86,7 +87,7 @@ class DurableAIAgent(SupportsAgentRun, Generic[TaskT]):
def run( # type: ignore[override]
self,
messages: str | Message | list[str] | list[Message] | None = None,
messages: AgentRunInputs | None = None,
*,
stream: Literal[False] = False,
session: AgentSession | None = None,
@@ -143,7 +144,7 @@ class DurableAIAgent(SupportsAgentRun, Generic[TaskT]):
"""
return self._executor.get_new_session(self.name, **kwargs)
def _normalize_messages(self, messages: str | Message | list[str] | list[Message] | None) -> str:
def _normalize_messages(self, messages: AgentRunInputs | None) -> str:
"""Convert supported message inputs to a single string.
Args:
@@ -151,19 +152,18 @@ class DurableAIAgent(SupportsAgentRun, Generic[TaskT]):
Returns:
A single string representation of the messages
Raises:
ValueError: If normalized messages contain non-text content only.
"""
if messages is None:
normalized_messages = normalize_messages(messages)
if not normalized_messages:
return ""
if isinstance(messages, str):
return messages
if isinstance(messages, Message):
return messages.text or ""
if isinstance(messages, list):
if not messages:
return ""
first_item = messages[0]
if isinstance(first_item, str):
return "\n".join(messages) # type: ignore[arg-type]
# List of Message
return "\n".join([msg.text or "" for msg in messages]) # type: ignore[union-attr]
return ""
message_texts: list[str] = []
for message in normalized_messages:
if not message.text:
raise ValueError("DurableAIAgent only supports text message inputs.")
message_texts.append(message.text)
return "\n".join(message_texts)
@@ -9,15 +9,16 @@ and enables registration of agents as durable entities.
from __future__ import annotations
import asyncio
import logging
from typing import Any
from agent_framework import SupportsAgentRun, get_logger
from agent_framework import SupportsAgentRun
from durabletask.worker import TaskHubGrpcWorker
from ._callbacks import AgentResponseCallbackProtocol
from ._entities import AgentEntity, DurableTaskEntityStateProvider
logger = get_logger("agent_framework.durabletask.worker")
logger = logging.getLogger("agent_framework.durabletask")
class DurableAIAgentWorker:
@@ -23,7 +23,7 @@ from agent_framework import (
)
from agent_framework._settings import load_settings
from agent_framework._tools import FunctionTool
from agent_framework._types import normalize_tools
from agent_framework._types import AgentRunInputs, normalize_tools
from agent_framework.exceptions import ServiceException
from copilot import CopilotClient, CopilotSession
from copilot.generated.session_events import SessionEvent, SessionEventType
@@ -277,7 +277,7 @@ class GitHubCopilotAgent(BaseAgent, Generic[OptionsT]):
@overload
def run(
self,
messages: str | Message | Sequence[str | Message] | None = None,
messages: AgentRunInputs | None = None,
*,
stream: Literal[False] = False,
session: AgentSession | None = None,
@@ -288,7 +288,7 @@ class GitHubCopilotAgent(BaseAgent, Generic[OptionsT]):
@overload
def run(
self,
messages: str | Message | Sequence[str | Message] | None = None,
messages: AgentRunInputs | None = None,
*,
stream: Literal[True],
session: AgentSession | None = None,
@@ -298,7 +298,7 @@ class GitHubCopilotAgent(BaseAgent, Generic[OptionsT]):
def run(
self,
messages: str | Message | Sequence[str | Message] | None = None,
messages: AgentRunInputs | None = None,
*,
stream: bool = False,
session: AgentSession | None = None,
@@ -340,7 +340,7 @@ class GitHubCopilotAgent(BaseAgent, Generic[OptionsT]):
async def _run_impl(
self,
messages: str | Message | Sequence[str | Message] | None = None,
messages: AgentRunInputs | None = None,
*,
session: AgentSession | None = None,
options: OptionsT | None = None,
@@ -388,7 +388,7 @@ class GitHubCopilotAgent(BaseAgent, Generic[OptionsT]):
async def _stream_updates(
self,
messages: str | Message | Sequence[str | Message] | None = None,
messages: AgentRunInputs | None = None,
*,
session: AgentSession | None = None,
options: OptionsT | None = None,
@@ -3,6 +3,7 @@
from __future__ import annotations
import json
import logging
import sys
from collections.abc import (
AsyncIterable,
@@ -28,7 +29,6 @@ from agent_framework import (
Message,
ResponseStream,
UsageDetails,
get_logger,
)
from agent_framework._settings import load_settings
from agent_framework.exceptions import (
@@ -281,7 +281,7 @@ class OllamaSettings(TypedDict, total=False):
model_id: str | None
logger = get_logger("agent_framework.ollama")
logger = logging.getLogger("agent_framework.ollama")
class OllamaChatClient(
@@ -38,7 +38,7 @@ class StubAgent(BaseAgent):
def run( # type: ignore[override]
self,
messages: str | Message | Sequence[str | Message] | None = None,
messages: str | Content | Message | Sequence[str | Content | Message] | None = None,
*,
stream: bool = False,
session: AgentSession | None = None,
@@ -76,7 +76,7 @@ class StubManagerAgent(Agent):
async def run(
self,
messages: str | Message | Sequence[str | Message] | None = None,
messages: str | Content | Message | Sequence[str | Content | Message] | None = None,
*,
session: AgentSession | None = None,
**kwargs: Any,
@@ -130,7 +130,7 @@ class ConcatenatedJsonManagerAgent(Agent):
async def run(
self,
messages: str | Message | Sequence[str | Message] | None = None,
messages: str | Content | Message | Sequence[str | Content | Message] | None = None,
*,
session: AgentSession | None = None,
**kwargs: Any,
@@ -896,7 +896,7 @@ async def test_group_chat_with_orchestrator_factory_returning_chat_agent():
async def run(
self,
messages: str | Message | Sequence[str | Message] | None = None,
messages: str | Content | Message | Sequence[str | Content | Message] | None = None,
*,
session: AgentSession | None = None,
**kwargs: Any,
@@ -150,7 +150,7 @@ class StubAgent(BaseAgent):
def run( # type: ignore[override]
self,
messages: str | Message | Sequence[str | Message] | None = None,
messages: str | Content | Message | Sequence[str | Content | Message] | None = None,
*,
stream: bool = False,
session: AgentSession | None = None,
@@ -411,7 +411,7 @@ class StubManagerAgent(BaseAgent):
def run(
self,
messages: str | Message | Sequence[str | Message] | None = None,
messages: str | Content | Message | Sequence[str | Content | Message] | None = None,
*,
stream: bool = False,
session: Any = None,
@@ -5,12 +5,12 @@ from __future__ import annotations
import base64
import inspect
import json
import logging
from typing import Any, cast
from uuid import uuid4
import httpx
from agent_framework import AGENT_FRAMEWORK_USER_AGENT
from agent_framework._logging import get_logger
from agent_framework.observability import get_tracer
from azure.core.credentials import TokenCredential
from azure.core.credentials_async import AsyncTokenCredential
@@ -33,7 +33,7 @@ from ._models import (
)
from ._settings import PurviewSettings, get_purview_scopes
logger = get_logger("agent_framework.purview")
logger = logging.getLogger("agent_framework.purview")
class PurviewClient:
@@ -1,9 +1,9 @@
# Copyright (c) Microsoft. All rights reserved.
import logging
from collections.abc import Awaitable, Callable
from agent_framework import AgentContext, AgentMiddleware, ChatContext, ChatMiddleware, MiddlewareTermination
from agent_framework._logging import get_logger
from azure.core.credentials import TokenCredential
from azure.core.credentials_async import AsyncTokenCredential
@@ -14,7 +14,7 @@ from ._models import Activity
from ._processor import ScopedContentProcessor
from ._settings import PurviewSettings
logger = get_logger("agent_framework.purview")
logger = logging.getLogger("agent_framework.purview")
class PurviewPolicyMiddleware(AgentMiddleware):
@@ -2,16 +2,16 @@
from __future__ import annotations
import logging
from collections.abc import Mapping, MutableMapping, Sequence
from datetime import datetime
from enum import Enum, Flag, auto
from typing import Any, ClassVar, TypeVar, cast
from uuid import uuid4
from agent_framework._logging import get_logger
from agent_framework._serialization import SerializationMixin
logger = get_logger("agent_framework.purview")
logger = logging.getLogger("agent_framework.purview")
# --------------------------------------------------------------------------------------
# Enums & flag helpers
@@ -1,13 +1,13 @@
# Copyright (c) Microsoft. All rights reserved.
import asyncio
import logging
import time
import uuid
from collections.abc import Iterable, MutableMapping
from typing import Any
from agent_framework import Message
from agent_framework._logging import get_logger
from ._cache import CacheProvider, InMemoryCacheProvider, create_protection_scopes_cache_key
from ._client import PurviewClient
@@ -37,7 +37,7 @@ from ._models import (
)
from ._settings import PurviewSettings
logger = get_logger("agent_framework.purview")
logger = logging.getLogger("agent_framework.purview")
def _is_valid_guid(value: str | None) -> bool:
@@ -178,12 +178,15 @@ The `configure_otel_providers()` function automatically reads **standard OpenTel
> **Note**: These are standard OpenTelemetry environment variables. See the [OpenTelemetry spec](https://opentelemetry.io/docs/specs/otel/configuration/sdk-environment-variables/) for more details.
#### Logging
Agent Framework has a built-in logging configuration that works well with telemetry. It sets the format to a standard format that includes timestamp, pathname, line number, and log level. You can use that by calling the `setup_logging()` function from the `agent_framework` module.
Use standard Python logging configuration to align logs with telemetry output.
```python
from agent_framework import setup_logging
import logging
setup_logging()
logging.basicConfig(
format="[%(asctime)s - %(pathname)s:%(lineno)d - %(levelname)s] %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
)
```
You can control at what level logging happens and thus what logs get exported, you can do this, by adding this:
@@ -2,11 +2,12 @@
import argparse
import asyncio
import logging
from contextlib import suppress
from random import randint
from typing import TYPE_CHECKING, Annotated, Literal
from agent_framework import setup_logging, tool
from agent_framework import tool
from agent_framework.observability import configure_otel_providers, get_tracer
from agent_framework.openai import OpenAIResponsesClient
from opentelemetry import trace
@@ -31,7 +32,9 @@ Use this approach when you need custom exporter configuration beyond what enviro
SCENARIOS = ["client", "client_stream", "tool", "all"]
# NOTE: approval_mode="never_require" is for sample brevity. Use "always_require" in production; see samples/02-agents/tools/function_tool_with_approval.py and samples/02-agents/tools/function_tool_with_approval_and_sessions.py.
# NOTE: approval_mode="never_require" is for sample brevity.
# Use "always_require" in production; see samples/02-agents/tools/function_tool_with_approval.py
# and samples/02-agents/tools/function_tool_with_approval_and_sessions.py.
@tool(approval_mode="never_require")
async def get_weather(
location: Annotated[str, Field(description="The location to get the weather for.")],
@@ -101,7 +104,10 @@ async def main(scenario: Literal["client", "client_stream", "tool", "all"] = "al
"""Run the selected scenario(s)."""
# Setup the logging with the more complete format
setup_logging()
logging.basicConfig(
format="[%(asctime)s - %(pathname)s:%(lineno)d - %(levelname)s] %(message)s",
datefmt="%Y-%m-%d %H:%M:%S",
)
# Create custom OTLP exporters with specific configuration
# Note: You need to install opentelemetry-exporter-otlp-proto-grpc or -http separately
@@ -2,7 +2,7 @@
import asyncio
from agent_framework import Content, Message
from agent_framework import Content
from agent_framework.azure import AzureOpenAIResponsesClient
from azure.identity import AzureCliCredential
@@ -20,24 +20,17 @@ async def main():
# 1. Create an Azure Responses agent with vision capabilities
agent = AzureOpenAIResponsesClient(credential=AzureCliCredential()).as_agent(
name="VisionAgent",
instructions="You are a helpful agent that can analyze images.",
instructions="You are a image analysist, you get a image and need to respond with what you see in the picture.",
)
# 2. Create a simple message with both text and image content
user_message = Message(
role="user",
contents=[
Content.from_text("What do you see in this image?"),
Content.from_uri(
uri="https://images.unsplash.com/photo-1506905925346-21bda4d32df4?w=800",
media_type="image/jpeg",
),
],
)
# 3. Get the agent's response
# 2. Get the agent's response
print("User: What do you see in this image? [Image provided]")
result = await agent.run(user_message)
result = await agent.run(
Content.from_uri(
uri="https://images.unsplash.com/photo-1506905925346-21bda4d32df4?w=800",
media_type="image/jpeg",
)
)
print(f"Agent: {result.text}")
print()
@@ -2,7 +2,7 @@
import asyncio
from agent_framework import Content, Message
from agent_framework import Content
from agent_framework.openai import OpenAIResponsesClient
"""
@@ -19,24 +19,17 @@ async def main():
# 1. Create an OpenAI Responses agent with vision capabilities
agent = OpenAIResponsesClient().as_agent(
name="VisionAgent",
instructions="You are a helpful agent that can analyze images.",
instructions="You are a image analysist, you get a image and need to respond with what you see in the picture.",
)
# 2. Create a simple message with both text and image content
user_message = Message(
role="user",
contents=[
Content.from_text(text="What do you see in this image?"),
Content.from_uri(
uri="https://images.unsplash.com/photo-1506905925346-21bda4d32df4?w=800",
media_type="image/jpeg",
),
],
)
# 3. Get the agent's response
# 2. Get the agent's response
print("User: What do you see in this image? [Image provided]")
result = await agent.run(user_message)
result = await agent.run(
Content.from_uri(
uri="https://images.unsplash.com/photo-1506905925346-21bda4d32df4?w=800",
media_type="image/jpeg",
)
)
print(f"Agent: {result.text}")
print()
+4
View File
@@ -38,6 +38,10 @@ export AZURE_OPENAI_RESPONSES_DEPLOYMENT_NAME="gpt-4o"
For Azure authentication, run `az login` before running samples.
## Note on XML tags
Some sample files include XML-style snippet tags (for example `<snippet_name>` and `</snippet_name>`). These are used by our documentation tooling and can be ignored or removed when you use the samples outside this repository.
## Additional Resources
- [Agent Framework Documentation](https://learn.microsoft.com/agent-framework/)