mirror of
https://github.com/microsoft/agent-framework.git
synced 2026-06-16 21:04:09 +08:00
* Fix #3613 message typing across chat and agents * Address #3613 review feedback and sample input style * refactor: use shared AgentRunMessages aliases (#3613) * refactor: rename agent run input aliases for #3613 * samples: inline image content in run calls * core: export AgentRunInputs from package init * core: use explicit init re-exports without __all__ * updated logging and inits * Fix core mypy export and samples XML note * Remove AgentRunInputsOrNone and dedupe loggers * Remove prepare_messages helper * fix integration tests
This commit is contained in:
committed by
GitHub
Unverified
parent
503eb10fdd
commit
dc9439a75a
+6
-1
@@ -81,7 +81,12 @@ from agent_framework.azure import AzureOpenAIChatClient
|
||||
|
||||
## Public API and Exports
|
||||
|
||||
Define `__all__` in each module. Avoid `from module import *` in `__init__.py` files:
|
||||
In `__init__.py` files that define package-level public APIs, use direct re-export imports plus an explicit
|
||||
`__all__`. Avoid identity aliases like `from ._agents import ChatAgent as ChatAgent`, and avoid
|
||||
`from module import *`.
|
||||
|
||||
Do not define `__all__` in internal non-`__init__.py` modules. Exception: modules intentionally exposed as a
|
||||
public import surface (for example, `agent_framework.observability`) should define `__all__`.
|
||||
|
||||
```python
|
||||
__all__ = ["ChatAgent", "Message", "ChatResponse"]
|
||||
|
||||
+19
-28
@@ -130,27 +130,6 @@ user_msg = UserMessage(content="Hello, world!")
|
||||
asst_msg = AssistantMessage(content="Hello, world!")
|
||||
```
|
||||
|
||||
### Logging
|
||||
|
||||
Use the centralized logging system:
|
||||
|
||||
```python
|
||||
from agent_framework import get_logger
|
||||
|
||||
# For main package
|
||||
logger = get_logger()
|
||||
|
||||
# For subpackages
|
||||
logger = get_logger('agent_framework.azure')
|
||||
```
|
||||
|
||||
**Do not use** direct logging module imports:
|
||||
```python
|
||||
# ❌ Avoid this
|
||||
import logging
|
||||
logger = logging.getLogger(__name__)
|
||||
```
|
||||
|
||||
### Import Structure
|
||||
|
||||
The package follows a flat import structure:
|
||||
@@ -189,8 +168,6 @@ python/
|
||||
│ │ ├── _clients.py # Chat client protocols and base classes
|
||||
│ │ ├── _tools.py # Tool definitions
|
||||
│ │ ├── _types.py # Type definitions
|
||||
│ │ ├── _logging.py # Logging utilities
|
||||
│ │ │
|
||||
│ │ │ # Provider folders - lazy load from connector packages
|
||||
│ │ ├── openai/ # OpenAI clients (built into core)
|
||||
│ │ ├── azure/ # Lazy loads from azure-ai, azure-ai-search, azurefunctions
|
||||
@@ -405,12 +382,15 @@ If in doubt, use the link above to read much more considerations of what to do a
|
||||
|
||||
**All wildcard imports (`from ... import *`) are prohibited** in production code, including both `.py` and `.pyi` files. Always use explicit import lists to maintain clarity and avoid namespace pollution.
|
||||
|
||||
Define `__all__` in each module to explicitly declare the public API, then import specific symbols by name:
|
||||
Do not use ``__all__`` in internal modules. Define it in the ``__init__`` file of the level you want to expose.
|
||||
If a non-``__init__`` module is intentionally part of the public API surface (for example, ``observability.py``),
|
||||
it should define ``__all__`` as well.
|
||||
|
||||
Also avoid identity alias imports in ``__init__`` files. Use ``from ._module import Symbol`` instead of
|
||||
``from ._module import Symbol as Symbol``.
|
||||
|
||||
```python
|
||||
# ✅ Preferred - explicit __all__ and named imports
|
||||
__all__ = ["Agent", "Message", "ChatResponse"]
|
||||
|
||||
from ._agents import Agent
|
||||
from ._types import Message, ChatResponse
|
||||
|
||||
@@ -422,9 +402,20 @@ from ._types import (
|
||||
ResponseStream,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"Agent",
|
||||
"AgentResponse",
|
||||
"ChatResponse",
|
||||
"Message",
|
||||
"ResponseStream",
|
||||
]
|
||||
|
||||
# ❌ Prohibited pattern: wildcard/star imports (do not use)
|
||||
# from ._agents import <all public symbols>
|
||||
# from ._types import <all public symbols>
|
||||
# from ._agents import *
|
||||
# from ._types import *
|
||||
|
||||
# ❌ Prohibited pattern: identity alias imports (do not use)
|
||||
# from ._agents import Agent as Agent
|
||||
```
|
||||
|
||||
**Rationale:**
|
||||
|
||||
@@ -40,6 +40,7 @@ from agent_framework import (
|
||||
normalize_messages,
|
||||
prepend_agent_framework_to_user_agent,
|
||||
)
|
||||
from agent_framework._types import AgentRunInputs
|
||||
from agent_framework.observability import AgentTelemetryLayer
|
||||
|
||||
__all__ = ["A2AAgent", "A2AContinuationToken"]
|
||||
@@ -208,7 +209,7 @@ class A2AAgent(AgentTelemetryLayer, BaseAgent):
|
||||
@overload
|
||||
def run(
|
||||
self,
|
||||
messages: str | Message | Sequence[str | Message] | None = None,
|
||||
messages: AgentRunInputs | None = None,
|
||||
*,
|
||||
stream: Literal[False] = ...,
|
||||
session: AgentSession | None = None,
|
||||
@@ -220,7 +221,7 @@ class A2AAgent(AgentTelemetryLayer, BaseAgent):
|
||||
@overload
|
||||
def run(
|
||||
self,
|
||||
messages: str | Message | Sequence[str | Message] | None = None,
|
||||
messages: AgentRunInputs | None = None,
|
||||
*,
|
||||
stream: Literal[True],
|
||||
session: AgentSession | None = None,
|
||||
@@ -231,7 +232,7 @@ class A2AAgent(AgentTelemetryLayer, BaseAgent):
|
||||
|
||||
def run(
|
||||
self,
|
||||
messages: str | Message | Sequence[str | Message] | None = None,
|
||||
messages: AgentRunInputs | None = None,
|
||||
*,
|
||||
stream: bool = False,
|
||||
session: AgentSession | None = None,
|
||||
|
||||
@@ -49,7 +49,7 @@ if TYPE_CHECKING:
|
||||
|
||||
from ._types import AGUIChatOptions
|
||||
|
||||
logger: logging.Logger = logging.getLogger(__name__)
|
||||
logger: logging.Logger = logging.getLogger("agent_framework.ag_ui")
|
||||
|
||||
|
||||
def _unwrap_server_function_call_contents(contents: MutableSequence[Content | dict[str, Any]]) -> None:
|
||||
@@ -277,9 +277,6 @@ class AGUIChatClient(
|
||||
registered: set[str] = getattr(self, "_registered_server_tools", set())
|
||||
registered.add(tool_name)
|
||||
self._registered_server_tools = registered # type: ignore[attr-defined]
|
||||
from agent_framework._logging import get_logger
|
||||
|
||||
logger = get_logger()
|
||||
logger.debug(f"[AGUIChatClient] Registered server placeholder: {tool_name}")
|
||||
|
||||
def _extract_state_from_messages(self, messages: Sequence[Message]) -> tuple[list[Message], dict[str, Any] | None]:
|
||||
@@ -310,9 +307,6 @@ class AGUIChatClient(
|
||||
messages_without_state = list(messages[:-1]) if len(messages) > 1 else []
|
||||
return messages_without_state, state
|
||||
except (json.JSONDecodeError, ValueError, KeyError) as e:
|
||||
from agent_framework._logging import get_logger
|
||||
|
||||
logger = get_logger()
|
||||
logger.warning(f"Failed to extract state from message: {e}")
|
||||
|
||||
return list(messages), None
|
||||
|
||||
@@ -408,9 +408,6 @@ def agui_messages_to_agent_framework(messages: list[dict[str, Any]]) -> list[Mes
|
||||
approval_payload_text = result_content if isinstance(result_content, str) else json.dumps(parsed)
|
||||
|
||||
# Log the full approval payload to debug modified arguments
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
logger.info(f"Approval payload received: {parsed}")
|
||||
|
||||
approval_call_id = tool_call_id
|
||||
|
||||
@@ -11,7 +11,7 @@ import asyncio
|
||||
import os
|
||||
from typing import cast
|
||||
|
||||
from agent_framework import ChatResponse, ChatResponseUpdate, ResponseStream
|
||||
from agent_framework import ChatResponse, ChatResponseUpdate, Message, ResponseStream
|
||||
from agent_framework.ag_ui import AGUIChatClient
|
||||
|
||||
|
||||
@@ -44,7 +44,7 @@ async def main():
|
||||
metadata = {"thread_id": thread_id} if thread_id else None
|
||||
|
||||
stream = client.get_response(
|
||||
message,
|
||||
[Message(role="user", text=message)],
|
||||
stream=True,
|
||||
options={"metadata": metadata} if metadata else None,
|
||||
)
|
||||
|
||||
@@ -15,7 +15,7 @@ import asyncio
|
||||
import os
|
||||
from typing import cast
|
||||
|
||||
from agent_framework import ChatResponse, ChatResponseUpdate, ResponseStream, tool
|
||||
from agent_framework import ChatResponse, ChatResponseUpdate, Message, ResponseStream, tool
|
||||
from agent_framework.ag_ui import AGUIChatClient
|
||||
|
||||
|
||||
@@ -73,7 +73,7 @@ async def streaming_example(client: AGUIChatClient, thread_id: str | None = None
|
||||
print("Assistant: ", end="", flush=True)
|
||||
|
||||
stream = client.get_response(
|
||||
"Tell me a short joke",
|
||||
[Message(role="user", text="Tell me a short joke")],
|
||||
stream=True,
|
||||
options={"metadata": metadata} if metadata else None,
|
||||
)
|
||||
@@ -100,7 +100,7 @@ async def non_streaming_example(client: AGUIChatClient, thread_id: str | None =
|
||||
|
||||
print("\nUser: What is 2 + 2?\n")
|
||||
|
||||
response = await client.get_response("What is 2 + 2?", metadata=metadata)
|
||||
response = await client.get_response([Message(role="user", text="What is 2 + 2?")], metadata=metadata)
|
||||
|
||||
print(f"Assistant: {response.text}")
|
||||
|
||||
@@ -139,7 +139,7 @@ async def tool_example(client: AGUIChatClient, thread_id: str | None = None):
|
||||
print("(Server must be configured with matching tools to execute them)\n")
|
||||
|
||||
response = await client.get_response(
|
||||
"What's the weather in Seattle?", tools=[get_weather, calculate], metadata=metadata
|
||||
[Message(role="user", text="What's the weather in Seattle?")], tools=[get_weather, calculate], metadata=metadata
|
||||
)
|
||||
|
||||
print(f"Assistant: {response.text}")
|
||||
@@ -174,14 +174,16 @@ async def conversation_example(client: AGUIChatClient):
|
||||
|
||||
# First turn
|
||||
print("User: My name is Alice\n")
|
||||
response1 = await client.get_response("My name is Alice")
|
||||
response1 = await client.get_response([Message(role="user", text="My name is Alice")])
|
||||
print(f"Assistant: {response1.text}")
|
||||
thread_id = response1.additional_properties.get("thread_id")
|
||||
print(f"\n[Thread: {thread_id}]")
|
||||
|
||||
# Second turn - using same thread
|
||||
print("\nUser: What's my name?\n")
|
||||
response2 = await client.get_response("What's my name?", options={"metadata": {"thread_id": thread_id}})
|
||||
response2 = await client.get_response(
|
||||
[Message(role="user", text="What's my name?")], options={"metadata": {"thread_id": thread_id}}
|
||||
)
|
||||
print(f"Assistant: {response2.text}")
|
||||
|
||||
# Check if context was maintained
|
||||
@@ -191,7 +193,9 @@ async def conversation_example(client: AGUIChatClient):
|
||||
# Third turn
|
||||
print("\nUser: Can you also tell me what 10 * 5 is?\n")
|
||||
response3 = await client.get_response(
|
||||
"Can you also tell me what 10 * 5 is?", options={"metadata": {"thread_id": thread_id}}, tools=[calculate]
|
||||
[Message(role="user", text="Can you also tell me what 10 * 5 is?")],
|
||||
options={"metadata": {"thread_id": thread_id}},
|
||||
tools=[calculate],
|
||||
)
|
||||
print(f"Assistant: {response3.text}")
|
||||
|
||||
|
||||
@@ -55,7 +55,7 @@ class StreamingChatClientStub(
|
||||
@overload
|
||||
def get_response(
|
||||
self,
|
||||
messages: str | Message | Sequence[str | Message],
|
||||
messages: Sequence[Message],
|
||||
*,
|
||||
stream: Literal[False] = ...,
|
||||
options: ChatOptions[Any],
|
||||
@@ -65,7 +65,7 @@ class StreamingChatClientStub(
|
||||
@overload
|
||||
def get_response(
|
||||
self,
|
||||
messages: str | Message | Sequence[str | Message],
|
||||
messages: Sequence[Message],
|
||||
*,
|
||||
stream: Literal[False] = ...,
|
||||
options: OptionsCoT | ChatOptions[None] | None = ...,
|
||||
@@ -75,7 +75,7 @@ class StreamingChatClientStub(
|
||||
@overload
|
||||
def get_response(
|
||||
self,
|
||||
messages: str | Message | Sequence[str | Message],
|
||||
messages: Sequence[Message],
|
||||
*,
|
||||
stream: Literal[True],
|
||||
options: OptionsCoT | ChatOptions[Any] | None = ...,
|
||||
@@ -84,7 +84,7 @@ class StreamingChatClientStub(
|
||||
|
||||
def get_response(
|
||||
self,
|
||||
messages: str | Message | Sequence[str | Message],
|
||||
messages: Sequence[Message],
|
||||
*,
|
||||
stream: bool = False,
|
||||
options: OptionsCoT | ChatOptions[Any] | None = None,
|
||||
@@ -175,7 +175,7 @@ class StubAgent(SupportsAgentRun):
|
||||
@overload
|
||||
def run(
|
||||
self,
|
||||
messages: str | Message | Sequence[str | Message] | None = None,
|
||||
messages: str | Content | Message | Sequence[str | Content | Message] | None = None,
|
||||
*,
|
||||
stream: Literal[False] = ...,
|
||||
session: AgentSession | None = None,
|
||||
@@ -185,7 +185,7 @@ class StubAgent(SupportsAgentRun):
|
||||
@overload
|
||||
def run(
|
||||
self,
|
||||
messages: str | Message | Sequence[str | Message] | None = None,
|
||||
messages: str | Content | Message | Sequence[str | Content | Message] | None = None,
|
||||
*,
|
||||
stream: Literal[True],
|
||||
session: AgentSession | None = None,
|
||||
@@ -194,7 +194,7 @@ class StubAgent(SupportsAgentRun):
|
||||
|
||||
def run(
|
||||
self,
|
||||
messages: str | Message | Sequence[str | Message] | None = None,
|
||||
messages: str | Content | Message | Sequence[str | Content | Message] | None = None,
|
||||
*,
|
||||
stream: bool = False,
|
||||
session: AgentSession | None = None,
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import sys
|
||||
from collections.abc import AsyncIterable, Awaitable, Mapping, MutableMapping, Sequence
|
||||
from typing import Any, ClassVar, Final, Generic, Literal, TypedDict
|
||||
@@ -24,7 +25,6 @@ from agent_framework import (
|
||||
ResponseStream,
|
||||
TextSpanRegion,
|
||||
UsageDetails,
|
||||
get_logger,
|
||||
)
|
||||
from agent_framework._settings import SecretString, load_settings
|
||||
from agent_framework._types import _get_data_bytes_as_str # type: ignore
|
||||
@@ -68,7 +68,7 @@ __all__ = [
|
||||
"ThinkingConfig",
|
||||
]
|
||||
|
||||
logger = get_logger("agent_framework.anthropic")
|
||||
logger = logging.getLogger("agent_framework.anthropic")
|
||||
|
||||
ANTHROPIC_DEFAULT_MAX_TOKENS: Final[int] = 1024
|
||||
BETA_FLAGS: Final[list[str]] = ["mcp-client-2025-04-04", "code-execution-2025-08-25"]
|
||||
|
||||
+2
-2
@@ -8,12 +8,12 @@ This module provides ``AzureAISearchContextProvider``, built on the new
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import sys
|
||||
from collections.abc import Awaitable, Callable
|
||||
from typing import TYPE_CHECKING, Any, ClassVar, Literal, TypedDict
|
||||
|
||||
from agent_framework import AGENT_FRAMEWORK_USER_AGENT, Message
|
||||
from agent_framework._logging import get_logger
|
||||
from agent_framework._sessions import AgentSession, BaseContextProvider, SessionContext
|
||||
from agent_framework._settings import SecretString, load_settings
|
||||
from agent_framework.exceptions import ServiceInitializationError
|
||||
@@ -103,7 +103,7 @@ try:
|
||||
except ImportError:
|
||||
_agentic_retrieval_available = False
|
||||
|
||||
logger = get_logger(__name__)
|
||||
logger = logging.getLogger("agent_framework.azure_ai_search")
|
||||
|
||||
_DEFAULT_AGENTIC_MESSAGE_HISTORY_COUNT = 10
|
||||
|
||||
|
||||
@@ -4,6 +4,7 @@ from __future__ import annotations
|
||||
|
||||
import ast
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import re
|
||||
import sys
|
||||
@@ -31,7 +32,6 @@ from agent_framework import (
|
||||
Role,
|
||||
TextSpanRegion,
|
||||
UsageDetails,
|
||||
get_logger,
|
||||
)
|
||||
from agent_framework._settings import load_settings
|
||||
from agent_framework.exceptions import ServiceInitializationError, ServiceInvalidRequestError, ServiceResponseException
|
||||
@@ -102,7 +102,7 @@ else:
|
||||
from typing_extensions import Self, TypedDict # type: ignore # pragma: no cover
|
||||
|
||||
|
||||
logger = get_logger("agent_framework.azure")
|
||||
logger = logging.getLogger("agent_framework.azure")
|
||||
|
||||
__all__ = ["AzureAIAgentClient", "AzureAIAgentOptions"]
|
||||
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import sys
|
||||
from collections.abc import Callable, Mapping, MutableMapping, Sequence
|
||||
from contextlib import suppress
|
||||
@@ -19,7 +20,6 @@ from agent_framework import (
|
||||
FunctionTool,
|
||||
Message,
|
||||
MiddlewareTypes,
|
||||
get_logger,
|
||||
)
|
||||
from agent_framework._settings import load_settings
|
||||
from agent_framework.exceptions import ServiceInitializationError
|
||||
@@ -59,7 +59,7 @@ else:
|
||||
from typing_extensions import Self, TypedDict # type: ignore # pragma: no cover
|
||||
|
||||
|
||||
logger = get_logger("agent_framework.azure")
|
||||
logger = logging.getLogger("agent_framework.azure")
|
||||
|
||||
|
||||
class AzureAIProjectAgentOptions(OpenAIResponsesOptions, total=False):
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import sys
|
||||
from collections.abc import Callable, MutableMapping, Sequence
|
||||
from typing import Any, Generic
|
||||
@@ -12,7 +13,6 @@ from agent_framework import (
|
||||
BaseContextProvider,
|
||||
FunctionTool,
|
||||
MiddlewareTypes,
|
||||
get_logger,
|
||||
normalize_tools,
|
||||
)
|
||||
from agent_framework._mcp import MCPTool
|
||||
@@ -43,7 +43,7 @@ else:
|
||||
from typing_extensions import Self, TypedDict # type: ignore # pragma: no cover
|
||||
|
||||
|
||||
logger = get_logger("agent_framework.azure")
|
||||
logger = logging.getLogger("agent_framework.azure")
|
||||
|
||||
|
||||
# Type variable for options - allows typed Agent[OptionsT] returns
|
||||
|
||||
@@ -2,13 +2,13 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import sys
|
||||
from collections.abc import Mapping, MutableMapping, Sequence
|
||||
from typing import Any, cast
|
||||
|
||||
from agent_framework import (
|
||||
FunctionTool,
|
||||
get_logger,
|
||||
)
|
||||
from agent_framework.exceptions import ServiceInvalidRequestError
|
||||
from azure.ai.agents.models import (
|
||||
@@ -37,7 +37,7 @@ if sys.version_info >= (3, 11):
|
||||
else:
|
||||
from typing_extensions import TypedDict # type: ignore # pragma: no cover
|
||||
|
||||
logger = get_logger("agent_framework.azure")
|
||||
logger = logging.getLogger("agent_framework.azure")
|
||||
|
||||
|
||||
class AzureAISettings(TypedDict, total=False):
|
||||
|
||||
@@ -134,7 +134,7 @@ def create_test_azure_ai_client(
|
||||
client._created_agent_tool_names = set() # type: ignore
|
||||
client._created_agent_structured_output_signature = None # type: ignore
|
||||
client.additional_properties = {}
|
||||
client.middleware = None
|
||||
client.chat_middleware = []
|
||||
|
||||
# Mock the OpenAI client attribute
|
||||
mock_openai_client = MagicMock()
|
||||
@@ -1546,7 +1546,12 @@ async def test_integration_web_search() -> None:
|
||||
async with temporary_chat_client(agent_name="af-int-test-web-search") as client:
|
||||
for streaming in [False, True]:
|
||||
content = {
|
||||
"messages": "Who are the main characters of Kpop Demon Hunters? Do a web search to find the answer.",
|
||||
"messages": [
|
||||
Message(
|
||||
role="user",
|
||||
text="Who are the main characters of Kpop Demon Hunters? Do a web search to find the answer.",
|
||||
)
|
||||
],
|
||||
"options": {
|
||||
"tool_choice": "auto",
|
||||
"tools": [client.get_web_search_tool()],
|
||||
@@ -1565,7 +1570,9 @@ async def test_integration_web_search() -> None:
|
||||
|
||||
# Test that the client will use the web search tool with location
|
||||
content = {
|
||||
"messages": "What is the current weather? Do not ask for my current location.",
|
||||
"messages": [
|
||||
Message(role="user", text="What is the current weather? Do not ask for my current location.")
|
||||
],
|
||||
"options": {
|
||||
"tool_choice": "auto",
|
||||
"tools": [client.get_web_search_tool(user_location={"country": "US", "city": "Seattle"})],
|
||||
@@ -1584,7 +1591,7 @@ async def test_integration_agent_hosted_mcp_tool() -> None:
|
||||
"""Integration test for MCP tool with Azure Response Agent using Microsoft Learn MCP."""
|
||||
async with temporary_chat_client(agent_name="af-int-test-mcp") as client:
|
||||
response = await client.get_response(
|
||||
"How to create an Azure storage account using az cli?",
|
||||
messages=[Message(role="user", text="How to create an Azure storage account using az cli?")],
|
||||
options={
|
||||
# this needs to be high enough to handle the full MCP tool response.
|
||||
"max_tokens": 5000,
|
||||
@@ -1608,7 +1615,7 @@ async def test_integration_agent_hosted_code_interpreter_tool():
|
||||
"""Test Azure Responses Client agent with code interpreter tool through AzureAIClient."""
|
||||
async with temporary_chat_client(agent_name="af-int-test-code-interpreter") as client:
|
||||
response = await client.get_response(
|
||||
"Calculate the sum of numbers from 1 to 10 using Python code.",
|
||||
messages=[Message(role="user", text="Calculate the sum of numbers from 1 to 10 using Python code.")],
|
||||
options={
|
||||
"tools": [client.get_code_interpreter_tool()],
|
||||
},
|
||||
|
||||
@@ -9,6 +9,7 @@ with Azure Durable Entities, enabling stateful and durable AI agent execution.
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
import uuid
|
||||
from collections.abc import Callable, Mapping
|
||||
@@ -18,7 +19,7 @@ from typing import TYPE_CHECKING, Any, TypeVar, cast
|
||||
|
||||
import azure.durable_functions as df
|
||||
import azure.functions as func
|
||||
from agent_framework import SupportsAgentRun, get_logger
|
||||
from agent_framework import SupportsAgentRun
|
||||
from agent_framework_durabletask import (
|
||||
DEFAULT_MAX_POLL_RETRIES,
|
||||
DEFAULT_POLL_INTERVAL_SECONDS,
|
||||
@@ -42,7 +43,7 @@ from ._entities import create_agent_entity
|
||||
from ._errors import IncomingRequestError
|
||||
from ._orchestration import AgentOrchestrationContextType, AgentTask, AzureFunctionsAgentExecutor
|
||||
|
||||
logger = get_logger("agent_framework.azurefunctions")
|
||||
logger = logging.getLogger("agent_framework.azurefunctions")
|
||||
|
||||
EntityHandler = Callable[[df.DurableEntityContext], None]
|
||||
HandlerT = TypeVar("HandlerT", bound=Callable[..., Any])
|
||||
|
||||
@@ -10,18 +10,19 @@ allows for long-running agent conversations.
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from collections.abc import Callable
|
||||
from typing import Any, cast
|
||||
|
||||
import azure.durable_functions as df
|
||||
from agent_framework import SupportsAgentRun, get_logger
|
||||
from agent_framework import SupportsAgentRun
|
||||
from agent_framework_durabletask import (
|
||||
AgentEntity,
|
||||
AgentEntityStateProviderMixin,
|
||||
AgentResponseCallbackProtocol,
|
||||
)
|
||||
|
||||
logger = get_logger("agent_framework.azurefunctions.entities")
|
||||
logger = logging.getLogger("agent_framework.azurefunctions")
|
||||
|
||||
|
||||
class AzureFunctionEntityStateProvider(AgentEntityStateProviderMixin):
|
||||
|
||||
@@ -5,11 +5,12 @@
|
||||
This module provides support for using agents inside Durable Function orchestrations.
|
||||
"""
|
||||
|
||||
import logging
|
||||
from collections.abc import Callable
|
||||
from typing import TYPE_CHECKING, Any, TypeAlias
|
||||
|
||||
import azure.durable_functions as df
|
||||
from agent_framework import AgentSession, get_logger
|
||||
from agent_framework import AgentSession
|
||||
from agent_framework_durabletask import (
|
||||
DurableAgentExecutor,
|
||||
RunRequest,
|
||||
@@ -21,7 +22,7 @@ from azure.durable_functions.models.actions.NoOpAction import NoOpAction
|
||||
from azure.durable_functions.models.Task import CompoundTask, TaskState
|
||||
from pydantic import BaseModel
|
||||
|
||||
logger = get_logger("agent_framework.azurefunctions.orchestration")
|
||||
logger = logging.getLogger("agent_framework.azurefunctions")
|
||||
|
||||
CompoundActionConstructor: TypeAlias = Callable[[list[Any]], Any] | None
|
||||
|
||||
|
||||
@@ -4,6 +4,7 @@ from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import sys
|
||||
from collections import deque
|
||||
from collections.abc import AsyncIterable, Awaitable, Mapping, MutableMapping, Sequence
|
||||
@@ -26,7 +27,6 @@ from agent_framework import (
|
||||
Message,
|
||||
ResponseStream,
|
||||
UsageDetails,
|
||||
get_logger,
|
||||
validate_tool_mode,
|
||||
)
|
||||
from agent_framework._settings import SecretString, load_settings
|
||||
@@ -50,7 +50,7 @@ if sys.version_info >= (3, 11):
|
||||
else:
|
||||
from typing_extensions import TypedDict # type: ignore # pragma: no cover
|
||||
|
||||
logger = get_logger("agent_framework.bedrock")
|
||||
logger = logging.getLogger("agent_framework.bedrock")
|
||||
|
||||
|
||||
__all__ = [
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import contextlib
|
||||
import logging
|
||||
import sys
|
||||
from collections.abc import AsyncIterable, Awaitable, Callable, MutableMapping, Sequence
|
||||
from pathlib import Path
|
||||
@@ -19,11 +20,10 @@ from agent_framework import (
|
||||
FunctionTool,
|
||||
Message,
|
||||
ResponseStream,
|
||||
get_logger,
|
||||
normalize_messages,
|
||||
)
|
||||
from agent_framework._settings import load_settings
|
||||
from agent_framework._types import normalize_tools
|
||||
from agent_framework._types import AgentRunInputs, normalize_tools
|
||||
from agent_framework.exceptions import ServiceException
|
||||
from claude_agent_sdk import (
|
||||
AssistantMessage,
|
||||
@@ -61,7 +61,7 @@ if TYPE_CHECKING:
|
||||
|
||||
__all__ = ["ClaudeAgent", "ClaudeAgentOptions"]
|
||||
|
||||
logger = get_logger("agent_framework.claude")
|
||||
logger = logging.getLogger("agent_framework.claude")
|
||||
|
||||
# Name of the in-process MCP server that hosts Agent Framework tools.
|
||||
# FunctionTool instances are converted to SDK MCP tools and served
|
||||
@@ -557,7 +557,7 @@ class ClaudeAgent(BaseAgent, Generic[OptionsT]):
|
||||
@overload
|
||||
def run(
|
||||
self,
|
||||
messages: str | Message | Sequence[str | Message] | None = None,
|
||||
messages: AgentRunInputs | None = None,
|
||||
*,
|
||||
stream: Literal[True],
|
||||
session: AgentSession | None = None,
|
||||
@@ -568,7 +568,7 @@ class ClaudeAgent(BaseAgent, Generic[OptionsT]):
|
||||
@overload
|
||||
async def run(
|
||||
self,
|
||||
messages: str | Message | Sequence[str | Message] | None = None,
|
||||
messages: AgentRunInputs | None = None,
|
||||
*,
|
||||
stream: Literal[False] = ...,
|
||||
session: AgentSession | None = None,
|
||||
@@ -578,7 +578,7 @@ class ClaudeAgent(BaseAgent, Generic[OptionsT]):
|
||||
|
||||
def run(
|
||||
self,
|
||||
messages: str | Message | Sequence[str | Message] | None = None,
|
||||
messages: AgentRunInputs | None = None,
|
||||
*,
|
||||
stream: bool = False,
|
||||
session: AgentSession | None = None,
|
||||
@@ -612,7 +612,7 @@ class ClaudeAgent(BaseAgent, Generic[OptionsT]):
|
||||
|
||||
async def _get_stream(
|
||||
self,
|
||||
messages: str | Message | Sequence[str | Message] | None = None,
|
||||
messages: AgentRunInputs | None = None,
|
||||
*,
|
||||
session: AgentSession | None = None,
|
||||
options: OptionsT | MutableMapping[str, Any] | None = None,
|
||||
|
||||
@@ -18,6 +18,7 @@ from agent_framework import (
|
||||
normalize_messages,
|
||||
)
|
||||
from agent_framework._settings import load_settings
|
||||
from agent_framework._types import AgentRunInputs
|
||||
from agent_framework.exceptions import ServiceException, ServiceInitializationError
|
||||
from microsoft_agents.copilotstudio.client import AgentType, ConnectionSettings, CopilotClient, PowerPlatformCloud
|
||||
|
||||
@@ -187,7 +188,7 @@ class CopilotStudioAgent(BaseAgent):
|
||||
@overload
|
||||
def run(
|
||||
self,
|
||||
messages: str | Message | list[str] | list[Message] | None = None,
|
||||
messages: AgentRunInputs | None = None,
|
||||
*,
|
||||
stream: Literal[False] = False,
|
||||
session: AgentSession | None = None,
|
||||
@@ -197,7 +198,7 @@ class CopilotStudioAgent(BaseAgent):
|
||||
@overload
|
||||
def run(
|
||||
self,
|
||||
messages: str | Message | list[str] | list[Message] | None = None,
|
||||
messages: AgentRunInputs | None = None,
|
||||
*,
|
||||
stream: Literal[True],
|
||||
session: AgentSession | None = None,
|
||||
@@ -206,7 +207,7 @@ class CopilotStudioAgent(BaseAgent):
|
||||
|
||||
def run(
|
||||
self,
|
||||
messages: str | Message | list[str] | list[Message] | None = None,
|
||||
messages: AgentRunInputs | None = None,
|
||||
*,
|
||||
stream: bool = False,
|
||||
session: AgentSession | None = None,
|
||||
@@ -236,7 +237,7 @@ class CopilotStudioAgent(BaseAgent):
|
||||
|
||||
async def _run_impl(
|
||||
self,
|
||||
messages: str | Message | list[str] | list[Message] | None = None,
|
||||
messages: AgentRunInputs | None = None,
|
||||
*,
|
||||
session: AgentSession | None = None,
|
||||
**kwargs: Any,
|
||||
@@ -261,7 +262,7 @@ class CopilotStudioAgent(BaseAgent):
|
||||
|
||||
def _run_stream_impl(
|
||||
self,
|
||||
messages: str | Message | list[str] | list[Message] | None = None,
|
||||
messages: AgentRunInputs | None = None,
|
||||
*,
|
||||
session: AgentSession | None = None,
|
||||
**kwargs: Any,
|
||||
|
||||
@@ -19,7 +19,6 @@ from ._clients import (
|
||||
SupportsMCPTool,
|
||||
SupportsWebSearchTool,
|
||||
)
|
||||
from ._logging import get_logger, setup_logging
|
||||
from ._mcp import MCPStdioTool, MCPStreamableHTTPTool, MCPWebsocketTool
|
||||
from ._middleware import (
|
||||
AgentContext,
|
||||
@@ -34,7 +33,6 @@ from ._middleware import (
|
||||
FunctionInvocationContext,
|
||||
FunctionMiddleware,
|
||||
FunctionMiddlewareTypes,
|
||||
MiddlewareException,
|
||||
MiddlewareTermination,
|
||||
MiddlewareType,
|
||||
MiddlewareTypes,
|
||||
@@ -67,6 +65,7 @@ from ._tools import (
|
||||
from ._types import (
|
||||
AgentResponse,
|
||||
AgentResponseUpdate,
|
||||
AgentRunInputs,
|
||||
Annotation,
|
||||
ChatOptions,
|
||||
ChatResponse,
|
||||
@@ -152,6 +151,7 @@ from ._workflows import (
|
||||
response_handler,
|
||||
validate_workflow_graph,
|
||||
)
|
||||
from .exceptions import MiddlewareException
|
||||
|
||||
__all__ = [
|
||||
"AGENT_FRAMEWORK_USER_AGENT",
|
||||
@@ -169,6 +169,7 @@ __all__ = [
|
||||
"AgentMiddlewareTypes",
|
||||
"AgentResponse",
|
||||
"AgentResponseUpdate",
|
||||
"AgentRunInputs",
|
||||
"AgentSession",
|
||||
"Annotation",
|
||||
"BaseAgent",
|
||||
@@ -264,6 +265,7 @@ __all__ = [
|
||||
"WorkflowRunnerException",
|
||||
"WorkflowValidationError",
|
||||
"WorkflowViz",
|
||||
"__version__",
|
||||
"add_usage_details",
|
||||
"agent_middleware",
|
||||
"chat_middleware",
|
||||
@@ -271,7 +273,6 @@ __all__ = [
|
||||
"detect_media_type_from_base64",
|
||||
"executor",
|
||||
"function_middleware",
|
||||
"get_logger",
|
||||
"handler",
|
||||
"map_chat_to_agent_update",
|
||||
"merge_chat_options",
|
||||
@@ -283,7 +284,6 @@ __all__ = [
|
||||
"register_state_type",
|
||||
"resolve_agent_id",
|
||||
"response_handler",
|
||||
"setup_logging",
|
||||
"tool",
|
||||
"validate_chat_options",
|
||||
"validate_tool_mode",
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import inspect
|
||||
import logging
|
||||
import re
|
||||
import sys
|
||||
from collections.abc import Awaitable, Callable, Mapping, MutableMapping, Sequence
|
||||
@@ -29,7 +30,6 @@ from mcp.shared.exceptions import McpError
|
||||
from pydantic import BaseModel, Field, create_model
|
||||
|
||||
from ._clients import BaseChatClient, SupportsChatGetResponse
|
||||
from ._logging import get_logger
|
||||
from ._mcp import LOG_LEVEL_MAPPING, MCPTool
|
||||
from ._middleware import AgentMiddlewareLayer, MiddlewareTypes
|
||||
from ._serialization import SerializationMixin
|
||||
@@ -41,6 +41,7 @@ from ._tools import (
|
||||
from ._types import (
|
||||
AgentResponse,
|
||||
AgentResponseUpdate,
|
||||
AgentRunInputs,
|
||||
ChatResponse,
|
||||
ChatResponseUpdate,
|
||||
Message,
|
||||
@@ -67,7 +68,7 @@ else:
|
||||
if TYPE_CHECKING:
|
||||
from ._types import ChatOptions
|
||||
|
||||
logger = get_logger("agent_framework")
|
||||
logger = logging.getLogger("agent_framework")
|
||||
|
||||
ResponseModelBoundT = TypeVar("ResponseModelBoundT", bound=BaseModel)
|
||||
OptionsCoT = TypeVar(
|
||||
@@ -159,9 +160,6 @@ class _RunContext(TypedDict):
|
||||
finalize_kwargs: dict[str, Any]
|
||||
|
||||
|
||||
__all__ = ["Agent", "BaseAgent", "RawAgent", "SupportsAgentRun"]
|
||||
|
||||
|
||||
# region Agent Protocol
|
||||
|
||||
|
||||
@@ -230,7 +228,7 @@ class SupportsAgentRun(Protocol):
|
||||
@overload
|
||||
def run(
|
||||
self,
|
||||
messages: str | Message | Sequence[str | Message] | None = None,
|
||||
messages: AgentRunInputs | None = None,
|
||||
*,
|
||||
stream: Literal[False] = ...,
|
||||
session: AgentSession | None = None,
|
||||
@@ -242,7 +240,7 @@ class SupportsAgentRun(Protocol):
|
||||
@overload
|
||||
def run(
|
||||
self,
|
||||
messages: str | Message | Sequence[str | Message] | None = None,
|
||||
messages: AgentRunInputs | None = None,
|
||||
*,
|
||||
stream: Literal[True],
|
||||
session: AgentSession | None = None,
|
||||
@@ -253,7 +251,7 @@ class SupportsAgentRun(Protocol):
|
||||
|
||||
def run(
|
||||
self,
|
||||
messages: str | Message | Sequence[str | Message] | None = None,
|
||||
messages: AgentRunInputs | None = None,
|
||||
*,
|
||||
stream: bool = False,
|
||||
session: AgentSession | None = None,
|
||||
@@ -763,7 +761,7 @@ class RawAgent(BaseAgent, Generic[OptionsCoT]): # type: ignore[misc]
|
||||
@overload
|
||||
def run(
|
||||
self,
|
||||
messages: str | Message | Sequence[str | Message] | None = None,
|
||||
messages: AgentRunInputs | None = None,
|
||||
*,
|
||||
stream: Literal[False] = ...,
|
||||
session: AgentSession | None = None,
|
||||
@@ -780,7 +778,7 @@ class RawAgent(BaseAgent, Generic[OptionsCoT]): # type: ignore[misc]
|
||||
@overload
|
||||
def run(
|
||||
self,
|
||||
messages: str | Message | Sequence[str | Message] | None = None,
|
||||
messages: AgentRunInputs | None = None,
|
||||
*,
|
||||
stream: Literal[False] = ...,
|
||||
session: AgentSession | None = None,
|
||||
@@ -797,7 +795,7 @@ class RawAgent(BaseAgent, Generic[OptionsCoT]): # type: ignore[misc]
|
||||
@overload
|
||||
def run(
|
||||
self,
|
||||
messages: str | Message | Sequence[str | Message] | None = None,
|
||||
messages: AgentRunInputs | None = None,
|
||||
*,
|
||||
stream: Literal[True],
|
||||
session: AgentSession | None = None,
|
||||
@@ -813,7 +811,7 @@ class RawAgent(BaseAgent, Generic[OptionsCoT]): # type: ignore[misc]
|
||||
|
||||
def run(
|
||||
self,
|
||||
messages: str | Message | Sequence[str | Message] | None = None,
|
||||
messages: AgentRunInputs | None = None,
|
||||
*,
|
||||
stream: bool = False,
|
||||
session: AgentSession | None = None,
|
||||
@@ -1000,7 +998,7 @@ class RawAgent(BaseAgent, Generic[OptionsCoT]): # type: ignore[misc]
|
||||
async def _prepare_run_context(
|
||||
self,
|
||||
*,
|
||||
messages: str | Message | Sequence[str | Message] | None,
|
||||
messages: AgentRunInputs | None,
|
||||
session: AgentSession | None,
|
||||
tools: FunctionTool
|
||||
| Callable[..., Any]
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import sys
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import (
|
||||
@@ -27,7 +28,6 @@ from typing import (
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from ._logging import get_logger
|
||||
from ._serialization import SerializationMixin
|
||||
from ._tools import (
|
||||
FunctionInvocationConfiguration,
|
||||
@@ -38,7 +38,6 @@ from ._types import (
|
||||
ChatResponseUpdate,
|
||||
Message,
|
||||
ResponseStream,
|
||||
prepare_messages,
|
||||
validate_chat_options,
|
||||
)
|
||||
|
||||
@@ -61,17 +60,7 @@ InputT = TypeVar("InputT", contravariant=True)
|
||||
EmbeddingT = TypeVar("EmbeddingT")
|
||||
BaseChatClientT = TypeVar("BaseChatClientT", bound="BaseChatClient")
|
||||
|
||||
logger = get_logger()
|
||||
|
||||
__all__ = [
|
||||
"BaseChatClient",
|
||||
"SupportsChatGetResponse",
|
||||
"SupportsCodeInterpreterTool",
|
||||
"SupportsFileSearchTool",
|
||||
"SupportsImageGenerationTool",
|
||||
"SupportsMCPTool",
|
||||
"SupportsWebSearchTool",
|
||||
]
|
||||
logger = logging.getLogger("agent_framework")
|
||||
|
||||
|
||||
# region SupportsChatGetResponse Protocol
|
||||
@@ -139,7 +128,7 @@ class SupportsChatGetResponse(Protocol[OptionsContraT]):
|
||||
@overload
|
||||
def get_response(
|
||||
self,
|
||||
messages: str | Message | Sequence[str | Message],
|
||||
messages: Sequence[Message],
|
||||
*,
|
||||
stream: Literal[False] = ...,
|
||||
options: ChatOptions[ResponseModelBoundT],
|
||||
@@ -149,7 +138,7 @@ class SupportsChatGetResponse(Protocol[OptionsContraT]):
|
||||
@overload
|
||||
def get_response(
|
||||
self,
|
||||
messages: str | Message | Sequence[str | Message],
|
||||
messages: Sequence[Message],
|
||||
*,
|
||||
stream: Literal[False] = ...,
|
||||
options: OptionsContraT | ChatOptions[None] | None = None,
|
||||
@@ -159,7 +148,7 @@ class SupportsChatGetResponse(Protocol[OptionsContraT]):
|
||||
@overload
|
||||
def get_response(
|
||||
self,
|
||||
messages: str | Message | Sequence[str | Message],
|
||||
messages: Sequence[Message],
|
||||
*,
|
||||
stream: Literal[True],
|
||||
options: OptionsContraT | ChatOptions[Any] | None = None,
|
||||
@@ -168,7 +157,7 @@ class SupportsChatGetResponse(Protocol[OptionsContraT]):
|
||||
|
||||
def get_response(
|
||||
self,
|
||||
messages: str | Message | Sequence[str | Message],
|
||||
messages: Sequence[Message],
|
||||
*,
|
||||
stream: bool = False,
|
||||
options: OptionsContraT | ChatOptions[Any] | None = None,
|
||||
@@ -254,9 +243,9 @@ class BaseChatClient(SerializationMixin, ABC, Generic[OptionsCoT]):
|
||||
client = CustomChatClient()
|
||||
|
||||
# Use the client to get responses
|
||||
response = await client.get_response("Hello, how are you?")
|
||||
response = await client.get_response([Message(role="user", text="Hello, how are you?")])
|
||||
# Or stream responses
|
||||
async for update in client.get_response("Hello!", stream=True):
|
||||
async for update in client.get_response([Message(role="user", text="Hello!")], stream=True):
|
||||
print(update)
|
||||
"""
|
||||
|
||||
@@ -376,7 +365,7 @@ class BaseChatClient(SerializationMixin, ABC, Generic[OptionsCoT]):
|
||||
@overload
|
||||
def get_response(
|
||||
self,
|
||||
messages: str | Message | Sequence[str | Message],
|
||||
messages: Sequence[Message],
|
||||
*,
|
||||
stream: Literal[False] = ...,
|
||||
options: ChatOptions[ResponseModelBoundT],
|
||||
@@ -386,7 +375,7 @@ class BaseChatClient(SerializationMixin, ABC, Generic[OptionsCoT]):
|
||||
@overload
|
||||
def get_response(
|
||||
self,
|
||||
messages: str | Message | Sequence[str | Message],
|
||||
messages: Sequence[Message],
|
||||
*,
|
||||
stream: Literal[False] = ...,
|
||||
options: OptionsCoT | ChatOptions[None] | None = None,
|
||||
@@ -396,7 +385,7 @@ class BaseChatClient(SerializationMixin, ABC, Generic[OptionsCoT]):
|
||||
@overload
|
||||
def get_response(
|
||||
self,
|
||||
messages: str | Message | Sequence[str | Message],
|
||||
messages: Sequence[Message],
|
||||
*,
|
||||
stream: Literal[True],
|
||||
options: OptionsCoT | ChatOptions[Any] | None = None,
|
||||
@@ -405,7 +394,7 @@ class BaseChatClient(SerializationMixin, ABC, Generic[OptionsCoT]):
|
||||
|
||||
def get_response(
|
||||
self,
|
||||
messages: str | Message | Sequence[str | Message],
|
||||
messages: Sequence[Message],
|
||||
*,
|
||||
stream: bool = False,
|
||||
options: OptionsCoT | ChatOptions[Any] | None = None,
|
||||
@@ -422,9 +411,8 @@ class BaseChatClient(SerializationMixin, ABC, Generic[OptionsCoT]):
|
||||
Returns:
|
||||
When streaming a response stream of ChatResponseUpdates, otherwise an Awaitable ChatResponse.
|
||||
"""
|
||||
prepared_messages = prepare_messages(messages)
|
||||
return self._inner_get_response(
|
||||
messages=prepared_messages,
|
||||
messages=messages,
|
||||
stream=stream,
|
||||
options=options or {}, # type: ignore[arg-type]
|
||||
**kwargs,
|
||||
|
||||
@@ -1,29 +0,0 @@
|
||||
# Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
import logging
|
||||
|
||||
from .exceptions import AgentFrameworkException
|
||||
|
||||
__all__ = ["get_logger", "setup_logging"]
|
||||
|
||||
|
||||
def setup_logging() -> None:
|
||||
"""Setup the logging configuration for the agent framework."""
|
||||
logging.basicConfig(
|
||||
format="[%(asctime)s - %(pathname)s:%(lineno)d - %(levelname)s] %(message)s",
|
||||
datefmt="%Y-%m-%d %H:%M:%S",
|
||||
)
|
||||
|
||||
|
||||
def get_logger(name: str = "agent_framework") -> logging.Logger:
|
||||
"""Get a logger with the specified name, defaulting to 'agent_framework'.
|
||||
|
||||
Args:
|
||||
name (str): The name of the logger. Defaults to 'agent_framework'.
|
||||
|
||||
Returns:
|
||||
logging.Logger: The configured logger instance.
|
||||
"""
|
||||
if not name.startswith("agent_framework"):
|
||||
raise AgentFrameworkException("Logger name must start with 'agent_framework'.")
|
||||
return logging.getLogger(name)
|
||||
@@ -72,12 +72,6 @@ LOG_LEVEL_MAPPING: dict[types.LoggingLevel, int] = {
|
||||
"emergency": logging.CRITICAL,
|
||||
}
|
||||
|
||||
__all__ = [
|
||||
"MCPStdioTool",
|
||||
"MCPStreamableHTTPTool",
|
||||
"MCPWebsocketTool",
|
||||
]
|
||||
|
||||
|
||||
def _parse_prompt_result_from_mcp(
|
||||
mcp_type: types.GetPromptResult,
|
||||
|
||||
@@ -14,11 +14,12 @@ from ._clients import SupportsChatGetResponse
|
||||
from ._types import (
|
||||
AgentResponse,
|
||||
AgentResponseUpdate,
|
||||
AgentRunInputs,
|
||||
ChatResponse,
|
||||
ChatResponseUpdate,
|
||||
Message,
|
||||
ResponseStream,
|
||||
prepare_messages,
|
||||
normalize_messages,
|
||||
)
|
||||
from .exceptions import MiddlewareException
|
||||
|
||||
@@ -42,27 +43,6 @@ if TYPE_CHECKING:
|
||||
|
||||
ResponseModelBoundT = TypeVar("ResponseModelBoundT", bound=BaseModel)
|
||||
|
||||
__all__ = [
|
||||
"AgentContext",
|
||||
"AgentMiddleware",
|
||||
"AgentMiddlewareLayer",
|
||||
"AgentMiddlewareTypes",
|
||||
"ChatAndFunctionMiddlewareTypes",
|
||||
"ChatContext",
|
||||
"ChatMiddleware",
|
||||
"ChatMiddlewareLayer",
|
||||
"ChatMiddlewareTypes",
|
||||
"FunctionInvocationContext",
|
||||
"FunctionMiddleware",
|
||||
"FunctionMiddlewareTypes",
|
||||
"MiddlewareException",
|
||||
"MiddlewareTermination",
|
||||
"MiddlewareType",
|
||||
"MiddlewareTypes",
|
||||
"agent_middleware",
|
||||
"chat_middleware",
|
||||
"function_middleware",
|
||||
]
|
||||
|
||||
AgentT = TypeVar("AgentT", bound="SupportsAgentRun")
|
||||
ContextT = TypeVar("ContextT")
|
||||
@@ -978,7 +958,7 @@ class ChatMiddlewareLayer(Generic[OptionsCoT]):
|
||||
@overload
|
||||
def get_response(
|
||||
self,
|
||||
messages: str | Message | Sequence[str | Message],
|
||||
messages: Sequence[Message],
|
||||
*,
|
||||
stream: Literal[False] = ...,
|
||||
options: ChatOptions[ResponseModelBoundT],
|
||||
@@ -988,7 +968,7 @@ class ChatMiddlewareLayer(Generic[OptionsCoT]):
|
||||
@overload
|
||||
def get_response(
|
||||
self,
|
||||
messages: str | Message | Sequence[str | Message],
|
||||
messages: Sequence[Message],
|
||||
*,
|
||||
stream: Literal[False] = ...,
|
||||
options: OptionsCoT | ChatOptions[None] | None = None,
|
||||
@@ -998,7 +978,7 @@ class ChatMiddlewareLayer(Generic[OptionsCoT]):
|
||||
@overload
|
||||
def get_response(
|
||||
self,
|
||||
messages: str | Message | Sequence[str | Message],
|
||||
messages: Sequence[Message],
|
||||
*,
|
||||
stream: Literal[True],
|
||||
options: OptionsCoT | ChatOptions[Any] | None = None,
|
||||
@@ -1007,7 +987,7 @@ class ChatMiddlewareLayer(Generic[OptionsCoT]):
|
||||
|
||||
def get_response(
|
||||
self,
|
||||
messages: str | Message | Sequence[str | Message],
|
||||
messages: Sequence[Message],
|
||||
*,
|
||||
stream: bool = False,
|
||||
options: OptionsCoT | ChatOptions[Any] | None = None,
|
||||
@@ -1034,7 +1014,7 @@ class ChatMiddlewareLayer(Generic[OptionsCoT]):
|
||||
|
||||
context = ChatContext(
|
||||
client=self, # type: ignore[arg-type]
|
||||
messages=prepare_messages(messages),
|
||||
messages=list(messages),
|
||||
options=options,
|
||||
stream=stream,
|
||||
kwargs=kwargs,
|
||||
@@ -1095,7 +1075,7 @@ class AgentMiddlewareLayer:
|
||||
@overload
|
||||
def run(
|
||||
self,
|
||||
messages: str | Message | Sequence[str | Message] | None = None,
|
||||
messages: AgentRunInputs | None = None,
|
||||
*,
|
||||
stream: Literal[False] = ...,
|
||||
session: AgentSession | None = None,
|
||||
@@ -1107,7 +1087,7 @@ class AgentMiddlewareLayer:
|
||||
@overload
|
||||
def run(
|
||||
self,
|
||||
messages: str | Message | Sequence[str | Message] | None = None,
|
||||
messages: AgentRunInputs | None = None,
|
||||
*,
|
||||
stream: Literal[False] = ...,
|
||||
session: AgentSession | None = None,
|
||||
@@ -1119,7 +1099,7 @@ class AgentMiddlewareLayer:
|
||||
@overload
|
||||
def run(
|
||||
self,
|
||||
messages: str | Message | Sequence[str | Message] | None = None,
|
||||
messages: AgentRunInputs | None = None,
|
||||
*,
|
||||
stream: Literal[True],
|
||||
session: AgentSession | None = None,
|
||||
@@ -1130,7 +1110,7 @@ class AgentMiddlewareLayer:
|
||||
|
||||
def run(
|
||||
self,
|
||||
messages: str | Message | Sequence[str | Message] | None = None,
|
||||
messages: AgentRunInputs | None = None,
|
||||
*,
|
||||
stream: bool = False,
|
||||
session: AgentSession | None = None,
|
||||
@@ -1161,7 +1141,7 @@ class AgentMiddlewareLayer:
|
||||
|
||||
context = AgentContext(
|
||||
agent=self, # type: ignore[arg-type]
|
||||
messages=prepare_messages(messages), # type: ignore[arg-type]
|
||||
messages=normalize_messages(messages),
|
||||
session=session,
|
||||
options=options,
|
||||
stream=stream,
|
||||
|
||||
@@ -3,13 +3,12 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
from collections.abc import Mapping, MutableMapping
|
||||
from typing import Any, ClassVar, Protocol, TypeVar, runtime_checkable
|
||||
|
||||
from ._logging import get_logger
|
||||
|
||||
logger = get_logger()
|
||||
logger = logging.getLogger("agent_framework")
|
||||
|
||||
ClassT = TypeVar("ClassT", bound="SerializationMixin")
|
||||
ProtocolT = TypeVar("ProtocolT", bound="SerializationProtocol")
|
||||
|
||||
@@ -24,16 +24,6 @@ if TYPE_CHECKING:
|
||||
from ._agents import SupportsAgentRun
|
||||
|
||||
|
||||
__all__ = [
|
||||
"AgentSession",
|
||||
"BaseContextProvider",
|
||||
"BaseHistoryProvider",
|
||||
"InMemoryHistoryProvider",
|
||||
"SessionContext",
|
||||
"register_state_type",
|
||||
]
|
||||
|
||||
|
||||
# Registry of known types for state deserialization
|
||||
_STATE_TYPE_REGISTRY: dict[str, type] = {}
|
||||
|
||||
|
||||
@@ -45,7 +45,6 @@ if sys.version_info >= (3, 13):
|
||||
else:
|
||||
from typing_extensions import TypeVar # type: ignore # pragma: no cover
|
||||
|
||||
__all__ = ["SecretString", "load_settings"]
|
||||
|
||||
SettingsT = TypeVar("SettingsT", default=dict[str, Any])
|
||||
|
||||
|
||||
@@ -2,21 +2,14 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
from typing import Any, Final
|
||||
|
||||
from . import __version__ as version_info
|
||||
from ._logging import get_logger
|
||||
|
||||
logger = get_logger()
|
||||
logger = logging.getLogger("agent_framework")
|
||||
|
||||
__all__ = [
|
||||
"AGENT_FRAMEWORK_USER_AGENT",
|
||||
"APP_INFO",
|
||||
"USER_AGENT_KEY",
|
||||
"USER_AGENT_TELEMETRY_DISABLED_ENV_VAR",
|
||||
"prepend_agent_framework_to_user_agent",
|
||||
]
|
||||
|
||||
# Note that if this environment variable does not exist, user agent telemetry is enabled.
|
||||
USER_AGENT_TELEMETRY_DISABLED_ENV_VAR = "AGENT_FRAMEWORK_USER_AGENT_DISABLED"
|
||||
|
||||
@@ -5,6 +5,7 @@ from __future__ import annotations
|
||||
import asyncio
|
||||
import inspect
|
||||
import json
|
||||
import logging
|
||||
import sys
|
||||
from collections.abc import (
|
||||
AsyncIterable,
|
||||
@@ -34,7 +35,6 @@ from typing import (
|
||||
from opentelemetry.metrics import Histogram, NoOpHistogram
|
||||
from pydantic import BaseModel, Field, ValidationError, create_model
|
||||
|
||||
from ._logging import get_logger
|
||||
from ._serialization import SerializationMixin
|
||||
from .exceptions import ToolException
|
||||
from .observability import (
|
||||
@@ -71,18 +71,8 @@ if TYPE_CHECKING:
|
||||
ResponseModelBoundT = TypeVar("ResponseModelBoundT", bound=BaseModel)
|
||||
|
||||
|
||||
logger = get_logger()
|
||||
logger = logging.getLogger("agent_framework")
|
||||
|
||||
__all__ = [
|
||||
"FunctionInvocationConfiguration",
|
||||
"FunctionInvocationLayer",
|
||||
"FunctionTool",
|
||||
"normalize_function_invocation_configuration",
|
||||
"tool",
|
||||
]
|
||||
|
||||
|
||||
logger = get_logger()
|
||||
DEFAULT_MAX_ITERATIONS: Final[int] = 40
|
||||
DEFAULT_MAX_CONSECUTIVE_ERRORS_PER_REQUEST: Final[int] = 3
|
||||
ChatClientT = TypeVar("ChatClientT", bound="SupportsChatGetResponse[Any]")
|
||||
@@ -1941,7 +1931,7 @@ class FunctionInvocationLayer(Generic[OptionsCoT]):
|
||||
@overload
|
||||
def get_response(
|
||||
self,
|
||||
messages: str | Message | Sequence[str | Message],
|
||||
messages: Sequence[Message],
|
||||
*,
|
||||
stream: Literal[False] = ...,
|
||||
options: ChatOptions[ResponseModelBoundT],
|
||||
@@ -1951,7 +1941,7 @@ class FunctionInvocationLayer(Generic[OptionsCoT]):
|
||||
@overload
|
||||
def get_response(
|
||||
self,
|
||||
messages: str | Message | Sequence[str | Message],
|
||||
messages: Sequence[Message],
|
||||
*,
|
||||
stream: Literal[False] = ...,
|
||||
options: OptionsCoT | ChatOptions[None] | None = None,
|
||||
@@ -1961,7 +1951,7 @@ class FunctionInvocationLayer(Generic[OptionsCoT]):
|
||||
@overload
|
||||
def get_response(
|
||||
self,
|
||||
messages: str | Message | Sequence[str | Message],
|
||||
messages: Sequence[Message],
|
||||
*,
|
||||
stream: Literal[True],
|
||||
options: OptionsCoT | ChatOptions[Any] | None = None,
|
||||
@@ -1970,7 +1960,7 @@ class FunctionInvocationLayer(Generic[OptionsCoT]):
|
||||
|
||||
def get_response(
|
||||
self,
|
||||
messages: str | Message | Sequence[str | Message],
|
||||
messages: Sequence[Message],
|
||||
*,
|
||||
stream: bool = False,
|
||||
options: OptionsCoT | ChatOptions[Any] | None = None,
|
||||
@@ -1982,7 +1972,6 @@ class FunctionInvocationLayer(Generic[OptionsCoT]):
|
||||
ChatResponse,
|
||||
ChatResponseUpdate,
|
||||
ResponseStream,
|
||||
prepare_messages,
|
||||
)
|
||||
|
||||
super_get_response = super().get_response # type: ignore[misc]
|
||||
@@ -2014,7 +2003,7 @@ class FunctionInvocationLayer(Generic[OptionsCoT]):
|
||||
nonlocal mutable_options
|
||||
nonlocal filtered_kwargs
|
||||
errors_in_a_row: int = 0
|
||||
prepped_messages = prepare_messages(messages)
|
||||
prepped_messages = list(messages)
|
||||
fcc_messages: list[Message] = []
|
||||
response: ChatResponse | None = None
|
||||
|
||||
@@ -2108,7 +2097,7 @@ class FunctionInvocationLayer(Generic[OptionsCoT]):
|
||||
nonlocal mutable_options
|
||||
nonlocal stream_result_hooks
|
||||
errors_in_a_row: int = 0
|
||||
prepped_messages = prepare_messages(messages)
|
||||
prepped_messages = list(messages)
|
||||
fcc_messages: list[Message] = []
|
||||
response: ChatResponse | None = None
|
||||
|
||||
|
||||
@@ -4,6 +4,7 @@ from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import json
|
||||
import logging
|
||||
import re
|
||||
import sys
|
||||
from asyncio import iscoroutine
|
||||
@@ -13,7 +14,6 @@ from typing import TYPE_CHECKING, Any, ClassVar, Final, Generic, Literal, NewTyp
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from ._logging import get_logger
|
||||
from ._serialization import SerializationMixin
|
||||
from ._tools import FunctionTool, tool
|
||||
from .exceptions import AdditionItemMismatch, ContentError
|
||||
@@ -27,41 +27,7 @@ if sys.version_info >= (3, 11):
|
||||
else:
|
||||
from typing_extensions import TypedDict # type: ignore # pragma: no cover
|
||||
|
||||
__all__ = [
|
||||
"AgentResponse",
|
||||
"AgentResponseUpdate",
|
||||
"Annotation",
|
||||
"ChatOptions",
|
||||
"ChatResponse",
|
||||
"ChatResponseUpdate",
|
||||
"Content",
|
||||
"ContinuationToken",
|
||||
"FinalT",
|
||||
"FinishReason",
|
||||
"FinishReasonLiteral",
|
||||
"Message",
|
||||
"OuterFinalT",
|
||||
"OuterUpdateT",
|
||||
"ResponseStream",
|
||||
"Role",
|
||||
"RoleLiteral",
|
||||
"TextSpanRegion",
|
||||
"ToolMode",
|
||||
"UpdateT",
|
||||
"UsageDetails",
|
||||
"add_usage_details",
|
||||
"detect_media_type_from_base64",
|
||||
"map_chat_to_agent_update",
|
||||
"merge_chat_options",
|
||||
"normalize_messages",
|
||||
"normalize_tools",
|
||||
"prepend_instructions_to_messages",
|
||||
"validate_chat_options",
|
||||
"validate_tool_mode",
|
||||
"validate_tools",
|
||||
]
|
||||
|
||||
logger = get_logger("agent_framework")
|
||||
logger = logging.getLogger("agent_framework")
|
||||
|
||||
|
||||
# region Content Parsing Utilities
|
||||
@@ -1536,47 +1502,11 @@ class Message(SerializationMixin):
|
||||
return " ".join(content.text for content in self.contents if content.type == "text") # type: ignore[misc]
|
||||
|
||||
|
||||
def prepare_messages(
|
||||
messages: str | Content | Message | Sequence[str | Content | Message],
|
||||
system_instructions: str | Sequence[str] | None = None,
|
||||
) -> list[Message]:
|
||||
"""Convert various message input formats into a list of Message objects.
|
||||
|
||||
Args:
|
||||
messages: The input messages in various supported formats. Can be:
|
||||
- A string (converted to a user message)
|
||||
- A Content object (wrapped in a user Message)
|
||||
- A Message object
|
||||
- A sequence containing any mix of the above
|
||||
system_instructions: The system instructions. They will be inserted to the start of the messages list.
|
||||
|
||||
Returns:
|
||||
A list of Message objects.
|
||||
"""
|
||||
if system_instructions is not None:
|
||||
if isinstance(system_instructions, str):
|
||||
system_instructions = [system_instructions]
|
||||
system_instruction_messages = [Message("system", [instr]) for instr in system_instructions]
|
||||
else:
|
||||
system_instruction_messages = []
|
||||
|
||||
if isinstance(messages, str):
|
||||
return [*system_instruction_messages, Message("user", [messages])]
|
||||
if isinstance(messages, Content):
|
||||
return [*system_instruction_messages, Message("user", [messages])]
|
||||
if isinstance(messages, Message):
|
||||
return [*system_instruction_messages, messages]
|
||||
|
||||
return_messages: list[Message] = system_instruction_messages
|
||||
for msg in messages:
|
||||
if isinstance(msg, (str, Content)):
|
||||
msg = Message("user", [msg])
|
||||
return_messages.append(msg)
|
||||
return return_messages
|
||||
AgentRunInputs = str | Content | Message | Sequence[str | Content | Message]
|
||||
|
||||
|
||||
def normalize_messages(
|
||||
messages: str | Content | Message | Sequence[str | Content | Message] | None = None,
|
||||
messages: AgentRunInputs | None = None,
|
||||
) -> list[Message]:
|
||||
"""Normalize message inputs to a list of Message objects.
|
||||
|
||||
|
||||
@@ -22,6 +22,7 @@ from .._sessions import (
|
||||
from .._types import (
|
||||
AgentResponse,
|
||||
AgentResponseUpdate,
|
||||
AgentRunInputs,
|
||||
Content,
|
||||
Message,
|
||||
ResponseStream,
|
||||
@@ -145,7 +146,7 @@ class WorkflowAgent(BaseAgent):
|
||||
@overload
|
||||
def run(
|
||||
self,
|
||||
messages: str | Message | Sequence[str | Message] | None = None,
|
||||
messages: AgentRunInputs | None = None,
|
||||
*,
|
||||
stream: Literal[True],
|
||||
session: AgentSession | None = None,
|
||||
@@ -157,7 +158,7 @@ class WorkflowAgent(BaseAgent):
|
||||
@overload
|
||||
async def run(
|
||||
self,
|
||||
messages: str | Message | Sequence[str | Message] | None = None,
|
||||
messages: AgentRunInputs | None = None,
|
||||
*,
|
||||
stream: Literal[False] = ...,
|
||||
session: AgentSession | None = None,
|
||||
@@ -168,7 +169,7 @@ class WorkflowAgent(BaseAgent):
|
||||
|
||||
def run(
|
||||
self,
|
||||
messages: str | Message | Sequence[str | Message] | None = None,
|
||||
messages: AgentRunInputs | None = None,
|
||||
*,
|
||||
stream: bool = False,
|
||||
session: AgentSession | None = None,
|
||||
@@ -214,7 +215,7 @@ class WorkflowAgent(BaseAgent):
|
||||
|
||||
async def _run_impl(
|
||||
self,
|
||||
messages: str | Message | Sequence[str | Message],
|
||||
messages: AgentRunInputs,
|
||||
response_id: str,
|
||||
session: AgentSession | None,
|
||||
checkpoint_id: str | None = None,
|
||||
@@ -270,7 +271,7 @@ class WorkflowAgent(BaseAgent):
|
||||
|
||||
async def _run_stream_impl(
|
||||
self,
|
||||
messages: str | Message | Sequence[str | Message],
|
||||
messages: AgentRunInputs,
|
||||
response_id: str,
|
||||
session: AgentSession | None,
|
||||
checkpoint_id: str | None = None,
|
||||
|
||||
@@ -3,11 +3,10 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import base64
|
||||
import logging
|
||||
import pickle # nosec # noqa: S403
|
||||
from typing import Any
|
||||
|
||||
from agent_framework import get_logger
|
||||
|
||||
"""Checkpoint encoding using JSON structure with pickle+base64 for arbitrary data.
|
||||
|
||||
This hybrid approach provides:
|
||||
@@ -20,7 +19,7 @@ from trusted sources. Loading a malicious checkpoint file can execute arbitrary
|
||||
"""
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
logger = logging.getLogger("agent_framework")
|
||||
|
||||
# Marker to identify pickled values in serialized JSON
|
||||
_PICKLE_MARKER = "__pickled__"
|
||||
|
||||
@@ -2,18 +2,17 @@
|
||||
|
||||
"""Shared helpers for normalizing workflow message inputs."""
|
||||
|
||||
from collections.abc import Sequence
|
||||
|
||||
from agent_framework import Message
|
||||
from agent_framework import Content, Message
|
||||
from agent_framework._types import AgentRunInputs
|
||||
|
||||
|
||||
def normalize_messages_input(
|
||||
messages: str | Message | Sequence[str | Message] | None = None,
|
||||
messages: AgentRunInputs | None = None,
|
||||
) -> list[Message]:
|
||||
"""Normalize heterogeneous message inputs to a list of Message objects.
|
||||
|
||||
Args:
|
||||
messages: String, Message, or sequence of either. None yields empty list.
|
||||
messages: String, Content, Message, or sequence of those values. None yields empty list.
|
||||
|
||||
Returns:
|
||||
List of Message instances suitable for workflow consumption.
|
||||
@@ -24,6 +23,9 @@ def normalize_messages_input(
|
||||
if isinstance(messages, str):
|
||||
return [Message(role="user", text=messages)]
|
||||
|
||||
if isinstance(messages, Content):
|
||||
return [Message(role="user", contents=[messages])]
|
||||
|
||||
if isinstance(messages, Message):
|
||||
return [messages]
|
||||
|
||||
@@ -31,13 +33,12 @@ def normalize_messages_input(
|
||||
for item in messages:
|
||||
if isinstance(item, str):
|
||||
normalized.append(Message(role="user", text=item))
|
||||
elif isinstance(item, Content):
|
||||
normalized.append(Message(role="user", contents=[item]))
|
||||
elif isinstance(item, Message):
|
||||
normalized.append(item)
|
||||
else:
|
||||
raise TypeError(
|
||||
f"Messages sequence must contain only str or Message instances; found {type(item).__name__}."
|
||||
f"Messages sequence must contain only str, Content, or Message instances; found {type(item).__name__}."
|
||||
)
|
||||
return normalized
|
||||
|
||||
|
||||
__all__ = ["normalize_messages_input"]
|
||||
|
||||
@@ -27,8 +27,6 @@ if sys.version_info >= (3, 11):
|
||||
else:
|
||||
from typing_extensions import TypedDict # type: ignore # pragma: no cover
|
||||
|
||||
__all__ = ["AzureOpenAIAssistantsClient"]
|
||||
|
||||
|
||||
# region Azure OpenAI Assistants Options TypedDict
|
||||
|
||||
|
||||
@@ -53,7 +53,6 @@ if TYPE_CHECKING:
|
||||
|
||||
logger: logging.Logger = logging.getLogger(__name__)
|
||||
|
||||
__all__ = ["AzureOpenAIChatClient", "AzureOpenAIChatOptions", "AzureUserSecurityContext"]
|
||||
|
||||
ResponseModelT = TypeVar("ResponseModelT", bound=BaseModel | None, default=None)
|
||||
|
||||
|
||||
@@ -42,8 +42,6 @@ if TYPE_CHECKING:
|
||||
from .._middleware import MiddlewareTypes
|
||||
from ..openai._responses_client import OpenAIResponsesOptions
|
||||
|
||||
__all__ = ["AzureOpenAIResponsesClient"]
|
||||
|
||||
|
||||
AzureOpenAIResponsesOptionsT = TypeVar(
|
||||
"AzureOpenAIResponsesOptionsT",
|
||||
|
||||
@@ -20,7 +20,6 @@ from opentelemetry.semconv.attributes import service_attributes
|
||||
from opentelemetry.semconv_ai import Meters, SpanAttributes
|
||||
|
||||
from . import __version__ as version_info
|
||||
from ._logging import get_logger
|
||||
from ._settings import load_settings
|
||||
|
||||
if sys.version_info >= (3, 13):
|
||||
@@ -44,6 +43,7 @@ if TYPE_CHECKING: # pragma: no cover
|
||||
from ._types import (
|
||||
AgentResponse,
|
||||
AgentResponseUpdate,
|
||||
AgentRunInputs,
|
||||
ChatOptions,
|
||||
ChatResponse,
|
||||
ChatResponseUpdate,
|
||||
@@ -73,7 +73,7 @@ AgentT = TypeVar("AgentT", bound="SupportsAgentRun")
|
||||
ChatClientT = TypeVar("ChatClientT", bound="SupportsChatGetResponse[Any]")
|
||||
|
||||
|
||||
logger = get_logger()
|
||||
logger = logging.getLogger("agent_framework")
|
||||
|
||||
|
||||
OTEL_METRICS: Final[str] = "__otel_metrics__"
|
||||
@@ -747,7 +747,6 @@ class ObservabilitySettings:
|
||||
for log_exporter in log_exporters:
|
||||
logger_provider.add_log_record_processor(BatchLogRecordProcessor(log_exporter))
|
||||
# Attach a handler with the provider to the root logger
|
||||
logger = logging.getLogger()
|
||||
handler = LoggingHandler(logger_provider=logger_provider)
|
||||
logger.addHandler(handler)
|
||||
set_logger_provider(logger_provider)
|
||||
@@ -1084,7 +1083,7 @@ class ChatTelemetryLayer(Generic[OptionsCoT]):
|
||||
@overload
|
||||
def get_response(
|
||||
self,
|
||||
messages: str | Message | Sequence[str | Message],
|
||||
messages: Sequence[Message],
|
||||
*,
|
||||
stream: Literal[False] = ...,
|
||||
options: ChatOptions[ResponseModelBoundT],
|
||||
@@ -1094,7 +1093,7 @@ class ChatTelemetryLayer(Generic[OptionsCoT]):
|
||||
@overload
|
||||
def get_response(
|
||||
self,
|
||||
messages: str | Message | Sequence[str | Message],
|
||||
messages: Sequence[Message],
|
||||
*,
|
||||
stream: Literal[False] = ...,
|
||||
options: OptionsCoT | ChatOptions[None] | None = None,
|
||||
@@ -1104,7 +1103,7 @@ class ChatTelemetryLayer(Generic[OptionsCoT]):
|
||||
@overload
|
||||
def get_response(
|
||||
self,
|
||||
messages: str | Message | Sequence[str | Message],
|
||||
messages: Sequence[Message],
|
||||
*,
|
||||
stream: Literal[True],
|
||||
options: OptionsCoT | ChatOptions[Any] | None = None,
|
||||
@@ -1113,7 +1112,7 @@ class ChatTelemetryLayer(Generic[OptionsCoT]):
|
||||
|
||||
def get_response(
|
||||
self,
|
||||
messages: str | Message | Sequence[str | Message],
|
||||
messages: Sequence[Message],
|
||||
*,
|
||||
stream: bool = False,
|
||||
options: OptionsCoT | ChatOptions[Any] | None = None,
|
||||
@@ -1277,7 +1276,7 @@ class AgentTelemetryLayer:
|
||||
@overload
|
||||
def run(
|
||||
self,
|
||||
messages: str | Message | Sequence[str | Message] | None = None,
|
||||
messages: AgentRunInputs | None = None,
|
||||
*,
|
||||
stream: Literal[False] = ...,
|
||||
session: AgentSession | None = None,
|
||||
@@ -1287,7 +1286,7 @@ class AgentTelemetryLayer:
|
||||
@overload
|
||||
def run(
|
||||
self,
|
||||
messages: str | Message | Sequence[str | Message] | None = None,
|
||||
messages: AgentRunInputs | None = None,
|
||||
*,
|
||||
stream: Literal[True],
|
||||
session: AgentSession | None = None,
|
||||
@@ -1296,7 +1295,7 @@ class AgentTelemetryLayer:
|
||||
|
||||
def run(
|
||||
self,
|
||||
messages: str | Message | Sequence[str | Message] | None = None,
|
||||
messages: AgentRunInputs | None = None,
|
||||
*,
|
||||
stream: bool = False,
|
||||
session: AgentSession | None = None,
|
||||
@@ -1614,15 +1613,15 @@ def capture_exception(span: trace.Span, exception: Exception, timestamp: int | N
|
||||
def _capture_messages(
|
||||
span: trace.Span,
|
||||
provider_name: str,
|
||||
messages: str | Message | Sequence[str | Message],
|
||||
messages: AgentRunInputs,
|
||||
system_instructions: str | list[str] | None = None,
|
||||
output: bool = False,
|
||||
finish_reason: FinishReason | None = None,
|
||||
) -> None:
|
||||
"""Log messages with extra information."""
|
||||
from ._types import prepare_messages
|
||||
from ._types import normalize_messages, prepend_instructions_to_messages
|
||||
|
||||
prepped = prepare_messages(messages, system_instructions=system_instructions)
|
||||
prepped = prepend_instructions_to_messages(normalize_messages(messages), system_instructions)
|
||||
otel_messages: list[dict[str, Any]] = []
|
||||
for index, message in enumerate(prepped):
|
||||
# Reuse the otel message representation for logging instead of calling to_dict()
|
||||
|
||||
@@ -33,7 +33,6 @@ if sys.version_info >= (3, 11):
|
||||
else:
|
||||
from typing_extensions import Self, TypedDict # type:ignore # pragma: no cover
|
||||
|
||||
__all__ = ["OpenAIAssistantProvider"]
|
||||
|
||||
# Type variable for options - allows typed OpenAIAssistantProvider[OptionsCoT] returns
|
||||
# Default matches OpenAIAssistantsClient's default options type
|
||||
|
||||
@@ -68,12 +68,6 @@ else:
|
||||
if TYPE_CHECKING:
|
||||
from .._middleware import MiddlewareTypes
|
||||
|
||||
__all__ = [
|
||||
"AssistantToolResources",
|
||||
"OpenAIAssistantsClient",
|
||||
"OpenAIAssistantsOptions",
|
||||
]
|
||||
|
||||
|
||||
# region OpenAI Assistants Options TypedDict
|
||||
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import sys
|
||||
from collections.abc import AsyncIterable, Awaitable, Callable, Mapping, MutableMapping, Sequence
|
||||
from datetime import datetime, timezone
|
||||
@@ -20,7 +21,6 @@ from openai.types.chat.completion_create_params import WebSearchOptions
|
||||
from pydantic import BaseModel
|
||||
|
||||
from .._clients import BaseChatClient
|
||||
from .._logging import get_logger
|
||||
from .._middleware import ChatAndFunctionMiddlewareTypes, ChatMiddlewareLayer
|
||||
from .._settings import load_settings
|
||||
from .._tools import (
|
||||
@@ -60,9 +60,7 @@ if sys.version_info >= (3, 11):
|
||||
else:
|
||||
from typing_extensions import TypedDict # type: ignore # pragma: no cover
|
||||
|
||||
__all__ = ["OpenAIChatClient", "OpenAIChatOptions"]
|
||||
|
||||
logger = get_logger("agent_framework.openai")
|
||||
logger = logging.getLogger("agent_framework.openai")
|
||||
|
||||
ResponseModelT = TypeVar("ResponseModelT", bound=BaseModel | None, default=None)
|
||||
|
||||
|
||||
@@ -10,8 +10,6 @@ from openai import BadRequestError
|
||||
|
||||
from ..exceptions import ServiceContentFilterException
|
||||
|
||||
__all__ = ["ContentFilterResultSeverity", "OpenAIContentFilterException"]
|
||||
|
||||
|
||||
class ContentFilterResultSeverity(Enum):
|
||||
"""The severity of the content filter result."""
|
||||
|
||||
@@ -2,6 +2,7 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import sys
|
||||
from collections.abc import (
|
||||
AsyncIterable,
|
||||
@@ -36,7 +37,6 @@ from openai.types.responses.web_search_tool_param import WebSearchToolParam
|
||||
from pydantic import BaseModel
|
||||
|
||||
from .._clients import BaseChatClient
|
||||
from .._logging import get_logger
|
||||
from .._middleware import ChatMiddlewareLayer
|
||||
from .._settings import load_settings
|
||||
from .._tools import (
|
||||
@@ -90,10 +90,7 @@ if TYPE_CHECKING:
|
||||
FunctionMiddlewareCallable,
|
||||
)
|
||||
|
||||
logger = get_logger("agent_framework.openai")
|
||||
|
||||
|
||||
__all__ = ["OpenAIContinuationToken", "OpenAIResponsesClient", "OpenAIResponsesOptions", "RawOpenAIResponsesClient"]
|
||||
logger = logging.getLogger("agent_framework.openai")
|
||||
|
||||
|
||||
class OpenAIContinuationToken(ContinuationToken):
|
||||
|
||||
@@ -22,14 +22,13 @@ from openai.types.responses.response import Response
|
||||
from openai.types.responses.response_stream_event import ResponseStreamEvent
|
||||
from packaging.version import parse
|
||||
|
||||
from .._logging import get_logger
|
||||
from .._serialization import SerializationMixin
|
||||
from .._settings import SecretString
|
||||
from .._telemetry import APP_INFO, USER_AGENT_KEY, prepend_agent_framework_to_user_agent
|
||||
from .._tools import FunctionTool
|
||||
from ..exceptions import ServiceInitializationError
|
||||
|
||||
logger: logging.Logger = get_logger("agent_framework.openai")
|
||||
logger: logging.Logger = logging.getLogger("agent_framework.openai")
|
||||
|
||||
|
||||
RESPONSE_TYPE = Union[
|
||||
@@ -53,9 +52,6 @@ else:
|
||||
from typing_extensions import TypedDict # type: ignore # pragma: no cover
|
||||
|
||||
|
||||
__all__ = ["OpenAISettings"]
|
||||
|
||||
|
||||
def _check_openai_version_for_callable_api_key() -> None:
|
||||
"""Check if OpenAI version supports callable API keys.
|
||||
|
||||
|
||||
@@ -334,8 +334,10 @@ async def test_integration_options(
|
||||
messages = [Message(role="user", text="What is the weather in Seattle?")]
|
||||
elif option_name == "response_format":
|
||||
# Use prompt that works well with structured output
|
||||
messages = [Message(role="user", text="The weather in Seattle is sunny")]
|
||||
messages.append(Message(role="user", text="What is the weather in Seattle?"))
|
||||
messages = [
|
||||
Message(role="user", text="The weather in Seattle is sunny"),
|
||||
Message(role="user", text="What is the weather in Seattle?"),
|
||||
]
|
||||
else:
|
||||
# Generic prompt for simple options
|
||||
messages = [Message(role="user", text="Say 'Hello World' briefly.")]
|
||||
@@ -396,7 +398,12 @@ async def test_integration_web_search() -> None:
|
||||
|
||||
for streaming in [False, True]:
|
||||
content = {
|
||||
"messages": "Who are the main characters of Kpop Demon Hunters? Do a web search to find the answer.",
|
||||
"messages": [
|
||||
Message(
|
||||
role="user",
|
||||
text="Who are the main characters of Kpop Demon Hunters? Do a web search to find the answer.",
|
||||
)
|
||||
],
|
||||
"options": {
|
||||
"tool_choice": "auto",
|
||||
"tools": [AzureOpenAIResponsesClient.get_web_search_tool()],
|
||||
@@ -416,7 +423,7 @@ async def test_integration_web_search() -> None:
|
||||
|
||||
# Test that the client will use the web search tool with location
|
||||
content = {
|
||||
"messages": "What is the current weather? Do not ask for my current location.",
|
||||
"messages": [Message(role="user", text="What is the current weather? Do not ask for my current location.")],
|
||||
"options": {
|
||||
"tool_choice": "auto",
|
||||
"tools": [
|
||||
@@ -498,7 +505,7 @@ async def test_integration_client_agent_hosted_mcp_tool() -> None:
|
||||
"""Integration test for MCP tool with Azure Response Agent using Microsoft Learn MCP."""
|
||||
client = AzureOpenAIResponsesClient(credential=AzureCliCredential())
|
||||
response = await client.get_response(
|
||||
"How to create an Azure storage account using az cli?",
|
||||
messages=[Message(role="user", text="How to create an Azure storage account using az cli?")],
|
||||
options={
|
||||
# this needs to be high enough to handle the full MCP tool response.
|
||||
"max_tokens": 5000,
|
||||
@@ -523,7 +530,7 @@ async def test_integration_client_agent_hosted_code_interpreter_tool():
|
||||
client = AzureOpenAIResponsesClient(credential=AzureCliCredential())
|
||||
|
||||
response = await client.get_response(
|
||||
"Calculate the sum of numbers from 1 to 10 using Python code.",
|
||||
messages=[Message(role="user", text="Calculate the sum of numbers from 1 to 10 using Python code.")],
|
||||
options={
|
||||
"tools": [AzureOpenAIResponsesClient.get_code_interpreter_tool()],
|
||||
},
|
||||
|
||||
@@ -43,6 +43,12 @@ async def test_agent_run(agent: SupportsAgentRun) -> None:
|
||||
assert response.messages[0].text == "Response"
|
||||
|
||||
|
||||
async def test_agent_run_with_content(agent: SupportsAgentRun) -> None:
|
||||
response = await agent.run(Content.from_text("test"))
|
||||
assert response.messages[0].role == "assistant"
|
||||
assert response.messages[0].text == "Response"
|
||||
|
||||
|
||||
async def test_agent_run_streaming(agent: SupportsAgentRun) -> None:
|
||||
async def collect_updates(updates: AsyncIterable[AgentResponseUpdate]) -> list[AgentResponseUpdate]:
|
||||
return [u async for u in updates]
|
||||
|
||||
@@ -21,13 +21,13 @@ def test_chat_client_type(client: SupportsChatGetResponse):
|
||||
|
||||
|
||||
async def test_chat_client_get_response(client: SupportsChatGetResponse):
|
||||
response = await client.get_response(Message(role="user", text="Hello"))
|
||||
response = await client.get_response([Message(role="user", text="Hello")])
|
||||
assert response.text == "test response"
|
||||
assert response.messages[0].role == "assistant"
|
||||
|
||||
|
||||
async def test_chat_client_get_response_streaming(client: SupportsChatGetResponse):
|
||||
async for update in client.get_response(Message(role="user", text="Hello"), stream=True):
|
||||
async for update in client.get_response([Message(role="user", text="Hello")], stream=True):
|
||||
assert update.text == "test streaming response " or update.text == "another update"
|
||||
assert update.role == "assistant"
|
||||
|
||||
@@ -38,13 +38,13 @@ def test_base_client(chat_client_base: SupportsChatGetResponse):
|
||||
|
||||
|
||||
async def test_base_client_get_response(chat_client_base: SupportsChatGetResponse):
|
||||
response = await chat_client_base.get_response(Message(role="user", text="Hello"))
|
||||
response = await chat_client_base.get_response([Message(role="user", text="Hello")])
|
||||
assert response.messages[0].role == "assistant"
|
||||
assert response.messages[0].text == "test response - Hello"
|
||||
|
||||
|
||||
async def test_base_client_get_response_streaming(chat_client_base: SupportsChatGetResponse):
|
||||
async for update in chat_client_base.get_response(Message(role="user", text="Hello"), stream=True):
|
||||
async for update in chat_client_base.get_response([Message(role="user", text="Hello")], stream=True):
|
||||
assert update.text == "update - Hello" or update.text == "another update"
|
||||
|
||||
|
||||
@@ -59,7 +59,9 @@ async def test_chat_client_instructions_handling(chat_client_base: SupportsChatG
|
||||
"_inner_get_response",
|
||||
side_effect=fake_inner_get_response,
|
||||
) as mock_inner_get_response:
|
||||
await chat_client_base.get_response("hello", options={"instructions": instructions})
|
||||
await chat_client_base.get_response(
|
||||
[Message(role="user", text="hello")], options={"instructions": instructions}
|
||||
)
|
||||
mock_inner_get_response.assert_called_once()
|
||||
_, kwargs = mock_inner_get_response.call_args
|
||||
messages = kwargs.get("messages", [])
|
||||
|
||||
@@ -38,7 +38,9 @@ async def test_base_client_with_function_calling(chat_client_base: SupportsChatG
|
||||
),
|
||||
ChatResponse(messages=Message(role="assistant", text="done")),
|
||||
]
|
||||
response = await chat_client_base.get_response("hello", options={"tool_choice": "auto", "tools": [ai_func]})
|
||||
response = await chat_client_base.get_response(
|
||||
[Message(role="user", text="hello")], options={"tool_choice": "auto", "tools": [ai_func]}
|
||||
)
|
||||
assert exec_counter == 1
|
||||
assert len(response.messages) == 3
|
||||
assert response.messages[0].role == "assistant"
|
||||
@@ -83,7 +85,9 @@ async def test_base_client_with_function_calling_resets(chat_client_base: Suppor
|
||||
),
|
||||
ChatResponse(messages=Message(role="assistant", text="done")),
|
||||
]
|
||||
response = await chat_client_base.get_response("hello", options={"tool_choice": "auto", "tools": [ai_func]})
|
||||
response = await chat_client_base.get_response(
|
||||
[Message(role="user", text="hello")], options={"tool_choice": "auto", "tools": [ai_func]}
|
||||
)
|
||||
assert exec_counter == 2
|
||||
assert len(response.messages) == 5
|
||||
assert response.messages[0].role == "assistant"
|
||||
@@ -388,11 +392,13 @@ async def test_function_invocation_scenarios(
|
||||
options["conversation_id"] = conversation_id
|
||||
|
||||
if not streaming:
|
||||
response = await chat_client_base.get_response("hello", options=options)
|
||||
response = await chat_client_base.get_response([Message(role="user", text="hello")], options=options)
|
||||
messages = response.messages
|
||||
else:
|
||||
updates = []
|
||||
async for update in chat_client_base.get_response("hello", options=options, stream=True):
|
||||
async for update in chat_client_base.get_response(
|
||||
[Message(role="user", text="hello")], options=options, stream=True
|
||||
):
|
||||
updates.append(update)
|
||||
messages = updates
|
||||
|
||||
@@ -776,7 +782,9 @@ async def test_max_iterations_limit(chat_client_base: SupportsChatGetResponse):
|
||||
# Set max_iterations to 1 in additional_properties
|
||||
chat_client_base.function_invocation_configuration["max_iterations"] = 1
|
||||
|
||||
response = await chat_client_base.get_response("hello", options={"tool_choice": "auto", "tools": [ai_func]})
|
||||
response = await chat_client_base.get_response(
|
||||
[Message(role="user", text="hello")], options={"tool_choice": "auto", "tools": [ai_func]}
|
||||
)
|
||||
|
||||
# With max_iterations=1, we should:
|
||||
# 1. Execute first function call (exec_counter=1)
|
||||
@@ -803,7 +811,9 @@ async def test_function_invocation_config_enabled_false(chat_client_base: Suppor
|
||||
# Disable function invocation
|
||||
chat_client_base.function_invocation_configuration["enabled"] = False
|
||||
|
||||
response = await chat_client_base.get_response("hello", options={"tool_choice": "auto", "tools": [ai_func]})
|
||||
response = await chat_client_base.get_response(
|
||||
[Message(role="user", text="hello")], options={"tool_choice": "auto", "tools": [ai_func]}
|
||||
)
|
||||
|
||||
# Function should not be executed - when enabled=False, the loop doesn't run
|
||||
assert exec_counter == 0
|
||||
@@ -859,7 +869,9 @@ async def test_function_invocation_config_max_consecutive_errors(chat_client_bas
|
||||
# Set max_consecutive_errors to 2
|
||||
chat_client_base.function_invocation_configuration["max_consecutive_errors_per_request"] = 2
|
||||
|
||||
response = await chat_client_base.get_response("hello", options={"tool_choice": "auto", "tools": [error_func]})
|
||||
response = await chat_client_base.get_response(
|
||||
[Message(role="user", text="hello")], options={"tool_choice": "auto", "tools": [error_func]}
|
||||
)
|
||||
|
||||
# Should stop after 2 consecutive errors and force a non-tool response
|
||||
error_results = [
|
||||
@@ -904,7 +916,9 @@ async def test_function_invocation_config_terminate_on_unknown_calls_false(chat_
|
||||
# Set terminate_on_unknown_calls to False (default)
|
||||
chat_client_base.function_invocation_configuration["terminate_on_unknown_calls"] = False
|
||||
|
||||
response = await chat_client_base.get_response("hello", options={"tool_choice": "auto", "tools": [known_func]})
|
||||
response = await chat_client_base.get_response(
|
||||
[Message(role="user", text="hello")], options={"tool_choice": "auto", "tools": [known_func]}
|
||||
)
|
||||
|
||||
# Should have a result message indicating the tool wasn't found
|
||||
assert len(response.messages) == 3
|
||||
@@ -940,7 +954,9 @@ async def test_function_invocation_config_terminate_on_unknown_calls_true(chat_c
|
||||
|
||||
# Should raise an exception when encountering an unknown function
|
||||
with pytest.raises(KeyError, match='Error: Requested function "unknown_function" not found'):
|
||||
await chat_client_base.get_response("hello", options={"tool_choice": "auto", "tools": [known_func]})
|
||||
await chat_client_base.get_response(
|
||||
[Message(role="user", text="hello")], options={"tool_choice": "auto", "tools": [known_func]}
|
||||
)
|
||||
|
||||
assert exec_counter == 0
|
||||
|
||||
@@ -978,7 +994,9 @@ async def test_function_invocation_config_additional_tools(chat_client_base: Sup
|
||||
chat_client_base.function_invocation_configuration["additional_tools"] = [hidden_func]
|
||||
|
||||
# Only pass visible_func in the tools parameter
|
||||
response = await chat_client_base.get_response("hello", options={"tool_choice": "auto", "tools": [visible_func]})
|
||||
response = await chat_client_base.get_response(
|
||||
[Message(role="user", text="hello")], options={"tool_choice": "auto", "tools": [visible_func]}
|
||||
)
|
||||
|
||||
# Additional tools are treated as declaration_only, so not executed
|
||||
# The function call should be in the messages but not executed
|
||||
@@ -1016,7 +1034,9 @@ async def test_function_invocation_config_include_detailed_errors_false(chat_cli
|
||||
# Set include_detailed_errors to False (default)
|
||||
chat_client_base.function_invocation_configuration["include_detailed_errors"] = False
|
||||
|
||||
response = await chat_client_base.get_response("hello", options={"tool_choice": "auto", "tools": [error_func]})
|
||||
response = await chat_client_base.get_response(
|
||||
[Message(role="user", text="hello")], options={"tool_choice": "auto", "tools": [error_func]}
|
||||
)
|
||||
|
||||
# Should have a generic error message
|
||||
error_result = next(
|
||||
@@ -1050,7 +1070,9 @@ async def test_function_invocation_config_include_detailed_errors_true(chat_clie
|
||||
# Set include_detailed_errors to True
|
||||
chat_client_base.function_invocation_configuration["include_detailed_errors"] = True
|
||||
|
||||
response = await chat_client_base.get_response("hello", options={"tool_choice": "auto", "tools": [error_func]})
|
||||
response = await chat_client_base.get_response(
|
||||
[Message(role="user", text="hello")], options={"tool_choice": "auto", "tools": [error_func]}
|
||||
)
|
||||
|
||||
# Should have detailed error message
|
||||
error_result = next(
|
||||
@@ -1120,7 +1142,9 @@ async def test_argument_validation_error_with_detailed_errors(chat_client_base:
|
||||
# Set include_detailed_errors to True
|
||||
chat_client_base.function_invocation_configuration["include_detailed_errors"] = True
|
||||
|
||||
response = await chat_client_base.get_response("hello", options={"tool_choice": "auto", "tools": [typed_func]})
|
||||
response = await chat_client_base.get_response(
|
||||
[Message(role="user", text="hello")], options={"tool_choice": "auto", "tools": [typed_func]}
|
||||
)
|
||||
|
||||
# Should have detailed validation error
|
||||
error_result = next(
|
||||
@@ -1154,7 +1178,9 @@ async def test_argument_validation_error_without_detailed_errors(chat_client_bas
|
||||
# Set include_detailed_errors to False (default)
|
||||
chat_client_base.function_invocation_configuration["include_detailed_errors"] = False
|
||||
|
||||
response = await chat_client_base.get_response("hello", options={"tool_choice": "auto", "tools": [typed_func]})
|
||||
response = await chat_client_base.get_response(
|
||||
[Message(role="user", text="hello")], options={"tool_choice": "auto", "tools": [typed_func]}
|
||||
)
|
||||
|
||||
# Should have generic validation error
|
||||
error_result = next(
|
||||
@@ -1219,7 +1245,9 @@ async def test_unapproved_tool_execution_raises_exception(chat_client_base: Supp
|
||||
]
|
||||
|
||||
# Get approval request
|
||||
response1 = await chat_client_base.get_response("hello", options={"tool_choice": "auto", "tools": [test_func]})
|
||||
response1 = await chat_client_base.get_response(
|
||||
[Message(role="user", text="hello")], options={"tool_choice": "auto", "tools": [test_func]}
|
||||
)
|
||||
|
||||
approval_req = [c for c in response1.messages[0].contents if c.type == "function_approval_request"][0]
|
||||
|
||||
@@ -1277,7 +1305,9 @@ async def test_approved_function_call_with_error_without_detailed_errors(chat_cl
|
||||
chat_client_base.function_invocation_configuration["include_detailed_errors"] = False
|
||||
|
||||
# Get approval request
|
||||
response1 = await chat_client_base.get_response("hello", options={"tool_choice": "auto", "tools": [error_func]})
|
||||
response1 = await chat_client_base.get_response(
|
||||
[Message(role="user", text="hello")], options={"tool_choice": "auto", "tools": [error_func]}
|
||||
)
|
||||
|
||||
approval_req = [c for c in response1.messages[0].contents if c.type == "function_approval_request"][0]
|
||||
|
||||
@@ -1340,7 +1370,9 @@ async def test_approved_function_call_with_error_with_detailed_errors(chat_clien
|
||||
chat_client_base.function_invocation_configuration["include_detailed_errors"] = True
|
||||
|
||||
# Get approval request
|
||||
response1 = await chat_client_base.get_response("hello", options={"tool_choice": "auto", "tools": [error_func]})
|
||||
response1 = await chat_client_base.get_response(
|
||||
[Message(role="user", text="hello")], options={"tool_choice": "auto", "tools": [error_func]}
|
||||
)
|
||||
|
||||
approval_req = [c for c in response1.messages[0].contents if c.type == "function_approval_request"][0]
|
||||
|
||||
@@ -1403,7 +1435,9 @@ async def test_approved_function_call_with_validation_error(chat_client_base: Su
|
||||
chat_client_base.function_invocation_configuration["include_detailed_errors"] = True
|
||||
|
||||
# Get approval request
|
||||
response1 = await chat_client_base.get_response("hello", options={"tool_choice": "auto", "tools": [typed_func]})
|
||||
response1 = await chat_client_base.get_response(
|
||||
[Message(role="user", text="hello")], options={"tool_choice": "auto", "tools": [typed_func]}
|
||||
)
|
||||
|
||||
approval_req = [c for c in response1.messages[0].contents if c.type == "function_approval_request"][0]
|
||||
|
||||
@@ -1459,7 +1493,9 @@ async def test_approved_function_call_successful_execution(chat_client_base: Sup
|
||||
]
|
||||
|
||||
# Get approval request
|
||||
response1 = await chat_client_base.get_response("hello", options={"tool_choice": "auto", "tools": [success_func]})
|
||||
response1 = await chat_client_base.get_response(
|
||||
[Message(role="user", text="hello")], options={"tool_choice": "auto", "tools": [success_func]}
|
||||
)
|
||||
|
||||
approval_req = [c for c in response1.messages[0].contents if c.type == "function_approval_request"][0]
|
||||
|
||||
@@ -1575,7 +1611,9 @@ async def test_multiple_function_calls_parallel_execution(chat_client_base: Supp
|
||||
ChatResponse(messages=Message(role="assistant", text="done")),
|
||||
]
|
||||
|
||||
response = await chat_client_base.get_response("hello", options={"tool_choice": "auto", "tools": [func1, func2]})
|
||||
response = await chat_client_base.get_response(
|
||||
[Message(role="user", text="hello")], options={"tool_choice": "auto", "tools": [func1, func2]}
|
||||
)
|
||||
|
||||
# Both functions should have been executed
|
||||
assert "func1_start" in exec_order
|
||||
@@ -1612,7 +1650,9 @@ async def test_callable_function_converted_to_tool(chat_client_base: SupportsCha
|
||||
]
|
||||
|
||||
# Pass plain function (will be auto-converted)
|
||||
response = await chat_client_base.get_response("hello", options={"tool_choice": "auto", "tools": [plain_function]})
|
||||
response = await chat_client_base.get_response(
|
||||
[Message(role="user", text="hello")], options={"tool_choice": "auto", "tools": [plain_function]}
|
||||
)
|
||||
|
||||
# Function should be executed
|
||||
assert exec_counter == 1
|
||||
@@ -1644,7 +1684,9 @@ async def test_conversation_id_handling(chat_client_base: SupportsChatGetRespons
|
||||
),
|
||||
]
|
||||
|
||||
response = await chat_client_base.get_response("hello", options={"tool_choice": "auto", "tools": [test_func]})
|
||||
response = await chat_client_base.get_response(
|
||||
[Message(role="user", text="hello")], options={"tool_choice": "auto", "tools": [test_func]}
|
||||
)
|
||||
|
||||
# Should have executed the function
|
||||
results = [content for msg in response.messages for content in msg.contents if content.type == "function_result"]
|
||||
@@ -1671,7 +1713,9 @@ async def test_function_result_appended_to_existing_assistant_message(chat_clien
|
||||
ChatResponse(messages=Message(role="assistant", text="done")),
|
||||
]
|
||||
|
||||
response = await chat_client_base.get_response("hello", options={"tool_choice": "auto", "tools": [test_func]})
|
||||
response = await chat_client_base.get_response(
|
||||
[Message(role="user", text="hello")], options={"tool_choice": "auto", "tools": [test_func]}
|
||||
)
|
||||
|
||||
# Should have messages with both function call and function result
|
||||
assert len(response.messages) >= 2
|
||||
@@ -1716,7 +1760,9 @@ async def test_error_recovery_resets_counter(chat_client_base: SupportsChatGetRe
|
||||
ChatResponse(messages=Message(role="assistant", text="done")),
|
||||
]
|
||||
|
||||
response = await chat_client_base.get_response("hello", options={"tool_choice": "auto", "tools": [sometimes_fails]})
|
||||
response = await chat_client_base.get_response(
|
||||
[Message(role="user", text="hello")], options={"tool_choice": "auto", "tools": [sometimes_fails]}
|
||||
)
|
||||
|
||||
# Should have both an error and a success
|
||||
error_results = [
|
||||
@@ -1990,7 +2036,9 @@ async def test_streaming_function_invocation_config_terminate_on_unknown_calls_t
|
||||
|
||||
# Should raise an exception when encountering an unknown function
|
||||
with pytest.raises(KeyError, match='Error: Requested function "unknown_function" not found'):
|
||||
async for _ in chat_client_base.get_response("hello", options={"tool_choice": "auto", "tools": [known_func]}):
|
||||
async for _ in chat_client_base.get_response(
|
||||
[Message(role="user", text="hello")], options={"tool_choice": "auto", "tools": [known_func]}
|
||||
):
|
||||
pass
|
||||
|
||||
assert exec_counter == 0
|
||||
|
||||
@@ -1,39 +0,0 @@
|
||||
# Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
|
||||
import pytest
|
||||
|
||||
from agent_framework import get_logger
|
||||
from agent_framework.exceptions import AgentFrameworkException
|
||||
|
||||
|
||||
def test_get_logger():
|
||||
"""Test that the logger is created with the correct name."""
|
||||
logger = get_logger()
|
||||
assert logger.name == "agent_framework"
|
||||
|
||||
|
||||
def test_get_logger_custom_name():
|
||||
"""Test that the logger can be created with a custom name."""
|
||||
custom_name = "agent_framework.custom"
|
||||
logger = get_logger(custom_name)
|
||||
assert logger.name == custom_name
|
||||
|
||||
|
||||
def test_get_logger_invalid_name():
|
||||
"""Test that an exception is raised for an invalid logger name."""
|
||||
with pytest.raises(AgentFrameworkException):
|
||||
get_logger("invalid_name")
|
||||
|
||||
|
||||
def test_log(caplog):
|
||||
"""Test that the logger can log messages and adheres to the expected format."""
|
||||
logger = get_logger()
|
||||
with caplog.at_level("DEBUG"):
|
||||
logger.debug("This is a debug message")
|
||||
assert len(caplog.records) == 1
|
||||
record = caplog.records[0]
|
||||
assert record.levelname == "DEBUG"
|
||||
assert record.message == "This is a debug message"
|
||||
assert record.name == "agent_framework"
|
||||
assert record.pathname.endswith("test_logging.py")
|
||||
@@ -1083,7 +1083,12 @@ async def test_integration_web_search() -> None:
|
||||
# Use static method for web search tool
|
||||
web_search_tool = OpenAIChatClient.get_web_search_tool()
|
||||
content = {
|
||||
"messages": "Who are the main characters of Kpop Demon Hunters? Do a web search to find the answer.",
|
||||
"messages": [
|
||||
Message(
|
||||
role="user",
|
||||
text="Who are the main characters of Kpop Demon Hunters? Do a web search to find the answer.",
|
||||
)
|
||||
],
|
||||
"options": {
|
||||
"tool_choice": "auto",
|
||||
"tools": [web_search_tool],
|
||||
@@ -1110,7 +1115,7 @@ async def test_integration_web_search() -> None:
|
||||
}
|
||||
)
|
||||
content = {
|
||||
"messages": "What is the current weather? Do not ask for my current location.",
|
||||
"messages": [Message(role="user", text="What is the current weather? Do not ask for my current location.")],
|
||||
"options": {
|
||||
"tool_choice": "auto",
|
||||
"tools": [web_search_tool_with_location],
|
||||
|
||||
@@ -2416,7 +2416,12 @@ async def test_integration_web_search() -> None:
|
||||
# Use static method for web search tool
|
||||
web_search_tool = OpenAIResponsesClient.get_web_search_tool()
|
||||
content = {
|
||||
"messages": "Who are the main characters of Kpop Demon Hunters? Do a web search to find the answer.",
|
||||
"messages": [
|
||||
Message(
|
||||
role="user",
|
||||
text="Who are the main characters of Kpop Demon Hunters? Do a web search to find the answer.",
|
||||
)
|
||||
],
|
||||
"options": {
|
||||
"tool_choice": "auto",
|
||||
"tools": [web_search_tool],
|
||||
@@ -2438,7 +2443,7 @@ async def test_integration_web_search() -> None:
|
||||
user_location={"country": "US", "city": "Seattle"},
|
||||
)
|
||||
content = {
|
||||
"messages": "What is the current weather? Do not ask for my current location.",
|
||||
"messages": [Message(role="user", text="What is the current weather? Do not ask for my current location.")],
|
||||
"options": {
|
||||
"tool_choice": "auto",
|
||||
"tools": [web_search_tool_with_location],
|
||||
|
||||
@@ -39,7 +39,7 @@ class _ToolCallingAgent(BaseAgent):
|
||||
|
||||
def run(
|
||||
self,
|
||||
messages: str | Message | Sequence[str | Message] | None = None,
|
||||
messages: str | Content | Message | Sequence[str | Content | Message] | None = None,
|
||||
*,
|
||||
stream: bool = False,
|
||||
session: AgentSession | None = None,
|
||||
|
||||
@@ -35,7 +35,7 @@ class _SimpleAgent(BaseAgent):
|
||||
|
||||
def run(
|
||||
self,
|
||||
messages: str | Message | Sequence[str | Message] | None = None,
|
||||
messages: str | Content | Message | Sequence[str | Content | Message] | None = None,
|
||||
*,
|
||||
stream: bool = False,
|
||||
session: AgentSession | None = None,
|
||||
@@ -105,7 +105,7 @@ class _CaptureAgent(BaseAgent):
|
||||
|
||||
def run(
|
||||
self,
|
||||
messages: str | Message | Sequence[str | Message] | None = None,
|
||||
messages: str | Content | Message | Sequence[str | Content | Message] | None = None,
|
||||
*,
|
||||
stream: bool = False,
|
||||
session: AgentSession | None = None,
|
||||
|
||||
@@ -835,7 +835,7 @@ class _StreamingTestAgent(BaseAgent):
|
||||
|
||||
def run(
|
||||
self,
|
||||
messages: str | Message | Sequence[str | Message] | None = None,
|
||||
messages: str | Content | Message | Sequence[str | Content | Message] | None = None,
|
||||
*,
|
||||
stream: bool = False,
|
||||
session: AgentSession | None = None,
|
||||
|
||||
@@ -52,7 +52,7 @@ class _KwargsCapturingAgent(BaseAgent):
|
||||
|
||||
def run(
|
||||
self,
|
||||
messages: str | Message | Sequence[str | Message] | None = None,
|
||||
messages: str | Content | Message | Sequence[str | Content | Message] | None = None,
|
||||
*,
|
||||
stream: bool = False,
|
||||
session: AgentSession | None = None,
|
||||
@@ -85,7 +85,7 @@ class _OptionsAwareAgent(BaseAgent):
|
||||
|
||||
def run(
|
||||
self,
|
||||
messages: str | Message | Sequence[str | Message] | None = None,
|
||||
messages: str | Content | Message | Sequence[str | Content | Message] | None = None,
|
||||
*,
|
||||
stream: bool = False,
|
||||
session: AgentSession | None = None,
|
||||
|
||||
@@ -1,12 +1,12 @@
|
||||
# Copyright (c) Microsoft. All rights reserved.
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import os
|
||||
from collections.abc import MutableMapping
|
||||
from contextvars import ContextVar
|
||||
from typing import Any, Literal, TypeVar, Union
|
||||
|
||||
from agent_framework import get_logger
|
||||
from agent_framework._serialization import SerializationMixin
|
||||
|
||||
try:
|
||||
@@ -20,7 +20,7 @@ except (ImportError, RuntimeError):
|
||||
|
||||
from typing import overload
|
||||
|
||||
logger = get_logger("agent_framework.declarative")
|
||||
logger = logging.getLogger("agent_framework.declarative")
|
||||
|
||||
# Context variable for safe_mode setting.
|
||||
# When True (default), environment variables are NOT accessible in PowerFx expressions.
|
||||
|
||||
+2
-2
@@ -10,10 +10,10 @@ This module implements handlers for:
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
from collections.abc import AsyncGenerator
|
||||
from typing import Any, cast
|
||||
|
||||
from agent_framework import get_logger
|
||||
from agent_framework._types import AgentResponse, Message
|
||||
|
||||
from ._handlers import (
|
||||
@@ -25,7 +25,7 @@ from ._handlers import (
|
||||
)
|
||||
from ._human_input import ExternalLoopEvent, QuestionRequest
|
||||
|
||||
logger = get_logger("agent_framework.declarative.workflows.actions")
|
||||
logger = logging.getLogger("agent_framework.declarative")
|
||||
|
||||
|
||||
def _extract_json_from_response(text: str) -> Any:
|
||||
|
||||
+2
-3
@@ -18,11 +18,10 @@ actually yielding any events.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from collections.abc import AsyncGenerator
|
||||
from typing import TYPE_CHECKING, Any, cast
|
||||
|
||||
from agent_framework import get_logger
|
||||
|
||||
from ._handlers import (
|
||||
ActionContext,
|
||||
AttachmentOutputEvent,
|
||||
@@ -35,7 +34,7 @@ from ._handlers import (
|
||||
if TYPE_CHECKING:
|
||||
from ._state import WorkflowState
|
||||
|
||||
logger = get_logger("agent_framework.declarative.workflows.actions")
|
||||
logger = logging.getLogger("agent_framework.declarative")
|
||||
|
||||
|
||||
@action_handler("SetValue")
|
||||
|
||||
+2
-3
@@ -11,10 +11,9 @@ This module implements handlers for:
|
||||
- ContinueLoop: Skip to the next iteration
|
||||
"""
|
||||
|
||||
import logging
|
||||
from collections.abc import AsyncGenerator
|
||||
|
||||
from agent_framework import get_logger
|
||||
|
||||
from ._handlers import (
|
||||
ActionContext,
|
||||
LoopControlSignal,
|
||||
@@ -22,7 +21,7 @@ from ._handlers import (
|
||||
action_handler,
|
||||
)
|
||||
|
||||
logger = get_logger("agent_framework.declarative.workflows.actions")
|
||||
logger = logging.getLogger("agent_framework.declarative")
|
||||
|
||||
|
||||
@action_handler("Foreach")
|
||||
|
||||
+2
-3
@@ -9,18 +9,17 @@ This module implements handlers for:
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from collections.abc import AsyncGenerator
|
||||
from dataclasses import dataclass
|
||||
|
||||
from agent_framework import get_logger
|
||||
|
||||
from ._handlers import (
|
||||
ActionContext,
|
||||
WorkflowEvent,
|
||||
action_handler,
|
||||
)
|
||||
|
||||
logger = get_logger("agent_framework.declarative.workflows.actions")
|
||||
logger = logging.getLogger("agent_framework.declarative")
|
||||
|
||||
|
||||
class WorkflowActionError(Exception):
|
||||
|
||||
@@ -12,6 +12,7 @@ enabling checkpointing, visualization, and pause/resume capabilities.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from collections.abc import Mapping
|
||||
from pathlib import Path
|
||||
from typing import Any, cast
|
||||
@@ -22,13 +23,12 @@ from agent_framework import (
|
||||
CheckpointStorage,
|
||||
SupportsAgentRun,
|
||||
Workflow,
|
||||
get_logger,
|
||||
)
|
||||
|
||||
from .._loader import AgentFactory
|
||||
from ._declarative_builder import DeclarativeWorkflowBuilder
|
||||
|
||||
logger = get_logger("agent_framework.declarative.workflows")
|
||||
logger = logging.getLogger("agent_framework.declarative")
|
||||
|
||||
|
||||
class DeclarativeWorkflowError(Exception):
|
||||
|
||||
@@ -9,16 +9,15 @@ has a corresponding handler registered via the @action_handler decorator.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from collections.abc import AsyncGenerator, Callable
|
||||
from dataclasses import dataclass, field
|
||||
from typing import TYPE_CHECKING, Any, Protocol, runtime_checkable
|
||||
|
||||
from agent_framework import get_logger
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from ._state import WorkflowState
|
||||
|
||||
logger = get_logger("agent_framework.declarative.workflows")
|
||||
logger = logging.getLogger("agent_framework.declarative")
|
||||
|
||||
|
||||
@dataclass
|
||||
|
||||
@@ -10,12 +10,11 @@ This module implements handlers for human input patterns:
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from collections.abc import AsyncGenerator
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Any, cast
|
||||
|
||||
from agent_framework import get_logger
|
||||
|
||||
from ._handlers import (
|
||||
ActionContext,
|
||||
WorkflowEvent,
|
||||
@@ -25,7 +24,7 @@ from ._handlers import (
|
||||
if TYPE_CHECKING:
|
||||
from ._state import WorkflowState
|
||||
|
||||
logger = get_logger("agent_framework.declarative.workflows.human_input")
|
||||
logger = logging.getLogger("agent_framework.declarative")
|
||||
|
||||
|
||||
@dataclass
|
||||
|
||||
@@ -11,11 +11,10 @@ This module provides state management for declarative workflows, handling:
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from collections.abc import Mapping
|
||||
from typing import Any, cast
|
||||
|
||||
from agent_framework import get_logger
|
||||
|
||||
try:
|
||||
from powerfx import Engine
|
||||
|
||||
@@ -25,7 +24,7 @@ except (ImportError, RuntimeError):
|
||||
# RuntimeError: .NET runtime not available or misconfigured
|
||||
_powerfx_engine = None
|
||||
|
||||
logger = get_logger("agent_framework.declarative.workflows")
|
||||
logger = logging.getLogger("agent_framework.declarative")
|
||||
|
||||
|
||||
class WorkflowState:
|
||||
|
||||
@@ -8,14 +8,16 @@ with durable agents via gRPC.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from agent_framework import AgentResponse, get_logger
|
||||
import logging
|
||||
|
||||
from agent_framework import AgentResponse
|
||||
from durabletask.client import TaskHubGrpcClient
|
||||
|
||||
from ._constants import DEFAULT_MAX_POLL_RETRIES, DEFAULT_POLL_INTERVAL_SECONDS
|
||||
from ._executors import ClientAgentExecutor
|
||||
from ._shim import DurableAgentProvider, DurableAIAgent
|
||||
|
||||
logger = get_logger("agent_framework.durabletask.client")
|
||||
logger = logging.getLogger("agent_framework.durabletask")
|
||||
|
||||
|
||||
class DurableAIAgentClient(DurableAgentProvider[AgentResponse]):
|
||||
|
||||
@@ -30,6 +30,7 @@ All classes support bidirectional conversion between:
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
from collections.abc import MutableMapping
|
||||
from datetime import datetime, timezone
|
||||
from enum import Enum
|
||||
@@ -40,14 +41,13 @@ from agent_framework import (
|
||||
Content,
|
||||
Message,
|
||||
UsageDetails,
|
||||
get_logger,
|
||||
)
|
||||
from dateutil import parser as date_parser
|
||||
|
||||
from ._constants import ContentTypes, DurableStateFields
|
||||
from ._models import RunRequest, serialize_response_format
|
||||
|
||||
logger = get_logger("agent_framework.durabletask.durable_agent_state")
|
||||
logger = logging.getLogger("agent_framework.durabletask")
|
||||
|
||||
|
||||
class DurableAgentStateEntryJsonType(str, Enum):
|
||||
|
||||
@@ -5,6 +5,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import inspect
|
||||
import logging
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any, cast
|
||||
|
||||
@@ -15,7 +16,6 @@ from agent_framework import (
|
||||
Message,
|
||||
ResponseStream,
|
||||
SupportsAgentRun,
|
||||
get_logger,
|
||||
)
|
||||
from durabletask.entities import DurableEntity
|
||||
|
||||
@@ -28,7 +28,7 @@ from ._durable_agent_state import (
|
||||
)
|
||||
from ._models import RunRequest
|
||||
|
||||
logger = get_logger("agent_framework.durabletask.entities")
|
||||
logger = logging.getLogger("agent_framework.durabletask")
|
||||
|
||||
|
||||
class AgentEntityStateProviderMixin:
|
||||
|
||||
@@ -10,13 +10,14 @@ and are injected into the shim.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
import time
|
||||
import uuid
|
||||
from abc import ABC, abstractmethod
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any, Generic, TypeVar
|
||||
|
||||
from agent_framework import AgentResponse, AgentSession, Content, Message, get_logger
|
||||
from agent_framework import AgentResponse, AgentSession, Content, Message
|
||||
from durabletask.client import TaskHubGrpcClient
|
||||
from durabletask.entities import EntityInstanceId
|
||||
from durabletask.task import CompletableTask, CompositeTask, OrchestrationContext, Task
|
||||
@@ -27,7 +28,7 @@ from ._durable_agent_state import DurableAgentState
|
||||
from ._models import AgentSessionId, DurableAgentSession, RunRequest
|
||||
from ._response_utils import ensure_response_format, load_agent_response
|
||||
|
||||
logger = get_logger("agent_framework.durabletask.executors")
|
||||
logger = logging.getLogger("agent_framework.durabletask")
|
||||
|
||||
# TypeVar for the task type returned by executors
|
||||
TaskT = TypeVar("TaskT")
|
||||
|
||||
@@ -8,13 +8,14 @@ orchestration functions to interact with durable agents.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
from agent_framework import get_logger
|
||||
import logging
|
||||
|
||||
from durabletask.task import OrchestrationContext
|
||||
|
||||
from ._executors import DurableAgentTask, OrchestrationAgentExecutor
|
||||
from ._shim import DurableAgentProvider, DurableAIAgent
|
||||
|
||||
logger = get_logger("agent_framework.durabletask.orchestration_context")
|
||||
logger = logging.getLogger("agent_framework.durabletask")
|
||||
|
||||
|
||||
class DurableAIAgentOrchestrationContext(DurableAgentProvider[DurableAgentTask]):
|
||||
|
||||
@@ -2,12 +2,13 @@
|
||||
|
||||
"""Shared utilities for handling AgentResponse parsing and validation."""
|
||||
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from agent_framework import AgentResponse, get_logger
|
||||
from agent_framework import AgentResponse
|
||||
from pydantic import BaseModel
|
||||
|
||||
logger = get_logger("agent_framework.durabletask.response_utils")
|
||||
logger = logging.getLogger("agent_framework.durabletask")
|
||||
|
||||
|
||||
def load_agent_response(agent_response: AgentResponse | dict[str, Any] | None) -> AgentResponse:
|
||||
|
||||
@@ -12,7 +12,8 @@ from __future__ import annotations
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Generic, Literal, TypeVar
|
||||
|
||||
from agent_framework import AgentSession, Message, SupportsAgentRun
|
||||
from agent_framework import AgentSession, SupportsAgentRun, normalize_messages
|
||||
from agent_framework._types import AgentRunInputs
|
||||
|
||||
from ._executors import DurableAgentExecutor
|
||||
from ._models import DurableAgentSession
|
||||
@@ -86,7 +87,7 @@ class DurableAIAgent(SupportsAgentRun, Generic[TaskT]):
|
||||
|
||||
def run( # type: ignore[override]
|
||||
self,
|
||||
messages: str | Message | list[str] | list[Message] | None = None,
|
||||
messages: AgentRunInputs | None = None,
|
||||
*,
|
||||
stream: Literal[False] = False,
|
||||
session: AgentSession | None = None,
|
||||
@@ -143,7 +144,7 @@ class DurableAIAgent(SupportsAgentRun, Generic[TaskT]):
|
||||
"""
|
||||
return self._executor.get_new_session(self.name, **kwargs)
|
||||
|
||||
def _normalize_messages(self, messages: str | Message | list[str] | list[Message] | None) -> str:
|
||||
def _normalize_messages(self, messages: AgentRunInputs | None) -> str:
|
||||
"""Convert supported message inputs to a single string.
|
||||
|
||||
Args:
|
||||
@@ -151,19 +152,18 @@ class DurableAIAgent(SupportsAgentRun, Generic[TaskT]):
|
||||
|
||||
Returns:
|
||||
A single string representation of the messages
|
||||
|
||||
Raises:
|
||||
ValueError: If normalized messages contain non-text content only.
|
||||
"""
|
||||
if messages is None:
|
||||
normalized_messages = normalize_messages(messages)
|
||||
if not normalized_messages:
|
||||
return ""
|
||||
if isinstance(messages, str):
|
||||
return messages
|
||||
if isinstance(messages, Message):
|
||||
return messages.text or ""
|
||||
if isinstance(messages, list):
|
||||
if not messages:
|
||||
return ""
|
||||
first_item = messages[0]
|
||||
if isinstance(first_item, str):
|
||||
return "\n".join(messages) # type: ignore[arg-type]
|
||||
# List of Message
|
||||
return "\n".join([msg.text or "" for msg in messages]) # type: ignore[union-attr]
|
||||
return ""
|
||||
|
||||
message_texts: list[str] = []
|
||||
for message in normalized_messages:
|
||||
if not message.text:
|
||||
raise ValueError("DurableAIAgent only supports text message inputs.")
|
||||
message_texts.append(message.text)
|
||||
|
||||
return "\n".join(message_texts)
|
||||
|
||||
@@ -9,15 +9,16 @@ and enables registration of agents as durable entities.
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import Any
|
||||
|
||||
from agent_framework import SupportsAgentRun, get_logger
|
||||
from agent_framework import SupportsAgentRun
|
||||
from durabletask.worker import TaskHubGrpcWorker
|
||||
|
||||
from ._callbacks import AgentResponseCallbackProtocol
|
||||
from ._entities import AgentEntity, DurableTaskEntityStateProvider
|
||||
|
||||
logger = get_logger("agent_framework.durabletask.worker")
|
||||
logger = logging.getLogger("agent_framework.durabletask")
|
||||
|
||||
|
||||
class DurableAIAgentWorker:
|
||||
|
||||
@@ -23,7 +23,7 @@ from agent_framework import (
|
||||
)
|
||||
from agent_framework._settings import load_settings
|
||||
from agent_framework._tools import FunctionTool
|
||||
from agent_framework._types import normalize_tools
|
||||
from agent_framework._types import AgentRunInputs, normalize_tools
|
||||
from agent_framework.exceptions import ServiceException
|
||||
from copilot import CopilotClient, CopilotSession
|
||||
from copilot.generated.session_events import SessionEvent, SessionEventType
|
||||
@@ -277,7 +277,7 @@ class GitHubCopilotAgent(BaseAgent, Generic[OptionsT]):
|
||||
@overload
|
||||
def run(
|
||||
self,
|
||||
messages: str | Message | Sequence[str | Message] | None = None,
|
||||
messages: AgentRunInputs | None = None,
|
||||
*,
|
||||
stream: Literal[False] = False,
|
||||
session: AgentSession | None = None,
|
||||
@@ -288,7 +288,7 @@ class GitHubCopilotAgent(BaseAgent, Generic[OptionsT]):
|
||||
@overload
|
||||
def run(
|
||||
self,
|
||||
messages: str | Message | Sequence[str | Message] | None = None,
|
||||
messages: AgentRunInputs | None = None,
|
||||
*,
|
||||
stream: Literal[True],
|
||||
session: AgentSession | None = None,
|
||||
@@ -298,7 +298,7 @@ class GitHubCopilotAgent(BaseAgent, Generic[OptionsT]):
|
||||
|
||||
def run(
|
||||
self,
|
||||
messages: str | Message | Sequence[str | Message] | None = None,
|
||||
messages: AgentRunInputs | None = None,
|
||||
*,
|
||||
stream: bool = False,
|
||||
session: AgentSession | None = None,
|
||||
@@ -340,7 +340,7 @@ class GitHubCopilotAgent(BaseAgent, Generic[OptionsT]):
|
||||
|
||||
async def _run_impl(
|
||||
self,
|
||||
messages: str | Message | Sequence[str | Message] | None = None,
|
||||
messages: AgentRunInputs | None = None,
|
||||
*,
|
||||
session: AgentSession | None = None,
|
||||
options: OptionsT | None = None,
|
||||
@@ -388,7 +388,7 @@ class GitHubCopilotAgent(BaseAgent, Generic[OptionsT]):
|
||||
|
||||
async def _stream_updates(
|
||||
self,
|
||||
messages: str | Message | Sequence[str | Message] | None = None,
|
||||
messages: AgentRunInputs | None = None,
|
||||
*,
|
||||
session: AgentSession | None = None,
|
||||
options: OptionsT | None = None,
|
||||
|
||||
@@ -3,6 +3,7 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import json
|
||||
import logging
|
||||
import sys
|
||||
from collections.abc import (
|
||||
AsyncIterable,
|
||||
@@ -28,7 +29,6 @@ from agent_framework import (
|
||||
Message,
|
||||
ResponseStream,
|
||||
UsageDetails,
|
||||
get_logger,
|
||||
)
|
||||
from agent_framework._settings import load_settings
|
||||
from agent_framework.exceptions import (
|
||||
@@ -281,7 +281,7 @@ class OllamaSettings(TypedDict, total=False):
|
||||
model_id: str | None
|
||||
|
||||
|
||||
logger = get_logger("agent_framework.ollama")
|
||||
logger = logging.getLogger("agent_framework.ollama")
|
||||
|
||||
|
||||
class OllamaChatClient(
|
||||
|
||||
@@ -38,7 +38,7 @@ class StubAgent(BaseAgent):
|
||||
|
||||
def run( # type: ignore[override]
|
||||
self,
|
||||
messages: str | Message | Sequence[str | Message] | None = None,
|
||||
messages: str | Content | Message | Sequence[str | Content | Message] | None = None,
|
||||
*,
|
||||
stream: bool = False,
|
||||
session: AgentSession | None = None,
|
||||
@@ -76,7 +76,7 @@ class StubManagerAgent(Agent):
|
||||
|
||||
async def run(
|
||||
self,
|
||||
messages: str | Message | Sequence[str | Message] | None = None,
|
||||
messages: str | Content | Message | Sequence[str | Content | Message] | None = None,
|
||||
*,
|
||||
session: AgentSession | None = None,
|
||||
**kwargs: Any,
|
||||
@@ -130,7 +130,7 @@ class ConcatenatedJsonManagerAgent(Agent):
|
||||
|
||||
async def run(
|
||||
self,
|
||||
messages: str | Message | Sequence[str | Message] | None = None,
|
||||
messages: str | Content | Message | Sequence[str | Content | Message] | None = None,
|
||||
*,
|
||||
session: AgentSession | None = None,
|
||||
**kwargs: Any,
|
||||
@@ -896,7 +896,7 @@ async def test_group_chat_with_orchestrator_factory_returning_chat_agent():
|
||||
|
||||
async def run(
|
||||
self,
|
||||
messages: str | Message | Sequence[str | Message] | None = None,
|
||||
messages: str | Content | Message | Sequence[str | Content | Message] | None = None,
|
||||
*,
|
||||
session: AgentSession | None = None,
|
||||
**kwargs: Any,
|
||||
|
||||
@@ -150,7 +150,7 @@ class StubAgent(BaseAgent):
|
||||
|
||||
def run( # type: ignore[override]
|
||||
self,
|
||||
messages: str | Message | Sequence[str | Message] | None = None,
|
||||
messages: str | Content | Message | Sequence[str | Content | Message] | None = None,
|
||||
*,
|
||||
stream: bool = False,
|
||||
session: AgentSession | None = None,
|
||||
@@ -411,7 +411,7 @@ class StubManagerAgent(BaseAgent):
|
||||
|
||||
def run(
|
||||
self,
|
||||
messages: str | Message | Sequence[str | Message] | None = None,
|
||||
messages: str | Content | Message | Sequence[str | Content | Message] | None = None,
|
||||
*,
|
||||
stream: bool = False,
|
||||
session: Any = None,
|
||||
|
||||
@@ -5,12 +5,12 @@ from __future__ import annotations
|
||||
import base64
|
||||
import inspect
|
||||
import json
|
||||
import logging
|
||||
from typing import Any, cast
|
||||
from uuid import uuid4
|
||||
|
||||
import httpx
|
||||
from agent_framework import AGENT_FRAMEWORK_USER_AGENT
|
||||
from agent_framework._logging import get_logger
|
||||
from agent_framework.observability import get_tracer
|
||||
from azure.core.credentials import TokenCredential
|
||||
from azure.core.credentials_async import AsyncTokenCredential
|
||||
@@ -33,7 +33,7 @@ from ._models import (
|
||||
)
|
||||
from ._settings import PurviewSettings, get_purview_scopes
|
||||
|
||||
logger = get_logger("agent_framework.purview")
|
||||
logger = logging.getLogger("agent_framework.purview")
|
||||
|
||||
|
||||
class PurviewClient:
|
||||
|
||||
@@ -1,9 +1,9 @@
|
||||
# Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
import logging
|
||||
from collections.abc import Awaitable, Callable
|
||||
|
||||
from agent_framework import AgentContext, AgentMiddleware, ChatContext, ChatMiddleware, MiddlewareTermination
|
||||
from agent_framework._logging import get_logger
|
||||
from azure.core.credentials import TokenCredential
|
||||
from azure.core.credentials_async import AsyncTokenCredential
|
||||
|
||||
@@ -14,7 +14,7 @@ from ._models import Activity
|
||||
from ._processor import ScopedContentProcessor
|
||||
from ._settings import PurviewSettings
|
||||
|
||||
logger = get_logger("agent_framework.purview")
|
||||
logger = logging.getLogger("agent_framework.purview")
|
||||
|
||||
|
||||
class PurviewPolicyMiddleware(AgentMiddleware):
|
||||
|
||||
@@ -2,16 +2,16 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import logging
|
||||
from collections.abc import Mapping, MutableMapping, Sequence
|
||||
from datetime import datetime
|
||||
from enum import Enum, Flag, auto
|
||||
from typing import Any, ClassVar, TypeVar, cast
|
||||
from uuid import uuid4
|
||||
|
||||
from agent_framework._logging import get_logger
|
||||
from agent_framework._serialization import SerializationMixin
|
||||
|
||||
logger = get_logger("agent_framework.purview")
|
||||
logger = logging.getLogger("agent_framework.purview")
|
||||
|
||||
# --------------------------------------------------------------------------------------
|
||||
# Enums & flag helpers
|
||||
|
||||
@@ -1,13 +1,13 @@
|
||||
# Copyright (c) Microsoft. All rights reserved.
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import time
|
||||
import uuid
|
||||
from collections.abc import Iterable, MutableMapping
|
||||
from typing import Any
|
||||
|
||||
from agent_framework import Message
|
||||
from agent_framework._logging import get_logger
|
||||
|
||||
from ._cache import CacheProvider, InMemoryCacheProvider, create_protection_scopes_cache_key
|
||||
from ._client import PurviewClient
|
||||
@@ -37,7 +37,7 @@ from ._models import (
|
||||
)
|
||||
from ._settings import PurviewSettings
|
||||
|
||||
logger = get_logger("agent_framework.purview")
|
||||
logger = logging.getLogger("agent_framework.purview")
|
||||
|
||||
|
||||
def _is_valid_guid(value: str | None) -> bool:
|
||||
|
||||
@@ -178,12 +178,15 @@ The `configure_otel_providers()` function automatically reads **standard OpenTel
|
||||
> **Note**: These are standard OpenTelemetry environment variables. See the [OpenTelemetry spec](https://opentelemetry.io/docs/specs/otel/configuration/sdk-environment-variables/) for more details.
|
||||
|
||||
#### Logging
|
||||
Agent Framework has a built-in logging configuration that works well with telemetry. It sets the format to a standard format that includes timestamp, pathname, line number, and log level. You can use that by calling the `setup_logging()` function from the `agent_framework` module.
|
||||
Use standard Python logging configuration to align logs with telemetry output.
|
||||
|
||||
```python
|
||||
from agent_framework import setup_logging
|
||||
import logging
|
||||
|
||||
setup_logging()
|
||||
logging.basicConfig(
|
||||
format="[%(asctime)s - %(pathname)s:%(lineno)d - %(levelname)s] %(message)s",
|
||||
datefmt="%Y-%m-%d %H:%M:%S",
|
||||
)
|
||||
```
|
||||
You can control at what level logging happens and thus what logs get exported, you can do this, by adding this:
|
||||
|
||||
|
||||
@@ -2,11 +2,12 @@
|
||||
|
||||
import argparse
|
||||
import asyncio
|
||||
import logging
|
||||
from contextlib import suppress
|
||||
from random import randint
|
||||
from typing import TYPE_CHECKING, Annotated, Literal
|
||||
|
||||
from agent_framework import setup_logging, tool
|
||||
from agent_framework import tool
|
||||
from agent_framework.observability import configure_otel_providers, get_tracer
|
||||
from agent_framework.openai import OpenAIResponsesClient
|
||||
from opentelemetry import trace
|
||||
@@ -31,7 +32,9 @@ Use this approach when you need custom exporter configuration beyond what enviro
|
||||
SCENARIOS = ["client", "client_stream", "tool", "all"]
|
||||
|
||||
|
||||
# NOTE: approval_mode="never_require" is for sample brevity. Use "always_require" in production; see samples/02-agents/tools/function_tool_with_approval.py and samples/02-agents/tools/function_tool_with_approval_and_sessions.py.
|
||||
# NOTE: approval_mode="never_require" is for sample brevity.
|
||||
# Use "always_require" in production; see samples/02-agents/tools/function_tool_with_approval.py
|
||||
# and samples/02-agents/tools/function_tool_with_approval_and_sessions.py.
|
||||
@tool(approval_mode="never_require")
|
||||
async def get_weather(
|
||||
location: Annotated[str, Field(description="The location to get the weather for.")],
|
||||
@@ -101,7 +104,10 @@ async def main(scenario: Literal["client", "client_stream", "tool", "all"] = "al
|
||||
"""Run the selected scenario(s)."""
|
||||
|
||||
# Setup the logging with the more complete format
|
||||
setup_logging()
|
||||
logging.basicConfig(
|
||||
format="[%(asctime)s - %(pathname)s:%(lineno)d - %(levelname)s] %(message)s",
|
||||
datefmt="%Y-%m-%d %H:%M:%S",
|
||||
)
|
||||
|
||||
# Create custom OTLP exporters with specific configuration
|
||||
# Note: You need to install opentelemetry-exporter-otlp-proto-grpc or -http separately
|
||||
|
||||
+9
-16
@@ -2,7 +2,7 @@
|
||||
|
||||
import asyncio
|
||||
|
||||
from agent_framework import Content, Message
|
||||
from agent_framework import Content
|
||||
from agent_framework.azure import AzureOpenAIResponsesClient
|
||||
from azure.identity import AzureCliCredential
|
||||
|
||||
@@ -20,24 +20,17 @@ async def main():
|
||||
# 1. Create an Azure Responses agent with vision capabilities
|
||||
agent = AzureOpenAIResponsesClient(credential=AzureCliCredential()).as_agent(
|
||||
name="VisionAgent",
|
||||
instructions="You are a helpful agent that can analyze images.",
|
||||
instructions="You are a image analysist, you get a image and need to respond with what you see in the picture.",
|
||||
)
|
||||
|
||||
# 2. Create a simple message with both text and image content
|
||||
user_message = Message(
|
||||
role="user",
|
||||
contents=[
|
||||
Content.from_text("What do you see in this image?"),
|
||||
Content.from_uri(
|
||||
uri="https://images.unsplash.com/photo-1506905925346-21bda4d32df4?w=800",
|
||||
media_type="image/jpeg",
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
# 3. Get the agent's response
|
||||
# 2. Get the agent's response
|
||||
print("User: What do you see in this image? [Image provided]")
|
||||
result = await agent.run(user_message)
|
||||
result = await agent.run(
|
||||
Content.from_uri(
|
||||
uri="https://images.unsplash.com/photo-1506905925346-21bda4d32df4?w=800",
|
||||
media_type="image/jpeg",
|
||||
)
|
||||
)
|
||||
print(f"Agent: {result.text}")
|
||||
print()
|
||||
|
||||
|
||||
+9
-16
@@ -2,7 +2,7 @@
|
||||
|
||||
import asyncio
|
||||
|
||||
from agent_framework import Content, Message
|
||||
from agent_framework import Content
|
||||
from agent_framework.openai import OpenAIResponsesClient
|
||||
|
||||
"""
|
||||
@@ -19,24 +19,17 @@ async def main():
|
||||
# 1. Create an OpenAI Responses agent with vision capabilities
|
||||
agent = OpenAIResponsesClient().as_agent(
|
||||
name="VisionAgent",
|
||||
instructions="You are a helpful agent that can analyze images.",
|
||||
instructions="You are a image analysist, you get a image and need to respond with what you see in the picture.",
|
||||
)
|
||||
|
||||
# 2. Create a simple message with both text and image content
|
||||
user_message = Message(
|
||||
role="user",
|
||||
contents=[
|
||||
Content.from_text(text="What do you see in this image?"),
|
||||
Content.from_uri(
|
||||
uri="https://images.unsplash.com/photo-1506905925346-21bda4d32df4?w=800",
|
||||
media_type="image/jpeg",
|
||||
),
|
||||
],
|
||||
)
|
||||
|
||||
# 3. Get the agent's response
|
||||
# 2. Get the agent's response
|
||||
print("User: What do you see in this image? [Image provided]")
|
||||
result = await agent.run(user_message)
|
||||
result = await agent.run(
|
||||
Content.from_uri(
|
||||
uri="https://images.unsplash.com/photo-1506905925346-21bda4d32df4?w=800",
|
||||
media_type="image/jpeg",
|
||||
)
|
||||
)
|
||||
print(f"Agent: {result.text}")
|
||||
print()
|
||||
|
||||
|
||||
@@ -38,6 +38,10 @@ export AZURE_OPENAI_RESPONSES_DEPLOYMENT_NAME="gpt-4o"
|
||||
|
||||
For Azure authentication, run `az login` before running samples.
|
||||
|
||||
## Note on XML tags
|
||||
|
||||
Some sample files include XML-style snippet tags (for example `<snippet_name>` and `</snippet_name>`). These are used by our documentation tooling and can be ignored or removed when you use the samples outside this repository.
|
||||
|
||||
## Additional Resources
|
||||
|
||||
- [Agent Framework Documentation](https://learn.microsoft.com/agent-framework/)
|
||||
|
||||
Reference in New Issue
Block a user