Python: Replace Pydantic Settings with TypedDict + load_settings() (#3843)

* Replace Pydantic Settings with TypedDict + load_settings()

- Remove pydantic-settings dependency, add python-dotenv
- Delete _pydantic.py (AFBaseSettings, HTTPsUrl)
- Add _settings.py with generic load_settings() function, SecretString,
  type coercion, and Required field validation (SettingNotFoundError)
- Convert all 13 settings classes from AFBaseSettings subclasses to
  TypedDict definitions with load_settings() calls
- Update all consumers from attribute access to dict access
- Add 20 unit tests for load_settings() covering basic loading, dotenv,
  SecretString, type coercion, and required field validation
- Update all existing tests for new settings patterns

* Fix mypy type errors from settings conversion

- Fix str | None attribute access in responses_client (walrus operator)
- Fix SecretString | None narrowing in bedrock (type: ignore after guard)
- Convert _context_provider.py attribute access to dict access (missed file)
- Fix endpoint type narrowing in search_provider and context_provider
- Fix purview: str | None .rstrip(), int | None defaults, urlparse bytes

* Address PR review: required_fields param, type validation, fixes

- Move required field validation from TypedDict annotations (Required)
  to a required_fields parameter on load_settings(), enabling runtime
  decisions about which fields are required
- Remove Required imports and restore from __future__ import annotations
  in ollama and foundry_local
- Add _check_override_type() for deterministic ServiceInitializationError
  on invalid override types (e.g. dict passed for str field)
- Fix all multi-exception test catches back to single exception type
- Fix Ollama host=None: use .get() so None is passed through to SDK default
- Fix Purview processor: use explicit is-None checks instead of or operator
- Remove unused BaseModel import from openai/_shared.py
- Add 4 new tests (24 total): required_fields param, type validation

* Fix type validation: allow int for float fields

_check_override_type now permits int values for float-typed fields,
matching Python's standard numeric promotion behavior.

* fix: wrap urlparse arg with str() to fix mypy bytes endswith error
This commit is contained in:
Eduard van Valkenburg
2026-02-12 09:51:20 +01:00
committed by GitHub
Unverified
parent b488158abe
commit 8457533c69
58 changed files with 1526 additions and 1113 deletions
@@ -27,7 +27,7 @@ from agent_framework import (
get_logger,
prepare_function_call_results,
)
from agent_framework._pydantic import AFBaseSettings
from agent_framework._settings import SecretString, load_settings
from agent_framework._types import _get_data_bytes_as_str # type: ignore
from agent_framework.exceptions import ServiceInitializationError
from agent_framework.observability import ChatTelemetryLayer
@@ -47,7 +47,7 @@ from anthropic.types.beta.beta_bash_code_execution_tool_result_error import (
from anthropic.types.beta.beta_code_execution_tool_result_error import (
BetaCodeExecutionToolResultError,
)
from pydantic import BaseModel, SecretStr, ValidationError
from pydantic import BaseModel
if sys.version_info >= (3, 11):
from typing import TypedDict # type: ignore # pragma: no cover
@@ -192,40 +192,20 @@ FINISH_REASON_MAP: dict[str, FinishReasonLiteral] = {
}
class AnthropicSettings(AFBaseSettings):
class AnthropicSettings(TypedDict, total=False):
"""Anthropic Project settings.
The settings are first loaded from environment variables with the prefix 'ANTHROPIC_'.
If the environment variables are not found, the settings can be loaded from a .env file
with the encoding 'utf-8'. If the settings are not found in the .env file, the settings
are ignored; however, validation will fail alerting that the settings are missing.
with the encoding 'utf-8'.
Keyword Args:
Keys:
api_key: The Anthropic API key.
chat_model_id: The Anthropic chat model ID.
env_file_path: If provided, the .env settings are read from this file path location.
env_file_encoding: The encoding of the .env file, defaults to 'utf-8'.
Examples:
.. code-block:: python
from agent_framework.anthropic import AnthropicSettings
# Using environment variables
# Set ANTHROPIC_API_KEY=your_anthropic_api_key
# ANTHROPIC_CHAT_MODEL_ID=claude-sonnet-4-5-20250929
# Or passing parameters directly
settings = AnthropicSettings(chat_model_id="claude-sonnet-4-5-20250929")
# Or loading from a .env file
settings = AnthropicSettings(env_file_path="path/to/.env")
"""
env_prefix: ClassVar[str] = "ANTHROPIC_"
api_key: SecretStr | None = None
chat_model_id: str | None = None
api_key: SecretString | None
chat_model_id: str | None
class AnthropicClient(
@@ -311,25 +291,24 @@ class AnthropicClient(
response = await client.get_response("Hello", options={"my_custom_option": "value"})
"""
try:
anthropic_settings = AnthropicSettings(
api_key=api_key, # type: ignore[arg-type]
chat_model_id=model_id,
env_file_path=env_file_path,
env_file_encoding=env_file_encoding,
)
except ValidationError as ex:
raise ServiceInitializationError("Failed to create Anthropic settings.", ex) from ex
anthropic_settings = load_settings(
AnthropicSettings,
env_prefix="ANTHROPIC_",
api_key=api_key,
chat_model_id=model_id,
env_file_path=env_file_path,
env_file_encoding=env_file_encoding,
)
if anthropic_client is None:
if not anthropic_settings.api_key:
if not anthropic_settings["api_key"]:
raise ServiceInitializationError(
"Anthropic API key is required. Set via 'api_key' parameter "
"or 'ANTHROPIC_API_KEY' environment variable."
)
anthropic_client = AsyncAnthropic(
api_key=anthropic_settings.api_key.get_secret_value(),
api_key=anthropic_settings["api_key"].get_secret_value(),
default_headers={"User-Agent": AGENT_FRAMEWORK_USER_AGENT},
)
@@ -343,7 +322,7 @@ class AnthropicClient(
# Initialize instance variables
self.anthropic_client = anthropic_client
self.additional_beta_flags = additional_beta_flags or []
self.model_id = anthropic_settings.chat_model_id
self.model_id = anthropic_settings["chat_model_id"]
# streaming requires tracking the last function call ID and name
self._last_call_id_name: tuple[str, str] | None = None
@@ -13,6 +13,7 @@ from agent_framework import (
SupportsChatGetResponse,
tool,
)
from agent_framework._settings import load_settings
from agent_framework.exceptions import ServiceInitializationError
from anthropic.types.beta import (
BetaMessage,
@@ -20,7 +21,7 @@ from anthropic.types.beta import (
BetaToolUseBlock,
BetaUsage,
)
from pydantic import Field, ValidationError
from pydantic import Field
from agent_framework_anthropic import AnthropicClient
from agent_framework_anthropic._chat_client import AnthropicSettings
@@ -41,8 +42,12 @@ def create_test_anthropic_client(
) -> AnthropicClient:
"""Helper function to create AnthropicClient instances for testing, bypassing normal validation."""
if anthropic_settings is None:
anthropic_settings = AnthropicSettings(
api_key="test-api-key-12345", chat_model_id="claude-3-5-sonnet-20241022", env_file_path="test.env"
anthropic_settings = load_settings(
AnthropicSettings,
env_prefix="ANTHROPIC_",
api_key="test-api-key-12345",
chat_model_id="claude-3-5-sonnet-20241022",
env_file_path="test.env",
)
# Create client instance directly
@@ -50,7 +55,7 @@ def create_test_anthropic_client(
# Set attributes directly
client.anthropic_client = mock_anthropic_client
client.model_id = model_id or anthropic_settings.chat_model_id
client.model_id = model_id or anthropic_settings["chat_model_id"]
client._last_call_id_name = None
client.additional_properties = {}
client.middleware = None
@@ -64,30 +69,34 @@ def create_test_anthropic_client(
def test_anthropic_settings_init(anthropic_unit_test_env: dict[str, str]) -> None:
"""Test AnthropicSettings initialization."""
settings = AnthropicSettings(env_file_path="test.env")
settings = load_settings(AnthropicSettings, env_prefix="ANTHROPIC_", env_file_path="test.env")
assert settings.api_key is not None
assert settings.api_key.get_secret_value() == anthropic_unit_test_env["ANTHROPIC_API_KEY"]
assert settings.chat_model_id == anthropic_unit_test_env["ANTHROPIC_CHAT_MODEL_ID"]
assert settings["api_key"] is not None
assert settings["api_key"].get_secret_value() == anthropic_unit_test_env["ANTHROPIC_API_KEY"]
assert settings["chat_model_id"] == anthropic_unit_test_env["ANTHROPIC_CHAT_MODEL_ID"]
def test_anthropic_settings_init_with_explicit_values() -> None:
"""Test AnthropicSettings initialization with explicit values."""
settings = AnthropicSettings(
api_key="custom-api-key", chat_model_id="claude-3-opus-20240229", env_file_path="test.env"
settings = load_settings(
AnthropicSettings,
env_prefix="ANTHROPIC_",
api_key="custom-api-key",
chat_model_id="claude-3-opus-20240229",
env_file_path="test.env",
)
assert settings.api_key is not None
assert settings.api_key.get_secret_value() == "custom-api-key"
assert settings.chat_model_id == "claude-3-opus-20240229"
assert settings["api_key"] is not None
assert settings["api_key"].get_secret_value() == "custom-api-key"
assert settings["chat_model_id"] == "claude-3-opus-20240229"
@pytest.mark.parametrize("exclude_list", [["ANTHROPIC_API_KEY"]], indirect=True)
def test_anthropic_settings_missing_api_key(anthropic_unit_test_env: dict[str, str]) -> None:
"""Test AnthropicSettings when API key is missing."""
settings = AnthropicSettings(env_file_path="test.env")
assert settings.api_key is None
assert settings.chat_model_id == anthropic_unit_test_env["ANTHROPIC_CHAT_MODEL_ID"]
settings = load_settings(AnthropicSettings, env_prefix="ANTHROPIC_", env_file_path="test.env")
assert settings["api_key"] is None
assert settings["chat_model_id"] == anthropic_unit_test_env["ANTHROPIC_CHAT_MODEL_ID"]
# Client Initialization Tests
@@ -116,23 +125,13 @@ def test_anthropic_client_init_auto_create_client(anthropic_unit_test_env: dict[
def test_anthropic_client_init_missing_api_key() -> None:
"""Test AnthropicClient initialization when API key is missing."""
with patch("agent_framework_anthropic._chat_client.AnthropicSettings") as mock_settings:
mock_settings.return_value.api_key = None
mock_settings.return_value.chat_model_id = "claude-3-5-sonnet-20241022"
with patch("agent_framework_anthropic._chat_client.load_settings") as mock_load:
mock_load.return_value = {"api_key": None, "chat_model_id": "claude-3-5-sonnet-20241022"}
with pytest.raises(ServiceInitializationError, match="Anthropic API key is required"):
AnthropicClient()
def test_anthropic_client_init_validation_error() -> None:
"""Test that ValidationError in AnthropicSettings is properly handled."""
with patch("agent_framework_anthropic._chat_client.AnthropicSettings") as mock_settings:
mock_settings.side_effect = ValidationError.from_exception_data("test", [])
with pytest.raises(ServiceInitializationError, match="Failed to create Anthropic settings"):
AnthropicClient()
def test_anthropic_client_service_url(mock_anthropic_client: MagicMock) -> None:
"""Test service_url method."""
client = create_test_anthropic_client(mock_anthropic_client)
@@ -16,6 +16,7 @@ from typing import TYPE_CHECKING, Any, ClassVar, Literal
from agent_framework import AGENT_FRAMEWORK_USER_AGENT, Message
from agent_framework._logging import get_logger
from agent_framework._sessions import AgentSession, BaseContextProvider, SessionContext
from agent_framework._settings import load_settings
from agent_framework.exceptions import ServiceInitializationError
from azure.core.credentials import AzureKeyCredential
from azure.core.credentials_async import AsyncTokenCredential
@@ -41,7 +42,6 @@ from azure.search.documents.models import (
VectorizableTextQuery,
VectorizedQuery,
)
from pydantic import ValidationError
from ._search_provider import AzureAISearchSettings
@@ -180,40 +180,39 @@ class _AzureAISearchContextProvider(BaseContextProvider):
super().__init__(source_id)
# Load settings from environment/file
try:
settings = AzureAISearchSettings(
endpoint=endpoint,
index_name=index_name,
knowledge_base_name=knowledge_base_name,
api_key=api_key if isinstance(api_key, str) else None,
env_file_path=env_file_path,
env_file_encoding=env_file_encoding,
)
except ValidationError as ex:
raise ServiceInitializationError("Failed to create Azure AI Search settings.", ex) from ex
settings = load_settings(
AzureAISearchSettings,
env_prefix="AZURE_SEARCH_",
endpoint=endpoint,
index_name=index_name,
knowledge_base_name=knowledge_base_name,
api_key=api_key if isinstance(api_key, str) else None,
env_file_path=env_file_path,
env_file_encoding=env_file_encoding,
)
if not settings.endpoint:
if not settings.get("endpoint"):
raise ServiceInitializationError(
"Azure AI Search endpoint is required. Set via 'endpoint' parameter "
"or 'AZURE_SEARCH_ENDPOINT' environment variable."
)
if mode == "semantic":
if not settings.index_name:
if not settings.get("index_name"):
raise ServiceInitializationError(
"Azure AI Search index name is required for semantic mode. "
"Set via 'index_name' parameter or 'AZURE_SEARCH_INDEX_NAME' environment variable."
)
elif mode == "agentic":
if settings.index_name and settings.knowledge_base_name:
if settings.get("index_name") and settings.get("knowledge_base_name"):
raise ServiceInitializationError(
"For agentic mode, provide either 'index_name' OR 'knowledge_base_name', not both."
)
if not settings.index_name and not settings.knowledge_base_name:
if not settings.get("index_name") and not settings.get("knowledge_base_name"):
raise ServiceInitializationError(
"For agentic mode, provide either 'index_name' or 'knowledge_base_name'."
)
if settings.index_name and not model_deployment_name:
if settings.get("index_name") and not model_deployment_name:
raise ServiceInitializationError(
"model_deployment_name is required for agentic mode when creating Knowledge Base from index."
)
@@ -223,16 +222,16 @@ class _AzureAISearchContextProvider(BaseContextProvider):
resolved_credential = credential
elif isinstance(api_key, AzureKeyCredential):
resolved_credential = api_key
elif settings.api_key:
resolved_credential = AzureKeyCredential(settings.api_key.get_secret_value())
elif settings.get("api_key"):
resolved_credential = AzureKeyCredential(settings["api_key"].get_secret_value()) # type: ignore[union-attr]
else:
raise ServiceInitializationError(
"Azure credential is required. Provide 'api_key' or 'credential' parameter "
"or set 'AZURE_SEARCH_API_KEY' environment variable."
)
self.endpoint = settings.endpoint
self.index_name = settings.index_name
self.endpoint: str = settings["endpoint"] # type: ignore[assignment] # validated above
self.index_name = settings.get("index_name")
self.credential = resolved_credential
self.mode = mode
self.top_k = top_k
@@ -244,7 +243,7 @@ class _AzureAISearchContextProvider(BaseContextProvider):
self.azure_openai_resource_url = azure_openai_resource_url
self.azure_openai_deployment_name = model_deployment_name
self.model_name = model_name or model_deployment_name
self.knowledge_base_name = settings.knowledge_base_name
self.knowledge_base_name = settings.get("knowledge_base_name")
self.retrieval_instructions = retrieval_instructions
self.azure_openai_api_key = azure_openai_api_key
self.knowledge_base_output_mode = knowledge_base_output_mode
@@ -253,10 +252,10 @@ class _AzureAISearchContextProvider(BaseContextProvider):
self._use_existing_knowledge_base = False
if mode == "agentic":
if settings.knowledge_base_name:
if settings.get("knowledge_base_name"):
self._use_existing_knowledge_base = True
else:
self.knowledge_base_name = f"{settings.index_name}-kb"
self.knowledge_base_name = f"{settings.get('index_name', '')}-kb"
self._auto_discovered_vector_field = False
self._use_vectorizable_query = False
@@ -5,11 +5,11 @@ from __future__ import annotations
import sys
from collections.abc import Awaitable, Callable, MutableSequence
from typing import TYPE_CHECKING, Any, ClassVar, Literal
from typing import TYPE_CHECKING, Any, Literal
from agent_framework import AGENT_FRAMEWORK_USER_AGENT, Context, ContextProvider, Message
from agent_framework._logging import get_logger
from agent_framework._pydantic import AFBaseSettings
from agent_framework._settings import SecretString, load_settings
from agent_framework.exceptions import ServiceInitializationError
from azure.core.credentials import AzureKeyCredential
from azure.core.credentials_async import AsyncTokenCredential
@@ -35,7 +35,6 @@ from azure.search.documents.models import (
VectorizableTextQuery,
VectorizedQuery,
)
from pydantic import SecretStr, ValidationError
# Type checking imports for optional agentic mode dependencies
if TYPE_CHECKING:
@@ -99,9 +98,9 @@ else:
from typing_extensions import override # type: ignore[import] # pragma: no cover
if sys.version_info >= (3, 11):
from typing import Self # pragma: no cover
from typing import Self, TypedDict # pragma: no cover
else:
from typing_extensions import Self # pragma: no cover
from typing_extensions import Self, TypedDict # pragma: no cover
"""Azure AI Search Context Provider for Agent Framework.
@@ -120,7 +119,7 @@ logger = get_logger("agent_framework.azure")
_DEFAULT_AGENTIC_MESSAGE_HISTORY_COUNT = 10
class AzureAISearchSettings(AFBaseSettings):
class AzureAISearchSettings(TypedDict, total=False):
"""Settings for Azure AI Search Context Provider with auto-loading from environment.
The settings are first loaded from environment variables with the prefix 'AZURE_SEARCH_'.
@@ -158,12 +157,10 @@ class AzureAISearchSettings(AFBaseSettings):
settings = AzureAISearchSettings(env_file_path="path/to/.env")
"""
env_prefix: ClassVar[str] = "AZURE_SEARCH_"
endpoint: str | None = None
index_name: str | None = None
knowledge_base_name: str | None = None
api_key: SecretStr | None = None
endpoint: str | None
index_name: str | None
knowledge_base_name: str | None
api_key: SecretString | None
class AzureAISearchContextProvider(ContextProvider):
@@ -336,42 +333,41 @@ class AzureAISearchContextProvider(ContextProvider):
provider = AzureAISearchContextProvider(credential=credential, env_file_path="path/to/.env")
"""
# Load settings from environment/file
try:
settings = AzureAISearchSettings(
endpoint=endpoint,
index_name=index_name,
knowledge_base_name=knowledge_base_name,
api_key=api_key if isinstance(api_key, str) else None,
env_file_path=env_file_path,
env_file_encoding=env_file_encoding,
)
except ValidationError as ex:
raise ServiceInitializationError("Failed to create Azure AI Search settings.", ex) from ex
settings = load_settings(
AzureAISearchSettings,
env_prefix="AZURE_SEARCH_",
endpoint=endpoint,
index_name=index_name,
knowledge_base_name=knowledge_base_name,
api_key=api_key if isinstance(api_key, str) else None,
env_file_path=env_file_path,
env_file_encoding=env_file_encoding,
)
# Validate required parameters
if not settings.endpoint:
if not settings.get("endpoint"):
raise ServiceInitializationError(
"Azure AI Search endpoint is required. Set via 'endpoint' parameter "
"or 'AZURE_SEARCH_ENDPOINT' environment variable."
)
# Validate index_name and knowledge_base_name based on mode
# Note: settings.* contains the resolved value (explicit param OR env var)
# Note: settings["field"] / settings.get("field") contains the resolved value (explicit param OR env var)
if mode == "semantic":
# Semantic mode: always requires index_name
if not settings.index_name:
if not settings.get("index_name"):
raise ServiceInitializationError(
"Azure AI Search index name is required for semantic mode. "
"Set via 'index_name' parameter or 'AZURE_SEARCH_INDEX_NAME' environment variable."
)
elif mode == "agentic":
# Agentic mode: requires exactly ONE of index_name or knowledge_base_name
if settings.index_name and settings.knowledge_base_name:
if settings.get("index_name") and settings.get("knowledge_base_name"):
raise ServiceInitializationError(
"For agentic mode, provide either 'index_name' OR 'knowledge_base_name', not both. "
"Use 'index_name' to auto-create a Knowledge Base, or 'knowledge_base_name' to use an existing one."
)
if not settings.index_name and not settings.knowledge_base_name:
if not settings.get("index_name") and not settings.get("knowledge_base_name"):
raise ServiceInitializationError(
"For agentic mode, provide either 'index_name' (to auto-create Knowledge Base) "
"or 'knowledge_base_name' (to use existing Knowledge Base). "
@@ -379,7 +375,7 @@ class AzureAISearchContextProvider(ContextProvider):
"AZURE_SEARCH_INDEX_NAME / AZURE_SEARCH_KNOWLEDGE_BASE_NAME."
)
# If using index_name to create KB, model config is required
if settings.index_name and not model_deployment_name:
if settings.get("index_name") and not model_deployment_name:
raise ServiceInitializationError(
"model_deployment_name is required for agentic mode when creating Knowledge Base from index. "
"This is the Azure OpenAI deployment used by the Knowledge Base for query planning."
@@ -392,16 +388,16 @@ class AzureAISearchContextProvider(ContextProvider):
resolved_credential = credential
elif isinstance(api_key, AzureKeyCredential):
resolved_credential = api_key
elif settings.api_key:
resolved_credential = AzureKeyCredential(settings.api_key.get_secret_value())
elif resolved_api_key := settings.get("api_key"):
resolved_credential = AzureKeyCredential(resolved_api_key.get_secret_value())
else:
raise ServiceInitializationError(
"Azure credential is required. Provide 'api_key' or 'credential' parameter "
"or set 'AZURE_SEARCH_API_KEY' environment variable."
)
self.endpoint = settings.endpoint
self.index_name = settings.index_name
self.endpoint: str = settings["endpoint"] # type: ignore[assignment] # validated above
self.index_name = settings.get("index_name")
self.credential = resolved_credential
self.mode = mode
self.top_k = top_k
@@ -416,7 +412,7 @@ class AzureAISearchContextProvider(ContextProvider):
# If model_name not provided, default to deployment name
self.model_name = model_name or model_deployment_name
# Use resolved KB name (from explicit param or env var)
self.knowledge_base_name = settings.knowledge_base_name
self.knowledge_base_name = settings.get("knowledge_base_name")
self.retrieval_instructions = retrieval_instructions
self.azure_openai_api_key = azure_openai_api_key
self.knowledge_base_output_mode = knowledge_base_output_mode
@@ -429,12 +425,12 @@ class AzureAISearchContextProvider(ContextProvider):
# - index_name provided: auto-create KB from index
self._use_existing_knowledge_base = False
if mode == "agentic":
if settings.knowledge_base_name:
if settings.get("knowledge_base_name"):
# Use existing KB directly (supports any source type: web, blob, index, etc.)
self._use_existing_knowledge_base = True
else:
# Auto-generate KB name from index name
self.knowledge_base_name = f"{settings.index_name}-kb"
self.knowledge_base_name = f"{settings.get('index_name', '')}-kb"
# Auto-discover vector field if not specified
self._auto_discovered_vector_field = False
@@ -6,6 +6,7 @@ from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from agent_framework import Context, Message
from agent_framework._settings import load_settings
from agent_framework.azure import AzureAISearchContextProvider, AzureAISearchSettings
from agent_framework.exceptions import ServiceInitializationError
from azure.core.credentials import AzureKeyCredential
@@ -48,25 +49,27 @@ class TestAzureAISearchSettings:
def test_settings_with_direct_values(self) -> None:
"""Test settings with direct values."""
settings = AzureAISearchSettings(
settings = load_settings(
AzureAISearchSettings,
env_prefix="AZURE_SEARCH_",
endpoint="https://test.search.windows.net",
index_name="test-index",
api_key="test-key",
)
assert settings.endpoint == "https://test.search.windows.net"
assert settings.index_name == "test-index"
# api_key is now SecretStr
assert settings.api_key.get_secret_value() == "test-key"
assert settings["endpoint"] == "https://test.search.windows.net"
assert settings["index_name"] == "test-index"
assert settings["api_key"] == "test-key"
def test_settings_with_env_file_path(self) -> None:
"""Test settings with env_file_path parameter."""
settings = AzureAISearchSettings(
settings = load_settings(
AzureAISearchSettings,
env_prefix="AZURE_SEARCH_",
endpoint="https://test.search.windows.net",
index_name="test-index",
env_file_path="test.env",
)
assert settings.endpoint == "https://test.search.windows.net"
assert settings.index_name == "test-index"
assert settings["endpoint"] == "https://test.search.windows.net"
assert settings["index_name"] == "test-index"
def test_provider_uses_settings_from_env(self) -> None:
"""Test that provider creates settings internally from env."""
@@ -15,12 +15,13 @@ from agent_framework import (
normalize_tools,
)
from agent_framework._mcp import MCPTool
from agent_framework._settings import load_settings
from agent_framework.exceptions import ServiceInitializationError
from azure.ai.agents.aio import AgentsClient
from azure.ai.agents.models import Agent as AzureAgent
from azure.ai.agents.models import ResponseFormatJsonSchema, ResponseFormatJsonSchemaType
from azure.core.credentials_async import AsyncTokenCredential
from pydantic import BaseModel, ValidationError
from pydantic import BaseModel
from ._chat_client import AzureAIAgentClient, AzureAIAgentOptions
from ._shared import AzureAISettings, from_azure_ai_agent_tools, to_azure_ai_agent_tools
@@ -112,21 +113,21 @@ class AzureAIAgentsProvider(Generic[OptionsCoT]):
Raises:
ServiceInitializationError: If required parameters are missing or invalid.
"""
try:
self._settings = AzureAISettings(
project_endpoint=project_endpoint,
env_file_path=env_file_path,
env_file_encoding=env_file_encoding,
)
except ValidationError as ex:
raise ServiceInitializationError("Failed to create Azure AI settings.", ex) from ex
self._settings = load_settings(
AzureAISettings,
env_prefix="AZURE_AI_",
project_endpoint=project_endpoint,
env_file_path=env_file_path,
env_file_encoding=env_file_encoding,
)
self._should_close_client = False
if agents_client is not None:
self._agents_client = agents_client
else:
if not self._settings.project_endpoint:
resolved_endpoint = self._settings.get("project_endpoint")
if not resolved_endpoint:
raise ServiceInitializationError(
"Azure AI project endpoint is required. Provide 'project_endpoint' parameter "
"or set 'AZURE_AI_PROJECT_ENDPOINT' environment variable."
@@ -134,7 +135,7 @@ class AzureAIAgentsProvider(Generic[OptionsCoT]):
if not credential:
raise ServiceInitializationError("Azure credential is required when agents_client is not provided.")
self._agents_client = AgentsClient(
endpoint=self._settings.project_endpoint,
endpoint=resolved_endpoint,
credential=credential,
user_agent=AGENT_FRAMEWORK_USER_AGENT,
)
@@ -211,7 +212,7 @@ class AzureAIAgentsProvider(Generic[OptionsCoT]):
tools=get_weather,
)
"""
resolved_model = model or self._settings.model_deployment_name
resolved_model = model or self._settings.get("model_deployment_name")
if not resolved_model:
raise ServiceInitializationError(
"Model deployment name is required. Provide 'model' parameter "
@@ -35,6 +35,7 @@ from agent_framework import (
get_logger,
prepare_function_call_results,
)
from agent_framework._settings import load_settings
from agent_framework.exceptions import ServiceInitializationError, ServiceInvalidRequestError, ServiceResponseException
from agent_framework.observability import ChatTelemetryLayer
from azure.ai.agents.aio import AgentsClient
@@ -85,7 +86,7 @@ from azure.ai.agents.models import (
ToolOutput,
)
from azure.core.credentials_async import AsyncTokenCredential
from pydantic import BaseModel, ValidationError
from pydantic import BaseModel
from ._shared import AzureAISettings, to_azure_ai_agent_tools
@@ -482,26 +483,26 @@ class AzureAIAgentClient(
client: AzureAIAgentClient[MyOptions] = AzureAIAgentClient(credential=credential)
response = await client.get_response("Hello", options={"my_custom_option": "value"})
"""
try:
azure_ai_settings = AzureAISettings(
project_endpoint=project_endpoint,
model_deployment_name=model_deployment_name,
env_file_path=env_file_path,
env_file_encoding=env_file_encoding,
)
except ValidationError as ex:
raise ServiceInitializationError("Failed to create Azure AI settings.", ex) from ex
azure_ai_settings = load_settings(
AzureAISettings,
env_prefix="AZURE_AI_",
project_endpoint=project_endpoint,
model_deployment_name=model_deployment_name,
env_file_path=env_file_path,
env_file_encoding=env_file_encoding,
)
# If no agents_client is provided, create one
should_close_client = False
if agents_client is None:
if not azure_ai_settings.project_endpoint:
resolved_endpoint = azure_ai_settings.get("project_endpoint")
if not resolved_endpoint:
raise ServiceInitializationError(
"Azure AI project endpoint is required. Set via 'project_endpoint' parameter "
"or 'AZURE_AI_PROJECT_ENDPOINT' environment variable."
)
if agent_id is None and not azure_ai_settings.model_deployment_name:
if agent_id is None and not azure_ai_settings.get("model_deployment_name"):
raise ServiceInitializationError(
"Azure AI model deployment name is required. Set via 'model_deployment_name' parameter "
"or 'AZURE_AI_MODEL_DEPLOYMENT_NAME' environment variable."
@@ -511,7 +512,7 @@ class AzureAIAgentClient(
if not credential:
raise ServiceInitializationError("Azure credential is required when agents_client is not provided.")
agents_client = AgentsClient(
endpoint=azure_ai_settings.project_endpoint,
endpoint=resolved_endpoint,
credential=credential,
user_agent=AGENT_FRAMEWORK_USER_AGENT,
)
@@ -530,7 +531,7 @@ class AzureAIAgentClient(
self.agent_id = agent_id
self.agent_name = agent_name
self.agent_description = agent_description
self.model_id = azure_ai_settings.model_deployment_name
self.model_id = azure_ai_settings.get("model_deployment_name")
self.thread_id = thread_id
self.should_cleanup_agent = should_cleanup_agent # Track whether we should delete the agent
self._agent_created = False # Track whether agent was created inside this class
@@ -20,6 +20,7 @@ from agent_framework import (
MiddlewareTypes,
get_logger,
)
from agent_framework._settings import load_settings
from agent_framework.exceptions import ServiceInitializationError
from agent_framework.observability import ChatTelemetryLayer
from agent_framework.openai import OpenAIResponsesOptions
@@ -40,7 +41,6 @@ from azure.ai.projects.models import (
from azure.ai.projects.models import FileSearchTool as ProjectsFileSearchTool
from azure.core.credentials_async import AsyncTokenCredential
from azure.core.exceptions import ResourceNotFoundError
from pydantic import ValidationError
from ._shared import AzureAISettings, create_text_format_config
@@ -171,20 +171,20 @@ class RawAzureAIClient(RawOpenAIResponsesClient[AzureAIClientOptionsT], Generic[
client: AzureAIClient[MyOptions] = AzureAIClient(credential=credential)
response = await client.get_response("Hello", options={"my_custom_option": "value"})
"""
try:
azure_ai_settings = AzureAISettings(
project_endpoint=project_endpoint,
model_deployment_name=model_deployment_name,
env_file_path=env_file_path,
env_file_encoding=env_file_encoding,
)
except ValidationError as ex:
raise ServiceInitializationError("Failed to create Azure AI settings.", ex) from ex
azure_ai_settings = load_settings(
AzureAISettings,
env_prefix="AZURE_AI_",
project_endpoint=project_endpoint,
model_deployment_name=model_deployment_name,
env_file_path=env_file_path,
env_file_encoding=env_file_encoding,
)
# If no project_client is provided, create one
should_close_client = False
if project_client is None:
if not azure_ai_settings.project_endpoint:
resolved_endpoint = azure_ai_settings.get("project_endpoint")
if not resolved_endpoint:
raise ServiceInitializationError(
"Azure AI project endpoint is required. Set via 'project_endpoint' parameter "
"or 'AZURE_AI_PROJECT_ENDPOINT' environment variable."
@@ -194,7 +194,7 @@ class RawAzureAIClient(RawOpenAIResponsesClient[AzureAIClientOptionsT], Generic[
if not credential:
raise ServiceInitializationError("Azure credential is required when project_client is not provided.")
project_client = AIProjectClient(
endpoint=azure_ai_settings.project_endpoint,
endpoint=resolved_endpoint,
credential=credential,
user_agent=AGENT_FRAMEWORK_USER_AGENT,
)
@@ -212,7 +212,7 @@ class RawAzureAIClient(RawOpenAIResponsesClient[AzureAIClientOptionsT], Generic[
self.use_latest_version = use_latest_version
self.project_client = project_client
self.credential = credential
self.model_id = azure_ai_settings.model_deployment_name
self.model_id = azure_ai_settings.get("model_deployment_name")
self.conversation_id = conversation_id
# Track whether the application endpoint is used
@@ -16,6 +16,7 @@ from agent_framework import (
normalize_tools,
)
from agent_framework._mcp import MCPTool
from agent_framework._settings import load_settings
from agent_framework.exceptions import ServiceInitializationError
from azure.ai.projects.aio import AIProjectClient
from azure.ai.projects.models import (
@@ -28,7 +29,6 @@ from azure.ai.projects.models import (
FunctionTool as AzureFunctionTool,
)
from azure.core.credentials_async import AsyncTokenCredential
from pydantic import ValidationError
from ._client import AzureAIClient, AzureAIProjectAgentOptions
from ._shared import AzureAISettings, create_text_format_config, from_azure_ai_tools, to_azure_ai_tools
@@ -123,21 +123,21 @@ class AzureAIProjectAgentProvider(Generic[OptionsCoT]):
Raises:
ServiceInitializationError: If required parameters are missing or invalid.
"""
try:
self._settings = AzureAISettings(
project_endpoint=project_endpoint,
model_deployment_name=model,
env_file_path=env_file_path,
env_file_encoding=env_file_encoding,
)
except ValidationError as ex:
raise ServiceInitializationError("Failed to create Azure AI settings.", ex) from ex
self._settings = load_settings(
AzureAISettings,
env_prefix="AZURE_AI_",
project_endpoint=project_endpoint,
model_deployment_name=model,
env_file_path=env_file_path,
env_file_encoding=env_file_encoding,
)
# Track whether we should close client connection
self._should_close_client = False
if project_client is None:
if not self._settings.project_endpoint:
resolved_endpoint = self._settings.get("project_endpoint")
if not resolved_endpoint:
raise ServiceInitializationError(
"Azure AI project endpoint is required. Set via 'project_endpoint' parameter "
"or 'AZURE_AI_PROJECT_ENDPOINT' environment variable."
@@ -147,7 +147,7 @@ class AzureAIProjectAgentProvider(Generic[OptionsCoT]):
raise ServiceInitializationError("Azure credential is required when project_client is not provided.")
project_client = AIProjectClient(
endpoint=self._settings.project_endpoint,
endpoint=resolved_endpoint,
credential=credential,
user_agent=AGENT_FRAMEWORK_USER_AGENT,
)
@@ -191,7 +191,7 @@ class AzureAIProjectAgentProvider(Generic[OptionsCoT]):
ServiceInitializationError: If required parameters are missing.
"""
# Resolve model from parameter or environment variable
resolved_model = model or self._settings.model_deployment_name
resolved_model = model or self._settings.get("model_deployment_name")
if not resolved_model:
raise ServiceInitializationError(
"Model deployment name is required. Provide 'model' parameter "
@@ -2,14 +2,14 @@
from __future__ import annotations
import sys
from collections.abc import Mapping, MutableMapping, Sequence
from typing import Any, ClassVar, cast
from typing import Any, cast
from agent_framework import (
FunctionTool,
get_logger,
)
from agent_framework._pydantic import AFBaseSettings
from agent_framework.exceptions import ServiceInvalidRequestError
from azure.ai.agents.models import (
CodeInterpreterToolDefinition,
@@ -32,10 +32,15 @@ from azure.ai.projects.models import (
)
from pydantic import BaseModel
if sys.version_info >= (3, 11):
from typing import TypedDict # pragma: no cover
else:
from typing_extensions import TypedDict # type: ignore # pragma: no cover
logger = get_logger("agent_framework.azure")
class AzureAISettings(AFBaseSettings):
class AzureAISettings(TypedDict, total=False):
"""Azure AI Project settings.
The settings are first loaded from environment variables with the prefix 'AZURE_AI_'.
@@ -70,10 +75,8 @@ class AzureAISettings(AFBaseSettings):
settings = AzureAISettings(env_file_path="path/to/.env")
"""
env_prefix: ClassVar[str] = "AZURE_AI_"
project_endpoint: str | None = None
model_deployment_name: str | None = None
project_endpoint: str | None
model_deployment_name: str | None
def _extract_project_connection_id(additional_properties: dict[str, Any] | None) -> str | None:
@@ -86,12 +86,9 @@ def test_provider_init_missing_endpoint_raises(
mock_azure_credential: MagicMock,
) -> None:
"""Test AzureAIAgentsProvider raises error when endpoint is missing."""
# Mock AzureAISettings to return None for project_endpoint
with patch("agent_framework_azure_ai._agent_provider.AzureAISettings") as mock_settings_class:
mock_settings = MagicMock()
mock_settings.project_endpoint = None
mock_settings.model_deployment_name = "test-model"
mock_settings_class.return_value = mock_settings
# Mock load_settings to return a dict with None for project_endpoint
with patch("agent_framework_azure_ai._agent_provider.load_settings") as mock_load_settings:
mock_load_settings.return_value = {"project_endpoint": None, "model_deployment_name": "test-model"}
with pytest.raises(ServiceInitializationError) as exc_info:
AzureAIAgentsProvider(credential=mock_azure_credential)
@@ -270,11 +267,8 @@ async def test_create_agent_missing_model_raises(
) -> None:
"""Test that create_agent raises error when model is not specified."""
# Create provider with mocked settings that has no model
with patch("agent_framework_azure_ai._agent_provider.AzureAISettings") as mock_settings_class:
mock_settings = MagicMock()
mock_settings.project_endpoint = "https://test.com"
mock_settings.model_deployment_name = None # No model configured
mock_settings_class.return_value = mock_settings
with patch("agent_framework_azure_ai._agent_provider.load_settings") as mock_load_settings:
mock_load_settings.return_value = {"project_endpoint": "https://test.com", "model_deployment_name": None}
provider = AzureAIAgentsProvider(agents_client=mock_agents_client)
@@ -21,6 +21,7 @@ from agent_framework import (
tool,
)
from agent_framework._serialization import SerializationMixin
from agent_framework._settings import load_settings
from agent_framework.exceptions import ServiceInitializationError, ServiceInvalidRequestError
from azure.ai.agents.models import (
AgentsNamedToolChoice,
@@ -44,7 +45,7 @@ from azure.ai.agents.models import (
)
from azure.core.credentials_async import AsyncTokenCredential
from azure.identity.aio import AzureCliCredential
from pydantic import BaseModel, Field, ValidationError
from pydantic import BaseModel, Field
from agent_framework_azure_ai import AzureAIAgentClient, AzureAISettings
@@ -67,7 +68,7 @@ def create_test_azure_ai_chat_client(
) -> AzureAIAgentClient:
"""Helper function to create AzureAIAgentClient instances for testing, bypassing normal validation."""
if azure_ai_settings is None:
azure_ai_settings = AzureAISettings(env_file_path="test.env")
azure_ai_settings = load_settings(AzureAISettings, env_prefix="AZURE_AI_", env_file_path="test.env")
# Create client instance directly
client = object.__new__(AzureAIAgentClient)
@@ -78,7 +79,7 @@ def create_test_azure_ai_chat_client(
client.agent_id = agent_id
client.agent_name = agent_name
client.agent_description = None
client.model_id = azure_ai_settings.model_deployment_name
client.model_id = azure_ai_settings.get("model_deployment_name")
client.thread_id = thread_id
client.should_cleanup_agent = should_cleanup_agent
client._agent_created = False
@@ -104,21 +105,23 @@ def create_test_azure_ai_chat_client(
def test_azure_ai_settings_init(azure_ai_unit_test_env: dict[str, str]) -> None:
"""Test AzureAISettings initialization."""
settings = AzureAISettings()
settings = load_settings(AzureAISettings, env_prefix="AZURE_AI_")
assert settings.project_endpoint == azure_ai_unit_test_env["AZURE_AI_PROJECT_ENDPOINT"]
assert settings.model_deployment_name == azure_ai_unit_test_env["AZURE_AI_MODEL_DEPLOYMENT_NAME"]
assert settings["project_endpoint"] == azure_ai_unit_test_env["AZURE_AI_PROJECT_ENDPOINT"]
assert settings["model_deployment_name"] == azure_ai_unit_test_env["AZURE_AI_MODEL_DEPLOYMENT_NAME"]
def test_azure_ai_settings_init_with_explicit_values() -> None:
"""Test AzureAISettings initialization with explicit values."""
settings = AzureAISettings(
settings = load_settings(
AzureAISettings,
env_prefix="AZURE_AI_",
project_endpoint="https://custom-endpoint.com/",
model_deployment_name="custom-model",
)
assert settings.project_endpoint == "https://custom-endpoint.com/"
assert settings.model_deployment_name == "custom-model"
assert settings["project_endpoint"] == "https://custom-endpoint.com/"
assert settings["model_deployment_name"] == "custom-model"
def test_azure_ai_chat_client_init_with_client(mock_agents_client: MagicMock) -> None:
@@ -138,33 +141,29 @@ def test_azure_ai_chat_client_init_auto_create_client(
mock_agents_client: MagicMock,
) -> None:
"""Test AzureAIAgentClient initialization with auto-created agents_client."""
azure_ai_settings = AzureAISettings(**azure_ai_unit_test_env) # type: ignore
azure_ai_settings = load_settings(AzureAISettings, env_prefix="AZURE_AI_", **azure_ai_unit_test_env) # type: ignore
# Create client instance directly
client = object.__new__(AzureAIAgentClient)
client.agents_client = mock_agents_client
client.agent_id = None
client.thread_id = None
client._should_close_client = False # type: ignore
client.credential = None
client.model_id = azure_ai_settings.model_deployment_name
client.agent_name = None
client.additional_properties = {}
client.middleware = None
chat_client = object.__new__(AzureAIAgentClient)
chat_client.agents_client = mock_agents_client
chat_client.agent_id = None
chat_client.thread_id = None
chat_client._should_close_client = False # type: ignore
chat_client.credential = None
chat_client.model_id = azure_ai_settings.get("model_deployment_name")
chat_client.agent_name = None
chat_client.additional_properties = {}
chat_client.middleware = None
assert client.agents_client is mock_agents_client
assert client.agent_id is None
assert chat_client.agents_client is mock_agents_client
assert chat_client.agent_id is None
def test_azure_ai_chat_client_init_missing_project_endpoint() -> None:
"""Test AzureAIAgentClient initialization when project_endpoint is missing and no agents_client provided."""
# Mock AzureAISettings to return settings with None project_endpoint
with patch("agent_framework_azure_ai._chat_client.AzureAISettings") as mock_settings:
mock_settings_instance = MagicMock()
mock_settings_instance.project_endpoint = None # This should trigger the error
mock_settings_instance.model_deployment_name = "test-model"
mock_settings_instance.agent_name = "test-agent"
mock_settings.return_value = mock_settings_instance
with patch("agent_framework_azure_ai._chat_client.load_settings") as mock_load_settings:
mock_load_settings.return_value = {"project_endpoint": None, "model_deployment_name": "test-model"}
with pytest.raises(ServiceInitializationError, match="project endpoint is required"):
AzureAIAgentClient(
@@ -179,12 +178,8 @@ def test_azure_ai_chat_client_init_missing_project_endpoint() -> None:
def test_azure_ai_chat_client_init_missing_model_deployment_for_agent_creation() -> None:
"""Test AzureAIAgentClient initialization when model deployment is missing for agent creation."""
# Mock AzureAISettings to return settings with None model_deployment_name
with patch("agent_framework_azure_ai._chat_client.AzureAISettings") as mock_settings:
mock_settings_instance = MagicMock()
mock_settings_instance.project_endpoint = "https://test.com"
mock_settings_instance.model_deployment_name = None # This should trigger the error
mock_settings_instance.agent_name = "test-agent"
mock_settings.return_value = mock_settings_instance
with patch("agent_framework_azure_ai._chat_client.load_settings") as mock_load_settings:
mock_load_settings.return_value = {"project_endpoint": "https://test.com", "model_deployment_name": None}
with pytest.raises(ServiceInitializationError, match="model deployment name is required"):
AzureAIAgentClient(
@@ -210,20 +205,6 @@ def test_azure_ai_chat_client_init_missing_credential(azure_ai_unit_test_env: di
)
def test_azure_ai_chat_client_init_validation_error(mock_azure_credential: MagicMock) -> None:
"""Test that ValidationError in AzureAISettings is properly handled."""
with patch("agent_framework_azure_ai._chat_client.AzureAISettings") as mock_settings:
# Create a proper ValidationError with empty errors list and model dict
mock_settings.side_effect = ValidationError.from_exception_data("AzureAISettings", [])
with pytest.raises(ServiceInitializationError, match="Failed to create Azure AI settings."):
AzureAIAgentClient(
project_endpoint="https://test.com",
model_deployment_name="test-model",
credential=mock_azure_credential,
)
def test_azure_ai_chat_client_from_dict() -> None:
"""Test from_settings class method."""
mock_agents_client = MagicMock()
@@ -248,11 +229,15 @@ async def test_azure_ai_chat_client_get_agent_id_or_create_with_temperature_and_
mock_agents_client: MagicMock, azure_ai_unit_test_env: dict[str, str]
) -> None:
"""Test _get_agent_id_or_create with temperature and top_p in run_options."""
azure_ai_settings = AzureAISettings(model_deployment_name=azure_ai_unit_test_env["AZURE_AI_MODEL_DEPLOYMENT_NAME"])
azure_ai_settings = load_settings(
AzureAISettings,
env_prefix="AZURE_AI_",
model_deployment_name=azure_ai_unit_test_env["AZURE_AI_MODEL_DEPLOYMENT_NAME"],
)
client = create_test_azure_ai_chat_client(mock_agents_client, azure_ai_settings=azure_ai_settings)
run_options = {
"model": azure_ai_settings.model_deployment_name,
"model": azure_ai_settings.get("model_deployment_name"),
"temperature": 0.7,
"top_p": 0.9,
}
@@ -284,13 +269,19 @@ async def test_azure_ai_chat_client_get_agent_id_or_create_create_new(
azure_ai_unit_test_env: dict[str, str],
) -> None:
"""Test _get_agent_id_or_create when creating a new agent."""
azure_ai_settings = AzureAISettings(model_deployment_name=azure_ai_unit_test_env["AZURE_AI_MODEL_DEPLOYMENT_NAME"])
client = create_test_azure_ai_chat_client(mock_agents_client, azure_ai_settings=azure_ai_settings)
azure_ai_settings = load_settings(
AzureAISettings,
env_prefix="AZURE_AI_",
model_deployment_name=azure_ai_unit_test_env["AZURE_AI_MODEL_DEPLOYMENT_NAME"],
)
chat_client = create_test_azure_ai_chat_client(mock_agents_client, azure_ai_settings=azure_ai_settings)
agent_id = await client._get_agent_id_or_create(run_options={"model": azure_ai_settings.model_deployment_name}) # type: ignore
agent_id = await chat_client._get_agent_id_or_create(
run_options={"model": azure_ai_settings.get("model_deployment_name")}
) # type: ignore
assert agent_id == "test-agent-id"
assert client._agent_created
assert chat_client._agent_created
async def test_azure_ai_chat_client_thread_management_through_public_api(mock_agents_client: MagicMock) -> None:
@@ -547,14 +538,18 @@ async def test_azure_ai_chat_client_get_agent_id_or_create_with_run_options(
mock_agents_client: MagicMock, azure_ai_unit_test_env: dict[str, str]
) -> None:
"""Test _get_agent_id_or_create with run_options containing tools and instructions."""
azure_ai_settings = AzureAISettings(model_deployment_name=azure_ai_unit_test_env["AZURE_AI_MODEL_DEPLOYMENT_NAME"])
azure_ai_settings = load_settings(
AzureAISettings,
env_prefix="AZURE_AI_",
model_deployment_name=azure_ai_unit_test_env["AZURE_AI_MODEL_DEPLOYMENT_NAME"],
)
client = create_test_azure_ai_chat_client(mock_agents_client, azure_ai_settings=azure_ai_settings)
run_options = {
"tools": [{"type": "function", "function": {"name": "test_tool"}}],
"instructions": "Test instructions",
"response_format": {"type": "json_object"},
"model": azure_ai_settings.model_deployment_name,
"model": azure_ai_settings.get("model_deployment_name"),
}
agent_id = await client._get_agent_id_or_create(run_options) # type: ignore
@@ -1134,13 +1129,19 @@ async def test_azure_ai_chat_client_get_agent_id_or_create_with_agent_name(
mock_agents_client: MagicMock, azure_ai_unit_test_env: dict[str, str]
) -> None:
"""Test _get_agent_id_or_create uses default name when no agent_name set."""
azure_ai_settings = AzureAISettings(model_deployment_name=azure_ai_unit_test_env["AZURE_AI_MODEL_DEPLOYMENT_NAME"])
azure_ai_settings = load_settings(
AzureAISettings,
env_prefix="AZURE_AI_",
model_deployment_name=azure_ai_unit_test_env["AZURE_AI_MODEL_DEPLOYMENT_NAME"],
)
client = create_test_azure_ai_chat_client(mock_agents_client, azure_ai_settings=azure_ai_settings)
# Ensure agent_name is None to test the default
client.agent_name = None # type: ignore
agent_id = await client._get_agent_id_or_create(run_options={"model": azure_ai_settings.model_deployment_name}) # type: ignore
agent_id = await client._get_agent_id_or_create(
run_options={"model": azure_ai_settings.get("model_deployment_name")}
) # type: ignore
assert agent_id == "test-agent-id"
# Verify create_agent was called with default "UnnamedAgent"
@@ -1153,11 +1154,15 @@ async def test_azure_ai_chat_client_get_agent_id_or_create_with_response_format(
mock_agents_client: MagicMock, azure_ai_unit_test_env: dict[str, str]
) -> None:
"""Test _get_agent_id_or_create with response_format in run_options."""
azure_ai_settings = AzureAISettings(model_deployment_name=azure_ai_unit_test_env["AZURE_AI_MODEL_DEPLOYMENT_NAME"])
azure_ai_settings = load_settings(
AzureAISettings,
env_prefix="AZURE_AI_",
model_deployment_name=azure_ai_unit_test_env["AZURE_AI_MODEL_DEPLOYMENT_NAME"],
)
client = create_test_azure_ai_chat_client(mock_agents_client, azure_ai_settings=azure_ai_settings)
# Test with response_format in run_options
run_options = {"response_format": {"type": "json_object"}, "model": azure_ai_settings.model_deployment_name}
run_options = {"response_format": {"type": "json_object"}, "model": azure_ai_settings.get("model_deployment_name")}
agent_id = await client._get_agent_id_or_create(run_options) # type: ignore
@@ -1172,13 +1177,17 @@ async def test_azure_ai_chat_client_get_agent_id_or_create_with_tool_resources(
mock_agents_client: MagicMock, azure_ai_unit_test_env: dict[str, str]
) -> None:
"""Test _get_agent_id_or_create with tool_resources in run_options."""
azure_ai_settings = AzureAISettings(model_deployment_name=azure_ai_unit_test_env["AZURE_AI_MODEL_DEPLOYMENT_NAME"])
azure_ai_settings = load_settings(
AzureAISettings,
env_prefix="AZURE_AI_",
model_deployment_name=azure_ai_unit_test_env["AZURE_AI_MODEL_DEPLOYMENT_NAME"],
)
client = create_test_azure_ai_chat_client(mock_agents_client, azure_ai_settings=azure_ai_settings)
# Test with tool_resources in run_options
run_options = {
"tool_resources": {"vector_store_ids": ["vs-123"]},
"model": azure_ai_settings.model_deployment_name,
"model": azure_ai_settings.get("model_deployment_name"),
}
agent_id = await client._get_agent_id_or_create(run_options) # type: ignore
@@ -20,6 +20,7 @@ from agent_framework import (
SupportsChatGetResponse,
tool,
)
from agent_framework._settings import load_settings
from agent_framework.exceptions import ServiceInitializationError
from azure.ai.projects.aio import AIProjectClient
from azure.ai.projects.models import (
@@ -36,7 +37,7 @@ from azure.core.exceptions import ResourceNotFoundError
from azure.identity.aio import AzureCliCredential
from openai.types.responses.parsed_response import ParsedResponse
from openai.types.responses.response import Response as OpenAIResponse
from pydantic import BaseModel, ConfigDict, Field, ValidationError
from pydantic import BaseModel, ConfigDict, Field
from pytest import fixture, param
from agent_framework_azure_ai import AzureAIClient, AzureAISettings
@@ -113,7 +114,7 @@ def create_test_azure_ai_client(
) -> AzureAIClient:
"""Helper function to create AzureAIClient instances for testing, bypassing normal validation."""
if azure_ai_settings is None:
azure_ai_settings = AzureAISettings(env_file_path="test.env")
azure_ai_settings = load_settings(AzureAISettings, env_prefix="AZURE_AI_", env_file_path="test.env")
# Create client instance directly
client = object.__new__(AzureAIClient)
@@ -125,7 +126,7 @@ def create_test_azure_ai_client(
client.agent_version = agent_version
client.agent_description = None
client.use_latest_version = use_latest_version
client.model_id = azure_ai_settings.model_deployment_name
client.model_id = azure_ai_settings.get("model_deployment_name")
client.conversation_id = conversation_id
client._is_application_endpoint = False # type: ignore
client._should_close_client = should_close_client # type: ignore
@@ -143,28 +144,29 @@ def create_test_azure_ai_client(
def test_azure_ai_settings_init(azure_ai_unit_test_env: dict[str, str]) -> None:
"""Test AzureAISettings initialization."""
settings = AzureAISettings()
settings = load_settings(AzureAISettings, env_prefix="AZURE_AI_")
assert settings.project_endpoint == azure_ai_unit_test_env["AZURE_AI_PROJECT_ENDPOINT"]
assert settings.model_deployment_name == azure_ai_unit_test_env["AZURE_AI_MODEL_DEPLOYMENT_NAME"]
assert settings["project_endpoint"] == azure_ai_unit_test_env["AZURE_AI_PROJECT_ENDPOINT"]
assert settings["model_deployment_name"] == azure_ai_unit_test_env["AZURE_AI_MODEL_DEPLOYMENT_NAME"]
def test_azure_ai_settings_init_with_explicit_values() -> None:
"""Test AzureAISettings initialization with explicit values."""
settings = AzureAISettings(
settings = load_settings(
AzureAISettings,
env_prefix="AZURE_AI_",
project_endpoint="https://custom-endpoint.com/",
model_deployment_name="custom-model",
)
assert settings.project_endpoint == "https://custom-endpoint.com/"
assert settings.model_deployment_name == "custom-model"
assert settings["project_endpoint"] == "https://custom-endpoint.com/"
assert settings["model_deployment_name"] == "custom-model"
def test_init_with_project_client(mock_project_client: MagicMock) -> None:
"""Test AzureAIClient initialization with existing project_client."""
with patch("agent_framework_azure_ai._client.AzureAISettings") as mock_settings:
mock_settings.return_value.project_endpoint = None
mock_settings.return_value.model_deployment_name = "test-model"
with patch("agent_framework_azure_ai._client.load_settings") as mock_load_settings:
mock_load_settings.return_value = {"project_endpoint": None, "model_deployment_name": "test-model"}
client = AzureAIClient(
project_client=mock_project_client,
@@ -205,9 +207,8 @@ def test_init_auto_create_client(
def test_init_missing_project_endpoint() -> None:
"""Test AzureAIClient initialization when project_endpoint is missing and no project_client provided."""
with patch("agent_framework_azure_ai._client.AzureAISettings") as mock_settings:
mock_settings.return_value.project_endpoint = None
mock_settings.return_value.model_deployment_name = "test-model"
with patch("agent_framework_azure_ai._client.load_settings") as mock_load_settings:
mock_load_settings.return_value = {"project_endpoint": None, "model_deployment_name": "test-model"}
with pytest.raises(ServiceInitializationError, match="Azure AI project endpoint is required"):
AzureAIClient(credential=MagicMock())
@@ -224,15 +225,6 @@ def test_init_missing_credential(azure_ai_unit_test_env: dict[str, str]) -> None
)
def test_init_validation_error(mock_azure_credential: MagicMock) -> None:
"""Test that ValidationError in AzureAISettings is properly handled."""
with patch("agent_framework_azure_ai._client.AzureAISettings") as mock_settings:
mock_settings.side_effect = ValidationError.from_exception_data("test", [])
with pytest.raises(ServiceInitializationError, match="Failed to create Azure AI settings"):
AzureAIClient(credential=mock_azure_credential)
async def test_get_agent_reference_or_create_existing_version(
mock_project_client: MagicMock,
) -> None:
@@ -259,7 +251,11 @@ async def test_get_agent_reference_or_create_new_agent(
azure_ai_unit_test_env: dict[str, str],
) -> None:
"""Test _get_agent_reference_or_create when creating a new agent."""
azure_ai_settings = AzureAISettings(model_deployment_name=azure_ai_unit_test_env["AZURE_AI_MODEL_DEPLOYMENT_NAME"])
azure_ai_settings = load_settings(
AzureAISettings,
env_prefix="AZURE_AI_",
model_deployment_name=azure_ai_unit_test_env["AZURE_AI_MODEL_DEPLOYMENT_NAME"],
)
client = create_test_azure_ai_client(
mock_project_client, agent_name="new-agent", azure_ai_settings=azure_ai_settings
)
@@ -270,7 +266,7 @@ async def test_get_agent_reference_or_create_new_agent(
mock_agent.version = "1.0"
mock_project_client.agents.create_version = AsyncMock(return_value=mock_agent)
run_options = {"model": azure_ai_settings.model_deployment_name}
run_options = {"model": azure_ai_settings.get("model_deployment_name")}
agent_ref = await client._get_agent_reference_or_create(run_options, None) # type: ignore
assert agent_ref == {"name": "new-agent", "version": "1.0", "type": "agent_reference"}
+44 -30
View File
@@ -107,9 +107,8 @@ def test_provider_init_with_credential_and_endpoint(
def test_provider_init_missing_endpoint() -> None:
"""Test AzureAIProjectAgentProvider initialization when endpoint is missing."""
with patch("agent_framework_azure_ai._project_provider.AzureAISettings") as mock_settings:
mock_settings.return_value.project_endpoint = None
mock_settings.return_value.model_deployment_name = "test-model"
with patch("agent_framework_azure_ai._project_provider.load_settings") as mock_load_settings:
mock_load_settings.return_value = {"project_endpoint": None, "model_deployment_name": "test-model"}
with pytest.raises(ServiceInitializationError, match="Azure AI project endpoint is required"):
AzureAIProjectAgentProvider(credential=MagicMock())
@@ -130,9 +129,11 @@ async def test_provider_create_agent(
azure_ai_unit_test_env: dict[str, str],
) -> None:
"""Test AzureAIProjectAgentProvider.create_agent method."""
with patch("agent_framework_azure_ai._project_provider.AzureAISettings") as mock_settings:
mock_settings.return_value.project_endpoint = azure_ai_unit_test_env["AZURE_AI_PROJECT_ENDPOINT"]
mock_settings.return_value.model_deployment_name = azure_ai_unit_test_env["AZURE_AI_MODEL_DEPLOYMENT_NAME"]
with patch("agent_framework_azure_ai._project_provider.load_settings") as mock_load_settings:
mock_load_settings.return_value = {
"project_endpoint": azure_ai_unit_test_env["AZURE_AI_PROJECT_ENDPOINT"],
"model_deployment_name": azure_ai_unit_test_env["AZURE_AI_MODEL_DEPLOYMENT_NAME"],
}
provider = AzureAIProjectAgentProvider(project_client=mock_project_client)
@@ -168,9 +169,11 @@ async def test_provider_create_agent_with_env_model(
azure_ai_unit_test_env: dict[str, str],
) -> None:
"""Test AzureAIProjectAgentProvider.create_agent uses model from env var."""
with patch("agent_framework_azure_ai._project_provider.AzureAISettings") as mock_settings:
mock_settings.return_value.project_endpoint = azure_ai_unit_test_env["AZURE_AI_PROJECT_ENDPOINT"]
mock_settings.return_value.model_deployment_name = azure_ai_unit_test_env["AZURE_AI_MODEL_DEPLOYMENT_NAME"]
with patch("agent_framework_azure_ai._project_provider.load_settings") as mock_load_settings:
mock_load_settings.return_value = {
"project_endpoint": azure_ai_unit_test_env["AZURE_AI_PROJECT_ENDPOINT"],
"model_deployment_name": azure_ai_unit_test_env["AZURE_AI_MODEL_DEPLOYMENT_NAME"],
}
provider = AzureAIProjectAgentProvider(project_client=mock_project_client)
@@ -200,9 +203,8 @@ async def test_provider_create_agent_with_env_model(
async def test_provider_create_agent_missing_model(mock_project_client: MagicMock) -> None:
"""Test AzureAIProjectAgentProvider.create_agent raises when model is missing."""
with patch("agent_framework_azure_ai._project_provider.AzureAISettings") as mock_settings:
mock_settings.return_value.project_endpoint = "https://test.com"
mock_settings.return_value.model_deployment_name = None
with patch("agent_framework_azure_ai._project_provider.load_settings") as mock_load_settings:
mock_load_settings.return_value = {"project_endpoint": "https://test.com", "model_deployment_name": None}
provider = AzureAIProjectAgentProvider(project_client=mock_project_client)
@@ -215,9 +217,11 @@ async def test_provider_create_agent_with_rai_config(
azure_ai_unit_test_env: dict[str, str],
) -> None:
"""Test AzureAIProjectAgentProvider.create_agent passes rai_config from default_options."""
with patch("agent_framework_azure_ai._project_provider.AzureAISettings") as mock_settings:
mock_settings.return_value.project_endpoint = azure_ai_unit_test_env["AZURE_AI_PROJECT_ENDPOINT"]
mock_settings.return_value.model_deployment_name = azure_ai_unit_test_env["AZURE_AI_MODEL_DEPLOYMENT_NAME"]
with patch("agent_framework_azure_ai._project_provider.load_settings") as mock_load_settings:
mock_load_settings.return_value = {
"project_endpoint": azure_ai_unit_test_env["AZURE_AI_PROJECT_ENDPOINT"],
"model_deployment_name": azure_ai_unit_test_env["AZURE_AI_MODEL_DEPLOYMENT_NAME"],
}
provider = AzureAIProjectAgentProvider(project_client=mock_project_client)
@@ -258,9 +262,11 @@ async def test_provider_create_agent_with_reasoning(
azure_ai_unit_test_env: dict[str, str],
) -> None:
"""Test AzureAIProjectAgentProvider.create_agent passes reasoning from default_options."""
with patch("agent_framework_azure_ai._project_provider.AzureAISettings") as mock_settings:
mock_settings.return_value.project_endpoint = azure_ai_unit_test_env["AZURE_AI_PROJECT_ENDPOINT"]
mock_settings.return_value.model_deployment_name = azure_ai_unit_test_env["AZURE_AI_MODEL_DEPLOYMENT_NAME"]
with patch("agent_framework_azure_ai._project_provider.load_settings") as mock_load_settings:
mock_load_settings.return_value = {
"project_endpoint": azure_ai_unit_test_env["AZURE_AI_PROJECT_ENDPOINT"],
"model_deployment_name": azure_ai_unit_test_env["AZURE_AI_MODEL_DEPLOYMENT_NAME"],
}
provider = AzureAIProjectAgentProvider(project_client=mock_project_client)
@@ -465,9 +471,11 @@ async def test_provider_context_manager(mock_project_client: MagicMock) -> None:
mock_client.close = AsyncMock()
mock_ai_project_client.return_value = mock_client
with patch("agent_framework_azure_ai._project_provider.AzureAISettings") as mock_settings:
mock_settings.return_value.project_endpoint = "https://test.com"
mock_settings.return_value.model_deployment_name = "test-model"
with patch("agent_framework_azure_ai._project_provider.load_settings") as mock_load_settings:
mock_load_settings.return_value = {
"project_endpoint": "https://test.com",
"model_deployment_name": "test-model",
}
async with AzureAIProjectAgentProvider(credential=MagicMock()) as provider:
assert provider._project_client is mock_client # type: ignore
@@ -494,9 +502,11 @@ async def test_provider_close_method(mock_project_client: MagicMock) -> None:
mock_client.close = AsyncMock()
mock_ai_project_client.return_value = mock_client
with patch("agent_framework_azure_ai._project_provider.AzureAISettings") as mock_settings:
mock_settings.return_value.project_endpoint = "https://test.com"
mock_settings.return_value.model_deployment_name = "test-model"
with patch("agent_framework_azure_ai._project_provider.load_settings") as mock_load_settings:
mock_load_settings.return_value = {
"project_endpoint": "https://test.com",
"model_deployment_name": "test-model",
}
provider = AzureAIProjectAgentProvider(credential=MagicMock())
await provider.close()
@@ -581,12 +591,14 @@ async def test_provider_create_agent_with_mcp_tool(
return [tools]
with (
patch("agent_framework_azure_ai._project_provider.AzureAISettings") as mock_settings,
patch("agent_framework_azure_ai._project_provider.load_settings") as mock_load_settings,
patch("agent_framework_azure_ai._project_provider.to_azure_ai_tools") as mock_to_azure_tools,
patch("agent_framework_azure_ai._project_provider.normalize_tools", side_effect=mock_normalize_tools),
):
mock_settings.return_value.project_endpoint = azure_ai_unit_test_env["AZURE_AI_PROJECT_ENDPOINT"]
mock_settings.return_value.model_deployment_name = azure_ai_unit_test_env["AZURE_AI_MODEL_DEPLOYMENT_NAME"]
mock_load_settings.return_value = {
"project_endpoint": azure_ai_unit_test_env["AZURE_AI_PROJECT_ENDPOINT"],
"model_deployment_name": azure_ai_unit_test_env["AZURE_AI_MODEL_DEPLOYMENT_NAME"],
}
mock_to_azure_tools.return_value = [{"type": "function", "name": "mcp_function_1"}]
provider = AzureAIProjectAgentProvider(project_client=mock_project_client)
@@ -642,12 +654,14 @@ async def test_provider_create_agent_with_mcp_and_regular_tools(
return [tools]
with (
patch("agent_framework_azure_ai._project_provider.AzureAISettings") as mock_settings,
patch("agent_framework_azure_ai._project_provider.load_settings") as mock_load_settings,
patch("agent_framework_azure_ai._project_provider.to_azure_ai_tools") as mock_to_azure_tools,
patch("agent_framework_azure_ai._project_provider.normalize_tools", side_effect=mock_normalize_tools),
):
mock_settings.return_value.project_endpoint = azure_ai_unit_test_env["AZURE_AI_PROJECT_ENDPOINT"]
mock_settings.return_value.model_deployment_name = azure_ai_unit_test_env["AZURE_AI_MODEL_DEPLOYMENT_NAME"]
mock_load_settings.return_value = {
"project_endpoint": azure_ai_unit_test_env["AZURE_AI_PROJECT_ENDPOINT"],
"model_deployment_name": azure_ai_unit_test_env["AZURE_AI_MODEL_DEPLOYMENT_NAME"],
}
mock_to_azure_tools.return_value = []
provider = AzureAIProjectAgentProvider(project_client=mock_project_client)
@@ -30,13 +30,13 @@ from agent_framework import (
prepare_function_call_results,
validate_tool_mode,
)
from agent_framework._pydantic import AFBaseSettings
from agent_framework._settings import SecretString, load_settings
from agent_framework.exceptions import ServiceInitializationError, ServiceInvalidResponseError
from agent_framework.observability import ChatTelemetryLayer
from boto3.session import Session as Boto3Session
from botocore.client import BaseClient
from botocore.config import Config as BotoConfig
from pydantic import BaseModel, SecretStr, ValidationError
from pydantic import BaseModel
if sys.version_info >= (3, 13):
from typing import TypeVar # type: ignore # pragma: no cover
@@ -205,16 +205,14 @@ FINISH_REASON_MAP: dict[str, FinishReasonLiteral] = {
}
class BedrockSettings(AFBaseSettings):
class BedrockSettings(TypedDict, total=False):
"""Bedrock configuration settings pulled from environment variables or .env files."""
env_prefix: ClassVar[str] = "BEDROCK_"
region: str = DEFAULT_REGION
chat_model_id: str | None = None
access_key: SecretStr | None = None
secret_key: SecretStr | None = None
session_token: SecretStr | None = None
region: str | None
chat_model_id: str | None
access_key: SecretString | None
secret_key: SecretString | None
session_token: SecretString | None
class BedrockChatClient(
@@ -280,24 +278,25 @@ class BedrockChatClient(
client = BedrockChatClient[MyOptions](model_id="<model name>")
response = await client.get_response("Hello", options={"my_custom_option": "value"})
"""
try:
settings = BedrockSettings(
region=region,
chat_model_id=model_id,
access_key=access_key, # type: ignore[arg-type]
secret_key=secret_key, # type: ignore[arg-type]
session_token=session_token, # type: ignore[arg-type]
env_file_path=env_file_path,
env_file_encoding=env_file_encoding,
)
except ValidationError as ex:
raise ServiceInitializationError("Failed to initialize Bedrock settings.", ex) from ex
settings = load_settings(
BedrockSettings,
env_prefix="BEDROCK_",
region=region,
chat_model_id=model_id,
access_key=access_key,
secret_key=secret_key,
session_token=session_token,
env_file_path=env_file_path,
env_file_encoding=env_file_encoding,
)
if not settings.get("region"):
settings["region"] = DEFAULT_REGION
if client is None:
session = boto3_session or self._create_session(settings)
client = session.client(
"bedrock-runtime",
region_name=settings.region,
region_name=settings["region"],
config=BotoConfig(user_agent_extra=AGENT_FRAMEWORK_USER_AGENT),
)
@@ -307,17 +306,17 @@ class BedrockChatClient(
**kwargs,
)
self._bedrock_client = client
self.model_id = settings.chat_model_id
self.region = settings.region
self.model_id = settings["chat_model_id"]
self.region = settings["region"]
@staticmethod
def _create_session(settings: BedrockSettings) -> Boto3Session:
session_kwargs: dict[str, Any] = {"region_name": settings.region or DEFAULT_REGION}
if settings.access_key and settings.secret_key:
session_kwargs["aws_access_key_id"] = settings.access_key.get_secret_value()
session_kwargs["aws_secret_access_key"] = settings.secret_key.get_secret_value()
if settings.session_token:
session_kwargs["aws_session_token"] = settings.session_token.get_secret_value()
session_kwargs: dict[str, Any] = {"region_name": settings.get("region") or DEFAULT_REGION}
if settings.get("access_key") and settings.get("secret_key"):
session_kwargs["aws_access_key_id"] = settings["access_key"].get_secret_value() # type: ignore[union-attr]
session_kwargs["aws_secret_access_key"] = settings["secret_key"].get_secret_value() # type: ignore[union-attr]
if settings.get("session_token"):
session_kwargs["aws_session_token"] = settings["session_token"].get_secret_value() # type: ignore[union-attr]
return Boto3Session(**session_kwargs)
@override
@@ -11,6 +11,7 @@ from agent_framework import (
FunctionTool,
Message,
)
from agent_framework._settings import load_settings
from pydantic import BaseModel
from agent_framework_bedrock._chat_client import BedrockChatClient, BedrockSettings
@@ -33,9 +34,9 @@ def _dummy_weather(location: str) -> str: # pragma: no cover - helper
def test_settings_load_from_environment(monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setenv("BEDROCK_REGION", "us-west-2")
monkeypatch.setenv("BEDROCK_CHAT_MODEL_ID", "anthropic.claude-v2")
settings = BedrockSettings()
assert settings.region == "us-west-2"
assert settings.chat_model_id == "anthropic.claude-v2"
settings = load_settings(BedrockSettings, env_prefix="BEDROCK_")
assert settings["region"] == "us-west-2"
assert settings["chat_model_id"] == "anthropic.claude-v2"
def test_build_request_includes_tool_config() -> None:
@@ -21,8 +21,9 @@ from agent_framework import (
get_logger,
normalize_messages,
)
from agent_framework._settings import load_settings
from agent_framework._types import normalize_tools
from agent_framework.exceptions import ServiceException, ServiceInitializationError
from agent_framework.exceptions import ServiceException
from claude_agent_sdk import (
AssistantMessage,
ClaudeSDKClient,
@@ -34,7 +35,6 @@ from claude_agent_sdk import (
ClaudeAgentOptions as SDKOptions,
)
from claude_agent_sdk.types import StreamEvent, TextBlock
from pydantic import ValidationError
from ._settings import ClaudeAgentSettings
@@ -273,19 +273,18 @@ class ClaudeAgent(BaseAgent, Generic[OptionsT]):
self._mcp_servers: dict[str, Any] = opts.pop("mcp_servers", None) or {}
# Load settings from environment and options
try:
self._settings = ClaudeAgentSettings(
cli_path=cli_path,
model=model,
cwd=cwd,
permission_mode=permission_mode,
max_turns=max_turns,
max_budget_usd=max_budget_usd,
env_file_path=env_file_path,
env_file_encoding=env_file_encoding,
)
except ValidationError as ex:
raise ServiceInitializationError("Failed to create Claude Agent settings.", ex) from ex
self._settings = load_settings(
ClaudeAgentSettings,
env_prefix="CLAUDE_AGENT_",
cli_path=cli_path,
model=model,
cwd=cwd,
permission_mode=permission_mode,
max_turns=max_turns,
max_budget_usd=max_budget_usd,
env_file_path=env_file_path,
env_file_encoding=env_file_encoding,
)
# Separate built-in tools (strings) from custom tools (callables/FunctionTool)
self._builtin_tools: list[str] = []
@@ -411,18 +410,18 @@ class ClaudeAgent(BaseAgent, Generic[OptionsT]):
opts["resume"] = resume_session_id
# Apply settings from environment
if self._settings.cli_path:
opts["cli_path"] = self._settings.cli_path
if self._settings.model:
opts["model"] = self._settings.model
if self._settings.cwd:
opts["cwd"] = self._settings.cwd
if self._settings.permission_mode:
opts["permission_mode"] = self._settings.permission_mode
if self._settings.max_turns:
opts["max_turns"] = self._settings.max_turns
if self._settings.max_budget_usd:
opts["max_budget_usd"] = self._settings.max_budget_usd
if self._settings["cli_path"]:
opts["cli_path"] = self._settings["cli_path"]
if self._settings["model"]:
opts["model"] = self._settings["model"]
if self._settings["cwd"]:
opts["cwd"] = self._settings["cwd"]
if self._settings["permission_mode"]:
opts["permission_mode"] = self._settings["permission_mode"]
if self._settings["max_turns"]:
opts["max_turns"] = self._settings["max_turns"]
if self._settings["max_budget_usd"]:
opts["max_budget_usd"] = self._settings["max_budget_usd"]
# Apply default options
for key, value in self._default_options.items():
@@ -1,51 +1,29 @@
# Copyright (c) Microsoft. All rights reserved.
from typing import ClassVar
from agent_framework._pydantic import AFBaseSettings
from typing import TypedDict
__all__ = ["ClaudeAgentSettings"]
class ClaudeAgentSettings(AFBaseSettings):
class ClaudeAgentSettings(TypedDict, total=False):
"""Claude Agent settings.
The settings are first loaded from environment variables with the prefix 'CLAUDE_AGENT_'.
If the environment variables are not found, the settings can be loaded from a .env file
with the encoding 'utf-8'. If the settings are not found in the .env file, the settings
are ignored; however, validation will fail alerting that the settings are missing.
with the encoding 'utf-8'.
Keyword Args:
Keys:
cli_path: The path to Claude CLI executable.
model: The model to use (sonnet, opus, haiku).
cwd: The working directory for Claude CLI.
permission_mode: Permission mode (default, acceptEdits, plan, bypassPermissions).
max_turns: Maximum number of conversation turns.
max_budget_usd: Maximum budget in USD.
env_file_path: If provided, the .env settings are read from this file path location.
env_file_encoding: The encoding of the .env file, defaults to 'utf-8'.
Examples:
.. code-block:: python
from agent_framework.anthropic import ClaudeAgentSettings
# Using environment variables
# Set CLAUDE_AGENT_MODEL=sonnet
# CLAUDE_AGENT_PERMISSION_MODE=default
# Or passing parameters directly
settings = ClaudeAgentSettings(model="sonnet")
# Or loading from a .env file
settings = ClaudeAgentSettings(env_file_path="path/to/.env")
"""
env_prefix: ClassVar[str] = "CLAUDE_AGENT_"
cli_path: str | None = None
model: str | None = None
cwd: str | None = None
permission_mode: str | None = None
max_turns: int | None = None
max_budget_usd: float | None = None
cli_path: str | None
model: str | None
cwd: str | None
permission_mode: str | None
max_turns: int | None
max_budget_usd: float | None
@@ -5,6 +5,7 @@ from unittest.mock import AsyncMock, MagicMock, patch
import pytest
from agent_framework import AgentResponseUpdate, AgentThread, Content, Message, tool
from agent_framework._settings import load_settings
from agent_framework_claude import ClaudeAgent, ClaudeAgentOptions, ClaudeAgentSettings
from agent_framework_claude._agent import TOOLS_MCP_SERVER_NAME
@@ -15,23 +16,21 @@ from agent_framework_claude._agent import TOOLS_MCP_SERVER_NAME
class TestClaudeAgentSettings:
"""Tests for ClaudeAgentSettings."""
def test_env_prefix(self) -> None:
"""Test that env_prefix is correctly set."""
assert ClaudeAgentSettings.env_prefix == "CLAUDE_AGENT_"
def test_default_values(self) -> None:
"""Test default values are None."""
settings = ClaudeAgentSettings()
assert settings.cli_path is None
assert settings.model is None
assert settings.cwd is None
assert settings.permission_mode is None
assert settings.max_turns is None
assert settings.max_budget_usd is None
settings = load_settings(ClaudeAgentSettings, env_prefix="CLAUDE_AGENT_")
assert settings["cli_path"] is None
assert settings["model"] is None
assert settings["cwd"] is None
assert settings["permission_mode"] is None
assert settings["max_turns"] is None
assert settings["max_budget_usd"] is None
def test_explicit_values(self) -> None:
"""Test explicit values override defaults."""
settings = ClaudeAgentSettings(
settings = load_settings(
ClaudeAgentSettings,
env_prefix="CLAUDE_AGENT_",
cli_path="/usr/local/bin/claude",
model="sonnet",
cwd="/home/user/project",
@@ -39,20 +38,20 @@ class TestClaudeAgentSettings:
max_turns=10,
max_budget_usd=5.0,
)
assert settings.cli_path == "/usr/local/bin/claude"
assert settings.model == "sonnet"
assert settings.cwd == "/home/user/project"
assert settings.permission_mode == "default"
assert settings.max_turns == 10
assert settings.max_budget_usd == 5.0
assert settings["cli_path"] == "/usr/local/bin/claude"
assert settings["model"] == "sonnet"
assert settings["cwd"] == "/home/user/project"
assert settings["permission_mode"] == "default"
assert settings["max_turns"] == 10
assert settings["max_budget_usd"] == 5.0
def test_env_variable_loading(self, monkeypatch: pytest.MonkeyPatch) -> None:
"""Test loading from environment variables."""
monkeypatch.setenv("CLAUDE_AGENT_MODEL", "opus")
monkeypatch.setenv("CLAUDE_AGENT_MAX_TURNS", "20")
settings = ClaudeAgentSettings()
assert settings.model == "opus"
assert settings.max_turns == 20
settings = load_settings(ClaudeAgentSettings, env_prefix="CLAUDE_AGENT_")
assert settings["model"] == "opus"
assert settings["max_turns"] == 20
# region Test ClaudeAgent Initialization
@@ -95,9 +94,9 @@ class TestClaudeAgentInit:
"max_turns": 10,
}
agent = ClaudeAgent(default_options=options)
assert agent._settings.model == "sonnet" # type: ignore[reportPrivateUsage]
assert agent._settings.permission_mode == "default" # type: ignore[reportPrivateUsage]
assert agent._settings.max_turns == 10 # type: ignore[reportPrivateUsage]
assert agent._settings["model"] == "sonnet" # type: ignore[reportPrivateUsage]
assert agent._settings["permission_mode"] == "default" # type: ignore[reportPrivateUsage]
assert agent._settings["max_turns"] == 10 # type: ignore[reportPrivateUsage]
def test_with_function_tool(self) -> None:
"""Test agent with function tool."""
@@ -620,13 +619,13 @@ class TestClaudeAgentPermissions:
def test_default_permission_mode(self) -> None:
"""Test default permission mode."""
agent = ClaudeAgent()
assert agent._settings.permission_mode is None # type: ignore[reportPrivateUsage]
assert agent._settings["permission_mode"] is None # type: ignore[reportPrivateUsage]
def test_permission_mode_from_settings(self, monkeypatch: pytest.MonkeyPatch) -> None:
"""Test permission mode from environment settings."""
monkeypatch.setenv("CLAUDE_AGENT_PERMISSION_MODE", "acceptEdits")
settings = ClaudeAgentSettings()
assert settings.permission_mode == "acceptEdits"
settings = load_settings(ClaudeAgentSettings, env_prefix="CLAUDE_AGENT_")
assert settings["permission_mode"] == "acceptEdits"
def test_permission_mode_in_options(self) -> None:
"""Test permission mode in options."""
@@ -634,7 +633,7 @@ class TestClaudeAgentPermissions:
"permission_mode": "bypassPermissions",
}
agent = ClaudeAgent(default_options=options)
assert agent._settings.permission_mode == "bypassPermissions" # type: ignore[reportPrivateUsage]
assert agent._settings["permission_mode"] == "bypassPermissions" # type: ignore[reportPrivateUsage]
# region Test ClaudeAgent Error Handling
@@ -3,7 +3,7 @@
from __future__ import annotations
from collections.abc import AsyncIterable, Awaitable, Sequence
from typing import Any, ClassVar, Literal, overload
from typing import Any, Literal, TypedDict, overload
from agent_framework import (
AgentMiddlewareTypes,
@@ -17,23 +17,21 @@ from agent_framework import (
ResponseStream,
normalize_messages,
)
from agent_framework._pydantic import AFBaseSettings
from agent_framework._settings import load_settings
from agent_framework.exceptions import ServiceException, ServiceInitializationError
from microsoft_agents.copilotstudio.client import AgentType, ConnectionSettings, CopilotClient, PowerPlatformCloud
from pydantic import ValidationError
from ._acquire_token import acquire_token
class CopilotStudioSettings(AFBaseSettings):
class CopilotStudioSettings(TypedDict, total=False):
"""Copilot Studio model settings.
The settings are first loaded from environment variables with the prefix 'COPILOTSTUDIOAGENT__'.
If the environment variables are not found, the settings can be loaded from a .env file
with the encoding 'utf-8'. If the settings are not found in the .env file, the settings
are ignored; however, validation will fail alerting that the settings are missing.
with the encoding 'utf-8'.
Keyword Args:
Keys:
environmentid: Environment ID of environment with the Copilot Studio App.
Can be set via environment variable COPILOTSTUDIOAGENT__ENVIRONMENTID.
schemaname: The agent identifier or schema name of the Copilot to use.
@@ -42,32 +40,12 @@ class CopilotStudioSettings(AFBaseSettings):
Can be set via environment variable COPILOTSTUDIOAGENT__AGENTAPPID.
tenantid: The tenant ID of the App Registration used to login.
Can be set via environment variable COPILOTSTUDIOAGENT__TENANTID.
env_file_path: If provided, the .env settings are read from this file path location.
env_file_encoding: The encoding of the .env file, defaults to 'utf-8'.
Examples:
.. code-block:: python
from agent_framework_copilotstudio import CopilotStudioSettings
# Using environment variables
# Set COPILOTSTUDIOAGENT__ENVIRONMENTID=env-123
# Set COPILOTSTUDIOAGENT__SCHEMANAME=my-agent
settings = CopilotStudioSettings()
# Or passing parameters directly
settings = CopilotStudioSettings(environmentid="env-123", schemaname="my-agent")
# Or loading from a .env file
settings = CopilotStudioSettings(env_file_path="path/to/.env")
"""
env_prefix: ClassVar[str] = "COPILOTSTUDIOAGENT__"
environmentid: str | None = None
schemaname: str | None = None
agentappid: str | None = None
tenantid: str | None = None
environmentid: str | None
schemaname: str | None
agentappid: str | None
tenantid: str | None
class CopilotStudioAgent(BaseAgent):
@@ -144,54 +122,53 @@ class CopilotStudioAgent(BaseAgent):
middleware=middleware,
)
if not client:
try:
copilot_studio_settings = CopilotStudioSettings(
environmentid=environment_id,
schemaname=agent_identifier,
agentappid=client_id,
tenantid=tenant_id,
env_file_path=env_file_path,
env_file_encoding=env_file_encoding,
)
except ValidationError as ex:
raise ServiceInitializationError("Failed to create Copilot Studio settings.", ex) from ex
copilot_studio_settings = load_settings(
CopilotStudioSettings,
env_prefix="COPILOTSTUDIOAGENT__",
environmentid=environment_id,
schemaname=agent_identifier,
agentappid=client_id,
tenantid=tenant_id,
env_file_path=env_file_path,
env_file_encoding=env_file_encoding,
)
if not settings:
if not copilot_studio_settings.environmentid:
if not copilot_studio_settings["environmentid"]:
raise ServiceInitializationError(
"Copilot Studio environment ID is required. Set via 'environment_id' parameter "
"or 'COPILOTSTUDIOAGENT__ENVIRONMENTID' environment variable."
)
if not copilot_studio_settings.schemaname:
if not copilot_studio_settings["schemaname"]:
raise ServiceInitializationError(
"Copilot Studio agent identifier/schema name is required. Set via 'agent_identifier' parameter "
"or 'COPILOTSTUDIOAGENT__SCHEMANAME' environment variable."
)
settings = ConnectionSettings(
environment_id=copilot_studio_settings.environmentid,
agent_identifier=copilot_studio_settings.schemaname,
environment_id=copilot_studio_settings["environmentid"],
agent_identifier=copilot_studio_settings["schemaname"],
cloud=cloud,
copilot_agent_type=agent_type,
custom_power_platform_cloud=custom_power_platform_cloud,
)
if not token:
if not copilot_studio_settings.agentappid:
if not copilot_studio_settings["agentappid"]:
raise ServiceInitializationError(
"Copilot Studio client ID is required. Set via 'client_id' parameter "
"or 'COPILOTSTUDIOAGENT__AGENTAPPID' environment variable."
)
if not copilot_studio_settings.tenantid:
if not copilot_studio_settings["tenantid"]:
raise ServiceInitializationError(
"Copilot Studio tenant ID is required. Set via 'tenant_id' parameter "
"or 'COPILOTSTUDIOAGENT__TENANTID' environment variable."
)
token = acquire_token(
client_id=copilot_studio_settings.agentappid,
tenant_id=copilot_studio_settings.tenantid,
client_id=copilot_studio_settings["agentappid"],
tenant_id=copilot_studio_settings["tenantid"],
username=username,
token_cache=token_cache,
scopes=scopes,
@@ -38,49 +38,57 @@ class TestCopilotStudioAgent:
return MagicMock(spec=CopilotClient)
@patch("agent_framework_copilotstudio._acquire_token.acquire_token")
@patch("agent_framework_copilotstudio._agent.CopilotStudioSettings")
def test_init_missing_environment_id(self, mock_settings: MagicMock, mock_acquire_token: MagicMock) -> None:
@patch("agent_framework_copilotstudio._agent.load_settings")
def test_init_missing_environment_id(self, mock_load_settings: MagicMock, mock_acquire_token: MagicMock) -> None:
mock_acquire_token.return_value = "fake-token"
mock_settings.return_value.environmentid = None
mock_settings.return_value.schemaname = "test-bot"
mock_settings.return_value.tenantid = "test-tenant"
mock_settings.return_value.agentappid = "test-client"
mock_load_settings.return_value = {
"environmentid": None,
"schemaname": "test-bot",
"tenantid": "test-tenant",
"agentappid": "test-client",
}
with pytest.raises(ServiceInitializationError, match="environment ID is required"):
CopilotStudioAgent()
@patch("agent_framework_copilotstudio._acquire_token.acquire_token")
@patch("agent_framework_copilotstudio._agent.CopilotStudioSettings")
def test_init_missing_bot_id(self, mock_settings: MagicMock, mock_acquire_token: MagicMock) -> None:
@patch("agent_framework_copilotstudio._agent.load_settings")
def test_init_missing_bot_id(self, mock_load_settings: MagicMock, mock_acquire_token: MagicMock) -> None:
mock_acquire_token.return_value = "fake-token"
mock_settings.return_value.environmentid = "test-env"
mock_settings.return_value.schemaname = None
mock_settings.return_value.tenantid = "test-tenant"
mock_settings.return_value.agentappid = "test-client"
mock_load_settings.return_value = {
"environmentid": "test-env",
"schemaname": None,
"tenantid": "test-tenant",
"agentappid": "test-client",
}
with pytest.raises(ServiceInitializationError, match="agent identifier"):
CopilotStudioAgent()
@patch("agent_framework_copilotstudio._acquire_token.acquire_token")
@patch("agent_framework_copilotstudio._agent.CopilotStudioSettings")
def test_init_missing_tenant_id(self, mock_settings: MagicMock, mock_acquire_token: MagicMock) -> None:
@patch("agent_framework_copilotstudio._agent.load_settings")
def test_init_missing_tenant_id(self, mock_load_settings: MagicMock, mock_acquire_token: MagicMock) -> None:
mock_acquire_token.return_value = "fake-token"
mock_settings.return_value.environmentid = "test-env"
mock_settings.return_value.schemaname = "test-bot"
mock_settings.return_value.tenantid = None
mock_settings.return_value.agentappid = "test-client"
mock_load_settings.return_value = {
"environmentid": "test-env",
"schemaname": "test-bot",
"tenantid": None,
"agentappid": "test-client",
}
with pytest.raises(ServiceInitializationError, match="tenant ID is required"):
CopilotStudioAgent()
@patch("agent_framework_copilotstudio._acquire_token.acquire_token")
@patch("agent_framework_copilotstudio._agent.CopilotStudioSettings")
def test_init_missing_client_id(self, mock_settings: MagicMock, mock_acquire_token: MagicMock) -> None:
@patch("agent_framework_copilotstudio._agent.load_settings")
def test_init_missing_client_id(self, mock_load_settings: MagicMock, mock_acquire_token: MagicMock) -> None:
mock_acquire_token.return_value = "fake-token"
mock_settings.return_value.environmentid = "test-env"
mock_settings.return_value.schemaname = "test-bot"
mock_settings.return_value.tenantid = "test-tenant"
mock_settings.return_value.agentappid = None
mock_load_settings.return_value = {
"environmentid": "test-env",
"schemaname": "test-bot",
"tenantid": "test-tenant",
"agentappid": None,
}
with pytest.raises(ServiceInitializationError, match="client ID is required"):
CopilotStudioAgent()
@@ -93,11 +101,13 @@ class TestCopilotStudioAgent:
@patch("agent_framework_copilotstudio._acquire_token.acquire_token")
def test_init_empty_environment_id(self, mock_acquire_token: MagicMock) -> None:
mock_acquire_token.return_value = "fake-token"
with patch("agent_framework_copilotstudio._agent.CopilotStudioSettings") as mock_settings:
mock_settings.return_value.environmentid = ""
mock_settings.return_value.schemaname = "test-bot"
mock_settings.return_value.tenantid = "test-tenant"
mock_settings.return_value.agentappid = "test-client"
with patch("agent_framework_copilotstudio._agent.load_settings") as mock_load_settings:
mock_load_settings.return_value = {
"environmentid": "",
"schemaname": "test-bot",
"tenantid": "test-tenant",
"agentappid": "test-client",
}
with pytest.raises(ServiceInitializationError, match="environment ID is required"):
CopilotStudioAgent()
@@ -105,11 +115,13 @@ class TestCopilotStudioAgent:
@patch("agent_framework_copilotstudio._acquire_token.acquire_token")
def test_init_empty_schema_name(self, mock_acquire_token: MagicMock) -> None:
mock_acquire_token.return_value = "fake-token"
with patch("agent_framework_copilotstudio._agent.CopilotStudioSettings") as mock_settings:
mock_settings.return_value.environmentid = "test-env"
mock_settings.return_value.schemaname = ""
mock_settings.return_value.tenantid = "test-tenant"
mock_settings.return_value.agentappid = "test-client"
with patch("agent_framework_copilotstudio._agent.load_settings") as mock_load_settings:
mock_load_settings.return_value = {
"environmentid": "test-env",
"schemaname": "",
"tenantid": "test-tenant",
"agentappid": "test-client",
}
with pytest.raises(ServiceInitializationError, match="agent identifier"):
CopilotStudioAgent()
@@ -1,70 +0,0 @@
# Copyright (c) Microsoft. All rights reserved.
from __future__ import annotations
from typing import Annotated, Any, ClassVar, TypeVar
from pydantic import Field, UrlConstraints
from pydantic.networks import AnyUrl
from pydantic_settings import BaseSettings, SettingsConfigDict
HTTPsUrl = Annotated[AnyUrl, UrlConstraints(max_length=2083, allowed_schemes=["https"])]
__all__ = ["AFBaseSettings", "HTTPsUrl"]
SettingsT = TypeVar("SettingsT", bound="AFBaseSettings")
class AFBaseSettings(BaseSettings):
"""Base class for all settings classes in the Agent Framework.
A subclass creates it's fields and overrides the env_prefix class variable
with the prefix for the environment variables.
In the case where a value is specified for the same Settings field in multiple ways,
the selected value is determined as follows (in descending order of priority):
- Arguments passed to the Settings class initializer.
- Environment variables, e.g. my_prefix_special_function as described above.
- Variables loaded from a dotenv (.env) file.
- Variables loaded from the secrets directory.
- The default field values for the Settings model.
"""
env_prefix: ClassVar[str] = ""
env_file_path: str | None = Field(default=None, exclude=True)
env_file_encoding: str | None = Field(default="utf-8", exclude=True)
model_config = SettingsConfigDict(
extra="ignore",
case_sensitive=False,
)
def __init__(
self,
**kwargs: Any,
) -> None:
"""Initialize the settings class."""
# Remove any None values from the kwargs so that defaults are used.
kwargs = {k: v for k, v in kwargs.items() if v is not None}
super().__init__(**kwargs)
def __new__(cls: type[SettingsT], *args: Any, **kwargs: Any) -> SettingsT:
"""Override the __new__ method to set the env_prefix."""
# for both, if supplied but None, set to default
if "env_file_encoding" in kwargs and kwargs["env_file_encoding"] is not None:
env_file_encoding = kwargs["env_file_encoding"]
else:
env_file_encoding = "utf-8"
if "env_file_path" in kwargs and kwargs["env_file_path"] is not None:
env_file_path = kwargs["env_file_path"]
else:
env_file_path = ".env"
cls.model_config.update( # type: ignore
env_prefix=cls.env_prefix,
env_file=env_file_path,
env_file_encoding=env_file_encoding,
)
cls.model_rebuild()
return super().__new__(cls) # type: ignore[return-value]
@@ -0,0 +1,262 @@
# Copyright (c) Microsoft. All rights reserved.
"""Generic settings loader with environment variable resolution.
This module provides a ``load_settings()`` function that populates a ``TypedDict``
from environment variables, ``.env`` files, and explicit overrides. It replaces
the previous pydantic-settings-based ``AFBaseSettings`` with a lighter-weight,
function-based approach that has no pydantic-settings dependency.
Usage::
class MySettings(TypedDict, total=False):
api_key: str | None # optional — resolves to None if not set
model_id: str | None # optional by default
# Make model_id required at call time:
settings = load_settings(
MySettings,
env_prefix="MY_APP_",
required_fields=["model_id"],
model_id="gpt-4",
)
settings["api_key"] # type-checked dict access
settings["model_id"] # str | None per type, but guaranteed not None at runtime
"""
from __future__ import annotations
import os
import sys
from collections.abc import Callable, Sequence
from contextlib import suppress
from typing import Any, Union, get_args, get_origin, get_type_hints
from dotenv import load_dotenv
from .exceptions import SettingNotFoundError
if sys.version_info >= (3, 13):
from typing import TypeVar # type: ignore # pragma: no cover
else:
from typing_extensions import TypeVar # type: ignore # pragma: no cover
__all__ = ["SecretString", "load_settings"]
SettingsT = TypeVar("SettingsT", default=dict[str, Any])
class SecretString(str):
"""A string subclass that masks its value in repr() to prevent accidental exposure.
SecretString behaves exactly like a regular string in all operations,
but its repr() shows '**********' instead of the actual value.
This helps prevent secrets from being accidentally logged or displayed.
It also provides a ``get_secret_value()`` method for backward compatibility
with code that previously used ``pydantic.SecretStr``.
Example:
```python
api_key = SecretString("sk-secret-key")
print(api_key) # sk-secret-key (normal string behavior)
print(repr(api_key)) # SecretString('**********')
print(f"Key: {api_key}") # Key: sk-secret-key
print(api_key.get_secret_value()) # sk-secret-key
```
"""
def __repr__(self) -> str:
"""Return a masked representation to prevent secret exposure."""
return "SecretString('**********')"
def get_secret_value(self) -> str:
"""Return the underlying string value.
Provided for backward compatibility with ``pydantic.SecretStr``.
Since SecretString *is* a str, this simply returns ``str(self)``.
"""
return str(self)
def _coerce_value(value: str, target_type: type) -> Any:
"""Coerce a string value to the target type."""
origin = get_origin(target_type)
args = get_args(target_type)
# Handle Union types (e.g., str | None) — try each non-None arm
if origin is type(None):
return None
if args and type(None) in args:
for arg in args:
if arg is not type(None):
with suppress(ValueError, TypeError):
return _coerce_value(value, arg)
return value
# Handle SecretString
if target_type is SecretString or (isinstance(target_type, type) and issubclass(target_type, SecretString)):
return SecretString(value)
# Handle basic types
if target_type is str:
return value
if target_type is int:
return int(value)
if target_type is float:
return float(value)
if target_type is bool:
return value.lower() in ("true", "1", "yes", "on")
return value
def _check_override_type(value: Any, field_type: type, field_name: str) -> None:
"""Validate that *value* is compatible with *field_type*.
Raises ``ServiceInitializationError`` when the override is clearly
incompatible (e.g. a ``dict`` passed where ``str`` is expected).
Callable values and ``None`` are always accepted.
"""
if value is None:
return
# Callables are always allowed (e.g. lazy token providers)
if callable(value) and not isinstance(value, (str, bytes)):
return
# Collect the concrete types that *field_type* allows
origin = get_origin(field_type)
args = get_args(field_type)
allowed: tuple[type, ...]
if origin is Union or origin is type(int | str):
allowed = tuple(a for a in args if isinstance(a, type) and a is not type(None))
# If any arm is a Callable, allow anything callable
if any(get_origin(a) is Callable or a is Callable for a in args):
return
elif isinstance(field_type, type):
allowed = (field_type,)
else:
return # complex / unknown annotation — skip check
if not allowed:
return
if not isinstance(value, allowed):
# Allow str for SecretString fields (will be coerced)
if isinstance(value, str) and any(isinstance(a, type) and issubclass(a, str) for a in allowed):
return
# Allow int for float fields (standard numeric promotion)
if isinstance(value, int) and float in allowed:
return
from .exceptions import ServiceInitializationError
allowed_names = ", ".join(t.__name__ for t in allowed)
raise ServiceInitializationError(
f"Invalid type for setting '{field_name}': expected {allowed_names}, got {type(value).__name__}."
)
def load_settings(
settings_type: type[SettingsT],
*,
env_prefix: str = "",
env_file_path: str | None = None,
env_file_encoding: str | None = None,
required_fields: Sequence[str] | None = None,
**overrides: Any,
) -> SettingsT:
"""Load settings from environment variables, a ``.env`` file, and explicit overrides.
The *settings_type* must be a ``TypedDict`` subclass. Values are resolved in
this order (highest priority first):
1. Explicit keyword *overrides* (``None`` values are filtered out).
2. Environment variables (``<env_prefix><FIELD_NAME>``).
3. A ``.env`` file (loaded via ``python-dotenv``; existing env vars take precedence).
4. Default values fields with class-level defaults on the TypedDict, or
``None`` for optional fields.
Fields listed in *required_fields* are validated after resolution. If any
required field resolves to ``None``, a ``SettingNotFoundError`` is raised.
This allows callers to decide which fields are required based on runtime
context (e.g. ``endpoint`` is only required when no pre-built client is
provided).
Args:
settings_type: A ``TypedDict`` class describing the settings schema.
env_prefix: Prefix for environment variable lookup (e.g. ``"OPENAI_"``).
env_file_path: Path to ``.env`` file. Defaults to ``".env"`` when omitted.
env_file_encoding: Encoding for reading the ``.env`` file. Defaults to ``"utf-8"``.
required_fields: Field names that must resolve to a non-``None`` value.
**overrides: Field values. ``None`` values are ignored so that callers can
forward optional parameters without masking env-var / default resolution.
Returns:
A populated dict matching *settings_type*.
Raises:
SettingNotFoundError: If a required field could not be resolved from any source.
ServiceInitializationError: If an override value has an incompatible type.
"""
encoding = env_file_encoding or "utf-8"
# Load .env file if it exists (existing env vars take precedence by default)
env_path = env_file_path or ".env"
if os.path.isfile(env_path):
load_dotenv(dotenv_path=env_path, encoding=encoding)
# Filter out None overrides so defaults / env vars are preserved
overrides = {k: v for k, v in overrides.items() if v is not None}
# Get field type hints from the TypedDict
hints = get_type_hints(settings_type)
required: set[str] = set(required_fields) if required_fields else set()
result: dict[str, Any] = {}
for field_name, field_type in hints.items():
# 1. Explicit override wins
if field_name in overrides:
override_value = overrides[field_name]
_check_override_type(override_value, field_type, field_name)
# Coerce plain str → SecretString if the annotation expects it
if isinstance(override_value, str) and not isinstance(override_value, SecretString):
with suppress(ValueError, TypeError):
coerced = _coerce_value(override_value, field_type)
if isinstance(coerced, SecretString):
override_value = coerced
result[field_name] = override_value
continue
# 2. Environment variable
env_var_name = f"{env_prefix}{field_name.upper()}"
env_value = os.getenv(env_var_name)
if env_value is not None:
try:
result[field_name] = _coerce_value(env_value, field_type)
except (ValueError, TypeError):
result[field_name] = env_value
continue
# 3. Default from TypedDict class-level defaults, or None for optional fields
if hasattr(settings_type, field_name):
result[field_name] = getattr(settings_type, field_name)
else:
result[field_name] = None
# Validate required fields after all resolution
if required:
for field_name in required:
if result.get(field_name) is None:
env_var_name = f"{env_prefix}{field_name.upper()}"
raise SettingNotFoundError(
f"Required setting '{field_name}' was not provided. "
f"Set it via the '{field_name}' parameter or the "
f"'{env_var_name}' environment variable."
)
return result # type: ignore[return-value]
@@ -7,12 +7,13 @@ from collections.abc import Mapping
from typing import TYPE_CHECKING, Any, ClassVar, Generic
from openai.lib.azure import AsyncAzureADTokenProvider, AsyncAzureOpenAI
from pydantic import ValidationError
from .._settings import load_settings
from ..exceptions import ServiceInitializationError
from ..openai import OpenAIAssistantsClient
from ..openai._assistants_client import OpenAIAssistantsOptions
from ._shared import AzureOpenAISettings
from ._entra_id_authentication import get_entra_auth_token
from ._shared import DEFAULT_AZURE_TOKEN_ENDPOINT, AzureOpenAISettings, _apply_azure_defaults
if TYPE_CHECKING:
from azure.core.credentials import TokenCredential
@@ -137,23 +138,21 @@ class AzureOpenAIAssistantsClient(
client: AzureOpenAIAssistantsClient[MyOptions] = AzureOpenAIAssistantsClient()
response = await client.get_response("Hello", options={"my_custom_option": "value"})
"""
try:
azure_openai_settings = AzureOpenAISettings(
# pydantic settings will see if there is a value, if not, will try the env var or .env file
api_key=api_key, # type: ignore
base_url=base_url, # type: ignore
endpoint=endpoint, # type: ignore
chat_deployment_name=deployment_name,
api_version=api_version,
env_file_path=env_file_path,
env_file_encoding=env_file_encoding,
token_endpoint=token_endpoint,
default_api_version=self.DEFAULT_AZURE_API_VERSION,
)
except ValidationError as ex:
raise ServiceInitializationError("Failed to create Azure OpenAI settings.", ex) from ex
azure_openai_settings = load_settings(
AzureOpenAISettings,
env_prefix="AZURE_OPENAI_",
api_key=api_key,
base_url=base_url,
endpoint=endpoint,
chat_deployment_name=deployment_name,
api_version=api_version,
env_file_path=env_file_path,
env_file_encoding=env_file_encoding,
token_endpoint=token_endpoint,
)
_apply_azure_defaults(azure_openai_settings, default_api_version=self.DEFAULT_AZURE_API_VERSION)
if not azure_openai_settings.chat_deployment_name:
if not azure_openai_settings["chat_deployment_name"]:
raise ServiceInitializationError(
"Azure OpenAI deployment name is required. Set via 'deployment_name' parameter "
"or 'AZURE_OPENAI_CHAT_DEPLOYMENT_NAME' environment variable."
@@ -162,40 +161,41 @@ class AzureOpenAIAssistantsClient(
# Handle authentication: try API key first, then AD token, then Entra ID
if (
not async_client
and not azure_openai_settings.api_key
and not azure_openai_settings["api_key"]
and not ad_token
and not ad_token_provider
and azure_openai_settings.token_endpoint
and azure_openai_settings["token_endpoint"]
and credential
):
ad_token = azure_openai_settings.get_azure_auth_token(credential)
token_ep = azure_openai_settings["token_endpoint"] or DEFAULT_AZURE_TOKEN_ENDPOINT
ad_token = get_entra_auth_token(credential, token_ep)
if not async_client and not azure_openai_settings.api_key and not ad_token and not ad_token_provider:
if not async_client and not azure_openai_settings["api_key"] and not ad_token and not ad_token_provider:
raise ServiceInitializationError("The Azure OpenAI API key, ad_token, or ad_token_provider is required.")
# Create Azure client if not provided
if not async_client:
client_params: dict[str, Any] = {
"api_version": azure_openai_settings.api_version,
"api_version": azure_openai_settings["api_version"],
"default_headers": default_headers,
}
if azure_openai_settings.api_key:
client_params["api_key"] = azure_openai_settings.api_key.get_secret_value()
if azure_openai_settings["api_key"]:
client_params["api_key"] = azure_openai_settings["api_key"].get_secret_value()
elif ad_token:
client_params["azure_ad_token"] = ad_token
elif ad_token_provider:
client_params["azure_ad_token_provider"] = ad_token_provider
if azure_openai_settings.base_url:
client_params["base_url"] = str(azure_openai_settings.base_url)
elif azure_openai_settings.endpoint:
client_params["azure_endpoint"] = str(azure_openai_settings.endpoint)
if azure_openai_settings["base_url"]:
client_params["base_url"] = str(azure_openai_settings["base_url"])
elif azure_openai_settings["endpoint"]:
client_params["azure_endpoint"] = str(azure_openai_settings["endpoint"])
async_client = AsyncAzureOpenAI(**client_params)
super().__init__(
model_id=azure_openai_settings.chat_deployment_name,
model_id=azure_openai_settings["chat_deployment_name"],
assistant_id=assistant_id,
assistant_name=assistant_name,
assistant_description=assistant_description,
@@ -12,7 +12,7 @@ from azure.core.credentials import TokenCredential
from openai.lib.azure import AsyncAzureADTokenProvider, AsyncAzureOpenAI
from openai.types.chat.chat_completion import Choice
from openai.types.chat.chat_completion_chunk import Choice as ChunkChoice
from pydantic import BaseModel, ValidationError
from pydantic import BaseModel
from agent_framework import (
Annotation,
@@ -28,9 +28,11 @@ from agent_framework.observability import ChatTelemetryLayer
from agent_framework.openai import OpenAIChatOptions
from agent_framework.openai._chat_client import RawOpenAIChatClient
from .._settings import load_settings
from ._shared import (
AzureOpenAIConfigMixin,
AzureOpenAISettings,
_apply_azure_defaults,
)
if sys.version_info >= (3, 13):
@@ -247,37 +249,35 @@ class AzureOpenAIChatClient( # type: ignore[misc]
client: AzureOpenAIChatClient[MyOptions] = AzureOpenAIChatClient()
response = await client.get_response("Hello", options={"my_custom_option": "value"})
"""
try:
# Filter out any None values from the arguments
azure_openai_settings = AzureOpenAISettings(
# pydantic settings will see if there is a value, if not, will try the env var or .env file
api_key=api_key, # type: ignore
base_url=base_url, # type: ignore
endpoint=endpoint, # type: ignore
chat_deployment_name=deployment_name,
api_version=api_version,
env_file_path=env_file_path,
env_file_encoding=env_file_encoding,
token_endpoint=token_endpoint,
)
except ValidationError as exc:
raise ServiceInitializationError(f"Failed to validate settings: {exc}") from exc
azure_openai_settings = load_settings(
AzureOpenAISettings,
env_prefix="AZURE_OPENAI_",
api_key=api_key,
base_url=base_url,
endpoint=endpoint,
chat_deployment_name=deployment_name,
api_version=api_version,
env_file_path=env_file_path,
env_file_encoding=env_file_encoding,
token_endpoint=token_endpoint,
)
_apply_azure_defaults(azure_openai_settings)
if not azure_openai_settings.chat_deployment_name:
if not azure_openai_settings["chat_deployment_name"]:
raise ServiceInitializationError(
"Azure OpenAI deployment name is required. Set via 'deployment_name' parameter "
"or 'AZURE_OPENAI_CHAT_DEPLOYMENT_NAME' environment variable."
)
super().__init__(
deployment_name=azure_openai_settings.chat_deployment_name,
endpoint=azure_openai_settings.endpoint,
base_url=azure_openai_settings.base_url,
api_version=azure_openai_settings.api_version, # type: ignore
api_key=azure_openai_settings.api_key.get_secret_value() if azure_openai_settings.api_key else None,
deployment_name=azure_openai_settings["chat_deployment_name"],
endpoint=azure_openai_settings["endpoint"],
base_url=azure_openai_settings["base_url"],
api_version=azure_openai_settings["api_version"], # type: ignore
api_key=azure_openai_settings["api_key"].get_secret_value() if azure_openai_settings["api_key"] else None,
ad_token=ad_token,
ad_token_provider=ad_token_provider,
token_endpoint=azure_openai_settings.token_endpoint,
token_endpoint=azure_openai_settings["token_endpoint"],
credential=credential,
default_headers=default_headers,
client=async_client,
@@ -5,15 +5,15 @@ from __future__ import annotations
import sys
from collections.abc import Mapping, Sequence
from typing import TYPE_CHECKING, Any, Generic
from urllib.parse import urljoin
from urllib.parse import urljoin, urlparse
from azure.ai.projects.aio import AIProjectClient
from azure.core.credentials import TokenCredential
from openai import AsyncOpenAI
from openai.lib.azure import AsyncAzureADTokenProvider
from pydantic import ValidationError
from .._middleware import ChatMiddlewareLayer
from .._settings import load_settings
from .._telemetry import AGENT_FRAMEWORK_USER_AGENT
from .._tools import FunctionInvocationConfiguration, FunctionInvocationLayer
from ..exceptions import ServiceInitializationError
@@ -22,6 +22,7 @@ from ..openai._responses_client import RawOpenAIResponsesClient
from ._shared import (
AzureOpenAIConfigMixin,
AzureOpenAISettings,
_apply_azure_defaults,
)
if sys.version_info >= (3, 13):
@@ -82,7 +83,8 @@ class AzureOpenAIResponsesClient( # type: ignore[misc]
env_file_encoding: str | None = None,
instruction_role: str | None = None,
middleware: Sequence[MiddlewareTypes] | None = None,
function_invocation_configuration: FunctionInvocationConfiguration | None = None,
function_invocation_configuration: FunctionInvocationConfiguration
| None = None,
**kwargs: Any,
) -> None:
"""Initialize an Azure OpenAI Responses client.
@@ -188,54 +190,58 @@ class AzureOpenAIResponsesClient( # type: ignore[misc]
deployment_name = str(model_id)
# Project client path: create OpenAI client from an Azure AI Foundry project
if async_client is None and (project_client is not None or project_endpoint is not None):
if async_client is None and (
project_client is not None or project_endpoint is not None
):
async_client = self._create_client_from_project(
project_client=project_client,
project_endpoint=project_endpoint,
credential=credential,
)
try:
azure_openai_settings = AzureOpenAISettings(
# pydantic settings will see if there is a value, if not, will try the env var or .env file
api_key=api_key, # type: ignore
base_url=base_url, # type: ignore
endpoint=endpoint, # type: ignore
responses_deployment_name=deployment_name,
api_version=api_version,
env_file_path=env_file_path,
env_file_encoding=env_file_encoding,
token_endpoint=token_endpoint,
default_api_version="preview",
azure_openai_settings = load_settings(
AzureOpenAISettings,
env_prefix="AZURE_OPENAI_",
api_key=api_key,
base_url=base_url,
endpoint=endpoint,
responses_deployment_name=deployment_name,
api_version=api_version,
env_file_path=env_file_path,
env_file_encoding=env_file_encoding,
token_endpoint=token_endpoint,
)
_apply_azure_defaults(azure_openai_settings, default_api_version="preview")
# TODO(peterychang): This is a temporary hack to ensure that the base_url is set correctly
# while this feature is in preview.
# But we should only do this if we're on azure. Private deployments may not need this.
if (
not azure_openai_settings.get("base_url")
and azure_openai_settings.get("endpoint")
and (hostname := urlparse(str(azure_openai_settings["endpoint"])).hostname)
and hostname.endswith(".openai.azure.com")
):
azure_openai_settings["base_url"] = urljoin(
str(azure_openai_settings["endpoint"]), "/openai/v1/"
)
# TODO(peterychang): This is a temporary hack to ensure that the base_url is set correctly
# while this feature is in preview.
# But we should only do this if we're on azure. Private deployments may not need this.
if (
not azure_openai_settings.base_url
and azure_openai_settings.endpoint
and azure_openai_settings.endpoint.host
and azure_openai_settings.endpoint.host.endswith(".openai.azure.com")
):
azure_openai_settings.base_url = urljoin(str(azure_openai_settings.endpoint), "/openai/v1/") # type: ignore
except ValidationError as exc:
raise ServiceInitializationError(f"Failed to validate settings: {exc}") from exc
if not azure_openai_settings.responses_deployment_name:
if not azure_openai_settings["responses_deployment_name"]:
raise ServiceInitializationError(
"Azure OpenAI deployment name is required. Set via 'deployment_name' parameter "
"or 'AZURE_OPENAI_RESPONSES_DEPLOYMENT_NAME' environment variable."
)
super().__init__(
deployment_name=azure_openai_settings.responses_deployment_name,
endpoint=azure_openai_settings.endpoint,
base_url=azure_openai_settings.base_url,
api_version=azure_openai_settings.api_version, # type: ignore
api_key=azure_openai_settings.api_key.get_secret_value() if azure_openai_settings.api_key else None,
deployment_name=azure_openai_settings["responses_deployment_name"],
endpoint=azure_openai_settings["endpoint"],
base_url=azure_openai_settings["base_url"],
api_version=azure_openai_settings["api_version"], # type: ignore
api_key=azure_openai_settings["api_key"].get_secret_value()
if azure_openai_settings["api_key"]
else None,
ad_token=ad_token,
ad_token_provider=ad_token_provider,
token_endpoint=azure_openai_settings.token_endpoint,
token_endpoint=azure_openai_settings["token_endpoint"],
credential=credential,
default_headers=default_headers,
client=async_client,
@@ -11,28 +11,26 @@ from typing import Any, ClassVar, Final
from azure.core.credentials import TokenCredential
from openai import AsyncOpenAI
from openai.lib.azure import AsyncAzureOpenAI
from pydantic import SecretStr, model_validator
from .._pydantic import AFBaseSettings, HTTPsUrl
from .._settings import SecretString
from .._telemetry import APP_INFO, prepend_agent_framework_to_user_agent
from ..exceptions import ServiceInitializationError
from ..openai._shared import OpenAIBase
from ._entra_id_authentication import get_entra_auth_token
if sys.version_info >= (3, 11):
from typing import Self # pragma: no cover
else:
from typing_extensions import Self # pragma: no cover
logger: logging.Logger = logging.getLogger(__name__)
if sys.version_info >= (3, 11):
from typing import TypedDict # type: ignore # pragma: no cover
else:
from typing_extensions import TypedDict # type: ignore # pragma: no cover
DEFAULT_AZURE_API_VERSION: Final[str] = "2024-10-21"
DEFAULT_AZURE_TOKEN_ENDPOINT: Final[str] = "https://cognitiveservices.azure.com/.default" # noqa: S105
class AzureOpenAISettings(AFBaseSettings):
class AzureOpenAISettings(TypedDict, total=False):
"""AzureOpenAI model settings.
The settings are first loaded from environment variables with the prefix 'AZURE_OPENAI_'.
@@ -62,7 +60,7 @@ class AzureOpenAISettings(AFBaseSettings):
found in the Keys & Endpoint section when examining your resource in
the Azure portal. You can use either KEY1 or KEY2.
Can be set via environment variable AZURE_OPENAI_API_KEY.
api_version: The API version to use. The default value is `default_api_version`.
api_version: The API version to use. The default value is `DEFAULT_AZURE_API_VERSION`.
Can be set via environment variable AZURE_OPENAI_API_VERSION.
base_url: The url of the Azure deployment. This value
can be found in the Keys & Endpoint section when examining
@@ -71,14 +69,8 @@ class AzureOpenAISettings(AFBaseSettings):
use endpoint if you only want to supply the endpoint.
Can be set via environment variable AZURE_OPENAI_BASE_URL.
token_endpoint: The token endpoint to use to retrieve the authentication token.
The default value is `default_token_endpoint`.
The default value is `DEFAULT_AZURE_TOKEN_ENDPOINT`.
Can be set via environment variable AZURE_OPENAI_TOKEN_ENDPOINT.
default_api_version: The default API version to use if not specified.
The default value is "2024-10-21".
default_token_endpoint: The default token endpoint to use if not specified.
The default value is "https://cognitiveservices.azure.com/.default".
env_file_path: The path to the .env file to load settings from.
env_file_encoding: The encoding of the .env file, defaults to 'utf-8'.
Examples:
.. code-block:: python
@@ -89,60 +81,46 @@ class AzureOpenAISettings(AFBaseSettings):
# Set AZURE_OPENAI_ENDPOINT=https://your-endpoint.openai.azure.com
# Set AZURE_OPENAI_CHAT_DEPLOYMENT_NAME=gpt-4
# Set AZURE_OPENAI_API_KEY=your-key
settings = AzureOpenAISettings()
settings = load_settings(AzureOpenAISettings, env_prefix="AZURE_OPENAI_")
# Or passing parameters directly
settings = AzureOpenAISettings(
endpoint="https://your-endpoint.openai.azure.com", chat_deployment_name="gpt-4", api_key="your-key"
settings = load_settings(
AzureOpenAISettings,
env_prefix="AZURE_OPENAI_",
endpoint="https://your-endpoint.openai.azure.com",
chat_deployment_name="gpt-4",
api_key="your-key",
)
# Or loading from a .env file
settings = AzureOpenAISettings(env_file_path="path/to/.env")
settings = load_settings(AzureOpenAISettings, env_prefix="AZURE_OPENAI_", env_file_path="path/to/.env")
"""
env_prefix: ClassVar[str] = "AZURE_OPENAI_"
chat_deployment_name: str | None
responses_deployment_name: str | None
endpoint: str | None
base_url: str | None
api_key: SecretString | None
api_version: str | None
token_endpoint: str | None
chat_deployment_name: str | None = None
responses_deployment_name: str | None = None
endpoint: HTTPsUrl | None = None
base_url: HTTPsUrl | None = None
api_key: SecretStr | None = None
api_version: str | None = None
token_endpoint: str | None = None
default_api_version: str = DEFAULT_AZURE_API_VERSION
default_token_endpoint: str = DEFAULT_AZURE_TOKEN_ENDPOINT
def get_azure_auth_token(
self, credential: TokenCredential, token_endpoint: str | None = None, **kwargs: Any
) -> str | None:
"""Retrieve a Microsoft Entra Auth Token for a given token endpoint for the use with Azure OpenAI.
def _apply_azure_defaults(
settings: AzureOpenAISettings,
default_api_version: str = DEFAULT_AZURE_API_VERSION,
default_token_endpoint: str = DEFAULT_AZURE_TOKEN_ENDPOINT,
) -> None:
"""Apply default values for api_version and token_endpoint after loading settings.
The required role for the token is `Cognitive Services OpenAI Contributor`.
The token endpoint may be specified as an environment variable, via the .env
file or as an argument. If the token endpoint is not provided, the default is None.
The `token_endpoint` argument takes precedence over the `token_endpoint` attribute.
Args:
credential: The Azure AD credential to use.
token_endpoint: The token endpoint to use. Defaults to `https://cognitiveservices.azure.com/.default`.
Keyword Args:
**kwargs: Additional keyword arguments to pass to the token retrieval method.
Returns:
The Azure token or None if the token could not be retrieved.
Raises:
ServiceInitializationError: If the token endpoint is not provided.
"""
endpoint_to_use = token_endpoint or self.token_endpoint or self.default_token_endpoint
return get_entra_auth_token(credential, endpoint_to_use, **kwargs)
@model_validator(mode="after")
def _validate_fields(self) -> Self:
self.api_version = self.api_version or self.default_api_version
self.token_endpoint = self.token_endpoint or self.default_token_endpoint
return self
Args:
settings: The loaded Azure OpenAI settings dict.
default_api_version: The default API version to use if not set.
default_token_endpoint: The default token endpoint to use if not set.
"""
if not settings.get("api_version"):
settings["api_version"] = default_api_version
if not settings.get("token_endpoint"):
settings["token_endpoint"] = default_token_endpoint
class AzureOpenAIConfigMixin(OpenAIBase):
@@ -154,8 +132,8 @@ class AzureOpenAIConfigMixin(OpenAIBase):
def __init__(
self,
deployment_name: str,
endpoint: HTTPsUrl | None = None,
base_url: HTTPsUrl | None = None,
endpoint: str | None = None,
base_url: str | None = None,
api_version: str = DEFAULT_AZURE_API_VERSION,
api_key: str | None = None,
ad_token: str | None = None,
@@ -170,7 +148,7 @@ class AzureOpenAIConfigMixin(OpenAIBase):
"""Internal class for configuring a connection to an Azure OpenAI service.
The `validate_call` decorator is used with a configuration that allows arbitrary types.
This is necessary for types like `HTTPsUrl` and `OpenAIModelTypes`.
This is necessary for types like `str` and `OpenAIModelTypes`.
Args:
deployment_name: Name of the deployment.
@@ -146,3 +146,9 @@ class ContentError(AgentFrameworkException):
"""An error occurred while processing content."""
pass
class SettingNotFoundError(AgentFrameworkException):
"""A required setting could not be resolved from any source."""
pass
@@ -18,11 +18,10 @@ from opentelemetry import metrics, trace
from opentelemetry.sdk.resources import Resource
from opentelemetry.semconv.attributes import service_attributes
from opentelemetry.semconv_ai import Meters, SpanAttributes
from pydantic import PrivateAttr
from . import __version__ as version_info
from ._logging import get_logger
from ._pydantic import AFBaseSettings
from ._settings import load_settings
if sys.version_info >= (3, 13):
from typing import TypeVar # type: ignore # pragma: no cover
@@ -566,7 +565,16 @@ def create_metric_views() -> list[View]:
]
class ObservabilitySettings(AFBaseSettings):
class _ObservabilitySettingsData(TypedDict, total=False):
"""TypedDict schema for observability settings fields."""
enable_instrumentation: bool | None
enable_sensitive_data: bool | None
enable_console_exporters: bool | None
vs_code_extension_port: int | None
class ObservabilitySettings:
"""Settings for Agent Framework Observability.
If the environment variables are not found, the settings can
@@ -603,23 +611,27 @@ class ObservabilitySettings(AFBaseSettings):
settings = ObservabilitySettings(enable_instrumentation=True, enable_console_exporters=True)
"""
env_prefix: ClassVar[str] = ""
enable_instrumentation: bool = False
enable_sensitive_data: bool = False
enable_console_exporters: bool = False
vs_code_extension_port: int | None = None
_resource: Resource = PrivateAttr()
_executed_setup: bool = PrivateAttr(default=False)
def __init__(self, **kwargs: Any) -> None:
"""Initialize the settings and create the resource."""
super().__init__(**kwargs)
# Create resource with env file settings
self._resource = create_resource(
env_file_path=self.env_file_path,
env_file_encoding=self.env_file_encoding,
env_file_path = kwargs.pop("env_file_path", None)
env_file_encoding = kwargs.pop("env_file_encoding", None)
data = load_settings(
_ObservabilitySettingsData,
env_file_path=env_file_path,
env_file_encoding=env_file_encoding,
**kwargs,
)
self.enable_instrumentation: bool = data.get("enable_instrumentation") or False
self.enable_sensitive_data: bool = data.get("enable_sensitive_data") or False
self.enable_console_exporters: bool = data.get("enable_console_exporters") or False
self.vs_code_extension_port: int | None = data.get("vs_code_extension_port")
self.env_file_path = env_file_path
self.env_file_encoding = env_file_encoding
self._resource = create_resource(
env_file_path=env_file_path,
env_file_encoding=env_file_encoding,
)
self._executed_setup = False
@property
def ENABLED(self) -> bool:
@@ -8,7 +8,9 @@ from typing import TYPE_CHECKING, Any, Generic, cast
from openai import AsyncOpenAI
from openai.types.beta.assistant import Assistant
from pydantic import BaseModel, SecretStr, ValidationError
from pydantic import BaseModel
from agent_framework._settings import SecretString, load_settings
from .._agents import Agent
from .._memory import ContextProvider
@@ -107,7 +109,7 @@ class OpenAIAssistantProvider(Generic[OptionsCoT]):
self,
client: AsyncOpenAI | None = None,
*,
api_key: str | SecretStr | Callable[[], str | Awaitable[str]] | None = None,
api_key: str | SecretString | Callable[[], str | Awaitable[str]] | None = None,
org_id: str | None = None,
base_url: str | None = None,
env_file_path: str | None = None,
@@ -147,35 +149,34 @@ class OpenAIAssistantProvider(Generic[OptionsCoT]):
if client is None:
# Load settings and create client
try:
settings = OpenAISettings(
api_key=api_key, # type: ignore[reportArgumentType]
org_id=org_id,
base_url=base_url,
env_file_path=env_file_path,
env_file_encoding=env_file_encoding,
)
except ValidationError as ex:
raise ServiceInitializationError("Failed to create OpenAI settings.", ex) from ex
settings = load_settings(
OpenAISettings,
env_prefix="OPENAI_",
api_key=api_key,
org_id=org_id,
base_url=base_url,
env_file_path=env_file_path,
env_file_encoding=env_file_encoding,
)
if not settings.api_key:
if not settings["api_key"]:
raise ServiceInitializationError(
"OpenAI API key is required. Set via 'api_key' parameter or 'OPENAI_API_KEY' environment variable."
)
# Get API key value
api_key_value: str | Callable[[], str | Awaitable[str]] | None
if isinstance(settings.api_key, SecretStr):
api_key_value = settings.api_key.get_secret_value()
if isinstance(settings["api_key"], SecretString):
api_key_value = settings["api_key"].get_secret_value()
else:
api_key_value = settings.api_key
api_key_value = settings["api_key"]
# Create client
client_args: dict[str, Any] = {"api_key": api_key_value}
if settings.org_id:
client_args["organization"] = settings.org_id
if settings.base_url:
client_args["base_url"] = settings.base_url
if settings["org_id"]:
client_args["organization"] = settings["org_id"]
if settings["base_url"]:
client_args["base_url"] = settings["base_url"]
self._client = AsyncOpenAI(**client_args)
@@ -27,10 +27,11 @@ from openai.types.beta.threads import (
from openai.types.beta.threads.run_create_params import AdditionalMessage
from openai.types.beta.threads.run_submit_tool_outputs_params import ToolOutput
from openai.types.beta.threads.runs import RunStep
from pydantic import BaseModel, ValidationError
from pydantic import BaseModel
from .._clients import BaseChatClient
from .._middleware import ChatMiddlewareLayer
from .._settings import load_settings
from .._tools import (
FunctionInvocationConfiguration,
FunctionInvocationLayer,
@@ -343,35 +344,34 @@ class OpenAIAssistantsClient( # type: ignore[misc]
client: OpenAIAssistantsClient[MyOptions] = OpenAIAssistantsClient(model_id="gpt-4")
response = await client.get_response("Hello", options={"my_custom_option": "value"})
"""
try:
openai_settings = OpenAISettings(
api_key=api_key, # type: ignore[reportArgumentType]
base_url=base_url,
org_id=org_id,
chat_model_id=model_id,
env_file_path=env_file_path,
env_file_encoding=env_file_encoding,
)
except ValidationError as ex:
raise ServiceInitializationError("Failed to create OpenAI settings.", ex) from ex
openai_settings = load_settings(
OpenAISettings,
env_prefix="OPENAI_",
api_key=api_key,
base_url=base_url,
org_id=org_id,
chat_model_id=model_id,
env_file_path=env_file_path,
env_file_encoding=env_file_encoding,
)
if not async_client and not openai_settings.api_key:
if not async_client and not openai_settings["api_key"]:
raise ServiceInitializationError(
"OpenAI API key is required. Set via 'api_key' parameter or 'OPENAI_API_KEY' environment variable."
)
if not openai_settings.chat_model_id:
if not openai_settings["chat_model_id"]:
raise ServiceInitializationError(
"OpenAI model ID is required. "
"Set via 'model_id' parameter or 'OPENAI_CHAT_MODEL_ID' environment variable."
)
super().__init__(
model_id=openai_settings.chat_model_id,
api_key=self._get_api_key(openai_settings.api_key),
org_id=openai_settings.org_id,
model_id=openai_settings["chat_model_id"],
api_key=self._get_api_key(openai_settings["api_key"]),
org_id=openai_settings["org_id"],
default_headers=default_headers,
client=async_client,
base_url=openai_settings.base_url,
base_url=openai_settings["base_url"],
middleware=middleware,
function_invocation_configuration=function_invocation_configuration,
)
@@ -17,11 +17,12 @@ from openai.types.chat.chat_completion_chunk import ChatCompletionChunk
from openai.types.chat.chat_completion_chunk import Choice as ChunkChoice
from openai.types.chat.chat_completion_message_custom_tool_call import ChatCompletionMessageCustomToolCall
from openai.types.chat.completion_create_params import WebSearchOptions
from pydantic import BaseModel, ValidationError
from pydantic import BaseModel
from .._clients import BaseChatClient
from .._logging import get_logger
from .._middleware import ChatAndFunctionMiddlewareTypes, ChatMiddlewareLayer
from .._settings import load_settings
from .._tools import (
FunctionInvocationConfiguration,
FunctionInvocationLayer,
@@ -718,33 +719,32 @@ class OpenAIChatClient( # type: ignore[misc]
client: OpenAIChatClient[MyOptions] = OpenAIChatClient(model_id="<model name>")
response = await client.get_response("Hello", options={"my_custom_option": "value"})
"""
try:
openai_settings = OpenAISettings(
api_key=api_key, # type: ignore[reportArgumentType]
base_url=base_url,
org_id=org_id,
chat_model_id=model_id,
env_file_path=env_file_path,
env_file_encoding=env_file_encoding,
)
except ValidationError as ex:
raise ServiceInitializationError("Failed to create OpenAI settings.", ex) from ex
openai_settings = load_settings(
OpenAISettings,
env_prefix="OPENAI_",
api_key=api_key,
base_url=base_url,
org_id=org_id,
chat_model_id=model_id,
env_file_path=env_file_path,
env_file_encoding=env_file_encoding,
)
if not async_client and not openai_settings.api_key:
if not async_client and not openai_settings["api_key"]:
raise ServiceInitializationError(
"OpenAI API key is required. Set via 'api_key' parameter or 'OPENAI_API_KEY' environment variable."
)
if not openai_settings.chat_model_id:
if not openai_settings["chat_model_id"]:
raise ServiceInitializationError(
"OpenAI model ID is required. "
"Set via 'model_id' parameter or 'OPENAI_CHAT_MODEL_ID' environment variable."
)
super().__init__(
model_id=openai_settings.chat_model_id,
api_key=self._get_api_key(openai_settings.api_key),
base_url=openai_settings.base_url if openai_settings.base_url else None,
org_id=openai_settings.org_id,
model_id=openai_settings["chat_model_id"],
api_key=self._get_api_key(openai_settings["api_key"]),
base_url=openai_settings["base_url"] if openai_settings["base_url"] else None,
org_id=openai_settings["org_id"],
default_headers=default_headers,
client=async_client,
instruction_role=instruction_role,
@@ -33,11 +33,12 @@ from openai.types.responses.tool_param import (
Mcp,
)
from openai.types.responses.web_search_tool_param import WebSearchToolParam
from pydantic import BaseModel, ValidationError
from pydantic import BaseModel
from .._clients import BaseChatClient
from .._logging import get_logger
from .._middleware import ChatMiddlewareLayer
from .._settings import load_settings
from .._tools import (
FunctionInvocationConfiguration,
FunctionInvocationLayer,
@@ -1810,36 +1811,35 @@ class OpenAIResponsesClient( # type: ignore[misc]
client: OpenAIResponsesClient[MyOptions] = OpenAIResponsesClient(model_id="gpt-4o")
response = await client.get_response("Hello", options={"my_custom_option": "value"})
"""
try:
openai_settings = OpenAISettings(
api_key=api_key, # type: ignore[reportArgumentType]
org_id=org_id,
base_url=base_url,
responses_model_id=model_id,
env_file_path=env_file_path,
env_file_encoding=env_file_encoding,
)
except ValidationError as ex:
raise ServiceInitializationError("Failed to create OpenAI settings.", ex) from ex
openai_settings = load_settings(
OpenAISettings,
env_prefix="OPENAI_",
api_key=api_key,
org_id=org_id,
base_url=base_url,
responses_model_id=model_id,
env_file_path=env_file_path,
env_file_encoding=env_file_encoding,
)
if not async_client and not openai_settings.api_key:
if not async_client and not openai_settings["api_key"]:
raise ServiceInitializationError(
"OpenAI API key is required. Set via 'api_key' parameter or 'OPENAI_API_KEY' environment variable."
)
if not openai_settings.responses_model_id:
if not openai_settings["responses_model_id"]:
raise ServiceInitializationError(
"OpenAI model ID is required. "
"Set via 'model_id' parameter or 'OPENAI_RESPONSES_MODEL_ID' environment variable."
)
super().__init__(
model_id=openai_settings.responses_model_id,
api_key=self._get_api_key(openai_settings.api_key),
org_id=openai_settings.org_id,
model_id=openai_settings["responses_model_id"],
api_key=self._get_api_key(openai_settings["api_key"]),
org_id=openai_settings["org_id"],
default_headers=default_headers,
client=async_client,
instruction_role=instruction_role,
base_url=openai_settings.base_url,
base_url=openai_settings["base_url"],
middleware=middleware,
function_invocation_configuration=function_invocation_configuration,
**kwargs,
@@ -3,6 +3,7 @@
from __future__ import annotations
import logging
import sys
from collections.abc import Awaitable, Callable, Mapping, MutableMapping, Sequence
from copy import copy
from typing import Any, ClassVar, Union
@@ -20,11 +21,10 @@ from openai.types.images_response import ImagesResponse
from openai.types.responses.response import Response
from openai.types.responses.response_stream_event import ResponseStreamEvent
from packaging.version import parse
from pydantic import SecretStr
from .._logging import get_logger
from .._pydantic import AFBaseSettings
from .._serialization import SerializationMixin
from .._settings import SecretString
from .._telemetry import APP_INFO, USER_AGENT_KEY, prepend_agent_framework_to_user_agent
from .._tools import FunctionTool
from ..exceptions import ServiceInitializationError
@@ -47,6 +47,11 @@ RESPONSE_TYPE = Union[
OPTION_TYPE = dict[str, Any]
if sys.version_info >= (3, 11):
from typing import TypedDict # type: ignore # pragma: no cover
else:
from typing_extensions import TypedDict # type: ignore # pragma: no cover
__all__ = ["OpenAISettings"]
@@ -74,7 +79,7 @@ def _check_openai_version_for_callable_api_key() -> None:
logger.warning(f"Could not check OpenAI version for callable API key support: {e}")
class OpenAISettings(AFBaseSettings):
class OpenAISettings(TypedDict, total=False):
"""OpenAI environment settings.
The settings are first loaded from environment variables with the prefix 'OPENAI_'.
@@ -93,8 +98,6 @@ class OpenAISettings(AFBaseSettings):
Can be set via environment variable OPENAI_CHAT_MODEL_ID.
responses_model_id: The OpenAI responses model ID to use, for example, gpt-4o or o1.
Can be set via environment variable OPENAI_RESPONSES_MODEL_ID.
env_file_path: The path to the .env file to load settings from.
env_file_encoding: The encoding of the .env file, defaults to 'utf-8'.
Examples:
.. code-block:: python
@@ -104,22 +107,20 @@ class OpenAISettings(AFBaseSettings):
# Using environment variables
# Set OPENAI_API_KEY=sk-...
# Set OPENAI_CHAT_MODEL_ID=gpt-4
settings = OpenAISettings()
settings = load_settings(OpenAISettings, env_prefix="OPENAI_")
# Or passing parameters directly
settings = OpenAISettings(api_key="sk-...", chat_model_id="gpt-4")
settings = load_settings(OpenAISettings, env_prefix="OPENAI_", api_key="sk-...", chat_model_id="gpt-4")
# Or loading from a .env file
settings = OpenAISettings(env_file_path="path/to/.env")
settings = load_settings(OpenAISettings, env_prefix="OPENAI_", env_file_path="path/to/.env")
"""
env_prefix: ClassVar[str] = "OPENAI_"
api_key: SecretStr | Callable[[], str | Awaitable[str]] | None = None
base_url: str | None = None
org_id: str | None = None
chat_model_id: str | None = None
responses_model_id: str | None = None
api_key: SecretString | Callable[[], str | Awaitable[str]] | None
base_url: str | None
org_id: str | None
chat_model_id: str | None
responses_model_id: str | None
class OpenAIBase(SerializationMixin):
@@ -181,19 +182,18 @@ class OpenAIBase(SerializationMixin):
return self.client
def _get_api_key(
self, api_key: str | SecretStr | Callable[[], str | Awaitable[str]] | None
self, api_key: str | SecretString | Callable[[], str | Awaitable[str]] | None
) -> str | Callable[[], str | Awaitable[str]] | None:
"""Get the appropriate API key value for client initialization.
Args:
api_key: The API key parameter which can be a string, SecretStr, callable, or None.
api_key: The API key parameter which can be a string, SecretString, callable, or None.
Returns:
For callable API keys: returns the callable directly.
For SecretStr API keys: returns the string value.
For string/None API keys: returns as-is.
For SecretString/string/None API keys: returns as-is (SecretString is a str subclass).
"""
if isinstance(api_key, SecretStr):
if isinstance(api_key, SecretString):
return api_key.get_secret_value()
# Check version compatibility for callable API keys
+1 -1
View File
@@ -26,7 +26,7 @@ dependencies = [
# utilities
"typing-extensions",
"pydantic>=2,<3",
"pydantic-settings>=2,<3",
"python-dotenv>=1,<2",
# telemetry
"opentelemetry-api>=1.39.0",
"opentelemetry-sdk>=1.39.0",
@@ -19,6 +19,7 @@ from agent_framework import (
SupportsChatGetResponse,
tool,
)
from agent_framework._settings import SecretString
from agent_framework.azure import AzureOpenAIAssistantsClient
from agent_framework.exceptions import ServiceInitializationError
@@ -556,19 +557,21 @@ def test_azure_assistants_client_entra_id_authentication() -> None:
mock_credential = MagicMock()
with (
patch("agent_framework.azure._assistants_client.AzureOpenAISettings") as mock_settings_class,
patch("agent_framework.azure._assistants_client.load_settings") as mock_load_settings,
patch("agent_framework.azure._assistants_client.get_entra_auth_token") as mock_get_token,
patch("agent_framework.azure._assistants_client.AsyncAzureOpenAI") as mock_azure_client,
patch("agent_framework.openai.OpenAIAssistantsClient.__init__", return_value=None),
):
mock_settings = MagicMock()
mock_settings.chat_deployment_name = "test-deployment"
mock_settings.api_key = None # No API key to trigger Entra ID path
mock_settings.token_endpoint = "https://login.microsoftonline.com/test"
mock_settings.get_azure_auth_token.return_value = "entra-token-12345"
mock_settings.api_version = "2024-05-01-preview"
mock_settings.endpoint = "https://test-endpoint.openai.azure.com"
mock_settings.base_url = None
mock_settings_class.return_value = mock_settings
mock_load_settings.return_value = {
"chat_deployment_name": "test-deployment",
"responses_deployment_name": None,
"api_key": None,
"token_endpoint": "https://login.microsoftonline.com/test",
"api_version": "2024-05-01-preview",
"endpoint": "https://test-endpoint.openai.azure.com",
"base_url": None,
}
mock_get_token.return_value = "entra-token-12345"
client = AzureOpenAIAssistantsClient(
deployment_name="test-deployment",
@@ -579,7 +582,7 @@ def test_azure_assistants_client_entra_id_authentication() -> None:
)
# Verify Entra ID token was requested
mock_settings.get_azure_auth_token.assert_called_once_with(mock_credential)
mock_get_token.assert_called_once_with(mock_credential, "https://login.microsoftonline.com/test")
# Verify client was created with the token
mock_azure_client.assert_called_once()
@@ -592,12 +595,16 @@ def test_azure_assistants_client_entra_id_authentication() -> None:
def test_azure_assistants_client_no_authentication_error() -> None:
"""Test authentication validation error when no auth provided."""
with patch("agent_framework.azure._assistants_client.AzureOpenAISettings") as mock_settings_class:
mock_settings = MagicMock()
mock_settings.chat_deployment_name = "test-deployment"
mock_settings.api_key = None # No API key
mock_settings.token_endpoint = None # No token endpoint
mock_settings_class.return_value = mock_settings
with patch("agent_framework.azure._assistants_client.load_settings") as mock_load_settings:
mock_load_settings.return_value = {
"chat_deployment_name": "test-deployment",
"responses_deployment_name": None,
"api_key": None,
"token_endpoint": None,
"api_version": "2024-05-01-preview",
"endpoint": "https://test-endpoint.openai.azure.com",
"base_url": None,
}
# Test missing authentication raises error
with pytest.raises(ServiceInitializationError, match="API key, ad_token, or ad_token_provider is required"):
@@ -611,17 +618,19 @@ def test_azure_assistants_client_no_authentication_error() -> None:
def test_azure_assistants_client_ad_token_authentication() -> None:
"""Test ad_token authentication client parameter path."""
with (
patch("agent_framework.azure._assistants_client.AzureOpenAISettings") as mock_settings_class,
patch("agent_framework.azure._assistants_client.load_settings") as mock_load_settings,
patch("agent_framework.azure._assistants_client.AsyncAzureOpenAI") as mock_azure_client,
patch("agent_framework.openai.OpenAIAssistantsClient.__init__", return_value=None),
):
mock_settings = MagicMock()
mock_settings.chat_deployment_name = "test-deployment"
mock_settings.api_key = None # No API key
mock_settings.api_version = "2024-05-01-preview"
mock_settings.endpoint = "https://test-endpoint.openai.azure.com"
mock_settings.base_url = None
mock_settings_class.return_value = mock_settings
mock_load_settings.return_value = {
"chat_deployment_name": "test-deployment",
"responses_deployment_name": None,
"api_key": None,
"token_endpoint": None,
"api_version": "2024-05-01-preview",
"endpoint": "https://test-endpoint.openai.azure.com",
"base_url": None,
}
client = AzureOpenAIAssistantsClient(
deployment_name="test-deployment",
@@ -645,17 +654,19 @@ def test_azure_assistants_client_ad_token_provider_authentication() -> None:
mock_token_provider = MagicMock(spec=AsyncAzureADTokenProvider)
with (
patch("agent_framework.azure._assistants_client.AzureOpenAISettings") as mock_settings_class,
patch("agent_framework.azure._assistants_client.load_settings") as mock_load_settings,
patch("agent_framework.azure._assistants_client.AsyncAzureOpenAI") as mock_azure_client,
patch("agent_framework.openai.OpenAIAssistantsClient.__init__", return_value=None),
):
mock_settings = MagicMock()
mock_settings.chat_deployment_name = "test-deployment"
mock_settings.api_key = None # No API key
mock_settings.api_version = "2024-05-01-preview"
mock_settings.endpoint = "https://test-endpoint.openai.azure.com"
mock_settings.base_url = None
mock_settings_class.return_value = mock_settings
mock_load_settings.return_value = {
"chat_deployment_name": "test-deployment",
"responses_deployment_name": None,
"api_key": None,
"token_endpoint": None,
"api_version": "2024-05-01-preview",
"endpoint": "https://test-endpoint.openai.azure.com",
"base_url": None,
}
client = AzureOpenAIAssistantsClient(
deployment_name="test-deployment",
@@ -675,17 +686,19 @@ def test_azure_assistants_client_ad_token_provider_authentication() -> None:
def test_azure_assistants_client_base_url_configuration() -> None:
"""Test base_url client parameter path."""
with (
patch("agent_framework.azure._assistants_client.AzureOpenAISettings") as mock_settings_class,
patch("agent_framework.azure._assistants_client.load_settings") as mock_load_settings,
patch("agent_framework.azure._assistants_client.AsyncAzureOpenAI") as mock_azure_client,
patch("agent_framework.openai.OpenAIAssistantsClient.__init__", return_value=None),
):
mock_settings = MagicMock()
mock_settings.chat_deployment_name = "test-deployment"
mock_settings.api_key.get_secret_value.return_value = "test-api-key"
mock_settings.base_url = "https://custom-base-url.com"
mock_settings.endpoint = None # No endpoint, should use base_url
mock_settings.api_version = "2024-05-01-preview"
mock_settings_class.return_value = mock_settings
mock_load_settings.return_value = {
"chat_deployment_name": "test-deployment",
"responses_deployment_name": None,
"api_key": SecretString("test-api-key"),
"token_endpoint": None,
"api_version": "2024-05-01-preview",
"endpoint": None,
"base_url": "https://custom-base-url.com",
}
client = AzureOpenAIAssistantsClient(
deployment_name="test-deployment", api_key="test-api-key", base_url="https://custom-base-url.com"
@@ -704,17 +717,19 @@ def test_azure_assistants_client_base_url_configuration() -> None:
def test_azure_assistants_client_azure_endpoint_configuration() -> None:
"""Test azure_endpoint client parameter path."""
with (
patch("agent_framework.azure._assistants_client.AzureOpenAISettings") as mock_settings_class,
patch("agent_framework.azure._assistants_client.load_settings") as mock_load_settings,
patch("agent_framework.azure._assistants_client.AsyncAzureOpenAI") as mock_azure_client,
patch("agent_framework.openai.OpenAIAssistantsClient.__init__", return_value=None),
):
mock_settings = MagicMock()
mock_settings.chat_deployment_name = "test-deployment"
mock_settings.api_key.get_secret_value.return_value = "test-api-key"
mock_settings.base_url = None # No base_url
mock_settings.endpoint = "https://test-endpoint.openai.azure.com"
mock_settings.api_version = "2024-05-01-preview"
mock_settings_class.return_value = mock_settings
mock_load_settings.return_value = {
"chat_deployment_name": "test-deployment",
"responses_deployment_name": None,
"api_key": SecretString("test-api-key"),
"token_endpoint": None,
"api_version": "2024-05-01-preview",
"endpoint": "https://test-endpoint.openai.azure.com",
"base_url": None,
}
client = AzureOpenAIAssistantsClient(
deployment_name="test-deployment",
@@ -109,8 +109,11 @@ def test_init_with_empty_endpoint_and_base_url(azure_openai_unit_test_env: dict[
@pytest.mark.parametrize("override_env_param_dict", [{"AZURE_OPENAI_ENDPOINT": "http://test.com"}], indirect=True)
def test_init_with_invalid_endpoint(azure_openai_unit_test_env: dict[str, str]) -> None:
with pytest.raises(ServiceInitializationError):
AzureOpenAIChatClient()
# Note: URL scheme validation was previously handled by pydantic's HTTPsUrl type.
# After migrating to load_settings with TypedDict, endpoint is a plain string and no longer
# validated at the settings level. The Azure OpenAI SDK may reject invalid URLs at runtime.
client = AzureOpenAIChatClient()
assert client is not None
@pytest.mark.parametrize("exclude_list", [["AZURE_OPENAI_BASE_URL"]], indirect=True)
@@ -0,0 +1,238 @@
# Copyright (c) Microsoft. All rights reserved.
"""Tests for load_settings() function."""
import os
import tempfile
from typing import TypedDict
import pytest
from agent_framework._settings import SecretString, load_settings
class SimpleSettings(TypedDict, total=False):
api_key: str | None
timeout: int | None
enabled: bool | None
rate_limit: float | None
class RequiredFieldSettings(TypedDict, total=False):
name: str | None
optional_field: str | None
class SecretSettings(TypedDict, total=False):
api_key: SecretString | None
username: str | None
class TestLoadSettingsBasic:
"""Test basic load_settings functionality."""
def test_fields_are_none_when_unset(self) -> None:
settings = load_settings(SimpleSettings, env_prefix="TEST_APP_")
assert settings["api_key"] is None
assert settings["timeout"] is None
assert settings["enabled"] is None
assert settings["rate_limit"] is None
def test_overrides(self) -> None:
settings = load_settings(SimpleSettings, env_prefix="TEST_APP_", timeout=60, enabled=False)
assert settings["timeout"] == 60
assert settings["enabled"] is False
def test_none_overrides_are_filtered(self, monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setenv("TEST_APP_TIMEOUT", "120")
settings = load_settings(SimpleSettings, env_prefix="TEST_APP_", timeout=None)
# timeout=None is filtered, so env var wins
assert settings["timeout"] == 120
def test_env_vars(self, monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setenv("TEST_APP_API_KEY", "test-key-123")
monkeypatch.setenv("TEST_APP_TIMEOUT", "120")
monkeypatch.setenv("TEST_APP_ENABLED", "false")
settings = load_settings(SimpleSettings, env_prefix="TEST_APP_")
assert settings["api_key"] == "test-key-123"
assert settings["timeout"] == 120
assert settings["enabled"] is False
def test_overrides_beat_env_vars(self, monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setenv("TEST_APP_TIMEOUT", "120")
settings = load_settings(SimpleSettings, env_prefix="TEST_APP_", timeout=60)
assert settings["timeout"] == 60
def test_no_prefix(self, monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setenv("API_KEY", "no-prefix-key")
settings = load_settings(SimpleSettings, api_key=None)
assert settings["api_key"] == "no-prefix-key"
class TestDotenvFile:
"""Test .env file loading."""
def test_load_from_dotenv(self, monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.delenv("TEST_APP_API_KEY", raising=False)
monkeypatch.delenv("TEST_APP_TIMEOUT", raising=False)
with tempfile.NamedTemporaryFile(mode="w", suffix=".env", delete=False) as f:
f.write("TEST_APP_API_KEY=dotenv-key\n")
f.write("TEST_APP_TIMEOUT=90\n")
f.flush()
env_path = f.name
try:
settings = load_settings(SimpleSettings, env_prefix="TEST_APP_", env_file_path=env_path)
assert settings["api_key"] == "dotenv-key"
assert settings["timeout"] == 90
finally:
os.unlink(env_path)
def test_env_vars_override_dotenv(self, monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setenv("TEST_APP_API_KEY", "real-env-key")
with tempfile.NamedTemporaryFile(mode="w", suffix=".env", delete=False) as f:
f.write("TEST_APP_API_KEY=dotenv-key\n")
f.flush()
env_path = f.name
try:
settings = load_settings(SimpleSettings, env_prefix="TEST_APP_", env_file_path=env_path)
assert settings["api_key"] == "real-env-key"
finally:
os.unlink(env_path)
def test_missing_dotenv_file(self, monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.delenv("TEST_APP_API_KEY", raising=False)
settings = load_settings(SimpleSettings, env_prefix="TEST_APP_", env_file_path="/nonexistent/.env")
assert settings["api_key"] is None
class TestSecretString:
"""Test SecretString type handling."""
def test_secretstring_from_env(self, monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setenv("SECRET_API_KEY", "secret-value")
settings = load_settings(SecretSettings, env_prefix="SECRET_")
assert isinstance(settings["api_key"], SecretString)
assert settings["api_key"] == "secret-value"
def test_secretstring_from_override(self) -> None:
settings = load_settings(SecretSettings, env_prefix="SECRET_", api_key="kwarg-secret")
assert isinstance(settings["api_key"], SecretString)
assert settings["api_key"] == "kwarg-secret"
def test_secretstring_masked_in_repr(self) -> None:
s = SecretString("my-secret")
assert "my-secret" not in repr(s)
assert "**********" in repr(s)
def test_get_secret_value_compat(self) -> None:
s = SecretString("my-secret")
assert s.get_secret_value() == "my-secret"
assert isinstance(s.get_secret_value(), str)
class TestTypeCoercion:
"""Test type coercion from string values."""
def test_int_coercion(self, monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setenv("TEST_APP_TIMEOUT", "42")
settings = load_settings(SimpleSettings, env_prefix="TEST_APP_")
assert settings["timeout"] == 42
assert isinstance(settings["timeout"], int)
def test_float_coercion(self, monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setenv("TEST_APP_RATE_LIMIT", "2.5")
settings = load_settings(SimpleSettings, env_prefix="TEST_APP_")
assert settings["rate_limit"] == 2.5
assert isinstance(settings["rate_limit"], float)
def test_bool_coercion_true_values(self, monkeypatch: pytest.MonkeyPatch) -> None:
for true_val in ["true", "True", "TRUE", "1", "yes", "on"]:
monkeypatch.setenv("TEST_APP_ENABLED", true_val)
settings = load_settings(SimpleSettings, env_prefix="TEST_APP_")
assert settings["enabled"] is True, f"Failed for {true_val}"
def test_bool_coercion_false_values(self, monkeypatch: pytest.MonkeyPatch) -> None:
for false_val in ["false", "False", "FALSE", "0", "no", "off"]:
monkeypatch.setenv("TEST_APP_ENABLED", false_val)
settings = load_settings(SimpleSettings, env_prefix="TEST_APP_")
assert settings["enabled"] is False, f"Failed for {false_val}"
class TestRequiredFields:
"""Test required field validation."""
def test_required_field_provided(self) -> None:
settings = load_settings(
RequiredFieldSettings,
env_prefix="TEST_",
required_fields=["name"],
name="my-app",
)
assert settings["name"] == "my-app"
assert settings["optional_field"] is None
def test_required_field_from_env(self, monkeypatch: pytest.MonkeyPatch) -> None:
monkeypatch.setenv("TEST_NAME", "env-app")
settings = load_settings(RequiredFieldSettings, env_prefix="TEST_", required_fields=["name"])
assert settings["name"] == "env-app"
def test_required_field_missing_raises(self) -> None:
from agent_framework.exceptions import SettingNotFoundError
with pytest.raises(SettingNotFoundError, match="Required setting 'name'"):
load_settings(RequiredFieldSettings, env_prefix="TEST_", required_fields=["name"])
def test_without_required_fields_param_allows_none(self) -> None:
settings = load_settings(RequiredFieldSettings, env_prefix="TEST_")
assert settings["name"] is None
class TestOverrideTypeValidation:
"""Test override type validation."""
def test_invalid_type_raises(self) -> None:
from agent_framework.exceptions import ServiceInitializationError
with pytest.raises(ServiceInitializationError, match="Invalid type for setting 'api_key'"):
load_settings(SimpleSettings, env_prefix="TEST_", api_key={"bad": "type"})
def test_valid_types_accepted(self) -> None:
settings = load_settings(SimpleSettings, env_prefix="TEST_", timeout=42, enabled=True)
assert settings["timeout"] == 42
assert settings["enabled"] is True
def test_str_accepted_for_secretstring(self) -> None:
settings = load_settings(SecretSettings, env_prefix="TEST_", api_key="plain-string")
assert isinstance(settings["api_key"], SecretString)
assert settings["api_key"] == "plain-string"
@@ -131,9 +131,15 @@ class TestOpenAIAssistantProviderInit:
"""Test initialization fails without API key when settings return None."""
from unittest.mock import patch
# Mock OpenAISettings to return None for api_key
with patch("agent_framework.openai._assistant_provider.OpenAISettings") as mock_settings:
mock_settings.return_value.api_key = None
# Mock load_settings to return a dict with None for api_key
with patch("agent_framework.openai._assistant_provider.load_settings") as mock_load:
mock_load.return_value = {
"api_key": None,
"org_id": None,
"base_url": None,
"chat_model_id": None,
"responses_model_id": None,
}
with pytest.raises(ServiceInitializationError) as exc_info:
OpenAIAssistantProvider()
@@ -146,7 +146,7 @@ def test_init_auto_create_client(
def test_init_validation_fail() -> None:
"""Test OpenAIAssistantsClient initialization with validation failure."""
with pytest.raises(ServiceInitializationError):
# Force failure by providing invalid model ID type - this should cause validation to fail
# Force failure by providing invalid model ID type
OpenAIAssistantsClient(model_id=123, api_key="valid-key") # type: ignore
@@ -4,7 +4,7 @@ from __future__ import annotations
import sys
from collections.abc import Sequence
from typing import Any, ClassVar, Generic
from typing import Any, Generic
from agent_framework import (
ChatAndFunctionMiddlewareTypes,
@@ -13,7 +13,7 @@ from agent_framework import (
FunctionInvocationConfiguration,
FunctionInvocationLayer,
)
from agent_framework._pydantic import AFBaseSettings
from agent_framework._settings import load_settings
from agent_framework.exceptions import ServiceInitializationError
from agent_framework.observability import ChatTelemetryLayer
from agent_framework.openai._chat_client import RawOpenAIChatClient
@@ -115,25 +115,19 @@ FoundryLocalChatOptionsT = TypeVar(
# endregion
class FoundryLocalSettings(AFBaseSettings):
class FoundryLocalSettings(TypedDict, total=False):
"""Foundry local model settings.
The settings are first loaded from environment variables with the prefix 'FOUNDRY_LOCAL_'.
If the environment variables are not found, the settings can be loaded from a .env file
with the encoding 'utf-8'. If the settings are not found in the .env file, the settings
are ignored; however, validation will fail alerting that the settings are missing.
with the encoding 'utf-8'.
Attributes:
Keys:
model_id: The name of the model deployment to use.
(Env var FOUNDRY_LOCAL_MODEL_ID)
Parameters:
env_file_path: If provided, the .env settings are read from this file path location.
env_file_encoding: The encoding of the .env file, defaults to 'utf-8'.
"""
env_prefix: ClassVar[str] = "FOUNDRY_LOCAL_"
model_id: str
model_id: str | None
class FoundryLocalClient(
@@ -247,21 +241,27 @@ class FoundryLocalClient(
type that is not supported by the model, it will not be found.
"""
settings = FoundryLocalSettings(
model_id=model_id, # type: ignore
settings = load_settings(
FoundryLocalSettings,
env_prefix="FOUNDRY_LOCAL_",
required_fields=["model_id"],
model_id=model_id,
env_file_path=env_file_path,
env_file_encoding=env_file_encoding,
)
manager = FoundryLocalManager(bootstrap=bootstrap, timeout=timeout)
model_info = manager.get_model_info(
alias_or_model_id=settings.model_id,
alias_or_model_id=settings["model_id"],
device=device,
)
if model_info is None:
message = (
f"Model with ID or alias '{settings.model_id}:{device.value}' not found in Foundry Local."
f"Model with ID or alias '{settings['model_id']}:{device.value}' not found in Foundry Local."
if device
else f"Model with ID or alias '{settings.model_id}' for your current device not found in Foundry Local."
else (
f"Model with ID or alias '{settings['model_id']}' for your current device "
"not found in Foundry Local."
)
)
raise ServiceInitializationError(message)
if prepare_model:
@@ -4,8 +4,8 @@ from unittest.mock import MagicMock, patch
import pytest
from agent_framework import SupportsChatGetResponse
from agent_framework.exceptions import ServiceInitializationError
from pydantic import ValidationError
from agent_framework._settings import load_settings
from agent_framework.exceptions import ServiceInitializationError, SettingNotFoundError
from agent_framework_foundry_local import FoundryLocalClient
from agent_framework_foundry_local._foundry_local_client import FoundryLocalSettings
@@ -15,31 +15,43 @@ from agent_framework_foundry_local._foundry_local_client import FoundryLocalSett
def test_foundry_local_settings_init_from_env(foundry_local_unit_test_env: dict[str, str]) -> None:
"""Test FoundryLocalSettings initialization from environment variables."""
settings = FoundryLocalSettings(env_file_path="test.env")
settings = load_settings(FoundryLocalSettings, env_prefix="FOUNDRY_LOCAL_", env_file_path="test.env")
assert settings.model_id == foundry_local_unit_test_env["FOUNDRY_LOCAL_MODEL_ID"]
assert settings["model_id"] == foundry_local_unit_test_env["FOUNDRY_LOCAL_MODEL_ID"]
def test_foundry_local_settings_init_with_explicit_values() -> None:
"""Test FoundryLocalSettings initialization with explicit values."""
settings = FoundryLocalSettings(model_id="custom-model-id", env_file_path="test.env")
settings = load_settings(
FoundryLocalSettings,
env_prefix="FOUNDRY_LOCAL_",
model_id="custom-model-id",
env_file_path="test.env",
)
assert settings.model_id == "custom-model-id"
assert settings["model_id"] == "custom-model-id"
@pytest.mark.parametrize("exclude_list", [["FOUNDRY_LOCAL_MODEL_ID"]], indirect=True)
def test_foundry_local_settings_missing_model_id(foundry_local_unit_test_env: dict[str, str]) -> None:
"""Test FoundryLocalSettings when model_id is missing raises ValidationError."""
with pytest.raises(ValidationError):
FoundryLocalSettings(env_file_path="test.env")
"""Test FoundryLocalSettings when model_id is missing raises error."""
with pytest.raises(SettingNotFoundError, match="Required setting 'model_id'"):
load_settings(
FoundryLocalSettings,
env_prefix="FOUNDRY_LOCAL_",
required_fields=["model_id"],
env_file_path="test.env",
)
def test_foundry_local_settings_explicit_overrides_env(foundry_local_unit_test_env: dict[str, str]) -> None:
"""Test that explicit values override environment variables."""
settings = FoundryLocalSettings(model_id="override-model-id", env_file_path="test.env")
settings = load_settings(
FoundryLocalSettings, env_prefix="FOUNDRY_LOCAL_", model_id="override-model-id", env_file_path="test.env"
)
assert settings.model_id == "override-model-id"
assert settings.model_id != foundry_local_unit_test_env["FOUNDRY_LOCAL_MODEL_ID"]
assert settings["model_id"] == "override-model-id"
assert settings["model_id"] != foundry_local_unit_test_env["FOUNDRY_LOCAL_MODEL_ID"]
# Client Initialization Tests
@@ -21,9 +21,10 @@ from agent_framework import (
ResponseStream,
normalize_messages,
)
from agent_framework._settings import load_settings
from agent_framework._tools import FunctionTool
from agent_framework._types import normalize_tools
from agent_framework.exceptions import ServiceException, ServiceInitializationError
from agent_framework.exceptions import ServiceException
from copilot import CopilotClient, CopilotSession
from copilot.generated.session_events import SessionEvent, SessionEventType
from copilot.types import (
@@ -38,7 +39,6 @@ from copilot.types import (
ToolResult,
)
from copilot.types import Tool as CopilotTool
from pydantic import ValidationError
from ._settings import GitHubCopilotSettings
@@ -207,17 +207,16 @@ class GitHubCopilotAgent(BaseAgent, Generic[OptionsT]):
on_permission_request: PermissionHandlerType | None = opts.pop("on_permission_request", None)
mcp_servers: dict[str, MCPServerConfig] | None = opts.pop("mcp_servers", None)
try:
self._settings = GitHubCopilotSettings(
cli_path=cli_path,
model=model,
timeout=timeout,
log_level=log_level,
env_file_path=env_file_path,
env_file_encoding=env_file_encoding,
)
except ValidationError as ex:
raise ServiceInitializationError("Failed to create GitHub Copilot settings.", ex) from ex
self._settings = load_settings(
GitHubCopilotSettings,
env_prefix="GITHUB_COPILOT_",
cli_path=cli_path,
model=model,
timeout=timeout,
log_level=log_level,
env_file_path=env_file_path,
env_file_encoding=env_file_encoding,
)
self._tools = normalize_tools(tools)
self._permission_handler = on_permission_request
@@ -249,10 +248,10 @@ class GitHubCopilotAgent(BaseAgent, Generic[OptionsT]):
if self._client is None:
client_options: CopilotClientOptions = {}
if self._settings.cli_path:
client_options["cli_path"] = self._settings.cli_path
if self._settings.log_level:
client_options["log_level"] = self._settings.log_level # type: ignore[typeddict-item]
if self._settings["cli_path"]:
client_options["cli_path"] = self._settings["cli_path"]
if self._settings["log_level"]:
client_options["log_level"] = self._settings["log_level"] # type: ignore[typeddict-item]
self._client = CopilotClient(client_options if client_options else None)
@@ -355,7 +354,7 @@ class GitHubCopilotAgent(BaseAgent, Generic[OptionsT]):
thread = self.get_new_thread()
opts: dict[str, Any] = dict(options) if options else {}
timeout = opts.pop("timeout", None) or self._settings.timeout or DEFAULT_TIMEOUT_SECONDS
timeout = opts.pop("timeout", None) or self._settings["timeout"] or DEFAULT_TIMEOUT_SECONDS
session = await self._get_or_create_session(thread, streaming=False, runtime_options=opts)
input_messages = normalize_messages(messages)
@@ -578,7 +577,7 @@ class GitHubCopilotAgent(BaseAgent, Generic[OptionsT]):
opts = runtime_options or {}
config: SessionConfig = {"streaming": streaming}
model = opts.get("model") or self._settings.model
model = opts.get("model") or self._settings["model"]
if model:
config["model"] = model # type: ignore[typeddict-item]
@@ -1,19 +1,16 @@
# Copyright (c) Microsoft. All rights reserved.
from typing import ClassVar
from agent_framework._pydantic import AFBaseSettings
from typing import TypedDict
class GitHubCopilotSettings(AFBaseSettings):
class GitHubCopilotSettings(TypedDict, total=False):
"""GitHub Copilot model settings.
The settings are first loaded from environment variables with the prefix 'GITHUB_COPILOT_'.
If the environment variables are not found, the settings can be loaded from a .env file
with the encoding 'utf-8'. If the settings are not found in the .env file, the settings
are ignored; however, validation will fail alerting that the settings are missing.
with the encoding 'utf-8'.
Keyword Args:
Keys:
cli_path: Path to the Copilot CLI executable.
Can be set via environment variable GITHUB_COPILOT_CLI_PATH.
model: Model to use (e.g., "gpt-5", "claude-sonnet-4").
@@ -22,28 +19,9 @@ class GitHubCopilotSettings(AFBaseSettings):
Can be set via environment variable GITHUB_COPILOT_TIMEOUT.
log_level: CLI log level.
Can be set via environment variable GITHUB_COPILOT_LOG_LEVEL.
env_file_path: If provided, the .env settings are read from this file path location.
env_file_encoding: The encoding of the .env file, defaults to 'utf-8'.
Examples:
.. code-block:: python
from agent_framework_github_copilot import GitHubCopilotSettings
# Using environment variables
# Set GITHUB_COPILOT_MODEL=gpt-5
settings = GitHubCopilotSettings()
# Or passing parameters directly
settings = GitHubCopilotSettings(model="claude-sonnet-4", timeout=120)
# Or loading from a .env file
settings = GitHubCopilotSettings(env_file_path="path/to/.env")
"""
env_prefix: ClassVar[str] = "GITHUB_COPILOT_"
cli_path: str | None = None
model: str | None = None
timeout: float | None = None
log_level: str | None = None
cli_path: str | None
model: str | None
timeout: float | None
log_level: str | None
@@ -122,8 +122,8 @@ class TestGitHubCopilotAgentInit:
agent: GitHubCopilotAgent[GitHubCopilotOptions] = GitHubCopilotAgent(
default_options={"model": "claude-sonnet-4", "timeout": 120}
)
assert agent._settings.model == "claude-sonnet-4" # type: ignore
assert agent._settings.timeout == 120 # type: ignore
assert agent._settings["model"] == "claude-sonnet-4" # type: ignore
assert agent._settings["timeout"] == 120 # type: ignore
def test_init_with_tools(self) -> None:
"""Test initialization with function tools."""
@@ -30,9 +30,8 @@ from agent_framework import (
UsageDetails,
get_logger,
)
from agent_framework._pydantic import AFBaseSettings
from agent_framework._settings import load_settings
from agent_framework.exceptions import (
ServiceInitializationError,
ServiceInvalidRequestError,
ServiceResponseException,
)
@@ -42,7 +41,7 @@ from ollama import AsyncClient
# Rename imported types to avoid naming conflicts with Agent Framework types
from ollama._types import ChatResponse as OllamaChatResponse
from ollama._types import Message as OllamaMessage
from pydantic import BaseModel, ValidationError
from pydantic import BaseModel
if sys.version_info >= (3, 13):
from typing import TypeVar # type: ignore # pragma: no cover
@@ -275,13 +274,11 @@ OllamaChatOptionsT = TypeVar("OllamaChatOptionsT", bound=TypedDict, default="Oll
# endregion
class OllamaSettings(AFBaseSettings):
class OllamaSettings(TypedDict, total=False):
"""Ollama settings."""
env_prefix: ClassVar[str] = "OLLAMA_"
host: str | None = None
model_id: str | None = None
host: str | None
model_id: str | None
logger = get_logger("agent_framework.ollama")
@@ -322,23 +319,19 @@ class OllamaChatClient(
env_file_encoding: The encoding to use when reading the dotenv (.env) file. Defaults to 'utf-8'.
**kwargs: Additional keyword arguments passed to BaseChatClient.
"""
try:
ollama_settings = OllamaSettings(
host=host,
model_id=model_id,
env_file_encoding=env_file_encoding,
env_file_path=env_file_path,
)
except ValidationError as ex:
raise ServiceInitializationError("Failed to create Ollama settings.", ex) from ex
ollama_settings = load_settings(
OllamaSettings,
env_prefix="OLLAMA_",
required_fields=["model_id"],
host=host,
model_id=model_id,
env_file_encoding=env_file_encoding,
env_file_path=env_file_path,
)
if ollama_settings.model_id is None:
raise ServiceInitializationError(
"Ollama chat model ID must be provided via model_id or OLLAMA_MODEL_ID environment variable."
)
self.model_id = ollama_settings.model_id
self.client = client or AsyncClient(host=ollama_settings.host)
self.model_id = ollama_settings["model_id"]
# we can just pass in None for the host, the default is set by the Ollama package.
self.client = client or AsyncClient(host=ollama_settings.get("host"))
# Save Host URL for serialization with to_dict()
self.host = str(self.client._client.base_url) # pyright: ignore[reportUnknownMemberType,reportPrivateUsage,reportUnknownArgumentType]
@@ -15,9 +15,9 @@ from agent_framework import (
tool,
)
from agent_framework.exceptions import (
ServiceInitializationError,
ServiceInvalidRequestError,
ServiceResponseException,
SettingNotFoundError,
)
from ollama import AsyncClient
from ollama._types import ChatResponse as OllamaChatResponse
@@ -182,7 +182,7 @@ def test_init_client(ollama_unit_test_env: dict[str, str]) -> None:
@pytest.mark.parametrize("exclude_list", [["OLLAMA_MODEL_ID"]], indirect=True)
def test_with_invalid_settings(ollama_unit_test_env: dict[str, str]) -> None:
with pytest.raises(ServiceInitializationError):
with pytest.raises(SettingNotFoundError, match="Required setting 'model_id'"):
OllamaChatClient(
host="http://localhost:12345",
model_id=None,
@@ -9,7 +9,7 @@ from ._exceptions import (
PurviewServiceError,
)
from ._middleware import PurviewChatPolicyMiddleware, PurviewPolicyMiddleware
from ._settings import PurviewAppLocation, PurviewLocationType, PurviewSettings
from ._settings import PurviewAppLocation, PurviewLocationType, PurviewSettings, get_purview_scopes
__all__ = [
"CacheProvider",
@@ -23,4 +23,5 @@ __all__ = [
"PurviewRequestError",
"PurviewServiceError",
"PurviewSettings",
"get_purview_scopes",
]
@@ -31,7 +31,7 @@ from ._models import (
ProtectionScopesRequest,
ProtectionScopesResponse,
)
from ._settings import PurviewSettings
from ._settings import PurviewSettings, get_purview_scopes
logger = get_logger("agent_framework.purview")
@@ -52,7 +52,7 @@ class PurviewClient:
):
self._credential: TokenCredential | AsyncTokenCredential = credential
self._settings = settings
self._graph_uri = settings.graph_base_uri.rstrip("/")
self._graph_uri = (settings.get("graph_base_uri") or "https://graph.microsoft.com/v1.0/").rstrip("/")
self._timeout = timeout
self._client = httpx.AsyncClient(timeout=timeout)
@@ -61,7 +61,7 @@ class PurviewClient:
async def _get_token(self, *, tenant_id: str | None = None) -> str:
"""Acquire an access token using either async or sync credential."""
scopes = self._settings.get_scopes()
scopes = get_purview_scopes(self._settings)
cred = self._credential
token = cred.get_token(*scopes, tenant_id=tenant_id)
token = await token if inspect.isawaitable(token) else token
@@ -167,7 +167,7 @@ class PurviewClient:
if resp.status_code in (401, 403):
raise PurviewAuthenticationError(f"Auth failure {resp.status_code}: {resp.text}")
if resp.status_code == 402:
if self._settings.ignore_payment_required:
if self._settings.get("ignore_payment_required", False):
return response_type() # type: ignore[call-arg, no-any-return]
raise PurviewPaymentRequiredError(f"Payment required {resp.status_code}: {resp.text}")
if resp.status_code == 429:
@@ -78,18 +78,22 @@ class PurviewPolicyMiddleware(AgentMiddleware):
from agent_framework import AgentResponse, Message
context.result = AgentResponse(
messages=[Message(role="system", text=self._settings.blocked_prompt_message)]
messages=[
Message(
role="system", text=self._settings.get("blocked_prompt_message", "Prompt blocked by policy")
)
]
)
raise MiddlewareTermination
except MiddlewareTermination:
raise
except PurviewPaymentRequiredError as ex:
logger.error(f"Purview payment required error in policy pre-check: {ex}")
if not self._settings.ignore_payment_required:
if not self._settings.get("ignore_payment_required", False):
raise
except Exception as ex:
logger.error(f"Error in Purview policy pre-check: {ex}")
if not self._settings.ignore_exceptions:
if not self._settings.get("ignore_exceptions", False):
raise
await call_next()
@@ -111,18 +115,23 @@ class PurviewPolicyMiddleware(AgentMiddleware):
from agent_framework import AgentResponse, Message
context.result = AgentResponse(
messages=[Message(role="system", text=self._settings.blocked_response_message)]
messages=[
Message(
role="system",
text=self._settings.get("blocked_response_message", "Response blocked by policy"),
)
]
)
else:
# Streaming responses are not supported for post-checks
logger.debug("Streaming responses are not supported for Purview policy post-checks")
except PurviewPaymentRequiredError as ex:
logger.error(f"Purview payment required error in policy post-check: {ex}")
if not self._settings.ignore_payment_required:
if not self._settings.get("ignore_payment_required", False):
raise
except Exception as ex:
logger.error(f"Error in Purview policy post-check: {ex}")
if not self._settings.ignore_exceptions:
if not self._settings.get("ignore_exceptions", False):
raise
@@ -173,18 +182,20 @@ class PurviewChatPolicyMiddleware(ChatMiddleware):
if should_block_prompt:
from agent_framework import ChatResponse, Message
blocked_message = Message(role="system", text=self._settings.blocked_prompt_message)
blocked_message = Message(
role="system", text=self._settings.get("blocked_prompt_message", "Prompt blocked by policy")
)
context.result = ChatResponse(messages=[blocked_message])
raise MiddlewareTermination
except MiddlewareTermination:
raise
except PurviewPaymentRequiredError as ex:
logger.error(f"Purview payment required error in policy pre-check: {ex}")
if not self._settings.ignore_payment_required:
if not self._settings.get("ignore_payment_required", False):
raise
except Exception as ex:
logger.error(f"Error in Purview policy pre-check: {ex}")
if not self._settings.ignore_exceptions:
if not self._settings.get("ignore_exceptions", False):
raise
await call_next()
@@ -205,15 +216,18 @@ class PurviewChatPolicyMiddleware(ChatMiddleware):
if should_block_response:
from agent_framework import ChatResponse, Message
blocked_message = Message(role="system", text=self._settings.blocked_response_message)
blocked_message = Message(
role="system",
text=self._settings.get("blocked_response_message", "Response blocked by policy"),
)
context.result = ChatResponse(messages=[blocked_message])
else:
logger.debug("Streaming responses are not supported for Purview policy post-checks")
except PurviewPaymentRequiredError as ex:
logger.error(f"Purview payment required error in policy post-check: {ex}")
if not self._settings.ignore_payment_required:
if not self._settings.get("ignore_payment_required", False):
raise
except Exception as ex:
logger.error(f"Error in Purview policy post-check: {ex}")
if not self._settings.ignore_exceptions:
if not self._settings.get("ignore_exceptions", False):
raise
@@ -57,8 +57,11 @@ class ScopedContentProcessor:
def __init__(self, client: PurviewClient, settings: PurviewSettings, cache_provider: CacheProvider | None = None):
self._client = client
self._settings = settings
cache_ttl = settings.get("cache_ttl_seconds")
max_cache = settings.get("max_cache_size_bytes")
self._cache: CacheProvider = cache_provider or InMemoryCacheProvider(
default_ttl_seconds=settings.cache_ttl_seconds, max_size_bytes=settings.max_cache_size_bytes
default_ttl_seconds=cache_ttl if cache_ttl is not None else 14400,
max_size_bytes=max_cache if max_cache is not None else 200 * 1024 * 1024,
)
self._background_tasks: set[asyncio.Task[Any]] = set()
@@ -116,10 +119,10 @@ class ScopedContentProcessor:
results: list[ProcessContentRequest] = []
token_info = None
if not (self._settings.tenant_id and self._settings.purview_app_location):
token_info = await self._client.get_user_info_from_token(tenant_id=self._settings.tenant_id)
if not (self._settings.get("tenant_id") and self._settings.get("purview_app_location")):
token_info = await self._client.get_user_info_from_token(tenant_id=self._settings.get("tenant_id"))
tenant_id = (token_info or {}).get("tenant_id") or self._settings.tenant_id
tenant_id = (token_info or {}).get("tenant_id") or self._settings.get("tenant_id")
if not tenant_id or not _is_valid_guid(tenant_id):
raise ValueError("Tenant id required or must be inferable from credential")
@@ -159,10 +162,11 @@ class ScopedContentProcessor:
)
activity_meta = ActivityMetadata(activity=activity)
if self._settings.purview_app_location:
purview_app_location = self._settings.get("purview_app_location")
if purview_app_location:
policy_location = PolicyLocation(
data_type=self._settings.purview_app_location.get_policy_location()["@odata.type"],
value=self._settings.purview_app_location.location_value,
data_type=purview_app_location.get_policy_location()["@odata.type"],
value=purview_app_location.location_value,
)
elif token_info and token_info.get("client_id"):
policy_location = PolicyLocation(
@@ -172,13 +176,14 @@ class ScopedContentProcessor:
else:
raise ValueError("App location not provided or inferable")
app_version = self._settings.app_version or "Unknown"
protected_app = ProtectedAppMetadata(
name=self._settings.app_name,
version=app_version,
name=self._settings["app_name"],
version=self._settings.get("app_version", "Unknown"),
application_location=policy_location,
)
integrated_app = IntegratedAppMetadata(name=self._settings.app_name, version=app_version)
integrated_app = IntegratedAppMetadata(
name=self._settings["app_name"], version=self._settings.get("app_version", "Unknown")
)
device_meta = DeviceMetadata(
operating_system_specifications=OperatingSystemSpecifications(
operating_system_platform="Unknown", operating_system_version="Unknown"
@@ -229,11 +234,13 @@ class ScopedContentProcessor:
ps_resp = cached_ps_resp
else:
try:
ttl = self._settings.get("cache_ttl_seconds")
ttl_seconds = ttl if ttl is not None else 14400
ps_resp = await self._client.get_protection_scopes(ps_req)
await self._cache.set(cache_key, ps_resp, ttl_seconds=self._settings.cache_ttl_seconds)
await self._cache.set(cache_key, ps_resp, ttl_seconds=ttl_seconds)
except PurviewPaymentRequiredError as ex:
# Cache the exception at tenant level so all subsequent requests for this tenant fail fast
await self._cache.set(tenant_payment_cache_key, ex, ttl_seconds=self._settings.cache_ttl_seconds)
await self._cache.set(tenant_payment_cache_key, ex, ttl_seconds=ttl_seconds)
raise
if ps_resp.scope_identifier:
@@ -1,10 +1,14 @@
# Copyright (c) Microsoft. All rights reserved.
import sys
from enum import Enum
from agent_framework._pydantic import AFBaseSettings
from pydantic import BaseModel, Field
from pydantic_settings import SettingsConfigDict
from pydantic import BaseModel
if sys.version_info >= (3, 11):
from typing import TypedDict # pragma: no cover
else:
from typing_extensions import TypedDict # type: ignore # pragma: no cover
class PurviewLocationType(str, Enum):
@@ -18,8 +22,8 @@ class PurviewLocationType(str, Enum):
class PurviewAppLocation(BaseModel):
"""Identifier representing the app's location for Purview policy evaluation."""
location_type: PurviewLocationType = Field(..., description="The location type.")
location_value: str = Field(..., description="The location value.")
location_type: PurviewLocationType
location_value: str
def get_policy_location(self) -> dict[str, str]:
ns = "microsoft.graph"
@@ -34,8 +38,8 @@ class PurviewAppLocation(BaseModel):
return {"@odata.type": dt, "value": self.location_value}
class PurviewSettings(AFBaseSettings):
"""Settings for Purview integration.
class PurviewSettings(TypedDict, total=False):
"""Settings for Purview integration mirroring .NET PurviewSettings.
Attributes:
app_name: Public app name.
@@ -51,40 +55,30 @@ class PurviewSettings(AFBaseSettings):
max_cache_size_bytes: Maximum cache size in bytes (default 200MB).
"""
app_name: str = Field(...)
app_version: str | None = Field(default=None)
tenant_id: str | None = Field(default=None)
purview_app_location: PurviewAppLocation | None = Field(default=None)
graph_base_uri: str = Field(default="https://graph.microsoft.com/v1.0/")
blocked_prompt_message: str = Field(
default="Prompt blocked by policy",
description="Message to return when a prompt is blocked by policy.",
)
blocked_response_message: str = Field(
default="Response blocked by policy",
description="Message to return when a response is blocked by policy.",
)
ignore_exceptions: bool = Field(
default=False,
description="If True, all Purview exceptions will be logged but not thrown in middleware.",
)
ignore_payment_required: bool = Field(
default=False,
description="If True, 402 payment required errors will be logged but not thrown.",
)
cache_ttl_seconds: int = Field(
default=14400,
description="Time to live for cache entries in seconds (default 14400 = 4 hours).",
)
max_cache_size_bytes: int = Field(
default=200 * 1024 * 1024,
description="Maximum cache size in bytes (default 200MB).",
)
app_name: str | None
app_version: str | None
tenant_id: str | None
purview_app_location: PurviewAppLocation | None
graph_base_uri: str | None
blocked_prompt_message: str | None
blocked_response_message: str | None
ignore_exceptions: bool | None
ignore_payment_required: bool | None
cache_ttl_seconds: int | None
max_cache_size_bytes: int | None
model_config = SettingsConfigDict(populate_by_name=True, validate_assignment=True)
def get_scopes(self) -> list[str]:
from urllib.parse import urlparse
def get_purview_scopes(settings: PurviewSettings) -> list[str]:
"""Get the OAuth scopes for the Purview Graph API.
host = urlparse(self.graph_base_uri).hostname or "graph.microsoft.com"
return [f"https://{host}/.default"]
Args:
settings: The Purview settings containing graph_base_uri.
Returns:
A list of OAuth scope strings.
"""
from urllib.parse import urlparse
graph_base_uri = settings.get("graph_base_uri", "https://graph.microsoft.com/v1.0/")
host = urlparse(str(graph_base_uri)).hostname or "graph.microsoft.com"
return [f"https://{host}/.default"]
@@ -125,7 +125,7 @@ class TestPurviewChatPolicyMiddleware:
) -> None:
"""Test that exceptions in post-check are logged but don't affect result when ignore_exceptions=True."""
# Set ignore_exceptions to True to test exception suppression
middleware._settings.ignore_exceptions = True
middleware._settings["ignore_exceptions"] = True
call_count = 0
@@ -119,7 +119,7 @@ class TestPurviewPolicyMiddleware:
) -> None:
"""Test middleware handles result that doesn't have messages attribute."""
# Set ignore_exceptions to True so AttributeError is caught and logged
middleware._settings.ignore_exceptions = True
middleware._settings["ignore_exceptions"] = True
context = AgentContext(agent=mock_agent, messages=[Message(role="user", text="Hello")])
@@ -216,7 +216,7 @@ class TestPurviewPolicyMiddleware:
self, middleware: PurviewPolicyMiddleware, mock_agent: MagicMock
) -> None:
"""Test that post-check exceptions are propagated when ignore_exceptions=False."""
middleware._settings.ignore_exceptions = False
middleware._settings["ignore_exceptions"] = False
context = AgentContext(agent=mock_agent, messages=[Message(role="user", text="Hello")])
@@ -242,7 +242,7 @@ class TestPurviewPolicyMiddleware:
) -> None:
"""Test that exceptions in pre-check are logged but don't stop processing when ignore_exceptions=True."""
# Set ignore_exceptions to True
middleware._settings.ignore_exceptions = True
middleware._settings["ignore_exceptions"] = True
context = AgentContext(agent=mock_agent, messages=[Message(role="user", text="Test")])
@@ -265,7 +265,7 @@ class TestPurviewPolicyMiddleware:
) -> None:
"""Test that exceptions in post-check are logged but don't affect result when ignore_exceptions=True."""
# Set ignore_exceptions to True
middleware._settings.ignore_exceptions = True
middleware._settings["ignore_exceptions"] = True
context = AgentContext(agent=mock_agent, messages=[Message(role="user", text="Test")])
@@ -636,7 +636,6 @@ class TestScopedContentProcessorCaching:
return PurviewSettings(
app_name="Test App",
tenant_id="12345678-1234-1234-1234-123456789012",
default_user_id="12345678-1234-1234-1234-123456789012",
purview_app_location=location,
)
@@ -47,7 +47,7 @@ class TestPurviewClient:
@pytest.fixture
def settings(self) -> PurviewSettings:
"""Create test settings."""
return PurviewSettings(app_name="Test App", tenant_id="test-tenant", default_user_id="test-user")
return PurviewSettings(app_name="Test App", tenant_id="test-tenant")
@pytest.fixture
async def client(
@@ -4,7 +4,7 @@
import pytest
from agent_framework_purview import PurviewAppLocation, PurviewLocationType, PurviewSettings
from agent_framework_purview import PurviewAppLocation, PurviewLocationType, PurviewSettings, get_purview_scopes
class TestPurviewSettings:
@@ -14,10 +14,10 @@ class TestPurviewSettings:
"""Test PurviewSettings with default values."""
settings = PurviewSettings(app_name="Test App")
assert settings.app_name == "Test App"
assert settings.graph_base_uri == "https://graph.microsoft.com/v1.0/"
assert settings.tenant_id is None
assert settings.purview_app_location is None
assert settings["app_name"] == "Test App"
assert settings.get("graph_base_uri") is None
assert settings.get("tenant_id") is None
assert settings.get("purview_app_location") is None
def test_settings_with_custom_values(self) -> None:
"""Test PurviewSettings with custom values."""
@@ -30,9 +30,9 @@ class TestPurviewSettings:
purview_app_location=app_location,
)
assert settings.graph_base_uri == "https://graph.microsoft-ppe.com"
assert settings.tenant_id == "test-tenant-id"
assert settings.purview_app_location.location_value == "app-123"
assert settings["graph_base_uri"] == "https://graph.microsoft-ppe.com"
assert settings["tenant_id"] == "test-tenant-id"
assert settings["purview_app_location"].location_value == "app-123"
@pytest.mark.parametrize(
"graph_uri,expected_scope",
@@ -44,7 +44,7 @@ class TestPurviewSettings:
def test_get_scopes(self, graph_uri: str, expected_scope: str) -> None:
"""Test get_scopes returns correct scope for different URIs."""
settings = PurviewSettings(app_name="Test App", graph_base_uri=graph_uri)
scopes = settings.get_scopes()
scopes = get_purview_scopes(settings)
assert len(scopes) == 1
assert expected_scope in scopes
+2 -2
View File
@@ -331,7 +331,7 @@ dependencies = [
{ name = "opentelemetry-semantic-conventions-ai", marker = "sys_platform == 'darwin' or sys_platform == 'linux' or sys_platform == 'win32'" },
{ name = "packaging", marker = "sys_platform == 'darwin' or sys_platform == 'linux' or sys_platform == 'win32'" },
{ name = "pydantic", marker = "sys_platform == 'darwin' or sys_platform == 'linux' or sys_platform == 'win32'" },
{ name = "pydantic-settings", marker = "sys_platform == 'darwin' or sys_platform == 'linux' or sys_platform == 'win32'" },
{ name = "python-dotenv", marker = "sys_platform == 'darwin' or sys_platform == 'linux' or sys_platform == 'win32'" },
{ name = "typing-extensions", marker = "sys_platform == 'darwin' or sys_platform == 'linux' or sys_platform == 'win32'" },
]
@@ -386,7 +386,7 @@ requires-dist = [
{ name = "opentelemetry-semantic-conventions-ai", specifier = ">=0.4.13" },
{ name = "packaging", specifier = ">=24.1" },
{ name = "pydantic", specifier = ">=2,<3" },
{ name = "pydantic-settings", specifier = ">=2,<3" },
{ name = "python-dotenv", specifier = ">=1,<2" },
{ name = "typing-extensions" },
]
provides-extras = ["all"]